Skip to content

Commit dd4d106

Browse files
Implement sparse-sparse matrix multiplication with GPU-compatible kernels (#36)
* Initial plan * Implement sparse-sparse matrix multiplication for all formats Co-authored-by: albertomercurio <61953577+albertomercurio@users.noreply.github.com> * Add comprehensive tests for sparse-sparse multiplication Co-authored-by: albertomercurio <61953577+albertomercurio@users.noreply.github.com> * Fix code review comments - improve docstrings and remove misleading comment Co-authored-by: albertomercurio <61953577+albertomercurio@users.noreply.github.com> * Implement proper GPU-compatible SpGEMM kernels for CSC and CSR formats Co-authored-by: albertomercurio <61953577+albertomercurio@users.noreply.github.com> * Simplify COO multiplication and remove redundant tests Co-authored-by: albertomercurio <61953577+albertomercurio@users.noreply.github.com> * Fix Metal errors --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: albertomercurio <61953577+albertomercurio@users.noreply.github.com> Co-authored-by: Alberto Mercurio <alberto.mercurio96@gmail.com>
1 parent 5352275 commit dd4d106

File tree

13 files changed

+748
-83
lines changed

13 files changed

+748
-83
lines changed

ext/DeviceSparseArraysJLArraysExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@ import DeviceSparseArrays
55

66
DeviceSparseArrays._sortperm_AK(x::JLArray) = JLArray(sortperm(collect(x)))
77
DeviceSparseArrays._cumsum_AK(x::JLArray) = JLArray(cumsum(collect(x)))
8+
DeviceSparseArrays._searchsortedfirst_AK(v::JLArray, x::JLArray) =
9+
JLArray(searchsortedfirst.(Ref(collect(v)), collect(x)))
810

911
end

src/conversions/conversion_kernels.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,3 @@ end
5555
i = @index(Global)
5656
keys[i] = rowind[i] * n + colind[i]
5757
end
58-
59-
# Kernel for counting entries per column (for COO → CSC)
60-
@kernel inbounds=true function kernel_count_per_col!(colptr, @Const(colind_sorted))
61-
i = @index(Global)
62-
col = colind_sorted[i]
63-
@atomic colptr[col+1] += 1
64-
end
65-
66-
# Kernel for counting entries per row (for COO → CSR)
67-
@kernel inbounds=true function kernel_count_per_row!(rowptr, @Const(rowind_sorted))
68-
i = @index(Global)
69-
row = rowind_sorted[i]
70-
@atomic rowptr[row+1] += 1
71-
end

src/conversions/conversions.jl

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -165,21 +165,38 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
165165
colind_sorted = A.colind[perm]
166166
nzval_sorted = A.nzval[perm]
167167

168-
# Build colptr on device using a histogram approach
169-
colptr = similar(A.colind, Ti, n + 1)
170-
fill!(colptr, zero(Ti))
171-
172-
# Count entries per column
173-
kernel! = kernel_count_per_col!(backend)
174-
kernel!(colptr, colind_sorted; ndrange = (nnz_count,))
168+
# Build colptr on device using searchsortedfirst approach
169+
# Since colind_sorted is sorted, find where each column starts
170+
col_indices = similar(A.colind, Ti, n)
171+
col_indices .= Ti(1):Ti(n)
175172

176-
# Compute cumulative sum
177-
@allowscalar colptr[1] = 1 # TODO: Is there a better way to do this?
178-
colptr[2:end] .= _cumsum_AK(colptr[2:end]) .+ 1
173+
# Find start positions for each column
174+
colptr = similar(A.colind, Ti, n + 1)
175+
colptr[1:n] .= _searchsortedfirst_AK(colind_sorted, col_indices)
176+
@allowscalar colptr[n+1] = Ti(nnz_count + 1)
179177

180178
return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
181179
end
182180

181+
# Transpose and Adjoint conversions for COO to CSC
182+
DeviceSparseMatrixCSC(A::Transpose{Tv,<:DeviceSparseMatrixCOO}) where {Tv} =
183+
DeviceSparseMatrixCSC(DeviceSparseMatrixCOO(
184+
size(A, 1),
185+
size(A, 2),
186+
A.parent.colind,
187+
A.parent.rowind,
188+
A.parent.nzval,
189+
))
190+
191+
DeviceSparseMatrixCSC(A::Adjoint{Tv,<:DeviceSparseMatrixCOO}) where {Tv} =
192+
DeviceSparseMatrixCSC(DeviceSparseMatrixCOO(
193+
size(A, 1),
194+
size(A, 2),
195+
A.parent.colind,
196+
A.parent.rowind,
197+
conj.(A.parent.nzval),
198+
))
199+
183200
# ============================================================================
184201
# CSR ↔ COO Conversions
185202
# ============================================================================
@@ -223,17 +240,15 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
223240
colind_sorted = A.colind[perm]
224241
nzval_sorted = A.nzval[perm]
225242

226-
# Build rowptr on device using a histogram approach
227-
rowptr = similar(A.rowind, Ti, m + 1)
228-
fill!(rowptr, zero(Ti))
229-
230-
# Count entries per row
231-
kernel! = kernel_count_per_row!(backend)
232-
kernel!(rowptr, rowind_sorted; ndrange = (nnz_count,))
243+
# Build rowptr on device using searchsortedfirst approach
244+
# Since rowind_sorted is sorted, find where each row starts
245+
row_indices = similar(A.rowind, Ti, m)
246+
row_indices .= Ti(1):Ti(m)
233247

234-
# Compute cumulative sum
235-
@allowscalar rowptr[1] = 1 # TODO: Is there a better way to do this?
236-
rowptr[2:end] .= _cumsum_AK(rowptr[2:end]) .+ 1
248+
# Find start positions for each row
249+
rowptr = similar(A.rowind, Ti, m + 1)
250+
rowptr[1:m] .= _searchsortedfirst_AK(rowind_sorted, row_indices)
251+
@allowscalar rowptr[m+1] = Ti(nnz_count + 1)
237252

238253
return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)
239254
end

src/helpers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Helper functions to call AcceleratedKernels methods
22
_sortperm_AK(x) = AcceleratedKernels.sortperm(x)
33
_cumsum_AK(x) = AcceleratedKernels.cumsum(x)
4+
_searchsortedfirst_AK(v, x) = AcceleratedKernels.searchsortedfirst(v, x)

src/matrix_coo/matrix_coo.jl

Lines changed: 122 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,13 @@ function Base.:+(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
385385
# Mark unique entries (first occurrence of each (row, col) pair)
386386
keep_mask = similar(rowind_sorted, Bool, nnz_concat)
387387
kernel_mark! = kernel_mark_unique_coo!(backend)
388-
kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))
388+
kernel_mark!(
389+
keep_mask,
390+
rowind_sorted,
391+
colind_sorted,
392+
nnz_concat;
393+
ndrange = (nnz_concat,),
394+
)
389395

390396
# Compute write indices using cumsum
391397
write_indices = _cumsum_AK(keep_mask)
@@ -415,42 +421,43 @@ end
415421

416422
# Addition with transpose/adjoint support
417423
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
418-
for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
424+
for (wrapb, transb, conjb, unwrapb, whereT2) in
425+
trans_adj_wrappers(:DeviceSparseMatrixCOO)
419426
# Skip the case where both are not transposed (already handled above)
420427
(transa == false && transb == false) && continue
421-
428+
422429
TypeA = wrapa(:(T1))
423430
TypeB = wrapb(:(T2))
424-
431+
425432
@eval function Base.:+(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))}
426433
size(A) == size(B) || throw(
427434
DimensionMismatch(
428435
"dimensions must match: A has dims $(size(A)), B has dims $(size(B))",
429436
),
430437
)
431-
438+
432439
_A = $(unwrapa(:A))
433440
_B = $(unwrapb(:B))
434-
441+
435442
backend_A = get_backend(_A)
436443
backend_B = get_backend(_B)
437444
backend_A == backend_B ||
438445
throw(ArgumentError("Both matrices must have the same backend"))
439-
446+
440447
m, n = size(A)
441448
Ti = eltype(getrowind(_A))
442449
Tv = promote_type(eltype(nonzeros(_A)), eltype(nonzeros(_B)))
443-
450+
444451
# For transposed COO, swap row and column indices
445452
nnz_A = nnz(_A)
446453
nnz_B = nnz(_B)
447454
nnz_concat = nnz_A + nnz_B
448-
455+
449456
# Allocate concatenated arrays
450457
rowind_concat = similar(getrowind(_A), nnz_concat)
451458
colind_concat = similar(getcolind(_A), nnz_concat)
452459
nzval_concat = similar(nonzeros(_A), Tv, nnz_concat)
453-
460+
454461
# Copy entries from A (potentially swapping row/col for transpose)
455462
if $transa
456463
rowind_concat[1:nnz_A] .= getcolind(_A) # Swap for transpose
@@ -464,7 +471,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
464471
else
465472
nzval_concat[1:nnz_A] .= nonzeros(_A)
466473
end
467-
474+
468475
# Copy entries from B (potentially swapping row/col for transpose)
469476
if $transb
470477
rowind_concat[(nnz_A+1):end] .= getcolind(_B) # Swap for transpose
@@ -478,29 +485,41 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
478485
else
479486
nzval_concat[(nnz_A+1):end] .= nonzeros(_B)
480487
end
481-
488+
482489
# Sort and compact (same as before)
483490
backend = backend_A
484491
keys = similar(rowind_concat, Ti, nnz_concat)
485492
kernel_make_keys! = kernel_make_csc_keys!(backend)
486-
kernel_make_keys!(keys, rowind_concat, colind_concat, m; ndrange = (nnz_concat,))
487-
493+
kernel_make_keys!(
494+
keys,
495+
rowind_concat,
496+
colind_concat,
497+
m;
498+
ndrange = (nnz_concat,),
499+
)
500+
488501
perm = _sortperm_AK(keys)
489502
rowind_sorted = rowind_concat[perm]
490503
colind_sorted = colind_concat[perm]
491504
nzval_sorted = nzval_concat[perm]
492-
505+
493506
keep_mask = similar(rowind_sorted, Bool, nnz_concat)
494507
kernel_mark! = kernel_mark_unique_coo!(backend)
495-
kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))
496-
508+
kernel_mark!(
509+
keep_mask,
510+
rowind_sorted,
511+
colind_sorted,
512+
nnz_concat;
513+
ndrange = (nnz_concat,),
514+
)
515+
497516
write_indices = _cumsum_AK(keep_mask)
498517
nnz_final = @allowscalar write_indices[nnz_concat]
499-
518+
500519
rowind_C = similar(getrowind(_A), nnz_final)
501520
colind_C = similar(getcolind(_A), nnz_final)
502521
nzval_C = similar(nonzeros(_A), Tv, nnz_final)
503-
522+
504523
kernel_compact! = kernel_compact_coo!(backend)
505524
kernel_compact!(
506525
rowind_C,
@@ -513,7 +532,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
513532
nnz_concat;
514533
ndrange = (nnz_concat,),
515534
)
516-
535+
517536
return DeviceSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C)
518537
end
519538
end
@@ -587,3 +606,86 @@ function LinearAlgebra.kron(
587606

588607
return DeviceSparseMatrixCOO(m_C, n_C, rowind_C, colind_C, nzval_C)
589608
end
609+
610+
"""
611+
*(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
612+
613+
Multiply two sparse matrices in COO format. Both matrices must have compatible dimensions
614+
(number of columns of A equals number of rows of B) and be on the same backend (device).
615+
616+
The multiplication converts to CSC format, performs the multiplication with GPU-compatible
617+
kernels, and converts back to COO format. This approach is used for all cases including
618+
transpose/adjoint since COO doesn't have an efficient direct multiplication algorithm.
619+
620+
# Examples
621+
```jldoctest
622+
julia> using DeviceSparseArrays, SparseArrays
623+
624+
julia> A = DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [2.0, 3.0], 2, 2));
625+
626+
julia> B = DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [4.0, 5.0], 2, 2));
627+
628+
julia> C = A * B;
629+
630+
julia> collect(C)
631+
2×2 Matrix{Float64}:
632+
8.0 0.0
633+
0.0 15.0
634+
```
635+
"""
636+
function Base.:(*)(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
637+
size(A, 2) == size(B, 1) || throw(
638+
DimensionMismatch(
639+
"second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))",
640+
),
641+
)
642+
643+
backend_A = get_backend(A)
644+
backend_B = get_backend(B)
645+
backend_A == backend_B ||
646+
throw(ArgumentError("Both matrices must have the same backend"))
647+
648+
# Convert to CSC, multiply, convert back to COO
649+
# This is acceptable as COO doesn't have an efficient direct multiplication algorithm
650+
# and CSC provides the sorted structure needed for efficient SpGEMM
651+
A_csc = DeviceSparseMatrixCSC(A)
652+
B_csc = DeviceSparseMatrixCSC(B)
653+
C_csc = A_csc * B_csc
654+
return DeviceSparseMatrixCOO(C_csc)
655+
end
656+
657+
# Multiplication with transpose/adjoint support - all cases use the same approach
658+
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
659+
for (wrapb, transb, conjb, unwrapb, whereT2) in
660+
trans_adj_wrappers(:DeviceSparseMatrixCOO)
661+
# Skip the case where both are not transposed (already handled above)
662+
(transa == false && transb == false) && continue
663+
664+
TypeA = wrapa(:(T1))
665+
TypeB = wrapb(:(T2))
666+
667+
@eval function Base.:(*)(
668+
A::$TypeA,
669+
B::$TypeB,
670+
) where {$(whereT1(:T1)),$(whereT2(:T2))}
671+
size(A, 2) == size(B, 1) || throw(
672+
DimensionMismatch(
673+
"second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))",
674+
),
675+
)
676+
677+
backend_A = get_backend($(unwrapa(:A)))
678+
backend_B = get_backend($(unwrapb(:B)))
679+
backend_A == backend_B ||
680+
throw(ArgumentError("Both matrices must have the same backend"))
681+
682+
# Convert to CSC (handles transpose/adjoint), multiply, convert back to COO
683+
# Same approach as the base case since COO doesn't have an efficient
684+
# direct multiplication algorithm
685+
A_csc = DeviceSparseMatrixCSC(A)
686+
B_csc = DeviceSparseMatrixCSC(B)
687+
C_csc = A_csc * B_csc
688+
return DeviceSparseMatrixCOO(C_csc)
689+
end
690+
end
691+
end

src/matrix_coo/matrix_coo_kernels.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,16 +216,18 @@ end
216216

217217
if i <= nnz_in
218218
out_idx = write_indices[i]
219-
219+
220220
# If this is a new entry (or first of duplicates), write it
221221
if i == 1 || (rowind_in[i] != rowind_in[i-1] || colind_in[i] != colind_in[i-1])
222222
rowind_out[out_idx] = rowind_in[i]
223223
colind_out[out_idx] = colind_in[i]
224-
224+
225225
# Sum all duplicates
226226
val_sum = nzval_in[i]
227227
j = i + 1
228-
while j <= nnz_in && rowind_in[j] == rowind_in[i] && colind_in[j] == colind_in[i]
228+
while j <= nnz_in &&
229+
rowind_in[j] == rowind_in[i] &&
230+
colind_in[j] == colind_in[i]
229231
val_sum += nzval_in[j]
230232
j += 1
231233
end

0 commit comments

Comments
 (0)