-
Notifications
You must be signed in to change notification settings - Fork 455
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug report
I followed the tutorial for doing SFT on MaxText (with Gemma3)
https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html
Unfortunately, I encountered a value error when converting the MaxText checkpoint back to HuggingFace format
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name='gemma3-4b' \
hf_access_token=${some_token} \
load_parameters_path=gs://${somewhere}/${run_name}/checkpoints/1/model_params \
base_output_directory=/tmp/gemma3/hf_converted/ \
use_multimodal=false \
scan_layers=false \
skip_jax_distributed_system=trueLogs/Output
I0123 09:00:33.342407 124540818566272 checkpointer.py:318] Finished restoring checkpoint in 16.29 seconds from gs://${somewhere}/gemma3-4b-sft-v1/train_from_hf_gemma3_4b_it/checkpoints/1/model_params.
I0123 09:00:33.344157 124540818566272 to_huggingface.py:140] Elapse for checkpoint load: 0.28 min
I0123 09:00:34.637320 124540818566272 utils.py:884] Detected NNX-SFT checkpoint structure
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/jed351/maxtext/src/MaxText/utils/ckpt_conversion/to_huggingface.py", line 216, in <module>
app.run(main)
File "/home/jed351/maxtext/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
_run_main(main, args)
File "/home/jed351/maxtext/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/home/jed351/maxtext/src/MaxText/utils/ckpt_conversion/to_huggingface.py", line 173, in main
filtered_map_keys = validate_and_filter_param_map_keys(param_map.keys(), maxtext_state_dict.keys())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/maxtext/src/MaxText/utils/ckpt_conversion/utils/utils.py", line 132, in validate_and_filter_param_map_keys
raise ValueError(
ValueError: maxtext_state_dict must be a subset of flattened param_map
param map
dict_keys(['params-token_embedder-embedding', 'params-decoder-decoder_norm-scale', ...'])
maxtext:
dict_keys(['params-decoder-decoder_norm-scale', 'params-decoder-layers-layers_0-mlp-wi_0-kernel', ...])
missing keys:
{'params-decoder-layers-layers_5-mlp-wi_1-kernel', 'params-decoder-layers-layers_2-self_attention-key_norm-scale', ...}
Environment Information
TPU v6e
v2-alpha-tpuv6e
MaxText installed from source
Additional Context
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working