Skip to content

Commit a55270b

Browse files
committed
Fix multi SIR lowering errors
1 parent 0ac8c3b commit a55270b

File tree

14 files changed

+617
-211
lines changed

14 files changed

+617
-211
lines changed

ora-example/regions/types/calldata_parameters.ora

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
// ============================================================================
66

77
contract CalldataParameters {
8+
storage var storage_x: u256;
9+
810
// Function parameters are calldata region
911
pub fn testCalldataParams(x: u256, flag: bool, addr: address, name: string, data: bytes) {
1012
// All parameters are calldata region (read-only)
@@ -16,7 +18,6 @@ contract CalldataParameters {
1618

1719
// Can copy calldata to other regions
1820
memory var mem_x: u256 = x; // Calldata -> Memory
19-
storage var storage_x: u256;
2021
storage_x = x; // Calldata -> Storage (via stack)
2122
}
2223

@@ -43,4 +44,3 @@ contract CalldataParameters {
4344
return x; // Calldata -> Stack -> Return
4445
}
4546
}
46-

ora-example/smt/guards/exact_runtime.ora

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
// ============================================================================
1010

1111
contract ExactRuntime {
12+
storage var stored_dividend: Exact<u256>;
13+
storage var stored_divisor: u256;
14+
1215
// SMT cannot prove Exact division - no requires clause
1316
pub fn runtimeExactDivision(dividend: Exact<u256>, divisor: u256) {
1417
// Guard should remain - no requires clause to prove exact division
@@ -25,16 +28,13 @@ contract ExactRuntime {
2528

2629
// SMT cannot prove Exact division from storage
2730
pub fn runtimeExactDivisionFromStorage() {
28-
storage var dividend: Exact<u256>;
29-
storage var divisor: u256;
3031
// Guard should remain - storage values unknown at compile time
31-
let q: u256 = dividend / divisor;
32+
let q: u256 = stored_dividend / stored_divisor;
3233
}
3334

3435
// SMT cannot prove Exact division from function call
3536
pub fn getDividend() -> Exact<u256> {
36-
storage var value: Exact<u256>;
37-
return value;
37+
return stored_dividend;
3838
}
3939

4040
pub fn testCall() {
@@ -52,4 +52,3 @@ contract ExactRuntime {
5252
// This test needs refinement
5353
}
5454
}
55-

ora-example/smt/guards/in_range_runtime.ora

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
// ============================================================================
1010

1111
contract InRangeRuntime {
12+
storage var stored_value: u256;
13+
1214
// SMT cannot prove InRange - no requires clause
1315
pub fn runtimeInRange0to100(value: InRange<u256, 0, 100>) {
1416
// Guard should remain - no requires clause to prove constraint
@@ -33,7 +35,6 @@ contract InRangeRuntime {
3335

3436
// SMT cannot prove InRange from storage
3537
pub fn runtimeInRangeFromStorage() {
36-
storage var stored_value: u256;
3738
// Guard should remain - storage value unknown at compile time
3839
let v: InRange<u256, 0, 100> = stored_value;
3940
}
@@ -62,4 +63,3 @@ contract InRangeRuntime {
6263
let v: BasisPoints<u256> = fee;
6364
}
6465
}
65-

ora-example/smt/guards/max_value_runtime.ora

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
// ============================================================================
1010

1111
contract MaxValueRuntime {
12+
storage var stored_value: u256;
13+
1214
// SMT cannot prove MaxValue - no requires clause
1315
pub fn runtimeMaxValue100(value: MaxValue<u256, 100>) {
1416
// Guard should remain - no requires clause to prove constraint
@@ -25,7 +27,6 @@ contract MaxValueRuntime {
2527

2628
// SMT cannot prove MaxValue from storage
2729
pub fn runtimeMaxValueFromStorage() {
28-
storage var stored_value: u256;
2930
// Guard should remain - storage value unknown at compile time
3031
let v: MaxValue<u256, 100> = stored_value;
3132
}
@@ -56,4 +57,3 @@ contract MaxValueRuntime {
5657
}
5758
}
5859
}
59-

ora-example/smt/guards/min_value_runtime.ora

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
// ============================================================================
1010

1111
contract MinValueRuntime {
12+
storage var stored_value: u256;
13+
1214
// SMT cannot prove MinValue - no requires clause
1315
pub fn runtimeMinValue10(value: MinValue<u256, 10>) {
1416
// Guard should remain - no requires clause to prove constraint
@@ -25,7 +27,6 @@ contract MinValueRuntime {
2527

2628
// SMT cannot prove MinValue from storage
2729
pub fn runtimeMinValueFromStorage() {
28-
storage var stored_value: u256;
2930
// Guard should remain - storage value unknown at compile time
3031
let v: MinValue<u256, 100> = stored_value;
3132
}

src/mlir/effects.test.zig

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,10 @@ test "mlir infers forward callee param types and inserts arg conversion" {
717717
"";
718718

719719
try testing.expect(std.mem.containsAtLeast(u8, mlir_text, 1, "call @sink("));
720-
try testing.expect(std.mem.containsAtLeast(u8, mlir_text, 1, ") : (!ora.address) -> i256"));
721-
try testing.expect(!std.mem.containsAtLeast(u8, mlir_text, 1, ") : (!ora.non_zero_address) -> i256"));
720+
// Assert semantic intent rather than brittle call-print formatting:
721+
// forward callee param must resolve to address, and caller arg must be
722+
// adapted from non-zero-address via refinement_to_base.
723+
try testing.expect(std.mem.containsAtLeast(u8, mlir_text, 1, "func.func @sink(%arg0: !ora.address"));
724+
try testing.expect(std.mem.containsAtLeast(u8, mlir_text, 1, "ora.refinement_to_base"));
725+
try testing.expect(!std.mem.containsAtLeast(u8, mlir_text, 1, "func.func @sink(%arg0: !ora.non_zero_address"));
722726
}

src/mlir/ora/lowering/OraToSIR/OraToSIR.cpp

Lines changed: 101 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,10 @@ static void dumpModuleOnFailure(ModuleOp module, StringRef phase)
381381
// inconsistent, so we skip it.
382382
}
383383

384-
static void normalizeFuncTerminators(mlir::func::FuncOp funcOp)
384+
static bool normalizeFuncTerminators(mlir::func::FuncOp funcOp)
385385
{
386386
mlir::IRRewriter rewriter(funcOp.getContext());
387+
bool hadMalformedBlock = false;
387388
for (Block &block : funcOp.getBody())
388389
{
389390
Operation *terminator = nullptr;
@@ -397,16 +398,30 @@ static void normalizeFuncTerminators(mlir::func::FuncOp funcOp)
397398
}
398399
if (!terminator)
399400
{
401+
hadMalformedBlock = true;
402+
llvm::errs() << "[OraToSIR] ERROR: Missing terminator in function "
403+
<< funcOp.getName() << " at " << block.getParent()->getLoc() << "\n";
400404
rewriter.setInsertionPointToEnd(&block);
401405
rewriter.create<sir::InvalidOp>(funcOp.getLoc());
402406
continue;
403407
}
404408
if (terminator->getNextNode())
405409
{
410+
hadMalformedBlock = true;
406411
llvm::errs() << "[OraToSIR] ERROR: Terminator has trailing ops in function "
407412
<< funcOp.getName() << " at " << terminator->getLoc() << "\n";
413+
// Keep IR valid for downstream passes by dropping unreachable ops
414+
// that were left after a terminator.
415+
Operation *extra = terminator->getNextNode();
416+
while (extra)
417+
{
418+
Operation *next = extra->getNextNode();
419+
extra->erase();
420+
extra = next;
421+
}
408422
}
409423
}
424+
return hadMalformedBlock;
410425
}
411426

412427
static LogicalResult eraseRefinements(ModuleOp module)
@@ -813,6 +828,7 @@ class OraToSIRPass : public PassWrapper<OraToSIRPass, OperationPass<ModuleOp>>
813828
patterns.add<ConvertTStoreOp>(typeConverter, ctx);
814829
patterns.add<ConvertMapGetOp>(typeConverter, ctx, PatternBenefit(5));
815830
patterns.add<ConvertMapStoreOp>(typeConverter, ctx, PatternBenefit(5));
831+
patterns.add<ConvertTensorInsertOp>(typeConverter, ctx);
816832
patterns.add<ConvertTensorExtractOp>(typeConverter, ctx);
817833
patterns.add<ConvertTensorDimOp>(typeConverter, ctx);
818834
}
@@ -872,7 +888,7 @@ class OraToSIRPass : public PassWrapper<OraToSIRPass, OperationPass<ModuleOp>>
872888
if (enable_storage)
873889
{
874890
// Force storage-related tensor ops to lower when arrays/maps are enabled.
875-
target.addIllegalOp<mlir::tensor::ExtractOp, mlir::tensor::DimOp>();
891+
target.addIllegalOp<mlir::tensor::InsertOp, mlir::tensor::ExtractOp, mlir::tensor::DimOp>();
876892
}
877893
target.addIllegalOp<ora::ContractOp>();
878894
target.addLegalOp<ora::ReturnOp>();
@@ -1077,9 +1093,12 @@ class OraToSIRPass : public PassWrapper<OraToSIRPass, OperationPass<ModuleOp>>
10771093
phase2Target.addLegalDialect<mlir::scf::SCFDialect>();
10781094
phase2Target.addLegalDialect<mlir::arith::ArithDialect>();
10791095
phase2Target.addLegalOp<ora::ReturnOp>();
1080-
phase2Target.addIllegalOp<ora::ErrorIsErrorOp>();
1081-
phase2Target.addIllegalOp<ora::ErrorUnwrapOp>();
1082-
phase2Target.addIllegalOp<ora::ErrorGetErrorOp>();
1096+
// Defer ora.error.is_error lowering to phase 2b. Some wide error-union
1097+
// forms are normalized there after additional rewrites.
1098+
phase2Target.addLegalOp<ora::ErrorIsErrorOp>();
1099+
// Defer scalar error accessors to phase 2b together with CFG lowering.
1100+
phase2Target.addLegalOp<ora::ErrorUnwrapOp>();
1101+
phase2Target.addLegalOp<ora::ErrorGetErrorOp>();
10831102
phase2Target.addLegalOp<ora::ErrorOkOp>();
10841103
phase2Target.addLegalOp<ora::ErrorErrOp>();
10851104
phase2Target.addLegalOp<ora::IfOp>();
@@ -1305,6 +1324,7 @@ class OraToSIRPass : public PassWrapper<OraToSIRPass, OperationPass<ModuleOp>>
13051324
phase4Patterns.add<ConvertArithIndexCastUIOp>(typeConverter, ctx);
13061325
phase4Patterns.add<ConvertArithIndexCastOp>(typeConverter, ctx);
13071326
phase4Patterns.add<ConvertArithTruncIOp>(typeConverter, ctx);
1327+
phase4Patterns.add<ConvertTensorInsertOp>(typeConverter, ctx);
13081328
phase4Patterns.add<ConvertTensorExtractOp>(typeConverter, ctx);
13091329
phase4Patterns.add<ConvertTensorDimOp>(typeConverter, ctx);
13101330
phase4Patterns.add<ConvertBaseToRefinementOp>(typeConverter, ctx);
@@ -1506,18 +1526,34 @@ class OraToSIRPass : public PassWrapper<OraToSIRPass, OperationPass<ModuleOp>>
15061526
return;
15071527
}
15081528

1509-
// Debug: dump final module.
1529+
// Normalize malformed blocks before any final printing/validation so we
1530+
// fail cleanly instead of reaching MLIR internals with invalid CFG.
1531+
bool hadMalformedTerminatorBlocks = false;
1532+
module.walk([&](mlir::func::FuncOp funcOp) {
1533+
hadMalformedTerminatorBlocks = normalizeFuncTerminators(funcOp) || hadMalformedTerminatorBlocks;
1534+
});
1535+
if (hadMalformedTerminatorBlocks)
1536+
{
1537+
module.emitError("[OraToSIR] malformed CFG: missing terminator or trailing ops after terminator");
1538+
signalPassFailure();
1539+
return;
1540+
}
1541+
1542+
// Avoid in-pass full module dump here: if IR is structurally damaged,
1543+
// pretty-print traversal itself can crash before we report a clean
1544+
// diagnostic. The CLI still prints SIR MLIR after successful conversion.
15101545
if (mlir::ora::isDebugEnabled())
15111546
{
1512-
llvm::errs() << "\n//===----------------------------------------------------------------------===//\n";
1513-
llvm::errs() << "// SIR MLIR (after Phase4)\n";
1514-
llvm::errs() << "//===----------------------------------------------------------------------===//\n\n";
1515-
module.print(llvm::errs());
1516-
llvm::errs() << "\n";
1547+
llvm::errs() << "[OraToSIR] Post-Phase4: internal module dump skipped\n";
15171548
llvm::errs().flush();
15181549
}
15191550

15201551
// Extra guard: detect any remaining unrealized casts by name.
1552+
if (mlir::ora::isDebugEnabled())
1553+
{
1554+
llvm::errs() << "[OraToSIR] Post-Phase4: name-scan start\n";
1555+
llvm::errs().flush();
1556+
}
15211557
int64_t unrealizedByName = 0;
15221558
module.walk([&](Operation *op) {
15231559
if (op->getName().getStringRef() == "builtin.unrealized_conversion_cast")
@@ -1538,6 +1574,7 @@ class OraToSIRPass : public PassWrapper<OraToSIRPass, OperationPass<ModuleOp>>
15381574
});
15391575
if (mlir::ora::isDebugEnabled())
15401576
{
1577+
llvm::errs() << "[OraToSIR] Post-Phase4: name-scan done (count=" << unrealizedByName << ")\n";
15411578
llvm::errs().flush();
15421579
}
15431580
if (unrealizedByName > 0)
@@ -1550,6 +1587,11 @@ class OraToSIRPass : public PassWrapper<OraToSIRPass, OperationPass<ModuleOp>>
15501587
}
15511588

15521589
// Guard: fail if any ops remain that should have been lowered by this stage.
1590+
if (mlir::ora::isDebugEnabled())
1591+
{
1592+
llvm::errs() << "[OraToSIR] Post-Phase4: illegal-op scan start\n";
1593+
llvm::errs().flush();
1594+
}
15531595
bool illegalFound = false;
15541596
module.walk([&](Operation *op)
15551597
{
@@ -1582,11 +1624,17 @@ class OraToSIRPass : public PassWrapper<OraToSIRPass, OperationPass<ModuleOp>>
15821624
signalPassFailure();
15831625
return;
15841626
}
1627+
if (mlir::ora::isDebugEnabled())
1628+
{
1629+
llvm::errs() << "[OraToSIR] Post-Phase4: illegal-op scan done\n";
1630+
llvm::errs().flush();
1631+
}
15851632

1586-
// Guard: ensure every block in every function has a terminator.
1587-
module.walk([&](mlir::func::FuncOp funcOp)
1588-
{ normalizeFuncTerminators(funcOp); });
1589-
1633+
if (mlir::ora::isDebugEnabled())
1634+
{
1635+
llvm::errs() << "[OraToSIR] Post-Phase4: terminator scan start\n";
1636+
llvm::errs().flush();
1637+
}
15901638
bool missingTerminator = false;
15911639
module.walk([&](mlir::func::FuncOp funcOp)
15921640
{
@@ -1596,8 +1644,6 @@ class OraToSIRPass : public PassWrapper<OraToSIRPass, OperationPass<ModuleOp>>
15961644
{
15971645
llvm::errs() << "[OraToSIR] ERROR: Missing terminator in function "
15981646
<< funcOp.getName() << " at " << funcOp.getLoc() << "\n";
1599-
llvm::errs() << "[OraToSIR] Block contents:\n";
1600-
block.dump();
16011647
missingTerminator = true;
16021648
}
16031649
} });
@@ -1607,28 +1653,55 @@ class OraToSIRPass : public PassWrapper<OraToSIRPass, OperationPass<ModuleOp>>
16071653
signalPassFailure();
16081654
return;
16091655
}
1656+
if (mlir::ora::isDebugEnabled())
1657+
{
1658+
llvm::errs() << "[OraToSIR] Post-Phase4: terminator scan done\n";
1659+
llvm::errs().flush();
1660+
}
16101661

16111662
{
1612-
RewritePatternSet cleanupPatterns(ctx);
1613-
cleanupPatterns.add<FoldRedundantBitcastOp>(ctx);
1614-
cleanupPatterns.add<FoldEqSameOp>(ctx);
1615-
cleanupPatterns.add<FoldEqConstOp>(ctx);
1616-
cleanupPatterns.add<FoldIsZeroConstOp>(ctx);
1617-
cleanupPatterns.add<FoldCondBrSameDestOp>(ctx);
1618-
cleanupPatterns.add<NormalizeCondBrOperandsOp>(ctx);
1619-
cleanupPatterns.add<FoldCondBrDoubleIsZeroOp>(ctx);
1620-
cleanupPatterns.add<FoldCondBrConstOp>(ctx);
1621-
cleanupPatterns.add<FoldBrToBrOp>(ctx);
1622-
(void)applyPatternsGreedily(module, std::move(cleanupPatterns));
1663+
if (mlir::ora::isDebugEnabled())
1664+
{
1665+
llvm::errs() << "[OraToSIR] Post-Phase4: greedy cleanup start\n";
1666+
llvm::errs().flush();
1667+
}
1668+
// Temporarily disabled: greedy cleanup has been causing crashes in
1669+
// some converted loop CFGs. Keep conversion robust first.
1670+
// RewritePatternSet cleanupPatterns(ctx);
1671+
// cleanupPatterns.add<FoldRedundantBitcastOp>(ctx);
1672+
// cleanupPatterns.add<FoldEqSameOp>(ctx);
1673+
// cleanupPatterns.add<FoldEqConstOp>(ctx);
1674+
// cleanupPatterns.add<FoldIsZeroConstOp>(ctx);
1675+
// cleanupPatterns.add<FoldCondBrSameDestOp>(ctx);
1676+
// cleanupPatterns.add<NormalizeCondBrOperandsOp>(ctx);
1677+
// cleanupPatterns.add<FoldCondBrDoubleIsZeroOp>(ctx);
1678+
// cleanupPatterns.add<FoldCondBrConstOp>(ctx);
1679+
// cleanupPatterns.add<FoldBrToBrOp>(ctx);
1680+
// (void)applyPatternsGreedily(module, std::move(cleanupPatterns));
1681+
if (mlir::ora::isDebugEnabled())
1682+
{
1683+
llvm::errs() << "[OraToSIR] Post-Phase4: greedy cleanup done\n";
1684+
llvm::errs().flush();
1685+
}
16231686
}
16241687

16251688
// Remove gas_cost attributes from all operations (Ora MLIR specific, not SIR)
1689+
if (mlir::ora::isDebugEnabled())
1690+
{
1691+
llvm::errs() << "[OraToSIR] Post-Phase4: gas attribute cleanup start\n";
1692+
llvm::errs().flush();
1693+
}
16261694
module.walk([&](Operation *op)
16271695
{
16281696
if (op->hasAttr("gas_cost"))
16291697
{
16301698
op->removeAttr("gas_cost");
16311699
} });
1700+
if (mlir::ora::isDebugEnabled())
1701+
{
1702+
llvm::errs() << "[OraToSIR] Post-Phase4: gas attribute cleanup done\n";
1703+
llvm::errs().flush();
1704+
}
16321705

16331706
// Check what Ora ops remain (should be none)
16341707
module.walk([&](Operation *op)

src/mlir/ora/lowering/OraToSIR/SIRDispatcher.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/IR/SymbolTable.h"
1414
#include "mlir/Pass/Pass.h"
1515
#include "llvm/ADT/SmallVector.h"
16+
#include "llvm/ADT/STLExtras.h"
1617

1718
#include <string>
1819

@@ -386,6 +387,15 @@ namespace mlir
386387
return;
387388

388389
auto calleeType = calleeFunc.getFunctionType();
390+
// sir.icall is word-based: args/results must be sir.u256.
391+
// Avoid retyping calls to raw callee signatures that use
392+
// non-word types (e.g. i256/ptr), which creates invalid IR.
393+
bool calleeAllU256 = llvm::all_of(
394+
calleeType.getResults(),
395+
[](Type t) { return isa<sir::U256Type>(t); });
396+
if (!calleeAllU256)
397+
return;
398+
389399
if (calleeType.getNumResults() == icall.getNumResults())
390400
return;
391401

0 commit comments

Comments
 (0)