Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Add various llvm raising passes #265

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,20 @@ cc_library(
":chlo-derivatives",
"@enzyme//:EnzymeMLIR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:AffineAnalysis",
"@llvm-project//mlir:AffineTransforms",
"@llvm-project//mlir:MemRefUtils",
"@llvm-project//mlir:MemorySlotInterfaces",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ControlFlowToSCF",
"@llvm-project//mlir:ArithUtils",
"@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:UBDialect",
"@llvm-project//mlir:DataLayoutInterfaces",
"@llvm-project//mlir:CommonFolders",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:ControlFlowInterfaces",
Expand Down
37 changes: 37 additions & 0 deletions src/enzyme_ad/jax/Dialect/Dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

//===----------------------------------------------------------------------===//
// Enzyme dialect definition.
Expand All @@ -33,4 +34,40 @@ class EnzymeXLA_Op<string mnemonic, list<Trait> traits = []>

class EnzymeXLA_Type<string name> : TypeDef<EnzymeXLA_Dialect, name>;

//===----------------------------------------------------------------------===//
// AtAddrOp
//===----------------------------------------------------------------------===//

def AtAddrOp : EnzymeXLA_Op<"ataddr", [Pure]> {
let summary =
"Construct a c-style memref at an addr";
let description = [{
}];
let arguments = (ins AnyType:$addr);
let results = (outs AnyRankedOrUnrankedMemRef:$dest);
let builders = [
OpBuilder<(ins "Value":$source), [{
return build($_builder, $_state,
TypeRange(MemRefType::get({ShapedType::kDynamic}, $_builder.getI8Type())),
ValueRange(source));
}]>,
OpBuilder<(ins "Value":$source, "Type":$dstType), [{
return build($_builder, $_state,
TypeRange(dstType),
ValueRange(source));
}]>
];
}

def AffineScopeOp : EnzymeXLA_Op<"scope", [
AffineScope,
AutomaticAllocationScope,
RecursiveMemoryEffects,
]>,
Arguments<(ins Variadic<AnyType>:$operands)>,
Results<(outs Variadic<AnyType>:$results)> {
let summary = "Inline affine scope";
let regions = (region SizedRegion<1>:$region);
}

#endif // ENZYMEXLA_DIALECT
110 changes: 110 additions & 0 deletions src/enzyme_ad/jax/Passes/BarrierUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
//===- BarrierUtil.h - Utilities for barrier removal --------* C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_LIB_DIALECT_SCF_TRANSFORMS_BARRIERUTILS_H_
#define MLIR_LIB_DIALECT_SCF_TRANSFORMS_BARRIERUTILS_H_

#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Block.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/ErrorHandling.h"

std::pair<mlir::Block *, mlir::Block::iterator>
findInsertionPointAfterLoopOperands(mlir::scf::ParallelOp op);

/// Emits the IR computing the total number of iterations in the loop. We don't
/// need to linearize them since we can allocate an nD array instead.
static llvm::SmallVector<mlir::Value>
emitIterationCounts(mlir::OpBuilder &rewriter, mlir::scf::ParallelOp op) {
using namespace mlir;
SmallVector<Value> iterationCounts;
for (auto bounds :
llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep())) {
Value lowerBound = std::get<0>(bounds);
Value upperBound = std::get<1>(bounds);
Value step = std::get<2>(bounds);
Value diff =
rewriter.create<arith::SubIOp>(op.getLoc(), upperBound, lowerBound);
Value count = rewriter.create<arith::CeilDivSIOp>(op.getLoc(), diff, step);
iterationCounts.push_back(count);
}
return iterationCounts;
}

mlir::Value callMalloc(mlir::OpBuilder &builder, mlir::ModuleOp module,
mlir::Location loc, mlir::Value arg);
mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(mlir::ModuleOp module);

template <typename T>
static mlir::Value
allocateTemporaryBuffer(mlir::OpBuilder &rewriter, mlir::Value value,
mlir::ValueRange iterationCounts, bool alloca = true,
mlir::DataLayout *DLI = nullptr,
mlir::Attribute memorySpace = nullptr);
template <>
mlir::Value allocateTemporaryBuffer<mlir::memref::AllocaOp>(
mlir::OpBuilder &rewriter, mlir::Value value,
mlir::ValueRange iterationCounts, bool alloca, mlir::DataLayout *DLI,
mlir::Attribute memorySpace) {
using namespace mlir;
SmallVector<int64_t> bufferSize(iterationCounts.size(), ShapedType::kDynamic);
mlir::Type elty = value.getType();
mlir::Type ty = rewriter.getI8Type();
if (alloca) {
if (auto allocaOp = value.getDefiningOp<memref::AllocaOp>()) {
auto mt = allocaOp.getType();
bool hasDynamicSize = false;
for (auto s : mt.getShape()) {
if (s == ShapedType::kDynamic) {
hasDynamicSize = true;
break;
}
}
if (!hasDynamicSize) {
for (auto s : mt.getShape()) {
bufferSize.push_back(s);
}
ty = mt.getElementType();
}
} else {
bufferSize.push_back(DLI->getTypeSize(elty));
}
}
auto type =
MemRefType::get(bufferSize, ty, MemRefLayoutAttrInterface{}, memorySpace);
return rewriter.create<memref::AllocaOp>(value.getLoc(), type,
iterationCounts);
}

template <>
mlir::Value allocateTemporaryBuffer<mlir::LLVM::AllocaOp>(
mlir::OpBuilder &rewriter, mlir::Value value,
mlir::ValueRange iterationCounts, bool alloca, mlir::DataLayout *DLI,
mlir::Attribute memorySpace) {
llvm_unreachable("llvm alloca");
using namespace mlir;
auto val = value.getDefiningOp<LLVM::AllocaOp>();
auto sz = val.getArraySize();
assert(DLI);
for (auto iter : iterationCounts) {
sz = cast<TypedValue<IntegerType>>(
rewriter
.create<arith::MulIOp>(value.getLoc(), sz,
rewriter.create<arith::IndexCastOp>(
value.getLoc(), sz.getType(), iter))
.getResult());
}
return rewriter.create<LLVM::AllocaOp>(value.getLoc(), val.getType(), sz);
}

#endif // MLIR_LIB_DIALECT_SCF_TRANSFORMS_BARRIERUTILS_H_
41 changes: 41 additions & 0 deletions src/enzyme_ad/jax/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
add_flag_if_supported("-Wno-global-constructors" WNO_GLOBAL_CONSTRUCTOR_MLIR_IR)
add_mlir_conversion_library(MLIRGPULaunchToCall
LoopDistribute.cpp
GPULaunchToCall.cpp
OutlineGpuJitRegions.cpp
LLVMToMemref.cpp
PromoteWhile.cpp
GPUAffineOpt.cpp
ReshapeMemrefs.cpp
DependenceInfo.cpp
ISLUtils.cpp
GPULowering.cpp
LoopUndistribute.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/GPULaunchToCall

DEPENDS
MLIRConversionPassIncGen
intrinsics_gen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRGPUDialect
MLIRAffineDialect
MLIRSCFDialect
MLIRPass
PolymerTransforms
PolymerSupport
PolymerTargetISL
)

target_include_directories(MLIRGPULaunchToCall
PRIVATE
${POLYMER_MAIN_INCLUDE_DIR}/polymer/Target/ISL
${MLIR_MAIN_INCLUDE_DIR}/../../polly/include/
${MLIR_MAIN_INCLUDE_DIR}/../../polly/lib/External/isl/include/
${LLVM_BINARY_DIR}/tools/polly/lib/External/isl/include
)
77 changes: 77 additions & 0 deletions src/enzyme_ad/jax/Passes/ControlFlowToSCF.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Define conversions from the ControlFlow dialect to the SCF dialect.
//
//===----------------------------------------------------------------------===//

#include "Passes.h"

#include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/CFGToSCF.h"

namespace mlir {
#define GEN_PASS_DEF_ENZYMELIFTCONTROLFLOWTOSCFPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {

struct EnzymeLiftControlFlowToSCF
: public impl::EnzymeLiftControlFlowToSCFPassBase<EnzymeLiftControlFlowToSCF> {

using Base::Base;

void runOnOperation() override {
ControlFlowToSCFTransformation transformation;

bool changed = false;
Operation *op = getOperation();
WalkResult result = op->walk([&](Region *region) {
if (region->empty())
return WalkResult::advance();

Operation *regionParent = region->getParentOp();
auto &domInfo = regionParent != op
? getChildAnalysis<DominanceInfo>(regionParent)
: getAnalysis<DominanceInfo>();

auto visitor = [&](Operation *innerOp) -> WalkResult {
for (Region &reg : innerOp->getRegions()) {
FailureOr<bool> changedFunc =
transformCFGToSCF(reg, transformation, domInfo);
if (failed(changedFunc))
return WalkResult::interrupt();

changed |= *changedFunc;
}
return WalkResult::advance();
};

if (region->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
return WalkResult::interrupt();

return WalkResult::advance();
});
if (result.wasInterrupted())
return signalPassFailure();

if (!changed)
markAllAnalysesPreserved();
}
};
} // namespace
Loading
Loading