Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions launch/launch/utilities/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,45 @@
import os
from functools import wraps
from typing import List

from langchain_core.messages import BaseMessage, HumanMessage
from tenacity import retry, stop_after_attempt, wait_exponential_jitter
from tenacity import (
retry,
retry_if_exception,
stop_after_attempt,
wait_exponential_jitter,
)

try: # OpenAI SDK present
from openai import RateLimitError as OpenAIRateLimitError
except Exception: # pragma: no cover - openai optional
class OpenAIRateLimitError(Exception):
"""Fallback placeholder when OpenAI SDK is unavailable."""


try:
from anthropic import RateLimitError as AnthropicRateLimitError
except Exception: # pragma: no cover - anthropic optional
class AnthropicRateLimitError(Exception):
"""Fallback placeholder when Anthropic SDK is unavailable."""


def _is_rate_limit_error(exc: Exception) -> bool:
"""
Detects rate limit errors across different SDKs.

Besides dedicated RateLimitError classes, we inspect HTTP-style status
attributes and error messages to catch wrapped 429 responses.
"""
if isinstance(exc, (OpenAIRateLimitError, AnthropicRateLimitError)):
return True

status = getattr(exc, "status_code", None) or getattr(exc, "status", None)
if status == 429:
return True

message = str(getattr(exc, "message", exc)).lower()
return "rate limit" in message


def logged_invoke(invoke_func):
Expand Down Expand Up @@ -78,8 +115,10 @@ def __init__(self, llm_provider: str, log_folder: str | None = "./llm_logs", **k

@logged_invoke
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential_jitter(initial=5, max=10, jitter=3)
reraise=True,
retry=retry_if_exception(_is_rate_limit_error),
stop=stop_after_attempt(6),
wait=wait_exponential_jitter(initial=2, max=60),
)
def invoke(self, messages: List[BaseMessage]) -> BaseMessage:
"""
Expand Down