Official PyTorch implementation of PhysioWave, accepted at NeurIPS 2025
A novel wavelet-based architecture for physiological signal processing that leverages adaptive multi-scale decomposition and frequency-guided masking to advance self-supervised learning
|
β¨ Learnable Wavelet Decomposition
π Frequency-Guided Masking
|
π Cross-Scale Feature Fusion
π§ Multi-Modal Support
|
π Large-Scale Pretraining: Models trained on 182GB of ECG and 823GB of EMG data
The PhysioWave pretraining pipeline consists of five key stages:
- Wavelet Initialization: Standard wavelet functions (e.g., 'db6', 'sym4') generate learnable low-pass and high-pass filters
- Multi-Scale Decomposition: Adaptive wavelet decomposition produces multi-scale frequency-band representations
- Patch Embedding: Decomposed features are processed into spatio-temporal patches with FFT-based importance scoring
- Masked Encoding: High-scoring patches are masked and processed through Transformer layers with rotary position embeddings
- Reconstruction: Lightweight decoder reconstructs masked patches for self-supervised learning
| Component | Description |
|---|---|
| π Learnable Wavelet Decomposition | Adaptively selects optimal wavelet bases for input signals |
| π Multi-Scale Feature Reconstruction | Hierarchical decomposition with soft gating between scales |
| π― Frequency-Guided Masking | Identifies and masks high-energy patches for self-supervised learning |
| π Transformer Encoder/Decoder | Processes masked patches with rotary position embeddings |
| Task | Dataset | Metric | Performance |
|---|---|---|---|
| ECG Arrhythmia | PTB-XL | Accuracy | 73.1% |
| ECG Multi-Label | CPSC 2018 | F1-Micro | 77.1% |
| ECG Multi-Label | Shaoxing | F1-Micro | 94.6% |
| EMG Gesture | EPN-612 | Accuracy | 94.5% |
CPSC 2018 Dataset (9-Class Multi-Label)
| Metric | Micro-Average | Macro-Average |
|---|---|---|
| Precision | 0.7389 | 0.6173 |
| Recall | 0.8059 | 0.6883 |
| F1-Score | 0.7709 | 0.6500 |
| AUROC | 0.9584 | 0.9280 |
Dataset Details:
- 9 official diagnostic classes (SNR, AF, IAVB, LBBB, RBBB, PAC, PVC, STD, STE)
- 12-lead ECG signals at 500 Hz
- Record-level split to prevent data leakage
Chapman-Shaoxing Dataset (4-Class Multi-Label)
| Metric | Micro-Average | Macro-Average |
|---|---|---|
| Precision | 0.9389 | 0.9361 |
| Recall | 0.9536 | 0.9470 |
| F1-Score | 0.9462 | 0.9413 |
| AUROC | 0.9949 | 0.9930 |
Dataset Details:
- 4 merged diagnostic classes (SB, AFIB, GSVT, SR)
- 12-lead ECG signals at 500 Hz
- Balanced multi-label distribution
| Model | Parameters | Training Data | Description |
|---|---|---|---|
ecg.pth |
14M | 182GB ECG | ECG pretrained model |
emg.pth |
5M | 823GB EMG | EMG pretrained model |
Usage:
# Load pretrained model
checkpoint = torch.load('ecg.pth')
model.load_state_dict(checkpoint['model_state_dict'])# Clone repository
git clone https://github.com/ForeverBlue816/PhysioWave.git
cd PhysioWave
# Create conda environment
conda create -n physiowave python=3.11
conda activate physiowave
# Install PyTorch (CUDA 12.1)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# Install requirements
pip install -r requirements.txtDataset Download Links
- PTB-XL Database - 21,837 clinical ECG records
- MIMIC-IV-ECG - 800K+ ECG recordings
- PhysioNet Challenge 2021 - Multi-database ECG
- CPSC 2018 - Arrhythmia classification challenge
- Chapman-Shaoxing - Large-scale 12-lead ECG
- EPN-612 Dataset - 612 hand gestures
- NinaPro Database DB6 - HD-sEMG recordings
Data Format Specifications
# Single-label classification
{
'data': (N, C, T), # Signal data: float32
'label': (N,) # Labels: int64
}
# Multi-label classification
{
'data': (N, C, T), # Signal data: float32
'label': (N, K) # Multi-hot labels: float32
}Dimensions:
N= Number of samplesC= Number of channelsT= Time pointsK= Number of classes (multi-label only)
| Signal | Channels | Length | Sampling Rate | Normalization |
|---|---|---|---|---|
| ECG | 12 | 2048 | 500 Hz | MinMax [-1,1] or Z-score |
| EMG | 8 | 1024 | 200-2000 Hz | Max-abs or Z-score |
ECG Preprocessing (PTB-XL - Single-Label)
# Download PTB-XL dataset
wget -r -N -c -np https://physionet.org/files/ptb-xl/1.0.3/
# Preprocess for single-label classification
python ECG/ptbxl_finetune.pyOutput files:
train.h5- Training data with shape(N, 12, 2048)val.h5- Validation datatest.h5- Test data
Label format: (N,) with 5 superclasses (NORM, MI, STTC, CD, HYP)
ECG Preprocessing (CPSC 2018 - Multi-Label)
# Preprocess CPSC 2018 dataset
python ECG/cpsc_multilabel.pyOutput files:
cpsc_9class_train.h5- Training datacpsc_9class_val.h5- Validation datacpsc_9class_test.h5- Test datacpsc_9class_info.json- Dataset metadatalabel_map.json- Class mappingsrecord_splits.json- Record-level split info
Label format: (N, 9) with 9 official CPSC classes
ECG Preprocessing (Chapman-Shaoxing - Multi-Label)
# Preprocess Chapman-Shaoxing dataset
python ECG/shaoxing_multilabel.pyOutput files:
train.h5- Training dataval.h5- Validation datatest.h5- Test datadataset_info.json- Metadatarecord_splits.json- Split information
Label format: (N, 4) with 4 merged classes (SB, AFIB, GSVT, SR)
EMG Preprocessing (EPN-612)
# Download from Zenodo and preprocess
python EMG/epn_finetune.pyOutput files:
epn612_train_set.h5- Training set(N, 8, 1024)epn612_val_set.h5- Validation setepn612_test_set.h5- Test set
Label format: (N,) with 6 gesture classes
ECG Pretraining
# Edit ECG/pretrain_ecg.sh to set data paths
bash ECG/pretrain_ecg.shKey parameters:
--mask_ratio 0.7 # Mask 70% of patches
--masking_strategy frequency_guided # Use frequency-guided masking
--importance_ratio 0.7 # Balance importance vs randomness
--epochs 100 # Pretraining epochsEMG Pretraining
# Edit EMG/pretrain_emg.sh to set data paths
bash EMG/pretrain_emg.shKey parameters:
--mask_ratio 0.6 # Mask 60% of patches
--in_channels 8 # 8-channel EMG
--wave_kernel_size 16 # Smaller kernel for EMGStandard Fine-tuning (ECG/EMG)
# ECG fine-tuning (PTB-XL)
bash ECG/finetune_ecg.sh
# EMG fine-tuning (EPN-612)
bash EMG/finetune_emg.shExample command:
torchrun --nproc_per_node=4 finetune.py \
--train_file path/to/train.h5 \
--val_file path/to/val.h5 \
--test_file path/to/test.h5 \
--pretrained_path pretrained/ecg.pth \
--task_type classification \
--num_classes 5 \
--batch_size 16 \
--epochs 50 \
--lr 1e-4Multi-Label Fine-tuning (CPSC/Shaoxing)
This repository uses finetune_multilabel.py for multi-label classification tasks. First, prepare your data using the corresponding preprocessing scripts.
CPSC 2018 Example:
# Edit paths in ECG/cpsc_multilabel.sh
bash ECG/cpsc_multilabel.shShaoxing Example:
# Edit paths in ECG/shaoxing_multilabel.sh
bash ECG/shaoxing_multilabel.shManual command:
NUM_GPUS=4
torchrun --nproc_per_node=${NUM_GPUS} finetune_multilabel.py \
--train_file "path/to/train.h5" \
--val_file "path/to/val.h5" \
--test_file "path/to/test.h5" \
--pretrained_path "path/to/pretrained_ecg/best_model.pth" \
\
`# Task Configuration` \
--task_type multilabel \
--threshold 0.3 \
\
`# Model Architecture` \
--in_channels 12 \
--max_level 3 \
--wave_kernel_size 24 \
--wavelet_names db4 db6 sym4 coif2 \
--use_separate_channel \
--patch_size 64 \
--embed_dim 384 \
--depth 8 \
--num_heads 12 \
--mlp_ratio 4.0 \
--dropout 0.1 \
\
`# Training Parameters` \
--batch_size 16 \
--epochs 50 \
--lr 1e-4 \
--weight_decay 1e-4 \
--scheduler cosine \
--warmup_epochs 5 \
--grad_clip 1.0 \
--use_amp \
\
`# Classification Head` \
--pooling mean \
--head_hidden_dim 512 \
--head_dropout 0.2 \
--label_smoothing 0.1 \
\
`# Output` \
--seed 42 \
--output_dir "./checkpoints_multilabel"Key Parameters for Multi-Label:
--task_type multilabel- Enable multi-label classification--threshold 0.3- Decision threshold (adjust based on validation)--label_smoothing 0.1- Regularization for better generalization
Linear Probing
Evaluate pretrained representations by freezing the encoder and training only the classification head:
torchrun --nproc_per_node=4 finetune.py \
--train_file path/to/train.h5 \
--val_file path/to/val.h5 \
--test_file path/to/test.h5 \
--pretrained_path pretrained/ecg.pth \
--freeze_encoder \
--num_classes 5 \
--epochs 10 \
--lr 1e-3Architecture Parameters
| Parameter | Description | Options | Recommendation |
|---|---|---|---|
--in_channels |
Input channels | 12 (ECG), 8 (EMG) | Match your data |
--max_level |
Wavelet decomposition levels | 2-4 | 3 (default) |
--wave_kernel_size |
Wavelet kernel size | 16-32 | 24 (ECG), 16 (EMG) |
--wavelet_names |
Wavelet families | db, sym, coif, bior | See tips below |
--embed_dim |
Embedding dimension | 128-768 | 256/384/512 |
--depth |
Transformer layers | 4-12 | 6/8/12 |
--num_heads |
Attention heads | 4-16 | 8/12 |
--patch_size |
Temporal patch size | 20-128 | 64 (ECG), 32 (EMG) |
π‘ Wavelet Selection Tips:
| Signal Type | Recommended Wavelets | Rationale |
|---|---|---|
| ECG | db4 db6 sym4 coif2 |
Optimal for QRS complex detection |
| EMG | sym4 sym5 db6 coif3 bior4.4 |
Best for muscle activation patterns |
| Custom | Experiment with combinations | Domain-specific optimization |
Training Configuration
| Parameter | Description | ECG | EMG |
|---|---|---|---|
--mask_ratio |
Masking ratio | 0.7 | 0.6 |
--masking_strategy |
Masking type | frequency_guided |
frequency_guided |
--importance_ratio |
Importance weight | 0.7 | 0.6 |
--epochs |
Training epochs | 100 | 100 |
--lr |
Learning rate | 2e-5 | 5e-5 |
| Parameter | Description | Default | Range |
|---|---|---|---|
--batch_size |
Batch size per GPU | 16 | 8-64 |
--epochs |
Training epochs | 50 | 20-100 |
--lr |
Learning rate | 1e-4 | 1e-5 to 1e-3 |
--weight_decay |
L2 regularization | 1e-4 | 1e-5 to 1e-3 |
--scheduler |
LR scheduler | cosine |
cosine/step/plateau |
--warmup_epochs |
Warmup epochs | 5 | 0-10 |
--grad_clip |
Gradient clipping | 1.0 | 0.5-2.0 |
| Parameter | Description | Default | Notes |
|---|---|---|---|
--threshold |
Decision threshold | 0.3-0.5 | Tune on validation set |
--label_smoothing |
Label smoothing | 0.1 | 0.0-0.2 for regularization |
--use_class_weights |
Class balancing | False | Enable for imbalanced data |
Hardware and Performance
# Enable mixed precision for 2x speedup
--use_amp
# Increase batch size with gradient accumulation
--batch_size 8 --grad_accumulation_steps 4 # Effective batch size: 32
# Multi-GPU training
torchrun --nproc_per_node=4 [script.py]If you find our work helpful, please cite:
@article{chen2025physiowave,
title={PhysioWave: A Multi-Scale Wavelet-Transformer for Physiological Signal Representation},
author={Chen, Yanlong and Orlandi, Mattia and Rapa, Pierangelo Maria and Benatti, Simone and Benini, Luca and Li, Yawei},
journal={arXiv preprint arXiv:2506.10351},
year={2025}
}Lead Author: Yanlong Chen
Email: yanlchen@student.ethz.ch
We welcome contributions! Feel free to:
- π Report bugs
- π‘ Suggest enhancements
- π§ Submit Pull Requests
- β Star this repository if you find it useful!
We thank:
- The authors of PTB-XL, MIMIC-IV-ECG, CPSC 2018, Chapman-Shaoxing, and EPN-612 datasets
- The PyTorch team for their excellent framework
- The open-source community for inspiration and tools
This project is licensed under the MIT License - see the LICENSE file for details.
