diff --git a/src/core/base_plugin_manager.cpp b/src/core/base_plugin_manager.cpp new file mode 100644 index 000000000..0b7e0136e --- /dev/null +++ b/src/core/base_plugin_manager.cpp @@ -0,0 +1,275 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "base_plugin_manager.h" +#include "common/nixl_log.h" +#include +#include + +using lock_guard = const std::lock_guard; + +void +dlHandleDeleter::operator()(void *handle) const noexcept { + if (handle) { + // Call cleanup function if specified + if (!fini_func_name.empty()) { + using fini_func_t = void (*)(); + fini_func_t fini = reinterpret_cast(dlsym(handle, fini_func_name.c_str())); + if (fini) { + try { + fini(); + } + catch (const std::exception &e) { + NIXL_WARN << "Exception in plugin cleanup (" << fini_func_name + << "): " << e.what(); + } + catch (...) { + NIXL_WARN << "Unknown exception in plugin cleanup (" << fini_func_name << ")"; + } + } + } + + dlclose(handle); + } +} + +basePluginHandle::basePluginHandle(std::unique_ptr handle, + const void *plugin_interface) + : dlHandle_(std::move(handle)), + pluginInterface_(plugin_interface) { + assert(dlHandle_ && "DlHandleDeleter must not be null"); + assert(pluginInterface_ && "Plugin interface must not be null"); +} + +basePluginManager::basePluginManager(pluginConfig config) : config_(std::move(config)) {} + +plugin_load_result +basePluginManager::loadPluginFromPathInternal(const std::filesystem::path &plugin_path) { + plugin_load_result result; + + dlHandleDeleter deleter(config_.finiFuncName); + std::unique_ptr handle( + dlopen(plugin_path.c_str(), RTLD_NOW | RTLD_LOCAL), deleter); + + if (!handle) { + result = failure_result{std::string("Failed to dlopen: ") + dlerror()}; + NIXL_ERROR << "Failed to load plugin from " << plugin_path << ": " + << std::get(result).message; + return result; + } + + using init_func_t = void *(*)(); + init_func_t init = + reinterpret_cast(dlsym(handle.get(), config_.initFuncName.c_str())); + + if (!init) { + result = failure_result{std::string("Failed to find ") + config_.initFuncName + ": " + + dlerror()}; + NIXL_ERROR << "Failed to find " << config_.initFuncName << " in " << plugin_path << ": " + << std::get(result).message; + return result; + } + + void *plugin_interface = init(); + if (!plugin_interface) { + result = failure_result{"Plugin initialization returned nullptr"}; + NIXL_ERROR << "Plugin initialization failed for " << plugin_path; + return result; + } + + if (!checkApiVersion(plugin_interface)) { + result = failure_result{"API version mismatch"}; + NIXL_ERROR << "Plugin API version mismatch for " << plugin_path; + return result; + } + + result = success_result{std::move(handle), plugin_interface}; + + return result; +} + +std::string +basePluginManager::extractPluginNameFromFilename(const std::string &filename) const { + const auto &prefix = config_.filenamePrefix; + const auto &suffix = config_.filenameSuffix; + + const size_t min_length = prefix.size() + suffix.size() + 1; // +1 for at least 1 char name + if (filename.size() < min_length) { + return ""; + } + + if (filename.compare(0, prefix.size(), prefix) != 0) { + return ""; + } + + if (filename.compare(filename.size() - suffix.size(), suffix.size(), suffix) != 0) { + return ""; + } + + return filename.substr(prefix.size(), filename.size() - prefix.size() - suffix.size()); +} + +std::filesystem::path +basePluginManager::constructPluginPath(const std::filesystem::path &directory, + const std::string &plugin_name) const { + + std::string filename = config_.filenamePrefix + plugin_name + config_.filenameSuffix; + return directory / filename; +} + +void +basePluginManager::discoverPluginsFromDir(const std::filesystem::path &dirpath) { + std::error_code ec; + // Use recursive iterator to find plugins in subdirectories too (for build directories) + std::filesystem::recursive_directory_iterator dir_iter(dirpath, ec); + if (ec) { + NIXL_ERROR << "Error accessing plugin directory(" << dirpath << "): " << ec.message(); + return; + } + + for (const auto &entry : dir_iter) { + if (!entry.is_regular_file(ec)) { + continue; + } + + std::string filename = entry.path().filename().string(); + std::string plugin_name = extractPluginNameFromFilename(filename); + + if (!plugin_name.empty()) { + // Restore old behavior: actually load the plugin during discovery + auto plugin = loadPlugin(plugin_name); + if (plugin) { + NIXL_INFO << "Discovered and loaded plugin: " << plugin_name; + } + } + } +} + +void +basePluginManager::addPluginDirectory(const std::filesystem::path &directory) { + if (directory.empty()) { + NIXL_ERROR << "Cannot add empty plugin directory"; + return; + } + + if (!std::filesystem::exists(directory) || !std::filesystem::is_directory(directory)) { + NIXL_ERROR << "Plugin directory does not exist or is not readable: " << directory; + return; + } + + { + lock_guard lg(lock_); + + if (std::find(pluginDirs_.begin(), pluginDirs_.end(), directory) != pluginDirs_.end()) { + NIXL_WARN << "Plugin directory already registered: " << directory; + return; + } + + pluginDirs_.insert(pluginDirs_.begin(), directory); + } + + NIXL_INFO << "Added plugin directory: " << directory; + + discoverPluginsFromDir(directory); +} + +std::vector +basePluginManager::getPluginDirectories() const { + lock_guard lg(lock_); + return pluginDirs_; +} + +std::shared_ptr +basePluginManager::loadPluginInternal(const std::string &plugin_name) { + lock_guard lg(lock_); + + // Check if the plugin is already loaded + auto it = loadedPlugins_.find(plugin_name); + if (it != loadedPlugins_.end()) { + return it->second; + } + + // Try to load the plugin from all registered directories + for (const auto &dir : pluginDirs_) { + if (dir.empty()) { + continue; + } + + // Construct expected plugin path in this directory + auto plugin_path = constructPluginPath(dir, plugin_name); + + // Skip if plugin file doesn't exist in this directory + if (!std::filesystem::exists(plugin_path)) { + NIXL_DEBUG << "Plugin not found at: " << plugin_path; + continue; + } + + // Load the plugin + auto result = loadPluginFromPathInternal(plugin_path); + if (std::holds_alternative>(result)) { + auto &success = std::get>(result); + auto plugin_handle = createPluginHandle(std::move(success.handle), success.interface); + + if (plugin_handle) { + loadedPlugins_.emplace(plugin_name, plugin_handle); + NIXL_INFO << "Loaded plugin: " << plugin_name << " (version " + << plugin_handle->getVersion() << ")"; + onPluginLoaded(plugin_name, success.interface); + return plugin_handle; + } + } + } + + // Failed to load the plugin + NIXL_ERROR << "Failed to load plugin '" << plugin_name << "' from any directory"; + return nullptr; +} + +std::shared_ptr +basePluginManager::getPluginInternal(const std::string &plugin_name) const { + lock_guard lg(lock_); + + auto it = loadedPlugins_.find(plugin_name); + if (it != loadedPlugins_.end()) { + return it->second; + } + return nullptr; +} + +void +basePluginManager::unloadPlugin(const std::string &plugin_name) { + // Check if plugin can be unloaded (e.g., not a static plugin) + if (!canUnloadPlugin(plugin_name)) { + NIXL_DEBUG << "Plugin '" << plugin_name << "' cannot be unloaded"; + return; + } + + lock_guard lg(lock_); + loadedPlugins_.erase(plugin_name); +} + +std::vector +basePluginManager::getLoadedPluginNames() const { + lock_guard lg(lock_); + + std::vector names; + names.reserve(loadedPlugins_.size()); + for (const auto &pair : loadedPlugins_) { + names.push_back(pair.first); + } + return names; +} diff --git a/src/core/base_plugin_manager.h b/src/core/base_plugin_manager.h new file mode 100644 index 000000000..7d5c31c0d --- /dev/null +++ b/src/core/base_plugin_manager.h @@ -0,0 +1,281 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __BASE_PLUGIN_MANAGER_H +#define __BASE_PLUGIN_MANAGER_H + +#include "common/nixl_log.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * Custom deleter for dlopen handles + */ +struct dlHandleDeleter { + std::string fini_func_name; + + explicit dlHandleDeleter(std::string fini_name = "") : fini_func_name(std::move(fini_name)) {} + + void + operator()(void *handle) const noexcept; +}; + +struct pluginConfig { + std::string initFuncName; // e.g., "nixl_plugin_init" + std::string finiFuncName; // e.g., "nixl_plugin_fini" + std::string filenamePrefix; // e.g., "libplugin_" + std::string filenameSuffix; // e.g., ".so" + int expectedApiVersion; // Expected API version for validation +}; + +template struct success_result { + std::unique_ptr handle; + pluginInterface *interface; +}; + +struct failure_result { + std::string message; +}; + +template +using plugin_load_result = std::variant>; + +/** + * Base class for all plugin handles + * Provides common functionality for managing dynamic library handles + */ +class basePluginHandle { +public: + virtual ~basePluginHandle() = default; + + basePluginHandle(const basePluginHandle &) = delete; + basePluginHandle & + operator=(const basePluginHandle &) = delete; + basePluginHandle(basePluginHandle &&) = delete; + basePluginHandle & + operator=(basePluginHandle &&) = delete; + + virtual const char * + getName() const = 0; + virtual const char * + getVersion() const = 0; + + const void * + getPluginInterface() const { + return pluginInterface_; + } + +protected: + std::unique_ptr dlHandle_; + const void *pluginInterface_; + + basePluginHandle(std::unique_ptr handle, const void *pluginInterface); +}; + +/** + * Base plugin manager providing common functionality for all plugin types + * + * This class implements the common plugin loading, discovery, and management + * logic that is shared across different plugin types (backend, telemetry, etc.) + */ +class basePluginManager { +public: + virtual ~basePluginManager() = default; + + basePluginManager(const basePluginManager &) = delete; + basePluginManager & + operator=(const basePluginManager &) = delete; + + /** + * Discover and load all plugins from a directory + * Searches for files matching the configured pattern + */ + void + discoverPluginsFromDir(const std::filesystem::path &dirpath); + + /** + * Add a directory to search for plugins + * New directory is prioritized over existing ones + */ + void + addPluginDirectory(const std::filesystem::path &directory); + + /** + * Get all registered plugin directories + */ + std::vector + getPluginDirectories() const; + + /** + * Get plugin configuration + */ + const pluginConfig & + getConfig() const { + return config_; + } + + /** + * Load a specific plugin by name with automatic type casting + * Searches registered directories and loads the plugin if found + * + * @tparam handleType The specific plugin handle type (defaults to basePluginHandle) + * @param plugin_name Name of the plugin to load + * @return Typed plugin handle or nullptr if not found or cast fails + */ + template + std::shared_ptr + loadPlugin(const std::string &plugin_name) { + auto base_handle = loadPluginInternal(plugin_name); + auto typed_handle = std::dynamic_pointer_cast(base_handle); + if (!typed_handle && base_handle) { + NIXL_ERROR << "Failed to cast plugin '" << plugin_name << "' to requested handle type"; + } + + return typed_handle; + } + + /** + * Get an already loaded plugin handle with automatic type casting + * Returns nullptr if plugin is not loaded + * + * @tparam handleType The specific plugin handle type (defaults to basePluginHandle) + * @param plugin_name Name of the plugin + * @return Typed plugin handle or nullptr if not loaded or cast fails + */ + template + [[nodiscard]] std::shared_ptr + getPlugin(const std::string &plugin_name) const { + auto base_handle = getPluginInternal(plugin_name); + return std::dynamic_pointer_cast(base_handle); + } + + /** + * Load a plugin from a specific file path + * Returns typed plugin handle or nullptr if not loaded or cast fails + */ + template + std::shared_ptr + loadPluginFromPath(const std::filesystem::path &plugin_path, const std::string &plugin_name) { + auto result = loadPluginFromPathInternal(plugin_path); + if (std::holds_alternative>(result)) { + auto &success = std::get>(result); + auto base_handle = createPluginHandle(std::move(success.handle), success.interface); + loadedPlugins_.emplace(plugin_name, base_handle); + onPluginLoaded(plugin_name, success.interface); + + return std::dynamic_pointer_cast(base_handle); + } + return nullptr; + } + + /** + * Unload a plugin by name + * Does nothing if plugin is not loaded or cannot be unloaded + */ + void + unloadPlugin(const std::string &plugin_name); + + /** + * Get all loaded plugin names + */ + std::vector + getLoadedPluginNames() const; + +protected: + mutable std::mutex lock_; + std::map, std::less<>> loadedPlugins_; + explicit basePluginManager(pluginConfig config); + + /** + * Internal non-template implementation of loadPluginFromPath + */ + plugin_load_result + loadPluginFromPathInternal(const std::filesystem::path &plugin_path); + + /** + * Extract plugin name from filename based on configured prefix/suffix + * Returns empty string if filename doesn't match pattern + */ + std::string + extractPluginNameFromFilename(const std::string &filename) const; + + /** + * Construct full plugin path from directory and plugin name + */ + std::filesystem::path + constructPluginPath(const std::filesystem::path &directory, + const std::string &plugin_name) const; + + /** + * Check if API version matches expected version + * Derived classes override to provide specific version checking logic + */ + virtual bool + checkApiVersion(void *plugin_interface) const = 0; + + /** + * Factory method for creating typed plugin handles + * Derived classes override to create their specific handle type + */ + virtual std::shared_ptr + createPluginHandle(std::unique_ptr dl_handle, + void *plugin_interface) = 0; + + /** + * Check if a plugin can be unloaded + * Derived classes can override to prevent unloading (e.g., static plugins) + */ + virtual bool + canUnloadPlugin(const std::string &plugin_name) const { + return true; // Default: allow unload + } + + /** + * Called after successful plugin load - derived classes can do additional setup + */ + virtual void + onPluginLoaded(const std::string &plugin_name, void *plugin_interface) { + // Default: do nothing + (void)plugin_name; + (void)plugin_interface; + } + + /** + * Internal non-template implementation of loadPlugin + */ + std::shared_ptr + loadPluginInternal(const std::string &plugin_name); + + /** + * Internal non-template implementation of getPlugin + */ + std::shared_ptr + getPluginInternal(const std::string &plugin_name) const; + +private: + std::vector pluginDirs_; + pluginConfig config_; +}; + +#endif // __BASE_PLUGIN_MANAGER_H diff --git a/src/core/meson.build b/src/core/meson.build index 54e6344d2..ef17f6e53 100644 --- a/src/core/meson.build +++ b/src/core/meson.build @@ -55,6 +55,7 @@ endif nixl_lib = library('nixl', 'nixl_agent.cpp', + 'base_plugin_manager.cpp', 'nixl_plugin_manager.cpp', 'nixl_listener.cpp', 'telemetry.cpp', diff --git a/src/core/nixl_agent.cpp b/src/core/nixl_agent.cpp index 00d4e482b..4bf911081 100644 --- a/src/core/nixl_agent.cpp +++ b/src/core/nixl_agent.cpp @@ -168,7 +168,7 @@ nixlAgentData::~nixlAgentData() { for (auto & elm: backendEngines) { auto& plugin_manager = nixlPluginManager::getInstance(); - auto plugin_handle = plugin_manager.getPlugin(elm.second->getType()); + auto plugin_handle = plugin_manager.getPlugin(elm.second->getType()); if (plugin_handle) { // If we have a plugin handle, use it to destroy the engine @@ -246,7 +246,7 @@ nixlAgent::getPluginParams (const nixl_backend_t &type, // First try to get options from a loaded plugin auto& plugin_manager = nixlPluginManager::getInstance(); - auto plugin_handle = plugin_manager.getPlugin(type); + auto plugin_handle = plugin_manager.getPlugin(type); if (plugin_handle) { // If the plugin is already loaded, get options directly @@ -256,7 +256,7 @@ nixlAgent::getPluginParams (const nixl_backend_t &type, } // If plugin isn't loaded yet, try to load it temporarily - plugin_handle = plugin_manager.loadPlugin(type); + plugin_handle = plugin_manager.loadPlugin(type); if (plugin_handle) { params = plugin_handle->getBackendOptions(); mems = plugin_handle->getBackendMems(); @@ -332,7 +332,7 @@ nixlAgent::createBackend(const nixl_backend_t &type, // First, try to load the backend as a plugin auto& plugin_manager = nixlPluginManager::getInstance(); - auto plugin_handle = plugin_manager.loadPlugin(type); + auto plugin_handle = plugin_manager.loadPlugin(type); if (plugin_handle) { // Plugin found, use it to create the backend diff --git a/src/core/nixl_plugin_manager.cpp b/src/core/nixl_plugin_manager.cpp index fe11531a3..ab153afa2 100644 --- a/src/core/nixl_plugin_manager.cpp +++ b/src/core/nixl_plugin_manager.cpp @@ -17,7 +17,6 @@ #include "plugin_manager.h" #include "nixl.h" -#include "common/nixl_log.h" #include #include #include @@ -31,25 +30,10 @@ using lock_guard = const std::lock_guard; // pluginHandle implementation -nixlPluginHandle::nixlPluginHandle(void* handle, nixlBackendPlugin* plugin) - : handle_(handle), plugin_(plugin) { -} - -nixlPluginHandle::~nixlPluginHandle() { - if (handle_) { - // Call the plugin's cleanup function - typedef void (*fini_func_t)(); - fini_func_t fini = (fini_func_t) dlsym(handle_, "nixl_plugin_fini"); - if (fini) { - fini(); - } - - // Close the dynamic library - dlclose(handle_); - handle_ = nullptr; - plugin_ = nullptr; - } -} +nixlPluginHandle::nixlPluginHandle(std::unique_ptr handle, + nixlBackendPlugin *plugin) + : basePluginHandle(std::move(handle), plugin), + plugin_(plugin) {} nixlBackendEngine* nixlPluginHandle::createEngine(const nixlBackendInitParams* init_params) const { if (plugin_ && plugin_->create_engine) { @@ -115,64 +99,49 @@ std::map loadPluginList(const std::string& filename return plugins; } -std::shared_ptr nixlPluginManager::loadPluginFromPath(const std::string& plugin_path) { - // Open the plugin file - void* handle = dlopen(plugin_path.c_str(), RTLD_NOW | RTLD_LOCAL); - if (!handle) { - NIXL_INFO << "Failed to load plugin from " << plugin_path << ": " << dlerror(); - return nullptr; +bool +nixlPluginManager::checkApiVersion(void *plugin_interface) const { + if (!plugin_interface) { + return false; } - // Get the initialization function - typedef nixlBackendPlugin* (*init_func_t)(); - init_func_t init = (init_func_t) dlsym(handle, "nixl_plugin_init"); - if (!init) { - NIXL_ERROR << "Failed to find nixl_plugin_init in " << plugin_path << ": " << dlerror(); - dlclose(handle); - return nullptr; - } + nixlBackendPlugin *plugin = static_cast(plugin_interface); + return plugin->api_version == NIXL_PLUGIN_API_VERSION; +} - // Call the initialization function - nixlBackendPlugin* plugin = init(); - if (!plugin) { - NIXL_ERROR << "Plugin initialization failed for " << plugin_path; - dlclose(handle); - return nullptr; - } +std::shared_ptr +nixlPluginManager::createPluginHandle(std::unique_ptr dl_handle, + void *plugin_interface) { + auto *plugin = static_cast(plugin_interface); + return std::make_shared(std::move(dl_handle), plugin); +} - // Check API version - if (plugin->api_version != NIXL_PLUGIN_API_VERSION) { - NIXL_ERROR << "Plugin API version mismatch for " << plugin_path - << ": expected " << NIXL_PLUGIN_API_VERSION - << ", got " << plugin->api_version; - dlclose(handle); - return nullptr; +bool +nixlPluginManager::canUnloadPlugin(const std::string &plugin_name) const { + // Do not unload static plugins + for (const auto &splugin : staticPlugins_) { + if (splugin.name == plugin_name) { + return false; + } } - - // Create and store the plugin handle - auto plugin_handle = std::make_shared(handle, plugin); - - return plugin_handle; + return true; } -void nixlPluginManager::loadPluginsFromList(const std::string& filename) { +void +nixlPluginManager::loadPluginsFromList(const std::string &filename) { auto plugins = loadPluginList(filename); - lock_guard lg(lock); - for (const auto& pair : plugins) { const std::string& name = pair.first; - const std::string& path = pair.second; + const std::filesystem::path path = pair.second; - auto plugin_handle = loadPluginFromPath(path); - if (plugin_handle) { - loaded_plugins_[name] = plugin_handle; - } + // Load using base class - it will handle storage + basePluginManager::loadPluginFromPath(path, name); } } namespace { -static std::string +static std::filesystem::path getPluginDir() { // Environment variable takes precedence const char *plugin_dir = getenv("NIXL_PLUGIN_DIR"); @@ -186,12 +155,17 @@ getPluginDir() { NIXL_ERROR << "Failed to get plugin directory from dladdr"; return ""; } - return (std::filesystem::path(info.dli_fname).parent_path() / "plugins").string(); + return std::filesystem::path(info.dli_fname).parent_path() / "plugins"; } } // namespace // PluginManager implementation -nixlPluginManager::nixlPluginManager() { +nixlPluginManager::nixlPluginManager() + : basePluginManager(pluginConfig{.initFuncName = "nixl_plugin_init", + .finiFuncName = "nixl_plugin_fini", + .filenamePrefix = "libplugin_", + .filenameSuffix = ".so", + .expectedApiVersion = NIXL_PLUGIN_API_VERSION}) { // Force levels right before logging #ifdef NIXL_USE_PLUGIN_FILE NIXL_DEBUG << "Loading plugins from file: " << NIXL_USE_PLUGIN_FILE; @@ -201,11 +175,10 @@ nixlPluginManager::nixlPluginManager() { } #endif - std::string plugin_dir = getPluginDir(); + auto plugin_dir = getPluginDir(); if (!plugin_dir.empty()) { - NIXL_DEBUG << "Loading plugins from: " << plugin_dir; - plugin_dirs_.insert(plugin_dirs_.begin(), plugin_dir); - discoverPluginsFromDir(plugin_dir); + NIXL_DEBUG << "Loading backend plugins from: " << plugin_dir; + addPluginDirectory(plugin_dir); } registerBuiltinPlugins(); @@ -219,129 +192,6 @@ nixlPluginManager& nixlPluginManager::getInstance() { return instance; } -void nixlPluginManager::addPluginDirectory(const std::string& directory) { - if (directory.empty()) { - NIXL_ERROR << "Cannot add empty plugin directory"; - return; - } - - // Check if directory exists - if (!std::filesystem::exists(directory) || !std::filesystem::is_directory(directory)) { - NIXL_ERROR << "Plugin directory does not exist or is not readable: " << directory; - return; - } - - { - lock_guard lg(lock); - - // Check if directory is already in the list - for (const auto& dir : plugin_dirs_) { - if (dir == directory) { - NIXL_WARN << "Plugin directory already registered: " << directory; - return; - } - } - - // Prioritize the new directory by inserting it at the beginning - plugin_dirs_.insert(plugin_dirs_.begin(), directory); - } - - discoverPluginsFromDir(directory); -} - -std::shared_ptr nixlPluginManager::loadPlugin(const std::string& plugin_name) { - lock_guard lg(lock); - - // Check if the plugin is already loaded - // Static Plugins are preloaded so return handle - auto it = loaded_plugins_.find(plugin_name); - if (it != loaded_plugins_.end()) { - return it->second; - } - - // Try to load the plugin from all registered directories - for (const auto& dir : plugin_dirs_) { - // Handle path joining correctly with or without trailing slash - std::string plugin_path; - if (dir.empty()) { - continue; - } else if (dir.back() == '/') { - plugin_path = dir + "libplugin_" + plugin_name + ".so"; - } else { - plugin_path = dir + "/libplugin_" + plugin_name + ".so"; - } - - // Check if the plugin file exists before attempting to load i - if (!std::filesystem::exists(plugin_path)) { - NIXL_WARN << "Plugin file does not exist: " << plugin_path; - continue; - } - - auto plugin_handle = loadPluginFromPath(plugin_path); - if (plugin_handle) { - loaded_plugins_[plugin_name] = plugin_handle; - return plugin_handle; - } - } - - // Failed to load the plugin - NIXL_INFO << "Failed to load plugin '" << plugin_name << "' from any directory"; - return nullptr; -} - -void nixlPluginManager::discoverPluginsFromDir(const std::string& dirpath) { - std::filesystem::path dir_path(dirpath); - std::error_code ec; - std::filesystem::directory_iterator dir_iter(dir_path, ec); - if (ec) { - NIXL_ERROR << "Error accessing directory(" << dir_path << "): " - << ec.message(); - return; - } - - for (const auto& entry : dir_iter) { - std::string filename = entry.path().filename().string(); - - if(filename.size() < 11) continue; - // Check if this is a plugin file - if (filename.substr(0, 10) == "libplugin_" && - filename.substr(filename.size() - 3) == ".so") { - - // Extract plugin name - std::string plugin_name = filename.substr(10, filename.size() - 13); - - // Try to load the plugin - auto plugin = loadPlugin(plugin_name); - if (plugin) { - NIXL_INFO << "Discovered and loaded plugin: " << plugin_name; - } - } - } -} - -void nixlPluginManager::unloadPlugin(const nixl_backend_t& plugin_name) { - // Do no unload static plugins - for (const auto& splugin : getStaticPlugins()) { - if (splugin.name == plugin_name) { - return; - } - } - - lock_guard lg(lock); - - loaded_plugins_.erase(plugin_name); -} - -std::shared_ptr nixlPluginManager::getPlugin(const nixl_backend_t& plugin_name) { - lock_guard lg(lock); - - auto it = loaded_plugins_.find(plugin_name); - if (it != loaded_plugins_.end()) { - return it->second; - } - return nullptr; -} - nixl_b_params_t nixlPluginHandle::getBackendOptions() const { nixl_b_params_t params; if (plugin_ && plugin_->get_backend_options) { @@ -358,36 +208,31 @@ nixl_mem_list_t nixlPluginHandle::getBackendMems() const { return mems; // Return empty mems if not implemented } -std::vector nixlPluginManager::getLoadedPluginNames() { - lock_guard lg(lock); - - std::vector names; - for (const auto& pair : loaded_plugins_) { - names.push_back(pair.first); - } - return names; -} - -void nixlPluginManager::registerStaticPlugin(const char* name, nixlStaticPluginCreatorFunc creator) { - lock_guard lg(lock); - +void +nixlPluginManager::registerStaticPlugin(const char *name, nixlStaticPluginCreatorFunc creator) { nixlStaticPluginInfo info; info.name = name; info.createFunc = creator; - static_plugins_.push_back(info); + staticPlugins_.push_back(info); //Static Plugins are considered pre-loaded nixlBackendPlugin* plugin = info.createFunc(); NIXL_INFO << "Loading static plugin: " << name; if (plugin) { - // Register the loaded plugin - auto plugin_handle = std::make_shared(nullptr, plugin); - loaded_plugins_[name] = plugin_handle; + // Register the loaded plugin (nullptr handle for static plugins) + dlHandleDeleter deleter(""); // No cleanup for static plugins + std::unique_ptr handle(nullptr, deleter); + auto plugin_handle = createPluginHandle(std::move(handle), plugin); + + // Store in base class map + lock_guard lg(lock_); + loadedPlugins_.emplace(name, plugin_handle); } } -const std::vector& nixlPluginManager::getStaticPlugins() { - return static_plugins_; +const std::vector & +nixlPluginManager::getStaticPlugins() const noexcept { + return staticPlugins_; } #define NIXL_REGISTER_STATIC_PLUGIN(name) \ diff --git a/src/core/plugin_manager.h b/src/core/plugin_manager.h index 359488efa..750bc004f 100644 --- a/src/core/plugin_manager.h +++ b/src/core/plugin_manager.h @@ -24,6 +24,7 @@ #include #include #include "backend/backend_plugin.h" +#include "base_plugin_manager.h" // Forward declarations class nixlBackendEngine; @@ -35,19 +36,20 @@ struct nixlBackendInitParams; * operation, e.g., query operations and plugin instance creation. This allows using it in * multi-threading environments without lock protection. */ -class nixlPluginHandle { +class nixlPluginHandle : public basePluginHandle { private: - void* handle_; // Handle to the dynamically loaded library - nixlBackendPlugin* plugin_; // Plugin interface + nixlBackendPlugin *plugin_; public: - nixlPluginHandle(void* handle, nixlBackendPlugin* plugin); - ~nixlPluginHandle(); + nixlPluginHandle(std::unique_ptr handle, nixlBackendPlugin *plugin); + ~nixlPluginHandle() override = default; nixlBackendEngine* createEngine(const nixlBackendInitParams* init_params) const; void destroyEngine(nixlBackendEngine* engine) const; - const char* getName() const; - const char* getVersion() const; + const char * + getName() const override; + const char * + getVersion() const override; nixl_b_params_t getBackendOptions() const; nixl_mem_list_t getBackendMems() const; }; @@ -61,12 +63,9 @@ struct nixlStaticPluginInfo { nixlStaticPluginCreatorFunc createFunc; }; -class nixlPluginManager { +class nixlPluginManager : public basePluginManager { private: - std::map> loaded_plugins_; - std::vector plugin_dirs_; - std::vector static_plugins_; - std::mutex lock; + std::vector staticPlugins_; void registerBuiltinPlugins(); void registerStaticPlugin(const char* name, nixlStaticPluginCreatorFunc creator); @@ -74,6 +73,17 @@ class nixlPluginManager { // Private constructor for singleton pattern nixlPluginManager(); +protected: + bool + checkApiVersion(void *plugin_interface) const override; + + std::shared_ptr + createPluginHandle(std::unique_ptr dl_handle, + void *plugin_interface) override; + + bool + canUnloadPlugin(const std::string &plugin_name) const override; + public: // Singleton instance accessor static nixlPluginManager& getInstance(); @@ -82,33 +92,17 @@ class nixlPluginManager { nixlPluginManager(const nixlPluginManager&) = delete; nixlPluginManager& operator=(const nixlPluginManager&) = delete; - std::shared_ptr loadPluginFromPath(const std::string& plugin_path); - - void loadPluginsFromList(const std::string& filename); - - // Load a specific plugin - std::shared_ptr loadPlugin(const nixl_backend_t& plugin_name); - - // Search a directory for plugins - void discoverPluginsFromDir(const std::string& dirpath); - - // Unload a plugin - void unloadPlugin(const nixl_backend_t& plugin_name); - - // Get a plugin handle - std::shared_ptr getPlugin(const nixl_backend_t& plugin_name); - - // Get all loaded plugin names - std::vector getLoadedPluginNames(); + // Backend-specific plugin loading + void + loadPluginsFromList(const std::string &filename); // Get backend options - nixl_b_params_t getBackendOptions(const nixl_backend_t& type); - - // Add a plugin directory - void addPluginDirectory(const std::string& directory); + nixl_b_params_t + getBackendOptions(const nixl_backend_t &type); // Static Plugin Helpers - const std::vector& getStaticPlugins(); + const std::vector & + getStaticPlugins() const noexcept; }; #endif // __PLUGIN_MANAGER_H diff --git a/test/gtest/plugin_manager.cpp b/test/gtest/plugin_manager.cpp index 8dd274173..f0c72125e 100644 --- a/test/gtest/plugin_manager.cpp +++ b/test/gtest/plugin_manager.cpp @@ -50,7 +50,7 @@ class LoadSinglePluginTestFixture if (GetParam().type == PluginDesc::PluginType::Real) GTEST_SKIP(); #endif - plugin_handle_ = plugin_manager_.loadPlugin(GetParam().name); + plugin_handle_ = plugin_manager_.loadPlugin(GetParam().name); } void TearDown() override { @@ -77,7 +77,7 @@ class LoadMultiplePluginsTestFixture if (plugin.type == PluginDesc::PluginType::Real) continue; #endif - plugin_handles_.push_back(plugin_manager_.loadPlugin(plugin.name)); + plugin_handles_.push_back(plugin_manager_.loadPlugin(plugin.name)); } } diff --git a/test/nixl/test_plugin.cpp b/test/nixl/test_plugin.cpp index b75c7d4cc..dd67070f0 100644 --- a/test/nixl/test_plugin.cpp +++ b/test/nixl/test_plugin.cpp @@ -45,7 +45,7 @@ int verify_plugin(std::string name, nixlPluginManager& plugin_manager) std::cout << "\nLoading " << name << " plugin..." << std::endl; // Load the plugin - auto plugin_ = plugin_manager.loadPlugin(name); + auto plugin_ = plugin_manager.loadPlugin(name); if (!plugin_) { std::cerr << "Failed to load " << name << " plugin" << std::endl; return -1;