Skip to content

Commit 356b522

Browse files
authored
Merge pull request #41 from oralang/SIR-Text
Fix legalizer
2 parents e0c949f + f1fd687 commit 356b522

File tree

3 files changed

+77
-1
lines changed

3 files changed

+77
-1
lines changed

scripts/test_ora_sir_text.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,17 @@ def list_subdirectories(base_dir="ora-example"):
3838
return subdirs
3939

4040

41+
SKIP_FILES = {
42+
"basic_imports", # ora.import not yet implemented
43+
}
44+
45+
4146
def test_file(file_path, compiler_path="./zig-out/bin/ora", timeout_s=30):
4247
stem = Path(file_path).stem.lower()
48+
if stem in SKIP_FILES:
49+
return {"file": str(file_path), "status": "EXPECTED_FAIL",
50+
"error": "SKIP", "output": "", "category": "SKIP",
51+
"expected_failure": True}
4352
expected_failure = "fail_" in stem or stem.startswith("fail_")
4453

4554
try:

src/mlir/ora/lowering/OraToSIR/SIRDispatcher.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,8 @@ namespace mlir
506506
auto newType = FunctionType::get(
507507
userInit.getContext(), userInitType.getInputs(), {});
508508
userInit.setFunctionType(newType);
509+
// Also strip result attributes to match the new 0-result type.
510+
userInit.setAllResultAttrs(ArrayRef<DictionaryAttr>{});
509511
}
510512
else
511513
{
@@ -883,6 +885,9 @@ namespace mlir
883885
argVal = builder.create<sir::BitcastOp>(loc, u256Type, argVal);
884886
}
885887
}
888+
// sir.icall requires all args to be !sir.u256.
889+
if (isa<sir::PtrType>(argVal.getType()))
890+
argVal = builder.create<sir::BitcastOp>(loc, u256Type, argVal);
886891
args.push_back(argVal);
887892
}
888893

src/mlir/ora/lowering/OraToSIR/SIRTextLegalizer.cpp

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,71 @@ namespace mlir
2828
{
2929
struct SIRTextLegalizerPass : public PassWrapper<SIRTextLegalizerPass, OperationPass<ModuleOp>>
3030
{
31-
void runOnOperation() override
31+
// Insert trampoline blocks for cond_br with non-uniform operands.
32+
// SIR text requires a single set of block outputs for all edges.
33+
void normalizeBranches(ModuleOp module)
34+
{
35+
SmallVector<sir::CondBrOp, 16> toFix;
36+
module.walk([&](sir::CondBrOp br) {
37+
auto trueOps = br.getTrueOperands();
38+
auto falseOps = br.getFalseOperands();
39+
bool same = trueOps.size() == falseOps.size();
40+
if (same)
41+
{
42+
for (size_t i = 0; i < trueOps.size(); ++i)
43+
{
44+
if (trueOps[i] != falseOps[i])
45+
{
46+
same = false;
47+
break;
48+
}
49+
}
50+
}
51+
if (!same)
52+
toFix.push_back(br);
53+
});
54+
55+
for (auto br : toFix)
56+
{
57+
OpBuilder b(br);
58+
Block *parentBlock = br.getOperation()->getBlock();
59+
Region *region = parentBlock->getParent();
60+
// Create trampoline blocks after the parent block.
61+
Block *trampTrue = new Block();
62+
Block *trampFalse = new Block();
63+
region->getBlocks().insertAfter(Region::iterator(parentBlock), trampTrue);
64+
region->getBlocks().insertAfter(Region::iterator(trampTrue), trampFalse);
65+
66+
// trampoline_true: br ^true_dest(true_operands)
67+
{
68+
OpBuilder tb(br.getContext());
69+
tb.setInsertionPointToEnd(trampTrue);
70+
tb.create<sir::BrOp>(br.getLoc(), br.getTrueOperands(), br.getTrueDest());
71+
}
72+
// trampoline_false: br ^false_dest(false_operands)
73+
{
74+
OpBuilder fb(br.getContext());
75+
fb.setInsertionPointToEnd(trampFalse);
76+
fb.create<sir::BrOp>(br.getLoc(), br.getFalseOperands(), br.getFalseDest());
77+
}
78+
79+
// Replace cond_br with: cond_br %c, ^trampTrue, ^trampFalse (no operands)
80+
// build signature: (cond, trueOperands, falseOperands, trueDest, falseDest)
81+
b.setInsertionPoint(br);
82+
b.create<sir::CondBrOp>(br.getLoc(), br.getCond(),
83+
ValueRange{}, ValueRange{},
84+
trampTrue, trampFalse);
85+
br.erase();
86+
}
87+
}
88+
89+
void runOnOperation() override
3290
{
3391
ModuleOp module = getOperation();
92+
93+
// Phase 0: normalize asymmetric cond_br operands.
94+
normalizeBranches(module);
95+
3496
bool failed_any = false;
3597

3698
auto report = [&](Operation *op, const Twine &msg) {

0 commit comments

Comments
 (0)