From d1db4781e111a59ecdfcbb43ecc53815d76e6cc1 Mon Sep 17 00:00:00 2001 From: Stuart Geipel Date: Thu, 6 Feb 2025 23:20:09 -0500 Subject: [PATCH] [networks] Add timeouts/graceful shutdown to test suite (#33734) --- pkg/network/tracer/testutil/tcp.go | 48 +++++++++++++++++- pkg/network/tracer/tracer_linux_test.go | 67 +++++++++++-------------- pkg/network/tracer/tracer_test.go | 40 +++++++-------- 3 files changed, 97 insertions(+), 58 deletions(-) diff --git a/pkg/network/tracer/testutil/tcp.go b/pkg/network/tracer/testutil/tcp.go index 9ef69afc27172c..21dd4bdf0b9f9c 100644 --- a/pkg/network/tracer/testutil/tcp.go +++ b/pkg/network/tracer/testutil/tcp.go @@ -8,7 +8,13 @@ // Package testutil has utilities for testing the network tracer package testutil -import "net" +import ( + "errors" + "fmt" + "io" + "net" + "time" +) // TCPServer is a simple TCP server for use in tests type TCPServer struct { @@ -61,6 +67,10 @@ func (t *TCPServer) Run() error { if err != nil { return } + err = SetTestDeadline(conn) + if err != nil { + return + } go t.onMessage(conn) } }() @@ -68,6 +78,24 @@ func (t *TCPServer) Run() error { return nil } +// Dial creates a TCP connection to the server, and sets reasonable timeouts +func (t *TCPServer) Dial() (net.Conn, error) { + return DialTCP("tcp", t.Address()) +} + +// DialTCP creates a connection to the specified address, and sets reasonable timeouts for TCP +func DialTCP(network, address string) (net.Conn, error) { + conn, err := net.DialTimeout(network, address, time.Second) + if err != nil { + return nil, fmt.Errorf("failed to dial %s: %w", address, err) + } + err = SetTestDeadline(conn) + if err != nil { + return nil, err + } + return conn, nil +} + // Shutdown stops the TCP server func (t *TCPServer) Shutdown() { if t.ln != nil { @@ -75,3 +103,21 @@ func (t *TCPServer) Shutdown() { t.ln = nil } } + +// SetTestDeadline prevents connection reads/writes from blocking the test indefinitely +func SetTestDeadline(conn net.Conn) error { + // any test in the tracer suite should conclude in less than a minute (normally a couple seconds) + return conn.SetDeadline(time.Now().Add(time.Minute)) +} + +// GracefulCloseTCP closes a connection after making sure all data has been sent/read +// It first shuts down the write end, then reads until EOF, then closes the connection +// https://blog.netherlabs.nl/articles/2009/01/18/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable +func GracefulCloseTCP(conn net.Conn) error { + tcpConn := conn.(*net.TCPConn) + + shutdownErr := tcpConn.CloseWrite() + _, readErr := io.Copy(io.Discard, tcpConn) + closeErr := tcpConn.Close() + return errors.Join(shutdownErr, readErr, closeErr) +} diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index e78f8401b3d913..f26bfdeb985e56 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -25,7 +25,6 @@ import ( "regexp" "strconv" "strings" - "sync" "sync/atomic" "syscall" "testing" @@ -349,7 +348,7 @@ func (s *TracerSuite) TestTCPMiscount() { t.Cleanup(server.Shutdown) require.NoError(t, server.Run()) - c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond) + c, err := server.Dial() if err != nil { t.Fatal(err) } @@ -467,7 +466,7 @@ func (s *TracerSuite) TestConntrackExpiration() { _, port, err := net.SplitHostPort(server.Address()) require.NoError(t, err, "could not split server address %s", server.Address()) - c, err := net.Dial("tcp", "2.2.2.2:"+port) + c, err := tracertestutil.DialTCP("tcp", "2.2.2.2:"+port) require.NoError(t, err) t.Cleanup(func() { c.Close() @@ -520,7 +519,6 @@ func (s *TracerSuite) TestConntrackDelays() { skipOnEbpflessNotSupported(t, cfg) netlinktestutil.SetupDNAT(t) - wg := sync.WaitGroup{} tr := setupTracer(t, cfg) // This will ensure that the first lookup for every connection fails, while the following ones succeed @@ -529,9 +527,7 @@ func (s *TracerSuite) TestConntrackDelays() { // Letting the OS pick an open port is necessary to avoid flakiness in the test. Running the the test multiple // times can fail if binding to the same port since Conntrack might not emit NEW events for the same tuple server := tracertestutil.NewTCPServerOnAddress(fmt.Sprintf("1.1.1.1:%d", 0), func(c net.Conn) { - wg.Add(1) - defer wg.Done() - defer c.Close() + defer tracertestutil.GracefulCloseTCP(c) r := bufio.NewReader(c) r.ReadBytes(byte('\n')) @@ -541,9 +537,9 @@ func (s *TracerSuite) TestConntrackDelays() { _, port, err := net.SplitHostPort(server.Address()) require.NoError(t, err) - c, err := net.Dial("tcp", fmt.Sprintf("2.2.2.2:%s", port)) + c, err := tracertestutil.DialTCP("tcp", fmt.Sprintf("2.2.2.2:%s", port)) require.NoError(t, err) - defer c.Close() + defer tracertestutil.GracefulCloseTCP(c) _, err = c.Write([]byte("ping")) require.NoError(t, err) @@ -557,7 +553,6 @@ func (s *TracerSuite) TestConntrackDelays() { // write newline so server connections will exit _, err = c.Write([]byte("\n")) require.NoError(t, err) - wg.Wait() } func (s *TracerSuite) TestTranslationBindingRegression() { @@ -567,15 +562,12 @@ func (s *TracerSuite) TestTranslationBindingRegression() { skipOnEbpflessNotSupported(t, cfg) netlinktestutil.SetupDNAT(t) - wg := sync.WaitGroup{} tr := setupTracer(t, cfg) // Setup TCP server server := tracertestutil.NewTCPServerOnAddress(fmt.Sprintf("1.1.1.1:%d", 0), func(c net.Conn) { - wg.Add(1) - defer wg.Done() - defer c.Close() + defer tracertestutil.GracefulCloseTCP(c) r := bufio.NewReader(c) r.ReadBytes(byte('\n')) @@ -586,9 +578,9 @@ func (s *TracerSuite) TestTranslationBindingRegression() { // Send data to 2.2.2.2 (which should be translated to 1.1.1.1) _, port, err := net.SplitHostPort(server.Address()) require.NoError(t, err) - c, err := net.Dial("tcp", fmt.Sprintf("2.2.2.2:%s", port)) + c, err := tracertestutil.DialTCP("tcp", fmt.Sprintf("2.2.2.2:%s", port)) require.NoError(t, err) - defer c.Close() + defer tracertestutil.GracefulCloseTCP(c) _, err = c.Write([]byte("ping")) require.NoError(t, err) @@ -616,7 +608,6 @@ func (s *TracerSuite) TestTranslationBindingRegression() { // write newline so server connections will exit _, err = c.Write([]byte("\n")) require.NoError(t, err) - wg.Wait() } func (s *TracerSuite) TestUnconnectedUDPSendIPv6() { @@ -910,7 +901,7 @@ func (s *TracerSuite) TestGatewayLookupCrossNamespace() { var conn *network.ConnectionStats t.Run("client in root namespace", func(t *testing.T) { - c, err := net.DialTimeout("tcp", server.Address(), 2*time.Second) + c, err := server.Dial() require.NoError(t, err) // write some data @@ -940,7 +931,7 @@ func (s *TracerSuite) TestGatewayLookupCrossNamespace() { var c net.Conn err = kernel.WithNS(test2Ns, func() error { var err error - c, err = net.DialTimeout("tcp", server.Address(), 2*time.Second) + c, err = server.Dial() return err }) require.NoError(t, err) @@ -1116,7 +1107,7 @@ func (s *TracerSuite) TestDNATIntraHostIntegration() { require.NoError(t, err) var conn net.Conn - conn, err = net.Dial("tcp", "2.2.2.2:"+port) + conn, err = tracertestutil.DialTCP("tcp", "2.2.2.2:"+port) require.NoError(t, err, "error connecting to client") t.Cleanup(func() { conn.Close() @@ -1719,7 +1710,7 @@ func (s *TracerSuite) TestSendfileRegression() { require.NoError(t, server.Run()) // Connect to TCP server - c, err := net.DialTimeout("tcp", server.Address(), time.Second) + c, err := tracertestutil.DialTCP("tcp", server.Address()) require.NoError(t, err) testSendfileServer(t, c.(*net.TCPConn), network.TCP, family, func() int64 { return rcvd.Load() }) @@ -1790,7 +1781,7 @@ func (s *TracerSuite) TestSendfileError() { require.NoError(t, server.Run()) t.Cleanup(server.Shutdown) - c, err := net.DialTimeout("tcp", server.Address(), time.Second) + c, err := server.Dial() require.NoError(t, err) // Send file contents via SENDFILE(2) @@ -1949,21 +1940,21 @@ func (s *TracerSuite) TestBlockingReadCounts() { tr := setupTracer(t, testConfig()) ch := make(chan struct{}) server := tracertestutil.NewTCPServer(func(c net.Conn) { + defer tracertestutil.GracefulCloseTCP(c) _, err := c.Write([]byte("foo")) require.NoError(t, err, "error writing to client") time.Sleep(time.Second) _, err = c.Write([]byte("foo")) require.NoError(t, err, "error writing to client") - <-ch }) require.NoError(t, server.Run()) t.Cleanup(server.Shutdown) t.Cleanup(func() { close(ch) }) - c, err := net.DialTimeout("tcp", server.Address(), 5*time.Second) + c, err := server.Dial() require.NoError(t, err) - defer c.Close() + defer tracertestutil.GracefulCloseTCP(c) rawConn, err := c.(syscall.Conn).SyscallConn() require.NoError(t, err, "error getting raw conn") @@ -2019,12 +2010,12 @@ func (s *TracerSuite) TestPreexistingConnectionDirection() { } _, _ = c.Write(genPayload(serverMessageSize)) } - c.Close() + tracertestutil.GracefulCloseTCP(c) }) t.Cleanup(server.Shutdown) require.NoError(t, server.Run()) - c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond) + c, err := server.Dial() require.NoError(t, err) t.Cleanup(func() { c.Close() }) @@ -2041,7 +2032,7 @@ func (s *TracerSuite) TestPreexistingConnectionDirection() { _, err = r.ReadBytes(byte('\n')) require.NoError(t, err) - c.Close() + tracertestutil.GracefulCloseTCP(c) var incoming, outgoing *network.ConnectionStats require.EventuallyWithT(t, func(collect *assert.CollectT) { @@ -2127,7 +2118,7 @@ func testPreexistingEmptyIncomingConnectionDirection(t *testing.T, config *confi require.NoError(t, server.Run()) t.Cleanup(server.Shutdown) - c, err := net.DialTimeout("tcp", server.Address(), 5*time.Second) + c, err := server.Dial() require.NoError(t, err) // Enable BPF-based system probe @@ -2432,9 +2423,8 @@ func (s *TracerSuite) TestConnectionDuration() { require.NoError(t, srv.Run(), "error running server") t.Cleanup(srv.Shutdown) - srvAddr := srv.Address() - c, err := net.DialTimeout("tcp", srvAddr, time.Second) - require.NoError(t, err, "could not connect to server at %s", srvAddr) + c, err := srv.Dial() + require.NoError(t, err) ticker := time.NewTicker(100 * time.Millisecond) t.Cleanup(ticker.Stop) @@ -2583,7 +2573,7 @@ func (s *TracerSuite) TestTCPFailureConnectionResetWithDNAT() { // Attempt to connect to the DNAT address (2.2.2.2), which should be redirected to the server at 1.1.1.1 serverAddr := "2.2.2.2:80" - c, err := net.Dial("tcp", serverAddr) + c, err := tracertestutil.DialTCP("tcp", serverAddr) require.NoError(t, err, "could not connect to server: ", err) // Write to the server and expect a reset @@ -2649,6 +2639,7 @@ func (s *TracerSuite) TestTLSClassification() { postTracerSetup: func(t *testing.T) (uint16, uint16) { srv := usmtestutil.NewTLSServerWithSpecificVersion("localhost:0", func(conn net.Conn) { defer conn.Close() + tracertestutil.SetTestDeadline(conn) _, err := io.Copy(conn, conn) if err != nil { fmt.Printf("Failed to echo data: %v\n", err) @@ -2674,7 +2665,7 @@ func (s *TracerSuite) TestTLSClassification() { SessionTicketsDisabled: true, ClientSessionCache: nil, } - conn, err := net.Dial("tcp", addr) + conn, err := tracertestutil.DialTCP("tcp", addr) require.NoError(t, err) defer conn.Close() @@ -2705,7 +2696,8 @@ func (s *TracerSuite) TestTLSClassification() { return } go func(c net.Conn) { - defer c.Close() + tracertestutil.SetTestDeadline(c) + defer tracertestutil.GracefulCloseTCP(c) buf := make([]byte, 1024) _, _ = c.Read(buf) // Do nothing with the data @@ -2722,9 +2714,10 @@ func (s *TracerSuite) TestTLSClassification() { port := uint16(portInt) // Client connects to the server - conn, err := net.Dial("tcp", addr) + conn, err := tracertestutil.DialTCP("tcp", addr) require.NoError(t, err) - defer conn.Close() + defer tracertestutil.GracefulCloseTCP(conn) + tracertestutil.SetTestDeadline(conn) // Send invalid TLS handshake data _, err = conn.Write([]byte("invalid TLS data")) diff --git a/pkg/network/tracer/tracer_test.go b/pkg/network/tracer/tracer_test.go index 3b7f8529c0b9c2..c641738d46aeb8 100644 --- a/pkg/network/tracer/tracer_test.go +++ b/pkg/network/tracer/tracer_test.go @@ -176,15 +176,15 @@ func (s *TracerSuite) TestTCPSendAndReceive() { break } } - c.Close() + testutil.GracefulCloseTCP(c) }) t.Cleanup(server.Shutdown) err := server.Run() require.NoError(t, err) - c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond) + c, err := server.Dial() require.NoError(t, err) - defer c.Close() + defer testutil.GracefulCloseTCP(c) // Connect to server 10 times wg := new(errgroup.Group) @@ -241,7 +241,7 @@ func (s *TracerSuite) TestTCPShortLived() { require.NoError(t, server.Run()) // Connect to server - c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond) + c, err := server.Dial() require.NoError(t, err) // Write clientMessageSize to server, and read response @@ -305,12 +305,12 @@ func (s *TracerSuite) TestTCPOverIPv6() { r := bufio.NewReader(c) r.ReadBytes(byte('\n')) c.Write(genPayload(serverMessageSize)) - c.Close() + testutil.GracefulCloseTCP(c) } }() // Connect to server - c, err := net.DialTimeout("tcp6", ln.Addr().String(), 50*time.Millisecond) + c, err := testutil.DialTCP("tcp6", ln.Addr().String()) require.NoError(t, err) // Write clientMessageSize to server, and read response @@ -354,17 +354,18 @@ func (s *TracerSuite) TestTCPCollectionDisabled() { r := bufio.NewReader(c) r.ReadBytes(byte('\n')) c.Write(genPayload(serverMessageSize)) - c.Close() + testutil.GracefulCloseTCP(c) }) t.Cleanup(server.Shutdown) require.NoError(t, server.Run()) // Connect to server - c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond) + c, err := server.Dial() if err != nil { t.Fatal(err) } + defer testutil.GracefulCloseTCP(c) // Write clientMessageSize to server, and read response if _, err = c.Write(genPayload(clientMessageSize)); err != nil { @@ -398,7 +399,7 @@ func (s *TracerSuite) TestTCPConnsReported() { require.NoError(t, server.Run()) // Connect to server - c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond) + c, err := server.Dial() require.NoError(t, err) <-processedChan c.Close() @@ -833,7 +834,7 @@ func benchEchoTCP(size int) func(b *testing.B) { for { buf, err := r.ReadBytes(byte('\n')) if err == io.EOF { - c.Close() + testutil.GracefulCloseTCP(c) return } c.Write(buf) @@ -845,11 +846,11 @@ func benchEchoTCP(size int) func(b *testing.B) { b.Cleanup(server.Shutdown) require.NoError(b, server.Run()) - c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond) + c, err := server.Dial() if err != nil { b.Fatal(err) } - defer c.Close() + defer testutil.GracefulCloseTCP(c) r := bufio.NewReader(c) b.ResetTimer() @@ -872,7 +873,7 @@ func benchSendTCP(size int) func(b *testing.B) { for { // Drop all payloads received _, err := r.Discard(r.Buffered() + 1) if err == io.EOF { - c.Close() + testutil.GracefulCloseTCP(c) return } } @@ -883,11 +884,11 @@ func benchSendTCP(size int) func(b *testing.B) { b.Cleanup(server.Shutdown) require.NoError(b, server.Run()) - c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond) + c, err := server.Dial() if err != nil { b.Fatal(err) } - defer c.Close() + defer testutil.GracefulCloseTCP(c) b.ResetTimer() for i := 0; i < b.N; i++ { // Send-heavy workload @@ -1146,7 +1147,7 @@ func (s *TracerSuite) TestTCPEstablishedPreExistingConn() { t.Cleanup(server.Shutdown) require.NoError(t, server.Run()) - c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond) + c, err := server.Dial() require.NoError(t, err) laddr, raddr := c.LocalAddr(), c.RemoteAddr() t.Logf("laddr=%s raddr=%s", laddr, raddr) @@ -1466,8 +1467,7 @@ func BenchmarkGetActiveConnections(b *testing.B) { cfg := testConfig() tr := setupTracer(b, cfg) server := testutil.NewTCPServer(func(c net.Conn) { - io.Copy(io.Discard, c) - c.Close() + testutil.GracefulCloseTCP(c) }) b.Cleanup(server.Shutdown) require.NoError(b, server.Run()) @@ -1476,7 +1476,7 @@ func BenchmarkGetActiveConnections(b *testing.B) { b.ResetTimer() for range b.N { - c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond) + c, err := server.Dial() require.NoError(b, err) laddr, raddr := c.LocalAddr(), c.RemoteAddr() c.Write([]byte("hello")) @@ -1486,7 +1486,7 @@ func BenchmarkGetActiveConnections(b *testing.B) { require.True(b, ok) assert.Equal(b, uint32(1), conn.Last.TCPEstablished) assert.Equal(b, uint32(0), conn.Last.TCPClosed) - c.Close() + testutil.GracefulCloseTCP(c) // Wait for the connection to be sent from the perf buffer require.Eventually(b, func() bool {