Skip to content

Commit 5532c13

Browse files
authored
Merge branch 'feature/IMAS_coupling' into subfeature/core_profiles_to_imas
2 parents ee71f89 + 768e292 commit 5532c13

8 files changed

Lines changed: 147 additions & 105 deletions

File tree

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ The following command will run TORAX using the default configuration file
126126
run_torax --config='examples/basic_config.py'
127127
```
128128

129-
Simulation progress is shown by a terminal progress bar indicating the current time and percentage completed.
129+
130+
Simulation progress is shown by a terminal progress bar indicating the current
131+
time and percentage completed.
130132

131133
To run more involved, ITER-inspired simulations, run:
132134

docs/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@
310310
.. |QLKNN| replace:: `QLKNN <QLKNN_target_>`_
311311
.. _QLKNN_target: {github_base_url}/torax/_src/transport_model/qlknn_transport_model.py
312312
313+
.. |QuaLiKiz| replace:: `QuaLiKiz <qualikiz_target_>`_
314+
.. _qualikiz_target: {github_base_url}/torax/_src/transport_model/qualikiz_transport_model.py
315+
313316
.. |transport_model| replace:: `transport_model <torax_src_transport_model_target_>`_
314317
.. _torax_src_transport_model_target: {github_base_url}/torax/_src/transport_model
315318

docs/interfacing_with_surrogates.rst

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
.. _interfacing_with_surrogates:
22

3-
JAX-compatible interfaces with ML surrogates of physics models
3+
JAX-compatible interfaces with ML-surrogates of physics models
44
##############################################################
55

6-
This section discusses a variety of options for building JAX-friendly interfaces to surrogate models.
6+
This section discusses a variety of options for building JAX-friendly interfaces
7+
to surrogate models.
8+
9+
As an illustrative example, suppose we have a new neural network surrogate
10+
transport model that we would like to use in TORAX. Assume that all the
11+
boilerplate described in the previous sections has been taken care of, as well
12+
as the definition of some functions to convert between TORAX structures and
13+
tensors for the neural network.
714

8-
As an illustrative example, suppose we have a new neural network surrogate transport model that we would like to use in TORAX.
9-
Assume that all the boilerplate described in the previous sections has been taken care of, as well as the definition of some functions to convert between TORAX structures and tensors for the neural network.
1015

1116
.. code-block:: python
1217
@@ -31,23 +36,33 @@ Assume that all the boilerplate described in the previous sections has been take
3136
v_e=v_e,
3237
)
3338
34-
In this guide, we explore a few options for how you could make the ``_call_surrogate_model`` function for an existing surrogate, while maintaining the full power of JAX:
39+
In this guide, we explore a few options for how you could make the
40+
``_call_surrogate_model`` function for an existing surrogate, while maintaining
41+
the full power of JAX:
3542

36-
1. **Manually reimplementing the model in JAX**,
37-
2. **Converting a Pytorch model to a JAX model**,
43+
1. **Manually reimplementing the model in JAX**.
44+
2. **Converting a Pytorch model to a JAX model**.
3845
3. **Using an ONNX model**.
3946

4047
.. note::
41-
These conversion methods are necessary in order to make an external model compatible with JAX's autodiff and JIT functionality, which is required for using TORAX's gradient-driven nonlinear solvers (e.g. Newton-Raphson).
42-
Interfacing with non-differentiable, non-JITtable models is possible (for an example, see the `QuaLiKiz transport model`_) if the linear solver is used. However, note that if the model is called within the step function JIT will need to be disabled with ``TORAX_COMPILATION_ENABLED=0``.
48+
These conversion methods are necessary in order to make an external model
49+
compatible with JAX's autodiff and JIT functionality, which is required for
50+
using TORAX's gradient-driven nonlinear solvers (e.g. Newton-Raphson).
51+
Interfacing with non-differentiable, non-JITtable models is possible
52+
(for an example, see the |QuaLiKiz| transport model implementation) if the
53+
linear solver is used. However, note that if the model is called within the
54+
step function, JIT will need to be disabled with
55+
``TORAX_COMPILATION_ENABLED=0``.
4356

4457

4558
Option 1: manually reimplementing the model in JAX
4659
==================================================
4760

48-
If the architecture of the surrogate is sufficiently simple, you might consider reimplementing the model in JAX.
49-
The surrogates in TORAX are mostly implemented using `Flax Linen`_, and can be found in the |fusion_surrogates|_ repository.
50-
If you're not familiar with Flax, you can check out the `Flax documentation`_ on how to define your own models.
61+
If the architecture of the surrogate is sufficiently simple, you might consider
62+
reimplementing the model in JAX. The surrogates in TORAX are mostly implemented
63+
using `Flax Linen`_, and can be found in the |fusion_surrogates|_ repository.
64+
If you're not familiar with Flax, you can check out the `Flax documentation`_
65+
on how to define your own models.
5166

5267
Consider a PyTorch neural network,
5368

@@ -97,8 +112,9 @@ This model can be replicated in Flax as follows:
97112
98113
flax_model = FlaxMLP(hidden_dim, n_hidden, output_dim, input_dim)
99114
100-
As this is only the model architecture, we need to load the trained weights separately.
101-
This can be a bit fiddly as you have to map from the parameter names in the weights checkpoint file to the parameter names in the Flax model.
115+
As this is only the model architecture, we need to load the trained weights
116+
separately. This can be a bit fiddly as you have to map from the parameter names
117+
in the weights checkpoint file to the parameter names in the Flax model.
102118

103119
For loading weights from a PyTorch checkpoint, you might do something like:
104120

@@ -130,24 +146,30 @@ The model can then be called like any Flax model,
130146
131147
132148
.. warning::
133-
You need to be very careful when loading from a PyTorch state dict, as Flax and PyTorch may have slightly different representations of the weights (for example, one could be the transpose of the other). It's worth validating the output of your PyTorch model against your JAX model to make sure.
134-
149+
You need to be very careful when loading from a PyTorch state dict, as
150+
Flax and PyTorch may have slightly different representations of the weights
151+
(for example, one could be the transpose of the other). It's worth
152+
validating the output of your PyTorch model against your JAX model to make
153+
sure.
135154

136155

137156
Option 2: converting a PyTorch model to a JAX model
138157
===================================================
139158

140159
.. warning::
141-
The `torch_xla2`_ package is still evolving, which means there may be unexpected breaking changes. Some of the methods described in this section may become deprecated with little warning.
160+
The `torch_xla2`_ package is still evolving, which means there may be
161+
unexpected breaking changes. Some of the methods described in this section
162+
may become deprecated with little warning.
142163

143-
If your model is in PyTorch, you could also consider using the `torch_xla2`_ package to do the conversion to JAX automatically.
164+
If your model is in PyTorch, you could also consider using the `torch_xla2`_
165+
package to do the conversion to JAX automatically.
144166

145167
.. code-block:: python
146168
147169
import torch
148170
import torch_xla2 as tx
149171
150-
trained_model = torch.load(PYTORCH_MODEL_PATH, weights_only=False) # Use weights_only=False if you want to load the full model
172+
trained_model = torch.load(PYTORCH_MODEL_PATH, weights_only=False) # Use weights_only=False if you want to load the full model
151173
params, jax_model_from_torch = tx.extract_jax(model)
152174
153175
The model can then be called as a pure JAX function:
@@ -156,7 +178,9 @@ The model can then be called as a pure JAX function:
156178
157179
output_tensor = jax.jit(jax_model_from_torch)(params, input_tensor)
158180
159-
To remove the need for performing the conversion every time the model is loaded, you might want to save a JAX-compatible version of the weights and model to disk:
181+
To remove the need for performing the conversion every time the model is loaded,
182+
you might want to save a JAX-compatible version of the weights and model to
183+
disk:
160184

161185
.. code-block:: python
162186
@@ -181,16 +205,19 @@ The model can then be loaded and run as follows:
181205
model = jax.export.deserialize(model_as_bytes)
182206
183207
# Load the weights
184-
weights_as_npz = jnp.load('weights.npz')
208+
weights_as_npz = np.load('weights.npz')
185209
weights = [jnp.array(v) for v in weights_as_npz.values()]
186210
187211
188212
Option 3: using an ONNX model
189213
=============================
190214

191-
The `Open Neural Network Exchange`_ format (ONNX) is a highly interoperable format for sharing neural network models. ONNX files include the model architecture and weights bundled together.
215+
The `Open Neural Network Exchange`_ format (ONNX) is a highly interoperable
216+
format for sharing neural network models. ONNX files include the model
217+
architecture and weights bundled together.
192218

193-
An ONNX model can be loaded and called as follows, making sure to specify the correct input and output node names for your specific model:
219+
An ONNX model can be loaded and called as follows, making sure to specify the
220+
correct input and output node names for your specific model:
194221

195222
.. code-block:: python
196223
@@ -207,7 +234,8 @@ An ONNX model can be loaded and called as follows, making sure to specify the co
207234
)
208235
209236
However, JAX will not be able to differentiate through the InferenceSession.
210-
To convert the ONNX model to a JAX representation, you can use the `jaxonnxruntime`_ package:
237+
To convert the ONNX model to a JAX representation, you can use the
238+
`jaxonnxruntime`_ package:
211239

212240
.. code-block:: python
213241
@@ -225,10 +253,12 @@ To convert the ONNX model to a JAX representation, you can use the `jaxonnxrunti
225253
Best practices
226254
==============
227255

228-
**Caching and lazy loading**: Ideally, the model should be constructed and weights loaded once only, on the first call to the function.
229-
The loaded model should be cached and reused for subsequent calls.
256+
**Caching and lazy loading**: Ideally, the model should be constructed and
257+
weights loaded once only, on the first call to the function. The loaded model
258+
should be cached and reused for subsequent calls.
230259

231-
For example, in the ``_combined`` function of the QLKNN transport model (the function that actually evaluates this model), we have:
260+
For example, in the ``_combined`` function of the QLKNN transport model (the
261+
function that actually evaluates this model), we have:
232262

233263
.. code-block:: python
234264
@@ -242,18 +272,21 @@ where
242272
243273
@functools.lru_cache(maxsize=1)
244274
def get_model(path: str) -> base_qlknn_model.BaseQLKNNModel:
245-
"""Load the model."""
246-
...
247-
return qlknn_10d.QLKNN10D(path)
275+
"""Load the model."""
276+
...
277+
return qlknn_10d.QLKNN10D(path)
248278
249-
By decorating with ``functools.lru_cache(maxsize=1)``, the result of this function - the loaded model - is stored in the cache and is only re-loaded if the function is called with a different ``path``.
279+
By decorating with ``functools.lru_cache(maxsize=1)``, the result of this
280+
function - the loaded model - is stored in the cache and is only re-loaded if
281+
the function is called with a different ``path``.
250282

251-
**JITting model calls**: In general, you should make sure that your forward call of the model is JITted:
283+
**JITting model calls**: In general, you should make sure that your forward call
284+
of the model is JITted:
252285

253286
.. code-block:: python
254287
255-
output_tensor = jax.jit(flax_model.apply)(params, input_tensor) # Good
256-
output_tensor = flax_model.apply(params, input_tensor) # Bad
288+
output_tensor = jax.jit(flax_model.apply)(params, input_tensor) # Good
289+
output_tensor = flax_model.apply(params, input_tensor) # Bad
257290
258291
This is vital to ensure fast performance.
259292

@@ -262,6 +295,5 @@ This is vital to ensure fast performance.
262295
.. _torch_xla2: https://pytorch.org/xla/master/features/stablehlo.html
263296
.. _Open Neural Network Exchange: https://onnx.ai/
264297
.. _jaxonnxruntime: https://github.com/google/jaxonnxruntime
265-
.. _QuaLiKiz transport model: https://github.com/google-deepmind/torax/blob/main/torax/transport_model/qualikiz_transport_model.py
266298
.. |fusion_surrogates| replace:: ``google-deepmind/fusion_surrogates``
267299
.. _fusion_surrogates: https://github.com/google-deepmind/fusion_surrogates

docs/model_integration.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ exposed this as part of the TORAX API.
1212

1313
If you would like to use this please reach out to us. We aim to expose this
1414
functionality as part of the TORAX API in the very near future to further
15-
facilitate the integration of custom models.
15+
facilitate the integration of custom models, and further expand the
16+
documentation.
1617

17-
For information on JAX-friendly interfacing with ML-surrogates of physics models, see :ref:`interfacing_with_surrogates`.
18+
.. toctree::
19+
:maxdepth: 1
20+
:caption: Model Integration Topics
21+
22+
interfacing_with_surrogates

torax/_src/orchestration/run_simulation.py

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -44,66 +44,66 @@ def prepare_simulation(
4444
transport_model = torax_config.transport.build_transport_model()
4545
pedestal_model = torax_config.pedestal.build_pedestal_model()
4646

47-
geometry_provider = torax_config.geometry.build_provider
48-
source_models = source_models_lib.SourceModels(
49-
torax_config.sources, neoclassical=torax_config.neoclassical
50-
)
51-
52-
static_runtime_params_slice = (
53-
build_runtime_params.build_static_params_from_config(torax_config)
54-
)
55-
56-
solver = torax_config.solver.build_solver(
57-
static_runtime_params_slice=static_runtime_params_slice,
58-
transport_model=transport_model,
59-
source_models=source_models,
60-
pedestal_model=pedestal_model,
61-
)
62-
63-
mhd_models = torax_config.mhd.build_mhd_models(
64-
static_runtime_params_slice=static_runtime_params_slice,
65-
transport_model=transport_model,
66-
source_models=source_models,
67-
pedestal_model=pedestal_model,
68-
)
69-
70-
step_fn = step_function.SimulationStepFn(
71-
solver=solver,
72-
time_step_calculator=torax_config.time_step_calculator.time_step_calculator,
73-
transport_model=transport_model,
74-
pedestal_model=pedestal_model,
75-
mhd_models=mhd_models,
76-
)
77-
78-
dynamic_runtime_params_slice_provider = (
79-
build_runtime_params.DynamicRuntimeParamsSliceProvider.from_config(
80-
torax_config
81-
)
82-
)
83-
84-
if torax_config.restart and torax_config.restart.do_restart:
85-
initial_state, post_processed_outputs = (
86-
initial_state_lib.get_initial_state_and_post_processed_outputs_from_file(
87-
t_initial=torax_config.numerics.t_initial,
88-
file_restart=torax_config.restart,
89-
static_runtime_params_slice=static_runtime_params_slice,
90-
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
91-
geometry_provider=geometry_provider,
92-
step_fn=step_fn,
93-
)
47+
geometry_provider = torax_config.geometry.build_provider
48+
source_models = source_models_lib.SourceModels(
49+
torax_config.sources, neoclassical=torax_config.neoclassical
50+
)
51+
52+
static_runtime_params_slice = (
53+
build_runtime_params.build_static_params_from_config(torax_config)
54+
)
55+
56+
solver = torax_config.solver.build_solver(
57+
static_runtime_params_slice=static_runtime_params_slice,
58+
transport_model=transport_model,
59+
source_models=source_models,
60+
pedestal_model=pedestal_model,
61+
)
62+
63+
mhd_models = torax_config.mhd.build_mhd_models(
64+
static_runtime_params_slice=static_runtime_params_slice,
65+
transport_model=transport_model,
66+
source_models=source_models,
67+
pedestal_model=pedestal_model,
9468
)
95-
restart_case = True
96-
else:
97-
initial_state, post_processed_outputs = (
98-
initial_state_lib.get_initial_state_and_post_processed_outputs(
99-
t=torax_config.numerics.t_initial,
100-
static_runtime_params_slice=static_runtime_params_slice,
101-
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
102-
geometry_provider=geometry_provider,
103-
step_fn=step_fn,
69+
70+
step_fn = step_function.SimulationStepFn(
71+
solver=solver,
72+
time_step_calculator=torax_config.time_step_calculator.time_step_calculator,
73+
transport_model=transport_model,
74+
pedestal_model=pedestal_model,
75+
mhd_models=mhd_models,
76+
)
77+
78+
dynamic_runtime_params_slice_provider = (
79+
build_runtime_params.DynamicRuntimeParamsSliceProvider.from_config(
80+
torax_config
10481
)
10582
)
106-
restart_case = False
83+
84+
if torax_config.restart and torax_config.restart.do_restart:
85+
initial_state, post_processed_outputs = (
86+
initial_state_lib.get_initial_state_and_post_processed_outputs_from_file(
87+
t_initial=torax_config.numerics.t_initial,
88+
file_restart=torax_config.restart,
89+
static_runtime_params_slice=static_runtime_params_slice,
90+
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
91+
geometry_provider=geometry_provider,
92+
step_fn=step_fn,
93+
)
94+
)
95+
restart_case = True
96+
else:
97+
initial_state, post_processed_outputs = (
98+
initial_state_lib.get_initial_state_and_post_processed_outputs(
99+
t=torax_config.numerics.t_initial,
100+
static_runtime_params_slice=static_runtime_params_slice,
101+
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
102+
geometry_provider=geometry_provider,
103+
step_fn=step_fn,
104+
)
105+
)
106+
restart_case = False
107107

108108
return (
109109
static_runtime_params_slice,
@@ -162,7 +162,7 @@ def run_simulation(
162162
torax_config=torax_config,
163163
)
164164

165-
return (
166-
state_history.simulation_output_to_xr(torax_config.restart),
167-
state_history,
168-
)
165+
return (
166+
state_history.simulation_output_to_xr(torax_config.restart),
167+
state_history,
168+
)

torax/imas_tools/equilibrium.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def geometry_from_IMAS(
144144
rhon = IMAS_data.profiles_1d.rho_tor_norm
145145
vpr = 4 * np.pi * Phi[-1] * rhon / (F * flux_surf_avg_1_over_R2)
146146
spr = vpr / (2 * np.pi * R_major)
147-
Ip_profile_unscaled = scipy.integrate.cumulative_trapezoid(y=spr * jtor, x=rhon, initial=0.0) # this Ip_profile by integration results in a discrepancy between this term and the total ip from IDSAdd commentMore actions
147+
Ip_profile_unscaled = scipy.integrate.cumulative_trapezoid(y=spr * jtor, x=rhon, initial=0.0) # this Ip_profile by integration results in a discrepancy between this term and the total ip from IDS
148148

149149
# Because of the discrepancy between Ip_profile[-1] (computed by integration) and global_quantities.ip, here we will scale Ip_profile such that the total plasma current is equal
150150
Ip_total = -1 * IMAS_data.global_quantities.ip

torax/imas_tools/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def update_dict(old_dict:dict, updates:dict) -> dict:
162162
return new_dict
163163

164164

165+
165166
# todo check if we can copy form geometry without weird dependency loops
166167
def face_to_cell(face):
167168
"""Infers cell values corresponding to a vector of face values.

0 commit comments

Comments
 (0)