Skip to content

It worth not hiding exceptions #21432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import builtins
import math

from absl import logging

import jax
import jax.experimental.sparse as jax_sparse
import jax.numpy as jnp
Expand Down Expand Up @@ -1256,6 +1258,9 @@ def dot_product_attention(
# TPU-specific flash attention path
if is_tpu and flash_attention:
# Get sharding parameters from distribution context
head_shards = 1
# Typically keep q_seq_shards=1 for best performance
q_seq_shards = 1
try:
from keras.src.distribution.distribution_lib import ModelParallel
from keras.src.distribution.distribution_lib import (
Expand All @@ -1270,12 +1275,12 @@ def dot_product_attention(
model_dim_index = mesh.axis_names.index("model")
# Set head_shards based on the model dimension of the mesh
head_shards = mesh.shape[model_dim_index]
# Typically keep q_seq_shards=1 for best performance
q_seq_shards = 1
except (ImportError, ValueError, AttributeError):
# Use default values if detection fails
head_shards = 1
q_seq_shards = 1
logging.exception(
"Failed to determine distribution context for sharding. "
"Using default head_shards=1 and q_seq_shards=1."
)
# Transpose to ('batch', 'heads', 'length', 'head_dim')
query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3))
key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3))
Expand Down Expand Up @@ -1328,24 +1333,17 @@ def dot_product_attention(
# Transpose output back to Keras layout
return jnp.transpose(output, axes=(0, 2, 1, 3))
except Exception:
logging.exception(
"Failed to apply Splash kernel for flash attention. "
"Falling back to JAX native dot_product_attention."
)
flash_attention = False

# JAX native dot_product_attention for GPU or fallback for TPU
if hasattr(jax.nn, "dot_product_attention"):
try:
return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation="cudnn" if flash_attention else "xla",
)
except Exception:
# If flash attention fails, fall back to XLA implementation
if flash_attention:
impls = ["cudnn", "xla"] if flash_attention else ["xla"]
for impl in impls:
try:
return jax.nn.dot_product_attention(
query,
key,
Expand All @@ -1354,9 +1352,14 @@ def dot_product_attention(
mask=mask,
scale=scale,
is_causal=is_causal,
implementation="xla",
implementation=impl,
)
raise
except Exception:
logging.exception(
f"Failed to apply {impl} implementation of "
"jax.nn.dot_product_attention."
)


if flash_attention:
raise RuntimeError(
Expand Down
Loading