|
| 1 | +/* |
| 2 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +package software.amazon.smithy.rustsdk |
| 7 | + |
| 8 | +import org.junit.jupiter.api.Test |
| 9 | +import software.amazon.smithy.rust.codegen.core.rustlang.Feature |
| 10 | +import software.amazon.smithy.rust.codegen.core.rustlang.rust |
| 11 | +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate |
| 12 | +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType |
| 13 | +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope |
| 14 | +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel |
| 15 | +import software.amazon.smithy.rust.codegen.core.testutil.integrationTest |
| 16 | +import software.amazon.smithy.rust.codegen.core.testutil.tokioTest |
| 17 | + |
| 18 | +class EndpointOverrideMetricDecoratorTest { |
| 19 | + companion object { |
| 20 | + private const val PREFIX = "\$version: \"2\"" |
| 21 | + val model = |
| 22 | + """ |
| 23 | + $PREFIX |
| 24 | + namespace test |
| 25 | +
|
| 26 | + use aws.api#service |
| 27 | + use aws.auth#sigv4 |
| 28 | + use aws.protocols#restJson1 |
| 29 | + use smithy.rules#endpointRuleSet |
| 30 | +
|
| 31 | + @service(sdkId: "dontcare") |
| 32 | + @restJson1 |
| 33 | + @sigv4(name: "dontcare") |
| 34 | + @auth([sigv4]) |
| 35 | + @endpointRuleSet({ |
| 36 | + "version": "1.0", |
| 37 | + "rules": [ |
| 38 | + { |
| 39 | + "type": "endpoint", |
| 40 | + "conditions": [ |
| 41 | + { "fn": "isSet", "argv": [{ "ref": "Endpoint" }] } |
| 42 | + ], |
| 43 | + "endpoint": { "url": { "ref": "Endpoint" } } |
| 44 | + }, |
| 45 | + { |
| 46 | + "type": "endpoint", |
| 47 | + "conditions": [], |
| 48 | + "endpoint": { "url": "https://example.com" } |
| 49 | + } |
| 50 | + ], |
| 51 | + "parameters": { |
| 52 | + "Region": { "required": false, "type": "String", "builtIn": "AWS::Region" }, |
| 53 | + "Endpoint": { "required": false, "type": "String", "builtIn": "SDK::Endpoint" } |
| 54 | + } |
| 55 | + }) |
| 56 | + service TestService { |
| 57 | + version: "2023-01-01", |
| 58 | + operations: [SomeOperation] |
| 59 | + } |
| 60 | +
|
| 61 | + @http(uri: "/SomeOperation", method: "GET") |
| 62 | + @optionalAuth |
| 63 | + operation SomeOperation { |
| 64 | + input: SomeInput, |
| 65 | + output: SomeOutput |
| 66 | + } |
| 67 | +
|
| 68 | + @input |
| 69 | + structure SomeInput {} |
| 70 | +
|
| 71 | + @output |
| 72 | + structure SomeOutput {} |
| 73 | + """.asSmithyModel() |
| 74 | + } |
| 75 | + |
| 76 | + @Test |
| 77 | + fun `decorator is registered in AwsCodegenDecorator list`() { |
| 78 | + val decoratorNames = DECORATORS.map { it.name } |
| 79 | + assert(decoratorNames.contains("EndpointOverrideMetric")) { |
| 80 | + "EndpointOverrideMetricDecorator should be registered in DECORATORS list. Found: $decoratorNames" |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + @Test |
| 85 | + fun `endpoint override metric appears when set via SdkConfig`() { |
| 86 | + val testParams = awsIntegrationTestParams() |
| 87 | + |
| 88 | + awsSdkIntegrationTest( |
| 89 | + model, |
| 90 | + testParams, |
| 91 | + environment = mapOf("RUSTUP_TOOLCHAIN" to "1.88.0"), |
| 92 | + ) { context, rustCrate -> |
| 93 | + val rc = context.runtimeConfig |
| 94 | + val moduleName = context.moduleUseName() |
| 95 | + |
| 96 | + // Enable test-util feature for aws-runtime |
| 97 | + rustCrate.mergeFeature(Feature("test-util", true, listOf("aws-runtime/test-util"))) |
| 98 | + |
| 99 | + rustCrate.integrationTest("endpoint_override_via_sdk_config") { |
| 100 | + tokioTest("metric_tracked_when_endpoint_set_via_sdk_config") { |
| 101 | + rustTemplate( |
| 102 | + """ |
| 103 | + use $moduleName::config::Region; |
| 104 | + use $moduleName::Client; |
| 105 | + use #{capture_request}; |
| 106 | + use #{assert_ua_contains_metric_values}; |
| 107 | + |
| 108 | + let (http_client, rcvr) = capture_request(None); |
| 109 | + |
| 110 | + // Create SdkConfig with endpoint URL |
| 111 | + let sdk_config = #{SdkConfig}::builder() |
| 112 | + .region(Region::new("us-east-1")) |
| 113 | + .endpoint_url("https://sdk-custom.example.com") |
| 114 | + .http_client(http_client.clone()) |
| 115 | + .build(); |
| 116 | + |
| 117 | + // Create client from SdkConfig |
| 118 | + let client = Client::new(&sdk_config); |
| 119 | +
|
| 120 | + // Make a request |
| 121 | + let _ = client.some_operation().send().await; |
| 122 | +
|
| 123 | + // Verify the request |
| 124 | + let request = rcvr.expect_request(); |
| 125 | +
|
| 126 | + // Verify endpoint was overridden |
| 127 | + let uri = request.uri().to_string(); |
| 128 | + assert!( |
| 129 | + uri.starts_with("https://sdk-custom.example.com"), |
| 130 | + "Expected SDK custom endpoint, got: {}", |
| 131 | + uri |
| 132 | + ); |
| 133 | +
|
| 134 | + // Verify metric 'N' is present in x-amz-user-agent header |
| 135 | + let user_agent = request |
| 136 | + .headers() |
| 137 | + .get("x-amz-user-agent") |
| 138 | + .expect("x-amz-user-agent header missing"); |
| 139 | + |
| 140 | + assert_ua_contains_metric_values(user_agent, &["N"]); |
| 141 | + """, |
| 142 | + *preludeScope, |
| 143 | + "capture_request" to RuntimeType.captureRequest(rc), |
| 144 | + "assert_ua_contains_metric_values" to AwsRuntimeType.awsRuntime(rc).resolve("user_agent::test_util::assert_ua_contains_metric_values"), |
| 145 | + "SdkConfig" to AwsRuntimeType.awsTypes(rc).resolve("sdk_config::SdkConfig"), |
| 146 | + ) |
| 147 | + } |
| 148 | + } |
| 149 | + } |
| 150 | + } |
| 151 | + |
| 152 | + @Test |
| 153 | + fun `no endpoint override metric when endpoint not set`() { |
| 154 | + val testParams = awsIntegrationTestParams() |
| 155 | + |
| 156 | + awsSdkIntegrationTest( |
| 157 | + model, |
| 158 | + testParams, |
| 159 | + environment = mapOf("RUSTUP_TOOLCHAIN" to "1.88.0"), |
| 160 | + ) { context, rustCrate -> |
| 161 | + val rc = context.runtimeConfig |
| 162 | + val moduleName = context.moduleUseName() |
| 163 | + |
| 164 | + // Enable test-util feature for aws-runtime |
| 165 | + rustCrate.mergeFeature(Feature("test-util", true, listOf("aws-runtime/test-util"))) |
| 166 | + |
| 167 | + rustCrate.integrationTest("no_endpoint_override") { |
| 168 | + tokioTest("no_metric_when_endpoint_not_overridden") { |
| 169 | + rustTemplate( |
| 170 | + """ |
| 171 | + use $moduleName::config::{Credentials, Region, SharedCredentialsProvider}; |
| 172 | + use $moduleName::{Config, Client}; |
| 173 | + use #{capture_request}; |
| 174 | + |
| 175 | + let (http_client, rcvr) = capture_request(None); |
| 176 | + |
| 177 | + // Create config WITHOUT endpoint override |
| 178 | + let config = Config::builder() |
| 179 | + .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) |
| 180 | + .region(Region::new("us-east-1")) |
| 181 | + .http_client(http_client.clone()) |
| 182 | + .build(); |
| 183 | + let client = Client::from_conf(config); |
| 184 | +
|
| 185 | + // Make a request |
| 186 | + let _ = client.some_operation().send().await; |
| 187 | +
|
| 188 | + // Verify the request |
| 189 | + let request = rcvr.expect_request(); |
| 190 | +
|
| 191 | + // Verify default endpoint was used |
| 192 | + let uri = request.uri().to_string(); |
| 193 | + assert!( |
| 194 | + uri.starts_with("https://example.com"), |
| 195 | + "Expected default endpoint, got: {}", |
| 196 | + uri |
| 197 | + ); |
| 198 | +
|
| 199 | + // Verify metric 'N' is NOT present |
| 200 | + let user_agent = request |
| 201 | + .headers() |
| 202 | + .get("x-amz-user-agent") |
| 203 | + .expect("x-amz-user-agent header should be present"); |
| 204 | +
|
| 205 | + assert!( |
| 206 | + !user_agent.contains("m/N"), |
| 207 | + "Metric 'N' should NOT be present when endpoint not overridden" |
| 208 | + ); |
| 209 | + """, |
| 210 | + *preludeScope, |
| 211 | + "capture_request" to RuntimeType.captureRequest(rc), |
| 212 | + ) |
| 213 | + } |
| 214 | + |
| 215 | + // Add a should_panic test to verify assert_ua_contains_metric_values panics when metric is not present |
| 216 | + rust("##[should_panic(expected = \"metric values\")]") |
| 217 | + tokioTest("assert_panics_when_metric_not_present") { |
| 218 | + rustTemplate( |
| 219 | + """ |
| 220 | + use $moduleName::config::{Credentials, Region, SharedCredentialsProvider}; |
| 221 | + use $moduleName::{Config, Client}; |
| 222 | + use #{capture_request}; |
| 223 | + use #{assert_ua_contains_metric_values}; |
| 224 | +
|
| 225 | + let (http_client, rcvr) = capture_request(None); |
| 226 | +
|
| 227 | + let config = Config::builder() |
| 228 | + .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) |
| 229 | + .region(Region::new("us-east-1")) |
| 230 | + .http_client(http_client.clone()) |
| 231 | + .build(); |
| 232 | + let client = Client::from_conf(config); |
| 233 | +
|
| 234 | + let _ = client.some_operation().send().await; |
| 235 | + let request = rcvr.expect_request(); |
| 236 | + let user_agent = request.headers().get("x-amz-user-agent").unwrap(); |
| 237 | +
|
| 238 | + // This should panic because 'N' is not present |
| 239 | + assert_ua_contains_metric_values(user_agent, &["N"]); |
| 240 | + """, |
| 241 | + *preludeScope, |
| 242 | + "capture_request" to RuntimeType.captureRequest(rc), |
| 243 | + "assert_ua_contains_metric_values" to AwsRuntimeType.awsRuntime(rc).resolve("user_agent::test_util::assert_ua_contains_metric_values"), |
| 244 | + ) |
| 245 | + } |
| 246 | + } |
| 247 | + } |
| 248 | + } |
| 249 | +} |
0 commit comments