diff --git a/src/plugins/ucx/ucx_backend.cpp b/src/plugins/ucx/ucx_backend.cpp index f18a3f60b..7e946272f 100644 --- a/src/plugins/ucx/ucx_backend.cpp +++ b/src/plugins/ucx/ucx_backend.cpp @@ -750,8 +750,8 @@ nixlUcxChunkBackendH::complete(nixl_status_t status) { sharedState_->status.store(status); } sharedState_->pendingReqs.fetch_sub(1); - setWorker(nullptr, UINT64_MAX); NIXL_TRACE << *this << " completed with status: " << status << ", " << *sharedState_; + setWorker(nullptr, UINT64_MAX); sharedState_.reset(); } @@ -790,7 +790,7 @@ class nixlUcxCompositeBackendH : public nixlUcxBackendH { size_t getNumChunks() const { - return sharedState_->chunks.size(); + return sharedState_ ? sharedState_->chunks.size() : 0; } void @@ -816,12 +816,14 @@ class nixlUcxCompositeBackendH : public nixlUcxBackendH { release() override { NIXL_TRACE << *this << " releasing"; nixl_status_t status = nixlUcxBackendH::release(); - // Set failed status to stop progress chunks - sharedState_->status.store(NIXL_ERR_NOT_FOUND); + if (sharedState_) { + // Set failed status to stop progress chunks + sharedState_->status.store(NIXL_ERR_NOT_FOUND); + // Reset shared state - it will be effectively released when the last chunk + // resets the shared state pointer + sharedState_.reset(); + } - // Reset shared state - it will be effectively released when the last chunk - // resets the shared state pointer - sharedState_.reset(); return status; } @@ -1364,25 +1366,23 @@ nixl_status_t nixlUcxEngine::unloadMD (nixlBackendMD* input) { *****************************************/ static nixl_status_t -_retHelper(nixl_status_t ret, - nixlUcxBackendH *hndl, - nixlUcxReq &req, - ucx_connection_ptr_t conn = nullptr) { +_retHelper(nixl_status_t ret, nixlUcxBackendH *hndl, nixlUcxReq &req, ucx_connection_ptr_t conn) { /* if transfer wasn't immediately completed */ switch(ret) { - case NIXL_IN_PROG: - // TODO: this cast does not look safe - // We need to allocate a vector of nixlUcxIntReq and set nixlUcxReqt - hndl->append((nixlUcxIntReq *)req); - nixlUcxReqSetConnection(req, conn); - case NIXL_SUCCESS: - // Nothing to do - break; - default: - // Error. Release all previously initiated ops and exit: - hndl->release(); - return ret; + case NIXL_IN_PROG: + // TODO: this cast does not look safe + // We need to allocate a vector of nixlUcxIntReq and set nixlUcxReqt + hndl->append((nixlUcxIntReq *)req); + nixlUcxReqSetConnection(req, conn); + case NIXL_SUCCESS: + // Nothing to do + break; + default: + // Error. Release all previously initiated ops and exit: + hndl->release(); + return ret; } + return NIXL_SUCCESS; } @@ -1564,7 +1564,8 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation, if (ret == NIXL_SUCCESS) { nixlUcxReq req; auto rmd = (nixlUcxPublicMetadata *)remote[0].metadataP; - ret = notifSendPriv(remote_agent, opt_args->notifMsg, req, int_handle->getWorkerId()); + ret = notifSendPriv( + remote_agent, opt_args->notifMsg, req, rmd->conn->getEp(int_handle->getWorkerId())); if (_retHelper(ret, int_handle, req, rmd->conn)) { return ret; } @@ -1581,23 +1582,33 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation, nixl_status_t nixlUcxEngine::checkXfer (nixlBackendReqH* handle) const { nixlUcxBackendH *intHandle = (nixlUcxBackendH *)handle; - size_t workerId = intHandle->getWorkerId(); - - nixl_status_t status = intHandle->status(); auto& notif = intHandle->notification(); - if (status == NIXL_SUCCESS && notif.has_value()) { - nixlUcxReq req; - status = notifSendPriv(notif->agent, notif->payload, req, workerId); - notif.reset(); - // TODO: conn lookup - if (_retHelper(status, intHandle, req, nullptr)) { - return status; + nixl_status_t handle_status = intHandle->status(); + + if ((handle_status != NIXL_SUCCESS) || !notif.has_value()) { + if (handle_status != NIXL_IN_PROG) { // error flow + notif.reset(); } - status = intHandle->status(); + return handle_status; } - return status; + ucx_connection_ptr_t conn = getConnection(notif->agent); + if (!conn) { + notif.reset(); + return NIXL_ERR_NOT_FOUND; + } + + nixlUcxReq req; + nixl_status_t status = + notifSendPriv(notif->agent, notif->payload, req, conn->getEp(intHandle->getWorkerId())); + notif.reset(); + status = _retHelper(status, intHandle, req, conn); + if (status != NIXL_SUCCESS) { + return status; + } + + return intHandle->status(); } nixl_status_t nixlUcxEngine::releaseReqH(nixlBackendReqH* handle) const @@ -1624,30 +1635,21 @@ int nixlUcxEngine::progress() { *****************************************/ //agent will provide cached msg -nixl_status_t nixlUcxEngine::notifSendPriv(const std::string &remote_agent, - const std::string &msg, - nixlUcxReq &req, - size_t worker_id) const -{ +nixl_status_t +nixlUcxEngine::notifSendPriv(const std::string &remote_agent, + const std::string &msg, + nixlUcxReq &req, + const std::unique_ptr &ep) const { nixlSerDes ser_des; nixl_status_t ret; - auto search = remoteConnMap.find(remote_agent); - - if(search == remoteConnMap.end()) { - //TODO: err: remote connection not found - return NIXL_ERR_NOT_FOUND; - } - ser_des.addStr("name", localAgent); ser_des.addStr("msg", msg); // TODO: replace with mpool for performance auto buffer = std::make_unique(ser_des.exportStr()); - ret = search->second->getEp(worker_id)->sendAm(NOTIF_STR, NULL, 0, - (void*)buffer->data(), buffer->size(), - UCP_AM_SEND_FLAG_EAGER, req); - + ret = ep->sendAm( + NOTIF_STR, NULL, 0, (void *)buffer->data(), buffer->size(), UCP_AM_SEND_FLAG_EAGER, req); if (ret == NIXL_IN_PROG) { nixlUcxIntReq* nReq = (nixlUcxIntReq*)req; nReq->amBuffer = std::move(buffer); @@ -1655,6 +1657,12 @@ nixl_status_t nixlUcxEngine::notifSendPriv(const std::string &remote_agent, return ret; } +ucx_connection_ptr_t +nixlUcxEngine::getConnection(const std::string &remote_agent) const { + auto search = remoteConnMap.find(remote_agent); + return (search != remoteConnMap.end()) ? search->second : nullptr; +} + void nixlUcxEngine::appendNotif(std::string remote_name, std::string msg) { notifMainList.emplace_back(std::move(remote_name), std::move(msg)); @@ -1702,14 +1710,17 @@ nixl_status_t nixlUcxEngine::genNotif(const std::string &remote_agent, const std { nixl_status_t ret; nixlUcxReq req; - size_t wid = getWorkerId(); - ret = notifSendPriv(remote_agent, msg, req, wid); + auto conn = getConnection(remote_agent); + if (!conn) { + return NIXL_ERR_NOT_FOUND; + } + ret = notifSendPriv(remote_agent, msg, req, conn->getEp(getWorkerId())); switch(ret) { case NIXL_IN_PROG: /* do not track the request */ - getWorker(wid)->reqRelease(req); + getWorker(getWorkerId())->reqRelease(req); case NIXL_SUCCESS: break; default: diff --git a/src/plugins/ucx/ucx_backend.h b/src/plugins/ucx/ucx_backend.h index 6dc9cd1ef..6c33eb318 100644 --- a/src/plugins/ucx/ucx_backend.h +++ b/src/plugins/ucx/ucx_backend.h @@ -271,7 +271,10 @@ class nixlUcxEngine : public nixlBackendEngine { notifSendPriv(const std::string &remote_agent, const std::string &msg, nixlUcxReq &req, - size_t worker_id) const; + const std::unique_ptr &ep) const; + + ucx_connection_ptr_t + getConnection(const std::string &remote_agent) const; /* UCX data */ std::unique_ptr uc; diff --git a/src/utils/ucx/ucx_utils.cpp b/src/utils/ucx/ucx_utils.cpp index af898312a..271e7ee02 100644 --- a/src/utils/ucx/ucx_utils.cpp +++ b/src/utils/ucx/ucx_utils.cpp @@ -235,14 +235,17 @@ nixl_status_t nixlUcxEp::sendAm(unsigned msg_id, void* buffer, size_t len, uint32_t flags, nixlUcxReq &req) { - ucs_status_ptr_t request; + nixl_status_t status = checkTxState(); + if (status != NIXL_SUCCESS) { + return status; + } + ucp_request_param_t param = {0}; param.op_attr_mask |= UCP_OP_ATTR_FIELD_FLAGS; param.flags = flags; - request = ucp_am_send_nbx(eph, msg_id, hdr, hdr_len, buffer, len, ¶m); - + ucs_status_ptr_t request = ucp_am_send_nbx(eph, msg_id, hdr, hdr_len, buffer, len, ¶m); if (UCS_PTR_IS_PTR(request)) { req = (void*)request; return NIXL_IN_PROG; diff --git a/test/gtest/error_handling.cpp b/test/gtest/error_handling.cpp index d702b1b9d..61694d460 100644 --- a/test/gtest/error_handling.cpp +++ b/test/gtest/error_handling.cpp @@ -122,18 +122,28 @@ class TestErrorHandling : public testing::TestWithParam void testXfer(); private: - template bool isFailure(size_t iter); + template + bool + failBeforePost(size_t iter); + template + bool + failAfterPost(size_t iter); + template + bool + isFailure(size_t iter); template size_t numIter(); void exchangeMetaData(); + template std::variant - postXfer(enum nixl_xfer_op_t op, bool target_failure); + postXfer(enum nixl_xfer_op_t op, size_t iter); ScopedEnv m_env; Agent m_Initiator; @@ -256,11 +266,10 @@ void TestErrorHandling::testXfer() { exchangeMetaData(); for (size_t i = 0; i < numIter(); ++i) { - auto result = postXfer(op, isFailure(i)); nixl_status_t status; - + auto result = postXfer(op, i); if (std::holds_alternative(result)) { - // Transfer failed immediately + // Transfer completed immediately status = std::get(result); } else { // Transfer was posted, wait for completion @@ -269,7 +278,12 @@ void TestErrorHandling::testXfer() { } if (isFailure(i)) { - EXPECT_EQ(NIXL_ERR_REMOTE_DISCONNECT, status); + if (failBeforePost(i)) { + EXPECT_EQ(status, NIXL_ERR_REMOTE_DISCONNECT); + } else { + EXPECT_TRUE((status == NIXL_ERR_REMOTE_DISCONNECT) || (status == NIXL_SUCCESS)); + } + if (test_type == TestType::XFER_FAIL_RESTORE) { m_Target.init(target_name, m_backend_name, numWorkers_, numThreads_); exchangeMetaData(); @@ -293,22 +307,40 @@ void TestErrorHandling::testXfer() { return; case TestType::LOAD_REMOTE_THEN_FAIL: case TestType::XFER_THEN_FAIL: + case TestType::FAIL_AFTER_POST: m_Initiator.destroy(); return; } } template -bool TestErrorHandling::isFailure(size_t iter) { +bool +TestErrorHandling::failBeforePost(size_t iter) { switch (test_type) { - case TestType::BASIC_XFER: return false; - case TestType::LOAD_REMOTE_THEN_FAIL: return iter == 0; + case TestType::BASIC_XFER: + return false; + case TestType::LOAD_REMOTE_THEN_FAIL: + return iter == 0; case TestType::XFER_THEN_FAIL: case TestType::XFER_FAIL_RESTORE: return iter == 1; + case TestType::FAIL_AFTER_POST: + return false; } } +template +bool +TestErrorHandling::failAfterPost(size_t iter) { + return (test_type == TestType::FAIL_AFTER_POST) && (iter == 1); +} + +template +bool +TestErrorHandling::isFailure(size_t iter) { + return failBeforePost(iter) || failAfterPost(iter); +} + template size_t TestErrorHandling::numIter() { @@ -317,6 +349,7 @@ TestErrorHandling::numIter() { case TestType::LOAD_REMOTE_THEN_FAIL: return 1; case TestType::XFER_THEN_FAIL: + case TestType::FAIL_AFTER_POST: return 2; case TestType::XFER_FAIL_RESTORE: return 3; @@ -328,8 +361,9 @@ void TestErrorHandling::exchangeMetaData() { m_Target.loadRemoteMD(m_Initiator.getLocalMD()); } +template std::variant -TestErrorHandling::postXfer(enum nixl_xfer_op_t op, bool target_failure) { +TestErrorHandling::postXfer(enum nixl_xfer_op_t op, size_t iter) { EXPECT_TRUE(op == NIXL_WRITE || op == NIXL_READ); nixlBasicDesc sReq_src; @@ -341,34 +375,30 @@ TestErrorHandling::postXfer(enum nixl_xfer_op_t op, bool target_failure) { m_Target.fillRegList(rReq_descs, rReq_dst); nixlXferReqH* req_handle; - nixl_status_t status; - - status = m_Initiator.createXferReq(op, sReq_descs, rReq_descs, req_handle); + nixl_status_t status = m_Initiator.createXferReq(op, sReq_descs, rReq_descs, req_handle); EXPECT_EQ(NIXL_SUCCESS, status) << "createXferReq failed with unexpected error: " << nixlEnumStrings::statusStr(status); - if (target_failure) { + if (failBeforePost(iter)) { m_Target.destroy(); } status = m_Initiator.postXferReq(req_handle); - if (target_failure) { - // If the target is destroyed, the transfer may fail immediately - // or later - if (status == NIXL_ERR_REMOTE_DISCONNECT) { - // failed handle destroyed on post - return status; - } - EXPECT_EQ(NIXL_IN_PROG, status) << "status: " << nixlEnumStrings::statusStr(status); - } else { - EXPECT_LE(0, status) << "status: " - << nixlEnumStrings::statusStr(status); + if (failAfterPost(iter)) { + m_Target.destroy(); + } + + if (isFailure(iter) && (status == NIXL_ERR_REMOTE_DISCONNECT)) { + // failed handle destroyed on post + return status; } + EXPECT_LE(0, status) << "status: " << nixlEnumStrings::statusStr(status); return req_handle; } + TEST_P(TestErrorHandling, BasicXfer) { testXfer(); testXfer(); @@ -389,6 +419,11 @@ TEST_P(TestErrorHandling, XferFailRestore) { testXfer(); } +TEST_P(TestErrorHandling, XferPostThenFail) { + testXfer(); + testXfer(); +} + INSTANTIATE_TEST_SUITE_P(ucx, TestErrorHandling, testing::Values(std::make_tuple("UCX", 1, 0))); INSTANTIATE_TEST_SUITE_P(ucx_mo, TestErrorHandling,