Skip to content

Commit b7886d0

Browse files
NXP backend: Remove IR optimization to remove dead branches. (#13574)
### Summary This PR replaces an IR optimization that removes dead code from the model, by an equivalent executorch call. ### Test plan Unit test provided in `backends/nxp/tests/test_removing_dead_code.py`. cc @digantdesai @JakeStevens @robert-kalmar
1 parent 52128bd commit b7886d0

File tree

4 files changed

+69
-92
lines changed

4 files changed

+69
-92
lines changed

backends/nxp/backend/ir/tflite_optimizer/optimizations/eliminate_dead_branches.py

Lines changed: 0 additions & 82 deletions
This file was deleted.

backends/nxp/backend/ir/tflite_optimizer/optimizer.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#
22
# Copyright 2023 Martin Pavella
3-
# Copyright 2024 NXP
3+
# Copyright 2024-2025 NXP
44
#
55
# License: MIT
66
# See the LICENSE_MIT for more details.
@@ -14,9 +14,6 @@
1414
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.combine_hard_sigmoid_and_mul_to_hard_swish import (
1515
CombineHardSigmoidAndMulIntoHardSwish,
1616
)
17-
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.eliminate_dead_branches import (
18-
EliminateDeadBranches,
19-
)
2017
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.fuse_activation_functions import (
2118
FuseActivationFunctions,
2219
)
@@ -57,7 +54,6 @@ class Optimization(Enum):
5754
FUSE_PARALLEL_QUANTIZE_OPERATORS = 8
5855

5956
REMOVE_UNUSED_TENSORS = 10
60-
ELIMINATE_DEAD_BRANCHES = 11
6157
PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE = 12
6258

6359
MOVE_ACTIVATION_BEFORE_CONCAT = 15
@@ -115,9 +111,6 @@ def __init__(
115111
Optimization.REMOVE_UNUSED_TENSORS: RemoveUnusedTensorsAndBuffers(
116112
builder, conversion_config
117113
),
118-
Optimization.ELIMINATE_DEAD_BRANCHES: EliminateDeadBranches(
119-
builder, conversion_config
120-
),
121114
Optimization.PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE: PermuteFullyConnectedWeightsAfterReshape(
122115
builder, conversion_config
123116
),

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import List, Optional, Tuple, Union
88

99
import torch
10+
1011
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
1112
NeutronAtenPassManager,
1213
)
@@ -242,8 +243,13 @@ def __init__(self):
242243
def transform_for_annotation(
243244
self, model: torch.fx.GraphModule
244245
) -> torch.fx.GraphModule:
245-
pass_runner = NeutronAtenPassManager()
246-
return pass_runner(model).graph_module
246+
model.graph.eliminate_dead_code() # Remove dead code to simplify the graph for the passes.
247+
248+
model = NeutronAtenPassManager()(model).graph_module
249+
250+
model.graph.eliminate_dead_code() # Remove dead code again, in case it was created by the passes.
251+
252+
return model
247253

248254
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
249255
self._annotate_inputs(model)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import numpy as np
9+
import pytest
10+
import torch
11+
12+
from executorch.backends.nxp.tests.executorch_pipeline import _quantize_model
13+
from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops
14+
15+
16+
@pytest.fixture(autouse=True)
17+
def reseed_model_per_test_run():
18+
torch.manual_seed(42)
19+
np.random.seed(23)
20+
21+
22+
class DeadCodeModule(torch.nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.eval()
26+
27+
def forward(self, x):
28+
_ = torch.add(x, x) # Dead code
29+
return torch.mul(x, x)
30+
31+
32+
class TestRemovingDeadCode(unittest.TestCase):
33+
__test__ = False # Prevent interfering with PyTest tests
34+
35+
def test_removing_dead_code(self):
36+
input_shape = (42,)
37+
example_inputs = (torch.ones(input_shape),)
38+
model = DeadCodeModule()
39+
40+
exir_program_aten = torch.export.export(model, example_inputs, strict=True)
41+
42+
# Make sure the model contains the dead code.
43+
assert graph_contains_any_of_ops(
44+
exir_program_aten.module().graph, [torch.ops.aten.add.Tensor]
45+
)
46+
47+
# The `NeutronQuantizer` should remove the dead code in the `transform_for_annotation()` method.
48+
exir_program_aten_quant = _quantize_model(
49+
exir_program_aten.module(), [example_inputs]
50+
)
51+
52+
# Make sure the is no `add` operation in the graph anymore.
53+
assert not any(
54+
"add" in str(node.target) for node in exir_program_aten_quant.graph.nodes
55+
)
56+
57+
@classmethod
58+
def setUpClass(cls):
59+
torch.manual_seed(23)
60+
np.random.seed(23)

0 commit comments

Comments
 (0)