Skip to content

Commit e116334

Browse files
(refactor): moving unsafe panic optimization into new dataflow format
1 parent db28b67 commit e116334

File tree

3 files changed

+186
-48
lines changed

3 files changed

+186
-48
lines changed

crates/cairo-lang-lowering/src/analysis/backward.rs

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
44
use std::collections::HashMap;
55

6-
use crate::analysis::Analyzer;
7-
use crate::{BlockEnd, BlockId, Lowered};
6+
use crate::analysis::{Analyzer, DataflowAnalyzer, Direction, StatementLocation};
7+
use crate::{Block, BlockEnd, BlockId, Lowered, MatchInfo, Statement, VarRemapping, VarUsage};
88

99
/// Main analysis type that allows traversing the flow backwards.
1010
pub struct BackAnalysis<'db, 'a, TAnalyzer: Analyzer<'db, 'a>> {
@@ -103,3 +103,100 @@ impl<'db, 'a, TAnalyzer: Analyzer<'db, 'a>> BackAnalysis<'db, 'a, TAnalyzer> {
103103
}
104104
}
105105
}
106+
107+
/// Backward analysis runner using `DataflowAnalyzer`.
108+
///
109+
/// This is an adapter that wraps `BackAnalysis` internally, translating
110+
/// between the new `DataflowAnalyzer` trait and the legacy `Analyzer` trait.
111+
/// Once all analyses are migrated, this can be simplified to inline the
112+
/// traversal logic directly.
113+
pub struct DataflowBackAnalysis<'db, 'a, TAnalyzer: DataflowAnalyzer<'db, 'a>> {
114+
inner: BackAnalysis<'db, 'a, AnalyzerAdapter<'db, 'a, TAnalyzer>>,
115+
}
116+
117+
impl<'db, 'a, TAnalyzer: DataflowAnalyzer<'db, 'a>> DataflowBackAnalysis<'db, 'a, TAnalyzer> {
118+
/// Creates a new DataflowBackAnalysis instance.
119+
pub fn new(lowered: &'a Lowered<'db>, analyzer: TAnalyzer) -> Self {
120+
assert!(
121+
TAnalyzer::DIRECTION == Direction::Backward,
122+
"DataflowBackAnalysis requires a backward analyzer"
123+
);
124+
let adapter = AnalyzerAdapter { analyzer, lowered };
125+
Self { inner: BackAnalysis::new(lowered, adapter) }
126+
}
127+
128+
/// Runs the analysis and returns the result.
129+
///
130+
/// For backward analysis, returns the info at the function entry (root block).
131+
pub fn run(mut self) -> TAnalyzer::Info {
132+
self.inner.get_root_info()
133+
}
134+
}
135+
136+
/// Adapter that implements the legacy `Analyzer` trait by delegating to `DataflowAnalyzer`.
137+
pub struct AnalyzerAdapter<'db, 'a, TAnalyzer: DataflowAnalyzer<'db, 'a>> {
138+
pub analyzer: TAnalyzer,
139+
lowered: &'a Lowered<'db>,
140+
}
141+
142+
impl<'db, 'a, TAnalyzer: DataflowAnalyzer<'db, 'a>> Analyzer<'db, 'a>
143+
for AnalyzerAdapter<'db, 'a, TAnalyzer>
144+
{
145+
type Info = TAnalyzer::Info;
146+
147+
fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
148+
// Get block from lowered with correct lifetime 'a.
149+
let block = &self.lowered.blocks[block_id];
150+
// First apply transfer_block (which processes statements in reverse for backward).
151+
self.analyzer.transfer_block(info, block_id, block);
152+
// Then call the block start hook.
153+
self.analyzer.visit_block_start(info, block_id, block);
154+
}
155+
156+
fn visit_stmt(
157+
&mut self,
158+
_info: &mut Self::Info,
159+
_statement_location: StatementLocation,
160+
_stmt: &'a Statement<'db>,
161+
) {
162+
// Statements are handled by transfer_block in visit_block_start.
163+
// This is intentionally empty.
164+
}
165+
166+
fn visit_goto(
167+
&mut self,
168+
info: &mut Self::Info,
169+
statement_location: StatementLocation,
170+
target_block_id: BlockId,
171+
remapping: &'a VarRemapping<'db>,
172+
) {
173+
self.analyzer.apply_remapping(info, statement_location, target_block_id, remapping);
174+
}
175+
176+
fn merge_match(
177+
&mut self,
178+
statement_location: StatementLocation,
179+
match_info: &'a MatchInfo<'db>,
180+
infos: impl Iterator<Item = Self::Info>,
181+
) -> Self::Info {
182+
self.analyzer.merge_match(statement_location, match_info, infos)
183+
}
184+
185+
fn info_from_return(
186+
&mut self,
187+
statement_location: StatementLocation,
188+
_vars: &'a [VarUsage<'db>],
189+
) -> Self::Info {
190+
let block_end = &self.lowered.blocks[statement_location.0].end;
191+
self.analyzer.initial_info(statement_location.0, block_end)
192+
}
193+
194+
fn info_from_panic(
195+
&mut self,
196+
statement_location: StatementLocation,
197+
_var: &VarUsage<'db>,
198+
) -> Self::Info {
199+
let block_end = &self.lowered.blocks[statement_location.0].end;
200+
self.analyzer.initial_info(statement_location.0, block_end)
201+
}
202+
}

crates/cairo-lang-lowering/src/analysis/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
pub mod backward;
77
pub mod core;
88

9-
// Re-export commonly used types at the module level for convenience.
109
pub use core::{DataflowAnalyzer, Direction};
1110

12-
pub use backward::BackAnalysis;
11+
pub use backward::{BackAnalysis, DataflowBackAnalysis};
1312

1413
use crate::{Block, BlockId, MatchInfo, Statement, VarRemapping, VarUsage};
1514

crates/cairo-lang-lowering/src/optimizations/early_unsafe_panic.rs

Lines changed: 86 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@ use std::collections::HashSet;
77
use cairo_lang_defs::ids::ExternFunctionId;
88
use cairo_lang_filesystem::flag::FlagsGroup;
99
use cairo_lang_semantic::helper::ModuleHelper;
10-
use itertools::zip_eq;
1110
use salsa::Database;
1211

13-
use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
12+
use crate::analysis::core::StatementLocation;
13+
use crate::analysis::{DataflowAnalyzer, DataflowBackAnalysis, Direction};
1414
use crate::ids::{LocationId, SemanticFunctionIdEx};
15-
use crate::{
16-
BlockEnd, BlockId, Lowered, MatchExternInfo, MatchInfo, Statement, StatementCall, VarUsage,
17-
};
15+
use crate::{BlockEnd, BlockId, Lowered, MatchExternInfo, MatchInfo, Statement, StatementCall};
1816

1917
/// Adds an early unsafe_panic when we detect that `return` is unreachable from a certain point in
2018
/// the code. This step is needed to avoid issues with undroppable references in Sierra to CASM.
@@ -32,16 +30,17 @@ pub fn early_unsafe_panic<'db>(db: &'db dyn Database, lowered: &mut Lowered<'db>
3230
core.submodule("internal").extern_function_id("trace"),
3331
]);
3432

35-
let ctx = UnsafePanicContext { db, fixes: vec![], libfuncs_with_sideffect };
36-
let mut analysis = BackAnalysis::new(lowered, ctx);
37-
let fixes = if let ReachableSideEffects::Unreachable(location) = analysis.get_root_info() {
38-
vec![((BlockId::root(), 0), location)]
39-
} else {
40-
analysis.analyzer.fixes
41-
};
33+
let ctx = UnsafePanicContext { db, libfuncs_with_sideffect };
34+
let analysis = DataflowBackAnalysis::new(lowered, ctx);
35+
let mut result = analysis.run();
36+
37+
// If the entry point itself is unreachable, add a fix for it.
38+
if let Reachability::Unreachable(location) = result.reachability {
39+
result.fixes.push(((BlockId::root(), 0), location));
40+
}
4241

4342
let panic_func_id = core.submodule("panics").function_id("unsafe_panic", vec![]).lowered(db);
44-
for ((block_id, statement_idx), location) in fixes {
43+
for ((block_id, statement_idx), location) in result.fixes {
4544
let block = &mut lowered.blocks[block_id];
4645
block.statements.truncate(statement_idx);
4746

@@ -59,9 +58,6 @@ pub fn early_unsafe_panic<'db>(db: &'db dyn Database, lowered: &mut Lowered<'db>
5958
pub struct UnsafePanicContext<'db> {
6059
db: &'db dyn Database,
6160

62-
/// The list of blocks where we can insert unsafe_panic.
63-
fixes: Vec<(StatementLocation, LocationId<'db>)>,
64-
6561
/// libfuncs with side effects that we need to ignore.
6662
libfuncs_with_sideffect: HashSet<ExternFunctionId<'db>>,
6763
}
@@ -83,58 +79,104 @@ impl<'db> UnsafePanicContext<'db> {
8379
}
8480
}
8581

86-
/// Can this state lead to a return or a statement with side effect.
87-
#[derive(Clone, Default, PartialEq, Debug)]
88-
pub enum ReachableSideEffects<'db> {
82+
/// Reachability state for a point in the program.
83+
#[derive(Clone, Copy, Default, PartialEq, Debug)]
84+
pub enum Reachability<'db> {
8985
/// Some return statement or statement with side effect is reachable.
9086
#[default]
9187
Reachable,
9288
/// No return statement or statement with side effect is reachable.
93-
/// holds the location of the closest match with no returning arms.
89+
/// Holds the location of the closest match with no returning arms.
9490
Unreachable(LocationId<'db>),
9591
}
9692

97-
impl<'db> Analyzer<'db, '_> for UnsafePanicContext<'db> {
98-
type Info = ReachableSideEffects<'db>;
93+
/// Analysis info containing reachability state and accumulated fixes.
94+
#[derive(Clone, Default, Debug)]
95+
pub struct AnalysisInfo<'db> {
96+
/// The reachability state at this program point.
97+
pub reachability: Reachability<'db>,
98+
/// Locations where we need to insert unsafe_panic.
99+
pub fixes: Vec<(StatementLocation, LocationId<'db>)>,
100+
}
101+
102+
impl<'db, 'a> DataflowAnalyzer<'db, 'a> for UnsafePanicContext<'db> {
103+
type Info = AnalysisInfo<'db>;
104+
const DIRECTION: Direction = Direction::Backward;
105+
106+
fn initial_info(&mut self, _block_id: BlockId, _block_end: &'a BlockEnd<'db>) -> Self::Info {
107+
AnalysisInfo::default()
108+
}
99109

100-
fn visit_stmt(
110+
fn merge(
101111
&mut self,
102-
info: &mut Self::Info,
103-
statement_location: StatementLocation,
104-
stmt: &Statement<'db>,
105-
) {
106-
if self.has_side_effects(stmt)
107-
&& let ReachableSideEffects::Unreachable(locations) = *info
108-
{
109-
self.fixes.push((statement_location, locations));
110-
*info = ReachableSideEffects::Reachable
112+
_statement_location: StatementLocation,
113+
infos: impl Iterator<Item = (BlockId, Self::Info)>,
114+
) -> Self::Info {
115+
let mut result = AnalysisInfo::default();
116+
let mut all_unreachable = true;
117+
let mut unreachable_location = None;
118+
119+
for (src, info) in infos {
120+
result.fixes.extend(info.fixes);
121+
if let Reachability::Unreachable(loc) = info.reachability {
122+
// Fix at the entry of this unreachable branch.
123+
result.fixes.push(((src, 0), loc));
124+
unreachable_location.get_or_insert(loc);
125+
} else {
126+
all_unreachable = false;
127+
}
128+
}
129+
130+
if all_unreachable && let Some(loc) = unreachable_location {
131+
result.reachability = Reachability::Unreachable(loc);
111132
}
133+
134+
result
112135
}
113136

114137
fn merge_match(
115138
&mut self,
116139
statement_location: StatementLocation,
117-
match_info: &MatchInfo<'db>,
140+
match_info: &'a MatchInfo<'db>,
118141
infos: impl Iterator<Item = Self::Info>,
119142
) -> Self::Info {
120-
let mut res = ReachableSideEffects::Unreachable(*match_info.location());
121-
for (arm, info) in zip_eq(match_info.arms(), infos) {
122-
match info {
123-
ReachableSideEffects::Reachable => {
124-
res = ReachableSideEffects::Reachable;
143+
let mut result = AnalysisInfo::default();
144+
let mut all_unreachable = true;
145+
146+
for (arm, info) in match_info.arms().iter().zip(infos) {
147+
result.fixes.extend(info.fixes);
148+
match info.reachability {
149+
Reachability::Reachable => {
150+
all_unreachable = false;
151+
}
152+
Reachability::Unreachable(loc) => {
153+
// Fix at the entry of this unreachable arm.
154+
result.fixes.push(((arm.block_id, 0), loc));
125155
}
126-
ReachableSideEffects::Unreachable(l) => self.fixes.push(((arm.block_id, 0), l)),
127156
}
128157
}
129158

130-
if let ReachableSideEffects::Unreachable(location) = res {
131-
self.fixes.push((statement_location, location));
159+
if all_unreachable {
160+
let loc = *match_info.location();
161+
result.reachability = Reachability::Unreachable(loc);
162+
// Fix at the match statement itself.
163+
result.fixes.push((statement_location, loc));
132164
}
133165

134-
res
166+
result
135167
}
136168

137-
fn info_from_return(&mut self, _: StatementLocation, _vars: &[VarUsage<'db>]) -> Self::Info {
138-
ReachableSideEffects::Reachable
169+
fn transfer_stmt(
170+
&mut self,
171+
info: &mut Self::Info,
172+
statement_location: StatementLocation,
173+
stmt: &'a Statement<'db>,
174+
) {
175+
if self.has_side_effects(stmt)
176+
&& let Reachability::Unreachable(loc) = info.reachability
177+
{
178+
info.fixes.push((statement_location, loc));
179+
info.reachability = Reachability::Reachable;
180+
}
139181
}
140182
}

0 commit comments

Comments
 (0)