-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Align linear op with torch v2.9.1 #77039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Align linear op with torch v2.9.1 #77039
Conversation
wanghuancoder
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## develop #77039 +/- ##
==========================================
Coverage ? 97.18%
==========================================
Files ? 9
Lines ? 213
Branches ? 0
==========================================
Hits ? 207
Misses ? 6
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
JiabinYang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for auto_code_generator/eager_gen.py
014eea2
wanghuancoder
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
/re-run all-failed |
| # NOTE(Pan Zhaowu): disable linear_v2 decomp to test infersymbolics | ||
| paddle.set_flags( | ||
| { | ||
| "FLAGS_deny_cinn_ops": "linear_v2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
linear_v2 不在cinn 的支持范围,FLAGS_deny_cinn_ops 这里的设置应该没什么用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的收到,这块我后续删除掉,感谢
wanghuancoder
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
XiaoguangHu01
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| ) | ||
|
|
||
| # NOTE(Pan Zhaowu): using legacy linear to fulfill promise of tensorrt graph capturing | ||
| # and converting. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensorrt下也需要解决一下,只在单测里面加,那跑pir-trt模型的时候就还是用的新的linear,在python/paddle/tensorrt/export.py的def convert前面加上这个环境变量吧
* aligned linear & matmul_with_bias, need check incompatible cases. * aligned high-dim matmul operation, adding flags control and controlling failed cases * aligned matmul, start aligning einsum * Add flag * polish * fix shape related issues. * Finish crash handling * revert redundant diff * revert redundant diff * restrict influence to only CUDA * Optimized CPU overhead, bypass windows. * optimize branch cost * disable dist tensor case * add GPUPlace check * fix matmul diff * add flags related logic * polish * polish * polish * polish * polish * polish * polish * polish * bypass win32 * stash * stash * align fwd in gpu * fix fwd miscs * fix shape miscs * clean code * fix grad * recover redundant diff * using legacy linear * add multi-platform support, polish * refractor * fix flag and amp rules * fix CI * polish * fix miscs * add infersymbolics * Add metaconfig * fix symbolic, move flags * fix bwd infermeta * Add prim linear_v2_grad * fix fwd decomp * add proper fallback to fulfill legacy promise * tmp restrict prim * Add inferSPMD, fix CI * fix ci * fix auto parallel. * fix TRT and prec test * add infer_symbolic instance, remove glog including in header * fix reduncant DtoD cpy * coverage linear_v2 symbolics
PR Category
Operator Mechanism
PR Types
Improvements
Description
本项目旨在将 PaddlePaddle 中的
linear算子与 PyTorch v2.9.1 进行功能对齐,通过引入linear_v2架构优化对齐底层算子的执行机制、自动微分逻辑及分布式策略。注:由于对该算子的修改影响面十分广泛,为了兼容性考虑,本PR涉及到的所有机制都是可以一键回退的。可以使用
FLAGS_use_legacy_linear=1将现有逻辑回退,保持和PR合入之前一样的算子行为,回退行为本身零GPU开销、CPU额外开销仅一个分支一、变化
1. 性能
linear_v2深度集成cublaslt,在 GPU 上实现了更好的矩阵乘法与 Bias 融合。注意,在组合算子机制被手动开启(使用编译器)的情况下,linear_v2会被拆解为matmul + add,如有需要请使用黑名单机制阻止linear_v2的拆解。
实测数据(warmup 100, repeat 1000)
注:下图中L、N、T分别代表Legacy linear、New linear(linear_v2)和torch(v2.9.1)
Hopper fp16
Hopper bf16
Hopper fp32 (no tf32_override)
在包括了小batchsize(推理)和大batchsize(训练)的Hopper典型场景下,前反向平均加速比如下:
fp16相对老linear机制提升19.25%、6.92%
bf16相对老linear机制提升19.26%、6.84%
fp32相对老linear机制提升6.79%、5.69%
已经确定在小shape下相对于legacy的收益大部分来源于调度开销的节省,而离torch的距离也来源于此。
已经确定在中等、较大shape下的性能收益主要来源于对elewise add的节省,即使在相对于legacy有所回退的case里,和torch的性能也基本可以对齐
已经确定相对于torch v2.9.1的差距基本来源于框架调度开销或reduce性能(主要影响反向,因为该PR的测试环境未合入reduce相关修复)。
2.精度
二、主要文件修改与注解
1. 基础设施与代码生成
common/flags.ccfluid/eager/auto_code_generator/generator/eager_gen.pyfluid/pir/dialect/op_generator/decomp_interface_gen_op_list.pyfluid/pir/dialect/op_generator/op_build_gen.pyfluid/pir/dialect/op_generator/vjp_interface_black_list.pyfluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc/hfluid/prim/api/composite_backward/composite_double_backward_api.hfluid/prim/api/composite_backward/api.yamlfluid/primitive/codegen/decomp_vjp_gen.pyfluid/primitive/decomp_rule/decomp_rule/composite.hfluid/primitive/decomp_vjp/details.h2. 算子逻辑
phi/infermeta/spmd_rules/fused_gemm_epilogue.hphi/infermeta/spmd_rules/linear_v2.cc/hphi/infermeta/ternary.cc/hphi/kernels/cpu/linear_v2_grad_kernel.ccphi/kernels/cpu/linear_v2_kernel.ccphi/kernels/funcs/blas/blaslt_impl.cu.hphi/kernels/gpu/linear_v2_grad_kernel.cu/hphi/kernels/gpu/linear_v2_kernel.cu/hphi/ops/yaml/ops.yaml / backward.yaml / op_compat.yaml3. Python API 与测试
python/paddle/amp/amp_lists.pypython/paddle/autograd/backward_utils.pypython/paddle/nn/functional/common.pytest/amp/test/collective/fleet/test_dygraph_sharding_stage2.pytest/ir/pir/test/legacy_test/三、后续工作
pcard-91067