2020import chex
2121import jax
2222import jax .numpy as jnp
23+ import jaxtyping as jt
2324import numpy as np
24- from torax ._src import array_typing
25+ from torax ._src import array_typing as at
2526from torax ._src .torax_pydantic import torax_pydantic
2627
2728
29+ @at .jaxtyped
2830def face_to_cell (
29- face : array_typing . FloatVectorFace ,
30- ) -> array_typing . FloatVectorCell :
31+ face : jt . Float [ at . Array , 'rhon' ] ,
32+ ) -> jt . Float [ at . Array , 'rhon-1' ] :
3133 """Infers cell values corresponding to a vector of face values.
3234
3335 Simply a linear interpolation between face values.
@@ -60,6 +62,7 @@ class GeometryType(enum.IntEnum):
6062# pylint: disable=invalid-name
6163
6264
65+ @at .jaxtyped
6366@jax .tree_util .register_dataclass
6467@dataclasses .dataclass (frozen = True )
6568class Geometry :
@@ -179,51 +182,51 @@ class Geometry:
179182 [:math:`\mathrm{m}`].
180183 """
181184
182- geometry_type : GeometryType
185+ geometry_type : GeometryType = dataclasses . field ( metadata = dict ( static = True ))
183186 torax_mesh : torax_pydantic .Grid1D
184- Phi : array_typing . Array
185- Phi_face : array_typing . Array
186- R_major : array_typing . FloatScalar
187- a_minor : array_typing . FloatScalar
188- B_0 : array_typing . FloatScalar
189- volume : array_typing . Array
190- volume_face : array_typing . Array
191- area : array_typing . Array
192- area_face : array_typing . Array
193- vpr : array_typing . Array
194- vpr_face : array_typing . Array
195- spr : array_typing . Array
196- spr_face : array_typing . Array
197- delta_face : array_typing . Array
198- elongation : array_typing . Array
199- elongation_face : array_typing . Array
200- g0 : array_typing . Array
201- g0_face : array_typing . Array
202- g1 : array_typing . Array
203- g1_face : array_typing . Array
204- g2 : array_typing . Array
205- g2_face : array_typing . Array
206- g3 : array_typing . Array
207- g3_face : array_typing . Array
208- gm4 : array_typing . Array
209- gm4_face : array_typing . Array
210- gm5 : array_typing . Array
211- gm5_face : array_typing . Array
212- g2g3_over_rhon : array_typing . Array
213- g2g3_over_rhon_face : array_typing . Array
214- g2g3_over_rhon_hires : array_typing . Array
215- F : array_typing . Array
216- F_face : array_typing . Array
217- F_hires : array_typing . Array
218- R_in : array_typing . Array
219- R_in_face : array_typing . Array
220- R_out : array_typing . Array
221- R_out_face : array_typing . Array
222- spr_hires : array_typing . Array
223- rho_hires_norm : array_typing . Array
224- rho_hires : array_typing . Array
225- Phi_b_dot : array_typing . FloatScalar
226- _z_magnetic_axis : array_typing . FloatScalar | None
187+ Phi : jt . Float [ at . Array , '*stack rhon' ]
188+ Phi_face : jt . Float [ at . Array , '*stack rhon+1' ]
189+ R_major : jt . Float [ at . Array | float , '*stack' ]
190+ a_minor : jt . Float [ at . Array | float , '*stack' ]
191+ B_0 : jt . Float [ at . Array , '*stack' ]
192+ volume : jt . Float [ at . Array , '*stack rhon' ]
193+ volume_face : jt . Float [ at . Array , '*stack rhon+1' ]
194+ area : jt . Float [ at . Array , '*stack rhon' ]
195+ area_face : jt . Float [ at . Array , '*stack rhon+1' ]
196+ vpr : jt . Float [ at . Array , '*stack rhon' ]
197+ vpr_face : jt . Float [ at . Array , '*stack rhon+1' ]
198+ spr : jt . Float [ at . Array , '*stack rhon' ]
199+ spr_face : jt . Float [ at . Array , '*stack rhon+1' ]
200+ delta_face : jt . Float [ at . Array , '*stack rhon+1' ]
201+ elongation : jt . Float [ at . Array , '*stack rhon' ]
202+ elongation_face : jt . Float [ at . Array , '*stack rhon+1' ]
203+ g0 : jt . Float [ at . Array , '*stack rhon' ]
204+ g0_face : jt . Float [ at . Array , '*stack rhon+1' ]
205+ g1 : jt . Float [ at . Array , '*stack rhon' ]
206+ g1_face : jt . Float [ at . Array , '*stack rhon+1' ]
207+ g2 : jt . Float [ at . Array , '*stack rhon' ]
208+ g2_face : jt . Float [ at . Array , '*stack rhon+1' ]
209+ g3 : jt . Float [ at . Array , '*stack rhon' ]
210+ g3_face : jt . Float [ at . Array , '*stack rhon+1' ]
211+ gm4 : jt . Float [ at . Array , '*stack rhon' ]
212+ gm4_face : jt . Float [ at . Array , '*stack rhon+1' ]
213+ gm5 : jt . Float [ at . Array , '*stack rhon' ]
214+ gm5_face : jt . Float [ at . Array , '*stack rhon+1' ]
215+ g2g3_over_rhon : jt . Float [ at . Array , '*stack rhon' ]
216+ g2g3_over_rhon_face : jt . Float [ at . Array , '*stack rhon+1' ]
217+ g2g3_over_rhon_hires : jt . Float [ at . Array , '*stack rhon_hires' ]
218+ F : jt . Float [ at . Array , '*stack rhon' ]
219+ F_face : jt . Float [ at . Array , '*stack rhon+1' ]
220+ F_hires : jt . Float [ at . Array , '*stack rhon_hires' ]
221+ R_in : jt . Float [ at . Array , '*stack rhon' ]
222+ R_in_face : jt . Float [ at . Array , '*stack rhon+1' ]
223+ R_out : jt . Float [ at . Array , '*stack rhon' ]
224+ R_out_face : jt . Float [ at . Array , '*stack rhon+1' ]
225+ spr_hires : jt . Float [ at . Array , '*stack rhon_hires' ]
226+ rho_hires_norm : jt . Float [ at . Array , '*stack rhon_hires' ]
227+ rho_hires : jt . Float [ at . Array , '*stack rhon_hires' ]
228+ Phi_b_dot : jt . Float [ at . Array | float , '*stack' ]
229+ _z_magnetic_axis : jt . Float [ at . Array , '*stack' ] | None
227230
228231 def __eq__ (self , other : 'Geometry' ) -> bool :
229232 try :
@@ -245,27 +248,27 @@ def q_correction_factor(self) -> chex.Numeric:
245248 )
246249
247250 @property
248- def rho_norm (self ) -> array_typing .Array :
251+ def rho_norm (self ) -> at .Array :
249252 r"""Normalized toroidal flux coordinate on cell grid [dimensionless]."""
250253 return self .torax_mesh .cell_centers
251254
252255 @property
253- def rho_face_norm (self ) -> array_typing .Array :
256+ def rho_face_norm (self ) -> at .Array :
254257 r"""Normalized toroidal flux coordinate on face grid [dimensionless]."""
255258 return self .torax_mesh .face_centers
256259
257260 @property
258- def drho_norm (self ) -> array_typing .Array :
261+ def drho_norm (self ) -> jax .Array :
259262 r"""Grid size for rho_norm [dimensionless]."""
260263 return jnp .array (self .torax_mesh .dx )
261264
262265 @property
263- def rho_face (self ) -> array_typing .Array :
266+ def rho_face (self ) -> jax .Array :
264267 r"""Toroidal flux coordinate on face grid :math:`\mathrm{m}`."""
265268 return self .rho_face_norm * jnp .expand_dims (self .rho_b , axis = - 1 )
266269
267270 @property
268- def rho (self ) -> array_typing .Array :
271+ def rho (self ) -> jax .Array :
269272 r"""Toroidal flux coordinate on cell grid :math:`\mathrm{m}`.
270273
271274 The toroidal flux coordinate is defined as
@@ -276,49 +279,49 @@ def rho(self) -> array_typing.Array:
276279 return self .rho_norm * jnp .expand_dims (self .rho_b , axis = - 1 )
277280
278281 @property
279- def r_mid (self ) -> array_typing .Array :
282+ def r_mid (self ) -> at .Array :
280283 """Midplane radius of the plasma [m], defined as (Rout-Rin)/2."""
281284 return (self .R_out - self .R_in ) / 2
282285
283286 @property
284- def r_mid_face (self ) -> array_typing .Array :
287+ def r_mid_face (self ) -> at .Array :
285288 """Midplane radius of the plasma on the face grid [m]."""
286289 return (self .R_out_face - self .R_in_face ) / 2
287290
288291 @property
289- def epsilon (self ) -> array_typing .Array :
292+ def epsilon (self ) -> at .Array :
290293 """Local midplane inverse aspect ratio [dimensionless]."""
291294 return (self .R_out - self .R_in ) / (self .R_out + self .R_in )
292295
293296 @property
294- def epsilon_face (self ) -> array_typing .Array :
297+ def epsilon_face (self ) -> at .Array :
295298 """Local midplane inverse aspect ratio on the face grid [dimensionless]."""
296299 return (self .R_out_face - self .R_in_face ) / (
297300 self .R_out_face + self .R_in_face
298301 )
299302
300303 @property
301- def drho (self ) -> array_typing .Array :
304+ def drho (self ) -> at .Array :
302305 """Grid size for rho [m]."""
303306 return self .drho_norm * self .rho_b
304307
305308 @property
306- def rho_b (self ) -> array_typing . FloatScalar :
309+ def rho_b (self ) -> jax . Array :
307310 """Toroidal flux coordinate [m] at boundary (LCFS)."""
308311 return jnp .sqrt (self .Phi_b / np .pi / self .B_0 )
309312
310313 @property
311- def Phi_b (self ) -> array_typing . FloatScalar :
314+ def Phi_b (self ) -> at . Array :
312315 r"""Toroidal flux at boundary (LCFS) :math:`\mathrm{Wb}`."""
313316 return self .Phi_face [..., - 1 ]
314317
315318 @property
316- def g1_over_vpr (self ) -> array_typing .Array :
319+ def g1_over_vpr (self ) -> at .Array :
317320 r"""g1/vpr [:math:`\mathrm{m}`]."""
318321 return self .g1 / self .vpr
319322
320323 @property
321- def g1_over_vpr2 (self ) -> array_typing .Array :
324+ def g1_over_vpr2 (self ) -> at .Array :
322325 r"""g1/vpr**2 [:math:`\mathrm{m}^{-2}`]."""
323326 return self .g1 / self .vpr ** 2
324327
@@ -366,7 +369,7 @@ def gm9_face(self) -> jax.Array:
366369 [jnp .expand_dims (first_element , axis = - 1 ), bulk ], axis = - 1
367370 )
368371
369- def z_magnetic_axis (self ) -> chex . Numeric :
372+ def z_magnetic_axis (self ) -> at . Array :
370373 """z position of magnetic axis [m]."""
371374 z_magnetic_axis = self ._z_magnetic_axis
372375 if z_magnetic_axis is not None :
@@ -404,7 +407,7 @@ def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
404407 field_name = field .name
405408 field_value = getattr (first_geo , field_name )
406409 # Stack stackable fields. Save first geo's value for non-stackable fields.
407- if isinstance (field_value , (array_typing .Array , array_typing .FloatScalar )):
410+ if isinstance (field_value , (at .Array , at .FloatScalar )):
408411 field_values = [getattr (geo , field_name ) for geo in geometries ]
409412 stacked_data [field_name ] = np .stack (field_values )
410413 else :
@@ -416,7 +419,7 @@ def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
416419
417420def update_geometries_with_Phibdot (
418421 * ,
419- dt : chex . Numeric ,
422+ dt : at . FloatScalar ,
420423 geo_t : Geometry ,
421424 geo_t_plus_dt : Geometry ,
422425) -> tuple [Geometry , Geometry ]:
0 commit comments