Skip to content

Commit 267242d

Browse files
authored
[data][llm] Add single-node Ray Data LLM perf baseline benchmark + regression guard (#58289)
Signed-off-by: Nikhil Ghosh <[email protected]> Signed-off-by: Nikhil G <[email protected]>
1 parent b705c21 commit 267242d

File tree

5 files changed

+185
-2
lines changed

5 files changed

+185
-2
lines changed

python/ray/llm/_internal/batch/benchmark/benchmark_processor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
from enum import Enum
1414
from time import perf_counter, sleep
1515

16-
from dataset import ShareGPTDataset
17-
1816
import ray
17+
from .dataset import ShareGPTDataset
1918
from ray import data, serve
2019
from ray.data.llm import (
2120
ServeDeploymentProcessorConfig,
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Single-node compute config for Ray Data LLM baseline benchmark
2+
# Instance: g6.xlarge (1x NVIDIA L4 GPU, 24GB VRAM)
3+
4+
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}}
5+
region: us-west-2
6+
7+
head_node_type:
8+
name: head_node
9+
instance_type: m5.large
10+
resources:
11+
cpu: 0
12+
13+
worker_node_types:
14+
- name: worker_node
15+
instance_type: g6.xlarge
16+
min_workers: 1
17+
max_workers: 1
18+
use_spot: false
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#!/usr/bin/env python
2+
"""
3+
Single-node vLLM baseline benchmark for Ray Data LLM batch inference.
4+
5+
Measures throughput and supports env-driven thresholds and
6+
JSON artifact output.
7+
"""
8+
import json
9+
import os
10+
import sys
11+
12+
import pytest
13+
14+
import ray
15+
from ray.llm._internal.batch.benchmark.dataset import ShareGPTDataset
16+
from ray.llm._internal.batch.benchmark.benchmark_processor import (
17+
Mode,
18+
VLLM_SAMPLING_PARAMS,
19+
benchmark,
20+
)
21+
22+
23+
# Benchmark constants
24+
NUM_REQUESTS = 1000
25+
MODEL_ID = "facebook/opt-1.3b"
26+
BATCH_SIZE = 64
27+
CONCURRENCY = 1
28+
29+
30+
@pytest.fixture(autouse=True)
31+
def disable_vllm_compile_cache(monkeypatch):
32+
"""Disable vLLM compile cache to avoid cache corruption."""
33+
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
34+
35+
36+
@pytest.fixture(autouse=True)
37+
def cleanup_ray_resources():
38+
"""Cleanup Ray resources between tests."""
39+
yield
40+
ray.shutdown()
41+
42+
43+
def _get_float_env(name: str, default: float | None = None) -> float | None:
44+
value = os.getenv(name)
45+
if value is None or value == "":
46+
return default
47+
try:
48+
return float(value)
49+
except ValueError:
50+
raise AssertionError(f"Invalid float for {name}: {value}")
51+
52+
53+
def test_single_node_baseline_benchmark():
54+
"""
55+
Single-node baseline benchmark: facebook/opt-1.3b, TP=1, PP=1, 1000 prompts.
56+
57+
Logs BENCHMARK_* metrics and optionally asserts perf thresholds from env:
58+
- RAY_DATA_LLM_BENCHMARK_MIN_THROUGHPUT (req/s)
59+
- RAY_DATA_LLM_BENCHMARK_MAX_LATENCY_S (seconds)
60+
Writes JSON artifact to RAY_LLM_BENCHMARK_ARTIFACT_PATH if set.
61+
"""
62+
# Dataset setup
63+
dataset_path = os.getenv(
64+
"RAY_LLM_BENCHMARK_DATASET_PATH", "/tmp/ray_llm_benchmark_dataset"
65+
)
66+
67+
dataset = ShareGPTDataset(
68+
dataset_path=dataset_path,
69+
seed=0,
70+
hf_dataset_id="Crystalcareai/Code-feedback-sharegpt-renamed",
71+
hf_split="train",
72+
truncate_prompt=2048,
73+
)
74+
75+
print(f"Loading {NUM_REQUESTS} prompts from ShareGPT dataset...")
76+
prompts = dataset.sample(num_requests=NUM_REQUESTS)
77+
print(f"Loaded {len(prompts)} prompts")
78+
79+
ds = ray.data.from_items(prompts)
80+
81+
# Benchmark config (single node, TP=1, PP=1)
82+
print(
83+
f"\nBenchmark: {MODEL_ID}, batch={BATCH_SIZE}, concurrency={CONCURRENCY}, TP=1, PP=1"
84+
)
85+
86+
# Use benchmark processor to run a single-node vLLM benchmark
87+
result = benchmark(
88+
Mode.VLLM_ENGINE,
89+
ds,
90+
batch_size=BATCH_SIZE,
91+
concurrency=CONCURRENCY,
92+
model=MODEL_ID,
93+
sampling_params=VLLM_SAMPLING_PARAMS,
94+
pipeline_parallel_size=1,
95+
tensor_parallel_size=1,
96+
distributed_executor_backend="mp",
97+
)
98+
99+
result.show()
100+
101+
# Assertions and metrics
102+
assert result.samples == len(prompts)
103+
assert result.throughput > 0
104+
105+
print("\n" + "=" * 60)
106+
print("BENCHMARK METRICS")
107+
print("=" * 60)
108+
print(f"BENCHMARK_THROUGHPUT: {result.throughput:.4f} req/s")
109+
print(f"BENCHMARK_LATENCY: {result.elapsed_s:.4f} s")
110+
print(f"BENCHMARK_SAMPLES: {result.samples}")
111+
print("=" * 60)
112+
113+
# Optional thresholds to fail on regressions
114+
min_throughput = _get_float_env("RAY_DATA_LLM_BENCHMARK_MIN_THROUGHPUT", 5)
115+
max_latency_s = _get_float_env("RAY_DATA_LLM_BENCHMARK_MAX_LATENCY_S", 120)
116+
if min_throughput is not None:
117+
assert (
118+
result.throughput >= min_throughput
119+
), f"Throughput regression: {result.throughput:.4f} < {min_throughput:.4f} req/s"
120+
if max_latency_s is not None:
121+
assert (
122+
result.elapsed_s <= max_latency_s
123+
), f"Latency regression: {result.elapsed_s:.4f} > {max_latency_s:.4f} s"
124+
125+
# Optional JSON artifact emission for downstream ingestion
126+
artifact_path = os.getenv("RAY_LLM_BENCHMARK_ARTIFACT_PATH")
127+
if artifact_path:
128+
metrics = {
129+
"model": MODEL_ID,
130+
"batch_size": BATCH_SIZE,
131+
"concurrency": CONCURRENCY,
132+
"samples": int(result.samples),
133+
"throughput_req_per_s": float(result.throughput),
134+
"elapsed_s": float(result.elapsed_s),
135+
}
136+
try:
137+
os.makedirs(os.path.dirname(artifact_path), exist_ok=True)
138+
with open(artifact_path, "w", encoding="utf-8") as f:
139+
json.dump(metrics, f, indent=2, sort_keys=True)
140+
print(f"Wrote benchmark artifact to: {artifact_path}")
141+
except Exception as e: # noqa: BLE001
142+
print(
143+
f"Warning: failed to write benchmark artifact to {artifact_path}: {e}"
144+
)
145+
146+
147+
if __name__ == "__main__":
148+
sys.exit(pytest.main(["-v", "-s", __file__]))

release/release_data_tests.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,7 @@
647647
timeout: 5400
648648
script: python tpch_q1.py --sf 100
649649

650+
650651
#################################################
651652
# Cross-AZ RPC fault tolerance test
652653
#################################################

release/release_tests.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4105,6 +4105,23 @@
41054105
script: >
41064106
pytest -sv test_batch_multi_node_vllm.py
41074107
4108+
- name: llm_batch_single_node_baseline_benchmark
4109+
group: llm-batch
4110+
working_dir: llm_tests/batch
4111+
frequency: weekly
4112+
team: llm
4113+
4114+
cluster:
4115+
byod:
4116+
runtime_env:
4117+
- VLLM_DISABLE_COMPILE_CACHE=1
4118+
type: gpu
4119+
cluster_compute: llm_single_node_benchmark_l4.yaml
4120+
4121+
run:
4122+
timeout: 3600
4123+
script: pytest -v test_batch_single_node_vllm.py::test_single_node_baseline_benchmark
4124+
41084125

41094126
- name: text_embeddings_benchmark_{{scaling}}
41104127
frequency: nightly

0 commit comments

Comments
 (0)