File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -1305,8 +1305,8 @@ def __init__(self, n_inputs):
13051305 input_sig = "," .join (f"(m{ i } ,n{ i } )" for i in range (n_inputs ))
13061306 self .gufunc_signature = f"{ input_sig } ->(m,n)"
13071307
1308- if n_inputs == 0 :
1309- raise ValueError ("n_inputs must be greater than 0 " )
1308+ if n_inputs <= 1 :
1309+ raise ValueError ("n_inputs must be greater than 1 " )
13101310 self .n_inputs = n_inputs
13111311
13121312 def grad (self , inputs , gout ):
@@ -1402,6 +1402,9 @@ def block_diag(*matrices: TensorVariable):
14021402 [0, 0, 5, 6],
14031403 [0, 0, 7, 8]])
14041404 """
1405+ if len (matrices ) == 1 :
1406+ return matrices [0 ]
1407+
14051408 _block_diagonal_matrix = Blockwise (BlockDiagonal (n_inputs = len (matrices )))
14061409 return _block_diagonal_matrix (* matrices )
14071410
You can’t perform that action at this time.
0 commit comments