Skip to content

Commit ea148eb

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Modify the Mcp retry logic to be completely async
PiperOrigin-RevId: 805881432
1 parent 6f6e5a2 commit ea148eb

File tree

2 files changed

+141
-60
lines changed

2 files changed

+141
-60
lines changed

core/src/main/java/com/google/adk/tools/mcp/McpToolset.java

Lines changed: 58 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
package com.google.adk.tools.mcp;
1818

19-
import static com.google.common.collect.ImmutableList.toImmutableList;
19+
import static java.util.concurrent.TimeUnit.MILLISECONDS;
2020

2121
import com.fasterxml.jackson.databind.ObjectMapper;
2222
import com.google.adk.JsonBaseModel;
@@ -198,73 +198,72 @@ public McpToolset(StreamableHttpServerParameters connectionParams) {
198198

199199
@Override
200200
public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
201-
return Flowable.fromCallable(
201+
return Flowable.defer(
202202
() -> {
203-
for (int i = 0; i < MAX_RETRIES; i++) {
204-
try {
205-
if (this.mcpSession == null) {
206-
logger.info("MCP session is null or closed, initializing (attempt {}).", i + 1);
207-
this.mcpSession = this.mcpSessionManager.createSession();
208-
}
209-
210-
ListToolsResult toolsResponse = this.mcpSession.listTools();
211-
return toolsResponse.tools().stream()
203+
if (this.mcpSession == null) {
204+
logger.info("MCP session is null, initializing.");
205+
this.mcpSession = this.mcpSessionManager.createSession();
206+
}
207+
208+
// Retrieve tools from the MCP session, wrap them in McpTool, filter them, and return
209+
// as a Flowable.
210+
ListToolsResult toolsResponse = this.mcpSession.listTools();
211+
return Flowable.fromStream(
212+
toolsResponse.tools().stream()
212213
.map(
213214
tool ->
214215
new McpTool(
215216
tool, this.mcpSession, this.mcpSessionManager, this.objectMapper))
216217
.filter(
217218
tool ->
218219
isToolSelected(
219-
tool, toolFilter, Optional.ofNullable(readonlyContext)))
220-
.collect(toImmutableList());
221-
} catch (IllegalArgumentException e) {
222-
// This could happen if parameters for tool loading are somehow invalid.
223-
// This is likely a fatal error and should not be retried.
224-
logger.error("Invalid argument encountered during tool loading.", e);
225-
throw new McpToolsetException.McpToolLoadingException(
226-
"Invalid argument encountered during tool loading.", e);
227-
} catch (RuntimeException e) { // Catch any other unexpected runtime exceptions
228-
logger.error("Unexpected error during tool loading, retry attempt " + (i + 1), e);
229-
if (i < MAX_RETRIES - 1) {
230-
// For other general exceptions, we might still want to retry if they are
231-
// potentially transient, or if we don't have more specific handling. But it's
232-
// better to be specific. For now, we'll treat them as potentially retryable but
233-
// log
234-
// them at a higher level.
235-
try {
236-
logger.info(
237-
"Reinitializing MCP session before next retry for unexpected error.");
238-
this.mcpSession = this.mcpSessionManager.createSession();
239-
Thread.sleep(RETRY_DELAY_MILLIS);
240-
} catch (InterruptedException ie) {
241-
Thread.currentThread().interrupt();
242-
logger.error(
243-
"Interrupted during retry delay for loadTools (unexpected error).", ie);
244-
throw new McpToolsetException.McpToolLoadingException(
245-
"Interrupted during retry delay (unexpected error)", ie);
246-
} catch (RuntimeException reinitE) {
247-
logger.error(
248-
"Failed to reinitialize session during retry (unexpected error).",
249-
reinitE);
250-
throw new McpToolsetException.McpInitializationException(
251-
"Failed to reinitialize session during tool loading retry (unexpected"
252-
+ " error).",
253-
reinitE);
254-
}
255-
} else {
256-
logger.error(
257-
"Failed to load tools after multiple retries due to unexpected error.", e);
258-
throw new McpToolsetException.McpToolLoadingException(
259-
"Failed to load tools after multiple retries due to unexpected error.", e);
260-
}
261-
}
262-
}
263-
// This line should ideally not be reached if retries are handled correctly or an
264-
// exception is always thrown.
265-
throw new IllegalStateException("Unexpected state in getTools retry loop");
220+
tool, toolFilter, Optional.ofNullable(readonlyContext))));
266221
})
267-
.flatMapIterable(tools -> tools);
222+
.retryWhen(
223+
errorObservable ->
224+
errorObservable.zipWith(
225+
Flowable.range(1, MAX_RETRIES),
226+
(error, retryCount) -> {
227+
if (error instanceof IllegalArgumentException) {
228+
// This could happen if parameters for tool loading are somehow invalid.
229+
// This is likely a fatal error and should not be retried.
230+
logger.error("Invalid argument encountered during tool loading.", error);
231+
throw new McpToolsetException.McpToolLoadingException(
232+
"Invalid argument encountered during tool loading.", error);
233+
} else if (error instanceof RuntimeException) {
234+
// Catch any other unexpected runtime exceptions
235+
logger.error(
236+
"Unexpected error during tool loading, retry attempt " + retryCount,
237+
error);
238+
logger.info(
239+
"Reinitializing MCP session before next retry for unexpected error.");
240+
this.mcpSession = null;
241+
242+
if (retryCount < MAX_RETRIES) {
243+
// For other general exceptions, we might still want to retry if they are
244+
// potentially transient, or if we don't have more specific handling. But
245+
// it's better to be specific. For now, we'll treat them as potentially
246+
// retryable but log them at a higher level.
247+
248+
// Delay before retrying
249+
return Flowable.timer(RETRY_DELAY_MILLIS, MILLISECONDS);
250+
} else {
251+
logger.error(
252+
"Failed to load tools after multiple retries due to unexpected"
253+
+ " error.",
254+
error);
255+
throw new McpToolsetException.McpToolLoadingException(
256+
"Failed to load tools after multiple retries due to unexpected"
257+
+ " error.",
258+
error);
259+
}
260+
}
261+
// This line should ideally not be reached if retries are handled correctly or
262+
// an exception is always thrown.
263+
// If an unhandled error type occurs, propagate it.
264+
return Flowable.error(error);
265+
}))
266+
.map(tools -> tools);
268267
}
269268

270269
@Override

core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,38 @@
1717
package com.google.adk.tools.mcp;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20+
import static java.util.concurrent.TimeUnit.SECONDS;
2021
import static org.junit.Assert.assertThrows;
22+
import static org.mockito.Mockito.times;
23+
import static org.mockito.Mockito.verify;
24+
import static org.mockito.Mockito.when;
2125

26+
import com.google.adk.JsonBaseModel;
2227
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
28+
import com.google.adk.agents.ReadonlyContext;
2329
import com.google.adk.tools.BaseTool;
2430
import com.google.adk.tools.mcp.McpToolset.McpToolsetConfig;
2531
import com.google.common.collect.ImmutableList;
2632
import com.google.common.collect.ImmutableMap;
33+
import io.modelcontextprotocol.client.McpSyncClient;
34+
import io.modelcontextprotocol.spec.McpSchema;
35+
import java.util.List;
36+
import java.util.Optional;
37+
import org.junit.Rule;
2738
import org.junit.Test;
2839
import org.junit.runner.RunWith;
2940
import org.junit.runners.JUnit4;
41+
import org.mockito.Mock;
42+
import org.mockito.junit.MockitoJUnit;
43+
import org.mockito.junit.MockitoRule;
3044

3145
@RunWith(JUnit4.class)
3246
public class McpToolsetTest {
47+
@Rule public final MockitoRule mocks = MockitoJUnit.rule();
48+
49+
@Mock private McpSessionManager mockMcpSessionManager;
50+
@Mock private McpSyncClient mockMcpSyncClient;
51+
@Mock private ReadonlyContext mockReadonlyContext;
3352

3453
@Test
3554
public void testMcpToolsetConfig_withStdioServerParams_parsesCorrectly() {
@@ -135,7 +154,6 @@ public void testFromConfig_validStdioParams_createsToolset() throws Configuratio
135154
McpToolset toolset = McpToolset.fromConfig(config, configPath);
136155

137156
assertThat(toolset).isNotNull();
138-
// The toolset should be created successfully with stdio parameters
139157
}
140158

141159
@Test
@@ -243,4 +261,68 @@ public void testFromConfig_emptyToolFilter_createsToolset() throws Configuration
243261
assertThat(toolset).isNotNull();
244262
// The toolset should be created successfully with empty tool filter
245263
}
264+
265+
@Test
266+
public void getTools_withToolFilter_returnsFilteredTools() {
267+
ImmutableList<String> toolFilter = ImmutableList.of("tool1", "tool3");
268+
McpSchema.Tool mockTool1 =
269+
McpSchema.Tool.builder().name("tool1").description("desc1").inputSchema("{}").build();
270+
McpSchema.Tool mockTool2 =
271+
McpSchema.Tool.builder().name("tool2").description("desc2").inputSchema("{}").build();
272+
McpSchema.Tool mockTool3 =
273+
McpSchema.Tool.builder().name("tool3").description("desc3").inputSchema("{}").build();
274+
McpSchema.ListToolsResult mockResult =
275+
new McpSchema.ListToolsResult(ImmutableList.of(mockTool1, mockTool2, mockTool3), null);
276+
277+
when(mockMcpSessionManager.createSession()).thenReturn(mockMcpSyncClient);
278+
when(mockMcpSyncClient.listTools()).thenReturn(mockResult);
279+
280+
McpToolset toolset =
281+
new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.of(toolFilter));
282+
283+
List<BaseTool> tools = toolset.getTools(mockReadonlyContext).toList().blockingGet();
284+
285+
assertThat(tools.stream().map(BaseTool::name).collect(ImmutableList.toImmutableList()))
286+
.containsExactly("tool1", "tool3")
287+
.inOrder();
288+
verify(mockMcpSessionManager).createSession();
289+
verify(mockMcpSyncClient).listTools();
290+
}
291+
292+
@Test
293+
public void getTools_retriesAndFailsAfterMaxRetries() {
294+
when(mockMcpSessionManager.createSession()).thenReturn(mockMcpSyncClient);
295+
when(mockMcpSyncClient.listTools()).thenThrow(new RuntimeException("Test Exception"));
296+
297+
McpToolset toolset =
298+
new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty());
299+
300+
toolset
301+
.getTools(mockReadonlyContext)
302+
.test()
303+
.awaitDone(5, SECONDS)
304+
.assertError(McpToolsetException.McpToolLoadingException.class);
305+
306+
verify(mockMcpSessionManager, times(3)).createSession();
307+
verify(mockMcpSyncClient, times(3)).listTools();
308+
}
309+
310+
@Test
311+
public void getTools_succeedsOnLastRetryAttempt() {
312+
McpSchema.ListToolsResult mockResult = new McpSchema.ListToolsResult(ImmutableList.of(), null);
313+
when(mockMcpSessionManager.createSession()).thenReturn(mockMcpSyncClient);
314+
when(mockMcpSyncClient.listTools())
315+
.thenThrow(new RuntimeException("Attempt 1 failed"))
316+
.thenThrow(new RuntimeException("Attempt 2 failed"))
317+
.thenReturn(mockResult);
318+
319+
McpToolset toolset =
320+
new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty());
321+
322+
List<BaseTool> tools = toolset.getTools(mockReadonlyContext).toList().blockingGet();
323+
324+
assertThat(tools).isEmpty();
325+
verify(mockMcpSessionManager, times(3)).createSession();
326+
verify(mockMcpSyncClient, times(3)).listTools();
327+
}
246328
}

0 commit comments

Comments
 (0)