-
Notifications
You must be signed in to change notification settings - Fork 24
Description
The spread_cube_count_plan function in cubek-matmul uses u32 arithmetic when computing total_cubes = m_cubes * n_cubes * batch_cubes. This overflows for large matrices, causing a panic or incorrect workgroup allocation.
Impact
I ran into this while trying to implement the VAE for Stable Diffusion (64x64). Specifically, Stable Diffusion VAE decoding at 512×512 resolution fails because the VAE's mid-block attention layer operates on 4096 tokens (64×64 spatial dimensions), creating a 4096×4096 attention matrix during Q×KT computation.
Root Cause
In crates/cubek-matmul/src/definition/hypercube/cube_count/plan.rs, line 271:
let total_cubes = m_cubes * n_cubes * batch_cubes; // u32 * u32 * u32 → overflow
When m_cubes and n_cubes are large, the product exceeds u32::MAX, causing silent wraparound or panic depending on build profile. (Thankfully, I saw a panic.)
Fix
Use u64 arithmetic for intermediate calculations, only truncating back to u32 for the final (x, y, z) dimensions which are bounded by GPU limits:
let total_cubes = m_cubes as u64 * n_cubes as u64 * batch_cubes as u64;
The same pattern should be applied to volume and intermediate x/xy_cubes calculations.
Reproduction
Any matmul where m_cubes * n_cubes * batch_cubes > u32::MAX. In practice, this occurs with:
- Attention mechanisms with seq_len ≥ 4096
- Large batch sizes on moderate sequence lengths
- Any combination where the cube count product exceeds ~4.3 billion