From 74ff3e0dcd6d7ce89f3e3735abfff108f6efd363 Mon Sep 17 00:00:00 2001 From: Haowei Wen Date: Sun, 21 Apr 2024 11:10:46 +0800 Subject: [PATCH] Save the original model config in train script The train script removes the "oxe_kwargs" key in the config, and add "dataset_kwargs_list" & "sample_weights" to the config. Such behavior causes a `TypeError: Object of type function is not JSON serializable(base)` error when saving model checkpoint, because the `standardize_fn` we put into the model config is not serializable. This commit fixes the issue by preserving the unmodified config. --- scripts/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/train.py b/scripts/train.py index d936083b..f81f7553 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -139,6 +139,9 @@ def process_batch(batch): del batch["dataset_name"] return batch + # copy the original config before we modify it + model_config = FLAGS.config.to_dict() + # load datasets if "oxe_kwargs" in FLAGS.config.dataset_kwargs: # create dataset_kwargs_list from oxe_kwargs @@ -180,7 +183,7 @@ def process_batch(batch): rng = jax.random.PRNGKey(FLAGS.config.seed) rng, init_rng = jax.random.split(rng) model = OctoModel.from_config( - FLAGS.config.to_dict(), + model_config, example_batch, text_processor, verbose=True,