Skip to content

Add return_attention_scores support to CachedMultiHeadAttention #2213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

DakshBegani
Copy link

This PR addresses #2055, where attention_scores was always None due to the _return_attention_scores flag not being set in the CachedMultiHeadAttention subclass.

In recent Keras versions, the base MultiHeadAttention layer uses a private flag self._return_attention_scores to decide whether or not to return attention scores from _compute_attention.

However, CachedMultiHeadAttention was not passing or setting this flag at all, which meant attention_scores were silently dropped — making them inaccessible for debugging or analysis.

In this PR we did the following-
1.Adds return_attention_scores as an optional argument to the constructor (default False, just like in base MHA).
2.Sets self._return_attention_scores appropriately.
3.Updates the call() method to return attention_scores alongside attention_output and cache when requested — fully preserving existing behavior otherwise.

@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label Apr 15, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Apr 15, 2025
# Call the parent class constructor
super().__init__(num_heads, key_dim, **kwargs)
# New flag to optionally return attention scores
self._return_attention_scores = return_attention_scores
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a call arg in the super class here - https://github.com/keras-team/keras/blob/44a655bdb28037046ab279a49d4cd679fea7ca50/keras/src/layers/attention/multi_head_attention.py#L523

Also if flash attention is used using ops.dot_product_attention then attention scores will not be returned

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a call arg in the super class here

makes sense now — instead of manually setting the flag, it made sense to just pass return_attention_scores into super().init() since the base MHA layer handles it internally.

I’ve pushed the fix with that change; let me know if further changes are needed

@@ -63,7 +63,13 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention):
projected to the shape specified by `output_shape`. `cache` is the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably update the returns section with this. Maybe format as a list so it reads easier?

@@ -63,7 +63,13 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention):
projected to the shape specified by `output_shape`. `cache` is the
updated cache.
"""

def __init__(self, num_heads, key_dim, return_attention_scores=False, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this section right? The super class on __init__ does not take in return attention scores as far as I can tell. It's an argument to call, so we'd need to add it there instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @mattdangerw!

I just wanted to double-check before making changes:
As per the current Keras MultiHeadAttention implementation (link here), it looks like return_attention_scores is indeed accepted as a constructor (init) argument now.

So in this case, it seemed appropriate to forward it via super().init() rather than manually setting the private attribute.

Could you please specify the adjustments needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. The piece of code you are linking is showing return_attention_scores as a call argument, not an __init__ argument. We should do the same here. No where in init does MultiHeadAttention take return_attention_scores, so this would crash.

https://github.com/keras-team/keras/blob/37eacb012dff7bc6ffece2d5992faa4d279ed244/keras/src/layers/attention/multi_head_attention.py#L103-L122

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants