Skip to content

changes #3984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

changes #3984

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion models/spring-ai-ollama/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,11 @@
<artifactId>ollama</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-client-chat</artifactId>
<version>1.1.0-SNAPSHOT</version>
<scope>compile</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,15 @@
package org.springframework.ai.ollama;

import java.time.Duration;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.*;

import com.fasterxml.jackson.core.type.TypeReference;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClientRequest;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

Expand Down Expand Up @@ -234,6 +232,9 @@ public ChatResponse call(Prompt prompt) {
return this.internalCall(requestPrompt, null);
}




private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {

OllamaApi.ChatRequest request = ollamaChatRequest(prompt, false);
Expand Down Expand Up @@ -304,6 +305,22 @@ public Flux<ChatResponse> stream(Prompt prompt) {
return this.internalStream(requestPrompt, null);
}

public Flux<ChatResponse> stream(ChatClientRequest chatClientRequest) {
Prompt prompt = chatClientRequest.prompt();
Prompt requestPrompt = buildRequestPrompt(prompt);
Flux<ChatResponse> responseFlux = this.internalStream(requestPrompt, null);
return responseFlux.map(chatResponse -> {
if (isStop(chatResponse)) {
return ChatResponse.builder()
.context(chatClientRequest.context())
.from(chatResponse)
.build();
}
return chatResponse;
});
}


private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {
OllamaApi.ChatRequest request = ollamaChatRequest(prompt, true);
Expand Down Expand Up @@ -337,7 +354,6 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh
}

var assistantMessage = new AssistantMessage(content, Map.of(), toolCalls);

ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
generationMetadata = ChatGenerationMetadata.builder().finishReason(chunk.doneReason()).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

package org.springframework.ai.ollama;

import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;

import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistryAssert;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.client.ChatClientRequest;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.metadata.ChatResponseMetadata;
Expand Down Expand Up @@ -90,6 +92,36 @@ void observationForChatOperation() {
validate(responseMetadata);
}

@Test
void observationForChatOperationUsingRequest() {
var options = OllamaOptions.builder()
.model(MODEL)
.frequencyPenalty(0.0)
.numPredict(2048)
.presencePenalty(0.0)
.stop(List.of("this-is-the-end"))
.temperature(0.7)
.topK(1)
.topP(1.0)
.build();
HashMap<String, Object> ctx = new HashMap<>();
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
ChatClientRequest chatClientRequest=new ChatClientRequest(prompt,ctx);
Flux<ChatResponse> chatResponseFlux= this.chatModel.stream(chatClientRequest);
List<ChatResponse> responses = chatResponseFlux.collectList().block();
System.out.println(responses);
assertThat(responses).isNotEmpty();
assertThat(responses).hasSizeGreaterThan(10);

String aggregatedResponse = responses.subList(0, responses.size() - 1)
.stream()
.map(r -> r.getResult().getOutput().getText())
.collect(Collectors.joining());
assertThat(aggregatedResponse).isNotEmpty();


}

@Test
void observationForStreamingChatOperation() {
var options = OllamaOptions.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@

package org.springframework.ai.openai.audio.speech;

import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.*;

import org.springframework.ai.model.ModelResponse;
import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata;
Expand All @@ -38,6 +36,8 @@ public class SpeechResponse implements ModelResponse<Speech> {

private final OpenAiAudioSpeechResponseMetadata speechResponseMetadata;

private final Map<String,Speech> context;

/**
* Creates a new instance of SpeechResponse with the given speech result.
* @param speech the speech result to be set in the SpeechResponse
Expand All @@ -59,6 +59,13 @@ public SpeechResponse(Speech speech) {
public SpeechResponse(Speech speech, OpenAiAudioSpeechResponseMetadata speechResponseMetadata) {
this.speech = speech;
this.speechResponseMetadata = speechResponseMetadata;
this.context=new HashMap<>();
}

public SpeechResponse(Speech speech, OpenAiAudioSpeechResponseMetadata speechResponseMetadata,Map<String,Speech> context) {
this.speech = speech;
this.speechResponseMetadata = speechResponseMetadata;
this.context=context;
}

@Override
Expand All @@ -76,6 +83,11 @@ public OpenAiAudioSpeechResponseMetadata getMetadata() {
return this.speechResponseMetadata;
}

@Override
public Map<String,Speech> getContext() {
return this.context;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public Flux<ChatClientResponse> aggregateChatClientResponse(Flux<ChatClientRespo
AtomicReference<Map<String, Object>> context = new AtomicReference<>(new HashMap<>());

return new MessageAggregator().aggregate(chatClientResponses.mapNotNull(chatClientResponse -> {
context.get().putAll(chatClientResponse.context());
context.get().putAll(chatClientResponse.chatResponse().getContext());
return chatClientResponse.chatResponse();
}), aggregatedChatResponse -> {
ChatClientResponse aggregatedChatClientResponse = ChatClientResponse.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,34 @@
* Represents a response returned by a {@link ChatClient}.
*
* @param chatResponse The response returned by the AI model
* @param context The contextual data propagated through the execution chain
* @author Thomas Vitale
* @since 1.0.0
*/
public record ChatClientResponse(@Nullable ChatResponse chatResponse, Map<String, Object> context) {
public record ChatClientResponse(@Nullable ChatResponse chatResponse) {

public ChatClientResponse {
Assert.notNull(context, "context cannot be null");
Assert.noNullElements(context.keySet(), "context keys cannot be null");
}

public ChatClientResponse copy() {
return new ChatClientResponse(this.chatResponse, new HashMap<>(this.context));
if (this.chatResponse == null) {
return new ChatClientResponse(null);
}
Map<String, Object> copiedContext = new HashMap<>(this.chatResponse.getContext());
ChatResponse copiedChatResponse = new ChatResponse(
this.chatResponse.getGenerations(),
this.chatResponse.getMetadata(),
copiedContext
);
return new ChatClientResponse(copiedChatResponse);
}


public Builder mutate() {
return new Builder().chatResponse(this.chatResponse).context(new HashMap<>(this.context));
Builder builder = new Builder().chatResponse(this.chatResponse);
if (this.chatResponse != null && this.chatResponse.getContext() != null) {
builder.context(new HashMap<>(this.chatResponse.getContext()));
}
return builder;
}

public static Builder builder() {
Expand All @@ -70,16 +81,21 @@ public Builder context(Map<String, Object> context) {
return this;
}

public Builder context(String key, Object value) {
Assert.notNull(key, "key cannot be null");
this.context.put(key, value);
return this;
}

// In Builder class
public ChatClientResponse build() {
return new ChatClientResponse(this.chatResponse, this.context);
if (this.chatResponse == null) {
return new ChatClientResponse(null);
}
Map<String, Object> mergedContext = new HashMap<>(this.chatResponse.getContext());
mergedContext.putAll(this.context);
ChatResponse newChatResponse = new ChatResponse(
this.chatResponse.getGenerations(),
this.chatResponse.getMetadata(),
mergedContext
);
return new ChatClientResponse(newChatResponse);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh
.map(g -> (Message) g.getOutput())
.toList();
}
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
this.chatMemory.add(this.getConversationId(chatClientResponse.chatResponse().getContext(), this.defaultConversationId),
assistantMessages);
return chatClientResponse;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,18 @@ else if (chatClientResponse.chatResponse() != null) {
}

if (!assistantMessages.isEmpty()) {
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
this.chatMemory.add(this.getConversationId(chatClientResponse.chatResponse().getContext(), this.defaultConversationId),
assistantMessages);

if (logger.isDebugEnabled()) {
logger.debug(
"[PromptChatMemoryAdvisor.after] Added ASSISTANT messages to memory for conversationId={}: {}",
this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
this.getConversationId(chatClientResponse.chatResponse().getContext(), this.defaultConversationId),
assistantMessages);
List<Message> memoryMessages = this.chatMemory
.get(this.getConversationId(chatClientResponse.context(), this.defaultConversationId));
.get(this.getConversationId(chatClientResponse.chatResponse().getContext(), this.defaultConversationId));
logger.debug("[PromptChatMemoryAdvisor.after] Memory after ASSISTANT add for conversationId={}: {}",
this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
this.getConversationId(chatClientResponse.chatResponse().getContext(), this.defaultConversationId),
memoryMessages);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

package org.springframework.ai.chat.client;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatResponse;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* Unit tests for {@link ChatClientResponse}.
Expand All @@ -31,55 +33,41 @@
*/
class ChatClientResponseTests {

@Test
void whenContextIsNullThenThrow() {
assertThatThrownBy(() -> new ChatClientResponse(null, null)).isInstanceOf(IllegalArgumentException.class)
.hasMessage("context cannot be null");

assertThatThrownBy(() -> ChatClientResponse.builder().chatResponse(null).context(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("context cannot be null");
}

@Test
void whenContextHasNullKeysThenThrow() {
Map<String, Object> context = new HashMap<>();
context.put(null, "something");
assertThatThrownBy(() -> new ChatClientResponse(null, context)).isInstanceOf(IllegalArgumentException.class)
.hasMessage("context keys cannot be null");
}

@Test
void whenCopyThenImmutableContext() {
Map<String, Object> context = new HashMap<>();
context.put("key", "value");
ChatClientResponse response = ChatClientResponse.builder().chatResponse(null).context(context).build();
ChatResponse chatResponse=new ChatResponse(new ArrayList<>(),new ChatResponseMetadata(), context);
ChatClientResponse response = ChatClientResponse.builder().chatResponse(chatResponse).build();

ChatClientResponse copy = response.copy();

copy.context().put("key2", "value2");
assertThat(response.context()).doesNotContainKey("key2");
assertThat(copy.context()).containsKey("key2");
copy.chatResponse().getContext().put("key2", "value2");
assertThat(response.chatResponse().getContext()).doesNotContainKey("key2");
assertThat(copy.chatResponse().getContext()).containsKey("key2");

copy.context().put("key", "newValue");
assertThat(copy.context()).containsEntry("key", "newValue");
assertThat(response.context()).containsEntry("key", "value");
copy.chatResponse().getContext().put("key", "newValue");
assertThat(copy.chatResponse().getContext()).containsEntry("key", "newValue");
assertThat(response.chatResponse().getContext()).containsEntry("key", "value");
}

@Test
void whenMutateThenImmutableContext() {
Map<String, Object> context = new HashMap<>();
context.put("key", "value");
ChatClientResponse response = ChatClientResponse.builder().chatResponse(null).context(context).build();

ChatClientResponse copy = response.mutate().context(Map.of("key2", "value2")).build();

assertThat(response.context()).doesNotContainKey("key2");
assertThat(copy.context()).containsKey("key2");

copy.context().put("key", "newValue");
assertThat(copy.context()).containsEntry("key", "newValue");
assertThat(response.context()).containsEntry("key", "value");
ChatResponse chatResponse=new ChatResponse(new ArrayList<>(),new ChatResponseMetadata(),context);
ChatClientResponse response = ChatClientResponse.builder().chatResponse(chatResponse).build();
HashMap<String,Object> hashMap=new HashMap<>();
hashMap.put("key2","value");
ChatClientResponse copy = response.mutate().context(hashMap).build();

assertThat(response.chatResponse().getContext()).doesNotContainKey("key2");
assertThat(copy.chatResponse().getContext()).containsKey("key2");

copy.chatResponse().getContext().put("key", "newValue");
assertThat(copy.chatResponse().getContext()).containsEntry("key", "newValue");
assertThat(response.chatResponse().getContext()).containsEntry("key", "value");
}

}
Loading