From 925536cbed7addd01b86400a99fd9cc918156e5e Mon Sep 17 00:00:00 2001 From: Daniel Zayas Date: Thu, 4 Dec 2025 12:20:32 -0800 Subject: [PATCH] update llm.py to backoff after LLM provider rate limits --- launch/launch/utilities/llm.py | 45 +++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/launch/launch/utilities/llm.py b/launch/launch/utilities/llm.py index 88e9abd..39f01c4 100644 --- a/launch/launch/utilities/llm.py +++ b/launch/launch/utilities/llm.py @@ -4,8 +4,45 @@ import os from functools import wraps from typing import List + from langchain_core.messages import BaseMessage, HumanMessage -from tenacity import retry, stop_after_attempt, wait_exponential_jitter +from tenacity import ( + retry, + retry_if_exception, + stop_after_attempt, + wait_exponential_jitter, +) + +try: # OpenAI SDK present + from openai import RateLimitError as OpenAIRateLimitError +except Exception: # pragma: no cover - openai optional + class OpenAIRateLimitError(Exception): + """Fallback placeholder when OpenAI SDK is unavailable.""" + + +try: + from anthropic import RateLimitError as AnthropicRateLimitError +except Exception: # pragma: no cover - anthropic optional + class AnthropicRateLimitError(Exception): + """Fallback placeholder when Anthropic SDK is unavailable.""" + + +def _is_rate_limit_error(exc: Exception) -> bool: + """ + Detects rate limit errors across different SDKs. + + Besides dedicated RateLimitError classes, we inspect HTTP-style status + attributes and error messages to catch wrapped 429 responses. + """ + if isinstance(exc, (OpenAIRateLimitError, AnthropicRateLimitError)): + return True + + status = getattr(exc, "status_code", None) or getattr(exc, "status", None) + if status == 429: + return True + + message = str(getattr(exc, "message", exc)).lower() + return "rate limit" in message def logged_invoke(invoke_func): @@ -78,8 +115,10 @@ def __init__(self, llm_provider: str, log_folder: str | None = "./llm_logs", **k @logged_invoke @retry( - stop=stop_after_attempt(3), - wait=wait_exponential_jitter(initial=5, max=10, jitter=3) + reraise=True, + retry=retry_if_exception(_is_rate_limit_error), + stop=stop_after_attempt(6), + wait=wait_exponential_jitter(initial=2, max=60), ) def invoke(self, messages: List[BaseMessage]) -> BaseMessage: """