Skip to content

u32 overflow in spread_cube_count_plan breaks large matmul operations #66

@jeff-hiner

Description

@jeff-hiner

The spread_cube_count_plan function in cubek-matmul uses u32 arithmetic when computing total_cubes = m_cubes * n_cubes * batch_cubes. This overflows for large matrices, causing a panic or incorrect workgroup allocation.

Impact

I ran into this while trying to implement the VAE for Stable Diffusion (64x64). Specifically, Stable Diffusion VAE decoding at 512×512 resolution fails because the VAE's mid-block attention layer operates on 4096 tokens (64×64 spatial dimensions), creating a 4096×4096 attention matrix during Q×KT computation.

Root Cause

In crates/cubek-matmul/src/definition/hypercube/cube_count/plan.rs, line 271:

let total_cubes = m_cubes * n_cubes * batch_cubes;  // u32 * u32 * u32 → overflow

When m_cubes and n_cubes are large, the product exceeds u32::MAX, causing silent wraparound or panic depending on build profile. (Thankfully, I saw a panic.)

Fix

Use u64 arithmetic for intermediate calculations, only truncating back to u32 for the final (x, y, z) dimensions which are bounded by GPU limits:

let total_cubes = m_cubes as u64 * n_cubes as u64 * batch_cubes as u64;

The same pattern should be applied to volume and intermediate x/xy_cubes calculations.

Reproduction

Any matmul where m_cubes * n_cubes * batch_cubes > u32::MAX. In practice, this occurs with:

  • Attention mechanisms with seq_len ≥ 4096
  • Large batch sizes on moderate sequence lengths
  • Any combination where the cube count product exceeds ~4.3 billion

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions