|
if self.rope_interleaved: |
|
query_states = self.rotary_embedding(query_states, position_ids=position_ids) |
|
key_states = self.rotary_embedding(key_states, position_ids=position_ids) |
|
else: |
|
cos, sin = self.rotary_embedding(value_states, position_ids) |
|
query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
# Compute rotary embeddings |
|
# Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache |
|
old_rotary_embed_end = self.rotary_embedding.end |
|
# interleaved version. |
|
if self.rope_interleaved: |
|
query_states = self.rotary_embedding(query_states, position_ids=position_ids) |
|
key_states = self.rotary_embedding(key_states, position_ids=position_ids) |
|
# non interleaved version. |
|
else: |
|
cos, sin = self.rotary_embedding(value_states, position_ids) |
|
query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb( |
|
query_states, key_states, cos, sin |
|
) |
|
|
|
if "key" not in store: |
|
# First inference iteration (Prefill) |
|
# TODO @nouamane: support custom masking |
|
# assert that [ False, False, False, False, True, True, True, True, True, True] is accepted |
|
# but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence) |
|
assert ~( |
|
sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False |
|
).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing" |
|
|
|
# preallocate k_cache, v_cache to self.prefill_kv_len |
|
k_cache = torch.zeros( |
|
( |
|
batch_size, |
|
self.prefill_kv_len, |
|
self.n_local_kv_heads, |
|
self.d_qk, |
|
), |
|
dtype=query_states.dtype, |
|
device=query_states.device, |
|
) |
|
v_cache = torch.zeros( |
|
(batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v), |
|
dtype=query_states.dtype, |
|
device=query_states.device, |
|
) |
|
# Remove pad tokens from key_states and concatenate samples in key_unpad |
|
# cu_seqlens_k is the cumulative sequence lengths of key_states |
|
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( |
|
query_states, |
|
sequence_mask, |
|
) |
|
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( |
|
key_states, sequence_mask |
|
) |
|
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask) |
|
|
|
# NOTE: this scale is for µTransfer, |
|
# in SP, we use sqrt(1/d_h) |
|
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None |
|
output_unpad = flash_attn_varlen_func( |
|
q=query_unpad, # (total_q, n_local_q_heads, d_qk) |
|
k=key_unpad, # (total_kv, n_local_kv_heads, d_qk) |
|
v=value_unpad, # (total_kv, n_local_kv_heads, d_v) |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_q, |
|
max_seqlen_k=max_seqlen_k, |
|
dropout_p=0.0, |
|
softmax_scale=softmax_scale, |
|
causal=True, # True in prefill phase, False in subsequent phases |
|
return_attn_probs=False, |
|
) # (total_unpadded, n_local_q_heads, d_v) |
|
|
|
attention_output = bert_padding.pad_input( |
|
output_unpad, indices_q, batch_size, q_length |
|
) # (batch_size, q_length, n_local_q_heads, d_v) |
|
|
|
pad_to_right(key_states, sequence_mask, new_tensor=k_cache) |
|
pad_to_right(value_states, sequence_mask, new_tensor=v_cache) |
|
|
|
else: |
|
# Pull pre-computed key/value states |
|
# Subsequent inference iterations (q_length=1) |
|
k_cache = store["key"] |
|
v_cache = store["value"] |
|
|
|
# NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values" |
|
# Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache |
|
if self.rotary_embedding.end > old_rotary_embed_end: |
|
k_cache = torch.cat( |
|
[ |
|
k_cache, |
|
torch.zeros( |
|
( |
|
batch_size, |
|
self.rotary_embedding.end - old_rotary_embed_end, |
|
self.n_local_kv_heads, |
|
self.d_qk, |
|
), |
|
dtype=query_states.dtype, |
|
device=query_states.device, |
|
), |
|
], |
|
dim=1, |
|
) |
|
|
|
v_cache = torch.cat( |
|
[ |
|
v_cache, |
|
torch.zeros( |
|
( |
|
batch_size, |
|
self.rotary_embedding.end - old_rotary_embed_end, |
|
self.n_local_kv_heads, |
|
self.d_v, |
|
), |
|
dtype=query_states.dtype, |
|
device=query_states.device, |
|
), |
|
], |
|
dim=1, |
|
) |
|
|
|
assert ( |
|
k_cache.shape[1] == self.rotary_embedding.end |
|
), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}" |
|
assert ( |
|
v_cache.shape[1] == self.rotary_embedding.end |
|
), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}" |
|
|
|
# [batch_size, seq_length, num_heads, d_qk] |
|
query_states = query_states.view( |
|
batch_size, q_length, self.n_local_q_heads, self.d_qk |
|
) # [batch_size, q_length, self.n_heads, d_qk] |
|
kv_length = key_states.shape[1] |
|
key_states = key_states.view( |
|
batch_size, kv_length, self.n_local_kv_heads, self.d_qk |
|
) # [batch_size, kv_length, self.n_heads, d_qk] |
|
value_states = value_states.view( |
|
batch_size, kv_length, self.n_local_kv_heads, self.d_v |
|
) # [batch_size, kv_length, self.n_heads, d_v] |
|
|
|
# NOTE: this scale is for µTransfer, |
|
# in SP, we use sqrt(1/d_h) |
|
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None |
|
attention_output = flash_attn_with_kvcache( |
|
query_states, |
|
k_cache, |
|
v_cache, |
|
key_states, |
|
value_states, |
|
rotary_cos=None, |
|
rotary_sin=None, |
|
# TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0) |
|
cache_seqlens=position_offsets.contiguous(), |
|
softmax_scale=softmax_scale, |
|
causal=True, |
|
rotary_interleaved=False, # the value is not used unless rotary_cos/sin is provided. https://github.com/Dao-AILab/flash-attention |
|
) |
|
|
|
store.update( |
|
{ |
|
"key": k_cache, # flash-attn has updated with new key_states using cache_seqlens |
|
"value": v_cache, |
|
"position_offsets": position_offsets, |
|
} |
|
) |
There's a bug of duplicate code w/ wrong indentation level when computing
attention_outputinCausalSelfAttention._forward_inference. Currently it's never computed.nanotron/src/nanotron/models/llama.py
Lines 499 to 666 in c737f00
Wrong argument passed to
parametrizator_clswheninit_model_randomlyfor testingnanotron/src/nanotron/models/llama.py
Line 1095 in c737f00