diff --git a/templates/ray-train-deepspeed/README.md b/templates/ray-train-deepspeed/README.md new file mode 100644 index 000000000..630f4f68e --- /dev/null +++ b/templates/ray-train-deepspeed/README.md @@ -0,0 +1,436 @@ +# Getting Started with DeepSpeed ZeRO and Ray Train + +This template demonstrates how to combine DeepSpeed ZeRO with Ray Train to efficiently scale PyTorch training across GPUs and nodes while minimizing memory usage. + +DeepSpeed is a deep learning optimization library designed for scalability and efficiency. Its ZeRO (Zero Redundancy Optimizer) family partitions model states, gradients, and optimizer states across workers to drastically reduce memory consumption while preserving data-parallel semantics. + +This tutorial provides a step-by-step guide to integrating DeepSpeed ZeRO with Ray Train. Specifically, it covers: +- A hands-on example of fine-tuning a LLM +- Checkpoint saving and resuming with Ray Train +- Launching a distributed training job +- Configuring DeepSpeed for memory and performance (stages, mixed precision, CPU offload) + +Note: This template is optimized for the Anyscale platform. When running on open source Ray, you must configure a Ray cluster, install dependencies on all nodes, and set up storage for checkpoints. + +**Anyscale Specific Configuration** + +Note: This tutorial is optimized for the Anyscale platform. When running on open source Ray, additional configuration is required. For example, you will need to manually: + +- **Configure your Ray Cluster**: Set up your multi-node environment and manage resource allocation without Anyscale's automation. +- **Manage Dependencies**: Manually install and manage dependencies on each node. +- **Set Up Storage**: Configure your own distributed or shared storage system for model checkpointing. + +## Step by Step Guide + +In this example, we will demonstrate how to fine-tune an LLM with Ray Train and DeepSpeed in a multi-GPU, multi-node environment. +Before writing a Python script for fine-tuning, install the required dependencies: + +```bash +%%bash +pip install torch torchvision +pip install transformers datasets==3.6.0 trl +pip install deepspeed +``` + +### 1. Import Packages + +We start by importing the required libraries. These include Ray Train APIs for distributed training, PyTorch utilities for model and data handling, Transformers and Datasets from Hugging Face, and DeepSpeed. + +```python +import os +import tempfile +import uuid +import logging + +import argparse +from typing import Dict, Any + +import ray +import ray.train +import ray.train.torch +from ray.train.torch import TorchTrainer +from ray.train import ScalingConfig, RunConfig, Checkpoint + +import torch +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset, DownloadConfig + +import deepspeed + +logger = logging.getLogger(__name__) +``` + + +### 2. Set up dataloader + +We now define a dataset and a dataloader. The function below: + +1. Downloads a tokenizer from the Hugging Face Hub (AutoTokenizer). +1. Loads a dataset using Hugging Face’s load_dataset. +1. Applies tokenization with padding and truncation using map. +1. Converts the dataset into a PyTorch DataLoader, which handles batching and shuffling. +1. Finally, use ray.train.torch.prepare_data_loader to make the dataloader distributed-ready. + + +```python +def setup_dataloader(model_name: str, dataset_name: str, seq_length: int, batch_size: int) -> DataLoader: + # (1) Download tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + # (2) Load dataset + dataset = load_dataset(dataset_name, split="train[:100%]") + + # (3) Apply tokenization + def tokenize_function(examples): + return tokenizer(examples['text'], padding='max_length', max_length=seq_length, truncation=True) + + tokenized_dataset = dataset.map(tokenize_function, batched=True, num_proc=1, keep_in_memory=True) + tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask']) + + # (4) Create DataLoader + data_loader = DataLoader( + tokenized_dataset, + batch_size=batch_size, + shuffle=True + ) + + # (5) Use prepare_data_loader for distributed training + return ray.train.torch.prepare_data_loader(data_loader) +``` + +**Making the dataloader distributed-ready with Ray:** + +When training across multiple GPUs, the most common strategy is data parallelism: +- Each GPU worker gets a shard of the dataset. +- Workers process their batches independently. +- After each step, gradients are synchronized across workers to keep model parameters aligned. + +Normally, you’d need to manually configure PyTorch’s DistributedSampler for this. Ray’s prepare_data_loader automates that setup: +- Ensures each worker only sees its own shard. +- Avoids overlapping samples across GPUs. +- Handles epoch boundaries automatically. +This makes distributed training easier, while still relying on familiar PyTorch APIs. + + +### 3. Model and optimizer initialization + +We now set up the model and optimizer. The function below: + +1. Downloads a pretrained model from the Hugging Face Hub (AutoModelForCausalLM). +1. Defines the optimizer (AdamW). +1. Wraps the model and optimizer with DeepSpeed’s initialize, which applies ZeRO optimizations and returns a DeepSpeedEngine. + + +```python +def setup_model_and_optimizer(model_name: str, learning_rate: float, ds_config: Dict[str, Any]) -> deepspeed.runtime.engine.DeepSpeedEngine: + # (1) Load pretrained model + model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + + # (2) Define optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + + # (3) Initialize with DeepSpeed (distributed + memory optimizations) + ds_engine, _, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=ds_config, + ) + return ds_engine +``` + +**Making the model distributed-ready with Ray and DeepSpeed** + +In distributed training, every worker needs its own copy of the model and optimizer, but memory can quickly become a bottleneck. +DeepSpeed’s `initialize` always partitions the optimizer states across workers (ZeRO Stage 1). Depending on the chosen stage, it can also partition gradients (Stage 2) and model parameters/weights (Stage 3). This staged approach lets you balance memory savings with communication overhead while still applying additional optimizations for performance. We will describe these ZeRO stages in more detail [later in the tutorial](#deepspeed-zero-stages). + +## 4. Checkpointing and Loading + +Checkpointing is crucial for fault tolerance and for resuming training after interruptions. The functions below: + +1. Create a temporary directory for storing checkpoints. +1. Save the partitioned model and optimizer states with DeepSpeed’s `save_checkpoint`. +1. Synchronize all workers with `torch.distributed.barrier` to ensure every process finishes saving. +1. Report metrics and checkpoint location to Ray with `ray.train.report`. +1. Restore a previously saved checkpoint into the DeepSpeed engine using `load_checkpoint`. + + +```python +def report_metrics_and_save_checkpoint( + ds_engine: deepspeed.runtime.engine.DeepSpeedEngine, + metrics: Dict[str, Any] +) -> None: + # (1) Create temporary directory + with tempfile.TemporaryDirectory() as tmp: + tmp_epoch = os.path.join(tmp, "epoch") + os.makedirs(tmp_epoch, exist_ok=True) + + # (2) Save checkpoint (partitioned across workers) + ds_engine.save_checkpoint(tmp_epoch) + + # (3) Synchronize workers + torch.distributed.barrier() + + # (4) Report metrics and checkpoint to Ray + ray.train.report(metrics, checkpoint=Checkpoint.from_directory(tmp)) + + +def load_checkpoint(ds_engine: deepspeed.runtime.engine.DeepSpeedEngine, ckpt: ray.train.Checkpoint): + try: + # (5) Restore checkpoint into DeepSpeed engine + with ckpt.as_directory() as checkpoint_dir: + ds_engine.load_checkpoint(checkpoint_dir) + except Exception as e: + raise RuntimeError(f"Checkpoint loading failed: {e}") from e +``` + +**Making checkpoints distributed-ready with Ray and DeepSpeed** + +DeepSpeed saves model and optimizer states in a partitioned format, where each worker stores only its shard. This requires synchronization across processes, so it’s important to ensure that all workers reach the same checkpointing point before proceeding. We use `torch.distributed.barrier()` to guarantee that every worker finishes saving before moving on. + +Finally, `ray.train.report` both reports training metrics and saves the checkpoint to persistent storage, making it accessible for resuming training later. + + +### 5. Training Loop + +In Ray Train, we define a training loop function that orchestrates the entire process on each GPU worker. The function below: + +1. Restores training from a checkpoint if one is available. +1. Sets up the dataloader with setup_dataloader. +1. Initializes the model and optimizer with DeepSpeed. +1. Gets the device assigned to this worker. +1. Iterates through the specified number of epochs. +1. For multi-GPU training, ensures each worker sees a unique data shard each epoch. +1. For each batch: + - Moves inputs to the device. + - Runs the forward pass to compute loss. + - Logs the loss. +1. Performs backward pass and optimizer step with DeepSpeed. +1. Aggregates average loss and reports metrics, saving a checkpoint at the end of each epoch. + +```python +def train_loop(config: Dict[str, Any]) -> None: + + # (1) Load checkpoint if exists + ckpt = ray.train.get_checkpoint() + if ckpt: + load_checkpoint(ds_engine, ckpt) + + # (2) Set up dataloader + train_loader = setup_dataloader(config["model_name"], config["seq_length"], config["batch_size"]) + + # (3) Initialize model + optimizer with DeepSpeed + ds_engine = setup_model_and_optimizer(config["model_name"], config["learning_rate"], config["ds_config"]) + + # (4) Get device for this worker + device = ray.train.torch.get_device() + + for epoch in range(config["epochs"]): + # (6) Ensure unique shard per worker when using multiple GPUs + if ray.train.get_context().get_world_size() > 1: + train_loader.sampler.set_epoch(epoch) + + running_loss = 0.0 + num_batches = 0 + + # (7) Iterate over batches + for step, batch in enumerate(train_loader): + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + + # Forward pass + outputs = ds_engine( + input_ids=input_ids, + attention_mask=attention_mask, + labels=input_ids, + use_cache=False + ) + loss = outputs.loss + print(f"step {step} loss: {loss.item()}") + + # Backward pass + optimizer step + ds_engine.backward(loss) + ds_engine.step() + + running_loss += loss.item() + num_batches += 1 + + # (8) Report metrics + save checkpoint + report_metrics_and_save_checkpoint(ds_engine, {"loss": running_loss / num_batches, "epoch": epoch}) +``` + +**Coordinating distributed training with Ray and DeepSpeed** + +Ray launches this training loop on each GPU worker, while DeepSpeed handles partitioning and optimization under the hood. Each worker processes a unique shard of the data (data parallelism), computes local gradients, and synchronizes with others. +By combining Ray’s orchestration with DeepSpeed’s memory-efficient engine, you get distributed training that scales smoothly across multiple GPUs and nodes — with automatic checkpointing and metric reporting built in. + +## 6. Configure DeepSpeed and Launch Trainer + + +The final step is to configure parameters and launch the distributed training job with Ray’s TorchTrainer. The function below: +1. Parses command-line arguments for training and model settings. +1. Defines the Ray scaling configuration (e.g., number of workers, GPU usage). +1. Builds the DeepSpeed configuration dictionary (ds_config). +1. Prepares the training loop configuration with hyperparameters and model details. +1. Sets up the Ray RunConfig to manage storage and experiment metadata. +1. Creates a TorchTrainer that launches the training loop on multiple GPU workers. +1. Starts training with trainer.fit() and prints the result. + +```python +def main(): + # (1) Parse arguments + args = get_args() + print(args) + + # (2) Ray scaling configuration + scaling_config = ScalingConfig(num_workers=2, use_gpu=True) + + # (3) DeepSpeed configuration + ds_config = { + "train_micro_batch_size_per_gpu": args.batch_size, + "bf16": {"enabled": True}, + "grad_accum_dtype": "bf16", + "zero_optimization": { + "stage": args.zero_stage, + "overlap_comm": True, + "contiguous_gradients": True, + }, + "gradient_clipping": 1.0, + } + + # (4) Training loop configuration + train_loop_config = { + "epochs": args.num_epochs, + "learning_rate": args.learning_rate, + "batch_size": args.batch_size, + "ds_config": ds_config, + "model_name": args.model_name, + "seq_length": args.seq_length, + "dataset_name": args.dataset_name, + } + + # (5) Ray run configuration + run_config = RunConfig( + storage_path="/mnt/cluster_storage/", + name=f"deepspeed_sample_{uuid.uuid4().hex[:8]}", + ) + + # (6) Create trainer + trainer = TorchTrainer( + train_loop_per_worker=train_loop, + scaling_config=scaling_config, + train_loop_config=train_loop_config, + run_config=run_config, + ) + + # (7) Launch training + result = trainer.fit() + print(f"Training finished. Result: {result}") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, default="MiniLLM/MiniPLM-Qwen-500M") + parser.add_argument("--dataset_name", type=str, default="ag_news") + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_epochs", type=int, default=1) + parser.add_argument("--seq_length", type=int, default=512) + parser.add_argument("--learning_rate", type=float, default=1e-6) + parser.add_argument("--zero_stage", type=int, default=3) + return parser.parse_args() + + +if __name__ == "__main__": + main() +``` + +The complete working script is available as `train.py`. + + +**Launching distributed training with Ray and DeepSpeed** + +Ray’s TorchTrainer automatically launches multiple workers (one per GPU) and runs the `train_loop` on each. The scaling configuration controls how many workers to start, while the run configuration handles logging, storage, and experiment tracking. + +DeepSpeed’s `ds_config` ensures that the right ZeRO stage and optimizations are applied inside each worker. Together, this setup makes it easy to scale from a single GPU to a multi-node cluster without changing your training loop code. + + +## Advanced Usage + +DeepSpeed has many other configuration options to tune performance and memory usage. +Here we introduce some of the most commonly used options. +Please refer to the [DeepSpeed documentation](https://www.deepspeed.ai/docs/config-json/) for more details. + + +### DeepSpeed ZeRO Stages + +DeepSpeed ZeRO has three stages, each providing different levels of memory optimization and performance trade-offs. + +- **Stage 1**: This stage focuses on optimizer state partitioning. It reduces memory usage by partitioning the optimizer states across data parallel workers. This is the least aggressive stage and is suitable for most models without significant changes. +- **Stage 2**: In addition to optimizer state partitioning, this stage also partitions the gradients. This further reduces memory usage but may introduce some communication overhead. It's a good choice for larger models that can benefit from additional memory savings. +- **Stage 3**: This is the most aggressive stage, which partitions both the optimizer states and the model parameters. It provides the highest memory savings but may require more careful tuning of the training process. This stage is recommended for very large models that cannot fit into the memory of a single GPU. + +The higher the stage, the more memory savings you get, but it may also introduce more communication overhead and complexity in training. +You can select the desired ZeRO stage by setting the `zero_stage` parameter in the DeepSpeed configuration dictionary passed to `deepspeed.initialize`. + +```python +ds_config = { + "zero_optimization": { + "stage": 2, # or 1 or 3 +... + }, +} +``` + + +### Mixed Precision Training + +Mixed precision training is a technique that uses both 16-bit and 32-bit floating-point types in a single network. This can lead to faster training times and reduced memory usage. DeepSpeed has built-in support for mixed precision training using either FP16 or BF16. + +To enable mixed precision training, you can set the `bf16` or `fp16` parameters in the DeepSpeed configuration dictionary. For example: + +```python +ds_config = { + "bf16": {"enabled": True}, # or "fp16": {"enabled": True} +} +``` + +Note that these options keep the clone of weights/gradients and optimizer states in 32-bit precision to maintain numerical stability. + + +### CPU Offloading + +DeepSpeed supports offloading model states and optimizer states to CPU memory. +Offloading these causes a certain amount of overhead due to data transfer between CPU and GPU, but it significantly reduces GPU memory usage, which can be beneficial when training very large models that do not fit into GPU memory. + +To enable CPU offloading, you can set the `offload` parameters in the DeepSpeed configuration dictionary. For example: + +```python +ds_config = { + "offload_param": { + "device": "cpu", + "pin_memory": True, + } +} +``` + +You can also offload only optimizer states similarly by using the `offload_optimizer` parameter. + +```python +ds_config = { + "offload_optimizer": { + "device": "cpu", + "pin_memory": True, + } +} +``` + +### Convert Checkpoint for Inference + +As the checkpoint of DeepSpeed ZeRO Stage 3 is partitioned across multiple GPUs, it cannot be directly used for inference. To convert a ZeRO Stage 3 checkpoint to a standard model checkpoint that can be loaded for inference, you can use `get_fp32_state_dict_from_zero_checkpoint` API. + +```python +from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint +# do the training and checkpoint saving +state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) +torch.save(state_dict, "model_fp32.pt") +``` diff --git a/templates/ray-train-deepspeed/ray_deepspeed_tutorial.ipynb b/templates/ray-train-deepspeed/ray_deepspeed_tutorial.ipynb new file mode 100644 index 000000000..7c30e1624 --- /dev/null +++ b/templates/ray-train-deepspeed/ray_deepspeed_tutorial.ipynb @@ -0,0 +1,592 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "\n", + "# Getting Started with DeepSpeed ZeRO and Ray Train (Notebook)\n", + "\n", + "This notebook walks through how to combine **DeepSpeed ZeRO** with **Ray Train** to efficiently scale PyTorch training across GPUs and nodes while minimizing memory usage.\n", + "\n", + "It includes:\n", + "- A hands-on example of fine-tuning an LLM\n", + "- Checkpoint saving and resuming with Ray Train\n", + "- Configuring ZeRO for memory and performance (stages, mixed precision, CPU offload)\n", + "- Launching a distributed training job\n", + "\n", + "> **Note**: This template is optimized for the Anyscale platform. When running on open-source Ray, you must configure a Ray cluster, install dependencies on all nodes, and set up storage for checkpoints.\n" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "\n", + "## Anyscale-Specific Configuration\n", + "\n", + "On Anyscale, most configuration is automated. When running on open-source Ray, you will need to manually:\n", + "- **Configure your Ray Cluster** (multi-node setup, resource allocation)\n", + "- **Manage Dependencies** (install on each node)\n", + "- **Set Up Storage** (shared or distributed storage for checkpoints)\n" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "\n", + "## Install Dependencies (if needed)\n", + "\n", + "Uncomment and run the cell below if your environment doesn't already have these packages installed.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "pip install torch torchvision\n", + "pip install transformers datasets==3.6.0 trl\n", + "pip install deepspeed ray[train]" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "\n", + "## Configuration Constants\n", + "\n", + "We use simple constants instead of `argparse` so this notebook is easier to run. Adjust these as needed.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"RAY_TRAIN_V2_ENABLED\"] = \"1\" # Ensure Ray Train v2 APIs\n", + "\n", + "# ---- Training constants (edit these) ----\n", + "MODEL_NAME = \"MiniLLM/MiniPLM-Qwen-500M\"\n", + "DATASET_NAME = \"ag_news\"\n", + "BATCH_SIZE = 1\n", + "NUM_EPOCHS = 1\n", + "SEQ_LENGTH = 512\n", + "LEARNING_RATE = 1e-6\n", + "ZERO_STAGE = 3\n", + "\n", + "# Ray scaling settings\n", + "NUM_WORKERS = 2\n", + "USE_GPU = True\n", + "\n", + "# Storage\n", + "STORAGE_PATH = \"/mnt/cluster_storage/\"\n", + "EXPERIMENT_PREFIX = \"deepspeed_sample\"" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "\n", + "### 1. Import Packages\n", + "We import **Ray Train** for distributed orchestration, **PyTorch** for modeling, **Hugging Face Transformers/Datasets** for models and data, and **DeepSpeed** for ZeRO-based optimization.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "import uuid\n", + "import logging\n", + "from typing import Dict, Any\n", + "import tempfile\n", + "\n", + "import ray\n", + "import ray.train\n", + "import ray.train.torch\n", + "from ray.train.torch import TorchTrainer\n", + "from ray.train import ScalingConfig, RunConfig, Checkpoint\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from datasets import load_dataset, DownloadConfig\n", + "\n", + "import deepspeed\n", + "\n", + "logger = logging.getLogger(__name__)" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "\n", + "### 2. Set Up the Dataloader\n", + "\n", + "The code below:\n", + "\n", + "1. Downloads a tokenizer from the Hugging Face Hub (`AutoTokenizer`). \n", + "2. Loads the `ag_news` dataset using Hugging Face’s `load_dataset`. \n", + "3. Applies tokenization with padding and truncation via `map`. \n", + "4. Converts the dataset into a PyTorch `DataLoader`, which handles batching and shuffling. \n", + "5. Finally, use `ray.train.torch.prepare_data_loader` to make the dataloader distributed-ready.\n", + "\n", + "Here we use only 1% of the dataset for quick testing. Adjust as needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "def setup_dataloader(model_name: str, dataset_name: str, seq_length: int, batch_size: int) -> DataLoader:\n", + " # (1) Get tokenizer\n", + " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n", + "\n", + " # (2) Load dataset\n", + " dataset = load_dataset(dataset_name, split=\"train[:1%]\", download_config=DownloadConfig(disable_tqdm=True))\n", + "\n", + " # (3) Tokenize\n", + " def tokenize_function(examples):\n", + " return tokenizer(examples['text'], padding='max_length', max_length=seq_length, truncation=True)\n", + " tokenized_dataset = dataset.map(tokenize_function, batched=True, num_proc=1, keep_in_memory=True)\n", + " tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])\n", + "\n", + " # (4) Create DataLoader\n", + " data_loader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=True)\n", + "\n", + " # (5) Use prepare_data_loader for distributed training\n", + " return ray.train.torch.prepare_data_loader(data_loader)" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "The following code demonstrates how to use the tokenizer to encode a sample string. \n", + "- `AutoTokenizer.from_pretrained` downloads and configures the tokenizer for your model.\n", + "- You can encode any text string and inspect the resulting token IDs and attention mask." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "# Example usage of get_tokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n", + "sample_text = \"Ray Train and DeepSpeed make distributed training easy!\"\n", + "encoded = tokenizer(sample_text, padding='max_length', max_length=32, truncation=True)\n", + "print(encoded)" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "\n", + "**Making the dataloader distributed-ready with Ray** \n", + "In **data parallelism**, each GPU worker trains on a unique shard of the dataset while holding its own copy of the model; gradients are synchronized after each step. \n", + "Ray’s `prepare_data_loader` wraps PyTorch’s `DataLoader` and automatically applies a `DistributedSampler`, ensuring workers see disjoint data, splits are balanced, and epoch boundaries are handled correctly." + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "\n", + "### 3. Initialize Model and Optimizer\n", + "\n", + "The function below:\n", + "\n", + "1. Loads a pretrained model from the Hugging Face Hub (`AutoModelForCausalLM`). \n", + "2. Defines the optimizer (`AdamW`). \n", + "3. Initializes DeepSpeed with ZeRO options and returns a `DeepSpeedEngine`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def setup_model_and_optimizer(model_name: str, learning_rate: float, ds_config: Dict[str, Any]) -> deepspeed.runtime.engine.DeepSpeedEngine:\n", + " # (1) Load pretrained model\n", + " model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)\n", + "\n", + " # (2) Define optimizer\n", + " optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n", + "\n", + " # (3) Initialize with DeepSpeed (distributed + memory optimizations)\n", + " ds_engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config)\n", + " return ds_engine\n" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "\n", + "**Making the model distributed-ready with Ray and DeepSpeed** \n", + "DeepSpeed’s `initialize` always partitions **optimizer states** (ZeRO Stage 1). Depending on the chosen stage, it can also partition **gradients** (Stage 2) and **model parameters/weights** (Stage 3). This staged approach balances memory savings and communication overhead, and we’ll describe these stages in more detail [later in the tutorial](#deepspeed-zero-stages)." + ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": {}, + "source": [ + "\n", + "### 4. Checkpointing and Loading\n", + "\n", + "\n", + "Checkpointing is crucial for fault tolerance and for resuming training after interruptions. The functions below:\n", + "\n", + "1. Create a temporary directory for storing checkpoints.\n", + "1. Save the partitioned model and optimizer states with DeepSpeed’s `save_checkpoint`.\n", + "1. Synchronize all workers with `torch.distributed.barrier` to ensure every process finishes saving.\n", + "1. Report metrics and checkpoint location to Ray with `ray.train.report`.\n", + "1. Restore a previously saved checkpoint into the DeepSpeed engine using `load_checkpoint`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def report_metrics_and_save_checkpoint(\n", + " ds_engine: deepspeed.runtime.engine.DeepSpeedEngine,\n", + " metrics: Dict[str, Any]\n", + ") -> None:\n", + " \"\"\"Save worker checkpoints and report metrics to Ray.\n", + " Each rank writes its shard to a temp directory so Ray bundles all of them.\n", + " \"\"\"\n", + " ctx = ray.train.get_context()\n", + " epoch_value = metrics[\"epoch\"]\n", + "\n", + " with tempfile.TemporaryDirectory() as tmp_dir:\n", + " checkpoint_dir = os.path.join(tmp_dir, \"checkpoint\")\n", + " os.makedirs(checkpoint_dir, exist_ok=True)\n", + "\n", + " ds_engine.save_checkpoint(checkpoint_dir)\n", + "\n", + " epoch_file = os.path.join(checkpoint_dir, \"epoch.txt\")\n", + " with open(epoch_file, \"w\", encoding=\"utf-8\") as f:\n", + " f.write(str(epoch_value))\n", + "\n", + " checkpoint = Checkpoint.from_directory(tmp_dir)\n", + " ray.train.report(metrics, checkpoint=checkpoint)\n", + "\n", + " if ctx.get_world_rank() == 0:\n", + " experiment_name = ctx.get_experiment_name()\n", + " print(\n", + " f\"Checkpoint saved successfully for experiment {experiment_name} at {checkpoint_dir}. Metrics: {metrics}\"\n", + " )\n", + "\n", + "\n", + "def load_checkpoint(ds_engine: deepspeed.runtime.engine.DeepSpeedEngine, ckpt: ray.train.Checkpoint) -> int:\n", + " \"\"\"Restore DeepSpeed state and determine next epoch.\"\"\"\n", + " next_epoch = 0\n", + " try:\n", + " with ckpt.as_directory() as checkpoint_dir:\n", + " print(f\"Loading checkpoint from {checkpoint_dir}\")\n", + " epoch_dir = os.path.join(checkpoint_dir, \"checkpoint\")\n", + " if not os.path.isdir(epoch_dir):\n", + " epoch_dir = checkpoint_dir\n", + "\n", + " ds_engine.load_checkpoint(epoch_dir)\n", + "\n", + " epoch_file = os.path.join(epoch_dir, \"epoch.txt\")\n", + " if os.path.isfile(epoch_file):\n", + " with open(epoch_file, \"r\", encoding=\"utf-8\") as f:\n", + " last_epoch = int(f.read().strip())\n", + " next_epoch = last_epoch + 1\n", + "\n", + " if torch.distributed.is_available() and torch.distributed.is_initialized():\n", + " torch.distributed.barrier()\n", + " except Exception as e:\n", + " raise RuntimeError(f\"Checkpoint loading failed: {e}\") from e\n", + " return next_epoch\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "20", + "metadata": {}, + "source": [ + "\n", + "**Making checkpoints distributed-ready with Ray and DeepSpeed** \n", + "DeepSpeed saves model and optimizer states in a **partitioned format**, where each worker stores only its shard. This requires synchronization across processes, so all workers must reach the same checkpointing point before proceeding. We use `torch.distributed.barrier()` to ensure that every worker finishes saving before moving on. \n", + "Finally, `ray.train.report` both reports training metrics and saves the checkpoint to persistent storage, making it accessible for resuming training later.\n" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "\n", + "### 5. Training Iteration\n", + "\n", + "In Ray Train, we define a training loop function that orchestrates the entire process on each GPU worker. The function below:\n", + "\n", + "1. Initializes the model and optimizer with DeepSpeed.\n", + "1. Restores training from a checkpoint if one is available.\n", + "1. Sets up the dataloader with setup_dataloader.\n", + "1. Gets the device assigned to this worker.\n", + "1. Iterates through the specified number of epochs.\n", + "1. For multi-GPU training, ensures each worker sees a unique data shard each epoch.\n", + "1. For each batch:\n", + " - Moves inputs to the device.\n", + " - Runs the forward pass to compute loss.\n", + " - Logs the loss.\n", + "1. Performs backward pass and optimizer step with DeepSpeed.\n", + "1. Aggregates average loss and reports metrics, saving a checkpoint at the end of each epoch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def train_loop(config: Dict[str, Any]) -> None:\n", + " # (1) Initialize model + optimizer with DeepSpeed\n", + " ds_engine = setup_model_and_optimizer(config[\"model_name\"], config[\"learning_rate\"], config[\"ds_config\"])\n", + "\n", + " # (2) Load checkpoint if exists\n", + " ckpt = ray.train.get_checkpoint()\n", + " start_epoch = 0\n", + " if ckpt:\n", + " start_epoch = load_checkpoint(ds_engine, ckpt)\n", + "\n", + " # (3) Set up dataloader\n", + " train_loader = setup_dataloader(config[\"model_name\"], config[\"dataset_name\"], config[\"seq_length\"], config[\"batch_size\"])\n", + " total_steps = len(train_loader) * config[\"epochs\"]\n", + "\n", + " # (4) Get device for this worker\n", + " device = ray.train.torch.get_device()\n", + "\n", + " for epoch in range(start_epoch, config[\"epochs\"]):\n", + " # (6) Ensure unique shard per worker when using multiple GPUs\n", + " if ray.train.get_context().get_world_size() > 1 and hasattr(train_loader, \"sampler\"):\n", + " sampler = getattr(train_loader, \"sampler\", None)\n", + " if sampler and hasattr(sampler, \"set_epoch\"):\n", + " sampler.set_epoch(epoch)\n", + "\n", + " running_loss = 0.0\n", + " num_batches = 0\n", + "\n", + " # (7) Iterate over batches\n", + " for step, batch in enumerate(train_loader):\n", + " input_ids = batch['input_ids'].to(device)\n", + " attention_mask = batch['attention_mask'].to(device)\n", + "\n", + " # Forward pass\n", + " outputs = ds_engine(\n", + " input_ids=input_ids,\n", + " attention_mask=attention_mask,\n", + " labels=input_ids,\n", + " use_cache=False\n", + " )\n", + " loss = outputs.loss\n", + " print(f\"Epoch: {epoch} Step: {step + 1}/{total_steps} Loss: {loss.item()}\")\n", + "\n", + " # Backward pass + optimizer step\n", + " ds_engine.backward(loss)\n", + " ds_engine.step()\n", + "\n", + " running_loss += loss.item()\n", + " num_batches += 1\n", + "\n", + " # Stop early in the tutorial so runs finish quickly\n", + " if step + 1 >= 30:\n", + " print(\"Stopping early at 30 steps for the tutorial\")\n", + " break\n", + "\n", + " # (8) Report metrics + save checkpoint\n", + " report_metrics_and_save_checkpoint(ds_engine, {\"loss\": running_loss / num_batches, \"epoch\": epoch})\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "\n", + "**Coordinating distributed training with Ray and DeepSpeed** \n", + "Ray launches this `train_loop` on each worker, while DeepSpeed manages partitioning and memory optimizations. With **data parallelism**, each worker processes a unique shard of data, computes gradients locally, and participates in synchronization so parameters stay in sync.\n" + ] + }, + { + "cell_type": "markdown", + "id": "24", + "metadata": {}, + "source": [ + "\n", + "### 6. Configure DeepSpeed and Launch Trainer\n", + "\n", + "The final step is to configure parameters and launch the distributed training job with Ray’s `TorchTrainer`. The function below:\n", + "1. Parses command-line arguments for training and model settings.\n", + "1. Defines the Ray scaling configuration (e.g., number of workers, GPU usage).\n", + "1. Builds the DeepSpeed configuration dictionary (ds_config).\n", + "1. Prepares the training loop configuration with hyperparameters and model details.\n", + "1. Sets up the Ray RunConfig to manage storage and experiment metadata. Here we set a random experiment name, but you can specify the name of a previous experiment to load the checkpoint.\n", + "1. Creates a TorchTrainer that launches the training loop on multiple GPU workers.\n", + "1. Starts training with trainer.fit() and prints the result." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Ray scaling configuration\n", + "scaling_config = ScalingConfig(num_workers=NUM_WORKERS, use_gpu=USE_GPU)\n", + "\n", + "# DeepSpeed configuration\n", + "ds_config = {\n", + " \"train_micro_batch_size_per_gpu\": BATCH_SIZE,\n", + " \"bf16\": {\"enabled\": True},\n", + " \"grad_accum_dtype\": \"bf16\",\n", + " \"zero_optimization\": {\n", + " \"stage\": ZERO_STAGE,\n", + " \"overlap_comm\": True,\n", + " \"contiguous_gradients\": True,\n", + " },\n", + " \"gradient_clipping\": 1.0,\n", + "}\n", + "\n", + "# Training loop configuration\n", + "train_loop_config = {\n", + " \"epochs\": NUM_EPOCHS,\n", + " \"learning_rate\": LEARNING_RATE,\n", + " \"batch_size\": BATCH_SIZE,\n", + " \"ds_config\": ds_config,\n", + " \"model_name\": MODEL_NAME,\n", + " \"dataset_name\": DATASET_NAME,\n", + " \"seq_length\": SEQ_LENGTH,\n", + "}\n", + "\n", + "# Ray run configuration\n", + "run_config = RunConfig(\n", + " storage_path=STORAGE_PATH,\n", + " # Set the name of the previous experiment when resuming from a checkpoint\n", + " name=f\"{EXPERIMENT_PREFIX}_{uuid.uuid4().hex[:8]}\",\n", + ")\n", + "\n", + "# Create and launch the trainer\n", + "trainer = TorchTrainer(\n", + " train_loop_per_worker=train_loop,\n", + " scaling_config=scaling_config,\n", + " train_loop_config=train_loop_config,\n", + " run_config=run_config,\n", + ")\n", + "\n", + "# To actually run training, execute the following:\n", + "result = trainer.fit()\n", + "print(f\"Training finished. Result: {result}\")" + ] + }, + { + "cell_type": "markdown", + "id": "26", + "metadata": {}, + "source": [ + "**Launching distributed training with Ray and DeepSpeed**\n", + "\n", + "Ray’s TorchTrainer automatically launches multiple workers (one per GPU) and runs the `train_loop` on each. The scaling configuration controls how many workers to start, while the run configuration handles logging, storage, and experiment tracking.\n", + "\n", + "DeepSpeed’s `ds_config` ensures that the right ZeRO stage and optimizations are applied inside each worker. Together, this setup makes it easy to scale from a single GPU to a multi-node cluster without changing your training loop code.\n", + "\n", + "\n", + "## Advanced Configurations\n", + "\n", + "DeepSpeed has many other configuration options to tune performance and memory usage.\n", + "Here we introduce some of the most commonly used options.\n", + "Please refer to the [DeepSpeed documentation](https://www.deepspeed.ai/docs/config-json/) for more details.\n", + "\n", + "### DeepSpeed ZeRO Stages\n", + "- **Stage 1**: Partitions optimizer states (always on when using ZeRO). \n", + "- **Stage 2**: Additionally partitions gradients. \n", + "- **Stage 3**: Additionally partitions model parameters/weights.\n", + "\n", + "The higher the stage, the more memory savings you get, but it may also introduce more communication overhead and complexity in training.\n", + "You can select the stage via `ds_config[\"zero_optimization\"][\"stage\"]`. See the DeepSpeed docs for more details.\n", + "\n", + "```python\n", + "ds_config = {\n", + " \"zero_optimization\": {\n", + " \"stage\": 2, # or 1 or 3\n", + "...\n", + " },\n", + "}\n", + "```\n", + "\n", + "### Mixed Precision\n", + "Enable BF16 or FP16:\n", + "```python\n", + "ds_config = {\n", + " \"bf16\": {\"enabled\": True}, # or \"fp16\": {\"enabled\": True}\n", + "}\n", + "```\n", + "\n", + "### CPU Offloading\n", + "Reduce GPU memory pressure by offloading to CPU (at the cost of PCIe traffic):\n", + "```python\n", + "ds_config = {\n", + " \"offload_param\": {\"device\": \"cpu\", \"pin_memory\": True},\n", + " # or\n", + " \"offload_optimizer\": {\"device\": \"cpu\", \"pin_memory\": True},\n", + "}\n", + "```\n" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/templates/ray-train-deepspeed/train.py b/templates/ray-train-deepspeed/train.py new file mode 100644 index 000000000..9bcc9a435 --- /dev/null +++ b/templates/ray-train-deepspeed/train.py @@ -0,0 +1,250 @@ +import os +import tempfile +import uuid +import logging + +import argparse +from typing import Dict, Any + +os.environ["RAY_TRAIN_V2_ENABLED"] = "1" + +import ray +import ray.train +import ray.train.torch +from ray.train.torch import TorchTrainer +from ray.train import ScalingConfig, RunConfig, Checkpoint + +import torch +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset, DownloadConfig + +import deepspeed + + +logger = logging.getLogger(__name__) + + +def log_rank0(message: str) -> None: + if ray.train.get_context().get_world_rank() == 0: + logger.info(message) + + +def get_tokenizer(model_name: str, trust_remote_code: bool = True) -> Any: + """ + Load and configure the tokenizer for the given model. + + Args: + model_name: Name of the model to load tokenizer for + trust_remote_code: Whether to trust remote code + + Returns: + Configured tokenizer + """ + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code) + + # Set pad token if not already set + if tokenizer.pad_token is None: + if tokenizer.eos_token is not None: + tokenizer.pad_token = tokenizer.eos_token + else: + # Fallback for models without eos_token + tokenizer.pad_token = tokenizer.convert_ids_to_tokens(2) + + return tokenizer + + +def setup_dataloader(model_name: str, dataset_name: str, seq_length: int, batch_size: int) -> DataLoader: + tokenizer = get_tokenizer(model_name, trust_remote_code=True) + + dataset = load_dataset(dataset_name, split=f"train[:1%]", download_config=DownloadConfig(disable_tqdm=True)) + + def tokenize_function(examples): + return tokenizer(examples['text'], padding='max_length', max_length=seq_length, truncation=True) + + tokenized_dataset = dataset.map(tokenize_function, batched=True, num_proc=1, keep_in_memory=True) + tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask']) + data_loader = DataLoader( + tokenized_dataset, + batch_size=batch_size, + shuffle=True + ) + + return ray.train.torch.prepare_data_loader(data_loader) + + +def setup_model_and_optimizer(model_name: str, learning_rate: float, ds_config: Dict[str, Any]) -> deepspeed.runtime.engine.DeepSpeedEngine: + model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + log_rank0(f"Model loaded: {model_name} (#parameters: {sum(p.numel() for p in model.parameters())})") + + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + ds_engine, optimizer, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=ds_config, + ) + return ds_engine + + +def report_metrics_and_save_checkpoint( + ds_engine: deepspeed.runtime.engine.DeepSpeedEngine, + metrics: Dict[str, Any] +) -> None: + ctx = ray.train.get_context() + epoch_value = metrics["epoch"] + + with tempfile.TemporaryDirectory() as tmp_dir: + checkpoint_dir = os.path.join(tmp_dir, "checkpoint") + os.makedirs(checkpoint_dir, exist_ok=True) + + ds_engine.save_checkpoint(checkpoint_dir) + + epoch_file = os.path.join(checkpoint_dir, "epoch.txt") + with open(epoch_file, "w", encoding="utf-8") as f: + f.write(str(epoch_value)) + + checkpoint = Checkpoint.from_directory(tmp_dir) + ray.train.report(metrics, checkpoint=checkpoint) + + if ctx.get_world_rank() == 0: + experiment_name = ctx.get_experiment_name() + log_rank0( + f"Checkpoint saved successfully for experiment {experiment_name} at {checkpoint_dir}. Metrics: {metrics}" + ) + + +def load_checkpoint(ds_engine: deepspeed.runtime.engine.DeepSpeedEngine, ckpt: ray.train.Checkpoint) -> int: + next_epoch = 0 + try: + with ckpt.as_directory() as checkpoint_dir: + log_rank0(f"Loading checkpoint from {checkpoint_dir}") + epoch_dir = os.path.join(checkpoint_dir, "checkpoint") + if not os.path.isdir(epoch_dir): + epoch_dir = checkpoint_dir + + ds_engine.load_checkpoint(epoch_dir) + + epoch_file = os.path.join(epoch_dir, "epoch.txt") + assert os.path.isfile(epoch_file), f"Epoch file not found in checkpoint: {epoch_file}" + with open(epoch_file, "r", encoding="utf-8") as f: + last_epoch = int(f.read().strip()) + next_epoch = last_epoch + 1 + + torch.distributed.barrier() + log_rank0("Successfully loaded distributed checkpoint") + except Exception as e: + logger.error(f"Failed to load checkpoint: {e}") + raise RuntimeError(f"Checkpoint loading failed: {e}") from e + return next_epoch + + +def train_loop(config: Dict[str, Any]) -> None: + + ds_engine = setup_model_and_optimizer(config["model_name"], config["learning_rate"], config["ds_config"]) + + # Load checkpoint if exists + ckpt = ray.train.get_checkpoint() + start_epoch = 0 + if ckpt: + start_epoch = load_checkpoint(ds_engine, ckpt) + + if start_epoch > 0: + log_rank0(f"Resuming training from epoch {start_epoch}") + + train_loader = setup_dataloader(config["model_name"], config["dataset_name"], config["seq_length"], config["batch_size"]) + total_steps = len(train_loader) * config["epochs"] + device = ray.train.torch.get_device() + + for epoch in range(start_epoch, config["epochs"]): + if ray.train.get_context().get_world_size() > 1: + train_loader.sampler.set_epoch(epoch) + + running_loss = 0.0 + num_batches = 0 + for step, batch in enumerate(train_loader): + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + outputs = ds_engine(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, use_cache=False) + loss = outputs.loss + log_rank0(f"Epoch: {epoch} Step: {step + 1}/{total_steps} Loss: {loss.item()}") + + ds_engine.backward(loss) + ds_engine.step() + + running_loss += loss.item() + num_batches += 1 + + if config["debug_steps"] > 0 and step + 1 >= config["debug_steps"]: + log_rank0(f"Debug steps finished. Stopping epoch {epoch}.") + break + + report_metrics_and_save_checkpoint( + ds_engine, + {"loss": running_loss / num_batches, "epoch": epoch}, + ) + + +def main(): + args = get_args() + print(args) + + scaling_config = ScalingConfig(num_workers=2, use_gpu=True) + + ds_config = { + "train_micro_batch_size_per_gpu": args.batch_size, + "bf16": {"enabled": True}, + "grad_accum_dtype": "bf16", + "zero_optimization": { + "stage": args.zero_stage, + "overlap_comm": True, + "contiguous_gradients": True, + }, + "gradient_clipping": 1.0, + } + + train_loop_config = { + "epochs": args.num_epochs, + "learning_rate": args.learning_rate, + "batch_size": args.batch_size, + "ds_config": ds_config, + "model_name": args.model_name, + "seq_length": args.seq_length, + "dataset_name": args.dataset_name, + "debug_steps": args.debug_steps, + } + + name = f"deepspeed_sample_{uuid.uuid4().hex[:8]}" if args.resume_experiment is None else args.resume_experiment + print(f"Experiment name: {name}") + run_config = RunConfig( + storage_path="/mnt/cluster_storage/", + name=name, + ) + + trainer = TorchTrainer( + train_loop_per_worker=train_loop, + scaling_config=scaling_config, + train_loop_config=train_loop_config, + run_config=run_config, + ) + + result = trainer.fit() + print(f"Training finished. Result: {result}") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, default="MiniLLM/MiniPLM-Qwen-500M") + parser.add_argument("--dataset_name", type=str, default="ag_news") + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_epochs", type=int, default=1) + parser.add_argument("--seq_length", type=int, default=512) + parser.add_argument("--learning_rate", type=float, default=1e-6) + parser.add_argument("--zero_stage", type=int, default=3) + parser.add_argument("--resume_experiment", type=str, default=None, help="Path to the experiment to resume from") + parser.add_argument("--debug_steps", type=int, default=0) + + return parser.parse_args() + + +if __name__ == "__main__": + main()