Skip to content

Commit d33d8a9

Browse files
authored
Advanced request handling optimizations (#1009)
1 parent 8206602 commit d33d8a9

File tree

6 files changed

+141
-117
lines changed

6 files changed

+141
-117
lines changed

src/plugins/ucx/ucx_backend.cpp

Lines changed: 120 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -265,40 +265,14 @@ void nixlUcxEngine::vramFiniCtx()
265265
cudaCtx.reset();
266266
}
267267

268-
/****************************************
269-
* UCX request management
270-
*****************************************/
271-
272-
273-
class nixlUcxIntReq {
274-
public:
275-
operator nixlUcxReq() noexcept {
276-
return static_cast<nixlUcxReq>(this);
277-
}
278-
279-
void
280-
setConnection(nixlUcxConnection *conn) {
281-
conn_ = conn;
282-
}
283-
284-
nixl_status_t
285-
checkConnection(size_t ep_id) const {
286-
NIXL_ASSERT(conn_) << "Connection is not set";
287-
return conn_->getEp(ep_id)->checkTxState();
288-
}
289-
290-
private:
291-
nixlUcxConnection *conn_;
292-
};
293-
294268
/****************************************
295269
* Backend request management
296270
*****************************************/
297271

298272
class nixlUcxBackendH : public nixlBackendReqH {
299273
private:
300274
std::set<ucx_connection_ptr_t> connections_;
301-
std::vector<nixlUcxIntReq *> requests_;
275+
std::vector<nixlUcxReq> requests_;
302276
nixlUcxWorker *worker;
303277
size_t worker_id;
304278

@@ -313,26 +287,54 @@ class nixlUcxBackendH : public nixlBackendReqH {
313287
};
314288
std::optional<Notif> notif;
315289

316-
public:
317-
auto& notification() {
318-
return notif;
290+
nixl_status_t
291+
checkConnection(nixl_status_t status = NIXL_SUCCESS) const {
292+
NIXL_ASSERT(!connections_.empty());
293+
for (const auto &conn : connections_) {
294+
nixl_status_t conn_status = conn->getEp(worker_id)->checkTxState();
295+
if (conn_status != NIXL_SUCCESS) {
296+
return conn_status;
297+
}
298+
}
299+
return status;
319300
}
320301

302+
public:
321303
nixlUcxBackendH(nixlUcxWorker *worker, size_t worker_id)
322304
: worker(worker),
323305
worker_id(worker_id) {}
324306

307+
auto &
308+
notification() {
309+
return notif;
310+
}
311+
325312
void
326313
reserve(size_t size) {
327314
requests_.reserve(size);
328315
}
329316

330-
void
331-
append(nixlUcxReq req, ucx_connection_ptr_t conn) {
332-
auto req_int = static_cast<nixlUcxIntReq *>(req);
333-
req_int->setConnection(conn.get());
334-
requests_.push_back(req_int);
317+
nixl_status_t
318+
append(nixl_status_t status, nixlUcxReq req, ucx_connection_ptr_t conn) {
335319
connections_.insert(conn);
320+
switch (status) {
321+
case NIXL_IN_PROG:
322+
requests_.push_back(req);
323+
break;
324+
case NIXL_SUCCESS:
325+
// Nothing to do
326+
break;
327+
default:
328+
// Error. Release all previously initiated ops and exit:
329+
release();
330+
return status;
331+
}
332+
return NIXL_SUCCESS;
333+
}
334+
335+
const std::set<ucx_connection_ptr_t> &
336+
getConnections() const {
337+
return connections_;
336338
}
337339

338340
virtual bool
@@ -343,7 +345,7 @@ class nixlUcxBackendH : public nixlBackendReqH {
343345
virtual nixl_status_t
344346
release() {
345347
// TODO: Error log: uncompleted requests found! Cancelling ...
346-
for (nixlUcxIntReq *req : requests_) {
348+
for (nixlUcxReq req : requests_) {
347349
nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req));
348350
if (ret == NIXL_IN_PROG) {
349351
// TODO: Need process this properly.
@@ -370,20 +372,19 @@ class nixlUcxBackendH : public nixlBackendReqH {
370372

371373
/* If last request is incomplete, return NIXL_IN_PROG early without
372374
* checking other requests */
373-
nixlUcxIntReq *req = requests_.back();
375+
nixlUcxReq req = requests_.back();
374376
nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req));
375377
if (ret == NIXL_IN_PROG) {
376378
return NIXL_IN_PROG;
377379
} else if (ret != NIXL_SUCCESS) {
378-
nixl_status_t conn_status = req->checkConnection(worker_id);
379-
return (conn_status == NIXL_SUCCESS) ? ret : conn_status;
380+
return checkConnection(ret);
380381
}
381382

382383
/* Last request completed successfully, all the others must be in the
383384
* same state. TODO: remove extra checks? */
384385
size_t incomplete_reqs = 0;
385386
nixl_status_t out_ret = NIXL_SUCCESS;
386-
for (nixlUcxIntReq *req : requests_) {
387+
for (nixlUcxReq req : requests_) {
387388
nixl_status_t ret = ucx_status_to_nixl(ucp_request_check_status(req));
388389
if (__builtin_expect(ret == NIXL_SUCCESS, 0)) {
389390
worker->reqRelease(req);
@@ -394,8 +395,7 @@ class nixlUcxBackendH : public nixlBackendReqH {
394395
requests_[incomplete_reqs++] = req;
395396
} else {
396397
// Any other ret value is ERR and will be returned
397-
nixl_status_t conn_status = req->checkConnection(worker_id);
398-
out_ret = (conn_status == NIXL_SUCCESS) ? ret : conn_status;
398+
out_ret = checkConnection(ret);
399399
}
400400
}
401401

@@ -1102,7 +1102,7 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params)
11021102
}
11031103

11041104
uc = std::make_unique<nixlUcxContext>(
1105-
devs, sizeof(nixlUcxIntReq), init_params.enableProgTh, num_workers, init_params.syncMode);
1105+
devs, init_params.enableProgTh, num_workers, init_params.syncMode);
11061106

11071107
for (size_t i = 0; i < num_workers; i++) {
11081108
uws.emplace_back(std::make_unique<nixlUcxWorker>(*uc, err_handling_mode));
@@ -1324,24 +1324,6 @@ nixl_status_t nixlUcxEngine::unloadMD (nixlBackendMD* input) {
13241324
* Data movement
13251325
*****************************************/
13261326

1327-
static nixl_status_t
1328-
_retHelper(nixl_status_t ret, nixlUcxBackendH *hndl, nixlUcxReq &req, ucx_connection_ptr_t conn) {
1329-
/* if transfer wasn't immediately completed */
1330-
switch(ret) {
1331-
case NIXL_IN_PROG:
1332-
hndl->append(req, conn);
1333-
case NIXL_SUCCESS:
1334-
// Nothing to do
1335-
break;
1336-
default:
1337-
// Error. Release all previously initiated ops and exit:
1338-
hndl->release();
1339-
return ret;
1340-
}
1341-
1342-
return NIXL_SUCCESS;
1343-
}
1344-
13451327
size_t
13461328
nixlUcxEngine::getWorkerId() const {
13471329
auto it = tlsSharedWorkerMap().find(this);
@@ -1461,6 +1443,56 @@ nixl_status_t nixlUcxEngine::estimateXferCost (const nixl_xfer_op_t &operation,
14611443
return NIXL_SUCCESS;
14621444
}
14631445

1446+
nixlUcxEngine::batchResult
1447+
nixlUcxEngine::sendXferRangeBatch(nixlUcxEp &ep,
1448+
nixl_xfer_op_t operation,
1449+
const nixl_meta_dlist_t &local,
1450+
const nixl_meta_dlist_t &remote,
1451+
size_t worker_id,
1452+
size_t start_idx,
1453+
size_t end_idx) {
1454+
batchResult result = {NIXL_SUCCESS, 0, nullptr};
1455+
1456+
for (size_t i = start_idx; i < end_idx; ++i) {
1457+
void *laddr = (void *)local[i].addr;
1458+
size_t lsize = local[i].len;
1459+
uint64_t raddr = static_cast<uint64_t>(remote[i].addr);
1460+
NIXL_ASSERT(lsize == remote[i].len);
1461+
1462+
auto lmd = static_cast<nixlUcxPrivateMetadata *>(local[i].metadataP);
1463+
auto rmd = static_cast<nixlUcxPublicMetadata *>(remote[i].metadataP);
1464+
auto &rmd_ep = rmd->conn->getEp(worker_id);
1465+
if (__builtin_expect(rmd_ep.get() != &ep, 0)) {
1466+
break;
1467+
}
1468+
1469+
++result.size;
1470+
nixlUcxReq req;
1471+
nixl_status_t ret = operation == NIXL_READ ?
1472+
ep.read(raddr, rmd->getRkey(worker_id), laddr, lmd->mem, lsize, req) :
1473+
ep.write(laddr, lmd->mem, raddr, rmd->getRkey(worker_id), lsize, req);
1474+
1475+
if (ret == NIXL_IN_PROG) {
1476+
if (__builtin_expect(result.req != nullptr, 1)) {
1477+
ucp_request_free(result.req);
1478+
}
1479+
result.req = req;
1480+
} else if (ret != NIXL_SUCCESS) {
1481+
result.status = ret;
1482+
if (result.req != nullptr) {
1483+
ucp_request_free(result.req);
1484+
result.req = nullptr;
1485+
}
1486+
break;
1487+
}
1488+
}
1489+
1490+
if (result.status == NIXL_SUCCESS && result.req) {
1491+
result.status = NIXL_IN_PROG;
1492+
}
1493+
return result;
1494+
}
1495+
14641496
nixl_status_t
14651497
nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation,
14661498
const nixl_meta_dlist_t &local,
@@ -1470,54 +1502,44 @@ nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation,
14701502
size_t start_idx,
14711503
size_t end_idx) const {
14721504
nixlUcxBackendH *intHandle = (nixlUcxBackendH *)handle;
1473-
nixlUcxPrivateMetadata *lmd;
1474-
nixlUcxPublicMetadata *rmd;
1475-
nixl_status_t ret;
1476-
nixlUcxReq req;
14771505
size_t workerId = intHandle->getWorkerId();
1506+
nixl_status_t ret;
14781507

1479-
// Reserve space for the requests, +2 for flush and completion
1480-
intHandle->reserve(end_idx - start_idx + 2);
1508+
if (operation != NIXL_WRITE && operation != NIXL_READ) {
1509+
return NIXL_ERR_INVALID_PARAM;
1510+
}
14811511

1482-
for (size_t i = start_idx; i < end_idx; i++) {
1483-
void *laddr = (void*) local[i].addr;
1484-
size_t lsize = local[i].len;
1485-
uint64_t raddr = (uint64_t)remote[i].addr;
1486-
size_t rsize = remote[i].len;
1512+
/* Assuming we have a single EP, we need 3 requests: one pending request,
1513+
* one flush request, and one notification request */
1514+
intHandle->reserve(3);
14871515

1488-
lmd = (nixlUcxPrivateMetadata*) local[i].metadataP;
1489-
rmd = (nixlUcxPublicMetadata*) remote[i].metadataP;
1516+
for (size_t i = start_idx; i < end_idx;) {
1517+
/* Send requests to a single EP */
1518+
auto rmd = static_cast<nixlUcxPublicMetadata *>(remote[i].metadataP);
14901519
auto &ep = rmd->conn->getEp(workerId);
1520+
auto result = sendXferRangeBatch(*ep, operation, local, remote, workerId, i, end_idx);
14911521

1492-
if (lsize != rsize) {
1493-
return NIXL_ERR_INVALID_PARAM;
1494-
}
1495-
1496-
switch (operation) {
1497-
case NIXL_READ:
1498-
ret = ep->read(raddr, rmd->getRkey(workerId), laddr, lmd->mem, lsize, req);
1499-
break;
1500-
case NIXL_WRITE:
1501-
ret = ep->write(laddr, lmd->mem, raddr, rmd->getRkey(workerId), lsize, req);
1502-
break;
1503-
default:
1504-
return NIXL_ERR_INVALID_PARAM;
1505-
}
1506-
1507-
if (_retHelper(ret, intHandle, req, rmd->conn)) {
1522+
/* Append a single pending request for the entire EP batch */
1523+
ret = intHandle->append(result.status, result.req, rmd->conn);
1524+
if (ret != NIXL_SUCCESS) {
15081525
return ret;
15091526
}
1527+
1528+
i += result.size;
15101529
}
15111530

15121531
/*
15131532
* Flush keeps intHandle non-empty until the operation is actually
15141533
* completed, which can happen after local requests completion.
1534+
* We need to flush all distinct connections to ensure that the operation
1535+
* is actually completed.
15151536
*/
1516-
rmd = (nixlUcxPublicMetadata *)remote[0].metadataP;
1517-
ret = rmd->conn->getEp(workerId)->flushEp(req);
1518-
1519-
if (_retHelper(ret, intHandle, req, rmd->conn)) {
1520-
return ret;
1537+
for (auto &conn : intHandle->getConnections()) {
1538+
nixlUcxReq req;
1539+
ret = conn->getEp(workerId)->flushEp(req);
1540+
if (intHandle->append(ret, req, conn) != NIXL_SUCCESS) {
1541+
return ret;
1542+
}
15211543
}
15221544

15231545
return NIXL_SUCCESS;
@@ -1557,7 +1579,7 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation,
15571579
opt_args->notifMsg,
15581580
rmd->conn->getEp(int_handle->getWorkerId()),
15591581
&req);
1560-
if (_retHelper(ret, int_handle, req, rmd->conn)) {
1582+
if (int_handle->append(ret, req, rmd->conn) != NIXL_SUCCESS) {
15611583
return ret;
15621584
}
15631585

@@ -1594,8 +1616,8 @@ nixl_status_t nixlUcxEngine::checkXfer (nixlBackendReqH* handle) const
15941616
nixl_status_t status =
15951617
notifSendPriv(notif->agent, notif->payload, conn->getEp(intHandle->getWorkerId()), &req);
15961618
notif.reset();
1597-
status = _retHelper(status, intHandle, req, conn);
1598-
if (status != NIXL_SUCCESS) {
1619+
1620+
if (intHandle->append(status, req, conn) != NIXL_SUCCESS) {
15991621
return status;
16001622
}
16011623

src/plugins/ucx/ucx_backend.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,21 @@ class nixlUcxEngine : public nixlBackendEngine {
294294
ucx_connection_ptr_t
295295
getConnection(const std::string &remote_agent) const;
296296

297+
struct batchResult {
298+
nixl_status_t status;
299+
size_t size;
300+
nixlUcxReq req;
301+
};
302+
303+
static batchResult
304+
sendXferRangeBatch(nixlUcxEp &ep,
305+
nixl_xfer_op_t operation,
306+
const nixl_meta_dlist_t &local,
307+
const nixl_meta_dlist_t &remote,
308+
size_t worker_id,
309+
size_t start_idx,
310+
size_t end_idx);
311+
297312
/* UCX data */
298313
std::unique_ptr<nixlUcxContext> uc;
299314
std::vector<std::unique_ptr<nixlUcxWorker>> uws;

src/utils/ucx/ucx_utils.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,6 @@ bool nixlUcxMtLevelIsSupported(const nixl_ucx_mt_t mt_type) noexcept
406406
}
407407

408408
nixlUcxContext::nixlUcxContext(std::vector<std::string> devs,
409-
size_t req_size,
410409
bool prog_thread,
411410
unsigned long num_workers,
412411
nixl_thread_sync_t sync_mode) {
@@ -429,11 +428,6 @@ nixlUcxContext::nixlUcxContext(std::vector<std::string> devs,
429428
ucp_params.features |= UCP_FEATURE_WAKEUP;
430429
ucp_params.mt_workers_shared = num_workers > 1 ? 1 : 0;
431430

432-
if (req_size) {
433-
ucp_params.request_size = req_size;
434-
ucp_params.field_mask |= UCP_PARAM_FIELD_REQUEST_SIZE;
435-
}
436-
437431
nixl::ucx::config config;
438432

439433
/* If requested, restrict the set of network devices */

src/utils/ucx/ucx_utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,6 @@ class nixlUcxContext {
199199

200200
public:
201201
nixlUcxContext(std::vector<std::string> devices,
202-
size_t req_size,
203202
bool prog_thread,
204203
unsigned long num_workers,
205204
nixl_thread_sync_t sync_mode);

0 commit comments

Comments
 (0)