Skip to content

Conversation

@i-colbert
Copy link
Collaborator

@i-colbert i-colbert commented Feb 2, 2026

Reason for this PR

This PR extends graph equalization to enable MixQuant, which calibrates permutations to improve quantization accuracy when using block rotations. If calibrated intentionally (e.g., with mass diffusion), permutations can help balance the distribution of activation magnitudes across blocks prior to rotation. This is particularly beneficial for low-bit quantization (e.g., INT4, FP4) where outlier management is critical.

@article{sanjeet2026mixquant,
      title={MixQuant: Pushing the Limits of Block Rotations in Post-Training Quantization}, 
      author={Sai Sanjeet and Ian Colbert and Pablo Monteagudo-Lago and Giuseppe Franco and Yaman Umuroglu and Nicholas J. Fraser},
      year={2026},
      eprint={2601.22347},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2601.22347}, 
}

Changes Made in this PR

  • New Permutation Infrastructure (src/brevitas/graph/equalize.py):

    • Added PermuteGraph class to manage permutation computation and application
    • Implemented 4 permutation strategies: massdiff, zigzag, absmax, and random
    • Added apply_permute method to Region class for applying permutations
    • Created rotate_permute_mode context manager for unified rotation+permutation workflow
  • CLI Integration (src/brevitas_examples/llm/llm_args.py):

    • Added --apply-permute flag to enable permutation equalization
    • Added --permute-fn argument to select permutation strategy (default: massdiff)
  • Main Workflow Updates (src/brevitas_examples/llm/main.py):

    • Modified fused_rotation_no_fx() to support permutation mode
    • Integrated permutation computation during calibration phase
    • Added proper flow control for rotation+permutation combinations
  • Utility Functions (src/brevitas/graph/utils.py):

    • Added find_node_for_module helper function for graph traversal

Context Manager Design: The rotate_permute_mode context manager encapsulates the entire workflow:

with rotate_permute_mode(model, permute_fn='massdiff', ...) as rpm:
    # Calibration happens here to collect stats
    model(**calibration_data)
    # Permutations are computed and applied on exit

Notable Implementation Details:

  1. Permutations are always fused (applied directly to weights), unlike rotations which can be parametrized for optimization
  2. Hook management: Careful tracking of hooked modules to avoid duplicate hooks
  3. Device handling: Activation statistics collected on CPU to avoid OOM issues
  4. SDPA region filtering: SDPA (Scaled Dot-Product Attention) regions are excluded from permutation (at least for now) to avoid head alignment issues. Experiments suggest this is not critical.

Expected Results

MixQuant demonstrates significant improvements over block rotations alone on Llama-3.2-1B-Instruct with W4A4 per-channel quantization. Using block rotations with block_rotation_dim: 32 and massdiff permutation strategy, MixQuant achieves:

  • 37% reduction in perplexity (26.9 → 17.0 on WikiText-2)
  • +3.1 pp improvement in average downstream task accuracy
  • Consistent gains across all zero-shot tasks (ARC, HellaSwag, PIQA, Winogrande)

The improvements stem from better activation outlier management through channel permutations that balance magnitude distributions within rotation blocks.

Method Wiki2 ARC-C ARC-E HellaS PIQA Wino Overall
No Permute 26.9 25.6% 48.4% 36.3% 64.9% 50.4% 45.1%
MixQuant 17.0 27.7% 52.4% 38.3% 68.9% 53.8% 48.2%

Configuration: Both methods use Qronos for error correction with dynamic per-row activations, MSE weight scales, and fused Hadamard rotations. See src/brevitas_examples/papers/mixquant/llama3-mixquant-int4.yml for the full config. You can run this as:

brevitas_ptq_llm --config=llama3-mixquant-int4.yml

Please use https://github.com/i-colbert/brevitas/tree/mixquant/src/brevitas_examples/papers/mixquant to reproduce the experiments from the paper.

Testing Summary

Added test_rotate_permute_mode to tests/brevitas/graph/test_equalization.py

  • Tests with multiple permutation functions: massdiff, zigzag, absmax, and random
  • Tests interop with block_rotation_dim, disable_block_rotation_for_fused, and expansion_step
  • Tests CPU and GPU (if available)

Risk Highlight

  • This PR includes code from another work (please detail).
  • This PR contains API-breaking changes.
  • This PR depends on work in another PR (please provide links/details).
  • This PR introduces new dependencies (please detail).
  • There are coverage gaps not covered by tests.
  • Documentation updates required in subsequent PR.

Checklist

  • Code comments added to any hard-to-understand areas, if applicable.
  • Changes generate no new warnings.
  • Updated any relevant tests, if applicable.
  • No conflicts with destination dev branch.
  • I reviewed my own code changes.
  • Initial CI/CD passing.
  • 1+ reviews given, and any review issues addressed and approved.
  • Post-review full CI/CD passing.

# When both rotation and permutation are enabled, use the unified context manager
if args.apply_permute:
print("Applying permutations...")
with rotate_permute_mode(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we decouple rotation and permutations, we can first compute all rotations no matter what, and then do rotate_permute_mode maybe?
In that case, there should be the possibility pass the rotation rewriters to the class that it will just take care of applying them (rather than computing them)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants