@@ -1159,16 +1159,16 @@ def mult(self, mat, x, y):
11591159 # Row matrix
11601160 out = v .dot (x )
11611161 if y .comm .rank == 0 :
1162- y .array [0 ] = out
1162+ y .array [... ] = out
11631163 else :
11641164 y .array [...]
11651165 else :
11661166 # Column matrix
11671167 if x .sizes [1 ] == 1 :
11681168 v .copy (y )
1169- a = np .zeros (1 , dtype = dtypes .ScalarType )
1169+ a = np .zeros (() , dtype = dtypes .ScalarType )
11701170 if x .comm .rank == 0 :
1171- a [0 ] = x .array_r
1171+ a [... ] = x .array_r
11721172 else :
11731173 x .array_r
11741174 with mpi .temp_internal_comm (x .comm ) as comm :
@@ -1183,9 +1183,9 @@ def multTranspose(self, mat, x, y):
11831183 # Row matrix
11841184 if x .sizes [1 ] == 1 :
11851185 v .copy (y )
1186- a = np .zeros (1 , dtype = dtypes .ScalarType )
1186+ a = np .zeros (() , dtype = dtypes .ScalarType )
11871187 if x .comm .rank == 0 :
1188- a [0 ] = x .array_r
1188+ a [... ] = x .array_r
11891189 else :
11901190 x .array_r
11911191 with mpi .temp_internal_comm (x .comm ) as comm :
@@ -1197,7 +1197,7 @@ def multTranspose(self, mat, x, y):
11971197 # Column matrix
11981198 out = v .dot (x )
11991199 if y .comm .rank == 0 :
1200- y .array [0 ] = out
1200+ y .array [... ] = out
12011201 else :
12021202 y .array [...]
12031203
@@ -1208,9 +1208,9 @@ def multTransposeAdd(self, mat, x, y, z):
12081208 # Row matrix
12091209 if x .sizes [1 ] == 1 :
12101210 v .copy (z )
1211- a = np .zeros (1 , dtype = dtypes .ScalarType )
1211+ a = np .zeros (() , dtype = dtypes .ScalarType )
12121212 if x .comm .rank == 0 :
1213- a [0 ] = x .array_r
1213+ a [... ] = x .array_r
12141214 else :
12151215 x .array_r
12161216 with mpi .temp_internal_comm (x .comm ) as comm :
@@ -1235,7 +1235,7 @@ def multTransposeAdd(self, mat, x, y, z):
12351235 out = v .dot (x )
12361236 y = y .array_r
12371237 if z .comm .rank == 0 :
1238- z .array [0 ] = out + y [ 0 ]
1238+ z .array [... ] = out + y
12391239 else :
12401240 z .array [...]
12411241
0 commit comments