Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions src/guidellm/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import time
from collections.abc import AsyncIterator
from itertools import chain
from pathlib import Path
from typing import Any, ClassVar, Optional, Union

Expand All @@ -29,7 +30,7 @@
GenerationRequestTimings,
GenerationResponse,
)
from guidellm.scheduler import ScheduledRequestInfo
from guidellm.scheduler import HistoryT, ScheduledRequestInfo

__all__ = ["OpenAIHTTPBackend", "UsageStats"]

Expand Down Expand Up @@ -280,7 +281,7 @@ async def resolve(
self,
request: GenerationRequest,
request_info: ScheduledRequestInfo,
history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None,
history: Optional[HistoryT[GenerationRequest, GenerationResponse]] = None,
) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]:
"""
Process a generation request and yield progressive responses.
Expand All @@ -295,10 +296,8 @@ async def resolve(
:yields: Tuples of (response, updated_request_info) as generation progresses.
"""
self._check_in_process()
if history is not None:
raise NotImplementedError(
"Multi-turn requests with conversation history are not yet supported"
)
if history:
request = self._apply_history(request, history)

response = GenerationResponse(
request_id=request.request_id,
Expand Down Expand Up @@ -500,6 +499,22 @@ async def chat_completions(
self._get_completions_usage_stats(data),
)

def _apply_history(
self,
request: GenerationRequest,
history: HistoryT[GenerationRequest, GenerationResponse],
) -> GenerationRequest:
"""
Apply conversation history to the current request.
"""

def turn_to_text(turn: tuple[GenerationRequest, GenerationResponse]) -> str:
req, res = turn
return f"{req.content}{res.value}"

request.content = "".join(chain(map(turn_to_text, history), (request.content,)))
return request

Comment on lines +502 to +517
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary hack until we land request templates.

def _build_headers(
self,
api_key: Optional[str],
Expand Down
103 changes: 71 additions & 32 deletions src/guidellm/dataset/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Iterable, Iterator
from itertools import cycle
from pathlib import Path
from typing import Any, Literal, Optional, Union
from typing import Any, Optional, TypedDict, Union

import yaml
from datasets import (
Expand Down Expand Up @@ -69,6 +69,26 @@ class SyntheticDatasetConfig(BaseModel):
gt=0,
default=None,
)
turns: int = Field(
description="The number of turns in the conversation.",
gt=0,
default=1,
)
turns_stdev: Optional[int] = Field(
description="The standard deviation of the number of turns.",
gt=0,
default=None,
)
turns_min: Optional[int] = Field(
description="The minimum number of turns in the conversation.",
gt=0,
default=None,
)
turns_max: Optional[int] = Field(
description="The maximum number of turns in the conversation.",
gt=0,
default=None,
)
samples: int = Field(
description="The number of samples to generate for the dataset.",
gt=0,
Expand Down Expand Up @@ -124,14 +144,13 @@ def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig":
return SyntheticDatasetConfig(**config_dict)


class SyntheticTextItemsGenerator(
Iterable[
dict[
Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
Union[str, int],
]
]
):
class SyntheticDatasetRow(TypedDict):
prompt: list[str]
prompt_tokens_count: list[int]
output_tokens_count: list[int]


class SyntheticTextItemsGenerator(Iterable[SyntheticDatasetRow]):
def __init__(
self,
config: SyntheticDatasetConfig,
Expand All @@ -147,12 +166,7 @@ def __init__(

def __iter__(
self,
) -> Iterator[
dict[
Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
Union[str, int],
]
]:
) -> Iterator[SyntheticDatasetRow]:
prompt_tokens_sampler = IntegerRangeSampler(
average=self.config.prompt_tokens,
variance=self.config.prompt_tokens_stdev,
Expand All @@ -167,31 +181,56 @@ def __iter__(
max_value=self.config.output_tokens_max,
random_seed=self.random_seed + 1, # ensure diff dist from prompts
)
turns_sampler = IntegerRangeSampler(
average=self.config.turns,
variance=self.config.turns_stdev,
min_value=self.config.turns_min,
max_value=self.config.turns_max,
random_seed=self.random_seed + 7, # ensure diff dist
)
# ensure diff distribution from output tokens
rand = random.Random(self.random_seed + 2) # noqa: S311
unique_prefix_iter = cycle(self.processor.get_vocab().values())

prefix_index = rand.randint(0, len(self.text_creator.words))
prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index)

for _, prompt_tokens, output_tokens in zip(
range(self.config.samples),
prompt_tokens_sampler,
output_tokens_sampler,
):
start_index = rand.randint(0, len(self.text_creator.words))
prompt_text = self.processor.decode(
prefix_tokens
+ self._create_prompt(
prompt_tokens, start_index, next(unique_prefix_iter)
),
skip_special_tokens=True,
)
yield {
"prompt": prompt_text,
"prompt_tokens_count": self.config.prefix_tokens + prompt_tokens,
"output_tokens_count": output_tokens,
for _, turns in zip(range(self.config.samples), turns_sampler):
row: SyntheticDatasetRow = {
"prompt": [],
"prompt_tokens_count": [],
"output_tokens_count": [],
}
for i, prompt_tokens, output_tokens in zip(
range(turns),
prompt_tokens_sampler,
output_tokens_sampler,
):
start_index = rand.randint(0, len(self.text_creator.words))
# Append the prefix tokens only for the first turn
if i == 0:
prompt_text = self.processor.decode(
prefix_tokens
+ self._create_prompt(
prompt_tokens, start_index, next(unique_prefix_iter)
),
skip_special_tokens=True,
)
row["prompt"].append(prompt_text)
row["prompt_tokens_count"].append(self.config.prefix_tokens + prompt_tokens)
row["output_tokens_count"].append(output_tokens)
else:
prompt_text = self.processor.decode(
self._create_prompt(
prompt_tokens, start_index, next(unique_prefix_iter)
),
skip_special_tokens=True,
)
row["prompt"].append(prompt_text)
row["prompt_tokens_count"].append(prompt_tokens)
row["output_tokens_count"].append(output_tokens)

yield row

def _create_prompt(
self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None
Expand Down
60 changes: 41 additions & 19 deletions src/guidellm/request/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def __init__(
self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests
self._preserved_iter = None

def __iter__(self) -> Iterator[GenerationRequest]:
def __iter__(self) -> Iterator[list[tuple[GenerationRequest, float]]]:
scope_create_count = 0

while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None:
scope_create_count += 1

for item in dataset_iter:
yield self._create_request(item)
yield self._create_requests(item)

self._preserved_iter = None

Expand Down Expand Up @@ -260,25 +260,47 @@ def _get_dataset_iter(

return dataset_iter

def _create_request(self, item: dict[str, Any]) -> GenerationRequest:
prompt_tokens = (
item[self.column_mappings["prompt_tokens_count_column"]]
def _create_requests(
self, item: dict[str, Any]
) -> list[tuple[GenerationRequest, float]]:
prompts = list(item[self.column_mappings["prompt_column"]])
prompts_tokens: list[Optional[int]] = (
list(item[self.column_mappings["prompt_tokens_count_column"]])
if "prompt_tokens_count_column" in self.column_mappings
else None
else [None] * len(prompts)
)
output_tokens = (
item[self.column_mappings["output_tokens_count_column"]]
outputs_tokens: list[Optional[int]] = (
list(item[self.column_mappings["output_tokens_count_column"]])
if "output_tokens_count_column" in self.column_mappings
else None
else [None] * len(prompts)
)

return GenerationRequest(
request_type=settings.preferred_route,
content=item[self.column_mappings["prompt_column"]],
stats=(
{"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {}
),
constraints=(
{"output_tokens": output_tokens} if output_tokens is not None else {}
),
)
if len(prompts) != len(prompts_tokens) != len(outputs_tokens):
raise ValueError(
"Mismatched lengths between prompts and token counts. "
f"Prompts: {len(prompts)}, Prompt Tokens: {len(prompts_tokens)}, "
f"Output Tokens: {len(outputs_tokens)}"
)

return [
(
GenerationRequest(
request_type=settings.preferred_route,
content=prompt,
stats=(
{"prompt_tokens": prompt_tokens}
if prompt_tokens is not None
else {}
),
constraints=(
{"output_tokens": output_tokens}
if output_tokens is not None
else {}
),
),
0.0, # TODO: delay
)
for prompt, prompt_tokens, output_tokens in zip(
prompts, prompts_tokens, outputs_tokens
)
]
10 changes: 8 additions & 2 deletions src/guidellm/scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
from .objects import (
BackendInterface,
BackendT,
DatasetIterT,
HistoryT,
MeasuredRequestTimings,
MultiTurnRequestT,
RequestDataT,
RequestSchedulerTimings,
RequestT,
ResponseT,
ScheduledRequestAugmentation,
ScheduledRequestInfo,
SchedulerMessagingPydanticRegistry,
SchedulerState,
Expand Down Expand Up @@ -55,22 +58,25 @@
"Constraint",
"ConstraintInitializer",
"ConstraintsInitializerFactory",
"DatasetIterT",
"Environment",
"HistoryT",
"LastCompletionRequestTimings",
"MaxDurationConstraint",
"MaxErrorRateConstraint",
"MaxErrorsConstraint",
"MaxGlobalErrorRateConstraint",
"MaxNumberConstraint",
"MeasuredRequestTimings",
"MultiTurnRequestT",
"NoDelayRequestTimings",
"NonDistributedEnvironment",
"PoissonRateRequestTimings",
"PydanticConstraintInitializer",
"RequestDataT",
"RequestSchedulerTimings",
"RequestT",
"ResponseT",
"ScheduledRequestAugmentation",
"ScheduledRequestInfo",
"ScheduledRequestTimings",
"Scheduler",
Expand Down
16 changes: 8 additions & 8 deletions src/guidellm/scheduler/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

import time
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterable
from collections.abc import AsyncIterator
from typing import (
Generic,
)

from guidellm.scheduler.constraints import Constraint
from guidellm.scheduler.objects import (
MultiTurnRequestT,
DatasetIterT,
RequestT,
ResponseT,
ScheduledRequestInfo,
Expand All @@ -52,11 +52,11 @@ class Environment(ABC, Generic[RequestT, ResponseT], InfoMixin):
@abstractmethod
async def sync_run_params(
self,
requests: Iterable[RequestT | MultiTurnRequestT[RequestT]],
requests: DatasetIterT[RequestT],
strategy: SchedulingStrategy,
constraints: dict[str, Constraint],
) -> tuple[
Iterable[RequestT | MultiTurnRequestT[RequestT]],
DatasetIterT[RequestT],
SchedulingStrategy,
dict[str, Constraint],
]:
Expand Down Expand Up @@ -130,7 +130,7 @@ async def sync_run_end(
) -> AsyncIterator[
tuple[
ResponseT,
RequestT | MultiTurnRequestT[RequestT],
RequestT,
ScheduledRequestInfo,
SchedulerState,
]
Expand Down Expand Up @@ -194,11 +194,11 @@ def __init__(self):

async def sync_run_params(
self,
requests: Iterable[RequestT | MultiTurnRequestT[RequestT]],
requests: DatasetIterT[RequestT],
strategy: SchedulingStrategy,
constraints: dict[str, Constraint],
) -> tuple[
Iterable[RequestT | MultiTurnRequestT[RequestT]],
DatasetIterT[RequestT],
SchedulingStrategy,
dict[str, Constraint],
]:
Expand Down Expand Up @@ -250,7 +250,7 @@ async def sync_run_end(
) -> AsyncIterator[
tuple[
ResponseT,
RequestT | MultiTurnRequestT[RequestT],
RequestT,
ScheduledRequestInfo,
SchedulerState,
]
Expand Down
Loading
Loading