diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 2f85654e8..ccfc74ab6 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol; import java.time.Duration; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -14,6 +16,7 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; @@ -30,6 +33,8 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; @@ -42,8 +47,10 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.spec.McpSchema.ErrorCodes.INVALID_PARAMS; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertWith; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -851,6 +858,138 @@ void testToolListChangeHandlingSuccess(String clientType) { mcpServer.close(); } + // --------------------------------------- + // Tests for Paginated Tool List Results + // --------------------------------------- + + @ParameterizedTest(name = "{0} ({1}) : {displayName} ") + @MethodSource("providePaginationTestParams") + void testListToolsSuccess(String clientType, int availableElements) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + List tools = new ArrayList<>(); + + for (int i = 0; i < availableElements; i++) { + var mock = new McpSchema.Tool("test-tool-" + i, "Test Tool Description", emptyJsonSchema); + var spec = new McpServerFeatures.SyncToolSpecification(mock, null); + + tools.add(spec); + } + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tools) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var returnedElements = new HashSet(); + + var hasEntries = true; + String nextCursor = null; + + while (hasEntries) { + var res = mcpClient.listTools(nextCursor); + + res.tools().forEach(e -> returnedElements.add(e.name())); // store unique + // attribute + + nextCursor = res.nextCursor(); + + if (nextCursor == null) { + hasEntries = false; + } + } + + assertThat(returnedElements.size()).isEqualTo(availableElements); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListToolsCursorInvalidListChanged(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + var pageSize = 10; + List tools = new ArrayList<>(); + + for (int i = 0; i <= pageSize; i++) { + var mock = new McpSchema.Tool("test-tool-" + i, "Test Tool Description", emptyJsonSchema); + var spec = new McpServerFeatures.SyncToolSpecification(mock, null); + + tools.add(spec); + } + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tools) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var res = mcpClient.listTools(null); + + // Change list + var mock = new McpSchema.Tool("test-tool-xyz", "Test Tool Description", emptyJsonSchema); + mcpServer.addTool(new McpServerFeatures.SyncToolSpecification(mock, null)); + + assertThatThrownBy(() -> mcpClient.listTools(res.nextCursor())).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListToolsInvalidCursor(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mock = new McpSchema.Tool("test-tool", "Test Tool Description", emptyJsonSchema); + var spec = new McpServerFeatures.SyncToolSpecification(mock, null); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(spec) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatThrownBy(() -> mcpClient.listTools("INVALID")).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testInitialize(String clientType) { @@ -1025,4 +1164,381 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { mcpServer.close(); } -} \ No newline at end of file + // --------------------------------------- + // Tests for Paginated Prompt List Results + // --------------------------------------- + + @ParameterizedTest(name = "{0} ({1}) : {displayName} ") + @MethodSource("providePaginationTestParams") + void testListPromptsSuccess(String clientType, int availableElements) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + List prompts = new ArrayList<>(); + + for (int i = 0; i < availableElements; i++) { + var mock = new McpSchema.Prompt("test-prompt-" + i, "Test Prompt Description", + List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); + var spec = new McpServerFeatures.SyncPromptSpecification(mock, null); + + prompts.add(spec); + } + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(prompts) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var returnedElements = new HashSet(); + + var hasEntries = true; + String nextCursor = null; + + while (hasEntries) { + var res = mcpClient.listPrompts(nextCursor); + + res.prompts().forEach(e -> returnedElements.add(e.name())); // store + // unique + // attribute + + nextCursor = res.nextCursor(); + + if (nextCursor == null) { + hasEntries = false; + } + } + + assertThat(returnedElements.size()).isEqualTo(availableElements); + + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListPromptsCursorInvalidListChanged(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + var pageSize = 10; + List prompts = new ArrayList<>(); + + for (int i = 0; i <= pageSize; i++) { + var mock = new McpSchema.Prompt("test-prompt-" + i, "Test Prompt Description", + List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); + var spec = new McpServerFeatures.SyncPromptSpecification(mock, null); + + prompts.add(spec); + } + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(prompts) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var res = mcpClient.listPrompts(null); + + // Change list + var mock = new McpSchema.Prompt("test-prompt-xyz", "Test Prompt Description", + List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); + + mcpServer.addPrompt(new McpServerFeatures.SyncPromptSpecification(mock, null)); + + assertThatThrownBy(() -> mcpClient.listPrompts(res.nextCursor())).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListPromptsInvalidCursor(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mock = new McpSchema.Prompt("test-prompt", "Test Prompt Description", + List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); + + var spec = new McpServerFeatures.SyncPromptSpecification(mock, null); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(spec) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatThrownBy(() -> mcpClient.listPrompts("INVALID")).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tests for Paginated Resources List Results + // --------------------------------------- + + @ParameterizedTest(name = "{0} ({1}) : {displayName} ") + @MethodSource("providePaginationTestParams") + void testListResourcesSuccess(String clientType, int availableElements) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + List resources = new ArrayList<>(); + + for (int i = 0; i < availableElements; i++) { + var mock = new McpSchema.Resource("file://example-" + i + ".txt", "test-resource", + "Test Resource Description", "application/octet-stream", null); + var spec = new McpServerFeatures.SyncResourceSpecification(mock, null); + + resources.add(spec); + } + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().resources(true, true).build()) + .resources(resources) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var returnedElements = new HashSet(); + + var hasEntries = true; + String nextCursor = null; + + while (hasEntries) { + var res = mcpClient.listResources(nextCursor); + + res.resources().forEach(e -> returnedElements.add(e.uri())); // store + // unique + // attribute + + nextCursor = res.nextCursor(); + + if (nextCursor == null) { + hasEntries = false; + } + } + + assertThat(returnedElements.size()).isEqualTo(availableElements); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListResourcesCursorInvalidListChanged(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + var pageSize = 10; + List resources = new ArrayList<>(); + + for (int i = 0; i <= pageSize; i++) { + var mock = new McpSchema.Resource("file://example-" + i + ".txt", "test-resource", + "Test Resource Description", "application/octet-stream", null); + var spec = new McpServerFeatures.SyncResourceSpecification(mock, null); + + resources.add(spec); + } + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().resources(true, true).build()) + .resources(resources) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var res = mcpClient.listResources(null); + + // Change list + var mock = new McpSchema.Resource("file://example-xyz.txt", "test-resource", "Test Resource Description", + "application/octet-stream", null); + mcpServer.addResource(new McpServerFeatures.SyncResourceSpecification(mock, null)); + + assertThatThrownBy(() -> mcpClient.listResources(res.nextCursor())).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListResourcesInvalidCursor(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mock = new McpSchema.Resource("file://example.txt", "test-resource", "Test Resource Description", + "application/octet-stream", null); + var spec = new McpServerFeatures.SyncResourceSpecification(mock, null); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().resources(true, true).build()) + .resources(spec) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatThrownBy(() -> mcpClient.listResources("INVALID")).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tests for Paginated Resource Templates Results + // --------------------------------------- + + @ParameterizedTest(name = "{0} ({1}) : {displayName} ") + @MethodSource("providePaginationTestParams") + void testListResourceTemplatesSuccess(String clientType, int availableElements) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + List resourceTemplates = new ArrayList<>(); + + for (int i = 0; i < availableElements; i++) { + resourceTemplates.add(new McpSchema.ResourceTemplate("file://{path}-" + i + ".txt", "test-resource", + "Test Resource Description", "application/octet-stream", null)); + } + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().resources(true, true).build()) + .resourceTemplates(resourceTemplates) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var returnedElements = new HashSet(); + + var hasEntries = true; + String nextCursor = null; + + while (hasEntries) { + var res = mcpClient.listResourceTemplates(nextCursor); + + res.resourceTemplates().forEach(e -> returnedElements.add(e.uriTemplate())); // store + // unique + // attribute + + nextCursor = res.nextCursor(); + + if (nextCursor == null) { + hasEntries = false; + } + } + + assertThat(returnedElements.size()).isEqualTo(availableElements); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListResourceTemplatesInvalidCursor(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mock = new McpSchema.ResourceTemplate("file://{path}.txt", "test-resource", "Test Resource Description", + "application/octet-stream", null); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().resources(true, true).build()) + .resourceTemplates(mock) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatThrownBy(() -> mcpClient.listResourceTemplates("INVALID")).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + // --------------------------------------- + // Helpers for Tests of Paginated Lists + // --------------------------------------- + + /** + * Helper function for pagination tests. This provides a stream of the following + * parameters: 1. Client type (e.g. httpclient, webflux) 2. Number of available + * elements in the list + * @return a stream of arguments with test parameters + */ + static Stream providePaginationTestParams() { + return Stream.of(Arguments.of("httpclient", 0), Arguments.of("httpclient", 1), Arguments.of("httpclient", 21), + Arguments.of("webflux", 0), Arguments.of("webflux", 1), Arguments.of("webflux", 21)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 02ad955b9..3a334d8d7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -6,6 +6,7 @@ import java.time.Duration; import java.util.ArrayList; +import java.util.Base64; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -19,13 +20,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpParamsValidationError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; -import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; @@ -100,6 +101,8 @@ public class McpAsyncServer { private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + private static final int PAGE_SIZE = 10; + // FIXME: this field is deprecated and should be remvoed together with the // broadcasting loggingNotification. private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; @@ -340,9 +343,25 @@ public Mono notifyToolsListChanged() { private McpServerSession.RequestHandler toolsListRequestHandler() { return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + new TypeReference() { + }); + + int mapSize = this.tools.size(); + int mapHash = this.tools.hashCode(); + + int requestedStartIndex = handleCursor(request.cursor(), mapSize, mapHash).block(); + int endIndex = Math.min(requestedStartIndex + PAGE_SIZE, mapSize); - return Mono.just(new McpSchema.ListToolsResult(tools, null)); + var nextCursor = getCursor(endIndex, mapSize, mapHash); + + var resultList = this.tools.stream() + .skip(requestedStartIndex) + .limit(endIndex - requestedStartIndex) + .map(McpServerFeatures.AsyncToolSpecification::tool) + .toList(); + + return Mono.just(new McpSchema.ListToolsResult(resultList, nextCursor)); }; } @@ -441,18 +460,49 @@ public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification private McpServerSession.RequestHandler resourcesListRequestHandler() { return (exchange, params) -> { - var resourceList = this.resources.values() + McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + new TypeReference() { + }); + + int mapSize = this.resources.size(); + int mapHash = this.resources.hashCode(); + + int requestedStartIndex = handleCursor(request.cursor(), mapSize, mapHash).block(); + int endIndex = Math.min(requestedStartIndex + PAGE_SIZE, mapSize); + + var nextCursor = getCursor(endIndex, mapSize, mapHash); + + var resultList = this.resources.values() .stream() + .skip(requestedStartIndex) + .limit(endIndex - requestedStartIndex) .map(McpServerFeatures.AsyncResourceSpecification::resource) .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + + return Mono.just(new McpSchema.ListResourcesResult(resultList, nextCursor)); }; } private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + return (exchange, params) -> { + McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + new TypeReference() { + }); + var all = this.getResourceTemplates(); + + int mapSize = all.size(); + int mapHash = all.hashCode(); + + int requestedStartIndex = handleCursor(request.cursor(), mapSize, mapHash).block(); + int endIndex = Math.min(requestedStartIndex + PAGE_SIZE, mapSize); + + var nextCursor = getCursor(endIndex, mapSize, mapHash); + + var resultList = all.stream().skip(requestedStartIndex).limit(endIndex - requestedStartIndex).toList(); + + return Mono.just(new McpSchema.ListResourceTemplatesResult(resultList, nextCursor)); + }; } private List getResourceTemplates() { @@ -568,17 +618,27 @@ public Mono notifyPromptsListChanged() { private McpServerSession.RequestHandler promptsListRequestHandler() { return (exchange, params) -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, - // new TypeReference() { - // }); - var promptList = this.prompts.values() + McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + new TypeReference() { + }); + + int mapSize = this.prompts.size(); + int mapHash = this.prompts.hashCode(); + + int requestedStartIndex = handleCursor(request.cursor(), mapSize, mapHash).block(); + int endIndex = Math.min(requestedStartIndex + PAGE_SIZE, mapSize); + + var nextCursor = getCursor(endIndex, mapSize, mapHash); + + var resultList = this.prompts.values() .stream() + .skip(requestedStartIndex) + .limit(endIndex - requestedStartIndex) .map(McpServerFeatures.AsyncPromptSpecification::prompt) .toList(); - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + return Mono.just(new McpSchema.ListPromptsResult(resultList, nextCursor)); }; } @@ -747,4 +807,79 @@ void setProtocolVersions(List protocolVersions) { this.protocolVersions = protocolVersions; } + // --------------------------------------- + // Cursor Handling for paginated requests + // --------------------------------------- + + /** + * Handles the cursor by decoding, validating and reading the index of it. + * @param cursor the base64 representation of the cursor. + * @param mapSize the size of the map from which the values should be read. + * @param mapHash the hash of the map to compare the cursor value to. + * @return a {@link Mono} which contains the index to which the cursor points. + */ + private Mono handleCursor(String cursor, int mapSize, int mapHash) { + if (cursor == null) { + return Mono.just(0); + } + + var decodedCursor = decodeCursor(cursor); + + if (!isCursorValid(decodedCursor, mapSize, mapHash)) { + return Mono.error(new McpParamsValidationError("Invalid cursor")); + } + + return Mono.just(getCursorIndex(decodedCursor)); + } + + private String getCursor(int endIndex, int mapSize, int mapHash) { + if (endIndex >= mapSize) { + return null; + } + return encodeCursor(endIndex, mapHash); + } + + private int getCursorIndex(String cursor) { + return Integer.parseInt(cursor.split(":")[0]); + } + + private boolean isCursorValid(String cursor, int maxPageSize, int currentHash) { + var cursorElements = cursor.split(":"); + + if (cursorElements.length != 2) { + logger.debug("Length of elements in cursor doesn't match expected number. Cursor: {} Actual number: {}", + cursor, cursorElements.length); + return false; + } + + int index; + int hash; + + try { + index = Integer.parseInt(cursorElements[0]); + hash = Integer.parseInt(cursorElements[1]); + } + catch (NumberFormatException e) { + logger.debug("Failed to parse cursor elements."); + return false; + } + + if (index < 0 || index > maxPageSize || hash != currentHash) { + logger.debug("Cursor boundaries are invalid."); + return false; + } + + return true; + } + + private String encodeCursor(int index, int hash) { + var cursor = index + ":" + hash; + + return Base64.getEncoder().encodeToString(cursor.getBytes()); + } + + private String decodeCursor(String base64Cursor) { + return new String(Base64.getDecoder().decode(base64Cursor)); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpParamsValidationError.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpParamsValidationError.java new file mode 100644 index 000000000..e7ecb0058 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpParamsValidationError.java @@ -0,0 +1,9 @@ +package io.modelcontextprotocol.spec; + +public class McpParamsValidationError extends McpError { + + public McpParamsValidationError(String error) { + super(error); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d859..1f6feaa17 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -226,10 +226,20 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field + .onErrorResume(error -> { + + var errorCode = McpSchema.ErrorCodes.INTERNAL_ERROR; + + if (error instanceof McpParamsValidationError) { + errorCode = McpSchema.ErrorCodes.INVALID_PARAMS; + } + + // TODO: add error message through the data field + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(errorCode, error.getMessage(), null)); + + return Mono.just(errorResponse); + }); }); }