Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 66 additions & 55 deletions src/plugins/ucx/ucx_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,8 +750,8 @@ nixlUcxChunkBackendH::complete(nixl_status_t status) {
sharedState_->status.store(status);
}
sharedState_->pendingReqs.fetch_sub(1);
setWorker(nullptr, UINT64_MAX);
NIXL_TRACE << *this << " completed with status: " << status << ", " << *sharedState_;
setWorker(nullptr, UINT64_MAX);
sharedState_.reset();
}

Expand Down Expand Up @@ -790,7 +790,7 @@ class nixlUcxCompositeBackendH : public nixlUcxBackendH {

size_t
getNumChunks() const {
return sharedState_->chunks.size();
return sharedState_ ? sharedState_->chunks.size() : 0;
}

void
Expand All @@ -816,12 +816,14 @@ class nixlUcxCompositeBackendH : public nixlUcxBackendH {
release() override {
NIXL_TRACE << *this << " releasing";
nixl_status_t status = nixlUcxBackendH::release();
// Set failed status to stop progress chunks
sharedState_->status.store(NIXL_ERR_NOT_FOUND);
if (sharedState_) {
// Set failed status to stop progress chunks
sharedState_->status.store(NIXL_ERR_NOT_FOUND);
// Reset shared state - it will be effectively released when the last chunk
// resets the shared state pointer
sharedState_.reset();
}

// Reset shared state - it will be effectively released when the last chunk
// resets the shared state pointer
sharedState_.reset();
return status;
}

Expand Down Expand Up @@ -1364,25 +1366,23 @@ nixl_status_t nixlUcxEngine::unloadMD (nixlBackendMD* input) {
*****************************************/

static nixl_status_t
_retHelper(nixl_status_t ret,
nixlUcxBackendH *hndl,
nixlUcxReq &req,
ucx_connection_ptr_t conn = nullptr) {
_retHelper(nixl_status_t ret, nixlUcxBackendH *hndl, nixlUcxReq &req, ucx_connection_ptr_t conn) {
/* if transfer wasn't immediately completed */
switch(ret) {
case NIXL_IN_PROG:
// TODO: this cast does not look safe
// We need to allocate a vector of nixlUcxIntReq and set nixlUcxReqt
hndl->append((nixlUcxIntReq *)req);
nixlUcxReqSetConnection(req, conn);
case NIXL_SUCCESS:
// Nothing to do
break;
default:
// Error. Release all previously initiated ops and exit:
hndl->release();
return ret;
case NIXL_IN_PROG:
// TODO: this cast does not look safe
// We need to allocate a vector of nixlUcxIntReq and set nixlUcxReqt
hndl->append((nixlUcxIntReq *)req);
nixlUcxReqSetConnection(req, conn);
case NIXL_SUCCESS:
// Nothing to do
break;
default:
// Error. Release all previously initiated ops and exit:
hndl->release();
return ret;
}

return NIXL_SUCCESS;
}

Expand Down Expand Up @@ -1564,7 +1564,8 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation,
if (ret == NIXL_SUCCESS) {
nixlUcxReq req;
auto rmd = (nixlUcxPublicMetadata *)remote[0].metadataP;
ret = notifSendPriv(remote_agent, opt_args->notifMsg, req, int_handle->getWorkerId());
ret = notifSendPriv(
remote_agent, opt_args->notifMsg, req, rmd->conn->getEp(int_handle->getWorkerId()));
if (_retHelper(ret, int_handle, req, rmd->conn)) {
return ret;
}
Expand All @@ -1581,23 +1582,33 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation,
nixl_status_t nixlUcxEngine::checkXfer (nixlBackendReqH* handle) const
{
nixlUcxBackendH *intHandle = (nixlUcxBackendH *)handle;
size_t workerId = intHandle->getWorkerId();

nixl_status_t status = intHandle->status();
auto& notif = intHandle->notification();
if (status == NIXL_SUCCESS && notif.has_value()) {
nixlUcxReq req;
status = notifSendPriv(notif->agent, notif->payload, req, workerId);
notif.reset();
// TODO: conn lookup
if (_retHelper(status, intHandle, req, nullptr)) {
return status;
nixl_status_t handle_status = intHandle->status();

if ((handle_status != NIXL_SUCCESS) || !notif.has_value()) {
if (handle_status != NIXL_IN_PROG) { // error flow
notif.reset();
}

status = intHandle->status();
return handle_status;
}

return status;
ucx_connection_ptr_t conn = getConnection(notif->agent);
if (!conn) {
notif.reset();
return NIXL_ERR_NOT_FOUND;
}

nixlUcxReq req;
nixl_status_t status =
notifSendPriv(notif->agent, notif->payload, req, conn->getEp(intHandle->getWorkerId()));
notif.reset();
status = _retHelper(status, intHandle, req, conn);
if (status != NIXL_SUCCESS) {
return status;
}

return intHandle->status();
}

nixl_status_t nixlUcxEngine::releaseReqH(nixlBackendReqH* handle) const
Expand All @@ -1624,37 +1635,34 @@ int nixlUcxEngine::progress() {
*****************************************/

//agent will provide cached msg
nixl_status_t nixlUcxEngine::notifSendPriv(const std::string &remote_agent,
const std::string &msg,
nixlUcxReq &req,
size_t worker_id) const
{
nixl_status_t
nixlUcxEngine::notifSendPriv(const std::string &remote_agent,
const std::string &msg,
nixlUcxReq &req,
const std::unique_ptr<nixlUcxEp> &ep) const {
nixlSerDes ser_des;
nixl_status_t ret;

auto search = remoteConnMap.find(remote_agent);

if(search == remoteConnMap.end()) {
//TODO: err: remote connection not found
return NIXL_ERR_NOT_FOUND;
}

ser_des.addStr("name", localAgent);
ser_des.addStr("msg", msg);
// TODO: replace with mpool for performance

auto buffer = std::make_unique<std::string>(ser_des.exportStr());
ret = search->second->getEp(worker_id)->sendAm(NOTIF_STR, NULL, 0,
(void*)buffer->data(), buffer->size(),
UCP_AM_SEND_FLAG_EAGER, req);

ret = ep->sendAm(
NOTIF_STR, NULL, 0, (void *)buffer->data(), buffer->size(), UCP_AM_SEND_FLAG_EAGER, req);
if (ret == NIXL_IN_PROG) {
nixlUcxIntReq* nReq = (nixlUcxIntReq*)req;
nReq->amBuffer = std::move(buffer);
}
return ret;
}

ucx_connection_ptr_t
nixlUcxEngine::getConnection(const std::string &remote_agent) const {
auto search = remoteConnMap.find(remote_agent);
return (search != remoteConnMap.end()) ? search->second : nullptr;
}

void
nixlUcxEngine::appendNotif(std::string remote_name, std::string msg) {
notifMainList.emplace_back(std::move(remote_name), std::move(msg));
Expand Down Expand Up @@ -1702,14 +1710,17 @@ nixl_status_t nixlUcxEngine::genNotif(const std::string &remote_agent, const std
{
nixl_status_t ret;
nixlUcxReq req;
size_t wid = getWorkerId();

ret = notifSendPriv(remote_agent, msg, req, wid);
auto conn = getConnection(remote_agent);
if (!conn) {
return NIXL_ERR_NOT_FOUND;
}

ret = notifSendPriv(remote_agent, msg, req, conn->getEp(getWorkerId()));
switch(ret) {
case NIXL_IN_PROG:
/* do not track the request */
getWorker(wid)->reqRelease(req);
getWorker(getWorkerId())->reqRelease(req);
case NIXL_SUCCESS:
break;
default:
Expand Down
5 changes: 4 additions & 1 deletion src/plugins/ucx/ucx_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ class nixlUcxEngine : public nixlBackendEngine {
notifSendPriv(const std::string &remote_agent,
const std::string &msg,
nixlUcxReq &req,
size_t worker_id) const;
const std::unique_ptr<nixlUcxEp> &ep) const;

ucx_connection_ptr_t
getConnection(const std::string &remote_agent) const;

/* UCX data */
std::unique_ptr<nixlUcxContext> uc;
Expand Down
9 changes: 6 additions & 3 deletions src/utils/ucx/ucx_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,17 @@ nixl_status_t nixlUcxEp::sendAm(unsigned msg_id,
void* buffer, size_t len,
uint32_t flags, nixlUcxReq &req)
{
ucs_status_ptr_t request;
nixl_status_t status = checkTxState();
if (status != NIXL_SUCCESS) {
return status;
}

ucp_request_param_t param = {0};

param.op_attr_mask |= UCP_OP_ATTR_FIELD_FLAGS;
param.flags = flags;

request = ucp_am_send_nbx(eph, msg_id, hdr, hdr_len, buffer, len, &param);

ucs_status_ptr_t request = ucp_am_send_nbx(eph, msg_id, hdr, hdr_len, buffer, len, &param);
if (UCS_PTR_IS_PTR(request)) {
req = (void*)request;
return NIXL_IN_PROG;
Expand Down
Loading