diff --git a/pkg/nack/generator_interceptor_test.go b/pkg/nack/generator_interceptor_test.go index 46a812a..5ea926c 100644 --- a/pkg/nack/generator_interceptor_test.go +++ b/pkg/nack/generator_interceptor_test.go @@ -76,3 +76,60 @@ func TestGeneratorInterceptor_InvalidSize(t *testing.T) { _, err := f.NewInterceptor("") assert.Error(t, err, ErrInvalidSize) } + +func TestGeneratorInterceptor_StreamFilter(t *testing.T) { + const interval = time.Millisecond * 10 + f, err := NewGeneratorInterceptor( + GeneratorSize(64), + GeneratorSkipLastN(2), + GeneratorInterval(interval), + GeneratorLog(logging.NewDefaultLoggerFactory().NewLogger("test")), + GeneratorStreamsFilter(func(info *interceptor.StreamInfo) bool { + return info.SSRC != 1 // enable nacks only for ssrc 2 + }), + ) + assert.NoError(t, err) + + i, err := f.NewInterceptor("") + assert.NoError(t, err) + + streamWithoutNacks := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 1, + RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, + }, i) + defer func() { + assert.NoError(t, streamWithoutNacks.Close()) + }() + + streamWithNacks := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 2, + RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, + }, i) + defer func() { + assert.NoError(t, streamWithNacks.Close()) + }() + + for _, seqNum := range []uint16{10, 11, 12, 14, 16, 18} { + streamWithNacks.ReceiveRTP(&rtp.Packet{Header: rtp.Header{SequenceNumber: seqNum}}) + streamWithoutNacks.ReceiveRTP(&rtp.Packet{Header: rtp.Header{SequenceNumber: seqNum}}) + } + + time.Sleep(interval * 2) // wait for at least 2 nack packets + + // both test streams receive RTCP packets about both test streams (as they both call BindRTCPWriter), so we + // can check only one + rtcpStream := streamWithNacks.WrittenRTCP() + + for { + select { + case pkts := <-rtcpStream: + for _, pkt := range pkts { + if nack, isNack := pkt.(*rtcp.TransportLayerNack); isNack { + assert.NotEqual(t, uint32(1), nack.MediaSSRC) // check there are no nacks for ssrc 1 + } + } + default: + return + } + } +} diff --git a/pkg/nack/responder_interceptor_test.go b/pkg/nack/responder_interceptor_test.go index 360e142..9eb5b23 100644 --- a/pkg/nack/responder_interceptor_test.go +++ b/pkg/nack/responder_interceptor_test.go @@ -150,3 +150,84 @@ func TestResponderInterceptor_Race(t *testing.T) { } } } + +func TestResponderInterceptor_StreamFilter(t *testing.T) { + f, err := NewResponderInterceptor( + ResponderSize(8), + ResponderLog(logging.NewDefaultLoggerFactory().NewLogger("test")), + ResponderStreamsFilter(func(info *interceptor.StreamInfo) bool { + return info.SSRC != 1 // enable nacks only for ssrc 2 + })) + + require.NoError(t, err) + + i, err := f.NewInterceptor("") + require.NoError(t, err) + + streamWithoutNacks := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 1, + RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, + }, i) + defer func() { + require.NoError(t, streamWithoutNacks.Close()) + }() + + streamWithNacks := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 2, + RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, + }, i) + defer func() { + require.NoError(t, streamWithNacks.Close()) + }() + + for _, seqNum := range []uint16{10, 11, 12, 14, 15} { + require.NoError(t, streamWithoutNacks.WriteRTP(&rtp.Packet{Header: rtp.Header{SequenceNumber: seqNum, SSRC: 1}})) + require.NoError(t, streamWithNacks.WriteRTP(&rtp.Packet{Header: rtp.Header{SequenceNumber: seqNum, SSRC: 2}})) + + select { + case p := <-streamWithoutNacks.WrittenRTP(): + require.Equal(t, seqNum, p.SequenceNumber) + case <-time.After(10 * time.Millisecond): + t.Fatal("written rtp packet not found") + } + + select { + case p := <-streamWithNacks.WrittenRTP(): + require.Equal(t, seqNum, p.SequenceNumber) + case <-time.After(10 * time.Millisecond): + t.Fatal("written rtp packet not found") + } + } + + streamWithoutNacks.ReceiveRTCP([]rtcp.Packet{ + &rtcp.TransportLayerNack{ + MediaSSRC: 1, + SenderSSRC: 2, + Nacks: []rtcp.NackPair{ + {PacketID: 11, LostPackets: 0b1011}, // sequence numbers: 11, 12, 13, 15 + }, + }, + }) + + streamWithNacks.ReceiveRTCP([]rtcp.Packet{ + &rtcp.TransportLayerNack{ + MediaSSRC: 2, + SenderSSRC: 2, + Nacks: []rtcp.NackPair{ + {PacketID: 11, LostPackets: 0b1011}, // sequence numbers: 11, 12, 13, 15 + }, + }, + }) + + select { + case <-streamWithNacks.WrittenRTP(): + case <-time.After(10 * time.Millisecond): + t.Fatal("nack response expected") + } + + select { + case <-streamWithoutNacks.WrittenRTP(): + t.Fatal("no nack response expected") + case <-time.After(10 * time.Millisecond): + } +}