|
334 | 334 | Add two sparse matrices in COO format. Both matrices must have the same dimensions |
335 | 335 | and be on the same backend (device). |
336 | 336 |
|
337 | | -The result is a COO matrix with entries from both A and B concatenated. Note that |
338 | | -duplicate entries (same row and column) are not combined, which is valid for COO format. |
| 337 | +The result is a COO matrix with entries from both A and B properly merged, |
| 338 | +with duplicate entries (same row and column) combined by summing their values. |
339 | 339 |
|
340 | 340 | # Examples |
341 | 341 | ```jldoctest |
@@ -366,27 +366,68 @@ function Base.:+(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO) |
366 | 366 | throw(ArgumentError("Both matrices must have the same backend")) |
367 | 367 |
|
368 | 368 | m, n = size(A) |
| 369 | + Ti = eltype(getrowind(A)) |
369 | 370 | Tv = promote_type(eltype(nonzeros(A)), eltype(nonzeros(B))) |
370 | 371 |
|
371 | 372 | # Concatenate the coordinate arrays |
372 | 373 | nnz_A = nnz(A) |
373 | 374 | nnz_B = nnz(B) |
374 | | - nnz_total = nnz_A + nnz_B |
375 | | - |
376 | | - # Allocate result arrays |
377 | | - rowind_C = similar(getrowind(A), nnz_total) |
378 | | - colind_C = similar(getcolind(A), nnz_total) |
379 | | - nzval_C = similar(nonzeros(A), Tv, nnz_total) |
380 | | - |
381 | | - # Copy entries from A |
382 | | - rowind_C[1:nnz_A] .= getrowind(A) |
383 | | - colind_C[1:nnz_A] .= getcolind(A) |
384 | | - nzval_C[1:nnz_A] .= nonzeros(A) |
385 | | - |
386 | | - # Copy entries from B |
387 | | - rowind_C[(nnz_A+1):end] .= getrowind(B) |
388 | | - colind_C[(nnz_A+1):end] .= getcolind(B) |
389 | | - nzval_C[(nnz_A+1):end] .= nonzeros(B) |
| 375 | + nnz_concat = nnz_A + nnz_B |
| 376 | + |
| 377 | + # Allocate concatenated arrays |
| 378 | + rowind_concat = similar(getrowind(A), nnz_concat) |
| 379 | + colind_concat = similar(getcolind(A), nnz_concat) |
| 380 | + nzval_concat = similar(nonzeros(A), Tv, nnz_concat) |
| 381 | + |
| 382 | + # Copy entries from A and B |
| 383 | + rowind_concat[1:nnz_A] .= getrowind(A) |
| 384 | + colind_concat[1:nnz_A] .= getcolind(A) |
| 385 | + nzval_concat[1:nnz_A] .= nonzeros(A) |
| 386 | + rowind_concat[(nnz_A+1):end] .= getrowind(B) |
| 387 | + colind_concat[(nnz_A+1):end] .= getcolind(B) |
| 388 | + nzval_concat[(nnz_A+1):end] .= nonzeros(B) |
| 389 | + |
| 390 | + # Sort by (row, col) using keys similar to COO->CSC conversion |
| 391 | + backend = backend_A |
| 392 | + keys = similar(rowind_concat, Ti, nnz_concat) |
| 393 | + kernel_make_keys! = kernel_make_csc_keys!(backend) |
| 394 | + kernel_make_keys!(keys, rowind_concat, colind_concat, m; ndrange = (nnz_concat,)) |
| 395 | + |
| 396 | + # Sort using AcceleratedKernels |
| 397 | + perm = _sortperm_AK(keys) |
| 398 | + |
| 399 | + # Apply permutation to get sorted arrays |
| 400 | + rowind_sorted = rowind_concat[perm] |
| 401 | + colind_sorted = colind_concat[perm] |
| 402 | + nzval_sorted = nzval_concat[perm] |
| 403 | + |
| 404 | + # Mark unique entries (first occurrence of each (row, col) pair) |
| 405 | + keep_mask = similar(rowind_sorted, Bool, nnz_concat) |
| 406 | + kernel_mark! = kernel_mark_unique_coo!(backend) |
| 407 | + kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,)) |
| 408 | + |
| 409 | + # Compute write indices using cumsum |
| 410 | + write_indices = _cumsum_AK(keep_mask) |
| 411 | + nnz_final = allowed_getindex(write_indices, nnz_concat) |
| 412 | + |
| 413 | + # Allocate final arrays |
| 414 | + rowind_C = similar(getrowind(A), nnz_final) |
| 415 | + colind_C = similar(getcolind(A), nnz_final) |
| 416 | + nzval_C = similar(nonzeros(A), Tv, nnz_final) |
| 417 | + |
| 418 | + # Compact: merge duplicates by summing |
| 419 | + kernel_compact! = kernel_compact_coo!(backend) |
| 420 | + kernel_compact!( |
| 421 | + rowind_C, |
| 422 | + colind_C, |
| 423 | + nzval_C, |
| 424 | + rowind_sorted, |
| 425 | + colind_sorted, |
| 426 | + nzval_sorted, |
| 427 | + write_indices, |
| 428 | + nnz_concat; |
| 429 | + ndrange = (nnz_concat,), |
| 430 | + ) |
390 | 431 |
|
391 | 432 | return DeviceSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C) |
392 | 433 | end |
|
0 commit comments