From 3059ab4a0044cb664dd8f2bfbb1728b1d4a37876 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 1 Mar 2025 02:15:49 +0000 Subject: [PATCH] torch compiled flex attention Signed-off-by: Sami Jaghouar torch compiled flash attention Signed-off-by: Sami Jaghouar --- src/transformers/integrations/flex_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 5181b2c1a0ad..1d5f35d9fe55 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -7,6 +7,8 @@ if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import flex_attention + flex_attention = torch.compile(flex_attention, dynamic=False) + def flex_attention_forward(