Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
449552e
Add mpi commrank op
romanlee Nov 12, 2025
e1a9745
Add bazel rules and #includes, now builds
romanlee Nov 15, 2025
29acf47
Cleanup MPI TypeDefs
romanlee Nov 18, 2025
743f7d5
Skeleton of MPI to LLVM lowering pass
romanlee Nov 18, 2025
049f971
MPICommRankOpLowering pass stuff
romanlee Nov 24, 2025
00b1754
commrankoplowering pass stuff
romanlee Nov 24, 2025
cb7819b
Update EnzymeXLAOps.td LowerEnzymeXLAMPI.cpp
romanlee Nov 26, 2025
2f02303
Just trying to get something working
romanlee Dec 3, 2025
3435454
Revert "Just trying to get something working"
romanlee Dec 3, 2025
44777ec
Just trying to get something working take 2
romanlee Dec 3, 2025
40c2c38
Remove args to comm_rank for now, will hard code comm=comm world
romanlee Dec 4, 2025
51031d2
Add some placeholder mpi lit tests
romanlee Dec 4, 2025
4ae45b9
Messing around w tests
romanlee Dec 5, 2025
ba2abc5
Toy lowering pass
romanlee Dec 6, 2025
d617950
TODO figure out what's going on with return success/failure
romanlee Dec 8, 2025
79b35fa
Starting to resemble an actual lowering pass
romanlee Dec 9, 2025
abe208b
Call the llvm function with enzymexla.jit_call
romanlee Dec 9, 2025
f22f7fb
Fix/use the jit call
romanlee Dec 9, 2025
ff64716
success() probably is behaving correctly, I just probably don't know …
romanlee Dec 9, 2025
edd9e1b
Declare and use COMM_WORLD and use func arg
romanlee Dec 10, 2025
e634b77
cleanup
romanlee Dec 10, 2025
2dccd0f
Output operand aliasings
romanlee Dec 10, 2025
de266cd
cleanup
romanlee Dec 10, 2025
bf64f3b
Use rewriter.create<>() instead of ::create(rewriter, ...)
romanlee Dec 10, 2025
67eb5c4
Style
romanlee Dec 10, 2025
2b23a11
Remove debug print statement
romanlee Dec 10, 2025
c13a0bf
Update mpi.mlir test
romanlee Dec 10, 2025
d83b9ea
Use stablehlo.constant, add enzymexla.memory_effects attr, update tests
romanlee Dec 10, 2025
b10af90
Add EnzymeXLAOp for MPI_Comm_size and lowering pass
romanlee Dec 11, 2025
ed278d1
Add MPI Send Op
romanlee Dec 12, 2025
291f0f9
Switch send arg order to correspond with the underlying c api
romanlee Dec 12, 2025
47f8a61
Fix typo
romanlee Dec 12, 2025
db759f7
Add count arg to Send Op
romanlee Dec 13, 2025
ca3ec3c
Add Send lowering pass
romanlee Dec 13, 2025
5ab51b1
Don't think we need to run enzyme-hlo-opt in mpi tests
romanlee Dec 13, 2025
a0dfbc5
Cleanup MPI ops
romanlee Dec 15, 2025
58740c4
Add MPI_Recv Op
romanlee Dec 15, 2025
a0dac96
buf should be an output of Recv, and other various and sundry items
romanlee Dec 15, 2025
3a2936d
Think we need the buf as both an input and output of Recv based on how
romanlee Dec 16, 2025
9add455
Half baked Recv lowering pass
romanlee Dec 16, 2025
6736cb3
Finish Recv Op lowering, add test
romanlee Dec 16, 2025
c2ac363
Add Irecv and Wait Ops
romanlee Dec 16, 2025
9b64ee4
Add MPI_Wait lowering pass and test
romanlee Dec 17, 2025
c133f71
Add Irecv lowering pass. Untested
romanlee Dec 17, 2025
fffb30b
Add half a irecv lit test
romanlee Dec 17, 2025
25d70f5
Fix Irecv lowering pass, finish test
romanlee Dec 17, 2025
4191dd8
Add Isend Op
romanlee Dec 17, 2025
6aacea5
Add Isend lowering pass and lit test
romanlee Dec 17, 2025
5451cd4
Add MPI Barrier Op
romanlee Dec 17, 2025
bdda0e4
Add Barrier lowering pass
romanlee Dec 17, 2025
c8f7172
Fix Barrier Op, add lit test
romanlee Dec 17, 2025
1ee4451
Add Allreduce Op
romanlee Dec 17, 2025
8929872
Add Allreduce lowering pass and lit test
romanlee Dec 17, 2025
3ada090
Let Comm_rank/size use input/output tensor paradigm
romanlee Dec 18, 2025
e39b868
Revert "Let Comm_rank/size use input/output tensor paradigm"
romanlee Dec 18, 2025
41af065
Let comm_rank take a rank tensor as input
romanlee Dec 18, 2025
8ddc82a
Fix name
romanlee Dec 18, 2025
6787a95
Modify comm_rank lowering pass to take a rank tensor as input
romanlee Dec 19, 2025
fb50ab1
Let comm_size take size as input
romanlee Dec 19, 2025
8e84c34
Update comm_size lit test
romanlee Dec 19, 2025
8cfb7c9
Fix MPI wait lit test, remove irecv to better isolate what we want to
romanlee Dec 19, 2025
7503a44
Fix send lit test
romanlee Dec 19, 2025
ad95368
Remove roundtrip test
romanlee Dec 19, 2025
4c11991
Add irecv + wait lit test
romanlee Dec 19, 2025
ba64b94
COmments
romanlee Dec 19, 2025
a3f1dc5
Add stringattr to send op
romanlee Dec 19, 2025
c4f8748
Update MPI Send lowering pass and test to use datatype attr
romanlee Dec 19, 2025
47634a2
Add datatype strattr to recv, isend, irecv, allreduce (and add op
romanlee Dec 20, 2025
4e0ce74
Clang format
romanlee Dec 22, 2025
1c9af14
Cleanup unecessary changes
romanlee Dec 22, 2025
79b701c
Update Ops summaries
romanlee Dec 23, 2025
3b747d7
Remove inrank arg from comm_rank op, temporarily comment out lowering
romanlee Jan 6, 2026
a54d31b
Uncomment comm rank lowering pass
romanlee Jan 6, 2026
40b415f
Update comm_rank lowering pass and tests to not take any args
romanlee Jan 6, 2026
7cf58a9
Fix comm_rank lowering pass
romanlee Jan 6, 2026
20b6106
Use `mpi.` prefix for ops
romanlee Jan 6, 2026
a963166
Update comm_size op and lowering pass to take no args
romanlee Jan 7, 2026
aed10ef
Update comm_size test
romanlee Jan 7, 2026
8ecfb08
Remove inrequest arg from mpi isend op
romanlee Jan 7, 2026
1ba1467
Reuncomment isend lowering
romanlee Jan 7, 2026
baf0efd
Update Isend lowering pass
romanlee Jan 7, 2026
8b261c0
Update isend test
romanlee Jan 7, 2026
7b0ccc6
Remove inrequest from Irecv op, update lowering pass
romanlee Jan 7, 2026
0c4f8b6
Update irecv tests
romanlee Jan 7, 2026
5183ef3
MPI Send op, replace string attr with enum attr, first step
romanlee Jan 8, 2026
82de909
Test: uncomment lapackuploattrget c api export thing
romanlee Jan 8, 2026
701fc52
Revert "Test: uncomment lapackuploattrget c api export thing"
romanlee Jan 8, 2026
9df57f6
Placeholder MPIDatatypeAttr enum
romanlee Jan 8, 2026
a90180d
Use new MPIDatatypeAttr in MPI Send
romanlee Jan 8, 2026
9db29fd
Towards an actual MPIDatatypeAttr
romanlee Jan 8, 2026
8c61600
Frogot to update cpp side
romanlee Jan 8, 2026
5126f33
Uncomment SendOp lowering pass
romanlee Jan 8, 2026
f367c9a
Update MPISendOp lowering pass and test to use new MPIDatatypeAttr
romanlee Jan 8, 2026
ecee86d
Flesh out MPI_Datatype enum
romanlee Jan 8, 2026
1700638
Oops, accidentatlly removed in last commit
romanlee Jan 8, 2026
a8e6c18
Forgot to update EnzymeXLA.cpp, do that here
romanlee Jan 8, 2026
302db8b
Recv, Isend/recv, allreduce: replace StrAttr datatype with EnumAttr
romanlee Jan 8, 2026
e568ab0
Update the set of MPI datatypes
romanlee Jan 8, 2026
bf8e032
Remove comment
romanlee Jan 9, 2026
f6fe860
Update mpi lit tests
romanlee Jan 9, 2026
8ee3acf
Use EnumAttr instead of StrAttr for Op in allreduce
romanlee Jan 9, 2026
a4f61b7
Update allreduce lit test
romanlee Jan 9, 2026
74ee582
Mark comm_rank/size as pure
romanlee Jan 9, 2026
0280340
Clang formatter
romanlee Jan 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,72 @@ def EnzymeXLA_GuaranteedAnalysisResult : I32EnumAttr<"GuaranteedAnalysisResult",
def EnzymeXLA_GuaranteedAnalysisResultAttr : EnumAttr<EnzymeXLA_Dialect,
EnzymeXLA_GuaranteedAnalysisResult, "guaranteed">;

// MPI

def EnzymeXLA_MPIDatatype : I32EnumAttr<"MPIDatatype",
"MPI Datatype",
[
I32EnumAttrCase<"MPI_DATATYPE_NULL", 0>,
I32EnumAttrCase<"MPI_INT8_T", 1>,
I32EnumAttrCase<"MPI_UINT8_T", 2>,
I32EnumAttrCase<"MPI_INT16_T", 3>,
I32EnumAttrCase<"MPI_UINT16_T", 4>,
I32EnumAttrCase<"MPI_INT32_T", 5>,
I32EnumAttrCase<"MPI_UINT32_T", 6>,
I32EnumAttrCase<"MPI_INT64_T", 7>,
I32EnumAttrCase<"MPI_UINT64_T", 8>,
I32EnumAttrCase<"MPI_BYTE", 9>,
I32EnumAttrCase<"MPI_SHORT", 10>,
I32EnumAttrCase<"MPI_UNSIGNED_SHORT", 11>,
I32EnumAttrCase<"MPI_INT", 12>,
I32EnumAttrCase<"MPI_UNSIGNED", 13>,
I32EnumAttrCase<"MPI_LONG", 14>,
I32EnumAttrCase<"MPI_UNSIGNED_LONG", 15>,
I32EnumAttrCase<"MPI_LONG_LONG_INT", 16>,
I32EnumAttrCase<"MPI_UNSIGNED_LONG_LONG", 17>,
I32EnumAttrCase<"MPI_CHAR", 18>,
I32EnumAttrCase<"MPI_SIGNED_CHAR", 19>,
I32EnumAttrCase<"MPI_UNSIGNED_CHAR", 20>,
I32EnumAttrCase<"MPI_WCHAR", 21>,
I32EnumAttrCase<"MPI_FLOAT", 22>,
I32EnumAttrCase<"MPI_DOUBLE", 23>,
I32EnumAttrCase<"MPI_C_FLOAT_COMPLEX", 24>,
I32EnumAttrCase<"MPI_C_DOUBLE_COMPLEX", 25>,
I32EnumAttrCase<"MPI_C_BOOL", 26>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::enzymexla";
}

def EnzymeXLA_MPIDatatypeAttr : EnumAttr<EnzymeXLA_Dialect,
EnzymeXLA_MPIDatatype, "datatype"> {
let assemblyFormat = "`<` $value `>`";
}

def EnzymeXLA_MPIOp : I32EnumAttr<"MPIOp",
"MPI Operator",
[
I32EnumAttrCase<"MPI_OP_NULL", 0>,
I32EnumAttrCase<"MPI_BAND", 1>,
I32EnumAttrCase<"MPI_BOR", 2>,
I32EnumAttrCase<"MPI_BXOR", 3>,
I32EnumAttrCase<"MPI_LAND", 4>,
I32EnumAttrCase<"MPI_LOR", 5>,
I32EnumAttrCase<"MPI_LXOR", 6>,
I32EnumAttrCase<"MPI_MAX", 7>,
I32EnumAttrCase<"MPI_MIN", 8>,
I32EnumAttrCase<"MPI_PROD", 9>,
I32EnumAttrCase<"MPI_REPLACE", 10>,
I32EnumAttrCase<"MPI_SUM", 11>,
I32EnumAttrCase<"MPI_NO_OP", 12>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::enzymexla";
}

def EnzymeXLA_MPIOpAttr : EnumAttr<EnzymeXLA_Dialect,
EnzymeXLA_MPIOp, "op"> {
let assemblyFormat = "`<` $value `>`";
}

#endif // ENZYMEXLA_ATTRS
126 changes: 126 additions & 0 deletions src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1088,4 +1088,130 @@ def AffineStoreVar : EnzymeXLA_Op<"store_var", [Pure]> {
let summary = "Fake store an SSA value for conversion to ISL";
}

// MPI Ops

def MPICommRankOp : EnzymeXLA_Op<"mpi.comm_rank", [Pure]> {
let summary = "Equivalent to " "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";

let results = (
outs TensorOf<[I32]> : $rank
);

let assemblyFormat = "attr-dict `:` type(results)";
}

def MPICommSizeOp : EnzymeXLA_Op<"mpi.comm_size", [Pure]> {
let summary = "Equivalent to MPI_Comm_size(MPI_COMM_WORLD, &size)";

let results = (
outs TensorOf<[I32]> : $size
);

let assemblyFormat = "attr-dict `:` type(results)";
}

def MPIBarrierOp : EnzymeXLA_Op<"mpi.barrier", []> {
let summary = "Equivalent to MPI_Barrier(MPI_COMM_WORLD)";
let assemblyFormat = "attr-dict";
}

def MPISendOp : EnzymeXLA_Op<"mpi.send", []> {
let summary = "Equivalent to "
"`MPI_Send(&buf, count, datatype, dest, tag, comm)`";

let arguments = (
ins AnyTensor : $buf,
TensorOf<[I32]> : $count,
TensorOf<[I32]> : $dest,
TensorOf<[I32]> : $tag,
EnzymeXLA_MPIDatatypeAttr:$datatype
);

let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
}

def MPIRecvOp : EnzymeXLA_Op<"mpi.recv", []> {
let summary = "Equivalent to "
"`MPI_Recv(&buf, count, datatype, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";

let arguments = (
ins AnyTensor : $inbuf,
TensorOf<[I32]> : $count,
TensorOf<[I32]> : $source,
TensorOf<[I32]> : $tag,
EnzymeXLA_MPIDatatypeAttr:$datatype
);

let results = (
outs AnyTensor : $outbuf
);

let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)";
}

def MPIIsendOp : EnzymeXLA_Op<"mpi.isend", []> {
let summary = "Equivalent to "
"`MPI_Isend(&buf, count, datatype, dest, tag, MPI_COMM_WORLD, &request)`";

let arguments = (
ins AnyTensor : $buf,
TensorOf<[I32]> : $count,
TensorOf<[I32]> : $dest,
TensorOf<[I32]> : $tag,
EnzymeXLA_MPIDatatypeAttr:$datatype
);

let results = (
outs TensorOf<[I64]> : $request
);

let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)";
}

def MPIIrecvOp : EnzymeXLA_Op<"mpi.irecv", []> {
let summary = "Equivalent to "
"`MPI_Irecv(&buf, count, datatype, source, tag, MPI_COMM_WORLD, &request)`";

let arguments = (
ins AnyTensor : $inbuf,
TensorOf<[I32]> : $count,
TensorOf<[I32]> : $source,
TensorOf<[I32]> : $tag,
EnzymeXLA_MPIDatatypeAttr:$datatype
);

let results = (
outs AnyTensor : $outbuf,
TensorOf<[I64]> : $request
);

let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)";
}

def MPIWaitOp : EnzymeXLA_Op<"mpi.wait", []> {
let summary = "Equivalent to "
"`MPI_Wait(&request, &status)`";
let arguments = (ins TensorOf<[I64]> : $request);
let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
}

def MPIAllreduceOp : EnzymeXLA_Op<"mpi.allreduce", []> {
let summary = "Equivalent to "
"`MPI_Allreduce(&sendbuf, &recvbuf, count, datatype, op, MPI_COMM_WORLD)`";

let arguments = (
ins AnyTensor : $sendbuf,
AnyTensor : $inbuf,
TensorOf<[I32]> : $count,
EnzymeXLA_MPIDatatypeAttr:$datatype,
EnzymeXLA_MPIOpAttr:$op
);

let results = (
outs AnyTensor : $outbuf
);

let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)";
}

#endif // ENZYMEXLA_OPS
138 changes: 138 additions & 0 deletions src/enzyme_ad/jax/Integrations/c/EnzymeXLA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,141 @@ MlirAttribute enzymexlaGuaranteedAnalysisResultAttrGet(MlirContext ctx,
return wrap(mlir::enzymexla::GuaranteedAnalysisResultAttr::get(unwrap(ctx),
analysis));
}

MlirAttribute enzymexlaMPIDatatypeAttrGet(MlirContext ctx, int32_t mode) {
mlir::enzymexla::MPIDatatype datatype;
switch (mode) {
case 0:
datatype = mlir::enzymexla::MPIDatatype::MPI_DATATYPE_NULL;
break;
case 1:
datatype = mlir::enzymexla::MPIDatatype::MPI_INT8_T;
break;
case 2:
datatype = mlir::enzymexla::MPIDatatype::MPI_UINT8_T;
break;
case 3:
datatype = mlir::enzymexla::MPIDatatype::MPI_INT16_T;
break;
case 4:
datatype = mlir::enzymexla::MPIDatatype::MPI_UINT16_T;
break;
case 5:
datatype = mlir::enzymexla::MPIDatatype::MPI_INT32_T;
break;
case 6:
datatype = mlir::enzymexla::MPIDatatype::MPI_UINT32_T;
break;
case 7:
datatype = mlir::enzymexla::MPIDatatype::MPI_INT64_T;
break;
case 8:
datatype = mlir::enzymexla::MPIDatatype::MPI_UINT64_T;
break;
case 9:
datatype = mlir::enzymexla::MPIDatatype::MPI_BYTE;
break;
case 10:
datatype = mlir::enzymexla::MPIDatatype::MPI_SHORT;
break;
case 11:
datatype = mlir::enzymexla::MPIDatatype::MPI_UNSIGNED_SHORT;
break;
case 12:
datatype = mlir::enzymexla::MPIDatatype::MPI_INT;
break;
case 13:
datatype = mlir::enzymexla::MPIDatatype::MPI_UNSIGNED;
break;
case 14:
datatype = mlir::enzymexla::MPIDatatype::MPI_LONG;
break;
case 15:
datatype = mlir::enzymexla::MPIDatatype::MPI_UNSIGNED_LONG;
break;
case 16:
datatype = mlir::enzymexla::MPIDatatype::MPI_LONG_LONG_INT;
break;
case 17:
datatype = mlir::enzymexla::MPIDatatype::MPI_UNSIGNED_LONG_LONG;
break;
case 18:
datatype = mlir::enzymexla::MPIDatatype::MPI_CHAR;
break;
case 19:
datatype = mlir::enzymexla::MPIDatatype::MPI_SIGNED_CHAR;
break;
case 20:
datatype = mlir::enzymexla::MPIDatatype::MPI_UNSIGNED_CHAR;
break;
case 21:
datatype = mlir::enzymexla::MPIDatatype::MPI_WCHAR;
break;
case 22:
datatype = mlir::enzymexla::MPIDatatype::MPI_FLOAT;
break;
case 23:
datatype = mlir::enzymexla::MPIDatatype::MPI_DOUBLE;
break;
case 24:
datatype = mlir::enzymexla::MPIDatatype::MPI_C_FLOAT_COMPLEX;
break;
case 25:
datatype = mlir::enzymexla::MPIDatatype::MPI_C_DOUBLE_COMPLEX;
break;
case 26:
datatype = mlir::enzymexla::MPIDatatype::MPI_C_BOOL;
break;
default:
llvm_unreachable("Invalid MPI datatype mode");
}
return wrap(mlir::enzymexla::MPIDatatypeAttr::get(unwrap(ctx), datatype));
}

MlirAttribute enzymexlaMPIOpAttrGet(MlirContext ctx, int32_t mode) {
mlir::enzymexla::MPIOp op;
switch (mode) {
case 0:
op = mlir::enzymexla::MPIOp::MPI_OP_NULL;
break;
case 1:
op = mlir::enzymexla::MPIOp::MPI_BAND;
break;
case 2:
op = mlir::enzymexla::MPIOp::MPI_BOR;
break;
case 3:
op = mlir::enzymexla::MPIOp::MPI_BXOR;
break;
case 4:
op = mlir::enzymexla::MPIOp::MPI_LAND;
break;
case 5:
op = mlir::enzymexla::MPIOp::MPI_LOR;
break;
case 6:
op = mlir::enzymexla::MPIOp::MPI_LXOR;
break;
case 7:
op = mlir::enzymexla::MPIOp::MPI_MAX;
break;
case 8:
op = mlir::enzymexla::MPIOp::MPI_MIN;
break;
case 9:
op = mlir::enzymexla::MPIOp::MPI_PROD;
break;
case 10:
op = mlir::enzymexla::MPIOp::MPI_REPLACE;
break;
case 11:
op = mlir::enzymexla::MPIOp::MPI_SUM;
break;
case 12:
op = mlir::enzymexla::MPIOp::MPI_NO_OP;
break;
default:
llvm_unreachable("Invalid MPI op mode");
}
return wrap(mlir::enzymexla::MPIOpAttr::get(unwrap(ctx), op));
}
10 changes: 10 additions & 0 deletions src/enzyme_ad/jax/Integrations/c/EnzymeXLA.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ MLIR_CAPI_EXPORTED MlirAttribute enzymexlaSVDAlgorithmAttrGet(MlirContext ctx,
MLIR_CAPI_EXPORTED MlirAttribute
enzymexlaGeluApproximationAttrGet(MlirContext ctx, int32_t mode);

//===----------------------------------------------------------------------===//
// MPI Ops
//===----------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED MlirAttribute enzymexlaMPIDatatypeAttrGet(MlirContext ctx,
int32_t mode);

MLIR_CAPI_EXPORTED MlirAttribute enzymexlaMPIOpAttrGet(MlirContext ctx,
int32_t mode);

//===----------------------------------------------------------------------===//
// Other Ops / Attributes
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading