Skip to content

Commit c541f53

Browse files
cocotdfjatinwadhwa921
authored andcommitted
Add of ReduceMax Gradient (microsoft#23501)
1 parent e03ef74 commit c541f53

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

orttraining/orttraining/core/graph/gradient_builder.cc

+60
Original file line numberDiff line numberDiff line change
@@ -2253,6 +2253,66 @@ IMPLEMENT_GRADIENT_BUILDER(GetGlobalMaxPoolGradient) {
22532253

22542254
result.push_back(NodeDef("Expand", {GO(0), IA("X_shape")}, {IA("expanded_dY")}));
22552255
result.push_back(NodeDef("Mul", {IA("mask_cast"), IA("expanded_dY")}, {GI(0)}));
2256+
return result;
2257+
}
2258+
2259+
IMPLEMENT_GRADIENT_BUILDER(GetReduceMaxGradient) {
2260+
std::vector<NodeDef> result;
2261+
auto attributes = SrcNodeAttributes();
2262+
bool keepdims = true;
2263+
2264+
// Check the "keepdims" attribute
2265+
if (attributes.find("keepdims") != attributes.end() &&
2266+
attributes.at("keepdims").has_i()) {
2267+
keepdims = static_cast<bool>(attributes.at("keepdims").i());
2268+
}
2269+
2270+
ArgDef grad = GO(0);
2271+
ArgDef reduced_output = O(0);
2272+
2273+
if (!keepdims) {
2274+
size_t numInputs = GetSrcNodeInputSize();
2275+
ArgDef unsqueeze_axes_arg;
2276+
bool axes_provided = false;
2277+
2278+
// Handle "axes" as attribute or input
2279+
if (attributes.find("axes") != attributes.end()) {
2280+
axes_provided = true;
2281+
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
2282+
if (SrcNodeOpsetVersion() >= 13) {
2283+
NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values"));
2284+
result.push_back(axes_values_node);
2285+
unsqueeze_axes_arg = axes_values_node.output_args[0];
2286+
}
2287+
} else if (numInputs == 2) {
2288+
axes_provided = true;
2289+
unsqueeze_axes_arg = I(1);
2290+
}
2291+
2292+
if (axes_provided) {
2293+
grad = IA("Unsqueezed_Grad");
2294+
reduced_output = IA("Unsqueezed_Output");
2295+
if (SrcNodeOpsetVersion() < 13 && attributes.find("axes") != attributes.end()) {
2296+
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
2297+
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
2298+
result.push_back(NodeDef("Unsqueeze", {O(0)}, {reduced_output}, {MakeAttribute("axes", axes_values)}));
2299+
} else {
2300+
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), unsqueeze_axes_arg}, {grad}));
2301+
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {O(0), unsqueeze_axes_arg}, {reduced_output}));
2302+
}
2303+
}
2304+
}
2305+
2306+
// Step 1: Recreate the boolean mask tensor indicating max positions
2307+
result.push_back(NodeDef("Shape", {I(0)}, {IA("Shaped_X")}));
2308+
result.push_back(NodeDef("Expand", {reduced_output, IA("Shaped_X")}, {IA("Expanded_Output")}));
2309+
result.push_back(NodeDef("Equal", {I(0), IA("Expanded_Output")}, {IA("Mask")}));
2310+
// Step 2: Convert the boolean mask to a float tensor (0.0 and 1.0)
2311+
result.push_back(NodeDef("Cast", {IA("Mask")}, {IA("Mask_Float")}, {MakeAttribute("to", static_cast<int64_t>(OElemType(0)))}));
2312+
// Step 3: Multiply the input gradient by the mask
2313+
result.push_back(NodeDef("Mul", {grad, IA("Mask_Float")}, {IA("Masked_Grad")}));
2314+
// Step 4: Ensure the output gradient has the same shape as the input
2315+
result.push_back(NodeDef("Expand", {IA("Masked_Grad"), IA("Shaped_X")}, {GI(0)}));
22562316

22572317
return result;
22582318
}

orttraining/orttraining/core/graph/gradient_builder.h

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient)
9595
DECLARE_GRADIENT_BUILDER(GetResizeGradient)
9696
DECLARE_GRADIENT_BUILDER(GetAtanGradient)
9797
DECLARE_GRADIENT_BUILDER(GetGlobalMaxPoolGradient)
98+
DECLARE_GRADIENT_BUILDER(GetReduceMaxGradient)
9899

99100
DECLARE_GRADIENT_BUILDER(GetExternalGradient)
100101

orttraining/orttraining/core/graph/gradient_builder_registry.cc

+1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
127127
REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient);
128128
REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient);
129129
REGISTER_GRADIENT_BUILDER("GlobalMaxPool", GetGlobalMaxPoolGradient);
130+
REGISTER_GRADIENT_BUILDER("ReduceMax", GetReduceMaxGradient);
130131

131132
REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
132133
};

orttraining/orttraining/test/gradient/gradient_ops_test.cc

+24
Original file line numberDiff line numberDiff line change
@@ -3379,6 +3379,30 @@ TEST(GradientCheckerTest, GlobalMaxPoolGrad) {
33793379
}
33803380
}
33813381

3382+
TEST(GradientCheckerTest, ReduceMaxGrad) {
3383+
// Attribute axes supports negative values from opset 11.
3384+
OpDef op_def_11{"ReduceMax", kOnnxDomain, 11};
3385+
3386+
RunReductionTests(op_def_11, false, true);
3387+
3388+
OpDef op_def_12{"ReduceMax", kOnnxDomain, 12};
3389+
3390+
RunReductionTests(op_def_12, false, true);
3391+
3392+
OpDef op_def_13{"ReduceMax", kOnnxDomain, 13};
3393+
3394+
RunReductionTests(op_def_13, false, true);
3395+
3396+
// axes is input from opset 18.
3397+
OpDef op_def_18{"ReduceMax", kOnnxDomain, 18};
3398+
3399+
RunReductionTests(op_def_18, true, true);
3400+
3401+
OpDef op_def_20{"ReduceMax", kOnnxDomain, 20};
3402+
3403+
RunReductionTests(op_def_20, true, true);
3404+
}
3405+
33823406
} // namespace test
33833407
} // namespace onnxruntime
33843408

0 commit comments

Comments
 (0)