Skip to content

Implementation and analysis of Sparse Autoencoders for neural network interpretability research. Features interactive visualization dashboard and W&B integration.

Notifications You must be signed in to change notification settings

ashioyajotham/exploring_saes

Repository files navigation

Exploring Sparse Autoencoders for Mechanistic Interpretability

Overview

Research project investigating how Sparse Autoencoders (SAEs) learn and represent features from transformer models. Focuses on understanding concept emergence, activation patterns, and neuron specialization through multiple activation functions and comprehensive analysis tools.

Methodology

Decomposing the "Neural Soup" of GPT-2

Welcome to a from-scratch implementation of Sparse Autoencoders designed to interpret the internal activations of Large Language Models (specifically GPT-2 Small).

This project is an exploration into Mechanistic Interpretability—the science of reverse-engineering neural networks to understand not just what they do, but how they think.

The Core Concepts

  • Sparse Autoencoders (SAEs): Neural networks that learn efficient representations of input data by enforcing sparsity in the hidden layers. Sparsity means that only a small fraction of neurons are active at any given time, which can lead to more interpretable features.
  • Activation Functions: Different functions (ReLU, JumpReLU, TopK) are employed to enforce sparsity and study their effects on feature learning.
  • Concept Emergence: Tracking how distinct features or "concepts" arise in the hidden layers during training.
  • Neuron Specialization: Analyzing how individual neurons develop specific roles based on their activation patterns.

1. The Problem: Polysemanticity ("The Soup")

  • Large language models often exhibit polysemanticity, where single neurons respond to multiple unrelated features. This makes it challenging to interpret what each neuron represents. In a standard neural network, individual neurons are polysemantic (Many-Meanings) for instance, a neuron might activate for both "cat" and "satellite," making it hard to decipher its true function.
  • This is because the model is compressing trillions of concepts into limited space, a concept known as "Superposition."
  • The result is that looking at raw neuron activations is like looking at a bowl of soup—you see a mix of ingredients but can't easily identify each one. The internal representations are entangled and hard to interpret.

2. The Solution: Monosemanticity ("The Ingredients")

  • We want to map the network to a state of Monosemanticity (One-Meaning), where each neuron corresponds to a single, distinct concept. This makes it easier to understand what each neuron is doing.
  • We want to find specific features in the model's activations, such as "cat," "satellite," or "the concept of 'being in space,'" and have individual neurons represent these features clearly.

3. The Tool: Sparse Autoencoders (SAEs)

  • The SAE acts like a Prism.

    • It takes the "white light" of raw activations (the soup)
    • It expands it into a massive sparse dimension
    • It forces the data to separate into distinct, interpretable "rays" (the features)

Key Research Questions

  • How do different activation functions affect feature learning in SAEs?
  • What drives concept emergence in hidden layers when training on transformer activations?
  • How reliable are activation frequency patterns as indicators of neuron specialization?
  • Can we quantify and visualize neuron behavior during training?
  • How do different sparsity mechanisms impact feature interpretability?

Installation

Requirements

  • Python 3.8+
  • CUDA capable GPU (recommended)
  • 8GB+ RAM

Setup

# Create virtual environment
python -m venv venv

# Activate environment
.\venv\Scripts\activate  # Windows
source venv/bin/activate  # Linux/Mac

# Install dependencies
pip install -r requirements.txt

# Install additional visualization tools
pip install umap-learn wandb

Project Structure

The Stack

  • transformer_data.py (The Harvester): Hooks into a pre-trained transformer model (GPT-2 Small) using Transformer_lens to extract raw "resid_pre" activations from specific layers.

  • models/autoencoder.py (The Prism): a custom Pytorch implementation of the SAE.

    • Encoders: supports standard ReLU, JumpReLU, and TopK activation functions to enforce sparsity.
    • Decoders: reconstructs the original activations from the sparse representations learned by the encoder (reconstructs the signal to measure fidelity).
  • experiments/ (The Laboratory): Contains scripts for various experiments analyzing activation functions, concept emergence, frequency patterns, and more. Scripts to run training loops, managing the trade-off between reconstruction loss (MSE Loss) and sparsity (L1 Loss). L1 loss is the loss resulting from the L1 norm of the activations, encouraging sparsity while MSE loss measures how well the SAE reconstructs the original input.

  • visualization/ (The Microscope): Tools for visualizing training progress and results, including ASCII terminal outputs and W&B dashboards.

exploring_saes/
├── experiments/
│   ├── activation_study.py   # Activation function analysis
│   ├── concept_emergence.py  # Feature learning tracking
│   ├── frequency_analysis.py # Neuron firing patterns
│   ├── checkpointing.py     # Experiment state management
│   └── transformer_data.py   # Model integration
├── models/
│   └── autoencoder.py       # SAE implementation
├── visualization/
│   ├── ascii_viz.py         # Terminal visualizations
│   └── wandb_viz.py         # W&B dashboard integration
├── config/
│   └── config.py            # Configuration management
└── run_experiments.py       # Main entry point

Usage

Basic Training

python run_experiments.py --hidden-dim 256 --epochs 100

Transformer Analysis

python run_experiments.py --model-name gpt2-small --layer 0 --n-samples 10000 --seq-len 20 --use-wandb

Notes:

  • Default n_samples is now 10000 (total tokens ~= n_samples * seq_len).
  • Use --seq-len to control the token sequence length sampled from the transformer (default: 20).

Programmatic Usage

If you prefer to harvest activations and use the dataset directly from Python, here's a minimal example:

from experiments.transformer_data import TransformerActivationDataset
import torch

# Create dataset (harvests activations and caches them in memory)
ds = TransformerActivationDataset(model_name='gpt2-small', layer=0, n_samples=1000)

# Use a PyTorch DataLoader for batching
loader = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True)

for batch in loader:
  x = batch['pixel_values']  # shape: (batch_size, features)
  # pass `x` into your SAE training loop

Configuration Options

Parameter Description Default
--hidden-dim Hidden layer size 256
--lr Learning rate 0.001
--epochs Training epochs 100
--batch-size Batch size 64
--activation Activation type [relu/jump_relu/topk] relu
--model-name Transformer model gpt2-small
--layer Layer to analyze 0
--n-samples Number of samples 10000
--use-wandb Enable W&B logging False
--seq-len Token sequence length used when sampling 20
--normalize-activations Normalize transformer activations (zero-mean, unit-std) True
--no-normalize Disable activation normalization False
--dead-resample Enable dead-neuron resampling during training True
--no-dead-resample Disable dead-neuron resampling False
--dead-threshold Threshold for considering a neuron "dead" (mean firing rate) 0.001

Quick Smoke Test

Run a very short experiment to validate the code path and flags without waiting for a full training run:

python run_experiments.py --n-samples 100 --seq-len 10 --epochs 2 --activation relu

This performs a small harvest (100 samples) and trains for 2 epochs — useful for confirming everything runs end-to-end.

Recommended Hyperparameters

These settings are a starting point for avoiding "dead" features and getting stable SAE training on transformer activations.

  • Data volume:

    • --n-samples: 10000 or higher (default now 10000). Aim for 10k–100k samples depending on compute.
    • --seq-len: 20 (default). Increasing seq-len multiplies token diversity.
  • Optimization:

    • --lr (learning rate): 1e-3 -> start at 1e-3, reduce to 5e-4 or 1e-4 if loss collapses (cliff then flatline).
    • --batch-size: 64 (default). Increase to 128–256 if you have GPU memory.
    • --epochs: 100 (default). For large n-samples, scale epochs down or use more data per epoch.
  • Sparsity and activations:

    • --activation: relu is stable; topk enforces strong sparsity but may need more data and a larger hidden dim.
    • For topk, increase k (in config) to allow more active features when reconstruction is poor.
    • Keep --normalize-activations enabled to stabilize training.
  • Dead-neuron handling:

    • --dead-resample (enabled by default) will try to reinitialize low-firing neurons during training.
    • --dead-threshold default 0.001 is conservative; adjust upward (e.g., 0.005) if many neurons are flagged as dead too early.
  • Practical tips:

    • If you see the loss drop then flatline: reduce --lr and increase data (--n-samples).
    • Monitor mean_activation_rate and high_freq_neurons in W&B — aim for non-zero mean activation and a small but non-trivial number of high-frequency neurons.

Features

Analysis Tools

  • Multiple activation function comparison
  • Neuron frequency analysis
  • Concept emergence tracking
  • Feature clustering
  • Attribution scoring
  • Sparsity measurements

Visualization

  • Real-time ASCII training metrics
  • W&B experiment tracking
  • Feature map visualization
  • Activation heatmaps
  • Concept embedding plots

Checkpointing

Automatic experiment state saving enables:

  • Recovery from interruptions
  • Continuation of training
  • Progress tracking
  • Result caching

Results Visualization

Terminal Output

╔════════════════════ EXPERIMENT RESULTS ════════════════════╗
║ Activation Function Comparison:
║   ReLU     - Loss: 0.9870, Sparsity: 27.87%
║   JumpReLU - Loss: 0.9700, Sparsity: 30.15%
║   TopK     - Loss: 1.0427, Sparsity: 98.05%
╚═══════════════════════════════════════════════════════════╝

W&B Dashboard

Access experiment tracking at: https://wandb.ai/ashioyajotham/sae-interpretability

Features:

  • Loss curves
  • Activation patterns
  • Feature maps
  • Concept embeddings
  • Neuron statistics

Weights & Biases (W&B) Usage

Quick start:

  • Install and login:
pip install wandb
wandb login
  • Run an experiment with W&B enabled:
python run_experiments.py --use-wandb --model-name gpt2-small --n-samples 10000
  • Recommended W&B options:

    • Use --use-wandb to enable logging.
    • The run is created under project sae-interpretability by default; change the project in config/config.py or pass a wandb config programmatically if you use the API.

Notes:

  • W&B stores model metrics, activation snapshots and images (feature maps, UMAP plots). Large runs may produce many artifacts; consider sampling or limiting visualization frequency with config.log_freq.
  • If you want reproducible runs, set the seed and record it in wandb.config.

Error Recovery

Training can be resumed using checkpoints:

# Training will continue from last successful state
python run_experiments.py [previous-args] --resume

Roadmap: From Understanding to Control

Currently, this project focuses on Feature Discovery(finding the dictionary)—identifying and understanding the features learned by SAEs. The next phase involves Steering—developing methods to manipulate and control these features within the model. This could include techniques for targeted editing of neuron activations or guiding the model's behavior based on the learned features.

  • Phase 1: Feature Discovery (Current)

    • Train SAEs to decompose transformer activations into sparse, interpretable features.
    • Identify and analyze features learned by SAEs.
    • Understand how different activation functions impact feature learning.
  • Phase 2: Feature Identification (In Progress)

    • find_feature.py Identify specific features corresponding to concepts eg "The Golden Gate Bridge", "Quantum Mechanics", etc.
  • Phase 3: Steering (Future Work)

    • steer_model.py "Clamp" these features to control model behavior, e.g., during inference to force the model to hallucinate specific concepts like discussing "Quantum Mechanics" or describing "The Golden Gate Bridge" irrespective of the prompt.

References

[1] "Towards Monosemanticity: Decomposing Language Models With Dictionary Learning", Anthropic (2023)

[2] "Sparse Autoencoders Find Highly Interpretable Features in Language Models", Lee et al. (2023)

[3] "Scaling and evaluating sparse autoencoders", OpenAI (2022)

[4] "How to Read an Artificial Neural Brain", Joshua Placidi (Medium) — a practical, accessible discussion of techniques for probing neural representations, including an intuitive treatment of superposition and feature entanglement.

About

Implementation and analysis of Sparse Autoencoders for neural network interpretability research. Features interactive visualization dashboard and W&B integration.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages