Skip to content

Sleep Classifier implemented in PyTorch featuring a hybrid transformer and lstm model.

License

Notifications You must be signed in to change notification settings

danielgavrila2/Sleep-Hybrid-Classifier

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

11 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Sleep Stage Classification System

A real-time sleep stage prediction system using deep learning to classify sleep stages (Wake, Light, Deep, REM) from wearable device data (heart rate, motion, steps).

πŸ“‹ Table of Contents

🎯 Overview

This project implements two neural network architectures for sleep stage classification:

  1. LSTM Model: Bidirectional LSTM with attention mechanism
  2. Hybrid CNN-Transformer: CNN front-end for local patterns + Transformer encoder for long-range dependencies

Both models process time-series data from consumer-grade wearables (e.g., Apple Watch) to predict sleep stages in real-time.

✨ Features

  • Multiple Model Architectures: LSTM and CNN-Transformer hybrid models
  • Real-time Inference: Predict sleep stages from live sensor data
  • Robust Data Pipeline: Handles missing data, resampling, and feature engineering
  • REM Sleep Focus: Special attention to improving REM detection accuracy
  • Docker Support: Fully containerized for easy deployment
  • Class Imbalance Handling: Focal loss, class weighting, and data augmentation
  • Comprehensive Metrics: Per-class accuracy, confusion matrices, and detailed reports

πŸ—οΈ Architecture

LSTM Model

  • Bidirectional LSTM layers (2-3 layers, 256-320 hidden units)
  • Attention mechanism for temporal pattern focus
  • Layer normalization for training stability
  • Residual connections to prevent gradient vanishing

Hybrid CNN-Transformer Model

  • CNN Front-end: Captures local temporal patterns (2 layers, 64 channels)
  • Transformer Encoder: Models long-range dependencies (3 layers, 4 attention heads)
  • Positional Encoding: Preserves temporal order information
  • Global Pooling: Aggregates sequence information for classification

πŸ“Š Dataset

Source: Sleep-Accel Dataset

Input Features:

  • Heart rate (BPM)
  • Motion magnitude (accelerometer x, y, z)
  • Steps count (prior day activity)

Labels:

  • 0: Wake
  • 1: Light Sleep (combined NREM1 + NREM2)
  • 2: Deep Sleep (NREM3)
  • 3: REM Sleep

Data Processing:

  • 30-second epochs
  • Sequence length: 25 epochs (12.5 minutes of context)
  • StandardScaler normalization
  • Forward/backward fill for missing values

πŸš€ Installation

Option 1: Docker (Recommended)

# Clone the repository
git clone <repository-url>
cd sleep-staging

# Build and run with docker-compose
docker-compose up

# For training
docker-compose --profile train up train-hybrid

Option 2: Local Installation

# Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install dependencies
pip install -r requirements.txt

# Download dataset from PhysioNet
# Extract to ./data/ directory following the structure:
# data/
#   labels/
#   heart_rate/
#   motion/
#   steps/

πŸ’» Usage

Training

Train LSTM Model:

python train.py

Train Hybrid CNN-Transformer:

python train_hybrid.py

Training Configuration (modify in script):

  • SEQ_LEN: Sequence length (default: 25)
  • BATCH_SIZE: Batch size (default: 96)
  • EPOCHS: Number of epochs (default: 100)
  • LR: Learning rate (default: 8e-4)

Inference

Run Inference on Test Set:

# LSTM model
python inference.py

# Hybrid model
python inference_hybrid.py

Programmatic Usage:

from inference_hybrid import load_model, predict_label_from_triplet

# Load model
model, scaler = load_model()

# Predict from single reading
pred, probs = predict_label_from_triplet(
    model, scaler,
    hr=70.0,      # heart rate
    motion=0.05,  # motion magnitude
    steps=0.0     # cumulative steps
)

print(f"Predicted stage: {pred}")  # 0=Wake, 1=Light, 2=Deep, 3=REM
print(f"Probabilities: {probs}")

Docker Usage

# Build image
docker build -t sleep-staging .

# Run inference
docker run -v $(pwd)/data:/app/data sleep-staging python inference_hybrid.py

# Run training
docker run -v $(pwd)/data:/app/data sleep-staging python train_hybrid.py

# Interactive shell
docker run -it -v $(pwd)/data:/app/data sleep-staging /bin/bash

πŸ“ˆ Model Performance

πŸŽ‰ Bonus Achievement: REM Accuracy > 70%

Both models successfully achieved the bonus criteria!

Hybrid CNN-Transformer Model ⭐ BEST

Test Set Performance (4 subjects, 3,538 sequences):

Metric Score
Overall Accuracy 33.97%
REM Accuracy 82.34% βœ… BONUS ACHIEVED
Wake Accuracy 31.01%
Light Sleep Accuracy 11.76%
Deep Sleep Accuracy 60.74%

Detailed Metrics:

              precision    recall  f1-score   support
        Wake     0.4534    0.3101    0.3683       345
       Light     0.8511    0.1176    0.2066      2041
        Deep     0.2906    0.6074    0.3931       433
         REM     0.2799    0.8234    0.4178       719

REM Confusion Analysis:

  • Correct REM predictions: 592/719 (82.3%)
  • REM β†’ Deep: 77 (10.7%) - most common error
  • REM β†’ Wake: 26 (3.6%)
  • REM β†’ Light: 24 (3.3%)

LSTM Model

Test Set Performance (4 subjects, 3,538 sequences):

Metric Score
Overall Accuracy 35.08%
REM Accuracy 75.80% βœ… BONUS ACHIEVED
Wake Accuracy 43.19%
Light Sleep Accuracy 13.57%
Deep Sleep Accuracy 62.36%

Detailed Metrics:

              precision    recall  f1-score   support
        Wake     0.3260    0.4319    0.3716       345
       Light     0.6942    0.1357    0.2270      2041
        Deep     0.2885    0.6236    0.3944       433
         REM     0.3121    0.7580    0.4422       719

REM Confusion Analysis:

  • Correct REM predictions: 545/719 (75.8%)
  • REM β†’ Light: 71 (9.9%)
  • REM β†’ Wake: 59 (8.2%)
  • REM β†’ Deep: 44 (6.1%)

Key Insights

Model Comparison:

  • Hybrid model wins on REM detection: 82.34% vs 75.80% (+6.5%)
  • Both models prioritize REM recall over overall accuracy (by design)
  • The aggressive REM optimization trades off Light sleep accuracy

Why Low Overall Accuracy? The models were specifically optimized for REM detection (the bonus criteria) using:

  • Heavy REM class weight boosting (1.6-2.5Γ—)
  • REM-specific data augmentation
  • Focal loss prioritizing hard examples (REM)

This aggressive optimization maximizes REM recall at the cost of other classes, particularly Light sleep. For balanced performance, reduce REM weight multiplier.

πŸ“ Project Structure

sleep-staging/
β”œβ”€β”€ data/                      # Dataset directory (not in repo)
β”‚   β”œβ”€β”€ labels/
β”‚   β”œβ”€β”€ heart_rate/
β”‚   β”œβ”€β”€ motion/
β”‚   └── steps/
β”œβ”€β”€ data_loader.py            # Data loading and preprocessing
β”œβ”€β”€ model.py                  # LSTM model architecture
β”œβ”€β”€ model_hybrid.py           # CNN-Transformer hybrid architecture
β”œβ”€β”€ train.py                  # LSTM training script
β”œβ”€β”€ train_hybrid.py           # Hybrid model training script
β”œβ”€β”€ inference.py              # LSTM inference script
β”œβ”€β”€ inference_hybrid.py       # Hybrid inference script
β”œβ”€β”€ requirements.txt          # Python dependencies
β”œβ”€β”€ Dockerfile                # Docker configuration
β”œβ”€β”€ docker-compose.yml        # Docker Compose configuration
β”œβ”€β”€ .dockerignore            # Docker ignore patterns
└── README.md                # This file

πŸ”¬ Technical Details

Data Preprocessing

  1. Temporal Alignment: All signals aligned to 30-second epochs
  2. Feature Engineering:
    • Motion magnitude: sqrt(xΒ² + yΒ² + zΒ²)
    • Steps: Cumulative count from previous day
  3. Normalization: StandardScaler (fitted on training data only)
  4. Sequence Creation: Sliding window of 25 epochs

Training Techniques

Addressing Class Imbalance:

  • Focal Loss (Ξ³=2.0) to focus on hard examples
  • Balanced class weights with REM boost (1.6-2.5Γ—)
  • REM data augmentation (2Γ— with gaussian noise)

Regularization:

  • Dropout (0.2-0.3)
  • Weight decay (5e-5)
  • Gradient clipping (max_norm=0.5)

Optimization:

  • AdamW optimizer
  • Cosine annealing with warm restarts
  • Learning rate: 8e-4 β†’ 1e-6

Model Selection

Models are selected based on validation REM accuracy rather than overall accuracy, as REM detection is the primary challenge and most valuable for sleep analysis applications.

Subject-Level Splitting

To prevent data leakage, the dataset is split at the subject level:

  • Training: 80% of subjects
  • Validation: 10% of subjects
  • Test: 10% of subjects

This ensures the model generalizes to new individuals, not just new nights from the same person.

πŸ› οΈ Troubleshooting

CUDA Out of Memory:

  • Reduce BATCH_SIZE in training scripts
  • Reduce SEQ_LEN (at cost of context)

Low REM Accuracy:

  • Increase REM class weight multiplier in training script
  • Increase augmentation factor
  • Try longer sequences for more context

Model Not Loading:

  • Ensure model_hybrid_best.pth and scaler.pkl exist
  • Check model architecture matches saved checkpoint

πŸ“ License

This project is for academic and research purposes. Dataset is from PhysioNet and subject to their terms of use.

πŸ™ Acknowledgments

πŸ“§ Contact

For questions or issues, please open an issue in the repository.

About

Sleep Classifier implemented in PyTorch featuring a hybrid transformer and lstm model.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published