Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2370,7 +2370,7 @@ PHI_DEFINE_EXPORTED_bool(use_accuracy_compatible_kernel,
/**
* Legacy gemm related FLAG
* Name: FLAGS_use_legacy_gemm
* Since Version: 3.3.0
* Since Version: 3.2.2
* Value Range: bool, default=false
* Example:
* Note: Whether use legacy gemm kernel.
Expand All @@ -2379,6 +2379,18 @@ PHI_DEFINE_EXPORTED_bool(use_legacy_gemm,
false,
"Whether use legacy gemm dispatch logics.");

/**
* Legacy gemm related FLAG
* Name: FLAGS_use_legacy_linear
* Since Version: 3.3.1
* Value Range: bool, default=false
* Example:
* Note: Whether use legacy linear kernel.
*/
PHI_DEFINE_EXPORTED_bool(use_legacy_linear,
false,
"Whether use legacy linear dispatch logics.");

/**
* Allocator Compact related FLAG
* Name: FLAGS_enable_compact_mem
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"where_double_grad",
"bmm_double_grad",
"index_put_double_grad",
"linear_v2_double_grad",
"gather_nd_double_grad",
"reshape_double_grad",
"take_along_axis_double_grad",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"lerp",
"log_loss",
"log_softmax",
"linear_v2",
"mean",
"mean_all",
"meshgrid",
Expand Down Expand Up @@ -112,6 +113,7 @@
'layer_norm_grad',
'log_grad',
'matmul_grad',
'linear_v2_grad',
'max_grad',
'maximum_grad',
'mean_grad',
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
'FusedGateAttentionGradInferMeta',
'ResnetBasicBlockInferMeta',
'ResnetBasicBlockGradInferMeta',
'LinearV2InferMeta',
# multiary.h
'AddNInferMeta',
'ApVariadicInferMeta',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@
'index_elementwise_get_grad',
'index_elementwise_put_with_tensor_grad',
'index_elementwise_put_grad',
'linear_v2_grad',
'view_shape_grad',
]
Original file line number Diff line number Diff line change
Expand Up @@ -2282,6 +2282,49 @@ bool FusedGemmEpilogueOpInferSymbolicShape(
return true;
}

bool LinearV2OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &x_dims = x_shape_or_data.shape();

const auto &y_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const auto &y_dims = y_shape_or_data.shape();

const auto &bias_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
const auto &bias_dims = bias_shape_or_data.shape();

size_t x_rank = x_dims.size();
size_t y_rank = y_dims.size();

std::vector<symbol::DimExpr> out_shape;
out_shape.reserve(x_rank);

for (size_t i = 0; i + 2 < x_rank; ++i) {
out_shape.emplace_back(x_dims[i]);
}

symbol::DimExpr out_M = x_dims[x_rank - 2];
symbol::DimExpr out_N = y_dims[y_rank - 1];

out_shape.emplace_back(out_M);
out_shape.emplace_back(out_N);

symbol::DimExpr x_K = x_dims[x_rank - 1];
symbol::DimExpr y_K = y_dims[y_rank - 2];

infer_context->AddEqualCstr(x_K, y_K);
// bias_dims[0] equal to out_N
infer_context->AddEqualCstr(out_N, bias_dims[0]);

infer_context->SetShapeOrDataForValue(op->result(0),
ShapeOrData{TensorExprs(out_shape)});

return true;
}

bool FusedMultiTransformerOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBatchNormAct_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedGemmEpilogue)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LinearV2)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedMultiTransformer)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GenerateProposals)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphKhopSampler)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- exp
- scale
- matmul
- linear_v2
- expand
- sum
- abs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,33 @@ void add_triple_grad(const paddle::optional<Tensor>& grad_grad_x,
}
}

template <typename T>
void linear_v2_double_grad(const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_input_grad,
const paddle::optional<Tensor>& grad_weight_grad,
const paddle::optional<Tensor>& grad_bias_grad,
Tensor* input_grad,
Tensor* weight_grad,
Tensor* bias_grad,
Tensor* grad_out_grad) {
matmul_double_grad<T>(input,
weight,
grad_out,
grad_input_grad,
grad_weight_grad,
false,
false,
input_grad,
weight_grad,
grad_out_grad);
if (bias_grad) {
add_double_grad<T>(bias, grad_out, nullptr, grad_bias_grad, -1, bias_grad);
}
}

template <typename T>
void subtract_double_grad(const Tensor& y,
const Tensor& grad_out,
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/codegen/decomp_vjp_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
'logsumexp_grad',
'masked_select_grad',
'matmul_grad',
'linear_v2_grad',
'max_grad',
'maximum_grad',
'minimum_grad',
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ Tensor bmm_decomp(const Tensor& x, const Tensor& y) {
return matmul<T>(x, y, false, false);
}

template <typename T>
Tensor linear_v2_decomp(const Tensor& input,
const Tensor& weight,
const Tensor& bias) {
Tensor result = matmul<T>(input, weight, false, false);
result = result + bias;
return result;
}

template <typename T>
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_decomp(
const Tensor& x,
Expand Down
15 changes: 15 additions & 0 deletions paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,21 @@ void matmul_grad(const Tensor& x,
}
}

template <typename T>
void linear_v2_grad(const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const Tensor& out_grad,
Tensor* input_grad,
Tensor* weight_grad,
Tensor* bias_grad) {
matmul_grad<T>(
input, weight, out_grad, false, false, input_grad, weight_grad);
if (bias_grad) {
add_grad<T>(bias, bias, out_grad, -1, nullptr, bias_grad);
}
}

template <typename T>
void maximum_grad(const Tensor& x,
const Tensor& y,
Expand Down
62 changes: 62 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,68 @@ void CudnnLSTMGradInferMeta(
UnchangedMultiInferMeta(weight_list.get(), weight_list_grad);
}
}
void LinearV2GradInferMeta(const MetaTensor& input,
const MetaTensor& weight,
const MetaTensor& bias,
const MetaTensor& out_grad,
MetaTensor* input_grad,
MetaTensor* weight_grad,
MetaTensor* bias_grad) {
auto input_dims = input.dims();
auto weight_dims = weight.dims();
auto bias_dims = bias.dims();
auto dout_dims = out_grad.dims();

auto dout_mat_dims = common::flatten_to_2d(dout_dims, dout_dims.size() - 1);

const int64_t input_ndim = input_dims.size();
auto k_from_dout = input_ndim >= 2 ? dout_dims[input_ndim - 2] : 1;
auto k_from_input = input_ndim >= 2 ? input_dims[input_ndim - 2] : 1;

bool check_k =
(k_from_dout < 0 || k_from_input < 0) || (k_from_dout == k_from_input);

if (check_k) {
PADDLE_ENFORCE_EQ(
dout_mat_dims[1],
weight_dims[1],
common::errors::InvalidArgument(
"The last dimension of DOut should be equal with Y's last "
"dimension. But received DOut[-1] = [%d], Y[1] = [%d].",
dout_mat_dims[1],
weight_dims[1]));
}

for (int32_t i = 0; i + 2 < input_dims.size(); ++i) {
if (dout_dims[i] > 0 && input_dims[i] > 0) {
PADDLE_ENFORCE_EQ(
dout_dims[i],
input_dims[i],
common::errors::InvalidArgument(
"The i dimension of DOut should be equal with i dimension of X."
"But received DOut[%d] = [%d], Y[%d] = [%d].",
i,
dout_dims[i],
i,
input_dims[i]));
}
}

if (input_grad) {
input_grad->set_dims(input_dims);
input_grad->set_dtype(input.dtype());
}

if (weight_grad) {
weight_grad->set_dims(weight_dims);
weight_grad->set_dtype(weight.dtype());
}

if (bias_grad) {
bias_grad->set_dims(bias_dims);
bias_grad->set_dtype(bias.dtype());
}
}

void LSTMGradInferMeta(const MetaTensor& input,
const MetaTensor& h0,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ PADDLE_API void CudnnLSTMGradInferMeta(
MetaTensor* init_c_grad,
std::vector<MetaTensor*> weight_list_grad);

PADDLE_API void LinearV2GradInferMeta(const MetaTensor& input,
const MetaTensor& weight,
const MetaTensor& bias,
const MetaTensor& out_grad,
MetaTensor* input_grad,
MetaTensor* weight_grad,
MetaTensor* bias_grad);

PADDLE_API void LSTMGradInferMeta(const MetaTensor& input,
const MetaTensor& h0,
const MetaTensor& c0,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ limitations under the License. */

namespace phi {
namespace distributed {
void FillMatmulPartOperandNotation(const int x_ndim,
const int y_ndim,
std::string* x_axes,
std::string* y_axes,
std::string* out_axes);
TensorDistAttr GetMatmulPartInferredDistAttr(
const TensorDistAttr& origin_dist_attr,
const std::vector<int64_t>& shape,
const std::string& tensor_axis,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
bool trans_axis);
void SetTensorDistAttrReplicated(TensorDistAttr* dist_attr, const int ndim);

SpmdInfo FusedGemmEpilogueInferSpmdBase(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& bias,
Expand Down
Loading
Loading