diff --git a/mediadeviceinfo.go b/mediadeviceinfo.go index 21c119e0..ce6c4fa5 100644 --- a/mediadeviceinfo.go +++ b/mediadeviceinfo.go @@ -7,7 +7,7 @@ type MediaDeviceType int // MediaDeviceType definitions. const ( - VideoInput MediaDeviceType = iota + VideoInput MediaDeviceType = iota + 1 AudioInput AudioOutput ) diff --git a/mediastream.go b/mediastream.go index d8520903..6b99239b 100644 --- a/mediastream.go +++ b/mediastream.go @@ -2,8 +2,6 @@ package mediadevices import ( "sync" - - "github.com/pion/webrtc/v2" ) // MediaStream is an interface that represents a collection of existing tracks. @@ -21,21 +19,20 @@ type MediaStream interface { } type mediaStream struct { - trackers map[string]Tracker + trackers map[Tracker]struct{} l sync.RWMutex } -const rtpCodecTypeDefault webrtc.RTPCodecType = 0 +const trackTypeDefault MediaDeviceType = 0 // NewMediaStream creates a MediaStream interface that's defined in // https://w3c.github.io/mediacapture-main/#dom-mediastream func NewMediaStream(trackers ...Tracker) (MediaStream, error) { - m := mediaStream{trackers: make(map[string]Tracker)} + m := mediaStream{trackers: make(map[Tracker]struct{})} for _, tracker := range trackers { - id := tracker.LocalTrack().ID() - if _, ok := m.trackers[id]; !ok { - m.trackers[id] = tracker + if _, ok := m.trackers[tracker]; !ok { + m.trackers[tracker] = struct{}{} } } @@ -43,26 +40,26 @@ func NewMediaStream(trackers ...Tracker) (MediaStream, error) { } func (m *mediaStream) GetAudioTracks() []Tracker { - return m.queryTracks(webrtc.RTPCodecTypeAudio) + return m.queryTracks(AudioInput) } func (m *mediaStream) GetVideoTracks() []Tracker { - return m.queryTracks(webrtc.RTPCodecTypeVideo) + return m.queryTracks(VideoInput) } func (m *mediaStream) GetTracks() []Tracker { - return m.queryTracks(rtpCodecTypeDefault) + return m.queryTracks(trackTypeDefault) } // queryTracks returns all tracks that are the same kind as t. // If t is 0, which is the default, queryTracks will return all the tracks. -func (m *mediaStream) queryTracks(t webrtc.RTPCodecType) []Tracker { +func (m *mediaStream) queryTracks(t MediaDeviceType) []Tracker { m.l.RLock() defer m.l.RUnlock() result := make([]Tracker, 0) - for _, tracker := range m.trackers { - if tracker.LocalTrack().Kind() == t || t == rtpCodecTypeDefault { + for tracker := range m.trackers { + if tracker.Kind() == t || t == trackTypeDefault { result = append(result, tracker) } } @@ -74,17 +71,16 @@ func (m *mediaStream) AddTrack(t Tracker) { m.l.Lock() defer m.l.Unlock() - id := t.LocalTrack().ID() - if _, ok := m.trackers[id]; ok { + if _, ok := m.trackers[t]; ok { return } - m.trackers[id] = t + m.trackers[t] = struct{}{} } func (m *mediaStream) RemoveTrack(t Tracker) { m.l.Lock() defer m.l.Unlock() - delete(m.trackers, t.LocalTrack().ID()) + delete(m.trackers, t) } diff --git a/mediastream_test.go b/mediastream_test.go new file mode 100644 index 00000000..c86b28be --- /dev/null +++ b/mediastream_test.go @@ -0,0 +1,83 @@ +package mediadevices + +import ( + "testing" + + "github.com/pion/webrtc/v2" +) + +type mockMediaStreamTrack struct { + kind MediaDeviceType +} + +func (track *mockMediaStreamTrack) Track() *webrtc.Track { + return nil +} + +func (track *mockMediaStreamTrack) LocalTrack() LocalTrack { + return nil +} + +func (track *mockMediaStreamTrack) Stop() { +} + +func (track *mockMediaStreamTrack) Kind() MediaDeviceType { + return track.kind +} + +func (track *mockMediaStreamTrack) OnEnded(handler func(error)) { +} + +func TestMediaStreamFilters(t *testing.T) { + audioTracks := []Tracker{ + &mockMediaStreamTrack{AudioInput}, + &mockMediaStreamTrack{AudioInput}, + &mockMediaStreamTrack{AudioInput}, + &mockMediaStreamTrack{AudioInput}, + &mockMediaStreamTrack{AudioInput}, + } + + videoTracks := []Tracker{ + &mockMediaStreamTrack{VideoInput}, + &mockMediaStreamTrack{VideoInput}, + &mockMediaStreamTrack{VideoInput}, + } + + tracks := append(audioTracks, videoTracks...) + stream, err := NewMediaStream(tracks...) + if err != nil { + t.Fatal(err) + } + + expect := func(t *testing.T, actual, expected []Tracker) { + if len(actual) != len(expected) { + t.Fatalf("%s: Expected to get %d trackers, but got %d trackers", t.Name(), len(expected), len(actual)) + } + + for _, a := range actual { + found := false + for _, e := range expected { + if e == a { + found = true + break + } + } + + if !found { + t.Fatalf("%s: Expected to find %p in the query results", t.Name(), a) + } + } + } + + t.Run("GetAudioTracks", func(t *testing.T) { + expect(t, stream.GetAudioTracks(), audioTracks) + }) + + t.Run("GetVideoTracks", func(t *testing.T) { + expect(t, stream.GetVideoTracks(), videoTracks) + }) + + t.Run("GetTracks", func(t *testing.T) { + expect(t, stream.GetTracks(), tracks) + }) +} diff --git a/track.go b/track.go index 74cba6f3..8ca69cee 100644 --- a/track.go +++ b/track.go @@ -18,6 +18,7 @@ type Tracker interface { Track() *webrtc.Track LocalTrack() LocalTrack Stop() + Kind() MediaDeviceType // OnEnded registers a handler to receive an error from the media stream track. // If the error is already occured before registering, the handler will be // immediately called. @@ -41,12 +42,14 @@ type track struct { err error mu sync.Mutex endOnce sync.Once + kind MediaDeviceType } func newTrack(opts *MediaDevicesOptions, d driver.Driver, constraints MediaTrackConstraints) (*track, error) { var encoderBuilders []encoderBuilder var rtpCodecs []*webrtc.RTPCodec var buildSampler func(t LocalTrack) samplerFunc + var kind MediaDeviceType var err error err = d.Open() @@ -56,10 +59,12 @@ func newTrack(opts *MediaDevicesOptions, d driver.Driver, constraints MediaTrack switch r := d.(type) { case driver.VideoRecorder: + kind = VideoInput rtpCodecs = opts.codecs[webrtc.RTPCodecTypeVideo] buildSampler = newVideoSampler encoderBuilders, err = newVideoEncoderBuilders(r, constraints) case driver.AudioRecorder: + kind = AudioInput rtpCodecs = opts.codecs[webrtc.RTPCodecTypeAudio] buildSampler = func(t LocalTrack) samplerFunc { return newAudioSampler(t, constraints.selectedMedia.Latency) @@ -108,6 +113,7 @@ func newTrack(opts *MediaDevicesOptions, d driver.Driver, constraints MediaTrack sample: buildSampler(localTrack), d: d, encoder: encoder, + kind: kind, } go t.start() return &t, nil @@ -117,6 +123,11 @@ func newTrack(opts *MediaDevicesOptions, d driver.Driver, constraints MediaTrack return nil, errors.New("newTrack: failed to find a matching codec") } +// Kind returns track's kind +func (t *track) Kind() MediaDeviceType { + return t.kind +} + // OnEnded sets an error handler. When a track has been created and started, if an // error occurs, handler will get called with the error given to the parameter. func (t *track) OnEnded(handler func(error)) {