Skip to content

Commit 1d50bcf

Browse files
Sean-Dercptpcrd
andcommitted
Add Peek to ReadStreamSRTP
This will be used while probing for PayloadType and Simulcast. Relates to pion/webrtc#2777 Co-authored-by: cptpcrd <[email protected]>
1 parent c515598 commit 1d50bcf

File tree

2 files changed

+125
-1
lines changed

2 files changed

+125
-1
lines changed

stream_srtp.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package srtp
66
import (
77
"errors"
88
"io"
9+
"slices"
910
"sync"
1011
"time"
1112

@@ -26,7 +27,8 @@ type ReadStreamSRTP struct {
2627
ssrc uint32
2728
isInited bool
2829

29-
buffer io.ReadWriteCloser
30+
buffer io.ReadWriteCloser
31+
peekedPackets [][]byte
3032
}
3133

3234
// Used by getOrCreateReadStream.
@@ -74,8 +76,31 @@ func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) {
7476
return n, err
7577
}
7678

79+
// Peek reads and decrypts full RTP packet from the nextConn.
80+
// It is then buffered so that a call to `Read` will return it.
81+
func (r *ReadStreamSRTP) Peek(buf []byte) (n int, err error) {
82+
n, err = r.buffer.Read(buf)
83+
if err == nil {
84+
r.peekedPackets = append(r.peekedPackets, slices.Clone(buf[:n]))
85+
}
86+
87+
return
88+
}
89+
7790
// Read reads and decrypts full RTP packet from the nextConn.
7891
func (r *ReadStreamSRTP) Read(buf []byte) (int, error) {
92+
if len(r.peekedPackets) != 0 {
93+
if len(r.peekedPackets[0]) > len(buf) {
94+
return 0, io.ErrShortBuffer
95+
}
96+
97+
n := len(r.peekedPackets[0])
98+
copy(buf, r.peekedPackets[0])
99+
r.peekedPackets = r.peekedPackets[1:]
100+
101+
return n, nil
102+
}
103+
79104
return r.buffer.Read(buf)
80105
}
81106

stream_srtp_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,105 @@ func (c *noopConn) SetDeadline(time.Time) error { return nil }
3535
func (c *noopConn) SetReadDeadline(time.Time) error { return nil }
3636
func (c *noopConn) SetWriteDeadline(time.Time) error { return nil }
3737

38+
func TestPeek(t *testing.T) {
39+
firstBuffer := []byte{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
40+
secondBuffer := []byte{0xBB, 0xBB, 0xBB}
41+
thirdBuffer := []byte{0xCC, 0xCC, 0xCC}
42+
43+
buffer := packetio.NewBuffer()
44+
stream := &ReadStreamSRTP{buffer: buffer}
45+
46+
t.Run("Short Peek", func(t *testing.T) {
47+
_, err := buffer.Write(firstBuffer)
48+
assert.NoError(t, err)
49+
50+
readBuff := make([]byte, 1)
51+
_, err = stream.Peek(readBuff)
52+
assert.Error(t, err, io.ErrShortBuffer)
53+
})
54+
55+
t.Run("Short Read", func(t *testing.T) {
56+
_, err := buffer.Write(firstBuffer)
57+
assert.NoError(t, err)
58+
59+
readBuff := make([]byte, 6)
60+
n, err := stream.Peek(readBuff)
61+
assert.NoError(t, err)
62+
assert.Equal(t, n, 6)
63+
assert.Equal(t, readBuff, firstBuffer)
64+
65+
n, err = stream.Read([]byte{})
66+
assert.Error(t, err, io.ErrShortBuffer)
67+
assert.Equal(t, n, 0)
68+
assert.Equal(t, readBuff, firstBuffer)
69+
70+
n, err = stream.Read(readBuff)
71+
assert.NoError(t, err)
72+
assert.Equal(t, n, 6)
73+
assert.Equal(t, readBuff, firstBuffer)
74+
})
75+
76+
t.Run("Single Peek", func(t *testing.T) {
77+
_, err := buffer.Write(firstBuffer)
78+
assert.NoError(t, err)
79+
80+
readBuff := make([]byte, 6)
81+
82+
n, err := stream.Peek(readBuff)
83+
assert.NoError(t, err)
84+
assert.Equal(t, n, 6)
85+
assert.Equal(t, readBuff, firstBuffer)
86+
87+
n, err = stream.Read(readBuff)
88+
assert.NoError(t, err)
89+
assert.Equal(t, n, 6)
90+
assert.Equal(t, readBuff, firstBuffer)
91+
})
92+
93+
t.Run("Multi Peek", func(t *testing.T) {
94+
_, err := buffer.Write(firstBuffer)
95+
assert.NoError(t, err)
96+
97+
_, err = buffer.Write(secondBuffer)
98+
assert.NoError(t, err)
99+
100+
_, err = buffer.Write(thirdBuffer)
101+
assert.NoError(t, err)
102+
103+
readBuff := make([]byte, 6)
104+
105+
n, err := stream.Peek(readBuff)
106+
assert.NoError(t, err)
107+
assert.Equal(t, n, 6)
108+
assert.Equal(t, readBuff[:n], firstBuffer)
109+
110+
n, err = stream.Peek(readBuff)
111+
assert.NoError(t, err)
112+
assert.Equal(t, n, 3)
113+
assert.Equal(t, readBuff[:n], secondBuffer)
114+
115+
n, err = stream.Peek(readBuff)
116+
assert.NoError(t, err)
117+
assert.Equal(t, n, 3)
118+
assert.Equal(t, readBuff[:n], thirdBuffer)
119+
120+
n, err = stream.Read(readBuff)
121+
assert.NoError(t, err)
122+
assert.Equal(t, n, 6)
123+
assert.Equal(t, readBuff[:n], firstBuffer)
124+
125+
n, err = stream.Read(readBuff)
126+
assert.NoError(t, err)
127+
assert.Equal(t, n, 3)
128+
assert.Equal(t, readBuff[:n], secondBuffer)
129+
130+
n, err = stream.Read(readBuff)
131+
assert.NoError(t, err)
132+
assert.Equal(t, n, 3)
133+
assert.Equal(t, readBuff[:n], thirdBuffer)
134+
})
135+
}
136+
38137
func TestBufferFactory(t *testing.T) {
39138
wg := sync.WaitGroup{}
40139
wg.Add(2)

0 commit comments

Comments
 (0)