Skip to content

Conversation

@A-nnonymous
Copy link
Contributor

@A-nnonymous A-nnonymous commented Dec 23, 2025

PR Category

Operator Mechanism

PR Types

Improvements

Description

本项目旨在将 PaddlePaddle 中的 linear 算子与 PyTorch v2.9.1 进行功能对齐,通过引入 linear_v2 架构优化对齐底层算子的执行机制、自动微分逻辑及分布式策略。

注:由于对该算子的修改影响面十分广泛,为了兼容性考虑,本PR涉及到的所有机制都是可以一键回退的。可以使用FLAGS_use_legacy_linear=1将现有逻辑回退,保持和PR合入之前一样的算子行为,回退行为本身零GPU开销、CPU额外开销仅一个分支


一、变化

1. 性能

  • 通过 linear_v2 深度集成 cublaslt,在 GPU 上实现了更好的矩阵乘法与 Bias 融合。
  • 对齐了cublaslt中的heuristic search参数,完成了低overhead的运行时heuristics,对齐torch
  • 减少了 Kernel Launch 次数,通过降低launch开销提升整体计算吞吐量。
  • 融合matmul + add,减少了中间变量冗余读写

注意,在组合算子机制被手动开启(使用编译器)的情况下,linear_v2会被拆解为matmul + add,如有需要请使用黑名单机制阻止linear_v2的拆解。

实测数据(warmup 100, repeat 1000)

注:下图中L、N、T分别代表Legacy linear、New linear(linear_v2)和torch(v2.9.1)

Hopper fp16

image

Hopper bf16

image

Hopper fp32 (no tf32_override)

image

在包括了小batchsize(推理)和大batchsize(训练)的Hopper典型场景下,前反向平均加速比如下:

  • fp16相对老linear机制提升19.25%、6.92%

  • bf16相对老linear机制提升19.26%、6.84%

  • fp32相对老linear机制提升6.79%、5.69%

  • 已经确定在小shape下相对于legacy的收益大部分来源于调度开销的节省,而离torch的距离也来源于此。

  • 已经确定在中等、较大shape下的性能收益主要来源于对elewise add的节省,即使在相对于legacy有所回退的case里,和torch的性能也基本可以对齐

  • 已经确定相对于torch v2.9.1的差距基本来源于框架调度开销reduce性能(主要影响反向,因为该PR的测试环境未合入reduce相关修复)。

2.精度

  • 合入后paddle框架中linear精度默认对齐torch v2.9.1
  • 融合算子在低于fp32的精度case中,精度具有优势(节省一次downcast舍入),合入后linear_v2精度严格优于旧linear

二、主要文件修改与注解

1. 基础设施与代码生成

文件路径 描述
common/flags.cc 增加FLAGS_use_legacy_linear以便于回退至旧linear逻辑(兼容性方案)
fluid/eager/auto_code_generator/generator/eager_gen.py 适配动态图下的linear_v2拆分逻辑
fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py 组合算子列表加入linear_v2
fluid/pir/dialect/op_generator/op_build_gen.py 组合算子列表加入linear_v2
fluid/pir/dialect/op_generator/vjp_interface_black_list.py 将相关算子加入 VJP 黑名单以支持自定义反向
fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc/h linear_v2 inferSymbolic
fluid/prim/api/composite_backward/composite_double_backward_api.h 定义组合算子双重微分 API
fluid/prim/api/composite_backward/api.yaml 实现linear_v2组合算子的 API 映射
fluid/primitive/codegen/decomp_vjp_gen.py 实现linear_v2 VJP
fluid/primitive/decomp_rule/decomp_rule/composite.h 定义linear_v2分解规则
fluid/primitive/decomp_vjp/details.h linear_v2 VJP 分解的具体实现

2. 算子逻辑

文件路径 描述
phi/infermeta/spmd_rules/fused_gemm_epilogue.h 额外暴露两个通用方法,供linear_v2的spmd复用
phi/infermeta/spmd_rules/linear_v2.cc/h linear_v2 的自动并行切分规则
phi/infermeta/ternary.cc/h 实现linear_v2前向推导
phi/kernels/cpu/linear_v2_grad_kernel.cc 实现 Linear V2 的 CPU 反向
phi/kernels/cpu/linear_v2_kernel.cc 实现 Linear V2 的 CPU 前向
phi/kernels/funcs/blas/blaslt_impl.cu.h 为cublaslt gemm执行时加入启发式路径并对齐工作区大小
phi/kernels/gpu/linear_v2_grad_kernel.cu/h 实现 Linear V2 的 GPU 反向内核及头文件定义
phi/kernels/gpu/linear_v2_kernel.cu/h 实现 Linear V2 的 GPU 前向内核及头文件定义
phi/ops/yaml/ops.yaml / backward.yaml / op_compat.yaml 算子相关yaml

3. Python API 与测试

文件路径 描述
python/paddle/amp/amp_lists.py 更新自动混合精度(AMP)的算子黑白名单
python/paddle/autograd/backward_utils.py 更新ALLOW_DYNAMIC_SHAPE_VJP_OPS
python/paddle/nn/functional/common.py 在 Python 层封装 Linear 相关函数,加入回退机制
test/amp/ 多项 AMP 相关单测(API、Master Grad、Promote、PIR)适配linear_v2
test/collective/fleet/test_dygraph_sharding_stage2.py fleet相关单测,要求array_equal硬编码值时,回退
test/ir/pir/ PIR 模式下的反向与子图计算单测适配linear_v2
test/legacy_test/ 包含针对 Nan/Inf 检查、JIT 加载、Lookahead 优化及 PaddleScience 的linear_v2适配

三、后续工作

项目 描述
einsum对齐
旧机制退场

pcard-91067

wanghuancoder
wanghuancoder previously approved these changes Jan 9, 2026
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@codecov-commenter
Copy link

codecov-commenter commented Jan 9, 2026

Codecov Report

❌ Patch coverage is 97.18310% with 6 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@d2bc32e). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...terface/infer_symbolic_shape/multiary_infer_sym.cc 88.88% 3 Missing ⚠️
paddle/phi/infermeta/spmd_rules/linear_v2.cc 98.68% 1 Missing ⚠️
paddle/phi/kernels/cpu/linear_v2_kernel.cc 90.90% 1 Missing ⚠️
python/paddle/nn/functional/common.py 95.65% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #77039   +/-   ##
==========================================
  Coverage           ?   97.18%           
==========================================
  Files              ?        9           
  Lines              ?      213           
  Branches           ?        0           
==========================================
  Hits               ?      207           
  Misses             ?        6           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

JiabinYang
JiabinYang previously approved these changes Jan 12, 2026
Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for auto_code_generator/eager_gen.py

wanghuancoder
wanghuancoder previously approved these changes Jan 13, 2026
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@A-nnonymous
Copy link
Contributor Author

/re-run all-failed

# NOTE(Pan Zhaowu): disable linear_v2 decomp to test infersymbolics
paddle.set_flags(
{
"FLAGS_deny_cinn_ops": "linear_v2",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linear_v2 不在cinn 的支持范围,FLAGS_deny_cinn_ops 这里的设置应该没什么用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的收到,这块我后续删除掉,感谢

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wanghuancoder wanghuancoder merged commit 8f6f5d4 into PaddlePaddle:develop Jan 13, 2026
119 of 131 checks passed
)

# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of tensorrt graph capturing
# and converting.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensorrt下也需要解决一下,只在单测里面加,那跑pir-trt模型的时候就还是用的新的linear,在python/paddle/tensorrt/export.py的def convert前面加上这个环境变量吧

A-nnonymous added a commit to A-nnonymous/Paddle that referenced this pull request Jan 15, 2026
* aligned linear & matmul_with_bias, need check incompatible cases.

* aligned high-dim matmul operation, adding flags control and controlling failed cases

* aligned matmul, start aligning einsum

* Add flag

* polish

* fix shape related issues.

* Finish crash handling

* revert redundant diff

* revert redundant diff

* restrict influence to only CUDA

* Optimized CPU overhead, bypass windows.

* optimize branch cost

* disable dist tensor case

* add GPUPlace check

* fix matmul diff

* add flags related logic

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* bypass win32

* stash

* stash

* align fwd in gpu

* fix fwd miscs

* fix shape miscs

* clean code

* fix grad

* recover redundant diff

* using legacy linear

* add multi-platform support, polish

* refractor

* fix flag and amp rules

* fix CI

* polish

* fix miscs

* add infersymbolics

* Add metaconfig

* fix symbolic, move flags

* fix bwd infermeta

* Add prim linear_v2_grad

* fix fwd decomp

* add proper fallback to fulfill legacy promise

* tmp restrict prim

* Add inferSPMD, fix CI

* fix ci

* fix auto parallel.

* fix TRT and prec test

* add infer_symbolic instance, remove glog including in header

* fix reduncant DtoD cpy

* coverage linear_v2 symbolics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants