Skip to content

Commit 3087578

Browse files
committed
Add pass to remove unused parameters in to_executorch
Summary: Currently, ExecuTorch will serialize any parameters in the exported program, regardless of whether they are actually used. Exporting with strict=True will remove unused parameters, but strict=False will not. Export recently switched to non-strict as the default behavior. This causes PTE bloat when doing pt2e quantization (unquantized weights are left in the graph) or sometimes when exporting multiple methods (encode and decoder, for example). This PR adds a new pass (`remove_unused_parameters_pass`) to strip unused parameters from the `ExportedProgram`. It is run as part of `to_executorch`. Parameters are considered unused if there are no uses of the placeholder node. Parameters are removed by stripping them from the state_dict, input specs, and graph. As a question for reviewers, should we run this pass earlier, as part of to_edge? My rationale for running as part of to_executorch was that it could theoretically clean up anything else left by partitioning and lowering, but I'm not aware of any concrete use cases for this. Differential Revision: D73654202
1 parent 7e034ca commit 3087578

File tree

6 files changed

+203
-0
lines changed

6 files changed

+203
-0
lines changed

exir/passes/TARGETS

+12
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ python_library(
2121
":quant_fusion_pass",
2222
":quantize_io_pass",
2323
":remove_noop_pass",
24+
":remove_unused_parameters_pass",
2425
":replace_aten_with_edge_pass",
2526
":replace_broken_ops_with_function_ops_pass",
2627
":replace_edge_with_backend_pass",
@@ -386,3 +387,14 @@ python_library(
386387
"//executorch/exir/dialects:lib",
387388
],
388389
)
390+
391+
python_library(
392+
name = "remove_unused_parameters_pass",
393+
srcs = [
394+
"remove_unused_parameters_pass.py",
395+
],
396+
deps = [
397+
"//caffe2:torch",
398+
"//executorch/exir/dialects:lib",
399+
],
400+
)

exir/passes/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
4646
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
4747
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
48+
from executorch.exir.passes.remove_unused_parameters_pass import (
49+
remove_unused_parameters_pass,
50+
)
4851
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
4952
from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
5053
ReplaceBrokenOpsWithFunctionalOpsPass,
@@ -71,6 +74,7 @@
7174
"MemoryPlanningPass",
7275
"HintBasedSymShapeEvalPass",
7376
"insert_write_back_for_buffers_pass",
77+
"remove_unused_parameters_pass",
7478
"weights_to_outputs_pass",
7579
]
7680

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
import torch
10+
11+
from torch.export.exported_program import ExportedProgram, InputKind
12+
13+
14+
def remove_unused_parameters_pass(
15+
ep: ExportedProgram,
16+
) -> ExportedProgram:
17+
"""
18+
Remove unused parameters from the exported program.
19+
"""
20+
21+
placeholder_nodes = {
22+
node.target: node
23+
for node in ep.graph_module.graph.nodes
24+
if node.op == "placeholder"
25+
}
26+
27+
unused_parameters = [
28+
s
29+
for s in ep.graph_signature.input_specs
30+
if s.kind == InputKind.PARAMETER
31+
and not _is_parameter_used(ep, s.arg.name, placeholder_nodes)
32+
]
33+
34+
# Remove params from the state dict, graph, and signature.
35+
new_signature = copy.deepcopy(ep.graph_signature)
36+
for param in unused_parameters:
37+
new_signature.input_specs.remove(param)
38+
del ep._state_dict[param.target]
39+
ep.graph_module.graph.erase_node(placeholder_nodes[param.arg.name])
40+
41+
ep._graph_signature = new_signature
42+
ep.graph_module.recompile()
43+
return ep
44+
45+
46+
def _is_parameter_used(
47+
ep: ExportedProgram, parameter: str, placeholder_nodes: dict[str, torch.fx.Node]
48+
) -> bool:
49+
placeholder_node = placeholder_nodes.get(parameter)
50+
if placeholder_node is None:
51+
# Shouldn't happen, but in this case, leave the parameter to be safe.
52+
return True
53+
54+
return len(placeholder_node.users) > 0

exir/program/_program.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
EdgeToBackendOpsPass,
4242
MemoryFormatOpsPass,
4343
OpReplacePass,
44+
remove_unused_parameters_pass,
4445
)
4546
from executorch.exir.passes.external_constants_pass import (
4647
external_constants_pass,
@@ -1529,6 +1530,7 @@ def to_executorch(
15291530
for name, program in self._edge_programs.items():
15301531
program = weights_to_outputs_pass(program)
15311532
program = unsafe_remove_auto_functionalized_pass(program)
1533+
program = remove_unused_parameters_pass(program)
15321534
gm, new_signature = insert_write_back_for_buffers_pass(program)
15331535
new_gm = program.graph_module
15341536
for p in edge_to_executorch_passes(config, name):

exir/tests/TARGETS

+16
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,22 @@ python_unittest(
432432
],
433433
)
434434

435+
python_unittest(
436+
name = "test_remove_unused_parameters_pass",
437+
srcs = [
438+
"test_remove_unused_parameters_pass.py",
439+
],
440+
deps = [
441+
"//caffe2:torch",
442+
"//executorch/backends/xnnpack:xnnpack_delegate",
443+
"//executorch/exir:lib",
444+
"//executorch/exir:memory",
445+
"//executorch/exir/capture:config",
446+
"//executorch/exir/passes:lib",
447+
"//executorch/runtime:runtime",
448+
],
449+
)
450+
435451
python_unittest(
436452
name = "test_remove_view_copy",
437453
srcs = [
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import unittest
2+
from typing import Sequence
3+
4+
import torch
5+
6+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
7+
from executorch.exir import to_edge_transform_and_lower
8+
from executorch.exir.passes import remove_unused_parameters_pass
9+
from executorch.runtime import Runtime
10+
from torch.export import ExportedProgram
11+
12+
13+
class TestRemoveUnusedParametersPass(unittest.TestCase):
14+
class ModelWithUnusedParameters(torch.nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
self.linear1 = torch.nn.Linear(16, 16)
18+
self.unused_linear = torch.nn.Linear(1024, 1024)
19+
20+
def forward(self, x):
21+
return self.linear1(x)
22+
23+
def _test_pass(
24+
self,
25+
ep: ExportedProgram,
26+
unused_param_names_and_args: dict[str, str],
27+
example_inputs: Sequence[torch.Tensor],
28+
expected_outputs: torch.Tensor,
29+
):
30+
# Verify EP state before running the pass.
31+
placeholders = set(
32+
n.target for n in ep.graph_module.graph.nodes if n.op == "placeholder"
33+
)
34+
for param_name, param_arg in unused_param_names_and_args.items():
35+
self.assertIn(param_name, ep.state_dict.keys())
36+
self.assertIn(param_name, ep.graph_signature.parameters)
37+
self.assertIn(param_arg, placeholders)
38+
39+
new_ep = remove_unused_parameters_pass(ep)
40+
41+
# Verify that the unused params are not in the state dict,
42+
# graph signature, or graph.
43+
new_placeholders = set(
44+
n.target for n in new_ep.graph_module.graph.nodes if n.op == "placeholder"
45+
)
46+
for param_name, param_arg in unused_param_names_and_args.items():
47+
self.assertNotIn(param_name, new_ep.state_dict.keys())
48+
self.assertNotIn(param_name, new_ep.graph_signature.parameters)
49+
self.assertNotIn(param_arg, new_placeholders)
50+
51+
# Verify that the outputs are unchanged.
52+
new_outputs = new_ep.module()(*example_inputs)
53+
self.assertTrue(torch.allclose(new_outputs, expected_outputs))
54+
55+
def test_remove_unused_parameters_simple(self):
56+
model = self.ModelWithUnusedParameters()
57+
model.eval()
58+
example_inputs = (torch.randn(1, 16),)
59+
eager_outputs = model(*example_inputs)
60+
ep = torch.export.export(model, example_inputs, strict=False)
61+
62+
unused_param_names_and_args = {
63+
"unused_linear.weight": "p_unused_linear_weight",
64+
"unused_linear.bias": "p_unused_linear_bias",
65+
}
66+
67+
self._test_pass(ep, unused_param_names_and_args, example_inputs, eager_outputs)
68+
69+
def test_remove_unused_parameters_simple_edge_dialect(self):
70+
model = self.ModelWithUnusedParameters()
71+
model.eval()
72+
example_inputs = (torch.randn(1, 16),)
73+
eager_outputs = model(*example_inputs)
74+
75+
unused_param_names_and_args = {
76+
"unused_linear.weight": "p_unused_linear_weight",
77+
"unused_linear.bias": "p_unused_linear_bias",
78+
}
79+
80+
for delegated in [False, True]:
81+
lowered = to_edge_transform_and_lower(
82+
torch.export.export(model, example_inputs, strict=False),
83+
partitioner=[XnnpackPartitioner()] if delegated else [],
84+
)
85+
86+
self._test_pass(
87+
lowered.exported_program(),
88+
unused_param_names_and_args,
89+
example_inputs,
90+
eager_outputs,
91+
)
92+
93+
def test_remove_unused_parameters_serialized_e2e(self):
94+
model = self.ModelWithUnusedParameters()
95+
model.eval()
96+
example_inputs = (torch.randn(1, 16),)
97+
eager_outputs = model(*example_inputs)
98+
99+
# Pass is expected to run as part of to_executorch().
100+
lowered = to_edge_transform_and_lower(
101+
torch.export.export(model, example_inputs, strict=False),
102+
).to_executorch()
103+
104+
# There are approximately 1M unused fp32 parameters - ~4Mb.
105+
# Without the unused params, the expected size is ~2.5Kb.
106+
self.assertLess(len(lowered.buffer), 10000)
107+
108+
# Make sure we can load and run the serialized .pte.
109+
runtime = Runtime.get()
110+
program = runtime.load_program(lowered.buffer)
111+
method = program.load_method("forward")
112+
runtime_outputs = method.execute([*example_inputs])
113+
114+
self.assertEqual(1, len(runtime_outputs))
115+
self.assertTrue(torch.allclose(runtime_outputs[0], eager_outputs))

0 commit comments

Comments
 (0)