Skip to content

Commit 91c5086

Browse files
committed
Add XTensorLike type alias
1 parent c647bb2 commit 91c5086

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

pytensor/xtensor/type.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import typing
21
import warnings
32
from types import EllipsisType
43

@@ -11,6 +10,7 @@
1110
)
1211
from pytensor.scalar import ScalarType
1312
from pytensor.tensor import (
13+
TensorLike,
1414
TensorType,
1515
_as_tensor_variable,
1616
as_tensor_variable,
@@ -20,14 +20,16 @@
2020

2121

2222
try:
23-
import xarray as xr
23+
import xarray
2424

25+
DataArray = xarray.DataArray
2526
XARRAY_AVAILABLE = True
2627
except ModuleNotFoundError:
2728
XARRAY_AVAILABLE = False
2829

2930
from collections.abc import Sequence
30-
from typing import Any, Literal, TypeVar
31+
from typing import Any, Literal, TypeAlias, TypeVar, Union
32+
from typing import cast as typing_cast
3133

3234
import numpy as np
3335

@@ -95,7 +97,7 @@ def clone(
9597

9698
def filter(self, value, strict=False, allow_downcast=None):
9799
# XTensorType behaves like TensorType at runtime, so we filter the same way.
98-
if XARRAY_AVAILABLE and isinstance(value, xr.DataArray):
100+
if XARRAY_AVAILABLE and isinstance(value, DataArray):
99101
value = value.transpose(*self.dims).values
100102
return TensorType.filter(
101103
self, value, strict=strict, allow_downcast=allow_downcast
@@ -109,7 +111,7 @@ def filter_variable(self, other, allow_convert=True):
109111
if not isinstance(other, Variable):
110112
# The value is not a Variable: we cast it into
111113
# a Constant of the appropriate Type.
112-
if XARRAY_AVAILABLE and isinstance(other, xr.DataArray):
114+
if XARRAY_AVAILABLE and isinstance(other, DataArray):
113115
other = other.transpose(*self.dims).values
114116
other = XTensorConstant(type=self, data=other)
115117

@@ -369,7 +371,7 @@ def __trunc__(self):
369371
@property
370372
def values(self) -> TensorVariable:
371373
"""Convert to a TensorVariable with the same data."""
372-
return typing.cast(TensorVariable, px.basic.tensor_from_xtensor(self))
374+
return typing_cast(TensorVariable, px.basic.tensor_from_xtensor(self))
373375

374376
# Can't provide property data because that's already taken by Constants!
375377
# data = values
@@ -409,7 +411,7 @@ def shape(self) -> tuple[TensorVariable, ...]:
409411
@property
410412
def size(self) -> TensorVariable:
411413
"""The total number of elements in the variable."""
412-
return typing.cast(TensorVariable, variadic_mul(*self.shape))
414+
return typing_cast(TensorVariable, variadic_mul(*self.shape))
413415

414416
@property
415417
def dtype(self) -> str:
@@ -904,6 +906,12 @@ def broadcast_like(self, other, exclude=None):
904906
return self_bcast
905907

906908

909+
if XARRAY_AVAILABLE:
910+
XTensorLike: TypeAlias = Union[TensorLike, XTensorVariable, "DataArray"]
911+
else:
912+
XTensorLike: TypeAlias = TensorLike | XTensorVariable
913+
914+
907915
class XTensorConstantSignature(TensorConstantSignature):
908916
pass
909917

@@ -939,18 +947,18 @@ def _extract_data_and_dims(
939947
x, dims: None | Sequence[str] = None
940948
) -> tuple[np.ndarray, tuple[str, ...]]:
941949
x_dims: tuple[str, ...]
942-
if XARRAY_AVAILABLE and isinstance(x, xr.DataArray):
950+
if XARRAY_AVAILABLE and isinstance(x, DataArray):
943951
xarray_dims = x.dims
944952
if not all(isinstance(dim, str) for dim in xarray_dims):
945953
raise NotImplementedError(
946954
"DataArray can only be converted to xtensor if all dims are of string type"
947955
)
948-
x_dims = tuple(typing.cast(typing.Iterable[str], xarray_dims))
956+
x_dims = tuple(typing_cast(Sequence[str], xarray_dims))
949957
x_data = x.values
950958

951959
if dims is not None and dims != x_dims:
952960
raise ValueError(
953-
f"xr.DataArray dims {x_dims} don't match requested specified {dims}. "
961+
f"xarray.DataArray dims {x_dims} don't match requested specified {dims}. "
954962
"Use transpose or rename"
955963
)
956964
else:
@@ -1015,8 +1023,8 @@ def xtensor_shared(
10151023

10161024

10171025
if XARRAY_AVAILABLE:
1018-
_as_symbolic.register(xr.DataArray, xtensor_constant)
1019-
shared_constructor.register(xr.DataArray, xtensor_shared)
1026+
_as_symbolic.register(DataArray, xtensor_constant)
1027+
shared_constructor.register(DataArray, xtensor_shared)
10201028

10211029

10221030
def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None):

0 commit comments

Comments
 (0)