From b3146056e2d55defd0fe2126d7105bd89cdf0741 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Mon, 28 Apr 2025 12:59:30 -0700 Subject: [PATCH] Add stateless Streamable HTTP support - This allows a single MCP session spanning multiple requests to be handled by different servers without sharing state - This does require the servers share data protection keys, but this is standard for ASP.NET Core cookies and antiforgery as well --- .../HttpMcpServerBuilderExtensions.cs | 1 + .../HttpMcpSession.cs | 30 +-- .../HttpServerTransportOptions.cs | 7 + .../McpEndpointRouteBuilderExtensions.cs | 35 ++-- .../SseHandler.cs | 5 +- .../StatelessSessionId.cs | 16 ++ .../StatelessSessionIdJsonContext.cs | 6 + .../StreamableHttpHandler.cs | 194 ++++++++++++++---- .../Protocol/Messages/JsonRpcRequest.cs | 3 +- .../StreamableHttpClientSessionTransport.cs | 2 +- .../Transport/StreamableHttpPostTransport.cs | 42 ++-- .../StreamableHttpServerTransport.cs | 35 +++- src/ModelContextProtocol/Server/McpServer.cs | 25 ++- .../Server/McpServerOptions.cs | 22 ++ .../HttpServerIntegrationTests.cs | 2 + .../MapMcpSseTests.cs | 24 +++ .../MapMcpStatelessTests.cs | 10 + .../MapMcpStreamableHttpTests.cs | 2 +- .../MapMcpTests.cs | 32 +-- .../SseIntegrationTests.cs | 2 +- .../StatelessServerIntegrationTests.cs | 16 ++ .../StatelessServerTests.cs | 69 +++++++ .../StreamableHttpServerConformanceTests.cs | 42 +++- .../StreamableHttpServerIntegrationTests.cs | 1 - .../Program.cs | 24 +++ 25 files changed, 501 insertions(+), 146 deletions(-) create mode 100644 src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs create mode 100644 src/ModelContextProtocol.AspNetCore/StatelessSessionIdJsonContext.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index 8bff4596..a8a63e49 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -26,6 +26,7 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(); builder.Services.AddHostedService(); + builder.Services.AddDataProtection(); if (configureOptions is not null) { diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs index 1b854b94..0903dda6 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs @@ -4,7 +4,11 @@ namespace ModelContextProtocol.AspNetCore; -internal sealed class HttpMcpSession(string sessionId, TTransport transport, ClaimsPrincipal user, TimeProvider timeProvider) : IAsyncDisposable +internal sealed class HttpMcpSession( + string sessionId, + TTransport transport, + (string Type, string Value, string Issuer)? userIdClaim, + TimeProvider timeProvider) : IAsyncDisposable where TTransport : ITransport { private int _referenceCount; @@ -13,7 +17,7 @@ internal sealed class HttpMcpSession(string sessionId, TTransport tr public string Id { get; } = sessionId; public TTransport Transport { get; } = transport; - public (string Type, string Value, string Issuer)? UserIdClaim { get; } = GetUserIdClaim(user); + public (string Type, string Value, string Issuer)? UserIdClaim { get; } = userIdClaim; public CancellationToken SessionClosed => _disposeCts.Token; @@ -63,27 +67,7 @@ public async ValueTask DisposeAsync() } public bool HasSameUserId(ClaimsPrincipal user) - => UserIdClaim == GetUserIdClaim(user); - - // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. - // However, we short-circuit unlike antiforgery since we expect to call this to verify MCP messages a lot more frequently than - // verifying antiforgery tokens from
posts. - private static (string Type, string Value, string Issuer)? GetUserIdClaim(ClaimsPrincipal user) - { - if (user?.Identity?.IsAuthenticated != true) - { - return null; - } - - var claim = user.FindFirst(ClaimTypes.NameIdentifier) ?? user.FindFirst("sub") ?? user.FindFirst(ClaimTypes.Upn); - - if (claim is { } idClaim) - { - return (idClaim.Type, idClaim.Value, idClaim.Issuer); - } - - return null; - } + => UserIdClaim == StreamableHttpHandler.GetUserIdClaim(user); private sealed class UnreferenceDisposable(HttpMcpSession session, TimeProvider timeProvider) : IDisposable { diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 4880714c..df83ff6d 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -22,6 +22,13 @@ public class HttpServerTransportOptions /// public Func? RunSessionHandler { get; set; } + /// + /// Gets or sets whether the server should run in a stateless mode that does not require all requests for a given session + /// to arrive to the same ASP.NET Core application process. If true, the /sse endpoint will be disabled, and + /// client capabilities will be round-tripped as part of the mcp-session-id header instead of stored in memory. Defaults to false. + /// + public bool Stateless { get; set; } + /// /// Represents the duration of time the server will wait between any active requests before timing out an /// MCP session. This is checked in background every 5 seconds. A client trying to resume a session will diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 0eefa52f..1e60d2aa 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -35,20 +35,27 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo .WithMetadata(new AcceptsMetadata(["application/json"])) .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])) .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); - streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync) - .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); - streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync); - - // Map legacy HTTP with SSE endpoints. - var sseHandler = endpoints.ServiceProvider.GetRequiredService(); - var sseGroup = mcpGroup.MapGroup("") - .WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}"); - - sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync) - .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); - sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync) - .WithMetadata(new AcceptsMetadata(["application/json"])) - .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); + + if (!streamableHttpHandler.HttpServerTransportOptions.Stateless) + { + // The GET and DELETE endpoints are not mapped in Stateless mode since there's no way to send unsolicited messages + // for the GET to handle, and there is no server-side state for the DELETE to clean up. + streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); + streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync); + + // Map legacy HTTP with SSE endpoints only if not in Stateless mode, because we cannot guarantee the /message requests + // will be handled by the same process as the /sse request. + var sseHandler = endpoints.ServiceProvider.GetRequiredService(); + var sseGroup = mcpGroup.MapGroup("") + .WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}"); + + sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); + sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync) + .WithMetadata(new AcceptsMetadata(["application/json"])) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); + } return mcpGroup; } diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs index 36efadef..cea6817e 100644 --- a/src/ModelContextProtocol.AspNetCore/SseHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs @@ -34,7 +34,10 @@ public async Task HandleSseRequestAsync(HttpContext context) var requestPath = (context.Request.PathBase + context.Request.Path).ToString(); var endpointPattern = requestPath[..(requestPath.LastIndexOf('/') + 1)]; await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}"); - await using var httpMcpSession = new HttpMcpSession(sessionId, transport, context.User, httpMcpServerOptions.Value.TimeProvider); + + var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User); + await using var httpMcpSession = new HttpMcpSession(sessionId, transport, userIdClaim, httpMcpServerOptions.Value.TimeProvider); + if (!_sessions.TryAdd(sessionId, httpMcpSession)) { throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); diff --git a/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs b/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs new file mode 100644 index 00000000..73c206e9 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs @@ -0,0 +1,16 @@ +using ModelContextProtocol.Protocol.Types; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.AspNetCore; + +internal class StatelessSessionId +{ + [JsonPropertyName("capabilities")] + public ClientCapabilities? Capabilities { get; init; } + + [JsonPropertyName("clientInfo")] + public Implementation? ClientInfo { get; init; } + + [JsonPropertyName("userIdClaim")] + public (string Type, string Value, string Issuer)? UserIdClaim { get; init; } +} diff --git a/src/ModelContextProtocol.AspNetCore/StatelessSessionIdJsonContext.cs b/src/ModelContextProtocol.AspNetCore/StatelessSessionIdJsonContext.cs new file mode 100644 index 00000000..2690a3b1 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/StatelessSessionIdJsonContext.cs @@ -0,0 +1,6 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.AspNetCore; + +[JsonSerializable(typeof(StatelessSessionId))] +internal sealed partial class StatelessSessionIdJsonContext : JsonSerializerContext; diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 64b10d6d..072ad851 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -1,4 +1,5 @@ -using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.DataProtection; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.Logging; @@ -11,7 +12,9 @@ using System.Collections.Concurrent; using System.Diagnostics; using System.IO.Pipelines; +using System.Security.Claims; using System.Security.Cryptography; +using System.Text.Json; using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.AspNetCore; @@ -19,16 +22,24 @@ namespace ModelContextProtocol.AspNetCore; internal sealed class StreamableHttpHandler( IOptions mcpServerOptionsSnapshot, IOptionsFactory mcpServerOptionsFactory, - IOptions httpMcpServerOptions, + IOptions httpServerTransportOptions, + IDataProtectionProvider dataProtection, ILoggerFactory loggerFactory, IServiceProvider applicationServices) { + private const string StatelessSessionIdPurpose = "Microsoft.AspNetCore.StreamableHttpHandler.StatelessSessionId"; + private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); + private static readonly MediaTypeHeaderValue s_applicationJsonMediaType = new("application/json"); private static readonly MediaTypeHeaderValue s_textEventStreamMediaType = new("text/event-stream"); public ConcurrentDictionary> Sessions { get; } = new(StringComparer.Ordinal); + public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value; + + private IDataProtector Protector { get; } = dataProtection.CreateProtector(StatelessSessionIdPurpose); + public async Task HandlePostRequestAsync(HttpContext context) { // The Streamable HTTP spec mandates the client MUST accept both application/json and text/event-stream. @@ -50,14 +61,28 @@ await WriteJsonRpcErrorAsync(context, return; } - using var _ = session.AcquireReference(); - InitializeSseResponse(context); - var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted); - if (!wroteResponse) + try + { + using var _ = session.AcquireReference(); + + InitializeSseResponse(context); + var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted); + if (!wroteResponse) + { + // We wound up writing nothing, so there should be no Content-Type response header. + context.Response.Headers.ContentType = (string?)null; + context.Response.StatusCode = StatusCodes.Status202Accepted; + } + } + finally { - // We wound up writing nothing, so there should be no Content-Type response header. - context.Response.Headers.ContentType = (string?)null; - context.Response.StatusCode = StatusCodes.Status202Accepted; + // Stateless sessions are 1:1 with HTTP requests and are outlived by the MCP session tracked by the mcp-session-id. + // Non-stateless sessions are 1:1 with the mcp-session-id and outlive the POST request. + // Non-stateless sessions get disposed by a DELETE request or the IdleTrackingBackgroundService. + if (HttpServerTransportOptions.Stateless) + { + await session.DisposeAsync(); + } } } @@ -108,27 +133,36 @@ public async Task HandleDeleteRequestAsync(HttpContext context) private async ValueTask?> GetSessionAsync(HttpContext context, string sessionId) { - if (Sessions.TryGetValue(sessionId, out var existingSession)) + HttpMcpSession? session; + + if (HttpServerTransportOptions.Stateless) { - if (!existingSession.HasSameUserId(context.User)) - { - await WriteJsonRpcErrorAsync(context, - "Forbidden: The currently authenticated user does not match the user who initiated the session.", - StatusCodes.Status403Forbidden); - return null; - } + var sessionJson = Protector.Unprotect(sessionId); + var statelessSessionId = JsonSerializer.Deserialize(sessionJson, StatelessSessionIdJsonContext.Default.StatelessSessionId); + var transport = new StreamableHttpServerTransport(); + session = await CreateSessionAsync(context, transport, sessionId, statelessSessionId); + } + else if (!Sessions.TryGetValue(sessionId, out session)) + { + // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. + // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this + // JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound + // https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields + await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, 32001); + return null; + } - context.Response.Headers["mcp-session-id"] = existingSession.Id; - context.Features.Set(existingSession.Server); - return existingSession; + if (!session.HasSameUserId(context.User)) + { + await WriteJsonRpcErrorAsync(context, + "Forbidden: The currently authenticated user does not match the user who initiated the session.", + StatusCodes.Status403Forbidden); + return null; } - // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. - // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this - // JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound - // https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields - await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, 32001); - return null; + context.Response.Headers["mcp-session-id"] = session.Id; + context.Features.Set(session.Server); + return session; } private async ValueTask?> GetOrCreateSessionAsync(HttpContext context) @@ -137,14 +171,7 @@ await WriteJsonRpcErrorAsync(context, if (string.IsNullOrEmpty(sessionId)) { - var session = await CreateSessionAsync(context); - - if (!Sessions.TryAdd(session.Id, session)) - { - throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); - } - - return session; + return await StartNewSessionAsync(context); } else { @@ -152,29 +179,72 @@ await WriteJsonRpcErrorAsync(context, } } - private async ValueTask> CreateSessionAsync(HttpContext context) + private async ValueTask> StartNewSessionAsync(HttpContext context) { - var sessionId = MakeNewSessionId(); - context.Response.Headers["mcp-session-id"] = sessionId; + string sessionId; + var transport = new StreamableHttpServerTransport(); + + if (!HttpServerTransportOptions.Stateless) + { + sessionId = MakeNewSessionId(); + context.Response.Headers["mcp-session-id"] = sessionId; + } + else + { + // "(uninitialized stateless id)" is not written anywhere. We delay writing th mcp-session-id + // until after we receive the initialize request with the client info we need to serialize. + sessionId = "(uninitialized stateless id)"; + ScheduleStatelessSessionIdWrite(context, transport); + } + + var session = await CreateSessionAsync(context, transport, sessionId); + // The HttpMcpSession is not stored between requests in stateless mode. Instead the session is recreated from the mcp-session-id. + if (!HttpServerTransportOptions.Stateless) + { + if (!Sessions.TryAdd(sessionId, session)) + { + throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); + } + } + + return session; + } + + private async ValueTask> CreateSessionAsync( + HttpContext context, + StreamableHttpServerTransport transport, + string sessionId, + StatelessSessionId? statelessId = null) + { var mcpServerOptions = mcpServerOptionsSnapshot.Value; - if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions) + if (statelessId is not null || HttpServerTransportOptions.ConfigureSessionOptions is not null) { mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName); - await configureSessionOptions(context, mcpServerOptions, context.RequestAborted); + + if (statelessId is not null) + { + mcpServerOptions.KnownClientInfo = statelessId.ClientInfo; + mcpServerOptions.KnownClientCapabilities = statelessId.Capabilities; + } + + if (HttpServerTransportOptions.ConfigureSessionOptions is { } configureSessionOptions) + { + await configureSessionOptions(context, mcpServerOptions, context.RequestAborted); + } } - var transport = new StreamableHttpServerTransport(); // Use application instead of request services, because the session will likely outlive the first initialization request. var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, applicationServices); context.Features.Set(server); - var session = new HttpMcpSession(sessionId, transport, context.User, httpMcpServerOptions.Value.TimeProvider) + var userIdClaim = statelessId?.UserIdClaim ?? GetUserIdClaim(context.User); + var session = new HttpMcpSession(sessionId, transport, userIdClaim, HttpServerTransportOptions.TimeProvider) { Server = server, }; - var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? RunSessionAsync; + var runSessionAsync = HttpServerTransportOptions.RunSessionHandler ?? RunSessionAsync; session.ServerRunTask = runSessionAsync(context, server, session.SessionClosed); return session; @@ -210,9 +280,49 @@ internal static string MakeNewSessionId() return WebEncoders.Base64UrlEncode(buffer); } + private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttpServerTransport transport) + { + context.Response.OnStarting(() => + { + var statelessId = new StatelessSessionId + { + ClientInfo = transport.ClientInfo, + Capabilities = transport.ClientCapabilities, + UserIdClaim = GetUserIdClaim(context.User), + }; + + var sessionJson = JsonSerializer.Serialize(statelessId, StatelessSessionIdJsonContext.Default.StatelessSessionId); + var sessionId = Protector.Protect(sessionJson); + + context.Response.Headers["mcp-session-id"] = sessionId; + + return Task.CompletedTask; + }); + } + internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted) => session.RunAsync(requestAborted); + // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. + // However, we short-circuit unlike antiforgery since we expect to call this to verify MCP messages a lot more frequently than + // verifying antiforgery tokens from posts. + internal static (string Type, string Value, string Issuer)? GetUserIdClaim(ClaimsPrincipal user) + { + if (user?.Identity?.IsAuthenticated != true) + { + return null; + } + + var claim = user.FindFirst(ClaimTypes.NameIdentifier) ?? user.FindFirst("sub") ?? user.FindFirst(ClaimTypes.Upn); + + if (claim is { } idClaim) + { + return (idClaim.Type, idClaim.Value, idClaim.Issuer); + } + + return null; + } + private static JsonTypeInfo GetRequiredJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); private sealed class HttpDuplexPipe(HttpContext context) : IDuplexPipe diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs index ff7a4504..6e356cf2 100644 --- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs +++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs @@ -35,7 +35,8 @@ internal JsonRpcRequest WithId(RequestId id) JsonRpc = JsonRpc, Id = id, Method = Method, - Params = Params + Params = Params, + RelatedTransport = RelatedTransport, }; } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs index 7697c28e..a442d5b3 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs @@ -57,7 +57,7 @@ public override async Task SendMessageAsync( cancellationToken = sendCts.Token; #if NET - using var content = JsonContent.Create(message, McpJsonUtilities.DefaultOptions.GetTypeInfo()); + using var content = JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); #else using var content = new StringContent( JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs index 4cdb30b3..e3cdb404 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs @@ -1,7 +1,5 @@ using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; -using System.Buffers; using System.IO.Pipelines; using System.Net.ServerSentEvents; using System.Runtime.CompilerServices; @@ -14,12 +12,11 @@ namespace ModelContextProtocol.Protocol.Transport; /// Handles processing the request/response body pairs for the Streamable HTTP transport. /// This is typically used via . /// -internal sealed class StreamableHttpPostTransport(ChannelWriter? incomingChannel, IDuplexPipe httpBodies) : ITransport +internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, IDuplexPipe httpBodies) : ITransport { private readonly SseWriter _sseWriter = new(); - private readonly HashSet _pendingRequests = []; + private RequestId _pendingRequest; - // REVIEW: Should we introduce a send-only interface for RelatedTransport? public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages."); /// @@ -29,15 +26,11 @@ internal sealed class StreamableHttpPostTransport(ChannelWriter? /// public async ValueTask RunAsync(CancellationToken cancellationToken) { - // The incomingChannel is null to handle the potential client GET request to handle unsolicited JsonRpcMessages. - if (incomingChannel is not null) - { - var message = await JsonSerializer.DeserializeAsync(httpBodies.Input.AsStream(), - McpJsonUtilities.JsonContext.Default.JsonRpcMessage, cancellationToken).ConfigureAwait(false); - await OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false); - } + var message = await JsonSerializer.DeserializeAsync(httpBodies.Input.AsStream(), + McpJsonUtilities.JsonContext.Default.JsonRpcMessage, cancellationToken).ConfigureAwait(false); + await OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false); - if (_pendingRequests.Count == 0) + if (_pendingRequest.Id is null) { return false; } @@ -63,13 +56,10 @@ public async ValueTask DisposeAsync() { yield return message; - if (message.Data is JsonRpcMessageWithId response) + if (message.Data is JsonRpcMessageWithId response && response.Id == _pendingRequest) { - if (_pendingRequests.Remove(response.Id) && _pendingRequests.Count == 0) - { - // Complete the SSE response stream now that all pending requests have been processed. - break; - } + // Complete the SSE response stream now that all pending requests have been processed. + break; } } } @@ -83,13 +73,19 @@ private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, Cancella if (message is JsonRpcRequest request) { - _pendingRequests.Add(request.Id); + _pendingRequest = request.Id; + + // Store client capabilities so they can be serialized by "stateless" callers for use in later requests. + if (request.Method == RequestMethods.Initialize) + { + var initializeRequestParams = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); + parentTransport.ClientCapabilities = initializeRequestParams?.Capabilities; + parentTransport.ClientInfo = initializeRequestParams?.ClientInfo; + } } message.RelatedTransport = this; - // Really an assertion. This doesn't get called when incomingChannel is null for GET requests. - Throw.IfNull(incomingChannel); - await incomingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false); + await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs index aa9e522d..9e8cb1d6 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Types; using System.IO.Pipelines; using System.Threading.Channels; @@ -36,6 +37,21 @@ public sealed class StreamableHttpServerTransport : ITransport private int _getRequestStarted; + /// + /// Gets the capabilities supported by the client if it was received by . + /// + public ClientCapabilities? ClientCapabilities { get; internal set; } + + /// + /// Gets the version and implementation information of the connected client if it was received by . + /// + public Implementation? ClientInfo { get; internal set; } + + /// + public ChannelReader MessageReader => _incomingChannel.Reader; + + internal ChannelWriter MessageWriter => _incomingChannel.Writer; + /// /// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by /// writing any unsolicited JSON-RPC messages sent via @@ -63,20 +79,17 @@ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken c /// The duplex pipe facilitates the reading and writing of HTTP request and response data. /// This token allows for the operation to be canceled if needed. /// - /// True, if data was written to the respond body. + /// True, if data was written to the response body. /// False, if nothing was written because the request body did not contain any messages to respond to. /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. /// public async Task HandlePostRequest(IDuplexPipe httpBodies, CancellationToken cancellationToken) { using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken); - await using var postTransport = new StreamableHttpPostTransport(_incomingChannel.Writer, httpBodies); + await using var postTransport = new StreamableHttpPostTransport(this, httpBodies); return await postTransport.RunAsync(postCts.Token).ConfigureAwait(false); } - /// - public ChannelReader MessageReader => _incomingChannel.Reader; - /// public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { @@ -86,14 +99,20 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can /// public async ValueTask DisposeAsync() { - _disposeCts.Cancel(); try { - await _sseWriter.DisposeAsync().ConfigureAwait(false); + await _disposeCts.CancelAsync(); } finally { - _disposeCts.Dispose(); + try + { + await _sseWriter.DisposeAsync().ConfigureAwait(false); + } + finally + { + _disposeCts.Dispose(); + } } } } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index ae0e7afc..4a55a0af 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -26,7 +26,8 @@ internal sealed class McpServer : McpEndpoint, IMcpServer private readonly EventHandler? _toolsChangedDelegate; private readonly EventHandler? _promptsChangedDelegate; - private string _endpointName; + private readonly string _serverOnlyEndpointName; + private string? _endpointName; private int _started; /// Holds a boxed value for the server. @@ -56,9 +57,13 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? _sessionTransport = transport; ServerOptions = options; Services = serviceProvider; - _endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; + _serverOnlyEndpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; _servicesScopePerRequest = options.ScopeRequests; + ClientCapabilities = options.KnownClientCapabilities; + ClientInfo = options.KnownClientInfo; + UpdateEndpointNameWithClientInfo(); + // Configure all request handlers based on the supplied options. SetInitializeHandler(options); SetToolsHandler(options); @@ -114,7 +119,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? public IServiceProvider? Services { get; } /// - public override string EndpointName => _endpointName; + public override string EndpointName => _endpointName ?? _serverOnlyEndpointName; /// public LoggingLevel? LoggingLevel => _loggingLevel?.Value; @@ -172,8 +177,8 @@ private void SetInitializeHandler(McpServerOptions options) ClientInfo = request?.ClientInfo; // Use the ClientInfo to update the session EndpointName for logging. - _endpointName = $"{_endpointName}, Client ({ClientInfo?.Name} {ClientInfo?.Version})"; - GetSessionOrThrow().EndpointName = _endpointName; + UpdateEndpointNameWithClientInfo(); + GetSessionOrThrow().EndpointName = EndpointName; return new InitializeResult { @@ -551,6 +556,16 @@ private void SetHandler( requestTypeInfo, responseTypeInfo); } + private void UpdateEndpointNameWithClientInfo() + { + if (ClientInfo is null) + { + return; + } + + _endpointName = $"{_serverOnlyEndpointName}, Client ({ClientInfo.Name} {ClientInfo.Version})"; + } + /// Maps a to a . internal static LoggingLevel ToLoggingLevel(LogLevel level) => level switch diff --git a/src/ModelContextProtocol/Server/McpServerOptions.cs b/src/ModelContextProtocol/Server/McpServerOptions.cs index 6880d2f2..4c820a49 100644 --- a/src/ModelContextProtocol/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol/Server/McpServerOptions.cs @@ -65,4 +65,26 @@ public class McpServerOptions /// handler will be invoked within a new service scope. /// public bool ScopeRequests { get; set; } = true; + + /// + /// Gets or sets preexisting knowledge about the client including its name and version to help support + /// stateless Streamable HTTP servers that encode this knowledge in the mcp-session-id header. + /// + /// + /// + /// When not specified, this information sourced from the client's initialize request. + /// + /// + public Implementation? KnownClientInfo { get; set; } + + /// + /// Gets or sets preexisting knowledge about the client client capabilities to help support + /// stateless Streamable HTTP servers that encode this knowledge in the mcp-session-id header. + /// + /// + /// + /// When not specified, this information sourced from the client's initialize request. + /// + /// + public ClientCapabilities? KnownClientCapabilities { get; set; } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 57a6c6ad..fe7c9d03 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -207,6 +207,8 @@ await Assert.ThrowsAsync(() => [Fact] public async Task Sampling_Sse_TestServer() { + Assert.SkipWhen(GetType() == typeof(StatelessServerIntegrationTests), "Sampling is not supported in stateless mode."); + // arrange // Set up the sampling handler int samplingHandlerCalls = 0; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs index d385623a..1d4917ba 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs @@ -8,6 +8,30 @@ public class MapMcpSseTests(ITestOutputHelper outputHelper) : MapMcpTests(output { protected override bool UseStreamableHttp => false; + [Theory] + [InlineData("/mcp")] + [InlineData("/mcp/secondary")] + public async Task Allows_Customizing_Route(string pattern) + { + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless); + await using var app = Builder.Build(); + + app.MapMcp(pattern); + + await app.StartAsync(TestContext.Current.CancellationToken); + + using var response = await HttpClient.GetAsync($"http://localhost{pattern}/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + response.EnsureSuccessStatusCode(); + using var sseStream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); + using var sseStreamReader = new StreamReader(sseStream, System.Text.Encoding.UTF8); + var eventLine = await sseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); + var dataLine = await sseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.NotNull(eventLine); + Assert.Equal("event: endpoint", eventLine); + Assert.NotNull(dataLine); + Assert.Equal($"data: {pattern}/message", dataLine[..dataLine.IndexOf('?')]); + } + [Theory] [InlineData("/a", "/a/sse")] [InlineData("/a/", "/a/sse")] diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs new file mode 100644 index 00000000..030701c7 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs @@ -0,0 +1,10 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class MapMcpStatelessTests(ITestOutputHelper outputHelper) : MapMcpStreamableHttpTests(outputHelper) +{ + protected override bool UseStreamableHttp => true; + protected override bool Stateless => true; +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index 30632a8e..0b2f68bb 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -22,7 +22,7 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat Name = "TestCustomRouteServer", Version = "1.0.0", }; - }).WithHttpTransport(); + }).WithHttpTransport(ConfigureStateless); await using var app = Builder.Build(); app.MapMcp(routePattern); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 70b028e2..dd654071 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -15,6 +15,13 @@ public abstract class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelI { protected abstract bool UseStreamableHttp { get; } + protected virtual bool Stateless => false; + + protected void ConfigureStateless(HttpServerTransportOptions options) + { + options.Stateless = Stateless; + } + protected async Task ConnectAsync(string? path = null) { path ??= UseStreamableHttp ? "/" : "/sse"; @@ -37,34 +44,11 @@ public async Task MapMcp_ThrowsInvalidOperationException_IfWithHttpTransportIsNo Assert.StartsWith("You must call WithHttpTransport()", exception.Message); } - [Theory] - [InlineData("/mcp")] - [InlineData("/mcp/secondary")] - public async Task Allows_Customizing_Route(string pattern) - { - Builder.Services.AddMcpServer().WithHttpTransport(); - await using var app = Builder.Build(); - - app.MapMcp(pattern); - - await app.StartAsync(TestContext.Current.CancellationToken); - - using var response = await HttpClient.GetAsync($"http://localhost{pattern}/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); - response.EnsureSuccessStatusCode(); - using var sseStream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); - using var sseStreamReader = new StreamReader(sseStream, System.Text.Encoding.UTF8); - var eventLine = await sseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); - var dataLine = await sseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); - Assert.NotNull(eventLine); - Assert.Equal("event: endpoint", eventLine); - Assert.NotNull(dataLine); - Assert.Equal($"data: {pattern}/message", dataLine[..dataLine.IndexOf('?')]); - } [Fact] public async Task Messages_FromNewUser_AreRejected() { - Builder.Services.AddMcpServer().WithHttpTransport().WithTools(); + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools(); // Add an authentication scheme that will send a 403 Forbidden response. Builder.Services.AddAuthentication().AddBearerToken(); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index b659ff17..7733c836 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -17,7 +17,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper) { - private SseClientTransportOptions DefaultTransportOptions = new() + private readonly SseClientTransportOptions DefaultTransportOptions = new() { Endpoint = new Uri("http://localhost/sse"), Name = "In-memory Test Server", diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs new file mode 100644 index 00000000..03ceacd7 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs @@ -0,0 +1,16 @@ +using ModelContextProtocol.Protocol.Transport; +using System.Text; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) + : StreamableHttpServerIntegrationTests(fixture, testOutputHelper) + +{ + protected override SseClientTransportOptions ClientTransportOptions => new() + { + Endpoint = new Uri("http://localhost/stateless"), + Name = "TestServer", + UseStreamableHttp = true, + }; +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs new file mode 100644 index 00000000..06bd35f9 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -0,0 +1,69 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Protocol.Types; +using System.Net; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class StatelessServerTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable +{ + private WebApplication? _app; + + private async Task StartAsync() + { + Builder.Services.AddMcpServer(mcpServerOptions => + { + mcpServerOptions.ServerInfo = new Implementation + { + Name = nameof(StreamableHttpServerConformanceTests), + Version = "73", + }; + }).WithHttpTransport(httpServerTransportOptions => + { + httpServerTransportOptions.Stateless = true; + }); + + _app = Builder.Build(); + + _app.MapMcp(); + + await _app.StartAsync(TestContext.Current.CancellationToken); + + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); + } + + public async ValueTask DisposeAsync() + { + if (_app is not null) + { + await _app.DisposeAsync(); + } + base.Dispose(); + } + + [Fact] + public async Task EnablingStatelessMode_Disables_SseEndpoints() + { + await StartAsync(); + + using var sseResponse = await HttpClient.GetAsync("/sse", TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotFound, sseResponse.StatusCode); + + using var messageResponse = await HttpClient.PostAsync("/message", new StringContent(""), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotFound, messageResponse.StatusCode); + } + + [Fact] + public async Task EnablingStatelessMode_Disables_GetAndDeleteEndpoints() + { + await StartAsync(); + + using var getResponse = await HttpClient.GetAsync("/", TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.MethodNotAllowed, getResponse.StatusCode); + + using var deleteResponse = await HttpClient.DeleteAsync("/", TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.MethodNotAllowed, deleteResponse.StatusCode); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index 9e5ce6fa..196b5b61 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -71,7 +71,6 @@ public async Task NegativeNonInfiniteIdleTimeout_Throws_ArgumentOutOfRangeExcept Assert.Contains("IdleTimeout", ex.Message); } - [Fact] public async Task NegativeMaxIdleSessionCount_Throws_ArgumentOutOfRangeException() { @@ -360,6 +359,47 @@ public async Task Progress_IsReported_InSameSseResponseAsRpcResponse() Assert.Equal(11, currentSseItem); } + [Fact] + public async Task AsyncLocalSetInRunSessionHandlerCallback_Flows_ToAllToolCalls() + { + var asyncLocal = new AsyncLocal(); + var totalSessionCount = 0; + + Builder.Services.AddMcpServer() + .WithHttpTransport(options => + { + options.RunSessionHandler = async (httpContext, mcpServer, cancellationToken) => + { + asyncLocal.Value = $"RunSessionHandler ({totalSessionCount++})"; + await mcpServer.RunAsync(cancellationToken); + }; + }); + + Builder.Services.AddSingleton(McpServerTool.Create([McpServerTool(Name = "async-local-session")] () => asyncLocal.Value)); + + await StartAsync(); + + var firstSessionId = await CallInitializeAndValidateAsync(); + + async Task CallAsyncLocalToolAndValidateAsync(int expectedSessionIndex) + { + var response = await HttpClient.PostAsync("", JsonContent(CallTool("async-local-session")), TestContext.Current.CancellationToken); + var rpcResponse = await AssertSingleSseResponseAsync(response); + var callToolResponse = AssertType(rpcResponse.Result); + var callToolContent = Assert.Single(callToolResponse.Content); + Assert.Equal("text", callToolContent.Type); + Assert.Equal($"RunSessionHandler ({expectedSessionIndex})", callToolContent.Text); + } + + await CallAsyncLocalToolAndValidateAsync(expectedSessionIndex: 0); + + await CallInitializeAndValidateAsync(); + await CallAsyncLocalToolAndValidateAsync(expectedSessionIndex: 1); + + SetSessionId(firstSessionId); + await CallAsyncLocalToolAndValidateAsync(expectedSessionIndex: 0); + } + [Fact] public async Task IdleSessions_ArePruned_AfterIdleTimeout() { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index 9d304892..3abb1aa3 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -1,5 +1,4 @@ using ModelContextProtocol.Protocol.Transport; -using System.Net; using System.Text; namespace ModelContextProtocol.AspNetCore.Tests; diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 72a271cf..88124f9d 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Utils.Json; using Serilog; +using System.Diagnostics; using System.Text; using System.Text.Json; @@ -378,6 +379,26 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st }; } + private static void HandleStatelessMcp(IApplicationBuilder app) + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddLogging(); + serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService()); + serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService()); + serviceCollection.AddRoutingCore(); + + serviceCollection.AddMcpServer(ConfigureOptions).WithHttpTransport(options => options.Stateless = true); + + var appBuilder = new ApplicationBuilder(serviceCollection.BuildServiceProvider()); + appBuilder.UseRouting(); + appBuilder.UseEndpoints(innerEndpoints => + { + innerEndpoints.MapMcp("/stateless"); + }); + + app.Run(appBuilder.Build()); + } + public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, IConnectionListenerFactory? kestrelTransport = null, CancellationToken cancellationToken = default) { Console.WriteLine("Starting server..."); @@ -419,6 +440,9 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide app.UseRouting(); app.UseEndpoints(_ => { }); + // Handle the /stateless endpoint if no other endpoints have been matched by the call to UseRouting above. + HandleStatelessMcp(app); + app.MapMcp(); await app.RunAsync(cancellationToken);