Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b43ce1b
Change nthreads default to nothing in solver
matthieugomez Jan 28, 2026
d54b08a
Update AbstractFixedEffectSolver to use optional nthreads
matthieugomez Jan 29, 2026
c5b7c2c
Allow optional nthreads parameter in AbstractFixedEffectSolver
matthieugomez Jan 29, 2026
116cacd
Change nthreads parameter to 'nothing' in solver functions
matthieugomez Jan 29, 2026
c82f58b
Bump version from 2.5.2 to 2.6.0
matthieugomez Jan 29, 2026
7787b6a
Update method and double_precision arguments in functions
matthieugomez Jan 29, 2026
c255451
Fix capitalization of 'Metal' in method argument
matthieugomez Jan 29, 2026
3c2eb65
Update benchmark_Metal.jl
matthieugomez Jan 29, 2026
f1a7738
Reformat function signatures for consistency
matthieugomez Jan 29, 2026
67eba75
Decrease maxiter by 1 in lsmr! call
matthieugomez Jan 29, 2026
ca26e10
safer to use Int for big arrays and not more costly
matthieugomez Jan 29, 2026
7d3c214
Merge branch 'patch-15' of https://github.com/matthieugomez/FixedEffe…
matthieugomez Jan 29, 2026
34ff447
Update MetalExt.jl
matthieugomez Jan 29, 2026
647ea08
better to do chunkis of 100_000 even if it means more threads than Th…
matthieugomez Jan 29, 2026
1b8d7de
Update SolverCPU.jl
matthieugomez Jan 29, 2026
f4dc918
used shared arrays for Metal
matthieugomez Jan 30, 2026
c365731
Update Project.toml
matthieugomez Jan 30, 2026
ea5b1ab
Update AbstractFixedEffectSolver.jl
matthieugomez Jan 30, 2026
3528e2b
Update MetalExt.jl
matthieugomez Jan 30, 2026
3de3559
Merge branch 'master' into patch-15
matthieugomez Jan 30, 2026
38b861b
rmv nthreads
matthieugomez Jan 30, 2026
bea3c57
Merge branch 'patch-15' of https://github.com/matthieugomez/FixedEffe…
matthieugomez Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name = "FixedEffects"
uuid = "c8885935-8500-56a7-9867-7708b20db0eb"
version = "2.6.0"
version = "2.7.0"


[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
38 changes: 23 additions & 15 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module CUDAExt
using FixedEffects, CUDA
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap, copy_internal!
CUDA.allowscalar(false)

##############################################################################
Expand Down Expand Up @@ -36,17 +36,17 @@ mutable struct FixedEffectLinearMapCUDA{T} <: AbstractFixedEffectLinearMap{T}
fes::Vector{<:FixedEffect}
scales::Vector{<:AbstractVector}
caches::Vector{<:AbstractVector}
nthreads::Int
end

function FixedEffectLinearMapCUDA{T}(fes::Vector{<:FixedEffect}, nthreads) where {T}
function FixedEffectLinearMapCUDA{T}(fes::Vector{<:FixedEffect}) where {T}
fes = [_cu(T, fe) for fe in fes]
scales = [CUDA.zeros(T, fe.n) for fe in fes]
caches = [CUDA.zeros(T, length(fes[1].interaction)) for fe in fes]
return FixedEffectLinearMapCUDA{T}(fes, scales, caches, nthreads)
return FixedEffectLinearMapCUDA{T}(fes, scales, caches)
end

function FixedEffects.gather!(fecoef::CuVector, refs::CuVector, α::Number, y::CuVector, cache::CuVector, nthreads::Integer)
function FixedEffects.gather!(fecoef::CuVector, refs::CuVector, α::Number, y::CuVector, cache::CuVector)
nthreads = 256
nblocks = cld(length(y), nthreads)
@cuda threads=nthreads blocks=nblocks gather_kernel!(fecoef, refs, α, y, cache)
end
Expand All @@ -61,7 +61,8 @@ function gather_kernel!(fecoef, refs, α, y, cache)
end
end

function FixedEffects.scatter!(y::CuVector, α::Number, fecoef::CuVector, refs::CuVector, cache::CuVector, nthreads::Integer)
function FixedEffects.scatter!(y::CuVector, α::Number, fecoef::CuVector, refs::CuVector, cache::CuVector)
nthreads = 256
nblocks = cld(length(y), nthreads)
@cuda threads=nthreads blocks=nblocks scatter_kernel!(y, α, fecoef, refs, cache)
end
Expand Down Expand Up @@ -101,11 +102,7 @@ function FixedEffects.AbstractFixedEffectSolver{T}(fes::Vector{<:FixedEffect}, w
end

function FixedEffects.AbstractFixedEffectSolver{T}(fes::Vector{<:FixedEffect}, weights::AbstractWeights, ::Type{Val{:CUDA}}, nthreads = nothing) where {T}
if nthreads === nothing
nthreads = 256
end
nthreads = prevpow(2, nthreads)
m = FixedEffectLinearMapCUDA{T}(fes, nthreads)
m = FixedEffectLinearMapCUDA{T}(fes)
b = CUDA.zeros(T, length(weights))
r = CUDA.zeros(T, length(weights))
x = FixedEffectCoefficients([CUDA.zeros(T, fe.n) for fe in fes])
Expand All @@ -120,15 +117,16 @@ end
function FixedEffects.update_weights!(feM::FixedEffectSolverCUDA{T}, weights::AbstractWeights) where {T}
copyto!(feM.weights, _cu(T, weights))
for (scale, fe) in zip(feM.m.scales, feM.m.fes)
scale!(scale, fe.refs, fe.interaction, feM.weights, feM.m.nthreads)
scale!(scale, fe.refs, fe.interaction, feM.weights)
end
for (cache, scale, fe) in zip(feM.m.caches, feM.m.scales, feM.m.fes)
cache!(cache, fe.refs, fe.interaction, feM.weights, scale, feM.m.nthreads)
cache!(cache, fe.refs, fe.interaction, feM.weights, scale)
end
return feM
end

function scale!(scale::CuVector, refs::CuVector, interaction::CuVector, weights::CuVector, nthreads::Integer)
function scale!(scale::CuVector, refs::CuVector, interaction::CuVector, weights::CuVector)
nthreads = 256
nblocks = cld(length(refs), nthreads)
fill!(scale, 0)
@cuda threads=nthreads blocks=nblocks scale_kernel!(scale, refs, interaction, weights)
Expand All @@ -145,7 +143,8 @@ function scale_kernel!(scale, refs, interaction, weights)
end
end

function cache!(cache::CuVector, refs::CuVector, interaction::CuVector, weights::CuVector, scale::CuVector, nthreads::Integer)
function cache!(cache::CuVector, refs::CuVector, interaction::CuVector, weights::CuVector, scale::CuVector)
nthreads = 256
nblocks = cld(length(cache), nthreads)
@cuda threads=nthreads blocks=nblocks cache!_kernel!(cache, refs, interaction, weights, scale)
end
Expand All @@ -160,6 +159,15 @@ function cache!_kernel!(cache, refs, interaction, weights, scale)
end
end

function FixedEffects.copy_internal!(feM::FixedEffectSolverCUDA, field::Symbol, r::AbstractVector)
copyto!(feM.tmp, r)
copyto!(getfield(feM, field), feM.tmp)
end

function FixedEffects.copy_internal!(r::AbstractVector, feM::FixedEffectSolverCUDA, field::Symbol)
copyto!(feM.tmp, getfield(feM, field))
copyto!(r, feM.tmp)
end


end
73 changes: 43 additions & 30 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module MetalExt
using FixedEffects, Metal
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap, copy_internal!
Metal.allowscalar(false)

##############################################################################
Expand Down Expand Up @@ -35,50 +35,53 @@ mutable struct FixedEffectLinearMapMetal{T} <: AbstractFixedEffectLinearMap{T}
fes::Vector{<:FixedEffect}
scales::Vector{<:AbstractVector}
caches::Vector
nthreads::Int
end

function bucketize_refs(refs::Vector, n::Int)
# count the number of obs per group
counts = zeros(Int, n)
@inbounds for r in refs
counts[r] += 1
end
counts = zeros(Int, n)
@inbounds for r in refs
counts[r] += 1
end
# offsets is vcat(1, cumsum(counts))
offsets = Vector{Int}(undef, n + 1)
offsets_mtl = Metal.@sync Metal.zeros(Int, n + 1; storage = Metal.SharedStorage)
offsets = unsafe_wrap(Array{Int}, offsets_mtl, size(offsets_mtl))
offsets[1] = 1
@inbounds for k in 1:n
offsets[k+1] = offsets[k] + counts[k]
end

perm_mtl = Metal.@sync Metal.zeros(Int, length(refs); storage = Metal.SharedStorage)
perm = unsafe_wrap(Array{Int}, perm_mtl, size(perm_mtl))
next = offsets[1:n]
perm = Vector{Int}(undef, length(refs))
@inbounds for i in eachindex(refs)
r = refs[i]
p = next[r]
perm[p] = i
next[r] = p + 1
end
return perm, offsets
return perm_mtl, offsets_mtl
end

function FixedEffectLinearMapMetal{T}(fes::Vector{<:FixedEffect}, nthreads) where {T}
function FixedEffectLinearMapMetal{T}(fes::Vector{<:FixedEffect}) where {T}
fes2 = [_mtl(T, fe) for fe in fes]
scales = [Metal.zeros(T, fe.n) for fe in fes]
caches = [[Metal.zeros(T, length(fe.refs)), Metal.zeros(Int, 1), Metal.zeros(Int, 1)] for fe in fes]
caches = [Any[Metal.zeros(T, length(fe.refs)), Metal.zeros(Int, 1), Metal.zeros(Int, 1)] for fe in fes]
Threads.@threads for i in 1:length(fes)
refs = fes[i].refs
n = fes[i].n
if n < min(100_000, div(length(refs), 16))
out = bucketize_refs(refs, n)
caches[i][2] = MtlArray(out[1])
caches[i][3] = MtlArray(out[2])
caches[i][2] = out[1]
caches[i][3] = out[2]
end
end
return FixedEffectLinearMapMetal{T}(fes2, scales, caches, nthreads)
return FixedEffectLinearMapMetal{T}(fes2, scales, caches)
end

function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache::Vector, nthreads::Integer)
function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache::Vector)
n = length(fecoef)
nthreads = Int(device().maxThreadsPerThreadgroup.width)
if n < min(100_000, div(length(refs), 16))
Metal.@sync @metal threads=nthreads groups=n gather_kernel_bin!(fecoef, refs, α, y, cache[1], cache[2], cache[3], Val(nthreads))
else
Expand Down Expand Up @@ -138,7 +141,8 @@ function gather_kernel!(fecoef, refs, α, y, cache)
return nothing
end

function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache::Vector, nthreads::Integer)
function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache::Vector)
nthreads = Int(device().maxThreadsPerThreadgroup.width)
nblocks = cld(length(y), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks scatter_kernel!(y, α, fecoef, refs, cache[1])
end
Expand Down Expand Up @@ -168,40 +172,36 @@ mutable struct FixedEffectSolverMetal{T} <: FixedEffects.AbstractFixedEffectSolv
v::FixedEffectCoefficients{<: AbstractVector{T}}
h::FixedEffectCoefficients{<: AbstractVector{T}}
hbar::FixedEffectCoefficients{<: AbstractVector{T}}
tmp::Vector{T} # used to convert AbstractVector to Vector{T}
fes::Vector{<:FixedEffect}
end


function FixedEffects.AbstractFixedEffectSolver{T}(fes::Vector{<:FixedEffect}, weights::AbstractWeights, ::Type{Val{:Metal}}, nthreads = nothing) where {T}
if nthreads === nothing
nthreads = Int(device().maxThreadsPerThreadgroup.width)
end
nthreads = prevpow(2, nthreads)
m = FixedEffectLinearMapMetal{T}(fes, nthreads)
b = Metal.zeros(T, length(weights))
r = Metal.zeros(T, length(weights))
m = FixedEffectLinearMapMetal{T}(fes)
b = Metal.zeros(T, length(weights); storage = Metal.SharedStorage)
r = Metal.zeros(T, length(weights); storage = Metal.SharedStorage)
x = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
v = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
h = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
hbar = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
tmp = zeros(T, length(weights))
feM = FixedEffectSolverMetal{T}(m, Metal.zeros(T, length(weights)), b, r, x, v, h, hbar, tmp, fes)
feM = FixedEffectSolverMetal{T}(m, Metal.zeros(T, length(weights)), b, r, x, v, h, hbar, fes)
FixedEffects.update_weights!(feM, weights)
end


function FixedEffects.update_weights!(feM::FixedEffectSolverMetal{T}, weights::AbstractWeights) where {T}
copyto!(feM.weights, _mtl(T, weights))
for (scale, fe) in zip(feM.m.scales, feM.m.fes)
scale!(scale, fe.refs, fe.interaction, feM.weights, feM.m.nthreads)
scale!(scale, fe.refs, fe.interaction, feM.weights)
end
for (cache, scale, fe) in zip(feM.m.caches, feM.m.scales, feM.m.fes)
cache!(cache, fe.refs, fe.interaction, feM.weights, scale, feM.m.nthreads)
cache!(cache, fe.refs, fe.interaction, feM.weights, scale)
end
return feM
end

function scale!(scale::MtlVector, refs::MtlVector, interaction::MtlVector, weights::MtlVector, nthreads::Integer)
function scale!(scale::MtlVector, refs::MtlVector, interaction::MtlVector, weights::MtlVector)
nthreads = Int(device().maxThreadsPerThreadgroup.width)
nblocks = cld(length(refs), nthreads)
fill!(scale, 0)
Metal.@sync @metal threads=nthreads groups=nblocks scale_kernel!(scale, refs, interaction, weights)
Expand All @@ -224,7 +224,8 @@ function inv_kernel!(scale, T)
return nothing
end

function cache!(cache, refs::MtlVector, interaction::MtlVector, weights::MtlVector, scale::MtlVector, nthreads::Integer)
function cache!(cache, refs::MtlVector, interaction::MtlVector, weights::MtlVector, scale::MtlVector)
nthreads = Int(device().maxThreadsPerThreadgroup.width)
nblocks = cld(length(cache[1]), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks cache!_kernel!(cache[1], refs, interaction, weights, scale)
end
Expand All @@ -237,5 +238,17 @@ function cache!_kernel!(cache, refs, interaction, weights, scale)
return nothing
end

function FixedEffects.copy_internal!(feM::FixedEffectSolverMetal{T}, field::Symbol, r::AbstractVector) where {T}
synchronize()
feM_r = unsafe_wrap(Array{T}, getfield(feM, field), size(getfield(feM, field)))
copyto!(feM_r, r)
end

function FixedEffects.copy_internal!(r::AbstractVector, feM::FixedEffectSolverMetal{T}, field::Symbol) where {T}
synchronize()
feM_r = unsafe_wrap(Array{T}, getfield(feM, field), size(getfield(feM, field)))
copyto!(r, feM_r)
end


end
4 changes: 2 additions & 2 deletions src/AbstractFixedEffectLinearMap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function LinearAlgebra.mul!(fecoefs::FixedEffectCoefficients,
fem = adjoint(Cfem)
rmul!(fecoefs, β)
for (fecoef, fe, cache) in zip(fecoefs.x, fem.fes, fem.caches)
gather!(fecoef, fe.refs, α, y, cache, fem.nthreads)
gather!(fecoef, fe.refs, α, y, cache)
end
return fecoefs
end
Expand All @@ -38,7 +38,7 @@ function LinearAlgebra.mul!(y::AbstractVector, fem::AbstractFixedEffectLinearMap
fecoefs::FixedEffectCoefficients, α::Number, β::Number)
rmul!(y, β)
for (fecoef, fe, cache) in zip(fecoefs.x, fem.fes, fem.caches)
scatter!(y, α, fecoef, fe.refs, cache, fem.nthreads)
scatter!(y, α, fecoef, fe.refs, cache)
end
return y
end
Expand Down
28 changes: 6 additions & 22 deletions src/AbstractFixedEffectSolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
##
##############################################################################
abstract type AbstractFixedEffectSolver{T} end
works_with_view(::AbstractFixedEffectSolver) = false

"""
`solve_residuals!(y, fes, w; method = :cpu, double_precision = method == :cpu, tol = 1e-8, maxiter = 10000)`
Expand Down Expand Up @@ -43,22 +42,17 @@ function solve_residuals!(y::Union{AbstractVector{<: Real}, AbstractMatrix{<: Re
nthreads = nothing)
any((length(fe) != size(y, 1) for fe in fes)) && throw("FixedEffects must have the same length as y")
any(ismissing.(fes)) && throw("FixedEffects must not have missing values")
feM = AbstractFixedEffectSolver{double_precision ? Float64 : Float32}(fes, w, Val{method}, nthreads)
feM = AbstractFixedEffectSolver{double_precision ? Float64 : Float32}(fes, w, Val{method})
solve_residuals!(y, feM; maxiter = maxiter, tol = tol)
end



function solve_residuals!(r::AbstractVector{<:Real}, feM::AbstractFixedEffectSolver{T}; tol::Real = sqrt(eps(T)), maxiter::Integer = 100_000) where {T}
# One cannot copy view of Vector (r) on GPU, so first collect the vector
if works_with_view(feM)
copyto!(feM.r, r)
else
copyto!(feM.tmp, r)
copyto!(feM.r, feM.tmp)
end
copy_internal!(feM, :r, r)
if !(feM.weights isa UnitWeights)
feM.r .*= sqrt.(feM.weights)
feM.r .*= sqrt.(feM.weights)
end
copyto!(feM.b, feM.r)
mul!(feM.x, feM.m', feM.b, 1, 0)
Expand All @@ -71,12 +65,7 @@ function solve_residuals!(r::AbstractVector{<:Real}, feM::AbstractFixedEffectSol
if !(feM.weights isa UnitWeights)
feM.r ./= sqrt.(feM.weights)
end
if works_with_view(feM)
copyto!(r, feM.r)
else
copyto!(feM.tmp, feM.r)
copyto!(r, feM.tmp)
end
copy_internal!(r, feM, :r)
return r, iter, converged
end

Expand Down Expand Up @@ -160,18 +149,13 @@ function solve_coefficients!(y::AbstractVector{<: Number}, fes::AbstractVector{<
nthreads = nothing)
any(ismissing.(fes)) && throw("Some FixedEffect has a missing value for reference or interaction")
any((length(fe) != length(y) for fe in fes)) && throw("FixedEffects must have the same length as y")
feM = AbstractFixedEffectSolver{double_precision ? Float64 : Float32}(fes, w, Val{method}, nthreads)
feM = AbstractFixedEffectSolver{double_precision ? Float64 : Float32}(fes, w, Val{method})
solve_coefficients!(y, feM; maxiter = maxiter, tol = tol)
end

function FixedEffects.solve_coefficients!(r::AbstractVector, feM::AbstractFixedEffectSolver{T}; tol::Real = sqrt(eps(T)), maxiter::Integer = 100_000) where {T}
# One cannot copy view of Vector (r) on GPU, so first collect the vector
if works_with_view(feM)
copyto!(feM.b, r)
else
copyto!(feM.tmp, r)
copyto!(feM.b, feM.tmp)
end
copy_internal!(feM, :b, r)
if !(feM.weights isa UnitWeights)
feM.b .*= sqrt.(feM.weights)
end
Expand Down
Loading