@@ -52,60 +52,20 @@ def configure_tagged_union(
5252 if is_type_alias (union ):
5353 union = union .__value__
5454 args = union .__args__
55+
5556 tag_to_hook = {}
5657 exact_cl_unstruct_hooks = {}
57- for cl in args :
58- tag = tag_generator (cl )
59- struct_handler = converter .get_structure_hook (cl )
60- unstruct_handler = converter .get_unstructure_hook (cl )
61-
62- def structure_union_member (val : dict , _cl = cl , _h = struct_handler ) -> cl :
63- return _h (val , _cl )
64-
65- def unstructure_union_member (val : union , _h = unstruct_handler ) -> dict :
66- return _h (val )
67-
68- tag_to_hook [tag ] = structure_union_member
69- exact_cl_unstruct_hooks [cl ] = unstructure_union_member
70-
71- cl_to_tag = {cl : tag_generator (cl ) for cl in args }
58+ cl_to_tag = {}
7259
7360 if default is not NOTHING :
7461 default_handler = converter .get_structure_hook (default )
7562
7663 def structure_default (val : dict , _cl = default , _h = default_handler ):
7764 return _h (val , _cl )
7865
79- tag_to_hook = defaultdict (lambda : structure_default , tag_to_hook )
80- cl_to_tag = defaultdict (lambda : default , cl_to_tag )
66+ tag_to_hook = defaultdict (lambda : structure_default )
67+ cl_to_tag = defaultdict (lambda : default )
8168
82- def unstructure_tagged_union (
83- val : union ,
84- _exact_cl_unstruct_hooks = exact_cl_unstruct_hooks ,
85- _cl_to_tag = cl_to_tag ,
86- _tag_name = tag_name ,
87- ) -> dict :
88- res = _exact_cl_unstruct_hooks [val .__class__ ](val )
89- res [_tag_name ] = _cl_to_tag [val .__class__ ]
90- return res
91-
92- if default is NOTHING :
93- if getattr (converter , "forbid_extra_keys" , False ):
94-
95- def structure_tagged_union (
96- val : dict , _ , _tag_to_cl = tag_to_hook , _tag_name = tag_name
97- ) -> union :
98- val = val .copy ()
99- return _tag_to_cl [val .pop (_tag_name )](val )
100-
101- else :
102-
103- def structure_tagged_union (
104- val : dict , _ , _tag_to_cl = tag_to_hook , _tag_name = tag_name
105- ) -> union :
106- return _tag_to_cl [val [_tag_name ]](val )
107-
108- else :
10969 if getattr (converter , "forbid_extra_keys" , False ):
11070
11171 def structure_tagged_union (
@@ -135,9 +95,50 @@ def structure_tagged_union(
13595 return _tag_to_hook [val [_tag_name ]](val )
13696 return _dh (val , _default )
13797
98+ else :
99+ if getattr (converter , "forbid_extra_keys" , False ):
100+
101+ def structure_tagged_union (
102+ val : dict , _ , _tag_to_cl = tag_to_hook , _tag_name = tag_name
103+ ) -> union :
104+ val = val .copy ()
105+ return _tag_to_cl [val .pop (_tag_name )](val )
106+
107+ else :
108+
109+ def structure_tagged_union (
110+ val : dict , _ , _tag_to_cl = tag_to_hook , _tag_name = tag_name
111+ ) -> union :
112+ return _tag_to_cl [val [_tag_name ]](val )
113+
114+ def unstructure_tagged_union (
115+ val : union ,
116+ _exact_cl_unstruct_hooks = exact_cl_unstruct_hooks ,
117+ _cl_to_tag = cl_to_tag ,
118+ _tag_name = tag_name ,
119+ ) -> dict :
120+ res = _exact_cl_unstruct_hooks [val .__class__ ](val )
121+ res [_tag_name ] = _cl_to_tag [val .__class__ ]
122+ return res
123+
138124 converter .register_unstructure_hook (union , unstructure_tagged_union )
139125 converter .register_structure_hook (union , structure_tagged_union )
140126
127+ for cl in args :
128+ tag = tag_generator (cl )
129+ struct_handler = converter .get_structure_hook (cl )
130+ unstruct_handler = converter .get_unstructure_hook (cl )
131+
132+ def structure_union_member (val : dict , _cl = cl , _h = struct_handler ) -> cl :
133+ return _h (val , _cl )
134+
135+ def unstructure_union_member (val : union , _h = unstruct_handler ) -> dict :
136+ return _h (val )
137+
138+ tag_to_hook [tag ] = structure_union_member
139+ exact_cl_unstruct_hooks [cl ] = unstructure_union_member
140+ cl_to_tag [cl ] = tag
141+
141142
142143def configure_union_passthrough (
143144 union : Any , converter : BaseConverter , accept_ints_as_floats : bool = True
0 commit comments