From 5da0ebf443724eae894a22a59a48344a06154406 Mon Sep 17 00:00:00 2001 From: Kyle <25kylworc@gmail.com> Date: Sun, 22 Jan 2023 17:00:20 -0800 Subject: [PATCH] Fix deadlock in track.Bind() (#466) Occurs when read errors happen from a driver source during a call to track.Bind() --- track.go | 17 +++++++++++++---- track_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/track.go b/track.go index 32d13d1f..121c082b 100644 --- a/track.go +++ b/track.go @@ -83,6 +83,7 @@ type baseTrack struct { Source err error onErrorHandler func(error) + errMu sync.Mutex mu sync.Mutex endOnce sync.Once kind MediaDeviceType @@ -129,10 +130,10 @@ func (track *baseTrack) RID() string { // OnEnded sets an error handler. When a track has been created and started, if an // error occurs, handler will get called with the error given to the parameter. func (track *baseTrack) OnEnded(handler func(error)) { - track.mu.Lock() + track.errMu.Lock() track.onErrorHandler = handler err := track.err - track.mu.Unlock() + track.errMu.Unlock() if err != nil && handler != nil { // Already errored. @@ -144,10 +145,10 @@ func (track *baseTrack) OnEnded(handler func(error)) { // onError is a callback when an error occurs func (track *baseTrack) onError(err error) { - track.mu.Lock() + track.errMu.Lock() track.err = err handler := track.onErrorHandler - track.mu.Unlock() + track.errMu.Unlock() if handler != nil { track.endOnce.Do(func() { @@ -171,6 +172,14 @@ func (track *baseTrack) bind(ctx webrtc.TrackLocalContext, specializedTrack Trac for _, wantedCodec := range ctx.CodecParameters() { logger.Debugf("trying to build %s rtp reader", wantedCodec.MimeType) encodedReader, err = specializedTrack.NewRTPReader(wantedCodec.MimeType, uint32(ctx.SSRC()), rtpOutboundMTU) + + track.errMu.Lock() + if track.err != nil { + err = track.err + encodedReader = nil + } + track.errMu.Unlock() + if err == nil { selectedCodec = wantedCodec break diff --git a/track_test.go b/track_test.go index d01bdc6f..7ff6d84e 100644 --- a/track_test.go +++ b/track_test.go @@ -2,14 +2,33 @@ package mediadevices import ( "errors" - "github.com/pion/interceptor" "io" + "sync" "testing" "time" + + "github.com/pion/interceptor" + "github.com/pion/webrtc/v3" ) +var errExpected error = errors.New("an error") + +type DummyBindTrack struct { + *baseTrack +} + +func (track *DummyBindTrack) Bind(ctx webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) { + track.mu.Lock() + defer track.mu.Unlock() + + track.onError(errExpected) + + <-time.After(5 * time.Millisecond) + + return webrtc.RTPCodecParameters{}, nil +} + func TestOnEnded(t *testing.T) { - errExpected := errors.New("an error") t.Run("ErrorAfterRegister", func(t *testing.T) { tr := &baseTrack{} @@ -54,6 +73,34 @@ func TestOnEnded(t *testing.T) { t.Error("Timeout") } }) + + t.Run("ErrorDurringBind", func(t *testing.T) { + tr := &DummyBindTrack{ + baseTrack: &baseTrack{ + activePeerConnections: make(map[string]chan<- chan<- struct{}), + mu: sync.Mutex{}, + }, + } + + called := make(chan error, 1) + tr.OnEnded(func(err error) { + called <- errExpected + }) + + _, err := tr.Bind(webrtc.TrackLocalContext{}) + if err != nil { + t.Fatal(err) + } + + select { + case err := <-called: + if err != errExpected { + t.Errorf("Expected to receive error: %v, got: %v", errExpected, err) + } + case <-time.After(10 * time.Millisecond): + t.Error("Timeout") + } + }) } type fakeRTCPReader struct {