diff --git a/protocol.go b/protocol.go index 917ee62..ec7291f 100644 --- a/protocol.go +++ b/protocol.go @@ -56,10 +56,12 @@ type Listener struct { // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will // return the address of the client instead of the proxy address. type Conn struct { - bufReader *bufio.Reader - conn net.Conn - dstAddr *net.TCPAddr - srcAddr *net.TCPAddr + bufReader *bufio.Reader + conn net.Conn + dstAddr *net.TCPAddr + srcAddr *net.TCPAddr + // Any error encountered while reading the proxyproto header + proxyErr error useConnAddr bool once sync.Once proxyHeaderTimeout time.Duration @@ -158,7 +160,7 @@ func (p *Conn) LocalAddr() net.Addr { // protocol is being used, otherwise just returns the address of // the socket peer. If there is an error parsing the header, the // address of the client is not returned, and the socket is closed. -// Once implication of this is that the call could block if the +// One implication of this is that the call could block if the // client is slow. Using a Deadline is recommended if this is called // before Read() func (p *Conn) RemoteAddr() net.Addr { @@ -169,6 +171,22 @@ func (p *Conn) RemoteAddr() net.Addr { return p.conn.RemoteAddr() } +// ProxySourceAddr returns the source address according to the proxyproto. +// If there was an error parsing the proxy header, that error will be returned. +// This call will read the proxy header if it hasn't been read yet, and thus +// using a Deadline is recommended if this is called before Read(). +// This method, if called, can be used to reliably check if the connection is +// using a proxy. +// If UnknownTrue is set on the listener, ProxySourcAddr may return 'nil, nil' +// in the case of a proxy protocol being used with PROXY UNKNOWN. +func (p *Conn) ProxySourceAddr() (net.Addr, error) { + p.checkPrefixOnce() + if p.srcAddr == nil { + return nil, p.proxyErr + } + return p.srcAddr, p.proxyErr +} + func (p *Conn) SetDeadline(t time.Time) error { return p.conn.SetDeadline(t) } @@ -203,6 +221,7 @@ func (p *Conn) checkPrefix() error { inp, err := p.bufReader.Peek(i) if err != nil { + p.proxyErr = fmt.Errorf("error while trying to read proxy header: %w", err) if neterr, ok := err.(net.Error); ok && neterr.Timeout() { return nil } else { @@ -212,6 +231,7 @@ func (p *Conn) checkPrefix() error { // Check for a prefix mis-match, quit early if !bytes.Equal(inp, prefix[:i]) { + p.proxyErr = fmt.Errorf("connection read did not match proxy header") return nil } } @@ -219,6 +239,7 @@ func (p *Conn) checkPrefix() error { // Read the header line header, err := p.bufReader.ReadString('\n') if err != nil { + p.proxyErr = fmt.Errorf("error reading first proxyheader line: %w", err) p.conn.Close() return err } @@ -230,7 +251,8 @@ func (p *Conn) checkPrefix() error { parts := strings.Split(header, " ") if len(parts) < 2 { p.conn.Close() - return fmt.Errorf("Invalid header line: %s", header) + p.proxyErr = fmt.Errorf("invalid header line: %s", header) + return p.proxyErr } // Verify the type is known @@ -238,7 +260,8 @@ func (p *Conn) checkPrefix() error { case "UNKNOWN": if !p.unknownOK || len(parts) != 2 { p.conn.Close() - return fmt.Errorf("Invalid UNKNOWN header line: %s", header) + p.proxyErr = fmt.Errorf("invalid UNKNOWN header line: %s", header) + return p.proxyErr } p.useConnAddr = true return nil @@ -246,24 +269,28 @@ func (p *Conn) checkPrefix() error { case "TCP6": default: p.conn.Close() - return fmt.Errorf("Unhandled address type: %s", parts[1]) + p.proxyErr = fmt.Errorf("Unhandled address type: %s", parts[1]) + return p.proxyErr } if len(parts) != 6 { p.conn.Close() - return fmt.Errorf("Invalid header line: %s", header) + p.proxyErr = fmt.Errorf("Invalid header line (should have 6 parts): %s", header) + return p.proxyErr } // Parse out the source address ip := net.ParseIP(parts[2]) if ip == nil { p.conn.Close() - return fmt.Errorf("Invalid source ip: %s", parts[2]) + p.proxyErr = fmt.Errorf("Invalid source ip: %s", parts[2]) + return p.proxyErr } port, err := strconv.Atoi(parts[4]) if err != nil { p.conn.Close() - return fmt.Errorf("Invalid source port: %s", parts[4]) + p.proxyErr = fmt.Errorf("Invalid source port: %s", parts[4]) + return p.proxyErr } p.srcAddr = &net.TCPAddr{IP: ip, Port: port} @@ -271,12 +298,14 @@ func (p *Conn) checkPrefix() error { ip = net.ParseIP(parts[3]) if ip == nil { p.conn.Close() - return fmt.Errorf("Invalid destination ip: %s", parts[3]) + p.proxyErr = fmt.Errorf("Invalid destination ip: %s", parts[3]) + return p.proxyErr } port, err = strconv.Atoi(parts[5]) if err != nil { p.conn.Close() - return fmt.Errorf("Invalid destination port: %s", parts[5]) + p.proxyErr = fmt.Errorf("Invalid destination port: %s", parts[5]) + return p.proxyErr } p.dstAddr = &net.TCPAddr{IP: ip, Port: port} diff --git a/protocol_test.go b/protocol_test.go index 0abe798..af1fb62 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -62,6 +62,10 @@ func TestPassthrough(t *testing.T) { if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } + + if src, err := conn.(*Conn).ProxySourceAddr(); err == nil { + t.Fatalf("expected error on passthrough, but got nil and src %v", src) + } } func TestTimeout(t *testing.T) { @@ -185,6 +189,13 @@ func TestParse_ipv4(t *testing.T) { if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } + src, err := conn.(*Conn).ProxySourceAddr() + if err != nil { + t.Fatalf("expected no error on proxy source addr: %v", err) + } + if src != addr { + t.Fatalf("expected addrs to match in working proxy case: %v != %v", src, addr) + } } func TestParse_ipv6(t *testing.T) { @@ -244,6 +255,13 @@ func TestParse_ipv6(t *testing.T) { if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } + src, err := conn.(*Conn).ProxySourceAddr() + if err != nil { + t.Fatalf("expected no error on proxy source addr: %v", err) + } + if src != addr { + t.Fatalf("expected addrs to match in working proxy case: %v != %v", src, addr) + } } func TestParse_Unknown(t *testing.T) { @@ -294,7 +312,13 @@ func TestParse_Unknown(t *testing.T) { if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } - + src, err := conn.(*Conn).ProxySourceAddr() + if err != nil { + t.Fatalf("expected no error on proxy source addr for UNKNOWN: %v", err) + } + if src != nil { + t.Fatalf("expected src addr to be nil on UNKNOWN proxy: %v", src) + } } func TestParse_BadHeader(t *testing.T) { @@ -337,6 +361,11 @@ func TestParse_BadHeader(t *testing.T) { t.Fatalf("bad: %v", addr) } + // ProxySourceAddr should return the error + if _, err := conn.(*Conn).ProxySourceAddr(); err == nil { + t.Fatalf("expected an error when the proxy header was wrong") + } + // Read should fail recv := make([]byte, 4) _, err = conn.Read(recv)