diff --git a/internal/rtpbuffer/errors.go b/internal/rtpbuffer/errors.go new file mode 100644 index 00000000..57fca021 --- /dev/null +++ b/internal/rtpbuffer/errors.go @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtpbuffer + +import "errors" + +// ErrInvalidSize is returned by newReceiveLog/newRTPBuffer, when an incorrect buffer size is supplied. +var ErrInvalidSize = errors.New("invalid buffer size") + +var ( + errPacketReleased = errors.New("could not retain packet, already released") + errFailedToCastHeaderPool = errors.New("could not access header pool, failed cast") + errFailedToCastPayloadPool = errors.New("could not access payload pool, failed cast") +) diff --git a/pkg/nack/retainable_packet.go b/internal/rtpbuffer/packet_factory.go similarity index 60% rename from pkg/nack/retainable_packet.go rename to internal/rtpbuffer/packet_factory.go index 18c533a8..4ab07fbe 100644 --- a/pkg/nack/retainable_packet.go +++ b/internal/rtpbuffer/packet_factory.go @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT -package nack +package rtpbuffer import ( "encoding/binary" @@ -11,16 +11,22 @@ import ( "github.com/pion/rtp" ) -const maxPayloadLen = 1460 +// PacketFactory allows custom logic around the handle of RTP Packets before they added to the RTPBuffer. +// The NoOpPacketFactory doesn't copy packets, while the RetainablePacket will take a copy before adding +type PacketFactory interface { + NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*RetainablePacket, error) +} -type packetManager struct { +// PacketFactoryCopy is PacketFactory that takes a copy of packets when added to the RTPBuffer +type PacketFactoryCopy struct { headerPool *sync.Pool payloadPool *sync.Pool rtxSequencer rtp.Sequencer } -func newPacketManager() *packetManager { - return &packetManager{ +// NewPacketFactoryCopy constructs a PacketFactory that takes a copy of packets when added to the RTPBuffer +func NewPacketFactoryCopy() *PacketFactoryCopy { + return &PacketFactoryCopy{ headerPool: &sync.Pool{ New: func() interface{} { return &rtp.Header{} @@ -36,12 +42,13 @@ func newPacketManager() *packetManager { } } -func (m *packetManager) NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*retainablePacket, error) { +// NewPacket constructs a new RetainablePacket that can be added to the RTPBuffer +func (m *PacketFactoryCopy) NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*RetainablePacket, error) { if len(payload) > maxPayloadLen { return nil, io.ErrShortBuffer } - p := &retainablePacket{ + p := &RetainablePacket{ onRelease: m.releasePacket, sequenceNumber: header.SequenceNumber, // new packets have retain count of 1 @@ -92,17 +99,19 @@ func (m *packetManager) NewPacket(header *rtp.Header, payload []byte, rtxSsrc ui return p, nil } -func (m *packetManager) releasePacket(header *rtp.Header, payload *[]byte) { +func (m *PacketFactoryCopy) releasePacket(header *rtp.Header, payload *[]byte) { m.headerPool.Put(header) if payload != nil { m.payloadPool.Put(payload) } } -type noOpPacketFactory struct{} +// PacketFactoryNoOp is a PacketFactory implementation that doesn't copy packets +type PacketFactoryNoOp struct{} -func (f *noOpPacketFactory) NewPacket(header *rtp.Header, payload []byte, _ uint32, _ uint8) (*retainablePacket, error) { - return &retainablePacket{ +// NewPacket constructs a new RetainablePacket that can be added to the RTPBuffer +func (f *PacketFactoryNoOp) NewPacket(header *rtp.Header, payload []byte, _ uint32, _ uint8) (*RetainablePacket, error) { + return &RetainablePacket{ onRelease: f.releasePacket, count: 1, header: header, @@ -111,52 +120,6 @@ func (f *noOpPacketFactory) NewPacket(header *rtp.Header, payload []byte, _ uint }, nil } -func (f *noOpPacketFactory) releasePacket(_ *rtp.Header, _ *[]byte) { +func (f *PacketFactoryNoOp) releasePacket(_ *rtp.Header, _ *[]byte) { // no-op } - -type retainablePacket struct { - onRelease func(*rtp.Header, *[]byte) - - countMu sync.Mutex - count int - - header *rtp.Header - buffer *[]byte - payload []byte - - sequenceNumber uint16 -} - -func (p *retainablePacket) Header() *rtp.Header { - return p.header -} - -func (p *retainablePacket) Payload() []byte { - return p.payload -} - -func (p *retainablePacket) Retain() error { - p.countMu.Lock() - defer p.countMu.Unlock() - if p.count == 0 { - // already released - return errPacketReleased - } - p.count++ - return nil -} - -func (p *retainablePacket) Release() { - p.countMu.Lock() - defer p.countMu.Unlock() - p.count-- - - if p.count == 0 { - // release back to pool - p.onRelease(p.header, p.buffer) - p.header = nil - p.buffer = nil - p.payload = nil - } -} diff --git a/internal/rtpbuffer/retainable_packet.go b/internal/rtpbuffer/retainable_packet.go new file mode 100644 index 00000000..5a9afd79 --- /dev/null +++ b/internal/rtpbuffer/retainable_packet.go @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtpbuffer + +import ( + "sync" + + "github.com/pion/rtp" +) + +// RetainablePacket is a referenced counted RTP packet +type RetainablePacket struct { + onRelease func(*rtp.Header, *[]byte) + + countMu sync.Mutex + count int + + header *rtp.Header + buffer *[]byte + payload []byte + + sequenceNumber uint16 +} + +// Header returns the RTP Header of the RetainablePacket +func (p *RetainablePacket) Header() *rtp.Header { + return p.header +} + +// Payload returns the RTP Payload of the RetainablePacket +func (p *RetainablePacket) Payload() []byte { + return p.payload +} + +// Retain increases the reference count of the RetainablePacket +func (p *RetainablePacket) Retain() error { + p.countMu.Lock() + defer p.countMu.Unlock() + if p.count == 0 { + // already released + return errPacketReleased + } + p.count++ + return nil +} + +// Release decreases the reference count of the RetainablePacket and frees if needed +func (p *RetainablePacket) Release() { + p.countMu.Lock() + defer p.countMu.Unlock() + p.count-- + + if p.count == 0 { + // release back to pool + p.onRelease(p.header, p.buffer) + p.header = nil + p.buffer = nil + p.payload = nil + } +} diff --git a/internal/rtpbuffer/rtpbuffer.go b/internal/rtpbuffer/rtpbuffer.go new file mode 100644 index 00000000..92535074 --- /dev/null +++ b/internal/rtpbuffer/rtpbuffer.go @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package rtpbuffer provides a buffer for storing RTP packets +package rtpbuffer + +import ( + "fmt" +) + +const ( + // Uint16SizeHalf is half of a math.Uint16 + Uint16SizeHalf = 1 << 15 + + maxPayloadLen = 1460 +) + +// RTPBuffer stores RTP packets and allows custom logic around the lifetime of them via the PacketFactory +type RTPBuffer struct { + packets []*RetainablePacket + size uint16 + lastAdded uint16 + started bool +} + +// NewRTPBuffer constructs a new RTPBuffer +func NewRTPBuffer(size uint16) (*RTPBuffer, error) { + allowedSizes := make([]uint16, 0) + correctSize := false + for i := 0; i < 16; i++ { + if size == 1<= Uint16SizeHalf { + return nil + } + + if diff >= r.size { + return nil + } + + pkt := r.packets[seq%r.size] + if pkt != nil { + if pkt.sequenceNumber != seq { + return nil + } + // already released + if err := pkt.Retain(); err != nil { + return nil + } + } + return pkt +} diff --git a/pkg/nack/send_buffer_test.go b/internal/rtpbuffer/rtpbuffer_test.go similarity index 64% rename from pkg/nack/send_buffer_test.go rename to internal/rtpbuffer/rtpbuffer_test.go index 8e45f0f6..746fb5c1 100644 --- a/pkg/nack/send_buffer_test.go +++ b/internal/rtpbuffer/rtpbuffer_test.go @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT -package nack +package rtpbuffer import ( "testing" @@ -10,12 +10,12 @@ import ( "github.com/stretchr/testify/require" ) -func TestSendBuffer(t *testing.T) { - pm := newPacketManager() +func TestRTPBuffer(t *testing.T) { + pm := NewPacketFactoryCopy() for _, start := range []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 511, 512, 513, 32767, 32768, 32769, 65527, 65528, 65529, 65530, 65531, 65532, 65533, 65534, 65535} { start := start - sb, err := newSendBuffer(8) + sb, err := NewRTPBuffer(8) require.NoError(t, err) add := func(nums ...uint16) { @@ -23,7 +23,7 @@ func TestSendBuffer(t *testing.T) { seq := start + n pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil, 0, 0) require.NoError(t, err) - sb.add(pkt) + sb.Add(pkt) } } @@ -31,7 +31,7 @@ func TestSendBuffer(t *testing.T) { t.Helper() for _, n := range nums { seq := start + n - packet := sb.get(seq) + packet := sb.Get(seq) if packet == nil { t.Errorf("packet not found: %d", seq) continue @@ -46,7 +46,7 @@ func TestSendBuffer(t *testing.T) { t.Helper() for _, n := range nums { seq := start + n - packet := sb.get(seq) + packet := sb.Get(seq) if packet != nil { t.Errorf("packet found for %d: %d", seq, packet.Header().SequenceNumber) } @@ -70,21 +70,21 @@ func TestSendBuffer(t *testing.T) { } } -func TestSendBuffer_Overridden(t *testing.T) { +func TestRTPBuffer_Overridden(t *testing.T) { // override original packet content and get - pm := newPacketManager() - sb, err := newSendBuffer(1) + pm := NewPacketFactoryCopy() + sb, err := NewRTPBuffer(1) require.NoError(t, err) require.Equal(t, uint16(1), sb.size) originalBytes := []byte("originalContent") pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1}, originalBytes, 0, 0) require.NoError(t, err) - sb.add(pkt) + sb.Add(pkt) // change payload copy(originalBytes, "altered") - retrieved := sb.get(1) + retrieved := sb.Get(1) require.NotNil(t, retrieved) require.Equal(t, "originalContent", string(retrieved.Payload())) retrieved.Release() @@ -93,41 +93,8 @@ func TestSendBuffer_Overridden(t *testing.T) { // ensure original packet is released pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, originalBytes, 0, 0) require.NoError(t, err) - sb.add(pkt) + sb.Add(pkt) require.Equal(t, 0, retrieved.count) - require.Nil(t, sb.get(1)) -} - -// this test is only useful when being run with the race detector, it won't fail otherwise: -// -// go test -race ./pkg/nack/ -func TestSendBuffer_Race(t *testing.T) { - pm := newPacketManager() - for _, start := range []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 511, 512, 513, 32767, 32768, 32769, 65527, 65528, 65529, 65530, 65531, 65532, 65533, 65534, 65535} { - start := start - - sb, err := newSendBuffer(8) - require.NoError(t, err) - - add := func(nums ...uint16) { - for _, n := range nums { - seq := start + n - pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil, 0, 0) - require.NoError(t, err) - sb.add(pkt) - } - } - - get := func(nums ...uint16) { - t.Helper() - for _, n := range nums { - seq := start + n - sb.get(seq) - } - } - - go add(0, 1, 2, 3, 4, 5, 6, 7) - go get(0, 1, 2, 3, 4, 5, 6, 7) - } + require.Nil(t, sb.Get(1)) } diff --git a/pkg/nack/errors.go b/pkg/nack/errors.go index b47ec39c..8b0958d0 100644 --- a/pkg/nack/errors.go +++ b/pkg/nack/errors.go @@ -3,13 +3,7 @@ package nack -import "errors" +import "github.com/pion/interceptor/internal/rtpbuffer" -// ErrInvalidSize is returned by newReceiveLog/newSendBuffer, when an incorrect buffer size is supplied. -var ErrInvalidSize = errors.New("invalid buffer size") - -var ( - errPacketReleased = errors.New("could not retain packet, already released") - errFailedToCastHeaderPool = errors.New("could not access header pool, failed cast") - errFailedToCastPayloadPool = errors.New("could not access payload pool, failed cast") -) +// ErrInvalidSize is returned by newReceiveLog/newRTPBuffer, when an incorrect buffer size is supplied. +var ErrInvalidSize = rtpbuffer.ErrInvalidSize diff --git a/pkg/nack/receive_log.go b/pkg/nack/receive_log.go index 6a19996e..313133e2 100644 --- a/pkg/nack/receive_log.go +++ b/pkg/nack/receive_log.go @@ -6,6 +6,8 @@ package nack import ( "fmt" "sync" + + "github.com/pion/interceptor/internal/rtpbuffer" ) type receiveLog struct { @@ -54,7 +56,7 @@ func (s *receiveLog) add(seq uint16) { switch { case diff == 0: return - case diff < uint16SizeHalf: + case diff < rtpbuffer.Uint16SizeHalf: // this means a positive diff, in other words seq > end (with counting for rollovers) for i := s.end + 1; i != seq; i++ { // clear packets between end and seq (these may contain packets from a "size" ago) @@ -82,7 +84,7 @@ func (s *receiveLog) get(seq uint16) bool { defer s.m.RUnlock() diff := s.end - seq - if diff >= uint16SizeHalf { + if diff >= rtpbuffer.Uint16SizeHalf { return false } @@ -98,7 +100,7 @@ func (s *receiveLog) missingSeqNumbers(skipLastN uint16) []uint16 { defer s.m.RUnlock() until := s.end - skipLastN - if until-s.lastConsecutive >= uint16SizeHalf { + if until-s.lastConsecutive >= rtpbuffer.Uint16SizeHalf { // until < s.lastConsecutive (counting for rollover) return nil } diff --git a/pkg/nack/responder_interceptor.go b/pkg/nack/responder_interceptor.go index 22d038ba..58e34301 100644 --- a/pkg/nack/responder_interceptor.go +++ b/pkg/nack/responder_interceptor.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/rtpbuffer" "github.com/pion/logging" "github.com/pion/rtcp" "github.com/pion/rtp" @@ -17,10 +18,6 @@ type ResponderInterceptorFactory struct { opts []ResponderOption } -type packetFactory interface { - NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*retainablePacket, error) -} - // NewInterceptor constructs a new ResponderInterceptor func (r *ResponderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { i := &ResponderInterceptor{ @@ -37,10 +34,10 @@ func (r *ResponderInterceptorFactory) NewInterceptor(_ string) (interceptor.Inte } if i.packetFactory == nil { - i.packetFactory = newPacketManager() + i.packetFactory = rtpbuffer.NewPacketFactoryCopy() } - if _, err := newSendBuffer(i.size); err != nil { + if _, err := rtpbuffer.NewRTPBuffer(i.size); err != nil { return nil, err } @@ -53,15 +50,16 @@ type ResponderInterceptor struct { streamsFilter func(info *interceptor.StreamInfo) bool size uint16 log logging.LeveledLogger - packetFactory packetFactory + packetFactory rtpbuffer.PacketFactory streams map[uint32]*localStream streamsMu sync.Mutex } type localStream struct { - sendBuffer *sendBuffer - rtpWriter interceptor.RTPWriter + rtpBuffer *rtpbuffer.RTPBuffer + rtpBufferMutex sync.RWMutex + rtpWriter interceptor.RTPWriter } // NewResponderInterceptor returns a new ResponderInterceptorFactor @@ -106,11 +104,11 @@ func (n *ResponderInterceptor) BindLocalStream(info *interceptor.StreamInfo, wri } // error is already checked in NewGeneratorInterceptor - sendBuffer, _ := newSendBuffer(n.size) + rtpBuffer, _ := rtpbuffer.NewRTPBuffer(n.size) n.streamsMu.Lock() n.streams[info.SSRC] = &localStream{ - sendBuffer: sendBuffer, - rtpWriter: writer, + rtpBuffer: rtpBuffer, + rtpWriter: writer, } n.streamsMu.Unlock() @@ -119,7 +117,11 @@ func (n *ResponderInterceptor) BindLocalStream(info *interceptor.StreamInfo, wri if err != nil { return 0, err } - sendBuffer.add(pkt) + n.streams[info.SSRC].rtpBufferMutex.Lock() + defer n.streams[info.SSRC].rtpBufferMutex.Unlock() + + rtpBuffer.Add(pkt) + return writer.Write(header, payload, attributes) }) } @@ -141,7 +143,10 @@ func (n *ResponderInterceptor) resendPackets(nack *rtcp.TransportLayerNack) { for i := range nack.Nacks { nack.Nacks[i].Range(func(seq uint16) bool { - if p := stream.sendBuffer.get(seq); p != nil { + stream.rtpBufferMutex.Lock() + defer stream.rtpBufferMutex.Unlock() + + if p := stream.rtpBuffer.Get(seq); p != nil { if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil { n.log.Warnf("failed resending nacked packet: %+v", err) } diff --git a/pkg/nack/responder_interceptor_test.go b/pkg/nack/responder_interceptor_test.go index 68e88d12..019d85e5 100644 --- a/pkg/nack/responder_interceptor_test.go +++ b/pkg/nack/responder_interceptor_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/rtpbuffer" "github.com/pion/interceptor/internal/test" "github.com/pion/logging" "github.com/pion/rtcp" @@ -111,7 +112,7 @@ func TestResponderInterceptor_DisableCopy(t *testing.T) { require.NoError(t, err) i, err := f.NewInterceptor("id") require.NoError(t, err) - _, ok := i.(*ResponderInterceptor).packetFactory.(*noOpPacketFactory) + _, ok := i.(*ResponderInterceptor).packetFactory.(*rtpbuffer.PacketFactoryNoOp) require.True(t, ok) } diff --git a/pkg/nack/responder_option.go b/pkg/nack/responder_option.go index 24c7c469..fd1a11cd 100644 --- a/pkg/nack/responder_option.go +++ b/pkg/nack/responder_option.go @@ -5,6 +5,7 @@ package nack import ( "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/rtpbuffer" "github.com/pion/logging" ) @@ -32,7 +33,7 @@ func ResponderLog(log logging.LeveledLogger) ResponderOption { // you are not re-using underlying buffers of packets that have been written func DisableCopy() ResponderOption { return func(s *ResponderInterceptor) error { - s.packetFactory = &noOpPacketFactory{} + s.packetFactory = &rtpbuffer.PacketFactoryNoOp{} return nil } } diff --git a/pkg/nack/send_buffer.go b/pkg/nack/send_buffer.go deleted file mode 100644 index 2b3b076f..00000000 --- a/pkg/nack/send_buffer.go +++ /dev/null @@ -1,104 +0,0 @@ -// SPDX-FileCopyrightText: 2023 The Pion community -// SPDX-License-Identifier: MIT - -package nack - -import ( - "fmt" - "sync" -) - -const ( - uint16SizeHalf = 1 << 15 -) - -type sendBuffer struct { - packets []*retainablePacket - size uint16 - lastAdded uint16 - started bool - - m sync.RWMutex -} - -func newSendBuffer(size uint16) (*sendBuffer, error) { - allowedSizes := make([]uint16, 0) - correctSize := false - for i := 0; i < 16; i++ { - if size == 1<= uint16SizeHalf { - return nil - } - - if diff >= s.size { - return nil - } - - pkt := s.packets[seq%s.size] - if pkt != nil { - if pkt.sequenceNumber != seq { - return nil - } - // already released - if err := pkt.Retain(); err != nil { - return nil - } - } - return pkt -}