Skip to content

Commit 4a68de3

Browse files
committed
Recognize existing partial symmetry annotations in IR
1 parent b95389e commit 4a68de3

File tree

4 files changed

+104
-60
lines changed

4 files changed

+104
-60
lines changed

src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
66
#include "mlir/IR/Matchers.h"
77
#include "mlir/IR/PatternMatch.h"
8+
#include "mlir/Interfaces/FunctionInterfaces.h"
9+
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
810
#include "stablehlo/dialect/StablehloOps.h"
911
#include "llvm/ADT/DenseMap.h"
1012

@@ -387,7 +389,7 @@ PartialSymmetryAnnotation::getDimensionSets() const {
387389
}
388390

389391
PartialSymmetryAnnotation
390-
PartialSymmetryAnnotation::fromDimensionSets(int64_t rank,
392+
PartialSymmetryAnnotation::createFromDimensionSets(int64_t rank,
391393
ArrayRef<ArrayRef<int64_t>> dimensionSets) {
392394
SmallVector<int64_t> dimensionSetIDs(rank);
393395
for (int64_t i = 0; i < rank; ++i) {
@@ -405,6 +407,78 @@ PartialSymmetryAnnotation::fromDimensionSets(int64_t rank,
405407
return PartialSymmetryAnnotation(dimensionSetIDs);
406408
}
407409

410+
std::optional<PartialSymmetryAnnotation>
411+
PartialSymmetryAnnotation::createFromIR(Value val) {
412+
auto blockArg = dyn_cast<BlockArgument>(val);
413+
if (blockArg) {
414+
auto op = blockArg.getOwner()->getParentOp();
415+
auto funcOpInterface = dyn_cast<FunctionOpInterface>(op);
416+
if (!funcOpInterface) {
417+
return std::nullopt;
418+
}
419+
420+
auto argAttrs = funcOpInterface.getArgAttrs(blockArg.getArgNumber());
421+
for (auto attr : argAttrs) {
422+
if (attr.getName() == "enzymexla.partial_symmetry") {
423+
auto arrayAttr = dyn_cast<ArrayAttr>(attr.getValue());
424+
if (!arrayAttr || arrayAttr.empty()) {
425+
continue;
426+
}
427+
428+
auto partialSymmetryAttr = dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
429+
arrayAttr[0]);
430+
431+
if (!partialSymmetryAttr) {
432+
continue;
433+
}
434+
435+
auto dimensionSetAttrs = partialSymmetryAttr.getValues();
436+
auto rank = cast<RankedTensorType>(val.getType()).getRank();
437+
438+
SmallVector<ArrayRef<int64_t>> dimensionSets;
439+
for (auto dimensionSetAttr : dimensionSetAttrs) {
440+
auto dims = dimensionSetAttr.getDimensions().asArrayRef();
441+
dimensionSets.push_back(dims);
442+
}
443+
444+
return PartialSymmetryAnnotation::createFromDimensionSets(rank, dimensionSets);
445+
}
446+
}
447+
return std::nullopt;
448+
}
449+
450+
auto op = val.getDefiningOp();
451+
if (!op)
452+
return std::nullopt;
453+
454+
auto arrayAttr =
455+
op->getAttrOfType<ArrayAttr>("enzymexla.partial_symmetry");
456+
if (!arrayAttr || arrayAttr.empty())
457+
return std::nullopt;
458+
459+
auto opResult = dyn_cast<OpResult>(val);
460+
if (!opResult)
461+
return std::nullopt;
462+
463+
auto resultNumber = opResult.getResultNumber();
464+
465+
auto partialSymmetryAttr = dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
466+
arrayAttr[resultNumber]);
467+
if (!partialSymmetryAttr)
468+
return std::nullopt;
469+
470+
auto dimensionSetAttrs = partialSymmetryAttr.getValues();
471+
auto rank = cast<RankedTensorType>(val.getType()).getRank();
472+
473+
SmallVector<ArrayRef<int64_t>> dimensionSets;
474+
for (auto dimensionSetAttr : dimensionSetAttrs) {
475+
auto dims = dimensionSetAttr.getDimensions().asArrayRef();
476+
dimensionSets.push_back(dims);
477+
}
478+
479+
return PartialSymmetryAnnotation::createFromDimensionSets(rank, dimensionSets);
480+
}
481+
408482
void PartialSymmetryAnnotation::print(raw_ostream &os) const {
409483
auto dimensionSets = getDimensionSets();
410484
os << "{";
@@ -430,6 +504,19 @@ void PartialSymmetryAnnotation::print(raw_ostream &os) const {
430504
// PartialSymmetryLattice Implementation
431505
//===----------------------------------------------------------------------===//
432506

507+
PartialSymmetryLattice::PartialSymmetryLattice(Value v) : AbstractSparseLattice(v) {
508+
if (auto type = dyn_cast<RankedTensorType>(v.getType())) {
509+
// Trust existing IR annotations if present.
510+
auto annotation = PartialSymmetryAnnotation::createFromIR(v);
511+
if (annotation.has_value()) {
512+
value = annotation.value();
513+
return;
514+
}
515+
516+
value = PartialSymmetryAnnotation::createFullySymmetric(type.getRank());
517+
}
518+
}
519+
433520
ChangeResult PartialSymmetryLattice::join(const AbstractSparseLattice &rhs) {
434521
const auto *rhsStruct =
435522
reinterpret_cast<const PartialSymmetryLattice *>(&rhs);
@@ -452,6 +539,12 @@ void PartialSymmetryLattice::print(raw_ostream &os) const { value.print(os); }
452539
//===----------------------------------------------------------------------===//
453540

454541
void PartialSymmetryAnalysis::setToEntryState(PartialSymmetryLattice *lattice) {
542+
auto annotation = PartialSymmetryAnnotation::createFromIR(lattice->getAnchor());
543+
if (annotation.has_value()) {
544+
lattice->setValue(annotation.value());
545+
return;
546+
}
547+
455548
lattice->setValue(PartialSymmetryAnnotation::createNotSymmetric(
456549
lattice->getValue().getRank()));
457550
}

src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ class PartialSymmetryAnnotation {
2626
static PartialSymmetryAnnotation createFullySymmetric(int64_t rank);
2727

2828
bool isSymmetric(int64_t i, int64_t j) const;
29-
3029
int64_t getSetId(int64_t i) const { return dimensionSetIDs[i]; }
31-
3230
int64_t getRank() const { return dimensionSetIDs.size(); }
3331

3432
static PartialSymmetryAnnotation join(const PartialSymmetryAnnotation &lhs,
@@ -67,9 +65,8 @@ class PartialSymmetryAnnotation {
6765
}
6866

6967
SmallVector<SmallVector<int64_t>> getDimensionSets() const;
70-
71-
static PartialSymmetryAnnotation
72-
fromDimensionSets(int64_t rank, ArrayRef<ArrayRef<int64_t>> dimensionSets);
68+
static PartialSymmetryAnnotation createFromDimensionSets(int64_t rank, ArrayRef<ArrayRef<int64_t>> dimensionSets);
69+
static std::optional<PartialSymmetryAnnotation> createFromIR(Value val);
7370

7471
void print(raw_ostream &os) const;
7572

@@ -84,11 +81,7 @@ class PartialSymmetryLattice : public dataflow::AbstractSparseLattice {
8481
public:
8582
using AbstractSparseLattice::AbstractSparseLattice;
8683

87-
PartialSymmetryLattice(Value v) : AbstractSparseLattice(v) {
88-
if (auto type = dyn_cast<RankedTensorType>(v.getType())) {
89-
value = PartialSymmetryAnnotation::createFullySymmetric(type.getRank());
90-
}
91-
}
84+
PartialSymmetryLattice(Value v);
9285

9386
ChangeResult join(const AbstractSparseLattice &rhs) override;
9487
ChangeResult join(const PartialSymmetryLattice &rhs);

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
2020
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2121
#include "mlir/IR/Builders.h"
22+
#include "mlir/Interfaces/FunctionInterfaces.h"
2223
#include "mlir/IR/Dominance.h"
2324
#include "mlir/IR/IRMapping.h"
2425
#include "mlir/IR/Matchers.h"
@@ -6955,43 +6956,6 @@ struct TransposeSymmetricSimplify
69556956
}
69566957
};
69576958

6958-
static std::optional<enzyme::PartialSymmetryAnnotation>
6959-
getPartialSymmetryFromAttr(Value val) {
6960-
auto op = val.getDefiningOp();
6961-
if (!op)
6962-
return std::nullopt;
6963-
6964-
auto arrayAttr =
6965-
op->getAttrOfType<ArrayAttr>("enzymexla.partial_symmetry");
6966-
if (!arrayAttr || arrayAttr.empty())
6967-
return std::nullopt;
6968-
6969-
// Get the result number for this value
6970-
auto opResult = dyn_cast<OpResult>(val);
6971-
if (!opResult)
6972-
return std::nullopt;
6973-
6974-
auto resultNumber = opResult.getResultNumber();
6975-
if (resultNumber >= arrayAttr.size())
6976-
return std::nullopt;
6977-
6978-
auto partialSymmetryAttr = dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
6979-
arrayAttr[resultNumber]);
6980-
if (!partialSymmetryAttr)
6981-
return std::nullopt;
6982-
6983-
auto dimensionSetAttrs = partialSymmetryAttr.getValues();
6984-
auto rank = cast<RankedTensorType>(val.getType()).getRank();
6985-
6986-
SmallVector<ArrayRef<int64_t>> dimensionSets;
6987-
for (auto dimensionSetAttr : dimensionSetAttrs) {
6988-
auto dims = dimensionSetAttr.getDimensions().asArrayRef();
6989-
dimensionSets.push_back(dims);
6990-
}
6991-
6992-
return enzyme::PartialSymmetryAnnotation::fromDimensionSets(rank, dimensionSets);
6993-
}
6994-
69956959
struct TransposePartialSymmetrySimplify
69966960
: public CheckedOpRewritePattern<stablehlo::TransposeOp,
69976961
TransposePartialSymmetrySimplify> {
@@ -7004,16 +6968,12 @@ struct TransposePartialSymmetrySimplify
70046968
auto operand = op.getOperand();
70056969
auto perm = op.getPermutation();
70066970

7007-
// Get partial symmetry annotation from the operand
7008-
auto annotationOpt = getPartialSymmetryFromAttr(operand);
6971+
auto annotationOpt = enzyme::PartialSymmetryAnnotation::createFromIR(operand);
70096972
if (!annotationOpt.has_value())
70106973
return failure();
70116974

70126975
auto annotation = annotationOpt.value();
70136976

7014-
// Check if the transpose is an identity based on partial symmetry
7015-
// A transpose is identity if permuting dimensions doesn't change which
7016-
// dimensions are in the same symmetric set
70176977
bool isIdentity = true;
70186978
for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
70196979
if (annotation.getSetId(i) != annotation.getSetId(perm[i])) {

test/lit_tests/structured_tensors/partial_symmetry.mlir

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,14 @@ func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3x
3939
// CHECK-NEXT: return %1 : tensor<3x2x3xf32>
4040
// CHECK-NEXT: }
4141

42-
func.func @test_dot_propagate() -> tensor<2x2xf32> {
43-
%cst0 = stablehlo.constant dense<[[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], [[2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]]> : tensor<2x2x3xf32>
42+
func.func @test_dot_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> {
4443
%cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<3x2x2xf32>
45-
%0 = stablehlo.dot_general %cst0, %cst1, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32>
44+
%0 = stablehlo.dot_general %arg0, %cst1, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32>
4645
return %0 : tensor<2x2xf32>
4746
}
48-
// CHECK: func.func @test_dot_propagate() -> tensor<2x2xf32> {
49-
// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2x3xf32>
50-
// CHECK-NEXT: %cst_0 = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32>
51-
// CHECK-NEXT: %0 = stablehlo.dot_general %cst, %cst_0, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32>
47+
// CHECK: func.func @test_dot_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> {
48+
// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32>
49+
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %cst, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32>
5250
// CHECK-NEXT: return %0 : tensor<2x2xf32>
5351
// CHECK-NEXT: }
5452

0 commit comments

Comments
 (0)