11function sparse_release_matrix_handle(A:: oneAbstractSparseMatrix )
2- queue = global_queue(context(A. nzVal), device(A. nzVal))
3- handle_ptr = Ref{matrix_handle_t}(A. handle)
4- onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr)
2+ return if A. handle != = nothing
3+ try
4+ queue = global_queue(context(A. nzVal), device(A. nzVal))
5+ handle_ptr = Ref{matrix_handle_t}(A. handle)
6+ onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr)
7+ # Only synchronize after successful release to ensure completion
8+ synchronize(queue)
9+ catch err
10+ # Don't let finalizer errors crash the program
11+ @warn " Error releasing sparse matrix handle" exception = err
12+ end
13+ end
514end
615
716for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int32),
@@ -13,20 +22,55 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
1322 (:onemklZsparse_set_csr_data , :ComplexF64, :Int32),
1423 (:onemklZsparse_set_csr_data_64, :ComplexF64, :Int64))
1524 @eval begin
16- function oneSparseMatrixCSR(A:: SparseMatrixCSC{$elty, $intty} )
25+
26+ function oneSparseMatrixCSR(
27+ rowPtr:: oneVector{$intty} , colVal:: oneVector{$intty} ,
28+ nzVal:: oneVector{$elty} , dims:: NTuple{2, Int}
29+ )
30+ handle_ptr = Ref{matrix_handle_t}()
31+ onemklXsparse_init_matrix_handle(handle_ptr)
32+ m, n = dims
33+ nnzA = length(nzVal)
34+ queue = global_queue(context(nzVal), device(nzVal))
35+ # Don't update handle if matrix is empty
36+ if m != 0 && n != 0
37+ $ fname(sycl_queue(queue), handle_ptr[], m, n, ' O' , rowPtr, colVal, nzVal)
38+ dA = oneSparseMatrixCSR{$ elty, $ intty}(handle_ptr[], rowPtr, colVal, nzVal, (m, n), nnzA)
39+ finalizer(sparse_release_matrix_handle, dA)
40+ else
41+ dA = oneSparseMatrixCSR{$ elty, $ intty}(nothing , rowPtr, colVal, nzVal, (m, n), nnzA)
42+ end
43+ return dA
44+ end
45+
46+ function oneSparseMatrixCSC(
47+ colPtr:: oneVector{$intty} , rowVal:: oneVector{$intty} ,
48+ nzVal:: oneVector{$elty} , dims:: NTuple{2, Int}
49+ )
50+ queue = global_queue(context(nzVal), device(nzVal))
1751 handle_ptr = Ref{matrix_handle_t}()
1852 onemklXsparse_init_matrix_handle(handle_ptr)
53+ m, n = dims
54+ nnzA = length(nzVal)
55+ # Don't update handle if matrix is empty
56+ if m != 0 && n != 0
57+ $ fname(sycl_queue(queue), handle_ptr[], n, m, ' O' , colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
58+ dA = oneSparseMatrixCSC{$ elty, $ intty}(handle_ptr[], colPtr, rowVal, nzVal, (m, n), nnzA)
59+ finalizer(sparse_release_matrix_handle, dA)
60+ else
61+ dA = oneSparseMatrixCSC{$ elty, $ intty}(nothing , colPtr, rowVal, nzVal, (m, n), nnzA)
62+ end
63+ return dA
64+ end
65+
66+
67+ function oneSparseMatrixCSR(A:: SparseMatrixCSC{$elty, $intty} )
1968 m, n = size(A)
2069 At = SparseMatrixCSC(A |> transpose)
2170 rowPtr = oneVector{$ intty}(At. colptr)
2271 colVal = oneVector{$ intty}(At. rowval)
2372 nzVal = oneVector{$ elty}(At. nzval)
24- nnzA = length(At. nzval)
25- queue = global_queue(context(nzVal), device())
26- $ fname(sycl_queue(queue), handle_ptr[], m, n, ' O' , rowPtr, colVal, nzVal)
27- dA = oneSparseMatrixCSR{$ elty, $ intty}(handle_ptr[], rowPtr, colVal, nzVal, (m,n), nnzA)
28- finalizer(sparse_release_matrix_handle, dA)
29- return dA
73+ return oneSparseMatrixCSR(rowPtr, colVal, nzVal, (m, n))
3074 end
3175
3276 function SparseArrays. SparseMatrixCSC(A:: oneSparseMatrixCSR{$elty, $intty} )
@@ -37,18 +81,11 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
3781 end
3882
3983 function oneSparseMatrixCSC(A:: SparseMatrixCSC{$elty, $intty} )
40- handle_ptr = Ref{matrix_handle_t}()
41- onemklXsparse_init_matrix_handle(handle_ptr)
4284 m, n = size(A)
4385 colPtr = oneVector{$ intty}(A. colptr)
4486 rowVal = oneVector{$ intty}(A. rowval)
4587 nzVal = oneVector{$ elty}(A. nzval)
46- nnzA = length(A. nzval)
47- queue = global_queue(context(nzVal), device())
48- $ fname(sycl_queue(queue), handle_ptr[], n, m, ' O' , colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
49- dA = oneSparseMatrixCSC{$ elty, $ intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
50- finalizer(sparse_release_matrix_handle, dA)
51- return dA
88+ return oneSparseMatrixCSC(colPtr, rowVal, nzVal, (m, n))
5289 end
5390
5491 function SparseArrays. SparseMatrixCSC(A:: oneSparseMatrixCSC{$elty, $intty} )
@@ -77,10 +114,14 @@ for (fname, elty, intty) in ((:onemklSsparse_set_coo_data , :Float32 , :Int3
77114 colInd = oneVector{$ intty}(col)
78115 nzVal = oneVector{$ elty}(val)
79116 nnzA = length(val)
80- queue = global_queue(context(nzVal), device())
81- $ fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, ' O' , rowInd, colInd, nzVal)
82- dA = oneSparseMatrixCOO{$ elty, $ intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
83- finalizer(sparse_release_matrix_handle, dA)
117+ queue = global_queue(context(nzVal), device(nzVal))
118+ if m != 0 && n != 0
119+ $ fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, ' O' , rowInd, colInd, nzVal)
120+ dA = oneSparseMatrixCOO{$ elty, $ intty}(handle_ptr[], rowInd, colInd, nzVal, (m, n), nnzA)
121+ finalizer(sparse_release_matrix_handle, dA)
122+ else
123+ dA = oneSparseMatrixCOO{$ elty, $ intty}(nothing , rowInd, colInd, nzVal, (m, n), nnzA)
124+ end
84125 return dA
85126 end
86127
@@ -105,7 +146,7 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO)
105146 beta:: Number ,
106147 y:: oneStridedVector{$elty} )
107148
108- queue = global_queue(context(x), device())
149+ queue = global_queue(context(x), device(x ))
109150 $ fname(sycl_queue(queue), trans, alpha, A. handle, x, beta, y)
110151 y
111152 end
@@ -140,8 +181,11 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
140181 beta:: Number ,
141182 y:: oneStridedVector{$elty} )
142183
143- queue = global_queue(context(x), device())
144- $ fname(sycl_queue(queue), flip_trans(trans), alpha, A. handle, x, beta, y)
184+ queue = global_queue(context(x), device(x))
185+ m, n = size(A)
186+ if m != 0 && n != 0
187+ $ fname(sycl_queue(queue), flip_trans(trans), alpha, A. handle, x, beta, y)
188+ end
145189 y
146190 end
147191 end
@@ -173,7 +217,7 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
173217 beta = conj(beta)
174218 end
175219
176- queue = global_queue(context(x), device())
220+ queue = global_queue(context(x), device(x ))
177221 $ fname(sycl_queue(queue), flip_trans(trans), alpha, A. handle, x, beta, y)
178222
179223 if trans == ' C'
@@ -217,7 +261,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
217261 nrhs = size(B, 2 )
218262 ldb = max(1 ,stride(B,2 ))
219263 ldc = max(1 ,stride(C,2 ))
220- queue = global_queue(context(C), device())
264+ queue = global_queue(context(C), device(C ))
221265 $ fname(sycl_queue(queue), ' C' , transa, transb, alpha, A. handle, B, nrhs, ldb, beta, C, ldc)
222266 C
223267 end
@@ -254,7 +298,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
254298 nrhs = size(B, 2 )
255299 ldb = max(1 ,stride(B,2 ))
256300 ldc = max(1 ,stride(C,2 ))
257- queue = global_queue(context(C), device())
301+ queue = global_queue(context(C), device(C ))
258302 $ fname(sycl_queue(queue), ' C' , flip_trans(transa), transb, alpha, A. handle, B, nrhs, ldb, beta, C, ldc)
259303 C
260304 end
@@ -289,7 +333,7 @@ for (fname, elty) in (
289333 nrhs = size(B, 2 )
290334 ldb = max(1 , stride(B, 2 ))
291335 ldc = max(1 , stride(C, 2 ))
292- queue = global_queue(context(C), device())
336+ queue = global_queue(context(C), device(C ))
293337
294338 # Use identity: conj(C_new) = conj(alpha) * S * conj(opB(B)) + conj(beta) * conj(C)
295339 # Prepare conj(C) in-place and conj(B) into a temporary if needed
@@ -359,7 +403,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
359403 beta:: Number ,
360404 y:: oneStridedVector{$elty} )
361405
362- queue = global_queue(context(y), device())
406+ queue = global_queue(context(y), device(y ))
363407 $ fname(sycl_queue(queue), uplo, alpha, A. handle, x, beta, y)
364408 y
365409 end
@@ -379,7 +423,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
379423 beta:: Number ,
380424 y:: oneStridedVector{$elty} )
381425
382- queue = global_queue(context(y), device())
426+ queue = global_queue(context(y), device(y ))
383427 $ fname(sycl_queue(queue), flip_uplo(uplo), alpha, A. handle, x, beta, y)
384428 y
385429 end
@@ -400,7 +444,7 @@ for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
400444 beta:: Number ,
401445 y:: oneStridedVector{$elty} )
402446
403- queue = global_queue(context(y), device())
447+ queue = global_queue(context(y), device(y ))
404448 $ fname(sycl_queue(queue), uplo, trans, diag, alpha, A. handle, x, beta, y)
405449 y
406450 end
@@ -442,7 +486,7 @@ for (fname, elty) in (
442486 " Convert to oneSparseMatrixCSR format instead."
443487 )
444488 )
445- queue = global_queue(context(y), device())
489+ queue = global_queue(context(y), device(y ))
446490 $ fname(sycl_queue(queue), uplo, flip_trans(trans), diag, alpha, A. handle, x, beta, y)
447491 return y
448492 end
@@ -475,7 +519,7 @@ for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
475519 x:: oneStridedVector{$elty} ,
476520 y:: oneStridedVector{$elty} )
477521
478- queue = global_queue(context(y), device())
522+ queue = global_queue(context(y), device(y ))
479523 $ fname(sycl_queue(queue), uplo, trans, diag, alpha, A. handle, x, y)
480524 y
481525 end
@@ -512,7 +556,7 @@ for (fname, elty) in (
512556 " Convert to oneSparseMatrixCSR format instead."
513557 )
514558 )
515- queue = global_queue(context(y), device())
559+ queue = global_queue(context(y), device(y ))
516560 onemklXsparse_optimize_trsv(sycl_queue(queue), uplo, flip_trans(trans), diag, A. handle)
517561 return A
518562 end
@@ -555,7 +599,7 @@ for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
555599 nrhs = size(X, 2 )
556600 ldx = max(1 ,stride(X,2 ))
557601 ldy = max(1 ,stride(Y,2 ))
558- queue = global_queue(context(Y), device())
602+ queue = global_queue(context(Y), device(Y ))
559603 $ fname(sycl_queue(queue), ' C' , transA, transX, uplo, diag, alpha, A. handle, X, nrhs, ldx, Y, ldy)
560604 Y
561605 end
@@ -614,7 +658,7 @@ for (fname, elty) in (
614658 nrhs = size(X, 2 )
615659 ldx = max(1 , stride(X, 2 ))
616660 ldy = max(1 , stride(Y, 2 ))
617- queue = global_queue(context(Y), device())
661+ queue = global_queue(context(Y), device(Y ))
618662 $ fname(sycl_queue(queue), ' C' , flip_trans(transA), transX, uplo, diag, alpha, A. handle, X, nrhs, ldx, Y, ldy)
619663 return Y
620664 end
0 commit comments