Skip to content

Commit f4196f9

Browse files
block_diag of one matrix is Identity (#1865)
1 parent 557307a commit f4196f9

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

pytensor/tensor/slinalg.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,8 +1305,8 @@ def __init__(self, n_inputs):
13051305
input_sig = ",".join(f"(m{i},n{i})" for i in range(n_inputs))
13061306
self.gufunc_signature = f"{input_sig}->(m,n)"
13071307

1308-
if n_inputs == 0:
1309-
raise ValueError("n_inputs must be greater than 0")
1308+
if n_inputs <= 1:
1309+
raise ValueError("n_inputs must be greater than 1")
13101310
self.n_inputs = n_inputs
13111311

13121312
def grad(self, inputs, gout):
@@ -1402,6 +1402,9 @@ def block_diag(*matrices: TensorVariable):
14021402
[0, 0, 5, 6],
14031403
[0, 0, 7, 8]])
14041404
"""
1405+
if len(matrices) == 1:
1406+
return matrices[0]
1407+
14051408
_block_diagonal_matrix = Blockwise(BlockDiagonal(n_inputs=len(matrices)))
14061409
return _block_diagonal_matrix(*matrices)
14071410

0 commit comments

Comments
 (0)