diff --git a/models/spring-ai-ollama/pom.xml b/models/spring-ai-ollama/pom.xml index 673064e4bb1..94563516c88 100644 --- a/models/spring-ai-ollama/pom.xml +++ b/models/spring-ai-ollama/pom.xml @@ -114,5 +114,11 @@ ollama test - + + org.springframework.ai + spring-ai-client-chat + 1.1.0-SNAPSHOT + compile + + diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index c6bd6c2676e..e6f737af14b 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -17,10 +17,7 @@ package org.springframework.ai.ollama; import java.time.Duration; -import java.util.Base64; -import java.util.List; -import java.util.Map; -import java.util.Optional; +import java.util.*; import com.fasterxml.jackson.core.type.TypeReference; import io.micrometer.observation.Observation; @@ -28,6 +25,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClientRequest; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -234,6 +232,9 @@ public ChatResponse call(Prompt prompt) { return this.internalCall(requestPrompt, null); } + + + private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { OllamaApi.ChatRequest request = ollamaChatRequest(prompt, false); @@ -304,6 +305,22 @@ public Flux stream(Prompt prompt) { return this.internalStream(requestPrompt, null); } + public Flux stream(ChatClientRequest chatClientRequest) { + Prompt prompt = chatClientRequest.prompt(); + Prompt requestPrompt = buildRequestPrompt(prompt); + Flux responseFlux = this.internalStream(requestPrompt, null); + return responseFlux.map(chatResponse -> { + if (isStop(chatResponse)) { + return ChatResponse.builder() + .context(chatClientRequest.context()) + .from(chatResponse) + .build(); + } + return chatResponse; + }); + } + + private Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { return Flux.deferContextual(contextView -> { OllamaApi.ChatRequest request = ollamaChatRequest(prompt, true); @@ -337,7 +354,6 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh } var assistantMessage = new AssistantMessage(content, Map.of(), toolCalls); - ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; if (chunk.promptEvalCount() != null && chunk.evalCount() != null) { generationMetadata = ChatGenerationMetadata.builder().finishReason(chunk.doneReason()).build(); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java index 0d8b6a0b714..a117fcf19a2 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java @@ -16,6 +16,7 @@ package org.springframework.ai.ollama; +import java.util.HashMap; import java.util.List; import java.util.stream.Collectors; @@ -23,6 +24,7 @@ import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClientRequest; import reactor.core.publisher.Flux; import org.springframework.ai.chat.metadata.ChatResponseMetadata; @@ -90,6 +92,36 @@ void observationForChatOperation() { validate(responseMetadata); } + @Test + void observationForChatOperationUsingRequest() { + var options = OllamaOptions.builder() + .model(MODEL) + .frequencyPenalty(0.0) + .numPredict(2048) + .presencePenalty(0.0) + .stop(List.of("this-is-the-end")) + .temperature(0.7) + .topK(1) + .topP(1.0) + .build(); + HashMap ctx = new HashMap<>(); + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + ChatClientRequest chatClientRequest=new ChatClientRequest(prompt,ctx); + Flux chatResponseFlux= this.chatModel.stream(chatClientRequest); + List responses = chatResponseFlux.collectList().block(); + System.out.println(responses); + assertThat(responses).isNotEmpty(); + assertThat(responses).hasSizeGreaterThan(10); + + String aggregatedResponse = responses.subList(0, responses.size() - 1) + .stream() + .map(r -> r.getResult().getOutput().getText()) + .collect(Collectors.joining()); + assertThat(aggregatedResponse).isNotEmpty(); + + + } + @Test void observationForStreamingChatOperation() { var options = OllamaOptions.builder() diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java index 9662764aec5..24ddc790814 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java @@ -16,9 +16,7 @@ package org.springframework.ai.openai.audio.speech; -import java.util.Collections; -import java.util.List; -import java.util.Objects; +import java.util.*; import org.springframework.ai.model.ModelResponse; import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; @@ -38,6 +36,8 @@ public class SpeechResponse implements ModelResponse { private final OpenAiAudioSpeechResponseMetadata speechResponseMetadata; + private final Map context; + /** * Creates a new instance of SpeechResponse with the given speech result. * @param speech the speech result to be set in the SpeechResponse @@ -59,6 +59,13 @@ public SpeechResponse(Speech speech) { public SpeechResponse(Speech speech, OpenAiAudioSpeechResponseMetadata speechResponseMetadata) { this.speech = speech; this.speechResponseMetadata = speechResponseMetadata; + this.context=new HashMap<>(); + } + + public SpeechResponse(Speech speech, OpenAiAudioSpeechResponseMetadata speechResponseMetadata,Map context) { + this.speech = speech; + this.speechResponseMetadata = speechResponseMetadata; + this.context=context; } @Override @@ -76,6 +83,11 @@ public OpenAiAudioSpeechResponseMetadata getMetadata() { return this.speechResponseMetadata; } + @Override + public Map getContext() { + return this.context; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientMessageAggregator.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientMessageAggregator.java index 582d77b488f..9fee0eb250a 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientMessageAggregator.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientMessageAggregator.java @@ -46,7 +46,7 @@ public Flux aggregateChatClientResponse(Flux> context = new AtomicReference<>(new HashMap<>()); return new MessageAggregator().aggregate(chatClientResponses.mapNotNull(chatClientResponse -> { - context.get().putAll(chatClientResponse.context()); + context.get().putAll(chatClientResponse.chatResponse().getContext()); return chatClientResponse.chatResponse(); }), aggregatedChatResponse -> { ChatClientResponse aggregatedChatClientResponse = ChatClientResponse.builder() diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java index a069702356b..0511f70b7bf 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java @@ -27,23 +27,34 @@ * Represents a response returned by a {@link ChatClient}. * * @param chatResponse The response returned by the AI model - * @param context The contextual data propagated through the execution chain * @author Thomas Vitale * @since 1.0.0 */ -public record ChatClientResponse(@Nullable ChatResponse chatResponse, Map context) { +public record ChatClientResponse(@Nullable ChatResponse chatResponse) { public ChatClientResponse { - Assert.notNull(context, "context cannot be null"); - Assert.noNullElements(context.keySet(), "context keys cannot be null"); } public ChatClientResponse copy() { - return new ChatClientResponse(this.chatResponse, new HashMap<>(this.context)); + if (this.chatResponse == null) { + return new ChatClientResponse(null); + } + Map copiedContext = new HashMap<>(this.chatResponse.getContext()); + ChatResponse copiedChatResponse = new ChatResponse( + this.chatResponse.getGenerations(), + this.chatResponse.getMetadata(), + copiedContext + ); + return new ChatClientResponse(copiedChatResponse); } + public Builder mutate() { - return new Builder().chatResponse(this.chatResponse).context(new HashMap<>(this.context)); + Builder builder = new Builder().chatResponse(this.chatResponse); + if (this.chatResponse != null && this.chatResponse.getContext() != null) { + builder.context(new HashMap<>(this.chatResponse.getContext())); + } + return builder; } public static Builder builder() { @@ -70,16 +81,21 @@ public Builder context(Map context) { return this; } - public Builder context(String key, Object value) { - Assert.notNull(key, "key cannot be null"); - this.context.put(key, value); - return this; - } + // In Builder class public ChatClientResponse build() { - return new ChatClientResponse(this.chatResponse, this.context); + if (this.chatResponse == null) { + return new ChatClientResponse(null); + } + Map mergedContext = new HashMap<>(this.chatResponse.getContext()); + mergedContext.putAll(this.context); + ChatResponse newChatResponse = new ChatResponse( + this.chatResponse.getGenerations(), + this.chatResponse.getMetadata(), + mergedContext + ); + return new ChatClientResponse(newChatResponse); } - } } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index 1b8bbea84e9..69bfe28e1ad 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -108,7 +108,7 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh .map(g -> (Message) g.getOutput()) .toList(); } - this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId), + this.chatMemory.add(this.getConversationId(chatClientResponse.chatResponse().getContext(), this.defaultConversationId), assistantMessages); return chatClientResponse; } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index de88715e896..fde4999cf6b 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -155,18 +155,18 @@ else if (chatClientResponse.chatResponse() != null) { } if (!assistantMessages.isEmpty()) { - this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId), + this.chatMemory.add(this.getConversationId(chatClientResponse.chatResponse().getContext(), this.defaultConversationId), assistantMessages); if (logger.isDebugEnabled()) { logger.debug( "[PromptChatMemoryAdvisor.after] Added ASSISTANT messages to memory for conversationId={}: {}", - this.getConversationId(chatClientResponse.context(), this.defaultConversationId), + this.getConversationId(chatClientResponse.chatResponse().getContext(), this.defaultConversationId), assistantMessages); List memoryMessages = this.chatMemory - .get(this.getConversationId(chatClientResponse.context(), this.defaultConversationId)); + .get(this.getConversationId(chatClientResponse.chatResponse().getContext(), this.defaultConversationId)); logger.debug("[PromptChatMemoryAdvisor.after] Memory after ASSISTANT add for conversationId={}: {}", - this.getConversationId(chatClientResponse.context(), this.defaultConversationId), + this.getConversationId(chatClientResponse.chatResponse().getContext(), this.defaultConversationId), memoryMessages); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientResponseTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientResponseTests.java index 234309a02c8..15a83d38523 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientResponseTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientResponseTests.java @@ -16,13 +16,15 @@ package org.springframework.ai.chat.client; +import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatResponse; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Unit tests for {@link ChatClientResponse}. @@ -31,55 +33,41 @@ */ class ChatClientResponseTests { - @Test - void whenContextIsNullThenThrow() { - assertThatThrownBy(() -> new ChatClientResponse(null, null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); - - assertThatThrownBy(() -> ChatClientResponse.builder().chatResponse(null).context(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); - } - - @Test - void whenContextHasNullKeysThenThrow() { - Map context = new HashMap<>(); - context.put(null, "something"); - assertThatThrownBy(() -> new ChatClientResponse(null, context)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("context keys cannot be null"); - } @Test void whenCopyThenImmutableContext() { Map context = new HashMap<>(); context.put("key", "value"); - ChatClientResponse response = ChatClientResponse.builder().chatResponse(null).context(context).build(); + ChatResponse chatResponse=new ChatResponse(new ArrayList<>(),new ChatResponseMetadata(), context); + ChatClientResponse response = ChatClientResponse.builder().chatResponse(chatResponse).build(); ChatClientResponse copy = response.copy(); - copy.context().put("key2", "value2"); - assertThat(response.context()).doesNotContainKey("key2"); - assertThat(copy.context()).containsKey("key2"); + copy.chatResponse().getContext().put("key2", "value2"); + assertThat(response.chatResponse().getContext()).doesNotContainKey("key2"); + assertThat(copy.chatResponse().getContext()).containsKey("key2"); - copy.context().put("key", "newValue"); - assertThat(copy.context()).containsEntry("key", "newValue"); - assertThat(response.context()).containsEntry("key", "value"); + copy.chatResponse().getContext().put("key", "newValue"); + assertThat(copy.chatResponse().getContext()).containsEntry("key", "newValue"); + assertThat(response.chatResponse().getContext()).containsEntry("key", "value"); } @Test void whenMutateThenImmutableContext() { Map context = new HashMap<>(); context.put("key", "value"); - ChatClientResponse response = ChatClientResponse.builder().chatResponse(null).context(context).build(); - - ChatClientResponse copy = response.mutate().context(Map.of("key2", "value2")).build(); - - assertThat(response.context()).doesNotContainKey("key2"); - assertThat(copy.context()).containsKey("key2"); - - copy.context().put("key", "newValue"); - assertThat(copy.context()).containsEntry("key", "newValue"); - assertThat(response.context()).containsEntry("key", "value"); + ChatResponse chatResponse=new ChatResponse(new ArrayList<>(),new ChatResponseMetadata(),context); + ChatClientResponse response = ChatClientResponse.builder().chatResponse(chatResponse).build(); + HashMap hashMap=new HashMap<>(); + hashMap.put("key2","value"); + ChatClientResponse copy = response.mutate().context(hashMap).build(); + + assertThat(response.chatResponse().getContext()).doesNotContainKey("key2"); + assertThat(copy.chatResponse().getContext()).containsKey("key2"); + + copy.chatResponse().getContext().put("key", "newValue"); + assertThat(copy.chatResponse().getContext()).containsEntry("key", "newValue"); + assertThat(response.chatResponse().getContext()).containsEntry("key", "value"); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java index 48b3da6873e..f531450c726 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java @@ -17,6 +17,7 @@ package org.springframework.ai.chat.client.advisor; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -27,6 +28,7 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; @@ -65,9 +67,18 @@ public void callAdvisorsContextPropagation() { // the priority. var mockAroundAdvisor1 = new MockAroundAdvisor("Advisor1", 0); var mockAroundAdvisor2 = new MockAroundAdvisor("Advisor2", 1); + HashMap context=new HashMap<>(); + context.put("key1", "value1"); + context.put("key2", "value2"); + context.put("aroundCallBeforeAdvisor1", "AROUND_CALL_BEFORE Advisor1"); + context.put("aroundCallAfterAdvisor1", "AROUND_CALL_AFTER Advisor1"); + context.put("aroundCallBeforeAdvisor2", "AROUND_CALL_BEFORE Advisor2"); + context.put("aroundCallAfterAdvisor2", "AROUND_CALL_AFTER Advisor2"); + context.put("lastBefore", "Advisor2"); + context.put("lastAfter", "Advisor1"); given(this.chatModel.call(this.promptCaptor.capture())) - .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))))); + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))),new ChatResponseMetadata(),context)); var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") @@ -85,7 +96,7 @@ public void callAdvisorsContextPropagation() { // AROUND assertThat(mockAroundAdvisor1.chatClientResponse.chatResponse()).isNotNull(); - assertThat(mockAroundAdvisor1.chatClientResponse.context()).containsEntry("key1", "value1") + assertThat(mockAroundAdvisor1.chatClientResponse.chatResponse().getContext()).containsEntry("key1", "value1") .containsEntry("key2", "value2") .containsEntry("aroundCallBeforeAdvisor1", "AROUND_CALL_BEFORE Advisor1") .containsEntry("aroundCallAfterAdvisor1", "AROUND_CALL_AFTER Advisor1") @@ -97,15 +108,27 @@ public void callAdvisorsContextPropagation() { verify(this.chatModel).call(this.promptCaptor.capture()); } + @Test public void streamAdvisorsContextPropagation() { var mockAroundAdvisor1 = new MockAroundAdvisor("Advisor1", 0); var mockAroundAdvisor2 = new MockAroundAdvisor("Advisor2", 1); + HashMap context = new HashMap<>(); + context.put("key1", "value1"); + context.put("key2", "value2"); + context.put("aroundCallBeforeAdvisor1", "AROUND_CALL_BEFORE Advisor1"); + context.put("aroundCallAfterAdvisor1", "AROUND_CALL_AFTER Advisor1"); + context.put("aroundCallBeforeAdvisor2", "AROUND_CALL_BEFORE Advisor2"); + context.put("aroundCallAfterAdvisor2", "AROUND_CALL_AFTER Advisor2"); + context.put("lastBefore", "Advisor2"); + context.put("lastAfter", "Advisor1"); given(this.chatModel.stream(this.promptCaptor.capture())) - .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello")))), - new ChatResponse(List.of(new Generation(new AssistantMessage(" John")))))); + .willReturn(Flux.just( + new ChatResponse(List.of(new Generation(new AssistantMessage("Hello"))), null, context), + new ChatResponse(List.of(new Generation(new AssistantMessage(" John"))), null, context) + )); var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") @@ -129,7 +152,7 @@ public void streamAdvisorsContextPropagation() { assertThat(mockAroundAdvisor1.advisedChatClientResponses).isNotEmpty(); mockAroundAdvisor1.advisedChatClientResponses.stream() - .forEach(chatClientResponse -> assertThat(chatClientResponse.context()).containsEntry("key1", "value1") + .forEach(chatClientResponse -> assertThat(chatClientResponse.chatResponse().getContext()).containsEntry("key1", "value1") .containsEntry("key2", "value2") .containsEntry("aroundStreamBeforeAdvisor1", "AROUND_STREAM_BEFORE Advisor1") .containsEntry("aroundStreamAfterAdvisor1", "AROUND_STREAM_AFTER Advisor1") @@ -172,20 +195,27 @@ public int getOrder() { @Override public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { this.chatClientRequest = chatClientRequest.mutate() - .context(Map.of("aroundCallBefore" + getName(), "AROUND_CALL_BEFORE " + getName(), "lastBefore", - getName())) - .build(); + .context(mergeContext(chatClientRequest.context(), Map.of( + "aroundCallBefore" + getName(), "AROUND_CALL_BEFORE " + getName(), + "lastBefore", getName()))) + .build(); var chatClientResponse = callAdvisorChain.nextCall(this.chatClientRequest); - this.chatClientResponse = chatClientResponse.mutate() - .context( - Map.of("aroundCallAfter" + getName(), "AROUND_CALL_AFTER " + getName(), "lastAfter", getName())) - .build(); - + .context(mergeContext(chatClientResponse.chatResponse().getContext(), Map.of( + "aroundCallAfter" + getName(), "AROUND_CALL_AFTER " + getName(), + "lastAfter", getName()))) + .build(); + System.out.println("...."+this.chatClientResponse.chatResponse().getContext()); return this.chatClientResponse; } + private Map mergeContext(Map original, Map additions) { + Map merged = new HashMap<>(original); + merged.putAll(additions); + return merged; + } + @Override public Flux adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java b/spring-ai-model/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java index c274698d148..1709f62f42f 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java @@ -16,7 +16,9 @@ package org.springframework.ai.audio.transcription; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.springframework.ai.model.ModelResponse; @@ -33,14 +35,21 @@ public class AudioTranscriptionResponse implements ModelResponse transcriptionContext; + public AudioTranscriptionResponse(AudioTranscription transcript) { - this(transcript, new AudioTranscriptionResponseMetadata()); + this(transcript, new AudioTranscriptionResponseMetadata(),new HashMap<>()); + } + + public AudioTranscriptionResponse(AudioTranscription transcript,AudioTranscriptionResponseMetadata transcriptionResponseMetadata) { + this(transcript, new AudioTranscriptionResponseMetadata(),new HashMap<>()); } public AudioTranscriptionResponse(AudioTranscription transcript, - AudioTranscriptionResponseMetadata transcriptionResponseMetadata) { + AudioTranscriptionResponseMetadata transcriptionResponseMetadata, Map transcriptionContext) { this.transcript = transcript; this.transcriptionResponseMetadata = transcriptionResponseMetadata; + this.transcriptionContext=transcriptionContext; } @Override @@ -58,4 +67,17 @@ public AudioTranscriptionResponseMetadata getMetadata() { return this.transcriptionResponseMetadata; } + @Override + public Map getContext() { + return transcriptionContext; + } + + @Override + public String toString() { + return "AudioTranscriptionResponse{" + + "transcript=" + transcript + + ", transcriptionResponseMetadata=" + transcriptionResponseMetadata + + ", transcriptionContext=" + transcriptionContext + + '}'; + } } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/audio/tts/TextToSpeechResponse.java b/spring-ai-model/src/main/java/org/springframework/ai/audio/tts/TextToSpeechResponse.java index 6e23ef43e49..5a39912e3d4 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/audio/tts/TextToSpeechResponse.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/audio/tts/TextToSpeechResponse.java @@ -16,7 +16,9 @@ package org.springframework.ai.audio.tts; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import org.springframework.ai.model.ModelResponse; @@ -32,13 +34,16 @@ public class TextToSpeechResponse implements ModelResponse { private final TextToSpeechResponseMetadata textToSpeechResponseMetadata; + private final Map textToSpeechContext; + public TextToSpeechResponse(List results) { - this(results, null); + this(results, null,new HashMap<>()); } - public TextToSpeechResponse(List results, TextToSpeechResponseMetadata textToSpeechResponseMetadata) { + public TextToSpeechResponse(List results, TextToSpeechResponseMetadata textToSpeechResponseMetadata, Map textToSpeechContext) { this.results = results; this.textToSpeechResponseMetadata = textToSpeechResponseMetadata; + this.textToSpeechContext= textToSpeechContext; } @Override @@ -55,6 +60,11 @@ public TextToSpeechResponseMetadata getMetadata() { return this.textToSpeechResponseMetadata; } + @Override + public Map getContext() { + return textToSpeechContext; + } + @Override public boolean equals(Object o) { if (this == o) @@ -71,7 +81,10 @@ public int hashCode() { @Override public String toString() { - return "TextToSpeechResponse{" + "results=" + results + '}'; + return "TextToSpeechResponse{" + + "results=" + results + + ", textToSpeechResponseMetadata=" + textToSpeechResponseMetadata + + ", textToSpeechContext=" + textToSpeechContext + + '}'; } - } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/ChatModel.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/ChatModel.java index 95bebeb2ea6..a5c0f849739 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/ChatModel.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/ChatModel.java @@ -51,4 +51,12 @@ default Flux stream(Prompt prompt) { throw new UnsupportedOperationException("streaming is not supported"); } + default boolean isStop(ChatResponse chatResponse) { + + return (null !=chatResponse.getResult() + && null !=chatResponse.getResult().getMetadata() + && null !=chatResponse.getResult().getMetadata().getFinishReason() + && "stop".equals(chatResponse.getResult().getMetadata().getFinishReason())); + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/ChatResponse.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/ChatResponse.java index 66a844bab25..dc55b397fd7 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/ChatResponse.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/ChatResponse.java @@ -16,10 +16,7 @@ package org.springframework.ai.chat.model; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; +import java.util.*; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.model.ModelResponse; @@ -45,25 +42,56 @@ public class ChatResponse implements ModelResponse { */ private final List generations; + public ChatResponseMetadata getChatResponseMetadata() { + return chatResponseMetadata; + } + + public List getGenerations() { + return generations; + } + /** * Construct a new {@link ChatResponse} instance without metadata. + * * @param generations the {@link List} of {@link Generation} returned by the AI * provider. */ + private final Map context; + + + /** + * Constructs a new `ChatResponse` with the provided list of generations. + * + * @param generations the list of `Generation` objects returned by the AI provider + */ + public ChatResponse(List generations) { - this(generations, new ChatResponseMetadata()); + this(generations, new ChatResponseMetadata(), Map.of()); } + + /** * Construct a new {@link ChatResponse} instance. - * @param generations the {@link List} of {@link Generation} returned by the AI - * provider. + * + * @param generations the {@link List} of {@link Generation} returned by the AI + * provider. * @param chatResponseMetadata {@link ChatResponseMetadata} containing information - * about the use of the AI provider's API. + * about the use of the AI provider's API. */ - public ChatResponse(List generations, ChatResponseMetadata chatResponseMetadata) { + public ChatResponse(List generations, ChatResponseMetadata chatResponseMetadata, Map context) { this.chatResponseMetadata = chatResponseMetadata; this.generations = List.copyOf(generations); + this.context = context != null ? new HashMap<>(context) : new HashMap<>(); + } + + @Override + public Map getContext() { + return context; + } + + public ChatResponse(List generations, ChatResponseMetadata chatResponseMetadata) { + this(generations, chatResponseMetadata, Map.of()); } public static Builder builder() { @@ -75,6 +103,7 @@ public static Builder builder() { *

* It is a {@link List} of {@link List lists} because the Prompt could request * multiple output {@link Generation generations}. + * * @return the {@link List} of {@link Generation generated outputs}. */ @@ -102,6 +131,7 @@ public ChatResponseMetadata getMetadata() { return this.chatResponseMetadata; } + /** * Whether the model has requested the execution of a tool. */ @@ -129,7 +159,7 @@ public boolean hasFinishReasons(Set finishReasons) { @Override public String toString() { - return "ChatResponse [metadata=" + this.chatResponseMetadata + ", generations=" + this.generations + "]"; + return "ChatResponse [metadata=" + this.chatResponseMetadata + ", generations=" + this.generations + ", context=" +this.context+"]"; } @Override @@ -149,12 +179,17 @@ public int hashCode() { return Objects.hash(this.chatResponseMetadata, this.generations); } + + + public static final class Builder { private List generations; private ChatResponseMetadata.Builder chatResponseMetadataBuilder; + private Map context; + private Builder() { this.chatResponseMetadataBuilder = ChatResponseMetadata.builder(); } @@ -185,11 +220,15 @@ public Builder metadata(ChatResponseMetadata other) { public Builder generations(List generations) { this.generations = generations; return this; + } + public Builder context(Map context) { + this.context = context != null ? new HashMap<>(context) : new HashMap<>(); + return this; } public ChatResponse build() { - return new ChatResponse(this.generations, this.chatResponseMetadataBuilder.build()); + return new ChatResponse(this.generations, this.chatResponseMetadataBuilder.build(), this.context); } } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java b/spring-ai-model/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java index 2ad2afac32f..a47d6a09fb7 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java @@ -16,7 +16,9 @@ package org.springframework.ai.embedding; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import org.springframework.ai.model.ModelResponse; @@ -41,8 +43,17 @@ public class EmbeddingResponse implements ModelResponse { * Creates a new {@link EmbeddingResponse} instance with empty metadata. * @param embeddings the embedding data. */ + + private final Map embeddingContext; + public EmbeddingResponse(List embeddings) { - this(embeddings, new EmbeddingResponseMetadata()); + this(embeddings, new EmbeddingResponseMetadata(),new HashMap<>()); + } + + public EmbeddingResponse(List embeddings, EmbeddingResponseMetadata metadata,Map embeddingContext) { + this.embeddings = embeddings; + this.metadata = metadata; + this.embeddingContext=embeddingContext; } /** @@ -53,6 +64,7 @@ public EmbeddingResponse(List embeddings) { public EmbeddingResponse(List embeddings, EmbeddingResponseMetadata metadata) { this.embeddings = embeddings; this.metadata = metadata; + this.embeddingContext=new HashMap<>(); } /** @@ -62,6 +74,11 @@ public EmbeddingResponseMetadata getMetadata() { return this.metadata; } + @Override + public Map getContext() { + return embeddingContext; + } + @Override public Embedding getResult() { Assert.notEmpty(this.embeddings, "No embedding data available."); @@ -95,7 +112,10 @@ public int hashCode() { @Override public String toString() { - return "EmbeddingResult{" + "data=" + this.embeddings + ", metadata=" + this.metadata + '}'; + return "EmbeddingResponse{" + + "embeddings=" + embeddings + + ", metadata=" + metadata + + ", embeddingContext=" + embeddingContext + + '}'; } - } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java index c4605d81890..60734fdc648 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java @@ -16,7 +16,9 @@ package org.springframework.ai.image; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import org.springframework.ai.model.ModelResponse; @@ -40,23 +42,35 @@ public class ImageResponse implements ModelResponse { /** * Construct a new {@link ImageResponse} instance without metadata. + * * @param generations the {@link List} of {@link ImageGeneration} returned by the AI * provider. */ + + private final Map context; + public ImageResponse(List generations) { - this(generations, new ImageResponseMetadata()); + this(generations, new ImageResponseMetadata(), new HashMap<>()); } /** * Construct a new {@link ImageResponse} instance. - * @param generations the {@link List} of {@link ImageGeneration} returned by the AI - * provider. + * + * @param generations the {@link List} of {@link ImageGeneration} returned by the AI + * provider. * @param imageResponseMetadata {@link ImageResponseMetadata} containing information - * about the use of the AI provider's API. + * about the use of the AI provider's API. */ public ImageResponse(List generations, ImageResponseMetadata imageResponseMetadata) { this.imageResponseMetadata = imageResponseMetadata; this.imageGenerations = List.copyOf(generations); + this.context = new HashMap<>(); + } + + public ImageResponse(List generations, ImageResponseMetadata imageResponseMetadata, Map context) { + this.imageResponseMetadata = imageResponseMetadata; + this.imageGenerations = List.copyOf(generations); + this.context = context; } /** @@ -64,6 +78,7 @@ public ImageResponse(List generations, ImageResponseMetadata im *

* It is a {@link List} of {@link List lists} because the Prompt could request * multiple output {@link ImageGeneration generations}. + * * @return the {@link List} of {@link ImageGeneration generated outputs}. */ @Override @@ -91,10 +106,18 @@ public ImageResponseMetadata getMetadata() { return this.imageResponseMetadata; } + @Override + public Map getContext() { + return this.context; + } + @Override public String toString() { - return "ImageResponse [" + "imageResponseMetadata=" + this.imageResponseMetadata + ", imageGenerations=" - + this.imageGenerations + "]"; + return "ImageResponse{" + + "imageResponseMetadata=" + imageResponseMetadata + + ", imageGenerations=" + imageGenerations + + ", context=" + context + + '}'; } @Override diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/ModelResponse.java b/spring-ai-model/src/main/java/org/springframework/ai/model/ModelResponse.java index 5df8b8d2a82..575e363ee33 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/ModelResponse.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/ModelResponse.java @@ -17,6 +17,7 @@ package org.springframework.ai.model; import java.util.List; +import java.util.Map; /** * Interface representing the response received from an AI model. This interface provides @@ -49,4 +50,10 @@ public interface ModelResponse> { */ ResponseMetadata getMetadata(); + /** + * Retrieves the context of the response + * @return + */ + Map getContext(); + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/moderation/ModerationResponse.java b/spring-ai-model/src/main/java/org/springframework/ai/moderation/ModerationResponse.java index 043104436e1..2bb1ad22862 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/moderation/ModerationResponse.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/moderation/ModerationResponse.java @@ -16,7 +16,9 @@ package org.springframework.ai.moderation; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import org.springframework.ai.model.ModelResponse; @@ -37,6 +39,8 @@ public class ModerationResponse implements ModelResponse { private final Generation generations; + private final Map context; + public ModerationResponse(Generation generations) { this(generations, new ModerationResponseMetadata()); } @@ -44,8 +48,16 @@ public ModerationResponse(Generation generations) { public ModerationResponse(Generation generations, ModerationResponseMetadata moderationResponseMetadata) { this.moderationResponseMetadata = moderationResponseMetadata; this.generations = generations; + this.context=new HashMap<>(); + } + + public ModerationResponse(Generation generations, ModerationResponseMetadata moderationResponseMetadata,Map context) { + this.moderationResponseMetadata = moderationResponseMetadata; + this.generations = generations; + this.context=context; } + @Override public Generation getResult() { return this.generations; @@ -61,10 +73,18 @@ public ModerationResponseMetadata getMetadata() { return this.moderationResponseMetadata; } + @Override + public Map getContext() { + return this.context; + } + @Override public String toString() { - return "ModerationResponse{" + "moderationResponseMetadata=" + this.moderationResponseMetadata - + ", generations=" + this.generations + '}'; + return "ModerationResponse{" + + "moderationResponseMetadata=" + moderationResponseMetadata + + ", generations=" + generations + + ", context=" + context + + '}'; } @Override