Skip to content

Commit a7a1b30

Browse files
authored
add GenericModelType to fix serialization (#1672)
* fix: add GenericModelType to fix serialization also sets all GenericModelType fields to Optional with default=None * chore: linting * fix: whitespace
1 parent c183ea3 commit a7a1b30

File tree

10 files changed

+40
-30
lines changed

10 files changed

+40
-30
lines changed

src/aind_data_schema/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Field,
1515
NaiveDatetime,
1616
PrivateAttr,
17+
SerializeAsAny,
1718
ValidationError,
1819
ValidatorFunctionWrapHandler,
1920
create_model,
@@ -79,10 +80,7 @@ def has_corrupt_keys(input) -> bool:
7980

8081
class GenericModel(BaseModel, extra="allow"):
8182
"""Base class for generic types that can be used in AIND schema"""
82-
8383
# extra="allow" is needed because BaseModel by default drops extra parameters.
84-
# Alternatively, consider using 'SerializeAsAny' once this issue is resolved
85-
# https://github.com/pydantic/pydantic/issues/6423
8684

8785
@model_validator(mode="after")
8886
def validate_fieldnames(self):
@@ -95,6 +93,8 @@ def validate_fieldnames(self):
9593
return self
9694

9795

96+
GenericModelType = SerializeAsAny[GenericModel]
97+
9898
T = TypeVar("T")
9999
Discriminated = Annotated[T, Field(discriminator="object_type")]
100100
DiscriminatedList = List[Discriminated[T]]

src/aind_data_schema/components/configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pydantic import Field, field_validator, model_validator
2323
from pydantic_core.core_schema import ValidationInfo
2424

25-
from aind_data_schema.base import AwareDatetimeWithDefault, DataModel, DiscriminatedList, GenericModel
25+
from aind_data_schema.base import AwareDatetimeWithDefault, DataModel, DiscriminatedList, GenericModelType
2626
from aind_data_schema.components.coordinates import (
2727
TRANSFORM_TYPES,
2828
AtlasCoordinate,
@@ -542,7 +542,7 @@ class MRIScan(DeviceConfig):
542542
# other fields
543543
resolution: Optional[Scale] = Field(default=None, title="Voxel resolution")
544544
resolution_unit: Optional[SizeUnit] = Field(default=None, title="Voxel resolution unit")
545-
additional_scan_parameters: GenericModel = Field(..., title="Parameters")
545+
additional_scan_parameters: Optional[GenericModelType] = Field(default=None, title="Parameters")
546546
notes: Optional[str] = Field(default=None, title="Notes", validate_default=True)
547547

548548
@field_validator("notes", mode="after")

src/aind_data_schema/components/devices.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from pydantic import Field, ValidationInfo, field_validator, model_validator
4343

44-
from aind_data_schema.base import DataModel, Discriminated, GenericModel
44+
from aind_data_schema.base import DataModel, Discriminated, GenericModelType
4545
from aind_data_schema.components.coordinates import TRANSFORM_TYPES, AxisName, CoordinateSystem, Scale
4646
from aind_data_schema.components.identifiers import Software
4747

@@ -78,7 +78,7 @@ class Device(DataModel):
7878
model: Optional[str] = Field(default=None, title="Model")
7979

8080
# Additional fields
81-
additional_settings: Optional[GenericModel] = Field(default=None, title="Additional parameters")
81+
additional_settings: Optional[GenericModelType] = Field(default=None, title="Additional parameters")
8282
notes: Optional[str] = Field(default=None, title="Notes")
8383

8484
@model_validator(mode="after")

src/aind_data_schema/components/identifiers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from aind_data_schema_models.registries import Registry
88
from pydantic import Field
99

10-
from aind_data_schema.base import DataModel, DiscriminatedList, GenericModel
10+
from aind_data_schema.base import DataModel, DiscriminatedList, GenericModelType
1111

1212

1313
class Database(str, Enum):
@@ -83,7 +83,7 @@ class Code(DataModel):
8383
input_data: Optional[DiscriminatedList[DataAsset | CombinedData]] = Field(
8484
default=None, title="Input data", description="Input data used in the code or script"
8585
)
86-
parameters: Optional[GenericModel] = Field(
86+
parameters: Optional[GenericModelType] = Field(
8787
default=None, title="Parameters", description="Parameters used in the code or script"
8888
)
8989

src/aind_data_schema/components/measurements.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from aind_data_schema_models.units import UNITS, PowerUnit, TimeUnit, VolumeUnit, VoltageUnit
77
from pydantic import model_validator
88

9-
from aind_data_schema.base import AwareDatetimeWithDefault, DataModel, Discriminated, Field, GenericModel
9+
from aind_data_schema.base import AwareDatetimeWithDefault, DataModel, Discriminated, Field, GenericModelType
1010
from aind_data_schema.components.configs import DeviceConfig
1111
from aind_data_schema.components.reagent import Reagent
1212
from aind_data_schema.utils.validators import TimeValidation
@@ -27,7 +27,7 @@ class CalibrationFit(DataModel):
2727
...,
2828
title="Fit type",
2929
)
30-
fit_parameters: Optional[GenericModel] = Field(
30+
fit_parameters: Optional[GenericModelType] = Field(
3131
default=None,
3232
title="Fit parameters",
3333
description="Parameters of the fit equation, e.g. slope and intercept for linear fit",

src/aind_data_schema/components/stimulus.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from aind_data_schema_models.units import FrequencyUnit, PowerUnit, TimeUnit, ConcentrationUnit
88
from pydantic import Field, model_validator
99

10-
from aind_data_schema.base import DataModel, GenericModel
10+
from aind_data_schema.base import DataModel, GenericModel, GenericModelType
1111

1212

1313
class PulseShape(str, Enum):
@@ -48,16 +48,16 @@ class OptoStimulation(GenericModel):
4848
description="Duration of baseline recording prior to first pulse train",
4949
)
5050
baseline_duration_unit: TimeUnit = Field(default=TimeUnit.S, title="Baseline duration unit")
51-
other_parameters: GenericModel = Field(GenericModel(), title="Other parameters")
51+
other_parameters: Optional[GenericModelType] = Field(default=None, title="Other parameters")
5252
notes: Optional[str] = Field(default=None, title="Notes")
5353

5454

5555
class VisualStimulation(GenericModel):
5656
"""Description of visual stimulus parameters. Provides a high level description of stimulus."""
5757

5858
stimulus_name: str = Field(..., title="Stimulus name")
59-
stimulus_parameters: GenericModel = Field(
60-
GenericModel(),
59+
stimulus_parameters: Optional[GenericModelType] = Field(
60+
default=None,
6161
title="Stimulus parameters",
6262
description="Define and list the parameter values used (e.g. all TF or orientation values)",
6363
)
@@ -82,7 +82,7 @@ class PhotoStimulationGroup(DataModel):
8282
spiral_duration_unit: TimeUnit = Field(default=TimeUnit.S, title="Spiral duration unit")
8383
inter_spiral_interval: Decimal = Field(..., title="Inter trial interval (s)")
8484
inter_spiral_interval_unit: TimeUnit = Field(default=TimeUnit.S, title="Inter trial interval unit")
85-
other_parameters: GenericModel = Field(GenericModel(), title="Other parameters")
85+
other_parameters: Optional[GenericModelType] = Field(default=None, title="Other parameters")
8686
notes: Optional[str] = Field(default=None, title="Notes")
8787

8888

@@ -94,7 +94,7 @@ class PhotoStimulation(GenericModel):
9494
groups: List[PhotoStimulationGroup] = Field(..., title="Groups")
9595
inter_trial_interval: Decimal = Field(..., title="Inter trial interval (s)")
9696
inter_trial_interval_unit: TimeUnit = Field(default=TimeUnit.S, title="Inter trial interval unit")
97-
other_parameters: GenericModel = Field(GenericModel(), title="Other parameters")
97+
other_parameters: Optional[GenericModelType] = Field(default=None, title="Other parameters")
9898
notes: Optional[str] = Field(default=None, title="Notes")
9999

100100

src/aind_data_schema/core/acquisition.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
from aind_data_schema.utils.validators import TimeValidation
1111
from pydantic import Field, SkipValidation, model_validator
1212

13-
from aind_data_schema.base import AwareDatetimeWithDefault, DataCoreModel, DataModel, DiscriminatedList, GenericModel
13+
from aind_data_schema.base import (
14+
AwareDatetimeWithDefault,
15+
DataCoreModel,
16+
DataModel,
17+
DiscriminatedList,
18+
GenericModelType
19+
)
1420
from aind_data_schema.components.configs import (
1521
AirPuffConfig,
1622
CatheterConfig,
@@ -83,7 +89,7 @@ class AcquisitionSubjectDetails(DataModel):
8389
class PerformanceMetrics(DataModel):
8490
"""Summary of a StimulusEpoch"""
8591

86-
output_parameters: GenericModel = Field(default=GenericModel(), title="Additional metrics")
92+
output_parameters: Optional[GenericModelType] = Field(default=None, title="Additional metrics")
8793
reward_consumed_during_epoch: Optional[Decimal] = Field(default=None, title="Reward consumed during training (uL)")
8894
reward_consumed_unit: Optional[VolumeUnit] = Field(default=None, title="Reward consumed unit")
8995
trials_total: Optional[int] = Field(default=None, title="Total trials")

src/aind_data_schema/core/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from aind_data_schema_models.system_architecture import ModelArchitecture
66
from pydantic import Field
77

8-
from aind_data_schema.base import DataCoreModel, DataModel, DiscriminatedList, GenericModel
8+
from aind_data_schema.base import DataCoreModel, DataModel, DiscriminatedList, GenericModelType
99
from aind_data_schema.components.identifiers import Code, Software
1010
from aind_data_schema.core.processing import DataProcess, ProcessName
1111

@@ -59,8 +59,8 @@ class Model(DataCoreModel):
5959
)
6060
architecture: ModelArchitecture = Field(..., title="architecture", description="Model architecture / type of model")
6161
software_framework: Optional[Software] = Field(default=None, title="Software framework")
62-
architecture_parameters: GenericModel = Field(
63-
default=GenericModel(),
62+
architecture_parameters: Optional[GenericModelType] = Field(
63+
default=None,
6464
title="Architecture parameters",
6565
description="Parameters of model architecture, such as input signature or number of layers.",
6666
)

src/aind_data_schema/core/processing.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from aind_data_schema_models.units import MemoryUnit, UnitlessUnit
1010
from pydantic import Field, SkipValidation, ValidationInfo, field_validator, model_validator
1111

12-
from aind_data_schema.base import AwareDatetimeWithDefault, DataCoreModel, DataModel, GenericModel
12+
from aind_data_schema.base import AwareDatetimeWithDefault, DataCoreModel, DataModel, GenericModelType
1313
from aind_data_schema.components.identifiers import Code
1414
from aind_data_schema.components.wrappers import AssetPath
1515
from aind_data_schema.utils.merge import merge_notes, merge_optional_list, merge_process_graph
@@ -71,7 +71,11 @@ class DataProcess(DataModel):
7171
output_path: Optional[AssetPath] = Field(
7272
default=None, title="Output path", description="Path to processing outputs, if stored."
7373
)
74-
output_parameters: GenericModel = Field(default=GenericModel(), description="Output parameters", title="Outputs")
74+
output_parameters: Optional[GenericModelType] = Field(
75+
default=None,
76+
description="Output parameters",
77+
title="Outputs"
78+
)
7579
notes: Optional[str] = Field(default=None, title="Notes", validate_default=True)
7680
resources: Optional[ResourceUsage] = Field(default=None, title="Process resource usage")
7781

tests/test_aind_generic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@
22

33
import unittest
44

5-
from pydantic import BaseModel, Field
5+
from pydantic import Field
66

7-
from aind_data_schema.base import DataModel, GenericModel
7+
from aind_data_schema.base import DataModel, GenericModel, GenericModelType
88

99

1010
class GenericContainer(DataModel):
1111
"""Represents a generic container"""
1212

13-
contains_model: GenericModel
14-
contains_dict: GenericModel
13+
contains_model: GenericModelType
14+
contains_dict: GenericModelType
1515

1616

17-
class Bar(BaseModel):
17+
class Bar(GenericModel):
1818
"""Represents a mock model"""
1919

2020
bar: str = Field(default="bar")
@@ -58,7 +58,7 @@ def test_sub_container_from_container(self):
5858
contains_dict={"foodict": 1, "bardict": "bar"},
5959
)
6060
parent_container = GenericContainer(
61-
contains_model=Bar(bar="baz", foo=2).model_dump(),
61+
contains_model=Bar(bar="baz", foo=2),
6262
contains_dict={"foodict": 1, "bardict": "bar"},
6363
)
6464
deserialized = SubGenericContainer.model_validate_json(parent_container.model_dump_json())

0 commit comments

Comments
 (0)