Skip to content

Commit 778e354

Browse files
Fix COO addition to merge duplicates and remove unnecessary test dependency
Co-authored-by: albertomercurio <61953577+albertomercurio@users.noreply.github.com>
1 parent 0ced574 commit 778e354

File tree

3 files changed

+111
-19
lines changed

3 files changed

+111
-19
lines changed

src/matrix_coo/matrix_coo.jl

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,8 @@ end
334334
Add two sparse matrices in COO format. Both matrices must have the same dimensions
335335
and be on the same backend (device).
336336
337-
The result is a COO matrix with entries from both A and B concatenated. Note that
338-
duplicate entries (same row and column) are not combined, which is valid for COO format.
337+
The result is a COO matrix with entries from both A and B properly merged,
338+
with duplicate entries (same row and column) combined by summing their values.
339339
340340
# Examples
341341
```jldoctest
@@ -366,27 +366,68 @@ function Base.:+(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
366366
throw(ArgumentError("Both matrices must have the same backend"))
367367

368368
m, n = size(A)
369+
Ti = eltype(getrowind(A))
369370
Tv = promote_type(eltype(nonzeros(A)), eltype(nonzeros(B)))
370371

371372
# Concatenate the coordinate arrays
372373
nnz_A = nnz(A)
373374
nnz_B = nnz(B)
374-
nnz_total = nnz_A + nnz_B
375-
376-
# Allocate result arrays
377-
rowind_C = similar(getrowind(A), nnz_total)
378-
colind_C = similar(getcolind(A), nnz_total)
379-
nzval_C = similar(nonzeros(A), Tv, nnz_total)
380-
381-
# Copy entries from A
382-
rowind_C[1:nnz_A] .= getrowind(A)
383-
colind_C[1:nnz_A] .= getcolind(A)
384-
nzval_C[1:nnz_A] .= nonzeros(A)
385-
386-
# Copy entries from B
387-
rowind_C[(nnz_A+1):end] .= getrowind(B)
388-
colind_C[(nnz_A+1):end] .= getcolind(B)
389-
nzval_C[(nnz_A+1):end] .= nonzeros(B)
375+
nnz_concat = nnz_A + nnz_B
376+
377+
# Allocate concatenated arrays
378+
rowind_concat = similar(getrowind(A), nnz_concat)
379+
colind_concat = similar(getcolind(A), nnz_concat)
380+
nzval_concat = similar(nonzeros(A), Tv, nnz_concat)
381+
382+
# Copy entries from A and B
383+
rowind_concat[1:nnz_A] .= getrowind(A)
384+
colind_concat[1:nnz_A] .= getcolind(A)
385+
nzval_concat[1:nnz_A] .= nonzeros(A)
386+
rowind_concat[(nnz_A+1):end] .= getrowind(B)
387+
colind_concat[(nnz_A+1):end] .= getcolind(B)
388+
nzval_concat[(nnz_A+1):end] .= nonzeros(B)
389+
390+
# Sort by (row, col) using keys similar to COO->CSC conversion
391+
backend = backend_A
392+
keys = similar(rowind_concat, Ti, nnz_concat)
393+
kernel_make_keys! = kernel_make_csc_keys!(backend)
394+
kernel_make_keys!(keys, rowind_concat, colind_concat, m; ndrange = (nnz_concat,))
395+
396+
# Sort using AcceleratedKernels
397+
perm = _sortperm_AK(keys)
398+
399+
# Apply permutation to get sorted arrays
400+
rowind_sorted = rowind_concat[perm]
401+
colind_sorted = colind_concat[perm]
402+
nzval_sorted = nzval_concat[perm]
403+
404+
# Mark unique entries (first occurrence of each (row, col) pair)
405+
keep_mask = similar(rowind_sorted, Bool, nnz_concat)
406+
kernel_mark! = kernel_mark_unique_coo!(backend)
407+
kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))
408+
409+
# Compute write indices using cumsum
410+
write_indices = _cumsum_AK(keep_mask)
411+
nnz_final = allowed_getindex(write_indices, nnz_concat)
412+
413+
# Allocate final arrays
414+
rowind_C = similar(getrowind(A), nnz_final)
415+
colind_C = similar(getcolind(A), nnz_final)
416+
nzval_C = similar(nonzeros(A), Tv, nnz_final)
417+
418+
# Compact: merge duplicates by summing
419+
kernel_compact! = kernel_compact_coo!(backend)
420+
kernel_compact!(
421+
rowind_C,
422+
colind_C,
423+
nzval_C,
424+
rowind_sorted,
425+
colind_sorted,
426+
nzval_sorted,
427+
write_indices,
428+
nnz_concat;
429+
ndrange = (nnz_concat,),
430+
)
390431

391432
return DeviceSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C)
392433
end

src/matrix_coo/matrix_coo_kernels.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,55 @@ end
181181
nzval_C[idx] = val_A * val_B
182182
end
183183
end
184+
185+
# Kernel for marking duplicate entries in sorted COO format
186+
# Returns a mask where mask[i] = true if entry i should be kept (first occurrence or sum)
187+
@kernel inbounds=true function kernel_mark_unique_coo!(
188+
keep_mask,
189+
@Const(rowind),
190+
@Const(colind),
191+
@Const(nnz_total),
192+
)
193+
i = @index(Global)
194+
195+
if i == 1
196+
# Always keep the first entry
197+
keep_mask[i] = true
198+
elseif i <= nnz_total
199+
# Keep if different from previous entry
200+
keep_mask[i] = (rowind[i] != rowind[i-1] || colind[i] != colind[i-1])
201+
end
202+
end
203+
204+
# Kernel for compacting COO by summing duplicate entries
205+
@kernel inbounds=true function kernel_compact_coo!(
206+
rowind_out,
207+
colind_out,
208+
nzval_out,
209+
@Const(rowind_in),
210+
@Const(colind_in),
211+
@Const(nzval_in),
212+
@Const(write_indices),
213+
@Const(nnz_in),
214+
)
215+
i = @index(Global)
216+
217+
if i <= nnz_in
218+
out_idx = write_indices[i]
219+
220+
# If this is a new entry (or first of duplicates), write it
221+
if i == 1 || (rowind_in[i] != rowind_in[i-1] || colind_in[i] != colind_in[i-1])
222+
rowind_out[out_idx] = rowind_in[i]
223+
colind_out[out_idx] = colind_in[i]
224+
225+
# Sum all duplicates
226+
val_sum = nzval_in[i]
227+
j = i + 1
228+
while j <= nnz_in && rowind_in[j] == rowind_in[i] && colind_in[j] == colind_in[i]
229+
val_sum += nzval_in[j]
230+
j += 1
231+
end
232+
nzval_out[out_idx] = val_sum
233+
end
234+
end
235+
end

test/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4-
DeviceSparseArrays = "da3fe0eb-88a8-4d14-ae1a-857c283e9c70"
54
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
65
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
76
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

0 commit comments

Comments
 (0)