Skip to content

Conversation

@S1ro1
Copy link
Collaborator

@S1ro1 S1ro1 commented Nov 17, 2025

I was too lazy to explain my changes, decreasing my chances of this ever getting merged.


GitHub Issue: [Issue ID]
Linear Issue: Resolves [Issue ID]

@S1ro1 S1ro1 force-pushed the feat-ring-flash-attn branch from 6c64e0b to c749b20 Compare November 19, 2025 14:49
Comment on lines +178 to +242
def substitute_prime_rl_flash_attn(process_group: torch.distributed.ProcessGroup, heads_k_stride: int) -> None:
from ring_flash_attn import llama3_flash_attn_varlen_func

global ATTN_IMPL2CLASS

class RingFlashAttention(FlashAttention):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
cu_seqlens: torch.LongTensor | None = None,
max_seqlen: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

query_states = self.q_proj(hidden_states).view(hidden_shape)
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = self.v_proj(hidden_states).view(hidden_shape)

if self.use_qk_norm: # main diff from Llama
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS

cu_seqlens_q = DATA_PARAMS["cu_seqlens_q"]
cu_seqlens_k = DATA_PARAMS["cu_seqlens_k"]
max_seqlen_q = DATA_PARAMS["max_seqlen_q"]
max_seqlen_k = DATA_PARAMS["max_seqlen_k"]
local_k_slice = DATA_PARAMS["local_k_slice"]

# TODO: Can we optimize the rotary applicaiton instead of double transpose?
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
out = llama3_flash_attn_varlen_func(
query_states[0],
key_states[0],
value_states[0],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
local_k_slice=local_k_slice,
causal=True,
group=process_group,
heads_k_stride=heads_k_stride,
)
out = out.contiguous()
attn_output = out.view(1, out.shape[0], -1)
attn_weights = None

# attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights

ATTN_IMPL2CLASS["flash_attention_2"].forward = RingFlashAttention.forward
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not the biggest fan of this pattern, smell like transformer magic.

what about just having flash_attenion_2_cp in attn.impl ? and doing parralel.cp 2 would automatically change the attention to it

Comment on lines +244 to +255
input_ids = maybe_shard_for_cp(
micro_batch["input_ids"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp
)
advantages = maybe_shard_for_cp(
micro_batch["advantages"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp
)
loss_mask = maybe_shard_for_cp(
micro_batch["loss_mask"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp
)
inference_logprobs = maybe_shard_for_cp(
micro_batch["inference_logprobs"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather do

Suggested change
input_ids = maybe_shard_for_cp(
micro_batch["input_ids"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp
)
advantages = maybe_shard_for_cp(
micro_batch["advantages"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp
)
loss_mask = maybe_shard_for_cp(
micro_batch["loss_mask"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp
)
inference_logprobs = maybe_shard_for_cp(
micro_batch["inference_logprobs"].to("cuda"), cp_rank=cp_rank, cp_world_size=config.model.cp
)
if config.model.cp > 1:
shard_for_cp(...)

not fan of the maybe pattern

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants