Skip to content

Commit 653f121

Browse files
sbodensteinTorax team
authored andcommitted
Stricter typing for geometry.
PiperOrigin-RevId: 800410735
1 parent a3ca1bb commit 653f121

File tree

3 files changed

+87
-67
lines changed

3 files changed

+87
-67
lines changed

torax/_src/array_typing.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ============================================================================
1515
"""Common types for using jaxtyping in TORAX."""
1616

17+
import dataclasses
1718
from typing import TypeAlias, TypeVar
1819
import jax
1920
import jaxtyping as jt
@@ -50,6 +51,20 @@ def jaxtyped(fn: T) -> T:
5051
The decorated function.
5152
"""
5253
runtime_checking = jax_utils.env_bool(name="TORAX_JAXTYPING", default=False)
54+
runtime_checking_dataclasses = jax_utils.env_bool(
55+
name="TORAX_JAXTYPING_DATACLASSES", default=False
56+
)
57+
58+
# Dataclasses are dangerous to decorate with jaxtyping because the shapes
59+
# are checked during __init__, which is called during PyTree unpflattening.
60+
# This can cause timeouts in large tests. Allow only specific tests to enable
61+
# this behavior.
62+
if dataclasses.is_dataclass(fn):
63+
if runtime_checking_dataclasses:
64+
return jt.jaxtyped(fn, typechecker=typeguard.typechecked)
65+
else:
66+
return fn
67+
5368
if runtime_checking:
5469
return jt.jaxtyped(fn, typechecker=typeguard.typechecked)
5570
else:

torax/_src/geometry/geometry.py

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@
2020
import chex
2121
import jax
2222
import jax.numpy as jnp
23+
import jaxtyping as jt
2324
import numpy as np
24-
from torax._src import array_typing
25+
from torax._src import array_typing as at
2526
from torax._src.torax_pydantic import torax_pydantic
2627

2728

29+
@at.jaxtyped
2830
def face_to_cell(
29-
face: array_typing.FloatVectorFace,
30-
) -> array_typing.FloatVectorCell:
31+
face: jt.Float[at.Array, 'rhon'],
32+
) -> jt.Float[at.Array, 'rhon-1']:
3133
"""Infers cell values corresponding to a vector of face values.
3234
3335
Simply a linear interpolation between face values.
@@ -60,6 +62,7 @@ class GeometryType(enum.IntEnum):
6062
# pylint: disable=invalid-name
6163

6264

65+
@at.jaxtyped
6366
@jax.tree_util.register_dataclass
6467
@dataclasses.dataclass(frozen=True)
6568
class Geometry:
@@ -179,51 +182,51 @@ class Geometry:
179182
[:math:`\mathrm{m}`].
180183
"""
181184

182-
geometry_type: GeometryType
185+
geometry_type: GeometryType = dataclasses.field(metadata=dict(static=True))
183186
torax_mesh: torax_pydantic.Grid1D
184-
Phi: array_typing.Array
185-
Phi_face: array_typing.Array
186-
R_major: array_typing.FloatScalar
187-
a_minor: array_typing.FloatScalar
188-
B_0: array_typing.FloatScalar
189-
volume: array_typing.Array
190-
volume_face: array_typing.Array
191-
area: array_typing.Array
192-
area_face: array_typing.Array
193-
vpr: array_typing.Array
194-
vpr_face: array_typing.Array
195-
spr: array_typing.Array
196-
spr_face: array_typing.Array
197-
delta_face: array_typing.Array
198-
elongation: array_typing.Array
199-
elongation_face: array_typing.Array
200-
g0: array_typing.Array
201-
g0_face: array_typing.Array
202-
g1: array_typing.Array
203-
g1_face: array_typing.Array
204-
g2: array_typing.Array
205-
g2_face: array_typing.Array
206-
g3: array_typing.Array
207-
g3_face: array_typing.Array
208-
gm4: array_typing.Array
209-
gm4_face: array_typing.Array
210-
gm5: array_typing.Array
211-
gm5_face: array_typing.Array
212-
g2g3_over_rhon: array_typing.Array
213-
g2g3_over_rhon_face: array_typing.Array
214-
g2g3_over_rhon_hires: array_typing.Array
215-
F: array_typing.Array
216-
F_face: array_typing.Array
217-
F_hires: array_typing.Array
218-
R_in: array_typing.Array
219-
R_in_face: array_typing.Array
220-
R_out: array_typing.Array
221-
R_out_face: array_typing.Array
222-
spr_hires: array_typing.Array
223-
rho_hires_norm: array_typing.Array
224-
rho_hires: array_typing.Array
225-
Phi_b_dot: array_typing.FloatScalar
226-
_z_magnetic_axis: array_typing.FloatScalar | None
187+
Phi: jt.Float[at.Array, '*stack rhon']
188+
Phi_face: jt.Float[at.Array, '*stack rhon+1']
189+
R_major: jt.Float[at.Array | float, '*stack']
190+
a_minor: jt.Float[at.Array | float, '*stack']
191+
B_0: jt.Float[at.Array, '*stack']
192+
volume: jt.Float[at.Array, '*stack rhon']
193+
volume_face: jt.Float[at.Array, '*stack rhon+1']
194+
area: jt.Float[at.Array, '*stack rhon']
195+
area_face: jt.Float[at.Array, '*stack rhon+1']
196+
vpr: jt.Float[at.Array, '*stack rhon']
197+
vpr_face: jt.Float[at.Array, '*stack rhon+1']
198+
spr: jt.Float[at.Array, '*stack rhon']
199+
spr_face: jt.Float[at.Array, '*stack rhon+1']
200+
delta_face: jt.Float[at.Array, '*stack rhon+1']
201+
elongation: jt.Float[at.Array, '*stack rhon']
202+
elongation_face: jt.Float[at.Array, '*stack rhon+1']
203+
g0: jt.Float[at.Array, '*stack rhon']
204+
g0_face: jt.Float[at.Array, '*stack rhon+1']
205+
g1: jt.Float[at.Array, '*stack rhon']
206+
g1_face: jt.Float[at.Array, '*stack rhon+1']
207+
g2: jt.Float[at.Array, '*stack rhon']
208+
g2_face: jt.Float[at.Array, '*stack rhon+1']
209+
g3: jt.Float[at.Array, '*stack rhon']
210+
g3_face: jt.Float[at.Array, '*stack rhon+1']
211+
gm4: jt.Float[at.Array, '*stack rhon']
212+
gm4_face: jt.Float[at.Array, '*stack rhon+1']
213+
gm5: jt.Float[at.Array, '*stack rhon']
214+
gm5_face: jt.Float[at.Array, '*stack rhon+1']
215+
g2g3_over_rhon: jt.Float[at.Array, '*stack rhon']
216+
g2g3_over_rhon_face: jt.Float[at.Array, '*stack rhon+1']
217+
g2g3_over_rhon_hires: jt.Float[at.Array, '*stack rhon_hires']
218+
F: jt.Float[at.Array, '*stack rhon']
219+
F_face: jt.Float[at.Array, '*stack rhon+1']
220+
F_hires: jt.Float[at.Array, '*stack rhon_hires']
221+
R_in: jt.Float[at.Array, '*stack rhon']
222+
R_in_face: jt.Float[at.Array, '*stack rhon+1']
223+
R_out: jt.Float[at.Array, '*stack rhon']
224+
R_out_face: jt.Float[at.Array, '*stack rhon+1']
225+
spr_hires: jt.Float[at.Array, '*stack rhon_hires']
226+
rho_hires_norm: jt.Float[at.Array, '*stack rhon_hires']
227+
rho_hires: jt.Float[at.Array, '*stack rhon_hires']
228+
Phi_b_dot: jt.Float[at.Array | float, '*stack']
229+
_z_magnetic_axis: jt.Float[at.Array, '*stack'] | None
227230

228231
def __eq__(self, other: 'Geometry') -> bool:
229232
try:
@@ -245,27 +248,27 @@ def q_correction_factor(self) -> chex.Numeric:
245248
)
246249

247250
@property
248-
def rho_norm(self) -> array_typing.Array:
251+
def rho_norm(self) -> at.Array:
249252
r"""Normalized toroidal flux coordinate on cell grid [dimensionless]."""
250253
return self.torax_mesh.cell_centers
251254

252255
@property
253-
def rho_face_norm(self) -> array_typing.Array:
256+
def rho_face_norm(self) -> at.Array:
254257
r"""Normalized toroidal flux coordinate on face grid [dimensionless]."""
255258
return self.torax_mesh.face_centers
256259

257260
@property
258-
def drho_norm(self) -> array_typing.Array:
261+
def drho_norm(self) -> jax.Array:
259262
r"""Grid size for rho_norm [dimensionless]."""
260263
return jnp.array(self.torax_mesh.dx)
261264

262265
@property
263-
def rho_face(self) -> array_typing.Array:
266+
def rho_face(self) -> jax.Array:
264267
r"""Toroidal flux coordinate on face grid :math:`\mathrm{m}`."""
265268
return self.rho_face_norm * jnp.expand_dims(self.rho_b, axis=-1)
266269

267270
@property
268-
def rho(self) -> array_typing.Array:
271+
def rho(self) -> jax.Array:
269272
r"""Toroidal flux coordinate on cell grid :math:`\mathrm{m}`.
270273
271274
The toroidal flux coordinate is defined as
@@ -276,49 +279,49 @@ def rho(self) -> array_typing.Array:
276279
return self.rho_norm * jnp.expand_dims(self.rho_b, axis=-1)
277280

278281
@property
279-
def r_mid(self) -> array_typing.Array:
282+
def r_mid(self) -> at.Array:
280283
"""Midplane radius of the plasma [m], defined as (Rout-Rin)/2."""
281284
return (self.R_out - self.R_in) / 2
282285

283286
@property
284-
def r_mid_face(self) -> array_typing.Array:
287+
def r_mid_face(self) -> at.Array:
285288
"""Midplane radius of the plasma on the face grid [m]."""
286289
return (self.R_out_face - self.R_in_face) / 2
287290

288291
@property
289-
def epsilon(self) -> array_typing.Array:
292+
def epsilon(self) -> at.Array:
290293
"""Local midplane inverse aspect ratio [dimensionless]."""
291294
return (self.R_out - self.R_in) / (self.R_out + self.R_in)
292295

293296
@property
294-
def epsilon_face(self) -> array_typing.Array:
297+
def epsilon_face(self) -> at.Array:
295298
"""Local midplane inverse aspect ratio on the face grid [dimensionless]."""
296299
return (self.R_out_face - self.R_in_face) / (
297300
self.R_out_face + self.R_in_face
298301
)
299302

300303
@property
301-
def drho(self) -> array_typing.Array:
304+
def drho(self) -> at.Array:
302305
"""Grid size for rho [m]."""
303306
return self.drho_norm * self.rho_b
304307

305308
@property
306-
def rho_b(self) -> array_typing.FloatScalar:
309+
def rho_b(self) -> jax.Array:
307310
"""Toroidal flux coordinate [m] at boundary (LCFS)."""
308311
return jnp.sqrt(self.Phi_b / np.pi / self.B_0)
309312

310313
@property
311-
def Phi_b(self) -> array_typing.FloatScalar:
314+
def Phi_b(self) -> at.Array:
312315
r"""Toroidal flux at boundary (LCFS) :math:`\mathrm{Wb}`."""
313316
return self.Phi_face[..., -1]
314317

315318
@property
316-
def g1_over_vpr(self) -> array_typing.Array:
319+
def g1_over_vpr(self) -> at.Array:
317320
r"""g1/vpr [:math:`\mathrm{m}`]."""
318321
return self.g1 / self.vpr
319322

320323
@property
321-
def g1_over_vpr2(self) -> array_typing.Array:
324+
def g1_over_vpr2(self) -> at.Array:
322325
r"""g1/vpr**2 [:math:`\mathrm{m}^{-2}`]."""
323326
return self.g1 / self.vpr**2
324327

@@ -366,7 +369,7 @@ def gm9_face(self) -> jax.Array:
366369
[jnp.expand_dims(first_element, axis=-1), bulk], axis=-1
367370
)
368371

369-
def z_magnetic_axis(self) -> chex.Numeric:
372+
def z_magnetic_axis(self) -> at.Array:
370373
"""z position of magnetic axis [m]."""
371374
z_magnetic_axis = self._z_magnetic_axis
372375
if z_magnetic_axis is not None:
@@ -404,7 +407,7 @@ def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
404407
field_name = field.name
405408
field_value = getattr(first_geo, field_name)
406409
# Stack stackable fields. Save first geo's value for non-stackable fields.
407-
if isinstance(field_value, (array_typing.Array, array_typing.FloatScalar)):
410+
if isinstance(field_value, (at.Array, at.FloatScalar)):
408411
field_values = [getattr(geo, field_name) for geo in geometries]
409412
stacked_data[field_name] = np.stack(field_values)
410413
else:
@@ -416,7 +419,7 @@ def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
416419

417420
def update_geometries_with_Phibdot(
418421
*,
419-
dt: chex.Numeric,
422+
dt: at.FloatScalar,
420423
geo_t: Geometry,
421424
geo_t_plus_dt: Geometry,
422425
) -> tuple[Geometry, Geometry]:

torax/_src/geometry/tests/geometry_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def test_stack_geometries_error_handling_different_geometry_types(self):
151151
def test_update_phibdot(self):
152152
"""Test update_phibdot for circular geometries."""
153153
geo = geometry_pydantic_model.CircularConfig().build_geometry()
154-
geo0 = dataclasses.replace(geo, Phi_face=np.array([1.0]))
155-
geo1 = dataclasses.replace(geo, Phi_face=np.array([2.0]))
154+
geo0 = dataclasses.replace(geo, Phi_face=np.ones_like(geo.Phi_face))
155+
geo1 = dataclasses.replace(geo, Phi_face=np.full_like(geo.Phi_face, 2.0))
156156
geo0_updated, geo1_updated = geometry.update_geometries_with_Phibdot(
157157
dt=0.1, geo_t=geo0, geo_t_plus_dt=geo1
158158
)
@@ -166,7 +166,9 @@ def test_geometry_eq(self):
166166
self.assertEqual(geo1, geo2)
167167

168168
with self.subTest('different_geometries_are_not_equal'):
169-
geo3 = dataclasses.replace(geo1, Phi_face=np.array([2.0]))
169+
geo3 = dataclasses.replace(
170+
geo1, Phi_face=np.full_like(geo1.Phi_face, 2.0)
171+
)
170172
self.assertNotEqual(geo1, geo3)
171173

172174

0 commit comments

Comments
 (0)