Skip to content

Commit

Permalink
[networks] Add timeouts/graceful shutdown to test suite (#33734)
Browse files Browse the repository at this point in the history
  • Loading branch information
pimlu authored Feb 7, 2025
1 parent df47013 commit d1db478
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 58 deletions.
48 changes: 47 additions & 1 deletion pkg/network/tracer/testutil/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
// Package testutil has utilities for testing the network tracer
package testutil

import "net"
import (
"errors"
"fmt"
"io"
"net"
"time"
)

// TCPServer is a simple TCP server for use in tests
type TCPServer struct {
Expand Down Expand Up @@ -61,17 +67,57 @@ func (t *TCPServer) Run() error {
if err != nil {
return
}
err = SetTestDeadline(conn)
if err != nil {
return
}
go t.onMessage(conn)
}
}()

return nil
}

// Dial creates a TCP connection to the server, and sets reasonable timeouts
func (t *TCPServer) Dial() (net.Conn, error) {
return DialTCP("tcp", t.Address())
}

// DialTCP creates a connection to the specified address, and sets reasonable timeouts for TCP
func DialTCP(network, address string) (net.Conn, error) {
conn, err := net.DialTimeout(network, address, time.Second)
if err != nil {
return nil, fmt.Errorf("failed to dial %s: %w", address, err)
}
err = SetTestDeadline(conn)
if err != nil {
return nil, err
}
return conn, nil
}

// Shutdown stops the TCP server
func (t *TCPServer) Shutdown() {
if t.ln != nil {
_ = t.ln.Close()
t.ln = nil
}
}

// SetTestDeadline prevents connection reads/writes from blocking the test indefinitely
func SetTestDeadline(conn net.Conn) error {
// any test in the tracer suite should conclude in less than a minute (normally a couple seconds)
return conn.SetDeadline(time.Now().Add(time.Minute))
}

// GracefulCloseTCP closes a connection after making sure all data has been sent/read
// It first shuts down the write end, then reads until EOF, then closes the connection
// https://blog.netherlabs.nl/articles/2009/01/18/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable
func GracefulCloseTCP(conn net.Conn) error {
tcpConn := conn.(*net.TCPConn)

shutdownErr := tcpConn.CloseWrite()
_, readErr := io.Copy(io.Discard, tcpConn)
closeErr := tcpConn.Close()
return errors.Join(shutdownErr, readErr, closeErr)
}
67 changes: 30 additions & 37 deletions pkg/network/tracer/tracer_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"regexp"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"testing"
Expand Down Expand Up @@ -349,7 +348,7 @@ func (s *TracerSuite) TestTCPMiscount() {
t.Cleanup(server.Shutdown)
require.NoError(t, server.Run())

c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond)
c, err := server.Dial()
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -467,7 +466,7 @@ func (s *TracerSuite) TestConntrackExpiration() {
_, port, err := net.SplitHostPort(server.Address())
require.NoError(t, err, "could not split server address %s", server.Address())

c, err := net.Dial("tcp", "2.2.2.2:"+port)
c, err := tracertestutil.DialTCP("tcp", "2.2.2.2:"+port)
require.NoError(t, err)
t.Cleanup(func() {
c.Close()
Expand Down Expand Up @@ -520,7 +519,6 @@ func (s *TracerSuite) TestConntrackDelays() {
skipOnEbpflessNotSupported(t, cfg)

netlinktestutil.SetupDNAT(t)
wg := sync.WaitGroup{}

tr := setupTracer(t, cfg)
// This will ensure that the first lookup for every connection fails, while the following ones succeed
Expand All @@ -529,9 +527,7 @@ func (s *TracerSuite) TestConntrackDelays() {
// Letting the OS pick an open port is necessary to avoid flakiness in the test. Running the the test multiple
// times can fail if binding to the same port since Conntrack might not emit NEW events for the same tuple
server := tracertestutil.NewTCPServerOnAddress(fmt.Sprintf("1.1.1.1:%d", 0), func(c net.Conn) {
wg.Add(1)
defer wg.Done()
defer c.Close()
defer tracertestutil.GracefulCloseTCP(c)

r := bufio.NewReader(c)
r.ReadBytes(byte('\n'))
Expand All @@ -541,9 +537,9 @@ func (s *TracerSuite) TestConntrackDelays() {

_, port, err := net.SplitHostPort(server.Address())
require.NoError(t, err)
c, err := net.Dial("tcp", fmt.Sprintf("2.2.2.2:%s", port))
c, err := tracertestutil.DialTCP("tcp", fmt.Sprintf("2.2.2.2:%s", port))
require.NoError(t, err)
defer c.Close()
defer tracertestutil.GracefulCloseTCP(c)
_, err = c.Write([]byte("ping"))
require.NoError(t, err)

Expand All @@ -557,7 +553,6 @@ func (s *TracerSuite) TestConntrackDelays() {
// write newline so server connections will exit
_, err = c.Write([]byte("\n"))
require.NoError(t, err)
wg.Wait()
}

func (s *TracerSuite) TestTranslationBindingRegression() {
Expand All @@ -567,15 +562,12 @@ func (s *TracerSuite) TestTranslationBindingRegression() {
skipOnEbpflessNotSupported(t, cfg)

netlinktestutil.SetupDNAT(t)
wg := sync.WaitGroup{}

tr := setupTracer(t, cfg)

// Setup TCP server
server := tracertestutil.NewTCPServerOnAddress(fmt.Sprintf("1.1.1.1:%d", 0), func(c net.Conn) {
wg.Add(1)
defer wg.Done()
defer c.Close()
defer tracertestutil.GracefulCloseTCP(c)

r := bufio.NewReader(c)
r.ReadBytes(byte('\n'))
Expand All @@ -586,9 +578,9 @@ func (s *TracerSuite) TestTranslationBindingRegression() {
// Send data to 2.2.2.2 (which should be translated to 1.1.1.1)
_, port, err := net.SplitHostPort(server.Address())
require.NoError(t, err)
c, err := net.Dial("tcp", fmt.Sprintf("2.2.2.2:%s", port))
c, err := tracertestutil.DialTCP("tcp", fmt.Sprintf("2.2.2.2:%s", port))
require.NoError(t, err)
defer c.Close()
defer tracertestutil.GracefulCloseTCP(c)
_, err = c.Write([]byte("ping"))
require.NoError(t, err)

Expand Down Expand Up @@ -616,7 +608,6 @@ func (s *TracerSuite) TestTranslationBindingRegression() {
// write newline so server connections will exit
_, err = c.Write([]byte("\n"))
require.NoError(t, err)
wg.Wait()
}

func (s *TracerSuite) TestUnconnectedUDPSendIPv6() {
Expand Down Expand Up @@ -910,7 +901,7 @@ func (s *TracerSuite) TestGatewayLookupCrossNamespace() {

var conn *network.ConnectionStats
t.Run("client in root namespace", func(t *testing.T) {
c, err := net.DialTimeout("tcp", server.Address(), 2*time.Second)
c, err := server.Dial()
require.NoError(t, err)

// write some data
Expand Down Expand Up @@ -940,7 +931,7 @@ func (s *TracerSuite) TestGatewayLookupCrossNamespace() {
var c net.Conn
err = kernel.WithNS(test2Ns, func() error {
var err error
c, err = net.DialTimeout("tcp", server.Address(), 2*time.Second)
c, err = server.Dial()
return err
})
require.NoError(t, err)
Expand Down Expand Up @@ -1116,7 +1107,7 @@ func (s *TracerSuite) TestDNATIntraHostIntegration() {
require.NoError(t, err)

var conn net.Conn
conn, err = net.Dial("tcp", "2.2.2.2:"+port)
conn, err = tracertestutil.DialTCP("tcp", "2.2.2.2:"+port)
require.NoError(t, err, "error connecting to client")
t.Cleanup(func() {
conn.Close()
Expand Down Expand Up @@ -1719,7 +1710,7 @@ func (s *TracerSuite) TestSendfileRegression() {
require.NoError(t, server.Run())

// Connect to TCP server
c, err := net.DialTimeout("tcp", server.Address(), time.Second)
c, err := tracertestutil.DialTCP("tcp", server.Address())
require.NoError(t, err)

testSendfileServer(t, c.(*net.TCPConn), network.TCP, family, func() int64 { return rcvd.Load() })
Expand Down Expand Up @@ -1790,7 +1781,7 @@ func (s *TracerSuite) TestSendfileError() {
require.NoError(t, server.Run())
t.Cleanup(server.Shutdown)

c, err := net.DialTimeout("tcp", server.Address(), time.Second)
c, err := server.Dial()
require.NoError(t, err)

// Send file contents via SENDFILE(2)
Expand Down Expand Up @@ -1949,21 +1940,21 @@ func (s *TracerSuite) TestBlockingReadCounts() {
tr := setupTracer(t, testConfig())
ch := make(chan struct{})
server := tracertestutil.NewTCPServer(func(c net.Conn) {
defer tracertestutil.GracefulCloseTCP(c)
_, err := c.Write([]byte("foo"))
require.NoError(t, err, "error writing to client")
time.Sleep(time.Second)
_, err = c.Write([]byte("foo"))
require.NoError(t, err, "error writing to client")
<-ch
})

require.NoError(t, server.Run())
t.Cleanup(server.Shutdown)
t.Cleanup(func() { close(ch) })

c, err := net.DialTimeout("tcp", server.Address(), 5*time.Second)
c, err := server.Dial()
require.NoError(t, err)
defer c.Close()
defer tracertestutil.GracefulCloseTCP(c)

rawConn, err := c.(syscall.Conn).SyscallConn()
require.NoError(t, err, "error getting raw conn")
Expand Down Expand Up @@ -2019,12 +2010,12 @@ func (s *TracerSuite) TestPreexistingConnectionDirection() {
}
_, _ = c.Write(genPayload(serverMessageSize))
}
c.Close()
tracertestutil.GracefulCloseTCP(c)
})
t.Cleanup(server.Shutdown)
require.NoError(t, server.Run())

c, err := net.DialTimeout("tcp", server.Address(), 50*time.Millisecond)
c, err := server.Dial()
require.NoError(t, err)
t.Cleanup(func() { c.Close() })

Expand All @@ -2041,7 +2032,7 @@ func (s *TracerSuite) TestPreexistingConnectionDirection() {
_, err = r.ReadBytes(byte('\n'))
require.NoError(t, err)

c.Close()
tracertestutil.GracefulCloseTCP(c)

var incoming, outgoing *network.ConnectionStats
require.EventuallyWithT(t, func(collect *assert.CollectT) {
Expand Down Expand Up @@ -2127,7 +2118,7 @@ func testPreexistingEmptyIncomingConnectionDirection(t *testing.T, config *confi
require.NoError(t, server.Run())
t.Cleanup(server.Shutdown)

c, err := net.DialTimeout("tcp", server.Address(), 5*time.Second)
c, err := server.Dial()
require.NoError(t, err)

// Enable BPF-based system probe
Expand Down Expand Up @@ -2432,9 +2423,8 @@ func (s *TracerSuite) TestConnectionDuration() {
require.NoError(t, srv.Run(), "error running server")
t.Cleanup(srv.Shutdown)

srvAddr := srv.Address()
c, err := net.DialTimeout("tcp", srvAddr, time.Second)
require.NoError(t, err, "could not connect to server at %s", srvAddr)
c, err := srv.Dial()
require.NoError(t, err)

ticker := time.NewTicker(100 * time.Millisecond)
t.Cleanup(ticker.Stop)
Expand Down Expand Up @@ -2583,7 +2573,7 @@ func (s *TracerSuite) TestTCPFailureConnectionResetWithDNAT() {

// Attempt to connect to the DNAT address (2.2.2.2), which should be redirected to the server at 1.1.1.1
serverAddr := "2.2.2.2:80"
c, err := net.Dial("tcp", serverAddr)
c, err := tracertestutil.DialTCP("tcp", serverAddr)
require.NoError(t, err, "could not connect to server: ", err)

// Write to the server and expect a reset
Expand Down Expand Up @@ -2649,6 +2639,7 @@ func (s *TracerSuite) TestTLSClassification() {
postTracerSetup: func(t *testing.T) (uint16, uint16) {
srv := usmtestutil.NewTLSServerWithSpecificVersion("localhost:0", func(conn net.Conn) {
defer conn.Close()
tracertestutil.SetTestDeadline(conn)
_, err := io.Copy(conn, conn)
if err != nil {
fmt.Printf("Failed to echo data: %v\n", err)
Expand All @@ -2674,7 +2665,7 @@ func (s *TracerSuite) TestTLSClassification() {
SessionTicketsDisabled: true,
ClientSessionCache: nil,
}
conn, err := net.Dial("tcp", addr)
conn, err := tracertestutil.DialTCP("tcp", addr)
require.NoError(t, err)
defer conn.Close()

Expand Down Expand Up @@ -2705,7 +2696,8 @@ func (s *TracerSuite) TestTLSClassification() {
return
}
go func(c net.Conn) {
defer c.Close()
tracertestutil.SetTestDeadline(c)
defer tracertestutil.GracefulCloseTCP(c)
buf := make([]byte, 1024)
_, _ = c.Read(buf)
// Do nothing with the data
Expand All @@ -2722,9 +2714,10 @@ func (s *TracerSuite) TestTLSClassification() {
port := uint16(portInt)

// Client connects to the server
conn, err := net.Dial("tcp", addr)
conn, err := tracertestutil.DialTCP("tcp", addr)
require.NoError(t, err)
defer conn.Close()
defer tracertestutil.GracefulCloseTCP(conn)
tracertestutil.SetTestDeadline(conn)

// Send invalid TLS handshake data
_, err = conn.Write([]byte("invalid TLS data"))
Expand Down
Loading

0 comments on commit d1db478

Please sign in to comment.