@@ -29,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
2929import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
3030import software.amazon.smithy.rust.codegen.core.rustlang.Writable
3131import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
32+ import software.amazon.smithy.rust.codegen.core.rustlang.join
3233import software.amazon.smithy.rust.codegen.core.rustlang.rust
3334import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
3435import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
@@ -57,10 +58,29 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape
5758/* * Class describing a CBOR parser section that can be used in a customization. */
5859sealed class CborParserSection (name : String ) : Section(name) {
5960 data class BeforeBoxingDeserializedMember (val shape : MemberShape ) : CborParserSection(" BeforeBoxingDeserializedMember" )
61+
62+ /* *
63+ * Represents a customization point in union deserialization that occurs before decoding the map structure.
64+ * This allows for custom handling of union variants before the standard map decoding logic is applied.
65+ * @property shape The union shape being deserialized.
66+ */
67+ data class UnionParserBeforeDecodingMap (val shape : UnionShape ) : CborParserSection(" UnionParserBeforeDecodingMap" )
6068}
6169
62- /* * Customization for the CBOR parser. */
63- typealias CborParserCustomization = NamedCustomization <CborParserSection >
70+ /* *
71+ * Customization class for CBOR parser generation that allows modification of union type deserialization behavior.
72+ * Previously, union variant discrimination was hardcoded to use `decoder.str()`. This has been made more flexible
73+ * to support different decoder implementations and discrimination methods.
74+ */
75+ abstract class CborParserCustomization : NamedCustomization <CborParserSection >() {
76+ /* *
77+ * Allows customization of how union variants are discriminated during deserialization.
78+ * @param defaultContext The default discrimination context containing decoder symbol and discriminator method.
79+ * @return UnionVariantDiscriminatorContext that defines how to discriminate union variants.
80+ */
81+ open fun getUnionVariantDiscriminator (defaultContext : CborParserGenerator .UnionVariantDiscriminatorContext ) =
82+ defaultContext
83+ }
6484
6585class CborParserGenerator (
6686 private val codegenContext : CodegenContext ,
@@ -75,6 +95,16 @@ class CborParserGenerator(
7595 private val shouldWrapBuilderMemberSetterInputWithOption : (MemberShape ) -> Boolean = { _ -> true },
7696 private val customizations : List <CborParserCustomization > = emptyList(),
7797) : StructuredDataParserGenerator {
98+ /* *
99+ * Context class that encapsulates the information needed to discriminate union variants during deserialization.
100+ * @property decoderSymbol The symbol representing the decoder type.
101+ * @property variantDiscriminatorExpression The method call expression to determine the union variant.
102+ */
103+ data class UnionVariantDiscriminatorContext (
104+ val decoderSymbol : Symbol ,
105+ val variantDiscriminatorExpression : Writable ,
106+ )
107+
78108 private val model = codegenContext.model
79109 private val symbolProvider = codegenContext.symbolProvider
80110 private val runtimeConfig = codegenContext.runtimeConfig
@@ -298,16 +328,26 @@ class CborParserGenerator(
298328 private fun unionPairParserFnWritable (shape : UnionShape ) =
299329 writable {
300330 val returnSymbolToParse = returnSymbolToParse(shape)
331+ // Get actual decoder type to use and the discriminating function to call to extract
332+ // the variant of the union that has been encoded in the data.
333+ val discriminatorContext = getUnionDiscriminatorContext(" Decoder" , " decoder.str()?.as_ref()" )
334+
301335 rustBlockTemplate(
302336 """
303337 fn pair(
304- decoder: &mut #{Decoder }
338+ decoder: &mut #{DecoderSymbol }
305339 ) -> #{Result}<#{UnionSymbol}, #{Error}>
306340 """ ,
307341 * codegenScope,
342+ " DecoderSymbol" to discriminatorContext.decoderSymbol,
308343 " UnionSymbol" to returnSymbolToParse.symbol,
309344 ) {
310- withBlock(" Ok(match decoder.str()?.as_ref() {" , " })" ) {
345+ rustTemplate(
346+ """
347+ Ok(match #{VariableDiscriminatingExpression} {
348+ """ ,
349+ " VariableDiscriminatingExpression" to discriminatorContext.variantDiscriminatorExpression,
350+ ).run {
311351 for (member in shape.members()) {
312352 val variantName = symbolProvider.toMemberName(member)
313353
@@ -349,9 +389,24 @@ class CborParserGenerator(
349389 )
350390 }
351391 }
392+ rust(" })" )
352393 }
353394 }
354395
396+ private fun getUnionDiscriminatorContext (
397+ decoderType : String ,
398+ callMethod : String ,
399+ ): UnionVariantDiscriminatorContext {
400+ val defaultUnionPairContext =
401+ UnionVariantDiscriminatorContext (
402+ smithyCbor.resolve(decoderType).toSymbol(),
403+ writable { rustTemplate(callMethod) },
404+ )
405+ return customizations.fold(defaultUnionPairContext) { context, customization ->
406+ customization.getUnionVariantDiscriminator(context)
407+ }
408+ }
409+
355410 enum class CollectionKind {
356411 Map ,
357412 List ,
@@ -677,12 +732,22 @@ class CborParserGenerator(
677732
678733 private fun RustWriter.deserializeUnion (shape : UnionShape ) {
679734 val returnSymbolToParse = returnSymbolToParse(shape)
735+ val beforeDecoderMapCustomization =
736+ customizations.map { customization ->
737+ customization.section(
738+ CborParserSection .UnionParserBeforeDecodingMap (
739+ shape,
740+ ),
741+ )
742+ }.join(" " )
743+
680744 val parser =
681745 protocolFunctions.deserializeFn(shape) { fnName ->
682746 rustTemplate(
683747 """
684748 pub(crate) fn $fnName (decoder: &mut #{Decoder}) -> #{Result}<#{UnionSymbol}, #{Error}> {
685749 #{UnionPairParserFnWritable}
750+ #{BeforeDecoderMapCustomization:W}
686751
687752 match decoder.map()? {
688753 None => {
@@ -707,6 +772,7 @@ class CborParserGenerator(
707772 """ ,
708773 " UnionSymbol" to returnSymbolToParse.symbol,
709774 " UnionPairParserFnWritable" to unionPairParserFnWritable(shape),
775+ " BeforeDecoderMapCustomization" to beforeDecoderMapCustomization,
710776 * codegenScope,
711777 )
712778 }
0 commit comments