diff --git a/dialer/dialer.go b/dialer/dialer.go index 593746602..52792ae93 100644 --- a/dialer/dialer.go +++ b/dialer/dialer.go @@ -11,6 +11,7 @@ import ( "io" "net" "runtime/debug" + "sync/atomic" "time" "github.com/getlantern/golog" @@ -18,10 +19,16 @@ import ( var log = golog.LoggerFor("dialer") +var currentDialer atomic.Value + // New creates a new dialer that first tries to connect as quickly as possilbe while also // optimizing for the fastest dialer. func New(opts *Options) Dialer { - return NewTwoPhaseDialer(opts, func(opts *Options, existing Dialer) Dialer { + if currentDialer.Load() != nil { + log.Debug("Closing existing dialer") + currentDialer.Load().(Dialer).Close() + } + d := NewTwoPhaseDialer(opts, func(opts *Options, existing Dialer) Dialer { bandit, err := NewBandit(opts) if err != nil { log.Errorf("Unable to create bandit: %v", err) @@ -29,6 +36,8 @@ func New(opts *Options) Dialer { } return bandit }) + currentDialer.Store(d) + return d } // NoDialer returns a dialer that does nothing. This is useful during startup @@ -48,6 +57,8 @@ func (d *noDialer) DialContext(ctx context.Context, network, addr string) (net.C return nil, errors.New("no dialer available") } +func (d *noDialer) Close() {} + const ( // NetworkConnect is a pseudo network name to instruct the dialer to establish // a CONNECT tunnel to the proxy. @@ -90,6 +101,9 @@ type Dialer interface { // DialContext dials out to the domain or IP address representing a destination site. DialContext(ctx context.Context, network, addr string) (net.Conn, error) + + // Close closes the dialer and cleans up any resources + Close() } // hasSucceedingDialer checks whether or not any of the given dialers is able to successfully dial our proxies diff --git a/dialer/fastconnect.go b/dialer/fastconnect.go index 055d7b389..a1f20fc75 100644 --- a/dialer/fastconnect.go +++ b/dialer/fastconnect.go @@ -29,8 +29,14 @@ type fastConnectDialer struct { next func(*Options, Dialer) Dialer opts *Options + + // Create a channel for stopping connections to dialers + stopCh chan struct{} } +// Make sure fastConnectDialer implements Dialer +var _ Dialer = (*fastConnectDialer)(nil) + func newFastConnectDialer(opts *Options, next func(opts *Options, existing Dialer) Dialer) *fastConnectDialer { if opts.OnError == nil { opts.OnError = func(error, bool) {} @@ -45,6 +51,7 @@ func newFastConnectDialer(opts *Options, next func(opts *Options, existing Diale opts: opts, next: next, topDialer: protectedDialer{}, + stopCh: make(chan struct{}), } } @@ -76,6 +83,13 @@ func (fcd *fastConnectDialer) DialContext(ctx context.Context, network, addr str return conn, err } +func (fcd *fastConnectDialer) Close() { + // We don't call Stop on the Dialers themselves here because they are likely + // in use by other Dialers, such as the BanditDialer. + // Stop all dialing + fcd.stopCh <- struct{}{} +} + func (fcd *fastConnectDialer) onConnected(pd ProxyDialer, connectTime time.Duration) { log.Debugf("Connected to %v", pd.Name()) @@ -99,13 +113,23 @@ func (fcd *fastConnectDialer) connectAll(dialers []ProxyDialer) { return } log.Debugf("Dialing all dialers in parallel %#v", dialers) - // Loop until we're connected - for len(fcd.connected.dialers) < 2 { - fcd.parallelDial(dialers) - // Add jitter to avoid thundering herd - time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) +outerLoop: + for { + select { + case <-fcd.stopCh: + log.Debug("Stopping parallel dialing") + return + default: + // Loop until we're connected + if len(fcd.connected.dialers) < 2 { + fcd.parallelDial(dialers) + // Add jitter to avoid thundering herd + time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) + } else { + break outerLoop + } + } } - // At this point, we've tried all of the dialers, and they've all either // succeeded or failed. @@ -114,6 +138,7 @@ func (fcd *fastConnectDialer) connectAll(dialers []ProxyDialer) { nextOpts := fcd.opts.Clone() nextOpts.Dialers = fcd.connected.proxyDialers() fcd.next(nextOpts, fcd) + } func (fcd *fastConnectDialer) parallelDial(dialers []ProxyDialer) { diff --git a/dialer/two_phase_dialer.go b/dialer/two_phase_dialer.go index d27f0d0e1..a5acc4b6d 100644 --- a/dialer/two_phase_dialer.go +++ b/dialer/two_phase_dialer.go @@ -16,6 +16,9 @@ type twoPhaseDialer struct { activeDialer activeDialer } +// Make sure twoPhaseDialer implements Dialer +var _ Dialer = (*twoPhaseDialer)(nil) + // NewTwoPhaseDialer creates a new dialer for checking proxy connectivity. func NewTwoPhaseDialer(opts *Options, next func(opts *Options, existing Dialer) Dialer) Dialer { log.Debugf("Creating new two phase dialer with %d dialers", len(opts.Dialers)) @@ -26,6 +29,7 @@ func NewTwoPhaseDialer(opts *Options, next func(opts *Options, existing Dialer) // This is where we move to the second dialer. nextDialer := next(dialerOpts, existing) tpd.activeDialer.set(nextDialer) + existing.Close() return nextDialer }) @@ -45,6 +49,14 @@ func (ccd *twoPhaseDialer) DialContext(ctx context.Context, network string, addr return td.DialContext(ctx, network, addr) } +// Close implements Dialer. +func (ccd *twoPhaseDialer) Close() { + td := ccd.activeDialer.get() + if td != nil { + td.Close() + } +} + // protectedDialer protects a dialer.Dialer with a RWMutex. We can't use an atomic.Value here // because Dialer is an interface. type activeDialer struct {