Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Humanity's last exam #520

Merged
merged 13 commits into from
Feb 18, 2025
5 changes: 5 additions & 0 deletions examples/model_configs/serverless_model_with_openai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model:
model_name: "deepseek-ai/DeepSeek-R1" #meta-llama/Llama-3.1-8B-Instruct" #Qwen/Qwen2.5-14B" #Qwen/Qwen2.5-7B"
api:
base_url: "https://huggingface.co/api/inference-proxy/together"
api_key: "hf_"
38 changes: 32 additions & 6 deletions src/lighteval/metrics/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,21 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Literal

from pydantic import BaseModel
from tqdm import tqdm

from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available
from lighteval.utils.utils import as_list


logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)


DEFAULT_FORMAT = {"type": "text"}


class JudgeLM:
"""
A class representing a judge for evaluating answers using either the OpenAI or Transformers library.
Expand Down Expand Up @@ -76,6 +81,7 @@ def __init__(
judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm"],
url: str | None = None,
api_key: str | None = None,
response_format: BaseModel = None,
):
self.model = model
self.template = templates
Expand All @@ -91,6 +97,8 @@ def __init__(
self.api_key = api_key
self.backend = judge_backend

self.response_format = response_format if not None else DEFAULT_FORMAT

def __lazy_load_client(self):
match self.backend:
# Wether we use openai or TGI models, we go through the openai API
Expand Down Expand Up @@ -244,16 +252,34 @@ def __call_api_parallel(self, prompts):
def __call_api(self, prompt):
for _ in range(self.API_MAX_RETRY):
try:
response = self.client.chat.completions.create(
# Base model
response = self.client.beta.chat.completions.parse(
model=self.model,
messages=prompt,
response_format={"type": "text"},
max_tokens=512,
messages=as_list(prompt),
response_format=self.response_format,
max_tokens=4096,
temperature=0.0,
n=1,
)
text = response.choices[0].message.content
return text
answer = response.choices[0].message.parsed
return answer
except TypeError:
try:
# Finetune
response = self.client.chat.completions.create(
model=self.model,
messages=as_list(prompt),
response_format=self.response_format,
max_tokens=512,
n=1,
)
text = response.choices[0].message.content
return text
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)

raise Exception("Failed to get response from the API")
5 changes: 4 additions & 1 deletion src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordTokenizer
from nltk.translate.bleu_score import sentence_bleu
from pydantic import BaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from lighteval.metrics.imports.bert_scorer import BERTScorer
Expand Down Expand Up @@ -852,7 +853,7 @@ def edit_similarity(self, s1, s2):


class JudgeLLM:
available_models_openai = ["gpt-3.5-turbo", "gpt-4o", "gpt-4-turbo", "gpt-4"]
available_models_openai = ["gpt-3.5-turbo", "gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-4o-2024-08-06"]

def __init__(
self,
Expand All @@ -861,6 +862,7 @@ def __init__(
process_judge_response: Callable,
judge_backend: Literal["litellm", "openai", "transformers", "vllm", "tgi"],
short_judge_name: str | None = None,
response_format: BaseModel = None,
) -> None:
match judge_backend:
case "openai":
Expand Down Expand Up @@ -893,6 +895,7 @@ def __init__(
api_key=api_key,
url=url,
judge_backend=judge_backend,
response_format=response_format,
)

def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]:
Expand Down
97 changes: 86 additions & 11 deletions src/lighteval/models/endpoints/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import asyncio
import logging
import os
import time
Expand All @@ -28,6 +29,8 @@
from typing import Optional

from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
from transformers import AutoTokenizer

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
from lighteval.models.abstract_model import LightevalModel
Expand All @@ -54,7 +57,7 @@
import logging

import tiktoken
from openai import OpenAI
from openai import AsyncOpenAI, OpenAI

logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
Expand All @@ -64,6 +67,8 @@
class OpenAIModelConfig:
model: str
generation_parameters: GenerationParameters = None
base_url: str = "https://api.openai.com/v1"
api_key: str = os.environ.get("OPENAI_API_KEY", None)

def __post_init__(self):
if not self.generation_parameters:
Expand All @@ -74,17 +79,23 @@ def from_path(cls, path: str) -> "OpenAIModelConfig":
import yaml

with open(path, "r") as f:
config = yaml.safe_load(f)["model"]
loaded_file = yaml.safe_load(f)
config = loaded_file["model"]
api = loaded_file.get("api", {})
generation_parameters = GenerationParameters.from_dict(config)
return cls(model=config["model_name"], generation_parameters=generation_parameters)
return cls(model=config["model_name"], generation_parameters=generation_parameters, **api)


class OpenAIClient(LightevalModel):
_DEFAULT_MAX_LENGTH: int = 4096

def __init__(self, config: OpenAIModelConfig, env_config) -> None:
api_key = os.environ["OPENAI_API_KEY"]
self.client = OpenAI(api_key=api_key)
def __init__(self, config: OpenAIModelConfig, env_config, is_async: bool = False) -> None:
self.is_async = is_async
if is_async:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
else:
self.client = OpenAI(api_key=config.api_key, base_url=config.base_url)
self.config = config
self.generation_parameters = config.generation_parameters
self.sampling_params = self.generation_parameters.to_vllm_openai_dict()

Expand All @@ -99,27 +110,32 @@ def __init__(self, config: OpenAIModelConfig, env_config) -> None:
self.API_RETRY_MULTIPLIER = 2
self.CONCURENT_CALLS = 100
self.model = config.model
self._tokenizer = tiktoken.encoding_for_model(self.model)
try:
self._tokenizer = tiktoken.encoding_for_model(self.model)
except KeyError:
self._tokenizer = AutoTokenizer.from_pretrained(self.model)
self.pairwise_tokenization = False

def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias):
for _ in range(self.API_MAX_RETRY):
try:
response_format = {"response_format": {"type": "text"}} if "openai" in self.config.base_url else {}
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "text"},
max_tokens=max_new_tokens if max_new_tokens > 0 else None,
logprobs=return_logits,
logit_bias=logit_bias,
n=num_samples,
**self.sampling_params,
**response_format,
)
self.API_RETRY_SLEEP = 3
return response
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
self.API_RETRY_SLEEP = self.API_RETRY_SLEEP**self.API_RETRY_MULTIPLIER
self.API_RETRY_SLEEP = self.API_RETRY_SLEEP * self.API_RETRY_MULTIPLIER
raise Exception("Failed to get response from the API")

def __call_api_parallel(
Expand Down Expand Up @@ -153,6 +169,62 @@ def __call_api_parallel(

return results

async def __call_api_async_one(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias):
for _ in range(self.API_MAX_RETRY):
try:
response_format = {"response_format": {"type": "text"}} if "openai" in self.config.base_url else {}
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_new_tokens if max_new_tokens > 0 else None,
logprobs=return_logits,
logit_bias=logit_bias,
n=num_samples,
**self.sampling_params,
**response_format,
)
return response
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
self.API_RETRY_SLEEP = self.API_RETRY_SLEEP**self.API_RETRY_MULTIPLIER
raise Exception("Failed to get response from the API")

async def __call_api_async(
self,
prompts,
return_logits: bool | list[bool],
max_new_tokens: int | list[int],
num_samples: int | list[int],
logit_bias: list[dict[int, float]] | None = None,
):
# Convert single values to lists
return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits
max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens
num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples
logit_biass = [logit_bias for _ in prompts] if logit_bias is None else logit_bias

# Validate input lengths
assert (
len(prompts) == len(return_logitss) == len(max_new_tokenss) == len(num_sampless) == len(logit_biass)
), "Length of prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass should be same"

async with asyncio.Semaphore(10): # 10 = num workers
# Create tasks for each prompt
tasks = [
await self.__call_api_async_one(prompt, ret_log, max_tok, num_samp, log_bias)
for prompt, ret_log, max_tok, num_samp, log_bias in zip(
prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass
)
]

results = await tqdm_asyncio.gather(*tasks, return_exceptions=True)

if None in results:
raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.")

return results

def greedy_until(
self,
requests: list[GreedyUntilRequest],
Expand Down Expand Up @@ -181,12 +253,15 @@ def greedy_until(
position=0,
disable=False, # self.disable_tqdm,
):
max_new_tokens = dataset[0].generation_size # could be none
max_new_tokens = 500 # dataset[0].generation_size # could be none
return_logits = dataset[0].use_logits
num_samples = dataset[0].num_samples
contexts = [c.context for c in dataset]

responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples)
if self.is_async:
responses = asyncio.run(self.__call_api_async(contexts, return_logits, max_new_tokens, num_samples))
else:
responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples)

for response in responses:
result: list[str] = [output.message.content for output in response.choices]
Expand Down
3 changes: 2 additions & 1 deletion src/lighteval/tasks/extended/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@


if can_load_extended_tasks():
import lighteval.tasks.extended.hle.main as hle
import lighteval.tasks.extended.ifeval.main as ifeval
import lighteval.tasks.extended.mix_eval.main as mix_eval
import lighteval.tasks.extended.mt_bench.main as mt_bench
import lighteval.tasks.extended.tiny_benchmarks.main as tiny_benchmarks

AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval]
AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval, hle]

else:
AVAILABLE_EXTENDED_TASKS_MODULES = []
Loading