Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
8745833
Add Custom Pedestal Model API for user-defined scaling laws
vector-one Nov 5, 2025
7bba2ab
Refactor to use transport model registration pattern
vector-one Nov 16, 2025
2252755
Merge branch 'google-deepmind:main' into main
vector-one Nov 16, 2025
29f5ec8
Address PR review feedback
vector-one Nov 17, 2025
e353edd
Address second round of PR review feedback
vector-one Nov 17, 2025
fc93e9f
Fix pedestal model tests to use ModelConfig API and add pedestal modu…
vector-one Nov 28, 2025
fcc33aa
Fix test failures in register_model_test.py
vector-one Dec 3, 2025
c7b9d29
Merge branch 'main' into custom-pedestal-api
vector-one Dec 3, 2025
4e4668c
Fix test failures: handle JAX tracers, add custom_pedestal_example co…
vector-one Jan 2, 2026
c605609
Merge branch 'main' into custom-pedestal-api
vector-one Jan 2, 2026
5b7be89
Fix test failures: syntax errors, Pydantic discriminator conflicts, a…
vector-one Jan 28, 2026
1f1e884
Merge branch 'temp' into main
vector-one Jan 28, 2026
157836f
Merge pull request #1 from Aaryan-549/main
vector-one Jan 28, 2026
cc1e7c2
Merge pull request #2 from Aaryan-549/temp
vector-one Jan 28, 2026
c4840e7
fixed: transport_model_test.py & sim_time_dependence_test.py
vector-one Jan 28, 2026
b91c94e
Merge pull request #3 from Aaryan-549/temp
vector-one Jan 28, 2026
82d8954
removed pycache & added it in gitignore
vector-one Jan 28, 2026
ae20c85
Remove __pycache__ files from git tracking
vector-one Jan 28, 2026
3d772c5
removed formatting changes
vector-one Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions docs/custom_pedestal_models.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
.. _custom-pedestal-models:

Custom Pedestal Models
######################

TORAX allows you to define custom pedestal scaling laws without modifying the source code.
This is useful for machine-specific models like those used for STEP.

Quick Start
===========

Follow these four steps to create and use a custom pedestal model:

1. Define Your JAX Pedestal Model
----------------------------------

Create a class that inherits from ``PedestalModel``:

.. code-block:: python

import dataclasses
import jax.numpy as jnp
from torax import Geometry
from torax import CoreProfiles
from torax import pedestal

@dataclasses.dataclass(frozen=True)
class MyPedestalModel(pedestal.PedestalModel):
"""My custom pedestal model with EPED-like scaling."""

def _call_implementation(
self,
runtime_params: pedestal.RuntimeParams,
geo: Geometry,
core_profiles: CoreProfiles,
) -> pedestal.PedestalModelOutput:
# Extract plasma parameters
Ip_MA = runtime_params.profile_conditions.Ip / 1e6
B0 = geo.B0

# Your custom scaling laws
T_e_ped = 0.5 * (Ip_MA ** 0.2) * (B0 ** 0.8)
T_i_ped = 1.2 * T_e_ped
n_e_ped = 0.7e20
rho_norm_ped_top = 0.91

# Find mesh index
rho_norm_ped_top_idx = jnp.argmin(
jnp.abs(geo.rho_norm - rho_norm_ped_top)
)

return pedestal.PedestalModelOutput(
rho_norm_ped_top=rho_norm_ped_top,
rho_norm_ped_top_idx=rho_norm_ped_top_idx,
T_i_ped=T_i_ped,
T_e_ped=T_e_ped,
n_e_ped=n_e_ped,
)

2. Define Your Pydantic Configuration
--------------------------------------

.. code-block:: python

from typing import Annotated, Literal
from torax import JAX_STATIC

class MyPedestal(pedestal.BasePedestal):
"""Configuration for my custom pedestal model."""

model_name: Annotated[
Literal['my_pedestal'],
JAX_STATIC
] = 'my_pedestal'

def build_pedestal_model(self) -> MyPedestalModel:
return MyPedestalModel()

def build_runtime_params(self, t) -> pedestal.RuntimeParams:
return pedestal.RuntimeParams(
set_pedestal=self.set_pedestal.get_value(t),
)

3. Register Your Model
----------------------

.. code-block:: python

pedestal.register_pedestal_model(MyPedestal)

4. Use in Configuration
-----------------------

Now use it in your simulation config. Note: this is a minimal example showing
only the pedestal configuration. See the full example for a complete runnable config.

.. code-block:: python

CONFIG = {
# ... other config sections ...
'pedestal': {
'model_name': 'my_pedestal',
'set_pedestal': True,
},
}

Example
=======

See ``torax/examples/custom_pedestal_example.py`` for a complete working example
with EPED-like scaling that can be run directly.

Key Points
==========

* ``PedestalModel`` already inherits from ``StaticDataclass`` - don't inherit twice
* Use public API (``from torax import ...``) not ``_src``
* Models must be JAX-compatible (use ``jax.numpy``)
* Choose a unique ``model_name``
* Register before using in configuration
107 changes: 107 additions & 0 deletions torax/_src/pedestal_model/register_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Register a pedestal model with TORAX."""
from typing import Union, get_args

from torax._src.torax_pydantic import model_config
from torax._src.pedestal_model import pydantic_model


def register_pedestal_model(
pydantic_model_class: type[pydantic_model.BasePedestal],
):
"""Registers a pedestal model with TORAX.

This function adds the pedestal model to the config model such that it can
be configured via pydantic. The pydantic model class should inherit from
BasePedestal and should have a distinct model_name. It should also define a
build_pedestal_model method which returns a PedestalModel.

Example:
```python
from torax._src.pedestal_model import pydantic_model
from torax._src.pedestal_model import register_model
from torax._src.pedestal_model import pedestal_model as pm
from torax._src.pedestal_model import runtime_params
from typing import Annotated, Literal
from torax._src.torax_pydantic import torax_pydantic
import chex

# Define your custom JAX pedestal model
@chex.dataclass(frozen=True)
class MyPedestalModel(pm.PedestalModel):
def _call_implementation(
self,
runtime_params: runtime_params.RuntimeParams,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
) -> pm.PedestalModelOutput:
# Your custom pedestal calculation logic here
T_e_ped = 5.0 # keV
T_i_ped = 6.0 # keV
n_e_ped = 0.7e20 # m^-3
rho_norm_ped_top = 0.91

return pm.PedestalModelOutput(
rho_norm_ped_top=rho_norm_ped_top,
rho_norm_ped_top_idx=..., # compute from geo
T_i_ped=T_i_ped,
T_e_ped=T_e_ped,
n_e_ped=n_e_ped,
)

# Define your Pydantic config class
class MyPedestalConfig(pydantic_model.BasePedestal):
model_name: Annotated[Literal['my_pedestal'], torax_pydantic.JAX_STATIC] = 'my_pedestal'

# Add any configuration parameters you need
scaling_factor: float = 1.0

def build_pedestal_model(self) -> MyPedestalModel:
return MyPedestalModel()

def build_runtime_params(self, t: chex.Numeric) -> runtime_params.RuntimeParams:
return runtime_params.RuntimeParams(
set_pedestal=self.set_pedestal.get_value(t),
)

# Register your model
register_model.register_pedestal_model(MyPedestalConfig)

# Now you can use it in your config
CONFIG = {
'pedestal': {
'model_name': 'my_pedestal',
'set_pedestal': True,
'scaling_factor': 1.5,
},
}
```

Args:
pydantic_model_class: The pydantic model class to register.
"""
# Get the current PedestalConfig union types
current_types = get_args(
model_config.ToraxConfig.model_fields['pedestal'].annotation
)

# Create a new union with the additional pedestal model
type_tuple = (*current_types, pydantic_model_class)
model_config.ToraxConfig.model_fields['pedestal'].annotation = Union[
*type_tuple
]

# Rebuild the model to incorporate the new type
model_config.ToraxConfig.model_rebuild(force=True)
Loading
Loading