Skip to content

Commit bb7ddb4

Browse files
authored
Rename TrainState.epoch as TrainState.iteration to better represent its actual meaning. (#1962)
1 parent 9546d7e commit bb7ddb4

File tree

4 files changed

+27
-17
lines changed

4 files changed

+27
-17
lines changed

deepxde/callbacks.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def on_epoch_end(self):
151151
if self.verbose > 0:
152152
print(
153153
"Epoch {}: {} improved from {:.2e} to {:.2e}, saving model to {} ...\n".format(
154-
self.model.train_state.epoch,
154+
self.model.train_state.iteration,
155155
self.monitor,
156156
self.best,
157157
current,
@@ -224,7 +224,7 @@ def on_train_begin(self):
224224
self.best = np.inf if self.monitor_op == np.less else -np.inf
225225

226226
def on_epoch_end(self):
227-
if self.model.train_state.epoch < self.start_from_epoch:
227+
if self.model.train_state.iteration < self.start_from_epoch:
228228
return
229229
current = self.get_monitor_value()
230230
if self.monitor_op(current - self.min_delta, self.best):
@@ -233,7 +233,7 @@ def on_epoch_end(self):
233233
else:
234234
self.wait += 1
235235
if self.wait >= self.patience:
236-
self.stopped_epoch = self.model.train_state.epoch
236+
self.stopped_epoch = self.model.train_state.iteration
237237
self.model.stop_training = True
238238

239239
def on_train_end(self):
@@ -274,7 +274,7 @@ def on_epoch_end(self):
274274
self.model.stop_training = True
275275
print(
276276
"\nStop training as time used up. time used: {:.1f} mins, epoch trained: {}".format(
277-
(time.time() - self.t_start) / 60, self.model.train_state.epoch
277+
(time.time() - self.t_start) / 60, self.model.train_state.iteration
278278
)
279279
)
280280

@@ -347,7 +347,7 @@ def on_train_begin(self):
347347
self.value = [var.value for var in self.var_list]
348348

349349
print(
350-
self.model.train_state.epoch,
350+
self.model.train_state.iteration,
351351
utils.list_to_str(self.value, precision=self.precision),
352352
file=self.file,
353353
)
@@ -420,7 +420,7 @@ def op(inputs, params):
420420
def on_train_begin(self):
421421
self.on_predict_end()
422422
print(
423-
self.model.train_state.epoch,
423+
self.model.train_state.iteration,
424424
utils.list_to_str(self.value.flatten().tolist(), precision=self.precision),
425425
file=self.file,
426426
)

deepxde/model.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__all__ = ["LossHistory", "Model", "TrainState"]
22

33
import pickle
4+
import warnings
45
from collections import OrderedDict
56

67
import numpy as np
@@ -715,7 +716,7 @@ def _train_sgd(self, iterations, display_every, verbose=1):
715716
self.train_state.train_aux_vars,
716717
)
717718

718-
self.train_state.epoch += 1
719+
self.train_state.iteration += 1
719720
self.train_state.step += 1
720721
if self.train_state.step % display_every == 0 or i + 1 == iterations:
721722
self._test(verbose=verbose)
@@ -728,7 +729,7 @@ def _train_sgd(self, iterations, display_every, verbose=1):
728729

729730
def _train_tensorflow_compat_v1_scipy(self, display_every, verbose=1):
730731
def loss_callback(loss_train, loss_test, *args):
731-
self.train_state.epoch += 1
732+
self.train_state.iteration += 1
732733
self.train_state.step += 1
733734
if self.train_state.step % display_every == 0:
734735
self.train_state.loss_train = loss_train
@@ -749,7 +750,7 @@ def loss_callback(loss_train, loss_test, *args):
749750
cb.epochs_since_last = 0
750751

751752
print(
752-
cb.model.train_state.epoch,
753+
cb.model.train_state.iteration,
753754
list_to_str(
754755
[float(arg) for arg in args],
755756
precision=cb.precision,
@@ -792,7 +793,7 @@ def _train_tensorflow_tfp(self, verbose=1):
792793
self.train_state.train_aux_vars,
793794
)
794795
n_iter += results.num_iterations.numpy()
795-
self.train_state.epoch += results.num_iterations.numpy()
796+
self.train_state.iteration += results.num_iterations.numpy()
796797
self.train_state.step += results.num_iterations.numpy()
797798
self._test(verbose=verbose)
798799

@@ -819,7 +820,7 @@ def _train_pytorch_lbfgs(self, verbose=1):
819820
# Converged
820821
break
821822

822-
self.train_state.epoch += n_iter - prev_n_iter
823+
self.train_state.iteration += n_iter - prev_n_iter
823824
self.train_state.step += n_iter - prev_n_iter
824825
prev_n_iter = n_iter
825826
self._test(verbose=verbose)
@@ -851,7 +852,7 @@ def _train_paddle_lbfgs(self, verbose=1):
851852
# Converged
852853
break
853854

854-
self.train_state.epoch += n_iter - prev_n_iter
855+
self.train_state.iteration += n_iter - prev_n_iter
855856
self.train_state.step += n_iter - prev_n_iter
856857
prev_n_iter = n_iter
857858
self._test(verbose=verbose)
@@ -1071,7 +1072,7 @@ def save(self, save_path, protocol="backend", verbose=0):
10711072
Returns:
10721073
string: Path where model is saved.
10731074
"""
1074-
save_path = f"{save_path}-{self.train_state.epoch}"
1075+
save_path = f"{save_path}-{self.train_state.iteration}"
10751076
if protocol == "pickle":
10761077
save_path += ".pkl"
10771078
with open(save_path, "wb") as f:
@@ -1104,7 +1105,7 @@ def save(self, save_path, protocol="backend", verbose=0):
11041105
if verbose > 0:
11051106
print(
11061107
"Epoch {}: saving model to {} ...\n".format(
1107-
self.train_state.epoch, save_path
1108+
self.train_state.iteration, save_path
11081109
)
11091110
)
11101111
return save_path
@@ -1159,7 +1160,7 @@ def print_model(self):
11591160

11601161
class TrainState:
11611162
def __init__(self):
1162-
self.epoch = 0
1163+
self.iteration = 0
11631164
self.step = 0
11641165

11651166
# Current data
@@ -1188,6 +1189,15 @@ def __init__(self):
11881189
self.best_ystd = None
11891190
self.best_metrics = None
11901191

1192+
@property
1193+
def epoch(self):
1194+
warnings.warn(
1195+
"TrainState.epoch is deprecated and will be removed in a future version. Use TrainState.iteration instead.",
1196+
DeprecationWarning,
1197+
stacklevel=2,
1198+
)
1199+
return self.iteration
1200+
11911201
def set_data_train(self, X_train, y_train, train_aux_vars=None):
11921202
self.X_train = X_train
11931203
self.y_train = y_train

docs/demos/pinn_forward/elasticity.plate.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ We then train the model for 5000 iterations:
257257

258258
.. code-block:: python
259259
260-
losshistory, train_state = model.train(epochs=5000)
260+
losshistory, train_state = model.train(iterations=5000)
261261
262262
Complete code
263263
--------------

docs/demos/pinn_forward/helmholtz.2d.neumann.hole.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ First, the DeepXDE, Numpy and Matplotlib modules are imported:
5454
import matplotlib.pyplot as plt
5555
import numpy as np
5656
57-
We begin by defining the general parameters for the problem. We use a collocation points density of 15 (resp. 30) points per wavelength for the training (resp. testing) data along each direction. The PINN will be trained over 5000 epochs. We define the learning rate, the number of dense layers and nodes, and the activation function.
57+
We begin by defining the general parameters for the problem. We use a collocation points density of 15 (resp. 30) points per wavelength for the training (resp. testing) data along each direction. The PINN will be trained over 5000 iterations. We define the learning rate, the number of dense layers and nodes, and the activation function.
5858

5959
.. code-block:: python
6060

0 commit comments

Comments
 (0)