-
Notifications
You must be signed in to change notification settings - Fork 135
Feat ring flash attn #1305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feat ring flash attn #1305
Conversation
6c64e0b to
c749b20
Compare
| def substitute_prime_rl_flash_attn(process_group: torch.distributed.ProcessGroup, heads_k_stride: int) -> None: | ||
| from ring_flash_attn import llama3_flash_attn_varlen_func | ||
|
|
||
| global ATTN_IMPL2CLASS | ||
|
|
||
| class RingFlashAttention(FlashAttention): | ||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | ||
| cu_seqlens: torch.LongTensor | None = None, | ||
| max_seqlen: int | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | ||
| input_shape = hidden_states.shape[:-1] | ||
| hidden_shape = (*input_shape, -1, self.head_dim) | ||
|
|
||
| query_states = self.q_proj(hidden_states).view(hidden_shape) | ||
| key_states = self.k_proj(hidden_states).view(hidden_shape) | ||
| value_states = self.v_proj(hidden_states).view(hidden_shape) | ||
|
|
||
| if self.use_qk_norm: # main diff from Llama | ||
| query_states = self.q_norm(query_states) | ||
| key_states = self.k_norm(key_states) | ||
|
|
||
| query_states = query_states.transpose(1, 2) | ||
| key_states = key_states.transpose(1, 2) | ||
| value_states = value_states.transpose(1, 2) | ||
|
|
||
| cos, sin = position_embeddings | ||
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||
|
|
||
| from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS | ||
|
|
||
| cu_seqlens_q = DATA_PARAMS["cu_seqlens_q"] | ||
| cu_seqlens_k = DATA_PARAMS["cu_seqlens_k"] | ||
| max_seqlen_q = DATA_PARAMS["max_seqlen_q"] | ||
| max_seqlen_k = DATA_PARAMS["max_seqlen_k"] | ||
| local_k_slice = DATA_PARAMS["local_k_slice"] | ||
|
|
||
| # TODO: Can we optimize the rotary applicaiton instead of double transpose? | ||
| query_states = query_states.transpose(1, 2) | ||
| key_states = key_states.transpose(1, 2) | ||
| value_states = value_states.transpose(1, 2) | ||
| out = llama3_flash_attn_varlen_func( | ||
| query_states[0], | ||
| key_states[0], | ||
| value_states[0], | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| cu_seqlens_k=cu_seqlens_k, | ||
| max_seqlen_q=max_seqlen_q, | ||
| max_seqlen_k=max_seqlen_k, | ||
| local_k_slice=local_k_slice, | ||
| causal=True, | ||
| group=process_group, | ||
| heads_k_stride=heads_k_stride, | ||
| ) | ||
| out = out.contiguous() | ||
| attn_output = out.view(1, out.shape[0], -1) | ||
| attn_weights = None | ||
|
|
||
| # attn_output = attn_output.reshape(*input_shape, -1).contiguous() | ||
| attn_output = self.o_proj(attn_output) | ||
| return attn_output, attn_weights | ||
|
|
||
| ATTN_IMPL2CLASS["flash_attention_2"].forward = RingFlashAttention.forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not the biggest fan of this pattern, smell like transformer magic.
what about just having flash_attenion_2_cp in attn.impl ? and doing parralel.cp 2 would automatically change the attention to it
| input_ids = maybe_shard_for_cp( | ||
| micro_batch["input_ids"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp | ||
| ) | ||
| advantages = maybe_shard_for_cp( | ||
| micro_batch["advantages"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp | ||
| ) | ||
| loss_mask = maybe_shard_for_cp( | ||
| micro_batch["loss_mask"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp | ||
| ) | ||
| inference_logprobs = maybe_shard_for_cp( | ||
| micro_batch["inference_logprobs"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather do
| input_ids = maybe_shard_for_cp( | |
| micro_batch["input_ids"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp | |
| ) | |
| advantages = maybe_shard_for_cp( | |
| micro_batch["advantages"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp | |
| ) | |
| loss_mask = maybe_shard_for_cp( | |
| micro_batch["loss_mask"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp | |
| ) | |
| inference_logprobs = maybe_shard_for_cp( | |
| micro_batch["inference_logprobs"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp | |
| ) | |
| if config.model.cp > 1: | |
| shard_for_cp(...) |
not fan of the maybe pattern
I was too lazy to explain my changes, decreasing my chances of this ever getting merged.
GitHub Issue: [Issue ID]
Linear Issue: Resolves [Issue ID]