From f7ced7486d57716f2d63167ed85684ecc9b1c425 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Fri, 19 Sep 2025 13:40:53 -0500 Subject: [PATCH 1/6] Distributed dialect- change subbranches to regions --- src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp | 44 +++++++++++---- src/enzyme_ad/jax/Dialect/Distributed/Ops.td | 21 +++----- test/lit_tests/distributed/roundtrip.mlir | 54 +++++++++---------- 3 files changed, 64 insertions(+), 55 deletions(-) diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp b/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp index 21d62cecf6..5cecd29131 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp +++ b/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp @@ -111,20 +111,44 @@ GroupSplitOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { getDeviceGroupAttr()); } -LogicalResult -SplitBranchOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { - // Split branches have programs for individual devices or channels - Operation *dev_or_chan = - symbol_table.lookupNearestSymbolFrom(*this, getDeviceOrChannelAttr()); - if (!dev_or_chan || !(dev_or_chan->hasTrait() || - dev_or_chan->hasTrait())) { - mlir::emitError(getLoc()) - << "branches must reference a valid device or channel"; - return mlir::failure(); +// Printer/parser for GroupsplitOp branches +mlir::ParseResult parseSplitBranches( + OpAsmParser &parser, mlir::ArrayAttr &branchAssignments, + llvm::SmallVector, 2> &branchesRegions) { + // Expect 0 or more `branch` $symbol_name $symbol_region + // While next token is `branch`: + llvm::SmallVector assignment_symbols; + while (parser.parseOptionalKeyword("branch").succeeded()) { + // Parse symbol name + mlir::SymbolRefAttr sym; + auto sym_parse_failed = parser.parseAttribute(sym); + if (sym_parse_failed) + return mlir::failure(); + assignment_symbols.push_back(sym); + + // Put placeholder region in list and parse into it + branchesRegions.push_back(std::make_unique()); + auto parse_region_failed = parser.parseRegion(*branchesRegions.back()); + if (parse_region_failed) + return mlir::failure(); } + + branchAssignments = mlir::ArrayAttr::get(parser.getBuilder().getContext(), + assignment_symbols); return mlir::success(); } +void printSplitBranches(OpAsmPrinter &printer, const GroupSplitOp &op, + const mlir::ArrayAttr branchAssignments, + const llvm::MutableArrayRef branches) { + // Print each branch as `branch` $symbol_name $symbol_region + for (size_t i = 0; i < branches.size(); i++) { + printer << " branch "; + printer.printAttribute(branchAssignments[i]); + printer.printRegion(branches[i]); + } +} + LogicalResult DefineTokenOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { // Tokens need to indicate which channel they communicate over diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td index 0857ce2f2c..04da1452b2 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td +++ b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td @@ -59,25 +59,16 @@ def MeshForOp : DistributedOp<"MeshFor", [DeclareOpInterfaceMethods, NoTerminator, SingleBlock]>{ +def GroupSplitOp : DistributedOp<"GroupSplit", [DeclareOpInterfaceMethods, NoTerminator]>{ let arguments = (ins - SymbolRefAttr:$device_group // TODO: verify it's a group + SymbolRefAttr:$device_group, + ArrayAttr:$branch_assignments // Symbols mapping to each branch region ); - let regions = (region SizedRegion<1>:$declarations); // Takes as args the devices and channels in the group + // TODO check that declarations only declares tokens. + let regions = (region VariadicRegion>:$branches); let results = (outs ); // TODO // let hasVerifier = 1; // TODO - // let hasCanonicalizer = 1; // TODO: token declarations up front, followed by device and channel branches in order of listing in the group - let assemblyFormat = "$device_group $declarations attr-dict"; -} - -def SplitBranchOp : DistributedOp<"SplitBranch", [DeclareOpInterfaceMethods, NoTerminator, SingleBlock]>{ - let arguments = (ins - SymbolRefAttr:$device_or_channel // TODO: verify it's a device or channel - ); - let regions = (region MaxSizedRegion<1>:$body); // Takes as args the device or channel - let results = (outs ); // TODO - // let hasVerifier = 1; // TODO: parent is a groupsplitop - let assemblyFormat = "$device_or_channel $body attr-dict"; + let assemblyFormat = "$device_group custom($branch_assignments, $branches) attr-dict"; } def DefineTokenOp : DistributedOp<"DefineToken", [DeclareOpInterfaceMethods]>{ diff --git a/test/lit_tests/distributed/roundtrip.mlir b/test/lit_tests/distributed/roundtrip.mlir index e4c54e5b2c..e2d1ad7f77 100644 --- a/test/lit_tests/distributed/roundtrip.mlir +++ b/test/lit_tests/distributed/roundtrip.mlir @@ -6,37 +6,31 @@ distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1] func.func @foo() { - distributed.GroupSplit @gpusWithHost { - %tok = distributed.DefineToken @chan1 - distributed.SplitBranch @chan1 { } - distributed.SplitBranch @myCpu {} - distributed.SplitBranch @gpuMesh { - distributed.MeshFor @gpuMesh { + distributed.GroupSplit @gpusWithHost + branch @myGpu { + distributed.MeshFor @gpuMesh { + } + } + branch @myCpu { + distributed.DefineToken @chan1 + } - } - } - } func.return } -// CHECK: module { -// CHECK-NEXT: distributed.LeafDevice @myGpu -// CHECK-NEXT: distributed.DeviceMesh @gpuMesh @myGpu [2, 2] -// CHECK-NEXT: distributed.LeafDevice @myCpu -// CHECK-NEXT: distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] -// CHECK-NEXT: distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1] -// CHECK-NEXT: func.func @foo() { -// CHECK-NEXT: distributed.GroupSplit @gpusWithHost { -// CHECK-NEXT: %0 = distributed.DefineToken @chan1 -// CHECK-NEXT: distributed.SplitBranch @chan1 { -// CHECK-NEXT: } -// CHECK-NEXT: distributed.SplitBranch @myCpu { -// CHECK-NEXT: } -// CHECK-NEXT: distributed.SplitBranch @gpuMesh { -// CHECK-NEXT: distributed.MeshFor @gpuMesh { -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: return -// CHECK-NEXT: } -// CHECK-NEXT: } \ No newline at end of file +//CHECK: module { +//CHECK-NEXT: distributed.LeafDevice @myGpu +//CHECK-NEXT: distributed.DeviceMesh @gpuMesh @myGpu [2, 2] +//CHECK-NEXT: distributed.LeafDevice @myCpu +//CHECK-NEXT: distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] +//CHECK-NEXT: distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1] +//CHECK-NEXT: func.func @foo() { +//CHECK-NEXT: distributed.GroupSplit @gpusWithHost branch @myGpu{ +//CHECK-NEXT: distributed.MeshFor @gpuMesh { +//CHECK-NEXT: } +//CHECK-NEXT: } branch @myCpu{ +//CHECK-NEXT: %0 = distributed.DefineToken @chan1 +//CHECK-NEXT: } +//CHECK-NEXT: return +//CHECK-NEXT: } +//CHECK-NEXT: } \ No newline at end of file From 2fdaaf739d76a2fa0abc7c3755fb922757e69c3c Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Mon, 22 Sep 2025 22:19:42 -0500 Subject: [PATCH 2/6] Send, Recv ops and interfaces --- src/enzyme_ad/jax/BUILD | 5 ++-- .../jax/Dialect/Distributed/Dialect.cpp | 2 ++ .../jax/Dialect/Distributed/Dialect.h | 10 ++++---- .../jax/Dialect/Distributed/Interfaces.td | 25 +++++++++++++++++++ src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp | 14 +++++++++++ src/enzyme_ad/jax/Dialect/Distributed/Ops.td | 17 ++++++++++++- .../jax/Dialect/Distributed/Types.td | 3 ++- 7 files changed, 67 insertions(+), 9 deletions(-) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 24ac998d4d..78563f073a 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -569,11 +569,11 @@ gentbl_cc_library( name = "DistributedInterfacesIncGen", tbl_outs = [ ( - ["--gen-interface-decls"], + ["--gen-op-interface-decls"], "Dialect/Distributed/DistributedInterfaces.h.inc", ), ( - ["--gen-interface-defs"], + ["--gen-op-interface-defs"], "Dialect/Distributed/DistributedInterfaces.cpp.inc", ), ], @@ -744,6 +744,7 @@ cc_library( deps = [ ":CheckedRewrite", ":DistributedDialectIncGen", + ":DistributedInterfacesIncGen", ":DistributedOpsIncGen", ":DistributedTypesIncGen", ":EnzymeHLOPatternsIncGen", diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Dialect.cpp b/src/enzyme_ad/jax/Dialect/Distributed/Dialect.cpp index 343d327678..969b513e6e 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Dialect.cpp +++ b/src/enzyme_ad/jax/Dialect/Distributed/Dialect.cpp @@ -8,6 +8,8 @@ #define GET_TYPEDEF_CLASSES #include "src/enzyme_ad/jax/Dialect/Distributed/DistributedTypes.cpp.inc" +#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedInterfaces.cpp.inc" + // Initialize the dialect void mlir::enzyme::distributed::DistributedDialect::initialize() { addTypes< diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h b/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h index e8277364c6..62d8fef285 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h +++ b/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h @@ -10,14 +10,14 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" -// Include the dialect -#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedDialect.h.inc" -// Traits and interfaces #include "Traits.h" -// Types +#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedDialect.h.inc" + #define GET_TYPEDEF_CLASSES #include "src/enzyme_ad/jax/Dialect/Distributed/DistributedTypes.h.inc" -// Operations + +#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedInterfaces.h.inc" + #define GET_OP_CLASSES #include "src/enzyme_ad/jax/Dialect/Distributed/DistributedOps.h.inc" diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td b/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td index 8fd74fd4e2..59bd82f96f 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td +++ b/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td @@ -6,4 +6,29 @@ include "mlir/IR/OpBase.td" def DeviceDefTrait : NativeOpTrait<"enzyme::distributed::DeviceDefTrait">; def ChannelDefTrait : NativeOpTrait<"enzyme::distributed::ChannelDefTrait">; +def TokenReaderOpInterface : OpInterface<"TokenReaderOpInterface"> { + let cppNamespace = "::mlir::enzyme::distributed"; + let description = [{ + An interface to determine which ops can read from a channel and what type they expect. + Ops may read from multiple channels. + }]; + let methods = [ + InterfaceMethod<"Returns the SSA values of tokens read from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::ReadTokenType>>", "getReadTokens">, + InterfaceMethod<"Returns the types of tokens read from this op, parallel to getReadTokens.", "::llvm::ArrayRef<::mlir::Type>", "getReadTokenTypes"> + ]; +} + +def TokenWriterOpInterface : OpInterface<"TokenWriterOpInterface"> { + let cppNamespace = "::mlir::enzyme::distributed"; + let description = [{ + An interface to determine which ops can write to a channel and what type they provide. + Ops may write to multiple channels. + }]; + let methods = [ + InterfaceMethod<"Returns the SSA values of tokens written from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::WriteTokenType>>", "getWriteTokens">, + InterfaceMethod<"Returns the types of tokens written from this op, parallel to getWriteTokens.", "::llvm::ArrayRef<::mlir::Type>", "getWriteTokenTypes"> + ]; +} + + #endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_INTERFACES \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp b/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp index 5cecd29131..8aade88adc 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp +++ b/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp @@ -156,6 +156,20 @@ DefineTokenOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { getChannelAttr()); } +llvm::ArrayRef> SendOp::getWriteTokens() { + return llvm::SmallVector, 1>{getToken()}; +} +llvm::ArrayRef SendOp::getWriteTokenTypes() { + return llvm::SmallVector{getValue().getType()}; +} + +llvm::ArrayRef> RecvOp::getReadTokens() { + return llvm::SmallVector, 1>{getToken()}; +} +llvm::ArrayRef RecvOp::getReadTokenTypes() { + return llvm::SmallVector{getValue().getType()}; +} + } // namespace mlir::enzyme::distributed #define GET_OP_CLASSES #include "src/enzyme_ad/jax/Dialect/Distributed/DistributedOps.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td index 04da1452b2..c4c10b5d89 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td +++ b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td @@ -75,9 +75,24 @@ def DefineTokenOp : DistributedOp<"DefineToken", [DeclareOpInterfaceMethods]>{ + let arguments = (ins + WriteTokenType:$token, + // value to send + AnyType:$value); + let assemblyFormat = "$token type($value) $value attr-dict"; +} + +def RecvOp : DistributedOp<"Recv", [DeclareOpInterfaceMethods]>{ + let arguments = (ins + ReadTokenType:$token); + let results = (outs AnyType:$value); + let assemblyFormat = "$token type($value) attr-dict"; +} + #endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_OPS_TD \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Types.td b/src/enzyme_ad/jax/Dialect/Distributed/Types.td index 517c63753c..1e29602490 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Types.td +++ b/src/enzyme_ad/jax/Dialect/Distributed/Types.td @@ -3,6 +3,7 @@ include "Dialect.td" -def TokenType : DistributedType<"Token", "token">; +def ReadTokenType : DistributedType<"ReadToken", "read_token">; +def WriteTokenType : DistributedType<"WriteToken", "write_token">; #endif // ENZYME_DISTRIBUTED_DIALECT_TYPES_H \ No newline at end of file From 2d2f9e6ccf1c57dd233346d5124e3ccebb5de570 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Tue, 23 Sep 2025 21:23:57 -0500 Subject: [PATCH 3/6] Working commit: single token type --- src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td | 4 ++-- src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp | 8 ++++---- src/enzyme_ad/jax/Dialect/Distributed/Ops.td | 4 ++-- src/enzyme_ad/jax/Dialect/Distributed/Types.td | 3 +-- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td b/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td index 59bd82f96f..2908c5fb50 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td +++ b/src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td @@ -13,7 +13,7 @@ def TokenReaderOpInterface : OpInterface<"TokenReaderOpInterface"> { Ops may read from multiple channels. }]; let methods = [ - InterfaceMethod<"Returns the SSA values of tokens read from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::ReadTokenType>>", "getReadTokens">, + InterfaceMethod<"Returns the SSA values of tokens read from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::TokenType>>", "getReadTokens">, InterfaceMethod<"Returns the types of tokens read from this op, parallel to getReadTokens.", "::llvm::ArrayRef<::mlir::Type>", "getReadTokenTypes"> ]; } @@ -25,7 +25,7 @@ def TokenWriterOpInterface : OpInterface<"TokenWriterOpInterface"> { Ops may write to multiple channels. }]; let methods = [ - InterfaceMethod<"Returns the SSA values of tokens written from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::WriteTokenType>>", "getWriteTokens">, + InterfaceMethod<"Returns the SSA values of tokens written from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::TokenType>>", "getWriteTokens">, InterfaceMethod<"Returns the types of tokens written from this op, parallel to getWriteTokens.", "::llvm::ArrayRef<::mlir::Type>", "getWriteTokenTypes"> ]; } diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp b/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp index 8aade88adc..65fd3e04c1 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp +++ b/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp @@ -156,15 +156,15 @@ DefineTokenOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { getChannelAttr()); } -llvm::ArrayRef> SendOp::getWriteTokens() { - return llvm::SmallVector, 1>{getToken()}; +llvm::ArrayRef> SendOp::getWriteTokens() { + return llvm::SmallVector, 1>{getToken()}; } llvm::ArrayRef SendOp::getWriteTokenTypes() { return llvm::SmallVector{getValue().getType()}; } -llvm::ArrayRef> RecvOp::getReadTokens() { - return llvm::SmallVector, 1>{getToken()}; +llvm::ArrayRef> RecvOp::getReadTokens() { + return llvm::SmallVector, 1>{getToken()}; } llvm::ArrayRef RecvOp::getReadTokenTypes() { return llvm::SmallVector{getValue().getType()}; diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td index c4c10b5d89..6faa959cac 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td +++ b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td @@ -75,7 +75,7 @@ def DefineTokenOp : DistributedOp<"DefineToken", [DeclareOpInterfaceMethods]>{ let arguments = (ins - ReadTokenType:$token); + TokenType:$token); let results = (outs AnyType:$value); let assemblyFormat = "$token type($value) attr-dict"; } diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Types.td b/src/enzyme_ad/jax/Dialect/Distributed/Types.td index 1e29602490..517c63753c 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Types.td +++ b/src/enzyme_ad/jax/Dialect/Distributed/Types.td @@ -3,7 +3,6 @@ include "Dialect.td" -def ReadTokenType : DistributedType<"ReadToken", "read_token">; -def WriteTokenType : DistributedType<"WriteToken", "write_token">; +def TokenType : DistributedType<"Token", "token">; #endif // ENZYME_DISTRIBUTED_DIALECT_TYPES_H \ No newline at end of file From f4bb2c69d32d647d14e3da8ee088fd7f64be48a5 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Thu, 25 Sep 2025 14:15:02 -0500 Subject: [PATCH 4/6] Single device parallel op with block arg tokens --- src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp | 75 +++++++++++++------ src/enzyme_ad/jax/Dialect/Distributed/Ops.td | 75 +++++++++++-------- .../jax/Dialect/Distributed/Utils.cpp | 62 +++++++++++++++ src/enzyme_ad/jax/Dialect/Distributed/Utils.h | 36 +++++++++ test/lit_tests/distributed/roundtrip.mlir | 29 ++++--- 5 files changed, 216 insertions(+), 61 deletions(-) create mode 100644 src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp create mode 100644 src/enzyme_ad/jax/Dialect/Distributed/Utils.h diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp b/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp index 65fd3e04c1..894f989315 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp +++ b/src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp @@ -2,6 +2,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "Dialect.h" +#include "Utils.h" using mlir::OpTrait::enzyme::distributed::ChannelDefTrait; using mlir::OpTrait::enzyme::distributed::DeviceDefTrait; @@ -98,21 +99,60 @@ DeviceMeshOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { getDeviceType()); } -LogicalResult -MeshForOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { - // Mesh for ops apply only to meshes - return checkSymbolIsA(symbol_table, *this, getMeshAttr()); +Operation *DeviceParallelOp::getEnclosingDeviceOp() { + return mlir::SymbolTable::lookupNearestSymbolFrom(*this, + getEnclosingDeviceAttr()); } -LogicalResult -GroupSplitOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { - // Group splits apply only to device groups - return checkSymbolIsA(symbol_table, *this, - getDeviceGroupAttr()); +LogicalResult DeviceParallelOp::verifySymbolUses( + ::mlir::SymbolTableCollection &symbol_table) { + Operation *device_op = this->getEnclosingDeviceOp(); + if (isa(device_op) || isa(device_op)) { + return mlir::success(); + } + return emitOpError() + << "enclosing device symbol must refer to a device group or mesh"; +} + +LogicalResult DeviceParallelOp::verify() { + // Check number of branches matches number of assignments + + if (getNumRegions() != getBranchAssignments().size()) { + return emitOpError() + << "number of regions must match number of branch assignments"; + } + + // Look at device type to determine number of branches + auto device_op = mlir::SymbolTable::lookupNearestSymbolFrom( + *this, getEnclosingDeviceAttr()); + if (!device_op) { + return emitOpError() << "could not find enclosing device symbol"; + } + + if (DeviceGroupOp deviceGroup = dyn_cast(device_op)) { + // Device group: number of branches must match number of devices in group + auto devices = deviceGroup.getDevices(); + auto channels = deviceGroup.getChannels(); + if (getNumRegions() != devices.size() + channels.size()) { + return emitOpError() << "number of regions must match number of devices " + "and channels in device group"; + } + } else if (DeviceMeshOp mesh = dyn_cast(device_op)) { + // Exactly one branch for the mesth type + if (getNumRegions() != 1) { + return emitOpError() + << "device mesh must have exactly one region for its single type"; + } + } else { + return emitOpError() + << "enclosing device symbol must refer to a device group or mesh"; + } + + return mlir::success(); } -// Printer/parser for GroupsplitOp branches -mlir::ParseResult parseSplitBranches( +// Printer/parser for subdevice branches +mlir::ParseResult parseDeviceBranches( OpAsmParser &parser, mlir::ArrayAttr &branchAssignments, llvm::SmallVector, 2> &branchesRegions) { // Expect 0 or more `branch` $symbol_name $symbol_region @@ -138,9 +178,9 @@ mlir::ParseResult parseSplitBranches( return mlir::success(); } -void printSplitBranches(OpAsmPrinter &printer, const GroupSplitOp &op, - const mlir::ArrayAttr branchAssignments, - const llvm::MutableArrayRef branches) { +void printDeviceBranches(OpAsmPrinter &printer, const DeviceParallelOp &op, + const mlir::ArrayAttr branchAssignments, + const llvm::MutableArrayRef branches) { // Print each branch as `branch` $symbol_name $symbol_region for (size_t i = 0; i < branches.size(); i++) { printer << " branch "; @@ -149,13 +189,6 @@ void printSplitBranches(OpAsmPrinter &printer, const GroupSplitOp &op, } } -LogicalResult -DefineTokenOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) { - // Tokens need to indicate which channel they communicate over - return checkSymbolHasTrait(symbol_table, *this, - getChannelAttr()); -} - llvm::ArrayRef> SendOp::getWriteTokens() { return llvm::SmallVector, 1>{getToken()}; } diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td index 6faa959cac..57bfb214b8 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td +++ b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td @@ -10,7 +10,7 @@ include "Interfaces.td" // Device definition ops -def ChannelOp : DistributedOp<"Channel", [Symbol, ChannelDefTrait, DeclareOpInterfaceMethods]>{ +def ChannelOp : DistributedOp<"channel", [Symbol, ChannelDefTrait, DeclareOpInterfaceMethods]>{ let arguments = (ins SymbolNameAttr:$sym_name, // a variadic list of devices connected by this channel @@ -21,7 +21,7 @@ def ChannelOp : DistributedOp<"Channel", [Symbol, ChannelDefTrait, DeclareOpInte let assemblyFormat = "$sym_name $sending_devices $receiving_devices attr-dict"; } -def LeafDeviceOp : DistributedOp<"LeafDevice", [Symbol, DeviceDefTrait]>{ +def LeafDeviceOp : DistributedOp<"leaf_device", [Symbol, DeviceDefTrait]>{ let arguments = (ins SymbolNameAttr:$sym_name // TODO: device type, e.g. TPU, GPU, CPU, and other attributes @@ -29,7 +29,7 @@ def LeafDeviceOp : DistributedOp<"LeafDevice", [Symbol, DeviceDefTrait]>{ let assemblyFormat = "$sym_name attr-dict"; } -def DeviceGroupOp : DistributedOp<"DeviceGroup", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods]>{ +def DeviceGroupOp : DistributedOp<"device_group", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods]>{ let arguments = (ins SymbolNameAttr:$sym_name, // a variadic list of devices in the group @@ -39,7 +39,7 @@ def DeviceGroupOp : DistributedOp<"DeviceGroup", [Symbol, DeviceDefTrait, Declar ); let assemblyFormat = "$sym_name $devices $channels attr-dict"; } -def DeviceMeshOp : DistributedOp<"DeviceMesh", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods]>{ +def DeviceMeshOp : DistributedOp<"device_mesh", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods]>{ let arguments = (ins SymbolNameAttr:$sym_name, SymbolRefAttr:$device_type, @@ -49,50 +49,63 @@ def DeviceMeshOp : DistributedOp<"DeviceMesh", [Symbol, DeviceDefTrait, DeclareO let assemblyFormat = "$sym_name $device_type $shape attr-dict"; } -// Ops for breaking down computation across the device hierarchy +// def ContinueOp : DistributedOp<"continue", [Terminator]> { +// let description = [{ +// A terminator for DeviceParallelOp regions. Takes as arguments the tokens to be passed to the +// continuation of the DeviceParallelOp. These values can then be used in a subsequent DeviceParallelOp +// that is a sibling to the original DeviceParallelOp by referencing the returned tokens. +// }]; +// let arguments = (ins Variadic:$operands); +// let results = (outs ); // No outputs for terminators, the token is output by the parent DeviceParallelOp. +// let assemblyFormat = "$operands type($operands) attr-dict"; +// } -def MeshForOp : DistributedOp<"MeshFor", [DeclareOpInterfaceMethods, NoTerminator, SingleBlock]>{ - let arguments = (ins SymbolRefAttr:$mesh); // TODO: verify it's a mesh - let regions = (region MaxSizedRegion<1>:$body); // TODO: body's block args are device type and mesh index - let results = (outs ); // TODO - // let hasVerifier = 1; // TODO: verify body's block args take mesh index - let assemblyFormat = "$mesh $body attr-dict"; -} +def DeviceParallelOp : DistributedOp<"device_parallel", [DeclareOpInterfaceMethods, NoTerminator]>{ + let description = [{ + An op for mapping computations to subdevices. Serves both for homogenous device meshes as well + as explicitly enumerated device groups. In the case of device meshes, this op should contain + a single region to be executed in parallel on each device. In the case of device groups, this + op should contain one region per device and channel in the group. -def GroupSplitOp : DistributedOp<"GroupSplit", [DeclareOpInterfaceMethods, NoTerminator]>{ + In either case, regions must take as argument one device index within the parent device followed + by a number of token arguments. Tokens are matched by positionally between different branches, + and all branches must have the same number and type of token arguments (though they may be unused). + }]; + let arguments = (ins - SymbolRefAttr:$device_group, - ArrayAttr:$branch_assignments // Symbols mapping to each branch region + SymbolRefAttr:$enclosing_device, + ArrayAttr:$branch_assignments // the device components for each region (device-specific branch) ); - // TODO check that declarations only declares tokens. let regions = (region VariadicRegion>:$branches); - let results = (outs ); // TODO - // let hasVerifier = 1; // TODO - let assemblyFormat = "$device_group custom($branch_assignments, $branches) attr-dict"; -} - -def DefineTokenOp : DistributedOp<"DefineToken", [DeclareOpInterfaceMethods]>{ - let arguments = (ins - SymbolRefAttr:$channel - ); - let results = (outs TokenType:$token); - // let hasVerifier = 1; // TODO: verify writers and readers are connected to the channel - let assemblyFormat = "$channel attr-dict"; + // let results = (outs Variadic:$continuation_tokens); + let results = (outs ); + let hasVerifier = 1; // TODO + let assemblyFormat = "$enclosing_device `{` custom($branch_assignments, $branches) `}` attr-dict"; + let extraClassDeclaration = [{ + Operation* getEnclosingDeviceOp(); + }]; } -def SendOp : DistributedOp<"Send", [DeclareOpInterfaceMethods]>{ +def SendOp : DistributedOp<"send", [DeclareOpInterfaceMethods]>{ let arguments = (ins - WriteTokenType:$token, + TokenType:$token, // value to send AnyType:$value); let assemblyFormat = "$token type($value) $value attr-dict"; } -def RecvOp : DistributedOp<"Recv", [DeclareOpInterfaceMethods]>{ +def RecvOp : DistributedOp<"recv", [DeclareOpInterfaceMethods]>{ let arguments = (ins TokenType:$token); let results = (outs AnyType:$value); let assemblyFormat = "$token type($value) attr-dict"; } +def NoopOp : DistributedOp<"noop", []>{ + let description = [{ + A placeholder no-op. + }]; + let assemblyFormat = "attr-dict"; +} + #endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_OPS_TD \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp b/src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp new file mode 100644 index 0000000000..48efa39509 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp @@ -0,0 +1,62 @@ +#include "Utils.h" +namespace mlir::enzyme::distributed { +Region *getEnclosingDeviceParallelBranch(DeviceParallelOp parent, + Operation *op) { + auto region = op->getParentRegion(); + while (region->getParentOp() != parent) { + auto region_parent = + region->getParentOp(); // All regoins have parent ops... + if (!region_parent->getParentRegion()) // But not all ops have parent + // regions (e.g. top level ops) + return nullptr; + region = region_parent->getParentRegion(); + } + return region; +} + +int getDeviceParallelBranchIndex(DeviceParallelOp parent, Region *branch) { + assert(branch->getParentOp() == parent && "branch is not a region of parent"); + for (int i = 0; i < parent.getNumRegions(); i++) { + if (&parent.getRegion(i) == branch) + return i; + } + llvm_unreachable("branch not found in parent regions"); + return -1; +} + +mlir::Operation *getExecutingDevice(mlir::Operation *op) { + // Find current branch + auto parent = op->getParentOfType(); + auto branch = getEnclosingDeviceParallelBranch(parent, op); + if (!branch) + return nullptr; + // Find index of branch and cross-reference to parent device symbol + int branch_idx = getDeviceParallelBranchIndex(parent, branch); + auto device_sym = llvm::cast( + parent.getBranchAssignments()[branch_idx]); + + return SymbolTable::lookupNearestSymbolFrom(parent, device_sym); +} + +llvm::SmallVector +getCorrespondingTokens(mlir::BlockArgument token) { + unsigned idx = token.getArgNumber(); + auto op = token.getOwner()->getParentOp(); + DeviceParallelOp parent = llvm::cast(op); + llvm::SmallVector results; + results.reserve(parent.getNumRegions()); + for (auto region : parent.getRegions()) { + results.push_back(region->getArgument(idx)); + } + return results; +} + +llvm::SmallVector getTokenUsers(mlir::BlockArgument token) { + llvm::SmallVector results; + for (auto user : token.getUsers()) { + results.push_back(user); + } + return results; +} + +} // namespace mlir::enzyme::distributed \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Utils.h b/src/enzyme_ad/jax/Dialect/Distributed/Utils.h new file mode 100644 index 0000000000..659a34bbed --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Distributed/Utils.h @@ -0,0 +1,36 @@ +#ifndef ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H +#define ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H + +#include "Dialect.h" + +namespace mlir::enzyme::distributed { + +/** Get the enclosing device parallel branch for a given operation, or nullptr + * if the provided deviceParallelOp is not an ancestor of op. + */ +Region *getEnclosingDeviceParallelBranch(DeviceParallelOp parent, + Operation *op); + +/** Get the index of a device parallel branch within its parent operation. + * Parent op must be the direct parent of the branch region. + */ +int getDeviceParallelBranchIndex(DeviceParallelOp parent, Region *branch); + +/** + * Returns the defining op of the enclosing device of a given computational op + * (e.g. not the parent of a device defintion op). Returns nullptr if no such + * device can be found (not inside a device parallel region). + */ +mlir::Operation *getExecutingDevice(mlir::Operation *op); + +/** + * Returns all block arguments in the same device parallel region corresponding + * to the provided token, including the provided token itself. Will be provided + * in the same order as the branch assignments of the parent device parallel op. + */ +llvm::SmallVector +getCorrespondingTokens(mlir::BlockArgument token); +llvm::SmallVector getTokenUsers(mlir::BlockArgument token); +} // namespace mlir::enzyme::distributed + +#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H \ No newline at end of file diff --git a/test/lit_tests/distributed/roundtrip.mlir b/test/lit_tests/distributed/roundtrip.mlir index e2d1ad7f77..99ebb0d0b2 100644 --- a/test/lit_tests/distributed/roundtrip.mlir +++ b/test/lit_tests/distributed/roundtrip.mlir @@ -1,19 +1,30 @@ // RUN: enzymexlamlir-opt %s | FileCheck %s -distributed.LeafDevice @myGpu -distributed.DeviceMesh @gpuMesh @myGpu [2, 2] -distributed.LeafDevice @myCpu -distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] -distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1] +distributed.leaf_device @myGpu +distributed.device_mesh @gpuMesh @myGpu [2, 2] +distributed.leaf_device @myCpu +distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] +distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1] func.func @foo() { - distributed.GroupSplit @gpusWithHost + distributed.device_parallel @gpusWithHost { branch @myGpu { - distributed.MeshFor @gpuMesh { + ^entry(): + distributed.device_parallel @gpuMesh { + branch @myGpu { + ^entry(): + distributed.noop + } + } } - } branch @myCpu { - distributed.DefineToken @chan1 + ^entry(): + distributed.noop + } + branch @chan1 { + ^entry(): + distributed.noop } + } func.return } From d3ed39c07f343f66cfb40ed676480e2fad73c59d Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Thu, 25 Sep 2025 15:41:39 -0500 Subject: [PATCH 5/6] Distributed passes boilerplate --- src/enzyme_ad/jax/BUILD | 28 +++++++++++++++++++ src/enzyme_ad/jax/Passes/Distributed/Passes.h | 13 +++++++++ .../jax/Passes/Distributed/Passes.td | 15 ++++++++++ src/enzyme_ad/jax/RegistryUtils.cpp | 1 + 4 files changed, 57 insertions(+) create mode 100644 src/enzyme_ad/jax/Passes/Distributed/Passes.h create mode 100644 src/enzyme_ad/jax/Passes/Distributed/Passes.td diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 78563f073a..70a02a2c43 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -584,6 +584,31 @@ gentbl_cc_library( ], ) +td_library( + name = "DistributedPassesTdFiles", + srcs = [ + ], + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "DistributedPassesIncGen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=distributed", + ], + "Passes/Distributed/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes/Distributed/Passes.td", + deps = [":DistributedPassesTdFiles"], +) + td_library( name = "TesseraDialectTdFiles", srcs = [ @@ -717,6 +742,7 @@ cc_library( srcs = glob([ "Implementations/*.cpp", "Passes/*.cpp", + "Passes/Distributed/*.cpp", "Dialect/*.cpp", "Dialect/Distributed/*.cpp", "Dialect/Tessera/*.cpp", @@ -726,6 +752,7 @@ cc_library( hdrs = glob([ "Implementations/*.h", "Passes/*.h", + "Passes/Distributed/*.h", "Dialect/*.h", "Dialect/Distributed/*.h", "Dialect/Tessera/*.h", @@ -746,6 +773,7 @@ cc_library( ":DistributedDialectIncGen", ":DistributedInterfacesIncGen", ":DistributedOpsIncGen", + ":DistributedPassesIncGen", ":DistributedTypesIncGen", ":EnzymeHLOPatternsIncGen", ":EnzymeXLAAttrsIncGen", diff --git a/src/enzyme_ad/jax/Passes/Distributed/Passes.h b/src/enzyme_ad/jax/Passes/Distributed/Passes.h new file mode 100644 index 0000000000..9ecc9b6c4e --- /dev/null +++ b/src/enzyme_ad/jax/Passes/Distributed/Passes.h @@ -0,0 +1,13 @@ +#ifndef ENZYMEXLA_DISTRIBUTED_PASSES_H +#define ENZYMEXLA_DISTRIBUTED_PASSES_H + +namespace mlir::enzyme::distributed { + +#define GEN_PASS_DECL +#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc" +#define GEN_PASS_REGISTRATION +#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc" + +} // namespace mlir::enzyme::distributed + +#endif // ENZYMEXLA_DISTRIBUTED_PASSES_H \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/Distributed/Passes.td b/src/enzyme_ad/jax/Passes/Distributed/Passes.td new file mode 100644 index 0000000000..762a2e2a2f --- /dev/null +++ b/src/enzyme_ad/jax/Passes/Distributed/Passes.td @@ -0,0 +1,15 @@ +#ifndef ENZYMEXLA_DISTRIBUTED_PASSES +#define ENZYMEXLA_DISTRIBUTED_PASSES + +include "mlir/Pass/PassBase.td" + +def EliminateConstantCommunicationPass : Pass<"eliminate-constant-communication"> { + let summary = "Replaces communicated constants with local constants"; + let description = [{ + This pass identifies send instructions with constant operands and replaces + the corresponding receive instructions with local constants. + }]; + let dependentDialects = ["enzyme::distributed::DistributedDialect"]; +} + +#endif // ENZYMEXLA_DISTRIBUTED_PASSES \ No newline at end of file diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp index 150d74914f..9a927ffad5 100644 --- a/src/enzyme_ad/jax/RegistryUtils.cpp +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -86,6 +86,7 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #include "src/enzyme_ad/jax/Dialect/Distributed/Dialect.h" From 9eb63bc1ab2d43d20b934dfe27a2e2de834093ad Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Mon, 6 Oct 2025 22:17:02 -0500 Subject: [PATCH 6/6] Demo pass eliminate constant communication --- .../jax/Dialect/Distributed/Dialect.h | 25 ++++++++ src/enzyme_ad/jax/Dialect/Distributed/Ops.td | 7 --- .../jax/Dialect/Distributed/Utils.cpp | 59 +++++++++++++++---- src/enzyme_ad/jax/Dialect/Distributed/Utils.h | 37 +++++++++--- .../EliminateConstantCommunication.cpp | 58 ++++++++++++++++++ src/enzyme_ad/jax/Passes/Distributed/Passes.h | 3 + .../jax/Passes/Distributed/Passes.td | 5 +- src/enzyme_ad/jax/RegistryUtils.cpp | 1 + .../distributed/eliminateconstants.mlir | 57 ++++++++++++++++++ test/lit_tests/distributed/roundtrip.mlir | 33 +++++------ 10 files changed, 240 insertions(+), 45 deletions(-) create mode 100644 src/enzyme_ad/jax/Passes/Distributed/EliminateConstantCommunication.cpp create mode 100644 test/lit_tests/distributed/eliminateconstants.mlir diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h b/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h index 62d8fef285..e432c4c334 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h +++ b/src/enzyme_ad/jax/Dialect/Distributed/Dialect.h @@ -21,4 +21,29 @@ #define GET_OP_CLASSES #include "src/enzyme_ad/jax/Dialect/Distributed/DistributedOps.h.inc" +/** + * Convenience class to manage tokens, which are sometimes used as + * block args and other time as typed values. + */ +namespace mlir::enzyme::distributed { +class Token { + mlir::TypedValue typedValue; + mlir::BlockArgument blockArg; + +public: + Token(mlir::BlockArgument arg) : blockArg(arg) { + typedValue = dyn_cast>(arg); + assert(typedValue && "Block arg is not a token"); + } + Token(mlir::TypedValue val) : typedValue(val) { + assert(val && "Typed value is null"); + blockArg = dyn_cast(val); + assert(blockArg && "Typed value is not a block argument"); + } + + const mlir::TypedValue asTypedValue() const { return typedValue; } + const mlir::BlockArgument asBlockArg() const { return blockArg; } +}; +} // namespace mlir::enzyme::distributed + #endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_DIALECT_H diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td index 57bfb214b8..b330d53813 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Ops.td +++ b/src/enzyme_ad/jax/Dialect/Distributed/Ops.td @@ -101,11 +101,4 @@ def RecvOp : DistributedOp<"recv", [DeclareOpInterfaceMethods{ - let description = [{ - A placeholder no-op. - }]; - let assemblyFormat = "attr-dict"; -} - #endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_OPS_TD \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp b/src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp index 48efa39509..d347370554 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp +++ b/src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp @@ -1,11 +1,13 @@ #include "Utils.h" +#include "Dialect.h" namespace mlir::enzyme::distributed { + Region *getEnclosingDeviceParallelBranch(DeviceParallelOp parent, Operation *op) { auto region = op->getParentRegion(); while (region->getParentOp() != parent) { auto region_parent = - region->getParentOp(); // All regoins have parent ops... + region->getParentOp(); // All regions have parent ops... if (!region_parent->getParentRegion()) // But not all ops have parent // regions (e.g. top level ops) return nullptr; @@ -38,25 +40,60 @@ mlir::Operation *getExecutingDevice(mlir::Operation *op) { return SymbolTable::lookupNearestSymbolFrom(parent, device_sym); } -llvm::SmallVector -getCorrespondingTokens(mlir::BlockArgument token) { - unsigned idx = token.getArgNumber(); - auto op = token.getOwner()->getParentOp(); +llvm::SmallVector getCorrespondingTokens(Token token) { + unsigned idx = token.asBlockArg().getArgNumber(); + auto op = token.asBlockArg().getOwner()->getParentOp(); DeviceParallelOp parent = llvm::cast(op); - llvm::SmallVector results; + llvm::SmallVector results; results.reserve(parent.getNumRegions()); for (auto region : parent.getRegions()) { - results.push_back(region->getArgument(idx)); + results.push_back(Token(region->getArgument(idx))); } return results; } -llvm::SmallVector getTokenUsers(mlir::BlockArgument token) { - llvm::SmallVector results; - for (auto user : token.getUsers()) { - results.push_back(user); +llvm::SmallVector getTokenUsers(Token token) { + auto all_tokens = getCorrespondingTokens(token); + llvm::SmallVector results; + // Concatenate all users of all corresponding tokens. + // Due to scoping rules and since each token is a block arg to a + // different region, there should be no duplicates here. + for (auto t : all_tokens) { + for (auto user : t.asBlockArg().getUsers()) { + results.push_back(user); + } } return results; } +bool isSoleSender(TokenWriterOpInterface writer) { + auto tokens = writer.getWriteTokens(); + // Check for conflicts on all tokens + for (auto token : tokens) { + auto users = getTokenUsers(token); + if (!isSoleSender(writer, token, users)) { + return false; + } + } + return true; +} + +bool isSoleSender(TokenWriterOpInterface writer, Token token, + llvm::ArrayRef others) { + for (auto user : others) { + TypedValue as_val = token.asTypedValue(); + TokenWriterOpInterface other = dyn_cast(user); + if (other && other != writer) { + // Found another writer using the same token. Check if it uses + // the token to write, or only for something else: + auto other_write_tokens = other.getWriteTokens(); + for (auto t : other_write_tokens) { + if (t == as_val) { + return false; // Found another op writing to the same token + } + } + } + } + return true; +} } // namespace mlir::enzyme::distributed \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Distributed/Utils.h b/src/enzyme_ad/jax/Dialect/Distributed/Utils.h index 659a34bbed..1b402af68c 100644 --- a/src/enzyme_ad/jax/Dialect/Distributed/Utils.h +++ b/src/enzyme_ad/jax/Dialect/Distributed/Utils.h @@ -2,16 +2,19 @@ #define ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H #include "Dialect.h" +#include "Traits.h" namespace mlir::enzyme::distributed { -/** Get the enclosing device parallel branch for a given operation, or nullptr +/** + * Get the enclosing device parallel branch for a given operation, or nullptr * if the provided deviceParallelOp is not an ancestor of op. */ Region *getEnclosingDeviceParallelBranch(DeviceParallelOp parent, Operation *op); -/** Get the index of a device parallel branch within its parent operation. +/** + * Get the index of a device parallel branch within its parent operation. * Parent op must be the direct parent of the branch region. */ int getDeviceParallelBranchIndex(DeviceParallelOp parent, Region *branch); @@ -24,13 +27,31 @@ int getDeviceParallelBranchIndex(DeviceParallelOp parent, Region *branch); mlir::Operation *getExecutingDevice(mlir::Operation *op); /** - * Returns all block arguments in the same device parallel region corresponding - * to the provided token, including the provided token itself. Will be provided - * in the same order as the branch assignments of the parent device parallel op. + * Returns the counterpart tokens across all branches for the provided token. + * Each token here corresponds to the same logical token, but passed as a + * different block argument to each branch. Tokens are ordered in the same order + * as the branches of the parent DeviceParallelOp. Includes token itself. + */ +llvm::SmallVector getCorrespondingTokens(Token token); + +/** + * Returns all users of the provided token or its counterpart across all + * branches, including readers, writers, and any other op that takes the token + * as an operand. + */ +llvm::SmallVector getTokenUsers(Token token); + +/** + * Returns true if no other ops ever write to any token written by the + * provided op. + */ +bool isSoleSender(TokenWriterOpInterface writer); + +/** + * Returns true if no other ops in the provided list send on the same channel. */ -llvm::SmallVector -getCorrespondingTokens(mlir::BlockArgument token); -llvm::SmallVector getTokenUsers(mlir::BlockArgument token); +bool isSoleSender(TokenWriterOpInterface writer, Token token, + llvm::ArrayRef others); } // namespace mlir::enzyme::distributed #endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/Distributed/EliminateConstantCommunication.cpp b/src/enzyme_ad/jax/Passes/Distributed/EliminateConstantCommunication.cpp new file mode 100644 index 0000000000..ae2ff84a8d --- /dev/null +++ b/src/enzyme_ad/jax/Passes/Distributed/EliminateConstantCommunication.cpp @@ -0,0 +1,58 @@ +/** + * Replaces send(constant); recv(); with just constant. + */ + +#include "Passes.h" +#include "src/enzyme_ad/jax/Dialect/Distributed/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Distributed/Utils.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir::enzyme::distributed { +#define GEN_PASS_DEF_ELIMINATECONSTANTCOMMUNICATIONPASS +#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc" + +bool isConstantOp(Operation *op) { return isa(op); } +bool isConstant(Value val) { + if (auto op = val.getDefiningOp()) { + return isConstantOp(op); + } + return false; +} + +struct EliminateConstantCommunicationPass + : public impl::EliminateConstantCommunicationPassBase< + EliminateConstantCommunicationPass> { + using EliminateConstantCommunicationPassBase:: + EliminateConstantCommunicationPassBase; + void runOnOperation() override { + Operation *op = getOperation(); + // Post-order walk is allowed to erase the sends. Less sure if we + // are permitted to erase the recvs during the walk. + op->walk([&](enzyme::distributed::SendOp send) { + if (isConstant(send.getValue())) { + // Check that we are the only sender on this channel, and get + // the corresponding recvs. + auto users = getTokenUsers(send.getToken()); + if (!isSoleSender(send, send.getToken(), users)) { + // If we're not the sole sender, we can't eliminate the communication. + return; + } + // If we are the sole sender, we can replace all recvs with a copy of + // the constant value. However, since the recv may be in a different + // scope, we need to replace it with a clone of the constant op. + for (auto user : users) { + if (auto recv = dyn_cast(user)) { + auto cloned_const = send.getValue().getDefiningOp()->clone(); + // Insert the cloned constant right before the recv + recv->getBlock()->getOperations().insert(recv->getIterator(), + cloned_const); + recv.getResult().replaceAllUsesWith(cloned_const->getResult(0)); + recv.erase(); + } + } + send.erase(); + } + }); + } +}; +} // namespace mlir::enzyme::distributed \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/Distributed/Passes.h b/src/enzyme_ad/jax/Passes/Distributed/Passes.h index 9ecc9b6c4e..52ad040d0d 100644 --- a/src/enzyme_ad/jax/Passes/Distributed/Passes.h +++ b/src/enzyme_ad/jax/Passes/Distributed/Passes.h @@ -1,10 +1,13 @@ #ifndef ENZYMEXLA_DISTRIBUTED_PASSES_H #define ENZYMEXLA_DISTRIBUTED_PASSES_H +#include "mlir/Pass/Pass.h" + namespace mlir::enzyme::distributed { #define GEN_PASS_DECL #include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc" + #define GEN_PASS_REGISTRATION #include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc" diff --git a/src/enzyme_ad/jax/Passes/Distributed/Passes.td b/src/enzyme_ad/jax/Passes/Distributed/Passes.td index 762a2e2a2f..430d833024 100644 --- a/src/enzyme_ad/jax/Passes/Distributed/Passes.td +++ b/src/enzyme_ad/jax/Passes/Distributed/Passes.td @@ -9,7 +9,10 @@ def EliminateConstantCommunicationPass : Pass<"eliminate-constant-communication" This pass identifies send instructions with constant operands and replaces the corresponding receive instructions with local constants. }]; - let dependentDialects = ["enzyme::distributed::DistributedDialect"]; + let dependentDialects = [ + "enzyme::distributed::DistributedDialect", + "stablehlo::StablehloDialect" + ]; } #endif // ENZYMEXLA_DISTRIBUTED_PASSES \ No newline at end of file diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp index 9a927ffad5..03a7046f97 100644 --- a/src/enzyme_ad/jax/RegistryUtils.cpp +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -295,6 +295,7 @@ void registerInterfaces(mlir::DialectRegistry ®istry) { void initializePasses() { registerenzymePasses(); enzyme::registerenzymexlaPasses(); + enzyme::distributed::registerdistributedPasses(); // Register the standard passes we want. mlir::registerCSEPass(); diff --git a/test/lit_tests/distributed/eliminateconstants.mlir b/test/lit_tests/distributed/eliminateconstants.mlir new file mode 100644 index 0000000000..5e47024824 --- /dev/null +++ b/test/lit_tests/distributed/eliminateconstants.mlir @@ -0,0 +1,57 @@ +// RUN: enzymexlamlir-opt --eliminate-constant-communication %s | FileCheck %s +distributed.leaf_device @myGpu +distributed.device_mesh @gpuMesh @myGpu [2, 2] +distributed.leaf_device @myCpu +distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] +distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1] + +func.func @foo() { + distributed.device_parallel @gpusWithHost { + branch @myGpu { + ^entry(%1: !distributed.token): + distributed.device_parallel @gpuMesh { + branch @myGpu { + ^entry(): + } + } + } + branch @myCpu { + ^entry(%1: !distributed.token): + %output = stablehlo.constant() { + value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> + } : () -> tensor<2x2xf32> + distributed.send %1 tensor<2x2xf32> %output + + } + branch @chan1 { + ^entry(%1: !distributed.token): + %input = distributed.recv %1 tensor<2x2xf32> + %sum = stablehlo.add %input, %input : tensor<2x2xf32> + } + } + + func.return +} + +//CHECK: module { +//CHECK-NEXT: distributed.leaf_device @myGpu +//CHECK-NEXT: distributed.device_mesh @gpuMesh @myGpu [2, 2] +//CHECK-NEXT: distributed.leaf_device @myCpu +//CHECK-NEXT: distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] +//CHECK-NEXT: distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1] +//CHECK-NEXT: func.func @foo() { +//CHECK-NEXT: distributed.device_parallel @gpusWithHost{ branch @myGpu{ +//CHECK-NEXT: ^bb0(%arg0: !distributed.token): +//CHECK-NEXT: distributed.device_parallel @gpuMesh{ branch @myGpu{ +//CHECK-NEXT: }} +//CHECK-NEXT: } branch @myCpu{ +//CHECK-NEXT: ^bb0(%arg0: !distributed.token): +//CHECK-NEXT{LITERAL}: %cst = stablehlo.constant dense<[[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32> +//CHECK-NEXT: } branch @chan1{ +//CHECK-NEXT: ^bb0(%arg0: !distributed.token): +//CHECK-NEXT{LITERAL}: %cst = stablehlo.constant dense<[[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32> +//CHECK-NEXT: %0 = stablehlo.add %cst, %cst : tensor<2x2xf32> +//CHECK-NEXT: }} +//CHECK-NEXT: return +//CHECK-NEXT: } +//CHECK-NEXT:} \ No newline at end of file diff --git a/test/lit_tests/distributed/roundtrip.mlir b/test/lit_tests/distributed/roundtrip.mlir index 99ebb0d0b2..0d29a6c814 100644 --- a/test/lit_tests/distributed/roundtrip.mlir +++ b/test/lit_tests/distributed/roundtrip.mlir @@ -12,17 +12,14 @@ func.func @foo() { distributed.device_parallel @gpuMesh { branch @myGpu { ^entry(): - distributed.noop } } } branch @myCpu { ^entry(): - distributed.noop } branch @chan1 { ^entry(): - distributed.noop } } @@ -30,18 +27,18 @@ func.func @foo() { } //CHECK: module { -//CHECK-NEXT: distributed.LeafDevice @myGpu -//CHECK-NEXT: distributed.DeviceMesh @gpuMesh @myGpu [2, 2] -//CHECK-NEXT: distributed.LeafDevice @myCpu -//CHECK-NEXT: distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] -//CHECK-NEXT: distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1] -//CHECK-NEXT: func.func @foo() { -//CHECK-NEXT: distributed.GroupSplit @gpusWithHost branch @myGpu{ -//CHECK-NEXT: distributed.MeshFor @gpuMesh { -//CHECK-NEXT: } -//CHECK-NEXT: } branch @myCpu{ -//CHECK-NEXT: %0 = distributed.DefineToken @chan1 -//CHECK-NEXT: } -//CHECK-NEXT: return -//CHECK-NEXT: } -//CHECK-NEXT: } \ No newline at end of file +//CHECK-NEXT: distributed.leaf_device @myGpu +//CHECK-NEXT: distributed.device_mesh @gpuMesh @myGpu [2, 2] +//CHECK-NEXT: distributed.leaf_device @myCpu +//CHECK-NEXT: distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu] +//CHECK-NEXT: distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1] +//CHECK-NEXT: func.func @foo() { +//CHECK-NEXT: distributed.device_parallel @gpusWithHost{ branch @myGpu{ +//CHECK-NEXT: distributed.device_parallel @gpuMesh{ branch @myGpu{ +//CHECK-NEXT: }} +//CHECK-NEXT: } branch @myCpu{ +//CHECK-NEXT: } branch @chan1{ +//CHECK-NEXT: }} +//CHECK-NEXT: return +//CHECK-NEXT: } +//CHECK-NEXT:} \ No newline at end of file