Skip to content

Commit

Permalink
feat: load previous rewards before rewriting so we keep rewards history
Browse files Browse the repository at this point in the history
  • Loading branch information
WendelHime committed Nov 27, 2024
1 parent 928b600 commit 72a77d4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
49 changes: 28 additions & 21 deletions dialer/bandit.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) (
}
}

err = bd.SaveBanditRewards(metrics)
err = bd.UpdateBanditRewards(metrics)
if err != nil {
log.Errorf("unable to save bandit weights: %v", err)
}
Expand All @@ -171,9 +171,10 @@ func (o *BanditDialer) LoadLastBanditRewards() (map[string]banditMetrics, error)

file := filepath.Join(o.opts.BanditDir, "rewards.csv")
data, err := os.Open(file)
if err != nil && os.IsNotExist(err) {
return nil, log.Errorf("unable to read bandit rewards from file: %v", err)
if err != nil {
return nil, err
}

reader := csv.NewReader(data)
_, err = reader.Read() // Skip the header
if err != nil {
Expand Down Expand Up @@ -208,42 +209,48 @@ func (o *BanditDialer) LoadLastBanditRewards() (map[string]banditMetrics, error)
return metrics, nil
}

func (o *BanditDialer) SaveBanditRewards(metrics map[string]banditMetrics) error {
func (o *BanditDialer) UpdateBanditRewards(newRewards map[string]banditMetrics) error {
if err := os.MkdirAll(o.opts.BanditDir, 0755); err != nil {
return log.Errorf("unable to create bandit directory: %v", err)
}

previousRewards, err := o.LoadLastBanditRewards()
if err != nil && !os.IsNotExist(err) {
return log.Errorf("couldn't load previous bandit rewards: %w", err)
}
o.banditRewardsMutex.Lock()
defer o.banditRewardsMutex.Unlock()

// if there's previous rewards, we must overwrite current values
if previousRewards != nil {
for dialer, metrics := range newRewards {
previousRewards[dialer] = metrics
}
} else {
previousRewards = newRewards
}

if o.opts.BanditDir == "" {
return log.Error("bandit directory is not set")
}

if err := os.MkdirAll(o.opts.BanditDir, 0755); err != nil {
return log.Errorf("unable to create bandit directory: %v", err)
}
file := filepath.Join(o.opts.BanditDir, "rewards.csv")

headers := []string{"dialer", "reward", "count"}
writeHeaders := false
if _, err := os.Stat(file); err != nil {
if !os.IsNotExist(err) {
return log.Errorf("unable to stat bandit rewards file: %v", err)
}
writeHeaders = true
}

f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return log.Errorf("unable to open bandit rewards file: %v", err)
}
defer f.Close()

w := csv.NewWriter(f)
defer w.Flush()
if writeHeaders {
if err = w.Write(headers); err != nil {
return log.Errorf("unable to write headers to bandit rewards file: %v", err)
}

if err = w.Write(headers); err != nil {
return log.Errorf("unable to write headers to bandit rewards file: %v", err)
}

for dialerName, metric := range metrics {
for dialerName, metric := range previousRewards {
if err = w.Write([]string{dialerName, fmt.Sprintf("%f", metric.Reward), fmt.Sprintf("%d", metric.Count)}); err != nil {
return log.Errorf("unable to write bandit rewards to file: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion dialer/bandit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ func TestSaveBanditRewards(t *testing.T) {
BanditDir: tempDir,
},
}
err = banditDialer.SaveBanditRewards(tt.given)
err = banditDialer.UpdateBanditRewards(tt.given)
tt.assert(t, tempDir, err)
})
}
Expand Down

0 comments on commit 72a77d4

Please sign in to comment.