Skip to content

Files

Latest commit

 

History

History
290 lines (248 loc) · 12.9 KB

positional-encoding.md

File metadata and controls

290 lines (248 loc) · 12.9 KB

Positional Encoding

Positional encoding is used in transformers to give the model some notion of position of the tokens in a word sequence. This is done by adding a positional encoding vector with the input embedding vector. In an RNN, the position of the token in the sequence is encoded in the hidden state of the RNN which is because RNNs are sequential models. Transformers are not sequential models and so we need to add some notion of position to the input embedding vector.

What we want to do is add some notion of position to the input embedding vector. So how about just adding an integer that gets incremented for each token in the sequence?

embedding = [0.1, 0.2, 0.3, 0.4]
positional_encoding = [0, 1, 2, 3]
embedding + positional_encoding = [0.1, 1.2, 2.3, 3.4]

An issue with this is that the sequences can be very large and that will effect the gradients. But perhaps the integer can be normalized to a value between 0 and 1. This has an issue where it is not possible to know how many tokens are in the sequence.

So instead of adding an integer to the embedding vector, we can add a vector that is calculated using a formula. This vector is called the positional encoding vector. The positional encoding vector is added to the input embedding vector.

The positional encoding matrix is calculated using the following formula:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

Where pos is the position of the token in the sequence. The i variable confused me for a while as I read that it was the index of the token sequence. But it is not and it is used as a column index for the matrix of positional encoding values. For example, if we have a token embedding that is a vector of four tokens, then the positional encoding will be a matrix of four rows and four columns:

d_model = 4

          i=0          i=0        i=1          i=1
pos  0   [ 0.          1.          0.          1.        ]
pos  1   [ 0.84147098  0.54030231  0.00999983  0.99995   ]
pos  2   [ 0.90929743 -0.41614684  0.01999867  0.99980001]
pos  4   [ 0.14112001 -0.9899925   0.0299955   0.99955003]

pos₀₀ = sin(0 / 10000^((2*0)/4)) = sin(0 / 10000^(0/4) = sin(0/10000^0)    = sin(0/1)  = sin(0)   = 0
pos₀₁ = cos(0 / 10000^((2*0)/4)) = cos(0 / 10000^(0/4) = cos(0/10000^0)    = cos(0/1)  = cos(0)   = 1
pos₀₂ = sin(0 / 10000^((2*1)/4)) = sin(0 / 10000^(1/4) = sin(0/10000^0.25) = sin(0/10) = sin(0)   = 0
pos₀₃ = cos(0 / 10000^((2*1)/4)) = cos(0 / 10000^(1/4) = cos(0/10000^0.25) = cos(0/10) = cos(0)   = 1

pos₁₀ = sin(1 / 10000^((2*0)/4)) = sin(1 / 10000^(0/4) = sin(1/10000^0)    = sin(1/1)  = sin(1)   = 0.84147098
pos₁₁ = cos(1 / 10000^((2*0)/4)) = cos(1 / 10000^(0/4) = cos(1/10000^0)    = cos(1/1)  = cos(1)   = 0.54030231
pos₁₂ = sin(1 / 10000^((2*1)/4)) = sin(1 / 10000^(1/4) = sin(1/10000^0.25) = sin(1/10) = sin(0.1) = 0.00999983
pos₁₃ = cos(1 / 10000^((2*1)/4)) = cos(1 / 10000^(1/4) = cos(1/10000^0.25) = cos(1/10) = cos(0.1) = 0.99995

pos₂₀ = sin(2 / 10000^((2*0)/4)) = sin(2 / 10000^(0/4) = sin(2/10000^0)    = sin(2/1)  = sin(2)   = 0.90929743
pos₂₁ = cos(2 / 10000^((2*0)/4)) = cos(2 / 10000^(0/4) = cos(2/10000^0)    = cos(2/1)  = cos(2)   = -0.41614684
pos₂₂ = sin(2 / 10000^((2*1)/4)) = sin(2 / 10000^(1/4) = sin(2/10000^0.25) = sin(2/10) = sin(0.2) = 0.01999867
pos₂₃ = cos(2 / 10000^((2*1)/4)) = cos(2 / 10000^(1/4) = cos(2/10000^0.25) = cos(2/10) = cos(0.2) = 0.99980001

pos₃₀ = sin(3 / 10000^((2*0)/4)) = sin(3 / 10000^(0/4) = sin(3/10000^0)    = sin(3/1)  = sin(3)   = 0.14112001
pos₃₁ = cos(3 / 10000^((2*0)/4)) = cos(3 / 10000^(0/4) = cos(3/10000^0)    = cos(3/1)  = cos(3)   = -0.9899925
pos₃₂ = sin(3 / 10000^((2*1)/4)) = sin(3 / 10000^(1/4) = sin(3/10000^0.25) = sin(3/10) = sin(0.3) = 0.0299955
pos₃₃ = cos(3 / 10000^((2*1)/4)) = cos(3 / 10000^(1/4) = cos(3/10000^0.25) = cos(3/10) = cos(0.3) = 0.99955003

So I can understand the formula and there is an example in positional-encoding.py, but I'm not sure about the intuition behind it. If we think/visualize the sine and cosine waves stacked upon each other and there will be a a wave (either sine or cosine) for each token in the sequence. Sine/Cosine Waves

(The image was generated by src/positional-encoding-waves.py)

If we look at waves above and think of the first (bottom most) graph, the value y value for x=0 represents the positional encoding for the first token in the sequence of tokens (which is sine(1)=0):

Input sentence: "Dan loves icecream"

+---+ +---+ +---+     +---+ +---+ +---+     +---+ +---+ +---+
| 0 | |   | |   |     |   | |   | |   |     |   | |   | |   |
+---+ +---+ +---+     +---+ +---+ +---+     +---+ +---+ +---+
  0     1     2         0     1     2         0     1     2
Token 1               Token 2               Token 3

  D     l     i         D     l     i         D     l     i
  a     o     c         a     o     c         a     o     c
  n     v     e         n     v     e         n     v     e
        e     c               e     c               e     c
        s     r               s     r               s     r
              e                     e                     e
              a                     a                     a
              m                     m                     m

Position nr 1 in the same first sine wave will provided the value of the second embedded encoding's first positional encoding value (which is sine(2)=0.90929743). Likewise, for the third embedding value it's first value will be sine(3)=0.14112001. So all embedding encodings will get their first value from the first sine wave.

The second positional encoding value is taken from the cosine wave above but also at position 0 which is 1 (cos(0)=1):

+---+ +---+ +---+     +---+ +---+ +---+     +---+ +---+ +---+
| 0 | | 1 | |   |     |   | |   | |   |     |   | |   | |   |
+---+ +---+ +---+     +---+ +---+ +---+     +---+ +---+ +---+
  0     1     2         0     1     2         0     1     2
Token 1               Token 2               Token 3

  D     l     i         D     l     i         D     l     i
  a     o     c         a     o     c         a     o     c
  n     v     e         n     v     e         n     v     e
        e     c               e     c               e     c
        s     r               s     r               s     r
              e                     e                     e
              a                     a                     a
              m                     m                     m

Now, if we had another sine wave above the cosine wave this would have a longer frequency. So the wave would be slightly more stretched out. This would mean that the positional encoding values would be more spread out.

If we turn our attention to the second entry, Token 2 above, in our token sequence it will also get its first positional encoding value from sine wave at the bottom, but this time it will get the y value for x=1 (so that will be sine(1)=0.84147):

+---+ +---+ +---+     +-------+ +-------+ +---+     +---+ +---+ +---+
| 0 | | 1 | |   |     |0.84147| |0.54030| |   |     |   | |   | |   |
+---+ +---+ +---+     +-------+ +-------+ +---+     +---+ +---+ +---+
  0     1     2         0     1     2               0     1     2
Token 1               Token 2                       Token 3

  D     l     i         D     l     i               D     l     i
  a     o     c         a     o     c               a     o     c
  n     v     e         n     v     e               n     v     e
        e     c               e     c                     e     c
        s     r               s     r                     s     r
              e                     e                           e
              a                     a                           a
              m                     m                           m

For token 2's second positional encoding value, it will get that value from the cosine wave above but at position 1 which is 0.54030231.

If we imagine this in binary we are trying to do something simliar to this:

pos  0   [0000]
pos  1   [0001]
pos  2   [0010]
pos  3   [0011]
pos  4   [0100]
pos  5   [0101]
pos  6   [0110]
pos  7   [0111]
pos  8   [1000]
pos  9   [1001]
pos 10   [1010]
pos 11   [1011]
pos 12   [1100]
pos 13   [1101]
pos 14   [1110]
pos 15   [1111]

One thing to notice is that if we look at the columns they also have a frequency , that is they repeat. For example, the first column repeats every 2 entries, the second column repeats every 4 entries, the third column repeats every 8 and so on. One problem with binary numbers is that there is straight forward way to represent relative positions. For example 100 does not tells us that it comes before 101. Using sine and cosine waves we are able to represent relative positions. For example, if we look at the first column, we can see that the first entry is 0 and the second entry is 1. This is because the sine wave starts at 0 and then goes to 1.

But with sine and cosine waves we are able to represent the positional encoding with a much smaller number of bits. For example, if we have a sequence of 16 tokens, we can represent the positional encoding with 4 bits instead of 16 bits. This is because we are able to represent the positional encoding with sine and cosine waves.


          i=0          i=0        i=1          i=1
pos  0   [ 0.          1.          0.          1.        ]
pos  1   [ 0.84147098  0.54030231  0.00999983  0.99995   ]
pos  2   [ 0.90929743 -0.41614684  0.01999867  0.99980001]
pos  4   [ 0.14112001 -0.9899925   0.0299955   0.99955003]

We start with a sine function which is 2pi periodic.

10000^(2i/d) represents a frequency which decreases as i increases, as we go from lower to higher demensions. So the wave gets more and more straight out as we go from lower to higher demensions.

Evolution of Positional Encoding strategies

  • Original absolute positional encoding Difficult to have a model generalize to sequences of different lengths.

  • Relative positional encoding The short coming here is that the relative positions are updated each time a new token is added to the sequence which make implementing a kv-cache difficult.

  • Learned positional encoding

  • RoPE

RoPE encoding:

[cos(nθ₀), sin(nθ₀), cos(nθ₁), sin(nθ₁), ...cos(nθ_(d/2)-1), sin(nθ_(d/2)-1)]

n   = token position
d   = embedding size        // context length
θ_i = 10000θ^-2i/d          // rotation fequencies

Lets look at a few values for i (token sequence position)

i = 0 10000^(-1*0)/2048 = 1
i = 1 10000^(-2*1)/2048 = 0.999755859375
i = 2 10000^(-2*2)/2048 = 0.99951171875
...
i = 2046 10000^(-2*2046)/2048 = 0.000244140625
i = 2047 10000^(-2*2047)/2048 = 0.0001220703125

With lower values for i the theta_i i closer to 1 which corresponds to a lower frequency since the changes in sine and cosine with respect to the position is smaller (longer wavelengths, fewer occilations). With higher values for i the theta_i is closer to 0 which corresponds to a higher frequency since the changes in sine and cosine with respect to the position is larger (shorter wavelenght, more occilations).

We can expolate this to a larger context window by increasing d.

  • RoPE with Frequency Scaling
[cos(n * freq_scale * θ₀), sin(n * freq_scale * θ₀), cos(n * freq_scale * θ₁), sin(n * freq_scale * θ₁), ...]
  • RoPE with Position Interpolation (PI) To handle the extended context window, we need to adjust (or interpolate) the positional encodings. This involves rescaling the rotation frequencies to fit the new context length.

L is the context window size which is the number of tokens the model can process at once. If we need/want to extend the context window we call this values L prime (L'). This extension ration defines as:

s = L'/L

s = extension ratio
β = θ^2/d              // base frequency scaling factor
λ = s

[cos(n/(λ(β)₀), sin(n/(λ(β)₀), cos(n/(λ(β)₁), sin(n/(λ(β)₁), ...]

  n
 -----
 λ(β)₀

n = token position

So the scaling factor s is the same for all values.

  • NTK Positional Encoding (Neural Tangent Kernel) Splits up the lower and highter dimensions and has a different scaling factor. So depending on where in the dimension is the scaling factor λ will be different, so λ^i. For lower dimensions the scaling factor is higher, and for higher dimensions the scaling factor is lower.

  • Yarn (Yet Another RoPE extentioN method) Is an extension of NTK which recall has to sections of the dimensions which it scaled differently. With Yarn we have 3 secions:

Low    frequencies: Position Interpolation
Middle frequencies: NTK (So the lower/higher positions are scaled differently)
High   frequencies: Extrapolation, λ=1

Since each position has a fixed encoding, positions beyond the training range would have encodings that the model has never seen, making it difficult for the model to interpret these positions accurately. This is something that RoPE addresses.