diff --git a/tcpproxy.go b/tcpproxy.go index 9826d94..d09b1cb 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -347,8 +347,6 @@ func UnderlyingConn(c net.Conn) net.Conn { return c } -func goCloseConn(c net.Conn) { go c.Close() } - // HandleConn implements the Target interface. func (dp *DialProxy) HandleConn(src net.Conn) { ctx := context.Background() @@ -364,13 +362,13 @@ func (dp *DialProxy) HandleConn(src net.Conn) { dp.onDialError()(src, err) return } - defer goCloseConn(dst) + defer dst.Close() if err = dp.sendProxyHeader(dst, src); err != nil { dp.onDialError()(src, err) return } - defer goCloseConn(src) + defer src.Close() if ka := dp.keepAlivePeriod(); ka > 0 { if c, ok := UnderlyingConn(src).(*net.TCPConn); ok { @@ -386,7 +384,12 @@ func (dp *DialProxy) HandleConn(src net.Conn) { errc := make(chan error, 1) go proxyCopy(errc, src, dst) go proxyCopy(errc, dst, src) - <-errc + + for i := 0; i < 2; i++ { + if err = <-errc; err != nil { + return + } + } } func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { @@ -437,6 +440,14 @@ func proxyCopy(errc chan<- error, dst, src net.Conn) { dst = UnderlyingConn(dst) _, err := io.Copy(dst, src) + + if tcpConn, ok := dst.(*net.TCPConn); ok { + tcpConn.CloseWrite() + } + if tcpConn, ok := src.(*net.TCPConn); ok { + tcpConn.CloseRead() + } + errc <- err } diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 5d75cc3..7c2c4e9 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -362,7 +362,7 @@ func (t *tlsServer) Close() { // cert creates a well-formed, but completely insecure self-signed // cert for domain. func cert(t *testing.T, domain string) tls.Certificate { - private, err := rsa.GenerateKey(rand.Reader, 512) + private, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatal(err) }