Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,11 @@ void addAndGet() {
assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator()
.isEqualTo(media);
memory.deleteByConversationId(sessionId);
ToolResponseMessage toolResponseMessage = new ToolResponseMessage(
List.of(new ToolResponse("id", "name", "responseData"),
new ToolResponse("id2", "name2", "responseData2")),
Map.of("id", "id", "metadataKey", "metadata"));
ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("id", "name", "responseData"),
new ToolResponse("id2", "name2", "responseData2")))
.metadata(Map.of("id", "id", "metadataKey", "metadata"))
.build();
memory.saveAll(sessionId, List.of(toolResponseMessage));
messages = memory.findByConversationId(sessionId);
assertThat(messages.size()).isEqualTo(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ private Message getMessage(UdtValue udt) {
return SystemMessage.builder().text(content).metadata(props).build();
case TOOL:
// todo – persist ToolResponse somehow
return new ToolResponseMessage(List.of(), props);
return ToolResponseMessage.builder().responses(List.of()).metadata(props).build();
default:
throw new IllegalStateException(
String.format("unknown message type %s", udt.getString(this.conf.messageUdtTypeColumn)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ private Message mapToMessage(Map<String, Object> doc) {
case ASSISTANT -> new AssistantMessage(content, metadata);
case USER -> UserMessage.builder().text(content).metadata(metadata).build();
case SYSTEM -> SystemMessage.builder().text(content).metadata(metadata).build();
case TOOL -> new ToolResponseMessage(List.of(), metadata);
case TOOL -> ToolResponseMessage.builder().responses(List.of()).metadata(metadata).build();
default -> throw new IllegalStateException(String.format("Unknown message type: %s", messageTypeStr));
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public Message mapRow(ResultSet rs, int i) throws SQLException {
// The content is always stored empty for ToolResponseMessages.
// If we want to capture the actual content, we need to extend
// AddBatchPreparedStatement to support it.
case TOOL -> new ToolResponseMessage(List.of());
case TOOL -> ToolResponseMessage.builder().responses(List.of()).build();
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.content.Media;
import org.springframework.ai.content.MediaContent;
Expand Down Expand Up @@ -172,12 +173,12 @@ public Neo4jChatMemoryRepositoryConfig getConfig() {

private Message buildToolMessage(org.neo4j.driver.Record record) {
Message message;
message = new ToolResponseMessage(record.get("toolResponses").asList(v -> {
message = ToolResponseMessage.builder().responses(record.get("toolResponses").asList(v -> {
Map<String, Object> trMap = v.asMap();
return new ToolResponseMessage.ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()),
return new ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()),
(String) trMap.get(ToolResponseAttributes.NAME.getValue()),
(String) trMap.get(ToolResponseAttributes.RESPONSE_DATA.getValue()));
}), record.get("metadata").asMap());
})).metadata(record.get("metadata").asMap()).build();
return message;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ void saveAndFindMultipleMessages() {
List<Message> messages = List.of(new AssistantMessage("Message from assistant - " + conversationId),
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId),
new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData"))));
ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("id", "name", "responseData")))
.build());

this.chatMemoryRepository.saveAll(conversationId, messages);
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
Expand Down Expand Up @@ -285,9 +287,11 @@ void handleAssistantMessageWithToolCalls() {
void handleToolResponseMessage() {
var conversationId = UUID.randomUUID().toString();

ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List
.of(new ToolResponse("id1", "name1", "responseData1"), new ToolResponse("id2", "name2", "responseData2")),
Map.of("metadataKey", "metadataValue"));
ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("id1", "name1", "responseData1"),
new ToolResponse("id2", "name2", "responseData2")))
.metadata(Map.of("metadataKey", "metadataValue"))
.build();

this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(toolResponseMessage));

Expand Down Expand Up @@ -408,7 +412,9 @@ private Message createMessageByType(String content, MessageType messageType) {
case ASSISTANT -> new AssistantMessage(content);
case USER -> new UserMessage(content);
case SYSTEM -> new SystemMessage(content);
case TOOL -> new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData")));
case TOOL -> ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("id", "name", "responseData")))
.build();
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ void createChatCompletionMessagesWithToolResponseMessage() {
var toolResponse1 = createToolResponse(1);
var toolResponse2 = createToolResponse(2);
var toolResponse3 = createToolResponse(3);
var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse1, toolResponse2, toolResponse3));
var toolResponseMessage = ToolResponseMessage.builder()
.responses(List.of(toolResponse1, toolResponse2, toolResponse3))
.build();
var prompt = createPrompt(toolResponseMessage);
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
var chatCompletionMessages = chatCompletionRequest.messages();
Expand All @@ -212,7 +214,7 @@ void createChatCompletionMessagesWithToolResponseMessage() {
@Test
void createChatCompletionMessagesWithInvalidToolResponseMessage() {
var toolResponse = new ToolResponseMessage.ToolResponse(null, null, null);
var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse));
var toolResponseMessage = ToolResponseMessage.builder().responses(List.of(toolResponse)).build();
var prompt = createPrompt(toolResponseMessage);
assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false))
.isInstanceOf(IllegalArgumentException.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
Expand Down Expand Up @@ -256,11 +257,10 @@ private static List<Message> createMessagesWithAllMessageTypes() {
var systemMessage = new SystemMessage("Test system message");
var userMessage = new UserMessage("Test user message");
// @formatter:off
var toolResponseMessage = new ToolResponseMessage(List.of(
new ToolResponseMessage.ToolResponse("tool1", "Tool 1", "Test tool response 1"),
new ToolResponseMessage.ToolResponse("tool2", "Tool 2", "Test tool response 2"),
new ToolResponseMessage.ToolResponse("tool3", "Tool 3", "Test tool response 3"))
);
var toolResponseMessage = ToolResponseMessage.builder().responses(List.of(
new ToolResponse("tool1", "Tool 1", "Test tool response 1"),
new ToolResponse("tool2", "Tool 2", "Test tool response 2"),
new ToolResponse("tool3", "Tool 3", "Test tool response 3"))).build();
// @formatter:on
var assistantMessage = new AssistantMessage("Test assistant message");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,27 @@ public class ToolResponseMessage extends AbstractMessage {

protected final List<ToolResponse> responses;

/**
* @deprecated in favor of using {@link ToolResponseMessage.Builder}
*/
@Deprecated
public ToolResponseMessage(List<ToolResponse> responses) {
this(responses, Map.of());
}

/**
* @deprecated in favor of using {@link ToolResponseMessage.Builder}
*/
@Deprecated
public ToolResponseMessage(List<ToolResponse> responses, Map<String, Object> metadata) {
super(MessageType.TOOL, "", metadata);
this.responses = responses;
}

public static Builder builder() {
return new Builder();
}

public List<ToolResponse> getResponses() {
return this.responses;
}
Expand Down Expand Up @@ -73,4 +85,29 @@ public record ToolResponse(String id, String name, String responseData) {

}

public static final class Builder {

private List<ToolResponse> responses;

private Map<String, Object> metadata = Map.of();

private Builder() {
}

public Builder responses(List<ToolResponse> responses) {
this.responses = responses;
return this;
}

public Builder metadata(Map<String, Object> metadata) {
this.metadata = metadata;
return this;
}

public ToolResponseMessage build() {
return new ToolResponseMessage(this.responses, this.metadata);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,10 @@ else if (message instanceof AssistantMessage assistantMessage) {
.build());
}
else if (message instanceof ToolResponseMessage toolResponseMessage) {
messagesCopy.add(new ToolResponseMessage(new ArrayList<>(toolResponseMessage.getResponses()),
new HashMap<>(toolResponseMessage.getMetadata())));
messagesCopy.add(ToolResponseMessage.builder()
.responses(new ArrayList<>(toolResponseMessage.getResponses()))
.metadata(new HashMap<>(toolResponseMessage.getMetadata()))
.build());
}
else {
throw new IllegalArgumentException("Unsupported message type: " + message.getClass().getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess
toolCallResult != null ? toolCallResult : ""));
}

return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
return new InternalToolExecutionResult(ToolResponseMessage.builder().responses(toolResponses).build(),
returnDirect);
}

private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand Down Expand Up @@ -169,8 +170,9 @@ void whenSingleToolCallInChatResponseThenExecute() {
.build())))
.build();

ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!")));
ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!")))
.build();

ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);

Expand All @@ -194,8 +196,9 @@ void whenSingleToolCallWithReturnDirectInChatResponseThenExecute() {
.build())))
.build();

ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!")));
ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!")))
.build();

ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);

Expand Down Expand Up @@ -223,9 +226,10 @@ void whenMultipleToolCallsInChatResponseThenExecute() {
.build())))
.build();

ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"),
new ToolResponseMessage.ToolResponse("toolB", "toolB", "Mission accomplished!")));
ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!"),
new ToolResponse("toolB", "toolB", "Mission accomplished!")))
.build();

ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);

Expand All @@ -249,8 +253,9 @@ void whenDuplicateMixedToolCallsInChatResponseThenExecute() {
.build())))
.build();

ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!")));
ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!")))
.build();

ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);

Expand All @@ -277,9 +282,10 @@ void whenMultipleToolCallsWithReturnDirectInChatResponseThenExecute() {
.build())))
.build();

ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"),
new ToolResponseMessage.ToolResponse("toolB", "toolB", "Mission accomplished!")));
ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!"),
new ToolResponse("toolB", "toolB", "Mission accomplished!")))
.build();

ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);

Expand Down Expand Up @@ -307,9 +313,10 @@ void whenMultipleToolCallsWithMixedReturnDirectInChatResponseThenExecute() {
.build())))
.build();

ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"),
new ToolResponseMessage.ToolResponse("toolB", "toolB", "Mission accomplished!")));
ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!"),
new ToolResponse("toolB", "toolB", "Mission accomplished!")))
.build();

ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);

Expand All @@ -334,8 +341,9 @@ void whenToolCallWithExceptionThenReturnError() {
.build())))
.build();

ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
List.of(new ToolResponseMessage.ToolResponse("toolC", "toolC", "You failed this city!")));
ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("toolC", "toolC", "You failed this city!")))
.build();

ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);

Expand Down Expand Up @@ -378,10 +386,10 @@ void whenMixedMethodToolCallsInChatResponseThenExecute() throws NoSuchMethodExce
.build())))
.build();

ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", TestGenericClass.CALL_RESULT_JSON),
new ToolResponseMessage.ToolResponse("toolB", "toolB",
TestGenericClass.CALL_WITH_TOOL_CONTEXT_RESULT_JSON)));
ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("toolA", "toolA", TestGenericClass.CALL_RESULT_JSON),
new ToolResponse("toolB", "toolB", TestGenericClass.CALL_WITH_TOOL_CONTEXT_RESULT_JSON)))
.build();

ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);

Expand Down
Loading