Skip to content

Implement adapter retry for Pydantic Validation Error #8050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from dspy.adapters.base import Adapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JSONAdapter
from dspy.adapters.retry_adapter import RetryAdapter
from dspy.adapters.types import Image, History
from dspy.adapters.two_step_adapter import TwoStepAdapter

DEFAULT_ADAPTER = RetryAdapter(main_adapter=ChatAdapter(), fallback_adapter=JSONAdapter())

__all__ = [
"Adapter",
"ChatAdapter",
"JSONAdapter",
"RetryAdapter",
"Image",
"History",
"TwoStepAdapter",
4 changes: 1 addition & 3 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import TYPE_CHECKING, Any, Optional, Type

from dspy.adapters.types import History
from dspy.adapters.types.image import try_expand_image_tags
from dspy.signatures.signature import Signature
@@ -8,7 +7,6 @@
if TYPE_CHECKING:
from dspy.clients.lm import LM


class Adapter:
def __init__(self, callbacks: Optional[list[BaseCallback]] = None):
self.callbacks = callbacks or []
@@ -22,7 +20,7 @@ def __init_subclass__(cls, **kwargs) -> None:

def _call_post_process(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]:
values = []

for output in outputs:
output_logprobs = None

22 changes: 0 additions & 22 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
import textwrap
from typing import Any, Dict, NamedTuple, Optional, Type

from litellm import ContextWindowExceededError
from pydantic.fields import FieldInfo

from dspy.adapters.base import Adapter
@@ -13,7 +12,6 @@
parse_value,
translate_field_type,
)
from dspy.clients.lm import LM
from dspy.signatures.signature import Signature
from dspy.utils.callback import BaseCallback

@@ -29,26 +27,6 @@ class ChatAdapter(Adapter):
def __init__(self, callbacks: Optional[list[BaseCallback]] = None):
super().__init__(callbacks)

def __call__(
self,
lm: LM,
lm_kwargs: dict[str, Any],
signature: Type[Signature],
demos: list[dict[str, Any]],
inputs: dict[str, Any],
) -> list[dict[str, Any]]:
try:
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
except Exception as e:
# fallback to JSONAdapter
from dspy.adapters.json_adapter import JSONAdapter

if isinstance(e, ContextWindowExceededError) or isinstance(self, JSONAdapter):
# On context window exceeded error or already using JSONAdapter, we don't want to retry with a different
# adapter.
raise e
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)

def format_field_description(self, signature: Type[Signature]) -> str:
return (
f"Your input fields are:\n{get_field_description_string(signature.input_fields)}\n"
145 changes: 145 additions & 0 deletions dspy/adapters/retry_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@

from typing import TYPE_CHECKING, Any, Optional, Type
import logging

from dspy.adapters.base import Adapter
from dspy.signatures.signature import Signature
from dspy.adapters.utils import create_signature_for_retry

if TYPE_CHECKING:
from dspy.clients.lm import LM

logger = logging.getLogger(__name__)

class RetryAdapter(Adapter):
"""
RetryAdapter is an adapter that retries the execution of another adapter for
a specified number of times if it fails to parse completion outputs.
"""

def __init__(self, main_adapter: Adapter, fallback_adapter: Optional[Adapter] = None, main_adapter_max_retries: int = 3):
"""
Initializes the RetryAdapter.

Args:
main_adapter (Adapter): The main adapter to use.
fallback_adapter (Optional[Adapter]): The fallback adapter to use if the main adapter fails.
main_adapter_max_retries (int): The maximum number of retries. Defaults to 3.
"""
self.main_adapter = main_adapter
self.fallback_adapter = fallback_adapter
self.main_adapter_max_retries = main_adapter_max_retries

def __call__(
self,
lm: "LM",
lm_kwargs: dict[str, Any],
signature: Type[Signature],
demos: list[dict[str, Any]],
inputs: dict[str, Any],
) -> list[dict[str, Any]]:
"""
Execute main_adapter and fallback_adapter in the following procedure:
1. Call the main_adapter.
2. If the main_adapter fails, call the fallback_adapter.
3. If the fallback_adapter fails, retry the main_adapter including previous response for `max_retries` times.

Args:
lm (LM): The dspy.LM to use.
lm_kwargs (dict[str, Any]): Additional arguments for the lm.
signature (Type[Signature]): The signature of the function.
demos (list[dict[str, Any]]): A list of demo examples.
inputs (dict[str, Any]): A list representating the user input.

Returns:
A list of parsed completions. The size of the list is equal to `n` argument. Defaults to 1.

Raises:
Exception: If fail to parse outputs after the maximum number of retries.
"""
outputs = []
max_retries = max(self.main_adapter_max_retries, 0)
n_completion = lm_kwargs.get("n", 1)

values, parse_failures = self._call_adapter(
self.main_adapter,
lm,
lm_kwargs,
signature,
demos,
inputs,
)
outputs.extend(values)

if len(outputs) == n_completion:
return outputs

lm_kwargs["n"] = n_completion - len(outputs)
if self.fallback_adapter is not None:
outputs.extend(self._call_adapter(
self.fallback_adapter,
lm,
lm_kwargs,
signature,
demos,
inputs,
)[0])
if len(outputs) == n_completion:
return outputs

# Retry the main adapter with previous response for `max_retries` times
lm_kwargs["n"] = 1
signature = create_signature_for_retry(signature)
if parse_failures:
inputs["previous_response"] = parse_failures[0][0]
inputs["error_message"] = str(parse_failures[0][1])
for i in range(max_retries):
values, parse_failures = self._call_adapter(
self.main_adapter,
lm,
lm_kwargs,
signature,
demos,
inputs,
)
outputs.extend(values)
if len(outputs) == n_completion:
return outputs
logger.warning(f"Retry {i+1}/{max_retries} for {self.main_adapter.__class__.__name__} failed with error: {parse_failures[0][1]}")
inputs["previous_response"] = parse_failures[0][0]
inputs["error_message"] = str(parse_failures[0][1])

# raise the last error
raise ValueError("Failed to parse LM outputs for maximum retries.") from parse_failures[0][1]

def _call_adapter(
self,
adapter: Adapter,
lm: "LM",
lm_kwargs: dict[str, Any],
signature: Type[Signature],
demos: list[dict[str, Any]],
inputs: dict[str, Any],
):
values = []
parse_failures = []
messages = adapter.format(signature=signature, demos=demos, inputs=inputs)
outputs = lm(messages=messages, **lm_kwargs)
for i, output in enumerate(outputs):
try:
output_logprobs = None

if isinstance(output, dict):
output, output_logprobs = output["text"], output["logprobs"]

value = adapter.parse(signature, output)

if output_logprobs is not None:
value["logprobs"] = output_logprobs

values.append(value)
except ValueError as e:
logger.warning(f"Failed to parse the {i+1}/{lm_kwargs.get('n', 1)} LM output with adapter {adapter.__class__.__name__}. Error: {e}")
parse_failures.append((outputs[i], e))

return values, parse_failures
11 changes: 11 additions & 0 deletions dspy/adapters/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Type
import ast
import enum
import inspect
@@ -11,6 +12,8 @@
from pydantic.fields import FieldInfo

from dspy.signatures.utils import get_dspy_field_type
from dspy.signatures.field import InputField
from dspy.signatures.signature import Signature


def serialize_for_json(value: Any) -> Any:
@@ -237,3 +240,11 @@ def _quoted_string_for_literal_type_annotation(s: str) -> str:
else:
# Neither => enclose in single quotes
return f"'{s}'"

def create_signature_for_retry(signature: Type[Signature]):
signature = signature.append("previous_response", InputField(
desc="Previous response with format errors. You should avoid the same type of error as the previous response.",
)).append("error_message", InputField(
desc="Error message for the previous response.",
))
return signature
4 changes: 2 additions & 2 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@

from pydantic import BaseModel

from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters import DEFAULT_ADAPTER
from dspy.clients.base_lm import BaseLM
from dspy.clients.lm import LM
from dspy.dsp.utils.settings import settings
@@ -116,7 +116,7 @@ def _forward_postprocess(self, completions, signature, **kwargs):
def forward(self, **kwargs):
lm, config, signature, demos, kwargs = self._forward_preprocess(**kwargs)

adapter = settings.adapter or ChatAdapter()
adapter = settings.adapter or DEFAULT_ADAPTER

stream_listeners = settings.stream_listeners or []
stream = settings.send_stream is not None
3 changes: 2 additions & 1 deletion dspy/predict/refine.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@

import dspy
from dspy.adapters.utils import get_field_description_string
from dspy.adapters import DEFAULT_ADAPTER
from dspy.predict.predict import Prediction
from dspy.signatures import InputField, OutputField, Signature

@@ -100,7 +101,7 @@ def forward(self, **kwargs):
temps = list(dict.fromkeys(temps))[: self.N]
best_pred, best_trace, best_reward = None, None, -float("inf")
advice = None
adapter = dspy.settings.adapter or dspy.ChatAdapter()
adapter = dspy.settings.adapter or DEFAULT_ADAPTER

for idx, t in enumerate(temps):
lm_ = lm.copy(temperature=t)
74 changes: 0 additions & 74 deletions dspy/predict/retry.py

This file was deleted.

4 changes: 2 additions & 2 deletions dspy/teleprompt/bootstrap_finetune.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@

import dspy
from dspy.adapters.base import Adapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters import DEFAULT_ADAPTER
from dspy.clients.lm import LM
from dspy.clients.utils_finetune import infer_data_format
from dspy.dsp.utils.settings import settings
@@ -163,7 +163,7 @@ def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_
logger.info(f"After filtering with the metric, {len(trace_data)} examples remain")

data = []
adapter = self.adapter[lm] or settings.adapter or ChatAdapter()
adapter = self.adapter[lm] or settings.adapter or DEFAULT_ADAPTER
data_format = infer_data_format(adapter)
for item in trace_data:
for pred_ind, _ in enumerate(item["trace"]):
Loading