diff --git a/paddle/fluid/pir/dialect/operator/interface/layout_transformation.cc b/paddle/fluid/pir/dialect/operator/interface/layout_transformation.cc index a1fffa7991750d..231c8b26cdb1b6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/layout_transformation.cc +++ b/paddle/fluid/pir/dialect/operator/interface/layout_transformation.cc @@ -39,13 +39,13 @@ std::vector RelevantOutputsImpl( } template <> -common::DataLayout PreferLayoutImpl(pir::Operation* op) { +DataLayout PreferLayoutImpl(pir::Operation* op) { // Note(bukejiyu): add_group_norm_silu only supports NHWC layout now. - return common::DataLayout::NHWC; + return DataLayout::NHWC; } template <> -common::DataLayout PreferLayoutImpl(pir::Operation* op) { +DataLayout PreferLayoutImpl(pir::Operation* op) { auto data_format_attr = op->attribute("data_format"); if (!data_format_attr) { PADDLE_THROW(common::errors::InvalidArgument( @@ -60,7 +60,7 @@ common::DataLayout PreferLayoutImpl(pir::Operation* op) { if (in_type.isa()) { if (auto tensor_type = in_type.dyn_cast()) { if (tensor_type.dtype().isa()) { - return common::DataLayout::NHWC; + return DataLayout::NHWC; } } } @@ -71,7 +71,7 @@ common::DataLayout PreferLayoutImpl(pir::Operation* op) { } template <> -common::DataLayout PreferLayoutImpl(pir::Operation* op) { +DataLayout PreferLayoutImpl(pir::Operation* op) { auto data_format_attr = op->attribute("data_format"); if (!data_format_attr) { PADDLE_THROW(common::errors::InvalidArgument( @@ -86,7 +86,7 @@ common::DataLayout PreferLayoutImpl(pir::Operation* op) { if (in_type.isa()) { if (auto tensor_type = in_type.dyn_cast()) { if (tensor_type.dtype().isa()) { - return common::DataLayout::NHWC; + return DataLayout::NHWC; } } } @@ -102,7 +102,7 @@ bool CanBeModifiedImpl(pir::Operation* op) { } template <> -common::DataLayout PreferLayoutImpl(pir::Operation* op) { +DataLayout PreferLayoutImpl(pir::Operation* op) { auto data_format_attr = op->attribute("data_format"); if (!data_format_attr) { PADDLE_THROW(common::errors::InvalidArgument( @@ -119,7 +119,7 @@ common::DataLayout PreferLayoutImpl(pir::Operation* op) { .at(kForceBackendAttr) .dyn_cast() .AsString() == "gpu") { - return common::DataLayout::NHWC; + return DataLayout::NHWC; } auto concrete_op = op->dyn_cast(); @@ -146,7 +146,7 @@ common::DataLayout PreferLayoutImpl(pir::Operation* op) { auto dims = tensor_type.dims(); if (dims.size() == 4 && (dims[0] % CUDNN_ALIGNMENT == 0) && (dims[1] % CUDNN_ALIGNMENT == 0)) { - return common::DataLayout::NHWC; + return DataLayout::NHWC; } } } @@ -227,8 +227,7 @@ bool CanBeModifiedImpl(pir::Operation* op) { } template <> -void RewriteByLayoutImpl(pir::Operation* op, - common::DataLayout new_layout) { +void RewriteByLayoutImpl(pir::Operation* op, DataLayout new_layout) { PADDLE_THROW(common::errors::Unimplemented( "Op %s should have a specialized RewriteByLayout function", op->name())); return; @@ -275,8 +274,7 @@ std::vector RelevantInputsImpl(pir::Operation* op) { } template <> -void RewriteByLayoutImpl(pir::Operation* op, - common::DataLayout new_layout) { +void RewriteByLayoutImpl(pir::Operation* op, DataLayout new_layout) { // we must the value of concat axis, but this is an input // which is really hard to process. // here we handle the simple case like pd_op.full and throw @@ -313,8 +311,7 @@ void RewriteByLayoutImpl(pir::Operation* op, } template <> -void RewriteByLayoutImpl(pir::Operation* op, - common::DataLayout new_layout) { +void RewriteByLayoutImpl(pir::Operation* op, DataLayout new_layout) { auto concrete_op = op->dyn_cast(); auto shape = concrete_op.shape(); @@ -333,10 +330,10 @@ void RewriteByLayoutImpl(pir::Operation* op, "Reshape's shape size was expected as 4, but got %d", value_attr.size())); std::vector new_value_attr; - if (new_layout == common::DataLayout::NHWC) { + if (new_layout == DataLayout::NHWC) { new_value_attr = std::vector{ value_attr[0], value_attr[2], value_attr[3], value_attr[1]}; - } else if (new_layout == common::DataLayout::NCHW) { + } else if (new_layout == DataLayout::NCHW) { new_value_attr = std::vector{ value_attr[0], value_attr[3], value_attr[1], value_attr[2]}; } else { @@ -353,8 +350,7 @@ void RewriteByLayoutImpl(pir::Operation* op, } template <> -void RewriteByLayoutImpl(pir::Operation* op, - common::DataLayout new_layout) { +void RewriteByLayoutImpl(pir::Operation* op, DataLayout new_layout) { auto concrete_op = op->dyn_cast(); auto axis = concrete_op.axis(); if (!axis || !(axis.defining_op()->isa())) { @@ -381,7 +377,7 @@ void RewriteByLayoutImpl(pir::Operation* op, template <> void RewriteByLayoutImpl(pir::Operation* op, - common::DataLayout new_layout) { + DataLayout new_layout) { auto concrete_op = op->dyn_cast(); auto out = concrete_op.out(); if (!out) return; @@ -403,7 +399,7 @@ std::vector RelevantInputsImpl(pir::Operation* op) { } template <> -common::DataLayout PreferLayoutImpl(pir::Operation* op) { +DataLayout PreferLayoutImpl(pir::Operation* op) { auto concrete_op = op->dyn_cast(); auto data_format_attr = op->attribute("data_format"); auto origin_format = common::StringToDataLayout(data_format_attr.AsString()); @@ -420,11 +416,11 @@ common::DataLayout PreferLayoutImpl(pir::Operation* op) { // get input dims h, w, c int32_t h, w, c; - if (origin_format == common::DataLayout::NHWC) { + if (origin_format == DataLayout::NHWC) { h = input.dims().at(1); w = input.dims().at(2); c = input.dims().at(3); - } else if (origin_format == common::DataLayout::NCHW) { + } else if (origin_format == DataLayout::NCHW) { h = input.dims().at(2); w = input.dims().at(3); c = input.dims().at(1); @@ -453,61 +449,61 @@ common::DataLayout PreferLayoutImpl(pir::Operation* op) { }; // TODO(liujinnan): need to test the prefer layout if kernel_size is not // aligned. - if (!AllEqual(kernel_size)) return common::DataLayout::NCHW; + if (!AllEqual(kernel_size)) return DataLayout::NCHW; int k = kernel_size[0]; // kernel size is all 1, prefer NCHW. - if (k == 1 || k == 2) return common::DataLayout::NCHW; + if (k == 1 || k == 2) return DataLayout::NCHW; if (pool_type == "max") { if (h * w <= 64 * 64) { - if (k <= 3) return common::DataLayout::NHWC; - return common::DataLayout::NCHW; + if (k <= 3) return DataLayout::NHWC; + return DataLayout::NCHW; } else { if (c <= 16) { - if (k <= 5) return common::DataLayout::NHWC; - return common::DataLayout::NCHW; + if (k <= 5) return DataLayout::NHWC; + return DataLayout::NCHW; } // when c > 16, all kernel_size return NHWC - return common::DataLayout::NHWC; + return DataLayout::NHWC; } } else if (pool_type == "avg") { if (h * w <= 64 * 64) { if (c <= 16) { if (k < 7) - return common::DataLayout::NCHW; + return DataLayout::NCHW; else - return common::DataLayout::NHWC; + return DataLayout::NHWC; } else if (c > 16 && c <= 32) { if (k <= 7) - return common::DataLayout::NHWC; + return DataLayout::NHWC; else - return common::DataLayout::NCHW; + return DataLayout::NCHW; } else if (c > 32 && c <= 64) { if (k < 5) - return common::DataLayout::NCHW; + return DataLayout::NCHW; else - return common::DataLayout::NHWC; + return DataLayout::NHWC; } else if (c > 64 && c <= 128) { if (k < 7) - return common::DataLayout::NHWC; + return DataLayout::NHWC; else - return common::DataLayout::NCHW; + return DataLayout::NCHW; } // when c > 128, all kernel_size return NHWC - return common::DataLayout::NHWC; + return DataLayout::NHWC; } else { if (c < 64) { - return common::DataLayout::NHWC; + return DataLayout::NHWC; } else { if (k < 7) - return common::DataLayout::NHWC; + return DataLayout::NHWC; else - return common::DataLayout::NCHW; + return DataLayout::NCHW; } } } - return common::DataLayout::NCHW; + return DataLayout::NCHW; }; return PreferLayout(c, h, w, kernel_size, pool_type); }