-
Notifications
You must be signed in to change notification settings - Fork 9
Description
Summary
Adopt an architectural pattern where temporal processes sample all innovations (iid Normal noise) upfront, then apply deterministic differentiable transforms to produce trajectories. This separates the stochastic and deterministic components.
Motivation
From discussion on PR #650, a pattern for time series in probabilistic models was proposed:
Layer 1: Sample ε[1..n] ~ iid Normal(0, 1) ← all randomness upfront
Layer 2: trajectory = transform(ε, params) ← deterministic, differentiable
AR(1), ARIMA(p,d,q), random walks, GPs, and renewal processes can all be defined as differentiable transforms stacked on top of a common noise-sampling primitive.
Benefits
-
HMC/NUTS inference — Non-centered parameterization avoids funnel pathologies and provides cleaner gradients (the latent infection model becomes analogous to a NN where we backprop towards random inputs)
-
Compositionality — Transforms are easier to write, test, and compose than transition functions with embedded sampling
-
Clearer residuals — The innovations ε are the first thing sampled, making "what are the residuals?" unambiguous
-
Extensibility — Swapping process types (AR → GP → HSGP) becomes a matter of changing the transform, not the sampling structure
-
Reduced innovation vectors — For GPs, HSGP or inducing point methods reduce the innovation vector length from n to m << n
Current State
| Component | Pattern | Notes |
|---|---|---|
RandomWalk._sample_vectorized |
✅ Non-centered | Samples increments_raw ~ N(0,1), applies cumsum transform |
RandomWalk._sample_single |
❌ Centered | Delegates to PyRenewRandomWalk |
AR1 |
❌ Centered | Delegates to ARProcess which samples inside scan |
DifferencedAR1 |
❌ Centered | Delegates to DifferencedProcess |
pyrenew.process.ARProcess |
❌ Centered | Samples noise inside scan transition function |
Proposed Pattern
# Conceptual structure for AR(1)
class AR1:
def sample_innovations(self, n_timepoints, n_processes, name_prefix) -> ArrayLike:
"""Sample raw ε ~ N(0,1) — all stochasticity here"""
with numpyro.plate(f"{name_prefix}_time", n_timepoints):
with numpyro.plate(f"{name_prefix}_proc", n_processes):
return numpyro.sample(f"{name_prefix}_innovations", dist.Normal(0, 1))
def transform(self, innovations, rho, sigma, init) -> ArrayLike:
"""Pure deterministic: ε → trajectory"""
def step(x_prev, eps_t):
x_t = rho * x_prev + sigma * eps_t
return x_t, x_t
_, trajectory = jax.lax.scan(step, init, innovations)
return trajectoryFor renewal:
# Pseudocode showing composition
rho = sample("rho", some_prior)
sigma = sample("sigma", some_prior)
epsilon = sample_innovations(n_timepoints)
log_rt = ar1_transform(epsilon, rho, sigma, init)
infections = renewal_transform(jnp.exp(log_rt), I0, generation_interval)Scope Options
Option A: pyrenew.latent.temporal_processes only
make changes now as part of #650, in progress
- Refactor
AR1,DifferencedAR1,RandomWalkto bypasspyrenew.processclasses - Self-contained change, no upstream dependencies
- Demonstrates pattern for future adoption
Option B: Upstream pyrenew.process refactor
- Add
sample_innovations()+transform()split toARProcess,DifferencedProcess, etc. - Enables pattern across all pyrenew users
- Larger scope, requires coordination
Option C: Full noise-first architecture
- All noise sampled at model level (in
HierarchicalInfections.sample()) - Passed down to transforms
- Most compositional but largest refactor
Related
- PR Model builder, latent infection process components and tutorials #650 — Hierarchical infections implementation (where this feedback originated)
pyrenew.process.ar.ARProcess— Current scan-based implementation withLocScaleReparam
References
- Non-centered parameterization: Stan User's Guide - Reparameterization
- The pattern is standard in normalizing flows and VAEs where all stochasticity is in the base distribution