diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt index 32631a01221..8450fbb30fe 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt @@ -155,7 +155,11 @@ class ClientEventStreamMarshallerGeneratorTest { .is_some()); assert_eq!( msg.payload(), - &bytes::Bytes::from_static(${testCase.generateRustPayloadInitializer(rpcEventStreamTestCase.expectedInInitialRequest)}) + &bytes::Bytes::from_static(${ + testCase.generateRustPayloadInitializer( + rpcEventStreamTestCase.expectedInInitialRequest, + ) + }) ); """, ) @@ -178,7 +182,7 @@ class ClientEventStreamMarshallerGeneratorTest { class TestCasesProvider : ArgumentsProvider { override fun provideArguments(context: ExtensionContext?): Stream = - EventStreamTestModels.TEST_CASES.map { Arguments.of(it) }.stream() + EventStreamTestModels.TEST_CASES.map { Arguments.of(it.withEnumMembers()) }.stream() } /** diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index b2fc9098be6..28737bf353c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.EnumShape import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.LongShape import software.amazon.smithy.model.shapes.MemberShape @@ -261,6 +262,7 @@ class EventStreamUnmarshallerGenerator( is IntegerShape -> rustTemplate("#{expect_fns}::expect_int32(header)?", *codegenScope) is LongShape -> rustTemplate("#{expect_fns}::expect_int64(header)?", *codegenScope) is BlobShape -> rustTemplate("#{expect_fns}::expect_byte_array(header)?", *codegenScope) + is EnumShape -> rustTemplate("#{expect_fns}::expect_string(header)?.as_str().into()", *codegenScope) is StringShape -> rustTemplate("#{expect_fns}::expect_string(header)?", *codegenScope) is TimestampShape -> rustTemplate("#{expect_fns}::expect_timestamp(header)?", *codegenScope) else -> throw IllegalStateException("unsupported event stream header shape type: $target") @@ -383,7 +385,8 @@ class EventStreamUnmarshallerGenerator( "builder", target, mapErr = { rustTemplate( - """|err|#{Error}::unmarshalling(format!("{}", err))""", *codegenScope, + """|err|#{Error}::unmarshalling(format!("{}", err))""", + *codegenScope, ) }, ), diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt index 326f1203475..4fbbbfaf226 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.EnumShape import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.LongShape import software.amazon.smithy.model.shapes.MemberShape @@ -243,6 +244,7 @@ open class EventStreamMarshallerGenerator( is IntegerShape -> "Int32($inputName)" is LongShape -> "Int64($inputName)" is BlobShape -> "ByteArray($inputName.into_inner().into())" + is EnumShape -> "String($inputName.to_string().into())" is StringShape -> "String($inputName.into())" is TimestampShape -> "Timestamp($inputName)" else -> throw IllegalStateException("unsupported event stream header shape type: $target") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index 39abeb551e7..acf8be80d0a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -21,8 +21,11 @@ private fun fillInBaseModel( namespacedProtocolName: String, extraServiceAnnotations: String = "", nonEventStreamMembers: String = "", + extraEventHeaderMembers: String = "", + extraShapes: String = "", ): String = """ + ${"\$version: \"2\""} namespace test use smithy.framework#ValidationException @@ -42,6 +45,8 @@ private fun fillInBaseModel( Message: String, } + $extraShapes + structure MessageWithBlob { @eventPayload data: Blob } structure MessageWithString { @eventPayload data: String } structure MessageWithStruct { @eventPayload someStruct: TestStruct } @@ -55,6 +60,7 @@ private fun fillInBaseModel( @eventHeader short: Short, @eventHeader string: String, @eventHeader timestamp: Timestamp, + $extraEventHeaderMembers } structure MessageWithHeaderAndPayload { @eventHeader header: String, @@ -129,6 +135,26 @@ object EventStreamTestModels { fun withNonEventStreamMembers(nonEventStreamMembers: String): TestCase = this.copy(model = fillInBaseModel(this.protocolShapeId, "", nonEventStreamMembers).asSmithyModel()) + + // Server doesn't support enum members in event streams, so this util allows Clients to test with those shapes + fun withEnumMembers(): TestCase = + this.copy( + model = + fillInBaseModel( + this.protocolShapeId, + extraEventHeaderMembers = "@eventHeader enum: TheEnum,\n@eventHeader intEnum: FaceCard,", + extraShapes = """ enum TheEnum { + FOO + BAR + } + + intEnum FaceCard { + JACK = 1 + QUEEN = 2 + KING = 3 + }""", + ).asSmithyModel(), + ) } private fun base64Encode(input: ByteArray): String {