A multimodal transformer research project combining BitNet 1.58-bit ternary quantization with an embedding-prediction architecture, optimized for consumer GPU hardware.
Current Status: Active Development (v0.2.0)
| Component | Status |
|---|---|
| Core configuration | Implemented |
| Model architecture (attention, MLP, layers) | Implemented |
| BitNet 1.58-bit quantization | Implemented |
| Multimodal tokenization | Implemented |
| Training loop (BitNet QAT) | Implemented |
| Inference engine (streaming) | Implemented |
| Progressive layer loading | Implemented |
| Dataset curation pipeline | Implemented |
| Quality gates (security, quality) | Implemented |
| Development tooling | Implemented |
Tritter implements a decoder-only transformer with the following design goals:
- BitNet b1.58 quantization: Ternary weights {-1, 0, +1} for ~10x memory reduction
- Multimodal tokenization: Unified vocabulary for text, code, image, and audio
- Embedding-prediction paradigm: Operations in continuous embedding space rather than discrete token space
- Consumer GPU target: Designed for RTX 5080 16GB VRAM constraints
The architecture operates in continuous embedding space:
- Entry point: Tokenization converts discrete tokens to embeddings
- Core computation: Transformer layers operate on continuous embeddings
- Exit point: Output projection to logits (temporary for training compatibility)
Production inference will use KNN/VQ rounding instead of argmax token selection, enabling continuous latent reasoning.
Primary target: NVIDIA RTX 5080 with 16GB GDDR7
| Component | Memory Budget |
|---|---|
| 7B BitNet weights | ~1.4 GB |
| INT4 KV-cache (128K context) | ~8-10 GB |
| Activations + overhead | ~2-3 GB |
| Vision encoder | ~0.4 GB |
| Total | ~12-15 GB |
Requires Python 3.12 or 3.13, and CUDA 12.1+ for GPU acceleration.
# Clone repository
git clone https://github.com/tzervas/tritter.git
cd tritter
# Create virtual environment (Python 3.13 recommended)
uv venv --python 3.13 .venv
source .venv/bin/activate
# Install PyTorch (standard CUDA)
uv pip install torch==2.5.1+cu121 --index-url https://download.pytorch.org/whl/cu121
# RTX 50-series (SM_120) pinned nightly
# export TRITTER_BLACKWELL_TORCH_VERSION="2.11.0.dev20260123+cu128"
# export TRITTER_BLACKWELL_TRITON_VERSION="3.6.0+git9844da95"
# uv pip install "torch==${TRITTER_BLACKWELL_TORCH_VERSION}" --pre --index-url https://download.pytorch.org/whl/nightly/cu128
# uv pip install "triton==${TRITTER_BLACKWELL_TRITON_VERSION}" --pre
# Install with development dependencies (extras recommended)
uv pip install -e ".[dev,training,inference,curation,extras]"python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}'); print(f'GPU: {torch.cuda.get_device_name(0)}')"Pip fallback:
python -m venv .venv
source .venv/bin/activate
pip install -e ".[dev,training,inference,curation,extras]"See docs/CUDA_SETUP.md for detailed CUDA configuration and troubleshooting.
from tritter import TritterConfig, TritterModel
from tritter.tokenization import MultiModalTokenizer, ModalityType
import torch
# Initialize configuration
config = TritterConfig(
model_size="3B", # Auto-configures architecture
use_bitnet=True,
use_flash_attention=True,
)
# Create model
model = TritterModel(config)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Initialize tokenizer
tokenizer = MultiModalTokenizer(vocab_size=config.vocab_size)
# Encode text
text = "def hello_world():\n print('Hello, World!')"
tokens = tokenizer.encode(text, ModalityType.CODE)
input_ids = torch.tensor([tokens])
# Forward pass
with torch.no_grad():
logits = model(input_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Output shape: {logits.shape}")See examples/basic_usage.py for a complete demonstration.
# Run all tests
pytest
# Run specific test file
pytest tests/unit/test_config.py
# Run with coverage
pytest --cov=src/tritter --cov-report=html# Format code
ruff format .
# Lint
ruff check .
# Type check (strict mode)
mypy src/tritter
# Verify imports
python -c "from tritter import *; print('OK')"Prepare training data and run pretraining:
# Curate training data from source code
python scripts/prepare_pretrain_data.py \
--input-dir /path/to/code \
--output-dir data/pretrain
# Train model (requires GPU)
python scripts/train_pretrain.py --model 1B --data-dir data/pretrainSee CLAUDE.md for full training pipeline documentation.
The devtools/ module provides development utilities:
# Run full validation suite
python -m devtools validate
# Quick validation (skip tests)
python -m devtools validate --quick
# Project status
python -m devtools status
# Implementation roadmap
python -m devtools status --roadmapAll contributions must follow the standards documented in docs/DEVELOPMENT_STANDARDS.md:
- Google-style docstrings with "Why" sections explaining design decisions
- Tensor shapes documented in comments:
x = proj(hidden) # (B, L, D) - Tests use config values (never hardcoded magic numbers)
- Parameter count tests include bounds checking
__all__exports must match imports in__init__.pyfiles
| Component | File | Purpose |
|---|---|---|
TritterConfig |
src/tritter/core/config.py |
Configuration with auto-scaling for 3B/7B |
TritterModel |
src/tritter/models/architecture.py |
Full transformer model |
TritterAttention |
src/tritter/models/architecture.py |
Multi-head attention with QK-Norm |
TritterMLP |
src/tritter/models/architecture.py |
FFN with Squared ReLU |
TernaryWeight |
src/tritter/quantization/bitnet.py |
BitNet quantization with STE |
MultiModalTokenizer |
src/tritter/tokenization/multimodal.py |
Unified multimodal tokenization |
The architecture follows BitNet b1.58 constraints:
- Squared ReLU (
x * ReLU(x)) activation in MLP layers - QK-Norm for attention stability
- Post-FFN LayerNorm (Chameleon-style placement)
- Shadow weights in full precision for STE training
Current implementation uses PyTorch SDPA with is_causal=True for FlashAttention-2 optimization. Planned enhancements:
- FlexAttention for dynamic masking (sliding window, document boundaries)
- Multiple attention modes (causal, bidirectional, prefix-lm)
- StreamingLLM attention sinks for streaming inference
| Document | Description |
|---|---|
docs/project-plan.md |
Technical blueprint and research foundations |
docs/DEVELOPMENT_STANDARDS.md |
Code standards and requirements |
docs/API_CONVENTIONS.md |
Interface patterns and conventions |
docs/CONTRIBUTING.md |
Contribution guidelines |
docs/clean-datasets.md |
Training data strategy |
docs/considerations.md |
Research on alternative architectures |
CLAUDE.md |
AI assistant guidelines |
tritter/
├── src/tritter/ # Core model implementation
│ ├── core/ # Configuration
│ ├── models/ # Architecture components
│ ├── quantization/ # BitNet implementation
│ ├── tokenization/ # Multimodal tokenization
│ ├── training/ # Training loop (stub)
│ ├── inference/ # Inference engine (stub)
│ └── utils/ # Utilities
├── devtools/ # Development tooling
├── tests/ # Test suite
├── examples/ # Usage examples
└── docs/ # Documentation
This project builds on published research:
- BitNet b1.58: Microsoft's ternary quantization achieving ~10x memory reduction
- Chameleon: Meta's early-fusion multimodal architecture
- Coconut/LCM: Embedding-prediction paradigm for continuous latent reasoning
- FlashAttention-2: Memory-efficient attention with tiled computation
See docs/project-plan.md for detailed citations and technical analysis.
- Training loop not yet implemented (model architecture only)
- Inference engine not yet implemented
- No pretrained weights available
- Multimodal capabilities (image, audio) require additional encoder integration
- RTX 5080 16GB memory budget has been validated on real hardware; some CUDA kernels may fall back until newer compute capability support lands in PyTorch
MIT License. See LICENSE for details.
Tyler Zervas (tz-dev@vectorweight.com)