diff --git a/client/api.go b/client/api.go index 76b04d6..61d7f8f 100644 --- a/client/api.go +++ b/client/api.go @@ -109,7 +109,7 @@ func (c *Client) StartScreenShare(tracks []webrtc.TrackLocal) (*webrtc.RTPTransc } } - c.screenTransceiver = trx + c.screenTransceivers = append(c.screenTransceivers, trx) sender := trx.Sender() @@ -131,13 +131,14 @@ func (c *Client) StopScreenShare() error { c.mut.Lock() defer c.mut.Unlock() - if c.screenTransceiver != nil { - if err := c.pc.RemoveTrack(c.screenTransceiver.Sender()); err != nil { + for _, trx := range c.screenTransceivers { + if err := c.pc.RemoveTrack(trx.Sender()); err != nil { return fmt.Errorf("failed to remove track: %w", err) } - c.screenTransceiver = nil } + c.screenTransceivers = nil + return c.sendWS(wsEventScreenOff, nil, false) } diff --git a/client/api_test.go b/client/api_test.go index 5ae877f..3fa6564 100644 --- a/client/api_test.go +++ b/client/api_test.go @@ -501,7 +501,7 @@ func TestAPIScreenShare(t *testing.T) { require.NoError(t, err) t.Run("not initialized", func(t *testing.T) { - _, err := th.userClient.StartScreenShare([]webrtc.TrackLocal{th.newScreenTrack()}) + _, err := th.userClient.StartScreenShare([]webrtc.TrackLocal{th.newScreenTrack(webrtc.MimeTypeVP8)}) require.EqualError(t, err, "rtc client is not initialized") }) @@ -533,7 +533,7 @@ func TestAPIScreenShare(t *testing.T) { // Test logic // User screen shares, admin should receive the track - userScreenTrack := th.newScreenTrack() + userScreenTrack := th.newScreenTrack(webrtc.MimeTypeVP8) _, err = th.userClient.StartScreenShare([]webrtc.TrackLocal{userScreenTrack}) require.NoError(t, err) go th.screenTrackWriter(userScreenTrack, userCloseCh) @@ -623,6 +623,154 @@ func TestAPIScreenShare(t *testing.T) { } } +func TestAPIScreenShareAV1(t *testing.T) { + th := setupTestHelper(t, "calls0") + + th.userClient.cfg.EnableAV1 = true + th.adminClient.cfg.EnableAV1 = true + + // Setup + userConnectCh := make(chan struct{}) + err := th.userClient.On(RTCConnectEvent, func(_ any) error { + close(userConnectCh) + return nil + }) + require.NoError(t, err) + + adminConnectCh := make(chan struct{}) + err = th.adminClient.On(RTCConnectEvent, func(_ any) error { + close(adminConnectCh) + return nil + }) + require.NoError(t, err) + + t.Run("not initialized", func(t *testing.T) { + _, err := th.userClient.StartScreenShare([]webrtc.TrackLocal{th.newScreenTrack(webrtc.MimeTypeAV1)}) + require.EqualError(t, err, "rtc client is not initialized") + }) + + go func() { + err := th.userClient.Connect() + require.NoError(t, err) + }() + + go func() { + err := th.adminClient.Connect() + require.NoError(t, err) + }() + + select { + case <-userConnectCh: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for user connect event") + } + + select { + case <-adminConnectCh: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for admin connect event") + } + + userCloseCh := make(chan struct{}) + adminCloseCh := make(chan struct{}) + + // Test logic + + // User screen shares, admin should receive the track + userScreenTrack := th.newScreenTrack(webrtc.MimeTypeAV1) + _, err = th.userClient.StartScreenShare([]webrtc.TrackLocal{userScreenTrack}) + require.NoError(t, err) + go th.screenTrackWriter(userScreenTrack, userCloseCh) + + screenTrackCh := make(chan struct{}) + err = th.adminClient.On(RTCTrackEvent, func(ctx any) error { + m := ctx.(map[string]any) + track := m["track"].(*webrtc.TrackRemote) + if track.Codec().MimeType == webrtc.MimeTypeAV1 { + close(screenTrackCh) + } + return nil + }) + require.NoError(t, err) + + userScreenOnCh := make(chan struct{}) + err = th.adminClient.On(WSCallScreenOnEvent, func(ctx any) error { + sessionID := ctx.(string) + if sessionID == th.userClient.originalConnID { + close(userScreenOnCh) + } + return nil + }) + require.NoError(t, err) + + select { + case <-userScreenOnCh: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for user screen on event") + } + + select { + case <-screenTrackCh: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for screen track") + } + + userScreenOffCh := make(chan struct{}) + err = th.adminClient.On(WSCallScreenOffEvent, func(ctx any) error { + sessionID := ctx.(string) + if sessionID == th.userClient.originalConnID { + close(userScreenOffCh) + } + return nil + }) + require.NoError(t, err) + + err = th.userClient.StopScreenShare() + require.NoError(t, err) + + select { + case <-userScreenOffCh: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for user screen off event") + } + + // Teardown + + err = th.userClient.On(CloseEvent, func(_ any) error { + close(userCloseCh) + return nil + }) + require.NoError(t, err) + + err = th.adminClient.On(CloseEvent, func(_ any) error { + close(adminCloseCh) + return nil + }) + require.NoError(t, err) + + go func() { + err := th.userClient.Close() + require.NoError(t, err) + }() + + go func() { + err := th.adminClient.Close() + require.NoError(t, err) + }() + + select { + case <-userCloseCh: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for close event") + } + + select { + case <-adminCloseCh: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for close event") + } +} + func TestAPIConcurrency(t *testing.T) { t.Run("Mute/Unmute", func(t *testing.T) { th := setupTestHelper(t, "calls0") @@ -746,7 +894,7 @@ func TestAPIScreenShareAndVoice(t *testing.T) { // Test logic // User screen shares, admin should receive the track - userScreenTrack := th.newScreenTrack() + userScreenTrack := th.newScreenTrack(webrtc.MimeTypeVP8) _, err = th.userClient.StartScreenShare([]webrtc.TrackLocal{userScreenTrack}) require.NoError(t, err) go th.screenTrackWriter(userScreenTrack, userCloseCh) diff --git a/client/call.go b/client/call.go index a1dfb39..48d7b68 100644 --- a/client/call.go +++ b/client/call.go @@ -9,8 +9,9 @@ import ( func (c *Client) joinCall() error { if err := c.SendWS(wsEventJoin, CallJoinMessage{ - ChannelID: c.cfg.ChannelID, - JobID: c.cfg.JobID, + ChannelID: c.cfg.ChannelID, + JobID: c.cfg.JobID, + AV1Support: c.cfg.EnableAV1, }, false); err != nil { return fmt.Errorf("failed to send ws msg: %w", err) } diff --git a/client/client.go b/client/client.go index f290e4f..7963c88 100644 --- a/client/client.go +++ b/client/client.go @@ -97,12 +97,12 @@ type Client struct { currentConnID string // WebRTC - pc *webrtc.PeerConnection - dc *webrtc.DataChannel - iceCh chan webrtc.ICECandidateInit - receivers map[string][]*webrtc.RTPReceiver - voiceSender *webrtc.RTPSender - screenTransceiver *webrtc.RTPTransceiver + pc *webrtc.PeerConnection + dc *webrtc.DataChannel + iceCh chan webrtc.ICECandidateInit + receivers map[string][]*webrtc.RTPReceiver + voiceSender *webrtc.RTPSender + screenTransceivers []*webrtc.RTPTransceiver state int32 diff --git a/client/config.go b/client/config.go index 84f4564..862ee8c 100644 --- a/client/config.go +++ b/client/config.go @@ -22,6 +22,9 @@ type Config struct { // JobID is an id used to identify bot initiated sessions (e.g. // recording/transcription) JobID string + // EnableAV1 controls whether the client should advertise support + // for receiving the AV1 codec. + EnableAV1 bool wsURL string } diff --git a/client/helper_test.go b/client/helper_test.go index 84ef8c2..7e87eca 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -48,11 +48,11 @@ const ( waitTimeout = 5 * time.Second ) -func (th *TestHelper) newScreenTrack() *webrtc.TrackLocalStaticRTP { +func (th *TestHelper) newScreenTrack(mimeType string) *webrtc.TrackLocalStaticRTP { th.tb.Helper() track, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{ - MimeType: "video/VP8", + MimeType: mimeType, ClockRate: 90000, SDPFmtpLine: "", RTCPFeedback: []webrtc.RTCPFeedback{ @@ -68,19 +68,27 @@ func (th *TestHelper) newScreenTrack() *webrtc.TrackLocalStaticRTP { } func (th *TestHelper) screenTrackWriter(track *webrtc.TrackLocalStaticRTP, closeCh <-chan struct{}) { + var payloader rtp.Payloader + payloader = &codecs.VP8Payloader{ + EnablePictureID: true, + } + filename := "../testfiles/video.ivf" + if track.Codec().MimeType == webrtc.MimeTypeAV1 { + payloader = &codecs.AV1Payloader{} + filename = "../testfiles/video_av1.ivf" + } + packetizer := rtp.NewPacketizer( 1200, 0, 0, - &codecs.VP8Payloader{ - EnablePictureID: true, - }, + payloader, rtp.NewRandomSequencer(), 90000, ) // Open a IVF file and start reading using our IVFReader - file, ivfErr := os.Open("../testfiles/video.ivf") + file, ivfErr := os.Open(filename) if ivfErr != nil { log.Fatalf(ivfErr.Error()) } @@ -139,7 +147,7 @@ func (th *TestHelper) screenTrackWriter(track *webrtc.TrackLocalStaticRTP, close func (th *TestHelper) transmitScreenTrack(c *Client) { th.tb.Helper() - track := th.newScreenTrack() + track := th.newScreenTrack(webrtc.MimeTypeVP8) sender, err := c.pc.AddTrack(track) require.NoError(th.tb, err) diff --git a/client/types.go b/client/types.go index d3f0d20..4089998 100644 --- a/client/types.go +++ b/client/types.go @@ -6,8 +6,9 @@ package client const pluginID = "com.mattermost.calls" type CallJoinMessage struct { - ChannelID string `json:"channelID"` - JobID string `json:"jobID"` + ChannelID string `json:"channelID"` + JobID string `json:"jobID"` + AV1Support bool `json:"av1Support"` } type CallReconnectMessage struct { diff --git a/testfiles/video_av1.ivf b/testfiles/video_av1.ivf new file mode 100644 index 0000000..dacecee Binary files /dev/null and b/testfiles/video_av1.ivf differ