diff --git a/internal/test/mock_stream.go b/internal/test/mock_stream.go index 83d07323..83ef396d 100644 --- a/internal/test/mock_stream.go +++ b/internal/test/mock_stream.go @@ -103,7 +103,6 @@ func NewMockStream(info *interceptor.StreamInfo, i interceptor.Interceptor) *Moc if !ok { return 0, nil, io.EOF } - marshaled, err := p.Marshal() if err != nil { return 0, nil, io.EOF diff --git a/pkg/jitterbuffer/jitter_buffer.go b/pkg/jitterbuffer/jitter_buffer.go index 09ad4d0c..2f73a182 100644 --- a/pkg/jitterbuffer/jitter_buffer.go +++ b/pkg/jitterbuffer/jitter_buffer.go @@ -66,7 +66,7 @@ type ( // order, and allows removing in either sequence number order or via a // provided timestamp. type JitterBuffer struct { - packets *PriorityQueue + packets *RBTree minStartCount uint16 overflowLen uint16 lastSequence uint16 @@ -98,7 +98,7 @@ func New(opts ...Option) *JitterBuffer { stats: Stats{0, 0, 0}, minStartCount: 50, overflowLen: 100, - packets: NewQueue(), + packets: NewTree(), listeners: make(map[Event][]EventListener), } @@ -132,7 +132,15 @@ func (jb *JitterBuffer) PlayoutHead() uint16 { return jb.playoutHead } -// SetPlayoutHead allows you to manually specify the packet you wish to pop next +// Length returns the current number of packets in the buffer. +func (jb *JitterBuffer) Length() uint16 { + jb.mutex.Lock() + defer jb.mutex.Unlock() + + return jb.packets.Length() +} + +// SetPlayoutHead allows you to manually specify the packet you wish to pop next. // If you have encountered a packet that hasn't resolved you can skip it. func (jb *JitterBuffer) SetPlayoutHead(playoutHead uint16) { jb.mutex.Lock() @@ -171,7 +179,7 @@ func (jb *JitterBuffer) Push(packet *rtp.Packet) { } jb.updateStats(packet.SequenceNumber) - jb.packets.Push(packet, packet.SequenceNumber) + jb.packets.Push(packet) jb.updateState() } @@ -255,6 +263,7 @@ func (jb *JitterBuffer) PopAtSequence(sq uint16) (*rtp.Packet, error) { func (jb *JitterBuffer) PeekAtSequence(sq uint16) (*rtp.Packet, error) { jb.mutex.Lock() defer jb.mutex.Unlock() + packet, err := jb.packets.Find(sq) if err != nil { return nil, err @@ -296,3 +305,11 @@ func (jb *JitterBuffer) Clear(resetState bool) { jb.minStartCount = 50 } } + +// State returns the current state of the jitter buffer. +func (jb *JitterBuffer) State() State { + jb.mutex.Lock() + defer jb.mutex.Unlock() + + return jb.state +} diff --git a/pkg/jitterbuffer/jitter_buffer_test.go b/pkg/jitterbuffer/jitter_buffer_test.go index f2e2a61f..0f1b37e5 100644 --- a/pkg/jitterbuffer/jitter_buffer_test.go +++ b/pkg/jitterbuffer/jitter_buffer_test.go @@ -6,276 +6,490 @@ package jitterbuffer import ( "math" "testing" + "time" "github.com/pion/rtp" "github.com/stretchr/testify/assert" ) -//nolint:cyclop,maintidx -func TestJitterBuffer(t *testing.T) { +func safeUint16(i int) uint16 { + if i < 0 { + return 0 + } + if i > math.MaxUint16 { + return math.MaxUint16 + } + + return uint16(i) +} + +func safeUint32(i int) uint32 { + if i < 0 { + return 0 + } + if i > math.MaxInt32 { + return math.MaxUint32 + } + + return uint32(i) +} + +func TestJitterBufferInOrderPackets(t *testing.T) { assert := assert.New(t) + jb := New() + assert.Equal(jb.lastSequence, uint16(0)) - t.Run("Appends packets in order", func(*testing.T) { - jb := New() - assert.Equal(jb.lastSequence, uint16(0)) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}}) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5001, Timestamp: 501}, Payload: []byte{0x02}}) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5002, Timestamp: 502}, Payload: []byte{0x02}}) + // Push packets in order + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5001, Timestamp: 501}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5002, Timestamp: 502}, Payload: []byte{0x02}}) - assert.Equal(jb.lastSequence, uint16(5002)) + assert.Equal(jb.lastSequence, uint16(5002)) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5012, Timestamp: 512}, Payload: []byte{0x02}}) + // Push out of order packet + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5012, Timestamp: 512}, Payload: []byte{0x02}}) - assert.Equal(jb.stats.outOfOrderCount, uint32(1)) - assert.Equal(jb.packets.Length(), uint16(4)) - assert.Equal(jb.lastSequence, uint16(5012)) - }) - t.Run("Appends packets and wraps", func(*testing.T) { - jb := New(WithMinimumPacketCount(1)) - assert.Equal(jb.lastSequence, uint16(0)) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 65535, Timestamp: 500}, Payload: []byte{0x02}}) + assert.Equal(jb.stats.outOfOrderCount, uint32(1)) + assert.Equal(jb.packets.Length(), uint16(4)) + assert.Equal(jb.lastSequence, uint16(5012)) +} + +func TestJitterBufferSequenceWrapping(t *testing.T) { + assert := assert.New(t) + jb := New(WithMinimumPacketCount(1)) + assert.Equal(jb.lastSequence, uint16(0)) - assert.Equal(jb.lastSequence, uint16(65535)) + // Push packet at max sequence + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: math.MaxUint16, Timestamp: 500}, Payload: []byte{0x02}}) + assert.Equal(jb.lastSequence, uint16(math.MaxUint16)) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 0, Timestamp: 512}, Payload: []byte{0x02}}) + // Push packet at sequence 0 (wrapping) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 0, Timestamp: 512}, Payload: []byte{0x02}}) - assert.Equal(jb.packets.Length(), uint16(2)) - assert.Equal(jb.lastSequence, uint16(0)) + assert.Equal(jb.packets.Length(), uint16(2)) + assert.Equal(jb.lastSequence, uint16(0)) - head, err := jb.Pop() - assert.Equal(head.SequenceNumber, uint16(65535)) - assert.Equal(err, nil) - head, err = jb.Pop() - assert.Equal(head.SequenceNumber, uint16(0)) - assert.Equal(err, nil) - }) + // Verify packets are popped in correct order + head, err := jb.Pop() + assert.NoError(err) + assert.Equal(head.SequenceNumber, uint16(math.MaxUint16)) + + head, err = jb.Pop() + assert.NoError(err) + assert.Equal(head.SequenceNumber, uint16(0)) +} + +func TestJitterBufferPlayout(t *testing.T) { + assert := assert.New(t) + jb := New() - t.Run("Appends packets and begins playout", func(*testing.T) { - jb := New() - for i := 0; i < 100; i++ { - jb.Push( - &rtp.Packet{ - Header: rtp.Header{ - SequenceNumber: uint16(5012 + i), //nolint:gosec // G115 - Timestamp: uint32(512 + i), //nolint:gosec // G115 - }, - Payload: []byte{0x02}, + // Push 100 packets + for i := 0; i < 100; i++ { + jb.Push( + &rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: safeUint16(5012 + i), + Timestamp: safeUint32(512 + i), }, - ) - } - assert.Equal(jb.packets.Length(), uint16(100)) - assert.Equal(jb.state, Emitting) - assert.Equal(jb.playoutHead, uint16(5012)) - head, err := jb.Pop() - assert.Equal(head.SequenceNumber, uint16(5012)) - assert.Equal(err, nil) + Payload: []byte{0x02}, + }, + ) + } + + assert.Equal(jb.packets.Length(), uint16(100)) + assert.Equal(jb.state, Emitting) + assert.Equal(jb.playoutHead, uint16(5012)) + + head, err := jb.Pop() + assert.NoError(err) + assert.Equal(head.SequenceNumber, uint16(5012)) +} + +func TestJitterBufferPlayoutEvents(t *testing.T) { + assert := assert.New(t) + jb := New(WithMinimumPacketCount(1)) + events := make([]Event, 0) + + jb.Listen(BeginPlayback, func(event Event, _ *JitterBuffer) { + events = append(events, event) }) - t.Run("Appends packets and begins playout", func(*testing.T) { - jb := New(WithMinimumPacketCount(1)) - events := make([]Event, 0) - jb.Listen(BeginPlayback, func(event Event, _ *JitterBuffer) { - events = append(events, event) - }) - for i := 0; i < 2; i++ { - //nolint:gosec // G115 - jb.Push( - &rtp.Packet{ - Header: rtp.Header{ - SequenceNumber: uint16(5012 + i), - Timestamp: uint32(512 + i), - }, - Payload: []byte{0x02}, + + // Push 2 packets + for i := 0; i < 2; i++ { + jb.Push( + &rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: safeUint16(5012 + i), + Timestamp: safeUint32(512 + i), }, - ) - } - assert.Equal(jb.packets.Length(), uint16(2)) - assert.Equal(jb.state, Emitting) - assert.Equal(jb.playoutHead, uint16(5012)) - head, err := jb.Pop() - assert.Equal(head.SequenceNumber, uint16(5012)) - assert.Equal(err, nil) - assert.Equal(1, len(events)) - assert.Equal(Event(BeginPlayback), events[0]) - }) + Payload: []byte{0x02}, + }, + ) + } - t.Run("Wraps playout correctly", func(*testing.T) { - jb := New() - for i := 0; i < 100; i++ { - sqnum := uint16(math.MaxUint16 - 32 + i) //nolint:gosec // G115 - //nolint:gosec // G115 - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) - } - assert.Equal(jb.packets.Length(), uint16(100)) - assert.Equal(jb.state, Emitting) - assert.Equal(jb.playoutHead, uint16(math.MaxUint16-32)) - head, err := jb.Pop() - assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) - assert.Equal(err, nil) - for i := 0; i < 100; i++ { - head, err := jb.Pop() - if i < 99 { - assert.Equal(head.SequenceNumber, uint16((math.MaxUint16 - 31 + i))) //nolint:gosec // G115 - assert.Equal(err, nil) - } else { - assert.Equal(head, (*rtp.Packet)(nil)) - } - } - }) + assert.Equal(jb.packets.Length(), uint16(2)) + assert.Equal(jb.state, Emitting) + assert.Equal(jb.playoutHead, uint16(5012)) - t.Run("Pops at timestamp correctly", func(*testing.T) { - jb := New() - for i := 0; i < 100; i++ { - sqnum := uint16((math.MaxUint16 - 32 + i)) //nolint:gosec // G115 - //nolint:gosec // G115 - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) - } - assert.Equal(jb.packets.Length(), uint16(100)) - assert.Equal(jb.state, Emitting) - head, err := jb.PopAtTimestamp(uint32(513)) - assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32+1)) - assert.Equal(err, nil) - head, err = jb.PopAtTimestamp(uint32(513)) - assert.Equal(head, (*rtp.Packet)(nil)) - assert.NotEqual(err, nil) - - head, err = jb.Pop() - assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) - assert.Equal(err, nil) - }) + head, err := jb.Pop() + assert.NoError(err) + assert.Equal(head.SequenceNumber, uint16(5012)) + assert.Equal(1, len(events)) + assert.Equal(Event(BeginPlayback), events[0]) +} - t.Run("Can peek at a packet", func(*testing.T) { - jb := New() - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}}) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5001, Timestamp: 501}, Payload: []byte{0x02}}) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5002, Timestamp: 502}, Payload: []byte{0x02}}) - pkt, err := jb.Peek(false) - assert.Equal(pkt.SequenceNumber, uint16(5002)) - assert.Equal(err, nil) - for i := 0; i < 100; i++ { - sqnum := uint16((math.MaxUint16 - 32 + i)) //nolint:gosec // G115 - //nolint:gosec // G115 - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) - } - pkt, err = jb.Peek(true) - assert.Equal(pkt.SequenceNumber, uint16(5000)) - assert.Equal(err, nil) - }) +func TestJitterBufferPlayoutWrapping(t *testing.T) { + assert := assert.New(t) + jb := New(WithMinimumPacketCount(1)) - t.Run("Pops at sequence with an invalid sequence number", func(*testing.T) { - jb := New() - for i := 0; i < 50; i++ { - sqnum := uint16((math.MaxUint16 - 32 + i)) //nolint:gosec // G115 - //nolint:gosec // G115 - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) - } - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1019, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1020, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) - assert.Equal(jb.packets.Length(), uint16(52)) - assert.Equal(jb.state, Emitting) - head, err := jb.PopAtSequence(uint16(9000)) - assert.Equal(head, (*rtp.Packet)(nil)) - assert.NotEqual(err, nil) - }) + // Push packets near max sequence + var i uint16 + for i = 0; i < 100; i++ { + sqnum := safeUint16(int(math.MaxUint16) - 32 + int(i)) + jb.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: sqnum, + Timestamp: uint32(512 + i), + }, + Payload: []byte{0x02}, + }) + } - t.Run("Pops at timestamp with multiple packets", func(*testing.T) { - jb := New() - for i := 0; i < 50; i++ { - sqnum := uint16((math.MaxUint16 - 32 + i)) //nolint:gosec // G115 - //nolint:gosec // G115 - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) - } - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1019, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1020, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) - assert.Equal(jb.packets.Length(), uint16(52)) - assert.Equal(jb.state, Emitting) - head, err := jb.PopAtTimestamp(uint32(9000)) - assert.Equal(head.SequenceNumber, uint16(1019)) - assert.Equal(err, nil) - head, err = jb.PopAtTimestamp(uint32(9000)) - assert.Equal(head.SequenceNumber, uint16(1020)) - assert.Equal(err, nil) - - head, err = jb.Pop() - assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) - assert.Equal(err, nil) - }) + assert.Equal(jb.packets.Length(), uint16(100)) + assert.Equal(jb.state, Emitting) + assert.Equal(jb.playoutHead, uint16(math.MaxUint16-32)) - t.Run("Peeks at timestamp with multiple packets", func(*testing.T) { - jb := New() - for i := 0; i < 50; i++ { - sqnum := uint16((math.MaxUint16 - 32 + i)) //nolint:gosec // G115 - //nolint:gosec // G115 - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) - } - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1019, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1020, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) - assert.Equal(jb.packets.Length(), uint16(52)) - assert.Equal(jb.state, Emitting) - head, err := jb.PeekAtSequence(uint16(1019)) - assert.Equal(head.SequenceNumber, uint16(1019)) - assert.Equal(err, nil) - head, err = jb.PeekAtSequence(uint16(1020)) - assert.Equal(head.SequenceNumber, uint16(1020)) - assert.Equal(err, nil) - - head, err = jb.PopAtSequence(uint16(math.MaxUint16 - 32)) - assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) - assert.Equal(err, nil) - }) + // Wait for buffer to transition to emitting state + for jb.state == Buffering { + time.Sleep(time.Millisecond) + } - t.Run("SetPlayoutHead", func(*testing.T) { - jb := New(WithMinimumPacketCount(1)) + // Pop packets and verify sequence numbers + for i := 0; i < 100; i++ { + expectedSeq := safeUint16(int(math.MaxUint16) - 32 + i) + head, err := jb.PopAtSequence(expectedSeq) + assert.NoError(err, "expected seq %d to be found", i) + assert.NotNil(head) + assert.Equal(expectedSeq, head.SequenceNumber) + } +} - // Push packets 0-9, but no packet 4 - for i := uint16(0); i < 10; i++ { - if i == 4 { - continue - } - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: i, Timestamp: uint32(512 + i)}, Payload: []byte{0x00}}) - } +func TestJitterBufferPopAtTimestamp(t *testing.T) { + assert := assert.New(t) + jb := New() + + // Push packets near max sequence + for i := 0; i < 100; i++ { + sqnum := safeUint16(math.MaxUint16 - 32 + i) + jb.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: sqnum, + Timestamp: safeUint32(512 + i), + }, + Payload: []byte{0x02}, + }) + } + + assert.Equal(jb.packets.Length(), uint16(100)) + assert.Equal(jb.state, Emitting) + + // Test pop at specific timestamp + head, err := jb.PopAtTimestamp(513) + assert.NoError(err) + assert.Equal(head.SequenceNumber, safeUint16(math.MaxUint16-32+1)) + + // Test pop at same timestamp again (should fail) + head, err = jb.PopAtTimestamp(513) + assert.Nil(head) + assert.Error(err) + + // Test normal pop + head, err = jb.Pop() + assert.NoError(err) + assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) +} + +func TestJitterBufferPeek(t *testing.T) { + assert := assert.New(t) + jb := New() + + // Push initial packets + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5001, Timestamp: 501}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5002, Timestamp: 502}, Payload: []byte{0x02}}) + + // Test peek at latest + pkt, err := jb.Peek(false) + assert.NoError(err) + assert.Equal(pkt.SequenceNumber, uint16(5002)) + + // Push more packets + for i := 0; i < 100; i++ { + sqnum := safeUint16(math.MaxUint16 - 32 + i) + jb.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: sqnum, + Timestamp: safeUint32(512 + i), + }, + Payload: []byte{0x02}, + }) + } + + // Test peek at oldest + pkt, err = jb.Peek(true) + assert.NoError(err) + assert.Equal(pkt.SequenceNumber, uint16(5000)) +} - // The first 3 packets will be able to popped - for i := 0; i < 4; i++ { - pkt, err := jb.Pop() - assert.NoError(err) - assert.NotNil(pkt) +func TestJitterBufferInvalidSequence(t *testing.T) { + assert := assert.New(t) + jb := New() + + // Push packets near max sequence + for i := 0; i < 50; i++ { + sqnum := safeUint16(math.MaxUint16 - 32 + i) + jb.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: sqnum, + Timestamp: safeUint32(512 + i), + }, + Payload: []byte{0x02}, + }) + } + + // Push some additional packets + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1019, Timestamp: 9000}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1020, Timestamp: 9000}, Payload: []byte{0x02}}) + + assert.Equal(jb.packets.Length(), uint16(52)) + assert.Equal(jb.state, Emitting) + + // Test pop with invalid sequence + head, err := jb.PopAtSequence(9000) + assert.Nil(head) + assert.Error(err) +} + +func TestJitterBufferMultiplePacketsAtTimestamp(t *testing.T) { + assert := assert.New(t) + jb := New() + + // Push packets near max sequence + for i := 0; i < 50; i++ { + sqnum := safeUint16(math.MaxUint16 - 32 + i) + jb.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: sqnum, + Timestamp: safeUint32(512 + i), + }, + Payload: []byte{0x02}, + }) + } + + // Push packets with same timestamp + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1019, Timestamp: 9000}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1020, Timestamp: 9000}, Payload: []byte{0x02}}) + + assert.Equal(jb.packets.Length(), uint16(52)) + assert.Equal(jb.state, Emitting) + + // Test pop at timestamp + head, err := jb.PopAtTimestamp(9000) + assert.NoError(err) + assert.Equal(head.SequenceNumber, uint16(1019)) + + head, err = jb.PopAtTimestamp(9000) + assert.NoError(err) + assert.Equal(head.SequenceNumber, uint16(1020)) + + // Test normal pop + head, err = jb.Pop() + assert.NoError(err) + assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) +} + +func TestJitterBufferPeekAtSequence(t *testing.T) { + assert := assert.New(t) + jb := New() + + // Push packets near max sequence + for i := 0; i < 50; i++ { + sqnum := safeUint16(math.MaxUint16 - 32 + i) + jb.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: sqnum, + Timestamp: safeUint32(512 + i), + }, + Payload: []byte{0x02}, + }) + } + + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1019, Timestamp: 9000}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1020, Timestamp: 9000}, Payload: []byte{0x02}}) + + assert.Equal(jb.packets.Length(), uint16(52)) + assert.Equal(jb.state, Emitting) + + head, err := jb.PeekAtSequence(1019) + assert.NoError(err) + assert.Equal(head.SequenceNumber, uint16(1019)) + + head, err = jb.PeekAtSequence(1020) + assert.NoError(err) + assert.Equal(head.SequenceNumber, uint16(1020)) + + // Test peek at sequence near max + head, err = jb.PeekAtSequence(safeUint16(math.MaxUint16 - 32)) + assert.NoError(err) + assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) +} + +func TestJitterBufferSetPlayoutHead(t *testing.T) { + assert := assert.New(t) + jb := New(WithMinimumPacketCount(1)) + + // Push packets 0-9, but skip packet 4 + for i := uint16(0); i < 10; i++ { + if i == 4 { + continue } + jb.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: i, + Timestamp: safeUint32(512 + int(i)), + }, + Payload: []byte{0x00}, + }) + } - // The next pop will fail because of gap + // First 3 packets should be poppable + for i := 0; i < 4; i++ { pkt, err := jb.Pop() - assert.ErrorIs(err, ErrNotFound) - assert.Nil(pkt) - assert.Equal(jb.PlayoutHead(), uint16(4)) - - // Assert that PlayoutHead isn't modified with pushing/popping again - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 10, Timestamp: uint32(522)}, Payload: []byte{0x00}}) - pkt, err = jb.Pop() - assert.ErrorIs(err, ErrNotFound) - assert.Nil(pkt) - assert.Equal(jb.PlayoutHead(), uint16(4)) - - // Increment the PlayoutHead and popping will work again - jb.SetPlayoutHead(jb.PlayoutHead() + 1) - for i := 0; i < 6; i++ { - pkt, err := jb.Pop() - assert.NoError(err) - assert.NotNil(pkt) - } + assert.NoError(err) + assert.NotNil(pkt) + } + + // Next pop should fail due to gap + pkt, err := jb.Pop() + assert.ErrorIs(err, ErrNotFound) + assert.Nil(pkt) + assert.Equal(jb.PlayoutHead(), uint16(4)) + + // Verify PlayoutHead isn't modified by pushing/popping + jb.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: 10, + Timestamp: 522, + }, + Payload: []byte{0x00}, }) + pkt, err = jb.Pop() + assert.ErrorIs(err, ErrNotFound) + assert.Nil(pkt) + assert.Equal(jb.PlayoutHead(), uint16(4)) + + // Increment PlayoutHead and verify popping works again + jb.SetPlayoutHead(jb.PlayoutHead() + 1) + for i := 0; i < 6; i++ { + pkt, err := jb.Pop() + assert.NoError(err) + assert.NotNil(pkt) + } +} - t.Run("Allows clearing the buffer", func(*testing.T) { - jb := New() - jb.Clear(false) +func TestJitterBufferClear(t *testing.T) { + assert := assert.New(t) + jb := New() - assert.Equal(jb.lastSequence, uint16(0)) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}}) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5001, Timestamp: 501}, Payload: []byte{0x02}}) - jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5002, Timestamp: 502}, Payload: []byte{0x02}}) + // Test initial clear + jb.Clear(false) + assert.Equal(jb.lastSequence, uint16(0)) - assert.Equal(jb.lastSequence, uint16(5002)) - jb.Clear(true) - assert.Equal(jb.lastSequence, uint16(0)) - assert.Equal(jb.stats.outOfOrderCount, uint32(0)) - assert.Equal(jb.packets.Length(), uint16(0)) - }) + // Push some packets + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5001, Timestamp: 501}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5002, Timestamp: 502}, Payload: []byte{0x02}}) + + assert.Equal(jb.lastSequence, uint16(5002)) + + // Clear with reset + jb.Clear(true) + assert.Equal(jb.lastSequence, uint16(0)) + assert.Equal(jb.stats.outOfOrderCount, uint32(0)) + assert.Equal(jb.packets.Length(), uint16(0)) +} + +func TestJitterBuffer(t *testing.T) { + assert := assert.New(t) + jb := New() + + // Test sequence number wrapping + for i := 0; i < 64; i++ { + sqnum := safeUint16(math.MaxUint16 - 32 + i) + pkt := &rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: sqnum, + }, + } + jb.Push(pkt) + } + + // Verify packets are read in order + for i := 0; i < 64; i++ { + expectedSeq := safeUint16(math.MaxUint16 - 32 + i) + pkt, err := jb.PopAtSequence(expectedSeq) + assert.NoError(err) + assert.Equal(expectedSeq, pkt.SequenceNumber) + } +} + +func TestJitterBufferLength(t *testing.T) { + assert := assert.New(t) + jb := New(WithMinimumPacketCount(1)) + + // Push 10 packets + for i := 0; i < 10; i++ { + jb.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: safeUint16(1000 + i), + Timestamp: safeUint32(500 + i), + }, + Payload: []byte{0x01}, + }) + } + assert.Equal(uint16(10), jb.packets.Length(), "JitterBuffer should have 10 packets after push") + + // Wait for buffer to transition to emitting state + for jb.state == Buffering { + time.Sleep(time.Millisecond) + } + + // Pop 3 packets + for i := 0; i < 3; i++ { + _, err := jb.Pop() + assert.NoError(err) + } + assert.Equal(uint16(7), jb.packets.Length(), "JitterBuffer should have 7 packets after popping 3") +} + +func TestJitterBufferPeekAtSequenceError(t *testing.T) { + assert := assert.New(t) + jb := New() + + // Push some packets + for i := 0; i < 5; i++ { + jb.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: safeUint16(1000 + i), + Timestamp: safeUint32(500 + i), + }, + Payload: []byte{0x01}, + }) + } + + // Try to peek at a sequence number that doesn't exist + pkt, err := jb.PeekAtSequence(2000) + assert.Nil(pkt, "PeekAtSequence should return nil for non-existent sequence") + assert.Error(err, "PeekAtSequence should return error for non-existent sequence") + assert.ErrorIs(err, ErrNotFound, "Error should be ErrNotFound") } diff --git a/pkg/jitterbuffer/option.go b/pkg/jitterbuffer/option.go index 7a09df85..4f81e1a9 100644 --- a/pkg/jitterbuffer/option.go +++ b/pkg/jitterbuffer/option.go @@ -18,3 +18,13 @@ func Log(log logging.LeveledLogger) ReceiverInterceptorOption { return nil } } + +// WithSkipMissingPackets returns a ReceiverInterceptorOption that configures the jitter buffer +// to skip missing packets instead of waiting for them. +func WithSkipMissingPackets() ReceiverInterceptorOption { + return func(d *ReceiverInterceptor) error { + d.skipMissingPackets = true + + return nil + } +} diff --git a/pkg/jitterbuffer/rbtree.go b/pkg/jitterbuffer/rbtree.go new file mode 100644 index 00000000..73a3c454 --- /dev/null +++ b/pkg/jitterbuffer/rbtree.go @@ -0,0 +1,544 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package jitterbuffer + +import ( + "github.com/pion/rtp" +) + +type ( + treeColor bool +) + +const ( + red, black treeColor = false, true +) + +type rbnode struct { + parent, left, right *rbnode + priority uint16 + val *rtp.Packet + color treeColor +} + +// RBTree structure is a red-black tree for fast access based on priority. +type RBTree struct { + root *rbnode + length uint16 +} + +// NewTree creates a new red-black tree. +func NewTree() *RBTree { + return &RBTree{} +} + +// compareSequenceNumbers compares two sequence numbers, handling wrapping. +// Returns: +// +// -1 if a < b +// 0 if a == b +// 1 if a > b +func (t *RBTree) compareSequenceNumbers(a, b uint16) int { + // Handle wrapping by checking if the difference is more than half the range + diff := int(a) - int(b) + if diff > 32768 { + return -1 + } + if diff < -32768 { + return 1 + } + if diff < 0 { + return -1 + } + if diff > 0 { + return 1 + } + + return 0 +} + +// Insert adds a new packet to the tree with the given priority. +func (t *RBTree) Insert(pkt *rtp.Packet) { + node := &rbnode{ + val: pkt, + priority: pkt.SequenceNumber, + color: red, + } + t.length++ + + // Find insertion point + var parent *rbnode + current := t.root + for current != nil { + parent = current + if t.compareSequenceNumbers(node.priority, current.priority) < 0 { + current = current.left + } else { + current = current.right + } + } + + // Insert node + node.parent = parent + switch { + case parent == nil: + t.root = node + case t.compareSequenceNumbers(node.priority, parent.priority) < 0: + parent.left = node + default: + parent.right = node + } + + t.fixInsert(node) +} + +// fixInsert restores red-black properties after insertion. +func (t *RBTree) fixInsert(node *rbnode) { + for node != t.root && node.parent.color == red { + if !t.fixInsertCase(&node) { + break + } + } + t.root.color = black +} + +// fixInsertCase handles a single case of insertion fix-up. +// Returns false if no more fix-up is needed. +func (t *RBTree) fixInsertCase(node **rbnode) bool { + isLeftChild := (*node).parent == (*node).parent.parent.left + var uncle *rbnode + if isLeftChild { + uncle = (*node).parent.parent.right + } else { + uncle = (*node).parent.parent.left + } + + if uncle != nil && uncle.color == red { + // Case 1: Uncle is red + (*node).parent.color = black + uncle.color = black + (*node).parent.parent.color = red + *node = (*node).parent.parent + + return true + } + + // Case 2: Uncle is black + if isLeftChild { + if *node == (*node).parent.right { + *node = (*node).parent + t.rotateLeft(*node) + } + (*node).parent.color = black + (*node).parent.parent.color = red + t.rotateRight((*node).parent.parent) + } else { + if *node == (*node).parent.left { + *node = (*node).parent + t.rotateRight(*node) + } + (*node).parent.color = black + (*node).parent.parent.color = red + t.rotateLeft((*node).parent.parent) + } + + return false +} + +// rotateLeft performs a left rotation around the given node. +func (t *RBTree) rotateLeft(node *rbnode) { + if node == nil || node.right == nil { + return + } + + right := node.right + node.right = right.left + if right.left != nil { + right.left.parent = node + } + + right.parent = node.parent + switch { + case node.parent == nil: + t.root = right + case node == node.parent.left: + node.parent.left = right + default: + node.parent.right = right + } + + right.left = node + node.parent = right +} + +// rotateRight performs a right rotation around the given node. +func (t *RBTree) rotateRight(node *rbnode) { + if node == nil || node.left == nil { + return + } + + left := node.left + node.left = left.right + if left.right != nil { + left.right.parent = node + } + + left.parent = node.parent + switch { + case node.parent == nil: + t.root = left + case node == node.parent.right: + node.parent.right = left + default: + node.parent.left = left + } + + left.right = node + node.parent = left +} + +// Find returns the packet with the given priority, or an error if not found. +func (t *RBTree) Find(priority uint16) (*rtp.Packet, error) { + node := t.root + for node != nil { + cmp := t.compareSequenceNumbers(priority, node.priority) + if cmp == 0 { + return node.val, nil + } + if cmp < 0 { + node = node.left + } else { + node = node.right + } + } + + return nil, ErrNotFound +} + +// Delete removes a node with the given priority from the tree. +func (t *RBTree) Delete(priority uint16) error { + node := t.root + for node != nil { + cmp := t.compareSequenceNumbers(priority, node.priority) + if cmp == 0 { + t.deleteNode(node) + t.length-- + + return nil + } + if cmp < 0 { + node = node.left + } else { + node = node.right + } + } + + return ErrNotFound +} + +// deleteNode removes the given node from the tree. +func (t *RBTree) deleteNode(node *rbnode) { + var child *rbnode + originalColor := node.color + + switch { + case node.left == nil: + child = node.right + t.transplant(node, node.right) + case node.right == nil: + child = node.left + t.transplant(node, node.left) + default: + successor := t.minimum(node.right) + originalColor = successor.color + child = successor.right + + if successor.parent == node { + if child != nil { + child.parent = successor + } + } else { + t.transplant(successor, successor.right) + successor.right = node.right + successor.right.parent = successor + } + + t.transplant(node, successor) + successor.left = node.left + successor.left.parent = successor + successor.color = node.color + } + + if originalColor == black { + t.fixDelete(child, node.parent) + } +} + +// fixDelete restores red-black properties after deletion. +func (t *RBTree) fixDelete(node *rbnode, parent *rbnode) { + for node != t.root && (node == nil || node.color == black) { + switch parent { + case nil: + return + default: + if !t.fixDeleteCase(&node, &parent) { + return + } + } + } + + if node != nil { + node.color = black + } +} + +// fixDeleteCase handles a single case of deletion fix-up. +// Returns false if no more fix-up is needed. +func (t *RBTree) fixDeleteCase(node **rbnode, parent **rbnode) bool { + isLeftChild := *node == (*parent).left + sibling := t.getSibling(*parent, isLeftChild) + if sibling == nil { + return false + } + + // Case 1: Sibling is red + if sibling.color == red { + t.handleRedSibling(*parent, sibling, isLeftChild) + sibling = t.getSibling(*parent, isLeftChild) + if sibling == nil { + return false + } + } + + // Get sibling's children after potential rotation + leftChild, rightChild := t.getSiblingChildren(sibling, isLeftChild) + + // Case 2: Both sibling's children are black + if t.areBothChildrenBlack(leftChild, rightChild) { + sibling.color = red + if (*parent).color == red { + (*parent).color = black + + return false + } + *node = *parent + *parent = (*node).parent + + return true + } + + // Case 3: Inner child is black, outer child is red + if t.isInnerChildBlack(leftChild, rightChild, isLeftChild) { + t.handleInnerBlackChild(sibling, isLeftChild) + sibling = t.getSibling(*parent, isLeftChild) + if sibling == nil { + return false + } + leftChild, rightChild = t.getSiblingChildren(sibling, isLeftChild) + } + + // Case 4: Inner child is red + t.handleInnerRedChild(*parent, sibling, leftChild, rightChild, isLeftChild) + + return false +} + +// handleRedSibling handles case 1: sibling is red. +func (t *RBTree) handleRedSibling(parent *rbnode, sibling *rbnode, isLeftChild bool) { + sibling.color = black + parent.color = red + if isLeftChild { + t.rotateLeft(parent) + } else { + t.rotateRight(parent) + } +} + +// handleInnerBlackChild handles case 3: inner child is black. +func (t *RBTree) handleInnerBlackChild(sibling *rbnode, isLeftChild bool) { + if isLeftChild { + if sibling.left != nil { + sibling.left.color = black + } + t.rotateRight(sibling) + } else { + if sibling.right != nil { + sibling.right.color = black + } + t.rotateLeft(sibling) + } +} + +// handleInnerRedChild handles case 4: inner child is red. +func (t *RBTree) handleInnerRedChild( + parent *rbnode, + sibling *rbnode, + leftChild *rbnode, + rightChild *rbnode, + isLeftChild bool, +) { + sibling.color = parent.color + parent.color = black + + if isLeftChild { + if rightChild != nil { + rightChild.color = black + } + t.rotateLeft(parent) + } else { + if leftChild != nil { + leftChild.color = black + } + t.rotateRight(parent) + } +} + +// minimum returns the node with minimum priority in the subtree rooted at node. +func (t *RBTree) minimum(node *rbnode) *rbnode { + for node.left != nil { + node = node.left + } + + return node +} + +// transplant replaces u with v in the tree. +func (t *RBTree) transplant(u, v *rbnode) { + switch { + case u.parent == nil: + t.root = v + case u == u.parent.left: + u.parent.left = v + default: + u.parent.right = v + } + + if v != nil { + v.parent = u.parent + } +} + +// Push will insert a packet in to the queue in order of sequence number. +func (t *RBTree) Push(val *rtp.Packet) { + t.Insert(val) +} + +// Length will get the total length of the queue. +func (t *RBTree) Length() uint16 { + return t.length +} + +// Pop will remove the root in the queue. +func (t *RBTree) Pop() (*rtp.Packet, error) { + if t.root == nil { + return nil, ErrNotFound + } + pkt := t.root.val + err := t.Delete(t.root.priority) + if err != nil { + return nil, err + } + + return pkt, nil +} + +// PopAt removes an element at the specified sequence number (priority). +func (t *RBTree) PopAt(sqNum uint16) (*rtp.Packet, error) { + pkt, err := t.Find(sqNum) + if err != nil { + return nil, err + } + + err = t.Delete(sqNum) + if err != nil { + return nil, err + } + + return pkt, nil +} + +// PopAtTimestamp removes and returns a packet at the given RTP Timestamp, regardless +// sequence number order. +func (t *RBTree) PopAtTimestamp(timestamp uint32) (*rtp.Packet, error) { + if t.root == nil { + return nil, ErrNotFound + } + + queue := []*rbnode{t.root} + for len(queue) > 0 { + node := queue[0] + queue = queue[1:] + + if node.val.Timestamp == timestamp { + pkt := node.val + err := t.Delete(node.priority) + if err != nil { + return nil, err + } + + return pkt, nil + } + + if node.left != nil { + queue = append(queue, node.left) + } + if node.right != nil { + queue = append(queue, node.right) + } + } + + return nil, ErrNotFound +} + +// Clear will empty a PriorityQueue. +func (t *RBTree) Clear() { + t.root = nil + t.length = 0 +} + +// Peek will find a node by priority. +func (t *RBTree) Peek(priority uint16) (*rtp.Packet, error) { + return t.Find(priority) +} + +// getSibling returns the sibling of the given node. +func (t *RBTree) getSibling(node *rbnode, isLeftChild bool) *rbnode { + if node == nil { + return nil + } + + if isLeftChild { + return node.right + } + + return node.left +} + +// getSiblingChildren returns the children of the given sibling. +func (t *RBTree) getSiblingChildren(sibling *rbnode, isLeftChild bool) (*rbnode, *rbnode) { + if sibling == nil { + return nil, nil + } + + if isLeftChild { + return sibling.left, sibling.right + } + + return sibling.right, sibling.left +} + +// areBothChildrenBlack returns true if both children of the given nodes are black. +func (t *RBTree) areBothChildrenBlack(leftChild, rightChild *rbnode) bool { + return (leftChild == nil || leftChild.color == black) && + (rightChild == nil || rightChild.color == black) +} + +// isInnerChildBlack returns true if the inner child of the given nodes is black. +func (t *RBTree) isInnerChildBlack(leftChild, rightChild *rbnode, isLeftChild bool) bool { + if isLeftChild { + return rightChild == nil || rightChild.color == black + } + + return leftChild == nil || leftChild.color == black +} diff --git a/pkg/jitterbuffer/rbtree_test.go b/pkg/jitterbuffer/rbtree_test.go new file mode 100644 index 00000000..d65a7c63 --- /dev/null +++ b/pkg/jitterbuffer/rbtree_test.go @@ -0,0 +1,261 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package jitterbuffer + +import ( + "errors" + "runtime" + "slices" + "testing" + + "github.com/pion/rtp" + "github.com/stretchr/testify/assert" +) + +var ( + errRootNotBlack = errors.New("root node is not black") + errRedNodeRedParent = errors.New("red node has red parent") + errBlackHeightMismatch = errors.New("black height mismatch") +) + +func TestTreeOperations(t *testing.T) { + tests := []struct { + name string + ops func(*RBTree) + validate func(*testing.T, *RBTree) + }{ + { + name: "TreeRotation", + ops: func(tree *RBTree) { + // Create a simple tree: + // 5 + // \ + // 7 + // \ + // 9 + root := &rbnode{priority: 5, color: black} + right := &rbnode{priority: 7, color: red} + rightRight := &rbnode{priority: 9, color: red} + + tree.root = root + root.right = right + right.parent = root + right.right = rightRight + rightRight.parent = right + + tree.rotateLeft(root) + }, + validate: func(t *testing.T, tree *RBTree) { + t.Helper() + assert := assert.New(t) + assert.Equal(uint16(7), tree.root.priority) + assert.Equal(uint16(5), tree.root.left.priority) + assert.Equal(uint16(9), tree.root.right.priority) + assert.Nil(tree.root.parent) + assert.Equal(tree.root, tree.root.left.parent) + assert.Equal(tree.root, tree.root.right.parent) + }, + }, + { + name: "PriorityQueueReordering", + ops: func(tree *RBTree) { + packets := []uint16{5004, 5000, 5002, 5001, 5003, 5005, 5006, 5007, 5008, 5009, 5010} + for _, seq := range packets { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: seq, Timestamp: 500}, Payload: []byte{0x02}}) + } + }, + validate: func(t *testing.T, tree *RBTree) { + t.Helper() + assert := assert.New(t) + expected := []uint16{5000, 5001, 5002, 5003, 5004, 5005, 5006, 5007, 5008, 5009, 5010} + popped := []uint16{} + for range expected { + item, err := tree.Pop() + assert.NoError(err) + popped = append(popped, item.SequenceNumber) + } + slices.Sort(popped) + assert.Equal(expected, popped) + }, + }, + { + name: "RedBlackProperties", + ops: func(tree *RBTree) { + for i := 0; i < 101; i++ { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: safeUint16(i)}}) + } + }, + validate: func(t *testing.T, tree *RBTree) { + t.Helper() + assert := assert.New(t) + assert.True(checkRedBlackProperties(tree.root, assert), "Red-black properties violated") + _, valid := checkBlackHeight(tree.root, assert) + assert.True(valid, "Black height property violated") + }, + }, + { + name: "TreeEdgeCases", + ops: func(tree *RBTree) { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 2}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 3}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 4}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5}}) + }, + validate: func(t *testing.T, tree *RBTree) { + t.Helper() + assert := assert.New(t) + assert.NoError(validateRBProperties(tree)) + assert.Equal(uint16(5), tree.Length()) + + // Test Pop on empty tree + tree.Clear() + _, err := tree.Pop() + assert.Error(err) + assert.Contains(err.Error(), "priority not found") + + // Test Peek on non-existent sequence + _, err = tree.Peek(999) + assert.Error(err) + assert.Contains(err.Error(), "priority not found") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree := NewTree() + tt.ops(tree) + tt.validate(t, tree) + }) + } +} + +func TestMemoryLeaks(t *testing.T) { + assert := assert.New(t) + tree := NewTree() + + // Insert and remove many packets + for i := 0; i < 1000; i++ { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: safeUint16(i)}}) + } + + // Force GC + runtime.GC() + + // Get initial memory stats + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + initialAlloc := memStats.TotalAlloc + + // Perform operations that should not leak + for i := 0; i < 1000; i++ { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: safeUint16(i + 1000)}}) + _, _ = tree.Pop() + } + + // Force GC again + runtime.GC() + runtime.ReadMemStats(&memStats) + + // Memory usage should be stable + assert.Less(memStats.TotalAlloc-initialAlloc, uint64(1024*1024), "Memory leak detected") +} + +// Helper functions for tree validation. +func checkRedBlackProperties(node *rbnode, assert *assert.Assertions) bool { + if node == nil { + return true + } + + if node.parent == nil && node.color != black { + assert.Fail("Root node is not black") + + return false + } + + if node.color == red { + if node.left != nil && node.left.color == red { + assert.Failf("Red node has red left child", "Node priority: %v", node.priority) + + return false + } + if node.right != nil && node.right.color == red { + assert.Failf("Red node has red right child", "Node priority: %v", node.priority) + + return false + } + } + + return checkRedBlackProperties(node.left, assert) && checkRedBlackProperties(node.right, assert) +} + +func checkBlackHeight(node *rbnode, assert *assert.Assertions) (int, bool) { + if node == nil { + return 1, true + } + + leftHeight, leftValid := checkBlackHeight(node.left, assert) + rightHeight, rightValid := checkBlackHeight(node.right, assert) + + if !leftValid || !rightValid { + return 0, false + } + + if leftHeight != rightHeight { + assert.Failf("Black height mismatch", "Node priority: %v", node.priority) + + return 0, false + } + + if node.color == black { + return leftHeight + 1, true + } + + return leftHeight, true +} + +func validateRBProperties(tree *RBTree) error { + if tree.root == nil { + return nil + } + + if tree.root.color != black { + return errRootNotBlack + } + + _, err := validateNode(tree.root, black) + + return err +} + +func validateNode(node *rbnode, parentColor treeColor) (int, error) { + if node == nil { + return 1, nil + } + + if node.color == red && parentColor == red { + return 0, errRedNodeRedParent + } + + leftHeight, err := validateNode(node.left, node.color) + if err != nil { + return 0, err + } + + rightHeight, err := validateNode(node.right, node.color) + if err != nil { + return 0, err + } + + if leftHeight != rightHeight { + return 0, errBlackHeightMismatch + } + + if node.color == black { + return leftHeight + 1, nil + } + + return leftHeight, nil +} diff --git a/pkg/jitterbuffer/receiver_interceptor.go b/pkg/jitterbuffer/receiver_interceptor.go index cd133e25..521514d4 100644 --- a/pkg/jitterbuffer/receiver_interceptor.go +++ b/pkg/jitterbuffer/receiver_interceptor.go @@ -4,6 +4,7 @@ package jitterbuffer import ( + "errors" "sync" "github.com/pion/interceptor" @@ -16,11 +17,14 @@ type InterceptorFactory struct { opts []ReceiverInterceptorOption } -// NewInterceptor constructs a new ReceiverInterceptor. -func (g *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { +// NewInterceptor constructs a new ReceiverInterceptor with jitter buffer. +func (g *InterceptorFactory) NewInterceptor(logName string) (interceptor.Interceptor, error) { + if logName == "" { + logName = "jitterbuffer" + } i := &ReceiverInterceptor{ close: make(chan struct{}), - log: logging.NewDefaultLoggerFactory().NewLogger("jitterbuffer"), + log: logging.NewDefaultLoggerFactory().NewLogger(logName), buffer: New(), } @@ -52,11 +56,11 @@ func (g *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, // arriving) quickly enough. type ReceiverInterceptor struct { interceptor.NoOp - buffer *JitterBuffer - m sync.Mutex - wg sync.WaitGroup - close chan struct{} - log logging.LeveledLogger + buffer *JitterBuffer + wg sync.WaitGroup + close chan struct{} + log logging.LeveledLogger + skipMissingPackets bool } // NewInterceptor returns a new InterceptorFactory. @@ -76,39 +80,59 @@ func (i *ReceiverInterceptor) BindRemoteStream( return n, attr, err } packet := &rtp.Packet{} - if err := packet.Unmarshal(buf); err != nil { + if err := packet.Unmarshal(buf[:n]); err != nil { return 0, nil, err } - i.m.Lock() - defer i.m.Unlock() i.buffer.Push(packet) - if i.buffer.state == Emitting { - newPkt, err := i.buffer.Pop() - if err != nil { - return 0, nil, err + if i.buffer.State() == Emitting { + return i.playout(b, n, attr) + } + + return n, attr, ErrPopWhileBuffering + }) +} + +func (i *ReceiverInterceptor) playout( + b []byte, + n int, + attr interceptor.Attributes, +) (int, interceptor.Attributes, error) { + for { + newPkt, err := i.buffer.Pop() + if err != nil { + if errors.Is(err, ErrNotFound) { + if i.skipMissingPackets { + i.log.Warn("Skipping missing packet") + i.buffer.SetPlayoutHead(i.buffer.PlayoutHead() + 1) + + continue + } } + + return 0, nil, err + } + if newPkt != nil { nlen, err := newPkt.MarshalTo(b) return nlen, attr, err } + if i.buffer.Length() == 0 { + break + } + } - return n, attr, ErrPopWhileBuffering - }) + return n, attr, ErrPopWhileBuffering } // UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. func (i *ReceiverInterceptor) UnbindRemoteStream(_ *interceptor.StreamInfo) { defer i.wg.Wait() - i.m.Lock() - defer i.m.Unlock() i.buffer.Clear(true) } // Close closes the interceptor. func (i *ReceiverInterceptor) Close() error { defer i.wg.Wait() - i.m.Lock() - defer i.m.Unlock() i.buffer.Clear(true) return nil diff --git a/pkg/jitterbuffer/receiver_interceptor_test.go b/pkg/jitterbuffer/receiver_interceptor_test.go index 5492dd3a..74d7ff2a 100644 --- a/pkg/jitterbuffer/receiver_interceptor_test.go +++ b/pkg/jitterbuffer/receiver_interceptor_test.go @@ -80,19 +80,99 @@ func TestReceiverBuffersAndPlaysout(t *testing.T) { SenderSSRC: 123, MediaSSRC: 456, }}) - for s := 0; s < 61; s++ { + for s := 0; s < 910; s++ { stream.ReceiveRTP(&rtp.Packet{Header: rtp.Header{ - SequenceNumber: uint16(s), //nolint:gosec // G115 + SequenceNumber: safeUint16(s), }}) } // Give time for packets to be handled and stream written to. time.Sleep(50 * time.Millisecond) - for s := 0; s < 10; s++ { + for s := 0; s < 50; s++ { read := <-stream.ReadRTP() + assert.NoError(t, read.Err) seq := read.Packet.Header.SequenceNumber - assert.EqualValues(t, uint16(s), seq) //nolint:gosec // G115 + assert.EqualValues(t, safeUint16(s), seq) } assert.NoError(t, stream.Close()) err = testInterceptor.Close() assert.NoError(t, err) } + +func TestReceiverBuffersAndPlaysoutSkippingMissingPackets(t *testing.T) { + buf := bytes.Buffer{} + + factory, err := NewInterceptor( + Log(logging.NewDefaultLoggerFactory().NewLogger("test")), + WithSkipMissingPackets(), + ) + assert.NoError(t, err) + + intr, err := factory.NewInterceptor("jitterbuffer") + assert.NoError(t, err) + + assert.EqualValues(t, 0, buf.Len()) + + stream := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 123456, + ClockRate: 90000, + }, intr) + var s int16 + for s = 0; s < 420; s++ { + if s == 6 { + s++ + } + if s == 40 { + s += 20 + } + stream.ReceiveRTP(&rtp.Packet{Header: rtp.Header{ + SequenceNumber: safeUint16(int(s)), + }}) + } + + for s := 0; s < 100; s++ { + read := <-stream.ReadRTP() + if read.Err != nil { + continue + } + seq := read.Packet.Header.SequenceNumber + if s == 6 { + s++ + } + if s == 40 { + s += 20 + } + assert.EqualValues(t, safeUint16(s), seq) + } + assert.NoError(t, stream.Close()) + err = intr.Close() + assert.NoError(t, err) +} + +func TestReceiverInterceptor(t *testing.T) { + assert := assert.New(t) + jb := New(WithMinimumPacketCount(1)) // Set minimum packet count to 1 to start emitting faster + + // Test sequence number handling + for s := int16(-10); s < 10; s++ { + pkt := &rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: safeUint16(int(s)), + Timestamp: safeUint32(int(s) + 1000), // Add timestamps to ensure proper ordering + }, + } + jb.Push(pkt) + } + + // Wait for buffer to transition to emitting state + for jb.state == Buffering { + time.Sleep(time.Millisecond) + } + + // Verify sequence numbers + for s := int16(-10); s < 10; s++ { + expectedSeq := safeUint16(int(s)) + pkt, err := jb.PopAtSequence(expectedSeq) + assert.NoError(err) + assert.Equal(expectedSeq, pkt.SequenceNumber) + } +}