diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 24ac998d4d..70a02a2c43 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -569,11 +569,11 @@ gentbl_cc_library( name = "DistributedInterfacesIncGen", tbl_outs = [ ( - ["--gen-interface-decls"], + ["--gen-op-interface-decls"], "Dialect/Distributed/DistributedInterfaces.h.inc", ), ( - ["--gen-interface-defs"], + ["--gen-op-interface-defs"], "Dialect/Distributed/DistributedInterfaces.cpp.inc", ), ], @@ -584,6 +584,31 @@ gentbl_cc_library( ], ) +td_library( + name = "DistributedPassesTdFiles", + srcs = [ + ], + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "DistributedPassesIncGen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=distributed", + ], + "Passes/Distributed/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes/Distributed/Passes.td", + deps = [":DistributedPassesTdFiles"], +) + td_library( name = "TesseraDialectTdFiles", srcs = [ @@ -717,6 +742,7 @@ cc_library( srcs = glob([ "Implementations/*.cpp", "Passes/*.cpp", + "Passes/Distributed/*.cpp", "Dialect/*.cpp", "Dialect/Distributed/*.cpp", "Dialect/Tessera/*.cpp", @@ -726,6 +752,7 @@ cc_library( hdrs = glob([ "Implementations/*.h", "Passes/*.h", + "Passes/Distributed/*.h", "Dialect/*.h", "Dialect/Distributed/*.h", "Dialect/Tessera/*.h", @@ -744,7 +771,9 @@ cc_library( deps = [ ":CheckedRewrite", ":DistributedDialectIncGen", + ":DistributedInterfacesIncGen", ":DistributedOpsIncGen", + ":DistributedPassesIncGen", ":DistributedTypesIncGen", ":EnzymeHLOPatternsIncGen", ":EnzymeXLAAttrsIncGen", diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Dialect.cpp b/src/enzyme_ad/jax/Dialect/Distributed/Dialect.cpp index 343d327678..969b513e6e 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Dialect.cpp +++ b/src/enzyme_ad/jax/Dialect/Distributed/Dialect.cpp @@ -8,6 +8,8 @@ #define GET_TYPEDEF_CLASSES #include "src/enzyme_ad/jax/Dialect/Distributed/DistributedTypes.cpp.inc" +#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedInterfaces.cpp.inc" + // Initialize the dialect void mlir::enzyme::distributed::DistributedDialect::initialize() { addTypes< diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h b/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h index e8277364c6..e432c4c334 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h +++ b/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h @@ -10,15 +10,40 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" -// Include the dialect -#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedDialect.h.inc" -// Traits and interfaces #include "Traits.h" -// Types +#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedDialect.h.inc" + #define GET_TYPEDEF_CLASSES #include "src/enzyme_ad/jax/Dialect/Distributed/DistributedTypes.h.inc" -// Operations + +#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedInterfaces.h.inc" + #define GET_OP_CLASSES #include "src/enzyme_ad/jax/Dialect/Distributed/DistributedOps.h.inc" +/** + * Convenience class to manage tokens, which are sometimes used as + * block args and other time as typed values. + */ +namespace mlir::enzyme::distributed { +class Token { + mlir::TypedValue typedValue; + mlir::BlockArgument blockArg; + +public: + Token(mlir::BlockArgument arg) : blockArg(arg) { + typedValue = dyn_cast>(arg); + assert(typedValue && "Block arg is not a token"); + } + Token(mlir::TypedValue val) : typedValue(val) { + assert(val && "Typed value is null"); + blockArg = dyn_cast(val); + assert(blockArg && "Typed value is not a block argument"); + } + + const mlir::TypedValue asTypedValue() const { return typedValue; } + const mlir::BlockArgument asBlockArg() const { return blockArg; } +}; +} // namespace mlir::enzyme::distributed + #endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_DIALECT_H diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td b/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td index 8fd74fd4e2..2908c5fb50 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td +++ b/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td @@ -6,4 +6,29 @@ include "mlir/IR/OpBase.td" def DeviceDefTrait : NativeOpTrait<"enzyme::distributed::DeviceDefTrait">; def ChannelDefTrait : NativeOpTrait<"enzyme::distributed::ChannelDefTrait">; +def TokenReaderOpInterface : OpInterface<"TokenReaderOpInterface"> { + let cppNamespace = "::mlir::enzyme::distributed"; + let description = [{ + An interface to determine which ops can read from a channel and what type they expect. + Ops may read from multiple channels. + }]; + let methods = [ + InterfaceMethod<"Returns the SSA values of tokens read from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::TokenType>>", "getReadTokens">, + InterfaceMethod<"Returns the types of tokens read from this op, parallel to getReadTokens.", "::llvm::ArrayRef<::mlir::Type>", "getReadTokenTypes"> + ]; +} + +def TokenWriterOpInterface : OpInterface<"TokenWriterOpInterface"> { + let cppNamespace = "::mlir::enzyme::distributed"; + let description = [{ + An interface to determine which ops can write to a channel and what type they provide. + Ops may write to multiple channels. + }]; + let methods = [ + InterfaceMethod<"Returns the SSA values of tokens written from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::TokenType>>", "getWriteTokens">, + InterfaceMethod<"Returns the types of tokens written from this op, parallel to getWriteTokens.", "::llvm::ArrayRef<::mlir::Type>", "getWriteTokenTypes"> + ]; +} + + #endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_INTERFACES \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp b/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp index 21d62cecf6..894f989315 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp +++ b/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp @@ -2,6 +2,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "Dialect.h" +#include "Utils.h" using mlir::OpTrait::enzyme::distributed::ChannelDefTrait; using mlir::OpTrait::enzyme::distributed::DeviceDefTrait; @@ -98,38 +99,108 @@ DeviceMeshOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { getDeviceType()); } -LogicalResult -MeshForOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { - // Mesh for ops apply only to meshes - return checkSymbolIsA(symbol_table, *this, getMeshAttr()); +Operation *DeviceParallelOp::getEnclosingDeviceOp() { + return mlir::SymbolTable::lookupNearestSymbolFrom(*this, + getEnclosingDeviceAttr()); } -LogicalResult -GroupSplitOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { - // Group splits apply only to device groups - return checkSymbolIsA(symbol_table, *this, - getDeviceGroupAttr()); +LogicalResult DeviceParallelOp::verifySymbolUses( + ::mlir::SymbolTableCollection &symbol_table) { + Operation *device_op = this->getEnclosingDeviceOp(); + if (isa(device_op) || isa(device_op)) { + return mlir::success(); + } + return emitOpError() + << "enclosing device symbol must refer to a device group or mesh"; } -LogicalResult -SplitBranchOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { - // Split branches have programs for individual devices or channels - Operation *dev_or_chan = - symbol_table.lookupNearestSymbolFrom(*this, getDeviceOrChannelAttr()); - if (!dev_or_chan || !(dev_or_chan->hasTrait() || - dev_or_chan->hasTrait())) { - mlir::emitError(getLoc()) - << "branches must reference a valid device or channel"; - return mlir::failure(); +LogicalResult DeviceParallelOp::verify() { + // Check number of branches matches number of assignments + + if (getNumRegions() != getBranchAssignments().size()) { + return emitOpError() + << "number of regions must match number of branch assignments"; } + + // Look at device type to determine number of branches + auto device_op = mlir::SymbolTable::lookupNearestSymbolFrom( + *this, getEnclosingDeviceAttr()); + if (!device_op) { + return emitOpError() << "could not find enclosing device symbol"; + } + + if (DeviceGroupOp deviceGroup = dyn_cast(device_op)) { + // Device group: number of branches must match number of devices in group + auto devices = deviceGroup.getDevices(); + auto channels = deviceGroup.getChannels(); + if (getNumRegions() != devices.size() + channels.size()) { + return emitOpError() << "number of regions must match number of devices " + "and channels in device group"; + } + } else if (DeviceMeshOp mesh = dyn_cast(device_op)) { + // Exactly one branch for the mesth type + if (getNumRegions() != 1) { + return emitOpError() + << "device mesh must have exactly one region for its single type"; + } + } else { + return emitOpError() + << "enclosing device symbol must refer to a device group or mesh"; + } + return mlir::success(); } -LogicalResult -DefineTokenOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { - // Tokens need to indicate which channel they communicate over - return checkSymbolHasTrait(symbol_table, *this, - getChannelAttr()); +// Printer/parser for subdevice branches +mlir::ParseResult parseDeviceBranches( + OpAsmParser &parser, mlir::ArrayAttr &branchAssignments, + llvm::SmallVector, 2> &branchesRegions) { + // Expect 0 or more `branch` $symbol_name $symbol_region + // While next token is `branch`: + llvm::SmallVector assignment_symbols; + while (parser.parseOptionalKeyword("branch").succeeded()) { + // Parse symbol name + mlir::SymbolRefAttr sym; + auto sym_parse_failed = parser.parseAttribute(sym); + if (sym_parse_failed) + return mlir::failure(); + assignment_symbols.push_back(sym); + + // Put placeholder region in list and parse into it + branchesRegions.push_back(std::make_unique()); + auto parse_region_failed = parser.parseRegion(*branchesRegions.back()); + if (parse_region_failed) + return mlir::failure(); + } + + branchAssignments = mlir::ArrayAttr::get(parser.getBuilder().getContext(), + assignment_symbols); + return mlir::success(); +} + +void printDeviceBranches(OpAsmPrinter &printer, const DeviceParallelOp &op, + const mlir::ArrayAttr branchAssignments, + const llvm::MutableArrayRef branches) { + // Print each branch as `branch` $symbol_name $symbol_region + for (size_t i = 0; i < branches.size(); i++) { + printer << " branch "; + printer.printAttribute(branchAssignments[i]); + printer.printRegion(branches[i]); + } +} + +llvm::ArrayRef> SendOp::getWriteTokens() { + return llvm::SmallVector, 1>{getToken()}; +} +llvm::ArrayRef SendOp::getWriteTokenTypes() { + return llvm::SmallVector{getValue().getType()}; +} + +llvm::ArrayRef> RecvOp::getReadTokens() { + return llvm::SmallVector, 1>{getToken()}; +} +llvm::ArrayRef RecvOp::getReadTokenTypes() { + return llvm::SmallVector{getValue().getType()}; } } // namespace mlir::enzyme::distributed diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td index 0857ce2f2c..b330d53813 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td +++ b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td @@ -10,7 +10,7 @@ include "Interfaces.td" // Device definition ops -def ChannelOp : DistributedOp<"Channel", [Symbol, ChannelDefTrait, DeclareOpInterfaceMethods]>{ +def ChannelOp : DistributedOp<"channel", [Symbol, ChannelDefTrait, DeclareOpInterfaceMethods]>{ let arguments = (ins SymbolNameAttr:$sym_name, // a variadic list of devices connected by this channel @@ -21,7 +21,7 @@ def ChannelOp : DistributedOp<"Channel", [Symbol, ChannelDefTrait, DeclareOpInte let assemblyFormat = "$sym_name $sending_devices $receiving_devices attr-dict"; } -def LeafDeviceOp : DistributedOp<"LeafDevice", [Symbol, DeviceDefTrait]>{ +def LeafDeviceOp : DistributedOp<"leaf_device", [Symbol, DeviceDefTrait]>{ let arguments = (ins SymbolNameAttr:$sym_name // TODO: device type, e.g. TPU, GPU, CPU, and other attributes @@ -29,7 +29,7 @@ def LeafDeviceOp : DistributedOp<"LeafDevice", [Symbol, DeviceDefTrait]>{ let assemblyFormat = "$sym_name attr-dict"; } -def DeviceGroupOp : DistributedOp<"DeviceGroup", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods]>{ +def DeviceGroupOp : DistributedOp<"device_group", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods]>{ let arguments = (ins SymbolNameAttr:$sym_name, // a variadic list of devices in the group @@ -39,7 +39,7 @@ def DeviceGroupOp : DistributedOp<"DeviceGroup", [Symbol, DeviceDefTrait, Declar ); let assemblyFormat = "$sym_name $devices $channels attr-dict"; } -def DeviceMeshOp : DistributedOp<"DeviceMesh", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods]>{ +def DeviceMeshOp : DistributedOp<"device_mesh", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods]>{ let arguments = (ins SymbolNameAttr:$sym_name, SymbolRefAttr:$device_type, @@ -49,44 +49,56 @@ def DeviceMeshOp : DistributedOp<"DeviceMesh", [Symbol, DeviceDefTrait, DeclareO let assemblyFormat = "$sym_name $device_type $shape attr-dict"; } -// Ops for breaking down computation across the device hierarchy +// def ContinueOp : DistributedOp<"continue", [Terminator]> { +// let description = [{ +// A terminator for DeviceParallelOp regions. Takes as arguments the tokens to be passed to the +// continuation of the DeviceParallelOp. These values can then be used in a subsequent DeviceParallelOp +// that is a sibling to the original DeviceParallelOp by referencing the returned tokens. +// }]; +// let arguments = (ins Variadic:$operands); +// let results = (outs ); // No outputs for terminators, the token is output by the parent DeviceParallelOp. +// let assemblyFormat = "$operands type($operands) attr-dict"; +// } -def MeshForOp : DistributedOp<"MeshFor", [DeclareOpInterfaceMethods, NoTerminator, SingleBlock]>{ - let arguments = (ins SymbolRefAttr:$mesh); // TODO: verify it's a mesh - let regions = (region MaxSizedRegion<1>:$body); // TODO: body's block args are device type and mesh index - let results = (outs ); // TODO - // let hasVerifier = 1; // TODO: verify body's block args take mesh index - let assemblyFormat = "$mesh $body attr-dict"; -} +def DeviceParallelOp : DistributedOp<"device_parallel", [DeclareOpInterfaceMethods, NoTerminator]>{ + let description = [{ + An op for mapping computations to subdevices. Serves both for homogenous device meshes as well + as explicitly enumerated device groups. In the case of device meshes, this op should contain + a single region to be executed in parallel on each device. In the case of device groups, this + op should contain one region per device and channel in the group. -def GroupSplitOp : DistributedOp<"GroupSplit", [DeclareOpInterfaceMethods, NoTerminator, SingleBlock]>{ + In either case, regions must take as argument one device index within the parent device followed + by a number of token arguments. Tokens are matched by positionally between different branches, + and all branches must have the same number and type of token arguments (though they may be unused). + }]; + let arguments = (ins - SymbolRefAttr:$device_group // TODO: verify it's a group + SymbolRefAttr:$enclosing_device, + ArrayAttr:$branch_assignments // the device components for each region (device-specific branch) ); - let regions = (region SizedRegion<1>:$declarations); // Takes as args the devices and channels in the group - let results = (outs ); // TODO - // let hasVerifier = 1; // TODO - // let hasCanonicalizer = 1; // TODO: token declarations up front, followed by device and channel branches in order of listing in the group - let assemblyFormat = "$device_group $declarations attr-dict"; + let regions = (region VariadicRegion>:$branches); + // let results = (outs Variadic:$continuation_tokens); + let results = (outs ); + let hasVerifier = 1; // TODO + let assemblyFormat = "$enclosing_device `{` custom($branch_assignments, $branches) `}` attr-dict"; + let extraClassDeclaration = [{ + Operation* getEnclosingDeviceOp(); + }]; } -def SplitBranchOp : DistributedOp<"SplitBranch", [DeclareOpInterfaceMethods, NoTerminator, SingleBlock]>{ +def SendOp : DistributedOp<"send", [DeclareOpInterfaceMethods]>{ let arguments = (ins - SymbolRefAttr:$device_or_channel // TODO: verify it's a device or channel - ); - let regions = (region MaxSizedRegion<1>:$body); // Takes as args the device or channel - let results = (outs ); // TODO - // let hasVerifier = 1; // TODO: parent is a groupsplitop - let assemblyFormat = "$device_or_channel $body attr-dict"; + TokenType:$token, + // value to send + AnyType:$value); + let assemblyFormat = "$token type($value) $value attr-dict"; } -def DefineTokenOp : DistributedOp<"DefineToken", [DeclareOpInterfaceMethods]>{ +def RecvOp : DistributedOp<"recv", [DeclareOpInterfaceMethods]>{ let arguments = (ins - SymbolRefAttr:$channel - ); - let results = (outs TokenType:$token_out); - // let hasVerifier = 1; // TODO: verify writers and readers are connected to the channel - let assemblyFormat = "$channel attr-dict"; + TokenType:$token); + let results = (outs AnyType:$value); + let assemblyFormat = "$token type($value) attr-dict"; } #endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_OPS_TD \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp b/src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp new file mode 100644 index 0000000000..d347370554 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp @@ -0,0 +1,99 @@ +#include "Utils.h" +#include "Dialect.h" +namespace mlir::enzyme::distributed { + +Region *getEnclosingDeviceParallelBranch(DeviceParallelOp parent, + Operation *op) { + auto region = op->getParentRegion(); + while (region->getParentOp() != parent) { + auto region_parent = + region->getParentOp(); // All regions have parent ops... + if (!region_parent->getParentRegion()) // But not all ops have parent + // regions (e.g. top level ops) + return nullptr; + region = region_parent->getParentRegion(); + } + return region; +} + +int getDeviceParallelBranchIndex(DeviceParallelOp parent, Region *branch) { + assert(branch->getParentOp() == parent && "branch is not a region of parent"); + for (int i = 0; i < parent.getNumRegions(); i++) { + if (&parent.getRegion(i) == branch) + return i; + } + llvm_unreachable("branch not found in parent regions"); + return -1; +} + +mlir::Operation *getExecutingDevice(mlir::Operation *op) { + // Find current branch + auto parent = op->getParentOfType(); + auto branch = getEnclosingDeviceParallelBranch(parent, op); + if (!branch) + return nullptr; + // Find index of branch and cross-reference to parent device symbol + int branch_idx = getDeviceParallelBranchIndex(parent, branch); + auto device_sym = llvm::cast( + parent.getBranchAssignments()[branch_idx]); + + return SymbolTable::lookupNearestSymbolFrom(parent, device_sym); +} + +llvm::SmallVector getCorrespondingTokens(Token token) { + unsigned idx = token.asBlockArg().getArgNumber(); + auto op = token.asBlockArg().getOwner()->getParentOp(); + DeviceParallelOp parent = llvm::cast(op); + llvm::SmallVector results; + results.reserve(parent.getNumRegions()); + for (auto region : parent.getRegions()) { + results.push_back(Token(region->getArgument(idx))); + } + return results; +} + +llvm::SmallVector getTokenUsers(Token token) { + auto all_tokens = getCorrespondingTokens(token); + llvm::SmallVector results; + // Concatenate all users of all corresponding tokens. + // Due to scoping rules and since each token is a block arg to a + // different region, there should be no duplicates here. + for (auto t : all_tokens) { + for (auto user : t.asBlockArg().getUsers()) { + results.push_back(user); + } + } + return results; +} + +bool isSoleSender(TokenWriterOpInterface writer) { + auto tokens = writer.getWriteTokens(); + // Check for conflicts on all tokens + for (auto token : tokens) { + auto users = getTokenUsers(token); + if (!isSoleSender(writer, token, users)) { + return false; + } + } + return true; +} + +bool isSoleSender(TokenWriterOpInterface writer, Token token, + llvm::ArrayRef others) { + for (auto user : others) { + TypedValue as_val = token.asTypedValue(); + TokenWriterOpInterface other = dyn_cast(user); + if (other && other != writer) { + // Found another writer using the same token. Check if it uses + // the token to write, or only for something else: + auto other_write_tokens = other.getWriteTokens(); + for (auto t : other_write_tokens) { + if (t == as_val) { + return false; // Found another op writing to the same token + } + } + } + } + return true; +} +} // namespace mlir::enzyme::distributed \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Utils.h b/src/enzyme_ad/jax/Dialect/Distributed/Utils.h new file mode 100644 index 0000000000..1b402af68c --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Distributed/Utils.h @@ -0,0 +1,57 @@ +#ifndef ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H +#define ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H + +#include "Dialect.h" +#include "Traits.h" + +namespace mlir::enzyme::distributed { + +/** + * Get the enclosing device parallel branch for a given operation, or nullptr + * if the provided deviceParallelOp is not an ancestor of op. + */ +Region *getEnclosingDeviceParallelBranch(DeviceParallelOp parent, + Operation *op); + +/** + * Get the index of a device parallel branch within its parent operation. + * Parent op must be the direct parent of the branch region. + */ +int getDeviceParallelBranchIndex(DeviceParallelOp parent, Region *branch); + +/** + * Returns the defining op of the enclosing device of a given computational op + * (e.g. not the parent of a device defintion op). Returns nullptr if no such + * device can be found (not inside a device parallel region). + */ +mlir::Operation *getExecutingDevice(mlir::Operation *op); + +/** + * Returns the counterpart tokens across all branches for the provided token. + * Each token here corresponds to the same logical token, but passed as a + * different block argument to each branch. Tokens are ordered in the same order + * as the branches of the parent DeviceParallelOp. Includes token itself. + */ +llvm::SmallVector getCorrespondingTokens(Token token); + +/** + * Returns all users of the provided token or its counterpart across all + * branches, including readers, writers, and any other op that takes the token + * as an operand. + */ +llvm::SmallVector getTokenUsers(Token token); + +/** + * Returns true if no other ops ever write to any token written by the + * provided op. + */ +bool isSoleSender(TokenWriterOpInterface writer); + +/** + * Returns true if no other ops in the provided list send on the same channel. + */ +bool isSoleSender(TokenWriterOpInterface writer, Token token, + llvm::ArrayRef others); +} // namespace mlir::enzyme::distributed + +#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/Distributed/EliminateConstantCommunication.cpp b/src/enzyme_ad/jax/Passes/Distributed/EliminateConstantCommunication.cpp new file mode 100644 index 0000000000..ae2ff84a8d --- /dev/null +++ b/src/enzyme_ad/jax/Passes/Distributed/EliminateConstantCommunication.cpp @@ -0,0 +1,58 @@ +/** + * Replaces send(constant); recv(); with just constant. + */ + +#include "Passes.h" +#include "src/enzyme_ad/jax/Dialect/Distributed/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Distributed/Utils.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir::enzyme::distributed { +#define GEN_PASS_DEF_ELIMINATECONSTANTCOMMUNICATIONPASS +#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc" + +bool isConstantOp(Operation *op) { return isa(op); } +bool isConstant(Value val) { + if (auto op = val.getDefiningOp()) { + return isConstantOp(op); + } + return false; +} + +struct EliminateConstantCommunicationPass + : public impl::EliminateConstantCommunicationPassBase< + EliminateConstantCommunicationPass> { + using EliminateConstantCommunicationPassBase:: + EliminateConstantCommunicationPassBase; + void runOnOperation() override { + Operation *op = getOperation(); + // Post-order walk is allowed to erase the sends. Less sure if we + // are permitted to erase the recvs during the walk. + op->walk([&](enzyme::distributed::SendOp send) { + if (isConstant(send.getValue())) { + // Check that we are the only sender on this channel, and get + // the corresponding recvs. + auto users = getTokenUsers(send.getToken()); + if (!isSoleSender(send, send.getToken(), users)) { + // If we're not the sole sender, we can't eliminate the communication. + return; + } + // If we are the sole sender, we can replace all recvs with a copy of + // the constant value. However, since the recv may be in a different + // scope, we need to replace it with a clone of the constant op. + for (auto user : users) { + if (auto recv = dyn_cast(user)) { + auto cloned_const = send.getValue().getDefiningOp()->clone(); + // Insert the cloned constant right before the recv + recv->getBlock()->getOperations().insert(recv->getIterator(), + cloned_const); + recv.getResult().replaceAllUsesWith(cloned_const->getResult(0)); + recv.erase(); + } + } + send.erase(); + } + }); + } +}; +} // namespace mlir::enzyme::distributed \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/Distributed/Passes.h b/src/enzyme_ad/jax/Passes/Distributed/Passes.h new file mode 100644 index 0000000000..52ad040d0d --- /dev/null +++ b/src/enzyme_ad/jax/Passes/Distributed/Passes.h @@ -0,0 +1,16 @@ +#ifndef ENZYMEXLA_DISTRIBUTED_PASSES_H +#define ENZYMEXLA_DISTRIBUTED_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir::enzyme::distributed { + +#define GEN_PASS_DECL +#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc" + +} // namespace mlir::enzyme::distributed + +#endif // ENZYMEXLA_DISTRIBUTED_PASSES_H \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/Distributed/Passes.td b/src/enzyme_ad/jax/Passes/Distributed/Passes.td new file mode 100644 index 0000000000..430d833024 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/Distributed/Passes.td @@ -0,0 +1,18 @@ +#ifndef ENZYMEXLA_DISTRIBUTED_PASSES +#define ENZYMEXLA_DISTRIBUTED_PASSES + +include "mlir/Pass/PassBase.td" + +def EliminateConstantCommunicationPass : Pass<"eliminate-constant-communication"> { + let summary = "Replaces communicated constants with local constants"; + let description = [{ + This pass identifies send instructions with constant operands and replaces + the corresponding receive instructions with local constants. + }]; + let dependentDialects = [ + "enzyme::distributed::DistributedDialect", + "stablehlo::StablehloDialect" + ]; +} + +#endif // ENZYMEXLA_DISTRIBUTED_PASSES \ No newline at end of file diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp index 150d74914f..03a7046f97 100644 --- a/src/enzyme_ad/jax/RegistryUtils.cpp +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -86,6 +86,7 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #include "src/enzyme_ad/jax/Dialect/Distributed/Dialect.h" @@ -294,6 +295,7 @@ void registerInterfaces(mlir::DialectRegistry ®istry) { void initializePasses() { registerenzymePasses(); enzyme::registerenzymexlaPasses(); + enzyme::distributed::registerdistributedPasses(); // Register the standard passes we want. mlir::registerCSEPass(); diff --git a/test/lit_tests/distributed/eliminateconstants.mlir b/test/lit_tests/distributed/eliminateconstants.mlir new file mode 100644 index 0000000000..5e47024824 --- /dev/null +++ b/test/lit_tests/distributed/eliminateconstants.mlir @@ -0,0 +1,57 @@ +// RUN: enzymexlamlir-opt --eliminate-constant-communication %s | FileCheck %s +distributed.leaf_device @myGpu +distributed.device_mesh @gpuMesh @myGpu [2, 2] +distributed.leaf_device @myCpu +distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] +distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1] + +func.func @foo() { + distributed.device_parallel @gpusWithHost { + branch @myGpu { + ^entry(%1: !distributed.token): + distributed.device_parallel @gpuMesh { + branch @myGpu { + ^entry(): + } + } + } + branch @myCpu { + ^entry(%1: !distributed.token): + %output = stablehlo.constant() { + value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> + } : () -> tensor<2x2xf32> + distributed.send %1 tensor<2x2xf32> %output + + } + branch @chan1 { + ^entry(%1: !distributed.token): + %input = distributed.recv %1 tensor<2x2xf32> + %sum = stablehlo.add %input, %input : tensor<2x2xf32> + } + } + + func.return +} + +//CHECK: module { +//CHECK-NEXT: distributed.leaf_device @myGpu +//CHECK-NEXT: distributed.device_mesh @gpuMesh @myGpu [2, 2] +//CHECK-NEXT: distributed.leaf_device @myCpu +//CHECK-NEXT: distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] +//CHECK-NEXT: distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1] +//CHECK-NEXT: func.func @foo() { +//CHECK-NEXT: distributed.device_parallel @gpusWithHost{ branch @myGpu{ +//CHECK-NEXT: ^bb0(%arg0: !distributed.token): +//CHECK-NEXT: distributed.device_parallel @gpuMesh{ branch @myGpu{ +//CHECK-NEXT: }} +//CHECK-NEXT: } branch @myCpu{ +//CHECK-NEXT: ^bb0(%arg0: !distributed.token): +//CHECK-NEXT{LITERAL}: %cst = stablehlo.constant dense<[[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32> +//CHECK-NEXT: } branch @chan1{ +//CHECK-NEXT: ^bb0(%arg0: !distributed.token): +//CHECK-NEXT{LITERAL}: %cst = stablehlo.constant dense<[[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32> +//CHECK-NEXT: %0 = stablehlo.add %cst, %cst : tensor<2x2xf32> +//CHECK-NEXT: }} +//CHECK-NEXT: return +//CHECK-NEXT: } +//CHECK-NEXT:} \ No newline at end of file diff --git a/test/lit_tests/distributed/roundtrip.mlir b/test/lit_tests/distributed/roundtrip.mlir index e4c54e5b2c..0d29a6c814 100644 --- a/test/lit_tests/distributed/roundtrip.mlir +++ b/test/lit_tests/distributed/roundtrip.mlir @@ -1,42 +1,44 @@ // RUN: enzymexlamlir-opt %s | FileCheck %s -distributed.LeafDevice @myGpu -distributed.DeviceMesh @gpuMesh @myGpu [2, 2] -distributed.LeafDevice @myCpu -distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] -distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1] +distributed.leaf_device @myGpu +distributed.device_mesh @gpuMesh @myGpu [2, 2] +distributed.leaf_device @myCpu +distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] +distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1] func.func @foo() { - distributed.GroupSplit @gpusWithHost { - %tok = distributed.DefineToken @chan1 - distributed.SplitBranch @chan1 { } - distributed.SplitBranch @myCpu {} - distributed.SplitBranch @gpuMesh { - distributed.MeshFor @gpuMesh { + distributed.device_parallel @gpusWithHost { + branch @myGpu { + ^entry(): + distributed.device_parallel @gpuMesh { + branch @myGpu { + ^entry(): + } + } + } + branch @myCpu { + ^entry(): + } + branch @chan1 { + ^entry(): + } + } - } - } - } func.return } -// CHECK: module { -// CHECK-NEXT: distributed.LeafDevice @myGpu -// CHECK-NEXT: distributed.DeviceMesh @gpuMesh @myGpu [2, 2] -// CHECK-NEXT: distributed.LeafDevice @myCpu -// CHECK-NEXT: distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] -// CHECK-NEXT: distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1] -// CHECK-NEXT: func.func @foo() { -// CHECK-NEXT: distributed.GroupSplit @gpusWithHost { -// CHECK-NEXT: %0 = distributed.DefineToken @chan1 -// CHECK-NEXT: distributed.SplitBranch @chan1 { -// CHECK-NEXT: } -// CHECK-NEXT: distributed.SplitBranch @myCpu { -// CHECK-NEXT: } -// CHECK-NEXT: distributed.SplitBranch @gpuMesh { -// CHECK-NEXT: distributed.MeshFor @gpuMesh { -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: return -// CHECK-NEXT: } -// CHECK-NEXT: } \ No newline at end of file +//CHECK: module { +//CHECK-NEXT: distributed.leaf_device @myGpu +//CHECK-NEXT: distributed.device_mesh @gpuMesh @myGpu [2, 2] +//CHECK-NEXT: distributed.leaf_device @myCpu +//CHECK-NEXT: distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] +//CHECK-NEXT: distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1] +//CHECK-NEXT: func.func @foo() { +//CHECK-NEXT: distributed.device_parallel @gpusWithHost{ branch @myGpu{ +//CHECK-NEXT: distributed.device_parallel @gpuMesh{ branch @myGpu{ +//CHECK-NEXT: }} +//CHECK-NEXT: } branch @myCpu{ +//CHECK-NEXT: } branch @chan1{ +//CHECK-NEXT: }} +//CHECK-NEXT: return +//CHECK-NEXT: } +//CHECK-NEXT:} \ No newline at end of file