diff --git a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td index 11bf8ca254..cbb4235c48 100644 --- a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td +++ b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td @@ -133,4 +133,72 @@ def EnzymeXLA_GuaranteedAnalysisResult : I32EnumAttr<"GuaranteedAnalysisResult", def EnzymeXLA_GuaranteedAnalysisResultAttr : EnumAttr; +// MPI + +def EnzymeXLA_MPIDatatype : I32EnumAttr<"MPIDatatype", + "MPI Datatype", + [ + I32EnumAttrCase<"MPI_DATATYPE_NULL", 0>, + I32EnumAttrCase<"MPI_INT8_T", 1>, + I32EnumAttrCase<"MPI_UINT8_T", 2>, + I32EnumAttrCase<"MPI_INT16_T", 3>, + I32EnumAttrCase<"MPI_UINT16_T", 4>, + I32EnumAttrCase<"MPI_INT32_T", 5>, + I32EnumAttrCase<"MPI_UINT32_T", 6>, + I32EnumAttrCase<"MPI_INT64_T", 7>, + I32EnumAttrCase<"MPI_UINT64_T", 8>, + I32EnumAttrCase<"MPI_BYTE", 9>, + I32EnumAttrCase<"MPI_SHORT", 10>, + I32EnumAttrCase<"MPI_UNSIGNED_SHORT", 11>, + I32EnumAttrCase<"MPI_INT", 12>, + I32EnumAttrCase<"MPI_UNSIGNED", 13>, + I32EnumAttrCase<"MPI_LONG", 14>, + I32EnumAttrCase<"MPI_UNSIGNED_LONG", 15>, + I32EnumAttrCase<"MPI_LONG_LONG_INT", 16>, + I32EnumAttrCase<"MPI_UNSIGNED_LONG_LONG", 17>, + I32EnumAttrCase<"MPI_CHAR", 18>, + I32EnumAttrCase<"MPI_SIGNED_CHAR", 19>, + I32EnumAttrCase<"MPI_UNSIGNED_CHAR", 20>, + I32EnumAttrCase<"MPI_WCHAR", 21>, + I32EnumAttrCase<"MPI_FLOAT", 22>, + I32EnumAttrCase<"MPI_DOUBLE", 23>, + I32EnumAttrCase<"MPI_C_FLOAT_COMPLEX", 24>, + I32EnumAttrCase<"MPI_C_DOUBLE_COMPLEX", 25>, + I32EnumAttrCase<"MPI_C_BOOL", 26> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::enzymexla"; +} + +def EnzymeXLA_MPIDatatypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def EnzymeXLA_MPIOp : I32EnumAttr<"MPIOp", + "MPI Operator", + [ + I32EnumAttrCase<"MPI_OP_NULL", 0>, + I32EnumAttrCase<"MPI_BAND", 1>, + I32EnumAttrCase<"MPI_BOR", 2>, + I32EnumAttrCase<"MPI_BXOR", 3>, + I32EnumAttrCase<"MPI_LAND", 4>, + I32EnumAttrCase<"MPI_LOR", 5>, + I32EnumAttrCase<"MPI_LXOR", 6>, + I32EnumAttrCase<"MPI_MAX", 7>, + I32EnumAttrCase<"MPI_MIN", 8>, + I32EnumAttrCase<"MPI_PROD", 9>, + I32EnumAttrCase<"MPI_REPLACE", 10>, + I32EnumAttrCase<"MPI_SUM", 11>, + I32EnumAttrCase<"MPI_NO_OP", 12> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::enzymexla"; +} + +def EnzymeXLA_MPIOpAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + #endif // ENZYMEXLA_ATTRS diff --git a/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td b/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td index 38315cc52a..7af220cb80 100644 --- a/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td +++ b/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td @@ -1088,4 +1088,130 @@ def AffineStoreVar : EnzymeXLA_Op<"store_var", [Pure]> { let summary = "Fake store an SSA value for conversion to ISL"; } +// MPI Ops + +def MPICommRankOp : EnzymeXLA_Op<"mpi.comm_rank", [Pure]> { + let summary = "Equivalent to " "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`"; + + let results = ( + outs TensorOf<[I32]> : $rank + ); + + let assemblyFormat = "attr-dict `:` type(results)"; +} + +def MPICommSizeOp : EnzymeXLA_Op<"mpi.comm_size", [Pure]> { + let summary = "Equivalent to MPI_Comm_size(MPI_COMM_WORLD, &size)"; + + let results = ( + outs TensorOf<[I32]> : $size + ); + + let assemblyFormat = "attr-dict `:` type(results)"; +} + +def MPIBarrierOp : EnzymeXLA_Op<"mpi.barrier", []> { + let summary = "Equivalent to MPI_Barrier(MPI_COMM_WORLD)"; + let assemblyFormat = "attr-dict"; +} + +def MPISendOp : EnzymeXLA_Op<"mpi.send", []> { + let summary = "Equivalent to " + "`MPI_Send(&buf, count, datatype, dest, tag, comm)`"; + + let arguments = ( + ins AnyTensor : $buf, + TensorOf<[I32]> : $count, + TensorOf<[I32]> : $dest, + TensorOf<[I32]> : $tag, + EnzymeXLA_MPIDatatypeAttr:$datatype + ); + + let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)"; +} + +def MPIRecvOp : EnzymeXLA_Op<"mpi.recv", []> { + let summary = "Equivalent to " + "`MPI_Recv(&buf, count, datatype, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)`"; + + let arguments = ( + ins AnyTensor : $inbuf, + TensorOf<[I32]> : $count, + TensorOf<[I32]> : $source, + TensorOf<[I32]> : $tag, + EnzymeXLA_MPIDatatypeAttr:$datatype + ); + + let results = ( + outs AnyTensor : $outbuf + ); + + let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)"; +} + +def MPIIsendOp : EnzymeXLA_Op<"mpi.isend", []> { + let summary = "Equivalent to " + "`MPI_Isend(&buf, count, datatype, dest, tag, MPI_COMM_WORLD, &request)`"; + + let arguments = ( + ins AnyTensor : $buf, + TensorOf<[I32]> : $count, + TensorOf<[I32]> : $dest, + TensorOf<[I32]> : $tag, + EnzymeXLA_MPIDatatypeAttr:$datatype + ); + + let results = ( + outs TensorOf<[I64]> : $request + ); + + let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)"; +} + +def MPIIrecvOp : EnzymeXLA_Op<"mpi.irecv", []> { + let summary = "Equivalent to " + "`MPI_Irecv(&buf, count, datatype, source, tag, MPI_COMM_WORLD, &request)`"; + + let arguments = ( + ins AnyTensor : $inbuf, + TensorOf<[I32]> : $count, + TensorOf<[I32]> : $source, + TensorOf<[I32]> : $tag, + EnzymeXLA_MPIDatatypeAttr:$datatype + ); + + let results = ( + outs AnyTensor : $outbuf, + TensorOf<[I64]> : $request + ); + + let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)"; +} + +def MPIWaitOp : EnzymeXLA_Op<"mpi.wait", []> { + let summary = "Equivalent to " + "`MPI_Wait(&request, &status)`"; + let arguments = (ins TensorOf<[I64]> : $request); + let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)"; +} + +def MPIAllreduceOp : EnzymeXLA_Op<"mpi.allreduce", []> { + let summary = "Equivalent to " + "`MPI_Allreduce(&sendbuf, &recvbuf, count, datatype, op, MPI_COMM_WORLD)`"; + + let arguments = ( + ins AnyTensor : $sendbuf, + AnyTensor : $inbuf, + TensorOf<[I32]> : $count, + EnzymeXLA_MPIDatatypeAttr:$datatype, + EnzymeXLA_MPIOpAttr:$op + ); + + let results = ( + outs AnyTensor : $outbuf + ); + + let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)"; +} + #endif // ENZYMEXLA_OPS diff --git a/src/enzyme_ad/jax/Integrations/c/EnzymeXLA.cpp b/src/enzyme_ad/jax/Integrations/c/EnzymeXLA.cpp index e4bfd4f4dc..6bf10beac0 100644 --- a/src/enzyme_ad/jax/Integrations/c/EnzymeXLA.cpp +++ b/src/enzyme_ad/jax/Integrations/c/EnzymeXLA.cpp @@ -122,3 +122,141 @@ MlirAttribute enzymexlaGuaranteedAnalysisResultAttrGet(MlirContext ctx, return wrap(mlir::enzymexla::GuaranteedAnalysisResultAttr::get(unwrap(ctx), analysis)); } + +MlirAttribute enzymexlaMPIDatatypeAttrGet(MlirContext ctx, int32_t mode) { + mlir::enzymexla::MPIDatatype datatype; + switch (mode) { + case 0: + datatype = mlir::enzymexla::MPIDatatype::MPI_DATATYPE_NULL; + break; + case 1: + datatype = mlir::enzymexla::MPIDatatype::MPI_INT8_T; + break; + case 2: + datatype = mlir::enzymexla::MPIDatatype::MPI_UINT8_T; + break; + case 3: + datatype = mlir::enzymexla::MPIDatatype::MPI_INT16_T; + break; + case 4: + datatype = mlir::enzymexla::MPIDatatype::MPI_UINT16_T; + break; + case 5: + datatype = mlir::enzymexla::MPIDatatype::MPI_INT32_T; + break; + case 6: + datatype = mlir::enzymexla::MPIDatatype::MPI_UINT32_T; + break; + case 7: + datatype = mlir::enzymexla::MPIDatatype::MPI_INT64_T; + break; + case 8: + datatype = mlir::enzymexla::MPIDatatype::MPI_UINT64_T; + break; + case 9: + datatype = mlir::enzymexla::MPIDatatype::MPI_BYTE; + break; + case 10: + datatype = mlir::enzymexla::MPIDatatype::MPI_SHORT; + break; + case 11: + datatype = mlir::enzymexla::MPIDatatype::MPI_UNSIGNED_SHORT; + break; + case 12: + datatype = mlir::enzymexla::MPIDatatype::MPI_INT; + break; + case 13: + datatype = mlir::enzymexla::MPIDatatype::MPI_UNSIGNED; + break; + case 14: + datatype = mlir::enzymexla::MPIDatatype::MPI_LONG; + break; + case 15: + datatype = mlir::enzymexla::MPIDatatype::MPI_UNSIGNED_LONG; + break; + case 16: + datatype = mlir::enzymexla::MPIDatatype::MPI_LONG_LONG_INT; + break; + case 17: + datatype = mlir::enzymexla::MPIDatatype::MPI_UNSIGNED_LONG_LONG; + break; + case 18: + datatype = mlir::enzymexla::MPIDatatype::MPI_CHAR; + break; + case 19: + datatype = mlir::enzymexla::MPIDatatype::MPI_SIGNED_CHAR; + break; + case 20: + datatype = mlir::enzymexla::MPIDatatype::MPI_UNSIGNED_CHAR; + break; + case 21: + datatype = mlir::enzymexla::MPIDatatype::MPI_WCHAR; + break; + case 22: + datatype = mlir::enzymexla::MPIDatatype::MPI_FLOAT; + break; + case 23: + datatype = mlir::enzymexla::MPIDatatype::MPI_DOUBLE; + break; + case 24: + datatype = mlir::enzymexla::MPIDatatype::MPI_C_FLOAT_COMPLEX; + break; + case 25: + datatype = mlir::enzymexla::MPIDatatype::MPI_C_DOUBLE_COMPLEX; + break; + case 26: + datatype = mlir::enzymexla::MPIDatatype::MPI_C_BOOL; + break; + default: + llvm_unreachable("Invalid MPI datatype mode"); + } + return wrap(mlir::enzymexla::MPIDatatypeAttr::get(unwrap(ctx), datatype)); +} + +MlirAttribute enzymexlaMPIOpAttrGet(MlirContext ctx, int32_t mode) { + mlir::enzymexla::MPIOp op; + switch (mode) { + case 0: + op = mlir::enzymexla::MPIOp::MPI_OP_NULL; + break; + case 1: + op = mlir::enzymexla::MPIOp::MPI_BAND; + break; + case 2: + op = mlir::enzymexla::MPIOp::MPI_BOR; + break; + case 3: + op = mlir::enzymexla::MPIOp::MPI_BXOR; + break; + case 4: + op = mlir::enzymexla::MPIOp::MPI_LAND; + break; + case 5: + op = mlir::enzymexla::MPIOp::MPI_LOR; + break; + case 6: + op = mlir::enzymexla::MPIOp::MPI_LXOR; + break; + case 7: + op = mlir::enzymexla::MPIOp::MPI_MAX; + break; + case 8: + op = mlir::enzymexla::MPIOp::MPI_MIN; + break; + case 9: + op = mlir::enzymexla::MPIOp::MPI_PROD; + break; + case 10: + op = mlir::enzymexla::MPIOp::MPI_REPLACE; + break; + case 11: + op = mlir::enzymexla::MPIOp::MPI_SUM; + break; + case 12: + op = mlir::enzymexla::MPIOp::MPI_NO_OP; + break; + default: + llvm_unreachable("Invalid MPI op mode"); + } + return wrap(mlir::enzymexla::MPIOpAttr::get(unwrap(ctx), op)); +} diff --git a/src/enzyme_ad/jax/Integrations/c/EnzymeXLA.h b/src/enzyme_ad/jax/Integrations/c/EnzymeXLA.h index fda4a3b6c1..db4118c9fc 100644 --- a/src/enzyme_ad/jax/Integrations/c/EnzymeXLA.h +++ b/src/enzyme_ad/jax/Integrations/c/EnzymeXLA.h @@ -43,6 +43,16 @@ MLIR_CAPI_EXPORTED MlirAttribute enzymexlaSVDAlgorithmAttrGet(MlirContext ctx, MLIR_CAPI_EXPORTED MlirAttribute enzymexlaGeluApproximationAttrGet(MlirContext ctx, int32_t mode); +//===----------------------------------------------------------------------===// +// MPI Ops +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirAttribute enzymexlaMPIDatatypeAttrGet(MlirContext ctx, + int32_t mode); + +MLIR_CAPI_EXPORTED MlirAttribute enzymexlaMPIOpAttrGet(MlirContext ctx, + int32_t mode); + //===----------------------------------------------------------------------===// // Other Ops / Attributes //===----------------------------------------------------------------------===// diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLAMPI.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLAMPI.cpp new file mode 100644 index 0000000000..aa30b6a03b --- /dev/null +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLAMPI.cpp @@ -0,0 +1,1470 @@ +// #include "mhlo/IR/hlo_ops.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Passes/LinalgUtils.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" +#include "src/enzyme_ad/jax/Utils.h" +// #include "stablehlo/dialect/StablehloOps.h" +// #include "llvm/ADT/DynamicAPInt.h" +// #include "llvm/ADT/SetVector.h" +// #include "llvm/ADT/SmallVector.h" +// #include "llvm/Support/ErrorHandling.h" +// #include "llvm/Support/LogicalResult.h" +// #include "llvm/Support/MathExtras.h" +// #include +// #include + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LOWERENZYMEXLAMPIPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; + +struct MPICommRankOpLowering + : public OpRewritePattern { + + std::string backend; + MPICommRankOpLowering(std::string backend, MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), backend(backend) {} + + LogicalResult matchAndRewrite(enzymexla::MPICommRankOp op, + PatternRewriter &rewriter) const override { + auto context = op->getContext(); + + if (backend == "cpu") { + + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(context); + auto llvmVoidType = LLVM::LLVMVoidType::get(context); + auto i32Type = IntegerType::get(context, 32); + + std::string mpiFunctionName = "MPI_Comm_rank"; + + // For now we just hard code MPI_COMM_WORLD as the communicator. + // TODO make this more flexible + std::string communicatorName = "MPI_COMM_WORLD"; + + // Generate the enzymexla_wrapper_MPI_Comm_rank LLVM function body + std::string wrapperFunctionName = "enzymexla_wrapper_" + mpiFunctionName; + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + // Create the function type + auto funcType = + LLVM::LLVMFunctionType::get(llvmVoidType, // void return type + {llvmPtrType}, // pointer parameter + false); // is variadic: false + + auto wrapperFunc = rewriter.create( + op.getLoc(), wrapperFunctionName, funcType); + + // Add function-level memory effects attribute + auto memoryEffectsAttr = rewriter.getArrayAttr( + {rewriter.getStringAttr("read"), rewriter.getStringAttr("write"), + rewriter.getStringAttr("allocate"), + rewriter.getStringAttr("free")}); + wrapperFunc->setAttr("enzymexla.memory_effects", memoryEffectsAttr); + + Block *entryBlock = wrapperFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + + // Add argument-level memory effects attribute + wrapperFunc.setArgAttr(0, "enzymexla.memory_effects", + memoryEffectsAttr); + + // Get the rank pointer from the argument + Value rankPtr = entryBlock->getArgument(0); + + // Get the address of the communicator + // NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, + // they are represented as word-size values (i.e. `int` or ptr) + Value addressOfComm = rewriter.create( + op.getLoc(), llvmPtrType, communicatorName); + + // TODO error checking + // MPI_Comm_rank returns i32 error code which we're ignoring here + rewriter.create( + op.getLoc(), TypeRange{i32Type}, + SymbolRefAttr::get(context, mpiFunctionName), + ValueRange{addressOfComm, rankPtr}); + + rewriter.create(op.getLoc(), ValueRange{}); + } + + // Insert MPI_Comm_rank function declaration if not already present + if (!moduleOp.lookupSymbol(mpiFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + i32Type, {llvmPtrType, llvmPtrType}, false); + + rewriter.create(op.getLoc(), mpiFunctionName, + funcType, LLVM::Linkage::External); + } + + // Insert MPI_COMM_WORLD declaration if not already present + if (!moduleOp.lookupSymbol(communicatorName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create( + op.getLoc(), llvmPtrType, + /*isConstant=*/true, LLVM::Linkage::External, communicatorName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Create a constant tensor to hold the result + auto tensorType = llvm::cast(op->getResultTypes()[0]); + auto constantAttr = + DenseIntElementsAttr::get(tensorType, ArrayRef{-1}); + Value constantTensor = rewriter.create( + op.getLoc(), tensorType, constantAttr); + + // Call the LLVM function with enzymexla.jit_call + auto aliasAttr = stablehlo::OutputOperandAliasAttr::get( + context, + /*outputTupleIndices=*/ArrayRef{}, + /*operandIndex=*/0, + /*operandTupleIndices=*/ArrayRef{}); + + auto jitCall = rewriter.create( + op.getLoc(), op->getResultTypes(), + mlir::FlatSymbolRefAttr::get(context, wrapperFunctionName), + ValueRange{constantTensor}, rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, + /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr({aliasAttr}), + /*xla_side_effect_free=*/nullptr); + + rewriter.replaceOp(op, jitCall.getResult(0)); + + return success(); + } else { + return rewriter.notifyMatchFailure(op, + "Backend not supported: " + backend); + } + } +}; + +struct MPICommSizeOpLowering + : public OpRewritePattern { + + std::string backend; + MPICommSizeOpLowering(std::string backend, MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), backend(backend) {} + + LogicalResult matchAndRewrite(enzymexla::MPICommSizeOp op, + PatternRewriter &rewriter) const override { + auto context = op->getContext(); + + if (backend == "cpu") { + + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(context); + auto llvmVoidType = LLVM::LLVMVoidType::get(context); + + auto i32Type = IntegerType::get(context, 32); + + std::string mpiFunctionName = "MPI_Comm_size"; + + // For now we just hard code MPI_COMM_WORLD as the communicator. + // TODO make this more flexible + std::string communicatorName = "MPI_COMM_WORLD"; + + // Generate the enzymexla_wrapper_MPI_Comm_size LLVM function body + std::string wrapperFunctionName = "enzymexla_wrapper_" + mpiFunctionName; + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + // Create the function type + auto funcType = + LLVM::LLVMFunctionType::get(llvmVoidType, // void return type + {llvmPtrType}, // parameter types + false); // is variadic: false + + auto wrapperFunc = rewriter.create( + op.getLoc(), wrapperFunctionName, funcType); + + // Add function-level memory effects attribute + auto memoryEffectsAttr = rewriter.getArrayAttr( + {rewriter.getStringAttr("read"), rewriter.getStringAttr("write"), + rewriter.getStringAttr("allocate"), + rewriter.getStringAttr("free")}); + wrapperFunc->setAttr("enzymexla.memory_effects", memoryEffectsAttr); + + Block *entryBlock = wrapperFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + + // Add argument-level memory effects attribute + wrapperFunc.setArgAttr(0, "enzymexla.memory_effects", + memoryEffectsAttr); + + // Get the first (and only) argument of the function + Value sizePtr = entryBlock->getArgument(0); + + // Get the address of the communicator + // NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, + // they are represented as w ord-size values (i.e. `int` or ptr) + Value addressOfComm = rewriter.create( + op.getLoc(), llvmPtrType, communicatorName); + + // TODO error checking + // MPI_Comm_size returns i32 error code which we're ignoring here + rewriter.create( + op.getLoc(), TypeRange{i32Type}, + SymbolRefAttr::get(context, mpiFunctionName), + ValueRange{addressOfComm, sizePtr}); + + rewriter.create(op.getLoc(), ValueRange{}); + } + + // Insert MPI_Comm_size function declaration if not already present + if (!moduleOp.lookupSymbol(mpiFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + i32Type, {llvmPtrType, llvmPtrType}, false); + + rewriter.create(op.getLoc(), mpiFunctionName, + funcType, LLVM::Linkage::External); + } + + // Insert MPI_COMM_WORLD declaration if not already present + if (!moduleOp.lookupSymbol(communicatorName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create( + op.getLoc(), llvmPtrType, + /*isConstant=*/true, LLVM::Linkage::External, communicatorName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Create a constant tensor to hold the result + auto tensorType = llvm::cast(op->getResultTypes()[0]); + auto constantAttr = + DenseIntElementsAttr::get(tensorType, ArrayRef{-1}); + Value constantTensor = rewriter.create( + op.getLoc(), tensorType, constantAttr); + + // Call the LLVM function with enzymexla.jit_call + SmallVector aliases; + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + context, + /*output_operand_aliases=*/std::vector{}, + /*operand_index=*/0, + /*operand_tuple_indices=*/std::vector{})); + + auto jitCall = rewriter.create( + op.getLoc(), op->getResultTypes(), + mlir::FlatSymbolRefAttr::get(context, wrapperFunctionName), + ValueRange{constantTensor}, rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, + /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), + /*xla_side_effect_free=*/nullptr); + + rewriter.replaceOp(op, jitCall); + + return success(); + } else { + return rewriter.notifyMatchFailure(op, + "Backend not supported: " + backend); + } + } +}; + +struct MPIBarrierOpLowering : public OpRewritePattern { + + std::string backend; + MPIBarrierOpLowering(std::string backend, MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), backend(backend) {} + + LogicalResult matchAndRewrite(enzymexla::MPIBarrierOp op, + PatternRewriter &rewriter) const override { + auto context = op->getContext(); + + if (backend == "cpu") { + + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(context); + auto llvmVoidType = LLVM::LLVMVoidType::get(context); + + auto i32Type = IntegerType::get(context, 32); + + std::string mpiFunctionName = "MPI_Barrier"; + + // TODO For now we just hard code MPI_COMM_WORLD as the communicator. + std::string communicatorName = "MPI_COMM_WORLD"; + + // Generate the enzymexla_wrapper_MPI_Barrier LLVM function body + std::string wrapperFunctionName = "enzymexla_wrapper_" + mpiFunctionName; + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + // Create the function type + auto funcType = LLVM::LLVMFunctionType::get(llvmVoidType, {}, false); + + auto wrapperFunc = rewriter.create( + op.getLoc(), wrapperFunctionName, funcType); + + // Add function-level memory effects attribute + auto memoryEffectsAttr = rewriter.getArrayAttr( + {rewriter.getStringAttr("read"), rewriter.getStringAttr("write"), + rewriter.getStringAttr("allocate"), + rewriter.getStringAttr("free")}); + wrapperFunc->setAttr("enzymexla.memory_effects", memoryEffectsAttr); + + Block *entryBlock = wrapperFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + + // Get the address of the communicator + // NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, + // they are represented as w ord-size values (i.e. `int` or ptr) + Value addressOfComm = rewriter.create( + op.getLoc(), llvmPtrType, communicatorName); + + // Call MPI_Barrier + // int MPI_Barrier(MPI_Comm comm) + // TODO returns i32 error code which we're ignoring here + rewriter.create( + op.getLoc(), TypeRange{i32Type}, + SymbolRefAttr::get(context, mpiFunctionName), + ValueRange{addressOfComm}); + + rewriter.create(op.getLoc(), ValueRange{}); + } + + // Insert MPI_Barrier function declaration if not already present + if (!moduleOp.lookupSymbol(mpiFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = + LLVM::LLVMFunctionType::get(i32Type, {llvmPtrType}, false); + + rewriter.create(op.getLoc(), mpiFunctionName, + funcType, LLVM::Linkage::External); + } + + // Insert MPI_COMM_WORLD declaration if not already present + if (!moduleOp.lookupSymbol(communicatorName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create( + op.getLoc(), llvmPtrType, + /*isConstant=*/true, LLVM::Linkage::External, communicatorName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Call the LLVM function with enzymexla.jit_call + rewriter.create( + op.getLoc(), TypeRange{}, + mlir::FlatSymbolRefAttr::get(context, wrapperFunctionName), + ValueRange{}, rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, + /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/nullptr, + /*xla_side_effect_free=*/nullptr); + + rewriter.eraseOp(op); + + return success(); + } else { + return rewriter.notifyMatchFailure(op, + "Backend not supported: " + backend); + } + } +}; + +struct MPISendOpLowering : public OpRewritePattern { + + std::string backend; + MPISendOpLowering(std::string backend, MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), backend(backend) {} + + LogicalResult matchAndRewrite(enzymexla::MPISendOp op, + PatternRewriter &rewriter) const override { + auto context = op->getContext(); + + if (backend == "cpu") { + + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(context); + auto llvmVoidType = LLVM::LLVMVoidType::get(context); + + auto i32Type = IntegerType::get(context, 32); + + std::string mpiFunctionName = "MPI_Send"; + + // get the MPI datatype + auto datatype = op.getDatatype(); + StringRef datatypeName = stringifyMPIDatatype(datatype); + + // For now we just hard code MPI_COMM_WORLD as the communicator. + // TODO make this more flexible + std::string communicatorName = "MPI_COMM_WORLD"; + + // Generate the enzymexla_wrapper LLVM function body + std::string wrapperFunctionName = + "enzymexla_wrapper_" + mpiFunctionName + "_" + datatypeName.str(); + + if (!moduleOp.lookupSymbol(wrapperFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + // Create the wrapper function decl + auto funcType = LLVM::LLVMFunctionType::get( + llvmVoidType, {llvmPtrType, llvmPtrType, llvmPtrType, llvmPtrType}, + false); + + auto wrapperFunc = rewriter.create( + op.getLoc(), wrapperFunctionName, funcType); + + // Add function-level memory effects attribute + auto memoryEffectsAttr = rewriter.getArrayAttr( + {rewriter.getStringAttr("read"), rewriter.getStringAttr("write"), + rewriter.getStringAttr("allocate"), + rewriter.getStringAttr("free")}); + wrapperFunc->setAttr("enzymexla.memory_effects", memoryEffectsAttr); + + Block *entryBlock = wrapperFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + + // Add argument-level memory effects attribute to all arguments + for (unsigned i = 0; i < 4; ++i) { + wrapperFunc.setArgAttr(i, "enzymexla.memory_effects", + memoryEffectsAttr); + } + + // Get the function arguments + Value bufPtr = entryBlock->getArgument(0); + Value countPtr = entryBlock->getArgument(1); + Value destPtr = entryBlock->getArgument(2); + Value tagPtr = entryBlock->getArgument(3); + + // Load the count, dest, tag values + Value count = + rewriter.create(op.getLoc(), i32Type, countPtr); + + Value dest = + rewriter.create(op.getLoc(), i32Type, destPtr); + + Value tag = rewriter.create(op.getLoc(), i32Type, tagPtr); + + // Get the address of the datatype + // NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, + // they are represented as w ord-size values (i.e. `int` or ptr) + Value addressOfDtype = rewriter.create( + op.getLoc(), llvmPtrType, datatypeName); + + // Get the address of the communicator + Value addressOfComm = rewriter.create( + op.getLoc(), llvmPtrType, communicatorName); + + // Call MPI_Send + // int MPI_Send(const void* buf, int count, MPI_Datatype datatype, int + // dest, int tag, MPI_Comm comm) + // TODO returns i32 error code which we're ignoring here + rewriter.create( + op.getLoc(), TypeRange{i32Type}, + SymbolRefAttr::get(context, mpiFunctionName), + ValueRange{bufPtr, count, addressOfDtype, dest, tag, + addressOfComm}); + + rewriter.create(op.getLoc(), ValueRange{}); + } + + // Insert MPI_Send function declaration if not already present + if (!moduleOp.lookupSymbol(mpiFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + i32Type, + {llvmPtrType, i32Type, llvmPtrType, i32Type, i32Type, llvmPtrType}, + false); + + rewriter.create(op.getLoc(), mpiFunctionName, + funcType, LLVM::Linkage::External); + } + + // Insert MPI_COMM_WORLD declaration if not already present + if (!moduleOp.lookupSymbol(communicatorName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create( + op.getLoc(), llvmPtrType, + /*isConstant=*/true, LLVM::Linkage::External, communicatorName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Insert datatype declaration if not already present + if (!moduleOp.lookupSymbol(datatypeName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create(op.getLoc(), llvmPtrType, + /*isConstant=*/true, + LLVM::Linkage::External, datatypeName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Get all orinigal op operands + auto operands = op.getOperands(); + + // Call the LLVM function with enzymexla.jit_call + rewriter.create( + op.getLoc(), TypeRange{}, + mlir::FlatSymbolRefAttr::get(context, wrapperFunctionName), + ValueRange{operands}, rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, + /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/nullptr, + /*xla_side_effect_free=*/nullptr); + + rewriter.eraseOp(op); + + return success(); + } else { + return rewriter.notifyMatchFailure(op, + "Backend not supported: " + backend); + } + } +}; + +struct MPIRecvOpLowering : public OpRewritePattern { + + std::string backend; + MPIRecvOpLowering(std::string backend, MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), backend(backend) {} + + LogicalResult matchAndRewrite(enzymexla::MPIRecvOp op, + PatternRewriter &rewriter) const override { + auto context = op->getContext(); + + if (backend == "cpu") { + + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(context); + auto llvmVoidType = LLVM::LLVMVoidType::get(context); + + auto i32Type = IntegerType::get(context, 32); + + std::string mpiFunctionName = "MPI_Recv"; + + // get the MPI datatype + auto datatype = op.getDatatype(); + StringRef datatypeName = stringifyMPIDatatype(datatype); + + // For now we just hard code MPI_COMM_WORLD as the communicator. + // TODO make this more flexible + std::string communicatorName = "MPI_COMM_WORLD"; + + std::string statusName = "MPI_STATUS_IGNORE"; + + // Generate the enzymexla_wrapper LLVM function body + std::string wrapperFunctionName = + "enzymexla_wrapper_" + mpiFunctionName + "_" + datatypeName.str(); + + if (!moduleOp.lookupSymbol(wrapperFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + // Create the wrapper function decl + auto funcType = LLVM::LLVMFunctionType::get( + llvmVoidType, {llvmPtrType, llvmPtrType, llvmPtrType, llvmPtrType}, + false); + + auto wrapperFunc = rewriter.create( + op.getLoc(), wrapperFunctionName, funcType); + + // Add function-level memory effects attribute + auto memoryEffectsAttr = rewriter.getArrayAttr( + {rewriter.getStringAttr("read"), rewriter.getStringAttr("write"), + rewriter.getStringAttr("allocate"), + rewriter.getStringAttr("free")}); + wrapperFunc->setAttr("enzymexla.memory_effects", memoryEffectsAttr); + + Block *entryBlock = wrapperFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + + // Add argument-level memory effects attribute to all arguments + for (unsigned i = 0; i < 4; ++i) { + wrapperFunc.setArgAttr(i, "enzymexla.memory_effects", + memoryEffectsAttr); + } + + // Get the function arguments + Value bufPtr = entryBlock->getArgument(0); + Value countPtr = entryBlock->getArgument(1); + Value srcPtr = entryBlock->getArgument(2); + Value tagPtr = entryBlock->getArgument(3); + + // Load the count, src, tag values + Value count = + rewriter.create(op.getLoc(), i32Type, countPtr); + + Value src = rewriter.create(op.getLoc(), i32Type, srcPtr); + + Value tag = rewriter.create(op.getLoc(), i32Type, tagPtr); + + // Get the address of the datatype + // NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, + // they are represented as w ord-size values (i.e. `int` or ptr) + Value addressOfDtype = rewriter.create( + op.getLoc(), llvmPtrType, datatypeName); + + // Get the address of the communicator + Value addressOfComm = rewriter.create( + op.getLoc(), llvmPtrType, communicatorName); + + // Get the address of the status + Value addressOfStatus = rewriter.create( + op.getLoc(), llvmPtrType, statusName); + + // Call MPI_Recv + // int MPI_Recv(void* buf, int count, MPI_Datatype datatype, int + // source, int tag, MPI_Comm comm, MPI_Status* status) + // TODO returns i32 error code which we're ignoring here + rewriter.create( + op.getLoc(), TypeRange{i32Type}, + SymbolRefAttr::get(context, mpiFunctionName), + ValueRange{bufPtr, count, addressOfDtype, src, tag, addressOfComm, + addressOfStatus}); + + rewriter.create(op.getLoc(), ValueRange{}); + } + + // Insert MPI_Recv function declaration if not already present + if (!moduleOp.lookupSymbol(mpiFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + i32Type, + {llvmPtrType, i32Type, llvmPtrType, i32Type, i32Type, llvmPtrType, + llvmPtrType}, + false); + + rewriter.create(op.getLoc(), mpiFunctionName, + funcType, LLVM::Linkage::External); + } + + // Insert MPI_STATUS_IGNORE declaration if not already present + if (!moduleOp.lookupSymbol(communicatorName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create(op.getLoc(), llvmPtrType, + /*isConstant=*/true, + LLVM::Linkage::External, statusName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Insert MPI_COMM_WORLD declaration if not already present + if (!moduleOp.lookupSymbol(communicatorName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create( + op.getLoc(), llvmPtrType, + /*isConstant=*/true, LLVM::Linkage::External, communicatorName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Insert datatype declaration if not already present + if (!moduleOp.lookupSymbol(datatypeName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create(op.getLoc(), llvmPtrType, + /*isConstant=*/true, + LLVM::Linkage::External, datatypeName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Get all orinigal op operands + auto operands = op.getOperands(); + + // Call the LLVM function with enzymexla.jit_call + SmallVector aliases; + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + context, + /*output_operand_aliases=*/std::vector{}, + /*operand_index=*/0, + /*operand_tuple_indices=*/std::vector{})); + + auto jitCall = rewriter.create( + op.getLoc(), op->getResultTypes(), + mlir::FlatSymbolRefAttr::get(context, wrapperFunctionName), + ValueRange{operands}, rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, + /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), + /*xla_side_effect_free=*/nullptr); + + rewriter.replaceOp(op, jitCall); + + return success(); + } else { + return rewriter.notifyMatchFailure(op, + "Backend not supported: " + backend); + } + } +}; + +struct MPIIsendOpLowering : public OpRewritePattern { + + std::string backend; + MPIIsendOpLowering(std::string backend, MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), backend(backend) {} + + LogicalResult matchAndRewrite(enzymexla::MPIIsendOp op, + PatternRewriter &rewriter) const override { + auto context = op->getContext(); + + if (backend == "cpu") { + + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(context); + auto llvmVoidType = LLVM::LLVMVoidType::get(context); + + auto i32Type = IntegerType::get(context, 32); + + std::string mpiFunctionName = "MPI_Isend"; + + // get the MPI datatype + auto datatype = op.getDatatype(); + StringRef datatypeName = stringifyMPIDatatype(datatype); + + // For now we just hard code MPI_COMM_WORLD as the communicator. + // TODO make this more flexible + std::string communicatorName = "MPI_COMM_WORLD"; + + std::string wrapperFunctionName = + "enzymexla_wrapper_" + mpiFunctionName + "_" + datatypeName.str(); + + if (!moduleOp.lookupSymbol(wrapperFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + // Create the wrapper function decl + auto funcType = LLVM::LLVMFunctionType::get( + llvmVoidType, + {llvmPtrType, llvmPtrType, llvmPtrType, llvmPtrType, llvmPtrType}, + false); + + auto wrapperFunc = rewriter.create( + op.getLoc(), wrapperFunctionName, funcType); + + // Add function-level memory effects attribute + auto memoryEffectsAttr = rewriter.getArrayAttr( + {rewriter.getStringAttr("read"), rewriter.getStringAttr("write"), + rewriter.getStringAttr("allocate"), + rewriter.getStringAttr("free")}); + wrapperFunc->setAttr("enzymexla.memory_effects", memoryEffectsAttr); + + Block *entryBlock = wrapperFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + + // Add argument-level memory effects attribute to all arguments + for (unsigned i = 0; i < 5; ++i) { + wrapperFunc.setArgAttr(i, "enzymexla.memory_effects", + memoryEffectsAttr); + } + + // Get the function arguments + Value bufPtr = entryBlock->getArgument(0); + Value countPtr = entryBlock->getArgument(1); + Value destPtr = entryBlock->getArgument(2); + Value tagPtr = entryBlock->getArgument(3); + Value requestPtr = entryBlock->getArgument(4); + + // Load the count, dest, tag values + Value count = + rewriter.create(op.getLoc(), i32Type, countPtr); + + Value dest = + rewriter.create(op.getLoc(), i32Type, destPtr); + + Value tag = rewriter.create(op.getLoc(), i32Type, tagPtr); + + // Get the address of the datatype + // NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, + // they are represented as w ord-size values (i.e. `int` or ptr) + Value addressOfDtype = rewriter.create( + op.getLoc(), llvmPtrType, datatypeName); + + // Get the address of the communicator + Value addressOfComm = rewriter.create( + op.getLoc(), llvmPtrType, communicatorName); + + // Call MPI_Isend + // int MPI_Isend(void* buf, int count, MPI_Datatype datatype, int + // dest, int tag, MPI_Comm comm, MPI_Request* request) + // TODO returns i32 error code which we're ignoring here + rewriter.create( + op.getLoc(), TypeRange{i32Type}, + SymbolRefAttr::get(context, mpiFunctionName), + ValueRange{bufPtr, count, addressOfDtype, dest, tag, addressOfComm, + requestPtr}); + + rewriter.create(op.getLoc(), ValueRange{}); + } + + // Insert MPI_Isend function declaration if not already present + if (!moduleOp.lookupSymbol(mpiFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + i32Type, + {llvmPtrType, i32Type, llvmPtrType, i32Type, i32Type, llvmPtrType, + llvmPtrType}, + false); + + rewriter.create(op.getLoc(), mpiFunctionName, + funcType, LLVM::Linkage::External); + } + + // Insert MPI_COMM_WORLD declaration if not already present + if (!moduleOp.lookupSymbol(communicatorName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create( + op.getLoc(), llvmPtrType, + /*isConstant=*/true, LLVM::Linkage::External, communicatorName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Insert datatype declaration if not already present + if (!moduleOp.lookupSymbol(datatypeName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create(op.getLoc(), llvmPtrType, + /*isConstant=*/true, + LLVM::Linkage::External, datatypeName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Get all orinigal op operands + auto opOperands = op.getOperands(); + + // Create a constant tensor to hold request + auto i64Type = rewriter.getI64Type(); + auto tensorType = RankedTensorType::get({}, i64Type); + auto constantAttr = + DenseIntElementsAttr::get(tensorType, ArrayRef{-1}); + Value constantTensor = rewriter.create( + op.getLoc(), tensorType, constantAttr); + + // Combine all operands + SmallVector jitCallOperands(opOperands.begin(), opOperands.end()); + jitCallOperands.push_back(constantTensor); + + // Add request to output operand aliases + SmallVector aliases; + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + context, + /*output_operand_aliases=*/std::vector{}, + /*operand_index=*/4, + /*operand_tuple_indices=*/std::vector{})); + + // Call the LLVM function with enzymexla.jit_call + auto jitCall = rewriter.create( + op.getLoc(), op->getResultTypes(), + mlir::FlatSymbolRefAttr::get(context, wrapperFunctionName), + jitCallOperands, rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, + /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), + /*xla_side_effect_free=*/nullptr); + + rewriter.replaceOp(op, jitCall); + + return success(); + } else { + return rewriter.notifyMatchFailure(op, + "Backend not supported: " + backend); + } + } +}; + +struct MPIIrecvOpLowering : public OpRewritePattern { + + std::string backend; + MPIIrecvOpLowering(std::string backend, MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), backend(backend) {} + + LogicalResult matchAndRewrite(enzymexla::MPIIrecvOp op, + PatternRewriter &rewriter) const override { + auto context = op->getContext(); + + if (backend == "cpu") { + + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(context); + auto llvmVoidType = LLVM::LLVMVoidType::get(context); + + auto i32Type = IntegerType::get(context, 32); + + std::string mpiFunctionName = "MPI_Irecv"; + + // get the MPI datatype + auto datatype = op.getDatatype(); + StringRef datatypeName = stringifyMPIDatatype(datatype); + + // For now we just hard code MPI_COMM_WORLD as the communicator. + // TODO make this more flexible + std::string communicatorName = "MPI_COMM_WORLD"; + + // Generate the enzymexla_wrapper LLVM function body + std::string wrapperFunctionName = + "enzymexla_wrapper_" + mpiFunctionName + "_" + datatypeName.str(); + + if (!moduleOp.lookupSymbol(wrapperFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + // Create the wrapper function decl + auto funcType = LLVM::LLVMFunctionType::get( + llvmVoidType, + {llvmPtrType, llvmPtrType, llvmPtrType, llvmPtrType, llvmPtrType}, + false); + + auto wrapperFunc = rewriter.create( + op.getLoc(), wrapperFunctionName, funcType); + + // Add function-level memory effects attribute + auto memoryEffectsAttr = rewriter.getArrayAttr( + {rewriter.getStringAttr("read"), rewriter.getStringAttr("write"), + rewriter.getStringAttr("allocate"), + rewriter.getStringAttr("free")}); + wrapperFunc->setAttr("enzymexla.memory_effects", memoryEffectsAttr); + + Block *entryBlock = wrapperFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + + // Add argument-level memory effects attribute to all arguments + for (unsigned i = 0; i < 5; ++i) { + wrapperFunc.setArgAttr(i, "enzymexla.memory_effects", + memoryEffectsAttr); + } + + // Get the function arguments + Value bufPtr = entryBlock->getArgument(0); + Value countPtr = entryBlock->getArgument(1); + Value srcPtr = entryBlock->getArgument(2); + Value tagPtr = entryBlock->getArgument(3); + Value requestPtr = entryBlock->getArgument(4); + + // Load the count, src, tag values + Value count = + rewriter.create(op.getLoc(), i32Type, countPtr); + + Value src = rewriter.create(op.getLoc(), i32Type, srcPtr); + + Value tag = rewriter.create(op.getLoc(), i32Type, tagPtr); + + // Get the address of the datatype + // NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, + // they are represented as w ord-size values (i.e. `int` or ptr) + Value addressOfDtype = rewriter.create( + op.getLoc(), llvmPtrType, datatypeName); + + // Get the address of the communicator + Value addressOfComm = rewriter.create( + op.getLoc(), llvmPtrType, communicatorName); + + // Call MPI_Irecv + // int MPI_Irecv(void* buf, int count, MPI_Datatype datatype, int + // source, int tag, MPI_Comm comm, MPI_Request* request) + // TODO returns i32 error code which we're ignoring here + rewriter.create( + op.getLoc(), TypeRange{i32Type}, + SymbolRefAttr::get(context, mpiFunctionName), + ValueRange{bufPtr, count, addressOfDtype, src, tag, addressOfComm, + requestPtr}); + + rewriter.create(op.getLoc(), ValueRange{}); + } + + // Insert MPI_Irecv function declaration if not already present + if (!moduleOp.lookupSymbol(mpiFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + i32Type, + {llvmPtrType, i32Type, llvmPtrType, i32Type, i32Type, llvmPtrType, + llvmPtrType}, + false); + + rewriter.create(op.getLoc(), mpiFunctionName, + funcType, LLVM::Linkage::External); + } + + // Insert MPI_COMM_WORLD declaration if not already present + if (!moduleOp.lookupSymbol(communicatorName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create( + op.getLoc(), llvmPtrType, + /*isConstant=*/true, LLVM::Linkage::External, communicatorName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Insert datatype declaration if not already present + if (!moduleOp.lookupSymbol(datatypeName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create(op.getLoc(), llvmPtrType, + /*isConstant=*/true, + LLVM::Linkage::External, datatypeName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Get all orinigal op operands + auto opOperands = op.getOperands(); + + // Create a constant tensor to hold request + auto i64Type = rewriter.getI64Type(); + auto tensorType = RankedTensorType::get({}, i64Type); + auto constantAttr = + DenseIntElementsAttr::get(tensorType, ArrayRef{-1}); + Value constantTensor = rewriter.create( + op.getLoc(), tensorType, constantAttr); + + // Combine all operands + SmallVector jitCallOperands(opOperands.begin(), opOperands.end()); + jitCallOperands.push_back(constantTensor); + + // Add buffer to output operand aliases + SmallVector aliases; + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + context, + /*output_operand_aliases=*/std::vector{0}, + /*operand_index=*/0, + /*operand_tuple_indices=*/std::vector{})); + + // Add request to output operand aliases + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + context, + /*output_operand_aliases=*/std::vector{1}, + /*operand_index=*/4, + /*operand_tuple_indices=*/std::vector{})); + + // Call the LLVM function with enzymexla.jit_call + auto jitCall = rewriter.create( + op.getLoc(), op->getResultTypes(), + mlir::FlatSymbolRefAttr::get(context, wrapperFunctionName), + ValueRange{jitCallOperands}, rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, + /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), + /*xla_side_effect_free=*/nullptr); + + rewriter.replaceOp(op, jitCall); + + return success(); + } else { + return rewriter.notifyMatchFailure(op, + "Backend not supported: " + backend); + } + } +}; + +struct MPIWaitOpLowering : public OpRewritePattern { + + std::string backend; + MPIWaitOpLowering(std::string backend, MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), backend(backend) {} + + LogicalResult matchAndRewrite(enzymexla::MPIWaitOp op, + PatternRewriter &rewriter) const override { + auto context = op->getContext(); + + if (backend == "cpu") { + + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(context); + auto llvmVoidType = LLVM::LLVMVoidType::get(context); + + auto i32Type = IntegerType::get(context, 32); + + std::string mpiFunctionName = "MPI_Wait"; + + // Generate the enzymexla_wrapper LLVM function body + std::string wrapperFunctionName = "enzymexla_wrapper_" + mpiFunctionName; + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + // Create the wrapper function decl + auto funcType = + LLVM::LLVMFunctionType::get(llvmVoidType, {llvmPtrType}, false); + + auto wrapperFunc = rewriter.create( + op.getLoc(), wrapperFunctionName, funcType); + + // Add function-level memory effects attribute + auto memoryEffectsAttr = rewriter.getArrayAttr( + {rewriter.getStringAttr("read"), rewriter.getStringAttr("write"), + rewriter.getStringAttr("allocate"), + rewriter.getStringAttr("free")}); + wrapperFunc->setAttr("enzymexla.memory_effects", memoryEffectsAttr); + + Block *entryBlock = wrapperFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + + // Add argument-level memory effects attribute to all arguments + wrapperFunc.setArgAttr(0, "enzymexla.memory_effects", + memoryEffectsAttr); + + // Get the function argument + Value requestPtr = entryBlock->getArgument(0); + + // Allocate a 1x!llvm.array<6 x i32> that we use in place of MPI_Status + // Size of status is implem dependendent, this should cover the max + Value numElements = rewriter.create( + op.getLoc(), i32Type, rewriter.getI32IntegerAttr(1)); + + auto arrayType = LLVM::LLVMArrayType::get(i32Type, 6); + + Value statusPtr = rewriter.create( + op.getLoc(), llvmPtrType, arrayType, numElements); + + // Call MPI_Wait + // int MPI_Wait(MPI_Request* request, MPI_Status* status) + // TODO returns i32 error code which we're ignoring here + rewriter.create( + op.getLoc(), TypeRange{i32Type}, + SymbolRefAttr::get(context, mpiFunctionName), + ValueRange{requestPtr, statusPtr}); + + rewriter.create(op.getLoc(), ValueRange{}); + } + + // Insert MPI_Wait function declaration if not already present + if (!moduleOp.lookupSymbol(mpiFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + i32Type, {llvmPtrType, llvmPtrType}, false); + + rewriter.create(op.getLoc(), mpiFunctionName, + funcType, LLVM::Linkage::External); + } + + // Get the request operand + auto request = op.getRequest(); + + // Call the LLVM function with enzymexla.jit_call + rewriter.create( + op.getLoc(), TypeRange{}, + mlir::FlatSymbolRefAttr::get(context, wrapperFunctionName), + ValueRange{request}, rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, + /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/nullptr, + /*xla_side_effect_free=*/nullptr); + + rewriter.eraseOp(op); + + return success(); + } else { + return rewriter.notifyMatchFailure(op, + "Backend not supported: " + backend); + } + } +}; + +struct MPIAllreduceOpLowering + : public OpRewritePattern { + + std::string backend; + MPIAllreduceOpLowering(std::string backend, MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), backend(backend) {} + + LogicalResult matchAndRewrite(enzymexla::MPIAllreduceOp op, + PatternRewriter &rewriter) const override { + auto context = op->getContext(); + + if (backend == "cpu") { + + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(context); + auto llvmVoidType = LLVM::LLVMVoidType::get(context); + + auto i32Type = IntegerType::get(context, 32); + + std::string mpiFunctionName = "MPI_Allreduce"; + + // get the MPI datatype + auto datatype = op.getDatatype(); + StringRef datatypeName = stringifyMPIDatatype(datatype); + + // get the MPI Op type + StringRef mpiOpName = stringifyMPIOp(op.getOp()); + + // TODO For now we just hard code MPI_COMM_WORLD as the communicator. + std::string communicatorName = "MPI_COMM_WORLD"; + + // Generate the enzymexla_wrapper LLVM function body + std::string wrapperFunctionName = "enzymexla_wrapper_" + mpiFunctionName + + "_" + mpiOpName.str() + "_" + + datatypeName.str(); + + if (!moduleOp.lookupSymbol(wrapperFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + // Create the wrapper function decl + auto funcType = LLVM::LLVMFunctionType::get( + llvmVoidType, {llvmPtrType, llvmPtrType, llvmPtrType}, false); + + auto wrapperFunc = rewriter.create( + op.getLoc(), wrapperFunctionName, funcType); + + // Add function-level memory effects attribute + auto memoryEffectsAttr = rewriter.getArrayAttr( + {rewriter.getStringAttr("read"), rewriter.getStringAttr("write"), + rewriter.getStringAttr("allocate"), + rewriter.getStringAttr("free")}); + wrapperFunc->setAttr("enzymexla.memory_effects", memoryEffectsAttr); + + Block *entryBlock = wrapperFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + + // Add argument-level memory effects attribute to all arguments + for (unsigned i = 0; i < 3; ++i) { + wrapperFunc.setArgAttr(i, "enzymexla.memory_effects", + memoryEffectsAttr); + } + + // Get the function arguments + Value sendbufPtr = entryBlock->getArgument(0); + Value inbufPtr = entryBlock->getArgument(1); + Value countPtr = entryBlock->getArgument(2); + + // Load the count value + Value count = + rewriter.create(op.getLoc(), i32Type, countPtr); + + // Get the address of the datatype + // NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, + // they are represented as w ord-size values (i.e. `int` or ptr) + Value addressOfDtype = rewriter.create( + op.getLoc(), llvmPtrType, datatypeName); + + // Get the address of the communicator + Value addressOfComm = rewriter.create( + op.getLoc(), llvmPtrType, communicatorName); + + // Get the address of the MPI Op + Value addressOfMPIOp = rewriter.create( + op.getLoc(), llvmPtrType, mpiOpName); + + // Call MPI_Allreduce + // int MPI_Allreduce(const void* sendbuf, void* recvbuf, int count, + // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) + // TODO returns i32 error code which we're ignoring here + rewriter.create( + op.getLoc(), TypeRange{i32Type}, + SymbolRefAttr::get(context, mpiFunctionName), + ValueRange{sendbufPtr, inbufPtr, count, addressOfDtype, + addressOfMPIOp, addressOfComm}); + + rewriter.create(op.getLoc(), ValueRange{}); + } + + // Insert MPI_Allreduce function declaration if not already present + if (!moduleOp.lookupSymbol(mpiFunctionName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = + LLVM::LLVMFunctionType::get(i32Type, + {llvmPtrType, llvmPtrType, i32Type, + llvmPtrType, llvmPtrType, llvmPtrType}, + false); + + rewriter.create(op.getLoc(), mpiFunctionName, + funcType, LLVM::Linkage::External); + } + + // Insert MPI_COMM_WORLD declaration if not already present + if (!moduleOp.lookupSymbol(communicatorName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create( + op.getLoc(), llvmPtrType, + /*isConstant=*/true, LLVM::Linkage::External, communicatorName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Insert datatype declaration if not already present + if (!moduleOp.lookupSymbol(datatypeName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create(op.getLoc(), llvmPtrType, + /*isConstant=*/true, + LLVM::Linkage::External, datatypeName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Insert MPI_Op declaration if not already present + if (!moduleOp.lookupSymbol(mpiOpName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + rewriter.create(op.getLoc(), llvmPtrType, + /*isConstant=*/true, + LLVM::Linkage::External, mpiOpName, + /*value=*/Attribute(), + /*alignment=*/0, + /*addrSpace=*/0); + } + + // Get all orinigal op operands + auto operands = op.getOperands(); + + // Add inbuf to output operand aliases + SmallVector aliases; + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + context, + /*output_operand_aliases=*/std::vector{}, + /*operand_index=*/1, + /*operand_tuple_indices=*/std::vector{})); + + // Call the LLVM function with enzymexla.jit_call + auto jitCall = rewriter.create( + op.getLoc(), op->getResultTypes(), + mlir::FlatSymbolRefAttr::get(context, wrapperFunctionName), + ValueRange{operands}, rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, + /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), + /*xla_side_effect_free=*/nullptr); + + rewriter.replaceOp(op, jitCall); + + return success(); + } else { + return rewriter.notifyMatchFailure(op, + "Backend not supported: " + backend); + } + } +}; + +struct LowerEnzymeXLAMPIPass + : public enzyme::impl::LowerEnzymeXLAMPIPassBase { + using Base::Base; + + void runOnOperation() override { + auto context = getOperation()->getContext(); + RewritePatternSet patterns(context); + + patterns.add(backend, context); + patterns.add(backend, context); + patterns.add(backend, context); + patterns.add(backend, context); + patterns.add(backend, context); + patterns.add(backend, context); + patterns.add(backend, context); + patterns.add(backend, context); + patterns.add(backend, context); + + GreedyRewriteConfig config; + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + signalPassFailure(); + } + } +}; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 6f135eec29..178e4f6fde 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -460,6 +460,22 @@ def LowerEnzymeXLALinalgPass : Pass<"lower-enzymexla-linalg"> { ]; } +def LowerEnzymeXLAMPIPass : Pass<"lower-enzymexla-mpi"> { + let summary = "Lower MPI Ops to LLVM"; + let dependentDialects = [ + "enzymexla::EnzymeXLADialect", + "LLVM::LLVMDialect", + ]; + let options = [ + Option< + /*C++ variable name=*/"backend", + /*CLI argument=*/"backend", + /*type=*/"std::string", + /*default=*/"\"cpu\"", + /*description=*/"HW backend">, + ]; +} + def LowerEnzymeXLALapackPass : Pass<"lower-enzymexla-lapack"> { let summary = "Lower enzymexla.lapack ops to stablehlo"; let dependentDialects = [ diff --git a/test/lit_tests/mpi/allreduce.mlir b/test/lit_tests/mpi/allreduce.mlir new file mode 100644 index 0000000000..891d26258c --- /dev/null +++ b/test/lit_tests/mpi/allreduce.mlir @@ -0,0 +1,31 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(lower-enzymexla-mpi{backend=cpu})" %s | FileCheck %s --check-prefix=CPU + +module { + func.func @main(%arg0: tensor {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) -> tensor attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { + %c = stablehlo.constant dense<0> : tensor + %c_0 = stablehlo.constant dense<1> : tensor + %0 = enzymexla.mpi.allreduce(%arg0, %c, %c_0) {datatype = #enzymexla.datatype, op = #enzymexla.op} : (tensor, tensor, tensor) -> tensor + return %0 : tensor + } +} + +// CPU: module { +// CPU-NEXT: llvm.mlir.global external constant @MPI_LAND() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.mlir.global external constant @MPI_INT() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.mlir.global external constant @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.func @enzymexla_wrapper_MPI_Allreduce_MPI_LAND_MPI_INT(%arg0: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg2: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %0 = llvm.mlir.addressof @MPI_LAND : !llvm.ptr +// CPU-NEXT: %1 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr +// CPU-NEXT: %2 = llvm.mlir.addressof @MPI_INT : !llvm.ptr +// CPU-NEXT: %3 = llvm.load %arg2 : !llvm.ptr -> i32 +// CPU-NEXT: %4 = llvm.call @MPI_Allreduce(%arg0, %arg1, %3, %2, %0, %1) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.return +// CPU-NEXT: } +// CPU-NEXT: func.func @main(%arg0: tensor {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) -> tensor attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %c = stablehlo.constant dense<0> : tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %0 = enzymexla.jit_call @enzymexla_wrapper_MPI_Allreduce_MPI_LAND_MPI_INT (%arg0, %c, %c_0) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor, tensor, tensor) -> tensor +// CPU-NEXT: return %0 : tensor +// CPU-NEXT: } +// CPU-NEXT: } diff --git a/test/lit_tests/mpi/barrier.mlir b/test/lit_tests/mpi/barrier.mlir new file mode 100644 index 0000000000..705342b132 --- /dev/null +++ b/test/lit_tests/mpi/barrier.mlir @@ -0,0 +1,22 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(lower-enzymexla-mpi{backend=cpu})" %s | FileCheck %s --check-prefix=CPU + +module { + func.func @main() attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { + enzymexla.mpi.barrier + return + } +} + +// CPU: module { +// CPU-NEXT: llvm.mlir.global external constant @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.func @MPI_Barrier(!llvm.ptr) -> i32 +// CPU-NEXT: llvm.func @enzymexla_wrapper_MPI_Barrier() attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %0 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr +// CPU-NEXT: %1 = llvm.call @MPI_Barrier(%0) : (!llvm.ptr) -> i32 +// CPU-NEXT: llvm.return +// CPU-NEXT: } +// CPU-NEXT: func.func @main() attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: enzymexla.jit_call @enzymexla_wrapper_MPI_Barrier () : () -> () +// CPU-NEXT: return +// CPU-NEXT: } +// CPU-NEXT: } diff --git a/test/lit_tests/mpi/comm_rank.mlir b/test/lit_tests/mpi/comm_rank.mlir new file mode 100644 index 0000000000..ad4636f314 --- /dev/null +++ b/test/lit_tests/mpi/comm_rank.mlir @@ -0,0 +1,23 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(lower-enzymexla-mpi{backend=cpu})" %s | FileCheck %s --check-prefix=CPU + +module { + func.func @main() -> tensor attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { + %0 = enzymexla.mpi.comm_rank : tensor + return %0 : tensor + } +} + +// CPU: module { +// CPU-NEXT: llvm.mlir.global external constant @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.func @enzymexla_wrapper_MPI_Comm_rank(%arg0: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %0 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr +// CPU-NEXT: %1 = llvm.call @MPI_Comm_rank(%0, %arg0) : (!llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.return +// CPU-NEXT: } +// CPU-NEXT: func.func @main() -> tensor attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor +// CPU-NEXT: %0 = enzymexla.jit_call @enzymexla_wrapper_MPI_Comm_rank (%c) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor) -> tensor +// CPU-NEXT: return %0 : tensor +// CPU-NEXT: } +// CPU-NEXT: } diff --git a/test/lit_tests/mpi/comm_size.mlir b/test/lit_tests/mpi/comm_size.mlir new file mode 100644 index 0000000000..7733ae3d83 --- /dev/null +++ b/test/lit_tests/mpi/comm_size.mlir @@ -0,0 +1,23 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(lower-enzymexla-mpi{backend=cpu})" %s | FileCheck %s --check-prefix=CPU + +module { + func.func @main() -> tensor attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { + %0 = enzymexla.mpi.comm_size : tensor + return %0 : tensor + } +} + +// CPU: module { +// CPU-NEXT: llvm.mlir.global external constant @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.func @enzymexla_wrapper_MPI_Comm_size(%arg0: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %0 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr +// CPU-NEXT: %1 = llvm.call @MPI_Comm_size(%0, %arg0) : (!llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.return +// CPU-NEXT: } +// CPU-NEXT: func.func @main() -> tensor attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor +// CPU-NEXT: %0 = enzymexla.jit_call @enzymexla_wrapper_MPI_Comm_size (%c) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor) -> tensor +// CPU-NEXT: return %0 : tensor +// CPU-NEXT: } +// CPU-NEXT: } diff --git a/test/lit_tests/mpi/irecv-wait.mlir b/test/lit_tests/mpi/irecv-wait.mlir new file mode 100644 index 0000000000..62f1371e92 --- /dev/null +++ b/test/lit_tests/mpi/irecv-wait.mlir @@ -0,0 +1,48 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(lower-enzymexla-mpi{backend=cpu})" %s | FileCheck %s --check-prefix=CPU + +module { + func.func @main(%arg0: tensor<5xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"], tf.aliasing_output = 0 : i32}) -> tensor<5xf64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { + %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> + %c = stablehlo.constant dense<1> : tensor + %c_0 = stablehlo.constant dense<42> : tensor + %c_1 = stablehlo.constant dense<5> : tensor + %c_2 = stablehlo.constant dense<-1> : tensor + %outbuf, %request = enzymexla.mpi.irecv(%0, %c_1, %c, %c_0) {datatype = #enzymexla.datatype} : (tensor<5xf64>, tensor, tensor, tensor) -> (tensor<5xf64>, tensor) + enzymexla.mpi.wait(%request) : tensor + %1 = stablehlo.transpose %outbuf, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> + return %1 : tensor<5xf64> + } +} + +// CPU: module { +// CPU-NEXT: llvm.mlir.global external constant @MPI_INT() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.mlir.global external constant @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.func @MPI_Irecv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.func @enzymexla_wrapper_MPI_Irecv_MPI_INT(%arg0: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg2: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg3: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg4: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %0 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr +// CPU-NEXT: %1 = llvm.mlir.addressof @MPI_INT : !llvm.ptr +// CPU-NEXT: %2 = llvm.load %arg1 : !llvm.ptr -> i32 +// CPU-NEXT: %3 = llvm.load %arg2 : !llvm.ptr -> i32 +// CPU-NEXT: %4 = llvm.load %arg3 : !llvm.ptr -> i32 +// CPU-NEXT: %5 = llvm.call @MPI_Irecv(%arg0, %2, %1, %3, %4, %0, %arg4) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.return +// CPU-NEXT: } +// CPU-NEXT: llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.func @enzymexla_wrapper_MPI_Wait(%arg0: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %c1_i32 = arith.constant 1 : i32 +// CPU-NEXT: %0 = llvm.alloca %c1_i32 x !llvm.array<6 x i32> : (i32) -> !llvm.ptr +// CPU-NEXT: %1 = llvm.call @MPI_Wait(%arg0, %0) : (!llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.return +// CPU-NEXT: } +// CPU-NEXT: func.func @main(%arg0: tensor<5xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"], tf.aliasing_output = 0 : i32}) -> tensor<5xf64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<5> : tensor +// CPU-NEXT: %c_1 = stablehlo.constant dense<42> : tensor +// CPU-NEXT: %c_2 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> +// CPU-NEXT: %1:2 = enzymexla.jit_call @enzymexla_wrapper_MPI_Irecv_MPI_INT (%0, %c_0, %c_2, %c_1, %c) {output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias]} : (tensor<5xf64>, tensor, tensor, tensor, tensor) -> (tensor<5xf64>, tensor) +// CPU-NEXT: enzymexla.jit_call @enzymexla_wrapper_MPI_Wait (%1#1) : (tensor) -> () +// CPU-NEXT: %2 = stablehlo.transpose %1#0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> +// CPU-NEXT: return %2 : tensor<5xf64> +// CPU-NEXT: } +// CPU-NEXT: } diff --git a/test/lit_tests/mpi/irecv.mlir b/test/lit_tests/mpi/irecv.mlir new file mode 100644 index 0000000000..37a9b6adbb --- /dev/null +++ b/test/lit_tests/mpi/irecv.mlir @@ -0,0 +1,40 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(lower-enzymexla-mpi{backend=cpu})" %s | FileCheck %s --check-prefix=CPU + +module { + func.func @main(%arg0: tensor<5xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"], tf.aliasing_output = 0 : i32}) -> tensor<5xf64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { + %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> + %c = stablehlo.constant dense<1> : tensor + %c_0 = stablehlo.constant dense<42> : tensor + %c_1 = stablehlo.constant dense<5> : tensor + %c_2 = stablehlo.constant dense<-1> : tensor + %outbuf, %request = enzymexla.mpi.irecv(%0, %c_1, %c, %c_0) {datatype = #enzymexla.datatype} : (tensor<5xf64>, tensor, tensor, tensor) -> (tensor<5xf64>, tensor) + // enzymexla.mpi.wait(%request) : tensor + %1 = stablehlo.transpose %outbuf, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> + return %1 : tensor<5xf64> + } +} + +// CPU: module { +// CPU-NEXT: llvm.mlir.global external constant @MPI_INT() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.mlir.global external constant @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.func @MPI_Irecv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.func @enzymexla_wrapper_MPI_Irecv_MPI_INT(%arg0: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg2: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg3: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg4: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %0 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr +// CPU-NEXT: %1 = llvm.mlir.addressof @MPI_INT : !llvm.ptr +// CPU-NEXT: %2 = llvm.load %arg1 : !llvm.ptr -> i32 +// CPU-NEXT: %3 = llvm.load %arg2 : !llvm.ptr -> i32 +// CPU-NEXT: %4 = llvm.load %arg3 : !llvm.ptr -> i32 +// CPU-NEXT: %5 = llvm.call @MPI_Irecv(%arg0, %2, %1, %3, %4, %0, %arg4) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.return +// CPU-NEXT: } +// CPU-NEXT: func.func @main(%arg0: tensor<5xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"], tf.aliasing_output = 0 : i32}) -> tensor<5xf64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<5> : tensor +// CPU-NEXT: %c_1 = stablehlo.constant dense<42> : tensor +// CPU-NEXT: %c_2 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> +// CPU-NEXT: %1:2 = enzymexla.jit_call @enzymexla_wrapper_MPI_Irecv_MPI_INT (%0, %c_0, %c_2, %c_1, %c) {output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias]} : (tensor<5xf64>, tensor, tensor, tensor, tensor) -> (tensor<5xf64>, tensor) +// CPU-NEXT: %2 = stablehlo.transpose %1#0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> +// CPU-NEXT: return %2 : tensor<5xf64> +// CPU-NEXT: } +// CPU-NEXT: } diff --git a/test/lit_tests/mpi/isend.mlir b/test/lit_tests/mpi/isend.mlir new file mode 100644 index 0000000000..8cbc2de9a7 --- /dev/null +++ b/test/lit_tests/mpi/isend.mlir @@ -0,0 +1,39 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(lower-enzymexla-mpi{backend=cpu})" %s | FileCheck %s --check-prefix=CPU + +module { + func.func @main(%arg0: tensor<5xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"], tf.aliasing_output = 0 : i32}) -> tensor<5xf64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { + %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> + %c = stablehlo.constant dense<1> : tensor + %c_0 = stablehlo.constant dense<42> : tensor + %c_1 = stablehlo.constant dense<5> : tensor + %1 = enzymexla.mpi.isend(%0, %c_1, %c, %c_0) {datatype = #enzymexla.datatype} : (tensor<5xf64>, tensor, tensor, tensor) -> tensor + // enzymexla.mpi.wait(%1) : tensor + %2 = stablehlo.transpose %0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> + return %2 : tensor<5xf64> + } +} + +// CPU: module { +// CPU-NEXT: llvm.mlir.global external constant @MPI_INT() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.mlir.global external constant @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.func @MPI_Isend(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.func @enzymexla_wrapper_MPI_Isend_MPI_INT(%arg0: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg2: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg3: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg4: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %0 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr +// CPU-NEXT: %1 = llvm.mlir.addressof @MPI_INT : !llvm.ptr +// CPU-NEXT: %2 = llvm.load %arg1 : !llvm.ptr -> i32 +// CPU-NEXT: %3 = llvm.load %arg2 : !llvm.ptr -> i32 +// CPU-NEXT: %4 = llvm.load %arg3 : !llvm.ptr -> i32 +// CPU-NEXT: %5 = llvm.call @MPI_Isend(%arg0, %2, %1, %3, %4, %0, %arg4) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.return +// CPU-NEXT: } +// CPU-NEXT: func.func @main(%arg0: tensor<5xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"], tf.aliasing_output = 0 : i32}) -> tensor<5xf64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<5> : tensor +// CPU-NEXT: %c_1 = stablehlo.constant dense<42> : tensor +// CPU-NEXT: %c_2 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> +// CPU-NEXT: %1 = enzymexla.jit_call @enzymexla_wrapper_MPI_Isend_MPI_INT (%0, %c_0, %c_2, %c_1, %c) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<5xf64>, tensor, tensor, tensor, tensor) -> tensor +// CPU-NEXT: %2 = stablehlo.transpose %0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> +// CPU-NEXT: return %2 : tensor<5xf64> +// CPU-NEXT: } +// CPU-NEXT: } diff --git a/test/lit_tests/mpi/recv.mlir b/test/lit_tests/mpi/recv.mlir new file mode 100644 index 0000000000..5e940b82e4 --- /dev/null +++ b/test/lit_tests/mpi/recv.mlir @@ -0,0 +1,39 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(lower-enzymexla-mpi{backend=cpu})" %s | FileCheck %s --check-prefix=CPU + +module { + func.func @main(%arg0: tensor<5xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"], tf.aliasing_output = 0 : i32}) -> tensor<5xf64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { + %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> + %c = stablehlo.constant dense<43> : tensor + %c_0 = stablehlo.constant dense<0> : tensor + %c_1 = stablehlo.constant dense<5> : tensor + %1 = enzymexla.mpi.recv(%0, %c_1, %c_0, %c) {datatype = #enzymexla.datatype} : (tensor<5xf64>, tensor, tensor, tensor) -> tensor<5xf64> + %2 = stablehlo.transpose %1, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> + return %2 : tensor<5xf64> + } +} + +// CPU: module { +// CPU-NEXT: llvm.mlir.global external constant @MPI_INT() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.mlir.global external constant @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.mlir.global external constant @MPI_STATUS_IGNORE() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.func @enzymexla_wrapper_MPI_Recv_MPI_INT(%arg0: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg2: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg3: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %0 = llvm.mlir.addressof @MPI_STATUS_IGNORE : !llvm.ptr +// CPU-NEXT: %1 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr +// CPU-NEXT: %2 = llvm.mlir.addressof @MPI_INT : !llvm.ptr +// CPU-NEXT: %3 = llvm.load %arg1 : !llvm.ptr -> i32 +// CPU-NEXT: %4 = llvm.load %arg2 : !llvm.ptr -> i32 +// CPU-NEXT: %5 = llvm.load %arg3 : !llvm.ptr -> i32 +// CPU-NEXT: %6 = llvm.call @MPI_Recv(%arg0, %3, %2, %4, %5, %1, %0) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.return +// CPU-NEXT: } +// CPU-NEXT: func.func @main(%arg0: tensor<5xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"], tf.aliasing_output = 0 : i32}) -> tensor<5xf64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %c = stablehlo.constant dense<5> : tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<0> : tensor +// CPU-NEXT: %c_1 = stablehlo.constant dense<43> : tensor +// CPU-NEXT: %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> +// CPU-NEXT: %1 = enzymexla.jit_call @enzymexla_wrapper_MPI_Recv_MPI_INT (%0, %c, %c_0, %c_1) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<5xf64>, tensor, tensor, tensor) -> tensor<5xf64> +// CPU-NEXT: %2 = stablehlo.transpose %1, dims = [0] : (tensor<5xf64>) -> tensor<5xf64> +// CPU-NEXT: return %2 : tensor<5xf64> +// CPU-NEXT: } +// CPU-NEXT: } diff --git a/test/lit_tests/mpi/send.mlir b/test/lit_tests/mpi/send.mlir new file mode 100644 index 0000000000..4fd5d14b63 --- /dev/null +++ b/test/lit_tests/mpi/send.mlir @@ -0,0 +1,37 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(lower-enzymexla-mpi{backend=cpu})" %s | FileCheck %s --check-prefix=CPU + +module { + func.func @main(%arg0: tensor<5xi64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"], tf.aliasing_output = 0 : i32}) -> tensor<5xi64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { + %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xi64>) -> tensor<5xi64> + %c = stablehlo.constant dense<43> : tensor + %c_0 = stablehlo.constant dense<1> : tensor + %c_1 = stablehlo.constant dense<5> : tensor + enzymexla.mpi.send(%0, %c_1, %c_0, %c) {datatype = #enzymexla.datatype} : tensor<5xi64>, tensor, tensor, tensor + %1 = stablehlo.transpose %0, dims = [0] : (tensor<5xi64>) -> tensor<5xi64> + return %1 : tensor<5xi64> + } +} + +// CPU: module { +// CPU-NEXT: llvm.mlir.global external constant @MPI_INT() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.mlir.global external constant @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.ptr +// CPU-NEXT: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.func @enzymexla_wrapper_MPI_Send_MPI_INT(%arg0: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg2: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg3: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %0 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr +// CPU-NEXT: %1 = llvm.mlir.addressof @MPI_INT : !llvm.ptr +// CPU-NEXT: %2 = llvm.load %arg1 : !llvm.ptr -> i32 +// CPU-NEXT: %3 = llvm.load %arg2 : !llvm.ptr -> i32 +// CPU-NEXT: %4 = llvm.load %arg3 : !llvm.ptr -> i32 +// CPU-NEXT: %5 = llvm.call @MPI_Send(%arg0, %2, %1, %3, %4, %0) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.return +// CPU-NEXT: } +// CPU-NEXT: func.func @main(%arg0: tensor<5xi64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"], tf.aliasing_output = 0 : i32}) -> tensor<5xi64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %c = stablehlo.constant dense<5> : tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %c_1 = stablehlo.constant dense<43> : tensor +// CPU-NEXT: %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xi64>) -> tensor<5xi64> +// CPU-NEXT: enzymexla.jit_call @enzymexla_wrapper_MPI_Send_MPI_INT (%0, %c, %c_0, %c_1) : (tensor<5xi64>, tensor, tensor, tensor) -> () +// CPU-NEXT: %1 = stablehlo.transpose %0, dims = [0] : (tensor<5xi64>) -> tensor<5xi64> +// CPU-NEXT: return %1 : tensor<5xi64> +// CPU-NEXT: } +// CPU-NEXT: } diff --git a/test/lit_tests/mpi/wait.mlir b/test/lit_tests/mpi/wait.mlir new file mode 100644 index 0000000000..1134939f5a --- /dev/null +++ b/test/lit_tests/mpi/wait.mlir @@ -0,0 +1,24 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(lower-enzymexla-mpi{backend=cpu})" %s | FileCheck %s --check-prefix=CPU + +module { + func.func @main() { + %c_2 = stablehlo.constant dense<-1> : tensor + enzymexla.mpi.wait(%c_2) : tensor + return + } +} + +// CPU: module { +// CPU-NEXT: llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.func @enzymexla_wrapper_MPI_Wait(%arg0: !llvm.ptr {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} { +// CPU-NEXT: %c1_i32 = arith.constant 1 : i32 +// CPU-NEXT: %0 = llvm.alloca %c1_i32 x !llvm.array<6 x i32> : (i32) -> !llvm.ptr +// CPU-NEXT: %1 = llvm.call @MPI_Wait(%arg0, %0) : (!llvm.ptr, !llvm.ptr) -> i32 +// CPU-NEXT: llvm.return +// CPU-NEXT: } +// CPU-NEXT: func.func @main() { +// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor +// CPU-NEXT: enzymexla.jit_call @enzymexla_wrapper_MPI_Wait (%c) : (tensor) -> () +// CPU-NEXT: return +// CPU-NEXT: } +// CPU-NEXT: }