From fd8593db8dfc13a7136e3292e9fe2e43170fd0a1 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Mon, 21 Jul 2025 17:21:39 +0100 Subject: [PATCH] Tool filtering as ToolCallingChatOption - This PR provides a toolcalling option to filter the toolcallbacks - The toolcallback filter comes with the Predicate with any type of ToolCallback At filter time, via ToolCallback, both the ToolDefinition and ToolMetadata are available along with any custom metadata from ToolCallback implementations Signed-off-by: Ilayaperumal Gopinathan --- .../ai/mcp/AsyncMcpToolCallback.java | 2 +- .../ai/mcp/McpToolCallback.java | 28 +++++++++++++++++++ .../ai/mcp/SyncMcpToolCallback.java | 2 +- .../ai/anthropic/AnthropicChatModel.java | 11 ++++++-- .../ai/anthropic/AnthropicChatOptions.java | 22 ++++++++++++++- .../ai/anthropic/AnthropicChatModelIT.java | 18 +++++++++++- .../azure/openai/AzureOpenAiChatOptions.java | 20 +++++++++++++ .../bedrock/converse/BedrockChatOptions.java | 21 ++++++++++++++ .../converse/BedrockProxyChatModel.java | 5 ++-- .../ai/deepseek/DeepSeekChatOptions.java | 23 +++++++++++++-- .../google/genai/GoogleGenAiChatOptions.java | 19 +++++++++++++ .../ai/minimax/MiniMaxChatOptions.java | 20 +++++++++++++ .../ai/mistralai/MistralAiChatOptions.java | 20 +++++++++++++ .../ai/ollama/api/OllamaOptions.java | 22 ++++++++++++++- .../ai/openai/OpenAiChatModel.java | 7 +++-- .../ai/openai/OpenAiChatOptions.java | 20 +++++++++++++ .../gemini/VertexAiGeminiChatOptions.java | 16 +++++++++++ .../ai/zhipuai/ZhiPuAiChatOptions.java | 20 +++++++++++++ .../tool/DefaultToolCallingChatOptions.java | 23 +++++++++++++++ .../ai/model/tool/ToolCallingChatOptions.java | 25 +++++++++++++++++ 20 files changed, 329 insertions(+), 15 deletions(-) create mode 100644 mcp/common/src/main/java/org/springframework/ai/mcp/McpToolCallback.java diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index 5f8da416109..eda63ec20d8 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -59,7 +59,7 @@ * @see McpAsyncClient * @see Tool */ -public class AsyncMcpToolCallback implements ToolCallback { +public class AsyncMcpToolCallback implements McpToolCallback { private final McpAsyncClient asyncMcpClient; diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolCallback.java new file mode 100644 index 00000000000..732b170bac2 --- /dev/null +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolCallback.java @@ -0,0 +1,28 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +import org.springframework.ai.tool.ToolCallback; + +/** + * Custom type for MCP specific tool. + */ +public interface McpToolCallback extends ToolCallback { + + // TODO: Add MCP metadata + +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index fc61d801df1..0e65d43fc35 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -61,7 +61,7 @@ * @see McpSyncClient * @see Tool */ -public class SyncMcpToolCallback implements ToolCallback { +public class SyncMcpToolCallback implements McpToolCallback { private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolCallback.class); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 2ded856a05f..1b1eb1d1596 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -455,17 +455,22 @@ Prompt buildRequestPrompt(Prompt prompt) { this.defaultOptions.getInternalToolExecutionEnabled())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); - requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), - this.defaultOptions.getToolCallbacks())); + // Make sure to set the tool context before setting toolcallbacks so that the + // context can be used to filter the toolcallbacks. requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), this.defaultOptions.getToolContext())); + requestOptions.setToolCallbacks(runtimeOptions.getFilteredToolCallbacks(ToolCallingChatOptions + .mergeToolCallbacks(runtimeOptions.getToolCallbacks(), this.defaultOptions.getToolCallbacks()))); } else { requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); - requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); + // Make sure to set the tool context before setting toolcallbacks so that the + // context can be used to filter the toolcallbacks. requestOptions.setToolContext(this.defaultOptions.getToolContext()); + requestOptions + .setToolCallbacks(this.defaultOptions.getFilteredToolCallbacks(this.defaultOptions.getToolCallbacks())); } ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index dbfbee561c8..4bf3f831af2 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -24,6 +24,8 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.BiPredicate; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -82,13 +84,15 @@ public class AnthropicChatOptions implements ToolCallingChatOptions { @JsonIgnore private Map toolContext = new HashMap<>(); - /** * Optional HTTP headers to be added to the chat completion request. */ @JsonIgnore private Map httpHeaders = new HashMap<>(); + @JsonIgnore + private Predicate toolCallbackFilter; + // @formatter:on public static Builder builder() { @@ -110,6 +114,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null) .build(); } @@ -259,6 +264,16 @@ public void setHttpHeaders(Map httpHeaders) { this.httpHeaders = httpHeaders; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override @SuppressWarnings("unchecked") public AnthropicChatOptions copy() { @@ -384,6 +399,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public Builder httpHeaders(Map httpHeaders) { this.options.setHttpHeaders(httpHeaders); return this; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index 6570d5ee6a6..e35d8b680a9 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.function.BiPredicate; +import java.util.function.Predicate; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; @@ -42,6 +44,7 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -50,6 +53,7 @@ import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -284,11 +288,23 @@ void functionCallTest() { var promptOptions = AnthropicChatOptions.builder() .model(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getName()) + .toolContext(Map.of("tool_prefix", "get")) .toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") .inputType(MockWeatherService.Request.class) - .build()) + .build(), + FunctionToolCallback.builder("retrieveWeather", new MockWeatherService()) + .description( + "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") + .inputType(MockWeatherService.Request.class) + .build()) + .toolCallbackFilter(new Predicate() { + @Override + public boolean test(ToolCallback toolCallback) { + return (toolCallback.getToolDefinition().name().startsWith("get")) ? true : false; + } + }) .build(); ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index da442b4ad4d..6a83f0970d7 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; import com.azure.ai.openai.models.ChatCompletionStreamOptions; @@ -257,6 +258,9 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @JsonIgnore + private Predicate toolCallbackFilter; + public static Builder builder() { return new Builder(); } @@ -288,6 +292,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .toolCallbacks( fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -474,6 +479,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + public ChatCompletionStreamOptions getStreamOptions() { return this.streamOptions; } @@ -664,6 +679,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java index 776cba66d58..09578a59fd6 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java @@ -24,10 +24,12 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; @@ -77,6 +79,9 @@ public class BedrockChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Predicate toolCallbackFilter; + public static Builder builder() { return new Builder(); } @@ -96,6 +101,7 @@ public static BedrockChatOptions fromOptions(BedrockChatOptions fromOptions) { .toolNames(new HashSet<>(fromOptions.getToolNames())) .toolContext(new HashMap<>(fromOptions.getToolContext())) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -224,6 +230,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override @SuppressWarnings("unchecked") public BedrockChatOptions copy() { @@ -337,6 +353,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public BedrockChatOptions build() { return this.options; } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 071e77a78cb..b81fca818f0 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -303,8 +303,8 @@ Prompt buildRequestPrompt(Prompt prompt) { : this.defaultOptions.getTemperature()) .topP(runtimeOptions.getTopP() != null ? runtimeOptions.getTopP() : this.defaultOptions.getTopP()) - .toolCallbacks(runtimeOptions.getToolCallbacks() != null ? runtimeOptions.getToolCallbacks() - : this.defaultOptions.getToolCallbacks()) + .toolCallbacks(runtimeOptions.getFilteredToolCallbacks(runtimeOptions.getToolCallbacks() != null + ? runtimeOptions.getToolCallbacks() : this.defaultOptions.getToolCallbacks())) .toolNames(runtimeOptions.getToolNames() != null ? runtimeOptions.getToolNames() : this.defaultOptions.getToolNames()) .toolContext(runtimeOptions.getToolContext() != null ? runtimeOptions.getToolContext() @@ -312,6 +312,7 @@ Prompt buildRequestPrompt(Prompt prompt) { .internalToolExecutionEnabled(runtimeOptions.getInternalToolExecutionEnabled() != null ? runtimeOptions.getInternalToolExecutionEnabled() : this.defaultOptions.getInternalToolExecutionEnabled()) + .toolCallbackFilter(runtimeOptions.getToolCallbackFilter()) .build(); } diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java index b9c7a3d4962..e7c02347dc5 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -143,7 +144,10 @@ public class DeepSeekChatOptions implements ToolCallingChatOptions { private Set toolNames = new HashSet<>(); @JsonIgnore - private Map toolContext = new HashMap<>();; + private Map toolContext = new HashMap<>(); + + @JsonIgnore + private Predicate toolCallbackFilter; public static Builder builder() { return new Builder(); @@ -246,7 +250,6 @@ public void setToolChoice(Object toolChoice) { this.toolChoice = toolChoice; } - @Override @JsonIgnore public List getToolCallbacks() { @@ -322,6 +325,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public DeepSeekChatOptions copy() { return DeepSeekChatOptions.fromOptions(this); @@ -379,6 +392,7 @@ public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) { .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -497,6 +511,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public DeepSeekChatOptions build() { return this.options; } diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java index 2ee9e4fa029..1ad98d216da 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -138,6 +139,9 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private List safetySettings = new ArrayList<>(); + + @JsonIgnore + private Predicate toolCallbackFilter; // @formatter:on public static Builder builder() { @@ -327,6 +331,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -489,6 +503,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public GoogleGenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index a8f1e62e77e..0b14fb600ab 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -156,6 +157,9 @@ public class MiniMaxChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Predicate toolCallbackFilter; + // @formatter:on public static Builder builder() { @@ -180,6 +184,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext()) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -362,6 +367,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public int hashCode() { return Objects.hash(model, frequencyPenalty, maxTokens, n, presencePenalty, responseFormat, seed, stop, @@ -508,6 +523,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public MiniMaxChatOptions build() { return this.options; } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index 801c35f2118..8993bf6db62 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -159,6 +160,9 @@ public class MistralAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Map toolContext = new HashMap<>(); + @JsonIgnore + private Predicate toolCallbackFilter; + public static Builder builder() { return new Builder(); } @@ -182,6 +186,7 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -366,6 +371,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override @SuppressWarnings("unchecked") public MistralAiChatOptions copy() { @@ -517,6 +532,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public MistralAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index a71be1ce2b2..4e631fd6ade 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -344,6 +345,9 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { @JsonIgnore private Map toolContext = new HashMap<>(); + @JsonIgnore + private Predicate toolCallbackFilter; + public static Builder builder() { return new Builder(); } @@ -398,7 +402,8 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolCallbacks(fromOptions.getToolCallbacks()) - .toolContext(fromOptions.getToolContext()).build(); + .toolContext(fromOptions.getToolContext()) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()).build(); } // ------------------- @@ -764,6 +769,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + /** * Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs. * @return The {@link Map} of key/value pairs. @@ -1039,6 +1054,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public OllamaOptions build() { return this.options; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 2ad584fa82f..7026c8d05d4 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -529,10 +529,10 @@ Prompt buildRequestPrompt(Prompt prompt) { this.defaultOptions.getInternalToolExecutionEnabled())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); - requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), - this.defaultOptions.getToolCallbacks())); requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), this.defaultOptions.getToolContext())); + requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), + this.defaultOptions.getToolCallbacks())); } else { requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); @@ -542,7 +542,8 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setToolContext(this.defaultOptions.getToolContext()); } - ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); + ToolCallingChatOptions + .validateToolCallbacks(requestOptions.getFilteredToolCallbacks(requestOptions.getToolCallbacks())); return new Prompt(prompt.getInstructions(), requestOptions); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index afbbd803ec6..b115c25b335 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -228,6 +229,9 @@ public class OpenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Map toolContext = new HashMap<>(); + @JsonIgnore + private Predicate toolCallbackFilter; + // @formatter:on public static Builder builder() { @@ -268,6 +272,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .metadata(fromOptions.getMetadata()) .reasoningEffort(fromOptions.getReasoningEffort()) .webSearchOptions(fromOptions.getWebSearchOptions()) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -564,6 +569,16 @@ public void setWebSearchOptions(WebSearchOptions webSearchOptions) { this.webSearchOptions = webSearchOptions; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); @@ -802,6 +817,11 @@ public Builder webSearchOptions(WebSearchOptions webSearchOptions) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public OpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 9c7788c82a3..ca551f6daf4 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -24,6 +24,8 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.BiPredicate; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -151,6 +153,9 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions { @JsonIgnore private List safetySettings = new ArrayList<>(); + + @JsonIgnore + private Predicate toolCallbackFilter; // @formatter:on public static Builder builder() { @@ -178,6 +183,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setToolContext(fromOptions.getToolContext()); options.setLogprobs(fromOptions.getLogprobs()); options.setResponseLogprobs(fromOptions.getResponseLogprobs()); + options.setToolCallbackFilter(fromOptions.getToolCallbackFilter()); return options; } @@ -358,6 +364,16 @@ public boolean getResponseLogprobs() { return responseLogprobs; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index c31320defe1..2a74d41e56c 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -126,6 +127,9 @@ public class ZhiPuAiChatOptions implements ToolCallingChatOptions { private Map toolContext = new HashMap<>(); // @formatter:on + @JsonIgnore + private Predicate toolCallbackFilter; + public static Builder builder() { return new Builder(); } @@ -146,6 +150,7 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext()) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -314,6 +319,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public int hashCode() { final int prime = 31; @@ -610,6 +625,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public ZhiPuAiChatOptions build() { return this.options; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java index 870db6931b9..da295eda950 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java @@ -23,6 +23,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.BiPredicate; +import java.util.function.Predicate; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.tool.ToolCallback; @@ -70,6 +72,9 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { @Nullable private Double topP; + @Nullable + private Predicate toolCallbackFilter; + @Override public List getToolCallbacks() { return List.copyOf(this.toolCallbacks); @@ -198,6 +203,16 @@ public void setTopP(@Nullable Double topP) { this.topP = topP; } + @Override + @Nullable + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + public void setToolCallbackFilter(@Nullable Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override @SuppressWarnings("unchecked") public T copy() { @@ -214,6 +229,7 @@ public T copy() { options.setTemperature(getTemperature()); options.setTopK(getTopK()); options.setTopP(getTopP()); + options.setToolCallbackFilter(getToolCallbackFilter()); return (T) options; } @@ -325,6 +341,13 @@ public ToolCallingChatOptions.Builder topP(@Nullable Double topP) { return this; } + @Override + public ToolCallingChatOptions.Builder toolCallbackFilter( + @Nullable Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + @Override public ToolCallingChatOptions build() { return this.options; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index f06e71aa869..09463719379 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -22,8 +22,11 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.BiPredicate; +import java.util.function.Predicate; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.support.ToolUtils; @@ -88,6 +91,26 @@ public interface ToolCallingChatOptions extends ChatOptions { */ void setToolContext(Map toolContext); + void setToolCallbackFilter(Predicate toolCallbackFilter); + + Predicate getToolCallbackFilter(); + + default List getFilteredToolCallbacks(List toolCallbacks) { + Predicate filter = getToolCallbackFilter(); + if (filter == null) { + return this.getToolCallbacks(); + } + else { + return applyFilter(toolCallbacks, filter); + } + } + + private List applyFilter(List toolCallbacks, + Predicate filter) { + + return toolCallbacks.stream().filter(toolCallback -> filter.test((T) toolCallback)).toList(); + } + /** * A builder to create a new {@link ToolCallingChatOptions} instance. */ @@ -193,6 +216,8 @@ interface Builder extends ChatOptions.Builder { */ Builder toolContext(String key, Object value); + Builder toolCallbackFilter(@Nullable Predicate toolCallbackFilter); + // ChatOptions.Builder methods @Override