diff --git a/src/sax/models/couplers.py b/src/sax/models/couplers.py index fc3f6c8..d3f3ad3 100644 --- a/src/sax/models/couplers.py +++ b/src/sax/models/couplers.py @@ -191,7 +191,6 @@ def coupler( ) -@jax.jit @validate_call def grating_coupler( *, @@ -202,7 +201,7 @@ def grating_coupler( reflection_fiber: sax.FloatArrayLike = 0.0, bandwidth: sax.FloatArrayLike = 40e-3, ) -> sax.SDict: - """Grating coupler model for fiber-chip coupling. + """Grating coupler model for fiber-chip coupling (2-port reciprocal S-matrix). ```{svgbob} out0 @@ -216,25 +215,24 @@ def grating_coupler( ``` Args: - wl: Operating wavelength in micrometers. Can be a scalar or array for - spectral analysis. Defaults to 1.55 μm. - wl0: Center wavelength in micrometers where peak transmission occurs. - This is the design wavelength of the grating. Defaults to 1.55 μm. - loss: Insertion loss in dB at the center wavelength. Includes coupling - efficiency losses, scattering, and mode mismatch. Defaults to 0.0 dB. - reflection: Reflection coefficient from the waveguide side (chip interface). - Represents reflections back into the waveguide from grating discontinuities. - Range: 0 to 1. Defaults to 0.0. - reflection_fiber: Reflection coefficient from the fiber side (top interface). - Represents reflections back toward the fiber from the grating surface. - Range: 0 to 1. Defaults to 0.0. - bandwidth: 3dB bandwidth in micrometers. Determines the spectral width - of the Gaussian transmission profile. Typical values: 20-50 nm. - Defaults to 40e-3 μm (40 nm). + wl: Operating wavelength(s) in micrometers. Scalar or array. + wl0: Center wavelength (μm) where peak transmission occurs. + loss: Insertion loss (dB) at wl0, applied to *power*. Must be in [0, 20]. + Converted internally to an *amplitude* factor A0 = 10^(-loss/20). + reflection: **Amplitude** reflection coefficient seen from the waveguide + side (port in0). Must be in [0, 1]. (If you have reflection specified + as *power*, convert via sqrt first.) + reflection_fiber: **Amplitude** reflection coefficient seen from the + fiber side (port out0). Must be in [0, 1]. + bandwidth: 3 dB bandwidth (FWHM) in micrometers (e.g., 40e-3 = 40 nm). Returns: The grating coupler s-matrix + Raises: + ValueError: If reflection or reflection_fiber is outside [0, 1], + or if loss is outside [0, 20] dB. + Examples: Basic grating coupler: @@ -260,13 +258,9 @@ def grating_coupler( ``` Note: - The transmission power profile follows a Gaussian shape: - P(λ) = P₀ * exp(-((λ-λ₀)/σ)²) - - The amplitude transmission is: - A(λ) = A₀ * exp(-((λ-λ₀)/σ)²/2) - - Where σ = bandwidth / (2*√(2*ln(2))) converts FWHM to Gaussian width. + Amplitude transmission has a Gaussian spectral envelope: + A(λ) = A0 * exp(-(λ-λ0)^2 / (4σ^2)) + with σ = FWHM / (2*sqrt(2*ln(2))). This model assumes: @@ -277,19 +271,81 @@ def grating_coupler( - No higher-order diffraction effects """ - one = jnp.ones_like(wl) - reflection = jnp.asarray(reflection) * one - reflection_fiber = jnp.asarray(reflection_fiber) * one - amplitude = jnp.asarray(10 ** (-loss / 20)) * one - wl0_array = jnp.asarray(wl0) * one - sigma = jnp.asarray(bandwidth / (2 * jnp.sqrt(2 * jnp.log(2)))) * one - transmission = amplitude * jnp.exp(-((wl - wl0_array) ** 2) / (4 * sigma**2)) + import numpy as np + + # Validate reflection coefficients (amplitude, must be in [0, 1]) + r_arr = np.asarray(reflection) + if np.any(r_arr < 0) or np.any(r_arr > 1): + raise ValueError( + f"reflection must be in [0, 1] (amplitude), got {reflection}. " + "If you have power reflection, use sqrt(R_power) instead." + ) + + r_fib_arr = np.asarray(reflection_fiber) + if np.any(r_fib_arr < 0) or np.any(r_fib_arr > 1): + raise ValueError( + f"reflection_fiber must be in [0, 1] (amplitude), got {reflection_fiber}. " + "If you have power reflection, use sqrt(R_power) instead." + ) + + # Validate loss (dB, typically 0-20 dB for grating couplers) + loss_arr = np.asarray(loss) + if np.any(loss_arr < 0) or np.any(loss_arr > 20): + raise ValueError( + f"loss must be in [0, 20] dB, got {loss}. " + "Note: loss is in dB (e.g., 3.0 for 3 dB loss), not linear." + ) + + return _grating_coupler_impl( + wl=wl, + wl0=wl0, + loss=loss, + reflection=reflection, + reflection_fiber=reflection_fiber, + bandwidth=bandwidth, + ) + + +@jax.jit +def _grating_coupler_impl( + wl: sax.FloatArrayLike, + wl0: sax.FloatArrayLike, + loss: sax.FloatArrayLike, + reflection: sax.FloatArrayLike, + reflection_fiber: sax.FloatArrayLike, + bandwidth: sax.FloatArrayLike, +) -> sax.SDict: + """JIT-compiled implementation of grating_coupler.""" + wl = jnp.asarray(wl) + wl0 = jnp.asarray(wl0) + loss = jnp.asarray(loss) + reflection = jnp.asarray(reflection) + reflection_fiber = jnp.asarray(reflection_fiber) + bandwidth = jnp.asarray(bandwidth) + + # Broadcast scalars/short arrays to wl shape + wl_shape = wl.shape + wl0_b = jnp.broadcast_to(wl0, wl_shape) + loss_b = jnp.broadcast_to(loss, wl_shape) + r_wg = jnp.broadcast_to(reflection, wl_shape) + r_fib = jnp.broadcast_to(reflection_fiber, wl_shape) + bw_b = jnp.broadcast_to(bandwidth, wl_shape) + + # Constants + conversions + ln2 = jnp.log(jnp.asarray(2.0, dtype=wl.dtype)) + sigma = bw_b / (2.0 * jnp.sqrt(2.0 * ln2)) # FWHM -> σ + a0 = 10.0 ** (-loss_b / 20.0) # dB (power) -> amplitude + + # Gaussian amplitude envelope + d = wl - wl0_b + t = a0 * jnp.exp(-(d * d) / (4.0 * sigma * sigma)) + p = sax.PortNamer(1, 1) return sax.reciprocal( { - (p.in0, p.in0): reflection, - (p.in0, p.out0): transmission, - (p.out0, p.in0): transmission, - (p.out0, p.out0): reflection_fiber, + (p.in0, p.in0): r_wg, + (p.in0, p.out0): t, + (p.out0, p.in0): t, + (p.out0, p.out0): r_fib, } )