Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .changelog/1762379492.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
applies_to:
- client
- aws-sdk-rust
authors:
- aajtodd
references: []
breaking: false
new_feature: false
bug_fix: false
---
Validate `Region` is a valid host label when constructing endpoints.
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ package software.amazon.smithy.rustsdk
import software.amazon.smithy.aws.traits.auth.SigV4Trait
import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rulesengine.aws.language.functions.AwsBuiltIns
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.configReexport
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointsLib
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.EndpointParamsGenerator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.memberName
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
Expand Down Expand Up @@ -143,6 +147,35 @@ class RegionDecorator : ClientCodegenDecorator {
)
}
}

override fun endpointParamsBuilderValidator(
codegenContext: ClientCodegenContext,
parameter: Parameter,
): Writable? {
if (endpointTestsValidatesRegionAlready(codegenContext)) {
return null
}

return when (parameter.builtIn) {
AwsBuiltIns.REGION.builtIn ->
writable {
rustTemplate(
"""
if let Some(region) = &self.${parameter.memberName()} {
if !#{is_valid_host_label}(region.as_ref() as &str, true, &mut #{DiagnosticCollector}::new()) {
return Err(#{ParamsError}::invalid_value(${parameter.memberName().dq()}, "must be a valid host label"))
}
}
""",
"is_valid_host_label" to EndpointsLib.isValidHostLabel,
"ParamsError" to EndpointParamsGenerator.paramsError(),
"DiagnosticCollector" to EndpointsLib.DiagnosticCollector,
)
}

else -> null
}
}
},
)
}
Expand Down Expand Up @@ -238,3 +271,14 @@ fun usesRegion(codegenContext: ClientCodegenContext) =
codegenContext.getBuiltIn(AwsBuiltIns.REGION) != null ||
ServiceIndex.of(codegenContext.model)
.getEffectiveAuthSchemes(codegenContext.serviceShape).containsKey(SigV4Trait.ID)

/**
* Test if region is already validated via endpoint rules tests. Validating region when building parameters
* will break endpoint tests which validate during resolution and expect specific errors.
*/
fun endpointTestsValidatesRegionAlready(codegenContext: ClientCodegenContext): Boolean =
codegenContext.serviceShape.id in
setOf(
ShapeId.from("com.amazonaws.s3#AmazonS3"),
ShapeId.from("com.amazonaws.s3control#AWSS3ControlServiceV20180820"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ package software.amazon.smithy.rustsdk
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.testutil.tokioTest
import kotlin.io.path.readText

class RegionDecoratorTest {
Expand Down Expand Up @@ -105,4 +109,48 @@ class RegionDecoratorTest {
val configContents = path.resolve("src/config.rs").readText()
assertTrue(configContents.contains("fn set_region("))
}

// V1988105516
@Test
fun `models with region built-in params should validate host label`() {
awsSdkIntegrationTest(modelWithRegionParam) { ctx, rustCrate ->
val rc = ctx.runtimeConfig
val codegenScope =
arrayOf(
*RuntimeType.preludeScope,
"capture_request" to RuntimeType.captureRequest(rc),
"Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"),
)

rustCrate.integrationTest("endpoint_params_validation") {
tokioTest("region_must_be_valid_host_label") {
val moduleName = ctx.moduleUseName()
rustTemplate(
"""
let (http_client, _rx) = #{capture_request}(#{None});
let client_config = $moduleName::Config::builder()
.http_client(http_client)
.region(#{Region}::new("@controlled-proxy.com##"))
.build();

let client = $moduleName::Client::from_conf(client_config);

let err = client
.some_operation()
.send()
.await
.expect_err("error");

let err_str = format!("{}", $moduleName::error::DisplayErrorContext(&err));
dbg!(&err_str);
let expected = "invalid value for field: `region` - must be a valid host label";
assert!(err_str.contains(expected));

""",
*codegenScope,
)
}
}
}
}
}
8 changes: 4 additions & 4 deletions aws/rust-runtime/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,43 @@ interface EndpointCustomization {
codegenContext: ClientCodegenContext,
params: String,
): Writable? = null

/**
* Allows injecting validation logic for endpoint parameters into the `ParamsBuilder::build` method.
*
* e.g. when generating the builder for the endpoint parameters this allows you to insert validation logic before
* being finalizing the parameters.
*
* ```rs
* impl ParamsBuilder {
* pub fn build(self) -> ::std::result::Result<crate::config::endpoint::Params, crate::config::endpoint::InvalidParams> {
* <validation logic>
* ...
* }
* }
*
* Example:
* ```kotlin
*
* override fun endpointParamsBuilderValidator(codegenContext: ClientCodegenContext, parameter: Parameter): Writable? {
* rustTemplate("""
* if let Some(region) = self.${parameter.memberName()} {
* if #{is_valid_host_label}(region.as_ref() as &str, false, #{DiagnosticCollector}::new()) {
* return Err(#{ParamsError}::invalid_value(${parameter.memberName().dq()}, "must be a valid host label"))
* }
* }
* """,
* "is_valid_host_label" to EndpointsLib.isValidHostLabel,
* "ParamsError" to EndpointParamsGenerator.paramsError(),
* "DiagnosticCollector" to EndpointsLib.DiagnosticCollector,
* )
* }
* ```
*/
fun endpointParamsBuilderValidator(
codegenContext: ClientCodegenContext,
parameter: Parameter,
): Writable? = null
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ val EndpointStdLib = RustModule.private("endpoint_lib")
* ```
*/

internal class EndpointParamsGenerator(
class EndpointParamsGenerator(
private val codegenContext: ClientCodegenContext,
private val parameters: Parameters,
) {
Expand All @@ -122,6 +122,52 @@ internal class EndpointParamsGenerator(
fun setterName(parameterName: String) = "set_${memberName(parameterName)}"

fun getterName(parameterName: String) = "get_${memberName(parameterName)}"

fun paramsError(): RuntimeType =
RuntimeType.forInlineFun("InvalidParams", ClientRustModule.Config.endpoint) {
rust(
"""
/// An error that occurred during endpoint resolution
##[derive(Debug)]
pub struct InvalidParams {
field: std::borrow::Cow<'static, str>,
kind: InvalidParamsErrorKind,
}

/// The kind of invalid parameter error
##[derive(Debug)]
enum InvalidParamsErrorKind {
MissingField,
InvalidValue {
message: &'static str,
}
}

impl InvalidParams {
##[allow(dead_code)]
fn missing(field: &'static str) -> Self {
Self { field: field.into(), kind: InvalidParamsErrorKind::MissingField }
}

##[allow(dead_code)]
fn invalid_value(field: &'static str, message: &'static str) -> Self {
Self { field: field.into(), kind: InvalidParamsErrorKind::InvalidValue { message }}
}
}

impl std::fmt::Display for InvalidParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.kind {
InvalidParamsErrorKind::MissingField => write!(f, "a required field was missing: `{}`", self.field),
InvalidParamsErrorKind::InvalidValue { message} => write!(f, "invalid value for field: `{}` - {}", self.field, message),
}
}
}

impl std::error::Error for InvalidParams { }
""",
)
}
}

fun paramsStruct(): RuntimeType =
Expand All @@ -134,34 +180,6 @@ internal class EndpointParamsGenerator(
generateEndpointParamsBuilder(this)
}

private fun paramsError(): RuntimeType =
RuntimeType.forInlineFun("InvalidParams", ClientRustModule.Config.endpoint) {
rust(
"""
/// An error that occurred during endpoint resolution
##[derive(Debug)]
pub struct InvalidParams {
field: std::borrow::Cow<'static, str>
}

impl InvalidParams {
##[allow(dead_code)]
fn missing(field: &'static str) -> Self {
Self { field: field.into() }
}
}

impl std::fmt::Display for InvalidParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "a required field was missing: `{}`", self.field)
}
}

impl std::error::Error for InvalidParams { }
""",
)
}

/**
* Generates an endpoints struct based on the provided endpoint rules. The struct fields are `pub(crate)`
* with optionality as indicated by the required status of the parameter.
Expand Down Expand Up @@ -251,6 +269,17 @@ internal class EndpointParamsGenerator(
"Params" to paramsStruct(),
"ParamsError" to paramsError(),
) {
// additional validation for endpoint parameters during construction
parameters.toList().forEach { parameter ->
val validators =
codegenContext.rootDecorator.endpointCustomizations(codegenContext)
.mapNotNull { it.endpointParamsBuilderValidator(codegenContext, parameter) }

validators.forEach { validator ->
rust("#W;", validator)
}
}

val params =
writable {
Attribute.AllowClippyUnnecessaryLazyEvaluations.render(this)
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ allowLocalDeps=false
# Avoid registering dependencies/plugins/tasks that are only used for testing purposes
isTestingEnabled=true
# codegen publication version
codegenVersion=0.1.4
codegenVersion=0.1.5
Loading