diff --git a/src/http/client.zig b/src/http/client.zig index 462a26e7..197da8d0 100644 --- a/src/http/client.zig +++ b/src/http/client.zig @@ -321,12 +321,16 @@ const Connection = struct { socket: posix.socket_t, const TLSClient = union(enum) { - blocking: tls.Connection(std.net.Stream), + blocking: tls.Connection(std.net.Stream), // Note can also be a tlsproxy if the destination is not secure blocking_tlsproxy: struct { proxy: tls.Connection(std.net.Stream), // Note, self-referential field. Proxy should be pinned in memory. destination: tls.Connection(*tls.Connection(std.net.Stream)), }, nonblocking: tls.nonblock.Connection, + nonblocking_tlsproxy: struct { + proxy: tls.nonblock.Connection, + destination: tls.nonblock.Connection, + }, fn close(self: *TLSClient) void { switch (self.*) { @@ -335,7 +339,7 @@ const Connection = struct { tls_in_tls.destination.close() catch {}; tls_in_tls.proxy.close() catch {}; }, - .nonblocking => {}, + .nonblocking, .nonblocking_tlsproxy => {}, } } }; @@ -787,32 +791,16 @@ pub const Request = struct { .handler = handler, .read_buf = state.read_buf, .write_buf = state.write_buf, + .write_connect_buf = state.write_connect_buf, .reader = self.newReader(), .socket = connection.socket, .conn = .{ .handler = async_handler, .protocol = .{ .plain = {} } }, }; - if (self._client.isConnectProxy() and self._proxy_secure) log.warn(.http, "ASYNC TLS CONNECT no impl.", .{}); - if (self._request_secure) { - if (self._connection_from_keepalive) { - // If the connection came from the keepalive pool, than we already - // have a TLS Connection. - async_handler.conn.protocol = .{ .encrypted = .{ .conn = &connection.tls.?.nonblocking } }; - } else { - std.debug.assert(connection.tls == null); - async_handler.conn.protocol = .{ - .handshake = tls.nonblock.Client.init(.{ - .host = if (self._client.isConnectProxy()) self._request_host else self._connect_host, // looks wrong - .root_ca = self._client.root_ca, - .insecure_skip_verify = self._tls_verify_host == false, - .key_log_callback = tls.config.key_log.callback, - }), - }; - } - } - - if (self._connection_from_keepalive) { - // we're already connected + if (self._connection_from_keepalive and self._request_secure) { + // If the connection came from the keepalive pool, than we already have a TLS Connection. + async_handler.conn.protocol = .{ .encrypted = .{ .conn = &connection.tls.?.nonblocking } }; + // and we're already connected async_handler.pending_connect = false; return async_handler.conn.connected(); } @@ -1065,6 +1053,7 @@ fn AsyncHandler(comptime H: type) type { // need a separate read and write buf because, with TLS, messages are // not strictly req->resp. write_buf: []u8, + write_connect_buf: []u8, socket: posix.socket_t, read_completion: IO.Completion = undefined, @@ -1081,7 +1070,7 @@ fn AsyncHandler(comptime H: type) type { send_queue: SendQueue = .{}, // Used to help us know if we're writing the header or the body; - state: SendState = .handshake, + state: SendState = .init, // Abstraction over TLS and plain text socket, this is a version of // the request._connection (which is a *Connection) that is async-specific. @@ -1110,7 +1099,9 @@ fn AsyncHandler(comptime H: type) type { const SendQueue = std.DoublyLinkedList([]const u8); const SendState = enum { - connect, + init, + connect_handshake, + connect_header, handshake, header, body, @@ -1134,6 +1125,16 @@ fn AsyncHandler(comptime H: type) type { self.maybeShutdown(); } + /// Shift unused part of the buffer to the beginning. + /// Returns write position for the next write into buffer. + /// Unused part is at the end of the buffer. + fn shiftUnused(buf: []u8, unused: []const u8) usize { + if (unused.len == 0) return 0; + if (unused.ptr == buf.ptr) return unused.len; + std.mem.copyForwards(u8, buf, unused); + return unused.len; + } + fn connected(self: *Self, _: *IO.Completion, result: IO.ConnectError!void) void { self.pending_connect = false; if (self.shutdown) { @@ -1142,16 +1143,81 @@ fn AsyncHandler(comptime H: type) type { result catch |err| return self.handleError("Connection failed", err); - if (self.request.shouldProxyConnect()) { - self.state = .connect; - const header = self.request.buildConnectHeader() catch |err| { + const request = self.request; + if (request._request_secure) { + std.debug.assert(request._connection.?.tls == null); + self.conn.protocol = .{ + .handshake = tls.nonblock.Client.init(.{ + .host = request._request_host, + .root_ca = request._client.root_ca, + .insecure_skip_verify = request._tls_verify_host == false, + .key_log_callback = tls.config.key_log.callback, + }), + }; + } + + if (request.shouldProxyConnect()) { + if (request._proxy_secure) { + // If we're using a CONNECT proxy, we need to do a TLS handshake, before sending the CONNECT request + self.state = .connect_handshake; + const tls_config = tls.config.Client{ + .host = request._connect_host, + .root_ca = request._client.root_ca, + .insecure_skip_verify = request._tls_verify_host == false, + .key_log_callback = tls.config.key_log.callback, + }; + self.conn.connect_protocol = .{ .handshake = tls.nonblock.Client.init(tls_config) }; + const handshake = &self.conn.connect_protocol.?.handshake; + + var recv_buf: [tls.max_ciphertext_record_len]u8 = undefined; + var send_buf: [tls.max_ciphertext_record_len]u8 = undefined; + var recv_pos: usize = 0; + while (true) { // run handshake until done + const res = handshake.run(recv_buf[0..recv_pos], &send_buf) catch |err| { + return self.handleError("TLS handshake error", err); + }; + if (res.send.len > 0) { + var i: usize = 0; + while (i < res.send.len) { + i += posix.write(self.socket, res.send[i..]) catch |err| { + return self.handleError("TLS handshake write error", err); + }; + } + } + recv_pos = shiftUnused(&recv_buf, res.unused_recv); + if (handshake.done()) break; + while (true) { + recv_pos += posix.read(self.socket, recv_buf[recv_pos..]) catch |err| { + if (err == error.WouldBlock) continue; // Retry on WouldBlock + return self.handleError("TLS handshake read error", err); + }; + break; + } + } + + const encrypted = tls.nonblock.Connection.init(handshake.cipher().?); // steal it from blocking connection if it exists? + std.debug.assert(request._connection.?.tls == null); + request._connection.?.tls = .{ .nonblocking = encrypted }; // TODO considering storing it in nonblocking_tlsproxy directly if request_secure + self.conn.connect_protocol = .{ + .encrypted = .{ + .conn = &request._connection.?.tls.?.nonblocking, + }, + }; + } + // else { + // If we're using a plain proxy, we just send the CONNECT request + self.state = .connect_header; + const header = request.buildConnectHeader() catch |err| { return self.handleError("Failed to build CONNECT header", err); }; self.send(header); self.receive(); return; + // } } + // There is no CONNECT proxy so we can proceed with the TLS handshake or header if plain + self.state = .handshake; self.conn.connected() catch |err| { self.handleError("connected handler error", err); }; @@ -1164,6 +1230,18 @@ fn AsyncHandler(comptime H: type) type { }; node.data = data; + + if (self.conn.connect_protocol) |*connect_protocol| { + // Encrypt data with the cypher + const res = connect_protocol.encrypted.conn.encrypt(data, self.conn.handler.write_connect_buf) catch |err| { + self.handleError("TLS proxy encrypt error", err); + return; + }; + connect_protocol.encrypted.unsent = res.unused_cleartext; + node.data = res.ciphertext; + if (res.unused_cleartext.len > 0) log.warn(.http_client, "TLS encrypt unused data", .{}); + } + self.send_queue.append(node); if (self.send_queue.len > 1) { // if we already had a message in the queue, then our send loop @@ -1177,7 +1255,7 @@ fn AsyncHandler(comptime H: type) type { self, &self.send_completion, sent, - self.socket, + self.socket, // For TLS CONNECT should this be a TLS CLient?, no data should already be encrypted if that is needed node.data, ) catch |err| { self.handleError("loop send error", err); @@ -1221,7 +1299,7 @@ fn AsyncHandler(comptime H: type) type { return; } - if (self.state == .connect) { + if (self.state == .connect_handshake or self.state == .connect_header) { // We're in a proxy CONNECT flow. There's nothing for us to // do except for wait for the response. return; @@ -1272,7 +1350,60 @@ fn AsyncHandler(comptime H: type) type { const data = self.read_buf[0 .. self.read_pos + n]; - if (self.state == .connect) { + if (self.state == .connect_handshake) { + // TODO send/rec + } + + if (self.state == .connect_header) { + blk: { + if (self.conn.connect_protocol) |*connect_protocol| { + const res = connect_protocol.encrypted.conn.decrypt(data, data) catch |err| { + return self.handleError("TLS proxy decrypt error", err); + }; + + if (res.ciphertext_pos == 0) { + // no part of the encrypted data was consumed + // no cleartext data should have been generated + std.debug.assert(res.cleartext.len == 0); + + // our next read needs to append more data to + // the existing data + self.read_pos = data.len; + return if (res.closed) break :blk else self.receive(); + } + + if (res.cleartext.len > 0) { + // status = self.processData(res.cleartext); + break :blk; // we assume we can read the header in one go + } + + if (res.closed) break :blk; + + const unused = res.unused_ciphertext; + if (unused.len == 0) { + // all of data was used up, our next read can use + // the whole read buffer. + self.read_pos = 0; + return self.receive(); + } + + // We used some of the data, but have some leftover + // (i.e. there was 1+ full records AND an incomplete + // record). We need to maintain the "leftover" data + // for subsequent reads. + + // Remember that our read_buf is the MAX possible TLS + // record size. So as long as we make sure that the start + // of a record is at read_buf[0], we know that we'll + // always have enough space for 1 record. + std.mem.copyForwards(u8, self.read_buf, unused); + self.read_pos = unused.len; + + // an incomplete record means there must be more data + return self.receive(); + } + } + const success = self.reader.connectResponse(data) catch |err| { return self.handleError("Invalid CONNECT response", err); }; @@ -1524,6 +1655,7 @@ fn AsyncHandler(comptime H: type) type { const Conn = struct { handler: *Self, protocol: Protocol, + connect_protocol: ?Protocol = null, const Encrypted = struct { conn: *tls.nonblock.Connection, @@ -1580,6 +1712,50 @@ fn AsyncHandler(comptime H: type) type { fn received(self: *Conn, data: []u8) !ProcessStatus { const handler = self.handler; + + if (self.connect_protocol) |*connect_protocol| { + blk: { + const res = try connect_protocol.encrypted.conn.decrypt(data, data); + + // Not sure about this code: + + if (res.ciphertext_pos == 0) { + // no part of the encrypted data was consumed no cleartext data should have been generated + std.debug.assert(res.cleartext.len == 0); + + // our next read needs to append more data to the existing data + handler.read_pos = data.len; + return if (res.closed) break :blk else return .need_more; + } + + if (res.cleartext.len > 0) break :blk; + if (res.closed) break :blk; + + const unused = res.unused_ciphertext; + if (unused.len == 0) { + // all of data was used up, our next read can use + // the whole read buffer. + handler.read_pos = 0; + return .need_more; + } + + // We used some of the data, but have some leftover + // (i.e. there was 1+ full records AND an incomplete + // record). We need to maintain the "leftover" data + // for subsequent reads. + + // Remember that our read_buf is the MAX possible TLS + // record size. So as long as we make sure that the start + // of a record is at read_buf[0], we know that we'll + // always have enough space for 1 record. + std.mem.copyForwards(u8, handler.read_buf, unused); + handler.read_pos = unused.len; + + // an incomplete record means there must be more data + return .need_more; + } + } + switch (self.protocol) { .plain => { std.debug.assert(handler.state == .body); @@ -1672,7 +1848,7 @@ fn AsyncHandler(comptime H: type) type { const handler = self.handler; switch (self.protocol) { .plain => switch (handler.state) { - .handshake, .connect => unreachable, + .init, .handshake, .connect_header, .connect_handshake => unreachable, .header => return self.sendBody(), .body => {}, }, @@ -1681,7 +1857,7 @@ fn AsyncHandler(comptime H: type) type { return self.send(encrypted.unsent); } switch (handler.state) { - .handshake, .connect => unreachable, + .init, .handshake, .connect_header, .connect_handshake => unreachable, .header => return self.sendBody(), .body => {}, } @@ -1748,6 +1924,9 @@ fn AsyncHandler(comptime H: type) type { .encrypted => |*encrypted| { const res = try encrypted.conn.encrypt(data, handler.write_buf); encrypted.unsent = res.unused_cleartext; + + // TODO encrypt with CONNECT tls if any + return handler.send(res.ciphertext); }, .handshake => { @@ -1774,7 +1953,7 @@ const SyncHandler = struct { const c = request._connection.?; if (c.tls) |*tls_client| { switch (tls_client.*) { - .nonblocking => unreachable, + .nonblocking, .nonblocking_tlsproxy => unreachable, .blocking => |*blocking| { break :blk .{ .tls = blocking }; }, @@ -2790,6 +2969,7 @@ const State = struct { // write_buf, even though HTTP is req -> resp, it's for TLS, which has // bidirectional data. write_buf: []u8, + write_connect_buf: []u8, // Used for keeping any unparsed header line until more data is received // At most, this represents 1 line in the header. @@ -2808,6 +2988,8 @@ const State = struct { const write_buf = try allocator.alloc(u8, buf_size); errdefer allocator.free(write_buf); + const write_connect_buf = try allocator.alloc(u8, buf_size * 5); + errdefer allocator.free(write_connect_buf); const header_buf = try allocator.alloc(u8, header_size); errdefer allocator.free(header_buf); @@ -2816,6 +2998,7 @@ const State = struct { .peek_buf = peek_buf, .read_buf = read_buf, .write_buf = write_buf, + .write_connect_buf = write_connect_buf, .header_buf = header_buf, .arena = std.heap.ArenaAllocator.init(allocator), }; @@ -2830,6 +3013,7 @@ const State = struct { allocator.free(self.peek_buf); allocator.free(self.read_buf); allocator.free(self.write_buf); + allocator.free(self.write_connect_buf); allocator.free(self.header_buf); self.arena.deinit(); }