diff --git a/llm_toolkit/models.py b/llm_toolkit/models.py index 5a81c28dd0..2da03b0f30 100644 --- a/llm_toolkit/models.py +++ b/llm_toolkit/models.py @@ -77,6 +77,9 @@ def __init__( self.temperature = temperature self.temperature_list = temperature_list + # Preserve chat history for OpenAI + self.messages = [] + def cloud_setup(self): """Runs Cloud specific-setup.""" # Only a subset of models need a cloud specific set up, so @@ -275,14 +278,19 @@ def chat_llm(self, client: Any, prompt: prompts.Prompt) -> str: logger.info('OpenAI does not allow temperature list: %s', self.temperature_list) + self.messages.extend(prompt.get()) + completion = self.with_retry_on_error( - lambda: client.chat.completions.create(messages=prompt.get(), + lambda: client.chat.completions.create(messages=self.messages, model=self.name, n=self.num_samples, temperature=self.temperature), [openai.OpenAIError]) - return completion.choices[0].message.content + llm_response = completion.choices[0].message.content + self.messages.append({'role': 'assistant', 'content': llm_response}) + + return llm_response def ask_llm(self, prompt: prompts.Prompt) -> str: """Queries LLM a single prompt and returns its response.""" diff --git a/llm_toolkit/prompts.py b/llm_toolkit/prompts.py index 647fafc82c..469ac6d205 100644 --- a/llm_toolkit/prompts.py +++ b/llm_toolkit/prompts.py @@ -125,7 +125,8 @@ def gettext(self) -> str: """Gets the final formatted prompt in plain text.""" result = '' for item in self.get(): - result = f'{result}\n{item.get("content", "")}' + result = (f'{result}\n{item.get("role", "Unknown")}:' + f'\n{item.get("content", "")}') return result