Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ class OmniDiffusionConfig:
# Compilation
enable_torch_compile: bool = False

# Enable sleep mode
enable_sleep_mode: bool = False

disable_autocast: bool = False

# VSA parameters
Expand Down
67 changes: 61 additions & 6 deletions vllm_omni/diffusion/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import multiprocessing as mp
import os
import time
from collections.abc import Iterable
from contextlib import AbstractContextManager, nullcontext

import torch
import zmq
Expand Down Expand Up @@ -42,7 +44,7 @@ def __init__(
self.rank = rank
self.od_config = od_config
self.pipeline = None

self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
self.init_device_and_model()

def init_device_and_model(self) -> None:
Expand Down Expand Up @@ -71,11 +73,12 @@ def init_device_and_model(self) -> None:
load_config = LoadConfig()
model_loader = DiffusersPipelineLoader(load_config)
time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m:
self.pipeline = model_loader.load_model(
od_config=self.od_config,
load_device=f"cuda:{rank}",
)
with self._maybe_get_memory_pool_context(tag="weights"):
with DeviceMemoryProfiler() as m:
self.pipeline = model_loader.load_model(
od_config=self.od_config,
load_device=f"cuda:{rank}",
)
time_after_load = time.perf_counter()

logger.info(
Expand Down Expand Up @@ -107,6 +110,58 @@ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusi
output = self.pipeline.forward(req)
return output

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return self.pipeline.loaded_weights(weights)

def sleep(self, level: int = 1) -> bool:
from vllm.device_allocator.cumem import CuMemAllocator

free_bytes_before_sleep = torch.cuda.mem_get_info()[0]

# Save the buffers before level 2 sleep
if level == 2:
model = self.pipeline
self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()}

allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
freed_bytes / GiB_bytes,
used_bytes / GiB_bytes,
)
return True

def wake_up(self, tags: list[str] | None = None) -> bool:
from vllm.device_allocator.cumem import CuMemAllocator

allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags)

# Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers):
model = self.pipeline
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
return True

def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
if self.od_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator

allocator = CuMemAllocator.get_instance()
if tag == "weights":
assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process."
return allocator.use_memory_pool(tag=tag)
else:
return nullcontext()

def shutdown(self) -> None:
if torch.distributed.is_initialized():
try:
Expand Down