From 7e4f85b1457e9c8f34e2742e9cd4469fd5e7446d Mon Sep 17 00:00:00 2001 From: Steven Yang Date: Tue, 10 Apr 2018 18:34:40 -0700 Subject: [PATCH] Refactored the test files with helpers to test backend ``` func testRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) func testNotRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) <-chan bool ``` --- tcpproxy_test.go | 149 ++++++++++++++++++++++++----------------------- 1 file changed, 75 insertions(+), 74 deletions(-) diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 682214d..0061cf2 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -169,38 +169,90 @@ func testProxy(t *testing.T, front net.Listener) *Proxy { } } -func TestProxyAlwaysMatch(t *testing.T) { - front := newLocalListener(t) - defer front.Close() - back := newLocalListener(t) - defer back.Close() +func testRouteToBackendWithExpected(t *testing.T, toFront net.Conn, back net.Listener, msg string, expected string) { + io.WriteString(toFront, msg) + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } - p := testProxy(t, front) - p.AddRoute(testFrontAddr, To(back.Addr().String())) - if err := p.Start(); err != nil { + buf := make([]byte, len(expected)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { t.Fatal(err) } + if string(buf) != expected { + t.Fatalf("got %q; want %q", buf, expected) + } +} +func testRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) { toFront, err := net.Dial("tcp", front.Addr().String()) if err != nil { t.Fatal(err) } defer toFront.Close() - fromProxy, err := back.Accept() + testRouteToBackendWithExpected(t, toFront, back, msg, msg) +} + +// test the backend is not receiving traffic +func testNotRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) <-chan bool { + done := make(chan bool) + toFront, err := net.Dial("tcp", front.Addr().String()) if err != nil { t.Fatal(err) } - const msg = "message" - io.WriteString(toFront, msg) + defer toFront.Close() - buf := make([]byte, len(msg)) - if _, err := io.ReadFull(fromProxy, buf); err != nil { + timeC := time.NewTimer(10 * time.Millisecond).C + acceptC := make(chan struct{}) + go func() { + io.WriteString(toFront, msg) + fromProxy, err := back.Accept() + acceptC <- struct{}{} + { + if err == nil { + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + t.Fatalf("Expect backend to not receive message, but found %s", string(buf)) + } + err, ok := err.(net.Error) + if !ok || !err.Timeout() { + t.Fatalf("Expect backend to timeout, but found err: %v", err) + } + } + }() + go func() { + select { + case <-timeC: + { + done <- true + } + case <-acceptC: + { + t.Fatal("Expect backend to not receive message") + done <- true + } + } + }() + return done +} + +func TestProxyAlwaysMatch(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, To(back.Addr().String())) + if err := p.Start(); err != nil { t.Fatal(err) } - if string(buf) != msg { - t.Fatalf("got %q; want %q", buf, msg) - } + + testRouteToBackend(t, front, back, "message") } func TestProxyHTTP(t *testing.T) { @@ -219,27 +271,9 @@ func TestProxyHTTP(t *testing.T) { t.Fatal(err) } - toFront, err := net.Dial("tcp", front.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer toFront.Close() - - const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n" - io.WriteString(toFront, msg) - - fromProxy, err := backBar.Accept() - if err != nil { - t.Fatal(err) - } - - buf := make([]byte, len(msg)) - if _, err := io.ReadFull(fromProxy, buf); err != nil { - t.Fatal(err) - } - if string(buf) != msg { - t.Fatalf("got %q; want %q", buf, msg) - } + testRouteToBackend(t, front, backBar, "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n") + <-testNotRouteToBackend(t, front, backBar, "GET / HTTP/1.1\r\nHost: boo.com\r\n\r\n") + testRouteToBackend(t, front, backFoo, "GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n") } func TestProxySNI(t *testing.T) { @@ -258,27 +292,9 @@ func TestProxySNI(t *testing.T) { t.Fatal(err) } - toFront, err := net.Dial("tcp", front.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer toFront.Close() - - msg := clientHelloRecord(t, "bar.com") - io.WriteString(toFront, msg) - - fromProxy, err := backBar.Accept() - if err != nil { - t.Fatal(err) - } - - buf := make([]byte, len(msg)) - if _, err := io.ReadFull(fromProxy, buf); err != nil { - t.Fatal(err) - } - if string(buf) != msg { - t.Fatalf("got %q; want %q", buf, msg) - } + testRouteToBackend(t, front, backBar, clientHelloRecord(t, "bar.com")) + <-testNotRouteToBackend(t, front, backBar, clientHelloRecord(t, "foo.com")) + testRouteToBackend(t, front, backFoo, clientHelloRecord(t, "foo.com")) } func TestProxyPROXYOut(t *testing.T) { @@ -301,23 +317,8 @@ func TestProxyPROXYOut(t *testing.T) { t.Fatal(err) } - io.WriteString(toFront, "foo") - toFront.Close() - - fromProxy, err := back.Accept() - if err != nil { - t.Fatal(err) - } - - bs, err := ioutil.ReadAll(fromProxy) - if err != nil { - t.Fatal(err) - } - want := fmt.Sprintf("PROXY TCP4 %s %d %s %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).Port) - if string(bs) != want { - t.Fatalf("got %q; want %q", bs, want) - } + testRouteToBackendWithExpected(t, toFront, back, "foo", want) } type tlsServer struct {