Skip to content

Commit eaedd4d

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Allow removing permute pairs in addition to transpose pairs (#10501)
Summary: Pull Request resolved: #10501 As titled. Gets us 27% better cycles on Activity Classification (at opt level 3). Can be improved further, task is T222295719 Reviewed By: zonglinpeng Differential Revision: D73619452
1 parent 9cc9f82 commit eaedd4d

File tree

3 files changed

+120
-50
lines changed

3 files changed

+120
-50
lines changed

backends/cadence/aot/fuse_ops.py

+30-37
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import operator
1515
from collections import deque
1616
from numbers import Number
17-
from typing import cast, Sequence
17+
from typing import cast
1818

1919
# Import these for the cadence function signatures.
2020
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
@@ -881,9 +881,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
881881

882882

883883
@register_cadence_pass(CadencePassAttribute(opt_level=1))
884-
class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
884+
class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885
"""
886-
Fuse transpose op pairs to a single view op.
886+
Fuse transpose or permute op pairs to a single view op.
887887
"""
888888

889889
# A list of ops that can be bypassed when looking for a
@@ -907,42 +907,28 @@ def can_fuse_for_chain(
907907
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
908908
return False
909909

910-
def get_dims(node: torch.fx.Node) -> tuple[int, int]:
911-
def canonicalize(dim: int) -> int:
912-
if dim < 0:
913-
dim += len(node.meta["val"].shape)
914-
return dim
915-
916-
return tuple(canonicalize(cast(int, d)) for d in node.args[1:3])
917-
918-
def is_equivalent(
919-
shape: Sequence[int],
920-
transpose0: tuple[int, int],
921-
transpose1: tuple[int, int],
922-
) -> bool:
923-
def permute_order(
924-
order: Sequence[int], dims: tuple[int, int]
925-
) -> Sequence[int]:
926-
new_order = list(order)
927-
new_order[dims[0]], new_order[dims[1]] = (
928-
new_order[dims[1]],
929-
new_order[dims[0]],
930-
)
931-
return new_order
910+
input_shape = list(cast(torch.fx.Node, producer.args[0]).meta["val"].shape)
932911

933-
order = permute_order(range(len(shape)), transpose0)
934-
order = permute_order(order, transpose1)
912+
intermediate_shape = (
913+
get_transposed_dims(producer, input_shape)
914+
if producer.target == exir_ops.edge.aten.transpose_copy.int
915+
else get_permuted_dims(producer, input_shape)
916+
)
935917

936-
non_unit_dims = [dim for dim in range(len(shape)) if shape[dim] != 1]
937-
non_unit_dims_permuted = [dim for dim in order if shape[dim] != 1]
918+
final_shape = (
919+
get_transposed_dims(consumer, intermediate_shape)
920+
if consumer.target == exir_ops.edge.aten.transpose_copy.int
921+
else get_permuted_dims(consumer, intermediate_shape)
922+
)
938923

939-
return non_unit_dims == non_unit_dims_permuted
924+
non_unit_dims = [
925+
input_shape[dim] for dim in range(len(input_shape)) if input_shape[dim] != 1
926+
]
927+
non_unit_dims_permuted = [
928+
final_shape[dim] for dim in range(len(final_shape)) if final_shape[dim] != 1
929+
]
940930

941-
return is_equivalent(
942-
cast(torch.fx.Node, producer.args[0]).meta["val"].shape,
943-
get_dims(producer),
944-
get_dims(consumer),
945-
)
931+
return non_unit_dims == non_unit_dims_permuted
946932

947933
def get_fused_node(
948934
self,
@@ -960,13 +946,20 @@ def get_fused_node(
960946
return view
961947

962948
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
963-
# Remove any dequantize op that has only quantize ops as its users.
949+
# Remove any transpose op pair that cancel each other.
964950
self.find_and_fuse(
965951
graph_module,
966952
producer_op_packets={exir_ops.edge.aten.transpose_copy},
967953
consumer_op_packets={exir_ops.edge.aten.transpose_copy},
968954
bypass_ops=self.bypass_ops,
969955
)
956+
# Remove any permute op pair that cancel each other.
957+
self.find_and_fuse(
958+
graph_module,
959+
producer_op_packets={exir_ops.edge.aten.permute_copy},
960+
consumer_op_packets={exir_ops.edge.aten.permute_copy},
961+
bypass_ops=self.bypass_ops,
962+
)
970963
result = super().call(graph_module)
971964
return result
972965

@@ -1028,5 +1021,5 @@ class CadenceFuseOpsInGraph:
10281021
FuseQuantDequantToRequantizePass,
10291022
FuseMulIntoDequantPass,
10301023
FuseFullThenReshapePass,
1031-
FuseTransposeOpPairsPass,
1024+
FuseTransposeOrPermuteOpPairsPass,
10321025
]

backends/cadence/aot/passes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.backends.cadence.aot.fuse_ops import (
1515
CadenceFuseOpsInGraph,
1616
FuseFullThenReshapePass,
17-
FuseTransposeOpPairsPass,
17+
FuseTransposeOrPermuteOpPairsPass,
1818
)
1919
from executorch.backends.cadence.aot.pass_utils import (
2020
CadencePassAttribute,
@@ -83,7 +83,7 @@ def get_passes_in_default_order() -> List[ExportPass]:
8383
CadenceSimplifyOpsInGraph.passes,
8484
FinalizePipeline,
8585
FuseFullThenReshapePass,
86-
FuseTransposeOpPairsPass,
86+
FuseTransposeOrPermuteOpPairsPass,
8787
RemoveNopSliceOrViewOpPass,
8888
]
8989
return pytree.tree_flatten(passes)[0]

backends/cadence/aot/tests/test_fusion_ops_passes.py

+88-11
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
FuseFullThenReshapePass,
2121
FuseMulIntoDequantPass,
2222
FuseQuantDequantToRequantizePass,
23-
FuseTransposeOpPairsPass,
23+
FuseTransposeOrPermuteOpPairsPass,
2424
)
2525
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
2626
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
@@ -510,7 +510,7 @@ def test_fuse_then_transpose_pass(self):
510510
)
511511

512512

513-
class TestFuseTransposeOpPairsPass(TestFusionPassesBase):
513+
class TestFuseTransposeOrPermuteOpPairsPass(TestFusionPassesBase):
514514
def _create_operator(
515515
self, builder: GraphBuilder, op: torch._ops.OpOverload, x: ProxyValue
516516
) -> ProxyValue:
@@ -536,17 +536,17 @@ def _create_operator(
536536
def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
537537
# Create a graph with transpose -> quant -> transpose.
538538
builder = GraphBuilder()
539-
x = builder.placeholder("x", torch.randn(2, 3))
540-
transpose_node = builder.call_operator(
539+
x = builder.placeholder("x", torch.randn(2, 3, 4))
540+
transpose_node0 = builder.call_operator(
541541
op=exir_ops.edge.aten.transpose_copy.int,
542542
args=(x, 0, 1),
543543
)
544-
quant_node = self._create_operator(builder, op, transpose_node)
545-
transpose_node = builder.call_operator(
544+
quant_node = self._create_operator(builder, op, transpose_node0)
545+
transpose_node1 = builder.call_operator(
546546
op=exir_ops.edge.aten.transpose_copy.int,
547-
args=(quant_node, 0, 1),
547+
args=(quant_node, 1, 2),
548548
)
549-
builder.output([transpose_node])
549+
builder.output([transpose_node1])
550550
gm = builder.get_graph_module()
551551
self.check_op_counts(
552552
gm,
@@ -557,7 +557,7 @@ def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
557557
)
558558

559559
# Check that the pass fuses the two transpose ops.
560-
fusion_pass_result = FuseTransposeOpPairsPass()(gm)
560+
fusion_pass_result = FuseTransposeOrPermuteOpPairsPass()(gm)
561561
self.assertIsNotNone(fusion_pass_result)
562562
gm_after_pass = fusion_pass_result.graph_module
563563
self.check_op_counts(
@@ -568,6 +568,47 @@ def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
568568
},
569569
)
570570

571+
@parameterized.expand(
572+
[
573+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
574+
exir_ops.edge.cadence.quantized_relu.per_tensor,
575+
],
576+
)
577+
def test_fuse_permute_pairs(self, op: torch._ops.OpOverload):
578+
# Create a graph with permute -> quant -> permute.
579+
builder = GraphBuilder()
580+
x = builder.placeholder("x", torch.randn(8, 2, 3, 4))
581+
permute_node0 = builder.call_operator(
582+
op=exir_ops.edge.aten.permute_copy.default,
583+
args=(x, [0, 3, 1, 2]),
584+
)
585+
quant_node = self._create_operator(builder, op, permute_node0)
586+
permute_node1 = builder.call_operator(
587+
op=exir_ops.edge.aten.permute_copy.default,
588+
args=(quant_node, [0, 2, 3, 1]),
589+
)
590+
builder.output([permute_node1])
591+
gm = builder.get_graph_module()
592+
self.check_op_counts(
593+
gm,
594+
expected_op_counts={
595+
exir_ops.edge.aten.permute_copy.default: 2,
596+
op: 1,
597+
},
598+
)
599+
600+
# Check that the pass fuses the two transpose ops.
601+
fusion_pass_result = FuseTransposeOrPermuteOpPairsPass()(gm)
602+
self.assertIsNotNone(fusion_pass_result)
603+
gm_after_pass = fusion_pass_result.graph_module
604+
self.check_op_counts(
605+
gm_after_pass,
606+
expected_op_counts={
607+
exir_ops.edge.aten.permute_copy.default: 0,
608+
op: 1,
609+
},
610+
)
611+
571612
def test_no_fusion_for_transpose_pairs(self):
572613
# Create a graph with transpose -> quant -> transpose.
573614
builder = GraphBuilder()
@@ -595,7 +636,7 @@ def test_no_fusion_for_transpose_pairs(self):
595636
)
596637

597638
# No fusion.
598-
gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module
639+
gm_after_pass = FuseTransposeOrPermuteOpPairsPass()(gm).graph_module
599640
self.check_op_counts(
600641
gm_after_pass,
601642
expected_op_counts={
@@ -604,6 +645,42 @@ def test_no_fusion_for_transpose_pairs(self):
604645
},
605646
)
606647

648+
def test_no_fusion_for_permute_pairs(self):
649+
# Create a graph with permute -> quant -> permute.
650+
builder = GraphBuilder()
651+
x = builder.placeholder("x", torch.randn(2, 3, 4))
652+
permute_node = builder.call_operator(
653+
op=exir_ops.edge.aten.permute_copy.default,
654+
args=(x, [2, 0, 1]),
655+
)
656+
quant_node = builder.call_operator(
657+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
658+
args=(permute_node, 1.2, 3, 0, 127, torch.int8),
659+
)
660+
permute_node = builder.call_operator(
661+
op=exir_ops.edge.aten.permute_copy.default,
662+
args=(quant_node, [2, 0, 1]),
663+
)
664+
builder.output(permute_node)
665+
gm = builder.get_graph_module()
666+
self.check_op_counts(
667+
gm,
668+
expected_op_counts={
669+
exir_ops.edge.aten.permute_copy.default: 2,
670+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
671+
},
672+
)
673+
674+
# No fusion.
675+
gm_after_pass = FuseTransposeOrPermuteOpPairsPass()(gm).graph_module
676+
self.check_op_counts(
677+
gm_after_pass,
678+
expected_op_counts={
679+
exir_ops.edge.aten.permute_copy.default: 2,
680+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
681+
},
682+
)
683+
607684
def test_fusion_for_forked_transposes(self):
608685
# Create a graph with transpose -> quant -> transpose.
609686
builder = GraphBuilder()
@@ -636,7 +713,7 @@ def test_fusion_for_forked_transposes(self):
636713
)
637714

638715
# Fuse the all the transpose ops.
639-
gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module
716+
gm_after_pass = FuseTransposeOrPermuteOpPairsPass()(gm).graph_module
640717
self.check_op_counts(
641718
gm_after_pass,
642719
expected_op_counts={

0 commit comments

Comments
 (0)