You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
1. some of the blocks are only used by a single token, and some are shared
30
-
1. some of the key values have been just calculated and are available alongside queries and don't need to be fetched from the cache
30
+
2. some of the key values have been just calculated and are available alongside queries and don't need to be fetched from the cache
31
31
32
32
In a naive implementation we would just multiply whole query times key and value and use appropriate bias to mask unused fields, but that would be very inneficient especially for decodes where usually we have only a single token per sample in a batch and there's almost no overlap between used blocks. We could slice the query and key into chunks and multiply only those regions that have relevant data, but that's currently difficult to achieve due to technical reasons. Instead we can divide the work into 3 separate parts and merge the results at the end.
The problem here lies in the denominator as it contains the sum of all terms. Fortunately we can split the calculation into two separate softmax and then readjust the results and combine them. Let's say we have:
41
41
$$z_1, z_2\text{ - local softmax results} \\ c_1, c_2 \text{ - local maxima} \\ s_1, s_2 \text{ - local sums}$$
This way we can calculate parts of softmax and later readjust and recombine the values into the final result. There are two other tricks that we can use. Since we're going to divide by the global sum anyway we can skip dividing by local sums followed by multiplying by local sums during readjustment and keep intermediate 'softmax' values without division. Additionally since readjustment is multiplication by a constant we can utilize the facts that:
@@ -50,11 +50,11 @@ and move softmax readjustment after multiplication by V in attention calculation
50
50
51
51
Causal attention is used to calculate attention values between currently computed Q, K and V. Since we data has been recently calculated, we don't need to fetch it from kv-cache. Prompt lengths are usually much longer then max_num_seqs. This means, in practice, we don't need to distinguish which tokens are used in prompts and which in decodes and use the whole Q relying on attn bias to mask out unnecessary tokens. Since we're using all query tokens one after another it works similarily to merged prefill feature. Here's an example how the computed causal bias might look like:
52
52
53
-

53
+

54
54
55
55
One optimization that is used here is that we can divide query into equal slices that use different lengths of key:
This way we can skip parts of the computation where index(key) > index(query). In the current implementation slice size is constant and is set to 512 based on experimental results.
0 commit comments