Skip to content

Commit 6ce5b33

Browse files
authored
Support local kernels in benchmark (#265)
* feat: load local kernel for benchmark when path: prefixed * feat: update activation bench * fix: cleanup activation bench typos and improve * fix: avoid prefix and handle relative paths * fix: prefer python warning
1 parent 5a70f5e commit 6ce5b33

File tree

2 files changed

+80
-9
lines changed

2 files changed

+80
-9
lines changed

kernels/src/kernels/benchmark.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import subprocess
88
import sys
99
import time
10+
import warnings
1011
from dataclasses import dataclass
1112
from pathlib import Path
1213
from typing import Any
@@ -478,6 +479,7 @@ def run_benchmark_class(
478479
iterations: int,
479480
warmup: int,
480481
repo_id: str,
482+
is_local: bool,
481483
revision: str,
482484
) -> tuple[dict[str, TimingResults], str]:
483485
results = {}
@@ -493,9 +495,13 @@ def run_benchmark_class(
493495
raise RuntimeError(f"No benchmark_* methods found in {benchmark_cls.__name__}")
494496

495497
# Load kernel once for all workloads
496-
from kernels import get_kernel
498+
from kernels import get_local_kernel, get_kernel
499+
500+
if is_local:
501+
kernel = get_local_kernel(Path(repo_id), "activation")
502+
else:
503+
kernel = get_kernel(repo_id, revision=revision)
497504

498-
kernel = get_kernel(repo_id, revision=revision)
499505
kernel_sha = get_kernel_sha_from_build_name(kernel)
500506
backend_name = backend() if TORCH_AVAILABLE else "cpu"
501507
# Map backend names to torch device names
@@ -654,6 +660,7 @@ def run_benchmark_script(
654660
warmup: int,
655661
cwd: Path,
656662
repo_id: str,
663+
is_local: bool,
657664
revision: str,
658665
) -> tuple[dict[str, TimingResults], str]:
659666
print(f"Running {script_path.name}...", file=sys.stderr)
@@ -681,6 +688,7 @@ def run_benchmark_script(
681688
iterations=iterations,
682689
warmup=warmup,
683690
repo_id=repo_id,
691+
is_local=is_local,
684692
revision=revision,
685693
)
686694
for name, timing in results.items():
@@ -734,6 +742,24 @@ def run_benchmark(
734742
# Suppress progress bars for cleaner output (files are often cached)
735743
disable_progress_bars()
736744

745+
repo_id_path = Path(repo_id)
746+
747+
if repo_id_path.is_absolute():
748+
is_local = repo_id_path.exists()
749+
else:
750+
is_local = (Path.cwd() / repo_id_path).exists()
751+
repo_id_path = Path.cwd() / repo_id_path
752+
753+
if is_local:
754+
if repo_id.count("/") == 1 and not repo_id.startswith(("./", "../")):
755+
warnings.warn(
756+
f"'{repo_id}' exists locally but looks like a repo_id. "
757+
f"Use './{repo_id}' to be explicit.",
758+
stacklevel=2,
759+
)
760+
branch = "local"
761+
version = None
762+
737763
# Requires either branch or version or parses from repo_id
738764
if branch is None and version is None:
739765
if "@" not in repo_id:
@@ -756,7 +782,10 @@ def run_benchmark(
756782
assert revision is not None # Guaranteed by parsing logic above
757783

758784
print(f"Downloading {repo_id}@{revision}...", file=sys.stderr)
759-
repo_path = Path(snapshot_download(repo_id=repo_id, revision=revision))
785+
if is_local:
786+
repo_path = repo_id_path.resolve()
787+
else:
788+
repo_path = Path(snapshot_download(repo_id=repo_id, revision=revision))
760789

761790
scripts = discover_benchmark_scripts(repo_id, repo_path)
762791

@@ -770,6 +799,7 @@ def run_benchmark(
770799
warmup=warmup,
771800
cwd=repo_path,
772801
repo_id=repo_id,
802+
is_local=is_local,
773803
revision=revision,
774804
)
775805
timing_results.update(results)

kernels/src/kernels/benchmarks/activation.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ class SiluAndMulBenchmark(Benchmark):
99

1010
# Workload: small
1111
def setup_small(self):
12-
self.x = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16)
13-
self.out = torch.empty(1, 128, 256, device="cuda", dtype=torch.float16)
12+
self.x = torch.randn(8, 1024, 2048, device=self.device, dtype=torch.float16)
13+
self.out = torch.empty(8, 1024, 1024, device=self.device, dtype=torch.float16)
1414

1515
def benchmark_small(self):
1616
self.kernel.silu_and_mul(self.out, self.x)
@@ -21,8 +21,8 @@ def verify_small(self) -> torch.Tensor:
2121

2222
# Workload: medium
2323
def setup_medium(self):
24-
self.x = torch.randn(4, 512, 1024, device="cuda", dtype=torch.float16)
25-
self.out = torch.empty(4, 512, 512, device="cuda", dtype=torch.float16)
24+
self.x = torch.randn(8, 2048, 4096, device=self.device, dtype=torch.float16)
25+
self.out = torch.empty(8, 2048, 2048, device=self.device, dtype=torch.float16)
2626

2727
def benchmark_medium(self):
2828
self.kernel.silu_and_mul(self.out, self.x)
@@ -33,12 +33,53 @@ def verify_medium(self) -> torch.Tensor:
3333

3434
# Workload: large
3535
def setup_large(self):
36-
self.x = torch.randn(8, 1024, 2048, device="cuda", dtype=torch.float16)
37-
self.out = torch.empty(8, 1024, 1024, device="cuda", dtype=torch.float16)
36+
self.x = torch.randn(8, 4096, 8192, device=self.device, dtype=torch.float16)
37+
self.out = torch.empty(8, 4096, 4096, device=self.device, dtype=torch.float16)
3838

3939
def benchmark_large(self):
4040
self.kernel.silu_and_mul(self.out, self.x)
41+
self.kernel.silu_and_mul(self.out, self.x)
4142

4243
def verify_large(self) -> torch.Tensor:
4344
d = self.x.shape[-1] // 2
4445
return F.silu(self.x[..., :d]) * self.x[..., d:]
46+
47+
48+
class GeluAndMulBenchmark(Benchmark):
49+
seed: int = 42
50+
51+
# Workload: small
52+
def setup_small(self):
53+
self.x = torch.randn(8, 1024, 2048, device=self.device, dtype=torch.float16)
54+
self.out = torch.empty(8, 1024, 1024, device=self.device, dtype=torch.float16)
55+
56+
def benchmark_small(self):
57+
self.kernel.gelu_and_mul(self.out, self.x)
58+
59+
def verify_small(self) -> torch.Tensor:
60+
d = self.x.shape[-1] // 2
61+
return F.gelu(self.x[..., :d]) * self.x[..., d:]
62+
63+
# Workload: medium
64+
def setup_medium(self):
65+
self.x = torch.randn(8, 2048, 4096, device=self.device, dtype=torch.float16)
66+
self.out = torch.empty(8, 2048, 2048, device=self.device, dtype=torch.float16)
67+
68+
def benchmark_medium(self):
69+
self.kernel.gelu_and_mul(self.out, self.x)
70+
71+
def verify_medium(self) -> torch.Tensor:
72+
d = self.x.shape[-1] // 2
73+
return F.gelu(self.x[..., :d]) * self.x[..., d:]
74+
75+
# Workload: large
76+
def setup_large(self):
77+
self.x = torch.randn(8, 4096, 8192, device=self.device, dtype=torch.float16)
78+
self.out = torch.empty(8, 4096, 4096, device=self.device, dtype=torch.float16)
79+
80+
def benchmark_large(self):
81+
self.kernel.gelu_and_mul(self.out, self.x)
82+
83+
def verify_large(self) -> torch.Tensor:
84+
d = self.x.shape[-1] // 2
85+
return F.gelu(self.x[..., :d]) * self.x[..., d:]

0 commit comments

Comments
 (0)