diff --git a/docs/docs/how-tos/node-retries.ipynb b/docs/docs/how-tos/node-retries.ipynb index b4bd89e1ae..600afa4b1e 100644 --- a/docs/docs/how-tos/node-retries.ipynb +++ b/docs/docs/how-tos/node-retries.ipynb @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -78,20 +78,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "RetryPolicy(initial_interval=0.5, backoff_factor=2.0, max_interval=128.0, max_attempts=3, jitter=True, retry_on=)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from langgraph.pregel import RetryPolicy\n", "\n", @@ -131,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -180,6 +169,75 @@ "\n", "graph = builder.compile()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## State-Modifying Retry Policies\n", + "\n", + "Sometimes you might want to modify the state before retrying after an error occurs. LangGraph now supports this through a dictionary mapping of exceptions to state-modifying functions in the retry policy.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exception occured: Error occurred in processing\n", + "Counter: 0\n", + "Exception occured: Error occurred in processing\n", + "Counter: 1\n", + "Exception occured: Error occurred in processing\n", + "Counter: 2\n", + "{'counter': 3, 'result': 'Success!'}\n" + ] + } + ], + "source": [ + "class CounterState(TypedDict):\n", + " counter: int\n", + " result: str\n", + "\n", + "\n", + "def increment_counter(state: CounterState, exception: Exception):\n", + " print(f\"Exception occured: {exception}\")\n", + " print(f\"Counter: {state['counter']}\")\n", + " state['counter'] += 1\n", + "\n", + "def processing_node(state: CounterState) -> CounterState:\n", + " if state['counter'] < 3:\n", + " raise ValueError(\"Error occurred in processing\")\n", + " state['result'] = \"Success!\"\n", + " return state\n", + "\n", + "workflow = StateGraph(CounterState)\n", + "\n", + "retry_policy = RetryPolicy(\n", + " max_attempts=5,\n", + " retry_on={ValueError: increment_counter}\n", + ")\n", + "\n", + "\n", + "workflow.add_node(\"process\", processing_node, retry=retry_policy)\n", + "workflow.add_edge(START, \"process\")\n", + "workflow.add_edge(\"process\", END)\n", + "\n", + "app = workflow.compile()\n", + "final_state = app.invoke({\"counter\": 0, \"result\": \"\"})\n", + "print(final_state)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example, we define a retry policy that maps `ValueError` to the `increment_counter` function. When a `ValueError` occurs, `increment_counter` is called with the current state and the exception before the retry attempt. " + ] } ], "metadata": { diff --git a/libs/langgraph/langgraph/pregel/retry.py b/libs/langgraph/langgraph/pregel/retry.py index 6d0e43b547..00e9d8a098 100644 --- a/libs/langgraph/langgraph/pregel/retry.py +++ b/libs/langgraph/langgraph/pregel/retry.py @@ -77,6 +77,12 @@ def run_with_retry( elif callable(retry_policy.retry_on): if not retry_policy.retry_on(exc): # type: ignore[call-arg] raise + elif isinstance(retry_policy.retry_on, dict): + exception_handler = retry_policy.retry_on.get(type(exc), False) + if callable(exception_handler): + exception_handler(task.input, exc) + else: + raise else: raise TypeError( "retry_on must be an Exception class, a list or tuple of Exception classes, or a callable" @@ -165,6 +171,12 @@ async def arun_with_retry( elif callable(retry_policy.retry_on): if not retry_policy.retry_on(exc): # type: ignore[call-arg] raise + elif isinstance(retry_policy.retry_on, dict): + exception_handler = retry_policy.retry_on.get(type(exc), False) + if callable(exception_handler): + exception_handler(task.input, exc) + else: + raise else: raise TypeError( "retry_on must be an Exception class, a list or tuple of Exception classes, or a callable" diff --git a/libs/langgraph/langgraph/types.py b/libs/langgraph/langgraph/types.py index 0a3e9f2058..cbb87e1858 100644 --- a/libs/langgraph/langgraph/types.py +++ b/libs/langgraph/langgraph/types.py @@ -114,9 +114,12 @@ class RetryPolicy(NamedTuple): jitter: bool = True """Whether to add random jitter to the interval between retries.""" retry_on: Union[ - Type[Exception], Sequence[Type[Exception]], Callable[[Exception], bool] + Type[Exception], + Sequence[Type[Exception]], + Callable[[Exception], bool], + dict[Type[Exception], Callable[[Any, Exception], None]], ] = default_retry_on - """List of exception classes that should trigger a retry, or a callable that returns True for exceptions that should trigger a retry.""" + """List of exception classes that should trigger a retry, or a callable that returns True for exceptions that should trigger a retry, or a dictionary mapping exception classes to callables that can modify state before retrying. The state-modifying callables should accept (state, exception) as parameters.""" class CachePolicy(NamedTuple):