1- // RUN: enzymexlamlir-opt --partial-symmetry-simplify %s | FileCheck %s
1+ // RUN: enzymexlamlir-opt --partial-symmetry-annotate --enzyme-hlo-generate-td="patterns=transpose_partial_symmetry_simplify" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
22
3- func.func @test1 () -> tensor <2 x2 xf32 > {
3+ func.func @test_constant () -> tensor <2 x2 xf32 > {
44 %cst = stablehlo.constant dense <[[1.0 , 2.0 ], [2.0 , 3.0 ]]> : tensor <2 x2 xf32 >
5- return %cst : tensor <2 x2 xf32 >
5+ %0 = stablehlo.transpose %cst , dims = [1 , 0 ] : (tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
6+ return %0 : tensor <2 x2 xf32 >
67}
7- // CHECK: func.func @test1 () -> tensor<2x2xf32> {
8+ // CHECK: func.func @test_constant () -> tensor<2x2xf32> {
89// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2xf32>
910// CHECK-NEXT: return %cst : tensor<2x2xf32>
1011// CHECK-NEXT: }
1112
12- func.func @test2 () -> tensor <2 x2 x2 x3 xf32 > {
13+ func.func @test_propagate () -> tensor <2 x2 x2 x3 xf32 > {
1314 %cst0 = stablehlo.constant dense <[[[1.0 , 2.0 ], [3.0 , 4.0 ]], [[3.0 , 4.0 ], [5.0 , 6.0 ]]]> : tensor <2 x2 x2 xf32 >
1415 %cst1 = stablehlo.constant dense <[[[1.0 , 2.0 ], [2.0 , 3.0 ]], [[2.0 , 3.0 ], [3.0 , 4.0 ]]]> : tensor <2 x2 x2 xf32 >
1516 %0 = stablehlo.add %cst0 , %cst1 : tensor <2 x2 x2 xf32 >
1617 %1 = stablehlo.transpose %0 , dims = [0 , 2 , 1 ] : (tensor <2 x2 x2 xf32 >) -> tensor <2 x2 x2 xf32 >
1718 %2 = stablehlo.broadcast_in_dim %1 , dims = [1 , 0 , 2 ] : (tensor <2 x2 x2 xf32 >) -> tensor <2 x2 x2 x3 xf32 >
18- return %2 : tensor <2 x2 x2 x3 xf32 >
19+ %3 = stablehlo.transpose %2 , dims = [0 , 2 , 1 , 3 ] : (tensor <2 x2 x2 x3 xf32 >) -> tensor <2 x2 x2 x3 xf32 >
20+ return %3 : tensor <2 x2 x2 x3 xf32 >
1921}
20- // CHECK: func.func @test2 () -> tensor<2x2x2x3xf32> {
22+ // CHECK: func.func @test_propagate () -> tensor<2x2x2x3xf32> {
2123// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2x2xf32>
2224// CHECK-NEXT: %cst_0 = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1, 2]>>]} dense<{{.*}}> : tensor<2x2x2xf32>
2325// CHECK-NEXT: %0 = stablehlo.add %cst, %cst_0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : tensor<2x2x2xf32>
@@ -26,36 +28,36 @@ func.func @test2() -> tensor<2x2x2x3xf32> {
2628// CHECK-NEXT: return %2 : tensor<2x2x2x3xf32>
2729// CHECK-NEXT: }
2830
29- func.func @test3 (%arg0: tensor <3 x2 x3 xf32 >) -> tensor <3 x2 x3 xf32 > {
31+ func.func @test_add_generate_symmetry (%arg0: tensor <3 x2 x3 xf32 >) -> tensor <3 x2 x3 xf32 > {
3032 %0 = stablehlo.transpose %arg0 , dims = [2 , 1 , 0 ] : (tensor <3 x2 x3 xf32 >) -> tensor <3 x2 x3 xf32 >
3133 %1 = stablehlo.add %0 , %arg0 : tensor <3 x2 x3 xf32 >
3234 return %1 : tensor <3 x2 x3 xf32 >
3335}
34- // CHECK: func.func @test3 (%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> {
36+ // CHECK: func.func @test_add_generate_symmetry (%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> {
3537// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
3638// CHECK-NEXT: %1 = stablehlo.add %0, %arg0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 2]>>]} : tensor<3x2x3xf32>
3739// CHECK-NEXT: return %1 : tensor<3x2x3xf32>
3840// CHECK-NEXT: }
3941
40- func.func @test4 () -> tensor <2 x2 xf32 > {
42+ func.func @test_dot_propagate () -> tensor <2 x2 xf32 > {
4143 %cst0 = stablehlo.constant dense <[[[1.0 , 2.0 , 3.0 ], [2.0 , 3.0 , 4.0 ]], [[2.0 , 3.0 , 4.0 ], [3.0 , 4.0 , 5.0 ]]]> : tensor <2 x2 x3 xf32 >
4244 %cst1 = stablehlo.constant dense <[[[1.0 , 2.0 ], [2.0 , 3.0 ]], [[2.0 , 3.0 ], [3.0 , 4.0 ]], [[2.0 , 3.0 ], [3.0 , 4.0 ]]]> : tensor <3 x2 x2 xf32 >
4345 %0 = stablehlo.dot_general %cst0 , %cst1 , batching_dims = [0 , 1 ] x [1 , 2 ], contracting_dims = [2 ] x [0 ] : (tensor <2 x2 x3 xf32 >, tensor <3 x2 x2 xf32 >) -> tensor <2 x2 xf32 >
4446 return %0 : tensor <2 x2 xf32 >
4547}
46- // CHECK: func.func @test4 () -> tensor<2x2xf32> {
48+ // CHECK: func.func @test_dot_propagate () -> tensor<2x2xf32> {
4749// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2x3xf32>
4850// CHECK-NEXT: %cst_0 = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32>
4951// CHECK-NEXT: %0 = stablehlo.dot_general %cst, %cst_0, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32>
5052// CHECK-NEXT: return %0 : tensor<2x2xf32>
5153// CHECK-NEXT: }
5254
53- func.func @test5 (%arg0: tensor <3 x3 x3 xf32 >) -> tensor <3 x3 x3 xf32 > {
55+ func.func @test_dot_generate_symmetry (%arg0: tensor <3 x3 x3 xf32 >) -> tensor <3 x3 x3 xf32 > {
5456 %0 = stablehlo.transpose %arg0 , dims = [2 , 1 , 0 ] : (tensor <3 x3 x3 xf32 >) -> tensor <3 x3 x3 xf32 >
5557 %1 = stablehlo.dot_general %arg0 , %0 , batching_dims = [1 ] x [1 ], contracting_dims = [0 ] x [2 ] : (tensor <3 x3 x3 xf32 >, tensor <3 x3 x3 xf32 >) -> tensor <3 x3 x3 xf32 >
5658 return %1 : tensor <3 x3 x3 xf32 >
5759}
58- // CHECK: func.func @test5 (%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> {
60+ // CHECK: func.func @test_dot_generate_symmetry (%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> {
5961// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32>
6062// CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32>
6163// CHECK-NEXT: return %1 : tensor<3x3x3xf32>
0 commit comments