diff --git a/tests/diffusion/test_gpu_worker.py b/tests/diffusion/test_gpu_worker.py new file mode 100644 index 000000000..defeffe5b --- /dev/null +++ b/tests/diffusion/test_gpu_worker.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Unit tests for GPUWorker class. + +This module tests the GPUWorker implementation: +- load_weights: Loading model weights +- sleep: Putting worker into sleep mode (levels 1 and 2) +- wake_up: Waking worker from sleep mode +""" + +from unittest.mock import Mock, patch + +import pytest +import torch + +from vllm_omni.diffusion.worker.gpu_worker import GPUWorker + + +@pytest.fixture +def mock_od_config(): + """Create a mock OmniDiffusionConfig.""" + config = Mock() + config.num_gpus = 1 + config.master_port = 12345 + config.enable_sleep_mode = False + config.cache_backend = None + config.cache_config = None + config.model = "test-model" + return config + + +@pytest.fixture +def mock_gpu_worker(mock_od_config): + """Create a GPUWorker with mocked initialization.""" + with patch.object(GPUWorker, "init_device_and_model"): + worker = GPUWorker(local_rank=0, rank=0, od_config=mock_od_config) + # Mock the pipeline + worker.pipeline = Mock() + worker.cache_backend = None + return worker + + +class TestGPUWorkerLoadWeights: + """Test GPUWorker.load_weights method.""" + + def test_load_weights_calls_pipeline(self, mock_gpu_worker): + """Test that load_weights delegates to pipeline.load_weights.""" + # Setup mock weights + mock_weights = [ + ("layer1.weight", torch.randn(10, 10)), + ("layer2.weight", torch.randn(20, 20)), + ] + expected_loaded = {"layer1.weight", "layer2.weight"} + + # Configure pipeline mock + mock_gpu_worker.pipeline.load_weights = Mock(return_value=expected_loaded) + + # Call load_weights + result = mock_gpu_worker.load_weights(mock_weights) + + # Verify pipeline.load_weights was called with the weights + mock_gpu_worker.pipeline.load_weights.assert_called_once_with(mock_weights) + assert result == expected_loaded + + def test_load_weights_empty_iterable(self, mock_gpu_worker): + """Test load_weights with empty weights iterable.""" + mock_gpu_worker.pipeline.load_weights = Mock(return_value=set()) + + result = mock_gpu_worker.load_weights([]) + + mock_gpu_worker.pipeline.load_weights.assert_called_once_with([]) + assert result == set() + + +class TestGPUWorkerSleep: + """Test GPUWorker.sleep method.""" + + @patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info") + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_sleep_level_1(self, mock_allocator_class, mock_mem_info, mock_gpu_worker): + """Test sleep mode level 1 (offload weights only).""" + # Setup memory info mocks + # Before sleep: 1GB free, 8GB total + # After sleep: 3GB free, 8GB total (freed 2GB) + mock_mem_info.side_effect = [ + (1 * 1024**3, 8 * 1024**3), # Before sleep + (3 * 1024**3, 8 * 1024**3), # After sleep + ] + + # Setup allocator mock + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.sleep = Mock() + + # Call sleep with level 1 + result = mock_gpu_worker.sleep(level=1) + + # Verify sleep was called with correct tags + mock_allocator.sleep.assert_called_once_with(offload_tags=("weights",)) + assert result is True + # Verify buffers were NOT saved (level 1 doesn't save buffers) + assert len(mock_gpu_worker._sleep_saved_buffers) == 0 + + @patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info") + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_sleep_level_2(self, mock_allocator_class, mock_mem_info, mock_gpu_worker): + """Test sleep mode level 2 (offload all, save buffers).""" + # Setup memory info mocks + mock_mem_info.side_effect = [ + (1 * 1024**3, 8 * 1024**3), # Before sleep + (5 * 1024**3, 8 * 1024**3), # After sleep (freed 4GB) + ] + + # Setup allocator mock + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.sleep = Mock() + + # Mock pipeline buffers + mock_buffer1 = torch.randn(10, 10) + mock_buffer2 = torch.randn(20, 20) + mock_gpu_worker.pipeline.named_buffers = Mock( + return_value=[ + ("buffer1", mock_buffer1), + ("buffer2", mock_buffer2), + ] + ) + + # Call sleep with level 2 + result = mock_gpu_worker.sleep(level=2) + + # Verify sleep was called with empty tags (offload all) + mock_allocator.sleep.assert_called_once_with(offload_tags=tuple()) + assert result is True + + # Verify buffers were saved + assert len(mock_gpu_worker._sleep_saved_buffers) == 2 + assert "buffer1" in mock_gpu_worker._sleep_saved_buffers + assert "buffer2" in mock_gpu_worker._sleep_saved_buffers + + @patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info") + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_sleep_memory_freed_validation(self, mock_allocator_class, mock_mem_info, mock_gpu_worker): + """Test that sleep validates memory was actually freed.""" + # Simulate memory increase (should trigger assertion error) + mock_mem_info.side_effect = [ + (3 * 1024**3, 8 * 1024**3), # Before sleep: 3GB free + (1 * 1024**3, 8 * 1024**3), # After sleep: 1GB free (negative freed!) + ] + + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.sleep = Mock() + + # This should raise an assertion error + with pytest.raises(AssertionError, match="Memory usage increased after sleeping"): + mock_gpu_worker.sleep(level=1) + + +class TestGPUWorkerWakeUp: + """Test GPUWorker.wake_up method.""" + + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_wake_up_without_buffers(self, mock_allocator_class, mock_gpu_worker): + """Test wake_up without saved buffers (level 1 sleep).""" + # Setup allocator mock + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.wake_up = Mock() + + # Ensure no saved buffers + mock_gpu_worker._sleep_saved_buffers = {} + + # Call wake_up + result = mock_gpu_worker.wake_up(tags=["weights"]) + + # Verify allocator.wake_up was called + mock_allocator.wake_up.assert_called_once_with(["weights"]) + assert result is True + + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_wake_up_with_buffers(self, mock_allocator_class, mock_gpu_worker): + """Test wake_up with saved buffers (level 2 sleep).""" + # Setup allocator mock + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.wake_up = Mock() + + # Create saved buffers + saved_buffer1 = torch.randn(10, 10) + saved_buffer2 = torch.randn(20, 20) + mock_gpu_worker._sleep_saved_buffers = { + "buffer1": saved_buffer1, + "buffer2": saved_buffer2, + } + + # Mock pipeline buffers (these will be restored) + mock_buffer1 = Mock() + mock_buffer1.data = Mock() + mock_buffer2 = Mock() + mock_buffer2.data = Mock() + + mock_gpu_worker.pipeline.named_buffers = Mock( + return_value=[ + ("buffer1", mock_buffer1), + ("buffer2", mock_buffer2), + ] + ) + + # Call wake_up + result = mock_gpu_worker.wake_up(tags=None) + + # Verify allocator.wake_up was called + mock_allocator.wake_up.assert_called_once_with(None) + + # Verify buffers were restored + mock_buffer1.data.copy_.assert_called_once() + mock_buffer2.data.copy_.assert_called_once() + + # Verify saved buffers were cleared + assert len(mock_gpu_worker._sleep_saved_buffers) == 0 + assert result is True + + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_wake_up_partial_buffer_restore(self, mock_allocator_class, mock_gpu_worker): + """Test wake_up only restores buffers that were saved.""" + # Setup allocator mock + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.wake_up = Mock() + + # Only save buffer1, not buffer2 + saved_buffer1 = torch.randn(10, 10) + mock_gpu_worker._sleep_saved_buffers = { + "buffer1": saved_buffer1, + } + + # Mock pipeline has both buffers + mock_buffer1 = Mock() + mock_buffer1.data = Mock() + mock_buffer2 = Mock() + mock_buffer2.data = Mock() + + mock_gpu_worker.pipeline.named_buffers = Mock( + return_value=[ + ("buffer1", mock_buffer1), + ("buffer2", mock_buffer2), + ] + ) + + # Call wake_up + result = mock_gpu_worker.wake_up() + + # Verify only buffer1 was restored + mock_buffer1.data.copy_.assert_called_once() + # buffer2 should NOT be restored since it wasn't saved + mock_buffer2.data.copy_.assert_not_called() + + assert result is True diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 8210bedab..5958fb6d3 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -244,6 +244,9 @@ class OmniDiffusionConfig: # Compilation enable_torch_compile: bool = False + # Enable sleep mode + enable_sleep_mode: bool = False + disable_autocast: bool = False # VSA parameters diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index 868611841..ee78f4357 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -3,6 +3,8 @@ import multiprocessing as mp import os import time +from collections.abc import Iterable +from contextlib import AbstractContextManager, nullcontext import torch import zmq @@ -42,7 +44,7 @@ def __init__( self.rank = rank self.od_config = od_config self.pipeline = None - + self._sleep_saved_buffers: dict[str, torch.Tensor] = {} self.init_device_and_model() def init_device_and_model(self) -> None: @@ -71,11 +73,12 @@ def init_device_and_model(self) -> None: load_config = LoadConfig() model_loader = DiffusersPipelineLoader(load_config) time_before_load = time.perf_counter() - with DeviceMemoryProfiler() as m: - self.pipeline = model_loader.load_model( - od_config=self.od_config, - load_device=f"cuda:{rank}", - ) + with self._maybe_get_memory_pool_context(tag="weights"): + with DeviceMemoryProfiler() as m: + self.pipeline = model_loader.load_model( + od_config=self.od_config, + load_device=f"cuda:{rank}", + ) time_after_load = time.perf_counter() logger.info( @@ -107,6 +110,58 @@ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusi output = self.pipeline.forward(req) return output + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + return self.pipeline.load_weights(weights) + + def sleep(self, level: int = 1) -> bool: + from vllm.device_allocator.cumem import CuMemAllocator + + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + + # Save the buffers before level 2 sleep + if level == 2: + model = self.pipeline + self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()} + + allocator = CuMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights",) if level == 1 else tuple()) + free_bytes_after_sleep, total = torch.cuda.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.", + freed_bytes / GiB_bytes, + used_bytes / GiB_bytes, + ) + return True + + def wake_up(self, tags: list[str] | None = None) -> bool: + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + allocator.wake_up(tags) + + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers): + model = self.pipeline + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} + return True + + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: + if self.od_config.enable_sleep_mode: + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + if tag == "weights": + assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process." + return allocator.use_memory_pool(tag=tag) + else: + return nullcontext() + def shutdown(self) -> None: if torch.distributed.is_initialized(): try: