@@ -231,10 +231,11 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
231231 #
232232 # threads in a group work together to reduce values across the reduction dimensions;
233233 # we want as many as possible to improve algorithm efficiency and execution occupancy.
234- wanted_threads = shuffle ? nextwarp(kernel. kern. pipeline, length(Rreduce)) : length(Rreduce)
235- function compute_threads(max_threads)
234+ function compute_threads(kern)
235+ max_threads = KI. kernel_max_work_group_size(backend, kern)
236+ wanted_threads = shuffle ? nextwarp(kern. kern. pipeline, length(Rreduce)) : length(Rreduce)
236237 if wanted_threads > max_threads
237- shuffle ? prevwarp(kernel . kern. pipeline, max_threads) : max_threads
238+ shuffle ? prevwarp(kern . kern. pipeline, max_threads) : max_threads
238239 else
239240 wanted_threads
240241 end
@@ -244,7 +245,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
244245 # kernel above may be greater than the maxTotalThreadsPerThreadgroup of the eventually launched
245246 # kernel below, causing errors
246247 # reduce_threads = compute_threads(kernel.pipeline.maxTotalThreadsPerThreadgroup)
247- reduce_threads = compute_threads(KI . kernel_max_work_group_size(backend, kernel) )
248+ reduce_threads = compute_threads(kernel)
248249
249250 # how many groups should we launch?
250251 #
@@ -265,21 +266,33 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
265266 Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A;
266267 numworkgroups= groups, workgroupsize= threads)
267268 else
268- # we need multiple steps to cover all values to reduce
269- partial = similar(R, (size(R). .. , reduce_groups))
269+ # temporary empty array whose type will match the final partial array
270+ partial = similar(R, ntuple(_ -> 0 , Val(ndims(R)+ 1 )))
271+
272+ # NOTE: we can't use the previously-compiled kernel, or its launch configuration,
273+ # since the type of `partial` might not match the original output container
274+ # (e.g. if that was a view).
275+ partial_kernel = KI. KIKernel(backend, partial_mapreduce_device,
276+ f, op, init, Val(threads), Val(Rreduce),
277+ Val(Rother), Val(UInt64(length(Rother))),
278+ Val(grain), Val(shuffle), partial, A)
279+ partial_reduce_threads = compute_threads(partial_kernel)
280+ partial_reduce_groups = cld(length(Rreduce), partial_reduce_threads * grain)
281+
282+ partial_threads = partial_reduce_threads
283+ partial_groups = partial_reduce_groups* other_groups
284+
285+ partial = similar(R, (size(R). .. , partial_reduce_groups))
270286 if init === nothing
271287 # without an explicit initializer we need to copy from the output container
272- # use broadcasting to extend singleton dimensions
273288 partial .= R
274289 end
275- # NOTE: we can't use the previously-compiled kernel, since the type of `partial`
276- # might not match the original output container (e.g. if that was a view).
277- KI. KIKernel(backend, partial_mapreduce_device,
278- f, op, init, Val(threads), Val(Rreduce), Val(Rother),
279- Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A)(
280- f, op, init, Val(threads), Val(Rreduce), Val(Rother),
281- Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A;
282- numworkgroups= groups, workgroupsize= threads)
290+
291+ partial_kernel(f, op, init, Val(threads), Val(Rreduce),
292+ Val(Rother), Val(UInt64(length(Rother))),
293+ Val(grain), Val(shuffle), partial, A;
294+ numworkgroups= partial_groups, workgroupsize= partial_threads)
295+
283296
284297 GPUArrays. mapreducedim!(identity, op, R, partial; init= init)
285298 end
0 commit comments