Skip to content

Commit 7d2b8c6

Browse files
authored
Arm backend: Add correction for floor mode (#14776)
Correct implementation of div.tensor_mode for 'floor' case to make it numerically stable. Signed-off-by: Elena Zhelezina <[email protected]>
1 parent 9be3aaa commit 7d2b8c6

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

backends/arm/_passes/decompose_div_tensor_mode.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
"full": exir_ops.edge.aten.full.default,
2323
"lt": exir_ops.edge.aten.lt.Tensor,
2424
"where": exir_ops.edge.aten.where.self,
25+
"mul": exir_ops.edge.aten.mul.Tensor,
26+
"sub": exir_ops.edge.aten.sub.Tensor,
2527
}
2628

2729
aten_unary = {
@@ -31,6 +33,8 @@
3133
"full": torch.ops.aten.full.default,
3234
"lt": torch.ops.aten.lt.Tensor,
3335
"where": torch.ops.aten.where.self,
36+
"mul": torch.ops.aten.mul.Tensor,
37+
"sub": torch.ops.aten.sub.Tensor,
3438
}
3539

3640

@@ -70,13 +74,57 @@ def call_operator(self, op, args, kwargs, meta):
7074
return q
7175

7276
if rounding_mode == "floor":
73-
return super().call_operator(opset["floor"], (q,), {}, meta)
77+
q_raw = q
78+
79+
# trunc(q_raw) = where(q_raw < 0, ceil(q_raw), floor(q_raw))
80+
q_floor = super().call_operator(opset["floor"], (q_raw,), {}, meta)
81+
q_ceil = super().call_operator(opset["ceil"], (q_raw,), {}, meta)
82+
83+
# a zero tensor with the right shape
84+
out_shape = (1,) * len(meta["val"].size())
85+
zero = super().call_operator(
86+
opset["full"],
87+
args=(out_shape, 0.0),
88+
kwargs={},
89+
meta=meta,
90+
)
91+
92+
is_neg = super().call_operator(opset["lt"], (q_raw, zero), {}, meta)
93+
q_trunc = super().call_operator(
94+
opset["where"], (is_neg, q_ceil, q_floor), {}, meta
95+
)
96+
97+
# r = a - q_trunc * b (true remainder under truncation)
98+
q_times_b = super().call_operator(opset["mul"], (q_trunc, b), {}, meta)
99+
r = super().call_operator(opset["sub"], (a, q_times_b), {}, meta)
100+
101+
# Decide if we need to subtract 1:
102+
# for b > 0, adjust if r < 0; for b < 0, adjust if r > 0.
103+
b_pos = super().call_operator(opset["lt"], (zero, b), {}, meta) # b > 0
104+
r_lt0 = super().call_operator(opset["lt"], (r, zero), {}, meta) # r < 0
105+
r_gt0 = super().call_operator(opset["lt"], (zero, r), {}, meta) # r > 0
106+
107+
adjust_if = super().call_operator(
108+
opset["where"], (b_pos, r_lt0, r_gt0), {}, meta
109+
)
110+
111+
one = super().call_operator(
112+
opset["full"],
113+
args=(out_shape, 1.0),
114+
kwargs={},
115+
meta=meta,
116+
)
117+
q_minus_1 = super().call_operator(opset["sub"], (q_trunc, one), {}, meta)
118+
119+
return super().call_operator(
120+
opset["where"], (adjust_if, q_minus_1, q_trunc), {}, meta
121+
)
74122

75123
if rounding_mode == "trunc":
76124
zero = super().call_operator(
77125
opset["full"],
78126
args=((1,) * len(meta["val"].size()), 0.0),
79-
kwargs={"dtype": torch.float32},
127+
kwargs={},
80128
meta=meta,
81129
)
82130
lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta)

backends/arm/test/ops/test_div_tensor_mode.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
3636
return torch.div(x, y, rounding_mode=self.mode)
3737

3838

39+
def _rank4_large_randn_case():
40+
torch.manual_seed(0)
41+
x = 200 * torch.randn(5, 10, 25, 20) + 1
42+
torch.manual_seed(1)
43+
y = torch.rand(5, 10, 25, 20) + 1
44+
return x, y
45+
46+
3947
test_data = {
4048
"mode_none": lambda: (None, (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3)),
4149
"mode_floor": lambda: (
@@ -47,6 +55,13 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
4755
(torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3),
4856
),
4957
"int_denominator": lambda: (None, (torch.randn(4, 8), 2)),
58+
"op_floor_div_rank4_large_randn": lambda: (
59+
"floor",
60+
(
61+
200 * torch.randn(5, 10, 25, 20) + 1,
62+
torch.rand(5, 10, 25, 20) + 1,
63+
),
64+
),
5065
}
5166

5267

@@ -84,7 +99,13 @@ def test_div_tensor_mode_tosa_INT(data):
8499

85100
@common.XfailIfNoCorstone300
86101
@common.parametrize(
87-
"data", test_data, xfails={"mode_trunc": "CPU op missing in unittests"}
102+
"data",
103+
test_data,
104+
xfails={
105+
"mode_trunc": "CPU op missing in unittests",
106+
"mode_floor": "Not supported",
107+
"op_floor_div_rank4_large_randn": "Not supported",
108+
},
88109
)
89110
def test_div_tensor_mode_u55_INT(data):
90111
mode, inputs = data()
@@ -94,9 +115,10 @@ def test_div_tensor_mode_u55_INT(data):
94115
model,
95116
inputs,
96117
aten_ops=model.aten_ops_int,
97-
exir_ops=[],
98118
use_to_edge_transform_and_lower=True,
99119
)
120+
pipeline.pop_stage("check_not.exir")
121+
pipeline.pop_stage("check_count.exir")
100122
pipeline.run()
101123

102124

0 commit comments

Comments
 (0)