-
Notifications
You must be signed in to change notification settings - Fork 278
Add return_attention_scores support to CachedMultiHeadAttention #2213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add return_attention_scores support to CachedMultiHeadAttention #2213
Conversation
# Call the parent class constructor | ||
super().__init__(num_heads, key_dim, **kwargs) | ||
# New flag to optionally return attention scores | ||
self._return_attention_scores = return_attention_scores |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a call arg in the super class here - https://github.com/keras-team/keras/blob/44a655bdb28037046ab279a49d4cd679fea7ca50/keras/src/layers/attention/multi_head_attention.py#L523
Also if flash attention is used using ops.dot_product_attention then attention scores will not be returned
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a call arg in the super class here
makes sense now — instead of manually setting the flag, it made sense to just pass return_attention_scores into super().init() since the base MHA layer handles it internally.
I’ve pushed the fix with that change; let me know if further changes are needed
… updating call logic
@@ -63,7 +63,13 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention): | |||
projected to the shape specified by `output_shape`. `cache` is the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably update the returns section with this. Maybe format as a list so it reads easier?
@@ -63,7 +63,13 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention): | |||
projected to the shape specified by `output_shape`. `cache` is the | |||
updated cache. | |||
""" | |||
|
|||
def __init__(self, num_heads, key_dim, return_attention_scores=False, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this section right? The super class on __init__
does not take in return attention scores as far as I can tell. It's an argument to call, so we'd need to add it there instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @mattdangerw!
I just wanted to double-check before making changes:
As per the current Keras MultiHeadAttention implementation (link here), it looks like return_attention_scores is indeed accepted as a constructor (init) argument now.
So in this case, it seemed appropriate to forward it via super().init() rather than manually setting the private attribute.
Could you please specify the adjustments needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused. The piece of code you are linking is showing return_attention_scores
as a call
argument, not an __init__
argument. We should do the same here. No where in init does MultiHeadAttention take return_attention_scores
, so this would crash.
This PR addresses #2055, where attention_scores was always None due to the _return_attention_scores flag not being set in the CachedMultiHeadAttention subclass.
In recent Keras versions, the base MultiHeadAttention layer uses a private flag self._return_attention_scores to decide whether or not to return attention scores from _compute_attention.
However, CachedMultiHeadAttention was not passing or setting this flag at all, which meant attention_scores were silently dropped — making them inaccessible for debugging or analysis.
In this PR we did the following-
1.Adds return_attention_scores as an optional argument to the constructor (default False, just like in base MHA).
2.Sets self._return_attention_scores appropriately.
3.Updates the call() method to return attention_scores alongside attention_output and cache when requested — fully preserving existing behavior otherwise.