From 2e929292bf0dbd72ddd82d25e841888ca49552e6 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Wed, 2 Apr 2025 08:36:49 +0800 Subject: [PATCH 1/2] make torch use flash-attn --- keras_hub/src/utils/keras_utils.py | 43 ++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/keras_hub/src/utils/keras_utils.py b/keras_hub/src/utils/keras_utils.py index d247f7c254..34c4daf59c 100644 --- a/keras_hub/src/utils/keras_utils.py +++ b/keras_hub/src/utils/keras_utils.py @@ -56,20 +56,35 @@ def standardize_data_format(data_format): def has_flash_attention_support(): - if ( - hasattr(keras.config, "is_flash_attention_enabled") - and keras.config.backend() == "jax" - ): - try: - from jax.nn import dot_product_attention as dot_product_attention - except ImportError: - logging.warning( - "Flash attention is not supported in your current JAX version. " - "Please update it by following the official guide: " - "https://jax.readthedocs.io/en/latest/installation.html" - ) - return False - return True + if hasattr(keras.config, "is_flash_attention_enabled"): + if keras.config.backend() == "jax": + try: + from jax.nn import ( + dot_product_attention as dot_product_attention, + ) + except ImportError: + logging.warning( + "Flash-attn is not supported in your current JAX version. " + "Please update it by following the official guide: " + "https://jax.readthedocs.io/en/latest/installation.html" + ) + return False + return True + elif keras.config.backend() == "torch": + try: + from torch.backends.cuda import SDPAParams # noqa: F401 + from torch.backends.cuda import ( + can_use_flash_attention, # noqa: F401 + ) + except ImportError: + logging.warning( + "Flash-attn is not supported in your current PyTorch " + "version. Update it by following the official guide: " + "https://pytorch.org/get-started/locally/" + ) + return False + return True + else: return False From bd287b51e01dee94a1ff0e69deb3c772cb343dcb Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Wed, 2 Apr 2025 08:44:06 +0800 Subject: [PATCH 2/2] make torch use flash-attn --- keras_hub/src/utils/keras_utils.py | 56 ++++++++++++++---------------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/keras_hub/src/utils/keras_utils.py b/keras_hub/src/utils/keras_utils.py index 34c4daf59c..cf0e41d01f 100644 --- a/keras_hub/src/utils/keras_utils.py +++ b/keras_hub/src/utils/keras_utils.py @@ -56,35 +56,33 @@ def standardize_data_format(data_format): def has_flash_attention_support(): - if hasattr(keras.config, "is_flash_attention_enabled"): - if keras.config.backend() == "jax": - try: - from jax.nn import ( - dot_product_attention as dot_product_attention, - ) - except ImportError: - logging.warning( - "Flash-attn is not supported in your current JAX version. " - "Please update it by following the official guide: " - "https://jax.readthedocs.io/en/latest/installation.html" - ) - return False - return True - elif keras.config.backend() == "torch": - try: - from torch.backends.cuda import SDPAParams # noqa: F401 - from torch.backends.cuda import ( - can_use_flash_attention, # noqa: F401 - ) - except ImportError: - logging.warning( - "Flash-attn is not supported in your current PyTorch " - "version. Update it by following the official guide: " - "https://pytorch.org/get-started/locally/" - ) - return False - return True - + if not hasattr(keras.config, "is_flash_attention_enabled"): + return False + if keras.config.backend() == "jax": + try: + from jax.nn import dot_product_attention as dot_product_attention + except ImportError: + logging.warning( + "Flash attention is not supported in your current JAX version. " + "Please update it by following the official guide: " + "https://jax.readthedocs.io/en/latest/installation.html" + ) + return False + return True + elif keras.config.backend() == "torch": + try: + from torch.backends.cuda import SDPAParams # noqa: F401 + from torch.backends.cuda import ( + can_use_flash_attention, # noqa: F401 + ) + except ImportError: + logging.warning( + "Flash attention is not supported in your current PyTorch " + "version. Please update it by following the official guide:" + "https://pytorch.org/get-started/locally/" + ) + return False + return True else: return False