MultiHeadAttentionWrapper should instantiate CausalSelfAttention with d_out = d_out // num_heads? #609
Replies: 2 comments 3 replies
-
|
I believe the confusion lies in how we are interpreting In your impl, It's true that it's clearer in the sense that the |
Beta Was this translation helpful? Give feedback.
-
|
@rasbt I hope it's ok if I borrow this thread to make a comment related to that code, specifically The book states that the single-head attention modules are processed sequentially via Or did you mean in terms of kernel launches? Like, a single batched matmul across all heads in the optimised MultiHeadAttention class allows you to process all heads with one kernel launch, compared to multiple ones with MultiHeadAttentionWrapper? |
Beta Was this translation helpful? Give feedback.

Uh oh!
There was an error while loading. Please reload this page.
-
Since the
MultiHeadAttentionWrapperclass callstorch.cat([head(x) for head in self.heads], dim=-1)shouldn't we be instantiating
CausalSelfAttentionwith d_out = d_out // num_heads so that the finalMultiHeadAttentionWrapperoutput has the same shape and d_out as was specified in the input?In other words, is this a clearer implementation?
Beta Was this translation helpful? Give feedback.
All reactions