Skip to content

Commit

Permalink
fix: adding suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
WendelHime committed Nov 26, 2024
1 parent 6a8d589 commit 63bc8ef
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 202 deletions.
70 changes: 1 addition & 69 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"io"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"strconv"
Expand Down Expand Up @@ -737,78 +736,11 @@ func (client *Client) initDialers(proxies map[string]*commonconfig.ProxyConfig)
)
}
},
SaveBanditRewards: saveBanditRewards(configDir),
LoadLastBanditRewards: loadLastBanditRewards(configDir),
BanditDir: filepath.Join(configDir, "bandit"),
})
return dialers, dialer, nil
}

func saveBanditRewards(dir string) func(map[string]dialer.BanditMetrics) {
return func(metrics map[string]dialer.BanditMetrics) {
dir := filepath.Join(dir, "bandit")
if err := os.MkdirAll(dir, 0755); err != nil {
log.Errorf("unable to create bandit directory: %v", err)
return
}
file := filepath.Join(dir, "rewards.csv")
csv := new(strings.Builder)
csv.WriteString("dialer,reward,count\n")
for dialerName, metric := range metrics {
csv.WriteString(fmt.Sprintf("%s,%f,%d\n", dialerName, metric.Reward, metric.Count))
}
f, err := os.Create(file)
if err != nil {
log.Errorf("unable to create bandit rewards file: %v", err)
return
}
defer f.Close()
if _, err := f.WriteString(csv.String()); err != nil {
log.Errorf("unable to write bandit rewards to file: %v", err)
}
}
}

func loadLastBanditRewards(outputDir string) func() map[string]dialer.BanditMetrics {
return func() map[string]dialer.BanditMetrics {
dir := filepath.Join(outputDir, "bandit")
file := filepath.Join(dir, "rewards.csv")
if _, err := os.Stat(file); os.IsNotExist(err) {
return nil
}
data, err := os.ReadFile(file)
if err != nil {
log.Errorf("unable to read bandit rewards from file: %v", err)
return nil
}
lines := strings.Split(string(data), "\n")
metrics := make(map[string]dialer.BanditMetrics)
for i, line := range lines {
if i == 0 {
continue
}
parts := strings.Split(line, ",")
if len(parts) != 3 {
continue
}
reward, err := strconv.ParseFloat(parts[1], 64)
if err != nil {
log.Errorf("unable to parse reward from %s: %v", parts[1], err)
continue
}
count, err := strconv.Atoi(parts[2])
if err != nil {
log.Errorf("unable to parse count from %s: %v", parts[2], err)
continue
}
metrics[parts[0]] = dialer.BanditMetrics{
Reward: reward,
Count: count,
}
}
return metrics
}
}

// Creates a local server to capture client hello messages from the browser and
// caches them.
func (client *Client) cacheClientHellos() {
Expand Down
85 changes: 0 additions & 85 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ import (
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -424,89 +422,6 @@ func TestAccessingProxyPort(t *testing.T) {
assert.Equal(t, "0", resp.Header.Get("Content-Length"))
}

func TestSaveBanditRewards(t *testing.T) {
var tests = []struct {
name string
given map[string]dialer.BanditMetrics
assert func(t *testing.T, dir string)
}{
{
name: "it should save the rewards",
given: map[string]dialer.BanditMetrics{
"test-dialer": {
Reward: 1.0,
Count: 1,
},
},
assert: func(t *testing.T, dir string) {
f, err := os.Open(filepath.Join(dir, "bandit", "rewards.csv"))
require.NoError(t, err)
defer f.Close()
b, err := io.ReadAll(f)
require.NoError(t, err)

lines := strings.Split(string(b), "\n")
// check if headers are there
assert.Contains(t, lines[0], "dialer,reward,count")
// check if the data is there
assert.Contains(t, lines[1], "test-dialer,1.000000,1")
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
tempDir, err := os.MkdirTemp("", "client_test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)

f := saveBanditRewards(tempDir)
f(tt.given)
tt.assert(t, tempDir)
})
}
}

func TestLoadLastBanditRewards(t *testing.T) {
var tests = []struct {
name string
given string
assert func(t *testing.T, metrics map[string]dialer.BanditMetrics)
}{
{
name: "it should load the rewards",
given: "dialer,reward,count\ntest-dialer,1.000000,1\n",
assert: func(t *testing.T, metrics map[string]dialer.BanditMetrics) {
assert.Contains(t, metrics, "test-dialer")
assert.Equal(t, 1.0, metrics["test-dialer"].Reward)
assert.Equal(t, 1, metrics["test-dialer"].Count)
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
tempDir, err := os.MkdirTemp("", "client_test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)

if err := os.MkdirAll(filepath.Join(tempDir, "bandit"), 0755); err != nil {
log.Errorf("unable to create bandit directory: %v", err)
return
}

f, err := os.Create(filepath.Join(tempDir, "bandit", "rewards.csv"))
require.NoError(t, err)
defer f.Close()
_, err = f.WriteString(tt.given)
require.NoError(t, err)

metrics := loadLastBanditRewards(tempDir)()
tt.assert(t, metrics)
})
}
}

// Assert that a testDialer is a bandit.Dialer
var _ dialer.ProxyDialer = &testDialer{}

Expand Down
154 changes: 128 additions & 26 deletions dialer/bandit.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@ package dialer

import (
"context"
"encoding/csv"
"fmt"
"io"
"math/rand"
"net"
"os"
"path/filepath"
"strconv"
"sync/atomic"
"time"

Expand Down Expand Up @@ -36,32 +42,38 @@ func NewBandit(opts *Options) (Dialer, error) {
dialers: dialers,
opts: opts,
}
if opts.LoadLastBanditRewards != nil {
log.Debugf("Loading bandit weights from %s", opts.LoadLastBanditRewards)
dialerWeights := opts.LoadLastBanditRewards()

dialerWeights, err := dialer.LoadLastBanditRewards()
if err != nil {
log.Errorf("unable to load bandit weights: %v", err)
}
if dialerWeights != nil {
log.Debugf("Loading bandit weights from %q", opts.BanditDir)
counts := make([]int, len(dialers))
rewards := make([]float64, len(dialers))
for arm, dialer := range dialers {
if banditMetrics, ok := dialerWeights[dialer.Name()]; ok {
rewards[arm] = banditMetrics.Reward
counts[arm] = banditMetrics.Count
if metrics, ok := dialerWeights[dialer.Name()]; ok {
rewards[arm] = metrics.Reward
counts[arm] = metrics.Count
}
}
b, err = bandit.NewEpsilonGreedy(0.1, counts, rewards)
if err != nil {
log.Errorf("unable to create weighted bandit: %w", err)
return nil, err
}
} else {
b, err = bandit.NewEpsilonGreedy(0.1, nil, nil)
if err != nil {
log.Errorf("unable to create bandit: %v", err)
return nil, err
}
if err := b.Init(len(dialers)); err != nil {
log.Errorf("unable to initialize bandit: %v", err)
return nil, err
}
dialer.bandit = b
return dialer, nil
}

b, err = bandit.NewEpsilonGreedy(0.1, nil, nil)
if err != nil {
log.Errorf("unable to create bandit: %v", err)
return nil, err
}
if err := b.Init(len(dialers)); err != nil {
log.Errorf("unable to initialize bandit: %v", err)
return nil, err
}
dialer.bandit = b

Expand Down Expand Up @@ -124,25 +136,115 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) (

time.AfterFunc(30*time.Second, func() {
// Save the bandit weights
if bd.opts.SaveBanditRewards != nil {
metrics := make(map[string]BanditMetrics)
rewards := bd.bandit.GetRewards()
counts := bd.bandit.GetCounts()
for i, d := range bd.dialers {
metrics[d.Name()] = BanditMetrics{
Reward: rewards[i],
Count: counts[i],
}
metrics := make(map[string]banditMetrics)
rewards := bd.bandit.GetRewards()
counts := bd.bandit.GetCounts()
for i, d := range bd.dialers {
metrics[d.Name()] = banditMetrics{
Reward: rewards[i],
Count: counts[i],
}
}

bd.opts.SaveBanditRewards(metrics)
err = bd.SaveBanditRewards(metrics)
if err != nil {
log.Errorf("unable to save bandit weights: %v", err)
}
})

bd.opts.OnSuccess(d)
return dt, err
}

// LoadLastBanditRewards is a function that returns the last bandit rewards
// for each dialer. If this is set, the bandit will be initialized with the
// last metrics.
func (o *BanditDialer) LoadLastBanditRewards() (map[string]banditMetrics, error) {
if o.opts.BanditDir == "" {
return nil, nil
}

file := filepath.Join(o.opts.BanditDir, "rewards.csv")
data, err := os.Open(file)
if err != nil && os.IsNotExist(err) {
return nil, log.Errorf("unable to read bandit rewards from file: %v", err)
}
reader := csv.NewReader(data)
_, err = reader.Read() // Skip the header
if err != nil {
return nil, log.Errorf("unable to skip headers from bandit rewards csv: %v", err)
}
metrics := make(map[string]banditMetrics)
for {
line, err := reader.Read()
if err == io.EOF {
break
}
if err != nil {
return nil, log.Errorf("unable to read line from bandit rewards csv: %v", err)
}

if len(line) != 3 {
return nil, log.Errorf("invalid line in bandit rewards csv: %v", line)
}
reward, err := strconv.ParseFloat(line[1], 64)
if err != nil {
return nil, log.Errorf("unable to parse reward from %s: %v", line[0], err)
}
count, err := strconv.Atoi(line[2])
if err != nil {
return nil, log.Errorf("unable to parse count from %s: %v", line[0], err)
}
metrics[line[0]] = banditMetrics{
Reward: reward,
Count: count,
}
}
return metrics, nil
}

func (o *BanditDialer) SaveBanditRewards(metrics map[string]banditMetrics) error {
if o.opts.BanditDir == "" {
return log.Error("bandit directory is not set")
}

if err := os.MkdirAll(o.opts.BanditDir, 0755); err != nil {
return log.Errorf("unable to create bandit directory: %v", err)
}
file := filepath.Join(o.opts.BanditDir, "rewards.csv")

headers := []string{"dialer", "reward", "count"}
writeHeaders := false
if _, err := os.Stat(file); err != nil {
if !os.IsNotExist(err) {
return log.Errorf("unable to stat bandit rewards file: %v", err)
}
writeHeaders = true
}

f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return log.Errorf("unable to open bandit rewards file: %v", err)
}
defer f.Close()

w := csv.NewWriter(f)
defer w.Flush()
if writeHeaders {
if err = w.Write(headers); err != nil {
return log.Errorf("unable to write headers to bandit rewards file: %v", err)
}
}

for dialerName, metric := range metrics {
if err = w.Write([]string{dialerName, fmt.Sprintf("%f", metric.Reward), fmt.Sprintf("%d", metric.Count)}); err != nil {
return log.Errorf("unable to write bandit rewards to file: %v", err)
}
}

return nil
}

func (o *BanditDialer) chooseDialerForDomain(network, addr string) (ProxyDialer, int) {
// Loop through the number of dialers we have and select the one that is best
// for the given domain.
Expand Down
Loading

0 comments on commit 63bc8ef

Please sign in to comment.