From 93f11198ce2f941f1dfbfc6597fa630557f98de2 Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Tue, 8 Apr 2025 20:00:22 +0200 Subject: [PATCH 1/5] refactor: migrates CompletableFuture to reactive patterns for HttpClientSseClientTransport --- .../client/transport/FlowSseClient.java | 191 ++++++++++++------ .../HttpClientSseClientTransport.java | 147 +++++++------- 2 files changed, 204 insertions(+), 134 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java index 50af35c7..e768ab98 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java @@ -7,12 +7,15 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; -import java.util.concurrent.CompletableFuture; +import java.util.Map; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.Function; import java.util.regex.Pattern; +import reactor.core.publisher.Mono; + /** * A Server-Sent Events (SSE) client implementation using Java's Flow API for reactive * stream processing. This client establishes a connection to an SSE endpoint and @@ -59,6 +62,11 @@ public class FlowSseClient { */ private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE); + /** + * Atomic reference to hold the current subscription for the SSE stream. + */ + private final AtomicReference currentSubscription = new AtomicReference<>(); + /** * Record class representing a Server-Sent Event with its standard fields. * @@ -66,7 +74,7 @@ public class FlowSseClient { * @param type the event type (defaults to "message" if not specified in the stream) * @param data the event payload data */ - public static record SseEvent(String id, String type, String data) { + public record SseEvent(String id, String type, String data) { } /** @@ -121,90 +129,143 @@ public FlowSseClient(HttpClient httpClient, HttpRequest.Builder requestBuilder) * @throws RuntimeException if the connection fails with a non-200 status code */ public void subscribe(String url, SseEventHandler eventHandler) { + subscribeAsync(url, eventHandler).subscribe(); + } + + /** + * Subscribes to an SSE endpoint and processes the event stream. + * + *

+ * This method establishes a connection to the specified URL and begins processing the + * SSE stream. Events are parsed and delivered to the provided event handler. The + * connection remains active until either an error occurs or the server closes the + * connection. + * @param url the SSE endpoint URL to connect to + * @param eventHandler the handler that will receive SSE events and error + * notifications + * @return a Mono representing the completion of the subscription + * @throws RuntimeException if the connection fails with a non-200 status code + */ + public Mono subscribeAsync(String url, SseEventHandler eventHandler) { HttpRequest request = this.requestBuilder.uri(URI.create(url)) .header("Accept", "text/event-stream") .header("Cache-Control", "no-cache") .GET() .build(); - StringBuilder eventBuilder = new StringBuilder(); - AtomicReference currentEventId = new AtomicReference<>(); - AtomicReference currentEventType = new AtomicReference<>("message"); + SseSubscriber lineSubscriber = new SseSubscriber(eventHandler); + Function, HttpResponse.BodySubscriber> subscriberFactory = HttpResponse.BodySubscribers::fromLineSubscriber; - Flow.Subscriber lineSubscriber = new Flow.Subscriber<>() { - private Flow.Subscription subscription; + return Mono + .fromFuture(() -> this.httpClient.sendAsync(request, info -> subscriberFactory.apply(lineSubscriber))) + .doOnTerminate(lineSubscriber::cancelSubscription) + .doOnError(eventHandler::onError) + .doOnSuccess(response -> { + int status = response.statusCode(); + if (status != 200 && status != 201 && status != 202 && status != 206) { + throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status); + } + }) + .then() + .doOnSubscribe(subscription -> currentSubscription.set(lineSubscriber.getSubscription())); + } - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - subscription.request(Long.MAX_VALUE); - } + /** + * Gracefully close the SSE stream subscription if active. + */ + public void close() { + Flow.Subscription subscription = currentSubscription.get(); + if (subscription != null) { + subscription.cancel(); + currentSubscription.set(null); + } + } - @Override - public void onNext(String line) { - if (line.isEmpty()) { - // Empty line means end of event - if (eventBuilder.length() > 0) { - String eventData = eventBuilder.toString(); - SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - eventHandler.onEvent(event); - eventBuilder.setLength(0); - } + /** + * Inner class that implements Flow.Subscriber to handle incoming SSE events. + * It processes the event stream, parsing the data and notifying the event handler. + */ + private static class SseSubscriber implements Flow.Subscriber { + + private final SseEventHandler eventHandler; + + private final StringBuilder eventBuilder = new StringBuilder(); + + private final AtomicReference currentEventId = new AtomicReference<>(); + + private final AtomicReference currentEventType = new AtomicReference<>("message"); + + private Flow.Subscription subscription; + + public SseSubscriber(SseEventHandler eventHandler) { + this.eventHandler = eventHandler; + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(String line) { + if (line.isEmpty()) { + // Empty line means end of event + if (eventBuilder.isEmpty()) { + String eventData = eventBuilder.toString(); + SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + eventHandler.onEvent(event); + eventBuilder.setLength(0); } - else { - if (line.startsWith("data:")) { - var matcher = EVENT_DATA_PATTERN.matcher(line); - if (matcher.find()) { - eventBuilder.append(matcher.group(1).trim()).append("\n"); - } + } + else { + if (line.startsWith("data:")) { + var matcher = EVENT_DATA_PATTERN.matcher(line); + if (matcher.find()) { + eventBuilder.append(matcher.group(1).trim()).append("\n"); } - else if (line.startsWith("id:")) { - var matcher = EVENT_ID_PATTERN.matcher(line); - if (matcher.find()) { - currentEventId.set(matcher.group(1).trim()); - } + } + else if (line.startsWith("id:")) { + var matcher = EVENT_ID_PATTERN.matcher(line); + if (matcher.find()) { + currentEventId.set(matcher.group(1).trim()); } - else if (line.startsWith("event:")) { - var matcher = EVENT_TYPE_PATTERN.matcher(line); - if (matcher.find()) { - currentEventType.set(matcher.group(1).trim()); - } + } + else if (line.startsWith("event:")) { + var matcher = EVENT_TYPE_PATTERN.matcher(line); + if (matcher.find()) { + currentEventType.set(matcher.group(1).trim()); } } - subscription.request(1); } + subscription.request(1); + } - @Override - public void onError(Throwable throwable) { - eventHandler.onError(throwable); - } + @Override + public void onError(Throwable throwable) { + eventHandler.onError(throwable); + } - @Override - public void onComplete() { - // Handle any remaining event data - if (eventBuilder.length() > 0) { - String eventData = eventBuilder.toString(); - SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - eventHandler.onEvent(event); - } + @Override + public void onComplete() { + // Handle any remaining event data + if (eventBuilder.isEmpty()) { + String eventData = eventBuilder.toString(); + SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + eventHandler.onEvent(event); } - }; + } - Function, HttpResponse.BodySubscriber> subscriberFactory = subscriber -> HttpResponse.BodySubscribers - .fromLineSubscriber(subscriber); + public Flow.Subscription getSubscription() { + return this.subscription; + } - CompletableFuture> future = this.httpClient.sendAsync(request, - info -> subscriberFactory.apply(lineSubscriber)); - - future.thenAccept(response -> { - int status = response.statusCode(); - if (status != 200 && status != 201 && status != 202 && status != 206) { - throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status); + public void cancelSubscription() { + if (subscription != null) { + subscription.cancel(); } - }).exceptionally(throwable -> { - eventHandler.onError(throwable); - return null; - }); + } + } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 99cf2a62..174c7be2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -9,9 +9,6 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -28,6 +25,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.util.retry.Retry; /** * Server-Sent Events (SSE) implementation of the @@ -90,18 +89,12 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** JSON object mapper for message serialization/deserialization */ protected ObjectMapper objectMapper; - /** Flag indicating if the transport is in closing state */ - private volatile boolean isClosing = false; - - /** Latch for coordinating endpoint discovery */ - private final CountDownLatch closeLatch = new CountDownLatch(1); + /** Enum indicating the transport state */ + private final AtomicReference state = new AtomicReference<>(TransportState.DISCONNECTED); /** Holds the discovered message endpoint URL */ private final AtomicReference messageEndpoint = new AtomicReference<>(); - /** Holds the SSE connection future */ - private final AtomicReference> connectionFuture = new AtomicReference<>(); - /** * Creates a new transport instance with default HTTP client and object mapper. * @param baseUri the base URI of the MCP server @@ -338,48 +331,51 @@ public HttpClientSseClientTransport build() { */ @Override public Mono connect(Function, Mono> handler) { - CompletableFuture future = new CompletableFuture<>(); - connectionFuture.set(future); + state.set(TransportState.CONNECTING); + return Mono.create(sink -> subscribeSse(handler, sink)) + .timeout(Duration.ofSeconds(10)) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)).maxBackoff(Duration.ofSeconds(5))) + .doOnError(err -> logger.error("Error during connection", err)); - URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { + } + + private void subscribeSse(final Function, Mono> handler, MonoSink sink) { + final URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { @Override public void onEvent(SseEvent event) { - if (isClosing) { + if (state.get() == TransportState.CLOSING || state.get() == TransportState.DISCONNECTED) { return; } try { - if (ENDPOINT_EVENT_TYPE.equals(event.type())) { - String endpoint = event.data(); - messageEndpoint.set(endpoint); - closeLatch.countDown(); - future.complete(null); - } - else if (MESSAGE_EVENT_TYPE.equals(event.type())) { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); - handler.apply(Mono.just(message)).subscribe(); - } - else { - logger.error("Received unrecognized SSE event type: {}", event.type()); + switch (event.type()) { + case ENDPOINT_EVENT_TYPE -> { + messageEndpoint.set(event.data()); + state.set(TransportState.CONNECTED); + sink.success(); + } + case MESSAGE_EVENT_TYPE -> { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); + handler.apply(Mono.just(message)).subscribe(); + } + default -> logger.error("Received unrecognized SSE event type: {}", event.type()); } } - catch (IOException e) { + catch (Exception e) { logger.error("Error processing SSE event", e); - future.completeExceptionally(e); + sink.error(new McpError("Error processing SSE event")); } } @Override public void onError(Throwable error) { - if (!isClosing) { + if (state.get() != TransportState.CLOSING) { logger.error("SSE connection error", error); - future.completeExceptionally(error); + sink.error(error); } } }); - - return Mono.fromFuture(future); } /** @@ -394,44 +390,44 @@ public void onError(Throwable error) { */ @Override public Mono sendMessage(JSONRPCMessage message) { - if (isClosing) { + if (state.get() == TransportState.CLOSING || state.get() == TransportState.DISCONNECTED) { return Mono.empty(); } - - try { - if (!closeLatch.await(10, TimeUnit.SECONDS)) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); + return Mono.defer(() -> { + if (messageEndpoint.get() == null) { + return Mono.error(new McpError("No message endpoint available")); } - } - catch (InterruptedException e) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); - } - String endpoint = messageEndpoint.get(); - if (endpoint == null) { - return Mono.error(new McpError("No message endpoint available")); - } + return serializeMessage(message).flatMap(body -> sendHttpPost(messageEndpoint.get(), body)) + .doOnNext(this::logIfNotOk) + .doOnError(err -> logger.error("Error sending message", err)) + .then(); + + }).retryWhen(Retry.fixedDelay(3, Duration.ofSeconds(3)).filter(err -> messageEndpoint.get() == null)); + } + private Mono serializeMessage(final JSONRPCMessage message) { try { - String jsonText = this.objectMapper.writeValueAsString(message); - URI requestUri = Utils.resolveUri(baseUri, endpoint); - HttpRequest request = this.requestBuilder.uri(requestUri) - .POST(HttpRequest.BodyPublishers.ofString(jsonText)) - .build(); - - return Mono.fromFuture( - httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { - if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 - && response.statusCode() != 206) { - logger.error("Error sending message: {}", response.statusCode()); - } - })); + return Mono.just(objectMapper.writeValueAsString(message)); } catch (IOException e) { - if (!isClosing) { - return Mono.error(new RuntimeException("Failed to serialize message", e)); - } - return Mono.empty(); + return Mono.error(new McpError("Failed to serialize message")); + } + } + + private Mono> sendHttpPost(final String endpoint, final String body) { + final URI requestUri = Utils.resolveUri(baseUri, endpoint); + final HttpRequest request = requestBuilder.uri(requestUri) + .POST(HttpRequest.BodyPublishers.ofString(body)) + .build(); + + return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding())); + } + + private void logIfNotOk(final HttpResponse response) { + if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 + && response.statusCode() != 206) { + logger.error("Error sending message: {}", response.statusCode()); } } @@ -445,12 +441,10 @@ public Mono sendMessage(JSONRPCMessage message) { */ @Override public Mono closeGracefully() { + state.set(TransportState.CLOSING); return Mono.fromRunnable(() -> { - isClosing = true; - CompletableFuture future = connectionFuture.get(); - if (future != null && !future.isDone()) { - future.cancel(true); - } + sseClient.close(); + state.set(TransportState.DISCONNECTED); }); } @@ -466,4 +460,19 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { return this.objectMapper.convertValue(data, typeRef); } + /** + * Get the current transport state. + * @return the current transport state + */ + public TransportState getState() { + return state.get(); + } + + // Enum to manage transport states + public enum TransportState { + + DISCONNECTED, CONNECTING, CONNECTED, CLOSING + + } + } From 5b734018df7b28c6f092d064b68cf7570dc60c23 Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Tue, 8 Apr 2025 20:08:05 +0200 Subject: [PATCH 2/5] refactor: code review --- .../io/modelcontextprotocol/client/transport/FlowSseClient.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java index e768ab98..b9467e47 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java @@ -7,10 +7,8 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; -import java.util.Map; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiConsumer; import java.util.function.Function; import java.util.regex.Pattern; From e153034a75ce95df277678b6be19485cea2f1386 Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Wed, 9 Apr 2025 16:16:51 +0200 Subject: [PATCH 3/5] refactor: correct retry --- .../client/transport/FlowSseClient.java | 209 +++++++++--------- .../HttpClientSseClientTransport.java | 5 +- .../transport/SseConnectionException.java | 41 ++++ 3 files changed, 148 insertions(+), 107 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/client/transport/SseConnectionException.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java index b9467e47..b17dbe81 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java @@ -7,12 +7,14 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.time.Duration; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.regex.Pattern; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; /** * A Server-Sent Events (SSE) client implementation using Java's Flow API for reactive @@ -65,6 +67,12 @@ public class FlowSseClient { */ private final AtomicReference currentSubscription = new AtomicReference<>(); + /** + * Atomic reference to hold the last event ID received from the SSE stream. This can + * be used to resume the stream from the last known event. + */ + private final AtomicReference lastEventId = new AtomicReference<>(); + /** * Record class representing a Server-Sent Event with its standard fields. * @@ -145,125 +153,120 @@ public void subscribe(String url, SseEventHandler eventHandler) { * @throws RuntimeException if the connection fails with a non-200 status code */ public Mono subscribeAsync(String url, SseEventHandler eventHandler) { - HttpRequest request = this.requestBuilder.uri(URI.create(url)) - .header("Accept", "text/event-stream") - .header("Cache-Control", "no-cache") - .GET() - .build(); - - SseSubscriber lineSubscriber = new SseSubscriber(eventHandler); - Function, HttpResponse.BodySubscriber> subscriberFactory = HttpResponse.BodySubscribers::fromLineSubscriber; + final Function, HttpResponse.BodySubscriber> subscriberFactory = HttpResponse.BodySubscribers::fromLineSubscriber; + final StringBuilder eventBuilder = new StringBuilder(); + final AtomicReference currentEventId = new AtomicReference<>(); + final AtomicReference currentEventType = new AtomicReference<>("message"); + final Flow.Subscriber lineSubscriber = new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + currentSubscription.set(subscription); + subscription.request(Long.MAX_VALUE); + } - return Mono - .fromFuture(() -> this.httpClient.sendAsync(request, info -> subscriberFactory.apply(lineSubscriber))) - .doOnTerminate(lineSubscriber::cancelSubscription) - .doOnError(eventHandler::onError) - .doOnSuccess(response -> { - int status = response.statusCode(); - if (status != 200 && status != 201 && status != 202 && status != 206) { - throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status); + @Override + public void onNext(String line) { + if (line.isEmpty()) { + // Empty line means end of event + if (eventBuilder.length() > 0) { + String eventData = eventBuilder.toString(); + SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + lastEventId.set(currentEventId.get()); + eventHandler.onEvent(event); + eventBuilder.setLength(0); + } } - }) - .then() - .doOnSubscribe(subscription -> currentSubscription.set(lineSubscriber.getSubscription())); - } - - /** - * Gracefully close the SSE stream subscription if active. - */ - public void close() { - Flow.Subscription subscription = currentSubscription.get(); - if (subscription != null) { - subscription.cancel(); - currentSubscription.set(null); - } - } - - /** - * Inner class that implements Flow.Subscriber to handle incoming SSE events. - * It processes the event stream, parsing the data and notifying the event handler. - */ - private static class SseSubscriber implements Flow.Subscriber { - - private final SseEventHandler eventHandler; - - private final StringBuilder eventBuilder = new StringBuilder(); - - private final AtomicReference currentEventId = new AtomicReference<>(); - - private final AtomicReference currentEventType = new AtomicReference<>("message"); - - private Flow.Subscription subscription; - - public SseSubscriber(SseEventHandler eventHandler) { - this.eventHandler = eventHandler; - } + else { + if (line.startsWith("data:")) { + var matcher = EVENT_DATA_PATTERN.matcher(line); + if (matcher.find()) { + eventBuilder.append(matcher.group(1).trim()).append("\n"); + } + } + else if (line.startsWith("id:")) { + var matcher = EVENT_ID_PATTERN.matcher(line); + if (matcher.find()) { + currentEventId.set(matcher.group(1).trim()); + } + } + else if (line.startsWith("event:")) { + var matcher = EVENT_TYPE_PATTERN.matcher(line); + if (matcher.find()) { + currentEventType.set(matcher.group(1).trim()); + } + } + } + subscription.request(1); + } - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - subscription.request(Long.MAX_VALUE); - } + @Override + public void onError(Throwable throwable) { + eventHandler.onError(throwable); + } - @Override - public void onNext(String line) { - if (line.isEmpty()) { - // Empty line means end of event - if (eventBuilder.isEmpty()) { + @Override + public void onComplete() { + // Handle any remaining event data + if (eventBuilder.length() > 0) { String eventData = eventBuilder.toString(); SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); eventHandler.onEvent(event); - eventBuilder.setLength(0); - } - } - else { - if (line.startsWith("data:")) { - var matcher = EVENT_DATA_PATTERN.matcher(line); - if (matcher.find()) { - eventBuilder.append(matcher.group(1).trim()).append("\n"); - } - } - else if (line.startsWith("id:")) { - var matcher = EVENT_ID_PATTERN.matcher(line); - if (matcher.find()) { - currentEventId.set(matcher.group(1).trim()); - } - } - else if (line.startsWith("event:")) { - var matcher = EVENT_TYPE_PATTERN.matcher(line); - if (matcher.find()) { - currentEventType.set(matcher.group(1).trim()); - } } } - subscription.request(1); - } + }; - @Override - public void onError(Throwable throwable) { - eventHandler.onError(throwable); - } + return Mono.defer(() -> { + HttpRequest.Builder builder = this.requestBuilder.uri(URI.create(url)) + .header("Accept", "text/event-stream") + .header("Cache-Control", "no-cache") + .GET(); - @Override - public void onComplete() { - // Handle any remaining event data - if (eventBuilder.isEmpty()) { - String eventData = eventBuilder.toString(); - SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - eventHandler.onEvent(event); + String lastId = lastEventId.get(); + if (lastId != null) { + builder.header("Last-Event-ID", lastId); } - } - public Flow.Subscription getSubscription() { - return this.subscription; - } + HttpRequest request = builder.build(); - public void cancelSubscription() { - if (subscription != null) { - subscription.cancel(); + return Mono + .fromFuture(() -> this.httpClient.sendAsync(request, info -> subscriberFactory.apply(lineSubscriber))) + .flatMap(response -> { + int status = response.statusCode(); + if (status >= 400 && status < 500 && status != 429 && status != 408) { + return Mono.error(new SseConnectionException("Client error." + status, status)); + } + if (status != 200 && status != 201 && status != 202 && status != 206) { + return Mono.error(new SseConnectionException("Failed to connect to SSE stream.", status)); + } + return Mono.empty(); + }) + .doOnError(eventHandler::onError) + .doFinally(sig -> { + Flow.Subscription active = currentSubscription.getAndSet(null); + if (active != null) + active.cancel(); + }) + .then(); + }).retryWhen(Retry.backoff(3, Duration.ofSeconds(2)).filter(err -> { + if (err instanceof SseConnectionException exception) { + return exception.isRetryable(); } - } + return true; // Retry on other exceptions + }).onRetryExhaustedThrow((spec, signal) -> signal.failure())); + + } + /** + * Gracefully close the SSE stream subscription if active. + */ + public void close() { + Flow.Subscription subscription = currentSubscription.getAndSet(null); + if (subscription != null) { + subscription.cancel(); + } } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 174c7be2..36ed70c3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -333,8 +333,6 @@ public HttpClientSseClientTransport build() { public Mono connect(Function, Mono> handler) { state.set(TransportState.CONNECTING); return Mono.create(sink -> subscribeSse(handler, sink)) - .timeout(Duration.ofSeconds(10)) - .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)).maxBackoff(Duration.ofSeconds(5))) .doOnError(err -> logger.error("Error during connection", err)); } @@ -347,13 +345,12 @@ public void onEvent(SseEvent event) { if (state.get() == TransportState.CLOSING || state.get() == TransportState.DISCONNECTED) { return; } - + sink.success(); try { switch (event.type()) { case ENDPOINT_EVENT_TYPE -> { messageEndpoint.set(event.data()); state.set(TransportState.CONNECTED); - sink.success(); } case MESSAGE_EVENT_TYPE -> { JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/SseConnectionException.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SseConnectionException.java new file mode 100644 index 00000000..58fb4f21 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SseConnectionException.java @@ -0,0 +1,41 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +/** + * Exception thrown when there is an issue with the SSE connection. + */ +public class SseConnectionException extends RuntimeException { + + private final int statusCode; + + /** + * Constructor for SseConnectionException. + * @param message the error message + * @param statusCode the HTTP status code associated with the error + */ + public SseConnectionException(final String message, final int statusCode) { + super(message + " (Status code: " + statusCode + ")"); + this.statusCode = statusCode; + } + + /** + * Gets the HTTP status code associated with this exception. + * @return the HTTP status code. + */ + public int getStatusCode() { + return statusCode; + } + + /** + * Checks if the status code indicates a retryable error. + * @return true if the status code is 408, 429, or in the 500-599 range; false + * otherwise. + */ + public boolean isRetryable() { + return statusCode == 408 || statusCode == 429 || (statusCode >= 500 && statusCode < 600); + } + +} From f72e92b6d6f7bb0987c8caa031bd60f8e9deb4a5 Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Mon, 21 Apr 2025 13:49:23 +0200 Subject: [PATCH 4/5] fix format --- .../client/transport/HttpClientSseClientTransport.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 36ed70c3..75a85b38 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -338,8 +338,8 @@ public Mono connect(Function, Mono> h } private void subscribeSse(final Function, Mono> handler, MonoSink sink) { - final URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { + final URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { @Override public void onEvent(SseEvent event) { if (state.get() == TransportState.CLOSING || state.get() == TransportState.DISCONNECTED) { @@ -413,7 +413,7 @@ private Mono serializeMessage(final JSONRPCMessage message) { } private Mono> sendHttpPost(final String endpoint, final String body) { - final URI requestUri = Utils.resolveUri(baseUri, endpoint); + final URI requestUri = Utils.resolveUri(baseUri, endpoint); final HttpRequest request = requestBuilder.uri(requestUri) .POST(HttpRequest.BodyPublishers.ofString(body)) .build(); From f7f6e65434744c941bf78e52538bd57d4f06fdd6 Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Mon, 21 Apr 2025 18:20:55 +0200 Subject: [PATCH 5/5] removes mock test from non-mock test class --- .../HttpClientSseClientTransportTests.java | 29 +------------------ 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 762264de..96207346 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -7,10 +7,8 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; -import java.net.http.HttpResponse; import java.time.Duration; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -22,8 +20,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; + import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Mono; @@ -34,9 +31,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; @@ -370,25 +364,4 @@ void testChainedCustomizations() { customizedTransport.closeGracefully().block(); } - @Test - @SuppressWarnings("unchecked") - void testResolvingClientEndpoint() { - HttpClient httpClient = Mockito.mock(HttpClient.class); - HttpResponse httpResponse = Mockito.mock(HttpResponse.class); - CompletableFuture> future = new CompletableFuture<>(); - future.complete(httpResponse); - when(httpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(future); - - HttpClientSseClientTransport transport = new HttpClientSseClientTransport(httpClient, HttpRequest.newBuilder(), - "http://example.com", "http://example.com/sse", new ObjectMapper()); - - transport.connect(Function.identity()); - - ArgumentCaptor httpRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); - verify(httpClient).sendAsync(httpRequestCaptor.capture(), any(HttpResponse.BodyHandler.class)); - assertThat(httpRequestCaptor.getValue().uri()).isEqualTo(URI.create("http://example.com/sse")); - - transport.closeGracefully().block(); - } - }