Skip to content

Commit 99c3626

Browse files
authored
Add mixed precision support (#1650)
1 parent e104988 commit 99c3626

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

deepxde/config.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
# Default float type
4242
real = Real(32)
43+
# Using mixed precision
44+
mixed = False
4345
# Random seed
4446
random_seed = None
4547
if backend_name == "jax":
@@ -71,11 +73,14 @@ def default_float():
7173
def set_default_float(value):
7274
"""Sets the default float type.
7375
74-
The default floating point type is 'float32'.
76+
The default floating point type is 'float32'. Mixed precision uses the method in the paper:
77+
`J. Hayford, J. Goldman-Wetzler, E. Wang, & L. Lu. Speeding up and reducing memory usage for scientific machine learning via mixed precision.
78+
Computer Methods in Applied Mechanics and Engineering, 428, 117093, 2024 <https://doi.org/10.1016/j.cma.2024.117093>`_.
7579
7680
Args:
77-
value (String): 'float16', 'float32', or 'float64'.
81+
value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision).
7882
"""
83+
global mixed
7984
if value == "float16":
8085
print("Set the default float type to float16")
8186
real.set_float16()
@@ -85,6 +90,20 @@ def set_default_float(value):
8590
elif value == "float64":
8691
print("Set the default float type to float64")
8792
real.set_float64()
93+
elif value == "mixed":
94+
print("Set the float type to mixed precision of float16 and float32")
95+
mixed = True
96+
if backend_name == "tensorflow":
97+
real.set_float16()
98+
tf.keras.mixed_precision.set_global_policy("mixed_float16")
99+
return # don't try to set it again below
100+
if backend_name == "pytorch":
101+
# Use float16 during the forward and backward passes, but store in float32
102+
real.set_float32()
103+
else:
104+
raise ValueError(
105+
f"{backend_name} backend does not currently support mixed precision."
106+
)
88107
else:
89108
raise ValueError(f"{value} not supported in deepXDE")
90109
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:

deepxde/model.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,11 @@ def closure():
374374
total_loss.backward()
375375
return total_loss
376376

377-
self.opt.step(closure)
377+
def closure_mixed():
378+
with torch.autocast(device_type=torch.get_default_device().type, dtype=torch.float16):
379+
return closure()
380+
381+
self.opt.step(closure if not config.mixed else closure_mixed)
378382
if self.lr_scheduler is not None:
379383
self.lr_scheduler.step()
380384

docs/user/faq.rst

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ General usage
1010
| **A**: `#5`_
1111
- | **Q**: By default, DeepXDE uses ``float32``. How can I use ``float64``?
1212
| **A**: `#28`_
13+
- | **Q**: How can I use mixed precision training?
14+
| **A**: Use ``dde.config.set_default_float("mixed")`` with the ``tensorflow`` or ``pytorch`` backends. See `this paper <https://doi.org/10.1016/j.cma.2024.117093>`_ for more information.
1315
- | **Q**: I want to set the global random seeds.
1416
| **A**: `#353`_
1517
- | **Q**: GPU.

0 commit comments

Comments
 (0)