diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h index 56abb10f16189d..b358a1387ff920 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h @@ -109,6 +109,10 @@ class OneDNNLegacyKernelOp : public pir::Op { } // namespace dialect } // namespace paddle +namespace pir { +using paddle::dialect::PhiKernelOp; +} + IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomKernelOp) diff --git a/paddle/fluid/pir/transforms/general/remove_shadow_feed_pass.cc b/paddle/fluid/pir/transforms/general/remove_shadow_feed_pass.cc index 5ef855f66ace32..52041e5c7ebcf4 100644 --- a/paddle/fluid/pir/transforms/general/remove_shadow_feed_pass.cc +++ b/paddle/fluid/pir/transforms/general/remove_shadow_feed_pass.cc @@ -28,13 +28,12 @@ #include "paddle/pir/include/pass/pass_registry.h" #include "paddle/pir/include/pattern_rewrite/pattern_match.h" -namespace { +namespace pir { -std::unique_ptr GetParser( - pir::Operation *op) { +std::unique_ptr GetParser(Operation *op) { std::unique_ptr op_info_parser(nullptr); - std::string op_name = op->dyn_cast().op_name(); - auto op_info = pir::IrContext::Instance()->GetRegisteredOpInfo(op_name); + std::string op_name = op->dyn_cast().op_name(); + auto op_info = IrContext::Instance()->GetRegisteredOpInfo(op_name); if (op_info.HasInterface()) { auto impl = op_info.GetInterfaceImpl(); @@ -58,15 +57,13 @@ phi::Place GetVarPlace(const paddle::framework::Variable *var, return place; } -class RemoveShadowFeedPattern - : public pir::OpRewritePattern { +class RemoveShadowFeedPattern : public OpRewritePattern { public: - explicit RemoveShadowFeedPattern(pir::IrContext *context, - const pir::Block *block, + explicit RemoveShadowFeedPattern(IrContext *context, + const Block *block, const phi::Place &place, const paddle::framework::Scope *scope) - : pir::OpRewritePattern::OpRewritePattern( - context), + : OpRewritePattern::OpRewritePattern(context), place_(place), scope_(scope), kwargs_map_() { @@ -75,18 +72,17 @@ class RemoveShadowFeedPattern } } - bool IsSamePlaceShadowFeed(paddle::dialect::PhiKernelOp op) const { + bool IsSamePlaceShadowFeed(PhiKernelOp op) const { if (op.op_name() == "pd_op.shadow_feed") { auto in = op.operand_source(0); auto *var = [&]() -> paddle::framework::Variable * { auto *defined_op = in.defining_op(); - if (defined_op && defined_op->isa()) { - if (defined_op->dyn_cast() - .kernel_name() != "data") + if (defined_op && defined_op->isa()) { + if (defined_op->dyn_cast().kernel_name() != "data") return nullptr; const auto &name = defined_op->attributes() .at("name") - .dyn_cast() + .dyn_cast() .AsString(); return scope_->FindVar(name); } @@ -114,7 +110,7 @@ class RemoveShadowFeedPattern } int dst_place_type = - op.attribute("dst_place_type").dyn_cast().data(); + op.attribute("dst_place_type").dyn_cast().data(); if (dst_place_type == 0) { dst_place = phi::CPUPlace(); } else { @@ -126,7 +122,7 @@ class RemoveShadowFeedPattern return false; } - bool IsTensorAttrShadowFeed(paddle::dialect::PhiKernelOp op) const { + bool IsTensorAttrShadowFeed(PhiKernelOp op) const { if (op.op_name() == "pd_op.shadow_feed") { auto in = op.operand_source(0); if (!kwargs_map_.count(in)) { @@ -135,7 +131,7 @@ class RemoveShadowFeedPattern auto out = op.result(0); if (out.use_count() == 1) { auto use_op = out.first_use().owner(); - if (!use_op->isa()) { + if (!use_op->isa()) { return false; } auto op_info_parser = GetParser(use_op); @@ -150,12 +146,12 @@ class RemoveShadowFeedPattern return false; } - bool Match(paddle::dialect::PhiKernelOp op) const override { + bool Match(PhiKernelOp op) const override { return IsSamePlaceShadowFeed(op) || IsTensorAttrShadowFeed(op); } - void Rewrite(paddle::dialect::PhiKernelOp op, - pir::PatternRewriter &rewriter) const override { // NOLINT + void Rewrite(PhiKernelOp op, + PatternRewriter &rewriter) const override { // NOLINT auto in = op.operand_source(0); auto out = op.result(0); in.set_type(out.type()); @@ -166,22 +162,20 @@ class RemoveShadowFeedPattern private: const phi::Place place_; const paddle::framework::Scope *scope_; - std::unordered_map<::pir::Value, std::string> kwargs_map_; + std::unordered_map kwargs_map_; }; -class RemoveShadowFeedPatternInference - : public pir::OpRewritePattern { +class RemoveShadowFeedPatternInference : public OpRewritePattern { public: - explicit RemoveShadowFeedPatternInference(pir::IrContext *context) - : pir::OpRewritePattern::OpRewritePattern( - context) {} + explicit RemoveShadowFeedPatternInference(IrContext *context) + : OpRewritePattern::OpRewritePattern(context) {} - bool Match(paddle::dialect::PhiKernelOp op) const override { + bool Match(PhiKernelOp op) const override { return op.op_name() == "pd_op.shadow_feed"; } - void Rewrite(paddle::dialect::PhiKernelOp op, - pir::PatternRewriter &rewriter) const override { // NOLINT + void Rewrite(PhiKernelOp op, + PatternRewriter &rewriter) const override { // NOLINT auto in = op.operand_source(0); auto out = op.result(0); in.set_type(out.type()); @@ -190,13 +184,12 @@ class RemoveShadowFeedPatternInference } }; -class RemoveShadowFeedPass : public pir::PatternRewritePass { +class RemoveShadowFeedPass : public PatternRewritePass { public: - RemoveShadowFeedPass() - : pir::PatternRewritePass("remove_shadow_feed_pass", 0) {} + RemoveShadowFeedPass() : PatternRewritePass("remove_shadow_feed_pass", 0) {} - pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { - pir::RewritePatternSet ps(context); + RewritePatternSet InitializePatterns(IrContext *context) override { + RewritePatternSet ps(context); if (Has("used_for_inference") && Get("used_for_inference")) { ps.Add(context); } else { @@ -208,23 +201,22 @@ class RemoveShadowFeedPass : public pir::PatternRewritePass { "When using RemoveShadowFeedPass, block attribute is required!" "Use Set method to set the place attribute.")); PADDLE_ENFORCE_EQ( - Has(pir::Pass::kPlaceAttr), + Has(Pass::kPlaceAttr), true, common::errors::InvalidArgument( "Pass initialize failed." "When using RemoveShadowFeedPass, place attribute is required!" "Use Set method to set the place attribute.")); PADDLE_ENFORCE_EQ( - Has(pir::Pass::kParamScopeAttr), + Has(Pass::kParamScopeAttr), true, common::errors::InvalidArgument( "Pass initialize failed." "When using RemoveShadowFeedPass, scope attribute is required!" "Use Set method to set the scope attribute.")); - auto block = &Get("top_block"); - auto &place = Get(pir::Pass::kPlaceAttr); - auto scope = - &Get(pir::Pass::kParamScopeAttr); + auto block = &Get("top_block"); + auto &place = Get(Pass::kPlaceAttr); + auto scope = &Get(Pass::kParamScopeAttr); PADDLE_ENFORCE_NOT_NULL( block, common::errors::InvalidArgument("block can not be nullptr")); PADDLE_ENFORCE_NOT_NULL( @@ -236,11 +228,7 @@ class RemoveShadowFeedPass : public pir::PatternRewritePass { } }; -} // namespace - -namespace pir { - -std::unique_ptr CreateRemoveShadowFeedPass() { +std::unique_ptr CreateRemoveShadowFeedPass() { return std::make_unique(); }