-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
53 lines (40 loc) · 1.21 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from dataclasses import dataclass
from typing import Optional, List
SEPARATOR_TOKEN = "<|endoftext|>"
@dataclass(frozen=True)
class Message:
user: str
text: Optional[str] = None
def render(self):
result = self.user + ":"
if self.text is not None:
result += " " + self.text
return result
@dataclass
class Conversation:
messages: List[Message]
def prepend(self, message: Message):
self.messages.insert(0, message)
return self
def render(self):
return f"\n{SEPARATOR_TOKEN}".join(
[message.render() for message in self.messages]
)
@dataclass(frozen=True)
class Config:
name: str
instructions: str
example_conversations: List[Conversation]
@dataclass(frozen=True)
class Prompt:
header: Message
examples: List[Conversation]
convo: Conversation
def render(self):
return f"\n{SEPARATOR_TOKEN}".join(
[self.header.render()]
+ [Message("System", "Example conversations:").render()]
+ [conversation.render() for conversation in self.examples]
+ [Message("System", "Current conversation:").render()]
+ [self.convo.render()],
)