Skip to content

refactor(client): migrates CompletableFuture to reactive patterns for HttpClientSseClientTransport #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.concurrent.CompletableFuture;
import java.time.Duration;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.regex.Pattern;

import reactor.core.publisher.Mono;
import reactor.util.retry.Retry;

/**
* A Server-Sent Events (SSE) client implementation using Java's Flow API for reactive
* stream processing. This client establishes a connection to an SSE endpoint and
Expand Down Expand Up @@ -59,14 +62,25 @@ public class FlowSseClient {
*/
private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE);

/**
* Atomic reference to hold the current subscription for the SSE stream.
*/
private final AtomicReference<Flow.Subscription> currentSubscription = new AtomicReference<>();

/**
* Atomic reference to hold the last event ID received from the SSE stream. This can
* be used to resume the stream from the last known event.
*/
private final AtomicReference<String> lastEventId = new AtomicReference<>();

/**
* Record class representing a Server-Sent Event with its standard fields.
*
* @param id the event ID (may be null)
* @param type the event type (defaults to "message" if not specified in the stream)
* @param data the event payload data
*/
public static record SseEvent(String id, String type, String data) {
public record SseEvent(String id, String type, String data) {
}

/**
Expand Down Expand Up @@ -121,22 +135,35 @@ public FlowSseClient(HttpClient httpClient, HttpRequest.Builder requestBuilder)
* @throws RuntimeException if the connection fails with a non-200 status code
*/
public void subscribe(String url, SseEventHandler eventHandler) {
HttpRequest request = this.requestBuilder.uri(URI.create(url))
.header("Accept", "text/event-stream")
.header("Cache-Control", "no-cache")
.GET()
.build();

StringBuilder eventBuilder = new StringBuilder();
AtomicReference<String> currentEventId = new AtomicReference<>();
AtomicReference<String> currentEventType = new AtomicReference<>("message");
subscribeAsync(url, eventHandler).subscribe();
}

Flow.Subscriber<String> lineSubscriber = new Flow.Subscriber<>() {
/**
* Subscribes to an SSE endpoint and processes the event stream.
*
* <p>
* This method establishes a connection to the specified URL and begins processing the
* SSE stream. Events are parsed and delivered to the provided event handler. The
* connection remains active until either an error occurs or the server closes the
* connection.
* @param url the SSE endpoint URL to connect to
* @param eventHandler the handler that will receive SSE events and error
* notifications
* @return a Mono representing the completion of the subscription
* @throws RuntimeException if the connection fails with a non-200 status code
*/
public Mono<Void> subscribeAsync(String url, SseEventHandler eventHandler) {
final Function<Flow.Subscriber<String>, HttpResponse.BodySubscriber<Void>> subscriberFactory = HttpResponse.BodySubscribers::fromLineSubscriber;
final StringBuilder eventBuilder = new StringBuilder();
final AtomicReference<String> currentEventId = new AtomicReference<>();
final AtomicReference<String> currentEventType = new AtomicReference<>("message");
final Flow.Subscriber<String> lineSubscriber = new Flow.Subscriber<>() {
private Flow.Subscription subscription;

@Override
public void onSubscribe(Flow.Subscription subscription) {
this.subscription = subscription;
currentSubscription.set(subscription);
subscription.request(Long.MAX_VALUE);
}

Expand All @@ -147,6 +174,7 @@ public void onNext(String line) {
if (eventBuilder.length() > 0) {
String eventData = eventBuilder.toString();
SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim());
lastEventId.set(currentEventId.get());
eventHandler.onEvent(event);
eventBuilder.setLength(0);
}
Expand Down Expand Up @@ -190,21 +218,55 @@ public void onComplete() {
}
};

Function<Flow.Subscriber<String>, HttpResponse.BodySubscriber<Void>> subscriberFactory = subscriber -> HttpResponse.BodySubscribers
.fromLineSubscriber(subscriber);
return Mono.defer(() -> {
HttpRequest.Builder builder = this.requestBuilder.uri(URI.create(url))
.header("Accept", "text/event-stream")
.header("Cache-Control", "no-cache")
.GET();

String lastId = lastEventId.get();
if (lastId != null) {
builder.header("Last-Event-ID", lastId);
}

CompletableFuture<HttpResponse<Void>> future = this.httpClient.sendAsync(request,
info -> subscriberFactory.apply(lineSubscriber));
HttpRequest request = builder.build();

future.thenAccept(response -> {
int status = response.statusCode();
if (status != 200 && status != 201 && status != 202 && status != 206) {
throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status);
return Mono
.fromFuture(() -> this.httpClient.sendAsync(request, info -> subscriberFactory.apply(lineSubscriber)))
.flatMap(response -> {
int status = response.statusCode();
if (status >= 400 && status < 500 && status != 429 && status != 408) {
return Mono.error(new SseConnectionException("Client error." + status, status));
}
if (status != 200 && status != 201 && status != 202 && status != 206) {
return Mono.error(new SseConnectionException("Failed to connect to SSE stream.", status));
}
return Mono.empty();
})
.doOnError(eventHandler::onError)
.doFinally(sig -> {
Flow.Subscription active = currentSubscription.getAndSet(null);
if (active != null)
active.cancel();
})
.then();
}).retryWhen(Retry.backoff(3, Duration.ofSeconds(2)).filter(err -> {
if (err instanceof SseConnectionException exception) {
return exception.isRetryable();
}
}).exceptionally(throwable -> {
eventHandler.onError(throwable);
return null;
});
return true; // Retry on other exceptions
}).onRetryExhaustedThrow((spec, signal) -> signal.failure()));

}

/**
* Gracefully close the SSE stream subscription if active.
*/
public void close() {
Flow.Subscription subscription = currentSubscription.getAndSet(null);
if (subscription != null) {
subscription.cancel();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
Expand All @@ -28,6 +25,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;
import reactor.util.retry.Retry;

/**
* Server-Sent Events (SSE) implementation of the
Expand Down Expand Up @@ -90,18 +89,12 @@ public class HttpClientSseClientTransport implements McpClientTransport {
/** JSON object mapper for message serialization/deserialization */
protected ObjectMapper objectMapper;

/** Flag indicating if the transport is in closing state */
private volatile boolean isClosing = false;

/** Latch for coordinating endpoint discovery */
private final CountDownLatch closeLatch = new CountDownLatch(1);
/** Enum indicating the transport state */
private final AtomicReference<TransportState> state = new AtomicReference<>(TransportState.DISCONNECTED);

/** Holds the discovered message endpoint URL */
private final AtomicReference<String> messageEndpoint = new AtomicReference<>();

/** Holds the SSE connection future */
private final AtomicReference<CompletableFuture<Void>> connectionFuture = new AtomicReference<>();

/**
* Creates a new transport instance with default HTTP client and object mapper.
* @param baseUri the base URI of the MCP server
Expand Down Expand Up @@ -338,48 +331,48 @@ public HttpClientSseClientTransport build() {
*/
@Override
public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) {
CompletableFuture<Void> future = new CompletableFuture<>();
connectionFuture.set(future);
state.set(TransportState.CONNECTING);
return Mono.<Void>create(sink -> subscribeSse(handler, sink))
.doOnError(err -> logger.error("Error during connection", err));

URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint);
}

private void subscribeSse(final Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler, MonoSink<Void> sink) {
final URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint);
sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() {
@Override
public void onEvent(SseEvent event) {
if (isClosing) {
if (state.get() == TransportState.CLOSING || state.get() == TransportState.DISCONNECTED) {
return;
}

sink.success();
try {
if (ENDPOINT_EVENT_TYPE.equals(event.type())) {
String endpoint = event.data();
messageEndpoint.set(endpoint);
closeLatch.countDown();
future.complete(null);
}
else if (MESSAGE_EVENT_TYPE.equals(event.type())) {
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data());
handler.apply(Mono.just(message)).subscribe();
}
else {
logger.error("Received unrecognized SSE event type: {}", event.type());
switch (event.type()) {
case ENDPOINT_EVENT_TYPE -> {
messageEndpoint.set(event.data());
state.set(TransportState.CONNECTED);
}
case MESSAGE_EVENT_TYPE -> {
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data());
handler.apply(Mono.just(message)).subscribe();
}
default -> logger.error("Received unrecognized SSE event type: {}", event.type());
}
}
catch (IOException e) {
catch (Exception e) {
logger.error("Error processing SSE event", e);
future.completeExceptionally(e);
sink.error(new McpError("Error processing SSE event"));
}
}

@Override
public void onError(Throwable error) {
if (!isClosing) {
if (state.get() != TransportState.CLOSING) {
logger.error("SSE connection error", error);
future.completeExceptionally(error);
sink.error(error);
}
}
});

return Mono.fromFuture(future);
}

/**
Expand All @@ -394,44 +387,44 @@ public void onError(Throwable error) {
*/
@Override
public Mono<Void> sendMessage(JSONRPCMessage message) {
if (isClosing) {
if (state.get() == TransportState.CLOSING || state.get() == TransportState.DISCONNECTED) {
return Mono.empty();
}

try {
if (!closeLatch.await(10, TimeUnit.SECONDS)) {
return Mono.error(new McpError("Failed to wait for the message endpoint"));
return Mono.defer(() -> {
if (messageEndpoint.get() == null) {
return Mono.error(new McpError("No message endpoint available"));
}
}
catch (InterruptedException e) {
return Mono.error(new McpError("Failed to wait for the message endpoint"));
}

String endpoint = messageEndpoint.get();
if (endpoint == null) {
return Mono.error(new McpError("No message endpoint available"));
}
return serializeMessage(message).flatMap(body -> sendHttpPost(messageEndpoint.get(), body))
.doOnNext(this::logIfNotOk)
.doOnError(err -> logger.error("Error sending message", err))
.then();

}).retryWhen(Retry.fixedDelay(3, Duration.ofSeconds(3)).filter(err -> messageEndpoint.get() == null));
}

private Mono<String> serializeMessage(final JSONRPCMessage message) {
try {
String jsonText = this.objectMapper.writeValueAsString(message);
URI requestUri = Utils.resolveUri(baseUri, endpoint);
HttpRequest request = this.requestBuilder.uri(requestUri)
.POST(HttpRequest.BodyPublishers.ofString(jsonText))
.build();

return Mono.fromFuture(
httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> {
if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202
&& response.statusCode() != 206) {
logger.error("Error sending message: {}", response.statusCode());
}
}));
return Mono.just(objectMapper.writeValueAsString(message));
}
catch (IOException e) {
if (!isClosing) {
return Mono.error(new RuntimeException("Failed to serialize message", e));
}
return Mono.empty();
return Mono.error(new McpError("Failed to serialize message"));
}
}

private Mono<HttpResponse<Void>> sendHttpPost(final String endpoint, final String body) {
final URI requestUri = Utils.resolveUri(baseUri, endpoint);
final HttpRequest request = requestBuilder.uri(requestUri)
.POST(HttpRequest.BodyPublishers.ofString(body))
.build();

return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()));
}

private void logIfNotOk(final HttpResponse<?> response) {
if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202
&& response.statusCode() != 206) {
logger.error("Error sending message: {}", response.statusCode());
}
}

Expand All @@ -445,12 +438,10 @@ public Mono<Void> sendMessage(JSONRPCMessage message) {
*/
@Override
public Mono<Void> closeGracefully() {
state.set(TransportState.CLOSING);
return Mono.fromRunnable(() -> {
isClosing = true;
CompletableFuture<Void> future = connectionFuture.get();
if (future != null && !future.isDone()) {
future.cancel(true);
}
sseClient.close();
state.set(TransportState.DISCONNECTED);
});
}

Expand All @@ -466,4 +457,19 @@ public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
return this.objectMapper.convertValue(data, typeRef);
}

/**
* Get the current transport state.
* @return the current transport state
*/
public TransportState getState() {
return state.get();
}

// Enum to manage transport states
public enum TransportState {

DISCONNECTED, CONNECTING, CONNECTED, CLOSING

}

}
Loading