Skip to content

Co-Locating vLLM w/ training to achieve higher throughput and GPU utilization #3162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9ce5478
Add vllm colocation
toslali-ibm Mar 24, 2025
e3a0734
Fix typo
toslali-ibm Mar 24, 2025
72697c2
Remove profiling
toslali-ibm Mar 24, 2025
6762037
Fix default dtype
toslali-ibm Mar 24, 2025
95ef38f
Remove profiling
toslali-ibm Mar 24, 2025
65b3bc6
Print for debugging
toslali-ibm Mar 24, 2025
4f6aa27
Fix bug - generate in all procs
toslali-ibm Mar 24, 2025
85f8f40
Fix guided decoding param
toslali-ibm Mar 24, 2025
845328e
Fix reset prefix caching
toslali-ibm Mar 24, 2025
20d6fef
Fix reset prefix caching
toslali-ibm Mar 24, 2025
852c48d
Merge remote-tracking branch 'origin/main' into coloc
toslali-ibm Mar 25, 2025
47aca04
Revert dtype
toslali-ibm Mar 25, 2025
fa14b21
Add timeout arg
toslali-ibm Mar 25, 2025
e0dfba8
Fix vllm client init
toslali-ibm Mar 25, 2025
0f98e5c
Remove lazy import
toslali-ibm Mar 25, 2025
837263c
Debugging client
toslali-ibm Mar 25, 2025
76ca767
Add dtype auto as default
toslali-ibm Mar 25, 2025
d96655f
Remove prints and set default for vllm params
toslali-ibm Mar 25, 2025
c3e58a3
Remove debug statements
toslali-ibm Mar 25, 2025
5b43f0f
have just 1 vllm_client.py
fabianlim Mar 26, 2025
5aa7882
move controls into vllm_client objects
fabianlim Mar 26, 2025
07c2b0f
Add comments and docstring
toslali-ibm Mar 26, 2025
df38dd2
Merge branch 'main' into coloc
toslali-ibm Mar 26, 2025
0551491
Remove hf overrides
toslali-ibm Mar 27, 2025
192bf38
Merge branch 'main' into coloc
toslali-ibm Mar 31, 2025
d3fc0d9
Merge branch 'main' into coloc
toslali-ibm Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 217 additions & 5 deletions trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,78 @@

from ..import_utils import is_requests_available, is_vllm_available

from ..trainer.grpo_config import GRPOConfig

if is_requests_available():
import requests
from requests import ConnectionError


if is_vllm_available():
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
from vllm import SamplingParams, LLM
from vllm.sampling_params import GuidedDecodingParams

from accelerate import Accelerator
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed

logger = logging.getLogger(__name__)

class VLLMNoOpClient:
"""
A no-op vLLM client used in distributed training when the process is neither the main process
nor running in vLLM colocation mode.

This stub client ensures compatibility in distributed setups without performing actual
inference or model updates.

Methods like `generate` and `update_named_param` are implemented as no-ops or return default
values to maintain consistent interfaces across processes.

class VLLMClient:
This class should only be used internally by `get_vllm_client`.
"""

def __init__(self, process_index: int):
self.process_index = process_index

def generate(
self,
prompts: list[str],
n: int = 1,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
max_tokens: int = 16,
guided_decoding_regex: Optional[str] = None,
) -> list[list[str]]:
orig_size = len(prompts)
prompts = gather_object(prompts)
completion_ids = [None] * len(prompts)
return self._broadcast_and_slice(completion_ids, orig_size)

def update_named_param(self, name: str, weights: torch.Tensor):
pass

def reset_prefix_cache(self):
pass

def _gather(self, prompts):
return gather_object(prompts)

def _broadcast_and_slice(self, completion_ids: list, slice_size: int):
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice

completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.process_index * slice_size,
(self.process_index + 1) * slice_size,
)
return completion_ids[process_slice]

class VLLMClient(VLLMNoOpClient):
"""
A client class to interact with a vLLM server.

Expand Down Expand Up @@ -80,13 +137,17 @@ class VLLMClient:
"""

def __init__(
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0,
distributed: bool = False
):
super().__init__(process_index=0)

if not is_requests_available():
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
if not is_vllm_available():
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.")

self.distributed = distributed
self.session = requests.Session()
self.host = host
self.server_port = server_port
Expand Down Expand Up @@ -168,6 +229,16 @@ def generate(
`list[list[int]]`:
List of lists of token IDs representing the model-generated completions for each prompt.
"""

if self.distributed:
orig_size = len(prompts)
prompts = self._gather(prompts)

# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually
prompts = prompts[::n]

url = f"http://{self.host}:{self.server_port}/generate/"
response = self.session.post(
url,
Expand All @@ -184,10 +255,15 @@ def generate(
},
)
if response.status_code == 200:
return response.json()["completion_ids"]
completion_ids = response.json()["completion_ids"]
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

if self.distributed:
completion_ids = self._broadcast_and_slice(completion_ids, orig_size)

return completion_ids

def init_communicator(self):
"""
Initializes the weight update group in a distributed setup for model synchronization.
Expand Down Expand Up @@ -264,8 +340,143 @@ def close_communicator(self):
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

class VLLMColocationClient:
"""
A client class to interact with vLLM processes colocated with the training process.

This client bypasses remote communication and directly interacts with the in-process vLLM engine.
It supports weight updates and text generation functionalities similar to `VLLMClient`, but is optimized
for scenarios where vLLM is running in the same process or node as training.

Args:
args (`GRPOConfig`): Configuration object containing vLLM parameters.
model (`transformers.PreTrainedModel`): The model being used.
vllm_device (`torch.device` or `str`): Device on which the model is loaded (e.g., "cuda:0").
"""

def __init__(self, args: GRPOConfig, model, vllm_device):
self.args: GRPOConfig = args
self.model = model
self.vllm_device = vllm_device

self.llm = LLM(
model=self.model.name_or_path,
device=self.vllm_device,
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
dtype=self.args.vllm_dtype,
enable_prefix_caching=self.args.vllm_enable_prefix_caching,
max_model_len=self.args.vllm_max_model_len,
distributed_executor_backend="external_launcher",
)

def update_named_param(self, name: str, weights: torch.Tensor):
"""
Updates a specific named parameter in the model.

Args:
name (`str`):
Name of the layer whose weights are being updated.
weights (`torch.Tensor`):
Tensor containing the updated weights.
"""
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(name,weights)])

def generate(
self,
prompts: list[str],
n: int = 1,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
max_tokens: int = 16,
guided_decoding_regex: Optional[str] = None,
) -> list[list[str]]:
"""
Generates model completions for the provided prompts.

Args:
prompts (`list[str]`):
List of text prompts for which the model will generate completions.
n (`int`, *optional*, defaults to `1`):
Number of completions to generate for each prompt.
repetition_penalty (`float`, *optional*, defaults to `1.0`):
Parameter for repetition penalty. 1.0 means no penalty.
temperature (`float`, *optional*, defaults to `1.0`):
Temperature parameter for sampling. Higher values increase diversity.
top_p (`float`, *optional*, defaults to `1.0`):
Top-p sampling parameter.`1.0` means no truncation.
top_k (`int`, *optional*, defaults to `-1`):
Top-k sampling parameter. `-1` means no truncation.
min_p (`float`, *optional*, defaults to `0.0`):
Minimum probability for sampling.
max_tokens (`int`, *optional*, defaults to `16`):
Maximum number of tokens to generate for each prompt.
guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regular expression to guide the decoding process.

Returns:
`list[list[int]]`:
List of lists of token IDs representing the model-generated completions for each prompt.
"""
# Guided decoding, if enabled
if guided_decoding_regex is not None:
guided_decoding = GuidedDecodingParams(backend="outlines", regex=guided_decoding_regex)
else:
guided_decoding = None

sampling_params = SamplingParams(
n=1, # vLLM on each GPU generates only 1 in vllm_colocation mode
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
max_tokens=max_tokens,
guided_decoding=guided_decoding,
)

all_outputs = self.llm.generate(
prompts, sampling_params=sampling_params, use_tqdm=False
)
completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
return completion_ids

# Example usage
def reset_prefix_cache(self):
"""
Resets the prefix cache for the model.
"""
self.llm.reset_prefix_cache()

def get_vllm_client(args: GRPOConfig, model, accelerator: Accelerator) -> VLLMNoOpClient:
"""
Returns the appropriate vLLM client based on the current configuration.

This function acts as a proxy to initialize and return the correct vLLM client type:
- If colocation is enabled, it returns `VLLMColocationClient`, which interacts directly with
the colocated vLLM process for faster integration.
- If running in the main process (non-colocated mode), it returns `VLLMClient`, which communicates
with an external vLLM server.
- If not the main process and colocation is disabled, it returns a base client (`VLLMNoOpClient`)
for compatibility in distributed settings.

Args:
args (`GRPOConfig`): Configuration object containing flags for colocation, server host, port, etc.
model (`transformers.PreTrainedModel`): The model to use, passed only for the colocated client.
accelerator (`Accelerator`): Hugging Face `Accelerator` object that helps with multi-GPU training.
"""
if args.vllm_colocation:
return VLLMColocationClient(args, model, accelerator.device)
elif accelerator.is_main_process:
return VLLMClient(
args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout,
distributed=accelerator.num_processes > 1,
)
return VLLMNoOpClient(accelerator.process_index)

# Example usage for VLLMCLient
if __name__ == "__main__":
from vllm import SamplingParams

Expand All @@ -280,3 +491,4 @@ def close_communicator(self):

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda")
client.update_model_params(model)

18 changes: 15 additions & 3 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ class GRPOConfig(TrainingArguments):
timeout, a `ConnectionError` is raised.
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
vllm_colocation (`bool`, *optional*, defaults to `False`):
Whether to use colocated vLLM execution via external launcher. If set to `True`, vLLM will be
initialized in **all processes**, each assigned to its respective device. This allows multi-GPU
or multi-node execution with vLLM's external launcher, enabling improved large-scale inference.

> Parameters that control the training

Expand Down Expand Up @@ -250,6 +254,14 @@ class GRPOConfig(TrainingArguments):
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
vllm_colocation: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use colocated vLLM execution via external launcher. If set to `True`, vLLM will be "
"initialized in all processes, each assigned to its respective device. This enables optimized "
"multi-GPU inference."
},
)

# Parameters that control the training
learning_rate: float = field(
Expand Down Expand Up @@ -349,15 +361,15 @@ class GRPOConfig(TrainingArguments):
},
)
vllm_gpu_memory_utilization: Optional[float] = field(
default=None,
default=0.3,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control the GPU memory "
"utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server "
"configuration."
},
)
vllm_dtype: Optional[str] = field(
default=None,
default="auto",
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control the data type for "
"vLLM generation, you should now use the `dtype` parameter in the vLLM server configuration."
Expand All @@ -372,7 +384,7 @@ class GRPOConfig(TrainingArguments):
},
)
vllm_enable_prefix_caching: Optional[bool] = field(
default=None,
default=False,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control prefix caching in "
"vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server configuration."
Expand Down
Loading