@@ -68,8 +68,8 @@ def make_node(self, x):
6868 )
6969 return Apply (self , [x ], [output ])
7070
71- def vectorize_node (self , node , new_x ):
72- return self . make_node (new_x )
71+ def vectorize_node (self , node , new_x , new_dim ):
72+ return [ self (new_x )]
7373
7474
7575def stack (x , dim : dict [str , Sequence [str ]] | None = None , ** dims : Sequence [str ]):
@@ -149,18 +149,13 @@ def make_node(self, x, *unstacked_length):
149149 )
150150 return Apply (self , [x , * unstacked_lengths ], [output ])
151151
152- def vectorize_node (self , node , new_x , * new_unstacked_length ):
153- if len (new_unstacked_length ) != len (self .unstacked_dims ):
154- raise NotImplementedError (
155- f"Vectorization of { self } with additional unstacked_length not implemented, "
156- "as it can't infer new dimension labels"
157- )
152+ def vectorize_node (self , node , new_x , * new_unstacked_length , new_dim ):
158153 new_unstacked_length = [ul .squeeze () for ul in new_unstacked_length ]
159154 if not all (ul .type .ndim == 0 for ul in new_unstacked_length ):
160155 raise NotImplementedError (
161156 f"Vectorization of { self } with batched unstacked_length not implemented, "
162157 )
163- return self . make_node (new_x , * new_unstacked_length )
158+ return [ self (new_x , * new_unstacked_length )]
164159
165160
166161def unstack (x , dim : dict [str , dict [str , int ]] | None = None , ** dims : dict [str , int ]):
@@ -205,10 +200,10 @@ def make_node(self, x):
205200 )
206201 return Apply (self , [x ], [output ])
207202
208- def vectorize_node (self , node , new_x ):
203+ def vectorize_node (self , node , new_x , new_dim ):
209204 old_dims = self .dims
210205 new_dims = tuple (dim for dim in new_x .dims if dim not in old_dims )
211- return type (self )(dims = (* new_dims , * old_dims )). make_node (new_x )
206+ return [ type (self )(dims = (* new_dims , * old_dims ))(new_x )]
212207
213208
214209def transpose (
@@ -323,8 +318,8 @@ def make_node(self, *inputs):
323318 output = xtensor (dtype = dtype , dims = dims , shape = shape )
324319 return Apply (self , inputs , [output ])
325320
326- def vectorize_node (self , node , * new_inputs ):
327- return self . make_node (* new_inputs )
321+ def vectorize_node (self , node , * new_inputs , new_dim ):
322+ return [ self (* new_inputs )]
328323
329324
330325def concat (xtensors , dim : str ):
@@ -407,8 +402,8 @@ def make_node(self, x):
407402 )
408403 return Apply (self , [x ], [out ])
409404
410- def vectorize_node (self , node , new_x ):
411- return self . make_node (new_x )
405+ def vectorize_node (self , node , new_x , new_dim ):
406+ return [ self (new_x )]
412407
413408
414409def squeeze (x , dim : str | Sequence [str ] | None = None ):
@@ -469,7 +464,7 @@ def make_node(self, x, size):
469464 )
470465 return Apply (self , [x , size ], [out ])
471466
472- def vectorize_node (self , node , new_x , new_size ):
467+ def vectorize_node (self , node , new_x , new_size , new_dim ):
473468 new_size = new_size .squeeze ()
474469 if new_size .type .ndim != 0 :
475470 raise NotImplementedError (
@@ -572,7 +567,7 @@ def make_node(self, *inputs):
572567
573568 return Apply (self , inputs , outputs )
574569
575- def vectorize_node (self , node , * new_inputs ):
570+ def vectorize_node (self , node , * new_inputs , new_dim ):
576571 if exclude_set := set (self .exclude ):
577572 for new_x , old_x in zip (node .inputs , new_inputs , strict = True ):
578573 if invalid_excluded := (
@@ -583,7 +578,7 @@ def vectorize_node(self, node, *new_inputs):
583578 f"has an excluded dimension { sorted (invalid_excluded )} that it did not have before."
584579 )
585580
586- return self . make_node (* new_inputs )
581+ return self (* new_inputs , return_list = True )
587582
588583
589584def broadcast (
0 commit comments