Skip to content

Commit 880b366

Browse files
committed
feat: Add Pagination for requesting list of prompts
Adds the Pagination feature to the `prompts/list` feature as described in the specification. To make this possible mainly two changes are made: 1. The logic for cursor handling is added. 2. Handling for invalid parameters (MCP error code `-32602 (Invalid params)`) is added to the `McpServerSession`. For now the cursor is the base64 encoded start index of the next page. The page size is set to 10. When parameters are found to be invalid the newly introduced `McpParamsValidationError` is returned to handle it properly in the `McpServerSession`.
1 parent f348a83 commit 880b366

File tree

4 files changed

+185
-9
lines changed

4 files changed

+185
-9
lines changed

Diff for: mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

+102
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.concurrent.atomic.AtomicReference;
1212
import java.util.function.Function;
1313
import java.util.stream.Collectors;
14+
import java.util.stream.Stream;
1415

1516
import com.fasterxml.jackson.databind.ObjectMapper;
1617
import io.modelcontextprotocol.client.McpClient;
@@ -20,6 +21,7 @@
2021
import io.modelcontextprotocol.server.McpServerFeatures;
2122
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
2223
import io.modelcontextprotocol.spec.McpError;
24+
import io.modelcontextprotocol.spec.McpParamsValidationError;
2325
import io.modelcontextprotocol.spec.McpSchema;
2426
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
2527
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
@@ -34,7 +36,10 @@
3436
import org.junit.jupiter.api.AfterEach;
3537
import org.junit.jupiter.api.BeforeEach;
3638
import org.junit.jupiter.params.ParameterizedTest;
39+
import org.junit.jupiter.params.provider.Arguments;
40+
import org.junit.jupiter.params.provider.MethodSource;
3741
import org.junit.jupiter.params.provider.ValueSource;
42+
import reactor.core.CoreSubscriber;
3843
import reactor.core.publisher.Mono;
3944
import reactor.netty.DisposableServer;
4045
import reactor.netty.http.server.HttpServer;
@@ -46,7 +51,9 @@
4651
import org.springframework.web.reactive.function.client.WebClient;
4752
import org.springframework.web.reactive.function.server.RouterFunctions;
4853

54+
import static io.modelcontextprotocol.spec.McpSchema.ErrorCodes.INVALID_PARAMS;
4955
import static org.assertj.core.api.Assertions.assertThat;
56+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
5057
import static org.awaitility.Awaitility.await;
5158
import static org.mockito.Mockito.mock;
5259

@@ -620,4 +627,99 @@ void testLoggingNotification(String clientType) {
620627
mcpServer.close();
621628
}
622629

630+
// ---------------------------------------
631+
// Prompt List Tests
632+
// ---------------------------------------
633+
634+
static Stream<Arguments> providePaginationTestParams() {
635+
return Stream.of(Arguments.of("httpclient", 0), Arguments.of("httpclient", 1), Arguments.of("httpclient", 21),
636+
Arguments.of("webflux", 0), Arguments.of("webflux", 1), Arguments.of("webflux", 21));
637+
}
638+
639+
@ParameterizedTest(name = "{0} ({1}) : {displayName} ")
640+
@MethodSource("providePaginationTestParams")
641+
void testListPromptSuccess(String clientType, int availablePrompts) {
642+
643+
var clientBuilder = clientBuilders.get(clientType);
644+
645+
// Setup list of prompts
646+
List<McpServerFeatures.SyncPromptSpecification> prompts = new ArrayList<>();
647+
648+
for (int i = 0; i < availablePrompts; i++) {
649+
McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt-" + i, "Test Prompt Description",
650+
List.of(new McpSchema.PromptArgument("arg1", "Test argument", true)));
651+
652+
var promptSpec = new McpServerFeatures.SyncPromptSpecification(mockPrompt, null);
653+
654+
prompts.add(promptSpec);
655+
}
656+
657+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
658+
.capabilities(ServerCapabilities.builder().prompts(true).build())
659+
.prompts(prompts)
660+
.build();
661+
662+
try (var mcpClient = clientBuilder.build()) {
663+
664+
InitializeResult initResult = mcpClient.initialize();
665+
assertThat(initResult).isNotNull();
666+
667+
// Iterate through list
668+
var returnedPromptsSum = 0;
669+
670+
var hasEntries = true;
671+
String nextCursor = null;
672+
673+
while (hasEntries) {
674+
var res = mcpClient.listPrompts(nextCursor);
675+
returnedPromptsSum += res.prompts().size();
676+
677+
nextCursor = res.nextCursor();
678+
679+
if (nextCursor == null) {
680+
hasEntries = false;
681+
}
682+
}
683+
684+
assertThat(returnedPromptsSum).isEqualTo(availablePrompts);
685+
686+
}
687+
688+
mcpServer.close();
689+
}
690+
691+
@ParameterizedTest(name = "{0} : {displayName} ")
692+
@ValueSource(strings = { "httpclient", "webflux" })
693+
void testListPromptInvalidCursor(String clientType) {
694+
695+
var clientBuilder = clientBuilders.get(clientType);
696+
697+
McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt", "Test Prompt Description",
698+
List.of(new McpSchema.PromptArgument("arg1", "Test argument", true)));
699+
700+
var promptSpec = new McpServerFeatures.SyncPromptSpecification(mockPrompt, null);
701+
702+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
703+
.capabilities(ServerCapabilities.builder().prompts(true).build())
704+
.prompts(promptSpec)
705+
.build();
706+
707+
try (var mcpClient = clientBuilder.build()) {
708+
709+
InitializeResult initResult = mcpClient.initialize();
710+
assertThat(initResult).isNotNull();
711+
712+
assertThatThrownBy(() -> mcpClient.listPrompts("INVALID")).isInstanceOf(McpError.class)
713+
.hasMessage("Invalid cursor")
714+
.satisfies(exception -> {
715+
var error = (McpError) exception;
716+
assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS);
717+
assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor");
718+
});
719+
720+
}
721+
722+
mcpServer.close();
723+
}
724+
623725
}

Diff for: mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java

+56-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.Map;
1010
import java.util.Optional;
1111
import java.util.UUID;
12+
import java.util.Base64;
1213
import java.util.concurrent.ConcurrentHashMap;
1314
import java.util.concurrent.CopyOnWriteArrayList;
1415
import java.util.function.BiFunction;
@@ -17,6 +18,7 @@
1718
import com.fasterxml.jackson.databind.ObjectMapper;
1819
import io.modelcontextprotocol.spec.McpClientSession;
1920
import io.modelcontextprotocol.spec.McpError;
21+
import io.modelcontextprotocol.spec.McpParamsValidationError;
2022
import io.modelcontextprotocol.spec.McpSchema;
2123
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
2224
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
@@ -264,6 +266,8 @@ private static class AsyncServerImpl extends McpAsyncServer {
264266

265267
private final ConcurrentHashMap<String, McpServerFeatures.AsyncPromptSpecification> prompts = new ConcurrentHashMap<>();
266268

269+
private static final int PAGE_SIZE = 10;
270+
267271
// FIXME: this field is deprecated and should be remvoed together with the
268272
// broadcasting loggingNotification.
269273
private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG;
@@ -637,20 +641,67 @@ public Mono<Void> notifyPromptsListChanged() {
637641

638642
private McpServerSession.RequestHandler<McpSchema.ListPromptsResult> promptsListRequestHandler() {
639643
return (exchange, params) -> {
640-
// TODO: Implement pagination
641-
// McpSchema.PaginatedRequest request = objectMapper.convertValue(params,
642-
// new TypeReference<McpSchema.PaginatedRequest>() {
643-
// });
644+
McpSchema.PaginatedRequest request = objectMapper.convertValue(params,
645+
new TypeReference<McpSchema.PaginatedRequest>() {
646+
});
647+
648+
if (!isCursorValid(request.cursor(), this.prompts.size())) {
649+
return Mono.error(new McpParamsValidationError("Invalid cursor"));
650+
}
651+
652+
int requestedStartIndex = 0;
653+
654+
if (request.cursor() != null) {
655+
requestedStartIndex = decodeCursor(request.cursor());
656+
}
657+
658+
int endIndex = Math.min(requestedStartIndex + PAGE_SIZE, this.prompts.size());
644659

645660
var promptList = this.prompts.values()
646661
.stream()
662+
.skip(requestedStartIndex)
663+
.limit(endIndex - requestedStartIndex)
647664
.map(McpServerFeatures.AsyncPromptSpecification::prompt)
648665
.toList();
649666

650-
return Mono.just(new McpSchema.ListPromptsResult(promptList, null));
667+
String nextCursor = null;
668+
669+
if (endIndex < this.prompts.size()) {
670+
nextCursor = encodeCursor(endIndex);
671+
}
672+
673+
return Mono.just(new McpSchema.ListPromptsResult(promptList, nextCursor));
651674
};
652675
}
653676

677+
private boolean isCursorValid(String cursor, int maxPageSize) {
678+
if (cursor == null) {
679+
return true;
680+
}
681+
682+
try {
683+
var decoded = decodeCursor(cursor);
684+
685+
if (decoded < 0 || decoded > maxPageSize) {
686+
return false;
687+
}
688+
689+
return true;
690+
}
691+
catch (NumberFormatException e) {
692+
return false;
693+
}
694+
}
695+
696+
private String encodeCursor(int index) {
697+
return Base64.getEncoder().encodeToString(String.valueOf(index).getBytes());
698+
}
699+
700+
private int decodeCursor(String cursor) {
701+
String decoded = new String(Base64.getDecoder().decode(cursor));
702+
return Integer.parseInt(decoded);
703+
}
704+
654705
private McpServerSession.RequestHandler<McpSchema.GetPromptResult> promptsGetRequestHandler() {
655706
return (exchange, params) -> {
656707
McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package io.modelcontextprotocol.spec;
2+
3+
public class McpParamsValidationError extends McpError {
4+
5+
public McpParamsValidationError(McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError) {
6+
super(jsonRpcError.message());
7+
}
8+
9+
public McpParamsValidationError(Object error) {
10+
super(error.toString());
11+
}
12+
13+
}

Diff for: mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java

+14-4
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,20 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
221221
}
222222
return resultMono
223223
.map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null))
224-
.onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(),
225-
null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
226-
error.getMessage(), null)))); // TODO: add error message
227-
// through the data field
224+
.onErrorResume(error -> {
225+
226+
var errorCode = McpSchema.ErrorCodes.INTERNAL_ERROR;
227+
228+
if (error instanceof McpParamsValidationError) {
229+
errorCode = McpSchema.ErrorCodes.INVALID_PARAMS;
230+
}
231+
232+
// TODO: add error message through the data field
233+
var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
234+
new McpSchema.JSONRPCResponse.JSONRPCError(errorCode, error.getMessage(), null));
235+
236+
return Mono.just(errorResponse);
237+
});
228238
});
229239
}
230240

0 commit comments

Comments
 (0)