Skip to content

Commit

Permalink
added retry when sending httpx request to LLM provider apis (#20)
Browse files Browse the repository at this point in the history
Co-authored-by: Andy Lane <[email protected]>
Co-authored-by: Luke Alvoeiro <[email protected]>
  • Loading branch information
3 people committed Sep 2, 2024
1 parent bb22210 commit 49970aa
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 26 deletions.
24 changes: 6 additions & 18 deletions src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
import time
from typing import Any, Dict, List, Tuple, Type

import httpx

from exchange import Message, Tool
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from exchange.providers.utils import raise_for_status

ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"
Expand Down Expand Up @@ -138,26 +138,14 @@ def complete(
)
payload = {k: v for k, v in payload.items() if v}

max_retries = 5
initial_wait = 10 # Start with 10 seconds
backoff_factor = 1
for retry in range(max_retries):
response = self.client.post(ANTHROPIC_HOST, json=payload)
if response.status_code not in (429, 529, 500):
break
else:
sleep_time = initial_wait + (backoff_factor * (2**retry))
time.sleep(sleep_time)

if response.status_code in (429, 529, 500):
raise httpx.HTTPStatusError(
f"Failed after {max_retries} retries due to rate limiting",
request=response.request,
response=response,
)
response = self._send_request(payload)

response_data = raise_for_status(response).json()
message = self.anthropic_response_to_message(response_data)
usage = self.get_usage(response_data)

return message, usage

@retry_httpx_request()
def _send_request(self, payload: Dict[str, Any]) -> httpx.Response:
return self.client.post(ANTHROPIC_HOST, json=payload)
7 changes: 6 additions & 1 deletion src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand Down Expand Up @@ -98,7 +99,7 @@ def complete(

payload = {k: v for k, v in payload.items() if v}
request_url = f"{self.client.base_url}/chat/completions?api-version={self.api_version}"
response = self.client.post(request_url, json=payload)
response = self._send_request(payload, request_url)

# Check for context_length_exceeded error for single, long input message
if "error" in response.json() and len(messages) == 1:
Expand All @@ -109,3 +110,7 @@ def complete(
message = openai_response_to_message(data)
usage = self.get_usage(data)
return message, usage

@retry_httpx_request()
def _send_request(self, payload: Any, request_url: str) -> httpx.Response: # noqa: ANN401
return self.client.post(request_url, json=payload)
7 changes: 6 additions & 1 deletion src/exchange/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from exchange.providers.utils import raise_for_status
from exchange.tool import Tool

Expand Down Expand Up @@ -204,7 +205,7 @@ def complete(

path = f"model/{model}/converse"

response = self.client.post(path, json=payload)
response = self._send_request(payload, path)
raise_for_status(response)
response_message = response.json()["output"]["message"]

Expand All @@ -217,6 +218,10 @@ def complete(

return self.response_to_message(response_message), usage

@retry_httpx_request()
def _send_request(self, payload: Any, path:str) -> httpx.Response: # noqa: ANN401
return self.client.post(path, json=payload)

@staticmethod
def message_to_bedrock_spec(message: Message) -> dict:
bedrock_content = []
Expand Down
13 changes: 9 additions & 4 deletions src/exchange/providers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand Down Expand Up @@ -79,11 +80,15 @@ def complete(
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self.client.post(
f"serving-endpoints/{model}/invocations",
json=payload,
)
response = self._send_request(model, payload)
data = raise_for_status(response).json()
message = openai_response_to_message(data)
usage = self.get_usage(data)
return message, usage

@retry_httpx_request()
def _send_request(self, model: str, payload: Any) -> httpx.Response: # noqa: ANN401
return self.client.post(
f"serving-endpoints/{model}/invocations",
json=payload,
)
7 changes: 6 additions & 1 deletion src/exchange/providers/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand Down Expand Up @@ -80,7 +81,7 @@ def complete(
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self.client.post("v1/chat/completions", json=payload)
response = self._send_request(payload)

# Check for context_length_exceeded error for single, long input message
if "error" in response.json() and len(messages) == 1:
Expand All @@ -91,3 +92,7 @@ def complete(
message = openai_response_to_message(data)
usage = self.get_usage(data)
return message, usage

@retry_httpx_request()
def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401
return self.client.post("v1/chat/completions", json=payload)
7 changes: 6 additions & 1 deletion src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand Down Expand Up @@ -74,7 +75,7 @@ def complete(
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self.client.post("v1/chat/completions", json=payload)
response = self._send_request(payload)

# Check for context_length_exceeded error for single, long input message
if "error" in response.json() and len(messages) == 1:
Expand All @@ -85,3 +86,7 @@ def complete(
message = openai_response_to_message(data)
usage = self.get_usage(data)
return message, usage

@retry_httpx_request()
def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401
return self.client.post("v1/chat/completions", json=payload)
56 changes: 56 additions & 0 deletions src/exchange/providers/retry_with_back_off_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import time
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Optional

from httpx import HTTPStatusError, Response


def retry_with_backoff(
should_retry: Callable,
max_retries: Optional[int] = 5,
initial_wait: Optional[float] = 10,
backoff_factor: Optional[float] = 1,
handle_retry_exhausted: Optional[Callable] = None) -> Callable:
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args: List, **kwargs: Dict) -> Any: # noqa: ANN401
result = None
for retry in range(max_retries):
result = func(*args, **kwargs)
if not should_retry(result):
return result
if (retry + 1) == max_retries:
break
sleep_time = initial_wait + (backoff_factor * (2 ** retry))
time.sleep(sleep_time)
if handle_retry_exhausted:
handle_retry_exhausted(result, max_retries)
return result
return wrapper
return decorator

def retry_httpx_request(
retry_on_status_code: Optional[Iterable[int]] = None,
max_retries: Optional[int] = 5,
initial_wait: Optional[float] = 10,
backoff_factor: Optional[float] = 1,
) -> Callable:
if retry_on_status_code is None:
retry_on_status_code = set(range(401, 999))
def should_retry(response: Response) -> bool:
return response.status_code in retry_on_status_code

def handle_retry_exhausted(response: Response, max_retries: int) -> None:
raise HTTPStatusError(
f"Failed after {max_retries}.",
request=response.request,
response=response,
)

return retry_with_backoff(
max_retries=max_retries,
initial_wait=initial_wait,
backoff_factor=backoff_factor,
should_retry=should_retry,
handle_retry_exhausted=handle_retry_exhausted
)
Loading

0 comments on commit 49970aa

Please sign in to comment.