Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions tests/diffusion/test_gpu_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
Unit tests for GPUWorker class.

This module tests the GPUWorker implementation:
- load_weights: Loading model weights
- sleep: Putting worker into sleep mode (levels 1 and 2)
- wake_up: Waking worker from sleep mode
"""

from unittest.mock import Mock, patch

import pytest
import torch

from vllm_omni.diffusion.worker.gpu_worker import GPUWorker


@pytest.fixture
def mock_od_config():
"""Create a mock OmniDiffusionConfig."""
config = Mock()
config.num_gpus = 1
config.master_port = 12345
config.enable_sleep_mode = False
config.cache_backend = None
config.cache_config = None
config.model = "test-model"
return config


@pytest.fixture
def mock_gpu_worker(mock_od_config):
"""Create a GPUWorker with mocked initialization."""
with patch.object(GPUWorker, "init_device_and_model"):
worker = GPUWorker(local_rank=0, rank=0, od_config=mock_od_config)
# Mock the pipeline
worker.pipeline = Mock()
worker.cache_backend = None
return worker


class TestGPUWorkerLoadWeights:
"""Test GPUWorker.load_weights method."""

def test_load_weights_calls_pipeline(self, mock_gpu_worker):
"""Test that load_weights delegates to pipeline.load_weights."""
# Setup mock weights
mock_weights = [
("layer1.weight", torch.randn(10, 10)),
("layer2.weight", torch.randn(20, 20)),
]
expected_loaded = {"layer1.weight", "layer2.weight"}

# Configure pipeline mock
mock_gpu_worker.pipeline.load_weights = Mock(return_value=expected_loaded)

# Call load_weights
result = mock_gpu_worker.load_weights(mock_weights)

# Verify pipeline.load_weights was called with the weights
mock_gpu_worker.pipeline.load_weights.assert_called_once_with(mock_weights)
assert result == expected_loaded

def test_load_weights_empty_iterable(self, mock_gpu_worker):
"""Test load_weights with empty weights iterable."""
mock_gpu_worker.pipeline.load_weights = Mock(return_value=set())

result = mock_gpu_worker.load_weights([])

mock_gpu_worker.pipeline.load_weights.assert_called_once_with([])
assert result == set()


class TestGPUWorkerSleep:
"""Test GPUWorker.sleep method."""

@patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info")
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_sleep_level_1(self, mock_allocator_class, mock_mem_info, mock_gpu_worker):
"""Test sleep mode level 1 (offload weights only)."""
# Setup memory info mocks
# Before sleep: 1GB free, 8GB total
# After sleep: 3GB free, 8GB total (freed 2GB)
mock_mem_info.side_effect = [
(1 * 1024**3, 8 * 1024**3), # Before sleep
(3 * 1024**3, 8 * 1024**3), # After sleep
]

# Setup allocator mock
mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.sleep = Mock()

# Call sleep with level 1
result = mock_gpu_worker.sleep(level=1)

# Verify sleep was called with correct tags
mock_allocator.sleep.assert_called_once_with(offload_tags=("weights",))
assert result is True
# Verify buffers were NOT saved (level 1 doesn't save buffers)
assert len(mock_gpu_worker._sleep_saved_buffers) == 0

@patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info")
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_sleep_level_2(self, mock_allocator_class, mock_mem_info, mock_gpu_worker):
"""Test sleep mode level 2 (offload all, save buffers)."""
# Setup memory info mocks
mock_mem_info.side_effect = [
(1 * 1024**3, 8 * 1024**3), # Before sleep
(5 * 1024**3, 8 * 1024**3), # After sleep (freed 4GB)
]

# Setup allocator mock
mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.sleep = Mock()

# Mock pipeline buffers
mock_buffer1 = torch.randn(10, 10)
mock_buffer2 = torch.randn(20, 20)
mock_gpu_worker.pipeline.named_buffers = Mock(
return_value=[
("buffer1", mock_buffer1),
("buffer2", mock_buffer2),
]
)

# Call sleep with level 2
result = mock_gpu_worker.sleep(level=2)

# Verify sleep was called with empty tags (offload all)
mock_allocator.sleep.assert_called_once_with(offload_tags=tuple())
assert result is True

# Verify buffers were saved
assert len(mock_gpu_worker._sleep_saved_buffers) == 2
assert "buffer1" in mock_gpu_worker._sleep_saved_buffers
assert "buffer2" in mock_gpu_worker._sleep_saved_buffers

@patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info")
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_sleep_memory_freed_validation(self, mock_allocator_class, mock_mem_info, mock_gpu_worker):
"""Test that sleep validates memory was actually freed."""
# Simulate memory increase (should trigger assertion error)
mock_mem_info.side_effect = [
(3 * 1024**3, 8 * 1024**3), # Before sleep: 3GB free
(1 * 1024**3, 8 * 1024**3), # After sleep: 1GB free (negative freed!)
]

mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.sleep = Mock()

# This should raise an assertion error
with pytest.raises(AssertionError, match="Memory usage increased after sleeping"):
mock_gpu_worker.sleep(level=1)


class TestGPUWorkerWakeUp:
"""Test GPUWorker.wake_up method."""

@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_wake_up_without_buffers(self, mock_allocator_class, mock_gpu_worker):
"""Test wake_up without saved buffers (level 1 sleep)."""
# Setup allocator mock
mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.wake_up = Mock()

# Ensure no saved buffers
mock_gpu_worker._sleep_saved_buffers = {}

# Call wake_up
result = mock_gpu_worker.wake_up(tags=["weights"])

# Verify allocator.wake_up was called
mock_allocator.wake_up.assert_called_once_with(["weights"])
assert result is True

@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_wake_up_with_buffers(self, mock_allocator_class, mock_gpu_worker):
"""Test wake_up with saved buffers (level 2 sleep)."""
# Setup allocator mock
mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.wake_up = Mock()

# Create saved buffers
saved_buffer1 = torch.randn(10, 10)
saved_buffer2 = torch.randn(20, 20)
mock_gpu_worker._sleep_saved_buffers = {
"buffer1": saved_buffer1,
"buffer2": saved_buffer2,
}

# Mock pipeline buffers (these will be restored)
mock_buffer1 = Mock()
mock_buffer1.data = Mock()
mock_buffer2 = Mock()
mock_buffer2.data = Mock()

mock_gpu_worker.pipeline.named_buffers = Mock(
return_value=[
("buffer1", mock_buffer1),
("buffer2", mock_buffer2),
]
)

# Call wake_up
result = mock_gpu_worker.wake_up(tags=None)

# Verify allocator.wake_up was called
mock_allocator.wake_up.assert_called_once_with(None)

# Verify buffers were restored
mock_buffer1.data.copy_.assert_called_once()
mock_buffer2.data.copy_.assert_called_once()

# Verify saved buffers were cleared
assert len(mock_gpu_worker._sleep_saved_buffers) == 0
assert result is True

@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_wake_up_partial_buffer_restore(self, mock_allocator_class, mock_gpu_worker):
"""Test wake_up only restores buffers that were saved."""
# Setup allocator mock
mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.wake_up = Mock()

# Only save buffer1, not buffer2
saved_buffer1 = torch.randn(10, 10)
mock_gpu_worker._sleep_saved_buffers = {
"buffer1": saved_buffer1,
}

# Mock pipeline has both buffers
mock_buffer1 = Mock()
mock_buffer1.data = Mock()
mock_buffer2 = Mock()
mock_buffer2.data = Mock()

mock_gpu_worker.pipeline.named_buffers = Mock(
return_value=[
("buffer1", mock_buffer1),
("buffer2", mock_buffer2),
]
)

# Call wake_up
result = mock_gpu_worker.wake_up()

# Verify only buffer1 was restored
mock_buffer1.data.copy_.assert_called_once()
# buffer2 should NOT be restored since it wasn't saved
mock_buffer2.data.copy_.assert_not_called()

assert result is True
3 changes: 3 additions & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ class OmniDiffusionConfig:
# Compilation
enable_torch_compile: bool = False

# Enable sleep mode
enable_sleep_mode: bool = False

disable_autocast: bool = False

# VSA parameters
Expand Down
67 changes: 61 additions & 6 deletions vllm_omni/diffusion/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import multiprocessing as mp
import os
import time
from collections.abc import Iterable
from contextlib import AbstractContextManager, nullcontext

import torch
import zmq
Expand Down Expand Up @@ -42,7 +44,7 @@ def __init__(
self.rank = rank
self.od_config = od_config
self.pipeline = None

self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
self.init_device_and_model()

def init_device_and_model(self) -> None:
Expand Down Expand Up @@ -71,11 +73,12 @@ def init_device_and_model(self) -> None:
load_config = LoadConfig()
model_loader = DiffusersPipelineLoader(load_config)
time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m:
self.pipeline = model_loader.load_model(
od_config=self.od_config,
load_device=f"cuda:{rank}",
)
with self._maybe_get_memory_pool_context(tag="weights"):
with DeviceMemoryProfiler() as m:
self.pipeline = model_loader.load_model(
od_config=self.od_config,
load_device=f"cuda:{rank}",
)
time_after_load = time.perf_counter()

logger.info(
Expand Down Expand Up @@ -107,6 +110,58 @@ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusi
output = self.pipeline.forward(req)
return output

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return self.pipeline.load_weights(weights)

def sleep(self, level: int = 1) -> bool:
from vllm.device_allocator.cumem import CuMemAllocator

free_bytes_before_sleep = torch.cuda.mem_get_info()[0]

# Save the buffers before level 2 sleep
if level == 2:
model = self.pipeline
self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()}

allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
freed_bytes / GiB_bytes,
used_bytes / GiB_bytes,
)
return True

def wake_up(self, tags: list[str] | None = None) -> bool:
from vllm.device_allocator.cumem import CuMemAllocator

allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags)

# Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers):
model = self.pipeline
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
return True

def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
if self.od_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator

allocator = CuMemAllocator.get_instance()
if tag == "weights":
assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process."
return allocator.use_memory_pool(tag=tag)
else:
return nullcontext()

def shutdown(self) -> None:
if torch.distributed.is_initialized():
try:
Expand Down