Skip to content

Commit

Permalink
refactored to make code more self-documenting
Browse files Browse the repository at this point in the history
  • Loading branch information
myleshorton committed Nov 12, 2024
1 parent 6e02586 commit d6c3a75
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 132 deletions.
12 changes: 3 additions & 9 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,10 @@ func NewClient(
if err != nil {
return nil, errors.New("Unable to create rewrite LRU: %v", err)
}
banditDialer, err := dialer.NewFastConnectDialer(&dialer.Options{}, func(*dialer.Options) (dialer.Dialer, error) {
return nil, nil
})
if err != nil {
return nil, errors.New("Unable to create bandit: %v", err)
}
client := &Client{
configDir: configDir,
requestTimeout: requestTimeout,
dialer: banditDialer,
dialer: dialer.New(&dialer.Options{}),
disconnected: disconnected,
proxyAll: proxyAll,
useShortcut: useShortcut,
Expand Down Expand Up @@ -715,15 +709,15 @@ func (client *Client) initDialers(proxies map[string]*commonconfig.ProxyConfig)
configDir := client.configDir
chained.PersistSessionStates(configDir)
dialers := chained.CreateDialers(configDir, proxies, client.user)
dialer, err := dialer.New(&dialer.Options{
dialer := dialer.New(&dialer.Options{
Dialers: dialers,
OnError: client.onDialError,
OnSuccess: func(dialer dialer.ProxyDialer) {
client.onSucceedingProxy()
},
StatsTracker: client.statsTracker,
})
return dialers, dialer, err
return dialers, dialer, nil
}

// Creates a local server to capture client hello messages from the browser and
Expand Down
2 changes: 1 addition & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func newTestUserConfig() *common.UserConfigData {
}

func resetDialers(client *Client, dial func(network, addr string) (net.Conn, error)) {
d, _ := dialer.New(&dialer.Options{
d := dialer.New(&dialer.Options{
Dialers: []dialer.ProxyDialer{&testDialer{
name: "test-dialer",
dial: dial,
Expand Down
6 changes: 3 additions & 3 deletions dialer/bandit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestBanditDialer_chooseDialerForDomain(t *testing.T) {
}
}

func TestNew(t *testing.T) {
func TestNewBandit(t *testing.T) {
tests := []struct {
name string
opts *Options
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestNew(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
got, err := NewBandit(tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("NewBandit() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.want != nil && !reflect.TypeOf(got).AssignableTo(reflect.TypeOf(tt.want)) {
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestBanditDialer_DialContext(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o, err := New(tt.opts)
o, err := NewBandit(tt.opts)
if err != nil {
t.Fatal(err)
}
Expand Down
13 changes: 10 additions & 3 deletions dialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@ import (

var log = golog.LoggerFor("dialer")

func New(opts *Options) (Dialer, error) {
return NewFastConnectDialer(opts, func(opts *Options) (Dialer, error) {
return NewBandit(opts)
// New creates a new dialer that first tries to connect as quickly as possilbe while also
// optimizing for the fastest dialer.
func New(opts *Options) Dialer {
return TwoPhaseDialer(opts, func(opts *Options, existing Dialer) Dialer {
bandit, err := NewBandit(opts)
if err != nil {
log.Errorf("Unable to create bandit: %v", err)
return existing
}
return bandit
})
}

Expand Down
194 changes: 78 additions & 116 deletions dialer/fastconnect.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,6 @@ import (
"github.com/getlantern/flashlight/v7/stats"
)

// FastConnectDialer finds a working dialer as quickly as possible.
type FastConnectDialer struct {
dialers []ProxyDialer
onError func(error, bool)
onSuccess func(ProxyDialer)
statsTracker stats.Tracker
connectTimeDialer *connectTimeDialer

activeDialer Dialer
activeDialerLock sync.RWMutex

next func(*Options) (Dialer, error)
opts *Options
}

// DialContext implements Dialer.
func (ccd *FastConnectDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
td := ccd.loadActiveDialer()
if td == nil {
return nil, errors.New("no active dialer")
}
return td.DialContext(ctx, network, addr)
}

type connectTimeProxyDialer struct {
ProxyDialer

Expand All @@ -56,151 +32,137 @@ func (d dialersByConnectTime) Swap(i, j int) {
d[i], d[j] = d[j], d[i]
}

type connectTimeDialer struct {
// fastConnectDialer stores the time it took to connect to each dialer and uses
// that information to select the fastest dialer to use.
type fastConnectDialer struct {
topDialer ProxyDialer
topDialerLock sync.RWMutex
connected dialersByConnectTime
connectedChan chan int
// Lock for the slice of dialers.
connectedLock sync.RWMutex

next func(*Options, Dialer) Dialer
opts *Options

onError func(error, bool)
onSuccess func(ProxyDialer)
statsTracker stats.Tracker
}

func (ctd *connectTimeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
func newFastConnectDialer(opts *Options, next func(opts *Options, existing Dialer) Dialer) *fastConnectDialer {
if opts.OnError == nil {
opts.OnError = func(error, bool) {}
}
if opts.OnSuccess == nil {
opts.OnSuccess = func(ProxyDialer) {}
}
if opts.StatsTracker == nil {
opts.StatsTracker = stats.NewNoop()
}
return &fastConnectDialer{
connected: make(dialersByConnectTime, 0),
connectedChan: make(chan int),
opts: opts,
next: next,
onError: opts.OnError,
onSuccess: opts.OnSuccess,
statsTracker: opts.StatsTracker,
}
}

func (fcd *fastConnectDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
// Use the dialer with the lowest connect time, waiting on early dials for any
// connections at all.
td := ctd.loadTopDialer()
td := fcd.loadTopDialer()
if td == nil {
log.Debug("No top dialer")
return nil, errors.New("no top dialer")
}
conn, up, err := td.DialContext(ctx, network, addr)
conn, failedUpstream, err := td.DialContext(ctx, network, addr)
if err != nil {
hasSucceeding := len(fcd.connected) > 0
fcd.statsTracker.SetHasSucceedingProxy(hasSucceeding)
fcd.onError(err, hasSucceeding)
// Error connecting to the proxy or to the destination
if up {
if failedUpstream {
// Error connecting to the destination
log.Debugf("Error connecting to upstream destination %v: %v", addr, err)
} else {
// Error connecting to the proxy
log.Debugf("Error connecting to proxy %v: %v", td.Name(), err)

}
return nil, err
}
fcd.statsTracker.SetHasSucceedingProxy(true)
fcd.onSuccess(td)
return conn, err

}

// Accessor for a copy of the ProxyDialer slice
func (cdt *connectTimeDialer) proxyDialers() []ProxyDialer {
cdt.connectedLock.RLock()
defer cdt.connectedLock.RUnlock()
dialers := make([]ProxyDialer, len(cdt.connected))
for i, ctd := range cdt.connected {
func (fcd *fastConnectDialer) proxyDialers() []ProxyDialer {
fcd.connectedLock.RLock()
defer fcd.connectedLock.RUnlock()

dialers := make([]ProxyDialer, len(fcd.connected))

// Note that we manually copy here vs using copy because we need an array of
// ProxyDialers, not a dialersByConnectTime.
for i, ctd := range fcd.connected {
dialers[i] = ctd.ProxyDialer
}
return dialers
}

func (ctd *connectTimeDialer) onConnected(pd ProxyDialer, connectTime time.Duration) {
func (fcd *fastConnectDialer) onConnected(pd ProxyDialer, connectTime time.Duration) {
log.Debugf("Connected to %v", pd.Name())
ctd.connectedLock.Lock()
defer ctd.connectedLock.Unlock()
fcd.connectedLock.Lock()
defer fcd.connectedLock.Unlock()

ctd.connected = append(ctd.connected, connectTimeProxyDialer{
fcd.connected = append(fcd.connected, connectTimeProxyDialer{
ProxyDialer: pd,
connectTime: connectTime,
})
sort.Sort(ctd.connected)
sort.Sort(fcd.connected)

// Set top dialer if the fastest dialer changed.
td := ctd.loadTopDialer()
newTopDialer := ctd.connected[0].ProxyDialer
td := fcd.loadTopDialer()
newTopDialer := fcd.connected[0].ProxyDialer
if td != newTopDialer {
ctd.storeTopDialer(newTopDialer)
fcd.storeTopDialer(newTopDialer)
}
log.Debug("Finished adding connected dialer")
}

func (ctd *connectTimeDialer) loadTopDialer() ProxyDialer {
ctd.topDialerLock.RLock()
defer ctd.topDialerLock.RUnlock()
return ctd.topDialer
}

func (ctd *connectTimeDialer) storeTopDialer(pd ProxyDialer) {
ctd.topDialerLock.Lock()
defer ctd.topDialerLock.Unlock()
ctd.topDialer = pd
}

// NewFastConnectDialer creates a new dialer for checking proxy connectivity.
func NewFastConnectDialer(opts *Options, next func(opts *Options) (Dialer, error)) (Dialer, error) {
if opts.OnError == nil {
opts.OnError = func(error, bool) {}
}
if opts.OnSuccess == nil {
opts.OnSuccess = func(ProxyDialer) {}
}
if opts.StatsTracker == nil {
opts.StatsTracker = stats.NewNoop()
}

log.Debugf("Creating new dialer with %d dialers", len(opts.Dialers))

ctd := &connectTimeDialer{
connected: make(dialersByConnectTime, 0),
connectedChan: make(chan int),
}
//ctd.storeTopDialer(newWaitForConnectionDialer(ctd.connectedChan))

fcd := &FastConnectDialer{
dialers: opts.Dialers,
onError: opts.OnError,
onSuccess: opts.OnSuccess,
statsTracker: opts.StatsTracker,
connectTimeDialer: ctd,
next: next,
opts: opts,
}
fcd.storeActiveDialer(ctd)

fcd.parallelDial()

return fcd, nil
}

// parallelDial dials all the dialers in parallel to connect the user as quickly as
// possible on startup.
func (fcd *FastConnectDialer) parallelDial() {
if len(fcd.dialers) == 0 {
func (fcd *fastConnectDialer) connectAll(dialers []ProxyDialer) {
if len(dialers) == 0 {
log.Errorf("No dialers to connect to")
return
}
log.Debugf("Dialing all dialers in parallel %#v", fcd.dialers)
log.Debugf("Dialing all dialers in parallel %#v", dialers)
// Loop until we're connected
for len(fcd.connectTimeDialer.connected) < 2 {
fcd.connectAll()
for len(fcd.connected) < 2 {
fcd.parallelDial(dialers)
// Add jitter to avoid thundering herd
time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond)
}

// At this point, we've tried all of the dialers, and they've all either
// succeeded or failed.

// If we've connected to more than one dialer after trying all of them,
// switch to the next dialer that's optimized for multiple connections.
nextOpts := fcd.opts.Clone()
nextOpts.Dialers = fcd.connectTimeDialer.proxyDialers()
nextDialer, err := fcd.next(nextOpts)
if err != nil {
log.Errorf("Could not create next dialer? ", err)
} else {
log.Debug("Switching to next dialer")
fcd.storeActiveDialer(nextDialer)
}
nextOpts.Dialers = fcd.proxyDialers()
fcd.next(nextOpts, fcd)
}

func (fcd *FastConnectDialer) connectAll() {
func (fcd *fastConnectDialer) parallelDial(dialers []ProxyDialer) {
log.Debug("Connecting to all dialers")
var wg sync.WaitGroup
for index, d := range fcd.dialers {
for index, d := range dialers {
wg.Add(1)
go func(pd ProxyDialer, index int) {
defer wg.Done()
Expand All @@ -220,20 +182,20 @@ func (fcd *FastConnectDialer) connectAll() {

log.Debugf("Dialer %v succeeded in %v", pd.Name(), time.Since(start))
fcd.statsTracker.SetHasSucceedingProxy(true)
fcd.connectTimeDialer.onConnected(pd, time.Since(start))
fcd.onConnected(pd, time.Since(start))
}(d, index)
}
wg.Wait()
}

func (fcd *FastConnectDialer) storeActiveDialer(active Dialer) {
fcd.activeDialerLock.Lock()
defer fcd.activeDialerLock.Unlock()
fcd.activeDialer = active
func (fcd *fastConnectDialer) loadTopDialer() ProxyDialer {
fcd.topDialerLock.RLock()
defer fcd.topDialerLock.RUnlock()
return fcd.topDialer
}

func (fcd *FastConnectDialer) loadActiveDialer() Dialer {
fcd.activeDialerLock.RLock()
defer fcd.activeDialerLock.RUnlock()
return fcd.activeDialer
func (fcd *fastConnectDialer) storeTopDialer(pd ProxyDialer) {
fcd.topDialerLock.Lock()
defer fcd.topDialerLock.Unlock()
fcd.topDialer = pd
}
Loading

0 comments on commit d6c3a75

Please sign in to comment.