Skip to content

Commit 4297205

Browse files
Use generic methods for * and / with a scalar (#43)
* Use generic methods for `*` and `/` with a scalar * Fix errors with Metal
1 parent cfa1d29 commit 4297205

File tree

4 files changed

+31
-46
lines changed

4 files changed

+31
-46
lines changed

src/core.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,17 @@ function LinearAlgebra.lmul!(x::Number, A::AbstractGenericSparseArray)
3232
return A
3333
end
3434

35-
function LinearAlgebra.rdiv!(A::AbstractGenericSparseArray, x::Number)
36-
rdiv!(nonzeros(A), x)
35+
function LinearAlgebra.rdiv!(A::AbstractGenericSparseArray{Tv}, x::Number) where {Tv}
36+
rmul!(A, inv(Tv(x)))
3737
return A
3838
end
3939

4040
Base.:+(A::AbstractGenericSparseArray) = copy(A)
4141

42+
Base.:*::Number, A::AbstractGenericSparseArray) = lmul!(α, copy(A))
43+
Base.:*(A::AbstractGenericSparseArray, α::Number) = rmul!(copy(A), α)
44+
Base.:(/)(A::AbstractGenericSparseArray, α::Number) = rdiv!(copy(A), α)
45+
4246
Base.:*(A::AbstractGenericSparseArray, J::UniformScaling) = A * J.λ
4347
Base.:*(J::UniformScaling, A::AbstractGenericSparseArray) = J.λ * A
4448

src/matrix_coo/matrix_coo.jl

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,6 @@ function Base.zero(A::GenericSparseMatrixCOO)
104104
return GenericSparseMatrixCOO(A.m, A.n, rowind, colind, nzval)
105105
end
106106

107-
function Base.:(*)(α::Number, A::GenericSparseMatrixCOO)
108-
return GenericSparseMatrixCOO(
109-
A.m,
110-
A.n,
111-
copy(getrowind(A)),
112-
copy(getcolind(A)),
113-
α .* nonzeros(A),
114-
)
115-
end
116-
Base.:(*)(A::GenericSparseMatrixCOO, α::Number) = α * A
117-
Base.:(/)(A::GenericSparseMatrixCOO, α::Number) = (1 / α) * A
118-
119107
function Base.:-(A::GenericSparseMatrixCOO)
120108
return GenericSparseMatrixCOO(A.m, A.n, copy(A.rowind), copy(A.colind), -A.nzval)
121109
end
@@ -392,8 +380,7 @@ function Base.:+(A::GenericSparseMatrixCOO, B::GenericSparseMatrixCOO)
392380
)
393381

394382
C = GenericSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C)
395-
dropzeros!(C)
396-
return C
383+
return dropzeros(C)
397384
end
398385

399386
# Addition with transpose/adjoint support
@@ -511,8 +498,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSpars
511498
)
512499

513500
C = GenericSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C)
514-
dropzeros!(C)
515-
return C
501+
return dropzeros(C)
516502
end
517503

518504
@eval function Base.:-(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)), $(whereT2(:T2))}
@@ -820,6 +806,13 @@ function SparseArrays.dropzeros!(A::GenericSparseMatrixCOO)
820806
return A
821807
end
822808

809+
if total_nnz == 0
810+
# All elements are zeros - some GPU backends (e.g., Metal) don't support
811+
# resize to 0. Keep the stored zeros; users can use dropzeros() (non-mutating)
812+
# which returns a new matrix with properly empty arrays.
813+
return A
814+
end
815+
823816
# Allocate temporary arrays for new data
824817
new_rowind = similar(rowind, total_nnz)
825818
new_colind = similar(colind, total_nnz)

src/matrix_csc/matrix_csc.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,6 @@ function Base.zero(A::GenericSparseMatrixCSC)
8585
return GenericSparseMatrixCSC(A.m, A.n, colptr, rowval, nzval)
8686
end
8787

88-
function Base.:(*)(α::Number, A::GenericSparseMatrixCSC)
89-
return GenericSparseMatrixCSC(
90-
A.m,
91-
A.n,
92-
copy(getcolptr(A)),
93-
copy(rowvals(A)),
94-
α .* nonzeros(A),
95-
)
96-
end
97-
Base.:(*)(A::GenericSparseMatrixCSC, α::Number) = α * A
98-
Base.:(/)(A::GenericSparseMatrixCSC, α::Number) = (1 / α) * A
99-
10088
function Base.:-(A::GenericSparseMatrixCSC)
10189
return GenericSparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), -A.nzval)
10290
end
@@ -365,8 +353,7 @@ function Base.:+(A::GenericSparseMatrixCSC, B::GenericSparseMatrixCSC)
365353
)
366354

367355
C = GenericSparseMatrixCSC(m, n, colptr_C, rowval_C, nzval_C)
368-
dropzeros!(C)
369-
return C
356+
return dropzeros(C)
370357
end
371358

372359
# Addition with transpose/adjoint support
@@ -662,6 +649,13 @@ function SparseArrays.dropzeros!(A::GenericSparseMatrixCSC)
662649
cumsum_nnz = _cumsum_AK(nnz_per_col)
663650
total_nnz = @allowscalar cumsum_nnz[end]
664651

652+
if total_nnz == 0
653+
# All elements are zeros - some GPU backends (e.g., Metal) don't support
654+
# resize to 0. Keep the stored zeros; users can use dropzeros() (non-mutating)
655+
# which returns a new matrix with properly empty arrays.
656+
return A
657+
end
658+
665659
# Allocate temporary arrays for new data
666660
new_colptr = similar(getcolptr(A))
667661
new_rowval = similar(rowvals(A), total_nnz)

src/matrix_csr/matrix_csr.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,6 @@ function Base.zero(A::GenericSparseMatrixCSR)
8585
return GenericSparseMatrixCSR(A.m, A.n, rowptr, rowval, nzval)
8686
end
8787

88-
function Base.:(*)(α::Number, A::GenericSparseMatrixCSR)
89-
return GenericSparseMatrixCSR(
90-
A.m,
91-
A.n,
92-
copy(getrowptr(A)),
93-
copy(colvals(A)),
94-
α .* nonzeros(A),
95-
)
96-
end
97-
Base.:(*)(A::GenericSparseMatrixCSR, α::Number) = α * A
98-
Base.:(/)(A::GenericSparseMatrixCSR, α::Number) = (1 / α) * A
99-
10088
function Base.:-(A::GenericSparseMatrixCSR)
10189
return GenericSparseMatrixCSR(A.m, A.n, copy(A.rowptr), copy(A.colval), -A.nzval)
10290
end
@@ -363,8 +351,7 @@ function Base.:+(A::GenericSparseMatrixCSR, B::GenericSparseMatrixCSR)
363351
)
364352

365353
C = GenericSparseMatrixCSR(m, n, rowptr_C, colval_C, nzval_C)
366-
dropzeros!(C)
367-
return C
354+
return dropzeros(C)
368355
end
369356

370357
# Addition with transpose/adjoint support
@@ -657,6 +644,13 @@ function SparseArrays.dropzeros!(A::GenericSparseMatrixCSR)
657644
cumsum_nnz = _cumsum_AK(nnz_per_row)
658645
total_nnz = @allowscalar cumsum_nnz[end]
659646

647+
if total_nnz == 0
648+
# All elements are zeros - some GPU backends (e.g., Metal) don't support
649+
# resize to 0. Keep the stored zeros; users can use dropzeros() (non-mutating)
650+
# which returns a new matrix with properly empty arrays.
651+
return A
652+
end
653+
660654
# Allocate temporary arrays for new data
661655
new_rowptr = similar(getrowptr(A))
662656
new_colval = similar(colvals(A), total_nnz)

0 commit comments

Comments
 (0)