Skip to content

Implementation of mixture of experts with grouped query attention

Notifications You must be signed in to change notification settings

vivek12345/moe-with-gqa

Repository files navigation

Qwen-Inspired Transformer with Mixture of Experts

A from-scratch implementation of a Qwen-inspired transformer language model featuring:

  • RMS Layer Normalization for efficient normalization
  • Multi-Head Self-Attention with causal masking
  • Sparse Mixture of Experts (MoE) for scalable capacity
  • Pre-normalization architecture similar to modern LLMs

🏗️ Architecture

This implementation follows modern transformer design principles used in models like Qwen and LLaMA:

Model Components

  1. Token & Positional Embeddings

    • Learned token embeddings for vocabulary
    • Learned positional embeddings for sequence positions
  2. Transformer Blocks (block.py)

    • Pre-normalization with RMS LayerNorm
    • Multi-head self-attention
    • Sparse Mixture of Experts (MoE)
    • Residual connections
  3. RMS Layer Normalization (rms.py)

    • More efficient than standard LayerNorm
    • Used in models like LLaMA and Qwen
    • Normalizes based on root mean square
  4. Multi-Head Attention (multi_head.py)

    • Parallel attention heads for different representation subspaces
    • Causal masking for autoregressive generation
    • Scaled dot-product attention
  5. Sparse Mixture of Experts (moe.py)

    • Router network with top-k selection
    • Multiple expert networks (feed-forward layers)
    • Optional noisy routing for exploration
    • Sparse activation: only top-k experts per token

📁 Project Structure

qwen-from-scratch/
├── model.py        # Main transformer model
├── block.py        # Transformer block (attention + MoE)
├── rms.py          # RMS Layer Normalization
├── multi_head.py   # Multi-head self-attention
├── moe.py          # Sparse Mixture of Experts
├── pyproject.toml  # Project dependencies
└── README.md       # This file

🚀 Installation

This project uses uv for dependency management. If you don't have it installed:

# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh

Then install dependencies:

# Install dependencies
uv sync

Alternatively, with pip:

pip install torch numpy tiktoken

💡 Usage

Custom Model

import torch
from model import Model

# Define model configuration
config = {
    "vocab_size": 50257,      # Vocabulary size (e.g., GPT-2 tokenizer)
    "context_length": 256,     # Maximum sequence length
    "emb_dim": 512,            # Embedding dimension
    "num_layers": 6,           # Number of transformer blocks
    "num_heads": 8,            # Number of attention heads
    "dropout": 0.1,            # Dropout probability
    "qkv_bias": False          # Use bias in attention projections
}

# MoE configuration
num_experts = 8      # Number of expert networks
top_k = 2            # Number of experts to activate per token
use_noisy_top_k = True  # Add noise to routing (helpful during training)

# Create model
model = Model(config, num_experts, top_k, use_noisy_top_k)

# Forward pass
input_ids = torch.randint(0, config["vocab_size"], (2, 128))  # (batch_size, seq_len)
logits = model(input_ids)  # (batch_size, seq_len, vocab_size)

print(f"Input shape: {input_ids.shape}")
print(f"Output shape: {logits.shape}")

Text Generation Example

import torch
import torch.nn.functional as F
from model import Model
import tiktoken

# Initialize model
config = {
    "vocab_size": 50257,
    "context_length": 256,
    "emb_dim": 512,
    "num_layers": 6,
    "num_heads": 8,
    "dropout": 0.1,
    "qkv_bias": False
}

model = Model(config, num_experts=8, top_k=2, use_noisy_top_k=True)
model.eval()

# Initialize tokenizer
tokenizer = tiktoken.get_encoding("gpt2")

# Generate text
def generate(model, prompt, max_tokens=50, temperature=1.0):
    """Generate text using the model."""
    tokens = tokenizer.encode(prompt)
    tokens = torch.tensor(tokens).unsqueeze(0)  # Add batch dimension
    
    for _ in range(max_tokens):
        # Get predictions
        with torch.no_grad():
            logits = model(tokens)
        
        # Get logits for last token
        logits = logits[:, -1, :] / temperature
        
        # Sample next token
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        # Append to sequence
        tokens = torch.cat([tokens, next_token], dim=1)
        
        # Check if we've reached max context length
        if tokens.size(1) >= config["context_length"]:
            break
    
    return tokenizer.decode(tokens[0].tolist())

# Generate text
prompt = "Once upon a time"
generated_text = generate(model, prompt)
print(generated_text)

🔧 Configuration

Model Hyperparameters

Parameter Description Typical Values
vocab_size Size of vocabulary 50257 (GPT-2), 32000 (LLaMA)
context_length Maximum sequence length 256, 512, 1024, 2048
emb_dim Embedding dimension 256, 512, 768, 1024
num_layers Number of transformer blocks 6, 12, 24, 32
num_heads Number of attention heads 8, 12, 16, 32
dropout Dropout probability 0.0, 0.1, 0.2
qkv_bias Bias in attention projections False (LLaMA/Qwen), True (GPT)

MoE Hyperparameters

Parameter Description Typical Values
num_experts Total number of experts 4, 8, 16, 32
top_k Experts activated per token 2, 4
use_noisy_top_k Add noise to routing True (training), False (inference)

🧠 How It Works

Sparse Mixture of Experts (MoE)

The MoE layer allows the model to scale capacity without proportionally increasing computation:

  1. Router Network: For each token, a router network computes scores for all experts
  2. Top-k Selection: Only the top-k experts with highest scores are selected
  3. Sparse Routing: Selected experts process the token in parallel
  4. Weighted Combination: Expert outputs are weighted by routing probabilities and summed

Benefits:

  • Larger model capacity with same computational cost
  • Each token can specialize to different experts
  • Better parameter efficiency

RMS Layer Normalization

RMS normalization is simpler and more efficient than standard LayerNorm:

RMSNorm(x) = (x / RMS(x)) * γ

where RMS(x) = sqrt(mean(x²) + ε)

Advantages:

  • No need to compute mean and variance separately
  • Fewer operations → faster training and inference
  • Used in modern LLMs like LLaMA and Qwen

📊 Model Size Estimation

Use this formula to estimate model parameters:

Total Parameters ≈ 
  vocab_size × emb_dim × 2                    # Embeddings
  + num_layers × (
      4 × emb_dim²                            # Attention
      + 2 × emb_dim                           # RMS Norms
      + num_experts × 8 × emb_dim²            # MoE Experts
      + emb_dim × num_experts                 # Router
    )
  + emb_dim × vocab_size                      # Output projection

Example: With emb_dim=512, num_layers=6, num_experts=8, vocab_size=50257:

  • Approximately 1.2B parameters
  • But only ~200M active per token (due to sparse MoE with top_k=2)

📚 References

🤝 Contributing

Feel free to open issues or submit pull requests for:

  • Bug fixes
  • Documentation improvements
  • New features
  • Performance optimizations

📝 License

This is an educational implementation. Feel free to use it for learning and research purposes.

🙏 Acknowledgments

This implementation is inspired by:

  • Qwen by Alibaba Cloud
  • LLaMA by Meta AI
  • Various open-source transformer implementations

Note: This is a from-scratch educational implementation. For production use, consider using established libraries like Hugging Face Transformers or PyTorch's built-in modules.

About

Implementation of mixture of experts with grouped query attention

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages