Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedDeque;

import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
Expand All @@ -44,6 +47,10 @@
import org.springframework.ai.tool.observation.ToolCallingObservationDocumentation;
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.core.task.TaskExecutor;
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
import org.springframework.lang.Nullable;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand Down Expand Up @@ -71,6 +78,8 @@ public final class DefaultToolCallingManager implements ToolCallingManager {
private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
= DefaultToolExecutionExceptionProcessor.builder().build();

private static final TaskExecutor DEFAULT_TASK_EXECUTOR = buildDefaultTaskExecutor();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initializing the static thread pool by calling buildDefaultTaskExecutor() here causes the method to be executed regardless of whether the user provides a thread pool in the constructor, which seems inappropriate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just deployed a change tweaking this behavior. The system will now create a default instance only when no custom configuration is supplied.


// @formatter:on

private final ObservationRegistry observationRegistry;
Expand All @@ -79,17 +88,20 @@ public final class DefaultToolCallingManager implements ToolCallingManager {

private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;

private final TaskExecutor taskExecutor;

private ToolCallingObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver,
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor, @Nullable TaskExecutor taskExecutor) {
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");

this.observationRegistry = observationRegistry;
this.toolCallbackResolver = toolCallbackResolver;
this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
}

@Override
Expand Down Expand Up @@ -173,64 +185,59 @@ private static List<Message> buildConversationHistoryBeforeToolExecution(Prompt
*/
private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage,
ToolContext toolContext) {
List<ToolCallback> toolCallbacks = List.of();
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
toolCallbacks = toolCallingChatOptions.getToolCallbacks();
}

List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();

Boolean returnDirect = null;

for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {

logger.debug("Executing tool call: {}", toolCall.name());

String toolName = toolCall.name();
String toolInputArguments = toolCall.arguments();

ToolCallback toolCallback = toolCallbacks.stream()
.filter(tool -> toolName.equals(tool.getToolDefinition().name()))
.findFirst()
.orElseGet(() -> this.toolCallbackResolver.resolve(toolName));

if (toolCallback == null) {
throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
}
final List<ToolCallback> toolCallbacks = (prompt
.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions)
? toolCallingChatOptions.getToolCallbacks() : List.of();

if (returnDirect == null) {
returnDirect = toolCallback.getToolMetadata().returnDirect();
}
else {
returnDirect = returnDirect && toolCallback.getToolMetadata().returnDirect();
}

ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder()
.toolDefinition(toolCallback.getToolDefinition())
.toolMetadata(toolCallback.getToolMetadata())
.toolCallArguments(toolInputArguments)
.build();

String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
String toolResult;
try {
toolResult = toolCallback.call(toolInputArguments, toolContext);
}
catch (ToolExecutionException ex) {
toolResult = this.toolExecutionExceptionProcessor.process(ex);
}
observationContext.setToolCallResult(toolResult);
return toolResult;
});

toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName,
toolCallResult != null ? toolCallResult : ""));
}

return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
final Queue<Boolean> toolsReturnDirect = new ConcurrentLinkedDeque<>();
List<ToolResponseMessage.ToolResponse> toolResponses = assistantMessage.getToolCalls()
.stream()
.map(toolCall -> CompletableFuture.supplyAsync(() -> {
logger.debug("Executing tool call: {}", toolCall.name());

String toolName = toolCall.name();
String toolInputArguments = toolCall.arguments();

ToolCallback toolCallback = toolCallbacks.stream()
.filter(tool -> toolName.equals(tool.getToolDefinition().name()))
.findFirst()
.orElseGet(() -> this.toolCallbackResolver.resolve(toolName));

if (toolCallback == null) {
throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
}

toolsReturnDirect.add(toolCallback.getToolMetadata().returnDirect());

ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder()
.toolDefinition(toolCallback.getToolDefinition())
.toolMetadata(toolCallback.getToolMetadata())
.toolCallArguments(toolInputArguments)
.build();

String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
String toolResult;
try {
toolResult = toolCallback.call(toolInputArguments, toolContext);
}
catch (ToolExecutionException ex) {
toolResult = this.toolExecutionExceptionProcessor.process(ex);
}
observationContext.setToolCallResult(toolResult);
return toolResult;
});

return new ToolResponseMessage.ToolResponse(toolCall.id(), toolName,
toolCallResult != null ? toolCallResult : "");
}, this.taskExecutor))
.map(CompletableFuture::join)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended to wait for all tools to finish execution in startup order? I'm not sure if this approach makes sense—would using CompletableFuture.allOf() be more appropriate?

Copy link
Contributor Author

@rafaelrddc rafaelrddc Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to use allOf, but we still need to use join to get the responses. However, it will return immediately.

.toList();

return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()),
toolsReturnDirect.stream().allMatch(Boolean::booleanValue));
}

private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages,
Expand All @@ -245,6 +252,16 @@ public void setObservationConvention(ToolCallingObservationConvention observatio
this.observationConvention = observationConvention;
}

private static TaskExecutor buildDefaultTaskExecutor() {
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
taskExecutor.setThreadNamePrefix("ai-toll-calling-");
taskExecutor.setCorePoolSize(4);
taskExecutor.setMaxPoolSize(16);
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
taskExecutor.initialize();
return taskExecutor;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, no matter what, we shouldn't hardcode thread pool configurations as "magic values" here, even if it's just a default thread pool. Perhaps we should introduce corresponding configuration options and a dedicated configuration class to manage these parameters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, the spelling of the thread name also appears to be incorrect: "toll" should be "tool".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add it to the ToolCallingProperties class, which will allow us to customize the default values using properties. What do you think?

Exemple

@ConfigurationProperties(ToolCallingProperties.CONFIG_PREFIX)
public class ToolCallingProperties {

	public static final String CONFIG_PREFIX = "spring.ai.tools";

	private final Observations observations = new Observations();

	private final TaskExecutorProperties taskExecutor = new TaskExecutorProperties();
    
        public static class TaskExecutorProperties {
		/**
		 * Whether to enable custom task executor configuration for tool calls.
		 */
		private boolean enabled = false;

		/**
		 * Core number of threads in the pool.
		 */
		private int corePoolSize = Runtime.getRuntime().availableProcessors();

		/**
		 * Maximum number of threads in the pool.
		 */
		private int maxPoolSize = Runtime.getRuntime().availableProcessors() * 2;

		/**
		 * Capacity of the queue for holding tasks before they are executed.
		 */
		private int queueCapacity = 100;

		/**
		 * Prefix for thread names in the pool.
		 */
		private String threadNamePrefix = "tool-call-exec-";
	}
}

public static Builder builder() {
return new Builder();
}
Expand All @@ -260,6 +277,8 @@ public final static class Builder {

private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR;

private TaskExecutor taskExecutor = DEFAULT_TASK_EXECUTOR;

private Builder() {
}

Expand All @@ -279,9 +298,14 @@ public Builder toolExecutionExceptionProcessor(
return this;
}

public Builder taskExecutor(TaskExecutor taskExecutor) {
this.taskExecutor = taskExecutor;
return this;
}

public DefaultToolCallingManager build() {
return new DefaultToolCallingManager(this.observationRegistry, this.toolCallbackResolver,
this.toolExecutionExceptionProcessor);
this.toolExecutionExceptionProcessor, this.taskExecutor);
}

}
Expand Down