Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,11 @@ gentbl_cc_library(
name = "DistributedInterfacesIncGen",
tbl_outs = [
(
["--gen-interface-decls"],
["--gen-op-interface-decls"],
"Dialect/Distributed/DistributedInterfaces.h.inc",
),
(
["--gen-interface-defs"],
["--gen-op-interface-defs"],
"Dialect/Distributed/DistributedInterfaces.cpp.inc",
),
],
Expand All @@ -584,6 +584,31 @@ gentbl_cc_library(
],
)

td_library(
name = "DistributedPassesTdFiles",
srcs = [
],
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)

gentbl_cc_library(
name = "DistributedPassesIncGen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=distributed",
],
"Passes/Distributed/Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes/Distributed/Passes.td",
deps = [":DistributedPassesTdFiles"],
)

td_library(
name = "TesseraDialectTdFiles",
srcs = [
Expand Down Expand Up @@ -717,6 +742,7 @@ cc_library(
srcs = glob([
"Implementations/*.cpp",
"Passes/*.cpp",
"Passes/Distributed/*.cpp",
"Dialect/*.cpp",
"Dialect/Distributed/*.cpp",
"Dialect/Tessera/*.cpp",
Expand All @@ -726,6 +752,7 @@ cc_library(
hdrs = glob([
"Implementations/*.h",
"Passes/*.h",
"Passes/Distributed/*.h",
"Dialect/*.h",
"Dialect/Distributed/*.h",
"Dialect/Tessera/*.h",
Expand All @@ -744,7 +771,9 @@ cc_library(
deps = [
":CheckedRewrite",
":DistributedDialectIncGen",
":DistributedInterfacesIncGen",
":DistributedOpsIncGen",
":DistributedPassesIncGen",
":DistributedTypesIncGen",
":EnzymeHLOPatternsIncGen",
":EnzymeXLAAttrsIncGen",
Expand Down
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/Dialect/Distributed/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#define GET_TYPEDEF_CLASSES
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedTypes.cpp.inc"

#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedInterfaces.cpp.inc"

// Initialize the dialect
void mlir::enzyme::distributed::DistributedDialect::initialize() {
addTypes<
Expand Down
35 changes: 30 additions & 5 deletions src/enzyme_ad/jax/Dialect/Distributed/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,40 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Types.h"

// Include the dialect
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedDialect.h.inc"
// Traits and interfaces
#include "Traits.h"
// Types
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedDialect.h.inc"

#define GET_TYPEDEF_CLASSES
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedTypes.h.inc"
// Operations

#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedInterfaces.h.inc"

#define GET_OP_CLASSES
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedOps.h.inc"

/**
* Convenience class to manage tokens, which are sometimes used as
* block args and other time as typed values.
*/
namespace mlir::enzyme::distributed {
class Token {
mlir::TypedValue<TokenType> typedValue;
mlir::BlockArgument blockArg;

public:
Token(mlir::BlockArgument arg) : blockArg(arg) {
typedValue = dyn_cast<mlir::TypedValue<TokenType>>(arg);
assert(typedValue && "Block arg is not a token");
}
Token(mlir::TypedValue<TokenType> val) : typedValue(val) {
assert(val && "Typed value is null");
blockArg = dyn_cast<mlir::BlockArgument>(val);
assert(blockArg && "Typed value is not a block argument");
}

const mlir::TypedValue<TokenType> asTypedValue() const { return typedValue; }
const mlir::BlockArgument asBlockArg() const { return blockArg; }
};
} // namespace mlir::enzyme::distributed

#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_DIALECT_H
25 changes: 25 additions & 0 deletions src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,29 @@ include "mlir/IR/OpBase.td"
def DeviceDefTrait : NativeOpTrait<"enzyme::distributed::DeviceDefTrait">;
def ChannelDefTrait : NativeOpTrait<"enzyme::distributed::ChannelDefTrait">;

def TokenReaderOpInterface : OpInterface<"TokenReaderOpInterface"> {
let cppNamespace = "::mlir::enzyme::distributed";
let description = [{
An interface to determine which ops can read from a channel and what type they expect.
Ops may read from multiple channels.
}];
let methods = [
InterfaceMethod<"Returns the SSA values of tokens read from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::TokenType>>", "getReadTokens">,
InterfaceMethod<"Returns the types of tokens read from this op, parallel to getReadTokens.", "::llvm::ArrayRef<::mlir::Type>", "getReadTokenTypes">
];
}

def TokenWriterOpInterface : OpInterface<"TokenWriterOpInterface"> {
let cppNamespace = "::mlir::enzyme::distributed";
let description = [{
An interface to determine which ops can write to a channel and what type they provide.
Ops may write to multiple channels.
}];
let methods = [
InterfaceMethod<"Returns the SSA values of tokens written from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::TokenType>>", "getWriteTokens">,
InterfaceMethod<"Returns the types of tokens written from this op, parallel to getWriteTokens.", "::llvm::ArrayRef<::mlir::Type>", "getWriteTokenTypes">
];
}


#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_INTERFACES
119 changes: 95 additions & 24 deletions src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "llvm/ADT/TypeSwitch.h"

#include "Dialect.h"
#include "Utils.h"

using mlir::OpTrait::enzyme::distributed::ChannelDefTrait;
using mlir::OpTrait::enzyme::distributed::DeviceDefTrait;
Expand Down Expand Up @@ -98,38 +99,108 @@ DeviceMeshOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
getDeviceType());
}

LogicalResult
MeshForOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
// Mesh for ops apply only to meshes
return checkSymbolIsA<DeviceMeshOp>(symbol_table, *this, getMeshAttr());
Operation *DeviceParallelOp::getEnclosingDeviceOp() {
return mlir::SymbolTable::lookupNearestSymbolFrom(*this,
getEnclosingDeviceAttr());
}

LogicalResult
GroupSplitOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
// Group splits apply only to device groups
return checkSymbolIsA<DeviceGroupOp>(symbol_table, *this,
getDeviceGroupAttr());
LogicalResult DeviceParallelOp::verifySymbolUses(
::mlir::SymbolTableCollection &symbol_table) {
Operation *device_op = this->getEnclosingDeviceOp();
if (isa<DeviceGroupOp>(device_op) || isa<DeviceMeshOp>(device_op)) {
return mlir::success();
}
return emitOpError()
<< "enclosing device symbol must refer to a device group or mesh";
}

LogicalResult
SplitBranchOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
// Split branches have programs for individual devices or channels
Operation *dev_or_chan =
symbol_table.lookupNearestSymbolFrom(*this, getDeviceOrChannelAttr());
if (!dev_or_chan || !(dev_or_chan->hasTrait<DeviceDefTrait>() ||
dev_or_chan->hasTrait<ChannelDefTrait>())) {
mlir::emitError(getLoc())
<< "branches must reference a valid device or channel";
return mlir::failure();
LogicalResult DeviceParallelOp::verify() {
// Check number of branches matches number of assignments

if (getNumRegions() != getBranchAssignments().size()) {
return emitOpError()
<< "number of regions must match number of branch assignments";
}

// Look at device type to determine number of branches
auto device_op = mlir::SymbolTable::lookupNearestSymbolFrom(
*this, getEnclosingDeviceAttr());
if (!device_op) {
return emitOpError() << "could not find enclosing device symbol";
}

if (DeviceGroupOp deviceGroup = dyn_cast<DeviceGroupOp>(device_op)) {
// Device group: number of branches must match number of devices in group
auto devices = deviceGroup.getDevices();
auto channels = deviceGroup.getChannels();
if (getNumRegions() != devices.size() + channels.size()) {
return emitOpError() << "number of regions must match number of devices "
"and channels in device group";
}
} else if (DeviceMeshOp mesh = dyn_cast<DeviceMeshOp>(device_op)) {
// Exactly one branch for the mesth type
if (getNumRegions() != 1) {
return emitOpError()
<< "device mesh must have exactly one region for its single type";
}
} else {
return emitOpError()
<< "enclosing device symbol must refer to a device group or mesh";
}

return mlir::success();
}

LogicalResult
DefineTokenOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
// Tokens need to indicate which channel they communicate over
return checkSymbolHasTrait<ChannelDefTrait>(symbol_table, *this,
getChannelAttr());
// Printer/parser for subdevice branches
mlir::ParseResult parseDeviceBranches(
OpAsmParser &parser, mlir::ArrayAttr &branchAssignments,
llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> &branchesRegions) {
// Expect 0 or more `branch` $symbol_name $symbol_region
// While next token is `branch`:
llvm::SmallVector<mlir::Attribute, 2> assignment_symbols;
while (parser.parseOptionalKeyword("branch").succeeded()) {
// Parse symbol name
mlir::SymbolRefAttr sym;
auto sym_parse_failed = parser.parseAttribute<mlir::SymbolRefAttr>(sym);
if (sym_parse_failed)
return mlir::failure();
assignment_symbols.push_back(sym);

// Put placeholder region in list and parse into it
branchesRegions.push_back(std::make_unique<mlir::Region>());
auto parse_region_failed = parser.parseRegion(*branchesRegions.back());
if (parse_region_failed)
return mlir::failure();
}

branchAssignments = mlir::ArrayAttr::get(parser.getBuilder().getContext(),
assignment_symbols);
return mlir::success();
}

void printDeviceBranches(OpAsmPrinter &printer, const DeviceParallelOp &op,
const mlir::ArrayAttr branchAssignments,
const llvm::MutableArrayRef<mlir::Region> branches) {
// Print each branch as `branch` $symbol_name $symbol_region
for (size_t i = 0; i < branches.size(); i++) {
printer << " branch ";
printer.printAttribute(branchAssignments[i]);
printer.printRegion(branches[i]);
}
}

llvm::ArrayRef<mlir::TypedValue<TokenType>> SendOp::getWriteTokens() {
return llvm::SmallVector<mlir::TypedValue<TokenType>, 1>{getToken()};
}
llvm::ArrayRef<mlir::Type> SendOp::getWriteTokenTypes() {
return llvm::SmallVector<mlir::Type, 1>{getValue().getType()};
}

llvm::ArrayRef<mlir::TypedValue<TokenType>> RecvOp::getReadTokens() {
return llvm::SmallVector<mlir::TypedValue<TokenType>, 1>{getToken()};
}
llvm::ArrayRef<mlir::Type> RecvOp::getReadTokenTypes() {
return llvm::SmallVector<mlir::Type, 1>{getValue().getType()};
}

} // namespace mlir::enzyme::distributed
Expand Down
Loading
Loading