Skip to content

Commit b288ff4

Browse files
authored
[vllm] But fix for V1 Completion (#735)
1 parent dbe7b4e commit b288ff4

File tree

10 files changed

+217
-3
lines changed

10 files changed

+217
-3
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ ipython_config.py
8484
*.distcp
8585
.metadata
8686

87+
# Track golden test outputs (these are checked in for correctness tests)
88+
!tests/integration_tests/fixtures/golden_outputs/*.pt
89+
8790
# mypy
8891
.mypy_cache/
8992
.dmypy.json

src/forge/actors/vllm/v1/generator.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,24 @@ async def validate_model_params(self, validate_fn):
527527
logger.info("start validating model parameters.")
528528
return await self.workers.validate_model_params.call(validate_fn)
529529

530+
def _extract_logprobs(self, output) -> torch.Tensor | None:
531+
"""Extract logprobs from vLLM output as a torch.Tensor.
532+
533+
Args:
534+
output: vLLM CompletionOutput with optional logprobs.
535+
536+
Returns:
537+
torch.Tensor of logprobs for each token, or None if not available.
538+
"""
539+
if output.logprobs is not None:
540+
return torch.tensor(
541+
[
542+
top_k_dict[token].logprob
543+
for token, top_k_dict in zip(output.token_ids, output.logprobs)
544+
]
545+
)
546+
return None
547+
530548
def _to_completions(
531549
self, request_output: RequestOutput, prompt: str
532550
) -> list[Completion]:
@@ -553,15 +571,19 @@ def _to_completions(
553571
token_ids=torch.tensor(
554572
output.token_ids if hasattr(output, "token_ids") else []
555573
),
556-
logprobs=(output.logprobs if hasattr(output, "logprobs") else None),
574+
logprobs=self._extract_logprobs(output),
557575
stop_reason=output.finish_reason,
558576
generator_version=self.generator_version,
559-
metadata=None,
577+
metadata={"num_cached_tokens": request_output.num_cached_tokens},
560578
)
561579
completions.append(completion)
562580

563581
return completions
564582

583+
@endpoint
584+
async def _reset_prefix_cache(self):
585+
await self.llm.reset_prefix_cache()
586+
565587

566588
class _WeightFetcher(ForgeActor):
567589
"""Fetches weights from torchstore and loads them into shared memory.
3.2 KB
Binary file not shown.
3.14 KB
Binary file not shown.
3.2 KB
Binary file not shown.
3.26 KB
Binary file not shown.
2.58 KB
Binary file not shown.
1.42 KB
Binary file not shown.
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Generate baseline golden output files using the current Generator implementation.
9+
10+
These golden files serve as a baseline for verifying that new implementations
11+
produce identical outputs. Uses deterministic sampling (temperature=0) for
12+
reproducibility.
13+
14+
NOTE: Golden output artifacts are checked into git. Keep the number of prompts
15+
and MAX_TOKENS small to avoid bloating the repository. Current artifacts are
16+
~20KB total.
17+
18+
Usage:
19+
python tests/integration_tests/generate_golden_outputs.py
20+
21+
The script will generate golden files in tests/integration_tests/fixtures/golden_outputs/
22+
"""
23+
24+
import asyncio
25+
from pathlib import Path
26+
27+
import torch
28+
from forge.actors.generator import Generator
29+
30+
31+
# Configuration - matches test_vllm_policy_correctness.py
32+
MODEL_NAME = "facebook/opt-125m"
33+
MAX_MODEL_LEN = 512
34+
GPU_MEMORY_UTILIZATION = 0.1
35+
ENFORCE_EAGER = True
36+
ENABLE_PREFIX_CACHING = True
37+
TENSOR_PARALLEL_SIZE = 1
38+
39+
# Deterministic sampling
40+
MAX_TOKENS = 50
41+
TEMPERATURE = 0.0
42+
TOP_P = 1.0
43+
N_SAMPLES = 1
44+
45+
TEST_PROMPTS = [
46+
"Hello, how are you?",
47+
"What is 2+2?",
48+
"Tell me a joke.",
49+
"Explain machine learning briefly.",
50+
"What color is the sky?",
51+
]
52+
53+
54+
async def generate_golden_outputs():
55+
"""Generate golden outputs using the current Generator."""
56+
golden_dir = Path(__file__).parent / "fixtures" / "golden_outputs"
57+
golden_dir.mkdir(parents=True, exist_ok=True)
58+
59+
print(f"Generating golden outputs to: {golden_dir}")
60+
print(f"Model: {MODEL_NAME}")
61+
62+
generator = None
63+
try:
64+
generator = await Generator.options(
65+
procs=1, num_replicas=1, with_gpus=True
66+
).as_service(
67+
engine_args={
68+
"model": MODEL_NAME,
69+
"tensor_parallel_size": TENSOR_PARALLEL_SIZE,
70+
"enforce_eager": ENFORCE_EAGER,
71+
"max_model_len": MAX_MODEL_LEN,
72+
"gpu_memory_utilization": GPU_MEMORY_UTILIZATION,
73+
"enable_prefix_caching": ENABLE_PREFIX_CACHING,
74+
},
75+
sampling_params={
76+
"n": N_SAMPLES,
77+
"max_tokens": MAX_TOKENS,
78+
"temperature": TEMPERATURE,
79+
"top_p": TOP_P,
80+
"logprobs": 1,
81+
},
82+
)
83+
84+
print("Generator ready. Generating outputs...\n")
85+
86+
for i, prompt in enumerate(TEST_PROMPTS):
87+
print(f"[{i + 1}/{len(TEST_PROMPTS)}] Prompt: {prompt[:50]}...")
88+
89+
result = await generator.generate.route(prompt)
90+
completion = result[0]
91+
92+
# Serialize entire Completion object
93+
golden_path = golden_dir / f"completion_{i}.pt"
94+
torch.save(completion, golden_path)
95+
print(f" Saved: {golden_path}")
96+
97+
metadata = {
98+
"model": MODEL_NAME,
99+
"max_tokens": MAX_TOKENS,
100+
"temperature": TEMPERATURE,
101+
"prompts": TEST_PROMPTS,
102+
}
103+
torch.save(metadata, golden_dir / "metadata.pt")
104+
105+
print("\nGolden output generation complete!")
106+
107+
finally:
108+
if generator is not None:
109+
await generator.shutdown()
110+
111+
112+
if __name__ == "__main__":
113+
asyncio.run(generate_golden_outputs())

tests/integration_tests/test_vllm_policy_correctness.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import pytest
8+
import torch
89
from forge.actors.generator import Generator
910
from vllm import SamplingParams
1011
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -29,7 +30,7 @@
2930

3031
@pytest.mark.asyncio
3132
async def test_same_output():
32-
"""Compare outputs between vLLM and Generator service"""
33+
"""Compare outputs between vLLM and Generator service."""
3334
test_prompts = [
3435
"Hello, how are you?",
3536
"What is 2+2?",
@@ -236,3 +237,78 @@ async def test_cache_usage():
236237
finally:
237238
if generator is not None:
238239
await generator.shutdown()
240+
241+
242+
@pytest.mark.asyncio
243+
async def test_generator_matches_golden():
244+
"""Verify Generator produces identical outputs to baseline golden files.
245+
246+
Golden files are already committed. Only regenerate when updating baseline:
247+
python tests/integration_tests/generate_golden_outputs.py
248+
"""
249+
from dataclasses import fields
250+
from pathlib import Path
251+
252+
def completions_equal(a, b) -> bool:
253+
"""Compare two Completion objects, handling tensors correctly."""
254+
for field in fields(a):
255+
val_a = getattr(a, field.name)
256+
val_b = getattr(b, field.name)
257+
if isinstance(val_a, torch.Tensor) and isinstance(val_b, torch.Tensor):
258+
if not torch.equal(val_a, val_b):
259+
return False
260+
elif val_a != val_b:
261+
return False
262+
return True
263+
264+
golden_dir = Path(__file__).parent / "fixtures" / "golden_outputs"
265+
metadata_path = golden_dir / "metadata.pt"
266+
267+
if not metadata_path.exists():
268+
pytest.skip(
269+
"Golden files not found. Generate baseline first: "
270+
"python tests/integration_tests/generate_golden_outputs.py"
271+
)
272+
273+
metadata = torch.load(metadata_path, weights_only=False)
274+
test_prompts = metadata["prompts"]
275+
276+
generator = None
277+
try:
278+
generator = await Generator.options(
279+
procs=1, num_replicas=1, with_gpus=True
280+
).as_service(
281+
engine_args={
282+
"model": MODEL_NAME,
283+
"tensor_parallel_size": TENSOR_PARALLEL_SIZE,
284+
"enforce_eager": ENFORCE_EAGER,
285+
"max_model_len": MAX_MODEL_LEN,
286+
"gpu_memory_utilization": GPU_MEMORY_UTILIZATION,
287+
"enable_prefix_caching": ENABLE_PREFIX_CACHING,
288+
},
289+
sampling_params={
290+
"n": N_SAMPLES,
291+
"max_tokens": MAX_TOKENS,
292+
"temperature": TEMPERATURE,
293+
"top_p": TOP_P,
294+
"logprobs": 1,
295+
},
296+
)
297+
298+
for i, prompt in enumerate(test_prompts):
299+
golden_path = golden_dir / f"completion_{i}.pt"
300+
assert golden_path.exists(), f"Golden file not found: {golden_path}"
301+
302+
golden = torch.load(golden_path, weights_only=False)
303+
result = await generator.generate.route(prompt)
304+
completion = result[0]
305+
306+
assert completions_equal(
307+
completion, golden
308+
), f"Prompt {i}: completion mismatch"
309+
310+
print(f"Prompt {i}: PASS")
311+
312+
finally:
313+
if generator is not None:
314+
await generator.shutdown()

0 commit comments

Comments
 (0)