Skip to content

Commit 2240e24

Browse files
authored
Fix bug where servers rejected event stream content type (#4322)
## Motivation and Context There is a bug in the generated server SDK where for operations with an output event stream, it rejects `application/vnd.amazon.eventstream`, instead expecting `application/cbor`. This is incorrect. However, we need to maintain backwards compatibility with existing clients. This adds a new fallback option that protocols can set to handle this for RpcV2Cbor. ## Description - Small refactoring to extract a mime-type generator - Allow accepting multiple headers as required. ## Testing I briefly went down the path of trying to get this to work as a protocol test but ultimately it became too much of a yak shave. I added a unit test that should be sufficient. ## Checklist <!--- If a checkbox below is not applicable, then please DELETE it rather than leaving it unchecked --> - [ ] For changes to the smithy-rs codegen or runtime crates, I have created a changelog entry Markdown file in the `.changelog` directory, specifying "client," "server," or both in the `applies_to` key. - [ ] For changes to the AWS SDK, generated SDK code, or SDK runtime crates, I have created a changelog entry Markdown file in the `.changelog` directory, specifying "aws-sdk-rust" in the `applies_to` key. ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._
1 parent 71c69f7 commit 2240e24

File tree

7 files changed

+221
-40
lines changed

7 files changed

+221
-40
lines changed

.changelog/1759353705.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
---
2+
applies_to:
3+
- server
4+
authors:
5+
- rcoh
6+
references: [ ]
7+
breaking: false
8+
new_feature: false
9+
bug_fix: true
10+
---
11+
12+
Fix bug where servers rejected `application/vnd.amazon.evenstream` ACCEPT header for RPCv2Cbor
13+
14+
This change allows this header while also allowing `application/cbor` for backwards compatibility.

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ interface HttpBindingResolver {
102102
*/
103103
fun responseContentType(operationShape: OperationShape): String?
104104

105+
/**
106+
* Due to an initial implementation bug, were accepting `application/cbor` when we should have been accepting
107+
* an event stream header. This allows configuring an optional fallback header to support backwards compatibility
108+
* in these cases.
109+
*/
110+
fun legacyBackwardsCompatContentType(operationShape: OperationShape): String? = null
111+
105112
/**
106113
* Determines the value of the event stream `:content-type` header based on union member
107114
*/
@@ -203,7 +210,11 @@ open class HttpTraitHttpBindingResolver(
203210
).orNull()
204211

205212
override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
206-
ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, contentTypes.eventStreamMessageContentType)
213+
ProtocolContentTypes.eventStreamMemberContentType(
214+
model,
215+
memberShape,
216+
contentTypes.eventStreamMessageContentType,
217+
)
207218

208219
// Sort the members after extracting them from the map to have a consistent order
209220
private fun mappedBindings(bindings: Map<String, HttpBinding>): List<HttpBindingDescriptor> =

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@ class RpcV2CborHttpBindingResolver(
9292
*/
9393
override fun responseContentType(operationShape: OperationShape): String? =
9494
if (OperationNormalizer.hadUserModeledOperationOutput(operationShape, model)) {
95+
if (operationShape.isOutputEventStream(model)) {
96+
contentTypes.eventStreamContentType
97+
} else {
98+
contentTypes.responseDocument
99+
}
100+
} else {
101+
null
102+
}
103+
104+
override fun legacyBackwardsCompatContentType(operationShape: OperationShape): String? =
105+
if (operationShape.isOutputEventStream(model)) {
106+
// Return "application/cbor" for backwards compatibility
95107
contentTypes.responseDocument
96108
} else {
97109
null

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/StringsTest.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ internal class StringsTest {
3232
"NotificationARNs".toSnakeCase() shouldBe "notification_arns"
3333
}
3434

35+
@Test
36+
fun handleDashes() {
37+
"application/x-amzn-json-1.1".toSnakeCase() shouldBe "application_x_amzn_json_1_1"
38+
}
39+
3540
@Test
3641
fun testAllNames() {
3742
// Set this to true to write a new test expectation file

codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import software.amazon.smithy.model.traits.HttpPayloadTrait
2323
import software.amazon.smithy.model.traits.HttpTrait
2424
import software.amazon.smithy.model.traits.MediaTypeTrait
2525
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
26+
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
2627
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
2728
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
2829
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
@@ -66,6 +67,7 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait
6667
import software.amazon.smithy.rust.codegen.core.util.inputShape
6768
import software.amazon.smithy.rust.codegen.core.util.isStreaming
6869
import software.amazon.smithy.rust.codegen.core.util.outputShape
70+
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
6971
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
7072
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
7173
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
@@ -228,43 +230,34 @@ class ServerHttpBoundProtocolTraitImplGenerator(
228230
outputSymbol: Symbol,
229231
operationShape: OperationShape,
230232
) {
231-
val operationName = symbolProvider.toSymbol(operationShape).name
232-
val staticContentType = "CONTENT_TYPE_${operationName.uppercase()}"
233233
val verifyAcceptHeader =
234234
writable {
235235
httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
236-
rustTemplate(
237-
"""
238-
if !#{SmithyHttpServer}::protocol::accept_header_classifier(request.headers(), &$staticContentType) {
239-
return Err(#{RequestRejection}::NotAcceptable);
240-
}
241-
""",
242-
*codegenScope,
243-
)
244-
}
245-
}
246-
val verifyAcceptHeaderStaticContentTypeInit =
247-
writable {
248-
httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
249-
val init =
250-
when (contentType) {
251-
"application/json" ->
252-
"const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_JSON;"
253-
254-
"application/octet-stream" ->
255-
"const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_OCTET_STREAM;"
256-
257-
"application/x-www-form-urlencoded" ->
258-
"const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_WWW_FORM_URLENCODED;"
259-
260-
else ->
261-
"""
262-
static $staticContentType: std::sync::LazyLock<#{Mime}::Mime> = std::sync::LazyLock::new(|| {
263-
${contentType.dq()}.parse::<#{Mime}::Mime>().expect("BUG: MIME parsing failed, content_type is not valid")
264-
});
265-
"""
266-
}
267-
rustTemplate(init, *codegenScope)
236+
val legacyContentType = httpBindingResolver.legacyBackwardsCompatContentType(operationShape)
237+
if (legacyContentType != null) {
238+
// For operations with legacy backwards compatibility, accept both content types
239+
rustTemplate(
240+
"""
241+
if !#{SmithyHttpServer}::protocol::accept_header_classifier(request.headers(), &#{ContentType}) &&
242+
!#{SmithyHttpServer}::protocol::accept_header_classifier(request.headers(), &#{FallbackContentType}) {
243+
return Err(#{RequestRejection}::NotAcceptable);
244+
}
245+
""",
246+
"ContentType" to mimeType(contentType),
247+
"FallbackContentType" to mimeType(legacyContentType),
248+
*codegenScope,
249+
)
250+
} else {
251+
rustTemplate(
252+
"""
253+
if !#{SmithyHttpServer}::protocol::accept_header_classifier(request.headers(), &#{ContentType}) {
254+
return Err(#{RequestRejection}::NotAcceptable);
255+
}
256+
""",
257+
"ContentType" to mimeType(contentType),
258+
*codegenScope,
259+
)
260+
}
268261
}
269262
}
270263

@@ -273,7 +266,6 @@ class ServerHttpBoundProtocolTraitImplGenerator(
273266
// TODO(https://github.com/smithy-lang/smithy-rs/issues/2238): Remove the `Pin<Box<dyn Future>>` and replace with thin wrapper around `Collect`.
274267
rustTemplate(
275268
"""
276-
#{verifyAcceptHeaderStaticContentTypeInit:W}
277269
#{PinProjectLite}::pin_project! {
278270
/// A [`Future`](std::future::Future) aggregating the body bytes of a [`Request`] and constructing the
279271
/// [`${inputSymbol.name}`](#{I}) using modelled bindings.
@@ -324,8 +316,6 @@ class ServerHttpBoundProtocolTraitImplGenerator(
324316
"Marker" to protocol.markerStruct(),
325317
"parse_request" to serverParseRequest(operationShape),
326318
"verifyAcceptHeader" to verifyAcceptHeader,
327-
"verifyAcceptHeaderStaticContentTypeInit" to
328-
verifyAcceptHeaderStaticContentTypeInit,
329319
)
330320

331321
// Implement `into_response` for output types.
@@ -379,6 +369,26 @@ class ServerHttpBoundProtocolTraitImplGenerator(
379369
}
380370
}
381371

372+
/**
373+
* Generates `pub(crate) static CONTENT_TYPE_<MIME_TYPE> = ....
374+
*
375+
* Usage: In templates, #{MimeType}, "MimeType" to mimeType("yourDesiredType")
376+
*/
377+
private fun mimeType(type: String): RuntimeType {
378+
val variableName = type.toSnakeCase().uppercase()
379+
val typeName = "CONTENT_TYPE_$variableName"
380+
return RuntimeType.forInlineFun(typeName, RustModule.private("mimes")) {
381+
rustTemplate(
382+
"""
383+
pub(crate) static $typeName: std::sync::LazyLock<#{Mime}::Mime> = std::sync::LazyLock::new(|| {
384+
${type.dq()}.parse::<#{Mime}::Mime>().expect("BUG: MIME parsing failed, content_type is not valid")
385+
});
386+
""",
387+
*codegenScope,
388+
)
389+
}
390+
}
391+
382392
private fun serverParseRequest(operationShape: OperationShape): RuntimeType {
383393
val inputShape = operationShape.inputShape(model)
384394
val inputSymbol = symbolProvider.toSymbol(inputShape)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package software.amazon.smithy.rust.codegen.server.smithy.generators
7+
8+
import org.junit.jupiter.api.Test
9+
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
10+
import software.amazon.smithy.rust.codegen.core.rustlang.rust
11+
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
12+
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
13+
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
14+
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
15+
import software.amazon.smithy.rust.codegen.core.testutil.testModule
16+
import software.amazon.smithy.rust.codegen.core.testutil.tokioTest
17+
import software.amazon.smithy.rust.codegen.core.util.dq
18+
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
19+
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest
20+
21+
internal class EventStreamAcceptHeaderTest {
22+
private val model =
23+
"""
24+
${'$'}version: "2.0"
25+
namespace test
26+
27+
use smithy.protocols#rpcv2Cbor
28+
use smithy.framework#ValidationException
29+
30+
@rpcv2Cbor
31+
service TestService {
32+
operations: [StreamingOutputOperation]
33+
}
34+
35+
operation StreamingOutputOperation {
36+
input: StreamingOutputOperationInput
37+
output: StreamingOutputOperationOutput
38+
errors: [ValidationException]
39+
}
40+
41+
structure StreamingOutputOperationInput {
42+
message: String
43+
}
44+
45+
structure StreamingOutputOperationOutput {
46+
events: Events
47+
}
48+
49+
@streaming
50+
union Events {
51+
event: StreamingEvent
52+
}
53+
54+
structure StreamingEvent {
55+
data: String
56+
}
57+
""".asSmithyModel()
58+
59+
@Test
60+
fun acceptHeaderTests() {
61+
serverIntegrationTest(model) { codegenContext, rustCrate ->
62+
rustCrate.testModule {
63+
generateAcceptHeaderTest(
64+
acceptHeader = "application/vnd.amazon.eventstream",
65+
shouldFail = false,
66+
codegenContext = codegenContext,
67+
)
68+
generateAcceptHeaderTest(
69+
acceptHeader = "application/cbor",
70+
shouldFail = false,
71+
codegenContext = codegenContext,
72+
)
73+
generateAcceptHeaderTest(
74+
acceptHeader = "application/invalid",
75+
shouldFail = true,
76+
codegenContext = codegenContext,
77+
)
78+
generateAcceptHeaderTest(
79+
acceptHeader = "application/json, application/cbor",
80+
shouldFail = false,
81+
codegenContext = codegenContext,
82+
testName = "combined_header",
83+
)
84+
}
85+
}
86+
}
87+
88+
private fun RustWriter.generateAcceptHeaderTest(
89+
acceptHeader: String,
90+
shouldFail: Boolean,
91+
codegenContext: CodegenContext,
92+
testName: String = acceptHeader.toSnakeCase(),
93+
) {
94+
tokioTest("test_header_$testName") {
95+
rustTemplate(
96+
"""
97+
use aws_smithy_http_server::body::Body;
98+
use aws_smithy_http_server::request::FromRequest;
99+
let cbor_empty_bytes = #{Bytes}::copy_from_slice(&#{decode_body_data}(
100+
"oA==".as_bytes(),
101+
#{MediaType}::from("application/cbor"),
102+
));
103+
104+
let http_request = ::http::Request::builder()
105+
.uri("/service/TestService/operation/StreamingOutputOperation")
106+
.method("POST")
107+
.header("Accept", ${acceptHeader.dq()})
108+
.header("Content-Type", "application/cbor")
109+
.header("smithy-protocol", "rpc-v2-cbor")
110+
.body(Body::from(cbor_empty_bytes))
111+
.unwrap();
112+
let parsed = crate::input::StreamingOutputOperationInput::from_request(http_request).await;
113+
""",
114+
"Bytes" to RuntimeType.Bytes,
115+
"MediaType" to RuntimeType.protocolTest(codegenContext.runtimeConfig, "MediaType"),
116+
"decode_body_data" to
117+
RuntimeType.protocolTest(
118+
codegenContext.runtimeConfig,
119+
"decode_body_data",
120+
),
121+
)
122+
123+
if (shouldFail) {
124+
rust("""parsed.expect_err("header should be rejected");""")
125+
} else {
126+
rust("""parsed.expect("header should be accepted");""")
127+
}
128+
}
129+
}
130+
}

gradle.properties

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,5 @@ kotlin.code.style=official
1616
allowLocalDeps=false
1717
# Avoid registering dependencies/plugins/tasks that are only used for testing purposes
1818
isTestingEnabled=true
19-
2019
# codegen publication version
21-
codegenVersion=0.1.1
20+
codegenVersion=0.1.2

0 commit comments

Comments
 (0)