diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/src/main/java/org/springframework/ai/model/deepseek/autoconfigure/DeepSeekChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/src/main/java/org/springframework/ai/model/deepseek/autoconfigure/DeepSeekChatAutoConfiguration.java index 23515e416f2..b4e4ab625da 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/src/main/java/org/springframework/ai/model/deepseek/autoconfigure/DeepSeekChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/src/main/java/org/springframework/ai/model/deepseek/autoconfigure/DeepSeekChatAutoConfiguration.java @@ -18,9 +18,11 @@ import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.chat.model.StreamFunctionCallingHelper; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.deepseek.DeepSeekChatModel; import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.ai.deepseek.api.DeepSeekStreamFunctionCallingHelper; import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.model.SpringAIModelProperties; import org.springframework.ai.model.SpringAIModels; @@ -69,11 +71,14 @@ public DeepSeekChatModel deepSeekChatModel(DeepSeekConnectionProperties commonPr RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, ObjectProvider observationRegistry, ObjectProvider observationConvention, - ObjectProvider deepseekToolExecutionEligibilityPredicate) { + ObjectProvider deepseekToolExecutionEligibilityPredicate, + ObjectProvider> streamFunctionCallingHelper) { var deepSeekApi = deepSeekApi(chatProperties, commonProperties, restClientBuilderProvider.getIfAvailable(RestClient::builder), - webClientBuilderProvider.getIfAvailable(WebClient::builder), responseErrorHandler); + webClientBuilderProvider.getIfAvailable(WebClient::builder), + responseErrorHandler, + streamFunctionCallingHelper.getIfAvailable(DeepSeekStreamFunctionCallingHelper::new)); var chatModel = DeepSeekChatModel.builder() .deepSeekApi(deepSeekApi) @@ -92,7 +97,8 @@ public DeepSeekChatModel deepSeekChatModel(DeepSeekConnectionProperties commonPr private DeepSeekApi deepSeekApi(DeepSeekChatProperties chatProperties, DeepSeekConnectionProperties commonProperties, RestClient.Builder restClientBuilder, - WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler, + StreamFunctionCallingHelper chunkMerger) { String resolvedBaseUrl = StringUtils.hasText(chatProperties.getBaseUrl()) ? chatProperties.getBaseUrl() : commonProperties.getBaseUrl(); @@ -110,6 +116,7 @@ private DeepSeekApi deepSeekApi(DeepSeekChatProperties chatProperties, .restClientBuilder(restClientBuilder) .webClientBuilder(webClientBuilder) .responseErrorHandler(responseErrorHandler) + .streamFunctionCallingHelper(chunkMerger) .build(); } diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java index f565c2ba26e..7611d16f6cd 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java @@ -32,6 +32,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.ai.chat.model.StreamFunctionCallingHelper; import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; @@ -67,7 +68,7 @@ public class DeepSeekApi { private final WebClient webClient; - private DeepSeekStreamFunctionCallingHelper chunkMerger = new DeepSeekStreamFunctionCallingHelper(); + private final StreamFunctionCallingHelper chunkMerger; /** * Create a new chat completion api. @@ -79,10 +80,11 @@ public class DeepSeekApi { * @param restClientBuilder RestClient builder. * @param webClientBuilder WebClient builder. * @param responseErrorHandler Response error handler. + * @param chunkMerger Chat completion chunk merger. */ public DeepSeekApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, String completionsPath, String betaPrefixPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, - ResponseErrorHandler responseErrorHandler) { + ResponseErrorHandler responseErrorHandler, StreamFunctionCallingHelper chunkMerger) { Assert.hasText(completionsPath, "Completions Path must not be null"); Assert.hasText(betaPrefixPath, "Beta feature path must not be null"); @@ -105,6 +107,8 @@ public DeepSeekApi(String baseUrl, ApiKey apiKey, MultiValueMap .baseUrl(baseUrl) .defaultHeaders(finalHeaders) .build(); // @formatter:on + + this.chunkMerger = chunkMerger; } /** @@ -923,6 +927,8 @@ public static class Builder { private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER; + private StreamFunctionCallingHelper chunkMerger = new DeepSeekStreamFunctionCallingHelper(); + public Builder baseUrl(String baseUrl) { Assert.hasText(baseUrl, "baseUrl cannot be null or empty"); this.baseUrl = baseUrl; @@ -977,10 +983,16 @@ public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) { return this; } + public Builder streamFunctionCallingHelper(StreamFunctionCallingHelper chunkMerger){ + Assert.notNull(chunkMerger, "chunkMerger cannot be null"); + this.chunkMerger = chunkMerger; + return this; + } + public DeepSeekApi build() { Assert.notNull(this.apiKey, "apiKey must be set"); return new DeepSeekApi(this.baseUrl, this.apiKey, this.headers, this.completionsPath, this.betaPrefixPath, - this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler); + this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler, this.chunkMerger); } } diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekStreamFunctionCallingHelper.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekStreamFunctionCallingHelper.java index 68cbe2a4b93..f05050e3b1d 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekStreamFunctionCallingHelper.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekStreamFunctionCallingHelper.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.List; +import org.springframework.ai.chat.model.StreamFunctionCallingHelper; import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk; import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk.ChunkChoice; import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionFinishReason; @@ -34,8 +35,9 @@ * * @author Geng Rong */ -public class DeepSeekStreamFunctionCallingHelper { +public class DeepSeekStreamFunctionCallingHelper implements StreamFunctionCallingHelper { + @Override public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) { if (previous == null) { @@ -142,6 +144,7 @@ private ChatCompletionFunction merge(ChatCompletionFunction previous, ChatComple * @param chatCompletion the ChatCompletionChunk to check * @return true if the ChatCompletionChunk is a streaming tool function call. */ + @Override public boolean isStreamingToolFunctionCall(ChatCompletionChunk chatCompletion) { if (chatCompletion == null || CollectionUtils.isEmpty(chatCompletion.choices())) { @@ -160,6 +163,7 @@ public boolean isStreamingToolFunctionCall(ChatCompletionChunk chatCompletion) { * @return true if the ChatCompletionChunk is a streaming tool function call and it is * the last one. */ + @Override public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatCompletion) { if (chatCompletion == null || CollectionUtils.isEmpty(chatCompletion.choices())) { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/StreamFunctionCallingHelper.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/StreamFunctionCallingHelper.java new file mode 100644 index 00000000000..1993505f057 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/StreamFunctionCallingHelper.java @@ -0,0 +1,38 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.model; + +/** + * Helper class to support Streaming function calling. It can merge the streamed + * ChatCompletionChunk in case of function calling message. + */ +public interface StreamFunctionCallingHelper { + T merge(T previous, T current); + + /** + * @param chatCompletion the ChatCompletionChunk to check + * @return true if the ChatCompletionChunk is a streaming tool function call. + */ + boolean isStreamingToolFunctionCall(T chatCompletion); + + /** + * @param chatCompletion the ChatCompletionChunk to check + * @return true if the ChatCompletionChunk is a streaming tool function call and it is + * the last one. + */ + boolean isStreamingToolFunctionCallFinish(T chatCompletion); +}