Skip to content

Commit e40e347

Browse files
committed
optimize paddle::Tensor in fluid/prim
1 parent 3d889e5 commit e40e347

File tree

4 files changed

+84
-93
lines changed

4 files changed

+84
-93
lines changed

paddle/fluid/prim/api/manual_prim/static_prim_api.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ Tensor full<DescTensor>(const IntArray& shape,
108108

109109
template <>
110110
Tensor cast<DescTensor>(const Tensor& x, DataType dtype) {
111-
Tensor out = empty<DescTensor>({}, DataType::FLOAT32, paddle::Place());
111+
Tensor out = empty<DescTensor>({}, DataType::FLOAT32, Place());
112112
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
113113
framework::OpDesc* op = block->AppendOp();
114114
op->SetType("cast");
@@ -137,7 +137,7 @@ Tensor slice<DescTensor>(const Tensor& input,
137137
op->SetInput(
138138
"Input",
139139
{std::static_pointer_cast<prim::DescTensor>(input.impl())->Name()});
140-
auto out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
140+
auto out = empty<DescTensor>({}, phi::DataType::FLOAT32, Place());
141141
op->SetOutput(
142142
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
143143
op->SetAttr("axes", unsafe_vector_cast<int64_t, int>(axes));

paddle/fluid/prim/api/manual_prim/utils/eager_utils.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,31 +21,31 @@ namespace paddle::prim {
2121
template <>
2222
Tensor empty<Tensor>(const paddle::experimental::IntArray& shape,
2323
phi::DataType dtype,
24-
const paddle::Place& place) {
24+
const Place& place) {
2525
if (dtype == phi::DataType::UNDEFINED) {
2626
dtype = phi::DataType::FLOAT32;
2727
}
2828
return empty_ad_func(shape, dtype, place);
2929
}
3030

3131
template <>
32-
Tensor empty_like<Tensor>(const paddle::Tensor& x,
32+
Tensor empty_like<Tensor>(const Tensor& x,
3333
phi::DataType dtype,
34-
const paddle::Place& place) {
34+
const Place& place) {
3535
if (dtype == phi::DataType::UNDEFINED) {
3636
dtype = phi::DataType::FLOAT32;
3737
}
3838
return empty_like_ad_func(x, dtype, place);
3939
}
4040

4141
template <>
42-
void set_output<Tensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) {
42+
void set_output<Tensor>(const Tensor& x_tmp, Tensor* x) {
4343
x->set_impl(x_tmp.impl());
4444
x->set_autograd_meta(x_tmp.mutable_autograd_meta());
4545
}
4646

4747
template <>
48-
void by_pass<Tensor>(const paddle::Tensor& x, Tensor* out) {
48+
void by_pass<Tensor>(const Tensor& x, Tensor* out) {
4949
set_output<Tensor>(x, out);
5050
}
5151

paddle/fluid/prim/api/manual_prim/utils/static_utils.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
#include "paddle/phi/api/include/tensor.h"
2525
#include "paddle/phi/core/utils/data_type.h"
2626
namespace paddle::prim {
27-
using Tensor = paddle::Tensor;
27+
using Tensor = Tensor;
2828
template <>
2929
TEST_API Tensor empty<DescTensor>(const paddle::experimental::IntArray& shape,
3030
phi::DataType dtype,
31-
const paddle::Place& place) {
31+
const Place& place) {
3232
framework::VarDesc* new_var =
3333
StaticCompositeContext::Instance().GetBlock()->Var(
3434
StaticCompositeContext::Instance().GenerateUniqueName());
@@ -41,24 +41,24 @@ TEST_API Tensor empty<DescTensor>(const paddle::experimental::IntArray& shape,
4141
template <>
4242
Tensor empty_like<DescTensor>(const Tensor& x,
4343
phi::DataType dtype,
44-
const paddle::Place& place) {
44+
const Place& place) {
4545
return empty<prim::DescTensor>(
46-
paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place());
46+
paddle::experimental::IntArray(x.shape()), x.dtype(), Place());
4747
}
4848

4949
template <>
50-
void set_output<DescTensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) {
50+
void set_output<DescTensor>(const Tensor& x_tmp, Tensor* x) {
5151
x->set_impl(x_tmp.impl());
5252
}
5353

5454
template <>
55-
void by_pass<DescTensor>(const paddle::Tensor& x, paddle::Tensor* real_out) {
55+
void by_pass<DescTensor>(const Tensor& x, Tensor* real_out) {
5656
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
5757
framework::OpDesc* op = block->AppendOp();
5858
op->SetType("assign");
5959
op->SetInput("X",
6060
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
61-
auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
61+
auto out = empty<DescTensor>({}, x.dtype(), Place());
6262
op->SetOutput(
6363
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
6464
op->CheckAttrs();

0 commit comments

Comments
 (0)