Skip to content

Commit 171a3c6 breaks checkpoint conversion in to_huggingface.py #2914

@panalexeu

Description

@panalexeu

Bug report

# Training
subprocess.run(
    [
        '/workspace/.venv/bin/python',
        '-m',
        'MaxText.train',
        CNF_PATH,
        'base_output_directory=/data/qwen3-0.6b-fine-tune',
        'model_name=qwen3-0.6b',
        'load_parameters_path=/data/qwen3-0.6b-ckpt/0/items',
        'dataset_type=hf',
        'hf_path=parquet',
        'hf_train_files=/data/gsm8k/train_full.parquet',
        'run_name=gsm8k_finetune',
        'tokenizer_path=/data/Qwen3-0.6B',
        'tokenizer_type=huggingface',
        'enable_checkpointing=true',
        'train_data_columns=["question","answer"]',
        'use_sft=true'
    ],
    cwd=MAXTEXT_PATH,
    check=True
)

# Convert back from checkpoint to safetensors format
subprocess.run(
    [
        '/workspace/.venv/bin/python',
        'src/MaxText/utils/ckpt_conversion/to_huggingface.py',
        CNF_PATH,
        'enable_checkpointing=true',
        'model_name=qwen3-0.6b',
        'load_parameters_path=/data/qwen3-0.6b-fine-tune/gsm8k_finetune/checkpoints/0/items',
        'base_output_directory=/data/Qwen3-0.6B-fine-tune',
        'scan_layers=true'
    ],
    cwd=MAXTEXT_PATH,
    check=True
)

Logs/Output

Failure Logs

Error Message:

ValueError: Shape mismatch for model.embed_tokens.weight: Expect [151936, 1024], got (151936, 1024)

Stack Trace:

Traceback (most recent call last):
  File "/workdir/maxtext/src/MaxText/utils/ckpt_conversion/to_huggingface.py", line 237, in <module>
    app.run(main)
  File "/workdir/.venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/workdir/.venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/workdir/maxtext/src/MaxText/utils/ckpt_conversion/to_huggingface.py", line 211, in main
    processed_params = process_maxtext_param(key, weight, param_map, hook_fn_map, shape_map, config)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workdir/maxtext/src/MaxText/utils/ckpt_conversion/utils/utils.py", line 264, in process_maxtext_param
    _process(hf_path, maxtext_param_weight, output_weights, current_hook_fns, hf_shape_map)
  File "/workdir/maxtext/src/MaxText/utils/ckpt_conversion/utils/utils.py", line 203, in _process
    raise ValueError(f"Shape mismatch for {hf_path}: Expect {target_hf_shape}, got {numpy_slice.shape}")
ValueError: Shape mismatch for model.embed_tokens.weight: Expect [151936, 1024], got (151936, 1024)

Root Cause:
In utils/utils.py line 203, the shape comparison fails due to type mismatch:

  • target_hf_shape is a Python list: [151936, 1024]
  • numpy_slice.shape is a NumPy tuple: (151936, 1024)

The comparison [151936, 1024] != (151936, 1024) returns True in Python because list ≠ tuple, even though the dimensions are identical.

Environment Information

  • Python 3.12
  • Model: Qwen3-0.6B
  • Platform: TPU v5e 2x2

I think the problem in Commit 171a3c6, specifically changes in src/MaxText/utils/ckpt_conversion/utils/utils.py. On earlier commits I did not face this issue.

Additional Context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions