Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions flagcx/adaptor/device/cann_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,13 @@ flagcxResult_t cannAdaptorStreamWaitEvent(flagcxStream_t stream,
return flagcxSuccess;
}

flagcxResult_t cannAdaptorEventCreate(flagcxEvent_t *event) {
flagcxResult_t cannAdaptorEventCreate(flagcxEvent_t *event,
flagcxEventType_t eventType) {
(*event) = NULL;
flagcxCalloc(event, 1);
DEVCHECK(aclrtCreateEventWithFlag((aclrtEvent *)(*event), ACL_EVENT_SYNC));
const unsigned int flags =
(eventType == flagcxEventDefault) ? ACL_EVENT_TIME_LINE : ACL_EVENT_SYNC;
DEVCHECK(aclrtCreateEventWithFlag(&((*event)->base), flags));
return flagcxSuccess;
}

Expand Down
9 changes: 6 additions & 3 deletions flagcx/adaptor/device/cuda_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,14 @@ flagcxResult_t cudaAdaptorStreamWaitEvent(flagcxStream_t stream,
return flagcxSuccess;
}

flagcxResult_t cudaAdaptorEventCreate(flagcxEvent_t *event) {
flagcxResult_t cudaAdaptorEventCreate(flagcxEvent_t *event,
flagcxEventType_t eventType) {
(*event) = NULL;
flagcxCalloc(event, 1);
DEVCHECK(cudaEventCreateWithFlags((cudaEvent_t *)(*event),
cudaEventDefault));
const unsigned int flags = (eventType == flagcxEventDefault)
? cudaEventDefault
: cudaEventDisableTiming;
DEVCHECK(cudaEventCreateWithFlags(&((*event)->base), flags));
return flagcxSuccess;
}

Expand Down
9 changes: 6 additions & 3 deletions flagcx/adaptor/device/ducuda_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,14 @@ flagcxResult_t ducudaAdaptorStreamWaitEvent(flagcxStream_t stream,
return flagcxSuccess;
}

flagcxResult_t ducudaAdaptorEventCreate(flagcxEvent_t *event) {
flagcxResult_t ducudaAdaptorEventCreate(flagcxEvent_t *event,
flagcxEventType_t eventType) {
(*event) = NULL;
flagcxCalloc(event, 1);
DEVCHECK(cudaEventCreateWithFlags((cudaEvent_t *)(*event),
cudaEventDisableTiming));
const unsigned int flags = (eventType == flagcxEventDefault)
? cudaEventDefault
: cudaEventDisableTiming;
DEVCHECK(cudaEventCreateWithFlags(&((*event)->base), flags));
return flagcxSuccess;
}

Expand Down
9 changes: 6 additions & 3 deletions flagcx/adaptor/device/hip_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,14 @@ flagcxResult_t hipAdaptorStreamWaitEvent(flagcxStream_t stream,
return flagcxSuccess;
}

flagcxResult_t hipAdaptorEventCreate(flagcxEvent_t *event) {
flagcxResult_t hipAdaptorEventCreate(flagcxEvent_t *event,
flagcxEventType_t eventType) {
(*event) = NULL;
flagcxCalloc(event, 1);
DEVCHECK(
hipEventCreateWithFlags((hipEvent_t *)(*event), hipEventDisableTiming));
const unsigned int flags = (eventType == flagcxEventDefault)
? hipEventDefault
: hipEventDisableTiming;
DEVCHECK(hipEventCreateWithFlags(&((*event)->base), flags));
return flagcxSuccess;
}

Expand Down
10 changes: 7 additions & 3 deletions flagcx/adaptor/device/ixcuda_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,15 @@ flagcxResult_t ixcudaAdaptorStreamWaitEvent(flagcxStream_t stream,
return flagcxSuccess;
}

flagcxResult_t ixcudaAdaptorEventCreate(flagcxEvent_t *event) {
flagcxResult_t ixcudaAdaptorEventCreate(flagcxEvent_t *event,
flagcxEventType_t eventType) {
(*event) = NULL;
flagcxCalloc(event, 1);
DEVCHECK(cudaEventCreateWithFlags((cudaEvent_t *)(*event),
cudaEventDisableTiming));
const unsigned int flags = (eventType == flagcxEventDefault)
? cudaEventDefault
: cudaEventDisableTiming;
DEVCHECK(cudaEventCreateWithFlags(&((*event)->base), flags));

return flagcxSuccess;
}

Expand Down
9 changes: 6 additions & 3 deletions flagcx/adaptor/device/kunlunxin_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,14 @@ flagcxResult_t kunlunAdaptorStreamWaitEvent(flagcxStream_t stream,
return flagcxSuccess;
}

flagcxResult_t kunlunAdaptorEventCreate(flagcxEvent_t *event) {
flagcxResult_t kunlunAdaptorEventCreate(flagcxEvent_t *event,
flagcxEventType_t eventType) {
(*event) = NULL;
flagcxCalloc(event, 1);
DEVCHECK(cudaEventCreateWithFlags((cudaEvent_t *)(*event),
cudaEventDisableTiming));
const unsigned int flags = (eventType == flagcxEventDefault)
? cudaEventDefault
: cudaEventDisableTiming;
DEVCHECK(cudaEventCreateWithFlags(&((*event)->base), flags));
return flagcxSuccess;
}

Expand Down
7 changes: 5 additions & 2 deletions flagcx/adaptor/device/maca_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,13 @@ flagcxResult_t macaAdaptorStreamWaitEvent(flagcxStream_t stream,
return flagcxSuccess;
}

flagcxResult_t macaAdaptorEventCreate(flagcxEvent_t *event) {
flagcxResult_t macaAdaptorEventCreate(flagcxEvent_t *event,
flagcxEventType_t eventType) {
(*event) = NULL;
flagcxCalloc(event, 1);
DEVCHECK(mcEventCreateWithFlags((mcEvent_t *)(*event), mcEventDisableTiming));
const unsigned int flags =
(eventType == flagcxEventDefault) ? mcEventDefault : mcEventDisableTiming;
DEVCHECK(mcEventCreateWithFlags(&((*event)->base), flags));
return flagcxSuccess;
}

Expand Down
9 changes: 6 additions & 3 deletions flagcx/adaptor/device/mlu_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,14 @@ flagcxResult_t mluAdaptorStreamWaitEvent(flagcxStream_t stream,
return flagcxSuccess;
}

flagcxResult_t mluAdaptorEventCreate(flagcxEvent_t *event) {
flagcxResult_t mluAdaptorEventCreate(flagcxEvent_t *event,
flagcxEventType_t eventType) {
(*event) = NULL;
flagcxCalloc(event, 1);
DEVCHECK(cnrtNotifierCreateWithFlags((cnrtNotifier_t *)(*event),
CNRT_NOTIFIER_DISABLE_TIMING_ALL));
const unsigned int flags = (eventType == flagcxEventDefault)
? CNRT_NOTIFIER_DEFAULT
: CNRT_NOTIFIER_DISABLE_TIMING_ALL;
DEVCHECK(cnrtNotifierCreateWithFlags(&((*event)->base), flags));
return flagcxSuccess;
}

Expand Down
9 changes: 6 additions & 3 deletions flagcx/adaptor/device/musa_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,14 @@ flagcxResult_t musaAdaptorStreamWaitEvent(flagcxStream_t stream,
return flagcxSuccess;
}

flagcxResult_t musaAdaptorEventCreate(flagcxEvent_t *event) {
flagcxResult_t musaAdaptorEventCreate(flagcxEvent_t *event,
flagcxEventType_t eventType) {
(*event) = NULL;
flagcxCalloc(event, 1);
DEVCHECK(musaEventCreateWithFlags((musaEvent_t *)(*event),
musaEventDisableTiming));
const unsigned int flags = (eventType == flagcxEventDefault)
? musaEventDefault
: musaEventDisableTiming;
DEVCHECK(musaEventCreateWithFlags(&((*event)->base), flags));
return flagcxSuccess;
}

Expand Down
3 changes: 2 additions & 1 deletion flagcx/adaptor/include/adaptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ struct flagcxDeviceAdaptor {
flagcxResult_t (*streamWaitEvent)(flagcxStream_t stream, flagcxEvent_t event);

// Event functions
flagcxResult_t (*eventCreate)(flagcxEvent_t *event);
flagcxResult_t (*eventCreate)(flagcxEvent_t *event,
flagcxEventType_t eventType);
flagcxResult_t (*eventDestroy)(flagcxEvent_t event);
flagcxResult_t (*eventRecord)(flagcxEvent_t event, flagcxStream_t stream);
flagcxResult_t (*eventSynchronize)(flagcxEvent_t event);
Expand Down
6 changes: 4 additions & 2 deletions flagcx/core/group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ static flagcxResult_t groupLaunch(struct flagcxAsyncJob *job_) {
&op->args.regHandle));
// we don't use semaphore tracking for device func for the moment
if (deviceAsyncLoad && deviceAsyncStore) {
FLAGCXCHECK(deviceAdaptor->eventCreate(&op->event));
FLAGCXCHECK(deviceAdaptor->eventCreate(&op->event,
flagcxEventDisableTiming));
FLAGCXCHECK(deviceAdaptor->eventRecord(op->event, op->stream));
std::vector<void *> argList;
FLAGCXCHECK(deviceAdaptor->deviceMalloc(
Expand Down Expand Up @@ -247,7 +248,8 @@ static flagcxResult_t groupLaunch(struct flagcxAsyncJob *job_) {
// we don't use semaphore tracking for device func for the moment
if (deviceAsyncLoad && deviceAsyncStore) {
std::vector<void *> argList;
FLAGCXCHECK(deviceAdaptor->eventCreate(&op->event));
FLAGCXCHECK(deviceAdaptor->eventCreate(&op->event,
flagcxEventDisableTiming));
FLAGCXCHECK(deviceAdaptor->eventRecord(op->event, op->stream));
FLAGCXCHECK(deviceAdaptor->deviceMalloc(
(void **)&op->args.dlArgs, sizeof(bool), flagcxMemDevice,
Expand Down
2 changes: 1 addition & 1 deletion flagcx/core/launch_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct flagcxHostSemaphore {
flagcxEvent_t getEvent() {
events.push_back(nullptr);
auto &event = events.back();
deviceAdaptor->eventCreate(&event);
deviceAdaptor->eventCreate(&event, flagcxEventDisableTiming);
return event;
}
void signalFlag() { __atomic_store_n(&flag, 1, __ATOMIC_RELEASE); }
Expand Down
6 changes: 4 additions & 2 deletions flagcx/core/transport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ flagcxResult_t flagcxTransportP2pSetup(struct flagcxHeteroComm *comm,
sizeof(flagcxIbHandle));
deviceAdaptor->streamCreate(&resources->cpStream);
for (int s = 0; s < MAXSTEPS; s++) {
deviceAdaptor->eventCreate(&resources->cpEvents[s]);
deviceAdaptor->eventCreate(&resources->cpEvents[s],
flagcxEventDisableTiming);
}
resources->buffSizes[0] = REGMRBUFFERSIZE;
if (comm->netAdaptor == getUnifiedNetAdaptor(SOCKET)) {
Expand Down Expand Up @@ -70,7 +71,8 @@ flagcxResult_t flagcxTransportP2pSetup(struct flagcxHeteroComm *comm,
handle->stage.comm = comm;
deviceAdaptor->streamCreate(&resources->cpStream);
for (int s = 0; s < MAXSTEPS; s++) {
deviceAdaptor->eventCreate(&resources->cpEvents[s]);
deviceAdaptor->eventCreate(&resources->cpEvents[s],
flagcxEventDisableTiming);
}
resources->buffSizes[0] = REGMRBUFFERSIZE;
if (comm->netAdaptor == getUnifiedNetAdaptor(SOCKET)) {
Expand Down
8 changes: 7 additions & 1 deletion flagcx/include/flagcx.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ typedef enum {
flagcxMemManaged = 2
} flagcxMemType_t;

typedef enum {
flagcxEventDefault = 0,
flagcxEventDisableTiming = 1
} flagcxEventType_t;

// TODO: add more vendor types
typedef enum {
FLAGCX_VENDOR_NVIDIA = 0,
Expand Down Expand Up @@ -143,7 +148,8 @@ struct flagcxDeviceHandle {
flagcxResult_t (*streamQuery)(flagcxStream_t stream);
flagcxResult_t (*streamWaitEvent)(flagcxStream_t stream, flagcxEvent_t event);
// Event functions
flagcxResult_t (*eventCreate)(flagcxEvent_t *event);
flagcxResult_t (*eventCreate)(flagcxEvent_t *event,
flagcxEventType_t eventType);
flagcxResult_t (*eventDestroy)(flagcxEvent_t event);
flagcxResult_t (*eventRecord)(flagcxEvent_t event, flagcxStream_t stream);
flagcxResult_t (*eventSynchronize)(flagcxEvent_t event);
Expand Down
4 changes: 2 additions & 2 deletions flagcx/service/timer.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ struct flagcxRecord {

template <typename T>
flagcxRecord<T>::flagcxRecord() : duration(0.0f) {
deviceAdaptor->eventCreate(&beginEvent);
deviceAdaptor->eventCreate(&endEvent);
deviceAdaptor->eventCreate(&beginEvent, flagcxEventDefault);
deviceAdaptor->eventCreate(&endEvent, flagcxEventDefault);
}

template <typename T>
Expand Down
3 changes: 2 additions & 1 deletion plugin/interservice/flagcx_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
flagcxRedOp_t = ctypes.c_int
flagcxMemcpyType_t = ctypes.c_int
flagcxMemType_t = ctypes.c_int
flagcxEventType_t = ctypes.c_int

flagcxHandlerGroup_t = ctypes.c_void_p
flagcxComm_t = ctypes.c_void_p
Expand Down Expand Up @@ -65,7 +66,7 @@ class flagcxUniqueId(ctypes.Structure):
STREAM_QUERY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t)
STREAM_WAIT_EVENT_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t, flagcxEvent_t)

EVENT_CREATE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxEvent_t))
EVENT_CREATE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxEvent_t), flagcxEventType_t)
EVENT_DESTROY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t)
EVENT_RECORD_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t, flagcxStream_t)
EVENT_SYNCHRONIZE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t)
Expand Down