diff --git a/test/gtest/plugins/memory_handler.h b/test/gtest/plugins/memory_handler.h index 632bf50cf..03fbb46fd 100644 --- a/test/gtest/plugins/memory_handler.h +++ b/test/gtest/plugins/memory_handler.h @@ -44,8 +44,9 @@ template<> class memoryHandler { for (auto &entry : buf_) { uint8_t expected_byte = start_byte++; if (entry != expected_byte) { - NIXL_ERROR << "Verification failed! local: " << entry - << ", expected: " << expected_byte; + NIXL_ERROR << absl::StrFormat("Byte mismatch: actual=0x%x, expected=0x%x", + static_cast(entry), + static_cast(expected_byte)); return false; } } @@ -105,8 +106,8 @@ template<> class memoryHandler { void populateMetaDesc(nixlMetaDesc *desc, int entry_index, size_t entry_size) { - desc->addr = 0; - desc->len = len_; + desc->addr = entry_index * entry_size; + desc->len = entry_size; desc->devId = devId_; desc->metadataP = md_; } diff --git a/test/gtest/plugins/obj_plugin.cpp b/test/gtest/plugins/obj_plugin.cpp index 7ccc3257c..050a3e2ae 100644 --- a/test/gtest/plugins/obj_plugin.cpp +++ b/test/gtest/plugins/obj_plugin.cpp @@ -50,29 +50,35 @@ class setupObjTestFixture : public setupBackendTestFixture { }; TEST_P(setupObjTestFixture, XferTest) { + transferMemConfig mem_cfg; transferHandler transfer( - localBackendEngine_, localBackendEngine_, local_agent_name, local_agent_name, false, 1); - transfer.setLocalMem(); + localBackendEngine_, localBackendEngine_, local_agent_name, local_agent_name, mem_cfg); + transfer.setupMems(); + transfer.setSrcMem(); transfer.testTransfer(NIXL_WRITE); - transfer.resetLocalMem(); + transfer.resetSrcMem(); transfer.testTransfer(NIXL_READ); - transfer.checkLocalMem(); + transfer.checkSrcMem(); } TEST_P(setupObjTestFixture, XferMultiBufsTest) { + transferMemConfig mem_cfg{.numBufs_ = 3}; transferHandler transfer( - localBackendEngine_, localBackendEngine_, local_agent_name, local_agent_name, false, 3); - transfer.setLocalMem(); + localBackendEngine_, localBackendEngine_, local_agent_name, local_agent_name, mem_cfg); + transfer.setupMems(); + transfer.setSrcMem(); transfer.testTransfer(NIXL_WRITE); - transfer.resetLocalMem(); + transfer.resetSrcMem(); transfer.testTransfer(NIXL_READ); - transfer.checkLocalMem(); + transfer.checkSrcMem(); } TEST_P(setupObjTestFixture, queryMemTest) { + transferMemConfig mem_cfg{.numBufs_ = 3}; transferHandler transfer( - localBackendEngine_, localBackendEngine_, local_agent_name, local_agent_name, false, 3); - transfer.setLocalMem(); + localBackendEngine_, localBackendEngine_, local_agent_name, local_agent_name, mem_cfg); + transfer.setupMems(); + transfer.setSrcMem(); transfer.testTransfer(NIXL_WRITE); nixl_reg_dlist_t descs(OBJ_SEG); diff --git a/test/gtest/plugins/transfer_handler.h b/test/gtest/plugins/transfer_handler.h index 9c13b2f9e..535efbba9 100644 --- a/test/gtest/plugins/transfer_handler.h +++ b/test/gtest/plugins/transfer_handler.h @@ -17,6 +17,7 @@ #ifndef __TRANSFER_HANDLER_H #define __TRANSFER_HANDLER_H +#include #include #include #include "backend_engine.h" @@ -26,41 +27,58 @@ namespace gtest::plugins { +int +getRandomInt(int min, int max) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dist(min, max); + return dist(gen); +} + +struct transferMemConfig { + const size_t numEntries_ = 1; + const size_t entrySize_ = 64; + const size_t numBufs_ = 1; + const uint8_t srcBufByte_ = getRandomInt(0, 255); + const uint8_t dstBufByte_ = getRandomInt(0, 255); + + size_t + bufSize() const { + return numEntries_ * entrySize_; + } +}; + template class transferHandler { public: transferHandler(std::shared_ptr src_engine, std::shared_ptr dst_engine, std::string src_agent_name, std::string dst_agent_name, - bool split_buf, - int num_bufs) + transferMemConfig mem_cfg = transferMemConfig()) : srcBackendEngine_(src_engine), dstBackendEngine_(dst_engine), + srcDescs_(std::make_unique(srcMemType)), + dstDescs_(std::make_unique(dstMemType)), + memConfig_(std::move(mem_cfg)), srcAgentName_(src_agent_name), dstAgentName_(dst_agent_name), - srcDevId_(0) { - - bool remote_xfer = srcAgentName_ != dstAgentName_; - if (remote_xfer) { - CHECK(src_engine->supportsRemote()) << "Local engine does not support remote transfers"; - dstDevId_ = 1; - verifyConnInfo(); - } else { - CHECK(src_engine->supportsLocal()) << "Local engine does not support local transfers"; - dstDevId_ = srcDevId_; - } + isRemoteXfer_(srcAgentName_ != dstAgentName_), + srcDevId_(0), + dstDevId_(isRemoteXfer_ ? 1 : 0) { + if (dstBackendEngine_->supportsNotif()) setupNotifs("Test"); + } - for (int i = 0; i < num_bufs; i++) { + void + setupMems() { + for (size_t i = 0; i < memConfig_.numBufs_; i++) { srcMem_.emplace_back( - std::make_unique>(BUF_SIZE, srcDevId_ + i)); + std::make_unique>(memConfig_.bufSize(), srcDevId_ + i)); dstMem_.emplace_back( - std::make_unique>(BUF_SIZE, dstDevId_ + i)); + std::make_unique>(memConfig_.bufSize(), dstDevId_ + i)); } - if (dstBackendEngine_->supportsNotif()) setupNotifs("Test"); - registerMems(); - prepMems(split_buf, remote_xfer); + prepareMems(); } ~transferHandler() { @@ -71,47 +89,89 @@ template class transferHandler { void testTransfer(nixl_xfer_op_t op) { - performTransfer(op); + verifyConnInfo(); + ASSERT_EQ(prepareTransfer(op), NIXL_SUCCESS); + ASSERT_EQ(postTransfer(op), NIXL_SUCCESS); + ASSERT_EQ(waitForTransfer(), NIXL_SUCCESS); + ASSERT_EQ(srcBackendEngine_->releaseReqH(xferHandle_), NIXL_SUCCESS); verifyTransfer(op); } + nixl_status_t + prepareTransfer(nixl_xfer_op_t op) { + return srcBackendEngine_->prepXfer( + op, *srcDescs_, *dstDescs_, dstAgentName_, xferHandle_, &xferOptArgs_); + } + + nixl_status_t + postTransfer(nixl_xfer_op_t op) { + nixl_status_t ret; + ret = srcBackendEngine_->postXfer( + op, *srcDescs_, *dstDescs_, dstAgentName_, xferHandle_, &xferOptArgs_); + return (ret == NIXL_SUCCESS || ret == NIXL_IN_PROG) ? NIXL_SUCCESS : NIXL_ERR_BACKEND; + } + + nixl_status_t + waitForTransfer() { + nixl_status_t ret = NIXL_IN_PROG; + auto end_time = absl::Now() + absl::Seconds(3); + + NIXL_INFO << "\t\tWaiting for transfer to complete..."; + while (ret == NIXL_IN_PROG && absl::Now() < end_time) { + ret = srcBackendEngine_->checkXfer(xferHandle_); + if (ret != NIXL_SUCCESS && ret != NIXL_IN_PROG) return ret; + + if (dstBackendEngine_->supportsProgTh()) dstBackendEngine_->progress(); + } + NIXL_INFO << "\nTransfer complete"; + + return NIXL_SUCCESS; + } + void - setLocalMem() { + addSrcDesc(nixlMetaDesc &meta_desc) { + srcDescs_->addDesc(meta_desc); + } + + void + addDstDesc(nixlMetaDesc &meta_desc) { + dstDescs_->addDesc(meta_desc); + } + + void + setSrcMem() { for (size_t i = 0; i < srcMem_.size(); i++) - srcMem_[i]->setIncreasing(LOCAL_BUF_BYTE + i); + srcMem_[i]->setIncreasing(memConfig_.srcBufByte_ + i); } void - resetLocalMem() { + resetSrcMem() { for (const auto &mem : srcMem_) mem->reset(); } void - checkLocalMem() { + checkSrcMem() { for (size_t i = 0; i < srcMem_.size(); i++) - EXPECT_TRUE(srcMem_[i]->checkIncreasing(LOCAL_BUF_BYTE + i)); + EXPECT_TRUE(srcMem_[i]->checkIncreasing(memConfig_.srcBufByte_ + i)); } private: - static constexpr uint8_t LOCAL_BUF_BYTE = 0x11; - static constexpr uint8_t XFER_BUF_BYTE = 0x22; - static constexpr size_t NUM_ENTRIES = 4; - static constexpr size_t ENTRY_SIZE = 16; - static constexpr size_t BUF_SIZE = NUM_ENTRIES * ENTRY_SIZE; - std::vector>> srcMem_; std::vector>> dstMem_; - std::shared_ptr srcBackendEngine_; - std::shared_ptr dstBackendEngine_; - std::unique_ptr srcDescs_; - std::unique_ptr dstDescs_; + const std::shared_ptr srcBackendEngine_; + const std::shared_ptr dstBackendEngine_; + const std::unique_ptr srcDescs_; + const std::unique_ptr dstDescs_; + const transferMemConfig memConfig_; + const std::string srcAgentName_; + const std::string dstAgentName_; nixl_opt_b_args_t xferOptArgs_; nixlBackendMD *xferLoadedMd_; - std::string srcAgentName_; - std::string dstAgentName_; - int srcDevId_; - int dstDevId_; + nixlBackendReqH *xferHandle_; + const bool isRemoteXfer_; + const int srcDevId_; + const int dstDevId_; void registerMems() { @@ -139,8 +199,8 @@ template class transferHandler { } void - prepMems(bool split_buf, bool remote_xfer) { - if (remote_xfer) { + prepareMems() { + if (isRemoteXfer_) { nixlBlobDesc info; dstMem_[0]->populateBlobDesc(&info); ASSERT_EQ(srcBackendEngine_->getPublicData(dstMem_[0]->getMD(), info.metaInfo), @@ -154,51 +214,17 @@ template class transferHandler { NIXL_SUCCESS); } - srcDescs_ = std::make_unique(srcMemType); - dstDescs_ = std::make_unique(dstMemType); - - int num_entries = split_buf ? NUM_ENTRIES : 1; - int entry_size = split_buf ? ENTRY_SIZE : BUF_SIZE; for (size_t i = 0; i < srcMem_.size(); i++) { - for (int entry_i = 0; entry_i < num_entries; entry_i++) { + for (size_t entry_i = 0; entry_i < memConfig_.numEntries_; entry_i++) { nixlMetaDesc desc; - srcMem_[i]->populateMetaDesc(&desc, entry_i, entry_size); + srcMem_[i]->populateMetaDesc(&desc, entry_i, memConfig_.entrySize_); srcDescs_->addDesc(desc); - dstMem_[i]->populateMetaDesc(&desc, entry_i, entry_size); + dstMem_[i]->populateMetaDesc(&desc, entry_i, memConfig_.entrySize_); dstDescs_->addDesc(desc); } } } - void - performTransfer(nixl_xfer_op_t op) { - nixlBackendReqH *handle; - nixl_status_t ret; - - ASSERT_EQ(srcBackendEngine_->prepXfer( - op, *srcDescs_, *dstDescs_, dstAgentName_, handle, &xferOptArgs_), - NIXL_SUCCESS); - - ret = srcBackendEngine_->postXfer( - op, *srcDescs_, *dstDescs_, dstAgentName_, handle, &xferOptArgs_); - ASSERT_TRUE(ret == NIXL_SUCCESS || ret == NIXL_IN_PROG); - - NIXL_INFO << "\t\tWaiting for transfer to complete..."; - - auto end_time = absl::Now() + absl::Seconds(3); - - while (ret == NIXL_IN_PROG && absl::Now() < end_time) { - ret = srcBackendEngine_->checkXfer(handle); - ASSERT_TRUE(ret == NIXL_SUCCESS || ret == NIXL_IN_PROG); - - if (dstBackendEngine_->supportsProgTh()) dstBackendEngine_->progress(); - } - - NIXL_INFO << "\nTransfer complete"; - - ASSERT_EQ(srcBackendEngine_->releaseReqH(handle), NIXL_SUCCESS); - } - void verifyTransfer(nixl_xfer_op_t op) { if (srcBackendEngine_->supportsNotif()) { @@ -245,8 +271,9 @@ template class transferHandler { void verifyConnInfo() { - std::string conn_info; + if (!isRemoteXfer_) return; + std::string conn_info; ASSERT_EQ(srcBackendEngine_->getConnInfo(conn_info), NIXL_SUCCESS); ASSERT_EQ(dstBackendEngine_->getConnInfo(conn_info), NIXL_SUCCESS); ASSERT_EQ(srcBackendEngine_->loadRemoteConnInfo(dstAgentName_, conn_info), NIXL_SUCCESS);