Skip to content

Commit 0a63d5b

Browse files
drganjooFahad Zubair
andauthored
Pass UnionShape for union type discrimination (#3984)
UnionShape needs to be passed so that the customization code can detect which Union type is being generated. Co-authored-by: Fahad Zubair <fahadzub@amazon.com>
1 parent 3d801c4 commit 0a63d5b

File tree

1 file changed

+7
-4
lines changed
  • codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse

1 file changed

+7
-4
lines changed

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,10 @@ abstract class CborParserCustomization : NamedCustomization<CborParserSection>()
7878
* @param defaultContext The default discrimination context containing decoder symbol and discriminator method.
7979
* @return UnionVariantDiscriminatorContext that defines how to discriminate union variants.
8080
*/
81-
open fun getUnionVariantDiscriminator(defaultContext: CborParserGenerator.UnionVariantDiscriminatorContext) =
82-
defaultContext
81+
open fun getUnionVariantDiscriminator(
82+
unionShape: UnionShape,
83+
defaultContext: CborParserGenerator.UnionVariantDiscriminatorContext,
84+
) = defaultContext
8385
}
8486

8587
class CborParserGenerator(
@@ -330,7 +332,7 @@ class CborParserGenerator(
330332
val returnSymbolToParse = returnSymbolToParse(shape)
331333
// Get actual decoder type to use and the discriminating function to call to extract
332334
// the variant of the union that has been encoded in the data.
333-
val discriminatorContext = getUnionDiscriminatorContext("Decoder", "decoder.str()?.as_ref()")
335+
val discriminatorContext = getUnionDiscriminatorContext(shape, "Decoder", "decoder.str()?.as_ref()")
334336

335337
rustBlockTemplate(
336338
"""
@@ -394,6 +396,7 @@ class CborParserGenerator(
394396
}
395397

396398
private fun getUnionDiscriminatorContext(
399+
unionShape: UnionShape,
397400
decoderType: String,
398401
callMethod: String,
399402
): UnionVariantDiscriminatorContext {
@@ -403,7 +406,7 @@ class CborParserGenerator(
403406
writable { rustTemplate(callMethod) },
404407
)
405408
return customizations.fold(defaultUnionPairContext) { context, customization ->
406-
customization.getUnionVariantDiscriminator(context)
409+
customization.getUnionVariantDiscriminator(unionShape, context)
407410
}
408411
}
409412

0 commit comments

Comments
 (0)