Skip to content

[BUG]DotProductAttention:Disabling FlashAttention as it does not support MLA #1698

@derby-ding

Description

@derby-ding

Describe the bug
DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=False}
DEBUG:DotProductAttention:Selected backend = NoBackend

To Reproduce

KV_LORA_RANK=512
QK_NOPE_HEAD_DIM=128
QK_ROPE_HEAD_DIM=64
V_HEAD_DIM=128
NUM_EXPERTS=32
ROUTER_TOPK=3
NUM_SHARED_EXPERTS=1
MOE_LAYER_FREQ=1
MOE_FIRST_K_DENSE_REPLACE=2
RMS_NORM_EPS=1e-6

MOE_ARGS=(
--moe-ffn-hidden-size ${MOE_INTERMEDIATE_SIZE}
--moe-router-topk ${ROUTER_TOPK}
--num-experts ${NUM_EXPERTS}
--moe-layer-freq ${MOE_LAYER_FREQ}
--moe-aux-loss-coeff 0.001
--moe-shared-expert-intermediate-size $((${MOE_INTERMEDIATE_SIZE} * ${NUM_SHARED_EXPERTS} ))
--kv-lora-rank ${KV_LORA_RANK}
--qk-head-dim ${QK_NOPE_HEAD_DIM}
--qk-pos-emb-head-dim ${QK_ROPE_HEAD_DIM}
--v-head-dim ${V_HEAD_DIM}
--moe-grouped-gemm ###
)
export NVTE_FLASH_ATTN=1 NVTE_FUSED_ATTN=0
fl_options=" --attention-backend flash "
--multi-latent-attention
--moe-router-dtype fp32 ##moe
--moe-permute-fusion

Expected behavior
MLA can support flashattention, if not, GPU will get oom error when sequence length longer than 20K.

Stack trace/logs
If applicable, add the stack trace or logs from the time of the error.

Environment (please complete the following information):

  • Megatron-LM commit ID
  • nvcr.io/nvidia/pytorch:24.05-py3
    A100 80G, cuda 12.4 torch2.7.1 te 2.4.0 flash attn 2.4.2

Proposed fix
If you have a proposal for how to fix the issue state it here or link to a PR.

Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions