Skip to content

Commit 7614754

Browse files
authored
Add map datatype to the Ibis engine implementation (#2206)
* Add map datatype to the Ibis engine implementation Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu> * Limit generic bases that get looked up in registry Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu> * Revert to map implementation with default datatype Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu> * Unit test equivalence of newly-added map datatypes Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu> * Fix the module name lookup logic for generic bases Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu> * Undo reversion to default datatypes implementation This reverts commit b058433. Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu> * Test checking all parametrizations of `Map` dtypes Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu> * Add test cases corresponding to scenarios in issue Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu> --------- Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
1 parent 8b87d0f commit 7614754

File tree

5 files changed

+109
-12
lines changed

5 files changed

+109
-12
lines changed

pandera/engines/engine.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,13 @@ def dtype(cls: "Engine", data_type: Any) -> DataType:
260260
or ((NamedTuple,) if _is_namedtuple(data_type) else ())
261261
or typing_inspect.get_generic_bases(data_type)
262262
)
263-
if datatype_generic_bases:
264-
equivalent_data_type = None
265-
for base in datatype_generic_bases:
266-
equivalent_data_type = registry.get_equivalent(base)
267-
break
263+
if datatype_generic_bases and inspect.getmodule(
264+
base := datatype_generic_bases[0]
265+
).__name__ in { # type: ignore[union-attr]
266+
*sys.stdlib_module_names,
267+
"typing_extensions",
268+
}:
269+
equivalent_data_type = registry.get_equivalent(base)
268270
if equivalent_data_type is None:
269271
raise TypeError(
270272
f"Type '{data_type}' not understood by {cls.__name__}."

pandera/engines/ibis_engine.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def check(
5858
data_container: ibis.Table | None = None,
5959
) -> Union[bool, Iterable[bool]]:
6060
try:
61-
return self.type == pandera_dtype.type
61+
return self.type == Engine.dtype(pandera_dtype).type
6262
except TypeError:
6363
return False
6464

@@ -484,3 +484,38 @@ def from_parametrized_dtype(cls, ibis_dtype: dt.Interval):
484484
"""Convert a :class:`dt.Interval` to a Pandera
485485
:class:`~pandera.engines.ibis_engine.Timedelta`."""
486486
return cls(unit=ibis_dtype.unit)
487+
488+
489+
###############################################################################
490+
# nested
491+
###############################################################################
492+
493+
494+
@Engine.register_dtype(
495+
equivalents=[
496+
dict,
497+
dt.Map,
498+
]
499+
)
500+
@immutable(init=True)
501+
class Map(DataType):
502+
"""Semantic representation of a :class:`dt.Map`."""
503+
504+
type: dt.Map
505+
506+
def __init__(
507+
self,
508+
key_type: dt.DataType | None = None,
509+
value_type: dt.DataType | None = None,
510+
):
511+
if key_type is not None and value_type is not None:
512+
object.__setattr__(self, "type", dt.Map(key_type, value_type))
513+
514+
@classmethod
515+
def from_parametrized_dtype(cls, ibis_dtype: dt.Map):
516+
"""Convert a :class:`dt.Map` to a Pandera
517+
:class:`~pandera.engines.ibis_engine.Map`."""
518+
return cls(
519+
key_type=ibis_dtype.key_type,
520+
value_type=ibis_dtype.value_type,
521+
)

pandera/engines/polars_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def __init__(
623623
elif shape is not None:
624624
kwargs["shape"] = shape
625625

626-
if inner:
626+
if inner is not None:
627627
object.__setattr__(self, "type", pl.Array(inner=inner, **kwargs))
628628

629629
@classmethod
@@ -655,7 +655,7 @@ def __init__(
655655
self,
656656
inner: PolarsDataType | None = None,
657657
) -> None:
658-
if inner:
658+
if inner is not None:
659659
object.__setattr__(self, "type", pl.List(inner=inner))
660660

661661
@classmethod
@@ -674,7 +674,7 @@ def __init__(
674674
self,
675675
fields: Union[Sequence[pl.Field], SchemaDict] | None = None,
676676
) -> None:
677-
if fields:
677+
if fields is not None:
678678
object.__setattr__(self, "type", pl.Struct(fields=fields))
679679

680680
@classmethod

tests/ibis/test_ibis_dtypes.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import pytest
66
from hypothesis import given, settings
77
from hypothesis import strategies as st
8+
from ibis import _
89
from polars.testing import assert_frame_equal
910
from polars.testing.parametric import dataframes
1011

12+
import pandera.ibis as pa
1113
from pandera.engines import ibis_engine as ie
1214

1315
NUMERIC_TYPES = [
@@ -22,6 +24,10 @@
2224
ie.String,
2325
]
2426

27+
SPECIAL_TYPES = [
28+
ie.Map,
29+
]
30+
2531
ALL_TYPES = NUMERIC_TYPES + TEMPORAL_TYPES + OTHER_TYPES
2632

2733

@@ -71,7 +77,7 @@ def test_coerce_cast(from_dtype, to_dtype, strategy, data):
7177
assert dtype == to_dtype.type
7278

7379

74-
@pytest.mark.parametrize("dtype", ALL_TYPES)
80+
@pytest.mark.parametrize("dtype", ALL_TYPES + SPECIAL_TYPES)
7581
def test_check_not_equivalent(dtype):
7682
"""Test that check() rejects non-equivalent dtypes."""
7783
if str(ie.Engine.dtype(dtype)) == "string":
@@ -82,7 +88,7 @@ def test_check_not_equivalent(dtype):
8288
assert not actual_dtype.check(expected_dtype)
8389

8490

85-
@pytest.mark.parametrize("dtype", ALL_TYPES)
91+
@pytest.mark.parametrize("dtype", ALL_TYPES + SPECIAL_TYPES)
8692
def test_check_equivalent(dtype):
8793
"""Test that check() accepts equivalent dtypes."""
8894
actual_dtype = ie.Engine.dtype(dtype)
@@ -130,3 +136,57 @@ def test_ibis_decimal_from_parametrized_dtype(ibis_dtype, expected_dtype):
130136

131137
assert pandera_dtype.precision == expected_dtype.precision
132138
assert pandera_dtype.scale == expected_dtype.scale
139+
140+
141+
@pytest.mark.parametrize("key_dtype", ALL_TYPES)
142+
@pytest.mark.parametrize("value_dtype", ALL_TYPES)
143+
def test_ibis_map_nested_type(key_dtype, value_dtype):
144+
ibis_dtype = dt.Map(key_dtype.type(), value_dtype.type())
145+
pandera_dtype = ie.Engine.dtype(ibis_dtype)
146+
147+
assert pandera_dtype.check(ibis_dtype)
148+
assert pandera_dtype.check(pandera_dtype)
149+
150+
151+
def test_ibis_map_type():
152+
# https://github.com/unionai-oss/pandera/issues/2201
153+
data = {
154+
"id": ["01", "02", "03"],
155+
"key_col": [
156+
[1, 2, 3],
157+
[
158+
1,
159+
2,
160+
],
161+
None,
162+
],
163+
"value_col": [
164+
["value_1A", "value_1B", "value_1C"],
165+
["value_2A", "value_2B"],
166+
None,
167+
],
168+
}
169+
df = (
170+
ibis.memtable(data)
171+
.mutate(dict_col=ibis.map(_.key_col, _.value_col))
172+
.drop("key_col", "value_col")
173+
)
174+
175+
class ValidateSchema(pa.DataFrameModel):
176+
id: str = pa.Field(nullable=False)
177+
dict_col: ibis.dtype("map<int64, string>") = pa.Field(nullable=True)
178+
179+
ValidateSchema.validate(df)
180+
181+
class ValidateSchema(pa.DataFrameModel):
182+
id: str = pa.Field(nullable=False)
183+
dict_col: dt.Map = pa.Field(nullable=True)
184+
185+
ValidateSchema.validate(df)
186+
187+
class ValidateSchema(pa.DataFrameModel):
188+
id: str = pa.Field(nullable=False)
189+
dict_col: ibis.dtype("map<string, string>") = pa.Field(nullable=True)
190+
191+
with pytest.raises(pa.errors.SchemaError):
192+
ValidateSchema.validate(df)

tests/polars/test_polars_dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def test_polars_decimal_from_parametrized_dtype(polars_dtype, expected_dtype):
354354
)
355355
@given(st.integers(min_value=2, max_value=10))
356356
@settings(max_examples=5)
357-
def test_polars_nested_array_type_check(inner_dtype_cls, width):
357+
def test_polars_array_nested_type(inner_dtype_cls, width):
358358
polars_dtype = pl.Array(inner_dtype_cls(), width)
359359
pandera_dtype = pe.Engine.dtype(polars_dtype)
360360

0 commit comments

Comments
 (0)