Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AV1 support to client implementation #152

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (c *Client) StartScreenShare(tracks []webrtc.TrackLocal) (*webrtc.RTPTransc
}
}

c.screenTransceiver = trx
c.screenTransceivers = append(c.screenTransceivers, trx)

sender := trx.Sender()

Expand All @@ -131,13 +131,14 @@ func (c *Client) StopScreenShare() error {
c.mut.Lock()
defer c.mut.Unlock()

if c.screenTransceiver != nil {
if err := c.pc.RemoveTrack(c.screenTransceiver.Sender()); err != nil {
for _, trx := range c.screenTransceivers {
if err := c.pc.RemoveTrack(trx.Sender()); err != nil {
return fmt.Errorf("failed to remove track: %w", err)
}
c.screenTransceiver = nil
}

c.screenTransceivers = nil

return c.sendWS(wsEventScreenOff, nil, false)
}

Expand Down
154 changes: 151 additions & 3 deletions client/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ func TestAPIScreenShare(t *testing.T) {
require.NoError(t, err)

t.Run("not initialized", func(t *testing.T) {
_, err := th.userClient.StartScreenShare([]webrtc.TrackLocal{th.newScreenTrack()})
_, err := th.userClient.StartScreenShare([]webrtc.TrackLocal{th.newScreenTrack(webrtc.MimeTypeVP8)})
require.EqualError(t, err, "rtc client is not initialized")
})

Expand Down Expand Up @@ -533,7 +533,7 @@ func TestAPIScreenShare(t *testing.T) {
// Test logic

// User screen shares, admin should receive the track
userScreenTrack := th.newScreenTrack()
userScreenTrack := th.newScreenTrack(webrtc.MimeTypeVP8)
_, err = th.userClient.StartScreenShare([]webrtc.TrackLocal{userScreenTrack})
require.NoError(t, err)
go th.screenTrackWriter(userScreenTrack, userCloseCh)
Expand Down Expand Up @@ -623,6 +623,154 @@ func TestAPIScreenShare(t *testing.T) {
}
}

func TestAPIScreenShareAV1(t *testing.T) {
th := setupTestHelper(t, "calls0")

th.userClient.cfg.EnableAV1 = true
th.adminClient.cfg.EnableAV1 = true

// Setup
userConnectCh := make(chan struct{})
err := th.userClient.On(RTCConnectEvent, func(_ any) error {
close(userConnectCh)
return nil
})
require.NoError(t, err)

adminConnectCh := make(chan struct{})
err = th.adminClient.On(RTCConnectEvent, func(_ any) error {
close(adminConnectCh)
return nil
})
require.NoError(t, err)

t.Run("not initialized", func(t *testing.T) {
_, err := th.userClient.StartScreenShare([]webrtc.TrackLocal{th.newScreenTrack(webrtc.MimeTypeAV1)})
require.EqualError(t, err, "rtc client is not initialized")
})

go func() {
err := th.userClient.Connect()
require.NoError(t, err)
}()

go func() {
err := th.adminClient.Connect()
require.NoError(t, err)
}()

select {
case <-userConnectCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for user connect event")
}

select {
case <-adminConnectCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for admin connect event")
}

userCloseCh := make(chan struct{})
adminCloseCh := make(chan struct{})

// Test logic

// User screen shares, admin should receive the track
userScreenTrack := th.newScreenTrack(webrtc.MimeTypeAV1)
_, err = th.userClient.StartScreenShare([]webrtc.TrackLocal{userScreenTrack})
require.NoError(t, err)
go th.screenTrackWriter(userScreenTrack, userCloseCh)

screenTrackCh := make(chan struct{})
err = th.adminClient.On(RTCTrackEvent, func(ctx any) error {
m := ctx.(map[string]any)
track := m["track"].(*webrtc.TrackRemote)
if track.Codec().MimeType == webrtc.MimeTypeAV1 {
close(screenTrackCh)
}
return nil
})
require.NoError(t, err)

userScreenOnCh := make(chan struct{})
err = th.adminClient.On(WSCallScreenOnEvent, func(ctx any) error {
sessionID := ctx.(string)
if sessionID == th.userClient.originalConnID {
close(userScreenOnCh)
}
return nil
})
require.NoError(t, err)

select {
case <-userScreenOnCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for user screen on event")
}

select {
case <-screenTrackCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for screen track")
}

userScreenOffCh := make(chan struct{})
err = th.adminClient.On(WSCallScreenOffEvent, func(ctx any) error {
sessionID := ctx.(string)
if sessionID == th.userClient.originalConnID {
close(userScreenOffCh)
}
return nil
})
require.NoError(t, err)

err = th.userClient.StopScreenShare()
require.NoError(t, err)

select {
case <-userScreenOffCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for user screen off event")
}

// Teardown

err = th.userClient.On(CloseEvent, func(_ any) error {
close(userCloseCh)
return nil
})
require.NoError(t, err)

err = th.adminClient.On(CloseEvent, func(_ any) error {
close(adminCloseCh)
return nil
})
require.NoError(t, err)

go func() {
err := th.userClient.Close()
require.NoError(t, err)
}()

go func() {
err := th.adminClient.Close()
require.NoError(t, err)
}()

select {
case <-userCloseCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for close event")
}

select {
case <-adminCloseCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for close event")
}
}

func TestAPIConcurrency(t *testing.T) {
t.Run("Mute/Unmute", func(t *testing.T) {
th := setupTestHelper(t, "calls0")
Expand Down Expand Up @@ -746,7 +894,7 @@ func TestAPIScreenShareAndVoice(t *testing.T) {
// Test logic

// User screen shares, admin should receive the track
userScreenTrack := th.newScreenTrack()
userScreenTrack := th.newScreenTrack(webrtc.MimeTypeVP8)
_, err = th.userClient.StartScreenShare([]webrtc.TrackLocal{userScreenTrack})
require.NoError(t, err)
go th.screenTrackWriter(userScreenTrack, userCloseCh)
Expand Down
5 changes: 3 additions & 2 deletions client/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (

func (c *Client) joinCall() error {
if err := c.SendWS(wsEventJoin, CallJoinMessage{
ChannelID: c.cfg.ChannelID,
JobID: c.cfg.JobID,
ChannelID: c.cfg.ChannelID,
JobID: c.cfg.JobID,
AV1Support: c.cfg.EnableAV1,
}, false); err != nil {
return fmt.Errorf("failed to send ws msg: %w", err)
}
Expand Down
12 changes: 6 additions & 6 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ type Client struct {
currentConnID string

// WebRTC
pc *webrtc.PeerConnection
dc *webrtc.DataChannel
iceCh chan webrtc.ICECandidateInit
receivers map[string][]*webrtc.RTPReceiver
voiceSender *webrtc.RTPSender
screenTransceiver *webrtc.RTPTransceiver
pc *webrtc.PeerConnection
dc *webrtc.DataChannel
iceCh chan webrtc.ICECandidateInit
receivers map[string][]*webrtc.RTPReceiver
voiceSender *webrtc.RTPSender
screenTransceivers []*webrtc.RTPTransceiver

state int32

Expand Down
3 changes: 3 additions & 0 deletions client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ type Config struct {
// JobID is an id used to identify bot initiated sessions (e.g.
// recording/transcription)
JobID string
// EnableAV1 controls whether the client should advertise support
// for receiving the AV1 codec.
EnableAV1 bool

wsURL string
}
Expand Down
22 changes: 15 additions & 7 deletions client/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ const (
waitTimeout = 5 * time.Second
)

func (th *TestHelper) newScreenTrack() *webrtc.TrackLocalStaticRTP {
func (th *TestHelper) newScreenTrack(mimeType string) *webrtc.TrackLocalStaticRTP {
th.tb.Helper()

track, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{
MimeType: "video/VP8",
MimeType: mimeType,
ClockRate: 90000,
SDPFmtpLine: "",
RTCPFeedback: []webrtc.RTCPFeedback{
Expand All @@ -68,19 +68,27 @@ func (th *TestHelper) newScreenTrack() *webrtc.TrackLocalStaticRTP {
}

func (th *TestHelper) screenTrackWriter(track *webrtc.TrackLocalStaticRTP, closeCh <-chan struct{}) {
var payloader rtp.Payloader
payloader = &codecs.VP8Payloader{
EnablePictureID: true,
}
filename := "../testfiles/video.ivf"
if track.Codec().MimeType == webrtc.MimeTypeAV1 {
payloader = &codecs.AV1Payloader{}
filename = "../testfiles/video_av1.ivf"
}

packetizer := rtp.NewPacketizer(
1200,
0,
0,
&codecs.VP8Payloader{
EnablePictureID: true,
},
payloader,
rtp.NewRandomSequencer(),
90000,
)

// Open a IVF file and start reading using our IVFReader
file, ivfErr := os.Open("../testfiles/video.ivf")
file, ivfErr := os.Open(filename)
if ivfErr != nil {
log.Fatalf(ivfErr.Error())
}
Expand Down Expand Up @@ -139,7 +147,7 @@ func (th *TestHelper) screenTrackWriter(track *webrtc.TrackLocalStaticRTP, close
func (th *TestHelper) transmitScreenTrack(c *Client) {
th.tb.Helper()

track := th.newScreenTrack()
track := th.newScreenTrack(webrtc.MimeTypeVP8)

sender, err := c.pc.AddTrack(track)
require.NoError(th.tb, err)
Expand Down
5 changes: 3 additions & 2 deletions client/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ package client
const pluginID = "com.mattermost.calls"

type CallJoinMessage struct {
ChannelID string `json:"channelID"`
JobID string `json:"jobID"`
ChannelID string `json:"channelID"`
JobID string `json:"jobID"`
AV1Support bool `json:"av1Support"`
}

type CallReconnectMessage struct {
Expand Down
Binary file added testfiles/video_av1.ivf
Binary file not shown.
Loading