You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.
When training some models with an initially high learning rate, the ReduceLRUponNan kicks in and seems to properly load weights, but then fails next epoch due to the error above.
There is a minimal working example python script below and I have attached a conda environment file for reproducing the error, as well as the script's output when I run it on my machine.
I have a working solution to this in a fork, I will submit a PR after posting. It seems to arise from recompiling the model during training.
"""MWE showing learning rate reduction method instability."""
import numpy as np
import tensorflow as tf
from matminer.datasets import load_dataset
from megnet.data.crystal import CrystalGraph
from megnet.models import MEGNetModel
from sklearn.model_selection import train_test_split
RANDOM_SEED = 2021
def get_default_megnet_args(
nfeat_bond: int = 10, r_cutoff: float = 5.0, gaussian_width: float = 0.5
) -> dict:
gaussian_centers = np.linspace(0, r_cutoff + 1, nfeat_bond)
graph_converter = CrystalGraph(cutoff=r_cutoff)
return {
"graph_converter": graph_converter,
"centers": gaussian_centers,
"width": gaussian_width,
}
if __name__ == "__main__":
# For reproducability
tf.random.set_seed(RANDOM_SEED)
data = load_dataset("matbench_jdft2d")
train, test = train_test_split(data, random_state=RANDOM_SEED)
meg_model = MEGNetModel(**get_default_megnet_args(), lr=1e-2)
meg_model.train(
train["structure"],
train["exfoliation_en"],
test["structure"],
test["exfoliation_en"],
epochs=8,
verbose=2,
)
The text was updated successfully, but these errors were encountered:
@a-ws-m Thanks. In fact, initially, this part was written the same way as what you did in #221 during the TensorFlow 1.x era. a4e446e
I changed it later when upgrading to TensorFlow 2.x because without resetting the states of the optimizer, reloading weights and adjusting lr does not correct the NaN loss. This has something to do with the version of tensorflow. It is quite annoying.
I think here we can just set automatic_correction = Falsewhen constructing the MEGNetModel. I have recently improved the model training stability in the latest Github repo. The previous NaN loss should be gone. In that case, there is no need for automatic correction.
That's interesting, it's a shame about the functionality change there. I noticed the recent stability improvements, they'll be very useful, cheers. In fact, I first noticed the issue using MEGNet 1.2.5 and then realised I had to set a really high learning rate to make a MWE to reproduce it in 1.2.6!
That being said, as you can see in the log files, it does at least seem to correct the loss shooting up (training continues and loss is reduced), if not the NaN loss issue. EDIT: Unless that's expected behaviour with the callback disabled?
@a-ws-m yes indeed. In this case, we do not really need to correct the training, since loss will come down by itself even without the correction. To be honest, I have not extensively tried 1e-2. Even if I did, that was long ago in the tensorflow 1.x era, things were quite different.
Also, adding tf.compat.v1.disable_eager_execution() may also solve issue you had earlier.
Sign up for freeto subscribe to this conversation on GitHub.
Already have an account?
Sign in.
When training some models with an initially high learning rate, the
ReduceLRUponNan
kicks in and seems to properly load weights, but then fails next epoch due to the error above.There is a minimal working example python script below and I have attached a conda environment file for reproducing the error, as well as the script's output when I run it on my machine.
I have a working solution to this in a fork, I will submit a PR after posting. It seems to arise from recompiling the model during training.
The text was updated successfully, but these errors were encountered: