From 912e81812c6d5d3f33e70ac0fcd8db1b89e1d268 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 4 Dec 2025 18:45:54 +0800 Subject: [PATCH 1/4] Adds Triton GEMV kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces an autotuned block-tiled matvec kernel so matrix–vector updates run on GPU with configurable scaling and launch grid. --- kernel_course/triton_ops/gemv.py | 121 +++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 kernel_course/triton_ops/gemv.py diff --git a/kernel_course/triton_ops/gemv.py b/kernel_course/triton_ops/gemv.py new file mode 100644 index 0000000..71b2013 --- /dev/null +++ b/kernel_course/triton_ops/gemv.py @@ -0,0 +1,121 @@ +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2), + ], + key=["n_elements_M", "n_elements_N"], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["n_elements_M"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["n_elements_N"] % args["BLOCK_N"] == 0, + } +) +@triton.jit +def gemv_kernel( + A_ptr, + x_ptr, + y_ptr, + alpha, + beta, + stride_am, + stride_an, + stride_x, + stride_y, + n_elements_M, + n_elements_N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, +): + # There are multiple program processing different blocks of data + # We identify which program we are in using program_id + start_m = tl.program_id(0) + # This program will process inputs that offset from the initial pointer + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # Create a mask to guard memory operations against out-of-bounds accesses + mask_m = offs_m < n_elements_M + # Initialize the accumulator to zero for each row + acc = tl.zeros((BLOCK_M,), dtype=tl.float32) + end_n = n_elements_N + # Loop over the N dimension in blocks of BLOCK_N + for start_n in range(0, end_n, BLOCK_N): + # Align start_n to a multiple of BLOCK_N for efficient memory access + start_n = tl.multiple_of(start_n, BLOCK_N) + # This program will process inputs that offset from the initial pointer + offs_n = start_n + tl.arange(0, BLOCK_N) + # Create a mask to guard memory operations against out-of-bounds accesses + mask_n = offs_n < n_elements_N + # Load a block of A and x from DRAM, masking out any extra elements in case the input is not a multiple of the block size + if EVEN_N & EVEN_M: + a = tl.load(A_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an) + else: + a = tl.load( + A_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an, + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, + ) + if EVEN_N: + x = tl.load(x_ptr + offs_n * stride_x) + else: + x = tl.load(x_ptr + offs_n * stride_x, mask=mask_n, other=0.0) + # Perform the matrix-vector multiplication for this block and accumulate the results + acc += tl.sum(a * x[None, :], axis=1) + # Load y from DRAM, masking out any extra elements in case the input is not a multiple of the block size + if EVEN_M: + y = tl.load(y_ptr + offs_m * stride_y) + else: + y = tl.load(y_ptr + offs_m * stride_y, mask=mask_m, other=0.0) + # Compute y = alpha * A * x + beta * y + y_new = (alpha * acc + beta * y).to(y.dtype) + # Write y back to DRAM + if EVEN_M: + tl.store(y_ptr + offs_m * stride_y, y_new) + else: + tl.store(y_ptr + offs_m * stride_y, y_new, mask=mask_m) + + +def gemv( + A: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, + alpha: float, + beta: float, +) -> torch.Tensor: + """ + Updates tensor `y` by adding the product of matrix `A` and vector `x` + scaled by `alpha`, and `y` scaled by `beta` using a Triton kernel. + + Args: + A (torch.Tensor): Matrix tensor. + x (torch.Tensor): Vector tensor to be multiplied with `A`. + y (torch.Tensor): Vector tensor to be updated. + alpha (float): Scaling factor for the product of `A` and `x`. + beta (float): Scaling factor for `y`. + + Returns: + torch.Tensor: The updated tensor `y`. + """ + + # Calculate the number of elements in the input tensors + n_elements_M, n_elements_N = A.shape + + # The SPMD launch grid is one-dimensional, with each program processing a block of rows of A + def grid(meta): + return (triton.cdiv(n_elements_M, meta["BLOCK_M"]),) + + # Launch the Triton kernel + gemv_kernel[grid]( + A, x, y, alpha, beta, + A.stride(0), A.stride(1), + x.stride(0), y.stride(0), + n_elements_M, n_elements_N, + ) + + return y \ No newline at end of file From 5e706518aba5eecf3de2cc8ffd6510a165a20359 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 4 Dec 2025 18:46:31 +0800 Subject: [PATCH 2/4] Format code for better readability in gemv.py --- kernel_course/triton_ops/gemv.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/kernel_course/triton_ops/gemv.py b/kernel_course/triton_ops/gemv.py index 71b2013..1cb00b6 100644 --- a/kernel_course/triton_ops/gemv.py +++ b/kernel_course/triton_ops/gemv.py @@ -54,7 +54,9 @@ def gemv_kernel( mask_n = offs_n < n_elements_N # Load a block of A and x from DRAM, masking out any extra elements in case the input is not a multiple of the block size if EVEN_N & EVEN_M: - a = tl.load(A_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an) + a = tl.load( + A_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an + ) else: a = tl.load( A_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an, @@ -109,13 +111,20 @@ def gemv( # The SPMD launch grid is one-dimensional, with each program processing a block of rows of A def grid(meta): return (triton.cdiv(n_elements_M, meta["BLOCK_M"]),) - + # Launch the Triton kernel gemv_kernel[grid]( - A, x, y, alpha, beta, - A.stride(0), A.stride(1), - x.stride(0), y.stride(0), - n_elements_M, n_elements_N, + A, + x, + y, + alpha, + beta, + A.stride(0), + A.stride(1), + x.stride(0), + y.stride(0), + n_elements_M, + n_elements_N, ) - return y \ No newline at end of file + return y From 7304ff3c16a421377e699188e98b3d744c2331dd Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 4 Dec 2025 18:46:58 +0800 Subject: [PATCH 3/4] Adds GEMV performance tests Introduces parameterized GEMV benchmarks covering CUDA, MPS, and multiple dtypes to compare python, PyTorch, Triton, and CUTe implementations. --- tests/test_gemv.py | 95 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/test_gemv.py diff --git a/tests/test_gemv.py b/tests/test_gemv.py new file mode 100644 index 0000000..bca3241 --- /dev/null +++ b/tests/test_gemv.py @@ -0,0 +1,95 @@ +import pytest +import torch + +from kernel_course import testing +from kernel_course.python_ops import gemv as python_gemv + +try: + from kernel_course.pytorch_ops import gemv as pytorch_gemv + + HAS_PYTORCH = True +except Exception: + pytorch_gemv = None + HAS_PYTORCH = False + +try: + from kernel_course.triton_ops import gemv as triton_gemv + + HAS_TRITON = True +except Exception: + triton_gemv = None + HAS_TRITON = False + +try: + from kernel_course.cute_ops import gemv as cute_gemv + + HAS_CUTE = True +except Exception: + cute_gemv = None + HAS_CUTE = False + + +def factory( + MN: tuple[int, int], + device: torch.device, + dtype: torch.dtype = torch.float32, +): + M, N = MN + A = torch.linspace(0.0, 1.0, steps=M * N, device=device, dtype=dtype).view(M, N) + x = torch.linspace(0.0, 1.0, steps=N, device=device, dtype=dtype) + y = torch.linspace(0.0, 1.0, steps=M, device=device, dtype=dtype) + alpha = 1.14 + beta = 5.14 + return (A, x, y, alpha, beta), {} + + +@pytest.mark.parametrize( + "device", + [ + pytest.param( + torch.device("cuda"), + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="requires CUDA" + ), + ), + pytest.param( + torch.device("mps"), + marks=pytest.mark.skipif( + not torch.backends.mps.is_available(), reason="requires MPS" + ), + ), + ], +) +@pytest.mark.parametrize( + "dtype", + [torch.float32, torch.float16, torch.bfloat16], +) +@pytest.mark.parametrize( + "MN", + [ + (1 << 4, 1 << 4), + (1 << 8, 1 << 8), + ], +) +def test_gemv( + device: torch.device, + dtype: torch.dtype, + MN: tuple[int, int], +) -> None: + impls = testing.get_impls( + python_impl=python_gemv.gemv, + pytorch_impl=pytorch_gemv.gemv if HAS_PYTORCH else None, + triton_impl=triton_gemv.gemv if HAS_TRITON else None, + cute_impl=cute_gemv.gemv if HAS_CUTE else None, + ) + + # Benchmark each implementation + config = testing.BenchmarkConfig(warmup=3, repeat=100) + results = testing.run_benchmarks( + impls, + lambda: factory(MN, device, dtype), + flops=2 * MN[0] * MN[1], + config=config, + ) + + testing.show_benchmarks(results) From 27b5143d20f9dd238d5b70d1e4bdbf52af7e06fd Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 4 Dec 2025 18:48:54 +0800 Subject: [PATCH 4/4] Update README to reflect GEMV Triton kernel implementation --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 09134c5..b14c23a 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ The following common BLAS kernels have been implemented in multiple frameworks. | [scal](./docs/scal.md) | scale vector | $y = \alpha y$ | $n$ | $2n$ | [✅](./kernel_course/python_ops/scal.py) | [✅](./kernel_course/pytorch_ops/scal.py) | [✅](./kernel_course/triton_ops/scal.py) | ❌ | [✅](./tests/test_scal.py) | | [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [✅](./kernel_course/python_ops/axpby.py) | [✅](./kernel_course/pytorch_ops/axpby.py) | [✅](./kernel_course/triton_ops/axpby.py) | ❌ | [✅](./tests/test_axpby.py) | | [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) | -| [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | ❌ | ❌ | ❌ | +| [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) | | geru | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | ❌ | ❌ | ❌ | ❌ | ❌ | | gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | ❌ | ❌ | ❌ | ❌ | ❌ |