Skip to content

This is an implementation of privacy-preserving federated learning for medical image classification. This project demonstrates how multiple medical institutions/nodes/clients can collaborate to train a shared ML model without exchanging sensitive patient data. Built with PyTorch and Flower framework, it supports 3 medical imaging datasets.

Notifications You must be signed in to change notification settings

skypo9/Medical-Imaging-Federated-Learning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 

Repository files navigation

Federated Learning for Medical Image Classification

This repository implements a federated learning system for medical image classification using PyTorch and the Flower (flwr) framework. The system is designed to handle multiple medical imaging datasets while preserving data privacy through federated learning.

Features

  • Federated learning implementation using Flower framework
  • Support for multiple medical imaging datasets:
    • Brain MRI (tumor detection)
    • Chest X-ray (COVID-19 detection)
    • Retinal images (diabetic retinopathy detection)
  • FedBN (Federated Batch Normalization) support for handling distribution shifts
  • Comprehensive visualization of training progress and results
  • CSV export of training metrics and test accuracies

Requirements

  • Python 3.8+
  • PyTorch
  • torchvision
  • Flower (flwr)
  • matplotlib
  • numpy
  • Pillow
  • kaggle

Install dependencies using:

pip install torch torchvision flwr matplotlib numpy pillow kaggle

Project Structure

.
├── fl_simulation.py      # Main federated learning implementation
├── datasets/            # Directory for medical datasets
│   ├── brain_mri/      # Brain MRI dataset
│   ├── chest_xray/     # Chest X-ray dataset
│   └── retina/         # Retinal images dataset
├── training_results/   # Generated training visualizations and metrics
└── README.md

Usage

  1. Set up your Kaggle API credentials:

    • Create a Kaggle account if you don't have one
    • Go to Account settings and create a new API token
    • Save the kaggle.json file in ~/.kaggle/
  2. Run the federated learning simulation:

python fl_simulation.py --data_dir ./datasets/chest_xray --clients 3 --rounds 3 --use_fedbn

Arguments:

  • --data_dir: Path to dataset directory
  • --clients: Number of federated clients
  • --rounds: Number of training rounds
  • --use_fedbn: Enable FedBN strategy (optional)

Model Architecture

The CNN model architecture includes:

  • Multiple convolutional blocks with increasing channels
  • Batch normalization for training stability
  • Dropout for regularization
  • Adaptive pooling for handling variable input sizes

Results

Training results are saved in the training_results directory:

  • Individual client training loss curves (client_X_training.png)
  • Test accuracies for each client (test_accuracies.png)
  • Combined visualization (combined_results.png)
  • CSV files with detailed metrics

Citation

If you use this code in your research, please cite:

@misc{pohiva2025federated,
  author = {Akuila Pohiva},
  title = {Federated Learning for Medical Image Classification},
  year = {2025},
  publisher = {GitHub},
  url = {https://github.com/username/fl-medical-federated}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

About

This is an implementation of privacy-preserving federated learning for medical image classification. This project demonstrates how multiple medical institutions/nodes/clients can collaborate to train a shared ML model without exchanging sensitive patient data. Built with PyTorch and Flower framework, it supports 3 medical imaging datasets.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages