Skip to content

Commit d164c18

Browse files
committed
Format
1 parent 4a68de3 commit d164c18

File tree

3 files changed

+31
-24
lines changed

3 files changed

+31
-24
lines changed

src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,8 @@ PartialSymmetryAnnotation::getDimensionSets() const {
388388
return result;
389389
}
390390

391-
PartialSymmetryAnnotation
392-
PartialSymmetryAnnotation::createFromDimensionSets(int64_t rank,
393-
ArrayRef<ArrayRef<int64_t>> dimensionSets) {
391+
PartialSymmetryAnnotation PartialSymmetryAnnotation::createFromDimensionSets(
392+
int64_t rank, ArrayRef<ArrayRef<int64_t>> dimensionSets) {
394393
SmallVector<int64_t> dimensionSetIDs(rank);
395394
for (int64_t i = 0; i < rank; ++i) {
396395
dimensionSetIDs[i] = i;
@@ -424,14 +423,15 @@ PartialSymmetryAnnotation::createFromIR(Value val) {
424423
if (!arrayAttr || arrayAttr.empty()) {
425424
continue;
426425
}
427-
428-
auto partialSymmetryAttr = dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
429-
arrayAttr[0]);
430-
426+
427+
auto partialSymmetryAttr =
428+
dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
429+
arrayAttr[0]);
430+
431431
if (!partialSymmetryAttr) {
432432
continue;
433433
}
434-
434+
435435
auto dimensionSetAttrs = partialSymmetryAttr.getValues();
436436
auto rank = cast<RankedTensorType>(val.getType()).getRank();
437437

@@ -441,7 +441,8 @@ PartialSymmetryAnnotation::createFromIR(Value val) {
441441
dimensionSets.push_back(dims);
442442
}
443443

444-
return PartialSymmetryAnnotation::createFromDimensionSets(rank, dimensionSets);
444+
return PartialSymmetryAnnotation::createFromDimensionSets(
445+
rank, dimensionSets);
445446
}
446447
}
447448
return std::nullopt;
@@ -451,8 +452,7 @@ PartialSymmetryAnnotation::createFromIR(Value val) {
451452
if (!op)
452453
return std::nullopt;
453454

454-
auto arrayAttr =
455-
op->getAttrOfType<ArrayAttr>("enzymexla.partial_symmetry");
455+
auto arrayAttr = op->getAttrOfType<ArrayAttr>("enzymexla.partial_symmetry");
456456
if (!arrayAttr || arrayAttr.empty())
457457
return std::nullopt;
458458

@@ -462,8 +462,9 @@ PartialSymmetryAnnotation::createFromIR(Value val) {
462462

463463
auto resultNumber = opResult.getResultNumber();
464464

465-
auto partialSymmetryAttr = dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
466-
arrayAttr[resultNumber]);
465+
auto partialSymmetryAttr =
466+
dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
467+
arrayAttr[resultNumber]);
467468
if (!partialSymmetryAttr)
468469
return std::nullopt;
469470

@@ -476,7 +477,8 @@ PartialSymmetryAnnotation::createFromIR(Value val) {
476477
dimensionSets.push_back(dims);
477478
}
478479

479-
return PartialSymmetryAnnotation::createFromDimensionSets(rank, dimensionSets);
480+
return PartialSymmetryAnnotation::createFromDimensionSets(rank,
481+
dimensionSets);
480482
}
481483

482484
void PartialSymmetryAnnotation::print(raw_ostream &os) const {
@@ -504,15 +506,16 @@ void PartialSymmetryAnnotation::print(raw_ostream &os) const {
504506
// PartialSymmetryLattice Implementation
505507
//===----------------------------------------------------------------------===//
506508

507-
PartialSymmetryLattice::PartialSymmetryLattice(Value v) : AbstractSparseLattice(v) {
509+
PartialSymmetryLattice::PartialSymmetryLattice(Value v)
510+
: AbstractSparseLattice(v) {
508511
if (auto type = dyn_cast<RankedTensorType>(v.getType())) {
509512
// Trust existing IR annotations if present.
510513
auto annotation = PartialSymmetryAnnotation::createFromIR(v);
511514
if (annotation.has_value()) {
512515
value = annotation.value();
513516
return;
514517
}
515-
518+
516519
value = PartialSymmetryAnnotation::createFullySymmetric(type.getRank());
517520
}
518521
}
@@ -539,12 +542,13 @@ void PartialSymmetryLattice::print(raw_ostream &os) const { value.print(os); }
539542
//===----------------------------------------------------------------------===//
540543

541544
void PartialSymmetryAnalysis::setToEntryState(PartialSymmetryLattice *lattice) {
542-
auto annotation = PartialSymmetryAnnotation::createFromIR(lattice->getAnchor());
545+
auto annotation =
546+
PartialSymmetryAnnotation::createFromIR(lattice->getAnchor());
543547
if (annotation.has_value()) {
544548
lattice->setValue(annotation.value());
545549
return;
546550
}
547-
551+
548552
lattice->setValue(PartialSymmetryAnnotation::createNotSymmetric(
549553
lattice->getValue().getRank()));
550554
}

src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ class PartialSymmetryAnnotation {
6565
}
6666

6767
SmallVector<SmallVector<int64_t>> getDimensionSets() const;
68-
static PartialSymmetryAnnotation createFromDimensionSets(int64_t rank, ArrayRef<ArrayRef<int64_t>> dimensionSets);
68+
static PartialSymmetryAnnotation
69+
createFromDimensionSets(int64_t rank,
70+
ArrayRef<ArrayRef<int64_t>> dimensionSets);
6971
static std::optional<PartialSymmetryAnnotation> createFromIR(Value val);
7072

7173
void print(raw_ostream &os) const;

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,22 @@
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
2020
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2121
#include "mlir/IR/Builders.h"
22-
#include "mlir/Interfaces/FunctionInterfaces.h"
2322
#include "mlir/IR/Dominance.h"
2423
#include "mlir/IR/IRMapping.h"
2524
#include "mlir/IR/Matchers.h"
2625
#include "mlir/IR/PatternMatch.h"
2726
#include "mlir/IR/Visitors.h"
27+
#include "mlir/Interfaces/FunctionInterfaces.h"
2828
#include "mlir/Pass/PassManager.h"
2929
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3030
#include "shardy/dialect/sdy/ir/utils.h"
31+
#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h"
3132
#include "src/enzyme_ad/jax/CheckedRewrite.h"
3233
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
3334
#include "src/enzyme_ad/jax/Dialect/Ops.h"
3435
#include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h"
3536
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h"
3637
#include "src/enzyme_ad/jax/Passes/Passes.h"
37-
#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h"
3838
#include "src/enzyme_ad/jax/Passes/StructuredTensors.h"
3939
#include "src/enzyme_ad/jax/Utils.h"
4040
#include "stablehlo/dialect/Base.h"
@@ -6964,14 +6964,15 @@ struct TransposePartialSymmetrySimplify
69646964
TransposePartialSymmetrySimplify>::CheckedOpRewritePattern;
69656965

69666966
LogicalResult matchAndRewriteImpl(stablehlo::TransposeOp op,
6967-
PatternRewriter &rewriter) const {
6967+
PatternRewriter &rewriter) const {
69686968
auto operand = op.getOperand();
69696969
auto perm = op.getPermutation();
69706970

6971-
auto annotationOpt = enzyme::PartialSymmetryAnnotation::createFromIR(operand);
6971+
auto annotationOpt =
6972+
enzyme::PartialSymmetryAnnotation::createFromIR(operand);
69726973
if (!annotationOpt.has_value())
69736974
return failure();
6974-
6975+
69756976
auto annotation = annotationOpt.value();
69766977

69776978
bool isIdentity = true;

0 commit comments

Comments
 (0)