Description
System Info
1 node with 4 A100 40GB GPUs launched by SkyPilot (A100:4
) on GCP
Who can help?
What happened?
FSDP SFT fine-tuning of meta-llama/Llama-3.2-90B-Vision-Instruct
on 1 node with 4 A100-40GB
GPU-s with TRL trainer (trl.SFTTrainer
) started to fail for us after upgrade to transformers>=4.46
, including transformers==4.48.2
:
Sample error for sdpa
attention:
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/home/gcpuser/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/home/gcpuser/miniconda3/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py", line 798, in forward
[rank2]: attn_output = torch.nn.functional.scaled_dot_product_attention(
[rank2]: RuntimeError: The expanded size of the tensor (46) must match the existing size (23) at non-singleton dimension 3. Target sizes: [2, 32, 23, 46]. Tensor sizes: [2, 1, 23, 23]
It fails with similar error messages for eager
attention as well.
This affects both full-finetuning and LoRA tuning.
Disabling grad checkpointing (w/ smaller batch size) resolves the error.
Note that if we install transformers>=4.45.2,<4.46
then training works w/o the error under the same settings w/ gradient checkpointing on or off. It's likely the regression is related to this attention refactor: #35235
Steps to reproduce the bug
- Install
transformers>=4.48.2,<4.49
,trl>=0.13.0,<0.14
- FSDP tune
meta-llama/Llama-3.2-90B-Vision-Instruct
usingtorchrun
Accelerate environment variables for FSDP:
{'ACCELERATE_DYNAMO_BACKEND': 'NO', 'ACCELERATE_DYNAMO_MODE': 'default', 'ACCELERATE_DYNAMO_USE_FULLGRAPH': 'False', 'ACCELERATE_DYNAMO_USE_DYNAMIC': 'False', 'FSDP_CPU_RAM_EFFICIENT_LOADING': 'true', 'FSDP_USE_ORIG_PARAMS': 'true', 'ACCELERATE_USE_FSDP': 'true', 'FSDP_SHARDING_STRATEGY': 'HYBRID_SHARD', 'FSDP_OFFLOAD_PARAMS': 'false', 'FSDP_BACKWARD_PREFETCH': 'BACKWARD_PRE', 'FSDP_FORWARD_PREFETCH': 'false', 'FSDP_STATE_DICT_TYPE': 'FULL_STATE_DICT', 'FSDP_AUTO_WRAP_POLICY': 'TRANSFORMER_BASED_WRAP', 'FSDP_MIN_NUM_PARAMS': '100000', 'FSDP_TRANSFORMER_CLS_TO_WRAP': 'MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer', 'FSDP_SYNC_MODULE_STATES': 'true', 'FSDP_ACTIVATION_CHECKPOINTING': 'true'}
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I don't yet have a standalone repro script for this issue (it was reproduced as part of a different system). If it's a requirement, and you can't easily reproduce the issue using your own scripts based on the description above, please let me know .
Expected behavior
No error