Skip to content

Issue: 3930, Fix MessageChatMemoryAdvisor to handle message updates correctly #3940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -91,9 +92,10 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai
.prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build())
.build();

// 4. Add the new user message to the conversation memory.
// 4. Handle message updates and add the new user message to the conversation
// memory.
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
this.chatMemory.add(conversationId, userMessage);
handleMessageUpdate(conversationId, userMessage);

return processedChatClientRequest;
}
Expand Down Expand Up @@ -128,6 +130,74 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest
response -> this.after(response, streamAdvisorChain)));
}

/**
* Handle message updates by checking if a user message with the same ID already
* exists. If it does, remove the old message and its corresponding assistant response
* before adding the new message.
* @param conversationId the conversation ID
* @param userMessage the user message to add
*/
private void handleMessageUpdate(String conversationId, UserMessage userMessage) {
// Ensure the user message has a unique messageId for tracking
UserMessage messageWithId = ensureMessageId(userMessage);

String messageId = (String) messageWithId.getMetadata().get("messageId");

// Check if this is an update (messageId already exists in memory)
if (this.chatMemory instanceof MessageWindowChatMemory windowMemory) {
// If we have an existing message with this ID, remove it and its response
if (hasExistingMessage(conversationId, messageId)) {
windowMemory.removeMessageAndResponse(conversationId, messageId);
}
}

// Add the new/updated message to memory
this.chatMemory.add(conversationId, messageWithId);
}

/**
* Ensure the user message has a unique messageId in its metadata. If no messageId
* exists, generate one based on content hash.
* @param userMessage the user message
* @return the user message with messageId in metadata
*/
private UserMessage ensureMessageId(UserMessage userMessage) {
String existingMessageId = (String) userMessage.getMetadata().get("messageId");
if (existingMessageId != null) {
return userMessage;
}

// Generate a messageId based on content hash for tracking updates
String messageId = generateMessageId(userMessage);

// Merge with existing metadata
java.util.Map<String, Object> metadata = new java.util.HashMap<>(userMessage.getMetadata());
metadata.put("messageId", messageId);

return userMessage.mutate().metadata(metadata).build();
}

/**
* Generate a unique message ID based on the user message content.
* @param userMessage the user message
* @return a unique message ID
*/
private String generateMessageId(UserMessage userMessage) {
// Use content hash as a stable identifier for the same logical message
return String.valueOf(userMessage.getText().hashCode());
}

/**
* Check if a message with the given ID already exists in the conversation memory.
* @param conversationId the conversation ID
* @param messageId the message ID to check
* @return true if the message exists, false otherwise
*/
private boolean hasExistingMessage(String conversationId, String messageId) {
List<Message> memoryMessages = this.chatMemory.get(conversationId);
return memoryMessages.stream().anyMatch(message -> messageId.equals(message.getMetadata().get("messageId")));
}

public static Builder builder(ChatMemory chatMemory) {
return new Builder(chatMemory);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.AssistantMessage;

import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down Expand Up @@ -108,4 +113,79 @@ void testDefaultValues() {
assertThat(advisor.getOrder()).isEqualTo(Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
}

@Test
void testMessageUpdateFunctionality() {
// Create a chat memory
MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();

MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build();
String conversationId = "test-conversation";

// Test 1: Add original message with specific messageId
UserMessage originalMessage = UserMessage.builder()
.text("What is the capital of France?")
.metadata(Map.of("messageId", "msg-001"))
.build();

chatMemory.add(conversationId, originalMessage);

// Simulate adding an assistant response
AssistantMessage assistantResponse = new AssistantMessage("The capital of France is Paris.");
chatMemory.add(conversationId, assistantResponse);

// Verify initial state: should have 2 messages (user + assistant)
assertThat(chatMemory.get(conversationId)).hasSize(2);
assertThat(chatMemory.get(conversationId).get(0).getText()).isEqualTo("What is the capital of France?");
assertThat(chatMemory.get(conversationId).get(1).getText()).isEqualTo("The capital of France is Paris.");

// Test 2: Update the message with same messageId
UserMessage updatedMessage = UserMessage.builder()
.text("What is the capital of Italy?")
.metadata(Map.of("messageId", "msg-001")) // Same messageId
.build();

// Remove old message and response manually (testing the repository functionality)
chatMemory.removeMessageAndResponse(conversationId, "msg-001");
chatMemory.add(conversationId, updatedMessage);

// Verify the update: should have only 1 message (the updated user message)
// The old user message and assistant response should be removed
assertThat(chatMemory.get(conversationId)).hasSize(1);
assertThat(chatMemory.get(conversationId).get(0).getText()).isEqualTo("What is the capital of Italy?");
assertThat(chatMemory.get(conversationId).get(0).getMetadata().get("messageId")).isEqualTo("msg-001");
}

@Test
void testMessageIdGeneration() {
// Create a chat memory
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();

MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build();

// Test that messages without messageId get one generated automatically
UserMessage messageWithoutId = new UserMessage("Hello world");

// This would happen inside handleMessageUpdate method when a message is processed
// We can't directly test the private method, but we can verify the behavior
// by checking that the same content generates the same hash-based ID

String expectedId = String.valueOf("Hello world".hashCode());

// The generateMessageId method should produce consistent IDs for same content
assertThat(expectedId).isNotNull();

// Messages with same content should have same messageId
UserMessage message1 = new UserMessage("Same content");
UserMessage message2 = new UserMessage("Same content");

String id1 = String.valueOf(message1.getText().hashCode());
String id2 = String.valueOf(message2.getText().hashCode());

assertThat(id1).isEqualTo(id2);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,52 @@ public interface ChatMemoryRepository {

void deleteByConversationId(String conversationId);

/**
* Find a message by its unique identifier within a conversation.
* @param conversationId the conversation ID
* @param messageId the unique message ID
* @return the message if found, null otherwise
*/
default Message findByMessageId(String conversationId, String messageId) {
return findByConversationId(conversationId).stream()
.filter(message -> messageId.equals(message.getMetadata().get("messageId")))
.findFirst()
.orElse(null);
}

/**
* Delete a specific message and its subsequent assistant response if any. This is
* used when a user message is updated to remove the old message pair.
* @param conversationId the conversation ID
* @param messageId the unique message ID to delete
*/
default void deleteMessageAndResponse(String conversationId, String messageId) {
List<Message> messages = findByConversationId(conversationId);
List<Message> updatedMessages = new java.util.ArrayList<>();

boolean skipNext = false;
for (int i = 0; i < messages.size(); i++) {
Message message = messages.get(i);
String currentMessageId = (String) message.getMetadata().get("messageId");

if (skipNext) {
skipNext = false;
continue;
}

if (messageId.equals(currentMessageId)) {
// Skip this message and potentially the next assistant response
if (i + 1 < messages.size() && messages.get(i + 1)
.getMessageType() == org.springframework.ai.chat.messages.MessageType.ASSISTANT) {
skipNext = true;
}
continue;
}

updatedMessages.add(message);
}

saveAll(conversationId, updatedMessages);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ public void clear(String conversationId) {
this.chatMemoryRepository.deleteByConversationId(conversationId);
}

/**
* Remove a specific message and its subsequent assistant response from the
* conversation memory. This is used when a user message is updated to clean up the
* old message pair.
* @param conversationId the conversation ID
* @param messageId the unique message ID to remove
*/
public void removeMessageAndResponse(String conversationId, String messageId) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");
Assert.hasText(messageId, "messageId cannot be null or empty");
this.chatMemoryRepository.deleteMessageAndResponse(conversationId, messageId);
}

private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) {
List<Message> processedMessages = new ArrayList<>();

Expand Down