55#include " mlir/Analysis/DataFlow/SparseAnalysis.h"
66#include " mlir/IR/Matchers.h"
77#include " mlir/IR/PatternMatch.h"
8+ #include " mlir/Interfaces/FunctionInterfaces.h"
9+ #include " src/enzyme_ad/jax/Dialect/Dialect.h"
810#include " stablehlo/dialect/StablehloOps.h"
911#include " llvm/ADT/DenseMap.h"
1012
@@ -387,7 +389,7 @@ PartialSymmetryAnnotation::getDimensionSets() const {
387389}
388390
389391PartialSymmetryAnnotation
390- PartialSymmetryAnnotation::fromDimensionSets (int64_t rank,
392+ PartialSymmetryAnnotation::createFromDimensionSets (int64_t rank,
391393 ArrayRef<ArrayRef<int64_t >> dimensionSets) {
392394 SmallVector<int64_t > dimensionSetIDs (rank);
393395 for (int64_t i = 0 ; i < rank; ++i) {
@@ -405,6 +407,78 @@ PartialSymmetryAnnotation::fromDimensionSets(int64_t rank,
405407 return PartialSymmetryAnnotation (dimensionSetIDs);
406408}
407409
410+ std::optional<PartialSymmetryAnnotation>
411+ PartialSymmetryAnnotation::createFromIR (Value val) {
412+ auto blockArg = dyn_cast<BlockArgument>(val);
413+ if (blockArg) {
414+ auto op = blockArg.getOwner ()->getParentOp ();
415+ auto funcOpInterface = dyn_cast<FunctionOpInterface>(op);
416+ if (!funcOpInterface) {
417+ return std::nullopt ;
418+ }
419+
420+ auto argAttrs = funcOpInterface.getArgAttrs (blockArg.getArgNumber ());
421+ for (auto attr : argAttrs) {
422+ if (attr.getName () == " enzymexla.partial_symmetry" ) {
423+ auto arrayAttr = dyn_cast<ArrayAttr>(attr.getValue ());
424+ if (!arrayAttr || arrayAttr.empty ()) {
425+ continue ;
426+ }
427+
428+ auto partialSymmetryAttr = dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
429+ arrayAttr[0 ]);
430+
431+ if (!partialSymmetryAttr) {
432+ continue ;
433+ }
434+
435+ auto dimensionSetAttrs = partialSymmetryAttr.getValues ();
436+ auto rank = cast<RankedTensorType>(val.getType ()).getRank ();
437+
438+ SmallVector<ArrayRef<int64_t >> dimensionSets;
439+ for (auto dimensionSetAttr : dimensionSetAttrs) {
440+ auto dims = dimensionSetAttr.getDimensions ().asArrayRef ();
441+ dimensionSets.push_back (dims);
442+ }
443+
444+ return PartialSymmetryAnnotation::createFromDimensionSets (rank, dimensionSets);
445+ }
446+ }
447+ return std::nullopt ;
448+ }
449+
450+ auto op = val.getDefiningOp ();
451+ if (!op)
452+ return std::nullopt ;
453+
454+ auto arrayAttr =
455+ op->getAttrOfType <ArrayAttr>(" enzymexla.partial_symmetry" );
456+ if (!arrayAttr || arrayAttr.empty ())
457+ return std::nullopt ;
458+
459+ auto opResult = dyn_cast<OpResult>(val);
460+ if (!opResult)
461+ return std::nullopt ;
462+
463+ auto resultNumber = opResult.getResultNumber ();
464+
465+ auto partialSymmetryAttr = dyn_cast<enzymexla::PartialSymmetryAnalysisResultAttr>(
466+ arrayAttr[resultNumber]);
467+ if (!partialSymmetryAttr)
468+ return std::nullopt ;
469+
470+ auto dimensionSetAttrs = partialSymmetryAttr.getValues ();
471+ auto rank = cast<RankedTensorType>(val.getType ()).getRank ();
472+
473+ SmallVector<ArrayRef<int64_t >> dimensionSets;
474+ for (auto dimensionSetAttr : dimensionSetAttrs) {
475+ auto dims = dimensionSetAttr.getDimensions ().asArrayRef ();
476+ dimensionSets.push_back (dims);
477+ }
478+
479+ return PartialSymmetryAnnotation::createFromDimensionSets (rank, dimensionSets);
480+ }
481+
408482void PartialSymmetryAnnotation::print (raw_ostream &os) const {
409483 auto dimensionSets = getDimensionSets ();
410484 os << " {" ;
@@ -430,6 +504,19 @@ void PartialSymmetryAnnotation::print(raw_ostream &os) const {
430504// PartialSymmetryLattice Implementation
431505// ===----------------------------------------------------------------------===//
432506
507+ PartialSymmetryLattice::PartialSymmetryLattice (Value v) : AbstractSparseLattice(v) {
508+ if (auto type = dyn_cast<RankedTensorType>(v.getType ())) {
509+ // Trust existing IR annotations if present.
510+ auto annotation = PartialSymmetryAnnotation::createFromIR (v);
511+ if (annotation.has_value ()) {
512+ value = annotation.value ();
513+ return ;
514+ }
515+
516+ value = PartialSymmetryAnnotation::createFullySymmetric (type.getRank ());
517+ }
518+ }
519+
433520ChangeResult PartialSymmetryLattice::join (const AbstractSparseLattice &rhs) {
434521 const auto *rhsStruct =
435522 reinterpret_cast <const PartialSymmetryLattice *>(&rhs);
@@ -452,6 +539,12 @@ void PartialSymmetryLattice::print(raw_ostream &os) const { value.print(os); }
452539// ===----------------------------------------------------------------------===//
453540
454541void PartialSymmetryAnalysis::setToEntryState (PartialSymmetryLattice *lattice) {
542+ auto annotation = PartialSymmetryAnnotation::createFromIR (lattice->getAnchor ());
543+ if (annotation.has_value ()) {
544+ lattice->setValue (annotation.value ());
545+ return ;
546+ }
547+
455548 lattice->setValue (PartialSymmetryAnnotation::createNotSymmetric (
456549 lattice->getValue ().getRank ()));
457550}
0 commit comments