Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ The following common BLAS kernels have been implemented in multiple frameworks.
| [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [✅](./kernel_course/python_ops/axpby.py) | [✅](./kernel_course/pytorch_ops/axpby.py) | [✅](./kernel_course/triton_ops/axpby.py) | ❌ | [✅](./tests/test_axpby.py) |
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) |
| [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) |
| geru | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | | ❌ | ❌ | ❌ | ❌ |
| [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | ❌ | ❌ | ❌ | ❌ |
| gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | ❌ | ❌ | ❌ | ❌ | ❌ |


Expand Down
35 changes: 35 additions & 0 deletions docs/geru.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# GERU Kernel

The `geru` operator computes the outer product of two vectors and adds the result to a matrix.

## Mathematical Definition

Given an input matrix `A` and input vectors `x` and `y`, along with a scalar `α`, the kernel evaluates

$$
A = A + \alpha x y^\top
$$

The outer product is computed by multiplying the vector `x` with the transpose of vector `y`, scaling the result by `α`, and then adding it to the matrix `A` to produce the updated matrix `A`.

## Kernel Implementations

- [Python Implementation](../kernel_course/python_ops/geru.py)
- [PyTorch Implementation](../kernel_course/pytorch_ops/geru.py)
- [Triton Implementation](../kernel_course/triton_ops/geru.py)
- [CuTe Implementation](../kernel_course/cute_ops/geru.py)

All backends share the interface:

```python
def geru(A: torch.Tensor, x: torch.Tensor, y: torch.Tensor, alpha: float) -> torch.Tensor:
...
```

## Testing

See the [test suite](../tests/test_geru.py) for the validation harness that exercises every backend.

```bash
pytest tests/test_geru.py -s
```
25 changes: 25 additions & 0 deletions kernel_course/python_ops/geru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch


def geru(
A: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
alpha: float,
):
"""
Updates tensor `A` by adding the outer product of vectors `x` and `y` scaled by `alpha`.

Args:
A (torch.Tensor): Matrix tensor to be updated.
x (torch.Tensor): Vector tensor.
y (torch.Tensor): Vector tensor.
alpha (float): Scaling factor for the outer product of `x` and `y`.

Returns:
torch.Tensor: The updated tensor `A`.
"""

A = A + alpha * x[:, None] * y[None, :]

return A
Loading