Skip to content

Commit

Permalink
fix: catch any HTTP errors at exchange.generate(...) (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukealvoeiro committed Sep 3, 2024
1 parent 49d5f6f commit 119b8d7
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.8.4] - 2024-09-02

- Catch any HTTP errors the provider emits and retry the call to `generate` with different messages

## [0.8.3] - 2024-09-02

- Refactor checkpoints to allow exchange to stay in-sync across messages and checkpoints
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ai-exchange"
version = "0.8.3"
version = "0.8.4"
description = "a uniform python SDK for message generation with LLMs"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
42 changes: 36 additions & 6 deletions src/exchange/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, List, Mapping, Tuple

from attrs import define, evolve, field
from httpx import HTTPStatusError
from tiktoken import get_encoding

from exchange.checkpoint import Checkpoint, CheckpointData
Expand All @@ -14,6 +15,11 @@
from exchange.providers import Provider, Usage
from exchange.tool import Tool

FAILED_TO_GENERATE_MSG = "Failed to generate the next message."

# TODO: decide on the correct number of retries here
REMOVE_MESSAGE_RETRY_TIMES = 3


def validate_tool_output(output: str) -> None:
"""Validate tool output for the given model"""
Expand Down Expand Up @@ -72,12 +78,36 @@ def generate(self) -> Message:
"""Generate the next message."""
self.moderator.rewrite(self)

message, usage = self.provider.complete(
self.model,
self.system,
messages=self.messages,
tools=self.tools,
)
num_times_attempted = 0
while num_times_attempted < 3:
# we will attempt to generate a response with retries a few times.
# if we run into an HTTP error, we will pop the last message until the last
# message is a user text message, and then try again. we will do this three
# times before giving up.
# providers that are hosted on your own machine should not throw HTTP errors,
# and could instead configure their own behavior by throwing a different type
# of error.
try:
message, usage = self.provider.complete(
self.model,
self.system,
messages=self.messages,
tools=self.tools,
)
break
except HTTPStatusError:
if len(self.messages) <= 1:
# we can't pop any messages, so we have to give up
raise Exception(FAILED_TO_GENERATE_MSG)
self.pop_last_message()
while len(self.messages) > 1 and self.messages[-1].role == "assistant":
# why 1? because we need to keep at least one user message in the exchange
self.pop_last_message()
num_times_attempted += 1

if num_times_attempted >= REMOVE_MESSAGE_RETRY_TIMES:
# we failed to generate a response after three attempts
raise Exception(FAILED_TO_GENERATE_MSG)

self.add(message)
self.add_checkpoints_from_usage(usage) # this has to come after adding the response
Expand Down
74 changes: 74 additions & 0 deletions tests/test_exchange.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Tuple
from unittest.mock import Mock

from httpx import HTTPStatusError
import pytest

from exchange.checkpoint import Checkpoint, CheckpointData
Expand Down Expand Up @@ -679,3 +681,75 @@ def test_prepend_checkpointed_message_empty_exchange(normal_exchange):
assert len(ex.checkpoint_data.checkpoints) == 3
assert len(ex.messages) == 3
assert_no_overlapping_checkpoints(ex)


def test_generate_successful_response_on_first_try(normal_exchange):
ex = normal_exchange
ex.add(Message(role="user", content=[Text("Hello")]))
ex.generate()


class MockErrorProvider(Provider):
def __init__(self, sequence: List[Message | Exception]):
self.sequence = sequence
self.call_count = 0

def complete(self, model: str, system: str, messages: List[Message], tools: List[Tool]) -> Message:
next_item = self.sequence[self.call_count]
if self.call_count != len(self.sequence) - 1:
self.call_count += 1
if isinstance(next_item, Exception):
raise HTTPStatusError("Bad Request", request=Mock(), response=Mock())
return next_item, Usage(input_tokens=10, output_tokens=10, total_tokens=20)


def test_generate_http_error_recovery_empty_messages():
ex = Exchange(
provider=MockErrorProvider(
sequence=[Message.assistant("Some text"), Exception(), Message.assistant("Some other text")]
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
moderator=PassiveModerator(),
)
ex.add(Message.user("Hello"))
ex.generate()
ex.add(Message.user("Hello again!"))
ex.generate()


def test_generate_http_error_changes_messages():
ex = Exchange(
provider=MockErrorProvider(
sequence=[
Message(role="assistant", content=[ToolUse(id="1", name="dishwasher", parameters={})]),
Exception(),
Message.assistant("I'm done cleaning the dishes."),
]
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
moderator=PassiveModerator(),
)
ex.add(Message.user("Hi! Can you clean the dishes for me?"))
ex.generate()
ex.add(Message(role="user", content=[ToolResult(tool_use_id="1", output="I cleaned the dishes.")]))
ex.generate()

assert len(ex.messages) == 2
assert len(ex.checkpoint_data.checkpoints) == 2
assert ex.messages[0].text == "Hi! Can you clean the dishes for me?"
assert ex.messages[1].text == "I'm done cleaning the dishes."


def test_generate_http_error_no_recovery():
ex = Exchange(
provider=MockErrorProvider(sequence=[Exception()]),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
moderator=PassiveModerator(),
)
ex.add(Message(role="user", content=[Text("Hello")]))
with pytest.raises(Exception) as e:
ex.generate()
assert str(e.value) == "Failed to generate the next message."

0 comments on commit 119b8d7

Please sign in to comment.