Skip to content

Commit

Permalink
Fix deadlock in track.Bind() (pion#466)
Browse files Browse the repository at this point in the history
Occurs when read errors happen from a
driver source during a call to track.Bind()
  • Loading branch information
KW-M authored Jan 23, 2023
1 parent f8f8511 commit 5da0ebf
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
17 changes: 13 additions & 4 deletions track.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ type baseTrack struct {
Source
err error
onErrorHandler func(error)
errMu sync.Mutex
mu sync.Mutex
endOnce sync.Once
kind MediaDeviceType
Expand Down Expand Up @@ -129,10 +130,10 @@ func (track *baseTrack) RID() string {
// OnEnded sets an error handler. When a track has been created and started, if an
// error occurs, handler will get called with the error given to the parameter.
func (track *baseTrack) OnEnded(handler func(error)) {
track.mu.Lock()
track.errMu.Lock()
track.onErrorHandler = handler
err := track.err
track.mu.Unlock()
track.errMu.Unlock()

if err != nil && handler != nil {
// Already errored.
Expand All @@ -144,10 +145,10 @@ func (track *baseTrack) OnEnded(handler func(error)) {

// onError is a callback when an error occurs
func (track *baseTrack) onError(err error) {
track.mu.Lock()
track.errMu.Lock()
track.err = err
handler := track.onErrorHandler
track.mu.Unlock()
track.errMu.Unlock()

if handler != nil {
track.endOnce.Do(func() {
Expand All @@ -171,6 +172,14 @@ func (track *baseTrack) bind(ctx webrtc.TrackLocalContext, specializedTrack Trac
for _, wantedCodec := range ctx.CodecParameters() {
logger.Debugf("trying to build %s rtp reader", wantedCodec.MimeType)
encodedReader, err = specializedTrack.NewRTPReader(wantedCodec.MimeType, uint32(ctx.SSRC()), rtpOutboundMTU)

track.errMu.Lock()
if track.err != nil {
err = track.err
encodedReader = nil
}
track.errMu.Unlock()

if err == nil {
selectedCodec = wantedCodec
break
Expand Down
51 changes: 49 additions & 2 deletions track_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,33 @@ package mediadevices

import (
"errors"
"github.com/pion/interceptor"
"io"
"sync"
"testing"
"time"

"github.com/pion/interceptor"
"github.com/pion/webrtc/v3"
)

var errExpected error = errors.New("an error")

type DummyBindTrack struct {
*baseTrack
}

func (track *DummyBindTrack) Bind(ctx webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) {
track.mu.Lock()
defer track.mu.Unlock()

track.onError(errExpected)

<-time.After(5 * time.Millisecond)

return webrtc.RTPCodecParameters{}, nil
}

func TestOnEnded(t *testing.T) {
errExpected := errors.New("an error")

t.Run("ErrorAfterRegister", func(t *testing.T) {
tr := &baseTrack{}
Expand Down Expand Up @@ -54,6 +73,34 @@ func TestOnEnded(t *testing.T) {
t.Error("Timeout")
}
})

t.Run("ErrorDurringBind", func(t *testing.T) {
tr := &DummyBindTrack{
baseTrack: &baseTrack{
activePeerConnections: make(map[string]chan<- chan<- struct{}),
mu: sync.Mutex{},
},
}

called := make(chan error, 1)
tr.OnEnded(func(err error) {
called <- errExpected
})

_, err := tr.Bind(webrtc.TrackLocalContext{})
if err != nil {
t.Fatal(err)
}

select {
case err := <-called:
if err != errExpected {
t.Errorf("Expected to receive error: %v, got: %v", errExpected, err)
}
case <-time.After(10 * time.Millisecond):
t.Error("Timeout")
}
})
}

type fakeRTCPReader struct {
Expand Down

0 comments on commit 5da0ebf

Please sign in to comment.