Skip to content

Commit d969f38

Browse files
committed
fix for PR #275
Signed-off-by: PatrykWo <[email protected]>
1 parent 88aa7f9 commit d969f38

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

docs/features/unified_attn.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ To get the main idea behind the algorithm, let's work on a concrete example. Ass
2222
* we're using scaled dot product attention:
2323
$$\text{Attention}(Q, K, V, B) = \text{softmax}\left( s \cdot QK^\top + B \right) V$$
2424

25-
![](../../docs/assets/unified_attn/block_table.png)
25+
![](../assets/unified_attn/block_table.png)
2626

2727
We can observe two things:
2828

2929
1. some of the blocks are only used by a single token, and some are shared
30-
1. some of the key values have been just calculated and are available alongside queries and don't need to be fetched from the cache
30+
2. some of the key values have been just calculated and are available alongside queries and don't need to be fetched from the cache
3131

3232
In a naive implementation we would just multiply whole query times key and value and use appropriate bias to mask unused fields, but that would be very inneficient especially for decodes where usually we have only a single token per sample in a batch and there's almost no overlap between used blocks. We could slice the query and key into chunks and multiply only those regions that have relevant data, but that's currently difficult to achieve due to technical reasons. Instead we can divide the work into 3 separate parts and merge the results at the end.
3333

34-
![](../../docs/assets/unified_attn/block_table_annotated.png)
34+
![](../assets/unified_attn/block_table_annotated.png)
3535

3636
## Splitting softmax
3737

@@ -40,7 +40,7 @@ $$\text{softmax}(x_i) = \frac{e^{x_i-c}}{\sum_{j} e^{x_j-c}}, c = max(x_i)$$
4040
The problem here lies in the denominator as it contains the sum of all terms. Fortunately we can split the calculation into two separate softmax and then readjust the results and combine them. Let's say we have:
4141
$$z_1, z_2\text{ - local softmax results} \\ c_1, c_2 \text{ - local maxima} \\ s_1, s_2 \text{ - local sums}$$
4242
We can then calculate:
43-
$$c = max(c_1, c_2) \\ adj_i = e^{c_i-c} \\ s = s_1 * adj_1 + s_2 * adj_2\\ z_i\prime = \frac{z_i*s_i*adj_i}{s} $$
43+
$$c = max(c_1, c_2) \\ adj_i = e^{c_i-c} \\ s = s_1 *adj_1 + s_2* adj_2\\ z_i\prime = \frac{z_i*s_i*adj_i}{s} $$
4444

4545
This way we can calculate parts of softmax and later readjust and recombine the values into the final result. There are two other tricks that we can use. Since we're going to divide by the global sum anyway we can skip dividing by local sums followed by multiplying by local sums during readjustment and keep intermediate 'softmax' values without division. Additionally since readjustment is multiplication by a constant we can utilize the facts that:
4646
$$(sA)B=s(AB) \\ [A; B; C+D] \times [A; C+D; E] = [A; B; C] \times [A; C; E] + [A; B; D] \times [A; D; E] = [A; B; E]$$
@@ -50,11 +50,11 @@ and move softmax readjustment after multiplication by V in attention calculation
5050

5151
Causal attention is used to calculate attention values between currently computed Q, K and V. Since we data has been recently calculated, we don't need to fetch it from kv-cache. Prompt lengths are usually much longer then max_num_seqs. This means, in practice, we don't need to distinguish which tokens are used in prompts and which in decodes and use the whole Q relying on attn bias to mask out unnecessary tokens. Since we're using all query tokens one after another it works similarily to merged prefill feature. Here's an example how the computed causal bias might look like:
5252

53-
![](../../docs/assets/unified_attn/causal.png)
53+
![](../assets/unified_attn/causal.png)
5454

5555
One optimization that is used here is that we can divide query into equal slices that use different lengths of key:
5656

57-
![](../../docs/assets/unified_attn/causal_sliced.png)
57+
![](../assets/unified_attn/causal_sliced.png)
5858

5959
This way we can skip parts of the computation where index(key) > index(query). In the current implementation slice size is constant and is set to 512 based on experimental results.
6060

0 commit comments

Comments
 (0)