You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This section discusses a variety of options for building JAX-friendly interfaces to surrogate models.
6
+
This section discusses a variety of options for building JAX-friendly interfaces
7
+
to surrogate models.
8
+
9
+
As an illustrative example, suppose we have a new neural network surrogate
10
+
transport model that we would like to use in TORAX. Assume that all the
11
+
boilerplate described in the previous sections has been taken care of, as well
12
+
as the definition of some functions to convert between TORAX structures and
13
+
tensors for the neural network.
7
14
8
-
As an illustrative example, suppose we have a new neural network surrogate transport model that we would like to use in TORAX.
9
-
Assume that all the boilerplate described in the previous sections has been taken care of, as well as the definition of some functions to convert between TORAX structures and tensors for the neural network.
10
15
11
16
.. code-block:: python
12
17
@@ -31,23 +36,33 @@ Assume that all the boilerplate described in the previous sections has been take
31
36
v_e=v_e,
32
37
)
33
38
34
-
In this guide, we explore a few options for how you could make the ``_call_surrogate_model`` function for an existing surrogate, while maintaining the full power of JAX:
39
+
In this guide, we explore a few options for how you could make the
40
+
``_call_surrogate_model`` function for an existing surrogate, while maintaining
41
+
the full power of JAX:
35
42
36
-
1. **Manually reimplementing the model in JAX**,
37
-
2. **Converting a Pytorch model to a JAX model**,
43
+
1. **Manually reimplementing the model in JAX**.
44
+
2. **Converting a Pytorch model to a JAX model**.
38
45
3. **Using an ONNX model**.
39
46
40
47
.. note::
41
-
These conversion methods are necessary in order to make an external model compatible with JAX's autodiff and JIT functionality, which is required for using TORAX's gradient-driven nonlinear solvers (e.g. Newton-Raphson).
42
-
Interfacing with non-differentiable, non-JITtable models is possible (for an example, see the `QuaLiKiz transport model`_) if the linear solver is used. However, note that if the model is called within the step function JIT will need to be disabled with ``TORAX_COMPILATION_ENABLED=0``.
48
+
These conversion methods are necessary in order to make an external model
49
+
compatible with JAX's autodiff and JIT functionality, which is required for
50
+
using TORAX's gradient-driven nonlinear solvers (e.g. Newton-Raphson).
51
+
Interfacing with non-differentiable, non-JITtable models is possible
52
+
(for an example, see the |QuaLiKiz| transport model implementation) if the
53
+
linear solver is used. However, note that if the model is called within the
54
+
step function, JIT will need to be disabled with
55
+
``TORAX_COMPILATION_ENABLED=0``.
43
56
44
57
45
58
Option 1: manually reimplementing the model in JAX
As this is only the model architecture, we need to load the trained weights separately.
101
-
This can be a bit fiddly as you have to map from the parameter names in the weights checkpoint file to the parameter names in the Flax model.
115
+
As this is only the model architecture, we need to load the trained weights
116
+
separately. This can be a bit fiddly as you have to map from the parameter names
117
+
in the weights checkpoint file to the parameter names in the Flax model.
102
118
103
119
For loading weights from a PyTorch checkpoint, you might do something like:
104
120
@@ -130,24 +146,30 @@ The model can then be called like any Flax model,
130
146
131
147
132
148
.. warning::
133
-
You need to be very careful when loading from a PyTorch state dict, as Flax and PyTorch may have slightly different representations of the weights (for example, one could be the transpose of the other). It's worth validating the output of your PyTorch model against your JAX model to make sure.
134
-
149
+
You need to be very careful when loading from a PyTorch state dict, as
150
+
Flax and PyTorch may have slightly different representations of the weights
151
+
(for example, one could be the transpose of the other). It's worth
152
+
validating the output of your PyTorch model against your JAX model to make
153
+
sure.
135
154
136
155
137
156
Option 2: converting a PyTorch model to a JAX model
The `torch_xla2`_ package is still evolving, which means there may be unexpected breaking changes. Some of the methods described in this section may become deprecated with little warning.
160
+
The `torch_xla2`_ package is still evolving, which means there may be
161
+
unexpected breaking changes. Some of the methods described in this section
162
+
may become deprecated with little warning.
142
163
143
-
If your model is in PyTorch, you could also consider using the `torch_xla2`_ package to do the conversion to JAX automatically.
164
+
If your model is in PyTorch, you could also consider using the `torch_xla2`_
165
+
package to do the conversion to JAX automatically.
144
166
145
167
.. code-block:: python
146
168
147
169
import torch
148
170
import torch_xla2 as tx
149
171
150
-
trained_model = torch.load(PYTORCH_MODEL_PATH, weights_only=False) # Use weights_only=False if you want to load the full model
172
+
trained_model = torch.load(PYTORCH_MODEL_PATH, weights_only=False) # Use weights_only=False if you want to load the full model
To remove the need for performing the conversion every time the model is loaded, you might want to save a JAX-compatible version of the weights and model to disk:
181
+
To remove the need for performing the conversion every time the model is loaded,
182
+
you might want to save a JAX-compatible version of the weights and model to
183
+
disk:
160
184
161
185
.. code-block:: python
162
186
@@ -181,16 +205,19 @@ The model can then be loaded and run as follows:
181
205
model = jax.export.deserialize(model_as_bytes)
182
206
183
207
# Load the weights
184
-
weights_as_npz =jnp.load('weights.npz')
208
+
weights_as_npz =np.load('weights.npz')
185
209
weights = [jnp.array(v) for v in weights_as_npz.values()]
186
210
187
211
188
212
Option 3: using an ONNX model
189
213
=============================
190
214
191
-
The `Open Neural Network Exchange`_ format (ONNX) is a highly interoperable format for sharing neural network models. ONNX files include the model architecture and weights bundled together.
215
+
The `Open Neural Network Exchange`_ format (ONNX) is a highly interoperable
216
+
format for sharing neural network models. ONNX files include the model
217
+
architecture and weights bundled together.
192
218
193
-
An ONNX model can be loaded and called as follows, making sure to specify the correct input and output node names for your specific model:
219
+
An ONNX model can be loaded and called as follows, making sure to specify the
220
+
correct input and output node names for your specific model:
194
221
195
222
.. code-block:: python
196
223
@@ -207,7 +234,8 @@ An ONNX model can be loaded and called as follows, making sure to specify the co
207
234
)
208
235
209
236
However, JAX will not be able to differentiate through the InferenceSession.
210
-
To convert the ONNX model to a JAX representation, you can use the `jaxonnxruntime`_ package:
237
+
To convert the ONNX model to a JAX representation, you can use the
238
+
`jaxonnxruntime`_ package:
211
239
212
240
.. code-block:: python
213
241
@@ -225,10 +253,12 @@ To convert the ONNX model to a JAX representation, you can use the `jaxonnxrunti
225
253
Best practices
226
254
==============
227
255
228
-
**Caching and lazy loading**: Ideally, the model should be constructed and weights loaded once only, on the first call to the function.
229
-
The loaded model should be cached and reused for subsequent calls.
256
+
**Caching and lazy loading**: Ideally, the model should be constructed and
257
+
weights loaded once only, on the first call to the function. The loaded model
258
+
should be cached and reused for subsequent calls.
230
259
231
-
For example, in the ``_combined`` function of the QLKNN transport model (the function that actually evaluates this model), we have:
260
+
For example, in the ``_combined`` function of the QLKNN transport model (the
261
+
function that actually evaluates this model), we have:
By decorating with ``functools.lru_cache(maxsize=1)``, the result of this function - the loaded model - is stored in the cache and is only re-loaded if the function is called with a different ``path``.
279
+
By decorating with ``functools.lru_cache(maxsize=1)``, the result of this
280
+
function - the loaded model - is stored in the cache and is only re-loaded if
281
+
the function is called with a different ``path``.
250
282
251
-
**JITting model calls**: In general, you should make sure that your forward call of the model is JITted:
283
+
**JITting model calls**: In general, you should make sure that your forward call
284
+
of the model is JITted:
252
285
253
286
.. code-block:: python
254
287
255
-
output_tensor = jax.jit(flax_model.apply)(params, input_tensor) # Good
256
-
output_tensor = flax_model.apply(params, input_tensor) # Bad
288
+
output_tensor = jax.jit(flax_model.apply)(params, input_tensor) # Good
289
+
output_tensor = flax_model.apply(params, input_tensor) # Bad
257
290
258
291
This is vital to ensure fast performance.
259
292
@@ -262,6 +295,5 @@ This is vital to ensure fast performance.
Ip_profile_unscaled=scipy.integrate.cumulative_trapezoid(y=spr*jtor, x=rhon, initial=0.0) # this Ip_profile by integration results in a discrepancy between this term and the total ip from IDSAdd commentMore actions
147
+
Ip_profile_unscaled=scipy.integrate.cumulative_trapezoid(y=spr*jtor, x=rhon, initial=0.0) # this Ip_profile by integration results in a discrepancy between this term and the total ip from IDS
148
148
149
149
# Because of the discrepancy between Ip_profile[-1] (computed by integration) and global_quantities.ip, here we will scale Ip_profile such that the total plasma current is equal
0 commit comments