Skip to content

Commit fe5e695

Browse files
Added GPU tests.
1 parent 446b5e0 commit fe5e695

File tree

3 files changed

+458
-0
lines changed

3 files changed

+458
-0
lines changed

.github/workflows/gpu-tests.yml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: GPU Tests
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
- v*-release
8+
pull_request:
9+
branches:
10+
- main
11+
merge_group:
12+
workflow_dispatch:
13+
14+
jobs:
15+
gpu-tests:
16+
name: Run GPU tests
17+
runs-on: [self-hosted, gpu] # Requires a GPU-enabled runner
18+
steps:
19+
- name: Checkout code
20+
uses: actions/checkout@v4
21+
22+
- name: Install uv
23+
uses: astral-sh/setup-uv@v5
24+
with:
25+
version: "0.8.6"
26+
27+
- name: Set up Python environment
28+
run: uv sync
29+
30+
- name: Verify CUDA availability
31+
run: |
32+
uv run python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}'); print(f'CUDA devices: {torch.cuda.device_count()}')"
33+
34+
- name: Run GPU tests
35+
run: |
36+
# Run all GPU test files
37+
uv run pytest -xvs tests/*_gpu.py open_instruct/*_gpu.py
38+
env:
39+
CUDA_VISIBLE_DEVICES: 0 # Use first GPU
40+
NCCL_CUMEM_ENABLE: 0
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import gc
2+
import os
3+
import unittest
4+
5+
import ray
6+
import torch
7+
from ray.util import queue as ray_queue
8+
from transformers import AutoTokenizer
9+
from vllm import SamplingParams
10+
11+
from open_instruct import utils
12+
from open_instruct.queue_types import GenerationResult, PromptRequest
13+
from open_instruct.vllm_utils3 import create_vllm_engines
14+
15+
16+
class TestGrpoFastGPUBase(unittest.TestCase):
17+
"""Base class with common test utilities for GPU tests."""
18+
19+
def _get_resource_tracker_state(self):
20+
"""Get current resource tracker state for debugging."""
21+
tracked_resources = {}
22+
try:
23+
# Try to access resource tracker directly
24+
from multiprocessing.resource_tracker import _resource_tracker
25+
26+
if hasattr(_resource_tracker, "_cache"):
27+
for name, rtype in list(_resource_tracker._cache.items()):
28+
if rtype not in tracked_resources:
29+
tracked_resources[rtype] = []
30+
tracked_resources[rtype].append(name)
31+
except Exception:
32+
# Alternative approach: check via resource_tracker module
33+
try:
34+
import multiprocessing.resource_tracker as rt
35+
36+
if hasattr(rt, "getfd"):
37+
# This is a hack to get the cache info
38+
39+
# Try to find the cache in the module
40+
for attr_name in dir(rt):
41+
attr = getattr(rt, attr_name)
42+
if isinstance(attr, dict) and any("semaphore" in str(v) for v in attr.values()):
43+
for k, v in attr.items():
44+
if v not in tracked_resources:
45+
tracked_resources[v] = []
46+
tracked_resources[v].append(k)
47+
except Exception:
48+
pass
49+
return tracked_resources
50+
51+
def setUp(self):
52+
"""Initialize Ray and check for pre-existing leaks."""
53+
# Check if CUDA is available
54+
if not torch.cuda.is_available():
55+
self.skipTest("CUDA is not available, skipping test")
56+
57+
# Save original environment variable value
58+
self._original_nccl_cumem = os.environ.get("NCCL_CUMEM_ENABLE")
59+
60+
# Record initial resource tracker state
61+
self._initial_resources = self._get_resource_tracker_state()
62+
63+
# Track Ray queues for cleanup
64+
self._ray_queues = []
65+
66+
# Check for leaks after Ray init
67+
leak_report = utils.check_runtime_leaks()
68+
# After Ray init, we expect exactly one Ray head worker
69+
if len(leak_report.ray_workers) == 1:
70+
# Check if it's the head worker (worker ID all zeros or all f's)
71+
worker = leak_report.ray_workers[0]
72+
worker_id = worker.get("worker_id", "")
73+
if worker_id in [
74+
"01000000ffffffffffffffffffffffffffffffffffffffffffffffff",
75+
"00000000ffffffffffffffffffffffffffffffffffffffffffffffff",
76+
]:
77+
# This is the expected Ray head worker, clear it
78+
leak_report.ray_workers = []
79+
80+
if not leak_report.is_clean:
81+
self.fail(f"Leaks detected before test {self._testMethodName}:\n{leak_report.pretty()}")
82+
83+
# Initialize Ray for this test
84+
ray.init(include_dashboard=False)
85+
86+
def _cleanup_ray_queues(self):
87+
"""Clean up all Ray queues created during the test."""
88+
for queue in self._ray_queues:
89+
try:
90+
queue.shutdown()
91+
except Exception as e:
92+
print(f"Warning: Failed to shutdown Ray queue: {e}")
93+
self._ray_queues.clear()
94+
95+
def tearDown(self):
96+
"""Check for leaks and shutdown Ray."""
97+
# Clean up Ray queues BEFORE shutting down Ray
98+
self._cleanup_ray_queues()
99+
100+
# Shutdown Ray
101+
if ray.is_initialized():
102+
ray.shutdown()
103+
104+
# Force garbage collection to clean up any lingering objects
105+
gc.collect()
106+
107+
# Get final resource tracker state
108+
final_resources = self._get_resource_tracker_state()
109+
110+
# Check for new resources that weren't there initially
111+
new_resources = {}
112+
for rtype, names in final_resources.items():
113+
initial_names = set(self._initial_resources.get(rtype, []))
114+
new_names = [n for n in names if n not in initial_names]
115+
if new_names:
116+
new_resources[rtype] = new_names
117+
118+
# Check for leaks before shutdown
119+
leak_report = utils.check_runtime_leaks()
120+
# We still expect the Ray head worker
121+
if len(leak_report.ray_workers) == 1:
122+
worker = leak_report.ray_workers[0]
123+
worker_id = worker.get("worker_id", "")
124+
if worker_id in [
125+
"01000000ffffffffffffffffffffffffffffffffffffffffffffffff",
126+
"00000000ffffffffffffffffffffffffffffffffffffffffffffffff",
127+
]:
128+
# This is the expected Ray head worker, clear it
129+
leak_report.ray_workers = []
130+
131+
if not leak_report.is_clean:
132+
self.fail(f"Leaks detected after test {self._testMethodName}:\n{leak_report.pretty()}")
133+
134+
# Check for semaphore leaks
135+
if new_resources:
136+
# Report all new resources, especially semaphores
137+
leak_msg = f"Resource leaks detected after test {self._testMethodName}:\n"
138+
for rtype, names in new_resources.items():
139+
leak_msg += f" {rtype}: {names}\n"
140+
141+
# Fail if there are semaphore leaks
142+
if "semaphore" in new_resources:
143+
self.fail(leak_msg)
144+
145+
# Restore original environment variable value
146+
if self._original_nccl_cumem is None:
147+
os.environ.pop("NCCL_CUMEM_ENABLE", None)
148+
else:
149+
os.environ["NCCL_CUMEM_ENABLE"] = self._original_nccl_cumem
150+
151+
152+
class TestGrpoFastVLLMGPU(TestGrpoFastGPUBase):
153+
def test_vllm_queue_system_single_prompt(self):
154+
"""Test the new queue-based vLLM system with a single prompt 'What is the capital of France?'"""
155+
# Set up tokenizer
156+
tokenizer_name = "EleutherAI/pythia-14m" # Using a small model for testing
157+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
158+
159+
# Tokenize the test prompt
160+
test_prompt = "What is the capital of France?"
161+
prompt_token_ids = tokenizer.encode(test_prompt, return_tensors="pt").tolist()[0]
162+
163+
# Create Ray queues
164+
param_prompt_Q = ray_queue.Queue(maxsize=1)
165+
inference_results_Q = ray_queue.Queue(maxsize=1)
166+
167+
# Track queues for cleanup
168+
self._ray_queues.extend([param_prompt_Q, inference_results_Q])
169+
170+
# Create vLLM engines with queues
171+
_ = create_vllm_engines(
172+
num_engines=1,
173+
tensor_parallel_size=1,
174+
enforce_eager=True,
175+
tokenizer_name_or_path=tokenizer_name,
176+
pretrain=tokenizer_name,
177+
revision="main",
178+
seed=42,
179+
enable_prefix_caching=False,
180+
max_model_len=512,
181+
vllm_gpu_memory_utilization=0.5, # Use less GPU memory for testing
182+
prompt_queue=param_prompt_Q,
183+
results_queue=inference_results_Q,
184+
)
185+
186+
# Set up generation config
187+
generation_config = SamplingParams(
188+
temperature=0.0, # Deterministic generation
189+
top_p=1.0,
190+
max_tokens=5,
191+
n=1,
192+
)
193+
194+
# Create a PromptRequest
195+
request = PromptRequest(
196+
prompt_token_ids=prompt_token_ids, generation_config=generation_config, dataset_index=[0], training_step=0
197+
)
198+
199+
# Send the request
200+
param_prompt_Q.put(request)
201+
202+
# Get the result
203+
result = inference_results_Q.get(timeout=30)
204+
205+
# Verify we got a GenerationResult
206+
self.assertIsInstance(result, GenerationResult)
207+
self.assertIsNotNone(result.responses)
208+
self.assertEqual(len(result.responses), 1)
209+
self.assertEqual(result.dataset_index, [0])
210+
211+
# Get the response IDs (skip the prompt)
212+
response_ids = result.responses[0]
213+
214+
# Decode the response
215+
generated_text = tokenizer.decode(response_ids, skip_special_tokens=True)
216+
217+
self.assertIsInstance(generated_text, str)
218+
self.assertGreater(len(generated_text), 0)
219+
220+
# Send stop signal
221+
param_prompt_Q.put(None)

0 commit comments

Comments
 (0)