From 84c1bfe28a9ad910a32cc389b34112138f6347e1 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Tue, 8 Apr 2025 20:28:18 +0900 Subject: [PATCH 01/12] implement adapter retry --- dspy/adapters/base.py | 57 ++++++++++++++++++---------- dspy/adapters/utils.py | 11 ++++++ dspy/dsp/utils/settings.py | 1 + dspy/predict/retry.py | 74 ------------------------------------- dspy/utils/dummies.py | 2 + tests/adapters/test_base.py | 33 +++++++++++++++++ 6 files changed, 84 insertions(+), 94 deletions(-) delete mode 100644 dspy/predict/retry.py create mode 100644 tests/adapters/test_base.py diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 52503781bf..be3acf0164 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,13 +1,18 @@ from typing import TYPE_CHECKING, Any, Optional, Type - +import json +import logging +from pydantic_core import ValidationError from dspy.adapters.types import History from dspy.adapters.types.image import try_expand_image_tags from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback, with_callbacks +from dspy.adapters.utils import create_signature_for_retry +from dspy.dsp.utils.settings import settings if TYPE_CHECKING: from dspy.clients.lm import LM +logger = logging.getLogger(__name__) class Adapter: def __init__(self, callbacks: Optional[list[BaseCallback]] = None): @@ -28,25 +33,37 @@ def __call__( demos: list[dict[str, Any]], inputs: dict[str, Any], ) -> list[dict[str, Any]]: - inputs = self.format(signature, demos, inputs) - - outputs = lm(messages=inputs, **lm_kwargs) - values = [] - - for output in outputs: - output_logprobs = None - - if isinstance(output, dict): - output, output_logprobs = output["text"], output["logprobs"] - - value = self.parse(signature, output) - - if output_logprobs is not None: - value["logprobs"] = output_logprobs - - values.append(value) - - return values + retry_count = 0 + max_retries = max(settings.adapter_retry_count, 0) + outputs = None + while (True): + messages = self.format(signature=signature, demos=demos, inputs=inputs) + outputs = lm(messages=messages, **lm_kwargs) + values = [] + try: + for output in outputs: + output_logprobs = None + + if isinstance(output, dict): + output, output_logprobs = output["text"], output["logprobs"] + + value = self.parse(signature, output) + + if output_logprobs is not None: + value["logprobs"] = output_logprobs + + values.append(value) + + return values + except ValidationError: + if retry_count >= max_retries: + raise + + logger.debug("A ValidationError occurred while parsing the LM output. Retrying with a new signature.", exc_info=True) + if retry_count == 0: + signature = create_signature_for_retry(signature) + inputs["previous_response"] = json.dumps(outputs) + retry_count += 1 def format( self, diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py index 4a3a480113..5e903790fe 100644 --- a/dspy/adapters/utils.py +++ b/dspy/adapters/utils.py @@ -1,3 +1,4 @@ +from typing import Type import ast import enum import inspect @@ -11,6 +12,8 @@ from pydantic.fields import FieldInfo from dspy.signatures.utils import get_dspy_field_type +from dspy.signatures.field import InputField +from dspy.signatures.signature import Signature def serialize_for_json(value: Any) -> Any: @@ -237,3 +240,11 @@ def _quoted_string_for_literal_type_annotation(s: str) -> str: else: # Neither => enclose in single quotes return f"'{s}'" + +def create_signature_for_retry(signature: Type[Signature]): + # Add previous_response field + signature = signature.append("previous_response", InputField( + prefix="Previous Response", + desc="Previous response with format errors", + )) + return signature diff --git a/dspy/dsp/utils/settings.py b/dspy/dsp/utils/settings.py index fad3cc67f0..2a25bc12ba 100644 --- a/dspy/dsp/utils/settings.py +++ b/dspy/dsp/utils/settings.py @@ -22,6 +22,7 @@ disable_history=False, track_usage=False, usage_tracker=None, + adapter_retry_count=3, # number of times to retry the adapter call ) # Global base configuration and owner tracking diff --git a/dspy/predict/retry.py b/dspy/predict/retry.py deleted file mode 100644 index 66542ba439..0000000000 --- a/dspy/predict/retry.py +++ /dev/null @@ -1,74 +0,0 @@ -# import copy - -# import dspy - -# from .predict import Predict - - -# class Retry(Predict): -# def __init__(self, module): -# super().__init__(module.signature) -# self.module = module -# self.original_signature = module.signature -# self.original_forward = module.forward -# self.new_signature = self._create_new_signature(self.original_signature) - -# def _create_new_signature(self, signature): -# # Add "Past" input fields for each output field -# for key, value in signature.output_fields.items(): -# actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":" -# signature = signature.append(f"past_{key}", dspy.InputField( -# prefix="Previous " + actual_prefix, -# desc=f"past {actual_prefix[:-1]} with errors", -# format=value.json_schema_extra.get("format"), -# )) - -# signature = signature.append("feedback", dspy.InputField( -# prefix="Instructions:", -# desc="Some instructions you must satisfy", -# format=str, -# )) - -# return signature - -# def forward(self, *, past_outputs, **kwargs): -# # Take into account the possible new signature, as in TypedPredictor -# new_signature = kwargs.pop("new_signature", None) -# if new_signature: -# self.original_signature = new_signature -# self.new_signature = self._create_new_signature(self.original_signature) - -# # Convert the dict past_outputs={"answer": ...} to kwargs -# # {past_answer=..., ...} -# for key, value in past_outputs.items(): -# past_key = f"past_{key}" -# if past_key in self.new_signature.input_fields: -# kwargs[past_key] = value -# # Tell the wrapped module to use the new signature. -# # Note: This only works if the wrapped module is a Predict or ChainOfThought. -# kwargs["new_signature"] = self.new_signature -# return self.original_forward(**kwargs) - -# def __call__(self, **kwargs): -# copy.deepcopy(kwargs) -# kwargs["_trace"] = False -# kwargs.setdefault("demos", self.demos if self.demos is not None else []) - -# # perform backtracking -# if dspy.settings.backtrack_to == self: -# for key, value in dspy.settings.backtrack_to_args.items(): -# kwargs.setdefault(key, value) -# pred = self.forward(**kwargs) -# else: -# pred = self.module(**kwargs) - -# # now pop multiple reserved keys -# # NOTE(shangyin) past_outputs seems not useful to include in demos, -# # therefore dropped -# for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]: -# kwargs.pop(key, None) - -# if dspy.settings.trace is not None: -# trace = dspy.settings.trace -# trace.append((self, {**kwargs}, pred)) -# return pred diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py index 9b96012b93..a27ec43e7e 100644 --- a/dspy/utils/dummies.py +++ b/dspy/utils/dummies.py @@ -73,6 +73,7 @@ def __init__(self, answers: Union[list[dict[str, str]], dict[str, dict[str, str] if isinstance(answers, list): self.answers = iter(answers) self.follow_examples = follow_examples + self.call_count = 0 def _use_example(self, messages): # find all field names @@ -126,6 +127,7 @@ def format_answer_fields(field_names_and_values: Dict[str, Any]): entry = dict(**entry, cost=0) self.history.append(entry) self.update_global_history(entry) + self.call_count += 1 return outputs diff --git a/tests/adapters/test_base.py b/tests/adapters/test_base.py new file mode 100644 index 0000000000..2a0b8070a8 --- /dev/null +++ b/tests/adapters/test_base.py @@ -0,0 +1,33 @@ +import pytest +from unittest.mock import patch +from pydantic_core import ValidationError +from pydantic import TypeAdapter +import dspy +from dspy.adapters.base import Adapter +from dspy.signatures.signature import Signature +from dspy.utils.dummies import DummyLM + +class MockSignature(Signature): + question: str = dspy.InputField() + answer: int = dspy.OutputField() + +def test_adapter_retry_logic(): + adapter = Adapter() + lm = DummyLM([{"answer": "42"}]) + demos = [] + inputs = {"question": "6 x 7"} + + with dspy.context(adapter_retry_count=1, lm=lm): + # Mock parse to raise ValidationError on the first call and succeed on the second + with patch.object(adapter, "format", return_value="formatted_input") as mock_format, \ + patch.object(adapter, "parse", side_effect=lambda signature, completion: TypeAdapter(signature).validate_python(completion)) as mock_parse: + with pytest.raises(ValidationError): + adapter(lm, {}, MockSignature, demos, inputs) + + assert mock_format.call_count == 2 + _, kwargs = mock_format.call_args + assert kwargs["inputs"]["previous_response"] == '["[[ ## answer ## ]]\\n42"]' + assert kwargs["inputs"]["question"] == "6 x 7" + assert mock_parse.call_count == 2 + assert lm.call_count == 2 + \ No newline at end of file From 5d3d927d22c13f96edb14a3c70b4845b5eb49f2d Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Wed, 9 Apr 2025 09:48:09 +0900 Subject: [PATCH 02/12] lint --- tests/adapters/test_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/adapters/test_base.py b/tests/adapters/test_base.py index 2a0b8070a8..f487b39644 100644 --- a/tests/adapters/test_base.py +++ b/tests/adapters/test_base.py @@ -30,4 +30,3 @@ def test_adapter_retry_logic(): assert kwargs["inputs"]["question"] == "6 x 7" assert mock_parse.call_count == 2 assert lm.call_count == 2 - \ No newline at end of file From 93f4543c899ce301b0ba4628c8ebc9a43e518439 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Wed, 9 Apr 2025 12:14:36 +0900 Subject: [PATCH 03/12] use ValueError instead of Validation Error --- dspy/adapters/base.py | 5 ++--- dspy/adapters/utils.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index be3acf0164..d9f476fb66 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, Any, Optional, Type import json import logging -from pydantic_core import ValidationError from dspy.adapters.types import History from dspy.adapters.types.image import try_expand_image_tags from dspy.signatures.signature import Signature @@ -55,11 +54,11 @@ def __call__( values.append(value) return values - except ValidationError: + except ValueError: if retry_count >= max_retries: raise - logger.debug("A ValidationError occurred while parsing the LM output. Retrying with a new signature.", exc_info=True) + logger.debug("A ValueError occurred while parsing the LM output. Retrying with a new signature.", exc_info=True) if retry_count == 0: signature = create_signature_for_retry(signature) inputs["previous_response"] = json.dumps(outputs) diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py index 5e903790fe..01c3b9b4c7 100644 --- a/dspy/adapters/utils.py +++ b/dspy/adapters/utils.py @@ -245,6 +245,6 @@ def create_signature_for_retry(signature: Type[Signature]): # Add previous_response field signature = signature.append("previous_response", InputField( prefix="Previous Response", - desc="Previous response with format errors", + desc="Previous response with format errors. You should avoid the same type of error as the previous response.", )) return signature From 181ace4db7b31ad5e5bce047b1b919b8c4e65a9d Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 10 Apr 2025 12:36:05 +0900 Subject: [PATCH 04/12] add demos to tests --- tests/adapters/test_base.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/adapters/test_base.py b/tests/adapters/test_base.py index f487b39644..0094f531fc 100644 --- a/tests/adapters/test_base.py +++ b/tests/adapters/test_base.py @@ -7,20 +7,34 @@ from dspy.signatures.signature import Signature from dspy.utils.dummies import DummyLM + class MockSignature(Signature): question: str = dspy.InputField() answer: int = dspy.OutputField() -def test_adapter_retry_logic(): + +@pytest.mark.parametrize( + "demos", + [ + [], + [dspy.Example({"question": "6 x 7", "answer": 42})], + ], +) +def test_adapter_retry_logic(demos): adapter = Adapter() lm = DummyLM([{"answer": "42"}]) - demos = [] inputs = {"question": "6 x 7"} with dspy.context(adapter_retry_count=1, lm=lm): # Mock parse to raise ValidationError on the first call and succeed on the second - with patch.object(adapter, "format", return_value="formatted_input") as mock_format, \ - patch.object(adapter, "parse", side_effect=lambda signature, completion: TypeAdapter(signature).validate_python(completion)) as mock_parse: + with ( + patch.object(adapter, "format", return_value="formatted_input") as mock_format, + patch.object( + adapter, + "parse", + side_effect=lambda signature, completion: TypeAdapter(signature).validate_python(completion), + ) as mock_parse, + ): with pytest.raises(ValidationError): adapter(lm, {}, MockSignature, demos, inputs) From 78794ba3d25c7cf1aa15416569627deeca4d6c9c Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 10 Apr 2025 13:35:49 +0900 Subject: [PATCH 05/12] avoid forcing all fields in demos --- dspy/adapters/chat_adapter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 98ed24aa2b..3b4d272e93 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -91,9 +91,9 @@ def format_user_message_content( ) -> str: messages = [prefix] for k, v in signature.input_fields.items(): - value = inputs[k] - formatted_field_value = format_field_value(field_info=v, value=value) - messages.append(f"[[ ## {k} ## ]]\n{formatted_field_value}") + if value := inputs.get(k): + formatted_field_value = format_field_value(field_info=v, value=value) + messages.append(f"[[ ## {k} ## ]]\n{formatted_field_value}") output_requirements = self.user_message_output_requirements(signature) if output_requirements is not None: From cbb45e2bb0548f7fc39d0f777862bcc404f5e716 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 10 Apr 2025 18:02:21 +0900 Subject: [PATCH 06/12] fix test --- tests/propose/test_grounded_proposer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/propose/test_grounded_proposer.py b/tests/propose/test_grounded_proposer.py index ddf041632b..9d2108a2da 100644 --- a/tests/propose/test_grounded_proposer.py +++ b/tests/propose/test_grounded_proposer.py @@ -14,7 +14,7 @@ ) def test_propose_instructions_for_program(demo_candidates): # Set large numner here so that lm always returns the same response - prompt_model = DummyLM([{"proposed_instruction": "instruction"}] * 10) + prompt_model = DummyLM([{"observations": ""}, {"summary": ""}, {"program_description": ""}, {"module_description": ""}, {"proposed_instruction": "instruction"}]) program = Predict("question -> answer") trainset = [] @@ -36,7 +36,7 @@ def test_propose_instructions_for_program(demo_candidates): ], ) def test_propose_instruction_for_predictor(demo_candidates): - prompt_model = DummyLM([{"proposed_instruction": "instruction"}] * 10) + prompt_model = DummyLM([{"observations": ""}, {"summary": ""}, {"program_description": ""}, {"module_description": ""}, {"proposed_instruction": "instruction"}]) program = Predict("question -> answer") proposer = GroundedProposer(prompt_model=prompt_model, program=program, trainset=[], verbose=False) From f9ccd3a75e856489ef8901700cfb6d89994455da Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 10 Apr 2025 19:22:05 +0900 Subject: [PATCH 07/12] change default adapter_retry_count to 0 --- dspy/dsp/utils/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/dsp/utils/settings.py b/dspy/dsp/utils/settings.py index 2a25bc12ba..9b8966defc 100644 --- a/dspy/dsp/utils/settings.py +++ b/dspy/dsp/utils/settings.py @@ -22,7 +22,7 @@ disable_history=False, track_usage=False, usage_tracker=None, - adapter_retry_count=3, # number of times to retry the adapter call + adapter_retry_count=0, # number of times to retry the adapter call ) # Global base configuration and owner tracking From 574c25261f62e5b030d32656ef347087a771755d Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Mon, 14 Apr 2025 19:28:40 +0900 Subject: [PATCH 08/12] introduce RetryAdapter --- dspy/adapters/__init__.py | 4 + dspy/adapters/base.py | 51 +++------ dspy/adapters/chat_adapter.py | 20 ---- dspy/adapters/retry_adapter.py | 145 ++++++++++++++++++++++++++ dspy/adapters/utils.py | 4 +- dspy/dsp/utils/settings.py | 1 - dspy/predict/predict.py | 4 +- dspy/predict/react.py | 3 +- dspy/predict/refine.py | 3 +- dspy/teleprompt/bootstrap_finetune.py | 4 +- tests/adapters/test_base.py | 46 -------- tests/adapters/test_retry_adapter.py | 92 ++++++++++++++++ 12 files changed, 269 insertions(+), 108 deletions(-) create mode 100644 dspy/adapters/retry_adapter.py delete mode 100644 tests/adapters/test_base.py create mode 100644 tests/adapters/test_retry_adapter.py diff --git a/dspy/adapters/__init__.py b/dspy/adapters/__init__.py index c592ac6ba1..8866a02abb 100644 --- a/dspy/adapters/__init__.py +++ b/dspy/adapters/__init__.py @@ -1,12 +1,16 @@ from dspy.adapters.base import Adapter from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter +from dspy.adapters.retry_adapter import RetryAdapter from dspy.adapters.types import Image, History +DEFAULT_ADAPTER = RetryAdapter(main_adapter=ChatAdapter(), fallback_adapter=JSONAdapter()) + __all__ = [ "Adapter", "ChatAdapter", "JSONAdapter", + "RetryAdapter", "Image", "History", ] diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 19b2ad773c..0f158a408f 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,12 +1,9 @@ from typing import TYPE_CHECKING, Any, Optional, Type -import json import logging from dspy.adapters.types import History from dspy.adapters.types.image import try_expand_image_tags from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback, with_callbacks -from dspy.adapters.utils import create_signature_for_retry -from dspy.dsp.utils.settings import settings if TYPE_CHECKING: from dspy.clients.lm import LM @@ -32,37 +29,23 @@ def __call__( demos: list[dict[str, Any]], inputs: dict[str, Any], ) -> list[dict[str, Any]]: - retry_count = 0 - max_retries = max(settings.adapter_retry_count, 0) - outputs = None - while (True): - messages = self.format(signature=signature, demos=demos, inputs=inputs) - outputs = lm(messages=messages, **lm_kwargs) - values = [] - try: - for output in outputs: - output_logprobs = None - - if isinstance(output, dict): - output, output_logprobs = output["text"], output["logprobs"] - - value = self.parse(signature, output) - - if output_logprobs is not None: - value["logprobs"] = output_logprobs - - values.append(value) - - return values - except ValueError: - if retry_count >= max_retries: - raise - - logger.debug("A ValueError occurred while parsing the LM output. Retrying with a new signature.", exc_info=True) - if retry_count == 0: - signature = create_signature_for_retry(signature) - inputs["previous_response"] = json.dumps(outputs) - retry_count += 1 + messages = self.format(signature=signature, demos=demos, inputs=inputs) + outputs = lm(messages=messages, **lm_kwargs) + values = [] + for output in outputs: + output_logprobs = None + + if isinstance(output, dict): + output, output_logprobs = output["text"], output["logprobs"] + + value = self.parse(signature, output) + + if output_logprobs is not None: + value["logprobs"] = output_logprobs + + values.append(value) + + return values def format( self, diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index f2228b756f..f4c2a040f2 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -29,26 +29,6 @@ class ChatAdapter(Adapter): def __init__(self, callbacks: Optional[list[BaseCallback]] = None): super().__init__(callbacks) - def __call__( - self, - lm: LM, - lm_kwargs: dict[str, Any], - signature: Type[Signature], - demos: list[dict[str, Any]], - inputs: dict[str, Any], - ) -> list[dict[str, Any]]: - try: - return super().__call__(lm, lm_kwargs, signature, demos, inputs) - except Exception as e: - # fallback to JSONAdapter - from dspy.adapters.json_adapter import JSONAdapter - - if isinstance(e, ContextWindowExceededError) or isinstance(self, JSONAdapter): - # On context window exceeded error or already using JSONAdapter, we don't want to retry with a different - # adapter. - raise e - return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs) - def format_field_description(self, signature: Type[Signature]) -> str: return ( f"Your input fields are:\n{get_field_description_string(signature.input_fields)}\n" diff --git a/dspy/adapters/retry_adapter.py b/dspy/adapters/retry_adapter.py new file mode 100644 index 0000000000..fe282b705f --- /dev/null +++ b/dspy/adapters/retry_adapter.py @@ -0,0 +1,145 @@ + +from typing import TYPE_CHECKING, Any, Optional, Type +import logging + +from dspy.adapters.base import Adapter +from dspy.signatures.signature import Signature +from dspy.adapters.utils import create_signature_for_retry + +if TYPE_CHECKING: + from dspy.clients.lm import LM + +logger = logging.getLogger(__name__) + +class RetryAdapter(Adapter): + """ + RetryAdapter is an adapter that retries the execution of another adapter for + a specified number of times if it fails to parse completion outputs. + """ + + def __init__(self, main_adapter: Adapter, fallback_adapter: Optional[Adapter] = None, max_retries: int = 3): + """ + Initializes the RetryAdapter. + + Args: + main_adapter (Adapter): The main adapter to use. + fallback_adapter (Optional[Adapter]): The fallback adapter to use if the main adapter fails. + max_retries (int): The maximum number of retries. Defaults to 3. + """ + self.main_adapter = main_adapter + self.fallback_adapter = fallback_adapter + self.max_retries = max_retries + + def __call__( + self, + lm: "LM", + lm_kwargs: dict[str, Any], + signature: Type[Signature], + demos: list[dict[str, Any]], + inputs: dict[str, Any], + ) -> list[dict[str, Any]]: + """ + Execute main_adapter and fallback_adapter in the following procedure: + 1. Call the main_adapter. + 2. If the main_adapter fails, call the fallback_adapter. + 3. If the fallback_adapter fails, retry the main_adapter including previous response for `max_retries` times. + + Args: + lm (LM): The dspy.LM to use. + lm_kwargs (dict[str, Any]): Additional arguments for the lm. + signature (Type[Signature]): The signature of the function. + demos (list[dict[str, Any]]): A list of demo examples. + inputs (dict[str, Any]): A list representating the user input. + + Returns: + A list of parsed completions. The size of the list is equal to `n` argument. Defaults to 1. + + Raises: + Exception: If fail to parse outputs after the maximum number of retries. + """ + outputs = [] + max_retries = max(self.max_retries, 0) + n_completion = lm_kwargs.get("n", 1) + + values, parse_failures = self._call_adapter( + self.main_adapter, + lm, + lm_kwargs, + signature, + demos, + inputs, + ) + outputs.extend(values) + + if len(outputs) == n_completion: + return outputs + + lm_kwargs["n"] = n_completion - len(outputs) + if self.fallback_adapter is not None: + outputs.extend(self._call_adapter( + self.fallback_adapter, + lm, + lm_kwargs, + signature, + demos, + inputs, + )[0]) + if len(outputs) == n_completion: + return outputs + + # Retry the main adapter with previous response for `max_retries` times + lm_kwargs["n"] = 1 + signature = create_signature_for_retry(signature) + if parse_failures: + inputs["previous_response"] = parse_failures[0][0] + inputs["error_message"] = str(parse_failures[0][1]) + for i in range(max_retries): + values, parse_failures = self._call_adapter( + self.main_adapter, + lm, + lm_kwargs, + signature, + demos, + inputs, + ) + outputs.extend(values) + if len(outputs) == n_completion: + return outputs + logger.warning(f"Retry {i+1}/{max_retries} for {self.main_adapter.__class__.__name__} failed with error: {parse_failures[0][1]}") + inputs["previous_response"] = parse_failures[0][0] + inputs["error_message"] = str(parse_failures[0][1]) + + # raise the last error + raise ValueError("Failed to parse LM outputs for maximum retries.") from parse_failures[0][1] + + def _call_adapter( + self, + adapter: Adapter, + lm: "LM", + lm_kwargs: dict[str, Any], + signature: Type[Signature], + demos: list[dict[str, Any]], + inputs: dict[str, Any], + ): + values = [] + parse_failures = [] + messages = adapter.format(signature=signature, demos=demos, inputs=inputs) + outputs = lm(messages=messages, **lm_kwargs) + for i, output in enumerate(outputs): + try: + output_logprobs = None + + if isinstance(output, dict): + output, output_logprobs = output["text"], output["logprobs"] + + value = adapter.parse(signature, output) + + if output_logprobs is not None: + value["logprobs"] = output_logprobs + + values.append(value) + except ValueError as e: + logger.warning(f"Failed to parse the {i+1}/{lm_kwargs.get('n', 1)} LM output with adapter {adapter.__class__.__name__}. Error: {e}") + parse_failures.append((outputs[i], e)) + + return values, parse_failures \ No newline at end of file diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py index 01c3b9b4c7..7b0c608962 100644 --- a/dspy/adapters/utils.py +++ b/dspy/adapters/utils.py @@ -242,9 +242,11 @@ def _quoted_string_for_literal_type_annotation(s: str) -> str: return f"'{s}'" def create_signature_for_retry(signature: Type[Signature]): - # Add previous_response field signature = signature.append("previous_response", InputField( prefix="Previous Response", desc="Previous response with format errors. You should avoid the same type of error as the previous response.", + )).append("error_message", InputField( + prefix="Validation Error Message", + desc="Error message for the previous response.", )) return signature diff --git a/dspy/dsp/utils/settings.py b/dspy/dsp/utils/settings.py index 9b8966defc..fad3cc67f0 100644 --- a/dspy/dsp/utils/settings.py +++ b/dspy/dsp/utils/settings.py @@ -22,7 +22,6 @@ disable_history=False, track_usage=False, usage_tracker=None, - adapter_retry_count=0, # number of times to retry the adapter call ) # Global base configuration and owner tracking diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index 8813697aed..3a9aef16a2 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -from dspy.adapters.chat_adapter import ChatAdapter +from dspy.adapters import DEFAULT_ADAPTER from dspy.clients.base_lm import BaseLM from dspy.clients.lm import LM from dspy.dsp.utils import settings @@ -103,7 +103,7 @@ def forward(self, **kwargs): missing, ) - adapter = settings.adapter or ChatAdapter() + adapter = settings.adapter or DEFAULT_ADAPTER completions = adapter( lm, lm_kwargs=config, diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 395a8f5ea8..19f9b57512 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -5,6 +5,7 @@ from pydantic import BaseModel import dspy +from dspy.adapters import DEFAULT_ADAPTER from dspy.primitives.program import Module from dspy.primitives.tool import Tool from dspy.signatures.signature import ensure_signature @@ -67,7 +68,7 @@ def __init__(self, signature, tools: list[Callable], max_iters=5): self.extract = dspy.ChainOfThought(fallback_signature) def _format_trajectory(self, trajectory: dict[str, Any]): - adapter = dspy.settings.adapter or dspy.ChatAdapter() + adapter = dspy.settings.adapter or DEFAULT_ADAPTER trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x") return adapter.format_user_message_content(trajectory_signature, trajectory) diff --git a/dspy/predict/refine.py b/dspy/predict/refine.py index 9991750d63..b983159ac9 100644 --- a/dspy/predict/refine.py +++ b/dspy/predict/refine.py @@ -6,6 +6,7 @@ import dspy from dspy.adapters.utils import get_field_description_string +from dspy.adapters import DEFAULT_ADAPTER from dspy.predict.predict import Prediction from dspy.signatures import InputField, OutputField, Signature @@ -100,7 +101,7 @@ def forward(self, **kwargs): temps = list(dict.fromkeys(temps))[: self.N] best_pred, best_trace, best_reward = None, None, -float("inf") advice = None - adapter = dspy.settings.adapter or dspy.ChatAdapter() + adapter = dspy.settings.adapter or DEFAULT_ADAPTER for idx, t in enumerate(temps): lm_ = lm.copy(temperature=t) diff --git a/dspy/teleprompt/bootstrap_finetune.py b/dspy/teleprompt/bootstrap_finetune.py index 6e5bf569af..f6d0790aa1 100644 --- a/dspy/teleprompt/bootstrap_finetune.py +++ b/dspy/teleprompt/bootstrap_finetune.py @@ -4,7 +4,7 @@ import dspy from dspy.adapters.base import Adapter -from dspy.adapters.chat_adapter import ChatAdapter +from dspy.adapters import DEFAULT_ADAPTER from dspy.clients.lm import LM from dspy.clients.utils_finetune import infer_data_format from dspy.dsp.utils.settings import settings @@ -162,7 +162,7 @@ def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_ logger.info(f"After filtering with the metric, {len(trace_data)} examples remain") data = [] - adapter = self.adapter[lm] or settings.adapter or ChatAdapter() + adapter = self.adapter[lm] or settings.adapter or DEFAULT_ADAPTER data_format = infer_data_format(adapter) for item in trace_data: for pred_ind, _ in enumerate(item["trace"]): diff --git a/tests/adapters/test_base.py b/tests/adapters/test_base.py deleted file mode 100644 index 0094f531fc..0000000000 --- a/tests/adapters/test_base.py +++ /dev/null @@ -1,46 +0,0 @@ -import pytest -from unittest.mock import patch -from pydantic_core import ValidationError -from pydantic import TypeAdapter -import dspy -from dspy.adapters.base import Adapter -from dspy.signatures.signature import Signature -from dspy.utils.dummies import DummyLM - - -class MockSignature(Signature): - question: str = dspy.InputField() - answer: int = dspy.OutputField() - - -@pytest.mark.parametrize( - "demos", - [ - [], - [dspy.Example({"question": "6 x 7", "answer": 42})], - ], -) -def test_adapter_retry_logic(demos): - adapter = Adapter() - lm = DummyLM([{"answer": "42"}]) - inputs = {"question": "6 x 7"} - - with dspy.context(adapter_retry_count=1, lm=lm): - # Mock parse to raise ValidationError on the first call and succeed on the second - with ( - patch.object(adapter, "format", return_value="formatted_input") as mock_format, - patch.object( - adapter, - "parse", - side_effect=lambda signature, completion: TypeAdapter(signature).validate_python(completion), - ) as mock_parse, - ): - with pytest.raises(ValidationError): - adapter(lm, {}, MockSignature, demos, inputs) - - assert mock_format.call_count == 2 - _, kwargs = mock_format.call_args - assert kwargs["inputs"]["previous_response"] == '["[[ ## answer ## ]]\\n42"]' - assert kwargs["inputs"]["question"] == "6 x 7" - assert mock_parse.call_count == 2 - assert lm.call_count == 2 diff --git a/tests/adapters/test_retry_adapter.py b/tests/adapters/test_retry_adapter.py new file mode 100644 index 0000000000..5fd661d126 --- /dev/null +++ b/tests/adapters/test_retry_adapter.py @@ -0,0 +1,92 @@ +import pytest +from unittest.mock import patch +import dspy +from dspy.adapters import RetryAdapter, ChatAdapter, JSONAdapter +from dspy.signatures.signature import Signature +from dspy.utils.dummies import DummyLM + + +class MockSignature(Signature): + question: str = dspy.InputField() + answer: int = dspy.OutputField() + + +@pytest.mark.parametrize( + "demos", + [ + [], + [dspy.Example({"question": "6 x 7", "answer": 42})], + ], +) +@pytest.mark.parametrize( + "max_retries", + [ + 0, + 3, + ], +) +@pytest.mark.parametrize( + "n", + [ + 1, + 3, + ], +) +def test_adapter_max_retry(demos, max_retries, n): + main_adapter = ChatAdapter() + fallback_adapter = JSONAdapter() + adapter = RetryAdapter(main_adapter=main_adapter, fallback_adapter=fallback_adapter, max_retries=max_retries) + lm = DummyLM([{"answer": "42"}] * (n * 2 + max_retries)) + inputs = {"question": "6 x 7"} + + with dspy.context(lm=lm): + with ( + patch.object( + main_adapter, + "parse", + side_effect=ValueError("error"), + ) as mock_main_parse, + patch.object( + main_adapter, + "format", + wraps=main_adapter.format, + ) as mock_main_format, + patch.object( + fallback_adapter, + "parse", + side_effect=ValueError("error"), + ) as mock_fallback_parse, + ): + with pytest.raises(ValueError, match="Failed to parse LM outputs for maximum retries"): + adapter(lm, {"n": n}, MockSignature, demos, inputs) + + assert mock_main_parse.call_count == n + max_retries + assert mock_fallback_parse.call_count == n + assert lm.call_count == max_retries + 2 + + assert mock_main_format.call_count == max_retries + 1 + _, kwargs = mock_main_format.call_args + assert kwargs["inputs"]["previous_response"] == "[[ ## answer ## ]]\n42" + assert kwargs["inputs"]["error_message"] == "error" + assert kwargs["inputs"]["question"] == "6 x 7" + + +def test_adapter_fallback(): + main_adapter = JSONAdapter() + fallback_adapter = ChatAdapter() + adapter = RetryAdapter(main_adapter=main_adapter, fallback_adapter=fallback_adapter, max_retries=1) + lm = DummyLM([{"answer": "42"}] * 3) + inputs = {"question": "6 x 7"} + + with dspy.context(lm=lm): + with ( + patch.object( + main_adapter, + "parse", + side_effect=ValueError("error"), + ) as mock_main_parse, + ): + result = adapter(lm, {}, MockSignature, [], inputs) + + assert result == [{"answer": 42}] + assert mock_main_parse.call_count == 1 From 27977c67023ec28101ec1f6e3442fc2d4d9c212d Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Mon, 14 Apr 2025 19:29:54 +0900 Subject: [PATCH 09/12] lint --- dspy/adapters/chat_adapter.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index f4c2a040f2..09453fe6b2 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -2,7 +2,6 @@ import textwrap from typing import Any, Dict, NamedTuple, Optional, Type -from litellm import ContextWindowExceededError from pydantic.fields import FieldInfo from dspy.adapters.base import Adapter @@ -13,7 +12,6 @@ parse_value, translate_field_type, ) -from dspy.clients.lm import LM from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback From ae67f49bb595fb6efd93d86d1e7988c7327ab316 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Mon, 14 Apr 2025 19:39:03 +0900 Subject: [PATCH 10/12] fix ReAct test --- dspy/predict/react.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 19f9b57512..395a8f5ea8 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -5,7 +5,6 @@ from pydantic import BaseModel import dspy -from dspy.adapters import DEFAULT_ADAPTER from dspy.primitives.program import Module from dspy.primitives.tool import Tool from dspy.signatures.signature import ensure_signature @@ -68,7 +67,7 @@ def __init__(self, signature, tools: list[Callable], max_iters=5): self.extract = dspy.ChainOfThought(fallback_signature) def _format_trajectory(self, trajectory: dict[str, Any]): - adapter = dspy.settings.adapter or DEFAULT_ADAPTER + adapter = dspy.settings.adapter or dspy.ChatAdapter() trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x") return adapter.format_user_message_content(trajectory_signature, trajectory) From bb4ce1703250185d9d1bb8ef2a99210da9a4d1b6 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Tue, 15 Apr 2025 15:54:22 +0900 Subject: [PATCH 11/12] remove logger --- dspy/adapters/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 0f158a408f..259bea4d49 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,5 +1,4 @@ from typing import TYPE_CHECKING, Any, Optional, Type -import logging from dspy.adapters.types import History from dspy.adapters.types.image import try_expand_image_tags from dspy.signatures.signature import Signature @@ -8,8 +7,6 @@ if TYPE_CHECKING: from dspy.clients.lm import LM -logger = logging.getLogger(__name__) - class Adapter: def __init__(self, callbacks: Optional[list[BaseCallback]] = None): self.callbacks = callbacks or [] @@ -32,6 +29,7 @@ def __call__( messages = self.format(signature=signature, demos=demos, inputs=inputs) outputs = lm(messages=messages, **lm_kwargs) values = [] + for output in outputs: output_logprobs = None From 1e7c9bb61918ff39fcc0808a495ca76372a49c82 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 17 Apr 2025 02:08:19 +0900 Subject: [PATCH 12/12] address comment --- dspy/adapters/retry_adapter.py | 8 ++++---- dspy/adapters/utils.py | 2 -- tests/adapters/test_retry_adapter.py | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/dspy/adapters/retry_adapter.py b/dspy/adapters/retry_adapter.py index fe282b705f..f77dca985a 100644 --- a/dspy/adapters/retry_adapter.py +++ b/dspy/adapters/retry_adapter.py @@ -17,18 +17,18 @@ class RetryAdapter(Adapter): a specified number of times if it fails to parse completion outputs. """ - def __init__(self, main_adapter: Adapter, fallback_adapter: Optional[Adapter] = None, max_retries: int = 3): + def __init__(self, main_adapter: Adapter, fallback_adapter: Optional[Adapter] = None, main_adapter_max_retries: int = 3): """ Initializes the RetryAdapter. Args: main_adapter (Adapter): The main adapter to use. fallback_adapter (Optional[Adapter]): The fallback adapter to use if the main adapter fails. - max_retries (int): The maximum number of retries. Defaults to 3. + main_adapter_max_retries (int): The maximum number of retries. Defaults to 3. """ self.main_adapter = main_adapter self.fallback_adapter = fallback_adapter - self.max_retries = max_retries + self.main_adapter_max_retries = main_adapter_max_retries def __call__( self, @@ -58,7 +58,7 @@ def __call__( Exception: If fail to parse outputs after the maximum number of retries. """ outputs = [] - max_retries = max(self.max_retries, 0) + max_retries = max(self.main_adapter_max_retries, 0) n_completion = lm_kwargs.get("n", 1) values, parse_failures = self._call_adapter( diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py index 7b0c608962..2059fabef6 100644 --- a/dspy/adapters/utils.py +++ b/dspy/adapters/utils.py @@ -243,10 +243,8 @@ def _quoted_string_for_literal_type_annotation(s: str) -> str: def create_signature_for_retry(signature: Type[Signature]): signature = signature.append("previous_response", InputField( - prefix="Previous Response", desc="Previous response with format errors. You should avoid the same type of error as the previous response.", )).append("error_message", InputField( - prefix="Validation Error Message", desc="Error message for the previous response.", )) return signature diff --git a/tests/adapters/test_retry_adapter.py b/tests/adapters/test_retry_adapter.py index 5fd661d126..1f3cc37163 100644 --- a/tests/adapters/test_retry_adapter.py +++ b/tests/adapters/test_retry_adapter.py @@ -35,7 +35,7 @@ class MockSignature(Signature): def test_adapter_max_retry(demos, max_retries, n): main_adapter = ChatAdapter() fallback_adapter = JSONAdapter() - adapter = RetryAdapter(main_adapter=main_adapter, fallback_adapter=fallback_adapter, max_retries=max_retries) + adapter = RetryAdapter(main_adapter=main_adapter, fallback_adapter=fallback_adapter, main_adapter_max_retries=max_retries) lm = DummyLM([{"answer": "42"}] * (n * 2 + max_retries)) inputs = {"question": "6 x 7"} @@ -74,7 +74,7 @@ def test_adapter_max_retry(demos, max_retries, n): def test_adapter_fallback(): main_adapter = JSONAdapter() fallback_adapter = ChatAdapter() - adapter = RetryAdapter(main_adapter=main_adapter, fallback_adapter=fallback_adapter, max_retries=1) + adapter = RetryAdapter(main_adapter=main_adapter, fallback_adapter=fallback_adapter, main_adapter_max_retries=1) lm = DummyLM([{"answer": "42"}] * 3) inputs = {"question": "6 x 7"}