diff --git a/pyproject.toml b/pyproject.toml index b5f6a3b..457efc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ transformers = "^4.31.0" sentence-transformers = "^2.2.2" zstd = "^1.5.5.1" openai = "^0.27.8" +litellm = "^0.1.400" numpy = "^1.20" nltk = "^3.8.1" rank-bm25 = "^0.2.2" diff --git a/retrieval_qa_benchmark/models/openai.py b/retrieval_qa_benchmark/models/openai.py index 6906185..461be8b 100644 --- a/retrieval_qa_benchmark/models/openai.py +++ b/retrieval_qa_benchmark/models/openai.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Optional import openai +import litellm from retrieval_qa_benchmark.schema import BaseLLM, BaseLLMOutput from retrieval_qa_benchmark.utils.registry import REGISTRY @@ -79,7 +80,7 @@ def _generate( self, text: str = "", ) -> BaseLLMOutput: - completion = openai.ChatCompletion.create( + completion = litellm.completion( model=self.name, messages=[ {"role": "system", "content": self.system_prompt},