|
11 | 11 | import java.util.concurrent.atomic.AtomicReference;
|
12 | 12 | import java.util.function.Function;
|
13 | 13 | import java.util.stream.Collectors;
|
| 14 | +import java.util.stream.Stream; |
14 | 15 |
|
15 | 16 | import com.fasterxml.jackson.databind.ObjectMapper;
|
16 | 17 | import io.modelcontextprotocol.client.McpClient;
|
|
20 | 21 | import io.modelcontextprotocol.server.McpServerFeatures;
|
21 | 22 | import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
|
22 | 23 | import io.modelcontextprotocol.spec.McpError;
|
| 24 | +import io.modelcontextprotocol.spec.McpParamsValidationError; |
23 | 25 | import io.modelcontextprotocol.spec.McpSchema;
|
24 | 26 | import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
|
25 | 27 | import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
|
|
34 | 36 | import org.junit.jupiter.api.AfterEach;
|
35 | 37 | import org.junit.jupiter.api.BeforeEach;
|
36 | 38 | import org.junit.jupiter.params.ParameterizedTest;
|
| 39 | +import org.junit.jupiter.params.provider.Arguments; |
| 40 | +import org.junit.jupiter.params.provider.MethodSource; |
37 | 41 | import org.junit.jupiter.params.provider.ValueSource;
|
| 42 | +import reactor.core.CoreSubscriber; |
38 | 43 | import reactor.core.publisher.Mono;
|
39 | 44 | import reactor.netty.DisposableServer;
|
40 | 45 | import reactor.netty.http.server.HttpServer;
|
|
46 | 51 | import org.springframework.web.reactive.function.client.WebClient;
|
47 | 52 | import org.springframework.web.reactive.function.server.RouterFunctions;
|
48 | 53 |
|
| 54 | +import static io.modelcontextprotocol.spec.McpSchema.ErrorCodes.INVALID_PARAMS; |
49 | 55 | import static org.assertj.core.api.Assertions.assertThat;
|
| 56 | +import static org.assertj.core.api.Assertions.assertThatThrownBy; |
50 | 57 | import static org.awaitility.Awaitility.await;
|
51 | 58 | import static org.mockito.Mockito.mock;
|
52 | 59 |
|
@@ -620,4 +627,99 @@ void testLoggingNotification(String clientType) {
|
620 | 627 | mcpServer.close();
|
621 | 628 | }
|
622 | 629 |
|
| 630 | + // --------------------------------------- |
| 631 | + // Prompt List Tests |
| 632 | + // --------------------------------------- |
| 633 | + |
| 634 | + static Stream<Arguments> providePaginationTestParams() { |
| 635 | + return Stream.of(Arguments.of("httpclient", 0), Arguments.of("httpclient", 1), Arguments.of("httpclient", 21), |
| 636 | + Arguments.of("webflux", 0), Arguments.of("webflux", 1), Arguments.of("webflux", 21)); |
| 637 | + } |
| 638 | + |
| 639 | + @ParameterizedTest(name = "{0} ({1}) : {displayName} ") |
| 640 | + @MethodSource("providePaginationTestParams") |
| 641 | + void testListPromptSuccess(String clientType, int availablePrompts) { |
| 642 | + |
| 643 | + var clientBuilder = clientBuilders.get(clientType); |
| 644 | + |
| 645 | + // Setup list of prompts |
| 646 | + List<McpServerFeatures.SyncPromptSpecification> prompts = new ArrayList<>(); |
| 647 | + |
| 648 | + for (int i = 0; i < availablePrompts; i++) { |
| 649 | + McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt-" + i, "Test Prompt Description", |
| 650 | + List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); |
| 651 | + |
| 652 | + var promptSpec = new McpServerFeatures.SyncPromptSpecification(mockPrompt, null); |
| 653 | + |
| 654 | + prompts.add(promptSpec); |
| 655 | + } |
| 656 | + |
| 657 | + var mcpServer = McpServer.sync(mcpServerTransportProvider) |
| 658 | + .capabilities(ServerCapabilities.builder().prompts(true).build()) |
| 659 | + .prompts(prompts) |
| 660 | + .build(); |
| 661 | + |
| 662 | + try (var mcpClient = clientBuilder.build()) { |
| 663 | + |
| 664 | + InitializeResult initResult = mcpClient.initialize(); |
| 665 | + assertThat(initResult).isNotNull(); |
| 666 | + |
| 667 | + // Iterate through list |
| 668 | + var returnedPromptsSum = 0; |
| 669 | + |
| 670 | + var hasEntries = true; |
| 671 | + String nextCursor = null; |
| 672 | + |
| 673 | + while (hasEntries) { |
| 674 | + var res = mcpClient.listPrompts(nextCursor); |
| 675 | + returnedPromptsSum += res.prompts().size(); |
| 676 | + |
| 677 | + nextCursor = res.nextCursor(); |
| 678 | + |
| 679 | + if (nextCursor == null) { |
| 680 | + hasEntries = false; |
| 681 | + } |
| 682 | + } |
| 683 | + |
| 684 | + assertThat(returnedPromptsSum).isEqualTo(availablePrompts); |
| 685 | + |
| 686 | + } |
| 687 | + |
| 688 | + mcpServer.close(); |
| 689 | + } |
| 690 | + |
| 691 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 692 | + @ValueSource(strings = { "httpclient", "webflux" }) |
| 693 | + void testListPromptInvalidCursor(String clientType) { |
| 694 | + |
| 695 | + var clientBuilder = clientBuilders.get(clientType); |
| 696 | + |
| 697 | + McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt", "Test Prompt Description", |
| 698 | + List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); |
| 699 | + |
| 700 | + var promptSpec = new McpServerFeatures.SyncPromptSpecification(mockPrompt, null); |
| 701 | + |
| 702 | + var mcpServer = McpServer.sync(mcpServerTransportProvider) |
| 703 | + .capabilities(ServerCapabilities.builder().prompts(true).build()) |
| 704 | + .prompts(promptSpec) |
| 705 | + .build(); |
| 706 | + |
| 707 | + try (var mcpClient = clientBuilder.build()) { |
| 708 | + |
| 709 | + InitializeResult initResult = mcpClient.initialize(); |
| 710 | + assertThat(initResult).isNotNull(); |
| 711 | + |
| 712 | + assertThatThrownBy(() -> mcpClient.listPrompts("INVALID")).isInstanceOf(McpError.class) |
| 713 | + .hasMessage("Invalid cursor") |
| 714 | + .satisfies(exception -> { |
| 715 | + var error = (McpError) exception; |
| 716 | + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); |
| 717 | + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); |
| 718 | + }); |
| 719 | + |
| 720 | + } |
| 721 | + |
| 722 | + mcpServer.close(); |
| 723 | + } |
| 724 | + |
623 | 725 | }
|
0 commit comments