1- import typing
21import warnings
32from types import EllipsisType
43
1110)
1211from pytensor .scalar import ScalarType
1312from pytensor .tensor import (
13+ TensorLike ,
1414 TensorType ,
1515 _as_tensor_variable ,
1616 as_tensor_variable ,
2020
2121
2222try :
23- import xarray as xr
23+ import xarray
2424
25+ DataArray = xarray .DataArray
2526 XARRAY_AVAILABLE = True
2627except ModuleNotFoundError :
2728 XARRAY_AVAILABLE = False
2829
2930from collections .abc import Sequence
30- from typing import Any , Literal , TypeVar
31+ from typing import Any , Literal , TypeAlias , TypeVar , Union
32+ from typing import cast as typing_cast
3133
3234import numpy as np
3335
@@ -95,7 +97,7 @@ def clone(
9597
9698 def filter (self , value , strict = False , allow_downcast = None ):
9799 # XTensorType behaves like TensorType at runtime, so we filter the same way.
98- if XARRAY_AVAILABLE and isinstance (value , xr . DataArray ):
100+ if XARRAY_AVAILABLE and isinstance (value , DataArray ):
99101 value = value .transpose (* self .dims ).values
100102 return TensorType .filter (
101103 self , value , strict = strict , allow_downcast = allow_downcast
@@ -109,7 +111,7 @@ def filter_variable(self, other, allow_convert=True):
109111 if not isinstance (other , Variable ):
110112 # The value is not a Variable: we cast it into
111113 # a Constant of the appropriate Type.
112- if XARRAY_AVAILABLE and isinstance (other , xr . DataArray ):
114+ if XARRAY_AVAILABLE and isinstance (other , DataArray ):
113115 other = other .transpose (* self .dims ).values
114116 other = XTensorConstant (type = self , data = other )
115117
@@ -369,7 +371,7 @@ def __trunc__(self):
369371 @property
370372 def values (self ) -> TensorVariable :
371373 """Convert to a TensorVariable with the same data."""
372- return typing . cast (TensorVariable , px .basic .tensor_from_xtensor (self ))
374+ return typing_cast (TensorVariable , px .basic .tensor_from_xtensor (self ))
373375
374376 # Can't provide property data because that's already taken by Constants!
375377 # data = values
@@ -409,7 +411,7 @@ def shape(self) -> tuple[TensorVariable, ...]:
409411 @property
410412 def size (self ) -> TensorVariable :
411413 """The total number of elements in the variable."""
412- return typing . cast (TensorVariable , variadic_mul (* self .shape ))
414+ return typing_cast (TensorVariable , variadic_mul (* self .shape ))
413415
414416 @property
415417 def dtype (self ) -> str :
@@ -904,6 +906,12 @@ def broadcast_like(self, other, exclude=None):
904906 return self_bcast
905907
906908
909+ if XARRAY_AVAILABLE :
910+ XTensorLike : TypeAlias = Union [TensorLike , XTensorVariable , "DataArray" ]
911+ else :
912+ XTensorLike : TypeAlias = TensorLike | XTensorVariable
913+
914+
907915class XTensorConstantSignature (TensorConstantSignature ):
908916 pass
909917
@@ -939,18 +947,18 @@ def _extract_data_and_dims(
939947 x , dims : None | Sequence [str ] = None
940948) -> tuple [np .ndarray , tuple [str , ...]]:
941949 x_dims : tuple [str , ...]
942- if XARRAY_AVAILABLE and isinstance (x , xr . DataArray ):
950+ if XARRAY_AVAILABLE and isinstance (x , DataArray ):
943951 xarray_dims = x .dims
944952 if not all (isinstance (dim , str ) for dim in xarray_dims ):
945953 raise NotImplementedError (
946954 "DataArray can only be converted to xtensor if all dims are of string type"
947955 )
948- x_dims = tuple (typing . cast ( typing . Iterable [str ], xarray_dims ))
956+ x_dims = tuple (typing_cast ( Sequence [str ], xarray_dims ))
949957 x_data = x .values
950958
951959 if dims is not None and dims != x_dims :
952960 raise ValueError (
953- f"xr .DataArray dims { x_dims } don't match requested specified { dims } . "
961+ f"xarray .DataArray dims { x_dims } don't match requested specified { dims } . "
954962 "Use transpose or rename"
955963 )
956964 else :
@@ -1015,8 +1023,8 @@ def xtensor_shared(
10151023
10161024
10171025if XARRAY_AVAILABLE :
1018- _as_symbolic .register (xr . DataArray , xtensor_constant )
1019- shared_constructor .register (xr . DataArray , xtensor_shared )
1026+ _as_symbolic .register (DataArray , xtensor_constant )
1027+ shared_constructor .register (DataArray , xtensor_shared )
10201028
10211029
10221030def as_xtensor (x , dims : Sequence [str ] | None = None , * , name : str | None = None ):
0 commit comments