Skip to content

Commit fd8b4f6

Browse files
authored
Fix protocol selection behavior in ClientProtocolLoader (#4165)
## Description The bugs include - The default SDK-supported protocols (`DefaultProtocols`) were listed in an incorrect priority order. - Protocol resolution logic incorrectly iterated over service-applied protocols, which is returned by `getProtocols()` whose result may not reflect the intended priority. This PR addresses these issues. ## Testing - Existing CI - `ClientProtocolLoaderTest.kt` ## Checklist - [x] For changes to the smithy-rs codegen or runtime crates, I have created a changelog entry Markdown file in the `.changelog` directory, specifying "client," "server," or both in the `applies_to` key. - [x] For changes to the AWS SDK, generated SDK code, or SDK runtime crates, I have created a changelog entry Markdown file in the `.changelog` directory, specifying "aws-sdk-rust" in the `applies_to` key. ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._
1 parent dc480da commit fd8b4f6

File tree

4 files changed

+144
-6
lines changed

4 files changed

+144
-6
lines changed

.changelog/1749155846.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
---
2+
applies_to:
3+
- client
4+
- aws-sdk-rust
5+
authors:
6+
- ysaito1001
7+
references:
8+
- smithy-rs#4165
9+
breaking: false
10+
new_feature: false
11+
bug_fix: true
12+
---
13+
Fix default supported protocols incorrectly ordered in `ClientProtocolLoader`.

codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ class ClientProtocolLoader(supportedProtocols: ProtocolMap<OperationGenerator, C
3737
companion object {
3838
val DefaultProtocols =
3939
mapOf(
40+
Rpcv2CborTrait.ID to ClientRpcV2CborFactory(),
4041
AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10),
4142
AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11),
42-
AwsQueryTrait.ID to ClientAwsQueryFactory(),
43-
Ec2QueryTrait.ID to ClientEc2QueryFactory(),
4443
RestJson1Trait.ID to ClientRestJsonFactory(),
4544
RestXmlTrait.ID to ClientRestXmlFactory(),
46-
Rpcv2CborTrait.ID to ClientRpcV2CborFactory(),
45+
AwsQueryTrait.ID to ClientAwsQueryFactory(),
46+
Ec2QueryTrait.ID to ClientEc2QueryFactory(),
4747
)
4848
val Default = ClientProtocolLoader(DefaultProtocols)
4949
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package software.amazon.smithy.rust.codegen.client.smithy.protocols
7+
8+
import org.junit.jupiter.api.Assertions.assertEquals
9+
import org.junit.jupiter.api.Test
10+
import org.junit.jupiter.api.assertThrows
11+
import org.junit.jupiter.api.extension.ExtensionContext
12+
import org.junit.jupiter.params.ParameterizedTest
13+
import org.junit.jupiter.params.provider.Arguments
14+
import org.junit.jupiter.params.provider.ArgumentsProvider
15+
import org.junit.jupiter.params.provider.ArgumentsSource
16+
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
17+
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
18+
import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait
19+
import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
20+
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
21+
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
22+
import software.amazon.smithy.codegen.core.CodegenException
23+
import software.amazon.smithy.model.Model
24+
import software.amazon.smithy.model.shapes.ServiceShape
25+
import software.amazon.smithy.model.shapes.ShapeId
26+
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
27+
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
28+
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator
29+
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientProtocolLoader.Companion.DefaultProtocols
30+
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
31+
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
32+
import java.util.stream.Stream
33+
34+
data class TestCase(
35+
val supportedProtocols: ProtocolMap<OperationGenerator, ClientCodegenContext>,
36+
val model: Model,
37+
val resolvedProtocol: String?,
38+
)
39+
40+
class ClientProtocolLoaderTest {
41+
@Test
42+
fun `test priority order of default supported protocols`() {
43+
val expectedOrder =
44+
listOf(
45+
Rpcv2CborTrait.ID,
46+
AwsJson1_0Trait.ID,
47+
AwsJson1_1Trait.ID,
48+
RestJson1Trait.ID,
49+
RestXmlTrait.ID,
50+
AwsQueryTrait.ID,
51+
Ec2QueryTrait.ID,
52+
)
53+
assertEquals(expectedOrder, DefaultProtocols.keys.toList())
54+
}
55+
56+
// Although the test function name appears generic, its purpose is to verify whether
57+
// the RPCv2Cbor protocol is selected based on specific contexts.
58+
@ParameterizedTest
59+
@ArgumentsSource(ProtocolSelectionTestCaseProvider::class)
60+
fun `should resolve expected protocol`(testCase: TestCase) {
61+
val protocolLoader = ClientProtocolLoader(testCase.supportedProtocols)
62+
val serviceShape = testCase.model.expectShape(ShapeId.from("test#TestService"), ServiceShape::class.java)
63+
if (testCase.resolvedProtocol.isNullOrEmpty()) {
64+
assertThrows<CodegenException> {
65+
protocolLoader.protocolFor(testCase.model, serviceShape)
66+
}
67+
} else {
68+
val actual = protocolLoader.protocolFor(testCase.model, serviceShape).first.name
69+
assertEquals(testCase.resolvedProtocol, actual)
70+
}
71+
}
72+
}
73+
74+
class ProtocolSelectionTestCaseProvider : ArgumentsProvider {
75+
override fun provideArguments(p0: ExtensionContext?): Stream<out Arguments> {
76+
val protocolsWithoutRpcv2Cbor = LinkedHashMap(DefaultProtocols)
77+
protocolsWithoutRpcv2Cbor.remove(Rpcv2CborTrait.ID)
78+
79+
return arrayOf(
80+
TestCase(DefaultProtocols, model(listOf("rpcv2Cbor", "awsJson1_0")), "rpcv2Cbor"),
81+
TestCase(DefaultProtocols, model(listOf("rpcv2Cbor")), "rpcv2Cbor"),
82+
TestCase(DefaultProtocols, model(listOf("rpcv2Cbor", "awsJson1_0", "awsQuery")), "rpcv2Cbor"),
83+
TestCase(DefaultProtocols, model(listOf("awsJson1_0", "awsQuery")), "awsJson1_0"),
84+
TestCase(DefaultProtocols, model(listOf("awsQuery")), "awsQuery"),
85+
TestCase(protocolsWithoutRpcv2Cbor, model(listOf("rpcv2Cbor", "awsJson1_0")), "awsJson1_0"),
86+
TestCase(protocolsWithoutRpcv2Cbor, model(listOf("rpcv2Cbor")), null),
87+
TestCase(protocolsWithoutRpcv2Cbor, model(listOf("rpcv2Cbor", "awsJson1_0", "awsQuery")), "awsJson1_0"),
88+
TestCase(protocolsWithoutRpcv2Cbor, model(listOf("awsJson1_0", "awsQuery")), "awsJson1_0"),
89+
TestCase(protocolsWithoutRpcv2Cbor, model(listOf("awsQuery")), "awsQuery"),
90+
).map { Arguments.of(it) }.stream()
91+
}
92+
93+
private fun model(protocols: List<String>) =
94+
(
95+
"""
96+
namespace test
97+
""" + renderProtocols(protocols) +
98+
"""
99+
@xmlNamespace(uri: "http://test.com") // required for @awsQuery
100+
service TestService {
101+
version: "1.0.0"
102+
}
103+
"""
104+
).asSmithyModel(smithyVersion = "2.0")
105+
}
106+
107+
private fun renderProtocols(protocols: List<String>): String {
108+
val (rpcProtocols, awsProtocols) = protocols.partition { it == "rpcv2Cbor" }
109+
110+
val uses =
111+
buildList {
112+
rpcProtocols.forEach { add("use smithy.protocols#$it") }
113+
awsProtocols.forEach { add("use aws.protocols#$it") }
114+
}.joinToString("\n")
115+
116+
val annotations = protocols.joinToString("\n") { "@$it" }
117+
118+
return """
119+
$uses
120+
121+
$annotations
122+
"""
123+
}

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolLoader.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ open class ProtocolLoader<T, C : CodegenContext>(private val supportedProtocols:
1818
model: Model,
1919
serviceShape: ServiceShape,
2020
): Pair<ShapeId, ProtocolGeneratorFactory<T, C>> {
21-
val protocols: MutableMap<ShapeId, Trait> = ServiceIndex.of(model).getProtocols(serviceShape)
21+
val serviceProtocols: MutableMap<ShapeId, Trait> = ServiceIndex.of(model).getProtocols(serviceShape)
2222
val matchingProtocols =
23-
protocols.keys.mapNotNull { protocolId -> supportedProtocols[protocolId]?.let { protocolId to it } }
23+
supportedProtocols.mapNotNull { (protocolId, factory) ->
24+
serviceProtocols[protocolId]?.let { protocolId to factory }
25+
}
2426
if (matchingProtocols.isEmpty()) {
25-
throw CodegenException("No matching protocol — service offers: ${protocols.keys}. We offer: ${supportedProtocols.keys}")
27+
throw CodegenException("No matching protocol — service offers: ${serviceProtocols.keys}. We offer: ${supportedProtocols.keys}")
2628
}
2729
return matchingProtocols.first()
2830
}

0 commit comments

Comments
 (0)