Skip to content

Commit 9a70afe

Browse files
committed
Implement dim-aware vectorize_graph
1 parent c193b25 commit 9a70afe

File tree

5 files changed

+200
-61
lines changed

5 files changed

+200
-61
lines changed

pytensor/graph/replace.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,13 @@ def vectorize_graph(
283283
# [array([-10., -11.]), array([10., 11.])]
284284
285285
"""
286+
# TODO: Move this to tensor.vectorize, and make this helper type agnostic.
287+
#
288+
# This helper may dispatch to tensor.vectorize_graph or xtensor.vectorize_graph depending on the replacement types
289+
# The behavior is distinct, because tensor vectorization depends on axis-position while xtensor depends on dimension labels
290+
#
291+
# xtensor.vectorize_graph will be able to handle batched inner tensor operations, while tensor.vectorize_graph won't,
292+
# as it is by design unaware of xtensors and their semantics.
286293
if isinstance(outputs, Sequence):
287294
seq_outputs = outputs
288295
else:

pytensor/xtensor/basic.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def perform(self, node, inputs, outputs):
1717
def do_constant_folding(self, fgraph, node):
1818
return False
1919

20-
def vectorize_node(self, node, *new_inputs):
20+
def vectorize_node(self, node, *new_inputs, new_dim: str | None):
2121
raise NotImplementedError(f"Vectorized node not implemented for {self}")
2222

2323

@@ -30,7 +30,7 @@ class XTypeCastOp(TypeCastingOp):
3030
def infer_shape(self, fgraph, node, input_shapes):
3131
return input_shapes
3232

33-
def vectorize_node(self, node, *new_inputs):
33+
def vectorize_node(self, node, *new_inputs, new_dim: str | None):
3434
raise NotImplementedError(f"Vectorized node not implemented for {self}")
3535

3636

@@ -48,11 +48,11 @@ def L_op(self, inputs, outs, g_outs):
4848
[g_out] = g_outs
4949
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
5050

51-
def vectorize_node(self, node, new_x):
51+
def vectorize_node(self, node, new_x, new_dim):
5252
[old_x] = node.inputs
5353
# We transpose batch dims to the left, for consistency with tensor vectorization
5454
new_x = new_x.transpose(..., *old_x.dims)
55-
return self.make_node(new_x)
55+
return [self(new_x)]
5656

5757

5858
tensor_from_xtensor = TensorFromXTensor()
@@ -75,15 +75,17 @@ def L_op(self, inputs, outs, g_outs):
7575
[g_out] = g_outs
7676
return [tensor_from_xtensor(g_out)]
7777

78-
def vectorize_node(self, node, new_x):
78+
def vectorize_node(self, node, new_x, new_dim):
7979
[old_x] = node.inputs
8080
if new_x.ndim != old_x.ndim:
81-
# TODO: Figure out API for this?
82-
raise NotImplementedError(
83-
f"Vectorization of {self} with batched inputs not implemented, "
84-
"as it can't infer new dimension labels"
85-
)
86-
return self().make_node(new_x)
81+
if new_dim is None:
82+
raise NotImplementedError(
83+
f"Vectorization of {self} is not well defined because it can't infer the new dimension labels. "
84+
f"Use pytensor.xtensor.vectorization.vectorize_graph instead."
85+
)
86+
return [type(self)(dims=(new_dim, *self.dims))(new_x)]
87+
else:
88+
return [self(new_x)]
8789

8890

8991
def xtensor_from_tensor(x, dims, name=None):
@@ -107,15 +109,15 @@ def L_op(self, inputs, outs, g_outs):
107109
[g_out] = g_outs
108110
return [rename(g_out, dims=x.type.dims)]
109111

110-
def vectorize_node(self, node, new_x):
112+
def vectorize_node(self, node, new_x, new_dim):
111113
[old_x] = node.inputs
112114
old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True))
113115

114116
# new_dims may include a mix of old dims (possibly re-ordered), and new dims which won't be renamed
115117
new_dims = tuple(
116118
old_dim_mapping.get(new_dim, new_dim) for new_dim in new_x.dims
117119
)
118-
return type(self)(new_dims).make_node(new_x)
120+
return [type(self)(new_dims)(new_x)]
119121

120122

121123
def rename(x, name_dict: dict[str, str] | None = None, **names: str):

pytensor/xtensor/reduction.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def make_node(self, x):
4646
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
4747
return Apply(self, [x], [output])
4848

49-
def vectorize_node(self, node, new_x):
50-
return self.make_node(new_x)
49+
def vectorize_node(self, node, new_x, new_dim):
50+
return [self(new_x)]
5151

5252

5353
def _process_user_dims(x: XTensorVariable, dim: REDUCE_DIM) -> Sequence[str]:
@@ -120,8 +120,8 @@ def make_node(self, x):
120120
out = x.type()
121121
return Apply(self, [x], [out])
122122

123-
def vectorize_node(self, node, new_x):
124-
return self.make_node(new_x)
123+
def vectorize_node(self, node, new_x, new_dim):
124+
return [self(new_x)]
125125

126126

127127
def cumreduce(x, dim: REDUCE_DIM, *, binary_op):

pytensor/xtensor/shape.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def make_node(self, x):
6868
)
6969
return Apply(self, [x], [output])
7070

71-
def vectorize_node(self, node, new_x):
72-
return self.make_node(new_x)
71+
def vectorize_node(self, node, new_x, new_dim):
72+
return [self(new_x)]
7373

7474

7575
def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]):
@@ -149,18 +149,13 @@ def make_node(self, x, *unstacked_length):
149149
)
150150
return Apply(self, [x, *unstacked_lengths], [output])
151151

152-
def vectorize_node(self, node, new_x, *new_unstacked_length):
153-
if len(new_unstacked_length) != len(self.unstacked_dims):
154-
raise NotImplementedError(
155-
f"Vectorization of {self} with additional unstacked_length not implemented, "
156-
"as it can't infer new dimension labels"
157-
)
152+
def vectorize_node(self, node, new_x, *new_unstacked_length, new_dim):
158153
new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length]
159154
if not all(ul.type.ndim == 0 for ul in new_unstacked_length):
160155
raise NotImplementedError(
161156
f"Vectorization of {self} with batched unstacked_length not implemented, "
162157
)
163-
return self.make_node(new_x, *new_unstacked_length)
158+
return [self(new_x, *new_unstacked_length)]
164159

165160

166161
def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
@@ -205,10 +200,10 @@ def make_node(self, x):
205200
)
206201
return Apply(self, [x], [output])
207202

208-
def vectorize_node(self, node, new_x):
203+
def vectorize_node(self, node, new_x, new_dim):
209204
old_dims = self.dims
210205
new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims)
211-
return type(self)(dims=(*new_dims, *old_dims)).make_node(new_x)
206+
return [type(self)(dims=(*new_dims, *old_dims))(new_x)]
212207

213208

214209
def transpose(
@@ -323,8 +318,8 @@ def make_node(self, *inputs):
323318
output = xtensor(dtype=dtype, dims=dims, shape=shape)
324319
return Apply(self, inputs, [output])
325320

326-
def vectorize_node(self, node, *new_inputs):
327-
return self.make_node(*new_inputs)
321+
def vectorize_node(self, node, *new_inputs, new_dim):
322+
return [self(*new_inputs)]
328323

329324

330325
def concat(xtensors, dim: str):
@@ -407,8 +402,8 @@ def make_node(self, x):
407402
)
408403
return Apply(self, [x], [out])
409404

410-
def vectorize_node(self, node, new_x):
411-
return self.make_node(new_x)
405+
def vectorize_node(self, node, new_x, new_dim):
406+
return [self(new_x)]
412407

413408

414409
def squeeze(x, dim: str | Sequence[str] | None = None):
@@ -469,7 +464,7 @@ def make_node(self, x, size):
469464
)
470465
return Apply(self, [x, size], [out])
471466

472-
def vectorize_node(self, node, new_x, new_size):
467+
def vectorize_node(self, node, new_x, new_size, new_dim):
473468
new_size = new_size.squeeze()
474469
if new_size.type.ndim != 0:
475470
raise NotImplementedError(
@@ -572,7 +567,7 @@ def make_node(self, *inputs):
572567

573568
return Apply(self, inputs, outputs)
574569

575-
def vectorize_node(self, node, *new_inputs):
570+
def vectorize_node(self, node, *new_inputs, new_dim):
576571
if exclude_set := set(self.exclude):
577572
for new_x, old_x in zip(node.inputs, new_inputs, strict=True):
578573
if invalid_excluded := (
@@ -583,7 +578,7 @@ def vectorize_node(self, node, *new_inputs):
583578
f"has an excluded dimension {sorted(invalid_excluded)} that it did not have before."
584579
)
585580

586-
return self.make_node(*new_inputs)
581+
return self(*new_inputs, return_list=True)
587582

588583

589584
def broadcast(

0 commit comments

Comments
 (0)