Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama-3.2-11B-Vision-Instruct (mllama) FSDP fails if grad checkpointing is enabled #36040

Open
4 tasks
xrdaukar opened this issue Feb 5, 2025 · 0 comments
Open
4 tasks
Labels

Comments

@xrdaukar
Copy link

xrdaukar commented Feb 5, 2025

System Info

1 node with 4 A100 40GB GPUs launched by SkyPilot (A100:4) on GCP

Who can help?

What happened?

FSDP SFT fine-tuning of meta-llama/Llama-3.2-90B-Vision-Instruct on 1 node with 4 A100-40GB GPU-s with TRL trainer (trl.SFTTrainer) started to fail for us after upgrade to transformers>=4.46, including transformers==4.48.2:

Sample error for sdpa attention:

[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/home/gcpuser/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:   File "/home/gcpuser/miniconda3/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py", line 798, in forward
[rank2]:     attn_output = torch.nn.functional.scaled_dot_product_attention(
[rank2]: RuntimeError: The expanded size of the tensor (46) must match the existing size (23) at non-singleton dimension 3.  Target sizes: [2, 32, 23, 46].  Tensor sizes: [2, 1, 23, 23]

It fails with similar error messages for eager attention as well.

This affects both full-finetuning and LoRA tuning.

Disabling grad checkpointing (w/ smaller batch size) resolves the error.

Note that if we install transformers>=4.45.2,<4.46 then training works w/o the error under the same settings w/ gradient checkpointing on or off. It's likely the regression is related to this attention refactor: #35235

Steps to reproduce the bug

  1. Install transformers>=4.48.2,<4.49, trl>=0.13.0,<0.14
  2. FSDP tune meta-llama/Llama-3.2-90B-Vision-Instruct using torchrun

Accelerate environment variables for FSDP:

{'ACCELERATE_DYNAMO_BACKEND': 'NO', 'ACCELERATE_DYNAMO_MODE': 'default', 'ACCELERATE_DYNAMO_USE_FULLGRAPH': 'False', 'ACCELERATE_DYNAMO_USE_DYNAMIC': 'False', 'FSDP_CPU_RAM_EFFICIENT_LOADING': 'true', 'FSDP_USE_ORIG_PARAMS': 'true', 'ACCELERATE_USE_FSDP': 'true', 'FSDP_SHARDING_STRATEGY': 'HYBRID_SHARD', 'FSDP_OFFLOAD_PARAMS': 'false', 'FSDP_BACKWARD_PREFETCH': 'BACKWARD_PRE', 'FSDP_FORWARD_PREFETCH': 'false', 'FSDP_STATE_DICT_TYPE': 'FULL_STATE_DICT', 'FSDP_AUTO_WRAP_POLICY': 'TRANSFORMER_BASED_WRAP', 'FSDP_MIN_NUM_PARAMS': '100000', 'FSDP_TRANSFORMER_CLS_TO_WRAP': 'MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer', 'FSDP_SYNC_MODULE_STATES': 'true', 'FSDP_ACTIVATION_CHECKPOINTING': 'true'}

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I don't yet have a standalone repro script for this issue (it was reproduced as part of a different system). If it's a requirement, and you can't easily reproduce the issue using your own scripts based on the description above, please let me know .

Expected behavior

No error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant