-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathopenai.py
More file actions
95 lines (83 loc) · 2.69 KB
/
openai.py
File metadata and controls
95 lines (83 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import annotations
import os
from typing import Any, Dict, Optional
import openai
import litellm
from retrieval_qa_benchmark.schema import BaseLLM, BaseLLMOutput
from retrieval_qa_benchmark.utils.registry import REGISTRY
@REGISTRY.register_model("remote-llm")
class RemoteLLM(BaseLLM):
system_prompt: str = "You are a helpful assistant."
@classmethod
def build(
cls,
name: str = "llama2-13b-chat",
api_base: str = os.getenv("OPENAI_API_BASE", "http://10.1.3.28:8990/v1"),
api_key: str = os.getenv(
"OPENAI_API_KEY", "sk-some-super-secret-key-you-will-never-know"
),
system_prompt: Optional[str] = None,
run_args: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> RemoteLLM:
openai.api_base = api_base
openai.api_key = api_key
return cls(
name=name,
run_args=run_args or {},
system_prompt=system_prompt or "",
**kwargs,
)
def _generate(
self,
text: str,
) -> BaseLLMOutput:
completion = openai.Completion.create(
model=self.name,
prompt="\n".join([self.system_prompt, text]),
**self.run_args,
)
return BaseLLMOutput(
generated=completion.choices[0].text,
prompt_tokens=completion.usage.prompt_tokens,
completion_tokens=completion.usage.completion_tokens,
)
@REGISTRY.register_model("gpt35")
class GPT(RemoteLLM):
@classmethod
def build(
cls,
name: str = "text-davinci-003",
api_base: str = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
api_key: str = os.getenv("OPENAI_API_KEY", ""),
system_prompt: Optional[str] = None,
run_args: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> GPT:
openai.api_base = api_base
openai.api_key = api_key
return cls(
name=name,
run_args=run_args or {},
system_prompt=system_prompt or "",
**kwargs,
)
@REGISTRY.register_model("chatgpt35")
class ChatGPT(GPT):
def _generate(
self,
text: str = "",
) -> BaseLLMOutput:
completion = litellm.completion(
model=self.name,
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": text},
],
**self.run_args,
)
return BaseLLMOutput(
generated=completion.choices[0].message.content,
prompt_tokens=completion.usage.prompt_tokens,
completion_tokens=completion.usage.completion_tokens,
)