Skip to content

Commit 15739f6

Browse files
committed
configure_tagged_unions: fix recursive type aliases
1 parent 31c19ce commit 15739f6

File tree

3 files changed

+86
-49
lines changed

3 files changed

+86
-49
lines changed

HISTORY.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ Our backwards-compatibility policy can be found [here](https://github.com/python
3131
([#707](https://github.com/python-attrs/cattrs/issues/707) [#708](https://github.com/python-attrs/cattrs/pull/708))
3232
- Enum handling has been optimized by switching to hook factories, improving performance especially for plain enums.
3333
([#705](https://github.com/python-attrs/cattrs/pull/705))
34-
- Fix `include_subclasses` when used with `configure_tagged_union` and classes using diamond inheritance.
34+
- Fix {func}`cattrs.strategies.include_subclasses` when used with {func}`cattrs.strategies.configure_tagged_union` and classes using diamond inheritance.
3535
([#685](https://github.com/python-attrs/cattrs/issues/685) [#713](https://github.com/python-attrs/cattrs/pull/713))
36+
- Fix {func}`cattrs.strategies.configure_tagged_union` when used with recursive type aliases.
37+
([#678](https://github.com/python-attrs/cattrs/issues/678) [#714](https://github.com/python-attrs/cattrs/pull/714))
3638

3739
## 25.3.0 (2025-10-07)
3840

src/cattrs/strategies/_unions.py

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -52,60 +52,20 @@ def configure_tagged_union(
5252
if is_type_alias(union):
5353
union = union.__value__
5454
args = union.__args__
55+
5556
tag_to_hook = {}
5657
exact_cl_unstruct_hooks = {}
57-
for cl in args:
58-
tag = tag_generator(cl)
59-
struct_handler = converter.get_structure_hook(cl)
60-
unstruct_handler = converter.get_unstructure_hook(cl)
61-
62-
def structure_union_member(val: dict, _cl=cl, _h=struct_handler) -> cl:
63-
return _h(val, _cl)
64-
65-
def unstructure_union_member(val: union, _h=unstruct_handler) -> dict:
66-
return _h(val)
67-
68-
tag_to_hook[tag] = structure_union_member
69-
exact_cl_unstruct_hooks[cl] = unstructure_union_member
70-
71-
cl_to_tag = {cl: tag_generator(cl) for cl in args}
58+
cl_to_tag = {}
7259

7360
if default is not NOTHING:
7461
default_handler = converter.get_structure_hook(default)
7562

7663
def structure_default(val: dict, _cl=default, _h=default_handler):
7764
return _h(val, _cl)
7865

79-
tag_to_hook = defaultdict(lambda: structure_default, tag_to_hook)
80-
cl_to_tag = defaultdict(lambda: default, cl_to_tag)
66+
tag_to_hook = defaultdict(lambda: structure_default)
67+
cl_to_tag = defaultdict(lambda: default)
8168

82-
def unstructure_tagged_union(
83-
val: union,
84-
_exact_cl_unstruct_hooks=exact_cl_unstruct_hooks,
85-
_cl_to_tag=cl_to_tag,
86-
_tag_name=tag_name,
87-
) -> dict:
88-
res = _exact_cl_unstruct_hooks[val.__class__](val)
89-
res[_tag_name] = _cl_to_tag[val.__class__]
90-
return res
91-
92-
if default is NOTHING:
93-
if getattr(converter, "forbid_extra_keys", False):
94-
95-
def structure_tagged_union(
96-
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
97-
) -> union:
98-
val = val.copy()
99-
return _tag_to_cl[val.pop(_tag_name)](val)
100-
101-
else:
102-
103-
def structure_tagged_union(
104-
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
105-
) -> union:
106-
return _tag_to_cl[val[_tag_name]](val)
107-
108-
else:
10969
if getattr(converter, "forbid_extra_keys", False):
11070

11171
def structure_tagged_union(
@@ -135,9 +95,50 @@ def structure_tagged_union(
13595
return _tag_to_hook[val[_tag_name]](val)
13696
return _dh(val, _default)
13797

98+
else:
99+
if getattr(converter, "forbid_extra_keys", False):
100+
101+
def structure_tagged_union(
102+
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
103+
) -> union:
104+
val = val.copy()
105+
return _tag_to_cl[val.pop(_tag_name)](val)
106+
107+
else:
108+
109+
def structure_tagged_union(
110+
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
111+
) -> union:
112+
return _tag_to_cl[val[_tag_name]](val)
113+
114+
def unstructure_tagged_union(
115+
val: union,
116+
_exact_cl_unstruct_hooks=exact_cl_unstruct_hooks,
117+
_cl_to_tag=cl_to_tag,
118+
_tag_name=tag_name,
119+
) -> dict:
120+
res = _exact_cl_unstruct_hooks[val.__class__](val)
121+
res[_tag_name] = _cl_to_tag[val.__class__]
122+
return res
123+
138124
converter.register_unstructure_hook(union, unstructure_tagged_union)
139125
converter.register_structure_hook(union, structure_tagged_union)
140126

127+
for cl in args:
128+
tag = tag_generator(cl)
129+
struct_handler = converter.get_structure_hook(cl)
130+
unstruct_handler = converter.get_unstructure_hook(cl)
131+
132+
def structure_union_member(val: dict, _cl=cl, _h=struct_handler) -> cl:
133+
return _h(val, _cl)
134+
135+
def unstructure_union_member(val: union, _h=unstruct_handler) -> dict:
136+
return _h(val)
137+
138+
tag_to_hook[tag] = structure_union_member
139+
exact_cl_unstruct_hooks[cl] = unstructure_union_member
140+
cl_to_tag[cl] = tag
141+
141142

142143
def configure_union_passthrough(
143144
union: Any, converter: BaseConverter, accept_ints_as_floats: bool = True
Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import pytest
1+
from __future__ import annotations
22

3-
from cattrs import BaseConverter
3+
from attrs import define
4+
5+
from cattrs import BaseConverter, Converter
46
from cattrs.strategies import configure_tagged_union
57

6-
from .._compat import is_py312_plus
78
from .test_tagged_unions import A, B
89

910

10-
@pytest.mark.skipif(not is_py312_plus, reason="New type alias syntax")
1111
def test_type_alias(converter: BaseConverter):
1212
"""Type aliases to unions also work."""
1313
type AOrB = A | B
@@ -19,3 +19,37 @@ def test_type_alias(converter: BaseConverter):
1919

2020
assert converter.structure({"_type": "A", "a": 1}, AOrB) == A(1)
2121
assert converter.structure({"_type": "B", "a": 1}, AOrB) == B("1")
22+
23+
24+
@define
25+
class Lit:
26+
value: float
27+
28+
29+
@define
30+
class Add:
31+
left: Expr
32+
right: Expr
33+
34+
35+
type Expr = Add | Lit
36+
37+
38+
def test_recursive_type_alias(genconverter: Converter):
39+
"""Recursive type aliases to unions also work.
40+
41+
Only tests on the GenConverter since the BaseConverter doesn't support
42+
stringified annotations.
43+
"""
44+
45+
configure_tagged_union(Expr, genconverter)
46+
47+
val = Add(Lit(1.0), Lit(2.0))
48+
expected = {
49+
"_type": "Add",
50+
"left": {"_type": "Lit", "value": 1.0},
51+
"right": {"_type": "Lit", "value": 2.0},
52+
}
53+
54+
assert genconverter.unstructure(val, Expr) == expected
55+
assert genconverter.structure(expected, Expr) == val

0 commit comments

Comments
 (0)