Skip to content

Commit aff5086

Browse files
authored
Update ReplaceSqueezeAndUnsqueezeWithViewPass to use new pass interface
Differential Revision: D86785126 Pull Request resolved: #15757
1 parent 131d1f4 commit aff5086

File tree

2 files changed

+57
-24
lines changed

2 files changed

+57
-24
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -193,39 +193,40 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
193193

194194

195195
@register_cadence_pass(CadencePassAttribute(opt_level=0))
196-
class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass):
196+
class ReplaceSqueezeAndUnsqueezeWithViewPass(RemoveOrReplacePassInterface):
197197
"""
198198
When the shape is static, replace squeeze_copy and unsqueeze_copy ops with
199199
view_copy op
200200
"""
201201

202-
def call_operator(
203-
self,
204-
op,
205-
args: Tuple[Argument, ...],
206-
kwargs: Dict[str, Argument],
207-
meta: NodeMetadata,
208-
) -> ProxyValue:
209-
# Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket,
210-
# which allows us to cover all overloads.
211-
if get_edge_overload_packet(op) not in {
212-
exir_ops.edge.aten.squeeze_copy,
213-
exir_ops.edge.aten.unsqueeze_copy,
214-
}:
215-
return super().call_operator(op, args, kwargs, meta)
202+
@property
203+
def targets(self) -> list[EdgeOpOverload]:
204+
return [
205+
exir_ops.edge.aten.squeeze_copy.default,
206+
exir_ops.edge.aten.squeeze_copy.dim,
207+
exir_ops.edge.aten.squeeze_copy.dims,
208+
exir_ops.edge.aten.unsqueeze_copy.default,
209+
]
210+
211+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
216212
# Get the output tensor shape
217-
out_shape = meta["val"].shape
213+
out_shape = node.meta["val"].shape
218214

219215
# Bail out if any dim is not an int (dynamic shape)
220216
for dim in list(out_shape):
221217
if not isinstance(dim, int):
222-
return super().call_operator(op, args, kwargs, meta)
218+
return False
223219

224-
# Return a view op with the new shape
225-
view_args = (args[0], list(out_shape))
226-
return super().call_operator(
227-
exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
228-
)
220+
# Replace with view op with the new shape
221+
with node.graph.inserting_before(node):
222+
new_node = node.graph.call_function(
223+
exir_ops.edge.aten.view_copy.default,
224+
args=(node.args[0], list(out_shape)),
225+
)
226+
# Do not remove the metadata copy!
227+
new_node.meta = node.meta
228+
node.replace_all_uses_with(new_node)
229+
return True
229230

230231

231232
@register_cadence_pass(CadencePassAttribute(opt_level=0))

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,12 @@ def test_replace_squeeze_with_view(
972972
args=(x,),
973973
)
974974
p = ReplaceSqueezeAndUnsqueezeWithViewPass()
975-
graph_after_passes = cast(PassResult, p(original_gm)).graph_module
975+
result = cast(PassResult, p(original_gm))
976+
977+
# Assert: Verify the pass modified the graph
978+
self.assertTrue(result.modified)
979+
graph_after_passes = result.graph_module
980+
976981
self.assertIsNotNone(graph_after_passes)
977982
self.assertEqual(
978983
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default),
@@ -1007,7 +1012,12 @@ def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None:
10071012
args=(x, dim),
10081013
)
10091014
p = ReplaceSqueezeAndUnsqueezeWithViewPass()
1010-
graph_after_passes = cast(PassResult, p(original_gm)).graph_module
1015+
result = cast(PassResult, p(original_gm))
1016+
1017+
# Assert: Verify the pass modified the graph
1018+
self.assertTrue(result.modified)
1019+
graph_after_passes = result.graph_module
1020+
10111021
self.assertIsNotNone(graph_after_passes)
10121022
self.assertEqual(
10131023
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default),
@@ -1018,6 +1028,28 @@ def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None:
10181028
0,
10191029
)
10201030

1031+
@torch.no_grad()
1032+
def test_replace_squeeze_and_unsqueeze_with_view_no_modification(self) -> None:
1033+
"""Negative test: pass doesn't modify graphs without squeeze/unsqueeze ops."""
1034+
x = torch.randn(2, 3, 4)
1035+
original_gm = single_op_builder(
1036+
placeholders=(x,),
1037+
op=exir_ops.edge.aten.view_copy.default,
1038+
args=(x, [2, 12]),
1039+
)
1040+
p = ReplaceSqueezeAndUnsqueezeWithViewPass()
1041+
result = cast(PassResult, p(original_gm))
1042+
1043+
# Assert: Verify the pass did NOT modify the graph
1044+
self.assertFalse(result.modified)
1045+
graph_after_passes = result.graph_module
1046+
1047+
# Verify the original view_copy operation is still there
1048+
self.assertEqual(
1049+
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default),
1050+
1,
1051+
)
1052+
10211053
@torch.no_grad()
10221054
def test_replace_conv1d_with_linear(self) -> None:
10231055
x = torch.randn(1, 96, 7)

0 commit comments

Comments
 (0)