Skip to content

A minimal Python GPU IR, with backends targeting JAX and torch.func

Notifications You must be signed in to change notification settings

generic-account/tinyir-lang

Repository files navigation

TinyIR

TinyIR is a minimal intermediate representation and builder DSL for functional tensor computation, with multiple execution backends (NumPy, JAX, and PyTorch).

Install

Create and activate a virtual environment, then install dependencies:

python3 -m venv .venv
.venv/bin/pip install -U pip setuptools wheel
.venv/bin/pip install -r requirements-cpu.txt
.venv/bin/pip install -e .

GPU machines (CUDA 12.x) can use the GPU requirements and vendor wheels:

.venv/bin/pip install -r requirements-gpu.txt
.venv/bin/pip install "jax[cuda12]==0.4.28" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
.venv/bin/pip install torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121

Usage

Build a simple module with the builder:

import numpy as np
from tinyir import Builder, TensorType, run_module
import tinyir.backends.numpy_backend  # registers the backend

b = Builder()
x = b.input(TensorType((2, 3), "float32"), "x")
y = b.input(TensorType((2, 3), "float32"), "y")
z = b.add(x, y)
mod = b.build(z)

out = run_module(mod, {"x": np.ones((2, 3)), "y": np.ones((2, 3))}, backend="numpy")

Switching backends is done by importing the backend module (to register it) and selecting it by name:

import tinyir.backends.torch_backend
import tinyir.backends.jax_backend
from tinyir.backends.base import get_backend

torch_backend = get_backend("torch")
jax_backend = get_backend("jax")
fn = torch_backend.compile(mod)

The examples/ directory shows full runs, including higher-order ops like vmap and scan, and backend-specific demos (examples/torch_test.py, examples/torch_compile_demo.py).

Testing

Run the full test suite with pytest:

.venv/bin/python -m pytest

Tests parameterize backends (NumPy, JAX, and PyTorch) and include higher-order ops like vmap, scan, and reduce.

Project Structure

tinyir-lang
├── examples/
├── tests/
├── tinyir/
│   ├── __init__.py             # public API, registry refresh, primitive registration
│   ├── builder.py              # Builder class for IR construction
│   ├── dsl.py                  # DSL wrappers and context manager
│   ├── interpreter.py          # run_module/run_fn utilities and interpreter
│   ├── ir.py                   # IR dataclasses (Module, Fn, Prim, Let, Types)
│   ├── schema_registry.py      # primitive registry and lowering lookup
│   ├── backends/
│   │   ├── __init__.py         # backend exports
│   │   ├── base.py             # backend registry + base interface
│   │   ├── eager_base.py       # eager execution helper backend
│   │   ├── jax_backend.py      # JAX backend implementation
│   │   ├── numpy_backend.py    # NumPy backend implementation
│   │   └── torch_backend.py    # PyTorch backend implementation
│   └── primitives/
│       ├── core.py             # core primitive specs and registrations
│       ├── elementwise.py      # elementwise primitive specs
│       └── higher_order.py     # vmap/scan/reduce/gather/scatter specs
├── requirements-cpu.txt        # pinned CPU dependencies
├── requirements-gpu.txt        # pinned GPU dependencies
├── pyproject.toml              # build metadata
└── README.md                   # this file :)

About

A minimal Python GPU IR, with backends targeting JAX and torch.func

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages