Skip to content

Commit c193b25

Browse files
committed
Implement vectorize_node for XOps
1 parent b07c2c5 commit c193b25

File tree

16 files changed

+513
-12
lines changed

16 files changed

+513
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ lines-after-imports = 2
163163
"tests/link/numba/**/test_*.py" = ["E402"]
164164
"tests/link/pytorch/**/test_*.py" = ["E402"]
165165
"tests/link/mlx/**/test_*.py" = ["E402"]
166-
"tests/xtensor/**/test_*.py" = ["E402"]
166+
"tests/xtensor/**/*.py" = ["E402"]
167167

168168

169169

pytensor/xtensor/basic.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ def perform(self, node, inputs, outputs):
1717
def do_constant_folding(self, fgraph, node):
1818
return False
1919

20+
def vectorize_node(self, node, *new_inputs):
21+
raise NotImplementedError(f"Vectorized node not implemented for {self}")
22+
2023

2124
class XTypeCastOp(TypeCastingOp):
2225
"""Base class for Ops that type cast between TensorType and XTensorType.
@@ -27,6 +30,9 @@ class XTypeCastOp(TypeCastingOp):
2730
def infer_shape(self, fgraph, node, input_shapes):
2831
return input_shapes
2932

33+
def vectorize_node(self, node, *new_inputs):
34+
raise NotImplementedError(f"Vectorized node not implemented for {self}")
35+
3036

3137
class TensorFromXTensor(XTypeCastOp):
3238
__props__ = ()
@@ -42,6 +48,12 @@ def L_op(self, inputs, outs, g_outs):
4248
[g_out] = g_outs
4349
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
4450

51+
def vectorize_node(self, node, new_x):
52+
[old_x] = node.inputs
53+
# We transpose batch dims to the left, for consistency with tensor vectorization
54+
new_x = new_x.transpose(..., *old_x.dims)
55+
return self.make_node(new_x)
56+
4557

4658
tensor_from_xtensor = TensorFromXTensor()
4759

@@ -63,6 +75,16 @@ def L_op(self, inputs, outs, g_outs):
6375
[g_out] = g_outs
6476
return [tensor_from_xtensor(g_out)]
6577

78+
def vectorize_node(self, node, new_x):
79+
[old_x] = node.inputs
80+
if new_x.ndim != old_x.ndim:
81+
# TODO: Figure out API for this?
82+
raise NotImplementedError(
83+
f"Vectorization of {self} with batched inputs not implemented, "
84+
"as it can't infer new dimension labels"
85+
)
86+
return self().make_node(new_x)
87+
6688

6789
def xtensor_from_tensor(x, dims, name=None):
6890
return XTensorFromTensor(dims=dims)(x, name=name)
@@ -85,6 +107,16 @@ def L_op(self, inputs, outs, g_outs):
85107
[g_out] = g_outs
86108
return [rename(g_out, dims=x.type.dims)]
87109

110+
def vectorize_node(self, node, new_x):
111+
[old_x] = node.inputs
112+
old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True))
113+
114+
# new_dims may include a mix of old dims (possibly re-ordered), and new dims which won't be renamed
115+
new_dims = tuple(
116+
old_dim_mapping.get(new_dim, new_dim) for new_dim in new_x.dims
117+
)
118+
return type(self)(new_dims).make_node(new_x)
119+
88120

89121
def rename(x, name_dict: dict[str, str] | None = None, **names: str):
90122
if name_dict is not None:

pytensor/xtensor/indexing.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
# https://numpy.org/neps/nep-0021-advanced-indexing.html
55
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
66
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
7+
from itertools import chain
78
from typing import Literal
89

910
from pytensor.graph.basic import Apply, Constant, Variable
1011
from pytensor.scalar.basic import discrete_dtypes
1112
from pytensor.tensor.basic import as_tensor
1213
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
1314
from pytensor.xtensor.basic import XOp, xtensor_from_tensor
15+
from pytensor.xtensor.shape import broadcast
1416
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
1517

1618

@@ -195,6 +197,15 @@ def combine_dim_info(idx_dim, idx_dim_shape):
195197
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
196198
return Apply(self, [x, *idxs], [output])
197199

200+
def vectorize_node(self, node, new_x, *new_idxs):
201+
# new_x may have dims in different order
202+
# we pair each pre-existing dim to the respective index
203+
# with new dims having simply a slice(None)
204+
old_x, *_ = node.inputs
205+
dims_to_idxs = dict(zip(old_x.dims, new_idxs, strict=True))
206+
new_idxs = tuple(dims_to_idxs.get(dim, slice(None)) for dim in new_x.dims)
207+
return self.make_node(new_x, *new_idxs)
208+
198209

199210
index = Index()
200211

@@ -226,6 +237,22 @@ def make_node(self, x, y, *idxs):
226237
out = x.type()
227238
return Apply(self, [x, y, *idxs], [out])
228239

240+
def vectorize_node(self, node, *new_inputs):
241+
# If y or the indices have new dimensions we need to broadcast_x
242+
exclude: set[str] = set(
243+
chain.from_iterable(old_inp.dims for old_inp in node.inputs)
244+
)
245+
for new_inp, old_inp in zip(new_inputs, node.inputs, strict=True):
246+
# Note: This check may be too conservative
247+
if invalid_new_dims := ((set(new_inp.dims) - set(old_inp.dims)) & exclude):
248+
raise NotImplementedError(
249+
f"Vectorize of {self} is undefined because one of the inputs {new_inp} new dimensions "
250+
f"was present in the old inputs: {sorted(invalid_new_dims)}"
251+
)
252+
new_x, *_ = broadcast(*new_inputs, exclude=tuple(exclude))
253+
_, new_y, *new_idxs = new_inputs
254+
return self.make_node(new_x, new_y, *new_idxs)
255+
229256

230257
index_assignment = IndexUpdate("set")
231258
index_increment = IndexUpdate("inc")

pytensor/xtensor/reduction.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def make_node(self, x):
4646
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
4747
return Apply(self, [x], [output])
4848

49+
def vectorize_node(self, node, new_x):
50+
return self.make_node(new_x)
51+
4952

5053
def _process_user_dims(x: XTensorVariable, dim: REDUCE_DIM) -> Sequence[str]:
5154
if isinstance(dim, str):
@@ -117,6 +120,9 @@ def make_node(self, x):
117120
out = x.type()
118121
return Apply(self, [x], [out])
119122

123+
def vectorize_node(self, node, new_x):
124+
return self.make_node(new_x)
125+
120126

121127
def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
122128
x = as_xtensor(x)

pytensor/xtensor/shape.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def make_node(self, x):
6868
)
6969
return Apply(self, [x], [output])
7070

71+
def vectorize_node(self, node, new_x):
72+
return self.make_node(new_x)
73+
7174

7275
def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]):
7376
if dim is not None:
@@ -146,6 +149,19 @@ def make_node(self, x, *unstacked_length):
146149
)
147150
return Apply(self, [x, *unstacked_lengths], [output])
148151

152+
def vectorize_node(self, node, new_x, *new_unstacked_length):
153+
if len(new_unstacked_length) != len(self.unstacked_dims):
154+
raise NotImplementedError(
155+
f"Vectorization of {self} with additional unstacked_length not implemented, "
156+
"as it can't infer new dimension labels"
157+
)
158+
new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length]
159+
if not all(ul.type.ndim == 0 for ul in new_unstacked_length):
160+
raise NotImplementedError(
161+
f"Vectorization of {self} with batched unstacked_length not implemented, "
162+
)
163+
return self.make_node(new_x, *new_unstacked_length)
164+
149165

150166
def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
151167
if dim is not None:
@@ -189,6 +205,11 @@ def make_node(self, x):
189205
)
190206
return Apply(self, [x], [output])
191207

208+
def vectorize_node(self, node, new_x):
209+
old_dims = self.dims
210+
new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims)
211+
return type(self)(dims=(*new_dims, *old_dims)).make_node(new_x)
212+
192213

193214
def transpose(
194215
x,
@@ -302,6 +323,9 @@ def make_node(self, *inputs):
302323
output = xtensor(dtype=dtype, dims=dims, shape=shape)
303324
return Apply(self, inputs, [output])
304325

326+
def vectorize_node(self, node, *new_inputs):
327+
return self.make_node(*new_inputs)
328+
305329

306330
def concat(xtensors, dim: str):
307331
"""Concatenate a sequence of XTensorVariables along a specified dimension.
@@ -383,6 +407,9 @@ def make_node(self, x):
383407
)
384408
return Apply(self, [x], [out])
385409

410+
def vectorize_node(self, node, new_x):
411+
return self.make_node(new_x)
412+
386413

387414
def squeeze(x, dim: str | Sequence[str] | None = None):
388415
"""Remove dimensions of size 1 from an XTensorVariable."""
@@ -442,6 +469,14 @@ def make_node(self, x, size):
442469
)
443470
return Apply(self, [x, size], [out])
444471

472+
def vectorize_node(self, node, new_x, new_size):
473+
new_size = new_size.squeeze()
474+
if new_size.type.ndim != 0:
475+
raise NotImplementedError(
476+
f"Vectorization of {self} with batched new_size not implemented, "
477+
)
478+
return self.make_node(new_x, new_size)
479+
445480

446481
def expand_dims(x, dim=None, axis=None, **dim_kwargs):
447482
"""Add one or more new dimensions to an XTensorVariable."""
@@ -537,6 +572,19 @@ def make_node(self, *inputs):
537572

538573
return Apply(self, inputs, outputs)
539574

575+
def vectorize_node(self, node, *new_inputs):
576+
if exclude_set := set(self.exclude):
577+
for new_x, old_x in zip(node.inputs, new_inputs, strict=True):
578+
if invalid_excluded := (
579+
(set(new_x.dims) - set(old_x.dims)) & exclude_set
580+
):
581+
raise NotImplementedError(
582+
f"Vectorize of {self} is undefined because one of the inputs {new_x} "
583+
f"has an excluded dimension {sorted(invalid_excluded)} that it did not have before."
584+
)
585+
586+
return self.make_node(*new_inputs)
587+
540588

541589
def broadcast(
542590
*args, exclude: str | Sequence[str] | None = None

pytensor/xtensor/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None)
10441044

10451045
if isinstance(x, Variable):
10461046
if isinstance(x.type, XTensorType):
1047-
if (dims is None) or (x.type.dims == dims):
1047+
if (dims is None) or (x.type.dims == tuple(dims)):
10481048
return x
10491049
else:
10501050
raise ValueError(

pytensor/xtensor/vectorization.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytensor import scalar as ps
77
from pytensor import shared
88
from pytensor.graph import Apply, Op
9+
from pytensor.graph.replace import _vectorize_node
910
from pytensor.scalar import discrete_dtypes
1011
from pytensor.tensor import tensor
1112
from pytensor.tensor.random.op import RNGConsumerOp
@@ -14,8 +15,11 @@
1415
get_static_shape_from_size_variables,
1516
)
1617
from pytensor.utils import unzip
17-
from pytensor.xtensor.basic import XOp
18-
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
18+
from pytensor.xtensor.basic import (
19+
XOp,
20+
XTypeCastOp,
21+
)
22+
from pytensor.xtensor.type import XTensorType, XTensorVariable, as_xtensor, xtensor
1923

2024

2125
def combine_dims_and_shape(
@@ -69,6 +73,9 @@ def make_node(self, *inputs):
6973
]
7074
return Apply(self, inputs, outputs)
7175

76+
def vectorize_node(self, node, *new_inputs):
77+
return self.make_node(*new_inputs)
78+
7279

7380
class XBlockwise(XOp):
7481
__props__ = ("core_op", "core_dims")
@@ -136,6 +143,9 @@ def make_node(self, *inputs):
136143
]
137144
return Apply(self, inputs, outputs)
138145

146+
def vectorize_node(self, node, *new_inputs):
147+
return self.make_node(*new_inputs)
148+
139149

140150
class XRV(XOp, RNGConsumerOp):
141151
"""Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics.
@@ -283,3 +293,60 @@ def make_node(self, rng, *extra_dim_lengths_and_params):
283293
)
284294

285295
return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out])
296+
297+
def vectorize_node(self, node, *new_inputs):
298+
if len(new_inputs) != len(node.inputs):
299+
# TODO: Figure out API to allow this
300+
raise NotImplementedError(
301+
f"Vectorization of {self} with additional extra_dim_lengths not implemented, "
302+
"as it can't infer new dimension labels"
303+
)
304+
new_rng, *new_extra_dim_lengths_and_params = new_inputs
305+
new_extra_dim_lengths, new_params = (
306+
new_extra_dim_lengths_and_params[: len(self.extra_dims)],
307+
new_extra_dim_lengths_and_params[len(self.extra_dims) :],
308+
)
309+
310+
new_extra_dim_lengths = [dl.squeeze() for dl in new_extra_dim_lengths]
311+
if not all(dl.type.ndim == 0 for dl in new_extra_dim_lengths):
312+
raise NotImplementedError(
313+
f"Vectorization of {self} with batched extra_dim_lengths not implemented, "
314+
)
315+
316+
return self.make_node(new_rng, *new_extra_dim_lengths, *new_params)
317+
318+
319+
@_vectorize_node.register(XOp)
320+
@_vectorize_node.register(XTypeCastOp)
321+
def vectorize_xop(op: XOp, node, *new_inputs) -> Apply:
322+
old_inp_dims = [
323+
inp.dims for inp in node.inputs if isinstance(inp.type, XTensorType)
324+
]
325+
old_out_dims = [
326+
out.dims for out in node.outputs if isinstance(out.type, XTensorType)
327+
]
328+
all_old_dims_set = set(chain.from_iterable((*old_inp_dims, old_out_dims)))
329+
330+
for new_inp, old_inp in zip(new_inputs, node.inputs, strict=True):
331+
if not (
332+
isinstance(new_inp.type, XTensorType)
333+
and isinstance(old_inp.type, XTensorType)
334+
):
335+
continue
336+
337+
old_dims_set = set(old_inp.dims)
338+
new_dims_set = set(new_inp.dims)
339+
340+
# Validate that new inputs didn't drop pre-existing dims
341+
if missing_dims := old_dims_set - new_dims_set:
342+
raise ValueError(
343+
f"Vectorized input {new_inp} is missing pre-existing dims: {sorted(missing_dims)}"
344+
)
345+
# Or have new dimensions that were already in the graph
346+
if new_core_dims := ((new_dims_set - old_dims_set) & all_old_dims_set):
347+
raise ValueError(
348+
f"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}"
349+
)
350+
351+
# TODO: Once we stop having to return an Apply, automatically align batch_dimensions in the order they were first seen
352+
return op.vectorize_node(node, *new_inputs)

0 commit comments

Comments
 (0)