diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index a5f2fefa323492..b2c091abb62653 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -2370,7 +2370,7 @@ PHI_DEFINE_EXPORTED_bool(use_accuracy_compatible_kernel, /** * Legacy gemm related FLAG * Name: FLAGS_use_legacy_gemm - * Since Version: 3.3.0 + * Since Version: 3.2.2 * Value Range: bool, default=false * Example: * Note: Whether use legacy gemm kernel. @@ -2379,6 +2379,18 @@ PHI_DEFINE_EXPORTED_bool(use_legacy_gemm, false, "Whether use legacy gemm dispatch logics."); +/** + * Legacy gemm related FLAG + * Name: FLAGS_use_legacy_linear + * Since Version: 3.3.1 + * Value Range: bool, default=false + * Example: + * Note: Whether use legacy linear kernel. + */ +PHI_DEFINE_EXPORTED_bool(use_legacy_linear, + true, + "Whether use legacy linear dispatch logics."); + /** * Allocator Compact related FLAG * Name: FLAGS_enable_compact_mem diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 756dd9d38c40a0..ffaae7de12c378 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -86,6 +86,7 @@ "where_double_grad", "bmm_double_grad", "index_put_double_grad", + "linear_v2_double_grad", "gather_nd_double_grad", "reshape_double_grad", "take_along_axis_double_grad", diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 72774939dac6d3..36e0271aeb1f16 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -57,6 +57,7 @@ "lerp", "log_loss", "log_softmax", + "linear_v2", "mean", "mean_all", "meshgrid", @@ -112,6 +113,7 @@ 'layer_norm_grad', 'log_grad', 'matmul_grad', + 'linear_v2_grad', 'max_grad', 'maximum_grad', 'mean_grad', diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index 60840cc60ec5e9..9aa280181ba1f1 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -68,6 +68,7 @@ 'FusedGateAttentionGradInferMeta', 'ResnetBasicBlockInferMeta', 'ResnetBasicBlockGradInferMeta', + 'LinearV2InferMeta', # multiary.h 'AddNInferMeta', 'ApVariadicInferMeta', diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py index 4ffcf5a51224b0..fd9d770f4e2b87 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -44,5 +44,6 @@ 'index_elementwise_get_grad', 'index_elementwise_put_with_tensor_grad', 'index_elementwise_put_grad', + 'linear_v2_grad', 'view_shape_grad', ] diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index cdbb0515674b0e..1671cd86ba00a6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2281,6 +2281,49 @@ bool FusedGemmEpilogueOpInferSymbolicShape( return true; } +bool LinearV2OpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &x_dims = x_shape_or_data.shape(); + + const auto &y_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const auto &y_dims = y_shape_or_data.shape(); + + const auto &bias_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + const auto &bias_dims = bias_shape_or_data.shape(); + + size_t x_rank = x_dims.size(); + size_t y_rank = y_dims.size(); + + std::vector out_shape; + out_shape.reserve(x_rank); + + for (size_t i = 0; i + 2 < x_rank; ++i) { + out_shape.emplace_back(x_dims[i]); + } + + symbol::DimExpr out_M = x_dims[x_rank - 2]; + symbol::DimExpr out_N = y_dims[y_rank - 1]; + + out_shape.emplace_back(out_M); + out_shape.emplace_back(out_N); + + symbol::DimExpr x_K = x_dims[x_rank - 1]; + symbol::DimExpr y_K = y_dims[y_rank - 2]; + + infer_context->AddEqualCstr(x_K, y_K); + // bias_dims[0] equal to out_N + infer_context->AddEqualCstr(out_N, bias_dims[0]); + + infer_context->SetShapeOrDataForValue(op->result(0), + ShapeOrData{TensorExprs(out_shape)}); + + return true; +} + bool FusedMultiTransformerOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index b1d193e44163ea..d71b46967ce162 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -67,6 +67,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBatchNormAct_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedGemmEpilogue) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LinearV2) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedMultiTransformer) OP_DECLARE_INFER_SYMBOLIC_SHAPE(GenerateProposals) OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphKhopSampler) diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index f4c7a145ef73b6..74e1649ebccd0c 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -15,6 +15,7 @@ - exp - scale - matmul +- linear_v2 - expand - sum - abs diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index f612c8d2ec2a34..3d11fec8558b35 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -849,6 +849,33 @@ void add_triple_grad(const paddle::optional& grad_grad_x, } } +template +void linear_v2_double_grad(const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const Tensor& grad_out, + const paddle::optional& grad_input_grad, + const paddle::optional& grad_weight_grad, + const paddle::optional& grad_bias_grad, + Tensor* input_grad, + Tensor* weight_grad, + Tensor* bias_grad, + Tensor* grad_out_grad) { + matmul_double_grad(input, + weight, + grad_out, + grad_input_grad, + grad_weight_grad, + false, + false, + input_grad, + weight_grad, + grad_out_grad); + if (bias_grad) { + add_double_grad(bias, grad_out, nullptr, grad_bias_grad, -1, bias_grad); + } +} + template void subtract_double_grad(const Tensor& y, const Tensor& grad_out, diff --git a/paddle/fluid/primitive/codegen/decomp_vjp_gen.py b/paddle/fluid/primitive/codegen/decomp_vjp_gen.py index db177fec29198a..2345ddfa1caf62 100644 --- a/paddle/fluid/primitive/codegen/decomp_vjp_gen.py +++ b/paddle/fluid/primitive/codegen/decomp_vjp_gen.py @@ -103,6 +103,7 @@ 'logsumexp_grad', 'masked_select_grad', 'matmul_grad', + 'linear_v2_grad', 'max_grad', 'maximum_grad', 'minimum_grad', diff --git a/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h b/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h index 830b301524bf42..c1bd3eb36a9bbb 100644 --- a/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h +++ b/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h @@ -256,6 +256,15 @@ Tensor bmm_decomp(const Tensor& x, const Tensor& y) { return matmul(x, y, false, false); } +template +Tensor linear_v2_decomp(const Tensor& input, + const Tensor& weight, + const Tensor& bias) { + Tensor result = matmul(input, weight, false, false); + result = result + bias; + return result; +} + template std::tuple batch_norm_decomp( const Tensor& x, diff --git a/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h b/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h index f6bb5fa99de6df..a21fa67fae969b 100644 --- a/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h +++ b/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h @@ -1438,6 +1438,21 @@ void matmul_grad(const Tensor& x, } } +template +void linear_v2_grad(const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const Tensor& out_grad, + Tensor* input_grad, + Tensor* weight_grad, + Tensor* bias_grad) { + matmul_grad( + input, weight, out_grad, false, false, input_grad, weight_grad); + if (bias_grad) { + add_grad(bias, bias, out_grad, -1, nullptr, bias_grad); + } +} + template void maximum_grad(const Tensor& x, const Tensor& y, diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 94750577a5debc..71013a6b42a0f6 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -492,6 +492,68 @@ void CudnnLSTMGradInferMeta( UnchangedMultiInferMeta(weight_list.get(), weight_list_grad); } } +void LinearV2GradInferMeta(const MetaTensor& input, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& out_grad, + MetaTensor* input_grad, + MetaTensor* weight_grad, + MetaTensor* bias_grad) { + auto input_dims = input.dims(); + auto weight_dims = weight.dims(); + auto bias_dims = bias.dims(); + auto dout_dims = out_grad.dims(); + + auto dout_mat_dims = common::flatten_to_2d(dout_dims, dout_dims.size() - 1); + + const int64_t input_ndim = input_dims.size(); + auto k_from_dout = input_ndim >= 2 ? dout_dims[input_ndim - 2] : 1; + auto k_from_input = input_ndim >= 2 ? input_dims[input_ndim - 2] : 1; + + bool check_k = + (k_from_dout < 0 || k_from_input < 0) || (k_from_dout == k_from_input); + + if (check_k) { + PADDLE_ENFORCE_EQ( + dout_mat_dims[1], + weight_dims[1], + common::errors::InvalidArgument( + "The last dimension of DOut should be equal with Y's last " + "dimension. But received DOut[-1] = [%d], Y[1] = [%d].", + dout_mat_dims[1], + weight_dims[1])); + } + + for (int32_t i = 0; i + 2 < input_dims.size(); ++i) { + if (dout_dims[i] > 0 && input_dims[i] > 0) { + PADDLE_ENFORCE_EQ( + dout_dims[i], + input_dims[i], + common::errors::InvalidArgument( + "The i dimension of DOut should be equal with i dimension of X." + "But received DOut[%d] = [%d], Y[%d] = [%d].", + i, + dout_dims[i], + i, + input_dims[i])); + } + } + + if (input_grad) { + input_grad->set_dims(input_dims); + input_grad->set_dtype(input.dtype()); + } + + if (weight_grad) { + weight_grad->set_dims(weight_dims); + weight_grad->set_dtype(weight.dtype()); + } + + if (bias_grad) { + bias_grad->set_dims(bias_dims); + bias_grad->set_dtype(bias.dtype()); + } +} void LSTMGradInferMeta(const MetaTensor& input, const MetaTensor& h0, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index d92cfb5ca9592d..f96ded2c536b16 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -183,6 +183,14 @@ PADDLE_API void CudnnLSTMGradInferMeta( MetaTensor* init_c_grad, std::vector weight_list_grad); +PADDLE_API void LinearV2GradInferMeta(const MetaTensor& input, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& out_grad, + MetaTensor* input_grad, + MetaTensor* weight_grad, + MetaTensor* bias_grad); + PADDLE_API void LSTMGradInferMeta(const MetaTensor& input, const MetaTensor& h0, const MetaTensor& c0, diff --git a/paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h b/paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h index 3bf7fbe25b0225..7821b13f4597ff 100644 --- a/paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h +++ b/paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h @@ -19,6 +19,19 @@ limitations under the License. */ namespace phi { namespace distributed { +void FillMatmulPartOperandNotation(const int x_ndim, + const int y_ndim, + std::string* x_axes, + std::string* y_axes, + std::string* out_axes); +TensorDistAttr GetMatmulPartInferredDistAttr( + const TensorDistAttr& origin_dist_attr, + const std::vector& shape, + const std::string& tensor_axis, + const std::unordered_map& axis_to_dim_map, + bool trans_axis); +void SetTensorDistAttrReplicated(TensorDistAttr* dist_attr, const int ndim); + SpmdInfo FusedGemmEpilogueInferSpmdBase(const DistMetaTensor& x, const DistMetaTensor& y, const DistMetaTensor& bias, diff --git a/paddle/phi/infermeta/spmd_rules/linear_v2.cc b/paddle/phi/infermeta/spmd_rules/linear_v2.cc new file mode 100644 index 00000000000000..95b82b9a4dac84 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/linear_v2.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/linear_v2.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi::distributed { + +using phi::distributed::auto_parallel::str_join; + +SpmdInfo LinearV2InferSpmdBase(const DistMetaTensor& input, + const DistMetaTensor& weight, + const DistMetaTensor& bias) { + // Step0: verify input args based on matmul logic + auto ori_input_shape = common::vectorize(input.dims()); + auto ori_weight_shape = common::vectorize(weight.dims()); + auto ori_bias_shape = common::vectorize(bias.dims()); + int input_ndim = static_cast(ori_input_shape.size()); + int weight_ndim = static_cast(ori_weight_shape.size()); + int bias_ndim = static_cast(ori_bias_shape.size()); + const auto& input_dist_attr_src = input.dist_attr(); + const auto& weight_dist_attr_src = weight.dist_attr(); + const auto& bias_dist_attr_src = bias.dist_attr(); + std::vector input_dims_mapping = input_dist_attr_src.dims_mapping(); + std::vector weight_dims_mapping = + weight_dist_attr_src.dims_mapping(); + std::vector bias_dims_mapping = bias_dist_attr_src.dims_mapping(); + + PADDLE_ENFORCE_EQ(input_ndim, + input_dims_mapping.size(), + common::errors::InvalidArgument( + "LinearV2, The Tensor input's rank [%d] and input's " + "dims_mapping size [%d] are not matched.", + input_ndim, + input_dims_mapping.size())); + PADDLE_ENFORCE_EQ(weight_ndim, + weight_dims_mapping.size(), + common::errors::InvalidArgument( + "LinearV2, The Tensor weight's rank [%d] and weight's " + "dims_mapping size [%d] are not matched.", + weight_ndim, + weight_dims_mapping.size())); + + PADDLE_ENFORCE_EQ( + bias_ndim, + 1, + common::errors::InvalidArgument( + "LinearV2, The ndim of bias should be 1, but got [%d].", bias_ndim)); + + VLOG(4) << "LinearV2SPMDRule InferForward Inputs: "; + VLOG(4) << "input shape: [" << str_join(ori_input_shape) + << "], input_dims_mapping: [" << str_join(input_dims_mapping) << "];"; + VLOG(4) << "weight shape: [" << str_join(ori_weight_shape) + << "], weight_dims_mapping: [" << str_join(weight_dims_mapping) + << "];"; + VLOG(4) << "bias shape: [" << str_join(ori_bias_shape) + << "], bias_dims_mapping: [" << str_join(bias_dims_mapping) << "];"; + // Step1: build Einsum Notation + std::string input_axes; + std::string weight_axes; + std::string out_axes; + FillMatmulPartOperandNotation( + input_ndim, weight_ndim, &input_axes, &weight_axes, &out_axes); + + // Step2.1: Sharding Merge + std::pair> x_pair(input_axes, + input_dims_mapping); + std::pair> y_pair(weight_axes, + weight_dims_mapping); + auto axis_to_dim_map = ShardingMergeForTensors({x_pair, y_pair}); + + // Step2.2: Infer Output's Dims Mapping. + TensorDistAttr output_dist_attr_dst = + CopyTensorDistAttrForOutput(input_dist_attr_src); + std::vector out_dims_mapping; + out_dims_mapping.reserve(out_axes.size()); + for (size_t i = 0; i < out_axes.size(); ++i) { + out_dims_mapping.push_back(axis_to_dim_map[out_axes.substr(i, 1)]); + } + output_dist_attr_dst.set_dims_mapping(out_dims_mapping); + + // Step2.3: Merge and get Inputs' New Dims Mapping. + auto x_shape = common::vectorize(input.dims()); + auto y_shape = common::vectorize(weight.dims()); + + TensorDistAttr x_dist_attr_dst = GetMatmulPartInferredDistAttr( + input_dist_attr_src, x_shape, input_axes, axis_to_dim_map, false); + TensorDistAttr y_dist_attr_dst = GetMatmulPartInferredDistAttr( + weight_dist_attr_src, y_shape, weight_axes, axis_to_dim_map, false); + TensorDistAttr bias_dist_attr_dst = + CopyTensorDistAttrForOutput(bias_dist_attr_src); + bias_dist_attr_dst.set_dims_mapping( + std::vector{output_dist_attr_dst.dims_mapping().back()}); + + // Step2.3: Handle Partial + // Step2.3.1 Output Partial + std::vector partial_on_dims = + ResoluteOutputPartialDimension(axis_to_dim_map, out_axes); + output_dist_attr_dst.set_partial_status(partial_on_dims); + + if (output_dist_attr_dst.is_partial()) { + bias_dist_attr_dst.set_partial_status( + output_dist_attr_dst.partial_status()); + if (!IsPartialLegal(bias_dist_attr_dst) || + !IsPartialLegal(output_dist_attr_dst)) { + VLOG(4) << "LinearV2 partial output illegal, force set output " + "to replicated."; + output_dist_attr_dst.clean_partial_status(); + bias_dist_attr_dst.clean_partial_status(); + SetTensorDistAttrReplicated(&x_dist_attr_dst, input_ndim); + SetTensorDistAttrReplicated(&y_dist_attr_dst, weight_ndim); + SetTensorDistAttrReplicated(&bias_dist_attr_dst, bias_ndim); + SetTensorDistAttrReplicated(&output_dist_attr_dst, out_axes.size()); + } + } + TensorDistAttr output_reserve_dist_attr_dst = + CopyTensorDistAttrForOutput(output_dist_attr_dst); + VLOG(4) << "LinearV2SPMDRule InferForward: " + << "Einsum notation: [" << input_axes << "," << weight_axes << " --> " + << out_axes << "+" << out_axes.back() << "]. " << std::endl; + LogInputDistAttr( + "input", ori_input_shape, input_dist_attr_src, x_dist_attr_dst); + LogInputDistAttr( + "weight", ori_weight_shape, weight_dist_attr_src, y_dist_attr_dst); + LogInputDistAttr( + "Bias", ori_bias_shape, bias_dist_attr_src, bias_dist_attr_dst); + LogOutputDistAttr("Output", output_dist_attr_dst); + + return {{x_dist_attr_dst, y_dist_attr_dst, bias_dist_attr_dst}, + {output_dist_attr_dst, output_reserve_dist_attr_dst}}; +} +SpmdInfo LinearV2InferSpmd(const DistMetaTensor& input, + const DistMetaTensor& weight, + const DistMetaTensor& bias) { + return LinearV2InferSpmdBase(input, weight, bias); +} +} // namespace phi::distributed diff --git a/paddle/phi/infermeta/spmd_rules/linear_v2.h b/paddle/phi/infermeta/spmd_rules/linear_v2.h new file mode 100644 index 00000000000000..ed52f61cc8054a --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/linear_v2.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" +#include "paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h" + +namespace phi { +namespace distributed { +SpmdInfo LinearV2InferSpmdBase(const DistMetaTensor& input, + const DistMetaTensor& weight, + const DistMetaTensor& bias); +SpmdInfo LinearV2InferSpmd(const DistMetaTensor& input, + const DistMetaTensor& weigh, + const DistMetaTensor& bias); +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.cc b/paddle/phi/infermeta/spmd_rules/rules.cc index ae7af0f90f2c03..d5f29f1e509d13 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.cc +++ b/paddle/phi/infermeta/spmd_rules/rules.cc @@ -814,6 +814,10 @@ PD_REGISTER_SPMD_RULE( fused_gemm_epilogue, PD_INFER_SPMD(phi::distributed::FusedGemmEpilogueInferSpmdBase)); +// linear_v2 +PD_REGISTER_SPMD_RULE(linear_v2, + PD_INFER_SPMD(phi::distributed::LinearV2InferSpmdBase)); + // take_along_axis PD_REGISTER_SPMD_RULE( take_along_axis, diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index ff47ee4acea09f..2c3c4bdce04f87 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -59,6 +59,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/instance_norm.h" #include "paddle/phi/infermeta/spmd_rules/label_smooth.h" #include "paddle/phi/infermeta/spmd_rules/layer_norm.h" +#include "paddle/phi/infermeta/spmd_rules/linear_v2.h" #include "paddle/phi/infermeta/spmd_rules/logsumexp.h" #include "paddle/phi/infermeta/spmd_rules/matmul.h" #include "paddle/phi/infermeta/spmd_rules/mean_all.h" diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 20ac55035f6fcc..2a114bb96b47ce 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1550,6 +1550,79 @@ void LerpInferMeta(const MetaTensor& x, out->share_lod(x); } +void LinearV2InferMeta(const MetaTensor& input, + const MetaTensor& weight, + const MetaTensor& bias, + MetaTensor* out, + MetaConfig config) { + const auto& input_dims = input.dims(); + const auto& weight_dims = weight.dims(); + const int64_t weight_ndim = weight.dims().size(); + const bool is_bias_need_broadcast = bias.numel() == 1; + const bool is_valid_bias = + is_bias_need_broadcast || bias.numel() == weight.dims()[weight_ndim - 1]; + + PADDLE_ENFORCE_EQ(weight_dims.size(), + 2, + common::errors::InvalidArgument( + "The Input tensor Y's dimension of FusedGemmEpilogueOp " + " should be 2, but got %d.", + weight_dims.size())); + PADDLE_ENFORCE_GE(input_dims.size(), + 1, + common::errors::InvalidArgument( + "The Input tensor X's dimension of FusedGemmEpilogueOp " + " should be >= 1, but got %d.", + input_dims.size())); + PADDLE_ENFORCE_LE( + bias.dims().size(), + 1, + common::errors::InvalidArgument("Bias must be lesser than 1D")); + + PADDLE_ENFORCE_EQ(is_valid_bias, + true, + common::errors::InvalidArgument( + "Bias must be equal (or can be broadcasted) to the " + "last dimension of weight")); + + // regard [k] x [k, n] -> [n] + if (input_dims.size() == 1) { + out->set_dims(common::make_ddim({weight_dims[1]})); + out->set_dtype(input.dtype()); + return; + } + + auto input_mat_dims = + common::flatten_to_2d(input_dims, input_dims.size() - 1); + + auto input_rank = input_dims.size(); + int64_t K_from_input = input_mat_dims[1]; + int64_t K_from_weight = weight_dims[0]; + const bool check_dim = + (!config.is_runtime && K_from_input != -1) || config.is_runtime; + if (check_dim) { + PADDLE_ENFORCE_EQ( + K_from_input, + K_from_weight, + common::errors::InvalidArgument( + "The last dimension of X should be equal with Y's first dimension." + "But received X[-1] = [%d], Y[0] = [%d].", + K_from_input, + K_from_weight)); + } + std::vector out_dims; + out_dims.reserve(input_rank); + + for (int i = 0; i + 2 < input_rank; ++i) { + out_dims.push_back(input_dims[i]); + } + out_dims.push_back(input_dims[input_rank - 2]); + + out_dims.push_back(weight_dims[1]); + out->set_dims(common::make_ddim(out_dims)); + out->set_dtype(input.dtype()); +} + void LinspaceInferMeta(const MetaTensor& start, const MetaTensor& stop, const MetaTensor& number, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index d3b35e5462e49d..2cf3f0374d1183 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -274,6 +274,12 @@ PADDLE_API void LerpInferMeta(const MetaTensor& x, const MetaTensor& weight, MetaTensor* out); +PADDLE_API void LinearV2InferMeta(const MetaTensor& input, + const MetaTensor& weight, + const MetaTensor& bias, + MetaTensor* out, + MetaConfig config = MetaConfig()); + PADDLE_API void LinspaceRawInferMeta(const MetaTensor& start, const MetaTensor& stop, const MetaTensor& number, diff --git a/paddle/phi/kernels/cpu/linear_v2_grad_kernel.cc b/paddle/phi/kernels/cpu/linear_v2_grad_kernel.cc new file mode 100644 index 00000000000000..51199623e92be8 --- /dev/null +++ b/paddle/phi/kernels/cpu/linear_v2_grad_kernel.cc @@ -0,0 +1,18 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/linear_v2_grad_kernel.h" + +PD_REGISTER_KERNEL( + linear_v2_grad, CPU, ALL_LAYOUT, phi::LinearV2GradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/linear_v2_kernel.cc b/paddle/phi/kernels/cpu/linear_v2_kernel.cc new file mode 100644 index 00000000000000..ca0ada016ad2ca --- /dev/null +++ b/paddle/phi/kernels/cpu/linear_v2_kernel.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include "glog/logging.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +#include "paddle/common/enforce.h" +#include "paddle/phi/kernels/addmm_kernel.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/impl/matmul_kernel_impl.h" +#include "paddle/phi/kernels/linear_v2_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" +#include "paddle/phi/kernels/tile_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/scope_guard.h" + +namespace phi { + +template +void LinearV2Kernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& bias, + DenseTensor* out) { + dev_ctx.template Alloc(out); + if (out->numel() == 0) { + return; + } + + // When in CPU, we use legacy linear_logic by default. + // TODO(Pan Zhaowu): Adding more efficient kernel for CPU. + std::vector input_dims_vec = common::vectorize(input.dims()); + std::vector weight_dims_vec = common::vectorize(weight.dims()); + + MatMulFunction(dev_ctx, + input, + weight, + input_dims_vec, + weight_dims_vec, + out, + false, + false); + AddKernel(dev_ctx, *out, bias, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + linear_v2, CPU, ALL_LAYOUT, phi::LinearV2Kernel, float, double) {} diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h index 8f51275603e0e2..6c5016513d531e 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -32,6 +32,7 @@ limitations under the License. */ COMMON_DECLARE_int64(cublaslt_exhaustive_search_times); COMMON_DECLARE_bool(enable_blaslt_global_search); +COMMON_DECLARE_bool(use_legacy_linear); #endif namespace phi { @@ -466,7 +467,9 @@ struct CublasLtBase { // NOTE(limingshu): As workspace_size varies from different DL framework, // I wonder is there any smarter idea for workspace setting, currently I // just followed the settings from the NVIDIA colleague`s setting. - size_t workspace_size = static_cast(4) * 1024 * 1024; + size_t workspace_size = FLAGS_use_legacy_linear + ? static_cast(4) * 1024 * 1024 + : static_cast(1) * 1024 * 1024; phi::Allocator::AllocationPtr workspace = GetWorkspace(dev_ctx, workspace_size); @@ -490,25 +493,56 @@ struct CublasLtBase { cache.SetSubKey(sub_key, reinterpret_cast(best_desc)); } } + cublasLtMatmulHeuristicResult_t heuristic_results = {}; + if (!FLAGS_use_legacy_linear) { + cublasLtMatmulPreference_t preference; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceCreate(&preference)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(workspace_size))); + + int returned_results = 0; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, + desc->op_desc, + desc->y_desc, + desc->x_desc, + desc->out_desc, + desc->out_desc, + preference, + 1, + &heuristic_results, + &returned_results)); + PADDLE_ENFORCE_GT( + returned_results, + 0, + common::errors::Unavailable("No GEMM algorithm available.")); + + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceDestroy(preference)); + } VLOG(7) << "[Impl CublasltDescriptor] "; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmul(cublaslt_handle, - desc->op_desc, - static_cast(&alpha), - y_ptr, - desc->y_desc, - x_ptr, - desc->x_desc, - static_cast(&beta), - out_ptr, - desc->out_desc, - out_ptr, - desc->out_desc, - desc->algo, - workspace->ptr(), - workspace_size, - dev_ctx.stream())); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul( + cublaslt_handle, + desc->op_desc, + static_cast(&alpha), + y_ptr, + desc->y_desc, + x_ptr, + desc->x_desc, + static_cast(&beta), + out_ptr, + desc->out_desc, + out_ptr, + desc->out_desc, + FLAGS_use_legacy_linear ? desc->algo : &heuristic_results.algo, + workspace->ptr(), + workspace_size, + dev_ctx.stream())); } static void SearchBestAlgo(const phi::GPUContext& dev_ctx, diff --git a/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu new file mode 100644 index 00000000000000..e010d6460fa64b --- /dev/null +++ b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/phi/kernels/linear_v2_grad_kernel.h" + +PD_REGISTER_KERNEL(linear_v2_grad, + GPU, + ALL_LAYOUT, + phi::LinearV2GradKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/linear_v2_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_kernel.cu new file mode 100644 index 00000000000000..7f000ab1c83ef7 --- /dev/null +++ b/paddle/phi/kernels/gpu/linear_v2_kernel.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include "glog/logging.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +#include "paddle/common/enforce.h" +#include "paddle/phi/kernels/addmm_kernel.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/impl/matmul_kernel_impl.h" +#include "paddle/phi/kernels/linear_v2_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" +#include "paddle/phi/kernels/tile_kernel.h" + +#ifdef PADDLE_WITH_HIP +#include +#include +#else +#include // NOLINT +#include "cuda.h" // NOLINT +#endif + +#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) || \ + defined(PADDLE_WITH_HIP) + +#include "paddle/common/flags.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/scope_guard.h" +#include "paddle/utils/optional.h" +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 +#include "paddle/phi/backends/dynload/cublasLt.h" +#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" +#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" +#elif defined(PADDLE_WITH_HIP) +#include "paddle/phi/backends/dynload/hipblasLt.h" +#include "paddle/phi/backends/gpu/rocm/rocm_helper.h" +#include "paddle/phi/kernels/funcs/blas/blaslt_impl.hip.h" +#endif + +#endif +COMMON_DECLARE_bool(use_legacy_linear); + +namespace phi { + +template +void LinearV2Kernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& bias, + DenseTensor* out) { + dev_ctx.template Alloc(out); + if (out->numel() == 0) { + return; + } + +// broadcast bias, reshape input, run_fuse, reshape output +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION > 11060 && \ + !defined(PADDLE_WITH_HIP) && !defined(_WIN32) + if (!FLAGS_use_legacy_linear) { + VLOG(10) << "Use LinearV2Kernel with cublaslt"; + const auto out_dim_original = out->dims(); + const auto [M, N, K] = canonicalize_dims(input, weight); + VLOG(10) << "M: " << M << ", N: " << N << ", K: " << K; + + DenseTensor input_processed; + DenseTensor weight_processed; + DenseTensor output_processed; + phi::ReshapeKernel(dev_ctx, input, {M, K}, &input_processed); + phi::ReshapeKernel(dev_ctx, weight, {K, N}, &weight_processed); + out->Resize(common::make_ddim({M, N})); + VLOG(10) << "input_processed: " << input_processed.dims() + << ", weight_processed: " << weight_processed.dims() + << ", output_processed: " << out->dims(); + + if (N > 1 && K > 1) { + DenseTensor bias_processed; + if (bias.numel() != N) { + // only broadcast to 1D bias whatsoever + // pass1: scalar to 1D + phi::TileKernel(dev_ctx, bias, {N}, &bias_processed); + } else { + bias_processed = bias; + } + // CublasLt path with bias add epilogue + phi::funcs::LinearWithCublasLt::Run( + dev_ctx, + &input, + &weight, + out, + static_cast(bias_processed.data()), + nullptr, + M, + N, + K, + false, + false, + phi::funcs::MatmulFusedType::kMatmulBias); + } else { + DenseTensor bias_processed; + if (bias.numel() != (M * N)) { + phi::ReshapeKernel( + dev_ctx, bias, {1, bias.numel()}, &bias_processed); + VLOG(10) << "bias.dim(): " << bias.dims(); + VLOG(10) << "M*N: " << M * N; + VLOG(10) << "bias tiling and addmm calculating"; + // only broadcast to 1D bias whatsoever + phi::TileKernel( + dev_ctx, bias_processed, {M, 1}, &bias_processed); + VLOG(10) << "bias_processed.dims(): " << bias_processed.dims(); + } else { + bias_processed = bias; + } + phi::AddmmKernel(dev_ctx, + bias_processed, + input_processed, + weight_processed, + 1.0f, + 1.0f, + out); + } + VLOG(10) << "linear calculate complete"; + out->Resize(out_dim_original); + } else // NOLINT +#endif + // Fallback logic for legacy CUDA version or other hardware. + // Or specified by user to use a legacy behaviour. + { // NOLINT + // NOTE(Pan Zhaowu): Fallback logic for legacy CUDA version or DCU. + std::vector input_dims_vec = common::vectorize(input.dims()); + std::vector weight_dims_vec = + common::vectorize(weight.dims()); + + MatMulFunction(dev_ctx, + input, + weight, + input_dims_vec, + weight_dims_vec, + out, + false, + false); + AddKernel(dev_ctx, *out, bias, out); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(linear_v2, + GPU, + ALL_LAYOUT, + phi::LinearV2Kernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/paddle/phi/kernels/linear_v2_grad_kernel.h b/paddle/phi/kernels/linear_v2_grad_kernel.h new file mode 100644 index 00000000000000..6fab0e08665166 --- /dev/null +++ b/paddle/phi/kernels/linear_v2_grad_kernel.h @@ -0,0 +1,60 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include "glog/logging.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +#include "paddle/common/flags.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LinearV2GradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& bias, + const DenseTensor& out_grad, + DenseTensor* input_grad, + DenseTensor* weight_grad, + DenseTensor* bias_grad) { + phi::MatmulGradKernel( + dev_ctx, input, weight, out_grad, false, false, input_grad, weight_grad); + + if (bias_grad && bias.numel() != 0) { + if (out_grad.numel() != bias_grad->numel()) { + dev_ctx.template Alloc(bias_grad); + std::vector reduce_dims = + funcs::GetReduceDim(bias.dims(), out_grad.dims(), -1); + phi::SumKernel( + dev_ctx, out_grad, reduce_dims, out_grad.dtype(), false, bias_grad); + bias_grad->Resize(bias.dims()); + } else { + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, bias_grad); + bias_grad->Resize(bias.dims()); + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/linear_v2_kernel.h b/paddle/phi/kernels/linear_v2_kernel.h new file mode 100644 index 00000000000000..8eb17f9ca329e0 --- /dev/null +++ b/paddle/phi/kernels/linear_v2_kernel.h @@ -0,0 +1,47 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" + +namespace phi { + +// we don't receive 2+d tensor as weight +inline std::tuple canonicalize_dims( + const DenseTensor& input, const DenseTensor& weight) { + const auto x_dims = input.dims(); + const auto y_dims = weight.dims(); + const int64_t N = y_dims.size() < 2 ? 1 : y_dims[y_dims.size() - 1]; + const int64_t K = y_dims.size() < 2 ? y_dims[0] : y_dims[y_dims.size() - 2]; + + int64_t M = x_dims.size() >= 2 ? x_dims[x_dims.size() - 2] : 1; + if (x_dims.size() > 2) { + // Accumulate the batch dims for input + for (int64_t i = 0; i < x_dims.size() - 2; ++i) { + M *= x_dims[i]; + } + } + + return {M, N, K}; +} + +template +void LinearV2Kernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& bias, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 53e87720ffb065..c581103e62895b 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -2054,6 +2054,25 @@ data_transform : skip_transform : out_size, size_tensor, scale_tensor +- backward_op : linear_v2_double_grad + forward: linear_v2_grad (Tensor input, Tensor weight, Tensor bias, Tensor grad_out) -> Tensor(grad_input), Tensor(grad_weight), Tensor(grad_bias) + args : (Tensor input, Tensor weight, Tensor bias, Tensor grad_out, Tensor grad_input_grad, Tensor grad_weight_grad, Tensor grad_bias_grad) + output: Tensor(input_grad), Tensor(weight_grad), Tensor(bias_grad), Tensor(grad_out_grad) + infer_meta : + func : GeneralQuaternaryGradInferMeta + param : [input, weight, bias, grad_out] + composite: linear_v2_double_grad(input, weight, bias, grad_out, grad_input_grad, grad_weight_grad, grad_bias_grad, input_grad, weight_grad, bias_grad, grad_out_grad) + +- backward_op : linear_v2_grad + forward : linear_v2 (Tensor input, Tensor weight, Tensor bias) -> Tensor(out) + args: (Tensor input, Tensor weight, Tensor bias, Tensor out_grad) + output: Tensor(input_grad), Tensor(weight_grad), Tensor(bias_grad) + infer_meta : + func : LinearV2GradInferMeta + kernel : + func : linear_v2_grad + backward: linear_v2_double_grad + - backward_op : log10_grad forward : log10 (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index fa4999de08ee90..13af4c098d70c8 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -2307,6 +2307,9 @@ extra : attrs : [bool use_mkldnn = false, bool use_onednn = false] +- op : linear_v2 + backward : linear_v2_grad, linear_v2_double_grad + - op : linspace inputs : {start : Start, stop : Stop, number : Num} diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 7ffec5e5a4deeb..db1023fa7229ad 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3180,6 +3180,18 @@ skip_transform : out_size, size_tensor, scale_tensor interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : linear_v2 + args : (Tensor input, Tensor weight, Tensor bias) + output : Tensor(out) + infer_meta : + func : LinearV2InferMeta + spmd_rule : LinearV2InferSpmd + kernel : + func : linear_v2 + data_type : input + backward : linear_v2_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface + - op : linspace args : (Tensor start, Tensor stop, Tensor number, DataType dtype, Place place) output : Tensor(out) diff --git a/python/paddle/amp/amp_lists.py b/python/paddle/amp/amp_lists.py index 4b7914a0acd1cd..09ef995c51c027 100644 --- a/python/paddle/amp/amp_lists.py +++ b/python/paddle/amp/amp_lists.py @@ -22,6 +22,7 @@ 'einsum', 'matmul', 'matmul_v2', + 'linear_v2', 'max_pool2d_with_index', 'mul', 'fused_gemm_epilogue', diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index f103e1ae7e6f62..32c5c40c3ed053 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -66,6 +66,7 @@ "pd_op.log", "pd_op.logcumsumexp", "pd_op.logsumexp", + "pd_op.linear_v2", "pd_op.matmul", "pd_op.max", "pd_op.maximum", diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 675208d17faae1..3d3e8e5027e4bc 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -2529,50 +2529,61 @@ def linear( [ 1.08524013, 1.08524013, 1.08524013, 1.08524013], [-0.67769694, -0.67769694, -0.67769694, -0.67769694]]) """ - if in_dynamic_mode(): - # TODO(jiabin): using addmm for fast forward route - return _C_ops.linear(x, weight, bias) + # If not specified by user to use legacy linear, or not CUDA compatible, we fallback. - elif in_pir_mode(): - out = _C_ops.matmul(x, weight, False, False) - if bias is not None: - return _C_ops.add(out, bias) + if ( + paddle.get_flags("FLAGS_use_legacy_linear")["FLAGS_use_legacy_linear"] + or not paddle.is_compiled_with_cuda() + or not in_dynamic_or_pir_mode() + ): + if in_dynamic_mode(): + return _C_ops.linear(x, weight, bias) + + elif in_pir_mode(): + out = _C_ops.matmul(x, weight, False, False) + if bias is not None: + return _C_ops.add(out, bias) + else: + return out else: - return out - else: - helper = LayerHelper('linear', **locals()) - dtype = x.dtype + helper = LayerHelper('linear', **locals()) + dtype = x.dtype - check_variable_and_dtype( - x, 'x', ["uint16", 'float16', 'float32', 'float64'], 'linear' - ) - check_dtype( - dtype, - 'dtype', - ["uint16", 'float16', 'float32', 'float64'], - 'linear', - ) + check_variable_and_dtype( + x, 'x', ["uint16", 'float16', 'float32', 'float64'], 'linear' + ) + check_dtype( + dtype, + 'dtype', + ["uint16", 'float16', 'float32', 'float64'], + 'linear', + ) - inputs = {'X': [x], 'Y': [weight]} - attrs = {'trans_x': False, 'trans_y': False} - tmp = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type='matmul_v2', - inputs=inputs, - outputs={'Out': tmp}, - attrs=attrs, - ) - if bias is not None: - res = helper.create_variable_for_type_inference(dtype) + inputs = {'X': [x], 'Y': [weight]} + attrs = {'trans_x': False, 'trans_y': False} + tmp = helper.create_variable_for_type_inference(dtype) helper.append_op( - type='elementwise_add', - inputs={'X': [tmp], 'Y': [bias]}, - outputs={'Out': [res]}, - attrs={'axis': -1}, + type='matmul_v2', + inputs=inputs, + outputs={'Out': tmp}, + attrs=attrs, ) + if bias is not None: + res = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='elementwise_add', + inputs={'X': [tmp], 'Y': [bias]}, + outputs={'Out': [res]}, + attrs={'axis': -1}, + ) + else: + res = tmp + return res + else: + if bias is not None: + return _C_ops.linear_v2(x, weight, bias) else: - res = tmp - return res + return _C_ops.matmul(x, weight) def label_smooth( diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py index 10843a00f99298..691f38e0c6b67b 100644 --- a/test/amp/test_amp_api.py +++ b/test/amp/test_amp_api.py @@ -23,6 +23,8 @@ from paddle.base import core from paddle.static import amp +paddle.set_flags({"FLAGS_use_legacy_linear": True}) + @unittest.skipIf( not core.is_compiled_with_cuda() and not core.is_compiled_with_xpu(), diff --git a/test/amp/test_amp_master_grad.py b/test/amp/test_amp_master_grad.py index 9e646ef575d50b..8eb50136dbb6b5 100644 --- a/test/amp/test_amp_master_grad.py +++ b/test/amp/test_amp_master_grad.py @@ -19,6 +19,8 @@ import paddle from paddle.base import core +paddle.set_flags({"FLAGS_use_legacy_linear": True}) + class SimpleNet(paddle.nn.Layer): def __init__(self, input_size, output_size): diff --git a/test/amp/test_amp_promote.py b/test/amp/test_amp_promote.py index 76d48e66ca4314..66fc89e376d61f 100644 --- a/test/amp/test_amp_promote.py +++ b/test/amp/test_amp_promote.py @@ -21,6 +21,8 @@ from paddle.base import core from paddle.static import amp +paddle.set_flags({"FLAGS_use_legacy_linear": True}) + @unittest.skipIf( not core.is_compiled_with_cuda() and not core.is_compiled_with_xpu(), diff --git a/test/amp/test_pir_amp.py b/test/amp/test_pir_amp.py index 517e7814d7d493..019f8ae1149408 100644 --- a/test/amp/test_pir_amp.py +++ b/test/amp/test_pir_amp.py @@ -62,9 +62,11 @@ def test_linear_amp_o1(self): for op in main.global_block().ops: if op.name() == 'pd_op.cast': cast_op_count += 1 - np.testing.assert_equal(out1.dtype, core.DataType.FLOAT32) + # NOTE(Pan Zhaowu): After implementation of linear_v2, there's no + # need for mix-precision add op applies to intermediate result. + np.testing.assert_equal(out1.dtype, core.DataType.FLOAT16) np.testing.assert_equal(out2.dtype, core.DataType.FLOAT32) - np.testing.assert_equal(cast_op_count, 3) + np.testing.assert_equal(cast_op_count, 4) _white_list, _black_list = core._get_amp_op_list() np.testing.assert_equal(len(_white_list), 0) np.testing.assert_equal(len(_black_list), 0) @@ -88,9 +90,11 @@ def test_linear_amp_bf16_o1(self): for op in main.global_block().ops: if op.name() == 'pd_op.cast': cast_op_count += 1 - np.testing.assert_equal(out1.dtype, core.DataType.FLOAT32) + # NOTE(Pan Zhaowu): After implementation of linear_v2, there's no + # need for mix-precision add op applies to intermediate result. + np.testing.assert_equal(out1.dtype, core.DataType.BFLOAT16) np.testing.assert_equal(out2.dtype, core.DataType.FLOAT32) - np.testing.assert_equal(cast_op_count, 3) + np.testing.assert_equal(cast_op_count, 4) _white_list, _black_list = core._get_amp_op_list() np.testing.assert_equal(len(_white_list), 0) np.testing.assert_equal(len(_black_list), 0) diff --git a/test/collective/fleet/test_dygraph_sharding_stage2.py b/test/collective/fleet/test_dygraph_sharding_stage2.py index f300e4d7bc25f2..e48b63ad0cc067 100644 --- a/test/collective/fleet/test_dygraph_sharding_stage2.py +++ b/test/collective/fleet/test_dygraph_sharding_stage2.py @@ -18,6 +18,11 @@ TestMultipleAccelerators, ) +import paddle + +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of array equal. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) + class TestDygraphShardingStage2(TestMultipleAccelerators): # check sharding logic as well as the accuracy with single mode diff --git a/test/dygraph_to_static/test_gradname_parse.py b/test/dygraph_to_static/test_gradname_parse.py index 4dbcdb44a7138b..8e7fb017362b6e 100644 --- a/test/dygraph_to_static/test_gradname_parse.py +++ b/test/dygraph_to_static/test_gradname_parse.py @@ -21,6 +21,9 @@ ) import paddle + +# NOTE(Pan Zhaowu): Using decomp rules to fulfill promise of high-level grad, +paddle.core._set_prim_all_enabled(True) from paddle.nn import BatchNorm, Linear diff --git a/test/ir/pir/cinn/sub_graphs/base.py b/test/ir/pir/cinn/sub_graphs/base.py index 0c3ac360e46e31..c2f2c3616e48ff 100644 --- a/test/ir/pir/cinn/sub_graphs/base.py +++ b/test/ir/pir/cinn/sub_graphs/base.py @@ -18,6 +18,9 @@ import paddle +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of array equal in test_ast_prim_cinn +paddle.set_flags({"FLAGS_use_legacy_linear": True}) + class TestBase(unittest.TestCase): def setUp(self): diff --git a/test/ir/pir/test_ir_backward.py b/test/ir/pir/test_ir_backward.py index 9e560aba43ce8d..4c7ddefcb1e53d 100644 --- a/test/ir/pir/test_ir_backward.py +++ b/test/ir/pir/test_ir_backward.py @@ -17,6 +17,11 @@ import numpy as np import paddle + +# NOTE(Pan Zhaowu): Using legacy_linear to fulfill promise of high-level grad, +# with no side-effects to other ops. +# linear_v2's decomposed grad is fully tested in test_gradname_parse.py +paddle.set_flags({"FLAGS_use_legacy_linear": True}) from paddle.autograd.backward_utils import ValueDict, ValueSet from paddle.autograd.ir_backward import grad from paddle.base.wrapped_decorator import signature_safe_contextmanager diff --git a/test/legacy_test/check_nan_inf_base_dygraph.py b/test/legacy_test/check_nan_inf_base_dygraph.py index dc83c23a90fb26..7edbafb61cc1eb 100644 --- a/test/legacy_test/check_nan_inf_base_dygraph.py +++ b/test/legacy_test/check_nan_inf_base_dygraph.py @@ -19,6 +19,9 @@ import paddle from paddle import nn +# NOTE(Pan Zhaowu): Using legacy linear to fulfill the hard-coded op_count in test_nan_inf.py, +# which summon this script individually, with horrible design. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) # os.environ["GLOG_vmodule"] = "nan_inf_utils_detail=10" diff --git a/test/legacy_test/test_cumsum_op.py b/test/legacy_test/test_cumsum_op.py index 497e41f606ea43..f00fb54c114e94 100644 --- a/test/legacy_test/test_cumsum_op.py +++ b/test/legacy_test/test_cumsum_op.py @@ -31,6 +31,9 @@ ) import paddle + +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of hard-coded op numbers in TestTensorAxis. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) import paddle.inference as paddle_infer from paddle import base from paddle.base import core diff --git a/test/legacy_test/test_jit_save_load.py b/test/legacy_test/test_jit_save_load.py index 04598ecdbcc6bc..74393718173cd0 100644 --- a/test/legacy_test/test_jit_save_load.py +++ b/test/legacy_test/test_jit_save_load.py @@ -2331,13 +2331,13 @@ def test_save_dtype(self): out = model(data) save_dir = os.path.join(self.temp_dir.name, "test_save_dtype") path = save_dir + "/model" - paddle.jit.save( - model, path, input_spec=[InputSpec([None, 32], dtype='float32')] - ) + with paddle.amp.auto_cast(level='O2'): + paddle.jit.save( + model, path, input_spec=[InputSpec([None, 32], dtype='float32')] + ) loaded_model = paddle.jit.load(path) loaded_model = paddle.amp.decorate(models=loaded_model, level='O2') - with paddle.amp.auto_cast(level='O2'): - loaded_out = loaded_model(data) + loaded_out = loaded_model(data) np.testing.assert_allclose(out.numpy(), loaded_out.numpy(), atol=1e-5) diff --git a/test/legacy_test/test_lookahead.py b/test/legacy_test/test_lookahead.py index 4b095191df1b64..2b7b2080901958 100644 --- a/test/legacy_test/test_lookahead.py +++ b/test/legacy_test/test_lookahead.py @@ -17,6 +17,10 @@ import numpy as np import paddle + +# NOTE(Pan Zhaowu): using legacy linear to fulfill the promise of add_grad op. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) + from paddle import base, nn from paddle.base.framework import in_pir_mode diff --git a/test/legacy_test/test_paddlescience.py b/test/legacy_test/test_paddlescience.py index 0bddb3c4952640..0d424b14ef3c34 100644 --- a/test/legacy_test/test_paddlescience.py +++ b/test/legacy_test/test_paddlescience.py @@ -37,7 +37,7 @@ class TestPaddleSciencemodel(unittest.TestCase): def test_concat(self): - @jit.to_static + @jit.to_static(full_graph=True) def concat(x, y): """abc""" z = paddle.concat([x, y], 0) @@ -54,7 +54,7 @@ def concat(x, y): class TestEularBeam(unittest.TestCase): def test_eular_beam(self): - @jit.to_static + @jit.to_static(full_graph=True) def eular_beam(x): """abc""" z_ = model(x) diff --git a/test/legacy_test/test_pir_translated_layer.py b/test/legacy_test/test_pir_translated_layer.py index bae22c0fb760a6..b72e7c2a7d1626 100644 --- a/test/legacy_test/test_pir_translated_layer.py +++ b/test/legacy_test/test_pir_translated_layer.py @@ -30,6 +30,10 @@ IMAGE_SIZE = 784 CLASS_NUM = 10 +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of array equal +# in test_inference_and_fine_tuning. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) + # define a random dataset class RandomDataset(paddle.io.Dataset): diff --git a/test/tensorrt/test_converter_model_resnet50_move.py b/test/tensorrt/test_converter_model_resnet50_move.py index 0ab1d6b18a5bd5..fec1d3db32d317 100644 --- a/test/tensorrt/test_converter_model_resnet50_move.py +++ b/test/tensorrt/test_converter_model_resnet50_move.py @@ -148,6 +148,18 @@ def test_engine_serialized_path_move(self): output_expected = standardize(output_expected[0]) output_trt = standardize(output_converted[0].numpy()) + print("$#$#!$!#$!#!@$!#$!@#!@$!@$#!@#!") + print("++++++++++++++++++++++++ expected") + print(output_expected) + print("++++++++++++++++++++++++ trt") + print(output_trt) + raise ValueError( + f"\n\n[DEBUG ERROR] 打印两个 Tensor 数据:\n" + f"--- Tensor output_expect (Shape: {output_expected.shape}) ---\n" + f"{output_expected}\n\n" + f"--- Tensor output_trt (Shape: {output_trt.shape}) ---\n" + f"{output_trt}" + ) np.testing.assert_allclose( output_expected, output_trt,