Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
4892c78
aligned linear & matmul_with_bias, need check incompatible cases.
A-nnonymous Nov 21, 2025
595da6b
aligned high-dim matmul operation, adding flags control and controlli…
A-nnonymous Nov 21, 2025
312b6ff
aligned matmul, start aligning einsum
A-nnonymous Nov 21, 2025
65bfd9c
Add flag
A-nnonymous Nov 21, 2025
c91f672
polish
A-nnonymous Nov 21, 2025
2c306be
fix shape related issues.
A-nnonymous Nov 24, 2025
8c62a41
Finish crash handling
A-nnonymous Nov 24, 2025
3f18079
revert redundant diff
A-nnonymous Nov 24, 2025
fd03fa5
revert redundant diff
A-nnonymous Nov 24, 2025
2595025
restrict influence to only CUDA
A-nnonymous Nov 24, 2025
92f3cb3
Optimized CPU overhead, bypass windows.
A-nnonymous Nov 26, 2025
d7a556a
optimize branch cost
A-nnonymous Nov 27, 2025
93d2451
disable dist tensor case
A-nnonymous Nov 27, 2025
6439f4c
add GPUPlace check
A-nnonymous Nov 28, 2025
ed20e33
fix matmul diff
A-nnonymous Dec 8, 2025
3ac50bc
add flags related logic
A-nnonymous Dec 8, 2025
c2842b8
polish
A-nnonymous Dec 8, 2025
9eaf718
polish
A-nnonymous Dec 8, 2025
bfe5d6e
polish
A-nnonymous Dec 8, 2025
2858364
polish
A-nnonymous Dec 8, 2025
9f5c03d
polish
A-nnonymous Dec 8, 2025
7ee0a72
polish
A-nnonymous Dec 8, 2025
5452edc
polish
A-nnonymous Dec 8, 2025
7b8ad72
polish
A-nnonymous Dec 8, 2025
6666585
Merge remote-tracking branch 'origin/develop' into align_matmul
A-nnonymous Dec 8, 2025
708db20
bypass win32
A-nnonymous Dec 8, 2025
e2b0e9a
Merge branch 'align_matmul' into align_linear
A-nnonymous Dec 8, 2025
034a087
stash
A-nnonymous Dec 8, 2025
a98a620
stash
A-nnonymous Dec 10, 2025
5c1810a
merge align_linear
A-nnonymous Dec 23, 2025
028ae96
align fwd in gpu
A-nnonymous Dec 25, 2025
83a7bd2
fix fwd miscs
A-nnonymous Dec 26, 2025
487024b
fix shape miscs
A-nnonymous Dec 26, 2025
7c9b52c
clean code
A-nnonymous Dec 26, 2025
813713a
fix grad
A-nnonymous Dec 29, 2025
12884e4
recover redundant diff
A-nnonymous Dec 29, 2025
d3d7f58
using legacy linear
A-nnonymous Dec 29, 2025
7a49003
add multi-platform support, polish
A-nnonymous Dec 29, 2025
a45acea
refractor
A-nnonymous Dec 29, 2025
f15e3ad
fix flag and amp rules
A-nnonymous Dec 29, 2025
ddd1508
fix CI
A-nnonymous Dec 30, 2025
83b14e8
polish
A-nnonymous Dec 30, 2025
708855f
fix miscs
A-nnonymous Dec 30, 2025
a7b25fc
add infersymbolics
A-nnonymous Dec 30, 2025
442ddc4
Add metaconfig
A-nnonymous Dec 30, 2025
3b709f1
fix symbolic, move flags
A-nnonymous Dec 31, 2025
049d62f
fix bwd infermeta
A-nnonymous Dec 31, 2025
3eac5fe
Add prim linear_v2_grad
A-nnonymous Dec 31, 2025
bfc67e8
fix fwd decomp
A-nnonymous Dec 31, 2025
4f5146b
add proper fallback to fulfill legacy promise
A-nnonymous Jan 4, 2026
c472892
tmp restrict prim
A-nnonymous Jan 4, 2026
ef08d31
Add inferSPMD, fix CI
A-nnonymous Jan 5, 2026
429d39f
fix ci
A-nnonymous Jan 6, 2026
bc36cc4
use legacy by default.
A-nnonymous Jan 6, 2026
a561325
add tag
A-nnonymous Jan 6, 2026
bfa6c0b
force error
A-nnonymous Jan 6, 2026
e315a08
fix np
A-nnonymous Jan 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2370,7 +2370,7 @@ PHI_DEFINE_EXPORTED_bool(use_accuracy_compatible_kernel,
/**
* Legacy gemm related FLAG
* Name: FLAGS_use_legacy_gemm
* Since Version: 3.3.0
* Since Version: 3.2.2
* Value Range: bool, default=false
* Example:
* Note: Whether use legacy gemm kernel.
Expand All @@ -2379,6 +2379,18 @@ PHI_DEFINE_EXPORTED_bool(use_legacy_gemm,
false,
"Whether use legacy gemm dispatch logics.");

/**
* Legacy gemm related FLAG
* Name: FLAGS_use_legacy_linear
* Since Version: 3.3.1
* Value Range: bool, default=false
* Example:
* Note: Whether use legacy linear kernel.
*/
PHI_DEFINE_EXPORTED_bool(use_legacy_linear,
true,
"Whether use legacy linear dispatch logics.");

/**
* Allocator Compact related FLAG
* Name: FLAGS_enable_compact_mem
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"where_double_grad",
"bmm_double_grad",
"index_put_double_grad",
"linear_v2_double_grad",
"gather_nd_double_grad",
"reshape_double_grad",
"take_along_axis_double_grad",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"lerp",
"log_loss",
"log_softmax",
"linear_v2",
"mean",
"mean_all",
"meshgrid",
Expand Down Expand Up @@ -112,6 +113,7 @@
'layer_norm_grad',
'log_grad',
'matmul_grad',
'linear_v2_grad',
'max_grad',
'maximum_grad',
'mean_grad',
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
'FusedGateAttentionGradInferMeta',
'ResnetBasicBlockInferMeta',
'ResnetBasicBlockGradInferMeta',
'LinearV2InferMeta',
# multiary.h
'AddNInferMeta',
'ApVariadicInferMeta',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@
'index_elementwise_get_grad',
'index_elementwise_put_with_tensor_grad',
'index_elementwise_put_grad',
'linear_v2_grad',
'view_shape_grad',
]
Original file line number Diff line number Diff line change
Expand Up @@ -2281,6 +2281,49 @@ bool FusedGemmEpilogueOpInferSymbolicShape(
return true;
}

bool LinearV2OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &x_dims = x_shape_or_data.shape();

const auto &y_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const auto &y_dims = y_shape_or_data.shape();

const auto &bias_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
const auto &bias_dims = bias_shape_or_data.shape();

size_t x_rank = x_dims.size();
size_t y_rank = y_dims.size();

std::vector<symbol::DimExpr> out_shape;
out_shape.reserve(x_rank);

for (size_t i = 0; i + 2 < x_rank; ++i) {
out_shape.emplace_back(x_dims[i]);
}

symbol::DimExpr out_M = x_dims[x_rank - 2];
symbol::DimExpr out_N = y_dims[y_rank - 1];

out_shape.emplace_back(out_M);
out_shape.emplace_back(out_N);

symbol::DimExpr x_K = x_dims[x_rank - 1];
symbol::DimExpr y_K = y_dims[y_rank - 2];

infer_context->AddEqualCstr(x_K, y_K);
// bias_dims[0] equal to out_N
infer_context->AddEqualCstr(out_N, bias_dims[0]);

infer_context->SetShapeOrDataForValue(op->result(0),
ShapeOrData{TensorExprs(out_shape)});

return true;
}

bool FusedMultiTransformerOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBatchNormAct_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedGemmEpilogue)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LinearV2)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedMultiTransformer)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GenerateProposals)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphKhopSampler)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- exp
- scale
- matmul
- linear_v2
- expand
- sum
- abs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,33 @@ void add_triple_grad(const paddle::optional<Tensor>& grad_grad_x,
}
}

template <typename T>
void linear_v2_double_grad(const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_input_grad,
const paddle::optional<Tensor>& grad_weight_grad,
const paddle::optional<Tensor>& grad_bias_grad,
Tensor* input_grad,
Tensor* weight_grad,
Tensor* bias_grad,
Tensor* grad_out_grad) {
matmul_double_grad<T>(input,
weight,
grad_out,
grad_input_grad,
grad_weight_grad,
false,
false,
input_grad,
weight_grad,
grad_out_grad);
if (bias_grad) {
add_double_grad<T>(bias, grad_out, nullptr, grad_bias_grad, -1, bias_grad);
}
}

template <typename T>
void subtract_double_grad(const Tensor& y,
const Tensor& grad_out,
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/codegen/decomp_vjp_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
'logsumexp_grad',
'masked_select_grad',
'matmul_grad',
'linear_v2_grad',
'max_grad',
'maximum_grad',
'minimum_grad',
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ Tensor bmm_decomp(const Tensor& x, const Tensor& y) {
return matmul<T>(x, y, false, false);
}

template <typename T>
Tensor linear_v2_decomp(const Tensor& input,
const Tensor& weight,
const Tensor& bias) {
Tensor result = matmul<T>(input, weight, false, false);
result = result + bias;
return result;
}

template <typename T>
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_decomp(
const Tensor& x,
Expand Down
15 changes: 15 additions & 0 deletions paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,21 @@ void matmul_grad(const Tensor& x,
}
}

template <typename T>
void linear_v2_grad(const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const Tensor& out_grad,
Tensor* input_grad,
Tensor* weight_grad,
Tensor* bias_grad) {
matmul_grad<T>(
input, weight, out_grad, false, false, input_grad, weight_grad);
if (bias_grad) {
add_grad<T>(bias, bias, out_grad, -1, nullptr, bias_grad);
}
}

template <typename T>
void maximum_grad(const Tensor& x,
const Tensor& y,
Expand Down
62 changes: 62 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,68 @@ void CudnnLSTMGradInferMeta(
UnchangedMultiInferMeta(weight_list.get(), weight_list_grad);
}
}
void LinearV2GradInferMeta(const MetaTensor& input,
const MetaTensor& weight,
const MetaTensor& bias,
const MetaTensor& out_grad,
MetaTensor* input_grad,
MetaTensor* weight_grad,
MetaTensor* bias_grad) {
auto input_dims = input.dims();
auto weight_dims = weight.dims();
auto bias_dims = bias.dims();
auto dout_dims = out_grad.dims();

auto dout_mat_dims = common::flatten_to_2d(dout_dims, dout_dims.size() - 1);

const int64_t input_ndim = input_dims.size();
auto k_from_dout = input_ndim >= 2 ? dout_dims[input_ndim - 2] : 1;
auto k_from_input = input_ndim >= 2 ? input_dims[input_ndim - 2] : 1;

bool check_k =
(k_from_dout < 0 || k_from_input < 0) || (k_from_dout == k_from_input);

if (check_k) {
PADDLE_ENFORCE_EQ(
dout_mat_dims[1],
weight_dims[1],
common::errors::InvalidArgument(
"The last dimension of DOut should be equal with Y's last "
"dimension. But received DOut[-1] = [%d], Y[1] = [%d].",
dout_mat_dims[1],
weight_dims[1]));
}

for (int32_t i = 0; i + 2 < input_dims.size(); ++i) {
if (dout_dims[i] > 0 && input_dims[i] > 0) {
PADDLE_ENFORCE_EQ(
dout_dims[i],
input_dims[i],
common::errors::InvalidArgument(
"The i dimension of DOut should be equal with i dimension of X."
"But received DOut[%d] = [%d], Y[%d] = [%d].",
i,
dout_dims[i],
i,
input_dims[i]));
}
}

if (input_grad) {
input_grad->set_dims(input_dims);
input_grad->set_dtype(input.dtype());
}

if (weight_grad) {
weight_grad->set_dims(weight_dims);
weight_grad->set_dtype(weight.dtype());
}

if (bias_grad) {
bias_grad->set_dims(bias_dims);
bias_grad->set_dtype(bias.dtype());
}
}

void LSTMGradInferMeta(const MetaTensor& input,
const MetaTensor& h0,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ PADDLE_API void CudnnLSTMGradInferMeta(
MetaTensor* init_c_grad,
std::vector<MetaTensor*> weight_list_grad);

PADDLE_API void LinearV2GradInferMeta(const MetaTensor& input,
const MetaTensor& weight,
const MetaTensor& bias,
const MetaTensor& out_grad,
MetaTensor* input_grad,
MetaTensor* weight_grad,
MetaTensor* bias_grad);

PADDLE_API void LSTMGradInferMeta(const MetaTensor& input,
const MetaTensor& h0,
const MetaTensor& c0,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ limitations under the License. */

namespace phi {
namespace distributed {
void FillMatmulPartOperandNotation(const int x_ndim,
const int y_ndim,
std::string* x_axes,
std::string* y_axes,
std::string* out_axes);
TensorDistAttr GetMatmulPartInferredDistAttr(
const TensorDistAttr& origin_dist_attr,
const std::vector<int64_t>& shape,
const std::string& tensor_axis,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
bool trans_axis);
void SetTensorDistAttrReplicated(TensorDistAttr* dist_attr, const int ndim);

SpmdInfo FusedGemmEpilogueInferSpmdBase(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& bias,
Expand Down
Loading
Loading