Skip to content

SFT Model Convert to HuggingFace Error #3003

@jedcheng

Description

@jedcheng

Bug report

I followed the tutorial for doing SFT on MaxText (with Gemma3)
https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html

Unfortunately, I encountered a value error when converting the MaxText checkpoint back to HuggingFace format

python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
    model_name='gemma3-4b' \
    hf_access_token=${some_token} \
    load_parameters_path=gs://${somewhere}/${run_name}/checkpoints/1/model_params \
    base_output_directory=/tmp/gemma3/hf_converted/ \
    use_multimodal=false \
    scan_layers=false \
    skip_jax_distributed_system=true

Logs/Output

I0123 09:00:33.342407 124540818566272 checkpointer.py:318] Finished restoring checkpoint in 16.29 seconds from gs://${somewhere}/gemma3-4b-sft-v1/train_from_hf_gemma3_4b_it/checkpoints/1/model_params.
I0123 09:00:33.344157 124540818566272 to_huggingface.py:140] Elapse for checkpoint load: 0.28 min
I0123 09:00:34.637320 124540818566272 utils.py:884] Detected NNX-SFT checkpoint structure
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jed351/maxtext/src/MaxText/utils/ckpt_conversion/to_huggingface.py", line 216, in <module>
    app.run(main)
  File "/home/jed351/maxtext/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/home/jed351/maxtext/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/jed351/maxtext/src/MaxText/utils/ckpt_conversion/to_huggingface.py", line 173, in main
    filtered_map_keys = validate_and_filter_param_map_keys(param_map.keys(), maxtext_state_dict.keys())
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jed351/maxtext/src/MaxText/utils/ckpt_conversion/utils/utils.py", line 132, in validate_and_filter_param_map_keys
    raise ValueError(
ValueError: maxtext_state_dict must be a subset of flattened param_map
param map
dict_keys(['params-token_embedder-embedding', 'params-decoder-decoder_norm-scale', ...'])
maxtext:
dict_keys(['params-decoder-decoder_norm-scale', 'params-decoder-layers-layers_0-mlp-wi_0-kernel', ...])
missing keys:
{'params-decoder-layers-layers_5-mlp-wi_1-kernel', 'params-decoder-layers-layers_2-self_attention-key_norm-scale', ...}

Environment Information

TPU v6e
v2-alpha-tpuv6e

MaxText installed from source

Additional Context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions