diff --git a/benchmark/kvbench/README.md b/benchmark/kvbench/README.md index d8214de562..2f9ce21e91 100644 --- a/benchmark/kvbench/README.md +++ b/benchmark/kvbench/README.md @@ -118,7 +118,7 @@ These arguments are used by both `plan` and `profile` commands: | -------- | ----------- | | `--source` | Source of the nixl descriptors [file, memory, gpu] (default: file) | | `--destination` | Destination of the nixl descriptors [file, memory, gpu] (default: memory) | -| `--backend` | Communication backend [UCX, UCX_MO, GDS, GDS_MT, POSIX, GPUNETIO, Mooncake, HF3FS, OBJ] (default: UCX) | +| `--backend` | Communication backend [UCX, UCX_MO, GDS, GDS_MT, POSIX, GPUNETIO, Mooncake, HF3FS, OBJ, LIBFABRIC] (default: UCX) | | `--worker_type` | Worker to use to transfer data [nixl, nvshmem] (default: nixl) | | `--initiator_seg_type` | Memory segment type for initiator [DRAM, VRAM] (default: DRAM) | | `--target_seg_type` | Memory segment type for target [DRAM, VRAM] (default: DRAM) | diff --git a/benchmark/kvbench/commands/args.py b/benchmark/kvbench/commands/args.py index 38480bc6a1..9415a5c559 100644 --- a/benchmark/kvbench/commands/args.py +++ b/benchmark/kvbench/commands/args.py @@ -72,7 +72,7 @@ def nixl_bench_args(func): func = click.option( "--backend", type=str, - help="Communication backend [UCX, UCX_MO, GDS, GDS_MT, POSIX, GPUNETIO, Mooncake, HF3FS, OBJ] (default: UCX)", + help="Communication backend [UCX, UCX_MO, GDS, GDS_MT, POSIX, GPUNETIO, Mooncake, HF3FS, OBJ, LIBFABRIC] (default: UCX)", )(func) func = click.option( "--worker_type", diff --git a/benchmark/kvbench/test/custom_traffic_perftest.py b/benchmark/kvbench/test/custom_traffic_perftest.py index 51b4be2958..ad2ecb64d3 100644 --- a/benchmark/kvbench/test/custom_traffic_perftest.py +++ b/benchmark/kvbench/test/custom_traffic_perftest.py @@ -55,6 +55,8 @@ def __init__( self.nixl_agent = nixl_agent if mem_type in ("cuda", "vram"): device = torch.device("cuda") + else mem_type in ("hpu", "vram"): + device = torch.device("hpu") elif mem_type in ("cpu", "dram"): device = torch.device("cpu") else: @@ -95,6 +97,8 @@ def destroy(self): if hasattr(self.buf, "is_cuda") and self.buf.is_cuda: del self.buf torch.cuda.empty_cache() + if hasattr(self.buf, "is_hpu") and self.buf.is_hpu: + del self.buf class CTPerftest: @@ -122,6 +126,15 @@ def __init__( logger.warning( "Cuda buffers detected, but the env var CUDA_VISIBLE_DEVICES is not set, this will cause every process in the same host to use the same GPU device." ) + + if ( + not os.environ.get("HABANA_VISIBLE_MODULES") + and self.traffic_pattern.mem_type == "hpu" + ): + logger.warning( + "hpu buffers detected, but the env var HABANA_VISIBLE_DEVICES is not set, this will cause every process in the same host to use the same GPU device." + ) + """Initialize the buffers, one big send and recv buffer is used for all the transfers it has to be chunked inside each transfer to get buffers per ranks @@ -250,7 +263,7 @@ def _warmup( self, iters=15, fill_value: int = 100000, - mem_type: Literal["cuda", "vram", "cpu", "dram"] = "cuda", + mem_type: Literal["cuda", "vram", "cpu", "dram", "hpu", "vram"] = "cuda", ): full_matrix = np.full((self.world_size, self.world_size), fill_value=fill_value) tp = TrafficPattern(matrix=full_matrix, mem_type=mem_type) diff --git a/benchmark/kvbench/test/sequential_custom_traffic_perftest.py b/benchmark/kvbench/test/sequential_custom_traffic_perftest.py index c68a5e6c98..89888e456a 100644 --- a/benchmark/kvbench/test/sequential_custom_traffic_perftest.py +++ b/benchmark/kvbench/test/sequential_custom_traffic_perftest.py @@ -71,6 +71,15 @@ def __init__( logger.warning( "Cuda buffers detected, but the env var CUDA_VISIBLE_DEVICES is not set, this will cause every process in the same host to use the same GPU device." ) + + if ( + not os.environ.get("HABANA_VISIBLE_MODULES") + and self.traffic_pattern.mem_type == "hpu" + ): + logger.warning( + "hpu buffers detected, but the env var HABANA_VISIBLE_DEVICES is not set, this will cause every process in the same host to use the same GPU device." + ) + assert "UCX" in self.nixl_agent.get_plugin_list(), "UCX plugin is not loaded" # NixlBuffer caches buffers and reuse them if they are big enough, let's initialize them once, with the largest needed size diff --git a/benchmark/kvbench/test/traffic_pattern.py b/benchmark/kvbench/test/traffic_pattern.py index 1f612279af..93e6a1fd4e 100644 --- a/benchmark/kvbench/test/traffic_pattern.py +++ b/benchmark/kvbench/test/traffic_pattern.py @@ -35,7 +35,7 @@ class TrafficPattern: """ matrix: np.ndarray - mem_type: Literal["cuda", "vram", "cpu", "dram"] + mem_type: Literal["cuda", "vram", "cpu", "dram","hpu"] xfer_op: Literal["WRITE", "READ"] = "WRITE" shards: int = 1 dtype: torch.dtype = torch.int8 diff --git a/benchmark/nixlbench/meson.build b/benchmark/nixlbench/meson.build index 1910e7ad31..1d593798cb 100644 --- a/benchmark/nixlbench/meson.build +++ b/benchmark/nixlbench/meson.build @@ -103,16 +103,87 @@ if cuda_available endif endif +# SynapseAI (Habana Gaudi) dependency detection +synapse_inc_path = get_option('synapsepath_inc') +synapse_lib_path = get_option('synapsepath_lib') + +if synapse_lib_path == '' + #use default path + # Try to find both libSynapse and hl-thunk libraries + synapse_lib = cpp.find_library('Synapse', + dirs: ['/usr/lib/habanalabs', '/usr/local/lib/habanalabs'], + required: false) + hlthunk_lib = cpp.find_library('hl-thunk', + dirs: ['/usr/lib/habanalabs', '/usr/local/lib/habanalabs'], + required: false) +else + synapse_lib = cpp.find_library('Synapse', + dirs: [synapse_lib_path], + required: false) + hlthunk_lib = cpp.find_library('hl-thunk', + dirs: [synapse_lib_path], + required: false) +endif + +if synapse_inc_path == '' + #use default path + synapse_inc_path = '/usr/include/habanalabs/' +endif + +# SynapseAI support requires both libraries +synapseai_dep = dependency('', required: false) # Initialize as not found +if synapse_lib.found() and hlthunk_lib.found() + synapseai_dep = declare_dependency(dependencies: [synapse_lib, hlthunk_lib]) +elif hlthunk_lib.found() + # Fallback to just hl-thunk if libSynapse not available + synapseai_dep = hlthunk_lib +endif + +if synapseai_dep.found() + # Create proper dependency with include paths (including DRM path for habanalabs headers) + synapseai_dep = declare_dependency( + dependencies: synapseai_dep, + include_directories: [ + include_directories('/usr/include/drm'), + include_directories(synapse_inc_path) + ] + ) + message('Found SynapseAI support for Habana Gaudi devices') + synapseai_available = true +else + synapseai_available = false + warning('SynapseAI not found. Habana Gaudi device support will be disabled.') +endif + # GFlags gflags_dep = dependency('gflags', required: true) # OpenMP openmp_dep = dependency('openmp', required: true) - -# Check for etcd-cpp-api - use multiple methods for discovery +# Try pkg-config first etcd_dep = dependency('etcd-cpp-api', required : false) +if not etcd_dep.found() + # Fallback: manual configuration + # message('etcd-cpp-api not found via pkg-config, using manual configuration') -# Ensure etcd is available + # Check if we have the library files + etcd_lib = meson.get_compiler('cpp').find_library('etcd-cpp-api', + dirs: ['/usr/local/lib'], + required: false) + + if etcd_lib.found() + etcd_dep = declare_dependency( + include_directories: include_directories('/usr/local/include'), + dependencies: [etcd_lib], + # Add any required dependencies for etcd-cpp-api + link_args: [] # Add any additional link args if needed + ) + message('etcd-cpp-api found manually in /usr/local/lib') + else + etcd_dep = disabler() + message('etcd-cpp-api not found anywhere') + endif +endif etcd_available = etcd_dep.found() if etcd_available add_project_arguments('-DHAVE_ETCD', language: 'cpp') @@ -148,7 +219,14 @@ if cuda_fabric_available add_project_arguments('-DHAVE_CUDA_FABRIC', language: 'cpp') endif +if synapseai_available + add_project_arguments('-DHAVE_SYNAPSEAI', language: 'cpp') +endif + # Subprojects +if synapseai_available + subdir('src/synapseai') +endif subdir('src/utils') subdir('src/runtime') subdir('src/worker') @@ -161,6 +239,7 @@ configure_file( 'HAVE_NVSHMEM': nvshmem_available ? '1' : '0', 'HAVE_CUDA': cuda_available ? '1' : '0', 'HAVE_CUDA_FABRIC': cuda_fabric_available ? '1' : '0', + 'HAVE_SYNAPSEAI': synapseai_available ? '1' : '0', }, install: true, install_dir: get_option('includedir') / 'nixlbench' @@ -174,6 +253,11 @@ endif if cuda_available deps += [cuda_dep] endif + +if synapseai_available + deps += [synapseai_dep] + message('add synapseai_dep') +endif if nvshmem_available deps += [nvshmem_lib] args += [ @@ -185,9 +269,9 @@ if nvshmem_available ] endif -if not etcd_available - error('No runtime available or not found') -endif +#if not etcd_available +# error('No runtime available or not found') +#endif if nvshmem_available # Use nvcc directly for compilation and linking @@ -240,11 +324,21 @@ if nvshmem_available install_dir: get_option('bindir'), depends: [nixlbench_runtimes, utils_lib, worker_libs]) else - executable('nixlbench', 'src/main.cpp', - include_directories: inc_dir, - link_with: [nixlbench_runtimes, utils_lib, worker_libs], - dependencies: deps, - link_args: args, - install: true, - install_dir: get_option('bindir')) + if synapseai_available + executable('nixlbench', 'src/main.cpp', + include_directories: inc_dir, + link_with: [nixlbench_runtimes, utils_lib, worker_libs, synapseaiutils_lib], + dependencies: deps, + link_args: args, + install: true, + install_dir: get_option('bindir')) + else + executable('nixlbench', 'src/main.cpp', + include_directories: inc_dir, + link_with: [nixlbench_runtimes, utils_lib, worker_libs], + dependencies: deps, + link_args: args, + install: true, + install_dir: get_option('bindir')) + endif endif diff --git a/benchmark/nixlbench/meson_options.txt b/benchmark/nixlbench/meson_options.txt index 136cfda3c6..a6786f9a26 100644 --- a/benchmark/nixlbench/meson_options.txt +++ b/benchmark/nixlbench/meson_options.txt @@ -21,3 +21,5 @@ option('etcd_lib_path', type: 'string', value: '', description: 'Path to ETCD C+ option('nixl_path', type: 'string', value: '/usr/local', description: 'Path to NiXL') option('nvshmem_inc_path', type: 'string', value: '', description: 'Path to NVSHMEM include directory') option('nvshmem_lib_path', type: 'string', value: '', description: 'Path to NVSHMEM library directory') +option('synapsepath_inc', type: 'string', value: '', description: 'Include path for Intel Gaudi/ HPU') +option('synapsepath_lib', type: 'string', value: '', description: 'Library path for Intel Gaudi/ HPU') diff --git a/benchmark/nixlbench/src/synapseai/meson.build b/benchmark/nixlbench/src/synapseai/meson.build new file mode 100644 index 0000000000..4e661f0c7d --- /dev/null +++ b/benchmark/nixlbench/src/synapseai/meson.build @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +synapseaiutils_sources = [ + 'synapse_utils.cpp', + 'synapse_utils.h', +] + +synapseaiutils_deps = [ + synapseai_dep +] + +synapseaiutils_lib = static_library('synapseaiutils', + synapseaiutils_sources, + dependencies: synapseaiutils_deps, + include_directories: inc_dir +) +synapseaiutils_dep = declare_dependency( + link_with: synapseaiutils_lib, + dependencies: synapseaiutils_deps, + include_directories: inc_dir +) diff --git a/benchmark/nixlbench/src/synapseai/synapse_utils.cpp b/benchmark/nixlbench/src/synapseai/synapse_utils.cpp new file mode 100644 index 0000000000..1333cf0028 --- /dev/null +++ b/benchmark/nixlbench/src/synapseai/synapse_utils.cpp @@ -0,0 +1,111 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "synapse_utils.h" + +static bool device_initialized = false; +static std::mutex mtx; +static synDeviceId deviceHandle; +static synStreamHandle stream; + +namespace Synapseaiutils { +static void +check(int ret, const char *msg) { + if (ret) { + fprintf(stderr, "%s: %s(%d)\n", msg, "failed", -ret); + exit(1); + } +} + +int +init_synapse_device() { + std::lock_guard lock(mtx); + auto env = std::getenv("HLS_MODULE_ID"); + int module_id = 0; + if (env != nullptr) { + module_id = std::stoi(env); + } + if (device_initialized) return 0; + check(synInitialize(), "synInitialize"); + check(synDeviceAcquireByModuleId(&deviceHandle, module_id), "synDeviceAcquire"); + device_initialized = true; + check(synStreamCreateGeneric(&stream, deviceHandle, 0), "synStreamCreateGeneric"); + return 0; +} + +synDeviceId +get_device_handle() { + return deviceHandle; +} + +uint64_t +allocate_synapse_memory(size_t len, void *host_buffer) { + uint64_t device_buffer; + std::lock_guard lock(mtx); + if (!device_initialized) { + fprintf(stderr, "%s\n", "device nor initialized"); + exit(1); + } + + check(synDeviceMalloc(deviceHandle, len, 0x0, 0, &device_buffer), "synDeviceMalloc"); + check(synHostMap(deviceHandle, len, host_buffer), "synHostMap"); + check(synMemCopyAsync(stream, (uint64_t)host_buffer, len, device_buffer, HOST_TO_DRAM), + "synMemCopyAsync"); + check(synStreamSynchronize(stream), "synStreamSynchronize"); + std::cout << "allocate_synapse_memory" << "device buffer::" << device_buffer + << " host buffer::" << host_buffer << " Len::" << len << std::endl; + check(synHostUnmap(deviceHandle, host_buffer), "synHostUnmap"); + return device_buffer; +} + +void +free_synapse_memory(uint64_t ptr) { + std::lock_guard lock(mtx); + if (!device_initialized) fprintf(stderr, "%s\n", "device nor initialized"); + // cleanup Synapse resources + check(synDeviceFree(deviceHandle, ptr, 0), "synDeviceFree"); +} + +void +deinit_synapse_device() { + std::lock_guard lock(mtx); + if (!device_initialized) { + fprintf(stderr, "%s\n", "device nor initialized"); + exit(1); + } + check(synStreamDestroy(stream), "synStreamDestroy"); + check(synDeviceRelease(deviceHandle), "synDeviceRelease"); + check(synDestroy(), "synDestroy"); + device_initialized = false; +} + +void +copy_from_device_buffer(uint64_t device_buffer, void *host_buffer, size_t len) { + std::lock_guard lock(mtx); + if (!device_initialized) { + fprintf(stderr, "%s\n", "device nor initialized"); + exit(1); + } + check(synHostMap(deviceHandle, len, host_buffer), "synHostMap"); + check(synMemCopyAsync(stream, device_buffer, len, (uint64_t)host_buffer, DRAM_TO_HOST), + "synMemCopyAsync"); + check(synStreamSynchronize(stream), "synStreamSynchronize"); + check(synHostUnmap(deviceHandle, host_buffer), "synHostUnmap"); +} +} // namespace Synapseaiutils diff --git a/benchmark/nixlbench/src/synapseai/synapse_utils.h b/benchmark/nixlbench/src/synapseai/synapse_utils.h new file mode 100644 index 0000000000..363cb98c00 --- /dev/null +++ b/benchmark/nixlbench/src/synapseai/synapse_utils.h @@ -0,0 +1,34 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +namespace Synapseaiutils { +int +init_synapse_device(); +synDeviceId +get_device_handle(); +uint64_t +allocate_synapse_memory(size_t len, void *host_buffer); +void +free_synapse_memory(uint64_t ptr); +void +deinit_synapse_device(); +void +copy_from_device_buffer(uint64_t device_buffer, void *host_buffer, size_t len); +} // namespace Synapseaiutils diff --git a/benchmark/nixlbench/src/utils/meson.build b/benchmark/nixlbench/src/utils/meson.build index 5ec29a0943..616a318cef 100644 --- a/benchmark/nixlbench/src/utils/meson.build +++ b/benchmark/nixlbench/src/utils/meson.build @@ -22,7 +22,8 @@ utils_sources = [ utils_deps = [ cuda_dep, gflags_dep, - openmp_dep + openmp_dep, + synapseai_dep ] utils_lib = static_library('utils', @@ -32,7 +33,7 @@ utils_lib = static_library('utils', ) utils_dep = declare_dependency( link_with: utils_lib, - dependencies: utils_deps, + dependencies: [synapseaiutils_deps, utils_deps], include_directories: inc_dir ) diff --git a/benchmark/nixlbench/src/utils/utils.cpp b/benchmark/nixlbench/src/utils/utils.cpp index 385464d8a1..16279b0645 100644 --- a/benchmark/nixlbench/src/utils/utils.cpp +++ b/benchmark/nixlbench/src/utils/utils.cpp @@ -29,6 +29,9 @@ #if HAVE_CUDA #include #endif +#if HAVE_SYNAPSEAI +#include "src/synapseai/synapse_utils.h" +#endif #include #include @@ -391,7 +394,7 @@ xferBenchConfig::printConfig() { } printOption("Worker type (--worker_type=[nixl,nvshmem])", worker_type); if (worker_type == XFERBENCH_WORKER_NIXL) { - printOption("Backend (--backend=[UCX,UCX_MO,GDS,GDS_MT,POSIX,Mooncake,HF3FS,OBJ])", + printOption("Backend (--backend=[UCX,UCX_MO,GDS,GDS_MT,POSIX,Mooncake,HF3FS,OBJ,LIBFABRIC])", backend); printOption ("Enable pt (--enable_pt=[0,1])", std::to_string (enable_pt)); printOption("Progress threads (--progress_threads=N)", std::to_string(progress_threads)); @@ -476,6 +479,7 @@ std::vector xferBenchConfig::parseDeviceList() { // TODO: Add support for other schemes if (xferBenchConfig::scheme == XFERBENCH_SCHEME_PAIRWISE && xferBenchConfig::device_list != "all") { + std::cout << "xferBenchConfig::device_list not all" << std::endl; while (std::getline(ss, dev, ',')) { devices.push_back(dev); } @@ -488,6 +492,7 @@ std::vector xferBenchConfig::parseDeviceList() { return {}; } } else { + std::cout << "xferBenchConfig::device_list == all" << std::endl; devices.push_back("all"); } @@ -548,13 +553,19 @@ void xferBenchUtils::checkConsistency(std::vector> &io xferBenchConfig::backend == XFERBENCH_BACKEND_GPUNETIO) { if (xferBenchConfig::op_type == XFERBENCH_OP_READ) { if (xferBenchConfig::initiator_seg_type == XFERBENCH_SEG_TYPE_VRAM) { -#if HAVE_CUDA +#if HAVE_CUDA || HAVE_SYNAPSEAI addr = calloc(1, len); is_allocated = true; - CHECK_CUDA_ERROR(cudaMemcpy(addr, (void *)iov.addr, len, - cudaMemcpyDeviceToHost), "cudaMemcpy failed"); +#if HAVE_CUDA + CHECK_CUDA_ERROR( + cudaMemcpy(addr, (void *)iov.addr, len, cudaMemcpyDeviceToHost), + "cudaMemcpy failed"); +#else + Synapseaiutils::copy_from_device_buffer(uint64_t(iov.addr), addr, len); +#endif #else - std::cerr << "Failure in consistency check: VRAM segment type not supported without CUDA" + std::cerr << "Failure in consistency check: VRAM segment type not " + "supported without CUDA" << std::endl; exit(EXIT_FAILURE); #endif @@ -595,16 +606,22 @@ void xferBenchUtils::checkConsistency(std::vector> &io // This will be called on target process in case of write and // on initiator process in case of read if ((xferBenchConfig::op_type == XFERBENCH_OP_WRITE && - xferBenchConfig::target_seg_type == XFERBENCH_SEG_TYPE_VRAM) || - (xferBenchConfig::op_type == XFERBENCH_OP_READ && - xferBenchConfig::initiator_seg_type == XFERBENCH_SEG_TYPE_VRAM)) { -#if HAVE_CUDA + xferBenchConfig::target_seg_type == XFERBENCH_SEG_TYPE_VRAM) || + (xferBenchConfig::op_type == XFERBENCH_OP_READ && + xferBenchConfig::initiator_seg_type == XFERBENCH_SEG_TYPE_VRAM)) { +#if HAVE_CUDA || HAVE_SYNAPSEAI addr = calloc(1, len); is_allocated = true; - CHECK_CUDA_ERROR(cudaMemcpy(addr, (void *)iov.addr, len, - cudaMemcpyDeviceToHost), "cudaMemcpy failed"); +#if HAVE_CUDA + CHECK_CUDA_ERROR( + cudaMemcpy(addr, (void *)iov.addr, len, cudaMemcpyDeviceToHost), + "cudaMemcpy failed"); +#else + Synapseaiutils::copy_from_device_buffer(uint64_t(iov.addr), addr, len); +#endif #else - std::cerr << "Failure in consistency check: VRAM segment type not supported without CUDA" + std::cerr << "Failure in consistency check: VRAM segment type not supported " + "without CUDA" << std::endl; exit(EXIT_FAILURE); #endif diff --git a/benchmark/nixlbench/src/worker/meson.build b/benchmark/nixlbench/src/worker/meson.build index d7accfde2d..e45b133edb 100644 --- a/benchmark/nixlbench/src/worker/meson.build +++ b/benchmark/nixlbench/src/worker/meson.build @@ -22,7 +22,8 @@ worker_deps = [ cuda_dep, gflags_dep, openmp_dep, - etcd_dep + etcd_dep, + synapseai_dep ] worker_lib = static_library('worker', diff --git a/benchmark/nixlbench/src/worker/nixl/meson.build b/benchmark/nixlbench/src/worker/nixl/meson.build index c31321a727..f46e8c2763 100644 --- a/benchmark/nixlbench/src/worker/nixl/meson.build +++ b/benchmark/nixlbench/src/worker/nixl/meson.build @@ -23,9 +23,12 @@ if cuda_available nixl_worker_deps += [cuda_dep] endif +if synapseai_available + nixl_worker_deps += [synapseai_dep] +endif nixl_worker_lib = static_library('nixl_worker', nixl_worker_sources, include_directories: inc_dir, - dependencies: nixl_worker_deps, + dependencies: [synapseaiutils_deps, nixl_worker_deps], install: true, -) \ No newline at end of file +) diff --git a/benchmark/nixlbench/src/worker/nixl/nixl_worker.cpp b/benchmark/nixlbench/src/worker/nixl/nixl_worker.cpp index eec2be60f4..fecd6cb4a7 100644 --- a/benchmark/nixlbench/src/worker/nixl/nixl_worker.cpp +++ b/benchmark/nixlbench/src/worker/nixl/nixl_worker.cpp @@ -23,6 +23,9 @@ #include #include #endif +#ifdef HAVE_SYNAPSEAI +#include "synapseai/synapse_utils.h" +#endif #include #include #include @@ -46,7 +49,7 @@ } \ } while (0) -#if HAVE_CUDA +#if HAVE_CUDA || HAVE_SYNAPSEAI #define HANDLE_VRAM_SEGMENT(_seg_type) _seg_type = VRAM_SEG; #else #define HANDLE_VRAM_SEGMENT(_seg_type) \ @@ -207,6 +210,10 @@ xferBenchNixlWorker::xferBenchNixlWorker(int *argc, char ***argv, std::vectorcreateBackend(backend_name, backend_params, backend_engine); +#ifdef HAVE_SYNAPSEAI + std::cout << "initalizing synapse device" << std::endl; + Synapseaiutils::init_synapse_device(); +#endif } xferBenchNixlWorker::~xferBenchNixlWorker() { @@ -421,6 +428,43 @@ xferBenchNixlWorker::initBasicDescVram(size_t buffer_size, int mem_dev_id) { } #endif /* HAVE_CUDA */ +#if HAVE_SYNAPSEAI +static std::optional +getVramDescSynapseai(int devid, size_t buffer_size, uint8_t memset_value) { + void *host_addr = calloc(1, buffer_size); + memset(host_addr, memset_value, buffer_size); + auto device_buffer = Synapseaiutils::allocate_synapse_memory(buffer_size, host_addr); + + free(host_addr); + return std::optional(std::in_place, (uintptr_t)device_buffer, buffer_size, devid); +} + +static std::optional +getVramDesc(int devid, size_t buffer_size, bool isInit) { + uint8_t memset_value = + isInit ? XFERBENCH_INITIATOR_BUFFER_ELEMENT : XFERBENCH_TARGET_BUFFER_ELEMENT; + + return getVramDescSynapseai(devid, buffer_size, memset_value); +} + +std::optional +xferBenchNixlWorker::initBasicDescVram(size_t buffer_size, int mem_dev_id) { + if (IS_PAIRWISE_AND_SG()) { + int devid = rt->getRank(); + + if (isTarget()) { + devid -= xferBenchConfig::num_initiator_dev; + } + + if (devid != mem_dev_id) { + return std::nullopt; + } + } + + return getVramDesc(mem_dev_id, buffer_size, isInitiator()); +} +#endif /* HAVE_SYNAPSEAI */ + static std::vector createFileFds(std::string name, int num_files) { std::vector fds; @@ -548,6 +592,13 @@ xferBenchNixlWorker::cleanupBasicDescVram(xferBenchIOV &iov) { } #endif /* HAVE_CUDA */ +#if HAVE_SYNAPSEAI +void +xferBenchNixlWorker::cleanupBasicDescVram(xferBenchIOV &iov) { + Synapseaiutils::free_synapse_memory((uint64_t)iov.addr); +} +#endif /* HAVE_SYNAPSEAI */ + void xferBenchNixlWorker::cleanupBasicDescFile(xferBenchIOV &iov) { close(iov.devId); @@ -669,7 +720,7 @@ xferBenchNixlWorker::allocateMemory(int num_threads) { case DRAM_SEG: basic_desc = initBasicDescDram(buffer_size, i); break; -#if HAVE_CUDA +#if HAVE_CUDA || HAVE_SYNAPSEAI case VRAM_SEG: basic_desc = initBasicDescVram(buffer_size, i); break; @@ -707,7 +758,7 @@ xferBenchNixlWorker::deallocateMemory(std::vector> &io case DRAM_SEG: cleanupBasicDescDram(iov); break; -#if HAVE_CUDA +#if HAVE_CUDA || HAVE_SYNAPSEAI case VRAM_SEG: cleanupBasicDescVram(iov); break; diff --git a/benchmark/nixlbench/src/worker/nixl/nixl_worker.h b/benchmark/nixlbench/src/worker/nixl/nixl_worker.h index 7cb8f62ef0..f208546223 100644 --- a/benchmark/nixlbench/src/worker/nixl/nixl_worker.h +++ b/benchmark/nixlbench/src/worker/nixl/nixl_worker.h @@ -70,7 +70,7 @@ class xferBenchNixlWorker: public xferBenchWorker { std::optional initBasicDescDram(size_t buffer_size, int mem_dev_id); void cleanupBasicDescDram(xferBenchIOV &basic_desc); -#if HAVE_CUDA +#if HAVE_CUDA || HAVE_SYNAPSEAI std::optional initBasicDescVram(size_t buffer_size, int mem_dev_id); void cleanupBasicDescVram(xferBenchIOV &basic_desc); #endif diff --git a/benchmark/nixlbench/src/worker/worker.cpp b/benchmark/nixlbench/src/worker/worker.cpp index 3cfc994f45..ec3bc076a3 100644 --- a/benchmark/nixlbench/src/worker/worker.cpp +++ b/benchmark/nixlbench/src/worker/worker.cpp @@ -116,7 +116,6 @@ int xferBenchWorker::synchronize() { xferBenchWorker::xferBenchWorker(int *argc, char ***argv) { terminate = 0; - rt = createRT(&terminate); if (!rt) { std::cerr << "Failed to create runtime object" << std::endl; diff --git a/pyproject.toml b/pyproject.toml index 3cfbde6ef2..3e956fb25d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,4 +34,9 @@ dependencies = ["torch", "numpy"] profile = "black" [tool.meson-python.args] -setup = ['-Dinstall_headers=false'] +setup = ["-Dinstall_headers=false", + "-Dlibfabric_path=/software/users/kpjeeja/nixl/libfabric", + "-Ducx_path=/home/kpjeeja/ofi/ucx/build/install-debug/", + "-Dsynapsepath_inc=/software/users/kpjeeja/include/habanalabs", + "-Dsynapsepath_lib=/home/kpjeeja/qnpu/pt_1_23_368/bin/latest/", + "-Detcd_lib_path=/usr/local/lib/"] diff --git a/src/plugins/libfabric/libfabric_backend.cpp b/src/plugins/libfabric/libfabric_backend.cpp index 8bbd3264e5..4a96786f5b 100644 --- a/src/plugins/libfabric/libfabric_backend.cpp +++ b/src/plugins/libfabric/libfabric_backend.cpp @@ -19,7 +19,6 @@ #include "libfabric_backend.h" #include "serdes/serdes.h" #include "common/nixl_log.h" -#include "libfabric/libfabric_topology.h" #include #include @@ -497,7 +496,7 @@ nixlLibfabricEngine::connect(const std::string &remote_agent) { auto it = connections_.find(remote_agent); if (it != connections_.end() && it->second->overall_state_ == ConnectionState::CONNECTED) { NIXL_DEBUG << "Connection already established for " << remote_agent - << ", fi_addr: " << it->second->rail_remote_addr_list_[0]; + << ", fi_addr: " << it->second->rail_remote_addr_list_[0][0]; return NIXL_SUCCESS; } @@ -535,7 +534,7 @@ nixlLibfabricEngine::disconnect(const std::string &remote_agent) { // Connection exists - check if already disconnected if (it->second->overall_state_ == ConnectionState::DISCONNECTED) { NIXL_DEBUG << "Connection already established for " << remote_agent - << ", fi_addr: " << it->second->rail_remote_addr_list_[0]; + << ", fi_addr: " << it->second->rail_remote_addr_list_[0][0]; return NIXL_SUCCESS; } // TODO: Implement disconnect logic to cleanup the AV Address Entries from both local and remote @@ -558,11 +557,9 @@ nixlLibfabricEngine::createAgentConnection( NIXL_DEBUG << "Creating connection for agent: " << agent_name; - // Validate input parameters if (data_rail_endpoints.size() != rail_manager.getNumDataRails()) { - NIXL_ERROR << "Expected " << rail_manager.getNumDataRails() << " data rail endpoints, got " - << data_rail_endpoints.size(); - return NIXL_ERR_INVALID_PARAM; + NIXL_INFO << "Local " << rail_manager.getNumDataRails() << " data rail endpoints, remote " + << data_rail_endpoints.size(); } if (control_rail_endpoints.size() != rail_manager.getNumControlRails()) { @@ -696,7 +693,7 @@ nixlLibfabricEngine::establishConnection(const std::string &remote_agent) const nixl_status_t status = rail_manager.postControlMessage( nixlLibfabricRailManager::ControlMessageType::CONNECTION_REQ, control_request, - conn_info->control_rail_remote_addr_list_[0], // Always use control rail 0 + conn_info->control_rail_remote_addr_list_[0][0], // Always use control rail 0 it->second->agent_index_ // agent_index is only used in the ACK back from remote, // to match connection request ); @@ -892,24 +889,32 @@ nixlLibfabricEngine::getPublicData(const nixlBackendMD *meta, std::string &str) } nixl_status_t -nixlLibfabricEngine::loadLocalMD(nixlBackendMD *input, nixlBackendMD *&output) { - nixlLibfabricPrivateMetadata *input_md = static_cast(input); +nixlLibfabricEngine::loadMetadataHelper(const std::vector &rail_keys, + void *buffer, + std::shared_ptr conn, + nixlBackendMD *&output) { auto pub_md = std::make_unique(); - // Store all rail keys instead of just the first one - pub_md->rail_remote_key_list_.reserve(input_md->rail_key_list_.size()); - for (size_t rail_id = 0; rail_id < input_md->rail_key_list_.size(); ++rail_id) { - pub_md->rail_remote_key_list_.push_back(input_md->rail_key_list_[rail_id]); - NIXL_DEBUG << "Added rail " << rail_id << " key: " << input_md->rail_key_list_[rail_id]; - } - pub_md->remote_buf_addr_ = reinterpret_cast(input_md->buffer_); - pub_md->conn_ = connections_[localAgent]; + pub_md->rail_remote_key_list_ = std::move(rail_keys); + pub_md->derive_remote_selected_endpoints(); + pub_md->remote_buf_addr_ = reinterpret_cast(buffer); + pub_md->conn_ = conn; + NIXL_DEBUG << "Metadata loaded with" + << " Remote addr: " << (void *)pub_md->remote_buf_addr_ << " Remote keys for " + << pub_md->rail_remote_key_list_.size() << " rails" + << " Remote fi_addr: " << pub_md->conn_->rail_remote_addr_list_[0][0]; output = pub_md.release(); - NIXL_DEBUG << "Loading Local MD with " << input_md->rail_key_list_.size() << " rail keys"; return NIXL_SUCCESS; } +nixl_status_t +nixlLibfabricEngine::loadLocalMD(nixlBackendMD *input, nixlBackendMD *&output) { + nixlLibfabricPrivateMetadata *input_md = static_cast(input); + return loadMetadataHelper( + input_md->rail_key_list_, input_md->buffer_, connections_[localAgent], output); +} + nixl_status_t nixlLibfabricEngine::loadRemoteMD(const nixlBlobDesc &input, const nixl_mem_t &nixl_mem, @@ -927,24 +932,17 @@ nixlLibfabricEngine::loadRemoteMD(const nixlBlobDesc &input, std::vector remote_keys; uint64_t remote_addr; nixl_status_t status = - rail_manager.deserializeMemoryKeys(input.metaInfo, remote_keys, remote_addr); + rail_manager.deserializeMemoryKeys(input.metaInfo, + conn_it->second->rail_remote_addr_list_.at(0).size(), + remote_keys, + remote_addr); if (status != NIXL_SUCCESS) { NIXL_ERROR << "Rail Manager deserializeMemoryKeys failed"; return status; } - // Engine handles connection management and metadata object creation - auto pub_md = std::make_unique(); - pub_md->conn_ = conn_it->second; - pub_md->rail_remote_key_list_ = std::move(remote_keys); - pub_md->remote_buf_addr_ = remote_addr; - NIXL_DEBUG << "Remote metadata loaded with" - << " Remote addr: " << (void *)pub_md->remote_buf_addr_ << " Remote keys for " - << pub_md->rail_remote_key_list_.size() << " rails" - << " Remote fi_addr: " << pub_md->conn_->rail_remote_addr_list_[0]; - - output = pub_md.release(); - return NIXL_SUCCESS; + return loadMetadataHelper( + remote_keys, reinterpret_cast(remote_addr), conn_it->second, output); } nixl_status_t @@ -953,6 +951,23 @@ nixlLibfabricEngine::unloadMD(nixlBackendMD *input) { return NIXL_SUCCESS; } +/**************************************** + * Public Metadata Methods + *****************************************/ + +void +nixlLibfabricPublicMetadata::derive_remote_selected_endpoints() { + remote_selected_endpoints_.clear(); + + for (size_t i = 0; i < rail_remote_key_list_.size(); ++i) { + if (rail_remote_key_list_[i] != 0) { + remote_selected_endpoints_.push_back(i); + } else { + NIXL_DEBUG << "Skipping remote endpoint " << i << " with key 0"; + } + } +} + /**************************************** * Data movement *****************************************/ @@ -1087,36 +1102,6 @@ nixlLibfabricEngine::postXfer(const nixl_xfer_op_t &operation, NIXL_DEBUG << "DEBUG: remote_agent='" << remote_agent << "' localAgent='" << localAgent << "'"; - // Check for same-agent (local) transfer - handle with direct memcpy - if (remote_agent == localAgent) { - NIXL_DEBUG << "Same-agent transfer detected from localAgent= " << localAgent - << "to remote_agent " << remote_agent << "for descriptor " << desc_idx - << ", using memcpy fallback for " << transfer_size << " bytes"; - - // For same-agent transfers, we need to copy directly between the descriptor addresses - // The remote[desc_idx].addr should be the target address for the transfer - void *remote_addr = reinterpret_cast(remote[desc_idx].addr); - - NIXL_DEBUG << "About to perform memcpy: local_addr=" << transfer_addr - << " remote_addr=" << remote_addr << " size=" << transfer_size; - - if (op_type == nixlLibfabricReq::WRITE) { - // Write: copy from local_addr to remote_addr - std::memcpy(remote_addr, transfer_addr, transfer_size); - NIXL_DEBUG << "Same-agent memcpy write completed: " << transfer_addr << " -> " - << remote_addr << " (" << transfer_size << " bytes)"; - } else { - // Read: copy from remote_addr to local_addr - std::memcpy(transfer_addr, remote_addr, transfer_size); - NIXL_DEBUG << "Same-agent memcpy read completed: " << remote_addr << " -> " - << transfer_addr << " (" << transfer_size << " bytes)"; - } - - NIXL_DEBUG << "Successfully processed same-agent descriptor " << desc_idx - << " using memcpy fallback"; - continue; // Skip the rail manager transfer for this descriptor - } - // Prepare and submit transfer for remote agents // Use descriptor's specific target address uint64_t remote_target_addr = remote[desc_idx].addr; @@ -1129,6 +1114,7 @@ nixlLibfabricEngine::postXfer(const nixl_xfer_op_t &operation, local_md->selected_rails_, local_md->rail_mr_list_, remote_md->rail_remote_key_list_, + remote_md->remote_selected_endpoints_, conn_it->second->rail_remote_addr_list_, conn_it->second->agent_index_, [backend_handle]() { @@ -1149,8 +1135,9 @@ nixlLibfabricEngine::postXfer(const nixl_xfer_op_t &operation, NIXL_DEBUG << "Processing complete: submitted " << backend_handle->binary_notif.expected_completions << " requests from " - << desc_count << " descriptors" << " with " - << backend_handle->binary_notif.expected_completions << " total XFER_IDs"; + << desc_count << " descriptors" + << " with " << backend_handle->binary_notif.expected_completions + << " total XFER_IDs"; // For same-agent transfers, we need to set the total to 0 since we bypassed all rail operations if (remote_agent == localAgent) { @@ -1272,11 +1259,11 @@ nixlLibfabricEngine::notifSendPriv(const std::string &remote_agent, NIXL_DEBUG << "Sending binary notification control request" << " Message: " << binary_notification.getMessage() << " expected_completions: " << binary_notification.expected_completions; - nixl_status_t status = - rail_manager.postControlMessage(nixlLibfabricRailManager::ControlMessageType::NOTIFICATION, - control_request, - connection->control_rail_remote_addr_list_[control_rail_id], - connection->agent_index_); + nixl_status_t status = rail_manager.postControlMessage( + nixlLibfabricRailManager::ControlMessageType::NOTIFICATION, + control_request, + connection->control_rail_remote_addr_list_[control_rail_id][0], + connection->agent_index_); if (status != NIXL_SUCCESS) { NIXL_ERROR << "postControlMessage failed on control rail " << control_rail_id; @@ -1412,7 +1399,7 @@ nixlLibfabricEngine::postShutdownCompletion() { nixl_status_t status = rail_manager.postControlMessage( nixlLibfabricRailManager::ControlMessageType::DISCONNECT_REQ, control_request, - self_conn_it->second->rail_remote_addr_list_[rail_id], + self_conn_it->second->rail_remote_addr_list_[rail_id][0], self_conn_it->second->agent_index_); if (status == NIXL_SUCCESS) { @@ -1529,7 +1516,7 @@ nixlLibfabricEngine::processConnectionRequest(uint16_t agent_idx, } // Insert ALL data rail addresses at once - std::vector data_fi_addrs; + std::unordered_map> data_fi_addrs; std::vector data_ep_names; status = rail_manager.insertAllAddresses( nixlLibfabricRailManager::RailType::DATA, data_endpoints, data_fi_addrs, data_ep_names); @@ -1539,7 +1526,7 @@ nixlLibfabricEngine::processConnectionRequest(uint16_t agent_idx, } // Insert ALL control rail addresses at once - std::vector control_fi_addrs; + std::unordered_map> control_fi_addrs; std::vector control_ep_names; status = rail_manager.insertAllAddresses(nixlLibfabricRailManager::RailType::CONTROL, control_endpoints, @@ -1551,7 +1538,7 @@ nixlLibfabricEngine::processConnectionRequest(uint16_t agent_idx, } // Use the first control rail's fi_addr for ACK (same as before) - fi_addr_t initiator_control_fi_addr = control_fi_addrs[0]; + fi_addr_t initiator_control_fi_addr = control_fi_addrs[0][0]; NIXL_DEBUG << "Successfully inserted addresses for " << data_fi_addrs.size() << " data rails and " << control_fi_addrs.size() << " control rails" @@ -1623,7 +1610,6 @@ nixlLibfabricEngine::addReceivedXferId(uint16_t xfer_id) { checkPendingNotifications(); } - /**************************************** * Notification Queuing Helper Methods *****************************************/ diff --git a/src/plugins/libfabric/libfabric_backend.h b/src/plugins/libfabric/libfabric_backend.h index fcee2544c4..2a3154de35 100644 --- a/src/plugins/libfabric/libfabric_backend.h +++ b/src/plugins/libfabric/libfabric_backend.h @@ -96,9 +96,15 @@ class nixlLibfabricPublicMetadata : public nixlBackendMD { std::shared_ptr conn_; // Connection to remote agent std::vector rail_remote_key_list_; // Remote access keys, one per rail std::vector src_ep_names_; // Source endpoint names, one per rail + std::vector + remote_selected_endpoints_; // Remote rails selected, derived from rail_remote_key_list_. public: nixlLibfabricPublicMetadata() : nixlBackendMD(false) {} + + void + derive_remote_selected_endpoints(); + friend class nixlLibfabricEngine; }; @@ -107,8 +113,10 @@ class nixlLibfabricConnection : public nixlBackendConnMD { private: size_t agent_index_; // Unique agent identifier in agent_names vector std::string remoteAgent_; // Remote agent name - std::vector rail_remote_addr_list_; // Data rail libfabric addresses - std::vector control_rail_remote_addr_list_; // Control rail libfabric addresses + std::unordered_map> + rail_remote_addr_list_; // Data rail libfabric addresses. Key: data rail id. + std::unordered_map> + control_rail_remote_addr_list_; // Control rail libfabric addresses. Key: control rail id. std::vector src_ep_names_; // Data rail endpoint names std::vector control_ep_names_; // Control rail endpoint names ConnectionState overall_state_; // Current connection state @@ -281,7 +289,11 @@ class nixlLibfabricEngine : public nixlBackendEngine { processConnectionRequest(uint16_t agent_idx, const std::string &serialized_data, nixlLibfabricRail *rail); - + nixl_status_t + loadMetadataHelper(const std::vector &rail_keys, + void *buffer, + std::shared_ptr conn, + nixlBackendMD *&output); #ifdef HAVE_CUDA // CUDA context management methods diff --git a/src/utils/libfabric/libfabric_rail_manager.cpp b/src/utils/libfabric/libfabric_rail_manager.cpp index e2d639c6f5..d1f1050f48 100644 --- a/src/utils/libfabric/libfabric_rail_manager.cpp +++ b/src/utils/libfabric/libfabric_rail_manager.cpp @@ -34,6 +34,7 @@ resetSeqId(); // Static round-robin counter for rail selection static std::atomic round_robin_counter{0}; +static const std::string NUM_RAILS_TAG{"num_rails"}; nixlLibfabricRailManager::nixlLibfabricRailManager(size_t striping_threshold) : striping_threshold_(striping_threshold) { @@ -136,17 +137,19 @@ nixlLibfabricRailManager::shouldUseStriping(size_t transfer_size) const { } nixl_status_t -nixlLibfabricRailManager::prepareAndSubmitTransfer(nixlLibfabricReq::OpType op_type, - void *local_addr, - size_t transfer_size, - uint64_t remote_base_addr, - const std::vector &selected_rails, - const std::vector &local_mrs, - const std::vector &remote_keys, - const std::vector &dest_addrs, - uint16_t agent_idx, - std::function completion_callback, - BinaryNotification *binary_notif) { +nixlLibfabricRailManager::prepareAndSubmitTransfer( + nixlLibfabricReq::OpType op_type, + void *local_addr, + size_t transfer_size, + uint64_t remote_base_addr, + const std::vector &selected_rails, + const std::vector &local_mrs, + const std::vector &remote_keys, + const std::vector &remote_selected_endpoints, + const std::unordered_map> &dest_addrs, + uint16_t agent_idx, + std::function completion_callback, + BinaryNotification *binary_notif) { if (selected_rails.empty()) { NIXL_ERROR << "No rails selected for transfer"; return NIXL_ERR_INVALID_PARAM; @@ -154,11 +157,14 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer(nixlLibfabricReq::OpType op_t // Determine striping strategy bool use_striping = shouldUseStriping(transfer_size) && selected_rails.size() > 1; - + NIXL_DEBUG << "use_striping: " << use_striping; if (!use_striping) { // Round-robin: use one rail for entire transfer - size_t rail_idx = round_robin_counter.fetch_add(1) % selected_rails.size(); - size_t rail_id = selected_rails[rail_idx]; + const auto counter_value = round_robin_counter.fetch_add(1); + const size_t rail_id = selected_rails[counter_value % selected_rails.size()]; + const size_t remote_ep_id = + remote_selected_endpoints[counter_value % remote_selected_endpoints.size()]; + NIXL_DEBUG << "rail " << rail_id << ", remote_ep_id " << remote_ep_id; // Allocate request nixlLibfabricReq *req = data_rails_[rail_id]->allocateDataRequest(op_type); if (!req) { @@ -183,7 +189,7 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer(nixlLibfabricReq::OpType op_t } req->local_mr = local_mrs[rail_id]; - req->remote_key = remote_keys[rail_id]; + req->remote_key = remote_keys[remote_ep_id]; req->rail_id = rail_id; // Submit immediately nixl_status_t status; @@ -196,7 +202,7 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer(nixlLibfabricReq::OpType op_t req->chunk_size, fi_mr_desc(req->local_mr), imm_data, - dest_addrs[rail_id], + dest_addrs.at(rail_id)[remote_ep_id], req->remote_addr, req->remote_key, req); @@ -204,7 +210,7 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer(nixlLibfabricReq::OpType op_t status = data_rails_[rail_id]->postRead(req->local_addr, req->chunk_size, fi_mr_desc(req->local_mr), - dest_addrs[rail_id], + dest_addrs.at(rail_id)[remote_ep_id], req->remote_addr, req->remote_key, req); @@ -229,7 +235,10 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer(nixlLibfabricReq::OpType op_t size_t chunk_size = transfer_size / num_rails; size_t remainder = transfer_size % num_rails; for (size_t i = 0; i < num_rails; ++i) { - size_t rail_id = selected_rails[i]; + const size_t rail_id = selected_rails[i]; + const size_t remote_ep_id = + remote_selected_endpoints[i % remote_selected_endpoints.size()]; + NIXL_DEBUG << "rail " << rail_id << ", remote_ep_id=" << remote_ep_id; size_t current_chunk_size = chunk_size + (i == num_rails - 1 ? remainder : 0); if (current_chunk_size == 0) break; // Allocate request @@ -261,7 +270,7 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer(nixlLibfabricReq::OpType op_t } req->local_mr = local_mrs[rail_id]; - req->remote_key = remote_keys[rail_id]; + req->remote_key = remote_keys[remote_ep_id]; req->rail_id = rail_id; nixl_status_t status; if (op_type == nixlLibfabricReq::WRITE) { @@ -273,7 +282,7 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer(nixlLibfabricReq::OpType op_t req->chunk_size, fi_mr_desc(req->local_mr), imm_data, - dest_addrs[rail_id], + dest_addrs.at(rail_id)[remote_ep_id], req->remote_addr, req->remote_key, req); @@ -281,7 +290,7 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer(nixlLibfabricReq::OpType op_t status = data_rails_[rail_id]->postRead(req->local_addr, req->chunk_size, fi_mr_desc(req->local_mr), - dest_addrs[rail_id], + dest_addrs.at(rail_id)[remote_ep_id], req->remote_addr, req->remote_key, req); @@ -482,38 +491,33 @@ nixl_status_t nixlLibfabricRailManager::insertAllAddresses( RailType rail_type, const std::vector> &endpoints, - std::vector &fi_addrs_out, + std::unordered_map> &fi_addrs_out, std::vector &ep_names_out) { auto &rails = (rail_type == RailType::DATA) ? data_rails_ : control_rails_; const char *rail_type_str = (rail_type == RailType::DATA) ? "data" : "control"; - if (endpoints.size() != rails.size()) { - NIXL_ERROR << "Expected " << rails.size() << " " << rail_type_str << " endpoints, got " - << endpoints.size(); - return NIXL_ERR_INVALID_PARAM; - } - fi_addrs_out.clear(); ep_names_out.clear(); - fi_addrs_out.reserve(rails.size()); ep_names_out.reserve(rails.size()); // Process all rails in one operation for (size_t rail_id = 0; rail_id < rails.size(); ++rail_id) { - fi_addr_t fi_addr; - nixl_status_t status = rails[rail_id]->insertAddress(endpoints[rail_id].data(), &fi_addr); - if (status != NIXL_SUCCESS) { - NIXL_ERROR << "Failed for " << rail_type_str << " rail " << rail_id; - return status; + fi_addrs_out[rail_id].reserve(endpoints.size()); + for (const auto &endpoint : endpoints) { + fi_addr_t fi_addr; + nixl_status_t status = rails[rail_id]->insertAddress(endpoint.data(), &fi_addr); + if (status != NIXL_SUCCESS) { + NIXL_ERROR << "Failed for " << rail_type_str << " rail " << rail_id; + return status; + } + fi_addrs_out[rail_id].push_back(fi_addr); + NIXL_DEBUG << "Processed " << rail_type_str << " rail " << rail_id + << " (fi_addr: " << fi_addr << ")"; } - fi_addrs_out.push_back(fi_addr); ep_names_out.push_back( rails[rail_id] ->ep_name); // This is char[LF_EP_NAME_MAX_LEN], will be converted to char* - - NIXL_DEBUG << "Processed " << rail_type_str << " rail " << rail_id - << " (fi_addr: " << fi_addr << ")"; } NIXL_DEBUG << "Successfully processed " << rails.size() << " " << rail_type_str << " rails"; @@ -717,19 +721,20 @@ nixlLibfabricRailManager::serializeMemoryKeys(const std::vector &keys, nixl_status_t nixlLibfabricRailManager::deserializeMemoryKeys(const std::string &serialized_data, + const size_t num_keys, std::vector &keys_out, uint64_t &remote_addr_out) const { nixlSerDes ser_des; ser_des.importStr(serialized_data); // Load all rail keys instead of just one keys_out.clear(); - keys_out.reserve(data_rails_.size()); - for (size_t rail_id = 0; rail_id < data_rails_.size(); ++rail_id) { - std::string key_name = "key_" + std::to_string(rail_id); + keys_out.reserve(num_keys); + for (size_t idx = 0; idx < num_keys; ++idx) { + std::string key_name = "key_" + std::to_string(idx); uint64_t remote_key; nixl_status_t status = ser_des.getBuf(key_name.c_str(), &remote_key, sizeof(remote_key)); if (status != NIXL_SUCCESS) { - NIXL_ERROR << "Failed to get key " << key_name << " for rail " << rail_id; + NIXL_ERROR << "Failed to get key " << key_name; return NIXL_ERR_BACKEND; } keys_out.push_back(remote_key); @@ -775,14 +780,13 @@ nixlLibfabricRailManager::deserializeConnectionInfo( // Use user prefix with standard suffixes std::string data_prefix = user_prefix + "_data_ep_"; std::string control_prefix = user_prefix + "_control_ep_"; - nixl_status_t data_status = - deserializeRailEndpoints(ser_des, data_prefix, data_rails_.size(), data_endpoints_out); + nixl_status_t data_status = deserializeRailEndpoints(ser_des, data_prefix, data_endpoints_out); if (data_status != NIXL_SUCCESS) { NIXL_ERROR << "Failed to deserialize data rail endpoints with prefix: " << data_prefix; return data_status; } - nixl_status_t control_status = deserializeRailEndpoints( - ser_des, control_prefix, control_rails_.size(), control_endpoints_out); + nixl_status_t control_status = + deserializeRailEndpoints(ser_des, control_prefix, control_endpoints_out); if (control_status != NIXL_SUCCESS) { NIXL_ERROR << "Failed to deserialize control rail endpoints with prefix: " << control_prefix; @@ -802,6 +806,8 @@ nixlLibfabricRailManager::serializeRailEndpoints(nixlSerDes &ser_des, auto &rails = (rail_type == RailType::DATA) ? data_rails_ : control_rails_; const char *rail_type_str = (rail_type == RailType::DATA) ? "data" : "control"; + ser_des.addStr(NUM_RAILS_TAG, std::to_string(rails.size())); + for (size_t rail_id = 0; rail_id < rails.size(); ++rail_id) { std::string rail_key = key_prefix + std::to_string(rail_id); const char *ep_name = rails[rail_id]->ep_name; @@ -817,11 +823,31 @@ nixl_status_t nixlLibfabricRailManager::deserializeRailEndpoints( nixlSerDes &ser_des, const std::string &key_prefix, - size_t expected_count, std::vector> &endpoints_out) const { - endpoints_out.resize(expected_count); - for (size_t rail_id = 0; rail_id < expected_count; ++rail_id) { + std::string str; + unsigned long num_rails_val; + try { + str = ser_des.getStr(NUM_RAILS_TAG); + num_rails_val = std::stoul(str); + if (num_rails_val > std::numeric_limits::max()) { + NIXL_ERROR << "Key " << NUM_RAILS_TAG + << " value out of range (size_t): " << num_rails_val; + return NIXL_ERR_BACKEND; + } + } + catch (const std::invalid_argument &) { + NIXL_ERROR << "Key " << NUM_RAILS_TAG << " not found or invalid."; + return NIXL_ERR_BACKEND; + } + catch (const std::out_of_range &) { + NIXL_ERROR << "Key " << NUM_RAILS_TAG << " value out of range (unsigned long): " << str; + return NIXL_ERR_BACKEND; + } + const size_t num_rails = static_cast(num_rails_val); + endpoints_out.resize(num_rails); + + for (size_t rail_id = 0; rail_id < num_rails; ++rail_id) { std::string rail_key = key_prefix + std::to_string(rail_id); // First check if the key exists and get its length @@ -846,7 +872,7 @@ nixlLibfabricRailManager::deserializeRailEndpoints( } } - NIXL_DEBUG << "Successfully deserialized " << expected_count << " rail endpoints"; + NIXL_DEBUG << "Successfully deserialized " << num_rails << " rail endpoints."; return NIXL_SUCCESS; } diff --git a/src/utils/libfabric/libfabric_rail_manager.h b/src/utils/libfabric/libfabric_rail_manager.h index fe68e962d0..97ead734f3 100644 --- a/src/utils/libfabric/libfabric_rail_manager.h +++ b/src/utils/libfabric/libfabric_rail_manager.h @@ -140,14 +140,15 @@ class nixlLibfabricRailManager { /** Insert addresses into address vectors for all rails of specified type * @param rail_type Type of rails to operate on (DATA or CONTROL) * @param endpoints Remote endpoint addresses to insert - * @param fi_addrs_out Libfabric address handles for inserted endpoints + * @param fi_addrs_out Libfabric address handles for inserted endpoints, + * indexed by local rail id. * @param ep_names_out Local endpoint names for reference * @return NIXL_SUCCESS on success, error code on failure */ nixl_status_t insertAllAddresses(RailType rail_type, const std::vector> &endpoints, - std::vector &fi_addrs_out, + std::unordered_map> &fi_addrs_out, std::vector &ep_names_out); /** Clean up connection resources for specified rail type * @param rail_type Type of rails to clean up (DATA or CONTROL) @@ -165,6 +166,7 @@ class nixlLibfabricRailManager { * @param selected_rails Rails to use for the transfer * @param local_mrs Local memory registrations * @param remote_keys Remote access keys + * @param remote_selected_endpoints Selected remote endpoints, where remote keys are registered * @param dest_addrs Destination addresses for each rail * @param agent_idx Remote agent index for immediate data * @param completion_callback Callback for completion notification @@ -179,7 +181,8 @@ class nixlLibfabricRailManager { const std::vector &selected_rails, const std::vector &local_mrs, const std::vector &remote_keys, - const std::vector &dest_addrs, + const std::vector &remote_selected_endpoints, + const std::unordered_map> &dest_addrs, uint16_t agent_idx, std::function completion_callback, BinaryNotification *binary_notif); @@ -271,12 +274,14 @@ class nixlLibfabricRailManager { serializeMemoryKeys(const std::vector &keys, void *buffer, std::string &str) const; /** Deserialize memory keys and remote address * @param serialized_data Serialized memory information + * @param num_keys Number of keys * @param keys_out Remote access keys for all rails * @param remote_addr_out Remote buffer address * @return NIXL_SUCCESS on success, error code on failure */ nixl_status_t deserializeMemoryKeys(const std::string &serialized_data, + const size_t num_keys, std::vector &keys_out, uint64_t &remote_addr_out) const; // SerDes-based Connection Info Serialization @@ -333,7 +338,6 @@ class nixlLibfabricRailManager { deserializeRailEndpoints( nixlSerDes &ser_des, const std::string &key_prefix, - size_t expected_count, std::vector> &endpoints_out) const; };