diff --git a/tools/pnnx/src/pass_level5/fuse_static_prelu.cpp b/tools/pnnx/src/pass_level5/fuse_static_prelu.cpp index 1d61af6b9cfd..1b5e7289bec2 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_prelu.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_prelu.cpp @@ -35,12 +35,44 @@ pnnx.Output output 1 0 out } }; +class convert_prelu_to_leakyrelu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.PReLU op_0 1 1 input out num_parameters=1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.LeakyReLU"; + } + + const char* name_str() const + { + return "leakyrelu"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const Attribute& weight = captured_attrs.at("op_0.weight"); + op->params["negative_slope"] = weight.get_float32_data()[0]; + } +}; + void fuse_static_prelu(Graph& graph) { fuse_static_Fprelu_pass a; + convert_prelu_to_leakyrelu b; int opindex = 0; pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); } } // namespace pnnx