diff --git a/x/configurl/doc.go b/x/configurl/doc.go index feffc7a3..165f5f2e 100644 --- a/x/configurl/doc.go +++ b/x/configurl/doc.go @@ -57,6 +57,21 @@ SOCKS5 proxy (works with both stream and packet dialers, package [golang.getoutl USERINFO field is optional and only required if username and password authentication is used. It is in the format of username:password. +HTTP CONNECT proxy (streams only, package [golang.getoutline.org/sdk/x/httpconnect]) + +Three variants are available: + + - httpconnect: HTTP/1.1, or HTTP/2 if negotiated via TLS ALPN. When H2 is negotiated, CONNECT streams are multiplexed over a single TCP connection. + - h2connect: Pure HTTP/2. Always multiplexed. Supports h2c (cleartext H2) via plain=true. + - h3connect: HTTP/3 over QUIC. Always multiplexed. Creates its own UDP socket. + +The sni parameter sets the TLS SNI. The certname parameter sets the name to validate against the server certificate. +For h2connect, plain=true enables h2c (cleartext HTTP/2 without TLS). + + httpconnect://[HOST]:[PORT][?sni=SNI][&certname=CERTNAME] + h2connect://[HOST]:[PORT][?sni=SNI][&certname=CERTNAME][&plain=true] + h3connect://[HOST]:[PORT][?sni=SNI][&certname=CERTNAME] + # Transports TLS transport (currently streams only, package [golang.getoutline.org/sdk/transport/tls]) diff --git a/x/configurl/httpconnect.go b/x/configurl/httpconnect.go new file mode 100644 index 00000000..b654d34d --- /dev/null +++ b/x/configurl/httpconnect.go @@ -0,0 +1,144 @@ +// Copyright 2025 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package configurl + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" + + "golang.getoutline.org/sdk/transport" + "golang.getoutline.org/sdk/transport/tls" + "golang.getoutline.org/sdk/x/httpconnect" +) + +// parseConnectOptions parses query parameters from a hierarchical config URL +// (e.g. h2connect://host:port?sni=example.com&plain=true) into TransportOptions. +// +// Supported parameters: +// - sni: TLS server name for SNI. +// - certname: name to validate against the server certificate. +// - plain: if "true", use cleartext (no TLS). Only meaningful for h2connect (h2c). +func parseConnectOptions(configURL url.URL) ([]httpconnect.TransportOption, error) { + values, err := url.ParseQuery(configURL.RawQuery) + if err != nil { + return nil, err + } + var opts []httpconnect.TransportOption + var tlsOpts []tls.ClientOption + for key, vals := range values { + switch strings.ToLower(key) { + case "sni": + if len(vals) != 1 { + return nil, fmt.Errorf("sni option must have one value, found %v", len(vals)) + } + tlsOpts = append(tlsOpts, tls.WithSNI(vals[0])) + case "certname": + if len(vals) != 1 { + return nil, fmt.Errorf("certname option must have one value, found %v", len(vals)) + } + tlsOpts = append(tlsOpts, tls.WithCertVerifier(&tls.StandardCertVerifier{CertificateName: vals[0]})) + case "plain": + if len(vals) != 1 { + return nil, fmt.Errorf("plain option must have one value, found %v", len(vals)) + } + if vals[0] == "true" { + opts = append(opts, httpconnect.WithPlainHTTP()) + } + default: + return nil, fmt.Errorf("unsupported option %v", key) + } + } + if len(tlsOpts) > 0 { + opts = append(opts, httpconnect.WithTLSOptions(tlsOpts...)) + } + return opts, nil +} + +// registerHTTPConnectStreamDialer registers an HTTP CONNECT proxy transport (H1.1, or H2 via ALPN). +// +// Config format: httpconnect://host:port[?sni=SNI][&certname=CERTNAME] +// +// The base dialer (from the previous element in the pipe chain) is used to establish +// the TCP connection to the proxy. TLS is negotiated by the transport itself. +// When H2 is negotiated via ALPN, CONNECT streams are multiplexed over the single TCP connection. +func registerHTTPConnectStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + opts, err := parseConnectOptions(config.URL) + if err != nil { + return nil, err + } + tr, err := httpconnect.NewHTTPProxyTransport(sd, config.URL.Host, opts...) + if err != nil { + return nil, err + } + return httpconnect.NewConnectClient(tr) + }) +} + +// registerH2ConnectStreamDialer registers a pure HTTP/2 CONNECT proxy transport. +// +// Config format: h2connect://host:port[?sni=SNI][&certname=CERTNAME] +// +// Unlike httpconnect, all CONNECT streams are multiplexed over a single TCP connection +// to the proxy. The base dialer is used to establish that connection. +func registerH2ConnectStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + opts, err := parseConnectOptions(config.URL) + if err != nil { + return nil, err + } + tr, err := httpconnect.NewH2ProxyTransport(sd, config.URL.Host, opts...) + if err != nil { + return nil, err + } + return httpconnect.NewConnectClient(tr) + }) +} + +// registerH3ConnectStreamDialer registers an HTTP/3 CONNECT proxy transport over QUIC. +// +// Config format: h3connect://host:port[?sni=SNI][&certname=CERTNAME] +// +// A UDP socket is created internally and shared across all CONNECT streams (QUIC multiplexing). +// The base stream dialer is not used; QUIC always runs over a fresh UDP connection. +func registerH3ConnectStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + opts, err := parseConnectOptions(config.URL) + if err != nil { + return nil, err + } + udpConn, err := net.ListenPacket("udp", ":0") + if err != nil { + return nil, fmt.Errorf("failed to create UDP socket: %w", err) + } + tr, err := httpconnect.NewH3ProxyTransport(udpConn, config.URL.Host, opts...) + if err != nil { + udpConn.Close() + return nil, err + } + return httpconnect.NewConnectClient(tr) + }) +} diff --git a/x/configurl/httpconnect_test.go b/x/configurl/httpconnect_test.go new file mode 100644 index 00000000..d3958d2f --- /dev/null +++ b/x/configurl/httpconnect_test.go @@ -0,0 +1,92 @@ +// Copyright 2025 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package configurl_test + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "golang.getoutline.org/sdk/transport" + "golang.getoutline.org/sdk/x/configurl" + "golang.getoutline.org/sdk/x/httpproxy" + "golang.org/x/net/http2" +) + +// Test_H2Connect_H2C tests the h2connect configurl type using h2c (cleartext HTTP/2). +// It starts a local h2c proxy, builds a stream dialer via "h2connect://host:port?plain=true", +// and verifies that an HTTP request is tunneled through to a target server. +func Test_H2Connect_H2C(t *testing.T) { + t.Parallel() + + tcpDialer := &transport.TCPDialer{} + + // Start an h2c proxy server (plain HTTP/2 without TLS). + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + h2srv := &http2.Server{} + handler := httpproxy.NewConnectHandler(tcpDialer) + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go h2srv.ServeConn(conn, &http2.ServeConnOpts{Handler: handler}) + } + }() + + // Build a dialer using the configurl h2connect type. + providers := configurl.NewDefaultProviders() + dialer, err := providers.NewStreamDialer(context.Background(), + fmt.Sprintf("h2connect://%s?plain=true", ln.Addr().String()), + ) + require.NoError(t, err) + + // Start a target server that returns a JSON response. + type Response struct { + Message string `json:"message"` + } + want := Response{Message: "hello"} + targetSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(want) + })) + t.Cleanup(targetSrv.Close) + + // Make an HTTP request through the tunnel. + hc := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, addr string) (net.Conn, error) { + return dialer.DialStream(ctx, addr) + }, + }, + } + resp, err := hc.Get(targetSrv.URL) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + var got Response + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + require.Equal(t, want, got) +} diff --git a/x/configurl/module.go b/x/configurl/module.go index 5c730e12..1a4847a4 100644 --- a/x/configurl/module.go +++ b/x/configurl/module.go @@ -46,6 +46,10 @@ func RegisterDefaultProviders(c *ProviderContainer) *ProviderContainer { registerDO53StreamDialer(&c.StreamDialers, "do53", c.StreamDialers.NewInstance, c.PacketDialers.NewInstance) registerDOHStreamDialer(&c.StreamDialers, "doh", c.StreamDialers.NewInstance) + registerH2ConnectStreamDialer(&c.StreamDialers, "h2connect", c.StreamDialers.NewInstance) + registerH3ConnectStreamDialer(&c.StreamDialers, "h3connect") + registerHTTPConnectStreamDialer(&c.StreamDialers, "httpconnect", c.StreamDialers.NewInstance) + registerOverrideStreamDialer(&c.StreamDialers, "override", c.StreamDialers.NewInstance) registerOverridePacketDialer(&c.PacketDialers, "override", c.PacketDialers.NewInstance) @@ -128,7 +132,7 @@ func SanitizeConfig(configStr string) (string, error) { if err != nil { return "", err } - case "override", "split", "tls", "tlsfrag": + case "h2connect", "h3connect", "httpconnect", "override", "split", "tls", "tlsfrag": // No sanitization needed part = config.URL.String() default: diff --git a/x/httpconnect/connect_client.go b/x/httpconnect/connect_client.go index 9e66afe6..6d84a3b7 100644 --- a/x/httpconnect/connect_client.go +++ b/x/httpconnect/connect_client.go @@ -28,7 +28,8 @@ import ( // // The package also includes transport builders: // - NewHTTPProxyTransport -// - NewHTTP3ProxyTransport +// - NewH2ProxyTransport +// - NewH3ProxyTransport // // Options: // - WithHeaders appends the provided headers to every CONNECT request. @@ -39,13 +40,18 @@ type ConnectClient struct { var _ transport.StreamDialer = (*ConnectClient)(nil) +// ProxyRoundTripper is the minimal interface required by ConnectClient to send HTTP CONNECT requests. +// The Scheme method is used to construct the request URL, and the RoundTrip method is used to send the request. type ProxyRoundTripper interface { http.RoundTripper Scheme() string } +// ClientOption is an option for configuring the ConnectClient. type ClientOption func(c *clientConfig) +// NewConnectClient creates a new ConnectClient that uses the provided ProxyRoundTripper to send HTTP CONNECT requests. +// The returned client implements the [transport.StreamDialer] interface. func NewConnectClient(proxyRT ProxyRoundTripper, opts ...ClientOption) (*ConnectClient, error) { if proxyRT == nil { return nil, fmt.Errorf("transport must not be nil") @@ -75,6 +81,7 @@ type clientConfig struct { headers http.Header } +// DialStream implements the [transport.StreamDialer] interface by sending an HTTP CONNECT request to the proxy and returning a connection that tunnels to the target address. func (cc *ConnectClient) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { raddr, err := transport.MakeNetAddr("tcp", remoteAddr) if err != nil { diff --git a/x/httpconnect/connect_client_test.go b/x/httpconnect/connect_client_test.go index c8b544e2..320f6fbf 100644 --- a/x/httpconnect/connect_client_test.go +++ b/x/httpconnect/connect_client_test.go @@ -16,373 +16,413 @@ package httpconnect import ( "context" - "crypto/rand" - "crypto/rsa" stdTLS "crypto/tls" "crypto/x509" - "crypto/x509/pkix" "encoding/base64" "encoding/json" - "encoding/pem" "io" - "math/big" "net" "net/http" "net/http/httptest" "net/url" "sync" "testing" - "time" - "golang.getoutline.org/sdk/transport" - "golang.getoutline.org/sdk/transport/tls" "github.com/quic-go/quic-go/http3" "github.com/stretchr/testify/require" + "golang.getoutline.org/sdk/transport" + "golang.getoutline.org/sdk/transport/tls" + "golang.org/x/net/http2" "golang.getoutline.org/sdk/x/httpproxy" ) -func newTargetSrv(t *testing.T, resp interface{}) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) +// Compile-time check: net.Conn satisfies io.ReadWriteCloser. +var _ io.ReadWriteCloser = (net.Conn)(nil) - jsonResp, err := json.Marshal(resp) - require.NoError(t, err) +// tlsCertPool returns the built-in httptest TLS certificate and a cert pool trusting it. +// Reusing httptest's certificate avoids generating a custom CA in each test. +func tlsCertPool(t *testing.T) (stdTLS.Certificate, *x509.CertPool) { + t.Helper() + // Create a throwaway server solely to borrow its built-in TLS cert material. + srv := httptest.NewTLSServer(nil) + srv.Close() + pool := x509.NewCertPool() + pool.AddCert(srv.Certificate()) + return srv.TLS.Certificates[0], pool +} - _, err = w.Write(jsonResp) - require.NoError(t, err) +// verifyTunnel sends an HTTP GET through the given dialer to a local target server +// and asserts the response is received correctly end-to-end. +func verifyTunnel(t *testing.T, dialer transport.StreamDialer) { + t.Helper() + + type Response struct { + Message string `json:"message"` + } + want := Response{Message: "hello"} + + targetSrv := newTargetSrv(t, want) + t.Cleanup(targetSrv.Close) + + hc := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, addr string) (net.Conn, error) { + conn, err := dialer.DialStream(ctx, addr) + if err != nil { + return nil, err + } + require.Equal(t, addr, conn.RemoteAddr().String()) + return conn, nil + }, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetSrv.URL, nil) + require.NoError(t, err) + req.Close = true // close the tunnel right after the request + + resp, err := hc.Do(req) + require.NoError(t, err) + t.Cleanup(func() { resp.Body.Close() }) + + require.Equal(t, http.StatusOK, resp.StatusCode) + + var got Response + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + require.Equal(t, want, got) +} + + +// Test_ConnectClient_H1_Plain verifies that custom headers (e.g. Proxy-Authorization) +// are forwarded on every CONNECT request when using a plain HTTP/1.1 proxy. +func Test_ConnectClient_H1_Plain(t *testing.T) { + t.Parallel() + + tcpDialer := &transport.TCPDialer{} + creds := base64.StdEncoding.EncodeToString([]byte("username:password")) + + proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "Basic "+creds, r.Header.Get("Proxy-Authorization")) + httpproxy.NewConnectHandler(tcpDialer).ServeHTTP(w, r) })) + t.Cleanup(proxySrv.Close) + + proxyURL, err := url.Parse(proxySrv.URL) + require.NoError(t, err, "Parse") + + tr, err := NewHTTPProxyTransport(tcpDialer, proxyURL.Host, WithPlainHTTP()) + require.NoError(t, err, "NewHTTPProxyTransport") + + connClient, err := NewConnectClient(tr, WithHeaders(http.Header{ + "Proxy-Authorization": []string{"Basic " + creds}, + })) + require.NoError(t, err, "NewConnectClient") + + verifyTunnel(t, connClient) } -func Test_NewConnectClient_Ok(t *testing.T) { +// Test_ConnectClient_H1_TLS verifies end-to-end tunneling over a TLS-wrapped HTTP/1.1 proxy. +func Test_ConnectClient_H1_TLS(t *testing.T) { t.Parallel() - var _ io.ReadWriteCloser = (net.Conn)(nil) + tcpDialer := &transport.TCPDialer{} + + proxySrv := httptest.NewUnstartedServer(httpproxy.NewConnectHandler(tcpDialer)) + proxySrv.StartTLS() + t.Cleanup(proxySrv.Close) + + certPool := x509.NewCertPool() + certPool.AddCert(proxySrv.Certificate()) + + proxyURL, err := url.Parse(proxySrv.URL) + require.NoError(t, err, "Parse") + + tr, err := NewHTTPProxyTransport(tcpDialer, proxyURL.Host, + WithTLSOptions(tls.WithCertVerifier(&tls.StandardCertVerifier{Roots: certPool})), + ) + require.NoError(t, err, "NewHTTPProxyTransport") + + connClient, err := NewConnectClient(tr) + require.NoError(t, err, "NewConnectClient") + + verifyTunnel(t, connClient) +} + +// Test_ConnectClient_H2_TLS verifies tunneling over HTTP/2 using NewH2ProxyTransport directly with TLS. +func Test_ConnectClient_H2_TLS(t *testing.T) { + t.Parallel() tcpDialer := &transport.TCPDialer{} - h1ConnectHandler := httpproxy.NewConnectHandler(tcpDialer) - type closeFunc func() + // Use httpproxy.NewConnectHandler directly to verify it handles H2 (not just H1). + // Previously this would fail because the handler required http.Hijacker, which H2 doesn't support. + proxySrv := httptest.NewUnstartedServer(httpproxy.NewConnectHandler(tcpDialer)) + proxySrv.EnableHTTP2 = true + proxySrv.StartTLS() + t.Cleanup(proxySrv.Close) - type TestCase struct { - name string - prepareDialer func(t *testing.T) (transport.StreamDialer, closeFunc) - wantErr string - } + certPool := x509.NewCertPool() + certPool.AddCert(proxySrv.Certificate()) - tests := []TestCase{ - { - name: "ok. Plain HTTP/1 with headers", - prepareDialer: func(t *testing.T) (transport.StreamDialer, closeFunc) { - creds := base64.StdEncoding.EncodeToString([]byte("username:password")) + proxyURL, err := url.Parse(proxySrv.URL) + require.NoError(t, err, "Parse") - proxySrv := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - require.Equal(t, "Basic "+creds, request.Header.Get("Proxy-Authorization")) - h1ConnectHandler.ServeHTTP(writer, request) - })) + tr, err := NewH2ProxyTransport(tcpDialer, proxyURL.Host, + WithTLSOptions(tls.WithCertVerifier(&tls.StandardCertVerifier{Roots: certPool})), + ) + require.NoError(t, err, "NewH2ProxyTransport") - proxyURL, err := url.Parse(proxySrv.URL) - require.NoError(t, err, "Parse") + connClient, err := NewConnectClient(tr) + require.NoError(t, err, "NewConnectClient") - tr, err := NewHTTPProxyTransport(tcpDialer, proxyURL.Host, WithPlainHTTP()) - require.NoError(t, err, "NewHTTPProxyTransport") + verifyTunnel(t, connClient) +} - connClient, err := NewConnectClient(tr, WithHeaders(http.Header{ - "Proxy-Authorization": []string{"Basic " + creds}, - })) - require.NoError(t, err, "NewConnectClient") +// Test_ConnectClient_H2_TLS_HTTPTransport verifies tunneling over HTTP/2 when ALPN negotiation selects h2. +// Uses NewHTTPProxyTransport, which adds H2 support on top of net/http.Transport via ALPN. +func Test_ConnectClient_H2_TLS_HTTPTransport(t *testing.T) { + t.Parallel() - return connClient, proxySrv.Close - }, - }, - { - name: "ok. HTTP/1.1 with TLS", - prepareDialer: func(t *testing.T) (transport.StreamDialer, closeFunc) { - proxySrv := httptest.NewUnstartedServer(h1ConnectHandler) + tcpDialer := &transport.TCPDialer{} - rootCA, key := generateRootCA(t) - proxySrv.TLS = &stdTLS.Config{Certificates: []stdTLS.Certificate{key}} - certPool := x509.NewCertPool() - certPool.AddCert(rootCA) + proxySrv := httptest.NewUnstartedServer(httpproxy.NewConnectHandler(tcpDialer)) + proxySrv.EnableHTTP2 = true + proxySrv.StartTLS() + t.Cleanup(proxySrv.Close) - proxySrv.StartTLS() + certPool := x509.NewCertPool() + certPool.AddCert(proxySrv.Certificate()) - proxyURL, err := url.Parse(proxySrv.URL) - require.NoError(t, err, "Parse") + proxyURL, err := url.Parse(proxySrv.URL) + require.NoError(t, err, "Parse") - tr, err := NewHTTPProxyTransport( - tcpDialer, - proxyURL.Host, - WithTLSOptions(tls.WithCertVerifier(&tls.StandardCertVerifier{Roots: certPool})), - ) - require.NoError(t, err, "NewHTTPProxyTransport") + tr, err := NewHTTPProxyTransport(tcpDialer, proxyURL.Host, + WithTLSOptions( + tls.WithCertVerifier(&tls.StandardCertVerifier{Roots: certPool}), + tls.WithALPN([]string{"h2"}), + ), + ) + require.NoError(t, err, "NewHTTPProxyTransport") - connClient, err := NewConnectClient(tr) - require.NoError(t, err, "NewConnectClient") + connClient, err := NewConnectClient(tr) + require.NoError(t, err, "NewConnectClient") - return connClient, proxySrv.Close - }, - }, - { - name: "ok. HTTP/2 with TLS", - prepareDialer: func(t *testing.T) (transport.StreamDialer, closeFunc) { - proxySrv := httptest.NewUnstartedServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - require.Equal(t, "HTTP/2.0", request.Proto, "Proto") - require.Equal(t, http.MethodConnect, request.Method, "Method") - - conn, err := net.Dial("tcp", request.URL.Host) - require.NoError(t, err, "Dial") - defer conn.Close() - - writer.WriteHeader(http.StatusOK) - writer.(http.Flusher).Flush() - - wg := &sync.WaitGroup{} - - wg.Add(1) - go func() { - defer wg.Done() - io.Copy(conn, request.Body) - }() - - wg.Add(1) - go func() { - defer wg.Done() - // we can't use io.Copy, because it doesn't flush - fw := &flusherWriter{ - Flusher: writer.(http.Flusher), - Writer: writer, - } - fw.ReadFrom(conn) - }() - - wg.Wait() - })) - - rootCA, key := generateRootCA(t) - proxySrv.TLS = &stdTLS.Config{Certificates: []stdTLS.Certificate{key}} - certPool := x509.NewCertPool() - certPool.AddCert(rootCA) - - proxySrv.EnableHTTP2 = true - proxySrv.StartTLS() - - proxyURL, err := url.Parse(proxySrv.URL) - require.NoError(t, err, "Parse") - - tr, err := NewHTTPProxyTransport( - tcpDialer, - proxyURL.Host, - WithTLSOptions( - tls.WithCertVerifier(&tls.StandardCertVerifier{Roots: certPool}), - tls.WithALPN([]string{"h2"}), - ), - ) - require.NoError(t, err, "NewHTTPProxyTransport") - - connClient, err := NewConnectClient(tr) - require.NoError(t, err, "NewConnectClient") - - return connClient, proxySrv.Close - }, - }, - { - name: "fail. enforced HTTP/2, but server doesn't support it", - prepareDialer: func(t *testing.T) (transport.StreamDialer, closeFunc) { - connectHandler := httpproxy.NewConnectHandler(tcpDialer) - proxySrv := httptest.NewUnstartedServer(connectHandler) - - rootCA, key := generateRootCA(t) - proxySrv.TLS = &stdTLS.Config{Certificates: []stdTLS.Certificate{key}} - certPool := x509.NewCertPool() - certPool.AddCert(rootCA) - - proxySrv.StartTLS() - - proxyURL, err := url.Parse(proxySrv.URL) - require.NoError(t, err, "Parse") - - tr, err := NewHTTPProxyTransport( - tcpDialer, - proxyURL.Host, - WithTLSOptions( - tls.WithCertVerifier(&tls.StandardCertVerifier{Roots: certPool}), - tls.WithALPN([]string{"h2"}), - ), - ) - require.NoError(t, err, "NewHTTPProxyTransport") - - connClient, err := NewConnectClient(tr) - require.NoError(t, err, "NewConnectClient") - - return connClient, proxySrv.Close - }, - wantErr: "tls: no application protocol", - }, - { - name: "ok. HTTP/3 over QUIC with TLS", - prepareDialer: func(t *testing.T) (transport.StreamDialer, closeFunc) { - rootCA, key := generateRootCA(t) - certPool := x509.NewCertPool() - certPool.AddCert(rootCA) - - srvConn, err := net.ListenPacket("udp", "127.0.0.1:0") - require.NoError(t, err, "ListenPacket") - - proxySrv := &http3.Server{ - Addr: "127.0.0.1:0", - Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - require.Equal(t, "HTTP/3.0", request.Proto, "Proto") - require.Equal(t, http.MethodConnect, request.Method, "Method") - - conn, err := net.Dial("tcp", request.URL.Host) - require.NoError(t, err, "DialStream") - defer conn.Close() - - writer.WriteHeader(http.StatusOK) - writer.(http.Flusher).Flush() - - streamer, ok := writer.(http3.HTTPStreamer) - if !ok { - t.Fatal("http.ResponseWriter expected to implement http3.HTTPStreamer") - } - stream := streamer.HTTPStream() - defer stream.Close() - - wg := &sync.WaitGroup{} - - wg.Add(1) - go func() { - defer wg.Done() - io.Copy(stream, conn) - }() - - wg.Add(1) - go func() { - defer wg.Done() - io.Copy(conn, stream) - }() - - wg.Wait() - }), - TLSConfig: &stdTLS.Config{ - Certificates: []stdTLS.Certificate{key}, - }, - } - go func() { - _ = proxySrv.Serve(srvConn) - }() - - cliConn, err := net.ListenPacket("udp", "127.0.0.1:0") - require.NoError(t, err, "DialPacket") - - tr, err := NewHTTP3ProxyTransport( - cliConn.(net.PacketConn), - srvConn.LocalAddr().String(), - WithTLSOptions(tls.WithCertVerifier(&tls.StandardCertVerifier{Roots: certPool})), - ) - require.NoError(t, err, "NewHTTP3ProxyTransport") - - connClient, err := NewConnectClient(tr) - require.NoError(t, err, "NewConnectClient") - - return connClient, func() { - _ = cliConn.Close() - _ = proxySrv.Close() - _ = srvConn.Close() - } - }, - }, - } + verifyTunnel(t, connClient) +} - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() +// Test_ConnectClient_H2_TLS_AlpnFails verifies that when the client enforces H2 via ALPN +// but the server only supports H1, the TLS handshake fails with a clear protocol error. +func Test_ConnectClient_H2_TLS_AlpnFails(t *testing.T) { + t.Parallel() - type Response struct { - Message string `json:"message"` - } - wantResp := Response{Message: "hello"} - - targetSrv := newTargetSrv(t, wantResp) - defer targetSrv.Close() - - connClient, srvCloser := tt.prepareDialer(t) - defer srvCloser() - - hc := &http.Client{ - Transport: &http.Transport{ - DialContext: func(ctx context.Context, _, addr string) (net.Conn, error) { - conn, err := connClient.DialStream(ctx, addr) - if err != nil { - return nil, err - } - require.Equal(t, conn.RemoteAddr().String(), addr) - - return conn, nil - }, - }, - } + tcpDialer := &transport.TCPDialer{} + + // H1-only server: no EnableHTTP2. + proxySrv := httptest.NewUnstartedServer(httpproxy.NewConnectHandler(tcpDialer)) + proxySrv.StartTLS() + t.Cleanup(proxySrv.Close) + + certPool := x509.NewCertPool() + certPool.AddCert(proxySrv.Certificate()) + + proxyURL, err := url.Parse(proxySrv.URL) + require.NoError(t, err, "Parse") + + tr, err := NewHTTPProxyTransport(tcpDialer, proxyURL.Host, + WithTLSOptions( + tls.WithCertVerifier(&tls.StandardCertVerifier{Roots: certPool}), + tls.WithALPN([]string{"h2"}), + ), + ) + require.NoError(t, err, "NewHTTPProxyTransport") + + connClient, err := NewConnectClient(tr) + require.NoError(t, err, "NewConnectClient") + + _, err = connClient.DialStream(context.Background(), "127.0.0.1:1") + require.ErrorContains(t, err, "tls: no application protocol") +} + +// Test_ConnectClient_H2C verifies tunneling over cleartext HTTP/2 (h2c) via prior knowledge. +// Uses NewH2ProxyTransport with WithPlainHTTP(): no TLS, no HTTP upgrade — H2 from the first byte. +func Test_ConnectClient_H2C(t *testing.T) { + t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetSrv.URL, nil) - req.Close = true // close the connection after the request to close the tunnel right away - require.NoError(t, err, "NewRequest") + tcpDialer := &transport.TCPDialer{} - resp, err := hc.Do(req) - if tt.wantErr != "" { - require.Contains(t, err.Error(), tt.wantErr, "Do") + // Serve H2 prior knowledge (no TLS, no upgrade) on a raw TCP listener. + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "Listen") + t.Cleanup(func() { ln.Close() }) + + h2srv := &http2.Server{} + handler := httpproxy.NewConnectHandler(tcpDialer) + go func() { + for { + conn, err := ln.Accept() + if err != nil { return } - require.NoError(t, err, "Do") - defer resp.Body.Close() + go h2srv.ServeConn(conn, &http2.ServeConnOpts{Handler: handler}) + } + }() - require.Equal(t, http.StatusOK, resp.StatusCode) + tr, err := NewH2ProxyTransport(tcpDialer, ln.Addr().String(), WithPlainHTTP()) + require.NoError(t, err, "NewH2ProxyTransport") - var gotResp Response - err = json.NewDecoder(resp.Body).Decode(&gotResp) - require.NoError(t, err, "Decode") + connClient, err := NewConnectClient(tr) + require.NoError(t, err, "NewConnectClient") - require.Equal(t, wantResp, gotResp, "Response") - }) - } + verifyTunnel(t, connClient) } -func generateRootCA(t *testing.T) (*x509.Certificate, stdTLS.Certificate) { - t.Helper() +// Test_ConnectClient_H2_TLS_Multiplexed verifies that NewH2ProxyTransport uses a single +// underlying TCP connection to the proxy for multiple concurrent CONNECT streams. +func Test_ConnectClient_H2_TLS_Multiplexed(t *testing.T) { + t.Parallel() - privKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) + tcpDialer := &transport.TCPDialer{} - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{Organization: []string{"Test Root CA"}}, - NotBefore: time.Now().Add(-1 * time.Hour), - NotAfter: time.Now().Add(24 * time.Hour), - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, - BasicConstraintsValid: true, - IsCA: true, - IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + proxySrv := httptest.NewUnstartedServer(httpproxy.NewConnectHandler(tcpDialer)) + proxySrv.EnableHTTP2 = true + proxySrv.StartTLS() + t.Cleanup(proxySrv.Close) + + certPool := x509.NewCertPool() + certPool.AddCert(proxySrv.Certificate()) + + proxyURL, err := url.Parse(proxySrv.URL) + require.NoError(t, err, "Parse") + + // Wrap the dialer to count how many TCP connections are opened to the proxy. + var mu sync.Mutex + var dialCount int + countingDialer := transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + mu.Lock() + dialCount++ + mu.Unlock() + return tcpDialer.DialStream(ctx, addr) + }) + + tr, err := NewH2ProxyTransport(countingDialer, proxyURL.Host, + WithTLSOptions(tls.WithCertVerifier(&tls.StandardCertVerifier{Roots: certPool})), + ) + require.NoError(t, err, "NewH2ProxyTransport") + + connClient, err := NewConnectClient(tr) + require.NoError(t, err, "NewConnectClient") + + // Open 3 concurrent tunnels and assert they all share 1 TCP connection to the proxy. + targetSrv := newTargetSrv(t, "ignored") + t.Cleanup(targetSrv.Close) + targetURL, err := url.Parse(targetSrv.URL) + require.NoError(t, err, "Parse") + + var wg sync.WaitGroup + for range 3 { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := connClient.DialStream(context.Background(), targetURL.Host) + require.NoError(t, err, "DialStream") + conn.Close() + }() } + wg.Wait() - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey) - require.NoError(t, err) + mu.Lock() + require.Equal(t, 1, dialCount, "expected all streams to share 1 TCP connection to proxy") + mu.Unlock() - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}) - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + verifyTunnel(t, connClient) +} - tlsCert, err := stdTLS.X509KeyPair(certPEM, keyPEM) - require.NoError(t, err) +// Test_ConnectClient_H3_QUIC verifies tunneling over HTTP/3, where CONNECT streams +// run over QUIC rather than TCP. Uses http3.HTTPStreamer to access the raw H3 stream. +func Test_ConnectClient_H3_QUIC(t *testing.T) { + t.Parallel() - cert, err := x509.ParseCertificate(certDER) - require.NoError(t, err) + // http3.Server requires its own TLS config; borrow httptest's built-in cert. + tlsCert, certPool := tlsCertPool(t) + + srvConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err, "ListenPacket") + + proxySrv := &http3.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "HTTP/3.0", r.Proto, "Proto") + require.Equal(t, http.MethodConnect, r.Method, "Method") + + conn, err := net.Dial("tcp", r.URL.Host) + require.NoError(t, err, "Dial") + defer conn.Close() + + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + streamer, ok := w.(http3.HTTPStreamer) + require.True(t, ok, "expected http3.HTTPStreamer") + stream := streamer.HTTPStream() + defer stream.Close() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + io.Copy(stream, conn) + }() + wg.Add(1) + go func() { + defer wg.Done() + io.Copy(conn, stream) + }() + wg.Wait() + }), + TLSConfig: &stdTLS.Config{Certificates: []stdTLS.Certificate{tlsCert}}, + } + go func() { _ = proxySrv.Serve(srvConn) }() + t.Cleanup(func() { + _ = proxySrv.Close() + _ = srvConn.Close() + }) + + cliConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err, "ListenPacket") + t.Cleanup(func() { _ = cliConn.Close() }) + + tr, err := NewH3ProxyTransport(cliConn.(net.PacketConn), srvConn.LocalAddr().String(), + WithTLSOptions(tls.WithCertVerifier(&tls.StandardCertVerifier{Roots: certPool})), + ) + require.NoError(t, err, "NewH3ProxyTransport") - return cert, tlsCert + connClient, err := NewConnectClient(tr) + require.NoError(t, err, "NewConnectClient") + + verifyTunnel(t, connClient) +} + +// newTargetSrv starts a local HTTP server that responds to any request with resp serialized as JSON. +// It represents the tunnel destination — the server the client reaches through the proxy. +func newTargetSrv(t *testing.T, resp interface{}) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + jsonResp, err := json.Marshal(resp) + require.NoError(t, err) + + _, err = w.Write(jsonResp) + require.NoError(t, err) + })) } +// flusherWriter wraps an http.ResponseWriter so that every Write is followed by a Flush. +// This is required when relaying data over an H2 response stream: without explicit flushing, +// written bytes sit in the buffer and the remote end never receives them. type flusherWriter struct { http.Flusher io.Writer diff --git a/x/httpconnect/doc.go b/x/httpconnect/doc.go index 54418797..737fef75 100644 --- a/x/httpconnect/doc.go +++ b/x/httpconnect/doc.go @@ -12,5 +12,91 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package httpconnect contains an HTTP CONNECT client implementation. +// Package httpconnect provides an HTTP CONNECT tunnel client that works over +// HTTP/1.1, HTTP/2, and HTTP/3. +// +// # Overview +// +// An HTTP CONNECT proxy accepts a CONNECT request from the client, dials the +// requested host:port, responds with 200 OK, and then relays bytes in both +// directions. This package implements the client side of that exchange. +// +// The entry points are: +// - [NewConnectClient]: creates a [transport.StreamDialer] that tunnels through the proxy. +// - [NewHTTPProxyTransport]: HTTP/1.1 (or H2 via ALPN) transport for [NewConnectClient]. +// - [NewH2ProxyTransport]: pure HTTP/2 transport; supports h2c and multiplexes all CONNECT +// streams over a single TCP connection. +// - [NewH3ProxyTransport]: HTTP/3 over QUIC transport. It multiplexes all CONNECT streams +// over a single QUIC connection. +// +// # Manual testing with Caddy +// +// Caddy (https://caddyserver.com) with the forward_proxy community module +// (https://github.com/caddyserver/forwardproxy) is a convenient way to run a +// local CONNECT proxy that supports H1, H2, and H3. Use xcaddy to build it: +// +// go install github.com/caddyserver/xcaddy/cmd/xcaddy@latest +// xcaddy build --with github.com/caddyserver/forwardproxy +// +// Then write a Caddyfile. Use tls internal for a locally-trusted certificate +// (requires the Caddy root CA to be trusted — run "caddy trust" once): +// +// # Caddyfile +// :8443 { +// tls internal +// forward_proxy +// } +// +// Or supply your own certificate: +// +// # Caddyfile +// :8443 { +// tls /path/to/cert.pem /path/to/key.pem +// forward_proxy +// } +// +// Start the proxy: +// +// caddy run --config Caddyfile +// +// Caddy serves H1, H2, and H3 on the same port. H3 runs over UDP on the same +// port number as the TLS listener. +// +// # Connecting with this package +// +// Connect over H2 (multiplexed — all CONNECT streams share one TCP connection): +// +// tr, err := httpconnect.NewH2ProxyTransport(&transport.TCPDialer{}, "127.0.0.1:8443", +// httpconnect.WithTLSOptions( +// tls.WithCertVerifier(&tls.StandardCertVerifier{}), // uses system roots +// ), +// ) +// client, err := httpconnect.NewConnectClient(tr) +// conn, err := client.DialStream(ctx, "example.com:443") +// +// Connect over H3 (QUIC): +// +// udpConn, err := net.ListenPacket("udp", "127.0.0.1:0") +// tr, err := httpconnect.NewH3ProxyTransport(udpConn, "127.0.0.1:8443", +// httpconnect.WithTLSOptions( +// tls.WithCertVerifier(&tls.StandardCertVerifier{}), +// ), +// ) +// client, err := httpconnect.NewConnectClient(tr) +// conn, err := client.DialStream(ctx, "example.com:443") +// +// # Notes on the forward_proxy module and localhost +// +// By default the forward_proxy module denies connections to private/loopback +// addresses (127.0.0.0/8, 10.0.0.0/8, etc.) as an SSRF mitigation. To allow +// them in a test environment, add an explicit ACL rule in the Caddyfile: +// +// :8443 { +// tls internal +// forward_proxy { +// acl { +// allow 127.0.0.0/8 +// } +// } +// } package httpconnect diff --git a/x/httpconnect/transport.go b/x/httpconnect/transport.go index 93bcd0b9..cf41165a 100644 --- a/x/httpconnect/transport.go +++ b/x/httpconnect/transport.go @@ -93,12 +93,77 @@ func NewHTTPProxyTransport(dialer transport.StreamDialer, proxyAddr string, opts }, nil } -// NewHTTP3ProxyTransport creates an HTTP/3 transport that establishes a QUIC connection to the proxy using the given [net.PacketConn]. +// NewH2ProxyTransport creates a pure HTTP/2 transport that establishes a connection to the proxy +// using the given [transport.StreamDialer]. +// The proxy address must be in the form "host:port". +// +// Unlike [NewHTTPProxyTransport], this uses [golang.org/x/net/http2.Transport] directly, enabling: +// - h2c (cleartext HTTP/2 via prior knowledge) with [WithPlainHTTP] — no TLS required +// - Pure H2 from byte 1: multiple concurrent CONNECT tunnels share one TCP connection +func NewH2ProxyTransport(dialer transport.StreamDialer, proxyAddr string, opts ...TransportOption) (ProxyRoundTripper, error) { + if dialer == nil { + return nil, errors.New("dialer must not be nil") + } + host, _, err := net.SplitHostPort(proxyAddr) + if err != nil { + return nil, fmt.Errorf("failed to parse proxy address %s: %w", proxyAddr, err) + } + + cfg := &transportConfig{} + cfg.applyOptions(opts...) + + var tr *http2.Transport + if cfg.plainHTTP { + tr = &http2.Transport{ + AllowHTTP: true, + // DialTLSContext is used even for plaintext when AllowHTTP is true. + DialTLSContext: func(ctx context.Context, _, _ string, _ *stdTLS.Config) (net.Conn, error) { + return dialer.DialStream(ctx, proxyAddr) + }, + } + } else { + tlsCfg := tls.ClientConfig{ServerName: host} + for _, opt := range cfg.tlsOptions { + opt(host, &tlsCfg) + } + stdCfg := toStdConfig(tlsCfg) + // Ensure "h2" is in ALPN NextProtos so the server negotiates HTTP/2. + stdCfg.NextProtos = append([]string{"h2"}, stdCfg.NextProtos...) + tr = &http2.Transport{ + // http2.Transport type-asserts to *tls.Conn to read NegotiatedProtocol, + // so we must return *stdTLS.Conn directly — not an sdk tls.WrapConn wrapper. + DialTLSContext: func(ctx context.Context, _, _ string, _ *stdTLS.Config) (net.Conn, error) { + conn, err := dialer.DialStream(ctx, proxyAddr) + if err != nil { + return nil, err + } + tlsConn := stdTLS.Client(conn, stdCfg) + if err := tlsConn.HandshakeContext(ctx); err != nil { + conn.Close() + return nil, err + } + return tlsConn, nil + }, + } + } + + sch := schemeHTTPS + if cfg.plainHTTP { + sch = schemeHTTP + } + + return proxyRT{ + RoundTripper: tr, + scheme: sch, + }, nil +} + +// NewH3ProxyTransport creates an HTTP/3 transport that establishes a QUIC connection to the proxy using the given [net.PacketConn]. // The proxy address must be in the form "host:port". // // For HTTP/3 over QUIC over a datagram connection. // [tls.WithALPN] has no effect on this transport. -func NewHTTP3ProxyTransport(conn net.PacketConn, proxyAddr string, opts ...TransportOption) (ProxyRoundTripper, error) { +func NewH3ProxyTransport(conn net.PacketConn, proxyAddr string, opts ...TransportOption) (ProxyRoundTripper, error) { if conn == nil { return nil, errors.New("conn must not be nil") } diff --git a/x/httpproxy/connect_handler.go b/x/httpproxy/connect_handler.go index 3197694a..75f284e9 100644 --- a/x/httpproxy/connect_handler.go +++ b/x/httpproxy/connect_handler.go @@ -22,9 +22,9 @@ import ( "net" "net/http" "strings" + "sync" "golang.getoutline.org/sdk/transport" - "golang.getoutline.org/sdk/x/configurl" ) type sanitizeErrorDialer struct { @@ -50,9 +50,24 @@ func (d *sanitizeErrorDialer) DialStream(ctx context.Context, addr string) (tran return conn, nil } +// StreamDialerParser creates a [transport.StreamDialer] from a config string. +// It is used by [NewConnectHandler] to support the Transport request header. +type StreamDialerParser func(ctx context.Context, config string) (transport.StreamDialer, error) + +// HandlerOption configures a connect handler. +type HandlerOption func(*connectHandler) + +// WithStreamDialerParser sets a factory that creates a dialer from the Transport request header value. +// When set, clients can override the transport per-request by sending a Transport header. +func WithStreamDialerParser(f StreamDialerParser) HandlerOption { + return func(h *connectHandler) { + h.dialerFactory = f + } +} + type connectHandler struct { - dialer *sanitizeErrorDialer - providers *configurl.ProviderContainer + dialer *sanitizeErrorDialer + dialerFactory StreamDialerParser } var _ http.Handler = (*connectHandler)(nil) @@ -75,81 +90,151 @@ func (h *connectHandler) ServeHTTP(proxyResp http.ResponseWriter, proxyReq *http return } - // Dial the target. - transportConfig := proxyReq.Header.Get("Transport") - dialer, err := h.providers.NewStreamDialer(proxyReq.Context(), transportConfig) - if err != nil { - // Because we sanitize the base dialer error, it's safe to return error details here. - http.Error(proxyResp, fmt.Sprintf("Invalid config in Transport header: %v", err), http.StatusBadRequest) - return + // Dial the target, optionally using a per-request transport from the Transport header. + var dialer transport.StreamDialer = h.dialer + if transportConfig := proxyReq.Header.Get("Transport"); transportConfig != "" { + if h.dialerFactory == nil { + http.Error(proxyResp, "Transport header is not supported", http.StatusBadRequest) + return + } + var err error + dialer, err = h.dialerFactory(proxyReq.Context(), transportConfig) + if err != nil { + // Because we sanitize the base dialer error, it's safe to return error details here. + http.Error(proxyResp, fmt.Sprintf("Invalid config in Transport header: %v", err), http.StatusBadRequest) + return + } } - targetConn, err := dialer.DialStream(proxyReq.Context(), proxyReq.Host) - if err != nil { - http.Error(proxyResp, fmt.Sprintf("Failed to connect to %v: %v", proxyReq.Host, err), http.StatusServiceUnavailable) + targetConn, err2 := dialer.DialStream(proxyReq.Context(), proxyReq.Host) + if err2 != nil { + http.Error(proxyResp, fmt.Sprintf("Failed to connect to %v: %v", proxyReq.Host, err2), http.StatusServiceUnavailable) return } defer targetConn.Close() - hijacker, ok := proxyResp.(http.Hijacker) - if !ok { - http.Error(proxyResp, "Webserver doesn't support hijacking", http.StatusInternalServerError) - return - } - - httpConn, clientRW, err := hijacker.Hijack() - if err != nil { - http.Error(proxyResp, "Failed to hijack connection", http.StatusInternalServerError) - return + // Set up protocol-specific client I/O. H1 hijacks the raw connection; H2/H3 stream + // through the ResponseWriter with explicit flushing after each write. + var clientReader io.Reader + var clientWriter io.ReaderFrom + var afterCopy func() + if hijacker, ok := proxyResp.(http.Hijacker); ok { + // H1: hijack the raw connection and relay using the underlying bufio.ReadWriter. + httpConn, clientRW, err := hijacker.Hijack() + if err != nil { + http.Error(proxyResp, "Failed to hijack connection", http.StatusInternalServerError) + return + } + // TODO(fortuna): Use context.AfterFunc after we migrate to Go 1.21. + go func() { + // We close the hijacked connection when the context is done. This way + // we allow the HTTP server to control the request lifetime. + // The request context will be cancelled right after ServeHTTP returns, + // but it can be cancelled before, if the server uses a custom BaseContext. + <-proxyReq.Context().Done() + httpConn.Close() + }() + clientRW.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) + clientRW.Flush() + // clientRW (bufio.ReadWriter) implements io.ReaderFrom via its embedded bufio.Writer. + clientReader = clientRW + clientWriter = clientRW + // afterCopy flushes the bufio buffer to push any remaining bytes to the client. + afterCopy = func() { clientRW.Flush() } + } else { + // H2/H3: hijacking is not available on multiplexed connections. + flusher, ok := proxyResp.(http.Flusher) + if !ok { + http.Error(proxyResp, "Webserver doesn't support flushing", http.StatusInternalServerError) + return + } + proxyResp.WriteHeader(http.StatusOK) + flusher.Flush() + // flushingWriter flushes after every write, so no afterCopy flush is needed. + clientReader = proxyReq.Body + clientWriter = &flushingWriter{w: proxyResp, f: flusher} + afterCopy = func() {} } - // TODO(fortuna): Use context.AfterFunc after we migrate to Go 1.21. - go func() { - // We close the hijacked connection when the context is done. This way - // we allow the HTTP server to control the request lifetime. - // The request context will be cancelled right after ServeHTTP returns, - // but it can be cancelled before, if the server uses a custom BaseContext. - <-proxyReq.Context().Done() - httpConn.Close() - }() - - // Inform the client that the connection has been established. - clientRW.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) - clientRW.Flush() // Relay data between client and target in both directions. + var wg sync.WaitGroup + wg.Add(1) go func() { + defer wg.Done() // io.Copy prefers WriteTo, which clientRW implements. However, // bufio.ReadWriter.WriteTo issues an empty Write() call, which flushes // the Shadowsocks IV and connect request, breaking the coalescing with // the initial data. By preferring ReaderFrom, the coalescing of IV, // request and initial data is preserved. if rf, ok := targetConn.(io.ReaderFrom); ok { - rf.ReadFrom(clientRW) + rf.ReadFrom(clientReader) } else { - io.Copy(targetConn, clientRW) + io.Copy(targetConn, clientReader) } targetConn.CloseWrite() }() // We can't use io.Copy here because it doesn't call Flush on writes, so the first - // write is never sent and the entire relay gets stuck. bufio.Writer.ReadFrom takes - // care of that. - clientRW.ReadFrom(targetConn) - clientRW.Flush() + // write is never sent and the entire relay gets stuck. bufio.Writer.ReadFrom (H1) + // and flushingWriter.ReadFrom (H2/H3) take care of that. + clientWriter.ReadFrom(targetConn) + afterCopy() + wg.Wait() +} + +// flushingWriter wraps an http.ResponseWriter and flushes after every write, +// ensuring bytes are sent to the client immediately over H2/H3 streams. +type flushingWriter struct { + w http.ResponseWriter + f http.Flusher +} + +func (fw *flushingWriter) Write(b []byte) (int, error) { + n, err := fw.w.Write(b) + fw.f.Flush() + return n, err +} + +// ReadFrom shadows http.ResponseWriter's own ReadFrom (present in net/http's *response), +// which does not flush. This implementation flushes after every write so bytes reach +// the client immediately, and prefers r.WriteTo to avoid an intermediate buffer. +func (fw *flushingWriter) ReadFrom(r io.Reader) (int64, error) { + if wt, ok := r.(io.WriterTo); ok { + return wt.WriteTo(fw) + } + buf := make([]byte, 32*1024) + var n int64 + for { + nr, er := r.Read(buf) + if nr > 0 { + nw, ew := fw.Write(buf[:nr]) + n += int64(nw) + if ew != nil { + return n, ew + } + } + if er == io.EOF { + return n, nil + } + if er != nil { + return n, er + } + } } // NewConnectHandler creates a [http.Handler] that handles CONNECT requests and forwards // the requests using the given [transport.StreamDialer]. // -// Clients can specify a Transport header with a value of a transport config as specified in -// the [configurl] package to specify the transport for a given request. +// Use [WithStreamDialerParser] to support the Transport request header, which allows clients +// to specify a per-request transport config. // // The resulting handler is currently vulnerable to probing attacks. It's ok as a localhost proxy // but it may be vulnerable if used as a public proxy. -func NewConnectHandler(dialer transport.StreamDialer) http.Handler { +func NewConnectHandler(dialer transport.StreamDialer, opts ...HandlerOption) http.Handler { // We sanitize the errors from the input Dialer because we don't want to leak sensitive details // of the base dialer (e.g. access key credentials) to the user. sd := &sanitizeErrorDialer{dialer} - // TODO(fortuna): Inject the config parser - providers := configurl.NewDefaultProviders() - providers.StreamDialers.BaseInstance = sd - return &connectHandler{sd, providers} + h := &connectHandler{dialer: sd} + for _, opt := range opts { + opt(h) + } + return h }