Skip to content

Llama-3.2-11B-Vision-Instruct (mllama) FSDP fails if grad checkpointing is enabled #36040

Closed
@nikg4

Description

@nikg4

System Info

1 node with 4 A100 40GB GPUs launched by SkyPilot (A100:4) on GCP

Who can help?

What happened?

FSDP SFT fine-tuning of meta-llama/Llama-3.2-90B-Vision-Instruct on 1 node with 4 A100-40GB GPU-s with TRL trainer (trl.SFTTrainer) started to fail for us after upgrade to transformers>=4.46, including transformers==4.48.2:

Sample error for sdpa attention:

[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/home/gcpuser/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:   File "/home/gcpuser/miniconda3/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py", line 798, in forward
[rank2]:     attn_output = torch.nn.functional.scaled_dot_product_attention(
[rank2]: RuntimeError: The expanded size of the tensor (46) must match the existing size (23) at non-singleton dimension 3.  Target sizes: [2, 32, 23, 46].  Tensor sizes: [2, 1, 23, 23]

It fails with similar error messages for eager attention as well.

This affects both full-finetuning and LoRA tuning.

Disabling grad checkpointing (w/ smaller batch size) resolves the error.

Note that if we install transformers>=4.45.2,<4.46 then training works w/o the error under the same settings w/ gradient checkpointing on or off. It's likely the regression is related to this attention refactor: #35235

Steps to reproduce the bug

  1. Install transformers>=4.48.2,<4.49, trl>=0.13.0,<0.14
  2. FSDP tune meta-llama/Llama-3.2-90B-Vision-Instruct using torchrun

Accelerate environment variables for FSDP:

{'ACCELERATE_DYNAMO_BACKEND': 'NO', 'ACCELERATE_DYNAMO_MODE': 'default', 'ACCELERATE_DYNAMO_USE_FULLGRAPH': 'False', 'ACCELERATE_DYNAMO_USE_DYNAMIC': 'False', 'FSDP_CPU_RAM_EFFICIENT_LOADING': 'true', 'FSDP_USE_ORIG_PARAMS': 'true', 'ACCELERATE_USE_FSDP': 'true', 'FSDP_SHARDING_STRATEGY': 'HYBRID_SHARD', 'FSDP_OFFLOAD_PARAMS': 'false', 'FSDP_BACKWARD_PREFETCH': 'BACKWARD_PRE', 'FSDP_FORWARD_PREFETCH': 'false', 'FSDP_STATE_DICT_TYPE': 'FULL_STATE_DICT', 'FSDP_AUTO_WRAP_POLICY': 'TRANSFORMER_BASED_WRAP', 'FSDP_MIN_NUM_PARAMS': '100000', 'FSDP_TRANSFORMER_CLS_TO_WRAP': 'MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer', 'FSDP_SYNC_MODULE_STATES': 'true', 'FSDP_ACTIVATION_CHECKPOINTING': 'true'}

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I don't yet have a standalone repro script for this issue (it was reproduced as part of a different system). If it's a requirement, and you can't easily reproduce the issue using your own scripts based on the description above, please let me know .

Expected behavior

No error

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions