@@ -68,6 +68,9 @@ def make_node(self, x):
6868 )
6969 return Apply (self , [x ], [output ])
7070
71+ def vectorize_node (self , node , new_x ):
72+ return self .make_node (new_x )
73+
7174
7275def stack (x , dim : dict [str , Sequence [str ]] | None = None , ** dims : Sequence [str ]):
7376 if dim is not None :
@@ -146,6 +149,19 @@ def make_node(self, x, *unstacked_length):
146149 )
147150 return Apply (self , [x , * unstacked_lengths ], [output ])
148151
152+ def vectorize_node (self , node , new_x , * new_unstacked_length ):
153+ if len (new_unstacked_length ) != len (self .unstacked_dims ):
154+ raise NotImplementedError (
155+ f"Vectorization of { self } with additional unstacked_length not implemented, "
156+ "as it can't infer new dimension labels"
157+ )
158+ new_unstacked_length = [ul .squeeze () for ul in new_unstacked_length ]
159+ if not all (ul .type .ndim == 0 for ul in new_unstacked_length ):
160+ raise NotImplementedError (
161+ f"Vectorization of { self } with batched unstacked_length not implemented, "
162+ )
163+ return self .make_node (new_x , * new_unstacked_length )
164+
149165
150166def unstack (x , dim : dict [str , dict [str , int ]] | None = None , ** dims : dict [str , int ]):
151167 if dim is not None :
@@ -189,6 +205,11 @@ def make_node(self, x):
189205 )
190206 return Apply (self , [x ], [output ])
191207
208+ def vectorize_node (self , node , new_x ):
209+ old_dims = self .dims
210+ new_dims = tuple (dim for dim in new_x .dims if dim not in old_dims )
211+ return type (self )(dims = (* new_dims , * old_dims )).make_node (new_x )
212+
192213
193214def transpose (
194215 x ,
@@ -302,6 +323,9 @@ def make_node(self, *inputs):
302323 output = xtensor (dtype = dtype , dims = dims , shape = shape )
303324 return Apply (self , inputs , [output ])
304325
326+ def vectorize_node (self , node , * new_inputs ):
327+ return self .make_node (* new_inputs )
328+
305329
306330def concat (xtensors , dim : str ):
307331 """Concatenate a sequence of XTensorVariables along a specified dimension.
@@ -383,6 +407,9 @@ def make_node(self, x):
383407 )
384408 return Apply (self , [x ], [out ])
385409
410+ def vectorize_node (self , node , new_x ):
411+ return self .make_node (new_x )
412+
386413
387414def squeeze (x , dim : str | Sequence [str ] | None = None ):
388415 """Remove dimensions of size 1 from an XTensorVariable."""
@@ -442,6 +469,14 @@ def make_node(self, x, size):
442469 )
443470 return Apply (self , [x , size ], [out ])
444471
472+ def vectorize_node (self , node , new_x , new_size ):
473+ new_size = new_size .squeeze ()
474+ if new_size .type .ndim != 0 :
475+ raise NotImplementedError (
476+ f"Vectorization of { self } with batched new_size not implemented, "
477+ )
478+ return self .make_node (new_x , new_size )
479+
445480
446481def expand_dims (x , dim = None , axis = None , ** dim_kwargs ):
447482 """Add one or more new dimensions to an XTensorVariable."""
@@ -537,6 +572,19 @@ def make_node(self, *inputs):
537572
538573 return Apply (self , inputs , outputs )
539574
575+ def vectorize_node (self , node , * new_inputs ):
576+ if exclude_set := set (self .exclude ):
577+ for new_x , old_x in zip (node .inputs , new_inputs , strict = True ):
578+ if invalid_excluded := (
579+ (set (new_x .dims ) - set (old_x .dims )) & exclude_set
580+ ):
581+ raise NotImplementedError (
582+ f"Vectorize of { self } is undefined because one of the inputs { new_x } "
583+ f"has an excluded dimension { sorted (invalid_excluded )} that it did not have before."
584+ )
585+
586+ return self .make_node (* new_inputs )
587+
540588
541589def broadcast (
542590 * args , exclude : str | Sequence [str ] | None = None
0 commit comments