diff --git a/go.mod b/go.mod index 563fb925..929d2535 100644 --- a/go.mod +++ b/go.mod @@ -15,12 +15,13 @@ require ( github.com/livekit/protocol v1.36.2-0.20250408143132-c193b8d080da github.com/livekit/psrpc v0.6.1-0.20250205181828-a0beed2e4126 github.com/livekit/server-sdk-go/v2 v2.5.0 - github.com/livekit/sipgo v0.13.2-0.20250130142851-36ed3228d934 + github.com/livekit/sipgo v0.13.2-0.20250410120437-ca5b8ca7b53d github.com/mjibson/go-dsp v0.0.0-20180508042940-11479a337f12 github.com/ory/dockertest/v3 v3.11.0 github.com/pion/interceptor v0.1.37 github.com/pion/rtp v1.8.11 github.com/pion/sdp/v3 v3.0.10 + github.com/pion/srtp/v3 v3.0.4 github.com/pion/webrtc/v4 v4.0.9 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.20.5 @@ -98,7 +99,6 @@ require ( github.com/pion/randutil v0.1.0 // indirect github.com/pion/rtcp v1.2.15 // indirect github.com/pion/sctp v1.8.35 // indirect - github.com/pion/srtp/v3 v3.0.4 // indirect github.com/pion/stun/v3 v3.0.0 // indirect github.com/pion/transport/v3 v3.0.7 // indirect github.com/pion/turn/v4 v4.0.0 // indirect diff --git a/go.sum b/go.sum index 3c9ad814..fc1dfc44 100644 --- a/go.sum +++ b/go.sum @@ -133,8 +133,8 @@ github.com/livekit/psrpc v0.6.1-0.20250205181828-a0beed2e4126 h1:fzuYpAQbCid7ySP github.com/livekit/psrpc v0.6.1-0.20250205181828-a0beed2e4126/go.mod h1:X5WtEZ7OnEs72Fi5/J+i0on3964F1aynQpCalcgMqRo= github.com/livekit/server-sdk-go/v2 v2.5.0 h1:HCKm3f6PvefGp8emNC2mi9+9IXzBYrynuGbtUdp5u+w= github.com/livekit/server-sdk-go/v2 v2.5.0/go.mod h1:98/Sa+Wgb27ABwu0WYxLaMZaRfGljrrtoZDQ2xA4oVg= -github.com/livekit/sipgo v0.13.2-0.20250130142851-36ed3228d934 h1:BKDNIg729VUlRCqQ0dNXFFxuEvxBBIPcqbRADPfkz54= -github.com/livekit/sipgo v0.13.2-0.20250130142851-36ed3228d934/go.mod h1:nbNi0IsYn4tyY2ab7Rafvifty07miHYvgedPMKWbaI4= +github.com/livekit/sipgo v0.13.2-0.20250410120437-ca5b8ca7b53d h1:x3JSKtsQpWt/fynro+s7MusrOqIcaGdCnuSXKGEajXc= +github.com/livekit/sipgo v0.13.2-0.20250410120437-ca5b8ca7b53d/go.mod h1:nbNi0IsYn4tyY2ab7Rafvifty07miHYvgedPMKWbaI4= github.com/mackerelio/go-osstat v0.2.5 h1:+MqTbZUhoIt4m8qzkVoXUJg1EuifwlAJSk4Yl2GXh+o= github.com/mackerelio/go-osstat v0.2.5/go.mod h1:atxwWF+POUZcdtR1wnsUcQxTytoHG4uhl2AKKzrOajY= github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= diff --git a/pkg/media/dtmf/dtmf.go b/pkg/media/dtmf/dtmf.go index e89998ab..8deebcc3 100644 --- a/pkg/media/dtmf/dtmf.go +++ b/pkg/media/dtmf/dtmf.go @@ -161,11 +161,11 @@ func Decode(data []byte) (Event, error) { }, nil } -func DecodeRTP(p *rtp.Packet) (Event, bool) { - if !p.Marker { +func DecodeRTP(h *rtp.Header, payload []byte) (Event, bool) { + if !h.Marker { return Event{}, false } - ev, err := Decode(p.Payload) + ev, err := Decode(payload) if err != nil { return Event{}, false } diff --git a/pkg/media/rtp/conn.go b/pkg/media/rtp/conn.go index 626c67cb..faf52a24 100644 --- a/pkg/media/rtp/conn.go +++ b/pkg/media/rtp/conn.go @@ -16,17 +16,21 @@ package rtp import ( "net" + "net/netip" "sync" "sync/atomic" "time" "github.com/frostbyte73/core" "github.com/pion/rtp" + + "github.com/livekit/protocol/logger" ) var _ Writer = (*Conn)(nil) type ConnConfig struct { + Log logger.Logger MediaTimeoutInitial time.Duration MediaTimeout time.Duration TimeoutCallback func() @@ -40,6 +44,9 @@ func NewConnWith(conn UDPConn, conf *ConnConfig) *Conn { if conf == nil { conf = &ConnConfig{} } + if conf.Log == nil { + conf.Log = logger.GetLogger() + } if conf.MediaTimeoutInitial <= 0 { conf.MediaTimeoutInitial = 30 * time.Second } @@ -47,7 +54,8 @@ func NewConnWith(conn UDPConn, conf *ConnConfig) *Conn { conf.MediaTimeout = 15 * time.Second } c := &Conn{ - readBuf: make([]byte, 1500), // MTU + log: conf.Log, + readBuf: make([]byte, MTUSize+1), // larger buffer to detect overflow received: make(chan struct{}), conn: conn, timeout: conf.MediaTimeout, @@ -67,7 +75,9 @@ type UDPConn interface { Close() error } +// Deprecated: use MediaPort instead type Conn struct { + log logger.Logger wmu sync.Mutex conn UDPConn closed core.Fuse @@ -131,9 +141,11 @@ func (c *Conn) Listen(portMin, portMax int, listenAddr string) error { if listenAddr == "" { listenAddr = "0.0.0.0" } - - var err error - c.conn, err = ListenUDPPortRange(portMin, portMax, net.ParseIP(listenAddr)) + ip, err := netip.ParseAddr(listenAddr) + if err != nil { + return err + } + c.conn, err = ListenUDPPortRange(portMin, portMax, ip) if err != nil { return err } @@ -150,6 +162,7 @@ func (c *Conn) ListenAndServe(portMin, portMax int, listenAddr string) error { func (c *Conn) readLoop() { conn, buf := c.conn, c.readBuf + overflow := false var p rtp.Packet for { n, srcAddr, err := conn.ReadFromUDP(buf) @@ -157,6 +170,13 @@ func (c *Conn) readLoop() { return } c.dest.Store(srcAddr) + if n > MTUSize { + if !overflow { + overflow = true + c.log.Errorw("RTP packet is larger than MTU limit", nil) + } + continue // ignore partial messages + } p = rtp.Packet{} if err := p.Unmarshal(buf[:n]); err != nil { @@ -167,24 +187,23 @@ func (c *Conn) readLoop() { close(c.received) } if h := c.onRTP.Load(); h != nil { - _ = (*h).HandleRTP(&p) + _ = (*h).HandleRTP(&p.Header, p.Payload) } } } -func (c *Conn) WriteRTP(p *rtp.Packet) error { +func (c *Conn) WriteRTP(h *rtp.Header, payload []byte) (int, error) { addr := c.dest.Load() if addr == nil { - return nil + return 0, nil } - data, err := p.Marshal() + data, err := (&rtp.Packet{Header: *h, Payload: payload}).Marshal() if err != nil { - return err + return 0, err } c.wmu.Lock() defer c.wmu.Unlock() - _, err = c.conn.WriteToUDP(data, addr) - return err + return c.conn.WriteToUDP(data, addr) } func (c *Conn) ReadRTP() (*rtp.Packet, *net.UDPAddr, error) { diff --git a/pkg/media/rtp/jitter.go b/pkg/media/rtp/jitter.go index 32a77711..1fd31074 100644 --- a/pkg/media/rtp/jitter.go +++ b/pkg/media/rtp/jitter.go @@ -38,11 +38,11 @@ type jitterHandler struct { buf *jitter.Buffer } -func (h *jitterHandler) HandleRTP(p *rtp.Packet) error { - h.buf.Push(p.Clone()) +func (r *jitterHandler) HandleRTP(h *rtp.Header, payload []byte) error { + r.buf.Push(&rtp.Packet{Header: *h, Payload: payload}) var last error - for _, p := range h.buf.Pop(false) { - if err := h.h.HandleRTP(p); err != nil { + for _, p := range r.buf.Pop(false) { + if err := r.h.HandleRTP(&p.Header, p.Payload); err != nil { last = err } } diff --git a/pkg/media/rtp/listen.go b/pkg/media/rtp/listen.go index 3d9fb7c2..72c26652 100644 --- a/pkg/media/rtp/listen.go +++ b/pkg/media/rtp/listen.go @@ -18,14 +18,15 @@ import ( "errors" "math/rand" "net" + "net/netip" ) var ListenErr = errors.New("failed to listen on udp port") -func ListenUDPPortRange(portMin, portMax int, IP net.IP) (*net.UDPConn, error) { +func ListenUDPPortRange(portMin, portMax int, ip netip.Addr) (*net.UDPConn, error) { if portMin == 0 && portMax == 0 { return net.ListenUDP("udp", &net.UDPAddr{ - IP: IP, + IP: ip.AsSlice(), Port: 0, }) } @@ -48,7 +49,7 @@ func ListenUDPPortRange(portMin, portMax int, IP net.IP) (*net.UDPConn, error) { portCurrent := portStart for { - c, e := net.ListenUDP("udp", &net.UDPAddr{IP: IP, Port: portCurrent}) + c, e := net.ListenUDP("udp", &net.UDPAddr{IP: ip.AsSlice(), Port: portCurrent}) if e == nil { return c, nil } diff --git a/pkg/media/rtp/mux.go b/pkg/media/rtp/mux.go index d115fdee..52d20226 100644 --- a/pkg/media/rtp/mux.go +++ b/pkg/media/rtp/mux.go @@ -35,25 +35,25 @@ type Mux struct { // HandleRTP selects a Handler based on payload type. // Types can be registered with Register. If no handler is set, a default one will be used. -func (m *Mux) HandleRTP(p *rtp.Packet) error { +func (m *Mux) HandleRTP(h *rtp.Header, payload []byte) error { if m == nil { return nil } - var h Handler + var r Handler m.mu.RLock() - if p.PayloadType < byte(len(m.static)) { - h = m.static[p.PayloadType] + if h.PayloadType < byte(len(m.static)) { + r = m.static[h.PayloadType] } else { - h = m.dynamic[p.PayloadType] + r = m.dynamic[h.PayloadType] } - if h == nil { - h = m.def + if r == nil { + r = m.def } m.mu.RUnlock() - if h == nil { + if r == nil { return nil } - return h.HandleRTP(p) + return r.HandleRTP(h, payload) } // SetDefault sets a default RTP handler. diff --git a/pkg/media/rtp/rtp.go b/pkg/media/rtp/rtp.go index 08ccf7c5..fd8eebda 100644 --- a/pkg/media/rtp/rtp.go +++ b/pkg/media/rtp/rtp.go @@ -17,6 +17,7 @@ package rtp import ( "fmt" "math/rand/v2" + "slices" "sync" "github.com/pion/interceptor" @@ -31,7 +32,7 @@ type BytesFrame interface { } type Writer interface { - WriteRTP(p *rtp.Packet) error + WriteRTP(h *rtp.Header, payload []byte) (int, error) } type Reader interface { @@ -39,13 +40,13 @@ type Reader interface { } type Handler interface { - HandleRTP(p *rtp.Packet) error + HandleRTP(h *rtp.Header, payload []byte) error } -type HandlerFunc func(p *rtp.Packet) error +type HandlerFunc func(h *rtp.Header, payload []byte) error -func (fnc HandlerFunc) HandleRTP(p *rtp.Packet) error { - return fnc(p) +func (fnc HandlerFunc) HandleRTP(h *rtp.Header, payload []byte) error { + return fnc(h, payload) } func HandleLoop(r Reader, h Handler) error { @@ -54,7 +55,7 @@ func HandleLoop(r Reader, h Handler) error { if err != nil { return err } - err = h.HandleRTP(p) + err = h.HandleRTP(&p.Header, p.Payload) if err != nil { return err } @@ -64,26 +65,27 @@ func HandleLoop(r Reader, h Handler) error { // Buffer is a Writer that clones and appends RTP packets into a slice. type Buffer []*Packet -func (b *Buffer) WriteRTP(p *Packet) error { - p2 := p.Clone() - *b = append(*b, p2) - return nil +func (b *Buffer) WriteRTP(h *rtp.Header, payload []byte) (int, error) { + *b = append(*b, &rtp.Packet{ + Header: *h, + Payload: slices.Clone(payload), + }) + return len(payload), nil } // NewSeqWriter creates an RTP writer that automatically increments the sequence number. func NewSeqWriter(w Writer) *SeqWriter { s := &SeqWriter{w: w} - s.p = rtp.Packet{ - Header: rtp.Header{ - Version: 2, - SSRC: rand.Uint32(), - SequenceNumber: 0, - }, + s.h = rtp.Header{ + Version: 2, + SSRC: rand.Uint32(), + SequenceNumber: 0, } return s } type Packet = rtp.Packet +type Header = rtp.Header type Event struct { Type byte @@ -95,20 +97,19 @@ type Event struct { type SeqWriter struct { mu sync.Mutex w Writer - p Packet + h Header } func (s *SeqWriter) WriteEvent(ev *Event) error { s.mu.Lock() defer s.mu.Unlock() - s.p.PayloadType = ev.Type - s.p.Payload = ev.Payload - s.p.Marker = ev.Marker - s.p.Timestamp = ev.Timestamp - if err := s.w.WriteRTP(&s.p); err != nil { + s.h.PayloadType = ev.Type + s.h.Marker = ev.Marker + s.h.Timestamp = ev.Timestamp + if _, err := s.w.WriteRTP(&s.h, ev.Payload); err != nil { return err } - s.p.Header.SequenceNumber++ + s.h.SequenceNumber++ return nil } @@ -211,6 +212,6 @@ func (s *MediaStreamIn[T]) String() string { return fmt.Sprintf("RTP(%d) -> %s", s.Writer.SampleRate(), s.Writer) } -func (s *MediaStreamIn[T]) HandleRTP(p *rtp.Packet) error { - return s.Writer.WriteSample(T(p.Payload)) +func (s *MediaStreamIn[T]) HandleRTP(_ *rtp.Header, payload []byte) error { + return s.Writer.WriteSample(T(payload)) } diff --git a/pkg/media/rtp/session.go b/pkg/media/rtp/session.go new file mode 100644 index 00000000..505f53f9 --- /dev/null +++ b/pkg/media/rtp/session.go @@ -0,0 +1,223 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtp + +import ( + "io" + "net" + "slices" + "sync" + + "github.com/frostbyte73/core" + "github.com/pion/rtp" + + "github.com/livekit/protocol/logger" +) + +const ( + enableZeroCopy = true + MTUSize = 1500 +) + +type Session interface { + OpenWriteStream() (WriteStream, error) + AcceptStream() (ReadStream, uint32, error) + Close() error +} + +type WriteStream interface { + // WriteRTP writes RTP packet to the connection. + WriteRTP(h *rtp.Header, payload []byte) (int, error) +} + +type ReadStream interface { + // ReadRTP reads RTP packet and its header from the connection. + ReadRTP(h *rtp.Header, payload []byte) (int, error) +} + +func NewSession(log logger.Logger, conn net.Conn) Session { + return &session{ + log: log, + conn: conn, + w: &writeStream{conn: conn}, + bySSRC: make(map[uint32]*readStream), + rbuf: make([]byte, MTUSize+1), // larger buffer to detect overflow + } +} + +type session struct { + log logger.Logger + conn net.Conn + closed core.Fuse + w *writeStream + + rmu sync.Mutex + rbuf []byte + bySSRC map[uint32]*readStream +} + +func (s *session) OpenWriteStream() (WriteStream, error) { + return s.w, nil +} + +func (s *session) AcceptStream() (ReadStream, uint32, error) { + s.rmu.Lock() + defer s.rmu.Unlock() + overflow := false + for { + n, err := s.conn.Read(s.rbuf[:]) + if err != nil { + return nil, 0, err + } + if n > MTUSize { + overflow = true + if !overflow { + s.log.Errorw("RTP packet is larger than MTU limit", nil) + } + continue // ignore partial messages + } + buf := s.rbuf[:n] + var p rtp.Packet + err = p.Unmarshal(buf) + if err != nil { + continue // ignore + } + + isNew := false + r := s.bySSRC[p.SSRC] + if r == nil { + r = &readStream{ + ssrc: p.SSRC, + closed: s.closed.Watch(), + copied: make(chan int), + recv: make(chan *rtp.Packet, 10), + } + s.bySSRC[p.SSRC] = r + isNew = true + } + r.write(&p) + if isNew { + return r, r.ssrc, nil + } + } +} + +func (s *session) Close() error { + var err error + s.closed.Once(func() { + err = s.conn.Close() + s.rmu.Lock() + defer s.rmu.Unlock() + s.bySSRC = nil + }) + return err +} + +type writeStream struct { + mu sync.Mutex + buf []byte + conn net.Conn +} + +func (w *writeStream) WriteRTP(h *rtp.Header, payload []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + hsz := h.MarshalSize() + sz := hsz + len(payload) + w.buf = w.buf[:0] + w.buf = slices.Grow(w.buf, sz) + buf := w.buf[:sz] + n, err := h.MarshalTo(buf) + if err != nil { + return 0, err + } + copy(buf[n:], payload) + return w.conn.Write(buf) +} + +type readStream struct { + ssrc uint32 + + closed <-chan struct{} + recv chan *rtp.Packet + copied chan int + mu sync.Mutex + hdr *rtp.Header + payload []byte +} + +func (r *readStream) write(p *rtp.Packet) { + if enableZeroCopy { + r.mu.Lock() + h, payload := r.hdr, r.payload + r.hdr, r.payload = nil, nil + r.mu.Unlock() + if h != nil { + // zero copy + *h = p.Header + n := copy(payload, p.Payload) + select { + case <-r.closed: + case r.copied <- n: + } + return + } + } + p.Payload = slices.Clone(p.Payload) + select { + case r.recv <- p: + default: + } +} + +func (r *readStream) ReadRTP(h *rtp.Header, payload []byte) (int, error) { + direct := false + if enableZeroCopy { + r.mu.Lock() + if r.hdr == nil { + r.hdr = h + r.payload = payload + direct = true + } + r.mu.Unlock() + } + if !direct { + select { + case p := <-r.recv: + *h = p.Header + n := copy(payload, p.Payload) + return n, nil + case <-r.closed: + } + return 0, io.EOF + } + defer func() { + r.mu.Lock() + defer r.mu.Unlock() + if r.hdr == h { + r.hdr, r.payload = nil, nil + } + }() + select { + case n := <-r.copied: + return n, nil + case p := <-r.recv: + *h = p.Header + n := copy(payload, p.Payload) + return n, nil + case <-r.closed: + } + return 0, io.EOF +} diff --git a/pkg/media/sdp/offer.go b/pkg/media/sdp/offer.go index 0ff5f257..941f1294 100644 --- a/pkg/media/sdp/offer.go +++ b/pkg/media/sdp/offer.go @@ -15,6 +15,7 @@ package sdp import ( + "encoding/base64" "errors" "fmt" "math/rand/v2" @@ -29,6 +30,20 @@ import ( "github.com/livekit/sip/pkg/media" "github.com/livekit/sip/pkg/media/dtmf" "github.com/livekit/sip/pkg/media/rtp" + "github.com/livekit/sip/pkg/media/srtp" +) + +var ( + ErrNoCommonMedia = errors.New("common audio codec not found") + ErrNoCommonCrypto = errors.New("no common encryption profiles") +) + +type Encryption int + +const ( + EncryptionNone Encryption = iota + EncryptionAllow + EncryptionRequire ) type CodecInfo struct { @@ -70,11 +85,27 @@ func OfferCodecs() []CodecInfo { } type MediaDesc struct { - Codecs []CodecInfo - DTMFType byte // set to 0 if there's no DTMF + Codecs []CodecInfo + DTMFType byte // set to 0 if there's no DTMF + CryptoProfiles []srtp.Profile } -func OfferMedia(rtpListenerPort int) (MediaDesc, *sdp.MediaDescription) { +func appendCryptoProfiles(attrs []sdp.Attribute, profiles []srtp.Profile) []sdp.Attribute { + var buf []byte + for _, p := range profiles { + buf = buf[:0] + buf = append(buf, p.Key...) + buf = append(buf, p.Salt...) + skey := base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString(buf) + attrs = append(attrs, sdp.Attribute{ + Key: "crypto", + Value: fmt.Sprintf("%d %s inline:%s", p.Index, p.Profile, skey), + }) + } + return attrs +} + +func OfferMedia(rtpListenerPort int, encrypted Encryption) (MediaDesc, *sdp.MediaDescription, error) { // Static compiler check for frame duration hardcoded below. var _ = [1]struct{}{}[20*time.Millisecond-rtp.DefFrameDur] @@ -98,26 +129,42 @@ func OfferMedia(rtpListenerPort int) (MediaDesc, *sdp.MediaDescription) { Key: "fmtp", Value: fmt.Sprintf("%d 0-16", dtmfType), }) } + var cryptoProfiles []srtp.Profile + if encrypted != EncryptionNone { + var err error + cryptoProfiles, err = srtp.DefaultProfiles() + if err != nil { + return MediaDesc{}, nil, err + } + attrs = appendCryptoProfiles(attrs, cryptoProfiles) + } + attrs = append(attrs, []sdp.Attribute{ {Key: "ptime", Value: "20"}, {Key: "sendrecv"}, }...) + proto := "AVP" + if encrypted != EncryptionNone { + proto = "SAVP" + } + return MediaDesc{ - Codecs: codecs, - DTMFType: dtmfType, + Codecs: codecs, + DTMFType: dtmfType, + CryptoProfiles: cryptoProfiles, }, &sdp.MediaDescription{ MediaName: sdp.MediaName{ Media: "audio", Port: sdp.RangedPort{Value: rtpListenerPort}, - Protos: []string{"RTP", "AVP"}, + Protos: []string{"RTP", proto}, Formats: formats, }, Attributes: attrs, - } + }, nil } -func AnswerMedia(rtpListenerPort int, audio *AudioConfig) *sdp.MediaDescription { +func AnswerMedia(rtpListenerPort int, audio *AudioConfig, crypt *srtp.Profile) *sdp.MediaDescription { // Static compiler check for frame duration hardcoded below. var _ = [1]struct{}{}[20*time.Millisecond-rtp.DefFrameDur] @@ -134,6 +181,11 @@ func AnswerMedia(rtpListenerPort int, audio *AudioConfig) *sdp.MediaDescription {Key: "fmtp", Value: fmt.Sprintf("%d 0-16", audio.DTMFType)}, }...) } + proto := "AVP" + if crypt != nil { + proto = "SAVP" + attrs = appendCryptoProfiles(attrs, []srtp.Profile{*crypt}) + } attrs = append(attrs, []sdp.Attribute{ {Key: "ptime", Value: "20"}, {Key: "sendrecv"}, @@ -142,7 +194,7 @@ func AnswerMedia(rtpListenerPort int, audio *AudioConfig) *sdp.MediaDescription MediaName: sdp.MediaName{ Media: "audio", Port: sdp.RangedPort{Value: rtpListenerPort}, - Protos: []string{"RTP", "AVP"}, + Protos: []string{"RTP", proto}, Formats: formats, }, Attributes: attrs, @@ -159,10 +211,13 @@ type Offer Description type Answer Description -func NewOffer(publicIp netip.Addr, rtpListenerPort int) *Offer { +func NewOffer(publicIp netip.Addr, rtpListenerPort int, encrypted Encryption) (*Offer, error) { sessId := rand.Uint64() // TODO: do we need to track these? - m, mediaDesc := OfferMedia(rtpListenerPort) + m, mediaDesc, err := OfferMedia(rtpListenerPort, encrypted) + if err != nil { + return nil, err + } offer := sdp.SessionDescription{ Version: 0, Origin: sdp.Origin{ @@ -193,16 +248,34 @@ func NewOffer(publicIp netip.Addr, rtpListenerPort int) *Offer { SDP: offer, Addr: netip.AddrPortFrom(publicIp, uint16(rtpListenerPort)), MediaDesc: m, - } + }, nil } -func (d *Offer) Answer(publicIp netip.Addr, rtpListenerPort int) (*Answer, *MediaConfig, error) { +func (d *Offer) Answer(publicIp netip.Addr, rtpListenerPort int, enc Encryption) (*Answer, *MediaConfig, error) { audio, err := SelectAudio(d.MediaDesc) if err != nil { return nil, nil, err } - mediaDesc := AnswerMedia(rtpListenerPort, audio) + var ( + sconf *srtp.Config + sprof *srtp.Profile + ) + if len(d.CryptoProfiles) != 0 && enc != EncryptionNone { + answer, err := srtp.DefaultProfiles() + if err != nil { + return nil, nil, err + } + sconf, sprof, err = SelectCrypto(d.CryptoProfiles, answer, true) + if err != nil { + return nil, nil, err + } + } + if sprof == nil && enc == EncryptionRequire { + return nil, nil, ErrNoCommonCrypto + } + + mediaDesc := AnswerMedia(rtpListenerPort, audio, sprof) answer := sdp.SessionDescription{ Version: 0, Origin: sdp.Origin{ @@ -243,18 +316,30 @@ func (d *Offer) Answer(publicIp netip.Addr, rtpListenerPort int) (*Answer, *Medi Local: src, Remote: d.Addr, Audio: *audio, + Crypto: sconf, }, nil } -func (d *Answer) Apply(offer *Offer) (*MediaConfig, error) { +func (d *Answer) Apply(offer *Offer, enc Encryption) (*MediaConfig, error) { audio, err := SelectAudio(d.MediaDesc) if err != nil { return nil, err } + var sconf *srtp.Config + if len(d.CryptoProfiles) != 0 && enc != EncryptionNone { + sconf, _, err = SelectCrypto(offer.CryptoProfiles, d.CryptoProfiles, false) + if err != nil { + return nil, err + } + } + if sconf == nil && enc == EncryptionRequire { + return nil, ErrNoCommonCrypto + } return &MediaConfig{ Local: offer.Addr, Remote: d.Addr, Audio: *audio, + Crypto: sconf, }, nil } @@ -292,6 +377,42 @@ func ParseAnswer(data []byte) (*Answer, error) { return (*Answer)(d), nil } +func parseSRTPProfile(val string) (*srtp.Profile, error) { + val = strings.TrimSpace(val) + sub := strings.SplitN(val, " ", 3) + if len(sub) != 3 { + return nil, nil // ignore + } + sind, prof, skey := sub[0], srtp.ProtectionProfile(sub[1]), sub[2] + ind, err := strconv.Atoi(sind) + if err != nil { + return nil, err + } + var ok bool + skey, ok = strings.CutPrefix(skey, "inline:") + if !ok { + return nil, nil // ignore + } + keys, err := base64.StdEncoding.WithPadding(base64.StdPadding).DecodeString(skey) + if err != nil { + return nil, fmt.Errorf("cannot parse crypto key %q: %v", skey, err) + } + var salt []byte + if sp, err := prof.Parse(); err == nil { + keyLen, err := sp.KeyLen() + if err != nil { + return nil, err + } + keys, salt = keys[:keyLen], keys[keyLen:] + } + return &srtp.Profile{ + Index: ind, + Profile: prof, + Key: keys, + Salt: salt, + }, nil +} + func ParseMedia(d *sdp.MediaDescription) (*MediaDesc, error) { var out MediaDesc for _, m := range d.Attributes { @@ -315,6 +436,14 @@ func ParseMedia(d *sdp.MediaDescription) (*MediaDesc, error) { Type: byte(typ), Codec: codec, }) + case "crypto": + p, err := parseSRTPProfile(m.Value) + if err != nil { + return nil, fmt.Errorf("cannot parse srtp profile %q: %v", m.Value, err) + } else if p == nil { + continue + } + out.CryptoProfiles = append(out.CryptoProfiles, *p) } } for _, f := range d.MediaName.Formats { @@ -335,6 +464,7 @@ type MediaConfig struct { Local netip.AddrPort Remote netip.AddrPort Audio AudioConfig + Crypto *srtp.Config } type AudioConfig struct { @@ -361,7 +491,7 @@ func SelectAudio(desc MediaDesc) (*AudioConfig, error) { } } if audioCodec == nil { - return nil, fmt.Errorf("common audio codec not found") + return nil, ErrNoCommonMedia } return &AudioConfig{ Codec: audioCodec, @@ -369,3 +499,40 @@ func SelectAudio(desc MediaDesc) (*AudioConfig, error) { DTMFType: desc.DTMFType, }, nil } + +func SelectCrypto(offer, answer []srtp.Profile, swap bool) (*srtp.Config, *srtp.Profile, error) { + if len(offer) == 0 { + return nil, nil, nil + } + for _, ans := range answer { + sp, err := ans.Profile.Parse() + if err != nil { + continue + } + i := slices.IndexFunc(offer, func(off srtp.Profile) bool { + return off.Profile == ans.Profile + }) + if i >= 0 { + off := offer[i] + c := &srtp.Config{ + Keys: srtp.SessionKeys{ + LocalMasterKey: off.Key, + LocalMasterSalt: off.Salt, + RemoteMasterKey: ans.Key, + RemoteMasterSalt: ans.Salt, + }, + Profile: sp, + } + if swap { + c.Keys.LocalMasterKey, c.Keys.RemoteMasterKey = c.Keys.RemoteMasterKey, c.Keys.LocalMasterKey + c.Keys.LocalMasterSalt, c.Keys.RemoteMasterSalt = c.Keys.RemoteMasterSalt, c.Keys.LocalMasterSalt + } + prof := &off + if swap { + prof = &ans + } + return c, prof, nil + } + } + return nil, nil, nil +} diff --git a/pkg/media/sdp/offer_test.go b/pkg/media/sdp/offer_test.go index c9dbbca4..6fe8b7fb 100644 --- a/pkg/media/sdp/offer_test.go +++ b/pkg/media/sdp/offer_test.go @@ -15,6 +15,8 @@ package sdp_test import ( + "slices" + "strings" "testing" "github.com/pion/sdp/v3" @@ -25,11 +27,22 @@ import ( "github.com/livekit/sip/pkg/media/g722" "github.com/livekit/sip/pkg/media/rtp" . "github.com/livekit/sip/pkg/media/sdp" + "github.com/livekit/sip/pkg/media/srtp" ) +func getInline(s string) string { + const word = "inline:" + i := strings.Index(s, word) + if i < 0 { + return s + } + return s[i+len(word):] +} + func TestSDPMediaOffer(t *testing.T) { const port = 12345 - _, offer := OfferMedia(port) + _, offer, err := OfferMedia(port, EncryptionNone) + require.NoError(t, err) require.Equal(t, &sdp.MediaDescription{ MediaName: sdp.MediaName{ Media: "audio", @@ -48,10 +61,39 @@ func TestSDPMediaOffer(t *testing.T) { }, }, offer) + _, offer, err = OfferMedia(port, EncryptionRequire) + require.NoError(t, err) + i := slices.IndexFunc(offer.Attributes, func(a sdp.Attribute) bool { + return a.Key == "crypto" + }) + require.True(t, i > 0) + require.Equal(t, &sdp.MediaDescription{ + MediaName: sdp.MediaName{ + Media: "audio", + Port: sdp.RangedPort{Value: port}, + Protos: []string{"RTP", "SAVP"}, + Formats: []string{"9", "0", "8", "101"}, + }, + Attributes: []sdp.Attribute{ + {Key: "rtpmap", Value: "9 G722/8000"}, + {Key: "rtpmap", Value: "0 PCMU/8000"}, + {Key: "rtpmap", Value: "8 PCMA/8000"}, + {Key: "rtpmap", Value: "101 telephone-event/8000"}, + {Key: "fmtp", Value: "101 0-16"}, + {Key: "crypto", Value: "1 AES_CM_128_HMAC_SHA1_80 inline:" + getInline(offer.Attributes[i+0].Value)}, + {Key: "crypto", Value: "2 AES_CM_128_HMAC_SHA1_32 inline:" + getInline(offer.Attributes[i+1].Value)}, + {Key: "crypto", Value: "3 AES_256_CM_HMAC_SHA1_80 inline:" + getInline(offer.Attributes[i+2].Value)}, + {Key: "crypto", Value: "4 AES_256_CM_HMAC_SHA1_32 inline:" + getInline(offer.Attributes[i+3].Value)}, + {Key: "ptime", Value: "20"}, + {Key: "sendrecv"}, + }, + }, offer) + media.CodecSetEnabled(g722.SDPName, false) defer media.CodecSetEnabled(g722.SDPName, true) - _, offer = OfferMedia(port) + _, offer, err = OfferMedia(port, EncryptionNone) + require.NoError(t, err) require.Equal(t, &sdp.MediaDescription{ MediaName: sdp.MediaName{ Media: "audio", @@ -244,7 +286,8 @@ func TestSDPMediaAnswer(t *testing.T) { require.Equal(t, c.exp, got) }) } - _, offer := OfferMedia(port) + _, offer, err := OfferMedia(port, EncryptionNone) + require.NoError(t, err) require.Equal(t, &sdp.MediaDescription{ MediaName: sdp.MediaName{ Media: "audio", @@ -263,3 +306,66 @@ func TestSDPMediaAnswer(t *testing.T) { }, }, offer) } + +func TestParseOfferSRTP(t *testing.T) { + const sdpData = `v=0 +o=lin 3723 713 IN IP4 192.168.0.2 +s=Talk +c=IN IP4 192.168.0.2 +t=0 0 +a=rtcp-xr:rcvr-rtt=all:10000 stat-summary=loss,dup,jitt,TTL voip-metrics +a=record:off +m=audio 11200 RTP/SAVP 96 0 9 97 101 +a=rtpmap:96 opus/48000/2 +a=fmtp:96 useinbandfec=1 +a=rtpmap:97 telephone-event/48000 +a=rtpmap:101 telephone-event/8000 +a=crypto:1 AES_CM_128_HMAC_SHA1_80 inline:pMIPxjzYIG5TQuIWfkjTnaACVrzohhFfOGhSMgV1 +a=crypto:2 AES_CM_128_HMAC_SHA1_32 inline:ZKkTQfuCsliegVZtFSya3Z6oEVUtSwjGCfHlbrMf +a=crypto:3 AES_256_CM_HMAC_SHA1_80 inline:BvoLeRu5IcBkgNN14qtqaxi0r7ei2rwuBSodd0SANggS9JHsp5IU7lhEsyna1A== +a=crypto:4 AES_256_CM_HMAC_SHA1_32 inline:j92SDyNTUe0BNGk4LeeCsPqX0qwPP/e9TLafvd7L9waM8r4arjzgUqs7uUERyg== +a=crypto:5 AEAD_AES_128_GCM inline:gGetEkQgGk4NZIoKj/cbFpRkHdocmKlP0u3VMw== +a=crypto:6 AEAD_AES_256_GCM inline:EFFzS2FMyNoYcVcaARU+nvk+JhHmVbvdFtRxZuRi9rDmLYpLms5ySv93iy0= +a=rtcp-fb:* trr-int 5000 +a=rtcp-fb:* ccm tmmbr +` + v, err := ParseOffer([]byte(sdpData)) + require.NoError(t, err) + require.Equal(t, []srtp.Profile{ + { + Index: 1, + Profile: "AES_CM_128_HMAC_SHA1_80", + Key: []byte{0xa4, 0xc2, 0xf, 0xc6, 0x3c, 0xd8, 0x20, 0x6e, 0x53, 0x42, 0xe2, 0x16, 0x7e, 0x48, 0xd3, 0x9d}, + Salt: []byte{0xa0, 0x2, 0x56, 0xbc, 0xe8, 0x86, 0x11, 0x5f, 0x38, 0x68, 0x52, 0x32, 0x5, 0x75}, + }, + { + Index: 2, + Profile: "AES_CM_128_HMAC_SHA1_32", + Key: []byte{0x64, 0xa9, 0x13, 0x41, 0xfb, 0x82, 0xb2, 0x58, 0x9e, 0x81, 0x56, 0x6d, 0x15, 0x2c, 0x9a, 0xdd}, + Salt: []byte{0x9e, 0xa8, 0x11, 0x55, 0x2d, 0x4b, 0x8, 0xc6, 0x9, 0xf1, 0xe5, 0x6e, 0xb3, 0x1f}, + }, + { + Index: 3, + Profile: "AES_256_CM_HMAC_SHA1_80", + Key: []byte{0x6, 0xfa, 0xb, 0x79, 0x1b, 0xb9, 0x21, 0xc0, 0x64, 0x80, 0xd3, 0x75, 0xe2, 0xab, 0x6a, 0x6b, 0x18, 0xb4, 0xaf, 0xb7, 0xa2, 0xda, 0xbc, 0x2e, 0x5, 0x2a, 0x1d, 0x77, 0x44, 0x80, 0x36, 0x8}, + Salt: []uint8{0x12, 0xf4, 0x91, 0xec, 0xa7, 0x92, 0x14, 0xee, 0x58, 0x44, 0xb3, 0x29, 0xda, 0xd4}, + }, + { + Index: 4, + Profile: "AES_256_CM_HMAC_SHA1_32", + Key: []uint8{0x8f, 0xdd, 0x92, 0xf, 0x23, 0x53, 0x51, 0xed, 0x1, 0x34, 0x69, 0x38, 0x2d, 0xe7, 0x82, 0xb0, 0xfa, 0x97, 0xd2, 0xac, 0xf, 0x3f, 0xf7, 0xbd, 0x4c, 0xb6, 0x9f, 0xbd, 0xde, 0xcb, 0xf7, 0x6}, + Salt: []uint8{0x8c, 0xf2, 0xbe, 0x1a, 0xae, 0x3c, 0xe0, 0x52, 0xab, 0x3b, 0xb9, 0x41, 0x11, 0xca}, + }, + { + Index: 5, + Profile: "AEAD_AES_128_GCM", + Key: []uint8{0x80, 0x67, 0xad, 0x12, 0x44, 0x20, 0x1a, 0x4e, 0xd, 0x64, 0x8a, 0xa, 0x8f, 0xf7, 0x1b, 0x16, 0x94, 0x64, 0x1d, 0xda, 0x1c, 0x98, 0xa9, 0x4f, 0xd2, 0xed, 0xd5, 0x33}, + }, + { + Index: 6, + Profile: "AEAD_AES_256_GCM", + Key: []uint8{0x10, 0x51, 0x73, 0x4b, 0x61, 0x4c, 0xc8, 0xda, 0x18, 0x71, 0x57, 0x1a, 0x1, 0x15, 0x3e, 0x9e, 0xf9, 0x3e, 0x26, 0x11, 0xe6, 0x55, 0xbb, 0xdd, 0x16, 0xd4, 0x71, 0x66, 0xe4, 0x62, 0xf6, 0xb0, 0xe6, 0x2d, 0x8a, 0x4b, 0x9a, 0xce, 0x72, 0x4a, 0xff, 0x77, 0x8b, 0x2d}, + }, + }, + v.CryptoProfiles) +} diff --git a/pkg/media/srtp/srtp.go b/pkg/media/srtp/srtp.go new file mode 100644 index 00000000..767ac348 --- /dev/null +++ b/pkg/media/srtp/srtp.go @@ -0,0 +1,158 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package srtp + +import ( + "crypto/rand" + "fmt" + "net" + + prtp "github.com/pion/rtp" + "github.com/pion/srtp/v3" + + "github.com/livekit/protocol/logger" + "github.com/livekit/sip/pkg/media/rtp" +) + +var defaultProfiles = []ProtectionProfile{ + "AES_CM_128_HMAC_SHA1_80", + "AES_CM_128_HMAC_SHA1_32", + "AES_256_CM_HMAC_SHA1_80", + "AES_256_CM_HMAC_SHA1_32", +} + +func DefaultProfiles() ([]Profile, error) { + out := make([]Profile, 0, len(defaultProfiles)) + for i, p := range defaultProfiles { + sp, err := p.Parse() + if err != nil { + return nil, err + } + keyLen, err := sp.KeyLen() + if err != nil { + return nil, err + } + saltLen, err := sp.SaltLen() + if err != nil { + return nil, err + } + key := make([]byte, keyLen) + salt := make([]byte, saltLen) + if _, err := rand.Read(key); err != nil { + return nil, err + } + if _, err := rand.Read(salt); err != nil { + return nil, err + } + out = append(out, Profile{ + Index: i + 1, + Profile: p, + Key: key, + Salt: salt, + }) + } + return out, nil +} + +type Options struct { + Profiles []Profile +} + +type ProtectionProfile string + +func (p ProtectionProfile) Parse() (srtp.ProtectionProfile, error) { + switch p { + case "AES_CM_128_HMAC_SHA1_80": + return srtp.ProtectionProfileAes128CmHmacSha1_80, nil + case "AES_CM_128_HMAC_SHA1_32": + return srtp.ProtectionProfileAes128CmHmacSha1_32, nil + case "AES_256_CM_HMAC_SHA1_80": + return srtp.ProtectionProfileAes256CmHmacSha1_80, nil + case "AES_256_CM_HMAC_SHA1_32": + return srtp.ProtectionProfileAes256CmHmacSha1_32, nil + default: + return 0, fmt.Errorf("unsupported profile %q", p) + } +} + +type Profile struct { + Index int + Profile ProtectionProfile + Key []byte + Salt []byte +} + +type Config = srtp.Config +type SessionKeys = srtp.SessionKeys + +func NewSession(log logger.Logger, conn net.Conn, conf *Config) (rtp.Session, error) { + s, err := srtp.NewSessionSRTP(conn, conf) + if err != nil { + return nil, err + } + return &session{log: log, s: s}, nil +} + +type session struct { + log logger.Logger + s *srtp.SessionSRTP +} + +func (s *session) OpenWriteStream() (rtp.WriteStream, error) { + w, err := s.s.OpenWriteStream() + if err != nil { + return nil, err + } + return writeStream{w: w}, nil +} + +func (s *session) AcceptStream() (rtp.ReadStream, uint32, error) { + r, ssrc, err := s.s.AcceptStream() + if err != nil { + return nil, 0, err + } + return readStream{r: r}, ssrc, nil +} + +func (s *session) Close() error { + return s.s.Close() +} + +type writeStream struct { + w *srtp.WriteStreamSRTP +} + +func (w writeStream) WriteRTP(h *prtp.Header, payload []byte) (int, error) { + return w.w.WriteRTP(h, payload) +} + +type readStream struct { + r *srtp.ReadStreamSRTP +} + +func (r readStream) ReadRTP(h *prtp.Header, payload []byte) (int, error) { + buf := payload + n, err := r.r.Read(buf) + if err != nil { + return 0, err + } + var p rtp.Packet + if err = p.Unmarshal(buf[:n]); err != nil { + return 0, err + } + *h = p.Header + n = copy(payload, p.Payload) + return n, nil +} diff --git a/pkg/service/psrpc.go b/pkg/service/psrpc.go index 264d0bc6..01fa85b8 100644 --- a/pkg/service/psrpc.go +++ b/pkg/service/psrpc.go @@ -7,7 +7,6 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/tracer" - "github.com/livekit/sip/pkg/sip" ) @@ -82,10 +81,11 @@ func DispatchCall(ctx context.Context, psrpcClient rpc.IOInfoClient, log logger. case rpc.SIPDispatchResult_LEGACY_ACCEPT_OR_PIN: if resp.RequestPin { return sip.CallDispatch{ - ProjectID: resp.ProjectId, - TrunkID: resp.SipTrunkId, - DispatchRuleID: resp.SipDispatchRuleId, - Result: sip.DispatchRequestPin, + ProjectID: resp.ProjectId, + TrunkID: resp.SipTrunkId, + DispatchRuleID: resp.SipDispatchRuleId, + Result: sip.DispatchRequestPin, + MediaEncryption: resp.MediaEncryption, } } // TODO: finally deprecate and drop @@ -114,6 +114,7 @@ func DispatchCall(ctx context.Context, psrpcClient rpc.IOInfoClient, log logger. EnabledFeatures: resp.EnabledFeatures, RingingTimeout: resp.RingingTimeout.AsDuration(), MaxCallDuration: resp.MaxCallDuration.AsDuration(), + MediaEncryption: resp.MediaEncryption, } case rpc.SIPDispatchResult_ACCEPT: return sip.CallDispatch{ @@ -141,12 +142,14 @@ func DispatchCall(ctx context.Context, psrpcClient rpc.IOInfoClient, log logger. EnabledFeatures: resp.EnabledFeatures, RingingTimeout: resp.RingingTimeout.AsDuration(), MaxCallDuration: resp.MaxCallDuration.AsDuration(), + MediaEncryption: resp.MediaEncryption, } case rpc.SIPDispatchResult_REQUEST_PIN: return sip.CallDispatch{ - ProjectID: resp.ProjectId, - Result: sip.DispatchRequestPin, - TrunkID: resp.SipTrunkId, + ProjectID: resp.ProjectId, + Result: sip.DispatchRequestPin, + TrunkID: resp.SipTrunkId, + MediaEncryption: resp.MediaEncryption, } case rpc.SIPDispatchResult_REJECT: return sip.CallDispatch{ diff --git a/pkg/sip/client.go b/pkg/sip/client.go index e1df75ea..1c6b2902 100644 --- a/pkg/sip/client.go +++ b/pkg/sip/client.go @@ -160,6 +160,10 @@ func (c *Client) createSIPParticipant(ctx context.Context, req *rpc.InternalCrea if req.SipTrunkId != "" { log = log.WithValues("sipTrunk", req.SipTrunkId) } + enc, err := sdpEncryption(req.MediaEncryption) + if err != nil { + return nil, err + } log = log.WithValues( "callID", req.SipCallId, "room", req.RoomName, @@ -215,6 +219,7 @@ func (c *Client) createSIPParticipant(ctx context.Context, req *rpc.InternalCrea ringingTimeout: req.RingingTimeout.AsDuration(), maxCallDuration: req.MaxCallDuration.AsDuration(), enabledFeatures: req.EnabledFeatures, + mediaEncryption: enc, } log.Infow("Creating SIP participant") call, err := c.newCall(ctx, c.conf, log, LocalTag(req.SipCallId), roomConf, sipConf, state) diff --git a/pkg/sip/inbound.go b/pkg/sip/inbound.go index 5fb2098e..eae60ac4 100644 --- a/pkg/sip/inbound.go +++ b/pkg/sip/inbound.go @@ -36,6 +36,7 @@ import ( "github.com/livekit/protocol/tracer" "github.com/livekit/psrpc" lksdk "github.com/livekit/server-sdk-go/v2" + "github.com/livekit/sip/pkg/media/sdp" "github.com/livekit/sipgo/sip" "github.com/livekit/sip/pkg/config" @@ -433,22 +434,39 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI pinPrompt = true } - // We need to start media first, otherwise we won't be able to send audio prompts to the caller, or receive DTMF. - answerData, err := c.runMediaConn(req.Body(), conf, disp.EnabledFeatures) - if err != nil { - c.log.Errorw("Cannot start media", err) - c.cc.RespondAndDrop(sip.StatusInternalServerError, "") - c.close(true, callDropped, "media-failed") - return err + runMedia := func(enc livekit.SIPMediaEncryption) ([]byte, error) { + answerData, err := c.runMediaConn(req.Body(), enc, conf, disp.EnabledFeatures) + if err != nil { + isError := true + status, reason := callDropped, "media-failed" + if errors.Is(err, sdp.ErrNoCommonMedia) { + status, reason = callMediaFailed, "no-common-codec" + isError = false + } else if errors.Is(err, sdp.ErrNoCommonCrypto) { + status, reason = callMediaFailed, "no-common-crypto" + isError = false + } + if isError { + c.log.Errorw("Cannot start media", err) + } else { + c.log.Warnw("Cannot start media", err) + } + c.cc.RespondAndDrop(sip.StatusInternalServerError, "") + c.close(true, status, reason) + return nil, err + } + return answerData, nil } - acceptCall := func() (bool, error) { + + // We need to start media first, otherwise we won't be able to send audio prompts to the caller, or receive DTMF. + acceptCall := func(answerData []byte) (bool, error) { headers := disp.Headers c.attrsToHdr = disp.AttributesToHeaders if r := c.lkRoom.Room(); r != nil { headers = AttrsToHeaders(r.LocalParticipant.Attributes(), c.attrsToHdr, headers) } c.log.Infow("Accepting the call", "headers", headers) - if err = c.cc.Accept(ctx, answerData, headers); err != nil { + if err := c.cc.Accept(ctx, answerData, headers); err != nil { c.log.Errorw("Cannot respond to INVITE", err) return false, err } @@ -461,15 +479,30 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI } ok := false + var answerData []byte if pinPrompt { + var err error // Accept the call first on the SIP side, so that we can send audio prompts. - if ok, err = acceptCall(); !ok { + // This also means we have to pick encryption setting early, before room is selected. + // Backend must explicitly enable encryption for pin prompts. + answerData, err = runMedia(disp.MediaEncryption) + if err != nil { + return err // already sent a response + } + if ok, err = acceptCall(answerData); !ok { return err // could be success if the caller hung up } disp, ok, err = c.pinPrompt(ctx, trunkID) if !ok { return err // already sent a response. Could be success if user hung up } + } else { + // Start media with given encryption settings. + var err error + answerData, err = runMedia(disp.MediaEncryption) + if err != nil { + return err // already sent a response + } } p := &disp.Room.Participant p.Attributes = HeadersToAttrs(p.Attributes, disp.HeadersToAttributes, disp.IncludeHeaders, c.cc, nil) @@ -502,7 +535,7 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI if ok, err := c.waitSubscribe(ctx, disp.RingingTimeout); !ok { return err // already sent a response. Could be success if caller hung up } - if ok, err := acceptCall(); !ok { + if ok, err := acceptCall(answerData); !ok { return err // already sent a response. Could be success if caller hung up } } @@ -536,11 +569,16 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI } } -func (c *inboundCall) runMediaConn(offerData []byte, conf *config.Config, features []livekit.SIPFeature) (answerData []byte, _ error) { +func (c *inboundCall) runMediaConn(offerData []byte, enc livekit.SIPMediaEncryption, conf *config.Config, features []livekit.SIPFeature) (answerData []byte, _ error) { c.mon.SDPSize(len(offerData), true) c.log.Debugw("SDP offer", "sdp", string(offerData)) + e, err := sdpEncryption(enc) + if err != nil { + c.log.Errorw("Cannot parse encryption", err) + return nil, err + } - mp, err := NewMediaPort(c.log, c.mon, &MediaConfig{ + mp, err := NewMediaPort(c.log, c.mon, &MediaOptions{ IP: c.s.sconf.MediaIP, Ports: conf.RTPPort, MediaTimeoutInitial: c.s.conf.MediaTimeoutInitial, @@ -554,7 +592,7 @@ func (c *inboundCall) runMediaConn(offerData []byte, conf *config.Config, featur c.media.EnableTimeout(false) // enabled once we accept the call c.media.SetDTMFAudio(conf.AudioDTMF) - answer, mconf, err := mp.SetOffer(offerData) + answer, mconf, err := mp.SetOffer(offerData, e) if err != nil { return nil, err } @@ -717,17 +755,19 @@ func (c *inboundCall) close(error bool, status CallStatus, reason string) { } c.setStatus(status) c.mon.CallTerminate(reason) + sipCode, sipStatus := status.SIPStatus() + log := c.log.WithValues("status", sipCode, "reason", reason) if error { - c.log.Warnw("Closing inbound call with error", nil, "reason", reason) + log.Warnw("Closing inbound call with error", nil) } else { - c.log.Infow("Closing inbound call", "reason", reason) + log.Infow("Closing inbound call") } if status != callFlood { - defer c.log.Infow("Inbound call closed", "reason", reason) + defer log.Infow("Inbound call closed") } c.closeMedia() - c.cc.Close() + c.cc.CloseWithStatus(sipCode, sipStatus) if c.callDur != nil { c.callDur() } @@ -1242,17 +1282,20 @@ func (c *sipInbound) sendBye() { sendAndACK(ctx, c, r) } -func (c *sipInbound) sendRejected() { +func (c *sipInbound) sendStatus(code sip.StatusCode, status string) { if c.inviteOk != nil { return // call already established } if c.inviteTx == nil { return // rejected or closed } - _, span := tracer.Start(context.Background(), "sipInbound.sendRejected") + _, span := tracer.Start(context.Background(), "sipInbound.sendStatus") defer span.End() - r := sip.NewResponseFromRequest(c.invite, sip.StatusBusyHere, "Rejected", nil) + if status == "" { + status = sipStatus(code) + } + r := sip.NewResponseFromRequest(c.invite, code, status, nil) if c.setHeaders != nil { for k, v := range c.setHeaders(nil) { r.AppendHeader(sip.NewHeader(k, v)) @@ -1364,12 +1407,17 @@ func (c *sipInbound) handleNotify(req *sip.Request, tx sip.ServerTransaction) er // Close the inbound call cleanly. Depending on the call state it either sends BYE or terminates INVITE with busy status. func (c *sipInbound) Close() { + c.CloseWithStatus(sip.StatusBusyHere, "Rejected") +} + +// CloseWithStatus the inbound call cleanly. Depending on the call state it either sends BYE or terminates INVITE with a specified status. +func (c *sipInbound) CloseWithStatus(code sip.StatusCode, status string) { c.mu.Lock() defer c.mu.Unlock() if c.inviteOk != nil { c.sendBye() } else if c.inviteTx != nil { - c.sendRejected() + c.sendStatus(code, status) } else { c.drop() } diff --git a/pkg/sip/media.go b/pkg/sip/media.go index d4eeb5b8..9810e6e4 100644 --- a/pkg/sip/media.go +++ b/pkg/sip/media.go @@ -17,6 +17,8 @@ package sip import ( "strconv" + prtp "github.com/pion/rtp" + "github.com/livekit/sip/pkg/media/rtp" "github.com/livekit/sip/pkg/stats" ) @@ -26,13 +28,13 @@ const ( RoomSampleRate = 48000 ) -func newRTPStatsHandler(mon *stats.CallMonitor, typ string, h rtp.Handler) rtp.Handler { - if h == nil { - h = rtp.HandlerFunc(func(p *rtp.Packet) error { +func newRTPStatsHandler(mon *stats.CallMonitor, typ string, r rtp.Handler) rtp.Handler { + if r == nil { + r = rtp.HandlerFunc(func(h *rtp.Header, payload []byte) error { return nil }) } - return &rtpStatsHandler{h: h, typ: typ, mon: mon} + return &rtpStatsHandler{h: r, typ: typ, mon: mon} } type rtpStatsHandler struct { @@ -41,34 +43,34 @@ type rtpStatsHandler struct { mon *stats.CallMonitor } -func (h *rtpStatsHandler) HandleRTP(p *rtp.Packet) error { - if h.mon != nil { - typ := h.typ +func (r *rtpStatsHandler) HandleRTP(h *rtp.Header, payload []byte) error { + if r.mon != nil { + typ := r.typ if typ == "" { - typ = strconv.Itoa(int(p.PayloadType)) + typ = strconv.Itoa(int(h.PayloadType)) } - h.mon.RTPPacketRecv(typ) + r.mon.RTPPacketRecv(typ) } - return h.h.HandleRTP(p) + return r.h.HandleRTP(h, payload) } -func newRTPStatsWriter(mon *stats.CallMonitor, typ string, w rtp.Writer) rtp.Writer { +func newRTPStatsWriter(mon *stats.CallMonitor, typ string, w rtp.WriteStream) rtp.WriteStream { return &rtpStatsWriter{w: w, typ: typ, mon: mon} } type rtpStatsWriter struct { - w rtp.Writer + w rtp.WriteStream typ string mon *stats.CallMonitor } -func (h *rtpStatsWriter) WriteRTP(p *rtp.Packet) error { - if h.mon != nil { - typ := h.typ +func (w *rtpStatsWriter) WriteRTP(h *prtp.Header, payload []byte) (int, error) { + if w.mon != nil { + typ := w.typ if typ == "" { - typ = strconv.Itoa(int(p.PayloadType)) + typ = strconv.Itoa(int(h.PayloadType)) } - h.mon.RTPPacketSend(typ) + w.mon.RTPPacketSend(typ) } - return h.w.WriteRTP(p) + return w.w.WriteRTP(h, payload) } diff --git a/pkg/sip/media_port.go b/pkg/sip/media_port.go index 30564c85..ca139134 100644 --- a/pkg/sip/media_port.go +++ b/pkg/sip/media_port.go @@ -16,12 +16,17 @@ package sip import ( "context" + "errors" + "io" "net" "net/netip" + "strings" "sync" "sync/atomic" "time" + "github.com/frostbyte73/core" + "github.com/livekit/mediatransportutil/pkg/rtcconfig" "github.com/livekit/protocol/logger" @@ -29,16 +34,62 @@ import ( "github.com/livekit/sip/pkg/media/dtmf" "github.com/livekit/sip/pkg/media/rtp" "github.com/livekit/sip/pkg/media/sdp" + "github.com/livekit/sip/pkg/media/srtp" "github.com/livekit/sip/pkg/mixer" "github.com/livekit/sip/pkg/stats" ) +type UDPConn interface { + net.Conn + ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) + WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) +} + +func newUDPConn(conn UDPConn) *udpConn { + return &udpConn{UDPConn: conn} +} + +type udpConn struct { + UDPConn + src atomic.Pointer[netip.AddrPort] + dst atomic.Pointer[netip.AddrPort] +} + +func (c *udpConn) GetSrc() (netip.AddrPort, bool) { + ptr := c.src.Load() + if ptr == nil { + return netip.AddrPort{}, false + } + addr := *ptr + return addr, addr.IsValid() +} + +func (c *udpConn) SetDst(addr netip.AddrPort) { + if addr.IsValid() { + c.dst.Store(&addr) + } +} + +func (c *udpConn) Read(b []byte) (n int, err error) { + n, addr, err := c.ReadFromUDPAddrPort(b) + c.src.Store(&addr) + return n, err +} + +func (c *udpConn) Write(b []byte) (n int, err error) { + dst := c.dst.Load() + if dst == nil { + return len(b), nil // ignore + } + return c.WriteToUDPAddrPort(b, *dst) +} + type MediaConf struct { sdp.MediaConfig Processor media.PCM16Processor } -type MediaConfig struct { +type MediaOptions struct { IP netip.Addr Ports rtcconfig.PortRange MediaTimeoutInitial time.Duration @@ -46,31 +97,42 @@ type MediaConfig struct { EnableJitterBuffer bool } -func NewMediaPort(log logger.Logger, mon *stats.CallMonitor, conf *MediaConfig, sampleRate int) (*MediaPort, error) { - return NewMediaPortWith(log, mon, nil, conf, sampleRate) +func NewMediaPort(log logger.Logger, mon *stats.CallMonitor, opts *MediaOptions, sampleRate int) (*MediaPort, error) { + return NewMediaPortWith(log, mon, nil, opts, sampleRate) } -func NewMediaPortWith(log logger.Logger, mon *stats.CallMonitor, conn rtp.UDPConn, conf *MediaConfig, sampleRate int) (*MediaPort, error) { +func NewMediaPortWith(log logger.Logger, mon *stats.CallMonitor, conn UDPConn, opts *MediaOptions, sampleRate int) (*MediaPort, error) { + if opts == nil { + opts = &MediaOptions{} + } + if opts.MediaTimeoutInitial <= 0 { + opts.MediaTimeoutInitial = 30 * time.Second + } + if opts.MediaTimeout <= 0 { + opts.MediaTimeout = 15 * time.Second + } + if conn == nil { + c, err := rtp.ListenUDPPortRange(opts.Ports.Start, opts.Ports.End, netip.AddrFrom4([4]byte{0, 0, 0, 0})) + if err != nil { + return nil, err + } + conn = c + } mediaTimeout := make(chan struct{}) p := &MediaPort{ log: log, + opts: opts, mon: mon, - externalIP: conf.IP, + externalIP: opts.IP, mediaTimeout: mediaTimeout, - jitterEnabled: conf.EnableJitterBuffer, - conn: rtp.NewConnWith(conn, &rtp.ConnConfig{ - MediaTimeoutInitial: conf.MediaTimeoutInitial, - MediaTimeout: conf.MediaTimeout, - TimeoutCallback: func() { - close(mediaTimeout) - }, - }), - audioOut: media.NewSwitchWriter(sampleRate), - audioIn: media.NewSwitchWriter(sampleRate), - } - if err := p.conn.ListenAndServe(conf.Ports.Start, conf.Ports.End, "0.0.0.0"); err != nil { - return nil, err + jitterEnabled: opts.EnableJitterBuffer, + port: newUDPConn(conn), + audioOut: media.NewSwitchWriter(sampleRate), + audioIn: media.NewSwitchWriter(sampleRate), } + go p.timeoutLoop(func() { + close(mediaTimeout) + }) p.log.Debugw("listening for media on UDP", "port", p.Port()) return p, nil } @@ -78,16 +140,22 @@ func NewMediaPortWith(log logger.Logger, mon *stats.CallMonitor, conn rtp.UDPCon // MediaPort combines all functionality related to sending and accepting SIP media. type MediaPort struct { log logger.Logger + opts *MediaOptions mon *stats.CallMonitor externalIP netip.Addr - conn *rtp.Conn + port *udpConn + mediaReceived core.Fuse + packetCount atomic.Uint64 mediaTimeout <-chan struct{} + timeoutStart atomic.Pointer[time.Time] + closed core.Fuse dtmfAudioEnabled bool jitterEnabled bool - closed atomic.Bool mu sync.Mutex conf *MediaConf + sess rtp.Session + hnd atomic.Pointer[rtp.Handler] dtmfOutRTP *rtp.Stream dtmfOutAudio media.PCM16Writer @@ -99,38 +167,70 @@ type MediaPort struct { } func (p *MediaPort) EnableTimeout(enabled bool) { - p.conn.EnableTimeout(enabled) -} - -func (p *MediaPort) Close() { - if !p.closed.CompareAndSwap(false, true) { + if !enabled { + p.timeoutStart.Store(nil) return } - p.mu.Lock() - defer p.mu.Unlock() - if w := p.audioOut.Swap(nil); w != nil { - _ = w.Close() - } - if w := p.audioIn.Swap(nil); w != nil { - _ = w.Close() - } - p.audioOutRTP = nil - p.audioInHandler = nil - p.dtmfOutRTP = nil - if p.dtmfOutAudio != nil { - p.dtmfOutAudio.Close() - p.dtmfOutAudio = nil + now := time.Now() + p.timeoutStart.Store(&now) +} + +func (p *MediaPort) timeoutLoop(timeoutCallback func()) { + ticker := time.NewTicker(p.opts.MediaTimeout) + defer ticker.Stop() + + var lastPackets uint64 + for { + select { + case <-p.closed.Watch(): + return + case <-ticker.C: + curPackets := p.packetCount.Load() + if curPackets != lastPackets { + lastPackets = curPackets + continue + } + start := p.timeoutStart.Load() + if start == nil { + continue // temporary disabled + } + if lastPackets == 0 && time.Since(*start) < p.opts.MediaTimeoutInitial { + continue + } + timeoutCallback() + return + } } - p.dtmfIn.Store(nil) - _ = p.conn.Close() +} + +func (p *MediaPort) Close() { + p.closed.Once(func() { + p.mu.Lock() + defer p.mu.Unlock() + if w := p.audioOut.Swap(nil); w != nil { + _ = w.Close() + } + if w := p.audioIn.Swap(nil); w != nil { + _ = w.Close() + } + p.audioOutRTP = nil + p.audioInHandler = nil + p.dtmfOutRTP = nil + if p.dtmfOutAudio != nil { + p.dtmfOutAudio.Close() + p.dtmfOutAudio = nil + } + p.dtmfIn.Store(nil) + _ = p.port.Close() + }) } func (p *MediaPort) Port() int { - return p.conn.LocalAddr().Port + return p.port.LocalAddr().(*net.UDPAddr).Port } func (p *MediaPort) Received() <-chan struct{} { - return p.conn.Received() + return p.mediaReceived.Watch() } func (p *MediaPort) Timeout() <-chan struct{} { @@ -159,17 +259,17 @@ func (p *MediaPort) GetAudioWriter() media.PCM16Writer { } // NewOffer generates an SDP offer for the media. -func (p *MediaPort) NewOffer() *sdp.Offer { - return sdp.NewOffer(p.externalIP, p.Port()) +func (p *MediaPort) NewOffer(encrypted sdp.Encryption) (*sdp.Offer, error) { + return sdp.NewOffer(p.externalIP, p.Port(), encrypted) } // SetAnswer decodes and applies SDP answer for offer from NewOffer. SetConfig must be called with the decoded configuration. -func (p *MediaPort) SetAnswer(offer *sdp.Offer, answerData []byte) (*MediaConf, error) { +func (p *MediaPort) SetAnswer(offer *sdp.Offer, answerData []byte, enc sdp.Encryption) (*MediaConf, error) { answer, err := sdp.ParseAnswer(answerData) if err != nil { return nil, err } - mc, err := answer.Apply(offer) + mc, err := answer.Apply(offer, enc) if err != nil { return nil, err } @@ -177,12 +277,12 @@ func (p *MediaPort) SetAnswer(offer *sdp.Offer, answerData []byte) (*MediaConf, } // SetOffer decodes the offer from another party and returns encoded answer. To accept the offer, call SetConfig. -func (p *MediaPort) SetOffer(offerData []byte) (*sdp.Answer, *MediaConf, error) { +func (p *MediaPort) SetOffer(offerData []byte, enc sdp.Encryption) (*sdp.Answer, *MediaConf, error) { offer, err := sdp.ParseOffer(offerData) if err != nil { return nil, nil, err } - answer, mc, err := offer.Answer(p.externalIP, p.Port()) + answer, mc, err := offer.Answer(p.externalIP, p.Port(), enc) if err != nil { return nil, nil, err } @@ -190,30 +290,106 @@ func (p *MediaPort) SetOffer(offerData []byte) (*sdp.Answer, *MediaConf, error) } func (p *MediaPort) SetConfig(c *MediaConf) error { + var crypto string + if c.Crypto != nil { + crypto = c.Crypto.Profile.String() + } p.log.Infow("using codecs", "audio-codec", c.Audio.Codec.Info().SDPName, "audio-rtp", c.Audio.Type, "dtmf-rtp", c.Audio.DTMFType, + "srtp", crypto, ) + p.port.SetDst(c.Remote) + var ( + sess rtp.Session + err error + ) + if c.Crypto != nil { + sess, err = srtp.NewSession(p.log, p.port, c.Crypto) + } else { + sess = rtp.NewSession(p.log, p.port) + } + if err != nil { + return err + } + p.mu.Lock() defer p.mu.Unlock() - if ip := c.Remote; ip.IsValid() { - p.conn.SetDestAddr(&net.UDPAddr{ - IP: ip.Addr().AsSlice(), - Port: int(ip.Port()), - }) - } + p.port.SetDst(c.Remote) p.conf = c + p.sess = sess - p.setupOutput() + if err = p.setupOutput(); err != nil { + return err + } p.setupInput() return nil } +func (p *MediaPort) rtpLoop(sess rtp.Session) { + // Need a loop to process all incoming packets. + for { + r, _, err := sess.AcceptStream() + if err != nil { + if !errors.Is(err, net.ErrClosed) && !strings.Contains(err.Error(), "closed") { + p.log.Errorw("cannot accept RTP stream", err) + } + return + } + p.mediaReceived.Break() + go p.rtpReadLoop(r) + } +} + +func (p *MediaPort) rtpReadLoop(r rtp.ReadStream) { + buf := make([]byte, rtp.MTUSize+1) + overflow := false + var h rtp.Header + for { + h = rtp.Header{} + n, err := r.ReadRTP(&h, buf) + if err == io.EOF { + return + } else if err != nil { + p.log.Errorw("read RTP failed", err) + return + } + p.packetCount.Add(1) + if n > rtp.MTUSize { + overflow = true + if !overflow { + p.log.Errorw("RTP packet is larger than MTU limit", nil) + } + continue // ignore partial messages + } + + ptr := p.hnd.Load() + if ptr == nil { + continue + } + hnd := *ptr + if hnd == nil { + continue + } + err = hnd.HandleRTP(&h, buf[:n]) + if err != nil { + p.log.Errorw("handle RTP failed", err) + continue + } + } +} + // Must be called holding the lock -func (p *MediaPort) setupOutput() { +func (p *MediaPort) setupOutput() error { + go p.rtpLoop(p.sess) + w, err := p.sess.OpenWriteStream() + if err != nil { + return err + } + // TODO: this says "audio", but actually includes DTMF too - s := rtp.NewSeqWriter(newRTPStatsWriter(p.mon, "audio", p.conn)) + s := rtp.NewSeqWriter(newRTPStatsWriter(p.mon, "audio", w)) p.audioOutRTP = s.NewStream(p.conf.Audio.Type, p.conf.Audio.Codec.Info().RTPClockRate) // Encoding pipeline (LK PCM -> SIP RTP) @@ -233,6 +409,7 @@ func (p *MediaPort) setupOutput() { if w := p.audioOut.Swap(audioOut); w != nil { _ = w.Close() } + return nil } func (p *MediaPort) setupInput() { @@ -244,24 +421,23 @@ func (p *MediaPort) setupInput() { mux.SetDefault(newRTPStatsHandler(p.mon, "", nil)) mux.Register(p.conf.Audio.Type, newRTPStatsHandler(p.mon, p.conf.Audio.Codec.Info().SDPName, audioHandler)) if p.conf.Audio.DTMFType != 0 { - mux.Register(p.conf.Audio.DTMFType, newRTPStatsHandler(p.mon, dtmf.SDPName, rtp.HandlerFunc(func(pck *rtp.Packet) error { + mux.Register(p.conf.Audio.DTMFType, newRTPStatsHandler(p.mon, dtmf.SDPName, rtp.HandlerFunc(func(h *rtp.Header, payload []byte) error { ptr := p.dtmfIn.Load() if ptr == nil { return nil } fnc := *ptr - if ev, ok := dtmf.DecodeRTP(pck); ok && fnc != nil { + if ev, ok := dtmf.DecodeRTP(h, payload); ok && fnc != nil { fnc(ev) } return nil }))) } - + var hnd rtp.Handler = mux if p.jitterEnabled { - p.conn.OnRTP(rtp.HandleJitter(p.conf.Audio.Codec.Info().RTPClockRate, mux)) - } else { - p.conn.OnRTP(mux) + hnd = rtp.HandleJitter(p.conf.Audio.Codec.Info().RTPClockRate, hnd) } + p.hnd.Store(&hnd) } // SetDTMFAudio forces SIP to generate audio dTMF tones in addition to digital signals. diff --git a/pkg/sip/media_port_test.go b/pkg/sip/media_port_test.go index 7bad4887..8c634c1d 100644 --- a/pkg/sip/media_port_test.go +++ b/pkg/sip/media_port_test.go @@ -31,23 +31,67 @@ import ( "github.com/livekit/mediatransportutil/pkg/rtcconfig" "github.com/livekit/protocol/logger" + "github.com/livekit/sip/pkg/media/sdp" "github.com/livekit/sip/pkg/media" "github.com/livekit/sip/pkg/media/rtp" ) type testUDPConn struct { - addr *net.UDPAddr + addr netip.AddrPort closed chan struct{} buf chan []byte peer atomic.Pointer[testUDPConn] } -func (c *testUDPConn) LocalAddr() net.Addr { - return c.addr +func (c *testUDPConn) Read(b []byte) (int, error) { + n, _, err := c.ReadFromUDPAddrPort(b) + return n, err +} + +func (c *testUDPConn) Write(b []byte) (int, error) { + return c.WriteToUDPAddrPort(b, netip.AddrPort{}) +} + +func (c *testUDPConn) RemoteAddr() net.Addr { + p := c.peer.Load() + if p == nil { + return &net.UDPAddr{} + } + return p.LocalAddr() +} + +func (c *testUDPConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *testUDPConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *testUDPConn) SetWriteDeadline(t time.Time) error { + return nil +} + +func (c *testUDPConn) ReadFromUDPAddrPort(buf []byte) (int, netip.AddrPort, error) { + peer := c.peer.Load() + if peer == nil { + return 0, netip.AddrPort{}, io.ErrClosedPipe + } + select { + case <-c.closed: + return 0, netip.AddrPort{}, io.ErrClosedPipe + case data := <-c.buf: + n := copy(buf, data) + var err error + if n < len(data) { + err = io.ErrShortBuffer + } + return n, peer.addr, err + } } -func (c *testUDPConn) WriteToUDP(buf []byte, addr *net.UDPAddr) (int, error) { +func (c *testUDPConn) WriteToUDPAddrPort(buf []byte, addr netip.AddrPort) (int, error) { peer := c.peer.Load() if peer == nil { return 0, io.ErrClosedPipe @@ -65,21 +109,10 @@ func (c *testUDPConn) WriteToUDP(buf []byte, addr *net.UDPAddr) (int, error) { } } -func (c *testUDPConn) ReadFromUDP(buf []byte) (int, *net.UDPAddr, error) { - peer := c.peer.Load() - if peer == nil { - return 0, nil, io.ErrClosedPipe - } - select { - case <-c.closed: - return 0, nil, io.ErrClosedPipe - case data := <-c.buf: - n := copy(buf, data) - var err error - if n < len(data) { - err = io.ErrShortBuffer - } - return n, peer.addr, err +func (c *testUDPConn) LocalAddr() net.Addr { + return &net.UDPAddr{ + IP: c.addr.Addr().AsSlice(), + Port: int(c.addr.Port()), } } @@ -90,20 +123,20 @@ func (c *testUDPConn) Close() error { return nil } -func newUDPConn(i int) *testUDPConn { +func newTestConn(i int) *testUDPConn { return &testUDPConn{ - addr: &net.UDPAddr{ - IP: net.IPv4(byte(i), byte(i), byte(i), byte(i)), - Port: 10000 * i, - }, + addr: netip.AddrPortFrom( + netip.AddrFrom4([4]byte{byte(i), byte(i), byte(i), byte(i)}), + uint16(10000*i), + ), buf: make(chan []byte, 10), closed: make(chan struct{}), } } func newUDPPipe() (c1, c2 *testUDPConn) { - c1 = newUDPConn(1) - c2 = newUDPConn(2) + c1 = newTestConn(1) + c2 = newTestConn(2) c1.peer.Store(c2) c2.peer.Store(c1) return @@ -151,43 +184,51 @@ func TestMediaPort(t *testing.T) { nativeRate *= 2 // error in RFC } - for _, rate := range []int{ - nativeRate, - 48000, + for _, tconf := range []struct { + Rate int + Encrypted sdp.Encryption + }{ + {nativeRate, sdp.EncryptionNone}, + {48000, sdp.EncryptionRequire}, } { - t.Run(strconv.Itoa(rate), func(t *testing.T) { + suff := "" + if tconf.Encrypted != sdp.EncryptionNone { + suff = " srtp" + } + t.Run(fmt.Sprintf("%d%s", tconf.Rate, suff), func(t *testing.T) { c1, c2 := newUDPPipe() log := logger.GetLogger() - m1, err := NewMediaPortWith(log.WithName("one"), nil, c1, &MediaConfig{ + m1, err := NewMediaPortWith(log.WithName("one"), nil, c1, &MediaOptions{ IP: newIP("1.1.1.1"), Ports: rtcconfig.PortRange{Start: 10000}, - }, rate) + }, tconf.Rate) require.NoError(t, err) defer m1.Close() - m2, err := NewMediaPortWith(log.WithName("two"), nil, c2, &MediaConfig{ + m2, err := NewMediaPortWith(log.WithName("two"), nil, c2, &MediaOptions{ IP: newIP("2.2.2.2"), Ports: rtcconfig.PortRange{Start: 20000}, - }, rate) + }, tconf.Rate) require.NoError(t, err) defer m2.Close() - offer := m1.NewOffer() + offer, err := m1.NewOffer(tconf.Encrypted) + require.NoError(t, err) offerData, err := offer.SDP.Marshal() require.NoError(t, err) t.Logf("SDP offer:\n%s", string(offerData)) - answer, conf, err := m2.SetOffer(offerData) + answer, conf, err := m2.SetOffer(offerData, tconf.Encrypted) require.NoError(t, err) answerData, err := answer.SDP.Marshal() require.NoError(t, err) t.Logf("SDP answer:\n%s", string(answerData)) - mc, err := m1.SetAnswer(offer, answerData) + mc, err := m1.SetAnswer(offer, answerData, tconf.Encrypted) require.NoError(t, err) err = m1.SetConfig(mc) @@ -200,17 +241,17 @@ func TestMediaPort(t *testing.T) { require.Equal(t, info.SDPName, m2.Config().Audio.Codec.Info().SDPName) var buf1 media.PCM16Sample - bw1 := media.NewPCM16BufferWriter(&buf1, rate) + bw1 := media.NewPCM16BufferWriter(&buf1, tconf.Rate) m1.WriteAudioTo(bw1) var buf2 media.PCM16Sample - bw2 := media.NewPCM16BufferWriter(&buf2, rate) + bw2 := media.NewPCM16BufferWriter(&buf2, tconf.Rate) m2.WriteAudioTo(bw2) w1 := m1.GetAudioWriter() w2 := m2.GetAudioWriter() - packetSize := uint32(rate / int(time.Second/rtp.DefFrameDur)) + packetSize := uint32(tconf.Rate / int(time.Second/rtp.DefFrameDur)) sample1 := make(media.PCM16Sample, packetSize) sample2 := make(media.PCM16Sample, packetSize) for i := range packetSize { @@ -219,7 +260,7 @@ func TestMediaPort(t *testing.T) { } writes := 1 - if rate == nativeRate { + if tconf.Rate == nativeRate { expChain := fmt.Sprintf("Switch(%d) -> %s(encode) -> RTP(%d)", nativeRate, name, nativeRate) require.Equal(t, expChain, w1.String()) require.Equal(t, expChain, w2.String()) diff --git a/pkg/sip/outbound.go b/pkg/sip/outbound.go index b363e0f1..f30364e4 100644 --- a/pkg/sip/outbound.go +++ b/pkg/sip/outbound.go @@ -33,6 +33,7 @@ import ( "github.com/livekit/protocol/utils/guid" "github.com/livekit/psrpc" lksdk "github.com/livekit/server-sdk-go/v2" + "github.com/livekit/sip/pkg/media/sdp" "github.com/livekit/sipgo/sip" "github.com/livekit/sip/pkg/config" @@ -59,6 +60,7 @@ type sipOutboundConfig struct { ringingTimeout time.Duration maxCallDuration time.Duration enabledFeatures []livekit.SIPFeature + mediaEncryption sdp.Encryption } type outboundCall struct { @@ -117,7 +119,7 @@ func (c *Client) newCall(ctx context.Context, conf *config.Config, log logger.Lo call.mon = c.mon.NewCall(stats.Outbound, sipConf.host, sipConf.address) var err error - call.media, err = NewMediaPort(call.log, call.mon, &MediaConfig{ + call.media, err = NewMediaPort(call.log, call.mon, &MediaOptions{ IP: c.sconf.MediaIP, Ports: conf.RTPPort, MediaTimeoutInitial: c.conf.MediaTimeoutInitial, @@ -479,7 +481,10 @@ func (c *outboundCall) sipSignal(ctx context.Context) error { cancel() }() - sdpOffer := c.media.NewOffer() + sdpOffer, err := c.media.NewOffer(c.sipConf.mediaEncryption) + if err != nil { + return err + } sdpOfferData, err := sdpOffer.SDP.Marshal() if err != nil { return err @@ -523,7 +528,7 @@ func (c *outboundCall) sipSignal(ctx context.Context) error { c.log = LoggerWithHeaders(c.log, c.cc) - mc, err := c.media.SetAnswer(sdpOffer, sdpResp) + mc, err := c.media.SetAnswer(sdpOffer, sdpResp, c.sipConf.mediaEncryption) if err != nil { return err } @@ -737,6 +742,7 @@ authLoop: case sip.StatusBadRequest, sip.StatusNotFound, sip.StatusTemporarilyUnavailable, + sip.StatusNotAcceptableHere, sip.StatusBusyHere: err := &livekit.SIPStatus{Code: livekit.SIPStatusCode(resp.StatusCode)} if body := resp.Body(); len(body) != 0 { diff --git a/pkg/sip/participant.go b/pkg/sip/participant.go index 6477ff6a..1e3cbde0 100644 --- a/pkg/sip/participant.go +++ b/pkg/sip/participant.go @@ -18,6 +18,7 @@ import ( "time" "github.com/livekit/protocol/livekit" + "github.com/livekit/sipgo/sip" ) const ( @@ -80,6 +81,15 @@ func (v CallStatus) DisconnectReason() livekit.DisconnectReason { } } +func (v CallStatus) SIPStatus() (sip.StatusCode, string) { + switch v { + default: + return sip.StatusBusyHere, "Rejected" + case callMediaFailed: + return sip.StatusNotAcceptableHere, "MediaFailed" + } +} + const ( callDropped = CallStatus(iota) callFlood @@ -90,4 +100,5 @@ const ( CallHangup callUnavailable callRejected + callMediaFailed ) diff --git a/pkg/sip/protocol.go b/pkg/sip/protocol.go index 39a8c5e3..1382808c 100644 --- a/pkg/sip/protocol.go +++ b/pkg/sip/protocol.go @@ -98,6 +98,13 @@ var statusNamesMap = map[int]string{ 606: "GlobalNotAcceptable", } +func sipStatus(code sip.StatusCode) string { + if name := statusNamesMap[int(code)]; name != "" { + return name + } + return fmt.Sprintf("Status%d", int(code)) +} + func statusName(status int) string { if name := statusNamesMap[status]; name != "" { return fmt.Sprintf("%d-%s", status, name) diff --git a/pkg/sip/server.go b/pkg/sip/server.go index b481bb20..a221f889 100644 --- a/pkg/sip/server.go +++ b/pkg/sip/server.go @@ -96,6 +96,7 @@ type CallDispatch struct { EnabledFeatures []livekit.SIPFeature RingingTimeout time.Duration MaxCallDuration time.Duration + MediaEncryption livekit.SIPMediaEncryption } type Handler interface { diff --git a/pkg/sip/service_test.go b/pkg/sip/service_test.go index 73bf79d1..2e2623e9 100644 --- a/pkg/sip/service_test.go +++ b/pkg/sip/service_test.go @@ -108,7 +108,8 @@ func testInvite(t *testing.T, h Handler, hidden bool, from, to string, test func sipClient, err := sipgo.NewClient(sipUserAgent) require.NoError(t, err) - offer := sdp.NewOffer(localIP, 0xB0B) + offer, err := sdp.NewOffer(localIP, 0xB0B, sdp.EncryptionNone) + require.NoError(t, err) offerData, err := offer.SDP.Marshal() require.NoError(t, err) diff --git a/pkg/sip/types.go b/pkg/sip/types.go index 3a8ddbdc..6185d9d8 100644 --- a/pkg/sip/types.go +++ b/pkg/sip/types.go @@ -23,6 +23,7 @@ import ( "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" + "github.com/livekit/sip/pkg/media/sdp" "github.com/livekit/sipgo/sip" ) @@ -111,7 +112,11 @@ func (u URI) GetHost() string { func (u URI) GetPort() int { port := int(u.Addr.Port()) if port == 0 { - port = 5060 + if u.Transport == TransportTLS { + port = 5061 + } else { + port = 5060 + } } return port } @@ -307,3 +312,15 @@ func AttrsToHeaders(attrs, attrToHdr, headers map[string]string) map[string]stri } return headers } + +func sdpEncryption(e livekit.SIPMediaEncryption) (sdp.Encryption, error) { + switch e { + case livekit.SIPMediaEncryption_SIP_MEDIA_ENCRYPT_DISABLE: + return sdp.EncryptionNone, nil + case livekit.SIPMediaEncryption_SIP_MEDIA_ENCRYPT_ALLOW: + return sdp.EncryptionAllow, nil + case livekit.SIPMediaEncryption_SIP_MEDIA_ENCRYPT_REQUIRE: + return sdp.EncryptionRequire, nil + } + return sdp.EncryptionAllow, errors.New("invalid SIP media encryption type") +} diff --git a/pkg/siptest/client.go b/pkg/siptest/client.go index b9a473dc..2e12df89 100644 --- a/pkg/siptest/client.go +++ b/pkg/siptest/client.go @@ -218,31 +218,31 @@ func (c *Client) Close() { func (c *Client) setupRTPReceiver() { var lastTs atomic.Uint32 - c.mux = rtp.NewMux(rtp.HandlerFunc(func(pck *rtp.Packet) error { - lastTs.Store(pck.Timestamp) + c.mux = rtp.NewMux(rtp.HandlerFunc(func(hdr *rtp.Header, payload []byte) error { + lastTs.Store(hdr.Timestamp) h := c.recordHandler.Load() if h != nil { - return (*h).HandleRTP(pck) + return (*h).HandleRTP(hdr, payload) } return nil })) - c.mux.Register(101, rtp.HandlerFunc(func(pck *rtp.Packet) error { + c.mux.Register(101, rtp.HandlerFunc(func(hdr *rtp.Header, payload []byte) error { ts := lastTs.Load() var diff int64 if ts > 0 { - diff = int64(pck.Timestamp) - int64(ts) + diff = int64(hdr.Timestamp) - int64(ts) } if diff > int64(c.audioCodec.Info().RTPClockRate) || diff < -int64(c.audioCodec.Info().RTPClockRate) { - c.log.Info("reveived out of sync DTMF message", "dtmfTs", pck.Timestamp, "lastTs", ts) + c.log.Info("reveived out of sync DTMF message", "dtmfTs", hdr.Timestamp, "lastTs", ts) return nil } if c.conf.OnDTMF == nil { return nil } - if ev, ok := dtmf.DecodeRTP(pck); ok { + if ev, ok := dtmf.DecodeRTP(hdr, payload); ok { c.conf.OnDTMF(ev) } return nil @@ -664,7 +664,7 @@ func (c *Client) WaitSignals(ctx context.Context, vals []int, w io.WriteCloser) pkts := make(chan *rtp.Packet, 1) done := make(chan struct{}) - h := rtp.Handler(rtp.HandlerFunc(func(pkt *rtp.Packet) error { + h := rtp.Handler(rtp.HandlerFunc(func(hdr *rtp.Header, payload []byte) error { // Make sure er do not send on a closed channel select { case <-done: @@ -677,7 +677,7 @@ func (c *Client) WaitSignals(ctx context.Context, vals []int, w io.WriteCloser) close(pkts) close(done) return ctx.Err() - case pkts <- pkt: + case pkts <- &rtp.Packet{Header: *hdr, Payload: slices.Clone(payload)}: } return nil @@ -696,7 +696,7 @@ func (c *Client) WaitSignals(ctx context.Context, vals []int, w io.WriteCloser) continue } decoded = decoded[:0] - if err := dec.HandleRTP(p); err != nil { + if err := dec.HandleRTP(&p.Header, p.Payload); err != nil { return err } if ws != nil {