Skip to content

Commit b95389e

Browse files
committed
Add n-dim transpose removal opt
1 parent 9744194 commit b95389e

File tree

7 files changed

+130
-23
lines changed

7 files changed

+130
-23
lines changed

src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,25 @@ PartialSymmetryAnnotation::getDimensionSets() const {
386386
return result;
387387
}
388388

389+
PartialSymmetryAnnotation
390+
PartialSymmetryAnnotation::fromDimensionSets(int64_t rank,
391+
ArrayRef<ArrayRef<int64_t>> dimensionSets) {
392+
SmallVector<int64_t> dimensionSetIDs(rank);
393+
for (int64_t i = 0; i < rank; ++i) {
394+
dimensionSetIDs[i] = i;
395+
}
396+
397+
// Note that dimensionSets is not assumed to be a complete partition.
398+
// Missing dimensions are treated as separate sets.
399+
for (auto dims : dimensionSets) {
400+
for (int64_t i = 1; i < (int64_t)dims.size(); ++i) {
401+
dimensionSetIDs[dims[i]] = dimensionSetIDs[dims[0]];
402+
}
403+
}
404+
405+
return PartialSymmetryAnnotation(dimensionSetIDs);
406+
}
407+
389408
void PartialSymmetryAnnotation::print(raw_ostream &os) const {
390409
auto dimensionSets = getDimensionSets();
391410
os << "{";

src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class PartialSymmetryAnnotation {
6868

6969
SmallVector<SmallVector<int64_t>> getDimensionSets() const;
7070

71+
static PartialSymmetryAnnotation
72+
fromDimensionSets(int64_t rank, ArrayRef<ArrayRef<int64_t>> dimensionSets);
73+
7174
void print(raw_ostream &os) const;
7275

7376
private:

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h"
3434
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h"
3535
#include "src/enzyme_ad/jax/Passes/Passes.h"
36+
#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h"
3637
#include "src/enzyme_ad/jax/Passes/StructuredTensors.h"
3738
#include "src/enzyme_ad/jax/Utils.h"
3839
#include "stablehlo/dialect/Base.h"
@@ -55,6 +56,7 @@
5556
#include "llvm/ADT/MapVector.h"
5657
#include <iterator>
5758
#include <numeric>
59+
#include <optional>
5860
#define DEBUG_TYPE "enzymehloopt"
5961

6062
namespace mlir {
@@ -6953,6 +6955,82 @@ struct TransposeSymmetricSimplify
69536955
}
69546956
};
69556957

6958+
static std::optional<enzyme::PartialSymmetryAnnotation>
6959+
getPartialSymmetryFromAttr(Value val) {
6960+
auto op = val.getDefiningOp();
6961+
if (!op)
6962+
return std::nullopt;
6963+
6964+
auto arrayAttr =
6965+
op->getAttrOfType<ArrayAttr>("enzymexla.partial_symmetry");
6966+
if (!arrayAttr || arrayAttr.empty())
6967+
return std::nullopt;
6968+
6969+
// Get the result number for this value
6970+
auto opResult = dyn_cast<OpResult>(val);
6971+
if (!opResult)
6972+
return std::nullopt;
6973+
6974+
auto resultNumber = opResult.getResultNumber();
6975+
if (resultNumber >= arrayAttr.size())
6976+
return std::nullopt;
6977+
6978+
auto partialSymmetryAttr = dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
6979+
arrayAttr[resultNumber]);
6980+
if (!partialSymmetryAttr)
6981+
return std::nullopt;
6982+
6983+
auto dimensionSetAttrs = partialSymmetryAttr.getValues();
6984+
auto rank = cast<RankedTensorType>(val.getType()).getRank();
6985+
6986+
SmallVector<ArrayRef<int64_t>> dimensionSets;
6987+
for (auto dimensionSetAttr : dimensionSetAttrs) {
6988+
auto dims = dimensionSetAttr.getDimensions().asArrayRef();
6989+
dimensionSets.push_back(dims);
6990+
}
6991+
6992+
return enzyme::PartialSymmetryAnnotation::fromDimensionSets(rank, dimensionSets);
6993+
}
6994+
6995+
struct TransposePartialSymmetrySimplify
6996+
: public CheckedOpRewritePattern<stablehlo::TransposeOp,
6997+
TransposePartialSymmetrySimplify> {
6998+
using CheckedOpRewritePattern<
6999+
stablehlo::TransposeOp,
7000+
TransposePartialSymmetrySimplify>::CheckedOpRewritePattern;
7001+
7002+
LogicalResult matchAndRewriteImpl(stablehlo::TransposeOp op,
7003+
PatternRewriter &rewriter) const {
7004+
auto operand = op.getOperand();
7005+
auto perm = op.getPermutation();
7006+
7007+
// Get partial symmetry annotation from the operand
7008+
auto annotationOpt = getPartialSymmetryFromAttr(operand);
7009+
if (!annotationOpt.has_value())
7010+
return failure();
7011+
7012+
auto annotation = annotationOpt.value();
7013+
7014+
// Check if the transpose is an identity based on partial symmetry
7015+
// A transpose is identity if permuting dimensions doesn't change which
7016+
// dimensions are in the same symmetric set
7017+
bool isIdentity = true;
7018+
for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
7019+
if (annotation.getSetId(i) != annotation.getSetId(perm[i])) {
7020+
isIdentity = false;
7021+
break;
7022+
}
7023+
}
7024+
7025+
if (isIdentity) {
7026+
rewriter.replaceOp(op, operand);
7027+
return success();
7028+
}
7029+
7030+
return failure();
7031+
}
7032+
};
7033+
69567034
struct NoNanSelfSubSimplify
69577035
: public NoNanCheckedOpRewritePattern<stablehlo::SubtractOp,
69587036
NoNanSelfSubSimplify> {
@@ -26641,7 +26719,8 @@ struct EnzymeHLOOptPass
2664126719
NoNanAddSubSimplify, NoNanMulSimplify, NoNanDivSimplify>(
2664226720
(no_nan || all_finite), context);
2664326721

26644-
patterns.add<TransposeSymmetricSimplify>(context);
26722+
patterns.add<TransposeSymmetricSimplify, TransposePartialSymmetrySimplify>(
26723+
context);
2664526724
patterns.add<FactorScalarsInDotGeneral>(context);
2664626725

2664726726
// syrk patterns

src/enzyme_ad/jax/Passes/PartialSymmetrySimplify.cpp renamed to src/enzyme_ad/jax/Passes/PartialSymmetryAnnotate.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
#include "src/enzyme_ad/jax/Dialect/Ops.h"
1818
#include "stablehlo/dialect/StablehloOps.h"
1919

20-
#define DEBUG_TYPE "partial-symmetry-simplify"
20+
#define DEBUG_TYPE "partial-symmetry-annotate"
2121

2222
namespace mlir {
2323
namespace enzyme {
24-
#define GEN_PASS_DEF_PARTIALSYMMETRYSIMPLIFYPASS
24+
#define GEN_PASS_DEF_PARTIALSYMMETRYANNOTATEPASS
2525
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
2626
} // namespace enzyme
2727
} // namespace mlir
@@ -32,9 +32,9 @@ using namespace mlir::enzyme;
3232

3333
namespace {
3434

35-
class PartialSymmetrySimplifyPass
36-
: public enzyme::impl::PartialSymmetrySimplifyPassBase<
37-
PartialSymmetrySimplifyPass> {
35+
class PartialSymmetryAnnotatePass
36+
: public enzyme::impl::PartialSymmetryAnnotatePassBase<
37+
PartialSymmetryAnnotatePass> {
3838
public:
3939
using Base::Base;
4040

@@ -51,6 +51,7 @@ class PartialSymmetrySimplifyPass
5151

5252
auto mod = getOperation();
5353

54+
// Annotate all operations with partial symmetry information
5455
mod->walk([&](Operation *op) {
5556
SmallVector<Attribute> partialSymmetryAttrs;
5657
bool anyKnown = false;
@@ -92,8 +93,6 @@ class PartialSymmetrySimplifyPass
9293

9394
return WalkResult::advance();
9495
});
95-
96-
// TODO: do things here
9796
}
9897
};
9998

src/enzyme_ad/jax/Passes/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,8 +1077,8 @@ def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> {
10771077
];
10781078
}
10791079

1080-
def PartialSymmetrySimplifyPass : Pass<"partial-symmetry-simplify", "ModuleOp"> {
1081-
let summary = "Simplify operations using partial symmetry analysis";
1080+
def PartialSymmetryAnnotatePass : Pass<"partial-symmetry-annotate", "ModuleOp"> {
1081+
let summary = "Annotate operations using partial symmetry analysis";
10821082
let dependentDialects = [
10831083
"stablehlo::StablehloDialect",
10841084
"enzymexla::EnzymeXLADialect",

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,11 @@ def ApplyTransposeSymmetricSimplify : EnzymeHLOPatternOp<
608608
let patterns = ["TransposeSymmetricSimplify"];
609609
}
610610

611+
def ApplyTransposePartialSymmetrySimplify : EnzymeHLOPatternOp<
612+
"transpose_partial_symmetry_simplify"> {
613+
let patterns = ["TransposePartialSymmetrySimplify"];
614+
}
615+
611616
def ApplyFactorScalarsInDotGeneral : EnzymeHLOPatternOp<
612617
"factor_scalars_in_dot_general"> {
613618
let patterns = ["FactorScalarsInDotGeneral"];

test/lit_tests/structured_tensors/partial_symmetry.mlir

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
1-
// RUN: enzymexlamlir-opt --partial-symmetry-simplify %s | FileCheck %s
1+
// RUN: enzymexlamlir-opt --partial-symmetry-annotate --enzyme-hlo-generate-td="patterns=transpose_partial_symmetry_simplify" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
22

3-
func.func @test1() -> tensor<2x2xf32> {
3+
func.func @test_constant() -> tensor<2x2xf32> {
44
%cst = stablehlo.constant dense<[[1.0, 2.0], [2.0, 3.0]]> : tensor<2x2xf32>
5-
return %cst : tensor<2x2xf32>
5+
%0 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
6+
return %0 : tensor<2x2xf32>
67
}
7-
// CHECK: func.func @test1() -> tensor<2x2xf32> {
8+
// CHECK: func.func @test_constant() -> tensor<2x2xf32> {
89
// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2xf32>
910
// CHECK-NEXT: return %cst : tensor<2x2xf32>
1011
// CHECK-NEXT: }
1112

12-
func.func @test2() -> tensor<2x2x2x3xf32> {
13+
func.func @test_propagate() -> tensor<2x2x2x3xf32> {
1314
%cst0 = stablehlo.constant dense<[[[1.0, 2.0], [3.0, 4.0]], [[3.0, 4.0], [5.0, 6.0]]]> : tensor<2x2x2xf32>
1415
%cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<2x2x2xf32>
1516
%0 = stablehlo.add %cst0, %cst1 : tensor<2x2x2xf32>
1617
%1 = stablehlo.transpose %0, dims = [0, 2, 1] : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32>
1718
%2 = stablehlo.broadcast_in_dim %1, dims = [1, 0, 2] : (tensor<2x2x2xf32>) -> tensor<2x2x2x3xf32>
18-
return %2 : tensor<2x2x2x3xf32>
19+
%3 = stablehlo.transpose %2, dims = [0, 2, 1, 3] : (tensor<2x2x2x3xf32>) -> tensor<2x2x2x3xf32>
20+
return %3 : tensor<2x2x2x3xf32>
1921
}
20-
// CHECK: func.func @test2() -> tensor<2x2x2x3xf32> {
22+
// CHECK: func.func @test_propagate() -> tensor<2x2x2x3xf32> {
2123
// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2x2xf32>
2224
// CHECK-NEXT: %cst_0 = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1, 2]>>]} dense<{{.*}}> : tensor<2x2x2xf32>
2325
// CHECK-NEXT: %0 = stablehlo.add %cst, %cst_0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : tensor<2x2x2xf32>
@@ -26,36 +28,36 @@ func.func @test2() -> tensor<2x2x2x3xf32> {
2628
// CHECK-NEXT: return %2 : tensor<2x2x2x3xf32>
2729
// CHECK-NEXT: }
2830

29-
func.func @test3(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> {
31+
func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> {
3032
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
3133
%1 = stablehlo.add %0, %arg0 : tensor<3x2x3xf32>
3234
return %1 : tensor<3x2x3xf32>
3335
}
34-
// CHECK: func.func @test3(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> {
36+
// CHECK: func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> {
3537
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
3638
// CHECK-NEXT: %1 = stablehlo.add %0, %arg0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 2]>>]} : tensor<3x2x3xf32>
3739
// CHECK-NEXT: return %1 : tensor<3x2x3xf32>
3840
// CHECK-NEXT: }
3941

40-
func.func @test4() -> tensor<2x2xf32> {
42+
func.func @test_dot_propagate() -> tensor<2x2xf32> {
4143
%cst0 = stablehlo.constant dense<[[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], [[2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]]> : tensor<2x2x3xf32>
4244
%cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<3x2x2xf32>
4345
%0 = stablehlo.dot_general %cst0, %cst1, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32>
4446
return %0 : tensor<2x2xf32>
4547
}
46-
// CHECK: func.func @test4() -> tensor<2x2xf32> {
48+
// CHECK: func.func @test_dot_propagate() -> tensor<2x2xf32> {
4749
// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2x3xf32>
4850
// CHECK-NEXT: %cst_0 = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32>
4951
// CHECK-NEXT: %0 = stablehlo.dot_general %cst, %cst_0, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32>
5052
// CHECK-NEXT: return %0 : tensor<2x2xf32>
5153
// CHECK-NEXT: }
5254

53-
func.func @test5(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> {
55+
func.func @test_dot_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> {
5456
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32>
5557
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32>
5658
return %1 : tensor<3x3x3xf32>
5759
}
58-
// CHECK: func.func @test5(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> {
60+
// CHECK: func.func @test_dot_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> {
5961
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32>
6062
// CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32>
6163
// CHECK-NEXT: return %1 : tensor<3x3x3xf32>

0 commit comments

Comments
 (0)