Skip to content

Refactor temporal processes to use innovations-first (non-centered) parameterization pattern #672

@cdc-mitzimorris

Description

@cdc-mitzimorris

Summary

Adopt an architectural pattern where temporal processes sample all innovations (iid Normal noise) upfront, then apply deterministic differentiable transforms to produce trajectories. This separates the stochastic and deterministic components.

Motivation

From discussion on PR #650, a pattern for time series in probabilistic models was proposed:

Layer 1: Sample ε[1..n] ~ iid Normal(0, 1)     ← all randomness upfront
Layer 2: trajectory = transform(ε, params)     ← deterministic, differentiable

AR(1), ARIMA(p,d,q), random walks, GPs, and renewal processes can all be defined as differentiable transforms stacked on top of a common noise-sampling primitive.

Benefits

  1. HMC/NUTS inference — Non-centered parameterization avoids funnel pathologies and provides cleaner gradients (the latent infection model becomes analogous to a NN where we backprop towards random inputs)

  2. Compositionality — Transforms are easier to write, test, and compose than transition functions with embedded sampling

  3. Clearer residuals — The innovations ε are the first thing sampled, making "what are the residuals?" unambiguous

  4. Extensibility — Swapping process types (AR → GP → HSGP) becomes a matter of changing the transform, not the sampling structure

  5. Reduced innovation vectors — For GPs, HSGP or inducing point methods reduce the innovation vector length from n to m << n

Current State

Component Pattern Notes
RandomWalk._sample_vectorized ✅ Non-centered Samples increments_raw ~ N(0,1), applies cumsum transform
RandomWalk._sample_single ❌ Centered Delegates to PyRenewRandomWalk
AR1 ❌ Centered Delegates to ARProcess which samples inside scan
DifferencedAR1 ❌ Centered Delegates to DifferencedProcess
pyrenew.process.ARProcess ❌ Centered Samples noise inside scan transition function

Proposed Pattern

# Conceptual structure for AR(1)
class AR1:
    def sample_innovations(self, n_timepoints, n_processes, name_prefix) -> ArrayLike:
        """Sample raw ε ~ N(0,1) — all stochasticity here"""
        with numpyro.plate(f"{name_prefix}_time", n_timepoints):
            with numpyro.plate(f"{name_prefix}_proc", n_processes):
                return numpyro.sample(f"{name_prefix}_innovations", dist.Normal(0, 1))

    def transform(self, innovations, rho, sigma, init) -> ArrayLike:
        """Pure deterministic: ε → trajectory"""
        def step(x_prev, eps_t):
            x_t = rho * x_prev + sigma * eps_t
            return x_t, x_t
        _, trajectory = jax.lax.scan(step, init, innovations)
        return trajectory

For renewal:

# Pseudocode showing composition
rho = sample("rho", some_prior)
sigma = sample("sigma", some_prior)
epsilon = sample_innovations(n_timepoints)

log_rt = ar1_transform(epsilon, rho, sigma, init)
infections = renewal_transform(jnp.exp(log_rt), I0, generation_interval)

Scope Options

Option A: pyrenew.latent.temporal_processes only

make changes now as part of #650, in progress

  • Refactor AR1, DifferencedAR1, RandomWalk to bypass pyrenew.process classes
  • Self-contained change, no upstream dependencies
  • Demonstrates pattern for future adoption

Option B: Upstream pyrenew.process refactor

  • Add sample_innovations() + transform() split to ARProcess, DifferencedProcess, etc.
  • Enables pattern across all pyrenew users
  • Larger scope, requires coordination

Option C: Full noise-first architecture

  • All noise sampled at model level (in HierarchicalInfections.sample())
  • Passed down to transforms
  • Most compositional but largest refactor

Related

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions