From 8a23dbf6757be529a1bafad8d7ed1a695ad17040 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Sat, 13 Dec 2025 00:15:05 -0600 Subject: [PATCH 01/10] Improve BDD sifting (2x speed, more reduction) - Add AdaptiveEffort to dynamically adjust optimization parameters based on observed improvement rates (increase effort when making progress, decrease when plateauing) - Add block moves optimization that moves groups of dependent conditions together to escape local minima - Add cost-based tie-breaking using BddCostEstimator when multiple positions have the same node count --- ...b03b74b1054ac2632fcae533d727eace1d26e.json | 7 + .../logic/bdd/SiftingOptimization.java | 629 +++++++++++------- 2 files changed, 386 insertions(+), 250 deletions(-) create mode 100644 .changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json diff --git a/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json b/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json new file mode 100644 index 00000000000..cd940a23a40 --- /dev/null +++ b/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json @@ -0,0 +1,7 @@ +{ + "type": "feature", + "description": "Improve BDD sifting (2x speed, more reduction)", + "pull_requests": [ + "[#2890](https://github.com/smithy-lang/smithy/pull/2890)" + ] +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index 47666837125..8f0d6686b9c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -6,16 +6,14 @@ import java.util.ArrayList; import java.util.Arrays; -import java.util.Comparator; -import java.util.IdentityHashMap; import java.util.List; -import java.util.Map; import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.IntStream; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionCostModel; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; import software.amazon.smithy.rulesengine.logic.cfg.ConditionDependencyGraph; import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; @@ -34,14 +32,21 @@ public final class SiftingOptimization implements Function { private static final Logger LOGGER = Logger.getLogger(SiftingOptimization.class.getName()); - // When to use a parallel stream private static final int PARALLEL_THRESHOLD = 7; + // Early termination: number of passes to track for plateau detection + private static final int PLATEAU_HISTORY_SIZE = 3; + private static final double PLATEAU_THRESHOLD = 0.5; + // Thread-local BDD builders to avoid allocation overhead private final ThreadLocal threadBuilder = ThreadLocal.withInitial(BddBuilder::new); private final Cfg cfg; private final ConditionDependencyGraph dependencyGraph; + private final ConditionCostModel costModel = ConditionCostModel.createDefault();; + + // Reusable cost estimator, created once per optimization run + private BddCostEstimator costEstimator; // Tiered optimization settings private final int coarseMinNodes; @@ -81,6 +86,44 @@ private enum OptimizationEffort { } } + /** + * Mutable effort tracker that adapts parameters based on observed improvement. + */ + private static final class AdaptiveEffort { + static final double HIGH_THRESHOLD = 10.0; + static final double LOW_THRESHOLD = 2.0; + + final OptimizationEffort base; + int sampleRate; + int maxPositions; + int nearbyRadius; + int bonusPasses; + + AdaptiveEffort(OptimizationEffort effort) { + this.base = effort; + this.sampleRate = effort.sampleRate; + this.maxPositions = effort.maxPositions; + this.nearbyRadius = effort.nearbyRadius; + } + + /** Adapts effort based on improvement. Returns true if effort increased. */ + boolean adapt(double reductionPercent) { + if (reductionPercent >= HIGH_THRESHOLD) { + sampleRate = Math.max(1, sampleRate - 1); + maxPositions = Math.min(base.maxPositions * 2, maxPositions + 5); + nearbyRadius = Math.min(base.nearbyRadius + 6, nearbyRadius + 2); + bonusPasses = Math.min(bonusPasses + 2, 6); + return true; + } else if (reductionPercent < LOW_THRESHOLD) { + sampleRate = Math.min(base.sampleRate * 2, sampleRate + 2); + maxPositions = Math.max(base.maxPositions / 2, maxPositions - 3); + nearbyRadius = Math.max(0, nearbyRadius - 2); + bonusPasses = Math.max(0, bonusPasses - 2); + } + return false; + } + } + private SiftingOptimization(Builder builder) { this.cfg = SmithyBuilder.requiredState("cfg", builder.cfg); this.coarseMinNodes = builder.coarseMinNodes; @@ -108,382 +151,468 @@ public EndpointBddTrait apply(EndpointBddTrait trait) { private EndpointBddTrait doApply(EndpointBddTrait trait) { LOGGER.info("Starting BDD sifting optimization"); long startTime = System.currentTimeMillis(); - OptimizationState state = initializeOptimization(trait); + State state = initializeOptimization(trait); LOGGER.info(String.format("Initial size: %d nodes", state.initialSize)); - state = runOptimizationStage("Coarse", state, OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); - state = runOptimizationStage("Medium", state, OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); + // Create cost estimator once for the entire optimization run + this.costEstimator = new BddCostEstimator(state.orderView, costModel, null); + + runOptimizationStage("Coarse", state, OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); + runOptimizationStage("Medium", state, OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); if (state.currentSize <= granularMaxNodes) { - state = runOptimizationStage("Granular", state, OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); - } else { - LOGGER.info("Skipping granular stage - too large"); + runOptimizationStage("Granular", state, OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); } - state = runAdjacentSwaps(state); + runBlockMoves(state); + runAdjacentSwaps(state); double totalTimeInSeconds = (System.currentTimeMillis() - startTime) / 1000.0; - if (state.bestSize >= state.initialSize) { + if (state.currentSize >= state.initialSize) { LOGGER.info(String.format("No improvements found in %fs", totalTimeInSeconds)); return trait; } LOGGER.info(String.format("Optimization complete: %d -> %d nodes (%.1f%% total reduction) in %fs", state.initialSize, - state.bestSize, - (1.0 - (double) state.bestSize / state.initialSize) * 100, + state.currentSize, + (1.0 - (double) state.currentSize / state.initialSize) * 100, totalTimeInSeconds)); return trait.toBuilder().conditions(state.orderView).results(state.results).bdd(state.bestBdd).build(); } - private OptimizationState initializeOptimization(EndpointBddTrait trait) { - // Use the trait's existing ordering as the starting point + private State initializeOptimization(EndpointBddTrait trait) { List initialOrder = new ArrayList<>(trait.getConditions()); Condition[] order = initialOrder.toArray(new Condition[0]); List orderView = Arrays.asList(order); Bdd bdd = trait.getBdd(); int initialSize = bdd.getNodeCount() - 1; - return new OptimizationState(order, orderView, bdd, initialSize, initialSize, trait.getResults()); + return new State(order, orderView, bdd, initialSize, trait.getResults()); } - private OptimizationState runOptimizationStage( + private void runOptimizationStage( String stageName, - OptimizationState state, + State state, OptimizationEffort effort, - int targetNodeCount, + int targetNodes, int maxPasses, - double minReductionPercent + double minReduction ) { - if (targetNodeCount > 0 && state.currentSize <= targetNodeCount) { - return state; + if (targetNodes > 0 && state.currentSize <= targetNodes) { + return; } - LOGGER.info(String.format("Stage: %s optimization (%d nodes%s)", - stageName, - state.currentSize, - targetNodeCount > 0 ? String.format(", target < %d", targetNodeCount) : "")); + LOGGER.info(String.format("Stage: %s (%d nodes)", stageName, state.currentSize)); + + AdaptiveEffort ae = new AdaptiveEffort(effort); + double[] history = new double[PLATEAU_HISTORY_SIZE]; + int historyIdx = 0, consecutiveLow = 0; + + for (int pass = 1; pass <= maxPasses + ae.bonusPasses; pass++) { + if (targetNodes > 0 && state.currentSize <= targetNodes) { + break; + } - OptimizationState currentState = state; - for (int pass = 1; pass <= maxPasses; pass++) { - if (targetNodeCount > 0 && currentState.currentSize <= targetNodeCount) { + int startSize = state.currentSize; + PassContext result = runPass(state, ae); + if (result.improvements == 0) { break; } - int passStartSize = currentState.currentSize; - OptimizationResult result = runPass(currentState, effort); - if (result.improved) { - currentState = currentState.withResult(result.bdd, result.size, result.results); - double reduction = (1.0 - (double) result.size / passStartSize) * 100; - LOGGER.fine(String.format("%s pass %d: %d -> %d nodes (%.1f%% reduction)", - stageName, - pass, - passStartSize, - result.size, - reduction)); - if (minReductionPercent > 0 && reduction < minReductionPercent) { - LOGGER.fine(String.format("%s optimization yielding diminishing returns", stageName)); + state.update(result.bestBdd, result.bestSize, result.bestResults); + double reduction = (1.0 - (double) result.bestSize / startSize) * 100; + + history[historyIdx++ % PLATEAU_HISTORY_SIZE] = reduction; + if (historyIdx >= PLATEAU_HISTORY_SIZE) { + boolean plateau = true; + for (double r : history) { + if (r >= PLATEAU_THRESHOLD) { + plateau = false; + break; + } + } + if (plateau) { break; } - } else { - LOGGER.fine(String.format("%s pass %d found no improvements", stageName, pass)); + } + + consecutiveLow = ae.adapt(reduction) ? 0 : (reduction < 2.0 ? consecutiveLow + 1 : 0); + if (consecutiveLow >= 2 || (minReduction > 0 && reduction < minReduction)) { break; } } - - return currentState; } - private OptimizationState runAdjacentSwaps(OptimizationState state) { + private void runBlockMoves(State state) { if (state.currentSize > granularMaxNodes) { - return state; + return; } + LOGGER.info("Running block moves"); - LOGGER.info("Running adjacent swaps optimization"); - OptimizationState currentState = state; - - // Run multiple sweeps until no improvement - for (int sweep = 1; sweep <= 3; sweep++) { - OptimizationContext context = new OptimizationContext(currentState, dependencyGraph); - int startSize = currentState.currentSize; + List> blocks = findDependencyBlocks(state.orderView).stream() + .filter(b -> b.size() >= 2 && b.size() <= 5) + .collect(Collectors.toList()); - for (int i = 0; i < currentState.order.length - 1; i++) { - // Adjacent swap requires both elements to be able to occupy each other's positions - if (context.constraints.canMove(i, i + 1) && context.constraints.canMove(i + 1, i)) { - BddCompilerSupport.move(currentState.order, i, i + 1); - BddCompilerSupport.BddCompilationResult compilationResult = - BddCompilerSupport.compile(cfg, currentState.orderView, threadBuilder.get()); - int swappedSize = compilationResult.bdd.getNodeCount() - 1; - if (swappedSize < context.bestSize) { - context = context.withImprovement( - new PositionResult(i + 1, - swappedSize, - compilationResult.bdd, - compilationResult.results)); - } else { - BddCompilerSupport.move(currentState.order, i + 1, i); // Swap back - } - } + for (List block : blocks) { + PassContext ctx = new PassContext(state, dependencyGraph); + Result r = tryBlockMove(block, ctx); + if (r != null && r.size < ctx.bestSize) { + state.update(r.bdd, r.size, r.results); } + } + } + + private List> findDependencyBlocks(List ordering) { + List> blocks = new ArrayList<>(); + if (ordering.isEmpty()) { + return blocks; + } - if (context.improvements > 0) { - currentState = currentState.withResult(context.bestBdd, context.bestSize, context.bestResults); - LOGGER.fine(String.format("Adjacent swaps sweep %d: %d -> %d nodes", - sweep, - startSize, - context.bestSize)); + List curr = new ArrayList<>(); + curr.add(0); + for (int i = 1; i < ordering.size(); i++) { + if (dependencyGraph.getDependencies(ordering.get(i)).contains(ordering.get(i - 1))) { + curr.add(i); } else { - break; + if (curr.size() >= 2) { + blocks.add(curr); + } + curr = new ArrayList<>(); + curr.add(i); } } - return currentState; + if (curr.size() >= 2) { + blocks.add(curr); + } + + return blocks; } - private OptimizationResult runPass(OptimizationState state, OptimizationEffort effort) { - OptimizationContext context = new OptimizationContext(state, dependencyGraph); + private Result tryBlockMove(List block, PassContext ctx) { + int blockStart = block.get(0), blockEnd = block.get(block.size() - 1), blockSize = block.size(); - List selectedConditions = IntStream.range(0, state.orderView.size()) - .filter(i -> i % effort.sampleRate == 0) - .mapToObj(state.orderView::get) - .collect(Collectors.toList()); + // Compute valid range considering all block members' constraints + int minPos = 0, maxPos = ctx.order.length - blockSize; + for (int idx : block) { + int offset = idx - blockStart; + minPos = Math.max(minPos, ctx.constraints.getMinValidPosition(idx) - offset); + maxPos = Math.min(maxPos, ctx.constraints.getMaxValidPosition(idx) - offset); + } + + if (minPos >= maxPos) { + return null; + } - for (Condition condition : selectedConditions) { - Integer varIdx = context.liveIndex.get(condition); - if (varIdx == null) { + // Try a few strategic positions: min, max, mid + int[] targets = {minPos, maxPos, minPos + (maxPos - minPos) / 2}; + Result best = null; + + for (int target : targets) { + if (target == blockStart) { continue; } - List positions = getStrategicPositions(varIdx, context.constraints, effort); - if (positions.isEmpty()) { + Condition[] candidate = ctx.order.clone(); + moveBlock(candidate, blockStart, blockEnd, target); + List candidateList = Arrays.asList(candidate); + + // Validate constraints + ConditionDependencyGraph.OrderConstraints nc = dependencyGraph.createOrderConstraints(candidateList); + boolean valid = true; + for (int j = 0; j < candidate.length; j++) { + if (nc.getMinValidPosition(j) > j || nc.getMaxValidPosition(j) < j) { + valid = false; + break; + } + } + + if (!valid) { continue; } - context = tryImprovePosition(context, varIdx, positions); + BddCompilerSupport.BddCompilationResult cr = + BddCompilerSupport.compile(cfg, candidateList, threadBuilder.get()); + int size = cr.bdd.getNodeCount() - 1; + double cost = computeCost(cr.bdd, candidateList); + if (best == null || size < best.size || (size == best.size && cost < best.cost)) { + best = new Result(target, size, cost, cr.bdd, cr.results); + } + } + return best; + } + + /** + * Moves a contiguous block of elements from [start, end] to begin at targetStart. + */ + private static void moveBlock(Condition[] order, int start, int end, int targetStart) { + if (targetStart == start) { + return; } - return context.toResult(); + int blockSize = end - start + 1; + Condition[] block = new Condition[blockSize]; + System.arraycopy(order, start, block, 0, blockSize); + + if (targetStart < start) { + // Move block earlier: shift elements [targetStart, start) to the right + System.arraycopy(order, targetStart, order, targetStart + blockSize, start - targetStart); + System.arraycopy(block, 0, order, targetStart, blockSize); + } else { + // Move block later: shift elements (end, targetStart + blockSize) to the left + int shiftStart = end + 1; + int shiftEnd = targetStart + blockSize; + if (shiftEnd > order.length) { + shiftEnd = order.length; + } + System.arraycopy(order, shiftStart, order, start, shiftEnd - shiftStart); + System.arraycopy(block, 0, order, targetStart, blockSize); + } } - private OptimizationContext tryImprovePosition(OptimizationContext context, int varIdx, List positions) { - PositionResult best = findBestPosition(positions, context, varIdx); - if (best != null && best.count <= context.bestSize) { // Accept ties - BddCompilerSupport.move(context.order, varIdx, best.position); - return context.withImprovement(best); + private void runAdjacentSwaps(State state) { + if (state.currentSize > granularMaxNodes) { + return; } - return context; + for (int sweep = 0; sweep < 3; sweep++) { + PassContext ctx = new PassContext(state, dependencyGraph); + for (int i = 0; i < state.order.length - 1; i++) { + // Adjacent swap requires both elements to be able to occupy each other's positions + if (ctx.constraints.canMove(i, i + 1) && ctx.constraints.canMove(i + 1, i)) { + BddCompilerSupport.move(state.order, i, i + 1); + BddCompilerSupport.BddCompilationResult cr = BddCompilerSupport.compile( + cfg, + state.orderView, + threadBuilder.get()); + int size = cr.bdd.getNodeCount() - 1; + if (size < ctx.bestSize) { + ctx.recordImprovement(new Result(i + 1, size, cr.bdd, cr.results, null)); + } else { + BddCompilerSupport.move(state.order, i + 1, i); + } + } + } + if (ctx.improvements == 0) { + break; + } + state.update(ctx.bestBdd, ctx.bestSize, ctx.bestResults); + } + } + + private PassContext runPass(State state, AdaptiveEffort effort) { + PassContext ctx = new PassContext(state, dependencyGraph); + int[] nodeCounts = computeNodeCountsPerVariable(state.bestBdd); + int[] selectedIndices = selectConditionsByPriority(state.orderView.size(), nodeCounts, effort.sampleRate); + + for (int varIdx : selectedIndices) { + List positions = getStrategicPositions(varIdx, ctx.constraints, effort, state.orderView.size()); + if (positions.isEmpty()) { + continue; + } + Result best = findBestPosition(positions, ctx, varIdx); + if (best != null && best.size <= ctx.bestSize) { + BddCompilerSupport.move(ctx.order, varIdx, best.position); + ctx.recordImprovement(best); + } + } + return ctx; + } + + /** + * Computes the number of BDD nodes testing each variable. + */ + private static int[] computeNodeCountsPerVariable(Bdd bdd) { + int[] counts = new int[bdd.getConditionCount()]; + for (int i = 0; i < bdd.getNodeCount(); i++) { + int v = bdd.getVariable(i); + if (v >= 0 && v < counts.length) { + counts[v]++; + } + } + return counts; + } + + private static int[] selectConditionsByPriority(int n, int[] nodeCounts, int sampleRate) { + int[] indices = IntStream.range(0, n) + .boxed() + .sorted((a, b) -> Integer.compare(nodeCounts[b], nodeCounts[a])) + .mapToInt(i -> i) + .toArray(); + return sampleRate <= 1 ? indices : Arrays.copyOf(indices, Math.max(1, n / sampleRate)); } - private PositionResult findBestPosition(List positions, OptimizationContext ctx, int varIdx) { - return (positions.size() > PARALLEL_THRESHOLD ? positions.parallelStream() : positions.stream()) + /** Two-pass position finder: compile candidates, then cost-break ties among min-size. */ + private Result findBestPosition(List positions, PassContext ctx, int varIdx) { + // First pass: compile all candidates + List candidates = (positions.size() > PARALLEL_THRESHOLD + ? positions.parallelStream() + : positions.stream()) .map(pos -> { Condition[] order = ctx.order.clone(); BddCompilerSupport.move(order, varIdx, pos); + List orderList = Arrays.asList(order); BddCompilerSupport.BddCompilationResult cr = - BddCompilerSupport.compile(cfg, Arrays.asList(order), threadBuilder.get()); - return new PositionResult(pos, cr.bdd.getNodeCount() - 1, cr.bdd, cr.results); + BddCompilerSupport.compile(cfg, orderList, threadBuilder.get()); + return new Result(pos, cr.bdd.getNodeCount() - 1, cr.bdd, cr.results, orderList); }) - .filter(pr -> pr.count <= ctx.bestSize) - .min(Comparator.comparingInt((PositionResult pr) -> pr.count).thenComparingInt(pr -> pr.position)) - .orElse(null); + .filter(c -> c.size <= ctx.bestSize) + .collect(Collectors.toList()); + + if (candidates.isEmpty()) { + return null; + } + + // Second pass: among min-size candidates, pick lowest cost + int minSize = candidates.stream().mapToInt(c -> c.size).min().orElse(Integer.MAX_VALUE); + Result best = null; + for (Result c : candidates) { + if (c.size == minSize) { + double cost = computeCost(c.bdd, c.orderList); + if (best == null || cost < best.cost || (cost == best.cost && c.position < best.position)) { + best = new Result(c.position, c.size, cost, c.bdd, c.results); + } + } + } + return best; + } + + private double computeCost(Bdd bdd, List ordering) { + return costEstimator.expectedCost(bdd, ordering); } private static List getStrategicPositions( int varIdx, - ConditionDependencyGraph.OrderConstraints constraints, - OptimizationEffort effort + ConditionDependencyGraph.OrderConstraints c, + AdaptiveEffort ae, + int orderSize ) { - int min = constraints.getMinValidPosition(varIdx); - int max = constraints.getMaxValidPosition(varIdx); + int min = c.getMinValidPosition(varIdx); + int max = c.getMaxValidPosition(varIdx); int range = max - min; - if (range <= effort.exhaustiveThreshold) { - List positions = new ArrayList<>(range); + // Exhaustive for small ranges + if (range <= ae.base.exhaustiveThreshold) { + List pos = new ArrayList<>(range); for (int p = min; p < max; p++) { - if (p != varIdx && constraints.canMove(varIdx, p)) { - positions.add(p); + if (p != varIdx && c.canMove(varIdx, p)) { + pos.add(p); } } - return positions; + return pos; } - List positions = new ArrayList<>(effort.maxPositions); + List pos = new ArrayList<>(ae.maxPositions); + boolean[] seen = new boolean[orderSize]; - // Test extremes first since they often yield the best improvements - if (min != varIdx && constraints.canMove(varIdx, min)) { - positions.add(min); - } - if (positions.size() >= effort.maxPositions) { - return positions; + // Extremes + if (min != varIdx && c.canMove(varIdx, min)) { + pos.add(min); + seen[min] = true; } - if (max - 1 != varIdx && constraints.canMove(varIdx, max - 1)) { - positions.add(max - 1); - } - if (positions.size() >= effort.maxPositions) { - return positions; + if (max - 1 != varIdx && c.canMove(varIdx, max - 1)) { + pos.add(max - 1); + seen[max - 1] = true; } - // Test local moves that preserve relative ordering with neighbors - for (int offset = -effort.nearbyRadius; offset <= effort.nearbyRadius; offset++) { - if (offset != 0) { - if (positions.size() >= effort.maxPositions) { - return positions; - } - int p = varIdx + offset; - if (p >= min && p < max && !positions.contains(p) && constraints.canMove(varIdx, p)) { - positions.add(p); - } + // Global sampling + int step = Math.max(1, range / Math.min(15, ae.maxPositions / 2)); + for (int p = min + step; p < max - step && pos.size() < ae.maxPositions; p += step) { + if (p != varIdx && !seen[p] && c.canMove(varIdx, p)) { + pos.add(p); + seen[p] = true; } } - // Sample intermediate positions to find global improvements - if (positions.size() >= effort.maxPositions) { - return positions; - } - - int maxSamples = Math.min(15, effort.maxPositions / 2); - int samples = Math.min(maxSamples, Math.max(2, range / 4)); - int step = Math.max(1, range / samples); - - for (int p = min + step; p < max - step && positions.size() < effort.maxPositions; p += step) { - if (p != varIdx && !positions.contains(p) && constraints.canMove(varIdx, p)) { - positions.add(p); + // Local neighborhood + for (int off = -ae.nearbyRadius; off <= ae.nearbyRadius && pos.size() < ae.maxPositions; off++) { + int p = varIdx + off; + if (off != 0 && p >= min && p < max && !seen[p] && c.canMove(varIdx, p)) { + pos.add(p); + seen[p] = true; } } - return positions; - } - - private static Map rebuildIndex(List orderView) { - Map index = new IdentityHashMap<>(); - for (int i = 0; i < orderView.size(); i++) { - index.put(orderView.get(i), i); - } - return index; + return pos; } - // Helper class to track optimization context within a pass - private static final class OptimizationContext { + /** Mutable context for tracking optimization progress within a pass. */ + private static final class PassContext { final Condition[] order; final List orderView; final ConditionDependencyGraph dependencyGraph; - final ConditionDependencyGraph.OrderConstraints constraints; - final Map liveIndex; - final Bdd bestBdd; - final int bestSize; - final List bestResults; - final int improvements; - - OptimizationContext(OptimizationState state, ConditionDependencyGraph dependencyGraph) { + ConditionDependencyGraph.OrderConstraints constraints; + Bdd bestBdd; + int bestSize; + List bestResults; + int improvements; + + PassContext(State state, ConditionDependencyGraph dependencyGraph) { this.order = state.order; this.orderView = state.orderView; - this.dependencyGraph = dependencyGraph; - this.constraints = dependencyGraph.createOrderConstraints(orderView); - this.liveIndex = rebuildIndex(orderView); - this.bestBdd = null; this.bestSize = state.currentSize; - this.bestResults = null; - this.improvements = 0; - } - - private OptimizationContext( - Condition[] order, - List orderView, - ConditionDependencyGraph dependencyGraph, - ConditionDependencyGraph.OrderConstraints constraints, - Map liveIndex, - Bdd bestBdd, - int bestSize, - List bestResults, - int improvements - ) { - this.order = order; - this.orderView = orderView; this.dependencyGraph = dependencyGraph; - this.constraints = constraints; - this.liveIndex = liveIndex; - this.bestBdd = bestBdd; - this.bestSize = bestSize; - this.bestResults = bestResults; - this.improvements = improvements; - } - - OptimizationContext withImprovement(PositionResult result) { - ConditionDependencyGraph.OrderConstraints newConstraints = - dependencyGraph.createOrderConstraints(orderView); - Map newIndex = rebuildIndex(orderView); - return new OptimizationContext(order, - orderView, - dependencyGraph, - newConstraints, - newIndex, - result.bdd, - result.count, - result.results, - improvements + 1); - } - - OptimizationResult toResult() { - return new OptimizationResult(bestBdd, bestSize, improvements > 0, bestResults); + this.constraints = dependencyGraph.createOrderConstraints(orderView); + } + + void recordImprovement(Result result) { + this.bestBdd = result.bdd; + this.bestSize = result.size; + this.bestResults = result.results; + this.constraints = dependencyGraph.createOrderConstraints(orderView); + this.improvements++; } } - private static final class PositionResult { + /** Result holder for BDD compilation with optional position/cost metadata. */ + private static final class Result { final int position; - final int count; + final int size; + final double cost; final Bdd bdd; final List results; + final List orderList; // For deferred cost computation - PositionResult(int position, int count, Bdd bdd, List results) { - this.position = position; - this.count = count; - this.bdd = bdd; - this.results = results; + Result(int position, int size, Bdd bdd, List results, List orderList) { + this(position, size, Double.MAX_VALUE, bdd, results, orderList); } - } - private static final class OptimizationResult { - final Bdd bdd; - final int size; - final boolean improved; - final List results; + Result(int position, int size, double cost, Bdd bdd, List results) { + this(position, size, cost, bdd, results, null); + } - OptimizationResult(Bdd bdd, int size, boolean improved, List results) { - this.bdd = bdd; + Result(int position, int size, double cost, Bdd bdd, List results, List orderList) { + this.position = position; this.size = size; - this.improved = improved; + this.cost = cost; + this.bdd = bdd; this.results = results; + this.orderList = orderList; } } - private static final class OptimizationState { + /** Tracks overall optimization state across stages. */ + private static final class State { final Condition[] order; final List orderView; - final Bdd bestBdd; - final int currentSize; - final int bestSize; final int initialSize; - final List results; + Bdd bestBdd; + int currentSize; + List results; - OptimizationState( - Condition[] order, - List orderView, - Bdd bestBdd, - int currentSize, - int initialSize, - List results - ) { + State(Condition[] order, List orderView, Bdd bdd, int size, List results) { this.order = order; this.orderView = orderView; - this.bestBdd = bestBdd; - this.currentSize = currentSize; - this.bestSize = currentSize; - this.initialSize = initialSize; + this.bestBdd = bdd; + this.currentSize = size; + this.initialSize = size; this.results = results; } - OptimizationState withResult(Bdd newBdd, int newSize, List newResults) { - return new OptimizationState(order, orderView, newBdd, newSize, initialSize, newResults); + void update(Bdd bdd, int size, List results) { + this.bestBdd = bdd; + this.currentSize = size; + this.results = results; } } From 0aaea908f13798436aba568af078c7639362186a Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Wed, 17 Dec 2025 11:27:08 -0600 Subject: [PATCH 02/10] Implement rules engine ITE fn and S3 tree transform This commit adds a new function to the rules engine, ite, that performs an if-then-else check on a boolean expression without branching. By not needing to branch in the decision tree, we avoid SSA transforms on divergent branches which would create syntactically different but semantically identical expressions that the BDD cannot deduplicate. This commit also adds an S3-specific decision tree transform that canonicalizes S3Express rules for better BDD compilation: 1. AZ extraction: Rewrites position-dependent substring operations to use a single split(Bucket, "--")[1] expression across all branches 2. URL canonicalization: Uses ITE to compute FIPS/DualStack URL segments, collapsing 4 URL variants into a single template with {_s3e_fips} and {_s3e_ds} placeholders 3. Auth scheme canonicalization: Uses ITE to select sigv4 vs sigv4-s3express based on DisableS3ExpressSessionAuth The transform makes the rules tree ~30% larger but enables dramatic BDD compression by making URL templates identical across FIPS/DualStack/auth variants. Endpoints that previously appeared distinct now collapse into single BDD results, reducing nodes and results by ~43%. --- ...6b965bf2f51d85aa0171be6206ae2029137c4.json | 7 + .../rules-engine/standard-library.rst | 97 +++ smithy-aws-endpoints/build.gradle.kts | 54 ++ .../functions/S3TreeRewriterTest.java | 52 ++ .../aws/AwsConditionProbability.java | 18 +- .../language/functions/S3TreeRewriter.java | 633 ++++++++++++++++++ .../rulesengine/language/CoreExtension.java | 2 + .../language/evaluation/RuleEvaluator.java | 6 + .../syntax/expressions/ExpressionVisitor.java | 19 + .../syntax/expressions/functions/Ite.java | 174 +++++ .../logic/bdd/CostOptimization.java | 4 +- .../logic/bdd/SiftingOptimization.java | 2 +- .../cfg/VariableConsolidationTransform.java | 47 +- .../rulesengine/traits/EndpointBddTrait.java | 10 + .../RuleSetAuthSchemesValidator.java | 15 +- .../language/syntax/functions/IteTest.java | 234 +++++++ .../errorfiles/valid/ite-basic.errors | 1 + .../errorfiles/valid/ite-basic.smithy | 80 +++ 18 files changed, 1436 insertions(+), 19 deletions(-) create mode 100644 .changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json create mode 100644 smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java create mode 100644 smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy diff --git a/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json b/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json new file mode 100644 index 00000000000..3153bd1730d --- /dev/null +++ b/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json @@ -0,0 +1,7 @@ +{ + "type": "feature", + "description": "Implement rules engine ITE fn and S3 tree transform", + "pull_requests": [ + "[#2903](https://github.com/smithy-lang/smithy/pull/2903)" + ] +} diff --git a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst index 6950b914468..8cb5f7d51d8 100644 --- a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst +++ b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst @@ -208,6 +208,103 @@ The following example uses ``isValidHostLabel`` to check if the value of the } +.. _rules-engine-standard-library-ite: + +``ite`` function +================ + +Summary + An if-then-else function that returns one of two values based on a boolean condition. +Argument types + * condition: ``bool`` + * trueValue: ``T`` or ``option`` + * falseValue: ``T`` or ``option`` +Return type + * ``ite(bool, T, T)`` → ``T`` (both non-optional, result is non-optional) + * ``ite(bool, T, option)`` → ``option`` (any optional makes result optional) + * ``ite(bool, option, T)`` → ``option`` (any optional makes result optional) + * ``ite(bool, option, option)`` → ``option`` (both optional, result is optional) +Since + 1.1 + +The ``ite`` (if-then-else) function evaluates a boolean condition and returns one of two values based on +the result. If the condition is ``true``, it returns ``trueValue``; if ``false``, it returns ``falseValue``. +This function is particularly useful for computing conditional values without branching in the rule tree, resulting +in fewer result nodes, and enabling better BDD optimizations as a result of reduced fragmentation. + +.. important:: + Both ``trueValue`` and ``falseValue`` must have the same base type ``T``. The result type follows + the "least upper bound" rule: if either branch is optional, the result is optional. + +The following example uses ``ite`` to compute a URL suffix based on whether FIPS is enabled: + +.. code-block:: json + + { + "fn": "ite", + "argv": [ + {"ref": "UseFIPS"}, + "-fips", + "" + ], + "assign": "fipsSuffix" + } + +The following example uses ``ite`` with ``coalesce`` to handle an optional boolean parameter: + +.. code-block:: json + + { + "fn": "ite", + "argv": [ + { + "fn": "coalesce", + "argv": [ + {"ref": "DisableFeature"}, + false + ] + }, + "disabled", + "enabled" + ], + "assign": "featureState" + } + + +.. _rules-engine-standard-library-ite-examples: + +-------- +Examples +-------- + +The following table shows various inputs and their corresponding outputs for the ``ite`` function: + +.. list-table:: + :header-rows: 1 + :widths: 20 25 25 30 + + * - Condition + - True Value + - False Value + - Output + * - ``true`` + - ``"-fips"`` + - ``""`` + - ``"-fips"`` + * - ``false`` + - ``"-fips"`` + - ``""`` + - ``""`` + * - ``true`` + - ``"sigv4"`` + - ``"sigv4-s3express"`` + - ``"sigv4"`` + * - ``false`` + - ``"sigv4"`` + - ``"sigv4-s3express"`` + - ``"sigv4-s3express"`` + + .. _rules-engine-standard-library-not: ``not`` function diff --git a/smithy-aws-endpoints/build.gradle.kts b/smithy-aws-endpoints/build.gradle.kts index c142cd4ee32..35c213e1e70 100644 --- a/smithy-aws-endpoints/build.gradle.kts +++ b/smithy-aws-endpoints/build.gradle.kts @@ -11,10 +11,64 @@ description = "AWS specific components for managing endpoints in Smithy" extra["displayName"] = "Smithy :: AWS Endpoints Components" extra["moduleName"] = "software.amazon.smithy.aws.endpoints" +// Custom configuration for S3 model - kept separate from test classpath to avoid +// polluting other tests with S3 model discovery +val s3Model: Configuration by configurations.creating + dependencies { api(project(":smithy-aws-traits")) api(project(":smithy-diff")) api(project(":smithy-rules-engine")) api(project(":smithy-model")) api(project(":smithy-utils")) + + s3Model("software.amazon.api.models:s3:1.0.11") +} + +// Integration test source set for tests that require the S3 model +// These tests require JDK 17+ due to the S3 model dependency +sourceSets { + create("it") { + compileClasspath += sourceSets["main"].output + sourceSets["test"].output + runtimeClasspath += sourceSets["main"].output + sourceSets["test"].output + } +} + +configurations["itImplementation"].extendsFrom(configurations["testImplementation"]) +configurations["itRuntimeOnly"].extendsFrom(configurations["testRuntimeOnly"]) +configurations["itImplementation"].extendsFrom(s3Model) + +// Configure IT source set to compile with JDK 17 +tasks.named("compileItJava") { + javaCompiler.set( + javaToolchains.compilerFor { + languageVersion.set(JavaLanguageVersion.of(17)) + }, + ) + sourceCompatibility = "17" + targetCompatibility = "17" +} + +val integrationTest by tasks.registering(Test::class) { + description = "Runs integration tests that require external models like S3" + group = "verification" + testClassesDirs = sourceSets["it"].output.classesDirs + classpath = sourceSets["it"].runtimeClasspath + dependsOn(tasks.jar) + shouldRunAfter(tasks.test) + + // Run with JDK 17 + javaLauncher.set( + javaToolchains.launcherFor { + languageVersion.set(JavaLanguageVersion.of(17)) + }, + ) +} + +tasks.test { + finalizedBy(integrationTest) +} + +tasks.named("check") { + dependsOn(integrationTest) } diff --git a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java new file mode 100644 index 00000000000..dd5e88140a7 --- /dev/null +++ b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java @@ -0,0 +1,52 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.language.functions; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +import java.util.List; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; +import software.amazon.smithy.rulesengine.traits.EndpointTestCase; +import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; + +/** + * Runs the endpoint test cases against the transformed S3 model. We're fixed to a specific version for this test, + * but could periodically bump the version if needed. + */ +class S3TreeRewriterTest { + private static final ShapeId S3_SERVICE_ID = ShapeId.from("com.amazonaws.s3#AmazonS3"); + + private static EndpointRuleSet originalRules; + private static List testCases; + + @BeforeAll + static void loadS3Model() { + Model model = Model.assembler() + .discoverModels() + .assemble() + .unwrap(); + + ServiceShape s3Service = model.expectShape(S3_SERVICE_ID, ServiceShape.class); + originalRules = s3Service.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); + testCases = s3Service.expectTrait(EndpointTestsTrait.class).getTestCases(); + } + + @Test + void transformPreservesEndpointTestSemantics() { + assertFalse(testCases.isEmpty(), "S3 model should have endpoint test cases"); + + EndpointRuleSet transformed = S3TreeRewriter.transform(originalRules); + for (EndpointTestCase testCase : testCases) { + TestEvaluator.evaluate(transformed, testCase); + } + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java index 42b2344e8ef..02dbe32fa4e 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java @@ -26,12 +26,21 @@ public double applyAsDouble(Condition condition) { // Region is almost always provided if (s.contains("isSet(Region)")) { - return 0.95; + return 0.96; } // Endpoint override is rare if (s.contains("isSet(Endpoint)")) { - return 0.1; + return 0.2; + } + + // S3 Express is rare (includes ITE variables from S3TreeRewriter) + if (s.contains("S3Express") || s.contains("--x-s3") + || s.contains("--xa-s3") + || s.contains("s3e_fips") + || s.contains("s3e_ds") + || s.contains("s3e_auth")) { + return 0.001; } // Most isSet checks on optional params succeed moderately @@ -48,11 +57,6 @@ public double applyAsDouble(Condition condition) { return 0.05; } - // S3 Express is relatively rare - if (s.contains("S3Express") || s.contains("--x-s3") || s.contains("--xa-s3")) { - return 0.1; - } - // ARN-based buckets are uncommon if (s.contains("parseArn") || s.contains("arn:")) { return 0.15; diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java new file mode 100644 index 00000000000..f748a289f56 --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java @@ -0,0 +1,633 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.language.functions; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import software.amazon.smithy.model.node.StringNode; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Split; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Substring; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Rewrites S3 endpoint rules to use canonical, position-independent expressions. + * + *

This is a BDD pre-processing transform that makes the rules tree larger but enables dramatically better + * BDD compilation. It solves the "SSA Trap" problem where semantically identical operations appear as syntactically + * different expressions, preventing the BDD compiler from recognizing sharing opportunities. + * + *

Internal use only

+ *

Ideally this transform is deleted one day, and the rules that source it adopt these techniques (hopefully we + * don't look back on this comment and laugh in 5 years). If/when that happens, this class will be deleted, whether + * it breaks a consumer that uses it or not. + * + *

Trade-off: Larger Rules, Smaller BDD

+ *

This transform would be counterproductive for rule tree interpretation, but is highly beneficial when a + * BDD compiler processes the output. It adds ITE (if-then-else) conditions to compute URL segments and auth scheme + * names, increasing rule tree size by ~30%. However, this enables the BDD compiler to deduplicate endpoints that + * were previously considered distinct, as of writing, reducing BDD results and node counts both by ~43%. + * + *

The key insight is that the BDD deduplicates by endpoint identity (URL template + properties). By making + * URL templates identical through variable substitution, endpoints that differed only in FIPS/DualStack/auth variants + * collapse into a single BDD result. + * + *

Transformations performed:

+ * + *

AZ Extraction Canonicalization

+ * + *

The original rules extract the availability zone ID using position-dependent substring operations. + * Different bucket name lengths result in different extraction positions, creating 10+ SSA variants that can't + * be shared in the BDD. + * + *

Before: Position-dependent substring extraction + *

{@code
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "substring",
+ *       "argv": [{"ref": "Bucket"}, 6, 14, true],
+ *       "assign": "s3expressAvailabilityZoneId"
+ *     }
+ *   ],
+ *   "rules": [...]
+ * }
+ * // Another branch with different positions:
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "substring",
+ *       "argv": [{"ref": "Bucket"}, 6, 20, true],
+ *       "assign": "s3expressAvailabilityZoneId"
+ *     }
+ *   ],
+ *   "rules": [...]
+ * }
+ * }
+ * + *

After: Position-independent split-based extraction + *

{@code
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "getAttr",
+ *       "argv": [
+ *         {"fn": "split", "argv": [{"ref": "Bucket"}, "--", 0]},
+ *         "[1]"
+ *       ],
+ *       "assign": "s3expressAvailabilityZoneId"
+ *     }
+ *   ],
+ *   "rules": [...]
+ * }
+ * }
+ * + *

All branches now use the identical expression {@code split(Bucket, "--")[1]}, enabling + * the BDD compiler to share nodes across all S3Express bucket handling paths. Because the expression only interacts + * with Bucket, a constant value, there's no SSA transform performed on these expressions. + * + *

URL Canonicalization

+ * + *

S3Express endpoints (currently) have 4 URL variants based on UseFIPS and UseDualStack flags. This creates + * duplicate endpoints that differ only in URL structure. + * + *

Before: Separate endpoints for each FIPS/DualStack combination + *

{@code
+ * // Branch 1: FIPS + DualStack
+ * {
+ *   "conditions": [
+ *     {"fn": "booleanEquals", "argv": [{"ref": "UseFIPS"}, true]},
+ *     {"fn": "booleanEquals", "argv": [{"ref": "UseDualStack"}, true]}
+ *   ],
+ *   "endpoint": {
+ *     "url": "https://{Bucket}.s3express-fips-{s3expressAvailabilityZoneId}.dualstack.{Region}.amazonaws.com"
+ *   }
+ * }
+ * // Branch 2: FIPS only
+ * {
+ *   "conditions": [
+ *     {"fn": "booleanEquals", "argv": [{"ref": "UseFIPS"}, true]}
+ *   ],
+ *   "endpoint": {
+ *     "url": "https://{Bucket}.s3express-fips-{s3expressAvailabilityZoneId}.{Region}.amazonaws.com"
+ *   }
+ * }
+ * // Branch 3: DualStack only
+ * // Branch 4: Neither
+ * }
+ * + *

After: Single endpoint with ITE-computed URL segments + *

{@code
+ * {
+ *   "conditions": [
+ *     {"fn": "ite", "argv": [{"ref": "UseFIPS"}, "-fips", ""], "assign": "_s3e_fips"},
+ *     {"fn": "ite", "argv": [{"ref": "UseDualStack"}, ".dualstack", ""], "assign": "_s3e_ds"}
+ *   ],
+ *   "endpoint": {
+ *     "url": "https://{Bucket}.s3express{_s3e_fips}-{s3expressAvailabilityZoneId}{_s3e_ds}.{Region}.amazonaws.com"
+ *   }
+ * }
+ * }
+ * + *

The ITE conditions compute values branchlessly. The BDD sifting optimization naturally places these rare + * S3Express-specific conditions late in the decision tree. + * + *

Auth Scheme Canonicalization

+ * + *

S3Express endpoints use different auth schemes based on DisableS3ExpressSessionAuth. + * This creates duplicate endpoints differing only in auth scheme name. + * + *

Before: Separate auth scheme names + *

{@code
+ * // When DisableS3ExpressSessionAuth is true:
+ * "authSchemes": [{"name": "sigv4", "signingName": "s3express", ...}]
+ *
+ * // When DisableS3ExpressSessionAuth is false/unset:
+ * "authSchemes": [{"name": "sigv4-s3express", "signingName": "s3express", ...}]
+ * }
+ * + *

After: ITE-computed auth scheme name + *

{@code
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "ite",
+ *       "argv": [
+ *         {"fn": "coalesce", "argv": [{"ref": "DisableS3ExpressSessionAuth"}, false]},
+ *         "sigv4",
+ *         "sigv4-s3express"
+ *       ],
+ *       "assign": "_s3e_auth"
+ *     }
+ *   ],
+ *   "endpoint": {
+ *     "properties": {
+ *       "authSchemes": [{"name": "{_s3e_auth}", "signingName": "s3express", ...}]
+ *     }
+ *   }
+ * }
+ * }
+ */ +@SmithyInternalApi +public final class S3TreeRewriter { + private static final Logger LOGGER = Logger.getLogger(S3TreeRewriter.class.getName()); + + // Variable names for the computed suffixes + private static final String VAR_FIPS = "_s3e_fips"; + private static final String VAR_DS = "_s3e_ds"; + private static final String VAR_AUTH = "_s3e_auth"; + + // Suffix values used in the URI templates + private static final String FIPS_SUFFIX = "-fips"; + private static final String DS_SUFFIX = ".dualstack"; + private static final String EMPTY_SUFFIX = ""; + + // Auth scheme values used with s3-express + private static final String AUTH_SIGV4 = "sigv4"; + private static final String AUTH_SIGV4_S3EXPRESS = "sigv4-s3express"; + + // Property and parameter identifiers + private static final Identifier ID_AUTH_SCHEMES = Identifier.of("authSchemes"); + private static final Identifier ID_NAME = Identifier.of("name"); + private static final Identifier ID_BACKEND = Identifier.of("backend"); + private static final Identifier ID_BUCKET = Identifier.of("Bucket"); + private static final Identifier ID_AZ_ID = Identifier.of("s3expressAvailabilityZoneId"); + private static final Identifier ID_USE_FIPS = Identifier.of("UseFIPS"); + private static final Identifier ID_USE_DUAL_STACK = Identifier.of("UseDualStack"); + private static final Identifier ID_DISABLE_S3EXPRESS_SESSION_AUTH = Identifier.of("DisableS3ExpressSessionAuth"); + + // Auth scheme name literal shared across all rewritten endpoints + private static final Literal AUTH_NAME_LITERAL = Literal.stringLiteral(Template.fromString("{" + VAR_AUTH + "}")); + + // Patterns to match S3Express bucket endpoint URLs (with AZ) + // Format: https://{Bucket}.s3express[-fips]-{AZ}[.dualstack].{Region}.amazonaws.com + // (negative lookahead (?!dualstack) prevents matching dualstack variants in non-DS patterns) + private static final Pattern S3EXPRESS_FIPS_DS = Pattern.compile("(s3express)-fips-([^.]+)\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_FIPS = Pattern.compile("(s3express)-fips-([^.]+)\\.(?!dualstack)(.+)$"); + private static final Pattern S3EXPRESS_DS = Pattern.compile("(s3express)-([^.]+)\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_PLAIN = Pattern.compile("(s3express)-([^.]+)\\.(?!dualstack)(.+)$"); + + // Patterns to match S3Express control plane URLs (no AZ) + // Format: https://s3express-control[-fips][.dualstack].{Region}.amazonaws.com + private static final Pattern S3EXPRESS_CONTROL_FIPS_DS = Pattern.compile( + "(s3express-control)-fips\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_CONTROL_FIPS = Pattern.compile( + "(s3express-control)-fips\\.(?!dualstack)(.+)$"); + private static final Pattern S3EXPRESS_CONTROL_DS = Pattern.compile( + "(s3express-control)\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_CONTROL_PLAIN = Pattern.compile( + "(s3express-control)\\.(?!dualstack)(.+)$"); + + // Cached canonical expression for AZ extraction: split(Bucket, "--", 0) + private static final Split BUCKET_SPLIT = Split.ofExpressions( + Expression.getReference(ID_BUCKET), + Expression.of("--"), + Expression.of(0)); + + private int rewrittenCount = 0; + private int totalS3ExpressCount = 0; + + private S3TreeRewriter() {} + + /** + * Transforms the given endpoint rule set using canonical expressions. + * + * @param ruleSet the rule set to transform + * @return the transformed rule set + */ + public static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + return new S3TreeRewriter().run(ruleSet); + } + + private EndpointRuleSet run(EndpointRuleSet ruleSet) { + List transformedRules = new ArrayList<>(); + for (Rule rule : ruleSet.getRules()) { + transformedRules.add(transformRule(rule)); + } + + LOGGER.info(() -> String.format( + "S3 tree rewriter: %s/%s S3Express endpoints rewritten", + rewrittenCount, + totalS3ExpressCount)); + + return EndpointRuleSet.builder() + .sourceLocation(ruleSet.getSourceLocation()) + .parameters(ruleSet.getParameters()) + .rules(transformedRules) + .version(ruleSet.getVersion()) + .build(); + } + + private Rule transformRule(Rule rule) { + if (rule instanceof TreeRule) { + TreeRule tr = (TreeRule) rule; + // Transform conditions + List transformedConditions = transformConditions(tr.getConditions()); + List transformedChildren = new ArrayList<>(); + for (Rule child : tr.getRules()) { + transformedChildren.add(transformRule(child)); + } + return Rule.builder().conditions(transformedConditions).treeRule(transformedChildren); + } else if (rule instanceof EndpointRule) { + return rewriteEndpoint((EndpointRule) rule); + } else { + // Error rules pass through unchanged + return rule; + } + } + + private List transformConditions(List conditions) { + List result = new ArrayList<>(conditions.size()); + for (Condition cond : conditions) { + result.add(transformCondition(cond)); + } + return result; + } + + /** + * Transforms a single condition. + * + *

Handles: + *

+     * AZ extraction: substring(Bucket, N, M) -> split(Bucket, "--")[1]
+     * 
+ * + *

Note: Delimiter checks (s3expressAvailabilityZoneDelim) are not currently transformed because they're part + * of a complex fallback structure, and changing them breaks control flow. Possibly something we can improve, or + * wait until the upstream rules are optimized. + */ + private Condition transformCondition(Condition cond) { + // Is this a condition fishing for delimiters? + if (cond.getResult().isPresent() + && ID_AZ_ID.equals(cond.getResult().get()) + && cond.getFunction() instanceof Substring + && isSubstringOnBucket((Substring) cond.getFunction())) { + // Replace with split-based extraction: split(Bucket, "--")[1] + GetAttr azExpr = GetAttr.ofExpressions(BUCKET_SPLIT, "[1]"); + return cond.toBuilder().fn(azExpr).build(); + } + + return cond; + } + + private boolean isSubstringOnBucket(Substring substring) { + List args = substring.getArguments(); + if (args.isEmpty()) { + return false; + } + + Expression target = args.get(0); + return target instanceof Reference && ID_BUCKET.equals(((Reference) target).getName()); + } + + // Creates ITE conditions for branchless S3Express variable computation. + private List createIteConditions() { + List conditions = new ArrayList<>(); + conditions.add(createIteAssignment(VAR_FIPS, Expression.getReference(ID_USE_FIPS), FIPS_SUFFIX, EMPTY_SUFFIX)); + conditions.add(createIteAssignment( + VAR_DS, + Expression.getReference(ID_USE_DUAL_STACK), + DS_SUFFIX, + EMPTY_SUFFIX)); + // Auth scheme: sigv4 when session auth disabled, sigv4-s3express otherwise + Expression sessionAuthDisabled = Coalesce.ofExpressions( + Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), + Expression.of(false)); + conditions.add(createIteAssignment(VAR_AUTH, sessionAuthDisabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS)); + return conditions; + } + + // Creates an ITE-based assignment condition. + private Condition createIteAssignment(String varName, Expression condition, String trueValue, String falseValue) { + return Condition.builder() + .fn(Ite.ofStrings(condition, trueValue, falseValue)) + .result(varName) + .build(); + } + + // Rewrites an endpoint rule to use canonical S3Express URLs and auth schemes. + private Rule rewriteEndpoint(EndpointRule rule) { + Endpoint endpoint = rule.getEndpoint(); + Expression urlExpr = endpoint.getUrl(); + + // Extract the raw URL string from the expression (IFF it's a static string, rarely is anything else). + String urlStr = extractUrlString(urlExpr); + if (urlStr == null) { + return rule; + } + + // Check if this is an S3Express endpoint by URL or backend property. + // Note: while `contains("s3express")` is broad and could theoretically match path/query components, + // the subsequent matchUrl() call validates the hostname pattern before any rewriting occurs. + boolean isS3ExpressUrl = urlStr.contains("s3express"); + boolean isS3ExpressBackend = isS3ExpressBackend(endpoint); + + if (!isS3ExpressUrl && !isS3ExpressBackend) { + return rule; + } + + totalS3ExpressCount++; + + // For URL override endpoints (backend=S3Express but URL doesn't match s3express hostname), + // just canonicalize the auth scheme - no URL rewriting needed + if (isS3ExpressBackend && !isS3ExpressUrl) { + // Canonicalize auth scheme to use {_s3e_auth} + Map newProperties = canonicalizeAuthScheme(endpoint.getProperties()); + + if (newProperties == endpoint.getProperties()) { + // No changes needed + return rule; + } + + rewrittenCount++; + + Endpoint newEndpoint = Endpoint.builder() + .url(urlExpr) + .headers(endpoint.getHeaders()) + .properties(newProperties) + .sourceLocation(endpoint.getSourceLocation()) + .build(); + + // Add auth ITE condition for URL override endpoints + List allConditions = new ArrayList<>(rule.getConditions()); + allConditions.add(createAuthIteCondition()); + + return Rule.builder() + .conditions(allConditions) + .endpoint(newEndpoint); + } + + // Standard S3Express URL - match and rewrite + UrlMatchResult match = matchUrl(urlStr); + if (match == null) { + return rule; + } + + rewrittenCount++; + + // Rewrite the URL to use the ITE-assigned variables + String newUrl = match.rewriteUrl(); + + // Canonicalize auth scheme for bucket endpoints (not control plane) + // Control plane always uses sigv4, bucket endpoints vary based on DisableS3ExpressSessionAuth + Map newProperties = endpoint.getProperties(); + if (match instanceof BucketUrlMatchResult) { + newProperties = canonicalizeAuthScheme(endpoint.getProperties()); + } + + // Build the new endpoint with canonicalized URL and properties + Endpoint newEndpoint = Endpoint.builder() + .url(Expression.of(newUrl)) + .headers(endpoint.getHeaders()) + .properties(newProperties) + .sourceLocation(endpoint.getSourceLocation()) + .build(); + + // Add ITE conditions: original conditions first, then ITE conditions at the end. + List allConditions = new ArrayList<>(rule.getConditions()); + allConditions.addAll(createIteConditions()); + + return Rule.builder() + .conditions(allConditions) + .endpoint(newEndpoint); + } + + // Checks if the endpoint has `backend` property set to "S3Express". + private boolean isS3ExpressBackend(Endpoint endpoint) { + Literal backend = endpoint.getProperties().get(ID_BACKEND); + if (backend == null) { + return false; + } + + return backend.asStringLiteral() + .filter(Template::isStatic) + .map(t -> "S3Express".equalsIgnoreCase(t.expectLiteral())) + .orElse(false); + } + + // Creates just the auth ITE condition for URL override endpoints. + private Condition createAuthIteCondition() { + // `DisableS3ExpressSessionAuth` is nullable, so we need to coalesce it to have a false default. Fix upstream? + Expression isSessionAuthDisabled = Coalesce.ofExpressions( + Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), + Expression.of(false)); + return createIteAssignment(VAR_AUTH, isSessionAuthDisabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS); + } + + // Canonicalizes the authScheme name in endpoint properties to use the ITE variable. + private Map canonicalizeAuthScheme(Map properties) { + Literal authSchemes = properties.get(ID_AUTH_SCHEMES); + if (authSchemes == null) { + return properties; + } + + List schemes = authSchemes.asTupleLiteral().orElse(null); + if (schemes == null || schemes.isEmpty()) { + return properties; + } + + // Rewrite each auth scheme's name field + List newSchemes = new ArrayList<>(); + for (Literal scheme : schemes) { + Map record = scheme.asRecordLiteral().orElse(null); + if (record == null) { + // Auth is always a record, but maybe that changes in the future, so pass it through. + newSchemes.add(scheme); + continue; + } + + Literal nameLiteral = record.get(ID_NAME); + if (nameLiteral == null) { + // "name" should always be set, but pass through if not. + newSchemes.add(scheme); + continue; + } + + // Only transform string literals we recognize. + String name = nameLiteral.asStringLiteral() + .filter(Template::isStatic) + .map(Template::expectLiteral) + .orElse(null); + + // Only rewrite if it's one of the S3Express auth schemes + if (AUTH_SIGV4.equals(name) || AUTH_SIGV4_S3EXPRESS.equals(name)) { + Map newRecord = new LinkedHashMap<>(record); + newRecord.put(ID_NAME, AUTH_NAME_LITERAL); + newSchemes.add(Literal.recordLiteral(newRecord)); + } else { + newSchemes.add(scheme); + } + } + + Map newProperties = new LinkedHashMap<>(properties); + newProperties.put(ID_AUTH_SCHEMES, Literal.tupleLiteral(newSchemes)); + return newProperties; + } + + // Extracts the raw URL string from a URL expression. + private String extractUrlString(Expression urlExpr) { + return urlExpr.toNode().asStringNode().map(StringNode::getValue).orElse(null); + } + + // Matches an S3Express URL and returns the pattern match info. Tries to match in most specific order. + private UrlMatchResult matchUrl(String url) { + Matcher m; + + // First try control plane patterns (no AZ) since these are more specific + m = S3EXPRESS_CONTROL_FIPS_DS.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + m = S3EXPRESS_CONTROL_FIPS.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + m = S3EXPRESS_CONTROL_DS.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + m = S3EXPRESS_CONTROL_PLAIN.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + // Next, try bucket endpoint patterns (with AZ) + m = S3EXPRESS_FIPS_DS.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + m = S3EXPRESS_FIPS.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + m = S3EXPRESS_DS.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + m = S3EXPRESS_PLAIN.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + return null; + } + + /** + * Result of matching an S3Express URL pattern. + */ + private abstract static class UrlMatchResult { + protected final String prefix; + + UrlMatchResult(String prefix) { + this.prefix = prefix; + } + + abstract String rewriteUrl(); + } + + /** + * Match result for bucket endpoints (with AZ): {prefix}s3express{fips}-{AZ}{ds}.{region} + */ + private static final class BucketUrlMatchResult extends UrlMatchResult { + private final String s3express; + private final String az; + private final String regionSuffix; + + BucketUrlMatchResult(String url, Matcher m) { + super(url.substring(0, m.start())); + this.s3express = m.group(1); + this.az = m.group(2); + this.regionSuffix = m.group(3); + } + + @Override + String rewriteUrl() { + return String.format("%s%s{%s}-%s{%s}.%s", prefix, s3express, VAR_FIPS, az, VAR_DS, regionSuffix); + } + } + + /** + * Match result for control plane endpoints (no AZ): {prefix}s3express-control{fips}{ds}.{region} + */ + private static final class ControlPlaneUrlMatchResult extends UrlMatchResult { + private final String s3expressControl; + private final String regionSuffix; + + ControlPlaneUrlMatchResult(String url, Matcher m) { + super(url.substring(0, m.start())); + this.s3expressControl = m.group(1); + this.regionSuffix = m.group(2); + } + + @Override + String rewriteUrl() { + return String.format("%s%s{%s}{%s}.%s", prefix, s3expressControl, VAR_FIPS, VAR_DS, regionSuffix); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java index 9e0dfb0fd39..2dda13db308 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java @@ -11,6 +11,7 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsValidHostLabel; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.ParseUrl; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Split; @@ -43,6 +44,7 @@ public List getLibraryFunctions() { Split.getDefinition(), StringEquals.getDefinition(), Substring.getDefinition(), + Ite.getDefinition(), UriEncode.getDefinition()); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index 6e8d70a8771..efc82026a30 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -215,6 +215,12 @@ public Value visitStringEquals(Expression left, Expression right) { .equals(right.accept(this).expectStringValue())); } + @Override + public Value visitIte(Expression condition, Expression trueValue, Expression falseValue) { + boolean cond = condition.accept(this).expectBooleanValue().getValue(); + return cond ? trueValue.accept(this) : falseValue.accept(this); + } + @Override public Value visitGetAttr(GetAttr getAttr) { return getAttr.evaluate(getAttr.getTarget().accept(this)); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java index 1557b529b52..b4bbc93868f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java @@ -4,10 +4,12 @@ */ package software.amazon.smithy.rulesengine.language.syntax.expressions; +import java.util.Arrays; import java.util.List; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionDefinition; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -86,6 +88,18 @@ default R visitCoalesce(List expressions) { */ R visitStringEquals(Expression left, Expression right); + /** + * Visits an if-then-else (ITE) function. + * + * @param condition the boolean condition expression. + * @param trueValue the value if condition is true. + * @param falseValue the value if condition is false. + * @return the value from the visitor. + */ + default R visitIte(Expression condition, Expression trueValue, Expression falseValue) { + return visitLibraryFunction(Ite.getDefinition(), Arrays.asList(condition, trueValue, falseValue)); + } + /** * Visits a library function. * @@ -138,6 +152,11 @@ public R visitStringEquals(Expression left, Expression right) { return getDefault(); } + @Override + public R visitIte(Expression condition, Expression trueValue, Expression falseValue) { + return getDefault(); + } + @Override public R visitLibraryFunction(FunctionDefinition fn, List args) { return getDefault(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java new file mode 100644 index 00000000000..30d383cbcd9 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java @@ -0,0 +1,174 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.expressions.functions; + +import java.util.Arrays; +import java.util.List; +import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.ToExpression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.ExpressionVisitor; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * An if-then-else (ITE) function that returns one of two values based on a boolean condition. + * + *

This function is critical for avoiding SSA (Static Single Assignment) fragmentation in BDD compilation. + * By computing conditional values atomically without branching, it prevents the graph explosion that occurs when + * boolean flags like UseFips or UseDualStack create divergent paths with distinct variable identities. + * + *

Semantics: {@code ite(condition, trueValue, falseValue)} + *

    + *
  • If condition is true, returns trueValue
  • + *
  • If condition is false, returns falseValue
  • + *
  • The condition must be a non-optional boolean (use coalesce to provide a default if needed)
  • + *
+ * + *

Type checking rules (least upper bound of nullability): + *

    + *
  • {@code ite(Boolean, T, T) => T} - both non-optional, result is non-optional
  • + *
  • {@code ite(Boolean, T, Optional) => Optional} - any optional makes result optional
  • + *
  • {@code ite(Boolean, Optional, T) => Optional} - any optional makes result optional
  • + *
  • {@code ite(Boolean, Optional, Optional) => Optional} - both optional, result is optional
  • + *
+ * + *

Available since: rules engine 1.1. + */ +@SmithyUnstableApi +public final class Ite extends LibraryFunction { + public static final String ID = "ite"; + private static final Definition DEFINITION = new Definition(); + + private Ite(FunctionNode functionNode) { + super(DEFINITION, functionNode); + } + + /** + * Gets the {@link FunctionDefinition} implementation. + * + * @return the function definition. + */ + public static Definition getDefinition() { + return DEFINITION; + } + + /** + * Creates a {@link Ite} function from the given expressions. + * + * @param condition the boolean condition to evaluate + * @param trueValue the value to return if condition is true + * @param falseValue the value to return if condition is false + * @return The resulting {@link Ite} function. + */ + public static Ite ofExpressions(ToExpression condition, ToExpression trueValue, ToExpression falseValue) { + return DEFINITION.createFunction(FunctionNode.ofExpressions(ID, condition, trueValue, falseValue)); + } + + /** + * Creates a {@link Ite} function with a reference condition and string values. + * + * @param conditionRef the reference to a boolean parameter + * @param trueValue the string value if condition is true + * @param falseValue the string value if condition is false + * @return The resulting {@link Ite} function. + */ + public static Ite ofStrings(ToExpression conditionRef, String trueValue, String falseValue) { + return ofExpressions(conditionRef, Expression.of(trueValue), Expression.of(falseValue)); + } + + @Override + public RulesVersion availableSince() { + return RulesVersion.V1_1; + } + + @Override + public R accept(ExpressionVisitor visitor) { + return visitor.visitIte(getArguments().get(0), getArguments().get(1), getArguments().get(2)); + } + + @Override + public Type typeCheck(Scope scope) { + List args = getArguments(); + if (args.size() != 3) { + throw new IllegalArgumentException("ITE requires exactly 3 arguments, got " + args.size()); + } + + // Check condition is a boolean (non-optional) + Type conditionType = args.get(0).typeCheck(scope); + if (!conditionType.equals(Type.booleanType())) { + throw new IllegalArgumentException(String.format( + "ITE condition must be a non-optional Boolean, got %s. " + + "Use coalesce to provide a default for optional booleans.", + conditionType)); + } + + // Get trueValue and falseValue types + Type trueType = args.get(1).typeCheck(scope); + Type falseType = args.get(2).typeCheck(scope); + + // Extract base types (unwrap Optional if present) + Type trueBaseType = getInnerType(trueType); + Type falseBaseType = getInnerType(falseType); + + // Base types must match + if (!trueBaseType.equals(falseBaseType)) { + throw new IllegalArgumentException(String.format( + "ITE branches must have the same base type: true branch is %s, false branch is %s", + trueBaseType, + falseBaseType)); + } + + // Result is optional if EITHER branch is optional (least upper bound) + boolean resultIsOptional = (trueType instanceof OptionalType) || (falseType instanceof OptionalType); + return resultIsOptional ? Type.optionalType(trueBaseType) : trueBaseType; + } + + private static Type getInnerType(Type t) { + return (t instanceof OptionalType) ? ((OptionalType) t).inner() : t; + } + + /** + * A {@link FunctionDefinition} for the {@link Ite} function. + */ + public static final class Definition implements FunctionDefinition { + private Definition() {} + + @Override + public String getId() { + return ID; + } + + @Override + public List getArguments() { + // Actual type checking is done in typeCheck override + return Arrays.asList(Type.booleanType(), Type.anyType(), Type.anyType()); + } + + @Override + public Type getReturnType() { + // Actual return type is computed in typeCheck override + return Type.anyType(); + } + + @Override + public Value evaluate(List arguments) { + throw new UnsupportedOperationException("ITE evaluation is handled by ExpressionVisitor"); + } + + @Override + public Ite createFunction(FunctionNode functionNode) { + return new Ite(functionNode); + } + + @Override + public int getCost() { + return 10; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java index 11dbcc8b119..ea80fb56fbb 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java @@ -293,7 +293,7 @@ public static final class Builder implements SmithyBuilder { private Cfg cfg; private ConditionCostModel costModel; private ToDoubleFunction trueProbability; - private double maxAllowedGrowth = 0.1; + private double maxAllowedGrowth = 0.08; private int maxRounds = 30; private int topK = 50; @@ -333,7 +333,7 @@ public Builder trueProbability(ToDoubleFunction trueProbability) { } /** - * Sets the maximum allowed node growth as a fraction (default 0.1 or 10%). + * Sets the maximum allowed node growth as a fraction (default 0.08 or 8%). * * @param maxAllowedGrowth maximum growth (0.0 = no growth, 0.1 = 10% growth) * @return the builder diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index 8f0d6686b9c..eda9157d1d0 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -59,7 +59,7 @@ public final class SiftingOptimization implements Function '%s' (different base names) for: %s", + varName, + globalVar, + canonical)); + } else if (!wouldCauseShadowing(globalVar, path, ancestorVars)) { variableRenameMap.put(varName, globalVar); consolidatedCount++; LOGGER.info(String.format("Consolidating '%s' -> '%s' for: %s", @@ -177,6 +184,42 @@ private void discoverBindingsInRule( } } + /** + * Checks if two variable names have the same base name. + * For SSA-style variables like "foo_1" and "foo_2", the base name is "foo". + * Variables without SSA suffix (like "s3e_fips" and "s3e_ds") are considered + * to have their full name as the base. + */ + private boolean hasSameBaseName(String var1, String var2) { + String base1 = getSsaBaseName(var1); + String base2 = getSsaBaseName(var2); + return base1.equals(base2); + } + + /** + * Extracts the SSA base name from a variable. + * If the variable ends with _N (where N is a number), strips the suffix. + * Otherwise returns the full name. + */ + private String getSsaBaseName(String varName) { + int lastUnderscore = varName.lastIndexOf('_'); + if (lastUnderscore > 0 && lastUnderscore < varName.length() - 1) { + String suffix = varName.substring(lastUnderscore + 1); + // Check if suffix is all digits + boolean allDigits = true; + for (int i = 0; i < suffix.length(); i++) { + if (!Character.isDigit(suffix.charAt(i))) { + allDigits = false; + break; + } + } + if (allDigits) { + return varName.substring(0, lastUnderscore); + } + } + return varName; + } + private boolean wouldCauseShadowing(String varName, String currentPath, Set ancestorVars) { // Check if using this variable name would shadow an ancestor variable if (ancestorVars.contains(varName)) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java index 9cd9d627c36..619ca5994ee 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java @@ -206,6 +206,16 @@ public static EndpointBddTrait fromNode(Node node) { results.add(NoMatchRule.INSTANCE); // Always add no-match at index 0 results.addAll(serializedResults); + // Validate that results have no conditions (all conditions are hoisted into the BDD) + for (int i = 1; i < results.size(); i++) { + Rule rule = results.get(i); + if (!rule.getConditions().isEmpty()) { + throw new IllegalArgumentException( + "BDD result at index " + i + " has conditions, but BDD results must not have conditions. " + + "All conditions should be hoisted into the BDD decision structure."); + } + } + String nodesBase64 = obj.expectStringMember("nodes").getValue(); int nodeCount = obj.expectNumberMember("nodeCount").getValue().intValue(); int rootRef = obj.expectNumberMember("root").getValue().intValue(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java index c79ab5fbeb8..62e06b43742 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java @@ -124,7 +124,7 @@ private String validateAuthSchemeName( FromSourceLocation sourceLocation ) { Literal nameLiteral = authScheme.get(NAME); - if (nameLiteral == null) { + if (nameLiteral == null || nameLiteral.asStringLiteral().isEmpty()) { events.add(error(service, sourceLocation, String.format( @@ -133,13 +133,14 @@ private String validateAuthSchemeName( return null; } - String name = nameLiteral.asStringLiteral().map(s -> s.expectLiteral()).orElse(null); + // Try to get the name as a literal string. If the template contains variables + // (e.g., from branchless transforms like "{s3e_auth}"), we can't statically validate. + String name = nameLiteral.asStringLiteral() + .filter(t -> t.isStatic()) + .map(t -> t.expectLiteral()) + .orElse(null); if (name == null) { - events.add(error(service, - sourceLocation, - String.format( - "Expected `authSchemes` to have a `name` key with a string value but it did not: `%s`", - authScheme))); + // String literal with template variables - skip static validation return null; } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java new file mode 100644 index 00000000000..580fa7168e6 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java @@ -0,0 +1,234 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + +public class IteTest { + + @Test + void testIteBothBranchesNonOptionalString() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("-fips"); + Expression falseValue = Literal.of(""); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + // Both non-optional String => non-optional String + assertEquals(Type.stringType(), resultType); + } + + @Test + void testIteBothBranchesNonOptionalInteger() { + Expression condition = Expression.getReference(Identifier.of("useNewValue")); + Expression trueValue = Literal.of(100); + Expression falseValue = Literal.of(0); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("useNewValue", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.integerType(), resultType); + } + + @Test + void testIteTrueBranchOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeValue")); + Expression falseValue = Literal.of("default"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeValue", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // True branch optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteFalseBranchOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("value"); + Expression falseValue = Expression.getReference(Identifier.of("maybeDefault")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeDefault", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // False branch optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteBothBranchesOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybe1")); + Expression falseValue = Expression.getReference(Identifier.of("maybe2")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybe1", Type.optionalType(Type.stringType())); + scope.insert("maybe2", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // Both optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteWithOfStringsHelper() { + Expression condition = Expression.getReference(Identifier.of("UseFIPS")); + Ite ite = Ite.ofStrings(condition, "-fips", ""); + + Scope scope = new Scope<>(); + scope.insert("UseFIPS", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + // Both literal strings => non-optional String + assertEquals(Type.stringType(), resultType); + } + + @Test + void testIteTypeMismatchBetweenBranches() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("string"); + Expression falseValue = Literal.of(42); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("same base type")); + assertTrue(ex.getMessage().contains("true branch")); + assertTrue(ex.getMessage().contains("false branch")); + } + + @Test + void testIteConditionMustBeBoolean() { + Expression condition = Literal.of("not a boolean"); + Expression trueValue = Literal.of("yes"); + Expression falseValue = Literal.of("no"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("non-optional Boolean")); + } + + @Test + void testIteConditionCannotBeOptionalBoolean() { + Expression condition = Expression.getReference(Identifier.of("maybeFlag")); + Expression trueValue = Literal.of("yes"); + Expression falseValue = Literal.of("no"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("maybeFlag", Type.optionalType(Type.booleanType())); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("non-optional Boolean")); + assertTrue(ex.getMessage().contains("coalesce")); + } + + @Test + void testIteWithArrayTypes() { + Expression condition = Expression.getReference(Identifier.of("useFirst")); + Expression trueValue = Expression.getReference(Identifier.of("array1")); + Expression falseValue = Expression.getReference(Identifier.of("array2")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("useFirst", Type.booleanType()); + scope.insert("array1", Type.arrayType(Type.stringType())); + scope.insert("array2", Type.arrayType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.arrayType(Type.stringType()), resultType); + } + + @Test + void testIteWithOptionalArrayType() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeArray")); + Expression falseValue = Expression.getReference(Identifier.of("definiteArray")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeArray", Type.optionalType(Type.arrayType(Type.integerType()))); + scope.insert("definiteArray", Type.arrayType(Type.integerType())); + + Type resultType = ite.typeCheck(scope); + + // One optional array => result is optional array + assertEquals(Type.optionalType(Type.arrayType(Type.integerType())), resultType); + } + + @Test + void testIteWithBooleanValues() { + Expression condition = Expression.getReference(Identifier.of("invertFlag")); + Expression trueValue = Literal.of(false); + Expression falseValue = Literal.of(true); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("invertFlag", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.booleanType(), resultType); + } + + @Test + void testIteTypeMismatchWithOptionalUnwrapping() { + // Even with optional wrapping, base types must match + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeString")); + Expression falseValue = Expression.getReference(Identifier.of("maybeInt")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeString", Type.optionalType(Type.stringType())); + scope.insert("maybeInt", Type.optionalType(Type.integerType())); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("same base type")); + } + + @Test + void testIteReturnsCorrectId() { + assertEquals("ite", Ite.ID); + assertEquals("ite", Ite.getDefinition().getId()); + } +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors @@ -0,0 +1 @@ + diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy new file mode 100644 index 00000000000..75a4d4d050c --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy @@ -0,0 +1,80 @@ +$version: "2.0" + +namespace example + +use smithy.rules#clientContextParams +use smithy.rules#endpointRuleSet +use smithy.rules#endpointTests + +@clientContextParams( + useFips: {type: "boolean", documentation: "Use FIPS endpoints"} +) +@endpointRuleSet({ + version: "1.1", + parameters: { + useFips: { + type: "boolean", + documentation: "Use FIPS endpoints", + default: false, + required: true + } + }, + rules: [ + { + "documentation": "Use ite to select endpoint suffix" + "conditions": [ + { + "fn": "ite" + "argv": [{"ref": "useFips"}, "-fips", ""] + "assign": "suffix" + } + ] + "endpoint": { + "url": "https://example{suffix}.com" + } + "type": "endpoint" + } + ] +}) +@endpointTests({ + "version": "1.0", + "testCases": [ + { + "documentation": "When useFips is true, returns trueValue" + "params": { + "useFips": true + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example-fips.com" + } + } + } + { + "documentation": "When useFips is false, returns falseValue" + "params": { + "useFips": false + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com" + } + } + } + ] +}) +@suppress(["UnstableTrait.smithy"]) +service FizzBuzz { + version: "2022-01-01", + operations: [GetThing] +} + +operation GetThing { + input := {} +} From cffa45f14ef8b6a1b4c99da96869e9aae3a74e64 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Thu, 18 Dec 2025 10:27:47 -0600 Subject: [PATCH 03/10] Run type checking on BDD so type() works --- .../expressions/functions/Coalesce.java | 7 +-- .../syntax/expressions/functions/Ite.java | 9 ++-- .../language/syntax/rule/NoMatchRule.java | 2 +- .../rulesengine/traits/EndpointBddTrait.java | 18 ++++++- .../syntax/functions/CoalesceTest.java | 28 ++++++++-- .../language/syntax/functions/IteTest.java | 52 +++++++++++++++++-- .../rulesengine/traits/BddTraitTest.java | 48 ++++++++++++++++- 7 files changed, 146 insertions(+), 18 deletions(-) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java index 931e5d9f9dd..039d461350e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Optional; import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.error.InnerParseError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; @@ -81,10 +82,10 @@ public R accept(ExpressionVisitor visitor) { } @Override - public Type typeCheck(Scope scope) { + protected Type typeCheckLocal(Scope scope) throws InnerParseError { List args = getArguments(); if (args.size() < 2) { - throw new IllegalArgumentException("Coalesce requires at least 2 arguments, got " + args.size()); + throw new InnerParseError("Coalesce requires at least 2 arguments, got " + args.size()); } // Get the first argument's type as the baseline @@ -98,7 +99,7 @@ public Type typeCheck(Scope scope) { Type innerType = getInnerType(argType); if (!innerType.equals(baseInnerType)) { - throw new IllegalArgumentException(String.format( + throw new InnerParseError(String.format( "Type mismatch in coalesce at argument %d: expected %s but got %s", i + 1, baseInnerType, diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java index 30d383cbcd9..90acc71da01 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java @@ -7,6 +7,7 @@ import java.util.Arrays; import java.util.List; import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.error.InnerParseError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; @@ -93,16 +94,16 @@ public R accept(ExpressionVisitor visitor) { } @Override - public Type typeCheck(Scope scope) { + protected Type typeCheckLocal(Scope scope) throws InnerParseError { List args = getArguments(); if (args.size() != 3) { - throw new IllegalArgumentException("ITE requires exactly 3 arguments, got " + args.size()); + throw new InnerParseError("ITE requires exactly 3 arguments, got " + args.size()); } // Check condition is a boolean (non-optional) Type conditionType = args.get(0).typeCheck(scope); if (!conditionType.equals(Type.booleanType())) { - throw new IllegalArgumentException(String.format( + throw new InnerParseError(String.format( "ITE condition must be a non-optional Boolean, got %s. " + "Use coalesce to provide a default for optional booleans.", conditionType)); @@ -118,7 +119,7 @@ public Type typeCheck(Scope scope) { // Base types must match if (!trueBaseType.equals(falseBaseType)) { - throw new IllegalArgumentException(String.format( + throw new InnerParseError(String.format( "ITE branches must have the same base type: true branch is %s, false branch is %s", trueBaseType, falseBaseType)); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java index d7c76f7feec..be58bb2415d 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java @@ -28,7 +28,7 @@ public T accept(RuleValueVisitor visitor) { @Override protected Type typecheckValue(Scope scope) { - throw new UnsupportedOperationException("NO_MATCH is a sentinel"); + return Type.anyType(); } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java index 619ca5994ee..52ffd2b4f36 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java @@ -23,6 +23,8 @@ import software.amazon.smithy.model.traits.AbstractTraitBuilder; import software.amazon.smithy.model.traits.Trait; import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; @@ -360,7 +362,21 @@ public Builder bdd(Bdd bdd) { @Override public EndpointBddTrait build() { - return new EndpointBddTrait(this); + EndpointBddTrait trait = new EndpointBddTrait(this); + + // Type-check conditions and results so expression.type() works. Note that using a shared scope across + // each check is ok, because BDD evaluation always runs conditions in a fixed order and could in theory + // try every condition for a single path to a result. + Scope scope = new Scope<>(); + trait.getParameters().writeToScope(scope); + for (Condition condition : trait.getConditions()) { + condition.typeCheck(scope); + } + for (Rule result : trait.getResults()) { + result.typeCheck(scope); + } + + return trait; } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java index bf6ac4bb9da..7dd1d118883 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java @@ -10,6 +10,7 @@ import java.util.Arrays; import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.error.RuleError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.Identifier; @@ -135,7 +136,7 @@ void testCoalesceWithIncompatibleTypes() { Scope scope = new Scope<>(); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> coalesce.typeCheck(scope)); assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); assertTrue(ex.getMessage().contains("argument 2")); } @@ -151,7 +152,7 @@ void testCoalesceWithIncompatibleTypesInMiddle() { Scope scope = new Scope<>(); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> coalesce.typeCheck(scope)); assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); assertTrue(ex.getMessage().contains("argument 3")); } @@ -160,8 +161,7 @@ void testCoalesceWithIncompatibleTypesInMiddle() { void testCoalesceWithLessThanTwoArguments() { Expression single = Literal.of("only"); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, - () -> Coalesce.ofExpressions(single).typeCheck(new Scope<>())); + RuleError ex = assertThrows(RuleError.class, () -> Coalesce.ofExpressions(single).typeCheck(new Scope<>())); assertTrue(ex.getMessage().contains("at least 2 arguments")); } @@ -215,4 +215,24 @@ void testCoalesceWithBooleanTypes() { assertEquals(Type.booleanType(), resultType); } + + @Test + void testTypeMethodReturnsInferredTypeAfterTypeCheck() { + // Verify that type() returns the correct inferred type after typeCheck() + Expression optional1 = Expression.getReference(Identifier.of("maybeValue1")); + Expression optional2 = Expression.getReference(Identifier.of("maybeValue2")); + Expression definite = Literal.of("default"); + Coalesce coalesce = Coalesce.ofExpressions(optional1, optional2, definite); + + Scope scope = new Scope<>(); + scope.insert("maybeValue1", Type.optionalType(Type.stringType())); + scope.insert("maybeValue2", Type.optionalType(Type.stringType())); + + // Call typeCheck to cache the type + coalesce.typeCheck(scope); + + // Now type() should return the inferred type (non-optional since last arg is definite) + Type cachedType = coalesce.type(); + assertEquals(Type.stringType(), cachedType); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java index 580fa7168e6..5c57ed7dc6a 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.error.RuleError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.Identifier; @@ -125,7 +126,7 @@ void testIteTypeMismatchBetweenBranches() { Scope scope = new Scope<>(); scope.insert("flag", Type.booleanType()); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); assertTrue(ex.getMessage().contains("same base type")); assertTrue(ex.getMessage().contains("true branch")); assertTrue(ex.getMessage().contains("false branch")); @@ -140,7 +141,7 @@ void testIteConditionMustBeBoolean() { Scope scope = new Scope<>(); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); assertTrue(ex.getMessage().contains("non-optional Boolean")); } @@ -154,7 +155,7 @@ void testIteConditionCannotBeOptionalBoolean() { Scope scope = new Scope<>(); scope.insert("maybeFlag", Type.optionalType(Type.booleanType())); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); assertTrue(ex.getMessage().contains("non-optional Boolean")); assertTrue(ex.getMessage().contains("coalesce")); } @@ -222,7 +223,7 @@ void testIteTypeMismatchWithOptionalUnwrapping() { scope.insert("maybeString", Type.optionalType(Type.stringType())); scope.insert("maybeInt", Type.optionalType(Type.integerType())); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); assertTrue(ex.getMessage().contains("same base type")); } @@ -231,4 +232,47 @@ void testIteReturnsCorrectId() { assertEquals("ite", Ite.ID); assertEquals("ite", Ite.getDefinition().getId()); } + + @Test + void testTypeMethodReturnsInferredTypeAfterTypeCheck() { + // Verify that type() returns the correct inferred type after typeCheck() + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeValue")); + Expression falseValue = Literal.of("default"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeValue", Type.optionalType(Type.stringType())); + + // Call typeCheck to cache the type + ite.typeCheck(scope); + + // Now type() should return the inferred type + Type cachedType = ite.type(); + assertEquals(Type.optionalType(Type.stringType()), cachedType); + } + + @Test + void testNestedIteTypeInference() { + // Test that nested Ite expressions have correct type inference + Expression outerCondition = Expression.getReference(Identifier.of("outer")); + Expression innerCondition = Expression.getReference(Identifier.of("inner")); + + // Inner ITE: ite(inner, "a", "b") => String + Ite innerIte = Ite.ofExpressions(innerCondition, Literal.of("a"), Literal.of("b")); + + // Outer ITE: ite(outer, innerIte, "c") => String + Ite outerIte = Ite.ofExpressions(outerCondition, innerIte, Literal.of("c")); + + Scope scope = new Scope<>(); + scope.insert("outer", Type.booleanType()); + scope.insert("inner", Type.booleanType()); + + outerIte.typeCheck(scope); + + // Both inner and outer should have String type + assertEquals(Type.stringType(), innerIte.type()); + assertEquals(Type.stringType(), outerIte.type()); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java index 88c2aab778e..a4defa53c43 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java @@ -12,6 +12,13 @@ import java.util.List; import org.junit.jupiter.api.Test; import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; @@ -25,7 +32,11 @@ public class BddTraitTest { @Test void testBddTraitSerialization() { // Create a BddTrait with full context - Parameters params = Parameters.builder().build(); + Parameter regionParam = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + Parameters params = Parameters.builder().addParameter(regionParam).build(); Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); @@ -99,4 +110,39 @@ void testEmptyBddTrait() { assertEquals(1, trait.getResults().size()); assertEquals(-1, trait.getBdd().getRootRef()); // FALSE terminal } + + @Test + void testBuildTypeChecksExpressionsForCodegen() { + // Verify that after building an EndpointBddTrait, expression.type() works + // This is important for codegen to infer types without a scope + Parameter regionParam = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + Parameters params = Parameters.builder().addParameter(regionParam).build(); + + // Create a condition with a coalesce that infers to String + Expression regionRef = Expression.getReference(Identifier.of("Region")); + Expression fallback = Literal.of("us-east-1"); + Coalesce coalesce = Coalesce.ofExpressions(regionRef, fallback); + Condition cond = Condition.builder().fn(coalesce).result(Identifier.of("resolvedRegion")).build(); + + Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); + results.add(endpoint); + + EndpointBddTrait trait = EndpointBddTrait.builder() + .parameters(params) + .conditions(ListUtils.of(cond)) + .results(results) + .bdd(createSimpleBdd()) + .build(); + + // After build(), type() should work on the coalesce expression + // Region is Optional, fallback is String, so result is String (non-optional) + Coalesce builtCoalesce = (Coalesce) trait.getConditions().get(0).getFunction(); + assertEquals(Type.stringType(), builtCoalesce.type()); + } } From 3f2c363a46b72c767cfb3d1e2df5f5b5e8b1bb30 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Fri, 19 Dec 2025 10:31:03 -0600 Subject: [PATCH 04/10] Improve some loops --- .../logic/bdd/SiftingOptimization.java | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index eda9157d1d0..b779eceb4f3 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -249,9 +249,12 @@ private void runBlockMoves(State state) { } LOGGER.info("Running block moves"); - List> blocks = findDependencyBlocks(state.orderView).stream() - .filter(b -> b.size() >= 2 && b.size() <= 5) - .collect(Collectors.toList()); + List> blocks = new ArrayList<>(); + for (List b : findDependencyBlocks(state.orderView)) { + if (b.size() >= 2 && b.size() <= 5) { + blocks.add(b); + } + } for (List block : blocks) { PassContext ctx = new PassContext(state, dependencyGraph); @@ -464,7 +467,13 @@ private Result findBestPosition(List positions, PassContext ctx, int va } // Second pass: among min-size candidates, pick lowest cost - int minSize = candidates.stream().mapToInt(c -> c.size).min().orElse(Integer.MAX_VALUE); + int minSize = Integer.MAX_VALUE; + for (Result c : candidates) { + if (c.size < minSize) { + minSize = c.size; + } + } + Result best = null; for (Result c : candidates) { if (c.size == minSize) { From 789c130c8cbb287e8eb4b0900f6a2a739b651747 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Fri, 19 Dec 2025 10:56:39 -0600 Subject: [PATCH 05/10] Attempt to fix the windows build --- settings.gradle.kts | 2 ++ smithy-aws-endpoints/build.gradle.kts | 14 +++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/settings.gradle.kts b/settings.gradle.kts index f3c9eba093a..bffb67b89a7 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -6,6 +6,8 @@ pluginManagement { } } + + rootProject.name = "smithy" include(":smithy-aws-iam-traits") diff --git a/smithy-aws-endpoints/build.gradle.kts b/smithy-aws-endpoints/build.gradle.kts index 35c213e1e70..559731a51a3 100644 --- a/smithy-aws-endpoints/build.gradle.kts +++ b/smithy-aws-endpoints/build.gradle.kts @@ -26,7 +26,7 @@ dependencies { } // Integration test source set for tests that require the S3 model -// These tests require JDK 17+ due to the S3 model dependency +// These tests require JDK 21+ due to the S3 model dependency sourceSets { create("it") { compileClasspath += sourceSets["main"].output + sourceSets["test"].output @@ -38,15 +38,15 @@ configurations["itImplementation"].extendsFrom(configurations["testImplementatio configurations["itRuntimeOnly"].extendsFrom(configurations["testRuntimeOnly"]) configurations["itImplementation"].extendsFrom(s3Model) -// Configure IT source set to compile with JDK 17 +// Configure IT source set to compile with JDK 21 tasks.named("compileItJava") { javaCompiler.set( javaToolchains.compilerFor { - languageVersion.set(JavaLanguageVersion.of(17)) + languageVersion.set(JavaLanguageVersion.of(21)) }, ) - sourceCompatibility = "17" - targetCompatibility = "17" + sourceCompatibility = "21" + targetCompatibility = "21" } val integrationTest by tasks.registering(Test::class) { @@ -57,10 +57,10 @@ val integrationTest by tasks.registering(Test::class) { dependsOn(tasks.jar) shouldRunAfter(tasks.test) - // Run with JDK 17 + // Run with JDK 21 javaLauncher.set( javaToolchains.launcherFor { - languageVersion.set(JavaLanguageVersion.of(17)) + languageVersion.set(JavaLanguageVersion.of(21)) }, ) } From 72ffbb598f041b2ad3b4879119a87ef2b06aab99 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Mon, 29 Dec 2025 21:58:55 -0600 Subject: [PATCH 06/10] Optimize rules further Add nullable boolean transform. Improve variable and expression sharing. Add S3 BDD integration test. Add TreeMapper to simplify transforms. --- .../aws/language/functions/S3BddTest.java | 76 ++++ .../language/functions/S3TreeRewriter.java | 225 ++++------ .../functions/LibraryFunction.java | 31 +- .../rulesengine/logic/cfg/CfgBuilder.java | 66 ++- .../logic/cfg/CoalesceTransform.java | 16 +- .../cfg/DeadStoreEliminationTransform.java | 51 +++ .../cfg/IsSetBooleanCoalesceTransform.java | 144 +++++++ .../rulesengine/logic/cfg/SsaTransform.java | 299 ++++--------- .../logic/cfg/SyntheticBindingTransform.java | 76 ++++ .../{TreeRewriter.java => TreeMapper.java} | 408 +++++++++++------- .../logic/cfg/VariableAnalysis.java | 61 +-- .../cfg/VariableConsolidationTransform.java | 338 ++++++--------- .../logic/cfg/ReferenceRewriterTest.java | 36 +- .../logic/cfg/SsaTransformTest.java | 104 ++++- .../cfg/SyntheticBindingTransformTest.java | 148 +++++++ 15 files changed, 1265 insertions(+), 814 deletions(-) create mode 100644 smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/DeadStoreEliminationTransform.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/IsSetBooleanCoalesceTransform.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransform.java rename smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/{TreeRewriter.java => TreeMapper.java} (50%) create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransformTest.java diff --git a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java new file mode 100644 index 00000000000..bc4006878aa --- /dev/null +++ b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java @@ -0,0 +1,76 @@ +package software.amazon.smithy.rulesengine.aws.language.functions; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +import java.util.List; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; +import software.amazon.smithy.rulesengine.logic.bdd.CostOptimization; +import software.amazon.smithy.rulesengine.logic.bdd.SiftingOptimization; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; +import software.amazon.smithy.rulesengine.traits.EndpointTestCase; +import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; + +class S3BddTest { + private static final ShapeId S3_SERVICE_ID = ShapeId.from("com.amazonaws.s3#AmazonS3"); + private static EndpointRuleSet originalRules; + private static EndpointRuleSet rules; + private static List testCases; + + @BeforeAll + static void loadS3Model() { + Model model = Model.assembler() + .discoverModels() + .assemble() + .unwrap(); + + ServiceShape s3Service = model.expectShape(S3_SERVICE_ID, ServiceShape.class); + originalRules = s3Service.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); + rules = S3TreeRewriter.transform(originalRules); + testCases = s3Service.expectTrait(EndpointTestsTrait.class).getTestCases(); + } + + @Test + void compileToBddWithOptimizations() { + // Verify transforms preserve semantics by running all test cases + assertFalse(testCases.isEmpty(), "S3 model should have endpoint test cases"); + for (EndpointTestCase testCase : testCases) { + TestEvaluator.evaluate(rules, testCase); + } + + // Build CFG and compile to BDD + Cfg cfg = Cfg.from(rules); + EndpointBddTrait trait = EndpointBddTrait.from(cfg); + + StringBuilder sb = new StringBuilder(); + sb.append("\n=== BDD STATS ===\n"); + sb.append("Conditions: ").append(trait.getConditions().size()).append("\n"); + sb.append("Results: ").append(trait.getResults().size()).append("\n"); + sb.append("Initial BDD nodes: ").append(trait.getBdd().getNodeCount()).append("\n"); + + // Apply sifting optimization + SiftingOptimization sifting = SiftingOptimization.builder().cfg(cfg).build(); + EndpointBddTrait siftedTrait = sifting.apply(trait); + sb.append("After sifting - nodes: ").append(siftedTrait.getBdd().getNodeCount()).append("\n"); + + // Apply cost optimization + CostOptimization cost = CostOptimization.builder().cfg(cfg).build(); + EndpointBddTrait optimizedTrait = cost.apply(siftedTrait); + sb.append("After cost opt - nodes: ").append(optimizedTrait.getBdd().getNodeCount()).append("\n"); + + // Print conditions for analysis + sb.append("\n=== CONDITIONS ===\n"); + for (int i = 0; i < trait.getConditions().size(); i++) { + sb.append(i).append(": ").append(trait.getConditions().get(i)).append("\n"); + } + + System.out.println(sb); + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java index f748a289f56..71aca7412cc 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java @@ -219,30 +219,21 @@ public final class S3TreeRewriter { // Auth scheme name literal shared across all rewritten endpoints private static final Literal AUTH_NAME_LITERAL = Literal.stringLiteral(Template.fromString("{" + VAR_AUTH + "}")); - // Patterns to match S3Express bucket endpoint URLs (with AZ) - // Format: https://{Bucket}.s3express[-fips]-{AZ}[.dualstack].{Region}.amazonaws.com - // (negative lookahead (?!dualstack) prevents matching dualstack variants in non-DS patterns) - private static final Pattern S3EXPRESS_FIPS_DS = Pattern.compile("(s3express)-fips-([^.]+)\\.dualstack\\.(.+)$"); - private static final Pattern S3EXPRESS_FIPS = Pattern.compile("(s3express)-fips-([^.]+)\\.(?!dualstack)(.+)$"); - private static final Pattern S3EXPRESS_DS = Pattern.compile("(s3express)-([^.]+)\\.dualstack\\.(.+)$"); - private static final Pattern S3EXPRESS_PLAIN = Pattern.compile("(s3express)-([^.]+)\\.(?!dualstack)(.+)$"); - - // Patterns to match S3Express control plane URLs (no AZ) - // Format: https://s3express-control[-fips][.dualstack].{Region}.amazonaws.com - private static final Pattern S3EXPRESS_CONTROL_FIPS_DS = Pattern.compile( - "(s3express-control)-fips\\.dualstack\\.(.+)$"); - private static final Pattern S3EXPRESS_CONTROL_FIPS = Pattern.compile( - "(s3express-control)-fips\\.(?!dualstack)(.+)$"); - private static final Pattern S3EXPRESS_CONTROL_DS = Pattern.compile( - "(s3express-control)\\.dualstack\\.(.+)$"); - private static final Pattern S3EXPRESS_CONTROL_PLAIN = Pattern.compile( - "(s3express-control)\\.(?!dualstack)(.+)$"); - - // Cached canonical expression for AZ extraction: split(Bucket, "--", 0) - private static final Split BUCKET_SPLIT = Split.ofExpressions( - Expression.getReference(ID_BUCKET), - Expression.of("--"), - Expression.of(0)); + // URL pattern matchers, ordered from most specific to least specific. + // Control plane patterns (no AZ) come first, then bucket patterns (with AZ). + // Negative lookahead (?!dualstack) prevents matching dualstack variants in non-DS patterns. + private static final UrlPatternMatcher[] URL_PATTERNS = { + // Control plane: https://s3express-control[-fips][.dualstack].{Region}.amazonaws.com + new UrlPatternMatcher("(s3express-control)-fips\\.dualstack\\.(.+)$", false), + new UrlPatternMatcher("(s3express-control)-fips\\.(?!dualstack)(.+)$", false), + new UrlPatternMatcher("(s3express-control)\\.dualstack\\.(.+)$", false), + new UrlPatternMatcher("(s3express-control)\\.(?!dualstack)(.+)$", false), + // Bucket: https://{Bucket}.s3express[-fips]-{AZ}[.dualstack].{Region}.amazonaws.com + new UrlPatternMatcher("(s3express)-fips-([^.]+)\\.dualstack\\.(.+)$", true), + new UrlPatternMatcher("(s3express)-fips-([^.]+)\\.(?!dualstack)(.+)$", true), + new UrlPatternMatcher("(s3express)-([^.]+)\\.dualstack\\.(.+)$", true), + new UrlPatternMatcher("(s3express)-([^.]+)\\.(?!dualstack)(.+)$", true), + }; private int rewrittenCount = 0; private int totalS3ExpressCount = 0; @@ -256,7 +247,15 @@ private S3TreeRewriter() {} * @return the transformed rule set */ public static EndpointRuleSet transform(EndpointRuleSet ruleSet) { - return new S3TreeRewriter().run(ruleSet); + S3TreeRewriter rewriter = new S3TreeRewriter(); + EndpointRuleSet result = rewriter.run(ruleSet); + + LOGGER.info(() -> String.format( + "S3 tree rewriter: %s/%s S3Express endpoints rewritten", + rewriter.rewrittenCount, + rewriter.totalS3ExpressCount)); + + return result; } private EndpointRuleSet run(EndpointRuleSet ruleSet) { @@ -265,11 +264,6 @@ private EndpointRuleSet run(EndpointRuleSet ruleSet) { transformedRules.add(transformRule(rule)); } - LOGGER.info(() -> String.format( - "S3 tree rewriter: %s/%s S3Express endpoints rewritten", - rewrittenCount, - totalS3ExpressCount)); - return EndpointRuleSet.builder() .sourceLocation(ruleSet.getSourceLocation()) .parameters(ruleSet.getParameters()) @@ -281,7 +275,6 @@ private EndpointRuleSet run(EndpointRuleSet ruleSet) { private Rule transformRule(Rule rule) { if (rule instanceof TreeRule) { TreeRule tr = (TreeRule) rule; - // Transform conditions List transformedConditions = transformConditions(tr.getConditions()); List transformedChildren = new ArrayList<>(); for (Rule child : tr.getRules()) { @@ -304,29 +297,20 @@ private List transformConditions(List conditions) { return result; } - /** - * Transforms a single condition. - * - *

Handles: - *

-     * AZ extraction: substring(Bucket, N, M) -> split(Bucket, "--")[1]
-     * 
- * - *

Note: Delimiter checks (s3expressAvailabilityZoneDelim) are not currently transformed because they're part - * of a complex fallback structure, and changing them breaks control flow. Possibly something we can improve, or - * wait until the upstream rules are optimized. - */ private Condition transformCondition(Condition cond) { - // Is this a condition fishing for delimiters? + // Transform AZ extraction: substring(Bucket, N, M) -> split(Bucket, "--")[1] if (cond.getResult().isPresent() && ID_AZ_ID.equals(cond.getResult().get()) && cond.getFunction() instanceof Substring && isSubstringOnBucket((Substring) cond.getFunction())) { - // Replace with split-based extraction: split(Bucket, "--")[1] - GetAttr azExpr = GetAttr.ofExpressions(BUCKET_SPLIT, "[1]"); + // Create fresh expression each time to avoid type-checking conflicts + Split bucketSplit = Split.ofExpressions( + Expression.getReference(ID_BUCKET), + Expression.of("--"), + Expression.of(0)); + GetAttr azExpr = GetAttr.ofExpressions(bucketSplit, "[1]"); return cond.toBuilder().fn(azExpr).build(); } - return cond; } @@ -335,50 +319,19 @@ private boolean isSubstringOnBucket(Substring substring) { if (args.isEmpty()) { return false; } - Expression target = args.get(0); return target instanceof Reference && ID_BUCKET.equals(((Reference) target).getName()); } - // Creates ITE conditions for branchless S3Express variable computation. - private List createIteConditions() { - List conditions = new ArrayList<>(); - conditions.add(createIteAssignment(VAR_FIPS, Expression.getReference(ID_USE_FIPS), FIPS_SUFFIX, EMPTY_SUFFIX)); - conditions.add(createIteAssignment( - VAR_DS, - Expression.getReference(ID_USE_DUAL_STACK), - DS_SUFFIX, - EMPTY_SUFFIX)); - // Auth scheme: sigv4 when session auth disabled, sigv4-s3express otherwise - Expression sessionAuthDisabled = Coalesce.ofExpressions( - Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), - Expression.of(false)); - conditions.add(createIteAssignment(VAR_AUTH, sessionAuthDisabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS)); - return conditions; - } - - // Creates an ITE-based assignment condition. - private Condition createIteAssignment(String varName, Expression condition, String trueValue, String falseValue) { - return Condition.builder() - .fn(Ite.ofStrings(condition, trueValue, falseValue)) - .result(varName) - .build(); - } - - // Rewrites an endpoint rule to use canonical S3Express URLs and auth schemes. private Rule rewriteEndpoint(EndpointRule rule) { Endpoint endpoint = rule.getEndpoint(); Expression urlExpr = endpoint.getUrl(); - // Extract the raw URL string from the expression (IFF it's a static string, rarely is anything else). String urlStr = extractUrlString(urlExpr); if (urlStr == null) { return rule; } - // Check if this is an S3Express endpoint by URL or backend property. - // Note: while `contains("s3express")` is broad and could theoretically match path/query components, - // the subsequent matchUrl() call validates the hostname pattern before any rewriting occurs. boolean isS3ExpressUrl = urlStr.contains("s3express"); boolean isS3ExpressBackend = isS3ExpressBackend(endpoint); @@ -391,11 +344,9 @@ private Rule rewriteEndpoint(EndpointRule rule) { // For URL override endpoints (backend=S3Express but URL doesn't match s3express hostname), // just canonicalize the auth scheme - no URL rewriting needed if (isS3ExpressBackend && !isS3ExpressUrl) { - // Canonicalize auth scheme to use {_s3e_auth} Map newProperties = canonicalizeAuthScheme(endpoint.getProperties()); if (newProperties == endpoint.getProperties()) { - // No changes needed return rule; } @@ -408,7 +359,6 @@ private Rule rewriteEndpoint(EndpointRule rule) { .sourceLocation(endpoint.getSourceLocation()) .build(); - // Add auth ITE condition for URL override endpoints List allConditions = new ArrayList<>(rule.getConditions()); allConditions.add(createAuthIteCondition()); @@ -425,17 +375,13 @@ private Rule rewriteEndpoint(EndpointRule rule) { rewrittenCount++; - // Rewrite the URL to use the ITE-assigned variables String newUrl = match.rewriteUrl(); - // Canonicalize auth scheme for bucket endpoints (not control plane) - // Control plane always uses sigv4, bucket endpoints vary based on DisableS3ExpressSessionAuth Map newProperties = endpoint.getProperties(); if (match instanceof BucketUrlMatchResult) { newProperties = canonicalizeAuthScheme(endpoint.getProperties()); } - // Build the new endpoint with canonicalized URL and properties Endpoint newEndpoint = Endpoint.builder() .url(Expression.of(newUrl)) .headers(endpoint.getHeaders()) @@ -443,7 +389,6 @@ private Rule rewriteEndpoint(EndpointRule rule) { .sourceLocation(endpoint.getSourceLocation()) .build(); - // Add ITE conditions: original conditions first, then ITE conditions at the end. List allConditions = new ArrayList<>(rule.getConditions()); allConditions.addAll(createIteConditions()); @@ -452,29 +397,42 @@ private Rule rewriteEndpoint(EndpointRule rule) { .endpoint(newEndpoint); } - // Checks if the endpoint has `backend` property set to "S3Express". + private List createIteConditions() { + List conditions = new ArrayList<>(); + conditions.add(createIteAssignment(VAR_FIPS, Expression.getReference(ID_USE_FIPS), FIPS_SUFFIX, EMPTY_SUFFIX)); + conditions.add(createIteAssignment(VAR_DS, Expression.getReference(ID_USE_DUAL_STACK), DS_SUFFIX, EMPTY_SUFFIX)); + Expression sessionAuthDisabled = Coalesce.ofExpressions( + Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), + Expression.of(false)); + conditions.add(createIteAssignment(VAR_AUTH, sessionAuthDisabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS)); + return conditions; + } + + private Condition createIteAssignment(String varName, Expression condition, String trueValue, String falseValue) { + return Condition.builder() + .fn(Ite.ofStrings(condition, trueValue, falseValue)) + .result(varName) + .build(); + } + private boolean isS3ExpressBackend(Endpoint endpoint) { Literal backend = endpoint.getProperties().get(ID_BACKEND); if (backend == null) { return false; } - return backend.asStringLiteral() .filter(Template::isStatic) .map(t -> "S3Express".equalsIgnoreCase(t.expectLiteral())) .orElse(false); } - // Creates just the auth ITE condition for URL override endpoints. private Condition createAuthIteCondition() { - // `DisableS3ExpressSessionAuth` is nullable, so we need to coalesce it to have a false default. Fix upstream? Expression isSessionAuthDisabled = Coalesce.ofExpressions( Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), Expression.of(false)); return createIteAssignment(VAR_AUTH, isSessionAuthDisabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS); } - // Canonicalizes the authScheme name in endpoint properties to use the ITE variable. private Map canonicalizeAuthScheme(Map properties) { Literal authSchemes = properties.get(ID_AUTH_SCHEMES); if (authSchemes == null) { @@ -486,30 +444,25 @@ private Map canonicalizeAuthScheme(Map return properties; } - // Rewrite each auth scheme's name field List newSchemes = new ArrayList<>(); for (Literal scheme : schemes) { Map record = scheme.asRecordLiteral().orElse(null); if (record == null) { - // Auth is always a record, but maybe that changes in the future, so pass it through. newSchemes.add(scheme); continue; } Literal nameLiteral = record.get(ID_NAME); if (nameLiteral == null) { - // "name" should always be set, but pass through if not. newSchemes.add(scheme); continue; } - // Only transform string literals we recognize. String name = nameLiteral.asStringLiteral() .filter(Template::isStatic) .map(Template::expectLiteral) .orElse(null); - // Only rewrite if it's one of the S3Express auth schemes if (AUTH_SIGV4.equals(name) || AUTH_SIGV4_S3EXPRESS.equals(name)) { Map newRecord = new LinkedHashMap<>(record); newRecord.put(ID_NAME, AUTH_NAME_LITERAL); @@ -524,63 +477,20 @@ private Map canonicalizeAuthScheme(Map return newProperties; } - // Extracts the raw URL string from a URL expression. private String extractUrlString(Expression urlExpr) { return urlExpr.toNode().asStringNode().map(StringNode::getValue).orElse(null); } - // Matches an S3Express URL and returns the pattern match info. Tries to match in most specific order. private UrlMatchResult matchUrl(String url) { - Matcher m; - - // First try control plane patterns (no AZ) since these are more specific - m = S3EXPRESS_CONTROL_FIPS_DS.matcher(url); - if (m.find()) { - return new ControlPlaneUrlMatchResult(url, m); - } - - m = S3EXPRESS_CONTROL_FIPS.matcher(url); - if (m.find()) { - return new ControlPlaneUrlMatchResult(url, m); - } - - m = S3EXPRESS_CONTROL_DS.matcher(url); - if (m.find()) { - return new ControlPlaneUrlMatchResult(url, m); - } - - m = S3EXPRESS_CONTROL_PLAIN.matcher(url); - if (m.find()) { - return new ControlPlaneUrlMatchResult(url, m); - } - - // Next, try bucket endpoint patterns (with AZ) - m = S3EXPRESS_FIPS_DS.matcher(url); - if (m.find()) { - return new BucketUrlMatchResult(url, m); - } - - m = S3EXPRESS_FIPS.matcher(url); - if (m.find()) { - return new BucketUrlMatchResult(url, m); - } - - m = S3EXPRESS_DS.matcher(url); - if (m.find()) { - return new BucketUrlMatchResult(url, m); - } - - m = S3EXPRESS_PLAIN.matcher(url); - if (m.find()) { - return new BucketUrlMatchResult(url, m); + for (UrlPatternMatcher matcher : URL_PATTERNS) { + UrlMatchResult result = matcher.match(url); + if (result != null) { + return result; + } } - return null; } - /** - * Result of matching an S3Express URL pattern. - */ private abstract static class UrlMatchResult { protected final String prefix; @@ -591,9 +501,6 @@ private abstract static class UrlMatchResult { abstract String rewriteUrl(); } - /** - * Match result for bucket endpoints (with AZ): {prefix}s3express{fips}-{AZ}{ds}.{region} - */ private static final class BucketUrlMatchResult extends UrlMatchResult { private final String s3express; private final String az; @@ -612,9 +519,6 @@ String rewriteUrl() { } } - /** - * Match result for control plane endpoints (no AZ): {prefix}s3express-control{fips}{ds}.{region} - */ private static final class ControlPlaneUrlMatchResult extends UrlMatchResult { private final String s3expressControl; private final String regionSuffix; @@ -630,4 +534,25 @@ String rewriteUrl() { return String.format("%s%s{%s}{%s}.%s", prefix, s3expressControl, VAR_FIPS, VAR_DS, regionSuffix); } } + + private static final class UrlPatternMatcher { + private final Pattern pattern; + private final boolean isBucketPattern; + + UrlPatternMatcher(String regex, boolean isBucketPattern) { + this.pattern = Pattern.compile(regex); + this.isBucketPattern = isBucketPattern; + } + + UrlMatchResult match(String url) { + Matcher m = pattern.matcher(url); + if (!m.find()) { + return null; + } else if (isBucketPattern) { + return new BucketUrlMatchResult(url, m); + } else { + return new ControlPlaneUrlMatchResult(url, m); + } + } + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java index c5daeab89d1..dd543b5dba4 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java @@ -20,6 +20,7 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -231,18 +232,29 @@ public String toString() { * @return true if arguments should be swapped */ protected static boolean shouldSwapArgs(Expression arg0, Expression arg1) { - boolean arg0IsRef = isReference(arg0); - boolean arg1IsRef = isReference(arg1); + boolean arg0IsLiteral = isStaticLiteral(arg0); + boolean arg1IsLiteral = isStaticLiteral(arg1); - // Always put References before literals to make things consistent - if (arg0IsRef != arg1IsRef) { - return !arg0IsRef; // Swap if arg0 is literal and arg1 is reference + // Always put non-literals (expressions) before literals to make things consistent + if (arg0IsLiteral != arg1IsLiteral) { + return arg0IsLiteral; // Swap if arg0 is literal and arg1 is not } // Both same type, use string comparison for deterministic order return arg0.toString().compareTo(arg1.toString()) > 0; } + /** + * Returns true if the expression is a static literal (constant value). + * Dynamic string literals (templates with variables) are not considered static. + */ + private static boolean isStaticLiteral(Expression arg) { + if (arg instanceof StringLiteral) { + return ((StringLiteral) arg).value().isStatic(); + } + return arg instanceof Literal; + } + /** * Strips single-variable template wrappers if present. * Converts "{varName}" to just varName reference. @@ -264,13 +276,4 @@ static Expression stripSingleVariableTemplate(Expression expr) { return expr; } - private static boolean isReference(Expression arg) { - if (arg instanceof Reference) { - return true; - } else if (arg instanceof StringLiteral) { - StringLiteral s = (StringLiteral) arg; - return !s.value().isStatic(); - } - return false; - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java index 1cadcce2525..95933785140 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -15,6 +15,7 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; @@ -41,6 +42,9 @@ public final class CfgBuilder { private final Map resultCache = new HashMap<>(); private final Map resultNodeCache = new HashMap<>(); + // Track function expressions that have bindings (for isSet consolidation) + private final Map functionBindings = new HashMap<>(); + public CfgBuilder(EndpointRuleSet ruleSet) { // Apply SSA transform to ensure globally unique variable names this.ruleSet = SsaTransform.transform(ruleSet); @@ -115,9 +119,25 @@ public ConditionReference createConditionReference(Condition condition) { negated = !negated; } + // Consolidate isSet(f(x)) with existing v = f(x) bindings + canonical = consolidateIsSetWithBinding(canonical); + + // Deep-copy via serialization to get fresh Expression objects. + // This avoids sharing expressions that may have cached types from + // being type-checked in different scopes during EndpointRuleSet.build(). + canonical = Condition.fromNode(canonical.toNode()); + + // Track bindings for future isSet consolidation + if (canonical.getResult().isPresent()) { + String fnKey = canonical.getFunction().toString(); + functionBindings.putIfAbsent(fnKey, canonical); + } + ConditionReference reference = new ConditionReference(canonical, negated); conditionToReference.put(condition, reference); + // Also cache the canonical form so equivalent conditions from different branches + // that might have different original objects will still hit the cache. if (!negated && !condition.equals(canonical)) { conditionToReference.put(canonical, reference); } @@ -129,6 +149,32 @@ private Rule intern(Rule rule) { return resultCache.computeIfAbsent(canonicalizeResult(rule), k -> k); } + /** + * Consolidates {@code isSet(f(x))} with an existing {@code v = f(x)} binding. + * + *

This catches patterns that tree-level optimization cannot handle. Specifically, when the tree + * contains {@code not(isSet(f(x)))} in one branch and {@code v = f(x)} in another branch. + * Tree-level transforms can't handle this because they don't have visibility across sibling branches. + * The CFG builder's global {@code functionBindings} map provides this cross-branch visibility. + * + * @see SsaTransform for the full explanation of why both tree-level and CFG-level optimization are needed + */ + private Condition consolidateIsSetWithBinding(Condition condition) { + if (!(condition.getFunction() instanceof IsSet) || condition.getResult().isPresent()) { + return condition; + } + Expression inner = condition.getFunction().getArguments().get(0); + if (!(inner instanceof LibraryFunction)) { + return condition; + } + String fnKey = inner.toString(); + Condition existingBinding = functionBindings.get(fnKey); + if (existingBinding != null) { + return existingBinding; + } + return condition; + } + private Rule canonicalizeResult(Rule rule) { return rule == null ? null : rule.withConditions(Collections.emptyList()); } @@ -139,19 +185,23 @@ private Condition canonicalizeBooleanEquals(Condition condition) { } List args = condition.getFunction().getArguments(); - if (args.size() != 2 || !(args.get(0) instanceof Reference) || !(args.get(1) instanceof Literal)) { + if (args.size() != 2) { + return condition; + } + + // After canonicalization, literals should be in arg1 position + // Check if arg1 is a boolean literal with value false + if (!(args.get(1) instanceof Literal)) { return condition; } - Reference ref = (Reference) args.get(0); Boolean literalValue = ((Literal) args.get(1)).asBooleanLiteral().orElse(null); - if (literalValue != null && !literalValue && ruleSet != null) { - String varName = ref.getName().toString(); - Optional param = ruleSet.getParameters().get(Identifier.of(varName)); - if (param.isPresent() && param.get().getDefault().isPresent()) { - return condition.toBuilder().fn(BooleanEquals.ofExpressions(ref, true)).build(); - } + // Normalize booleanEquals(X, false) to booleanEquals(X, true) with negation + if (literalValue != null && !literalValue) { + return condition.toBuilder() + .fn(BooleanEquals.ofExpressions(args.get(0), true)) + .build(); } return condition; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java index 365ebe3e77c..17196dbc788 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java @@ -39,14 +39,13 @@ final class CoalesceTransform { private int cacheHits = 0; private int skippedNoZeroValue = 0; private int skippedMultipleUses = 0; - private final Set skippedRecordTypes = new HashSet<>(); static EndpointRuleSet transform(EndpointRuleSet ruleSet) { CoalesceTransform transform = new CoalesceTransform(); List transformedRules = new ArrayList<>(); - for (int i = 0; i < ruleSet.getRules().size(); i++) { - transformedRules.add(transform.transformRule(ruleSet.getRules().get(i), "root/rule[" + i + "]")); + for (Rule rule : ruleSet.getRules()) { + transformedRules.add(transform.transformRule(rule)); } if (LOGGER.isLoggable(Level.INFO)) { @@ -65,7 +64,7 @@ static EndpointRuleSet transform(EndpointRuleSet ruleSet) { .build(); } - private Rule transformRule(Rule rule, String rulePath) { + private Rule transformRule(Rule rule) { // Count local usage for THIS rule's conditions Map localVarUsage = new HashMap<>(); for (Condition condition : rule.getConditions()) { @@ -82,10 +81,14 @@ private Rule transformRule(Rule rule, String rulePath) { if (rule instanceof TreeRule) { TreeRule treeRule = (TreeRule) rule; + List transformedNested = new ArrayList<>(); + for (Rule nested : treeRule.getRules()) { + transformedNested.add(transformRule(nested)); + } return TreeRule.builder() .description(rule.getDocumentation().orElse(null)) .conditions(transformedConditions) - .treeRule(TreeRewriter.transformNestedRules(treeRule, rulePath, this::transformRule)); + .treeRule(transformedNested); } // CoalesceTransform only modifies conditions, not endpoints/errors @@ -144,7 +147,6 @@ private boolean canCoalesce(String var, Condition bind, Condition use, Map replacements = new HashMap<>(); replacements.put(var, coalesced); - Expression replaced = TreeRewriter.forReplacements(replacements).rewrite(useExpr); + Expression replaced = TreeMapper.newReferenceReplacingMapper(replacements).expression(useExpr); LibraryFunction canonicalized = ((LibraryFunction) replaced).canonicalize(); Condition.Builder builder = Condition.builder().fn(canonicalized); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/DeadStoreEliminationTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/DeadStoreEliminationTransform.java new file mode 100644 index 00000000000..58cb097ccf6 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/DeadStoreEliminationTransform.java @@ -0,0 +1,51 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; + +/** + * Removes bindings from conditions when the bound variable is never used. + * + *

As part of the optimization pipeline, we rewrite some expressions to create a + * variable binding in case the expressions can be consolidated with + * VariableConsolidationTransform. When they can't, it would leave a dangling binding + * that isn't used, so this transform rewrites those conditions to not create + * these dead stores. + */ +final class DeadStoreEliminationTransform extends TreeMapper { + private static final Logger LOGGER = Logger.getLogger(DeadStoreEliminationTransform.class.getName()); + + private int eliminated = 0; + private final VariableAnalysis analysis; + + private DeadStoreEliminationTransform(VariableAnalysis analysis) { + this.analysis = analysis; + } + + static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + DeadStoreEliminationTransform t = new DeadStoreEliminationTransform(VariableAnalysis.analyze(ruleSet)); + EndpointRuleSet result = t.endpointRuleSet(ruleSet); + if (t.eliminated > 0) { + LOGGER.info(() -> "Dead store elimination: " + t.eliminated + " bindings removed"); + } + return result; + } + + @Override + public Condition condition(Rule rule, Condition cond) { + if (cond.getResult().isPresent()) { + String varName = cond.getResult().get().toString(); + if (analysis.getReferenceCount(varName) == 0) { + eliminated++; + return Condition.builder().fn(cond.getFunction()).build(); + } + } + return cond; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/IsSetBooleanCoalesceTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/IsSetBooleanCoalesceTransform.java new file mode 100644 index 00000000000..de855ed8618 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/IsSetBooleanCoalesceTransform.java @@ -0,0 +1,144 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Logger; +import software.amazon.smithy.model.node.ArrayNode; +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; + +/** + * Coalesces consecutive isSet + booleanEquals patterns into a single coalesced check. + * + *

This transform identifies patterns where: + *

    + *
  • {@code isSet(X)} is immediately followed by {@code booleanEquals(X, true)}
  • + *
  • {@code isSet(X)} is immediately followed by {@code booleanEquals(X, false)}
  • + *
+ * + *

These patterns are replaced with: + *

    + *
  • {@code booleanEquals(coalesce(X, false), true)} - equivalent to "X is set and true"
  • + *
  • {@code booleanEquals(coalesce(X, true), false)} - equivalent to "X is set and false"
  • + *
+ * + *

This reduces the number of conditions in the BDD, improving both space and potentially node count. + */ +final class IsSetBooleanCoalesceTransform extends TreeMapper { + private static final Logger LOGGER = Logger.getLogger(IsSetBooleanCoalesceTransform.class.getName()); + + private int coalescedCount = 0; + + static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + IsSetBooleanCoalesceTransform t = new IsSetBooleanCoalesceTransform(); + List transformedRules = new ArrayList<>(); + for (Rule rule : ruleSet.getRules()) { + transformedRules.add(t.rule(rule)); + } + + if (t.coalescedCount > 0) { + LOGGER.info(() -> String.format("IsSet+boolean coalesce: %d patterns collapsed", t.coalescedCount)); + } + + // Build the node representation manually to avoid type-checking, then deserialize + // to get fresh Expression objects without cached types. + ObjectNode.Builder builder = ruleSet.toNode().expectObjectNode().toBuilder(); + ArrayNode.Builder rulesBuilder = ArrayNode.builder(); + for (Rule rule : transformedRules) { + rulesBuilder.withValue(rule.toNode()); + } + builder.withMember("rules", rulesBuilder.build()); + + return EndpointRuleSet.fromNode(builder.build()); + } + + /** + * Overrides conditions processing to look at pairs of consecutive conditions. + * This is necessary because we need to merge isSet(X) + booleanEquals(X, val) patterns. + */ + @Override + public List conditions(Rule rule, List conditions) { + List result = new ArrayList<>(); + int size = conditions.size(); + + for (int i = 0; i < size; i++) { + Condition current = conditions.get(i); + // Check for isSet(X) pattern + if (i + 1 < size && !current.getResult().isPresent() && current.getFunction() instanceof IsSet) { + Expression isSetArg = current.getFunction().getArguments().get(0); + if (isSetArg instanceof Reference) { + Condition next = conditions.get(i + 1); + + // Check if next is booleanEquals(X, true/false) + Condition coalesced = tryCoalesce((Reference) isSetArg, next); + if (coalesced != null) { + result.add(coalesced); + coalescedCount++; + i++; // Skip the next condition since we ate it + continue; + } + } + } + + result.add(current); + } + + return result; + } + + private Condition tryCoalesce(Reference isSetRef, Condition next) { + if (next.getResult().isPresent()) { + return null; // Can't coalesce if next has a binding + } + + LibraryFunction fn = next.getFunction(); + if (!(fn instanceof BooleanEquals)) { + return null; + } + + List args = fn.getArguments(); + Expression arg1 = args.get(0); + Expression arg2 = args.get(1); + + // Ensure the first argument is a reference to the isSet(X) condition + if (!(arg1 instanceof Reference)) { + return null; + } + Reference boolRef = (Reference) arg1; + if (!boolRef.getName().equals(isSetRef.getName())) { + return null; + } + + // Ensure the second argument is a boolean literal + if (!(arg2 instanceof Literal)) { + return null; + } + Boolean literalValue = ((Literal) arg2).asBooleanLiteral().orElse(null); + if (literalValue == null) { + return null; + } + + // Create a coalesced condition that sets the default to the opposite of the literal value: + // isSet(X) && booleanEquals(X, true) -> booleanEquals(coalesce(X, false), true) + // isSet(X) && booleanEquals(X, false) -> booleanEquals(coalesce(X, true), false) + boolean defaultValue = !literalValue; + + // Create a fresh Reference to avoid sharing cached type information with other conditions + Reference freshRef = Expression.getReference(isSetRef.getName()); + Expression coalesced = Coalesce.ofExpressions(freshRef, Expression.of(defaultValue)); + return Condition.builder().fn(BooleanEquals.ofExpressions(coalesced, literalValue)).build(); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java index f1da1f16520..62659aa46af 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java @@ -5,114 +5,104 @@ package software.amazon.smithy.rulesengine.logic.cfg; import java.util.ArrayDeque; -import java.util.ArrayList; import java.util.Deque; import java.util.HashMap; -import java.util.HashSet; -import java.util.IdentityHashMap; -import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; -import software.amazon.smithy.rulesengine.language.Endpoint; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; /** - * Transforms a decision tree into Static Single Assignment (SSA) form. + * Transforms a decision tree into Static Single Assignment (SSA) form and orchestrates the pre-BDD optimization + * pipeline (see transform()). * - *

This transformation ensures that each variable is assigned exactly once by renaming variables when they are - * reassigned in different parts of the tree. For example, if variable "x" is assigned in multiple branches, they - * become "x_ssa_1", "x_ssa_2", "x_ssa_3", etc. Without this transform, the BDD compilation would confuse divergent - * paths that have the same variable name. + *

Why Tree-Level AND CFG-Level Optimization?

* - *

Note that this transform is only applied when the reassignment is done using different - * arguments than previously seen assignments of the same variable name. + *

Tree-level transforms provide the bulk of optimization, but {@link CfgBuilder} performs additional + * consolidation via {@code consolidateIsSetWithBinding}. This catches cross-branch patterns that tree-level + * transforms cannot see. + * + *

Specifically, when one branch contains {@code not(isSet(f(x)))} and another branch contains + * {@code v = f(x)}, tree-level transforms can't consolidate them because: + *

    + *
  • {@link SyntheticBindingTransform} doesn't see the inner {@code isSet} - it checks the outer function + * which is {@code Not}, not {@code IsSet}
  • + *
  • Tree transforms don't have visibility across sibling branches
  • + *
+ * + *

During CFG construction, {@code CfgBuilder#isNegationWrapper} unwraps the {@code Not} to get + * {@code isSet(f(x))}, and {@code consolidateIsSetWithBinding} can then consolidate it with the existing + * binding from the other branch using its global {@code functionBindings} map. + * + *

SSA Renaming

+ * + *

The SSA portion ensures each variable is assigned exactly once by renaming variables when they are + * reassigned in different parts of the tree. For example, if variable "x" is assigned in multiple branches + * with different expressions, they become "x_ssa_1", "x_ssa_2", etc. Without this, BDD compilation would + * incorrectly share nodes for divergent paths that happen to use the same variable name. + * + *

SSA renaming is only applied when reassignment uses different arguments than previously seen assignments. + * + * @see CfgBuilder#createConditionReference for the CFG-level consolidation that catches cross-branch patterns */ -final class SsaTransform { +final class SsaTransform extends TreeMapper { private final Deque> scopeStack = new ArrayDeque<>(); - private final Map rewrittenConditions = new IdentityHashMap<>(); - private final Map rewrittenRules = new IdentityHashMap<>(); - private final VariableAnalysis variableAnalysis; - private final TreeRewriter referenceRewriter; - - private SsaTransform(VariableAnalysis variableAnalysis) { - scopeStack.push(new HashMap<>()); - this.variableAnalysis = variableAnalysis; - this.referenceRewriter = new TreeRewriter(this::referenceRewriter, this::needsRewriting); - } + private final VariableAnalysis analysis; - private Expression referenceRewriter(Reference ref) { - String originalName = ref.getName().toString(); - String uniqueName = resolveReference(originalName); - return Expression.getReference(Identifier.of(uniqueName)); + private SsaTransform(VariableAnalysis analysis) { + this.analysis = analysis; + // Seed initial scope with input parameters. + Map initialScope = new HashMap<>(); + for (String param : analysis.getInputParams()) { + initialScope.put(param, param); + } + scopeStack.push(initialScope); } static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + // Collapse isSet(X) + booleanEquals(X, true/false) into coalesced checks + ruleSet = IsSetBooleanCoalesceTransform.transform(ruleSet); + // Assign synthetic bindings to enable variable consolidation + ruleSet = SyntheticBindingTransform.transform(ruleSet); + // Consolidate variables and eliminate redundant bindings ruleSet = VariableConsolidationTransform.transform(ruleSet); + // Remove bindings that are never used (before coalescing inlines them) + ruleSet = DeadStoreEliminationTransform.transform(ruleSet); + // Coalesces bind-then-use patterns to reduce condition count and branching. ruleSet = CoalesceTransform.transform(ruleSet); - VariableAnalysis variableAnalysis = VariableAnalysis.analyze(ruleSet); - SsaTransform ssaTransform = new SsaTransform(variableAnalysis); - - List rewrittenRules = new ArrayList<>(ruleSet.getRules().size()); - for (Rule original : ruleSet.getRules()) { - rewrittenRules.add(ssaTransform.processRule(original)); - } - - return EndpointRuleSet.builder() - .parameters(ruleSet.getParameters()) - .rules(rewrittenRules) - .version(ruleSet.getVersion()) - .build(); + // Do an SSA transform so divergent paths in the BDD remain logically divergent and thereby correct. + return new SsaTransform(VariableAnalysis.analyze(ruleSet)).endpointRuleSet(ruleSet); } - private Rule processRule(Rule rule) { - enterScope(); - Rule rewrittenRule = rewriteRule(rule); - exitScope(); - return rewrittenRule; - } - - private void enterScope() { - scopeStack.push(new HashMap<>(scopeStack.peek())); - } - - private void exitScope() { - if (scopeStack.size() <= 1) { - throw new IllegalStateException("Cannot exit global scope"); + @Override + public Rule rule(Rule r) { + scopeStack.push(new HashMap<>(peekScope())); + try { + return super.rule(r); + } finally { + scopeStack.pop(); } - scopeStack.pop(); } - private Condition rewriteCondition(Condition condition) { - boolean hasBinding = condition.getResult().isPresent(); - - if (!hasBinding) { - Condition cached = rewrittenConditions.get(condition); - if (cached != null) { - return cached; - } - } - + @Override + public Condition condition(Rule rule, Condition condition) { LibraryFunction fn = condition.getFunction(); - Set rewritableRefs = filterOutInputParameters(fn.getReferences()); String uniqueBindingName = null; boolean needsUniqueBinding = false; - if (hasBinding) { + if (condition.getResult().isPresent()) { String varName = condition.getResult().get().toString(); // Only need SSA rename if variable has multiple bindings - if (variableAnalysis.hasMultipleBindings(varName)) { - Map expressionMap = variableAnalysis.getExpressionMappings().get(varName); + if (analysis.hasMultipleBindings(varName)) { + Map expressionMap = analysis.getExpressionMappings().get(varName); if (expressionMap != null) { uniqueBindingName = expressionMap.get(fn.toString()); needsUniqueBinding = uniqueBindingName != null && !uniqueBindingName.equals(varName); @@ -120,158 +110,59 @@ private Condition rewriteCondition(Condition condition) { } } - if (!needsRewriting(rewritableRefs) && !needsUniqueBinding) { - if (!hasBinding) { - rewrittenConditions.put(condition, condition); - } + if (doesNotNeedRewriting(fn.getReferences()) && !needsUniqueBinding) { return condition; } - LibraryFunction rewrittenExpr = (LibraryFunction) referenceRewriter.rewrite(fn); - boolean exprChanged = rewrittenExpr != fn; + LibraryFunction rewrittenFn = libraryFunction(fn); + boolean fnChanged = rewrittenFn != fn; - Condition rewritten; - if (hasBinding && uniqueBindingName != null) { - scopeStack.peek().put(condition.getResult().get().toString(), uniqueBindingName); - if (needsUniqueBinding || exprChanged) { - rewritten = condition.toBuilder().fn(rewrittenExpr).result(Identifier.of(uniqueBindingName)).build(); - } else { - rewritten = condition; + if (condition.getResult().isPresent() && uniqueBindingName != null) { + bindVariable(condition.getResult().get().toString(), uniqueBindingName); + if (needsUniqueBinding || fnChanged) { + return condition.toBuilder().fn(rewrittenFn).result(Identifier.of(uniqueBindingName)).build(); } - } else if (exprChanged) { - rewritten = condition.toBuilder().fn(rewrittenExpr).build(); - } else { - rewritten = condition; - } - - if (!hasBinding) { - rewrittenConditions.put(condition, rewritten); + } else if (fnChanged) { + return condition.toBuilder().fn(rewrittenFn).build(); } - return rewritten; + return condition; } - private Set filterOutInputParameters(Set references) { - if (references.isEmpty() || variableAnalysis.getInputParams().isEmpty()) { - return references; - } - - Set filtered = new HashSet<>(references); - filtered.removeAll(variableAnalysis.getInputParams()); - return filtered; + Map peekScope() { + return Objects.requireNonNull(scopeStack.peek(), "Scope stack is empty"); } - private boolean needsRewriting(Set references) { - if (references.isEmpty()) { - return false; - } - - Map currentScope = scopeStack.peek(); - for (String ref : references) { - String mapped = currentScope.get(ref); - if (mapped != null && !mapped.equals(ref)) { - return true; - } - } - return false; - } - - private boolean needsRewriting(Expression expression) { - return needsRewriting(filterOutInputParameters(expression.getReferences())); - } - - private Rule rewriteRule(Rule rule) { - Rule cached = rewrittenRules.get(rule); - if (cached != null) { - return cached; - } - - List rewrittenConditions = rewriteConditions(rule.getConditions()); - boolean conditionsChanged = !rewrittenConditions.equals(rule.getConditions()); - - Rule result; - if (rule instanceof EndpointRule) { - result = rewriteEndpointRule((EndpointRule) rule, rewrittenConditions, conditionsChanged); - } else if (rule instanceof ErrorRule) { - result = rewriteErrorRule((ErrorRule) rule, rewrittenConditions, conditionsChanged); - } else if (rule instanceof TreeRule) { - result = rewriteTreeRule((TreeRule) rule, rewrittenConditions, conditionsChanged); - } else if (conditionsChanged) { - throw new UnsupportedOperationException("Cannot change rule: " + rule); - } else { - result = rule; - } - - rewrittenRules.put(rule, result); - return result; - } - - private List rewriteConditions(List conditions) { - List rewritten = new ArrayList<>(conditions.size()); - for (Condition condition : conditions) { - rewritten.add(rewriteCondition(condition)); - } - return rewritten; + @Override + public Expression expression(Expression expression) { + return doesNotNeedRewriting(expression.getReferences()) ? expression : super.expression(expression); } - private Rule rewriteEndpointRule( - EndpointRule rule, - List rewrittenConditions, - boolean conditionsChanged - ) { - Endpoint rewrittenEndpoint = referenceRewriter.rewriteEndpoint(rule.getEndpoint()); - - if (conditionsChanged || rewrittenEndpoint != rule.getEndpoint()) { - return EndpointRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(rewrittenConditions) - .endpoint(rewrittenEndpoint); + @Override + public Reference reference(Reference ref) { + String originalName = ref.getName().toString(); + String uniqueName = peekScope().getOrDefault(originalName, originalName); + if (uniqueName.equals(originalName)) { + return ref; } - - return rule; + return Expression.getReference(Identifier.of(uniqueName)); } - private Rule rewriteErrorRule(ErrorRule rule, List rewrittenConditions, boolean conditionsChanged) { - Expression rewrittenError = referenceRewriter.rewrite(rule.getError()); - - if (conditionsChanged || rewrittenError != rule.getError()) { - return ErrorRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(rewrittenConditions) - .error(rewrittenError); + private void bindVariable(String oldName, String newName) { + Map scope = peekScope(); + String existing = scope.put(oldName, newName); + if (existing != null && !existing.equals(oldName)) { + throw new IllegalStateException("Cannot shadow variable: " + oldName + ", conflicts with " + existing); } - - return rule; } - private Rule rewriteTreeRule(TreeRule rule, List rewrittenConditions, boolean conditionsChanged) { - List rewrittenNestedRules = new ArrayList<>(); - boolean nestedChanged = false; - - for (Rule nestedRule : rule.getRules()) { - enterScope(); - Rule rewritten = rewriteRule(nestedRule); - rewrittenNestedRules.add(rewritten); - if (rewritten != nestedRule) { - nestedChanged = true; + private boolean doesNotNeedRewriting(Set references) { + Map scope = peekScope(); + for (String ref : references) { + if (!ref.equals(scope.getOrDefault(ref, ref))) { + return false; } - exitScope(); - } - - if (conditionsChanged || nestedChanged) { - return TreeRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(rewrittenConditions) - .treeRule(rewrittenNestedRules); } - - return rule; - } - - private String resolveReference(String originalName) { - // Input parameters are never rewritten - return variableAnalysis.getInputParams().contains(originalName) - ? originalName - : scopeStack.peek().getOrDefault(originalName, originalName); + return true; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransform.java new file mode 100644 index 00000000000..264c32db880 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransform.java @@ -0,0 +1,76 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; + +import java.util.logging.Logger; + +/** + * Assigns synthetic bindings to conditions that could benefit from variable consolidation. + * + *

This transform handles two cases: + *

    + *
  1. {@code isSet(f(x))} is rewritten to {@code _synthetic_N = f(x)}, unwrapping the isSet
  2. + *
  3. Bare function calls like {@code f(x)} become {@code _synthetic_N = f(x)}
  4. + *
+ * + *

This enables {@link VariableConsolidationTransform} to consolidate these synthetic bindings + * with real bindings like {@code url = f(x)}, eliminating redundant function calls. If no consolidation later + * occurs, then {@link DeadStoreEliminationTransform} can remove the unnecessary synthetic bindings. + */ +final class SyntheticBindingTransform extends TreeMapper { + private static final Logger LOGGER = Logger.getLogger(SyntheticBindingTransform.class.getName()); + + private int syntheticCounter = 0; + private int transformedCount = 0; + + static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + SyntheticBindingTransform t = new SyntheticBindingTransform(); + EndpointRuleSet result = t.endpointRuleSet(ruleSet); + if (t.transformedCount > 0) { + LOGGER.info(() -> String.format("Synthetic binding: %d conditions transformed", t.transformedCount)); + } + return result; + } + + @Override + public Condition condition(Rule rule, Condition cond) { + // If it already has a binding, then nothing to do + if (cond.getResult().isPresent()) { + return cond; + } + + LibraryFunction fn = cond.getFunction(); + + // isSet(f(x)) where f(x) is a function call - unwrap and bind + if (fn instanceof IsSet) { + Expression inner = fn.getArguments().get(0); + if (inner instanceof LibraryFunction) { + transformedCount++; + return cond.toBuilder().fn((LibraryFunction) inner).result(createSyntheticName(fn)).build(); + } + return cond; + } + + // Bare function call that doesn't return a boolean? add binding + if (fn.getFunctionDefinition().getReturnType() != Type.booleanType()) { + transformedCount++; + return cond.toBuilder().result(createSyntheticName(fn)).build(); + } + + return cond; + } + + private String createSyntheticName(LibraryFunction fn) { + return "_synthetic_" + fn.getName() + "_" + (syntheticCounter++); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeMapper.java similarity index 50% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeMapper.java index 7690a1cffd9..ed46642de8b 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeMapper.java @@ -8,11 +8,9 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.function.Predicate; import software.amazon.smithy.model.node.Node; import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; @@ -23,92 +21,276 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; -/** - * Utility for rewriting references within expression trees. - */ -final class TreeRewriter { - // A no-op rewriter that returns expressions unchanged. - static final TreeRewriter IDENTITY = new TreeRewriter(ref -> ref, expr -> false); - - private final Function referenceTransformer; - private final Predicate shouldRewrite; - +public abstract class TreeMapper { /** - * Creates a new reference rewriter. - * - * @param referenceTransformer function to transform references - * @param shouldRewrite predicate to determine if an expression needs rewriting + * A no-op mapper that returns expressions unchanged. */ - TreeRewriter( - Function referenceTransformer, - Predicate shouldRewrite - ) { - this.referenceTransformer = referenceTransformer; - this.shouldRewrite = shouldRewrite; - } + private static final TreeMapper IDENTITY = new TreeMapper() {}; /** - * Creates a simple rewriter that replaces specific references. + * Creates a mapper that replaces specific references. * * @param replacements map of variable names to replacement expressions - * @return a reference rewriter that performs the replacements + * @return a mapper that performs the replacements */ - static TreeRewriter forReplacements(Map replacements) { + public static TreeMapper newReferenceReplacingMapper(Map replacements) { if (replacements.isEmpty()) { return IDENTITY; } - return new TreeRewriter( - ref -> replacements.getOrDefault(ref.getName().toString(), ref), - expr -> expr.getReferences().stream().anyMatch(replacements::containsKey)); + + return new TreeMapper() { + @Override + public Expression expression(Expression expression) { + // Handle Reference -> non-Reference replacement here since reference() must return Reference + if (expression instanceof Reference) { + Expression replacement = replacements.get(((Reference) expression).getName().toString()); + if (replacement != null) { + return replacement; + } + return expression; + } + + // Only do deeper replacements if the expression references a relevant variable. + for (String ref : expression.getReferences()) { + if (replacements.containsKey(ref)) { + return super.expression(expression); + } + } + + return expression; + } + + @Override + public Reference reference(Reference ref) { + // This is only called from super.expression() for nested references + Expression replacement = replacements.get(ref.getName().toString()); + if (replacement instanceof Reference) { + return (Reference) replacement; + } + // Non-reference replacements are handled in expression() above + return ref; + } + }; } - static List transformNestedRules( - TreeRule tree, - String parentPath, - BiFunction transformer - ) { - List result = new ArrayList<>(); - for (int i = 0; i < tree.getRules().size(); i++) { - Rule transformed = transformer.apply( - tree.getRules().get(i), - parentPath + "/tree/rule[" + i + "]"); - if (transformed != null) { - result.add(transformed); + public EndpointRuleSet endpointRuleSet(EndpointRuleSet ruleSet) { + List transformed = new ArrayList<>(); + for (Rule r : ruleSet.getRules()) { + Rule mapped = rule(r); + if (mapped != null) { + transformed.add(mapped); } } + return ruleSet.toBuilder().rules(transformed).build(); + } + + public Rule rule(Rule r) { + if (r instanceof TreeRule) { + return treeRule((TreeRule) r); + } else if (r instanceof EndpointRule) { + return endpointRule((EndpointRule) r); + } else if (r instanceof ErrorRule) { + return errorRule((ErrorRule) r); + } + return r; + } + + public Rule treeRule(TreeRule tr) { + return Rule.builder() + .description(tr.getDocumentation().orElse(null)) + .conditions(conditions(tr, tr.getConditions())) + .treeRule(rules(tr, tr.getRules())); + } + + public List rules(TreeRule tr, List rules) { + List updated = new ArrayList<>(tr.getRules().size()); + for (Rule rule : rules) { + Rule mapped = rule(rule); + if (mapped != null) { + updated.add(mapped); + } + } + return updated; + } + + public List conditions(Rule rule, List conditions) { + List updated = new ArrayList<>(conditions.size()); + for (Condition condition : conditions) { + Condition mapped = condition(rule, condition); + if (mapped != null) { + updated.add(mapped); + } + } + return updated; + } + + public Condition condition(Rule rule, Condition condition) { + return Condition.builder() + .fn(libraryFunction(condition.getFunction())) + .result(result(rule, condition, condition.getResult().orElse(null))) + .build(); + } + + public Identifier result(Rule rule, Condition condition, Identifier result) { return result; } - /** - * Rewrites references within an expression tree. - * - * @param expression the expression to rewrite - * @return the rewritten expression, or the original if no changes needed - */ - Expression rewrite(Expression expression) { - if (!shouldRewrite.test(expression)) { - return expression; + public Rule endpointRule(EndpointRule er) { + return Rule.builder() + .description(er.getDocumentation().orElse(null)) + .conditions(conditions(er, er.getConditions())) + .endpoint(endpoint(er.getEndpoint())); + } + + public Rule errorRule(ErrorRule er) { + return Rule.builder() + .description(er.getDocumentation().orElse(null)) + .conditions(conditions(er, er.getConditions())) + .error(error(er, er.getError())); + } + + public Expression error(ErrorRule er, Expression e) { + return expression(e); + } + + public LibraryFunction libraryFunction(LibraryFunction fn) { + boolean changed = false; + List rewrittenArgs = new ArrayList<>(fn.getArguments().size()); + for (int i = 0; i < fn.getArguments().size(); i++) { + Expression argument = fn.getArguments().get(i); + Expression rewritten = argument(fn, i, argument); + rewrittenArgs.add(rewritten); + if (rewritten != argument) { + changed = true; + } + } + + if (!changed) { + return fn; } - if (expression instanceof StringLiteral) { - return rewriteStringLiteral((StringLiteral) expression); - } else if (expression instanceof TupleLiteral) { - return rewriteTupleLiteral((TupleLiteral) expression); - } else if (expression instanceof RecordLiteral) { - return rewriteRecordLiteral((RecordLiteral) expression); + return fn.getFunctionDefinition() + .createFunction(FunctionNode.builder() + .name(Node.from(fn.getName())) + .arguments(rewrittenArgs) + .build()); + } + + public Expression argument(LibraryFunction fn, int position, Expression argument) { + return expression(argument); + } + + public Expression expression(Expression expression) { + if (expression instanceof Literal) { + return literal((Literal) expression); } else if (expression instanceof Reference) { - return referenceTransformer.apply((Reference) expression); + return reference((Reference) expression); } else if (expression instanceof LibraryFunction) { - return rewriteLibraryFunction((LibraryFunction) expression); + return libraryFunction((LibraryFunction) expression); + } else { + return expression; + } + } + + public Reference reference(Reference reference) { + return reference; + } + + public Literal literal(Literal literal) { + if (literal instanceof StringLiteral) { + return stringLiteral((StringLiteral) literal); + } else if (literal instanceof TupleLiteral) { + return tupleLiteral((TupleLiteral) literal); + } else if (literal instanceof RecordLiteral) { + return recordLiteral((RecordLiteral) literal); + } else { + return literal; + } + } + + public Literal tupleLiteral(TupleLiteral tuple) { + List rewrittenMembers = new ArrayList<>(); + boolean changed = false; + + for (Literal member : tuple.members()) { + Literal rewritten = (Literal) expression(member); + rewrittenMembers.add(rewritten); + if (rewritten != member) { + changed = true; + } + } + + return changed ? Literal.tupleLiteral(rewrittenMembers) : tuple; + } + + public Literal recordLiteral(RecordLiteral record) { + Map rewrittenMembers = new LinkedHashMap<>(); + boolean changed = false; + + for (Map.Entry entry : record.members().entrySet()) { + Literal original = entry.getValue(); + Literal rewritten = (Literal) expression(original); + rewrittenMembers.put(entry.getKey(), rewritten); + if (rewritten != original) { + changed = true; + } + } + + return changed ? Literal.recordLiteral(rewrittenMembers) : record; + } + + public Literal stringLiteral(StringLiteral str) { + Template template = str.value(); + if (template.isStatic()) { + return str; + } + + StringBuilder templateBuilder = new StringBuilder(); + boolean changed = false; + + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + Template.Dynamic dynamic = (Template.Dynamic) part; + Expression original = dynamic.toExpression(); + Expression rewritten = expression(original); + if (rewritten != original) { + changed = true; + } + templateBuilder.append('{').append(rewritten).append('}'); + } else { + templateBuilder.append(((Template.Literal) part).getValue()); + } } - return expression; + return changed ? Literal.stringLiteral(Template.fromString(templateBuilder.toString())) : str; } - Map> rewriteHeaders(Map> headers) { + public Endpoint endpoint(Endpoint endpoint) { + Expression rewrittenUrl = expression(endpoint.getUrl()); + Map> rewrittenHeaders = rewriteHeaders(endpoint.getHeaders()); + Map rewrittenProperties = rewriteProperties(endpoint.getProperties()); + + // Only create new endpoint if something changed + if (rewrittenUrl != endpoint.getUrl() + || rewrittenHeaders != endpoint.getHeaders() + || rewrittenProperties != endpoint.getProperties()) { + return Endpoint.builder() + .url(rewrittenUrl) + .headers(rewrittenHeaders) + .properties(rewrittenProperties) + .build(); + } + + return endpoint; + } + + private Map> rewriteHeaders(Map> headers) { if (headers.isEmpty()) { return headers; } @@ -122,7 +304,7 @@ Map> rewriteHeaders(Map> heade for (int i = 0; i < originalValues.size(); i++) { Expression original = originalValues.get(i); - Expression rewrittenExpr = rewrite(original); + Expression rewrittenExpr = expression(original); if (rewrittenExpr != original) { if (rewrittenValues == null) { @@ -155,7 +337,7 @@ Map> rewriteHeaders(Map> heade return changed ? rewritten : headers; } - Map rewriteProperties(Map properties) { + private Map rewriteProperties(Map properties) { if (properties.isEmpty()) { return properties; } @@ -164,7 +346,7 @@ Map rewriteProperties(Map properties) boolean changed = false; for (Map.Entry entry : properties.entrySet()) { - Expression rewrittenExpr = rewrite(entry.getValue()); + Expression rewrittenExpr = expression(entry.getValue()); if (rewrittenExpr != entry.getValue()) { if (!(rewrittenExpr instanceof Literal)) { @@ -191,102 +373,4 @@ Map rewriteProperties(Map properties) return changed ? rewritten : properties; } - - Endpoint rewriteEndpoint(Endpoint endpoint) { - Expression rewrittenUrl = rewrite(endpoint.getUrl()); - Map> rewrittenHeaders = rewriteHeaders(endpoint.getHeaders()); - Map rewrittenProperties = rewriteProperties(endpoint.getProperties()); - - // Only create new endpoint if something changed - if (rewrittenUrl != endpoint.getUrl() - || rewrittenHeaders != endpoint.getHeaders() - || rewrittenProperties != endpoint.getProperties()) { - return Endpoint.builder() - .url(rewrittenUrl) - .headers(rewrittenHeaders) - .properties(rewrittenProperties) - .build(); - } - return endpoint; - } - - private Expression rewriteStringLiteral(StringLiteral str) { - Template template = str.value(); - if (template.isStatic()) { - return str; - } - - StringBuilder templateBuilder = new StringBuilder(); - boolean changed = false; - - for (Template.Part part : template.getParts()) { - if (part instanceof Template.Dynamic) { - Template.Dynamic dynamic = (Template.Dynamic) part; - Expression original = dynamic.toExpression(); - Expression rewritten = rewrite(original); - if (rewritten != original) { - changed = true; - } - templateBuilder.append('{').append(rewritten).append('}'); - } else { - templateBuilder.append(((Template.Literal) part).getValue()); - } - } - - return changed ? Literal.stringLiteral(Template.fromString(templateBuilder.toString())) : str; - } - - private Expression rewriteTupleLiteral(TupleLiteral tuple) { - List rewrittenMembers = new ArrayList<>(); - boolean changed = false; - - for (Literal member : tuple.members()) { - Literal rewritten = (Literal) rewrite(member); - rewrittenMembers.add(rewritten); - if (rewritten != member) { - changed = true; - } - } - - return changed ? Literal.tupleLiteral(rewrittenMembers) : tuple; - } - - private Expression rewriteRecordLiteral(RecordLiteral record) { - Map rewrittenMembers = new LinkedHashMap<>(); - boolean changed = false; - - for (Map.Entry entry : record.members().entrySet()) { - Literal original = entry.getValue(); - Literal rewritten = (Literal) rewrite(original); - rewrittenMembers.put(entry.getKey(), rewritten); - if (rewritten != original) { - changed = true; - } - } - - return changed ? Literal.recordLiteral(rewrittenMembers) : record; - } - - private Expression rewriteLibraryFunction(LibraryFunction fn) { - List rewrittenArgs = new ArrayList<>(); - boolean changed = false; - - for (Expression arg : fn.getArguments()) { - Expression rewritten = rewrite(arg); - rewrittenArgs.add(rewritten); - if (rewritten != arg) { - changed = true; - } - } - - if (!changed) { - return fn; - } - - FunctionNode node = FunctionNode.builder() - .name(Node.from(fn.getName())) - .arguments(rewrittenArgs) - .build(); - return fn.getFunctionDefinition().createFunction(node); - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java index 7d477a72b51..b2508f665ca 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java @@ -33,35 +33,33 @@ */ final class VariableAnalysis { private final Set inputParams; - private final Map> bindings; + private final Map bindingCounts; private final Map referenceCounts; private final Map> expressionMappings; private VariableAnalysis( Set inputParams, - Map> bindings, + Map bindingCounts, Map referenceCounts, Map> expressionMappings ) { this.inputParams = inputParams; - this.bindings = bindings; + this.bindingCounts = bindingCounts; this.referenceCounts = referenceCounts; this.expressionMappings = expressionMappings; } static VariableAnalysis analyze(EndpointRuleSet ruleSet) { - Set inputParameters = extractInputParameters(ruleSet); - - AnalysisVisitor visitor = new AnalysisVisitor(inputParameters); + AnalysisVisitor visitor = new AnalysisVisitor(); for (Rule rule : ruleSet.getRules()) { visitor.visitRule(rule); } return new VariableAnalysis( - inputParameters, - visitor.bindings, + extractInputParameters(ruleSet), + visitor.bindingCounts, visitor.referenceCounts, - createExpressionMappings(visitor.bindings)); + createExpressionMappings(visitor.bindings, visitor.bindingCounts)); } Set getInputParams() { @@ -81,13 +79,20 @@ boolean isReferencedOnce(String variableName) { } boolean hasSingleBinding(String variableName) { - Set expressions = bindings.get(variableName); - return expressions != null && expressions.size() == 1; + Integer count = bindingCounts.get(variableName); + return count != null && count == 1; } boolean hasMultipleBindings(String variableName) { - Set expressions = bindings.get(variableName); - return expressions != null && expressions.size() > 1; + // Check if variable is bound more than once, regardless of whether expressions are the same. + // This is important because SSA may rewrite references in the expressions, making originally + // identical expressions different after SSA. For example: + // Branch A: outpostId = getAttr(parsed, ...) + // Branch B: outpostId = getAttr(parsed, ...) + // If "parsed" is renamed to "parsed_ssa_1" in branch A and "parsed_ssa_2" in branch B, + // the expressions become different, but we've already decided not to rename "outpostId". + Integer count = bindingCounts.get(variableName); + return count != null && count > 1; } boolean isSafeToInline(String variableName) { @@ -103,29 +108,38 @@ private static Set extractInputParameters(EndpointRuleSet ruleSet) { } private static Map> createExpressionMappings( - Map> bindings + Map> bindings, + Map bindingCounts ) { Map> result = new HashMap<>(); for (Map.Entry> entry : bindings.entrySet()) { String varName = entry.getKey(); Set expressions = entry.getValue(); - result.put(varName, createMappingForVariable(varName, expressions)); + int bindingCount = bindingCounts.getOrDefault(varName, 0); + result.put(varName, createMappingForVariable(varName, expressions, bindingCount)); } return result; } private static Map createMappingForVariable( String varName, - Set expressions + Set expressions, + int bindingCount ) { Map mapping = new HashMap<>(); - if (expressions.size() == 1) { + if (bindingCount <= 1) { // Single binding: no SSA rename needed String expression = expressions.iterator().next(); mapping.put(expression, varName); + } else if (expressions.size() == 1) { + // Multiple bindings with the same expression: still need SSA rename because + // references in the expression may be renamed differently in each scope. + // Use a special suffix that indicates it's the same expression. + String expression = expressions.iterator().next(); + mapping.put(expression, varName + "_ssa_1"); } else { - // Multiple bindings: use SSA naming convention + // Multiple bindings with different expressions: use SSA naming convention List sortedExpressions = new ArrayList<>(expressions); sortedExpressions.sort(String::compareTo); for (int i = 0; i < sortedExpressions.size(); i++) { @@ -140,12 +154,8 @@ private static Map createMappingForVariable( private static class AnalysisVisitor { final Map> bindings = new HashMap<>(); + final Map bindingCounts = new HashMap<>(); final Map referenceCounts = new HashMap<>(); - private final Set inputParams; - - AnalysisVisitor(Set inputParams) { - this.inputParams = inputParams; - } void visitRule(Rule rule) { for (Condition condition : rule.getConditions()) { @@ -154,8 +164,9 @@ void visitRule(Rule rule) { LibraryFunction fn = condition.getFunction(); String expression = fn.toString(); - bindings.computeIfAbsent(varName, k -> new HashSet<>()) - .add(expression); + bindings.computeIfAbsent(varName, k -> new HashSet<>()).add(expression); + // Track number of times variable is bound (not just unique expressions) + bindingCounts.merge(varName, 1, Integer::sum); } countReferences(condition.getFunction()); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java index dd23a537cc1..041c3546b1b 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java @@ -4,32 +4,45 @@ */ package software.amazon.smithy.rulesengine.logic.cfg; -import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.List; +import java.util.IdentityHashMap; import java.util.Map; import java.util.Set; import java.util.logging.Logger; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; /** - * Consolidates variable names for identical expressions and eliminates redundant bindings. + * Consolidates variable bindings for identical expressions. * - *

This transform identifies conditions that compute the same expression but assign - * the result to different variable names, and either consolidates them to use the same - * name or eliminates redundant bindings when the same expression is already bound in - * an ancestor scope. + *

This transform performs two optimizations: elimination and consolidation. + * + *

Elimination: If an expression is already bound in a parent scope, child bindings are removed. + *

+ *     url = parseURL(Endpoint)
+ *     _synthetic = parseURL(Endpoint) // eliminated, references rewritten to 'url'
+ * 
+ * + *

Consolidation: If the same expression appears in sibling scopes, all bindings use the first name. + *

+ *     branch1: url = parseURL(Endpoint)
+ *     branch2: parsed = parseURL(Endpoint) // renamed to 'url'
+ * 
+ * + *

The transform avoids renaming when it would cause variable shadowing. + * + * @see SyntheticBindingTransform + * @see DeadStoreEliminationTransform */ -final class VariableConsolidationTransform { +final class VariableConsolidationTransform extends TreeMapper { private static final Logger LOGGER = Logger.getLogger(VariableConsolidationTransform.class.getName()); // Global map of canonical expressions to their first variable name seen @@ -38,266 +51,148 @@ final class VariableConsolidationTransform { // Maps old variable names to new canonical names for rewriting references private final Map variableRenameMap = new HashMap<>(); - // Tracks conditions to eliminate (by their path in the tree) - private final Set conditionsToEliminate = new HashSet<>(); + // Tracks conditions to eliminate (using identity for exact instance matching) + private final Set conditionsToEliminate = Collections.newSetFromMap(new IdentityHashMap<>()); - // Tracks all variables defined at each scope level to check for conflicts - private final Map> scopeDefinedVars = new HashMap<>(); + // Tracks all variables defined at each rule to check for conflicts + private final Map> ruleDefinedVars = new IdentityHashMap<>(); private int consolidatedCount = 0; private int eliminatedCount = 0; private int skippedDueToShadowing = 0; public static EndpointRuleSet transform(EndpointRuleSet ruleSet) { - VariableConsolidationTransform transform = new VariableConsolidationTransform(); - return transform.consolidate(ruleSet); - } - - private EndpointRuleSet consolidate(EndpointRuleSet ruleSet) { - LOGGER.info("Starting variable consolidation transform"); + VariableConsolidationTransform t = new VariableConsolidationTransform(); - for (int i = 0; i < ruleSet.getRules().size(); i++) { - collectDefinitions(ruleSet.getRules().get(i), "rule[" + i + "]"); + // Pass 1: Collect all variable definitions per rule + for (Rule rule : ruleSet.getRules()) { + t.collectDefinitions(rule); } - for (int i = 0; i < ruleSet.getRules().size(); i++) { - discoverBindingsInRule(ruleSet.getRules().get(i), "rule[" + i + "]", new HashMap<>(), new HashSet<>()); - } - - List transformedRules = new ArrayList<>(); - for (int i = 0; i < ruleSet.getRules().size(); i++) { - transformedRules.add(transformRule(ruleSet.getRules().get(i), "rule[" + i + "]")); + // Pass 2: Discover bindings to consolidate/eliminate + for (Rule rule : ruleSet.getRules()) { + t.discoverBindings(rule, new HashMap<>(), new HashSet<>()); } LOGGER.info(String.format("Variable consolidation: %d consolidated, %d eliminated, %d skipped due to shadowing", - consolidatedCount, - eliminatedCount, - skippedDueToShadowing)); + t.consolidatedCount, + t.eliminatedCount, + t.skippedDueToShadowing)); - return EndpointRuleSet.builder() - .parameters(ruleSet.getParameters()) - .rules(transformedRules) - .version(ruleSet.getVersion()) - .build(); + // Pass 3: Transform using TreeMapper + return t.endpointRuleSet(ruleSet); } - private void collectDefinitions(Rule rule, String path) { + private void collectDefinitions(Rule rule) { Set definedVars = new HashSet<>(); - - // Collect all variables defined at this scope level for (Condition condition : rule.getConditions()) { - if (condition.getResult().isPresent()) { - definedVars.add(condition.getResult().get().toString()); - } + condition.getResult().ifPresent(id -> definedVars.add(id.toString())); } - - scopeDefinedVars.put(path, definedVars); + ruleDefinedVars.put(rule, definedVars); if (rule instanceof TreeRule) { - TreeRule treeRule = (TreeRule) rule; - for (int i = 0; i < treeRule.getRules().size(); i++) { - collectDefinitions(treeRule.getRules().get(i), path + "/tree/rule[" + i + "]"); + for (Rule nested : ((TreeRule) rule).getRules()) { + collectDefinitions(nested); } } } - private void discoverBindingsInRule( + private void discoverBindings( Rule rule, - String path, Map parentBindings, Set ancestorVars ) { - // Track bindings at current scope (inherits parent bindings) Map currentBindings = new HashMap<>(parentBindings); - // Track all variables visible from ancestors (for shadowing check) Set visibleAncestorVars = new HashSet<>(ancestorVars); - for (int i = 0; i < rule.getConditions().size(); i++) { - Condition condition = rule.getConditions().get(i); - String condPath = path + "/cond[" + i + "]"; - - if (condition.getResult().isPresent()) { - String varName = condition.getResult().get().toString(); - LibraryFunction fn = condition.getFunction(); - String canonical = fn.canonicalize().toString(); - - // Check if this expression is already bound in parent scope - String parentVar = parentBindings.get(canonical); - if (parentVar != null) { - // Found duplicate in parent, eliminate this binding - variableRenameMap.put(varName, parentVar); - conditionsToEliminate.add(condPath); - eliminatedCount++; - LOGGER.info(String.format("Eliminating redundant binding at %s: '%s' -> '%s' for: %s", - condPath, - varName, - parentVar, - canonical)); - } else { - // Not bound in parent, add to current scope - currentBindings.put(canonical, varName); - visibleAncestorVars.add(varName); + for (Condition condition : rule.getConditions()) { + if (!condition.getResult().isPresent()) { + continue; + } - // Check for global consolidation opportunity - String globalVar = globalExpressionToVar.get(canonical); - if (globalVar != null && !globalVar.equals(varName)) { - // Same expression elsewhere with different name - // Only consolidate if both variables follow SSA naming (same base, different suffix) - // This prevents consolidating semantically different variables that happen to have the same value - if (!hasSameBaseName(varName, globalVar)) { - LOGGER.fine( - String.format("Skipping consolidation '%s' -> '%s' (different base names) for: %s", - varName, - globalVar, - canonical)); - } else if (!wouldCauseShadowing(globalVar, path, ancestorVars)) { - variableRenameMap.put(varName, globalVar); - consolidatedCount++; - LOGGER.info(String.format("Consolidating '%s' -> '%s' for: %s", - varName, - globalVar, - canonical)); - } else { - skippedDueToShadowing++; - LOGGER.fine(String.format("Cannot consolidate '%s' -> '%s' (would shadow) for: %s", - varName, - globalVar, - canonical)); - } - } else if (globalVar == null) { - // First time seeing this expression globally - globalExpressionToVar.put(canonical, varName); + String varName = condition.getResult().get().toString(); + String canonical = condition.getFunction().canonicalize().toString(); + + // Check if already bound in parent scope + String parentVar = parentBindings.get(canonical); + if (parentVar != null) { + variableRenameMap.put(varName, parentVar); + conditionsToEliminate.add(condition); + eliminatedCount++; + LOGGER.fine(() -> String.format("Eliminating redundant binding: '%s' -> '%s' for: %s", + varName, parentVar, canonical)); + } else { + currentBindings.put(canonical, varName); + visibleAncestorVars.add(varName); + + // Check for global consolidation opportunity + String globalVar = globalExpressionToVar.get(canonical); + if (globalVar != null && !globalVar.equals(varName)) { + boolean wouldShadow = wouldCauseShadowing(globalVar, rule, ancestorVars); + if (!wouldShadow) { + // No shadowing - safe to rename the binding + variableRenameMap.put(varName, globalVar); + consolidatedCount++; + LOGGER.fine(() -> String.format("Consolidating '%s' -> '%s' for: %s", + varName, globalVar, canonical)); + } else { + skippedDueToShadowing++; + LOGGER.info(() -> String.format("Shadowing skip: '%s' -> '%s' for expr: %s", + varName, globalVar, canonical)); } + } else if (globalVar == null) { + globalExpressionToVar.put(canonical, varName); } } } if (rule instanceof TreeRule) { - TreeRule treeRule = (TreeRule) rule; - for (int i = 0; i < treeRule.getRules().size(); i++) { - discoverBindingsInRule( - treeRule.getRules().get(i), - path + "/tree/rule[" + i + "]", - currentBindings, - visibleAncestorVars); - } - } - } - - /** - * Checks if two variable names have the same base name. - * For SSA-style variables like "foo_1" and "foo_2", the base name is "foo". - * Variables without SSA suffix (like "s3e_fips" and "s3e_ds") are considered - * to have their full name as the base. - */ - private boolean hasSameBaseName(String var1, String var2) { - String base1 = getSsaBaseName(var1); - String base2 = getSsaBaseName(var2); - return base1.equals(base2); - } - - /** - * Extracts the SSA base name from a variable. - * If the variable ends with _N (where N is a number), strips the suffix. - * Otherwise returns the full name. - */ - private String getSsaBaseName(String varName) { - int lastUnderscore = varName.lastIndexOf('_'); - if (lastUnderscore > 0 && lastUnderscore < varName.length() - 1) { - String suffix = varName.substring(lastUnderscore + 1); - // Check if suffix is all digits - boolean allDigits = true; - for (int i = 0; i < suffix.length(); i++) { - if (!Character.isDigit(suffix.charAt(i))) { - allDigits = false; - break; - } - } - if (allDigits) { - return varName.substring(0, lastUnderscore); + for (Rule nested : ((TreeRule) rule).getRules()) { + discoverBindings(nested, currentBindings, visibleAncestorVars); } } - return varName; } - private boolean wouldCauseShadowing(String varName, String currentPath, Set ancestorVars) { - // Check if using this variable name would shadow an ancestor variable + private boolean wouldCauseShadowing(String varName, Rule currentRule, Set ancestorVars) { if (ancestorVars.contains(varName)) { return true; } - // Check if any child scope already defines this variable - // (which would be shadowed if we use it here) - for (Map.Entry> entry : scopeDefinedVars.entrySet()) { - String scopePath = entry.getKey(); - Set scopeVars = entry.getValue(); - // Check if this scope is a descendant of current path - if (scopePath.startsWith(currentPath + "/") && scopeVars.contains(varName)) { - return true; - } - } - - return false; + // Check if any descendant rule defines this variable + return wouldShadowInDescendants(varName, currentRule); } - private Rule transformRule(Rule rule, String path) { - List transformedConditions = new ArrayList<>(); - - for (int i = 0; i < rule.getConditions().size(); i++) { - String condPath = path + "/cond[" + i + "]"; - - if (conditionsToEliminate.contains(condPath)) { - // Skip this condition entirely since it's redundant - continue; + private boolean wouldShadowInDescendants(String varName, Rule rule) { + if (rule instanceof TreeRule) { + for (Rule nested : ((TreeRule) rule).getRules()) { + Set nestedVars = ruleDefinedVars.get(nested); + if (nestedVars != null && nestedVars.contains(varName)) { + return true; + } + if (wouldShadowInDescendants(varName, nested)) { + return true; + } } - - Condition condition = rule.getConditions().get(i); - transformedConditions.add(transformCondition(condition)); } + return false; + } - if (rule instanceof TreeRule) { - TreeRule treeRule = (TreeRule) rule; - return TreeRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(transformedConditions) - .treeRule(TreeRewriter.transformNestedRules(treeRule, path, this::transformRule)); - - } else if (rule instanceof EndpointRule) { - EndpointRule endpointRule = (EndpointRule) rule; - TreeRewriter rewriter = createRewriter(); - - return EndpointRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(transformedConditions) - .endpoint(rewriter.rewriteEndpoint(endpointRule.getEndpoint())); - - } else if (rule instanceof ErrorRule) { - ErrorRule errorRule = (ErrorRule) rule; - TreeRewriter rewriter = createRewriter(); - - return ErrorRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(transformedConditions) - .error(rewriter.rewrite(errorRule.getError())); + @Override + public Condition condition(Rule rule, Condition condition) { + // Eliminate redundant conditions + if (conditionsToEliminate.contains(condition)) { + return null; } - return rule.withConditions(transformedConditions); - } - - private Condition transformCondition(Condition condition) { - // Rewrite any references in the function - TreeRewriter rewriter = createRewriter(); LibraryFunction fn = condition.getFunction(); - LibraryFunction rewrittenFn = (LibraryFunction) rewriter.rewrite(fn); + LibraryFunction rewrittenFn = libraryFunction(fn); - // If this condition assigns to a variable that should be renamed, - // use the canonical name instead + // Check if binding needs renaming if (condition.getResult().isPresent()) { String varName = condition.getResult().get().toString(); String canonicalName = variableRenameMap.get(varName); if (canonicalName != null) { - // This variable is being consolidated, use the canonical name return Condition.builder() .fn(rewrittenFn) .result(Identifier.of(canonicalName)) @@ -305,7 +200,7 @@ private Condition transformCondition(Condition condition) { } } - // No consolidation needed, but may still need reference rewriting + // Only rebuild if function changed if (rewrittenFn != fn) { return condition.toBuilder().fn(rewrittenFn).build(); } @@ -313,16 +208,23 @@ private Condition transformCondition(Condition condition) { return condition; } - private TreeRewriter createRewriter() { - if (variableRenameMap.isEmpty()) { - return TreeRewriter.IDENTITY; + @Override + public Expression expression(Expression expression) { + // Short-circuit if no references need rewriting + for (String ref : expression.getReferences()) { + if (variableRenameMap.containsKey(ref)) { + return super.expression(expression); + } } + return expression; + } - Map replacements = new HashMap<>(); - for (Map.Entry entry : variableRenameMap.entrySet()) { - replacements.put(entry.getKey(), Expression.getReference(Identifier.of(entry.getValue()))); + @Override + public Reference reference(Reference ref) { + String canonicalName = variableRenameMap.get(ref.getName().toString()); + if (canonicalName != null) { + return Expression.getReference(Identifier.of(canonicalName)); } - - return TreeRewriter.forReplacements(replacements); + return ref; } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java index 29b9c87f03e..9d3909871b3 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java @@ -34,11 +34,11 @@ void testSimpleReferenceReplacement() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("y"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); // Test rewriting a simple reference Reference original = Expression.getReference(Identifier.of("x")); - Expression rewritten = rewriter.rewrite(original); + Expression rewritten = mapper.expression(original); assertEquals("y", ((Reference) rewritten).getName().toString()); } @@ -49,11 +49,11 @@ void testNoRewriteNeeded() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("y"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); // Reference to "z" should not be rewritten Reference original = Expression.getReference(Identifier.of("z")); - Expression rewritten = rewriter.rewrite(original); + Expression rewritten = mapper.expression(original); assertEquals(original, rewritten); } @@ -67,8 +67,8 @@ void testRewriteInStringLiteral() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("newVar"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertInstanceOf(StringLiteral.class, rewritten); StringLiteral rewrittenStr = (StringLiteral) rewritten; @@ -84,8 +84,8 @@ void testRewriteInTupleLiteral() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("replaced"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertInstanceOf(TupleLiteral.class, rewritten); TupleLiteral rewrittenTuple = (TupleLiteral) rewritten; @@ -105,8 +105,8 @@ void testRewriteInRecordLiteral() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("newX"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertInstanceOf(RecordLiteral.class, rewritten); RecordLiteral rewrittenRecord = (RecordLiteral) rewritten; @@ -124,8 +124,8 @@ void testRewriteInLibraryFunction() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("replacedVar"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertTrue(rewritten.toString().contains("replacedVar")); assertNotEquals(original, rewritten); @@ -142,8 +142,8 @@ void testMultipleReplacements() { replacements.put("a", Expression.getReference(Identifier.of("x"))); replacements.put("b", Expression.getReference(Identifier.of("y"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertTrue(rewritten.toString().contains("x")); assertTrue(rewritten.toString().contains("y")); @@ -158,8 +158,8 @@ void testNestedRewriting() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("newVar"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertTrue(rewritten.toString().contains("newVar")); assertNotEquals(original, rewritten); @@ -173,8 +173,8 @@ void testStaticStringNotRewritten() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("y"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertEquals(original, rewritten); } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java index a1594d0d1ec..6e07b291836 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java @@ -16,6 +16,7 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.UriEncode; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; @@ -31,6 +32,7 @@ public class SsaTransformTest { @Test void testNoDisambiguationNeeded() { // When variables are not shadowed, they should remain unchanged + // Note: Dead store elimination will remove unused bindings Parameter bucketParam = Parameter.builder() .name("Bucket") .type(ParameterType.STRING) @@ -53,12 +55,15 @@ void testNoDisambiguationNeeded() { EndpointRuleSet result = SsaTransform.transform(original); - assertEquals(original, result); + // Binding is removed since it's not used (dead store elimination) + EndpointRule resultRule = (EndpointRule) result.getRules().get(0); + assertEquals(false, resultRule.getConditions().get(0).getResult().isPresent()); } @Test void testSimpleShadowing() { // Test when the same variable name is bound to different expressions + // Note: Dead store elimination removes unused bindings Parameter param = Parameter.builder() .name("Input") .type(ParameterType.STRING) @@ -80,16 +85,18 @@ void testSimpleShadowing() { List resultRules = result.getRules(); assertEquals(2, resultRules.size()); + // Bindings are removed since they're not used (dead store elimination) EndpointRule resultRule1 = (EndpointRule) resultRules.get(0); - assertEquals("temp_ssa_1", resultRule1.getConditions().get(0).getResult().get().toString()); + assertEquals(false, resultRule1.getConditions().get(0).getResult().isPresent()); EndpointRule resultRule2 = (EndpointRule) resultRules.get(1); - assertEquals("temp_ssa_2", resultRule2.getConditions().get(0).getResult().get().toString()); + assertEquals(false, resultRule2.getConditions().get(0).getResult().isPresent()); } @Test void testMultipleShadowsOfSameVariable() { // Test when a variable is shadowed multiple times + // Note: Dead store elimination removes unused bindings Parameter param = Parameter.builder() .name("Input") .type(ParameterType.STRING) @@ -109,9 +116,10 @@ void testMultipleShadowsOfSameVariable() { EndpointRuleSet result = SsaTransform.transform(original); List resultRules = result.getRules(); - assertEquals("temp_ssa_1", resultRules.get(0).getConditions().get(0).getResult().get().toString()); - assertEquals("temp_ssa_2", resultRules.get(1).getConditions().get(0).getResult().get().toString()); - assertEquals("temp_ssa_3", resultRules.get(2).getConditions().get(0).getResult().get().toString()); + // Bindings are removed since they're not used (dead store elimination) + assertEquals(false, resultRules.get(0).getConditions().get(0).getResult().isPresent()); + assertEquals(false, resultRules.get(1).getConditions().get(0).getResult().isPresent()); + assertEquals(false, resultRules.get(2).getConditions().get(0).getResult().isPresent()); } @Test @@ -146,6 +154,7 @@ void testErrorRuleHandling() { @Test void testTreeRuleHandling() { // Test tree rules with unique variable names at each level + // Note: Dead store elimination removes unused bindings Parameter param = Parameter.builder() .name("Region") .type(ParameterType.STRING) @@ -182,6 +191,7 @@ void testTreeRuleHandling() { @Test void testParameterShadowingAttempt() { // Test that attempting to shadow a parameter gets disambiguated + // Note: Dead store elimination removes unused bindings Parameter bucketParam = Parameter.builder() .name("Bucket") .type(ParameterType.STRING) @@ -205,9 +215,9 @@ void testParameterShadowingAttempt() { EndpointRuleSet result = SsaTransform.transform(original); - // Should handle without issues + // Binding is removed since it's not used (dead store elimination) EndpointRule resultRule = (EndpointRule) result.getRules().get(0); - assertEquals("Bucket_shadow", resultRule.getConditions().get(0).getResult().get().toString()); + assertEquals(false, resultRule.getConditions().get(0).getResult().isPresent()); } private static EndpointRule createRuleWithBinding(String param, String value, String resultVar, String url) { @@ -228,4 +238,82 @@ private static Expression expr(String value) { private static Endpoint endpoint(String value) { return Endpoint.builder().url(expr(value)).build(); } + + private static Endpoint endpoint(Expression url) { + return Endpoint.builder().url(url).build(); + } + + @Test + void testMultipleBindingsWithUsedVariablesAreSsaRenamed() { + // When the same variable is bound in multiple sibling branches AND is used, + // each binding should get a unique SSA name to avoid shadowing conflicts + Parameter param = Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build(); + + // Create two sibling rules that: + // 1. Both bind 'myVar' (same variable name, different expressions) + // 2. Use 'myVar' in their endpoints (so it won't be eliminated) + // Using UriEncode which returns a string (not boolean like StringEquals) + Condition cond1 = Condition.builder() + .fn(UriEncode.ofExpressions(Expression.of("Input"))) + .result("myVar") + .build(); + EndpointRule rule1 = (EndpointRule) EndpointRule.builder() + .conditions(Collections.singletonList(cond1)) + .endpoint(Endpoint.builder() + .url(Literal.stringLiteral(Template.fromString("https://{myVar}.example.com"))) + .build()); + + // Second rule uses a different expression (uriEncode of a literal) + Condition cond2 = Condition.builder() + .fn(UriEncode.ofExpressions(Expression.of("other"))) + .result("myVar") + .build(); + EndpointRule rule2 = (EndpointRule) EndpointRule.builder() + .conditions(Collections.singletonList(cond2)) + .endpoint(Endpoint.builder() + .url(Literal.stringLiteral(Template.fromString("https://{myVar}.other.com"))) + .build()); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(param).build()) + .rules(Arrays.asList(rule1, rule2)) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + // Verify both rules still exist + assertEquals(2, result.getRules().size()); + + // Get the result variable names from the transformed conditions + EndpointRule resultRule1 = (EndpointRule) result.getRules().get(0); + EndpointRule resultRule2 = (EndpointRule) result.getRules().get(1); + + String resultVar1 = resultRule1.getConditions().get(0).getResult() + .map(Object::toString).orElse(null); + String resultVar2 = resultRule2.getConditions().get(0).getResult() + .map(Object::toString).orElse(null); + + System.out.println("resultVar1=" + resultVar1); + System.out.println("resultVar2=" + resultVar2); + + // Both should have bindings (since they're used) + assertEquals(true, resultRule1.getConditions().get(0).getResult().isPresent(), + "First binding should be present since myVar is used"); + assertEquals(true, resultRule2.getConditions().get(0).getResult().isPresent(), + "Second binding should be present since myVar is used"); + + // They should be SSA-renamed to unique names + assertEquals(true, resultVar1.contains("_ssa_"), + "First binding should have SSA suffix, got: " + resultVar1); + assertEquals(true, resultVar2.contains("_ssa_"), + "Second binding should have SSA suffix, got: " + resultVar2); + + // They should NOT be the same (unique SSA names) + assertEquals(false, resultVar1.equals(resultVar2), + "SSA names should be unique: " + resultVar1 + " vs " + resultVar2); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransformTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransformTest.java new file mode 100644 index 00000000000..7e46b3b422e --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransformTest.java @@ -0,0 +1,148 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; + +class SyntheticBindingTransformTest { + + private static final Parameters PARAMS = Parameters.builder() + .addParameter(Parameter.builder().name("Input").type(ParameterType.STRING).build()) + .build(); + + @Test + void unwrapsIsSetFunctionCall() { + // isSet(substring(Input, 0, 5, false)) should become _synthetic_0 = substring(...) + Condition isSetSubstring = Condition.builder() + .fn(IsSet.ofExpressions(TestHelpers.substring("Input", 0, 5, false))) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(TestHelpers.isSet("Input"), isSetSubstring) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(PARAMS) + .addRule(rule) + .build(); + + EndpointRuleSet result = SyntheticBindingTransform.transform(ruleSet); + + Condition transformed = result.getRules().get(0).getConditions().get(1); + assertEquals("substring", transformed.getFunction().getName()); + assertTrue(transformed.getResult().isPresent()); + // Name includes the outer function (isSet) that was unwrapped + assertTrue(transformed.getResult().get().toString().startsWith("_synthetic_isSet_")); + } + + @Test + void doesNotUnwrapIsSetReference() { + // isSet(Input) should remain unchanged - it's checking a reference, not a function + Condition isSetRef = Condition.builder() + .fn(TestHelpers.isSet("Input")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(isSetRef) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(PARAMS) + .addRule(rule) + .build(); + + EndpointRuleSet result = SyntheticBindingTransform.transform(ruleSet); + + Condition transformed = result.getRules().get(0).getConditions().get(0); + assertEquals("isSet", transformed.getFunction().getName()); + assertFalse(transformed.getResult().isPresent()); + } + + @Test + void addsBindingToBareFunctionCall() { + // substring(Input, 0, 5, false) should become _synthetic_0 = substring(...) + Condition bareSubstring = Condition.builder() + .fn(TestHelpers.substring("Input", 0, 5, false)) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(TestHelpers.isSet("Input"), bareSubstring) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(PARAMS) + .addRule(rule) + .build(); + + EndpointRuleSet result = SyntheticBindingTransform.transform(ruleSet); + + Condition transformed = result.getRules().get(0).getConditions().get(1); + assertEquals("substring", transformed.getFunction().getName()); + assertTrue(transformed.getResult().isPresent()); + // Name includes the function name + assertTrue(transformed.getResult().get().toString().startsWith("_synthetic_substring_")); + } + + @Test + void doesNotModifyExistingBinding() { + // prefix = substring(Input, 0, 5, false) should remain unchanged + Condition binding = Condition.builder() + .fn(TestHelpers.substring("Input", 0, 5, false)) + .result(Identifier.of("prefix")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(TestHelpers.isSet("Input"), binding) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(PARAMS) + .addRule(rule) + .build(); + + EndpointRuleSet result = SyntheticBindingTransform.transform(ruleSet); + + Condition transformed = result.getRules().get(0).getConditions().get(1); + assertEquals("substring", transformed.getFunction().getName()); + assertEquals("prefix", transformed.getResult().get().toString()); + } + + @Test + void doesNotAddBindingToSimpleChecks() { + // isSet, booleanEquals, stringEquals, not, isValidHostLabel should not get bindings + Condition isSet = Condition.builder() + .fn(TestHelpers.isSet("Input")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(isSet) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(PARAMS) + .addRule(rule) + .build(); + + EndpointRuleSet result = SyntheticBindingTransform.transform(ruleSet); + + Condition transformed = result.getRules().get(0).getConditions().get(0); + assertEquals("isSet", transformed.getFunction().getName()); + assertFalse(transformed.getResult().isPresent()); + } +} From d228501f33554d2ec877064a246016d343210766 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Tue, 30 Dec 2025 14:32:13 -0600 Subject: [PATCH 07/10] Reduce SSA noise --- .../functions/LibraryFunction.java | 1 - .../rulesengine/logic/cfg/CfgBuilder.java | 4 --- .../logic/cfg/SyntheticBindingTransform.java | 3 +- .../logic/cfg/VariableAnalysis.java | 14 ++++----- .../cfg/VariableConsolidationTransform.java | 12 ++++++-- .../logic/cfg/SsaTransformTest.java | 29 +++++++++++++------ 6 files changed, 36 insertions(+), 27 deletions(-) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java index dd543b5dba4..51ebd435407 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java @@ -18,7 +18,6 @@ import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java index 95933785140..ac057d61179 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -9,17 +9,13 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.logic.ConditionReference; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransform.java index 264c32db880..7f0f38c1690 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransform.java @@ -4,6 +4,7 @@ */ package software.amazon.smithy.rulesengine.logic.cfg; +import java.util.logging.Logger; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; @@ -12,8 +13,6 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import java.util.logging.Logger; - /** * Assigns synthetic bindings to conditions that could benefit from variable consolidation. * diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java index b2508f665ca..842bc2aadd8 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java @@ -128,16 +128,14 @@ private static Map createMappingForVariable( ) { Map mapping = new HashMap<>(); - if (bindingCount <= 1) { - // Single binding: no SSA rename needed + if (bindingCount <= 1 || expressions.size() == 1) { + // Single binding or multiple bindings with the same expression: no SSA rename needed. + // When multiple bindings have the same expression text, references inside may get + // SSA-renamed differently in each scope, but that's fine: the resulting expressions + // will differ and be treated as distinct BDD conditions. The binding name being the + // same doesn't cause collisions since conditions are identified by their full content. String expression = expressions.iterator().next(); mapping.put(expression, varName); - } else if (expressions.size() == 1) { - // Multiple bindings with the same expression: still need SSA rename because - // references in the expression may be renamed differently in each scope. - // Use a special suffix that indicates it's the same expression. - String expression = expressions.iterator().next(); - mapping.put(expression, varName + "_ssa_1"); } else { // Multiple bindings with different expressions: use SSA naming convention List sortedExpressions = new ArrayList<>(expressions); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java index 041c3546b1b..b9c27f5f5ae 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java @@ -120,7 +120,9 @@ private void discoverBindings( conditionsToEliminate.add(condition); eliminatedCount++; LOGGER.fine(() -> String.format("Eliminating redundant binding: '%s' -> '%s' for: %s", - varName, parentVar, canonical)); + varName, + parentVar, + canonical)); } else { currentBindings.put(canonical, varName); visibleAncestorVars.add(varName); @@ -134,11 +136,15 @@ private void discoverBindings( variableRenameMap.put(varName, globalVar); consolidatedCount++; LOGGER.fine(() -> String.format("Consolidating '%s' -> '%s' for: %s", - varName, globalVar, canonical)); + varName, + globalVar, + canonical)); } else { skippedDueToShadowing++; LOGGER.info(() -> String.format("Shadowing skip: '%s' -> '%s' for expr: %s", - varName, globalVar, canonical)); + varName, + globalVar, + canonical)); } } else if (globalVar == null) { globalExpressionToVar.put(canonical, varName); diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java index 6e07b291836..0b3a8e66e61 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java @@ -292,28 +292,39 @@ void testMultipleBindingsWithUsedVariablesAreSsaRenamed() { EndpointRule resultRule1 = (EndpointRule) result.getRules().get(0); EndpointRule resultRule2 = (EndpointRule) result.getRules().get(1); - String resultVar1 = resultRule1.getConditions().get(0).getResult() - .map(Object::toString).orElse(null); - String resultVar2 = resultRule2.getConditions().get(0).getResult() - .map(Object::toString).orElse(null); + String resultVar1 = resultRule1.getConditions() + .get(0) + .getResult() + .map(Object::toString) + .orElse(null); + String resultVar2 = resultRule2.getConditions() + .get(0) + .getResult() + .map(Object::toString) + .orElse(null); System.out.println("resultVar1=" + resultVar1); System.out.println("resultVar2=" + resultVar2); // Both should have bindings (since they're used) - assertEquals(true, resultRule1.getConditions().get(0).getResult().isPresent(), + assertEquals(true, + resultRule1.getConditions().get(0).getResult().isPresent(), "First binding should be present since myVar is used"); - assertEquals(true, resultRule2.getConditions().get(0).getResult().isPresent(), + assertEquals(true, + resultRule2.getConditions().get(0).getResult().isPresent(), "Second binding should be present since myVar is used"); // They should be SSA-renamed to unique names - assertEquals(true, resultVar1.contains("_ssa_"), + assertEquals(true, + resultVar1.contains("_ssa_"), "First binding should have SSA suffix, got: " + resultVar1); - assertEquals(true, resultVar2.contains("_ssa_"), + assertEquals(true, + resultVar2.contains("_ssa_"), "Second binding should have SSA suffix, got: " + resultVar2); // They should NOT be the same (unique SSA names) - assertEquals(false, resultVar1.equals(resultVar2), + assertEquals(false, + resultVar1.equals(resultVar2), "SSA names should be unique: " + resultVar1 + " vs " + resultVar2); } } From 21cce54c6d6a34208c35b6b89ea822d80f7e7c76 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Tue, 30 Dec 2025 23:30:04 -0600 Subject: [PATCH 08/10] Refactor S3 tree rewriter into composable transforms MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split the monolithic S3TreeRewriter into three independent transforms: - S3AzCanonicalizerTransform: substring → split for AZ extraction - S3RegionUnifierTransform: unify Region/bucketArn#region/us-east-1 - S3ExpressEndpointTransform: canonicalize FIPS/DualStack/auth Added region unification pass that rewrites all region references to _effective_std_region (for aws-global → us-east-1 mapping) or _effective_arn_region (for UseArnRegion logic). This enables additional BDD sharing across endpoints with different region sources. S3 BDD stats: 77 conditions, 97 results, 484 nodes after sifting. --- smithy-aws-endpoints/build.gradle.kts | 25 +- .../aws/language/functions/S3BddTest.java | 75 ++- .../functions/S3TreeRewriterTest.java | 1 + .../language/functions/S3TreeRewriter.java | 558 ------------------ .../aws/s3/S3AzCanonicalizerTransform.java | 89 +++ .../aws/s3/S3ExpressEndpointTransform.java | 232 ++++++++ .../aws/s3/S3ExpressUrlCanonicalizer.java | 151 +++++ .../aws/s3/S3RegionUnifierTransform.java | 405 +++++++++++++ .../rulesengine/aws/s3/S3TreeRewriter.java | 64 ++ .../rulesengine/logic/cfg/CfgBuilderTest.java | 14 +- 10 files changed, 1030 insertions(+), 584 deletions(-) delete mode 100644 smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java create mode 100644 smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3AzCanonicalizerTransform.java create mode 100644 smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressEndpointTransform.java create mode 100644 smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressUrlCanonicalizer.java create mode 100644 smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3RegionUnifierTransform.java create mode 100644 smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3TreeRewriter.java diff --git a/smithy-aws-endpoints/build.gradle.kts b/smithy-aws-endpoints/build.gradle.kts index 559731a51a3..22f6c9843aa 100644 --- a/smithy-aws-endpoints/build.gradle.kts +++ b/smithy-aws-endpoints/build.gradle.kts @@ -26,7 +26,7 @@ dependencies { } // Integration test source set for tests that require the S3 model -// These tests require JDK 21+ due to the S3 model dependency +// These tests require JDK 17+ due to the S3 model dependency sourceSets { create("it") { compileClasspath += sourceSets["main"].output + sourceSets["test"].output @@ -38,15 +38,11 @@ configurations["itImplementation"].extendsFrom(configurations["testImplementatio configurations["itRuntimeOnly"].extendsFrom(configurations["testRuntimeOnly"]) configurations["itImplementation"].extendsFrom(s3Model) -// Configure IT source set to compile with JDK 21 +// Configure IT source set to compile with current JDK (17+) tasks.named("compileItJava") { - javaCompiler.set( - javaToolchains.compilerFor { - languageVersion.set(JavaLanguageVersion.of(21)) - }, - ) - sourceCompatibility = "21" - targetCompatibility = "21" + // Use current Java version instead of hardcoding to allow flexibility in CI + sourceCompatibility = "17" + targetCompatibility = "17" } val integrationTest by tasks.registering(Test::class) { @@ -57,11 +53,12 @@ val integrationTest by tasks.registering(Test::class) { dependsOn(tasks.jar) shouldRunAfter(tasks.test) - // Run with JDK 21 - javaLauncher.set( - javaToolchains.launcherFor { - languageVersion.set(JavaLanguageVersion.of(21)) - }, + // Pass build directory to tests + systemProperty( + "buildDir", + layout.buildDirectory + .get() + .asFile.absolutePath, ) } diff --git a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java index bc4006878aa..7fac4904424 100644 --- a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java +++ b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java @@ -1,13 +1,24 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ package software.amazon.smithy.rulesengine.aws.language.functions; import static org.junit.jupiter.api.Assertions.assertFalse; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.List; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.shapes.ModelSerializer; import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.rulesengine.aws.s3.S3TreeRewriter; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; import software.amazon.smithy.rulesengine.logic.bdd.CostOptimization; @@ -20,18 +31,20 @@ class S3BddTest { private static final ShapeId S3_SERVICE_ID = ShapeId.from("com.amazonaws.s3#AmazonS3"); + private static Model model; + private static ServiceShape s3Service; private static EndpointRuleSet originalRules; private static EndpointRuleSet rules; private static List testCases; @BeforeAll static void loadS3Model() { - Model model = Model.assembler() + model = Model.assembler() .discoverModels() .assemble() .unwrap(); - ServiceShape s3Service = model.expectShape(S3_SERVICE_ID, ServiceShape.class); + s3Service = model.expectShape(S3_SERVICE_ID, ServiceShape.class); originalRules = s3Service.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); rules = S3TreeRewriter.transform(originalRules); testCases = s3Service.expectTrait(EndpointTestsTrait.class).getTestCases(); @@ -44,33 +57,79 @@ void compileToBddWithOptimizations() { for (EndpointTestCase testCase : testCases) { TestEvaluator.evaluate(rules, testCase); } - + // Build CFG and compile to BDD Cfg cfg = Cfg.from(rules); EndpointBddTrait trait = EndpointBddTrait.from(cfg); - + StringBuilder sb = new StringBuilder(); sb.append("\n=== BDD STATS ===\n"); sb.append("Conditions: ").append(trait.getConditions().size()).append("\n"); sb.append("Results: ").append(trait.getResults().size()).append("\n"); sb.append("Initial BDD nodes: ").append(trait.getBdd().getNodeCount()).append("\n"); - + // Apply sifting optimization SiftingOptimization sifting = SiftingOptimization.builder().cfg(cfg).build(); EndpointBddTrait siftedTrait = sifting.apply(trait); sb.append("After sifting - nodes: ").append(siftedTrait.getBdd().getNodeCount()).append("\n"); - + // Apply cost optimization CostOptimization cost = CostOptimization.builder().cfg(cfg).build(); EndpointBddTrait optimizedTrait = cost.apply(siftedTrait); sb.append("After cost opt - nodes: ").append(optimizedTrait.getBdd().getNodeCount()).append("\n"); - + // Print conditions for analysis sb.append("\n=== CONDITIONS ===\n"); for (int i = 0; i < trait.getConditions().size(); i++) { sb.append(i).append(": ").append(trait.getConditions().get(i)).append("\n"); } - + + // Print results (endpoints) for analysis + sb.append("\n=== RESULTS ===\n"); + for (int i = 0; i < trait.getResults().size(); i++) { + sb.append(i).append(": ").append(trait.getResults().get(i)).append("\n"); + } + System.out.println(sb); + + // Write model with BDD trait to build directory + writeModelWithBddTrait(optimizedTrait); + } + + private void writeModelWithBddTrait(EndpointBddTrait bddTrait) { + String buildDir = System.getProperty("buildDir"); + if (buildDir == null) { + System.out.println("buildDir system property not set, skipping model output"); + return; + } + + // Create updated service with BDD trait instead of RuleSet trait + ServiceShape updatedService = s3Service.toBuilder() + .removeTrait(EndpointRuleSetTrait.ID) + .addTrait(bddTrait) + .build(); + + // Build updated model + Model updatedModel = model.toBuilder() + .removeShape(s3Service.getId()) + .addShape(updatedService) + .build(); + + // Serialize to JSON + ModelSerializer serializer = ModelSerializer.builder().build(); + String json = Node.prettyPrintJson(serializer.serialize(updatedModel)); + + // Write to build directory + Path outputPath = Paths.get(buildDir, "s3-bdd-model.json"); + try { + Path parentDir = outputPath.getParent(); + if (parentDir != null) { + Files.createDirectories(parentDir); + } + Files.writeString(outputPath, json); + System.out.println("Wrote S3 BDD model to: " + outputPath); + } catch (IOException e) { + throw new RuntimeException("Failed to write S3 BDD model", e); + } } } diff --git a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java index dd5e88140a7..48a83c36000 100644 --- a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java +++ b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java @@ -12,6 +12,7 @@ import software.amazon.smithy.model.Model; import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.rulesengine.aws.s3.S3TreeRewriter; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java deleted file mode 100644 index 71aca7412cc..00000000000 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java +++ /dev/null @@ -1,558 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.aws.language.functions; - -import java.util.ArrayList; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.logging.Logger; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import software.amazon.smithy.model.node.StringNode; -import software.amazon.smithy.rulesengine.language.Endpoint; -import software.amazon.smithy.rulesengine.language.EndpointRuleSet; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Split; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Substring; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; -import software.amazon.smithy.utils.SmithyInternalApi; - -/** - * Rewrites S3 endpoint rules to use canonical, position-independent expressions. - * - *

This is a BDD pre-processing transform that makes the rules tree larger but enables dramatically better - * BDD compilation. It solves the "SSA Trap" problem where semantically identical operations appear as syntactically - * different expressions, preventing the BDD compiler from recognizing sharing opportunities. - * - *

Internal use only

- *

Ideally this transform is deleted one day, and the rules that source it adopt these techniques (hopefully we - * don't look back on this comment and laugh in 5 years). If/when that happens, this class will be deleted, whether - * it breaks a consumer that uses it or not. - * - *

Trade-off: Larger Rules, Smaller BDD

- *

This transform would be counterproductive for rule tree interpretation, but is highly beneficial when a - * BDD compiler processes the output. It adds ITE (if-then-else) conditions to compute URL segments and auth scheme - * names, increasing rule tree size by ~30%. However, this enables the BDD compiler to deduplicate endpoints that - * were previously considered distinct, as of writing, reducing BDD results and node counts both by ~43%. - * - *

The key insight is that the BDD deduplicates by endpoint identity (URL template + properties). By making - * URL templates identical through variable substitution, endpoints that differed only in FIPS/DualStack/auth variants - * collapse into a single BDD result. - * - *

Transformations performed:

- * - *

AZ Extraction Canonicalization

- * - *

The original rules extract the availability zone ID using position-dependent substring operations. - * Different bucket name lengths result in different extraction positions, creating 10+ SSA variants that can't - * be shared in the BDD. - * - *

Before: Position-dependent substring extraction - *

{@code
- * {
- *   "conditions": [
- *     {
- *       "fn": "substring",
- *       "argv": [{"ref": "Bucket"}, 6, 14, true],
- *       "assign": "s3expressAvailabilityZoneId"
- *     }
- *   ],
- *   "rules": [...]
- * }
- * // Another branch with different positions:
- * {
- *   "conditions": [
- *     {
- *       "fn": "substring",
- *       "argv": [{"ref": "Bucket"}, 6, 20, true],
- *       "assign": "s3expressAvailabilityZoneId"
- *     }
- *   ],
- *   "rules": [...]
- * }
- * }
- * - *

After: Position-independent split-based extraction - *

{@code
- * {
- *   "conditions": [
- *     {
- *       "fn": "getAttr",
- *       "argv": [
- *         {"fn": "split", "argv": [{"ref": "Bucket"}, "--", 0]},
- *         "[1]"
- *       ],
- *       "assign": "s3expressAvailabilityZoneId"
- *     }
- *   ],
- *   "rules": [...]
- * }
- * }
- * - *

All branches now use the identical expression {@code split(Bucket, "--")[1]}, enabling - * the BDD compiler to share nodes across all S3Express bucket handling paths. Because the expression only interacts - * with Bucket, a constant value, there's no SSA transform performed on these expressions. - * - *

URL Canonicalization

- * - *

S3Express endpoints (currently) have 4 URL variants based on UseFIPS and UseDualStack flags. This creates - * duplicate endpoints that differ only in URL structure. - * - *

Before: Separate endpoints for each FIPS/DualStack combination - *

{@code
- * // Branch 1: FIPS + DualStack
- * {
- *   "conditions": [
- *     {"fn": "booleanEquals", "argv": [{"ref": "UseFIPS"}, true]},
- *     {"fn": "booleanEquals", "argv": [{"ref": "UseDualStack"}, true]}
- *   ],
- *   "endpoint": {
- *     "url": "https://{Bucket}.s3express-fips-{s3expressAvailabilityZoneId}.dualstack.{Region}.amazonaws.com"
- *   }
- * }
- * // Branch 2: FIPS only
- * {
- *   "conditions": [
- *     {"fn": "booleanEquals", "argv": [{"ref": "UseFIPS"}, true]}
- *   ],
- *   "endpoint": {
- *     "url": "https://{Bucket}.s3express-fips-{s3expressAvailabilityZoneId}.{Region}.amazonaws.com"
- *   }
- * }
- * // Branch 3: DualStack only
- * // Branch 4: Neither
- * }
- * - *

After: Single endpoint with ITE-computed URL segments - *

{@code
- * {
- *   "conditions": [
- *     {"fn": "ite", "argv": [{"ref": "UseFIPS"}, "-fips", ""], "assign": "_s3e_fips"},
- *     {"fn": "ite", "argv": [{"ref": "UseDualStack"}, ".dualstack", ""], "assign": "_s3e_ds"}
- *   ],
- *   "endpoint": {
- *     "url": "https://{Bucket}.s3express{_s3e_fips}-{s3expressAvailabilityZoneId}{_s3e_ds}.{Region}.amazonaws.com"
- *   }
- * }
- * }
- * - *

The ITE conditions compute values branchlessly. The BDD sifting optimization naturally places these rare - * S3Express-specific conditions late in the decision tree. - * - *

Auth Scheme Canonicalization

- * - *

S3Express endpoints use different auth schemes based on DisableS3ExpressSessionAuth. - * This creates duplicate endpoints differing only in auth scheme name. - * - *

Before: Separate auth scheme names - *

{@code
- * // When DisableS3ExpressSessionAuth is true:
- * "authSchemes": [{"name": "sigv4", "signingName": "s3express", ...}]
- *
- * // When DisableS3ExpressSessionAuth is false/unset:
- * "authSchemes": [{"name": "sigv4-s3express", "signingName": "s3express", ...}]
- * }
- * - *

After: ITE-computed auth scheme name - *

{@code
- * {
- *   "conditions": [
- *     {
- *       "fn": "ite",
- *       "argv": [
- *         {"fn": "coalesce", "argv": [{"ref": "DisableS3ExpressSessionAuth"}, false]},
- *         "sigv4",
- *         "sigv4-s3express"
- *       ],
- *       "assign": "_s3e_auth"
- *     }
- *   ],
- *   "endpoint": {
- *     "properties": {
- *       "authSchemes": [{"name": "{_s3e_auth}", "signingName": "s3express", ...}]
- *     }
- *   }
- * }
- * }
- */ -@SmithyInternalApi -public final class S3TreeRewriter { - private static final Logger LOGGER = Logger.getLogger(S3TreeRewriter.class.getName()); - - // Variable names for the computed suffixes - private static final String VAR_FIPS = "_s3e_fips"; - private static final String VAR_DS = "_s3e_ds"; - private static final String VAR_AUTH = "_s3e_auth"; - - // Suffix values used in the URI templates - private static final String FIPS_SUFFIX = "-fips"; - private static final String DS_SUFFIX = ".dualstack"; - private static final String EMPTY_SUFFIX = ""; - - // Auth scheme values used with s3-express - private static final String AUTH_SIGV4 = "sigv4"; - private static final String AUTH_SIGV4_S3EXPRESS = "sigv4-s3express"; - - // Property and parameter identifiers - private static final Identifier ID_AUTH_SCHEMES = Identifier.of("authSchemes"); - private static final Identifier ID_NAME = Identifier.of("name"); - private static final Identifier ID_BACKEND = Identifier.of("backend"); - private static final Identifier ID_BUCKET = Identifier.of("Bucket"); - private static final Identifier ID_AZ_ID = Identifier.of("s3expressAvailabilityZoneId"); - private static final Identifier ID_USE_FIPS = Identifier.of("UseFIPS"); - private static final Identifier ID_USE_DUAL_STACK = Identifier.of("UseDualStack"); - private static final Identifier ID_DISABLE_S3EXPRESS_SESSION_AUTH = Identifier.of("DisableS3ExpressSessionAuth"); - - // Auth scheme name literal shared across all rewritten endpoints - private static final Literal AUTH_NAME_LITERAL = Literal.stringLiteral(Template.fromString("{" + VAR_AUTH + "}")); - - // URL pattern matchers, ordered from most specific to least specific. - // Control plane patterns (no AZ) come first, then bucket patterns (with AZ). - // Negative lookahead (?!dualstack) prevents matching dualstack variants in non-DS patterns. - private static final UrlPatternMatcher[] URL_PATTERNS = { - // Control plane: https://s3express-control[-fips][.dualstack].{Region}.amazonaws.com - new UrlPatternMatcher("(s3express-control)-fips\\.dualstack\\.(.+)$", false), - new UrlPatternMatcher("(s3express-control)-fips\\.(?!dualstack)(.+)$", false), - new UrlPatternMatcher("(s3express-control)\\.dualstack\\.(.+)$", false), - new UrlPatternMatcher("(s3express-control)\\.(?!dualstack)(.+)$", false), - // Bucket: https://{Bucket}.s3express[-fips]-{AZ}[.dualstack].{Region}.amazonaws.com - new UrlPatternMatcher("(s3express)-fips-([^.]+)\\.dualstack\\.(.+)$", true), - new UrlPatternMatcher("(s3express)-fips-([^.]+)\\.(?!dualstack)(.+)$", true), - new UrlPatternMatcher("(s3express)-([^.]+)\\.dualstack\\.(.+)$", true), - new UrlPatternMatcher("(s3express)-([^.]+)\\.(?!dualstack)(.+)$", true), - }; - - private int rewrittenCount = 0; - private int totalS3ExpressCount = 0; - - private S3TreeRewriter() {} - - /** - * Transforms the given endpoint rule set using canonical expressions. - * - * @param ruleSet the rule set to transform - * @return the transformed rule set - */ - public static EndpointRuleSet transform(EndpointRuleSet ruleSet) { - S3TreeRewriter rewriter = new S3TreeRewriter(); - EndpointRuleSet result = rewriter.run(ruleSet); - - LOGGER.info(() -> String.format( - "S3 tree rewriter: %s/%s S3Express endpoints rewritten", - rewriter.rewrittenCount, - rewriter.totalS3ExpressCount)); - - return result; - } - - private EndpointRuleSet run(EndpointRuleSet ruleSet) { - List transformedRules = new ArrayList<>(); - for (Rule rule : ruleSet.getRules()) { - transformedRules.add(transformRule(rule)); - } - - return EndpointRuleSet.builder() - .sourceLocation(ruleSet.getSourceLocation()) - .parameters(ruleSet.getParameters()) - .rules(transformedRules) - .version(ruleSet.getVersion()) - .build(); - } - - private Rule transformRule(Rule rule) { - if (rule instanceof TreeRule) { - TreeRule tr = (TreeRule) rule; - List transformedConditions = transformConditions(tr.getConditions()); - List transformedChildren = new ArrayList<>(); - for (Rule child : tr.getRules()) { - transformedChildren.add(transformRule(child)); - } - return Rule.builder().conditions(transformedConditions).treeRule(transformedChildren); - } else if (rule instanceof EndpointRule) { - return rewriteEndpoint((EndpointRule) rule); - } else { - // Error rules pass through unchanged - return rule; - } - } - - private List transformConditions(List conditions) { - List result = new ArrayList<>(conditions.size()); - for (Condition cond : conditions) { - result.add(transformCondition(cond)); - } - return result; - } - - private Condition transformCondition(Condition cond) { - // Transform AZ extraction: substring(Bucket, N, M) -> split(Bucket, "--")[1] - if (cond.getResult().isPresent() - && ID_AZ_ID.equals(cond.getResult().get()) - && cond.getFunction() instanceof Substring - && isSubstringOnBucket((Substring) cond.getFunction())) { - // Create fresh expression each time to avoid type-checking conflicts - Split bucketSplit = Split.ofExpressions( - Expression.getReference(ID_BUCKET), - Expression.of("--"), - Expression.of(0)); - GetAttr azExpr = GetAttr.ofExpressions(bucketSplit, "[1]"); - return cond.toBuilder().fn(azExpr).build(); - } - return cond; - } - - private boolean isSubstringOnBucket(Substring substring) { - List args = substring.getArguments(); - if (args.isEmpty()) { - return false; - } - Expression target = args.get(0); - return target instanceof Reference && ID_BUCKET.equals(((Reference) target).getName()); - } - - private Rule rewriteEndpoint(EndpointRule rule) { - Endpoint endpoint = rule.getEndpoint(); - Expression urlExpr = endpoint.getUrl(); - - String urlStr = extractUrlString(urlExpr); - if (urlStr == null) { - return rule; - } - - boolean isS3ExpressUrl = urlStr.contains("s3express"); - boolean isS3ExpressBackend = isS3ExpressBackend(endpoint); - - if (!isS3ExpressUrl && !isS3ExpressBackend) { - return rule; - } - - totalS3ExpressCount++; - - // For URL override endpoints (backend=S3Express but URL doesn't match s3express hostname), - // just canonicalize the auth scheme - no URL rewriting needed - if (isS3ExpressBackend && !isS3ExpressUrl) { - Map newProperties = canonicalizeAuthScheme(endpoint.getProperties()); - - if (newProperties == endpoint.getProperties()) { - return rule; - } - - rewrittenCount++; - - Endpoint newEndpoint = Endpoint.builder() - .url(urlExpr) - .headers(endpoint.getHeaders()) - .properties(newProperties) - .sourceLocation(endpoint.getSourceLocation()) - .build(); - - List allConditions = new ArrayList<>(rule.getConditions()); - allConditions.add(createAuthIteCondition()); - - return Rule.builder() - .conditions(allConditions) - .endpoint(newEndpoint); - } - - // Standard S3Express URL - match and rewrite - UrlMatchResult match = matchUrl(urlStr); - if (match == null) { - return rule; - } - - rewrittenCount++; - - String newUrl = match.rewriteUrl(); - - Map newProperties = endpoint.getProperties(); - if (match instanceof BucketUrlMatchResult) { - newProperties = canonicalizeAuthScheme(endpoint.getProperties()); - } - - Endpoint newEndpoint = Endpoint.builder() - .url(Expression.of(newUrl)) - .headers(endpoint.getHeaders()) - .properties(newProperties) - .sourceLocation(endpoint.getSourceLocation()) - .build(); - - List allConditions = new ArrayList<>(rule.getConditions()); - allConditions.addAll(createIteConditions()); - - return Rule.builder() - .conditions(allConditions) - .endpoint(newEndpoint); - } - - private List createIteConditions() { - List conditions = new ArrayList<>(); - conditions.add(createIteAssignment(VAR_FIPS, Expression.getReference(ID_USE_FIPS), FIPS_SUFFIX, EMPTY_SUFFIX)); - conditions.add(createIteAssignment(VAR_DS, Expression.getReference(ID_USE_DUAL_STACK), DS_SUFFIX, EMPTY_SUFFIX)); - Expression sessionAuthDisabled = Coalesce.ofExpressions( - Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), - Expression.of(false)); - conditions.add(createIteAssignment(VAR_AUTH, sessionAuthDisabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS)); - return conditions; - } - - private Condition createIteAssignment(String varName, Expression condition, String trueValue, String falseValue) { - return Condition.builder() - .fn(Ite.ofStrings(condition, trueValue, falseValue)) - .result(varName) - .build(); - } - - private boolean isS3ExpressBackend(Endpoint endpoint) { - Literal backend = endpoint.getProperties().get(ID_BACKEND); - if (backend == null) { - return false; - } - return backend.asStringLiteral() - .filter(Template::isStatic) - .map(t -> "S3Express".equalsIgnoreCase(t.expectLiteral())) - .orElse(false); - } - - private Condition createAuthIteCondition() { - Expression isSessionAuthDisabled = Coalesce.ofExpressions( - Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), - Expression.of(false)); - return createIteAssignment(VAR_AUTH, isSessionAuthDisabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS); - } - - private Map canonicalizeAuthScheme(Map properties) { - Literal authSchemes = properties.get(ID_AUTH_SCHEMES); - if (authSchemes == null) { - return properties; - } - - List schemes = authSchemes.asTupleLiteral().orElse(null); - if (schemes == null || schemes.isEmpty()) { - return properties; - } - - List newSchemes = new ArrayList<>(); - for (Literal scheme : schemes) { - Map record = scheme.asRecordLiteral().orElse(null); - if (record == null) { - newSchemes.add(scheme); - continue; - } - - Literal nameLiteral = record.get(ID_NAME); - if (nameLiteral == null) { - newSchemes.add(scheme); - continue; - } - - String name = nameLiteral.asStringLiteral() - .filter(Template::isStatic) - .map(Template::expectLiteral) - .orElse(null); - - if (AUTH_SIGV4.equals(name) || AUTH_SIGV4_S3EXPRESS.equals(name)) { - Map newRecord = new LinkedHashMap<>(record); - newRecord.put(ID_NAME, AUTH_NAME_LITERAL); - newSchemes.add(Literal.recordLiteral(newRecord)); - } else { - newSchemes.add(scheme); - } - } - - Map newProperties = new LinkedHashMap<>(properties); - newProperties.put(ID_AUTH_SCHEMES, Literal.tupleLiteral(newSchemes)); - return newProperties; - } - - private String extractUrlString(Expression urlExpr) { - return urlExpr.toNode().asStringNode().map(StringNode::getValue).orElse(null); - } - - private UrlMatchResult matchUrl(String url) { - for (UrlPatternMatcher matcher : URL_PATTERNS) { - UrlMatchResult result = matcher.match(url); - if (result != null) { - return result; - } - } - return null; - } - - private abstract static class UrlMatchResult { - protected final String prefix; - - UrlMatchResult(String prefix) { - this.prefix = prefix; - } - - abstract String rewriteUrl(); - } - - private static final class BucketUrlMatchResult extends UrlMatchResult { - private final String s3express; - private final String az; - private final String regionSuffix; - - BucketUrlMatchResult(String url, Matcher m) { - super(url.substring(0, m.start())); - this.s3express = m.group(1); - this.az = m.group(2); - this.regionSuffix = m.group(3); - } - - @Override - String rewriteUrl() { - return String.format("%s%s{%s}-%s{%s}.%s", prefix, s3express, VAR_FIPS, az, VAR_DS, regionSuffix); - } - } - - private static final class ControlPlaneUrlMatchResult extends UrlMatchResult { - private final String s3expressControl; - private final String regionSuffix; - - ControlPlaneUrlMatchResult(String url, Matcher m) { - super(url.substring(0, m.start())); - this.s3expressControl = m.group(1); - this.regionSuffix = m.group(2); - } - - @Override - String rewriteUrl() { - return String.format("%s%s{%s}{%s}.%s", prefix, s3expressControl, VAR_FIPS, VAR_DS, regionSuffix); - } - } - - private static final class UrlPatternMatcher { - private final Pattern pattern; - private final boolean isBucketPattern; - - UrlPatternMatcher(String regex, boolean isBucketPattern) { - this.pattern = Pattern.compile(regex); - this.isBucketPattern = isBucketPattern; - } - - UrlMatchResult match(String url) { - Matcher m = pattern.matcher(url); - if (!m.find()) { - return null; - } else if (isBucketPattern) { - return new BucketUrlMatchResult(url, m); - } else { - return new ControlPlaneUrlMatchResult(url, m); - } - } - } -} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3AzCanonicalizerTransform.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3AzCanonicalizerTransform.java new file mode 100644 index 00000000000..e09c805f7ec --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3AzCanonicalizerTransform.java @@ -0,0 +1,89 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.s3; + +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Split; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Substring; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.cfg.TreeMapper; + +/** + * Canonicalizes S3Express availability zone extraction. + * + *

Rewrites position-dependent substring operations to position-independent split operations: + *

{@code
+ * substring(Bucket, N, M) → split(Bucket, "--")[1]
+ * }
+ * + *

This enables BDD sharing across endpoints that extract the AZ from different + * bucket name positions. + */ +final class S3AzCanonicalizerTransform extends TreeMapper { + + private static final Identifier ID_BUCKET = Identifier.of("Bucket"); + private static final Identifier ID_AZ_ID = Identifier.of("s3expressAvailabilityZoneId"); + + private int rewriteCount = 0; + + private S3AzCanonicalizerTransform() {} + + /** + * Creates a new transform instance. + * + * @return a new transform. + */ + static S3AzCanonicalizerTransform create() { + return new S3AzCanonicalizerTransform(); + } + + /** + * Returns the number of AZ extractions that were canonicalized. + * + * @return rewrite count. + */ + int getRewriteCount() { + return rewriteCount; + } + + @Override + public Condition condition(Rule rule, Condition cond) { + if (isAzIdSubstringBinding(cond)) { + rewriteCount++; + return createCanonicalAzCondition(cond); + } + return super.condition(rule, cond); + } + + // Matches: s3expressAvailabilityZoneId = substring(Bucket, N, M) + private static boolean isAzIdSubstringBinding(Condition cond) { + if (!ID_AZ_ID.equals(cond.getResult().orElse(null))) { + return false; + } + + LibraryFunction fn = cond.getFunction(); + if (!(fn instanceof Substring) || fn.getArguments().isEmpty()) { + return false; + } + + Expression target = fn.getArguments().get(0); + return target instanceof Reference && ID_BUCKET.equals(((Reference) target).getName()); + } + + // Creates: s3expressAvailabilityZoneId = split(Bucket, "--", 0)[1] + private static Condition createCanonicalAzCondition(Condition original) { + Split split = Split.ofExpressions( + Expression.getReference(ID_BUCKET), + Expression.of("--"), + Expression.of(0)); + GetAttr azExpr = GetAttr.ofExpressions(split, "[1]"); + return original.toBuilder().fn(azExpr).build(); + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressEndpointTransform.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressEndpointTransform.java new file mode 100644 index 00000000000..bc13f74a9bc --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressEndpointTransform.java @@ -0,0 +1,232 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.s3; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import software.amazon.smithy.model.node.StringNode; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.cfg.TreeMapper; + +/** + * Canonicalizes S3Express endpoint URLs and auth schemes. + * + *

S3Express endpoints have multiple URL variants based on FIPS and DualStack settings. + * This transform rewrites them to use ITE-computed URL segments, enabling BDD sharing + * across all variants. + * + *

Transformations: + *

    + *
  • URL patterns: {@code s3express-fips-{az}.dualstack.{region}} → + * {@code s3express{_fips}-{az}{_ds}.{region}} with ITE-computed variables
  • + *
  • Auth schemes: {@code sigv4}/{@code sigv4-s3express} → {@code {_s3e_auth}} + * with ITE based on DisableS3ExpressSessionAuth
  • + *
+ */ +final class S3ExpressEndpointTransform extends TreeMapper { + + // Computed variable names + private static final String VAR_FIPS = "_s3e_fips"; + private static final String VAR_DS = "_s3e_ds"; + private static final String VAR_AUTH = "_s3e_auth"; + + // Identifiers + private static final Identifier ID_USE_FIPS = Identifier.of("UseFIPS"); + private static final Identifier ID_USE_DUAL_STACK = Identifier.of("UseDualStack"); + private static final Identifier ID_DISABLE_S3EXPRESS_SESSION_AUTH = Identifier.of("DisableS3ExpressSessionAuth"); + private static final Identifier ID_AUTH_SCHEMES = Identifier.of("authSchemes"); + private static final Identifier ID_NAME = Identifier.of("name"); + private static final Identifier ID_BACKEND = Identifier.of("backend"); + + // Auth scheme values + private static final String AUTH_SIGV4 = "sigv4"; + private static final String AUTH_SIGV4_S3EXPRESS = "sigv4-s3express"; + private static final Literal AUTH_NAME_TEMPLATE = Literal.stringLiteral(Template.fromString("{" + VAR_AUTH + "}")); + + // Metrics + private int rewriteCount = 0; + private int totalCount = 0; + + private S3ExpressEndpointTransform() {} + + static S3ExpressEndpointTransform create() { + return new S3ExpressEndpointTransform(); + } + + int getRewriteCount() { + return rewriteCount; + } + + int getTotalCount() { + return totalCount; + } + + @Override + public Rule endpointRule(EndpointRule er) { + Endpoint endpoint = er.getEndpoint(); + String url = extractUrlString(endpoint.getUrl()); + if (url == null) { + return er; + } + + boolean isS3ExpressUrl = S3ExpressUrlCanonicalizer.isS3ExpressUrl(url); + boolean isS3ExpressBackend = hasS3ExpressBackend(endpoint); + + if (!isS3ExpressUrl && !isS3ExpressBackend) { + return er; + } + + totalCount++; + + // Custom endpoint with S3Express backend: just canonicalize auth + if (isS3ExpressBackend && !isS3ExpressUrl) { + return rewriteS3ExpressAuth(er); + } + + // Standard S3Express URL: rewrite URL pattern + return rewriteS3ExpressUrl(er, url); + } + + private Rule rewriteS3ExpressAuth(EndpointRule rule) { + Endpoint endpoint = rule.getEndpoint(); + Map newProps = canonicalizeAuthScheme(endpoint.getProperties()); + + if (newProps == endpoint.getProperties()) { + return rule; + } + + rewriteCount++; + List conditions = new ArrayList<>(rule.getConditions()); + conditions.add(createAuthIte()); + + return Rule.builder() + .conditions(conditions) + .endpoint(Endpoint.builder() + .url(endpoint.getUrl()) + .headers(endpoint.getHeaders()) + .properties(newProps) + .build()); + } + + // Adds: _s3e_fips = ite(UseFIPS, "-fips", ""), _s3e_ds = ite(UseDualStack, ".dualstack", ""), _s3e_auth = ... + // Then, rewrites URL to use {_s3e_fips} and {_s3e_ds}, and auth schemes to use {_s3e_auth}. + private Rule rewriteS3ExpressUrl(EndpointRule rule, String url) { + S3ExpressUrlCanonicalizer.CanonicalizedUrl canonicalized = S3ExpressUrlCanonicalizer.canonicalize(url); + if (canonicalized == null) { + return rule; + } + + rewriteCount++; + Endpoint endpoint = rule.getEndpoint(); + + // Note: _s3e_auth could technically be omitted for control plane, but including it reduces sifted + // BDD size, so just keeping it as-is for now. + List conditions = new ArrayList<>(rule.getConditions()); + conditions.add(createIte(VAR_FIPS, Expression.getReference(ID_USE_FIPS), "-fips", "")); + conditions.add(createIte(VAR_DS, Expression.getReference(ID_USE_DUAL_STACK), ".dualstack", "")); + conditions.add(createAuthIte()); + + Map newProps = canonicalized.isBucketPattern() + ? canonicalizeAuthScheme(endpoint.getProperties()) + : endpoint.getProperties(); + + return Rule.builder() + .conditions(conditions) + .endpoint(Endpoint.builder() + .url(Expression.of(canonicalized.toCanonicalUrl())) + .headers(endpoint.getHeaders()) + .properties(newProps) + .build()); + } + + // Creates: {varName} = ite(condition, trueVal, falseVal) + private Condition createIte(String varName, Expression condition, String trueVal, String falseVal) { + return Condition.builder() + .fn(Ite.ofStrings(condition, trueVal, falseVal)) + .result(varName) + .build(); + } + + // Creates: _s3e_auth = ite(coalesce(DisableS3ExpressSessionAuth, false), "sigv4", "sigv4-s3express") + private Condition createAuthIte() { + Expression disabled = Coalesce.ofExpressions( + Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), + Expression.of(false)); + return createIte(VAR_AUTH, disabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS); + } + + // Checks for: backend = "S3Express" + private boolean hasS3ExpressBackend(Endpoint endpoint) { + Literal backend = endpoint.getProperties().get(ID_BACKEND); + if (backend == null) { + return false; + } + return backend.asStringLiteral() + .filter(Template::isStatic) + .map(t -> "S3Express".equalsIgnoreCase(t.expectLiteral())) + .orElse(false); + } + + // Rewrites: authSchemes[].name: "sigv4"/"sigv4-s3express" → "{_s3e_auth}" + private Map canonicalizeAuthScheme(Map properties) { + Literal authSchemes = properties.get(ID_AUTH_SCHEMES); + if (authSchemes == null || !authSchemes.asTupleLiteral().isPresent()) { + return properties; + } + + List schemes = authSchemes.asTupleLiteral().get(); + List newSchemes = new ArrayList<>(schemes.size()); + boolean changed = false; + + for (Literal scheme : schemes) { + if (!scheme.asRecordLiteral().isPresent()) { + newSchemes.add(scheme); + continue; + } + + Map record = scheme.asRecordLiteral().get(); + Literal nameLit = record.get(ID_NAME); + String name = null; + if (nameLit != null && nameLit.asStringLiteral().isPresent()) { + Template template = nameLit.asStringLiteral().get(); + if (template.isStatic()) { + name = template.expectLiteral(); + } + } + + if (AUTH_SIGV4.equals(name) || AUTH_SIGV4_S3EXPRESS.equals(name)) { + Map newRecord = new LinkedHashMap<>(record); + newRecord.put(ID_NAME, AUTH_NAME_TEMPLATE); + newSchemes.add(Literal.recordLiteral(newRecord)); + changed = true; + } else { + newSchemes.add(scheme); + } + } + + if (!changed) { + return properties; + } + + Map newProps = new LinkedHashMap<>(properties); + newProps.put(ID_AUTH_SCHEMES, Literal.tupleLiteral(newSchemes)); + return newProps; + } + + private String extractUrlString(Expression urlExpr) { + return urlExpr.toNode().asStringNode().map(StringNode::getValue).orElse(null); + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressUrlCanonicalizer.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressUrlCanonicalizer.java new file mode 100644 index 00000000000..6a393c55d69 --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressUrlCanonicalizer.java @@ -0,0 +1,151 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.s3; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Handles URL pattern matching and canonicalization for S3Express endpoints. + * + *

S3Express URLs have multiple variants based on FIPS and DualStack settings. + * This class detects these variants and produces canonical URLs that use ITE + * variable references instead of hardcoded segments. + * + *

Supported patterns:

+ *
    + *
  • Control plane: {@code s3express-control[-fips][.dualstack].{Region}...}
  • + *
  • Bucket operations: {@code s3express[-fips]-{AZ}[.dualstack].{Region}...}
  • + *
+ */ +@SmithyInternalApi +final class S3ExpressUrlCanonicalizer { + + // ITE variable names injected by S3TreeRewriter + private static final String VAR_FIPS = "_s3e_fips"; + private static final String VAR_DS = "_s3e_ds"; + + // URL patterns ordered most specific first to avoid partial matches + private static final UrlPattern[] PATTERNS = { + // Control plane: s3express-control[-fips][.dualstack].{Region} + new UrlPattern("(s3express-control)-fips\\.dualstack\\.(.+)$", false), + new UrlPattern("(s3express-control)-fips\\.(?!dualstack)(.+)$", false), + new UrlPattern("(s3express-control)\\.dualstack\\.(.+)$", false), + new UrlPattern("(s3express-control)\\.(?!dualstack)(.+)$", false), + // Bucket: s3express[-fips]-{AZ}[.dualstack].{Region} + new UrlPattern("(s3express)-fips-([^.]+)\\.dualstack\\.(.+)$", true), + new UrlPattern("(s3express)-fips-([^.]+)\\.(?!dualstack)(.+)$", true), + new UrlPattern("(s3express)-([^.]+)\\.dualstack\\.(.+)$", true), + new UrlPattern("(s3express)-([^.]+)\\.(?!dualstack)(.+)$", true), + }; + + private S3ExpressUrlCanonicalizer() {} + + /** + * Checks if a URL is an S3Express URL that can be canonicalized. + * + * @param url The URL to check. + * @return true if the URL contains S3Express patterns. + */ + static boolean isS3ExpressUrl(String url) { + return url != null && url.contains("s3express"); + } + + /** + * Attempts to match and canonicalize an S3Express URL. + * + * @param url The URL to canonicalize. + * @return A {@link CanonicalizedUrl} if the URL matched a pattern, or null if no match. + */ + static CanonicalizedUrl canonicalize(String url) { + if (url == null) { + return null; + } + for (UrlPattern pattern : PATTERNS) { + Matcher m = pattern.pattern.matcher(url); + if (m.find()) { + return new CanonicalizedUrl(url, m, pattern.isBucketPattern); + } + } + return null; + } + + /** + * Holds a regex pattern and whether it matches bucket-level operations. + */ + private static final class UrlPattern { + final Pattern pattern; + final boolean isBucketPattern; + + UrlPattern(String regex, boolean isBucketPattern) { + this.pattern = Pattern.compile(regex); + this.isBucketPattern = isBucketPattern; + } + } + + /** + * Represents a successfully matched and canonicalized S3Express URL. + */ + static final class CanonicalizedUrl { + private final String prefix; + private final String service; + private final String az; // null for control plane patterns + private final String regionSuffix; + private final boolean isBucketPattern; + + private CanonicalizedUrl(String url, Matcher m, boolean isBucketPattern) { + this.prefix = url.substring(0, m.start()); + this.isBucketPattern = isBucketPattern; + if (isBucketPattern) { + this.service = m.group(1); + this.az = m.group(2); + this.regionSuffix = m.group(3); + } else { + this.service = m.group(1); + this.az = null; + this.regionSuffix = m.group(2); + } + } + + /** + * Returns whether this is a bucket-level URL pattern (vs control plane). + * + * @return true if bucket pattern. + */ + public boolean isBucketPattern() { + return isBucketPattern; + } + + /** + * Builds the canonicalized URL string using ITE variable references. + * + *

Example: {@code s3express-fips-use1-az1.dualstack.us-east-1...} → + * {@code s3express{_s3e_fips}-use1-az1{_s3e_ds}.us-east-1...} + * + * @return The canonicalized URL string. + */ + public String toCanonicalUrl() { + if (!isBucketPattern) { + return String.format("%s%s{%s}{%s}.%s", + prefix, + service, + VAR_FIPS, + VAR_DS, + regionSuffix); + } else if (az == null) { + throw new IllegalStateException("az must be non-null for bucket patterns"); + } else { + return String.format("%s%s{%s}-%s{%s}.%s", + prefix, + service, + VAR_FIPS, + az, + VAR_DS, + regionSuffix); + } + } + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3RegionUnifierTransform.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3RegionUnifierTransform.java new file mode 100644 index 00000000000..a3f4252da11 --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3RegionUnifierTransform.java @@ -0,0 +1,405 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.s3; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import software.amazon.smithy.rulesengine.aws.language.functions.AwsPartition; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionNode; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.cfg.TreeMapper; + +/** + * Unifies region references across S3 endpoint rules. + * + *

This transform solves the problem where the same logical region appears in + * syntactically different forms: + *

    + *
  • {@code Region} - the input parameter
  • + *
  • {@code bucketArn#region} - region from a parsed bucket ARN
  • + *
  • Hardcoded values like {@code us-east-1} (for aws-global)
  • + *
+ * + *

The transform injects computed variables that unify these: + *

    + *
  • {@code _effective_std_region} = ite(Region == "aws-global", "us-east-1", Region)
  • + *
  • {@code _effective_arn_region} = ite(coalesce(UseArnRegion, true), bucketArn#region, Region)
  • + *
+ */ +final class S3RegionUnifierTransform extends TreeMapper { + + // Computed variable names + static final String VAR_EFFECTIVE_ARN_REGION = "_effective_arn_region"; + static final String VAR_EFFECTIVE_STD_REGION = "_effective_std_region"; + + // Identifiers + private static final Identifier ID_BUCKET_ARN = Identifier.of("bucketArn"); + private static final Identifier ID_USE_ARN_REGION = Identifier.of("UseArnRegion"); + private static final Identifier ID_REGION = Identifier.of("Region"); + private static final Identifier ID_SIGNING_REGION = Identifier.of("signingRegion"); + + // Function names + private static final String AWS_PARSE_ARN = "aws.parseArn"; + private static final String AWS_PARTITION = "aws.partition"; + private static final String IS_VALID_HOST_LABEL = "isValidHostLabel"; + + // Scope tracking + private boolean inBucketArnScope = false; + private boolean signingRegionDefined = false; + + private int rewriteCount = 0; + + private S3RegionUnifierTransform() {} + + static S3RegionUnifierTransform create() { + return new S3RegionUnifierTransform(); + } + + /** + * Returns the number of region references that were unified. + * + * @return rewrite count. + */ + int getRewriteCount() { + return rewriteCount; + } + + @Override + public Rule treeRule(TreeRule tr) { + // Save scope state before descending + boolean savedArnScope = inBucketArnScope; + boolean savedSigningScope = signingRegionDefined; + + Rule result = super.treeRule(tr); + + // Restore scope state when exiting branch + inBucketArnScope = savedArnScope; + signingRegionDefined = savedSigningScope; + + return result; + } + + @Override + public List conditions(Rule rule, List conditions) { + List result = new ArrayList<>(conditions.size() + 2); + + for (Condition cond : conditions) { + Condition transformed = condition(rule, cond); + if (transformed == null) { + continue; + } + + result.add(transformed); + + // Inject _signing_region after isSet(Region) + if (!signingRegionDefined && isIsSetRegion(transformed)) { + result.add(createSigningRegionIte()); + signingRegionDefined = true; + } else if (isBucketArnBinding(transformed)) { + // Inject _effective_region after bucketArn binding + result.add(createEffectiveRegionIte()); + inBucketArnScope = true; + } + } + + return result; + } + + @Override + public Condition condition(Rule rule, Condition cond) { + // In ARN scope, rewrite bucketArn#region in partition/validation calls + if (inBucketArnScope) { + Condition rewritten = rewriteBucketArnRegionInCondition(cond); + if (rewritten != cond) { + return rewritten; + } + } + + return super.condition(rule, cond); + } + + @Override + public Expression error(ErrorRule er, Expression e) { + // Don't rewrite error messages + return e; + } + + @Override + public Literal stringLiteral(StringLiteral str) { + Template template = str.value(); + + // Handle static URL strings with region patterns + if (template.isStatic()) { + String value = template.expectLiteral(); + String rewritten = rewriteUrlRegionPatterns(value); + if (rewritten != null) { + return Literal.stringLiteral(Template.fromString(rewritten)); + } + return str; + } + + // Handle dynamic templates: check for bucketArn#region + return rewriteDynamicTemplate(str, template); + } + + @Override + public Literal recordLiteral(RecordLiteral record) { + Map members = record.members(); + Map newMembers = new LinkedHashMap<>(); + boolean changed = false; + + for (Map.Entry entry : members.entrySet()) { + Identifier key = entry.getKey(); + Literal value = entry.getValue(); + Literal rewritten; + + // Special handling for signingRegion property + if (ID_SIGNING_REGION.equals(key) && signingRegionDefined) { + rewritten = rewriteSigningRegionValue(value); + } else { + rewritten = (Literal) expression(value); + } + + newMembers.put(key, rewritten); + if (rewritten != value) { + changed = true; + } + } + + return changed ? Literal.recordLiteral(newMembers) : record; + } + + // ========== Condition detection helpers ========== + + // Matches: isSet(Region) + private boolean isIsSetRegion(Condition cond) { + if (cond.getResult().isPresent()) { + return false; + } + + LibraryFunction fn = cond.getFunction(); + if (!fn.getFunctionDefinition().equals(IsSet.getDefinition()) || fn.getArguments().isEmpty()) { + return false; + } + + Expression arg = fn.getArguments().get(0); + return arg instanceof Reference && ID_REGION.equals(((Reference) arg).getName()); + } + + // Matches: bucketArn = aws.parseArn(...) + private boolean isBucketArnBinding(Condition cond) { + return cond.getResult().isPresent() + && ID_BUCKET_ARN.equals(cond.getResult().get()) + && AWS_PARSE_ARN.equals(cond.getFunction().getName()); + } + + // Matches: bucketArn#region + private boolean isBucketArnRegion(Expression expr) { + if (!(expr instanceof GetAttr)) { + return false; + } + GetAttr getAttr = (GetAttr) expr; + List args = getAttr.getArguments(); + if (args.isEmpty() || !(args.get(0) instanceof Reference)) { + return false; + } + Reference ref = (Reference) args.get(0); + if (!ID_BUCKET_ARN.equals(ref.getName())) { + return false; + } + List path = getAttr.getPath(); + return path.size() == 1 + && path.get(0) instanceof GetAttr.Part.Key + && "region".equals(((GetAttr.Part.Key) path.get(0)).key().toString()); + } + + // ========== ITE condition creation ========== + + // Creates: _effective_std_region = ite(Region == "aws-global", "us-east-1", Region) + private Condition createSigningRegionIte() { + Expression isGlobal = StringEquals.ofExpressions(Expression.getReference(ID_REGION), "aws-global"); + Ite ite = Ite.ofExpressions(isGlobal, + Expression.of("us-east-1"), + Expression.getReference(ID_REGION)); + return Condition.builder().fn(ite).result(VAR_EFFECTIVE_STD_REGION).build(); + } + + /** + * Creates the effective region ITE for ARN scope. + * + *

This is only called after bucketArn is successfully bound, so we know + * the bucket IS an ARN. The ITE selects between the ARN's region and the + * input region based on UseArnRegion (defaulting to true). + */ + private Condition createEffectiveRegionIte() { + Expression useArnRegion = Coalesce.ofExpressions( + Expression.getReference(ID_USE_ARN_REGION), + Expression.of(true)); + Expression arnRegion = GetAttr.ofExpressions(Expression.getReference(ID_BUCKET_ARN), "region"); + Expression inputRegion = Expression.getReference(ID_REGION); + + Ite ite = Ite.ofExpressions(useArnRegion, arnRegion, inputRegion); + return Condition.builder().fn(ite).result(VAR_EFFECTIVE_ARN_REGION).build(); + } + + // ========== Region rewriting ========== + + private Condition rewriteBucketArnRegionInCondition(Condition cond) { + LibraryFunction fn = cond.getFunction(); + String fnName = fn.getName(); + + if (!AWS_PARTITION.equals(fnName) && !IS_VALID_HOST_LABEL.equals(fnName)) { + return cond; + } + if (fn.getArguments().isEmpty() || !isBucketArnRegion(fn.getArguments().get(0))) { + return cond; + } + + rewriteCount++; + List newArgs = new ArrayList<>(fn.getArguments()); + newArgs.set(0, Expression.getReference(Identifier.of(VAR_EFFECTIVE_ARN_REGION))); + + LibraryFunction newFn = fn.getFunctionDefinition() + .createFunction(FunctionNode.ofExpressions(fnName, newArgs.toArray(new Expression[0]))); + + return cond.toBuilder().fn(newFn).build(); + } + + /** + * Rewrites static URL strings to unify region patterns. + * + *

Pattern order matters: more specific patterns must come first to avoid + * partial matches. + */ + private String rewriteUrlRegionPatterns(String url) { + if (!signingRegionDefined) { + return null; + } + + String targetVar = inBucketArnScope ? VAR_EFFECTIVE_ARN_REGION : VAR_EFFECTIVE_STD_REGION; + String result = url; + boolean changed = false; + + // Order matters: replace more specific patterns first + if (result.contains(".us-east-1.")) { + result = result.replace(".us-east-1.", ".{" + targetVar + "}."); + changed = true; + } + if (result.contains(".{Region}.")) { + result = result.replace(".{Region}.", ".{" + targetVar + "}."); + changed = true; + } + if (result.contains("{Region}")) { + result = result.replace("{Region}", "{" + targetVar + "}"); + changed = true; + } + if (inBucketArnScope && result.contains("{bucketArn#region}")) { + result = result.replace("{bucketArn#region}", "{" + VAR_EFFECTIVE_ARN_REGION + "}"); + changed = true; + } + + if (changed) { + rewriteCount++; + return result; + } + return null; + } + + /** + * Rewrites dynamic template strings to unify region references. + */ + private Literal rewriteDynamicTemplate(StringLiteral str, Template template) { + List parts = template.getParts(); + StringBuilder sb = new StringBuilder(); + boolean changed = false; + String targetVar = inBucketArnScope ? VAR_EFFECTIVE_ARN_REGION : VAR_EFFECTIVE_STD_REGION; + + for (Template.Part part : parts) { + if (part instanceof Template.Dynamic) { + Expression expr = ((Template.Dynamic) part).toExpression(); + if (isBucketArnRegion(expr)) { + sb.append("{").append(VAR_EFFECTIVE_ARN_REGION).append("}"); + rewriteCount++; + changed = true; + continue; + } + if (signingRegionDefined && expr instanceof Reference + && ID_REGION.equals(((Reference) expr).getName())) { + sb.append("{").append(targetVar).append("}"); + rewriteCount++; + changed = true; + continue; + } + sb.append(part); + } else if (part instanceof Template.Literal) { + String literal = ((Template.Literal) part).getValue(); + if (signingRegionDefined && literal.contains(".us-east-1.")) { + literal = literal.replace(".us-east-1.", ".{" + targetVar + "}."); + rewriteCount++; + changed = true; + } + sb.append(literal); + } else { + sb.append(part); + } + } + + return changed ? Literal.stringLiteral(Template.fromString(sb.toString())) : str; + } + + private Literal rewriteSigningRegionValue(Literal value) { + if (!value.asStringLiteral().isPresent()) { + return (Literal) expression(value); + } + + Template template = value.asStringLiteral().get(); + String targetVar = inBucketArnScope ? VAR_EFFECTIVE_ARN_REGION : VAR_EFFECTIVE_STD_REGION; + + // Dynamic template: {Region} or {bucketArn#region} + if (!template.isStatic()) { + List parts = template.getParts(); + if (parts.size() == 1 && parts.get(0) instanceof Template.Dynamic) { + Expression expr = ((Template.Dynamic) parts.get(0)).toExpression(); + if (isBucketArnRegion(expr) + || (expr instanceof Reference && ID_REGION.equals(((Reference) expr).getName()))) { + rewriteCount++; + return Literal.stringLiteral(Template.fromString("{" + targetVar + "}")); + } + } + return (Literal) expression(value); + } + + // Static region value + String staticValue = template.expectLiteral(); + if (isKnownRegion(staticValue)) { + rewriteCount++; + return Literal.stringLiteral(Template.fromString("{" + targetVar + "}")); + } + + return value; + } + + private boolean isKnownRegion(String value) { + return value != null && ("aws-global".equals(value) || AwsPartition.findPartition(value) != null); + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3TreeRewriter.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3TreeRewriter.java new file mode 100644 index 00000000000..8c00d7a6880 --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3TreeRewriter.java @@ -0,0 +1,64 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.s3; + +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Rewrites S3 endpoint rules to create a dramatically smaller and more efficient BDD. + * + *

This is a BDD pre-processing transform that makes the decision tree larger but enables dramatically better + * BDD compilation. It solves the "SSA Trap" problem where semantically identical operations appear as syntactically + * different expressions. + * + *

This class composes three separate transforms: + *

    + *
  1. {@link S3AzCanonicalizerTransform} - Canonicalizes AZ extraction: + * {@code substring(Bucket, N, M)} → {@code split(Bucket, "--")[1]}
  2. + *
  3. {@link S3RegionUnifierTransform} - Unifies region references: + * {@code Region}/{@code bucketArn#region} → {@code _signing_region}/{@code _effective_region}
  4. + *
  5. {@link S3ExpressEndpointTransform} - Canonicalizes S3Express endpoints: + * FIPS/DualStack URL variants → ITE-computed segments
  6. + *
+ * + *

Each transform is independent and can be applied separately if needed. + */ +@SmithyInternalApi +public final class S3TreeRewriter { + private static final Logger LOGGER = Logger.getLogger(S3TreeRewriter.class.getName()); + + private S3TreeRewriter() {} + + /** + * Transforms the given endpoint rule set by applying all S3 canonicalization transforms. + * + * @param ruleSet Rules to transform. + * @return the transformed rule set. + */ + public static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + // Pass 1: Canonicalize AZ extraction + S3AzCanonicalizerTransform azTransform = S3AzCanonicalizerTransform.create(); + EndpointRuleSet afterAz = azTransform.endpointRuleSet(ruleSet); + + // Pass 2: Unify region references + S3RegionUnifierTransform regionTransform = S3RegionUnifierTransform.create(); + EndpointRuleSet afterRegion = regionTransform.endpointRuleSet(afterAz); + + // Pass 3: Canonicalize S3Express endpoints + S3ExpressEndpointTransform s3ExpressTransform = S3ExpressEndpointTransform.create(); + EndpointRuleSet result = s3ExpressTransform.endpointRuleSet(afterRegion); + + LOGGER.info(() -> String.format( + "S3 tree rewriter: %d AZ, %d region, %d/%d S3Express rewrites", + azTransform.getRewriteCount(), + regionTransform.getRewriteCount(), + s3ExpressTransform.getRewriteCount(), + s3ExpressTransform.getTotalCount())); + + return result; + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java index 8fb61f6b89a..b477efa1432 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java @@ -202,15 +202,21 @@ void createConditionReferenceHandlesBooleanEqualsCanonicalizations() { } @Test - void createConditionReferenceDoesNotCanonicalizeWithoutDefault() { - // Test that booleanEquals(region, false) is not canonicalized (no default) + void createConditionReferenceCanonicalizesEvenWithoutDefault() { + // booleanEquals(X, false) -> booleanEquals(X, true) with negation is a valid + // algebraic transformation regardless of whether the parameter has a default Expression ref = Expression.getReference(Identifier.of("region")); Condition cond = Condition.builder().fn(BooleanEquals.ofExpressions(ref, false)).build(); ConditionReference condRef = builder.createConditionReference(cond); - assertFalse(condRef.isNegated()); - assertEquals(cond.getFunction(), condRef.getCondition().getFunction()); + // Should be canonicalized to booleanEquals(region, true) with negation + assertTrue(condRef.isNegated()); + assertInstanceOf(BooleanEquals.class, condRef.getCondition().getFunction()); + + BooleanEquals fn = (BooleanEquals) condRef.getCondition().getFunction(); + assertEquals(ref, fn.getArguments().get(0)); + assertEquals(Literal.booleanLiteral(true), fn.getArguments().get(1)); } @Test From afb092f280a99a2910d6706cd606e9afd2e8b0ac Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Thu, 8 Jan 2026 17:16:29 -0600 Subject: [PATCH 09/10] Address PR feedback on gradle and test case --- settings.gradle.kts | 2 -- smithy-aws-endpoints/build.gradle.kts | 3 +-- .../amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java | 2 -- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/settings.gradle.kts b/settings.gradle.kts index bffb67b89a7..f3c9eba093a 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -6,8 +6,6 @@ pluginManagement { } } - - rootProject.name = "smithy" include(":smithy-aws-iam-traits") diff --git a/smithy-aws-endpoints/build.gradle.kts b/smithy-aws-endpoints/build.gradle.kts index 22f6c9843aa..b4cfdb81ccd 100644 --- a/smithy-aws-endpoints/build.gradle.kts +++ b/smithy-aws-endpoints/build.gradle.kts @@ -38,9 +38,8 @@ configurations["itImplementation"].extendsFrom(configurations["testImplementatio configurations["itRuntimeOnly"].extendsFrom(configurations["testRuntimeOnly"]) configurations["itImplementation"].extendsFrom(s3Model) -// Configure IT source set to compile with current JDK (17+) +// Configure IT source set to compile with JDK (17+) since the models it uses require it. tasks.named("compileItJava") { - // Use current Java version instead of hardcoding to allow flexibility in CI sourceCompatibility = "17" targetCompatibility = "17" } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java index 0b3a8e66e61..7fe4e52b988 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java @@ -303,8 +303,6 @@ void testMultipleBindingsWithUsedVariablesAreSsaRenamed() { .map(Object::toString) .orElse(null); - System.out.println("resultVar1=" + resultVar1); - System.out.println("resultVar2=" + resultVar2); // Both should have bindings (since they're used) assertEquals(true, From 99e402ababad7a72afa6b20aa57c12cc58e2c84a Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Fri, 9 Jan 2026 17:40:22 -0600 Subject: [PATCH 10/10] Add BDD coverage testing and unref condition tx --- .../aws/language/functions/S3BddTest.java | 38 ++++++++++--- .../analysis/BddCoverageChecker.java | 23 ++++++++ .../rulesengine/traits/EndpointBddTrait.java | 54 +++++++++++++++++++ .../logic/cfg/SsaTransformTest.java | 1 - 4 files changed, 107 insertions(+), 9 deletions(-) diff --git a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java index 7fac4904424..fccf607f251 100644 --- a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java +++ b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java @@ -4,8 +4,6 @@ */ package software.amazon.smithy.rulesengine.aws.language.functions; -import static org.junit.jupiter.api.Assertions.assertFalse; - import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; @@ -18,6 +16,7 @@ import software.amazon.smithy.model.shapes.ModelSerializer; import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.rulesengine.analysis.BddCoverageChecker; import software.amazon.smithy.rulesengine.aws.s3.S3TreeRewriter; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; @@ -53,7 +52,6 @@ static void loadS3Model() { @Test void compileToBddWithOptimizations() { // Verify transforms preserve semantics by running all test cases - assertFalse(testCases.isEmpty(), "S3 model should have endpoint test cases"); for (EndpointTestCase testCase : testCases) { TestEvaluator.evaluate(rules, testCase); } @@ -77,23 +75,47 @@ void compileToBddWithOptimizations() { CostOptimization cost = CostOptimization.builder().cfg(cfg).build(); EndpointBddTrait optimizedTrait = cost.apply(siftedTrait); sb.append("After cost opt - nodes: ").append(optimizedTrait.getBdd().getNodeCount()).append("\n"); + System.out.println("Unreferenced BDD conditions before dead condition elimination: " + + new BddCoverageChecker(optimizedTrait).getUnreferencedConditions()); + + EndpointBddTrait finalizedTrait = optimizedTrait.removeUnreferencedConditions(); + System.out.println("Unreferenced BDD conditions after dead condition elimination: " + + new BddCoverageChecker(optimizedTrait).getUnreferencedConditions()); // Print conditions for analysis sb.append("\n=== CONDITIONS ===\n"); - for (int i = 0; i < trait.getConditions().size(); i++) { - sb.append(i).append(": ").append(trait.getConditions().get(i)).append("\n"); + for (int i = 0; i < finalizedTrait.getConditions().size(); i++) { + sb.append(i).append(": ").append(finalizedTrait.getConditions().get(i)).append("\n"); } // Print results (endpoints) for analysis sb.append("\n=== RESULTS ===\n"); - for (int i = 0; i < trait.getResults().size(); i++) { - sb.append(i).append(": ").append(trait.getResults().get(i)).append("\n"); + for (int i = 0; i < finalizedTrait.getResults().size(); i++) { + sb.append(i).append(": ").append(finalizedTrait.getResults().get(i)).append("\n"); } System.out.println(sb); + // Verify transforms preserve semantics by running all test cases on the BDD -and- ensuring 100% coverage. + BddCoverageChecker coverageCheckerBdd = new BddCoverageChecker(finalizedTrait); + for (EndpointTestCase testCase : testCases) { + coverageCheckerBdd.evaluateTestCase(testCase); + } + + if (coverageCheckerBdd.getConditionCoverage() < 100) { + throw new RuntimeException("Condition coverage < 100%: " + + coverageCheckerBdd.getConditionCoverage() + + " : " + coverageCheckerBdd.getUnevaluatedConditions()); + } + + if (coverageCheckerBdd.getResultCoverage() < 100) { + throw new RuntimeException("Result coverage < 100%: " + + coverageCheckerBdd.getResultCoverage() + + " : " + coverageCheckerBdd.getUnevaluatedResults()); + } + // Write model with BDD trait to build directory - writeModelWithBddTrait(optimizedTrait); + writeModelWithBddTrait(finalizedTrait); } private void writeModelWithBddTrait(EndpointBddTrait bddTrait) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/BddCoverageChecker.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/BddCoverageChecker.java index d2b3443163d..e1b10e5130f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/BddCoverageChecker.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/BddCoverageChecker.java @@ -138,6 +138,29 @@ public double getResultCoverage() { return relevantResults == 0 ? 100.0 : (100.0 * coveredRelevantResults / relevantResults); } + /** + * Returns conditions that exist in the conditions list but are not referenced by any BDD node. + * + * @return set of unreferenced conditions + */ + public Set getUnreferencedConditions() { + BitSet referencedByBdd = new BitSet(conditions.size()); + for (int i = 0; i < bdd.getNodeCount(); i++) { + int varIdx = bdd.getVariable(i); + if (varIdx >= 0 && varIdx < conditions.size()) { + referencedByBdd.set(varIdx); + } + } + + Set unreferenced = new HashSet<>(); + for (int i = 0; i < conditions.size(); i++) { + if (!referencedByBdd.get(i)) { + unreferenced.add(conditions.get(i)); + } + } + return unreferenced; + } + // Evaluator that tracks what gets visited during BDD evaluation. private final class TestEvaluator implements ConditionEvaluator { private final RuleEvaluator ruleEvaluator; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java index 52ffd2b4f36..997532932f7 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java @@ -100,6 +100,60 @@ public static EndpointBddTrait from(Cfg cfg) { .build(); } + /** + * Removes conditions that are not referenced by any BDD node and remaps indices. + * + *

This should be called after all BDD optimizations (sifting, cost optimization) are complete. + * A condition could become unreferenced if it is optimized away by the BDD compiler because hi==lo. + * + * @return a new trait with unreferenced conditions removed, or this trait if none were removed + */ + public EndpointBddTrait removeUnreferencedConditions() { + // Find which conditions are actually referenced + int[] refCount = new int[conditions.size()]; + for (int i = 0; i < bdd.getNodeCount(); i++) { + int var = bdd.getVariable(i); + if (var >= 0 && var < refCount.length) { + refCount[var]++; + } + } + + // Build mapping from old to new indices + int[] oldToNew = new int[conditions.size()]; + List newConditions = new ArrayList<>(); + for (int i = 0; i < conditions.size(); i++) { + if (refCount[i] > 0) { + oldToNew[i] = newConditions.size(); + newConditions.add(conditions.get(i)); + } else { + oldToNew[i] = -1; + } + } + + // No change needed + if (newConditions.size() == conditions.size()) { + return this; + } + + // Remap BDD nodes + int[] newNodes = new int[bdd.getNodeCount() * 3]; + for (int i = 0; i < bdd.getNodeCount(); i++) { + int oldVar = bdd.getVariable(i); + newNodes[i * 3] = oldVar < 0 ? oldVar : oldToNew[oldVar]; + newNodes[i * 3 + 1] = bdd.getHigh(i); + newNodes[i * 3 + 2] = bdd.getLow(i); + } + + return builder() + .version(version) + .parameters(parameters) + .conditions(newConditions) + .results(results) + .bdd(new Bdd(bdd + .getRootRef(), newConditions.size(), bdd.getResultCount(), bdd.getNodeCount(), newNodes)) + .build(); + } + /** * Gets the parameters for the endpoint rules. * diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java index 7fe4e52b988..4e0510e55a8 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java @@ -303,7 +303,6 @@ void testMultipleBindingsWithUsedVariablesAreSsaRenamed() { .map(Object::toString) .orElse(null); - // Both should have bindings (since they're used) assertEquals(true, resultRule1.getConditions().get(0).getResult().isPresent(),