diff --git a/docs/features/weight_update.md b/docs/features/weight_update.md index adb61704c80..6ee8de53869 100644 --- a/docs/features/weight_update.md +++ b/docs/features/weight_update.md @@ -50,7 +50,7 @@ In FastDeploy >= 2.6, the underlying control-signal communication path is optimi | `/v1/is_paused` | `GET` | none | Return `{"is_paused": bool}`. | | `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | Offload selected GPU memory objects. Supported tags are `weight` and `kv_cache`. If omitted, both are used. | | `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | Reload previously offloaded weights and/or KV cache. On success, the engine resumes automatically. | -| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | Refresh weights in place through the worker control path. This API is intended for remote versioned updates, especially `load_strategy=rsync`. | +| `/v1/update_weights` | `POST` | JSON `{"version":"...", "verify_checksum": false}` | Refresh weights in place through the worker control path. This API is intended for remote versioned updates, especially `load_strategy=rsync`. | ### Compatibility Notes @@ -114,7 +114,7 @@ After `wakeup` succeeds, FastDeploy automatically calls `resume`. Current request fields: - `version`: optional string. Used to choose a target checkpoint version. -- `rsync_config`: optional dictionary. Must contain `etcd_server` when provided. +- `verify_checksum`: optional boolean. Defaults to `false`. Set to `true` to verify data integrity during weight synchronization. Important semantics: @@ -186,9 +186,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \ -H "Content-Type: application/json" \ -d '{ "version": "global_step_1200", - "rsync_config": { - "etcd_server": "127.0.0.1:2379" - } + "verify_checksum": false }' ``` @@ -261,9 +259,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \ -H "Content-Type: application/json" \ -d '{ "version": "global_step_1200", - "rsync_config": { - "etcd_server": "127.0.0.1:2379" - } + "verify_checksum": false }' # Resume the service after the update completes diff --git a/docs/zh/features/weight_update.md b/docs/zh/features/weight_update.md index 1b34a29b9bb..c95d89d6bca 100644 --- a/docs/zh/features/weight_update.md +++ b/docs/zh/features/weight_update.md @@ -50,7 +50,7 @@ python -m fastdeploy.entrypoints.openai.api_server \ | `/v1/is_paused` | `GET` | 无 | 返回 `{"is_paused": bool}`。 | | `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | 卸载指定 GPU 内存对象。支持 `weight` 与 `kv_cache`;不传时默认同时处理两者。 | | `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | 重新加载之前被卸载的权重和/或 KV Cache。成功后会自动 `resume`。 | -| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | 通过 worker 控制链路原地刷新模型权重。该接口主要面向 `load_strategy=rsync` 的远端版本更新。 | +| `/v1/update_weights` | `POST` | JSON `{"version":"...", "verify_checksum": false}` | 通过 worker 控制链路原地刷新模型权重。该接口主要面向 `load_strategy=rsync` 的远端版本更新。 | ### 兼容性说明 @@ -113,7 +113,7 @@ python -m fastdeploy.entrypoints.openai.api_server \ 当前支持的请求字段: - `version`:可选字符串,用于指定目标 checkpoint 版本。 -- `rsync_config`:可选字典;如果传入,必须包含 `etcd_server`。 +- `verify_checksum`:可选布尔值;默认为 `false`。设置为 `true` 时,会在权重同步过程中校验数据完整性。 关键语义: @@ -185,9 +185,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \ -H "Content-Type: application/json" \ -d '{ "version": "global_step_1200", - "rsync_config": { - "etcd_server": "127.0.0.1:2379" - } + "verify_checksum": false }' ``` @@ -260,9 +258,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \ -H "Content-Type: application/json" \ -d '{ "version": "global_step_1200", - "rsync_config": { - "etcd_server": "127.0.0.1:2379" - } + "verify_checksum": false }' # 更新完成后恢复服务 diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index a48850e2958..06db509142e 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -461,19 +461,14 @@ async def update_weights(request: Request) -> Response: ) args["version"] = request_data["version"] - # Validate and extract rsync_config parameter - if "rsync_config" in request_data and request_data["rsync_config"] is not None: - if not isinstance(request_data["rsync_config"], dict): + # Validate and extract verify_checksum parameter + if "verify_checksum" in request_data and request_data["verify_checksum"] is not None: + if not isinstance(request_data["verify_checksum"], bool): return JSONResponse( status_code=400, - content={"error": "Invalid parameter type", "message": "rsync_config must be a dictionary"}, + content={"error": "Invalid parameter type", "message": "verify_checksum must be a boolean"}, ) - if "etcd_server" not in request_data["rsync_config"]: - return JSONResponse( - status_code=400, - content={"error": "Invalid parameter type", "message": "rsync_config must contain etcd_server"}, - ) - args["rsync_config"] = request_data["rsync_config"] + args["verify_checksum"] = request_data["verify_checksum"] control_request = ControlRequest(request_id, "update_weights", args) control_response = await app.state.engine_client.run_control_method(control_request) diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index ea26db28b1e..7cda0acee7b 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -16,7 +16,6 @@ import gc import glob -import io import os import re import time @@ -31,30 +30,6 @@ from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus -def sync_weights_by_rdma(config, step, rank): - from checkpoint_transfer.core import RDMAWeightsDownloader - - downloader = RDMAWeightsDownloader(config) - downloader.initialize() - logger.info(f"Fetching weights for step:{step}, rank:{rank}...") - data = downloader.get_weights(step, rank) - if data is None: - logger.error("Failed to get weights!") - raise Exception("Failed to rsync weights through checkpoint_transfer") - logger.info(f"Successfully retrieved data. Type: {type(data)}") - if isinstance(data, np.ndarray): - data_bytes = data.tobytes() - elif isinstance(data, (bytes, bytearray)): - data_bytes = data - else: - data_bytes = bytes(data) - logger.info(f"Data size: {len(data_bytes)} bytes") - - buffer = io.BytesIO(data_bytes) - new_state_dict = paddle.load(buffer) - return new_state_dict - - class DynamicWeightManager: """Manages model weights loading, updating and shared state across processes.""" @@ -75,6 +50,7 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int): else: self.model_list = models self._capture_model_state() + self.rdma_handle = None if self.load_config.load_strategy == "rsync": self.update_weights_by_rdma() else: @@ -91,10 +67,12 @@ def _capture_model_state(self): """Capture and store initial model parameters state.""" for model in self.model_list: for name, param in model.state_dict().items(): - logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}") + if hasattr(param, "_is_initialized") and not param._is_initialized(): + param.initialize() + logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}, place={param.place}") self.state_dict[name] = param - def update_weights_by_rdma(self, version: str = None, rsync_config: Dict[str, Any] = None): + def update_weights_by_rdma(self, version: str = None, verify_checksum: bool = False): def valid_parameters(old_state_dict, new_state_dict): is_valid = True for key in old_state_dict: @@ -110,17 +88,11 @@ def valid_parameters(old_state_dict, new_state_dict): ) elif old_state_dict[key].dtype != new_state_dict[key].dtype: is_valid = False - logger.error(f"Invalid parameter: {key} dtype mismatch") + logger.error( + f"Invalid parameter: {key} dtype mismatch, old:{old_state_dict[key].dtype}, new:{new_state_dict[key].dtype}" + ) return is_valid - if rsync_config is None: - rsync_config = self.fd_config.load_config.rsync_config - if rsync_config is None or len(rsync_config) == 0: - raise Exception( - "rsync config not set, please set it in 1) launch arguments '--rsync-config' " - "or 2) interface arguments 'rsync_config'" - ) - if version is None or version == "": version = self.read_model_version_from_file() if version is None or version == "": @@ -129,11 +101,23 @@ def valid_parameters(old_state_dict, new_state_dict): "or 2) interface arguments 'version'" ) - logger.info(f"START update_weights_by_rdma, version:{version}, rsync_config:{rsync_config}") - rank = self.local_rank + logger.info( + f"START rank:{self.local_rank}/{self.nranks} update_weights_by_rdma, " + f"version:{version}, verify_checksum:{verify_checksum}" + ) + + if self.rdma_handle is None: + from checkpoint_transfer import CheckpointTransfer + + config = self.fd_config.load_config.rsync_config + logger.info(f"CheckpointTransfer rsync config:{config}") + self.rdma_handle = CheckpointTransfer(**config, local_rank=self.local_rank, group_size=self.nranks) + self.rdma_handle.initialize() sync_start = time.perf_counter() - new_state_dict = sync_weights_by_rdma(rsync_config, version, rank) + new_state_dict = dict() + for key, param in self.rdma_handle.receive_stream(step_id=version, verify_checksum=verify_checksum): + new_state_dict[key] = param sync_cost = time.perf_counter() - sync_start logger.info(f"weights sync cost {sync_cost:.2f} seconds") @@ -148,18 +132,17 @@ def valid_parameters(old_state_dict, new_state_dict): param.set_value(new_state_dict[name]) update_cost = time.perf_counter() - update_start logger.info(f"params set value cost {update_cost:.2f} seconds") - total_cost = time.perf_counter() - sync_start logger.info( f"END update_weights_by_rdma, cost {total_cost:.2f} seconds" - f" version:{version}, rsync_config: {rsync_config}", + f" version:{version}, verify_checksum: {verify_checksum}, local_rank: {self.local_rank}", ) return { "sync_cost": sync_cost, "update_cost": update_cost, "total_cost": total_cost, "version": version, - "rank": rank, + "rank": self.local_rank, } def update_parameters(self, pid: int = 0, restart_process_group=False) -> None: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index beef546db34..2db362488f3 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -20,7 +20,7 @@ import time from concurrent.futures import Future from threading import Thread -from typing import Any, Dict, List, Optional, cast +from typing import Dict, List, Optional, cast import numpy as np import paddle @@ -2692,8 +2692,8 @@ def update_parameters(self, pid): self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") - def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None): - return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config) + def update_weights(self, version: str = None, verify_checksum: bool = False): + return self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum) def sleep(self, tags): diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 5025dc95c7e..aebf3f21111 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -16,7 +16,7 @@ import gc import time -from typing import Any, Dict, List, Optional +from typing import List, Optional import paddle import pynvml @@ -192,9 +192,9 @@ def initialize_cache(self, num_gpu_blocks: int) -> None: if self.fd_config.routing_replay_config.enable_routing_replay: self.model_runner.initialize_routing_replay_manager() - def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None): + def update_weights(self, version: str = None, verify_checksum: bool = False): """update weights in place""" - return self.model_runner.update_weights(version, rsync_config) + return self.model_runner.update_weights(version, verify_checksum) def sleep(self, **kwargs) -> None: """Offload memory from GPU""" diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index fa7daa41cfb..93f5cec6a57 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -20,7 +20,7 @@ import time from concurrent.futures import Future from threading import Thread -from typing import Any, Dict, List, Optional, cast +from typing import List, Optional, cast import numpy as np import paddle @@ -2550,8 +2550,8 @@ def update_parameters(self, pid): self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") - def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None): - return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config) + def update_weights(self, version: str = None, verify_checksum: bool = False): + return self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum) def padding_cudagraph_inputs(self) -> None: """ diff --git a/fastdeploy/worker/metax_worker.py b/fastdeploy/worker/metax_worker.py index 4e5ef2d9b97..083f5949788 100644 --- a/fastdeploy/worker/metax_worker.py +++ b/fastdeploy/worker/metax_worker.py @@ -17,7 +17,7 @@ import gc import os import time -from typing import Any, Dict, List, Optional +from typing import List, Optional import paddle from paddle import nn @@ -191,9 +191,9 @@ def initialize_cache(self, num_gpu_blocks: int) -> None: # accurate cache size self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) - def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None): + def update_weights(self, version: str = None, verify_checksum: bool = False): """update weights in place""" - return self.model_runner.update_weights(version, rsync_config) + return self.model_runner.update_weights(version, verify_checksum) def execute_model( self, diff --git a/tests/entrypoints/openai/test_api_server.py b/tests/entrypoints/openai/test_api_server.py index 0cd57421701..db12a9a168b 100644 --- a/tests/entrypoints/openai/test_api_server.py +++ b/tests/entrypoints/openai/test_api_server.py @@ -604,13 +604,13 @@ async def test_update_weights_route_validation(): api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_control_response) valid_req = MagicMock() - valid_req.body = AsyncMock(return_value=b'{"version":"v2","rsync_config":{"etcd_server":"127.0.0.1"}}') - valid_req.json = AsyncMock(return_value={"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}}) + valid_req.body = AsyncMock(return_value=b'{"version":"v2","verify_checksum":true}') + valid_req.json = AsyncMock(return_value={"version": "v2", "verify_checksum": True}) valid_resp = await api_server.update_weights(valid_req) assert valid_resp.status_code == 200 control_request = api_server.app.state.engine_client.run_control_method.await_args.args[0] assert control_request.method == "update_weights" - assert control_request.args == {"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}} + assert control_request.args == {"version": "v2", "verify_checksum": True} invalid_version_req = MagicMock() invalid_version_req.body = AsyncMock(return_value=b'{"version":1}') @@ -618,11 +618,11 @@ async def test_update_weights_route_validation(): invalid_version_resp = await api_server.update_weights(invalid_version_req) assert invalid_version_resp.status_code == 400 - invalid_rsync_req = MagicMock() - invalid_rsync_req.body = AsyncMock(return_value=b'{"rsync_config":{"user":"u"}}') - invalid_rsync_req.json = AsyncMock(return_value={"rsync_config": {"user": "u"}}) - invalid_rsync_resp = await api_server.update_weights(invalid_rsync_req) - assert invalid_rsync_resp.status_code == 400 + invalid_checksum_req = MagicMock() + invalid_checksum_req.body = AsyncMock(return_value=b'{"verify_checksum":"true"}') + invalid_checksum_req.json = AsyncMock(return_value={"verify_checksum": "true"}) + invalid_checksum_resp = await api_server.update_weights(invalid_checksum_req) + assert invalid_checksum_resp.status_code == 400 @pytest.mark.asyncio