Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 57 additions & 10 deletions gpt_oss/evals/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ChatCompletionsSampler,
)
from .responses_sampler import ResponsesSampler
from .harmony_sampler import HarmonySampler


def main():
Expand All @@ -34,9 +35,9 @@ def main():
parser.add_argument(
"--sampler",
type=str,
choices=["responses", "chat_completions"],
choices=["responses", "chat_completions", "harmony"],
default="responses",
help="Sampler backend to use for models.",
help="Sampler backend to use for models. 'harmony' uses openai_harmony tokenization with SGLang /generate endpoint.",
)
parser.add_argument(
"--base-url",
Expand All @@ -56,6 +57,24 @@ def main():
default=1.0,
help="Sampling temperature",
)
parser.add_argument(
"--top-p",
type=float,
default=None,
help="Top-p (nucleus) sampling parameter",
)
parser.add_argument(
"--top-k",
type=int,
default=None,
help="Top-k sampling parameter (sglang/vLLM specific)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=32768,
help="Maximum number of output tokens",
)
parser.add_argument(
"--n-threads",
type=int,
Expand All @@ -68,22 +87,45 @@ def main():
parser.add_argument(
"--examples", type=int, help="Number of examples to use (overrides default)"
)
parser.add_argument(
"--n-repeats",
type=int,
default=None,
help="Number of repeats per example (default: 1 in debug mode, 8 otherwise)",
)
parser.add_argument(
"--dump-inputs",
type=str,
default=None,
help="Directory to dump input tokens to JSON files (harmony sampler only)",
)

args = parser.parse_args()

sampler_cls = ResponsesSampler if args.sampler == "responses" else ChatCompletionsSampler
if args.sampler == "responses":
sampler_cls = ResponsesSampler
elif args.sampler == "chat_completions":
sampler_cls = ChatCompletionsSampler
else: # harmony
sampler_cls = HarmonySampler

models = {}
for model_name in args.model.split(","):
for reasoning_effort in args.reasoning_effort.split(","):
models[f"{model_name}-{reasoning_effort}"] = sampler_cls(
sampler_kwargs = dict(
model=model_name,
reasoning_model=True,
reasoning_effort=reasoning_effort,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
base_url=args.base_url,
max_tokens=131_072,
max_tokens=args.max_tokens,
)
# Add dump_inputs_dir for harmony sampler
if args.sampler == "harmony" and args.dump_inputs:
sampler_kwargs["dump_inputs_dir"] = args.dump_inputs
models[f"{model_name}-{reasoning_effort}"] = sampler_cls(**sampler_kwargs)

print(f"Running with args {args}")

Expand All @@ -98,13 +140,18 @@ def get_evals(eval_name, debug_mode):
num_examples = (
args.examples if args.examples is not None else (5 if debug_mode else None)
)
# Determine n_repeats: use --n-repeats if provided, else 1 for debug, else 8
if args.n_repeats is not None:
n_repeats = args.n_repeats
else:
n_repeats = 1 if debug_mode else 8
# Set num_examples = None to reproduce full evals
match eval_name:
case "basic":
return BasicEval()
case "gpqa":
return GPQAEval(
n_repeats=1 if args.debug else 8,
n_repeats=n_repeats,
num_examples=num_examples,
debug=debug_mode,
n_threads=args.n_threads or 1,
Expand All @@ -113,29 +160,29 @@ def get_evals(eval_name, debug_mode):
return HealthBenchEval(
grader_model=grading_sampler,
num_examples=10 if debug_mode else num_examples,
n_repeats=1,
n_repeats=n_repeats,
n_threads=args.n_threads or 1,
subset_name=None,
)
case "healthbench_hard":
return HealthBenchEval(
grader_model=grading_sampler,
num_examples=10 if debug_mode else num_examples,
n_repeats=1,
n_repeats=n_repeats,
n_threads=args.n_threads or 1,
subset_name="hard",
)
case "healthbench_consensus":
return HealthBenchEval(
grader_model=grading_sampler,
num_examples=10 if debug_mode else num_examples,
n_repeats=1,
n_repeats=n_repeats,
n_threads=args.n_threads or 1,
subset_name="consensus",
)
case "aime25":
return AIME25Eval(
n_repeats=1 if args.debug else 8,
n_repeats=n_repeats,
num_examples=num_examples,
n_threads=args.n_threads or 1,
)
Expand Down
13 changes: 13 additions & 0 deletions gpt_oss/evals/chat_completions_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(
reasoning_model: bool = False,
reasoning_effort: str | None = None,
base_url: str = "http://localhost:8000/v1",
top_p: float | None = None,
top_k: int | None = None,
):
self.client = OpenAI(base_url=base_url, timeout=24 * 60 * 60)
self.model = model
Expand All @@ -35,6 +37,8 @@ def __init__(
self.reasoning_model = reasoning_model
self.reasoning_effort = reasoning_effort
self.image_format = "url"
self.top_p = top_p
self.top_k = top_k

def _pack_message(self, role: str, content: Any) -> dict[str, Any]:
return {"role": str(role), "content": content}
Expand All @@ -47,20 +51,29 @@ def __call__(self, message_list: MessageList) -> SamplerResponse:
trial = 0
while True:
try:
# Build extra kwargs for optional sampling parameters
extra_kwargs = {}
if self.top_p is not None:
extra_kwargs["top_p"] = self.top_p
if self.top_k is not None:
extra_kwargs["extra_body"] = {"top_k": self.top_k}

if self.reasoning_model:
response = self.client.chat.completions.create(
model=self.model,
messages=message_list,
reasoning_effort=self.reasoning_effort,
temperature=self.temperature,
max_tokens=self.max_tokens,
**extra_kwargs,
)
else:
response = self.client.chat.completions.create(
model=self.model,
messages=message_list,
temperature=self.temperature,
max_tokens=self.max_tokens,
**extra_kwargs,
)

choice = response.choices[0]
Expand Down
Loading