Skip to content

Commit b21af25

Browse files
committed
Pull thunder PR "Remove the --profile option"
Lightning-AI/lightning-thunder#2715
1 parent e181595 commit b21af25

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

benchmarks/python/benchmark_inference.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ class InferenceBenchmarkConfig:
224224
mode: str
225225
disable_moe_replacement: bool
226226
attn_implementation: str | None
227-
profile: bool
228227
thunder_cache: str | None
229228
enable_thunder_cudagraph: bool
230229

@@ -557,10 +556,17 @@ def run_benchmark(self) -> InferenceMetrics:
557556
for _ in tqdm(range(self.config.num_iterations), disable=LOCAL_RANK != 0):
558557
past_key_values.reset()
559558

560-
if self.config.profile:
559+
is_under_nsys = bool(os.environ.get("NSYS_PROFILING_SESSION_ID"))
560+
# Wrap each non-warmup iteration with cudaProfilerStart() and
561+
# cudaProfilerStop(). This allows the user to run
562+
# ```shell
563+
# nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<N> ...
564+
# ```
565+
# to record only the non-warmup iterations.
566+
if is_under_nsys:
561567
torch.cuda.cudart().cudaProfilerStart()
562568
iter_metrics = self.measure_inference_step(input_ids, past_key_values, self.config.output_length)
563-
if self.config.profile:
569+
if is_under_nsys:
564570
torch.cuda.cudart().cudaProfilerStop()
565571

566572
all_metrics.append(iter_metrics)
@@ -748,11 +754,6 @@ def parse_args() -> argparse.Namespace:
748754
action="store_true",
749755
help="let nvfuser take care of linear and matmul, note that this might fail with distributed run. See: https://github.com/NVIDIA/Fuser/issues/4507",
750756
)
751-
parser.add_argument(
752-
"--profile",
753-
action="store_true",
754-
help="Wrap each non-warmup iteration with cudaProfilerStart() and cudaProfilerStop(). This allows us to run `nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<N> ... --profile` to record only the non-warmup iterations.",
755-
)
756757

757758
parser.add_argument(
758759
"--thunder-trace",
@@ -801,7 +802,6 @@ def main():
801802
enable_nv_linear=args.enable_nv_linear,
802803
disable_moe_replacement=args.disable_moe_replacement,
803804
attn_implementation=args.attn_implementation,
804-
profile=args.profile,
805805
thunder_cache=args.thunder_cache,
806806
enable_thunder_cudagraph=args.enable_thunder_cudagraph,
807807
)

0 commit comments

Comments
 (0)