Skip to content

Commit cd30435

Browse files
authored
fix: recursive_get_all_names should only get names from DataModel subclasses (#1722)
1 parent afc045e commit cd30435

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

src/aind_data_schema/utils/validators.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def recursive_coord_system_check(data, coordinate_system_name: Optional[str], ax
238238

239239

240240
def recursive_get_all_names(obj: Any) -> List[str]:
241-
"""Recursively extract all 'name' fields from an object and its nested fields."""
241+
"""Recursively extract all 'name' fields from a DataModel object and its nested fields."""
242242
names = []
243243

244244
if obj is None or isinstance(obj, Enum): # Skip None and Enums
@@ -249,8 +249,13 @@ def recursive_get_all_names(obj: Any) -> List[str]:
249249
names.extend(recursive_get_all_names(item))
250250

251251
elif hasattr(obj, "__dict__"): # Handle objects (including Pydantic models)
252+
if not hasattr(obj, "object_type"):
253+
# All DataModel objects should have an object_type attribute
254+
return names
252255
if hasattr(obj, "name") and isinstance(obj.name, str): # Ensure name is a string
253256
names.append(obj.name)
257+
258+
# Continue recursion into fields
254259
for field_value in vars(obj).values(): # Use vars() for robustness
255260
names.extend(recursive_get_all_names(field_value))
256261

tests/test_utils_validators.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,14 +291,14 @@ class MockEnum(Enum):
291291
VALUE2 = "value2"
292292

293293

294-
class NestedModel(BaseModel):
294+
class NestedModel(DataModel):
295295
"""Nested model for testing"""
296296

297297
name: str
298298
value: int
299299

300300

301-
class ComplexModel(BaseModel):
301+
class ComplexModel(DataModel):
302302
"""Complex model for testing"""
303303

304304
name: str
@@ -307,13 +307,27 @@ class ComplexModel(BaseModel):
307307
enum_field: MockEnum
308308

309309

310-
class ComplexParentModel(BaseModel):
310+
class ComplexParentModel(DataModel):
311311
"""Multi-level model for testing"""
312312

313313
name: str
314314
child: ComplexModel
315315

316316

317+
class NonDataModel(BaseModel):
318+
"""BaseModel (not DataModel) for testing that names are not extracted from non-DataModel objects"""
319+
320+
name: str
321+
value: int
322+
323+
324+
class ModelWithNonDataModelChild(DataModel):
325+
"""DataModel containing a non-DataModel child"""
326+
327+
name: str
328+
non_data_child: NonDataModel
329+
330+
317331
class TestRecursiveGetAllNames(unittest.TestCase):
318332
"""Tests for recursive_get_all_names function"""
319333

@@ -380,6 +394,13 @@ def test_multi_level(self):
380394
result = recursive_get_all_names(parent_model)
381395
self.assertEqual(result, ["parent_name", "complex_name", "nested_name"])
382396

397+
def test_non_data_model_child_ignored(self):
398+
"""Test that names from non-DataModel objects (without object_type) are not extracted"""
399+
non_data_child = NonDataModel(name="should_be_ignored", value=42)
400+
model = ModelWithNonDataModelChild(name="parent_name", non_data_child=non_data_child)
401+
result = recursive_get_all_names(model)
402+
self.assertEqual(result, ["parent_name"])
403+
383404

384405
class TestRecursiveCheckPaths(unittest.TestCase):
385406
"""Tests for recursive_check_paths function"""

0 commit comments

Comments
 (0)