Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"crypto/rand"
"errors"
"fmt"
"io"
"slices"
"strconv"
"strings"
Expand Down Expand Up @@ -1688,7 +1687,7 @@ func (pc *PeerConnection) handleNonMediaBandwidthProbe() {
}
}

func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop
func (pc *PeerConnection) handleIncomingSSRC(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop,lll
remoteDescription := pc.RemoteDescription()
if remoteDescription == nil {
return errPeerConnRemoteDescriptionNil
Expand Down Expand Up @@ -1725,7 +1724,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
// We read the RTP packet to determine the payload type
b := make([]byte, pc.api.settingEngine.getReceiveMTU())

i, err := rtpStream.Read(b)
i, err := rtpStream.Peek(b)
if err != nil {
return err
}
Expand Down Expand Up @@ -1802,6 +1801,8 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
return err
}

peekedPackets := []*peekedPacket{}

// if the first packet didn't contain simuilcast IDs, then probe more packets
var paddingOnly bool
for readCount := 0; readCount <= simulcastProbeCount; readCount++ {
Expand All @@ -1811,11 +1812,16 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
readCount--
}

i, _, err := interceptor.Read(b, nil)
i, attributes, err := interceptor.Read(b, nil)
if err != nil {
return err
}

peekedPackets = append(peekedPackets, &peekedPacket{
payload: slices.Clone(b[:i]),
attributes: attributes,
})

if paddingOnly, err = handleUnknownRTPPacket(
b[:i], uint8(midExtensionID), //nolint:gosec // G115
uint8(streamIDExtensionID), //nolint:gosec // G115
Expand Down Expand Up @@ -1851,6 +1857,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
interceptor,
rtcpReadStream,
rtcpInterceptor,
peekedPackets,
)
if err != nil {
return err
Expand Down Expand Up @@ -1930,7 +1937,7 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { //nolint:cyclop
continue
}

go func(rtpStream io.Reader, ssrc SSRC) {
go func(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) {
if err := pc.handleIncomingSSRC(rtpStream, ssrc); err != nil {
pc.log.Errorf(incomingUnhandledRTPSsrc, ssrc, err)
}
Expand Down
107 changes: 104 additions & 3 deletions peerconnection_media_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"bufio"
"bytes"
"context"
"crypto/rand"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -2062,14 +2063,13 @@ func TestPeerConnection_Simulcast_RTX(t *testing.T) { //nolint:cyclop
assert.NotZero(t, ridID)
assert.NotZero(t, rsid)

err = signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string {
assert.NoError(t, signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string {
// Original chrome sdp contains no ssrc info https://pastebin.com/raw/JTjX6zg6
re := regexp.MustCompile("(?m)[\r\n]+^.*a=ssrc.*$")
res := re.ReplaceAllString(sdp, "")

return res
})
assert.NoError(t, err)
}))

// padding only packets should not affect simulcast probe
var sequenceNumber uint16
Expand Down Expand Up @@ -2493,3 +2493,104 @@ func Test_PeerConnection_RTX_E2E(t *testing.T) { //nolint:cyclop
closePairNow(t, pcOffer, pcAnswer)
assert.NoError(t, wan.Stop())
}

// Assert that we don't drop any packets during the probe.
func TestPeerConnection_Simulcast_Probe_PacketLoss(t *testing.T) { //nolint:cyclop
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

const rtpPktCount = 10
pcOffer, pcAnswer, wan := createVNetPair(t, nil)

rids := []string{"a", "b", "c"}
vp8WriterA, err := NewTrackLocalStaticRTP(
RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[0]),
)
assert.NoError(t, err)

vp8WriterB, err := NewTrackLocalStaticRTP(
RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[1]),
)
assert.NoError(t, err)

vp8WriterC, err := NewTrackLocalStaticRTP(
RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[2]),
)
assert.NoError(t, err)

sender, err := pcOffer.AddTrack(vp8WriterA)
assert.NoError(t, err)
assert.NotNil(t, sender)

assert.NoError(t, sender.AddEncoding(vp8WriterB))
assert.NoError(t, sender.AddEncoding(vp8WriterC))

expectedBuffer := make([]byte, outboundMTU*rtpPktCount)
_, err = rand.Read(expectedBuffer)
assert.NoError(t, err)

ctx, cancel := context.WithCancel(context.Background())
pcAnswer.OnTrack(func(trackRemote *TrackRemote, _ *RTPReceiver) {
actualBuffer := []byte{}

for i := 0; i < rtpPktCount; i++ {
pkt, _, err := trackRemote.ReadRTP()
assert.NoError(t, err)

actualBuffer = append(actualBuffer, pkt.Payload...)
}

assert.Equal(t, actualBuffer, expectedBuffer)
cancel()
})

var midID, ridID uint8
for _, extension := range sender.GetParameters().HeaderExtensions {
switch extension.URI {
case sdp.SDESMidURI:
midID = uint8(extension.ID) //nolint:gosec // G115
case sdp.SDESRTPStreamIDURI:
ridID = uint8(extension.ID) //nolint:gosec // G115
}
}
assert.NotZero(t, midID)
assert.NotZero(t, ridID)

assert.NoError(t, signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string {
// Original chrome sdp contains no ssrc info https://pastebin.com/raw/JTjX6zg6
re := regexp.MustCompile("(?m)[\r\n]+^.*a=ssrc.*$")
res := re.ReplaceAllString(sdp, "")

return res
}))

peerConnectionConnected := untilConnectionState(PeerConnectionStateConnected, pcOffer, pcAnswer)
peerConnectionConnected.Wait()

for sequenceNumber := uint16(0); sequenceNumber < rtpPktCount; sequenceNumber++ {
pkt := &rtp.Packet{
Header: rtp.Header{
Version: 2,
PayloadType: 96,
SequenceNumber: sequenceNumber,
},
}

// Make sure that packets for Stream received before MID/RID don't get dropped
if sequenceNumber > 3 {
assert.NoError(t, pkt.SetExtension(midID, []byte("0")))
assert.NoError(t, pkt.SetExtension(ridID, []byte(vp8WriterA.RID())))
}

offset := int(sequenceNumber) * outboundMTU
pkt.Payload = expectedBuffer[offset : offset+outboundMTU]
assert.NoError(t, vp8WriterA.WriteRTP(pkt))
}

<-ctx.Done()
assert.NoError(t, wan.Stop())
closePairNow(t, pcOffer, pcAnswer)
}
6 changes: 0 additions & 6 deletions peerconnection_renegotiation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1105,12 +1105,6 @@ func TestPeerConnection_Renegotiation_Simulcast(t *testing.T) {

for _, track := range trackMap {
_, _, err := track.ReadRTP()

// Ignore first Read, this was our peeked data
if err == nil {
_, _, err = track.ReadRTP()
}

assert.Equal(t, err, io.EOF)
}
}
Expand Down
37 changes: 23 additions & 14 deletions rtpreceiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"io"
"math"
"sync"
"sync/atomic"
"time"

"github.com/pion/interceptor"
Expand Down Expand Up @@ -64,8 +65,9 @@ type RTPReceiver struct {

tracks []trackStreams

closed, received chan any
mu sync.RWMutex
closed atomic.Bool
closedChan, received chan any
mu sync.RWMutex

tr *RTPTransceiver

Expand All @@ -84,12 +86,12 @@ func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) (*RT
}

rtpReceiver := &RTPReceiver{
kind: kind,
transport: transport,
api: api,
closed: make(chan any),
received: make(chan any),
tracks: []trackStreams{},
kind: kind,
transport: transport,
api: api,
closedChan: make(chan any),
received: make(chan any),
tracks: []trackStreams{},
rtxPool: sync.Pool{New: func() any {
return make([]byte, api.settingEngine.getReceiveMTU())
}},
Expand Down Expand Up @@ -290,7 +292,7 @@ func (r *RTPReceiver) Read(b []byte) (n int, a interceptor.Attributes, err error
}

return r.tracks[0].rtcpInterceptor.Read(b, a)
case <-r.closed:
case <-r.closedChan:
return 0, nil, io.ErrClosedPipe
}
}
Expand All @@ -315,7 +317,7 @@ func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, a interceptor.

return rtcpInterceptor.Read(b, a)

case <-r.closed:
case <-r.closedChan:
return 0, nil, io.ErrClosedPipe
}
}
Expand Down Expand Up @@ -359,14 +361,18 @@ func (r *RTPReceiver) haveReceived() bool {
}
}

func (r *RTPReceiver) haveClosed() bool {
return r.closed.Load()
}

// Stop irreversibly stops the RTPReceiver.
func (r *RTPReceiver) Stop() error { //nolint:cyclop
r.mu.Lock()
defer r.mu.Unlock()
var err error

select {
case <-r.closed:
case <-r.closedChan:
return err
default:
}
Expand Down Expand Up @@ -405,7 +411,8 @@ func (r *RTPReceiver) Stop() error { //nolint:cyclop
default:
}

close(r.closed)
close(r.closedChan)
r.closed.Store(true)

return err
}
Expand Down Expand Up @@ -519,7 +526,7 @@ func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams {
func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a interceptor.Attributes, err error) {
select {
case <-r.received:
case <-r.closed:
case <-r.closedChan:
return 0, nil, io.EOF
}

Expand All @@ -540,6 +547,7 @@ func (r *RTPReceiver) receiveForRid(
rtpInterceptor interceptor.RTPReader,
rtcpReadStream *srtp.ReadStreamSRTCP,
rtcpInterceptor interceptor.RTCPReader,
peekedPackets []*peekedPacket,
) (*TrackRemote, error) {
r.mu.Lock()
defer r.mu.Unlock()
Expand All @@ -551,6 +559,7 @@ func (r *RTPReceiver) receiveForRid(
r.tracks[i].track.codec = params.Codecs[0]
r.tracks[i].track.params = params
r.tracks[i].track.ssrc = SSRC(streamInfo.SSRC)
r.tracks[i].track.peekedPackets = peekedPackets
r.tracks[i].track.mu.Unlock()

r.tracks[i].streamInfo = streamInfo
Expand Down Expand Up @@ -651,7 +660,7 @@ func (r *RTPReceiver) receiveForRtx(
copy(b[headerLength:i-2], b[headerLength+2:i])

select {
case <-r.closed:
case <-r.closedChan:
r.rtxPool.Put(b) // nolint:staticcheck

return
Expand Down
2 changes: 1 addition & 1 deletion stats_go.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func (p *defaultAudioPlayoutStatsProvider) AddTrack(track *TrackRemote) error {
}

select {
case <-receiver.closed:
case <-receiver.closedChan:
p.removeTrackInternal(track)
case <-ctx.Done():
return
Expand Down
4 changes: 2 additions & 2 deletions stats_go_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2349,7 +2349,7 @@ func TestDefaultAudioPlayoutStatsProvider_AccumulateSnapshot(t *testing.T) {
}

func TestDefaultAudioPlayoutStatsProvider_AddRemoveTrack(t *testing.T) {
receiver := &RTPReceiver{closed: make(chan any)}
receiver := &RTPReceiver{closedChan: make(chan any)}
track := newTrackRemote(RTPCodecTypeAudio, 1234, 0, "", receiver)
samplesPerBatch := 960

Expand All @@ -2371,7 +2371,7 @@ func TestDefaultAudioPlayoutStatsProvider_AddRemoveTrack(t *testing.T) {
}

func TestDefaultAudioPlayoutStatsProvider_MultipleProviders(t *testing.T) {
receiver := &RTPReceiver{closed: make(chan any)}
receiver := &RTPReceiver{closedChan: make(chan any)}
track := newTrackRemote(RTPCodecTypeAudio, 5555, 0, "", receiver)
samplesPerBatch := 960

Expand Down
3 changes: 1 addition & 2 deletions track_local_static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,7 @@ func Test_TrackRemote_ReadRTP_UnmarshalError(t *testing.T) {
tr := newTrackRemote(RTPCodecTypeVideo, 0, 0, "", recv)

tr.mu.Lock()
tr.peeked = []byte{0x80, 96}
tr.peekedAttributes = nil
tr.peekedPackets = []*peekedPacket{{payload: []byte{0x80, 96}}}
tr.mu.Unlock()

pkt, attrs, err := tr.ReadRTP()
Expand Down
Loading
Loading