Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.springframework.core.ParameterizedTypeReference;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.BDDMockito.given;

/**
Expand Down Expand Up @@ -136,8 +137,104 @@ public void customSoCResponseEntityTest() {
assertThat(userMessage.getText()).contains("Tell me about Max");
}

record MyBean(String name, int age) {
@Test
public void whenEmptyResponseContentThenHandleGracefully() {
var chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage(""))));
given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse);

assertThatThrownBy(() -> ChatClient.builder(this.chatModel)
.build()
.prompt()
.user("test")
.call()
.responseEntity(MyBean.class)).isInstanceOf(RuntimeException.class);
}

@Test
public void whenInvalidJsonResponseThenThrows() {
var chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("invalid json content"))));
given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse);

assertThatThrownBy(() -> ChatClient.builder(this.chatModel)
.build()
.prompt()
.user("test")
.call()
.responseEntity(MyBean.class)).isInstanceOf(RuntimeException.class);
}

@Test
public void whenParameterizedTypeWithMapThenParseCorrectly() {
var chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
{
"key1": "value1",
"key2": "value2",
"key3": "value3"
}
"""))));

given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse);

ResponseEntity<ChatResponse, Map<String, String>> responseEntity = ChatClient.builder(this.chatModel)
.build()
.prompt()
.user("test")
.call()
.responseEntity(new ParameterizedTypeReference<Map<String, String>>() {
});

assertThat(responseEntity.getEntity()).containsEntry("key1", "value1");
assertThat(responseEntity.getEntity()).containsEntry("key2", "value2");
assertThat(responseEntity.getEntity()).containsEntry("key3", "value3");
}

@Test
public void whenEmptyArrayResponseThenReturnEmptyList() {
var chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("[]"))));
given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse);

ResponseEntity<ChatResponse, List<MyBean>> responseEntity = ChatClient.builder(this.chatModel)
.build()
.prompt()
.user("test")
.call()
.responseEntity(new ParameterizedTypeReference<List<MyBean>>() {
});

assertThat(responseEntity.getEntity()).isEmpty();
}

@Test
public void whenBooleanPrimitiveResponseThenParseCorrectly() {
var chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("true"))));
given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse);

ResponseEntity<ChatResponse, Boolean> responseEntity = ChatClient.builder(this.chatModel)
.build()
.prompt()
.user("Is this true?")
.call()
.responseEntity(Boolean.class);

assertThat(responseEntity.getEntity()).isTrue();
}

@Test
public void whenIntegerResponseThenParseCorrectly() {
var chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("1"))));
given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse);

ResponseEntity<ChatResponse, Integer> responseEntity = ChatClient.builder(this.chatModel)
.build()
.prompt()
.user("What is the answer?")
.call()
.responseEntity(Integer.class);

assertThat(responseEntity.getEntity()).isEqualTo(1);
}

record MyBean(String name, int age) {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
import java.util.Map;

import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;

/**
* Unit tests for {@link ChatClientResponse}.
Expand Down Expand Up @@ -82,4 +85,102 @@ void whenMutateThenImmutableContext() {
assertThat(response.context()).containsEntry("key", "value");
}

@Test
void whenValidChatResponseThenCreateSuccessfully() {
ChatResponse chatResponse = mock(ChatResponse.class);
Map<String, Object> context = Map.of("key", "value");

ChatClientResponse response = new ChatClientResponse(chatResponse, context);

assertThat(response.chatResponse()).isEqualTo(chatResponse);
assertThat(response.context()).containsExactlyInAnyOrderEntriesOf(context);
}

@Test
void whenBuilderWithValidDataThenCreateSuccessfully() {
ChatResponse chatResponse = mock(ChatResponse.class);
Map<String, Object> context = Map.of("key1", "value1", "key2", 42);

ChatClientResponse response = ChatClientResponse.builder().chatResponse(chatResponse).context(context).build();

assertThat(response.chatResponse()).isEqualTo(chatResponse);
assertThat(response.context()).containsExactlyInAnyOrderEntriesOf(context);
}

@Test
void whenEmptyContextThenCreateSuccessfully() {
ChatResponse chatResponse = mock(ChatResponse.class);
Map<String, Object> emptyContext = Map.of();

ChatClientResponse response = new ChatClientResponse(chatResponse, emptyContext);

assertThat(response.chatResponse()).isEqualTo(chatResponse);
assertThat(response.context()).isEmpty();
}

@Test
void whenContextWithNullValuesThenCreateSuccessfully() {
ChatResponse chatResponse = mock(ChatResponse.class);
Map<String, Object> context = new HashMap<>();
context.put("key1", "value1");
context.put("key2", null);

ChatClientResponse response = new ChatClientResponse(chatResponse, context);

assertThat(response.context()).containsEntry("key1", "value1");
assertThat(response.context()).containsEntry("key2", null);
}

@Test
void whenCopyWithNullChatResponseThenPreserveNull() {
Map<String, Object> context = Map.of("key", "value");
ChatClientResponse response = new ChatClientResponse(null, context);

ChatClientResponse copy = response.copy();

assertThat(copy.chatResponse()).isNull();
assertThat(copy.context()).containsExactlyInAnyOrderEntriesOf(context);
}

@Test
void whenMutateWithNewChatResponseThenUpdate() {
ChatResponse originalResponse = mock(ChatResponse.class);
ChatResponse newResponse = mock(ChatResponse.class);
Map<String, Object> context = Map.of("key", "value");

ChatClientResponse response = new ChatClientResponse(originalResponse, context);
ChatClientResponse mutated = response.mutate().chatResponse(newResponse).build();

assertThat(response.chatResponse()).isEqualTo(originalResponse);
assertThat(mutated.chatResponse()).isEqualTo(newResponse);
assertThat(mutated.context()).containsExactlyInAnyOrderEntriesOf(context);
}

@Test
void whenBuilderWithoutChatResponseThenCreateWithNull() {
Map<String, Object> context = Map.of("key", "value");

ChatClientResponse response = ChatClientResponse.builder().context(context).build();

assertThat(response.chatResponse()).isNull();
}

@Test
void whenComplexObjectsInContextThenPreserveCorrectly() {
ChatResponse chatResponse = mock(ChatResponse.class);
Generation generation = mock(Generation.class);
Map<String, Object> nestedMap = Map.of("nested", "value");

Map<String, Object> context = Map.of("string", "value", "number", 1, "boolean", true, "generation", generation,
"map", nestedMap);

ChatClientResponse response = new ChatClientResponse(chatResponse, context);

assertThat(response.context()).containsEntry("string", "value");
assertThat(response.context()).containsEntry("number", 1);
assertThat(response.context()).containsEntry("boolean", true);
assertThat(response.context()).containsEntry("generation", generation);
assertThat(response.context()).containsEntry("map", nestedMap);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,94 @@ void whenOverridingUserPromptThenLatestValueIsUsed() {
assertThat(defaultRequest.getUserText()).isEqualTo("second user prompt");
}

@Test
void whenDefaultUserStringSetThenAppliedToRequest() {
var chatModel = mock(ChatModel.class);
var builder = new DefaultChatClientBuilder(chatModel);

builder.defaultUser("test user prompt");

var defaultRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ReflectionTestUtils.getField(builder,
"defaultRequest");
assertThat(defaultRequest.getUserText()).isEqualTo("test user prompt");
}

@Test
void whenDefaultSystemStringSetThenAppliedToRequest() {
var chatModel = mock(ChatModel.class);
var builder = new DefaultChatClientBuilder(chatModel);

builder.defaultSystem("test system prompt");

var defaultRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ReflectionTestUtils.getField(builder,
"defaultRequest");
assertThat(defaultRequest.getSystemText()).isEqualTo("test system prompt");
}

@Test
void whenBuilderMethodChainingThenAllSettingsApplied() {
var chatModel = mock(ChatModel.class);

var builder = new DefaultChatClientBuilder(chatModel).defaultSystem("system prompt").defaultUser("user prompt");

var defaultRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ReflectionTestUtils.getField(builder,
"defaultRequest");

assertThat(defaultRequest.getSystemText()).isEqualTo("system prompt");
assertThat(defaultRequest.getUserText()).isEqualTo("user prompt");
}

@Test
void whenCloneWithAllSettingsThenAllAreCopied() {
var chatModel = mock(ChatModel.class);

var originalBuilder = new DefaultChatClientBuilder(chatModel).defaultSystem("system prompt")
.defaultUser("user prompt");

var clonedBuilder = (DefaultChatClientBuilder) originalBuilder.clone();
var clonedRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ReflectionTestUtils.getField(clonedBuilder,
"defaultRequest");

assertThat(clonedRequest.getSystemText()).isEqualTo("system prompt");
assertThat(clonedRequest.getUserText()).isEqualTo("user prompt");
}

@Test
void whenBuilderUsedMultipleTimesThenProducesDifferentInstances() {
var chatModel = mock(ChatModel.class);
var builder = new DefaultChatClientBuilder(chatModel);

var client1 = builder.build();
var client2 = builder.build();

assertThat(client1).isNotSameAs(client2);
assertThat(client1).isInstanceOf(DefaultChatClient.class);
assertThat(client2).isInstanceOf(DefaultChatClient.class);
}

@Test
void whenDefaultUserWithTemplateVariablesThenProcessed() {
var chatModel = mock(ChatModel.class);
var builder = new DefaultChatClientBuilder(chatModel);

builder.defaultUser("Hello {name}, welcome to {service}!");

var defaultRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ReflectionTestUtils.getField(builder,
"defaultRequest");
assertThat(defaultRequest.getUserText()).isEqualTo("Hello {name}, welcome to {service}!");
}

@Test
void whenMultipleSystemSettingsThenLastOneWins() {
var chatModel = mock(ChatModel.class);
var builder = new DefaultChatClientBuilder(chatModel);

builder.defaultSystem("first system message");
builder.defaultSystem("final system message");

var defaultRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ReflectionTestUtils.getField(builder,
"defaultRequest");
assertThat(defaultRequest.getSystemText()).isEqualTo("final system message");
}

}