Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions bin/katana/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ katana-primitives.workspace = true
katana-rpc-client.workspace = true
katana-rpc-types.workspace = true
katana-utils.workspace = true
katana-stage.workspace = true

anyhow.workspace = true
async-trait.workspace = true
Expand Down
6 changes: 6 additions & 0 deletions bin/katana/src/cli/stage/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use anyhow::Result;
use clap::{Args, Subcommand};

use crate::cli::execute_async;

mod checkpoint;
mod unwind;

#[derive(Debug, Args)]
#[cfg_attr(test, derive(PartialEq))]
Expand All @@ -15,12 +18,15 @@ pub struct StageArgs {
enum Commands {
/// Manage stage checkpoints
Checkpoint(checkpoint::CheckpointArgs),
/// Unwind a stage to a previous state
Unwind(unwind::UnwindArgs),
}

impl StageArgs {
pub fn execute(self) -> Result<()> {
match self.commands {
Commands::Checkpoint(args) => args.execute(),
Commands::Unwind(args) => execute_async(args.execute())?,
}
}
}
38 changes: 38 additions & 0 deletions bin/katana/src/cli/stage/unwind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use anyhow::Result;
use clap::Args;
use katana_primitives::block::BlockNumber;
use katana_provider::api::stage::StageCheckpointProvider;
use katana_provider::providers::db::DbProvider;
use katana_stage::Stage;

use crate::cli::db::open_db_rw;

#[derive(Debug, Args)]
#[cfg_attr(test, derive(PartialEq))]
pub struct UnwindArgs {
/// The stage ID to unwind
#[arg(value_name = "STAGE_ID")]
stage_id: String,

/// The stage ID to unwind to
#[arg(value_name = "UNWIND_TO")]
unwind_to: BlockNumber,

/// Path to the database directory.
#[arg(short, long)]
path: String,
}

impl UnwindArgs {
pub async fn execute(self) -> Result<()> {
use katana_stage::StateTrie;

let provider = DbProvider::new(open_db_rw(&self.path)?);
let mut stage = StateTrie::new(&provider);

stage.unwind(self.unwind_to).await?;
provider.set_checkpoint(stage.id(), self.unwind_to)?;

Ok(())
}
}
13 changes: 13 additions & 0 deletions crates/core/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,4 +687,17 @@ impl TrieWriter for GenesisTrieWriter {
trie.commit(block_number);
Ok(trie.root())
}

fn unwind_classes_trie(&self, unwind_to: BlockNumber) -> katana_provider::ProviderResult<Felt> {
let _ = unwind_to;
unimplemented!("unwinding not supported for genesis trie")
}

fn unwind_contracts_trie(
&self,
unwind_to: BlockNumber,
) -> katana_provider::ProviderResult<Felt> {
let _ = unwind_to;
unimplemented!("unwinding not supported for genesis trie")
}
}
4 changes: 4 additions & 0 deletions crates/storage/provider/provider-api/src/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@ pub trait TrieWriter: Send + Sync {
block_number: BlockNumber,
state_updates: &StateUpdates,
) -> ProviderResult<Felt>;

fn unwind_classes_trie(&self, unwind_to: BlockNumber) -> ProviderResult<Felt>;

fn unwind_contracts_trie(&self, unwind_to: BlockNumber) -> ProviderResult<Felt>;
}
41 changes: 39 additions & 2 deletions crates/storage/provider/provider/src/providers/db/trie.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::collections::{BTreeMap, HashMap};
use std::collections::{BTreeMap, BTreeSet, HashMap};

use katana_db::abstraction::DbTxMut;
use katana_db::abstraction::{DbCursor, DbTxMut};
use katana_db::tables;
use katana_db::trie::TrieDbMut;
use katana_primitives::block::BlockNumber;
use katana_primitives::class::{ClassHash, CompiledClassHash};
use katana_primitives::state::StateUpdates;
use katana_primitives::{ContractAddress, Felt};
use katana_provider_api::block::BlockNumberProvider;
use katana_provider_api::state::{StateFactoryProvider, StateProvider};
use katana_provider_api::trie::TrieWriter;
use katana_provider_api::ProviderError;
Expand Down Expand Up @@ -105,6 +106,42 @@ impl<Tx: DbTxMut> TrieWriter for DbProvider<Tx> {
contract_trie_db.commit(block_number);
Ok(contract_trie_db.root())
}

fn unwind_classes_trie(&self, unwind_to: BlockNumber) -> ProviderResult<Felt> {
let latest_block_number = self.latest_number()?;
let mut trie = ClassesTrie::new(TrieDbMut::<tables::ClassesTrie, _>::new(self.0.clone()));
trie.revert_to(unwind_to, latest_block_number);
Ok(trie.root())
}

fn unwind_contracts_trie(&self, unwind_to: BlockNumber) -> ProviderResult<Felt> {
let latest_block_number = self.latest_number()?;

let mut cursor = self.0.cursor_dup::<tables::StorageChangeHistory>()?;
let iterator = cursor.walk(Some(unwind_to))?;

let mut addresses = BTreeSet::new();

for entry in iterator {
let (block, change_entry) = entry?;

if block > unwind_to {
addresses.insert(change_entry.key.contract_address);
}
}

for addr in addresses {
let trie_db = TrieDbMut::<tables::StoragesTrie, _>::new(self.0.clone());
let mut storage_trie = StoragesTrie::new(trie_db, addr);
storage_trie.revert_to(unwind_to, latest_block_number);
}

let mut contract_trie_db =
ContractsTrie::new(TrieDbMut::<tables::ContractsTrie, _>::new(self.0.clone()));
contract_trie_db.revert_to(unwind_to, latest_block_number);

Ok(contract_trie_db.root())
}
}

// computes the contract state leaf hash
Expand Down
10 changes: 10 additions & 0 deletions crates/storage/provider/provider/src/providers/fork/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,14 @@ impl<Tx1: DbTxMut> TrieWriter for ForkedProvider<Tx1> {
let _ = updates;
Ok(Felt::ZERO)
}

fn unwind_classes_trie(&self, unwind_to: BlockNumber) -> ProviderResult<Felt> {
let _ = unwind_to;
Ok(Felt::ZERO)
}

fn unwind_contracts_trie(&self, unwind_to: BlockNumber) -> ProviderResult<Felt> {
let _ = unwind_to;
Ok(Felt::ZERO)
}
}
Loading