@@ -265,40 +265,14 @@ void nixlUcxEngine::vramFiniCtx()
265265 cudaCtx.reset ();
266266}
267267
268- /* ***************************************
269- * UCX request management
270- *****************************************/
271-
272-
273- class nixlUcxIntReq {
274- public:
275- operator nixlUcxReq () noexcept {
276- return static_cast <nixlUcxReq>(this );
277- }
278-
279- void
280- setConnection (nixlUcxConnection *conn) {
281- conn_ = conn;
282- }
283-
284- nixl_status_t
285- checkConnection (size_t ep_id) const {
286- NIXL_ASSERT (conn_) << " Connection is not set" ;
287- return conn_->getEp (ep_id)->checkTxState ();
288- }
289-
290- private:
291- nixlUcxConnection *conn_;
292- };
293-
294268/* ***************************************
295269 * Backend request management
296270*****************************************/
297271
298272class nixlUcxBackendH : public nixlBackendReqH {
299273private:
300274 std::set<ucx_connection_ptr_t > connections_;
301- std::vector<nixlUcxIntReq * > requests_;
275+ std::vector<nixlUcxReq > requests_;
302276 nixlUcxWorker *worker;
303277 size_t worker_id;
304278
@@ -313,26 +287,54 @@ class nixlUcxBackendH : public nixlBackendReqH {
313287 };
314288 std::optional<Notif> notif;
315289
316- public:
317- auto & notification () {
318- return notif;
290+ nixl_status_t
291+ checkConnection (nixl_status_t status = NIXL_SUCCESS) const {
292+ NIXL_ASSERT (!connections_.empty ());
293+ for (const auto &conn : connections_) {
294+ nixl_status_t conn_status = conn->getEp (worker_id)->checkTxState ();
295+ if (conn_status != NIXL_SUCCESS) {
296+ return conn_status;
297+ }
298+ }
299+ return status;
319300 }
320301
302+ public:
321303 nixlUcxBackendH (nixlUcxWorker *worker, size_t worker_id)
322304 : worker(worker),
323305 worker_id(worker_id) {}
324306
307+ auto &
308+ notification () {
309+ return notif;
310+ }
311+
325312 void
326313 reserve (size_t size) {
327314 requests_.reserve (size);
328315 }
329316
330- void
331- append (nixlUcxReq req, ucx_connection_ptr_t conn) {
332- auto req_int = static_cast <nixlUcxIntReq *>(req);
333- req_int->setConnection (conn.get ());
334- requests_.push_back (req_int);
317+ nixl_status_t
318+ append (nixl_status_t status, nixlUcxReq req, ucx_connection_ptr_t conn) {
335319 connections_.insert (conn);
320+ switch (status) {
321+ case NIXL_IN_PROG:
322+ requests_.push_back (req);
323+ break ;
324+ case NIXL_SUCCESS:
325+ // Nothing to do
326+ break ;
327+ default :
328+ // Error. Release all previously initiated ops and exit:
329+ release ();
330+ return status;
331+ }
332+ return NIXL_SUCCESS;
333+ }
334+
335+ const std::set<ucx_connection_ptr_t > &
336+ getConnections () const {
337+ return connections_;
336338 }
337339
338340 virtual bool
@@ -343,7 +345,7 @@ class nixlUcxBackendH : public nixlBackendReqH {
343345 virtual nixl_status_t
344346 release () {
345347 // TODO: Error log: uncompleted requests found! Cancelling ...
346- for (nixlUcxIntReq * req : requests_) {
348+ for (nixlUcxReq req : requests_) {
347349 nixl_status_t ret = ucx_status_to_nixl (ucp_request_check_status (req));
348350 if (ret == NIXL_IN_PROG) {
349351 // TODO: Need process this properly.
@@ -370,20 +372,19 @@ class nixlUcxBackendH : public nixlBackendReqH {
370372
371373 /* If last request is incomplete, return NIXL_IN_PROG early without
372374 * checking other requests */
373- nixlUcxIntReq * req = requests_.back ();
375+ nixlUcxReq req = requests_.back ();
374376 nixl_status_t ret = ucx_status_to_nixl (ucp_request_check_status (req));
375377 if (ret == NIXL_IN_PROG) {
376378 return NIXL_IN_PROG;
377379 } else if (ret != NIXL_SUCCESS) {
378- nixl_status_t conn_status = req->checkConnection (worker_id);
379- return (conn_status == NIXL_SUCCESS) ? ret : conn_status;
380+ return checkConnection (ret);
380381 }
381382
382383 /* Last request completed successfully, all the others must be in the
383384 * same state. TODO: remove extra checks? */
384385 size_t incomplete_reqs = 0 ;
385386 nixl_status_t out_ret = NIXL_SUCCESS;
386- for (nixlUcxIntReq * req : requests_) {
387+ for (nixlUcxReq req : requests_) {
387388 nixl_status_t ret = ucx_status_to_nixl (ucp_request_check_status (req));
388389 if (__builtin_expect (ret == NIXL_SUCCESS, 0 )) {
389390 worker->reqRelease (req);
@@ -394,8 +395,7 @@ class nixlUcxBackendH : public nixlBackendReqH {
394395 requests_[incomplete_reqs++] = req;
395396 } else {
396397 // Any other ret value is ERR and will be returned
397- nixl_status_t conn_status = req->checkConnection (worker_id);
398- out_ret = (conn_status == NIXL_SUCCESS) ? ret : conn_status;
398+ out_ret = checkConnection (ret);
399399 }
400400 }
401401
@@ -1102,7 +1102,7 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params)
11021102 }
11031103
11041104 uc = std::make_unique<nixlUcxContext>(
1105- devs, sizeof (nixlUcxIntReq), init_params.enableProgTh , num_workers, init_params.syncMode );
1105+ devs, init_params.enableProgTh , num_workers, init_params.syncMode );
11061106
11071107 for (size_t i = 0 ; i < num_workers; i++) {
11081108 uws.emplace_back (std::make_unique<nixlUcxWorker>(*uc, err_handling_mode));
@@ -1324,24 +1324,6 @@ nixl_status_t nixlUcxEngine::unloadMD (nixlBackendMD* input) {
13241324 * Data movement
13251325*****************************************/
13261326
1327- static nixl_status_t
1328- _retHelper (nixl_status_t ret, nixlUcxBackendH *hndl, nixlUcxReq &req, ucx_connection_ptr_t conn) {
1329- /* if transfer wasn't immediately completed */
1330- switch (ret) {
1331- case NIXL_IN_PROG:
1332- hndl->append (req, conn);
1333- case NIXL_SUCCESS:
1334- // Nothing to do
1335- break ;
1336- default :
1337- // Error. Release all previously initiated ops and exit:
1338- hndl->release ();
1339- return ret;
1340- }
1341-
1342- return NIXL_SUCCESS;
1343- }
1344-
13451327size_t
13461328nixlUcxEngine::getWorkerId () const {
13471329 auto it = tlsSharedWorkerMap ().find (this );
@@ -1461,6 +1443,56 @@ nixl_status_t nixlUcxEngine::estimateXferCost (const nixl_xfer_op_t &operation,
14611443 return NIXL_SUCCESS;
14621444}
14631445
1446+ nixlUcxEngine::batchResult
1447+ nixlUcxEngine::sendXferRangeBatch (nixlUcxEp &ep,
1448+ nixl_xfer_op_t operation,
1449+ const nixl_meta_dlist_t &local,
1450+ const nixl_meta_dlist_t &remote,
1451+ size_t worker_id,
1452+ size_t start_idx,
1453+ size_t end_idx) {
1454+ batchResult result = {NIXL_SUCCESS, 0 , nullptr };
1455+
1456+ for (size_t i = start_idx; i < end_idx; ++i) {
1457+ void *laddr = (void *)local[i].addr ;
1458+ size_t lsize = local[i].len ;
1459+ uint64_t raddr = static_cast <uint64_t >(remote[i].addr );
1460+ NIXL_ASSERT (lsize == remote[i].len );
1461+
1462+ auto lmd = static_cast <nixlUcxPrivateMetadata *>(local[i].metadataP );
1463+ auto rmd = static_cast <nixlUcxPublicMetadata *>(remote[i].metadataP );
1464+ auto &rmd_ep = rmd->conn ->getEp (worker_id);
1465+ if (__builtin_expect (rmd_ep.get () != &ep, 0 )) {
1466+ break ;
1467+ }
1468+
1469+ ++result.size ;
1470+ nixlUcxReq req;
1471+ nixl_status_t ret = operation == NIXL_READ ?
1472+ ep.read (raddr, rmd->getRkey (worker_id), laddr, lmd->mem , lsize, req) :
1473+ ep.write (laddr, lmd->mem , raddr, rmd->getRkey (worker_id), lsize, req);
1474+
1475+ if (ret == NIXL_IN_PROG) {
1476+ if (__builtin_expect (result.req != nullptr , 1 )) {
1477+ ucp_request_free (result.req );
1478+ }
1479+ result.req = req;
1480+ } else if (ret != NIXL_SUCCESS) {
1481+ result.status = ret;
1482+ if (result.req != nullptr ) {
1483+ ucp_request_free (result.req );
1484+ result.req = nullptr ;
1485+ }
1486+ break ;
1487+ }
1488+ }
1489+
1490+ if (result.status == NIXL_SUCCESS && result.req ) {
1491+ result.status = NIXL_IN_PROG;
1492+ }
1493+ return result;
1494+ }
1495+
14641496nixl_status_t
14651497nixlUcxEngine::sendXferRange (const nixl_xfer_op_t &operation,
14661498 const nixl_meta_dlist_t &local,
@@ -1470,54 +1502,44 @@ nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation,
14701502 size_t start_idx,
14711503 size_t end_idx) const {
14721504 nixlUcxBackendH *intHandle = (nixlUcxBackendH *)handle;
1473- nixlUcxPrivateMetadata *lmd;
1474- nixlUcxPublicMetadata *rmd;
1475- nixl_status_t ret;
1476- nixlUcxReq req;
14771505 size_t workerId = intHandle->getWorkerId ();
1506+ nixl_status_t ret;
14781507
1479- // Reserve space for the requests, +2 for flush and completion
1480- intHandle->reserve (end_idx - start_idx + 2 );
1508+ if (operation != NIXL_WRITE && operation != NIXL_READ) {
1509+ return NIXL_ERR_INVALID_PARAM;
1510+ }
14811511
1482- for (size_t i = start_idx; i < end_idx; i++) {
1483- void *laddr = (void *) local[i].addr ;
1484- size_t lsize = local[i].len ;
1485- uint64_t raddr = (uint64_t )remote[i].addr ;
1486- size_t rsize = remote[i].len ;
1512+ /* Assuming we have a single EP, we need 3 requests: one pending request,
1513+ * one flush request, and one notification request */
1514+ intHandle->reserve (3 );
14871515
1488- lmd = (nixlUcxPrivateMetadata*) local[i].metadataP ;
1489- rmd = (nixlUcxPublicMetadata*) remote[i].metadataP ;
1516+ for (size_t i = start_idx; i < end_idx;) {
1517+ /* Send requests to a single EP */
1518+ auto rmd = static_cast <nixlUcxPublicMetadata *>(remote[i].metadataP );
14901519 auto &ep = rmd->conn ->getEp (workerId);
1520+ auto result = sendXferRangeBatch (*ep, operation, local, remote, workerId, i, end_idx);
14911521
1492- if (lsize != rsize) {
1493- return NIXL_ERR_INVALID_PARAM;
1494- }
1495-
1496- switch (operation) {
1497- case NIXL_READ:
1498- ret = ep->read (raddr, rmd->getRkey (workerId), laddr, lmd->mem , lsize, req);
1499- break ;
1500- case NIXL_WRITE:
1501- ret = ep->write (laddr, lmd->mem , raddr, rmd->getRkey (workerId), lsize, req);
1502- break ;
1503- default :
1504- return NIXL_ERR_INVALID_PARAM;
1505- }
1506-
1507- if (_retHelper (ret, intHandle, req, rmd->conn )) {
1522+ /* Append a single pending request for the entire EP batch */
1523+ ret = intHandle->append (result.status , result.req , rmd->conn );
1524+ if (ret != NIXL_SUCCESS) {
15081525 return ret;
15091526 }
1527+
1528+ i += result.size ;
15101529 }
15111530
15121531 /*
15131532 * Flush keeps intHandle non-empty until the operation is actually
15141533 * completed, which can happen after local requests completion.
1534+ * We need to flush all distinct connections to ensure that the operation
1535+ * is actually completed.
15151536 */
1516- rmd = (nixlUcxPublicMetadata *)remote[0 ].metadataP ;
1517- ret = rmd->conn ->getEp (workerId)->flushEp (req);
1518-
1519- if (_retHelper (ret, intHandle, req, rmd->conn )) {
1520- return ret;
1537+ for (auto &conn : intHandle->getConnections ()) {
1538+ nixlUcxReq req;
1539+ ret = conn->getEp (workerId)->flushEp (req);
1540+ if (intHandle->append (ret, req, conn) != NIXL_SUCCESS) {
1541+ return ret;
1542+ }
15211543 }
15221544
15231545 return NIXL_SUCCESS;
@@ -1557,7 +1579,7 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation,
15571579 opt_args->notifMsg ,
15581580 rmd->conn ->getEp (int_handle->getWorkerId ()),
15591581 &req);
1560- if (_retHelper (ret, int_handle, req, rmd->conn )) {
1582+ if (int_handle-> append (ret, req, rmd->conn ) != NIXL_SUCCESS ) {
15611583 return ret;
15621584 }
15631585
@@ -1594,8 +1616,8 @@ nixl_status_t nixlUcxEngine::checkXfer (nixlBackendReqH* handle) const
15941616 nixl_status_t status =
15951617 notifSendPriv (notif->agent , notif->payload , conn->getEp (intHandle->getWorkerId ()), &req);
15961618 notif.reset ();
1597- status = _retHelper (status, intHandle, req, conn);
1598- if (status != NIXL_SUCCESS) {
1619+
1620+ if (intHandle-> append ( status, req, conn) != NIXL_SUCCESS) {
15991621 return status;
16001622 }
16011623
0 commit comments