diff --git a/src/plugins/libfabric/libfabric_backend.cpp b/src/plugins/libfabric/libfabric_backend.cpp index 5813e88656..beb3a47b8f 100644 --- a/src/plugins/libfabric/libfabric_backend.cpp +++ b/src/plugins/libfabric/libfabric_backend.cpp @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -169,13 +169,14 @@ nixlLibfabricBackendH::nixlLibfabricBackendH(nixl_xfer_op_t op, const std::strin : completed_requests_(0), submitted_requests_(0), operation_(op), - remote_agent_(remote_agent) { - // Initialize BinaryNotification - binary_notif.clear(); + remote_agent_(remote_agent), + total_notif_msg_len(0) { + // Initialize BinaryNotification vector + binary_notifs.clear(); NIXL_DEBUG << " handle constructor called, address: " << this << " total_requests_used=" << submitted_requests_.load() - << " BinaryNotification initialized"; + << " BinaryNotification vector initialized"; } nixlLibfabricBackendH::~nixlLibfabricBackendH() { @@ -192,8 +193,8 @@ nixlLibfabricBackendH::init_request_tracking(size_t num_requests) { void nixlLibfabricBackendH::increment_completed_requests() { - size_t completed = completed_requests_.fetch_add(1); - NIXL_DEBUG << "Request completed, total completed: " << completed << "/" + completed_requests_.fetch_add(1); + NIXL_DEBUG << "Request completed, total completed: " << completed_requests_.load() << "/" << submitted_requests_.load(); } @@ -203,12 +204,12 @@ nixlLibfabricBackendH::get_completed_requests_count() const { } size_t -nixlLibfabricBackendH::get_total_requests_used() const { +nixlLibfabricBackendH::get_submitted_requests_count() const { return submitted_requests_.load(); } void -nixlLibfabricBackendH::adjust_total_requests(size_t actual_count) { +nixlLibfabricBackendH::adjust_total_submitted_requests(size_t actual_count) { submitted_requests_.store(actual_count); NIXL_DEBUG << "Adjusted total requests to actual count: " << actual_count; } @@ -436,8 +437,8 @@ nixlLibfabricEngine::loadRemoteConnInfo(const std::string &remote_agent, std::lock_guard lock(connection_state_mutex_); NIXL_DEBUG << "Loading remote info for agent: " << remote_agent - << ", info length=" << remote_conn_info.length() - << ", info (hex): " << LibfabricUtils::hexdump(remote_conn_info.data()); + << ", info length=" << remote_conn_info.length() << ", info (hex): " + << LibfabricUtils::hexdump(remote_conn_info.data(), remote_conn_info.length()); if (remote_conn_info.empty()) { NIXL_ERROR << "Empty remote connection info received"; @@ -640,11 +641,11 @@ nixlLibfabricEngine::establishConnection(const std::string &remote_agent) const << " data rails and " << conn_info->control_ep_names_.size() << " control rails"; for (size_t i = 0; i < conn_info->src_ep_names_.size(); ++i) { NIXL_DEBUG << "Data rail " << i << ": " - << LibfabricUtils::hexdump(conn_info->src_ep_names_[i]); + << LibfabricUtils::hexdump(conn_info->src_ep_names_[i], LF_EP_NAME_MAX_LEN); } for (size_t i = 0; i < conn_info->control_ep_names_.size(); ++i) { NIXL_DEBUG << "Control rail " << i << ": " - << LibfabricUtils::hexdump(conn_info->control_ep_names_[i]); + << LibfabricUtils::hexdump(conn_info->control_ep_names_[i], LF_EP_NAME_MAX_LEN); } NIXL_DEBUG << "Agent index: " << it->second->agent_index_; if (!conn_info) { @@ -664,8 +665,9 @@ nixlLibfabricEngine::establishConnection(const std::string &remote_agent) const return serialize_status; } - nixlLibfabricReq *control_request = rail_manager.getControlRail(control_rail_id) - .allocateControlRequest(serialized_conn_info.length()); + nixlLibfabricReq *control_request = + rail_manager.getControlRail(control_rail_id) + .allocateControlRequest(serialized_conn_info.length(), LibfabricUtils::getNextXferId()); if (!control_request) { NIXL_ERROR << "Failed to allocate control request for connection establishment"; return NIXL_ERR_BACKEND; @@ -837,9 +839,8 @@ nixlLibfabricEngine::loadMetadataHelper(const std::vector &rail_keys, pub_md->remote_buf_addr_ = reinterpret_cast(buffer); pub_md->conn_ = conn; - NIXL_DEBUG << "Metadata loaded with" - << " Remote addr: " << (void *)pub_md->remote_buf_addr_ << " Remote keys for " - << pub_md->rail_remote_key_list_.size() << " rails" + NIXL_DEBUG << "Metadata loaded with" << " Remote addr: " << (void *)pub_md->remote_buf_addr_ + << " Remote keys for " << pub_md->rail_remote_key_list_.size() << " rails" << " Remote fi_addr: " << pub_md->conn_->rail_remote_addr_list_[0][0]; output = pub_md.release(); return NIXL_SUCCESS; @@ -933,10 +934,16 @@ nixlLibfabricEngine::prepXfer(const nixl_xfer_op_t &operation, // Set agent name and message in BinaryNotification during prepXfer if (opt_args && opt_args->hasNotif) { backend_handle->has_notif = true; - backend_handle->binary_notif.setAgentName(localAgent); - backend_handle->binary_notif.setMessage(opt_args->notifMsg); - backend_handle->binary_notif.expected_completions = 0; - NIXL_DEBUG << "Setting notification message: " << opt_args->notifMsg; + + // Use common fragmentation helper function + fragmentNotificationMessage(opt_args->notifMsg, + localAgent, + backend_handle->total_notif_msg_len, + backend_handle->binary_notifs); + + NIXL_DEBUG << "prepXfer: Fragmented notification into " + << backend_handle->binary_notifs.size() + << " fragments, total_length=" << backend_handle->total_notif_msg_len; } handle = backend_handle; // Assign to base class pointer @@ -993,13 +1000,8 @@ nixlLibfabricEngine::postXfer(const nixl_xfer_op_t &operation, return NIXL_ERR_INVALID_PARAM; } - // Use pre-allocated BinaryNotification from handle and set xfer_id - backend_handle->binary_notif.xfer_id = LibfabricUtils::getNextXferId(); - backend_handle->binary_notif.expected_completions = - 0; // Will be incremented during transfer submission - - NIXL_DEBUG << "Using pre-allocated BinaryNotification with XFER_ID=" - << backend_handle->binary_notif.xfer_id; + // Allocate xfer_id once in prepXfer + backend_handle->post_xfer_id = LibfabricUtils::getNextXferId(); nixlLibfabricReq::OpType op_type; int desc_count = local.descCount(); @@ -1013,6 +1015,8 @@ nixlLibfabricEngine::postXfer(const nixl_xfer_op_t &operation, size_t max_possible_requests = desc_count * rail_manager.getNumDataRails(); backend_handle->init_request_tracking(max_possible_requests); + size_t total_submitted = 0; + // Core transfer submission to process each descriptor with direct submission for (int desc_idx = 0; desc_idx < desc_count; ++desc_idx) { auto *local_md = static_cast(local[desc_idx].metadataP); @@ -1043,6 +1047,7 @@ nixlLibfabricEngine::postXfer(const nixl_xfer_op_t &operation, // Use descriptor's specific target address uint64_t remote_target_addr = remote[desc_idx].addr; + size_t submitted_count = 0; nixl_status_t status = rail_manager.prepareAndSubmitTransfer( op_type, transfer_addr, @@ -1054,11 +1059,11 @@ nixlLibfabricEngine::postXfer(const nixl_xfer_op_t &operation, remote_md->remote_selected_endpoints_, conn_it->second->rail_remote_addr_list_, conn_it->second->agent_index_, + backend_handle->post_xfer_id, [backend_handle]() { backend_handle->increment_completed_requests(); }, // Completion callback - &(backend_handle->binary_notif) // Populate BinaryNotification - ); + submitted_count); if (status != NIXL_SUCCESS) { NIXL_ERROR << "prepareAndSubmitTransfer failed for descriptor " << desc_idx << " GPU " @@ -1066,35 +1071,39 @@ nixlLibfabricEngine::postXfer(const nixl_xfer_op_t &operation, return status; } + // Add submitted requests to the total count + total_submitted += submitted_count; + NIXL_DEBUG << "Successfully processed descriptor " << desc_idx << " with " - << backend_handle->binary_notif.expected_completions << " requests submitted"; + << submitted_count << " requests submitted (accumulated: " << total_submitted + << ")"; } - NIXL_DEBUG << "Processing complete: submitted " - << backend_handle->binary_notif.expected_completions << " requests from " - << desc_count << " descriptors" - << " with " << backend_handle->binary_notif.expected_completions - << " total XFER_IDs"; + NIXL_DEBUG << "Processing complete: submitted " << total_submitted << " requests from " + << desc_count << " descriptors" << " for xfer_id" << backend_handle->post_xfer_id; - // For same-agent transfers, we need to set the total to 0 since we bypassed all rail operations + // For same-agent transfers, override to 0 since we bypassed all rail operations if (remote_agent == localAgent) { - backend_handle->adjust_total_requests(0); + backend_handle->adjust_total_submitted_requests(0); NIXL_DEBUG << "Same-agent transfer: adjusted total requests to 0 (all handled via memcpy)"; } else { // Adjust to actual request count after all submissions complete - backend_handle->adjust_total_requests(backend_handle->binary_notif.expected_completions); + backend_handle->adjust_total_submitted_requests(total_submitted); } // Send notification immediately after successful request submission if (backend_handle->has_notif && backend_handle->operation_ == nixl_xfer_op_t::NIXL_WRITE) { - nixl_status_t notif_status = notifSendPriv(remote_agent, backend_handle->binary_notif); + nixl_status_t notif_status = notifSendPriv(remote_agent, + backend_handle->binary_notifs, + backend_handle->total_notif_msg_len, + backend_handle->post_xfer_id, + backend_handle->get_submitted_requests_count()); if (notif_status != NIXL_SUCCESS) { NIXL_ERROR << "Failed to send notification"; return notif_status; } - NIXL_DEBUG << "Notification sent immediately with XFER_ID=" - << backend_handle->binary_notif.xfer_id << ", expected_completions: " - << backend_handle->binary_notif.expected_completions; + NIXL_DEBUG << "Notification sent immediately with XFER_ID=" << backend_handle->post_xfer_id + << ", expected_completions: " << backend_handle->get_submitted_requests_count(); } // Progress data rails to kick off transfers @@ -1108,8 +1117,11 @@ nixlLibfabricEngine::postXfer(const nixl_xfer_op_t &operation, // For very small transfers we can check for local completions immediately. if (backend_handle->is_completed()) { if (backend_handle->has_notif && backend_handle->operation_ == nixl_xfer_op_t::NIXL_READ) { - backend_handle->binary_notif.expected_completions = 0; - nixl_status_t notif_status = notifSendPriv(remote_agent, backend_handle->binary_notif); + nixl_status_t notif_status = notifSendPriv(remote_agent, + backend_handle->binary_notifs, + backend_handle->total_notif_msg_len, + backend_handle->post_xfer_id, + 0); if (notif_status != NIXL_SUCCESS) { NIXL_ERROR << "Failed to send notification"; return notif_status; @@ -1136,9 +1148,11 @@ nixlLibfabricEngine::checkXfer(nixlBackendReqH *handle) const { if (backend_handle->is_completed()) { NIXL_DEBUG << "Data transfer completed successfully"; if (backend_handle->has_notif && backend_handle->operation_ == nixl_xfer_op_t::NIXL_READ) { - backend_handle->binary_notif.expected_completions = 0; - nixl_status_t notif_status = - notifSendPriv(backend_handle->remote_agent_, backend_handle->binary_notif); + nixl_status_t notif_status = notifSendPriv(backend_handle->remote_agent_, + backend_handle->binary_notifs, + backend_handle->total_notif_msg_len, + backend_handle->post_xfer_id, + 0); if (notif_status != NIXL_SUCCESS) { NIXL_ERROR << "Failed to send notification"; return notif_status; @@ -1166,10 +1180,99 @@ nixlLibfabricEngine::releaseReqH(nixlBackendReqH *handle) const { return NIXL_SUCCESS; } -// notifSendPriv that accept control request +/**************************************** + * Notification Functions + *****************************************/ + +void +nixlLibfabricEngine::fragmentNotificationMessage( + const std::string &message, + const std::string &agent_name, + uint32_t &total_message_length, + std::vector &fragments_out) const { + // agent_name + message forms a single combined payload + std::string combined_payload = agent_name + message; + total_message_length = static_cast(combined_payload.length()); + + const size_t max_control_msg_size = BinaryNotification::MAX_FRAGMENT_SIZE; + + // Calculate fragment 0 capacity (has extra headers) + size_t frag0_overhead = sizeof(BinaryNotificationHeader) + sizeof(BinaryNotificationMetadata); + size_t frag0_capacity = max_control_msg_size - frag0_overhead; + + // Calculate fragment 1+ capacity (only has minimal header) + size_t frag_overhead = sizeof(BinaryNotificationHeader); + size_t frag_capacity = max_control_msg_size - frag_overhead; + + // Calculate number of fragments needed + size_t num_fragments = 1; // At least fragment 0 + size_t remaining = 0; + if (total_message_length > frag0_capacity) { + remaining = total_message_length - frag0_capacity; + num_fragments += (remaining + frag_capacity - 1) / frag_capacity; + } + + fragments_out.clear(); + fragments_out.resize(num_fragments); + + NIXL_DEBUG << "Fragmenting: agent_name=" << agent_name.length() + << "B, message=" << message.length() + << "B, combined_payload=" << total_message_length << "B, fragments=" << num_fragments + << ", frag0_capacity=" << frag0_capacity << ", frag_capacity=" << frag_capacity; + + size_t offset = 0; + + for (size_t frag_idx = 0; frag_idx < num_fragments; ++frag_idx) { + // Set header fields + BinaryNotificationHeader header; + header.notif_xfer_id = 0; // Will be set later in notifSendPriv + header.notif_seq_id = static_cast(frag_idx); + header.notif_seq_len = static_cast(num_fragments); + + if (frag_idx == 0) { + // Fragment 0: Pack metadata + combined_payload_chunk + size_t payload_chunk_len = + std::min(frag0_capacity, static_cast(total_message_length)); + header.payload_length = static_cast(payload_chunk_len); + + fragments_out[0].setHeader(header); + fragments_out[0].setMetadata(total_message_length, + 0, // expected_completions set later + static_cast(agent_name.length())); + // Set the payload chunk directly + fragments_out[0].setPayload(combined_payload.substr(0, payload_chunk_len)); + + offset = payload_chunk_len; + + NIXL_DEBUG << "Fragment 0: combined_payload_chunk=" << payload_chunk_len << "B"; + } else { + // Fragment 1+: Pack only combined_payload continuation + size_t payload_chunk_len = + std::min(frag_capacity, static_cast(total_message_length) - offset); + header.payload_length = static_cast(payload_chunk_len); + + fragments_out[frag_idx].setHeader(header); + // Set the payload chunk directly + fragments_out[frag_idx].setPayload(combined_payload.substr(offset, payload_chunk_len)); + + offset += payload_chunk_len; + + NIXL_DEBUG << "Fragment " << frag_idx + << ": combined_payload_chunk=" << payload_chunk_len << "B"; + } + } + + NIXL_DEBUG << "Fragmentation complete: " << num_fragments + << " fragments, total_payload=" << total_message_length << "B"; +} + +// notifSendPriv that accepts vector of BinaryNotifications for fragmentation support nixl_status_t nixlLibfabricEngine::notifSendPriv(const std::string &remote_agent, - BinaryNotification &binary_notification) const { + std::vector &binary_notifications, + uint32_t total_message_length, + uint16_t notif_xfer_id, + uint32_t expected_completions) const { auto it = connections_.find(remote_agent); if (it == connections_.end()) { NIXL_ERROR << "No connection found for agent: " << remote_agent; @@ -1179,46 +1282,74 @@ nixlLibfabricEngine::notifSendPriv(const std::string &remote_agent, auto connection = it->second; const size_t control_rail_id = 0; // Only use control rail 0 for notifications - // Allocate control request for notification - nixlLibfabricReq *control_request = rail_manager.getControlRail(control_rail_id) - .allocateControlRequest(sizeof(BinaryNotification)); - if (!control_request) { - NIXL_ERROR << "Failed to allocate control request for notification"; - return NIXL_ERR_BACKEND; - } + NIXL_DEBUG << "Sending " << binary_notifications.size() << " notification fragments" + << " total_message_length=" << total_message_length; - // Copy BinaryNotification to control request buffer - memcpy(control_request->buffer, &binary_notification, sizeof(BinaryNotification)); + // Send each notification fragment + for (size_t seq_id = 0; seq_id < binary_notifications.size(); ++seq_id) { + auto &binary_notification = binary_notifications[seq_id]; - // Set the correct buffer size for the notification - control_request->buffer_size = sizeof(BinaryNotification); + // Update header fields for this notification + BinaryNotificationHeader header = binary_notification.getHeader(); + header.notif_xfer_id = notif_xfer_id; + binary_notification.setHeader(header); - NIXL_DEBUG << "Sending binary notification control request" - << " Message: " << binary_notification.getMessage() - << " expected_completions: " << binary_notification.expected_completions; - nixl_status_t status = rail_manager.postControlMessage( - nixlLibfabricRailManager::ControlMessageType::NOTIFICATION, - control_request, - connection->control_rail_remote_addr_list_[control_rail_id][0], - connection->agent_index_); + // Update first fragment header with expected_completions (only for fragment 0) + // Note: agent_name_length was already set during fragmentation + if (seq_id == 0) { + const BinaryNotificationMetadata &metadata = binary_notification.getMetadata(); + binary_notification.setMetadata( + total_message_length, expected_completions, metadata.agent_name_length); + } - if (status != NIXL_SUCCESS) { - NIXL_ERROR << "postControlMessage failed on control rail " << control_rail_id; - return NIXL_ERR_BACKEND; + // Allocate control request for this notification fragment + size_t max_size = BinaryNotification::MAX_FRAGMENT_SIZE; + nixlLibfabricReq *control_request = rail_manager.getControlRail(control_rail_id) + .allocateControlRequest(max_size, notif_xfer_id); + + if (!control_request) { + NIXL_ERROR << "Failed to allocate control request for notification fragment " << seq_id; + return NIXL_ERR_BACKEND; + } + + // Serialize BinaryNotification to control request buffer + size_t serialized_size = binary_notification.serialize(control_request->buffer); + control_request->buffer_size = serialized_size; + + NIXL_DEBUG << "Sending binary notification fragment " << seq_id << "/" + << binary_notifications.size() << " size=" << serialized_size << "B" + << " payload_chunk_size=" << header.payload_length << "B" + << " notif_xfer_id=" << header.notif_xfer_id; + + nixl_status_t status = rail_manager.postControlMessage( + nixlLibfabricRailManager::ControlMessageType::NOTIFICATION, + control_request, + connection->control_rail_remote_addr_list_[control_rail_id][0], + connection->agent_index_); + + if (status != NIXL_SUCCESS) { + NIXL_ERROR << "postControlMessage failed on control rail " << control_rail_id + << " for fragment " << seq_id; + return NIXL_ERR_BACKEND; + } } + NIXL_DEBUG << "Successfully sent all " << binary_notifications.size() + << " notification fragments" << " total_length=" << total_message_length; return NIXL_SUCCESS; } nixl_status_t nixlLibfabricEngine::genNotif(const std::string &remote_agent, const std::string &msg) const { - // Create BinaryNotification directly in the control buffer - BinaryNotification binary_notif; - binary_notif.clear(); - binary_notif.setAgentName(localAgent); - binary_notif.setMessage(msg); + // Use common fragmentation helper function + uint32_t total_msg_len = 0; + std::vector notifications; + fragmentNotificationMessage(msg, localAgent, total_msg_len, notifications); + + NIXL_DEBUG << "genNotif: Fragmented notification into " << notifications.size() + << " fragments, total_length=" << total_msg_len; - return notifSendPriv(remote_agent, binary_notif); + return notifSendPriv(remote_agent, notifications, total_msg_len, 0, 0); } nixl_status_t @@ -1322,7 +1453,8 @@ nixlLibfabricEngine::postShutdownCompletion() { const size_t control_rail_id = 0; const size_t shutdown_msg_len = 8; // "SHUTDOWN" length nixlLibfabricReq *control_request = - rail_manager.getControlRail(control_rail_id).allocateControlRequest(shutdown_msg_len); + rail_manager.getControlRail(control_rail_id) + .allocateControlRequest(shutdown_msg_len, LibfabricUtils::getNextXferId()); if (!control_request) { NIXL_ERROR << "Failed to allocate control request for shutdown"; return; @@ -1354,70 +1486,88 @@ nixlLibfabricEngine::postShutdownCompletion() { void nixlLibfabricEngine::processNotification(const std::string &serialized_notif) { - // Only handle binary notification format - // Check if this is a binary notification (fixed size) - NIXL_DEBUG << "Received notification size=" << serialized_notif.size() - << ", sizeof(Notification): " << sizeof(BinaryNotification); - - if (serialized_notif.size() != sizeof(BinaryNotification)) { - NIXL_ERROR << "Invalid notification size=" << serialized_notif.size() - << ", expected: " << sizeof(BinaryNotification); - return; - } - - // Process binary notification format - const BinaryNotification *binary_notif = - reinterpret_cast(serialized_notif.data()); - - std::string remote_name = binary_notif->getAgentName(); - std::string msg = binary_notif->getMessage(); - uint16_t xfer_id = binary_notif->xfer_id; - uint32_t expected_completions = binary_notif->expected_completions; - - NIXL_TRACE << "Received notification from " << remote_name << " msg: " << msg - << " XFER_ID=" << xfer_id << " expected_completions: " << expected_completions; - - // Check if this is a transfer notification that needs completions matching - if (expected_completions > 0) { - NIXL_DEBUG << "Transfer notification with expected_completions=" << expected_completions - << ", for XFER_ID " << xfer_id; - - { - std::lock_guard lock(receiver_tracking_mutex_); - - // Create composite key for O(1) lookup - auto it = pending_notifications_.find(xfer_id); - - if (it != pending_notifications_.end()) { - // Case 1: Writes already arrived - update placeholder with real values - it->second.remote_agent = remote_name; // Update agent name from notification - it->second.message = msg; - it->second.expected_completions = expected_completions; - - NIXL_DEBUG << "Updated placeholder notification for agent " << remote_name - << " XFER_ID " << xfer_id - << " expected_completions=" << expected_completions - << " received_completions=" << it->second.received_completions; - } else { - // Case 2: Notification arrived first - create a pending notification entry - PendingNotification pending_notif(remote_name, msg, xfer_id, expected_completions); - pending_notifications_[xfer_id] = pending_notif; + NIXL_DEBUG << "Received notification size=" << serialized_notif.size(); - NIXL_DEBUG << "Created pending notification for agent " << remote_name - << " xfer_id=" << xfer_id - << " expected_completions=" << expected_completions; - } + // Deserialize binary notification + BinaryNotification binary_notif; + BinaryNotification::deserialize(serialized_notif.data(), serialized_notif.size(), binary_notif); + + // Extract fields + const BinaryNotificationHeader &header = binary_notif.getHeader(); + uint16_t notif_xfer_id = header.notif_xfer_id; + uint16_t notif_seq_id = header.notif_seq_id; + uint16_t notif_seq_len = header.notif_seq_len; + + // Get payload chunk (combined agent_name + message chunk for all fragments) + const std::string &payload_chunk = binary_notif.getPayload(); + + // Get metadata from first fragment (only valid for fragment 0) + uint32_t expected_completions = 0; + uint32_t total_payload_length = 0; + uint16_t agent_name_length = 0; + if (notif_seq_id == 0) { + const BinaryNotificationMetadata &metadata = binary_notif.getMetadata(); + expected_completions = metadata.expected_completions; + total_payload_length = metadata.total_payload_length; + agent_name_length = metadata.agent_name_length; + } + + NIXL_TRACE << "Received notification fragment" << " notif_xfer_id=" << notif_xfer_id + << " notif_seq_id=" << notif_seq_id << "/" << notif_seq_len + << " payload_chunk_size=" << payload_chunk.size() + << " expected_completions=" << expected_completions; + + { + std::lock_guard lock(receiver_tracking_mutex_); + + // Use try_emplace to construct in-place - eliminates extra copy + auto [it, inserted] = pending_notifications_.try_emplace(notif_xfer_id, notif_xfer_id); + + if (inserted) { + NIXL_DEBUG << "Created pending notification" << " notif_xfer_id=" << notif_xfer_id + << " expected_completions=" << expected_completions + << " expected_msg_fragments=" << notif_seq_len; } - // Check if any notifications can now be completed (after releasing the lock) - checkPendingNotifications(); - } else { - // Regular notification without expected completions - process immediately - NIXL_TRACE << "Regular notification (expected_completions=0), processing immediately"; - std::lock_guard lock(notif_mutex_); - notifMainList_.push_back({remote_name, msg}); - NIXL_TRACE << "Regular notification processed immediately: " << msg; + // Initialize fragment vector on first fragment (check if vector is empty) + if (it->second.message_fragments.empty()) { + it->second.message_fragments.resize(notif_seq_len); + it->second.expected_msg_fragments = notif_seq_len; + } + + // Validate fragment index + if (notif_seq_id >= notif_seq_len) { + NIXL_ERROR << "Invalid fragment sequence: notif_seq_id=" << notif_seq_id + << " >= notif_seq_len=" << notif_seq_len; + return; + } + + // Check for duplicate fragment + if (!it->second.message_fragments[notif_seq_id].empty()) { + NIXL_WARN << "Duplicate fragment received: notif_seq_id=" << notif_seq_id; + return; + } + + // Store payload chunk (combined agent_name + message chunk) + it->second.message_fragments[notif_seq_id] = payload_chunk; + it->second.received_msg_fragments++; + + // Update metadata from fragment 0 (agent_name will be extracted after reassembly) + if (notif_seq_id == 0) { + it->second.expected_completions = expected_completions; + it->second.total_message_length = total_payload_length; + it->second.agent_name_length = agent_name_length; + } + + NIXL_DEBUG << "Stored fragment" << " notif_xfer_id=" << notif_xfer_id << " fragment " + << notif_seq_id << "/" << notif_seq_len + << " received_msg_fragments=" << it->second.received_msg_fragments + << " expected_completions=" << it->second.expected_completions + << " received_completions=" << it->second.received_completions; } + + // Check if any notifications can now be completed (after releasing the lock) + checkPendingNotifications(); } void @@ -1486,7 +1636,8 @@ nixlLibfabricEngine::processConnectionRequest(uint16_t agent_idx, // Allocate control request const size_t control_rail_id = 0; nixlLibfabricReq *control_request = - rail_manager.getControlRail(control_rail_id).allocateControlRequest(ep_name_len); + rail_manager.getControlRail(control_rail_id) + .allocateControlRequest(ep_name_len, LibfabricUtils::getNextXferId()); if (!control_request) { NIXL_ERROR << "Failed to allocate control request for connection ACK"; return NIXL_ERR_BACKEND; @@ -1518,28 +1669,26 @@ void nixlLibfabricEngine::addReceivedXferId(uint16_t xfer_id) { { std::lock_guard lock(receiver_tracking_mutex_); - auto it = pending_notifications_.find(xfer_id); - if (it != pending_notifications_.end()) { - // Case 1: Notification already exists (message arrived first or placeholder exists) - it->second.received_completions++; - - NIXL_DEBUG << "Incremented received count for XFER_ID " << xfer_id << ": " - << it->second.received_completions << "/" << it->second.expected_completions; - } else { - // Case 2: Write arrived before notification - create placeholder with INT_MAX - PendingNotification placeholder; - placeholder.remote_agent = ""; // Empty until notification arrives - placeholder.message = ""; // Empty until notification arrives - placeholder.post_xfer_id = xfer_id; - placeholder.expected_completions = INT_MAX; // Sentinel value - placeholder.received_completions = 1; // Start with this completion - - pending_notifications_[xfer_id] = placeholder; - - NIXL_DEBUG << "Created placeholder notification for posted_xfer_id " << xfer_id + // Use try_emplace to construct in-place - eliminates extra copy + // First parameter: map key for lookup + // Second parameter: constructor argument for PendingNotification + auto [it, inserted] = pending_notifications_.try_emplace(xfer_id, xfer_id); + + if (inserted) { + // Set placeholder values for write-arrived-first case + it->second.remote_agent = ""; + it->second.expected_completions = INT_MAX; + it->second.received_completions = 0; + it->second.expected_msg_fragments = 1; // Default to 1 fragment + it->second.received_msg_fragments = 0; + NIXL_DEBUG << "Created placeholder notification for notif_xfer_id " << xfer_id << " (write arrived first)"; } + + it->second.received_completions++; + NIXL_DEBUG << "Incremented received count for notif_xfer_id " << xfer_id << ": " + << it->second.received_completions << "/" << it->second.expected_completions; } // Check if any notifications can now be completed (after releasing the lock) @@ -1555,18 +1704,47 @@ nixlLibfabricEngine::checkPendingNotifications() { std::lock_guard lock(receiver_tracking_mutex_); auto it = pending_notifications_.begin(); while (it != pending_notifications_.end()) { - // Check if transfer is complete by checking if all the remote completions for - // the xfer_id are received. - if (it->second.received_completions >= it->second.expected_completions) { - NIXL_TRACE << "Received all remote completions for queued notification, processing now"; + // Check BOTH conditions: fragments complete AND writes complete + bool fragments_complete = + (it->second.received_msg_fragments >= it->second.expected_msg_fragments); + bool writes_complete = (it->second.received_completions >= it->second.expected_completions); + + if (fragments_complete && writes_complete) { + NIXL_TRACE << "Notification complete: fragments=" << it->second.received_msg_fragments + << "/" << it->second.expected_msg_fragments + << " writes=" << it->second.received_completions << "/" + << it->second.expected_completions; + + // Reassemble combined payload from fragments + std::string combined_payload; + combined_payload.reserve(it->second.total_message_length); + for (const auto &fragment : it->second.message_fragments) { + combined_payload.append(fragment); + } + + // Extract agent_name and message from combined payload + uint16_t agent_name_len = it->second.agent_name_length; + std::string remote_agent; + std::string message; + + if (agent_name_len > 0 && combined_payload.size() >= agent_name_len) { + remote_agent = combined_payload.substr(0, agent_name_len); + if (combined_payload.size() > agent_name_len) { + message = combined_payload.substr(agent_name_len); + } + } else { + NIXL_ERROR << "Invalid combined payload: agent_name_len=" << agent_name_len + << " combined_payload_size=" << combined_payload.size(); + } // Move notification to main list (need to acquire notif_mutex_) { std::lock_guard notif_lock(notif_mutex_); - notifMainList_.push_back({it->second.remote_agent, it->second.message}); + notifMainList_.push_back({remote_agent, message}); } - NIXL_TRACE << "Processed queued notification: " << it->second.message; + NIXL_TRACE << "Processed queued notification from " << remote_agent + << " message_len=" << message.length(); // Remove from pending list it = pending_notifications_.erase(it); diff --git a/src/plugins/libfabric/libfabric_backend.h b/src/plugins/libfabric/libfabric_backend.h index a19a058ad7..ed9a97c42b 100644 --- a/src/plugins/libfabric/libfabric_backend.h +++ b/src/plugins/libfabric/libfabric_backend.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,7 +26,6 @@ #include #include #include -#include #include #include @@ -137,11 +136,13 @@ class nixlLibfabricBackendH : public nixlBackendReqH { std::atomic submitted_requests_; // Total number of submitted requests public: + uint16_t post_xfer_id; const nixl_xfer_op_t operation_; const std::string remote_agent_; bool has_notif; + uint32_t total_notif_msg_len; // Total length of notification message across all fragments - BinaryNotification binary_notif; // Direct BinaryNotification instance + std::vector binary_notifs; // Vector of BinaryNotification for fragmentation nixlLibfabricBackendH(nixl_xfer_op_t op, const std::string &remote_agent); ~nixlLibfabricBackendH(); @@ -158,17 +159,17 @@ class nixlLibfabricBackendH : public nixlBackendReqH { void increment_completed_requests(); - /** Get current count of completed requests */ + /** Get current count of requests completed as part of this transfer */ size_t get_completed_requests_count() const; - /** Get total number of requests used for this transfer */ + /** Get total number of requests submitted as part of this transfer */ size_t - get_total_requests_used() const; + get_submitted_requests_count() const; - /** Adjust total request count to actual value after submissions complete */ + /** Adjust total submitted request count to actual value after submissions complete */ void - adjust_total_requests(size_t actual_count); + adjust_total_submitted_requests(size_t actual_count); }; class nixlLibfabricEngine : public nixlBackendEngine { @@ -224,28 +225,28 @@ class nixlLibfabricEngine : public nixlBackendEngine { // Notification Queuing struct PendingNotification { std::string remote_agent; - std::string message; - uint16_t post_xfer_id; - uint32_t expected_completions; // Expected transfer requests for this post_xfer_id + std::vector message_fragments; // Store each fragment separately + uint16_t notif_xfer_id; + uint32_t expected_completions; // Expected transfer requests for this notif_xfer_id uint32_t received_completions; // Actual remote transfer completions received for this - // post_xfer_id - - // Default constructor for map operations - PendingNotification() : post_xfer_id(0), expected_completions(0), received_completions(0) {} - - PendingNotification(const std::string &agent, - const std::string &msg, - uint16_t xfer_id, - uint32_t expected_cnt = 0) - : remote_agent(agent), - message(msg), - post_xfer_id(xfer_id), - expected_completions(expected_cnt), - received_completions(0) {} + // notif_xfer_id + uint16_t expected_msg_fragments; // Total fragments expected (from notif_seq_len) + uint16_t received_msg_fragments; // Fragments received so far + uint32_t total_message_length; // Total length of complete message (all fragments) + uint16_t agent_name_length; // Length of agent_name in combined payload + + PendingNotification(uint16_t xfer_id) + : notif_xfer_id(xfer_id), + expected_completions(0), + received_completions(0), + expected_msg_fragments(0), + received_msg_fragments(0), + total_message_length(0), + agent_name_length(0) {} }; // O(1) lookup with postXferID key - std::map pending_notifications_; + std::unordered_map pending_notifications_; // Connection management helpers nixl_status_t @@ -256,9 +257,21 @@ class nixlLibfabricEngine : public nixlBackendEngine { createAgentConnection(const std::string &agent_name, const std::vector> &data_rail_endpoints, const std::vector> &control_rail_endpoints); + // Private notification implementation with unified binary notification system nixl_status_t - notifSendPriv(const std::string &remote_agent, BinaryNotification &binary_notification) const; + notifSendPriv(const std::string &remote_agent, + std::vector &binary_notifications, + uint32_t total_message_length, + uint16_t notif_xfer_id, + uint32_t expected_completions) const; + + // Private function to fragment notification messages to binary notifications + void + fragmentNotificationMessage(const std::string &message, + const std::string &agent_name, + uint32_t &total_message_length, + std::vector &fragments_out) const; #ifdef HAVE_CUDA // CUDA context management std::unique_ptr cudaCtx_; diff --git a/src/utils/libfabric/libfabric_common.cpp b/src/utils/libfabric/libfabric_common.cpp index 140694e118..5257307162 100644 --- a/src/utils/libfabric/libfabric_common.cpp +++ b/src/utils/libfabric/libfabric_common.cpp @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -96,12 +96,11 @@ getAvailableNetworkDevices() { } std::string -hexdump(const void *data) { - static constexpr uint HEXDUMP_MAX_LENGTH = 56; +hexdump(const void *data, size_t size) { std::stringstream ss; - ss.str().reserve(HEXDUMP_MAX_LENGTH * 3); + ss.str().reserve(size * 3); const unsigned char *bytes = static_cast(data); - for (size_t i = 0; i < HEXDUMP_MAX_LENGTH; ++i) { + for (size_t i = 0; i < size; ++i) { ss << std::hex << std::setw(2) << std::setfill('0') << static_cast(bytes[i]) << " "; } return ss.str(); diff --git a/src/utils/libfabric/libfabric_common.h b/src/utils/libfabric/libfabric_common.h index f98339b3d5..dde212712d 100644 --- a/src/utils/libfabric/libfabric_common.h +++ b/src/utils/libfabric/libfabric_common.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,6 +23,7 @@ #include #include #include +#include #include "nixl.h" @@ -44,9 +45,10 @@ #define LF_EP_NAME_MAX_LEN 56 // Request pool configuration constants -#define NIXL_LIBFABRIC_CONTROL_REQUESTS_PER_RAIL 1024 // SEND/RECV operations (1:1 with buffers) +#define NIXL_LIBFABRIC_CONTROL_REQUESTS_PER_RAIL 4096 // SEND/RECV operations (1:1 with buffers) #define NIXL_LIBFABRIC_DATA_REQUESTS_PER_RAIL 1024 // WRITE/read operations (no buffers) #define NIXL_LIBFABRIC_SEND_RECV_BUFFER_SIZE 8192 +#define NIXL_LIBFABRIC_RECV_POOL_SIZE 1024 // Number of recv requests to pre-post per rail // Retry configuration constants #define NIXL_LIBFABRIC_MAX_RETRIES 10 @@ -99,51 +101,161 @@ (((uint64_t)(seq_id) & NIXL_SEQ_ID_MASK) << NIXL_SEQ_ID_SHIFT)) /** - * @brief Binary notification format with counter-based matching + * @brief Notification header for all fragments (10 bytes) * - * This structure provides a fixed-size, binary format for notifications + * This is present in every fragment and contains only the essential + * fields needed for fragment identification and reassembly. */ -struct BinaryNotification { - char agent_name[256]; // Fixed-size agent name (null-terminated) - char message[1024]; // Fixed-size message (binary data, not null-terminated) - uint32_t message_length; // Actual length of message data - uint16_t xfer_id; // 16-bit postXfer ID (unique per postXfer call) - uint32_t expected_completions; // Total write requests for this xfer_id - - /** @brief Clear all fields to zero */ +struct BinaryNotificationHeader { + uint16_t notif_xfer_id; // Transfer ID for matching notifications + uint16_t notif_seq_id; // Fragment index (0, 1, 2...) + uint16_t notif_seq_len; // Total number of fragments + uint32_t payload_length; // Message bytes of this fragment +} __attribute__((packed)); + +/** + * @brief Metadata for fragment 0 only (10 bytes) + * + * This contains metadata that is constant across all fragments, + * so we only send it once in the first fragment. + */ +struct BinaryNotificationMetadata { + uint32_t total_payload_length; // Total message bytes across all fragments + uint32_t expected_completions; // Expected RDMA write completions + uint16_t agent_name_length; // Actual length of agent_name +} __attribute__((packed)); + +/** + * @brief Binary notification with variable-length encoding and fragmentation support + * + * The notification payload consists of agent_name + message, which is treated as a single + * combined payload that can be fragmented across multiple network messages. + * + * Fragment 0 layout: [Header:10B] [Metadata:10B] [combined_payload_chunk:variable] + * Fragment 1+ layout: [Header:10B] [combined_payload_chunk:variable] + * + * After reassembly, use metadata.agent_name_length to split the combined payload: + * - agent_name = combined_payload.substr(0, agent_name_length) + * - message = combined_payload.substr(agent_name_length) + * + * @note The __attribute__((packed)) ensures consistent byte layout across platforms, + * preventing padding-related data corruption during network serialization. + */ +class BinaryNotification { +private: + BinaryNotificationHeader header_; + BinaryNotificationMetadata metadata_; // Only valid for seq_id=0 + std::string payload_; // Chunk of (agent_name + message) combined payload + +public: + /** @brief Maximum fragment size for control messages */ + static constexpr size_t MAX_FRAGMENT_SIZE = NIXL_LIBFABRIC_SEND_RECV_BUFFER_SIZE; + + /** @brief Constructor */ + BinaryNotification() { + memset(&header_, 0, sizeof(header_)); + memset(&metadata_, 0, sizeof(metadata_)); + } + + /** @brief Set header fields */ void - clear() { - memset(this, 0, sizeof(BinaryNotification)); + setHeader(const BinaryNotificationHeader &header) { + header_ = header; } - /** @brief Set agent name with bounds checking */ + /** + * @brief Set metadata (only valid for fragment 0) + * @param total_payload_length Total length of combined payload across all fragments + * @param expected_completions Expected RDMA write completions + * @param agent_name_length Length of agent_name within combined payload + * @pre header_.notif_seq_id must be 0 + */ void - setAgentName(const std::string &name) { - strncpy(agent_name, name.c_str(), sizeof(agent_name) - 1); - agent_name[sizeof(agent_name) - 1] = '\0'; + setMetadata(uint32_t total_payload_length, + uint32_t expected_completions, + uint16_t agent_name_length) { + assert(header_.notif_seq_id == 0 && "setMetadata() can only be called for fragment 0"); + metadata_.total_payload_length = total_payload_length; + metadata_.expected_completions = expected_completions; + metadata_.agent_name_length = agent_name_length; } - /** @brief Set message with bounds checking and proper binary data handling */ + /** + * @brief Set payload chunk for this fragment using move semantics + * @param payload Chunk of (agent_name + message) combined payload (passed by value for move) + * @note Also updates header_.payload_length to match the chunk size + */ void - setMessage(const std::string &msg) { - message_length = std::min(msg.length(), sizeof(message)); - memcpy(message, msg.data(), message_length); - // Zero out remaining space for consistency - if (message_length < sizeof(message)) { - memset(message + message_length, 0, sizeof(message) - message_length); - } + setPayload(std::string payload) { + payload_ = std::move(payload); + header_.payload_length = static_cast(payload_.length()); } - /** @brief Get agent name as string */ - std::string - getAgentName() const { - return std::string(agent_name); + /** @brief Get header (valid for all fragments) */ + const BinaryNotificationHeader & + getHeader() const { + return header_; } - /** @brief Get message as string using stored length for proper binary data handling */ - std::string - getMessage() const { - return std::string(message, message_length); + /** + * @brief Get metadata (only valid for fragment 0) + * @return Reference to metadata + * @pre header_.notif_seq_id must be 0 + */ + const BinaryNotificationMetadata & + getMetadata() const { + assert(header_.notif_seq_id == 0 && "getMetadata() can only be called for fragment 0"); + return metadata_; + } + + /** @brief Get payload chunk for this fragment */ + const std::string & + getPayload() const { + return payload_; + } + + /** @brief Serialize to buffer for transmission */ + size_t + serialize(void *buffer) const { + char *ptr = static_cast(buffer); + size_t offset = 0; + + // Write header (always present) + memcpy(ptr + offset, &header_, sizeof(header_)); + offset += sizeof(header_); + + if (header_.notif_seq_id == 0) { + // Fragment 0: write metadata + memcpy(ptr + offset, &metadata_, sizeof(metadata_)); + offset += sizeof(metadata_); + } + + // Write payload chunk (single memcpy) + memcpy(ptr + offset, payload_.data(), payload_.size()); + offset += payload_.size(); + + return offset; + } + + /** @brief Deserialize from buffer */ + static void + deserialize(const void *buffer, size_t size, BinaryNotification ¬if_out) { + const char *ptr = static_cast(buffer); + size_t offset = 0; + + // Read header + memcpy(¬if_out.header_, ptr + offset, sizeof(notif_out.header_)); + offset += sizeof(notif_out.header_); + + if (notif_out.header_.notif_seq_id == 0) { + // Fragment 0: read metadata + memcpy(¬if_out.metadata_, ptr + offset, sizeof(notif_out.metadata_)); + offset += sizeof(notif_out.metadata_); + } + + // Read payload chunk + size_t remaining = size - offset; + notif_out.payload_.assign(ptr + offset, remaining); } }; @@ -167,7 +279,7 @@ std::pair> getAvailableNetworkDevices(); // String utilities std::string -hexdump(const void *data); +hexdump(const void *data, size_t size); } // namespace LibfabricUtils #endif // NIXL_SRC_UTILS_LIBFABRIC_LIBFABRIC_COMMON_H diff --git a/src/utils/libfabric/libfabric_rail.cpp b/src/utils/libfabric/libfabric_rail.cpp index 379582b20e..bb46a1597d 100644 --- a/src/utils/libfabric/libfabric_rail.cpp +++ b/src/utils/libfabric/libfabric_rail.cpp @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -118,7 +118,7 @@ RequestPool::getPoolUtilization() const { } nixlLibfabricReq * -RequestPool::allocateReq() { +RequestPool::allocateReq(uint32_t req_id) { std::lock_guard lock(pool_mutex_); if (free_indices_.empty()) { @@ -148,7 +148,7 @@ RequestPool::allocateReq() { nixlLibfabricReq *req = &requests_[idx]; req->in_use = true; - req->xfer_id = LibfabricUtils::getNextXferId(); + req->xfer_id = req_id; return req; } @@ -299,7 +299,7 @@ ControlRequestPool::expandPool() { } nixlLibfabricReq * -ControlRequestPool::allocate(size_t needed_size) { +ControlRequestPool::allocate(size_t needed_size, uint32_t req_id) { // Validate size before attempting allocation if (needed_size > NIXL_LIBFABRIC_SEND_RECV_BUFFER_SIZE) { NIXL_ERROR << "Control pool allocation failed on rail " << rail_id_ << " - requested size " @@ -309,7 +309,7 @@ ControlRequestPool::allocate(size_t needed_size) { } // Use common allocation logic from base class - nixlLibfabricReq *req = allocateReq(); + nixlLibfabricReq *req = allocateReq(req_id); if (req) { // Always reset buffer_size to the actual message size needed @@ -370,9 +370,9 @@ DataRequestPool::expandPool() { } nixlLibfabricReq * -DataRequestPool::allocate(nixlLibfabricReq::OpType op_type) { +DataRequestPool::allocate(nixlLibfabricReq::OpType op_type, uint32_t req_id) { // Use common allocation logic from base class - nixlLibfabricReq *req = allocateReq(); + nixlLibfabricReq *req = allocateReq(req_id); if (req) { // Set the operation type specific to data requests req->operation_type = op_type; @@ -586,20 +586,29 @@ nixlLibfabricRail::nixlLibfabricRail(const std::string &device, << " control requests, " << NIXL_LIBFABRIC_DATA_REQUESTS_PER_RAIL << " data requests for rail " << rail_id; - // Post initial receive using new resource management system - nixlLibfabricReq *recv_req = allocateControlRequest(NIXL_LIBFABRIC_SEND_RECV_BUFFER_SIZE); - if (!recv_req) { - NIXL_ERROR << "Failed to allocate request for initial receive on rail " << rail_id; - throw std::runtime_error("Failed to allocate request for initial receive on rail " + - std::to_string(rail_id)); - } - status = postRecv(recv_req); - if (status != NIXL_SUCCESS) { - NIXL_ERROR << "Failed to post initial receive on rail " << rail_id; - releaseRequest(recv_req); - throw std::runtime_error("Failed to post initial receive on rail " + - std::to_string(rail_id)); + // Post initial pool of receives using new resource management system + NIXL_INFO << "Pre-posting " << NIXL_LIBFABRIC_RECV_POOL_SIZE << " recv requests for rail " + << rail_id; + + for (size_t i = 0; i < NIXL_LIBFABRIC_RECV_POOL_SIZE; ++i) { + nixlLibfabricReq *recv_req = allocateControlRequest( + NIXL_LIBFABRIC_SEND_RECV_BUFFER_SIZE, LibfabricUtils::getNextXferId()); + if (!recv_req) { + NIXL_ERROR << "Failed to allocate request for recv " << i << " on rail " << rail_id; + throw std::runtime_error("Failed to allocate request for recv pool on rail " + + std::to_string(rail_id)); + } + status = postRecv(recv_req); + if (status != NIXL_SUCCESS) { + NIXL_ERROR << "Failed to post recv " << i << " on rail " << rail_id; + releaseRequest(recv_req); + throw std::runtime_error("Failed to post recv pool on rail " + + std::to_string(rail_id)); + } } + + NIXL_INFO << "Successfully pre-posted " << NIXL_LIBFABRIC_RECV_POOL_SIZE + << " recv requests for rail " << rail_id; NIXL_TRACE << "Successfully initialized rail " << rail_id; } catch (...) { @@ -971,7 +980,8 @@ nixlLibfabricRail::processRecvCompletion(struct fi_cq_data_entry *comp) const { releaseRequest(req); // Post a new receive using new resource management system - nixlLibfabricReq *new_req = allocateControlRequest(NIXL_LIBFABRIC_SEND_RECV_BUFFER_SIZE); + nixlLibfabricReq *new_req = allocateControlRequest(NIXL_LIBFABRIC_SEND_RECV_BUFFER_SIZE, + LibfabricUtils::getNextXferId()); if (!new_req) { NIXL_ERROR << "Failed to allocate request for subsequent receive on rail " << rail_id; return NIXL_ERR_BACKEND; @@ -1455,13 +1465,13 @@ nixlLibfabricRail::getMemoryKey(struct fid_mr *mr) const { // Optimized Resource Management Methods nixlLibfabricReq * -nixlLibfabricRail::allocateControlRequest(size_t needed_size) const { - return const_cast(control_request_pool_).allocate(needed_size); +nixlLibfabricRail::allocateControlRequest(size_t needed_size, uint32_t req_id) const { + return const_cast(control_request_pool_).allocate(needed_size, req_id); } nixlLibfabricReq * -nixlLibfabricRail::allocateDataRequest(nixlLibfabricReq::OpType op_type) const { - return const_cast(data_request_pool_).allocate(op_type); +nixlLibfabricRail::allocateDataRequest(nixlLibfabricReq::OpType op_type, uint32_t req_id) const { + return const_cast(data_request_pool_).allocate(op_type, req_id); } void @@ -1499,3 +1509,8 @@ nixlLibfabricRail::findRequestFromContext(void *context) const { NIXL_ERROR << "No request found for context " << context << " on rail " << rail_id; return nullptr; } + +fi_info * +nixlLibfabricRail::getRailInfo() const { + return info; +} diff --git a/src/utils/libfabric/libfabric_rail.h b/src/utils/libfabric/libfabric_rail.h index 33c44da6dd..5c335c93bd 100644 --- a/src/utils/libfabric/libfabric_rail.h +++ b/src/utils/libfabric/libfabric_rail.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -109,7 +109,7 @@ class RequestPool { protected: /** Common allocation logic shared by both pool types */ nixlLibfabricReq * - allocateReq(); + allocateReq(uint32_t req_id); public: // Non-copyable and non-movable since we use unique_ptr for management @@ -162,7 +162,7 @@ class ControlRequestPool : public RequestPool { /** Allocate control request with size validation */ nixlLibfabricReq * - allocate(size_t needed_size); + allocate(size_t needed_size, uint32_t req_id); /** Expand pool by adding new buffer chunk - implements pure virtual */ nixl_status_t @@ -205,7 +205,7 @@ class DataRequestPool : public RequestPool { /** Allocate data request for specified operation type */ nixlLibfabricReq * - allocate(nixlLibfabricReq::OpType op_type); + allocate(nixlLibfabricReq::OpType op_type, uint32_t req_id); /** Expand pool by doubling request count - implements pure virtual */ nixl_status_t @@ -361,11 +361,11 @@ class nixlLibfabricRail { // Optimized resource management methods /** Allocate control request with size validation */ [[nodiscard]] nixlLibfabricReq * - allocateControlRequest(size_t needed_size) const; + allocateControlRequest(size_t needed_size, uint32_t req_id) const; /** Allocate data request for specified operation */ [[nodiscard]] nixlLibfabricReq * - allocateDataRequest(nixlLibfabricReq::OpType op_type) const; + allocateDataRequest(nixlLibfabricReq::OpType op_type, uint32_t req_id) const; /** Release request back to appropriate pool */ void @@ -375,6 +375,9 @@ class nixlLibfabricRail { nixlLibfabricReq * findRequestFromContext(void *context) const; + fi_info * + getRailInfo() const; + private: // Core libfabric resources struct fi_info *info; // from rail_infos[rail_id] diff --git a/src/utils/libfabric/libfabric_rail_manager.cpp b/src/utils/libfabric/libfabric_rail_manager.cpp index 5bf23ae165..92124340f7 100644 --- a/src/utils/libfabric/libfabric_rail_manager.cpp +++ b/src/utils/libfabric/libfabric_rail_manager.cpp @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -148,8 +148,12 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer( const std::vector &remote_selected_endpoints, const std::unordered_map> &dest_addrs, uint16_t agent_idx, + uint16_t xfer_id, std::function completion_callback, - BinaryNotification *binary_notif) { + size_t &submitted_count_out) { + // Initialize output parameter + submitted_count_out = 0; + if (selected_rails.empty()) { NIXL_ERROR << "No rails selected for transfer"; return NIXL_ERR_INVALID_PARAM; @@ -166,7 +170,7 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer( remote_selected_endpoints[counter_value % remote_selected_endpoints.size()]; NIXL_DEBUG << "rail " << rail_id << ", remote_ep_id " << remote_ep_id; // Allocate request - nixlLibfabricReq *req = data_rails_[rail_id]->allocateDataRequest(op_type); + nixlLibfabricReq *req = data_rails_[rail_id]->allocateDataRequest(op_type, xfer_id); if (!req) { NIXL_ERROR << "Failed to allocate request for rail " << rail_id; return NIXL_ERR_BACKEND; @@ -179,13 +183,12 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer( // For TCP providers, use offset 0 instead of virtual address // TCP providers don't support FI_MR_VIRT_ADDR and expect offset-based addressing - if (data_rails_[rail_id]->provider_name == "tcp" || - data_rails_[rail_id]->provider_name == "sockets") { + if (data_rails_[rail_id]->getRailInfo()->domain_attr->mr_mode & FI_MR_VIRT_ADDR) { + req->remote_addr = remote_base_addr; // Use virtual address for EFA and other providers + } else { req->remote_addr = 0; // Use offset 0 for TCP providers NIXL_DEBUG << "TCP provider detected: using offset 0 instead of virtual address " << (void *)remote_base_addr << " for rail " << rail_id; - } else { - req->remote_addr = remote_base_addr; // Use virtual address for EFA and other providers } req->local_mr = local_mrs[rail_id]; @@ -196,8 +199,8 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer( if (op_type == nixlLibfabricReq::WRITE) { // Generate next SEQ_ID for this specific write operation uint8_t seq_id = LibfabricUtils::getNextSeqId(); - uint64_t imm_data = NIXL_MAKE_IMM_DATA( - NIXL_LIBFABRIC_MSG_TRANSFER, agent_idx, binary_notif->xfer_id, seq_id); + uint64_t imm_data = + NIXL_MAKE_IMM_DATA(NIXL_LIBFABRIC_MSG_TRANSFER, agent_idx, xfer_id, seq_id); status = data_rails_[rail_id]->postWrite(req->local_addr, req->chunk_size, fi_mr_desc(req->local_mr), @@ -224,7 +227,8 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer( return status; } - binary_notif->expected_completions++; + // Track submitted request + submitted_count_out = 1; NIXL_DEBUG << "Round-robin: submitted single request on rail " << rail_id << " for " << transfer_size << " bytes, XFER_ID=" << req->xfer_id; @@ -242,7 +246,7 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer( size_t current_chunk_size = chunk_size + (i == num_rails - 1 ? remainder : 0); if (current_chunk_size == 0) break; // Allocate request - nixlLibfabricReq *req = data_rails_[rail_id]->allocateDataRequest(op_type); + nixlLibfabricReq *req = data_rails_[rail_id]->allocateDataRequest(op_type, xfer_id); if (!req) { NIXL_ERROR << "Failed to allocate request for rail " << rail_id; return NIXL_ERR_BACKEND; @@ -258,15 +262,14 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer( // For TCP providers, use offset instead of virtual address // TCP providers don't support FI_MR_VIRT_ADDR and expect offset-based addressing - if (data_rails_[rail_id]->provider_name == "tcp" || - data_rails_[rail_id]->provider_name == "sockets") { + if (data_rails_[rail_id]->getRailInfo()->domain_attr->mr_mode & FI_MR_VIRT_ADDR) { + req->remote_addr = remote_base_addr + + chunk_offset; // Use virtual address for EFA and other providers + } else { req->remote_addr = chunk_offset; // Use chunk offset for TCP providers NIXL_DEBUG << "TCP provider detected: using chunk offset " << chunk_offset << " instead of virtual address " << (void *)(remote_base_addr + chunk_offset) << " for rail " << rail_id; - } else { - req->remote_addr = remote_base_addr + - chunk_offset; // Use virtual address for EFA and other providers } req->local_mr = local_mrs[rail_id]; @@ -276,8 +279,8 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer( if (op_type == nixlLibfabricReq::WRITE) { // Generate next SEQ_ID for this specific transfer operation uint8_t seq_id = LibfabricUtils::getNextSeqId(); - uint64_t imm_data = NIXL_MAKE_IMM_DATA( - NIXL_LIBFABRIC_MSG_TRANSFER, agent_idx, binary_notif->xfer_id, seq_id); + uint64_t imm_data = + NIXL_MAKE_IMM_DATA(NIXL_LIBFABRIC_MSG_TRANSFER, agent_idx, xfer_id, seq_id); status = data_rails_[rail_id]->postWrite(req->local_addr, req->chunk_size, fi_mr_desc(req->local_mr), @@ -304,16 +307,14 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer( return status; } - binary_notif->expected_completions++; + // Track submitted request + submitted_count_out++; } - NIXL_DEBUG << "Striping: submitted " - << (binary_notif ? binary_notif->expected_completions : 0) << " requests for " + NIXL_DEBUG << "Striping: submitted " << submitted_count_out << " requests for " << transfer_size << " bytes"; } - NIXL_DEBUG << "Successfully submitted " - << (binary_notif ? binary_notif->expected_completions : 0) << " requests for " - << transfer_size << " bytes"; + NIXL_DEBUG << "Successfully submitted requests for " << transfer_size << " bytes"; return NIXL_SUCCESS; } diff --git a/src/utils/libfabric/libfabric_rail_manager.h b/src/utils/libfabric/libfabric_rail_manager.h index 9baca7c18f..e824eb2777 100644 --- a/src/utils/libfabric/libfabric_rail_manager.h +++ b/src/utils/libfabric/libfabric_rail_manager.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -167,8 +167,9 @@ class nixlLibfabricRailManager { * @param remote_selected_endpoints Selected remote endpoints, where remote keys are registered * @param dest_addrs Destination addresses for each rail * @param agent_idx Remote agent index for immediate data + * @param xfer_id Transfer ID for tracking * @param completion_callback Callback for completion notification - * @param binary_notif Binary notification to populate with XFER_IDs + * @param submitted_count_out Number of requests successfully submitted * @return NIXL_SUCCESS on success, error code on failure */ nixl_status_t @@ -182,8 +183,9 @@ class nixlLibfabricRailManager { const std::vector &remote_selected_endpoints, const std::unordered_map> &dest_addrs, uint16_t agent_idx, + uint16_t xfer_id, std::function completion_callback, - BinaryNotification *binary_notif); + size_t &submitted_count_out); /** Determine if striping should be used for given transfer size * @param transfer_size Size of the transfer in bytes * @return true if striping should be used, false for round-robin