LiteMIL: A Computationally Efficient Cross-Attention Multiple Instance Learning for Cancer Subtyping on whole-slide images
A lightweight transformer-style cross-attention MIL network for whole slide image classification achieving competitive performance with transformer-based methods while requiring fewer parameters, faster inference, and lower GPU memory.
LiteMIL uses learnable query-based multi-head cross-attention to aggregate patch-level features from whole slide images (WSIs) into slide-level predictions.
- β Query-based Multi-head cross-attention aggregation with configurable number of queries
- β Multiple backbones support: ResNet50, Phikon v2, UNI for feature extraction
- β
Support for
.h5,.pt, and.pthfeature formats - β Chunked data loading for memory-efficient training
- β Nested cross-validation with comprehensive metrics
- β End-to-end inference pipeline: Raw WSI β Features β Prediction β Visualization
- β Attention visualization: Interpretable heatmaps and top-patch extraction
- β CLI and Jupyter notebook interfaces
- Python 3.13+
- uv package manager (recommended)
# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh # Linux/macOS
# or
powershell -c "irm https://astral.sh/uv/install.ps1 | iex" # Windows
# Clone and setup
git clone https://github.com/hkussaibi/LiteMIL.git
cd LiteMIL
uv sync
# Activate environment
source .venv/bin/activate # Linux/macOS
# or
.venv\Scripts\activate # WindowsLiteMIL/
βββ datasets/
β βββ breast/
β β βββ wsi/ # .svs files/
β β βββ features/ # .h5, .pt, or .pth files
β β βββ labels.csv
β βββ ...
HDF5 (.h5), PyTorch (.pt, .pth)
Your labels.csv should contain:
case_id: Patient/case identifier (for patient-level splitting in CV)wsi_id: Unique slide identifier (must match feature filename without extension)label: Class label (string)
# Single file
python extractFeatures.py --input testSlides/testSlide.svs --backbone resnet50 --patch_size 256 --stride 256
# Batch processing
python extractFeatures.py --input_dir testSlides/ --output_dir features/ --backbone phikon-v2 --patch_size 256 --level 0 --format h5Supported Backbones:
| Backbone | Description | Output Dim | Best For |
|---|---|---|---|
| ResNet50 | ImageNet pretrained, fast | 1024 | Quick prototyping, limited compute |
| Phikon v2 | 450M pathology images | 1024 | Pathology-specific features |
| UNI | 100M+ pathology images | 1024 | Maximum performance |
You may download pre-extracted features for 4 TCGA datasets from:
Supported Datasets:
- TCGA-BRCA (Breast): 875 WSIs, 2 classes (IDC vs ILC)
- TCGA-Kidney: 906 WSIs, 3 classes (PRCC, CCRCC, CHRCC)
- TCGA-Lung: 958 WSIs, 2 classes (LUAD vs LUSC)
- TUPAC16: 821 WSIs, 2 classes (Low vs High grade)
Ground truth labels available in: datasets/<dataset>/labels.csv
Chunked Mode (Recommended):
- Splits each slide into fixed-size chunks (e.g., 1000 instances)
- Uniform batch sizes β stable training
- Lower memory footprint
- Better for attention-based models
Full Mode:
- Loads entire slide as single bag (variable length)
- Requires more memory
- Better for small slides or models that need global context
from MILS.LiteMIL import LiteMIL
model = LiteMIL(
input_dim=1024, # Input feature dimension
hidden_dim=256, # Hidden layer dimension
num_classes=2, # Number of output classes
num_heads=4, # Number of attention heads
num_queries=1, # Number of learnable queries (1 or 4)
dropout=0.25 # Dropout rate
)Recommended Configurations:
| Task Complexity | num_queries | hidden_dim | num_heads |
|---|---|---|---|
| Simple (2-class) | 1 | 256 | 4 |
| Medium (3-class) | 1-4 | 256-512 | 4-8 |
| Complex (4+ class) | 4 | 512 | 8 |
# Train on breast cancer dataset with default settings
python train.py --mil LiteMIL --dataset breast --mode chunked --epochs 1
# Train with custom architecture
python train.py --mil LiteMIL --dataset breast --num_queries 4 --hidden_dim 256 --dropout 0.25 --n_outer 5 --n_inner 4 --epochs 1
# Train with full-slide mode (if memory allows)
python train.py --mil LiteMIL --dataset breast --mode full --batch_size 16 --epochs 1Key Training Arguments:
--mil:LiteMIL,ABMIL,ABMIL_Multihead,TransMIL,meanPool, ormaxPool--dataset:breast,lung,kidney, ortupac--mode:chunked(fixed-size chunks) orfull(variable-length bags)--num_queries: Number of learnable queries (1 recommended, 4 for complex tasks)--hidden_dim: Hidden dimension (default: 256)--num_heads: Number of attention heads (default: 4)--epochs: Maximum training epochs (default: 100)--patience: Early stopping patience (default: 10)
Results are saved in outputs/<mil>/<dataset>/:
outputs/LiteMIL/breast/
βββ all_folds_best.pth # Best model across all folds
βββ fold_1_best.pth # Best model for fold 1
βββ fold_2_best.pth
βββ ...
βββ outputs.json # Aggregated metrics
outputs.json contains:
- Chunk-level Accuracy: Classification accuracy on chunks
- Slide-level Accuracy: Classification accuracy after aggregation
- F1 Score: Macro-averaged F1
- AUC-ROC: Area under ROC curve
# Basic inference
python predict.py --mil LiteMIL --dataset breast --checkpoint outputs/LiteMIL/breast/all_folds_best.pth --input testSlides/testSlide.h5
# Inference with chunked mode (for large slides)
python predict.py --mil LiteMIL --dataset breast --checkpoint outputs/LiteMIL/breast/all_folds_best.pth --input testSlides/testSlide.h5 --mode chunked --chunk_size 1000# Direct inference from SVS file with on-the-fly feature extraction
python predict.py --mil LiteMIL --dataset breast --checkpoint outputs/LiteMIL/breast/all_folds_best.pth --input testSlides/testSlide.svs --backbone resnet50 --patch_size 256Generate interpretable attention heatmaps:
# Basic visualization
python predict.py --visualize --mil LiteMIL --dataset breast --checkpoint outputs/LiteMIL/breast/all_folds_best.pth --input testSlides/testSlide.h5 --wsi testSlides/testSlide.svs --output outputs/heatmaps/
# With custom parameters
python predict.py --visualize --mil LiteMIL --dataset breast --checkpoint outputs/LiteMIL/breast/all_folds_best.pth --input testSlides/testSlide.h5 --wsi testSlides/testSlide.h5 --mode chunked --top_k 20 --cmap hot --output outputs/heatmaps/Visualization Output:
slide_detailed.png- Comprehensive view with heatmap, histogram, and statsslide_heatmap.png- Clean attention overlayslide_patch_grid.png- Grid of top-k patchesslide_top_patches/- Individual high-attention patchesslide_attention.npz- Raw attention data
See Tutorial.ipynb for interactive examples including:
- Custom dataset loading
- Feature extraction from raw WSI
- Model training and evaluation
- Attention visualization
- Performance analysis
# 1. Lower magnification (faster, less detail)
extractor = WSIFeatureExtractor(level=1) # 20x instead of 40x
# 2. Increase batch size (if GPU memory allows)
extractor = WSIFeatureExtractor(batch_size=128) # Default: 64
# 3. Increase stride (less overlap, faster)
extractor = WSIFeatureExtractor(stride=512, patch_size=256)# Solution 1: Reduce batch size
cv = NestedCrossValidation(..., batch_size=8)
# Solution 2: Use chunked mode with smaller chunks
result = predictor.predict('slide.h5', mode='chunked', chunk_size=500)
# Solution 3: Lower magnification for feature extraction
extractor = WSIFeatureExtractor(level=1)- β Always validate on a small subset first
- β Save intermediate features to avoid re-extraction
- β Use appropriate backbone for your domain (Phikon/UNI for pathology)
- β Monitor GPU memory when processing large slides
- β Verify tissue detection on representative samples
- β
Use
chunkedmode for consistent batch sizes - β
Start with
num_queries=1for simplicity - β
Enable mixed precision (
use_amp=True) to save memory - β Monitor both chunk-level and slide-level accuracy
- β
Use patient-level splitting (via
case_id) to avoid data leakage
- β
Use
fullmode for slides with <10K patches - β
Use
chunkedmode for slides with >10K patches - β Validate predictions with attention heatmaps
- β Check attention statistics for model confidence
- β
Use colorblind-friendly colormaps (
viridis,plasma) for publications - β Include scale bar or patch size in captions
- β Show multiple examples (correct, incorrect, uncertain)
- β Validate attention patterns with domain experts
- β
Save attention data (
.npz) for reproducibility
If you use LiteMIL pipeline in your research, please cite:
@article{Kussaibi2026LiteMIL,
title = {LiteMIL: a computationally efficient cross-attention multiple instance learning for cancer subtyping on whole-slide images},
author = {Kussaibi, Haitham},
journal = {Journal of Medical Imaging},
year = {2026},
volume = {13},
number = {1},
pages = {017501},
doi = {10.1117/1.JMI.13.1.017501}
}CopyrightΒ© 2025 Haitham Kussaibi
This project is licensed under the Apache 2.0 License - see the LICENSE file for details.
- TCGA Research Network for providing datasets
- MAD-MIL study for pre-extracted features (Zenodo)
Dr. Haitham Kussaibi
- π§ Email: kussaibi@gmail.com
- π ORCID: 0000-0002-9570-0768
- πΌ LinkedIn: linkedin.com/in/haithamkussaibi
If you find LiteMIL useful for your research, please consider giving it a star! β
Made with β€οΈ for the computational pathology community