Skip to content

Commit 0cf9660

Browse files
committed
fix: Specialized ReshapedArray dispatch to resolve setindex! ambiguities
1 parent f516fc2 commit 0cf9660

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

src/host/indexing.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,14 @@ end
167167
function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
168168
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
169169
end
170-
# And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization.
171-
function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
170+
171+
#Implementation for ReshapedArrays using Cartesian indexing to resolve dispatch ties.
172+
function Base._unsafe_setindex!(::Base.IndexCartesian, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, N}) where {T, N}
173+
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
174+
end
175+
176+
#Implementation for ReshapedArrays using Linear indexing to resolve dispatch ties.
177+
function Base._unsafe_setindex!(::Base.IndexLinear, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, N}) where {T, N}
172178
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
173179
end
174180

test/testsuite/indexing.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,68 @@ end
284284
@test compare(argmin, AT, -rand(Int, 10))
285285
end
286286
end
287+
288+
@testsuite "indexing combinatorial" (AT, eltypes) -> begin
289+
@testset "Reshaped SubArray dispatch" for T in eltypes
290+
@testset "3D slice assignment" begin
291+
A = AT(ones(T, 4, 4, 4))
292+
@views V = A[:, :, 1:2]
293+
@allowscalar begin
294+
@test_nowarn V .= zero(T)
295+
@test all(Array(V) .== zero(T))
296+
end
297+
end
298+
299+
@testset "Logical mask view (dim = 3)" begin
300+
A = AT(ones(T, 4, 4, 4))
301+
mask = Bool[true, false, true, false]
302+
@views V = A[:, :, mask]
303+
@allowscalar begin
304+
@test_nowarn V .+= T(2)
305+
@test all(Array(V) .== T(3))
306+
end
307+
end
308+
309+
@testset "Nested Reshape" begin
310+
A = AT(ones(T, 4, 4, 4))
311+
V = view(A, 1:2, 1:2, 1:2)
312+
R1 = reshape(V, 4, 2)
313+
R2 = reshape(R1, :)
314+
@allowscalar begin
315+
@test_nowarn R2 .+= one(T)
316+
@test all(Array(R2) .== T(2))
317+
end
318+
end
319+
end
320+
321+
@testset "Permuted and Reinterpreted Views" for T in eltypes
322+
@testset "Reshaped PermutedDims" begin
323+
A = AT(ones(T, 4, 4))
324+
P = PermutedDimsArray(A, (2, 1))
325+
R = reshape(P, :)
326+
@allowscalar begin
327+
@test_nowarn R[1:2] .= zero(T)
328+
@test Array(R)[1] == zero(T)
329+
end
330+
end
331+
332+
@testset "Reshaped Reinterpreted" begin
333+
A = AT(ones(T, 4, 4))
334+
IT = T <: Complex ? Complex{Int16} : Int16
335+
R = reshape(reinterpret(IT, A), :)
336+
@allowscalar begin
337+
@test_nowarn R[1] = zero(IT)
338+
@test Array(R)[1] == zero(IT)
339+
end
340+
end
341+
end
342+
343+
@testset "Data parity with compare()" for T in eltypes
344+
@test compare(AT, rand(T, 8, 8, 8)) do A
345+
mask = isodd.(1:size(A, 2))
346+
@views V = A[:, mask, :]
347+
@allowscalar V .+= one(T)
348+
A
349+
end
350+
end
351+
end

0 commit comments

Comments
 (0)