From 201f05df1be3717adb9a9ee3d3776b0ff8fa469d Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 1 May 2025 18:16:17 -0400 Subject: [PATCH 1/2] Improve exception diagnostics for stdio client Track the last several lines of stderr and include those in an exception created as part of CleanupAsync if the server process has already exited. This is based on an assumption that the server should never exit prior to CleanupAsync being called. --- .../System/Diagnostics/ProcessExtensions.cs | 37 -------- .../Transport/SseClientSessionTransport.cs | 6 +- .../Transport/StdioClientSessionTransport.cs | 90 +++++++++++++------ .../Transport/StdioClientTransport.cs | 56 ++++++++---- .../Transport/StreamClientSessionTransport.cs | 30 +++---- .../Transport/StreamServerTransport.cs | 12 +-- .../StreamableHttpClientSessionTransport.cs | 4 +- .../Protocol/Transport/TransportBase.cs | 71 ++++++++++++--- src/ModelContextProtocol/Shared/McpSession.cs | 2 +- .../Transport/StdioClientTransportTests.cs | 20 +++++ 10 files changed, 202 insertions(+), 126 deletions(-) delete mode 100644 src/Common/Polyfills/System/Diagnostics/ProcessExtensions.cs create mode 100644 tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs diff --git a/src/Common/Polyfills/System/Diagnostics/ProcessExtensions.cs b/src/Common/Polyfills/System/Diagnostics/ProcessExtensions.cs deleted file mode 100644 index acd06077..00000000 --- a/src/Common/Polyfills/System/Diagnostics/ProcessExtensions.cs +++ /dev/null @@ -1,37 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Diagnostics; - -internal static class ProcessExtensions -{ - public static void Kill(this Process process, bool entireProcessTree) - { - _ = entireProcessTree; - process.Kill(); - } - - public static async Task WaitForExitAsync(this Process process, CancellationToken cancellationToken = default) - { - if (process.HasExited) - { - return; - } - - var tcs = new TaskCompletionSource(); - void ProcessExitedHandler(object? sender, EventArgs e) => tcs.TrySetResult(true); - - try - { - process.EnableRaisingEvents = true; - process.Exited += ProcessExitedHandler; - - using var _ = cancellationToken.Register(() => tcs.TrySetCanceled(cancellationToken)); - await tcs.Task.ConfigureAwait(false); - } - finally - { - process.Exited -= ProcessExitedHandler; - } - } -} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index c083764a..ccdfc758 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -137,7 +137,7 @@ private async Task CloseAsync() } finally { - SetConnected(false); + SetDisconnected(); } } @@ -203,7 +203,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) } finally { - SetConnected(false); + SetDisconnected(); } } @@ -251,7 +251,7 @@ private void HandleEndpointEvent(string data) _messageEndpoint = new Uri(_sseEndpoint, data); // Set connected state - SetConnected(true); + SetConnected(); _connectionEstablished.TrySetResult(true); } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs index 3bb5f312..194d659f 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs @@ -9,54 +9,51 @@ internal sealed class StdioClientSessionTransport : StreamClientSessionTransport { private readonly StdioClientTransportOptions _options; private readonly Process _process; + private readonly Queue _stderrRollingLog; + private int _cleanedUp = 0; - public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, ILoggerFactory? loggerFactory) + public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, Queue stderrRollingLog, ILoggerFactory? loggerFactory) : base(process.StandardInput, process.StandardOutput, endpointName, loggerFactory) { _process = process; _options = options; + _stderrRollingLog = stderrRollingLog; } /// - /// - /// - /// For stdio-based transports, this implementation first verifies that the underlying process - /// is still running before attempting to send the message. If the process has exited or cannot - /// be accessed, a is thrown with details about the failure. - /// - /// - /// After verifying the process state, this method delegates to the base class implementation - /// to handle the actual message serialization and transmission to the process's standard input stream. - /// - /// - /// - /// Thrown when the underlying process has exited or cannot be accessed. - /// public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { - Exception? processException = null; - bool hasExited = false; try { - hasExited = _process.HasExited; + await base.SendMessageAsync(message, cancellationToken); } - catch (Exception e) + catch (IOException) { - processException = e; - hasExited = true; - } + // We failed to send due to an I/O error. If the server process has exited, which is then very likely the cause + // for the I/O error, we should throw an exception for that instead. + if (await GetUnexpectedExitExceptionAsync(cancellationToken).ConfigureAwait(false) is Exception processExitException) + { + throw processExitException; + } - if (hasExited) - { - throw new InvalidOperationException("Transport is not connected", processException); + throw; } - - await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } /// - protected override ValueTask CleanupAsync(CancellationToken cancellationToken) + protected override async ValueTask CleanupAsync(Exception? error = null, CancellationToken cancellationToken = default) { + // Only clean up once. + if (Interlocked.Exchange(ref _cleanedUp, 1) != 0) + { + return; + } + + // We've not yet forcefully terminated the server. If it's already shut down, something went wrong, + // so create an exception with details about that. + error ??= await GetUnexpectedExitExceptionAsync(cancellationToken).ConfigureAwait(false); + + // Now terminate the server process. try { StdioClientTransport.DisposeProcess(_process, processRunning: true, _options.ShutdownTimeout, Name); @@ -66,6 +63,41 @@ protected override ValueTask CleanupAsync(CancellationToken cancellationToken) LogTransportShutdownFailed(Name, ex); } - return base.CleanupAsync(cancellationToken); + // And handle cleanup in the base type. + await base.CleanupAsync(error, cancellationToken); + } + + private async ValueTask GetUnexpectedExitExceptionAsync(CancellationToken cancellationToken) + { + if (!StdioClientTransport.HasExited(_process)) + { + return null; + } + + Debug.Assert(StdioClientTransport.HasExited(_process)); + try + { + // The process has exited, but we still need to ensure stderr has been flushed. +#if NET + await _process.WaitForExitAsync(cancellationToken).ConfigureAwait(false); +#else + _process.WaitForExit(); +#endif + } + catch { } + + string errorMessage = "MCP server process exited unexpectedly."; + lock (_stderrRollingLog) + { + if (_stderrRollingLog.Count > 0) + { + errorMessage = + $"{errorMessage}{Environment.NewLine}" + + $"Server's stderr tail:{Environment.NewLine}" + + $"{string.Join(Environment.NewLine, _stderrRollingLog)}"; + } + } + + return new IOException(errorMessage); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index 3fb00f04..c65ba5e4 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -127,8 +127,28 @@ public async Task ConnectAsync(CancellationToken cancellationToken = process = new() { StartInfo = startInfo }; - // Set up error logging - process.ErrorDataReceived += (sender, args) => LogReadStderr(logger, endpointName, args.Data ?? "(no data)"); + // Set up stderr handling. Log all stderr output, and keep the last + // few lines in a rolling log for use in exceptions. + const int MaxStderrLength = 10; // keep the last 10 lines of stderr + Queue stderrRollingLog = new(MaxStderrLength); + process.ErrorDataReceived += (sender, args) => + { + string? data = args.Data; + if (data is not null) + { + lock (stderrRollingLog) + { + if (stderrRollingLog.Count >= MaxStderrLength) + { + stderrRollingLog.Dequeue(); + } + + stderrRollingLog.Enqueue(data); + } + + LogReadStderr(logger, endpointName, data); + } + }; // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core, // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but @@ -154,14 +174,14 @@ public async Task ConnectAsync(CancellationToken cancellationToken = if (!processStarted) { LogTransportProcessStartFailed(logger, endpointName); - throw new InvalidOperationException("Failed to start MCP server process"); + throw new IOException("Failed to start MCP server process."); } LogTransportProcessStarted(logger, endpointName, process.Id); process.BeginErrorReadLine(); - return new StdioClientSessionTransport(_options, process, endpointName, _loggerFactory); + return new StdioClientSessionTransport(_options, process, endpointName, stderrRollingLog, _loggerFactory); } catch (Exception ex) { @@ -176,7 +196,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = LogTransportShutdownFailed(logger, endpointName, ex2); } - throw new InvalidOperationException("Failed to connect transport", ex); + throw new IOException("Failed to connect transport.", ex); } } @@ -185,20 +205,9 @@ internal static void DisposeProcess( { if (process is not null) { - if (processRunning) - { - try - { - processRunning = !process.HasExited; - } - catch - { - processRunning = false; - } - } - try { + processRunning = processRunning && !HasExited(process); if (processRunning) { // Wait for the process to exit. @@ -214,6 +223,19 @@ internal static void DisposeProcess( } } + /// Gets whether has exited. + internal static bool HasExited(Process process) + { + try + { + return process.HasExited; + } + catch + { + return true; + } + } + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} connecting.")] private static partial void LogTransportConnecting(ILogger logger, string endpointName); diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs index fc8672e0..26c87191 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs @@ -53,7 +53,7 @@ public StreamClientSessionTransport( _readTask = readTask.Unwrap(); readTask.Start(); - SetConnected(true); + SetConnected(); } /// @@ -61,7 +61,7 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation { if (!IsConnected) { - throw new InvalidOperationException("Transport is not connected"); + throw new InvalidOperationException("Transport is not connected."); } string id = "(no id)"; @@ -82,31 +82,22 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation catch (Exception ex) { LogTransportSendFailed(Name, id, ex); - throw new InvalidOperationException("Failed to send message", ex); + throw new IOException("Failed to send message", ex); } } /// - /// - /// Asynchronously releases all resources used by the stream client session transport. - /// - /// A task that represents the asynchronous dispose operation. - /// - /// This method cancels ongoing operations and waits for the read task to complete - /// before marking the transport as disconnected. It calls - /// to perform the actual cleanup work. - /// After disposal, the transport can no longer be used to send or receive messages. - /// - public override ValueTask DisposeAsync() => - CleanupAsync(CancellationToken.None); + public override ValueTask DisposeAsync() => + CleanupAsync(cancellationToken: CancellationToken.None); private async Task ReadMessagesAsync(CancellationToken cancellationToken) { + Exception? error = null; try { LogTransportEnteringReadMessagesLoop(Name); - while (!cancellationToken.IsCancellationRequested) + while (true) { if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line) { @@ -130,12 +121,13 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) } catch (Exception ex) { + error = ex; LogTransportReadMessagesFailed(Name, ex); } finally { _readTask = null; - await CleanupAsync(cancellationToken).ConfigureAwait(false); + await CleanupAsync(error, cancellationToken).ConfigureAwait(false); } } @@ -166,7 +158,7 @@ private async Task ProcessMessageAsync(string line, CancellationToken cancellati } } - protected virtual async ValueTask CleanupAsync(CancellationToken cancellationToken) + protected virtual async ValueTask CleanupAsync(Exception? error = null, CancellationToken cancellationToken = default) { LogTransportShuttingDown(Name); @@ -191,7 +183,7 @@ protected virtual async ValueTask CleanupAsync(CancellationToken cancellationTok } } - SetConnected(false); + SetDisconnected(error); LogTransportShutDown(Name); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs index fba41782..730f7b1d 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs @@ -51,7 +51,7 @@ public StreamServerTransport(Stream inputStream, Stream outputStream, string? se _inputReader = new StreamReader(inputStream, Encoding.UTF8); _outputStream = outputStream; - SetConnected(true); + SetConnected(); _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token); } @@ -60,7 +60,7 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation { if (!IsConnected) { - throw new InvalidOperationException("Transport is not connected"); + throw new InvalidOperationException("Transport is not connected."); } using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); @@ -80,13 +80,14 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation catch (Exception ex) { LogTransportSendFailed(Name, id, ex); - throw new InvalidOperationException("Failed to send message", ex); + throw new IOException("Failed to send message.", ex); } } private async Task ReadMessagesAsync() { CancellationToken shutdownToken = _shutdownCts.Token; + Exception? error = null; try { LogTransportEnteringReadMessagesLoop(Name); @@ -140,10 +141,11 @@ private async Task ReadMessagesAsync() catch (Exception ex) { LogTransportReadMessagesFailed(Name, ex); + error = ex; } finally { - SetConnected(false); + SetDisconnected(error); } } @@ -183,7 +185,7 @@ public override async ValueTask DisposeAsync() } finally { - SetConnected(false); + SetDisconnected(); LogTransportShutDown(Name); } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs index a442d5b3..080e90e4 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs @@ -45,7 +45,7 @@ public StreamableHttpClientSessionTransport(SseClientTransportOptions transportO // We connect with the initialization request with the MCP transport. This means that any errors won't be observed // until the first call to SendMessageAsync. Fortunately, that happens internally in McpClientFactory.ConnectAsync // so we still throw any connection-related Exceptions from there and never expose a pre-connected client to the user. - SetConnected(true); + SetConnected(); } /// @@ -139,7 +139,7 @@ public override async ValueTask DisposeAsync() } finally { - SetConnected(false); + SetDisconnected(); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs index 1af8d91f..77292224 100644 --- a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs +++ b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs @@ -1,7 +1,8 @@ +using System.Diagnostics; using System.Threading.Channels; using Microsoft.Extensions.Logging; -using ModelContextProtocol.Protocol.Messages; using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol.Messages; namespace ModelContextProtocol.Protocol.Transport; @@ -23,7 +24,14 @@ public abstract partial class TransportBase : ITransport { private readonly Channel _messageChannel; private readonly ILogger _logger; - private int _isConnected; + private volatile int _state = StateInitial; + + /// The transport has not yet been connected. + private const int StateInitial = 0; + /// The transport is connected. + private const int StateConnected = 1; + /// The transport was previously connected and is now disconnected. + private const int StateDisconnected = 2; /// /// Initializes a new instance of the class. @@ -53,7 +61,7 @@ protected TransportBase(string name, ILoggerFactory? loggerFactory) protected string Name { get; } /// - public bool IsConnected => _isConnected == 1; + public bool IsConnected => _state == StateConnected; /// public ChannelReader MessageReader => _messageChannel.Reader; @@ -73,7 +81,7 @@ protected async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken { if (!IsConnected) { - throw new InvalidOperationException("Transport is not connected"); + throw new InvalidOperationException("Transport is not connected."); } if (_logger.IsEnabled(LogLevel.Debug)) @@ -82,24 +90,61 @@ protected async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken LogTransportReceivedMessage(Name, messageId); } - await _messageChannel.Writer.WriteAsync(message, cancellationToken).ConfigureAwait(false); + bool wrote = _messageChannel.Writer.TryWrite(message); + Debug.Assert(wrote || !IsConnected, "_messageChannel is unbounded; this should only ever return false if the channel has been closed."); } /// - /// Sets the connected state of the transport. + /// Sets the transport to a connected state. /// - /// Whether the transport is connected. - protected void SetConnected(bool isConnected) + protected void SetConnected() { - var newIsConnected = isConnected ? 1 : 0; - if (Interlocked.Exchange(ref _isConnected, newIsConnected) == newIsConnected) + while (true) { - return; + int state = _state; + switch (state) + { + case StateInitial: + if (Interlocked.CompareExchange(ref _state, StateConnected, StateInitial) == StateInitial) + { + return; + } + break; + + case StateConnected: + return; + + case StateDisconnected: + throw new InvalidOperationException("Transport is already disconnected and can't be reconnected."); + + default: + Debug.Fail($"Unexpected state: {state}"); + return; + } } + } - if (!isConnected) + /// + /// Sets the transport to a disconnected state. + /// + /// Optional error information associated with the transport disconnecting. Should be if the disconnect was graceful and expected. + protected void SetDisconnected(Exception? error = null) + { + int state = _state; + switch (state) { - _messageChannel.Writer.Complete(); + case StateInitial: + case StateConnected: + _state = StateDisconnected; + _messageChannel.Writer.TryComplete(error); + break; + + case StateDisconnected: + return; + + default: + Debug.Fail($"Unexpected state: {state}"); + break; } } diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index fc7136a9..ab93efd2 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -197,7 +197,7 @@ await SendMessageAsync(new JsonRpcError // Fail any pending requests, as they'll never be satisfied. foreach (var entry in _pendingRequests) { - entry.Value.TrySetException(new InvalidOperationException("The server shut down unexpectedly.")); + entry.Value.TrySetException(new IOException("The server shut down unexpectedly.")); } } } diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs new file mode 100644 index 00000000..c3a1929f --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -0,0 +1,20 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Transport; + +namespace ModelContextProtocol.Tests.Transport; + +public class StdioClientTransportTests +{ + [Fact] + public async Task CreateAsync_ValidProcessInvalidServer_Throws() + { + StdioClientTransport transport = new(new() { Command = "echo", Arguments = ["this is a test", "1>&2"] }); + + IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, cancellationToken: TestContext.Current.CancellationToken)); + string exStr = e.ToString(); + if (!exStr.Contains("this is a test")) + { + throw new Exception($"Expected error message not found in exception: {exStr}"); + } + } +} From b0ece50ba761304395451e2f41ebe049439664b7 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 1 May 2025 21:51:36 -0400 Subject: [PATCH 2/2] Address feedback and fix test failure --- .../Transport/StdioClientSessionTransport.cs | 12 ++++++++++-- .../Transport/StreamClientSessionTransport.cs | 7 +------ .../Protocol/Transport/StreamServerTransport.cs | 2 +- .../Protocol/Transport/TransportBase.cs | 2 +- src/ModelContextProtocol/Shared/McpEndpoint.cs | 13 ++++++++++++- .../Transport/StdioClientTransportTests.cs | 13 +++++++------ 6 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs index 194d659f..d8ad6afb 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs @@ -86,13 +86,21 @@ protected override async ValueTask CleanupAsync(Exception? error = null, Cancell } catch { } - string errorMessage = "MCP server process exited unexpectedly."; + string errorMessage = "MCP server process exited unexpectedly"; + + string? exitCode = null; + try + { + exitCode = $" (exit code: {(uint)_process.ExitCode})"; + } + catch { } + lock (_stderrRollingLog) { if (_stderrRollingLog.Count > 0) { errorMessage = - $"{errorMessage}{Environment.NewLine}" + + $"{errorMessage}{exitCode}{Environment.NewLine}" + $"Server's stderr tail:{Environment.NewLine}" + $"{string.Join(Environment.NewLine, _stderrRollingLog)}"; } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs index 26c87191..07b78c75 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs @@ -59,11 +59,6 @@ public StreamClientSessionTransport( /// public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { - if (!IsConnected) - { - throw new InvalidOperationException("Transport is not connected."); - } - string id = "(no id)"; if (message is JsonRpcMessageWithId messageWithId) { @@ -82,7 +77,7 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation catch (Exception ex) { LogTransportSendFailed(Name, id, ex); - throw new IOException("Failed to send message", ex); + throw new IOException("Failed to send message.", ex); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs index 730f7b1d..a2b9fd4e 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs @@ -60,7 +60,7 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation { if (!IsConnected) { - throw new InvalidOperationException("Transport is not connected."); + return; } using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs index 77292224..56e55a63 100644 --- a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs +++ b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs @@ -115,7 +115,7 @@ protected void SetConnected() return; case StateDisconnected: - throw new InvalidOperationException("Transport is already disconnected and can't be reconnected."); + throw new IOException("Transport is already disconnected and can't be reconnected."); default: Debug.Fail($"Unexpected state: {state}"); diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs index 394ccaa7..0e709d0f 100644 --- a/src/ModelContextProtocol/Shared/McpEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs @@ -125,7 +125,18 @@ public virtual async ValueTask DisposeUnsynchronizedAsync() } protected McpSession GetSessionOrThrow() - => _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); + { +#if NET + ObjectDisposedException.ThrowIf(_disposed, this); +#else + if (_disposed) + { + throw new ObjectDisposedException(GetType().Name); + } +#endif + + return _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); + } [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shutting down.")] private partial void LogEndpointShuttingDown(string endpointName); diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index c3a1929f..8aadf95b 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -1,5 +1,6 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Transport; +using System.Runtime.InteropServices; namespace ModelContextProtocol.Tests.Transport; @@ -8,13 +9,13 @@ public class StdioClientTransportTests [Fact] public async Task CreateAsync_ValidProcessInvalidServer_Throws() { - StdioClientTransport transport = new(new() { Command = "echo", Arguments = ["this is a test", "1>&2"] }); + string id = Guid.NewGuid().ToString("N"); + + StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"] }) : + new(new() { Command = "ls", Arguments = [id] }); IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, cancellationToken: TestContext.Current.CancellationToken)); - string exStr = e.ToString(); - if (!exStr.Contains("this is a test")) - { - throw new Exception($"Expected error message not found in exception: {exStr}"); - } + Assert.Contains(id, e.ToString()); } }