Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nemo/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nemo.lightning.pytorch.callbacks.preemption import PreemptionCallback
from nemo.lightning.pytorch.callbacks.progress_bar import MegatronProgressBar
from nemo.lightning.pytorch.callbacks.progress_printer import ProgressPrinter
from nemo.lightning.pytorch.callbacks.pytorch_profiler import PytorchProfilerCallback
from nemo.lightning.pytorch.callbacks.pytorch_profiler import PytorchProfilerCallback, PyTorchProfilerCallback
from nemo.lightning.pytorch.callbacks.runtime_estimator import RuntimeEstimator
from nemo.lightning.pytorch.callbacks.speed_monitor import SpeedMonitor

Expand All @@ -38,6 +38,7 @@
"PEFT",
"NsysCallback",
"PytorchProfilerCallback",
"PyTorchProfilerCallback",
"MegatronProgressBar",
"ProgressPrinter",
"PreemptionCallback",
Expand Down
68 changes: 31 additions & 37 deletions nemo/lightning/pytorch/callbacks/pytorch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def trace_handler(prof, chakra_device_trace_path):
chakra_device_trace_path: The path where the trace file will be saved.
"""
rank = get_rank()
trace_file = chakra_device_trace_path / f"rank-{rank}.json"
trace_file = chakra_device_trace_path / f"rank-{rank}.json.gz"
prof.export_chrome_trace(str(trace_file))
logging.info(f"Kineto trace saved: {trace_file}")

Expand All @@ -52,16 +52,17 @@ class PytorchProfilerCallback(Callback, IOMixin):
warmup_steps (int): Number of warmup steps before profiling starts.
active_steps (int): Number of active profiling steps.
trace_dir (str): Directory where traces will be saved.
collect_et (bool): Collect Execution Trace (host trace).
profiler_kwargs (dict, optional): Additional keyword args to pass to torch.profiler.profile
"""

def __init__(
self,
start_step: int,
end_step: int,
warmup_steps: int = 0,
active_steps: int = 1,
warmup_steps: int = 2,
trace_dir: str = None,
collect_et: bool = False,
profiler_kwargs: Optional[Dict[str, Any]] = None,
):
if trace_dir is None:
Expand All @@ -79,23 +80,29 @@ def __init__(
self.start_step = start_step
self.end_step = end_step
self.warmup_steps = warmup_steps
self.active_steps = active_steps
self.active_steps = max(self.end_step - self.start_step, 1)

wait_steps = max(self.start_step - self.warmup_steps, 0)

self.trace_dir = Path(trace_dir)
self.chakra_host_trace_path = self.trace_dir / "host"
self.chakra_device_trace_path = self.trace_dir / "device"
self.chakra_device_trace_path = self.trace_dir / "torch_profiler"

self.chakra_host_trace_path.mkdir(parents=True, exist_ok=True)
self.chakra_device_trace_path.mkdir(parents=True, exist_ok=True)

self.trace_observer = torch.profiler.ExecutionTraceObserver()
self.trace_observer = None
if collect_et:
self.chakra_host_trace_path.mkdir(parents=True, exist_ok=True)
self.trace_observer = torch.profiler.ExecutionTraceObserver()

base_kwargs = {
"activities": [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
"schedule": torch.profiler.schedule(wait=0, warmup=self.warmup_steps, active=self.active_steps),
"schedule": torch.profiler.schedule(
wait=wait_steps, warmup=self.warmup_steps, active=self.active_steps, repeat=1
),
"on_trace_ready": lambda prof: trace_handler(prof, self.chakra_device_trace_path),
"execution_trace_observer": self.trace_observer,
}
Expand All @@ -105,55 +112,42 @@ def __init__(
base_kwargs.update(profiler_kwargs)

self.profiler = torch.profiler.profile(**base_kwargs)
self.is_profiling = False

logging.info(
"Chakra profiling initialized:\n"
"PyTorch profiling initialized:\n"
f" - Start Step: {self.start_step}\n"
f" - End Step: {self.end_step}\n"
f" - Warmup Steps: {self.warmup_steps}\n"
f" - Active Steps: {self.active_steps}\n"
f" - Trace Directory: {self.trace_dir}\n"
f" - Collect Execution Trace: {collect_et}\n"
f" - Extra profiler kwargs: {profiler_kwargs or {}}"
)

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx: int) -> None:
"""Chakra trace collection starts."""
if trainer.global_step == self.start_step:
if self.is_profiling:
logging.warning(
f"Attempted to start Chakra profiler multiple times at step {trainer.global_step}. Skipping."
)
return

logging.info(f"====== Start Chakra profiling at global_step {trainer.global_step} ======")

trace_file = self.chakra_host_trace_path / f"rank-{get_rank()}.json"
self.trace_observer.register_callback(str(trace_file))
if self.trace_observer is not None:
# Setup the trace filename during the training run once distributed has been correctly setup.
trace_file = self.chakra_host_trace_path / f"rank-{get_rank()}.json.gz"
self.trace_observer.register_callback(str(trace_file))

self.profiler.start()
self.is_profiling = True

logging.info("Chakra Profiler Started.\n")
logging.info("PyTorch/Chakra Profiler Started.\n")

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) -> None:
"""Chakra trace collection ends."""
if self.is_profiling:
if trainer.global_step < self.end_step:
self.profiler.step()
logging.info(f"Profiler step executed at global_step {trainer.global_step}")
else:
logging.info(f"====== End Chakra profiling at global_step {trainer.global_step} ======")
self._stop_profiler()

def _stop_profiler(self):
if self.is_profiling:
logging.info("Stopping Chakra Profiler...")
self.profiler.stop()
self.is_profiling = False
# Step the profiler after each training batch
if self.profiler:
self.profiler.step()

def on_train_end(self, trainer, pl_module):
if self.trace_observer:
try:
logging.info("Unregistering ExecutionTraceObserver...")
self.trace_observer.unregister_callback()
except RuntimeError as e:
logging.warning(f"ExecutionTraceObserver cleanup failed: {e}")


# Add an alias, ideally we want camelcase for PyTorch
PyTorchProfilerCallback = PytorchProfilerCallback
45 changes: 43 additions & 2 deletions nemo/lightning/run/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@
import signal
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Optional
from typing import Any, Callable, Dict, Optional

import nemo_run as run
import yaml
from lightning.pytorch import Callback
from lightning.pytorch.loggers import WandbLogger
from nemo_run.core.serialization.yaml import YamlSerializer

from nemo.lightning.pytorch.callbacks import MemoryProfileCallback, NsysCallback, PreemptionCallback
from nemo.lightning.pytorch.callbacks import (
MemoryProfileCallback,
NsysCallback,
PreemptionCallback,
PyTorchProfilerCallback,
)
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
from nemo.utils import logging
from nemo.utils.import_utils import safe_import
Expand Down Expand Up @@ -215,6 +220,42 @@ def setup(self, task: run.Partial | run.Script, executor: run.Executor):
_merge_callbacks(task, callbacks=callbacks)


@dataclass(kw_only=True)
class PyTorchProfilerPlugin(run.Plugin):
"""
A plugin for torch profiling.

You can specify when to start and end the profiling, on which ranks to run the profiling,
and what to trace during profiling.

Args:
start_step (int): The step at which to start the nsys profiling.
end_step (int): The step at which to end the nsys profiling.
with_modules(bool) : show modules
"""

output_path: str
start_step: int
end_step: int
with_stack: bool = False
collect_et: bool = False
profiler_kwargs: Optional[Dict[str, Any]] = None

def setup(self, task: run.Partial | run.Script, executor: run.Executor):
"""Set up the torch profiling plugin."""
if isinstance(task, run.Partial):
profiler_callback = run.Config(
PyTorchProfilerCallback,
trace_dir=self.output_path,
start_step=self.start_step,
end_step=self.end_step,
collect_et=self.collect_et,
profiler_kwargs=self.profiler_kwargs,
)
callbacks: list[run.Config[Callback]] = [profiler_callback] # type: ignore
_merge_callbacks(task, callbacks=callbacks)


@dataclass(kw_only=True)
class WandbPlugin(run.Plugin):
"""
Expand Down
14 changes: 14 additions & 0 deletions scripts/performance/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ def parse_cli_args():
required=False,
default=None,
)
parser.add_argument(
"-etp",
"--enable_torch_profiler",
help="Enable torch profiler. Disabled by default",
action="store_true",
)
parser.add_argument(
"-tpo",
"--torch_profile_out_path",
type=str,
help="Path to the output file of torch profiling",
required=False,
default=None,
)
parser.add_argument(
"-tb",
"--tensorboard",
Expand Down
29 changes: 29 additions & 0 deletions scripts/performance/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,32 @@ def build_perf_env_plugin(args, pp_size: int | None = None, user_buffer_registra
gpu_sm100_or_newer=gpu_sm100_or_newer,
user_buffer_registration=user_buf,
)


def build_torch_profiler_plugin(args):
"""
Build a PyTorchProfilerPlugin with consistent defaults across scripts.
"""
from nemo.lightning.run.plugins import PyTorchProfilerPlugin

enable_torch_profiler = args.enable_torch_profiler or os.environ.get('ENABLE_TORCH_PROFILER', '0') == '1'

if args.enable_nsys and enable_torch_profiler:
logging.warning("Cannot run both Nsys and PyTorch profiler at the same time.")
return None

if enable_torch_profiler:
start_iter = int(os.environ.get('TORCH_PROFILER_START_ITER', args.profiling_start_step))
end_iter = int(os.environ.get('TORCH_PROFILER_END_ITER', args.profiling_stop_step))
return PyTorchProfilerPlugin(
start_step=start_iter,
end_step=end_iter,
output_path=os.environ.get(
'TORCH_PROFILES_DIR', args.torch_profile_out_path
), # a subdir torch_profiles will be created here.
profiler_kwargs={
"with_stack": os.environ.get('TORCH_PROFILER_WITH_STACK', '0') == '1',
},
collect_et=os.environ.get('TORCH_PROFILER_COLLECT_ET', '0') == '1',
)
return None
12 changes: 11 additions & 1 deletion scripts/performance/llm/finetune_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@

from ..argument_parser import parse_additional_slurm_params, parse_cli_args
from ..executors import slurm_executor
from ..helpers import args_sanity_check, build_perf_env_plugin, get_user_configs, set_primary_perf_configs
from ..helpers import (
args_sanity_check,
build_perf_env_plugin,
build_torch_profiler_plugin,
get_user_configs,
set_primary_perf_configs,
)
from ..utils import hf_tokenizer, import_ckpt_experiment, isfile_train_pack_metadata

HF_MODEL_URI = "deepseek-ai/DeepSeek-V3-Base"
Expand Down Expand Up @@ -177,6 +183,10 @@ def override_recipe_configs(
plugins = [build_perf_env_plugin(args, pp_size=pp_size)]
if args.enable_nsys:
plugins.append(NsysPlugin(start_step=10, end_step=12, gen_shape=True))

if torch_profiler_plugin := build_torch_profiler_plugin(args):
plugins.append(torch_profiler_plugin)

if args.enable_memory_profile:
assert args.memory_profile_out_path is not None
plugins.append(MemoryProfilePlugin(dir=args.memory_profile_out_path))
Expand Down
5 changes: 5 additions & 0 deletions scripts/performance/llm/finetune_llama31_405b.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..helpers import (
args_sanity_check,
build_perf_env_plugin,
build_torch_profiler_plugin,
get_user_configs,
set_exp_logging_configs,
set_primary_perf_configs,
Expand Down Expand Up @@ -207,6 +208,10 @@ def override_recipe_configs(
plugins = [build_perf_env_plugin(args, pp_size=pp_size)]
if args.enable_nsys:
plugins.append(NsysPlugin(start_step=5, end_step=6))

if torch_profiler_plugin := build_torch_profiler_plugin(args):
plugins.append(torch_profiler_plugin)

if args.enable_memory_profile:
assert args.memory_profile_out_path is not None
plugins.append(MemoryProfilePlugin(dir=args.memory_profile_out_path))
Expand Down
3 changes: 3 additions & 0 deletions scripts/performance/llm/finetune_llama3_70b.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..helpers import (
args_sanity_check,
build_perf_env_plugin,
build_torch_profiler_plugin,
get_user_configs,
set_exp_logging_configs,
set_primary_perf_configs,
Expand Down Expand Up @@ -214,6 +215,8 @@ def override_recipe_configs(
plugins = [build_perf_env_plugin(args, pp_size=pp_size)]
if args.enable_nsys:
plugins.append(NsysPlugin(start_step=5, end_step=6))
if torch_profiler_plugin := build_torch_profiler_plugin(args):
plugins.append(torch_profiler_plugin)
if args.enable_memory_profile:
assert args.memory_profile_out_path is not None
plugins.append(MemoryProfilePlugin(dir=args.memory_profile_out_path))
Expand Down
5 changes: 5 additions & 0 deletions scripts/performance/llm/finetune_llama3_8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..helpers import (
args_sanity_check,
build_perf_env_plugin,
build_torch_profiler_plugin,
get_user_configs,
set_exp_logging_configs,
set_primary_perf_configs,
Expand Down Expand Up @@ -152,6 +153,10 @@ def override_recipe_configs(
plugins = [build_perf_env_plugin(args, pp_size=pp_size)]
if args.enable_nsys:
plugins.append(NsysPlugin(start_step=5, end_step=6))

if torch_profiler_plugin := build_torch_profiler_plugin(args):
plugins.append(torch_profiler_plugin)

if args.enable_memory_profile:
assert args.memory_profile_out_path is not None
plugins.append(MemoryProfilePlugin(dir=args.memory_profile_out_path))
Expand Down
3 changes: 3 additions & 0 deletions scripts/performance/llm/finetune_llama4_e128.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..helpers import (
args_sanity_check,
build_perf_env_plugin,
build_torch_profiler_plugin,
get_user_configs,
set_exp_logging_configs,
set_primary_perf_configs,
Expand Down Expand Up @@ -180,6 +181,8 @@ def override_recipe_configs(

if args.enable_nsys:
plugins.append(NsysPlugin(start_step=5, end_step=6))
if torch_profiler_plugin := build_torch_profiler_plugin(args):
plugins.append(torch_profiler_plugin)
if args.enable_memory_profile:
assert args.memory_profile_out_path is not None
plugins.append(MemoryProfilePlugin(dir=args.memory_profile_out_path))
Expand Down
6 changes: 5 additions & 1 deletion scripts/performance/llm/mlperf_lora_llama2_70b.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from ..argument_parser import parse_additional_slurm_params, parse_cli_args
from ..executors import slurm_executor
from ..helpers import args_sanity_check, build_perf_env_plugin
from ..helpers import args_sanity_check, build_perf_env_plugin, build_torch_profiler_plugin
from ..utils import import_ckpt_experiment

NUM_NODES = 1
Expand Down Expand Up @@ -353,6 +353,10 @@ def mlperf_lora_llama2_70b_recipe(
plugins = [build_perf_env_plugin(args, pp_size=PP_SIZE)]
if args.enable_nsys:
plugins.append(NsysPlugin(start_step=5, end_step=6))

if torch_profiler_plugin := build_torch_profiler_plugin(args):
plugins.append(torch_profiler_plugin)

if args.enable_memory_profile:
assert args.memory_profile_out_path is not None
plugins.append(MemoryProfilePlugin(dir=args.memory_profile_out_path))
Expand Down
Loading