Skip to content

Allow overwrite max_iter in ReAct #8096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def _format_trajectory(self, trajectory: dict[str, Any]):

def forward(self, **input_args):
trajectory = {}
for idx in range(self.max_iters):
max_iters = input_args.pop("max_iters", self.max_iters)
for idx in range(max_iters):
pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)

trajectory[f"thought_{idx}"] = pred.next_thought
Expand Down Expand Up @@ -173,8 +174,4 @@ def truncate_trajectory(self, trajectory):
TOPIC 06: Idiomatically allowing tools that maintain state across iterations, but not across different `forward` calls.
* So the tool would be newly initialized at the start of each `forward` call, but maintain state across iterations.
* This is pretty useful for allowing the agent to keep notes or count certain things, etc.

TOPIC 07: Make max_iters a bit more expressive.
* Allow passing `max_iters` in forward to overwrite the default.
* Get rid of `last_iteration: bool` in the format function. It's not necessary now.
"""
40 changes: 36 additions & 4 deletions tests/predict/test_react.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from dataclasses import dataclass

from pydantic import BaseModel

import dspy
from dspy.predict import react
from dspy.utils.dummies import DummyLM, dummy_rm
from dspy.utils.dummies import DummyLM
import litellm

# def test_example_no_tools():
Expand Down Expand Up @@ -276,3 +273,38 @@ def mock_react(**kwargs):
assert "thought_0" not in result.trajectory
assert "thought_2" in result.trajectory
assert result.output_text == "Final output"


def test_error_retry():
def foo(a, b):
raise Exception("tool error")

react = dspy.ReAct("a, b -> c:int", tools=[foo])
max_iters = 2
lm = DummyLM(
[
{"next_thought": "I need to add two numbers.", "next_tool_name": "foo", "next_tool_args": {"a": 1, "b": 2}},
{"next_thought": "I need to add two numbers.", "next_tool_name": "foo", "next_tool_args": {"a": 1, "b": 2}},
{"reasoning": "I added the numbers successfully", "c": 3},
]
)
dspy.settings.configure(lm=lm)

outputs = react(a=1, b=2, max_iters=max_iters)
expected_trajectory = {
"thought_0": "I need to add two numbers.",
"tool_name_0": "foo",
"tool_args_0": {
"a": 1,
"b": 2,
},
'observation_0': 'Failed to execute: tool error',
'thought_1': 'I need to add two numbers.',
'tool_name_1': 'foo',
"tool_args_1": {
"a": 1,
"b": 2,
},
'observation_1': 'Failed to execute: tool error',
}
assert outputs.trajectory == expected_trajectory
Loading