@@ -39,7 +39,7 @@ void MatmulKernelImpl(const Context& dev_ctx,
3939 const TensorType& x,
4040 const DenseTensor& y,
4141 DenseTensor* out) {
42- #if CUDA_VERSION >= 11000 || HIP_VERSION >= 402
42+ #if defined(PADDLE_WITH_CUDA) || HIP_VERSION >= 402
4343 std::vector<int64_t > xdim_vec = common::vectorize (x.dims ());
4444 std::vector<int64_t > ydim_vec = common::vectorize (y.dims ());
4545 auto x_ndims = xdim_vec.size ();
@@ -115,7 +115,7 @@ void MatmulCsrCsrKernel(const Context& dev_ctx,
115115 const SparseCsrTensor& x,
116116 const SparseCsrTensor& y,
117117 SparseCsrTensor* out) {
118- #if CUDA_VERSION >= 11000
118+ #if defined(PADDLE_WITH_CUDA)
119119 std::vector<int64_t > xdim_vec = phi::vectorize (x.dims ());
120120 std::vector<int64_t > ydim_vec = phi::vectorize (y.dims ());
121121 auto x_ndims = xdim_vec.size ();
@@ -152,13 +152,6 @@ void MatmulCsrCsrKernel(const Context& dev_ctx,
152152 auto sparse_blas = funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
153153 sparse_blas.SPGEMM (
154154 false , false , static_cast <T>(1 ), x, y, static_cast <T>(0 ), out);
155-
156- #else
157- #ifdef PADDLE_WITH_CUDA
158- PADDLE_THROW (common::errors::Unimplemented (
159- " forward of 'sparse.matmul' use cusparseSpGEMM, "
160- " which is supported from CUDA 11.0" ));
161- #endif
162155#endif
163156}
164157
@@ -182,7 +175,7 @@ void MaskedMatmulCsrKernel(const Context& dev_ctx,
182175 const DenseTensor& y,
183176 const SparseCsrTensor& mask,
184177 SparseCsrTensor* out) {
185- #if CUDA_VERSION >= 11030
178+ #if defined(PADDLE_WITH_CUDA)
186179 std::vector<int64_t > xdim_vec = common::vectorize (x.dims ());
187180 std::vector<int64_t > ydim_vec = common::vectorize (y.dims ());
188181 std::vector<int64_t > maskdim_vec = common::vectorize (mask.dims ());
@@ -252,10 +245,6 @@ void MaskedMatmulCsrKernel(const Context& dev_ctx,
252245 auto sparse_blas = funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
253246 sparse_blas.SDDMM (
254247 false , false , static_cast <T>(1 ), x, y, static_cast <T>(0 ), out);
255- #else
256- PADDLE_THROW (common::errors::Unimplemented (
257- " forward of 'sparse.masked_matmul' use cusparseSDDMM, which is supported "
258- " from CUDA 11.3" ));
259248#endif
260249}
261250
0 commit comments