Conversation
Did the workaround in FluxML/Zygote.jl#1378 not fix it? As mentioned in #684, should ideally be fixed in ChainRules nevertheless, but I'm a bit curious. |
Thanks for commenting. I think @willtebbutt said that he will have a look at these rules later on. |
|
Hi all, I rewrote the rules and now all the tests pass. There is probably opportunity to optimize them, please let me know. |
|
Ok, did not test on Julia 1.6. Apparently this requires special care |
|
Why don't we see the full stack traces here? Is it due to using JuliaInterpreter? |
|
Ok, I made the suggested changes and added tests to check the correct behavior of the projections. However, we have some type inference problem in the matrix-matrix case. |
|
The problem is this: julia> x = Diagonal(rand(2)); y = Diagonal(rand(2)); z, pb = rrule(kron, x, y);
julia> @code_warntype unthunk(pb(z)[2])
MethodInstance for ChainRulesCore.unthunk(::Thunk{ChainRules.var"#2318#2321"{Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}, ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}}}})
from unthunk(x::Thunk) @ ChainRulesCore ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204
Arguments
#self#::Core.Const(ChainRulesCore.unthunk)
x::Thunk{ChainRules.var"#2318#2321"{Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}, ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}}}}
Body::Any
1 ─ nothing
│ %2 = Base.getproperty(x, :f)::ChainRules.var"#2318#2321"{Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}, ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}}}
│ %3 = (%2)()::Any
└── return %3Any ideas how to make the @code_warntype dot(y, first(eachslice(dz; dims = (2, 4))))
MethodInstance for LinearAlgebra.dot(::Diagonal{Float64, Vector{Float64}}, ::SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false})
from dot(D::Diagonal, B::AbstractMatrix) @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/diagonal.jl:806
Arguments
#self#::Core.Const(LinearAlgebra.dot)
D::Diagonal{Float64, Vector{Float64}}
B::SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false}
Body::Any
1 ─ %1 = LinearAlgebra.size(D)::Tuple{Int64, Int64}
│ %2 = LinearAlgebra.size(B)::Tuple{Int64, Int64}
│ %3 = (%1 == %2)::Bool
└── goto #3 if not %3
2 ─ goto #4
3 ─ %6 = LinearAlgebra.size(D)::Tuple{Int64, Int64}
│ %7 = LinearAlgebra.size(B)::Tuple{Int64, Int64}
│ %8 = Base.string("Matrix sizes ", %6, " and ", %7, " differ")::String
│ %9 = LinearAlgebra.DimensionMismatch(%8)::Any
└── LinearAlgebra.throw(%9)
4 ┄ %11 = Base.getproperty(D, :diag)::Vector{Float64}
│ %12 = LinearAlgebra.diagind(B)::Core.PartialStruct(StepRange{Int64, Int64}, Any[Core.Const(1), Int64, Int64])
│ %13 = LinearAlgebra.view(B, %12)::Core.PartialStruct(SubArray{Float64, 1, Base.ReshapedArray{Float64, 1, SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{StepRange{Int64, Int64}}, false}, Any[Base.ReshapedArray{Float64, 1, SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Core.PartialStruct(Tuple{StepRange{Int64, Int64}}, Any[Core.PartialStruct(StepRange{Int64, Int64}, Any[Core.Const(1), Int64, Int64])]), Core.Const(0), Core.Const(0)])
│ %14 = LinearAlgebra.dot(%11, %13)::Any
└── return %14and I cannot fix that without collecting either |
| function kron_pullback(z̄) | ||
| dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2)) | ||
| x̄ = @thunk(project_x(_dot_collect.(Ref(y), eachslice(dz; dims = (2, 4))))) | ||
| ȳ = @thunk(project_y(_dot_collect.(Ref(x), eachslice(dz; dims = (1, 3))))) |
There was a problem hiding this comment.
I was wondering if you have to make slices, given that kron is just reshape and .*. So here's an attempt to do without:
using ChainRulesCore
function pr_rule(x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number}) # from https://github.com/JuliaDiff/ChainRules.jl/pull/741
project_x = ProjectTo(x)
project_y = ProjectTo(y)
function kron_pullback(z̄)
dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2))
x̄ = @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = (2, 4)))))
ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = (1, 3)))))
return NoTangent(), x̄, ȳ
end
end
# using TensorCast
# mykron(x,y) = @cast z[(a,b), (c,d)] := x[b,d] * y[a,c]
# @pretty @cast z[(a,b), (c,d)] := x[b,d] * y[a,c]
function shape_rule(x::AbstractMatrix, y::AbstractMatrix)
function back(dz)
x4 = reshape(x, 1, size(x,1), 1, size(x,2))
y4 = reshape(y, size(y,1), 1, size(y,2), 1)
dz4 = reshape(unthunk(dz), size(y,1), size(x,1), size(y,2), size(x,2))
dx = @thunk ProjectTo(x)(reshape(sum(dz4 .* y4, dims=(1,3)), size(x))) # might be missing conj
dy = @thunk ProjectTo(y)(reshape(sum(dz4 .* x4, dims=(2,4)), size(y)))
0, dx, dy
end
end
let x = rand(10,20), y = rand(30,10)
b1 = pr_rule(x, y)
b2 = shape_rule(x, y)
z = kron(x,y)
_, dx1, _ = @btime map(unthunk, $b1($z))
_, dx2, _ = @btime map(unthunk, $b2($z))
dx1 ≈ dx2
end
# min 181.458 μs, mean 185.668 μs (4 allocations, 4.39 KiB)
# min 80.583 μs, mean 169.305 μs (32 allocations, 943.05 KiB)
# trueIt's a pity to allocate these big arrays dz4 .* y4 but still seems quicker. Possibly we could use lazy broadcasting to avoid that:
bc = Broadcast.instantiate(Broadcast.broadcasted(*, [1 2 3], [4, 5]));
sum(bc) # OK
sum(bc; dims=1) # ERROR: MethodError: no method matching reducedim_init(::typeof(identity), ::typeof(Base.add_sum), ::Base.Broadcast.Broadcasted{…}, ::Int64)
sum!([0 0 0], bc) # ERROR: MethodError: no method matching sum!(::Matrix{Int64}, ::Base.Broadcast.Broadcasted
sum(bc; dims=1, init=0.0) # OK, not sure if it's fast or notOn StaticArrays (mentioned above) both at present make a SizedMatrix, which I think is ProjectTo's attempt to fix things up. Surely this reshaping could be done in a static-friendly way but IDK exactly how.
julia> let x = @SMatrix(rand(5,5)), y = @SMatrix(rand(5,5))
b1 = pr_rule(x, y)
b2 = shape_rule(x, y)
z = kron(x,y)
_, dx1, _ = @btime map(unthunk, $b1($z))
_, dx2, _ = @btime map(unthunk, $b2($z))
dx1 ≈ dx2
end
min 2.458 μs, mean 2.558 μs (2 allocations, 512 bytes)
min 4.006 μs, mean 5.198 μs (22 allocations, 11.38 KiB)
trueThere was a problem hiding this comment.
Does this result scale to larger arrays?
There was a problem hiding this comment.
Result meaning speed difference? It will vary with size & machine. On very small arrays reshaping is faster slower! (Like 3x3 I meant.)
Issues with StaticArrays will be similar at all sizes.
I think broadcasting over slices will work badly on CuArrays, and tend to make Arrays. But right now neither idea seems to work, not sure why
julia> using Metal
julia> bk = pr_rule(MtlArray(rand(Float32, 3,3)), MtlArray(rand(Float32, 3,3)));
julia> bk(MtlArray(rand(Float32, 9,9)))[2] |> unthunk
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#broadcast_kernel#26")(::Metal.mtlKernelContext, ::MtlDeviceMatrix{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument
julia> bk2 = shape_rule(MtlArray(rand(Float32, 3,3)), MtlArray(rand(Float32, 3,3)));
julia> bk2(MtlArray(rand(Float32, 9,9)))[2] |> unthunk
ERROR: could not load symbol "LLVMExtraAddPropagateJuliaAddrspaces":
dlsym(RTLD_DEFAULT, LLVMExtraAddPropagateJuliaAddrspaces): symbol not found
There was a problem hiding this comment.
If the reshape version is not strictly better than the current one, especially for large arrays, I would propose to keep the current version and put further optimizations in a separate PR.
There was a problem hiding this comment.
A bit curious at what sizes it's slower for you?
But mainly I think the issue is less about the race than that simple solid-array operations have a better chance of behaving well with StaticArrays, and CuArrays. I haven't taken another pass to see if the first draft can be improved on.
There was a problem hiding this comment.
I haven't benchmarked anything myself yet. I will give it a go later.
There was a problem hiding this comment.
Hmm, results seem to be mixed. For larger sizes the allocations are taking their price:
let x = rand(100,200), y = rand(300,100)
b1 = pr_rule(x, y)
b2 = shape_rule(x, y)
z = kron(x,y)
_, dx1, _ = @btime map(unthunk, $b1($z))
_, dx2, _ = @btime map(unthunk, $b2($z))
dx1 ≈ dx2
end
# 3.376 s (6 allocations: 390.84 KiB)
# 3.797 s (34 allocations: 8.94 GiB)
# true
I would suggest staying with the current implementation.
There was a problem hiding this comment.
One way to ensure any implementation isn't excluding all GPU array types would be to toss a @gpu in front of the new tests, no?
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: Seth Axen <seth@sethaxen.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
In Julia 1.9 there was an internal change in
kronthat introduced some mutation, which has made Zygote unable to differentiatekron. Here, we add some rules to restore that ability.Discovered in JuliaGaussianProcesses/TemporalGPs.jl#115