Skip to content

make torch use flash-attn #2189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions keras_hub/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,9 @@ def standardize_data_format(data_format):


def has_flash_attention_support():
if (
hasattr(keras.config, "is_flash_attention_enabled")
and keras.config.backend() == "jax"
):
if not hasattr(keras.config, "is_flash_attention_enabled"):
return False
if keras.config.backend() == "jax":
try:
from jax.nn import dot_product_attention as dot_product_attention
except ImportError:
Expand All @@ -70,6 +69,20 @@ def has_flash_attention_support():
)
return False
return True
elif keras.config.backend() == "torch":
try:
from torch.backends.cuda import SDPAParams # noqa: F401
from torch.backends.cuda import (
can_use_flash_attention, # noqa: F401
)
except ImportError:
logging.warning(
"Flash attention is not supported in your current PyTorch "
"version. Please update it by following the official guide:"
"https://pytorch.org/get-started/locally/"
)
return False
return True
else:
return False

Expand Down