Skip to content

Commit 4aa0b53

Browse files
authored
Merge pull request #59 from flash-algo/add-geru-test-script
[FEATURE SUPPORT] add geru test script
2 parents 96a46e8 + ed5ef56 commit 4aa0b53

File tree

5 files changed

+96
-4
lines changed

5 files changed

+96
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ The following common BLAS kernels have been implemented in multiple frameworks.
2323
| [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [](./kernel_course/python_ops/axpby.py) | [](./kernel_course/pytorch_ops/axpby.py) | [](./kernel_course/triton_ops/axpby.py) || [](./tests/test_axpby.py) |
2424
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [](./kernel_course/python_ops/dot.py) | [](./kernel_course/pytorch_ops/dot.py) | [](./kernel_course/triton_ops/dot.py) || [](./tests/test_dot.py) |
2525
| [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [](./kernel_course/python_ops/gemv.py) | [](./kernel_course/pytorch_ops/gemv.py) | [](./kernel_course/triton_ops/gemv.py) || [](./tests/test_gemv.py) |
26-
| [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [](./kernel_course/python_ops/geru.py) | [](./kernel_course/pytorch_ops/geru.py) | [](./kernel_course/triton_ops/geru.py) || |
26+
| [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [](./kernel_course/python_ops/geru.py) | [](./kernel_course/pytorch_ops/geru.py) | [](./kernel_course/triton_ops/geru.py) || [](./tests/test_geru.py) |
2727
| gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ ||||||
2828

2929

kernel_course/python_ops/geru.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ def geru(
2020
torch.Tensor: The updated tensor `A`.
2121
"""
2222

23-
A = A + alpha * x[:, None] * y[None, :]
23+
A = A + alpha * (x[:, None] * y[None, :])
2424

2525
return A

kernel_course/pytorch_ops/geru.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ def geru(
2020
torch.Tensor: The updated tensor `A`.
2121
"""
2222

23-
A += torch.mul(torch.ger(x, y), alpha)
23+
A = torch.add(A, torch.mul(torch.outer(x, y), alpha))
2424

2525
return A

tests/test_gemv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def factory(
7171
(1 << 8, 1 << 8),
7272
],
7373
)
74-
def test_gemv(
74+
def test_gemv_benchmark(
7575
device: torch.device,
7676
dtype: torch.dtype,
7777
MN: tuple[int, int],

tests/test_geru.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import pytest
2+
import torch
3+
4+
from kernel_course import testing
5+
from kernel_course.python_ops import geru as python_geru
6+
7+
try:
8+
from kernel_course.pytorch_ops import geru as pytorch_geru
9+
10+
HAS_PYTORCH = True
11+
except Exception:
12+
pytorch_geru = None
13+
HAS_PYTORCH = False
14+
15+
try:
16+
from kernel_course.triton_ops import geru as triton_geru
17+
18+
HAS_TRITON = True
19+
except Exception:
20+
triton_geru = None
21+
HAS_TRITON = False
22+
23+
try:
24+
from kernel_course.cute_ops import geru as cute_geru
25+
26+
HAS_CUTE = True
27+
except Exception:
28+
cute_geru = None
29+
HAS_CUTE = False
30+
31+
32+
def factory(
33+
MN: tuple[int, int],
34+
device: torch.device,
35+
dtype: torch.dtype = torch.float32,
36+
):
37+
M, N = MN
38+
A = torch.linspace(0.0, 1.0, steps=M * N, device=device, dtype=dtype).view(M, N)
39+
x = torch.linspace(0.0, 1.0, steps=N, device=device, dtype=dtype)
40+
y = torch.linspace(0.0, 1.0, steps=M, device=device, dtype=dtype)
41+
alpha = 3.14
42+
return (A, x, y, alpha), {}
43+
44+
45+
@pytest.mark.parametrize(
46+
"device",
47+
[
48+
pytest.param(
49+
torch.device("cuda"),
50+
marks=pytest.mark.skipif(
51+
not torch.cuda.is_available(), reason="requires CUDA"
52+
),
53+
),
54+
pytest.param(
55+
torch.device("mps"),
56+
marks=pytest.mark.skipif(
57+
not torch.backends.mps.is_available(), reason="requires MPS"
58+
),
59+
),
60+
],
61+
)
62+
@pytest.mark.parametrize(
63+
"dtype",
64+
[torch.float32, torch.float16, torch.bfloat16],
65+
)
66+
@pytest.mark.parametrize(
67+
"numel",
68+
[
69+
(1 << 4, 1 << 4),
70+
(1 << 8, 1 << 8),
71+
],
72+
)
73+
def test_geru_benchmark(
74+
device: torch.device, dtype: torch.dtype, numel: tuple[int, int]
75+
) -> None:
76+
impls = testing.get_impls(
77+
python_impl=python_geru.geru,
78+
pytorch_impl=pytorch_geru.geru if HAS_PYTORCH else None,
79+
triton_impl=triton_geru.geru if HAS_TRITON else None,
80+
cute_impl=cute_geru.geru if HAS_CUTE else None,
81+
)
82+
83+
# Benchmark each implementation
84+
config = testing.BenchmarkConfig(warmup=3, repeat=1_000)
85+
results = testing.run_benchmarks(
86+
impls,
87+
lambda: factory(numel, device, dtype),
88+
flops=2 * numel[0] * numel[1],
89+
config=config,
90+
)
91+
92+
testing.show_benchmarks(results)

0 commit comments

Comments
 (0)