diff --git a/pkg/report/receiver_stream.go b/pkg/report/receiver_stream.go index b899bb1b..3f239c2e 100644 --- a/pkg/report/receiver_stream.go +++ b/pkg/report/receiver_stream.go @@ -12,6 +12,13 @@ import ( "github.com/pion/rtp" ) +const ( + // packetsPerHistoryEntry represents how many packets are in the bitmask for + // each entry in the `packets` slice in the receiver stream. Because we use + // a uint64, we can keep track of 64 packets per entry. + packetsPerHistoryEntry = 64 +) + type receiverStream struct { ssrc uint32 receiverSSRC uint32 @@ -86,18 +93,18 @@ func (stream *receiverStream) processRTP(now time.Time, pktHeader *rtp.Header) { } func (stream *receiverStream) setReceived(seq uint16) { - pos := seq % stream.size - stream.packets[pos/64] |= 1 << (pos % 64) + pos := seq % (stream.size * packetsPerHistoryEntry) + stream.packets[pos/packetsPerHistoryEntry] |= 1 << (pos % packetsPerHistoryEntry) } func (stream *receiverStream) delReceived(seq uint16) { - pos := seq % stream.size - stream.packets[pos/64] &^= 1 << (pos % 64) + pos := seq % (stream.size * packetsPerHistoryEntry) + stream.packets[pos/packetsPerHistoryEntry] &^= 1 << (pos % packetsPerHistoryEntry) } func (stream *receiverStream) getReceived(seq uint16) bool { - pos := seq % stream.size - return (stream.packets[pos/64] & (1 << (pos % 64))) != 0 + pos := seq % (stream.size * packetsPerHistoryEntry) + return (stream.packets[pos/packetsPerHistoryEntry] & (1 << (pos % packetsPerHistoryEntry))) != 0 } func (stream *receiverStream) processSenderReport(now time.Time, sr *rtcp.SenderReport) { diff --git a/pkg/report/receiver_stream_test.go b/pkg/report/receiver_stream_test.go new file mode 100644 index 00000000..c04e3fd8 --- /dev/null +++ b/pkg/report/receiver_stream_test.go @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package report + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReceiverStream(t *testing.T) { + t.Run("can use entire history size", func(t *testing.T) { + stream := newReceiverStream(12345, 90000) + maxPackets := stream.size * packetsPerHistoryEntry + + // We shouldn't wrap around so long as we only try maxPackets worth. + for seq := uint16(0); seq < maxPackets; seq++ { + require.False(t, stream.getReceived(seq), "packet with SN %v shouldn't be received yet", seq) + stream.setReceived(seq) + require.True(t, stream.getReceived(seq), "packet with SN %v should now be received", seq) + } + + // Delete should also work. + for seq := uint16(0); seq < maxPackets; seq++ { + require.True(t, stream.getReceived(seq), "packet with SN %v should still be marked as received", seq) + stream.delReceived(seq) + require.False(t, stream.getReceived(seq), "packet with SN %v should no longer be received", seq) + } + }) +}