Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
512 changes: 345 additions & 167 deletions src/plugins/libfabric/libfabric_backend.cpp

Large diffs are not rendered by default.

69 changes: 41 additions & 28 deletions src/plugins/libfabric/libfabric_backend.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -26,7 +26,6 @@
#include <condition_variable>
#include <atomic>
#include <chrono>
#include <map>
#include <unordered_map>
#include <unordered_set>

Expand Down Expand Up @@ -137,11 +136,13 @@ class nixlLibfabricBackendH : public nixlBackendReqH {
std::atomic<size_t> submitted_requests_; // Total number of submitted requests

public:
uint16_t post_xfer_id;
const nixl_xfer_op_t operation_;
const std::string remote_agent_;
bool has_notif;
uint32_t total_notif_msg_len; // Total length of notification message across all fragments

BinaryNotification binary_notif; // Direct BinaryNotification instance
std::vector<BinaryNotification> binary_notifs; // Vector of BinaryNotification for fragmentation

nixlLibfabricBackendH(nixl_xfer_op_t op, const std::string &remote_agent);
~nixlLibfabricBackendH();
Expand All @@ -158,17 +159,17 @@ class nixlLibfabricBackendH : public nixlBackendReqH {
void
increment_completed_requests();

/** Get current count of completed requests */
/** Get current count of requests completed as part of this transfer */
size_t
get_completed_requests_count() const;

/** Get total number of requests used for this transfer */
/** Get total number of requests submitted as part of this transfer */
size_t
get_total_requests_used() const;
get_submitted_requests_count() const;

/** Adjust total request count to actual value after submissions complete */
/** Adjust total submitted request count to actual value after submissions complete */
void
adjust_total_requests(size_t actual_count);
adjust_total_submitted_requests(size_t actual_count);
};

class nixlLibfabricEngine : public nixlBackendEngine {
Expand Down Expand Up @@ -224,28 +225,28 @@ class nixlLibfabricEngine : public nixlBackendEngine {
// Notification Queuing
struct PendingNotification {
std::string remote_agent;
std::string message;
uint16_t post_xfer_id;
uint32_t expected_completions; // Expected transfer requests for this post_xfer_id
std::vector<std::string> message_fragments; // Store each fragment separately
uint16_t notif_xfer_id;
uint32_t expected_completions; // Expected transfer requests for this notif_xfer_id
uint32_t received_completions; // Actual remote transfer completions received for this
// post_xfer_id

// Default constructor for map operations
PendingNotification() : post_xfer_id(0), expected_completions(0), received_completions(0) {}

PendingNotification(const std::string &agent,
const std::string &msg,
uint16_t xfer_id,
uint32_t expected_cnt = 0)
: remote_agent(agent),
message(msg),
post_xfer_id(xfer_id),
expected_completions(expected_cnt),
received_completions(0) {}
// notif_xfer_id
uint16_t expected_msg_fragments; // Total fragments expected (from notif_seq_len)
uint16_t received_msg_fragments; // Fragments received so far
uint32_t total_message_length; // Total length of complete message (all fragments)
uint16_t agent_name_length; // Length of agent_name in combined payload

PendingNotification(uint16_t xfer_id)
: notif_xfer_id(xfer_id),
expected_completions(0),
received_completions(0),
expected_msg_fragments(0),
received_msg_fragments(0),
total_message_length(0),
agent_name_length(0) {}
};

// O(1) lookup with postXferID key
std::map<uint16_t, PendingNotification> pending_notifications_;
std::unordered_map<uint16_t, PendingNotification> pending_notifications_;

// Connection management helpers
nixl_status_t
Expand All @@ -256,9 +257,21 @@ class nixlLibfabricEngine : public nixlBackendEngine {
createAgentConnection(const std::string &agent_name,
const std::vector<std::array<char, 56>> &data_rail_endpoints,
const std::vector<std::array<char, 56>> &control_rail_endpoints);

// Private notification implementation with unified binary notification system
nixl_status_t
notifSendPriv(const std::string &remote_agent, BinaryNotification &binary_notification) const;
notifSendPriv(const std::string &remote_agent,
std::vector<BinaryNotification> &binary_notifications,
uint32_t total_message_length,
uint16_t notif_xfer_id,
uint32_t expected_completions) const;

// Private function to fragment notification messages to binary notifications
void
fragmentNotificationMessage(const std::string &message,
const std::string &agent_name,
uint32_t &total_message_length,
std::vector<BinaryNotification> &fragments_out) const;
#ifdef HAVE_CUDA
// CUDA context management
std::unique_ptr<nixlLibfabricCudaCtx> cudaCtx_;
Expand Down
11 changes: 5 additions & 6 deletions src/utils/libfabric/libfabric_common.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -96,12 +96,11 @@ getAvailableNetworkDevices() {
}

std::string
hexdump(const void *data) {
static constexpr uint HEXDUMP_MAX_LENGTH = 56;
hexdump(const void *data, size_t size) {
std::stringstream ss;
ss.str().reserve(HEXDUMP_MAX_LENGTH * 3);
ss.str().reserve(size * 3);
const unsigned char *bytes = static_cast<const unsigned char *>(data);
for (size_t i = 0; i < HEXDUMP_MAX_LENGTH; ++i) {
for (size_t i = 0; i < size; ++i) {
ss << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(bytes[i]) << " ";
}
return ss.str();
Expand Down
184 changes: 148 additions & 36 deletions src/utils/libfabric/libfabric_common.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -23,6 +23,7 @@
#include <unordered_set>
#include <unordered_map>
#include <cstring>
#include <cassert>

#include "nixl.h"

Expand All @@ -44,9 +45,10 @@
#define LF_EP_NAME_MAX_LEN 56

// Request pool configuration constants
#define NIXL_LIBFABRIC_CONTROL_REQUESTS_PER_RAIL 1024 // SEND/RECV operations (1:1 with buffers)
#define NIXL_LIBFABRIC_CONTROL_REQUESTS_PER_RAIL 4096 // SEND/RECV operations (1:1 with buffers)
#define NIXL_LIBFABRIC_DATA_REQUESTS_PER_RAIL 1024 // WRITE/read operations (no buffers)
#define NIXL_LIBFABRIC_SEND_RECV_BUFFER_SIZE 8192
#define NIXL_LIBFABRIC_RECV_POOL_SIZE 1024 // Number of recv requests to pre-post per rail

// Retry configuration constants
#define NIXL_LIBFABRIC_MAX_RETRIES 10
Expand Down Expand Up @@ -99,51 +101,161 @@
(((uint64_t)(seq_id) & NIXL_SEQ_ID_MASK) << NIXL_SEQ_ID_SHIFT))

/**
* @brief Binary notification format with counter-based matching
* @brief Notification header for all fragments (10 bytes)
*
* This structure provides a fixed-size, binary format for notifications
* This is present in every fragment and contains only the essential
* fields needed for fragment identification and reassembly.
*/
struct BinaryNotification {
char agent_name[256]; // Fixed-size agent name (null-terminated)
char message[1024]; // Fixed-size message (binary data, not null-terminated)
uint32_t message_length; // Actual length of message data
uint16_t xfer_id; // 16-bit postXfer ID (unique per postXfer call)
uint32_t expected_completions; // Total write requests for this xfer_id

/** @brief Clear all fields to zero */
struct BinaryNotificationHeader {
uint16_t notif_xfer_id; // Transfer ID for matching notifications
uint16_t notif_seq_id; // Fragment index (0, 1, 2...)
uint16_t notif_seq_len; // Total number of fragments
uint32_t payload_length; // Message bytes of this fragment
} __attribute__((packed));

/**
* @brief Metadata for fragment 0 only (10 bytes)
*
* This contains metadata that is constant across all fragments,
* so we only send it once in the first fragment.
*/
struct BinaryNotificationMetadata {
uint32_t total_payload_length; // Total message bytes across all fragments
uint32_t expected_completions; // Expected RDMA write completions
uint16_t agent_name_length; // Actual length of agent_name
} __attribute__((packed));

/**
* @brief Binary notification with variable-length encoding and fragmentation support
*
* The notification payload consists of agent_name + message, which is treated as a single
* combined payload that can be fragmented across multiple network messages.
*
* Fragment 0 layout: [Header:10B] [Metadata:10B] [combined_payload_chunk:variable]
* Fragment 1+ layout: [Header:10B] [combined_payload_chunk:variable]
*
* After reassembly, use metadata.agent_name_length to split the combined payload:
* - agent_name = combined_payload.substr(0, agent_name_length)
* - message = combined_payload.substr(agent_name_length)
*
* @note The __attribute__((packed)) ensures consistent byte layout across platforms,
* preventing padding-related data corruption during network serialization.
*/
class BinaryNotification {
private:
BinaryNotificationHeader header_;
BinaryNotificationMetadata metadata_; // Only valid for seq_id=0
std::string payload_; // Chunk of (agent_name + message) combined payload

public:
/** @brief Maximum fragment size for control messages */
static constexpr size_t MAX_FRAGMENT_SIZE = NIXL_LIBFABRIC_SEND_RECV_BUFFER_SIZE;

/** @brief Constructor */
BinaryNotification() {
memset(&header_, 0, sizeof(header_));
memset(&metadata_, 0, sizeof(metadata_));
}

/** @brief Set header fields */
void
clear() {
memset(this, 0, sizeof(BinaryNotification));
setHeader(const BinaryNotificationHeader &header) {
header_ = header;
}

/** @brief Set agent name with bounds checking */
/**
* @brief Set metadata (only valid for fragment 0)
* @param total_payload_length Total length of combined payload across all fragments
* @param expected_completions Expected RDMA write completions
* @param agent_name_length Length of agent_name within combined payload
* @pre header_.notif_seq_id must be 0
*/
void
setAgentName(const std::string &name) {
strncpy(agent_name, name.c_str(), sizeof(agent_name) - 1);
agent_name[sizeof(agent_name) - 1] = '\0';
setMetadata(uint32_t total_payload_length,
uint32_t expected_completions,
uint16_t agent_name_length) {
assert(header_.notif_seq_id == 0 && "setMetadata() can only be called for fragment 0");
metadata_.total_payload_length = total_payload_length;
metadata_.expected_completions = expected_completions;
metadata_.agent_name_length = agent_name_length;
}

/** @brief Set message with bounds checking and proper binary data handling */
/**
* @brief Set payload chunk for this fragment using move semantics
* @param payload Chunk of (agent_name + message) combined payload (passed by value for move)
* @note Also updates header_.payload_length to match the chunk size
*/
void
setMessage(const std::string &msg) {
message_length = std::min(msg.length(), sizeof(message));
memcpy(message, msg.data(), message_length);
// Zero out remaining space for consistency
if (message_length < sizeof(message)) {
memset(message + message_length, 0, sizeof(message) - message_length);
}
setPayload(std::string payload) {
payload_ = std::move(payload);
header_.payload_length = static_cast<uint32_t>(payload_.length());
}

/** @brief Get agent name as string */
std::string
getAgentName() const {
return std::string(agent_name);
/** @brief Get header (valid for all fragments) */
const BinaryNotificationHeader &
getHeader() const {
return header_;
}

/** @brief Get message as string using stored length for proper binary data handling */
std::string
getMessage() const {
return std::string(message, message_length);
/**
* @brief Get metadata (only valid for fragment 0)
* @return Reference to metadata
* @pre header_.notif_seq_id must be 0
*/
const BinaryNotificationMetadata &
getMetadata() const {
assert(header_.notif_seq_id == 0 && "getMetadata() can only be called for fragment 0");
return metadata_;
}

/** @brief Get payload chunk for this fragment */
const std::string &
getPayload() const {
return payload_;
}

/** @brief Serialize to buffer for transmission */
size_t
serialize(void *buffer) const {
char *ptr = static_cast<char *>(buffer);
size_t offset = 0;

// Write header (always present)
memcpy(ptr + offset, &header_, sizeof(header_));
offset += sizeof(header_);

if (header_.notif_seq_id == 0) {
// Fragment 0: write metadata
memcpy(ptr + offset, &metadata_, sizeof(metadata_));
offset += sizeof(metadata_);
}

// Write payload chunk (single memcpy)
memcpy(ptr + offset, payload_.data(), payload_.size());
offset += payload_.size();

return offset;
}

/** @brief Deserialize from buffer */
static void
deserialize(const void *buffer, size_t size, BinaryNotification &notif_out) {
const char *ptr = static_cast<const char *>(buffer);
size_t offset = 0;

// Read header
memcpy(&notif_out.header_, ptr + offset, sizeof(notif_out.header_));
offset += sizeof(notif_out.header_);

if (notif_out.header_.notif_seq_id == 0) {
// Fragment 0: read metadata
memcpy(&notif_out.metadata_, ptr + offset, sizeof(notif_out.metadata_));
offset += sizeof(notif_out.metadata_);
}

// Read payload chunk
size_t remaining = size - offset;
notif_out.payload_.assign(ptr + offset, remaining);
}
};

Expand All @@ -167,7 +279,7 @@ std::pair<std::string, std::vector<std::string>>
getAvailableNetworkDevices();
// String utilities
std::string
hexdump(const void *data);
hexdump(const void *data, size_t size);
} // namespace LibfabricUtils

#endif // NIXL_SRC_UTILS_LIBFABRIC_LIBFABRIC_COMMON_H
Loading
Loading