diff --git a/CMakeLists.txt b/CMakeLists.txt index 4007994..caec9a2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) +find_package(Options REQUIRED) if("$ENV{USE_CUDA}" STREQUAL "0") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_ZMQ -DDMLC_USE_RDMA -DSTEPMESH_USE_GDR") @@ -67,3 +68,10 @@ target_link_libraries(af ${TORCH_PYTHON_LIBRARY}) add_subdirectory(tests) + +if(ENABLE_PLUGIN) + add_subdirectory(plugins/klx_backend) + add_custom_target(dummy_target) + add_dependencies(klx_backend af) + add_dependencies(dummy_target af klx_backend) +endif() diff --git a/Makefile b/Makefile index 1dfcbfe..1a33c4b 100644 --- a/Makefile +++ b/Makefile @@ -107,3 +107,10 @@ af: @cd cmake_build; cmake .. -DCMAKE_CUDA_COMPILER=$(CMAKE_CUDA_COMPILER) -DPython_EXECUTABLE=$(shell which python3) -DCUDA_TOOLKIT_ROOT_DIR=$(CUDA_TOOLKIT_ROOT_DIR); make -j @mkdir -p build +plugin: + @mkdir -p cmake_build + @cd cmake_build; cmake .. -DENABLE_PLUGIN=ON -DCMAKE_CUDA_COMPILER=$(CMAKE_CUDA_COMPILER) -DENABLE_PLUGIN=ON -DPython_EXECUTABLE=$(shell which python3) -DCUDA_TOOLKIT_ROOT_DIR=$(CUDA_TOOLKIT_ROOT_DIR); make -j + @mkdir -p build + + + diff --git a/cmake/FindOptions.cmake b/cmake/FindOptions.cmake new file mode 100644 index 0000000..d312773 --- /dev/null +++ b/cmake/FindOptions.cmake @@ -0,0 +1 @@ +option(ENABLE_PLUGIN "Enable plugin feature" OFF) diff --git a/fserver/csrc/public.hpp b/fserver/csrc/public.hpp index 7299ef9..dd9c50d 100644 --- a/fserver/csrc/public.hpp +++ b/fserver/csrc/public.hpp @@ -1,4 +1,7 @@ /* Copyright (c) 2025, StepFun Authors. All rights reserved. */ + +#include + #include #include #include @@ -43,6 +46,7 @@ uint64_t handler_counter_ = 0; std::unordered_map meta_map_; std::vector> q_; std::atomic q_signal_; +static void* gPluginHandle = nullptr; void RequestHandler(const AFTensorMeta& req_meta, AFTensorServer* server) { std::vector tensors; @@ -164,7 +168,12 @@ void barrier(bool include_server, bool include_worker, bool instrance_barrier=tr } -void init() { +void init(const std::string& plugin) { + if (!plugin.empty()) { + gPluginHandle = dlopen(plugin.c_str(), RTLD_NOW); + PS_CHECK(gPluginHandle) + << "can't load plugin:" << plugin << ": " << dlerror(); + } std::string role_str = ps::GetEnv("DMLC_ROLE", "server"); int offset = 0; @@ -203,14 +212,17 @@ void stop() { ps::Postoffice::GetWorker(gpu_)->DoBarrier(0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true); } else if (role_ == Node::SERVER) { - ps::Postoffice::GetServer(gpu_)->DoBarrier(0, - ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true); + ps::Postoffice::GetServer(gpu_)->DoBarrier( + 0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true); } else { - ps::Postoffice::Get()->DoBarrier(0, - ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true); + ps::Postoffice::Get()->DoBarrier( + 0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true); } ps::Finalize(0, role_, true); + if (gPluginHandle) { + dlclose(gPluginHandle); + } } std::vector get_all_handlers(int handler) { @@ -237,7 +249,8 @@ uint64_t get_nanosecond() { void pybind_public(py::module &m){ - m.def("init", &init, py::call_guard()); + m.def("init", &init, py::arg("plugin") = "", + py::call_guard()); m.def("stop", &stop, py::call_guard()); m.def("register_recv_buffer", diff --git a/include/dmlc/backend_registry.h b/include/dmlc/backend_registry.h new file mode 100644 index 0000000..593ab3c --- /dev/null +++ b/include/dmlc/backend_registry.h @@ -0,0 +1,46 @@ +#pragma once + +#include "base.h" +#include "ps/backend.h" + +namespace dmlc { + +// StepMesh Backend Registration +// +// StepMesh uses a **dynamic backend registration mechanism** to decouple +// specific backend implementations (e.g., RDMA, CPU, GPU, future XPUs) from +// the core transport layer. +// +// Motivation: +// * Avoid hard-coding backend logic into the core codebase. +// * Support extensibility: new backends can be plugged in without touching +// core StepMesh logic. +// * Enable backend discovery at runtime via a registry pattern. +// +// Contract Requirements: +// * Each backend must define a unique type identifier. +// * Backends must satisfy the BackendInterface contract (init, push/pull, +// etc). +// * Registration should occur during StepMesh startup or module load. +// +// Backend registry responsibilities: +// * Store mappings from backend ID to constructor/factory. +// * Provide factory APIs to create instances by ID. +// * Ensure correct initialization order (core first, backends next). +// +// Example use cases: +// * Enable support for new accelerator types (e.g., NPU, TPU-like devices). +// * Abstract transport layer differences while exposing uniform APIs. +// +// Note: This refactor **does not add a new mandatory backend** — existing +// backends continue to work without modification unless explicitly replaced. +// It simply provides the structural foundation for extensibility. +// This design choice is focused on modularity and future-proofing. +template +struct STEPMESH_API backend_registry { + backend_registry(const std::string& name) { + ps::Backend::RegisterLazy(name, []() { return new T(); }); + } +}; + +} // namespace dmlc diff --git a/include/dmlc/base.h b/include/dmlc/base.h index c2a9365..1dd0478 100644 --- a/include/dmlc/base.h +++ b/include/dmlc/base.h @@ -190,4 +190,6 @@ inline const char *BeginPtr(const std::string &str) { #define alignof __alignof #endif +#define STEPMESH_API __attribute__((__visibility__("default"))) + #endif // DMLC_BASE_H_ diff --git a/include/dmlc/logging.h b/include/dmlc/logging.h index da641e6..4d70f07 100644 --- a/include/dmlc/logging.h +++ b/include/dmlc/logging.h @@ -71,25 +71,35 @@ inline void InitLogging(const char *argv0) { // DO NOTHING } +constexpr const char *getFileName(const char *path) { + auto last = path + strlen(path); + while (*last != '/') { + --last; + } + return ++last; +} + // Always-on checking -#define PS_CHECK(x) \ - if (!(x)) \ - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \ - "failed: " #x \ - << ' ' +#define PS_CHECK(x) \ + if (!(x)) \ + dmlc::LogMessageFatal(dmlc::getFileName(__FILE__), __LINE__).stream() \ + << "Check " \ + "failed: " #x \ + << ' ' #define PS_CHECK_LT(x, y) PS_CHECK((x) < (y)) #define PS_CHECK_GT(x, y) PS_CHECK((x) > (y)) #define PS_CHECK_LE(x, y) PS_CHECK((x) <= (y)) #define PS_CHECK_GE(x, y) PS_CHECK((x) >= (y)) #define PS_CHECK_EQ(x, y) PS_CHECK((x) == (y)) #define PS_CHECK_NE(x, y) PS_CHECK((x) != (y)) -#define PS_CHECK_NOTNULL(x) \ - ((x) == NULL \ - ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', \ +#define PS_CHECK_NOTNULL(x) \ + ((x) == NULL \ + ? dmlc::LogMessageFatal(dmlc::getFileName(__FILE__), __LINE__).stream() \ + << "Check notnull: " #x << ' ', \ (x) : (x)) // NOLINT(*) // Debug-only checking. #ifdef NDEBUG -/* +/* #define DPS_CHECK(x) \ while (false) PS_CHECK(x) #define DPS_CHECK_LT(x, y) \ @@ -114,12 +124,12 @@ inline void InitLogging(const char *argv0) { #define DPS_CHECK_NE(x, y) PS_CHECK((x) != (y)) */ #endif // NDEBUG -#define PS_LOG_API dmlc::LogMessage(__FILE__, __LINE__) +#define PS_LOG_API dmlc::LogMessage(dmlc::getFileName(__FILE__), __LINE__) #define PS_LOG_IF(severity, condition) \ !(condition) ? (void)0 : dmlc::LogMessageVoidify() & PS_LOG_API -#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__) +#define LOG_FATAL dmlc::LogMessageFatal(dmlc::getFileName(__FILE__), __LINE__) #define PS_LOG_FATAL LOG_FATAL.stream() #define LOG_QFATAL LOG_FATAL diff --git a/include/ps/af_tensor_app.h b/include/ps/af_tensor_app.h index 2f71765..4eaaf74 100644 --- a/include/ps/af_tensor_app.h +++ b/include/ps/af_tensor_app.h @@ -18,9 +18,9 @@ #include #include +#include "ps/backend.h" #include "ps/base.h" #include "ps/hash_table8.hpp" -#include "ps/internal/backend.h" #include "ps/internal/utils.h" #include "ps/kv_app.h" @@ -236,7 +236,8 @@ class AFTensorWorker { void ZPush_(int ts, const SArray& keys, const at::Tensor& tensor, int cmd = 0) { SArray val; - val.reset(reinterpret_cast(tensor.data_ptr()), + void* mappedPtr = Backend::Get()->GetAccessibleAddr(tensor); + val.reset(reinterpret_cast(mappedPtr), tensor.numel() * tensor.itemsize(), [tensor](void*) {}); Message msg; @@ -244,7 +245,7 @@ class AFTensorWorker { msg.meta.head = cmd; msg.meta.push = true; msg.meta.timestamp = ts; - msg.meta.addr = reinterpret_cast(tensor.data_ptr()); + msg.meta.addr = reinterpret_cast(mappedPtr); msg.meta.val_len = tensor.numel() * tensor.itemsize(); PS_VLOG(2) << "ZPush_ addr: 0x" << std::hex << msg.meta.addr << std::dec << " val_len: " << msg.meta.val_len; @@ -284,13 +285,14 @@ class AFTensorWorker { *key.data() = pull_tensors[i * pull_batch_size + index].key; - val.reset(reinterpret_cast(tensor.data_ptr()), + void* mappedPtr = Backend::Get()->GetAccessibleAddr(tensor); + val.reset(reinterpret_cast(mappedPtr), tensor.numel() * tensor.itemsize(), [tensor](void*) {}); msg.meta.request = true; msg.meta.head = cmd; msg.meta.push = false; - msg.meta.addr = reinterpret_cast(tensor.data_ptr()); + msg.meta.addr = reinterpret_cast(mappedPtr); msg.meta.val_len = tensor.numel() * tensor.itemsize(); msg.meta.key = key[0]; msg.meta.is_tensor = 1; @@ -483,7 +485,8 @@ class AFTensorServer { res.keys = key; SArray tensor_val; - tensor_val.reset(reinterpret_cast(tensors[0].val.data_ptr()), + tensor_val.reset(reinterpret_cast( + Backend::Get()->GetAccessibleAddr(tensors[0].val)), tensors[0].val.numel() * tensors[0].val.itemsize(), [](void*) {}); res.vals = tensor_val; @@ -506,7 +509,8 @@ class AFTensorServer { rsp.kv_pair.keys = key; rsp.kv_pair.vals.reset( - reinterpret_cast(res_kv.val.data_ptr()), + reinterpret_cast( + Backend::Get()->GetAccessibleAddr(res_kv.val)), res_kv.val.numel() * res_kv.val.itemsize(), [](void*) {}); rsp.kv_meta = kv_meta; @@ -558,7 +562,8 @@ class AFTensorServer { PS_CHECK_GT(worker_ranks.size(), 0) << "ranks or keys should not be empty"; PS_CHECK_EQ(worker_ranks.size(), keys.size()) << "rank list and key list have unequal size"; - char* buffer_ptr = reinterpret_cast(tensor.data_ptr()); + char* buffer_ptr = + reinterpret_cast(Backend::Get()->GetAccessibleAddr(tensor)); uint64_t data_size = tensor.numel() * tensor.element_size(); int chunk_size = data_size / worker_ranks.size(); PS_CHECK_EQ(data_size % worker_ranks.size(), 0) @@ -591,8 +596,14 @@ class AFTensorServer { .dtype(at::ScalarType(req_meta.dtype)) .memory_format(at::MemoryFormat::Contiguous) .device(Backend::Get()->GetDevice()); - key_tensor.val = - at::from_blob(req_data.vals.data(), req_meta.shape, options); + key_tensor.val = at::from_blob( + Backend::Get()->GetDeviceAddrFromHostPtr( + req_data.vals.data(), + std::accumulate(std::begin(req_meta.shape), + std::end(req_meta.shape), + c10::elementSize(at::ScalarType(req_meta.dtype)), + std::multiplies())), + req_meta.shape, options); } key_tensor.key = req_data.keys[0]; return key_tensor; diff --git a/include/ps/internal/backend.h b/include/ps/backend.h similarity index 71% rename from include/ps/internal/backend.h rename to include/ps/backend.h index 63dfb39..3b1dde3 100644 --- a/include/ps/internal/backend.h +++ b/include/ps/backend.h @@ -11,13 +11,15 @@ #endif #include +#include #include #include #include #include +#include "base.h" #include "dmlc/logging.h" -#include "ps/internal/env.h" +#include "ps/env.h" namespace ps { @@ -26,7 +28,7 @@ enum { BACKEND_OK = 0, BACKEND_FAILED = -1 }; /** * \brief Abstract Backend Class */ -class Backend { +class STEPMESH_API Backend { public: /** * \brief Set device index for current thread @@ -88,6 +90,24 @@ class Backend { */ virtual int SyncEvent(void* event) = 0; + /** + *\brief Get an address that is directly readable via the PCIe bus + * @param devicePtr device physical address + * @return an address that is directly readable via the PCIe bus + */ + virtual void* GetAccessibleAddr(void* devicePtr, size_t size) { + return devicePtr; + } + + virtual void* GetAccessibleAddr(const at::Tensor& tensor) { + return GetAccessibleAddr(tensor.data_ptr(), + tensor.numel() * tensor.element_size()); + } + + virtual void* GetDeviceAddrFromHostPtr(void* hostPtr, size_t size) { + return hostPtr; + } + /** * \brief Get the backend implementation * @return the backend implementation @@ -98,12 +118,17 @@ class Backend { RegisterImpl(name, backend); } + static void RegisterLazy(const std::string& name, + const std::function& ctor); + protected: Backend() = default; private: static std::mutex backends_mutex_; static std::unordered_map backends_; + static std::unordered_map> + backend_ctors_; static Backend* GetImpl() { static Backend* backend_impl = nullptr; @@ -113,9 +138,13 @@ class Backend { return backend_impl; } std::string backend_type = "GPU"; - backend_type = Environment::Get()->find("STEPMESH_BAKCEND", backend_type); - PS_CHECK_NE(backends_.find(backend_type), backends_.end()) - << "failed to get backend impl: " << backend_type; + backend_type = Environment::Get()->find("STEPMESH_BACKEND", backend_type); + if (backends_.find(backend_type) == backends_.end()) { + PS_CHECK_NE(backend_ctors_.find(backend_type), backend_ctors_.end()) + << "failed to get backend impl: " << backend_type; + backends_[backend_type] = backend_ctors_[backend_type](); + } + backend_impl = backends_[backend_type]; } return backend_impl; diff --git a/include/ps/internal/env.h b/include/ps/env.h similarity index 100% rename from include/ps/internal/env.h rename to include/ps/env.h diff --git a/include/ps/internal/cpu_backend.h b/include/ps/internal/cpu_backend.h index 293cec5..4fc1dc2 100644 --- a/include/ps/internal/cpu_backend.h +++ b/include/ps/internal/cpu_backend.h @@ -7,7 +7,7 @@ #include -#include "ps/internal/backend.h" +#include "ps/backend.h" namespace ps { diff --git a/include/ps/internal/gpu_backend.h b/include/ps/internal/gpu_backend.h index a93a21c..cfcc4aa 100644 --- a/include/ps/internal/gpu_backend.h +++ b/include/ps/internal/gpu_backend.h @@ -7,7 +7,7 @@ #include -#include "ps/internal/backend.h" +#include "ps/backend.h" namespace ps { diff --git a/include/ps/internal/message.h b/include/ps/internal/message.h index 5f0e36e..ae4b90c 100644 --- a/include/ps/internal/message.h +++ b/include/ps/internal/message.h @@ -17,7 +17,7 @@ #ifdef STEPMESH_USE_TORCH #endif // STEPMESH_USE_TORCH -#include "ps/internal/backend.h" +#include "ps/backend.h" #include "ps/internal/multi_qp.h" #include "ps/internal/trace.h" #include "ps/sarray.h" diff --git a/include/ps/internal/postoffice.h b/include/ps/internal/postoffice.h index 47ba7be..57d8cf6 100644 --- a/include/ps/internal/postoffice.h +++ b/include/ps/internal/postoffice.h @@ -11,8 +11,8 @@ #include #include +#include "ps/env.h" #include "ps/internal/customer.h" -#include "ps/internal/env.h" #include "ps/internal/van.h" #include "ps/range.h" diff --git a/include/ps/internal/threadsafe_queue.h b/include/ps/internal/threadsafe_queue.h index e6b6b19..c986436 100644 --- a/include/ps/internal/threadsafe_queue.h +++ b/include/ps/internal/threadsafe_queue.h @@ -12,7 +12,7 @@ #include "dmlc/logging.h" #include "ps/base.h" -#include "ps/internal/env.h" +#include "ps/env.h" #include "ps/internal/spsc_queue.h" namespace ps { diff --git a/include/ps/internal/utils.h b/include/ps/internal/utils.h index a2381b4..708eaa9 100644 --- a/include/ps/internal/utils.h +++ b/include/ps/internal/utils.h @@ -19,7 +19,7 @@ #include #include "dmlc/logging.h" -#include "ps/internal/env.h" +#include "ps/env.h" namespace ps { diff --git a/include/ps/ps.h b/include/ps/ps.h index 1618eae..5c364f6 100644 --- a/include/ps/ps.h +++ b/include/ps/ps.h @@ -18,7 +18,7 @@ #include "ps/kv_app.h" /** \brief tensor-based communication with a list of attention and ffn nodes. */ #include "ps/af_tensor_app.h" -#include "ps/internal/backend.h" +#include "ps/backend.h" #include "ps/internal/cpu_backend.h" #include "ps/internal/gpu_backend.h" diff --git a/plugins/klx_backend/CMakeLists.txt b/plugins/klx_backend/CMakeLists.txt new file mode 100644 index 0000000..3ec8842 --- /dev/null +++ b/plugins/klx_backend/CMakeLists.txt @@ -0,0 +1,62 @@ +# You can build either from the source tree or from the released Python package. + +cmake_minimum_required(VERSION 3.22 FATAL_ERROR) + +project(klx_backend LANGUAGES C CXX) + +execute_process(COMMAND ${Python_EXECUTABLE} + -c "import torch; print(int(torch.compiled_with_cxx11_abi()))" + OUTPUT_VARIABLE TORCH_CXX11_ABI OUTPUT_STRIP_TRAILING_WHITESPACE) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ") + +set(CMAKE_CXX_STANDARD 17) + +set(CMAKE_VERBOSE_MAKEFILE OFF) + +find_package(Python3 REQUIRED COMPONENTS Interpreter Development) +execute_process( + COMMAND ${Python3_EXECUTABLE} -c "import sysconfig;print(sysconfig.get_config_var('SOABI'))" + OUTPUT_VARIABLE PY_SOABI + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +execute_process(COMMAND ${Python_EXECUTABLE} + -c "import torch; print(torch.utils.cmake_prefix_path)" + OUTPUT_VARIABLE PYTORCH_CMAKE_PREFIX_PATH OUTPUT_STRIP_TRAILING_WHITESPACE) +list(APPEND CMAKE_PREFIX_PATH "${PYTORCH_CMAKE_PREFIX_PATH}/Torch") +find_package(Torch REQUIRED CONFIG) + +add_library(klx_backend SHARED klx_backend.cc) +target_include_directories(klx_backend PUBLIC + ${TORCH_INCLUDE_DIRS} + ${CMAKE_CURRENT_SOURCE_DIR}/../../include +) + +set(MODULE_NAME fserver_lib) +set(PY_EXT_SUFFIX ".so") + +set(PYTHON_EXTENSION_NAME + "${MODULE_NAME}.${PY_SOABI}${PY_EXT_SUFFIX}" +) + +find_package(CUDAToolkit REQUIRED) +target_link_libraries(klx_backend + PRIVATE + CUDA::cudart + ${TORCH_LIBRARIES} + ${TORCH_PYTHON_LIBRARY} +) + +target_link_options(klx_backend PRIVATE + -Wl,-rpath,${CMAKE_CURRENT_SOURCE_DIR}/../../${PYTHON_EXTENSION_NAME} +) + +install(TARGETS klx_backend + LIBRARY DESTINATION lib +) + +#set_target_properties(klx_backend PROPERTIES +# BUILD_RPATH "${CMAKE_BINARY_DIR}/../" +# INSTALL_RPATH "\$ORIGIN" +# BUILD_WITH_INSTALL_RPATH TRUE +#) diff --git a/plugins/klx_backend/klx_backend.cc b/plugins/klx_backend/klx_backend.cc new file mode 100644 index 0000000..be93af9 --- /dev/null +++ b/plugins/klx_backend/klx_backend.cc @@ -0,0 +1,324 @@ +/** + * Copyright (C) by StepAI Contributors. 2025. + */ + +#include +#include +#include + +#include + +#include "dmlc/backend_registry.h" +#include "ps/backend.h" +#include "ps/internal/gpu_backend.h" + +#define KLX_RT_CALL(func, ...) \ + do { \ + auto klx_errno = func(__VA_ARGS__); \ + PS_CHECK_EQ(klx_errno, 0) \ + << #func << " failed err:" << cudaGetErrorString(klx_errno); \ + } while (0) + +#define USE_MMAP_ALLOC +#undef USE_MMAP_ALLOC + +namespace klx { + +using namespace ps; + +class KlxBackend : public Backend { + public: + KlxBackend(); + int SetDevice(int dev) override; + int GetDeviceId() override; + at::Device GetDevice() override; + void* Alloc(uint64_t size) override; + void Free(void* m) override; + void* CreateEvent() override; + int FreeEvent(void* event) override; + int RecordEvent(void* event, void* stream) override; + int SyncEvent(void* event) override; + + void* GetAccessibleAddr(void* devicePtr, size_t size) final; + + void* GetAccessibleAddr(const at::Tensor& tensor) final; + + void* GetDeviceAddrFromHostPtr(void* hostPtr, size_t size) final; + + private: + void* CreateCudaEvent(); + int FreeCudaEvent(void* event); + int RecordCudaEvent(void* event, void* stream); + int SyncCudaEvent(void* event); + + void* CreateMemEvent(); + int FreeMemEvent(void* event); + int RecordMemEvent(void* event, void* stream); + int SyncMemEvent(void* event); + + private: + inline void DoInitGpu() { + static thread_local int gpu_idx = -1; + if (gpu_idx == -1) { + PS_CHECK_GE(gpu_idx_, 0) + << "cannot set device " << gpu_idx_ << " for gpu backend"; + SetDevice(gpu_idx_); + gpu_idx = gpu_idx_; + } + } + + /** \brief for cpu backend, the device stands for numa id */ + int gpu_idx_ = -1; + int mem_sync_ = 1; + // host address to device address map + std::unordered_map ha_da_map_; +}; + +KlxBackend::KlxBackend() { + Environment::Get()->find("STEPMESH_MEM_SYNC", &mem_sync_, mem_sync_); + PS_LOG(INFO) << "create klx backend"; +} + +int KlxBackend::SetDevice(int dev) { + PS_CHECK_GE(dev, 0) << "cannot set dev=" << dev << " for gpu backend"; + PS_CHECK_LE(dev, 7) << "cannot set dev=" << dev << " for gpu backend"; + static thread_local int gpu_idx = -1; + + gpu_idx_ = dev; + if (gpu_idx == -1 || gpu_idx != gpu_idx_) { + gpu_idx = gpu_idx_; + KLX_RT_CALL(cudaSetDevice, gpu_idx_); + } + + return BACKEND_OK; +} + +int KlxBackend::GetDeviceId() { + static thread_local int gpu_idx = -1; + if (gpu_idx != -1) { + KLX_RT_CALL(cudaGetDevice, &gpu_idx); + } + return gpu_idx; +} + +at::Device KlxBackend::GetDevice() { + PS_CHECK_GE(gpu_idx_, 0) << "device index is not initialized for gpu backend"; + return {at::kCUDA, static_cast(gpu_idx_)}; +} + +void* KlxBackend::Alloc(uint64_t size) { + DoInitGpu(); + void* ptr = nullptr; + KLX_RT_CALL(cudaMalloc, &ptr, size); + auto hostPtr = GetAccessibleAddr(ptr, size); + return hostPtr; +} + +void KlxBackend::Free(void* m) { +#ifdef USE_MMAP_ALLOC + if (ha_da_map_.find(m) != ha_da_map_.end()) { + m = ha_da_map_[m]; + free(m); + } +#endif + PS_CHECK_NE(m, nullptr) << "backend cannot free null memory"; + PS_VLOG(3) << "free gpu memory " << m; + if (ha_da_map_.erase(m)) { + m = ha_da_map_[m]; + } + KLX_RT_CALL(cudaFree, m); +} + +void* KlxBackend::GetAccessibleAddr(void* devicePtr, size_t size) { +#ifdef USE_MMAP_ALLOC + void* buf = mmap(nullptr, size, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + cudaMemcpy(buf, devicePtr, size, cudaMemcpyDeviceToHost); + ha_da_map_.emplace(buf, devicePtr); + return buf; +#endif + + struct cudaPointerAttributes attrs; + KLX_RT_CALL(cudaPointerGetAttributes, &attrs, devicePtr); + PS_LOG(INFO) << "GetAccessibleAddr devicePtr=" << devicePtr + << " hostPtr=" << attrs.hostPointer; + // size_t pagesz = sysconf(_SC_PAGESIZE); + // PS_CHECK_EQ(((uintptr_t)attrs.hostPointer % pagesz), 0) << "unaligned host + // ptr"; + + if (ha_da_map_.find(attrs.hostPointer) != ha_da_map_.end()) { + return reinterpret_cast(attrs.hostPointer) + + (reinterpret_cast(devicePtr) - + reinterpret_cast(ha_da_map_[attrs.hostPointer])); + } + ha_da_map_.emplace(attrs.hostPointer, devicePtr); + + return attrs.hostPointer; +} + +void* KlxBackend::GetAccessibleAddr(const at::Tensor& tensor) { + if (tensor.device().type() == at::kCUDA) { + return GetAccessibleAddr(tensor.data_ptr(), + tensor.numel() * tensor.element_size()); + } + PS_CHECK_EQ(tensor.device().type(), at::kCPU); + return tensor.data_ptr(); +} + +void* KlxBackend::GetDeviceAddrFromHostPtr(void* hostPtr, size_t size) { + PS_CHECK_NE(ha_da_map_.find(hostPtr), ha_da_map_.end()); +#ifdef USE_MMAP_ALLOC + KLX_RT_CALL(cudaMemcpy, ha_da_map_[hostPtr], hostPtr, size, + cudaMemcpyHostToDevice); +#endif + return ha_da_map_[hostPtr]; +} + +void* KlxBackend::CreateEvent() { + DoInitGpu(); + if (!mem_sync_) { + return CreateCudaEvent(); + } else { + return CreateMemEvent(); + } +} + +int KlxBackend::FreeEvent(void* event) { + DoInitGpu(); + PS_CHECK_NE(event, nullptr) << "backend cannot free null event"; + if (!mem_sync_) { + return FreeCudaEvent(event); + } else { + return FreeMemEvent(event); + } +} + +int KlxBackend::RecordEvent(void* event, void* stream) { + DoInitGpu(); + PS_CHECK_NE(event, nullptr) << "backend cannot record null event"; + if (!mem_sync_) { + return RecordCudaEvent(event, stream); + } else { + return RecordMemEvent(event, stream); + } +} + +int KlxBackend::SyncEvent(void* event) { + DoInitGpu(); + PS_CHECK_NE(event, nullptr) << "backend cannot sync null event"; + if (!mem_sync_) { + return SyncCudaEvent(event); + } else { + return SyncMemEvent(event); + } +} + +void* KlxBackend::CreateCudaEvent() { + cudaEvent_t* ev = nullptr; + cudaMallocHost(&ev, sizeof(cudaEvent_t)); + auto status = cudaEventCreateWithFlags(ev, cudaEventDisableTiming); + PS_CHECK_EQ(status, cudaSuccess) + << "cudaEventCreateWithFlags failed for gpu " << gpu_idx_; + return reinterpret_cast(ev); +} + +int KlxBackend::FreeCudaEvent(void* event) { + auto ev = reinterpret_cast(event); + cudaError_t err = cudaEventDestroy(*ev); + PS_CHECK_EQ(err, cudaSuccess) + << "cudaEventDestroy failed for event " << reinterpret_cast(event) + << " (" << cudaGetErrorString(err) << ")"; + cudaFreeHost(ev); + return BACKEND_OK; +} + +int KlxBackend::RecordCudaEvent(void* event, void* stream) { + cudaStream_t cuda_stream; + if (stream == nullptr) { + cuda_stream = at::cuda::getCurrentCUDAStream().stream(); + } else { + cuda_stream = reinterpret_cast(stream); + } + + auto ev = reinterpret_cast(event); + auto status = cudaEventRecord(*ev, cuda_stream); + if (status == cudaSuccess) { + return BACKEND_OK; + } else { + PS_LOG(WARNING) << "failed to record cuda event: " + << " (" << cudaGetErrorString(status) << ")"; + return BACKEND_FAILED; + } +} + +int KlxBackend::SyncCudaEvent(void* event) { + auto ev = reinterpret_cast(event); + cudaError_t status; + while (true) { + status = cudaEventQuery(*ev); + if (status == cudaErrorNotReady) { + sched_yield(); + continue; + } + break; + } + if (status != cudaSuccess) { + PS_LOG(WARNING) << "failed to sync cuda event: " + << " (" << cudaGetErrorString(status) << ")"; + return BACKEND_FAILED; + } + + return BACKEND_OK; +} + +struct KlxBackendMemEvent { + int* gpu_flag = nullptr; + int* cpu_flag = nullptr; +}; + +void* KlxBackend::CreateMemEvent() { + struct KlxBackendMemEvent* ev = nullptr; + AT_CUDA_CHECK(cudaMallocHost(&ev, sizeof(KlxBackendMemEvent))); + AT_CUDA_CHECK(cudaMalloc(&(ev->gpu_flag), sizeof(int))); + AT_CUDA_CHECK(cudaMemset(ev->gpu_flag, 0, sizeof(int))); + AT_CUDA_CHECK( + cudaMallocHost(reinterpret_cast(&(ev->cpu_flag)), sizeof(int))); + *ev->cpu_flag = 0; + return reinterpret_cast(ev); +} + +int KlxBackend::FreeMemEvent(void* event) { + auto ev = reinterpret_cast(event); + AT_CUDA_CHECK(cudaFree(ev->gpu_flag)); + AT_CUDA_CHECK(cudaFreeHost(reinterpret_cast(ev->cpu_flag))); + AT_CUDA_CHECK(cudaFreeHost(ev)); + return BACKEND_OK; +} + +int KlxBackend::RecordMemEvent(void* event, void* stream) { + auto ev = reinterpret_cast(event); + *(ev->cpu_flag) = 1; + cudaStream_t cuda_stream; + if (stream == nullptr) { + cuda_stream = at::cuda::getCurrentCUDAStream().stream(); + } else { + cuda_stream = reinterpret_cast(stream); + } + + AT_CUDA_CHECK(cudaMemcpyAsync(reinterpret_cast(ev->cpu_flag), + ev->gpu_flag, sizeof(int), + cudaMemcpyDeviceToHost, cuda_stream)); + return BACKEND_OK; +} + +int KlxBackend::SyncMemEvent(void* event) { + auto ev = reinterpret_cast(event); + while (*(ev->cpu_flag) == 1) { + _mm_pause(); + } + return BACKEND_OK; +} + +dmlc::backend_registry _("KLX"); + +} // namespace klx diff --git a/setup.py b/setup.py index c27d1e3..cdc4f9c 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,24 @@ def _get_cuda_bare_metal_version(cuda_dir): return bare_metal_major, bare_metal_minor +def filter_cuda_arch_and_code(cuda_dir, arch_list, code_list): + assert len(arch_list) == len(code_list), "arch_list and code_list should have the same length" + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "--list-gpu-arch"], + universal_newlines=True) + nvcc_arch_list = raw_output.strip().split('\n') + + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "--list-gpu-code"], + universal_newlines=True) + nvcc_code_list = raw_output.strip().split('\n') + + i = 0 + while i != len(arch_list): + if arch_list[i] not in nvcc_arch_list or code_list[i] not in nvcc_code_list: + del arch_list[i] + del code_list[i] + else: + i += 1 + __SRC_PATH__ = 'fserver/csrc/' __PS_PATH__ = f'{Path.cwd()}' @@ -62,8 +80,14 @@ def _get_cuda_bare_metal_version(cuda_dir): if use_cuda: extra_link += ['-lcuda', '-lcudart'] extra_compile_args['cxx'] += ['-DDMLC_USE_CUDA',] - extra_compile_args['nvcc'] = ['-O3', '-gencode', 'arch=compute_90,code=sm_90', '-gencode', 'arch=compute_80,code=sm_80', '-gencode', 'arch=compute_89,code=sm_89','-gencode', 'arch=compute_90a,code=sm_90a', - '--use_fast_math', f'-D_GLIBCXX_USE_CXX11_ABI={str(int(torch_cxx11_abi))}'] + cc_flag + cuda_arch_list = ['compute_90', 'compute_80', 'compute_89', 'compute_90a'] + cuda_code_list = ['sm_90', 'sm_80', 'sm_89', 'sm_90a'] + filter_cuda_arch_and_code(cpp_extension.CUDA_HOME, cuda_arch_list, cuda_code_list) + gencode_list = [] + for a,c in zip(cuda_arch_list, cuda_code_list): + gencode_list.append('-gencode') + gencode_list.append(f'arch={a},code={c}') + extra_compile_args['nvcc'] = ['-O3'] + gencode_list + ['--use_fast_math', f'-D_GLIBCXX_USE_CXX11_ABI={str(int(torch_cxx11_abi))}'] + cc_flag bare_metal_major, bare_metal_minor = \ _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) diff --git a/src/backend/backend.cc b/src/backend/backend.cc index 7d360f5..c15bb12 100644 --- a/src/backend/backend.cc +++ b/src/backend/backend.cc @@ -2,7 +2,7 @@ * Copyright (C) by StepAI Contributors. 2025. */ -#include "ps/internal/backend.h" +#include "ps/backend.h" #include #include @@ -11,5 +11,12 @@ namespace ps { std::mutex Backend::backends_mutex_; std::unordered_map Backend::backends_; +std::unordered_map> + Backend::backend_ctors_; + +void Backend::RegisterLazy(const std::string& name, + const std::function& ctor) { + Backend::backend_ctors_.emplace(name, ctor); +} } // namespace ps diff --git a/src/backend/cpu_backend.cc b/src/backend/cpu_backend.cc index 7265855..3f94c6b 100644 --- a/src/backend/cpu_backend.cc +++ b/src/backend/cpu_backend.cc @@ -4,7 +4,7 @@ #include "ps/internal/cpu_backend.h" -#include "ps/internal/backend.h" +#include "ps/backend.h" namespace ps { diff --git a/src/backend/gpu_backend.cc b/src/backend/gpu_backend.cc index 337452b..29110f8 100644 --- a/src/backend/gpu_backend.cc +++ b/src/backend/gpu_backend.cc @@ -8,7 +8,7 @@ #include #include -#include "ps/internal/backend.h" +#include "ps/backend.h" namespace ps { diff --git a/src/rdma_van.h b/src/rdma_van.h index 43bda5d..d424b46 100644 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -19,6 +19,8 @@ #ifdef DMLC_USE_RDMA +#include + #include #include #include @@ -35,7 +37,7 @@ namespace ps { class RDMAVan : public Van { public: - explicit RDMAVan(Postoffice *postoffice) + explicit RDMAVan(Postoffice* postoffice) : Van(postoffice), postoffice_(postoffice) { PS_CHECK_EQ(ibv_fork_init(), 0) << strerror(errno); } @@ -43,7 +45,7 @@ class RDMAVan : public Van { virtual std::string GetType() const { return std::string("rdma"); } - Postoffice *postoffice_; + Postoffice* postoffice_; protected: void Start(int customer_id, bool standalone) override { @@ -63,7 +65,8 @@ class RDMAVan : public Van { } if (event_channel_ == nullptr) { event_channel_ = rdma_create_event_channel(); - PS_CHECK(event_channel_) << "Create RDMA event channel failed"; + PS_CHECK(event_channel_) + << "Create RDMA event channel failed:" << strerror(errno); cm_event_polling_thread_.reset( new std::thread(&RDMAVan::PollEvents, this)); @@ -127,7 +130,7 @@ class RDMAVan : public Van { PS_CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ"; - for (auto &it : mem_mr_) ibv_dereg_mr(it.second); + for (auto& it : mem_mr_) ibv_dereg_mr(it.second); // TODO(non): ibv_dealloc_pd sometimes complains resource busy, need to fix // PS_CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD: " << @@ -138,7 +141,7 @@ class RDMAVan : public Van { rdma_destroy_event_channel(event_channel_); } - int Bind(Node &node, int max_retry) override { + int Bind(Node& node, int max_retry) override { PS_CHECK_EQ(my_node_.num_ports, 1) << "RDMA van does not support multiple ports"; PS_CHECK(rdma_create_id(event_channel_, &listener_, nullptr, RDMA_PS_TCP) == @@ -160,7 +163,7 @@ class RDMAVan : public Van { for (int i = 0; i < max_retry + 1; ++i) { addr.sin_port = htons(port); if (rdma_bind_addr(listener_, - reinterpret_cast(&addr)) == 0) { + reinterpret_cast(&addr)) == 0) { break; } if (i == max_retry) { @@ -174,7 +177,7 @@ class RDMAVan : public Van { return port; } - void Connect(const Node &node) override { + void Connect(const Node& node) override { PS_VLOG(1) << "Connecting to Node " << node.id << ", My_Node=" << my_node_.id; PS_CHECK_NE(node.id, node.kEmpty); @@ -195,20 +198,20 @@ class RDMAVan : public Van { endpoints_.erase(it); } - Endpoint *endpoint; + Endpoint* endpoint; endpoints_[node.id] = std::make_unique(); endpoint = endpoints_[node.id].get(); endpoints_mu_.unlock(); endpoint->SetNodeID(node.id); - struct addrinfo *remote_addr; + struct addrinfo* remote_addr; PS_CHECK_EQ( getaddrinfo(node.hostname.c_str(), std::to_string(node.port).c_str(), nullptr, &remote_addr), 0); - struct addrinfo *addr = nullptr; + struct addrinfo* addr = nullptr; auto val = Environment::Get()->find("DMLC_NODE_HOST"); if (val) { auto rc = getaddrinfo(val, "", NULL, &addr); @@ -274,7 +277,7 @@ class RDMAVan : public Van { } } - void RegisterRecvBuffer(Message &msg) override { + void RegisterRecvBuffer(Message& msg) override { RegisterMemory(msg); std::unique_lock lock(registered_recv_buffers_mu_); uint64_t key = DecodeKey(msg.data[0]); @@ -285,12 +288,12 @@ class RDMAVan : public Van { << ", size=" << msg.data[1].size(); } - void QueryRecvBuffer(uint64_t key, int node_id, void **buffer, size_t *size, - uint32_t *rkey) override { + void QueryRecvBuffer(uint64_t key, int node_id, void** buffer, size_t* size, + uint32_t* rkey) override { std::unique_lock lock(registered_recv_buffers_mu_); auto itr = registered_recv_buffers_.find(key); if (itr != registered_recv_buffers_.end()) { - for (auto &t : itr->second) { + for (auto& t : itr->second) { if (t.first == node_id) { *buffer = t.second.data(); *size = t.second.size(); @@ -306,10 +309,10 @@ class RDMAVan : public Van { *size = 0; } - int SendMsg(Message &msg) override { + int SendMsg(Message& msg) override { int remote_id = msg.meta.recver; PS_CHECK_NE(remote_id, Meta::kEmpty); - Endpoint *endpoint = nullptr; + Endpoint* endpoint = nullptr; { std::unique_lock lock(endpoints_mu_); auto itr = endpoints_.find(remote_id); @@ -333,14 +336,14 @@ class RDMAVan : public Van { // start rendezvous if no remote info if (!IsValidPushpull(msg)) { - MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); + MessageBuffer* msg_buf = PrepareNewMsgBuf(msg); StoreMsgBuf(msg_buf, msg); trans->SendRendezvousBegin(msg, msg_buf); return total_len; } if (!HasRemoteInfo(msg, msg.meta.key, msg.meta.push, remote_id)) { - MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); + MessageBuffer* msg_buf = PrepareNewMsgBuf(msg); StoreMsgBuf(msg_buf, msg); PrepareData(msg, msg_buf); trans->SendRendezvousBegin(msg, msg_buf); @@ -350,9 +353,9 @@ class RDMAVan : public Van { auto addr_tuple = GetRemoteAndLocalInfo(msg.meta.key, msg.meta.push, remote_id); #ifdef STEPMESH_USE_GDR - MessageBuffer *msg_buf = std::get<5>(addr_tuple); // local message buffer + MessageBuffer* msg_buf = std::get<5>(addr_tuple); // local message buffer #else - MessageBuffer *msg_buf = std::get<3>(addr_tuple); // local message buffer + MessageBuffer* msg_buf = std::get<3>(addr_tuple); // local message buffer #endif // print detail of msg_buf as one line @@ -405,13 +408,13 @@ class RDMAVan : public Van { return total_len; } - int RecvMsg(Message *msg) override { + int RecvMsg(Message* msg) override { msg->data.clear(); - std::tuple notification; + std::tuple notification; recv_buffers_.WaitAndPop(¬ification); int cmd = std::get(notification); - Endpoint *endpoint = std::get(notification); - BufferContext *buffer_ctx = std::get(notification); + Endpoint* endpoint = std::get(notification); + BufferContext* buffer_ctx = std::get(notification); auto trans = PS_CHECK_NOTNULL(endpoint->GetTransport()); msg->meta.recver = my_node_.id; msg->meta.sender = endpoint->node_id; @@ -419,17 +422,17 @@ class RDMAVan : public Van { // the second argument is actually deprecated, // we keep it as is in order to be compatible #ifdef STEPMESH_USE_GDR - bool is_server = static_cast(trans.get())->is_server_; - char *meta_buf = is_server ? buffer_ctx->meta_buffer : buffer_ctx->buffer; + bool is_server = static_cast(trans.get())->is_server_; + char* meta_buf = is_server ? buffer_ctx->meta_buffer : buffer_ctx->buffer; PS_CHECK(meta_buf); #else - char *meta_buf = buffer_ctx->buffer; + char* meta_buf = buffer_ctx->buffer; PS_CHECK(meta_buf); #endif PS_VLOG(3) << "3. 1 RecvMsg: " << msg->DebugString(); - RawMeta *raw = reinterpret_cast(meta_buf); + RawMeta* raw = reinterpret_cast(meta_buf); auto counters = raw->slave_qp_counter; if (raw->slave_qp_num > 0) { @@ -493,7 +496,7 @@ class RDMAVan : public Van { } private: - void PrintSendLog(Message &msg, MessageBuffer *msg_buf, + void PrintSendLog(Message& msg, MessageBuffer* msg_buf, RemoteTuple remote_tuple) { if (!enable_log_) return; std::lock_guard lock(log_mu_); @@ -508,7 +511,7 @@ class RDMAVan : public Van { << "\t tensor_len=" << msg_buf->mrs[0].second << "\t remote_idx=" << std::get<2>(remote_tuple) << "\t remote_addr=" - << reinterpret_cast(std::get<0>(remote_tuple)) + << reinterpret_cast(std::get<0>(remote_tuple)) << std::flush; } else if (msg.meta.push && !msg.meta.request) { // server, push response @@ -517,7 +520,7 @@ class RDMAVan : public Van { << "\t recver=" << msg.meta.recver << "\t remote_idx=" << std::get<2>(remote_tuple) << "\t remote_addr=" - << reinterpret_cast(std::get<0>(remote_tuple)) + << reinterpret_cast(std::get<0>(remote_tuple)) << std::flush; } else if (!msg.meta.push && msg.meta.request) { // worker, pull request @@ -526,7 +529,7 @@ class RDMAVan : public Van { << "\t recver=" << msg.meta.recver << "\t remote_idx=" << std::get<2>(remote_tuple) << "\t remote_addr=" - << reinterpret_cast(std::get<0>(remote_tuple)) + << reinterpret_cast(std::get<0>(remote_tuple)) << std::flush; } else if (!msg.meta.push && !msg.meta.request) { // server, pull response @@ -536,12 +539,12 @@ class RDMAVan : public Van { << "\t tensor_len=" << msg.meta.val_len << "\t idx=" << "none" << "\t remote_addr=" - << reinterpret_cast(std::get<0>(remote_tuple)) + << reinterpret_cast(std::get<0>(remote_tuple)) << std::flush; } } - void PrintRecvLog(Message *msg, BufferContext *buffer_ctx, int meta_len) { + void PrintRecvLog(Message* msg, BufferContext* buffer_ctx, int meta_len) { if (!enable_log_) return; std::lock_guard lock(log_mu_); @@ -572,7 +575,7 @@ class RDMAVan : public Van { } } - bool HasRemoteInfo(Message &msg, uint64_t key, bool is_push, int recver) { + bool HasRemoteInfo(Message& msg, uint64_t key, bool is_push, int recver) { std::lock_guard lk(addr_mu_); if (is_push && (push_addr_.find(key) != push_addr_.end()) && (push_addr_[key].find(recver) != push_addr_[key].end())) { @@ -586,33 +589,33 @@ class RDMAVan : public Van { return false; } - void StoreMsgBuf(MessageBuffer *msg_buf, Message &msg) { + void StoreMsgBuf(MessageBuffer* msg_buf, Message& msg) { std::lock_guard lk(addr_mu_); PS_CHECK_EQ(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); msgbuf_cache_[msg_buf] = msg; } - Message *GetFirstMsg(MessageBuffer *msg_buf) { + Message* GetFirstMsg(MessageBuffer* msg_buf) { std::lock_guard lk(addr_mu_); PS_CHECK_NE(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); return &msgbuf_cache_[msg_buf]; } - void ReleaseFirstMsg(MessageBuffer *msg_buf) { + void ReleaseFirstMsg(MessageBuffer* msg_buf) { std::lock_guard lk(addr_mu_); PS_CHECK_NE(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); msgbuf_cache_.erase(msg_buf); } #ifdef STEPMESH_USE_GDR - void StoreRemoteAndLocalInfo(MessageBuffer *msg_buf, uint64_t meta_addr, + void StoreRemoteAndLocalInfo(MessageBuffer* msg_buf, uint64_t meta_addr, uint32_t meta_rkey, uint64_t data_addr, uint32_t data_rkey, uint32_t idx) { std::lock_guard lk(addr_mu_); PS_CHECK_NE(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); - auto &msg = msgbuf_cache_[msg_buf]; + auto& msg = msgbuf_cache_[msg_buf]; auto key = msg.meta.key; auto is_push = msg.meta.push; @@ -628,13 +631,13 @@ class RDMAVan : public Van { } #endif - void StoreRemoteAndLocalInfo(MessageBuffer *msg_buf, uint64_t remote_addr, + void StoreRemoteAndLocalInfo(MessageBuffer* msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { std::lock_guard lk(addr_mu_); PS_CHECK_NE(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); - auto &msg = msgbuf_cache_[msg_buf]; + auto& msg = msgbuf_cache_[msg_buf]; auto key = msg.meta.key; auto is_push = msg.meta.push; @@ -659,8 +662,8 @@ class RDMAVan : public Van { return (is_push ? push_addr_[key][recver] : pull_addr_[key][recver]); } - MessageBuffer *PrepareNewMsgBuf(Message &msg) { - MessageBuffer *msg_buf = new MessageBuffer(); + MessageBuffer* PrepareNewMsgBuf(Message& msg) { + MessageBuffer* msg_buf = new MessageBuffer(); auto meta_len = GetPackMetaLen(msg.meta); msg_buf->inline_len = meta_len; msg_buf->inline_buf = mem_allocator_->Alloc(meta_len); @@ -669,20 +672,22 @@ class RDMAVan : public Van { return msg_buf; } - void RegisterMemory(Message &msg) { + void RegisterMemory(Message& msg) { size_t sa_cnt = 0; - for (auto &sa : msg.data) { + for (auto& sa : msg.data) { if (sa.size() == 0) continue; std::lock_guard lock(map_mu_); if ((mem_mr_.find(sa.data()) == mem_mr_.end() || mem_mr_[sa.data()]->length < sa.size()) && (sa_cnt == 1)) { // only vals register memory - struct ibv_mr *temp_mr; + struct ibv_mr* temp_mr; temp_mr = ibv_reg_mr(mem_allocator_->GetPD(), sa.data(), sa.size(), IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); if (temp_mr == nullptr) { LOG(WARNING) << "Failed to register the memory region: " - << strerror(errno) << ", sa.size()=" << sa.size(); + << strerror(errno) + << ", sa.data()=" << reinterpret_cast(sa.data()) + << ", sa.size()=" << sa.size(); PS_CHECK(0); } @@ -693,11 +698,11 @@ class RDMAVan : public Van { // register for tensor address of pull request if (IsValidPushpull(msg) && !msg.meta.push && msg.meta.request) { PS_CHECK_GT(msg.meta.val_len, 0) << msg.meta.val_len; - auto addr = reinterpret_cast(msg.meta.addr); + auto addr = reinterpret_cast(msg.meta.addr); std::lock_guard lock(map_mu_); if (mem_mr_.find(addr) == mem_mr_.end() || mem_mr_[addr]->length < msg.meta.val_len) { - struct ibv_mr *temp_mr; + struct ibv_mr* temp_mr; temp_mr = ibv_reg_mr(mem_allocator_->GetPD(), addr, msg.meta.val_len, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); if (temp_mr == nullptr) { @@ -710,31 +715,31 @@ class RDMAVan : public Van { } } - void PrepareData(Message &msg, MessageBuffer *msg_buf) { + void PrepareData(Message& msg, MessageBuffer* msg_buf) { if (!(msg.meta.push && msg.meta.request)) return; // only push request - auto &sa = msg_buf->data[1]; + auto& sa = msg_buf->data[1]; if (sa.size() == 0) return; std::lock_guard lock(map_mu_); auto it = mem_mr_.find(sa.data()); PS_CHECK_NE(it, mem_mr_.end()); - MRPtr ptr(it->second, [](struct ibv_mr *mr) {}); + MRPtr ptr(it->second, [](struct ibv_mr* mr) {}); PS_CHECK(ptr.get()) << strerror(errno); msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); } - void AddMeta(Message &msg) { + void AddMeta(Message& msg) { if (msg.meta.request) { msg.meta.key = DecodeKey(msg.data[0]); } if (!msg.meta.push && msg.meta.request) { // pull request std::lock_guard lock(map_mu_); - auto val_addr = reinterpret_cast(msg.meta.addr); + auto val_addr = reinterpret_cast(msg.meta.addr); msg.meta.option = mem_mr_[val_addr]->rkey; } } - void InitContext(struct ibv_context *context) { + void InitContext(struct ibv_context* context) { context_ = context; PS_CHECK(context_) << "ibv_context* empty"; @@ -751,7 +756,7 @@ class RDMAVan : public Van { PS_CHECK(cq_) << "Failed to create completion queue"; } - void ReleaseWorkRequestContext(WRContext *context, Endpoint *endpoint, + void ReleaseWorkRequestContext(WRContext* context, Endpoint* endpoint, int qpIndex = 0) { switch (context->type) { case kRendezvousStartContext: @@ -788,16 +793,15 @@ class RDMAVan : public Van { << static_cast(wc[i].wr_id) << " " << wc[i].vendor_err << " " << wc[i].opcode << " " << (wc[i].opcode == IBV_WC_RECV ? "RECV" : "OTHER") - << " postoffice ptr: " << reinterpret_cast(postoffice_); + << " postoffice ptr: " << reinterpret_cast(postoffice_); // IBV_WC_RDMA_WRITE use msg_buf as the wr_id // so there won't be context and endpoint for this op if (wc[i].opcode == IBV_WC_RDMA_WRITE) { continue; } - WRContext *context = reinterpret_cast(wc[i].wr_id); - Endpoint *endpoint = - reinterpret_cast(context->private_data); + WRContext* context = reinterpret_cast(wc[i].wr_id); + Endpoint* endpoint = reinterpret_cast(context->private_data); // IBV_WC_RDMA_WRITE use msg_buf as the wr_id // so there won't be context and endpoint for this op @@ -817,7 +821,7 @@ class RDMAVan : public Van { uint32_t cmd = ((addr_idx & 0xFFFF) >> 16) & 0xFFFF; endpoint->master_id = endpoint->cm_ids[cmIdInde]; - BufferContext *buf_ctx = addr_pool_.GetAddress(addr_idx); + BufferContext* buf_ctx = addr_pool_.GetAddress(addr_idx); recv_buffers_.Push( std::make_tuple(endpoint, buf_ctx, GetNanosecond(), cmd)); } else { @@ -829,22 +833,22 @@ class RDMAVan : public Van { case IBV_WC_RECV: { PS_CHECK(wc[i].wc_flags & IBV_WC_WITH_IMM); uint32_t imm = wc[i].imm_data; - struct ibv_mr *mr = context->buffer; + struct ibv_mr* mr = context->buffer; if (imm == kRendezvousStart) { - RendezvousStart *req = - reinterpret_cast(mr->addr); + RendezvousStart* req = + reinterpret_cast(mr->addr); auto trans = PS_CHECK_NOTNULL(endpoint->GetTransport()); trans->SendRendezvousReply(req, addr_pool_); } else if (imm == kRendezvousReply) { - RendezvousReply *resp = - reinterpret_cast(mr->addr); + RendezvousReply* resp = + reinterpret_cast(mr->addr); uint64_t origin_addr = resp->origin_addr; uint32_t idx = resp->idx; - MessageBuffer *msg_buf = - reinterpret_cast(origin_addr); + MessageBuffer* msg_buf = + reinterpret_cast(origin_addr); // Before RDMA write, store the remote info so that // subsequent write does not need repeated rendezvous #ifdef STEPMESH_USE_GDR @@ -854,7 +858,7 @@ class RDMAVan : public Van { StoreRemoteAndLocalInfo(msg_buf, resp->addr, resp->rkey, idx); #endif - Message *msg = GetFirstMsg(msg_buf); + Message* msg = GetFirstMsg(msg_buf); auto addr_tuple = GetRemoteAndLocalInfo( msg->meta.key, msg->meta.push, msg->meta.recver); auto trans = PS_CHECK_NOTNULL(endpoint->GetTransport()); @@ -918,7 +922,7 @@ class RDMAVan : public Van { continue; } - struct rdma_cm_event *event; + struct rdma_cm_event* event; PS_CHECK_EQ(rdma_get_cm_event(event_channel_, &event), 0); // TODO(clan): Reorder the list according to the event frequency switch (event->event) { @@ -948,9 +952,9 @@ class RDMAVan : public Van { } } - void OnRejected(struct rdma_cm_event *event) { - struct rdma_cm_id *id = event->id; - Endpoint *endpoint = reinterpret_cast(id->context); + void OnRejected(struct rdma_cm_event* event) { + struct rdma_cm_id* id = event->id; + Endpoint* endpoint = reinterpret_cast(id->context); endpoints_mu_.lock(); auto it = endpoints_.find(endpoint->node_id); @@ -967,8 +971,8 @@ class RDMAVan : public Van { } // Server Side - void OnConnectRequest(struct rdma_cm_event *event) { - struct rdma_cm_id *id = event->id; + void OnConnectRequest(struct rdma_cm_event* event) { + struct rdma_cm_id* id = event->id; PS_CHECK_NOTNULL(id); PS_CHECK_LE(sizeof(RequestContext), event->param.conn.private_data_len) << "RequestContext size mismatch. Actual: " @@ -976,10 +980,10 @@ class RDMAVan : public Van { << ", Expected: " << sizeof(RequestContext); PS_CHECK_NOTNULL(event->param.conn.private_data); - const RequestContext *remote_ctx = reinterpret_cast( - event->param.conn.private_data); + const RequestContext* remote_ctx = + reinterpret_cast(event->param.conn.private_data); - Endpoint *endpoint = nullptr; + Endpoint* endpoint = nullptr; std::string rem_host = std::string(remote_ctx->hostname) + "," + std::to_string(remote_ctx->node) + "," + std::to_string(remote_ctx->port); @@ -1039,16 +1043,16 @@ class RDMAVan : public Van { } // Resolve a route after address is resolved - void OnAddrResolved(struct rdma_cm_event *event) { - struct rdma_cm_id *id = event->id; + void OnAddrResolved(struct rdma_cm_event* event) { + struct rdma_cm_id* id = event->id; PS_CHECK_EQ(rdma_resolve_route(id, kTimeoutms), 0) << "Resolve RDMA route failed"; } // Make a connection after route is resolved - void OnRouteResolved(struct rdma_cm_event *event) { - struct rdma_cm_id *id = event->id; - Endpoint *endpoint = reinterpret_cast(id->context); + void OnRouteResolved(struct rdma_cm_event* event) { + struct rdma_cm_id* id = event->id; + Endpoint* endpoint = reinterpret_cast(id->context); if (context_ == nullptr) { InitContext(id->verbs); @@ -1072,10 +1076,10 @@ class RDMAVan : public Van { if (endpoint->inComingCount == QP_NUM) endpoint->inComingCount = 0; } - void OnConnected(struct rdma_cm_event *event) { - struct rdma_cm_id *id = event->id; + void OnConnected(struct rdma_cm_event* event) { + struct rdma_cm_id* id = event->id; PS_CHECK(id) << "rdma_cm_id not found."; - Endpoint *endpoint = reinterpret_cast(id->context); + Endpoint* endpoint = reinterpret_cast(id->context); PS_CHECK(endpoint) << "Endpoint not found."; if (cq_polling_thread_ == nullptr) { cq_polling_thread_.reset(new std::thread(&RDMAVan::PollCQ, this)); @@ -1103,9 +1107,9 @@ class RDMAVan : public Van { } } - void OnDisconnected(struct rdma_cm_event *event) { - struct rdma_cm_id *id = event->id; - Endpoint *endpoint = reinterpret_cast(id->context); + void OnDisconnected(struct rdma_cm_event* event) { + struct rdma_cm_id* id = event->id; + Endpoint* endpoint = reinterpret_cast(id->context); { std::lock_guard lk(endpoint->connect_mu); // endpoint->status = Endpoint::IDLE; @@ -1121,26 +1125,26 @@ class RDMAVan : public Van { std::unique_ptr rdma_trans_; std::unique_ptr ipc_trans_; - struct rdma_cm_id *listener_ = nullptr; + struct rdma_cm_id* listener_ = nullptr; std::atomic should_stop_; std::mutex endpoints_mu_; std::unordered_map> endpoints_; std::unordered_map> incoming_; - struct rdma_event_channel *event_channel_ = nullptr; - struct ibv_context *context_ = nullptr; + struct rdma_event_channel* event_channel_ = nullptr; + struct ibv_context* context_ = nullptr; // ibverbs protection domain - struct ibv_pd *pd_ = nullptr; + struct ibv_pd* pd_ = nullptr; // Completion queue, to poll on work completions - struct ibv_cq *cq_ = nullptr; + struct ibv_cq* cq_ = nullptr; // cq thread std::unique_ptr cq_polling_thread_ = nullptr; // event thread std::unique_ptr cm_event_polling_thread_ = nullptr; // Recv buffer queue - ThreadsafeQueue> + ThreadsafeQueue> recv_buffers_; // local IPC related @@ -1152,10 +1156,10 @@ class RDMAVan : public Van { // , () std::unordered_map push_addr_; std::unordered_map pull_addr_; - std::unordered_map msgbuf_cache_; // msg_buf, msg + std::unordered_map msgbuf_cache_; // msg_buf, msg std::mutex map_mu_; - std::unordered_map + std::unordered_map mem_mr_; // (memory address, ibv_mr) // logging diff --git a/tests/fserver/test_fserver.py b/tests/fserver/test_fserver.py index ca3a4d7..dff4008 100644 --- a/tests/fserver/test_fserver.py +++ b/tests/fserver/test_fserver.py @@ -1,10 +1,15 @@ -import torch, os +import torch, os, sys import time + +old_flags = sys.getdlopenflags() +sys.setdlopenflags(sys.getdlopenflags() | 0x100) import fserver_lib as f +sys.setdlopenflags(old_flags) + is_worker = os.environ.get('DMLC_ROLE') == 'worker' is_server = os.environ.get('DMLC_ROLE') == 'server' -f.init() +f.init("cmake_build/plugins/klx_backend/libklx_backend.so") if is_worker: gpu = os.environ.get('STEPMESH_GPU')