A hands-on deep dive into Vision Transformers for semantic segmentation, featuring SegFormer architecture and medical image fine-tuning.
This repository is a personal project aimed at understanding and applying transformer-based architectures in semantic segmentation tasks. Specifically, it focuses on the SegFormer architecture introduced by NVIDIA in the paper “SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers”.
-
Learning Purpose: I built this project to gain a deep understanding of how transformer-based models like SegFormer are applied in computer vision, especially in the context of image segmentation.
-
From Paper to Code: After reading the SegFormer paper, I reimplemented the architecture from scratch using PyTorch to get hands-on experience with the design choices—such as hierarchical Transformer encoders, efficient MLP decoders, and positional encoding-free design.
-
Real-World Application with Medical Data: I fine-tuned the pretrained SegFormer model (from Hugging Face Transformers) on the UW-Madison GI Tract Image Segmentation dataset. The original decoder head (built for 150 classes) was replaced with a custom Conv2D head for 3-class segmentation.
-
SegFormer Implementation in PyTorch
-
Fine-Tuning on UW-Madison GI Tract Dataset
-
PEFT (Parameter-Efficient Fine-Tuning) Support
-
Decoder head adapted for medical segmentation (3 classes)
-
Built for experimentation and architecture understanding
-
UW-Madison GI Tract Image Segmentation
Used for fine-tuning the model in a real-world medical setting. The dataset includes annotated abdominal scans ideal for semantic segmentation of organs like large bowel, small bowel, and stomach. -
ImageNet-1K (10-Class Subset)
A curated subset of the standard ImageNet dataset, used to pretrain the SegFormer encoder on a lightweight classification task. This helped the model learn low-level and mid-level visual features efficiently before segmentation-specific training. -
Pascal VOC 2012
Used to train the complete SegFormer model (encoder + decoder) for generic object segmentation. The dataset provides diverse class labels and object boundaries to ensure robustness and generalization of the architecture.