install ohara
pip install oharaTo train MLA:
python train_mla.py --attn_type=mlaFor baseline, use MHA:
python train_mla.py --attn_type=mha If you cant to calculate the number of parameters, and check what % kv cache you'll save visite this link: https://joey00072.github.io/Multi-Head-Latent-Attention-MLA-/
- write blog post
- add jax version
- Add GQA and MOQ in calculation (index.html)
- Distill llama to MLA version Maybe
