From 093540828153ad0ae30b091a522dcef2e8f65ed1 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Sat, 24 Feb 2024 17:50:33 +0100 Subject: [PATCH 1/4] Replace AwaitableSocketAsyncEventArgs in SocketExtensions The existing AwaitableSocketAsyncEventArgs is useful in principal for being reusable in order to save on allocations. However, we don't reuse it and the implementation is flawed. Instead, use implementations based on TaskCompletionSource, and add a SendAsync method. Because sockets are only natively cancellable on modern .NET, I was torn between 3 options for cancellation on the targets which use SocketExtensions: 1. Do not respect the CancellationToken once the socket operation has started. I believe this is what earlier versions of .NET Core did when CancellationToken overloads were first added via SocketTaskExtensions. 2. Do not close the socket upon cancellation, meaning the socket operation continues to run after the Task has completed. This is what the previous implementation effectively does. 3. Close the socket when the CancellationToken is cancelled, in order to stop the socket operation. The behaviour of a socket after (proper) cancellation is undefined(?), so in any case it should not make sense to use the socket after triggering cancellation. I felt that option 2 was the worst of them. This iteration goes for option 3. --- .../Abstractions/SocketAbstraction.cs | 4 +- .../Abstractions/SocketExtensions.cs | 162 +++++++++--------- 2 files changed, 79 insertions(+), 87 deletions(-) diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs index e1a12362e..c6032671b 100644 --- a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs +++ b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs @@ -312,9 +312,9 @@ public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeS } #if NET6_0_OR_GREATER == false - public static Task ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken) + public static ValueTask ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken) { - return socket.ReceiveAsync(buffer, 0, buffer.Length, cancellationToken); + return socket.ReceiveAsync(new ArraySegment(buffer, 0, buffer.Length), SocketFlags.None, cancellationToken); } #endif diff --git a/src/Renci.SshNet/Abstractions/SocketExtensions.cs b/src/Renci.SshNet/Abstractions/SocketExtensions.cs index 2c34c899c..67949eaa4 100644 --- a/src/Renci.SshNet/Abstractions/SocketExtensions.cs +++ b/src/Renci.SshNet/Abstractions/SocketExtensions.cs @@ -1,134 +1,126 @@ -#if !NET6_0_OR_GREATER +#if !NET +#if NETFRAMEWORK || NETSTANDARD2_0 using System; +#endif using System.Net; using System.Net.Sockets; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; namespace Renci.SshNet.Abstractions { - // Async helpers based on https://devblogs.microsoft.com/pfxteam/awaiting-socket-operations/ internal static class SocketExtensions { - private sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, INotifyCompletion + public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken) { - private static readonly Action SENTINEL = () => { }; + cancellationToken.ThrowIfCancellationRequested(); - private bool _isCancelled; - private Action _continuationAction; + TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); - public AwaitableSocketAsyncEventArgs() + using var args = new SocketAsyncEventArgs { - Completed += (sender, e) => SetCompleted(); - } + RemoteEndPoint = remoteEndpoint + }; + args.Completed += (_, _) => tcs.TrySetResult(null); - public AwaitableSocketAsyncEventArgs ExecuteAsync(Func func) + if (socket.ConnectAsync(args)) { - if (!func(this)) +#if NETSTANDARD2_1 + await using (cancellationToken.Register(() => +#else + using (cancellationToken.Register(() => +#endif { - SetCompleted(); - } - - return this; - } - - public void SetCompleted() - { - IsCompleted = true; - - var continuation = _continuationAction ?? Interlocked.CompareExchange(ref _continuationAction, SENTINEL, comparand: null); - if (continuation is not null) + if (tcs.TrySetCanceled(cancellationToken)) + { + socket.Dispose(); + } + }, + useSynchronizationContext: false) +#if NETSTANDARD2_1 + .ConfigureAwait(false) +#endif + ) { - continuation(); + _ = await tcs.Task.ConfigureAwait(false); } } - public void SetCancelled() + if (args.SocketError != SocketError.Success) { - _isCancelled = true; - SetCompleted(); + throw new SocketException((int) args.SocketError); } + } -#pragma warning disable S1144 // Unused private types or members should be removed - public AwaitableSocketAsyncEventArgs GetAwaiter() -#pragma warning restore S1144 // Unused private types or members should be removed - { - return this; - } +#if NETFRAMEWORK || NETSTANDARD2_0 + public static async ValueTask ReceiveAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); - public bool IsCompleted { get; private set; } + TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); - void INotifyCompletion.OnCompleted(Action continuation) - { - if (_continuationAction == SENTINEL || Interlocked.CompareExchange(ref _continuationAction, continuation, comparand: null) == SENTINEL) - { - // We have already completed; run continuation asynchronously - _ = Task.Run(continuation); - } - } + using var args = new SocketAsyncEventArgs(); + args.SocketFlags = socketFlags; + args.Completed += (_, _) => tcs.TrySetResult(null); + args.SetBuffer(buffer.Array, buffer.Offset, buffer.Count); -#pragma warning disable S1144 // Unused private types or members should be removed - public void GetResult() -#pragma warning restore S1144 // Unused private types or members should be removed + if (socket.ReceiveAsync(args)) { - if (_isCancelled) + using (cancellationToken.Register(() => { - throw new TaskCanceledException(); - } - - if (!IsCompleted) - { - // We don't support sync/async - throw new InvalidOperationException("The asynchronous operation has not yet completed."); - } - - if (SocketError != SocketError.Success) + if (tcs.TrySetCanceled(cancellationToken)) + { + socket.Dispose(); + } + }, + useSynchronizationContext: false)) { - throw new SocketException((int)SocketError); + _ = await tcs.Task.ConfigureAwait(false); } } - } - public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - using (var args = new AwaitableSocketAsyncEventArgs()) + if (args.SocketError != SocketError.Success) { - args.RemoteEndPoint = remoteEndpoint; - -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs)o).SetCancelled(), args, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs) o).SetCancelled(), args, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - await args.ExecuteAsync(socket.ConnectAsync); - } + throw new SocketException((int) args.SocketError); } + + return args.BytesTransferred; } - public static async Task ReceiveAsync(this Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken) + public static async ValueTask SendAsync(this Socket socket, byte[] buffer, SocketFlags socketFlags, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (var args = new AwaitableSocketAsyncEventArgs()) - { - args.SetBuffer(buffer, offset, length); + TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs) o).SetCancelled(), args, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs) o).SetCancelled(), args, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER + using var args = new SocketAsyncEventArgs(); + args.SocketFlags = socketFlags; + args.Completed += (_, _) => tcs.TrySetResult(null); + args.SetBuffer(buffer, 0, buffer.Length); + + if (socket.SendAsync(args)) + { + using (cancellationToken.Register(() => + { + if (tcs.TrySetCanceled(cancellationToken)) + { + socket.Dispose(); + } + }, + useSynchronizationContext: false)) { - await args.ExecuteAsync(socket.ReceiveAsync); + _ = await tcs.Task.ConfigureAwait(false); } + } - return args.BytesTransferred; + if (args.SocketError != SocketError.Success) + { + throw new SocketException((int) args.SocketError); } + + return args.BytesTransferred; } +#endif // NETFRAMEWORK || NETSTANDARD2_0 } } #endif From 6fac1fb6d340512a53d2d6359b842440f95500a2 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Mon, 26 Feb 2024 20:13:46 +0100 Subject: [PATCH 2/4] Remove "IsErrorResumable" and SocketAbstraction.Send{Async} Some methods in SocketAbstraction have code to retry a socket operation if it returns certain error codes. However AFAIK, these errors are only pertinent to nonblocking sockets, which we do not use. For blocking sockets, Socket.Send{Async} only returns when all of the bytes are sent. There is no need for a loop. These changes combined mean there is no need for Send methods in SocketAbstraction. --- .../Abstractions/SocketAbstraction.Async.cs | 33 -------- .../Abstractions/SocketAbstraction.cs | 83 ++----------------- .../Channels/ChannelDirectTcpip.cs | 2 +- .../Channels/ChannelForwardedTcpip.cs | 2 +- src/Renci.SshNet/Connection/HttpConnector.cs | 9 +- .../Connection/ProtocolVersionExchange.cs | 8 +- .../Connection/Socks4Connector.cs | 3 +- .../Connection/Socks5Connector.cs | 6 +- src/Renci.SshNet/ForwardedPortDynamic.cs | 10 +-- src/Renci.SshNet/Session.cs | 2 +- .../Common/Socks5Handler.cs | 10 +-- ...rTest_Connect_TimeoutReadingHttpContent.cs | 3 +- ...orTest_Connect_TimeoutReadingStatusLine.cs | 3 +- ...Test_TimeoutReadingIdentificationString.cs | 4 +- ...onnect_TimeoutReadingDestinationAddress.cs | 3 +- ...torTest_Connect_TimeoutReadingReplyCode.cs | 3 +- ...Test_Connect_TimeoutReadingReplyVersion.cs | 3 +- 17 files changed, 42 insertions(+), 145 deletions(-) diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs index 6cc6918ea..9a16b951c 100644 --- a/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs +++ b/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs @@ -1,7 +1,5 @@ #if NET6_0_OR_GREATER -using System; -using System.Diagnostics; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; @@ -14,37 +12,6 @@ public static ValueTask ReadAsync(Socket socket, byte[] buffer, Cancellatio { return socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken); } - - public static ValueTask SendAsync(Socket socket, ReadOnlyMemory data, CancellationToken cancellationToken = default) - { - Debug.Assert(socket != null); - Debug.Assert(data.Length > 0); - - if (cancellationToken.IsCancellationRequested) - { - return ValueTask.FromCanceled(cancellationToken); - } - - return SendAsyncCore(socket, data, cancellationToken); - - static async ValueTask SendAsyncCore(Socket socket, ReadOnlyMemory data, CancellationToken cancellationToken) - { - do - { - try - { - var bytesSent = await socket.SendAsync(data, SocketFlags.None, cancellationToken).ConfigureAwait(false); - data = data.Slice(bytesSent); - } - catch (SocketException ex) when (IsErrorResumable(ex.SocketErrorCode)) - { - // Buffer may be full; attempt a short delay and retry - await Task.Delay(30, cancellationToken).ConfigureAwait(false); - } - } - while (data.Length > 0); - } - } } } #endif // NET6_0_OR_GREATER diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs index c6032671b..d96b9189b 100644 --- a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs +++ b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs @@ -6,7 +6,6 @@ using System.Threading.Tasks; using Renci.SshNet.Common; -using Renci.SshNet.Messages.Transport; namespace Renci.SshNet.Abstractions { @@ -167,11 +166,6 @@ public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int } catch (SocketException ex) { - if (IsErrorResumable(ex.SocketErrorCode)) - { - continue; - } - #pragma warning disable IDE0010 // Add missing cases switch (ex.SocketErrorCode) { @@ -221,7 +215,7 @@ public static int ReadByte(Socket socket, TimeSpan timeout) public static void SendByte(Socket socket, byte value) { var buffer = new[] { value }; - Send(socket, buffer, 0, 1); + _ = socket.Send(buffer); } /// @@ -288,22 +282,12 @@ public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeS totalBytesRead += bytesRead; } - catch (SocketException ex) + catch (SocketException ex) when (ex.SocketErrorCode == SocketError.TimedOut) { - if (IsErrorResumable(ex.SocketErrorCode)) - { - ThreadAbstraction.Sleep(30); - continue; - } - - if (ex.SocketErrorCode == SocketError.TimedOut) - { - throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, - "Socket read operation has timed out after {0:F0} milliseconds.", - readTimeout.TotalMilliseconds)); - } - - throw; + throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, + "Socket read operation has timed out after {0:F0} milliseconds.", + readTimeout.TotalMilliseconds), + ex); } } while (totalBytesRead < totalBytesToRead); @@ -317,61 +301,6 @@ public static ValueTask ReadAsync(Socket socket, byte[] buffer, Cancellatio return socket.ReceiveAsync(new ArraySegment(buffer, 0, buffer.Length), SocketFlags.None, cancellationToken); } #endif - - public static void Send(Socket socket, byte[] data) - { - Send(socket, data, 0, data.Length); - } - - public static void Send(Socket socket, byte[] data, int offset, int size) - { - var totalBytesSent = 0; // how many bytes are already sent - var totalBytesToSend = size; - - do - { - try - { - var bytesSent = socket.Send(data, offset + totalBytesSent, totalBytesToSend - totalBytesSent, SocketFlags.None); - if (bytesSent == 0) - { - throw new SshConnectionException("An established connection was aborted by the server.", - DisconnectReason.ConnectionLost); - } - - totalBytesSent += bytesSent; - } - catch (SocketException ex) - { - if (IsErrorResumable(ex.SocketErrorCode)) - { - // socket buffer is probably full, wait and try again - ThreadAbstraction.Sleep(30); - } - else - { - throw; // any serious error occurr - } - } - } - while (totalBytesSent < totalBytesToSend); - } - - public static bool IsErrorResumable(SocketError socketError) - { -#pragma warning disable IDE0010 // Add missing cases - switch (socketError) - { - case SocketError.WouldBlock: - case SocketError.IOPending: - case SocketError.NoBufferSpaceAvailable: - return true; - default: - return false; - } -#pragma warning restore IDE0010 // Add missing cases - } - private static void ConnectCompleted(object sender, SocketAsyncEventArgs e) { var eventWaitHandle = (ManualResetEvent) e.UserToken; diff --git a/src/Renci.SshNet/Channels/ChannelDirectTcpip.cs b/src/Renci.SshNet/Channels/ChannelDirectTcpip.cs index 6c521bce2..3dc7f1e4c 100644 --- a/src/Renci.SshNet/Channels/ChannelDirectTcpip.cs +++ b/src/Renci.SshNet/Channels/ChannelDirectTcpip.cs @@ -201,7 +201,7 @@ protected override void OnData(byte[] data) { if (_socket.IsConnected()) { - SocketAbstraction.Send(_socket, data, 0, data.Length); + _ = _socket.Send(data); } } } diff --git a/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs b/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs index a8382015a..2e075f76b 100644 --- a/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs +++ b/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs @@ -201,7 +201,7 @@ protected override void OnData(byte[] data) var socket = _socket; if (socket.IsConnected()) { - SocketAbstraction.Send(socket, data, 0, data.Length); + _ = socket.Send(data); } } } diff --git a/src/Renci.SshNet/Connection/HttpConnector.cs b/src/Renci.SshNet/Connection/HttpConnector.cs index afbaf0f01..72f502b00 100644 --- a/src/Renci.SshNet/Connection/HttpConnector.cs +++ b/src/Renci.SshNet/Connection/HttpConnector.cs @@ -3,6 +3,7 @@ using System.Globalization; using System.Net; using System.Net.Sockets; +using System.Text; using System.Text.RegularExpressions; using Renci.SshNet.Abstractions; @@ -41,7 +42,7 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke var httpResponseRe = new Regex(@"HTTP/(?\d[.]\d) (?\d{3}) (?.+)$"); var httpHeaderRe = new Regex(@"(?[^\[\]()<>@,;:\""/?={} \t]+):(?.+)?"); - SocketAbstraction.Send(socket, SshData.Ascii.GetBytes(string.Format(CultureInfo.InvariantCulture, + _ = socket.Send(Encoding.ASCII.GetBytes(string.Format(CultureInfo.InvariantCulture, "CONNECT {0}:{1} HTTP/1.0\r\n", connectionInfo.Host, connectionInfo.Port))); @@ -51,11 +52,11 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke { var authorization = string.Format(CultureInfo.InvariantCulture, "Proxy-Authorization: Basic {0}\r\n", - Convert.ToBase64String(SshData.Ascii.GetBytes($"{connectionInfo.ProxyUsername}:{connectionInfo.ProxyPassword}"))); - SocketAbstraction.Send(socket, SshData.Ascii.GetBytes(authorization)); + Convert.ToBase64String(Encoding.ASCII.GetBytes($"{connectionInfo.ProxyUsername}:{connectionInfo.ProxyPassword}"))); + _ = socket.Send(Encoding.ASCII.GetBytes(authorization)); } - SocketAbstraction.Send(socket, SshData.Ascii.GetBytes("\r\n")); + _ = socket.Send(Encoding.ASCII.GetBytes("\r\n")); HttpStatusCode? statusCode = null; var contentLength = 0; diff --git a/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs b/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs index b14da93c0..8cb5e9310 100644 --- a/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs +++ b/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs @@ -38,7 +38,7 @@ public SshIdentification Start(string clientVersion, Socket socket, TimeSpan tim { // Immediately send the identification string since the spec states both sides MUST send an identification string // when the connection has been established - SocketAbstraction.Send(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A")); + _ = socket.Send(Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A")); var bytesReceived = new List(); @@ -81,11 +81,7 @@ public async Task StartAsync(string clientVersion, Socket soc { // Immediately send the identification string since the spec states both sides MUST send an identification string // when the connection has been established -#if NET6_0_OR_GREATER - await SocketAbstraction.SendAsync(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A"), cancellationToken).ConfigureAwait(false); -#else - SocketAbstraction.Send(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A")); -#endif // NET6_0_OR_GREATER + _ = await socket.SendAsync(Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A"), SocketFlags.None, cancellationToken).ConfigureAwait(false); var bytesReceived = new List(); diff --git a/src/Renci.SshNet/Connection/Socks4Connector.cs b/src/Renci.SshNet/Connection/Socks4Connector.cs index e3e9800f0..f8b0926b2 100644 --- a/src/Renci.SshNet/Connection/Socks4Connector.cs +++ b/src/Renci.SshNet/Connection/Socks4Connector.cs @@ -3,7 +3,6 @@ using System.Net.Sockets; using System.Text; -using Renci.SshNet.Abstractions; using Renci.SshNet.Common; namespace Renci.SshNet.Connection @@ -29,7 +28,7 @@ public Socks4Connector(ISocketFactory socketFactory) protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socket socket) { var connectionRequest = CreateSocks4ConnectionRequest(connectionInfo.Host, (ushort)connectionInfo.Port, connectionInfo.ProxyUsername); - SocketAbstraction.Send(socket, connectionRequest); + _ = socket.Send(connectionRequest); // Read reply version if (SocketReadByte(socket, connectionInfo.Timeout) != 0x00) diff --git a/src/Renci.SshNet/Connection/Socks5Connector.cs b/src/Renci.SshNet/Connection/Socks5Connector.cs index ecd286e00..306db7002 100644 --- a/src/Renci.SshNet/Connection/Socks5Connector.cs +++ b/src/Renci.SshNet/Connection/Socks5Connector.cs @@ -41,7 +41,7 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke // Username/Password authentication 0x02 }; - SocketAbstraction.Send(socket, greeting); + _ = socket.Send(greeting); var socksVersion = SocketReadByte(socket); if (socksVersion != 0x05) @@ -60,7 +60,7 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke var authenticationRequest = CreateSocks5UserNameAndPasswordAuthenticationRequest(connectionInfo.ProxyUsername, connectionInfo.ProxyPassword); // Send authentication request - SocketAbstraction.Send(socket, authenticationRequest); + _ = socket.Send(authenticationRequest); // Read authentication result var authenticationResult = SocketAbstraction.Read(socket, 2, connectionInfo.Timeout); @@ -83,7 +83,7 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke } var connectionRequest = CreateSocks5ConnectionRequest(connectionInfo.Host, (ushort) connectionInfo.Port); - SocketAbstraction.Send(socket, connectionRequest); + _ = socket.Send(connectionRequest); // Read Server SOCKS5 version if (SocketReadByte(socket) != 5) diff --git a/src/Renci.SshNet/ForwardedPortDynamic.cs b/src/Renci.SshNet/ForwardedPortDynamic.cs index 2a2c45f2c..440224efd 100644 --- a/src/Renci.SshNet/ForwardedPortDynamic.cs +++ b/src/Renci.SshNet/ForwardedPortDynamic.cs @@ -508,8 +508,8 @@ private bool HandleSocks4(Socket socket, IChannelDirectTcpip channel, TimeSpan t if (channel.IsOpen) { SocketAbstraction.SendByte(socket, 0x5a); - SocketAbstraction.Send(socket, portBuffer, 0, portBuffer.Length); - SocketAbstraction.Send(socket, ipBuffer, 0, ipBuffer.Length); + _ = socket.Send(portBuffer); + _ = socket.Send(ipBuffer); return true; } @@ -538,12 +538,12 @@ private bool HandleSocks5(Socket socket, IChannelDirectTcpip channel, TimeSpan t { // no user authentication is one of the authentication methods supported // by the SOCKS client - SocketAbstraction.Send(socket, new byte[] { 0x05, 0x00 }, 0, 2); + _ = socket.Send([0x05, 0x00]); } else { // the SOCKS client requires authentication, which we currently do not support - SocketAbstraction.Send(socket, new byte[] { 0x05, 0xFF }, 0, 2); + _ = socket.Send([0x05, 0xFF]); // we continue business as usual but expect the client to close the connection // so one of the subsequent reads should return -1 signaling that the client @@ -610,7 +610,7 @@ private bool HandleSocks5(Socket socket, IChannelDirectTcpip channel, TimeSpan t var socksReply = CreateSocks5Reply(channel.IsOpen); - SocketAbstraction.Send(socket, socksReply, 0, socksReply.Length); + _ = socket.Send(socksReply); return true; } diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index 37b7f2db3..2444aced3 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -1160,7 +1160,7 @@ private void SendPacket(byte[] packet, int offset, int length) throw new SshConnectionException("Client not connected."); } - SocketAbstraction.Send(_socket, packet, offset, length); + _ = _socket.Send(packet, offset, length, SocketFlags.None); } finally { diff --git a/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs b/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs index e50858c33..8ed929f71 100644 --- a/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs +++ b/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs @@ -90,7 +90,7 @@ private Socket Connect(byte[] addressBytes, int port) SocketWriteByte(socket, (byte) username.Length); // Send username - SocketAbstraction.Send(socket, username); + _ = socket.Send(username); var password = Encoding.ASCII.GetBytes(_password); @@ -99,11 +99,11 @@ private Socket Connect(byte[] addressBytes, int port) throw new ProxyException("Proxy password is too long."); } - // Send username length + // Send password length SocketWriteByte(socket, (byte) password.Length); - // Send username - SocketAbstraction.Send(socket, password); + // Send password + _ = socket.Send(password); var serverVersion = SocketReadByte(socket); @@ -135,7 +135,7 @@ private Socket Connect(byte[] addressBytes, int port) SocketWriteByte(socket, 0x00); // Send address type and address - SocketAbstraction.Send(socket, addressBytes); + _ = socket.Send(addressBytes); // Send port SocketWriteByte(socket, (byte)(port / 0xFF)); diff --git a/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingHttpContent.cs b/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingHttpContent.cs index 6808bfb50..22af7f447 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingHttpContent.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingHttpContent.cs @@ -119,7 +119,8 @@ protected override void Act() [TestMethod] public void ConnectShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", _connectionInfo.Timeout.TotalMilliseconds), _actualException.Message); } diff --git a/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingStatusLine.cs b/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingStatusLine.cs index 38f65634c..d16af22c0 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingStatusLine.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingStatusLine.cs @@ -95,7 +95,8 @@ protected override void Act() [TestMethod] public void ConnectShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", _connectionInfo.Timeout.TotalMilliseconds), _actualException.Message); } diff --git a/test/Renci.SshNet.Tests/Classes/Connection/ProtocolVersionExchangeTest_TimeoutReadingIdentificationString.cs b/test/Renci.SshNet.Tests/Classes/Connection/ProtocolVersionExchangeTest_TimeoutReadingIdentificationString.cs index 3710e2064..a3f1b0e9b 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/ProtocolVersionExchangeTest_TimeoutReadingIdentificationString.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/ProtocolVersionExchangeTest_TimeoutReadingIdentificationString.cs @@ -93,8 +93,8 @@ protected void Act() [TestMethod] public void StartShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNotNull(_actualException); - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format("Socket read operation has timed out after {0} milliseconds.", _timeout.TotalMilliseconds), _actualException.Message); } diff --git a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingDestinationAddress.cs b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingDestinationAddress.cs index d87969ced..3604aac72 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingDestinationAddress.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingDestinationAddress.cs @@ -97,7 +97,8 @@ protected override void Act() [TestMethod] public void ConnectShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", _connectionInfo.Timeout.TotalMilliseconds), _actualException.Message); } diff --git a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyCode.cs b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyCode.cs index 8f6ee9019..19a4b777b 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyCode.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyCode.cs @@ -93,7 +93,8 @@ protected override void Act() [TestMethod] public void ConnectShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", _connectionInfo.Timeout.TotalMilliseconds), _actualException.Message); } diff --git a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyVersion.cs b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyVersion.cs index 4ca6e0c58..0e5ace5e5 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyVersion.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyVersion.cs @@ -82,7 +82,8 @@ protected override void Act() [TestMethod] public void ConnectShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", _connectionInfo.Timeout.TotalMilliseconds), _actualException.Message); } From 75ced08eb3bdf71369172c697ed0f6066074920b Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Mon, 26 Feb 2024 20:25:22 +0100 Subject: [PATCH 3/4] Cleanup SocketAbstraction * Use "using" and ManualResetEventSlim in Connect * Delete unused and unnecessary methods --- .../Abstractions/SocketAbstraction.cs | 169 ++---------------- .../Channels/ChannelForwardedTcpip.cs | 9 +- src/Renci.SshNet/Common/Extensions.cs | 13 +- src/Renci.SshNet/Connection/ConnectorBase.cs | 2 +- .../Connection/Socks5Connector.cs | 3 +- src/Renci.SshNet/ForwardedPortDynamic.cs | 6 +- src/Renci.SshNet/Session.cs | 1 + .../Common/Socks5Handler.cs | 35 ++-- ...pose_SessionIsConnectedAndChannelIsOpen.cs | 2 + 9 files changed, 48 insertions(+), 192 deletions(-) diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs index d96b9189b..2c04e4c47 100644 --- a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs +++ b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs @@ -3,7 +3,9 @@ using System.Net; using System.Net.Sockets; using System.Threading; +#if NET6_0_OR_GREATER == false using System.Threading.Tasks; +#endif using Renci.SshNet.Common; @@ -11,78 +13,20 @@ namespace Renci.SshNet.Abstractions { internal static partial class SocketAbstraction { - public static bool CanRead(Socket socket) - { - if (socket.Connected) - { - return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0; - } - - return false; - } - - /// - /// Returns a value indicating whether the specified can be used - /// to send data. - /// - /// The to check. - /// - /// if can be written to; otherwise, . - /// - public static bool CanWrite(Socket socket) - { - if (socket != null && socket.Connected) - { - return socket.Poll(-1, SelectMode.SelectWrite); - } - - return false; - } - - public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout) - { - var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; - ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: true); - return socket; - } - public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout) { - ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: false); - } - - public static async Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken) - { - await socket.ConnectAsync(remoteEndpoint, cancellationToken).ConfigureAwait(false); - } - - private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket) - { - var connectCompleted = new ManualResetEvent(initialState: false); - var args = new SocketAsyncEventArgs - { - UserToken = connectCompleted, - RemoteEndPoint = remoteEndpoint - }; - args.Completed += ConnectCompleted; + using var connectCompleted = new ManualResetEventSlim(initialState: false); + using var args = new SocketAsyncEventArgs + { + RemoteEndPoint = remoteEndpoint + }; + args.Completed += (_, _) => connectCompleted.Set(); if (socket.ConnectAsync(args)) { - if (!connectCompleted.WaitOne(connectTimeout)) + if (!connectCompleted.Wait(connectTimeout)) { - // avoid ObjectDisposedException in ConnectCompleted - args.Completed -= ConnectCompleted; - if (ownsSocket) - { - // dispose Socket - socket.Dispose(); - } - - // dispose ManualResetEvent - connectCompleted.Dispose(); - - // dispose SocketAsyncEventArgs - args.Dispose(); + socket.Dispose(); throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", @@ -90,61 +34,12 @@ private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSp } } - // dispose ManualResetEvent - connectCompleted.Dispose(); - if (args.SocketError != SocketError.Success) { var socketError = (int) args.SocketError; - if (ownsSocket) - { - // dispose Socket - socket.Dispose(); - } - - // dispose SocketAsyncEventArgs - args.Dispose(); - throw new SocketException(socketError); } - - // dispose SocketAsyncEventArgs - args.Dispose(); - } - - public static void ClearReadBuffer(Socket socket) - { - var timeout = TimeSpan.FromMilliseconds(500); - var buffer = new byte[256]; - int bytesReceived; - - do - { - bytesReceived = ReadPartial(socket, buffer, 0, buffer.Length, timeout); - } - while (bytesReceived > 0); - } - - public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout) - { - socket.ReceiveTimeout = timeout.AsTimeout(nameof(timeout)); - - try - { - return socket.Receive(buffer, offset, size, SocketFlags.None); - } - catch (SocketException ex) - { - if (ex.SocketErrorCode == SocketError.TimedOut) - { - throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, - "Socket read operation has timed out after {0:F0} milliseconds.", - timeout.TotalMilliseconds)); - } - - throw; - } } public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int size, Action processReceivedBytesAction) @@ -206,41 +101,6 @@ public static int ReadByte(Socket socket, TimeSpan timeout) return buffer[0]; } - /// - /// Sends a byte using the specified . - /// - /// The to write to. - /// The value to send. - /// The write failed. - public static void SendByte(Socket socket, byte value) - { - var buffer = new[] { value }; - _ = socket.Send(buffer); - } - - /// - /// Receives data from a bound . - /// - /// The to read from. - /// The number of bytes to receive. - /// Specifies the amount of time after which the call will time out. - /// - /// The bytes received. - /// - /// - /// If no data is available for reading, the method will - /// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the - /// call will throw a . - /// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the - /// method will complete immediately and throw a . - /// - public static byte[] Read(Socket socket, int size, TimeSpan timeout) - { - var buffer = new byte[size]; - _ = Read(socket, buffer, 0, size, timeout); - return buffer; - } - /// /// Receives data from a bound into a receive buffer. /// @@ -258,10 +118,6 @@ public static byte[] Read(Socket socket, int size, TimeSpan timeout) /// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the /// call will throw a . /// - /// - /// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the - /// method will complete immediately and throw a . - /// /// public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeSpan readTimeout) { @@ -301,10 +157,5 @@ public static ValueTask ReadAsync(Socket socket, byte[] buffer, Cancellatio return socket.ReceiveAsync(new ArraySegment(buffer, 0, buffer.Length), SocketFlags.None, cancellationToken); } #endif - private static void ConnectCompleted(object sender, SocketAsyncEventArgs e) - { - var eventWaitHandle = (ManualResetEvent) e.UserToken; - _ = eventWaitHandle?.Set(); - } } } diff --git a/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs b/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs index 2e075f76b..d881a7fa7 100644 --- a/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs +++ b/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs @@ -3,6 +3,7 @@ using System.Net.Sockets; using Renci.SshNet.Abstractions; using Renci.SshNet.Common; +using Renci.SshNet.Connection; using Renci.SshNet.Messages.Connection; namespace Renci.SshNet.Channels @@ -13,6 +14,7 @@ namespace Renci.SshNet.Channels internal sealed class ChannelForwardedTcpip : ServerChannel, IChannelForwardedTcpip { private readonly object _socketShutdownAndCloseLock = new object(); + private readonly ISocketFactory _socketFactory; private Socket _socket; private IForwardedPort _forwardedPort; @@ -20,6 +22,7 @@ internal sealed class ChannelForwardedTcpip : ServerChannel, IChannelForwardedTc /// Initializes a new instance of the class. /// /// The session. + /// The socket factory. /// The local channel number. /// Size of the window. /// Size of the packet. @@ -27,6 +30,7 @@ internal sealed class ChannelForwardedTcpip : ServerChannel, IChannelForwardedTc /// The window size of the remote party. /// The maximum size of a data packet that we can send to the remote party. internal ChannelForwardedTcpip(ISession session, + ISocketFactory socketFactory, uint localChannelNumber, uint localWindowSize, uint localPacketSize, @@ -41,6 +45,7 @@ internal ChannelForwardedTcpip(ISession session, remoteWindowSize, remotePacketSize) { + _socketFactory = socketFactory; } /// @@ -72,7 +77,9 @@ public void Bind(IPEndPoint remoteEndpoint, IForwardedPort forwardedPort) // Try to connect to the socket try { - _socket = SocketAbstraction.Connect(remoteEndpoint, ConnectionInfo.Timeout); + _socket = _socketFactory.Create(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + + SocketAbstraction.Connect(_socket, remoteEndpoint, ConnectionInfo.Timeout); // Send channel open confirmation message SendMessage(new ChannelOpenConfirmationMessage(RemoteChannelNumber, LocalWindowSize, LocalPacketSize, LocalChannelNumber)); diff --git a/src/Renci.SshNet/Common/Extensions.cs b/src/Renci.SshNet/Common/Extensions.cs index 80fa8323d..2eb89e3e4 100644 --- a/src/Renci.SshNet/Common/Extensions.cs +++ b/src/Renci.SshNet/Common/Extensions.cs @@ -5,7 +5,7 @@ using System.Net; using System.Net.Sockets; using System.Text; -using Renci.SshNet.Abstractions; + using Renci.SshNet.Messages; namespace Renci.SshNet.Common @@ -336,22 +336,17 @@ public static byte[] Concat(this byte[] first, byte[] second) internal static bool CanRead(this Socket socket) { - return SocketAbstraction.CanRead(socket); + return socket.Connected && socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0; } internal static bool CanWrite(this Socket socket) { - return SocketAbstraction.CanWrite(socket); + return socket is not null && socket.Connected && socket.Poll(-1, SelectMode.SelectWrite); } internal static bool IsConnected(this Socket socket) { - if (socket is null) - { - return false; - } - - return socket.Connected; + return socket is not null && socket.Connected; } } } diff --git a/src/Renci.SshNet/Connection/ConnectorBase.cs b/src/Renci.SshNet/Connection/ConnectorBase.cs index 384091b18..9eaaee015 100644 --- a/src/Renci.SshNet/Connection/ConnectorBase.cs +++ b/src/Renci.SshNet/Connection/ConnectorBase.cs @@ -86,7 +86,7 @@ protected async Task SocketConnectAsync(string host, int port, Cancellat var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp); try { - await SocketAbstraction.ConnectAsync(socket, ep, cancellationToken).ConfigureAwait(false); + await socket.ConnectAsync(ep, cancellationToken).ConfigureAwait(false); const int socketBufferSize = 2 * Session.MaximumSshPacketSize; socket.SendBufferSize = socketBufferSize; diff --git a/src/Renci.SshNet/Connection/Socks5Connector.cs b/src/Renci.SshNet/Connection/Socks5Connector.cs index 306db7002..8c42e591e 100644 --- a/src/Renci.SshNet/Connection/Socks5Connector.cs +++ b/src/Renci.SshNet/Connection/Socks5Connector.cs @@ -63,7 +63,8 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke _ = socket.Send(authenticationRequest); // Read authentication result - var authenticationResult = SocketAbstraction.Read(socket, 2, connectionInfo.Timeout); + var authenticationResult = new byte[2]; + _ = SocketAbstraction.Read(socket, authenticationResult, 0, authenticationResult.Length, connectionInfo.Timeout); if (authenticationResult[0] != 0x01) { diff --git a/src/Renci.SshNet/ForwardedPortDynamic.cs b/src/Renci.SshNet/ForwardedPortDynamic.cs index 440224efd..b6705eb55 100644 --- a/src/Renci.SshNet/ForwardedPortDynamic.cs +++ b/src/Renci.SshNet/ForwardedPortDynamic.cs @@ -503,18 +503,18 @@ private bool HandleSocks4(Socket socket, IChannelDirectTcpip channel, TimeSpan t channel.Open(host, port, this, socket); - SocketAbstraction.SendByte(socket, 0x00); + _ = socket.Send([0x00]); if (channel.IsOpen) { - SocketAbstraction.SendByte(socket, 0x5a); + _ = socket.Send([0x5a]); _ = socket.Send(portBuffer); _ = socket.Send(ipBuffer); return true; } // signal that request was rejected or failed - SocketAbstraction.SendByte(socket, 0x5b); + _ = socket.Send([0x5b]); return false; } diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index 2444aced3..9eba01d2b 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -2232,6 +2232,7 @@ IChannelForwardedTcpip ISession.CreateChannelForwardedTcpip(uint remoteChannelNu uint remoteChannelDataPacketSize) { return new ChannelForwardedTcpip(this, + _socketFactory, NextChannelNumber, InitialLocalWindowSize, LocalChannelDataPacketSize, diff --git a/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs b/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs index 8ed929f71..69d9bf1af 100644 --- a/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs +++ b/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs @@ -3,18 +3,21 @@ using Renci.SshNet.Abstractions; using Renci.SshNet.Common; +using Renci.SshNet.Connection; using Renci.SshNet.Messages.Transport; namespace Renci.SshNet.IntegrationTests.Common { class Socks5Handler { + private readonly ISocketFactory _socketFactory; private readonly IPEndPoint _proxyEndPoint; private readonly string _userName; private readonly string _password; public Socks5Handler(IPEndPoint proxyEndPoint, string userName, string password) { + _socketFactory = new SocketFactory(); _proxyEndPoint = proxyEndPoint; _userName = userName; _password = password; @@ -52,17 +55,19 @@ public Socket Connect(string host, int port) private Socket Connect(byte[] addressBytes, int port) { - var socket = SocketAbstraction.Connect(_proxyEndPoint, TimeSpan.FromSeconds(5)); + var socket = _socketFactory.Create(_proxyEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + + SocketAbstraction.Connect(socket, _proxyEndPoint, TimeSpan.FromSeconds(5)); // Send socks version number - SocketWriteByte(socket, 0x05); + _ = socket.Send([0x05]); // Send number of supported authentication methods - SocketWriteByte(socket, 0x02); + _ = socket.Send([0x02]); // Send supported authentication methods - SocketWriteByte(socket, 0x00); // No authentication - SocketWriteByte(socket, 0x02); // Username/Password + _ = socket.Send([0x00]); // No authentication + _ = socket.Send([0x02]); // Username/Password var socksVersion = SocketReadByte(socket); if (socksVersion != 0x05) @@ -78,7 +83,7 @@ private Socket Connect(byte[] addressBytes, int port) case 0x02: // Send version - SocketWriteByte(socket, 0x01); + _ = socket.Send([0x01]); var username = Encoding.ASCII.GetBytes(_userName); if (username.Length > byte.MaxValue) @@ -87,7 +92,7 @@ private Socket Connect(byte[] addressBytes, int port) } // Send username length - SocketWriteByte(socket, (byte) username.Length); + _ = socket.Send([(byte) username.Length]); // Send username _ = socket.Send(username); @@ -100,7 +105,7 @@ private Socket Connect(byte[] addressBytes, int port) } // Send password length - SocketWriteByte(socket, (byte) password.Length); + _ = socket.Send([(byte) password.Length]); // Send password _ = socket.Send(password); @@ -126,20 +131,19 @@ private Socket Connect(byte[] addressBytes, int port) } // Send socks version number - SocketWriteByte(socket, 0x05); + _ = socket.Send([0x05]); // Send command code - SocketWriteByte(socket, 0x01); // establish a TCP/IP stream connection + _ = socket.Send([0x01]); // establish a TCP/IP stream connection // Send reserved, must be 0x00 - SocketWriteByte(socket, 0x00); + _ = socket.Send([0x00]); // Send address type and address _ = socket.Send(addressBytes); // Send port - SocketWriteByte(socket, (byte)(port / 0xFF)); - SocketWriteByte(socket, (byte)(port % 0xFF)); + _ = socket.Send([(byte) (port / 0xFF), (byte) (port % 0xFF)]); // Read Server SOCKS5 version if (SocketReadByte(socket) != 5) @@ -226,11 +230,6 @@ private static byte[] GetAddressBytes(IPEndPoint endPoint) throw new ProxyException(string.Format("SOCKS5: IP address '{0}' is not supported.", endPoint.Address)); } - private static void SocketWriteByte(Socket socket, byte data) - { - SocketAbstraction.Send(socket, new[] { data }); - } - private static byte SocketReadByte(Socket socket) { var buffer = new byte[1]; diff --git a/test/Renci.SshNet.Tests/Classes/Channels/ChannelForwardedTcpipTest_Dispose_SessionIsConnectedAndChannelIsOpen.cs b/test/Renci.SshNet.Tests/Classes/Channels/ChannelForwardedTcpipTest_Dispose_SessionIsConnectedAndChannelIsOpen.cs index 393eae093..df5fc4a95 100644 --- a/test/Renci.SshNet.Tests/Classes/Channels/ChannelForwardedTcpipTest_Dispose_SessionIsConnectedAndChannelIsOpen.cs +++ b/test/Renci.SshNet.Tests/Classes/Channels/ChannelForwardedTcpipTest_Dispose_SessionIsConnectedAndChannelIsOpen.cs @@ -9,6 +9,7 @@ using Moq; using Renci.SshNet.Channels; +using Renci.SshNet.Connection; using Renci.SshNet.Messages.Connection; using Renci.SshNet.Tests.Common; @@ -140,6 +141,7 @@ private void Arrange() _remoteListener.Start(); _channel = new ChannelForwardedTcpip(_sessionMock.Object, + new SocketFactory(), _localChannelNumber, _localWindowSize, _localPacketSize, From 3f6accb32e36c37a7cf65716e51e733bd3c00684 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Mon, 4 Mar 2024 22:10:02 +0100 Subject: [PATCH 4/4] Add a loop to SocketAbstraction.ReadAsync In order to ensure the buffer is read completely, as in SocketAbstraction.Read --- .../Abstractions/SocketAbstraction.Async.cs | 17 ---------- .../Abstractions/SocketAbstraction.cs | 33 +++++++++++++++---- .../Connection/ProtocolVersionExchange.cs | 2 +- 3 files changed, 28 insertions(+), 24 deletions(-) delete mode 100644 src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs deleted file mode 100644 index 9a16b951c..000000000 --- a/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs +++ /dev/null @@ -1,17 +0,0 @@ -#if NET6_0_OR_GREATER - -using System.Net.Sockets; -using System.Threading; -using System.Threading.Tasks; - -namespace Renci.SshNet.Abstractions -{ - internal static partial class SocketAbstraction - { - public static ValueTask ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken) - { - return socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken); - } - } -} -#endif // NET6_0_OR_GREATER diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs index 2c04e4c47..ee094c3c3 100644 --- a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs +++ b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs @@ -3,9 +3,7 @@ using System.Net; using System.Net.Sockets; using System.Threading; -#if NET6_0_OR_GREATER == false using System.Threading.Tasks; -#endif using Renci.SshNet.Common; @@ -151,11 +149,34 @@ public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeS return totalBytesRead; } -#if NET6_0_OR_GREATER == false - public static ValueTask ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken) + public static async ValueTask ReadAsync(Socket socket, byte[] buffer, int offset, int size, CancellationToken cancellationToken) { - return socket.ReceiveAsync(new ArraySegment(buffer, 0, buffer.Length), SocketFlags.None, cancellationToken); + var totalBytesRead = 0; + var totalBytesToRead = size; + + do + { + try + { + var bytesRead = await socket.ReceiveAsync(new ArraySegment(buffer, offset + totalBytesRead, totalBytesToRead - totalBytesRead), SocketFlags.None, cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) + { + return 0; + } + + totalBytesRead += bytesRead; + } + catch (SocketException ex) when (ex.SocketErrorCode == SocketError.TimedOut) + { + throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, + "Socket read operation has timed out after {0:F0} milliseconds.", + socket.ReceiveTimeout), + ex); + } + } + while (totalBytesRead < totalBytesToRead); + + return totalBytesRead; } -#endif } } diff --git a/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs b/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs index 8cb5e9310..607249e2b 100644 --- a/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs +++ b/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs @@ -187,7 +187,7 @@ private static async Task SocketReadLineAsync(Socket socket, List // to be processed by subsequent invocations. while (true) { - var bytesRead = await SocketAbstraction.ReadAsync(socket, data, cancellationToken).ConfigureAwait(false); + var bytesRead = await SocketAbstraction.ReadAsync(socket, data, 0, data.Length, cancellationToken).ConfigureAwait(false); if (bytesRead == 0) { throw new SshConnectionException("The connection was closed by the remote host.");