Skip to content

Commit 1d6bdee

Browse files
authored
Fix aliased imports not applied to base classes and non-matching fields (#2981)
1 parent 365419e commit 1d6bdee

File tree

5 files changed

+146
-15
lines changed

5 files changed

+146
-15
lines changed

src/datamodel_code_generator/parser/base.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,29 @@ def iter_models_field_data_types(
362362
yield model, field, data_type
363363

364364

365+
def _alias_base_class_imports(
366+
model: DataModel,
367+
aliased_imports: dict[tuple[str | None, str], Import],
368+
) -> None:
369+
"""Apply aliased imports to a model's base classes and their _additional_imports."""
370+
for base_class in model.base_classes:
371+
if not base_class.import_:
372+
continue
373+
key = (base_class.import_.from_, base_class.import_.import_)
374+
if key not in aliased_imports:
375+
continue
376+
old_import = base_class.import_
377+
aliased_import = aliased_imports[key]
378+
base_class.type = aliased_import.alias # type: ignore[assignment]
379+
base_class.import_ = aliased_import
380+
for i, additional_import in enumerate(model._additional_imports): # pragma: no branch # noqa: SLF001
381+
if (
382+
additional_import.from_ == old_import.from_ and additional_import.import_ == old_import.import_
383+
): # pragma: no branch
384+
model._additional_imports[i] = aliased_import # noqa: SLF001
385+
break
386+
387+
365388
ReferenceMapSet = dict[str, set[str]]
366389
SortedDataModels = dict[str, DataModel]
367390

@@ -2388,21 +2411,31 @@ def __alias_shadowed_imports( # noqa: PLR6301
23882411
models: list[DataModel],
23892412
all_model_field_names: set[str],
23902413
) -> None:
2391-
for _, model_field, data_type in iter_models_field_data_types(models):
2392-
if (
2393-
data_type
2394-
and data_type.import_
2395-
and data_type.type in all_model_field_names
2396-
and data_type.type == model_field.name
2397-
):
2398-
alias = data_type.type + "_aliased"
2399-
data_type.type = alias
2400-
data_type.import_ = Import(
2401-
from_=data_type.import_.from_,
2402-
import_=data_type.import_.import_,
2403-
alias=alias,
2404-
reference_path=data_type.import_.reference_path,
2405-
)
2414+
aliased_imports: dict[tuple[str | None, str], Import] = {}
2415+
for _, _model_field, data_type in iter_models_field_data_types(models):
2416+
if data_type and data_type.import_ and data_type.type in all_model_field_names:
2417+
key = (data_type.import_.from_, data_type.import_.import_)
2418+
if key not in aliased_imports:
2419+
aliased_imports[key] = Import(
2420+
from_=data_type.import_.from_,
2421+
import_=data_type.import_.import_,
2422+
alias=data_type.type + "_aliased",
2423+
reference_path=data_type.import_.reference_path,
2424+
)
2425+
2426+
if not aliased_imports:
2427+
return
2428+
2429+
for _, _model_field, data_type in iter_models_field_data_types(models):
2430+
if data_type and data_type.import_:
2431+
key = (data_type.import_.from_, data_type.import_.import_)
2432+
if key in aliased_imports:
2433+
aliased_import = aliased_imports[key]
2434+
data_type.type = aliased_import.alias # type: ignore[assignment]
2435+
data_type.import_ = aliased_import
2436+
2437+
for model in models:
2438+
_alias_base_class_imports(model, aliased_imports)
24062439

24072440
def __apply_generic_base_class( # noqa: PLR0912, PLR0914, PLR0915
24082441
self,
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# generated by datamodel-codegen:
2+
# filename: shadowed_imports_base_and_fields.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from mymodule.node import Node as Node_aliased
8+
from mymodule.other import Other
9+
from pydantic import BaseModel, ConfigDict
10+
11+
12+
class BaseItem(BaseModel):
13+
id: str | None = None
14+
15+
16+
class MyModel(BaseItem):
17+
model_config = ConfigDict(
18+
arbitrary_types_allowed=True,
19+
)
20+
Node: Node_aliased | None = None
21+
Node2: Node_aliased | None = None
22+
other: Other | None = None
23+
name: str | None = None
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# generated by datamodel-codegen:
2+
# filename: shadowed_imports_base_and_fields.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from mymodule.node import Node as Node_aliased
8+
from mymodule.other import Other
9+
from pydantic import ConfigDict
10+
11+
12+
class BaseItem(Node_aliased):
13+
id: str | None = None
14+
15+
16+
class MyModel(BaseItem):
17+
model_config = ConfigDict(
18+
arbitrary_types_allowed=True,
19+
)
20+
Node: Node_aliased | None = None
21+
Node2: Node_aliased | None = None
22+
other: Other | None = None
23+
name: str | None = None
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
openapi: 3.0.0
2+
info:
3+
title: Shadowed Imports Base and Fields Test
4+
version: 0.0.1
5+
paths: {}
6+
components:
7+
schemas:
8+
BaseItem:
9+
type: object
10+
properties:
11+
id:
12+
type: string
13+
MyModel:
14+
allOf:
15+
- $ref: '#/components/schemas/BaseItem'
16+
- type: object
17+
properties:
18+
Node:
19+
type: object
20+
customTypePath: mymodule.node.Node
21+
Node2:
22+
type: object
23+
customTypePath: mymodule.node.Node
24+
other:
25+
type: object
26+
customTypePath: mymodule.other.Other
27+
name:
28+
type: string

tests/main/openapi/test_main_openapi.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3712,6 +3712,30 @@ def test_main_openapi_shadowed_imports(output_file: Path) -> None:
37123712
)
37133713

37143714

3715+
def test_main_openapi_shadowed_imports_base_and_fields(output_file: Path) -> None:
3716+
"""Test that aliased imports are applied to all fields, not just matching field names."""
3717+
run_main_and_assert(
3718+
input_path=OPEN_API_DATA_PATH / "shadowed_imports_base_and_fields.yaml",
3719+
output_path=output_file,
3720+
input_file_type="openapi",
3721+
assert_func=assert_file_content,
3722+
expected_file="shadowed_imports_base_and_fields.py",
3723+
extra_args=["--output-model-type", "pydantic_v2.BaseModel"],
3724+
)
3725+
3726+
3727+
def test_main_openapi_shadowed_imports_base_and_fields_custom_base(output_file: Path) -> None:
3728+
"""Test that aliased imports are applied to custom base classes."""
3729+
run_main_and_assert(
3730+
input_path=OPEN_API_DATA_PATH / "shadowed_imports_base_and_fields.yaml",
3731+
output_path=output_file,
3732+
input_file_type="openapi",
3733+
assert_func=assert_file_content,
3734+
expected_file="shadowed_imports_base_and_fields_custom_base.py",
3735+
extra_args=["--output-model-type", "pydantic_v2.BaseModel", "--base-class", "mymodule.node.Node"],
3736+
)
3737+
3738+
37153739
def test_main_openapi_extra_fields_forbid(output_file: Path) -> None:
37163740
"""Test OpenAPI generation with extra fields forbidden."""
37173741
run_main_and_assert(

0 commit comments

Comments
 (0)