Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)

find_package(Options REQUIRED)

if("$ENV{USE_CUDA}" STREQUAL "0")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_ZMQ -DDMLC_USE_RDMA -DSTEPMESH_USE_GDR")
Expand Down Expand Up @@ -67,3 +68,10 @@ target_link_libraries(af
${TORCH_PYTHON_LIBRARY})

add_subdirectory(tests)

if(ENABLE_PLUGIN)
add_subdirectory(plugins/klx_backend)
add_custom_target(dummy_target)
add_dependencies(klx_backend af)
add_dependencies(dummy_target af klx_backend)
endif()
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,10 @@ af:
@cd cmake_build; cmake .. -DCMAKE_CUDA_COMPILER=$(CMAKE_CUDA_COMPILER) -DPython_EXECUTABLE=$(shell which python3) -DCUDA_TOOLKIT_ROOT_DIR=$(CUDA_TOOLKIT_ROOT_DIR); make -j
@mkdir -p build

plugin:
@mkdir -p cmake_build
@cd cmake_build; cmake .. -DENABLE_PLUGIN=ON -DCMAKE_CUDA_COMPILER=$(CMAKE_CUDA_COMPILER) -DENABLE_PLUGIN=ON -DPython_EXECUTABLE=$(shell which python3) -DCUDA_TOOLKIT_ROOT_DIR=$(CUDA_TOOLKIT_ROOT_DIR); make -j
@mkdir -p build



1 change: 1 addition & 0 deletions cmake/FindOptions.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
option(ENABLE_PLUGIN "Enable plugin feature" OFF)
25 changes: 19 additions & 6 deletions fserver/csrc/public.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
/* Copyright (c) 2025, StepFun Authors. All rights reserved. */

#include <dlfcn.h>

#include <execinfo.h>
#include <stdio.h>
#include <signal.h>
Expand Down Expand Up @@ -43,6 +46,7 @@ uint64_t handler_counter_ = 0;
std::unordered_map<uint64_t, AFTensorMeta> meta_map_;
std::vector<std::deque<ServerDataBatch>> q_;
std::atomic<uint64_t> q_signal_;
static void* gPluginHandle = nullptr;

void RequestHandler(const AFTensorMeta& req_meta, AFTensorServer* server) {
std::vector<torch::Tensor> tensors;
Expand Down Expand Up @@ -164,7 +168,12 @@ void barrier(bool include_server, bool include_worker, bool instrance_barrier=tr
}


void init() {
void init(const std::string& plugin) {
if (!plugin.empty()) {
gPluginHandle = dlopen(plugin.c_str(), RTLD_NOW);
PS_CHECK(gPluginHandle)
<< "can't load plugin:" << plugin << ": " << dlerror();
}

std::string role_str = ps::GetEnv("DMLC_ROLE", "server");
int offset = 0;
Expand Down Expand Up @@ -203,14 +212,17 @@ void stop() {
ps::Postoffice::GetWorker(gpu_)->DoBarrier(0,
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true);
} else if (role_ == Node::SERVER) {
ps::Postoffice::GetServer(gpu_)->DoBarrier(0,
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true);
ps::Postoffice::GetServer(gpu_)->DoBarrier(
0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true);
} else {
ps::Postoffice::Get()->DoBarrier(0,
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true);
ps::Postoffice::Get()->DoBarrier(
0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler, true);
}

ps::Finalize(0, role_, true);
if (gPluginHandle) {
dlclose(gPluginHandle);
}
}

std::vector<int> get_all_handlers(int handler) {
Expand All @@ -237,7 +249,8 @@ uint64_t get_nanosecond() {


void pybind_public(py::module &m){
m.def("init", &init, py::call_guard<py::gil_scoped_release>());
m.def("init", &init, py::arg("plugin") = "",
py::call_guard<py::gil_scoped_release>());
m.def("stop", &stop, py::call_guard<py::gil_scoped_release>());

m.def("register_recv_buffer",
Expand Down
46 changes: 46 additions & 0 deletions include/dmlc/backend_registry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include "base.h"
#include "ps/backend.h"

namespace dmlc {

// StepMesh Backend Registration
//
// StepMesh uses a **dynamic backend registration mechanism** to decouple
// specific backend implementations (e.g., RDMA, CPU, GPU, future XPUs) from
// the core transport layer.
//
// Motivation:
// * Avoid hard-coding backend logic into the core codebase.
// * Support extensibility: new backends can be plugged in without touching
// core StepMesh logic.
// * Enable backend discovery at runtime via a registry pattern.
//
// Contract Requirements:
// * Each backend must define a unique type identifier.
// * Backends must satisfy the BackendInterface contract (init, push/pull,
// etc).
// * Registration should occur during StepMesh startup or module load.
//
// Backend registry responsibilities:
// * Store mappings from backend ID to constructor/factory.
// * Provide factory APIs to create instances by ID.
// * Ensure correct initialization order (core first, backends next).
//
// Example use cases:
// * Enable support for new accelerator types (e.g., NPU, TPU-like devices).
// * Abstract transport layer differences while exposing uniform APIs.
//
// Note: This refactor **does not add a new mandatory backend** — existing
// backends continue to work without modification unless explicitly replaced.
// It simply provides the structural foundation for extensibility.
// This design choice is focused on modularity and future-proofing.
template <typename T>
struct STEPMESH_API backend_registry {
backend_registry(const std::string& name) {
ps::Backend::RegisterLazy(name, []() { return new T(); });
}
};

} // namespace dmlc
2 changes: 2 additions & 0 deletions include/dmlc/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,6 @@ inline const char *BeginPtr(const std::string &str) {
#define alignof __alignof
#endif

#define STEPMESH_API __attribute__((__visibility__("default")))

#endif // DMLC_BASE_H_
32 changes: 21 additions & 11 deletions include/dmlc/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,25 +71,35 @@ inline void InitLogging(const char *argv0) {
// DO NOTHING
}

constexpr const char *getFileName(const char *path) {
auto last = path + strlen(path);
while (*last != '/') {
--last;
}
return ++last;
}

// Always-on checking
#define PS_CHECK(x) \
if (!(x)) \
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \
"failed: " #x \
<< ' '
#define PS_CHECK(x) \
if (!(x)) \
dmlc::LogMessageFatal(dmlc::getFileName(__FILE__), __LINE__).stream() \
<< "Check " \
"failed: " #x \
<< ' '
#define PS_CHECK_LT(x, y) PS_CHECK((x) < (y))
#define PS_CHECK_GT(x, y) PS_CHECK((x) > (y))
#define PS_CHECK_LE(x, y) PS_CHECK((x) <= (y))
#define PS_CHECK_GE(x, y) PS_CHECK((x) >= (y))
#define PS_CHECK_EQ(x, y) PS_CHECK((x) == (y))
#define PS_CHECK_NE(x, y) PS_CHECK((x) != (y))
#define PS_CHECK_NOTNULL(x) \
((x) == NULL \
? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', \
#define PS_CHECK_NOTNULL(x) \
((x) == NULL \
? dmlc::LogMessageFatal(dmlc::getFileName(__FILE__), __LINE__).stream() \
<< "Check notnull: " #x << ' ', \
(x) : (x)) // NOLINT(*)
// Debug-only checking.
#ifdef NDEBUG
/*
/*
#define DPS_CHECK(x) \
while (false) PS_CHECK(x)
#define DPS_CHECK_LT(x, y) \
Expand All @@ -114,12 +124,12 @@ inline void InitLogging(const char *argv0) {
#define DPS_CHECK_NE(x, y) PS_CHECK((x) != (y)) */
#endif // NDEBUG

#define PS_LOG_API dmlc::LogMessage(__FILE__, __LINE__)
#define PS_LOG_API dmlc::LogMessage(dmlc::getFileName(__FILE__), __LINE__)

#define PS_LOG_IF(severity, condition) \
!(condition) ? (void)0 : dmlc::LogMessageVoidify() & PS_LOG_API

#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__)
#define LOG_FATAL dmlc::LogMessageFatal(dmlc::getFileName(__FILE__), __LINE__)
#define PS_LOG_FATAL LOG_FATAL.stream()
#define LOG_QFATAL LOG_FATAL

Expand Down
31 changes: 21 additions & 10 deletions include/ps/af_tensor_app.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
#include <utility>
#include <vector>

#include "ps/backend.h"
#include "ps/base.h"
#include "ps/hash_table8.hpp"
#include "ps/internal/backend.h"
#include "ps/internal/utils.h"
#include "ps/kv_app.h"

Expand Down Expand Up @@ -236,15 +236,16 @@ class AFTensorWorker {
void ZPush_(int ts, const SArray<Key>& keys, const at::Tensor& tensor,
int cmd = 0) {
SArray<char> val;
val.reset(reinterpret_cast<char*>(tensor.data_ptr()),
void* mappedPtr = Backend::Get()->GetAccessibleAddr(tensor);
val.reset(reinterpret_cast<char*>(mappedPtr),
tensor.numel() * tensor.itemsize(), [tensor](void*) {});

Message msg;
msg.meta.request = true;
msg.meta.head = cmd;
msg.meta.push = true;
msg.meta.timestamp = ts;
msg.meta.addr = reinterpret_cast<uint64_t>(tensor.data_ptr());
msg.meta.addr = reinterpret_cast<uint64_t>(mappedPtr);
msg.meta.val_len = tensor.numel() * tensor.itemsize();
PS_VLOG(2) << "ZPush_ addr: 0x" << std::hex << msg.meta.addr << std::dec
<< " val_len: " << msg.meta.val_len;
Expand Down Expand Up @@ -284,13 +285,14 @@ class AFTensorWorker {

*key.data() = pull_tensors[i * pull_batch_size + index].key;

val.reset(reinterpret_cast<char*>(tensor.data_ptr()),
void* mappedPtr = Backend::Get()->GetAccessibleAddr(tensor);
val.reset(reinterpret_cast<char*>(mappedPtr),
tensor.numel() * tensor.itemsize(), [tensor](void*) {});

msg.meta.request = true;
msg.meta.head = cmd;
msg.meta.push = false;
msg.meta.addr = reinterpret_cast<uint64_t>(tensor.data_ptr());
msg.meta.addr = reinterpret_cast<uint64_t>(mappedPtr);
msg.meta.val_len = tensor.numel() * tensor.itemsize();
msg.meta.key = key[0];
msg.meta.is_tensor = 1;
Expand Down Expand Up @@ -483,7 +485,8 @@ class AFTensorServer {
res.keys = key;

SArray<char> tensor_val;
tensor_val.reset(reinterpret_cast<char*>(tensors[0].val.data_ptr()),
tensor_val.reset(reinterpret_cast<char*>(
Backend::Get()->GetAccessibleAddr(tensors[0].val)),
tensors[0].val.numel() * tensors[0].val.itemsize(),
[](void*) {});
res.vals = tensor_val;
Expand All @@ -506,7 +509,8 @@ class AFTensorServer {
rsp.kv_pair.keys = key;

rsp.kv_pair.vals.reset(
reinterpret_cast<char*>(res_kv.val.data_ptr()),
reinterpret_cast<char*>(
Backend::Get()->GetAccessibleAddr(res_kv.val)),
res_kv.val.numel() * res_kv.val.itemsize(), [](void*) {});

rsp.kv_meta = kv_meta;
Expand Down Expand Up @@ -558,7 +562,8 @@ class AFTensorServer {
PS_CHECK_GT(worker_ranks.size(), 0) << "ranks or keys should not be empty";
PS_CHECK_EQ(worker_ranks.size(), keys.size())
<< "rank list and key list have unequal size";
char* buffer_ptr = reinterpret_cast<char*>(tensor.data_ptr());
char* buffer_ptr =
reinterpret_cast<char*>(Backend::Get()->GetAccessibleAddr(tensor));
uint64_t data_size = tensor.numel() * tensor.element_size();
int chunk_size = data_size / worker_ranks.size();
PS_CHECK_EQ(data_size % worker_ranks.size(), 0)
Expand Down Expand Up @@ -591,8 +596,14 @@ class AFTensorServer {
.dtype(at::ScalarType(req_meta.dtype))
.memory_format(at::MemoryFormat::Contiguous)
.device(Backend::Get()->GetDevice());
key_tensor.val =
at::from_blob(req_data.vals.data(), req_meta.shape, options);
key_tensor.val = at::from_blob(
Backend::Get()->GetDeviceAddrFromHostPtr(
req_data.vals.data(),
std::accumulate(std::begin(req_meta.shape),
std::end(req_meta.shape),
c10::elementSize(at::ScalarType(req_meta.dtype)),
std::multiplies<uint64_t>())),
req_meta.shape, options);
}
key_tensor.key = req_data.keys[0];
return key_tensor;
Expand Down
39 changes: 34 additions & 5 deletions include/ps/internal/backend.h → include/ps/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
#endif
#include <torch/torch.h>

#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>

#include "base.h"
#include "dmlc/logging.h"
#include "ps/internal/env.h"
#include "ps/env.h"

namespace ps {

Expand All @@ -26,7 +28,7 @@ enum { BACKEND_OK = 0, BACKEND_FAILED = -1 };
/**
* \brief Abstract Backend Class
*/
class Backend {
class STEPMESH_API Backend {
public:
/**
* \brief Set device index for current thread
Expand Down Expand Up @@ -88,6 +90,24 @@ class Backend {
*/
virtual int SyncEvent(void* event) = 0;

/**
*\brief Get an address that is directly readable via the PCIe bus
* @param devicePtr device physical address
* @return an address that is directly readable via the PCIe bus
*/
virtual void* GetAccessibleAddr(void* devicePtr, size_t size) {
return devicePtr;
}

virtual void* GetAccessibleAddr(const at::Tensor& tensor) {
return GetAccessibleAddr(tensor.data_ptr(),
tensor.numel() * tensor.element_size());
}

virtual void* GetDeviceAddrFromHostPtr(void* hostPtr, size_t size) {
return hostPtr;
}

/**
* \brief Get the backend implementation
* @return the backend implementation
Expand All @@ -98,12 +118,17 @@ class Backend {
RegisterImpl(name, backend);
}

static void RegisterLazy(const std::string& name,
const std::function<Backend*(void)>& ctor);

protected:
Backend() = default;

private:
static std::mutex backends_mutex_;
static std::unordered_map<std::string, Backend*> backends_;
static std::unordered_map<std::string, std::function<Backend*(void)>>
backend_ctors_;

static Backend* GetImpl() {
static Backend* backend_impl = nullptr;
Expand All @@ -113,9 +138,13 @@ class Backend {
return backend_impl;
}
std::string backend_type = "GPU";
backend_type = Environment::Get()->find("STEPMESH_BAKCEND", backend_type);
PS_CHECK_NE(backends_.find(backend_type), backends_.end())
<< "failed to get backend impl: " << backend_type;
backend_type = Environment::Get()->find("STEPMESH_BACKEND", backend_type);
if (backends_.find(backend_type) == backends_.end()) {
PS_CHECK_NE(backend_ctors_.find(backend_type), backend_ctors_.end())
<< "failed to get backend impl: " << backend_type;
backends_[backend_type] = backend_ctors_[backend_type]();
}

backend_impl = backends_[backend_type];
}
return backend_impl;
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion include/ps/internal/cpu_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#include <memory>

#include "ps/internal/backend.h"
#include "ps/backend.h"

namespace ps {

Expand Down
Loading