Skip to content

Explore fine-tuning the Vision Transformer (ViT) model for object recognition in robotics using PyTorch. This tutorial covers setup, training, and evaluation processes, achieving impressive accuracy with practical resource constraints. Ideal for learners in AI and robotics.

License

Notifications You must be signed in to change notification settings

jman4162/PyTorch-Vision-Transformers-ViT

Repository files navigation

vit-trainer

PyPI CI Open In Colab License: MIT Python 3.8+

A simple, educational package for fine-tuning Vision Transformer (ViT) models using PyTorch. Achieves 97.65% accuracy on CIFAR-10 with modern training techniques.

ViT

Why vit-trainer?

vs. timm/transformers vit-trainer
1000+ model architectures Focused on ViT fine-tuning
Complex APIs Simple, readable code
Research-oriented Educational + Production ready

Features:

  • Mixed precision training (AMP) for 2-3x speedup
  • AdamW optimizer with cosine annealing + warmup
  • Attention visualization for interpretability
  • ONNX export for deployment
  • CLI and Python API

Installation

pip install vit-trainer

Optional Dependencies

# Gradio web demo
pip install "vit-trainer[demo]"

# ONNX export
pip install "vit-trainer[export]"

# Everything
pip install "vit-trainer[all]"

Install from Source

git clone https://github.com/jman4162/PyTorch-Vision-Transformers-ViT.git
cd PyTorch-Vision-Transformers-ViT
pip install -e ".[dev]"

Quick Start

Python API

from vit_trainer import Trainer, load_model, get_cifar10_loaders

# Load data and model
train_loader, val_loader, test_loader = get_cifar10_loaders(batch_size=64)
model = load_model("vit_b_16", num_classes=10)

# Train
trainer = Trainer(model, lr=1e-4, use_amp=True)
history = trainer.fit(train_loader, val_loader, epochs=10)

# Evaluate
loss, accuracy = trainer.evaluate(test_loader)
print(f"Test Accuracy: {accuracy:.2f}%")

Command Line Interface

# Train a model
vit-train train --model vit_b_16 --dataset cifar10 --epochs 10

# Evaluate a trained model
vit-train eval --checkpoint best_model.pt --dataset cifar10 --plot-confusion

# Predict on a single image
vit-train predict --checkpoint best_model.pt --image cat.jpg --show-attention

# Export to ONNX
vit-train export --checkpoint best_model.pt --output model.onnx

Configuration Files

# Use YAML config
vit-train train --config configs/default.yaml

Usage Examples

Training with Custom Settings

from vit_trainer import Trainer, load_model, get_cifar10_loaders, TrainingConfig

# Create config
config = TrainingConfig(
    model_variant="vit_b_16",
    batch_size=64,
    epochs=10,
    lr=1e-4,
    weight_decay=0.05,
    warmup_epochs=2,
    patience=3,
    use_amp=True,
)

# Train
train_loader, val_loader, _ = get_cifar10_loaders(batch_size=config.batch_size)
model = load_model(config.model_variant, num_classes=10)
trainer = Trainer(
    model,
    lr=config.lr,
    weight_decay=config.weight_decay,
    warmup_epochs=config.warmup_epochs,
    use_amp=config.use_amp,
)
trainer.fit(train_loader, val_loader, epochs=config.epochs, patience=config.patience)

Attention Visualization

from vit_trainer import visualize_samples_with_attention, CIFAR10_CLASSES

visualize_samples_with_attention(
    model,
    test_loader.dataset,
    CIFAR10_CLASSES,
    num_samples=4,
)

Evaluation Metrics

from vit_trainer import get_predictions, compute_metrics, plot_confusion_matrix

y_pred, y_true, probs = get_predictions(model, test_loader)
metrics = compute_metrics(y_true, y_pred, CIFAR10_CLASSES)

print(metrics["classification_report"])
plot_confusion_matrix(y_true, y_pred, CIFAR10_CLASSES)

Loading Trained Models

from vit_trainer import load_model

# Load from checkpoint
model = load_model(
    "vit_b_16",
    num_classes=10,
    checkpoint_path="best_model.pt",
)

ONNX Export

from vit_trainer import load_model, ExportConfig

# Load trained model
model = load_model("vit_b_16", num_classes=10, checkpoint_path="best_model.pt")

# Export to ONNX
config = ExportConfig(output_path="model.onnx", opset_version=14)
config.export(model)

# Or use CLI
# vit-train export --checkpoint best_model.pt --output model.onnx

API Reference

from vit_trainer import (
    # Configuration
    TrainingConfig,           # Training hyperparameters
    ExportConfig,             # ONNX export settings

    # Models
    load_model,               # Load ViT with pretrained weights
    VIT_VARIANTS,             # Available model variants

    # Data
    get_cifar10_loaders,      # CIFAR-10 data loaders
    get_cifar100_loaders,     # CIFAR-100 data loaders
    CIFAR10_CLASSES,          # Class names

    # Training
    Trainer,                  # Training loop with AMP
    EarlyStopping,            # Early stopping callback
    ModelCheckpoint,          # Save best model

    # Evaluation
    evaluate_model,           # Loss and accuracy
    compute_metrics,          # Precision, recall, F1
    plot_confusion_matrix,    # Visualization

    # Visualization
    visualize_attention,      # Attention heatmaps
)

Project Structure

vit-trainer/
├── vit_trainer/
│   ├── __init__.py         # Public API
│   ├── config.py           # TrainingConfig dataclass
│   ├── cli.py              # Command-line interface
│   ├── data/               # Data loaders and transforms
│   ├── models/             # Model registry and factory
│   ├── training/           # Trainer and callbacks
│   ├── evaluation/         # Metrics and plotting
│   └── visualization/      # Attention maps
├── tests/                  # Unit tests (44 tests)
├── configs/                # YAML configurations
├── notebooks/              # Tutorial notebooks
├── app.py                  # Gradio demo
└── pyproject.toml          # Package configuration

ViT Variants

Variant Patch Size Parameters ImageNet Acc Use Case
vit_b_16 16x16 86M 81.1% Best accuracy/speed
vit_b_32 32x32 88M 75.9% Faster inference
vit_l_16 16x16 304M 79.7% Higher accuracy

Training Results

Metric Value
Test Accuracy 97.65%
Model vit_b_16
Training Time ~11 min/epoch (GPU)

Gradio Demo

# Launch interactive web interface
python app.py
# Opens at http://localhost:7860

Development

# Install dev dependencies
pip install -e ".[dev]"

# Run tests
pytest tests/

# Format code
black vit_trainer/
ruff check vit_trainer/

# Type check
mypy vit_trainer/

Troubleshooting

CUDA Out of Memory

  • Reduce batch size: --batch-size 32 or 16
  • AMP is enabled by default

Slow Training on CPU

  • Use Google Colab (free GPU)
  • Training on CPU is very slow (~60 min/epoch)

Import Errors

  • Make sure to install the package: pip install vit-trainer

Resources

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

License

Distributed under the MIT License. See LICENSE for more information.

About

Explore fine-tuning the Vision Transformer (ViT) model for object recognition in robotics using PyTorch. This tutorial covers setup, training, and evaluation processes, achieving impressive accuracy with practical resource constraints. Ideal for learners in AI and robotics.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 2

  •  
  •