Skip to content

Commit

Permalink
Stop old dialers when we get a new proxy config
Browse files Browse the repository at this point in the history
  • Loading branch information
myleshorton committed Nov 20, 2024
1 parent f48a119 commit c79d88f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 7 deletions.
16 changes: 15 additions & 1 deletion dialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,33 @@ import (
"io"
"net"
"runtime/debug"
"sync/atomic"
"time"

"github.com/getlantern/golog"
)

var log = golog.LoggerFor("dialer")

var currentDialer atomic.Value

// New creates a new dialer that first tries to connect as quickly as possilbe while also
// optimizing for the fastest dialer.
func New(opts *Options) Dialer {
return NewTwoPhaseDialer(opts, func(opts *Options, existing Dialer) Dialer {
if currentDialer.Load() != nil {
log.Debug("Closing existing dialer")
currentDialer.Load().(Dialer).Close()
}
d := NewTwoPhaseDialer(opts, func(opts *Options, existing Dialer) Dialer {
bandit, err := NewBandit(opts)
if err != nil {
log.Errorf("Unable to create bandit: %v", err)
return existing
}
return bandit
})
currentDialer.Store(d)
return d
}

// NoDialer returns a dialer that does nothing. This is useful during startup
Expand All @@ -48,6 +57,8 @@ func (d *noDialer) DialContext(ctx context.Context, network, addr string) (net.C
return nil, errors.New("no dialer available")
}

func (d *noDialer) Close() {}

const (
// NetworkConnect is a pseudo network name to instruct the dialer to establish
// a CONNECT tunnel to the proxy.
Expand Down Expand Up @@ -90,6 +101,9 @@ type Dialer interface {

// DialContext dials out to the domain or IP address representing a destination site.
DialContext(ctx context.Context, network, addr string) (net.Conn, error)

// Close closes the dialer and cleans up any resources
Close()
}

// hasSucceedingDialer checks whether or not any of the given dialers is able to successfully dial our proxies
Expand Down
37 changes: 31 additions & 6 deletions dialer/fastconnect.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@ type fastConnectDialer struct {

next func(*Options, Dialer) Dialer
opts *Options

// Create a channel for stopping connections to dialers
stopCh chan struct{}
}

// Make sure fastConnectDialer implements Dialer
var _ Dialer = (*fastConnectDialer)(nil)

func newFastConnectDialer(opts *Options, next func(opts *Options, existing Dialer) Dialer) *fastConnectDialer {
if opts.OnError == nil {
opts.OnError = func(error, bool) {}
Expand All @@ -45,6 +51,7 @@ func newFastConnectDialer(opts *Options, next func(opts *Options, existing Diale
opts: opts,
next: next,
topDialer: protectedDialer{},
stopCh: make(chan struct{}),
}
}

Expand Down Expand Up @@ -76,6 +83,13 @@ func (fcd *fastConnectDialer) DialContext(ctx context.Context, network, addr str
return conn, err
}

func (fcd *fastConnectDialer) Close() {
// We don't call Stop on the Dialers themselves here because they are likely
// in use by other Dialers, such as the BanditDialer.
// Stop all dialing
fcd.stopCh <- struct{}{}
}

func (fcd *fastConnectDialer) onConnected(pd ProxyDialer, connectTime time.Duration) {
log.Debugf("Connected to %v", pd.Name())

Expand All @@ -99,13 +113,23 @@ func (fcd *fastConnectDialer) connectAll(dialers []ProxyDialer) {
return
}
log.Debugf("Dialing all dialers in parallel %#v", dialers)
// Loop until we're connected
for len(fcd.connected.dialers) < 2 {
fcd.parallelDial(dialers)
// Add jitter to avoid thundering herd
time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond)
outerLoop:
for {
select {
case <-fcd.stopCh:
log.Debug("Stopping parallel dialing")
return
default:
// Loop until we're connected
if len(fcd.connected.dialers) < 2 {
fcd.parallelDial(dialers)
// Add jitter to avoid thundering herd
time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond)
} else {
break outerLoop
}
}
}

// At this point, we've tried all of the dialers, and they've all either
// succeeded or failed.

Expand All @@ -114,6 +138,7 @@ func (fcd *fastConnectDialer) connectAll(dialers []ProxyDialer) {
nextOpts := fcd.opts.Clone()
nextOpts.Dialers = fcd.connected.proxyDialers()
fcd.next(nextOpts, fcd)

}

func (fcd *fastConnectDialer) parallelDial(dialers []ProxyDialer) {
Expand Down
12 changes: 12 additions & 0 deletions dialer/two_phase_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ type twoPhaseDialer struct {
activeDialer activeDialer
}

// Make sure twoPhaseDialer implements Dialer
var _ Dialer = (*twoPhaseDialer)(nil)

// NewTwoPhaseDialer creates a new dialer for checking proxy connectivity.
func NewTwoPhaseDialer(opts *Options, next func(opts *Options, existing Dialer) Dialer) Dialer {
log.Debugf("Creating new two phase dialer with %d dialers", len(opts.Dialers))
Expand All @@ -26,6 +29,7 @@ func NewTwoPhaseDialer(opts *Options, next func(opts *Options, existing Dialer)
// This is where we move to the second dialer.
nextDialer := next(dialerOpts, existing)
tpd.activeDialer.set(nextDialer)
existing.Close()
return nextDialer
})

Expand All @@ -45,6 +49,14 @@ func (ccd *twoPhaseDialer) DialContext(ctx context.Context, network string, addr
return td.DialContext(ctx, network, addr)
}

// Close implements Dialer.
func (ccd *twoPhaseDialer) Close() {
td := ccd.activeDialer.get()
if td != nil {
td.Close()
}
}

// protectedDialer protects a dialer.Dialer with a RWMutex. We can't use an atomic.Value here
// because Dialer is an interface.
type activeDialer struct {
Expand Down

0 comments on commit c79d88f

Please sign in to comment.