TinyIR is a minimal intermediate representation and builder DSL for functional tensor computation, with multiple execution backends (NumPy, JAX, and PyTorch).
Create and activate a virtual environment, then install dependencies:
python3 -m venv .venv
.venv/bin/pip install -U pip setuptools wheel
.venv/bin/pip install -r requirements-cpu.txt
.venv/bin/pip install -e .GPU machines (CUDA 12.x) can use the GPU requirements and vendor wheels:
.venv/bin/pip install -r requirements-gpu.txt
.venv/bin/pip install "jax[cuda12]==0.4.28" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
.venv/bin/pip install torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121Build a simple module with the builder:
import numpy as np
from tinyir import Builder, TensorType, run_module
import tinyir.backends.numpy_backend # registers the backend
b = Builder()
x = b.input(TensorType((2, 3), "float32"), "x")
y = b.input(TensorType((2, 3), "float32"), "y")
z = b.add(x, y)
mod = b.build(z)
out = run_module(mod, {"x": np.ones((2, 3)), "y": np.ones((2, 3))}, backend="numpy")Switching backends is done by importing the backend module (to register it) and selecting it by name:
import tinyir.backends.torch_backend
import tinyir.backends.jax_backend
from tinyir.backends.base import get_backend
torch_backend = get_backend("torch")
jax_backend = get_backend("jax")
fn = torch_backend.compile(mod)The examples/ directory shows full runs, including higher-order ops like
vmap and scan, and backend-specific demos (examples/torch_test.py,
examples/torch_compile_demo.py).
Run the full test suite with pytest:
.venv/bin/python -m pytestTests parameterize backends (NumPy, JAX, and PyTorch) and include higher-order ops like vmap, scan, and reduce.
tinyir-lang
├── examples/
├── tests/
├── tinyir/
│ ├── __init__.py # public API, registry refresh, primitive registration
│ ├── builder.py # Builder class for IR construction
│ ├── dsl.py # DSL wrappers and context manager
│ ├── interpreter.py # run_module/run_fn utilities and interpreter
│ ├── ir.py # IR dataclasses (Module, Fn, Prim, Let, Types)
│ ├── schema_registry.py # primitive registry and lowering lookup
│ ├── backends/
│ │ ├── __init__.py # backend exports
│ │ ├── base.py # backend registry + base interface
│ │ ├── eager_base.py # eager execution helper backend
│ │ ├── jax_backend.py # JAX backend implementation
│ │ ├── numpy_backend.py # NumPy backend implementation
│ │ └── torch_backend.py # PyTorch backend implementation
│ └── primitives/
│ ├── core.py # core primitive specs and registrations
│ ├── elementwise.py # elementwise primitive specs
│ └── higher_order.py # vmap/scan/reduce/gather/scatter specs
├── requirements-cpu.txt # pinned CPU dependencies
├── requirements-gpu.txt # pinned GPU dependencies
├── pyproject.toml # build metadata
└── README.md # this file :)