From 4892c78c11772c5fd7487a8022bf0a348e7467a8 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 21 Nov 2025 10:43:31 +0800 Subject: [PATCH 01/55] aligned linear & matmul_with_bias, need check incompatible cases. --- paddle/phi/kernels/funcs/blas/blas_impl.cu.h | 2 ++ .../phi/kernels/funcs/blas/blaslt_impl.cu.h | 34 +++++++++++++++++-- python/paddle/nn/functional/common.py | 20 +++++++++-- 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index ae7b67de6d642f..1fc053bd5d85cd 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -2620,6 +2620,7 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, int64_t batchCount, int64_t strideA, int64_t strideB) const { + std::cout << "####### HI BGEMM" << std::endl; // Note that cublas follows fortran order, so the order is different from // the cblas convention. int64_t lda = (transA == CblasNoTrans) ? K : M; @@ -2731,6 +2732,7 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 9010 T h_alpha = static_cast(alpha); T h_beta = static_cast(beta); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { CUBlas::GEMM_STRIDED_BATCH(handle, cuTransB, diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h index cdfe0fb6dc5e32..86a8ae74aaee3a 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -466,7 +466,7 @@ struct CublasLtBase { // NOTE(limingshu): As workspace_size varies from different DL framework, // I wonder is there any smarter idea for workspace setting, currently I // just followed the settings from the NVIDIA colleague`s setting. - size_t workspace_size = static_cast(4) * 1024 * 1024; + size_t workspace_size = static_cast(1) * 1024 * 1024; phi::Allocator::AllocationPtr workspace = GetWorkspace(dev_ctx, workspace_size); @@ -491,6 +491,36 @@ struct CublasLtBase { } } + cublasLtMatmulPreference_t preference; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceCreate(&preference)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(workspace_size))); + + int returned_results = 0; + cublasLtMatmulHeuristicResult_t heuristic_results = {}; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, + desc->op_desc, + desc->y_desc, + desc->x_desc, + desc->out_desc, + desc->out_desc, + preference, + 1, + &heuristic_results, + &returned_results)); + PADDLE_ENFORCE_GT( + returned_results, + 0, + common::errors::Unavailable("No GEMM algorithm available.")); + + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceDestroy(preference)); + VLOG(7) << "[Impl CublasltDescriptor] "; PADDLE_ENFORCE_GPU_SUCCESS( dynload::cublasLtMatmul(cublaslt_handle, @@ -505,7 +535,7 @@ struct CublasLtBase { desc->out_desc, out_ptr, desc->out_desc, - desc->algo, + &heuristic_results.algo, workspace->ptr(), workspace_size, dev_ctx.stream())); diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 8a96f03d947f2a..07e986913edc50 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -2510,8 +2510,24 @@ def linear( [-0.67769694, -0.67769694, -0.67769694, -0.67769694]]) """ if in_dynamic_mode(): - # TODO(jiabin): using addmm for fast forward route - return _C_ops.linear(x, weight, bias) + if bias is not None: + assert len(bias.shape) == 1, "only support 1D bias" + if weight.shape[0] > 1 and weight.shape[1] > 1: + out, _ = _C_ops.fused_gemm_epilogue( + x, weight, bias, False, False, "none" + ) + else: + bias_reshaped = paddle.repeat_interleave( + bias, x.shape[0], axis=0 + ) + bias_reshaped = paddle.unsqueeze(bias_reshaped, axis=1) + out = paddle.addmm( + bias_reshaped, x, weight, alpha=1.0, beta=1.0 + ) + else: + out = _C_ops.matmul(x, weight, False, False) + + return out elif in_pir_mode(): out = _C_ops.matmul(x, weight, False, False) From 595da6bbae14f80ffb24a8c29579148b91cf939c Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 21 Nov 2025 13:54:04 +0800 Subject: [PATCH 02/55] aligned high-dim matmul operation, adding flags control and controlling failed cases --- paddle/phi/kernels/funcs/blas/blas_impl.cu.h | 30 +++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index 1fc053bd5d85cd..565601c6b76157 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -2482,6 +2482,10 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; const int64_t strideC = M * N; + std::cout << "@@@@m: " << M << ", n: " << N << ", k: " << K + << ", strideA: " << strideA << ", strideB: " << strideB + << "strideC: " << strideC << ", lda: " << lda << ", ldb: " << ldb + << ", ldc: " << ldc << std::endl; #if CUDA_VERSION >= 9010 if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || std::is_same::value) { @@ -2519,6 +2523,10 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, } if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { #if CUDA_VERSION >= 12030 && defined(__linux__) + std::cout << "!!!!m: " << M << ", n: " << N << ", k: " << K + << ", strideA: " << strideA << ", strideB: " << strideB + << "strideC: " << strideC << ", lda: " << lda + << ", ldb: " << ldb << ", ldc: " << ldc << std::endl; dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cublasGemmStridedBatchedEx_64(handle, @@ -2580,6 +2588,7 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, } else { #endif // CUDA_VERSION >= 9010 dev_ctx_.CublasCall([&](cublasHandle_t handle) { + /* CUBlas::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, @@ -2598,6 +2607,26 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, static_cast(ldc), strideC, static_cast(batchCount)); + */ + CUBlas::GEMM_STRIDED_BATCH( + handle, + (cuTransA == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, + (cuTransB == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, + static_cast(M), + static_cast(N), + static_cast(K), + &alpha, + A, + static_cast(lda), + strideA, + B, + static_cast(ldb), + strideB, + &beta, + C, + static_cast(ldc), + strideC, + static_cast(batchCount)); }); #if CUDA_VERSION >= 9010 @@ -2620,7 +2649,6 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, int64_t batchCount, int64_t strideA, int64_t strideB) const { - std::cout << "####### HI BGEMM" << std::endl; // Note that cublas follows fortran order, so the order is different from // the cblas convention. int64_t lda = (transA == CblasNoTrans) ? K : M; From 312b6ff42af534fb2c44c65f5f2fb68059eb950c Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 21 Nov 2025 14:53:39 +0800 Subject: [PATCH 03/55] aligned matmul, start aligning einsum --- paddle/phi/kernels/funcs/blas/blas_impl.cu.h | 80 ++++++++++---------- 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index 565601c6b76157..eae8783bfe25b4 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -2588,45 +2588,47 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, } else { #endif // CUDA_VERSION >= 9010 dev_ctx_.CublasCall([&](cublasHandle_t handle) { - /* - CUBlas::GEMM_STRIDED_BATCH(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &alpha, - B, - static_cast(ldb), - strideB, - A, - static_cast(lda), - strideA, - &beta, - C, - static_cast(ldc), - strideC, - static_cast(batchCount)); - */ - CUBlas::GEMM_STRIDED_BATCH( - handle, - (cuTransA == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, - (cuTransB == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, - static_cast(M), - static_cast(N), - static_cast(K), - &alpha, - A, - static_cast(lda), - strideA, - B, - static_cast(ldb), - strideB, - &beta, - C, - static_cast(ldc), - strideC, - static_cast(batchCount)); + if (ldc == 1 && M >= 1) { + CUBlas::GEMM_STRIDED_BATCH( + handle, + (cuTransA == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, + (cuTransB == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, + static_cast(M), + static_cast(N), + static_cast(K), + &alpha, + A, + static_cast(lda), + strideA, + B, + static_cast(ldb), + strideB, + &beta, + C, + static_cast(ldc), + strideC, + static_cast(batchCount)); + + } else { + CUBlas::GEMM_STRIDED_BATCH(handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + &alpha, + B, + static_cast(ldb), + strideB, + A, + static_cast(lda), + strideA, + &beta, + C, + static_cast(ldc), + strideC, + static_cast(batchCount)); + } }); #if CUDA_VERSION >= 9010 From 65bfd9c0da4579b2cfbf2147fb0bbb2a70209cfe Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 21 Nov 2025 16:36:42 +0800 Subject: [PATCH 04/55] Add flag --- paddle/common/flags.cc | 11 ++++ paddle/phi/kernels/funcs/blas/blas_impl.cu.h | 9 +-- .../phi/kernels/funcs/blas/blaslt_impl.cu.h | 65 ++++++++++--------- .../gpu/repeat_interleave_grad_kernel.cu | 6 +- .../kernels/gpu/repeat_interleave_kernel.cu | 3 +- python/paddle/nn/functional/common.py | 42 +++++++----- 6 files changed, 81 insertions(+), 55 deletions(-) diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index ff73c668439cb4..f68acef7c2ca91 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -2314,6 +2314,17 @@ PHI_DEFINE_EXPORTED_bool( PHI_DEFINE_EXPORTED_bool(use_accuracy_compatible_kernel, false, "Whether use torch compatible version kernel."); +/** + * Legacy gemm related FLAG + * Name: FLAGS_use_accuracy_compatible_kernel + * Since Version: 3.2.2 + * Value Range: bool, default=false + * Example: + * Note: Whether use legacy gemm kernel. + */ +PHI_DEFINE_EXPORTED_bool(use_legacy_gemm, + false, + "Whether use legacy gemm dispatch logics."); /** * Allocator Compact related FLAG diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index eae8783bfe25b4..5fc16c7949df27 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -28,6 +28,7 @@ COMMON_DECLARE_bool(enable_cublas_tensor_op_math); COMMON_DECLARE_bool(gemm_use_half_precision_compute_type); +COMMON_DECLARE_bool(use_legacy_gemm); namespace phi { namespace funcs { @@ -2482,10 +2483,6 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; const int64_t strideC = M * N; - std::cout << "@@@@m: " << M << ", n: " << N << ", k: " << K - << ", strideA: " << strideA << ", strideB: " << strideB - << "strideC: " << strideC << ", lda: " << lda << ", ldb: " << ldb - << ", ldc: " << ldc << std::endl; #if CUDA_VERSION >= 9010 if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || std::is_same::value) { @@ -2588,7 +2585,8 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, } else { #endif // CUDA_VERSION >= 9010 dev_ctx_.CublasCall([&](cublasHandle_t handle) { - if (ldc == 1 && M >= 1) { + if (ldc == 1 && M >= 1 && !FLAGS_use_legacy_gemm) { + // No transpose result in this case, align with torch's behaviour. CUBlas::GEMM_STRIDED_BATCH( handle, (cuTransA == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, @@ -2608,7 +2606,6 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, static_cast(ldc), strideC, static_cast(batchCount)); - } else { CUBlas::GEMM_STRIDED_BATCH(handle, cuTransB, diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h index 86a8ae74aaee3a..7d1d6b338a9b76 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -32,6 +32,7 @@ limitations under the License. */ COMMON_DECLARE_int64(cublaslt_exhaustive_search_times); COMMON_DECLARE_bool(enable_blaslt_global_search); +COMMON_DECLARE_bool(use_legacy_gemm); #endif namespace phi { @@ -466,7 +467,9 @@ struct CublasLtBase { // NOTE(limingshu): As workspace_size varies from different DL framework, // I wonder is there any smarter idea for workspace setting, currently I // just followed the settings from the NVIDIA colleague`s setting. - size_t workspace_size = static_cast(1) * 1024 * 1024; + size_t workspace_size = FLAGS_use_legacy_gemm + ? static_cast(4) * 1024 * 1024 + : static_cast(1) * 1024 * 1024; phi::Allocator::AllocationPtr workspace = GetWorkspace(dev_ctx, workspace_size); @@ -490,36 +493,38 @@ struct CublasLtBase { cache.SetSubKey(sub_key, reinterpret_cast(best_desc)); } } - - cublasLtMatmulPreference_t preference; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulPreferenceCreate(&preference)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( - preference, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace_size, - sizeof(workspace_size))); - - int returned_results = 0; cublasLtMatmulHeuristicResult_t heuristic_results = {}; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, - desc->op_desc, - desc->y_desc, - desc->x_desc, - desc->out_desc, - desc->out_desc, - preference, - 1, - &heuristic_results, - &returned_results)); - PADDLE_ENFORCE_GT( - returned_results, - 0, - common::errors::Unavailable("No GEMM algorithm available.")); + if (!FLAGS_use_legacy_gemm) { + cublasLtMatmulPreference_t preference; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceCreate(&preference)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(workspace_size))); + + int returned_results = 0; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, + desc->op_desc, + desc->y_desc, + desc->x_desc, + desc->out_desc, + desc->out_desc, + preference, + 1, + &heuristic_results, + &returned_results)); + PADDLE_ENFORCE_GT( + returned_results, + 0, + common::errors::Unavailable("No GEMM algorithm available.")); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulPreferenceDestroy(preference)); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceDestroy(preference)); + *desc->algo = heuristic_results.algo; + } VLOG(7) << "[Impl CublasltDescriptor] "; PADDLE_ENFORCE_GPU_SUCCESS( @@ -535,7 +540,7 @@ struct CublasLtBase { desc->out_desc, out_ptr, desc->out_desc, - &heuristic_results.algo, + desc->algo, workspace->ptr(), workspace_size, dev_ctx.stream())); diff --git a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu index be10f7ca31025b..c6c6a64bbf2c64 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu @@ -227,7 +227,8 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad, double, int, int64_t, - phi::bfloat16) {} + phi::bfloat16, + phi::float16) {} PD_REGISTER_KERNEL(repeat_interleave_grad, GPU, ALL_LAYOUT, @@ -236,4 +237,5 @@ PD_REGISTER_KERNEL(repeat_interleave_grad, double, int, int64_t, - phi::bfloat16) {} + phi::bfloat16, + phi::float16) {} diff --git a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu index bb30ac7cb40bea..10bb77cf153c46 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu @@ -302,7 +302,8 @@ PD_REGISTER_KERNEL(repeat_interleave, double, int, int64_t, - phi::bfloat16) {} + phi::bfloat16, + phi::float16) {} PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, GPU, diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 07e986913edc50..efbb2bd910f61c 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -2510,24 +2510,34 @@ def linear( [-0.67769694, -0.67769694, -0.67769694, -0.67769694]]) """ if in_dynamic_mode(): - if bias is not None: - assert len(bias.shape) == 1, "only support 1D bias" - if weight.shape[0] > 1 and weight.shape[1] > 1: - out, _ = _C_ops.fused_gemm_epilogue( - x, weight, bias, False, False, "none" - ) + if paddle.get_flags(["FLAGS_use_legacy_gemm"]).get( + "FLAGS_use_legacy_gemm", False + ): + if bias is not None: + assert len(bias.shape) == 1, "only support 1D bias" + if weight.shape[0] > 1 and weight.shape[1] > 1: + out, _ = _C_ops.fused_gemm_epilogue( + x.reshape(-1, x.shape[-1]), + weight.reshape(-1, weight.shape[-1]), + bias, + False, + False, + "none", + ) + else: + bias_reshaped = paddle.repeat_interleave( + bias, x.shape[0], axis=0 + ) + bias_reshaped = paddle.unsqueeze(bias_reshaped, axis=1) + out = paddle.addmm( + bias_reshaped, x, weight, alpha=1.0, beta=1.0 + ) else: - bias_reshaped = paddle.repeat_interleave( - bias, x.shape[0], axis=0 - ) - bias_reshaped = paddle.unsqueeze(bias_reshaped, axis=1) - out = paddle.addmm( - bias_reshaped, x, weight, alpha=1.0, beta=1.0 - ) + out = _C_ops.matmul(x, weight, False, False) + return out else: - out = _C_ops.matmul(x, weight, False, False) - - return out + # Fallback logic + return _C_ops.linear(x, weight, bias) elif in_pir_mode(): out = _C_ops.matmul(x, weight, False, False) From c91f672f80210f65b264e2105e3fb1b9376a1749 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 21 Nov 2025 17:04:59 +0800 Subject: [PATCH 05/55] polish --- paddle/phi/kernels/funcs/blas/blas_impl.cu.h | 4 ---- paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index 5fc16c7949df27..a670bb56c4dbfe 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -2520,10 +2520,6 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, } if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { #if CUDA_VERSION >= 12030 && defined(__linux__) - std::cout << "!!!!m: " << M << ", n: " << N << ", k: " << K - << ", strideA: " << strideA << ", strideB: " << strideB - << "strideC: " << strideC << ", lda: " << lda - << ", ldb: " << ldb << ", ldc: " << ldc << std::endl; dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cublasGemmStridedBatchedEx_64(handle, diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h index 7d1d6b338a9b76..965cfc94f5cfd9 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -523,7 +523,7 @@ struct CublasLtBase { PADDLE_ENFORCE_GPU_SUCCESS( dynload::cublasLtMatmulPreferenceDestroy(preference)); - *desc->algo = heuristic_results.algo; + desc->algo = &heuristic_results.algo; } VLOG(7) << "[Impl CublasltDescriptor] "; From 2c306be02c7b26a142a24b766f4f5c4e6fc6906a Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 24 Nov 2025 11:43:04 +0800 Subject: [PATCH 06/55] fix shape related issues. --- .../phi/kernels/funcs/blas/blaslt_impl.cu.h | 35 +++++++++---------- python/paddle/nn/functional/common.py | 5 ++- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h index 965cfc94f5cfd9..31338c225afebf 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -523,27 +523,26 @@ struct CublasLtBase { PADDLE_ENFORCE_GPU_SUCCESS( dynload::cublasLtMatmulPreferenceDestroy(preference)); - desc->algo = &heuristic_results.algo; } VLOG(7) << "[Impl CublasltDescriptor] "; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmul(cublaslt_handle, - desc->op_desc, - static_cast(&alpha), - y_ptr, - desc->y_desc, - x_ptr, - desc->x_desc, - static_cast(&beta), - out_ptr, - desc->out_desc, - out_ptr, - desc->out_desc, - desc->algo, - workspace->ptr(), - workspace_size, - dev_ctx.stream())); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul( + cublaslt_handle, + desc->op_desc, + static_cast(&alpha), + y_ptr, + desc->y_desc, + x_ptr, + desc->x_desc, + static_cast(&beta), + out_ptr, + desc->out_desc, + out_ptr, + desc->out_desc, + FLAGS_use_legacy_gemm ? desc->algo : &heuristic_results.algo, + workspace->ptr(), + workspace_size, + dev_ctx.stream())); } static void SearchBestAlgo(const phi::GPUContext& dev_ctx, diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index efbb2bd910f61c..b87850e768fce8 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -2510,12 +2510,13 @@ def linear( [-0.67769694, -0.67769694, -0.67769694, -0.67769694]]) """ if in_dynamic_mode(): - if paddle.get_flags(["FLAGS_use_legacy_gemm"]).get( + if not paddle.get_flags(["FLAGS_use_legacy_gemm"]).get( "FLAGS_use_legacy_gemm", False ): if bias is not None: assert len(bias.shape) == 1, "only support 1D bias" if weight.shape[0] > 1 and weight.shape[1] > 1: + x_shape_prefix = x.shape[:-1] out, _ = _C_ops.fused_gemm_epilogue( x.reshape(-1, x.shape[-1]), weight.reshape(-1, weight.shape[-1]), @@ -2524,6 +2525,8 @@ def linear( False, "none", ) + output_shape = x_shape_prefix + [weight.shape[-1]] # noqa:RUF005 + out = out.reshape(output_shape) else: bias_reshaped = paddle.repeat_interleave( bias, x.shape[0], axis=0 From 8c62a4146e22c08e047410333985e7a98157268b Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 24 Nov 2025 16:44:29 +0800 Subject: [PATCH 07/55] Finish crash handling --- paddle/phi/kernels/funcs/blas/blas_impl.cu.h | 6 ++-- python/paddle/nn/functional/common.py | 35 ++++++++++++++------ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index a670bb56c4dbfe..927deb193e73ea 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -2581,8 +2581,10 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, } else { #endif // CUDA_VERSION >= 9010 dev_ctx_.CublasCall([&](cublasHandle_t handle) { - if (ldc == 1 && M >= 1 && !FLAGS_use_legacy_gemm) { - // No transpose result in this case, align with torch's behaviour. + if (N == 1 && ldc >= std::max(1, M) && !FLAGS_use_legacy_gemm) { + // No transpose result in these case, align with torch's behaviour. + // TODO(Pan Zhaowu): Integrate proper stride support for arbitrary input + // tensor. CUBlas::GEMM_STRIDED_BATCH( handle, (cuTransA == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index b87850e768fce8..09eab684b048fe 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -16,7 +16,9 @@ import inspect import math +import operator import warnings +from functools import reduce from typing import TYPE_CHECKING, Any, Literal import numpy @@ -2510,13 +2512,25 @@ def linear( [-0.67769694, -0.67769694, -0.67769694, -0.67769694]]) """ if in_dynamic_mode(): - if not paddle.get_flags(["FLAGS_use_legacy_gemm"]).get( - "FLAGS_use_legacy_gemm", False + if ( + not paddle.get_flags(["FLAGS_use_legacy_gemm"]).get( + "FLAGS_use_legacy_gemm", False + ) + and x.place.is_gpu_place() ): + if bias is not None and bias.shape == []: + if bias.numel() == 0: + bias = None + else: + # scalar bias + if bias.numel() != weight.shape[-1]: + bias = bias.expand([weight.shape[-1]]) + if bias is not None: assert len(bias.shape) == 1, "only support 1D bias" + x_shape_prefix = x.shape[:-1] + output_shape = x_shape_prefix + [weight.shape[-1]] # noqa:RUF005 if weight.shape[0] > 1 and weight.shape[1] > 1: - x_shape_prefix = x.shape[:-1] out, _ = _C_ops.fused_gemm_epilogue( x.reshape(-1, x.shape[-1]), weight.reshape(-1, weight.shape[-1]), @@ -2525,16 +2539,17 @@ def linear( False, "none", ) - output_shape = x_shape_prefix + [weight.shape[-1]] # noqa:RUF005 - out = out.reshape(output_shape) else: - bias_reshaped = paddle.repeat_interleave( - bias, x.shape[0], axis=0 - ) - bias_reshaped = paddle.unsqueeze(bias_reshaped, axis=1) + flattened_m = reduce(operator.mul, x_shape_prefix, 1) out = paddle.addmm( - bias_reshaped, x, weight, alpha=1.0, beta=1.0 + bias.expand([flattened_m, bias.shape[0]]), + x.reshape(-1, x.shape[-1]), + weight.reshape(-1, weight.shape[-1]), + alpha=1.0, + beta=1.0, ) + + out = out.reshape(output_shape) else: out = _C_ops.matmul(x, weight, False, False) return out From 3f180799d9eb67da6f57a4f1f29646fdd14bba32 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 24 Nov 2025 17:02:58 +0800 Subject: [PATCH 08/55] revert redundant diff --- paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu | 2 -- paddle/phi/kernels/gpu/repeat_interleave_kernel.cu | 1 - 2 files changed, 3 deletions(-) diff --git a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu index c6c6a64bbf2c64..171a0c154bd824 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu @@ -227,7 +227,6 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad, double, int, int64_t, - phi::bfloat16, phi::float16) {} PD_REGISTER_KERNEL(repeat_interleave_grad, GPU, @@ -237,5 +236,4 @@ PD_REGISTER_KERNEL(repeat_interleave_grad, double, int, int64_t, - phi::bfloat16, phi::float16) {} diff --git a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu index 10bb77cf153c46..699de4c83cdcfc 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu @@ -302,7 +302,6 @@ PD_REGISTER_KERNEL(repeat_interleave, double, int, int64_t, - phi::bfloat16, phi::float16) {} PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, From fd03fa58b8fc9698ab4af9667aa31fa1837455b9 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 24 Nov 2025 17:05:30 +0800 Subject: [PATCH 09/55] revert redundant diff --- paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu | 4 ++-- paddle/phi/kernels/gpu/repeat_interleave_kernel.cu | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu index 171a0c154bd824..be10f7ca31025b 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu @@ -227,7 +227,7 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad, double, int, int64_t, - phi::float16) {} + phi::bfloat16) {} PD_REGISTER_KERNEL(repeat_interleave_grad, GPU, ALL_LAYOUT, @@ -236,4 +236,4 @@ PD_REGISTER_KERNEL(repeat_interleave_grad, double, int, int64_t, - phi::float16) {} + phi::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu index 699de4c83cdcfc..bb30ac7cb40bea 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu @@ -302,7 +302,7 @@ PD_REGISTER_KERNEL(repeat_interleave, double, int, int64_t, - phi::float16) {} + phi::bfloat16) {} PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, GPU, From 2595025fb71ac24893116e783501ae2baffeb96d Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 24 Nov 2025 19:24:41 +0800 Subject: [PATCH 10/55] restrict influence to only CUDA --- python/paddle/nn/functional/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 09eab684b048fe..bbb2bb465deb8a 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -2516,7 +2516,7 @@ def linear( not paddle.get_flags(["FLAGS_use_legacy_gemm"]).get( "FLAGS_use_legacy_gemm", False ) - and x.place.is_gpu_place() + and core.is_compiled_with_cuda() ): if bias is not None and bias.shape == []: if bias.numel() == 0: From 92f3cb37dbc0b83ff7318d9d2f50ff6cf8fd7844 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 26 Nov 2025 20:35:37 +0800 Subject: [PATCH 11/55] Optimized CPU overhead, bypass windows. --- paddle/fluid/pybind/eager_custom_python_api.h | 157 +++++++++++++++++- python/paddle/nn/functional/common.py | 47 +----- 2 files changed, 152 insertions(+), 52 deletions(-) diff --git a/paddle/fluid/pybind/eager_custom_python_api.h b/paddle/fluid/pybind/eager_custom_python_api.h index 32f7b9e8e4a954..300c82fa98bed2 100644 --- a/paddle/fluid/pybind/eager_custom_python_api.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -15,10 +15,13 @@ #include +#include "paddle/common/ddim.h" +#include "paddle/common/flags.h" #include "paddle/fluid/eager/to_static/run_program_func.h" #include "paddle/fluid/eager/utils.h" #include "paddle/phi/core/enforce.h" +COMMON_DECLARE_bool(use_legacy_gemm); using egr::ConvertAllInputsToDistTensor; using egr::InputsContainDistTensor; @@ -36,17 +39,159 @@ static PyObject *eager_api_linear(PyObject *self, tstate = PyEval_SaveThread(); - if (bias.is_dist_tensor() || bias.has_allocation()) { + if (bias.is_dist_tensor() || (bias.has_allocation() && bias.numel() > 0)) { const phi::distributed::ProcessMesh *mesh = nullptr; if (InputsContainDistTensor(&mesh, x, weight, bias)) { ConvertAllInputsToDistTensor(mesh, x, weight, bias); } +#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000 && \ + !(defined(_WIN32) || defined(WIN32))) + if (!FLAGS_use_legacy_gemm) { + // TODO(Pan Zhaowu): Add proper broadcast logic for batchsize unaligned + // batch-gemm. Currently handles: (B..., k) x (k, n) -> (B..., n), with + // 1D or scalar bias. - auto mm_out = matmul_ad_func(x, weight, false, false); - auto out = add_ad_func(mm_out, bias); - PyEval_RestoreThread(tstate); - tstate = nullptr; - return ToPyObject(out); + // --- Original input tensor dimensions and values --- + const auto &x_original_shape = x.shape(); + const size_t x_ndim_original = x_original_shape.size(); + + const auto &weight_original_shape = weight.shape(); + const size_t weight_ndim_original = weight_original_shape.size(); + + // Determine the 'k' and 'n' dimensions based on original shapes. + // These values are crucial for potential 1D reshaping and output shape + // calculation. + const int64_t k_dim = + x_original_shape[x_ndim_original - 1]; // Last dimension of X + const int64_t n_dim = + weight_original_shape[weight_ndim_original - + 1]; // Last dimension of Weight + + // --- Process 1D x and weight tensors by reshaping them to 2D if + // necessary --- Subsequent operations will use these processed tensors. + paddle::Tensor x_processed = + x; // Start with original, possibly reassign if reshaped + paddle::Tensor weight_processed = + weight; // Start with original, possibly reassign if reshaped + + // If x is 1D (e.g., shape [k]), reshape it to [1, k] to fit the (B..., + // k) x (k, n) pattern. This effectively treats a 1D vector as a row + // vector for matrix multiplication. + if (x_ndim_original == 1) { + x_processed = reshape_ad_func(x, {1, k_dim}); + } + // If weight is 1D (e.g., shape [n]), reshape it to [k, 1]. + // This implies 'n' was 1 in the original context, and 'k' is determined + // by x. This effectively treats a 1D vector as a column vector. Note: + // This 'else if' means if both x and weight are 1D, only x gets + // reshaped currently. For (k) x (n) where n != k and both are 1D, the + // semantics are ambiguous and not directly covered by (B..., k) x (k, + // n). The current design implies weight is at least 2D or is treated as + // [k, 1] if 1D. + else if (weight_ndim_original == 1) { // NOLINT + weight_processed = reshape_ad_func(weight, {k_dim, 1}); + } + + // --- Recalculate dimensions based on processed tensors --- + // These dimensions will be used for the actual GEMM operation. + const auto &x_shape_current = x_processed.shape(); + const size_t x_ndim_current = x_shape_current.size(); + + const auto &weight_shape_current = weight_processed.shape(); + const size_t weight_ndim_current = weight_shape_current.size(); + + // Effective 'k' and 'n' for GEMM. + const int64_t k_effective = x_shape_current[x_ndim_current - 1]; + const int64_t n_effective = + weight_shape_current[weight_ndim_current - 1]; + + // --- Determine the final output shape --- + // Start with the processed x's shape, then modify the last dimension. + std::vector output_shape_vec = x_shape_current; + output_shape_vec[x_ndim_current - 1] = n_effective; + + // If the original x was 1D, the processed x became [1, k]. + // The output_shape_vec would be [1, n]. + // For 1D input, we usually expect a 1D output (shape [n]) if possible. + if (x_ndim_original == 1 && output_shape_vec.size() > 1 && + output_shape_vec[0] == 1) { + output_shape_vec.erase( + output_shape_vec + .begin()); // Remove the artificial batch dimension + } + + // Calculate the total number of elements in the batch dimensions of X. + // This is used for reshaping X into a 2D matrix for addmm_ad_func. + const int64_t x_batch_numel = + std::accumulate(output_shape_vec.begin(), + output_shape_vec.end() - 1, + 1LL, + std::multiplies()); + + // --- Bias handling and GEMM execution --- + // The condition now uses the processed weight's shape. + // This branch typically handles (B..., k) x (k, n) where n > 1. + if (weight_shape_current[0] > 1 && weight_shape_current[1] > 1) { + paddle::Tensor bias_1d = + bias; // Create a mutable copy if modification is needed + // Align bias' shape to 'n_effective'. If bias.numel() != n_effective, + // tile it. + if (bias.numel() != n_effective) { + bias_1d = tile_ad_func(bias, {static_cast(n_effective)}); + } + // Execute fused GEMM with epilogue. + auto [out, _] = fused_gemm_epilogue_ad_func( + x_processed, weight_processed, bias_1d, false, false, "none"); + + // If original x was 1D and output_shape_vec is 1D (i.e., [n]), + // but fused_gemm_epilogue_ad_func returns a 2D tensor ([1, n]), + // reshape it back to the desired 1D output shape. + if (x_ndim_original == 1 && out.shape().size() == 2 && + output_shape_vec.size() == 1) { + out = reshape_ad_func(out, output_shape_vec); + } + + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(out); + } else { + // This branch handles cases where weight_processed is effectively 2D + // with one dimension being 1, e.g., (B..., k) x (k, 1) resulting in + // (B..., 1). Or when weight_processed was originally 1D and reshaped + // to [k, 1]. + + // Reshape bias to [1, n_effective] then tile to [x_batch_numel, 1] + // for addmm_ad_func. + paddle::Tensor bias_2d = tile_ad_func( + reshape_ad_func(bias, {1, n_effective}), {x_batch_numel, 1}); + + // Perform matrix multiplication using addmm_ad_func. + // x_processed is reshaped to 2D [x_batch_numel, k_effective] for the + // multiplication. + auto out = addmm_ad_func( + bias_2d, + reshape_ad_func(x_processed, {x_batch_numel, k_effective}), + weight_processed, + 1.0, + 1.0); + + // Reshape the final output to the target output_shape_vec. + out = reshape_ad_func(out, output_shape_vec); + + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(out); + } + } else // NOLINT(readability/braces) +#endif + { + auto mm_out = matmul_ad_func(x, weight, false, false); + auto out = add_ad_func(mm_out, bias); + + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(out); + } } else { const phi::distributed::ProcessMesh *mesh = nullptr; if (InputsContainDistTensor(&mesh, x, weight)) { diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index bbb2bb465deb8a..0060e060f939fd 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -16,9 +16,7 @@ import inspect import math -import operator import warnings -from functools import reduce from typing import TYPE_CHECKING, Any, Literal import numpy @@ -2512,50 +2510,7 @@ def linear( [-0.67769694, -0.67769694, -0.67769694, -0.67769694]]) """ if in_dynamic_mode(): - if ( - not paddle.get_flags(["FLAGS_use_legacy_gemm"]).get( - "FLAGS_use_legacy_gemm", False - ) - and core.is_compiled_with_cuda() - ): - if bias is not None and bias.shape == []: - if bias.numel() == 0: - bias = None - else: - # scalar bias - if bias.numel() != weight.shape[-1]: - bias = bias.expand([weight.shape[-1]]) - - if bias is not None: - assert len(bias.shape) == 1, "only support 1D bias" - x_shape_prefix = x.shape[:-1] - output_shape = x_shape_prefix + [weight.shape[-1]] # noqa:RUF005 - if weight.shape[0] > 1 and weight.shape[1] > 1: - out, _ = _C_ops.fused_gemm_epilogue( - x.reshape(-1, x.shape[-1]), - weight.reshape(-1, weight.shape[-1]), - bias, - False, - False, - "none", - ) - else: - flattened_m = reduce(operator.mul, x_shape_prefix, 1) - out = paddle.addmm( - bias.expand([flattened_m, bias.shape[0]]), - x.reshape(-1, x.shape[-1]), - weight.reshape(-1, weight.shape[-1]), - alpha=1.0, - beta=1.0, - ) - - out = out.reshape(output_shape) - else: - out = _C_ops.matmul(x, weight, False, False) - return out - else: - # Fallback logic - return _C_ops.linear(x, weight, bias) + return _C_ops.linear(x, weight, bias) elif in_pir_mode(): out = _C_ops.matmul(x, weight, False, False) From d7a556a442c282151389d211cc16cb2dea0b0183 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Thu, 27 Nov 2025 15:15:17 +0800 Subject: [PATCH 12/55] optimize branch cost --- paddle/fluid/pybind/eager_custom_python_api.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pybind/eager_custom_python_api.h b/paddle/fluid/pybind/eager_custom_python_api.h index 300c82fa98bed2..82bb48023505d6 100644 --- a/paddle/fluid/pybind/eager_custom_python_api.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -39,14 +39,14 @@ static PyObject *eager_api_linear(PyObject *self, tstate = PyEval_SaveThread(); - if (bias.is_dist_tensor() || (bias.has_allocation() && bias.numel() > 0)) { + if (bias.has_allocation() || bias.is_dist_tensor()) { const phi::distributed::ProcessMesh *mesh = nullptr; if (InputsContainDistTensor(&mesh, x, weight, bias)) { ConvertAllInputsToDistTensor(mesh, x, weight, bias); } #if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000 && \ !(defined(_WIN32) || defined(WIN32))) - if (!FLAGS_use_legacy_gemm) { + if (!FLAGS_use_legacy_gemm) [[likely]] { // NOLINT // TODO(Pan Zhaowu): Add proper broadcast logic for batchsize unaligned // batch-gemm. Currently handles: (B..., k) x (k, n) -> (B..., n), with // 1D or scalar bias. @@ -77,7 +77,7 @@ static PyObject *eager_api_linear(PyObject *self, // If x is 1D (e.g., shape [k]), reshape it to [1, k] to fit the (B..., // k) x (k, n) pattern. This effectively treats a 1D vector as a row // vector for matrix multiplication. - if (x_ndim_original == 1) { + if (x_ndim_original == 1) [[unlikely]] { x_processed = reshape_ad_func(x, {1, k_dim}); } // If weight is 1D (e.g., shape [n]), reshape it to [k, 1]. @@ -88,7 +88,7 @@ static PyObject *eager_api_linear(PyObject *self, // semantics are ambiguous and not directly covered by (B..., k) x (k, // n). The current design implies weight is at least 2D or is treated as // [k, 1] if 1D. - else if (weight_ndim_original == 1) { // NOLINT + else if (weight_ndim_original == 1) [[unlikely]] { // NOLINT weight_processed = reshape_ad_func(weight, {k_dim, 1}); } From 93d2451d8536cdc4ba854cf8f11d6af4455290be Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Thu, 27 Nov 2025 19:07:06 +0800 Subject: [PATCH 13/55] disable dist tensor case --- paddle/fluid/pybind/eager_custom_python_api.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/eager_custom_python_api.h b/paddle/fluid/pybind/eager_custom_python_api.h index 82bb48023505d6..707ba6e6d81c3e 100644 --- a/paddle/fluid/pybind/eager_custom_python_api.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -46,7 +46,8 @@ static PyObject *eager_api_linear(PyObject *self, } #if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000 && \ !(defined(_WIN32) || defined(WIN32))) - if (!FLAGS_use_legacy_gemm) [[likely]] { // NOLINT + if (!FLAGS_use_legacy_gemm && !bias.is_dist_tensor()) // NOLINT + [[likely]] { // NOLINT // TODO(Pan Zhaowu): Add proper broadcast logic for batchsize unaligned // batch-gemm. Currently handles: (B..., k) x (k, n) -> (B..., n), with // 1D or scalar bias. From 6439f4c3720908acbd3fa3a9482c9b91d630f47e Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 28 Nov 2025 11:38:25 +0800 Subject: [PATCH 14/55] add GPUPlace check --- paddle/fluid/pybind/eager_custom_python_api.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pybind/eager_custom_python_api.h b/paddle/fluid/pybind/eager_custom_python_api.h index 707ba6e6d81c3e..d15aeee04f15d2 100644 --- a/paddle/fluid/pybind/eager_custom_python_api.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -46,8 +46,12 @@ static PyObject *eager_api_linear(PyObject *self, } #if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000 && \ !(defined(_WIN32) || defined(WIN32))) - if (!FLAGS_use_legacy_gemm && !bias.is_dist_tensor()) // NOLINT - [[likely]] { // NOLINT + if (!FLAGS_use_legacy_gemm && // NOLINT + x.place().GetType() == phi::AllocationType::GPU && + weight.place().GetType() == phi::AllocationType::GPU && + bias.place().GetType() == phi::AllocationType::GPU && + !bias.is_dist_tensor()) // NOLINT + [[likely]] { // NOLINT // TODO(Pan Zhaowu): Add proper broadcast logic for batchsize unaligned // batch-gemm. Currently handles: (B..., k) x (k, n) -> (B..., n), with // 1D or scalar bias. @@ -64,9 +68,6 @@ static PyObject *eager_api_linear(PyObject *self, // calculation. const int64_t k_dim = x_original_shape[x_ndim_original - 1]; // Last dimension of X - const int64_t n_dim = - weight_original_shape[weight_ndim_original - - 1]; // Last dimension of Weight // --- Process 1D x and weight tensors by reshaping them to 2D if // necessary --- Subsequent operations will use these processed tensors. From ed20e33c6989b4490d424ea3fb126e997a6dc993 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 10:47:19 +0800 Subject: [PATCH 15/55] fix matmul diff --- paddle/phi/kernels/funcs/blas/blas_impl.cu.h | 63 ++++++++++++++------ paddle/phi/kernels/impl/matmul_kernel_impl.h | 59 +++++++++++------- 2 files changed, 81 insertions(+), 41 deletions(-) diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index ae7b67de6d642f..927deb193e73ea 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -28,6 +28,7 @@ COMMON_DECLARE_bool(enable_cublas_tensor_op_math); COMMON_DECLARE_bool(gemm_use_half_precision_compute_type); +COMMON_DECLARE_bool(use_legacy_gemm); namespace phi { namespace funcs { @@ -2580,24 +2581,49 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, } else { #endif // CUDA_VERSION >= 9010 dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM_STRIDED_BATCH(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &alpha, - B, - static_cast(ldb), - strideB, - A, - static_cast(lda), - strideA, - &beta, - C, - static_cast(ldc), - strideC, - static_cast(batchCount)); + if (N == 1 && ldc >= std::max(1, M) && !FLAGS_use_legacy_gemm) { + // No transpose result in these case, align with torch's behaviour. + // TODO(Pan Zhaowu): Integrate proper stride support for arbitrary input + // tensor. + CUBlas::GEMM_STRIDED_BATCH( + handle, + (cuTransA == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, + (cuTransB == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, + static_cast(M), + static_cast(N), + static_cast(K), + &alpha, + A, + static_cast(lda), + strideA, + B, + static_cast(ldb), + strideB, + &beta, + C, + static_cast(ldc), + strideC, + static_cast(batchCount)); + } else { + CUBlas::GEMM_STRIDED_BATCH(handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + &alpha, + B, + static_cast(ldb), + strideB, + A, + static_cast(lda), + strideA, + &beta, + C, + static_cast(ldc), + strideC, + static_cast(batchCount)); + } }); #if CUDA_VERSION >= 9010 @@ -2731,6 +2757,7 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 9010 T h_alpha = static_cast(alpha); T h_beta = static_cast(beta); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { CUBlas::GEMM_STRIDED_BATCH(handle, cuTransB, diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index 3ff015aa6fe368..b30be1110b12d0 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -145,16 +145,20 @@ void MatMulFunctionImplWithBlas( VLOG(3) << "MatMul's case 1"; Out->Resize(common::make_ddim({})); dev_ctx.template Alloc(Out); - blas.GEMM(CblasNoTrans, - CblasTrans, - 1, - 1, - M, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); + if constexpr (std::is_same::value) { + blas.CUDOT(M, X.data(), 1, Y.data(), 1, Out->data()); + } else { + blas.GEMM(CblasNoTrans, + CblasTrans, + 1, + 1, + M, + static_cast(1), + y_data, + x_data, + static_cast(flag), + dev_ctx.template Alloc(Out)); + } return; } @@ -407,19 +411,28 @@ void MatMulFunctionImplWithBlas( dev_ctx.template Alloc(Out)); } else { VLOG(3) << "MatMul's case 10"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - 0, - K * N); + // x batch == 1 and y batch > 1, transpose y and fold batch + DenseTensor transposedY = phi::TransposeLast2Dim(dev_ctx, Y); + blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasNoTrans : CblasTrans, + Y.numel() / K, + M, + K, + static_cast(1), + transposedY.data(), + x_data, + static_cast(flag), + dev_ctx.template Alloc(Out)); + // TODO(Pan Zhaowu): the actual layout is (B, N, M), need to reshape and + // transpose to (B, M, N) ty -> (B,N,K) + const auto out_original_shape = Out->dims(); + std::vector actual_dim = common::vectorize(transposedY.dims()); + actual_dim[actual_dim.size() - 1] = + out_original_shape[out_original_shape.size() - 2]; + Out->Resize(common::make_ddim(actual_dim)); + DenseTensor transposedOut = phi::TransposeLast2Dim(dev_ctx, *Out); + *Out = transposedOut; + Out->Resize(out_original_shape); } } else if (y_batch_size == 1) { if (!trans_x) { From 3ac50bc091d6480c806bdf6ebabd4f8ffb40c1c5 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 10:56:13 +0800 Subject: [PATCH 16/55] add flags related logic --- paddle/common/flags.cc | 11 +++ paddle/phi/kernels/impl/matmul_kernel_impl.h | 83 ++++++++++++++------ 2 files changed, 68 insertions(+), 26 deletions(-) diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index ff73c668439cb4..280964a4dc3f16 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -2314,6 +2314,17 @@ PHI_DEFINE_EXPORTED_bool( PHI_DEFINE_EXPORTED_bool(use_accuracy_compatible_kernel, false, "Whether use torch compatible version kernel."); +/** + * Legacy gemm related FLAG + * Name: FLAGS_use_legacy_gemm + * Since Version: 3.2.2 + * Value Range: bool, default=false + * Example: + * Note: Whether use legacy gemm kernel. + */ +PHI_DEFINE_EXPORTED_bool(use_legacy_gemm, + false, + "Whether use legacy gemm dispatch logics."); /** * Allocator Compact related FLAG diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index b30be1110b12d0..b7d9196a60f603 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -41,6 +41,7 @@ limitations under the License. */ #include "paddle/phi/kernels/full_kernel.h" COMMON_DECLARE_bool(cuda_core_int8_gemm); +COMMON_DECLARE_bool(use_legacy_gemm); namespace phi { @@ -145,9 +146,7 @@ void MatMulFunctionImplWithBlas( VLOG(3) << "MatMul's case 1"; Out->Resize(common::make_ddim({})); dev_ctx.template Alloc(Out); - if constexpr (std::is_same::value) { - blas.CUDOT(M, X.data(), 1, Y.data(), 1, Out->data()); - } else { + if (FLAGS_use_legacy_gemm) { blas.GEMM(CblasNoTrans, CblasTrans, 1, @@ -158,8 +157,24 @@ void MatMulFunctionImplWithBlas( x_data, static_cast(flag), dev_ctx.template Alloc(Out)); + return; + } else { + if constexpr (std::is_same::value) { + blas.CUDOT(M, X.data(), 1, Y.data(), 1, Out->data()); + } else { + blas.GEMM(CblasNoTrans, + CblasTrans, + 1, + 1, + M, + static_cast(1), + y_data, + x_data, + static_cast(flag), + dev_ctx.template Alloc(Out)); + } + return; } - return; } if (x_ndim == 1) { @@ -411,28 +426,44 @@ void MatMulFunctionImplWithBlas( dev_ctx.template Alloc(Out)); } else { VLOG(3) << "MatMul's case 10"; - // x batch == 1 and y batch > 1, transpose y and fold batch - DenseTensor transposedY = phi::TransposeLast2Dim(dev_ctx, Y); - blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasNoTrans : CblasTrans, - Y.numel() / K, - M, - K, - static_cast(1), - transposedY.data(), - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - // TODO(Pan Zhaowu): the actual layout is (B, N, M), need to reshape and - // transpose to (B, M, N) ty -> (B,N,K) - const auto out_original_shape = Out->dims(); - std::vector actual_dim = common::vectorize(transposedY.dims()); - actual_dim[actual_dim.size() - 1] = - out_original_shape[out_original_shape.size() - 2]; - Out->Resize(common::make_ddim(actual_dim)); - DenseTensor transposedOut = phi::TransposeLast2Dim(dev_ctx, *Out); - *Out = transposedOut; - Out->Resize(out_original_shape); + if (FLAGS_use_legacy_gemm) { + // x batch == 1 and y batch > 1, transpose y and fold batch + DenseTensor transposedY = phi::TransposeLast2Dim(dev_ctx, Y); + blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasNoTrans : CblasTrans, + Y.numel() / K, + M, + K, + static_cast(1), + transposedY.data(), + x_data, + static_cast(flag), + dev_ctx.template Alloc(Out)); + // TODO(Pan Zhaowu): the actual layout is (B, N, M), need to reshape and + // transpose to (B, M, N) ty -> (B,N,K) + const auto out_original_shape = Out->dims(); + std::vector actual_dim = common::vectorize(transposedY.dims()); + actual_dim[actual_dim.size() - 1] = + out_original_shape[out_original_shape.size() - 2]; + Out->Resize(common::make_ddim(actual_dim)); + DenseTensor transposedOut = phi::TransposeLast2Dim(dev_ctx, *Out); + *Out = transposedOut; + Out->Resize(out_original_shape); + } else { + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + dev_ctx.template Alloc(Out), + out_batch_size, + 0, + K * N); + } } } else if (y_batch_size == 1) { if (!trans_x) { From c2842b8f55db00b18c0ceddd00698bd81661269a Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 11:00:07 +0800 Subject: [PATCH 17/55] polish --- paddle/phi/kernels/impl/matmul_kernel_impl.h | 28 ++++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index b7d9196a60f603..5ef61ecd49d6c7 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -427,6 +427,20 @@ void MatMulFunctionImplWithBlas( } else { VLOG(3) << "MatMul's case 10"; if (FLAGS_use_legacy_gemm) { + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + dev_ctx.template Alloc(Out), + out_batch_size, + 0, + K * N); + } else { // x batch == 1 and y batch > 1, transpose y and fold batch DenseTensor transposedY = phi::TransposeLast2Dim(dev_ctx, Y); blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, @@ -449,20 +463,6 @@ void MatMulFunctionImplWithBlas( DenseTensor transposedOut = phi::TransposeLast2Dim(dev_ctx, *Out); *Out = transposedOut; Out->Resize(out_original_shape); - } else { - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - 0, - K * N); } } } else if (y_batch_size == 1) { From 9eaf7189102670959374773c0cc2ca123f9640d1 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 11:00:59 +0800 Subject: [PATCH 18/55] polish --- paddle/phi/kernels/impl/matmul_kernel_impl.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index 5ef61ecd49d6c7..8e7ba5ec344bc9 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -453,8 +453,9 @@ void MatMulFunctionImplWithBlas( x_data, static_cast(flag), dev_ctx.template Alloc(Out)); - // TODO(Pan Zhaowu): the actual layout is (B, N, M), need to reshape and - // transpose to (B, M, N) ty -> (B,N,K) + // The actual layout is (B, N, M), need to reshape and + // transpose to (B, M, N) ty -> (B,N,K), this requires transpose kernel + // to be implemented in high efficiency. const auto out_original_shape = Out->dims(); std::vector actual_dim = common::vectorize(transposedY.dims()); actual_dim[actual_dim.size() - 1] = From bfe5d6ea44d2dbf62b80e71efa25a51606967462 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 11:01:39 +0800 Subject: [PATCH 19/55] polish --- paddle/phi/kernels/impl/matmul_kernel_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index 8e7ba5ec344bc9..027c395b0bdf25 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -454,7 +454,7 @@ void MatMulFunctionImplWithBlas( static_cast(flag), dev_ctx.template Alloc(Out)); // The actual layout is (B, N, M), need to reshape and - // transpose to (B, M, N) ty -> (B,N,K), this requires transpose kernel + // transpose to (B, M, N), this requires batched transpose kernel // to be implemented in high efficiency. const auto out_original_shape = Out->dims(); std::vector actual_dim = common::vectorize(transposedY.dims()); From 2858364a56ff1a416cd35b50d4cbb0be846069d4 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 11:36:08 +0800 Subject: [PATCH 20/55] polish --- paddle/phi/kernels/impl/matmul_kernel_impl.h | 62 +++++++++++--------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index 027c395b0bdf25..b7363b56db17a0 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -159,20 +159,20 @@ void MatMulFunctionImplWithBlas( dev_ctx.template Alloc(Out)); return; } else { - if constexpr (std::is_same::value) { - blas.CUDOT(M, X.data(), 1, Y.data(), 1, Out->data()); - } else { - blas.GEMM(CblasNoTrans, - CblasTrans, - 1, - 1, - M, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) + blas.CUDOT(M, X.data(), 1, Y.data(), 1, Out->data()); +#else + blas.GEMM(CblasNoTrans, + CblasTrans, + 1, + 1, + M, + static_cast(1), + y_data, + x_data, + static_cast(flag), + dev_ctx.template Alloc(Out)); +#endif return; } } @@ -426,21 +426,8 @@ void MatMulFunctionImplWithBlas( dev_ctx.template Alloc(Out)); } else { VLOG(3) << "MatMul's case 10"; - if (FLAGS_use_legacy_gemm) { - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - 0, - K * N); - } else { +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) + if (!FLAGS_use_legacy_gemm) { // x batch == 1 and y batch > 1, transpose y and fold batch DenseTensor transposedY = phi::TransposeLast2Dim(dev_ctx, Y); blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, @@ -464,7 +451,24 @@ void MatMulFunctionImplWithBlas( DenseTensor transposedOut = phi::TransposeLast2Dim(dev_ctx, *Out); *Out = transposedOut; Out->Resize(out_original_shape); + } else // NOLINT +#else + { // NOLINT + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + dev_ctx.template Alloc(Out), + out_batch_size, + 0, + K * N); } +#endif } } else if (y_batch_size == 1) { if (!trans_x) { From 9f5c03d6fd824ec76bda12dd4ea02cf557b1403c Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 11:44:50 +0800 Subject: [PATCH 21/55] polish --- paddle/phi/kernels/impl/matmul_kernel_impl.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index b7363b56db17a0..61c6f471abc7ef 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -451,7 +451,7 @@ void MatMulFunctionImplWithBlas( DenseTensor transposedOut = phi::TransposeLast2Dim(dev_ctx, *Out); *Out = transposedOut; Out->Resize(out_original_shape); - } else // NOLINT + } else { // NOLINT #else { // NOLINT blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, @@ -469,6 +469,7 @@ void MatMulFunctionImplWithBlas( K * N); } #endif + } } } else if (y_batch_size == 1) { if (!trans_x) { From 7ee0a723dc8fd9b626ed2b074e39742175e4e338 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 12:17:07 +0800 Subject: [PATCH 22/55] polish --- paddle/phi/kernels/impl/matmul_kernel_impl.h | 59 +++++++++++--------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index 61c6f471abc7ef..f38ee4f70576fe 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -160,20 +160,25 @@ void MatMulFunctionImplWithBlas( return; } else { #if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) - blas.CUDOT(M, X.data(), 1, Y.data(), 1, Out->data()); + if (dev_ctx.GetPlace().type() == phi::PlaceType::GPU) { + blas.CUDOT(M, X.data(), 1, Y.data(), 1, Out->data()); + } else { #else - blas.GEMM(CblasNoTrans, - CblasTrans, - 1, - 1, - M, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); + { + blas.GEMM(CblasNoTrans, + CblasTrans, + 1, + 1, + M, + static_cast(1), + y_data, + x_data, + static_cast(flag), + dev_ctx.template Alloc(Out)); + } #endif - return; + return; + } } } @@ -453,21 +458,21 @@ void MatMulFunctionImplWithBlas( Out->Resize(out_original_shape); } else { // NOLINT #else - { // NOLINT - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - 0, - K * N); - } + { // NOLINT + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + dev_ctx.template Alloc(Out), + out_batch_size, + 0, + K * N); + } #endif } } From 5452edc449ce39fef9605f7556d38afc1c057f02 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 12:31:23 +0800 Subject: [PATCH 23/55] polish --- paddle/phi/kernels/impl/matmul_kernel_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index f38ee4f70576fe..266f97eedaa808 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -160,7 +160,7 @@ void MatMulFunctionImplWithBlas( return; } else { #if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) - if (dev_ctx.GetPlace().type() == phi::PlaceType::GPU) { + if (std::is_same::value) { blas.CUDOT(M, X.data(), 1, Y.data(), 1, Out->data()); } else { #else From 7b8ad72fd7a4b0561db20e28f707e23e2e6ca996 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 12:53:58 +0800 Subject: [PATCH 24/55] polish --- paddle/phi/kernels/impl/matmul_kernel_impl.h | 40 +++++++++----------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index 266f97eedaa808..ef658418f13c94 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -162,8 +162,8 @@ void MatMulFunctionImplWithBlas( #if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) if (std::is_same::value) { blas.CUDOT(M, X.data(), 1, Y.data(), 1, Out->data()); - } else { -#else + } else // NOLINT +#endif { blas.GEMM(CblasNoTrans, CblasTrans, @@ -176,9 +176,7 @@ void MatMulFunctionImplWithBlas( static_cast(flag), dev_ctx.template Alloc(Out)); } -#endif - return; - } + return; } } @@ -456,24 +454,22 @@ void MatMulFunctionImplWithBlas( DenseTensor transposedOut = phi::TransposeLast2Dim(dev_ctx, *Out); *Out = transposedOut; Out->Resize(out_original_shape); - } else { // NOLINT -#else - { // NOLINT - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - 0, - K * N); - } + } else // NOLINT #endif + { // NOLINT + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + dev_ctx.template Alloc(Out), + out_batch_size, + 0, + K * N); } } } else if (y_batch_size == 1) { From 708db20d331d8d22686878725d83b2be0fdcc1ce Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 14:04:25 +0800 Subject: [PATCH 25/55] bypass win32 --- paddle/phi/kernels/impl/matmul_kernel_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index ef658418f13c94..0ad6ee9e69dd92 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -159,7 +159,7 @@ void MatMulFunctionImplWithBlas( dev_ctx.template Alloc(Out)); return; } else { -#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32) if (std::is_same::value) { blas.CUDOT(M, X.data(), 1, Y.data(), 1, Out->data()); } else // NOLINT From 034a087850f6e5979909046731e9fc349f2394da Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 8 Dec 2025 16:02:34 +0800 Subject: [PATCH 26/55] stash --- paddle/fluid/pybind/eager_custom_python_api.h | 153 +---------- paddle/phi/kernels/gpu/linear_v2_kernel.cu | 257 ++++++++++++++++++ 2 files changed, 262 insertions(+), 148 deletions(-) create mode 100644 paddle/phi/kernels/gpu/linear_v2_kernel.cu diff --git a/paddle/fluid/pybind/eager_custom_python_api.h b/paddle/fluid/pybind/eager_custom_python_api.h index 0c667f32529793..bfb885c499d876 100644 --- a/paddle/fluid/pybind/eager_custom_python_api.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -44,156 +44,13 @@ static PyObject *eager_api_linear(PyObject *self, if (InputsContainDistTensor(&mesh, x, weight, bias)) { ConvertAllInputsToDistTensor(mesh, x, weight, bias); } -#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000 && \ - !(defined(_WIN32) || defined(WIN32))) - if (!FLAGS_use_legacy_gemm && // NOLINT - x.place().GetType() == phi::AllocationType::GPU && - weight.place().GetType() == phi::AllocationType::GPU && - bias.place().GetType() == phi::AllocationType::GPU && - !bias.is_dist_tensor()) // NOLINT - [[likely]] { // NOLINT - // TODO(Pan Zhaowu): Add proper broadcast logic for batchsize unaligned - // batch-gemm. Currently handles: (B..., k) x (k, n) -> (B..., n), with - // 1D or scalar bias. - // --- Original input tensor dimensions and values --- - const auto &x_original_shape = x.shape(); - const size_t x_ndim_original = x_original_shape.size(); - - const auto &weight_original_shape = weight.shape(); - const size_t weight_ndim_original = weight_original_shape.size(); - - // Determine the 'k' and 'n' dimensions based on original shapes. - // These values are crucial for potential 1D reshaping and output shape - // calculation. - const int64_t k_dim = - x_original_shape[x_ndim_original - 1]; // Last dimension of X - - // --- Process 1D x and weight tensors by reshaping them to 2D if - // necessary --- Subsequent operations will use these processed tensors. - paddle::Tensor x_processed = - x; // Start with original, possibly reassign if reshaped - paddle::Tensor weight_processed = - weight; // Start with original, possibly reassign if reshaped - - // If x is 1D (e.g., shape [k]), reshape it to [1, k] to fit the (B..., - // k) x (k, n) pattern. This effectively treats a 1D vector as a row - // vector for matrix multiplication. - if (x_ndim_original == 1) [[unlikely]] { - x_processed = reshape_ad_func(x, {1, k_dim}); - } - // If weight is 1D (e.g., shape [n]), reshape it to [k, 1]. - // This implies 'n' was 1 in the original context, and 'k' is determined - // by x. This effectively treats a 1D vector as a column vector. Note: - // This 'else if' means if both x and weight are 1D, only x gets - // reshaped currently. For (k) x (n) where n != k and both are 1D, the - // semantics are ambiguous and not directly covered by (B..., k) x (k, - // n). The current design implies weight is at least 2D or is treated as - // [k, 1] if 1D. - else if (weight_ndim_original == 1) [[unlikely]] { // NOLINT - weight_processed = reshape_ad_func(weight, {k_dim, 1}); - } - - // --- Recalculate dimensions based on processed tensors --- - // These dimensions will be used for the actual GEMM operation. - const auto &x_shape_current = x_processed.shape(); - const size_t x_ndim_current = x_shape_current.size(); - - const auto &weight_shape_current = weight_processed.shape(); - const size_t weight_ndim_current = weight_shape_current.size(); - - // Effective 'k' and 'n' for GEMM. - const int64_t k_effective = x_shape_current[x_ndim_current - 1]; - const int64_t n_effective = - weight_shape_current[weight_ndim_current - 1]; - - // --- Determine the final output shape --- - // Start with the processed x's shape, then modify the last dimension. - std::vector output_shape_vec = x_shape_current; - output_shape_vec[x_ndim_current - 1] = n_effective; - - // If the original x was 1D, the processed x became [1, k]. - // The output_shape_vec would be [1, n]. - // For 1D input, we usually expect a 1D output (shape [n]) if possible. - if (x_ndim_original == 1 && output_shape_vec.size() > 1 && - output_shape_vec[0] == 1) { - output_shape_vec.erase( - output_shape_vec - .begin()); // Remove the artificial batch dimension - } - - // Calculate the total number of elements in the batch dimensions of X. - // This is used for reshaping X into a 2D matrix for addmm_ad_func. - const int64_t x_batch_numel = - std::accumulate(output_shape_vec.begin(), - output_shape_vec.end() - 1, - 1LL, - std::multiplies()); - - // --- Bias handling and GEMM execution --- - // The condition now uses the processed weight's shape. - // This branch typically handles (B..., k) x (k, n) where n > 1. - if (weight_shape_current[0] > 1 && weight_shape_current[1] > 1) { - paddle::Tensor bias_1d = - bias; // Create a mutable copy if modification is needed - // Align bias' shape to 'n_effective'. If bias.numel() != n_effective, - // tile it. - if (bias.numel() != n_effective) { - bias_1d = tile_ad_func(bias, {static_cast(n_effective)}); - } - // Execute fused GEMM with epilogue. - auto [out, _] = fused_gemm_epilogue_ad_func( - x_processed, weight_processed, bias_1d, false, false, "none"); - - // If original x was 1D and output_shape_vec is 1D (i.e., [n]), - // but fused_gemm_epilogue_ad_func returns a 2D tensor ([1, n]), - // reshape it back to the desired 1D output shape. - if (x_ndim_original == 1 && out.shape().size() == 2 && - output_shape_vec.size() == 1) { - out = reshape_ad_func(out, output_shape_vec); - } - - PyEval_RestoreThread(tstate); - tstate = nullptr; - return ToPyObject(out); - } else { - // This branch handles cases where weight_processed is effectively 2D - // with one dimension being 1, e.g., (B..., k) x (k, 1) resulting in - // (B..., 1). Or when weight_processed was originally 1D and reshaped - // to [k, 1]. - - // Reshape bias to [1, n_effective] then tile to [x_batch_numel, 1] - // for addmm_ad_func. - paddle::Tensor bias_2d = tile_ad_func( - reshape_ad_func(bias, {1, n_effective}), {x_batch_numel, 1}); - - // Perform matrix multiplication using addmm_ad_func. - // x_processed is reshaped to 2D [x_batch_numel, k_effective] for the - // multiplication. - auto out = addmm_ad_func( - bias_2d, - reshape_ad_func(x_processed, {x_batch_numel, k_effective}), - weight_processed, - 1.0, - 1.0); - - // Reshape the final output to the target output_shape_vec. - out = reshape_ad_func(out, output_shape_vec); - - PyEval_RestoreThread(tstate); - tstate = nullptr; - return ToPyObject(out); - } - } else // NOLINT(readability/braces) -#endif - { - auto mm_out = matmul_ad_func(x, weight, false, false); - auto out = add_ad_func(mm_out, bias); + auto mm_out = matmul_ad_func(x, weight, false, false); + auto out = add_ad_func(mm_out, bias); - PyEval_RestoreThread(tstate); - tstate = nullptr; - return ToPyObject(out); - } + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(out); } else { const phi::distributed::ProcessMesh *mesh = nullptr; if (InputsContainDistTensor(&mesh, x, weight)) { diff --git a/paddle/phi/kernels/gpu/linear_v2_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_kernel.cu new file mode 100644 index 00000000000000..adba09b011cd60 --- /dev/null +++ b/paddle/phi/kernels/gpu/linear_v2_kernel.cu @@ -0,0 +1,257 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include "glog/logging.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + +#ifdef PADDLE_WITH_HIP +#include +#include +#else +#include // NOLINT +#include "cuda.h" // NOLINT +#endif + +#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) || \ + defined(PADDLE_WITH_HIP) + +#include "paddle/common/flags.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/scope_guard.h" +#include "paddle/utils/optional.h" +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 +#include "paddle/phi/backends/dynload/cublasLt.h" +#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" +#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" +#elif defined(PADDLE_WITH_HIP) +#include "paddle/phi/backends/dynload/hipblasLt.h" +#include "paddle/phi/backends/gpu/rocm/rocm_helper.h" +#include "paddle/phi/kernels/funcs/blas/blaslt_impl.hip.h" +#endif +namespace phi { + +/* +#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000 && \ + !(defined(_WIN32) || defined(WIN32))) + if (!FLAGS_use_legacy_gemm && // NOLINT + x.place().GetType() == phi::AllocationType::GPU && + weight.place().GetType() == phi::AllocationType::GPU && + bias.place().GetType() == phi::AllocationType::GPU && + !bias.is_dist_tensor()) // NOLINT + [[likely]] { // NOLINT + // TODO(Pan Zhaowu): Add proper broadcast logic for batchsize unaligned + // batch-gemm. Currently handles: (B..., k) x (k, n) -> (B..., n), with + // 1D or scalar bias. + + // --- Original input tensor dimensions and values --- + const auto &x_original_shape = x.shape(); + const size_t x_ndim_original = x_original_shape.size(); + + const auto &weight_original_shape = weight.shape(); + const size_t weight_ndim_original = weight_original_shape.size(); + + const int64_t k_dim = + x_original_shape[x_ndim_original - 1]; // Last dimension of X + + paddle::Tensor x_processed = + x; // Start with original, possibly reassign if reshaped + paddle::Tensor weight_processed = + weight; // Start with original, possibly reassign if reshaped + + // If x is 1D (e.g., shape [k]), reshape it to [1, k] to fit the (B..., + // k) x (k, n) pattern. This effectively treats a 1D vector as a row + // vector for matrix multiplication. + if (x_ndim_original == 1) [[unlikely]] { + x_processed = reshape_ad_func(x, {1, k_dim}); + } else if (weight_ndim_original == 1) [[unlikely]] { // NOLINT + weight_processed = reshape_ad_func(weight, {k_dim, 1}); + } + + // --- Recalculate dimensions based on processed tensors --- + const auto &x_shape_current = x_processed.shape(); + const size_t x_ndim_current = x_shape_current.size(); + + const auto &weight_shape_current = weight_processed.shape(); + const size_t weight_ndim_current = weight_shape_current.size(); + + const int64_t k_effective = x_shape_current[x_ndim_current - 1]; + const int64_t n_effective = + weight_shape_current[weight_ndim_current - 1]; + + // --- Determine the final output shape --- + std::vector output_shape_vec = x_shape_current; + output_shape_vec[x_ndim_current - 1] = n_effective; + + // If the original x was 1D, the processed x became [1, k]. + if (x_ndim_original == 1 && output_shape_vec.size() > 1 && + output_shape_vec[0] == 1) { + output_shape_vec.erase( + output_shape_vec + .begin()); // Remove the artificial batch dimension + } + + // This is used for reshaping X into a 2D matrix for addmm_ad_func. + const int64_t x_batch_numel = + std::accumulate(output_shape_vec.begin(), + output_shape_vec.end() - 1, + 1LL, + std::multiplies()); + + // --- Bias handling and GEMM execution --- + // The condition now uses the processed weight's shape. + // This branch typically handles (B..., k) x (k, n) where n > 1. + if (weight_shape_current[0] > 1 && weight_shape_current[1] > 1) { + paddle::Tensor bias_1d = + bias; // Create a mutable copy if modification is needed + // Align bias' shape to 'n_effective'. If bias.numel() != n_effective, + // tile it. + if (bias.numel() != n_effective) { + bias_1d = tile_ad_func(bias, {static_cast(n_effective)}); + } + // Execute fused GEMM with epilogue. + auto [out, _] = fused_gemm_epilogue_ad_func( + x_processed, weight_processed, bias_1d, false, false, "none"); + + // If original x was 1D and output_shape_vec is 1D (i.e., [n]), + // but fused_gemm_epilogue_ad_func returns a 2D tensor ([1, n]), + // reshape it back to the desired 1D output shape. + if (x_ndim_original == 1 && out.shape().size() == 2 && + output_shape_vec.size() == 1) { + out = reshape_ad_func(out, output_shape_vec); + } + + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(out); + } else { + // This branch handles cases where weight_processed is effectively 2D + // with one dimension being 1, e.g., (B..., k) x (k, 1) resulting in + // (B..., 1). Or when weight_processed was originally 1D and reshaped + // to [k, 1]. + + // Reshape bias to [1, n_effective] then tile to [x_batch_numel, 1] + // for addmm_ad_func. + paddle::Tensor bias_2d = tile_ad_func( + reshape_ad_func(bias, {1, n_effective}), {x_batch_numel, 1}); + + // Perform matrix multiplication using addmm_ad_func. + // x_processed is reshaped to 2D [x_batch_numel, k_effective] for the + // multiplication. + auto out = addmm_ad_func( + bias_2d, + reshape_ad_func(x_processed, {x_batch_numel, k_effective}), + weight_processed, + 1.0, + 1.0); + + out = reshape_ad_func(out, output_shape_vec); + + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(out); + } + } else // NOLINT(readability/braces) +#endif +*/ +// we don't receive 2+d tensor as weight +inline std::tuple canonicalize_dims( + const DenseTensor& input, const DenseTensor& weight) { + const auto x_dims = input.dims(); + const auto y_dims = weight.dims(); + PADDLE_ENFORCE_LE_GE( + y_dims.size(), + 2, + platform::errors::InvalidArgument("Y must be at most 2D")); + + const int64_t N = y_dims.size() < 2 ? 1 : y_dims[y_dims.size() - 1]; + const int64_t K = y_dims.size() < 2 ? y_dims[0] : y_dims[y_dims.size() - 2]; + + int64_t M = x_dims.size() >= 2 ? x_dims[x_dims.size() - 2] : 1; + if (x_dims.size() > 2) { + // Accumulate the batch dims for input + for (int64_t i = 0; i < x_dims.size() - 2; ++i) { + M *= x_dims[i]; + } + } + + return {M, N, K}; +} +template +void LinearV2Kernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& bias, + DenseTensor* out) { + PADDLE_ENFORCE_LE(bias.dims().size(), + 1, + platform::errors::InvalidArgument("Bias must be 1D")); + if (out->numel() == 0) { + dev_ctx.template Alloc(out); + return; + } +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION < 11060 || \ + defined(PADDLE_WITH_HIP) + // NOTE(Pan Zhaowu): Fallback logic for legacy CUDA version or DCU. + // TODO(Pan Zhaowu): Implement this + return; +#else + dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + const auto [M, N, K] = canonicalize_dims(input, weight); + + if (bias.numel() != N) { + // only broadcast to 1D bias whatsoever + // pass1: scalar to 1D + } + if (N > 1 && K > 1) { + // CublasLt path with bias add epilogue + phi::funcs::LinearWithCublasLt::Run( + dev_ctx, + &input, + &weight, + out, + static_cast(bias.data()), + nullptr, + M, + N, + K, + false, + false, + phi::funcs::MatmulFusedType::kMatmulBias); + } else { + // Cublas path with beta==1 bias adding. + blas.GEMM(dev_ctx, ) + } +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL(linear_v2, + GPU, + ALL_LAYOUT, + phi::LinearV2Kernel, + float, + double, + phi::float16, + phi::bfloat16) {} From a98a620aff3130782d041fbe44e8836e6c063320 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 10 Dec 2025 11:14:40 +0800 Subject: [PATCH 27/55] stash --- paddle/phi/kernels/gpu/linear_v2_kernel.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/phi/kernels/gpu/linear_v2_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_kernel.cu index adba09b011cd60..314fc47b306a85 100644 --- a/paddle/phi/kernels/gpu/linear_v2_kernel.cu +++ b/paddle/phi/kernels/gpu/linear_v2_kernel.cu @@ -206,6 +206,9 @@ void LinearV2Kernel(const Context& dev_ctx, PADDLE_ENFORCE_LE(bias.dims().size(), 1, platform::errors::InvalidArgument("Bias must be 1D")); + PADDLE_ENFORCE_LE(weight.dims().size(), + 2, + platform::errors::InvalidArgument("Weight must be 2D")); if (out->numel() == 0) { dev_ctx.template Alloc(out); return; @@ -222,6 +225,7 @@ void LinearV2Kernel(const Context& dev_ctx, if (bias.numel() != N) { // only broadcast to 1D bias whatsoever // pass1: scalar to 1D + phi::Tile(dev_ctx, bias, {N}, out); } if (N > 1 && K > 1) { // CublasLt path with bias add epilogue From 028ae960082ad1c3d46d236067b972866738e58b Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Thu, 25 Dec 2025 19:59:33 +0800 Subject: [PATCH 28/55] align fwd in gpu --- paddle/common/flags.cc | 12 + paddle/phi/infermeta/backward.cc | 97 ++++++ paddle/phi/infermeta/backward.h | 8 + paddle/phi/infermeta/ternary.cc | 68 +++++ paddle/phi/infermeta/ternary.h | 5 + .../phi/kernels/gpu/linear_v2_grad_kernel.cu | 275 ++++++++++++++++++ paddle/phi/kernels/gpu/linear_v2_kernel.cu | 262 ++++++----------- paddle/phi/ops/yaml/backward.yaml | 9 + paddle/phi/ops/yaml/ops.yaml | 10 + python/paddle/nn/functional/common.py | 82 +++--- 10 files changed, 626 insertions(+), 202 deletions(-) create mode 100644 paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index cfefab22493701..1755e8fcf0d311 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -2379,6 +2379,18 @@ PHI_DEFINE_EXPORTED_bool(use_legacy_gemm, false, "Whether use legacy gemm dispatch logics."); +/** + * Legacy gemm related FLAG + * Name: FLAGS_use_legacy_linear + * Since Version: 3.2.2 + * Value Range: bool, default=false + * Example: + * Note: Whether use legacy linear kernel. + */ +PHI_DEFINE_EXPORTED_bool(use_legacy_linear, + false, + "Whether use legacy linear dispatch logics."); + /** * Allocator Compact related FLAG * Name: FLAGS_enable_compact_mem diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 94750577a5debc..0c2a4eb9b1dd38 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -492,6 +492,103 @@ void CudnnLSTMGradInferMeta( UnchangedMultiInferMeta(weight_list.get(), weight_list_grad); } } +void LinearV2GradInferMeta(const MetaTensor& input, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& out_grad, + MetaTensor* input_grad, + MetaTensor* weight_grad, + MetaTensor* bias_grad) { + auto input_dims = input.dims(); + auto weight_dims = weight.dims(); + auto bias_dims = bias.dims(); + auto dout_dims = out_grad.dims(); + + PADDLE_ENFORCE_GE(dout_dims.size(), + 2, + common::errors::InvalidArgument( + "The Input tensor DOut's dimension of LinearV2Op " + " should be >= 2, but got %d.", + dout_dims.size())); + + PADDLE_ENFORCE_EQ(weight_dims.size(), + 2, + common::errors::InvalidArgument( + "The Input tensor Y's dimension of LinearV2Op " + " should be 2, but got %d.", + weight_dims.size())); + + PADDLE_ENFORCE_GE(input_dims.size(), + 2, + common::errors::InvalidArgument( + "The Input tensor X's dimension of LinearV2Op " + " should be >= 2, but got %d.", + input_dims.size())); + + PADDLE_ENFORCE_EQ( + dout_dims.size(), + input_dims.size(), + common::errors::InvalidArgument( + "The Input tensor DOut's and X's dimension of " + "LinearV2Op " + " should be the same, but got DOut's dim = %d and X's = %d.", + dout_dims.size(), + input_dims.size())); + + auto dout_mat_dims = common::flatten_to_2d(dout_dims, dout_dims.size() - 1); + + PADDLE_ENFORCE_EQ( + dout_mat_dims[1], + weight_dims[1], + common::errors::InvalidArgument( + "The last dimension of DOut should be equal with Y's last " + "dimension. But received DOut[-1] = [%d], Y[1] = [%d].", + dout_mat_dims[1], + weight_dims[1])); + + for (int32_t i = 0; i + 2 < input_dims.size(); ++i) { + if (dout_dims[i] > 0 && input_dims[i] > 0) { + PADDLE_ENFORCE_EQ( + dout_dims[i], + input_dims[i], + common::errors::InvalidArgument( + "The i dimension of DOut should be equal with i dimension of X." + "But received DOut[%d] = [%d], Y[%d] = [%d].", + i, + dout_dims[i], + i, + input_dims[i])); + } + } + + auto k_from_dout = dout_dims[input_dims.size() - 2]; + auto k_from_input = input_dims[input_dims.size() - 2]; + + bool check_k = + (k_from_dout < 0 || k_from_input < 0) || (k_from_dout == k_from_input); + PADDLE_ENFORCE_EQ( + check_k, + true, + common::errors::InvalidArgument( + "K from dout and x is not same, k_from_dout is [%d], k_from_x is[%d]", + k_from_dout, + k_from_input)); + + if (input_grad) { + input_grad->set_dims(input_dims); + input_grad->set_dtype(input.dtype()); + } + + if (weight_grad) { + weight_grad->set_dims(weight_dims); + weight_grad->set_dtype(weight.dtype()); + } + + if (bias_grad) { + bias_grad->set_dims(bias_dims); + bias_grad->set_dtype(bias.dtype()); + } +} void LSTMGradInferMeta(const MetaTensor& input, const MetaTensor& h0, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index d92cfb5ca9592d..f96ded2c536b16 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -183,6 +183,14 @@ PADDLE_API void CudnnLSTMGradInferMeta( MetaTensor* init_c_grad, std::vector weight_list_grad); +PADDLE_API void LinearV2GradInferMeta(const MetaTensor& input, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& out_grad, + MetaTensor* input_grad, + MetaTensor* weight_grad, + MetaTensor* bias_grad); + PADDLE_API void LSTMGradInferMeta(const MetaTensor& input, const MetaTensor& h0, const MetaTensor& c0, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 20ac55035f6fcc..afc997b24b2680 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1550,6 +1550,74 @@ void LerpInferMeta(const MetaTensor& x, out->share_lod(x); } +void LinearV2InferMeta(const MetaTensor& input, + const MetaTensor& weight, + const MetaTensor& bias, + MetaTensor* out) { + const auto& input_dims = input.dims(); + const auto& weight_dims = weight.dims(); + const int64_t weight_ndim = weight.dims().size(); + const bool is_bias_need_broadcast = bias.numel() == 1; + const bool is_valid_bias = + is_bias_need_broadcast || bias.numel() == weight.dims()[weight_ndim - 1]; + + PADDLE_ENFORCE_EQ(weight_dims.size(), + 2, + common::errors::InvalidArgument( + "The Input tensor Y's dimension of FusedGemmEpilogueOp " + " should be 2, but got %d.", + weight_dims.size())); + PADDLE_ENFORCE_GE(input_dims.size(), + 1, + common::errors::InvalidArgument( + "The Input tensor X's dimension of FusedGemmEpilogueOp " + " should be >= 1, but got %d.", + input_dims.size())); + PADDLE_ENFORCE_LE( + bias.dims().size(), + 1, + common::errors::InvalidArgument("Bias must be lesser than 1D")); + + PADDLE_ENFORCE_EQ(is_valid_bias, + true, + common::errors::InvalidArgument( + "Bias must be equal (or can be broadcasted) to the " + "last dimension of weight")); + + // regard [k] x [k, n] -> [n] + if (input_dims.size() == 1) { + out->set_dims(common::make_ddim({weight_dims[1]})); + out->set_dtype(input.dtype()); + return; + } + + auto input_mat_dims = + common::flatten_to_2d(input_dims, input_dims.size() - 1); + + auto input_rank = input_dims.size(); + int64_t K_from_input = input_mat_dims[1]; + int64_t K_from_weight = weight_dims[0]; + PADDLE_ENFORCE_EQ( + K_from_input, + K_from_weight, + common::errors::InvalidArgument( + "The last dimension of X should be equal with Y's first dimension." + "But received X[-1] = [%d], Y[0] = [%d].", + K_from_input, + K_from_weight)); + std::vector out_dims; + out_dims.reserve(input_rank); + + for (int i = 0; i + 2 < input_rank; ++i) { + out_dims.push_back(input_dims[i]); + } + out_dims.push_back(input_dims[input_rank - 2]); + + out_dims.push_back(weight_dims[1]); + out->set_dims(common::make_ddim(out_dims)); + out->set_dtype(input.dtype()); +} + void LinspaceInferMeta(const MetaTensor& start, const MetaTensor& stop, const MetaTensor& number, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index d3b35e5462e49d..587e18f1324f16 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -274,6 +274,11 @@ PADDLE_API void LerpInferMeta(const MetaTensor& x, const MetaTensor& weight, MetaTensor* out); +PADDLE_API void LinearV2InferMeta(const MetaTensor& input, + const MetaTensor& weight, + const MetaTensor& bias, + MetaTensor* out); + PADDLE_API void LinspaceRawInferMeta(const MetaTensor& start, const MetaTensor& stop, const MetaTensor& number, diff --git a/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu new file mode 100644 index 00000000000000..4dcf98517930ca --- /dev/null +++ b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu @@ -0,0 +1,275 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include "glog/logging.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/matrix_reduce.h" +#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" + +#ifdef PADDLE_WITH_HIP +#include +#include +#else +#include // NOLINT +#include "cuda.h" // NOLINT +#endif + +#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) || \ + defined(PADDLE_WITH_HIP) + +#include "paddle/common/flags.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/scope_guard.h" +#include "paddle/utils/optional.h" +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 +#include "paddle/phi/backends/dynload/cublasLt.h" +#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" +#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" +#elif defined(PADDLE_WITH_HIP) +#include "paddle/phi/backends/dynload/hipblasLt.h" +#include "paddle/phi/backends/gpu/rocm/rocm_helper.h" +#include "paddle/phi/kernels/funcs/blas/blaslt_impl.hip.h" +#endif + +#endif +namespace phi { + +/* +#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000 && \ + !(defined(_WIN32) || defined(WIN32))) + if (!FLAGS_use_legacy_gemm && // NOLINT + x.place().GetType() == phi::AllocationType::GPU && + weight.place().GetType() == phi::AllocationType::GPU && + bias.place().GetType() == phi::AllocationType::GPU && + !bias.is_dist_tensor()) // NOLINT + [[likely]] { // NOLINT + // TODO(Pan Zhaowu): Add proper broadcast logic for batchsize unaligned + // batch-gemm. Currently handles: (B..., k) x (k, n) -> (B..., n), with + // 1D or scalar bias. + + // --- Original input tensor dimensions and values --- + const auto &x_original_shape = x.shape(); + const size_t x_ndim_original = x_original_shape.size(); + + const auto &weight_original_shape = weight.shape(); + const size_t weight_ndim_original = weight_original_shape.size(); + + const int64_t k_dim = + x_original_shape[x_ndim_original - 1]; // Last dimension of X + + paddle::Tensor x_processed = + x; // Start with original, possibly reassign if reshaped + paddle::Tensor weight_processed = + weight; // Start with original, possibly reassign if reshaped + + // If x is 1D (e.g., shape [k]), reshape it to [1, k] to fit the (B..., + // k) x (k, n) pattern. This effectively treats a 1D vector as a row + // vector for matrix multiplication. + if (x_ndim_original == 1) [[unlikely]] { + x_processed = reshape_ad_func(x, {1, k_dim}); + } else if (weight_ndim_original == 1) [[unlikely]] { // NOLINT + weight_processed = reshape_ad_func(weight, {k_dim, 1}); + } + + // --- Recalculate dimensions based on processed tensors --- + const auto &x_shape_current = x_processed.shape(); + const size_t x_ndim_current = x_shape_current.size(); + + const auto &weight_shape_current = weight_processed.shape(); + const size_t weight_ndim_current = weight_shape_current.size(); + + const int64_t k_effective = x_shape_current[x_ndim_current - 1]; + const int64_t n_effective = + weight_shape_current[weight_ndim_current - 1]; + + // --- Determine the final output shape --- + std::vector output_shape_vec = x_shape_current; + output_shape_vec[x_ndim_current - 1] = n_effective; + + // If the original x was 1D, the processed x became [1, k]. + if (x_ndim_original == 1 && output_shape_vec.size() > 1 && + output_shape_vec[0] == 1) { + output_shape_vec.erase( + output_shape_vec + .begin()); // Remove the artificial batch dimension + } + + // This is used for reshaping X into a 2D matrix for addmm_ad_func. + const int64_t x_batch_numel = + std::accumulate(output_shape_vec.begin(), + output_shape_vec.end() - 1, + 1LL, + std::multiplies()); + + // --- Bias handling and GEMM execution --- + // The condition now uses the processed weight's shape. + // This branch typically handles (B..., k) x (k, n) where n > 1. + if (weight_shape_current[0] > 1 && weight_shape_current[1] > 1) { + paddle::Tensor bias_1d = + bias; // Create a mutable copy if modification is needed + // Align bias' shape to 'n_effective'. If bias.numel() != n_effective, + // tile it. + if (bias.numel() != n_effective) { + bias_1d = tile_ad_func(bias, {static_cast(n_effective)}); + } + // Execute fused GEMM with epilogue. + auto [out, _] = fused_gemm_epilogue_ad_func( + x_processed, weight_processed, bias_1d, false, false, "none"); + + // If original x was 1D and output_shape_vec is 1D (i.e., [n]), + // but fused_gemm_epilogue_ad_func returns a 2D tensor ([1, n]), + // reshape it back to the desired 1D output shape. + if (x_ndim_original == 1 && out.shape().size() == 2 && + output_shape_vec.size() == 1) { + out = reshape_ad_func(out, output_shape_vec); + } + + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(out); + } else { + // This branch handles cases where weight_processed is effectively 2D + // with one dimension being 1, e.g., (B..., k) x (k, 1) resulting in + // (B..., 1). Or when weight_processed was originally 1D and reshaped + // to [k, 1]. + + // Reshape bias to [1, n_effective] then tile to [x_batch_numel, 1] + // for addmm_ad_func. + paddle::Tensor bias_2d = tile_ad_func( + reshape_ad_func(bias, {1, n_effective}), {x_batch_numel, 1}); + + // Perform matrix multiplication using addmm_ad_func. + // x_processed is reshaped to 2D [x_batch_numel, k_effective] for the + // multiplication. + auto out = addmm_ad_func( + bias_2d, + reshape_ad_func(x_processed, {x_batch_numel, k_effective}), + weight_processed, + 1.0, + 1.0); + + out = reshape_ad_func(out, output_shape_vec); + + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(out); + } + } else // NOLINT(readability/braces) +#endif +// we don't receive 2+d tensor as weight +inline std::tuple canonicalize_dims( + const DenseTensor& input, const DenseTensor& weight) { + const auto x_dims = input.dims(); + const auto y_dims = weight.dims(); + PADDLE_ENFORCE_LE_GE( + y_dims.size(), + 2, + platform::errors::InvalidArgument("Y must be at most 2D")); + + const int64_t N = y_dims.size() < 2 ? 1 : y_dims[y_dims.size() - 1]; + const int64_t K = y_dims.size() < 2 ? y_dims[0] : y_dims[y_dims.size() - 2]; + + int64_t M = x_dims.size() >= 2 ? x_dims[x_dims.size() - 2] : 1; + if (x_dims.size() > 2) { + // Accumulate the batch dims for input + for (int64_t i = 0; i < x_dims.size() - 2; ++i) { + M *= x_dims[i]; + } + } + + return {M, N, K}; +} +*/ +template +void LinearV2GradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& bias, + const DenseTensor& out_grad, + DenseTensor* input_grad, + DenseTensor* weight_grad, + DenseTensor* bias_grad) { + /* + reduce bias with dummy add_grad, perform matmul grad, reshape grad dim + #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION > 11060 && + !defined(PADDLE_WITH_HIP) && !defined(_WIN32) dev_ctx.template Alloc(out, + out->numel() * sizeof(T)); const auto [M, N, K] = canonicalize_dims(input, + weight); + + if (bias.numel() != N) { + // only broadcast to 1D bias whatsoever + // pass1: scalar to 1D + phi::Tile(dev_ctx, bias, {N}, out); + } + if (N > 1 && K > 1) { + // CublasLt path with bias add epilogue + phi::funcs::LinearWithCublasLt::Run( + dev_ctx, + &input, + &weight, + out, + static_cast(bias.data()), + nullptr, + M, + N, + K, + false, + false, + phi::funcs::MatmulFusedType::kMatmulBias); + } else { + // Cublas path with beta==1 bias adding. + blas.GEMM(dev_ctx, ) + } + #else + */ + phi::MatmulGradKernel( + dev_ctx, input, weight, out_grad, false, false, input_grad, weight_grad); + /* + if (dy_bst.dims() == y.dims()) { + Copy(dev_ctx, dy_bst, dev_ctx.GetPlace(), false, dy); + } else { + funcs::MatrixReduceSumFunctor functor; + functor(dev_ctx, dy_bst, dy); + dy->Resize(y.dims()); + } + */ + dev_ctx.template Alloc(bias_grad); + if (out_grad.dims() != bias.dims()) { + funcs::MatrixReduceSumFunctor functor; + functor(dev_ctx, out_grad, bias_grad); + bias_grad->Resize(bias.dims()); + } else { + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, bias_grad); + } + // #endif +} + +} // namespace phi + +PD_REGISTER_KERNEL(linear_v2_grad, + GPU, + ALL_LAYOUT, + phi::LinearV2GradKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/linear_v2_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_kernel.cu index 314fc47b306a85..f142779abaf84a 100644 --- a/paddle/phi/kernels/gpu/linear_v2_kernel.cu +++ b/paddle/phi/kernels/gpu/linear_v2_kernel.cu @@ -19,7 +19,12 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/common/enforce.h" +#include "paddle/phi/kernels/addmm_kernel.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/impl/matmul_kernel_impl.h" +#include "paddle/phi/kernels/reshape_kernel.h" +#include "paddle/phi/kernels/tile_kernel.h" #ifdef PADDLE_WITH_HIP #include @@ -49,141 +54,22 @@ #include "paddle/phi/backends/gpu/rocm/rocm_helper.h" #include "paddle/phi/kernels/funcs/blas/blaslt_impl.hip.h" #endif -namespace phi { - -/* -#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000 && \ - !(defined(_WIN32) || defined(WIN32))) - if (!FLAGS_use_legacy_gemm && // NOLINT - x.place().GetType() == phi::AllocationType::GPU && - weight.place().GetType() == phi::AllocationType::GPU && - bias.place().GetType() == phi::AllocationType::GPU && - !bias.is_dist_tensor()) // NOLINT - [[likely]] { // NOLINT - // TODO(Pan Zhaowu): Add proper broadcast logic for batchsize unaligned - // batch-gemm. Currently handles: (B..., k) x (k, n) -> (B..., n), with - // 1D or scalar bias. - - // --- Original input tensor dimensions and values --- - const auto &x_original_shape = x.shape(); - const size_t x_ndim_original = x_original_shape.size(); - - const auto &weight_original_shape = weight.shape(); - const size_t weight_ndim_original = weight_original_shape.size(); - - const int64_t k_dim = - x_original_shape[x_ndim_original - 1]; // Last dimension of X - - paddle::Tensor x_processed = - x; // Start with original, possibly reassign if reshaped - paddle::Tensor weight_processed = - weight; // Start with original, possibly reassign if reshaped - - // If x is 1D (e.g., shape [k]), reshape it to [1, k] to fit the (B..., - // k) x (k, n) pattern. This effectively treats a 1D vector as a row - // vector for matrix multiplication. - if (x_ndim_original == 1) [[unlikely]] { - x_processed = reshape_ad_func(x, {1, k_dim}); - } else if (weight_ndim_original == 1) [[unlikely]] { // NOLINT - weight_processed = reshape_ad_func(weight, {k_dim, 1}); - } - - // --- Recalculate dimensions based on processed tensors --- - const auto &x_shape_current = x_processed.shape(); - const size_t x_ndim_current = x_shape_current.size(); - - const auto &weight_shape_current = weight_processed.shape(); - const size_t weight_ndim_current = weight_shape_current.size(); - - const int64_t k_effective = x_shape_current[x_ndim_current - 1]; - const int64_t n_effective = - weight_shape_current[weight_ndim_current - 1]; - - // --- Determine the final output shape --- - std::vector output_shape_vec = x_shape_current; - output_shape_vec[x_ndim_current - 1] = n_effective; - - // If the original x was 1D, the processed x became [1, k]. - if (x_ndim_original == 1 && output_shape_vec.size() > 1 && - output_shape_vec[0] == 1) { - output_shape_vec.erase( - output_shape_vec - .begin()); // Remove the artificial batch dimension - } - - // This is used for reshaping X into a 2D matrix for addmm_ad_func. - const int64_t x_batch_numel = - std::accumulate(output_shape_vec.begin(), - output_shape_vec.end() - 1, - 1LL, - std::multiplies()); - - // --- Bias handling and GEMM execution --- - // The condition now uses the processed weight's shape. - // This branch typically handles (B..., k) x (k, n) where n > 1. - if (weight_shape_current[0] > 1 && weight_shape_current[1] > 1) { - paddle::Tensor bias_1d = - bias; // Create a mutable copy if modification is needed - // Align bias' shape to 'n_effective'. If bias.numel() != n_effective, - // tile it. - if (bias.numel() != n_effective) { - bias_1d = tile_ad_func(bias, {static_cast(n_effective)}); - } - // Execute fused GEMM with epilogue. - auto [out, _] = fused_gemm_epilogue_ad_func( - x_processed, weight_processed, bias_1d, false, false, "none"); - - // If original x was 1D and output_shape_vec is 1D (i.e., [n]), - // but fused_gemm_epilogue_ad_func returns a 2D tensor ([1, n]), - // reshape it back to the desired 1D output shape. - if (x_ndim_original == 1 && out.shape().size() == 2 && - output_shape_vec.size() == 1) { - out = reshape_ad_func(out, output_shape_vec); - } - PyEval_RestoreThread(tstate); - tstate = nullptr; - return ToPyObject(out); - } else { - // This branch handles cases where weight_processed is effectively 2D - // with one dimension being 1, e.g., (B..., k) x (k, 1) resulting in - // (B..., 1). Or when weight_processed was originally 1D and reshaped - // to [k, 1]. - - // Reshape bias to [1, n_effective] then tile to [x_batch_numel, 1] - // for addmm_ad_func. - paddle::Tensor bias_2d = tile_ad_func( - reshape_ad_func(bias, {1, n_effective}), {x_batch_numel, 1}); - - // Perform matrix multiplication using addmm_ad_func. - // x_processed is reshaped to 2D [x_batch_numel, k_effective] for the - // multiplication. - auto out = addmm_ad_func( - bias_2d, - reshape_ad_func(x_processed, {x_batch_numel, k_effective}), - weight_processed, - 1.0, - 1.0); +#endif +COMMON_DECLARE_bool(use_legacy_linear); - out = reshape_ad_func(out, output_shape_vec); +namespace phi { - PyEval_RestoreThread(tstate); - tstate = nullptr; - return ToPyObject(out); - } - } else // NOLINT(readability/braces) -#endif +/* + NOTE(Pan Zhaowu): There's a promise from API level that bias is exist, + and always(or always can be broadcasted) to be equal to the output's last dim. */ + // we don't receive 2+d tensor as weight inline std::tuple canonicalize_dims( const DenseTensor& input, const DenseTensor& weight) { const auto x_dims = input.dims(); const auto y_dims = weight.dims(); - PADDLE_ENFORCE_LE_GE( - y_dims.size(), - 2, - platform::errors::InvalidArgument("Y must be at most 2D")); - const int64_t N = y_dims.size() < 2 ? 1 : y_dims[y_dims.size() - 1]; const int64_t K = y_dims.size() < 2 ? y_dims[0] : y_dims[y_dims.size() - 2]; @@ -197,56 +83,102 @@ inline std::tuple canonicalize_dims( return {M, N, K}; } + template void LinearV2Kernel(const Context& dev_ctx, const DenseTensor& input, const DenseTensor& weight, const DenseTensor& bias, DenseTensor* out) { - PADDLE_ENFORCE_LE(bias.dims().size(), - 1, - platform::errors::InvalidArgument("Bias must be 1D")); - PADDLE_ENFORCE_LE(weight.dims().size(), - 2, - platform::errors::InvalidArgument("Weight must be 2D")); if (out->numel() == 0) { dev_ctx.template Alloc(out); return; } -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION < 11060 || \ - defined(PADDLE_WITH_HIP) - // NOTE(Pan Zhaowu): Fallback logic for legacy CUDA version or DCU. - // TODO(Pan Zhaowu): Implement this - return; -#else - dev_ctx.template Alloc(out, out->numel() * sizeof(T)); - const auto [M, N, K] = canonicalize_dims(input, weight); - if (bias.numel() != N) { - // only broadcast to 1D bias whatsoever - // pass1: scalar to 1D - phi::Tile(dev_ctx, bias, {N}, out); - } - if (N > 1 && K > 1) { - // CublasLt path with bias add epilogue - phi::funcs::LinearWithCublasLt::Run( - dev_ctx, - &input, - &weight, - out, - static_cast(bias.data()), - nullptr, - M, - N, - K, - false, - false, - phi::funcs::MatmulFusedType::kMatmulBias); - } else { - // Cublas path with beta==1 bias adding. - blas.GEMM(dev_ctx, ) - } +// broadcast bias, reshape input, run_fuse, reshape output +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION > 11060 && \ + !defined(PADDLE_WITH_HIP) && !defined(_WIN32) + if (!FLAGS_use_legacy_linear) { + VLOG(3) << "Use LinearV2Kernel with cublaslt"; + dev_ctx.template Alloc(out); + const auto out_dim_original = out->dims(); + const auto [M, N, K] = canonicalize_dims(input, weight); + VLOG(3) << "M: " << M << ", N: " << N << ", K: " << K; + + DenseTensor input_processed; + DenseTensor weight_processed; + DenseTensor output_processed; + phi::ReshapeKernel(dev_ctx, input, {M, K}, &input_processed); + phi::ReshapeKernel(dev_ctx, weight, {K, N}, &weight_processed); + out->Resize(common::make_ddim({M, N})); + VLOG(3) << "input_processed: " << input_processed.dims() + << ", weight_processed: " << weight_processed.dims() + << ", output_processed: " << out->dims(); + + if (N > 1 && K > 1) { + DenseTensor bias_processed; + if (bias.numel() != N) { + // only broadcast to 1D bias whatsoever + // pass1: scalar to 1D + phi::TileKernel(dev_ctx, bias, {N}, &bias_processed); + } else { + bias_processed = bias; + } + // CublasLt path with bias add epilogue + phi::funcs::LinearWithCublasLt::Run( + dev_ctx, + &input, + &weight, + out, + static_cast(bias_processed.data()), + nullptr, + M, + N, + K, + false, + false, + phi::funcs::MatmulFusedType::kMatmulBias); + } else { + DenseTensor bias_processed; + if (bias.numel() != (M * N)) { + VLOG(3) << "bias.numel(): " << bias.numel(); + VLOG(3) << "M*N: " << M * N; + VLOG(3) << "bias tiling and addmm calculating"; + // only broadcast to 1D bias whatsoever + phi::TileKernel(dev_ctx, bias, {M, N}, &bias_processed); + } else { + bias_processed = bias; + } + phi::AddmmKernel(dev_ctx, + bias_processed, + input_processed, + weight_processed, + 1.0f, + 1.0f, + out); + } + VLOG(3) << "linear calculate complete"; + out->Resize(out_dim_original); + } else // NOLINT #endif + // Fallback logic for legacy CUDA version or other hardware. + // Or specified by user to use a legacy behaviour. + { // NOLINT + // NOTE(Pan Zhaowu): Fallback logic for legacy CUDA version or DCU. + std::vector input_dims_vec = common::vectorize(input.dims()); + std::vector weight_dims_vec = + common::vectorize(weight.dims()); + + MatMulFunction(dev_ctx, + input, + weight, + input_dims_vec, + weight_dims_vec, + out, + false, + false); + AddKernel(dev_ctx, *out, bias, out); + } } } // namespace phi diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 53e87720ffb065..630fb30e277ffc 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -2054,6 +2054,15 @@ data_transform : skip_transform : out_size, size_tensor, scale_tensor +- backward_op : linear_v2_grad + forward : linear_v2 (Tensor input, Tensor weight, Tensor bias) -> Tensor(out) + args: (Tensor input, Tensor weight, Tensor bias, Tensor out_grad) + output: Tensor(input_grad), Tensor(weight_grad), Tensor(bias_grad) + infer_meta : + func : LinearV2GradInferMeta + kernel : + func : linear_v2_grad + - backward_op : log10_grad forward : log10 (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 7ffec5e5a4deeb..01ea8ad1019f00 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3180,6 +3180,16 @@ skip_transform : out_size, size_tensor, scale_tensor interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : linear_v2 + args : (Tensor input, Tensor weight, Tensor bias) + output : Tensor(out) + infer_meta : + func : LinearV2InferMeta + kernel : + func : linear_v2 + data_type : input + backward : linear_v2_grad + - op : linspace args : (Tensor start, Tensor stop, Tensor number, DataType dtype, Place place) output : Tensor(out) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index df517f6c6738e8..21cc0312728dea 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -16,6 +16,7 @@ import inspect import math +import os import warnings from typing import TYPE_CHECKING, Any, Literal @@ -2529,49 +2530,56 @@ def linear( [ 1.08524013, 1.08524013, 1.08524013, 1.08524013], [-0.67769694, -0.67769694, -0.67769694, -0.67769694]]) """ - if in_dynamic_mode(): - return _C_ops.linear(x, weight, bias) + if os.environ.get("FLAGS_use_legacy_gemm", False): + if in_dynamic_mode(): + return _C_ops.linear(x, weight, bias) - elif in_pir_mode(): - out = _C_ops.matmul(x, weight, False, False) - if bias is not None: - return _C_ops.add(out, bias) + elif in_pir_mode(): + out = _C_ops.matmul(x, weight, False, False) + if bias is not None: + return _C_ops.add(out, bias) + else: + return out else: - return out - else: - helper = LayerHelper('linear', **locals()) - dtype = x.dtype + helper = LayerHelper('linear', **locals()) + dtype = x.dtype - check_variable_and_dtype( - x, 'x', ["uint16", 'float16', 'float32', 'float64'], 'linear' - ) - check_dtype( - dtype, - 'dtype', - ["uint16", 'float16', 'float32', 'float64'], - 'linear', - ) + check_variable_and_dtype( + x, 'x', ["uint16", 'float16', 'float32', 'float64'], 'linear' + ) + check_dtype( + dtype, + 'dtype', + ["uint16", 'float16', 'float32', 'float64'], + 'linear', + ) - inputs = {'X': [x], 'Y': [weight]} - attrs = {'trans_x': False, 'trans_y': False} - tmp = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type='matmul_v2', - inputs=inputs, - outputs={'Out': tmp}, - attrs=attrs, - ) - if bias is not None: - res = helper.create_variable_for_type_inference(dtype) + inputs = {'X': [x], 'Y': [weight]} + attrs = {'trans_x': False, 'trans_y': False} + tmp = helper.create_variable_for_type_inference(dtype) helper.append_op( - type='elementwise_add', - inputs={'X': [tmp], 'Y': [bias]}, - outputs={'Out': [res]}, - attrs={'axis': -1}, + type='matmul_v2', + inputs=inputs, + outputs={'Out': tmp}, + attrs=attrs, ) - else: - res = tmp - return res + if bias is not None: + res = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='elementwise_add', + inputs={'X': [tmp], 'Y': [bias]}, + outputs={'Out': [res]}, + attrs={'axis': -1}, + ) + else: + res = tmp + return res + else: + if in_dynamic_or_pir_mode(): + if bias is not None: + return _C_ops.linear_v2(x, weight, bias) + else: + return _C_ops.matmul(x, weight) def label_smooth( From 83a7bd2c74b07717352395763fdc3981a87f204a Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 26 Dec 2025 11:34:02 +0800 Subject: [PATCH 29/55] fix fwd miscs --- paddle/phi/kernels/funcs/matrix_reduce.cu | 2 ++ paddle/phi/kernels/gpu/linear_v2_kernel.cu | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/funcs/matrix_reduce.cu b/paddle/phi/kernels/funcs/matrix_reduce.cu index 819822761d4408..380ed148dfd0ff 100644 --- a/paddle/phi/kernels/funcs/matrix_reduce.cu +++ b/paddle/phi/kernels/funcs/matrix_reduce.cu @@ -52,6 +52,8 @@ class MatrixReduceSumFunctor { template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; +template class MatrixReduceSumFunctor; +template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; diff --git a/paddle/phi/kernels/gpu/linear_v2_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_kernel.cu index f142779abaf84a..28f9a73754443b 100644 --- a/paddle/phi/kernels/gpu/linear_v2_kernel.cu +++ b/paddle/phi/kernels/gpu/linear_v2_kernel.cu @@ -141,11 +141,15 @@ void LinearV2Kernel(const Context& dev_ctx, } else { DenseTensor bias_processed; if (bias.numel() != (M * N)) { - VLOG(3) << "bias.numel(): " << bias.numel(); + phi::ReshapeKernel( + dev_ctx, bias, {1, bias.numel()}, &bias_processed); + VLOG(3) << "bias.dim(): " << bias.dims(); VLOG(3) << "M*N: " << M * N; VLOG(3) << "bias tiling and addmm calculating"; // only broadcast to 1D bias whatsoever - phi::TileKernel(dev_ctx, bias, {M, N}, &bias_processed); + phi::TileKernel( + dev_ctx, bias_processed, {M, 1}, &bias_processed); + VLOG(3) << "bias_processed.dims(): " << bias_processed.dims(); } else { bias_processed = bias; } From 487024bac24e1e4f51c2c8c4ea7577321a70e5d2 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 26 Dec 2025 17:13:53 +0800 Subject: [PATCH 30/55] fix shape miscs --- paddle/phi/infermeta/backward.cc | 36 +----------- .../phi/kernels/gpu/linear_v2_grad_kernel.cu | 55 +++++-------------- 2 files changed, 16 insertions(+), 75 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 0c2a4eb9b1dd38..595a654f01552a 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -504,37 +504,6 @@ void LinearV2GradInferMeta(const MetaTensor& input, auto bias_dims = bias.dims(); auto dout_dims = out_grad.dims(); - PADDLE_ENFORCE_GE(dout_dims.size(), - 2, - common::errors::InvalidArgument( - "The Input tensor DOut's dimension of LinearV2Op " - " should be >= 2, but got %d.", - dout_dims.size())); - - PADDLE_ENFORCE_EQ(weight_dims.size(), - 2, - common::errors::InvalidArgument( - "The Input tensor Y's dimension of LinearV2Op " - " should be 2, but got %d.", - weight_dims.size())); - - PADDLE_ENFORCE_GE(input_dims.size(), - 2, - common::errors::InvalidArgument( - "The Input tensor X's dimension of LinearV2Op " - " should be >= 2, but got %d.", - input_dims.size())); - - PADDLE_ENFORCE_EQ( - dout_dims.size(), - input_dims.size(), - common::errors::InvalidArgument( - "The Input tensor DOut's and X's dimension of " - "LinearV2Op " - " should be the same, but got DOut's dim = %d and X's = %d.", - dout_dims.size(), - input_dims.size())); - auto dout_mat_dims = common::flatten_to_2d(dout_dims, dout_dims.size() - 1); PADDLE_ENFORCE_EQ( @@ -561,8 +530,9 @@ void LinearV2GradInferMeta(const MetaTensor& input, } } - auto k_from_dout = dout_dims[input_dims.size() - 2]; - auto k_from_input = input_dims[input_dims.size() - 2]; + const int64_t input_ndim = input_dims.size(); + auto k_from_dout = input_ndim >= 2 ? dout_dims[input_ndim - 2] : 1; + auto k_from_input = input_ndim >= 2 ? input_dims[input_ndim - 2] : 1; bool check_k = (k_from_dout < 0 || k_from_input < 0) || (k_from_dout == k_from_input); diff --git a/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu index 4dcf98517930ca..d53e6f0b352142 100644 --- a/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu @@ -18,8 +18,11 @@ #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/elementwise_grad_base.h" #include "paddle/phi/kernels/funcs/matrix_reduce.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" #ifdef PADDLE_WITH_HIP #include @@ -211,54 +214,22 @@ void LinearV2GradKernel(const Context& dev_ctx, /* reduce bias with dummy add_grad, perform matmul grad, reshape grad dim #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION > 11060 && - !defined(PADDLE_WITH_HIP) && !defined(_WIN32) dev_ctx.template Alloc(out, - out->numel() * sizeof(T)); const auto [M, N, K] = canonicalize_dims(input, - weight); - - if (bias.numel() != N) { - // only broadcast to 1D bias whatsoever - // pass1: scalar to 1D - phi::Tile(dev_ctx, bias, {N}, out); - } - if (N > 1 && K > 1) { - // CublasLt path with bias add epilogue - phi::funcs::LinearWithCublasLt::Run( - dev_ctx, - &input, - &weight, - out, - static_cast(bias.data()), - nullptr, - M, - N, - K, - false, - false, - phi::funcs::MatmulFusedType::kMatmulBias); - } else { - // Cublas path with beta==1 bias adding. - blas.GEMM(dev_ctx, ) - } + !defined(PADDLE_WITH_HIP) && !defined(_WIN32) #else */ phi::MatmulGradKernel( dev_ctx, input, weight, out_grad, false, false, input_grad, weight_grad); - /* - if (dy_bst.dims() == y.dims()) { - Copy(dev_ctx, dy_bst, dev_ctx.GetPlace(), false, dy); + + if (bias_grad && bias.numel() != 0) { + if (out_grad.numel() != bias_grad->numel()) { + dev_ctx.template Alloc(bias_grad); + std::vector reduce_dims = + funcs::GetReduceDim(bias.dims(), out_grad.dims(), -1); + phi::SumKernel( + dev_ctx, out_grad, reduce_dims, out_grad.dtype(), false, bias_grad); } else { - funcs::MatrixReduceSumFunctor functor; - functor(dev_ctx, dy_bst, dy); - dy->Resize(y.dims()); + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, bias_grad); } - */ - dev_ctx.template Alloc(bias_grad); - if (out_grad.dims() != bias.dims()) { - funcs::MatrixReduceSumFunctor functor; - functor(dev_ctx, out_grad, bias_grad); - bias_grad->Resize(bias.dims()); - } else { - phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, bias_grad); } // #endif } From 7c9b52ce21cea28a86b0c1308f08749a9180d7e7 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 26 Dec 2025 17:28:18 +0800 Subject: [PATCH 31/55] clean code --- .../phi/kernels/gpu/linear_v2_grad_kernel.cu | 152 ------------------ 1 file changed, 152 deletions(-) diff --git a/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu index d53e6f0b352142..757d6076a47679 100644 --- a/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu @@ -56,152 +56,6 @@ #endif namespace phi { -/* -#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000 && \ - !(defined(_WIN32) || defined(WIN32))) - if (!FLAGS_use_legacy_gemm && // NOLINT - x.place().GetType() == phi::AllocationType::GPU && - weight.place().GetType() == phi::AllocationType::GPU && - bias.place().GetType() == phi::AllocationType::GPU && - !bias.is_dist_tensor()) // NOLINT - [[likely]] { // NOLINT - // TODO(Pan Zhaowu): Add proper broadcast logic for batchsize unaligned - // batch-gemm. Currently handles: (B..., k) x (k, n) -> (B..., n), with - // 1D or scalar bias. - - // --- Original input tensor dimensions and values --- - const auto &x_original_shape = x.shape(); - const size_t x_ndim_original = x_original_shape.size(); - - const auto &weight_original_shape = weight.shape(); - const size_t weight_ndim_original = weight_original_shape.size(); - - const int64_t k_dim = - x_original_shape[x_ndim_original - 1]; // Last dimension of X - - paddle::Tensor x_processed = - x; // Start with original, possibly reassign if reshaped - paddle::Tensor weight_processed = - weight; // Start with original, possibly reassign if reshaped - - // If x is 1D (e.g., shape [k]), reshape it to [1, k] to fit the (B..., - // k) x (k, n) pattern. This effectively treats a 1D vector as a row - // vector for matrix multiplication. - if (x_ndim_original == 1) [[unlikely]] { - x_processed = reshape_ad_func(x, {1, k_dim}); - } else if (weight_ndim_original == 1) [[unlikely]] { // NOLINT - weight_processed = reshape_ad_func(weight, {k_dim, 1}); - } - - // --- Recalculate dimensions based on processed tensors --- - const auto &x_shape_current = x_processed.shape(); - const size_t x_ndim_current = x_shape_current.size(); - - const auto &weight_shape_current = weight_processed.shape(); - const size_t weight_ndim_current = weight_shape_current.size(); - - const int64_t k_effective = x_shape_current[x_ndim_current - 1]; - const int64_t n_effective = - weight_shape_current[weight_ndim_current - 1]; - - // --- Determine the final output shape --- - std::vector output_shape_vec = x_shape_current; - output_shape_vec[x_ndim_current - 1] = n_effective; - - // If the original x was 1D, the processed x became [1, k]. - if (x_ndim_original == 1 && output_shape_vec.size() > 1 && - output_shape_vec[0] == 1) { - output_shape_vec.erase( - output_shape_vec - .begin()); // Remove the artificial batch dimension - } - - // This is used for reshaping X into a 2D matrix for addmm_ad_func. - const int64_t x_batch_numel = - std::accumulate(output_shape_vec.begin(), - output_shape_vec.end() - 1, - 1LL, - std::multiplies()); - - // --- Bias handling and GEMM execution --- - // The condition now uses the processed weight's shape. - // This branch typically handles (B..., k) x (k, n) where n > 1. - if (weight_shape_current[0] > 1 && weight_shape_current[1] > 1) { - paddle::Tensor bias_1d = - bias; // Create a mutable copy if modification is needed - // Align bias' shape to 'n_effective'. If bias.numel() != n_effective, - // tile it. - if (bias.numel() != n_effective) { - bias_1d = tile_ad_func(bias, {static_cast(n_effective)}); - } - // Execute fused GEMM with epilogue. - auto [out, _] = fused_gemm_epilogue_ad_func( - x_processed, weight_processed, bias_1d, false, false, "none"); - - // If original x was 1D and output_shape_vec is 1D (i.e., [n]), - // but fused_gemm_epilogue_ad_func returns a 2D tensor ([1, n]), - // reshape it back to the desired 1D output shape. - if (x_ndim_original == 1 && out.shape().size() == 2 && - output_shape_vec.size() == 1) { - out = reshape_ad_func(out, output_shape_vec); - } - - PyEval_RestoreThread(tstate); - tstate = nullptr; - return ToPyObject(out); - } else { - // This branch handles cases where weight_processed is effectively 2D - // with one dimension being 1, e.g., (B..., k) x (k, 1) resulting in - // (B..., 1). Or when weight_processed was originally 1D and reshaped - // to [k, 1]. - - // Reshape bias to [1, n_effective] then tile to [x_batch_numel, 1] - // for addmm_ad_func. - paddle::Tensor bias_2d = tile_ad_func( - reshape_ad_func(bias, {1, n_effective}), {x_batch_numel, 1}); - - // Perform matrix multiplication using addmm_ad_func. - // x_processed is reshaped to 2D [x_batch_numel, k_effective] for the - // multiplication. - auto out = addmm_ad_func( - bias_2d, - reshape_ad_func(x_processed, {x_batch_numel, k_effective}), - weight_processed, - 1.0, - 1.0); - - out = reshape_ad_func(out, output_shape_vec); - - PyEval_RestoreThread(tstate); - tstate = nullptr; - return ToPyObject(out); - } - } else // NOLINT(readability/braces) -#endif -// we don't receive 2+d tensor as weight -inline std::tuple canonicalize_dims( - const DenseTensor& input, const DenseTensor& weight) { - const auto x_dims = input.dims(); - const auto y_dims = weight.dims(); - PADDLE_ENFORCE_LE_GE( - y_dims.size(), - 2, - platform::errors::InvalidArgument("Y must be at most 2D")); - - const int64_t N = y_dims.size() < 2 ? 1 : y_dims[y_dims.size() - 1]; - const int64_t K = y_dims.size() < 2 ? y_dims[0] : y_dims[y_dims.size() - 2]; - - int64_t M = x_dims.size() >= 2 ? x_dims[x_dims.size() - 2] : 1; - if (x_dims.size() > 2) { - // Accumulate the batch dims for input - for (int64_t i = 0; i < x_dims.size() - 2; ++i) { - M *= x_dims[i]; - } - } - - return {M, N, K}; -} -*/ template void LinearV2GradKernel(const Context& dev_ctx, const DenseTensor& input, @@ -211,12 +65,6 @@ void LinearV2GradKernel(const Context& dev_ctx, DenseTensor* input_grad, DenseTensor* weight_grad, DenseTensor* bias_grad) { - /* - reduce bias with dummy add_grad, perform matmul grad, reshape grad dim - #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION > 11060 && - !defined(PADDLE_WITH_HIP) && !defined(_WIN32) - #else - */ phi::MatmulGradKernel( dev_ctx, input, weight, out_grad, false, false, input_grad, weight_grad); From 813713aaafcb66f9c503f37f8d9a930173463e8b Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 29 Dec 2025 14:42:36 +0800 Subject: [PATCH 32/55] fix grad --- paddle/phi/kernels/funcs/reduce_gpu_kernel.h | 2 +- paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu | 4 ++-- paddle/phi/kernels/gpu/linear_v2_kernel.cu | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/funcs/reduce_gpu_kernel.h b/paddle/phi/kernels/funcs/reduce_gpu_kernel.h index 9592903ae7788d..4628a3f595c2c3 100644 --- a/paddle/phi/kernels/funcs/reduce_gpu_kernel.h +++ b/paddle/phi/kernels/funcs/reduce_gpu_kernel.h @@ -263,7 +263,7 @@ int GetOutputVecSize(const DenseTensorIterator& iter) { } // Simplify fraction by dividing both numerator and denominator by their GCD -// (Greatest Common Divisor). +// (Greatest Common Divisor) HOSTDEVICE static void ReduceFraction(size_t* numerator, size_t* denominator) { size_t a = *denominator; size_t b = *numerator; diff --git a/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu index 757d6076a47679..95ddc2d2819eaf 100644 --- a/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu @@ -18,8 +18,6 @@ #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" -#include "paddle/phi/kernels/funcs/elementwise_grad_base.h" -#include "paddle/phi/kernels/funcs/matrix_reduce.h" #include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" @@ -75,8 +73,10 @@ void LinearV2GradKernel(const Context& dev_ctx, funcs::GetReduceDim(bias.dims(), out_grad.dims(), -1); phi::SumKernel( dev_ctx, out_grad, reduce_dims, out_grad.dtype(), false, bias_grad); + bias_grad->Resize(bias.dims()); } else { phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, bias_grad); + bias_grad->Resize(bias.dims()); } } // #endif diff --git a/paddle/phi/kernels/gpu/linear_v2_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_kernel.cu index 28f9a73754443b..7cfdd1f59cc56d 100644 --- a/paddle/phi/kernels/gpu/linear_v2_kernel.cu +++ b/paddle/phi/kernels/gpu/linear_v2_kernel.cu @@ -90,8 +90,8 @@ void LinearV2Kernel(const Context& dev_ctx, const DenseTensor& weight, const DenseTensor& bias, DenseTensor* out) { + dev_ctx.template Alloc(out); if (out->numel() == 0) { - dev_ctx.template Alloc(out); return; } @@ -100,7 +100,6 @@ void LinearV2Kernel(const Context& dev_ctx, !defined(PADDLE_WITH_HIP) && !defined(_WIN32) if (!FLAGS_use_legacy_linear) { VLOG(3) << "Use LinearV2Kernel with cublaslt"; - dev_ctx.template Alloc(out); const auto out_dim_original = out->dims(); const auto [M, N, K] = canonicalize_dims(input, weight); VLOG(3) << "M: " << M << ", N: " << N << ", K: " << K; From 12884e4e62d030ea8bc9c8bd4ceb3141200a7c86 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 29 Dec 2025 14:47:42 +0800 Subject: [PATCH 33/55] recover redundant diff --- paddle/phi/kernels/funcs/matrix_reduce.cu | 2 -- paddle/phi/kernels/funcs/reduce_gpu_kernel.h | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/phi/kernels/funcs/matrix_reduce.cu b/paddle/phi/kernels/funcs/matrix_reduce.cu index 380ed148dfd0ff..819822761d4408 100644 --- a/paddle/phi/kernels/funcs/matrix_reduce.cu +++ b/paddle/phi/kernels/funcs/matrix_reduce.cu @@ -52,8 +52,6 @@ class MatrixReduceSumFunctor { template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; -template class MatrixReduceSumFunctor; -template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; diff --git a/paddle/phi/kernels/funcs/reduce_gpu_kernel.h b/paddle/phi/kernels/funcs/reduce_gpu_kernel.h index 4628a3f595c2c3..9592903ae7788d 100644 --- a/paddle/phi/kernels/funcs/reduce_gpu_kernel.h +++ b/paddle/phi/kernels/funcs/reduce_gpu_kernel.h @@ -263,7 +263,7 @@ int GetOutputVecSize(const DenseTensorIterator& iter) { } // Simplify fraction by dividing both numerator and denominator by their GCD -// (Greatest Common Divisor) +// (Greatest Common Divisor). HOSTDEVICE static void ReduceFraction(size_t* numerator, size_t* denominator) { size_t a = *denominator; size_t b = *numerator; From d3d7f5894864638f2c9d8266424d5fcda51ec9b3 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 29 Dec 2025 14:58:51 +0800 Subject: [PATCH 34/55] using legacy linear --- paddle/fluid/pybind/eager_custom_python_api.h | 4 ---- paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h | 8 ++++---- python/paddle/nn/functional/common.py | 2 +- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/pybind/eager_custom_python_api.h b/paddle/fluid/pybind/eager_custom_python_api.h index bfb885c499d876..aa233b2ccc95fd 100644 --- a/paddle/fluid/pybind/eager_custom_python_api.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -15,13 +15,10 @@ #include -#include "paddle/common/ddim.h" -#include "paddle/common/flags.h" #include "paddle/fluid/eager/to_static/run_program_func.h" #include "paddle/fluid/eager/utils.h" #include "paddle/phi/core/enforce.h" -COMMON_DECLARE_bool(use_legacy_gemm); using egr::ConvertAllInputsToDistTensor; using egr::InputsContainDistTensor; @@ -47,7 +44,6 @@ static PyObject *eager_api_linear(PyObject *self, auto mm_out = matmul_ad_func(x, weight, false, false); auto out = add_ad_func(mm_out, bias); - PyEval_RestoreThread(tstate); tstate = nullptr; return ToPyObject(out); diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h index b4c5b1a59e2c41..6c5016513d531e 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -32,7 +32,7 @@ limitations under the License. */ COMMON_DECLARE_int64(cublaslt_exhaustive_search_times); COMMON_DECLARE_bool(enable_blaslt_global_search); -COMMON_DECLARE_bool(use_legacy_gemm); +COMMON_DECLARE_bool(use_legacy_linear); #endif namespace phi { @@ -467,7 +467,7 @@ struct CublasLtBase { // NOTE(limingshu): As workspace_size varies from different DL framework, // I wonder is there any smarter idea for workspace setting, currently I // just followed the settings from the NVIDIA colleague`s setting. - size_t workspace_size = FLAGS_use_legacy_gemm + size_t workspace_size = FLAGS_use_legacy_linear ? static_cast(4) * 1024 * 1024 : static_cast(1) * 1024 * 1024; phi::Allocator::AllocationPtr workspace = @@ -494,7 +494,7 @@ struct CublasLtBase { } } cublasLtMatmulHeuristicResult_t heuristic_results = {}; - if (!FLAGS_use_legacy_gemm) { + if (!FLAGS_use_legacy_linear) { cublasLtMatmulPreference_t preference; PADDLE_ENFORCE_GPU_SUCCESS( dynload::cublasLtMatmulPreferenceCreate(&preference)); @@ -539,7 +539,7 @@ struct CublasLtBase { desc->out_desc, out_ptr, desc->out_desc, - FLAGS_use_legacy_gemm ? desc->algo : &heuristic_results.algo, + FLAGS_use_legacy_linear ? desc->algo : &heuristic_results.algo, workspace->ptr(), workspace_size, dev_ctx.stream())); diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 21cc0312728dea..9cf985c7ac6310 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -2530,7 +2530,7 @@ def linear( [ 1.08524013, 1.08524013, 1.08524013, 1.08524013], [-0.67769694, -0.67769694, -0.67769694, -0.67769694]]) """ - if os.environ.get("FLAGS_use_legacy_gemm", False): + if os.environ.get("FLAGS_use_legacy_linear", False): if in_dynamic_mode(): return _C_ops.linear(x, weight, bias) From 7a4900383784509e7cdb00ff58588590d73ebed0 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 29 Dec 2025 15:44:41 +0800 Subject: [PATCH 35/55] add multi-platform support, polish --- paddle/phi/kernels/cpu/linear_v2_kernel.cc | 67 +++++++++++++++++++ paddle/phi/kernels/gpu/linear_v2_kernel.cu | 45 +++---------- ...grad_kernel.cu => linear_v2_grad_kernel.h} | 37 +++------- paddle/phi/kernels/linear_v2_kernel.h | 47 +++++++++++++ python/paddle/nn/functional/common.py | 6 +- 5 files changed, 140 insertions(+), 62 deletions(-) create mode 100644 paddle/phi/kernels/cpu/linear_v2_kernel.cc rename paddle/phi/kernels/{gpu/linear_v2_grad_kernel.cu => linear_v2_grad_kernel.h} (73%) create mode 100644 paddle/phi/kernels/linear_v2_kernel.h diff --git a/paddle/phi/kernels/cpu/linear_v2_kernel.cc b/paddle/phi/kernels/cpu/linear_v2_kernel.cc new file mode 100644 index 00000000000000..08637363ad314c --- /dev/null +++ b/paddle/phi/kernels/cpu/linear_v2_kernel.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include "glog/logging.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +#include "paddle/common/enforce.h" +#include "paddle/phi/kernels/addmm_kernel.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/impl/matmul_kernel_impl.h" +#include "paddle/phi/kernels/linear_v2_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" +#include "paddle/phi/kernels/tile_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/scope_guard.h" + +namespace phi { + +template +void LinearV2Kernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& bias, + DenseTensor* out) { + dev_ctx.template Alloc(out); + if (out->numel() == 0) { + return; + } + + // When in CPU, we use legacy linear_logic by default. + // TODO(Pan Zhaowu): Adding more efficient kernel for CPU. + std::vector input_dims_vec = common::vectorize(input.dims()); + std::vector weight_dims_vec = common::vectorize(weight.dims()); + + MatMulFunction(dev_ctx, + input, + weight, + input_dims_vec, + weight_dims_vec, + out, + false, + false); + AddKernel(dev_ctx, *out, bias, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + linear_v2, CPU, ALL_LAYOUT, phi::LinearV2Kernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/linear_v2_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_kernel.cu index 7cfdd1f59cc56d..7f000ab1c83ef7 100644 --- a/paddle/phi/kernels/gpu/linear_v2_kernel.cu +++ b/paddle/phi/kernels/gpu/linear_v2_kernel.cu @@ -23,6 +23,7 @@ #include "paddle/phi/kernels/addmm_kernel.h" #include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/impl/matmul_kernel_impl.h" +#include "paddle/phi/kernels/linear_v2_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h" #include "paddle/phi/kernels/tile_kernel.h" @@ -60,30 +61,6 @@ COMMON_DECLARE_bool(use_legacy_linear); namespace phi { -/* - NOTE(Pan Zhaowu): There's a promise from API level that bias is exist, - and always(or always can be broadcasted) to be equal to the output's last dim. -*/ - -// we don't receive 2+d tensor as weight -inline std::tuple canonicalize_dims( - const DenseTensor& input, const DenseTensor& weight) { - const auto x_dims = input.dims(); - const auto y_dims = weight.dims(); - const int64_t N = y_dims.size() < 2 ? 1 : y_dims[y_dims.size() - 1]; - const int64_t K = y_dims.size() < 2 ? y_dims[0] : y_dims[y_dims.size() - 2]; - - int64_t M = x_dims.size() >= 2 ? x_dims[x_dims.size() - 2] : 1; - if (x_dims.size() > 2) { - // Accumulate the batch dims for input - for (int64_t i = 0; i < x_dims.size() - 2; ++i) { - M *= x_dims[i]; - } - } - - return {M, N, K}; -} - template void LinearV2Kernel(const Context& dev_ctx, const DenseTensor& input, @@ -99,10 +76,10 @@ void LinearV2Kernel(const Context& dev_ctx, #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION > 11060 && \ !defined(PADDLE_WITH_HIP) && !defined(_WIN32) if (!FLAGS_use_legacy_linear) { - VLOG(3) << "Use LinearV2Kernel with cublaslt"; + VLOG(10) << "Use LinearV2Kernel with cublaslt"; const auto out_dim_original = out->dims(); const auto [M, N, K] = canonicalize_dims(input, weight); - VLOG(3) << "M: " << M << ", N: " << N << ", K: " << K; + VLOG(10) << "M: " << M << ", N: " << N << ", K: " << K; DenseTensor input_processed; DenseTensor weight_processed; @@ -110,9 +87,9 @@ void LinearV2Kernel(const Context& dev_ctx, phi::ReshapeKernel(dev_ctx, input, {M, K}, &input_processed); phi::ReshapeKernel(dev_ctx, weight, {K, N}, &weight_processed); out->Resize(common::make_ddim({M, N})); - VLOG(3) << "input_processed: " << input_processed.dims() - << ", weight_processed: " << weight_processed.dims() - << ", output_processed: " << out->dims(); + VLOG(10) << "input_processed: " << input_processed.dims() + << ", weight_processed: " << weight_processed.dims() + << ", output_processed: " << out->dims(); if (N > 1 && K > 1) { DenseTensor bias_processed; @@ -142,13 +119,13 @@ void LinearV2Kernel(const Context& dev_ctx, if (bias.numel() != (M * N)) { phi::ReshapeKernel( dev_ctx, bias, {1, bias.numel()}, &bias_processed); - VLOG(3) << "bias.dim(): " << bias.dims(); - VLOG(3) << "M*N: " << M * N; - VLOG(3) << "bias tiling and addmm calculating"; + VLOG(10) << "bias.dim(): " << bias.dims(); + VLOG(10) << "M*N: " << M * N; + VLOG(10) << "bias tiling and addmm calculating"; // only broadcast to 1D bias whatsoever phi::TileKernel( dev_ctx, bias_processed, {M, 1}, &bias_processed); - VLOG(3) << "bias_processed.dims(): " << bias_processed.dims(); + VLOG(10) << "bias_processed.dims(): " << bias_processed.dims(); } else { bias_processed = bias; } @@ -160,7 +137,7 @@ void LinearV2Kernel(const Context& dev_ctx, 1.0f, out); } - VLOG(3) << "linear calculate complete"; + VLOG(10) << "linear calculate complete"; out->Resize(out_dim_original); } else // NOLINT #endif diff --git a/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu b/paddle/phi/kernels/linear_v2_grad_kernel.h similarity index 73% rename from paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu rename to paddle/phi/kernels/linear_v2_grad_kernel.h index 95ddc2d2819eaf..832915d4e0428b 100644 --- a/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu +++ b/paddle/phi/kernels/linear_v2_grad_kernel.h @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#pragma once #include #include #include @@ -22,36 +23,10 @@ #include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" -#ifdef PADDLE_WITH_HIP -#include -#include -#else -#include // NOLINT -#include "cuda.h" // NOLINT -#endif - -#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) || \ - defined(PADDLE_WITH_HIP) - #include "paddle/common/flags.h" #include "paddle/phi/backends/all_context.h" -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/phi/core/scope_guard.h" -#include "paddle/utils/optional.h" -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 -#include "paddle/phi/backends/dynload/cublasLt.h" -#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" -#elif defined(PADDLE_WITH_HIP) -#include "paddle/phi/backends/dynload/hipblasLt.h" -#include "paddle/phi/backends/gpu/rocm/rocm_helper.h" -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.hip.h" -#endif -#endif namespace phi { template @@ -79,7 +54,6 @@ void LinearV2GradKernel(const Context& dev_ctx, bias_grad->Resize(bias.dims()); } } - // #endif } } // namespace phi @@ -92,3 +66,12 @@ PD_REGISTER_KERNEL(linear_v2_grad, double, phi::float16, phi::bfloat16) {} + +PD_REGISTER_KERNEL(linear_v2_grad, + CPU, + ALL_LAYOUT, + phi::LinearV2GradKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/paddle/phi/kernels/linear_v2_kernel.h b/paddle/phi/kernels/linear_v2_kernel.h new file mode 100644 index 00000000000000..efce8a8cf51888 --- /dev/null +++ b/paddle/phi/kernels/linear_v2_kernel.h @@ -0,0 +1,47 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" + +namespace phi { + +// we don't receive 2+d tensor as weight +inline std::tuple canonicalize_dims( + const DenseTensor& input, const DenseTensor& weight) { + const auto x_dims = input.dims(); + const auto y_dims = weight.dims(); + const int64_t N = y_dims.size() < 2 ? 1 : y_dims[y_dims.size() - 1]; + const int64_t K = y_dims.size() < 2 ? y_dims[0] : y_dims[y_dims.size() - 2]; + + int64_t M = x_dims.size() >= 2 ? x_dims[x_dims.size() - 2] : 1; + if (x_dims.size() > 2) { + // Accumulate the batch dims for input + for (int64_t i = 0; i < x_dims.size() - 2; ++i) { + M *= x_dims[i]; + } + } + + return {M, N, K}; +} + +template +void LinearV2Kernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& bias, + DenseTensor* out); +} // namespace phi diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 9cf985c7ac6310..4e8f3aaa06a3d5 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -2530,7 +2530,11 @@ def linear( [ 1.08524013, 1.08524013, 1.08524013, 1.08524013], [-0.67769694, -0.67769694, -0.67769694, -0.67769694]]) """ - if os.environ.get("FLAGS_use_legacy_linear", False): + # If not specified by user to use legacy linear, or not CUDA compatible, we fallback. + if ( + os.environ.get("FLAGS_use_legacy_linear", False) + or not paddle.is_compiled_with_cuda() + ): if in_dynamic_mode(): return _C_ops.linear(x, weight, bias) From a45acea363a3fcd6f82d15c4f2c9741fc14f8ab6 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 29 Dec 2025 17:31:39 +0800 Subject: [PATCH 36/55] refractor --- .../phi/kernels/cpu/linear_v2_grad_kernel.cc | 18 +++++++++++++++ paddle/phi/kernels/cpu/linear_v2_kernel.cc | 2 +- .../phi/kernels/gpu/linear_v2_grad_kernel.cu | 23 +++++++++++++++++++ paddle/phi/kernels/linear_v2_grad_kernel.h | 21 ++--------------- paddle/phi/kernels/linear_v2_kernel.h | 2 +- 5 files changed, 45 insertions(+), 21 deletions(-) create mode 100644 paddle/phi/kernels/cpu/linear_v2_grad_kernel.cc create mode 100644 paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu diff --git a/paddle/phi/kernels/cpu/linear_v2_grad_kernel.cc b/paddle/phi/kernels/cpu/linear_v2_grad_kernel.cc new file mode 100644 index 00000000000000..51199623e92be8 --- /dev/null +++ b/paddle/phi/kernels/cpu/linear_v2_grad_kernel.cc @@ -0,0 +1,18 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/linear_v2_grad_kernel.h" + +PD_REGISTER_KERNEL( + linear_v2_grad, CPU, ALL_LAYOUT, phi::LinearV2GradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/linear_v2_kernel.cc b/paddle/phi/kernels/cpu/linear_v2_kernel.cc index 08637363ad314c..ca0ada016ad2ca 100644 --- a/paddle/phi/kernels/cpu/linear_v2_kernel.cc +++ b/paddle/phi/kernels/cpu/linear_v2_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu new file mode 100644 index 00000000000000..e010d6460fa64b --- /dev/null +++ b/paddle/phi/kernels/gpu/linear_v2_grad_kernel.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/phi/kernels/linear_v2_grad_kernel.h" + +PD_REGISTER_KERNEL(linear_v2_grad, + GPU, + ALL_LAYOUT, + phi::LinearV2GradKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/paddle/phi/kernels/linear_v2_grad_kernel.h b/paddle/phi/kernels/linear_v2_grad_kernel.h index 832915d4e0428b..6fab0e08665166 100644 --- a/paddle/phi/kernels/linear_v2_grad_kernel.h +++ b/paddle/phi/kernels/linear_v2_grad_kernel.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" @@ -57,21 +58,3 @@ void LinearV2GradKernel(const Context& dev_ctx, } } // namespace phi - -PD_REGISTER_KERNEL(linear_v2_grad, - GPU, - ALL_LAYOUT, - phi::LinearV2GradKernel, - float, - double, - phi::float16, - phi::bfloat16) {} - -PD_REGISTER_KERNEL(linear_v2_grad, - CPU, - ALL_LAYOUT, - phi::LinearV2GradKernel, - float, - double, - phi::float16, - phi::bfloat16) {} diff --git a/paddle/phi/kernels/linear_v2_kernel.h b/paddle/phi/kernels/linear_v2_kernel.h index efce8a8cf51888..8eb17f9ca329e0 100644 --- a/paddle/phi/kernels/linear_v2_kernel.h +++ b/paddle/phi/kernels/linear_v2_kernel.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From f15e3ad1e439b1f0a9b30924d0ce0300db7d1d66 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 29 Dec 2025 19:54:50 +0800 Subject: [PATCH 37/55] fix flag and amp rules --- paddle/common/flags.cc | 2 +- python/paddle/amp/amp_lists.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index 1755e8fcf0d311..87fa6b76dc1e7a 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -2382,7 +2382,7 @@ PHI_DEFINE_EXPORTED_bool(use_legacy_gemm, /** * Legacy gemm related FLAG * Name: FLAGS_use_legacy_linear - * Since Version: 3.2.2 + * Since Version: 3.3.1 * Value Range: bool, default=false * Example: * Note: Whether use legacy linear kernel. diff --git a/python/paddle/amp/amp_lists.py b/python/paddle/amp/amp_lists.py index 4b7914a0acd1cd..09ef995c51c027 100644 --- a/python/paddle/amp/amp_lists.py +++ b/python/paddle/amp/amp_lists.py @@ -22,6 +22,7 @@ 'einsum', 'matmul', 'matmul_v2', + 'linear_v2', 'max_pool2d_with_index', 'mul', 'fused_gemm_epilogue', From ddd1508a1384e4b7b89e1edebf34963f951091a6 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Tue, 30 Dec 2025 15:43:25 +0800 Subject: [PATCH 38/55] fix CI --- paddle/phi/infermeta/ternary.cc | 27 ++++++++++++++++----------- paddle/phi/infermeta/ternary.h | 3 ++- python/paddle/nn/functional/common.py | 12 +++++++----- test/amp/test_amp_api.py | 2 ++ test/amp/test_amp_master_grad.py | 2 ++ test/amp/test_amp_promote.py | 2 ++ 6 files changed, 31 insertions(+), 17 deletions(-) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index afc997b24b2680..5a8f2e49fb1e73 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1553,7 +1553,8 @@ void LerpInferMeta(const MetaTensor& x, void LinearV2InferMeta(const MetaTensor& input, const MetaTensor& weight, const MetaTensor& bias, - MetaTensor* out) { + MetaTensor* out, + MetaConfig config) { const auto& input_dims = input.dims(); const auto& weight_dims = weight.dims(); const int64_t weight_ndim = weight.dims().size(); @@ -1595,16 +1596,20 @@ void LinearV2InferMeta(const MetaTensor& input, common::flatten_to_2d(input_dims, input_dims.size() - 1); auto input_rank = input_dims.size(); - int64_t K_from_input = input_mat_dims[1]; - int64_t K_from_weight = weight_dims[0]; - PADDLE_ENFORCE_EQ( - K_from_input, - K_from_weight, - common::errors::InvalidArgument( - "The last dimension of X should be equal with Y's first dimension." - "But received X[-1] = [%d], Y[0] = [%d].", - K_from_input, - K_from_weight)); + const bool check_dim = + (!config.is_runtime && K_from_input != -1) || config.is_runtime; + if (check_dim) { + int64_t K_from_input = input_mat_dims[1]; + int64_t K_from_weight = weight_dims[0]; + PADDLE_ENFORCE_EQ( + K_from_input, + K_from_weight, + common::errors::InvalidArgument( + "The last dimension of X should be equal with Y's first dimension." + "But received X[-1] = [%d], Y[0] = [%d].", + K_from_input, + K_from_weight)); + } std::vector out_dims; out_dims.reserve(input_rank); diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 587e18f1324f16..2cf3f0374d1183 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -277,7 +277,8 @@ PADDLE_API void LerpInferMeta(const MetaTensor& x, PADDLE_API void LinearV2InferMeta(const MetaTensor& input, const MetaTensor& weight, const MetaTensor& bias, - MetaTensor* out); + MetaTensor* out, + MetaConfig config = MetaConfig()); PADDLE_API void LinspaceRawInferMeta(const MetaTensor& start, const MetaTensor& stop, diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 4e8f3aaa06a3d5..97f8cc34c579b8 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -2534,6 +2534,7 @@ def linear( if ( os.environ.get("FLAGS_use_legacy_linear", False) or not paddle.is_compiled_with_cuda() + or not in_dynamic_or_pir_mode() ): if in_dynamic_mode(): return _C_ops.linear(x, weight, bias) @@ -2579,11 +2580,12 @@ def linear( res = tmp return res else: - if in_dynamic_or_pir_mode(): - if bias is not None: - return _C_ops.linear_v2(x, weight, bias) - else: - return _C_ops.matmul(x, weight) + if bias is not None: + print("exist bias") + return _C_ops.linear_v2(x, weight, bias) + else: + print("not exist bias") + return _C_ops.matmul(x, weight) def label_smooth( diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py index 10843a00f99298..00c624fee8be5a 100644 --- a/test/amp/test_amp_api.py +++ b/test/amp/test_amp_api.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from contextlib import contextmanager @@ -455,4 +456,5 @@ def test_pir_op_called_as_expected(self): if __name__ == '__main__': + os.environ["FLAGS_use_legacy_linear"] = "True" unittest.main() diff --git a/test/amp/test_amp_master_grad.py b/test/amp/test_amp_master_grad.py index 9e646ef575d50b..149349f7537a9f 100644 --- a/test/amp/test_amp_master_grad.py +++ b/test/amp/test_amp_master_grad.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import numpy as np @@ -217,4 +218,5 @@ def test_pir_momentum_master_grad(self): if __name__ == '__main__': + os.environ["FLAGS_use_legacy_linear"] = "True" unittest.main() diff --git a/test/amp/test_amp_promote.py b/test/amp/test_amp_promote.py index 76d48e66ca4314..5172dfdd951d3f 100644 --- a/test/amp/test_amp_promote.py +++ b/test/amp/test_amp_promote.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import numpy as np @@ -419,4 +420,5 @@ def test_o2_use_promote_off(self): if __name__ == '__main__': + os.environ["FLAGS_use_legacy_linear"] = "True" unittest.main() From 83b14e8743415185eefc6feba7350221ee850a84 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Tue, 30 Dec 2025 15:45:32 +0800 Subject: [PATCH 39/55] polish --- python/paddle/nn/functional/common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 97f8cc34c579b8..336fe2b0cad7fd 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -2581,10 +2581,8 @@ def linear( return res else: if bias is not None: - print("exist bias") return _C_ops.linear_v2(x, weight, bias) else: - print("not exist bias") return _C_ops.matmul(x, weight) From 708855fa417e153002223c5b8a9128e14374b7af Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Tue, 30 Dec 2025 16:07:57 +0800 Subject: [PATCH 40/55] fix miscs --- paddle/phi/infermeta/ternary.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 5a8f2e49fb1e73..2a114bb96b47ce 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1596,11 +1596,11 @@ void LinearV2InferMeta(const MetaTensor& input, common::flatten_to_2d(input_dims, input_dims.size() - 1); auto input_rank = input_dims.size(); + int64_t K_from_input = input_mat_dims[1]; + int64_t K_from_weight = weight_dims[0]; const bool check_dim = (!config.is_runtime && K_from_input != -1) || config.is_runtime; if (check_dim) { - int64_t K_from_input = input_mat_dims[1]; - int64_t K_from_weight = weight_dims[0]; PADDLE_ENFORCE_EQ( K_from_input, K_from_weight, From a7b25fc25d9bf1e429e56eabb30039b7fc83c03e Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Tue, 30 Dec 2025 16:19:32 +0800 Subject: [PATCH 41/55] add infersymbolics --- .../multiary_infer_sym.cc | 49 +++++++++++++++++++ .../infer_symbolic_shape/multiary_infer_sym.h | 1 + paddle/phi/ops/yaml/ops.yaml | 1 + 3 files changed, 51 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index cdbb0515674b0e..478c0a4b68eda6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2281,6 +2281,55 @@ bool FusedGemmEpilogueOpInferSymbolicShape( return true; } +bool LinearV2OpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &x_dims = x_shape_or_data.shape(); + + const auto &y_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const auto &y_dims = y_shape_or_data.shape(); + + const auto &bias_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + const auto &bias_dims = bias_shape_or_data.shape(); + + size_t x_rank = x_dims.size(); + size_t y_rank = y_dims.size(); + + std::vector out_shape; + out_shape.reserve(x_rank); + + for (size_t i = 0; i + 2 < x_rank; ++i) { + out_shape.emplace_back(x_dims[i]); + } + + symbol::DimExpr out_M = x_dims[x_rank - 2]; + symbol::DimExpr out_N = y_dims[y_rank - 1]; + + out_shape.emplace_back(out_M); + out_shape.emplace_back(out_N); + + symbol::DimExpr x_K = x_dims[x_rank - 1]; + symbol::DimExpr y_K = y_dims[y_rank - 2]; + + infer_context->AddEqualCstr(x_K, y_K); + // bias_dims[0] equal to out_N + infer_context->AddEqualCstr(out_N, bias_dims[0]); + + infer_context->SetShapeOrDataForValue(op->result(0), + ShapeOrData{TensorExprs(out_shape)}); + + // process reserve space + if (!paddle::dialect::details::IsFakeValue(op->result(1))) { + infer_context->SetShapeOrDataForValue(op->result(1), + ShapeOrData{TensorExprs(out_shape)}); + } + + return true; +} + bool FusedMultiTransformerOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index b1d193e44163ea..d71b46967ce162 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -67,6 +67,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBatchNormAct_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedGemmEpilogue) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LinearV2) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedMultiTransformer) OP_DECLARE_INFER_SYMBOLIC_SHAPE(GenerateProposals) OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphKhopSampler) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 01ea8ad1019f00..b528c9b0084b4e 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3189,6 +3189,7 @@ func : linear_v2 data_type : input backward : linear_v2_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : linspace args : (Tensor start, Tensor stop, Tensor number, DataType dtype, Place place) From 442ddc4a7ad3dc6620051c97272f051693a92dea Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Tue, 30 Dec 2025 20:07:07 +0800 Subject: [PATCH 42/55] Add metaconfig --- paddle/fluid/pir/dialect/op_generator/op_build_gen.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index 60840cc60ec5e9..9aa280181ba1f1 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -68,6 +68,7 @@ 'FusedGateAttentionGradInferMeta', 'ResnetBasicBlockInferMeta', 'ResnetBasicBlockGradInferMeta', + 'LinearV2InferMeta', # multiary.h 'AddNInferMeta', 'ApVariadicInferMeta', From 3b709f1688adf5c35fbe5b4b950dfa2c4bf5baab Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 31 Dec 2025 11:31:57 +0800 Subject: [PATCH 43/55] fix symbolic, move flags --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 6 ------ test/amp/test_amp_api.py | 3 ++- test/amp/test_amp_master_grad.py | 3 ++- test/amp/test_amp_promote.py | 3 ++- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 478c0a4b68eda6..1671cd86ba00a6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2321,12 +2321,6 @@ bool LinearV2OpInferSymbolicShape( infer_context->SetShapeOrDataForValue(op->result(0), ShapeOrData{TensorExprs(out_shape)}); - // process reserve space - if (!paddle::dialect::details::IsFakeValue(op->result(1))) { - infer_context->SetShapeOrDataForValue(op->result(1), - ShapeOrData{TensorExprs(out_shape)}); - } - return true; } diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py index 00c624fee8be5a..18cc25faae41dd 100644 --- a/test/amp/test_amp_api.py +++ b/test/amp/test_amp_api.py @@ -24,6 +24,8 @@ from paddle.base import core from paddle.static import amp +os.environ["FLAGS_use_legacy_linear"] = "True" + @unittest.skipIf( not core.is_compiled_with_cuda() and not core.is_compiled_with_xpu(), @@ -456,5 +458,4 @@ def test_pir_op_called_as_expected(self): if __name__ == '__main__': - os.environ["FLAGS_use_legacy_linear"] = "True" unittest.main() diff --git a/test/amp/test_amp_master_grad.py b/test/amp/test_amp_master_grad.py index 149349f7537a9f..c08013864e232e 100644 --- a/test/amp/test_amp_master_grad.py +++ b/test/amp/test_amp_master_grad.py @@ -20,6 +20,8 @@ import paddle from paddle.base import core +os.environ["FLAGS_use_legacy_linear"] = "True" + class SimpleNet(paddle.nn.Layer): def __init__(self, input_size, output_size): @@ -218,5 +220,4 @@ def test_pir_momentum_master_grad(self): if __name__ == '__main__': - os.environ["FLAGS_use_legacy_linear"] = "True" unittest.main() diff --git a/test/amp/test_amp_promote.py b/test/amp/test_amp_promote.py index 5172dfdd951d3f..3d0298ccc81f2a 100644 --- a/test/amp/test_amp_promote.py +++ b/test/amp/test_amp_promote.py @@ -22,6 +22,8 @@ from paddle.base import core from paddle.static import amp +os.environ["FLAGS_use_legacy_linear"] = "True" + @unittest.skipIf( not core.is_compiled_with_cuda() and not core.is_compiled_with_xpu(), @@ -420,5 +422,4 @@ def test_o2_use_promote_off(self): if __name__ == '__main__': - os.environ["FLAGS_use_legacy_linear"] = "True" unittest.main() From 049d62f37aeedf8558d77117fc22aa970da6a115 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 31 Dec 2025 12:37:40 +0800 Subject: [PATCH 44/55] fix bwd infermeta --- paddle/phi/infermeta/backward.cc | 39 ++++++++++++++------------------ 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 595a654f01552a..71013a6b42a0f6 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -506,14 +506,23 @@ void LinearV2GradInferMeta(const MetaTensor& input, auto dout_mat_dims = common::flatten_to_2d(dout_dims, dout_dims.size() - 1); - PADDLE_ENFORCE_EQ( - dout_mat_dims[1], - weight_dims[1], - common::errors::InvalidArgument( - "The last dimension of DOut should be equal with Y's last " - "dimension. But received DOut[-1] = [%d], Y[1] = [%d].", - dout_mat_dims[1], - weight_dims[1])); + const int64_t input_ndim = input_dims.size(); + auto k_from_dout = input_ndim >= 2 ? dout_dims[input_ndim - 2] : 1; + auto k_from_input = input_ndim >= 2 ? input_dims[input_ndim - 2] : 1; + + bool check_k = + (k_from_dout < 0 || k_from_input < 0) || (k_from_dout == k_from_input); + + if (check_k) { + PADDLE_ENFORCE_EQ( + dout_mat_dims[1], + weight_dims[1], + common::errors::InvalidArgument( + "The last dimension of DOut should be equal with Y's last " + "dimension. But received DOut[-1] = [%d], Y[1] = [%d].", + dout_mat_dims[1], + weight_dims[1])); + } for (int32_t i = 0; i + 2 < input_dims.size(); ++i) { if (dout_dims[i] > 0 && input_dims[i] > 0) { @@ -530,20 +539,6 @@ void LinearV2GradInferMeta(const MetaTensor& input, } } - const int64_t input_ndim = input_dims.size(); - auto k_from_dout = input_ndim >= 2 ? dout_dims[input_ndim - 2] : 1; - auto k_from_input = input_ndim >= 2 ? input_dims[input_ndim - 2] : 1; - - bool check_k = - (k_from_dout < 0 || k_from_input < 0) || (k_from_dout == k_from_input); - PADDLE_ENFORCE_EQ( - check_k, - true, - common::errors::InvalidArgument( - "K from dout and x is not same, k_from_dout is [%d], k_from_x is[%d]", - k_from_dout, - k_from_input)); - if (input_grad) { input_grad->set_dims(input_dims); input_grad->set_dtype(input.dtype()); From 3eac5fe5e4eeae0a880f9ef13b04ea1a60b8c2b8 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 31 Dec 2025 15:37:39 +0800 Subject: [PATCH 45/55] Add prim linear_v2_grad --- .../generator/eager_gen.py | 1 + .../op_generator/vjp_interface_black_list.py | 1 + paddle/fluid/prim/api/api.yaml | 1 + .../composite_double_backward_api.h | 25 +++++++++++++++++++ paddle/phi/ops/yaml/backward.yaml | 10 ++++++++ paddle/phi/ops/yaml/op_compat.yaml | 3 +++ 6 files changed, 41 insertions(+) diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 756dd9d38c40a0..ffaae7de12c378 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -86,6 +86,7 @@ "where_double_grad", "bmm_double_grad", "index_put_double_grad", + "linear_v2_double_grad", "gather_nd_double_grad", "reshape_double_grad", "take_along_axis_double_grad", diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py index 4ffcf5a51224b0..fd9d770f4e2b87 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -44,5 +44,6 @@ 'index_elementwise_get_grad', 'index_elementwise_put_with_tensor_grad', 'index_elementwise_put_grad', + 'linear_v2_grad', 'view_shape_grad', ] diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index f4c7a145ef73b6..74e1649ebccd0c 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -15,6 +15,7 @@ - exp - scale - matmul +- linear_v2 - expand - sum - abs diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index f612c8d2ec2a34..a83cf7ba362181 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -849,6 +849,31 @@ void add_triple_grad(const paddle::optional& grad_grad_x, } } +template +void linear_v2_double_grad(const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const Tensor& grad_out, + const paddle::optional& grad_input_grad, + const paddle::optional& grad_weight_grad, + const paddle::optional& grad_bias_grad, + Tensor* input_grad, + Tensor* weight_grad, + Tensor* bias_grad, + Tensor* grad_out_grad) { + matmul_double_grad(input, + weight, + grad_out, + grad_input_grad, + grad_weight_grad, + false, + false, + input_grad, + weight_grad, + grad_out_grad); + add_double_grad(bias, grad_out, nullptr, grad_bias_grad, -1, bias_grad); +} + template void subtract_double_grad(const Tensor& y, const Tensor& grad_out, diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 630fb30e277ffc..c581103e62895b 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -2054,6 +2054,15 @@ data_transform : skip_transform : out_size, size_tensor, scale_tensor +- backward_op : linear_v2_double_grad + forward: linear_v2_grad (Tensor input, Tensor weight, Tensor bias, Tensor grad_out) -> Tensor(grad_input), Tensor(grad_weight), Tensor(grad_bias) + args : (Tensor input, Tensor weight, Tensor bias, Tensor grad_out, Tensor grad_input_grad, Tensor grad_weight_grad, Tensor grad_bias_grad) + output: Tensor(input_grad), Tensor(weight_grad), Tensor(bias_grad), Tensor(grad_out_grad) + infer_meta : + func : GeneralQuaternaryGradInferMeta + param : [input, weight, bias, grad_out] + composite: linear_v2_double_grad(input, weight, bias, grad_out, grad_input_grad, grad_weight_grad, grad_bias_grad, input_grad, weight_grad, bias_grad, grad_out_grad) + - backward_op : linear_v2_grad forward : linear_v2 (Tensor input, Tensor weight, Tensor bias) -> Tensor(out) args: (Tensor input, Tensor weight, Tensor bias, Tensor out_grad) @@ -2062,6 +2071,7 @@ func : LinearV2GradInferMeta kernel : func : linear_v2_grad + backward: linear_v2_double_grad - backward_op : log10_grad forward : log10 (Tensor x) -> Tensor(out) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index fa4999de08ee90..13af4c098d70c8 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -2307,6 +2307,9 @@ extra : attrs : [bool use_mkldnn = false, bool use_onednn = false] +- op : linear_v2 + backward : linear_v2_grad, linear_v2_double_grad + - op : linspace inputs : {start : Start, stop : Stop, number : Num} From bfc67e80f76ee6700a438bc26b659a456c0c40bf Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 31 Dec 2025 17:00:20 +0800 Subject: [PATCH 46/55] fix fwd decomp --- .../op_generator/decomp_interface_gen_op_list.py | 2 ++ .../composite_double_backward_api.h | 4 +++- paddle/fluid/primitive/codegen/decomp_vjp_gen.py | 1 + .../primitive/decomp_rule/decomp_rule/composite.h | 9 +++++++++ .../primitive/decomp_rule/decomp_vjp/details.h | 15 +++++++++++++++ python/paddle/autograd/backward_utils.py | 1 + 6 files changed, 31 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 72774939dac6d3..36e0271aeb1f16 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -57,6 +57,7 @@ "lerp", "log_loss", "log_softmax", + "linear_v2", "mean", "mean_all", "meshgrid", @@ -112,6 +113,7 @@ 'layer_norm_grad', 'log_grad', 'matmul_grad', + 'linear_v2_grad', 'max_grad', 'maximum_grad', 'mean_grad', diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index a83cf7ba362181..3d11fec8558b35 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -871,7 +871,9 @@ void linear_v2_double_grad(const Tensor& input, input_grad, weight_grad, grad_out_grad); - add_double_grad(bias, grad_out, nullptr, grad_bias_grad, -1, bias_grad); + if (bias_grad) { + add_double_grad(bias, grad_out, nullptr, grad_bias_grad, -1, bias_grad); + } } template diff --git a/paddle/fluid/primitive/codegen/decomp_vjp_gen.py b/paddle/fluid/primitive/codegen/decomp_vjp_gen.py index db177fec29198a..2345ddfa1caf62 100644 --- a/paddle/fluid/primitive/codegen/decomp_vjp_gen.py +++ b/paddle/fluid/primitive/codegen/decomp_vjp_gen.py @@ -103,6 +103,7 @@ 'logsumexp_grad', 'masked_select_grad', 'matmul_grad', + 'linear_v2_grad', 'max_grad', 'maximum_grad', 'minimum_grad', diff --git a/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h b/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h index 830b301524bf42..c1bd3eb36a9bbb 100644 --- a/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h +++ b/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h @@ -256,6 +256,15 @@ Tensor bmm_decomp(const Tensor& x, const Tensor& y) { return matmul(x, y, false, false); } +template +Tensor linear_v2_decomp(const Tensor& input, + const Tensor& weight, + const Tensor& bias) { + Tensor result = matmul(input, weight, false, false); + result = result + bias; + return result; +} + template std::tuple batch_norm_decomp( const Tensor& x, diff --git a/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h b/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h index f6bb5fa99de6df..a21fa67fae969b 100644 --- a/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h +++ b/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h @@ -1438,6 +1438,21 @@ void matmul_grad(const Tensor& x, } } +template +void linear_v2_grad(const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const Tensor& out_grad, + Tensor* input_grad, + Tensor* weight_grad, + Tensor* bias_grad) { + matmul_grad( + input, weight, out_grad, false, false, input_grad, weight_grad); + if (bias_grad) { + add_grad(bias, bias, out_grad, -1, nullptr, bias_grad); + } +} + template void maximum_grad(const Tensor& x, const Tensor& y, diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index f103e1ae7e6f62..32c5c40c3ed053 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -66,6 +66,7 @@ "pd_op.log", "pd_op.logcumsumexp", "pd_op.logsumexp", + "pd_op.linear_v2", "pd_op.matmul", "pd_op.max", "pd_op.maximum", From 4f5146bc47ceecbbb3c3f13f801819e31136ac15 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Sun, 4 Jan 2026 17:00:29 +0800 Subject: [PATCH 47/55] add proper fallback to fulfill legacy promise --- test/collective/fleet/test_dygraph_sharding_stage2.py | 4 ++++ test/dygraph_to_static/dygraph_to_static_utils.py | 3 +++ test/ir/pir/cinn/sub_graphs/base.py | 4 ++++ test/ir/pir/test_ir_backward.py | 3 +++ test/legacy_test/test_cumsum_op.py | 3 +++ test/legacy_test/test_lookahead.py | 5 +++++ test/legacy_test/test_pir_translated_layer.py | 4 ++++ 7 files changed, 26 insertions(+) diff --git a/test/collective/fleet/test_dygraph_sharding_stage2.py b/test/collective/fleet/test_dygraph_sharding_stage2.py index f300e4d7bc25f2..8a135a62909acb 100644 --- a/test/collective/fleet/test_dygraph_sharding_stage2.py +++ b/test/collective/fleet/test_dygraph_sharding_stage2.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from legacy_test.test_parallel_dygraph_dataparallel import ( TestMultipleAccelerators, ) +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of array equal. +os.environ["FLAGS_use_legacy_linear"] = "True" + class TestDygraphShardingStage2(TestMultipleAccelerators): # check sharding logic as well as the accuracy with single mode diff --git a/test/dygraph_to_static/dygraph_to_static_utils.py b/test/dygraph_to_static/dygraph_to_static_utils.py index d5444614ffb0aa..970730210eba17 100644 --- a/test/dygraph_to_static/dygraph_to_static_utils.py +++ b/test/dygraph_to_static/dygraph_to_static_utils.py @@ -27,6 +27,9 @@ from typing_extensions import TypeAlias import paddle + +# NOTE(Pan Zhaowu): enable prim all to support high-order gradients of linear_v2 +paddle.core._set_prim_all_enabled(True) from paddle import set_flags from paddle.jit.api import sot_mode_guard from paddle.jit.dy2static.utils import ( diff --git a/test/ir/pir/cinn/sub_graphs/base.py b/test/ir/pir/cinn/sub_graphs/base.py index 0c3ac360e46e31..c48ce029b0a620 100644 --- a/test/ir/pir/cinn/sub_graphs/base.py +++ b/test/ir/pir/cinn/sub_graphs/base.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import numpy as np import paddle +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of array equal in test_ast_prim_cinn +os.environ["FLAGS_use_legacy_linear"] = "True" + class TestBase(unittest.TestCase): def setUp(self): diff --git a/test/ir/pir/test_ir_backward.py b/test/ir/pir/test_ir_backward.py index 9e560aba43ce8d..91940348c472a7 100644 --- a/test/ir/pir/test_ir_backward.py +++ b/test/ir/pir/test_ir_backward.py @@ -17,6 +17,9 @@ import numpy as np import paddle + +# NOTE(Pan Zhaowu): enable prim all to support high-order gradients of linear_v2 +paddle.core._set_prim_all_enabled(True) from paddle.autograd.backward_utils import ValueDict, ValueSet from paddle.autograd.ir_backward import grad from paddle.base.wrapped_decorator import signature_safe_contextmanager diff --git a/test/legacy_test/test_cumsum_op.py b/test/legacy_test/test_cumsum_op.py index 497e41f606ea43..c70e18265b8571 100644 --- a/test/legacy_test/test_cumsum_op.py +++ b/test/legacy_test/test_cumsum_op.py @@ -17,6 +17,9 @@ import tempfile import unittest +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of hard-coded op numbers in TestTensorAxis. +os.environ["FLAGS_use_legacy_linear"] = "True" + from paddle.framework import use_pir_api sys.path.append("../../legacy_test") diff --git a/test/legacy_test/test_lookahead.py b/test/legacy_test/test_lookahead.py index 4b095191df1b64..ccd09e2ec55c7a 100644 --- a/test/legacy_test/test_lookahead.py +++ b/test/legacy_test/test_lookahead.py @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import numpy as np import paddle + +# NOTE(Pan Zhaowu): using legacy linear to fulfill the promise of add_grad op. +os.environ["FLAGS_use_legacy_linear"] = "True" + from paddle import base, nn from paddle.base.framework import in_pir_mode diff --git a/test/legacy_test/test_pir_translated_layer.py b/test/legacy_test/test_pir_translated_layer.py index bae22c0fb760a6..cc90bfdc28a0ed 100644 --- a/test/legacy_test/test_pir_translated_layer.py +++ b/test/legacy_test/test_pir_translated_layer.py @@ -30,6 +30,10 @@ IMAGE_SIZE = 784 CLASS_NUM = 10 +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of array equal +# in test_inference_and_fine_tuning. +os.environ["FLAGS_use_legacy_linear"] = "True" + # define a random dataset class RandomDataset(paddle.io.Dataset): From c472892b058098d64f70678a3389f8ac7461762b Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Sun, 4 Jan 2026 20:51:39 +0800 Subject: [PATCH 48/55] tmp restrict prim --- test/dygraph_to_static/dygraph_to_static_utils.py | 2 +- test/ir/pir/test_ir_backward.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/dygraph_to_static/dygraph_to_static_utils.py b/test/dygraph_to_static/dygraph_to_static_utils.py index 970730210eba17..f1ace02f8f4292 100644 --- a/test/dygraph_to_static/dygraph_to_static_utils.py +++ b/test/dygraph_to_static/dygraph_to_static_utils.py @@ -29,7 +29,7 @@ import paddle # NOTE(Pan Zhaowu): enable prim all to support high-order gradients of linear_v2 -paddle.core._set_prim_all_enabled(True) +# paddle.core._set_prim_all_enabled(True) from paddle import set_flags from paddle.jit.api import sot_mode_guard from paddle.jit.dy2static.utils import ( diff --git a/test/ir/pir/test_ir_backward.py b/test/ir/pir/test_ir_backward.py index 91940348c472a7..787cb232c6fa0c 100644 --- a/test/ir/pir/test_ir_backward.py +++ b/test/ir/pir/test_ir_backward.py @@ -19,7 +19,7 @@ import paddle # NOTE(Pan Zhaowu): enable prim all to support high-order gradients of linear_v2 -paddle.core._set_prim_all_enabled(True) +# paddle.core._set_prim_all_enabled(True) from paddle.autograd.backward_utils import ValueDict, ValueSet from paddle.autograd.ir_backward import grad from paddle.base.wrapped_decorator import signature_safe_contextmanager From ef08d31135e381b8a6734c13d54b3aff5499d3f8 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 5 Jan 2026 21:08:13 +0800 Subject: [PATCH 49/55] Add inferSPMD, fix CI --- .../spmd_rules/fused_gemm_epilogue.h | 13 ++ paddle/phi/infermeta/spmd_rules/linear_v2.cc | 153 ++++++++++++++++++ paddle/phi/infermeta/spmd_rules/linear_v2.h | 30 ++++ paddle/phi/infermeta/spmd_rules/rules.cc | 4 + paddle/phi/infermeta/spmd_rules/rules.h | 1 + paddle/phi/ops/yaml/ops.yaml | 1 + python/paddle/nn/functional/common.py | 4 +- test/amp/test_pir_amp.py | 12 +- .../dygraph_to_static_utils.py | 3 - test/dygraph_to_static/test_gradname_parse.py | 3 + test/ir/pir/test_ir_backward.py | 6 +- .../legacy_test/check_nan_inf_base_dygraph.py | 3 + test/legacy_test/test_jit_save_load.py | 10 +- test/legacy_test/test_paddlescience.py | 4 +- 14 files changed, 229 insertions(+), 18 deletions(-) create mode 100644 paddle/phi/infermeta/spmd_rules/linear_v2.cc create mode 100644 paddle/phi/infermeta/spmd_rules/linear_v2.h diff --git a/paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h b/paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h index 3bf7fbe25b0225..7821b13f4597ff 100644 --- a/paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h +++ b/paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h @@ -19,6 +19,19 @@ limitations under the License. */ namespace phi { namespace distributed { +void FillMatmulPartOperandNotation(const int x_ndim, + const int y_ndim, + std::string* x_axes, + std::string* y_axes, + std::string* out_axes); +TensorDistAttr GetMatmulPartInferredDistAttr( + const TensorDistAttr& origin_dist_attr, + const std::vector& shape, + const std::string& tensor_axis, + const std::unordered_map& axis_to_dim_map, + bool trans_axis); +void SetTensorDistAttrReplicated(TensorDistAttr* dist_attr, const int ndim); + SpmdInfo FusedGemmEpilogueInferSpmdBase(const DistMetaTensor& x, const DistMetaTensor& y, const DistMetaTensor& bias, diff --git a/paddle/phi/infermeta/spmd_rules/linear_v2.cc b/paddle/phi/infermeta/spmd_rules/linear_v2.cc new file mode 100644 index 00000000000000..95b82b9a4dac84 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/linear_v2.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/linear_v2.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi::distributed { + +using phi::distributed::auto_parallel::str_join; + +SpmdInfo LinearV2InferSpmdBase(const DistMetaTensor& input, + const DistMetaTensor& weight, + const DistMetaTensor& bias) { + // Step0: verify input args based on matmul logic + auto ori_input_shape = common::vectorize(input.dims()); + auto ori_weight_shape = common::vectorize(weight.dims()); + auto ori_bias_shape = common::vectorize(bias.dims()); + int input_ndim = static_cast(ori_input_shape.size()); + int weight_ndim = static_cast(ori_weight_shape.size()); + int bias_ndim = static_cast(ori_bias_shape.size()); + const auto& input_dist_attr_src = input.dist_attr(); + const auto& weight_dist_attr_src = weight.dist_attr(); + const auto& bias_dist_attr_src = bias.dist_attr(); + std::vector input_dims_mapping = input_dist_attr_src.dims_mapping(); + std::vector weight_dims_mapping = + weight_dist_attr_src.dims_mapping(); + std::vector bias_dims_mapping = bias_dist_attr_src.dims_mapping(); + + PADDLE_ENFORCE_EQ(input_ndim, + input_dims_mapping.size(), + common::errors::InvalidArgument( + "LinearV2, The Tensor input's rank [%d] and input's " + "dims_mapping size [%d] are not matched.", + input_ndim, + input_dims_mapping.size())); + PADDLE_ENFORCE_EQ(weight_ndim, + weight_dims_mapping.size(), + common::errors::InvalidArgument( + "LinearV2, The Tensor weight's rank [%d] and weight's " + "dims_mapping size [%d] are not matched.", + weight_ndim, + weight_dims_mapping.size())); + + PADDLE_ENFORCE_EQ( + bias_ndim, + 1, + common::errors::InvalidArgument( + "LinearV2, The ndim of bias should be 1, but got [%d].", bias_ndim)); + + VLOG(4) << "LinearV2SPMDRule InferForward Inputs: "; + VLOG(4) << "input shape: [" << str_join(ori_input_shape) + << "], input_dims_mapping: [" << str_join(input_dims_mapping) << "];"; + VLOG(4) << "weight shape: [" << str_join(ori_weight_shape) + << "], weight_dims_mapping: [" << str_join(weight_dims_mapping) + << "];"; + VLOG(4) << "bias shape: [" << str_join(ori_bias_shape) + << "], bias_dims_mapping: [" << str_join(bias_dims_mapping) << "];"; + // Step1: build Einsum Notation + std::string input_axes; + std::string weight_axes; + std::string out_axes; + FillMatmulPartOperandNotation( + input_ndim, weight_ndim, &input_axes, &weight_axes, &out_axes); + + // Step2.1: Sharding Merge + std::pair> x_pair(input_axes, + input_dims_mapping); + std::pair> y_pair(weight_axes, + weight_dims_mapping); + auto axis_to_dim_map = ShardingMergeForTensors({x_pair, y_pair}); + + // Step2.2: Infer Output's Dims Mapping. + TensorDistAttr output_dist_attr_dst = + CopyTensorDistAttrForOutput(input_dist_attr_src); + std::vector out_dims_mapping; + out_dims_mapping.reserve(out_axes.size()); + for (size_t i = 0; i < out_axes.size(); ++i) { + out_dims_mapping.push_back(axis_to_dim_map[out_axes.substr(i, 1)]); + } + output_dist_attr_dst.set_dims_mapping(out_dims_mapping); + + // Step2.3: Merge and get Inputs' New Dims Mapping. + auto x_shape = common::vectorize(input.dims()); + auto y_shape = common::vectorize(weight.dims()); + + TensorDistAttr x_dist_attr_dst = GetMatmulPartInferredDistAttr( + input_dist_attr_src, x_shape, input_axes, axis_to_dim_map, false); + TensorDistAttr y_dist_attr_dst = GetMatmulPartInferredDistAttr( + weight_dist_attr_src, y_shape, weight_axes, axis_to_dim_map, false); + TensorDistAttr bias_dist_attr_dst = + CopyTensorDistAttrForOutput(bias_dist_attr_src); + bias_dist_attr_dst.set_dims_mapping( + std::vector{output_dist_attr_dst.dims_mapping().back()}); + + // Step2.3: Handle Partial + // Step2.3.1 Output Partial + std::vector partial_on_dims = + ResoluteOutputPartialDimension(axis_to_dim_map, out_axes); + output_dist_attr_dst.set_partial_status(partial_on_dims); + + if (output_dist_attr_dst.is_partial()) { + bias_dist_attr_dst.set_partial_status( + output_dist_attr_dst.partial_status()); + if (!IsPartialLegal(bias_dist_attr_dst) || + !IsPartialLegal(output_dist_attr_dst)) { + VLOG(4) << "LinearV2 partial output illegal, force set output " + "to replicated."; + output_dist_attr_dst.clean_partial_status(); + bias_dist_attr_dst.clean_partial_status(); + SetTensorDistAttrReplicated(&x_dist_attr_dst, input_ndim); + SetTensorDistAttrReplicated(&y_dist_attr_dst, weight_ndim); + SetTensorDistAttrReplicated(&bias_dist_attr_dst, bias_ndim); + SetTensorDistAttrReplicated(&output_dist_attr_dst, out_axes.size()); + } + } + TensorDistAttr output_reserve_dist_attr_dst = + CopyTensorDistAttrForOutput(output_dist_attr_dst); + VLOG(4) << "LinearV2SPMDRule InferForward: " + << "Einsum notation: [" << input_axes << "," << weight_axes << " --> " + << out_axes << "+" << out_axes.back() << "]. " << std::endl; + LogInputDistAttr( + "input", ori_input_shape, input_dist_attr_src, x_dist_attr_dst); + LogInputDistAttr( + "weight", ori_weight_shape, weight_dist_attr_src, y_dist_attr_dst); + LogInputDistAttr( + "Bias", ori_bias_shape, bias_dist_attr_src, bias_dist_attr_dst); + LogOutputDistAttr("Output", output_dist_attr_dst); + + return {{x_dist_attr_dst, y_dist_attr_dst, bias_dist_attr_dst}, + {output_dist_attr_dst, output_reserve_dist_attr_dst}}; +} +SpmdInfo LinearV2InferSpmd(const DistMetaTensor& input, + const DistMetaTensor& weight, + const DistMetaTensor& bias) { + return LinearV2InferSpmdBase(input, weight, bias); +} +} // namespace phi::distributed diff --git a/paddle/phi/infermeta/spmd_rules/linear_v2.h b/paddle/phi/infermeta/spmd_rules/linear_v2.h new file mode 100644 index 00000000000000..ed52f61cc8054a --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/linear_v2.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" +#include "paddle/phi/infermeta/spmd_rules/fused_gemm_epilogue.h" + +namespace phi { +namespace distributed { +SpmdInfo LinearV2InferSpmdBase(const DistMetaTensor& input, + const DistMetaTensor& weight, + const DistMetaTensor& bias); +SpmdInfo LinearV2InferSpmd(const DistMetaTensor& input, + const DistMetaTensor& weigh, + const DistMetaTensor& bias); +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.cc b/paddle/phi/infermeta/spmd_rules/rules.cc index ae7af0f90f2c03..d5f29f1e509d13 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.cc +++ b/paddle/phi/infermeta/spmd_rules/rules.cc @@ -814,6 +814,10 @@ PD_REGISTER_SPMD_RULE( fused_gemm_epilogue, PD_INFER_SPMD(phi::distributed::FusedGemmEpilogueInferSpmdBase)); +// linear_v2 +PD_REGISTER_SPMD_RULE(linear_v2, + PD_INFER_SPMD(phi::distributed::LinearV2InferSpmdBase)); + // take_along_axis PD_REGISTER_SPMD_RULE( take_along_axis, diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index ff47ee4acea09f..2c3c4bdce04f87 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -59,6 +59,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/instance_norm.h" #include "paddle/phi/infermeta/spmd_rules/label_smooth.h" #include "paddle/phi/infermeta/spmd_rules/layer_norm.h" +#include "paddle/phi/infermeta/spmd_rules/linear_v2.h" #include "paddle/phi/infermeta/spmd_rules/logsumexp.h" #include "paddle/phi/infermeta/spmd_rules/matmul.h" #include "paddle/phi/infermeta/spmd_rules/mean_all.h" diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index b528c9b0084b4e..db1023fa7229ad 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3185,6 +3185,7 @@ output : Tensor(out) infer_meta : func : LinearV2InferMeta + spmd_rule : LinearV2InferSpmd kernel : func : linear_v2 data_type : input diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 336fe2b0cad7fd..3d3e8e5027e4bc 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -16,7 +16,6 @@ import inspect import math -import os import warnings from typing import TYPE_CHECKING, Any, Literal @@ -2531,8 +2530,9 @@ def linear( [-0.67769694, -0.67769694, -0.67769694, -0.67769694]]) """ # If not specified by user to use legacy linear, or not CUDA compatible, we fallback. + if ( - os.environ.get("FLAGS_use_legacy_linear", False) + paddle.get_flags("FLAGS_use_legacy_linear")["FLAGS_use_legacy_linear"] or not paddle.is_compiled_with_cuda() or not in_dynamic_or_pir_mode() ): diff --git a/test/amp/test_pir_amp.py b/test/amp/test_pir_amp.py index 517e7814d7d493..019f8ae1149408 100644 --- a/test/amp/test_pir_amp.py +++ b/test/amp/test_pir_amp.py @@ -62,9 +62,11 @@ def test_linear_amp_o1(self): for op in main.global_block().ops: if op.name() == 'pd_op.cast': cast_op_count += 1 - np.testing.assert_equal(out1.dtype, core.DataType.FLOAT32) + # NOTE(Pan Zhaowu): After implementation of linear_v2, there's no + # need for mix-precision add op applies to intermediate result. + np.testing.assert_equal(out1.dtype, core.DataType.FLOAT16) np.testing.assert_equal(out2.dtype, core.DataType.FLOAT32) - np.testing.assert_equal(cast_op_count, 3) + np.testing.assert_equal(cast_op_count, 4) _white_list, _black_list = core._get_amp_op_list() np.testing.assert_equal(len(_white_list), 0) np.testing.assert_equal(len(_black_list), 0) @@ -88,9 +90,11 @@ def test_linear_amp_bf16_o1(self): for op in main.global_block().ops: if op.name() == 'pd_op.cast': cast_op_count += 1 - np.testing.assert_equal(out1.dtype, core.DataType.FLOAT32) + # NOTE(Pan Zhaowu): After implementation of linear_v2, there's no + # need for mix-precision add op applies to intermediate result. + np.testing.assert_equal(out1.dtype, core.DataType.BFLOAT16) np.testing.assert_equal(out2.dtype, core.DataType.FLOAT32) - np.testing.assert_equal(cast_op_count, 3) + np.testing.assert_equal(cast_op_count, 4) _white_list, _black_list = core._get_amp_op_list() np.testing.assert_equal(len(_white_list), 0) np.testing.assert_equal(len(_black_list), 0) diff --git a/test/dygraph_to_static/dygraph_to_static_utils.py b/test/dygraph_to_static/dygraph_to_static_utils.py index f1ace02f8f4292..d5444614ffb0aa 100644 --- a/test/dygraph_to_static/dygraph_to_static_utils.py +++ b/test/dygraph_to_static/dygraph_to_static_utils.py @@ -27,9 +27,6 @@ from typing_extensions import TypeAlias import paddle - -# NOTE(Pan Zhaowu): enable prim all to support high-order gradients of linear_v2 -# paddle.core._set_prim_all_enabled(True) from paddle import set_flags from paddle.jit.api import sot_mode_guard from paddle.jit.dy2static.utils import ( diff --git a/test/dygraph_to_static/test_gradname_parse.py b/test/dygraph_to_static/test_gradname_parse.py index 4dbcdb44a7138b..8e7fb017362b6e 100644 --- a/test/dygraph_to_static/test_gradname_parse.py +++ b/test/dygraph_to_static/test_gradname_parse.py @@ -21,6 +21,9 @@ ) import paddle + +# NOTE(Pan Zhaowu): Using decomp rules to fulfill promise of high-level grad, +paddle.core._set_prim_all_enabled(True) from paddle.nn import BatchNorm, Linear diff --git a/test/ir/pir/test_ir_backward.py b/test/ir/pir/test_ir_backward.py index 787cb232c6fa0c..4c7ddefcb1e53d 100644 --- a/test/ir/pir/test_ir_backward.py +++ b/test/ir/pir/test_ir_backward.py @@ -18,8 +18,10 @@ import paddle -# NOTE(Pan Zhaowu): enable prim all to support high-order gradients of linear_v2 -# paddle.core._set_prim_all_enabled(True) +# NOTE(Pan Zhaowu): Using legacy_linear to fulfill promise of high-level grad, +# with no side-effects to other ops. +# linear_v2's decomposed grad is fully tested in test_gradname_parse.py +paddle.set_flags({"FLAGS_use_legacy_linear": True}) from paddle.autograd.backward_utils import ValueDict, ValueSet from paddle.autograd.ir_backward import grad from paddle.base.wrapped_decorator import signature_safe_contextmanager diff --git a/test/legacy_test/check_nan_inf_base_dygraph.py b/test/legacy_test/check_nan_inf_base_dygraph.py index dc83c23a90fb26..7edbafb61cc1eb 100644 --- a/test/legacy_test/check_nan_inf_base_dygraph.py +++ b/test/legacy_test/check_nan_inf_base_dygraph.py @@ -19,6 +19,9 @@ import paddle from paddle import nn +# NOTE(Pan Zhaowu): Using legacy linear to fulfill the hard-coded op_count in test_nan_inf.py, +# which summon this script individually, with horrible design. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) # os.environ["GLOG_vmodule"] = "nan_inf_utils_detail=10" diff --git a/test/legacy_test/test_jit_save_load.py b/test/legacy_test/test_jit_save_load.py index 04598ecdbcc6bc..74393718173cd0 100644 --- a/test/legacy_test/test_jit_save_load.py +++ b/test/legacy_test/test_jit_save_load.py @@ -2331,13 +2331,13 @@ def test_save_dtype(self): out = model(data) save_dir = os.path.join(self.temp_dir.name, "test_save_dtype") path = save_dir + "/model" - paddle.jit.save( - model, path, input_spec=[InputSpec([None, 32], dtype='float32')] - ) + with paddle.amp.auto_cast(level='O2'): + paddle.jit.save( + model, path, input_spec=[InputSpec([None, 32], dtype='float32')] + ) loaded_model = paddle.jit.load(path) loaded_model = paddle.amp.decorate(models=loaded_model, level='O2') - with paddle.amp.auto_cast(level='O2'): - loaded_out = loaded_model(data) + loaded_out = loaded_model(data) np.testing.assert_allclose(out.numpy(), loaded_out.numpy(), atol=1e-5) diff --git a/test/legacy_test/test_paddlescience.py b/test/legacy_test/test_paddlescience.py index 0bddb3c4952640..0d424b14ef3c34 100644 --- a/test/legacy_test/test_paddlescience.py +++ b/test/legacy_test/test_paddlescience.py @@ -37,7 +37,7 @@ class TestPaddleSciencemodel(unittest.TestCase): def test_concat(self): - @jit.to_static + @jit.to_static(full_graph=True) def concat(x, y): """abc""" z = paddle.concat([x, y], 0) @@ -54,7 +54,7 @@ def concat(x, y): class TestEularBeam(unittest.TestCase): def test_eular_beam(self): - @jit.to_static + @jit.to_static(full_graph=True) def eular_beam(x): """abc""" z_ = model(x) From 429d39f39bdca37633faf432cb2264495e3c6c31 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Tue, 6 Jan 2026 13:13:10 +0800 Subject: [PATCH 50/55] fix ci --- test/amp/test_amp_api.py | 3 +-- test/amp/test_amp_master_grad.py | 3 +-- test/amp/test_amp_promote.py | 3 +-- test/collective/fleet/test_dygraph_sharding_stage2.py | 5 +++-- test/ir/pir/cinn/sub_graphs/base.py | 3 +-- test/legacy_test/test_cumsum_op.py | 6 +++--- test/legacy_test/test_lookahead.py | 3 +-- test/legacy_test/test_pir_translated_layer.py | 2 +- 8 files changed, 12 insertions(+), 16 deletions(-) diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py index 18cc25faae41dd..691f38e0c6b67b 100644 --- a/test/amp/test_amp_api.py +++ b/test/amp/test_amp_api.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest from contextlib import contextmanager @@ -24,7 +23,7 @@ from paddle.base import core from paddle.static import amp -os.environ["FLAGS_use_legacy_linear"] = "True" +paddle.set_flags({"FLAGS_use_legacy_linear": True}) @unittest.skipIf( diff --git a/test/amp/test_amp_master_grad.py b/test/amp/test_amp_master_grad.py index c08013864e232e..8eb50136dbb6b5 100644 --- a/test/amp/test_amp_master_grad.py +++ b/test/amp/test_amp_master_grad.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest import numpy as np @@ -20,7 +19,7 @@ import paddle from paddle.base import core -os.environ["FLAGS_use_legacy_linear"] = "True" +paddle.set_flags({"FLAGS_use_legacy_linear": True}) class SimpleNet(paddle.nn.Layer): diff --git a/test/amp/test_amp_promote.py b/test/amp/test_amp_promote.py index 3d0298ccc81f2a..66fc89e376d61f 100644 --- a/test/amp/test_amp_promote.py +++ b/test/amp/test_amp_promote.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest import numpy as np @@ -22,7 +21,7 @@ from paddle.base import core from paddle.static import amp -os.environ["FLAGS_use_legacy_linear"] = "True" +paddle.set_flags({"FLAGS_use_legacy_linear": True}) @unittest.skipIf( diff --git a/test/collective/fleet/test_dygraph_sharding_stage2.py b/test/collective/fleet/test_dygraph_sharding_stage2.py index 8a135a62909acb..e48b63ad0cc067 100644 --- a/test/collective/fleet/test_dygraph_sharding_stage2.py +++ b/test/collective/fleet/test_dygraph_sharding_stage2.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest from legacy_test.test_parallel_dygraph_dataparallel import ( TestMultipleAccelerators, ) +import paddle + # NOTE(Pan Zhaowu): using legacy linear to fulfill promise of array equal. -os.environ["FLAGS_use_legacy_linear"] = "True" +paddle.set_flags({"FLAGS_use_legacy_linear": True}) class TestDygraphShardingStage2(TestMultipleAccelerators): diff --git a/test/ir/pir/cinn/sub_graphs/base.py b/test/ir/pir/cinn/sub_graphs/base.py index c48ce029b0a620..c2f2c3616e48ff 100644 --- a/test/ir/pir/cinn/sub_graphs/base.py +++ b/test/ir/pir/cinn/sub_graphs/base.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest import numpy as np @@ -20,7 +19,7 @@ import paddle # NOTE(Pan Zhaowu): using legacy linear to fulfill promise of array equal in test_ast_prim_cinn -os.environ["FLAGS_use_legacy_linear"] = "True" +paddle.set_flags({"FLAGS_use_legacy_linear": True}) class TestBase(unittest.TestCase): diff --git a/test/legacy_test/test_cumsum_op.py b/test/legacy_test/test_cumsum_op.py index c70e18265b8571..f00fb54c114e94 100644 --- a/test/legacy_test/test_cumsum_op.py +++ b/test/legacy_test/test_cumsum_op.py @@ -17,9 +17,6 @@ import tempfile import unittest -# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of hard-coded op numbers in TestTensorAxis. -os.environ["FLAGS_use_legacy_linear"] = "True" - from paddle.framework import use_pir_api sys.path.append("../../legacy_test") @@ -34,6 +31,9 @@ ) import paddle + +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of hard-coded op numbers in TestTensorAxis. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) import paddle.inference as paddle_infer from paddle import base from paddle.base import core diff --git a/test/legacy_test/test_lookahead.py b/test/legacy_test/test_lookahead.py index ccd09e2ec55c7a..2b7b2080901958 100644 --- a/test/legacy_test/test_lookahead.py +++ b/test/legacy_test/test_lookahead.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest import numpy as np @@ -20,7 +19,7 @@ import paddle # NOTE(Pan Zhaowu): using legacy linear to fulfill the promise of add_grad op. -os.environ["FLAGS_use_legacy_linear"] = "True" +paddle.set_flags({"FLAGS_use_legacy_linear": True}) from paddle import base, nn from paddle.base.framework import in_pir_mode diff --git a/test/legacy_test/test_pir_translated_layer.py b/test/legacy_test/test_pir_translated_layer.py index cc90bfdc28a0ed..b72e7c2a7d1626 100644 --- a/test/legacy_test/test_pir_translated_layer.py +++ b/test/legacy_test/test_pir_translated_layer.py @@ -32,7 +32,7 @@ # NOTE(Pan Zhaowu): using legacy linear to fulfill promise of array equal # in test_inference_and_fine_tuning. -os.environ["FLAGS_use_legacy_linear"] = "True" +paddle.set_flags({"FLAGS_use_legacy_linear": True}) # define a random dataset From e47e7b6d7fc3c942a64df14aa782568efda23f3c Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 7 Jan 2026 16:27:36 +0800 Subject: [PATCH 51/55] fix auto parallel. --- paddle/phi/infermeta/spmd_rules/linear_v2.cc | 21 +++++++--------- test/auto_parallel/local_view_compute.py | 8 ++++--- .../static_reshard_api_cross_mesh.py | 24 +++++++------------ .../test_sub_graph_stable_diffusion_2_st.py | 4 ++++ 4 files changed, 25 insertions(+), 32 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/linear_v2.cc b/paddle/phi/infermeta/spmd_rules/linear_v2.cc index 95b82b9a4dac84..ed5814cba05343 100644 --- a/paddle/phi/infermeta/spmd_rules/linear_v2.cc +++ b/paddle/phi/infermeta/spmd_rules/linear_v2.cc @@ -115,19 +115,14 @@ SpmdInfo LinearV2InferSpmdBase(const DistMetaTensor& input, output_dist_attr_dst.set_partial_status(partial_on_dims); if (output_dist_attr_dst.is_partial()) { - bias_dist_attr_dst.set_partial_status( - output_dist_attr_dst.partial_status()); - if (!IsPartialLegal(bias_dist_attr_dst) || - !IsPartialLegal(output_dist_attr_dst)) { - VLOG(4) << "LinearV2 partial output illegal, force set output " - "to replicated."; - output_dist_attr_dst.clean_partial_status(); - bias_dist_attr_dst.clean_partial_status(); - SetTensorDistAttrReplicated(&x_dist_attr_dst, input_ndim); - SetTensorDistAttrReplicated(&y_dist_attr_dst, weight_ndim); - SetTensorDistAttrReplicated(&bias_dist_attr_dst, bias_ndim); - SetTensorDistAttrReplicated(&output_dist_attr_dst, out_axes.size()); - } + // NOTE(Pan Zhaowu): linear_v2, as a fused matmul+elew op, which is + // different from legacy hacked behaviour, so disabled partial distribution + // strategy for now. + output_dist_attr_dst.clean_partial_status(); + SetTensorDistAttrReplicated(&x_dist_attr_dst, input_ndim); + SetTensorDistAttrReplicated(&y_dist_attr_dst, weight_ndim); + SetTensorDistAttrReplicated(&bias_dist_attr_dst, bias_ndim); + SetTensorDistAttrReplicated(&output_dist_attr_dst, out_axes.size()); } TensorDistAttr output_reserve_dist_attr_dst = CopyTensorDistAttrForOutput(output_dist_attr_dst); diff --git a/test/auto_parallel/local_view_compute.py b/test/auto_parallel/local_view_compute.py index 51a82eb5c489ae..bd72bcf32b8866 100644 --- a/test/auto_parallel/local_view_compute.py +++ b/test/auto_parallel/local_view_compute.py @@ -73,10 +73,12 @@ def masked_lm_loss_func(pred, label, global_local_loss_list_item=None): lossmask[8:16] = 1 pred_sub = pred[:, 0:1] # shape [B,1] - label_float = paddle.cast(label, 'float32') # shape [B,1] + # NOTE(Pan Zhaowu): Using float64 as golden to provide more + # persuasive result. + label_float = paddle.cast(label, 'float64') # shape [B,1] raw_loss = paddle.abs(pred_sub - label_float) - lossmask_ = lossmask.reshape([-1]).cast('float32') - raw_loss_flat = raw_loss.reshape([-1]).cast('float32') + lossmask_ = lossmask.reshape([-1]).cast('float64') + raw_loss_flat = raw_loss.reshape([-1]).cast('float64') masked_lm_loss_sum = paddle.sum(raw_loss_flat * lossmask_) valid_count = paddle.sum(lossmask_) diff --git a/test/auto_parallel/static_reshard_api_cross_mesh.py b/test/auto_parallel/static_reshard_api_cross_mesh.py index e871d03a7f4059..d80dce3a82db0a 100644 --- a/test/auto_parallel/static_reshard_api_cross_mesh.py +++ b/test/auto_parallel/static_reshard_api_cross_mesh.py @@ -160,18 +160,14 @@ def test_reshard_mesh(self): 'builtin.parameter', 'builtin.parameter', 'pd_op.data', - 'pd_op.matmul', - 'pd_op.add', + 'pd_op.linear_v2', 'pd_op.relu', - 'pd_op.matmul', - 'pd_op.add', + 'pd_op.linear_v2', 'pd_op.send_v2', 'pd_op.recv_v2', - 'pd_op.add_grad', - 'pd_op.matmul_grad', + 'pd_op.linear_v2_grad', 'pd_op.relu_grad', - 'pd_op.add_grad', - 'pd_op.matmul_grad', + 'pd_op.linear_v2_grad', 'pd_op.sgd_', 'pd_op.sgd_', 'pd_op.sgd_', @@ -188,11 +184,9 @@ def test_reshard_mesh(self): 'builtin.parameter', 'pd_op.data', 'pd_op.recv_v2', - 'pd_op.matmul', - 'pd_op.add', + 'pd_op.linear_v2', 'pd_op.relu', - 'pd_op.matmul', - 'pd_op.add', + 'pd_op.linear_v2', 'pd_op.subtract', 'pd_op.square', 'pd_op.full_int_array', @@ -203,11 +197,9 @@ def test_reshard_mesh(self): 'pd_op.mean_grad', 'pd_op.square_grad', 'pd_op.subtract_grad', - 'pd_op.add_grad', - 'pd_op.matmul_grad', + 'pd_op.linear_v2_grad', 'pd_op.relu_grad', - 'pd_op.add_grad', - 'pd_op.matmul_grad', + 'pd_op.linear_v2_grad', 'pd_op.send_v2', 'pd_op.sgd_', 'pd_op.sgd_', diff --git a/test/ir/pir/cinn/symbolic/test_sub_graph_stable_diffusion_2_st.py b/test/ir/pir/cinn/symbolic/test_sub_graph_stable_diffusion_2_st.py index a98e025dd5b3c3..82f38ef91c9aa3 100644 --- a/test/ir/pir/cinn/symbolic/test_sub_graph_stable_diffusion_2_st.py +++ b/test/ir/pir/cinn/symbolic/test_sub_graph_stable_diffusion_2_st.py @@ -21,6 +21,10 @@ import paddle +# NOTE(Pan Zhaowu): Using legacy linear to maintain the same behavior as the +# matmul + add op. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) + class LayerCase(paddle.nn.Layer): def __init__(self): From a3abc5eed2bcfc5af3ce99877ec46a01d5df294f Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 9 Jan 2026 15:06:55 +0800 Subject: [PATCH 52/55] fix TRT and prec test --- test/collective/fleet/dygraph_group_sharded_stage2.py | 2 ++ .../fleet/dygraph_group_sharded_stage2_comm_overlap.py | 2 ++ test/tensorrt/test_converter_model_resnet50.py | 4 ++++ test/tensorrt/test_converter_model_resnet50_move.py | 4 ++++ 4 files changed, 12 insertions(+) diff --git a/test/collective/fleet/dygraph_group_sharded_stage2.py b/test/collective/fleet/dygraph_group_sharded_stage2.py index 61033956bbaeaa..18bac8244148fe 100644 --- a/test/collective/fleet/dygraph_group_sharded_stage2.py +++ b/test/collective/fleet/dygraph_group_sharded_stage2.py @@ -35,6 +35,8 @@ np.random.seed(seed) paddle.seed(seed) +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of array equal. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) class MLP(paddle.nn.Layer): diff --git a/test/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py b/test/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py index 573caa86eaa943..8cae0315119441 100644 --- a/test/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py +++ b/test/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py @@ -35,6 +35,8 @@ np.random.seed(seed) paddle.seed(seed) +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of array equal. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) class MLP(paddle.nn.Layer): diff --git a/test/tensorrt/test_converter_model_resnet50.py b/test/tensorrt/test_converter_model_resnet50.py index 1ba466da90e108..a56ddbb0cc5253 100644 --- a/test/tensorrt/test_converter_model_resnet50.py +++ b/test/tensorrt/test_converter_model_resnet50.py @@ -38,6 +38,10 @@ ) from paddle.vision.models import resnet18 +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of tensorrt graph capturing +# and converting. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) + def standardize(array): mean_val = np.mean(array) diff --git a/test/tensorrt/test_converter_model_resnet50_move.py b/test/tensorrt/test_converter_model_resnet50_move.py index 0ab1d6b18a5bd5..4d3419a93582e7 100644 --- a/test/tensorrt/test_converter_model_resnet50_move.py +++ b/test/tensorrt/test_converter_model_resnet50_move.py @@ -33,6 +33,10 @@ predict_program, ) +# NOTE(Pan Zhaowu): using legacy linear to fulfill promise of tensorrt graph capturing +# and converting. +paddle.set_flags({"FLAGS_use_legacy_linear": True}) + def standardize(array): mean_val = np.mean(array) From 014eea20aa2d7547fcbcd12c3553c6b65b9b249f Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 12 Jan 2026 14:24:21 +0800 Subject: [PATCH 53/55] add infer_symbolic instance, remove glog including in header --- paddle/phi/kernels/linear_v2_grad_kernel.h | 1 - .../symbolic/test_cinn_input_shape_symbolic.py | 16 ++++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/linear_v2_grad_kernel.h b/paddle/phi/kernels/linear_v2_grad_kernel.h index 6fab0e08665166..ff43c282a1c935 100644 --- a/paddle/phi/kernels/linear_v2_grad_kernel.h +++ b/paddle/phi/kernels/linear_v2_grad_kernel.h @@ -15,7 +15,6 @@ #include #include #include -#include "glog/logging.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" diff --git a/test/ir/pir/cinn/symbolic/test_cinn_input_shape_symbolic.py b/test/ir/pir/cinn/symbolic/test_cinn_input_shape_symbolic.py index 9178907b83d021..3b16e16f49c408 100644 --- a/test/ir/pir/cinn/symbolic/test_cinn_input_shape_symbolic.py +++ b/test/ir/pir/cinn/symbolic/test_cinn_input_shape_symbolic.py @@ -26,8 +26,14 @@ class LayerCase(paddle.nn.Layer): - def __init__(self): + def __init__(self, in_features=256, out_features=256): super().__init__() + self.parameter_2 = self.create_parameter( + shape=[in_features, out_features], dtype='float32' + ) + self.parameter_1 = self.create_parameter( + shape=[out_features], dtype='float32' + ) def forward( self, @@ -35,13 +41,19 @@ def forward( var_1, # (shape: [], dtype: paddle.int32, stop_gradient: True) ): var_2 = var_0.unsqueeze(axis=0) - var_3 = var_2.transpose( + + var_linear = paddle.nn.functional.linear( + x=var_2, weight=self.parameter_2, bias=self.parameter_1, name=None + ) + + var_3 = var_linear.transpose( ( 0, 2, 1, ) ) + var_4 = var_3.expand( ( var_1, From fb216798b51fa9a739d706418cdcd4cdf1c54c5c Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 12 Jan 2026 19:45:58 +0800 Subject: [PATCH 54/55] fix reduncant DtoD cpy --- paddle/phi/kernels/gpu/linear_v2_kernel.cu | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/paddle/phi/kernels/gpu/linear_v2_kernel.cu b/paddle/phi/kernels/gpu/linear_v2_kernel.cu index 7f000ab1c83ef7..020dbadc95c13b 100644 --- a/paddle/phi/kernels/gpu/linear_v2_kernel.cu +++ b/paddle/phi/kernels/gpu/linear_v2_kernel.cu @@ -80,12 +80,10 @@ void LinearV2Kernel(const Context& dev_ctx, const auto out_dim_original = out->dims(); const auto [M, N, K] = canonicalize_dims(input, weight); VLOG(10) << "M: " << M << ", N: " << N << ", K: " << K; - - DenseTensor input_processed; - DenseTensor weight_processed; - DenseTensor output_processed; - phi::ReshapeKernel(dev_ctx, input, {M, K}, &input_processed); - phi::ReshapeKernel(dev_ctx, weight, {K, N}, &weight_processed); + DenseTensor input_processed = input; + DenseTensor weight_processed = weight; + input_processed.Resize(common::make_ddim({M, K})); + weight_processed.Resize(common::make_ddim({K, N})); out->Resize(common::make_ddim({M, N})); VLOG(10) << "input_processed: " << input_processed.dims() << ", weight_processed: " << weight_processed.dims() @@ -103,8 +101,8 @@ void LinearV2Kernel(const Context& dev_ctx, // CublasLt path with bias add epilogue phi::funcs::LinearWithCublasLt::Run( dev_ctx, - &input, - &weight, + &input_processed, + &weight_processed, out, static_cast(bias_processed.data()), nullptr, @@ -115,10 +113,9 @@ void LinearV2Kernel(const Context& dev_ctx, false, phi::funcs::MatmulFusedType::kMatmulBias); } else { - DenseTensor bias_processed; + DenseTensor bias_processed = bias; if (bias.numel() != (M * N)) { - phi::ReshapeKernel( - dev_ctx, bias, {1, bias.numel()}, &bias_processed); + bias_processed.Resize(common::make_ddim({1, bias.numel()})); VLOG(10) << "bias.dim(): " << bias.dims(); VLOG(10) << "M*N: " << M * N; VLOG(10) << "bias tiling and addmm calculating"; From 72b70f78e20776055305ee3a35cb8a228f0356f0 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Tue, 13 Jan 2026 10:50:02 +0800 Subject: [PATCH 55/55] coverage linear_v2 symbolics --- .../pir/cinn/symbolic/test_cinn_input_shape_symbolic.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/ir/pir/cinn/symbolic/test_cinn_input_shape_symbolic.py b/test/ir/pir/cinn/symbolic/test_cinn_input_shape_symbolic.py index 3b16e16f49c408..909648dc1d5169 100644 --- a/test/ir/pir/cinn/symbolic/test_cinn_input_shape_symbolic.py +++ b/test/ir/pir/cinn/symbolic/test_cinn_input_shape_symbolic.py @@ -24,6 +24,14 @@ import paddle from paddle.static import InputSpec +# NOTE(Pan Zhaowu): disable linear_v2 decomp to test infersymbolics +paddle.set_flags( + { + "FLAGS_deny_cinn_ops": "linear_v2", + "FLAGS_prim_forward_blacklist": "pd_op.linear_v2", + } +) + class LayerCase(paddle.nn.Layer): def __init__(self, in_features=256, out_features=256):