Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ std::vector<pir::Value> RelevantOutputsImpl<AddGroupNormSiluOp>(
}

template <>
common::DataLayout PreferLayoutImpl<AddGroupNormSiluOp>(pir::Operation* op) {
DataLayout PreferLayoutImpl<AddGroupNormSiluOp>(pir::Operation* op) {
// Note(bukejiyu): add_group_norm_silu only supports NHWC layout now.
return common::DataLayout::NHWC;
return DataLayout::NHWC;
}

template <>
common::DataLayout PreferLayoutImpl<Conv2dOp>(pir::Operation* op) {
DataLayout PreferLayoutImpl<Conv2dOp>(pir::Operation* op) {
auto data_format_attr = op->attribute<pir::StrAttribute>("data_format");
if (!data_format_attr) {
PADDLE_THROW(common::errors::InvalidArgument(
Expand All @@ -60,7 +60,7 @@ common::DataLayout PreferLayoutImpl<Conv2dOp>(pir::Operation* op) {
if (in_type.isa<DenseTensorType>()) {
if (auto tensor_type = in_type.dyn_cast<DenseTensorType>()) {
if (tensor_type.dtype().isa<pir::Float16Type>()) {
return common::DataLayout::NHWC;
return DataLayout::NHWC;
}
}
}
Expand All @@ -71,7 +71,7 @@ common::DataLayout PreferLayoutImpl<Conv2dOp>(pir::Operation* op) {
}

template <>
common::DataLayout PreferLayoutImpl<Conv2dTransposeOp>(pir::Operation* op) {
DataLayout PreferLayoutImpl<Conv2dTransposeOp>(pir::Operation* op) {
auto data_format_attr = op->attribute<pir::StrAttribute>("data_format");
if (!data_format_attr) {
PADDLE_THROW(common::errors::InvalidArgument(
Expand All @@ -86,7 +86,7 @@ common::DataLayout PreferLayoutImpl<Conv2dTransposeOp>(pir::Operation* op) {
if (in_type.isa<DenseTensorType>()) {
if (auto tensor_type = in_type.dyn_cast<DenseTensorType>()) {
if (tensor_type.dtype().isa<pir::Float16Type>()) {
return common::DataLayout::NHWC;
return DataLayout::NHWC;
}
}
}
Expand All @@ -102,7 +102,7 @@ bool CanBeModifiedImpl<Conv2dOp>(pir::Operation* op) {
}

template <>
common::DataLayout PreferLayoutImpl<FusedConv2dAddActOp>(pir::Operation* op) {
DataLayout PreferLayoutImpl<FusedConv2dAddActOp>(pir::Operation* op) {
auto data_format_attr = op->attribute<pir::StrAttribute>("data_format");
if (!data_format_attr) {
PADDLE_THROW(common::errors::InvalidArgument(
Expand All @@ -119,7 +119,7 @@ common::DataLayout PreferLayoutImpl<FusedConv2dAddActOp>(pir::Operation* op) {
.at(kForceBackendAttr)
.dyn_cast<pir::StrAttribute>()
.AsString() == "gpu") {
return common::DataLayout::NHWC;
return DataLayout::NHWC;
}

auto concrete_op = op->dyn_cast<FusedConv2dAddActOp>();
Expand All @@ -146,7 +146,7 @@ common::DataLayout PreferLayoutImpl<FusedConv2dAddActOp>(pir::Operation* op) {
auto dims = tensor_type.dims();
if (dims.size() == 4 && (dims[0] % CUDNN_ALIGNMENT == 0) &&
(dims[1] % CUDNN_ALIGNMENT == 0)) {
return common::DataLayout::NHWC;
return DataLayout::NHWC;
}
}
}
Expand Down Expand Up @@ -227,8 +227,7 @@ bool CanBeModifiedImpl<ReshapeOp>(pir::Operation* op) {
}

template <>
void RewriteByLayoutImpl<SqueezeOp>(pir::Operation* op,
common::DataLayout new_layout) {
void RewriteByLayoutImpl<SqueezeOp>(pir::Operation* op, DataLayout new_layout) {
PADDLE_THROW(common::errors::Unimplemented(
"Op %s should have a specialized RewriteByLayout function", op->name()));
return;
Expand Down Expand Up @@ -275,8 +274,7 @@ std::vector<pir::Value> RelevantInputsImpl<ConcatOp>(pir::Operation* op) {
}

template <>
void RewriteByLayoutImpl<ConcatOp>(pir::Operation* op,
common::DataLayout new_layout) {
void RewriteByLayoutImpl<ConcatOp>(pir::Operation* op, DataLayout new_layout) {
// we must the value of concat axis, but this is an input
// which is really hard to process.
// here we handle the simple case like pd_op.full and throw
Expand Down Expand Up @@ -313,8 +311,7 @@ void RewriteByLayoutImpl<ConcatOp>(pir::Operation* op,
}

template <>
void RewriteByLayoutImpl<ReshapeOp>(pir::Operation* op,
common::DataLayout new_layout) {
void RewriteByLayoutImpl<ReshapeOp>(pir::Operation* op, DataLayout new_layout) {
auto concrete_op = op->dyn_cast<ReshapeOp>();

auto shape = concrete_op.shape();
Expand All @@ -333,10 +330,10 @@ void RewriteByLayoutImpl<ReshapeOp>(pir::Operation* op,
"Reshape's shape size was expected as 4, but got %d",
value_attr.size()));
std::vector<pir::Attribute> new_value_attr;
if (new_layout == common::DataLayout::NHWC) {
if (new_layout == DataLayout::NHWC) {
new_value_attr = std::vector<pir::Attribute>{
value_attr[0], value_attr[2], value_attr[3], value_attr[1]};
} else if (new_layout == common::DataLayout::NCHW) {
} else if (new_layout == DataLayout::NCHW) {
new_value_attr = std::vector<pir::Attribute>{
value_attr[0], value_attr[3], value_attr[1], value_attr[2]};
} else {
Expand All @@ -353,8 +350,7 @@ void RewriteByLayoutImpl<ReshapeOp>(pir::Operation* op,
}

template <>
void RewriteByLayoutImpl<ArgmaxOp>(pir::Operation* op,
common::DataLayout new_layout) {
void RewriteByLayoutImpl<ArgmaxOp>(pir::Operation* op, DataLayout new_layout) {
auto concrete_op = op->dyn_cast<ArgmaxOp>();
auto axis = concrete_op.axis();
if (!axis || !(axis.defining_op()->isa<FullOp>())) {
Expand All @@ -381,7 +377,7 @@ void RewriteByLayoutImpl<ArgmaxOp>(pir::Operation* op,

template <>
void RewriteByLayoutImpl<pir::CombineOp>(pir::Operation* op,
common::DataLayout new_layout) {
DataLayout new_layout) {
auto concrete_op = op->dyn_cast<pir::CombineOp>();
auto out = concrete_op.out();
if (!out) return;
Expand All @@ -403,7 +399,7 @@ std::vector<pir::Value> RelevantInputsImpl<Pool2dOp>(pir::Operation* op) {
}

template <>
common::DataLayout PreferLayoutImpl<Pool2dOp>(pir::Operation* op) {
DataLayout PreferLayoutImpl<Pool2dOp>(pir::Operation* op) {
auto concrete_op = op->dyn_cast<Pool2dOp>();
auto data_format_attr = op->attribute<pir::StrAttribute>("data_format");
auto origin_format = common::StringToDataLayout(data_format_attr.AsString());
Expand All @@ -420,11 +416,11 @@ common::DataLayout PreferLayoutImpl<Pool2dOp>(pir::Operation* op) {

// get input dims h, w, c
int32_t h, w, c;
if (origin_format == common::DataLayout::NHWC) {
if (origin_format == DataLayout::NHWC) {
h = input.dims().at(1);
w = input.dims().at(2);
c = input.dims().at(3);
} else if (origin_format == common::DataLayout::NCHW) {
} else if (origin_format == DataLayout::NCHW) {
h = input.dims().at(2);
w = input.dims().at(3);
c = input.dims().at(1);
Expand Down Expand Up @@ -453,61 +449,61 @@ common::DataLayout PreferLayoutImpl<Pool2dOp>(pir::Operation* op) {
};
// TODO(liujinnan): need to test the prefer layout if kernel_size is not
// aligned.
if (!AllEqual(kernel_size)) return common::DataLayout::NCHW;
if (!AllEqual(kernel_size)) return DataLayout::NCHW;

int k = kernel_size[0];
// kernel size is all 1, prefer NCHW.
if (k == 1 || k == 2) return common::DataLayout::NCHW;
if (k == 1 || k == 2) return DataLayout::NCHW;

if (pool_type == "max") {
if (h * w <= 64 * 64) {
if (k <= 3) return common::DataLayout::NHWC;
return common::DataLayout::NCHW;
if (k <= 3) return DataLayout::NHWC;
return DataLayout::NCHW;
} else {
if (c <= 16) {
if (k <= 5) return common::DataLayout::NHWC;
return common::DataLayout::NCHW;
if (k <= 5) return DataLayout::NHWC;
return DataLayout::NCHW;
}
// when c > 16, all kernel_size return NHWC
return common::DataLayout::NHWC;
return DataLayout::NHWC;
}
} else if (pool_type == "avg") {
if (h * w <= 64 * 64) {
if (c <= 16) {
if (k < 7)
return common::DataLayout::NCHW;
return DataLayout::NCHW;
else
return common::DataLayout::NHWC;
return DataLayout::NHWC;
} else if (c > 16 && c <= 32) {
if (k <= 7)
return common::DataLayout::NHWC;
return DataLayout::NHWC;
else
return common::DataLayout::NCHW;
return DataLayout::NCHW;
} else if (c > 32 && c <= 64) {
if (k < 5)
return common::DataLayout::NCHW;
return DataLayout::NCHW;
else
return common::DataLayout::NHWC;
return DataLayout::NHWC;
} else if (c > 64 && c <= 128) {
if (k < 7)
return common::DataLayout::NHWC;
return DataLayout::NHWC;
else
return common::DataLayout::NCHW;
return DataLayout::NCHW;
}
// when c > 128, all kernel_size return NHWC
return common::DataLayout::NHWC;
return DataLayout::NHWC;
} else {
if (c < 64) {
return common::DataLayout::NHWC;
return DataLayout::NHWC;
} else {
if (k < 7)
return common::DataLayout::NHWC;
return DataLayout::NHWC;
else
return common::DataLayout::NCHW;
return DataLayout::NCHW;
}
}
}
return common::DataLayout::NCHW;
return DataLayout::NCHW;
};
return PreferLayout(c, h, w, kernel_size, pool_type);
}
Expand Down
Loading