Skip to content

feat: add a method in mcp session to provide the last request timestamp #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import com.fasterxml.jackson.core.type.TypeReference;
Expand Down Expand Up @@ -102,6 +107,7 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement

/** Session factory for creating new sessions */
private McpServerSession.Factory sessionFactory;
private ScheduledFuture<?> removeSessionsDeprecarted;

/**
* Creates a new HttpServletSseServerTransportProvider instance with a custom SSE
Expand Down Expand Up @@ -131,6 +137,18 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.removeSessionsDeprecarted = Executors.newScheduledThreadPool(1)
.scheduleAtFixedRate(() -> new ArrayList<>(sessions.values())
.forEach(session -> {
if (TimeUnit.MINUTES.convert(Duration.ofMillis(System.currentTimeMillis() - session.lastRequestTimestamp())) > 30L) {
// close the session if it has not received a request in the last 30 minutes
session.closeGracefully()
.doOnError(error ->
logger.warn("Failed to gracefully close the session {} while it has not received any request in the last 30 minutes." +
"Here is the following error msg: {}", session.getId(), error.getMessage()))
.subscribe();
}
}), 0, 1, TimeUnit.MINUTES);
}

/**
Expand Down Expand Up @@ -323,8 +341,8 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
public Mono<Void> closeGracefully() {
isClosing.set(true);
logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size());

return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then();
return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully)
.doOnTerminate(() -> this.removeSessionsDeprecarted.cancel(true)).then();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ public class McpClientSession implements McpSession {
/** Atomic counter for generating unique request IDs */
private final AtomicLong requestCounter = new AtomicLong(0);

/** To record the last request timestamp */
private final AtomicLong lastRequestTs = new AtomicLong(System.currentTimeMillis());

private final Disposable connection;

/**
Expand Down Expand Up @@ -135,6 +138,7 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport,
}
else if (message instanceof McpSchema.JSONRPCRequest request) {
logger.debug("Received request: {}", request);
lastRequestTs.set(System.currentTimeMillis());
handleIncomingRequest(request).subscribe(response -> transport.sendMessage(response).subscribe(),
error -> {
var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(),
Expand Down Expand Up @@ -286,4 +290,9 @@ public void close() {
transport.close();
}

@Override
public long lastRequestTimestamp() {
return this.lastRequestTs.get();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ public class McpServerSession implements McpSession {

private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED);

/** To record the last request timestamp */
private final AtomicLong lastRequestTs = new AtomicLong(System.currentTimeMillis());

/**
* Creates a new server session with the given parameters and the transport to use.
* @param id session id
Expand Down Expand Up @@ -169,6 +172,7 @@ public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
}
else if (message instanceof McpSchema.JSONRPCRequest request) {
logger.debug("Received request: {}", request);
lastRequestTs.set(System.currentTimeMillis());
return handleIncomingRequest(request).onErrorResume(error -> {
var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
Expand Down Expand Up @@ -277,6 +281,11 @@ public void close() {
this.transport.close();
}

@Override
public long lastRequestTimestamp() {
return lastRequestTs.get();
}

/**
* Request handler for the initialization request.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,9 @@ default Mono<Void> sendNotification(String method) {
*/
void close();

/**
* @return get the timestamp of the last request the session received.
*/
long lastRequestTimestamp();

}