@@ -388,9 +388,8 @@ PartialSymmetryAnnotation::getDimensionSets() const {
388388 return result;
389389}
390390
391- PartialSymmetryAnnotation
392- PartialSymmetryAnnotation::createFromDimensionSets (int64_t rank,
393- ArrayRef<ArrayRef<int64_t >> dimensionSets) {
391+ PartialSymmetryAnnotation PartialSymmetryAnnotation::createFromDimensionSets (
392+ int64_t rank, ArrayRef<ArrayRef<int64_t >> dimensionSets) {
394393 SmallVector<int64_t > dimensionSetIDs (rank);
395394 for (int64_t i = 0 ; i < rank; ++i) {
396395 dimensionSetIDs[i] = i;
@@ -424,14 +423,15 @@ PartialSymmetryAnnotation::createFromIR(Value val) {
424423 if (!arrayAttr || arrayAttr.empty ()) {
425424 continue ;
426425 }
427-
428- auto partialSymmetryAttr = dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
429- arrayAttr[0 ]);
430-
426+
427+ auto partialSymmetryAttr =
428+ dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
429+ arrayAttr[0 ]);
430+
431431 if (!partialSymmetryAttr) {
432432 continue ;
433433 }
434-
434+
435435 auto dimensionSetAttrs = partialSymmetryAttr.getValues ();
436436 auto rank = cast<RankedTensorType>(val.getType ()).getRank ();
437437
@@ -441,7 +441,8 @@ PartialSymmetryAnnotation::createFromIR(Value val) {
441441 dimensionSets.push_back (dims);
442442 }
443443
444- return PartialSymmetryAnnotation::createFromDimensionSets (rank, dimensionSets);
444+ return PartialSymmetryAnnotation::createFromDimensionSets (
445+ rank, dimensionSets);
445446 }
446447 }
447448 return std::nullopt ;
@@ -451,8 +452,7 @@ PartialSymmetryAnnotation::createFromIR(Value val) {
451452 if (!op)
452453 return std::nullopt ;
453454
454- auto arrayAttr =
455- op->getAttrOfType <ArrayAttr>(" enzymexla.partial_symmetry" );
455+ auto arrayAttr = op->getAttrOfType <ArrayAttr>(" enzymexla.partial_symmetry" );
456456 if (!arrayAttr || arrayAttr.empty ())
457457 return std::nullopt ;
458458
@@ -462,8 +462,9 @@ PartialSymmetryAnnotation::createFromIR(Value val) {
462462
463463 auto resultNumber = opResult.getResultNumber ();
464464
465- auto partialSymmetryAttr = dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
466- arrayAttr[resultNumber]);
465+ auto partialSymmetryAttr =
466+ dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
467+ arrayAttr[resultNumber]);
467468 if (!partialSymmetryAttr)
468469 return std::nullopt ;
469470
@@ -476,7 +477,8 @@ PartialSymmetryAnnotation::createFromIR(Value val) {
476477 dimensionSets.push_back (dims);
477478 }
478479
479- return PartialSymmetryAnnotation::createFromDimensionSets (rank, dimensionSets);
480+ return PartialSymmetryAnnotation::createFromDimensionSets (rank,
481+ dimensionSets);
480482}
481483
482484void PartialSymmetryAnnotation::print (raw_ostream &os) const {
@@ -504,15 +506,16 @@ void PartialSymmetryAnnotation::print(raw_ostream &os) const {
504506// PartialSymmetryLattice Implementation
505507// ===----------------------------------------------------------------------===//
506508
507- PartialSymmetryLattice::PartialSymmetryLattice (Value v) : AbstractSparseLattice(v) {
509+ PartialSymmetryLattice::PartialSymmetryLattice (Value v)
510+ : AbstractSparseLattice(v) {
508511 if (auto type = dyn_cast<RankedTensorType>(v.getType ())) {
509512 // Trust existing IR annotations if present.
510513 auto annotation = PartialSymmetryAnnotation::createFromIR (v);
511514 if (annotation.has_value ()) {
512515 value = annotation.value ();
513516 return ;
514517 }
515-
518+
516519 value = PartialSymmetryAnnotation::createFullySymmetric (type.getRank ());
517520 }
518521}
@@ -539,12 +542,13 @@ void PartialSymmetryLattice::print(raw_ostream &os) const { value.print(os); }
539542// ===----------------------------------------------------------------------===//
540543
541544void PartialSymmetryAnalysis::setToEntryState (PartialSymmetryLattice *lattice) {
542- auto annotation = PartialSymmetryAnnotation::createFromIR (lattice->getAnchor ());
545+ auto annotation =
546+ PartialSymmetryAnnotation::createFromIR (lattice->getAnchor ());
543547 if (annotation.has_value ()) {
544548 lattice->setValue (annotation.value ());
545549 return ;
546550 }
547-
551+
548552 lattice->setValue (PartialSymmetryAnnotation::createNotSymmetric (
549553 lattice->getValue ().getRank ()));
550554}
0 commit comments