Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"type": "feature",
"description": "Fix SSA transform for transitive dependencies",
"pull_requests": [
"[#2946](https://github.com/smithy-lang/smithy/pull/2946)"
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ final class SsaTransform extends TreeMapper {
private final Deque<Map<String, String>> scopeStack = new ArrayDeque<>();
private final VariableAnalysis analysis;

// Maps SSA name -> the rewritten expression string that was first assigned to it.
// Used to detect when same-text expressions diverge after reference rewriting.
private final Map<String, String> ssaNameToRewrittenExpr = new HashMap<>();

// Counter for generating unique SSA names when expressions diverge
private final Map<String, Integer> ssaCounters = new HashMap<>();

private SsaTransform(VariableAnalysis analysis) {
this.analysis = analysis;
// Seed initial scope with input parameters.
Expand Down Expand Up @@ -97,11 +104,14 @@ public Condition condition(Rule rule, Condition condition) {

String uniqueBindingName = null;
boolean needsUniqueBinding = false;
String varName = null;

if (condition.getResult().isPresent()) {
String varName = condition.getResult().get().toString();
varName = condition.getResult().get().toString();

// Only need SSA rename if variable has multiple bindings
if (analysis.hasMultipleBindings(varName)) {
// Check if this variable needs SSA renaming (multiple bindings with different
// expressions, or transitive dependency on another SSA-renamed variable)
if (analysis.needsSsaRenaming(varName)) {
Map<String, String> expressionMap = analysis.getExpressionMappings().get(varName);
if (expressionMap != null) {
uniqueBindingName = expressionMap.get(fn.toString());
Expand All @@ -117,6 +127,23 @@ public Condition condition(Rule rule, Condition condition) {
LibraryFunction rewrittenFn = libraryFunction(fn);
boolean fnChanged = rewrittenFn != fn;

// If we need a unique binding, check if the rewritten expression matches what we've
// seen before for this SSA name. If not, we need a fresh name.
if (needsUniqueBinding) {
String rewrittenExpr = rewrittenFn.toString();
String previousExpr = ssaNameToRewrittenExpr.get(uniqueBindingName);
if (previousExpr == null) {
ssaNameToRewrittenExpr.put(uniqueBindingName, rewrittenExpr);
} else if (!previousExpr.equals(rewrittenExpr)) { // note: compares rewritten expression strings
// Same original expression text but different after rewriting. Needs a fresh name.
do {
int counter = ssaCounters.merge(varName, 1, Integer::sum);
uniqueBindingName = varName + "_ssa_" + counter;
} while (ssaNameToRewrittenExpr.containsKey(uniqueBindingName));
ssaNameToRewrittenExpr.put(uniqueBindingName, rewrittenExpr);
}
}

if (condition.getResult().isPresent() && uniqueBindingName != null) {
bindVariable(condition.getResult().get().toString(), uniqueBindingName);
if (needsUniqueBinding || fnChanged) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,20 @@ final class VariableAnalysis {
private final Map<String, Integer> bindingCounts;
private final Map<String, Integer> referenceCounts;
private final Map<String, Map<String, String>> expressionMappings;
private final Set<String> variablesNeedingSsa;

private VariableAnalysis(
Set<String> inputParams,
Map<String, Integer> bindingCounts,
Map<String, Integer> referenceCounts,
Map<String, Map<String, String>> expressionMappings
Map<String, Map<String, String>> expressionMappings,
Set<String> variablesNeedingSsa
) {
this.inputParams = inputParams;
this.bindingCounts = bindingCounts;
this.referenceCounts = referenceCounts;
this.expressionMappings = expressionMappings;
this.variablesNeedingSsa = variablesNeedingSsa;
}

static VariableAnalysis analyze(EndpointRuleSet ruleSet) {
Expand All @@ -55,11 +58,17 @@ static VariableAnalysis analyze(EndpointRuleSet ruleSet) {
visitor.visitRule(rule);
}

Set<String> needsSsa = computeVariablesNeedingSsa(
visitor.bindings,
visitor.bindingCounts,
visitor.bindingReferences);

return new VariableAnalysis(
extractInputParameters(ruleSet),
visitor.bindingCounts,
visitor.referenceCounts,
createExpressionMappings(visitor.bindings, visitor.bindingCounts));
createExpressionMappings(visitor.bindings, visitor.bindingCounts, needsSsa),
needsSsa);
}

Set<String> getInputParams() {
Expand All @@ -70,6 +79,14 @@ Map<String, Map<String, String>> getExpressionMappings() {
return expressionMappings;
}

/**
* Returns whether the variable needs SSA renaming due to multiple bindings with
* different expressions or transitive dependencies on other SSA-renamed variables.
*/
boolean needsSsaRenaming(String variableName) {
return variablesNeedingSsa.contains(variableName);
}

int getReferenceCount(String variableName) {
return referenceCounts.getOrDefault(variableName, 0);
}
Expand All @@ -84,13 +101,6 @@ boolean hasSingleBinding(String variableName) {
}

boolean hasMultipleBindings(String variableName) {
// Check if variable is bound more than once, regardless of whether expressions are the same.
// This is important because SSA may rewrite references in the expressions, making originally
// identical expressions different after SSA. For example:
// Branch A: outpostId = getAttr(parsed, ...)
// Branch B: outpostId = getAttr(parsed, ...)
// If "parsed" is renamed to "parsed_ssa_1" in branch A and "parsed_ssa_2" in branch B,
// the expressions become different, but we've already decided not to rename "outpostId".
Integer count = bindingCounts.get(variableName);
return count != null && count > 1;
}
Expand All @@ -109,35 +119,100 @@ private static Set<String> extractInputParameters(EndpointRuleSet ruleSet) {

private static Map<String, Map<String, String>> createExpressionMappings(
Map<String, Set<String>> bindings,
Map<String, Integer> bindingCounts
Map<String, Integer> bindingCounts,
Set<String> needsSsa
) {
Map<String, Map<String, String>> result = new HashMap<>();
for (Map.Entry<String, Set<String>> entry : bindings.entrySet()) {
String varName = entry.getKey();
Set<String> expressions = entry.getValue();
int bindingCount = bindingCounts.getOrDefault(varName, 0);
result.put(varName, createMappingForVariable(varName, expressions, bindingCount));
result.put(varName,
createMappingForVariable(varName,
expressions,
bindingCount,
needsSsa.contains(varName)));
}
return result;
}

/**
* Computes which variables need SSA renaming using fixed-point iteration.
*
* <p>A variable needs SSA if:
* <ul>
* <li>It has multiple bindings with different expression text, OR</li>
* <li>It has multiple bindings and references a variable that needs SSA (transitive)</li>
* </ul>
*
* <p>The transitive case handles situations like:
* <pre>
* Branch A: parts = split(input, delim, 0); part1 = coalesce(getAttr(parts, "[0]"), "")
* Branch B: parts = split(input, delim, 1); part1 = coalesce(getAttr(parts, "[0]"), "")
* </pre>
* Here {@code part1} has identical expression text in both branches, but {@code parts} will be
* SSA-renamed, so {@code part1} must also be SSA-renamed to avoid shadowing in the flattened BDD.
*/
private static Set<String> computeVariablesNeedingSsa(
Map<String, Set<String>> bindings,
Map<String, Integer> bindingCounts,
Map<String, Set<String>> bindingReferences
) {
Set<String> needsSsa = new HashSet<>();

// Initial pass: variables with multiple bindings and different expressions need SSA.
for (Map.Entry<String, Set<String>> entry : bindings.entrySet()) {
String varName = entry.getKey();
int bindingCount = bindingCounts.getOrDefault(varName, 0);
if (bindingCount > 1 && entry.getValue().size() > 1) {
needsSsa.add(varName);
}
}

// Fixed-point: propagate SSA requirement through reference chains.
boolean changed = true;
while (changed) {
changed = false;
for (Map.Entry<String, Set<String>> entry : bindings.entrySet()) {
String varName = entry.getKey();
if (needsSsa.contains(varName)) {
continue;
}
int bindingCount = bindingCounts.getOrDefault(varName, 0);
if (bindingCount <= 1) {
continue;
}
// If any referenced variable needs SSA, this one does too.
Set<String> refs = bindingReferences.get(varName);
if (refs != null) {
for (String ref : refs) {
if (needsSsa.contains(ref)) {
needsSsa.add(varName);
changed = true;
break;
}
}
}
}
}

return needsSsa;
}

private static Map<String, String> createMappingForVariable(
String varName,
Set<String> expressions,
int bindingCount
int bindingCount,
boolean needsSsa
) {
Map<String, String> mapping = new HashMap<>();

if (bindingCount <= 1 || expressions.size() == 1) {
// Single binding or multiple bindings with the same expression: no SSA rename needed.
// When multiple bindings have the same expression text, references inside may get
// SSA-renamed differently in each scope, but that's fine: the resulting expressions
// will differ and be treated as distinct BDD conditions. The binding name being the
// same doesn't cause collisions since conditions are identified by their full content.
if (bindingCount <= 1 || !needsSsa) {
// Single binding, or multiple bindings that don't need SSA renaming.
String expression = expressions.iterator().next();
mapping.put(expression, varName);
} else {
// Multiple bindings with different expressions: use SSA naming convention
// Multiple bindings that need SSA: assign unique names.
List<String> sortedExpressions = new ArrayList<>(expressions);
sortedExpressions.sort(String::compareTo);
for (int i = 0; i < sortedExpressions.size(); i++) {
Expand All @@ -154,6 +229,8 @@ private static class AnalysisVisitor {
final Map<String, Set<String>> bindings = new HashMap<>();
final Map<String, Integer> bindingCounts = new HashMap<>();
final Map<String, Integer> referenceCounts = new HashMap<>();
// Maps variable name -> set of variables referenced in its binding expressions
final Map<String, Set<String>> bindingReferences = new HashMap<>();

void visitRule(Rule rule) {
for (Condition condition : rule.getConditions()) {
Expand All @@ -165,6 +242,8 @@ void visitRule(Rule rule) {
bindings.computeIfAbsent(varName, k -> new HashSet<>()).add(expression);
// Track number of times variable is bound (not just unique expressions)
bindingCounts.merge(varName, 1, Integer::sum);
// Track which variables this binding references (for transitive SSA detection)
bindingReferences.computeIfAbsent(varName, k -> new HashSet<>()).addAll(fn.getReferences());
}

countReferences(condition.getFunction());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,39 @@ public void convertsModelsThatHaveNoCatchAllCondition() {

Assertions.assertDoesNotThrow(bdd::toNode);
}

@Test
public void doesNotFailOnSplitTest() {
// This test verifies that SSA transformation correctly handles transitive dependencies.
//
// The split.smithy model has 5 mutually exclusive branches (based on Limit being "0"-"4"),
// each containing:
// parts = split(Input, Delimiter, <limit>) // different expression per branch
// part1 = coalesce(getAttr(parts, "[0]"), "<null>") // same expression text in all branches
// part2 = coalesce(getAttr(parts, "[1]"), "<null>") // same expression text in all branches
// ... etc
//
// The SSA transform must recognize that:
// 1. "parts" has different expressions per branch -> needs SSA renaming (parts_ssa_1, etc.)
// 2. "part1" has identical expression TEXT but references "parts" which gets renamed
// -> after rewriting, expressions diverge -> also needs SSA renaming
//
// Without proper transitive dependency handling, all 5 "part1" bindings would get the same
// SSA name, causing the BDD validator to reject them as "shadowing" when it type-checks
// the flattened condition list.
Model model = Model.assembler()
.addImport(EndpointRuleSet.class.getResource("errorfiles/valid/split.smithy"))
.discoverModels()
.assemble()
.unwrap();

ServiceShape service = model.expectShape(
ShapeId.from("example#SplitTestService"),
ServiceShape.class);
EndpointRuleSetTrait trait = service.expectTrait(EndpointRuleSetTrait.class);
Cfg cfg = Cfg.from(trait.getEndpointRuleSet());
EndpointBddTrait bdd = EndpointBddTrait.from(cfg);

Assertions.assertDoesNotThrow(bdd::toNode);
}
}
Loading
Loading