@@ -2253,6 +2253,66 @@ IMPLEMENT_GRADIENT_BUILDER(GetGlobalMaxPoolGradient) {
2253
2253
2254
2254
result.push_back (NodeDef (" Expand" , {GO (0 ), IA (" X_shape" )}, {IA (" expanded_dY" )}));
2255
2255
result.push_back (NodeDef (" Mul" , {IA (" mask_cast" ), IA (" expanded_dY" )}, {GI (0 )}));
2256
+ return result;
2257
+ }
2258
+
2259
+ IMPLEMENT_GRADIENT_BUILDER (GetReduceMaxGradient) {
2260
+ std::vector<NodeDef> result;
2261
+ auto attributes = SrcNodeAttributes ();
2262
+ bool keepdims = true ;
2263
+
2264
+ // Check the "keepdims" attribute
2265
+ if (attributes.find (" keepdims" ) != attributes.end () &&
2266
+ attributes.at (" keepdims" ).has_i ()) {
2267
+ keepdims = static_cast <bool >(attributes.at (" keepdims" ).i ());
2268
+ }
2269
+
2270
+ ArgDef grad = GO (0 );
2271
+ ArgDef reduced_output = O (0 );
2272
+
2273
+ if (!keepdims) {
2274
+ size_t numInputs = GetSrcNodeInputSize ();
2275
+ ArgDef unsqueeze_axes_arg;
2276
+ bool axes_provided = false ;
2277
+
2278
+ // Handle "axes" as attribute or input
2279
+ if (attributes.find (" axes" ) != attributes.end ()) {
2280
+ axes_provided = true ;
2281
+ std::vector<int64_t > axes_values = RetrieveValues<int64_t >(attributes.at (" axes" ));
2282
+ if (SrcNodeOpsetVersion () >= 13 ) {
2283
+ NodeDef axes_values_node = ConstantVectorNode (axes_values, Name (" axes_values" ));
2284
+ result.push_back (axes_values_node);
2285
+ unsqueeze_axes_arg = axes_values_node.output_args [0 ];
2286
+ }
2287
+ } else if (numInputs == 2 ) {
2288
+ axes_provided = true ;
2289
+ unsqueeze_axes_arg = I (1 );
2290
+ }
2291
+
2292
+ if (axes_provided) {
2293
+ grad = IA (" Unsqueezed_Grad" );
2294
+ reduced_output = IA (" Unsqueezed_Output" );
2295
+ if (SrcNodeOpsetVersion () < 13 && attributes.find (" axes" ) != attributes.end ()) {
2296
+ std::vector<int64_t > axes_values = RetrieveValues<int64_t >(attributes.at (" axes" ));
2297
+ result.push_back (NodeDef (" Unsqueeze" , {GO (0 )}, {grad}, {MakeAttribute (" axes" , axes_values)}));
2298
+ result.push_back (NodeDef (" Unsqueeze" , {O (0 )}, {reduced_output}, {MakeAttribute (" axes" , axes_values)}));
2299
+ } else {
2300
+ result.push_back (NodeDef (OpDef{" Unsqueeze" , kOnnxDomain , 13 }, {GO (0 ), unsqueeze_axes_arg}, {grad}));
2301
+ result.push_back (NodeDef (OpDef{" Unsqueeze" , kOnnxDomain , 13 }, {O (0 ), unsqueeze_axes_arg}, {reduced_output}));
2302
+ }
2303
+ }
2304
+ }
2305
+
2306
+ // Step 1: Recreate the boolean mask tensor indicating max positions
2307
+ result.push_back (NodeDef (" Shape" , {I (0 )}, {IA (" Shaped_X" )}));
2308
+ result.push_back (NodeDef (" Expand" , {reduced_output, IA (" Shaped_X" )}, {IA (" Expanded_Output" )}));
2309
+ result.push_back (NodeDef (" Equal" , {I (0 ), IA (" Expanded_Output" )}, {IA (" Mask" )}));
2310
+ // Step 2: Convert the boolean mask to a float tensor (0.0 and 1.0)
2311
+ result.push_back (NodeDef (" Cast" , {IA (" Mask" )}, {IA (" Mask_Float" )}, {MakeAttribute (" to" , static_cast <int64_t >(OElemType (0 )))}));
2312
+ // Step 3: Multiply the input gradient by the mask
2313
+ result.push_back (NodeDef (" Mul" , {grad, IA (" Mask_Float" )}, {IA (" Masked_Grad" )}));
2314
+ // Step 4: Ensure the output gradient has the same shape as the input
2315
+ result.push_back (NodeDef (" Expand" , {IA (" Masked_Grad" ), IA (" Shaped_X" )}, {GI (0 )}));
2256
2316
2257
2317
return result;
2258
2318
}
0 commit comments