A PyTorch implementation of a GPT-2 style non-causal transformer for solving Sudoku puzzles. The model reaches 99.95% accuracy and solves 98.92% of puzzles completely correctly in the validation set.
Note: While backtracking algorithms can solve Sudoku puzzles extremely fast (microseconds), this project serves as an important case study in ML research for understanding the strengths and limitations of transformer architectures. It provides insights into how transformers learn constraint satisfaction problems, their sample efficiency across different model scales, and emergent capabilities in structured reasoning tasks.
Models of 3 different sizes are trained:
- Small (25M): 6 layers, 384 embedding, 6 heads
- Medium (40M): 8 layers, 512 embedding, 8 heads
- Large (80M): 12 layers, 768 embedding, 12 heads
Each was trained for 100,000 iterations with a batch size of 256 on a RTX 3090 GPU, with the longest run taking about 12 hours. We reach a maximum accuracy of 99.95% solving 98.92% of puzzles completely with the largest model. We can see scaling laws in action in the graphs.
Training Cost: All three experiments were run on a rented RTX 3090 machine via Vast.ai at $0.210/hour for approximately 20 hours total, costing around $4.20 for the complete set of experiments.
We can see some overfitting on the largest model despite our relabeling augmentation.
The medium model surpasses the small model in about 1 hour while the large model surpasses the medium model in about 3 hours on a 3090 GPU.
Bigger models learn faster on a per-sample basis, demonstrating higher sample efficiency.
The dataset is the "1 million Sudoku games" from Kaggle: https://www.kaggle.com/datasets/bryanpark. The data is split into training and validation sets with 10,000 validation examples. The data is processed from csv into npz format with np.uint8 data type.
- Training set: 990,000 puzzles
- Validation set: 10,000 puzzles
- Format: 81 np.uint8 digits (0-9) where 0 represents empty cells
- Type: GPT-2 style transformer without causal masking (bidirectional attention)
- Input: Token embeddings (vocab size 10 for digits 0-9) + positional embeddings (81 positions)
- Output: Per-position classification into 9 classes (digits 1-9)
- Loss: CrossEntropyLoss computed only on non-hint positions (configurable via
mask_hints)
- Relabeling: Random permutation of digits 1-9 during training
- Example: All 1s → 5s, all 2s → 7s, etc.
- Preserves Sudoku structure while increasing data diversity
sudoku/
├── config/
│ ├── gpt2_sm.json # Small model configuration (25M params)
│ ├── gpt2_md.json # Medium model configuration (40M params)
│ ├── gpt2_lg.json # Large model configuration (80M params)
│ └── gpt2_xl.json # Extra-large model configuration
├── data/
│ ├── sudoku.csv # Raw dataset (1M puzzles)
│ └── sudoku.npz # Processed data (train/valid splits)
├── img/ # Training result visualizations
├── models/
│ └── {experiment_name}/ # Experiment outputs
│ ├── config.json # Copied config
│ ├── metrics.csv # Training metrics
│ └── best_model.pt # Best checkpoint by validation loss
├── notebooks/
│ ├── analyze.ipynb # Results analysis and visualization
│ └── explore.ipynb # Dataset exploration
├── src/
│ ├── data.py # Dataset and dataloader implementation
│ ├── gpt.py # GPT model architecture
│ ├── main.py # Training script
│ ├── process.py # Data preprocessing
│ ├── training.py # Training loop and metrics
│ └── utils.py # Config classes and utilities
├── LICENSE # MIT License
├── pyproject.toml # Python project configuration
└── README.md
- Install dependencies using
uv:
uv sync-
Download the dataset from Kaggle and place
sudoku.csvin thedata/directory. -
Process the dataset:
uv run python src/process.pyThis will create data/sudoku.npz with train/validation splits.
Start training with the default configuration:
uv run python src/main.pyOr specify a custom config:
uv run python src/main.py --config config/my_config.jsonExample configuration (config/gpt2_sm.json):
{
"experiment_name": "gpt2_sm_00",
"max_iters": 100000,
"batch_size": 256,
"num_workers": 4,
"learning_rate": 1e-4,
"eval_interval": 500,
"n_embd": 384,
"n_layer": 6,
"n_head": 6,
"dropout": 0,
"mask_hints": true,
"use_compile": true,
"compile_mode": "reduce-overhead"
}Configuration Parameters:
experiment_name: Name of the experiment (used for output directory)max_iters: Maximum training iterationsbatch_size: Batch size for training and validationnum_workers: Number of data loader workerslearning_rate: Learning rate for AdamW optimizereval_interval: Evaluate on validation set every N iterationsmask_hints: Iftrue, only compute loss on non-hint positions (where input == 0). Iffalse, compute loss on all 81 positions (default:false)n_embd: Embedding dimensionn_layer: Number of transformer layersn_head: Number of attention headsdropout: Dropout probabilityuse_compile: Enable PyTorch 2.0+ compilation for faster training (default:true)compile_mode: Compilation mode - "default", "reduce-overhead", or "max-autotune" (default:"reduce-overhead")
The training loop tracks:
- loss: CrossEntropyLoss computed on all 81 positions (when
mask_hints=false) or only on non-hint positions (whenmask_hints=true) - acc: Accuracy on non-hint positions only (where input == 0)
- acc_full: Per-puzzle accuracy (1.0 if all 81 positions correct, 0.0 otherwise)
Both training and validation metrics are saved to models/{experiment_name}/metrics.csv.
- Best model (by validation loss) is saved to
models/{experiment_name}/best_model.pt - Checkpoint includes:
- Model state dict
- Optimizer state dict
- Iteration number
- Validation loss
- Input:
(batch_size, 81)integers 0-9 - Output:
(batch_size, 81, 9)logits for classes 0-8 (representing digits 1-9) - Targets: Solution digits 1-9 converted to class labels 0-8
- Bidirectional attention (no causal masking)
- Relabeling augmentation for training data
- Optional hint masking: compute loss only on non-hint positions via
mask_hintsconfig - Exponential moving average (EMA) for training metrics
- Arithmetic mean for validation metrics
- Automatic experiment directory management
- Best model checkpointing
- Python >= 3.13
- PyTorch >= 2.5.0
- NumPy >= 2.0.0
- Pandas >= 2.2.0
- tqdm >= 4.66.0
- torchmetrics >= 1.5.0
- pydantic >= 2.0.0
- ipykernel >= 7.1.0 (for notebooks)
- matplotlib >= 3.10.7 (for visualization)
All dependencies are managed via uv and specified in pyproject.toml.
The project includes Jupyter notebooks for analysis and exploration:
- notebooks/analyze.ipynb - Training results analysis and visualization
- notebooks/explore.ipynb - Dataset exploration and statistics
This project is licensed under the MIT License - see the LICENSE file for details.





