forked from PaddlePaddle/FastDeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpu_worker.py
More file actions
267 lines (225 loc) · 10.8 KB
/
gpu_worker.py
File metadata and controls
267 lines (225 loc) · 10.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import gc
import time
from typing import List, Optional
import paddle
import pynvml
from paddle import nn
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.plugins.model_runner import load_model_runner_plugins
from fastdeploy.usage.usage_lib import report_usage_stats
from fastdeploy.utils import get_logger, set_random_seed
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("gpu_worker", "gpu_worker.log")
try:
ModelRunner = load_model_runner_plugins()
except Exception as e:
logger.info(f"Plugin ModelRunner not available ({e}), using default GPUModelRunner")
from fastdeploy.worker.gpu_model_runner import GPUModelRunner as ModelRunner
class GpuWorker(WorkerBase):
def __init__(
self,
fd_config: FDConfig,
local_rank: int,
rank: int,
):
super().__init__(
fd_config=fd_config,
local_rank=local_rank,
rank=rank,
)
pass
def init_device(self):
"""
Initialize device and construct model runner
"""
self.max_chips_per_node = 8
if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda():
# Set environment variable
self.device_ids = self.parallel_config.device_ids.split(",")
self.device = f"gpu:{self.local_rank % self.max_chips_per_node}"
paddle.device.set_device(self.device)
paddle.set_default_dtype(self.model_config.dtype)
gc.collect()
paddle.device.cuda.empty_cache()
if (
not self.parallel_config.disable_custom_all_reduce
and self.parallel_config.tensor_parallel_size > 1
and paddle.is_compiled_with_cuda()
):
from fastdeploy.distributed.communication import use_custom_allreduce
use_custom_allreduce(self.fd_config.parallel_config.tp_group)
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
if self.local_rank == 0:
report_usage_stats(self.fd_config)
set_random_seed(self.fd_config.model_config.seed)
# Construct model runner
self.model_runner: ModelRunnerBase = ModelRunner(
fd_config=self.fd_config,
device=self.device,
device_id=int(self.device_ids[self.local_rank % self.max_chips_per_node]),
rank=self.rank,
local_rank=self.local_rank,
)
def exist_prefill(self):
"""
check whether prefill stage exist
"""
return self.model_runner.exist_prefill()
def determine_available_memory(self) -> int:
"""
Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# 1. Record memory state before profile run
start_time = time.perf_counter()
Gb = 1024**3
local_rank = self.local_rank % self.max_chips_per_node
paddle.device.cuda.reset_max_memory_reserved(local_rank)
paddle.device.cuda.reset_max_memory_allocated(local_rank)
paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(local_rank)
paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(local_rank) # not reserved
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(int(self.device_ids[local_rank]))
before_run_meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
logger.info(
(
"Before running the profile, the memory usage info is as follows:",
f"\nDevice Total memory: {before_run_meminfo.total / Gb}",
f"\nDevice used memory: {before_run_meminfo.used / Gb}",
f"\nDevice free memory: {before_run_meminfo.free / Gb}",
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}",
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}",
)
)
# 2. Profile run
self.model_runner.profile_run()
set_random_seed(self.fd_config.model_config.seed)
# 3. Statistical memory information
paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(local_rank)
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(local_rank)
model_block_memory_used = self.cal_theortical_kvcache()
paddle_peak_increase = paddle_allocated_mem_after_run - paddle_allocated_mem_before_run
paddle.device.cuda.empty_cache()
after_run_meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
pynvml.nvmlShutdown()
available_kv_cache_memory = (
after_run_meminfo.total * self.cache_config.gpu_memory_utilization
- after_run_meminfo.used
- paddle_peak_increase
)
available_kv_cache_memory += model_block_memory_used * self.cache_config.total_block_num
end_time = time.perf_counter()
logger.info(
(
"After running the profile, the memory usage info is as follows:",
f"\nDevice Total memory: {after_run_meminfo.total / Gb}",
f"\nDevice used memory: {after_run_meminfo.used / Gb}",
f"\nDevice free memory: {after_run_meminfo.free / Gb}",
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
f"Profile time: {end_time - start_time}",
)
)
return available_kv_cache_memory # return to calculate the block num in this device
def load_model(self) -> None:
"""Load model"""
self.model_runner.load_model()
def get_model(self) -> nn.Layer:
"""Get current model"""
return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initizlize the KV Cache with accurate num_gpu_blocks"""
# accurate cache size
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
# Initialize routing replay manager
if self.fd_config.routing_replay_config.enable_routing_replay:
self.model_runner.initialize_routing_replay_manager()
def update_weights(self, version: str = None, verify_checksum: bool = False):
"""update weights in place"""
return self.model_runner.update_weights(version, verify_checksum)
def sleep(self, **kwargs) -> None:
"""Offload memory from GPU"""
return self.model_runner.sleep(**kwargs)
def wakeup(self, **kwargs) -> None:
"""Reload memory into GPU"""
return self.model_runner.wakeup(**kwargs)
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
num_running_request: int = None,
) -> Optional[ModelRunnerOutput]:
""" """
output = self.model_runner.execute_model(model_forward_batch, num_running_request)
return output
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
"""Process new requests and then start the decode loop
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
and workers and modelrunners should not perceive it.
"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.model_runner.insert_tasks_v1(req_dicts=req_dicts, num_running_requests=num_running_requests)
else:
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
def graph_optimize_and_warm_up_model(self) -> None:
"""
Perform the warm-up and the graph optimization.
Execution modes:
| Mode | Prefill + Mixed | Decode |
|-----------------------------------|--------------------------|--------------------------|
| Dynamic (graph_opt_level=0) | Dynamic | Dynamic + CUDAGraph |
| Static Full Graph (full=True) | Dynamic | Static + CUDAGraph |
| Static Split Graph (full=False) | Static + CUDAGraph | Dynamic + CUDAGraph |
"""
if self.fd_config.graph_opt_config.graph_opt_level >= 1 and not self.model_runner.use_cudagraph:
self.model_runner.sot_warmup()
if self.fd_config.graph_opt_config.graph_opt_level >= 1:
self.model_runner.vision_encoder_compile()
# Static split graph mode: capture CUDAGraph for prefill/mixed phase
if (
self.fd_config.graph_opt_config.graph_opt_level >= 1
and not self.fd_config.graph_opt_config.full_cuda_graph
):
self.model_runner.capture_model_prefill_and_mixed()
# Capture CUDAGraph for decode phase (all modes)
self.model_runner.capture_model()
# Deterministic mode: reset RNG and share_inputs after warmup.
# Warmup _dummy_run() calls consume CUDA RNG state and leave stale
# data (infer_seed, stop_flags, seq_lens, etc.) in share_inputs.
# Without this reset, the first real request may see different state
# than subsequent requests, causing occasional first-run divergence.
if envs.FD_DETERMINISTIC_MODE:
set_random_seed(self.fd_config.model_config.seed)
self.model_runner.share_inputs.reset_share_inputs()
def check_health(self) -> bool:
""" """
return True
def cal_theortical_kvcache(self) -> int:
"""Calculate the block memory required"""
return self.model_runner.cal_theortical_kvcache()