Skip to content

Commit

Permalink
Fix tsh play --skip-idle-time not working correctly (#47304)
Browse files Browse the repository at this point in the history
* fix(player): use skip idle flag and adjust max value

* test(player): increase timeout

* refactor(player): use time.Duration instead of float64 for timings

* refactor(player): store duration values in nanoseconds
  • Loading branch information
gabrielcorado authored Nov 4, 2024
1 parent c5333f4 commit 0b33c2c
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 36 deletions.
6 changes: 2 additions & 4 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2330,13 +2330,11 @@ func playSession(ctx context.Context, sessionID string, speed float64, streamer
}
playing = !playing
case keyLeft, keyDown:
current := time.Duration(player.LastPlayed() * int64(time.Millisecond))
player.SetPos(max(current-skipDuration, 0)) // rewind
player.SetPos(max(player.LastPlayed()-skipDuration, 0)) // rewind
term.Clear()
term.SetCursorPos(1, 1)
case keyRight, keyUp:
current := time.Duration(player.LastPlayed() * int64(time.Millisecond))
player.SetPos(current + skipDuration) // advance forward
player.SetPos(player.LastPlayed() + skipDuration) // advance forward
}
}
}()
Expand Down
68 changes: 37 additions & 31 deletions lib/player/player.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Player struct {
advanceTo atomic.Int64

emit chan events.AuditEvent
wake chan int64
wake chan time.Duration
done chan struct{}

// playPause holds a channel to be closed when
Expand All @@ -82,7 +82,12 @@ type Player struct {
translator sessionPrintTranslator
}

const normalPlayback = math.MinInt64
const (
normalPlayback = time.Duration(0)
// MaxIdleTime defines the max idle time when skipping idle
// periods on the recording.
MaxIdleTime = 500 * time.Millisecond
)

// Streamer is the underlying streamer that provides
// access to recorded session events.
Expand Down Expand Up @@ -135,18 +140,19 @@ func New(cfg *Config) (*Player, error) {
)

p := &Player{
clock: clk,
log: log,
sessionID: cfg.SessionID,
streamer: cfg.Streamer,
emit: make(chan events.AuditEvent, 1024),
playPause: make(chan chan struct{}, 1),
wake: make(chan int64),
done: make(chan struct{}),
clock: clk,
log: log,
sessionID: cfg.SessionID,
streamer: cfg.Streamer,
emit: make(chan events.AuditEvent, 1024),
playPause: make(chan chan struct{}, 1),
wake: make(chan time.Duration),
done: make(chan struct{}),
skipIdleTime: cfg.SkipIdleTime,
}

p.speed.Store(float64(defaultPlaybackSpeed))
p.advanceTo.Store(normalPlayback)
p.advanceTo.Store(int64(normalPlayback))

// start in a paused state
p.playPause <- make(chan struct{})
Expand Down Expand Up @@ -184,7 +190,7 @@ func (p *Player) stream() {
defer cancel()

eventsC, errC := p.streamer.StreamSessionEvents(ctx, p.sessionID, 0)
lastDelay := int64(0)
var lastDelay time.Duration
for {
select {
case <-p.done:
Expand Down Expand Up @@ -216,20 +222,20 @@ func (p *Player) stream() {

currentDelay := getDelay(evt)
if currentDelay > 0 && currentDelay >= lastDelay {
switch adv := p.advanceTo.Load(); {
switch adv := time.Duration(p.advanceTo.Load()); {
case adv >= currentDelay:
// no timing delay necessary, we are fast forwarding
break
case adv < 0 && adv != normalPlayback:
// any negative value other than normalPlayback means
// we rewind (by restarting the stream and seeking forward
// to the rewind point)
p.advanceTo.Store(adv * -1)
p.advanceTo.Store(int64(adv) * -1)
go p.stream()
return
default:
if adv != normalPlayback {
p.advanceTo.Store(normalPlayback)
p.advanceTo.Store(int64(normalPlayback))

// we're catching back up to real time, so the delay
// is calculated not from the last event but from the
Expand Down Expand Up @@ -257,7 +263,7 @@ func (p *Player) stream() {
//
// TODO: consider a select with a timeout to detect blocked readers?
p.emit <- evt
p.lastPlayed.Store(currentDelay)
p.lastPlayed.Store(int64(currentDelay))
}
}
}
Expand Down Expand Up @@ -309,14 +315,14 @@ func (p *Player) SetPos(d time.Duration) error {
if d == 0 {
d = 1 * time.Millisecond
}
if d.Milliseconds() < p.lastPlayed.Load() {
if d < time.Duration(p.lastPlayed.Load()) {
d = -1 * d
}
p.advanceTo.Store(d.Milliseconds())
p.advanceTo.Store(int64(d))

// try to wake up the player if it's waiting to emit an event
select {
case p.wake <- d.Milliseconds():
case p.wake <- d:
default:
}

Expand All @@ -333,18 +339,18 @@ func (p *Player) SetPos(d time.Duration) error {
//
// A nil return value indicates that the delay has elapsed and that
// the next even can be emitted.
func (p *Player) applyDelay(lastDelay, currentDelay int64) error {
func (p *Player) applyDelay(lastDelay, currentDelay time.Duration) error {
loop:
for {
// TODO(zmb3): changing play speed during a long sleep
// will not apply until after the sleep completes
speed := p.speed.Load().(float64)
scaled := float64(currentDelay-lastDelay) / speed
scaled := time.Duration(float64(currentDelay-lastDelay) / speed)
if p.skipIdleTime {
scaled = min(scaled, 500.0*float64(time.Millisecond))
scaled = min(scaled, MaxIdleTime)
}

timer := p.clock.NewTimer(time.Duration(scaled) * time.Millisecond)
timer := p.clock.NewTimer(scaled)
defer timer.Stop()

start := time.Now()
Expand All @@ -358,7 +364,7 @@ loop:
case newPos == interruptForPause:
// the user paused playback while we were waiting to emit the next event:
// 1) figure out much of the sleep we completed
dur := float64(time.Since(start).Milliseconds()) * speed
dur := time.Duration(float64(time.Since(start)) * speed)

// 2) wait here until the user resumes playback
if err := p.waitWhilePaused(); errors.Is(err, errSeekWhilePaused) {
Expand All @@ -370,7 +376,7 @@ loop:
// now that we're playing again, update our delay to account
// for the portion that was already satisfied and apply the
// remaining delay
lastDelay += int64(dur)
lastDelay += dur
timer.Stop()
continue loop
case newPos > currentDelay:
Expand Down Expand Up @@ -455,8 +461,8 @@ func (p *Player) waitWhilePaused() error {

// LastPlayed returns the time of the last played event,
// expressed as milliseconds since the start of the session.
func (p *Player) LastPlayed() int64 {
return p.lastPlayed.Load()
func (p *Player) LastPlayed() time.Duration {
return time.Duration(p.lastPlayed.Load())
}

// translateEvent translates events if applicable and return if they should be
Expand Down Expand Up @@ -491,13 +497,13 @@ var databaseTranslators = map[string]newSessionPrintTranslatorFunc{
// player.
var SupportedDatabaseProtocols = maps.Keys(databaseTranslators)

func getDelay(e events.AuditEvent) int64 {
func getDelay(e events.AuditEvent) time.Duration {
switch x := e.(type) {
case *events.DesktopRecording:
return x.DelayMilliseconds
return time.Duration(x.DelayMilliseconds) * time.Millisecond
case *events.SessionPrint:
return x.DelayMilliseconds
return time.Duration(x.DelayMilliseconds) * time.Millisecond
default:
return int64(0)
return time.Duration(0)
}
}
31 changes: 30 additions & 1 deletion lib/player/player_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"time"

"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

apievents "github.com/gravitational/teleport/api/types/events"
Expand Down Expand Up @@ -169,7 +170,7 @@ func TestClose(t *testing.T) {
_, ok := <-p.C()
require.False(t, ok, "player channel should have been closed")
require.NoError(t, p.Err())
require.Equal(t, int64(1000), p.LastPlayed())
require.Equal(t, time.Second, p.LastPlayed())
}

func TestSeekForward(t *testing.T) {
Expand Down Expand Up @@ -321,6 +322,34 @@ func TestUseDatabaseTranslator(t *testing.T) {
})
}

func TestSkipIdlePeriods(t *testing.T) {
eventCount := 3
delayMilliseconds := 60000
clk := clockwork.NewFakeClock()
p, err := player.New(&player.Config{
Clock: clk,
SessionID: "test-session",
SkipIdleTime: true,
Streamer: &simpleStreamer{count: int64(eventCount), delay: int64(delayMilliseconds)},
})
require.NoError(t, err)
require.NoError(t, p.Play())

for i := range eventCount {
// Consume events in an eventually loop to avoid firing the clock
// events before the timer is set.
require.EventuallyWithT(t, func(t *assert.CollectT) {
clk.Advance(player.MaxIdleTime)
select {
case evt := <-p.C():
assert.Equal(t, int64(i), evt.GetIndex())
default:
assert.Fail(t, "expected to receive event after short period, but got nothing")
}
}, 3*time.Second, 100*time.Millisecond)
}
}

// simpleStreamer streams a fake session that contains
// count events, emitted at a particular interval
type simpleStreamer struct {
Expand Down

0 comments on commit 0b33c2c

Please sign in to comment.