Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4b6eac7
Add partial symmetry detection
gaurav-arya Nov 27, 2025
ddcbcaf
Remove generalized A * A^T handling (soundness unclear)
gaurav-arya Nov 28, 2025
a955487
Add general dot_general logic for case where lhs = rhs
gaurav-arya Nov 29, 2025
be1229b
Add transpose symmetry generation logic for dot general
gaurav-arya Nov 29, 2025
41530c8
Progress refactoring symmetry generation logic for transpose
gaurav-arya Nov 29, 2025
37ccbde
Fix issue with rank computation
gaurav-arya Nov 29, 2025
f6d911a
Simplify elementwise propagation logic
gaurav-arya Nov 29, 2025
d8336d1
Some code cleanup
gaurav-arya Nov 29, 2025
a172118
Add test of dot_general symm gen + remove seemingly unnecessary
gaurav-arya Nov 29, 2025
0582344
Format
gaurav-arya Nov 29, 2025
8f93a70
Remove debug messages
gaurav-arya Nov 29, 2025
5606faf
Add n-dim transpose removal opt
gaurav-arya Nov 29, 2025
77e5b3a
Recognize existing partial symmetry annotations in IR
gaurav-arya Dec 1, 2025
1f40a6d
Format
gaurav-arya Dec 1, 2025
e62ad33
Fix test func naming
gaurav-arya Dec 1, 2025
62b180e
Add missing symmetry detection for broadcast in dim
gaurav-arya Dec 2, 2025
86e02cb
Check for lhs == rhs in aliasing check
gaurav-arya Dec 2, 2025
aa46d99
Make symmetry detection within DotGeneral LHS/RHS free dims run even
gaurav-arya Dec 3, 2025
c892955
Fix probably incorrect transfer function for elementwise
gaurav-arya Dec 4, 2025
4480fbc
Fix join/meet naming
gaurav-arya Dec 4, 2025
0bf5740
Format
gaurav-arya Dec 10, 2025
9dca247
Merge branch 'main' into ag/ndim_symmetry_lattice
avik-pal Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
699 changes: 699 additions & 0 deletions src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.cpp

Large diffs are not rendered by default.

114 changes: 114 additions & 0 deletions src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#pragma once

#include <cstddef>
#include <cstdint>
#include <memory>

#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"

namespace mlir {
namespace enzyme {

// Represents the partial symmetry of a tensor as a partition of its dimensions.
class PartialSymmetryAnnotation {
public:
PartialSymmetryAnnotation() : dimensionSetIDs() {}

explicit PartialSymmetryAnnotation(ArrayRef<int64_t> dimensionSetIDs);

static PartialSymmetryAnnotation createUninitialized(int64_t rank);
static PartialSymmetryAnnotation createNotSymmetric(int64_t rank);
static PartialSymmetryAnnotation createFullySymmetric(int64_t rank);

bool isSymmetric(int64_t i, int64_t j) const;
int64_t getSetId(int64_t i) const { return dimensionSetIDs[i]; }
int64_t getRank() const { return dimensionSetIDs.size(); }

static PartialSymmetryAnnotation meet(const PartialSymmetryAnnotation &lhs,
const PartialSymmetryAnnotation &rhs);
static PartialSymmetryAnnotation join(const PartialSymmetryAnnotation &lhs,
const PartialSymmetryAnnotation &rhs);

static PartialSymmetryAnnotation
propagateTranspose(const PartialSymmetryAnnotation &annotation,
ArrayRef<int64_t> permutation);

static PartialSymmetryAnnotation
propagateBroadcastInDim(const PartialSymmetryAnnotation &annotation,
int64_t outputRank,
ArrayRef<int64_t> broadcastDimensions);

static PartialSymmetryAnnotation
propagateDotGeneral(const PartialSymmetryAnnotation &lhsAnnotation,
const PartialSymmetryAnnotation &rhsAnnotation,
int64_t resultRank, ArrayRef<int64_t> lhsBatchingDims,
ArrayRef<int64_t> rhsBatchingDims,
ArrayRef<int64_t> lhsContractingDims,
ArrayRef<int64_t> rhsContractingDims, bool rhsAliasesLhs,
ArrayRef<int64_t> rhsDimToLhs);

static PartialSymmetryAnnotation checkConstant(DenseElementsAttr attr);

static PartialSymmetryAnnotation
propagateElementwiseBinary(const PartialSymmetryAnnotation &lhsAnnotation,
const PartialSymmetryAnnotation &rhsAnnotation,
int64_t resultRank, bool rhsAliasesLhs,
ArrayRef<int64_t> rhsDimToLhs);

bool operator==(const PartialSymmetryAnnotation &other) const {
return dimensionSetIDs == other.dimensionSetIDs;
}

SmallVector<SmallVector<int64_t>> getDimensionSets() const;
static PartialSymmetryAnnotation
createFromDimensionSets(int64_t rank,
ArrayRef<ArrayRef<int64_t>> dimensionSets);
static std::optional<PartialSymmetryAnnotation> createFromIR(Value val);

void print(raw_ostream &os) const;

private:
SmallVector<int64_t> dimensionSetIDs;

void canonicalize();
void uniteDimensionSets(int64_t rank, int64_t i, int64_t j);
};

class PartialSymmetryLattice : public dataflow::AbstractSparseLattice {
public:
using AbstractSparseLattice::AbstractSparseLattice;

PartialSymmetryLattice(Value v);

ChangeResult meet(const AbstractSparseLattice &rhs) override;
ChangeResult meet(const PartialSymmetryLattice &rhs);

void print(raw_ostream &os) const override;

const PartialSymmetryAnnotation &getValue() const { return value; }
void setValue(const PartialSymmetryAnnotation &v) { value = v; }

private:
PartialSymmetryAnnotation value;
};

class PartialSymmetryAnalysis
: public dataflow::SparseForwardDataFlowAnalysis<PartialSymmetryLattice> {
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;

void setToEntryState(PartialSymmetryLattice *lattice) override;

LogicalResult
visitOperation(Operation *op,
ArrayRef<const PartialSymmetryLattice *> operands,
ArrayRef<PartialSymmetryLattice *> results) override;
};

} // namespace enzyme
} // namespace mlir
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ cc_library(
cc_library(
name = "XLADerivatives",
srcs = glob([
"Analysis/*.cpp",
"Implementations/*.cpp",
"Passes/*.cpp",
"Dialect/*.cpp",
Expand All @@ -849,6 +850,7 @@ cc_library(
"Utils.cpp",
],
hdrs = glob([
"Analysis/*.h",
"Implementations/*.h",
"Passes/*.h",
"Dialect/*.h",
Expand Down
32 changes: 32 additions & 0 deletions src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,36 @@ def EnzymeXLA_GuaranteedAnalysisResult : I32EnumAttr<"GuaranteedAnalysisResult",
def EnzymeXLA_GuaranteedAnalysisResultAttr : EnumAttr<EnzymeXLA_Dialect,
EnzymeXLA_GuaranteedAnalysisResult, "guaranteed">;

def EnzymeXLA_SymmetricDimensionSetAttr : AttrDef<EnzymeXLA_Dialect,
"SymmetricDimensionSet"> {
let summary = "A set of symmetric dimension indices";
let cppNamespace = "::mlir::enzymexla";

let parameters = (ins
"DenseI64ArrayAttr":$dimensions
);

let assemblyFormat = [{
`<` $dimensions `>`
}];

let mnemonic = "symmetric_dimension_set";
}

def EnzymeXLA_PartialSymmetryAnalysisResultAttr : AttrDef<EnzymeXLA_Dialect,
"PartialSymmetryAnalysisResult"> {
let summary = "Sets of partially symmetric dimensions";
let cppNamespace = "::mlir::enzymexla";

let parameters = (ins
ArrayRefParameter<"SymmetricDimensionSetAttr">:$values
);

let assemblyFormat = [{
`<` $values `>`
}];

let mnemonic = "partial_symmetry";
}

#endif // ENZYMEXLA_ATTRS
42 changes: 41 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h"
#include "src/enzyme_ad/jax/CheckedRewrite.h"
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
#include "src/enzyme_ad/jax/Dialect/Ops.h"
Expand Down Expand Up @@ -55,6 +57,7 @@
#include "llvm/ADT/MapVector.h"
#include <iterator>
#include <numeric>
#include <optional>
#define DEBUG_TYPE "enzymehloopt"

namespace mlir {
Expand Down Expand Up @@ -6957,6 +6960,42 @@ struct TransposeSymmetricSimplify
}
};

struct TransposePartialSymmetrySimplify
: public CheckedOpRewritePattern<stablehlo::TransposeOp,
TransposePartialSymmetrySimplify> {
using CheckedOpRewritePattern<
stablehlo::TransposeOp,
TransposePartialSymmetrySimplify>::CheckedOpRewritePattern;

LogicalResult matchAndRewriteImpl(stablehlo::TransposeOp op,
PatternRewriter &rewriter) const {
auto operand = op.getOperand();
auto perm = op.getPermutation();

auto annotationOpt =
enzyme::PartialSymmetryAnnotation::createFromIR(operand);
if (!annotationOpt.has_value())
return failure();

auto annotation = annotationOpt.value();

bool isIdentity = true;
for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
if (annotation.getSetId(i) != annotation.getSetId(perm[i])) {
isIdentity = false;
break;
}
}

if (isIdentity) {
rewriter.replaceOp(op, operand);
return success();
}

return failure();
}
};

struct NoNanSelfSubSimplify
: public NoNanCheckedOpRewritePattern<stablehlo::SubtractOp,
NoNanSelfSubSimplify> {
Expand Down Expand Up @@ -26653,7 +26692,8 @@ struct EnzymeHLOOptPass
NoNanAddSubSimplify, NoNanMulSimplify, NoNanDivSimplify>(
(no_nan || all_finite), context);

patterns.add<TransposeSymmetricSimplify>(context);
patterns.add<TransposeSymmetricSimplify, TransposePartialSymmetrySimplify>(
context);
patterns.add<FactorScalarsInDotGeneral>(context);

// syrk patterns
Expand Down
99 changes: 99 additions & 0 deletions src/enzyme_ad/jax/Passes/PartialSymmetryAnnotate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include "src/enzyme_ad/jax/Analysis/PartialSymmetryAnalysis.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"

#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"

#include "mlir/Support/LLVM.h"
#include "llvm/Support/raw_ostream.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
#include "src/enzyme_ad/jax/Dialect/Ops.h"
#include "stablehlo/dialect/StablehloOps.h"

#define DEBUG_TYPE "partial-symmetry-annotate"

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_PARTIALSYMMETRYANNOTATEPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

using namespace mlir;
using namespace mlir::dataflow;
using namespace mlir::enzyme;

namespace {

class PartialSymmetryAnnotatePass
: public enzyme::impl::PartialSymmetryAnnotatePassBase<
PartialSymmetryAnnotatePass> {
public:
using Base::Base;

void runOnOperation() override {
DataFlowSolver solver;

solver.load<enzyme::PartialSymmetryAnalysis>();
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();

if (failed(solver.initializeAndRun(getOperation()))) {
return signalPassFailure();
}

auto mod = getOperation();

// Annotate all operations with partial symmetry information
mod->walk([&](Operation *op) {
SmallVector<Attribute> partialSymmetryAttrs;
bool anyKnown = false;

for (auto result : op->getResults()) {
auto *state =
solver.lookupState<enzyme::PartialSymmetryLattice>(result);
if (!state) {
continue;
}

auto dimensionSets = state->getValue().getDimensionSets();

SmallVector<enzymexla::SymmetricDimensionSetAttr> dimensionSetAttrs;
for (const auto &set : dimensionSets) {
if (set.size() > 1) {
anyKnown = true;
auto denseAttr = DenseI64ArrayAttr::get(mod.getContext(), set);
auto dimensionSetAttr = enzymexla::SymmetricDimensionSetAttr::get(
mod.getContext(), denseAttr);
dimensionSetAttrs.push_back(dimensionSetAttr);
}
}

if (dimensionSetAttrs.empty()) {
continue;
}

auto partialSymmetry =
enzymexla::PartialSymmetryAnalysisResultAttr::get(
mod.getContext(), dimensionSetAttrs);
partialSymmetryAttrs.push_back(partialSymmetry);
}

if (anyKnown) {
op->setAttr("enzymexla.partial_symmetry",
ArrayAttr::get(mod.getContext(), partialSymmetryAttrs));
}

return WalkResult::advance();
});
}
};

} // namespace
9 changes: 9 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,15 @@ def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> {
];
}

def PartialSymmetryAnnotatePass : Pass<"partial-symmetry-annotate", "ModuleOp"> {
let summary = "Annotate operations using partial symmetry analysis";
let dependentDialects = [
"stablehlo::StablehloDialect",
"enzymexla::EnzymeXLADialect",
"func::FuncDialect",
];
}

def ConvertAllConstantsToSplattedConstantPass : Pass<"convert-all-constants-to-splatted-constant", "ModuleOp"> {
let summary = "Convert all constants to splatted constants. This is supposed to be used for debugging purposes when dumping the module.";
let dependentDialects = [
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,11 @@ def ApplyTransposeSymmetricSimplify : EnzymeHLOPatternOp<
let patterns = ["TransposeSymmetricSimplify"];
}

def ApplyTransposePartialSymmetrySimplify : EnzymeHLOPatternOp<
"transpose_partial_symmetry_simplify"> {
let patterns = ["TransposePartialSymmetrySimplify"];
}

def ApplyFactorScalarsInDotGeneral : EnzymeHLOPatternOp<
"factor_scalars_in_dot_general"> {
let patterns = ["FactorScalarsInDotGeneral"];
Expand Down
3 changes: 3 additions & 0 deletions src/enzyme_ad/jax/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ bool mayAlias(mlir::MemoryEffects::EffectInstance a,

bool mayAlias(mlir::MemoryEffects::EffectInstance a, mlir::Value b);

bool checkNotEqual(llvm::APInt a, llvm::APInt b);
bool checkNotEqual(llvm::APFloat a, llvm::APFloat b);

bool canApplyNoNanPattern(bool allowOnFloatingPointMath, Type Ty);
bool canApplyNoNanPattern(bool allowOnFloatingPointMath, Type Ty,
mlir::Operation *op, PatternRewriter &rewriter);
Expand Down
Loading
Loading