diff --git a/cmd/main.go b/cmd/main.go index 640fb33..6571fc6 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,10 +1,12 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package main implements the entry point for the bandwidth estimation test tool. package main import ( "context" + "errors" "flag" "io" "log" @@ -35,11 +37,64 @@ func realMain() error { return send(*addr, *rtpLogFile, *rtcpLogFile, *ccLogFile) } - log.Fatalf("invalid mode: %s\n", *mode) + log.Fatalf("invalid mode: %s", *mode) + return nil } func receive(addr, rtpLogFile, rtcpLogFile string) error { + rcv, err := newReceiver(rtpLogFile, rtcpLogFile) + if err != nil { + return err + } + defer func() { + if err = rcv.Close(); err != nil { + log.Printf("failed to close receiver: %v", err) + } + }() + + err = rcv.receiver.SetupPeerConnection() + if err != nil { + return err + } + http.Handle("/sdp", rcv.receiver.SDPHandler()) + + //nolint:gosec + return http.ListenAndServe(addr, nil) +} + +type recv struct { + receiver *receiver.Receiver + rtpLogger io.WriteCloser + rtcpLogger io.WriteCloser +} + +func (c recv) Close() error { + var errs []error + + err := c.receiver.Close() + if err != nil { + errs = append(errs, err) + } + + if c.rtpLogger != nil { + err = c.rtpLogger.Close() + if err != nil { + errs = append(errs, err) + } + } + + if c.rtcpLogger != nil { + err = c.rtcpLogger.Close() + if err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + +func newReceiver(rtpLogFile, rtcpLogFile string) (recv, error) { options := []receiver.Option{ receiver.PacketLogWriter(os.Stdout, os.Stdout), receiver.DefaultInterceptors(), @@ -50,101 +105,143 @@ func receive(addr, rtpLogFile, rtcpLogFile string) error { if rtpLogFile != "" { rtpLogger, err = logging.GetLogFile(rtpLogFile) if err != nil { - return err + return recv{}, err } - defer rtpLogger.Close() } if rtcpLogFile != "" { rtcpLogger, err = logging.GetLogFile(rtcpLogFile) if err != nil { - return err + return recv{}, err } - defer rtcpLogger.Close() } if rtpLogger != nil || rtcpLogger != nil { options = append(options, receiver.PacketLogWriter(rtpLogger, rtcpLogger)) } r, err := receiver.NewReceiver(options...) + if err != nil { + return recv{}, err + } + + return recv{ + receiver: r, + rtpLogger: rtpLogger, + rtcpLogger: rtcpLogger, + }, nil +} + +func send(addr, rtpLogFile, rtcpLogFile, ccLogFile string) error { + snd, err := newSender(rtpLogFile, rtcpLogFile, ccLogFile) if err != nil { return err } - err = r.SetupPeerConnection() + defer func() { + if err = snd.Close(); err != nil { + log.Printf("failed to close sender: %v", err) + } + }() + + err = snd.sender.SetupPeerConnection() if err != nil { return err } - http.Handle("/sdp", r.SDPHandler()) - log.Fatal(http.ListenAndServe(addr, nil)) - return nil + err = snd.sender.SignalHTTP(addr, "sdp") + if err != nil { + return err + } + + ctx, cancel := context.WithCancel(context.Background()) + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + defer func() { + signal.Stop(sigs) + cancel() + }() + go func() { + select { + case <-sigs: + cancel() + log.Println("cancel called") + case <-ctx.Done(): + } + }() + + return snd.sender.Start(ctx) } -func send(addr, rtpLogFile, rtcpLogFile, ccLogFile string) error { +type sndr struct { + sender *sender.Sender + rtpLogger io.WriteCloser + rtcpLogger io.WriteCloser + ccLogger io.WriteCloser +} + +func (s sndr) Close() error { + var errs []error + + err := s.rtpLogger.Close() + if err != nil { + errs = append(errs, err) + } + + err = s.rtcpLogger.Close() + if err != nil { + errs = append(errs, err) + } + + err = s.ccLogger.Close() + if err != nil { + errs = append(errs, err) + } + + return errors.Join(errs...) +} + +func newSender(rtpLogFile, rtcpLogFile, ccLogFile string) (sndr, error) { options := []sender.Option{ sender.DefaultInterceptors(), sender.GCC(initialBitrate), } var rtpLogger io.WriteCloser var rtcpLogger io.WriteCloser + var ccLogger io.WriteCloser var err error if rtpLogFile != "" { rtpLogger, err = logging.GetLogFile(rtpLogFile) if err != nil { - return err + return sndr{}, err } - defer rtpLogger.Close() } if rtcpLogFile != "" { rtcpLogger, err = logging.GetLogFile(rtcpLogFile) if err != nil { - return err + return sndr{}, err } - defer rtcpLogger.Close() } if ccLogFile != "" { - var ccLogger io.WriteCloser ccLogger, err = logging.GetLogFile(ccLogFile) if err != nil { - return err + return sndr{}, err } - defer ccLogger.Close() options = append(options, sender.CCLogWriter(ccLogger)) } if rtpLogger != nil || rtcpLogger != nil { options = append(options, sender.PacketLogWriter(rtpLogger, rtcpLogger)) } - s, err := sender.NewSender( + snd, err := sender.NewSender( sender.NewStatisticalEncoderSource(), options..., ) if err != nil { - return err + return sndr{}, err } - err = s.SetupPeerConnection() - if err != nil { - return err - } - err = s.SignalHTTP(addr, "sdp") - if err != nil { - return err - } - - ctx, cancel := context.WithCancel(context.Background()) - - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - defer func() { - signal.Stop(sigs) - cancel() - }() - go func() { - select { - case <-sigs: - cancel() - log.Println("cancel called") - case <-ctx.Done(): - } - }() - return s.Start(ctx) + return sndr{ + sender: snd, + rtpLogger: rtpLogger, + rtcpLogger: rtcpLogger, + ccLogger: ccLogger, + }, nil } func main() { diff --git a/go.mod b/go.mod index ce1b713..7a87b84 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/pion/rtp v1.8.13 github.com/pion/transport/v3 v3.0.7 github.com/pion/webrtc/v4 v4.0.14 + github.com/stretchr/testify v1.10.0 golang.org/x/sync v0.11.0 ) @@ -24,12 +25,15 @@ require ( ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/pion/datachannel v1.5.10 // indirect github.com/pion/randutil v0.1.0 // indirect github.com/pion/sctp v1.8.37 // indirect github.com/pion/sdp/v3 v3.0.11 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.33.0 // indirect golang.org/x/net v0.35.0 // indirect golang.org/x/sys v0.30.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 1da3027..24e59c9 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,11 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o= github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M= github.com/pion/dtls/v3 v3.0.4 h1:44CZekewMzfrn9pmGrj5BNnTMDCFwr+6sLH+cCuLM7U= @@ -36,7 +39,9 @@ github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGux github.com/pion/webrtc/v4 v4.0.14 h1:nyds/sFRR+HvmWoBa6wrL46sSfpArE0qR883MBW96lg= github.com/pion/webrtc/v4 v4.0.14/go.mod h1:R3+qTnQTS03UzwDarYecgioNf7DYgTsldxnCXB821Kk= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= @@ -47,4 +52,7 @@ golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/sequencenumber/unwrapper.go b/internal/sequencenumber/unwrapper.go new file mode 100644 index 0000000..48500b3 --- /dev/null +++ b/internal/sequencenumber/unwrapper.go @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package sequencenumber provides a sequence number unwrapper +package sequencenumber + +const ( + maxSequenceNumberPlusOne = int64(65536) + breakpoint = 32768 // half of max uint16 +) + +// Unwrapper stores an unwrapped sequence number. +type Unwrapper struct { + init bool + lastUnwrapped int64 +} + +func isNewer(value, previous uint16) bool { + if value-previous == breakpoint { + return value > previous + } + + return value != previous && (value-previous) < breakpoint +} + +// Unwrap unwraps the next sequencenumber. +func (u *Unwrapper) Unwrap(i uint16) int64 { + if !u.init { + u.init = true + u.lastUnwrapped = int64(i) + + return u.lastUnwrapped + } + + lastWrapped := uint16(u.lastUnwrapped) //nolint:gosec // G115 + delta := int64(i - lastWrapped) + if isNewer(i, lastWrapped) { + if delta < 0 { + delta += maxSequenceNumberPlusOne + } + } else if delta > 0 && u.lastUnwrapped+delta-maxSequenceNumberPlusOne >= 0 { + delta -= maxSequenceNumberPlusOne + } + + u.lastUnwrapped += delta + + return u.lastUnwrapped +} diff --git a/internal/sequencenumber/unwrapper_test.go b/internal/sequencenumber/unwrapper_test.go new file mode 100644 index 0000000..47c4154 --- /dev/null +++ b/internal/sequencenumber/unwrapper_test.go @@ -0,0 +1,115 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package sequencenumber + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsNewer(t *testing.T) { + cases := []struct { + a, b uint16 + expected bool + }{ + { + a: 1, + b: 0, + expected: true, + }, + { + a: 65534, + b: 65535, + expected: false, + }, + { + a: 65535, + b: 65535, + expected: false, + }, + { + a: 0, + b: 65535, + expected: true, + }, + { + a: 0, + b: 32767, + expected: false, + }, + { + a: 32770, + b: 2, + expected: true, + }, + { + a: 3, + b: 32770, + expected: false, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + assert.Equalf(t, tc.expected, isNewer(tc.a, tc.b), "expected isNewer(%v, %v) to be %v", tc.a, tc.b, tc.expected) + }) + } +} + +func TestUnwrapper(t *testing.T) { + cases := []struct { + input []uint16 + expected []int64 + }{ + { + input: []uint16{}, + expected: []int64{}, + }, + { + input: []uint16{0, 1, 2, 3, 4}, + expected: []int64{0, 1, 2, 3, 4}, + }, + { + input: []uint16{65534, 65535, 0, 1, 2}, + expected: []int64{65534, 65535, 65536, 65537, 65538}, + }, + { + input: []uint16{32769, 0}, + expected: []int64{32769, 65536}, + }, + { + input: []uint16{32767, 0}, + expected: []int64{32767, 0}, + }, + { + input: []uint16{0, 1, 4, 3, 2, 5}, + expected: []int64{0, 1, 4, 3, 2, 5}, + }, + { + input: []uint16{65534, 0, 1, 65535, 4, 3, 2, 5}, + expected: []int64{65534, 65536, 65537, 65535, 65540, 65539, 65538, 65541}, + }, + { + input: []uint16{ + 0, 32767, 32768, 32769, 32770, + 1, 2, 32765, 32770, 65535, + }, + expected: []int64{ + 0, 32767, 32768, 32769, 32770, + 65537, 65538, 98301, 98306, 131071, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + u := &Unwrapper{} + result := []int64{} + for _, i := range tc.input { + result = append(result, u.Unwrap(i)) + } + assert.Equal(t, tc.expected, result) + }) + } +} diff --git a/logging/file.go b/logging/file.go index 8a80951..f038e27 100644 --- a/logging/file.go +++ b/logging/file.go @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package logging provides utilities for logging in bandwidth estimation tests. package logging import ( @@ -9,6 +10,10 @@ import ( "os" ) +// GetLogFile returns an io.WriteCloser for the specified file path. +// If file is empty, it returns a no-op writer. +// If file is "stdout", it returns os.Stdout wrapped in a nopCloser. +// Otherwise, it creates and returns the specified file. func GetLogFile(file string) (io.WriteCloser, error) { if len(file) == 0 { return nopCloser{io.Discard}, nil @@ -16,6 +21,7 @@ func GetLogFile(file string) (io.WriteCloser, error) { if file == "stdout" { return nopCloser{os.Stdout}, nil } + //nolint:gosec fd, err := os.Create(file) if err != nil { return nil, err @@ -47,5 +53,6 @@ func (f *fileCloser) Close() error { if err := f.buf.Flush(); err != nil { return err } + return f.f.Close() } diff --git a/logging/format.go b/logging/format.go index fcef194..bc6a20d 100644 --- a/logging/format.go +++ b/logging/format.go @@ -1,70 +1,37 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package logging provides utilities for logging in bandwidth estimation tests. package logging import ( "fmt" "time" + "github.com/pion/bwe-test/internal/sequencenumber" "github.com/pion/interceptor" "github.com/pion/rtcp" "github.com/pion/rtp" ) -const ( - maxSequenceNumberPlusOne = int64(65536) - breakpoint = 32768 // half of max uint16 -) - -type unwrapper struct { - init bool - lastUnwrapped int64 -} - -func isNewer(value, previous uint16) bool { - if value-previous == breakpoint { - return value > previous - } - return value != previous && (value-previous) < breakpoint -} - -func (u *unwrapper) unwrap(i uint16) int64 { - if !u.init { - u.init = true - u.lastUnwrapped = int64(i) - return u.lastUnwrapped - } - - lastWrapped := uint16(u.lastUnwrapped) - delta := int64(i - lastWrapped) - if isNewer(i, lastWrapped) { - if delta < 0 { - delta += maxSequenceNumberPlusOne - } - } else if delta > 0 && u.lastUnwrapped+delta-maxSequenceNumberPlusOne >= 0 { - delta -= maxSequenceNumberPlusOne - } - - u.lastUnwrapped += int64(delta) - return u.lastUnwrapped -} - +// RTPFormatter formats RTP packets for logging. type RTPFormatter struct { - seqnr unwrapper + seqnr sequencenumber.Unwrapper } +// RTPFormat formats an RTP packet as a string for logging. func (f *RTPFormatter) RTPFormat(pkt *rtp.Packet, _ interceptor.Attributes) string { var twcc rtp.TransportCCExtension - unwrappedSeqNr := f.seqnr.unwrap(pkt.SequenceNumber) + unwrappedSeqNr := f.seqnr.Unwrap(pkt.SequenceNumber) var twccNr uint16 if len(pkt.GetExtensionIDs()) > 0 { ext := pkt.GetExtension(pkt.GetExtensionIDs()[0]) if err := twcc.Unmarshal(ext); err != nil { - panic(err) + return fmt.Sprintf("Error unmarshaling TWCC extension: %v", err) } twccNr = twcc.TransportSequence } + return fmt.Sprintf("%v, %v, %v, %v, %v, %v, %v, %v, %v\n", time.Now().UnixMilli(), pkt.PayloadType, @@ -78,6 +45,7 @@ func (f *RTPFormatter) RTPFormat(pkt *rtp.Packet, _ interceptor.Attributes) stri ) } +// RTCPFormat formats RTCP packets as a string for logging. func RTCPFormat(pkts []rtcp.Packet, _ interceptor.Attributes) string { now := time.Now().UnixMilli() size := 0 @@ -86,8 +54,9 @@ func RTCPFormat(pkts []rtcp.Packet, _ interceptor.Attributes) string { case *rtcp.TransportLayerCC: size += int(feedback.Len()) case *rtcp.RawPacket: - size += int(len(*feedback)) + size += len(*feedback) } } + return fmt.Sprintf("%v, %v\n", now, size) } diff --git a/receiver/option.go b/receiver/option.go index 3753fc1..0eee8e0 100644 --- a/receiver/option.go +++ b/receiver/option.go @@ -1,24 +1,26 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package receiver implements WebRTC receiver functionality for bandwidth estimation tests. package receiver import ( "io" "time" + "github.com/pion/bwe-test/logging" "github.com/pion/interceptor/pkg/packetdump" plogging "github.com/pion/logging" "github.com/pion/transport/v3/vnet" "github.com/pion/webrtc/v4" - - "github.com/pion/bwe-test/logging" ) +// Option is a function that configures a Receiver. type Option func(*Receiver) error +// PacketLogWriter returns an Option that configures RTP and RTCP packet logging. func PacketLogWriter(rtpWriter, rtcpWriter io.Writer) Option { - return func(r *Receiver) error { + return func(receiver *Receiver) error { formatter := logging.RTPFormatter{} rtpLogger, err := packetdump.NewReceiverInterceptor( packetdump.RTPFormatter(formatter.RTPFormat), @@ -34,31 +36,37 @@ func PacketLogWriter(rtpWriter, rtcpWriter io.Writer) Option { if err != nil { return err } - r.registry.Add(rtpLogger) - r.registry.Add(rtcpLogger) + receiver.registry.Add(rtpLogger) + receiver.registry.Add(rtcpLogger) + return nil } } +// DefaultInterceptors returns an Option that registers the default WebRTC interceptors. func DefaultInterceptors() Option { return func(r *Receiver) error { return webrtc.RegisterDefaultInterceptors(r.mediaEngine, r.registry) } } +// SetVnet returns an Option that configures the virtual network for testing. func SetVnet(v *vnet.Net, publicIPs []string) Option { return func(r *Receiver) error { r.settingEngine.SetNet(v) r.settingEngine.SetICETimeouts(time.Second, time.Second, 200*time.Millisecond) r.settingEngine.SetNAT1To1IPs(publicIPs, webrtc.ICECandidateTypeHost) + return nil } } +// SetLoggerFactory returns an Option that configures the logger factory. func SetLoggerFactory(loggerFactory plogging.LoggerFactory) Option { return func(s *Receiver) error { s.settingEngine.LoggerFactory = loggerFactory s.log = loggerFactory.NewLogger("receiver") + return nil } } diff --git a/receiver/receiver.go b/receiver/receiver.go index 28422f2..e7c5325 100644 --- a/receiver/receiver.go +++ b/receiver/receiver.go @@ -1,11 +1,13 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package receiver implements WebRTC receiver functionality for bandwidth estimation tests. package receiver import ( "context" "encoding/json" + "errors" "io" "net/http" "time" @@ -16,6 +18,7 @@ import ( "github.com/pion/webrtc/v4" ) +// Receiver manages a WebRTC connection for receiving media. type Receiver struct { settingEngine *webrtc.SettingEngine mediaEngine *webrtc.MediaEngine @@ -27,29 +30,33 @@ type Receiver struct { log logging.LeveledLogger } +// NewReceiver creates a new WebRTC receiver with the given options. func NewReceiver(opts ...Option) (*Receiver, error) { - r := &Receiver{ + receiver := &Receiver{ settingEngine: &webrtc.SettingEngine{}, mediaEngine: &webrtc.MediaEngine{}, peerConnection: &webrtc.PeerConnection{}, registry: &interceptor.Registry{}, log: logging.NewDefaultLoggerFactory().NewLogger("receiver"), } - if err := r.mediaEngine.RegisterDefaultCodecs(); err != nil { + if err := receiver.mediaEngine.RegisterDefaultCodecs(); err != nil { return nil, err } for _, opt := range opts { - if err := opt(r); err != nil { + if err := opt(receiver); err != nil { return nil, err } } - return r, nil + + return receiver, nil } +// Close stops and cleans up the receiver. func (r *Receiver) Close() error { return r.peerConnection.Close() } +// SetupPeerConnection initializes the WebRTC peer connection. func (r *Receiver) SetupPeerConnection() error { peerConnection, err := webrtc.NewAPI( webrtc.WithSettingEngine(*r.settingEngine), @@ -63,25 +70,27 @@ func (r *Receiver) SetupPeerConnection() error { // Set the handler for ICE connection state // This will notify you when the peer has connected/disconnected peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { - r.log.Infof("Receiver Connection State has changed %s \n", connectionState.String()) + r.log.Infof("Receiver Connection State has changed %s", connectionState.String()) }) // Set the handler for Peer connection state // This will notify you when the peer has connected/disconnected peerConnection.OnConnectionStateChange(func(s webrtc.PeerConnectionState) { - r.log.Infof("Receiver Peer Connection State has changed: %s\n", s.String()) + r.log.Infof("Receiver Peer Connection State has changed: %s", s.String()) }) peerConnection.OnICECandidate(func(i *webrtc.ICECandidate) { - r.log.Infof("Receiver candidate: %v\n", i) + r.log.Infof("Receiver candidate: %v", i) }) peerConnection.OnTrack(r.onTrack) r.peerConnection = peerConnection + return nil } +// AcceptOffer processes a WebRTC offer from the remote peer and creates an answer. func (r *Receiver) AcceptOffer(offer *webrtc.SessionDescription) (*webrtc.SessionDescription, error) { if err := r.peerConnection.SetRemoteDescription(*offer); err != nil { return nil, err @@ -120,7 +129,7 @@ func (r *Receiver) onTrack(trackRemote *webrtc.TrackRemote, rtpReceiver *webrtc. bits := float64(bytesReceived) * 8.0 rate := bits / delta.Seconds() mBitPerSecond := rate / float64(vnet.MBit) - r.log.Infof("throughput: %.2f Mb/s\n", mBitPerSecond) + r.log.Infof("throughput: %.2f Mb/s", mBitPerSecond) bytesReceived = 0 last = now case newBytesReceived := <-bytesReceivedChan: @@ -137,36 +146,46 @@ func (r *Receiver) onTrack(trackRemote *webrtc.TrackRemote, rtpReceiver *webrtc. } p, _, err := trackRemote.ReadRTP() - if err == io.EOF { + if errors.Is(err, io.EOF) { r.log.Infof("trackRemote.ReadRTP received EOF") + return } if err != nil { - r.log.Infof("trackRemote.ReadRTP returned error: %v\n", err) + r.log.Infof("trackRemote.ReadRTP returned error: %v", err) + continue } bytesReceivedChan <- p.MarshalSize() } } +// SDPHandler returns an HTTP handler for WebRTC signaling. func (r *Receiver) SDPHandler() http.HandlerFunc { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + return http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) { sdp := webrtc.SessionDescription{} if err := json.NewDecoder(req.Body).Decode(&sdp); err != nil { - panic(err) + r.log.Errorf("failed to decode SDP offer: %v", err) + respWriter.WriteHeader(http.StatusBadRequest) + + return } answer, err := r.AcceptOffer(&sdp) if err != nil { - w.WriteHeader(http.StatusBadRequest) + respWriter.WriteHeader(http.StatusBadRequest) + return } // Send our answer to the HTTP server listening in the other process payload, err := json.Marshal(answer) if err != nil { - panic(err) + r.log.Errorf("failed to marshal SDP answer: %v", err) + respWriter.WriteHeader(http.StatusInternalServerError) + + return } - w.Header().Set("Content-Type", "application/json") - if _, err := w.Write(payload); err != nil { + respWriter.Header().Set("Content-Type", "application/json") + if _, err := respWriter.Write(payload); err != nil { r.log.Errorf("failed to write signaling response: %v", err) } }) diff --git a/sender/abr.go b/sender/abr.go index cd8f821..58fcaeb 100644 --- a/sender/abr.go +++ b/sender/abr.go @@ -1,10 +1,12 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package sender implements WebRTC sender functionality for bandwidth estimation tests. package sender import ( "context" + "errors" "sync" "github.com/pion/bwe-test/syncodec" @@ -12,7 +14,7 @@ import ( "github.com/pion/webrtc/v4/pkg/media" ) -// StatisticalEncoderSource is a source that fakes a media encoder using syncodec.StatisticalCodec +// StatisticalEncoderSource is a source that fakes a media encoder using syncodec.StatisticalCodec. type StatisticalEncoderSource struct { codec syncodec.Codec sampleWriter func(media.Sample) error @@ -23,12 +25,14 @@ type StatisticalEncoderSource struct { log logging.LeveledLogger } -// NewStatisticalEncoderSource returns a new StatisticalEncoderSource +var errUninitializedtatisticalEncoderSource = errors.New("write on uninitialized StatisticalEncoderSource.WriteSample") + +// NewStatisticalEncoderSource returns a new StatisticalEncoderSource. func NewStatisticalEncoderSource() *StatisticalEncoderSource { return &StatisticalEncoderSource{ codec: nil, sampleWriter: func(_ media.Sample) error { - panic("write on uninitialized StatisticalEncoderSource.WriteSample") + return errUninitializedtatisticalEncoderSource }, updateTargetBitrate: make(chan int), newFrame: make(chan syncodec.Frame), @@ -38,14 +42,17 @@ func NewStatisticalEncoderSource() *StatisticalEncoderSource { } } +// SetTargetBitrate sets the target bitrate for the encoder. func (s *StatisticalEncoderSource) SetTargetBitrate(rate int) { s.updateTargetBitrate <- rate } +// SetWriter sets the sample writer function. func (s *StatisticalEncoderSource) SetWriter(f func(sample media.Sample) error) { s.sampleWriter = f } +// Start begins the encoding process and runs until context is done. func (s *StatisticalEncoderSource) Start(ctx context.Context) error { s.wg.Add(1) defer s.wg.Done() @@ -58,7 +65,7 @@ func (s *StatisticalEncoderSource) Start(ctx context.Context) error { go s.codec.Start() defer func() { if err := s.codec.Close(); err != nil { - s.log.Infof("failed to close codec: %v", err) + s.log.Errorf("failed to close codec: %v", err) } }() @@ -78,12 +85,15 @@ func (s *StatisticalEncoderSource) Start(ctx context.Context) error { } } +// WriteFrame writes a frame to the encoder. func (s *StatisticalEncoderSource) WriteFrame(frame syncodec.Frame) { s.newFrame <- frame } +// Close stops the encoder and cleans up resources. func (s *StatisticalEncoderSource) Close() error { defer s.wg.Wait() close(s.done) + return nil } diff --git a/sender/option.go b/sender/option.go index 360041f..7e3f725 100644 --- a/sender/option.go +++ b/sender/option.go @@ -1,26 +1,28 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package sender implements WebRTC sender functionality for bandwidth estimation tests. package sender import ( "io" "time" + "github.com/pion/bwe-test/logging" "github.com/pion/interceptor/pkg/cc" "github.com/pion/interceptor/pkg/gcc" "github.com/pion/interceptor/pkg/packetdump" plogging "github.com/pion/logging" "github.com/pion/transport/v3/vnet" "github.com/pion/webrtc/v4" - - "github.com/pion/bwe-test/logging" ) +// Option is a function that configures a Sender. type Option func(*Sender) error +// PacketLogWriter returns an Option that configures RTP and RTCP packet logging. func PacketLogWriter(rtpWriter, rtcpWriter io.Writer) Option { - return func(s *Sender) error { + return func(sndr *Sender) error { formatter := logging.RTPFormatter{} rtpLogger, err := packetdump.NewSenderInterceptor( packetdump.RTPFormatter(formatter.RTPFormat), @@ -36,27 +38,32 @@ func PacketLogWriter(rtpWriter, rtcpWriter io.Writer) Option { if err != nil { return err } - s.registry.Add(rtpLogger) - s.registry.Add(rtcpLogger) + sndr.registry.Add(rtpLogger) + sndr.registry.Add(rtcpLogger) + return nil } } +// DefaultInterceptors returns an Option that registers the default WebRTC interceptors. func DefaultInterceptors() Option { return func(s *Sender) error { return webrtc.RegisterDefaultInterceptors(s.mediaEngine, s.registry) } } +// CCLogWriter returns an Option that configures congestion control logging. func CCLogWriter(w io.Writer) Option { return func(s *Sender) error { s.ccLogWriter = w + return nil } } +// GCC returns an Option that configures Google Congestion Control with the specified initial bitrate. func GCC(initialBitrate int) Option { - return func(s *Sender) error { + return func(sndr *Sender) error { controller, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) { return gcc.NewSendSideBWE(gcc.SendSideBWEInitialBitrate(initialBitrate)) }) @@ -65,37 +72,44 @@ func GCC(initialBitrate int) Option { } controller.OnNewPeerConnection(func(_ string, estimator cc.BandwidthEstimator) { go func() { - s.estimatorChan <- estimator + sndr.estimatorChan <- estimator }() }) - s.registry.Add(controller) - if err = webrtc.ConfigureTWCCHeaderExtensionSender(s.mediaEngine, s.registry); err != nil { + sndr.registry.Add(controller) + if err = webrtc.ConfigureTWCCHeaderExtensionSender(sndr.mediaEngine, sndr.registry); err != nil { return err } + return nil } } +// SetVnet returns an Option that configures the virtual network for testing. func SetVnet(v *vnet.Net, publicIPs []string) Option { return func(s *Sender) error { s.settingEngine.SetNet(v) s.settingEngine.SetICETimeouts(time.Second, time.Second, 200*time.Millisecond) s.settingEngine.SetNAT1To1IPs(publicIPs, webrtc.ICECandidateTypeHost) + return nil } } +// SetMediaSource returns an Option that sets the media source for the sender. func SetMediaSource(source MediaSource) Option { return func(s *Sender) error { s.source = source + return nil } } +// SetLoggerFactory returns an Option that configures the logger factory. func SetLoggerFactory(loggerFactory plogging.LoggerFactory) Option { return func(s *Sender) error { s.settingEngine.LoggerFactory = loggerFactory s.log = loggerFactory.NewLogger("sender") + return nil } } diff --git a/sender/sender.go b/sender/sender.go index f335389..168e94d 100644 --- a/sender/sender.go +++ b/sender/sender.go @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package sender implements WebRTC sender functionality for bandwidth estimation tests. package sender import ( @@ -20,12 +21,14 @@ import ( "golang.org/x/sync/errgroup" ) +// MediaSource represents a source of media samples that can be sent over WebRTC. type MediaSource interface { SetTargetBitrate(int) SetWriter(func(sample media.Sample) error) Start(ctx context.Context) error } +// Sender manages a WebRTC connection for sending media. type Sender struct { settingEngine *webrtc.SettingEngine mediaEngine *webrtc.MediaEngine @@ -44,6 +47,7 @@ type Sender struct { log logging.LeveledLogger } +// NewSender creates a new WebRTC sender with the given media source and options. func NewSender(source MediaSource, opts ...Option) (*Sender, error) { sender := &Sender{ settingEngine: &webrtc.SettingEngine{}, @@ -69,6 +73,7 @@ func NewSender(source MediaSource, opts ...Option) (*Sender, error) { return sender, nil } +// SetupPeerConnection initializes the WebRTC peer connection. func (s *Sender) SetupPeerConnection() error { // Create a new RTCPeerConnection peerConnection, err := webrtc.NewAPI( @@ -82,7 +87,13 @@ func (s *Sender) SetupPeerConnection() error { s.peerConnection = peerConnection // Create a video track - videoTrack, err := webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeVP8}, "video", "pion") + videoTrack, err := webrtc.NewTrackLocalStaticSample( + webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeVP8, + }, + "video", + "pion", + ) if err != nil { return err } @@ -108,22 +119,26 @@ func (s *Sender) SetupPeerConnection() error { // Set the handler for ICE connection state // This will notify you when the peer has connected/disconnected s.peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { - s.log.Infof("Sender Connection State has changed %s \n", connectionState.String()) + s.log.Infof("Sender Connection State has changed %s", connectionState.String()) }) // Set the handler for Peer connection state // This will notify you when the peer has connected/disconnected s.peerConnection.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { - s.log.Infof("Sender Peer Connection State has changed: %s\n", state.String()) + s.log.Infof("Sender Peer Connection State has changed: %s", state.String()) }) peerConnection.OnICECandidate(func(i *webrtc.ICECandidate) { - s.log.Infof("Sender candidate: %v\n", i) + s.log.Infof("Sender candidate: %v", i) }) + return nil } +var errNoPeerConnection = fmt.Errorf("no PeerConnection created") + +// CreateOffer creates a WebRTC offer for signaling. func (s *Sender) CreateOffer() (*webrtc.SessionDescription, error) { if s.peerConnection == nil { - return nil, fmt.Errorf("no PeerConnection created") + return nil, errNoPeerConnection } offer, err := s.peerConnection.CreateOffer(nil) if err != nil { @@ -139,16 +154,18 @@ func (s *Sender) CreateOffer() (*webrtc.SessionDescription, error) { // we do this because we only can exchange one signaling message // in a production application you should exchange ICE Candidates via OnICECandidate <-gatherComplete - s.log.Infof("Sender gatherComplete: %v\n", s.peerConnection.ICEGatheringState()) + s.log.Infof("Sender gatherComplete: %v", s.peerConnection.ICEGatheringState()) return s.peerConnection.LocalDescription(), nil } +// AcceptAnswer processes a WebRTC answer from the remote peer. func (s *Sender) AcceptAnswer(answer *webrtc.SessionDescription) error { // Sets the LocalDescription, and starts our UDP listeners return s.peerConnection.SetRemoteDescription(*answer) } +// Start begins the media sending process and runs until context is done. func (s *Sender) Start(ctx context.Context) error { ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() @@ -171,14 +188,16 @@ func (s *Sender) Start(ctx context.Context) error { case now := <-ticker.C: targetBitrate := estimator.GetTargetBitrate() if now.Sub(lastLog) >= time.Second { - s.log.Infof("targetBitrate = %v\n", targetBitrate) + s.log.Infof("targetBitrate = %v", targetBitrate) lastLog = now } if lastBitrate != targetBitrate { s.source.SetTargetBitrate(targetBitrate) lastBitrate = targetBitrate } - fmt.Fprintf(s.ccLogWriter, "%v, %v\n", now.UnixMilli(), targetBitrate) + if _, err := fmt.Fprintf(s.ccLogWriter, "%v, %v\n", now.UnixMilli(), targetBitrate); err != nil { + s.log.Errorf("failed to write to ccLogWriter: %v", err) + } case <-ctx.Done(): return nil } @@ -189,10 +208,18 @@ func (s *Sender) Start(ctx context.Context) error { return s.source.Start(ctx) }) - defer s.peerConnection.Close() + defer func() { + if err := s.peerConnection.Close(); err != nil { + s.log.Errorf("failed to close peer connection: %v", err) + } + }() + return wg.Wait() } +var errSignalingFailed = fmt.Errorf("signaling failed") + +// SignalHTTP performs WebRTC signaling over HTTP. func (s *Sender) SignalHTTP(addr, route string) error { offer, err := s.CreateOffer() if err != nil { @@ -203,17 +230,23 @@ func (s *Sender) SignalHTTP(addr, route string) error { return err } url := fmt.Sprintf("http://%s/%s", addr, route) - s.log.Infof("connecting to '%v'\n", url) + s.log.Infof("connecting to '%v'", url) + //nolint:gosec,noctx resp, err := http.Post(url, "application/json; charset=utf-8", bytes.NewReader(payload)) if err != nil { return err } + defer func() { + if err = resp.Body.Close(); err != nil { + s.log.Errorf("failed to close signal http body: %v", err) + } + }() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("signaling received unexpected status code: %v: %v", resp.StatusCode, resp.Status) + return fmt.Errorf("%w: unexpected status code: %v: %v", errSignalingFailed, resp.StatusCode, resp.Status) } answer := webrtc.SessionDescription{} if sdpErr := json.NewDecoder(resp.Body).Decode(&answer); sdpErr != nil { - panic(sdpErr) + return fmt.Errorf("decode SDP answer: %w", sdpErr) } return s.AcceptAnswer(&answer) diff --git a/sender/simulcast.go b/sender/simulcast.go index 413244d..7402483 100644 --- a/sender/simulcast.go +++ b/sender/simulcast.go @@ -1,10 +1,12 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package sender implements WebRTC sender functionality for bandwidth estimation tests. package sender import ( "context" + "errors" "io" "os" "sync" @@ -30,6 +32,8 @@ const ( ivfHeaderSize = 32 ) +// SimulcastFilesSource is a media source that switches between different quality +// video files based on the target bitrate. type SimulcastFilesSource struct { qualityLevels []struct { fileName string @@ -43,13 +47,17 @@ type SimulcastFilesSource struct { log logging.LeveledLogger } +// Close stops the simulcast source and cleans up resources. func (s *SimulcastFilesSource) Close() error { defer s.wg.Wait() close(s.done) + return nil } -// NewSimulcastFilesSource returns a new SimulcastFilesSource +var errUninitializedSimulcastFileSource = errors.New("write on uninitialized SimulcastFileSource.WriteSample") + +// NewSimulcastFilesSource returns a new SimulcastFilesSource. func NewSimulcastFilesSource() *SimulcastFilesSource { return &SimulcastFilesSource{ qualityLevels: []struct { @@ -62,8 +70,8 @@ func NewSimulcastFilesSource() *SimulcastFilesSource { }, currentQualityLevel: 0, updateTargetBitrate: make(chan int), - WriteSample: func(sample media.Sample) error { - panic("write on uninitialized SimulcastFileSource.WriteSample") + WriteSample: func(_ media.Sample) error { + return errUninitializedSimulcastFileSource }, done: make(chan struct{}), wg: sync.WaitGroup{}, @@ -71,14 +79,19 @@ func NewSimulcastFilesSource() *SimulcastFilesSource { } } +// SetTargetBitrate sets the target bitrate for the simulcast source. func (s *SimulcastFilesSource) SetTargetBitrate(rate int) { s.updateTargetBitrate <- rate } +// SetWriter sets the sample writer function. func (s *SimulcastFilesSource) SetWriter(f func(sample media.Sample) error) { s.WriteSample = f } +// Start begins the simulcast process and runs until context is done. +// +//nolint:gocognit,cyclop func (s *SimulcastFilesSource) Start(ctx context.Context) error { files := make(map[string]*os.File) file, err := os.Open(s.qualityLevels[s.currentQualityLevel].fileName) @@ -90,7 +103,7 @@ func (s *SimulcastFilesSource) Start(ctx context.Context) error { for _, file := range files { err1 := file.Close() if err1 != nil { - s.log.Infof("failed to close file %v: %v", file.Name(), err1) + s.log.Errorf("failed to close file %v: %v", file.Name(), err1) } } }() @@ -105,7 +118,8 @@ func (s *SimulcastFilesSource) Start(ctx context.Context) error { // It is important to use a time.Ticker instead of time.Sleep because // * avoids accumulating skew, just calling time.Sleep didn't compensate for the time spent parsing the data // * works around latency issues with Sleep (see https://github.com/golang/go/issues/44343) - ticker := time.NewTicker(time.Millisecond * time.Duration((float32(header.TimebaseNumerator)/float32(header.TimebaseDenominator))*1000)) + tickerMS := int64(float32(header.TimebaseNumerator) / float32(header.TimebaseDenominator) * 1000) + ticker := time.NewTicker(time.Millisecond * time.Duration(tickerMS)) var frame []byte frameHeader := &ivfreader.IVFFrameHeader{} currentTimestamp := uint64(0) @@ -113,6 +127,7 @@ func (s *SimulcastFilesSource) Start(ctx context.Context) error { setReaderFile := func(filename string) (f func(_ int64) io.Reader, err error) { file, ok := files[s.qualityLevels[s.currentQualityLevel].fileName] if !ok { + //nolint:gosec file, err = os.Open(filename) if err != nil { return nil, err @@ -122,13 +137,16 @@ func (s *SimulcastFilesSource) Start(ctx context.Context) error { if _, err = file.Seek(ivfHeaderSize, io.SeekStart); err != nil { return nil, err } + return func(_ int64) io.Reader { return file }, nil } switchQualityLevel := func(newQualityLevel int) error { - s.log.Infof("Switching from %s to %s \n", s.qualityLevels[s.currentQualityLevel].fileName, s.qualityLevels[newQualityLevel].fileName) + from := s.qualityLevels[s.currentQualityLevel].fileName + to := s.qualityLevels[newQualityLevel].fileName + s.log.Infof("Switching from %s to %s", from, to) s.currentQualityLevel = newQualityLevel readerFile, err1 := setReaderFile(s.qualityLevels[s.currentQualityLevel].fileName) @@ -138,12 +156,11 @@ func (s *SimulcastFilesSource) Start(ctx context.Context) error { ivf.ResetReader(readerFile) for { if frame, frameHeader, err = ivf.ParseNextFrame(); err != nil { - break + return err } else if frameHeader.Timestamp >= currentTimestamp && frame[0]&0x1 == 0 { - break + return nil } } - return nil } targetBitrate := initialBitrate @@ -152,6 +169,7 @@ func (s *SimulcastFilesSource) Start(ctx context.Context) error { case rate := <-s.updateTargetBitrate: targetBitrate = rate case <-ticker.C: + haveBetterQuality := len(s.qualityLevels) > (s.currentQualityLevel + 1) switch { // If current quality level is below target bitrate drop to level below case s.currentQualityLevel != 0 && targetBitrate < s.qualityLevels[s.currentQualityLevel].bitrate: @@ -161,7 +179,7 @@ func (s *SimulcastFilesSource) Start(ctx context.Context) error { } // If next quality level is above target bitrate move to next level - case len(s.qualityLevels) > (s.currentQualityLevel+1) && targetBitrate > s.qualityLevels[s.currentQualityLevel+1].bitrate: + case haveBetterQuality && targetBitrate > s.qualityLevels[s.currentQualityLevel+1].bitrate: err = switchQualityLevel(s.currentQualityLevel + 1) if err != nil { return err @@ -172,15 +190,15 @@ func (s *SimulcastFilesSource) Start(ctx context.Context) error { frame, _, err = ivf.ParseNextFrame() } - switch err { + switch { // No error write the video frame - case nil: + case err == nil: currentTimestamp = frameHeader.Timestamp if err = s.WriteSample(media.Sample{Data: frame, Duration: time.Second}); err != nil { return err } // If we have reached the end of the file start again - case io.EOF: + case errors.Is(err, io.EOF): readerFile, err1 := setReaderFile(s.qualityLevels[s.currentQualityLevel].fileName) if err1 != nil { return err1 diff --git a/stats/server.go b/stats/server.go index 498f8c8..a72a8ab 100644 --- a/stats/server.go +++ b/stats/server.go @@ -1,77 +1,81 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package stats provides functionality for collecting and visualizing statistics. package stats import ( - "context" "html/template" - "log" "net/http" "github.com/gorilla/websocket" + "github.com/pion/logging" ) +// DataPoint represents a single data point for visualization. type DataPoint struct { Label string Timestamp int64 // milliseconds after Start Value float64 } +// Server handles WebSocket connections for real-time data visualization. type Server struct { - srv *http.Server upgrader *websocket.Upgrader dataChan chan DataPoint + log logging.LeveledLogger } +// New creates a new statistics server. func New() *Server { return &Server{ upgrader: &websocket.Upgrader{}, dataChan: make(chan DataPoint), + log: logging.NewDefaultLoggerFactory().NewLogger("server"), } } +// Add adds a data point to the server for broadcasting to clients. func (s *Server) Add(d DataPoint) { go func() { s.dataChan <- d }() } +// Start starts the statistics server on the specified address. +func (s *Server) Start(addr string) error { + mux := http.NewServeMux() + mux.HandleFunc("/", s.home) + mux.HandleFunc("/update", s.update) + + //nolint:gosec + return http.ListenAndServe(addr, mux) +} + func (s *Server) update(w http.ResponseWriter, r *http.Request) { - c, err := s.upgrader.Upgrade(w, r, nil) + wsConn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { - log.Print("s.upgrader.Upgrade:", err) + s.log.Errorf("s.upgrader.Upgrade: %v", err) + return } - defer c.Close() - for dataPoint := range s.dataChan { - if err = c.WriteJSON(dataPoint); err != nil { - log.Print("c.WriteJSON:", err) - return + defer func() { + if err = wsConn.Close(); err != nil { + s.log.Errorf("failed to close websocket connection: %v", err) } - } -} + }() -func (s *Server) home(w http.ResponseWriter, r *http.Request) { - homeTemplate.Execute(w, "ws://"+r.Host+"/update") -} + for dataPoint := range s.dataChan { + if err = wsConn.WriteJSON(dataPoint); err != nil { + s.log.Errorf("c.WriteJSON: %v", err) -func (s *Server) Start() { - mux := &http.ServeMux{} - mux.HandleFunc("/", s.home) - mux.HandleFunc("/update", s.update) - s.srv = &http.Server{ - Addr: ":8080", - Handler: mux, + return + } } - s.srv.ListenAndServe() } -func (s *Server) Shutdown(ctx context.Context) error { - return s.srv.Shutdown(ctx) -} - -var homeTemplate = template.Must(template.New("").Parse(` +func (s *Server) home(respWriter http.ResponseWriter, req *http.Request) { + homeTemplate := template.Must(template.New("").Parse(` @@ -84,13 +88,13 @@ var homeTemplate = template.Must(template.New("").Parse(`
`)) + + if err := homeTemplate.Execute(respWriter, "ws://"+req.Host+"/update"); err != nil { + s.log.Errorf("failed to execute template: %v", err) + http.Error(respWriter, "Internal server error", http.StatusInternalServerError) + } +} diff --git a/syncodec/codec.go b/syncodec/codec.go index c9d4b61..7670019 100644 --- a/syncodec/codec.go +++ b/syncodec/codec.go @@ -1,6 +1,8 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package syncodec provides synthetic codec implementations for bandwidth estimation tests. +// It simulates media codecs with configurable bitrate and other parameters. package syncodec import ( @@ -8,22 +10,33 @@ import ( "time" ) +// Frame represents a media frame with content and duration. type Frame struct { - Content []byte - Duration time.Duration + Content []byte // Raw frame data + Duration time.Duration // Duration of the frame } func (f Frame) String() string { return fmt.Sprintf("FRAME: \n\tDURATION: %v\n\tSIZE: %v\n", f.Duration, len(f.Content)) } +// Codec defines the interface for synthetic codecs. type Codec interface { + // GetTargetBitrate returns the current target bitrate in bits per second. GetTargetBitrate() int + + // SetTargetBitrate sets the target bitrate in bits per second. SetTargetBitrate(int) + + // Start begins the codec operation. Start() + + // Close stops the codec and cleans up resources. Close() error } +// FrameWriter defines the interface for writing frames. type FrameWriter interface { + // WriteFrame writes a frame to the underlying media sink. WriteFrame(Frame) } diff --git a/syncodec/perfect_codec.go b/syncodec/perfect_codec.go index 2fe856c..843097d 100644 --- a/syncodec/perfect_codec.go +++ b/syncodec/perfect_codec.go @@ -9,6 +9,8 @@ import ( var _ Codec = (*PerfectCodec)(nil) +// PerfectCodec implements a simple codec that produces frames at a constant rate +// with sizes exactly matching the target bitrate. type PerfectCodec struct { writer FrameWriter @@ -18,6 +20,7 @@ type PerfectCodec struct { done chan struct{} } +// NewPerfectCodec creates a new PerfectCodec with the specified frame writer and target bitrate. func NewPerfectCodec(writer FrameWriter, targetBitrateBps int) *PerfectCodec { return &PerfectCodec{ writer: writer, @@ -37,6 +40,7 @@ func (c *PerfectCodec) SetTargetBitrate(r int) { c.targetBitrateBps = r } +// Start begins the codec operation, generating frames at the configured frame rate. func (c *PerfectCodec) Start() { msToNextFrame := time.Duration((1.0/float64(c.fps))*1000.0) * time.Millisecond ticker := time.NewTicker(msToNextFrame) @@ -53,7 +57,9 @@ func (c *PerfectCodec) Start() { } } +// Close stops the codec and cleans up resources. func (c *PerfectCodec) Close() error { close(c.done) + return nil } diff --git a/syncodec/statistical_codec.go b/syncodec/statistical_codec.go index 0b26a30..328b636 100644 --- a/syncodec/statistical_codec.go +++ b/syncodec/statistical_codec.go @@ -10,10 +10,7 @@ import ( "time" ) -func init() { - rand.Seed(time.Now().UnixNano()) -} - +// Constants for the statistical codec. const ( defaultTargetBitrateBps = 1_000_000 // 1 Mbps defaultFPS = 30 @@ -23,38 +20,45 @@ const ( defaultT0 = 33 * time.Millisecond defaultB0 = 4_170 // 4.17 KB - // scaling parameter of zero-mean laplacian distribution describing - // deviations in normalized frame interval + // Scaling parameter of zero-mean laplacian distribution describing + // deviations in normalized frame interval. defaultScaleT = 0.15 - // scaling parameter of zero-mean laplacian distribution describing - // deviations in normalized frame size + // Scaling parameter of zero-mean laplacian distribution describing + // deviations in normalized frame size. defaultScaleB = 0.15 defaultRMin = 150_000 // 150 kbps defaultRMax = 150_000_000 // 150 Mbps ) +// noiser defines an interface for adding noise to values. type noiser interface { noise() float64 } +// laplaceNoise implements the noiser interface using a Laplace distribution. type laplaceNoise struct { rnd *rand.Rand scale float64 } +// noise returns a random value from a Laplace distribution. func (l laplaceNoise) noise() float64 { if l.rnd == nil { + //nolint:gosec l.rnd = rand.New(rand.NewSource(time.Now().UnixNano())) } e1 := -l.scale * math.Log(l.rnd.Float64()) e2 := -l.scale * math.Log(l.rnd.Float64()) + return e1 - e2 } var _ Codec = (*StatisticalCodec)(nil) +// StatisticalCodec implements a codec that produces frames with sizes and timings +// that follow statistical distributions to simulate real-world codecs. type StatisticalCodec struct { // requested target bitrate targetBitrateBps int @@ -77,10 +81,10 @@ type StatisticalCodec struct { // reference frame size targetBitrateBps / fps b0 int - // max rate supported by video encoder + // min rate supported by video encoder rMin int - // min rate supported by video encoder + // max rate supported by video encoder rMax int // output writer @@ -95,7 +99,6 @@ type StatisticalCodec struct { scaleT float64 // internal types - targetBitrateLock sync.Mutex targetBitrateChan chan int lastTargetBitrateUpdate time.Time @@ -108,50 +111,64 @@ type StatisticalCodec struct { done chan struct{} } +// StatisticalCodecOption is a function that configures a StatisticalCodec. type StatisticalCodecOption func(*StatisticalCodec) error +// WithInitialTargetBitrate sets the initial target bitrate for the codec. func WithInitialTargetBitrate(targetBitrateBps int) StatisticalCodecOption { return func(sc *StatisticalCodec) error { sc.targetBitrateBps = targetBitrateBps + return nil } } +// WithFramesPerSecond sets the frames per second for the codec. func WithFramesPerSecond(fps int) StatisticalCodecOption { return func(sc *StatisticalCodec) error { sc.fps = fps + return nil } } +// WithScaleB sets the scaling parameter for frame size noise. func WithScaleB(scale float64) StatisticalCodecOption { return func(sc *StatisticalCodec) error { sc.scaleB = scale + return nil } } +// WithScaleT sets the scaling parameter for frame timing noise. func WithScaleT(scale float64) StatisticalCodecOption { return func(sc *StatisticalCodec) error { sc.scaleT = scale + return nil } } -func min(a, b int) int { +// minimum returns the minimum of two integers. +func minimum(a, b int) int { if a < b { return a } + return b } -func max(a, b int) int { +// maximum returns the maximum of two integers. +func maximum(a, b int) int { if a > b { return a } + return b } +// NewStatisticalEncoder creates a new StatisticalCodec with the given frame writer and options. func NewStatisticalEncoder(w FrameWriter, opts ...StatisticalCodecOption) (*StatisticalCodec, error) { sc := &StatisticalCodec{ targetBitrateBps: defaultTargetBitrateBps, @@ -182,10 +199,12 @@ func NewStatisticalEncoder(w FrameWriter, opts ...StatisticalCodecOption) (*Stat } sc.frameSizeNoiser = laplaceNoise{ + //nolint:gosec rnd: rand.New(rand.NewSource(time.Now().UnixNano())), scale: sc.scaleB, } sc.frameDurationNoiser = laplaceNoise{ + //nolint:gosec rnd: rand.New(rand.NewSource(time.Now().UnixNano())), scale: sc.scaleT, } @@ -207,13 +226,14 @@ func (c *StatisticalCodec) GetTargetBitrate() int { // c.rMin, bitrate will be set to c.rMin. func (c *StatisticalCodec) SetTargetBitrate(r int) { if r < c.targetBitrateBps { - c.targetBitrateBps = max(r, c.rMin) + c.targetBitrateBps = maximum(r, c.rMin) + return } - c.targetBitrateBps = min(r, c.rMax) + c.targetBitrateBps = minimum(r, c.rMax) } -// NextFrame returns the next faked video frame +// nextFrame returns the next faked video frame. func (c *StatisticalCodec) nextFrame() Frame { duration := time.Duration((1.0/float64(c.fps))*1000.0) * time.Millisecond @@ -244,7 +264,7 @@ func (c *StatisticalCodec) nextFrame() Frame { } } -// Start starts the StatisticalCodec +// Start starts the StatisticalCodec. func (c *StatisticalCodec) Start() { timer := time.NewTimer(c.t0) for { @@ -270,8 +290,9 @@ func (c *StatisticalCodec) Start() { } } -// Close stops and closes the StatisticalCodec +// Close stops and closes the StatisticalCodec. func (c *StatisticalCodec) Close() error { close(c.done) + return nil } diff --git a/vnet/flow.go b/vnet/flow.go index f69b347..ea5f405 100644 --- a/vnet/flow.go +++ b/vnet/flow.go @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package main implements virtual network functionality for bandwidth estimation tests. package main import ( @@ -8,23 +9,19 @@ import ( "fmt" "io" - plogging "github.com/pion/logging" - "github.com/pion/bwe-test/logging" "github.com/pion/bwe-test/receiver" "github.com/pion/bwe-test/sender" + plogging "github.com/pion/logging" ) +// Flow represents a WebRTC connection between a sender and receiver over a virtual network. type Flow struct { - sender *sender.Sender - receiver *receiver.Receiver - senderRTPLogger io.WriteCloser - senderRTCPLogger io.WriteCloser - ccLogger io.WriteCloser - receiverRTPLogger io.WriteCloser - receiverRTCPLogger io.WriteCloser + sender sndr + receiver recv } +// NewSimpleFlow creates a new Flow with the specified parameters. func NewSimpleFlow( loggerFactory plogging.LoggerFactory, nm *NetworkManager, @@ -32,35 +29,123 @@ func NewSimpleFlow( senderMode senderMode, dataDir string, ) (Flow, error) { - leftVnet, publicIPLeft, err := nm.GetLeftNet() + snd, err := newSender(loggerFactory, nm, id, senderMode, dataDir) if err != nil { - return Flow{}, fmt.Errorf("get left net: %w", err) + return Flow{}, fmt.Errorf("new sender: %w", err) } - rightVnet, publicIPRight, err := nm.GetRightNet() + err = snd.sender.SetupPeerConnection() if err != nil { - return Flow{}, fmt.Errorf("get right net: %w", err) + return Flow{}, fmt.Errorf("sender setup peer connection: %w", err) } - senderRTPLogger, err := logging.GetLogFile(fmt.Sprintf("%v/%v_sender_rtp.log", dataDir, id)) + offer, err := snd.sender.CreateOffer() if err != nil { - return Flow{}, fmt.Errorf("get sender rtp log file: %w", err) + return Flow{}, fmt.Errorf("sender create offer: %w", err) } - senderRTCPLogger, err := logging.GetLogFile(fmt.Sprintf("%v/%v_sender_rtcp.log", dataDir, id)) + rc, err := newReceiver(nm, id, dataDir) + if err != nil { + return Flow{}, fmt.Errorf("new sender: %w", err) + } + + err = rc.receiver.SetupPeerConnection() + if err != nil { + return Flow{}, fmt.Errorf("receiver setup peer connection: %w", err) + } + + answer, err := rc.receiver.AcceptOffer(offer) + if err != nil { + return Flow{}, fmt.Errorf("receiver accept offer: %w", err) + } + + err = snd.sender.AcceptAnswer(answer) + if err != nil { + return Flow{}, fmt.Errorf("sender accept answer: %w", err) + } + + return Flow{ + sender: snd, + receiver: rc, + }, nil +} + +// Close stops the flow and cleans up all resources. +func (f Flow) Close() error { + var errs []error + err := f.receiver.Close() + if err != nil { + errs = append(errs, fmt.Errorf("receiver close: %w", err)) + } + err = f.sender.Close() + if err != nil { + errs = append(errs, fmt.Errorf("sender close: %w", err)) + } + + return errors.Join(errs...) +} + +var errUnknownSenderMode = errors.New("unknown sender mode") + +type sndr struct { + sender *sender.Sender + ccLogger io.WriteCloser + senderRTPLogger io.WriteCloser + senderRTCPLogger io.WriteCloser +} + +func (s sndr) Close() error { + var errs []error + + err := s.ccLogger.Close() + if err != nil { + errs = append(errs, err) + } + + err = s.senderRTPLogger.Close() + if err != nil { + errs = append(errs, err) + } + + err = s.senderRTCPLogger.Close() + if err != nil { + errs = append(errs, err) + } + + return errors.Join(errs...) +} + +func newSender( + loggerFactory plogging.LoggerFactory, + nm *NetworkManager, + id int, + senderMode senderMode, + dataDir string, +) (sndr, error) { + leftVnet, publicIPLeft, err := nm.GetLeftNet() if err != nil { - return Flow{}, fmt.Errorf("get sender rtcp log file: %w", err) + return sndr{}, fmt.Errorf("get left net: %w", err) } ccLogger, err := logging.GetLogFile(fmt.Sprintf("%v/%v_cc.log", dataDir, id)) if err != nil { - return Flow{}, fmt.Errorf("get cc log file: %w", err) + return sndr{}, fmt.Errorf("get cc log file: %w", err) } - var s *sender.Sender + senderRTPLogger, err := logging.GetLogFile(fmt.Sprintf("%v/%v_sender_rtp.log", dataDir, id)) + if err != nil { + return sndr{}, fmt.Errorf("get sender rtp log file: %w", err) + } + + senderRTCPLogger, err := logging.GetLogFile(fmt.Sprintf("%v/%v_sender_rtcp.log", dataDir, id)) + if err != nil { + return sndr{}, fmt.Errorf("get sender rtcp log file: %w", err) + } + + var snd *sender.Sender switch senderMode { case abrSenderMode: - s, err = sender.NewSender( + snd, err = sender.NewSender( sender.NewStatisticalEncoderSource(), sender.SetVnet(leftVnet, []string{publicIPLeft}), sender.PacketLogWriter(senderRTPLogger, senderRTCPLogger), @@ -69,10 +154,10 @@ func NewSimpleFlow( sender.SetLoggerFactory(loggerFactory), ) if err != nil { - return Flow{}, fmt.Errorf("new abr sender: %w", err) + return sndr{}, fmt.Errorf("new abr sender: %w", err) } case simulcastSenderMode: - s, err = sender.NewSender( + snd, err = sender.NewSender( sender.NewSimulcastFilesSource(), sender.SetVnet(leftVnet, []string{publicIPLeft}), sender.PacketLogWriter(senderRTPLogger, senderRTCPLogger), @@ -81,92 +166,79 @@ func NewSimpleFlow( sender.SetLoggerFactory(loggerFactory), ) if err != nil { - return Flow{}, fmt.Errorf("new simulcast sender: %w", err) + return sndr{}, fmt.Errorf("new simulcast sender: %w", err) } default: - return Flow{}, fmt.Errorf("invalid sender mode: %v", senderMode) + return sndr{}, fmt.Errorf("%w: %v", errUnknownSenderMode, senderMode) } - receiverRTPLogger, err := logging.GetLogFile(fmt.Sprintf("%v/%v_receiver_rtp.log", dataDir, id)) - if err != nil { - return Flow{}, fmt.Errorf("get receiver rtp log file: %w", err) - } + return sndr{ + sender: snd, + ccLogger: ccLogger, + senderRTPLogger: senderRTPLogger, + senderRTCPLogger: senderRTCPLogger, + }, nil +} - receiverRTCPLogger, err := logging.GetLogFile(fmt.Sprintf("%v/%v_receiver_rtcp.log", dataDir, id)) +type recv struct { + receiver *receiver.Receiver + receiverRTPLogger io.WriteCloser + receiverRTCPLogger io.WriteCloser +} + +func (s recv) Close() error { + var errs []error + + err := s.receiver.Close() if err != nil { - return Flow{}, fmt.Errorf("get receiver rtcp log file: %w", err) + errs = append(errs, err) } - rc, err := receiver.NewReceiver( - receiver.SetVnet(rightVnet, []string{publicIPRight}), - receiver.PacketLogWriter(receiverRTPLogger, receiverRTCPLogger), - receiver.DefaultInterceptors(), - ) + err = s.receiverRTPLogger.Close() if err != nil { - return Flow{}, fmt.Errorf("new receiver: %w", err) + errs = append(errs, err) } - err = s.SetupPeerConnection() + err = s.receiverRTCPLogger.Close() if err != nil { - return Flow{}, fmt.Errorf("sender setup peer connection: %w", err) + errs = append(errs, err) } - offer, err := s.CreateOffer() + return errors.Join(errs...) +} + +func newReceiver( + nm *NetworkManager, + id int, + dataDir string, +) (recv, error) { + rightVnet, publicIPRight, err := nm.GetRightNet() if err != nil { - return Flow{}, fmt.Errorf("sender create offer: %w", err) + return recv{}, fmt.Errorf("get right net: %w", err) } - err = rc.SetupPeerConnection() + receiverRTPLogger, err := logging.GetLogFile(fmt.Sprintf("%v/%v_receiver_rtp.log", dataDir, id)) if err != nil { - return Flow{}, fmt.Errorf("receiver setup peer connection: %w", err) + return recv{}, fmt.Errorf("get receiver rtp log file: %w", err) } - answer, err := rc.AcceptOffer(offer) + receiverRTCPLogger, err := logging.GetLogFile(fmt.Sprintf("%v/%v_receiver_rtcp.log", dataDir, id)) if err != nil { - return Flow{}, fmt.Errorf("receiver accept offer: %w", err) + return recv{}, fmt.Errorf("get receiver rtcp log file: %w", err) } - err = s.AcceptAnswer(answer) + rc, err := receiver.NewReceiver( + receiver.SetVnet(rightVnet, []string{publicIPRight}), + receiver.PacketLogWriter(receiverRTPLogger, receiverRTCPLogger), + receiver.DefaultInterceptors(), + ) if err != nil { - return Flow{}, fmt.Errorf("sender accept answer: %w", err) + return recv{}, fmt.Errorf("new receiver: %w", err) } - return Flow{ - sender: s, + return recv{ receiver: rc, - senderRTPLogger: senderRTPLogger, - senderRTCPLogger: senderRTCPLogger, - ccLogger: ccLogger, receiverRTPLogger: receiverRTPLogger, receiverRTCPLogger: receiverRTCPLogger, }, nil } - -func (f Flow) Close() error { - var errs []error - err := f.receiver.Close() - if err != nil { - errs = append(errs, fmt.Errorf("receiver close: %w", err)) - } - err = f.senderRTPLogger.Close() - if err != nil { - errs = append(errs, fmt.Errorf("sender rtp logger close: %w", err)) - } - err = f.senderRTCPLogger.Close() - if err != nil { - errs = append(errs, fmt.Errorf("sender rtcp logger close: %w", err)) - } - err = f.ccLogger.Close() - if err != nil { - errs = append(errs, fmt.Errorf("cc logger close: %w", err)) - } - err = f.receiverRTPLogger.Close() - if err != nil { - errs = append(errs, fmt.Errorf("receiver rtp logger close: %w", err)) - } - err = f.receiverRTCPLogger.Close() - if err != nil { - errs = append(errs, fmt.Errorf("receiver rtcp logger close: %w", err)) - } - return errors.Join(errs...) -} diff --git a/vnet/main.go b/vnet/main.go index d46ef6d..fe04971 100644 --- a/vnet/main.go +++ b/vnet/main.go @@ -1,10 +1,12 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package main implements virtual network functionality for bandwidth estimation tests. package main import ( "context" + "errors" "flag" "fmt" "log" @@ -16,6 +18,7 @@ import ( "github.com/pion/transport/v3/vnet" ) +// senderMode defines the type of sender to use in the test. type senderMode int const ( @@ -23,6 +26,7 @@ const ( abrSenderMode ) +// flowMode defines whether to use a single flow or multiple flows in the test. type flowMode int const ( @@ -82,6 +86,8 @@ func main() { } } +var errUnknownLogLevel = errors.New("unknown log level") + func getLoggerFactory(logLevel string) (*logging.DefaultLoggerFactory, error) { logLevels := map[string]logging.LogLevel{ "disable": logging.LogLevelDisabled, @@ -94,7 +100,7 @@ func getLoggerFactory(logLevel string) (*logging.DefaultLoggerFactory, error) { level, ok := logLevels[strings.ToLower(logLevel)] if !ok { - return nil, fmt.Errorf("unknown log level: %v", logLevel) + return nil, fmt.Errorf("%w: %s", errUnknownLogLevel, logLevel) } loggerFactory := &logging.DefaultLoggerFactory{ @@ -106,6 +112,7 @@ func getLoggerFactory(logLevel string) (*logging.DefaultLoggerFactory, error) { return loggerFactory, nil } +// Runner manages the execution of bandwidth estimation tests. type Runner struct { loggerFactory *logging.DefaultLoggerFactory logger logging.LeveledLogger @@ -114,21 +121,25 @@ type Runner struct { flowMode flowMode } +var errUnknownFlowMode = errors.New("unknown flow mode") + +// Run executes the test based on the configured flow mode. func (r *Runner) Run() error { switch r.flowMode { case singleFlowMode: err := r.runVariableAvailableCapacitySingleFlow() if err != nil { - return fmt.Errorf("run variable availiable capacity single flow: %w", err) + return fmt.Errorf("run variable available capacity single flow: %w", err) } case multipleFlowsMode: err := r.runVariableAvailableCapacityMultipleFlows() if err != nil { - return fmt.Errorf("run variable availiable capacity multiple flows: %w", err) + return fmt.Errorf("run variable available capacity multiple flows: %w", err) } default: - return fmt.Errorf("unknown flow mode: %v", r.flowMode) + return fmt.Errorf("%w: %v", errUnknownFlowMode, r.flowMode) } + return nil } @@ -139,7 +150,7 @@ func (r *Runner) runVariableAvailableCapacitySingleFlow() error { } dataDir := fmt.Sprintf("data/%v", r.name) - err = os.MkdirAll(dataDir, os.ModePerm) + err = os.MkdirAll(dataDir, 0o750) if err != nil { return fmt.Errorf("mkdir data: %w", err) } @@ -149,7 +160,7 @@ func (r *Runner) runVariableAvailableCapacitySingleFlow() error { return fmt.Errorf("setup simple flow: %w", err) } defer func(flow Flow) { - err := flow.Close() + err = flow.Close() if err != nil { r.logger.Errorf("flow close: %v", err) } @@ -158,13 +169,13 @@ func (r *Runner) runVariableAvailableCapacitySingleFlow() error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { - err = flow.sender.Start(ctx) + err = flow.sender.sender.Start(ctx) if err != nil { r.logger.Errorf("sender start: %v", err) } }() - c := pathCharacteristics{ + path := pathCharacteristics{ referenceCapacity: 1 * vnet.MBit, phases: []phase{ { @@ -189,7 +200,8 @@ func (r *Runner) runVariableAvailableCapacitySingleFlow() error { }, }, } - r.runNetworkSimulation(c, nm) + r.runNetworkSimulation(path, nm) + return nil } @@ -200,7 +212,7 @@ func (r *Runner) runVariableAvailableCapacityMultipleFlows() error { } dataDir := fmt.Sprintf("data/%v", r.name) - err = os.MkdirAll(dataDir, os.ModePerm) + err = os.MkdirAll(dataDir, 0o750) if err != nil { return fmt.Errorf("mkdir data: %w", err) } @@ -208,7 +220,7 @@ func (r *Runner) runVariableAvailableCapacityMultipleFlows() error { for i := 0; i < 2; i++ { flow, err := NewSimpleFlow(r.loggerFactory, nm, i, r.senderMode, dataDir) defer func(flow Flow) { - err := flow.Close() + err = flow.Close() if err != nil { r.logger.Errorf("flow close: %v", err) } @@ -217,14 +229,14 @@ func (r *Runner) runVariableAvailableCapacityMultipleFlows() error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { - err = flow.sender.Start(ctx) + err = flow.sender.sender.Start(ctx) if err != nil { r.logger.Errorf("sender start: %v", err) } }() } - c := pathCharacteristics{ + path := pathCharacteristics{ referenceCapacity: 1 * vnet.MBit, phases: []phase{ { @@ -255,15 +267,18 @@ func (r *Runner) runVariableAvailableCapacityMultipleFlows() error { }, }, } - r.runNetworkSimulation(c, nm) + r.runNetworkSimulation(path, nm) + return nil } +// pathCharacteristics defines the network characteristics for the test. type pathCharacteristics struct { referenceCapacity int phases []phase } +// phase defines a single phase of the network simulation with specific characteristics. type phase struct { duration time.Duration capacityRatio float64 @@ -272,7 +287,7 @@ type phase struct { func (r *Runner) runNetworkSimulation(c pathCharacteristics, nm *NetworkManager) { for _, phase := range c.phases { - r.logger.Infof("enter next phase: %v\n", phase) + r.logger.Infof("enter next phase: %v", phase) nm.SetCapacity( int(float64(c.referenceCapacity)*phase.capacityRatio), phase.maxBurst, diff --git a/vnet/manager.go b/vnet/manager.go index 7999df4..3a32262 100644 --- a/vnet/manager.go +++ b/vnet/manager.go @@ -1,16 +1,20 @@ // SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT +// Package main implements virtual network functionality for bandwidth estimation tests. package main import ( - "fmt" + "errors" "strings" "github.com/pion/logging" "github.com/pion/transport/v3/vnet" ) +var errNoIPAvailiable = errors.New("no IP available") + +// RouterWithConfig combines a vnet Router with its configuration and IP tracking. type RouterWithConfig struct { *vnet.RouterConfig *vnet.Router @@ -19,7 +23,7 @@ type RouterWithConfig struct { func (r *RouterWithConfig) getIPMapping() (private, public string, err error) { if len(r.usedIPs) >= len(r.StaticIPs) { - return "", "", fmt.Errorf("no IP available") + return "", "", errNoIPAvailiable } ip := r.StaticIPs[0] for i := 1; i < len(r.StaticIPs); i++ { @@ -31,9 +35,11 @@ func (r *RouterWithConfig) getIPMapping() (private, public string, err error) { mapping := strings.Split(ip, "/") public = mapping[0] private = mapping[1] + return } +// NetworkManager manages the virtual network topology for bandwidth estimation tests. type NetworkManager struct { leftRouter *RouterWithConfig leftTBF *vnet.TokenBucketFilter @@ -46,6 +52,7 @@ const ( initMaxBurst = 80 * vnet.KBit ) +// NewManager creates a new NetworkManager with default configuration. func NewManager() (*NetworkManager, error) { wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "0.0.0.0/0", @@ -55,76 +62,39 @@ func NewManager() (*NetworkManager, error) { return nil, err } - leftRouterConfig := &vnet.RouterConfig{ - CIDR: "10.0.1.0/24", - StaticIPs: []string{ - "10.0.1.1/10.0.1.101", - }, - LoggerFactory: logging.NewDefaultLoggerFactory(), - NATType: &vnet.NATType{ - Mode: vnet.NATModeNAT1To1, - }, - } - leftRouter, err := vnet.NewRouter(leftRouterConfig) + leftRouter, leftTBF, err := newLeftNet() if err != nil { return nil, err } - leftTBF, err := vnet.NewTokenBucketFilter( - leftRouter, - vnet.TBFRate(initCapacity), - vnet.TBFMaxBurst(initMaxBurst), - ) - if err != nil { - return nil, err - } err = wan.AddNet(leftTBF) if err != nil { return nil, err } - err = wan.AddChildRouter(leftRouter) + err = wan.AddChildRouter(leftRouter.Router) if err != nil { return nil, err } - rightRouterConfig := &vnet.RouterConfig{ - CIDR: "10.0.2.0/24", - StaticIPs: []string{ - "10.0.2.1/10.0.2.101", - }, - LoggerFactory: logging.NewDefaultLoggerFactory(), - NATType: &vnet.NATType{ - Mode: vnet.NATModeNAT1To1, - }, - } - rightRouter, err := vnet.NewRouter(rightRouterConfig) - if err != nil { - return nil, err - } - rightTBF, err := vnet.NewTokenBucketFilter(rightRouter, vnet.TBFRate(initCapacity), vnet.TBFMaxBurst(initMaxBurst)) + rightRouter, rightTBF, err := newRightNet() if err != nil { return nil, err } + err = wan.AddNet(rightTBF) if err != nil { return nil, err } - err = wan.AddChildRouter(rightRouter) + err = wan.AddChildRouter(rightRouter.Router) if err != nil { return nil, err } manager := &NetworkManager{ - leftRouter: &RouterWithConfig{ - Router: leftRouter, - RouterConfig: leftRouterConfig, - }, - leftTBF: leftTBF, - rightRouter: &RouterWithConfig{ - Router: rightRouter, - RouterConfig: rightRouterConfig, - }, - rightTBF: rightTBF, + leftRouter: leftRouter, + leftTBF: leftTBF, + rightRouter: rightRouter, + rightTBF: rightTBF, } if err := wan.Start(); err != nil { @@ -134,6 +104,7 @@ func NewManager() (*NetworkManager, error) { return manager, nil } +// GetLeftNet creates and returns a new Net on the left side of the network topology. func (m *NetworkManager) GetLeftNet() (*vnet.Net, string, error) { privateIP, publicIP, err := m.leftRouter.getIPMapping() if err != nil { @@ -152,9 +123,11 @@ func (m *NetworkManager) GetLeftNet() (*vnet.Net, string, error) { if err != nil { return nil, "", err } + return net, publicIP, nil } +// GetRightNet creates and returns a new Net on the right side of the network topology. func (m *NetworkManager) GetRightNet() (*vnet.Net, string, error) { privateIP, publicIP, err := m.rightRouter.getIPMapping() if err != nil { @@ -173,10 +146,78 @@ func (m *NetworkManager) GetRightNet() (*vnet.Net, string, error) { if err != nil { return nil, "", err } + return net, publicIP, nil } +// SetCapacity sets the capacity and maximum burst size for both sides of the network. func (m *NetworkManager) SetCapacity(capacity, maxBurst int) { m.leftTBF.Set(vnet.TBFRate(capacity), vnet.TBFMaxBurst(maxBurst)) m.rightTBF.Set(vnet.TBFRate(capacity), vnet.TBFMaxBurst(maxBurst)) } + +func newLeftNet() (*RouterWithConfig, *vnet.TokenBucketFilter, error) { + routerConfig := &vnet.RouterConfig{ + CIDR: "10.0.1.0/24", + StaticIPs: []string{ + "10.0.1.1/10.0.1.101", + }, + LoggerFactory: logging.NewDefaultLoggerFactory(), + NATType: &vnet.NATType{ + Mode: vnet.NATModeNAT1To1, + }, + } + router, err := vnet.NewRouter(routerConfig) + if err != nil { + return nil, nil, err + } + + tbf, err := vnet.NewTokenBucketFilter( + router, + vnet.TBFRate(initCapacity), + vnet.TBFMaxBurst(initMaxBurst), + ) + if err != nil { + return nil, nil, err + } + + routerWithConfig := &RouterWithConfig{ + Router: router, + RouterConfig: routerConfig, + } + + return routerWithConfig, tbf, nil +} + +func newRightNet() (*RouterWithConfig, *vnet.TokenBucketFilter, error) { + routerConfig := &vnet.RouterConfig{ + CIDR: "10.0.2.0/24", + StaticIPs: []string{ + "10.0.2.1/10.0.2.101", + }, + LoggerFactory: logging.NewDefaultLoggerFactory(), + NATType: &vnet.NATType{ + Mode: vnet.NATModeNAT1To1, + }, + } + router, err := vnet.NewRouter(routerConfig) + if err != nil { + return nil, nil, err + } + + tbf, err := vnet.NewTokenBucketFilter( + router, + vnet.TBFRate(initCapacity), + vnet.TBFMaxBurst(initMaxBurst), + ) + if err != nil { + return nil, nil, err + } + + routerWithConfig := &RouterWithConfig{ + Router: router, + RouterConfig: routerConfig, + } + + return routerWithConfig, tbf, nil +}