diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 5181b2c1a0ad..1d5f35d9fe55 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -7,6 +7,8 @@ if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import flex_attention + flex_attention = torch.compile(flex_attention, dynamic=False) + def flex_attention_forward(