Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP Torch XLA vs. FSDPv2 (SMPD) Torch XLA checkpoint saving bug #36004

Open
4 tasks
salrowili opened this issue Feb 1, 2025 · 5 comments
Open
4 tasks

FSDP Torch XLA vs. FSDPv2 (SMPD) Torch XLA checkpoint saving bug #36004

salrowili opened this issue Feb 1, 2025 · 5 comments
Labels

Comments

@salrowili
Copy link

salrowili commented Feb 1, 2025

System Info

There is bug in how trainer (SFTTrainer) saves the checkpoint when we use FSDPv2 (SMPD) on TPU. This behavior does not show up with old method to run Torch XLA code ( xla_spawn.py). This behavior causes the new checkpoint to be almost exactly as the base model , throwing this error with PEFT

Found missing adapter keys while loading the checkpoint: {missing_keys}

even without PEFT, the weight of the models seems not affected by the training process.

The problem may related to how the saving function with FSDPv2 Torch XLA works in the trainer file. The same code is working 100% with GPU and also is working with xla_spawn.py FSDP method.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

To replicate save the code as sft.py and run it with PJRT_DEVICE=TPU XLA_USE_SPMD=1 python3 sft.py:

import torch
import torch_xla
import peft
import trl
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig,PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer, SFTConfig
import wandb
wandb.init(mode="disabled")
device = xm.xla_device() # Set up TPU device.
print(device)

def train():
	model_id = "meta-llama/Llama-3.2-1B-Instruct"
	model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
	tokenizer =  AutoTokenizer.from_pretrained(model_id)
	tokenizer.pad_token = tokenizer.eos_token
	data = load_dataset("philschmid/dolly-15k-oai-style",split="train")
	lora_config = LoraConfig(r=8,target_modules=["k_proj", "v_proj"],task_type="CAUSAL_LM")
	fsdp_config = {'fsdp_transformer_layer_cls_to_wrap': ['LlamaDecoderLayer'], 'xla': True, 'xla_fsdp_v2': True, 'xla_fsdp_grad_ckpt': True}
	args=SFTConfig(
                per_device_train_batch_size=8,
                num_train_epochs=1,
		max_steps=-1,
                output_dir="output",
                optim="adafactor",
                logging_steps=50,
                learning_rate=2e-5,
                max_seq_length=2048,
                packing=True,
                dataset_text_field=None,
                save_strategy="no",
                dataloader_drop_last = True,  # Required for SPMD.
                fsdp="full_shard",
                fsdp_config=fsdp_config)
	trainer = SFTTrainer(
	model=model,
	train_dataset=data,
	tokenizer = tokenizer,
	args=args,
	peft_config=lora_config)
	trainer.train()
	final_model=trainer.model
	final_model.to("cpu")
	final_model.save_pretrained("./LoRa")
if __name__ == "__main__":
	train()

You will notice in the output folder, that the saved model is not in LoRa format (not two adapter files adapter_config.json adapter_model.safetensors). This is because with FSDPv2, we will ended up here (You can check by adding print statement).

state_dict = xm._maybe_convert_to_cpu(model.state_dict())

However, if we use the same code with GPU or with old xla_spawn (FSDP) method, this issue will disappear. To replicate the same code with FSDP first run
wget https://raw.githubusercontent.com/huggingface/transformers/refs/heads/main/examples/pytorch/xla_spawn.py
then save the below code and run it with python3 xla_spawn --num_cores x sft.py :

from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TrainingArguments
from trl import SFTTrainer,SFTConfig
import os
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
import transformers
import wandb
wandb.init(mode="disabled")
def main():


	data = load_dataset("philschmid/dolly-15k-oai-style",split="train")
        model_id = "meta-llama/Llama-3.2-1B-Instruct"
	tokenizer = AutoTokenizer.from_pretrained(model_id)
	tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
	model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

	#target_modules=["k_proj", "v_proj","embed_tokens", "lm_head"]
	lora_config = LoraConfig(
		r=16,
		lora_alpha=32,
		lora_dropout=0.05,
		bias="none",
		target_modules=["q_proj", "k_proj", "v_proj","embed_tokens", "lm_head"],
		task_type="CAUSAL_LM",
		  )

	trainer = SFTTrainer(
		model=model,
		train_dataset=data,
		args=SFTConfig(
			per_device_train_batch_size=1,
			num_train_epochs=3,
			max_steps=-1,
			output_dir="./output",
			logging_steps=50,
		learning_rate=5e-5,
		max_seq_length=2048,
		save_steps=1000000,
		save_only_model=True,
		packing=True,
		dataset_num_proc=40,
		),
		peft_config=lora_config,
	)

	trainer.train()
	final_model=trainer.model
	final_model.to("cpu")
	final_model.save_pretrained("./LoRa")

def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()

With this code everything works great! because the saving function will ended up here:

model.save_pretrained(

I merged the LoRa adapter with the base model and the generated output is as expected from a finetuned model!

Finally, please note that this issue is not related to PEFT, because even if you use SFTTrainer without PEFT, this issue still exist. I believe it has to do with how we save checkpoint with FSDPv2 when we use TPUs.

Expected behavior

The model with LoRa should save two adapter files and when we merge LoRa with the base model we should not have this message (You should update PEFT to the latest version (0.14.0) as it adds additional check to detect problems with LoRa checkpoints.) :

Found missing adapter keys while loading the checkpoint: {missing_keys}

@salrowili salrowili added the bug label Feb 1, 2025
@Rocketknight1
Copy link
Member

cc @muellerzr @SunMarc for Trainer

@SunMarc
Copy link
Member

SunMarc commented Feb 4, 2025

Could you try this PR that was never finished ? #29780
If this solves the issue, I will finish the PR. The potential issue was that the model was not unwrapped correctly for FSDPv2

@salrowili
Copy link
Author

Thank you @SunMarc and @Rocketknight1 for the help! The code is now working with suggested solution from your side.

I did the following changes to the current version of Transformers as #29780 suggested:

def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:

replace it with this function

def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
    """
    Recursively unwraps a model from potential containers (as used in distributed training).

    Args:
        model (`torch.nn.Module`): The model to unwrap.
        recursive (`bool`, *optional*, defaults to `False`):
            Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
            recursively, not just the top-level distributed containers.
    """
    # Use accelerate implementation if available (should always be the case when using torch)
    # This is for pytorch, as we also have to handle things like dynamo
    def recursive_unwrap(module):
        if hasattr(module, "module"):
            unwrapped_module = recursive_unwrap(getattr(module, "module"))
        else:
            unwrapped_module = module  # Handle cases where wrapped module is inaccessible

        # Unwrap child sublayers recursively
        for name, child in module.named_children():
            setattr(module, name, recursive_unwrap(child))

        return unwrapped_module

    # Start with top-level unwrapping
    unwrapped_model = recursive_unwrap(model)
    return unwrapped_model

and for the trainer :

    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir

        logger.info(f"Saving model checkpoint to {output_dir}")
        model = self.model
        xm.mark_step()
        model.to("cpu")
        if xm.is_master_ordinal(local=False):
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        unwrap_model(model).save_pretrained(
                    output_dir,
                    is_main_process=self.args.should_save,
                    state_dict=unwrap_model(model).state_dict(),
                    save_function=xm.save,
                    safe_serialization=self.args.save_safetensors,
                )
        exit()

I implemented the function unwrap_model(model).save_pretrained directly under save_tpu and added model.to("cpu") because the current series of if statements under save_tpu function is not working probably for FSDPv2 on TPU.

if self.is_fsdp_xla_v1_enabled:

if isinstance(self.accelerator.unwrap_model(model), supported_classes):

Both statement does not hold True when FSDPv2 is on TPU, and we will ended up here :
state_dict = xm._maybe_convert_to_cpu(model.state_dict())

so we will never call unwrap_model(model).save_pretrained function.

FYI if this can help, the self.accelerator.unwrap_model(model) will return this value :

SpmdFullyShardedDataParallel(                                                                                                                                                  
  (_orig_module): PeftModelForCausalLM(                                                                                                                                        
    (base_model): LoraModel(                                                                                                                                                   
      (model): LlamaForCausalLM(                                                                                                                                               
        (model): LlamaModel(                                                                                                                                                   
          (embed_tokens): Embedding(128256, 2048)                                                                                                                              
          (layers): ModuleList(                                                                                                                                                
            (0-15): 16 x SpmdFullyShardedDataParallel(                                                                                                                         
              (_orig_module): LlamaDecoderLayer(                                                                                                                               
                (self_attn): LlamaAttention(                                                                                                                                   
                  (q_proj): Linear(in_features=2048, out_features=2048, bias=False)                                                                                            
                  (k_proj): lora.Linear(                                                                                                                                       
                    (base_layer): Linear(in_features=2048, out_features=512, bias=False)                                                                                       
                    (lora_dropout): ModuleDict(                                                                                                                                
                      (default): Identity()                                                                                                                                    
                    )                                                                                                                                                          
                    (lora_A): ModuleDict(                                                                                                                                      
                      (default): Linear(in_features=2048, out_features=8, bias=False)                                                                                          
                    )                                                                                                                                                          
                    (lora_B): ModuleDict(                                                                                                                                      
                      (default): Linear(in_features=8, out_features=512, bias=False)                                                                                           
                    )
                    ..........................................
                    ..........................................                                  

@SunMarc
Copy link
Member

SunMarc commented Feb 5, 2025

Glad that it works ! Would you like to open a PR to propose your fix, that will be nice as I don't have access to TPU to verify that it actually works ! Note that you can pass recursive to True in the unwrap_model function as this is now supported in accelerate. More details in this comment.

@salrowili
Copy link
Author

Hi @SunMarc ,
Ok. I will work on this but since this may affect other environment (e.g., FSDPv2 on GPU), i will spend more time testing different environment before proposing my fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants