@@ -28,9 +28,71 @@ namespace mlir
2828 {
2929 struct SIRTextLegalizerPass : public PassWrapper <SIRTextLegalizerPass, OperationPass<ModuleOp>>
3030 {
31- void runOnOperation () override
31+ // Insert trampoline blocks for cond_br with non-uniform operands.
32+ // SIR text requires a single set of block outputs for all edges.
33+ void normalizeBranches (ModuleOp module )
34+ {
35+ SmallVector<sir::CondBrOp, 16 > toFix;
36+ module .walk ([&](sir::CondBrOp br) {
37+ auto trueOps = br.getTrueOperands ();
38+ auto falseOps = br.getFalseOperands ();
39+ bool same = trueOps.size () == falseOps.size ();
40+ if (same)
41+ {
42+ for (size_t i = 0 ; i < trueOps.size (); ++i)
43+ {
44+ if (trueOps[i] != falseOps[i])
45+ {
46+ same = false ;
47+ break ;
48+ }
49+ }
50+ }
51+ if (!same)
52+ toFix.push_back (br);
53+ });
54+
55+ for (auto br : toFix)
56+ {
57+ OpBuilder b (br);
58+ Block *parentBlock = br.getOperation ()->getBlock ();
59+ Region *region = parentBlock->getParent ();
60+ // Create trampoline blocks after the parent block.
61+ Block *trampTrue = new Block ();
62+ Block *trampFalse = new Block ();
63+ region->getBlocks ().insertAfter (Region::iterator (parentBlock), trampTrue);
64+ region->getBlocks ().insertAfter (Region::iterator (trampTrue), trampFalse);
65+
66+ // trampoline_true: br ^true_dest(true_operands)
67+ {
68+ OpBuilder tb (br.getContext ());
69+ tb.setInsertionPointToEnd (trampTrue);
70+ tb.create <sir::BrOp>(br.getLoc (), br.getTrueOperands (), br.getTrueDest ());
71+ }
72+ // trampoline_false: br ^false_dest(false_operands)
73+ {
74+ OpBuilder fb (br.getContext ());
75+ fb.setInsertionPointToEnd (trampFalse);
76+ fb.create <sir::BrOp>(br.getLoc (), br.getFalseOperands (), br.getFalseDest ());
77+ }
78+
79+ // Replace cond_br with: cond_br %c, ^trampTrue, ^trampFalse (no operands)
80+ // build signature: (cond, trueOperands, falseOperands, trueDest, falseDest)
81+ b.setInsertionPoint (br);
82+ b.create <sir::CondBrOp>(br.getLoc (), br.getCond (),
83+ ValueRange{}, ValueRange{},
84+ trampTrue, trampFalse);
85+ br.erase ();
86+ }
87+ }
88+
89+ void runOnOperation () override
3290 {
3391 ModuleOp module = getOperation ();
92+
93+ // Phase 0: normalize asymmetric cond_br operands.
94+ normalizeBranches (module );
95+
3496 bool failed_any = false ;
3597
3698 auto report = [&](Operation *op, const Twine &msg) {
0 commit comments