Skip to content

[Cherry-pick][Optimization] enable trtllm_all_reduce fusion kernel in glm model#7228

Open
BingooYang wants to merge 17 commits intoPaddlePaddle:release/2.6from
BingooYang:2.6/trtllm_allreduce
Open

[Cherry-pick][Optimization] enable trtllm_all_reduce fusion kernel in glm model#7228
BingooYang wants to merge 17 commits intoPaddlePaddle:release/2.6from
BingooYang:2.6/trtllm_allreduce

Conversation

@BingooYang
Copy link
Copy Markdown
Contributor

Motivation

FD接入trtllm_allreduce_fusion算子

Modifications

  1. FD新增flashinfer allreduce fusion算子接入
  2. 更改GLM-Air-4.5模型组网结构接入trtllm_allreduce_fusion算子(默认不开启)
  3. 新增命令行参数--enable-flashinfer-allreduce-fusion,通过该参数来使能trtllm_allreduce_fusion
  4. 新增trtllm_allreduce_fusion算子单测
  5. 将def has_flashinfer()函数挪动到utils.py中

Usage or Command

H卡和B卡本地测试均通过
python -m fastdeploy.entrypoints.openai.api_server --model /root/paddlejob/workspace/bingoo/model/GLM-4.5-Air --tensor-parallel-size 4 --port 8185 --max-num-batched-tokens 2048 --enable-flashinfer-allreduce-fusion

Accuracy Tests

python -m paddle.distributed.launch --gpus=0,1 ./FastDeploy/tests/layers/test_rms_allreduce_fusion.py

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 7, 2026

Thanks for your contribution!

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 51.00000% with 49 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.6@b24765a). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...oy/model_executor/layers/flashinfer_comm_fusion.py 43.37% 38 Missing and 9 partials ⚠️
fastdeploy/model_executor/layers/normalization.py 50.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.6    #7228   +/-   ##
==============================================
  Coverage               ?   73.22%           
==============================================
  Files                  ?      377           
  Lines                  ?    52980           
  Branches               ?     8262           
==============================================
  Hits                   ?    38796           
  Misses                 ?    11466           
  Partials               ?     2718           
Flag Coverage Δ
GPU 73.22% <51.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-08

📋 Review 摘要

PR 概述:在 GLM 模型中启用 TRTLLM all_reduce fusion kernel,通过 flashinfer 提供融合的 all-reduce + residual + RMSNorm 操作以优化性能。

变更范围model_executor/layers/model_executor/models/config/engine/

影响面 Tag[Optimization] [OP] [Models]

发现的问题

级别 文件 概述
🔴 Bug normalization.py:249 flashinfer 返回 None 时后续访问会抛出 TypeError
🔴 Bug glm4_moe.py:226,301,309 prefix 修改导致模型权重加载失败
🟡 建议 flashinfer_comm_fusion.py:74 cleanup 时应使用初始化时的相同 group
🟡 建议 linear.py:864,glm4_moe.py:226 建议使用配置参数替代 prefix 字符串匹配

总体评价

整体设计合理,通过 flashinfer 实现了 all-reduce fusion 优化。但存在两个严重的 Bug:一是未处理 flashinfer 返回 None 的情况会导致运行时错误,二是 prefix 修改破坏了权重加载机制。这两个问题必须修复后才能合入。

return norm_out.astype(x_dtype), residual_out
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
# enable trtllm all reduce fusion
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bugflashinfer_allreduce_residual_rmsnorm 返回 (None, None) 时(flashinfer 不可用或未初始化),后续代码访问 norm_out 会抛出 TypeError

建议添加 fallback 逻辑检查返回值是否为 None:

elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
    fusion_result = flashinfer_allreduce_residual_rmsnorm(
        fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
    )
    if fusion_result[0] is not None:
        norm_out = fusion_result
    else:
        # Fallback to standard implementation
        norm_out = self.norm_func(
            x, norm_weight=self.weight, norm_bias=None, epsilon=self.eps,
            begin_norm_axis=self.begin_norm_axis, bias=self.bias, residual=residual_input,
            quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
            quant_round_type=self.quant_round_type,
            quant_max_bound=self.quant_max_bound,
            quant_min_bound=self.quant_min_bound,
        )

self.o_proj = RowParallelLinear(
fd_config,
prefix=f"{prefix}.o_proj",
prefix=f"{prefix}.enable_all_reduce.o_proj",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug o_proj 的 prefix 从 f"{prefix}.o_proj" 修改为 f"{prefix}.enable_all_reduce.o_proj",但 _get_tensor_parallel_mappings 中的权重映射仍使用 layers.0.self_attn.o_proj.weight

这会导致模型权重无法正确加载,因为模型参数名变成了 layers.{i}.self_attn.enable_all_reduce.o_proj.weight,但权重文件中的 key 是 layers.{i}.self_attn.o_proj.weight

建议:

  1. 不修改 prefix,使用单独的配置参数来控制 fusion 行为
  2. 或者同时更新权重映射逻辑以匹配新的 prefix

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该不影响模型加载,请重新确认一下

hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
prefix=f"{prefix}.enable_all_reduce_fusion.input_layernorm",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 同上,input_layernorm 的 prefix 修改为包含 enable_all_reduce_fusion,但权重映射逻辑没有相应更新,会导致权重加载失败。

hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
prefix=f"{prefix}.enable_all_reduce_fusion.post_attention_layernorm",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 同上,post_attention_layernorm 的 prefix 修改为包含 enable_all_reduce_fusion,但权重映射逻辑没有相应更新。

"""Clean up workspace"""
if self.initialized and self.ipc_handles is not None:
try:
_flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 cleanup 时使用 dist.get_group() 获取默认 group,但初始化时(第 55 行)没有指定 group,这意味着可能使用了非默认的通信组。

建议:

  1. FlashInferWorkspaceManager 中记录初始化时使用的 group
  2. cleanup 时使用相同的 group
class FlashInferWorkspaceManager:
    def __init__(self):
        self.group = None  # 记录初始化时使用的 group
        ...

    def initialize(self, ..., group=None, ...):
        self.group = group  # 保存 group
        ...

    def cleanup(self):
        ...
        _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
            self.ipc_handles, group=self.group or dist.get_group()
        )

"""
self.fd_config = fd_config
self.enable_all_reduce_fusion = (
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "enable_all_reduce" in prefix
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 通过 "enable_all_reduce" in prefix 字符串匹配来判断是否启用 fusion,这种方式不够直观且容易出错(如本次 PR 中的权重加载问题)。

建议使用明确的配置参数,例如:

# 在 ParallelConfig 中添加
self.enable_o_proj_fusion: bool = False

# 在 linear.py 中
self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and fd_config.parallel_config.enable_o_proj_fusion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants