-
Notifications
You must be signed in to change notification settings - Fork 2k
Parallel Tool Execution #4255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Parallel Tool Execution #4255
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,9 @@ | |
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.Optional; | ||
| import java.util.Queue; | ||
| import java.util.concurrent.CompletableFuture; | ||
| import java.util.concurrent.ConcurrentLinkedDeque; | ||
|
|
||
| import io.micrometer.observation.ObservationRegistry; | ||
| import org.slf4j.Logger; | ||
|
|
@@ -44,6 +47,10 @@ | |
| import org.springframework.ai.tool.observation.ToolCallingObservationDocumentation; | ||
| import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; | ||
| import org.springframework.ai.tool.resolution.ToolCallbackResolver; | ||
| import org.springframework.core.task.TaskExecutor; | ||
| import org.springframework.core.task.support.ContextPropagatingTaskDecorator; | ||
| import org.springframework.lang.Nullable; | ||
| import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; | ||
| import org.springframework.util.Assert; | ||
| import org.springframework.util.CollectionUtils; | ||
|
|
||
|
|
@@ -71,6 +78,8 @@ public final class DefaultToolCallingManager implements ToolCallingManager { | |
| private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR | ||
| = DefaultToolExecutionExceptionProcessor.builder().build(); | ||
|
|
||
| private static final TaskExecutor DEFAULT_TASK_EXECUTOR = buildDefaultTaskExecutor(); | ||
|
|
||
| // @formatter:on | ||
|
|
||
| private final ObservationRegistry observationRegistry; | ||
|
|
@@ -79,17 +88,20 @@ public final class DefaultToolCallingManager implements ToolCallingManager { | |
|
|
||
| private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor; | ||
|
|
||
| private final TaskExecutor taskExecutor; | ||
|
|
||
| private ToolCallingObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; | ||
|
|
||
| public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver, | ||
| ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) { | ||
| ToolExecutionExceptionProcessor toolExecutionExceptionProcessor, @Nullable TaskExecutor taskExecutor) { | ||
| Assert.notNull(observationRegistry, "observationRegistry cannot be null"); | ||
| Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null"); | ||
| Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null"); | ||
|
|
||
| this.observationRegistry = observationRegistry; | ||
| this.toolCallbackResolver = toolCallbackResolver; | ||
| this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor; | ||
| this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor(); | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -173,64 +185,59 @@ private static List<Message> buildConversationHistoryBeforeToolExecution(Prompt | |
| */ | ||
| private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage, | ||
| ToolContext toolContext) { | ||
| List<ToolCallback> toolCallbacks = List.of(); | ||
| if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { | ||
| toolCallbacks = toolCallingChatOptions.getToolCallbacks(); | ||
| } | ||
|
|
||
| List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>(); | ||
|
|
||
| Boolean returnDirect = null; | ||
|
|
||
| for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { | ||
|
|
||
| logger.debug("Executing tool call: {}", toolCall.name()); | ||
|
|
||
| String toolName = toolCall.name(); | ||
| String toolInputArguments = toolCall.arguments(); | ||
|
|
||
| ToolCallback toolCallback = toolCallbacks.stream() | ||
| .filter(tool -> toolName.equals(tool.getToolDefinition().name())) | ||
| .findFirst() | ||
| .orElseGet(() -> this.toolCallbackResolver.resolve(toolName)); | ||
|
|
||
| if (toolCallback == null) { | ||
| throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); | ||
| } | ||
| final List<ToolCallback> toolCallbacks = (prompt | ||
| .getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) | ||
| ? toolCallingChatOptions.getToolCallbacks() : List.of(); | ||
|
|
||
| if (returnDirect == null) { | ||
| returnDirect = toolCallback.getToolMetadata().returnDirect(); | ||
| } | ||
| else { | ||
| returnDirect = returnDirect && toolCallback.getToolMetadata().returnDirect(); | ||
| } | ||
|
|
||
| ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder() | ||
| .toolDefinition(toolCallback.getToolDefinition()) | ||
| .toolMetadata(toolCallback.getToolMetadata()) | ||
| .toolCallArguments(toolInputArguments) | ||
| .build(); | ||
|
|
||
| String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL | ||
| .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, | ||
| this.observationRegistry) | ||
| .observe(() -> { | ||
| String toolResult; | ||
| try { | ||
| toolResult = toolCallback.call(toolInputArguments, toolContext); | ||
| } | ||
| catch (ToolExecutionException ex) { | ||
| toolResult = this.toolExecutionExceptionProcessor.process(ex); | ||
| } | ||
| observationContext.setToolCallResult(toolResult); | ||
| return toolResult; | ||
| }); | ||
|
|
||
| toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, | ||
| toolCallResult != null ? toolCallResult : "")); | ||
| } | ||
|
|
||
| return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect); | ||
| final Queue<Boolean> toolsReturnDirect = new ConcurrentLinkedDeque<>(); | ||
| List<ToolResponseMessage.ToolResponse> toolResponses = assistantMessage.getToolCalls() | ||
| .stream() | ||
| .map(toolCall -> CompletableFuture.supplyAsync(() -> { | ||
| logger.debug("Executing tool call: {}", toolCall.name()); | ||
|
|
||
| String toolName = toolCall.name(); | ||
| String toolInputArguments = toolCall.arguments(); | ||
|
|
||
| ToolCallback toolCallback = toolCallbacks.stream() | ||
| .filter(tool -> toolName.equals(tool.getToolDefinition().name())) | ||
| .findFirst() | ||
| .orElseGet(() -> this.toolCallbackResolver.resolve(toolName)); | ||
|
|
||
| if (toolCallback == null) { | ||
| throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); | ||
| } | ||
|
|
||
| toolsReturnDirect.add(toolCallback.getToolMetadata().returnDirect()); | ||
|
|
||
| ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder() | ||
| .toolDefinition(toolCallback.getToolDefinition()) | ||
| .toolMetadata(toolCallback.getToolMetadata()) | ||
| .toolCallArguments(toolInputArguments) | ||
| .build(); | ||
|
|
||
| String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL | ||
| .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, | ||
| this.observationRegistry) | ||
| .observe(() -> { | ||
| String toolResult; | ||
| try { | ||
| toolResult = toolCallback.call(toolInputArguments, toolContext); | ||
| } | ||
| catch (ToolExecutionException ex) { | ||
| toolResult = this.toolExecutionExceptionProcessor.process(ex); | ||
| } | ||
| observationContext.setToolCallResult(toolResult); | ||
| return toolResult; | ||
| }); | ||
|
|
||
| return new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, | ||
| toolCallResult != null ? toolCallResult : ""); | ||
| }, this.taskExecutor)) | ||
| .map(CompletableFuture::join) | ||
|
||
| .toList(); | ||
|
|
||
| return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), | ||
| toolsReturnDirect.stream().allMatch(Boolean::booleanValue)); | ||
| } | ||
|
|
||
| private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages, | ||
|
|
@@ -245,6 +252,16 @@ public void setObservationConvention(ToolCallingObservationConvention observatio | |
| this.observationConvention = observationConvention; | ||
| } | ||
|
|
||
| private static TaskExecutor buildDefaultTaskExecutor() { | ||
| ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor(); | ||
| taskExecutor.setThreadNamePrefix("ai-toll-calling-"); | ||
| taskExecutor.setCorePoolSize(4); | ||
| taskExecutor.setMaxPoolSize(16); | ||
| taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator()); | ||
| taskExecutor.initialize(); | ||
| return taskExecutor; | ||
| } | ||
|
|
||
|
||
| public static Builder builder() { | ||
| return new Builder(); | ||
| } | ||
|
|
@@ -260,6 +277,8 @@ public final static class Builder { | |
|
|
||
| private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR; | ||
|
|
||
| private TaskExecutor taskExecutor = DEFAULT_TASK_EXECUTOR; | ||
|
|
||
| private Builder() { | ||
| } | ||
|
|
||
|
|
@@ -279,9 +298,14 @@ public Builder toolExecutionExceptionProcessor( | |
| return this; | ||
| } | ||
|
|
||
| public Builder taskExecutor(TaskExecutor taskExecutor) { | ||
| this.taskExecutor = taskExecutor; | ||
| return this; | ||
| } | ||
|
|
||
| public DefaultToolCallingManager build() { | ||
| return new DefaultToolCallingManager(this.observationRegistry, this.toolCallbackResolver, | ||
| this.toolExecutionExceptionProcessor); | ||
| this.toolExecutionExceptionProcessor, this.taskExecutor); | ||
| } | ||
|
|
||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initializing the static thread pool by calling
buildDefaultTaskExecutor()here causes the method to be executed regardless of whether the user provides a thread pool in the constructor, which seems inappropriate.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just deployed a change tweaking this behavior. The system will now create a default instance only when no custom configuration is supplied.