From 4b6eac70ee4f262258bf29735dc8fac6a6907672 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Thu, 27 Nov 2025 20:29:16 +0000 Subject: [PATCH 01/21] Add partial symmetry detection --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 563 ++++++++++++++++++ .../jax/Analysis/PartialSymmetryAnalysis.h | 116 ++++ src/enzyme_ad/jax/BUILD | 2 + src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td | 32 + .../jax/Passes/PartialSymmetrySimplify.cpp | 100 ++++ src/enzyme_ad/jax/Passes/Passes.td | 9 + .../structured_tensors/partial_symmetry.mlir | 52 ++ 7 files changed, 874 insertions(+) create mode 100644 src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp create mode 100644 src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h create mode 100644 src/enzyme_ad/jax/Passes/PartialSymmetrySimplify.cpp create mode 100644 test/lit_tests/structured_tensors/partial_symmetry.mlir diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp new file mode 100644 index 000000000..6589c84c9 --- /dev/null +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -0,0 +1,563 @@ +#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h" +#include "src/enzyme_ad/jax/Utils.h" + +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "llvm/ADT/DenseMap.h" + +using namespace mlir; +using namespace mlir::dataflow; + +namespace mlir { +namespace enzyme { + +//===----------------------------------------------------------------------===// +// PartialSymmetryAnnotation Implementation +//===----------------------------------------------------------------------===// + +PartialSymmetryAnnotation::PartialSymmetryAnnotation(ArrayRef s) + : known(true) { + storage.assign(s.begin(), s.end()); + canonicalize(); +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::createFullySymmetric(int64_t rank) { + PartialSymmetryAnnotation annotation; + annotation.known = true; + for (int64_t i = 0; i < rank; ++i) { + annotation.storage.push_back(0); + } + return annotation; +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::createNotSymmetric(int64_t rank) { + PartialSymmetryAnnotation annotation; + annotation.known = true; + for (int64_t i = 0; i < rank; ++i) { + annotation.storage.push_back(i); + } + return annotation; +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::createKnownUninitialized(int64_t rank) { + PartialSymmetryAnnotation annotation; + annotation.known = true; + annotation.storage.resize(rank); + return annotation; +} + +bool PartialSymmetryAnnotation::isSymmetric(int64_t i, int64_t j) const { + if (i < 0 || i >= (int64_t)storage.size() || j < 0 || + j >= (int64_t)storage.size()) + return false; + return storage[i] == storage[j]; +} + +void PartialSymmetryAnnotation::canonicalize() { + llvm::SmallDenseMap map; + int nextId = 0; + for (auto &id : storage) { + if (map.find(id) == map.end()) { + map[id] = nextId++; + } + id = map[id]; + } +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::join(const PartialSymmetryAnnotation &lhs, + const PartialSymmetryAnnotation &rhs) { + if (lhs.isUnknown() || rhs.isUnknown()) + return PartialSymmetryAnnotation(); + + PartialSymmetryAnnotation result = createKnownUninitialized(lhs.getRank()); + int nextId = 0; + + for (int64_t i = 0; i < lhs.getRank(); ++i) { + bool found = false; + for (int64_t j = 0; j < i; ++j) { + if (lhs.getSetId(i) == lhs.getSetId(j) && + rhs.getSetId(i) == rhs.getSetId(j)) { + result.storage[i] = result.storage[j]; + found = true; + break; + } + } + if (!found) { + result.storage[i] = nextId++; + } + } + + result.canonicalize(); + return result; +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::meet(const PartialSymmetryAnnotation &lhs, + const PartialSymmetryAnnotation &rhs) { + if (lhs.isUnknown()) + return rhs; + if (rhs.isUnknown()) + return lhs; + + PartialSymmetryAnnotation result = createKnownUninitialized(lhs.getRank()); + int nextId = 0; + + for (int64_t i = 0; i < lhs.getRank(); ++i) { + bool found = false; + for (int64_t j = 0; j < i; ++j) { + if (lhs.getSetId(i) == lhs.getSetId(j) || + rhs.getSetId(i) == rhs.getSetId(j)) { + result.storage[i] = result.storage[j]; + found = true; + break; + } + } + if (!found) { + result.storage[i] = nextId++; + } + } + + result.canonicalize(); + return result; +} + +PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateTranspose( + const PartialSymmetryAnnotation &annotation, + ArrayRef permutation) { + if (annotation.isUnknown()) + return PartialSymmetryAnnotation(); + + PartialSymmetryAnnotation result = + createKnownUninitialized(annotation.getRank()); + + for (int64_t i = 0; i < annotation.getRank(); ++i) { + result.storage[i] = annotation.getSetId(permutation[i]); + } + + result.canonicalize(); + return result; +} + +PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateBroadcastInDim( + const PartialSymmetryAnnotation &annotation, int64_t outputRank, + ArrayRef broadcastDimensions) { + + if (annotation.isUnknown()) + return PartialSymmetryAnnotation(); + + PartialSymmetryAnnotation result = createKnownUninitialized(outputRank); + + llvm::SmallDenseMap outputToInput; + for (size_t i = 0; i < broadcastDimensions.size(); ++i) { + outputToInput[broadcastDimensions[i]] = i; + } + + int maxSetId = -1; + for (int64_t i = 0; i < annotation.getRank(); ++i) { + maxSetId = std::max(maxSetId, annotation.getSetId(i)); + } + + int nextNewSetId = maxSetId + 1; + for (int64_t outputDim = 0; outputDim < outputRank; ++outputDim) { + if (outputToInput.find(outputDim) != outputToInput.end()) { + // dimension is preserved => use old ID + int64_t inputDim = outputToInput[outputDim]; + result.storage[outputDim] = annotation.getSetId(inputDim); + } else { + // broadcasted dimension => new ID + result.storage[outputDim] = nextNewSetId++; + } + } + + result.canonicalize(); + return result; +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::generateSymmetryFromBilinearTranspose( + const PartialSymmetryAnnotation &annotation, + ArrayRef permutation) { + int64_t rank = permutation.size(); + + // Each pair (i, j) where perm[i] = j and perm[j] = i is symmetric + PartialSymmetryAnnotation transposeSymmetry = createKnownUninitialized(rank); + SmallVector assigned(rank, false); + int nextId = 0; + + for (int64_t i = 0; i < rank; ++i) { + if (assigned[i]) + continue; + + int64_t j = permutation[i]; + if (j != i && permutation[j] == i) { + // i and j are swapped, so assign them the same ID + transposeSymmetry.storage[i] = nextId; + transposeSymmetry.storage[j] = nextId; + assigned[i] = true; + assigned[j] = true; + } else { + // dimension i is not swapped, so assign it a new ID + transposeSymmetry.storage[i] = nextId; + assigned[i] = true; + } + nextId++; + } + + transposeSymmetry.canonicalize(); + + // Meet the existing annotation with the transpose symmetry + return meet(annotation, transposeSymmetry); +} + +PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( + const PartialSymmetryAnnotation &lhsAnnotation, + const PartialSymmetryAnnotation &rhsAnnotation, int64_t resultRank, + ArrayRef lhsBatchingDims, ArrayRef rhsBatchingDims, + ArrayRef lhsContractingDims, ArrayRef rhsContractingDims, + bool lhsEqualsRhs) { + + if (lhsAnnotation.isUnknown() || rhsAnnotation.isUnknown()) + return PartialSymmetryAnnotation(); + + PartialSymmetryAnnotation result = createNotSymmetric(resultRank); + + for (int i = 0; i < lhsBatchingDims.size(); ++i) { + for (int j = 0; j < i; ++j) { + if (lhsAnnotation.getSetId(lhsBatchingDims[i]) == + lhsAnnotation.getSetId(lhsBatchingDims[j]) && + rhsAnnotation.getSetId(rhsBatchingDims[i]) == + rhsAnnotation.getSetId(rhsBatchingDims[j])) { + result.storage[i] = result.storage[j]; + } + } + } + + if (lhsEqualsRhs && lhsBatchingDims == rhsBatchingDims && + lhsContractingDims == rhsContractingDims) { + // Also preserve symmetry in non-contracting, non-batching dimensions + // TODO + } + + result.canonicalize(); + return result; +} + +static bool checkPairwiseSymmetry(DenseElementsAttr attr, int64_t dimA, + int64_t dimB) { + auto type = cast(attr.getType()); + auto shape = type.getShape(); + int64_t rank = type.getRank(); + + if (shape[dimA] != shape[dimB]) + return false; + + int64_t numElements = type.getNumElements(); + + if (auto intAttr = dyn_cast(attr)) { + auto values = intAttr.getValues(); + SmallVector strides(rank); + int64_t currentStride = 1; + for (int i = rank - 1; i >= 0; --i) { + strides[i] = currentStride; + currentStride *= shape[i]; + } + + for (int64_t i = 0; i < numElements; ++i) { + SmallVector coords(rank); + int64_t temp = i; + for (int d = 0; d < rank; ++d) { + coords[d] = temp / strides[d]; + temp %= strides[d]; + } + + std::swap(coords[dimA], coords[dimB]); + + int64_t swappedIdx = 0; + for (int d = 0; d < rank; ++d) { + swappedIdx += coords[d] * strides[d]; + } + + if (values[i] != values[swappedIdx]) + return false; + } + return true; + } else if (auto floatAttr = dyn_cast(attr)) { + auto values = floatAttr.getValues(); + SmallVector strides(rank); + int64_t currentStride = 1; + for (int i = rank - 1; i >= 0; --i) { + strides[i] = currentStride; + currentStride *= shape[i]; + } + + for (int64_t i = 0; i < numElements; ++i) { + SmallVector coords(rank); + int64_t temp = i; + for (int d = 0; d < rank; ++d) { + coords[d] = temp / strides[d]; + temp %= strides[d]; + } + + std::swap(coords[dimA], coords[dimB]); + + int64_t swappedIdx = 0; + for (int d = 0; d < rank; ++d) { + swappedIdx += coords[d] * strides[d]; + } + + if (values[i].compare(values[swappedIdx]) != APFloat::cmpEqual) + return false; + } + return true; + } + return false; +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::checkConstant(DenseElementsAttr attr) { + if (auto type = dyn_cast(attr.getType())) { + int64_t rank = type.getRank(); + SmallVector storage(rank); + for (int i = 0; i < rank; ++i) + storage[i] = i; + + for (int i = 0; i < rank; ++i) { + for (int j = i + 1; j < rank; ++j) { + if (storage[i] == storage[j]) + continue; + + if (checkPairwiseSymmetry(attr, i, j)) { + int oldId = storage[j]; + int newId = storage[i]; + for (int k = 0; k < rank; ++k) { + if (storage[k] == oldId) + storage[k] = newId; + } + } + } + } + return PartialSymmetryAnnotation(storage); + } + return PartialSymmetryAnnotation(); +} + +SmallVector> +PartialSymmetryAnnotation::getDimensionSets() const { + llvm::SmallDenseMap> sets; + for (int64_t i = 0; i < (int64_t)storage.size(); ++i) { + sets[storage[i]].push_back(i); + } + + SmallVector sortedKeys; + for (auto &kv : sets) + sortedKeys.push_back(kv.first); + std::sort(sortedKeys.begin(), sortedKeys.end(), + [&](int a, int b) { return sets[a][0] < sets[b][0]; }); + + SmallVector> result; + for (int key : sortedKeys) { + result.push_back(sets[key]); + } + return result; +} + +void PartialSymmetryAnnotation::print(raw_ostream &os) const { + auto dimensionSets = getDimensionSets(); + os << "{"; + bool firstSet = true; + for (const auto &set : dimensionSets) { + if (!firstSet) + os << ", "; + os << "{"; + bool firstElem = true; + for (int64_t dim : set) { + if (!firstElem) + os << ","; + os << dim; + firstElem = false; + } + os << "}"; + firstSet = false; + } + os << "}"; +} + +//===----------------------------------------------------------------------===// +// PartialSymmetryLattice Implementation +//===----------------------------------------------------------------------===// + +ChangeResult PartialSymmetryLattice::join(const AbstractSparseLattice &rhs) { + const auto *rhsStruct = + reinterpret_cast(&rhs); + return join(*rhsStruct); +} + +ChangeResult PartialSymmetryLattice::join(const PartialSymmetryLattice &rhs) { + auto newValue = PartialSymmetryAnnotation::join(getValue(), rhs.getValue()); + if (getValue() == newValue) + return ChangeResult::NoChange; + + setValue(newValue); + return ChangeResult::Change; +} + +void PartialSymmetryLattice::print(raw_ostream &os) const { value.print(os); } + +//===----------------------------------------------------------------------===// +// PartialSymmetryAnalysis Implementation +//===----------------------------------------------------------------------===// + +void PartialSymmetryAnalysis::setToEntryState(PartialSymmetryLattice *lattice) { + lattice->setValue(PartialSymmetryAnnotation()); +} + +LogicalResult PartialSymmetryAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + + SmallVector updatedAnnotation(results.size(), false); + SmallVector propagatedAnnotation(results.size()); + + SmallVector operandAnnotations(operands.size()); + for (size_t i = 0; i < operands.size(); i++) { + operandAnnotations[i] = operands[i]->getValue(); + } + + if (auto transposeOp = dyn_cast(op)) { + updatedAnnotation[0] = true; + propagatedAnnotation[0] = PartialSymmetryAnnotation::propagateTranspose( + operandAnnotations[0], transposeOp.getPermutation()); + } + + if (auto bcastOp = dyn_cast(op)) { + if (results.size() > 0) { + if (auto resultType = + dyn_cast(op->getResult(0).getType())) { + updatedAnnotation[0] = true; + propagatedAnnotation[0] = + PartialSymmetryAnnotation::propagateBroadcastInDim( + operandAnnotations[0], resultType.getRank(), + bcastOp.getBroadcastDimensions()); + } + } + } + + if (auto dotGeneralOp = dyn_cast(op)) { + if (results.size() > 0) { + if (auto resultType = + dyn_cast(op->getResult(0).getType())) { + auto dotDimNumbers = dotGeneralOp.getDotDimensionNumbers(); + auto lhs = dotGeneralOp.getLhs(); + auto rhs = dotGeneralOp.getRhs(); + bool lhsEqualsRhs = (lhs == rhs); + + // Check for transpose pattern: A x A^T or A^T x A + bool transposePatternDetected = false; + ArrayRef transposePermutation; + + if (auto lhsT = lhs.getDefiningOp()) { + if (rhs == lhsT.getOperand()) { + transposePatternDetected = true; + transposePermutation = lhsT.getPermutation(); + } + } + if (auto rhsT = rhs.getDefiningOp()) { + if (lhs == rhsT.getOperand()) { + transposePatternDetected = true; + transposePermutation = rhsT.getPermutation(); + } + } + + // Propagate symmetry through dotGeneral + propagatedAnnotation[0] = + PartialSymmetryAnnotation::propagateDotGeneral( + operandAnnotations[0], operandAnnotations[1], + resultType.getRank(), dotDimNumbers.getLhsBatchingDimensions(), + dotDimNumbers.getRhsBatchingDimensions(), + dotDimNumbers.getLhsContractingDimensions(), + dotDimNumbers.getRhsContractingDimensions(), lhsEqualsRhs); + + // If transpose pattern detected, add symmetry from transpose + if (transposePatternDetected) { + propagatedAnnotation[0] = + PartialSymmetryAnnotation::generateSymmetryFromBilinearTranspose( + propagatedAnnotation[0], transposePermutation); + } + + updatedAnnotation[0] = true; + } + } + } + + if (stablehlo::hasTraitElementwise(op)) { + if (results.size() == 1 && operands.size() > 0) { + propagatedAnnotation[0] = operandAnnotations[0]; + for (size_t i = 1; i < operands.size(); ++i) { + propagatedAnnotation[0] = PartialSymmetryAnnotation::join( + propagatedAnnotation[0], operandAnnotations[i]); + } + updatedAnnotation[0] = true; + + // Generate symmetry from commutative operation with transpose argument + if (op->hasTrait() || + op->hasTrait() && + op->getNumOperands() == 2) { + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + + bool transposePatternDetected = false; + ArrayRef transposePermutation; + + if (auto lhsT = lhs.getDefiningOp()) { + if (rhs == lhsT.getOperand()) { + transposePatternDetected = true; + transposePermutation = lhsT.getPermutation(); + } + } + if (auto rhsT = rhs.getDefiningOp()) { + if (lhs == rhsT.getOperand()) { + transposePatternDetected = true; + transposePermutation = rhsT.getPermutation(); + } + } + + if (transposePatternDetected) { + propagatedAnnotation[0] = + PartialSymmetryAnnotation::generateSymmetryFromBilinearTranspose( + propagatedAnnotation[0], transposePermutation); + } + } + } + } + + DenseElementsAttr denseAttr; + if (matchPattern(op->getResult(0), m_Constant(&denseAttr))) { + updatedAnnotation[0] = true; + propagatedAnnotation[0] = + PartialSymmetryAnnotation::checkConstant(denseAttr); + } + + for (size_t i = 0; i < results.size(); i++) { + if (updatedAnnotation[i]) { + auto resultOrig = results[i]->getValue(); + auto resultNew = + PartialSymmetryAnnotation::join(resultOrig, propagatedAnnotation[i]); + results[i]->setValue(resultNew); + propagateIfChanged(results[i], resultNew == resultOrig + ? ChangeResult::NoChange + : ChangeResult::Change); + } + } + + return success(); +} + +} // namespace enzyme +} // namespace mlir diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h new file mode 100644 index 000000000..64a6dce9c --- /dev/null +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace enzyme { + +// Represents the partial symmetry of a tensor as a partition of its dimensions. +class PartialSymmetryAnnotation { +public: + PartialSymmetryAnnotation() : known(false), storage() {} + + explicit PartialSymmetryAnnotation(ArrayRef storage); + + static PartialSymmetryAnnotation createKnownUninitialized(int64_t rank); + static PartialSymmetryAnnotation createNotSymmetric(int64_t rank); + static PartialSymmetryAnnotation createFullySymmetric(int64_t rank); + + bool isSymmetric(int64_t i, int64_t j) const; + + int getSetId(int64_t i) const { return storage[i]; } + + int64_t getRank() const { return storage.size(); } + + bool isUnknown() const { return !known; } + + static PartialSymmetryAnnotation join(const PartialSymmetryAnnotation &lhs, + const PartialSymmetryAnnotation &rhs); + static PartialSymmetryAnnotation meet(const PartialSymmetryAnnotation &lhs, + const PartialSymmetryAnnotation &rhs); + + static PartialSymmetryAnnotation + propagateTranspose(const PartialSymmetryAnnotation &annotation, + ArrayRef permutation); + + static PartialSymmetryAnnotation + propagateBroadcastInDim(const PartialSymmetryAnnotation &annotation, + int64_t outputRank, + ArrayRef broadcastDimensions); + + static PartialSymmetryAnnotation + propagateDotGeneral(const PartialSymmetryAnnotation &lhsAnnotation, + const PartialSymmetryAnnotation &rhsAnnotation, + int64_t resultRank, ArrayRef lhsBatchingDims, + ArrayRef rhsBatchingDims, + ArrayRef lhsContractingDims, + ArrayRef rhsContractingDims, bool lhsEqualsRhs); + + static PartialSymmetryAnnotation checkConstant(DenseElementsAttr attr); + + static PartialSymmetryAnnotation generateSymmetryFromBilinearTranspose( + const PartialSymmetryAnnotation &annotation, + ArrayRef permutation); + + bool operator==(const PartialSymmetryAnnotation &other) const { + return (!known && !other.known) || storage == other.storage; + } + + SmallVector> getDimensionSets() const; + + void print(raw_ostream &os) const; + +private: + bool known; + SmallVector storage; + + void canonicalize(); +}; + +class PartialSymmetryLattice : public dataflow::AbstractSparseLattice { +public: + using AbstractSparseLattice::AbstractSparseLattice; + + PartialSymmetryLattice(Value v) : AbstractSparseLattice(v) { + if (auto type = dyn_cast(v.getType())) { + value = PartialSymmetryAnnotation::createFullySymmetric(type.getRank()); + } + } + + ChangeResult join(const AbstractSparseLattice &rhs) override; + ChangeResult join(const PartialSymmetryLattice &rhs); + + void print(raw_ostream &os) const override; + + const PartialSymmetryAnnotation &getValue() const { return value; } + void setValue(const PartialSymmetryAnnotation &v) { value = v; } + +private: + bool isUnknown; + PartialSymmetryAnnotation value; +}; + +class PartialSymmetryAnalysis + : public dataflow::SparseForwardDataFlowAnalysis { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + void setToEntryState(PartialSymmetryLattice *lattice) override; + + LogicalResult + visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override; +}; + +} // namespace enzyme +} // namespace mlir diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 975f166f9..cf247db58 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -839,6 +839,7 @@ cc_library( cc_library( name = "XLADerivatives", srcs = glob([ + "Analysis/*.cpp", "Implementations/*.cpp", "Passes/*.cpp", "Dialect/*.cpp", @@ -848,6 +849,7 @@ cc_library( "Utils.cpp", ], hdrs = glob([ + "Analysis/*.h", "Implementations/*.h", "Passes/*.h", "Dialect/*.h", diff --git a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td index 11bf8ca25..2e1a43230 100644 --- a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td +++ b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td @@ -133,4 +133,36 @@ def EnzymeXLA_GuaranteedAnalysisResult : I32EnumAttr<"GuaranteedAnalysisResult", def EnzymeXLA_GuaranteedAnalysisResultAttr : EnumAttr; +def EnzymeXLA_SymmetricDimensionSetAttr : AttrDef { + let summary = "A set of symmetric dimension indices"; + let cppNamespace = "::mlir::enzymexla"; + + let parameters = (ins + "DenseI64ArrayAttr":$dimensions + ); + + let assemblyFormat = [{ + `<` $dimensions `>` + }]; + + let mnemonic = "symmetric_dimension_set"; +} + +def EnzymeXLA_PartialSymmetryAnalysisResultAttr : AttrDef { + let summary = "Sets of partially symmetric dimensions"; + let cppNamespace = "::mlir::enzymexla"; + + let parameters = (ins + ArrayRefParameter<"SymmetricDimensionSetAttr">:$values + ); + + let assemblyFormat = [{ + `<` $values `>` + }]; + + let mnemonic = "partial_symmetry"; +} + #endif // ENZYMEXLA_ATTRS diff --git a/src/enzyme_ad/jax/Passes/PartialSymmetrySimplify.cpp b/src/enzyme_ad/jax/Passes/PartialSymmetrySimplify.cpp new file mode 100644 index 000000000..baba354aa --- /dev/null +++ b/src/enzyme_ad/jax/Passes/PartialSymmetrySimplify.cpp @@ -0,0 +1,100 @@ +#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" + +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "stablehlo/dialect/StablehloOps.h" + +#define DEBUG_TYPE "partial-symmetry-simplify" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_PARTIALSYMMETRYSIMPLIFYPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::dataflow; +using namespace mlir::enzyme; + +namespace { + +class PartialSymmetrySimplifyPass + : public enzyme::impl::PartialSymmetrySimplifyPassBase< + PartialSymmetrySimplifyPass> { +public: + using Base::Base; + + void runOnOperation() override { + DataFlowSolver solver; + + solver.load(); + solver.load(); + solver.load(); + + if (failed(solver.initializeAndRun(getOperation()))) { + return signalPassFailure(); + } + + auto mod = getOperation(); + + mod->walk([&](Operation *op) { + SmallVector partialSymmetryAttrs; + bool anyKnown = false; + + for (auto result : op->getResults()) { + auto *state = + solver.lookupState(result); + if (!state) { + continue; + } + + auto dimensionSets = state->getValue().getDimensionSets(); + + SmallVector dimensionSetAttrs; + for (const auto &set : dimensionSets) { + if (set.size() > 1) { + anyKnown = true; + auto denseAttr = DenseI64ArrayAttr::get(mod.getContext(), set); + auto dimensionSetAttr = enzymexla::SymmetricDimensionSetAttr::get( + mod.getContext(), denseAttr); + dimensionSetAttrs.push_back(dimensionSetAttr); + } + } + + if (dimensionSetAttrs.empty()) { + continue; + } + + auto partialSymmetry = + enzymexla::PartialSymmetryAnalysisResultAttr::get( + mod.getContext(), dimensionSetAttrs); + partialSymmetryAttrs.push_back(partialSymmetry); + } + + if (anyKnown) { + op->setAttr("enzymexla.partial_symmetry", + ArrayAttr::get(mod.getContext(), partialSymmetryAttrs)); + } + + return WalkResult::advance(); + }); + + // TODO: do things here + } +}; + +} // namespace diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 5bf009201..f96d863f0 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1077,4 +1077,13 @@ def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> { ]; } +def PartialSymmetrySimplifyPass : Pass<"partial-symmetry-simplify", "ModuleOp"> { + let summary = "Simplify operations using partial symmetry analysis"; + let dependentDialects = [ + "stablehlo::StablehloDialect", + "enzymexla::EnzymeXLADialect", + "func::FuncDialect", + ]; +} + #endif diff --git a/test/lit_tests/structured_tensors/partial_symmetry.mlir b/test/lit_tests/structured_tensors/partial_symmetry.mlir new file mode 100644 index 000000000..01c5a94a2 --- /dev/null +++ b/test/lit_tests/structured_tensors/partial_symmetry.mlir @@ -0,0 +1,52 @@ +// RUN: enzymexlamlir-opt --partial-symmetry-simplify %s | FileCheck %s + +func.func @test1() -> tensor<2x2xf32> { + %cst = stablehlo.constant dense<[[1.0, 2.0], [2.0, 3.0]]> : tensor<2x2xf32> + return %cst : tensor<2x2xf32> +} +// CHECK: func.func @test1() -> tensor<2x2xf32> { +// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2xf32> +// CHECK-NEXT: return %cst : tensor<2x2xf32> +// CHECK-NEXT: } + +func.func @test2() -> tensor<2x2x2x3xf32> { + %cst0 = stablehlo.constant dense<[[[1.0, 2.0], [3.0, 4.0]], [[3.0, 4.0], [5.0, 6.0]]]> : tensor<2x2x2xf32> + %cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<2x2x2xf32> + %0 = stablehlo.add %cst0, %cst1 : tensor<2x2x2xf32> + %1 = stablehlo.transpose %0, dims = [0, 2, 1] : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + %2 = stablehlo.broadcast_in_dim %1, dims = [1, 0, 2] : (tensor<2x2x2xf32>) -> tensor<2x2x2x3xf32> + return %2 : tensor<2x2x2x3xf32> +} +// CHECK: func.func @test2() -> tensor<2x2x2x3xf32> { +// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2x2xf32> +// CHECK-NEXT: %cst_0 = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1, 2]>>]} dense<{{.*}}> : tensor<2x2x2xf32> +// CHECK-NEXT: %0 = stablehlo.add %cst, %cst_0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : tensor<2x2x2xf32> +// CHECK-NEXT: %1 = stablehlo.transpose %0, dims = [0, 2, 1] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 2]>>]} : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> +// CHECK-NEXT: %2 = stablehlo.broadcast_in_dim %1, dims = [1, 0, 2] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} : (tensor<2x2x2xf32>) -> tensor<2x2x2x3xf32> +// CHECK-NEXT: return %2 : tensor<2x2x2x3xf32> +// CHECK-NEXT: } + +func.func @test3(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> { + %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> + %1 = stablehlo.add %0, %arg0 : tensor<3x2x3xf32> + return %1 : tensor<3x2x3xf32> +} +// CHECK: func.func @test3(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> { +// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: %1 = stablehlo.add %0, %arg0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 2]>>]} : tensor<3x2x3xf32> +// CHECK-NEXT: return %1 : tensor<3x2x3xf32> +// CHECK-NEXT: } + +func.func @test4() -> tensor<2x2xf32> { + %cst0 = stablehlo.constant dense<[[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], [[2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]]> : tensor<2x2x3xf32> + %cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<3x2x2xf32> + %0 = stablehlo.dot_general %cst0, %cst1, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} +// CHECK: func.func @test4() -> tensor<2x2xf32> { +// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2x3xf32> +// CHECK-NEXT: %cst_0 = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32> +// CHECK-NEXT: %0 = stablehlo.dot_general %cst, %cst_0, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> +// CHECK-NEXT: return %0 : tensor<2x2xf32> +// CHECK-NEXT: } + From ddcbcafb1aa2df749efcbc0c8903c95222797e6f Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Fri, 28 Nov 2025 07:57:37 +0000 Subject: [PATCH 02/21] Remove generalized A * A^T handling (soundness unclear) --- src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 6589c84c9..66bce03a5 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -486,9 +486,7 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( // If transpose pattern detected, add symmetry from transpose if (transposePatternDetected) { - propagatedAnnotation[0] = - PartialSymmetryAnnotation::generateSymmetryFromBilinearTranspose( - propagatedAnnotation[0], transposePermutation); + // TODO } updatedAnnotation[0] = true; From a955487a5981eb01a2f9d7e38f37e89ae55562bf Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Nov 2025 17:47:57 +0000 Subject: [PATCH 03/21] Add general dot_general logic for case where lhs = rhs --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 131 +++++++++++++----- .../jax/Analysis/PartialSymmetryAnalysis.h | 3 +- 2 files changed, 99 insertions(+), 35 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 66bce03a5..676a4c097 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -70,28 +70,37 @@ void PartialSymmetryAnnotation::canonicalize() { } } +void PartialSymmetryAnnotation::uniteDimensionSets(int i, int j) { + if (storage[i] == storage[j]) + return; + + int oldId = storage[i]; + int newId = storage[j]; + for (size_t k = 0; k < storage.size(); ++k) { + if (storage[k] == oldId) { + storage[k] = newId; + } + } + + canonicalize(); +} + PartialSymmetryAnnotation PartialSymmetryAnnotation::join(const PartialSymmetryAnnotation &lhs, const PartialSymmetryAnnotation &rhs) { if (lhs.isUnknown() || rhs.isUnknown()) return PartialSymmetryAnnotation(); - PartialSymmetryAnnotation result = createKnownUninitialized(lhs.getRank()); - int nextId = 0; + PartialSymmetryAnnotation result = createNotSymmetric(lhs.getRank()); for (int64_t i = 0; i < lhs.getRank(); ++i) { bool found = false; for (int64_t j = 0; j < i; ++j) { if (lhs.getSetId(i) == lhs.getSetId(j) && rhs.getSetId(i) == rhs.getSetId(j)) { - result.storage[i] = result.storage[j]; - found = true; - break; + result.uniteDimensionSets(i, j); } } - if (!found) { - result.storage[i] = nextId++; - } } result.canonicalize(); @@ -106,22 +115,15 @@ PartialSymmetryAnnotation::meet(const PartialSymmetryAnnotation &lhs, if (rhs.isUnknown()) return lhs; - PartialSymmetryAnnotation result = createKnownUninitialized(lhs.getRank()); - int nextId = 0; + PartialSymmetryAnnotation result = createNotSymmetric(lhs.getRank()); for (int64_t i = 0; i < lhs.getRank(); ++i) { - bool found = false; for (int64_t j = 0; j < i; ++j) { if (lhs.getSetId(i) == lhs.getSetId(j) || rhs.getSetId(i) == rhs.getSetId(j)) { - result.storage[i] = result.storage[j]; - found = true; - break; + result.uniteDimensionSets(i, j); } } - if (!found) { - result.storage[i] = nextId++; - } } result.canonicalize(); @@ -134,8 +136,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateTranspose( if (annotation.isUnknown()) return PartialSymmetryAnnotation(); - PartialSymmetryAnnotation result = - createKnownUninitialized(annotation.getRank()); + PartialSymmetryAnnotation result = createKnownUninitialized(annotation.getRank()); for (int64_t i = 0; i < annotation.getRank(); ++i) { result.storage[i] = annotation.getSetId(permutation[i]); @@ -228,21 +229,90 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( PartialSymmetryAnnotation result = createNotSymmetric(resultRank); + // Preserve symmetry in batching dimensions for (int i = 0; i < lhsBatchingDims.size(); ++i) { for (int j = 0; j < i; ++j) { if (lhsAnnotation.getSetId(lhsBatchingDims[i]) == lhsAnnotation.getSetId(lhsBatchingDims[j]) && rhsAnnotation.getSetId(rhsBatchingDims[i]) == rhsAnnotation.getSetId(rhsBatchingDims[j])) { - result.storage[i] = result.storage[j]; + result.uniteDimensionSets(i, j); } } } + // Preserve symmetry in free (non-contracting, non-batching) dimensions if (lhsEqualsRhs && lhsBatchingDims == rhsBatchingDims && lhsContractingDims == rhsContractingDims) { - // Also preserve symmetry in non-contracting, non-batching dimensions - // TODO + + // annotations must be equal + PartialSymmetryAnnotation annotation = lhsAnnotation; + + bool exchange_valid = true; + + // check that each batching dimension has same ID for LHS and RHS + for (int i = 0; i < lhsBatchingDims.size(); ++i) { + if (lhsAnnotation.getSetId(lhsBatchingDims[i]) != rhsAnnotation.getSetId(rhsBatchingDims[i])) { + exchange_valid = false; + } + } + + // check that the multiset of IDs for contracting dimensions are equal for LHS and RHS + SmallVector lhsContractingIds, rhsContractingIds; + for (int64_t dim : lhsContractingDims) { + lhsContractingIds.push_back(lhsAnnotation.getSetId(dim)); + } + for (int64_t dim : rhsContractingDims) { + rhsContractingIds.push_back(rhsAnnotation.getSetId(dim)); + } + llvm::sort(lhsContractingIds); + llvm::sort(rhsContractingIds); + if (lhsContractingIds != rhsContractingIds) { + exchange_valid = false; + } + + if (exchange_valid) { + SmallVector lhsResultDims; + for (int64_t i = 0; i < annotation.getRank(); ++i) { + if (!llvm::is_contained(lhsBatchingDims, i) && !llvm::is_contained(lhsContractingDims, i)) { + lhsResultDims.push_back(i); + } + } + + SmallVector rhsResultDims; + for (int64_t i = 0; i < annotation.getRank(); ++i) { + if (!llvm::is_contained(rhsBatchingDims, i) && !llvm::is_contained(rhsContractingDims, i)) { + rhsResultDims.push_back(i); + } + } + + // Symmetry within free dimensions of LHS + for (int i = 0; i < lhsResultDims.size(); ++i) { + for (int j = 0; j < i; ++j) { + if (annotation.getSetId(lhsResultDims[i]) == annotation.getSetId(lhsResultDims[j])) { + result.uniteDimensionSets(lhsBatchingDims.size() + i, lhsBatchingDims.size() + j); + } + } + } + + // Symmetry between free dimensions of RHS + for (int i = 0; i < rhsResultDims.size(); ++i) { + for (int j = 0; j < i; ++j) { + if (annotation.getSetId(rhsResultDims[i]) == annotation.getSetId(rhsResultDims[j])) { + result.uniteDimensionSets(lhsBatchingDims.size() + lhsResultDims.size() + i, lhsBatchingDims.size() + lhsResultDims.size() + j); + } + } + } + + // Symmetry between free dimensions of LHS and RHS + for (int i = 0; i < lhsResultDims.size(); ++i) { + for (int j = 0; j < rhsResultDims.size(); ++j) { + if (annotation.getSetId(lhsResultDims[i]) == annotation.getSetId(rhsResultDims[j])) { + result.uniteDimensionSets(lhsBatchingDims.size() + lhsResultDims.size() + i, lhsBatchingDims.size() + j); + } + } + } + } } result.canonicalize(); @@ -324,26 +394,19 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::checkConstant(DenseElementsAttr attr) { if (auto type = dyn_cast(attr.getType())) { int64_t rank = type.getRank(); - SmallVector storage(rank); - for (int i = 0; i < rank; ++i) - storage[i] = i; - + PartialSymmetryAnnotation result = createNotSymmetric(rank); + for (int i = 0; i < rank; ++i) { for (int j = i + 1; j < rank; ++j) { - if (storage[i] == storage[j]) + if (result.getSetId(i) == result.getSetId(j)) continue; if (checkPairwiseSymmetry(attr, i, j)) { - int oldId = storage[j]; - int newId = storage[i]; - for (int k = 0; k < rank; ++k) { - if (storage[k] == oldId) - storage[k] = newId; - } + result.uniteDimensionSets(i, j); } } } - return PartialSymmetryAnnotation(storage); + return result; } return PartialSymmetryAnnotation(); } diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h index 64a6dce9c..3707ce82b 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -32,7 +32,7 @@ class PartialSymmetryAnnotation { int64_t getRank() const { return storage.size(); } bool isUnknown() const { return !known; } - + static PartialSymmetryAnnotation join(const PartialSymmetryAnnotation &lhs, const PartialSymmetryAnnotation &rhs); static PartialSymmetryAnnotation meet(const PartialSymmetryAnnotation &lhs, @@ -74,6 +74,7 @@ class PartialSymmetryAnnotation { SmallVector storage; void canonicalize(); + void uniteDimensionSets(int i, int j); }; class PartialSymmetryLattice : public dataflow::AbstractSparseLattice { From be1229b77f4357e4aae7f4bdaa92fc9e12e2d8fc Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Nov 2025 18:19:15 +0000 Subject: [PATCH 04/21] Add transpose symmetry generation logic for dot general --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 51 ++++++++----------- .../jax/Analysis/PartialSymmetryAnalysis.h | 3 +- 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 676a4c097..ab5afad13 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -222,7 +222,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( const PartialSymmetryAnnotation &rhsAnnotation, int64_t resultRank, ArrayRef lhsBatchingDims, ArrayRef rhsBatchingDims, ArrayRef lhsContractingDims, ArrayRef rhsContractingDims, - bool lhsEqualsRhs) { + bool rhsAliasesLhs, ArrayRef rhsDimToLhs) { if (lhsAnnotation.isUnknown() || rhsAnnotation.isUnknown()) return PartialSymmetryAnnotation(); @@ -242,17 +242,13 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( } // Preserve symmetry in free (non-contracting, non-batching) dimensions - if (lhsEqualsRhs && lhsBatchingDims == rhsBatchingDims && - lhsContractingDims == rhsContractingDims) { - - // annotations must be equal - PartialSymmetryAnnotation annotation = lhsAnnotation; - + if (rhsAliasesLhs) { + bool exchange_valid = true; // check that each batching dimension has same ID for LHS and RHS for (int i = 0; i < lhsBatchingDims.size(); ++i) { - if (lhsAnnotation.getSetId(lhsBatchingDims[i]) != rhsAnnotation.getSetId(rhsBatchingDims[i])) { + if (lhsAnnotation.getSetId(lhsBatchingDims[i]) != lhsAnnotation.getSetId(rhsDimToLhs[rhsBatchingDims[i]])) { exchange_valid = false; } } @@ -263,7 +259,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( lhsContractingIds.push_back(lhsAnnotation.getSetId(dim)); } for (int64_t dim : rhsContractingDims) { - rhsContractingIds.push_back(rhsAnnotation.getSetId(dim)); + rhsContractingIds.push_back(lhsAnnotation.getSetId(rhsDimToLhs[dim])); } llvm::sort(lhsContractingIds); llvm::sort(rhsContractingIds); @@ -273,14 +269,14 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( if (exchange_valid) { SmallVector lhsResultDims; - for (int64_t i = 0; i < annotation.getRank(); ++i) { + for (int64_t i = 0; i < lhsAnnotation.getRank(); ++i) { if (!llvm::is_contained(lhsBatchingDims, i) && !llvm::is_contained(lhsContractingDims, i)) { lhsResultDims.push_back(i); } } SmallVector rhsResultDims; - for (int64_t i = 0; i < annotation.getRank(); ++i) { + for (int64_t i = 0; i < rhsAnnotation.getRank(); ++i) { if (!llvm::is_contained(rhsBatchingDims, i) && !llvm::is_contained(rhsContractingDims, i)) { rhsResultDims.push_back(i); } @@ -289,7 +285,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( // Symmetry within free dimensions of LHS for (int i = 0; i < lhsResultDims.size(); ++i) { for (int j = 0; j < i; ++j) { - if (annotation.getSetId(lhsResultDims[i]) == annotation.getSetId(lhsResultDims[j])) { + if (lhsAnnotation.getSetId(lhsResultDims[i]) == lhsAnnotation.getSetId(lhsResultDims[j])) { result.uniteDimensionSets(lhsBatchingDims.size() + i, lhsBatchingDims.size() + j); } } @@ -298,7 +294,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( // Symmetry between free dimensions of RHS for (int i = 0; i < rhsResultDims.size(); ++i) { for (int j = 0; j < i; ++j) { - if (annotation.getSetId(rhsResultDims[i]) == annotation.getSetId(rhsResultDims[j])) { + if (rhsAnnotation.getSetId(rhsResultDims[i]) == rhsAnnotation.getSetId(rhsResultDims[j])) { result.uniteDimensionSets(lhsBatchingDims.size() + lhsResultDims.size() + i, lhsBatchingDims.size() + lhsResultDims.size() + j); } } @@ -307,8 +303,8 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( // Symmetry between free dimensions of LHS and RHS for (int i = 0; i < lhsResultDims.size(); ++i) { for (int j = 0; j < rhsResultDims.size(); ++j) { - if (annotation.getSetId(lhsResultDims[i]) == annotation.getSetId(rhsResultDims[j])) { - result.uniteDimensionSets(lhsBatchingDims.size() + lhsResultDims.size() + i, lhsBatchingDims.size() + j); + if (lhsAnnotation.getSetId(lhsResultDims[i]) == lhsAnnotation.getSetId(rhsDimToLhs[rhsResultDims[j]])) { + result.uniteDimensionSets(lhsBatchingDims.size() + i, lhsBatchingDims.size() + lhsResultDims.size() + j); } } } @@ -519,22 +515,22 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( auto dotDimNumbers = dotGeneralOp.getDotDimensionNumbers(); auto lhs = dotGeneralOp.getLhs(); auto rhs = dotGeneralOp.getRhs(); - bool lhsEqualsRhs = (lhs == rhs); - - // Check for transpose pattern: A x A^T or A^T x A - bool transposePatternDetected = false; - ArrayRef transposePermutation; + // Check for aliasing between LHS and RHS (up to transpose) + bool rhsAliasesLhs = false; + SmallVector rhsDimToLhs; if (auto lhsT = lhs.getDefiningOp()) { if (rhs == lhsT.getOperand()) { - transposePatternDetected = true; - transposePermutation = lhsT.getPermutation(); + rhsDimToLhs.assign(lhsT.getPermutation().begin(), lhsT.getPermutation().end()); + rhsAliasesLhs = true; } } if (auto rhsT = rhs.getDefiningOp()) { if (lhs == rhsT.getOperand()) { - transposePatternDetected = true; - transposePermutation = rhsT.getPermutation(); + rhsDimToLhs.resize(rhsT.getPermutation().size()); + for (size_t i = 0; i < rhsT.getPermutation().size(); ++i) + rhsDimToLhs[rhsT.getPermutation()[i]] = i; + rhsAliasesLhs = true; } } @@ -545,12 +541,7 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( resultType.getRank(), dotDimNumbers.getLhsBatchingDimensions(), dotDimNumbers.getRhsBatchingDimensions(), dotDimNumbers.getLhsContractingDimensions(), - dotDimNumbers.getRhsContractingDimensions(), lhsEqualsRhs); - - // If transpose pattern detected, add symmetry from transpose - if (transposePatternDetected) { - // TODO - } + dotDimNumbers.getRhsContractingDimensions(), rhsAliasesLhs, rhsDimToLhs); updatedAnnotation[0] = true; } diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h index 3707ce82b..112ed70dd 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -53,7 +53,8 @@ class PartialSymmetryAnnotation { int64_t resultRank, ArrayRef lhsBatchingDims, ArrayRef rhsBatchingDims, ArrayRef lhsContractingDims, - ArrayRef rhsContractingDims, bool lhsEqualsRhs); + ArrayRef rhsContractingDims, bool rhsAliasesLhs, + ArrayRef rhsDimToLhs); static PartialSymmetryAnnotation checkConstant(DenseElementsAttr attr); From 41530c833935cf4e3abc8a07743d36aa1af40f3b Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Nov 2025 18:51:41 +0000 Subject: [PATCH 05/21] Progress refactoring symmetry generation logic for transpose --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 210 +++++++++--------- .../jax/Analysis/PartialSymmetryAnalysis.h | 14 +- 2 files changed, 112 insertions(+), 112 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index ab5afad13..907ea4e11 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -70,7 +70,11 @@ void PartialSymmetryAnnotation::canonicalize() { } } -void PartialSymmetryAnnotation::uniteDimensionSets(int i, int j) { +void PartialSymmetryAnnotation::uniteDimensionSets(int64_t rank, int i, int j) { + if (isUnknown()) { + *this = createNotSymmetric(rank); + } + if (storage[i] == storage[j]) return; @@ -98,7 +102,7 @@ PartialSymmetryAnnotation::join(const PartialSymmetryAnnotation &lhs, for (int64_t j = 0; j < i; ++j) { if (lhs.getSetId(i) == lhs.getSetId(j) && rhs.getSetId(i) == rhs.getSetId(j)) { - result.uniteDimensionSets(i, j); + result.uniteDimensionSets(lhs.getRank(), i, j); } } } @@ -121,7 +125,7 @@ PartialSymmetryAnnotation::meet(const PartialSymmetryAnnotation &lhs, for (int64_t j = 0; j < i; ++j) { if (lhs.getSetId(i) == lhs.getSetId(j) || rhs.getSetId(i) == rhs.getSetId(j)) { - result.uniteDimensionSets(i, j); + result.uniteDimensionSets(lhs.getRank(), i, j); } } } @@ -182,39 +186,44 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateBroadcastInDim( } PartialSymmetryAnnotation -PartialSymmetryAnnotation::generateSymmetryFromBilinearTranspose( - const PartialSymmetryAnnotation &annotation, - ArrayRef permutation) { - int64_t rank = permutation.size(); - - // Each pair (i, j) where perm[i] = j and perm[j] = i is symmetric - PartialSymmetryAnnotation transposeSymmetry = createKnownUninitialized(rank); - SmallVector assigned(rank, false); - int nextId = 0; - - for (int64_t i = 0; i < rank; ++i) { - if (assigned[i]) - continue; - - int64_t j = permutation[i]; - if (j != i && permutation[j] == i) { - // i and j are swapped, so assign them the same ID - transposeSymmetry.storage[i] = nextId; - transposeSymmetry.storage[j] = nextId; - assigned[i] = true; - assigned[j] = true; - } else { - // dimension i is not swapped, so assign it a new ID - transposeSymmetry.storage[i] = nextId; - assigned[i] = true; +PartialSymmetryAnnotation::propagateElementwiseBinary( + const PartialSymmetryAnnotation &lhsAnnotation, + const PartialSymmetryAnnotation &rhsAnnotation, + int64_t resultRank, + bool rhsAliasesLhs, + ArrayRef rhsDimToLhs) { + + PartialSymmetryAnnotation result = join(lhsAnnotation, rhsAnnotation); + + if (rhsAliasesLhs) { + int64_t rank = result.getRank(); + + PartialSymmetryAnnotation transposeSymmetry = createKnownUninitialized(rank); + SmallVector assigned(rank, false); + int nextId = 0; + + for (int64_t i = 0; i < rank; ++i) { + if (assigned[i]) + continue; + + int64_t j = rhsDimToLhs[i]; + if (j != i && rhsDimToLhs[j] == i) { + transposeSymmetry.storage[i] = nextId; + transposeSymmetry.storage[j] = nextId; + assigned[i] = true; + assigned[j] = true; + } else { + transposeSymmetry.storage[i] = nextId; + assigned[i] = true; + } } - nextId++; + + transposeSymmetry.canonicalize(); + + result = meet(result, transposeSymmetry); } - - transposeSymmetry.canonicalize(); - - // Meet the existing annotation with the transpose symmetry - return meet(annotation, transposeSymmetry); + + return result; } PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( @@ -236,7 +245,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( lhsAnnotation.getSetId(lhsBatchingDims[j]) && rhsAnnotation.getSetId(rhsBatchingDims[i]) == rhsAnnotation.getSetId(rhsBatchingDims[j])) { - result.uniteDimensionSets(i, j); + result.uniteDimensionSets(resultRank, i, j); } } } @@ -286,7 +295,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( for (int i = 0; i < lhsResultDims.size(); ++i) { for (int j = 0; j < i; ++j) { if (lhsAnnotation.getSetId(lhsResultDims[i]) == lhsAnnotation.getSetId(lhsResultDims[j])) { - result.uniteDimensionSets(lhsBatchingDims.size() + i, lhsBatchingDims.size() + j); + result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, lhsBatchingDims.size() + j); } } } @@ -295,7 +304,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( for (int i = 0; i < rhsResultDims.size(); ++i) { for (int j = 0; j < i; ++j) { if (rhsAnnotation.getSetId(rhsResultDims[i]) == rhsAnnotation.getSetId(rhsResultDims[j])) { - result.uniteDimensionSets(lhsBatchingDims.size() + lhsResultDims.size() + i, lhsBatchingDims.size() + lhsResultDims.size() + j); + result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + lhsResultDims.size() + i, lhsBatchingDims.size() + lhsResultDims.size() + j); } } } @@ -304,7 +313,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( for (int i = 0; i < lhsResultDims.size(); ++i) { for (int j = 0; j < rhsResultDims.size(); ++j) { if (lhsAnnotation.getSetId(lhsResultDims[i]) == lhsAnnotation.getSetId(rhsDimToLhs[rhsResultDims[j]])) { - result.uniteDimensionSets(lhsBatchingDims.size() + i, lhsBatchingDims.size() + lhsResultDims.size() + j); + result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, lhsBatchingDims.size() + lhsResultDims.size() + j); } } } @@ -398,7 +407,7 @@ PartialSymmetryAnnotation::checkConstant(DenseElementsAttr attr) { continue; if (checkPairwiseSymmetry(attr, i, j)) { - result.uniteDimensionSets(i, j); + result.uniteDimensionSets(rank, i, j); } } } @@ -496,95 +505,84 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( } if (auto bcastOp = dyn_cast(op)) { - if (results.size() > 0) { - if (auto resultType = - dyn_cast(op->getResult(0).getType())) { - updatedAnnotation[0] = true; - propagatedAnnotation[0] = - PartialSymmetryAnnotation::propagateBroadcastInDim( - operandAnnotations[0], resultType.getRank(), - bcastOp.getBroadcastDimensions()); - } + if (auto resultType = dyn_cast(op->getResult(0).getType())) { + updatedAnnotation[0] = true; + propagatedAnnotation[0] = + PartialSymmetryAnnotation::propagateBroadcastInDim( + operandAnnotations[0], resultType.getRank(), + bcastOp.getBroadcastDimensions()); } } if (auto dotGeneralOp = dyn_cast(op)) { - if (results.size() > 0) { - if (auto resultType = - dyn_cast(op->getResult(0).getType())) { - auto dotDimNumbers = dotGeneralOp.getDotDimensionNumbers(); - auto lhs = dotGeneralOp.getLhs(); - auto rhs = dotGeneralOp.getRhs(); - - // Check for aliasing between LHS and RHS (up to transpose) - bool rhsAliasesLhs = false; - SmallVector rhsDimToLhs; - if (auto lhsT = lhs.getDefiningOp()) { - if (rhs == lhsT.getOperand()) { - rhsDimToLhs.assign(lhsT.getPermutation().begin(), lhsT.getPermutation().end()); - rhsAliasesLhs = true; - } + if (auto resultType = dyn_cast(op->getResult(0).getType())) { + auto dotDimNumbers = dotGeneralOp.getDotDimensionNumbers(); + auto lhs = dotGeneralOp.getLhs(); + auto rhs = dotGeneralOp.getRhs(); + + // Check for aliasing between LHS and RHS (up to transpose) + bool rhsAliasesLhs = false; + SmallVector rhsDimToLhs; + if (auto lhsT = lhs.getDefiningOp()) { + if (rhs == lhsT.getOperand()) { + rhsDimToLhs.assign(lhsT.getPermutation().begin(), lhsT.getPermutation().end()); + rhsAliasesLhs = true; } - if (auto rhsT = rhs.getDefiningOp()) { - if (lhs == rhsT.getOperand()) { - rhsDimToLhs.resize(rhsT.getPermutation().size()); - for (size_t i = 0; i < rhsT.getPermutation().size(); ++i) - rhsDimToLhs[rhsT.getPermutation()[i]] = i; - rhsAliasesLhs = true; - } + } + if (auto rhsT = rhs.getDefiningOp()) { + if (lhs == rhsT.getOperand()) { + rhsDimToLhs.resize(rhsT.getPermutation().size()); + for (size_t i = 0; i < rhsT.getPermutation().size(); ++i) + rhsDimToLhs[rhsT.getPermutation()[i]] = i; + rhsAliasesLhs = true; } + } - // Propagate symmetry through dotGeneral - propagatedAnnotation[0] = - PartialSymmetryAnnotation::propagateDotGeneral( - operandAnnotations[0], operandAnnotations[1], - resultType.getRank(), dotDimNumbers.getLhsBatchingDimensions(), - dotDimNumbers.getRhsBatchingDimensions(), - dotDimNumbers.getLhsContractingDimensions(), - dotDimNumbers.getRhsContractingDimensions(), rhsAliasesLhs, rhsDimToLhs); + // Propagate symmetry through dotGeneral + propagatedAnnotation[0] = + PartialSymmetryAnnotation::propagateDotGeneral( + operandAnnotations[0], operandAnnotations[1], + resultType.getRank(), dotDimNumbers.getLhsBatchingDimensions(), + dotDimNumbers.getRhsBatchingDimensions(), + dotDimNumbers.getLhsContractingDimensions(), + dotDimNumbers.getRhsContractingDimensions(), rhsAliasesLhs, rhsDimToLhs); - updatedAnnotation[0] = true; - } + updatedAnnotation[0] = true; } } if (stablehlo::hasTraitElementwise(op)) { - if (results.size() == 1 && operands.size() > 0) { - propagatedAnnotation[0] = operandAnnotations[0]; - for (size_t i = 1; i < operands.size(); ++i) { - propagatedAnnotation[0] = PartialSymmetryAnnotation::join( - propagatedAnnotation[0], operandAnnotations[i]); - } - updatedAnnotation[0] = true; - - // Generate symmetry from commutative operation with transpose argument - if (op->hasTrait() || - op->hasTrait() && - op->getNumOperands() == 2) { + if (auto resultType = dyn_cast(op->getResult(0).getType())) { + if (operands.size() == 1) { + propagatedAnnotation[0] = operandAnnotations[0]; + updatedAnnotation[0] = true; + } else if (operands.size() == 2 && + (op->hasTrait() || + op->hasTrait())) { auto lhs = op->getOperand(0); auto rhs = op->getOperand(1); - - bool transposePatternDetected = false; - ArrayRef transposePermutation; - + + bool rhsAliasesLhs = false; + SmallVector rhsDimToLhs; + if (auto lhsT = lhs.getDefiningOp()) { if (rhs == lhsT.getOperand()) { - transposePatternDetected = true; - transposePermutation = lhsT.getPermutation(); + rhsDimToLhs.assign(lhsT.getPermutation().begin(), lhsT.getPermutation().end()); + rhsAliasesLhs = true; } } if (auto rhsT = rhs.getDefiningOp()) { if (lhs == rhsT.getOperand()) { - transposePatternDetected = true; - transposePermutation = rhsT.getPermutation(); + rhsDimToLhs.resize(rhsT.getPermutation().size()); + for (size_t i = 0; i < rhsT.getPermutation().size(); ++i) + rhsDimToLhs[rhsT.getPermutation()[i]] = i; + rhsAliasesLhs = true; } } - - if (transposePatternDetected) { - propagatedAnnotation[0] = - PartialSymmetryAnnotation::generateSymmetryFromBilinearTranspose( - propagatedAnnotation[0], transposePermutation); - } + + // propagatedAnnotation[0] = PartialSymmetryAnnotation::propagateElementwiseBinary( + // operandAnnotations[0], operandAnnotations[1], resultType.getRank(), rhsAliasesLhs, rhsDimToLhs); + // updatedAnnotation[0] = true; } } } diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h index 112ed70dd..000521292 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -53,14 +53,16 @@ class PartialSymmetryAnnotation { int64_t resultRank, ArrayRef lhsBatchingDims, ArrayRef rhsBatchingDims, ArrayRef lhsContractingDims, - ArrayRef rhsContractingDims, bool rhsAliasesLhs, - ArrayRef rhsDimToLhs); + ArrayRef rhsContractingDims, + bool rhsAliasesLhs, ArrayRef rhsDimToLhs); static PartialSymmetryAnnotation checkConstant(DenseElementsAttr attr); - static PartialSymmetryAnnotation generateSymmetryFromBilinearTranspose( - const PartialSymmetryAnnotation &annotation, - ArrayRef permutation); + static PartialSymmetryAnnotation + propagateElementwiseBinary(const PartialSymmetryAnnotation &lhsAnnotation, + const PartialSymmetryAnnotation &rhsAnnotation, + int64_t resultRank, bool rhsAliasesLhs, + ArrayRef rhsDimToLhs); bool operator==(const PartialSymmetryAnnotation &other) const { return (!known && !other.known) || storage == other.storage; @@ -75,7 +77,7 @@ class PartialSymmetryAnnotation { SmallVector storage; void canonicalize(); - void uniteDimensionSets(int i, int j); + void uniteDimensionSets(int64_t rank, int i, int j); }; class PartialSymmetryLattice : public dataflow::AbstractSparseLattice { From 37ccbde5cdcae6afd04765c29429c8135d14efed Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Nov 2025 19:02:26 +0000 Subject: [PATCH 06/21] Fix issue with rank computation --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 907ea4e11..c7d13e39e 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -196,7 +196,7 @@ PartialSymmetryAnnotation::propagateElementwiseBinary( PartialSymmetryAnnotation result = join(lhsAnnotation, rhsAnnotation); if (rhsAliasesLhs) { - int64_t rank = result.getRank(); + int64_t rank = resultRank; PartialSymmetryAnnotation transposeSymmetry = createKnownUninitialized(rank); SmallVector assigned(rank, false); @@ -216,6 +216,8 @@ PartialSymmetryAnnotation::propagateElementwiseBinary( transposeSymmetry.storage[i] = nextId; assigned[i] = true; } + + nextId++; } transposeSymmetry.canonicalize(); @@ -580,9 +582,11 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( } } - // propagatedAnnotation[0] = PartialSymmetryAnnotation::propagateElementwiseBinary( - // operandAnnotations[0], operandAnnotations[1], resultType.getRank(), rhsAliasesLhs, rhsDimToLhs); - // updatedAnnotation[0] = true; + llvm::errs() << "handling elementwise op" << "\n"; + + propagatedAnnotation[0] = PartialSymmetryAnnotation::propagateElementwiseBinary( + operandAnnotations[0], operandAnnotations[1], resultType.getRank(), rhsAliasesLhs, rhsDimToLhs); + updatedAnnotation[0] = true; } } } From f6d911a5316c8d8372a464d1dacaf5810652815a Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Nov 2025 19:03:36 +0000 Subject: [PATCH 07/21] Simplify elementwise propagation logic --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 23 ++++--------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index c7d13e39e..503ea8996 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -199,30 +199,15 @@ PartialSymmetryAnnotation::propagateElementwiseBinary( int64_t rank = resultRank; PartialSymmetryAnnotation transposeSymmetry = createKnownUninitialized(rank); - SmallVector assigned(rank, false); - int nextId = 0; for (int64_t i = 0; i < rank; ++i) { - if (assigned[i]) - continue; - int64_t j = rhsDimToLhs[i]; - if (j != i && rhsDimToLhs[j] == i) { - transposeSymmetry.storage[i] = nextId; - transposeSymmetry.storage[j] = nextId; - assigned[i] = true; - assigned[j] = true; - } else { - transposeSymmetry.storage[i] = nextId; - assigned[i] = true; - } - - nextId++; + if (rhsDimToLhs[j] == i) { + result.uniteDimensionSets(rank, i, j); + } } - transposeSymmetry.canonicalize(); - - result = meet(result, transposeSymmetry); + result.canonicalize(); } return result; From d8336d1f28a98348df06a2fb8a3248711abb675e Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Nov 2025 19:26:25 +0000 Subject: [PATCH 08/21] Some code cleanup --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 188 ++++++++---------- .../jax/Analysis/PartialSymmetryAnalysis.h | 14 +- src/enzyme_ad/jax/Utils.h | 3 + 3 files changed, 93 insertions(+), 112 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 503ea8996..6a19a563a 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -18,9 +18,9 @@ namespace enzyme { // PartialSymmetryAnnotation Implementation //===----------------------------------------------------------------------===// -PartialSymmetryAnnotation::PartialSymmetryAnnotation(ArrayRef s) +PartialSymmetryAnnotation::PartialSymmetryAnnotation(ArrayRef dimensionSetIDs) : known(true) { - storage.assign(s.begin(), s.end()); + this->dimensionSetIDs.assign(dimensionSetIDs.begin(), dimensionSetIDs.end()); canonicalize(); } @@ -29,7 +29,7 @@ PartialSymmetryAnnotation::createFullySymmetric(int64_t rank) { PartialSymmetryAnnotation annotation; annotation.known = true; for (int64_t i = 0; i < rank; ++i) { - annotation.storage.push_back(0); + annotation.dimensionSetIDs.push_back(0); } return annotation; } @@ -39,7 +39,7 @@ PartialSymmetryAnnotation::createNotSymmetric(int64_t rank) { PartialSymmetryAnnotation annotation; annotation.known = true; for (int64_t i = 0; i < rank; ++i) { - annotation.storage.push_back(i); + annotation.dimensionSetIDs.push_back(i); } return annotation; } @@ -48,21 +48,18 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::createKnownUninitialized(int64_t rank) { PartialSymmetryAnnotation annotation; annotation.known = true; - annotation.storage.resize(rank); + annotation.dimensionSetIDs.resize(rank); return annotation; } bool PartialSymmetryAnnotation::isSymmetric(int64_t i, int64_t j) const { - if (i < 0 || i >= (int64_t)storage.size() || j < 0 || - j >= (int64_t)storage.size()) - return false; - return storage[i] == storage[j]; + return dimensionSetIDs[i] == dimensionSetIDs[j]; } void PartialSymmetryAnnotation::canonicalize() { - llvm::SmallDenseMap map; - int nextId = 0; - for (auto &id : storage) { + llvm::SmallDenseMap map; + int64_t nextId = 0; + for (auto &id : dimensionSetIDs) { if (map.find(id) == map.end()) { map[id] = nextId++; } @@ -70,19 +67,19 @@ void PartialSymmetryAnnotation::canonicalize() { } } -void PartialSymmetryAnnotation::uniteDimensionSets(int64_t rank, int i, int j) { +void PartialSymmetryAnnotation::uniteDimensionSets(int64_t rank, int64_t i, int64_t j) { if (isUnknown()) { *this = createNotSymmetric(rank); } - if (storage[i] == storage[j]) + if (dimensionSetIDs[i] == dimensionSetIDs[j]) return; - int oldId = storage[i]; - int newId = storage[j]; - for (size_t k = 0; k < storage.size(); ++k) { - if (storage[k] == oldId) { - storage[k] = newId; + int64_t oldId = dimensionSetIDs[i]; + int64_t newId = dimensionSetIDs[j]; + for (int64_t k = 0; k < (int64_t)dimensionSetIDs.size(); ++k) { + if (dimensionSetIDs[k] == oldId) { + dimensionSetIDs[k] = newId; } } @@ -143,7 +140,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateTranspose( PartialSymmetryAnnotation result = createKnownUninitialized(annotation.getRank()); for (int64_t i = 0; i < annotation.getRank(); ++i) { - result.storage[i] = annotation.getSetId(permutation[i]); + result.dimensionSetIDs[i] = annotation.getSetId(permutation[i]); } result.canonicalize(); @@ -160,24 +157,24 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateBroadcastInDim( PartialSymmetryAnnotation result = createKnownUninitialized(outputRank); llvm::SmallDenseMap outputToInput; - for (size_t i = 0; i < broadcastDimensions.size(); ++i) { + for (int64_t i = 0; i < (int64_t)broadcastDimensions.size(); ++i) { outputToInput[broadcastDimensions[i]] = i; } - int maxSetId = -1; + int64_t maxSetId = -1; for (int64_t i = 0; i < annotation.getRank(); ++i) { - maxSetId = std::max(maxSetId, annotation.getSetId(i)); + maxSetId = std::max(maxSetId, (int64_t)annotation.getSetId(i)); } - int nextNewSetId = maxSetId + 1; + int64_t nextSetId = maxSetId + 1; for (int64_t outputDim = 0; outputDim < outputRank; ++outputDim) { if (outputToInput.find(outputDim) != outputToInput.end()) { // dimension is preserved => use old ID int64_t inputDim = outputToInput[outputDim]; - result.storage[outputDim] = annotation.getSetId(inputDim); + result.dimensionSetIDs[outputDim] = annotation.getSetId(inputDim); } else { // broadcasted dimension => new ID - result.storage[outputDim] = nextNewSetId++; + result.dimensionSetIDs[outputDim] = nextSetId++; } } @@ -226,8 +223,8 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( PartialSymmetryAnnotation result = createNotSymmetric(resultRank); // Preserve symmetry in batching dimensions - for (int i = 0; i < lhsBatchingDims.size(); ++i) { - for (int j = 0; j < i; ++j) { + for (int64_t i = 0; i < (int64_t)lhsBatchingDims.size(); ++i) { + for (int64_t j = 0; j < i; ++j) { if (lhsAnnotation.getSetId(lhsBatchingDims[i]) == lhsAnnotation.getSetId(lhsBatchingDims[j]) && rhsAnnotation.getSetId(rhsBatchingDims[i]) == @@ -239,18 +236,17 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( // Preserve symmetry in free (non-contracting, non-batching) dimensions if (rhsAliasesLhs) { - bool exchange_valid = true; // check that each batching dimension has same ID for LHS and RHS - for (int i = 0; i < lhsBatchingDims.size(); ++i) { + for (int64_t i = 0; i < (int64_t)lhsBatchingDims.size(); ++i) { if (lhsAnnotation.getSetId(lhsBatchingDims[i]) != lhsAnnotation.getSetId(rhsDimToLhs[rhsBatchingDims[i]])) { exchange_valid = false; } } // check that the multiset of IDs for contracting dimensions are equal for LHS and RHS - SmallVector lhsContractingIds, rhsContractingIds; + SmallVector lhsContractingIds, rhsContractingIds; for (int64_t dim : lhsContractingDims) { lhsContractingIds.push_back(lhsAnnotation.getSetId(dim)); } @@ -279,8 +275,8 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( } // Symmetry within free dimensions of LHS - for (int i = 0; i < lhsResultDims.size(); ++i) { - for (int j = 0; j < i; ++j) { + for (int64_t i = 0; i < (int64_t)lhsResultDims.size(); ++i) { + for (int64_t j = 0; j < i; ++j) { if (lhsAnnotation.getSetId(lhsResultDims[i]) == lhsAnnotation.getSetId(lhsResultDims[j])) { result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, lhsBatchingDims.size() + j); } @@ -288,8 +284,8 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( } // Symmetry between free dimensions of RHS - for (int i = 0; i < rhsResultDims.size(); ++i) { - for (int j = 0; j < i; ++j) { + for (int64_t i = 0; i < (int64_t)rhsResultDims.size(); ++i) { + for (int64_t j = 0; j < i; ++j) { if (rhsAnnotation.getSetId(rhsResultDims[i]) == rhsAnnotation.getSetId(rhsResultDims[j])) { result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + lhsResultDims.size() + i, lhsBatchingDims.size() + lhsResultDims.size() + j); } @@ -297,8 +293,8 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( } // Symmetry between free dimensions of LHS and RHS - for (int i = 0; i < lhsResultDims.size(); ++i) { - for (int j = 0; j < rhsResultDims.size(); ++j) { + for (int64_t i = 0; i < (int64_t)lhsResultDims.size(); ++i) { + for (int64_t j = 0; j < (int64_t)rhsResultDims.size(); ++j) { if (lhsAnnotation.getSetId(lhsResultDims[i]) == lhsAnnotation.getSetId(rhsDimToLhs[rhsResultDims[j]])) { result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, lhsBatchingDims.size() + lhsResultDims.size() + j); } @@ -311,6 +307,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( return result; } +template static bool checkPairwiseSymmetry(DenseElementsAttr attr, int64_t dimA, int64_t dimB) { auto type = cast(attr.getType()); @@ -320,66 +317,44 @@ static bool checkPairwiseSymmetry(DenseElementsAttr attr, int64_t dimA, if (shape[dimA] != shape[dimB]) return false; - int64_t numElements = type.getNumElements(); + if (attr.isSplat()) + return true; - if (auto intAttr = dyn_cast(attr)) { - auto values = intAttr.getValues(); - SmallVector strides(rank); - int64_t currentStride = 1; - for (int i = rank - 1; i >= 0; --i) { - strides[i] = currentStride; - currentStride *= shape[i]; - } + auto values = attr.getValues(); + auto it = values.begin(); - for (int64_t i = 0; i < numElements; ++i) { - SmallVector coords(rank); - int64_t temp = i; - for (int d = 0; d < rank; ++d) { - coords[d] = temp / strides[d]; - temp %= strides[d]; - } - - std::swap(coords[dimA], coords[dimB]); + SmallVector strides(rank); + int64_t currentStride = 1; + for (int64_t i = rank - 1; i >= 0; --i) { + strides[i] = currentStride; + currentStride *= shape[i]; + } - int64_t swappedIdx = 0; - for (int d = 0; d < rank; ++d) { - swappedIdx += coords[d] * strides[d]; - } + int64_t numElements = 1; + for (int64_t s : shape) + numElements *= s; - if (values[i] != values[swappedIdx]) - return false; - } - return true; - } else if (auto floatAttr = dyn_cast(attr)) { - auto values = floatAttr.getValues(); - SmallVector strides(rank); - int64_t currentStride = 1; - for (int i = rank - 1; i >= 0; --i) { - strides[i] = currentStride; - currentStride *= shape[i]; + for (int64_t i = 0; i < numElements; ++i) { + SmallVector coords(rank); + int64_t temp = i; + for (int64_t d = 0; d < rank; ++d) { + coords[d] = temp / strides[d]; + temp %= strides[d]; } - for (int64_t i = 0; i < numElements; ++i) { - SmallVector coords(rank); - int64_t temp = i; - for (int d = 0; d < rank; ++d) { - coords[d] = temp / strides[d]; - temp %= strides[d]; - } - - std::swap(coords[dimA], coords[dimB]); - - int64_t swappedIdx = 0; - for (int d = 0; d < rank; ++d) { - swappedIdx += coords[d] * strides[d]; - } + std::swap(coords[dimA], coords[dimB]); - if (values[i].compare(values[swappedIdx]) != APFloat::cmpEqual) - return false; + int64_t swappedIdx = 0; + for (int64_t d = 0; d < rank; ++d) { + swappedIdx += coords[d] * strides[d]; } - return true; + + auto a = *(it + i); + auto b = *(it + swappedIdx); + if (checkNotEqual(a, b)) + return false; } - return false; + return true; } PartialSymmetryAnnotation @@ -388,13 +363,18 @@ PartialSymmetryAnnotation::checkConstant(DenseElementsAttr attr) { int64_t rank = type.getRank(); PartialSymmetryAnnotation result = createNotSymmetric(rank); - for (int i = 0; i < rank; ++i) { - for (int j = i + 1; j < rank; ++j) { - if (result.getSetId(i) == result.getSetId(j)) - continue; + for (int64_t i = 0; i < rank; ++i) { + for (int64_t j = 0; j < i; ++j) { + bool isSymmetric = false; + if (isa(attr.getElementType())) { + isSymmetric = checkPairwiseSymmetry(attr, i, j); + } else if (isa(attr.getElementType())) { + isSymmetric = checkPairwiseSymmetry(attr, i, j); + } - if (checkPairwiseSymmetry(attr, i, j)) { + if (isSymmetric) { result.uniteDimensionSets(rank, i, j); + continue; } } } @@ -405,19 +385,19 @@ PartialSymmetryAnnotation::checkConstant(DenseElementsAttr attr) { SmallVector> PartialSymmetryAnnotation::getDimensionSets() const { - llvm::SmallDenseMap> sets; - for (int64_t i = 0; i < (int64_t)storage.size(); ++i) { - sets[storage[i]].push_back(i); + llvm::SmallDenseMap> sets; + for (int64_t i = 0; i < (int64_t)dimensionSetIDs.size(); ++i) { + sets[dimensionSetIDs[i]].push_back(i); } - SmallVector sortedKeys; + SmallVector sortedKeys; for (auto &kv : sets) sortedKeys.push_back(kv.first); std::sort(sortedKeys.begin(), sortedKeys.end(), - [&](int a, int b) { return sets[a][0] < sets[b][0]; }); + [&](int64_t a, int64_t b) { return sets[a][0] < sets[b][0]; }); SmallVector> result; - for (int key : sortedKeys) { + for (int64_t key : sortedKeys) { result.push_back(sets[key]); } return result; @@ -481,7 +461,7 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( SmallVector propagatedAnnotation(results.size()); SmallVector operandAnnotations(operands.size()); - for (size_t i = 0; i < operands.size(); i++) { + for (int64_t i = 0; i < (int64_t)operands.size(); i++) { operandAnnotations[i] = operands[i]->getValue(); } @@ -519,7 +499,7 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( if (auto rhsT = rhs.getDefiningOp()) { if (lhs == rhsT.getOperand()) { rhsDimToLhs.resize(rhsT.getPermutation().size()); - for (size_t i = 0; i < rhsT.getPermutation().size(); ++i) + for (int64_t i = 0; i < (int64_t)rhsT.getPermutation().size(); ++i) rhsDimToLhs[rhsT.getPermutation()[i]] = i; rhsAliasesLhs = true; } @@ -561,14 +541,12 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( if (auto rhsT = rhs.getDefiningOp()) { if (lhs == rhsT.getOperand()) { rhsDimToLhs.resize(rhsT.getPermutation().size()); - for (size_t i = 0; i < rhsT.getPermutation().size(); ++i) + for (int64_t i = 0; i < (int64_t)rhsT.getPermutation().size(); ++i) rhsDimToLhs[rhsT.getPermutation()[i]] = i; rhsAliasesLhs = true; } } - llvm::errs() << "handling elementwise op" << "\n"; - propagatedAnnotation[0] = PartialSymmetryAnnotation::propagateElementwiseBinary( operandAnnotations[0], operandAnnotations[1], resultType.getRank(), rhsAliasesLhs, rhsDimToLhs); updatedAnnotation[0] = true; @@ -583,7 +561,7 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( PartialSymmetryAnnotation::checkConstant(denseAttr); } - for (size_t i = 0; i < results.size(); i++) { + for (int64_t i = 0; i < (int64_t)results.size(); i++) { if (updatedAnnotation[i]) { auto resultOrig = results[i]->getValue(); auto resultNew = diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h index 000521292..34cff4b7e 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -17,9 +17,9 @@ namespace enzyme { // Represents the partial symmetry of a tensor as a partition of its dimensions. class PartialSymmetryAnnotation { public: - PartialSymmetryAnnotation() : known(false), storage() {} + PartialSymmetryAnnotation() : known(false), dimensionSetIDs() {} - explicit PartialSymmetryAnnotation(ArrayRef storage); + explicit PartialSymmetryAnnotation(ArrayRef dimensionSetIDs); static PartialSymmetryAnnotation createKnownUninitialized(int64_t rank); static PartialSymmetryAnnotation createNotSymmetric(int64_t rank); @@ -27,9 +27,9 @@ class PartialSymmetryAnnotation { bool isSymmetric(int64_t i, int64_t j) const; - int getSetId(int64_t i) const { return storage[i]; } + int64_t getSetId(int64_t i) const { return dimensionSetIDs[i]; } - int64_t getRank() const { return storage.size(); } + int64_t getRank() const { return dimensionSetIDs.size(); } bool isUnknown() const { return !known; } @@ -65,7 +65,7 @@ class PartialSymmetryAnnotation { ArrayRef rhsDimToLhs); bool operator==(const PartialSymmetryAnnotation &other) const { - return (!known && !other.known) || storage == other.storage; + return (!known && !other.known) || dimensionSetIDs == other.dimensionSetIDs; } SmallVector> getDimensionSets() const; @@ -74,10 +74,10 @@ class PartialSymmetryAnnotation { private: bool known; - SmallVector storage; + SmallVector dimensionSetIDs; void canonicalize(); - void uniteDimensionSets(int64_t rank, int i, int j); + void uniteDimensionSets(int64_t rank, int64_t i, int64_t j); }; class PartialSymmetryLattice : public dataflow::AbstractSparseLattice { diff --git a/src/enzyme_ad/jax/Utils.h b/src/enzyme_ad/jax/Utils.h index ccbe670a7..8791349b0 100644 --- a/src/enzyme_ad/jax/Utils.h +++ b/src/enzyme_ad/jax/Utils.h @@ -307,6 +307,9 @@ bool mayAlias(mlir::MemoryEffects::EffectInstance a, bool mayAlias(mlir::MemoryEffects::EffectInstance a, mlir::Value b); +bool checkNotEqual(llvm::APInt a, llvm::APInt b); +bool checkNotEqual(llvm::APFloat a, llvm::APFloat b); + bool canApplyNoNanPattern(bool allowOnFloatingPointMath, Type Ty); bool canApplyNoNanPattern(bool allowOnFloatingPointMath, Type Ty, mlir::Operation *op, PatternRewriter &rewriter); From a1721182dce39e5e6d9ab0f231f15d2dd98adf45 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Nov 2025 19:59:44 +0000 Subject: [PATCH 09/21] Add test of dot_general symm gen + remove seemingly unnecessary "Unknown" lattice element --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 50 ++++++------------- .../jax/Analysis/PartialSymmetryAnalysis.h | 11 ++-- .../structured_tensors/partial_symmetry.mlir | 11 ++++ 3 files changed, 29 insertions(+), 43 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 6a19a563a..1f01b1cca 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -18,8 +18,7 @@ namespace enzyme { // PartialSymmetryAnnotation Implementation //===----------------------------------------------------------------------===// -PartialSymmetryAnnotation::PartialSymmetryAnnotation(ArrayRef dimensionSetIDs) - : known(true) { +PartialSymmetryAnnotation::PartialSymmetryAnnotation(ArrayRef dimensionSetIDs) { this->dimensionSetIDs.assign(dimensionSetIDs.begin(), dimensionSetIDs.end()); canonicalize(); } @@ -27,7 +26,6 @@ PartialSymmetryAnnotation::PartialSymmetryAnnotation(ArrayRef dimension PartialSymmetryAnnotation PartialSymmetryAnnotation::createFullySymmetric(int64_t rank) { PartialSymmetryAnnotation annotation; - annotation.known = true; for (int64_t i = 0; i < rank; ++i) { annotation.dimensionSetIDs.push_back(0); } @@ -37,7 +35,6 @@ PartialSymmetryAnnotation::createFullySymmetric(int64_t rank) { PartialSymmetryAnnotation PartialSymmetryAnnotation::createNotSymmetric(int64_t rank) { PartialSymmetryAnnotation annotation; - annotation.known = true; for (int64_t i = 0; i < rank; ++i) { annotation.dimensionSetIDs.push_back(i); } @@ -45,9 +42,8 @@ PartialSymmetryAnnotation::createNotSymmetric(int64_t rank) { } PartialSymmetryAnnotation -PartialSymmetryAnnotation::createKnownUninitialized(int64_t rank) { +PartialSymmetryAnnotation::createUninitialized(int64_t rank) { PartialSymmetryAnnotation annotation; - annotation.known = true; annotation.dimensionSetIDs.resize(rank); return annotation; } @@ -68,10 +64,6 @@ void PartialSymmetryAnnotation::canonicalize() { } void PartialSymmetryAnnotation::uniteDimensionSets(int64_t rank, int64_t i, int64_t j) { - if (isUnknown()) { - *this = createNotSymmetric(rank); - } - if (dimensionSetIDs[i] == dimensionSetIDs[j]) return; @@ -89,9 +81,6 @@ void PartialSymmetryAnnotation::uniteDimensionSets(int64_t rank, int64_t i, int6 PartialSymmetryAnnotation PartialSymmetryAnnotation::join(const PartialSymmetryAnnotation &lhs, const PartialSymmetryAnnotation &rhs) { - if (lhs.isUnknown() || rhs.isUnknown()) - return PartialSymmetryAnnotation(); - PartialSymmetryAnnotation result = createNotSymmetric(lhs.getRank()); for (int64_t i = 0; i < lhs.getRank(); ++i) { @@ -111,11 +100,6 @@ PartialSymmetryAnnotation::join(const PartialSymmetryAnnotation &lhs, PartialSymmetryAnnotation PartialSymmetryAnnotation::meet(const PartialSymmetryAnnotation &lhs, const PartialSymmetryAnnotation &rhs) { - if (lhs.isUnknown()) - return rhs; - if (rhs.isUnknown()) - return lhs; - PartialSymmetryAnnotation result = createNotSymmetric(lhs.getRank()); for (int64_t i = 0; i < lhs.getRank(); ++i) { @@ -134,10 +118,8 @@ PartialSymmetryAnnotation::meet(const PartialSymmetryAnnotation &lhs, PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateTranspose( const PartialSymmetryAnnotation &annotation, ArrayRef permutation) { - if (annotation.isUnknown()) - return PartialSymmetryAnnotation(); - PartialSymmetryAnnotation result = createKnownUninitialized(annotation.getRank()); + PartialSymmetryAnnotation result = createUninitialized(annotation.getRank()); for (int64_t i = 0; i < annotation.getRank(); ++i) { result.dimensionSetIDs[i] = annotation.getSetId(permutation[i]); @@ -151,10 +133,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateBroadcastInDim( const PartialSymmetryAnnotation &annotation, int64_t outputRank, ArrayRef broadcastDimensions) { - if (annotation.isUnknown()) - return PartialSymmetryAnnotation(); - - PartialSymmetryAnnotation result = createKnownUninitialized(outputRank); + PartialSymmetryAnnotation result = createUninitialized(outputRank); llvm::SmallDenseMap outputToInput; for (int64_t i = 0; i < (int64_t)broadcastDimensions.size(); ++i) { @@ -193,14 +172,10 @@ PartialSymmetryAnnotation::propagateElementwiseBinary( PartialSymmetryAnnotation result = join(lhsAnnotation, rhsAnnotation); if (rhsAliasesLhs) { - int64_t rank = resultRank; - - PartialSymmetryAnnotation transposeSymmetry = createKnownUninitialized(rank); - - for (int64_t i = 0; i < rank; ++i) { + for (int64_t i = 0; i < resultRank; ++i) { int64_t j = rhsDimToLhs[i]; if (rhsDimToLhs[j] == i) { - result.uniteDimensionSets(rank, i, j); + result.uniteDimensionSets(resultRank, i, j); } } @@ -216,10 +191,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( ArrayRef lhsBatchingDims, ArrayRef rhsBatchingDims, ArrayRef lhsContractingDims, ArrayRef rhsContractingDims, bool rhsAliasesLhs, ArrayRef rhsDimToLhs) { - - if (lhsAnnotation.isUnknown() || rhsAnnotation.isUnknown()) - return PartialSymmetryAnnotation(); - + PartialSymmetryAnnotation result = createNotSymmetric(resultRank); // Preserve symmetry in batching dimensions @@ -244,6 +216,8 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( exchange_valid = false; } } + + llvm::errs() << "still ok 1\n"; // check that the multiset of IDs for contracting dimensions are equal for LHS and RHS SmallVector lhsContractingIds, rhsContractingIds; @@ -258,6 +232,8 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( if (lhsContractingIds != rhsContractingIds) { exchange_valid = false; } + + llvm::errs() << "still ok 2\n"; if (exchange_valid) { SmallVector lhsResultDims; @@ -450,7 +426,7 @@ void PartialSymmetryLattice::print(raw_ostream &os) const { value.print(os); } //===----------------------------------------------------------------------===// void PartialSymmetryAnalysis::setToEntryState(PartialSymmetryLattice *lattice) { - lattice->setValue(PartialSymmetryAnnotation()); + lattice->setValue(PartialSymmetryAnnotation::createNotSymmetric(lattice->getValue().getRank())); } LogicalResult PartialSymmetryAnalysis::visitOperation( @@ -504,6 +480,8 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( rhsAliasesLhs = true; } } + + llvm::errs() << "dotGeneral rhsAliasesLhs: " << rhsAliasesLhs << "\n"; // Propagate symmetry through dotGeneral propagatedAnnotation[0] = diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h index 34cff4b7e..b7bafb99c 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -17,11 +17,11 @@ namespace enzyme { // Represents the partial symmetry of a tensor as a partition of its dimensions. class PartialSymmetryAnnotation { public: - PartialSymmetryAnnotation() : known(false), dimensionSetIDs() {} + PartialSymmetryAnnotation() : dimensionSetIDs() {} explicit PartialSymmetryAnnotation(ArrayRef dimensionSetIDs); - static PartialSymmetryAnnotation createKnownUninitialized(int64_t rank); + static PartialSymmetryAnnotation createUninitialized(int64_t rank); static PartialSymmetryAnnotation createNotSymmetric(int64_t rank); static PartialSymmetryAnnotation createFullySymmetric(int64_t rank); @@ -31,8 +31,7 @@ class PartialSymmetryAnnotation { int64_t getRank() const { return dimensionSetIDs.size(); } - bool isUnknown() const { return !known; } - + static PartialSymmetryAnnotation join(const PartialSymmetryAnnotation &lhs, const PartialSymmetryAnnotation &rhs); static PartialSymmetryAnnotation meet(const PartialSymmetryAnnotation &lhs, @@ -65,7 +64,7 @@ class PartialSymmetryAnnotation { ArrayRef rhsDimToLhs); bool operator==(const PartialSymmetryAnnotation &other) const { - return (!known && !other.known) || dimensionSetIDs == other.dimensionSetIDs; + return dimensionSetIDs == other.dimensionSetIDs; } SmallVector> getDimensionSets() const; @@ -73,7 +72,6 @@ class PartialSymmetryAnnotation { void print(raw_ostream &os) const; private: - bool known; SmallVector dimensionSetIDs; void canonicalize(); @@ -99,7 +97,6 @@ class PartialSymmetryLattice : public dataflow::AbstractSparseLattice { void setValue(const PartialSymmetryAnnotation &v) { value = v; } private: - bool isUnknown; PartialSymmetryAnnotation value; }; diff --git a/test/lit_tests/structured_tensors/partial_symmetry.mlir b/test/lit_tests/structured_tensors/partial_symmetry.mlir index 01c5a94a2..5206523a6 100644 --- a/test/lit_tests/structured_tensors/partial_symmetry.mlir +++ b/test/lit_tests/structured_tensors/partial_symmetry.mlir @@ -50,3 +50,14 @@ func.func @test4() -> tensor<2x2xf32> { // CHECK-NEXT: return %0 : tensor<2x2xf32> // CHECK-NEXT: } +func.func @test5(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { + %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32> + %1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32> + return %1 : tensor<3x3x3xf32> +} +// CHECK: func.func @test5(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { +// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32> +// CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32> +// CHECK-NEXT: return %1 : tensor<3x3x3xf32> +// CHECK-NEXT: } + From 058234448ab2d1170dc697f642389798db1ef651 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Nov 2025 20:03:14 +0000 Subject: [PATCH 10/21] Format --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 125 ++++++++++-------- .../jax/Analysis/PartialSymmetryAnalysis.h | 7 +- 2 files changed, 75 insertions(+), 57 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 1f01b1cca..0967eba5a 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -18,7 +18,8 @@ namespace enzyme { // PartialSymmetryAnnotation Implementation //===----------------------------------------------------------------------===// -PartialSymmetryAnnotation::PartialSymmetryAnnotation(ArrayRef dimensionSetIDs) { +PartialSymmetryAnnotation::PartialSymmetryAnnotation( + ArrayRef dimensionSetIDs) { this->dimensionSetIDs.assign(dimensionSetIDs.begin(), dimensionSetIDs.end()); canonicalize(); } @@ -63,10 +64,11 @@ void PartialSymmetryAnnotation::canonicalize() { } } -void PartialSymmetryAnnotation::uniteDimensionSets(int64_t rank, int64_t i, int64_t j) { +void PartialSymmetryAnnotation::uniteDimensionSets(int64_t rank, int64_t i, + int64_t j) { if (dimensionSetIDs[i] == dimensionSetIDs[j]) return; - + int64_t oldId = dimensionSetIDs[i]; int64_t newId = dimensionSetIDs[j]; for (int64_t k = 0; k < (int64_t)dimensionSetIDs.size(); ++k) { @@ -74,7 +76,7 @@ void PartialSymmetryAnnotation::uniteDimensionSets(int64_t rank, int64_t i, int6 dimensionSetIDs[k] = newId; } } - + canonicalize(); } @@ -161,27 +163,24 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateBroadcastInDim( return result; } -PartialSymmetryAnnotation -PartialSymmetryAnnotation::propagateElementwiseBinary( +PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateElementwiseBinary( const PartialSymmetryAnnotation &lhsAnnotation, - const PartialSymmetryAnnotation &rhsAnnotation, - int64_t resultRank, - bool rhsAliasesLhs, - ArrayRef rhsDimToLhs) { - + const PartialSymmetryAnnotation &rhsAnnotation, int64_t resultRank, + bool rhsAliasesLhs, ArrayRef rhsDimToLhs) { + PartialSymmetryAnnotation result = join(lhsAnnotation, rhsAnnotation); - + if (rhsAliasesLhs) { for (int64_t i = 0; i < resultRank; ++i) { int64_t j = rhsDimToLhs[i]; if (rhsDimToLhs[j] == i) { result.uniteDimensionSets(resultRank, i, j); - } + } } - + result.canonicalize(); } - + return result; } @@ -191,7 +190,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( ArrayRef lhsBatchingDims, ArrayRef rhsBatchingDims, ArrayRef lhsContractingDims, ArrayRef rhsContractingDims, bool rhsAliasesLhs, ArrayRef rhsDimToLhs) { - + PartialSymmetryAnnotation result = createNotSymmetric(resultRank); // Preserve symmetry in batching dimensions @@ -209,17 +208,19 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( // Preserve symmetry in free (non-contracting, non-batching) dimensions if (rhsAliasesLhs) { bool exchange_valid = true; - + // check that each batching dimension has same ID for LHS and RHS for (int64_t i = 0; i < (int64_t)lhsBatchingDims.size(); ++i) { - if (lhsAnnotation.getSetId(lhsBatchingDims[i]) != lhsAnnotation.getSetId(rhsDimToLhs[rhsBatchingDims[i]])) { + if (lhsAnnotation.getSetId(lhsBatchingDims[i]) != + lhsAnnotation.getSetId(rhsDimToLhs[rhsBatchingDims[i]])) { exchange_valid = false; } } llvm::errs() << "still ok 1\n"; - - // check that the multiset of IDs for contracting dimensions are equal for LHS and RHS + + // check that the multiset of IDs for contracting dimensions are equal for + // LHS and RHS SmallVector lhsContractingIds, rhsContractingIds; for (int64_t dim : lhsContractingDims) { lhsContractingIds.push_back(lhsAnnotation.getSetId(dim)); @@ -234,18 +235,20 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( } llvm::errs() << "still ok 2\n"; - + if (exchange_valid) { SmallVector lhsResultDims; for (int64_t i = 0; i < lhsAnnotation.getRank(); ++i) { - if (!llvm::is_contained(lhsBatchingDims, i) && !llvm::is_contained(lhsContractingDims, i)) { + if (!llvm::is_contained(lhsBatchingDims, i) && + !llvm::is_contained(lhsContractingDims, i)) { lhsResultDims.push_back(i); } } - + SmallVector rhsResultDims; for (int64_t i = 0; i < rhsAnnotation.getRank(); ++i) { - if (!llvm::is_contained(rhsBatchingDims, i) && !llvm::is_contained(rhsContractingDims, i)) { + if (!llvm::is_contained(rhsBatchingDims, i) && + !llvm::is_contained(rhsContractingDims, i)) { rhsResultDims.push_back(i); } } @@ -253,26 +256,34 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( // Symmetry within free dimensions of LHS for (int64_t i = 0; i < (int64_t)lhsResultDims.size(); ++i) { for (int64_t j = 0; j < i; ++j) { - if (lhsAnnotation.getSetId(lhsResultDims[i]) == lhsAnnotation.getSetId(lhsResultDims[j])) { - result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, lhsBatchingDims.size() + j); + if (lhsAnnotation.getSetId(lhsResultDims[i]) == + lhsAnnotation.getSetId(lhsResultDims[j])) { + result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, + lhsBatchingDims.size() + j); } } } - + // Symmetry between free dimensions of RHS for (int64_t i = 0; i < (int64_t)rhsResultDims.size(); ++i) { for (int64_t j = 0; j < i; ++j) { - if (rhsAnnotation.getSetId(rhsResultDims[i]) == rhsAnnotation.getSetId(rhsResultDims[j])) { - result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + lhsResultDims.size() + i, lhsBatchingDims.size() + lhsResultDims.size() + j); + if (rhsAnnotation.getSetId(rhsResultDims[i]) == + rhsAnnotation.getSetId(rhsResultDims[j])) { + result.uniteDimensionSets( + resultRank, lhsBatchingDims.size() + lhsResultDims.size() + i, + lhsBatchingDims.size() + lhsResultDims.size() + j); } } } - + // Symmetry between free dimensions of LHS and RHS for (int64_t i = 0; i < (int64_t)lhsResultDims.size(); ++i) { for (int64_t j = 0; j < (int64_t)rhsResultDims.size(); ++j) { - if (lhsAnnotation.getSetId(lhsResultDims[i]) == lhsAnnotation.getSetId(rhsDimToLhs[rhsResultDims[j]])) { - result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, lhsBatchingDims.size() + lhsResultDims.size() + j); + if (lhsAnnotation.getSetId(lhsResultDims[i]) == + lhsAnnotation.getSetId(rhsDimToLhs[rhsResultDims[j]])) { + result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, + lhsBatchingDims.size() + + lhsResultDims.size() + j); } } } @@ -338,7 +349,7 @@ PartialSymmetryAnnotation::checkConstant(DenseElementsAttr attr) { if (auto type = dyn_cast(attr.getType())) { int64_t rank = type.getRank(); PartialSymmetryAnnotation result = createNotSymmetric(rank); - + for (int64_t i = 0; i < rank; ++i) { for (int64_t j = 0; j < i; ++j) { bool isSymmetric = false; @@ -426,7 +437,8 @@ void PartialSymmetryLattice::print(raw_ostream &os) const { value.print(os); } //===----------------------------------------------------------------------===// void PartialSymmetryAnalysis::setToEntryState(PartialSymmetryLattice *lattice) { - lattice->setValue(PartialSymmetryAnnotation::createNotSymmetric(lattice->getValue().getRank())); + lattice->setValue(PartialSymmetryAnnotation::createNotSymmetric( + lattice->getValue().getRank())); } LogicalResult PartialSymmetryAnalysis::visitOperation( @@ -448,7 +460,8 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( } if (auto bcastOp = dyn_cast(op)) { - if (auto resultType = dyn_cast(op->getResult(0).getType())) { + if (auto resultType = + dyn_cast(op->getResult(0).getType())) { updatedAnnotation[0] = true; propagatedAnnotation[0] = PartialSymmetryAnnotation::propagateBroadcastInDim( @@ -458,7 +471,8 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( } if (auto dotGeneralOp = dyn_cast(op)) { - if (auto resultType = dyn_cast(op->getResult(0).getType())) { + if (auto resultType = + dyn_cast(op->getResult(0).getType())) { auto dotDimNumbers = dotGeneralOp.getDotDimensionNumbers(); auto lhs = dotGeneralOp.getLhs(); auto rhs = dotGeneralOp.getRhs(); @@ -468,7 +482,8 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( SmallVector rhsDimToLhs; if (auto lhsT = lhs.getDefiningOp()) { if (rhs == lhsT.getOperand()) { - rhsDimToLhs.assign(lhsT.getPermutation().begin(), lhsT.getPermutation().end()); + rhsDimToLhs.assign(lhsT.getPermutation().begin(), + lhsT.getPermutation().end()); rhsAliasesLhs = true; } } @@ -480,39 +495,41 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( rhsAliasesLhs = true; } } - + llvm::errs() << "dotGeneral rhsAliasesLhs: " << rhsAliasesLhs << "\n"; // Propagate symmetry through dotGeneral - propagatedAnnotation[0] = - PartialSymmetryAnnotation::propagateDotGeneral( - operandAnnotations[0], operandAnnotations[1], - resultType.getRank(), dotDimNumbers.getLhsBatchingDimensions(), - dotDimNumbers.getRhsBatchingDimensions(), - dotDimNumbers.getLhsContractingDimensions(), - dotDimNumbers.getRhsContractingDimensions(), rhsAliasesLhs, rhsDimToLhs); + propagatedAnnotation[0] = PartialSymmetryAnnotation::propagateDotGeneral( + operandAnnotations[0], operandAnnotations[1], resultType.getRank(), + dotDimNumbers.getLhsBatchingDimensions(), + dotDimNumbers.getRhsBatchingDimensions(), + dotDimNumbers.getLhsContractingDimensions(), + dotDimNumbers.getRhsContractingDimensions(), rhsAliasesLhs, + rhsDimToLhs); updatedAnnotation[0] = true; } } if (stablehlo::hasTraitElementwise(op)) { - if (auto resultType = dyn_cast(op->getResult(0).getType())) { + if (auto resultType = + dyn_cast(op->getResult(0).getType())) { if (operands.size() == 1) { propagatedAnnotation[0] = operandAnnotations[0]; updatedAnnotation[0] = true; } else if (operands.size() == 2 && - (op->hasTrait() || + (op->hasTrait() || op->hasTrait())) { auto lhs = op->getOperand(0); auto rhs = op->getOperand(1); - + bool rhsAliasesLhs = false; SmallVector rhsDimToLhs; - + if (auto lhsT = lhs.getDefiningOp()) { if (rhs == lhsT.getOperand()) { - rhsDimToLhs.assign(lhsT.getPermutation().begin(), lhsT.getPermutation().end()); + rhsDimToLhs.assign(lhsT.getPermutation().begin(), + lhsT.getPermutation().end()); rhsAliasesLhs = true; } } @@ -524,9 +541,11 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( rhsAliasesLhs = true; } } - - propagatedAnnotation[0] = PartialSymmetryAnnotation::propagateElementwiseBinary( - operandAnnotations[0], operandAnnotations[1], resultType.getRank(), rhsAliasesLhs, rhsDimToLhs); + + propagatedAnnotation[0] = + PartialSymmetryAnnotation::propagateElementwiseBinary( + operandAnnotations[0], operandAnnotations[1], + resultType.getRank(), rhsAliasesLhs, rhsDimToLhs); updatedAnnotation[0] = true; } } diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h index b7bafb99c..32701d06c 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -31,7 +31,6 @@ class PartialSymmetryAnnotation { int64_t getRank() const { return dimensionSetIDs.size(); } - static PartialSymmetryAnnotation join(const PartialSymmetryAnnotation &lhs, const PartialSymmetryAnnotation &rhs); static PartialSymmetryAnnotation meet(const PartialSymmetryAnnotation &lhs, @@ -52,12 +51,12 @@ class PartialSymmetryAnnotation { int64_t resultRank, ArrayRef lhsBatchingDims, ArrayRef rhsBatchingDims, ArrayRef lhsContractingDims, - ArrayRef rhsContractingDims, - bool rhsAliasesLhs, ArrayRef rhsDimToLhs); + ArrayRef rhsContractingDims, bool rhsAliasesLhs, + ArrayRef rhsDimToLhs); static PartialSymmetryAnnotation checkConstant(DenseElementsAttr attr); - static PartialSymmetryAnnotation + static PartialSymmetryAnnotation propagateElementwiseBinary(const PartialSymmetryAnnotation &lhsAnnotation, const PartialSymmetryAnnotation &rhsAnnotation, int64_t resultRank, bool rhsAliasesLhs, From 8f93a7085654bca03b96d1479c968f98e7fb97ba Mon Sep 17 00:00:00 2001 From: gaurav-arya Date: Sat, 29 Nov 2025 15:13:37 -0500 Subject: [PATCH 11/21] Remove debug messages --- src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 0967eba5a..544f1e4db 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -217,8 +217,6 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( } } - llvm::errs() << "still ok 1\n"; - // check that the multiset of IDs for contracting dimensions are equal for // LHS and RHS SmallVector lhsContractingIds, rhsContractingIds; @@ -234,8 +232,6 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( exchange_valid = false; } - llvm::errs() << "still ok 2\n"; - if (exchange_valid) { SmallVector lhsResultDims; for (int64_t i = 0; i < lhsAnnotation.getRank(); ++i) { @@ -496,8 +492,6 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( } } - llvm::errs() << "dotGeneral rhsAliasesLhs: " << rhsAliasesLhs << "\n"; - // Propagate symmetry through dotGeneral propagatedAnnotation[0] = PartialSymmetryAnnotation::propagateDotGeneral( operandAnnotations[0], operandAnnotations[1], resultType.getRank(), From 5606fafc7b80a7cff9655b891e5dafda9420737b Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Nov 2025 18:04:34 -0500 Subject: [PATCH 12/21] Add n-dim transpose removal opt --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 19 +++++ .../jax/Analysis/PartialSymmetryAnalysis.h | 3 + src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 81 ++++++++++++++++++- ...mplify.cpp => PartialSymmetryAnnotate.cpp} | 13 ++- src/enzyme_ad/jax/Passes/Passes.td | 4 +- .../jax/TransformOps/TransformOps.td | 5 ++ .../structured_tensors/partial_symmetry.mlir | 28 ++++--- 7 files changed, 130 insertions(+), 23 deletions(-) rename src/enzyme_ad/jax/Passes/{PartialSymmetrySimplify.cpp => PartialSymmetryAnnotate.cpp} (89%) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 544f1e4db..dbc594ea3 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -386,6 +386,25 @@ PartialSymmetryAnnotation::getDimensionSets() const { return result; } +PartialSymmetryAnnotation +PartialSymmetryAnnotation::fromDimensionSets(int64_t rank, + ArrayRef> dimensionSets) { + SmallVector dimensionSetIDs(rank); + for (int64_t i = 0; i < rank; ++i) { + dimensionSetIDs[i] = i; + } + + // Note that dimensionSets is not assumed to be a complete partition. + // Missing dimensions are treated as separate sets. + for (auto dims : dimensionSets) { + for (int64_t i = 1; i < (int64_t)dims.size(); ++i) { + dimensionSetIDs[dims[i]] = dimensionSetIDs[dims[0]]; + } + } + + return PartialSymmetryAnnotation(dimensionSetIDs); +} + void PartialSymmetryAnnotation::print(raw_ostream &os) const { auto dimensionSets = getDimensionSets(); os << "{"; diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h index 32701d06c..2e4ca163a 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -68,6 +68,9 @@ class PartialSymmetryAnnotation { SmallVector> getDimensionSets() const; + static PartialSymmetryAnnotation + fromDimensionSets(int64_t rank, ArrayRef> dimensionSets); + void print(raw_ostream &os) const; private: diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 97a70af80..c853d0a63 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -33,6 +33,7 @@ #include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h" #include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h" #include "src/enzyme_ad/jax/Passes/Passes.h" +#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h" #include "src/enzyme_ad/jax/Passes/StructuredTensors.h" #include "src/enzyme_ad/jax/Utils.h" #include "stablehlo/dialect/Base.h" @@ -55,6 +56,7 @@ #include "llvm/ADT/MapVector.h" #include #include +#include #define DEBUG_TYPE "enzymehloopt" namespace mlir { @@ -6953,6 +6955,82 @@ struct TransposeSymmetricSimplify } }; +static std::optional +getPartialSymmetryFromAttr(Value val) { + auto op = val.getDefiningOp(); + if (!op) + return std::nullopt; + + auto arrayAttr = + op->getAttrOfType("enzymexla.partial_symmetry"); + if (!arrayAttr || arrayAttr.empty()) + return std::nullopt; + + // Get the result number for this value + auto opResult = dyn_cast(val); + if (!opResult) + return std::nullopt; + + auto resultNumber = opResult.getResultNumber(); + if (resultNumber >= arrayAttr.size()) + return std::nullopt; + + auto partialSymmetryAttr = dyn_cast( + arrayAttr[resultNumber]); + if (!partialSymmetryAttr) + return std::nullopt; + + auto dimensionSetAttrs = partialSymmetryAttr.getValues(); + auto rank = cast(val.getType()).getRank(); + + SmallVector> dimensionSets; + for (auto dimensionSetAttr : dimensionSetAttrs) { + auto dims = dimensionSetAttr.getDimensions().asArrayRef(); + dimensionSets.push_back(dims); + } + + return enzyme::PartialSymmetryAnnotation::fromDimensionSets(rank, dimensionSets); +} + +struct TransposePartialSymmetrySimplify + : public CheckedOpRewritePattern { + using CheckedOpRewritePattern< + stablehlo::TransposeOp, + TransposePartialSymmetrySimplify>::CheckedOpRewritePattern; + + LogicalResult matchAndRewriteImpl(stablehlo::TransposeOp op, + PatternRewriter &rewriter) const { + auto operand = op.getOperand(); + auto perm = op.getPermutation(); + + // Get partial symmetry annotation from the operand + auto annotationOpt = getPartialSymmetryFromAttr(operand); + if (!annotationOpt.has_value()) + return failure(); + + auto annotation = annotationOpt.value(); + + // Check if the transpose is an identity based on partial symmetry + // A transpose is identity if permuting dimensions doesn't change which + // dimensions are in the same symmetric set + bool isIdentity = true; + for (int64_t i = 0; i < (int64_t)perm.size(); ++i) { + if (annotation.getSetId(i) != annotation.getSetId(perm[i])) { + isIdentity = false; + break; + } + } + + if (isIdentity) { + rewriter.replaceOp(op, operand); + return success(); + } + + return failure(); + } +}; + struct NoNanSelfSubSimplify : public NoNanCheckedOpRewritePattern { @@ -26758,7 +26836,8 @@ struct EnzymeHLOOptPass NoNanAddSubSimplify, NoNanMulSimplify, NoNanDivSimplify>( (no_nan || all_finite), context); - patterns.add(context); + patterns.add( + context); patterns.add(context); // syrk patterns diff --git a/src/enzyme_ad/jax/Passes/PartialSymmetrySimplify.cpp b/src/enzyme_ad/jax/Passes/PartialSymmetryAnnotate.cpp similarity index 89% rename from src/enzyme_ad/jax/Passes/PartialSymmetrySimplify.cpp rename to src/enzyme_ad/jax/Passes/PartialSymmetryAnnotate.cpp index baba354aa..2bfd814e4 100644 --- a/src/enzyme_ad/jax/Passes/PartialSymmetrySimplify.cpp +++ b/src/enzyme_ad/jax/Passes/PartialSymmetryAnnotate.cpp @@ -17,11 +17,11 @@ #include "src/enzyme_ad/jax/Dialect/Ops.h" #include "stablehlo/dialect/StablehloOps.h" -#define DEBUG_TYPE "partial-symmetry-simplify" +#define DEBUG_TYPE "partial-symmetry-annotate" namespace mlir { namespace enzyme { -#define GEN_PASS_DEF_PARTIALSYMMETRYSIMPLIFYPASS +#define GEN_PASS_DEF_PARTIALSYMMETRYANNOTATEPASS #include "src/enzyme_ad/jax/Passes/Passes.h.inc" } // namespace enzyme } // namespace mlir @@ -32,9 +32,9 @@ using namespace mlir::enzyme; namespace { -class PartialSymmetrySimplifyPass - : public enzyme::impl::PartialSymmetrySimplifyPassBase< - PartialSymmetrySimplifyPass> { +class PartialSymmetryAnnotatePass + : public enzyme::impl::PartialSymmetryAnnotatePassBase< + PartialSymmetryAnnotatePass> { public: using Base::Base; @@ -51,6 +51,7 @@ class PartialSymmetrySimplifyPass auto mod = getOperation(); + // Annotate all operations with partial symmetry information mod->walk([&](Operation *op) { SmallVector partialSymmetryAttrs; bool anyKnown = false; @@ -92,8 +93,6 @@ class PartialSymmetrySimplifyPass return WalkResult::advance(); }); - - // TODO: do things here } }; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index f96d863f0..f7a5e9307 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1077,8 +1077,8 @@ def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> { ]; } -def PartialSymmetrySimplifyPass : Pass<"partial-symmetry-simplify", "ModuleOp"> { - let summary = "Simplify operations using partial symmetry analysis"; +def PartialSymmetryAnnotatePass : Pass<"partial-symmetry-annotate", "ModuleOp"> { + let summary = "Annotate operations using partial symmetry analysis"; let dependentDialects = [ "stablehlo::StablehloDialect", "enzymexla::EnzymeXLADialect", diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index cac919513..f2a67bdb4 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -608,6 +608,11 @@ def ApplyTransposeSymmetricSimplify : EnzymeHLOPatternOp< let patterns = ["TransposeSymmetricSimplify"]; } +def ApplyTransposePartialSymmetrySimplify : EnzymeHLOPatternOp< + "transpose_partial_symmetry_simplify"> { + let patterns = ["TransposePartialSymmetrySimplify"]; +} + def ApplyFactorScalarsInDotGeneral : EnzymeHLOPatternOp< "factor_scalars_in_dot_general"> { let patterns = ["FactorScalarsInDotGeneral"]; diff --git a/test/lit_tests/structured_tensors/partial_symmetry.mlir b/test/lit_tests/structured_tensors/partial_symmetry.mlir index 5206523a6..9985086e6 100644 --- a/test/lit_tests/structured_tensors/partial_symmetry.mlir +++ b/test/lit_tests/structured_tensors/partial_symmetry.mlir @@ -1,23 +1,25 @@ -// RUN: enzymexlamlir-opt --partial-symmetry-simplify %s | FileCheck %s +// RUN: enzymexlamlir-opt --partial-symmetry-annotate --enzyme-hlo-generate-td="patterns=transpose_partial_symmetry_simplify" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s -func.func @test1() -> tensor<2x2xf32> { +func.func @test_constant() -> tensor<2x2xf32> { %cst = stablehlo.constant dense<[[1.0, 2.0], [2.0, 3.0]]> : tensor<2x2xf32> - return %cst : tensor<2x2xf32> + %0 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> } -// CHECK: func.func @test1() -> tensor<2x2xf32> { +// CHECK: func.func @test_constant() -> tensor<2x2xf32> { // CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2xf32> // CHECK-NEXT: return %cst : tensor<2x2xf32> // CHECK-NEXT: } -func.func @test2() -> tensor<2x2x2x3xf32> { +func.func @test_propagate() -> tensor<2x2x2x3xf32> { %cst0 = stablehlo.constant dense<[[[1.0, 2.0], [3.0, 4.0]], [[3.0, 4.0], [5.0, 6.0]]]> : tensor<2x2x2xf32> %cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<2x2x2xf32> %0 = stablehlo.add %cst0, %cst1 : tensor<2x2x2xf32> %1 = stablehlo.transpose %0, dims = [0, 2, 1] : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> %2 = stablehlo.broadcast_in_dim %1, dims = [1, 0, 2] : (tensor<2x2x2xf32>) -> tensor<2x2x2x3xf32> - return %2 : tensor<2x2x2x3xf32> + %3 = stablehlo.transpose %2, dims = [0, 2, 1, 3] : (tensor<2x2x2x3xf32>) -> tensor<2x2x2x3xf32> + return %3 : tensor<2x2x2x3xf32> } -// CHECK: func.func @test2() -> tensor<2x2x2x3xf32> { +// CHECK: func.func @test_propagate() -> tensor<2x2x2x3xf32> { // CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2x2xf32> // CHECK-NEXT: %cst_0 = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1, 2]>>]} dense<{{.*}}> : tensor<2x2x2xf32> // CHECK-NEXT: %0 = stablehlo.add %cst, %cst_0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : tensor<2x2x2xf32> @@ -26,36 +28,36 @@ func.func @test2() -> tensor<2x2x2x3xf32> { // CHECK-NEXT: return %2 : tensor<2x2x2x3xf32> // CHECK-NEXT: } -func.func @test3(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> { +func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> { %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> %1 = stablehlo.add %0, %arg0 : tensor<3x2x3xf32> return %1 : tensor<3x2x3xf32> } -// CHECK: func.func @test3(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> { +// CHECK: func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> { // CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> // CHECK-NEXT: %1 = stablehlo.add %0, %arg0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 2]>>]} : tensor<3x2x3xf32> // CHECK-NEXT: return %1 : tensor<3x2x3xf32> // CHECK-NEXT: } -func.func @test4() -> tensor<2x2xf32> { +func.func @test_dot_propagate() -> tensor<2x2xf32> { %cst0 = stablehlo.constant dense<[[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], [[2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]]> : tensor<2x2x3xf32> %cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<3x2x2xf32> %0 = stablehlo.dot_general %cst0, %cst1, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } -// CHECK: func.func @test4() -> tensor<2x2xf32> { +// CHECK: func.func @test_dot_propagate() -> tensor<2x2xf32> { // CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2x3xf32> // CHECK-NEXT: %cst_0 = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32> // CHECK-NEXT: %0 = stablehlo.dot_general %cst, %cst_0, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: return %0 : tensor<2x2xf32> // CHECK-NEXT: } -func.func @test5(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { +func.func @test_dot_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32> return %1 : tensor<3x3x3xf32> } -// CHECK: func.func @test5(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { +// CHECK: func.func @test_dot_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { // CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32> // CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32> // CHECK-NEXT: return %1 : tensor<3x3x3xf32> From 77e5b3afc5fbc1538fb4b2fc460d412c0c096afa Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 1 Dec 2025 20:01:48 +0000 Subject: [PATCH 13/21] Recognize existing partial symmetry annotations in IR --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 95 ++++++++++++++++++- .../jax/Analysis/PartialSymmetryAnalysis.h | 13 +-- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 44 +-------- .../structured_tensors/partial_symmetry.mlir | 12 +-- 4 files changed, 104 insertions(+), 60 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index dbc594ea3..b45381a8b 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -5,6 +5,8 @@ #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" #include "stablehlo/dialect/StablehloOps.h" #include "llvm/ADT/DenseMap.h" @@ -387,7 +389,7 @@ PartialSymmetryAnnotation::getDimensionSets() const { } PartialSymmetryAnnotation -PartialSymmetryAnnotation::fromDimensionSets(int64_t rank, +PartialSymmetryAnnotation::createFromDimensionSets(int64_t rank, ArrayRef> dimensionSets) { SmallVector dimensionSetIDs(rank); for (int64_t i = 0; i < rank; ++i) { @@ -405,6 +407,78 @@ PartialSymmetryAnnotation::fromDimensionSets(int64_t rank, return PartialSymmetryAnnotation(dimensionSetIDs); } +std::optional +PartialSymmetryAnnotation::createFromIR(Value val) { + auto blockArg = dyn_cast(val); + if (blockArg) { + auto op = blockArg.getOwner()->getParentOp(); + auto funcOpInterface = dyn_cast(op); + if (!funcOpInterface) { + return std::nullopt; + } + + auto argAttrs = funcOpInterface.getArgAttrs(blockArg.getArgNumber()); + for (auto attr : argAttrs) { + if (attr.getName() == "enzymexla.partial_symmetry") { + auto arrayAttr = dyn_cast(attr.getValue()); + if (!arrayAttr || arrayAttr.empty()) { + continue; + } + + auto partialSymmetryAttr = dyn_cast( + arrayAttr[0]); + + if (!partialSymmetryAttr) { + continue; + } + + auto dimensionSetAttrs = partialSymmetryAttr.getValues(); + auto rank = cast(val.getType()).getRank(); + + SmallVector> dimensionSets; + for (auto dimensionSetAttr : dimensionSetAttrs) { + auto dims = dimensionSetAttr.getDimensions().asArrayRef(); + dimensionSets.push_back(dims); + } + + return PartialSymmetryAnnotation::createFromDimensionSets(rank, dimensionSets); + } + } + return std::nullopt; + } + + auto op = val.getDefiningOp(); + if (!op) + return std::nullopt; + + auto arrayAttr = + op->getAttrOfType("enzymexla.partial_symmetry"); + if (!arrayAttr || arrayAttr.empty()) + return std::nullopt; + + auto opResult = dyn_cast(val); + if (!opResult) + return std::nullopt; + + auto resultNumber = opResult.getResultNumber(); + + auto partialSymmetryAttr = dyn_cast( + arrayAttr[resultNumber]); + if (!partialSymmetryAttr) + return std::nullopt; + + auto dimensionSetAttrs = partialSymmetryAttr.getValues(); + auto rank = cast(val.getType()).getRank(); + + SmallVector> dimensionSets; + for (auto dimensionSetAttr : dimensionSetAttrs) { + auto dims = dimensionSetAttr.getDimensions().asArrayRef(); + dimensionSets.push_back(dims); + } + + return PartialSymmetryAnnotation::createFromDimensionSets(rank, dimensionSets); +} + void PartialSymmetryAnnotation::print(raw_ostream &os) const { auto dimensionSets = getDimensionSets(); os << "{"; @@ -430,6 +504,19 @@ void PartialSymmetryAnnotation::print(raw_ostream &os) const { // PartialSymmetryLattice Implementation //===----------------------------------------------------------------------===// +PartialSymmetryLattice::PartialSymmetryLattice(Value v) : AbstractSparseLattice(v) { + if (auto type = dyn_cast(v.getType())) { + // Trust existing IR annotations if present. + auto annotation = PartialSymmetryAnnotation::createFromIR(v); + if (annotation.has_value()) { + value = annotation.value(); + return; + } + + value = PartialSymmetryAnnotation::createFullySymmetric(type.getRank()); + } +} + ChangeResult PartialSymmetryLattice::join(const AbstractSparseLattice &rhs) { const auto *rhsStruct = reinterpret_cast(&rhs); @@ -452,6 +539,12 @@ void PartialSymmetryLattice::print(raw_ostream &os) const { value.print(os); } //===----------------------------------------------------------------------===// void PartialSymmetryAnalysis::setToEntryState(PartialSymmetryLattice *lattice) { + auto annotation = PartialSymmetryAnnotation::createFromIR(lattice->getAnchor()); + if (annotation.has_value()) { + lattice->setValue(annotation.value()); + return; + } + lattice->setValue(PartialSymmetryAnnotation::createNotSymmetric( lattice->getValue().getRank())); } diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h index 2e4ca163a..7f22b51d2 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -26,9 +26,7 @@ class PartialSymmetryAnnotation { static PartialSymmetryAnnotation createFullySymmetric(int64_t rank); bool isSymmetric(int64_t i, int64_t j) const; - int64_t getSetId(int64_t i) const { return dimensionSetIDs[i]; } - int64_t getRank() const { return dimensionSetIDs.size(); } static PartialSymmetryAnnotation join(const PartialSymmetryAnnotation &lhs, @@ -67,9 +65,8 @@ class PartialSymmetryAnnotation { } SmallVector> getDimensionSets() const; - - static PartialSymmetryAnnotation - fromDimensionSets(int64_t rank, ArrayRef> dimensionSets); + static PartialSymmetryAnnotation createFromDimensionSets(int64_t rank, ArrayRef> dimensionSets); + static std::optional createFromIR(Value val); void print(raw_ostream &os) const; @@ -84,11 +81,7 @@ class PartialSymmetryLattice : public dataflow::AbstractSparseLattice { public: using AbstractSparseLattice::AbstractSparseLattice; - PartialSymmetryLattice(Value v) : AbstractSparseLattice(v) { - if (auto type = dyn_cast(v.getType())) { - value = PartialSymmetryAnnotation::createFullySymmetric(type.getRank()); - } - } + PartialSymmetryLattice(Value v); ChangeResult join(const AbstractSparseLattice &rhs) override; ChangeResult join(const PartialSymmetryLattice &rhs); diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index c853d0a63..4fe204a93 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" @@ -6955,43 +6956,6 @@ struct TransposeSymmetricSimplify } }; -static std::optional -getPartialSymmetryFromAttr(Value val) { - auto op = val.getDefiningOp(); - if (!op) - return std::nullopt; - - auto arrayAttr = - op->getAttrOfType("enzymexla.partial_symmetry"); - if (!arrayAttr || arrayAttr.empty()) - return std::nullopt; - - // Get the result number for this value - auto opResult = dyn_cast(val); - if (!opResult) - return std::nullopt; - - auto resultNumber = opResult.getResultNumber(); - if (resultNumber >= arrayAttr.size()) - return std::nullopt; - - auto partialSymmetryAttr = dyn_cast( - arrayAttr[resultNumber]); - if (!partialSymmetryAttr) - return std::nullopt; - - auto dimensionSetAttrs = partialSymmetryAttr.getValues(); - auto rank = cast(val.getType()).getRank(); - - SmallVector> dimensionSets; - for (auto dimensionSetAttr : dimensionSetAttrs) { - auto dims = dimensionSetAttr.getDimensions().asArrayRef(); - dimensionSets.push_back(dims); - } - - return enzyme::PartialSymmetryAnnotation::fromDimensionSets(rank, dimensionSets); -} - struct TransposePartialSymmetrySimplify : public CheckedOpRewritePattern { @@ -7004,16 +6968,12 @@ struct TransposePartialSymmetrySimplify auto operand = op.getOperand(); auto perm = op.getPermutation(); - // Get partial symmetry annotation from the operand - auto annotationOpt = getPartialSymmetryFromAttr(operand); + auto annotationOpt = enzyme::PartialSymmetryAnnotation::createFromIR(operand); if (!annotationOpt.has_value()) return failure(); auto annotation = annotationOpt.value(); - // Check if the transpose is an identity based on partial symmetry - // A transpose is identity if permuting dimensions doesn't change which - // dimensions are in the same symmetric set bool isIdentity = true; for (int64_t i = 0; i < (int64_t)perm.size(); ++i) { if (annotation.getSetId(i) != annotation.getSetId(perm[i])) { diff --git a/test/lit_tests/structured_tensors/partial_symmetry.mlir b/test/lit_tests/structured_tensors/partial_symmetry.mlir index 9985086e6..2ee9d7d8f 100644 --- a/test/lit_tests/structured_tensors/partial_symmetry.mlir +++ b/test/lit_tests/structured_tensors/partial_symmetry.mlir @@ -39,16 +39,14 @@ func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3x // CHECK-NEXT: return %1 : tensor<3x2x3xf32> // CHECK-NEXT: } -func.func @test_dot_propagate() -> tensor<2x2xf32> { - %cst0 = stablehlo.constant dense<[[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], [[2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]]> : tensor<2x2x3xf32> +func.func @test_dot_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> { %cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<3x2x2xf32> - %0 = stablehlo.dot_general %cst0, %cst1, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> + %0 = stablehlo.dot_general %arg0, %cst1, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } -// CHECK: func.func @test_dot_propagate() -> tensor<2x2xf32> { -// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2x3xf32> -// CHECK-NEXT: %cst_0 = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32> -// CHECK-NEXT: %0 = stablehlo.dot_general %cst, %cst_0, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> +// CHECK: func.func @test_dot_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> { +// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32> +// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %cst, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: return %0 : tensor<2x2xf32> // CHECK-NEXT: } From 1f40a6d82709ec1567d3845208337f7a81a7aba2 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 1 Dec 2025 21:06:04 +0000 Subject: [PATCH 14/21] Format --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 40 ++++++++++--------- .../jax/Analysis/PartialSymmetryAnalysis.h | 4 +- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 11 ++--- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index b45381a8b..19f489ceb 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -388,9 +388,8 @@ PartialSymmetryAnnotation::getDimensionSets() const { return result; } -PartialSymmetryAnnotation -PartialSymmetryAnnotation::createFromDimensionSets(int64_t rank, - ArrayRef> dimensionSets) { +PartialSymmetryAnnotation PartialSymmetryAnnotation::createFromDimensionSets( + int64_t rank, ArrayRef> dimensionSets) { SmallVector dimensionSetIDs(rank); for (int64_t i = 0; i < rank; ++i) { dimensionSetIDs[i] = i; @@ -424,14 +423,15 @@ PartialSymmetryAnnotation::createFromIR(Value val) { if (!arrayAttr || arrayAttr.empty()) { continue; } - - auto partialSymmetryAttr = dyn_cast( - arrayAttr[0]); - + + auto partialSymmetryAttr = + dyn_cast( + arrayAttr[0]); + if (!partialSymmetryAttr) { continue; } - + auto dimensionSetAttrs = partialSymmetryAttr.getValues(); auto rank = cast(val.getType()).getRank(); @@ -441,7 +441,8 @@ PartialSymmetryAnnotation::createFromIR(Value val) { dimensionSets.push_back(dims); } - return PartialSymmetryAnnotation::createFromDimensionSets(rank, dimensionSets); + return PartialSymmetryAnnotation::createFromDimensionSets( + rank, dimensionSets); } } return std::nullopt; @@ -451,8 +452,7 @@ PartialSymmetryAnnotation::createFromIR(Value val) { if (!op) return std::nullopt; - auto arrayAttr = - op->getAttrOfType("enzymexla.partial_symmetry"); + auto arrayAttr = op->getAttrOfType("enzymexla.partial_symmetry"); if (!arrayAttr || arrayAttr.empty()) return std::nullopt; @@ -462,8 +462,9 @@ PartialSymmetryAnnotation::createFromIR(Value val) { auto resultNumber = opResult.getResultNumber(); - auto partialSymmetryAttr = dyn_cast( - arrayAttr[resultNumber]); + auto partialSymmetryAttr = + dyn_cast( + arrayAttr[resultNumber]); if (!partialSymmetryAttr) return std::nullopt; @@ -476,7 +477,8 @@ PartialSymmetryAnnotation::createFromIR(Value val) { dimensionSets.push_back(dims); } - return PartialSymmetryAnnotation::createFromDimensionSets(rank, dimensionSets); + return PartialSymmetryAnnotation::createFromDimensionSets(rank, + dimensionSets); } void PartialSymmetryAnnotation::print(raw_ostream &os) const { @@ -504,7 +506,8 @@ void PartialSymmetryAnnotation::print(raw_ostream &os) const { // PartialSymmetryLattice Implementation //===----------------------------------------------------------------------===// -PartialSymmetryLattice::PartialSymmetryLattice(Value v) : AbstractSparseLattice(v) { +PartialSymmetryLattice::PartialSymmetryLattice(Value v) + : AbstractSparseLattice(v) { if (auto type = dyn_cast(v.getType())) { // Trust existing IR annotations if present. auto annotation = PartialSymmetryAnnotation::createFromIR(v); @@ -512,7 +515,7 @@ PartialSymmetryLattice::PartialSymmetryLattice(Value v) : AbstractSparseLattice( value = annotation.value(); return; } - + value = PartialSymmetryAnnotation::createFullySymmetric(type.getRank()); } } @@ -539,12 +542,13 @@ void PartialSymmetryLattice::print(raw_ostream &os) const { value.print(os); } //===----------------------------------------------------------------------===// void PartialSymmetryAnalysis::setToEntryState(PartialSymmetryLattice *lattice) { - auto annotation = PartialSymmetryAnnotation::createFromIR(lattice->getAnchor()); + auto annotation = + PartialSymmetryAnnotation::createFromIR(lattice->getAnchor()); if (annotation.has_value()) { lattice->setValue(annotation.value()); return; } - + lattice->setValue(PartialSymmetryAnnotation::createNotSymmetric( lattice->getValue().getRank())); } diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h index 7f22b51d2..54ccdc9e5 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -65,7 +65,9 @@ class PartialSymmetryAnnotation { } SmallVector> getDimensionSets() const; - static PartialSymmetryAnnotation createFromDimensionSets(int64_t rank, ArrayRef> dimensionSets); + static PartialSymmetryAnnotation + createFromDimensionSets(int64_t rank, + ArrayRef> dimensionSets); static std::optional createFromIR(Value val); void print(raw_ostream &os) const; diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 4fe204a93..97f3ff6d5 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -19,22 +19,22 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" -#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "shardy/dialect/sdy/ir/utils.h" +#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h" #include "src/enzyme_ad/jax/CheckedRewrite.h" #include "src/enzyme_ad/jax/Dialect/Dialect.h" #include "src/enzyme_ad/jax/Dialect/Ops.h" #include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h" #include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h" #include "src/enzyme_ad/jax/Passes/Passes.h" -#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h" #include "src/enzyme_ad/jax/Passes/StructuredTensors.h" #include "src/enzyme_ad/jax/Utils.h" #include "stablehlo/dialect/Base.h" @@ -6964,14 +6964,15 @@ struct TransposePartialSymmetrySimplify TransposePartialSymmetrySimplify>::CheckedOpRewritePattern; LogicalResult matchAndRewriteImpl(stablehlo::TransposeOp op, - PatternRewriter &rewriter) const { + PatternRewriter &rewriter) const { auto operand = op.getOperand(); auto perm = op.getPermutation(); - auto annotationOpt = enzyme::PartialSymmetryAnnotation::createFromIR(operand); + auto annotationOpt = + enzyme::PartialSymmetryAnnotation::createFromIR(operand); if (!annotationOpt.has_value()) return failure(); - + auto annotation = annotationOpt.value(); bool isIdentity = true; From e62ad338343502b8d7a58a2ef5991314a59292e4 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 1 Dec 2025 21:10:43 +0000 Subject: [PATCH 15/21] Fix test func naming --- test/lit_tests/structured_tensors/partial_symmetry.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/lit_tests/structured_tensors/partial_symmetry.mlir b/test/lit_tests/structured_tensors/partial_symmetry.mlir index 2ee9d7d8f..5303cfd2b 100644 --- a/test/lit_tests/structured_tensors/partial_symmetry.mlir +++ b/test/lit_tests/structured_tensors/partial_symmetry.mlir @@ -39,23 +39,23 @@ func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3x // CHECK-NEXT: return %1 : tensor<3x2x3xf32> // CHECK-NEXT: } -func.func @test_dot_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> { +func.func @test_dot_general_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> { %cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<3x2x2xf32> %0 = stablehlo.dot_general %arg0, %cst1, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } -// CHECK: func.func @test_dot_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> { +// CHECK: func.func @test_dot_general_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> { // CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32> // CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %cst, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: return %0 : tensor<2x2xf32> // CHECK-NEXT: } -func.func @test_dot_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { +func.func @test_dot_general_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32> return %1 : tensor<3x3x3xf32> } -// CHECK: func.func @test_dot_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { +// CHECK: func.func @test_dot_general_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { // CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32> // CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32> // CHECK-NEXT: return %1 : tensor<3x3x3xf32> From 62b180ee7b0d7a04e1fd9e73983e60c0dcb0fac3 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 2 Dec 2025 18:39:05 +0000 Subject: [PATCH 16/21] Add missing symmetry detection for broadcast in dim --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 5 +++-- .../structured_tensors/partial_symmetry.mlir | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 19f489ceb..7aaf89261 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -156,8 +156,9 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateBroadcastInDim( int64_t inputDim = outputToInput[outputDim]; result.dimensionSetIDs[outputDim] = annotation.getSetId(inputDim); } else { - // broadcasted dimension => new ID - result.dimensionSetIDs[outputDim] = nextSetId++; + // result is constant in each broadcasted dimension, + // so they are partially symmetric with each other + result.dimensionSetIDs[outputDim] = nextSetId; } } diff --git a/test/lit_tests/structured_tensors/partial_symmetry.mlir b/test/lit_tests/structured_tensors/partial_symmetry.mlir index 5303cfd2b..f22176468 100644 --- a/test/lit_tests/structured_tensors/partial_symmetry.mlir +++ b/test/lit_tests/structured_tensors/partial_symmetry.mlir @@ -61,3 +61,20 @@ func.func @test_dot_general_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tenso // CHECK-NEXT: return %1 : tensor<3x3x3xf32> // CHECK-NEXT: } +func.func @test_scalar_multiply(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = stablehlo.constant dense<9.900000e-01> : tensor + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = stablehlo.add %0, %arg0 : tensor<2x2xf32> + %2 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x2xf32> + %3 = stablehlo.multiply %1, %2 : tensor<2x2xf32> + return %3 : tensor<2x2xf32> +} +// CHECK: func.func @test_scalar_multiply(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK-NEXT: %cst = stablehlo.constant dense<9.900000e-01> : tensor +// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> +// CHECK-NEXT: %1 = stablehlo.add %0, %arg0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : tensor<2x2xf32> +// CHECK-NEXT: %2 = stablehlo.broadcast_in_dim %cst, dims = [] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor) -> tensor<2x2xf32> +// CHECK-NEXT: %3 = stablehlo.multiply %1, %2 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : tensor<2x2xf32> +// CHECK-NEXT: return %3 : tensor<2x2xf32> +// CHECK-NEXT: } + From 86e02cb88488813c510a6e83408701ded2d13f81 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 2 Dec 2025 21:08:07 +0000 Subject: [PATCH 17/21] Check for lhs == rhs in aliasing check --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 64 ++++++++++--------- .../structured_tensors/partial_symmetry.mlir | 6 +- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 7aaf89261..1f90079c8 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -593,20 +593,23 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( // Check for aliasing between LHS and RHS (up to transpose) bool rhsAliasesLhs = false; SmallVector rhsDimToLhs; - if (auto lhsT = lhs.getDefiningOp()) { - if (rhs == lhsT.getOperand()) { - rhsDimToLhs.assign(lhsT.getPermutation().begin(), - lhsT.getPermutation().end()); - rhsAliasesLhs = true; - } - } - if (auto rhsT = rhs.getDefiningOp()) { - if (lhs == rhsT.getOperand()) { - rhsDimToLhs.resize(rhsT.getPermutation().size()); - for (int64_t i = 0; i < (int64_t)rhsT.getPermutation().size(); ++i) - rhsDimToLhs[rhsT.getPermutation()[i]] = i; - rhsAliasesLhs = true; - } + if (lhs == rhs) { + auto lhsType = cast(lhs.getType()); + rhsDimToLhs.resize(lhsType.getRank()); + for (int64_t i = 0; i < lhsType.getRank(); ++i) + rhsDimToLhs[i] = i; // Identity mapping + rhsAliasesLhs = true; + } else if (auto lhsT = lhs.getDefiningOp(); + lhsT && rhs == lhsT.getOperand()) { + rhsDimToLhs.assign(lhsT.getPermutation().begin(), + lhsT.getPermutation().end()); + rhsAliasesLhs = true; + } else if (auto rhsT = rhs.getDefiningOp(); + rhsT && lhs == rhsT.getOperand()) { + rhsDimToLhs.resize(rhsT.getPermutation().size()); + for (int64_t i = 0; i < (int64_t)rhsT.getPermutation().size(); ++i) + rhsDimToLhs[rhsT.getPermutation()[i]] = i; + rhsAliasesLhs = true; } // Propagate symmetry through dotGeneral @@ -634,23 +637,26 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( auto lhs = op->getOperand(0); auto rhs = op->getOperand(1); + // Check for aliasing between LHS and RHS (up to transpose) bool rhsAliasesLhs = false; SmallVector rhsDimToLhs; - - if (auto lhsT = lhs.getDefiningOp()) { - if (rhs == lhsT.getOperand()) { - rhsDimToLhs.assign(lhsT.getPermutation().begin(), - lhsT.getPermutation().end()); - rhsAliasesLhs = true; - } - } - if (auto rhsT = rhs.getDefiningOp()) { - if (lhs == rhsT.getOperand()) { - rhsDimToLhs.resize(rhsT.getPermutation().size()); - for (int64_t i = 0; i < (int64_t)rhsT.getPermutation().size(); ++i) - rhsDimToLhs[rhsT.getPermutation()[i]] = i; - rhsAliasesLhs = true; - } + if (lhs == rhs) { + auto lhsType = cast(lhs.getType()); + rhsDimToLhs.resize(lhsType.getRank()); + for (int64_t i = 0; i < lhsType.getRank(); ++i) + rhsDimToLhs[i] = i; // Identity mapping + rhsAliasesLhs = true; + } else if (auto lhsT = lhs.getDefiningOp(); + lhsT && rhs == lhsT.getOperand()) { + rhsDimToLhs.assign(lhsT.getPermutation().begin(), + lhsT.getPermutation().end()); + rhsAliasesLhs = true; + } else if (auto rhsT = rhs.getDefiningOp(); + rhsT && lhs == rhsT.getOperand()) { + rhsDimToLhs.resize(rhsT.getPermutation().size()); + for (int64_t i = 0; i < (int64_t)rhsT.getPermutation().size(); ++i) + rhsDimToLhs[rhsT.getPermutation()[i]] = i; + rhsAliasesLhs = true; } propagatedAnnotation[0] = diff --git a/test/lit_tests/structured_tensors/partial_symmetry.mlir b/test/lit_tests/structured_tensors/partial_symmetry.mlir index f22176468..1869df03b 100644 --- a/test/lit_tests/structured_tensors/partial_symmetry.mlir +++ b/test/lit_tests/structured_tensors/partial_symmetry.mlir @@ -42,12 +42,14 @@ func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3x func.func @test_dot_general_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> { %cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<3x2x2xf32> %0 = stablehlo.dot_general %arg0, %cst1, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> + %1 = stablehlo.dot_general %0, %0, contracting_dims = [1] x [0] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> } // CHECK: func.func @test_dot_general_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> { // CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32> // CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %cst, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> -// CHECK-NEXT: return %0 : tensor<2x2xf32> +// CHECK-NEXT: %1 = stablehlo.dot_general %0, %0, contracting_dims = [1] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// CHECK-NEXT: return %1 : tensor<2x2xf32> // CHECK-NEXT: } func.func @test_dot_general_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { From aa46d9935cf3c2c4782fbcb0954aa9d92bde0e7f Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Wed, 3 Dec 2025 19:02:47 +0000 Subject: [PATCH 18/21] Make symmetry detection within DotGeneral LHS/RHS free dims run even when LHS and RHS do not alias --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 94 +++++++++---------- .../structured_tensors/partial_symmetry.mlir | 12 ++- 2 files changed, 55 insertions(+), 51 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 1f90079c8..ece1d24e4 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -196,7 +196,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( PartialSymmetryAnnotation result = createNotSymmetric(resultRank); - // Preserve symmetry in batching dimensions + // Symmetry between batching dimensions for (int64_t i = 0; i < (int64_t)lhsBatchingDims.size(); ++i) { for (int64_t j = 0; j < i; ++j) { if (lhsAnnotation.getSetId(lhsBatchingDims[i]) == @@ -208,7 +208,47 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( } } - // Preserve symmetry in free (non-contracting, non-batching) dimensions + // Calculate free (non-contracting, non-batching) dimensions + SmallVector lhsFreeDims; + for (int64_t i = 0; i < lhsAnnotation.getRank(); ++i) { + if (!llvm::is_contained(lhsBatchingDims, i) && + !llvm::is_contained(lhsContractingDims, i)) { + lhsFreeDims.push_back(i); + } + } + + SmallVector rhsFreeDims; + for (int64_t i = 0; i < rhsAnnotation.getRank(); ++i) { + if (!llvm::is_contained(rhsBatchingDims, i) && + !llvm::is_contained(rhsContractingDims, i)) { + rhsFreeDims.push_back(i); + } + } + + // Symmetry between free dimensions from LHS + for (int64_t i = 0; i < (int64_t)lhsFreeDims.size(); ++i) { + for (int64_t j = 0; j < i; ++j) { + if (lhsAnnotation.getSetId(lhsFreeDims[i]) == + lhsAnnotation.getSetId(lhsFreeDims[j])) { + result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, + lhsBatchingDims.size() + j); + } + } + } + + // Symmetry between free dimensions from RHS + for (int64_t i = 0; i < (int64_t)rhsFreeDims.size(); ++i) { + for (int64_t j = 0; j < i; ++j) { + if (rhsAnnotation.getSetId(rhsFreeDims[i]) == + rhsAnnotation.getSetId(rhsFreeDims[j])) { + result.uniteDimensionSets( + resultRank, lhsBatchingDims.size() + lhsFreeDims.size() + i, + lhsBatchingDims.size() + lhsFreeDims.size() + j); + } + } + } + + // Symmetry between free dimensions of LHS and free dimensions of RHS if (rhsAliasesLhs) { bool exchange_valid = true; @@ -236,53 +276,13 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( } if (exchange_valid) { - SmallVector lhsResultDims; - for (int64_t i = 0; i < lhsAnnotation.getRank(); ++i) { - if (!llvm::is_contained(lhsBatchingDims, i) && - !llvm::is_contained(lhsContractingDims, i)) { - lhsResultDims.push_back(i); - } - } - - SmallVector rhsResultDims; - for (int64_t i = 0; i < rhsAnnotation.getRank(); ++i) { - if (!llvm::is_contained(rhsBatchingDims, i) && - !llvm::is_contained(rhsContractingDims, i)) { - rhsResultDims.push_back(i); - } - } - - // Symmetry within free dimensions of LHS - for (int64_t i = 0; i < (int64_t)lhsResultDims.size(); ++i) { - for (int64_t j = 0; j < i; ++j) { - if (lhsAnnotation.getSetId(lhsResultDims[i]) == - lhsAnnotation.getSetId(lhsResultDims[j])) { - result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, - lhsBatchingDims.size() + j); - } - } - } - - // Symmetry between free dimensions of RHS - for (int64_t i = 0; i < (int64_t)rhsResultDims.size(); ++i) { - for (int64_t j = 0; j < i; ++j) { - if (rhsAnnotation.getSetId(rhsResultDims[i]) == - rhsAnnotation.getSetId(rhsResultDims[j])) { - result.uniteDimensionSets( - resultRank, lhsBatchingDims.size() + lhsResultDims.size() + i, - lhsBatchingDims.size() + lhsResultDims.size() + j); - } - } - } - - // Symmetry between free dimensions of LHS and RHS - for (int64_t i = 0; i < (int64_t)lhsResultDims.size(); ++i) { - for (int64_t j = 0; j < (int64_t)rhsResultDims.size(); ++j) { - if (lhsAnnotation.getSetId(lhsResultDims[i]) == - lhsAnnotation.getSetId(rhsDimToLhs[rhsResultDims[j]])) { + for (int64_t i = 0; i < (int64_t)lhsFreeDims.size(); ++i) { + for (int64_t j = 0; j < (int64_t)rhsFreeDims.size(); ++j) { + if (lhsAnnotation.getSetId(lhsFreeDims[i]) == + lhsAnnotation.getSetId(rhsDimToLhs[rhsFreeDims[j]])) { result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, lhsBatchingDims.size() + - lhsResultDims.size() + j); + lhsFreeDims.size() + j); } } } diff --git a/test/lit_tests/structured_tensors/partial_symmetry.mlir b/test/lit_tests/structured_tensors/partial_symmetry.mlir index 1869df03b..c4c007d91 100644 --- a/test/lit_tests/structured_tensors/partial_symmetry.mlir +++ b/test/lit_tests/structured_tensors/partial_symmetry.mlir @@ -39,17 +39,21 @@ func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3x // CHECK-NEXT: return %1 : tensor<3x2x3xf32> // CHECK-NEXT: } -func.func @test_dot_general_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> { +func.func @test_dot_general_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2x2x2xf32> { %cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<3x2x2xf32> %0 = stablehlo.dot_general %arg0, %cst1, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> %1 = stablehlo.dot_general %0, %0, contracting_dims = [1] x [0] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %tmp = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<2x2xf32>) -> tensor<2x2x3xf32> + %2 = stablehlo.dot_general %arg0, %tmp, contracting_dims = [2] x [2] : (tensor<2x2x3xf32>, tensor<2x2x3xf32>) -> tensor<2x2x2x2xf32> + return %2 : tensor<2x2x2x2xf32> } -// CHECK: func.func @test_dot_general_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> { +// CHECK: func.func @test_dot_general_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2x2x2xf32> { // CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32> // CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %cst, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: %1 = stablehlo.dot_general %0, %0, contracting_dims = [1] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// CHECK-NEXT: return %1 : tensor<2x2xf32> +// CHECK-NEXT: %2 = stablehlo.broadcast_in_dim %1, dims = [0, 1] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2xf32>) -> tensor<2x2x3xf32> +// CHECK-NEXT: %3 = stablehlo.dot_general %arg0, %2, contracting_dims = [2] x [2] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>, <[2, 3]>>]} : (tensor<2x2x3xf32>, tensor<2x2x3xf32>) -> tensor<2x2x2x2xf32> +// CHECK-NEXT: return %3 : tensor<2x2x2x2xf32> // CHECK-NEXT: } func.func @test_dot_general_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { From c89295592f4839b218befe0818ac2132abaf3986 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Thu, 4 Dec 2025 04:29:43 +0000 Subject: [PATCH 19/21] Fix probably incorrect transfer function for elementwise --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index ece1d24e4..2c55e4c31 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -174,13 +174,17 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateElementwiseBinary( PartialSymmetryAnnotation result = join(lhsAnnotation, rhsAnnotation); if (rhsAliasesLhs) { + int64_t changed_dim = -1; + int changed_dims = 0; for (int64_t i = 0; i < resultRank; ++i) { - int64_t j = rhsDimToLhs[i]; - if (rhsDimToLhs[j] == i) { - result.uniteDimensionSets(resultRank, i, j); + if (rhsDimToLhs[i] != i) { + changed_dim = i; + changed_dims++; } } - + if (changed_dims == 2) { + result.uniteDimensionSets(resultRank, changed_dim, rhsDimToLhs[changed_dim]); + } result.canonicalize(); } From 4480fbcdc7d9436624751ba3eb8c674fca734a43 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Thu, 4 Dec 2025 05:51:45 +0000 Subject: [PATCH 20/21] Fix join/meet naming --- .../jax/Analysis/PartialSymmetryAnalysis.cpp | 16 ++++++++-------- .../jax/Analysis/PartialSymmetryAnalysis.h | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index 2c55e4c31..e3b4e4769 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -83,7 +83,7 @@ void PartialSymmetryAnnotation::uniteDimensionSets(int64_t rank, int64_t i, } PartialSymmetryAnnotation -PartialSymmetryAnnotation::join(const PartialSymmetryAnnotation &lhs, +PartialSymmetryAnnotation::meet(const PartialSymmetryAnnotation &lhs, const PartialSymmetryAnnotation &rhs) { PartialSymmetryAnnotation result = createNotSymmetric(lhs.getRank()); @@ -102,7 +102,7 @@ PartialSymmetryAnnotation::join(const PartialSymmetryAnnotation &lhs, } PartialSymmetryAnnotation -PartialSymmetryAnnotation::meet(const PartialSymmetryAnnotation &lhs, +PartialSymmetryAnnotation::join(const PartialSymmetryAnnotation &lhs, const PartialSymmetryAnnotation &rhs) { PartialSymmetryAnnotation result = createNotSymmetric(lhs.getRank()); @@ -171,7 +171,7 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateElementwiseBinary( const PartialSymmetryAnnotation &rhsAnnotation, int64_t resultRank, bool rhsAliasesLhs, ArrayRef rhsDimToLhs) { - PartialSymmetryAnnotation result = join(lhsAnnotation, rhsAnnotation); + PartialSymmetryAnnotation result = meet(lhsAnnotation, rhsAnnotation); if (rhsAliasesLhs) { int64_t changed_dim = -1; @@ -525,14 +525,14 @@ PartialSymmetryLattice::PartialSymmetryLattice(Value v) } } -ChangeResult PartialSymmetryLattice::join(const AbstractSparseLattice &rhs) { +ChangeResult PartialSymmetryLattice::meet(const AbstractSparseLattice &rhs) { const auto *rhsStruct = reinterpret_cast(&rhs); - return join(*rhsStruct); + return meet(*rhsStruct); } -ChangeResult PartialSymmetryLattice::join(const PartialSymmetryLattice &rhs) { - auto newValue = PartialSymmetryAnnotation::join(getValue(), rhs.getValue()); +ChangeResult PartialSymmetryLattice::meet(const PartialSymmetryLattice &rhs) { + auto newValue = PartialSymmetryAnnotation::meet(getValue(), rhs.getValue()); if (getValue() == newValue) return ChangeResult::NoChange; @@ -683,7 +683,7 @@ LogicalResult PartialSymmetryAnalysis::visitOperation( if (updatedAnnotation[i]) { auto resultOrig = results[i]->getValue(); auto resultNew = - PartialSymmetryAnnotation::join(resultOrig, propagatedAnnotation[i]); + PartialSymmetryAnnotation::meet(resultOrig, propagatedAnnotation[i]); results[i]->setValue(resultNew); propagateIfChanged(results[i], resultNew == resultOrig ? ChangeResult::NoChange diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h index 54ccdc9e5..3c20ae6d3 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -29,10 +29,10 @@ class PartialSymmetryAnnotation { int64_t getSetId(int64_t i) const { return dimensionSetIDs[i]; } int64_t getRank() const { return dimensionSetIDs.size(); } - static PartialSymmetryAnnotation join(const PartialSymmetryAnnotation &lhs, - const PartialSymmetryAnnotation &rhs); static PartialSymmetryAnnotation meet(const PartialSymmetryAnnotation &lhs, const PartialSymmetryAnnotation &rhs); + static PartialSymmetryAnnotation join(const PartialSymmetryAnnotation &lhs, + const PartialSymmetryAnnotation &rhs); static PartialSymmetryAnnotation propagateTranspose(const PartialSymmetryAnnotation &annotation, @@ -85,8 +85,8 @@ class PartialSymmetryLattice : public dataflow::AbstractSparseLattice { PartialSymmetryLattice(Value v); - ChangeResult join(const AbstractSparseLattice &rhs) override; - ChangeResult join(const PartialSymmetryLattice &rhs); + ChangeResult meet(const AbstractSparseLattice &rhs) override; + ChangeResult meet(const PartialSymmetryLattice &rhs); void print(raw_ostream &os) const override; From 0bf5740aaeb12cecede53d29bd13d1871a41def0 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Wed, 10 Dec 2025 16:42:19 +0000 Subject: [PATCH 21/21] Format --- src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp index e3b4e4769..99c1bcb71 100644 --- a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -183,7 +183,8 @@ PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateElementwiseBinary( } } if (changed_dims == 2) { - result.uniteDimensionSets(resultRank, changed_dim, rhsDimToLhs[changed_dim]); + result.uniteDimensionSets(resultRank, changed_dim, + rhsDimToLhs[changed_dim]); } result.canonicalize(); }