@@ -385,7 +385,13 @@ function Base.:+(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
385385 # Mark unique entries (first occurrence of each (row, col) pair)
386386 keep_mask = similar (rowind_sorted, Bool, nnz_concat)
387387 kernel_mark! = kernel_mark_unique_coo! (backend)
388- kernel_mark! (keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))
388+ kernel_mark! (
389+ keep_mask,
390+ rowind_sorted,
391+ colind_sorted,
392+ nnz_concat;
393+ ndrange = (nnz_concat,),
394+ )
389395
390396 # Compute write indices using cumsum
391397 write_indices = _cumsum_AK (keep_mask)
@@ -415,42 +421,43 @@ end
415421
416422# Addition with transpose/adjoint support
417423for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers (:DeviceSparseMatrixCOO )
418- for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers (:DeviceSparseMatrixCOO )
424+ for (wrapb, transb, conjb, unwrapb, whereT2) in
425+ trans_adj_wrappers (:DeviceSparseMatrixCOO )
419426 # Skip the case where both are not transposed (already handled above)
420427 (transa == false && transb == false ) && continue
421-
428+
422429 TypeA = wrapa (:(T1))
423430 TypeB = wrapb (:(T2))
424-
431+
425432 @eval function Base.:+ (A:: $TypeA , B:: $TypeB ) where {$ (whereT1 (:T1 )),$ (whereT2 (:T2 ))}
426433 size (A) == size (B) || throw (
427434 DimensionMismatch (
428435 " dimensions must match: A has dims $(size (A)) , B has dims $(size (B)) " ,
429436 ),
430437 )
431-
438+
432439 _A = $ (unwrapa (:A ))
433440 _B = $ (unwrapb (:B ))
434-
441+
435442 backend_A = get_backend (_A)
436443 backend_B = get_backend (_B)
437444 backend_A == backend_B ||
438445 throw (ArgumentError (" Both matrices must have the same backend" ))
439-
446+
440447 m, n = size (A)
441448 Ti = eltype (getrowind (_A))
442449 Tv = promote_type (eltype (nonzeros (_A)), eltype (nonzeros (_B)))
443-
450+
444451 # For transposed COO, swap row and column indices
445452 nnz_A = nnz (_A)
446453 nnz_B = nnz (_B)
447454 nnz_concat = nnz_A + nnz_B
448-
455+
449456 # Allocate concatenated arrays
450457 rowind_concat = similar (getrowind (_A), nnz_concat)
451458 colind_concat = similar (getcolind (_A), nnz_concat)
452459 nzval_concat = similar (nonzeros (_A), Tv, nnz_concat)
453-
460+
454461 # Copy entries from A (potentially swapping row/col for transpose)
455462 if $ transa
456463 rowind_concat[1 : nnz_A] .= getcolind (_A) # Swap for transpose
@@ -464,7 +471,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
464471 else
465472 nzval_concat[1 : nnz_A] .= nonzeros (_A)
466473 end
467-
474+
468475 # Copy entries from B (potentially swapping row/col for transpose)
469476 if $ transb
470477 rowind_concat[(nnz_A+ 1 ): end ] .= getcolind (_B) # Swap for transpose
@@ -478,29 +485,41 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
478485 else
479486 nzval_concat[(nnz_A+ 1 ): end ] .= nonzeros (_B)
480487 end
481-
488+
482489 # Sort and compact (same as before)
483490 backend = backend_A
484491 keys = similar (rowind_concat, Ti, nnz_concat)
485492 kernel_make_keys! = kernel_make_csc_keys! (backend)
486- kernel_make_keys! (keys, rowind_concat, colind_concat, m; ndrange = (nnz_concat,))
487-
493+ kernel_make_keys! (
494+ keys,
495+ rowind_concat,
496+ colind_concat,
497+ m;
498+ ndrange = (nnz_concat,),
499+ )
500+
488501 perm = _sortperm_AK (keys)
489502 rowind_sorted = rowind_concat[perm]
490503 colind_sorted = colind_concat[perm]
491504 nzval_sorted = nzval_concat[perm]
492-
505+
493506 keep_mask = similar (rowind_sorted, Bool, nnz_concat)
494507 kernel_mark! = kernel_mark_unique_coo! (backend)
495- kernel_mark! (keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))
496-
508+ kernel_mark! (
509+ keep_mask,
510+ rowind_sorted,
511+ colind_sorted,
512+ nnz_concat;
513+ ndrange = (nnz_concat,),
514+ )
515+
497516 write_indices = _cumsum_AK (keep_mask)
498517 nnz_final = @allowscalar write_indices[nnz_concat]
499-
518+
500519 rowind_C = similar (getrowind (_A), nnz_final)
501520 colind_C = similar (getcolind (_A), nnz_final)
502521 nzval_C = similar (nonzeros (_A), Tv, nnz_final)
503-
522+
504523 kernel_compact! = kernel_compact_coo! (backend)
505524 kernel_compact! (
506525 rowind_C,
@@ -513,7 +532,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
513532 nnz_concat;
514533 ndrange = (nnz_concat,),
515534 )
516-
535+
517536 return DeviceSparseMatrixCOO (m, n, rowind_C, colind_C, nzval_C)
518537 end
519538 end
@@ -587,3 +606,86 @@ function LinearAlgebra.kron(
587606
588607 return DeviceSparseMatrixCOO (m_C, n_C, rowind_C, colind_C, nzval_C)
589608end
609+
610+ """
611+ *(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
612+
613+ Multiply two sparse matrices in COO format. Both matrices must have compatible dimensions
614+ (number of columns of A equals number of rows of B) and be on the same backend (device).
615+
616+ The multiplication converts to CSC format, performs the multiplication with GPU-compatible
617+ kernels, and converts back to COO format. This approach is used for all cases including
618+ transpose/adjoint since COO doesn't have an efficient direct multiplication algorithm.
619+
620+ # Examples
621+ ```jldoctest
622+ julia> using DeviceSparseArrays, SparseArrays
623+
624+ julia> A = DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [2.0, 3.0], 2, 2));
625+
626+ julia> B = DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [4.0, 5.0], 2, 2));
627+
628+ julia> C = A * B;
629+
630+ julia> collect(C)
631+ 2×2 Matrix{Float64}:
632+ 8.0 0.0
633+ 0.0 15.0
634+ ```
635+ """
636+ function Base.:(* )(A:: DeviceSparseMatrixCOO , B:: DeviceSparseMatrixCOO )
637+ size (A, 2 ) == size (B, 1 ) || throw (
638+ DimensionMismatch (
639+ " second dimension of A, $(size (A,2 )) , does not match first dimension of B, $(size (B,1 )) " ,
640+ ),
641+ )
642+
643+ backend_A = get_backend (A)
644+ backend_B = get_backend (B)
645+ backend_A == backend_B ||
646+ throw (ArgumentError (" Both matrices must have the same backend" ))
647+
648+ # Convert to CSC, multiply, convert back to COO
649+ # This is acceptable as COO doesn't have an efficient direct multiplication algorithm
650+ # and CSC provides the sorted structure needed for efficient SpGEMM
651+ A_csc = DeviceSparseMatrixCSC (A)
652+ B_csc = DeviceSparseMatrixCSC (B)
653+ C_csc = A_csc * B_csc
654+ return DeviceSparseMatrixCOO (C_csc)
655+ end
656+
657+ # Multiplication with transpose/adjoint support - all cases use the same approach
658+ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers (:DeviceSparseMatrixCOO )
659+ for (wrapb, transb, conjb, unwrapb, whereT2) in
660+ trans_adj_wrappers (:DeviceSparseMatrixCOO )
661+ # Skip the case where both are not transposed (already handled above)
662+ (transa == false && transb == false ) && continue
663+
664+ TypeA = wrapa (:(T1))
665+ TypeB = wrapb (:(T2))
666+
667+ @eval function Base.:(* )(
668+ A:: $TypeA ,
669+ B:: $TypeB ,
670+ ) where {$ (whereT1 (:T1 )),$ (whereT2 (:T2 ))}
671+ size (A, 2 ) == size (B, 1 ) || throw (
672+ DimensionMismatch (
673+ " second dimension of A, $(size (A,2 )) , does not match first dimension of B, $(size (B,1 )) " ,
674+ ),
675+ )
676+
677+ backend_A = get_backend ($ (unwrapa (:A )))
678+ backend_B = get_backend ($ (unwrapb (:B )))
679+ backend_A == backend_B ||
680+ throw (ArgumentError (" Both matrices must have the same backend" ))
681+
682+ # Convert to CSC (handles transpose/adjoint), multiply, convert back to COO
683+ # Same approach as the base case since COO doesn't have an efficient
684+ # direct multiplication algorithm
685+ A_csc = DeviceSparseMatrixCSC (A)
686+ B_csc = DeviceSparseMatrixCSC (B)
687+ C_csc = A_csc * B_csc
688+ return DeviceSparseMatrixCOO (C_csc)
689+ end
690+ end
691+ end
0 commit comments