20
20
FuseFullThenReshapePass ,
21
21
FuseMulIntoDequantPass ,
22
22
FuseQuantDequantToRequantizePass ,
23
- FuseTransposeOpPairsPass ,
23
+ FuseTransposeOrPermuteOpPairsPass ,
24
24
)
25
25
from executorch .backends .cadence .aot .graph_builder import GraphBuilder
26
26
from executorch .backends .cadence .aot .pass_utils import count_node , op_counts_match
@@ -510,7 +510,7 @@ def test_fuse_then_transpose_pass(self):
510
510
)
511
511
512
512
513
- class TestFuseTransposeOpPairsPass (TestFusionPassesBase ):
513
+ class TestFuseTransposeOrPermuteOpPairsPass (TestFusionPassesBase ):
514
514
def _create_operator (
515
515
self , builder : GraphBuilder , op : torch ._ops .OpOverload , x : ProxyValue
516
516
) -> ProxyValue :
@@ -536,17 +536,17 @@ def _create_operator(
536
536
def test_fuse_transpose_pairs (self , op : torch ._ops .OpOverload ):
537
537
# Create a graph with transpose -> quant -> transpose.
538
538
builder = GraphBuilder ()
539
- x = builder .placeholder ("x" , torch .randn (2 , 3 ))
540
- transpose_node = builder .call_operator (
539
+ x = builder .placeholder ("x" , torch .randn (2 , 3 , 4 ))
540
+ transpose_node0 = builder .call_operator (
541
541
op = exir_ops .edge .aten .transpose_copy .int ,
542
542
args = (x , 0 , 1 ),
543
543
)
544
- quant_node = self ._create_operator (builder , op , transpose_node )
545
- transpose_node = builder .call_operator (
544
+ quant_node = self ._create_operator (builder , op , transpose_node0 )
545
+ transpose_node1 = builder .call_operator (
546
546
op = exir_ops .edge .aten .transpose_copy .int ,
547
- args = (quant_node , 0 , 1 ),
547
+ args = (quant_node , 1 , 2 ),
548
548
)
549
- builder .output ([transpose_node ])
549
+ builder .output ([transpose_node1 ])
550
550
gm = builder .get_graph_module ()
551
551
self .check_op_counts (
552
552
gm ,
@@ -557,7 +557,7 @@ def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
557
557
)
558
558
559
559
# Check that the pass fuses the two transpose ops.
560
- fusion_pass_result = FuseTransposeOpPairsPass ()(gm )
560
+ fusion_pass_result = FuseTransposeOrPermuteOpPairsPass ()(gm )
561
561
self .assertIsNotNone (fusion_pass_result )
562
562
gm_after_pass = fusion_pass_result .graph_module
563
563
self .check_op_counts (
@@ -568,6 +568,47 @@ def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
568
568
},
569
569
)
570
570
571
+ @parameterized .expand (
572
+ [
573
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
574
+ exir_ops .edge .cadence .quantized_relu .per_tensor ,
575
+ ],
576
+ )
577
+ def test_fuse_permute_pairs (self , op : torch ._ops .OpOverload ):
578
+ # Create a graph with permute -> quant -> permute.
579
+ builder = GraphBuilder ()
580
+ x = builder .placeholder ("x" , torch .randn (8 , 2 , 3 , 4 ))
581
+ permute_node0 = builder .call_operator (
582
+ op = exir_ops .edge .aten .permute_copy .default ,
583
+ args = (x , [0 , 3 , 1 , 2 ]),
584
+ )
585
+ quant_node = self ._create_operator (builder , op , permute_node0 )
586
+ permute_node1 = builder .call_operator (
587
+ op = exir_ops .edge .aten .permute_copy .default ,
588
+ args = (quant_node , [0 , 2 , 3 , 1 ]),
589
+ )
590
+ builder .output ([permute_node1 ])
591
+ gm = builder .get_graph_module ()
592
+ self .check_op_counts (
593
+ gm ,
594
+ expected_op_counts = {
595
+ exir_ops .edge .aten .permute_copy .default : 2 ,
596
+ op : 1 ,
597
+ },
598
+ )
599
+
600
+ # Check that the pass fuses the two transpose ops.
601
+ fusion_pass_result = FuseTransposeOrPermuteOpPairsPass ()(gm )
602
+ self .assertIsNotNone (fusion_pass_result )
603
+ gm_after_pass = fusion_pass_result .graph_module
604
+ self .check_op_counts (
605
+ gm_after_pass ,
606
+ expected_op_counts = {
607
+ exir_ops .edge .aten .permute_copy .default : 0 ,
608
+ op : 1 ,
609
+ },
610
+ )
611
+
571
612
def test_no_fusion_for_transpose_pairs (self ):
572
613
# Create a graph with transpose -> quant -> transpose.
573
614
builder = GraphBuilder ()
@@ -595,7 +636,7 @@ def test_no_fusion_for_transpose_pairs(self):
595
636
)
596
637
597
638
# No fusion.
598
- gm_after_pass = FuseTransposeOpPairsPass ()(gm ).graph_module
639
+ gm_after_pass = FuseTransposeOrPermuteOpPairsPass ()(gm ).graph_module
599
640
self .check_op_counts (
600
641
gm_after_pass ,
601
642
expected_op_counts = {
@@ -604,6 +645,42 @@ def test_no_fusion_for_transpose_pairs(self):
604
645
},
605
646
)
606
647
648
+ def test_no_fusion_for_permute_pairs (self ):
649
+ # Create a graph with permute -> quant -> permute.
650
+ builder = GraphBuilder ()
651
+ x = builder .placeholder ("x" , torch .randn (2 , 3 , 4 ))
652
+ permute_node = builder .call_operator (
653
+ op = exir_ops .edge .aten .permute_copy .default ,
654
+ args = (x , [2 , 0 , 1 ]),
655
+ )
656
+ quant_node = builder .call_operator (
657
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
658
+ args = (permute_node , 1.2 , 3 , 0 , 127 , torch .int8 ),
659
+ )
660
+ permute_node = builder .call_operator (
661
+ op = exir_ops .edge .aten .permute_copy .default ,
662
+ args = (quant_node , [2 , 0 , 1 ]),
663
+ )
664
+ builder .output (permute_node )
665
+ gm = builder .get_graph_module ()
666
+ self .check_op_counts (
667
+ gm ,
668
+ expected_op_counts = {
669
+ exir_ops .edge .aten .permute_copy .default : 2 ,
670
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default : 1 ,
671
+ },
672
+ )
673
+
674
+ # No fusion.
675
+ gm_after_pass = FuseTransposeOrPermuteOpPairsPass ()(gm ).graph_module
676
+ self .check_op_counts (
677
+ gm_after_pass ,
678
+ expected_op_counts = {
679
+ exir_ops .edge .aten .permute_copy .default : 2 ,
680
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default : 1 ,
681
+ },
682
+ )
683
+
607
684
def test_fusion_for_forked_transposes (self ):
608
685
# Create a graph with transpose -> quant -> transpose.
609
686
builder = GraphBuilder ()
@@ -636,7 +713,7 @@ def test_fusion_for_forked_transposes(self):
636
713
)
637
714
638
715
# Fuse the all the transpose ops.
639
- gm_after_pass = FuseTransposeOpPairsPass ()(gm ).graph_module
716
+ gm_after_pass = FuseTransposeOrPermuteOpPairsPass ()(gm ).graph_module
640
717
self .check_op_counts (
641
718
gm_after_pass ,
642
719
expected_op_counts = {
0 commit comments