diff --git a/src/Common/Polyfills/System/Diagnostics/ProcessExtensions.cs b/src/Common/Polyfills/System/Diagnostics/ProcessExtensions.cs deleted file mode 100644 index acd06077..00000000 --- a/src/Common/Polyfills/System/Diagnostics/ProcessExtensions.cs +++ /dev/null @@ -1,37 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Diagnostics; - -internal static class ProcessExtensions -{ - public static void Kill(this Process process, bool entireProcessTree) - { - _ = entireProcessTree; - process.Kill(); - } - - public static async Task WaitForExitAsync(this Process process, CancellationToken cancellationToken = default) - { - if (process.HasExited) - { - return; - } - - var tcs = new TaskCompletionSource(); - void ProcessExitedHandler(object? sender, EventArgs e) => tcs.TrySetResult(true); - - try - { - process.EnableRaisingEvents = true; - process.Exited += ProcessExitedHandler; - - using var _ = cancellationToken.Register(() => tcs.TrySetCanceled(cancellationToken)); - await tcs.Task.ConfigureAwait(false); - } - finally - { - process.Exited -= ProcessExitedHandler; - } - } -} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index c083764a..ccdfc758 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -137,7 +137,7 @@ private async Task CloseAsync() } finally { - SetConnected(false); + SetDisconnected(); } } @@ -203,7 +203,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) } finally { - SetConnected(false); + SetDisconnected(); } } @@ -251,7 +251,7 @@ private void HandleEndpointEvent(string data) _messageEndpoint = new Uri(_sseEndpoint, data); // Set connected state - SetConnected(true); + SetConnected(); _connectionEstablished.TrySetResult(true); } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs index 3bb5f312..d8ad6afb 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs @@ -9,54 +9,51 @@ internal sealed class StdioClientSessionTransport : StreamClientSessionTransport { private readonly StdioClientTransportOptions _options; private readonly Process _process; + private readonly Queue _stderrRollingLog; + private int _cleanedUp = 0; - public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, ILoggerFactory? loggerFactory) + public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, Queue stderrRollingLog, ILoggerFactory? loggerFactory) : base(process.StandardInput, process.StandardOutput, endpointName, loggerFactory) { _process = process; _options = options; + _stderrRollingLog = stderrRollingLog; } /// - /// - /// - /// For stdio-based transports, this implementation first verifies that the underlying process - /// is still running before attempting to send the message. If the process has exited or cannot - /// be accessed, a is thrown with details about the failure. - /// - /// - /// After verifying the process state, this method delegates to the base class implementation - /// to handle the actual message serialization and transmission to the process's standard input stream. - /// - /// - /// - /// Thrown when the underlying process has exited or cannot be accessed. - /// public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { - Exception? processException = null; - bool hasExited = false; try { - hasExited = _process.HasExited; + await base.SendMessageAsync(message, cancellationToken); } - catch (Exception e) + catch (IOException) { - processException = e; - hasExited = true; - } + // We failed to send due to an I/O error. If the server process has exited, which is then very likely the cause + // for the I/O error, we should throw an exception for that instead. + if (await GetUnexpectedExitExceptionAsync(cancellationToken).ConfigureAwait(false) is Exception processExitException) + { + throw processExitException; + } - if (hasExited) - { - throw new InvalidOperationException("Transport is not connected", processException); + throw; } - - await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } /// - protected override ValueTask CleanupAsync(CancellationToken cancellationToken) + protected override async ValueTask CleanupAsync(Exception? error = null, CancellationToken cancellationToken = default) { + // Only clean up once. + if (Interlocked.Exchange(ref _cleanedUp, 1) != 0) + { + return; + } + + // We've not yet forcefully terminated the server. If it's already shut down, something went wrong, + // so create an exception with details about that. + error ??= await GetUnexpectedExitExceptionAsync(cancellationToken).ConfigureAwait(false); + + // Now terminate the server process. try { StdioClientTransport.DisposeProcess(_process, processRunning: true, _options.ShutdownTimeout, Name); @@ -66,6 +63,49 @@ protected override ValueTask CleanupAsync(CancellationToken cancellationToken) LogTransportShutdownFailed(Name, ex); } - return base.CleanupAsync(cancellationToken); + // And handle cleanup in the base type. + await base.CleanupAsync(error, cancellationToken); + } + + private async ValueTask GetUnexpectedExitExceptionAsync(CancellationToken cancellationToken) + { + if (!StdioClientTransport.HasExited(_process)) + { + return null; + } + + Debug.Assert(StdioClientTransport.HasExited(_process)); + try + { + // The process has exited, but we still need to ensure stderr has been flushed. +#if NET + await _process.WaitForExitAsync(cancellationToken).ConfigureAwait(false); +#else + _process.WaitForExit(); +#endif + } + catch { } + + string errorMessage = "MCP server process exited unexpectedly"; + + string? exitCode = null; + try + { + exitCode = $" (exit code: {(uint)_process.ExitCode})"; + } + catch { } + + lock (_stderrRollingLog) + { + if (_stderrRollingLog.Count > 0) + { + errorMessage = + $"{errorMessage}{exitCode}{Environment.NewLine}" + + $"Server's stderr tail:{Environment.NewLine}" + + $"{string.Join(Environment.NewLine, _stderrRollingLog)}"; + } + } + + return new IOException(errorMessage); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index 3fb00f04..c65ba5e4 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -127,8 +127,28 @@ public async Task ConnectAsync(CancellationToken cancellationToken = process = new() { StartInfo = startInfo }; - // Set up error logging - process.ErrorDataReceived += (sender, args) => LogReadStderr(logger, endpointName, args.Data ?? "(no data)"); + // Set up stderr handling. Log all stderr output, and keep the last + // few lines in a rolling log for use in exceptions. + const int MaxStderrLength = 10; // keep the last 10 lines of stderr + Queue stderrRollingLog = new(MaxStderrLength); + process.ErrorDataReceived += (sender, args) => + { + string? data = args.Data; + if (data is not null) + { + lock (stderrRollingLog) + { + if (stderrRollingLog.Count >= MaxStderrLength) + { + stderrRollingLog.Dequeue(); + } + + stderrRollingLog.Enqueue(data); + } + + LogReadStderr(logger, endpointName, data); + } + }; // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core, // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but @@ -154,14 +174,14 @@ public async Task ConnectAsync(CancellationToken cancellationToken = if (!processStarted) { LogTransportProcessStartFailed(logger, endpointName); - throw new InvalidOperationException("Failed to start MCP server process"); + throw new IOException("Failed to start MCP server process."); } LogTransportProcessStarted(logger, endpointName, process.Id); process.BeginErrorReadLine(); - return new StdioClientSessionTransport(_options, process, endpointName, _loggerFactory); + return new StdioClientSessionTransport(_options, process, endpointName, stderrRollingLog, _loggerFactory); } catch (Exception ex) { @@ -176,7 +196,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = LogTransportShutdownFailed(logger, endpointName, ex2); } - throw new InvalidOperationException("Failed to connect transport", ex); + throw new IOException("Failed to connect transport.", ex); } } @@ -185,20 +205,9 @@ internal static void DisposeProcess( { if (process is not null) { - if (processRunning) - { - try - { - processRunning = !process.HasExited; - } - catch - { - processRunning = false; - } - } - try { + processRunning = processRunning && !HasExited(process); if (processRunning) { // Wait for the process to exit. @@ -214,6 +223,19 @@ internal static void DisposeProcess( } } + /// Gets whether has exited. + internal static bool HasExited(Process process) + { + try + { + return process.HasExited; + } + catch + { + return true; + } + } + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} connecting.")] private static partial void LogTransportConnecting(ILogger logger, string endpointName); diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs index fc8672e0..07b78c75 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs @@ -53,17 +53,12 @@ public StreamClientSessionTransport( _readTask = readTask.Unwrap(); readTask.Start(); - SetConnected(true); + SetConnected(); } /// public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { - if (!IsConnected) - { - throw new InvalidOperationException("Transport is not connected"); - } - string id = "(no id)"; if (message is JsonRpcMessageWithId messageWithId) { @@ -82,31 +77,22 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation catch (Exception ex) { LogTransportSendFailed(Name, id, ex); - throw new InvalidOperationException("Failed to send message", ex); + throw new IOException("Failed to send message.", ex); } } /// - /// - /// Asynchronously releases all resources used by the stream client session transport. - /// - /// A task that represents the asynchronous dispose operation. - /// - /// This method cancels ongoing operations and waits for the read task to complete - /// before marking the transport as disconnected. It calls - /// to perform the actual cleanup work. - /// After disposal, the transport can no longer be used to send or receive messages. - /// - public override ValueTask DisposeAsync() => - CleanupAsync(CancellationToken.None); + public override ValueTask DisposeAsync() => + CleanupAsync(cancellationToken: CancellationToken.None); private async Task ReadMessagesAsync(CancellationToken cancellationToken) { + Exception? error = null; try { LogTransportEnteringReadMessagesLoop(Name); - while (!cancellationToken.IsCancellationRequested) + while (true) { if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line) { @@ -130,12 +116,13 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) } catch (Exception ex) { + error = ex; LogTransportReadMessagesFailed(Name, ex); } finally { _readTask = null; - await CleanupAsync(cancellationToken).ConfigureAwait(false); + await CleanupAsync(error, cancellationToken).ConfigureAwait(false); } } @@ -166,7 +153,7 @@ private async Task ProcessMessageAsync(string line, CancellationToken cancellati } } - protected virtual async ValueTask CleanupAsync(CancellationToken cancellationToken) + protected virtual async ValueTask CleanupAsync(Exception? error = null, CancellationToken cancellationToken = default) { LogTransportShuttingDown(Name); @@ -191,7 +178,7 @@ protected virtual async ValueTask CleanupAsync(CancellationToken cancellationTok } } - SetConnected(false); + SetDisconnected(error); LogTransportShutDown(Name); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs index fba41782..a2b9fd4e 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs @@ -51,7 +51,7 @@ public StreamServerTransport(Stream inputStream, Stream outputStream, string? se _inputReader = new StreamReader(inputStream, Encoding.UTF8); _outputStream = outputStream; - SetConnected(true); + SetConnected(); _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token); } @@ -60,7 +60,7 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation { if (!IsConnected) { - throw new InvalidOperationException("Transport is not connected"); + return; } using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); @@ -80,13 +80,14 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation catch (Exception ex) { LogTransportSendFailed(Name, id, ex); - throw new InvalidOperationException("Failed to send message", ex); + throw new IOException("Failed to send message.", ex); } } private async Task ReadMessagesAsync() { CancellationToken shutdownToken = _shutdownCts.Token; + Exception? error = null; try { LogTransportEnteringReadMessagesLoop(Name); @@ -140,10 +141,11 @@ private async Task ReadMessagesAsync() catch (Exception ex) { LogTransportReadMessagesFailed(Name, ex); + error = ex; } finally { - SetConnected(false); + SetDisconnected(error); } } @@ -183,7 +185,7 @@ public override async ValueTask DisposeAsync() } finally { - SetConnected(false); + SetDisconnected(); LogTransportShutDown(Name); } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs index a442d5b3..080e90e4 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs @@ -45,7 +45,7 @@ public StreamableHttpClientSessionTransport(SseClientTransportOptions transportO // We connect with the initialization request with the MCP transport. This means that any errors won't be observed // until the first call to SendMessageAsync. Fortunately, that happens internally in McpClientFactory.ConnectAsync // so we still throw any connection-related Exceptions from there and never expose a pre-connected client to the user. - SetConnected(true); + SetConnected(); } /// @@ -139,7 +139,7 @@ public override async ValueTask DisposeAsync() } finally { - SetConnected(false); + SetDisconnected(); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs index 1af8d91f..56e55a63 100644 --- a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs +++ b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs @@ -1,7 +1,8 @@ +using System.Diagnostics; using System.Threading.Channels; using Microsoft.Extensions.Logging; -using ModelContextProtocol.Protocol.Messages; using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol.Messages; namespace ModelContextProtocol.Protocol.Transport; @@ -23,7 +24,14 @@ public abstract partial class TransportBase : ITransport { private readonly Channel _messageChannel; private readonly ILogger _logger; - private int _isConnected; + private volatile int _state = StateInitial; + + /// The transport has not yet been connected. + private const int StateInitial = 0; + /// The transport is connected. + private const int StateConnected = 1; + /// The transport was previously connected and is now disconnected. + private const int StateDisconnected = 2; /// /// Initializes a new instance of the class. @@ -53,7 +61,7 @@ protected TransportBase(string name, ILoggerFactory? loggerFactory) protected string Name { get; } /// - public bool IsConnected => _isConnected == 1; + public bool IsConnected => _state == StateConnected; /// public ChannelReader MessageReader => _messageChannel.Reader; @@ -73,7 +81,7 @@ protected async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken { if (!IsConnected) { - throw new InvalidOperationException("Transport is not connected"); + throw new InvalidOperationException("Transport is not connected."); } if (_logger.IsEnabled(LogLevel.Debug)) @@ -82,24 +90,61 @@ protected async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken LogTransportReceivedMessage(Name, messageId); } - await _messageChannel.Writer.WriteAsync(message, cancellationToken).ConfigureAwait(false); + bool wrote = _messageChannel.Writer.TryWrite(message); + Debug.Assert(wrote || !IsConnected, "_messageChannel is unbounded; this should only ever return false if the channel has been closed."); } /// - /// Sets the connected state of the transport. + /// Sets the transport to a connected state. /// - /// Whether the transport is connected. - protected void SetConnected(bool isConnected) + protected void SetConnected() { - var newIsConnected = isConnected ? 1 : 0; - if (Interlocked.Exchange(ref _isConnected, newIsConnected) == newIsConnected) + while (true) { - return; + int state = _state; + switch (state) + { + case StateInitial: + if (Interlocked.CompareExchange(ref _state, StateConnected, StateInitial) == StateInitial) + { + return; + } + break; + + case StateConnected: + return; + + case StateDisconnected: + throw new IOException("Transport is already disconnected and can't be reconnected."); + + default: + Debug.Fail($"Unexpected state: {state}"); + return; + } } + } - if (!isConnected) + /// + /// Sets the transport to a disconnected state. + /// + /// Optional error information associated with the transport disconnecting. Should be if the disconnect was graceful and expected. + protected void SetDisconnected(Exception? error = null) + { + int state = _state; + switch (state) { - _messageChannel.Writer.Complete(); + case StateInitial: + case StateConnected: + _state = StateDisconnected; + _messageChannel.Writer.TryComplete(error); + break; + + case StateDisconnected: + return; + + default: + Debug.Fail($"Unexpected state: {state}"); + break; } } diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs index 394ccaa7..0e709d0f 100644 --- a/src/ModelContextProtocol/Shared/McpEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs @@ -125,7 +125,18 @@ public virtual async ValueTask DisposeUnsynchronizedAsync() } protected McpSession GetSessionOrThrow() - => _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); + { +#if NET + ObjectDisposedException.ThrowIf(_disposed, this); +#else + if (_disposed) + { + throw new ObjectDisposedException(GetType().Name); + } +#endif + + return _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); + } [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shutting down.")] private partial void LogEndpointShuttingDown(string endpointName); diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index fc7136a9..ab93efd2 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -197,7 +197,7 @@ await SendMessageAsync(new JsonRpcError // Fail any pending requests, as they'll never be satisfied. foreach (var entry in _pendingRequests) { - entry.Value.TrySetException(new InvalidOperationException("The server shut down unexpectedly.")); + entry.Value.TrySetException(new IOException("The server shut down unexpectedly.")); } } } diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs new file mode 100644 index 00000000..8aadf95b --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -0,0 +1,21 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Transport; +using System.Runtime.InteropServices; + +namespace ModelContextProtocol.Tests.Transport; + +public class StdioClientTransportTests +{ + [Fact] + public async Task CreateAsync_ValidProcessInvalidServer_Throws() + { + string id = Guid.NewGuid().ToString("N"); + + StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"] }) : + new(new() { Command = "ls", Arguments = [id] }); + + IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains(id, e.ToString()); + } +}