Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/darray.md
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ From `LinearAlgebra`:
- `*` (Out-of-place Matrix-(Matrix/Vector) multiply)
- `mul!` (In-place Matrix-Matrix and Matrix-Vector multiply)
- `cholesky`/`cholesky!` (In-place/Out-of-place Cholesky factorization)
- `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` only))
- `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` and `RowMaximum` only))

From `AbstractFFTs`:
- `fft`/`fft!`
Expand Down
5 changes: 3 additions & 2 deletions src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import SparseArrays: sprand, SparseMatrixCSC
import MemPool
import MemPool: DRef, FileRef, poolget, poolset

import Base: collect, reduce
import Base: collect, reduce, view

import LinearAlgebra
import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric
import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LU, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, Cholesky, diagind, ishermitian, issymmetric, I
import Random
import Random: AbstractRNG

Expand Down Expand Up @@ -120,6 +120,7 @@ include("array/sort.jl")
include("array/linalg.jl")
include("array/mul.jl")
include("array/cholesky.jl")
include("array/trsm.jl")
include("array/lu.jl")

import KernelAbstractions, Adapt
Expand Down
14 changes: 14 additions & 0 deletions src/array/alloc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ function Base.zero(x::DArray{T,N}) where {T,N}
return _to_darray(a)
end

# Weird LinearAlgebra dispatch in `\` needs this
function LinearAlgebra._zeros(::Type{T}, B::DVector, n::Integer) where T
m = max(size(B, 1), n)
sz = (m,)
return zeros(auto_blocks(sz), T, sz)
end
function LinearAlgebra._zeros(::Type{T}, B::DMatrix, n::Integer) where T
m = max(size(B, 1), n)
sz = (m, size(B, 2))
return zeros(auto_blocks(sz), T, sz)
end

function Base.view(A::AbstractArray{T,N}, p::Blocks{N}) where {T,N}
d = ArrayDomain(Base.index_shape(A))
dc = partition(p, d)
Expand All @@ -192,3 +204,5 @@ function Base.view(A::AbstractArray{T,N}, p::Blocks{N}) where {T,N}
chunks = [tochunk(view(A, x.indexes...)) for x in dc]
return DArray(T, d, dc, chunks, p)
end
Base.view(A::AbstractArray, ::AutoBlocks) =
view(A, auto_blocks(size(A)))
7 changes: 5 additions & 2 deletions src/array/darray.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import Base: ==, fetch

export DArray, DVector, DMatrix, Blocks, AutoBlocks
export DArray, DVector, DMatrix, DVecOrMat, Blocks, AutoBlocks
export distribute


Expand Down Expand Up @@ -146,6 +146,7 @@ const WrappedDMatrix{T} = WrappedDArray{T,2}
const WrappedDVector{T} = WrappedDArray{T,1}
const DMatrix{T} = DArray{T,2}
const DVector{T} = DArray{T,1}
const DVecOrMat{T} = Union{DVector{T}, DMatrix{T}}

# mainly for backwards-compatibility
DArray{T, N}(domain, subdomains, chunks, partitioning, concat=cat) where {T,N} =
Expand Down Expand Up @@ -250,7 +251,9 @@ function Base.getindex(A::ColorArray{T,N}, idxs::NTuple{N,Int}) where {T,N}
if !haskey(A.seen_values, idxs)
chunk = A.A.chunks[sd_idx]
if chunk isa Chunk || isready(chunk)
value = A.seen_values[idxs] = Some(getindex(A.A, idxs))
value = A.seen_values[idxs] = allowscalar() do
Some(getindex(A.A, idxs))
end
else
# Show a placeholder instead
value = A.seen_values[idxs] = nothing
Expand Down
125 changes: 125 additions & 0 deletions src/array/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,128 @@ function LinearAlgebra.ishermitian(A::DArray{T,2}) where T

return all(fetch, to_check)
end

function LinearAlgebra.LAPACK.chkfinite(A::DArray)
Ac = A.chunks
chunk_finite = [Ref(true) for _ in Ac]
chkfinite!(finite, A) = finite[] = LinearAlgebra.LAPACK.chkfinite(A)
Dagger.spawn_datadeps() do
for idx in eachindex(Ac)
Dagger.@spawn chkfinite!(Out(chunk_finite[idx]), In(Ac[idx]))
end
end
return all(getindex, chunk_finite)
end

DMatrix{T}(::LinearAlgebra.UniformScaling, m::Int, n::Int, IBlocks::Blocks) where T = DMatrix(Matrix{T}(I, m, n), IBlocks)

DMatrix{T}(::LinearAlgebra.UniformScaling, size::Tuple, IBlocks::Blocks) where T = DMatrix(Matrix{T}(I, size), IBlocks)

function LinearAlgebra.inv(F::LU{T,<:DMatrix}) where T
n = size(F, 1)
dest = DMatrix{T}(I, n, n, F.factors.partitioning)
LinearAlgebra.ldiv!(F, dest)
return dest
end

function LinearAlgebra.inv(A::LowerTriangular{T,<:DMatrix}) where T
S = typeof(LinearAlgebra.inv(oneunit(T)))
dest = DMatrix{S}(I, size(A), A.data.partitioning)
LinearAlgebra.ldiv!(convert(AbstractArray{S}, A), dest)
dest = LowerTriangular(dest)
return dest
end

function LinearAlgebra.inv(A::UpperTriangular{T,<:DMatrix}) where T
S = typeof(LinearAlgebra.inv(oneunit(T)))
dest = DMatrix{S}(I, size(A), A.data.partitioning)
LinearAlgebra.ldiv!(convert(AbstractArray{S}, A), dest)
dest = UpperTriangular(dest)
return dest
end

function LinearAlgebra.inv(A::UnitLowerTriangular{T,<:DMatrix}) where T
S = typeof(LinearAlgebra.inv(oneunit(T)))
dest = DMatrix{S}(I, size(A), A.data.partitioning)
LinearAlgebra.ldiv!(convert(AbstractArray{S}, A), dest)
dest = UnitLowerTriangular(dest)
return dest
end

function LinearAlgebra.inv(A::UnitUpperTriangular{T,<:DMatrix}) where T
S = typeof(LinearAlgebra.inv(oneunit(T)))
dest = DMatrix{S}(I, size(A), A.data.partitioning)
LinearAlgebra.ldiv!(convert(AbstractArray{S}, A), dest)
dest = UnitUpperTriangular(dest)
return dest
end

function LinearAlgebra.inv(A::DMatrix{T}) where T
n = LinearAlgebra.checksquare(A)
S = typeof(zero(T)/one(T)) # dimensionful
S0 = typeof(zero(T)/oneunit(T)) # dimensionless
dest = DMatrix{S0}(I, n, n, A.partitioning)
F = factorize(convert(AbstractMatrix{S}, A))
LinearAlgebra.ldiv!(F, dest)
return dest
end


function LinearAlgebra.ldiv!(A::LU{<:Any,<:DMatrix}, B::AbstractVecOrMat)
allowscalar(true) do
LinearAlgebra._apply_ipiv_rows!(A, B)
end
LinearAlgebra.ldiv!(UnitLowerTriangular(A.factors), B)
LinearAlgebra.ldiv!(UpperTriangular(A.factors), B)
end

function LinearAlgebra.ldiv!(A::Union{LowerTriangular{<:Any,<:DMatrix},UnitLowerTriangular{<:Any,<:DMatrix},UpperTriangular{<:Any,<:DMatrix},UnitUpperTriangular{<:Any,<:DMatrix}}, B::AbstractVecOrMat)
alpha = one(eltype(A))
trans = 'N'
diag = isa(A, UnitUpperTriangular) || isa(A, UnitLowerTriangular) ? 'U' : 'N'

if isa(A, UpperTriangular) || isa(A, UnitUpperTriangular)
uplo = 'U'
elseif isa(A, LowerTriangular) || isa(A, UnitLowerTriangular)
uplo = 'L'
end

dB = B isa DVecOrMat ? B : (B isa AbstractMatrix ? view(B, A.data.partitioning) : view(B, AutoBlocks()))

parent_A = parent(A)
if isa(B, AbstractVector)
min_bsa = min(min(parent_A.partitioning.blocksize...), dB.partitioning.blocksize[1])
Dagger.maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB=>Blocks(min_bsa)) do parent_A, dB
Dagger.trsv!(uplo, trans, diag, alpha, parent_A, dB)
end
elseif isa(B, AbstractMatrix)
min_bsa = min(parent_A.partitioning.blocksize...)
Dagger.maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB=>Blocks(min_bsa, min_bsa)) do parent_A, dB
Dagger.trsm!('L', uplo, trans, diag, alpha, parent_A, dB)
end
end
end

function LinearAlgebra.ldiv!(Y::DArray, A::DMatrix, B::DArray)
LinearAlgebra.ldiv!(A, copyto!(Y, B))
end

function LinearAlgebra.ldiv!(A::DMatrix, B::DArray)
LinearAlgebra.ldiv!(LinearAlgebra.lu(A), B)
end

function LinearAlgebra.ldiv!(C::DVecOrMat, A::Union{LowerTriangular{<:Any,<:DMatrix},UnitLowerTriangular{<:Any,<:DMatrix},UpperTriangular{<:Any,<:DMatrix},UnitUpperTriangular{<:Any,<:DMatrix}}, B::DVecOrMat)
LinearAlgebra.ldiv!(A, copyto!(C, B))
end

function LinearAlgebra.ldiv!(C::Cholesky{<:Any,<:DMatrix}, B::DVecOrMat)
# L * y = B
y = copyto!(similar(B), B)
LinearAlgebra.ldiv!(C.L, y)

# L' * x = y
copyto!(B, y)
LinearAlgebra.ldiv!(C.U, B)

return B
end
151 changes: 134 additions & 17 deletions src/array/lu.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,148 @@
function LinearAlgebra.lu(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=true) where T
LinearAlgebra.lu(A::DMatrix{T}, pivot::Union{LinearAlgebra.RowMaximum,LinearAlgebra.NoPivot} = LinearAlgebra.RowMaximum(); check::Bool=true, allowsingular::Bool=false) where {T<:LinearAlgebra.BlasFloat} = LinearAlgebra.lu(A, pivot; check=check, allowsingular=allowsingular)

LinearAlgebra.lu!(A::DMatrix{T}, pivot::Union{LinearAlgebra.RowMaximum,LinearAlgebra.NoPivot} = LinearAlgebra.RowMaximum(); check::Bool=true, allowsingular::Bool=false) where {T<:LinearAlgebra.BlasFloat} = LinearAlgebra.lu(A, pivot; check=check, allowsingular=allowsingular)

function LinearAlgebra.lu(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool = true, allowsingular::Bool = false) where {T<:LinearAlgebra.BlasFloat}
A_copy = LinearAlgebra._lucopy(A, LinearAlgebra.lutype(T))
return LinearAlgebra.lu!(A_copy, LinearAlgebra.NoPivot(); check=check)
return LinearAlgebra.lu!(A_copy, LinearAlgebra.NoPivot(); check)
end
function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=true) where T
function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool = true, allowsingular::Bool = false) where {T<:LinearAlgebra.BlasFloat}
check && LinearAlgebra.LAPACK.chkfinite(A)

zone = one(T)
mzone = -one(T)
Ac = A.chunks
mt, nt = size(Ac)
iscomplex = T <: Complex
trans = iscomplex ? 'C' : 'T'

Dagger.spawn_datadeps() do
for k in range(1, min(mt, nt))
Dagger.@spawn LinearAlgebra.generic_lufact!(InOut(Ac[k, k]), LinearAlgebra.NoPivot(); check)
for m in range(k+1, mt)
Dagger.@spawn BLAS.trsm!('R', 'U', 'N', 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]))
end
for n in range(k+1, nt)
Dagger.@spawn BLAS.trsm!('L', 'L', 'N', 'U', zone, In(Ac[k, k]), InOut(Ac[k, n]))

mb, nb = A.partitioning.blocksize

min_mb_nb = min(mb, nb)
maybe_copy_buffered(A => Blocks(min_mb_nb, min_mb_nb)) do A
Ac = A.chunks
mt, nt = size(Ac)

Dagger.spawn_datadeps() do
for k in range(1, min(mt, nt))
Dagger.@spawn LinearAlgebra.generic_lufact!(InOut(Ac[k, k]), LinearAlgebra.NoPivot(); check, allowsingular)
for m in range(k+1, mt)
Dagger.@spawn BLAS.gemm!('N', 'N', mzone, In(Ac[m, k]), In(Ac[k, n]), zone, InOut(Ac[m, n]))
Dagger.@spawn BLAS.trsm!('R', 'U', 'N', 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]))
end
for n in range(k+1, nt)
Dagger.@spawn BLAS.trsm!('L', 'L', 'N', 'U', zone, In(Ac[k, k]), InOut(Ac[k, n]))
for m in range(k+1, mt)
Dagger.@spawn BLAS.gemm!('N', 'N', mzone, In(Ac[m, k]), In(Ac[k, n]), zone, InOut(Ac[m, n]))
end
end
end
end

check && LinearAlgebra._check_lu_success(0, allowsingular)
end

ipiv = DVector([i for i in 1:min(size(A)...)])

return LinearAlgebra.LU{T,DMatrix{T},DVector{Int}}(A, ipiv, 0)
end

function searchmax_pivot!(piv_idx::AbstractVector{Int}, piv_val::AbstractVector{T}, A::AbstractMatrix{T}, offset::Int=0) where T
max_idx = LinearAlgebra.BLAS.iamax(A[:])
piv_idx[1] = offset+max_idx
piv_val[1] = A[max_idx]
end

function update_ipiv!(ipivl::AbstractVector{Int}, info::Ref{Int}, piv_idx::AbstractVector{Int}, piv_val::AbstractVector{T}, k::Int, nb::Int) where T
max_piv_idx = LinearAlgebra.BLAS.iamax(piv_val)
max_piv_val = piv_val[max_piv_idx]
abs_max_piv_val = max_piv_val isa Real ? abs(max_piv_val) : abs(real(max_piv_val)) + abs(imag(max_piv_val))
if isapprox(abs_max_piv_val, zero(T); atol=eps(real(T)))
info[] = k
end
ipivl[1] = (max_piv_idx+k-2)*nb + piv_idx[max_piv_idx]
end

function swaprows_panel!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipivl::AbstractVector{Int}, m::Int, p::Int, nb::Int) where T
q = div(ipivl[1]-1,nb) + 1
r = (ipivl[1]-1)%nb+1
if m == q
A[p,:], M[r,:] = M[r,:], A[p,:]
end
end

function update_panel!(M::AbstractMatrix{T}, A::AbstractMatrix{T}, p::Int) where T
Acinv = one(T) / A[p,p]
LinearAlgebra.BLAS.scal!(Acinv, view(M, :, p))
LinearAlgebra.BLAS.ger!(-one(T), view(M, :, p), conj.(view(A, p, p+1:size(A,2))), view(M, :, p+1:size(M,2)))
end

function swaprows_trail!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipiv::AbstractVector{Int}, m::Int, nb::Int) where T
for p in eachindex(ipiv)
q = div(ipiv[p]-1,nb) + 1
r = (ipiv[p]-1)%nb+1
if m == q
A[p,:], M[r,:] = M[r,:], A[p,:]
end
end
end

function LinearAlgebra.lu(A::DMatrix{T}, ::LinearAlgebra.RowMaximum; check::Bool = true, allowsingular::Bool = false) where {T<:LinearAlgebra.BlasFloat}
A_copy = LinearAlgebra._lucopy(A, LinearAlgebra.lutype(T))
return LinearAlgebra.lu!(A_copy, LinearAlgebra.RowMaximum(); check, allowsingular)
end
function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.RowMaximum; check::Bool = true, allowsingular::Bool = false) where {T<:LinearAlgebra.BlasFloat}
check && LinearAlgebra.LAPACK.chkfinite(A)

zone = one(T)
mzone = -one(T)

info = Ref(0)

mb, nb = A.partitioning.blocksize
min_mb_nb = min(mb, nb)
local ipiv
maybe_copy_buffered(A => Blocks(min_mb_nb, min_mb_nb)) do A
Ac = A.chunks
mb, nb = A.partitioning.blocksize
mt, nt = size(Ac)
m, n = size(A)

ipiv = DVector(collect(1:min(m, n)), Blocks(nb))
ipivc = ipiv.chunks

max_piv_idx = zeros(Int, mt)
max_piv_val = zeros(T, mt)

Dagger.spawn_datadeps() do
for k in 1:min(mt, nt)
for p in 1:min(nb, m-(k-1)*nb, n-(k-1)*nb)
Dagger.@spawn searchmax_pivot!(Out(view(max_piv_idx, k:k)), Out(view(max_piv_val, k:k)), In(view(Ac[k,k],p:min(nb,m-(k-1)*nb),p:p)), p-1)
for i in k+1:mt
Dagger.@spawn searchmax_pivot!(Out(view(max_piv_idx, i:i)), Out(view(max_piv_val, i:i)), In(view(Ac[i,k],:,p:p)))
end
Dagger.@spawn update_ipiv!(InOut(view(ipivc[k],p:p)), InOut(info), In(view(max_piv_idx, k:mt)), In(view(max_piv_val, k:mt)), k, nb)
for i in k:mt
Dagger.@spawn swaprows_panel!(InOut(Ac[k, k]), InOut(Ac[i, k]), In(view(ipivc[k],p:p)), i, p, nb)
end
if length(p+1:min(nb,m-(k-1)*nb)) > 0
Dagger.@spawn update_panel!(InOut(view(Ac[k,k],p+1:min(nb,m-(k-1)*nb),:)), In(Ac[k,k]), p)
end
for i in k+1:mt
Dagger.@spawn update_panel!(InOut(Ac[i, k]), In(Ac[k,k]), p)
end
end
for j in Iterators.flatten((1:k-1, k+1:nt))
for i in k:mt
Dagger.@spawn swaprows_trail!(InOut(Ac[k, j]), InOut(Ac[i, j]), In(ipivc[k]), i, mb)
end
end
for j in k+1:nt
Dagger.@spawn BLAS.trsm!('L', 'L', 'N', 'U', zone, In(Ac[k, k]), InOut(Ac[k, j]))
for i in k+1:mt
Dagger.@spawn BLAS.gemm!('N', 'N', mzone, In(Ac[i, k]), In(Ac[k, j]), zone, InOut(Ac[i, j]))
end
end
end
end

check && LinearAlgebra._check_lu_success(info[], allowsingular)
end

return LinearAlgebra.LU{T,DMatrix{T},DVector{Int}}(A, ipiv, info[])
end
Loading
Loading