From 3d626bf6df761dc4a6d3ac306fbf7d3fc4743a86 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Mon, 8 Dec 2025 22:09:54 +0800 Subject: [PATCH 1/3] Documents GERU kernel usage Clarifies GERU math intuition and shared API across backends Guides contributors to available implementations and tests to keep validation consistent --- docs/geru.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 docs/geru.md diff --git a/docs/geru.md b/docs/geru.md new file mode 100644 index 0000000..34c3567 --- /dev/null +++ b/docs/geru.md @@ -0,0 +1,35 @@ +# GERU Kernel + +The `geru` operator computes the outer product of two vectors and adds the result to a matrix. + +## Mathematical Definition + +Given an input matrix `A` and input vectors `x` and `y`, along with a scalar `α`, the kernel evaluates + +$$ +A = A + \alpha x y^\top +$$ + +The outer product is computed by multiplying the vector `x` with the transpose of vector `y`, scaling the result by `α`, and then adding it to the matrix `A` to produce the updated matrix `A`. + +## Kernel Implementations + +- [Python Implementation](../kernel_course/python_ops/geru.py) +- [PyTorch Implementation](../kernel_course/pytorch_ops/geru.py) +- [Triton Implementation](../kernel_course/triton_ops/geru.py) +- [CuTe Implementation](../kernel_course/cute_ops/geru.py) + +All backends share the interface: + +```python +def geru(A: torch.Tensor, x: torch.Tensor, y: torch.Tensor, alpha: float) -> torch.Tensor: + ... +``` + +## Testing + +See the [test suite](../tests/test_geru.py) for the validation harness that exercises every backend. + +```bash +pytest tests/test_geru.py -s +``` \ No newline at end of file From e51c60849e2964cff215b6711ff52ae08614421c Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Mon, 8 Dec 2025 22:10:06 +0800 Subject: [PATCH 2/3] Adds tensor GERU update helper Expands python ops with a torch-based GERU to support scaled outer-product updates --- kernel_course/python_ops/geru.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 kernel_course/python_ops/geru.py diff --git a/kernel_course/python_ops/geru.py b/kernel_course/python_ops/geru.py new file mode 100644 index 0000000..c2e99b2 --- /dev/null +++ b/kernel_course/python_ops/geru.py @@ -0,0 +1,25 @@ +import torch + + +def geru( + A: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, + alpha: float, +): + """ + Updates tensor `A` by adding the outer product of vectors `x` and `y` scaled by `alpha`. + + Args: + A (torch.Tensor): Matrix tensor to be updated. + x (torch.Tensor): Vector tensor. + y (torch.Tensor): Vector tensor. + alpha (float): Scaling factor for the outer product of `x` and `y`. + + Returns: + torch.Tensor: The updated tensor `A`. + """ + + A = A + alpha * x[:, None] * y[None, :] + + return A From ad45217922eba1efd766edf73ad477f246a8ef1f Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Mon, 8 Dec 2025 22:10:13 +0800 Subject: [PATCH 3/3] Updates README to include GERU kernel documentation and status --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b14c23a..8f15583 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ The following common BLAS kernels have been implemented in multiple frameworks. | [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [✅](./kernel_course/python_ops/axpby.py) | [✅](./kernel_course/pytorch_ops/axpby.py) | [✅](./kernel_course/triton_ops/axpby.py) | ❌ | [✅](./tests/test_axpby.py) | | [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) | | [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) | -| geru | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | ❌ | ❌ | ❌ | ❌ | | gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | ❌ | ❌ | ❌ | ❌ | ❌ |