-
Notifications
You must be signed in to change notification settings - Fork 278
Add return_attention_scores support to CachedMultiHeadAttention #2213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,7 +63,13 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention): | |
projected to the shape specified by `output_shape`. `cache` is the | ||
updated cache. | ||
""" | ||
|
||
def __init__(self, num_heads, key_dim, return_attention_scores=False, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need this section right? The super class on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the review @mattdangerw! I just wanted to double-check before making changes: So in this case, it seemed appropriate to forward it via super().init() rather than manually setting the private attribute. Could you please specify the adjustments needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused. The piece of code you are linking is showing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see where the confusion came in now — thanks for the clarification, @mattdangerw ! The version I was looking at previously must’ve been an earlier iteration where return_attention_scores was part of init. Totally makes sense now that it’s handled in call() instead. I’ll update the PR accordingly — drop the init arg, manually set the flag in call(), and revise the docstring to reflect the conditional returns. |
||
super().__init__( | ||
num_heads, | ||
key_dim, | ||
return_attention_scores=return_attention_scores, | ||
**kwargs, | ||
) | ||
def call( | ||
self, | ||
query, | ||
|
@@ -118,6 +124,12 @@ def call( | |
|
||
attention_output = self._output_dense(attention_output) | ||
|
||
if cache is not None: | ||
return attention_output, cache | ||
return attention_output | ||
# Returning updated logic to support attention_scores if requested | ||
if self._return_attention_scores: | ||
if cache is not None: | ||
return attention_output, attention_scores, cache | ||
return attention_output, attention_scores | ||
else: | ||
if cache is not None: | ||
return attention_output, cache | ||
return attention_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably update the returns section with this. Maybe format as a list so it reads easier?