-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP Torch XLA vs. FSDPv2 (SMPD) Torch XLA checkpoint saving bug #36004
Comments
cc @muellerzr @SunMarc for Trainer |
Could you try this PR that was never finished ? #29780 |
Thank you @SunMarc and @Rocketknight1 for the help! The code is now working with suggested solution from your side. I did the following changes to the current version of Transformers as #29780 suggested: transformers/src/transformers/modeling_utils.py Line 5636 in 8d73a38
replace it with this function
and for the trainer :
I implemented the function unwrap_model(model).save_pretrained directly under save_tpu and added transformers/src/transformers/trainer.py Line 3777 in 8d73a38
transformers/src/transformers/trainer.py Line 3811 in 8d73a38
Both statement does not hold True when FSDPv2 is on TPU, and we will ended up here : transformers/src/transformers/trainer.py Line 3821 in 8d73a38
so we will never call unwrap_model(model).save_pretrained function. FYI if this can help, the self.accelerator.unwrap_model(model) will return this value :
|
Glad that it works ! Would you like to open a PR to propose your fix, that will be nice as I don't have access to TPU to verify that it actually works ! Note that you can pass |
Hi @SunMarc , |
System Info
There is bug in how trainer (SFTTrainer) saves the checkpoint when we use FSDPv2 (SMPD) on TPU. This behavior does not show up with old method to run Torch XLA code ( xla_spawn.py). This behavior causes the new checkpoint to be almost exactly as the base model , throwing this error with PEFT
Found missing adapter keys while loading the checkpoint: {missing_keys}
even without PEFT, the weight of the models seems not affected by the training process.
The problem may related to how the saving function with FSDPv2 Torch XLA works in the trainer file. The same code is working 100% with GPU and also is working with xla_spawn.py FSDP method.
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
To replicate save the code as sft.py and run it with PJRT_DEVICE=TPU XLA_USE_SPMD=1 python3 sft.py:
You will notice in the output folder, that the saved model is not in LoRa format (not two adapter files adapter_config.json adapter_model.safetensors). This is because with FSDPv2, we will ended up here (You can check by adding print statement).
transformers/src/transformers/trainer.py
Line 3821 in 62db3e6
However, if we use the same code with GPU or with old xla_spawn (FSDP) method, this issue will disappear. To replicate the same code with FSDP first run
wget https://raw.githubusercontent.com/huggingface/transformers/refs/heads/main/examples/pytorch/xla_spawn.py
then save the below code and run it with python3 xla_spawn --num_cores x sft.py :
With this code everything works great! because the saving function will ended up here:
transformers/src/transformers/trainer.py
Line 3824 in 62db3e6
I merged the LoRa adapter with the base model and the generated output is as expected from a finetuned model!
Finally, please note that this issue is not related to PEFT, because even if you use SFTTrainer without PEFT, this issue still exist. I believe it has to do with how we save checkpoint with FSDPv2 when we use TPUs.
Expected behavior
The model with LoRa should save two adapter files and when we merge LoRa with the base model we should not have this message (You should update PEFT to the latest version (0.14.0) as it adds additional check to detect problems with LoRa checkpoints.) :
Found missing adapter keys while loading the checkpoint: {missing_keys}
The text was updated successfully, but these errors were encountered: