Skip to content

Add return_attention_scores support to CachedMultiHeadAttention #2213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions keras_hub/src/layers/modeling/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention):
projected to the shape specified by `output_shape`. `cache` is the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably update the returns section with this. Maybe format as a list so it reads easier?

updated cache.
"""

def __init__(self, num_heads, key_dim, return_attention_scores=False, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this section right? The super class on __init__ does not take in return attention scores as far as I can tell. It's an argument to call, so we'd need to add it there instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @mattdangerw!

I just wanted to double-check before making changes:
As per the current Keras MultiHeadAttention implementation (link here), it looks like return_attention_scores is indeed accepted as a constructor (init) argument now.

So in this case, it seemed appropriate to forward it via super().init() rather than manually setting the private attribute.

Could you please specify the adjustments needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. The piece of code you are linking is showing return_attention_scores as a call argument, not an __init__ argument. We should do the same here. No where in init does MultiHeadAttention take return_attention_scores, so this would crash.

https://github.com/keras-team/keras/blob/37eacb012dff7bc6ffece2d5992faa4d279ed244/keras/src/layers/attention/multi_head_attention.py#L103-L122

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see where the confusion came in now — thanks for the clarification, @mattdangerw !

The version I was looking at previously must’ve been an earlier iteration where return_attention_scores was part of init. Totally makes sense now that it’s handled in call() instead.

I’ll update the PR accordingly — drop the init arg, manually set the flag in call(), and revise the docstring to reflect the conditional returns.

super().__init__(
num_heads,
key_dim,
return_attention_scores=return_attention_scores,
**kwargs,
)
def call(
self,
query,
Expand Down Expand Up @@ -118,6 +124,12 @@ def call(

attention_output = self._output_dense(attention_output)

if cache is not None:
return attention_output, cache
return attention_output
# Returning updated logic to support attention_scores if requested
if self._return_attention_scores:
if cache is not None:
return attention_output, attention_scores, cache
return attention_output, attention_scores
else:
if cache is not None:
return attention_output, cache
return attention_output
Loading