diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp new file mode 100644 index 000000000..99c1bcb71 --- /dev/null +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp @@ -0,0 +1,699 @@ +#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h" +#include "src/enzyme_ad/jax/Utils.h" + +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "llvm/ADT/DenseMap.h" + +using namespace mlir; +using namespace mlir::dataflow; + +namespace mlir { +namespace enzyme { + +//===----------------------------------------------------------------------===// +// PartialSymmetryAnnotation Implementation +//===----------------------------------------------------------------------===// + +PartialSymmetryAnnotation::PartialSymmetryAnnotation( + ArrayRef dimensionSetIDs) { + this->dimensionSetIDs.assign(dimensionSetIDs.begin(), dimensionSetIDs.end()); + canonicalize(); +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::createFullySymmetric(int64_t rank) { + PartialSymmetryAnnotation annotation; + for (int64_t i = 0; i < rank; ++i) { + annotation.dimensionSetIDs.push_back(0); + } + return annotation; +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::createNotSymmetric(int64_t rank) { + PartialSymmetryAnnotation annotation; + for (int64_t i = 0; i < rank; ++i) { + annotation.dimensionSetIDs.push_back(i); + } + return annotation; +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::createUninitialized(int64_t rank) { + PartialSymmetryAnnotation annotation; + annotation.dimensionSetIDs.resize(rank); + return annotation; +} + +bool PartialSymmetryAnnotation::isSymmetric(int64_t i, int64_t j) const { + return dimensionSetIDs[i] == dimensionSetIDs[j]; +} + +void PartialSymmetryAnnotation::canonicalize() { + llvm::SmallDenseMap map; + int64_t nextId = 0; + for (auto &id : dimensionSetIDs) { + if (map.find(id) == map.end()) { + map[id] = nextId++; + } + id = map[id]; + } +} + +void PartialSymmetryAnnotation::uniteDimensionSets(int64_t rank, int64_t i, + int64_t j) { + if (dimensionSetIDs[i] == dimensionSetIDs[j]) + return; + + int64_t oldId = dimensionSetIDs[i]; + int64_t newId = dimensionSetIDs[j]; + for (int64_t k = 0; k < (int64_t)dimensionSetIDs.size(); ++k) { + if (dimensionSetIDs[k] == oldId) { + dimensionSetIDs[k] = newId; + } + } + + canonicalize(); +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::meet(const PartialSymmetryAnnotation &lhs, + const PartialSymmetryAnnotation &rhs) { + PartialSymmetryAnnotation result = createNotSymmetric(lhs.getRank()); + + for (int64_t i = 0; i < lhs.getRank(); ++i) { + bool found = false; + for (int64_t j = 0; j < i; ++j) { + if (lhs.getSetId(i) == lhs.getSetId(j) && + rhs.getSetId(i) == rhs.getSetId(j)) { + result.uniteDimensionSets(lhs.getRank(), i, j); + } + } + } + + result.canonicalize(); + return result; +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::join(const PartialSymmetryAnnotation &lhs, + const PartialSymmetryAnnotation &rhs) { + PartialSymmetryAnnotation result = createNotSymmetric(lhs.getRank()); + + for (int64_t i = 0; i < lhs.getRank(); ++i) { + for (int64_t j = 0; j < i; ++j) { + if (lhs.getSetId(i) == lhs.getSetId(j) || + rhs.getSetId(i) == rhs.getSetId(j)) { + result.uniteDimensionSets(lhs.getRank(), i, j); + } + } + } + + result.canonicalize(); + return result; +} + +PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateTranspose( + const PartialSymmetryAnnotation &annotation, + ArrayRef permutation) { + + PartialSymmetryAnnotation result = createUninitialized(annotation.getRank()); + + for (int64_t i = 0; i < annotation.getRank(); ++i) { + result.dimensionSetIDs[i] = annotation.getSetId(permutation[i]); + } + + result.canonicalize(); + return result; +} + +PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateBroadcastInDim( + const PartialSymmetryAnnotation &annotation, int64_t outputRank, + ArrayRef broadcastDimensions) { + + PartialSymmetryAnnotation result = createUninitialized(outputRank); + + llvm::SmallDenseMap outputToInput; + for (int64_t i = 0; i < (int64_t)broadcastDimensions.size(); ++i) { + outputToInput[broadcastDimensions[i]] = i; + } + + int64_t maxSetId = -1; + for (int64_t i = 0; i < annotation.getRank(); ++i) { + maxSetId = std::max(maxSetId, (int64_t)annotation.getSetId(i)); + } + + int64_t nextSetId = maxSetId + 1; + for (int64_t outputDim = 0; outputDim < outputRank; ++outputDim) { + if (outputToInput.find(outputDim) != outputToInput.end()) { + // dimension is preserved => use old ID + int64_t inputDim = outputToInput[outputDim]; + result.dimensionSetIDs[outputDim] = annotation.getSetId(inputDim); + } else { + // result is constant in each broadcasted dimension, + // so they are partially symmetric with each other + result.dimensionSetIDs[outputDim] = nextSetId; + } + } + + result.canonicalize(); + return result; +} + +PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateElementwiseBinary( + const PartialSymmetryAnnotation &lhsAnnotation, + const PartialSymmetryAnnotation &rhsAnnotation, int64_t resultRank, + bool rhsAliasesLhs, ArrayRef rhsDimToLhs) { + + PartialSymmetryAnnotation result = meet(lhsAnnotation, rhsAnnotation); + + if (rhsAliasesLhs) { + int64_t changed_dim = -1; + int changed_dims = 0; + for (int64_t i = 0; i < resultRank; ++i) { + if (rhsDimToLhs[i] != i) { + changed_dim = i; + changed_dims++; + } + } + if (changed_dims == 2) { + result.uniteDimensionSets(resultRank, changed_dim, + rhsDimToLhs[changed_dim]); + } + result.canonicalize(); + } + + return result; +} + +PartialSymmetryAnnotation PartialSymmetryAnnotation::propagateDotGeneral( + const PartialSymmetryAnnotation &lhsAnnotation, + const PartialSymmetryAnnotation &rhsAnnotation, int64_t resultRank, + ArrayRef lhsBatchingDims, ArrayRef rhsBatchingDims, + ArrayRef lhsContractingDims, ArrayRef rhsContractingDims, + bool rhsAliasesLhs, ArrayRef rhsDimToLhs) { + + PartialSymmetryAnnotation result = createNotSymmetric(resultRank); + + // Symmetry between batching dimensions + for (int64_t i = 0; i < (int64_t)lhsBatchingDims.size(); ++i) { + for (int64_t j = 0; j < i; ++j) { + if (lhsAnnotation.getSetId(lhsBatchingDims[i]) == + lhsAnnotation.getSetId(lhsBatchingDims[j]) && + rhsAnnotation.getSetId(rhsBatchingDims[i]) == + rhsAnnotation.getSetId(rhsBatchingDims[j])) { + result.uniteDimensionSets(resultRank, i, j); + } + } + } + + // Calculate free (non-contracting, non-batching) dimensions + SmallVector lhsFreeDims; + for (int64_t i = 0; i < lhsAnnotation.getRank(); ++i) { + if (!llvm::is_contained(lhsBatchingDims, i) && + !llvm::is_contained(lhsContractingDims, i)) { + lhsFreeDims.push_back(i); + } + } + + SmallVector rhsFreeDims; + for (int64_t i = 0; i < rhsAnnotation.getRank(); ++i) { + if (!llvm::is_contained(rhsBatchingDims, i) && + !llvm::is_contained(rhsContractingDims, i)) { + rhsFreeDims.push_back(i); + } + } + + // Symmetry between free dimensions from LHS + for (int64_t i = 0; i < (int64_t)lhsFreeDims.size(); ++i) { + for (int64_t j = 0; j < i; ++j) { + if (lhsAnnotation.getSetId(lhsFreeDims[i]) == + lhsAnnotation.getSetId(lhsFreeDims[j])) { + result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, + lhsBatchingDims.size() + j); + } + } + } + + // Symmetry between free dimensions from RHS + for (int64_t i = 0; i < (int64_t)rhsFreeDims.size(); ++i) { + for (int64_t j = 0; j < i; ++j) { + if (rhsAnnotation.getSetId(rhsFreeDims[i]) == + rhsAnnotation.getSetId(rhsFreeDims[j])) { + result.uniteDimensionSets( + resultRank, lhsBatchingDims.size() + lhsFreeDims.size() + i, + lhsBatchingDims.size() + lhsFreeDims.size() + j); + } + } + } + + // Symmetry between free dimensions of LHS and free dimensions of RHS + if (rhsAliasesLhs) { + bool exchange_valid = true; + + // check that each batching dimension has same ID for LHS and RHS + for (int64_t i = 0; i < (int64_t)lhsBatchingDims.size(); ++i) { + if (lhsAnnotation.getSetId(lhsBatchingDims[i]) != + lhsAnnotation.getSetId(rhsDimToLhs[rhsBatchingDims[i]])) { + exchange_valid = false; + } + } + + // check that the multiset of IDs for contracting dimensions are equal for + // LHS and RHS + SmallVector lhsContractingIds, rhsContractingIds; + for (int64_t dim : lhsContractingDims) { + lhsContractingIds.push_back(lhsAnnotation.getSetId(dim)); + } + for (int64_t dim : rhsContractingDims) { + rhsContractingIds.push_back(lhsAnnotation.getSetId(rhsDimToLhs[dim])); + } + llvm::sort(lhsContractingIds); + llvm::sort(rhsContractingIds); + if (lhsContractingIds != rhsContractingIds) { + exchange_valid = false; + } + + if (exchange_valid) { + for (int64_t i = 0; i < (int64_t)lhsFreeDims.size(); ++i) { + for (int64_t j = 0; j < (int64_t)rhsFreeDims.size(); ++j) { + if (lhsAnnotation.getSetId(lhsFreeDims[i]) == + lhsAnnotation.getSetId(rhsDimToLhs[rhsFreeDims[j]])) { + result.uniteDimensionSets(resultRank, lhsBatchingDims.size() + i, + lhsBatchingDims.size() + + lhsFreeDims.size() + j); + } + } + } + } + } + + result.canonicalize(); + return result; +} + +template +static bool checkPairwiseSymmetry(DenseElementsAttr attr, int64_t dimA, + int64_t dimB) { + auto type = cast(attr.getType()); + auto shape = type.getShape(); + int64_t rank = type.getRank(); + + if (shape[dimA] != shape[dimB]) + return false; + + if (attr.isSplat()) + return true; + + auto values = attr.getValues(); + auto it = values.begin(); + + SmallVector strides(rank); + int64_t currentStride = 1; + for (int64_t i = rank - 1; i >= 0; --i) { + strides[i] = currentStride; + currentStride *= shape[i]; + } + + int64_t numElements = 1; + for (int64_t s : shape) + numElements *= s; + + for (int64_t i = 0; i < numElements; ++i) { + SmallVector coords(rank); + int64_t temp = i; + for (int64_t d = 0; d < rank; ++d) { + coords[d] = temp / strides[d]; + temp %= strides[d]; + } + + std::swap(coords[dimA], coords[dimB]); + + int64_t swappedIdx = 0; + for (int64_t d = 0; d < rank; ++d) { + swappedIdx += coords[d] * strides[d]; + } + + auto a = *(it + i); + auto b = *(it + swappedIdx); + if (checkNotEqual(a, b)) + return false; + } + return true; +} + +PartialSymmetryAnnotation +PartialSymmetryAnnotation::checkConstant(DenseElementsAttr attr) { + if (auto type = dyn_cast(attr.getType())) { + int64_t rank = type.getRank(); + PartialSymmetryAnnotation result = createNotSymmetric(rank); + + for (int64_t i = 0; i < rank; ++i) { + for (int64_t j = 0; j < i; ++j) { + bool isSymmetric = false; + if (isa(attr.getElementType())) { + isSymmetric = checkPairwiseSymmetry(attr, i, j); + } else if (isa(attr.getElementType())) { + isSymmetric = checkPairwiseSymmetry(attr, i, j); + } + + if (isSymmetric) { + result.uniteDimensionSets(rank, i, j); + continue; + } + } + } + return result; + } + return PartialSymmetryAnnotation(); +} + +SmallVector> +PartialSymmetryAnnotation::getDimensionSets() const { + llvm::SmallDenseMap> sets; + for (int64_t i = 0; i < (int64_t)dimensionSetIDs.size(); ++i) { + sets[dimensionSetIDs[i]].push_back(i); + } + + SmallVector sortedKeys; + for (auto &kv : sets) + sortedKeys.push_back(kv.first); + std::sort(sortedKeys.begin(), sortedKeys.end(), + [&](int64_t a, int64_t b) { return sets[a][0] < sets[b][0]; }); + + SmallVector> result; + for (int64_t key : sortedKeys) { + result.push_back(sets[key]); + } + return result; +} + +PartialSymmetryAnnotation PartialSymmetryAnnotation::createFromDimensionSets( + int64_t rank, ArrayRef> dimensionSets) { + SmallVector dimensionSetIDs(rank); + for (int64_t i = 0; i < rank; ++i) { + dimensionSetIDs[i] = i; + } + + // Note that dimensionSets is not assumed to be a complete partition. + // Missing dimensions are treated as separate sets. + for (auto dims : dimensionSets) { + for (int64_t i = 1; i < (int64_t)dims.size(); ++i) { + dimensionSetIDs[dims[i]] = dimensionSetIDs[dims[0]]; + } + } + + return PartialSymmetryAnnotation(dimensionSetIDs); +} + +std::optional +PartialSymmetryAnnotation::createFromIR(Value val) { + auto blockArg = dyn_cast(val); + if (blockArg) { + auto op = blockArg.getOwner()->getParentOp(); + auto funcOpInterface = dyn_cast(op); + if (!funcOpInterface) { + return std::nullopt; + } + + auto argAttrs = funcOpInterface.getArgAttrs(blockArg.getArgNumber()); + for (auto attr : argAttrs) { + if (attr.getName() == "enzymexla.partial_symmetry") { + auto arrayAttr = dyn_cast(attr.getValue()); + if (!arrayAttr || arrayAttr.empty()) { + continue; + } + + auto partialSymmetryAttr = + dyn_cast( + arrayAttr[0]); + + if (!partialSymmetryAttr) { + continue; + } + + auto dimensionSetAttrs = partialSymmetryAttr.getValues(); + auto rank = cast(val.getType()).getRank(); + + SmallVector> dimensionSets; + for (auto dimensionSetAttr : dimensionSetAttrs) { + auto dims = dimensionSetAttr.getDimensions().asArrayRef(); + dimensionSets.push_back(dims); + } + + return PartialSymmetryAnnotation::createFromDimensionSets( + rank, dimensionSets); + } + } + return std::nullopt; + } + + auto op = val.getDefiningOp(); + if (!op) + return std::nullopt; + + auto arrayAttr = op->getAttrOfType("enzymexla.partial_symmetry"); + if (!arrayAttr || arrayAttr.empty()) + return std::nullopt; + + auto opResult = dyn_cast(val); + if (!opResult) + return std::nullopt; + + auto resultNumber = opResult.getResultNumber(); + + auto partialSymmetryAttr = + dyn_cast( + arrayAttr[resultNumber]); + if (!partialSymmetryAttr) + return std::nullopt; + + auto dimensionSetAttrs = partialSymmetryAttr.getValues(); + auto rank = cast(val.getType()).getRank(); + + SmallVector> dimensionSets; + for (auto dimensionSetAttr : dimensionSetAttrs) { + auto dims = dimensionSetAttr.getDimensions().asArrayRef(); + dimensionSets.push_back(dims); + } + + return PartialSymmetryAnnotation::createFromDimensionSets(rank, + dimensionSets); +} + +void PartialSymmetryAnnotation::print(raw_ostream &os) const { + auto dimensionSets = getDimensionSets(); + os << "{"; + bool firstSet = true; + for (const auto &set : dimensionSets) { + if (!firstSet) + os << ", "; + os << "{"; + bool firstElem = true; + for (int64_t dim : set) { + if (!firstElem) + os << ","; + os << dim; + firstElem = false; + } + os << "}"; + firstSet = false; + } + os << "}"; +} + +//===----------------------------------------------------------------------===// +// PartialSymmetryLattice Implementation +//===----------------------------------------------------------------------===// + +PartialSymmetryLattice::PartialSymmetryLattice(Value v) + : AbstractSparseLattice(v) { + if (auto type = dyn_cast(v.getType())) { + // Trust existing IR annotations if present. + auto annotation = PartialSymmetryAnnotation::createFromIR(v); + if (annotation.has_value()) { + value = annotation.value(); + return; + } + + value = PartialSymmetryAnnotation::createFullySymmetric(type.getRank()); + } +} + +ChangeResult PartialSymmetryLattice::meet(const AbstractSparseLattice &rhs) { + const auto *rhsStruct = + reinterpret_cast(&rhs); + return meet(*rhsStruct); +} + +ChangeResult PartialSymmetryLattice::meet(const PartialSymmetryLattice &rhs) { + auto newValue = PartialSymmetryAnnotation::meet(getValue(), rhs.getValue()); + if (getValue() == newValue) + return ChangeResult::NoChange; + + setValue(newValue); + return ChangeResult::Change; +} + +void PartialSymmetryLattice::print(raw_ostream &os) const { value.print(os); } + +//===----------------------------------------------------------------------===// +// PartialSymmetryAnalysis Implementation +//===----------------------------------------------------------------------===// + +void PartialSymmetryAnalysis::setToEntryState(PartialSymmetryLattice *lattice) { + auto annotation = + PartialSymmetryAnnotation::createFromIR(lattice->getAnchor()); + if (annotation.has_value()) { + lattice->setValue(annotation.value()); + return; + } + + lattice->setValue(PartialSymmetryAnnotation::createNotSymmetric( + lattice->getValue().getRank())); +} + +LogicalResult PartialSymmetryAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + + SmallVector updatedAnnotation(results.size(), false); + SmallVector propagatedAnnotation(results.size()); + + SmallVector operandAnnotations(operands.size()); + for (int64_t i = 0; i < (int64_t)operands.size(); i++) { + operandAnnotations[i] = operands[i]->getValue(); + } + + if (auto transposeOp = dyn_cast(op)) { + updatedAnnotation[0] = true; + propagatedAnnotation[0] = PartialSymmetryAnnotation::propagateTranspose( + operandAnnotations[0], transposeOp.getPermutation()); + } + + if (auto bcastOp = dyn_cast(op)) { + if (auto resultType = + dyn_cast(op->getResult(0).getType())) { + updatedAnnotation[0] = true; + propagatedAnnotation[0] = + PartialSymmetryAnnotation::propagateBroadcastInDim( + operandAnnotations[0], resultType.getRank(), + bcastOp.getBroadcastDimensions()); + } + } + + if (auto dotGeneralOp = dyn_cast(op)) { + if (auto resultType = + dyn_cast(op->getResult(0).getType())) { + auto dotDimNumbers = dotGeneralOp.getDotDimensionNumbers(); + auto lhs = dotGeneralOp.getLhs(); + auto rhs = dotGeneralOp.getRhs(); + + // Check for aliasing between LHS and RHS (up to transpose) + bool rhsAliasesLhs = false; + SmallVector rhsDimToLhs; + if (lhs == rhs) { + auto lhsType = cast(lhs.getType()); + rhsDimToLhs.resize(lhsType.getRank()); + for (int64_t i = 0; i < lhsType.getRank(); ++i) + rhsDimToLhs[i] = i; // Identity mapping + rhsAliasesLhs = true; + } else if (auto lhsT = lhs.getDefiningOp(); + lhsT && rhs == lhsT.getOperand()) { + rhsDimToLhs.assign(lhsT.getPermutation().begin(), + lhsT.getPermutation().end()); + rhsAliasesLhs = true; + } else if (auto rhsT = rhs.getDefiningOp(); + rhsT && lhs == rhsT.getOperand()) { + rhsDimToLhs.resize(rhsT.getPermutation().size()); + for (int64_t i = 0; i < (int64_t)rhsT.getPermutation().size(); ++i) + rhsDimToLhs[rhsT.getPermutation()[i]] = i; + rhsAliasesLhs = true; + } + + // Propagate symmetry through dotGeneral + propagatedAnnotation[0] = PartialSymmetryAnnotation::propagateDotGeneral( + operandAnnotations[0], operandAnnotations[1], resultType.getRank(), + dotDimNumbers.getLhsBatchingDimensions(), + dotDimNumbers.getRhsBatchingDimensions(), + dotDimNumbers.getLhsContractingDimensions(), + dotDimNumbers.getRhsContractingDimensions(), rhsAliasesLhs, + rhsDimToLhs); + + updatedAnnotation[0] = true; + } + } + + if (stablehlo::hasTraitElementwise(op)) { + if (auto resultType = + dyn_cast(op->getResult(0).getType())) { + if (operands.size() == 1) { + propagatedAnnotation[0] = operandAnnotations[0]; + updatedAnnotation[0] = true; + } else if (operands.size() == 2 && + (op->hasTrait() || + op->hasTrait())) { + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + + // Check for aliasing between LHS and RHS (up to transpose) + bool rhsAliasesLhs = false; + SmallVector rhsDimToLhs; + if (lhs == rhs) { + auto lhsType = cast(lhs.getType()); + rhsDimToLhs.resize(lhsType.getRank()); + for (int64_t i = 0; i < lhsType.getRank(); ++i) + rhsDimToLhs[i] = i; // Identity mapping + rhsAliasesLhs = true; + } else if (auto lhsT = lhs.getDefiningOp(); + lhsT && rhs == lhsT.getOperand()) { + rhsDimToLhs.assign(lhsT.getPermutation().begin(), + lhsT.getPermutation().end()); + rhsAliasesLhs = true; + } else if (auto rhsT = rhs.getDefiningOp(); + rhsT && lhs == rhsT.getOperand()) { + rhsDimToLhs.resize(rhsT.getPermutation().size()); + for (int64_t i = 0; i < (int64_t)rhsT.getPermutation().size(); ++i) + rhsDimToLhs[rhsT.getPermutation()[i]] = i; + rhsAliasesLhs = true; + } + + propagatedAnnotation[0] = + PartialSymmetryAnnotation::propagateElementwiseBinary( + operandAnnotations[0], operandAnnotations[1], + resultType.getRank(), rhsAliasesLhs, rhsDimToLhs); + updatedAnnotation[0] = true; + } + } + } + + DenseElementsAttr denseAttr; + if (matchPattern(op->getResult(0), m_Constant(&denseAttr))) { + updatedAnnotation[0] = true; + propagatedAnnotation[0] = + PartialSymmetryAnnotation::checkConstant(denseAttr); + } + + for (int64_t i = 0; i < (int64_t)results.size(); i++) { + if (updatedAnnotation[i]) { + auto resultOrig = results[i]->getValue(); + auto resultNew = + PartialSymmetryAnnotation::meet(resultOrig, propagatedAnnotation[i]); + results[i]->setValue(resultNew); + propagateIfChanged(results[i], resultNew == resultOrig + ? ChangeResult::NoChange + : ChangeResult::Change); + } + } + + return success(); +} + +} // namespace enzyme +} // namespace mlir diff --git a/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h new file mode 100644 index 000000000..3c20ae6d3 --- /dev/null +++ b/src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include +#include + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace enzyme { + +// Represents the partial symmetry of a tensor as a partition of its dimensions. +class PartialSymmetryAnnotation { +public: + PartialSymmetryAnnotation() : dimensionSetIDs() {} + + explicit PartialSymmetryAnnotation(ArrayRef dimensionSetIDs); + + static PartialSymmetryAnnotation createUninitialized(int64_t rank); + static PartialSymmetryAnnotation createNotSymmetric(int64_t rank); + static PartialSymmetryAnnotation createFullySymmetric(int64_t rank); + + bool isSymmetric(int64_t i, int64_t j) const; + int64_t getSetId(int64_t i) const { return dimensionSetIDs[i]; } + int64_t getRank() const { return dimensionSetIDs.size(); } + + static PartialSymmetryAnnotation meet(const PartialSymmetryAnnotation &lhs, + const PartialSymmetryAnnotation &rhs); + static PartialSymmetryAnnotation join(const PartialSymmetryAnnotation &lhs, + const PartialSymmetryAnnotation &rhs); + + static PartialSymmetryAnnotation + propagateTranspose(const PartialSymmetryAnnotation &annotation, + ArrayRef permutation); + + static PartialSymmetryAnnotation + propagateBroadcastInDim(const PartialSymmetryAnnotation &annotation, + int64_t outputRank, + ArrayRef broadcastDimensions); + + static PartialSymmetryAnnotation + propagateDotGeneral(const PartialSymmetryAnnotation &lhsAnnotation, + const PartialSymmetryAnnotation &rhsAnnotation, + int64_t resultRank, ArrayRef lhsBatchingDims, + ArrayRef rhsBatchingDims, + ArrayRef lhsContractingDims, + ArrayRef rhsContractingDims, bool rhsAliasesLhs, + ArrayRef rhsDimToLhs); + + static PartialSymmetryAnnotation checkConstant(DenseElementsAttr attr); + + static PartialSymmetryAnnotation + propagateElementwiseBinary(const PartialSymmetryAnnotation &lhsAnnotation, + const PartialSymmetryAnnotation &rhsAnnotation, + int64_t resultRank, bool rhsAliasesLhs, + ArrayRef rhsDimToLhs); + + bool operator==(const PartialSymmetryAnnotation &other) const { + return dimensionSetIDs == other.dimensionSetIDs; + } + + SmallVector> getDimensionSets() const; + static PartialSymmetryAnnotation + createFromDimensionSets(int64_t rank, + ArrayRef> dimensionSets); + static std::optional createFromIR(Value val); + + void print(raw_ostream &os) const; + +private: + SmallVector dimensionSetIDs; + + void canonicalize(); + void uniteDimensionSets(int64_t rank, int64_t i, int64_t j); +}; + +class PartialSymmetryLattice : public dataflow::AbstractSparseLattice { +public: + using AbstractSparseLattice::AbstractSparseLattice; + + PartialSymmetryLattice(Value v); + + ChangeResult meet(const AbstractSparseLattice &rhs) override; + ChangeResult meet(const PartialSymmetryLattice &rhs); + + void print(raw_ostream &os) const override; + + const PartialSymmetryAnnotation &getValue() const { return value; } + void setValue(const PartialSymmetryAnnotation &v) { value = v; } + +private: + PartialSymmetryAnnotation value; +}; + +class PartialSymmetryAnalysis + : public dataflow::SparseForwardDataFlowAnalysis { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + void setToEntryState(PartialSymmetryLattice *lattice) override; + + LogicalResult + visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override; +}; + +} // namespace enzyme +} // namespace mlir diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index fd02c7a6b..1b33e3722 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -840,6 +840,7 @@ cc_library( cc_library( name = "XLADerivatives", srcs = glob([ + "Analysis/*.cpp", "Implementations/*.cpp", "Passes/*.cpp", "Dialect/*.cpp", @@ -849,6 +850,7 @@ cc_library( "Utils.cpp", ], hdrs = glob([ + "Analysis/*.h", "Implementations/*.h", "Passes/*.h", "Dialect/*.h", diff --git a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td index 11bf8ca25..2e1a43230 100644 --- a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td +++ b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td @@ -133,4 +133,36 @@ def EnzymeXLA_GuaranteedAnalysisResult : I32EnumAttr<"GuaranteedAnalysisResult", def EnzymeXLA_GuaranteedAnalysisResultAttr : EnumAttr; +def EnzymeXLA_SymmetricDimensionSetAttr : AttrDef { + let summary = "A set of symmetric dimension indices"; + let cppNamespace = "::mlir::enzymexla"; + + let parameters = (ins + "DenseI64ArrayAttr":$dimensions + ); + + let assemblyFormat = [{ + `<` $dimensions `>` + }]; + + let mnemonic = "symmetric_dimension_set"; +} + +def EnzymeXLA_PartialSymmetryAnalysisResultAttr : AttrDef { + let summary = "Sets of partially symmetric dimensions"; + let cppNamespace = "::mlir::enzymexla"; + + let parameters = (ins + ArrayRefParameter<"SymmetricDimensionSetAttr">:$values + ); + + let assemblyFormat = [{ + `<` $values `>` + }]; + + let mnemonic = "partial_symmetry"; +} + #endif // ENZYMEXLA_ATTRS diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index faf7d607d..806aa372c 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -24,9 +24,11 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "shardy/dialect/sdy/ir/utils.h" +#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h" #include "src/enzyme_ad/jax/CheckedRewrite.h" #include "src/enzyme_ad/jax/Dialect/Dialect.h" #include "src/enzyme_ad/jax/Dialect/Ops.h" @@ -55,6 +57,7 @@ #include "llvm/ADT/MapVector.h" #include #include +#include #define DEBUG_TYPE "enzymehloopt" namespace mlir { @@ -6957,6 +6960,42 @@ struct TransposeSymmetricSimplify } }; +struct TransposePartialSymmetrySimplify + : public CheckedOpRewritePattern { + using CheckedOpRewritePattern< + stablehlo::TransposeOp, + TransposePartialSymmetrySimplify>::CheckedOpRewritePattern; + + LogicalResult matchAndRewriteImpl(stablehlo::TransposeOp op, + PatternRewriter &rewriter) const { + auto operand = op.getOperand(); + auto perm = op.getPermutation(); + + auto annotationOpt = + enzyme::PartialSymmetryAnnotation::createFromIR(operand); + if (!annotationOpt.has_value()) + return failure(); + + auto annotation = annotationOpt.value(); + + bool isIdentity = true; + for (int64_t i = 0; i < (int64_t)perm.size(); ++i) { + if (annotation.getSetId(i) != annotation.getSetId(perm[i])) { + isIdentity = false; + break; + } + } + + if (isIdentity) { + rewriter.replaceOp(op, operand); + return success(); + } + + return failure(); + } +}; + struct NoNanSelfSubSimplify : public NoNanCheckedOpRewritePattern { @@ -26653,7 +26692,8 @@ struct EnzymeHLOOptPass NoNanAddSubSimplify, NoNanMulSimplify, NoNanDivSimplify>( (no_nan || all_finite), context); - patterns.add(context); + patterns.add( + context); patterns.add(context); // syrk patterns diff --git a/src/enzyme_ad/jax/Passes/PartialSymmetryAnnotate.cpp b/src/enzyme_ad/jax/Passes/PartialSymmetryAnnotate.cpp new file mode 100644 index 000000000..2bfd814e4 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/PartialSymmetryAnnotate.cpp @@ -0,0 +1,99 @@ +#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" + +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "stablehlo/dialect/StablehloOps.h" + +#define DEBUG_TYPE "partial-symmetry-annotate" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_PARTIALSYMMETRYANNOTATEPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::dataflow; +using namespace mlir::enzyme; + +namespace { + +class PartialSymmetryAnnotatePass + : public enzyme::impl::PartialSymmetryAnnotatePassBase< + PartialSymmetryAnnotatePass> { +public: + using Base::Base; + + void runOnOperation() override { + DataFlowSolver solver; + + solver.load(); + solver.load(); + solver.load(); + + if (failed(solver.initializeAndRun(getOperation()))) { + return signalPassFailure(); + } + + auto mod = getOperation(); + + // Annotate all operations with partial symmetry information + mod->walk([&](Operation *op) { + SmallVector partialSymmetryAttrs; + bool anyKnown = false; + + for (auto result : op->getResults()) { + auto *state = + solver.lookupState(result); + if (!state) { + continue; + } + + auto dimensionSets = state->getValue().getDimensionSets(); + + SmallVector dimensionSetAttrs; + for (const auto &set : dimensionSets) { + if (set.size() > 1) { + anyKnown = true; + auto denseAttr = DenseI64ArrayAttr::get(mod.getContext(), set); + auto dimensionSetAttr = enzymexla::SymmetricDimensionSetAttr::get( + mod.getContext(), denseAttr); + dimensionSetAttrs.push_back(dimensionSetAttr); + } + } + + if (dimensionSetAttrs.empty()) { + continue; + } + + auto partialSymmetry = + enzymexla::PartialSymmetryAnalysisResultAttr::get( + mod.getContext(), dimensionSetAttrs); + partialSymmetryAttrs.push_back(partialSymmetry); + } + + if (anyKnown) { + op->setAttr("enzymexla.partial_symmetry", + ArrayAttr::get(mod.getContext(), partialSymmetryAttrs)); + } + + return WalkResult::advance(); + }); + } +}; + +} // namespace diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index ad0e1d653..03afbf77e 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1083,6 +1083,15 @@ def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> { ]; } +def PartialSymmetryAnnotatePass : Pass<"partial-symmetry-annotate", "ModuleOp"> { + let summary = "Annotate operations using partial symmetry analysis"; + let dependentDialects = [ + "stablehlo::StablehloDialect", + "enzymexla::EnzymeXLADialect", + "func::FuncDialect", + ]; +} + def ConvertAllConstantsToSplattedConstantPass : Pass<"convert-all-constants-to-splatted-constant", "ModuleOp"> { let summary = "Convert all constants to splatted constants. This is supposed to be used for debugging purposes when dumping the module."; let dependentDialects = [ diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index 03bce861b..26398050c 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -608,6 +608,11 @@ def ApplyTransposeSymmetricSimplify : EnzymeHLOPatternOp< let patterns = ["TransposeSymmetricSimplify"]; } +def ApplyTransposePartialSymmetrySimplify : EnzymeHLOPatternOp< + "transpose_partial_symmetry_simplify"> { + let patterns = ["TransposePartialSymmetrySimplify"]; +} + def ApplyFactorScalarsInDotGeneral : EnzymeHLOPatternOp< "factor_scalars_in_dot_general"> { let patterns = ["FactorScalarsInDotGeneral"]; diff --git a/src/enzyme_ad/jax/Utils.h b/src/enzyme_ad/jax/Utils.h index 314e69cdc..07f35477d 100644 --- a/src/enzyme_ad/jax/Utils.h +++ b/src/enzyme_ad/jax/Utils.h @@ -319,6 +319,9 @@ bool mayAlias(mlir::MemoryEffects::EffectInstance a, bool mayAlias(mlir::MemoryEffects::EffectInstance a, mlir::Value b); +bool checkNotEqual(llvm::APInt a, llvm::APInt b); +bool checkNotEqual(llvm::APFloat a, llvm::APFloat b); + bool canApplyNoNanPattern(bool allowOnFloatingPointMath, Type Ty); bool canApplyNoNanPattern(bool allowOnFloatingPointMath, Type Ty, mlir::Operation *op, PatternRewriter &rewriter); diff --git a/test/lit_tests/structured_tensors/partial_symmetry.mlir b/test/lit_tests/structured_tensors/partial_symmetry.mlir new file mode 100644 index 000000000..c4c007d91 --- /dev/null +++ b/test/lit_tests/structured_tensors/partial_symmetry.mlir @@ -0,0 +1,86 @@ +// RUN: enzymexlamlir-opt --partial-symmetry-annotate --enzyme-hlo-generate-td="patterns=transpose_partial_symmetry_simplify" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s + +func.func @test_constant() -> tensor<2x2xf32> { + %cst = stablehlo.constant dense<[[1.0, 2.0], [2.0, 3.0]]> : tensor<2x2xf32> + %0 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} +// CHECK: func.func @test_constant() -> tensor<2x2xf32> { +// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2xf32> +// CHECK-NEXT: return %cst : tensor<2x2xf32> +// CHECK-NEXT: } + +func.func @test_propagate() -> tensor<2x2x2x3xf32> { + %cst0 = stablehlo.constant dense<[[[1.0, 2.0], [3.0, 4.0]], [[3.0, 4.0], [5.0, 6.0]]]> : tensor<2x2x2xf32> + %cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<2x2x2xf32> + %0 = stablehlo.add %cst0, %cst1 : tensor<2x2x2xf32> + %1 = stablehlo.transpose %0, dims = [0, 2, 1] : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + %2 = stablehlo.broadcast_in_dim %1, dims = [1, 0, 2] : (tensor<2x2x2xf32>) -> tensor<2x2x2x3xf32> + %3 = stablehlo.transpose %2, dims = [0, 2, 1, 3] : (tensor<2x2x2x3xf32>) -> tensor<2x2x2x3xf32> + return %3 : tensor<2x2x2x3xf32> +} +// CHECK: func.func @test_propagate() -> tensor<2x2x2x3xf32> { +// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} dense<{{.*}}> : tensor<2x2x2xf32> +// CHECK-NEXT: %cst_0 = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1, 2]>>]} dense<{{.*}}> : tensor<2x2x2xf32> +// CHECK-NEXT: %0 = stablehlo.add %cst, %cst_0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : tensor<2x2x2xf32> +// CHECK-NEXT: %1 = stablehlo.transpose %0, dims = [0, 2, 1] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 2]>>]} : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> +// CHECK-NEXT: %2 = stablehlo.broadcast_in_dim %1, dims = [1, 0, 2] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} : (tensor<2x2x2xf32>) -> tensor<2x2x2x3xf32> +// CHECK-NEXT: return %2 : tensor<2x2x2x3xf32> +// CHECK-NEXT: } + +func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> { + %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> + %1 = stablehlo.add %0, %arg0 : tensor<3x2x3xf32> + return %1 : tensor<3x2x3xf32> +} +// CHECK: func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3xf32> { +// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: %1 = stablehlo.add %0, %arg0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 2]>>]} : tensor<3x2x3xf32> +// CHECK-NEXT: return %1 : tensor<3x2x3xf32> +// CHECK-NEXT: } + +func.func @test_dot_general_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2x2x2xf32> { + %cst1 = stablehlo.constant dense<[[[1.0, 2.0], [2.0, 3.0]], [[2.0, 3.0], [3.0, 4.0]], [[2.0, 3.0], [3.0, 4.0]]]> : tensor<3x2x2xf32> + %0 = stablehlo.dot_general %arg0, %cst1, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> + %1 = stablehlo.dot_general %0, %0, contracting_dims = [1] x [0] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %tmp = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<2x2xf32>) -> tensor<2x2x3xf32> + %2 = stablehlo.dot_general %arg0, %tmp, contracting_dims = [2] x [2] : (tensor<2x2x3xf32>, tensor<2x2x3xf32>) -> tensor<2x2x2x2xf32> + return %2 : tensor<2x2x2x2xf32> +} +// CHECK: func.func @test_dot_general_propagate(%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2x2x2xf32> { +// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32> +// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %cst, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32> +// CHECK-NEXT: %1 = stablehlo.dot_general %0, %0, contracting_dims = [1] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// CHECK-NEXT: %2 = stablehlo.broadcast_in_dim %1, dims = [0, 1] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2xf32>) -> tensor<2x2x3xf32> +// CHECK-NEXT: %3 = stablehlo.dot_general %arg0, %2, contracting_dims = [2] x [2] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>, <[2, 3]>>]} : (tensor<2x2x3xf32>, tensor<2x2x3xf32>) -> tensor<2x2x2x2xf32> +// CHECK-NEXT: return %3 : tensor<2x2x2x2xf32> +// CHECK-NEXT: } + +func.func @test_dot_general_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { + %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32> + %1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32> + return %1 : tensor<3x3x3xf32> +} +// CHECK: func.func @test_dot_general_generate_symmetry(%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> { +// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32> +// CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32> +// CHECK-NEXT: return %1 : tensor<3x3x3xf32> +// CHECK-NEXT: } + +func.func @test_scalar_multiply(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = stablehlo.constant dense<9.900000e-01> : tensor + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = stablehlo.add %0, %arg0 : tensor<2x2xf32> + %2 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x2xf32> + %3 = stablehlo.multiply %1, %2 : tensor<2x2xf32> + return %3 : tensor<2x2xf32> +} +// CHECK: func.func @test_scalar_multiply(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK-NEXT: %cst = stablehlo.constant dense<9.900000e-01> : tensor +// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> +// CHECK-NEXT: %1 = stablehlo.add %0, %arg0 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : tensor<2x2xf32> +// CHECK-NEXT: %2 = stablehlo.broadcast_in_dim %cst, dims = [] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor) -> tensor<2x2xf32> +// CHECK-NEXT: %3 = stablehlo.multiply %1, %2 {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : tensor<2x2xf32> +// CHECK-NEXT: return %3 : tensor<2x2xf32> +// CHECK-NEXT: } +