Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions src/plugins/libfabric/libfabric_backend.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -56,7 +56,7 @@

#ifdef HAVE_CUDA
static int
cudaQueryAddr(void *address, bool &is_dev, CUdevice &dev, CUcontext &ctx) {
cudaQueryAddr(void *address, bool &is_dev, CUdevice &dev, CUcontext &ctx, std::string &pci_bus_id) {
CUmemorytype mem_type = CU_MEMORYTYPE_HOST;
uint32_t is_managed = 0;
CUpointer_attribute attr_type[4];
Expand All @@ -75,6 +75,19 @@ cudaQueryAddr(void *address, bool &is_dev, CUdevice &dev, CUcontext &ctx) {
result = cuPointerGetAttributes(4, attr_type, attr_data, (CUdeviceptr)address);
is_dev = (mem_type == CU_MEMORYTYPE_DEVICE);

// Get PCI bus ID if device memory
if (result == CUDA_SUCCESS && is_dev) {
char pci_buf[32];
CUresult pci_result = cuDeviceGetPCIBusId(pci_buf, sizeof(pci_buf), dev);
if (pci_result == CUDA_SUCCESS) {
pci_bus_id = std::string(pci_buf);
} else {
pci_bus_id = "";
}
} else {
pci_bus_id = "";
}

return (CUDA_SUCCESS != result);
}

Expand All @@ -89,14 +102,15 @@ nixlLibfabricCudaCtx::cudaUpdateCtxPtr(void *address, int expected_dev, bool &wa
bool is_dev;
CUdevice dev;
CUcontext ctx;
std::string pci_bus_id; // Not used here, but required by cudaQueryAddr
int ret;

was_updated = false;

if (expected_dev == -1) return -1;
if (myDevId_ != -1 && expected_dev != myDevId_) return -1;

ret = cudaQueryAddr(address, is_dev, dev, ctx);
ret = cudaQueryAddr(address, is_dev, dev, ctx, pci_bus_id);
if (ret) return ret;
if (!is_dev) return 0;
if (dev != expected_dev) return -1;
Expand Down Expand Up @@ -734,6 +748,7 @@ nixlLibfabricEngine::registerMem(const nixlBlobDesc &mem,
priv->length_ = mem.len;
priv->gpu_device_id_ = mem.devId; // Store GPU device ID

std::string pci_bus_id = "";
#ifdef HAVE_CUDA
// Handle CUDA memory registration with GPU Direct RDMA support
if (nixl_mem == VRAM_SEG) {
Expand All @@ -760,6 +775,19 @@ nixlLibfabricEngine::registerMem(const nixlBlobDesc &mem,
}
NIXL_DEBUG << "Set CUDA device context to GPU " << mem.devId;
}

// Query PCI bus ID from memory address (AFTER setting context)
bool is_dev;
CUdevice dev;
CUcontext ctx;

int ret = cudaQueryAddr((void *)mem.addr, is_dev, dev, ctx, pci_bus_id);
if (ret || !is_dev) {
NIXL_ERROR << "Failed to query device from memory " << (void *)mem.addr;
return NIXL_ERR_BACKEND;
}

NIXL_DEBUG << "Queried PCI bus ID: " << pci_bus_id << " for GPU " << mem.devId;
}
#endif

Expand All @@ -777,12 +805,14 @@ nixlLibfabricEngine::registerMem(const nixlBlobDesc &mem,

// Use Rail Manager for centralized memory registration with GPU Direct RDMA support
NIXL_TRACE << "Registering memory: addr=" << (void *)mem.addr << " len=" << mem.len
<< " mem_type=" << nixl_mem << " devId=" << mem.devId;
<< " mem_type=" << nixl_mem << " devId=" << mem.devId
<< (nixl_mem == VRAM_SEG ? " pci_bus_id=" + pci_bus_id : "");

nixl_status_t status = rail_manager.registerMemory((void *)mem.addr,
mem.len,
nixl_mem,
mem.devId,
pci_bus_id,
priv->rail_mr_list_,
priv->rail_key_list_,
priv->selected_rails_);
Expand Down
42 changes: 29 additions & 13 deletions src/utils/libfabric/libfabric_rail_manager.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -21,6 +21,7 @@
#include "libfabric/libfabric_topology.h"
#include "common/nixl_log.h"
#include "serdes/serdes.h"
#include <sstream>

// Forward declaration for LibfabricUtils namespace
namespace LibfabricUtils {
Expand All @@ -46,6 +47,7 @@ nixlLibfabricRailManager::nixlLibfabricRailManager(size_t striping_threshold)

// Get network devices from topology and create rails automatically
std::vector<std::string> all_devices = topology->getAllDevices();

std::string selected_provider_name = topology->getProviderName();

NIXL_DEBUG << "Got " << all_devices.size()
Expand Down Expand Up @@ -321,16 +323,25 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer(
std::vector<size_t>
nixlLibfabricRailManager::selectRailsForMemory(void *mem_addr,
nixl_mem_t mem_type,
int gpu_id) const {
int gpu_id,
const std::string &gpu_pci_bus_id) const {
if (mem_type == VRAM_SEG) {
#ifdef HAVE_CUDA
if (gpu_id < 0) {
NIXL_ERROR << "Invalid GPU ID " << gpu_id << " for VRAM memory " << mem_addr;
return {}; // Return empty vector to indicate failure
}
std::vector<std::string> gpu_efa_devices = topology->getEfaDevicesForGpu(gpu_id);

// Use PCI bus ID provided by caller (queried in backend layer)
if (gpu_pci_bus_id.empty()) {
NIXL_ERROR << "Empty PCI bus ID provided for VRAM memory " << mem_addr;
return {}; // Return empty vector to indicate failure
}

// Get EFA devices for this PCI bus ID
std::vector<std::string> gpu_efa_devices = topology->getEfaDevicesForGPUPci(gpu_pci_bus_id);
if (gpu_efa_devices.empty()) {
NIXL_ERROR << "No EFA devices found for GPU " << gpu_id;
NIXL_ERROR << "No EFA devices found for PCI " << gpu_pci_bus_id;
return {}; // Return empty vector to indicate failure
}
std::vector<size_t> gpu_rails;
Expand All @@ -340,26 +351,26 @@ nixlLibfabricRailManager::selectRailsForMemory(void *mem_addr,
// Bounds check: ensure rail index is valid
if (it->second < data_rails_.size()) {
gpu_rails.push_back(it->second);
NIXL_DEBUG << "VRAM memory " << mem_addr << " on GPU " << gpu_id
NIXL_DEBUG << "VRAM memory " << mem_addr << " on GPU-PCI " << gpu_pci_bus_id
<< " mapped to rail " << it->second << " (EFA device=" << efa_device
<< ")";
} else {
NIXL_WARN << "EFA device " << efa_device << " maps to rail " << it->second
<< " but only " << data_rails_.size() << " rails available";
}
} else {
NIXL_WARN << "EFA device " << efa_device << " not found in rail mapping for GPU "
<< gpu_id;
NIXL_WARN << "EFA device " << efa_device
<< " not found in rail mapping for GPU-PCI " << gpu_pci_bus_id;
}
}

if (gpu_rails.empty()) {
NIXL_ERROR << "No valid rail mapping found for GPU " << gpu_id << " (checked "
<< gpu_efa_devices.size() << " EFA devices)";
NIXL_ERROR << "No valid rail mapping found for GPU-PCI " << gpu_pci_bus_id
<< " (checked " << gpu_efa_devices.size() << " EFA devices)";
return {};
}

NIXL_DEBUG << "VRAM memory " << mem_addr << " on GPU " << gpu_id << " will use "
NIXL_DEBUG << "VRAM memory " << mem_addr << " on GPU-PCI " << gpu_pci_bus_id << " will use "
<< gpu_rails.size() << " rails total";
return gpu_rails;
#else
Expand Down Expand Up @@ -390,6 +401,7 @@ nixlLibfabricRailManager::registerMemory(void *buffer,
size_t length,
nixl_mem_t mem_type,
int gpu_id,
const std::string &gpu_pci_bus_id,
std::vector<struct fid_mr *> &mr_list_out,
std::vector<uint64_t> &key_list_out,
std::vector<size_t> &selected_rails_out) {
Expand All @@ -398,8 +410,11 @@ nixlLibfabricRailManager::registerMemory(void *buffer,
return NIXL_ERR_INVALID_PARAM;
}

// Use internal rail selection with explicit GPU ID
std::vector<size_t> selected_rails = selectRailsForMemory(buffer, mem_type, gpu_id);
// Select rails based on memory type and PCI bus ID
// For VRAM: uses PCI bus ID provided by backend to map to topology-aware rails
// For DRAM: uses all available rails
std::vector<size_t> selected_rails =
selectRailsForMemory(buffer, mem_type, gpu_id, gpu_pci_bus_id);
if (selected_rails.empty()) {
NIXL_ERROR << "No rails selected for memory type " << mem_type;
return NIXL_ERR_NOT_SUPPORTED;
Expand Down Expand Up @@ -429,6 +444,7 @@ nixlLibfabricRailManager::registerMemory(void *buffer,

struct fid_mr *mr;
uint64_t key;
// Pass gpu_id parameter to individual rail's registerMemory calls
nixl_status_t status =
data_rails_[rail_idx]->registerMemory(buffer, length, mem_type, gpu_id, &mr, &key);
if (status != NIXL_SUCCESS) {
Expand Down
11 changes: 8 additions & 3 deletions src/utils/libfabric/libfabric_rail_manager.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025 Amazon.com, Inc. and affiliates.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 Amazon.com, Inc. and affiliates.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -110,6 +110,7 @@ class nixlLibfabricRailManager {
* @param length Buffer size in bytes
* @param mem_type Memory type (DRAM_SEG or VRAM_SEG)
* @param gpu_id GPU device ID (used for VRAM_SEG, ignored for DRAM_SEG)
* @param gpu_pci_bus_id PCI bus ID for VRAM-GPU (queried in backend layer), empty for DRAM
* @param mr_list_out Memory registration handles, indexed by rail ID
* @param key_list_out Remote access keys, indexed by rail ID
* @param selected_rails_out List of rail IDs where memory was registered
Expand All @@ -120,6 +121,7 @@ class nixlLibfabricRailManager {
size_t length,
nixl_mem_t mem_type,
int gpu_id,
const std::string &gpu_pci_bus_id,
std::vector<struct fid_mr *> &mr_list_out,
std::vector<uint64_t> &key_list_out,
std::vector<size_t> &selected_rails_out);
Expand Down Expand Up @@ -316,7 +318,10 @@ class nixlLibfabricRailManager {

// Internal rail selection method
std::vector<size_t>
selectRailsForMemory(void *mem_addr, nixl_mem_t mem_type, int gpu_id) const;
selectRailsForMemory(void *mem_addr,
nixl_mem_t mem_type,
int gpu_id,
const std::string &pci_bus_id = "") const;

// Helper functions for connection SerDes
void
Expand Down
Loading
Loading