Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The following common BLAS kernels have been implemented in multiple frameworks.
| [scal](./docs/scal.md) | scale vector | $y = \alpha y$ | $n$ | $2n$ | [✅](./kernel_course/python_ops/scal.py) | [✅](./kernel_course/pytorch_ops/scal.py) | [✅](./kernel_course/triton_ops/scal.py) | ❌ | [✅](./tests/test_scal.py) |
| [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [✅](./kernel_course/python_ops/axpby.py) | [✅](./kernel_course/pytorch_ops/axpby.py) | [✅](./kernel_course/triton_ops/axpby.py) | ❌ | [✅](./tests/test_axpby.py) |
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) |
| [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | | ❌ | |
| [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) |
| geru | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | ❌ | ❌ | ❌ | ❌ | ❌ |
| gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | ❌ | ❌ | ❌ | ❌ | ❌ |

Expand Down
130 changes: 130 additions & 0 deletions kernel_course/triton_ops/gemv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import torch
import triton
import triton.language as tl


@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2),
],
key=["n_elements_M", "n_elements_N"],
)
@triton.heuristics(
{
"EVEN_M": lambda args: args["n_elements_M"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["n_elements_N"] % args["BLOCK_N"] == 0,
}
)
@triton.jit
def gemv_kernel(
A_ptr,
x_ptr,
y_ptr,
alpha,
beta,
stride_am,
stride_an,
stride_x,
stride_y,
n_elements_M,
n_elements_N,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
):
# There are multiple program processing different blocks of data
# We identify which program we are in using program_id
start_m = tl.program_id(0)
# This program will process inputs that offset from the initial pointer
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# Create a mask to guard memory operations against out-of-bounds accesses
mask_m = offs_m < n_elements_M
# Initialize the accumulator to zero for each row
acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
end_n = n_elements_N
# Loop over the N dimension in blocks of BLOCK_N
for start_n in range(0, end_n, BLOCK_N):
# Align start_n to a multiple of BLOCK_N for efficient memory access
start_n = tl.multiple_of(start_n, BLOCK_N)
# This program will process inputs that offset from the initial pointer
offs_n = start_n + tl.arange(0, BLOCK_N)
# Create a mask to guard memory operations against out-of-bounds accesses
mask_n = offs_n < n_elements_N
# Load a block of A and x from DRAM, masking out any extra elements in case the input is not a multiple of the block size
if EVEN_N & EVEN_M:
a = tl.load(
A_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an
)
else:
a = tl.load(
A_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an,
mask=mask_m[:, None] & mask_n[None, :],
other=0.0,
)
if EVEN_N:
x = tl.load(x_ptr + offs_n * stride_x)
else:
x = tl.load(x_ptr + offs_n * stride_x, mask=mask_n, other=0.0)
# Perform the matrix-vector multiplication for this block and accumulate the results
acc += tl.sum(a * x[None, :], axis=1)
# Load y from DRAM, masking out any extra elements in case the input is not a multiple of the block size
if EVEN_M:
y = tl.load(y_ptr + offs_m * stride_y)
else:
y = tl.load(y_ptr + offs_m * stride_y, mask=mask_m, other=0.0)
# Compute y = alpha * A * x + beta * y
y_new = (alpha * acc + beta * y).to(y.dtype)
# Write y back to DRAM
if EVEN_M:
tl.store(y_ptr + offs_m * stride_y, y_new)
else:
tl.store(y_ptr + offs_m * stride_y, y_new, mask=mask_m)


def gemv(
A: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
alpha: float,
beta: float,
) -> torch.Tensor:
"""
Updates tensor `y` by adding the product of matrix `A` and vector `x`
scaled by `alpha`, and `y` scaled by `beta` using a Triton kernel.

Args:
A (torch.Tensor): Matrix tensor.
x (torch.Tensor): Vector tensor to be multiplied with `A`.
y (torch.Tensor): Vector tensor to be updated.
alpha (float): Scaling factor for the product of `A` and `x`.
beta (float): Scaling factor for `y`.

Returns:
torch.Tensor: The updated tensor `y`.
"""

# Calculate the number of elements in the input tensors
n_elements_M, n_elements_N = A.shape

# The SPMD launch grid is one-dimensional, with each program processing a block of rows of A
def grid(meta):
return (triton.cdiv(n_elements_M, meta["BLOCK_M"]),)

# Launch the Triton kernel
gemv_kernel[grid](
A,
x,
y,
alpha,
beta,
A.stride(0),
A.stride(1),
x.stride(0),
y.stride(0),
n_elements_M,
n_elements_N,
)

return y
95 changes: 95 additions & 0 deletions tests/test_gemv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest
import torch

from kernel_course import testing
from kernel_course.python_ops import gemv as python_gemv

try:
from kernel_course.pytorch_ops import gemv as pytorch_gemv

HAS_PYTORCH = True
except Exception:
pytorch_gemv = None
HAS_PYTORCH = False

try:
from kernel_course.triton_ops import gemv as triton_gemv

HAS_TRITON = True
except Exception:
triton_gemv = None
HAS_TRITON = False

try:
from kernel_course.cute_ops import gemv as cute_gemv

HAS_CUTE = True
except Exception:
cute_gemv = None
HAS_CUTE = False


def factory(
MN: tuple[int, int],
device: torch.device,
dtype: torch.dtype = torch.float32,
):
M, N = MN
A = torch.linspace(0.0, 1.0, steps=M * N, device=device, dtype=dtype).view(M, N)
x = torch.linspace(0.0, 1.0, steps=N, device=device, dtype=dtype)
y = torch.linspace(0.0, 1.0, steps=M, device=device, dtype=dtype)
alpha = 1.14
beta = 5.14
return (A, x, y, alpha, beta), {}


@pytest.mark.parametrize(
"device",
[
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="requires CUDA"
),
),
pytest.param(
torch.device("mps"),
marks=pytest.mark.skipif(
not torch.backends.mps.is_available(), reason="requires MPS"
),
),
],
)
@pytest.mark.parametrize(
"dtype",
[torch.float32, torch.float16, torch.bfloat16],
)
@pytest.mark.parametrize(
"MN",
[
(1 << 4, 1 << 4),
(1 << 8, 1 << 8),
],
)
def test_gemv(
device: torch.device,
dtype: torch.dtype,
MN: tuple[int, int],
) -> None:
impls = testing.get_impls(
python_impl=python_gemv.gemv,
pytorch_impl=pytorch_gemv.gemv if HAS_PYTORCH else None,
triton_impl=triton_gemv.gemv if HAS_TRITON else None,
cute_impl=cute_gemv.gemv if HAS_CUTE else None,
)

# Benchmark each implementation
config = testing.BenchmarkConfig(warmup=3, repeat=100)
results = testing.run_benchmarks(
impls,
lambda: factory(MN, device, dtype),
flops=2 * MN[0] * MN[1],
config=config,
)

testing.show_benchmarks(results)
Loading