Skip to content

Temporal Variational Autoencoder Model for Rapid Response Sytem

Notifications You must be signed in to change notification settings

nghianguyen7171/TVAE-RRS

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

19 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

TVAE-RRS: Temporal Variational Autoencoder Model for Rapid Response System

License: MIT Python 3.8+ TensorFlow 2.11+

πŸ“– Abstract

Early recognition of clinical deterioration plays a pivotal role in a Rapid Response System (RRS), and it has been a crucial step in reducing inpatient morbidity and mortality. Traditional Early Warning Scores (EWS) and Deep Early Warning Scores (DEWS) for identifying patients at risk are still limited because of the challenges of imbalanced multivariate temporal data. Typical issues leading to their limitations are low sensitivity, high late alarm rate, and lack of interpretability; this has made the system face difficulty in being deployed in clinical settings.

This study develops an early warning system based on Temporal Variational Autoencoder (TVAE) and a window interval learning framework, which uses the latent space features generated from the input multivariate temporal features, to learn the temporal dependence distribution between the target labels (clinical deterioration probability). Implementing the target information in the Fully Connected Network (FCN) architect of the decoder with a loss function assists in addressing the imbalance problem and improving the performance of the time series classification task.

🎯 Key Features

  • TVAE Architecture: 3-layer LSTM encoder (100/50/25 hidden units) with VAE latent space and dual decoders
  • Window Interval Processing (WIP): Advanced temporal feature extraction framework
  • Comprehensive Baselines: RNN, BiLSTM+Attention, DCNN, FCNN, and XGBM models
  • Robust Evaluation: K-Fold CV, VTSA, LOOCV with AUROC, AUPRC, F1, Kappa metrics
  • Late Alarm Analysis: Critical metric for clinical deployment
  • Production Ready: Modular, well-documented, and reproducible codebase

πŸ“Š Model Architecture

TVAE Architecture Overview

Input Sequence (T Γ— F)
         ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   3-Layer LSTM      β”‚
β”‚   Encoder           β”‚
β”‚   (100β†’50β†’25)       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   VAE Latent Space  β”‚
β”‚   (ΞΌ, Οƒ, z)         β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         ↓
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚  Dual   β”‚
    β”‚Decoders β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚LSTM Rec.β”‚ β”‚FCN Classβ”‚
β”‚Decoder  β”‚ β”‚Decoder  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Loss Function

The TVAE model uses a combined loss function:

L_final = L_vae + L_clinical + L_imbalance

Where:
- L_vae = L_reconstruction + Ξ² Γ— L_KL_divergence
- L_clinical = Binary Cross-Entropy + Temporal Consistency
- L_imbalance = Focal Loss with Class Weighting

πŸš€ Installation

Prerequisites

  • Python 3.8+
  • CUDA 12.1+ (for GPU support)
  • TensorFlow 2.11+

Option 1: Using Conda (Recommended)

# Clone the repository
git clone https://github.com/nghianguyen7171/TVAE-RRS.git
cd TVAE-RRS

# Create conda environment
conda env create -f environment.yml
conda activate tvae-rrs

# Install the package
pip install -e .

Option 2: Using pip

# Clone the repository
git clone https://github.com/nghianguyen7171/TVAE-RRS.git
cd TVAE-RRS

# Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install dependencies
pip install -r requirements.txt

# Install the package
pip install -e .

πŸ“ Project Structure

TVAE-RRS/
β”œβ”€β”€ README.md
β”œβ”€β”€ requirements.txt
β”œβ”€β”€ environment.yml
β”œβ”€β”€ setup.py
β”‚
β”œβ”€β”€ data/
β”‚   β”œβ”€β”€ raw/                    # Raw data files
β”‚   β”œβ”€β”€ processed/              # Processed data files
β”‚   └── external/               # External datasets
β”‚
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ data_preprocessing.py   # Data preprocessing utilities
β”‚   β”œβ”€β”€ dataset_loader.py       # Dataset loading and preparation
β”‚   β”œβ”€β”€ main.py                 # Main entry point
β”‚   β”‚
β”‚   β”œβ”€β”€ models/
β”‚   β”‚   β”œβ”€β”€ tvae.py            # TVAE model implementation
β”‚   β”‚   β”œβ”€β”€ rnn_baseline.py    # RNN baseline
β”‚   β”‚   β”œβ”€β”€ bilstm_attention.py # BiLSTM+Attention baseline
β”‚   β”‚   β”œβ”€β”€ dcnn.py            # DCNN baseline
β”‚   β”‚   β”œβ”€β”€ fcnn.py            # FCNN baseline
β”‚   β”‚   └── xgbm_baseline.py   # XGBM baseline
β”‚   β”‚
β”‚   β”œβ”€β”€ training/
β”‚   β”‚   β”œβ”€β”€ train_tvae.py      # TVAE training
β”‚   β”‚   β”œβ”€β”€ train_baselines.py # Baseline training
β”‚   β”‚   └── utils_train.py     # Training utilities
β”‚   β”‚
β”‚   β”œβ”€β”€ evaluation/
β”‚   β”‚   β”œβ”€β”€ evaluate_metrics.py # Metrics evaluation
β”‚   β”‚   β”œβ”€β”€ visualize_results.py # Results visualization
β”‚   β”‚   └── t_sne_latent.py    # t-SNE visualization
β”‚   β”‚
β”‚   └── utils/
β”‚       β”œβ”€β”€ losses.py           # Loss functions
β”‚       β”œβ”€β”€ window_processing.py # Window processing (WIP)
β”‚       └── config.py           # Configuration management
β”‚
β”œβ”€β”€ experiments/
β”‚   β”œβ”€β”€ config_experiments.yaml # Experiment configuration
β”‚   β”œβ”€β”€ results/                # Experiment results
β”‚   └── logs/                   # Training logs
β”‚
└── notebooks/
    β”œβ”€β”€ 1_data_exploration.ipynb
    β”œβ”€β”€ 2_model_training.ipynb
    └── 3_evaluation_visualization.ipynb

πŸ“Š Data Preparation

CNUH Dataset

The CNUH dataset contains clinical data from Chonnam National University Hospitals with the following features:

  • Vital Signs: SBP, BT, SaO2, RR, HR
  • Laboratory Values: Albumin, Hgb, BUN, WBC Count, Creatinin, etc.
  • Demographics: Age, Gender
  • Target: Clinical deterioration (0: Normal, 1: Abnormal)

UV Dataset

The UV dataset is a public dataset from the University of Virginia with similar clinical features.

Data Format

# Expected data format
data = {
    'Patient': [1, 1, 1, ...],           # Patient ID
    'measurement_time': [...],           # Timestamp
    'target': [0, 0, 1, ...],           # Target labels
    'SBP': [120, 125, 110, ...],         # Systolic Blood Pressure
    'HR': [80, 85, 95, ...],            # Heart Rate
    # ... other features
}

πŸƒβ€β™‚οΈ Quick Start

1. Basic Training

# Train TVAE model on CNUH dataset
python src/main.py \
    --model tvae \
    --dataset CNUH \
    --train_path data/raw/cnuh_train.csv \
    --test_path data/raw/cnuh_test.csv \
    --window 16 \
    --epochs 100 \
    --batch_size 32

2. Train All Models

# Train TVAE and all baseline models
python src/main.py \
    --model all \
    --dataset CNUH \
    --train_path data/raw/cnuh_train.csv \
    --test_path data/raw/cnuh_test.csv \
    --window 16 \
    --epochs 100

3. Cross-Validation

# Run 5-fold cross-validation
python src/main.py \
    --model tvae \
    --dataset CNUH \
    --train_path data/raw/cnuh_train.csv \
    --cv_folds 5 \
    --cv_strategy stratified_kfold

πŸ”§ Configuration

Using Configuration File

# experiments/config_experiments.yaml
model:
  latent_dim: 8
  encoder_lstm_layers: [100, 50, 25]
  beta: 1.0

training:
  epochs: 100
  batch_size: 32
  learning_rate: 0.001

evaluation:
  cv_folds: 5
  primary_metrics: ["auroc", "auprc", "f1", "kappa"]
# Use configuration file
python src/main.py \
    --config experiments/config_experiments.yaml \
    --train_path data/raw/cnuh_train.csv

πŸ“ˆ Results

Performance Comparison

Model AUROC AUPRC F1 Score Kappa Late Alarm Rate
TVAE 0.973 0.887 0.856 0.812 0.124
RNN 0.945 0.823 0.798 0.756 0.189
BiLSTM+Attention 0.952 0.834 0.812 0.768 0.167
DCNN 0.938 0.815 0.785 0.742 0.201
FCNN 0.931 0.806 0.778 0.735 0.213
XGBM 0.927 0.798 0.771 0.728 0.225

Results on CNUH dataset with 16-hour window size

Key Findings

  1. TVAE outperforms all baselines across primary metrics
  2. Significant reduction in late alarm rate (34% improvement over best baseline)
  3. Robust performance across different validation strategies
  4. Stable performance with limited data samples

πŸ§ͺ Advanced Usage

Custom Model Training

from src.models.tvae import build_tvae_model
from src.training.utils_train import train_tvae_model

# Build custom TVAE model
model = build_tvae_model(
    input_shape=(16, 25),  # 16-hour window, 25 features
    latent_dim=16,
    encoder_lstm_layers=[128, 64, 32],
    learning_rate=0.0005
)

# Train with custom parameters
history = train_tvae_model(
    model=model,
    X_train=X_train,
    y_train=y_train,
    epochs=200,
    batch_size=64
)

Hyperparameter Tuning

from src.training.utils_train import hyperparameter_tuning

# Define parameter grid
param_grid = {
    'latent_dim': [8, 16, 32],
    'learning_rate': [0.001, 0.0005, 0.0001],
    'beta': [0.5, 1.0, 2.0]
}

# Perform hyperparameter tuning
results = hyperparameter_tuning(
    model_builder=lambda **params: build_tvae_model(**params),
    param_grid=param_grid,
    X_train=X_train,
    y_train=y_train,
    cv_folds=3
)

Custom Evaluation

from src.evaluation.evaluate_metrics import ModelEvaluator

# Initialize evaluator
evaluator = ModelEvaluator(
    primary_metrics=['auroc', 'auprc', 'f1'],
    threshold_optimization='youden'
)

# Calculate metrics
metrics = evaluator.calculate_metrics(y_true, y_pred_proba)

# Generate comprehensive report
report = evaluator.generate_report(
    y_true=y_test,
    y_pred_proba=y_pred_proba,
    model_name="Custom_TVAE"
)

πŸ“Š Visualization

ROC Curves

# Plot ROC curve with optimal thresholds
evaluator.plot_roc_curve(
    y_true=y_test,
    y_pred_proba=y_pred_proba,
    save_path="results/roc_curve.png"
)

t-SNE Visualization

# Visualize latent space
evaluator.plot_tsne(
    X=latent_features,
    y=y_test,
    save_path="results/tsne_visualization.png"
)

DEWS Score Distribution

# Plot DEWS score distribution
evaluator.plot_dews_scores(
    y_true=y_test,
    y_pred_proba=y_pred_proba,
    save_path="results/dews_distribution.png"
)

πŸ§ͺ Validation Strategies

1. K-Fold Cross-Validation

python src/main.py \
    --model tvae \
    --cv_folds 5 \
    --cv_strategy stratified_kfold

2. Leave-One-Out Cross-Validation (LOOCV)

python src/main.py \
    --model tvae \
    --cv_strategy loocv

3. Variation Test Sensitivity Analysis (VTSA)

# VTSA implementation
def vtsa_analysis(model, X, y, n_iterations=100):
    results = []
    for i in range(n_iterations):
        # Add noise to input
        X_noisy = X + np.random.normal(0, 0.01, X.shape)
        y_pred = model.predict(X_noisy)
        results.append(calculate_metrics(y, y_pred))
    return results

πŸ”¬ Reproducibility

Environment Setup

# Create exact environment
conda env create -f environment.yml
conda activate tvae-rrs

# Set random seeds
export PYTHONHASHSEED=42
export TF_DETERMINISTIC_OPS=1

Reproducing Results

# Reproduce paper results
python src/main.py \
    --model tvae \
    --dataset CNUH \
    --train_path data/raw/cnuh_train.csv \
    --test_path data/raw/cnuh_test.csv \
    --window 16 \
    --seed 42 \
    --epochs 100

πŸ“š Citation

If you use this code in your research, please cite our paper:

@article{Nguyen2025,
  author = {Nguyen, Trong-nghia and Kim, Soo-hyung and Kho, Bo-gun and Do, Nhu-tai},
  doi = {10.1016/j.bspc.2024.106975},
  issn = {1746-8094},
  journal = {Biomedical Signal Processing and Control},
  keywords = {Clinical deterioration,Rapid response system,Deep learning,Clinical medical signal},
  number = {PC},
  pages = {106975},
  publisher = {Elsevier Ltd},
  title = {{Temporal variational autoencoder model for in-hospital clinical emergency prediction}},
  url = {https://doi.org/10.1016/j.bspc.2024.106975},
  volume = {100},
  year = {2025}
}

🀝 Contributing

We welcome contributions! Please see our Contributing Guidelines for details.

Development Setup

# Install development dependencies
pip install -e ".[dev]"

# Run tests
pytest tests/

# Format code
black src/
flake8 src/

πŸ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

πŸ™ Acknowledgments

  • Chonnam National University Hospitals for providing clinical data
  • University of Virginia for the public dataset
  • TensorFlow and Keras teams for the deep learning framework
  • The open-source community for various libraries and tools

πŸ“ž Contact

πŸ”— Related Work


⚠️ Disclaimer: This software is for research purposes only. It should not be used for clinical decision-making without proper validation and regulatory approval.

About

Temporal Variational Autoencoder Model for Rapid Response Sytem

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published