This repository implements a federated learning system for medical image classification using PyTorch and the Flower (flwr) framework. The system is designed to handle multiple medical imaging datasets while preserving data privacy through federated learning.
- Federated learning implementation using Flower framework
- Support for multiple medical imaging datasets:
- Brain MRI (tumor detection)
- Chest X-ray (COVID-19 detection)
- Retinal images (diabetic retinopathy detection)
- FedBN (Federated Batch Normalization) support for handling distribution shifts
- Comprehensive visualization of training progress and results
- CSV export of training metrics and test accuracies
- Python 3.8+
- PyTorch
- torchvision
- Flower (flwr)
- matplotlib
- numpy
- Pillow
- kaggle
Install dependencies using:
pip install torch torchvision flwr matplotlib numpy pillow kaggle.
├── fl_simulation.py # Main federated learning implementation
├── datasets/ # Directory for medical datasets
│ ├── brain_mri/ # Brain MRI dataset
│ ├── chest_xray/ # Chest X-ray dataset
│ └── retina/ # Retinal images dataset
├── training_results/ # Generated training visualizations and metrics
└── README.md
-
Set up your Kaggle API credentials:
- Create a Kaggle account if you don't have one
- Go to Account settings and create a new API token
- Save the kaggle.json file in ~/.kaggle/
-
Run the federated learning simulation:
python fl_simulation.py --data_dir ./datasets/chest_xray --clients 3 --rounds 3 --use_fedbnArguments:
--data_dir: Path to dataset directory--clients: Number of federated clients--rounds: Number of training rounds--use_fedbn: Enable FedBN strategy (optional)
The CNN model architecture includes:
- Multiple convolutional blocks with increasing channels
- Batch normalization for training stability
- Dropout for regularization
- Adaptive pooling for handling variable input sizes
Training results are saved in the training_results directory:
- Individual client training loss curves (
client_X_training.png) - Test accuracies for each client (
test_accuracies.png) - Combined visualization (
combined_results.png) - CSV files with detailed metrics
If you use this code in your research, please cite:
@misc{pohiva2025federated,
author = {Akuila Pohiva},
title = {Federated Learning for Medical Image Classification},
year = {2025},
publisher = {GitHub},
url = {https://github.com/username/fl-medical-federated}
}
This project is licensed under the MIT License - see the LICENSE file for details.
- Flower Framework
- FedBN Paper
- Kaggle for providing the medical imaging datasets