diff --git a/client/client.go b/client/client.go index f6e024c64..b373da7a0 100644 --- a/client/client.go +++ b/client/client.go @@ -7,7 +7,6 @@ import ( "io" "net" "net/http" - "os" "path/filepath" "regexp" "strconv" @@ -737,78 +736,11 @@ func (client *Client) initDialers(proxies map[string]*commonconfig.ProxyConfig) ) } }, - SaveBanditRewards: saveBanditRewards(configDir), - LoadLastBanditRewards: loadLastBanditRewards(configDir), + BanditDir: filepath.Join(configDir, "bandit"), }) return dialers, dialer, nil } -func saveBanditRewards(dir string) func(map[string]dialer.BanditMetrics) { - return func(metrics map[string]dialer.BanditMetrics) { - dir := filepath.Join(dir, "bandit") - if err := os.MkdirAll(dir, 0755); err != nil { - log.Errorf("unable to create bandit directory: %v", err) - return - } - file := filepath.Join(dir, "rewards.csv") - csv := new(strings.Builder) - csv.WriteString("dialer,reward,count\n") - for dialerName, metric := range metrics { - csv.WriteString(fmt.Sprintf("%s,%f,%d\n", dialerName, metric.Reward, metric.Count)) - } - f, err := os.Create(file) - if err != nil { - log.Errorf("unable to create bandit rewards file: %v", err) - return - } - defer f.Close() - if _, err := f.WriteString(csv.String()); err != nil { - log.Errorf("unable to write bandit rewards to file: %v", err) - } - } -} - -func loadLastBanditRewards(outputDir string) func() map[string]dialer.BanditMetrics { - return func() map[string]dialer.BanditMetrics { - dir := filepath.Join(outputDir, "bandit") - file := filepath.Join(dir, "rewards.csv") - if _, err := os.Stat(file); os.IsNotExist(err) { - return nil - } - data, err := os.ReadFile(file) - if err != nil { - log.Errorf("unable to read bandit rewards from file: %v", err) - return nil - } - lines := strings.Split(string(data), "\n") - metrics := make(map[string]dialer.BanditMetrics) - for i, line := range lines { - if i == 0 { - continue - } - parts := strings.Split(line, ",") - if len(parts) != 3 { - continue - } - reward, err := strconv.ParseFloat(parts[1], 64) - if err != nil { - log.Errorf("unable to parse reward from %s: %v", parts[1], err) - continue - } - count, err := strconv.Atoi(parts[2]) - if err != nil { - log.Errorf("unable to parse count from %s: %v", parts[2], err) - continue - } - metrics[parts[0]] = dialer.BanditMetrics{ - Reward: reward, - Count: count, - } - } - return metrics - } -} - // Creates a local server to capture client hello messages from the browser and // caches them. func (client *Client) cacheClientHellos() { diff --git a/client/client_test.go b/client/client_test.go index ef58f41a8..b2013bac2 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -11,9 +11,7 @@ import ( "net/http/httptest" "net/url" "os" - "path/filepath" "strconv" - "strings" "sync/atomic" "testing" "time" @@ -424,89 +422,6 @@ func TestAccessingProxyPort(t *testing.T) { assert.Equal(t, "0", resp.Header.Get("Content-Length")) } -func TestSaveBanditRewards(t *testing.T) { - var tests = []struct { - name string - given map[string]dialer.BanditMetrics - assert func(t *testing.T, dir string) - }{ - { - name: "it should save the rewards", - given: map[string]dialer.BanditMetrics{ - "test-dialer": { - Reward: 1.0, - Count: 1, - }, - }, - assert: func(t *testing.T, dir string) { - f, err := os.Open(filepath.Join(dir, "bandit", "rewards.csv")) - require.NoError(t, err) - defer f.Close() - b, err := io.ReadAll(f) - require.NoError(t, err) - - lines := strings.Split(string(b), "\n") - // check if headers are there - assert.Contains(t, lines[0], "dialer,reward,count") - // check if the data is there - assert.Contains(t, lines[1], "test-dialer,1.000000,1") - }, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - tempDir, err := os.MkdirTemp("", "client_test") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - f := saveBanditRewards(tempDir) - f(tt.given) - tt.assert(t, tempDir) - }) - } -} - -func TestLoadLastBanditRewards(t *testing.T) { - var tests = []struct { - name string - given string - assert func(t *testing.T, metrics map[string]dialer.BanditMetrics) - }{ - { - name: "it should load the rewards", - given: "dialer,reward,count\ntest-dialer,1.000000,1\n", - assert: func(t *testing.T, metrics map[string]dialer.BanditMetrics) { - assert.Contains(t, metrics, "test-dialer") - assert.Equal(t, 1.0, metrics["test-dialer"].Reward) - assert.Equal(t, 1, metrics["test-dialer"].Count) - }, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - tempDir, err := os.MkdirTemp("", "client_test") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - if err := os.MkdirAll(filepath.Join(tempDir, "bandit"), 0755); err != nil { - log.Errorf("unable to create bandit directory: %v", err) - return - } - - f, err := os.Create(filepath.Join(tempDir, "bandit", "rewards.csv")) - require.NoError(t, err) - defer f.Close() - _, err = f.WriteString(tt.given) - require.NoError(t, err) - - metrics := loadLastBanditRewards(tempDir)() - tt.assert(t, metrics) - }) - } -} - // Assert that a testDialer is a bandit.Dialer var _ dialer.ProxyDialer = &testDialer{} diff --git a/dialer/bandit.go b/dialer/bandit.go index a562ae73b..54f12727d 100644 --- a/dialer/bandit.go +++ b/dialer/bandit.go @@ -2,8 +2,14 @@ package dialer import ( "context" + "encoding/csv" + "fmt" + "io" "math/rand" "net" + "os" + "path/filepath" + "strconv" "sync/atomic" "time" @@ -36,15 +42,19 @@ func NewBandit(opts *Options) (Dialer, error) { dialers: dialers, opts: opts, } - if opts.LoadLastBanditRewards != nil { - log.Debugf("Loading bandit weights from %s", opts.LoadLastBanditRewards) - dialerWeights := opts.LoadLastBanditRewards() + + dialerWeights, err := dialer.LoadLastBanditRewards() + if err != nil { + log.Errorf("unable to load bandit weights: %v", err) + } + if dialerWeights != nil { + log.Debugf("Loading bandit weights from %q", opts.BanditDir) counts := make([]int, len(dialers)) rewards := make([]float64, len(dialers)) for arm, dialer := range dialers { - if banditMetrics, ok := dialerWeights[dialer.Name()]; ok { - rewards[arm] = banditMetrics.Reward - counts[arm] = banditMetrics.Count + if metrics, ok := dialerWeights[dialer.Name()]; ok { + rewards[arm] = metrics.Reward + counts[arm] = metrics.Count } } b, err = bandit.NewEpsilonGreedy(0.1, counts, rewards) @@ -52,16 +62,18 @@ func NewBandit(opts *Options) (Dialer, error) { log.Errorf("unable to create weighted bandit: %w", err) return nil, err } - } else { - b, err = bandit.NewEpsilonGreedy(0.1, nil, nil) - if err != nil { - log.Errorf("unable to create bandit: %v", err) - return nil, err - } - if err := b.Init(len(dialers)); err != nil { - log.Errorf("unable to initialize bandit: %v", err) - return nil, err - } + dialer.bandit = b + return dialer, nil + } + + b, err = bandit.NewEpsilonGreedy(0.1, nil, nil) + if err != nil { + log.Errorf("unable to create bandit: %v", err) + return nil, err + } + if err := b.Init(len(dialers)); err != nil { + log.Errorf("unable to initialize bandit: %v", err) + return nil, err } dialer.bandit = b @@ -124,18 +136,19 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) ( time.AfterFunc(30*time.Second, func() { // Save the bandit weights - if bd.opts.SaveBanditRewards != nil { - metrics := make(map[string]BanditMetrics) - rewards := bd.bandit.GetRewards() - counts := bd.bandit.GetCounts() - for i, d := range bd.dialers { - metrics[d.Name()] = BanditMetrics{ - Reward: rewards[i], - Count: counts[i], - } + metrics := make(map[string]banditMetrics) + rewards := bd.bandit.GetRewards() + counts := bd.bandit.GetCounts() + for i, d := range bd.dialers { + metrics[d.Name()] = banditMetrics{ + Reward: rewards[i], + Count: counts[i], } + } - bd.opts.SaveBanditRewards(metrics) + err = bd.SaveBanditRewards(metrics) + if err != nil { + log.Errorf("unable to save bandit weights: %v", err) } }) @@ -143,6 +156,95 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) ( return dt, err } +// LoadLastBanditRewards is a function that returns the last bandit rewards +// for each dialer. If this is set, the bandit will be initialized with the +// last metrics. +func (o *BanditDialer) LoadLastBanditRewards() (map[string]banditMetrics, error) { + if o.opts.BanditDir == "" { + return nil, nil + } + + file := filepath.Join(o.opts.BanditDir, "rewards.csv") + data, err := os.Open(file) + if err != nil && os.IsNotExist(err) { + return nil, log.Errorf("unable to read bandit rewards from file: %v", err) + } + reader := csv.NewReader(data) + _, err = reader.Read() // Skip the header + if err != nil { + return nil, log.Errorf("unable to skip headers from bandit rewards csv: %v", err) + } + metrics := make(map[string]banditMetrics) + for { + line, err := reader.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, log.Errorf("unable to read line from bandit rewards csv: %v", err) + } + + if len(line) != 3 { + return nil, log.Errorf("invalid line in bandit rewards csv: %v", line) + } + reward, err := strconv.ParseFloat(line[1], 64) + if err != nil { + return nil, log.Errorf("unable to parse reward from %s: %v", line[0], err) + } + count, err := strconv.Atoi(line[2]) + if err != nil { + return nil, log.Errorf("unable to parse count from %s: %v", line[0], err) + } + metrics[line[0]] = banditMetrics{ + Reward: reward, + Count: count, + } + } + return metrics, nil +} + +func (o *BanditDialer) SaveBanditRewards(metrics map[string]banditMetrics) error { + if o.opts.BanditDir == "" { + return log.Error("bandit directory is not set") + } + + if err := os.MkdirAll(o.opts.BanditDir, 0755); err != nil { + return log.Errorf("unable to create bandit directory: %v", err) + } + file := filepath.Join(o.opts.BanditDir, "rewards.csv") + + headers := []string{"dialer", "reward", "count"} + writeHeaders := false + if _, err := os.Stat(file); err != nil { + if !os.IsNotExist(err) { + return log.Errorf("unable to stat bandit rewards file: %v", err) + } + writeHeaders = true + } + + f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return log.Errorf("unable to open bandit rewards file: %v", err) + } + defer f.Close() + + w := csv.NewWriter(f) + defer w.Flush() + if writeHeaders { + if err = w.Write(headers); err != nil { + return log.Errorf("unable to write headers to bandit rewards file: %v", err) + } + } + + for dialerName, metric := range metrics { + if err = w.Write([]string{dialerName, fmt.Sprintf("%f", metric.Reward), fmt.Sprintf("%d", metric.Count)}); err != nil { + return log.Errorf("unable to write bandit rewards to file: %v", err) + } + } + + return nil +} + func (o *BanditDialer) chooseDialerForDomain(network, addr string) (ProxyDialer, int) { // Loop through the number of dialers we have and select the one that is best // for the given domain. diff --git a/dialer/bandit_test.go b/dialer/bandit_test.go index 8413055eb..2e054ab23 100644 --- a/dialer/bandit_test.go +++ b/dialer/bandit_test.go @@ -2,10 +2,14 @@ package dialer import ( "context" + "fmt" "io" "math/rand" "net" + "os" + "path/filepath" "reflect" + "strings" "testing" "time" @@ -81,17 +85,22 @@ func TestBanditDialer_chooseDialerForDomain(t *testing.T) { func TestNewBandit(t *testing.T) { oldDialer := newTcpConnDialer() + oldDialerMetric := banditMetrics{ + Reward: 0.7, + Count: 10, + } tests := []struct { name string opts *Options - assert func(t *testing.T, got Dialer, err error) + assert func(t *testing.T, got Dialer, err error, dir string) + setup func() string }{ { name: "should fail if there are no dialers", opts: &Options{ Dialers: nil, }, - assert: func(t *testing.T, got Dialer, err error) { + assert: func(t *testing.T, got Dialer, err error, _ string) { assert.Nil(t, got) assert.Error(t, err) }, @@ -101,7 +110,7 @@ func TestNewBandit(t *testing.T) { opts: &Options{ Dialers: []ProxyDialer{newTcpConnDialer()}, }, - assert: func(t *testing.T, got Dialer, err error) { + assert: func(t *testing.T, got Dialer, err error, _ string) { assert.NotNil(t, got) assert.NoError(t, err) assert.IsType(t, &BanditDialer{}, got) @@ -111,23 +120,26 @@ func TestNewBandit(t *testing.T) { name: "should load the last bandit rewards if they exist", opts: &Options{ Dialers: []ProxyDialer{oldDialer, newTcpConnDialer()}, - LoadLastBanditRewards: func() map[string]BanditMetrics { - return map[string]BanditMetrics{ - oldDialer.Name(): { - Reward: 0.7, - Count: 10, - }, - } - }, }, - assert: func(t *testing.T, got Dialer, err error) { + setup: func() string { + tempDir, err := os.MkdirTemp("", "client_test") + require.NoError(t, err) + + // create rewards.csv + err = os.WriteFile(filepath.Join(tempDir, "rewards.csv"), []byte(fmt.Sprintf("dialer,reward,count\n%s,%f,%d\n", oldDialer.Name(), oldDialerMetric.Reward, oldDialerMetric.Count)), 0644) + require.NoError(t, err) + return tempDir + }, + assert: func(t *testing.T, got Dialer, err error, dir string) { assert.NotNil(t, got) assert.NoError(t, err) assert.IsType(t, &BanditDialer{}, got) rewards := got.(*BanditDialer).bandit.GetRewards() counts := got.(*BanditDialer).bandit.GetCounts() - assert.Equal(t, 0.7, rewards[0]) - assert.Equal(t, 10, counts[0]) + // checking if the rewards are loaded correctly + assert.Equal(t, oldDialerMetric.Reward, rewards[0]) + assert.Equal(t, oldDialerMetric.Count, counts[0]) + // since there's no data for the second dialer, it should be 0 assert.Equal(t, float64(0), rewards[1]) assert.Equal(t, 0, counts[1]) }, @@ -135,8 +147,14 @@ func TestNewBandit(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + dir := "" + if tt.setup != nil { + dir = tt.setup() + defer os.RemoveAll(dir) + tt.opts.BanditDir = dir + } got, err := NewBandit(tt.opts) - tt.assert(t, got, err) + tt.assert(t, got, err, dir) }) } } @@ -306,6 +324,94 @@ func Test_differentArm(t *testing.T) { } } +func TestSaveBanditRewards(t *testing.T) { + var tests = []struct { + name string + given map[string]banditMetrics + assert func(t *testing.T, dir string, err error) + }{ + { + name: "it should save the rewards", + given: map[string]banditMetrics{ + "test-dialer": { + Reward: 1.0, + Count: 1, + }, + }, + assert: func(t *testing.T, dir string, err error) { + assert.NoError(t, err) + f, err := os.Open(filepath.Join(dir, "rewards.csv")) + require.NoError(t, err) + defer f.Close() + b, err := io.ReadAll(f) + require.NoError(t, err) + + lines := strings.Split(string(b), "\n") + // check if headers are there + assert.Contains(t, lines[0], "dialer,reward,count") + // check if the data is there + assert.Contains(t, lines[1], "test-dialer,1.000000,1") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "client_test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + banditDialer := &BanditDialer{ + opts: &Options{ + BanditDir: tempDir, + }, + } + err = banditDialer.SaveBanditRewards(tt.given) + tt.assert(t, tempDir, err) + }) + } +} + +func TestLoadLastBanditRewards(t *testing.T) { + var tests = []struct { + name string + given string + assert func(t *testing.T, metrics map[string]banditMetrics, err error) + }{ + { + name: "it should load the rewards", + given: "dialer,reward,count\ntest-dialer,1.000000,1\n", + assert: func(t *testing.T, metrics map[string]banditMetrics, err error) { + assert.NoError(t, err) + assert.Contains(t, metrics, "test-dialer") + assert.Equal(t, 1.0, metrics["test-dialer"].Reward) + assert.Equal(t, 1, metrics["test-dialer"].Count) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "client_test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + require.NoError(t, os.MkdirAll(filepath.Join(tempDir, "bandit"), 0755), "unable to create bandit directory") + + f, err := os.Create(filepath.Join(tempDir, "rewards.csv")) + require.NoError(t, err) + defer f.Close() + _, err = f.WriteString(tt.given) + require.NoError(t, err) + + banditDialer := &BanditDialer{ + opts: &Options{ + BanditDir: tempDir, + }, + } + metrics, err := banditDialer.LoadLastBanditRewards() + tt.assert(t, metrics, err) + }) + } +} + func newTcpConnDialer() ProxyDialer { client, server := net.Pipe() return &tcpConnDialer{ diff --git a/dialer/dialer.go b/dialer/dialer.go index 8ea6f4299..8ef1dee83 100644 --- a/dialer/dialer.go +++ b/dialer/dialer.go @@ -83,15 +83,11 @@ type Options struct { // OnSuccess is the callback that is called by dialer after a successful dial. OnSuccess func(ProxyDialer) - // LoadLastBanditRewards is a function that returns the last bandit rewards - // for each dialer. If this is set, the bandit will be initialized with the - // last metrics. - LoadLastBanditRewards func() map[string]BanditMetrics - - SaveBanditRewards func(map[string]BanditMetrics) + // BanditDir is the directory where the bandit will store its data + BanditDir string } -type BanditMetrics struct { +type banditMetrics struct { Reward float64 Count int }