diff --git a/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json b/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json new file mode 100644 index 00000000000..3153bd1730d --- /dev/null +++ b/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json @@ -0,0 +1,7 @@ +{ + "type": "feature", + "description": "Implement rules engine ITE fn and S3 tree transform", + "pull_requests": [ + "[#2903](https://github.com/smithy-lang/smithy/pull/2903)" + ] +} diff --git a/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json b/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json new file mode 100644 index 00000000000..cd940a23a40 --- /dev/null +++ b/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json @@ -0,0 +1,7 @@ +{ + "type": "feature", + "description": "Improve BDD sifting (2x speed, more reduction)", + "pull_requests": [ + "[#2890](https://github.com/smithy-lang/smithy/pull/2890)" + ] +} diff --git a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst index 6950b914468..8cb5f7d51d8 100644 --- a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst +++ b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst @@ -208,6 +208,103 @@ The following example uses ``isValidHostLabel`` to check if the value of the } +.. _rules-engine-standard-library-ite: + +``ite`` function +================ + +Summary + An if-then-else function that returns one of two values based on a boolean condition. +Argument types + * condition: ``bool`` + * trueValue: ``T`` or ``option`` + * falseValue: ``T`` or ``option`` +Return type + * ``ite(bool, T, T)`` → ``T`` (both non-optional, result is non-optional) + * ``ite(bool, T, option)`` → ``option`` (any optional makes result optional) + * ``ite(bool, option, T)`` → ``option`` (any optional makes result optional) + * ``ite(bool, option, option)`` → ``option`` (both optional, result is optional) +Since + 1.1 + +The ``ite`` (if-then-else) function evaluates a boolean condition and returns one of two values based on +the result. If the condition is ``true``, it returns ``trueValue``; if ``false``, it returns ``falseValue``. +This function is particularly useful for computing conditional values without branching in the rule tree, resulting +in fewer result nodes, and enabling better BDD optimizations as a result of reduced fragmentation. + +.. important:: + Both ``trueValue`` and ``falseValue`` must have the same base type ``T``. The result type follows + the "least upper bound" rule: if either branch is optional, the result is optional. + +The following example uses ``ite`` to compute a URL suffix based on whether FIPS is enabled: + +.. code-block:: json + + { + "fn": "ite", + "argv": [ + {"ref": "UseFIPS"}, + "-fips", + "" + ], + "assign": "fipsSuffix" + } + +The following example uses ``ite`` with ``coalesce`` to handle an optional boolean parameter: + +.. code-block:: json + + { + "fn": "ite", + "argv": [ + { + "fn": "coalesce", + "argv": [ + {"ref": "DisableFeature"}, + false + ] + }, + "disabled", + "enabled" + ], + "assign": "featureState" + } + + +.. _rules-engine-standard-library-ite-examples: + +-------- +Examples +-------- + +The following table shows various inputs and their corresponding outputs for the ``ite`` function: + +.. list-table:: + :header-rows: 1 + :widths: 20 25 25 30 + + * - Condition + - True Value + - False Value + - Output + * - ``true`` + - ``"-fips"`` + - ``""`` + - ``"-fips"`` + * - ``false`` + - ``"-fips"`` + - ``""`` + - ``""`` + * - ``true`` + - ``"sigv4"`` + - ``"sigv4-s3express"`` + - ``"sigv4"`` + * - ``false`` + - ``"sigv4"`` + - ``"sigv4-s3express"`` + - ``"sigv4-s3express"`` + + .. _rules-engine-standard-library-not: ``not`` function diff --git a/smithy-aws-endpoints/build.gradle.kts b/smithy-aws-endpoints/build.gradle.kts index c142cd4ee32..b4cfdb81ccd 100644 --- a/smithy-aws-endpoints/build.gradle.kts +++ b/smithy-aws-endpoints/build.gradle.kts @@ -11,10 +11,60 @@ description = "AWS specific components for managing endpoints in Smithy" extra["displayName"] = "Smithy :: AWS Endpoints Components" extra["moduleName"] = "software.amazon.smithy.aws.endpoints" +// Custom configuration for S3 model - kept separate from test classpath to avoid +// polluting other tests with S3 model discovery +val s3Model: Configuration by configurations.creating + dependencies { api(project(":smithy-aws-traits")) api(project(":smithy-diff")) api(project(":smithy-rules-engine")) api(project(":smithy-model")) api(project(":smithy-utils")) + + s3Model("software.amazon.api.models:s3:1.0.11") +} + +// Integration test source set for tests that require the S3 model +// These tests require JDK 17+ due to the S3 model dependency +sourceSets { + create("it") { + compileClasspath += sourceSets["main"].output + sourceSets["test"].output + runtimeClasspath += sourceSets["main"].output + sourceSets["test"].output + } +} + +configurations["itImplementation"].extendsFrom(configurations["testImplementation"]) +configurations["itRuntimeOnly"].extendsFrom(configurations["testRuntimeOnly"]) +configurations["itImplementation"].extendsFrom(s3Model) + +// Configure IT source set to compile with JDK (17+) since the models it uses require it. +tasks.named("compileItJava") { + sourceCompatibility = "17" + targetCompatibility = "17" +} + +val integrationTest by tasks.registering(Test::class) { + description = "Runs integration tests that require external models like S3" + group = "verification" + testClassesDirs = sourceSets["it"].output.classesDirs + classpath = sourceSets["it"].runtimeClasspath + dependsOn(tasks.jar) + shouldRunAfter(tasks.test) + + // Pass build directory to tests + systemProperty( + "buildDir", + layout.buildDirectory + .get() + .asFile.absolutePath, + ) +} + +tasks.test { + finalizedBy(integrationTest) +} + +tasks.named("check") { + dependsOn(integrationTest) } diff --git a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java new file mode 100644 index 00000000000..fccf607f251 --- /dev/null +++ b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3BddTest.java @@ -0,0 +1,157 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.language.functions; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.shapes.ModelSerializer; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.rulesengine.analysis.BddCoverageChecker; +import software.amazon.smithy.rulesengine.aws.s3.S3TreeRewriter; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; +import software.amazon.smithy.rulesengine.logic.bdd.CostOptimization; +import software.amazon.smithy.rulesengine.logic.bdd.SiftingOptimization; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; +import software.amazon.smithy.rulesengine.traits.EndpointTestCase; +import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; + +class S3BddTest { + private static final ShapeId S3_SERVICE_ID = ShapeId.from("com.amazonaws.s3#AmazonS3"); + private static Model model; + private static ServiceShape s3Service; + private static EndpointRuleSet originalRules; + private static EndpointRuleSet rules; + private static List testCases; + + @BeforeAll + static void loadS3Model() { + model = Model.assembler() + .discoverModels() + .assemble() + .unwrap(); + + s3Service = model.expectShape(S3_SERVICE_ID, ServiceShape.class); + originalRules = s3Service.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); + rules = S3TreeRewriter.transform(originalRules); + testCases = s3Service.expectTrait(EndpointTestsTrait.class).getTestCases(); + } + + @Test + void compileToBddWithOptimizations() { + // Verify transforms preserve semantics by running all test cases + for (EndpointTestCase testCase : testCases) { + TestEvaluator.evaluate(rules, testCase); + } + + // Build CFG and compile to BDD + Cfg cfg = Cfg.from(rules); + EndpointBddTrait trait = EndpointBddTrait.from(cfg); + + StringBuilder sb = new StringBuilder(); + sb.append("\n=== BDD STATS ===\n"); + sb.append("Conditions: ").append(trait.getConditions().size()).append("\n"); + sb.append("Results: ").append(trait.getResults().size()).append("\n"); + sb.append("Initial BDD nodes: ").append(trait.getBdd().getNodeCount()).append("\n"); + + // Apply sifting optimization + SiftingOptimization sifting = SiftingOptimization.builder().cfg(cfg).build(); + EndpointBddTrait siftedTrait = sifting.apply(trait); + sb.append("After sifting - nodes: ").append(siftedTrait.getBdd().getNodeCount()).append("\n"); + + // Apply cost optimization + CostOptimization cost = CostOptimization.builder().cfg(cfg).build(); + EndpointBddTrait optimizedTrait = cost.apply(siftedTrait); + sb.append("After cost opt - nodes: ").append(optimizedTrait.getBdd().getNodeCount()).append("\n"); + System.out.println("Unreferenced BDD conditions before dead condition elimination: " + + new BddCoverageChecker(optimizedTrait).getUnreferencedConditions()); + + EndpointBddTrait finalizedTrait = optimizedTrait.removeUnreferencedConditions(); + System.out.println("Unreferenced BDD conditions after dead condition elimination: " + + new BddCoverageChecker(optimizedTrait).getUnreferencedConditions()); + + // Print conditions for analysis + sb.append("\n=== CONDITIONS ===\n"); + for (int i = 0; i < finalizedTrait.getConditions().size(); i++) { + sb.append(i).append(": ").append(finalizedTrait.getConditions().get(i)).append("\n"); + } + + // Print results (endpoints) for analysis + sb.append("\n=== RESULTS ===\n"); + for (int i = 0; i < finalizedTrait.getResults().size(); i++) { + sb.append(i).append(": ").append(finalizedTrait.getResults().get(i)).append("\n"); + } + + System.out.println(sb); + + // Verify transforms preserve semantics by running all test cases on the BDD -and- ensuring 100% coverage. + BddCoverageChecker coverageCheckerBdd = new BddCoverageChecker(finalizedTrait); + for (EndpointTestCase testCase : testCases) { + coverageCheckerBdd.evaluateTestCase(testCase); + } + + if (coverageCheckerBdd.getConditionCoverage() < 100) { + throw new RuntimeException("Condition coverage < 100%: " + + coverageCheckerBdd.getConditionCoverage() + + " : " + coverageCheckerBdd.getUnevaluatedConditions()); + } + + if (coverageCheckerBdd.getResultCoverage() < 100) { + throw new RuntimeException("Result coverage < 100%: " + + coverageCheckerBdd.getResultCoverage() + + " : " + coverageCheckerBdd.getUnevaluatedResults()); + } + + // Write model with BDD trait to build directory + writeModelWithBddTrait(finalizedTrait); + } + + private void writeModelWithBddTrait(EndpointBddTrait bddTrait) { + String buildDir = System.getProperty("buildDir"); + if (buildDir == null) { + System.out.println("buildDir system property not set, skipping model output"); + return; + } + + // Create updated service with BDD trait instead of RuleSet trait + ServiceShape updatedService = s3Service.toBuilder() + .removeTrait(EndpointRuleSetTrait.ID) + .addTrait(bddTrait) + .build(); + + // Build updated model + Model updatedModel = model.toBuilder() + .removeShape(s3Service.getId()) + .addShape(updatedService) + .build(); + + // Serialize to JSON + ModelSerializer serializer = ModelSerializer.builder().build(); + String json = Node.prettyPrintJson(serializer.serialize(updatedModel)); + + // Write to build directory + Path outputPath = Paths.get(buildDir, "s3-bdd-model.json"); + try { + Path parentDir = outputPath.getParent(); + if (parentDir != null) { + Files.createDirectories(parentDir); + } + Files.writeString(outputPath, json); + System.out.println("Wrote S3 BDD model to: " + outputPath); + } catch (IOException e) { + throw new RuntimeException("Failed to write S3 BDD model", e); + } + } +} diff --git a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java new file mode 100644 index 00000000000..48a83c36000 --- /dev/null +++ b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java @@ -0,0 +1,53 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.language.functions; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +import java.util.List; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.rulesengine.aws.s3.S3TreeRewriter; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; +import software.amazon.smithy.rulesengine.traits.EndpointTestCase; +import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; + +/** + * Runs the endpoint test cases against the transformed S3 model. We're fixed to a specific version for this test, + * but could periodically bump the version if needed. + */ +class S3TreeRewriterTest { + private static final ShapeId S3_SERVICE_ID = ShapeId.from("com.amazonaws.s3#AmazonS3"); + + private static EndpointRuleSet originalRules; + private static List testCases; + + @BeforeAll + static void loadS3Model() { + Model model = Model.assembler() + .discoverModels() + .assemble() + .unwrap(); + + ServiceShape s3Service = model.expectShape(S3_SERVICE_ID, ServiceShape.class); + originalRules = s3Service.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); + testCases = s3Service.expectTrait(EndpointTestsTrait.class).getTestCases(); + } + + @Test + void transformPreservesEndpointTestSemantics() { + assertFalse(testCases.isEmpty(), "S3 model should have endpoint test cases"); + + EndpointRuleSet transformed = S3TreeRewriter.transform(originalRules); + for (EndpointTestCase testCase : testCases) { + TestEvaluator.evaluate(transformed, testCase); + } + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java index 42b2344e8ef..02dbe32fa4e 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java @@ -26,12 +26,21 @@ public double applyAsDouble(Condition condition) { // Region is almost always provided if (s.contains("isSet(Region)")) { - return 0.95; + return 0.96; } // Endpoint override is rare if (s.contains("isSet(Endpoint)")) { - return 0.1; + return 0.2; + } + + // S3 Express is rare (includes ITE variables from S3TreeRewriter) + if (s.contains("S3Express") || s.contains("--x-s3") + || s.contains("--xa-s3") + || s.contains("s3e_fips") + || s.contains("s3e_ds") + || s.contains("s3e_auth")) { + return 0.001; } // Most isSet checks on optional params succeed moderately @@ -48,11 +57,6 @@ public double applyAsDouble(Condition condition) { return 0.05; } - // S3 Express is relatively rare - if (s.contains("S3Express") || s.contains("--x-s3") || s.contains("--xa-s3")) { - return 0.1; - } - // ARN-based buckets are uncommon if (s.contains("parseArn") || s.contains("arn:")) { return 0.15; diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3AzCanonicalizerTransform.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3AzCanonicalizerTransform.java new file mode 100644 index 00000000000..e09c805f7ec --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3AzCanonicalizerTransform.java @@ -0,0 +1,89 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.s3; + +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Split; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Substring; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.cfg.TreeMapper; + +/** + * Canonicalizes S3Express availability zone extraction. + * + *

Rewrites position-dependent substring operations to position-independent split operations: + *

{@code
+ * substring(Bucket, N, M) → split(Bucket, "--")[1]
+ * }
+ * + *

This enables BDD sharing across endpoints that extract the AZ from different + * bucket name positions. + */ +final class S3AzCanonicalizerTransform extends TreeMapper { + + private static final Identifier ID_BUCKET = Identifier.of("Bucket"); + private static final Identifier ID_AZ_ID = Identifier.of("s3expressAvailabilityZoneId"); + + private int rewriteCount = 0; + + private S3AzCanonicalizerTransform() {} + + /** + * Creates a new transform instance. + * + * @return a new transform. + */ + static S3AzCanonicalizerTransform create() { + return new S3AzCanonicalizerTransform(); + } + + /** + * Returns the number of AZ extractions that were canonicalized. + * + * @return rewrite count. + */ + int getRewriteCount() { + return rewriteCount; + } + + @Override + public Condition condition(Rule rule, Condition cond) { + if (isAzIdSubstringBinding(cond)) { + rewriteCount++; + return createCanonicalAzCondition(cond); + } + return super.condition(rule, cond); + } + + // Matches: s3expressAvailabilityZoneId = substring(Bucket, N, M) + private static boolean isAzIdSubstringBinding(Condition cond) { + if (!ID_AZ_ID.equals(cond.getResult().orElse(null))) { + return false; + } + + LibraryFunction fn = cond.getFunction(); + if (!(fn instanceof Substring) || fn.getArguments().isEmpty()) { + return false; + } + + Expression target = fn.getArguments().get(0); + return target instanceof Reference && ID_BUCKET.equals(((Reference) target).getName()); + } + + // Creates: s3expressAvailabilityZoneId = split(Bucket, "--", 0)[1] + private static Condition createCanonicalAzCondition(Condition original) { + Split split = Split.ofExpressions( + Expression.getReference(ID_BUCKET), + Expression.of("--"), + Expression.of(0)); + GetAttr azExpr = GetAttr.ofExpressions(split, "[1]"); + return original.toBuilder().fn(azExpr).build(); + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressEndpointTransform.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressEndpointTransform.java new file mode 100644 index 00000000000..bc13f74a9bc --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressEndpointTransform.java @@ -0,0 +1,232 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.s3; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import software.amazon.smithy.model.node.StringNode; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.cfg.TreeMapper; + +/** + * Canonicalizes S3Express endpoint URLs and auth schemes. + * + *

S3Express endpoints have multiple URL variants based on FIPS and DualStack settings. + * This transform rewrites them to use ITE-computed URL segments, enabling BDD sharing + * across all variants. + * + *

Transformations: + *

    + *
  • URL patterns: {@code s3express-fips-{az}.dualstack.{region}} → + * {@code s3express{_fips}-{az}{_ds}.{region}} with ITE-computed variables
  • + *
  • Auth schemes: {@code sigv4}/{@code sigv4-s3express} → {@code {_s3e_auth}} + * with ITE based on DisableS3ExpressSessionAuth
  • + *
+ */ +final class S3ExpressEndpointTransform extends TreeMapper { + + // Computed variable names + private static final String VAR_FIPS = "_s3e_fips"; + private static final String VAR_DS = "_s3e_ds"; + private static final String VAR_AUTH = "_s3e_auth"; + + // Identifiers + private static final Identifier ID_USE_FIPS = Identifier.of("UseFIPS"); + private static final Identifier ID_USE_DUAL_STACK = Identifier.of("UseDualStack"); + private static final Identifier ID_DISABLE_S3EXPRESS_SESSION_AUTH = Identifier.of("DisableS3ExpressSessionAuth"); + private static final Identifier ID_AUTH_SCHEMES = Identifier.of("authSchemes"); + private static final Identifier ID_NAME = Identifier.of("name"); + private static final Identifier ID_BACKEND = Identifier.of("backend"); + + // Auth scheme values + private static final String AUTH_SIGV4 = "sigv4"; + private static final String AUTH_SIGV4_S3EXPRESS = "sigv4-s3express"; + private static final Literal AUTH_NAME_TEMPLATE = Literal.stringLiteral(Template.fromString("{" + VAR_AUTH + "}")); + + // Metrics + private int rewriteCount = 0; + private int totalCount = 0; + + private S3ExpressEndpointTransform() {} + + static S3ExpressEndpointTransform create() { + return new S3ExpressEndpointTransform(); + } + + int getRewriteCount() { + return rewriteCount; + } + + int getTotalCount() { + return totalCount; + } + + @Override + public Rule endpointRule(EndpointRule er) { + Endpoint endpoint = er.getEndpoint(); + String url = extractUrlString(endpoint.getUrl()); + if (url == null) { + return er; + } + + boolean isS3ExpressUrl = S3ExpressUrlCanonicalizer.isS3ExpressUrl(url); + boolean isS3ExpressBackend = hasS3ExpressBackend(endpoint); + + if (!isS3ExpressUrl && !isS3ExpressBackend) { + return er; + } + + totalCount++; + + // Custom endpoint with S3Express backend: just canonicalize auth + if (isS3ExpressBackend && !isS3ExpressUrl) { + return rewriteS3ExpressAuth(er); + } + + // Standard S3Express URL: rewrite URL pattern + return rewriteS3ExpressUrl(er, url); + } + + private Rule rewriteS3ExpressAuth(EndpointRule rule) { + Endpoint endpoint = rule.getEndpoint(); + Map newProps = canonicalizeAuthScheme(endpoint.getProperties()); + + if (newProps == endpoint.getProperties()) { + return rule; + } + + rewriteCount++; + List conditions = new ArrayList<>(rule.getConditions()); + conditions.add(createAuthIte()); + + return Rule.builder() + .conditions(conditions) + .endpoint(Endpoint.builder() + .url(endpoint.getUrl()) + .headers(endpoint.getHeaders()) + .properties(newProps) + .build()); + } + + // Adds: _s3e_fips = ite(UseFIPS, "-fips", ""), _s3e_ds = ite(UseDualStack, ".dualstack", ""), _s3e_auth = ... + // Then, rewrites URL to use {_s3e_fips} and {_s3e_ds}, and auth schemes to use {_s3e_auth}. + private Rule rewriteS3ExpressUrl(EndpointRule rule, String url) { + S3ExpressUrlCanonicalizer.CanonicalizedUrl canonicalized = S3ExpressUrlCanonicalizer.canonicalize(url); + if (canonicalized == null) { + return rule; + } + + rewriteCount++; + Endpoint endpoint = rule.getEndpoint(); + + // Note: _s3e_auth could technically be omitted for control plane, but including it reduces sifted + // BDD size, so just keeping it as-is for now. + List conditions = new ArrayList<>(rule.getConditions()); + conditions.add(createIte(VAR_FIPS, Expression.getReference(ID_USE_FIPS), "-fips", "")); + conditions.add(createIte(VAR_DS, Expression.getReference(ID_USE_DUAL_STACK), ".dualstack", "")); + conditions.add(createAuthIte()); + + Map newProps = canonicalized.isBucketPattern() + ? canonicalizeAuthScheme(endpoint.getProperties()) + : endpoint.getProperties(); + + return Rule.builder() + .conditions(conditions) + .endpoint(Endpoint.builder() + .url(Expression.of(canonicalized.toCanonicalUrl())) + .headers(endpoint.getHeaders()) + .properties(newProps) + .build()); + } + + // Creates: {varName} = ite(condition, trueVal, falseVal) + private Condition createIte(String varName, Expression condition, String trueVal, String falseVal) { + return Condition.builder() + .fn(Ite.ofStrings(condition, trueVal, falseVal)) + .result(varName) + .build(); + } + + // Creates: _s3e_auth = ite(coalesce(DisableS3ExpressSessionAuth, false), "sigv4", "sigv4-s3express") + private Condition createAuthIte() { + Expression disabled = Coalesce.ofExpressions( + Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), + Expression.of(false)); + return createIte(VAR_AUTH, disabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS); + } + + // Checks for: backend = "S3Express" + private boolean hasS3ExpressBackend(Endpoint endpoint) { + Literal backend = endpoint.getProperties().get(ID_BACKEND); + if (backend == null) { + return false; + } + return backend.asStringLiteral() + .filter(Template::isStatic) + .map(t -> "S3Express".equalsIgnoreCase(t.expectLiteral())) + .orElse(false); + } + + // Rewrites: authSchemes[].name: "sigv4"/"sigv4-s3express" → "{_s3e_auth}" + private Map canonicalizeAuthScheme(Map properties) { + Literal authSchemes = properties.get(ID_AUTH_SCHEMES); + if (authSchemes == null || !authSchemes.asTupleLiteral().isPresent()) { + return properties; + } + + List schemes = authSchemes.asTupleLiteral().get(); + List newSchemes = new ArrayList<>(schemes.size()); + boolean changed = false; + + for (Literal scheme : schemes) { + if (!scheme.asRecordLiteral().isPresent()) { + newSchemes.add(scheme); + continue; + } + + Map record = scheme.asRecordLiteral().get(); + Literal nameLit = record.get(ID_NAME); + String name = null; + if (nameLit != null && nameLit.asStringLiteral().isPresent()) { + Template template = nameLit.asStringLiteral().get(); + if (template.isStatic()) { + name = template.expectLiteral(); + } + } + + if (AUTH_SIGV4.equals(name) || AUTH_SIGV4_S3EXPRESS.equals(name)) { + Map newRecord = new LinkedHashMap<>(record); + newRecord.put(ID_NAME, AUTH_NAME_TEMPLATE); + newSchemes.add(Literal.recordLiteral(newRecord)); + changed = true; + } else { + newSchemes.add(scheme); + } + } + + if (!changed) { + return properties; + } + + Map newProps = new LinkedHashMap<>(properties); + newProps.put(ID_AUTH_SCHEMES, Literal.tupleLiteral(newSchemes)); + return newProps; + } + + private String extractUrlString(Expression urlExpr) { + return urlExpr.toNode().asStringNode().map(StringNode::getValue).orElse(null); + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressUrlCanonicalizer.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressUrlCanonicalizer.java new file mode 100644 index 00000000000..6a393c55d69 --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3ExpressUrlCanonicalizer.java @@ -0,0 +1,151 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.s3; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Handles URL pattern matching and canonicalization for S3Express endpoints. + * + *

S3Express URLs have multiple variants based on FIPS and DualStack settings. + * This class detects these variants and produces canonical URLs that use ITE + * variable references instead of hardcoded segments. + * + *

Supported patterns:

+ *
    + *
  • Control plane: {@code s3express-control[-fips][.dualstack].{Region}...}
  • + *
  • Bucket operations: {@code s3express[-fips]-{AZ}[.dualstack].{Region}...}
  • + *
+ */ +@SmithyInternalApi +final class S3ExpressUrlCanonicalizer { + + // ITE variable names injected by S3TreeRewriter + private static final String VAR_FIPS = "_s3e_fips"; + private static final String VAR_DS = "_s3e_ds"; + + // URL patterns ordered most specific first to avoid partial matches + private static final UrlPattern[] PATTERNS = { + // Control plane: s3express-control[-fips][.dualstack].{Region} + new UrlPattern("(s3express-control)-fips\\.dualstack\\.(.+)$", false), + new UrlPattern("(s3express-control)-fips\\.(?!dualstack)(.+)$", false), + new UrlPattern("(s3express-control)\\.dualstack\\.(.+)$", false), + new UrlPattern("(s3express-control)\\.(?!dualstack)(.+)$", false), + // Bucket: s3express[-fips]-{AZ}[.dualstack].{Region} + new UrlPattern("(s3express)-fips-([^.]+)\\.dualstack\\.(.+)$", true), + new UrlPattern("(s3express)-fips-([^.]+)\\.(?!dualstack)(.+)$", true), + new UrlPattern("(s3express)-([^.]+)\\.dualstack\\.(.+)$", true), + new UrlPattern("(s3express)-([^.]+)\\.(?!dualstack)(.+)$", true), + }; + + private S3ExpressUrlCanonicalizer() {} + + /** + * Checks if a URL is an S3Express URL that can be canonicalized. + * + * @param url The URL to check. + * @return true if the URL contains S3Express patterns. + */ + static boolean isS3ExpressUrl(String url) { + return url != null && url.contains("s3express"); + } + + /** + * Attempts to match and canonicalize an S3Express URL. + * + * @param url The URL to canonicalize. + * @return A {@link CanonicalizedUrl} if the URL matched a pattern, or null if no match. + */ + static CanonicalizedUrl canonicalize(String url) { + if (url == null) { + return null; + } + for (UrlPattern pattern : PATTERNS) { + Matcher m = pattern.pattern.matcher(url); + if (m.find()) { + return new CanonicalizedUrl(url, m, pattern.isBucketPattern); + } + } + return null; + } + + /** + * Holds a regex pattern and whether it matches bucket-level operations. + */ + private static final class UrlPattern { + final Pattern pattern; + final boolean isBucketPattern; + + UrlPattern(String regex, boolean isBucketPattern) { + this.pattern = Pattern.compile(regex); + this.isBucketPattern = isBucketPattern; + } + } + + /** + * Represents a successfully matched and canonicalized S3Express URL. + */ + static final class CanonicalizedUrl { + private final String prefix; + private final String service; + private final String az; // null for control plane patterns + private final String regionSuffix; + private final boolean isBucketPattern; + + private CanonicalizedUrl(String url, Matcher m, boolean isBucketPattern) { + this.prefix = url.substring(0, m.start()); + this.isBucketPattern = isBucketPattern; + if (isBucketPattern) { + this.service = m.group(1); + this.az = m.group(2); + this.regionSuffix = m.group(3); + } else { + this.service = m.group(1); + this.az = null; + this.regionSuffix = m.group(2); + } + } + + /** + * Returns whether this is a bucket-level URL pattern (vs control plane). + * + * @return true if bucket pattern. + */ + public boolean isBucketPattern() { + return isBucketPattern; + } + + /** + * Builds the canonicalized URL string using ITE variable references. + * + *

Example: {@code s3express-fips-use1-az1.dualstack.us-east-1...} → + * {@code s3express{_s3e_fips}-use1-az1{_s3e_ds}.us-east-1...} + * + * @return The canonicalized URL string. + */ + public String toCanonicalUrl() { + if (!isBucketPattern) { + return String.format("%s%s{%s}{%s}.%s", + prefix, + service, + VAR_FIPS, + VAR_DS, + regionSuffix); + } else if (az == null) { + throw new IllegalStateException("az must be non-null for bucket patterns"); + } else { + return String.format("%s%s{%s}-%s{%s}.%s", + prefix, + service, + VAR_FIPS, + az, + VAR_DS, + regionSuffix); + } + } + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3RegionUnifierTransform.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3RegionUnifierTransform.java new file mode 100644 index 00000000000..a3f4252da11 --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3RegionUnifierTransform.java @@ -0,0 +1,405 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.s3; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import software.amazon.smithy.rulesengine.aws.language.functions.AwsPartition; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionNode; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.cfg.TreeMapper; + +/** + * Unifies region references across S3 endpoint rules. + * + *

This transform solves the problem where the same logical region appears in + * syntactically different forms: + *

    + *
  • {@code Region} - the input parameter
  • + *
  • {@code bucketArn#region} - region from a parsed bucket ARN
  • + *
  • Hardcoded values like {@code us-east-1} (for aws-global)
  • + *
+ * + *

The transform injects computed variables that unify these: + *

    + *
  • {@code _effective_std_region} = ite(Region == "aws-global", "us-east-1", Region)
  • + *
  • {@code _effective_arn_region} = ite(coalesce(UseArnRegion, true), bucketArn#region, Region)
  • + *
+ */ +final class S3RegionUnifierTransform extends TreeMapper { + + // Computed variable names + static final String VAR_EFFECTIVE_ARN_REGION = "_effective_arn_region"; + static final String VAR_EFFECTIVE_STD_REGION = "_effective_std_region"; + + // Identifiers + private static final Identifier ID_BUCKET_ARN = Identifier.of("bucketArn"); + private static final Identifier ID_USE_ARN_REGION = Identifier.of("UseArnRegion"); + private static final Identifier ID_REGION = Identifier.of("Region"); + private static final Identifier ID_SIGNING_REGION = Identifier.of("signingRegion"); + + // Function names + private static final String AWS_PARSE_ARN = "aws.parseArn"; + private static final String AWS_PARTITION = "aws.partition"; + private static final String IS_VALID_HOST_LABEL = "isValidHostLabel"; + + // Scope tracking + private boolean inBucketArnScope = false; + private boolean signingRegionDefined = false; + + private int rewriteCount = 0; + + private S3RegionUnifierTransform() {} + + static S3RegionUnifierTransform create() { + return new S3RegionUnifierTransform(); + } + + /** + * Returns the number of region references that were unified. + * + * @return rewrite count. + */ + int getRewriteCount() { + return rewriteCount; + } + + @Override + public Rule treeRule(TreeRule tr) { + // Save scope state before descending + boolean savedArnScope = inBucketArnScope; + boolean savedSigningScope = signingRegionDefined; + + Rule result = super.treeRule(tr); + + // Restore scope state when exiting branch + inBucketArnScope = savedArnScope; + signingRegionDefined = savedSigningScope; + + return result; + } + + @Override + public List conditions(Rule rule, List conditions) { + List result = new ArrayList<>(conditions.size() + 2); + + for (Condition cond : conditions) { + Condition transformed = condition(rule, cond); + if (transformed == null) { + continue; + } + + result.add(transformed); + + // Inject _signing_region after isSet(Region) + if (!signingRegionDefined && isIsSetRegion(transformed)) { + result.add(createSigningRegionIte()); + signingRegionDefined = true; + } else if (isBucketArnBinding(transformed)) { + // Inject _effective_region after bucketArn binding + result.add(createEffectiveRegionIte()); + inBucketArnScope = true; + } + } + + return result; + } + + @Override + public Condition condition(Rule rule, Condition cond) { + // In ARN scope, rewrite bucketArn#region in partition/validation calls + if (inBucketArnScope) { + Condition rewritten = rewriteBucketArnRegionInCondition(cond); + if (rewritten != cond) { + return rewritten; + } + } + + return super.condition(rule, cond); + } + + @Override + public Expression error(ErrorRule er, Expression e) { + // Don't rewrite error messages + return e; + } + + @Override + public Literal stringLiteral(StringLiteral str) { + Template template = str.value(); + + // Handle static URL strings with region patterns + if (template.isStatic()) { + String value = template.expectLiteral(); + String rewritten = rewriteUrlRegionPatterns(value); + if (rewritten != null) { + return Literal.stringLiteral(Template.fromString(rewritten)); + } + return str; + } + + // Handle dynamic templates: check for bucketArn#region + return rewriteDynamicTemplate(str, template); + } + + @Override + public Literal recordLiteral(RecordLiteral record) { + Map members = record.members(); + Map newMembers = new LinkedHashMap<>(); + boolean changed = false; + + for (Map.Entry entry : members.entrySet()) { + Identifier key = entry.getKey(); + Literal value = entry.getValue(); + Literal rewritten; + + // Special handling for signingRegion property + if (ID_SIGNING_REGION.equals(key) && signingRegionDefined) { + rewritten = rewriteSigningRegionValue(value); + } else { + rewritten = (Literal) expression(value); + } + + newMembers.put(key, rewritten); + if (rewritten != value) { + changed = true; + } + } + + return changed ? Literal.recordLiteral(newMembers) : record; + } + + // ========== Condition detection helpers ========== + + // Matches: isSet(Region) + private boolean isIsSetRegion(Condition cond) { + if (cond.getResult().isPresent()) { + return false; + } + + LibraryFunction fn = cond.getFunction(); + if (!fn.getFunctionDefinition().equals(IsSet.getDefinition()) || fn.getArguments().isEmpty()) { + return false; + } + + Expression arg = fn.getArguments().get(0); + return arg instanceof Reference && ID_REGION.equals(((Reference) arg).getName()); + } + + // Matches: bucketArn = aws.parseArn(...) + private boolean isBucketArnBinding(Condition cond) { + return cond.getResult().isPresent() + && ID_BUCKET_ARN.equals(cond.getResult().get()) + && AWS_PARSE_ARN.equals(cond.getFunction().getName()); + } + + // Matches: bucketArn#region + private boolean isBucketArnRegion(Expression expr) { + if (!(expr instanceof GetAttr)) { + return false; + } + GetAttr getAttr = (GetAttr) expr; + List args = getAttr.getArguments(); + if (args.isEmpty() || !(args.get(0) instanceof Reference)) { + return false; + } + Reference ref = (Reference) args.get(0); + if (!ID_BUCKET_ARN.equals(ref.getName())) { + return false; + } + List path = getAttr.getPath(); + return path.size() == 1 + && path.get(0) instanceof GetAttr.Part.Key + && "region".equals(((GetAttr.Part.Key) path.get(0)).key().toString()); + } + + // ========== ITE condition creation ========== + + // Creates: _effective_std_region = ite(Region == "aws-global", "us-east-1", Region) + private Condition createSigningRegionIte() { + Expression isGlobal = StringEquals.ofExpressions(Expression.getReference(ID_REGION), "aws-global"); + Ite ite = Ite.ofExpressions(isGlobal, + Expression.of("us-east-1"), + Expression.getReference(ID_REGION)); + return Condition.builder().fn(ite).result(VAR_EFFECTIVE_STD_REGION).build(); + } + + /** + * Creates the effective region ITE for ARN scope. + * + *

This is only called after bucketArn is successfully bound, so we know + * the bucket IS an ARN. The ITE selects between the ARN's region and the + * input region based on UseArnRegion (defaulting to true). + */ + private Condition createEffectiveRegionIte() { + Expression useArnRegion = Coalesce.ofExpressions( + Expression.getReference(ID_USE_ARN_REGION), + Expression.of(true)); + Expression arnRegion = GetAttr.ofExpressions(Expression.getReference(ID_BUCKET_ARN), "region"); + Expression inputRegion = Expression.getReference(ID_REGION); + + Ite ite = Ite.ofExpressions(useArnRegion, arnRegion, inputRegion); + return Condition.builder().fn(ite).result(VAR_EFFECTIVE_ARN_REGION).build(); + } + + // ========== Region rewriting ========== + + private Condition rewriteBucketArnRegionInCondition(Condition cond) { + LibraryFunction fn = cond.getFunction(); + String fnName = fn.getName(); + + if (!AWS_PARTITION.equals(fnName) && !IS_VALID_HOST_LABEL.equals(fnName)) { + return cond; + } + if (fn.getArguments().isEmpty() || !isBucketArnRegion(fn.getArguments().get(0))) { + return cond; + } + + rewriteCount++; + List newArgs = new ArrayList<>(fn.getArguments()); + newArgs.set(0, Expression.getReference(Identifier.of(VAR_EFFECTIVE_ARN_REGION))); + + LibraryFunction newFn = fn.getFunctionDefinition() + .createFunction(FunctionNode.ofExpressions(fnName, newArgs.toArray(new Expression[0]))); + + return cond.toBuilder().fn(newFn).build(); + } + + /** + * Rewrites static URL strings to unify region patterns. + * + *

Pattern order matters: more specific patterns must come first to avoid + * partial matches. + */ + private String rewriteUrlRegionPatterns(String url) { + if (!signingRegionDefined) { + return null; + } + + String targetVar = inBucketArnScope ? VAR_EFFECTIVE_ARN_REGION : VAR_EFFECTIVE_STD_REGION; + String result = url; + boolean changed = false; + + // Order matters: replace more specific patterns first + if (result.contains(".us-east-1.")) { + result = result.replace(".us-east-1.", ".{" + targetVar + "}."); + changed = true; + } + if (result.contains(".{Region}.")) { + result = result.replace(".{Region}.", ".{" + targetVar + "}."); + changed = true; + } + if (result.contains("{Region}")) { + result = result.replace("{Region}", "{" + targetVar + "}"); + changed = true; + } + if (inBucketArnScope && result.contains("{bucketArn#region}")) { + result = result.replace("{bucketArn#region}", "{" + VAR_EFFECTIVE_ARN_REGION + "}"); + changed = true; + } + + if (changed) { + rewriteCount++; + return result; + } + return null; + } + + /** + * Rewrites dynamic template strings to unify region references. + */ + private Literal rewriteDynamicTemplate(StringLiteral str, Template template) { + List parts = template.getParts(); + StringBuilder sb = new StringBuilder(); + boolean changed = false; + String targetVar = inBucketArnScope ? VAR_EFFECTIVE_ARN_REGION : VAR_EFFECTIVE_STD_REGION; + + for (Template.Part part : parts) { + if (part instanceof Template.Dynamic) { + Expression expr = ((Template.Dynamic) part).toExpression(); + if (isBucketArnRegion(expr)) { + sb.append("{").append(VAR_EFFECTIVE_ARN_REGION).append("}"); + rewriteCount++; + changed = true; + continue; + } + if (signingRegionDefined && expr instanceof Reference + && ID_REGION.equals(((Reference) expr).getName())) { + sb.append("{").append(targetVar).append("}"); + rewriteCount++; + changed = true; + continue; + } + sb.append(part); + } else if (part instanceof Template.Literal) { + String literal = ((Template.Literal) part).getValue(); + if (signingRegionDefined && literal.contains(".us-east-1.")) { + literal = literal.replace(".us-east-1.", ".{" + targetVar + "}."); + rewriteCount++; + changed = true; + } + sb.append(literal); + } else { + sb.append(part); + } + } + + return changed ? Literal.stringLiteral(Template.fromString(sb.toString())) : str; + } + + private Literal rewriteSigningRegionValue(Literal value) { + if (!value.asStringLiteral().isPresent()) { + return (Literal) expression(value); + } + + Template template = value.asStringLiteral().get(); + String targetVar = inBucketArnScope ? VAR_EFFECTIVE_ARN_REGION : VAR_EFFECTIVE_STD_REGION; + + // Dynamic template: {Region} or {bucketArn#region} + if (!template.isStatic()) { + List parts = template.getParts(); + if (parts.size() == 1 && parts.get(0) instanceof Template.Dynamic) { + Expression expr = ((Template.Dynamic) parts.get(0)).toExpression(); + if (isBucketArnRegion(expr) + || (expr instanceof Reference && ID_REGION.equals(((Reference) expr).getName()))) { + rewriteCount++; + return Literal.stringLiteral(Template.fromString("{" + targetVar + "}")); + } + } + return (Literal) expression(value); + } + + // Static region value + String staticValue = template.expectLiteral(); + if (isKnownRegion(staticValue)) { + rewriteCount++; + return Literal.stringLiteral(Template.fromString("{" + targetVar + "}")); + } + + return value; + } + + private boolean isKnownRegion(String value) { + return value != null && ("aws-global".equals(value) || AwsPartition.findPartition(value) != null); + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3TreeRewriter.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3TreeRewriter.java new file mode 100644 index 00000000000..8c00d7a6880 --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/s3/S3TreeRewriter.java @@ -0,0 +1,64 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.s3; + +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Rewrites S3 endpoint rules to create a dramatically smaller and more efficient BDD. + * + *

This is a BDD pre-processing transform that makes the decision tree larger but enables dramatically better + * BDD compilation. It solves the "SSA Trap" problem where semantically identical operations appear as syntactically + * different expressions. + * + *

This class composes three separate transforms: + *

    + *
  1. {@link S3AzCanonicalizerTransform} - Canonicalizes AZ extraction: + * {@code substring(Bucket, N, M)} → {@code split(Bucket, "--")[1]}
  2. + *
  3. {@link S3RegionUnifierTransform} - Unifies region references: + * {@code Region}/{@code bucketArn#region} → {@code _signing_region}/{@code _effective_region}
  4. + *
  5. {@link S3ExpressEndpointTransform} - Canonicalizes S3Express endpoints: + * FIPS/DualStack URL variants → ITE-computed segments
  6. + *
+ * + *

Each transform is independent and can be applied separately if needed. + */ +@SmithyInternalApi +public final class S3TreeRewriter { + private static final Logger LOGGER = Logger.getLogger(S3TreeRewriter.class.getName()); + + private S3TreeRewriter() {} + + /** + * Transforms the given endpoint rule set by applying all S3 canonicalization transforms. + * + * @param ruleSet Rules to transform. + * @return the transformed rule set. + */ + public static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + // Pass 1: Canonicalize AZ extraction + S3AzCanonicalizerTransform azTransform = S3AzCanonicalizerTransform.create(); + EndpointRuleSet afterAz = azTransform.endpointRuleSet(ruleSet); + + // Pass 2: Unify region references + S3RegionUnifierTransform regionTransform = S3RegionUnifierTransform.create(); + EndpointRuleSet afterRegion = regionTransform.endpointRuleSet(afterAz); + + // Pass 3: Canonicalize S3Express endpoints + S3ExpressEndpointTransform s3ExpressTransform = S3ExpressEndpointTransform.create(); + EndpointRuleSet result = s3ExpressTransform.endpointRuleSet(afterRegion); + + LOGGER.info(() -> String.format( + "S3 tree rewriter: %d AZ, %d region, %d/%d S3Express rewrites", + azTransform.getRewriteCount(), + regionTransform.getRewriteCount(), + s3ExpressTransform.getRewriteCount(), + s3ExpressTransform.getTotalCount())); + + return result; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/BddCoverageChecker.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/BddCoverageChecker.java index d2b3443163d..e1b10e5130f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/BddCoverageChecker.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/BddCoverageChecker.java @@ -138,6 +138,29 @@ public double getResultCoverage() { return relevantResults == 0 ? 100.0 : (100.0 * coveredRelevantResults / relevantResults); } + /** + * Returns conditions that exist in the conditions list but are not referenced by any BDD node. + * + * @return set of unreferenced conditions + */ + public Set getUnreferencedConditions() { + BitSet referencedByBdd = new BitSet(conditions.size()); + for (int i = 0; i < bdd.getNodeCount(); i++) { + int varIdx = bdd.getVariable(i); + if (varIdx >= 0 && varIdx < conditions.size()) { + referencedByBdd.set(varIdx); + } + } + + Set unreferenced = new HashSet<>(); + for (int i = 0; i < conditions.size(); i++) { + if (!referencedByBdd.get(i)) { + unreferenced.add(conditions.get(i)); + } + } + return unreferenced; + } + // Evaluator that tracks what gets visited during BDD evaluation. private final class TestEvaluator implements ConditionEvaluator { private final RuleEvaluator ruleEvaluator; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java index 9e0dfb0fd39..2dda13db308 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java @@ -11,6 +11,7 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsValidHostLabel; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.ParseUrl; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Split; @@ -43,6 +44,7 @@ public List getLibraryFunctions() { Split.getDefinition(), StringEquals.getDefinition(), Substring.getDefinition(), + Ite.getDefinition(), UriEncode.getDefinition()); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index 6e8d70a8771..efc82026a30 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -215,6 +215,12 @@ public Value visitStringEquals(Expression left, Expression right) { .equals(right.accept(this).expectStringValue())); } + @Override + public Value visitIte(Expression condition, Expression trueValue, Expression falseValue) { + boolean cond = condition.accept(this).expectBooleanValue().getValue(); + return cond ? trueValue.accept(this) : falseValue.accept(this); + } + @Override public Value visitGetAttr(GetAttr getAttr) { return getAttr.evaluate(getAttr.getTarget().accept(this)); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java index 1557b529b52..b4bbc93868f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java @@ -4,10 +4,12 @@ */ package software.amazon.smithy.rulesengine.language.syntax.expressions; +import java.util.Arrays; import java.util.List; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionDefinition; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -86,6 +88,18 @@ default R visitCoalesce(List expressions) { */ R visitStringEquals(Expression left, Expression right); + /** + * Visits an if-then-else (ITE) function. + * + * @param condition the boolean condition expression. + * @param trueValue the value if condition is true. + * @param falseValue the value if condition is false. + * @return the value from the visitor. + */ + default R visitIte(Expression condition, Expression trueValue, Expression falseValue) { + return visitLibraryFunction(Ite.getDefinition(), Arrays.asList(condition, trueValue, falseValue)); + } + /** * Visits a library function. * @@ -138,6 +152,11 @@ public R visitStringEquals(Expression left, Expression right) { return getDefault(); } + @Override + public R visitIte(Expression condition, Expression trueValue, Expression falseValue) { + return getDefault(); + } + @Override public R visitLibraryFunction(FunctionDefinition fn, List args) { return getDefault(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java index 931e5d9f9dd..039d461350e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Optional; import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.error.InnerParseError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; @@ -81,10 +82,10 @@ public R accept(ExpressionVisitor visitor) { } @Override - public Type typeCheck(Scope scope) { + protected Type typeCheckLocal(Scope scope) throws InnerParseError { List args = getArguments(); if (args.size() < 2) { - throw new IllegalArgumentException("Coalesce requires at least 2 arguments, got " + args.size()); + throw new InnerParseError("Coalesce requires at least 2 arguments, got " + args.size()); } // Get the first argument's type as the baseline @@ -98,7 +99,7 @@ public Type typeCheck(Scope scope) { Type innerType = getInnerType(argType); if (!innerType.equals(baseInnerType)) { - throw new IllegalArgumentException(String.format( + throw new InnerParseError(String.format( "Type mismatch in coalesce at argument %d: expected %s but got %s", i + 1, baseInnerType, diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java new file mode 100644 index 00000000000..90acc71da01 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java @@ -0,0 +1,175 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.expressions.functions; + +import java.util.Arrays; +import java.util.List; +import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.error.InnerParseError; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.ToExpression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.ExpressionVisitor; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * An if-then-else (ITE) function that returns one of two values based on a boolean condition. + * + *

This function is critical for avoiding SSA (Static Single Assignment) fragmentation in BDD compilation. + * By computing conditional values atomically without branching, it prevents the graph explosion that occurs when + * boolean flags like UseFips or UseDualStack create divergent paths with distinct variable identities. + * + *

Semantics: {@code ite(condition, trueValue, falseValue)} + *

    + *
  • If condition is true, returns trueValue
  • + *
  • If condition is false, returns falseValue
  • + *
  • The condition must be a non-optional boolean (use coalesce to provide a default if needed)
  • + *
+ * + *

Type checking rules (least upper bound of nullability): + *

    + *
  • {@code ite(Boolean, T, T) => T} - both non-optional, result is non-optional
  • + *
  • {@code ite(Boolean, T, Optional) => Optional} - any optional makes result optional
  • + *
  • {@code ite(Boolean, Optional, T) => Optional} - any optional makes result optional
  • + *
  • {@code ite(Boolean, Optional, Optional) => Optional} - both optional, result is optional
  • + *
+ * + *

Available since: rules engine 1.1. + */ +@SmithyUnstableApi +public final class Ite extends LibraryFunction { + public static final String ID = "ite"; + private static final Definition DEFINITION = new Definition(); + + private Ite(FunctionNode functionNode) { + super(DEFINITION, functionNode); + } + + /** + * Gets the {@link FunctionDefinition} implementation. + * + * @return the function definition. + */ + public static Definition getDefinition() { + return DEFINITION; + } + + /** + * Creates a {@link Ite} function from the given expressions. + * + * @param condition the boolean condition to evaluate + * @param trueValue the value to return if condition is true + * @param falseValue the value to return if condition is false + * @return The resulting {@link Ite} function. + */ + public static Ite ofExpressions(ToExpression condition, ToExpression trueValue, ToExpression falseValue) { + return DEFINITION.createFunction(FunctionNode.ofExpressions(ID, condition, trueValue, falseValue)); + } + + /** + * Creates a {@link Ite} function with a reference condition and string values. + * + * @param conditionRef the reference to a boolean parameter + * @param trueValue the string value if condition is true + * @param falseValue the string value if condition is false + * @return The resulting {@link Ite} function. + */ + public static Ite ofStrings(ToExpression conditionRef, String trueValue, String falseValue) { + return ofExpressions(conditionRef, Expression.of(trueValue), Expression.of(falseValue)); + } + + @Override + public RulesVersion availableSince() { + return RulesVersion.V1_1; + } + + @Override + public R accept(ExpressionVisitor visitor) { + return visitor.visitIte(getArguments().get(0), getArguments().get(1), getArguments().get(2)); + } + + @Override + protected Type typeCheckLocal(Scope scope) throws InnerParseError { + List args = getArguments(); + if (args.size() != 3) { + throw new InnerParseError("ITE requires exactly 3 arguments, got " + args.size()); + } + + // Check condition is a boolean (non-optional) + Type conditionType = args.get(0).typeCheck(scope); + if (!conditionType.equals(Type.booleanType())) { + throw new InnerParseError(String.format( + "ITE condition must be a non-optional Boolean, got %s. " + + "Use coalesce to provide a default for optional booleans.", + conditionType)); + } + + // Get trueValue and falseValue types + Type trueType = args.get(1).typeCheck(scope); + Type falseType = args.get(2).typeCheck(scope); + + // Extract base types (unwrap Optional if present) + Type trueBaseType = getInnerType(trueType); + Type falseBaseType = getInnerType(falseType); + + // Base types must match + if (!trueBaseType.equals(falseBaseType)) { + throw new InnerParseError(String.format( + "ITE branches must have the same base type: true branch is %s, false branch is %s", + trueBaseType, + falseBaseType)); + } + + // Result is optional if EITHER branch is optional (least upper bound) + boolean resultIsOptional = (trueType instanceof OptionalType) || (falseType instanceof OptionalType); + return resultIsOptional ? Type.optionalType(trueBaseType) : trueBaseType; + } + + private static Type getInnerType(Type t) { + return (t instanceof OptionalType) ? ((OptionalType) t).inner() : t; + } + + /** + * A {@link FunctionDefinition} for the {@link Ite} function. + */ + public static final class Definition implements FunctionDefinition { + private Definition() {} + + @Override + public String getId() { + return ID; + } + + @Override + public List getArguments() { + // Actual type checking is done in typeCheck override + return Arrays.asList(Type.booleanType(), Type.anyType(), Type.anyType()); + } + + @Override + public Type getReturnType() { + // Actual return type is computed in typeCheck override + return Type.anyType(); + } + + @Override + public Value evaluate(List arguments) { + throw new UnsupportedOperationException("ITE evaluation is handled by ExpressionVisitor"); + } + + @Override + public Ite createFunction(FunctionNode functionNode) { + return new Ite(functionNode); + } + + @Override + public int getCost() { + return 10; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java index c5daeab89d1..51ebd435407 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java @@ -18,8 +18,8 @@ import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -231,18 +231,29 @@ public String toString() { * @return true if arguments should be swapped */ protected static boolean shouldSwapArgs(Expression arg0, Expression arg1) { - boolean arg0IsRef = isReference(arg0); - boolean arg1IsRef = isReference(arg1); + boolean arg0IsLiteral = isStaticLiteral(arg0); + boolean arg1IsLiteral = isStaticLiteral(arg1); - // Always put References before literals to make things consistent - if (arg0IsRef != arg1IsRef) { - return !arg0IsRef; // Swap if arg0 is literal and arg1 is reference + // Always put non-literals (expressions) before literals to make things consistent + if (arg0IsLiteral != arg1IsLiteral) { + return arg0IsLiteral; // Swap if arg0 is literal and arg1 is not } // Both same type, use string comparison for deterministic order return arg0.toString().compareTo(arg1.toString()) > 0; } + /** + * Returns true if the expression is a static literal (constant value). + * Dynamic string literals (templates with variables) are not considered static. + */ + private static boolean isStaticLiteral(Expression arg) { + if (arg instanceof StringLiteral) { + return ((StringLiteral) arg).value().isStatic(); + } + return arg instanceof Literal; + } + /** * Strips single-variable template wrappers if present. * Converts "{varName}" to just varName reference. @@ -264,13 +275,4 @@ static Expression stripSingleVariableTemplate(Expression expr) { return expr; } - private static boolean isReference(Expression arg) { - if (arg instanceof Reference) { - return true; - } else if (arg instanceof StringLiteral) { - StringLiteral s = (StringLiteral) arg; - return !s.value().isStatic(); - } - return false; - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java index d7c76f7feec..be58bb2415d 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java @@ -28,7 +28,7 @@ public T accept(RuleValueVisitor visitor) { @Override protected Type typecheckValue(Scope scope) { - throw new UnsupportedOperationException("NO_MATCH is a sentinel"); + return Type.anyType(); } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java index 11dbcc8b119..ea80fb56fbb 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java @@ -293,7 +293,7 @@ public static final class Builder implements SmithyBuilder { private Cfg cfg; private ConditionCostModel costModel; private ToDoubleFunction trueProbability; - private double maxAllowedGrowth = 0.1; + private double maxAllowedGrowth = 0.08; private int maxRounds = 30; private int topK = 50; @@ -333,7 +333,7 @@ public Builder trueProbability(ToDoubleFunction trueProbability) { } /** - * Sets the maximum allowed node growth as a fraction (default 0.1 or 10%). + * Sets the maximum allowed node growth as a fraction (default 0.08 or 8%). * * @param maxAllowedGrowth maximum growth (0.0 = no growth, 0.1 = 10% growth) * @return the builder diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index 47666837125..b779eceb4f3 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -6,16 +6,14 @@ import java.util.ArrayList; import java.util.Arrays; -import java.util.Comparator; -import java.util.IdentityHashMap; import java.util.List; -import java.util.Map; import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.IntStream; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionCostModel; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; import software.amazon.smithy.rulesengine.logic.cfg.ConditionDependencyGraph; import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; @@ -34,14 +32,21 @@ public final class SiftingOptimization implements Function { private static final Logger LOGGER = Logger.getLogger(SiftingOptimization.class.getName()); - // When to use a parallel stream private static final int PARALLEL_THRESHOLD = 7; + // Early termination: number of passes to track for plateau detection + private static final int PLATEAU_HISTORY_SIZE = 3; + private static final double PLATEAU_THRESHOLD = 0.5; + // Thread-local BDD builders to avoid allocation overhead private final ThreadLocal threadBuilder = ThreadLocal.withInitial(BddBuilder::new); private final Cfg cfg; private final ConditionDependencyGraph dependencyGraph; + private final ConditionCostModel costModel = ConditionCostModel.createDefault();; + + // Reusable cost estimator, created once per optimization run + private BddCostEstimator costEstimator; // Tiered optimization settings private final int coarseMinNodes; @@ -54,7 +59,7 @@ public final class SiftingOptimization implements Function= HIGH_THRESHOLD) { + sampleRate = Math.max(1, sampleRate - 1); + maxPositions = Math.min(base.maxPositions * 2, maxPositions + 5); + nearbyRadius = Math.min(base.nearbyRadius + 6, nearbyRadius + 2); + bonusPasses = Math.min(bonusPasses + 2, 6); + return true; + } else if (reductionPercent < LOW_THRESHOLD) { + sampleRate = Math.min(base.sampleRate * 2, sampleRate + 2); + maxPositions = Math.max(base.maxPositions / 2, maxPositions - 3); + nearbyRadius = Math.max(0, nearbyRadius - 2); + bonusPasses = Math.max(0, bonusPasses - 2); + } + return false; + } + } + private SiftingOptimization(Builder builder) { this.cfg = SmithyBuilder.requiredState("cfg", builder.cfg); this.coarseMinNodes = builder.coarseMinNodes; @@ -108,382 +151,477 @@ public EndpointBddTrait apply(EndpointBddTrait trait) { private EndpointBddTrait doApply(EndpointBddTrait trait) { LOGGER.info("Starting BDD sifting optimization"); long startTime = System.currentTimeMillis(); - OptimizationState state = initializeOptimization(trait); + State state = initializeOptimization(trait); LOGGER.info(String.format("Initial size: %d nodes", state.initialSize)); - state = runOptimizationStage("Coarse", state, OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); - state = runOptimizationStage("Medium", state, OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); + // Create cost estimator once for the entire optimization run + this.costEstimator = new BddCostEstimator(state.orderView, costModel, null); + + runOptimizationStage("Coarse", state, OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); + runOptimizationStage("Medium", state, OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); if (state.currentSize <= granularMaxNodes) { - state = runOptimizationStage("Granular", state, OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); - } else { - LOGGER.info("Skipping granular stage - too large"); + runOptimizationStage("Granular", state, OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); } - state = runAdjacentSwaps(state); + runBlockMoves(state); + runAdjacentSwaps(state); double totalTimeInSeconds = (System.currentTimeMillis() - startTime) / 1000.0; - if (state.bestSize >= state.initialSize) { + if (state.currentSize >= state.initialSize) { LOGGER.info(String.format("No improvements found in %fs", totalTimeInSeconds)); return trait; } LOGGER.info(String.format("Optimization complete: %d -> %d nodes (%.1f%% total reduction) in %fs", state.initialSize, - state.bestSize, - (1.0 - (double) state.bestSize / state.initialSize) * 100, + state.currentSize, + (1.0 - (double) state.currentSize / state.initialSize) * 100, totalTimeInSeconds)); return trait.toBuilder().conditions(state.orderView).results(state.results).bdd(state.bestBdd).build(); } - private OptimizationState initializeOptimization(EndpointBddTrait trait) { - // Use the trait's existing ordering as the starting point + private State initializeOptimization(EndpointBddTrait trait) { List initialOrder = new ArrayList<>(trait.getConditions()); Condition[] order = initialOrder.toArray(new Condition[0]); List orderView = Arrays.asList(order); Bdd bdd = trait.getBdd(); int initialSize = bdd.getNodeCount() - 1; - return new OptimizationState(order, orderView, bdd, initialSize, initialSize, trait.getResults()); + return new State(order, orderView, bdd, initialSize, trait.getResults()); } - private OptimizationState runOptimizationStage( + private void runOptimizationStage( String stageName, - OptimizationState state, + State state, OptimizationEffort effort, - int targetNodeCount, + int targetNodes, int maxPasses, - double minReductionPercent + double minReduction ) { - if (targetNodeCount > 0 && state.currentSize <= targetNodeCount) { - return state; + if (targetNodes > 0 && state.currentSize <= targetNodes) { + return; } - LOGGER.info(String.format("Stage: %s optimization (%d nodes%s)", - stageName, - state.currentSize, - targetNodeCount > 0 ? String.format(", target < %d", targetNodeCount) : "")); + LOGGER.info(String.format("Stage: %s (%d nodes)", stageName, state.currentSize)); + + AdaptiveEffort ae = new AdaptiveEffort(effort); + double[] history = new double[PLATEAU_HISTORY_SIZE]; + int historyIdx = 0, consecutiveLow = 0; + + for (int pass = 1; pass <= maxPasses + ae.bonusPasses; pass++) { + if (targetNodes > 0 && state.currentSize <= targetNodes) { + break; + } - OptimizationState currentState = state; - for (int pass = 1; pass <= maxPasses; pass++) { - if (targetNodeCount > 0 && currentState.currentSize <= targetNodeCount) { + int startSize = state.currentSize; + PassContext result = runPass(state, ae); + if (result.improvements == 0) { break; } - int passStartSize = currentState.currentSize; - OptimizationResult result = runPass(currentState, effort); - if (result.improved) { - currentState = currentState.withResult(result.bdd, result.size, result.results); - double reduction = (1.0 - (double) result.size / passStartSize) * 100; - LOGGER.fine(String.format("%s pass %d: %d -> %d nodes (%.1f%% reduction)", - stageName, - pass, - passStartSize, - result.size, - reduction)); - if (minReductionPercent > 0 && reduction < minReductionPercent) { - LOGGER.fine(String.format("%s optimization yielding diminishing returns", stageName)); + state.update(result.bestBdd, result.bestSize, result.bestResults); + double reduction = (1.0 - (double) result.bestSize / startSize) * 100; + + history[historyIdx++ % PLATEAU_HISTORY_SIZE] = reduction; + if (historyIdx >= PLATEAU_HISTORY_SIZE) { + boolean plateau = true; + for (double r : history) { + if (r >= PLATEAU_THRESHOLD) { + plateau = false; + break; + } + } + if (plateau) { break; } - } else { - LOGGER.fine(String.format("%s pass %d found no improvements", stageName, pass)); + } + + consecutiveLow = ae.adapt(reduction) ? 0 : (reduction < 2.0 ? consecutiveLow + 1 : 0); + if (consecutiveLow >= 2 || (minReduction > 0 && reduction < minReduction)) { break; } } - - return currentState; } - private OptimizationState runAdjacentSwaps(OptimizationState state) { + private void runBlockMoves(State state) { if (state.currentSize > granularMaxNodes) { - return state; + return; } + LOGGER.info("Running block moves"); - LOGGER.info("Running adjacent swaps optimization"); - OptimizationState currentState = state; - - // Run multiple sweeps until no improvement - for (int sweep = 1; sweep <= 3; sweep++) { - OptimizationContext context = new OptimizationContext(currentState, dependencyGraph); - int startSize = currentState.currentSize; + List> blocks = new ArrayList<>(); + for (List b : findDependencyBlocks(state.orderView)) { + if (b.size() >= 2 && b.size() <= 5) { + blocks.add(b); + } + } - for (int i = 0; i < currentState.order.length - 1; i++) { - // Adjacent swap requires both elements to be able to occupy each other's positions - if (context.constraints.canMove(i, i + 1) && context.constraints.canMove(i + 1, i)) { - BddCompilerSupport.move(currentState.order, i, i + 1); - BddCompilerSupport.BddCompilationResult compilationResult = - BddCompilerSupport.compile(cfg, currentState.orderView, threadBuilder.get()); - int swappedSize = compilationResult.bdd.getNodeCount() - 1; - if (swappedSize < context.bestSize) { - context = context.withImprovement( - new PositionResult(i + 1, - swappedSize, - compilationResult.bdd, - compilationResult.results)); - } else { - BddCompilerSupport.move(currentState.order, i + 1, i); // Swap back - } - } + for (List block : blocks) { + PassContext ctx = new PassContext(state, dependencyGraph); + Result r = tryBlockMove(block, ctx); + if (r != null && r.size < ctx.bestSize) { + state.update(r.bdd, r.size, r.results); } + } + } + + private List> findDependencyBlocks(List ordering) { + List> blocks = new ArrayList<>(); + if (ordering.isEmpty()) { + return blocks; + } - if (context.improvements > 0) { - currentState = currentState.withResult(context.bestBdd, context.bestSize, context.bestResults); - LOGGER.fine(String.format("Adjacent swaps sweep %d: %d -> %d nodes", - sweep, - startSize, - context.bestSize)); + List curr = new ArrayList<>(); + curr.add(0); + for (int i = 1; i < ordering.size(); i++) { + if (dependencyGraph.getDependencies(ordering.get(i)).contains(ordering.get(i - 1))) { + curr.add(i); } else { - break; + if (curr.size() >= 2) { + blocks.add(curr); + } + curr = new ArrayList<>(); + curr.add(i); } } - return currentState; + if (curr.size() >= 2) { + blocks.add(curr); + } + + return blocks; } - private OptimizationResult runPass(OptimizationState state, OptimizationEffort effort) { - OptimizationContext context = new OptimizationContext(state, dependencyGraph); + private Result tryBlockMove(List block, PassContext ctx) { + int blockStart = block.get(0), blockEnd = block.get(block.size() - 1), blockSize = block.size(); - List selectedConditions = IntStream.range(0, state.orderView.size()) - .filter(i -> i % effort.sampleRate == 0) - .mapToObj(state.orderView::get) - .collect(Collectors.toList()); + // Compute valid range considering all block members' constraints + int minPos = 0, maxPos = ctx.order.length - blockSize; + for (int idx : block) { + int offset = idx - blockStart; + minPos = Math.max(minPos, ctx.constraints.getMinValidPosition(idx) - offset); + maxPos = Math.min(maxPos, ctx.constraints.getMaxValidPosition(idx) - offset); + } + + if (minPos >= maxPos) { + return null; + } - for (Condition condition : selectedConditions) { - Integer varIdx = context.liveIndex.get(condition); - if (varIdx == null) { + // Try a few strategic positions: min, max, mid + int[] targets = {minPos, maxPos, minPos + (maxPos - minPos) / 2}; + Result best = null; + + for (int target : targets) { + if (target == blockStart) { continue; } - List positions = getStrategicPositions(varIdx, context.constraints, effort); - if (positions.isEmpty()) { + Condition[] candidate = ctx.order.clone(); + moveBlock(candidate, blockStart, blockEnd, target); + List candidateList = Arrays.asList(candidate); + + // Validate constraints + ConditionDependencyGraph.OrderConstraints nc = dependencyGraph.createOrderConstraints(candidateList); + boolean valid = true; + for (int j = 0; j < candidate.length; j++) { + if (nc.getMinValidPosition(j) > j || nc.getMaxValidPosition(j) < j) { + valid = false; + break; + } + } + + if (!valid) { continue; } - context = tryImprovePosition(context, varIdx, positions); + BddCompilerSupport.BddCompilationResult cr = + BddCompilerSupport.compile(cfg, candidateList, threadBuilder.get()); + int size = cr.bdd.getNodeCount() - 1; + double cost = computeCost(cr.bdd, candidateList); + if (best == null || size < best.size || (size == best.size && cost < best.cost)) { + best = new Result(target, size, cost, cr.bdd, cr.results); + } } + return best; + } + + /** + * Moves a contiguous block of elements from [start, end] to begin at targetStart. + */ + private static void moveBlock(Condition[] order, int start, int end, int targetStart) { + if (targetStart == start) { + return; + } + + int blockSize = end - start + 1; + Condition[] block = new Condition[blockSize]; + System.arraycopy(order, start, block, 0, blockSize); - return context.toResult(); + if (targetStart < start) { + // Move block earlier: shift elements [targetStart, start) to the right + System.arraycopy(order, targetStart, order, targetStart + blockSize, start - targetStart); + System.arraycopy(block, 0, order, targetStart, blockSize); + } else { + // Move block later: shift elements (end, targetStart + blockSize) to the left + int shiftStart = end + 1; + int shiftEnd = targetStart + blockSize; + if (shiftEnd > order.length) { + shiftEnd = order.length; + } + System.arraycopy(order, shiftStart, order, start, shiftEnd - shiftStart); + System.arraycopy(block, 0, order, targetStart, blockSize); + } } - private OptimizationContext tryImprovePosition(OptimizationContext context, int varIdx, List positions) { - PositionResult best = findBestPosition(positions, context, varIdx); - if (best != null && best.count <= context.bestSize) { // Accept ties - BddCompilerSupport.move(context.order, varIdx, best.position); - return context.withImprovement(best); + private void runAdjacentSwaps(State state) { + if (state.currentSize > granularMaxNodes) { + return; } - return context; + for (int sweep = 0; sweep < 3; sweep++) { + PassContext ctx = new PassContext(state, dependencyGraph); + for (int i = 0; i < state.order.length - 1; i++) { + // Adjacent swap requires both elements to be able to occupy each other's positions + if (ctx.constraints.canMove(i, i + 1) && ctx.constraints.canMove(i + 1, i)) { + BddCompilerSupport.move(state.order, i, i + 1); + BddCompilerSupport.BddCompilationResult cr = BddCompilerSupport.compile( + cfg, + state.orderView, + threadBuilder.get()); + int size = cr.bdd.getNodeCount() - 1; + if (size < ctx.bestSize) { + ctx.recordImprovement(new Result(i + 1, size, cr.bdd, cr.results, null)); + } else { + BddCompilerSupport.move(state.order, i + 1, i); + } + } + } + if (ctx.improvements == 0) { + break; + } + state.update(ctx.bestBdd, ctx.bestSize, ctx.bestResults); + } + } + + private PassContext runPass(State state, AdaptiveEffort effort) { + PassContext ctx = new PassContext(state, dependencyGraph); + int[] nodeCounts = computeNodeCountsPerVariable(state.bestBdd); + int[] selectedIndices = selectConditionsByPriority(state.orderView.size(), nodeCounts, effort.sampleRate); + + for (int varIdx : selectedIndices) { + List positions = getStrategicPositions(varIdx, ctx.constraints, effort, state.orderView.size()); + if (positions.isEmpty()) { + continue; + } + Result best = findBestPosition(positions, ctx, varIdx); + if (best != null && best.size <= ctx.bestSize) { + BddCompilerSupport.move(ctx.order, varIdx, best.position); + ctx.recordImprovement(best); + } + } + return ctx; + } + + /** + * Computes the number of BDD nodes testing each variable. + */ + private static int[] computeNodeCountsPerVariable(Bdd bdd) { + int[] counts = new int[bdd.getConditionCount()]; + for (int i = 0; i < bdd.getNodeCount(); i++) { + int v = bdd.getVariable(i); + if (v >= 0 && v < counts.length) { + counts[v]++; + } + } + return counts; + } + + private static int[] selectConditionsByPriority(int n, int[] nodeCounts, int sampleRate) { + int[] indices = IntStream.range(0, n) + .boxed() + .sorted((a, b) -> Integer.compare(nodeCounts[b], nodeCounts[a])) + .mapToInt(i -> i) + .toArray(); + return sampleRate <= 1 ? indices : Arrays.copyOf(indices, Math.max(1, n / sampleRate)); } - private PositionResult findBestPosition(List positions, OptimizationContext ctx, int varIdx) { - return (positions.size() > PARALLEL_THRESHOLD ? positions.parallelStream() : positions.stream()) + /** Two-pass position finder: compile candidates, then cost-break ties among min-size. */ + private Result findBestPosition(List positions, PassContext ctx, int varIdx) { + // First pass: compile all candidates + List candidates = (positions.size() > PARALLEL_THRESHOLD + ? positions.parallelStream() + : positions.stream()) .map(pos -> { Condition[] order = ctx.order.clone(); BddCompilerSupport.move(order, varIdx, pos); + List orderList = Arrays.asList(order); BddCompilerSupport.BddCompilationResult cr = - BddCompilerSupport.compile(cfg, Arrays.asList(order), threadBuilder.get()); - return new PositionResult(pos, cr.bdd.getNodeCount() - 1, cr.bdd, cr.results); + BddCompilerSupport.compile(cfg, orderList, threadBuilder.get()); + return new Result(pos, cr.bdd.getNodeCount() - 1, cr.bdd, cr.results, orderList); }) - .filter(pr -> pr.count <= ctx.bestSize) - .min(Comparator.comparingInt((PositionResult pr) -> pr.count).thenComparingInt(pr -> pr.position)) - .orElse(null); + .filter(c -> c.size <= ctx.bestSize) + .collect(Collectors.toList()); + + if (candidates.isEmpty()) { + return null; + } + + // Second pass: among min-size candidates, pick lowest cost + int minSize = Integer.MAX_VALUE; + for (Result c : candidates) { + if (c.size < minSize) { + minSize = c.size; + } + } + + Result best = null; + for (Result c : candidates) { + if (c.size == minSize) { + double cost = computeCost(c.bdd, c.orderList); + if (best == null || cost < best.cost || (cost == best.cost && c.position < best.position)) { + best = new Result(c.position, c.size, cost, c.bdd, c.results); + } + } + } + return best; + } + + private double computeCost(Bdd bdd, List ordering) { + return costEstimator.expectedCost(bdd, ordering); } private static List getStrategicPositions( int varIdx, - ConditionDependencyGraph.OrderConstraints constraints, - OptimizationEffort effort + ConditionDependencyGraph.OrderConstraints c, + AdaptiveEffort ae, + int orderSize ) { - int min = constraints.getMinValidPosition(varIdx); - int max = constraints.getMaxValidPosition(varIdx); + int min = c.getMinValidPosition(varIdx); + int max = c.getMaxValidPosition(varIdx); int range = max - min; - if (range <= effort.exhaustiveThreshold) { - List positions = new ArrayList<>(range); + // Exhaustive for small ranges + if (range <= ae.base.exhaustiveThreshold) { + List pos = new ArrayList<>(range); for (int p = min; p < max; p++) { - if (p != varIdx && constraints.canMove(varIdx, p)) { - positions.add(p); + if (p != varIdx && c.canMove(varIdx, p)) { + pos.add(p); } } - return positions; + return pos; } - List positions = new ArrayList<>(effort.maxPositions); + List pos = new ArrayList<>(ae.maxPositions); + boolean[] seen = new boolean[orderSize]; - // Test extremes first since they often yield the best improvements - if (min != varIdx && constraints.canMove(varIdx, min)) { - positions.add(min); - } - if (positions.size() >= effort.maxPositions) { - return positions; + // Extremes + if (min != varIdx && c.canMove(varIdx, min)) { + pos.add(min); + seen[min] = true; } - if (max - 1 != varIdx && constraints.canMove(varIdx, max - 1)) { - positions.add(max - 1); - } - if (positions.size() >= effort.maxPositions) { - return positions; + if (max - 1 != varIdx && c.canMove(varIdx, max - 1)) { + pos.add(max - 1); + seen[max - 1] = true; } - // Test local moves that preserve relative ordering with neighbors - for (int offset = -effort.nearbyRadius; offset <= effort.nearbyRadius; offset++) { - if (offset != 0) { - if (positions.size() >= effort.maxPositions) { - return positions; - } - int p = varIdx + offset; - if (p >= min && p < max && !positions.contains(p) && constraints.canMove(varIdx, p)) { - positions.add(p); - } + // Global sampling + int step = Math.max(1, range / Math.min(15, ae.maxPositions / 2)); + for (int p = min + step; p < max - step && pos.size() < ae.maxPositions; p += step) { + if (p != varIdx && !seen[p] && c.canMove(varIdx, p)) { + pos.add(p); + seen[p] = true; } } - // Sample intermediate positions to find global improvements - if (positions.size() >= effort.maxPositions) { - return positions; - } - - int maxSamples = Math.min(15, effort.maxPositions / 2); - int samples = Math.min(maxSamples, Math.max(2, range / 4)); - int step = Math.max(1, range / samples); - - for (int p = min + step; p < max - step && positions.size() < effort.maxPositions; p += step) { - if (p != varIdx && !positions.contains(p) && constraints.canMove(varIdx, p)) { - positions.add(p); + // Local neighborhood + for (int off = -ae.nearbyRadius; off <= ae.nearbyRadius && pos.size() < ae.maxPositions; off++) { + int p = varIdx + off; + if (off != 0 && p >= min && p < max && !seen[p] && c.canMove(varIdx, p)) { + pos.add(p); + seen[p] = true; } } - return positions; - } - - private static Map rebuildIndex(List orderView) { - Map index = new IdentityHashMap<>(); - for (int i = 0; i < orderView.size(); i++) { - index.put(orderView.get(i), i); - } - return index; + return pos; } - // Helper class to track optimization context within a pass - private static final class OptimizationContext { + /** Mutable context for tracking optimization progress within a pass. */ + private static final class PassContext { final Condition[] order; final List orderView; final ConditionDependencyGraph dependencyGraph; - final ConditionDependencyGraph.OrderConstraints constraints; - final Map liveIndex; - final Bdd bestBdd; - final int bestSize; - final List bestResults; - final int improvements; - - OptimizationContext(OptimizationState state, ConditionDependencyGraph dependencyGraph) { + ConditionDependencyGraph.OrderConstraints constraints; + Bdd bestBdd; + int bestSize; + List bestResults; + int improvements; + + PassContext(State state, ConditionDependencyGraph dependencyGraph) { this.order = state.order; this.orderView = state.orderView; - this.dependencyGraph = dependencyGraph; - this.constraints = dependencyGraph.createOrderConstraints(orderView); - this.liveIndex = rebuildIndex(orderView); - this.bestBdd = null; this.bestSize = state.currentSize; - this.bestResults = null; - this.improvements = 0; - } - - private OptimizationContext( - Condition[] order, - List orderView, - ConditionDependencyGraph dependencyGraph, - ConditionDependencyGraph.OrderConstraints constraints, - Map liveIndex, - Bdd bestBdd, - int bestSize, - List bestResults, - int improvements - ) { - this.order = order; - this.orderView = orderView; this.dependencyGraph = dependencyGraph; - this.constraints = constraints; - this.liveIndex = liveIndex; - this.bestBdd = bestBdd; - this.bestSize = bestSize; - this.bestResults = bestResults; - this.improvements = improvements; - } - - OptimizationContext withImprovement(PositionResult result) { - ConditionDependencyGraph.OrderConstraints newConstraints = - dependencyGraph.createOrderConstraints(orderView); - Map newIndex = rebuildIndex(orderView); - return new OptimizationContext(order, - orderView, - dependencyGraph, - newConstraints, - newIndex, - result.bdd, - result.count, - result.results, - improvements + 1); - } - - OptimizationResult toResult() { - return new OptimizationResult(bestBdd, bestSize, improvements > 0, bestResults); + this.constraints = dependencyGraph.createOrderConstraints(orderView); + } + + void recordImprovement(Result result) { + this.bestBdd = result.bdd; + this.bestSize = result.size; + this.bestResults = result.results; + this.constraints = dependencyGraph.createOrderConstraints(orderView); + this.improvements++; } } - private static final class PositionResult { + /** Result holder for BDD compilation with optional position/cost metadata. */ + private static final class Result { final int position; - final int count; + final int size; + final double cost; final Bdd bdd; final List results; + final List orderList; // For deferred cost computation - PositionResult(int position, int count, Bdd bdd, List results) { - this.position = position; - this.count = count; - this.bdd = bdd; - this.results = results; + Result(int position, int size, Bdd bdd, List results, List orderList) { + this(position, size, Double.MAX_VALUE, bdd, results, orderList); } - } - private static final class OptimizationResult { - final Bdd bdd; - final int size; - final boolean improved; - final List results; + Result(int position, int size, double cost, Bdd bdd, List results) { + this(position, size, cost, bdd, results, null); + } - OptimizationResult(Bdd bdd, int size, boolean improved, List results) { - this.bdd = bdd; + Result(int position, int size, double cost, Bdd bdd, List results, List orderList) { + this.position = position; this.size = size; - this.improved = improved; + this.cost = cost; + this.bdd = bdd; this.results = results; + this.orderList = orderList; } } - private static final class OptimizationState { + /** Tracks overall optimization state across stages. */ + private static final class State { final Condition[] order; final List orderView; - final Bdd bestBdd; - final int currentSize; - final int bestSize; final int initialSize; - final List results; + Bdd bestBdd; + int currentSize; + List results; - OptimizationState( - Condition[] order, - List orderView, - Bdd bestBdd, - int currentSize, - int initialSize, - List results - ) { + State(Condition[] order, List orderView, Bdd bdd, int size, List results) { this.order = order; this.orderView = orderView; - this.bestBdd = bestBdd; - this.currentSize = currentSize; - this.bestSize = currentSize; - this.initialSize = initialSize; + this.bestBdd = bdd; + this.currentSize = size; + this.initialSize = size; this.results = results; } - OptimizationState withResult(Bdd newBdd, int newSize, List newResults) { - return new OptimizationState(order, orderView, newBdd, newSize, initialSize, newResults); + void update(Bdd bdd, int size, List results) { + this.bestBdd = bdd; + this.currentSize = size; + this.results = results; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java index 1cadcce2525..ac057d61179 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -9,16 +9,13 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.logic.ConditionReference; @@ -41,6 +38,9 @@ public final class CfgBuilder { private final Map resultCache = new HashMap<>(); private final Map resultNodeCache = new HashMap<>(); + // Track function expressions that have bindings (for isSet consolidation) + private final Map functionBindings = new HashMap<>(); + public CfgBuilder(EndpointRuleSet ruleSet) { // Apply SSA transform to ensure globally unique variable names this.ruleSet = SsaTransform.transform(ruleSet); @@ -115,9 +115,25 @@ public ConditionReference createConditionReference(Condition condition) { negated = !negated; } + // Consolidate isSet(f(x)) with existing v = f(x) bindings + canonical = consolidateIsSetWithBinding(canonical); + + // Deep-copy via serialization to get fresh Expression objects. + // This avoids sharing expressions that may have cached types from + // being type-checked in different scopes during EndpointRuleSet.build(). + canonical = Condition.fromNode(canonical.toNode()); + + // Track bindings for future isSet consolidation + if (canonical.getResult().isPresent()) { + String fnKey = canonical.getFunction().toString(); + functionBindings.putIfAbsent(fnKey, canonical); + } + ConditionReference reference = new ConditionReference(canonical, negated); conditionToReference.put(condition, reference); + // Also cache the canonical form so equivalent conditions from different branches + // that might have different original objects will still hit the cache. if (!negated && !condition.equals(canonical)) { conditionToReference.put(canonical, reference); } @@ -129,6 +145,32 @@ private Rule intern(Rule rule) { return resultCache.computeIfAbsent(canonicalizeResult(rule), k -> k); } + /** + * Consolidates {@code isSet(f(x))} with an existing {@code v = f(x)} binding. + * + *

This catches patterns that tree-level optimization cannot handle. Specifically, when the tree + * contains {@code not(isSet(f(x)))} in one branch and {@code v = f(x)} in another branch. + * Tree-level transforms can't handle this because they don't have visibility across sibling branches. + * The CFG builder's global {@code functionBindings} map provides this cross-branch visibility. + * + * @see SsaTransform for the full explanation of why both tree-level and CFG-level optimization are needed + */ + private Condition consolidateIsSetWithBinding(Condition condition) { + if (!(condition.getFunction() instanceof IsSet) || condition.getResult().isPresent()) { + return condition; + } + Expression inner = condition.getFunction().getArguments().get(0); + if (!(inner instanceof LibraryFunction)) { + return condition; + } + String fnKey = inner.toString(); + Condition existingBinding = functionBindings.get(fnKey); + if (existingBinding != null) { + return existingBinding; + } + return condition; + } + private Rule canonicalizeResult(Rule rule) { return rule == null ? null : rule.withConditions(Collections.emptyList()); } @@ -139,19 +181,23 @@ private Condition canonicalizeBooleanEquals(Condition condition) { } List args = condition.getFunction().getArguments(); - if (args.size() != 2 || !(args.get(0) instanceof Reference) || !(args.get(1) instanceof Literal)) { + if (args.size() != 2) { + return condition; + } + + // After canonicalization, literals should be in arg1 position + // Check if arg1 is a boolean literal with value false + if (!(args.get(1) instanceof Literal)) { return condition; } - Reference ref = (Reference) args.get(0); Boolean literalValue = ((Literal) args.get(1)).asBooleanLiteral().orElse(null); - if (literalValue != null && !literalValue && ruleSet != null) { - String varName = ref.getName().toString(); - Optional param = ruleSet.getParameters().get(Identifier.of(varName)); - if (param.isPresent() && param.get().getDefault().isPresent()) { - return condition.toBuilder().fn(BooleanEquals.ofExpressions(ref, true)).build(); - } + // Normalize booleanEquals(X, false) to booleanEquals(X, true) with negation + if (literalValue != null && !literalValue) { + return condition.toBuilder() + .fn(BooleanEquals.ofExpressions(args.get(0), true)) + .build(); } return condition; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java index 365ebe3e77c..17196dbc788 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java @@ -39,14 +39,13 @@ final class CoalesceTransform { private int cacheHits = 0; private int skippedNoZeroValue = 0; private int skippedMultipleUses = 0; - private final Set skippedRecordTypes = new HashSet<>(); static EndpointRuleSet transform(EndpointRuleSet ruleSet) { CoalesceTransform transform = new CoalesceTransform(); List transformedRules = new ArrayList<>(); - for (int i = 0; i < ruleSet.getRules().size(); i++) { - transformedRules.add(transform.transformRule(ruleSet.getRules().get(i), "root/rule[" + i + "]")); + for (Rule rule : ruleSet.getRules()) { + transformedRules.add(transform.transformRule(rule)); } if (LOGGER.isLoggable(Level.INFO)) { @@ -65,7 +64,7 @@ static EndpointRuleSet transform(EndpointRuleSet ruleSet) { .build(); } - private Rule transformRule(Rule rule, String rulePath) { + private Rule transformRule(Rule rule) { // Count local usage for THIS rule's conditions Map localVarUsage = new HashMap<>(); for (Condition condition : rule.getConditions()) { @@ -82,10 +81,14 @@ private Rule transformRule(Rule rule, String rulePath) { if (rule instanceof TreeRule) { TreeRule treeRule = (TreeRule) rule; + List transformedNested = new ArrayList<>(); + for (Rule nested : treeRule.getRules()) { + transformedNested.add(transformRule(nested)); + } return TreeRule.builder() .description(rule.getDocumentation().orElse(null)) .conditions(transformedConditions) - .treeRule(TreeRewriter.transformNestedRules(treeRule, rulePath, this::transformRule)); + .treeRule(transformedNested); } // CoalesceTransform only modifies conditions, not endpoints/errors @@ -144,7 +147,6 @@ private boolean canCoalesce(String var, Condition bind, Condition use, Map replacements = new HashMap<>(); replacements.put(var, coalesced); - Expression replaced = TreeRewriter.forReplacements(replacements).rewrite(useExpr); + Expression replaced = TreeMapper.newReferenceReplacingMapper(replacements).expression(useExpr); LibraryFunction canonicalized = ((LibraryFunction) replaced).canonicalize(); Condition.Builder builder = Condition.builder().fn(canonicalized); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/DeadStoreEliminationTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/DeadStoreEliminationTransform.java new file mode 100644 index 00000000000..58cb097ccf6 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/DeadStoreEliminationTransform.java @@ -0,0 +1,51 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; + +/** + * Removes bindings from conditions when the bound variable is never used. + * + *

As part of the optimization pipeline, we rewrite some expressions to create a + * variable binding in case the expressions can be consolidated with + * VariableConsolidationTransform. When they can't, it would leave a dangling binding + * that isn't used, so this transform rewrites those conditions to not create + * these dead stores. + */ +final class DeadStoreEliminationTransform extends TreeMapper { + private static final Logger LOGGER = Logger.getLogger(DeadStoreEliminationTransform.class.getName()); + + private int eliminated = 0; + private final VariableAnalysis analysis; + + private DeadStoreEliminationTransform(VariableAnalysis analysis) { + this.analysis = analysis; + } + + static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + DeadStoreEliminationTransform t = new DeadStoreEliminationTransform(VariableAnalysis.analyze(ruleSet)); + EndpointRuleSet result = t.endpointRuleSet(ruleSet); + if (t.eliminated > 0) { + LOGGER.info(() -> "Dead store elimination: " + t.eliminated + " bindings removed"); + } + return result; + } + + @Override + public Condition condition(Rule rule, Condition cond) { + if (cond.getResult().isPresent()) { + String varName = cond.getResult().get().toString(); + if (analysis.getReferenceCount(varName) == 0) { + eliminated++; + return Condition.builder().fn(cond.getFunction()).build(); + } + } + return cond; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/IsSetBooleanCoalesceTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/IsSetBooleanCoalesceTransform.java new file mode 100644 index 00000000000..de855ed8618 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/IsSetBooleanCoalesceTransform.java @@ -0,0 +1,144 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Logger; +import software.amazon.smithy.model.node.ArrayNode; +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; + +/** + * Coalesces consecutive isSet + booleanEquals patterns into a single coalesced check. + * + *

This transform identifies patterns where: + *

    + *
  • {@code isSet(X)} is immediately followed by {@code booleanEquals(X, true)}
  • + *
  • {@code isSet(X)} is immediately followed by {@code booleanEquals(X, false)}
  • + *
+ * + *

These patterns are replaced with: + *

    + *
  • {@code booleanEquals(coalesce(X, false), true)} - equivalent to "X is set and true"
  • + *
  • {@code booleanEquals(coalesce(X, true), false)} - equivalent to "X is set and false"
  • + *
+ * + *

This reduces the number of conditions in the BDD, improving both space and potentially node count. + */ +final class IsSetBooleanCoalesceTransform extends TreeMapper { + private static final Logger LOGGER = Logger.getLogger(IsSetBooleanCoalesceTransform.class.getName()); + + private int coalescedCount = 0; + + static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + IsSetBooleanCoalesceTransform t = new IsSetBooleanCoalesceTransform(); + List transformedRules = new ArrayList<>(); + for (Rule rule : ruleSet.getRules()) { + transformedRules.add(t.rule(rule)); + } + + if (t.coalescedCount > 0) { + LOGGER.info(() -> String.format("IsSet+boolean coalesce: %d patterns collapsed", t.coalescedCount)); + } + + // Build the node representation manually to avoid type-checking, then deserialize + // to get fresh Expression objects without cached types. + ObjectNode.Builder builder = ruleSet.toNode().expectObjectNode().toBuilder(); + ArrayNode.Builder rulesBuilder = ArrayNode.builder(); + for (Rule rule : transformedRules) { + rulesBuilder.withValue(rule.toNode()); + } + builder.withMember("rules", rulesBuilder.build()); + + return EndpointRuleSet.fromNode(builder.build()); + } + + /** + * Overrides conditions processing to look at pairs of consecutive conditions. + * This is necessary because we need to merge isSet(X) + booleanEquals(X, val) patterns. + */ + @Override + public List conditions(Rule rule, List conditions) { + List result = new ArrayList<>(); + int size = conditions.size(); + + for (int i = 0; i < size; i++) { + Condition current = conditions.get(i); + // Check for isSet(X) pattern + if (i + 1 < size && !current.getResult().isPresent() && current.getFunction() instanceof IsSet) { + Expression isSetArg = current.getFunction().getArguments().get(0); + if (isSetArg instanceof Reference) { + Condition next = conditions.get(i + 1); + + // Check if next is booleanEquals(X, true/false) + Condition coalesced = tryCoalesce((Reference) isSetArg, next); + if (coalesced != null) { + result.add(coalesced); + coalescedCount++; + i++; // Skip the next condition since we ate it + continue; + } + } + } + + result.add(current); + } + + return result; + } + + private Condition tryCoalesce(Reference isSetRef, Condition next) { + if (next.getResult().isPresent()) { + return null; // Can't coalesce if next has a binding + } + + LibraryFunction fn = next.getFunction(); + if (!(fn instanceof BooleanEquals)) { + return null; + } + + List args = fn.getArguments(); + Expression arg1 = args.get(0); + Expression arg2 = args.get(1); + + // Ensure the first argument is a reference to the isSet(X) condition + if (!(arg1 instanceof Reference)) { + return null; + } + Reference boolRef = (Reference) arg1; + if (!boolRef.getName().equals(isSetRef.getName())) { + return null; + } + + // Ensure the second argument is a boolean literal + if (!(arg2 instanceof Literal)) { + return null; + } + Boolean literalValue = ((Literal) arg2).asBooleanLiteral().orElse(null); + if (literalValue == null) { + return null; + } + + // Create a coalesced condition that sets the default to the opposite of the literal value: + // isSet(X) && booleanEquals(X, true) -> booleanEquals(coalesce(X, false), true) + // isSet(X) && booleanEquals(X, false) -> booleanEquals(coalesce(X, true), false) + boolean defaultValue = !literalValue; + + // Create a fresh Reference to avoid sharing cached type information with other conditions + Reference freshRef = Expression.getReference(isSetRef.getName()); + Expression coalesced = Coalesce.ofExpressions(freshRef, Expression.of(defaultValue)); + return Condition.builder().fn(BooleanEquals.ofExpressions(coalesced, literalValue)).build(); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java index f1da1f16520..62659aa46af 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java @@ -5,114 +5,104 @@ package software.amazon.smithy.rulesengine.logic.cfg; import java.util.ArrayDeque; -import java.util.ArrayList; import java.util.Deque; import java.util.HashMap; -import java.util.HashSet; -import java.util.IdentityHashMap; -import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; -import software.amazon.smithy.rulesengine.language.Endpoint; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; /** - * Transforms a decision tree into Static Single Assignment (SSA) form. + * Transforms a decision tree into Static Single Assignment (SSA) form and orchestrates the pre-BDD optimization + * pipeline (see transform()). * - *

This transformation ensures that each variable is assigned exactly once by renaming variables when they are - * reassigned in different parts of the tree. For example, if variable "x" is assigned in multiple branches, they - * become "x_ssa_1", "x_ssa_2", "x_ssa_3", etc. Without this transform, the BDD compilation would confuse divergent - * paths that have the same variable name. + *

Why Tree-Level AND CFG-Level Optimization?

* - *

Note that this transform is only applied when the reassignment is done using different - * arguments than previously seen assignments of the same variable name. + *

Tree-level transforms provide the bulk of optimization, but {@link CfgBuilder} performs additional + * consolidation via {@code consolidateIsSetWithBinding}. This catches cross-branch patterns that tree-level + * transforms cannot see. + * + *

Specifically, when one branch contains {@code not(isSet(f(x)))} and another branch contains + * {@code v = f(x)}, tree-level transforms can't consolidate them because: + *

    + *
  • {@link SyntheticBindingTransform} doesn't see the inner {@code isSet} - it checks the outer function + * which is {@code Not}, not {@code IsSet}
  • + *
  • Tree transforms don't have visibility across sibling branches
  • + *
+ * + *

During CFG construction, {@code CfgBuilder#isNegationWrapper} unwraps the {@code Not} to get + * {@code isSet(f(x))}, and {@code consolidateIsSetWithBinding} can then consolidate it with the existing + * binding from the other branch using its global {@code functionBindings} map. + * + *

SSA Renaming

+ * + *

The SSA portion ensures each variable is assigned exactly once by renaming variables when they are + * reassigned in different parts of the tree. For example, if variable "x" is assigned in multiple branches + * with different expressions, they become "x_ssa_1", "x_ssa_2", etc. Without this, BDD compilation would + * incorrectly share nodes for divergent paths that happen to use the same variable name. + * + *

SSA renaming is only applied when reassignment uses different arguments than previously seen assignments. + * + * @see CfgBuilder#createConditionReference for the CFG-level consolidation that catches cross-branch patterns */ -final class SsaTransform { +final class SsaTransform extends TreeMapper { private final Deque> scopeStack = new ArrayDeque<>(); - private final Map rewrittenConditions = new IdentityHashMap<>(); - private final Map rewrittenRules = new IdentityHashMap<>(); - private final VariableAnalysis variableAnalysis; - private final TreeRewriter referenceRewriter; - - private SsaTransform(VariableAnalysis variableAnalysis) { - scopeStack.push(new HashMap<>()); - this.variableAnalysis = variableAnalysis; - this.referenceRewriter = new TreeRewriter(this::referenceRewriter, this::needsRewriting); - } + private final VariableAnalysis analysis; - private Expression referenceRewriter(Reference ref) { - String originalName = ref.getName().toString(); - String uniqueName = resolveReference(originalName); - return Expression.getReference(Identifier.of(uniqueName)); + private SsaTransform(VariableAnalysis analysis) { + this.analysis = analysis; + // Seed initial scope with input parameters. + Map initialScope = new HashMap<>(); + for (String param : analysis.getInputParams()) { + initialScope.put(param, param); + } + scopeStack.push(initialScope); } static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + // Collapse isSet(X) + booleanEquals(X, true/false) into coalesced checks + ruleSet = IsSetBooleanCoalesceTransform.transform(ruleSet); + // Assign synthetic bindings to enable variable consolidation + ruleSet = SyntheticBindingTransform.transform(ruleSet); + // Consolidate variables and eliminate redundant bindings ruleSet = VariableConsolidationTransform.transform(ruleSet); + // Remove bindings that are never used (before coalescing inlines them) + ruleSet = DeadStoreEliminationTransform.transform(ruleSet); + // Coalesces bind-then-use patterns to reduce condition count and branching. ruleSet = CoalesceTransform.transform(ruleSet); - VariableAnalysis variableAnalysis = VariableAnalysis.analyze(ruleSet); - SsaTransform ssaTransform = new SsaTransform(variableAnalysis); - - List rewrittenRules = new ArrayList<>(ruleSet.getRules().size()); - for (Rule original : ruleSet.getRules()) { - rewrittenRules.add(ssaTransform.processRule(original)); - } - - return EndpointRuleSet.builder() - .parameters(ruleSet.getParameters()) - .rules(rewrittenRules) - .version(ruleSet.getVersion()) - .build(); + // Do an SSA transform so divergent paths in the BDD remain logically divergent and thereby correct. + return new SsaTransform(VariableAnalysis.analyze(ruleSet)).endpointRuleSet(ruleSet); } - private Rule processRule(Rule rule) { - enterScope(); - Rule rewrittenRule = rewriteRule(rule); - exitScope(); - return rewrittenRule; - } - - private void enterScope() { - scopeStack.push(new HashMap<>(scopeStack.peek())); - } - - private void exitScope() { - if (scopeStack.size() <= 1) { - throw new IllegalStateException("Cannot exit global scope"); + @Override + public Rule rule(Rule r) { + scopeStack.push(new HashMap<>(peekScope())); + try { + return super.rule(r); + } finally { + scopeStack.pop(); } - scopeStack.pop(); } - private Condition rewriteCondition(Condition condition) { - boolean hasBinding = condition.getResult().isPresent(); - - if (!hasBinding) { - Condition cached = rewrittenConditions.get(condition); - if (cached != null) { - return cached; - } - } - + @Override + public Condition condition(Rule rule, Condition condition) { LibraryFunction fn = condition.getFunction(); - Set rewritableRefs = filterOutInputParameters(fn.getReferences()); String uniqueBindingName = null; boolean needsUniqueBinding = false; - if (hasBinding) { + if (condition.getResult().isPresent()) { String varName = condition.getResult().get().toString(); // Only need SSA rename if variable has multiple bindings - if (variableAnalysis.hasMultipleBindings(varName)) { - Map expressionMap = variableAnalysis.getExpressionMappings().get(varName); + if (analysis.hasMultipleBindings(varName)) { + Map expressionMap = analysis.getExpressionMappings().get(varName); if (expressionMap != null) { uniqueBindingName = expressionMap.get(fn.toString()); needsUniqueBinding = uniqueBindingName != null && !uniqueBindingName.equals(varName); @@ -120,158 +110,59 @@ private Condition rewriteCondition(Condition condition) { } } - if (!needsRewriting(rewritableRefs) && !needsUniqueBinding) { - if (!hasBinding) { - rewrittenConditions.put(condition, condition); - } + if (doesNotNeedRewriting(fn.getReferences()) && !needsUniqueBinding) { return condition; } - LibraryFunction rewrittenExpr = (LibraryFunction) referenceRewriter.rewrite(fn); - boolean exprChanged = rewrittenExpr != fn; + LibraryFunction rewrittenFn = libraryFunction(fn); + boolean fnChanged = rewrittenFn != fn; - Condition rewritten; - if (hasBinding && uniqueBindingName != null) { - scopeStack.peek().put(condition.getResult().get().toString(), uniqueBindingName); - if (needsUniqueBinding || exprChanged) { - rewritten = condition.toBuilder().fn(rewrittenExpr).result(Identifier.of(uniqueBindingName)).build(); - } else { - rewritten = condition; + if (condition.getResult().isPresent() && uniqueBindingName != null) { + bindVariable(condition.getResult().get().toString(), uniqueBindingName); + if (needsUniqueBinding || fnChanged) { + return condition.toBuilder().fn(rewrittenFn).result(Identifier.of(uniqueBindingName)).build(); } - } else if (exprChanged) { - rewritten = condition.toBuilder().fn(rewrittenExpr).build(); - } else { - rewritten = condition; - } - - if (!hasBinding) { - rewrittenConditions.put(condition, rewritten); + } else if (fnChanged) { + return condition.toBuilder().fn(rewrittenFn).build(); } - return rewritten; + return condition; } - private Set filterOutInputParameters(Set references) { - if (references.isEmpty() || variableAnalysis.getInputParams().isEmpty()) { - return references; - } - - Set filtered = new HashSet<>(references); - filtered.removeAll(variableAnalysis.getInputParams()); - return filtered; + Map peekScope() { + return Objects.requireNonNull(scopeStack.peek(), "Scope stack is empty"); } - private boolean needsRewriting(Set references) { - if (references.isEmpty()) { - return false; - } - - Map currentScope = scopeStack.peek(); - for (String ref : references) { - String mapped = currentScope.get(ref); - if (mapped != null && !mapped.equals(ref)) { - return true; - } - } - return false; - } - - private boolean needsRewriting(Expression expression) { - return needsRewriting(filterOutInputParameters(expression.getReferences())); - } - - private Rule rewriteRule(Rule rule) { - Rule cached = rewrittenRules.get(rule); - if (cached != null) { - return cached; - } - - List rewrittenConditions = rewriteConditions(rule.getConditions()); - boolean conditionsChanged = !rewrittenConditions.equals(rule.getConditions()); - - Rule result; - if (rule instanceof EndpointRule) { - result = rewriteEndpointRule((EndpointRule) rule, rewrittenConditions, conditionsChanged); - } else if (rule instanceof ErrorRule) { - result = rewriteErrorRule((ErrorRule) rule, rewrittenConditions, conditionsChanged); - } else if (rule instanceof TreeRule) { - result = rewriteTreeRule((TreeRule) rule, rewrittenConditions, conditionsChanged); - } else if (conditionsChanged) { - throw new UnsupportedOperationException("Cannot change rule: " + rule); - } else { - result = rule; - } - - rewrittenRules.put(rule, result); - return result; - } - - private List rewriteConditions(List conditions) { - List rewritten = new ArrayList<>(conditions.size()); - for (Condition condition : conditions) { - rewritten.add(rewriteCondition(condition)); - } - return rewritten; + @Override + public Expression expression(Expression expression) { + return doesNotNeedRewriting(expression.getReferences()) ? expression : super.expression(expression); } - private Rule rewriteEndpointRule( - EndpointRule rule, - List rewrittenConditions, - boolean conditionsChanged - ) { - Endpoint rewrittenEndpoint = referenceRewriter.rewriteEndpoint(rule.getEndpoint()); - - if (conditionsChanged || rewrittenEndpoint != rule.getEndpoint()) { - return EndpointRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(rewrittenConditions) - .endpoint(rewrittenEndpoint); + @Override + public Reference reference(Reference ref) { + String originalName = ref.getName().toString(); + String uniqueName = peekScope().getOrDefault(originalName, originalName); + if (uniqueName.equals(originalName)) { + return ref; } - - return rule; + return Expression.getReference(Identifier.of(uniqueName)); } - private Rule rewriteErrorRule(ErrorRule rule, List rewrittenConditions, boolean conditionsChanged) { - Expression rewrittenError = referenceRewriter.rewrite(rule.getError()); - - if (conditionsChanged || rewrittenError != rule.getError()) { - return ErrorRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(rewrittenConditions) - .error(rewrittenError); + private void bindVariable(String oldName, String newName) { + Map scope = peekScope(); + String existing = scope.put(oldName, newName); + if (existing != null && !existing.equals(oldName)) { + throw new IllegalStateException("Cannot shadow variable: " + oldName + ", conflicts with " + existing); } - - return rule; } - private Rule rewriteTreeRule(TreeRule rule, List rewrittenConditions, boolean conditionsChanged) { - List rewrittenNestedRules = new ArrayList<>(); - boolean nestedChanged = false; - - for (Rule nestedRule : rule.getRules()) { - enterScope(); - Rule rewritten = rewriteRule(nestedRule); - rewrittenNestedRules.add(rewritten); - if (rewritten != nestedRule) { - nestedChanged = true; + private boolean doesNotNeedRewriting(Set references) { + Map scope = peekScope(); + for (String ref : references) { + if (!ref.equals(scope.getOrDefault(ref, ref))) { + return false; } - exitScope(); - } - - if (conditionsChanged || nestedChanged) { - return TreeRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(rewrittenConditions) - .treeRule(rewrittenNestedRules); } - - return rule; - } - - private String resolveReference(String originalName) { - // Input parameters are never rewritten - return variableAnalysis.getInputParams().contains(originalName) - ? originalName - : scopeStack.peek().getOrDefault(originalName, originalName); + return true; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransform.java new file mode 100644 index 00000000000..7f0f38c1690 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransform.java @@ -0,0 +1,75 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; + +/** + * Assigns synthetic bindings to conditions that could benefit from variable consolidation. + * + *

This transform handles two cases: + *

    + *
  1. {@code isSet(f(x))} is rewritten to {@code _synthetic_N = f(x)}, unwrapping the isSet
  2. + *
  3. Bare function calls like {@code f(x)} become {@code _synthetic_N = f(x)}
  4. + *
+ * + *

This enables {@link VariableConsolidationTransform} to consolidate these synthetic bindings + * with real bindings like {@code url = f(x)}, eliminating redundant function calls. If no consolidation later + * occurs, then {@link DeadStoreEliminationTransform} can remove the unnecessary synthetic bindings. + */ +final class SyntheticBindingTransform extends TreeMapper { + private static final Logger LOGGER = Logger.getLogger(SyntheticBindingTransform.class.getName()); + + private int syntheticCounter = 0; + private int transformedCount = 0; + + static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + SyntheticBindingTransform t = new SyntheticBindingTransform(); + EndpointRuleSet result = t.endpointRuleSet(ruleSet); + if (t.transformedCount > 0) { + LOGGER.info(() -> String.format("Synthetic binding: %d conditions transformed", t.transformedCount)); + } + return result; + } + + @Override + public Condition condition(Rule rule, Condition cond) { + // If it already has a binding, then nothing to do + if (cond.getResult().isPresent()) { + return cond; + } + + LibraryFunction fn = cond.getFunction(); + + // isSet(f(x)) where f(x) is a function call - unwrap and bind + if (fn instanceof IsSet) { + Expression inner = fn.getArguments().get(0); + if (inner instanceof LibraryFunction) { + transformedCount++; + return cond.toBuilder().fn((LibraryFunction) inner).result(createSyntheticName(fn)).build(); + } + return cond; + } + + // Bare function call that doesn't return a boolean? add binding + if (fn.getFunctionDefinition().getReturnType() != Type.booleanType()) { + transformedCount++; + return cond.toBuilder().result(createSyntheticName(fn)).build(); + } + + return cond; + } + + private String createSyntheticName(LibraryFunction fn) { + return "_synthetic_" + fn.getName() + "_" + (syntheticCounter++); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeMapper.java similarity index 50% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeMapper.java index 7690a1cffd9..ed46642de8b 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeMapper.java @@ -8,11 +8,9 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.function.Predicate; import software.amazon.smithy.model.node.Node; import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; @@ -23,92 +21,276 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; -/** - * Utility for rewriting references within expression trees. - */ -final class TreeRewriter { - // A no-op rewriter that returns expressions unchanged. - static final TreeRewriter IDENTITY = new TreeRewriter(ref -> ref, expr -> false); - - private final Function referenceTransformer; - private final Predicate shouldRewrite; - +public abstract class TreeMapper { /** - * Creates a new reference rewriter. - * - * @param referenceTransformer function to transform references - * @param shouldRewrite predicate to determine if an expression needs rewriting + * A no-op mapper that returns expressions unchanged. */ - TreeRewriter( - Function referenceTransformer, - Predicate shouldRewrite - ) { - this.referenceTransformer = referenceTransformer; - this.shouldRewrite = shouldRewrite; - } + private static final TreeMapper IDENTITY = new TreeMapper() {}; /** - * Creates a simple rewriter that replaces specific references. + * Creates a mapper that replaces specific references. * * @param replacements map of variable names to replacement expressions - * @return a reference rewriter that performs the replacements + * @return a mapper that performs the replacements */ - static TreeRewriter forReplacements(Map replacements) { + public static TreeMapper newReferenceReplacingMapper(Map replacements) { if (replacements.isEmpty()) { return IDENTITY; } - return new TreeRewriter( - ref -> replacements.getOrDefault(ref.getName().toString(), ref), - expr -> expr.getReferences().stream().anyMatch(replacements::containsKey)); + + return new TreeMapper() { + @Override + public Expression expression(Expression expression) { + // Handle Reference -> non-Reference replacement here since reference() must return Reference + if (expression instanceof Reference) { + Expression replacement = replacements.get(((Reference) expression).getName().toString()); + if (replacement != null) { + return replacement; + } + return expression; + } + + // Only do deeper replacements if the expression references a relevant variable. + for (String ref : expression.getReferences()) { + if (replacements.containsKey(ref)) { + return super.expression(expression); + } + } + + return expression; + } + + @Override + public Reference reference(Reference ref) { + // This is only called from super.expression() for nested references + Expression replacement = replacements.get(ref.getName().toString()); + if (replacement instanceof Reference) { + return (Reference) replacement; + } + // Non-reference replacements are handled in expression() above + return ref; + } + }; } - static List transformNestedRules( - TreeRule tree, - String parentPath, - BiFunction transformer - ) { - List result = new ArrayList<>(); - for (int i = 0; i < tree.getRules().size(); i++) { - Rule transformed = transformer.apply( - tree.getRules().get(i), - parentPath + "/tree/rule[" + i + "]"); - if (transformed != null) { - result.add(transformed); + public EndpointRuleSet endpointRuleSet(EndpointRuleSet ruleSet) { + List transformed = new ArrayList<>(); + for (Rule r : ruleSet.getRules()) { + Rule mapped = rule(r); + if (mapped != null) { + transformed.add(mapped); } } + return ruleSet.toBuilder().rules(transformed).build(); + } + + public Rule rule(Rule r) { + if (r instanceof TreeRule) { + return treeRule((TreeRule) r); + } else if (r instanceof EndpointRule) { + return endpointRule((EndpointRule) r); + } else if (r instanceof ErrorRule) { + return errorRule((ErrorRule) r); + } + return r; + } + + public Rule treeRule(TreeRule tr) { + return Rule.builder() + .description(tr.getDocumentation().orElse(null)) + .conditions(conditions(tr, tr.getConditions())) + .treeRule(rules(tr, tr.getRules())); + } + + public List rules(TreeRule tr, List rules) { + List updated = new ArrayList<>(tr.getRules().size()); + for (Rule rule : rules) { + Rule mapped = rule(rule); + if (mapped != null) { + updated.add(mapped); + } + } + return updated; + } + + public List conditions(Rule rule, List conditions) { + List updated = new ArrayList<>(conditions.size()); + for (Condition condition : conditions) { + Condition mapped = condition(rule, condition); + if (mapped != null) { + updated.add(mapped); + } + } + return updated; + } + + public Condition condition(Rule rule, Condition condition) { + return Condition.builder() + .fn(libraryFunction(condition.getFunction())) + .result(result(rule, condition, condition.getResult().orElse(null))) + .build(); + } + + public Identifier result(Rule rule, Condition condition, Identifier result) { return result; } - /** - * Rewrites references within an expression tree. - * - * @param expression the expression to rewrite - * @return the rewritten expression, or the original if no changes needed - */ - Expression rewrite(Expression expression) { - if (!shouldRewrite.test(expression)) { - return expression; + public Rule endpointRule(EndpointRule er) { + return Rule.builder() + .description(er.getDocumentation().orElse(null)) + .conditions(conditions(er, er.getConditions())) + .endpoint(endpoint(er.getEndpoint())); + } + + public Rule errorRule(ErrorRule er) { + return Rule.builder() + .description(er.getDocumentation().orElse(null)) + .conditions(conditions(er, er.getConditions())) + .error(error(er, er.getError())); + } + + public Expression error(ErrorRule er, Expression e) { + return expression(e); + } + + public LibraryFunction libraryFunction(LibraryFunction fn) { + boolean changed = false; + List rewrittenArgs = new ArrayList<>(fn.getArguments().size()); + for (int i = 0; i < fn.getArguments().size(); i++) { + Expression argument = fn.getArguments().get(i); + Expression rewritten = argument(fn, i, argument); + rewrittenArgs.add(rewritten); + if (rewritten != argument) { + changed = true; + } + } + + if (!changed) { + return fn; } - if (expression instanceof StringLiteral) { - return rewriteStringLiteral((StringLiteral) expression); - } else if (expression instanceof TupleLiteral) { - return rewriteTupleLiteral((TupleLiteral) expression); - } else if (expression instanceof RecordLiteral) { - return rewriteRecordLiteral((RecordLiteral) expression); + return fn.getFunctionDefinition() + .createFunction(FunctionNode.builder() + .name(Node.from(fn.getName())) + .arguments(rewrittenArgs) + .build()); + } + + public Expression argument(LibraryFunction fn, int position, Expression argument) { + return expression(argument); + } + + public Expression expression(Expression expression) { + if (expression instanceof Literal) { + return literal((Literal) expression); } else if (expression instanceof Reference) { - return referenceTransformer.apply((Reference) expression); + return reference((Reference) expression); } else if (expression instanceof LibraryFunction) { - return rewriteLibraryFunction((LibraryFunction) expression); + return libraryFunction((LibraryFunction) expression); + } else { + return expression; + } + } + + public Reference reference(Reference reference) { + return reference; + } + + public Literal literal(Literal literal) { + if (literal instanceof StringLiteral) { + return stringLiteral((StringLiteral) literal); + } else if (literal instanceof TupleLiteral) { + return tupleLiteral((TupleLiteral) literal); + } else if (literal instanceof RecordLiteral) { + return recordLiteral((RecordLiteral) literal); + } else { + return literal; + } + } + + public Literal tupleLiteral(TupleLiteral tuple) { + List rewrittenMembers = new ArrayList<>(); + boolean changed = false; + + for (Literal member : tuple.members()) { + Literal rewritten = (Literal) expression(member); + rewrittenMembers.add(rewritten); + if (rewritten != member) { + changed = true; + } + } + + return changed ? Literal.tupleLiteral(rewrittenMembers) : tuple; + } + + public Literal recordLiteral(RecordLiteral record) { + Map rewrittenMembers = new LinkedHashMap<>(); + boolean changed = false; + + for (Map.Entry entry : record.members().entrySet()) { + Literal original = entry.getValue(); + Literal rewritten = (Literal) expression(original); + rewrittenMembers.put(entry.getKey(), rewritten); + if (rewritten != original) { + changed = true; + } + } + + return changed ? Literal.recordLiteral(rewrittenMembers) : record; + } + + public Literal stringLiteral(StringLiteral str) { + Template template = str.value(); + if (template.isStatic()) { + return str; + } + + StringBuilder templateBuilder = new StringBuilder(); + boolean changed = false; + + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + Template.Dynamic dynamic = (Template.Dynamic) part; + Expression original = dynamic.toExpression(); + Expression rewritten = expression(original); + if (rewritten != original) { + changed = true; + } + templateBuilder.append('{').append(rewritten).append('}'); + } else { + templateBuilder.append(((Template.Literal) part).getValue()); + } } - return expression; + return changed ? Literal.stringLiteral(Template.fromString(templateBuilder.toString())) : str; } - Map> rewriteHeaders(Map> headers) { + public Endpoint endpoint(Endpoint endpoint) { + Expression rewrittenUrl = expression(endpoint.getUrl()); + Map> rewrittenHeaders = rewriteHeaders(endpoint.getHeaders()); + Map rewrittenProperties = rewriteProperties(endpoint.getProperties()); + + // Only create new endpoint if something changed + if (rewrittenUrl != endpoint.getUrl() + || rewrittenHeaders != endpoint.getHeaders() + || rewrittenProperties != endpoint.getProperties()) { + return Endpoint.builder() + .url(rewrittenUrl) + .headers(rewrittenHeaders) + .properties(rewrittenProperties) + .build(); + } + + return endpoint; + } + + private Map> rewriteHeaders(Map> headers) { if (headers.isEmpty()) { return headers; } @@ -122,7 +304,7 @@ Map> rewriteHeaders(Map> heade for (int i = 0; i < originalValues.size(); i++) { Expression original = originalValues.get(i); - Expression rewrittenExpr = rewrite(original); + Expression rewrittenExpr = expression(original); if (rewrittenExpr != original) { if (rewrittenValues == null) { @@ -155,7 +337,7 @@ Map> rewriteHeaders(Map> heade return changed ? rewritten : headers; } - Map rewriteProperties(Map properties) { + private Map rewriteProperties(Map properties) { if (properties.isEmpty()) { return properties; } @@ -164,7 +346,7 @@ Map rewriteProperties(Map properties) boolean changed = false; for (Map.Entry entry : properties.entrySet()) { - Expression rewrittenExpr = rewrite(entry.getValue()); + Expression rewrittenExpr = expression(entry.getValue()); if (rewrittenExpr != entry.getValue()) { if (!(rewrittenExpr instanceof Literal)) { @@ -191,102 +373,4 @@ Map rewriteProperties(Map properties) return changed ? rewritten : properties; } - - Endpoint rewriteEndpoint(Endpoint endpoint) { - Expression rewrittenUrl = rewrite(endpoint.getUrl()); - Map> rewrittenHeaders = rewriteHeaders(endpoint.getHeaders()); - Map rewrittenProperties = rewriteProperties(endpoint.getProperties()); - - // Only create new endpoint if something changed - if (rewrittenUrl != endpoint.getUrl() - || rewrittenHeaders != endpoint.getHeaders() - || rewrittenProperties != endpoint.getProperties()) { - return Endpoint.builder() - .url(rewrittenUrl) - .headers(rewrittenHeaders) - .properties(rewrittenProperties) - .build(); - } - return endpoint; - } - - private Expression rewriteStringLiteral(StringLiteral str) { - Template template = str.value(); - if (template.isStatic()) { - return str; - } - - StringBuilder templateBuilder = new StringBuilder(); - boolean changed = false; - - for (Template.Part part : template.getParts()) { - if (part instanceof Template.Dynamic) { - Template.Dynamic dynamic = (Template.Dynamic) part; - Expression original = dynamic.toExpression(); - Expression rewritten = rewrite(original); - if (rewritten != original) { - changed = true; - } - templateBuilder.append('{').append(rewritten).append('}'); - } else { - templateBuilder.append(((Template.Literal) part).getValue()); - } - } - - return changed ? Literal.stringLiteral(Template.fromString(templateBuilder.toString())) : str; - } - - private Expression rewriteTupleLiteral(TupleLiteral tuple) { - List rewrittenMembers = new ArrayList<>(); - boolean changed = false; - - for (Literal member : tuple.members()) { - Literal rewritten = (Literal) rewrite(member); - rewrittenMembers.add(rewritten); - if (rewritten != member) { - changed = true; - } - } - - return changed ? Literal.tupleLiteral(rewrittenMembers) : tuple; - } - - private Expression rewriteRecordLiteral(RecordLiteral record) { - Map rewrittenMembers = new LinkedHashMap<>(); - boolean changed = false; - - for (Map.Entry entry : record.members().entrySet()) { - Literal original = entry.getValue(); - Literal rewritten = (Literal) rewrite(original); - rewrittenMembers.put(entry.getKey(), rewritten); - if (rewritten != original) { - changed = true; - } - } - - return changed ? Literal.recordLiteral(rewrittenMembers) : record; - } - - private Expression rewriteLibraryFunction(LibraryFunction fn) { - List rewrittenArgs = new ArrayList<>(); - boolean changed = false; - - for (Expression arg : fn.getArguments()) { - Expression rewritten = rewrite(arg); - rewrittenArgs.add(rewritten); - if (rewritten != arg) { - changed = true; - } - } - - if (!changed) { - return fn; - } - - FunctionNode node = FunctionNode.builder() - .name(Node.from(fn.getName())) - .arguments(rewrittenArgs) - .build(); - return fn.getFunctionDefinition().createFunction(node); - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java index 7d477a72b51..842bc2aadd8 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java @@ -33,35 +33,33 @@ */ final class VariableAnalysis { private final Set inputParams; - private final Map> bindings; + private final Map bindingCounts; private final Map referenceCounts; private final Map> expressionMappings; private VariableAnalysis( Set inputParams, - Map> bindings, + Map bindingCounts, Map referenceCounts, Map> expressionMappings ) { this.inputParams = inputParams; - this.bindings = bindings; + this.bindingCounts = bindingCounts; this.referenceCounts = referenceCounts; this.expressionMappings = expressionMappings; } static VariableAnalysis analyze(EndpointRuleSet ruleSet) { - Set inputParameters = extractInputParameters(ruleSet); - - AnalysisVisitor visitor = new AnalysisVisitor(inputParameters); + AnalysisVisitor visitor = new AnalysisVisitor(); for (Rule rule : ruleSet.getRules()) { visitor.visitRule(rule); } return new VariableAnalysis( - inputParameters, - visitor.bindings, + extractInputParameters(ruleSet), + visitor.bindingCounts, visitor.referenceCounts, - createExpressionMappings(visitor.bindings)); + createExpressionMappings(visitor.bindings, visitor.bindingCounts)); } Set getInputParams() { @@ -81,13 +79,20 @@ boolean isReferencedOnce(String variableName) { } boolean hasSingleBinding(String variableName) { - Set expressions = bindings.get(variableName); - return expressions != null && expressions.size() == 1; + Integer count = bindingCounts.get(variableName); + return count != null && count == 1; } boolean hasMultipleBindings(String variableName) { - Set expressions = bindings.get(variableName); - return expressions != null && expressions.size() > 1; + // Check if variable is bound more than once, regardless of whether expressions are the same. + // This is important because SSA may rewrite references in the expressions, making originally + // identical expressions different after SSA. For example: + // Branch A: outpostId = getAttr(parsed, ...) + // Branch B: outpostId = getAttr(parsed, ...) + // If "parsed" is renamed to "parsed_ssa_1" in branch A and "parsed_ssa_2" in branch B, + // the expressions become different, but we've already decided not to rename "outpostId". + Integer count = bindingCounts.get(variableName); + return count != null && count > 1; } boolean isSafeToInline(String variableName) { @@ -103,29 +108,36 @@ private static Set extractInputParameters(EndpointRuleSet ruleSet) { } private static Map> createExpressionMappings( - Map> bindings + Map> bindings, + Map bindingCounts ) { Map> result = new HashMap<>(); for (Map.Entry> entry : bindings.entrySet()) { String varName = entry.getKey(); Set expressions = entry.getValue(); - result.put(varName, createMappingForVariable(varName, expressions)); + int bindingCount = bindingCounts.getOrDefault(varName, 0); + result.put(varName, createMappingForVariable(varName, expressions, bindingCount)); } return result; } private static Map createMappingForVariable( String varName, - Set expressions + Set expressions, + int bindingCount ) { Map mapping = new HashMap<>(); - if (expressions.size() == 1) { - // Single binding: no SSA rename needed + if (bindingCount <= 1 || expressions.size() == 1) { + // Single binding or multiple bindings with the same expression: no SSA rename needed. + // When multiple bindings have the same expression text, references inside may get + // SSA-renamed differently in each scope, but that's fine: the resulting expressions + // will differ and be treated as distinct BDD conditions. The binding name being the + // same doesn't cause collisions since conditions are identified by their full content. String expression = expressions.iterator().next(); mapping.put(expression, varName); } else { - // Multiple bindings: use SSA naming convention + // Multiple bindings with different expressions: use SSA naming convention List sortedExpressions = new ArrayList<>(expressions); sortedExpressions.sort(String::compareTo); for (int i = 0; i < sortedExpressions.size(); i++) { @@ -140,12 +152,8 @@ private static Map createMappingForVariable( private static class AnalysisVisitor { final Map> bindings = new HashMap<>(); + final Map bindingCounts = new HashMap<>(); final Map referenceCounts = new HashMap<>(); - private final Set inputParams; - - AnalysisVisitor(Set inputParams) { - this.inputParams = inputParams; - } void visitRule(Rule rule) { for (Condition condition : rule.getConditions()) { @@ -154,8 +162,9 @@ void visitRule(Rule rule) { LibraryFunction fn = condition.getFunction(); String expression = fn.toString(); - bindings.computeIfAbsent(varName, k -> new HashSet<>()) - .add(expression); + bindings.computeIfAbsent(varName, k -> new HashSet<>()).add(expression); + // Track number of times variable is bound (not just unique expressions) + bindingCounts.merge(varName, 1, Integer::sum); } countReferences(condition.getFunction()); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java index b0846e786b3..b9c27f5f5ae 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java @@ -4,32 +4,45 @@ */ package software.amazon.smithy.rulesengine.logic.cfg; -import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.List; +import java.util.IdentityHashMap; import java.util.Map; import java.util.Set; import java.util.logging.Logger; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; /** - * Consolidates variable names for identical expressions and eliminates redundant bindings. + * Consolidates variable bindings for identical expressions. * - *

This transform identifies conditions that compute the same expression but assign - * the result to different variable names, and either consolidates them to use the same - * name or eliminates redundant bindings when the same expression is already bound in - * an ancestor scope. + *

This transform performs two optimizations: elimination and consolidation. + * + *

Elimination: If an expression is already bound in a parent scope, child bindings are removed. + *

+ *     url = parseURL(Endpoint)
+ *     _synthetic = parseURL(Endpoint) // eliminated, references rewritten to 'url'
+ * 
+ * + *

Consolidation: If the same expression appears in sibling scopes, all bindings use the first name. + *

+ *     branch1: url = parseURL(Endpoint)
+ *     branch2: parsed = parseURL(Endpoint) // renamed to 'url'
+ * 
+ * + *

The transform avoids renaming when it would cause variable shadowing. + * + * @see SyntheticBindingTransform + * @see DeadStoreEliminationTransform */ -final class VariableConsolidationTransform { +final class VariableConsolidationTransform extends TreeMapper { private static final Logger LOGGER = Logger.getLogger(VariableConsolidationTransform.class.getName()); // Global map of canonical expressions to their first variable name seen @@ -38,223 +51,154 @@ final class VariableConsolidationTransform { // Maps old variable names to new canonical names for rewriting references private final Map variableRenameMap = new HashMap<>(); - // Tracks conditions to eliminate (by their path in the tree) - private final Set conditionsToEliminate = new HashSet<>(); + // Tracks conditions to eliminate (using identity for exact instance matching) + private final Set conditionsToEliminate = Collections.newSetFromMap(new IdentityHashMap<>()); - // Tracks all variables defined at each scope level to check for conflicts - private final Map> scopeDefinedVars = new HashMap<>(); + // Tracks all variables defined at each rule to check for conflicts + private final Map> ruleDefinedVars = new IdentityHashMap<>(); private int consolidatedCount = 0; private int eliminatedCount = 0; private int skippedDueToShadowing = 0; public static EndpointRuleSet transform(EndpointRuleSet ruleSet) { - VariableConsolidationTransform transform = new VariableConsolidationTransform(); - return transform.consolidate(ruleSet); - } + VariableConsolidationTransform t = new VariableConsolidationTransform(); - private EndpointRuleSet consolidate(EndpointRuleSet ruleSet) { - LOGGER.info("Starting variable consolidation transform"); - - for (int i = 0; i < ruleSet.getRules().size(); i++) { - collectDefinitions(ruleSet.getRules().get(i), "rule[" + i + "]"); - } - - for (int i = 0; i < ruleSet.getRules().size(); i++) { - discoverBindingsInRule(ruleSet.getRules().get(i), "rule[" + i + "]", new HashMap<>(), new HashSet<>()); + // Pass 1: Collect all variable definitions per rule + for (Rule rule : ruleSet.getRules()) { + t.collectDefinitions(rule); } - List transformedRules = new ArrayList<>(); - for (int i = 0; i < ruleSet.getRules().size(); i++) { - transformedRules.add(transformRule(ruleSet.getRules().get(i), "rule[" + i + "]")); + // Pass 2: Discover bindings to consolidate/eliminate + for (Rule rule : ruleSet.getRules()) { + t.discoverBindings(rule, new HashMap<>(), new HashSet<>()); } LOGGER.info(String.format("Variable consolidation: %d consolidated, %d eliminated, %d skipped due to shadowing", - consolidatedCount, - eliminatedCount, - skippedDueToShadowing)); + t.consolidatedCount, + t.eliminatedCount, + t.skippedDueToShadowing)); - return EndpointRuleSet.builder() - .parameters(ruleSet.getParameters()) - .rules(transformedRules) - .version(ruleSet.getVersion()) - .build(); + // Pass 3: Transform using TreeMapper + return t.endpointRuleSet(ruleSet); } - private void collectDefinitions(Rule rule, String path) { + private void collectDefinitions(Rule rule) { Set definedVars = new HashSet<>(); - - // Collect all variables defined at this scope level for (Condition condition : rule.getConditions()) { - if (condition.getResult().isPresent()) { - definedVars.add(condition.getResult().get().toString()); - } + condition.getResult().ifPresent(id -> definedVars.add(id.toString())); } - - scopeDefinedVars.put(path, definedVars); + ruleDefinedVars.put(rule, definedVars); if (rule instanceof TreeRule) { - TreeRule treeRule = (TreeRule) rule; - for (int i = 0; i < treeRule.getRules().size(); i++) { - collectDefinitions(treeRule.getRules().get(i), path + "/tree/rule[" + i + "]"); + for (Rule nested : ((TreeRule) rule).getRules()) { + collectDefinitions(nested); } } } - private void discoverBindingsInRule( + private void discoverBindings( Rule rule, - String path, Map parentBindings, Set ancestorVars ) { - // Track bindings at current scope (inherits parent bindings) Map currentBindings = new HashMap<>(parentBindings); - // Track all variables visible from ancestors (for shadowing check) Set visibleAncestorVars = new HashSet<>(ancestorVars); - for (int i = 0; i < rule.getConditions().size(); i++) { - Condition condition = rule.getConditions().get(i); - String condPath = path + "/cond[" + i + "]"; - - if (condition.getResult().isPresent()) { - String varName = condition.getResult().get().toString(); - LibraryFunction fn = condition.getFunction(); - String canonical = fn.canonicalize().toString(); - - // Check if this expression is already bound in parent scope - String parentVar = parentBindings.get(canonical); - if (parentVar != null) { - // Found duplicate in parent, eliminate this binding - variableRenameMap.put(varName, parentVar); - conditionsToEliminate.add(condPath); - eliminatedCount++; - LOGGER.info(String.format("Eliminating redundant binding at %s: '%s' -> '%s' for: %s", - condPath, - varName, - parentVar, - canonical)); - } else { - // Not bound in parent, add to current scope - currentBindings.put(canonical, varName); - visibleAncestorVars.add(varName); + for (Condition condition : rule.getConditions()) { + if (!condition.getResult().isPresent()) { + continue; + } - // Check for global consolidation opportunity - String globalVar = globalExpressionToVar.get(canonical); - if (globalVar != null && !globalVar.equals(varName)) { - // Same expression elsewhere with different name - // Check if consolidation would cause shadowing - if (!wouldCauseShadowing(globalVar, path, ancestorVars)) { - variableRenameMap.put(varName, globalVar); - consolidatedCount++; - LOGGER.info(String.format("Consolidating '%s' -> '%s' for: %s", - varName, - globalVar, - canonical)); - } else { - skippedDueToShadowing++; - LOGGER.fine(String.format("Cannot consolidate '%s' -> '%s' (would shadow) for: %s", - varName, - globalVar, - canonical)); - } - } else if (globalVar == null) { - // First time seeing this expression globally - globalExpressionToVar.put(canonical, varName); + String varName = condition.getResult().get().toString(); + String canonical = condition.getFunction().canonicalize().toString(); + + // Check if already bound in parent scope + String parentVar = parentBindings.get(canonical); + if (parentVar != null) { + variableRenameMap.put(varName, parentVar); + conditionsToEliminate.add(condition); + eliminatedCount++; + LOGGER.fine(() -> String.format("Eliminating redundant binding: '%s' -> '%s' for: %s", + varName, + parentVar, + canonical)); + } else { + currentBindings.put(canonical, varName); + visibleAncestorVars.add(varName); + + // Check for global consolidation opportunity + String globalVar = globalExpressionToVar.get(canonical); + if (globalVar != null && !globalVar.equals(varName)) { + boolean wouldShadow = wouldCauseShadowing(globalVar, rule, ancestorVars); + if (!wouldShadow) { + // No shadowing - safe to rename the binding + variableRenameMap.put(varName, globalVar); + consolidatedCount++; + LOGGER.fine(() -> String.format("Consolidating '%s' -> '%s' for: %s", + varName, + globalVar, + canonical)); + } else { + skippedDueToShadowing++; + LOGGER.info(() -> String.format("Shadowing skip: '%s' -> '%s' for expr: %s", + varName, + globalVar, + canonical)); } + } else if (globalVar == null) { + globalExpressionToVar.put(canonical, varName); } } } if (rule instanceof TreeRule) { - TreeRule treeRule = (TreeRule) rule; - for (int i = 0; i < treeRule.getRules().size(); i++) { - discoverBindingsInRule( - treeRule.getRules().get(i), - path + "/tree/rule[" + i + "]", - currentBindings, - visibleAncestorVars); + for (Rule nested : ((TreeRule) rule).getRules()) { + discoverBindings(nested, currentBindings, visibleAncestorVars); } } } - private boolean wouldCauseShadowing(String varName, String currentPath, Set ancestorVars) { - // Check if using this variable name would shadow an ancestor variable + private boolean wouldCauseShadowing(String varName, Rule currentRule, Set ancestorVars) { if (ancestorVars.contains(varName)) { return true; } - // Check if any child scope already defines this variable - // (which would be shadowed if we use it here) - for (Map.Entry> entry : scopeDefinedVars.entrySet()) { - String scopePath = entry.getKey(); - Set scopeVars = entry.getValue(); - // Check if this scope is a descendant of current path - if (scopePath.startsWith(currentPath + "/") && scopeVars.contains(varName)) { - return true; - } - } - - return false; + // Check if any descendant rule defines this variable + return wouldShadowInDescendants(varName, currentRule); } - private Rule transformRule(Rule rule, String path) { - List transformedConditions = new ArrayList<>(); - - for (int i = 0; i < rule.getConditions().size(); i++) { - String condPath = path + "/cond[" + i + "]"; - - if (conditionsToEliminate.contains(condPath)) { - // Skip this condition entirely since it's redundant - continue; + private boolean wouldShadowInDescendants(String varName, Rule rule) { + if (rule instanceof TreeRule) { + for (Rule nested : ((TreeRule) rule).getRules()) { + Set nestedVars = ruleDefinedVars.get(nested); + if (nestedVars != null && nestedVars.contains(varName)) { + return true; + } + if (wouldShadowInDescendants(varName, nested)) { + return true; + } } - - Condition condition = rule.getConditions().get(i); - transformedConditions.add(transformCondition(condition)); } + return false; + } - if (rule instanceof TreeRule) { - TreeRule treeRule = (TreeRule) rule; - return TreeRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(transformedConditions) - .treeRule(TreeRewriter.transformNestedRules(treeRule, path, this::transformRule)); - - } else if (rule instanceof EndpointRule) { - EndpointRule endpointRule = (EndpointRule) rule; - TreeRewriter rewriter = createRewriter(); - - return EndpointRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(transformedConditions) - .endpoint(rewriter.rewriteEndpoint(endpointRule.getEndpoint())); - - } else if (rule instanceof ErrorRule) { - ErrorRule errorRule = (ErrorRule) rule; - TreeRewriter rewriter = createRewriter(); - - return ErrorRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(transformedConditions) - .error(rewriter.rewrite(errorRule.getError())); + @Override + public Condition condition(Rule rule, Condition condition) { + // Eliminate redundant conditions + if (conditionsToEliminate.contains(condition)) { + return null; } - return rule.withConditions(transformedConditions); - } - - private Condition transformCondition(Condition condition) { - // Rewrite any references in the function - TreeRewriter rewriter = createRewriter(); LibraryFunction fn = condition.getFunction(); - LibraryFunction rewrittenFn = (LibraryFunction) rewriter.rewrite(fn); + LibraryFunction rewrittenFn = libraryFunction(fn); - // If this condition assigns to a variable that should be renamed, - // use the canonical name instead + // Check if binding needs renaming if (condition.getResult().isPresent()) { String varName = condition.getResult().get().toString(); String canonicalName = variableRenameMap.get(varName); if (canonicalName != null) { - // This variable is being consolidated, use the canonical name return Condition.builder() .fn(rewrittenFn) .result(Identifier.of(canonicalName)) @@ -262,7 +206,7 @@ private Condition transformCondition(Condition condition) { } } - // No consolidation needed, but may still need reference rewriting + // Only rebuild if function changed if (rewrittenFn != fn) { return condition.toBuilder().fn(rewrittenFn).build(); } @@ -270,16 +214,23 @@ private Condition transformCondition(Condition condition) { return condition; } - private TreeRewriter createRewriter() { - if (variableRenameMap.isEmpty()) { - return TreeRewriter.IDENTITY; + @Override + public Expression expression(Expression expression) { + // Short-circuit if no references need rewriting + for (String ref : expression.getReferences()) { + if (variableRenameMap.containsKey(ref)) { + return super.expression(expression); + } } + return expression; + } - Map replacements = new HashMap<>(); - for (Map.Entry entry : variableRenameMap.entrySet()) { - replacements.put(entry.getKey(), Expression.getReference(Identifier.of(entry.getValue()))); + @Override + public Reference reference(Reference ref) { + String canonicalName = variableRenameMap.get(ref.getName().toString()); + if (canonicalName != null) { + return Expression.getReference(Identifier.of(canonicalName)); } - - return TreeRewriter.forReplacements(replacements); + return ref; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java index 9cd9d627c36..997532932f7 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java @@ -23,6 +23,8 @@ import software.amazon.smithy.model.traits.AbstractTraitBuilder; import software.amazon.smithy.model.traits.Trait; import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; @@ -98,6 +100,60 @@ public static EndpointBddTrait from(Cfg cfg) { .build(); } + /** + * Removes conditions that are not referenced by any BDD node and remaps indices. + * + *

This should be called after all BDD optimizations (sifting, cost optimization) are complete. + * A condition could become unreferenced if it is optimized away by the BDD compiler because hi==lo. + * + * @return a new trait with unreferenced conditions removed, or this trait if none were removed + */ + public EndpointBddTrait removeUnreferencedConditions() { + // Find which conditions are actually referenced + int[] refCount = new int[conditions.size()]; + for (int i = 0; i < bdd.getNodeCount(); i++) { + int var = bdd.getVariable(i); + if (var >= 0 && var < refCount.length) { + refCount[var]++; + } + } + + // Build mapping from old to new indices + int[] oldToNew = new int[conditions.size()]; + List newConditions = new ArrayList<>(); + for (int i = 0; i < conditions.size(); i++) { + if (refCount[i] > 0) { + oldToNew[i] = newConditions.size(); + newConditions.add(conditions.get(i)); + } else { + oldToNew[i] = -1; + } + } + + // No change needed + if (newConditions.size() == conditions.size()) { + return this; + } + + // Remap BDD nodes + int[] newNodes = new int[bdd.getNodeCount() * 3]; + for (int i = 0; i < bdd.getNodeCount(); i++) { + int oldVar = bdd.getVariable(i); + newNodes[i * 3] = oldVar < 0 ? oldVar : oldToNew[oldVar]; + newNodes[i * 3 + 1] = bdd.getHigh(i); + newNodes[i * 3 + 2] = bdd.getLow(i); + } + + return builder() + .version(version) + .parameters(parameters) + .conditions(newConditions) + .results(results) + .bdd(new Bdd(bdd + .getRootRef(), newConditions.size(), bdd.getResultCount(), bdd.getNodeCount(), newNodes)) + .build(); + } + /** * Gets the parameters for the endpoint rules. * @@ -206,6 +262,16 @@ public static EndpointBddTrait fromNode(Node node) { results.add(NoMatchRule.INSTANCE); // Always add no-match at index 0 results.addAll(serializedResults); + // Validate that results have no conditions (all conditions are hoisted into the BDD) + for (int i = 1; i < results.size(); i++) { + Rule rule = results.get(i); + if (!rule.getConditions().isEmpty()) { + throw new IllegalArgumentException( + "BDD result at index " + i + " has conditions, but BDD results must not have conditions. " + + "All conditions should be hoisted into the BDD decision structure."); + } + } + String nodesBase64 = obj.expectStringMember("nodes").getValue(); int nodeCount = obj.expectNumberMember("nodeCount").getValue().intValue(); int rootRef = obj.expectNumberMember("root").getValue().intValue(); @@ -350,7 +416,21 @@ public Builder bdd(Bdd bdd) { @Override public EndpointBddTrait build() { - return new EndpointBddTrait(this); + EndpointBddTrait trait = new EndpointBddTrait(this); + + // Type-check conditions and results so expression.type() works. Note that using a shared scope across + // each check is ok, because BDD evaluation always runs conditions in a fixed order and could in theory + // try every condition for a single path to a result. + Scope scope = new Scope<>(); + trait.getParameters().writeToScope(scope); + for (Condition condition : trait.getConditions()) { + condition.typeCheck(scope); + } + for (Rule result : trait.getResults()) { + result.typeCheck(scope); + } + + return trait; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java index c79ab5fbeb8..62e06b43742 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java @@ -124,7 +124,7 @@ private String validateAuthSchemeName( FromSourceLocation sourceLocation ) { Literal nameLiteral = authScheme.get(NAME); - if (nameLiteral == null) { + if (nameLiteral == null || nameLiteral.asStringLiteral().isEmpty()) { events.add(error(service, sourceLocation, String.format( @@ -133,13 +133,14 @@ private String validateAuthSchemeName( return null; } - String name = nameLiteral.asStringLiteral().map(s -> s.expectLiteral()).orElse(null); + // Try to get the name as a literal string. If the template contains variables + // (e.g., from branchless transforms like "{s3e_auth}"), we can't statically validate. + String name = nameLiteral.asStringLiteral() + .filter(t -> t.isStatic()) + .map(t -> t.expectLiteral()) + .orElse(null); if (name == null) { - events.add(error(service, - sourceLocation, - String.format( - "Expected `authSchemes` to have a `name` key with a string value but it did not: `%s`", - authScheme))); + // String literal with template variables - skip static validation return null; } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java index bf6ac4bb9da..7dd1d118883 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java @@ -10,6 +10,7 @@ import java.util.Arrays; import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.error.RuleError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.Identifier; @@ -135,7 +136,7 @@ void testCoalesceWithIncompatibleTypes() { Scope scope = new Scope<>(); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> coalesce.typeCheck(scope)); assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); assertTrue(ex.getMessage().contains("argument 2")); } @@ -151,7 +152,7 @@ void testCoalesceWithIncompatibleTypesInMiddle() { Scope scope = new Scope<>(); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> coalesce.typeCheck(scope)); assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); assertTrue(ex.getMessage().contains("argument 3")); } @@ -160,8 +161,7 @@ void testCoalesceWithIncompatibleTypesInMiddle() { void testCoalesceWithLessThanTwoArguments() { Expression single = Literal.of("only"); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, - () -> Coalesce.ofExpressions(single).typeCheck(new Scope<>())); + RuleError ex = assertThrows(RuleError.class, () -> Coalesce.ofExpressions(single).typeCheck(new Scope<>())); assertTrue(ex.getMessage().contains("at least 2 arguments")); } @@ -215,4 +215,24 @@ void testCoalesceWithBooleanTypes() { assertEquals(Type.booleanType(), resultType); } + + @Test + void testTypeMethodReturnsInferredTypeAfterTypeCheck() { + // Verify that type() returns the correct inferred type after typeCheck() + Expression optional1 = Expression.getReference(Identifier.of("maybeValue1")); + Expression optional2 = Expression.getReference(Identifier.of("maybeValue2")); + Expression definite = Literal.of("default"); + Coalesce coalesce = Coalesce.ofExpressions(optional1, optional2, definite); + + Scope scope = new Scope<>(); + scope.insert("maybeValue1", Type.optionalType(Type.stringType())); + scope.insert("maybeValue2", Type.optionalType(Type.stringType())); + + // Call typeCheck to cache the type + coalesce.typeCheck(scope); + + // Now type() should return the inferred type (non-optional since last arg is definite) + Type cachedType = coalesce.type(); + assertEquals(Type.stringType(), cachedType); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java new file mode 100644 index 00000000000..5c57ed7dc6a --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java @@ -0,0 +1,278 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.error.RuleError; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + +public class IteTest { + + @Test + void testIteBothBranchesNonOptionalString() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("-fips"); + Expression falseValue = Literal.of(""); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + // Both non-optional String => non-optional String + assertEquals(Type.stringType(), resultType); + } + + @Test + void testIteBothBranchesNonOptionalInteger() { + Expression condition = Expression.getReference(Identifier.of("useNewValue")); + Expression trueValue = Literal.of(100); + Expression falseValue = Literal.of(0); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("useNewValue", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.integerType(), resultType); + } + + @Test + void testIteTrueBranchOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeValue")); + Expression falseValue = Literal.of("default"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeValue", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // True branch optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteFalseBranchOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("value"); + Expression falseValue = Expression.getReference(Identifier.of("maybeDefault")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeDefault", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // False branch optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteBothBranchesOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybe1")); + Expression falseValue = Expression.getReference(Identifier.of("maybe2")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybe1", Type.optionalType(Type.stringType())); + scope.insert("maybe2", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // Both optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteWithOfStringsHelper() { + Expression condition = Expression.getReference(Identifier.of("UseFIPS")); + Ite ite = Ite.ofStrings(condition, "-fips", ""); + + Scope scope = new Scope<>(); + scope.insert("UseFIPS", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + // Both literal strings => non-optional String + assertEquals(Type.stringType(), resultType); + } + + @Test + void testIteTypeMismatchBetweenBranches() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("string"); + Expression falseValue = Literal.of(42); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("same base type")); + assertTrue(ex.getMessage().contains("true branch")); + assertTrue(ex.getMessage().contains("false branch")); + } + + @Test + void testIteConditionMustBeBoolean() { + Expression condition = Literal.of("not a boolean"); + Expression trueValue = Literal.of("yes"); + Expression falseValue = Literal.of("no"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("non-optional Boolean")); + } + + @Test + void testIteConditionCannotBeOptionalBoolean() { + Expression condition = Expression.getReference(Identifier.of("maybeFlag")); + Expression trueValue = Literal.of("yes"); + Expression falseValue = Literal.of("no"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("maybeFlag", Type.optionalType(Type.booleanType())); + + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("non-optional Boolean")); + assertTrue(ex.getMessage().contains("coalesce")); + } + + @Test + void testIteWithArrayTypes() { + Expression condition = Expression.getReference(Identifier.of("useFirst")); + Expression trueValue = Expression.getReference(Identifier.of("array1")); + Expression falseValue = Expression.getReference(Identifier.of("array2")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("useFirst", Type.booleanType()); + scope.insert("array1", Type.arrayType(Type.stringType())); + scope.insert("array2", Type.arrayType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.arrayType(Type.stringType()), resultType); + } + + @Test + void testIteWithOptionalArrayType() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeArray")); + Expression falseValue = Expression.getReference(Identifier.of("definiteArray")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeArray", Type.optionalType(Type.arrayType(Type.integerType()))); + scope.insert("definiteArray", Type.arrayType(Type.integerType())); + + Type resultType = ite.typeCheck(scope); + + // One optional array => result is optional array + assertEquals(Type.optionalType(Type.arrayType(Type.integerType())), resultType); + } + + @Test + void testIteWithBooleanValues() { + Expression condition = Expression.getReference(Identifier.of("invertFlag")); + Expression trueValue = Literal.of(false); + Expression falseValue = Literal.of(true); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("invertFlag", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.booleanType(), resultType); + } + + @Test + void testIteTypeMismatchWithOptionalUnwrapping() { + // Even with optional wrapping, base types must match + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeString")); + Expression falseValue = Expression.getReference(Identifier.of("maybeInt")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeString", Type.optionalType(Type.stringType())); + scope.insert("maybeInt", Type.optionalType(Type.integerType())); + + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("same base type")); + } + + @Test + void testIteReturnsCorrectId() { + assertEquals("ite", Ite.ID); + assertEquals("ite", Ite.getDefinition().getId()); + } + + @Test + void testTypeMethodReturnsInferredTypeAfterTypeCheck() { + // Verify that type() returns the correct inferred type after typeCheck() + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeValue")); + Expression falseValue = Literal.of("default"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeValue", Type.optionalType(Type.stringType())); + + // Call typeCheck to cache the type + ite.typeCheck(scope); + + // Now type() should return the inferred type + Type cachedType = ite.type(); + assertEquals(Type.optionalType(Type.stringType()), cachedType); + } + + @Test + void testNestedIteTypeInference() { + // Test that nested Ite expressions have correct type inference + Expression outerCondition = Expression.getReference(Identifier.of("outer")); + Expression innerCondition = Expression.getReference(Identifier.of("inner")); + + // Inner ITE: ite(inner, "a", "b") => String + Ite innerIte = Ite.ofExpressions(innerCondition, Literal.of("a"), Literal.of("b")); + + // Outer ITE: ite(outer, innerIte, "c") => String + Ite outerIte = Ite.ofExpressions(outerCondition, innerIte, Literal.of("c")); + + Scope scope = new Scope<>(); + scope.insert("outer", Type.booleanType()); + scope.insert("inner", Type.booleanType()); + + outerIte.typeCheck(scope); + + // Both inner and outer should have String type + assertEquals(Type.stringType(), innerIte.type()); + assertEquals(Type.stringType(), outerIte.type()); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java index 8fb61f6b89a..b477efa1432 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java @@ -202,15 +202,21 @@ void createConditionReferenceHandlesBooleanEqualsCanonicalizations() { } @Test - void createConditionReferenceDoesNotCanonicalizeWithoutDefault() { - // Test that booleanEquals(region, false) is not canonicalized (no default) + void createConditionReferenceCanonicalizesEvenWithoutDefault() { + // booleanEquals(X, false) -> booleanEquals(X, true) with negation is a valid + // algebraic transformation regardless of whether the parameter has a default Expression ref = Expression.getReference(Identifier.of("region")); Condition cond = Condition.builder().fn(BooleanEquals.ofExpressions(ref, false)).build(); ConditionReference condRef = builder.createConditionReference(cond); - assertFalse(condRef.isNegated()); - assertEquals(cond.getFunction(), condRef.getCondition().getFunction()); + // Should be canonicalized to booleanEquals(region, true) with negation + assertTrue(condRef.isNegated()); + assertInstanceOf(BooleanEquals.class, condRef.getCondition().getFunction()); + + BooleanEquals fn = (BooleanEquals) condRef.getCondition().getFunction(); + assertEquals(ref, fn.getArguments().get(0)); + assertEquals(Literal.booleanLiteral(true), fn.getArguments().get(1)); } @Test diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java index 29b9c87f03e..9d3909871b3 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java @@ -34,11 +34,11 @@ void testSimpleReferenceReplacement() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("y"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); // Test rewriting a simple reference Reference original = Expression.getReference(Identifier.of("x")); - Expression rewritten = rewriter.rewrite(original); + Expression rewritten = mapper.expression(original); assertEquals("y", ((Reference) rewritten).getName().toString()); } @@ -49,11 +49,11 @@ void testNoRewriteNeeded() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("y"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); // Reference to "z" should not be rewritten Reference original = Expression.getReference(Identifier.of("z")); - Expression rewritten = rewriter.rewrite(original); + Expression rewritten = mapper.expression(original); assertEquals(original, rewritten); } @@ -67,8 +67,8 @@ void testRewriteInStringLiteral() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("newVar"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertInstanceOf(StringLiteral.class, rewritten); StringLiteral rewrittenStr = (StringLiteral) rewritten; @@ -84,8 +84,8 @@ void testRewriteInTupleLiteral() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("replaced"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertInstanceOf(TupleLiteral.class, rewritten); TupleLiteral rewrittenTuple = (TupleLiteral) rewritten; @@ -105,8 +105,8 @@ void testRewriteInRecordLiteral() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("newX"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertInstanceOf(RecordLiteral.class, rewritten); RecordLiteral rewrittenRecord = (RecordLiteral) rewritten; @@ -124,8 +124,8 @@ void testRewriteInLibraryFunction() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("replacedVar"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertTrue(rewritten.toString().contains("replacedVar")); assertNotEquals(original, rewritten); @@ -142,8 +142,8 @@ void testMultipleReplacements() { replacements.put("a", Expression.getReference(Identifier.of("x"))); replacements.put("b", Expression.getReference(Identifier.of("y"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertTrue(rewritten.toString().contains("x")); assertTrue(rewritten.toString().contains("y")); @@ -158,8 +158,8 @@ void testNestedRewriting() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("newVar"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertTrue(rewritten.toString().contains("newVar")); assertNotEquals(original, rewritten); @@ -173,8 +173,8 @@ void testStaticStringNotRewritten() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("y"))); - TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); - Expression rewritten = rewriter.rewrite(original); + TreeMapper mapper = TreeMapper.newReferenceReplacingMapper(replacements); + Expression rewritten = mapper.expression(original); assertEquals(original, rewritten); } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java index a1594d0d1ec..4e0510e55a8 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java @@ -16,6 +16,7 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.UriEncode; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; @@ -31,6 +32,7 @@ public class SsaTransformTest { @Test void testNoDisambiguationNeeded() { // When variables are not shadowed, they should remain unchanged + // Note: Dead store elimination will remove unused bindings Parameter bucketParam = Parameter.builder() .name("Bucket") .type(ParameterType.STRING) @@ -53,12 +55,15 @@ void testNoDisambiguationNeeded() { EndpointRuleSet result = SsaTransform.transform(original); - assertEquals(original, result); + // Binding is removed since it's not used (dead store elimination) + EndpointRule resultRule = (EndpointRule) result.getRules().get(0); + assertEquals(false, resultRule.getConditions().get(0).getResult().isPresent()); } @Test void testSimpleShadowing() { // Test when the same variable name is bound to different expressions + // Note: Dead store elimination removes unused bindings Parameter param = Parameter.builder() .name("Input") .type(ParameterType.STRING) @@ -80,16 +85,18 @@ void testSimpleShadowing() { List resultRules = result.getRules(); assertEquals(2, resultRules.size()); + // Bindings are removed since they're not used (dead store elimination) EndpointRule resultRule1 = (EndpointRule) resultRules.get(0); - assertEquals("temp_ssa_1", resultRule1.getConditions().get(0).getResult().get().toString()); + assertEquals(false, resultRule1.getConditions().get(0).getResult().isPresent()); EndpointRule resultRule2 = (EndpointRule) resultRules.get(1); - assertEquals("temp_ssa_2", resultRule2.getConditions().get(0).getResult().get().toString()); + assertEquals(false, resultRule2.getConditions().get(0).getResult().isPresent()); } @Test void testMultipleShadowsOfSameVariable() { // Test when a variable is shadowed multiple times + // Note: Dead store elimination removes unused bindings Parameter param = Parameter.builder() .name("Input") .type(ParameterType.STRING) @@ -109,9 +116,10 @@ void testMultipleShadowsOfSameVariable() { EndpointRuleSet result = SsaTransform.transform(original); List resultRules = result.getRules(); - assertEquals("temp_ssa_1", resultRules.get(0).getConditions().get(0).getResult().get().toString()); - assertEquals("temp_ssa_2", resultRules.get(1).getConditions().get(0).getResult().get().toString()); - assertEquals("temp_ssa_3", resultRules.get(2).getConditions().get(0).getResult().get().toString()); + // Bindings are removed since they're not used (dead store elimination) + assertEquals(false, resultRules.get(0).getConditions().get(0).getResult().isPresent()); + assertEquals(false, resultRules.get(1).getConditions().get(0).getResult().isPresent()); + assertEquals(false, resultRules.get(2).getConditions().get(0).getResult().isPresent()); } @Test @@ -146,6 +154,7 @@ void testErrorRuleHandling() { @Test void testTreeRuleHandling() { // Test tree rules with unique variable names at each level + // Note: Dead store elimination removes unused bindings Parameter param = Parameter.builder() .name("Region") .type(ParameterType.STRING) @@ -182,6 +191,7 @@ void testTreeRuleHandling() { @Test void testParameterShadowingAttempt() { // Test that attempting to shadow a parameter gets disambiguated + // Note: Dead store elimination removes unused bindings Parameter bucketParam = Parameter.builder() .name("Bucket") .type(ParameterType.STRING) @@ -205,9 +215,9 @@ void testParameterShadowingAttempt() { EndpointRuleSet result = SsaTransform.transform(original); - // Should handle without issues + // Binding is removed since it's not used (dead store elimination) EndpointRule resultRule = (EndpointRule) result.getRules().get(0); - assertEquals("Bucket_shadow", resultRule.getConditions().get(0).getResult().get().toString()); + assertEquals(false, resultRule.getConditions().get(0).getResult().isPresent()); } private static EndpointRule createRuleWithBinding(String param, String value, String resultVar, String url) { @@ -228,4 +238,90 @@ private static Expression expr(String value) { private static Endpoint endpoint(String value) { return Endpoint.builder().url(expr(value)).build(); } + + private static Endpoint endpoint(Expression url) { + return Endpoint.builder().url(url).build(); + } + + @Test + void testMultipleBindingsWithUsedVariablesAreSsaRenamed() { + // When the same variable is bound in multiple sibling branches AND is used, + // each binding should get a unique SSA name to avoid shadowing conflicts + Parameter param = Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build(); + + // Create two sibling rules that: + // 1. Both bind 'myVar' (same variable name, different expressions) + // 2. Use 'myVar' in their endpoints (so it won't be eliminated) + // Using UriEncode which returns a string (not boolean like StringEquals) + Condition cond1 = Condition.builder() + .fn(UriEncode.ofExpressions(Expression.of("Input"))) + .result("myVar") + .build(); + EndpointRule rule1 = (EndpointRule) EndpointRule.builder() + .conditions(Collections.singletonList(cond1)) + .endpoint(Endpoint.builder() + .url(Literal.stringLiteral(Template.fromString("https://{myVar}.example.com"))) + .build()); + + // Second rule uses a different expression (uriEncode of a literal) + Condition cond2 = Condition.builder() + .fn(UriEncode.ofExpressions(Expression.of("other"))) + .result("myVar") + .build(); + EndpointRule rule2 = (EndpointRule) EndpointRule.builder() + .conditions(Collections.singletonList(cond2)) + .endpoint(Endpoint.builder() + .url(Literal.stringLiteral(Template.fromString("https://{myVar}.other.com"))) + .build()); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(param).build()) + .rules(Arrays.asList(rule1, rule2)) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + // Verify both rules still exist + assertEquals(2, result.getRules().size()); + + // Get the result variable names from the transformed conditions + EndpointRule resultRule1 = (EndpointRule) result.getRules().get(0); + EndpointRule resultRule2 = (EndpointRule) result.getRules().get(1); + + String resultVar1 = resultRule1.getConditions() + .get(0) + .getResult() + .map(Object::toString) + .orElse(null); + String resultVar2 = resultRule2.getConditions() + .get(0) + .getResult() + .map(Object::toString) + .orElse(null); + + // Both should have bindings (since they're used) + assertEquals(true, + resultRule1.getConditions().get(0).getResult().isPresent(), + "First binding should be present since myVar is used"); + assertEquals(true, + resultRule2.getConditions().get(0).getResult().isPresent(), + "Second binding should be present since myVar is used"); + + // They should be SSA-renamed to unique names + assertEquals(true, + resultVar1.contains("_ssa_"), + "First binding should have SSA suffix, got: " + resultVar1); + assertEquals(true, + resultVar2.contains("_ssa_"), + "Second binding should have SSA suffix, got: " + resultVar2); + + // They should NOT be the same (unique SSA names) + assertEquals(false, + resultVar1.equals(resultVar2), + "SSA names should be unique: " + resultVar1 + " vs " + resultVar2); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransformTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransformTest.java new file mode 100644 index 00000000000..7e46b3b422e --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SyntheticBindingTransformTest.java @@ -0,0 +1,148 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; + +class SyntheticBindingTransformTest { + + private static final Parameters PARAMS = Parameters.builder() + .addParameter(Parameter.builder().name("Input").type(ParameterType.STRING).build()) + .build(); + + @Test + void unwrapsIsSetFunctionCall() { + // isSet(substring(Input, 0, 5, false)) should become _synthetic_0 = substring(...) + Condition isSetSubstring = Condition.builder() + .fn(IsSet.ofExpressions(TestHelpers.substring("Input", 0, 5, false))) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(TestHelpers.isSet("Input"), isSetSubstring) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(PARAMS) + .addRule(rule) + .build(); + + EndpointRuleSet result = SyntheticBindingTransform.transform(ruleSet); + + Condition transformed = result.getRules().get(0).getConditions().get(1); + assertEquals("substring", transformed.getFunction().getName()); + assertTrue(transformed.getResult().isPresent()); + // Name includes the outer function (isSet) that was unwrapped + assertTrue(transformed.getResult().get().toString().startsWith("_synthetic_isSet_")); + } + + @Test + void doesNotUnwrapIsSetReference() { + // isSet(Input) should remain unchanged - it's checking a reference, not a function + Condition isSetRef = Condition.builder() + .fn(TestHelpers.isSet("Input")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(isSetRef) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(PARAMS) + .addRule(rule) + .build(); + + EndpointRuleSet result = SyntheticBindingTransform.transform(ruleSet); + + Condition transformed = result.getRules().get(0).getConditions().get(0); + assertEquals("isSet", transformed.getFunction().getName()); + assertFalse(transformed.getResult().isPresent()); + } + + @Test + void addsBindingToBareFunctionCall() { + // substring(Input, 0, 5, false) should become _synthetic_0 = substring(...) + Condition bareSubstring = Condition.builder() + .fn(TestHelpers.substring("Input", 0, 5, false)) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(TestHelpers.isSet("Input"), bareSubstring) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(PARAMS) + .addRule(rule) + .build(); + + EndpointRuleSet result = SyntheticBindingTransform.transform(ruleSet); + + Condition transformed = result.getRules().get(0).getConditions().get(1); + assertEquals("substring", transformed.getFunction().getName()); + assertTrue(transformed.getResult().isPresent()); + // Name includes the function name + assertTrue(transformed.getResult().get().toString().startsWith("_synthetic_substring_")); + } + + @Test + void doesNotModifyExistingBinding() { + // prefix = substring(Input, 0, 5, false) should remain unchanged + Condition binding = Condition.builder() + .fn(TestHelpers.substring("Input", 0, 5, false)) + .result(Identifier.of("prefix")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(TestHelpers.isSet("Input"), binding) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(PARAMS) + .addRule(rule) + .build(); + + EndpointRuleSet result = SyntheticBindingTransform.transform(ruleSet); + + Condition transformed = result.getRules().get(0).getConditions().get(1); + assertEquals("substring", transformed.getFunction().getName()); + assertEquals("prefix", transformed.getResult().get().toString()); + } + + @Test + void doesNotAddBindingToSimpleChecks() { + // isSet, booleanEquals, stringEquals, not, isValidHostLabel should not get bindings + Condition isSet = Condition.builder() + .fn(TestHelpers.isSet("Input")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(isSet) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(PARAMS) + .addRule(rule) + .build(); + + EndpointRuleSet result = SyntheticBindingTransform.transform(ruleSet); + + Condition transformed = result.getRules().get(0).getConditions().get(0); + assertEquals("isSet", transformed.getFunction().getName()); + assertFalse(transformed.getResult().isPresent()); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java index 88c2aab778e..a4defa53c43 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java @@ -12,6 +12,13 @@ import java.util.List; import org.junit.jupiter.api.Test; import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; @@ -25,7 +32,11 @@ public class BddTraitTest { @Test void testBddTraitSerialization() { // Create a BddTrait with full context - Parameters params = Parameters.builder().build(); + Parameter regionParam = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + Parameters params = Parameters.builder().addParameter(regionParam).build(); Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); @@ -99,4 +110,39 @@ void testEmptyBddTrait() { assertEquals(1, trait.getResults().size()); assertEquals(-1, trait.getBdd().getRootRef()); // FALSE terminal } + + @Test + void testBuildTypeChecksExpressionsForCodegen() { + // Verify that after building an EndpointBddTrait, expression.type() works + // This is important for codegen to infer types without a scope + Parameter regionParam = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + Parameters params = Parameters.builder().addParameter(regionParam).build(); + + // Create a condition with a coalesce that infers to String + Expression regionRef = Expression.getReference(Identifier.of("Region")); + Expression fallback = Literal.of("us-east-1"); + Coalesce coalesce = Coalesce.ofExpressions(regionRef, fallback); + Condition cond = Condition.builder().fn(coalesce).result(Identifier.of("resolvedRegion")).build(); + + Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); + results.add(endpoint); + + EndpointBddTrait trait = EndpointBddTrait.builder() + .parameters(params) + .conditions(ListUtils.of(cond)) + .results(results) + .bdd(createSimpleBdd()) + .build(); + + // After build(), type() should work on the coalesce expression + // Region is Optional, fallback is String, so result is String (non-optional) + Coalesce builtCoalesce = (Coalesce) trait.getConditions().get(0).getFunction(); + assertEquals(Type.stringType(), builtCoalesce.type()); + } } diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors @@ -0,0 +1 @@ + diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy new file mode 100644 index 00000000000..75a4d4d050c --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy @@ -0,0 +1,80 @@ +$version: "2.0" + +namespace example + +use smithy.rules#clientContextParams +use smithy.rules#endpointRuleSet +use smithy.rules#endpointTests + +@clientContextParams( + useFips: {type: "boolean", documentation: "Use FIPS endpoints"} +) +@endpointRuleSet({ + version: "1.1", + parameters: { + useFips: { + type: "boolean", + documentation: "Use FIPS endpoints", + default: false, + required: true + } + }, + rules: [ + { + "documentation": "Use ite to select endpoint suffix" + "conditions": [ + { + "fn": "ite" + "argv": [{"ref": "useFips"}, "-fips", ""] + "assign": "suffix" + } + ] + "endpoint": { + "url": "https://example{suffix}.com" + } + "type": "endpoint" + } + ] +}) +@endpointTests({ + "version": "1.0", + "testCases": [ + { + "documentation": "When useFips is true, returns trueValue" + "params": { + "useFips": true + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example-fips.com" + } + } + } + { + "documentation": "When useFips is false, returns falseValue" + "params": { + "useFips": false + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com" + } + } + } + ] +}) +@suppress(["UnstableTrait.smithy"]) +service FizzBuzz { + version: "2022-01-01", + operations: [GetThing] +} + +operation GetThing { + input := {} +}