From 25d521901577713dd872de0c35db4ef891433796 Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Thu, 10 Apr 2025 16:58:51 +0200 Subject: [PATCH] feat(ws): adds WebSocketClientTransport --- .../transport/WebSocketClientTransport.java | 221 ++++++++++++++++++ .../WebSocketClientTransportTest.java | 124 ++++++++++ mcp/src/test/resources/ws/Dockerfile | 17 ++ mcp/src/test/resources/ws/server.js | 21 ++ 4 files changed, 383 insertions(+) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/client/transport/WebSocketClientTransport.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/transport/WebSocketClientTransportTest.java create mode 100644 mcp/src/test/resources/ws/Dockerfile create mode 100644 mcp/src/test/resources/ws/server.js diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/WebSocketClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/WebSocketClientTransport.java new file mode 100644 index 00000000..721ff8fa --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/WebSocketClientTransport.java @@ -0,0 +1,221 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.WebSocket; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.retry.Retry; + +/** + * The WebSocket (WS) implementation of the + * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with WS + * transport specification, using Java's HttpClient. + * + * @author Aliaksei Darafeyeu + */ +public class WebSocketClientTransport implements McpClientTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketClientTransport.class); + + private final HttpClient httpClient; + + private final ObjectMapper objectMapper; + + private final URI uri; + + private final AtomicReference webSocketRef = new AtomicReference<>(); + + private final AtomicReference state = new AtomicReference<>(TransportState.DISCONNECTED); + + private final Sinks.Many errorSink = Sinks.many().multicast().onBackpressureBuffer(); + + /** + * The constructor for the WebSocketClientTransport. + * @param uri the URI to connect to + * @param clientBuilder the HttpClient builder + * @param objectMapper the ObjectMapper for JSON serialization/deserialization + */ + WebSocketClientTransport(final URI uri, final HttpClient.Builder clientBuilder, final ObjectMapper objectMapper) { + this.uri = uri; + this.httpClient = clientBuilder.build(); + this.objectMapper = objectMapper; + } + + /** + * Creates a new WebSocketClientTransport instance with the specified URI. + * @param uri the URI to connect to + * @return a new Builder instance + */ + public static Builder builder(final URI uri) { + return new Builder().uri(uri); + } + + /** + * The state of the Transport connection. + */ + public enum TransportState { + + DISCONNECTED, CONNECTING, CONNECTED, CLOSED + + } + + /** + * A builder for creating instances of WebSocketClientTransport. + */ + public static class Builder { + + private URI uri; + + private final HttpClient.Builder clientBuilder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .connectTimeout(Duration.ofSeconds(10)); + + private ObjectMapper objectMapper = new ObjectMapper(); + + public Builder uri(final URI uri) { + this.uri = uri; + return this; + } + + public Builder customizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + return this; + } + + public Builder objectMapper(final ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + return this; + } + + public WebSocketClientTransport build() { + return new WebSocketClientTransport(uri, clientBuilder, objectMapper); + } + + } + + public Mono connect(final Function, Mono> handler) { + if (!state.compareAndSet(TransportState.DISCONNECTED, TransportState.CONNECTING)) { + return Mono.error(new IllegalStateException("WebSocket is already connecting or connected")); + } + + return Mono.fromFuture(httpClient.newWebSocketBuilder().buildAsync(uri, new WebSocket.Listener() { + private final StringBuilder messageBuffer = new StringBuilder(); + + @Override + public void onOpen(WebSocket webSocket) { + webSocketRef.set(webSocket); + state.set(TransportState.CONNECTED); + } + + @Override + public CompletionStage onText(WebSocket webSocket, CharSequence data, boolean last) { + messageBuffer.append(data); + if (last) { + final String fullMessage = messageBuffer.toString(); + messageBuffer.setLength(0); + try { + final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, + fullMessage); + handler.apply(Mono.just(msg)).subscribe(); + } + catch (Exception e) { + errorSink.tryEmitNext(e); + LOGGER.error("Error processing WS event", e); + } + } + + webSocket.request(1); + return CompletableFuture.completedFuture(null); + } + + @Override + public void onError(WebSocket webSocket, Throwable error) { + errorSink.tryEmitNext(error); + state.set(TransportState.CLOSED); + LOGGER.error("WS connection error", error); + } + + @Override + public CompletionStage onClose(WebSocket webSocket, int statusCode, String reason) { + state.set(TransportState.CLOSED); + return CompletableFuture.completedFuture(null); + } + + })).then(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + + return Mono.defer(() -> { + WebSocket ws = webSocketRef.get(); + if (ws == null && state.get() == TransportState.CONNECTING) { + return Mono.error(new IllegalStateException("WebSocket is connecting.")); + } + + if (ws == null || state.get() == TransportState.DISCONNECTED || state.get() == TransportState.CLOSED) { + return Mono.error(new IllegalStateException("WebSocket is closed.")); + } + + try { + String json = objectMapper.writeValueAsString(message); + return Mono.fromFuture(ws.sendText(json, true)).then(); + } + catch (Exception e) { + return Mono.error(e); + } + }).retryWhen(Retry.backoff(3, Duration.ofSeconds(3)).filter(err -> { + if (err instanceof IllegalStateException) { + return err.getMessage().equals("WebSocket is connecting."); + } + return true; + })).onErrorResume(e -> { + LOGGER.error("Failed to send message after retries", e); + errorSink.tryEmitNext(e); + return Mono.error(new IllegalStateException("WebSocket send failed after retries", e)); + }); + + } + + @Override + public Mono closeGracefully() { + WebSocket webSocket = webSocketRef.getAndSet(null); + if (webSocket != null && state.get() == TransportState.CONNECTED) { + state.set(TransportState.CLOSED); + return Mono.fromFuture(webSocket.sendClose(WebSocket.NORMAL_CLOSURE, "Closing")).then(); + } + return Mono.empty(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + public TransportState getState() { + return state.get(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/WebSocketClientTransportTest.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/WebSocketClientTransportTest.java new file mode 100644 index 00000000..20d29cd1 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/WebSocketClientTransportTest.java @@ -0,0 +1,124 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.List; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.images.builder.ImageFromDockerfile; + +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for the {@link WebSocketClientTransport} class. + * + * @author Aliaksei Darafeyeu + */ +class WebSocketClientTransportTest { + + private static GenericContainer wsContainer; + + private static URI websocketUri; + + private WebSocketClientTransport transport; + + @BeforeAll + static void startContainer() { + wsContainer = new GenericContainer<>( + new ImageFromDockerfile().withFileFromClasspath("server.js", "ws/server.js") + .withFileFromClasspath("Dockerfile", "ws/Dockerfile")) + .withExposedPorts(8080); + + wsContainer.start(); + + int port = wsContainer.getMappedPort(8080); + websocketUri = URI.create("ws://localhost:" + port); + } + + @BeforeEach + public void setUp() { + transport = WebSocketClientTransport.builder(websocketUri).build(); + } + + @AfterAll + static void tearDown() { + wsContainer.stop(); + } + + @Test + void testConnectSuccessfully() { + // Try to connect to the WebSocket server + Mono connection = transport.connect(message -> Mono.empty()); + + // Wait for the connection to complete + StepVerifier.create(connection).expectComplete().verify(); + + // Ensure that connection is established + assertEquals(WebSocketClientTransport.TransportState.CONNECTED, transport.getState()); + } + + @Test + void testSendMessage() { + // Connect to the server + Mono connection = transport.connect(message -> Mono.empty()); + + // Ensure connection is successful + StepVerifier.create(connection).expectComplete().verify(); + + // Create a simple message to send + var messageRequest = new McpSchema.CreateMessageRequest( + List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))), + null, null, null, null, 0, null, null); + McpSchema.JSONRPCMessage message = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest); + + // Send a message to the server + Mono sendMessage = transport.sendMessage(message); + + // Ensure message is sent successfully + StepVerifier.create(sendMessage).expectComplete().verify(); + } + + @Test + void testCloseConnectionGracefully() { + Mono connection = transport.connect(message -> Mono.empty()); + + StepVerifier.create(connection).expectComplete().verify(); + + // Close the connection gracefully + Mono closeConnection = transport.closeGracefully(); + + // Verify that the connection is closed successfully + StepVerifier.create(closeConnection).expectComplete().verify(); + + assertEquals(WebSocketClientTransport.TransportState.CLOSED, transport.getState()); + } + + @Test + void testSendMessageAfterConnectionClosed() { + // Send a message before connection is established + // Create a simple message to send + var messageRequest = new McpSchema.CreateMessageRequest( + List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))), + null, null, null, null, 0, null, null); + McpSchema.JSONRPCMessage message = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest); + + Mono sendMessageBeforeConnect = transport.sendMessage(message); + + // Verify that the transport returns an error because the connection is closed + StepVerifier.create(sendMessageBeforeConnect).expectError(IllegalStateException.class).verify(); + } + +} \ No newline at end of file diff --git a/mcp/src/test/resources/ws/Dockerfile b/mcp/src/test/resources/ws/Dockerfile new file mode 100644 index 00000000..f9f99083 --- /dev/null +++ b/mcp/src/test/resources/ws/Dockerfile @@ -0,0 +1,17 @@ +# Use a Node.js base image +FROM node:14 + +# Set the working directory inside the container +WORKDIR /usr/src/app + +# Copy the server.js file into the container +COPY server.js /usr/src/app/ + +# Install dependencies (e.g., the ws package) +RUN npm init -y && npm install ws + +# Expose the port for WebSocket (e.g., 8080) +EXPOSE 8080 + +# Command to run the WebSocket server +CMD ["node", "server.js"] diff --git a/mcp/src/test/resources/ws/server.js b/mcp/src/test/resources/ws/server.js new file mode 100644 index 00000000..914829d1 --- /dev/null +++ b/mcp/src/test/resources/ws/server.js @@ -0,0 +1,21 @@ +// Import the WebSocket package +const WebSocket = require('ws'); + +// Set up the WebSocket server to listen on port 8080 +const wss = new WebSocket.Server({ port: 8080 }); + +// When a new WebSocket connection is established +wss.on('connection', function connection(ws) { + console.log('New client connected'); + + // When a message is received from the client + ws.on('message', function incoming(message) { + console.log('received: %s', message); + }); + + // Send a welcome message to the client + ws.send('Welcome to the WebSocket server!'); +}); + +// Log the WebSocket server start +console.log('WebSocket server is listening on ws://localhost:8080');