Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
756 changes: 756 additions & 0 deletions src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp

Large diffs are not rendered by default.

292 changes: 292 additions & 0 deletions src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
#pragma once

#include "mlir/Analysis/DataFlow/SparseAnalysis.h"

#include "src/enzyme_ad/jax/Dialect/Ops.h"

#include <algorithm>
#include <cstdint>

namespace mlir {
namespace structure_analysis {

namespace utils {

static bool isZero(APInt v) { return v.isZero(); }
static bool isZero(APFloat v) { return v.isZero(); }
static bool isZero(Attribute v) {
if (auto intAttr = dyn_cast<IntegerAttr>(v))
return isZero(intAttr.getValue());
if (auto floatAttr = dyn_cast<FloatAttr>(v))
return isZero(floatAttr.getValue());
return false;
}

static bool isOne(APInt v) { return v.isOne(); }
static bool isOne(APFloat v) { return v.isExactlyValue(1.0); }
static bool isOne(Attribute v) {
if (auto intAttr = dyn_cast<IntegerAttr>(v))
return isOne(intAttr.getValue());
if (auto floatAttr = dyn_cast<FloatAttr>(v))
return isOne(floatAttr.getValue());
return false;
}

static bool areEqual(APInt a, APInt b) { return a == b; }
static bool areEqual(APFloat a, APFloat b) {
return a.compare(b) == llvm::APFloat::cmpEqual;
}

} // namespace utils

//===----------------------------------------------------------------------===//
// Structured Sparsity Pattern Implementation
//===----------------------------------------------------------------------===//

enum class StructuredSparsityKind {
Dense,
Band,
UpperTriangular,
UpperBidiagonal,
LowerTriangular,
LowerBidiagonal,
Tridiagonal,
Diagonal,
Empty, // denotes that all elements are structural zeros
Unknown,
};

// TODO: currently only legal negative value is -1, which means "unknown"
// we should support negative bandwidths
class StructuredSparsityPattern {
public:
StructuredSparsityPattern()
: kind(StructuredSparsityKind::Unknown), lowerBandwidth(-1),
upperBandwidth(-1) {}

explicit StructuredSparsityPattern(StructuredSparsityKind kind)
: kind(kind), lowerBandwidth(-1), upperBandwidth(-1) {
initializeBandwidths();
}

StructuredSparsityPattern(Value v);

StructuredSparsityPattern(int64_t lowerBandwidth, int64_t upperBandwidth)
: kind(StructuredSparsityKind::Band), lowerBandwidth(lowerBandwidth),
upperBandwidth(upperBandwidth) {
refineKind();
}

StructuredSparsityKind getKind() const { return kind; }
int64_t getLowerBandwidth() const { return lowerBandwidth; }
int64_t getUpperBandwidth() const { return upperBandwidth; }

static StructuredSparsityPattern meet(const StructuredSparsityPattern &lhs,
const StructuredSparsityPattern &rhs);

static StructuredSparsityPattern join(const StructuredSparsityPattern &lhs,
const StructuredSparsityPattern &rhs);

bool operator==(const StructuredSparsityPattern &other) const {
return kind == other.kind && lowerBandwidth == other.lowerBandwidth &&
upperBandwidth == other.upperBandwidth;
}

void print(raw_ostream &os) const;
raw_ostream &operator<<(raw_ostream &os) const {
print(os);
return os;
}

// propagation rules
static StructuredSparsityPattern
propagateTranspose(Value val, const StructuredSparsityPattern &op);

private:
void initializeBandwidths();
void refineKind();

void setUnknown() {
kind = StructuredSparsityKind::Unknown;
lowerBandwidth = -1;
upperBandwidth = -1;
}

StructuredSparsityKind kind;
int64_t lowerBandwidth;
int64_t upperBandwidth;
};

//===----------------------------------------------------------------------===//
// Value Properties Implementation
//===----------------------------------------------------------------------===//

enum class ValueProperty {
UnitDiagonal = 1 << 0,
Symmetric = 1 << 1,
Hermitian = 1 << 2,
BroadcastedScalar = 1 << 3,
};

class ValueProperties {
public:
ValueProperties() = default;
explicit ValueProperties(uint32_t flags) : flags(flags) {}

ValueProperties(Value v);

void set(ValueProperty property) { flags |= static_cast<uint32_t>(property); }
void clear(ValueProperty property) {
flags &= ~static_cast<uint32_t>(property);
}
bool has(ValueProperty property) const {
return flags & static_cast<uint32_t>(property);
}

bool hasUnitDiagonal() const { return has(ValueProperty::UnitDiagonal); }
bool isSymmetric() const { return has(ValueProperty::Symmetric); }
bool isHermitian() const { return has(ValueProperty::Hermitian); }
bool isBroadcastedScalar() const {
return has(ValueProperty::BroadcastedScalar);
}

void print(raw_ostream &os) const;
raw_ostream &operator<<(raw_ostream &os) const {
print(os);
return os;
}

uint32_t getFlags() const { return flags; }
void setFlags(uint32_t f) { flags = f; }

static ValueProperties meet(const ValueProperties &lhs,
const ValueProperties &rhs);

static ValueProperties join(const ValueProperties &lhs,
const ValueProperties &rhs);

bool operator==(const ValueProperties &other) const {
return flags == other.flags;
}

private:
static ValueProperties getPropertiesFromDenseAttr(DenseElementsAttr attr);

static bool isUnitDiagonal(DenseElementsAttr attr, int64_t nrows,
int64_t ncols);
static std::tuple<int64_t, int64_t>
isSymmetricOrHermitian(DenseElementsAttr, int64_t nrows, int64_t ncols);

uint32_t flags = 0;
};

//===----------------------------------------------------------------------===//
// Structured Matrix Type
//===----------------------------------------------------------------------===//

class StructuredMatrixType {
public:
StructuredMatrixType() = default;
StructuredMatrixType(StructuredSparsityPattern sparsityPattern,
ValueProperties valueProperties)
: sparsityPattern(sparsityPattern), valueProperties(valueProperties) {}

StructuredMatrixType(Value v)
: StructuredMatrixType(StructuredSparsityPattern(v), ValueProperties(v)) {
}

const StructuredSparsityPattern &getSparsityPattern() const {
return sparsityPattern;
}
const ValueProperties &getProperties() const { return valueProperties; }

static StructuredMatrixType meet(const StructuredMatrixType &lhs,
const StructuredMatrixType &rhs);

static StructuredMatrixType join(const StructuredMatrixType &lhs,
const StructuredMatrixType &rhs);

bool operator==(const StructuredMatrixType &other) const {
return sparsityPattern == other.sparsityPattern &&
valueProperties == other.valueProperties;
}

void print(raw_ostream &os) const;
raw_ostream &operator<<(raw_ostream &os) const {
print(os);
return os;
}

// propagation rules
static StructuredMatrixType
propagateTranspose(Value val, const StructuredMatrixType &op);

static StructuredMatrixType propagateAdd(Value lhs, Value rhs,
const StructuredMatrixType &lhsType,
const StructuredMatrixType &rhsType);

static StructuredMatrixType
propagateMultiply(Value lhs, Value rhs, const StructuredMatrixType &lhsType,
const StructuredMatrixType &rhsType);

static StructuredMatrixType
propagateElementwise(ArrayRef<Value> operands,
SmallVectorImpl<StructuredMatrixType> &operandsType);

// TODO: implement queries that check both the sparsity pattern and value
// properties and return specific matrix kinds

private:
StructuredSparsityPattern sparsityPattern;
ValueProperties valueProperties;
};

//===----------------------------------------------------------------------===//
// Lattice Element
//===----------------------------------------------------------------------===//

class StructuredMatrixLattice : public dataflow::AbstractSparseLattice {
public:
using AbstractSparseLattice::AbstractSparseLattice;

StructuredMatrixLattice(Value v)
: AbstractSparseLattice(v), value(StructuredMatrixType(v)) {}

ChangeResult meet(const AbstractSparseLattice &rhs) override;
ChangeResult meet(StructuredMatrixLattice rhs);

ChangeResult join(const AbstractSparseLattice &rhs) override;
ChangeResult join(StructuredMatrixLattice rhs);

void print(raw_ostream &os) const override;
raw_ostream &operator<<(raw_ostream &os) const {
print(os);
return os;
}

const StructuredMatrixType &getValue() const { return value; }
void setValue(const StructuredMatrixType &v) { value = v; }

private:
StructuredMatrixType value;
};

//===----------------------------------------------------------------------===//
// Dataflow Analysis
//===----------------------------------------------------------------------===//

class StructuredMatrixAnalysis
: public dataflow::SparseForwardDataFlowAnalysis<StructuredMatrixLattice> {
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;

void setToEntryState(StructuredMatrixLattice *lattice) override;

LogicalResult
visitOperation(Operation *op,
ArrayRef<const StructuredMatrixLattice *> operands,
ArrayRef<StructuredMatrixLattice *> results) override;
};

} // namespace structure_analysis
} // namespace mlir
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ cc_library(
cc_library(
name = "XLADerivatives",
srcs = glob([
"Analysis/*.cpp",
"Implementations/*.cpp",
"Passes/*.cpp",
"Dialect/*.cpp",
Expand All @@ -848,6 +849,7 @@ cc_library(
"Utils.cpp",
],
hdrs = glob([
"Analysis/*.h",
"Implementations/*.h",
"Passes/*.h",
"Dialect/*.h",
Expand Down
Loading
Loading