Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 91 additions & 35 deletions src/sax/models/couplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def coupler(
)


@jax.jit
@validate_call
def grating_coupler(
*,
Expand All @@ -202,7 +201,7 @@ def grating_coupler(
reflection_fiber: sax.FloatArrayLike = 0.0,
bandwidth: sax.FloatArrayLike = 40e-3,
) -> sax.SDict:
"""Grating coupler model for fiber-chip coupling.
"""Grating coupler model for fiber-chip coupling (2-port reciprocal S-matrix).

```{svgbob}
out0
Expand All @@ -216,25 +215,24 @@ def grating_coupler(
```

Args:
wl: Operating wavelength in micrometers. Can be a scalar or array for
spectral analysis. Defaults to 1.55 μm.
wl0: Center wavelength in micrometers where peak transmission occurs.
This is the design wavelength of the grating. Defaults to 1.55 μm.
loss: Insertion loss in dB at the center wavelength. Includes coupling
efficiency losses, scattering, and mode mismatch. Defaults to 0.0 dB.
reflection: Reflection coefficient from the waveguide side (chip interface).
Represents reflections back into the waveguide from grating discontinuities.
Range: 0 to 1. Defaults to 0.0.
reflection_fiber: Reflection coefficient from the fiber side (top interface).
Represents reflections back toward the fiber from the grating surface.
Range: 0 to 1. Defaults to 0.0.
bandwidth: 3dB bandwidth in micrometers. Determines the spectral width
of the Gaussian transmission profile. Typical values: 20-50 nm.
Defaults to 40e-3 μm (40 nm).
wl: Operating wavelength(s) in micrometers. Scalar or array.
wl0: Center wavelength (μm) where peak transmission occurs.
loss: Insertion loss (dB) at wl0, applied to *power*. Must be in [0, 20].
Converted internally to an *amplitude* factor A0 = 10^(-loss/20).
reflection: **Amplitude** reflection coefficient seen from the waveguide
side (port in0). Must be in [0, 1]. (If you have reflection specified
as *power*, convert via sqrt first.)
reflection_fiber: **Amplitude** reflection coefficient seen from the
fiber side (port out0). Must be in [0, 1].
bandwidth: 3 dB bandwidth (FWHM) in micrometers (e.g., 40e-3 = 40 nm).

Returns:
The grating coupler s-matrix

Raises:
ValueError: If reflection or reflection_fiber is outside [0, 1],
or if loss is outside [0, 20] dB.

Examples:
Basic grating coupler:

Expand All @@ -260,13 +258,9 @@ def grating_coupler(
```

Note:
The transmission power profile follows a Gaussian shape:
P(λ) = P₀ * exp(-((λ-λ₀)/σ)²)

The amplitude transmission is:
A(λ) = A₀ * exp(-((λ-λ₀)/σ)²/2)

Where σ = bandwidth / (2*√(2*ln(2))) converts FWHM to Gaussian width.
Amplitude transmission has a Gaussian spectral envelope:
A(λ) = A0 * exp(-(λ-λ0)^2 / (4σ^2))
with σ = FWHM / (2*sqrt(2*ln(2))).

This model assumes:

Expand All @@ -277,19 +271,81 @@ def grating_coupler(
- No higher-order diffraction effects

"""
one = jnp.ones_like(wl)
reflection = jnp.asarray(reflection) * one
reflection_fiber = jnp.asarray(reflection_fiber) * one
amplitude = jnp.asarray(10 ** (-loss / 20)) * one
wl0_array = jnp.asarray(wl0) * one
sigma = jnp.asarray(bandwidth / (2 * jnp.sqrt(2 * jnp.log(2)))) * one
transmission = amplitude * jnp.exp(-((wl - wl0_array) ** 2) / (4 * sigma**2))
import numpy as np

# Validate reflection coefficients (amplitude, must be in [0, 1])
r_arr = np.asarray(reflection)
if np.any(r_arr < 0) or np.any(r_arr > 1):
raise ValueError(
f"reflection must be in [0, 1] (amplitude), got {reflection}. "
"If you have power reflection, use sqrt(R_power) instead."
)

r_fib_arr = np.asarray(reflection_fiber)
if np.any(r_fib_arr < 0) or np.any(r_fib_arr > 1):
raise ValueError(
f"reflection_fiber must be in [0, 1] (amplitude), got {reflection_fiber}. "
"If you have power reflection, use sqrt(R_power) instead."
)

# Validate loss (dB, typically 0-20 dB for grating couplers)
loss_arr = np.asarray(loss)
if np.any(loss_arr < 0) or np.any(loss_arr > 20):
raise ValueError(
f"loss must be in [0, 20] dB, got {loss}. "
"Note: loss is in dB (e.g., 3.0 for 3 dB loss), not linear."
)

return _grating_coupler_impl(
wl=wl,
wl0=wl0,
loss=loss,
reflection=reflection,
reflection_fiber=reflection_fiber,
bandwidth=bandwidth,
)


@jax.jit
def _grating_coupler_impl(
wl: sax.FloatArrayLike,
wl0: sax.FloatArrayLike,
loss: sax.FloatArrayLike,
reflection: sax.FloatArrayLike,
reflection_fiber: sax.FloatArrayLike,
bandwidth: sax.FloatArrayLike,
) -> sax.SDict:
"""JIT-compiled implementation of grating_coupler."""
wl = jnp.asarray(wl)
wl0 = jnp.asarray(wl0)
loss = jnp.asarray(loss)
reflection = jnp.asarray(reflection)
reflection_fiber = jnp.asarray(reflection_fiber)
bandwidth = jnp.asarray(bandwidth)

# Broadcast scalars/short arrays to wl shape
wl_shape = wl.shape
wl0_b = jnp.broadcast_to(wl0, wl_shape)
loss_b = jnp.broadcast_to(loss, wl_shape)
r_wg = jnp.broadcast_to(reflection, wl_shape)
r_fib = jnp.broadcast_to(reflection_fiber, wl_shape)
bw_b = jnp.broadcast_to(bandwidth, wl_shape)

# Constants + conversions
ln2 = jnp.log(jnp.asarray(2.0, dtype=wl.dtype))
sigma = bw_b / (2.0 * jnp.sqrt(2.0 * ln2)) # FWHM -> σ
a0 = 10.0 ** (-loss_b / 20.0) # dB (power) -> amplitude

# Gaussian amplitude envelope
d = wl - wl0_b
t = a0 * jnp.exp(-(d * d) / (4.0 * sigma * sigma))

p = sax.PortNamer(1, 1)
return sax.reciprocal(
{
(p.in0, p.in0): reflection,
(p.in0, p.out0): transmission,
(p.out0, p.in0): transmission,
(p.out0, p.out0): reflection_fiber,
(p.in0, p.in0): r_wg,
(p.in0, p.out0): t,
(p.out0, p.in0): t,
(p.out0, p.out0): r_fib,
}
)
Loading