CUDA kernel sync fixes, test_layer enhancement#3132
CUDA kernel sync fixes, test_layer enhancement#3132Cydral wants to merge 1 commit intodavisking:masterfrom
Conversation
|
@davisking, Hi Davis, once PR #3132 is resolved, I will be able to share a full update for the ACT processing layer. |
|
@davisking, Because the ACT layer update represents a significant refactoring that also introduces the foundational architecture for managing sub-networks within computational layers, would you have time please to review and merge this pending change? This would allow me to sync my dev repository with the latest version before submitting the new subsequent PR for the ACT layer. Thanks in advance. |
| __global__ void _cuda_inverse_norms_accumulate( | ||
| float* invnorms, | ||
| const float* data, | ||
| size_t nr, | ||
| size_t nc | ||
| ) | ||
| { | ||
| // initialize invnorms before we begin. | ||
| for (auto i : grid_stride_range_y(0, nr)) | ||
| for (auto j : grid_stride_range(0, 1)) | ||
| invnorms[i] = eps; | ||
| __syncthreads(); | ||
|
|
There was a problem hiding this comment.
This was correct and so should not be changed. The launch_kernel function sets it up so that this kind of thing is correct. This comment in launch_kernel explains it:
/*
In general, the reason m.num_y!=1 (i.e. the reason you are in this
code path) is because we are using nested grid-stride loops. There are
two important things to note about what we are doing here. To
illustrate them we will talk about this little CUDA code snippet:
// initialize out before we begin.
for (auto i : grid_stride_range_y(0, nr))
for (auto j : grid_stride_range(0, 1))
out[i] = 0;
__syncthreads(); // synchronize threads in block
// loop over some 2D thing and sum and store things into out.
for (auto i : grid_stride_range_y(0, nr))
{
float temp = 0;
for (auto j : grid_stride_range(0, nc))
temp += whatever[i*nc+j];
// store the sum into out[i]
warp_reduce_atomic_add(out[i], temp);
}
First, we make sure the number of x threads is a multiple of 32 so that
you can use warp_reduce_atomic_add() inside the y loop.
Second, we put the x block size to 1 so inter-block synchronization is
easier. For example, if the number of x blocks wasn't 1 the above code
would have a race condition in it. This is because the execution of
out[i]=0 would be done by blocks with blockIdx.x==0, but then in the
second set of loops, *all* the x blocks use out[i]. Since
__syncthreads() doesn't do any synchronization between blocks some of
the blocks might begin before the out[i]=0 statements finished and that
would be super bad.
*/
So the existing uses of __syncthreads() are correct (other than that new buggy one we were talking about the other day that precipitated these changes).
There was a problem hiding this comment.
Thank you for the detailed explanation about the launch_kernel design and the idempotent initialization pattern.
You're right that the existing pattern works correctly for idempotent operations like out[i] = 0. However, I'd like to make a case for the kernel decomposition approach as a safer and more maintainable alternative:
- The core issue with rms_normalize:
- The operation scale[n] = 1/sqrt(scale[n] + eps) is non-idempotent, which causes incorrect results when multiple warp threads execute it via grid_stride_range(0, 1). This is a subtle bug that's easy to introduce and hard to detect.
Why kernel decomposition is beneficial:
- Eliminates idempotence reasoning: developers don't need to verify whether each operation is safe for multiple executions. The sequential kernel approach is always correct.
- More explicit phase separation: the code clearly shows the computational phases, making it easier to understand and maintain.
- Future-proof: if someone later modifies an initialization to be non-idempotent (e.g., changing out[i] = 0 to out[i] = some_computation()), the code still works correctly.
- No performance penalty: sequential launch_kernel calls provide implicit synchronization with similar performance characteristics.
While the other functions (inverse_norms, dot_prods, layer_normalize) currently use probably idempotent operations and work correctly, the kernel decomposition approach provides a consistent, safer pattern across all functions.
That said, I understand if you prefer to keep the existing pattern for functions that are already correct. But rms_normalize genuinely needs this fix.
There was a problem hiding this comment.
Eh, splitting this stuff up into multiple kernels is slower since kernel launches aren't free. There isn't anything wrong with __syncthreads(). It just has to be used correctly. That goes for a bunch of stuff about CUDA.
But sometimes you need multiple kernels and that's fine too when it's really needed.
Summary
This PR addresses critical CUDA synchronization issues and enhances the
test_layerutility function.CUDA Kernel Fixes
Several CUDA kernels were using
__syncthreads()for cross-block synchronization, which is incorrect since__syncthreads()only synchronizes threads within the same block, not across different blocks. Whengrid_stride_range_ydistributes work across multiple blocks, these synchronization barriers fail silently.Affected functions decomposed into separate kernels:
inverse_norms()dot_prods()multiply_conv()layer_normalize()rms_normalize()compute_act_halt_probabilities()The fix replaces intra-kernel
__syncthreads()with sequentiallaunch_kernel()calls, which provide implicit synchronization between kernel executions.test_layer Enhancement
Modified
test_layerto accept optional parameters for testing layers with specific tensor input constraints, enabling proper gradient verification for layers that require particular input dimensions.Related Discussion
Follow-up to #3128