diff --git a/dspy/adapters/__init__.py b/dspy/adapters/__init__.py index f6a128a7e7..f66ec95441 100644 --- a/dspy/adapters/__init__.py +++ b/dspy/adapters/__init__.py @@ -1,13 +1,17 @@ from dspy.adapters.base import Adapter from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter +from dspy.adapters.retry_adapter import RetryAdapter from dspy.adapters.types import Image, History from dspy.adapters.two_step_adapter import TwoStepAdapter +DEFAULT_ADAPTER = RetryAdapter(main_adapter=ChatAdapter(), fallback_adapter=JSONAdapter()) + __all__ = [ "Adapter", "ChatAdapter", "JSONAdapter", + "RetryAdapter", "Image", "History", "TwoStepAdapter", diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 0fa3b5d657..07bab01069 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,5 +1,4 @@ from typing import TYPE_CHECKING, Any, Optional, Type - from dspy.adapters.types import History from dspy.adapters.types.image import try_expand_image_tags from dspy.signatures.signature import Signature @@ -8,7 +7,6 @@ if TYPE_CHECKING: from dspy.clients.lm import LM - class Adapter: def __init__(self, callbacks: Optional[list[BaseCallback]] = None): self.callbacks = callbacks or [] @@ -22,7 +20,7 @@ def __init_subclass__(cls, **kwargs) -> None: def _call_post_process(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]: values = [] - + for output in outputs: output_logprobs = None diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index f2228b756f..09453fe6b2 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -2,7 +2,6 @@ import textwrap from typing import Any, Dict, NamedTuple, Optional, Type -from litellm import ContextWindowExceededError from pydantic.fields import FieldInfo from dspy.adapters.base import Adapter @@ -13,7 +12,6 @@ parse_value, translate_field_type, ) -from dspy.clients.lm import LM from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback @@ -29,26 +27,6 @@ class ChatAdapter(Adapter): def __init__(self, callbacks: Optional[list[BaseCallback]] = None): super().__init__(callbacks) - def __call__( - self, - lm: LM, - lm_kwargs: dict[str, Any], - signature: Type[Signature], - demos: list[dict[str, Any]], - inputs: dict[str, Any], - ) -> list[dict[str, Any]]: - try: - return super().__call__(lm, lm_kwargs, signature, demos, inputs) - except Exception as e: - # fallback to JSONAdapter - from dspy.adapters.json_adapter import JSONAdapter - - if isinstance(e, ContextWindowExceededError) or isinstance(self, JSONAdapter): - # On context window exceeded error or already using JSONAdapter, we don't want to retry with a different - # adapter. - raise e - return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs) - def format_field_description(self, signature: Type[Signature]) -> str: return ( f"Your input fields are:\n{get_field_description_string(signature.input_fields)}\n" diff --git a/dspy/adapters/retry_adapter.py b/dspy/adapters/retry_adapter.py new file mode 100644 index 0000000000..f77dca985a --- /dev/null +++ b/dspy/adapters/retry_adapter.py @@ -0,0 +1,145 @@ + +from typing import TYPE_CHECKING, Any, Optional, Type +import logging + +from dspy.adapters.base import Adapter +from dspy.signatures.signature import Signature +from dspy.adapters.utils import create_signature_for_retry + +if TYPE_CHECKING: + from dspy.clients.lm import LM + +logger = logging.getLogger(__name__) + +class RetryAdapter(Adapter): + """ + RetryAdapter is an adapter that retries the execution of another adapter for + a specified number of times if it fails to parse completion outputs. + """ + + def __init__(self, main_adapter: Adapter, fallback_adapter: Optional[Adapter] = None, main_adapter_max_retries: int = 3): + """ + Initializes the RetryAdapter. + + Args: + main_adapter (Adapter): The main adapter to use. + fallback_adapter (Optional[Adapter]): The fallback adapter to use if the main adapter fails. + main_adapter_max_retries (int): The maximum number of retries. Defaults to 3. + """ + self.main_adapter = main_adapter + self.fallback_adapter = fallback_adapter + self.main_adapter_max_retries = main_adapter_max_retries + + def __call__( + self, + lm: "LM", + lm_kwargs: dict[str, Any], + signature: Type[Signature], + demos: list[dict[str, Any]], + inputs: dict[str, Any], + ) -> list[dict[str, Any]]: + """ + Execute main_adapter and fallback_adapter in the following procedure: + 1. Call the main_adapter. + 2. If the main_adapter fails, call the fallback_adapter. + 3. If the fallback_adapter fails, retry the main_adapter including previous response for `max_retries` times. + + Args: + lm (LM): The dspy.LM to use. + lm_kwargs (dict[str, Any]): Additional arguments for the lm. + signature (Type[Signature]): The signature of the function. + demos (list[dict[str, Any]]): A list of demo examples. + inputs (dict[str, Any]): A list representating the user input. + + Returns: + A list of parsed completions. The size of the list is equal to `n` argument. Defaults to 1. + + Raises: + Exception: If fail to parse outputs after the maximum number of retries. + """ + outputs = [] + max_retries = max(self.main_adapter_max_retries, 0) + n_completion = lm_kwargs.get("n", 1) + + values, parse_failures = self._call_adapter( + self.main_adapter, + lm, + lm_kwargs, + signature, + demos, + inputs, + ) + outputs.extend(values) + + if len(outputs) == n_completion: + return outputs + + lm_kwargs["n"] = n_completion - len(outputs) + if self.fallback_adapter is not None: + outputs.extend(self._call_adapter( + self.fallback_adapter, + lm, + lm_kwargs, + signature, + demos, + inputs, + )[0]) + if len(outputs) == n_completion: + return outputs + + # Retry the main adapter with previous response for `max_retries` times + lm_kwargs["n"] = 1 + signature = create_signature_for_retry(signature) + if parse_failures: + inputs["previous_response"] = parse_failures[0][0] + inputs["error_message"] = str(parse_failures[0][1]) + for i in range(max_retries): + values, parse_failures = self._call_adapter( + self.main_adapter, + lm, + lm_kwargs, + signature, + demos, + inputs, + ) + outputs.extend(values) + if len(outputs) == n_completion: + return outputs + logger.warning(f"Retry {i+1}/{max_retries} for {self.main_adapter.__class__.__name__} failed with error: {parse_failures[0][1]}") + inputs["previous_response"] = parse_failures[0][0] + inputs["error_message"] = str(parse_failures[0][1]) + + # raise the last error + raise ValueError("Failed to parse LM outputs for maximum retries.") from parse_failures[0][1] + + def _call_adapter( + self, + adapter: Adapter, + lm: "LM", + lm_kwargs: dict[str, Any], + signature: Type[Signature], + demos: list[dict[str, Any]], + inputs: dict[str, Any], + ): + values = [] + parse_failures = [] + messages = adapter.format(signature=signature, demos=demos, inputs=inputs) + outputs = lm(messages=messages, **lm_kwargs) + for i, output in enumerate(outputs): + try: + output_logprobs = None + + if isinstance(output, dict): + output, output_logprobs = output["text"], output["logprobs"] + + value = adapter.parse(signature, output) + + if output_logprobs is not None: + value["logprobs"] = output_logprobs + + values.append(value) + except ValueError as e: + logger.warning(f"Failed to parse the {i+1}/{lm_kwargs.get('n', 1)} LM output with adapter {adapter.__class__.__name__}. Error: {e}") + parse_failures.append((outputs[i], e)) + + return values, parse_failures \ No newline at end of file diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py index 4a3a480113..2059fabef6 100644 --- a/dspy/adapters/utils.py +++ b/dspy/adapters/utils.py @@ -1,3 +1,4 @@ +from typing import Type import ast import enum import inspect @@ -11,6 +12,8 @@ from pydantic.fields import FieldInfo from dspy.signatures.utils import get_dspy_field_type +from dspy.signatures.field import InputField +from dspy.signatures.signature import Signature def serialize_for_json(value: Any) -> Any: @@ -237,3 +240,11 @@ def _quoted_string_for_literal_type_annotation(s: str) -> str: else: # Neither => enclose in single quotes return f"'{s}'" + +def create_signature_for_retry(signature: Type[Signature]): + signature = signature.append("previous_response", InputField( + desc="Previous response with format errors. You should avoid the same type of error as the previous response.", + )).append("error_message", InputField( + desc="Error message for the previous response.", + )) + return signature diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index e6c800dd1d..251e53c696 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -from dspy.adapters.chat_adapter import ChatAdapter +from dspy.adapters import DEFAULT_ADAPTER from dspy.clients.base_lm import BaseLM from dspy.clients.lm import LM from dspy.dsp.utils.settings import settings @@ -116,7 +116,7 @@ def _forward_postprocess(self, completions, signature, **kwargs): def forward(self, **kwargs): lm, config, signature, demos, kwargs = self._forward_preprocess(**kwargs) - adapter = settings.adapter or ChatAdapter() + adapter = settings.adapter or DEFAULT_ADAPTER stream_listeners = settings.stream_listeners or [] stream = settings.send_stream is not None diff --git a/dspy/predict/refine.py b/dspy/predict/refine.py index 9991750d63..b983159ac9 100644 --- a/dspy/predict/refine.py +++ b/dspy/predict/refine.py @@ -6,6 +6,7 @@ import dspy from dspy.adapters.utils import get_field_description_string +from dspy.adapters import DEFAULT_ADAPTER from dspy.predict.predict import Prediction from dspy.signatures import InputField, OutputField, Signature @@ -100,7 +101,7 @@ def forward(self, **kwargs): temps = list(dict.fromkeys(temps))[: self.N] best_pred, best_trace, best_reward = None, None, -float("inf") advice = None - adapter = dspy.settings.adapter or dspy.ChatAdapter() + adapter = dspy.settings.adapter or DEFAULT_ADAPTER for idx, t in enumerate(temps): lm_ = lm.copy(temperature=t) diff --git a/dspy/predict/retry.py b/dspy/predict/retry.py deleted file mode 100644 index 66542ba439..0000000000 --- a/dspy/predict/retry.py +++ /dev/null @@ -1,74 +0,0 @@ -# import copy - -# import dspy - -# from .predict import Predict - - -# class Retry(Predict): -# def __init__(self, module): -# super().__init__(module.signature) -# self.module = module -# self.original_signature = module.signature -# self.original_forward = module.forward -# self.new_signature = self._create_new_signature(self.original_signature) - -# def _create_new_signature(self, signature): -# # Add "Past" input fields for each output field -# for key, value in signature.output_fields.items(): -# actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":" -# signature = signature.append(f"past_{key}", dspy.InputField( -# prefix="Previous " + actual_prefix, -# desc=f"past {actual_prefix[:-1]} with errors", -# format=value.json_schema_extra.get("format"), -# )) - -# signature = signature.append("feedback", dspy.InputField( -# prefix="Instructions:", -# desc="Some instructions you must satisfy", -# format=str, -# )) - -# return signature - -# def forward(self, *, past_outputs, **kwargs): -# # Take into account the possible new signature, as in TypedPredictor -# new_signature = kwargs.pop("new_signature", None) -# if new_signature: -# self.original_signature = new_signature -# self.new_signature = self._create_new_signature(self.original_signature) - -# # Convert the dict past_outputs={"answer": ...} to kwargs -# # {past_answer=..., ...} -# for key, value in past_outputs.items(): -# past_key = f"past_{key}" -# if past_key in self.new_signature.input_fields: -# kwargs[past_key] = value -# # Tell the wrapped module to use the new signature. -# # Note: This only works if the wrapped module is a Predict or ChainOfThought. -# kwargs["new_signature"] = self.new_signature -# return self.original_forward(**kwargs) - -# def __call__(self, **kwargs): -# copy.deepcopy(kwargs) -# kwargs["_trace"] = False -# kwargs.setdefault("demos", self.demos if self.demos is not None else []) - -# # perform backtracking -# if dspy.settings.backtrack_to == self: -# for key, value in dspy.settings.backtrack_to_args.items(): -# kwargs.setdefault(key, value) -# pred = self.forward(**kwargs) -# else: -# pred = self.module(**kwargs) - -# # now pop multiple reserved keys -# # NOTE(shangyin) past_outputs seems not useful to include in demos, -# # therefore dropped -# for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]: -# kwargs.pop(key, None) - -# if dspy.settings.trace is not None: -# trace = dspy.settings.trace -# trace.append((self, {**kwargs}, pred)) -# return pred diff --git a/dspy/teleprompt/bootstrap_finetune.py b/dspy/teleprompt/bootstrap_finetune.py index 129a8f7933..ce16cae700 100644 --- a/dspy/teleprompt/bootstrap_finetune.py +++ b/dspy/teleprompt/bootstrap_finetune.py @@ -4,7 +4,7 @@ import dspy from dspy.adapters.base import Adapter -from dspy.adapters.chat_adapter import ChatAdapter +from dspy.adapters import DEFAULT_ADAPTER from dspy.clients.lm import LM from dspy.clients.utils_finetune import infer_data_format from dspy.dsp.utils.settings import settings @@ -163,7 +163,7 @@ def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_ logger.info(f"After filtering with the metric, {len(trace_data)} examples remain") data = [] - adapter = self.adapter[lm] or settings.adapter or ChatAdapter() + adapter = self.adapter[lm] or settings.adapter or DEFAULT_ADAPTER data_format = infer_data_format(adapter) for item in trace_data: for pred_ind, _ in enumerate(item["trace"]): diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py index 99a028fd06..f233f81f3a 100644 --- a/dspy/utils/dummies.py +++ b/dspy/utils/dummies.py @@ -73,6 +73,7 @@ def __init__(self, answers: Union[list[dict[str, str]], dict[str, dict[str, str] if isinstance(answers, list): self.answers = iter(answers) self.follow_examples = follow_examples + self.call_count = 0 def _use_example(self, messages): # find all field names @@ -126,6 +127,7 @@ def format_answer_fields(field_names_and_values: Dict[str, Any]): entry = dict(**entry, cost=0) self.history.append(entry) self.update_global_history(entry) + self.call_count += 1 return outputs diff --git a/tests/adapters/test_retry_adapter.py b/tests/adapters/test_retry_adapter.py new file mode 100644 index 0000000000..1f3cc37163 --- /dev/null +++ b/tests/adapters/test_retry_adapter.py @@ -0,0 +1,92 @@ +import pytest +from unittest.mock import patch +import dspy +from dspy.adapters import RetryAdapter, ChatAdapter, JSONAdapter +from dspy.signatures.signature import Signature +from dspy.utils.dummies import DummyLM + + +class MockSignature(Signature): + question: str = dspy.InputField() + answer: int = dspy.OutputField() + + +@pytest.mark.parametrize( + "demos", + [ + [], + [dspy.Example({"question": "6 x 7", "answer": 42})], + ], +) +@pytest.mark.parametrize( + "max_retries", + [ + 0, + 3, + ], +) +@pytest.mark.parametrize( + "n", + [ + 1, + 3, + ], +) +def test_adapter_max_retry(demos, max_retries, n): + main_adapter = ChatAdapter() + fallback_adapter = JSONAdapter() + adapter = RetryAdapter(main_adapter=main_adapter, fallback_adapter=fallback_adapter, main_adapter_max_retries=max_retries) + lm = DummyLM([{"answer": "42"}] * (n * 2 + max_retries)) + inputs = {"question": "6 x 7"} + + with dspy.context(lm=lm): + with ( + patch.object( + main_adapter, + "parse", + side_effect=ValueError("error"), + ) as mock_main_parse, + patch.object( + main_adapter, + "format", + wraps=main_adapter.format, + ) as mock_main_format, + patch.object( + fallback_adapter, + "parse", + side_effect=ValueError("error"), + ) as mock_fallback_parse, + ): + with pytest.raises(ValueError, match="Failed to parse LM outputs for maximum retries"): + adapter(lm, {"n": n}, MockSignature, demos, inputs) + + assert mock_main_parse.call_count == n + max_retries + assert mock_fallback_parse.call_count == n + assert lm.call_count == max_retries + 2 + + assert mock_main_format.call_count == max_retries + 1 + _, kwargs = mock_main_format.call_args + assert kwargs["inputs"]["previous_response"] == "[[ ## answer ## ]]\n42" + assert kwargs["inputs"]["error_message"] == "error" + assert kwargs["inputs"]["question"] == "6 x 7" + + +def test_adapter_fallback(): + main_adapter = JSONAdapter() + fallback_adapter = ChatAdapter() + adapter = RetryAdapter(main_adapter=main_adapter, fallback_adapter=fallback_adapter, main_adapter_max_retries=1) + lm = DummyLM([{"answer": "42"}] * 3) + inputs = {"question": "6 x 7"} + + with dspy.context(lm=lm): + with ( + patch.object( + main_adapter, + "parse", + side_effect=ValueError("error"), + ) as mock_main_parse, + ): + result = adapter(lm, {}, MockSignature, [], inputs) + + assert result == [{"answer": 42}] + assert mock_main_parse.call_count == 1 diff --git a/tests/adapters/test_two_step_adapter.py b/tests/adapters/test_two_step_adapter.py index cf83c3dbdc..9e7f46d3a4 100644 --- a/tests/adapters/test_two_step_adapter.py +++ b/tests/adapters/test_two_step_adapter.py @@ -2,6 +2,7 @@ import pytest import dspy +from dspy.utils.dummies import DummyLM def test_two_step_adapter_call(): @@ -141,15 +142,7 @@ class ComplexSignature(dspy.Signature): first_response = "main LM response" - mock_extraction_lm = mock.MagicMock(spec=dspy.LM) - mock_extraction_lm.return_value = [""" - { - "tags": ["AI", "deep learning", "neural networks"], - "confidence": 0.87 - } - """] - mock_extraction_lm.kwargs = {"temperature": 1.0} - mock_extraction_lm.model = "openai/gpt-4o" + mock_extraction_lm = DummyLM([{"tags": ["AI", "deep learning", "neural networks"], "confidence": 0.87}]) adapter = dspy.TwoStepAdapter(mock_extraction_lm) dspy.configure(adapter=adapter, lm=mock_extraction_lm) diff --git a/tests/propose/test_grounded_proposer.py b/tests/propose/test_grounded_proposer.py index ddf041632b..9d2108a2da 100644 --- a/tests/propose/test_grounded_proposer.py +++ b/tests/propose/test_grounded_proposer.py @@ -14,7 +14,7 @@ ) def test_propose_instructions_for_program(demo_candidates): # Set large numner here so that lm always returns the same response - prompt_model = DummyLM([{"proposed_instruction": "instruction"}] * 10) + prompt_model = DummyLM([{"observations": ""}, {"summary": ""}, {"program_description": ""}, {"module_description": ""}, {"proposed_instruction": "instruction"}]) program = Predict("question -> answer") trainset = [] @@ -36,7 +36,7 @@ def test_propose_instructions_for_program(demo_candidates): ], ) def test_propose_instruction_for_predictor(demo_candidates): - prompt_model = DummyLM([{"proposed_instruction": "instruction"}] * 10) + prompt_model = DummyLM([{"observations": ""}, {"summary": ""}, {"program_description": ""}, {"module_description": ""}, {"proposed_instruction": "instruction"}]) program = Predict("question -> answer") proposer = GroundedProposer(prompt_model=prompt_model, program=program, trainset=[], verbose=False)