From 940dd090506142e4198a11f24fe883bbe26166f0 Mon Sep 17 00:00:00 2001 From: Andrew Druk Date: Mon, 4 Aug 2025 14:53:38 +0300 Subject: [PATCH] Fix WebSocket buffered read Add support for fragmented messages Buffered socket reads could result in incomplete frame parsing due to incorrect assumptions about TCP delivery. This patch introduces proper accumulation of partial reads. Also adds handling for fragmented WebSocket messages split across multiple frames. --- .../WebSocket/WebSocketURLProtocol.swift | 25 +++- .../URLSession/libcurl/EasyHandle.swift | 4 +- Tests/Foundation/HTTPServer.swift | 99 ++++++++++++++-- Tests/Foundation/TestURLSession.swift | 107 +++++++++--------- 4 files changed, 169 insertions(+), 66 deletions(-) diff --git a/Sources/FoundationNetworking/URLSession/WebSocket/WebSocketURLProtocol.swift b/Sources/FoundationNetworking/URLSession/WebSocket/WebSocketURLProtocol.swift index 8216f23d58..e35612ddae 100644 --- a/Sources/FoundationNetworking/URLSession/WebSocket/WebSocketURLProtocol.swift +++ b/Sources/FoundationNetworking/URLSession/WebSocket/WebSocketURLProtocol.swift @@ -17,6 +17,9 @@ import Foundation import Dispatch internal class _WebSocketURLProtocol: _HTTPURLProtocol { + + private var messageData = Data() + public required init(task: URLSessionTask, cachedResponse: CachedURLResponse?, client: URLProtocolClient?) { super.init(task: task, cachedResponse: nil, client: client) } @@ -118,14 +121,14 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol { lastRedirectBody = redirectBody } - let flags = easyHandle.getWebSocketFlags() + let (offset, bytesLeft, flags) = easyHandle.getWebSocketMeta() - notifyTask(aboutReceivedData: data, flags: flags) + notifyTask(aboutReceivedData: data, offset: offset, bytesLeft: bytesLeft, flags: flags) internalState = .transferInProgress(ts) return .proceed } - fileprivate func notifyTask(aboutReceivedData data: Data, flags: _EasyHandle.WebSocketFlags) { + fileprivate func notifyTask(aboutReceivedData data: Data, offset: Int64, bytesLeft: Int64, flags: _EasyHandle.WebSocketFlags) { guard let t = self.task else { fatalError("Cannot notify") } @@ -159,10 +162,21 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol { } else if flags.contains(.pong) { task.noteReceivedPong() } else if flags.contains(.binary) { - let message = URLSessionWebSocketTask.Message.data(data) + if bytesLeft > 0 || flags.contains(.cont) { + messageData.append(data) + return + } + messageData.append(data) + let message = URLSessionWebSocketTask.Message.data(messageData) task.appendReceivedMessage(message) + messageData = Data() // Reset for the next message } else if flags.contains(.text) { - guard let utf8 = String(data: data, encoding: .utf8) else { + if bytesLeft > 0 || flags.contains(.cont) { + messageData.append(data) + return + } + messageData.append(data) + guard let utf8 = String(data: messageData, encoding: .utf8) else { NSLog("Invalid utf8 message received from server \(data)") let error = NSError(domain: NSURLErrorDomain, code: NSURLErrorBadServerResponse, userInfo: [ @@ -175,6 +189,7 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol { } let message = URLSessionWebSocketTask.Message.string(utf8) task.appendReceivedMessage(message) + messageData = Data() // Reset for the next message } else { NSLog("Unexpected message received from server \(data) \(flags)") let error = NSError(domain: NSURLErrorDomain, code: NSURLErrorBadServerResponse, diff --git a/Sources/FoundationNetworking/URLSession/libcurl/EasyHandle.swift b/Sources/FoundationNetworking/URLSession/libcurl/EasyHandle.swift index e9b90a23f1..a891c9818e 100644 --- a/Sources/FoundationNetworking/URLSession/libcurl/EasyHandle.swift +++ b/Sources/FoundationNetworking/URLSession/libcurl/EasyHandle.swift @@ -375,10 +375,10 @@ extension _EasyHandle { } // Only valid to call within a didReceive(data:size:nmemb:) call - func getWebSocketFlags() -> WebSocketFlags { + func getWebSocketMeta() -> (Int64, Int64, WebSocketFlags) { let metadataPointer = CFURLSessionEasyHandleWebSocketsMetadata(rawHandle) let flags = WebSocketFlags(rawValue: metadataPointer.pointee.flags) - return flags + return (metadataPointer.pointee.offset, metadataPointer.pointee.bytesLeft, flags) } func receiveWebSocketsData() throws -> (Data, WebSocketFlags) { diff --git a/Tests/Foundation/HTTPServer.swift b/Tests/Foundation/HTTPServer.swift index a6457f285f..9922d32897 100644 --- a/Tests/Foundation/HTTPServer.swift +++ b/Tests/Foundation/HTTPServer.swift @@ -914,6 +914,8 @@ public class TestURLSessionServer: CustomStringConvertible { "Connection: Upgrade"] let expectFullRequestResponseTests: Bool + let bufferedSendingTests: Bool + let fragmentedTests: Bool let sendClosePacket: Bool let completeUpgrade: Bool @@ -921,14 +923,32 @@ public class TestURLSessionServer: CustomStringConvertible { switch uri { case "/web-socket": expectFullRequestResponseTests = true + bufferedSendingTests = false + fragmentedTests = false + completeUpgrade = true + sendClosePacket = true + case "/web-socket/buffered-sending": + expectFullRequestResponseTests = true + bufferedSendingTests = true + fragmentedTests = false + completeUpgrade = true + sendClosePacket = true + case "/web-socket/fragmented": + expectFullRequestResponseTests = true + bufferedSendingTests = false + fragmentedTests = true completeUpgrade = true sendClosePacket = true case "/web-socket/semi-abrupt-close": expectFullRequestResponseTests = false + bufferedSendingTests = false + fragmentedTests = false completeUpgrade = true sendClosePacket = false case "/web-socket/abrupt-close": expectFullRequestResponseTests = false + bufferedSendingTests = false + fragmentedTests = false completeUpgrade = false sendClosePacket = false default: @@ -944,6 +964,8 @@ public class TestURLSessionServer: CustomStringConvertible { } responseHeaders.append("Sec-WebSocket-Protocol: \(expectedProtocol)") expectFullRequestResponseTests = false + bufferedSendingTests = false + fragmentedTests = false completeUpgrade = true sendClosePacket = true } @@ -978,10 +1000,41 @@ public class TestURLSessionServer: CustomStringConvertible { NSLog("Invalid string frame") throw InternalServerError.badBody } - - // Send a string message - let sendStringFrame = Data([0x81, UInt8(stringPayload.count)]) + stringPayload - try httpServer.tcpSocket.writeRawData(sendStringFrame) + + if bufferedSendingTests { + // Send a string message in chunks of 2 bytes + let sendStringFrame = Data([0x81, UInt8(stringPayload.count)]) + stringPayload + let bufferSize = 2 // Let's assume the server has a buffer size of 2 bytes + for i in stride(from: 0, to: sendStringFrame.count, by: bufferSize) { + let end = min(i + bufferSize, sendStringFrame.count) + let chunk = sendStringFrame.subdata(in: i.. Void { + let url = try XCTUnwrap(URL(string: urlString)) + let request = URLRequest(url: url) + + let delegate = SessionDelegate(with: expectation(description: "\(urlString): Connect")) + let task = delegate.runWebSocketTask(with: request, timeoutInterval: 4) + + // We interleave sending and receiving, as the test HTTPServer implementation is barebones, and can't handle receiving more than one frame at a time. So, this back-and-forth acts as a gating mechanism + try await task.send(.string("Hello")) + + let stringMessage = try await task.receive() + switch stringMessage { + case .string(let str): + XCTAssert(str == "Hello") + default: + XCTFail("Unexpected String Message") + } + + try await task.send(.data(Data([0x20, 0x22, 0x10, 0x03]))) + + let dataMessage = try await task.receive() + switch dataMessage { + case .data(let data): + XCTAssert(data == Data([0x20, 0x22, 0x10, 0x03])) + default: + XCTFail("Unexpected Data Message") + } + + do { + try await task.sendPing() + // Server hasn't closed the connection yet + } catch { + // Server closed the connection before we could process the pong + let urlError = try XCTUnwrap(error as? URLError) + XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost) + } + + await fulfillment(of: [delegate.expectation], timeout: 50) + + do { + _ = try await task.receive() + XCTFail("Expected to throw when receiving on closed task") + } catch { + let urlError = try XCTUnwrap(error as? URLError) + XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost) + } + + let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)", + "urlSession(_:webSocketTask:didCloseWith:reason:)", + "urlSession(_:task:didCompleteWithError:)" ] + XCTAssertEqual(delegate.callbacks.count, callbacks.count) + XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)") } - - let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)", - "urlSession(_:webSocketTask:didCloseWith:reason:)", - "urlSession(_:task:didCompleteWithError:)" ] - XCTAssertEqual(delegate.callbacks.count, callbacks.count) - XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)") + + try await testWebSocket(withURL: "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket") + try await testWebSocket(withURL: "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket/buffered-sending") + try await testWebSocket(withURL: "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket/fragmented") } func test_webSocketShared() async throws {