diff --git a/bypass/bypass.go b/bypass/bypass.go index e8af3d5c9..d78e6d3d6 100644 --- a/bypass/bypass.go +++ b/bypass/bypass.go @@ -58,7 +58,6 @@ type bypass struct { // Start sends periodic traffic to the bypass server. The client periodically sends traffic to the server both via // domain fronting and proxying to determine if proxies are blocked. func Start(listen func(func(map[string]*commonconfig.ProxyConfig, config.Source)), configDir string, userConfig common.UserConfig) func() { - mrand.Seed(time.Now().UnixNano()) b := &bypass{ infos: make(map[string]*commonconfig.ProxyConfig), proxies: make([]*proxy, 0), @@ -173,14 +172,14 @@ func (p *proxy) sendToBypass() int64 { } defer func() { if resp.Body != nil { - if closeerr := resp.Body.Close(); closeerr != nil { - log.Errorf("Error closing response body: %v", closeerr) + if _, err := io.Copy(io.Discard, resp.Body); err != nil { + log.Errorf("Error reading response body: %v", err) + } + if err := resp.Body.Close(); err != nil { + log.Errorf("Error closing response body: %v", err) } } }() - if resp.Body != nil { - io.Copy(io.Discard, resp.Body) - } var sleepTime int64 sleepVal := resp.Header.Get(common.SleepHeader) diff --git a/client/client.go b/client/client.go index babb9b962..7a00cb92e 100644 --- a/client/client.go +++ b/client/client.go @@ -178,6 +178,8 @@ func NewClient( configDir: configDir, requestTimeout: requestTimeout, dialer: &protectedDialer{ + // This is just a placeholder dialer until we're able to fetch the + // actual proxy dialers from the config. dialer: dialer.NoDialer(), }, disconnected: disconnected, @@ -291,7 +293,12 @@ func (client *Client) ListenAndServeHTTP(requestedAddr string, onListeningFn fun } return fmt.Errorf("unable to accept connection: %v", err) } - go client.handle(conn) + go func(conn net.Conn) { + err := client.handle(conn) + if err != nil { + log.Errorf("Error handling connection: %v", err) + } + }(conn) } } @@ -396,6 +403,11 @@ func (client *Client) Stop() error { var TimeoutWaitingForDNSResolutionMap = 5 * time.Second func (client *Client) dial(ctx context.Context, isConnect bool, network, addr string) (conn net.Conn, err error) { + op := ops.Begin("proxied_dialer") + op.Set("local_proxy_type", "http") + op.OriginPort(addr, "") + defer op.End() + // Fetch DNS resolution map, if any // XXX <01-04-2022, soltzen> Do this fetch now, so it won't be affected by // the context timeout of client.doDial() @@ -413,7 +425,7 @@ func (client *Client) dial(ctx context.Context, isConnect bool, network, addr st ctx2, cancel2 := context.WithTimeout(ctx, client.requestTimeout) defer cancel2() - return client.doDial(ctx2, isConnect, addr, dnsResolutionMapForDirectDials) + return client.doDial(op, ctx2, isConnect, addr, dnsResolutionMapForDirectDials) } // doDial is the ultimate place to dial an origin site. It takes following steps: @@ -422,19 +434,24 @@ func (client *Client) dial(ctx context.Context, isConnect bool, network, addr st // * If the host or port is configured not proxyable, dial directly. // * If the site is allowed by shortcut, dial directly. If it failed before the deadline, try proxying. // * Try dial the site directly with 1/5th of the requestTimeout, then try proxying. -func (client *Client) doDial(ctx context.Context, isCONNECT bool, addr string, +func (client *Client) doDial(op *ops.Op, ctx context.Context, isCONNECT bool, addr string, dnsResolutionMapForDirectDials map[string]string) (net.Conn, error) { dialDirect := func(ctx context.Context, network, addr string) (net.Conn, error) { if v, ok := dnsResolutionMapForDirectDials[addr]; ok { log.Debugf("Bypassed DNS resolution: dialing %v as %v", addr, v) - return netx.DialContext(ctx, network, v) + conn, err := netx.DialContext(ctx, network, v) + op.FailIf(err) + return conn, err } else { - return netx.DialContext(ctx, network, addr) + conn, err := netx.DialContext(ctx, network, addr) + op.FailIf(err) + return conn, err } } dialProxied := func(ctx context.Context, _unused, addr string) (net.Conn, error) { + op.Set("remotely_proxied", true) proto := dialer.NetworkPersistent if isCONNECT { // UGLY HACK ALERT! In this case, we know we need to send a CONNECT request @@ -465,40 +482,57 @@ func (client *Client) doDial(ctx context.Context, isCONNECT bool, addr string, if routingRuleForDomain == domainrouting.MustDirect { log.Debugf("Forcing direct to %v per domain routing rules (MustDirect)", host) + op.Set("force_direct", true) + op.Set("force_direct_reason", "routingrule") return dialDirect(ctx, "tcp", addr) } if shouldForceProxying() { log.Tracef("Proxying to %v because everything is forced to be proxied", addr) + op.Set("force_proxied", true) + op.Set("force_proxied_reason", "forceproxying") return dialProxied(ctx, "whatever", addr) } if routingRuleForDomain == domainrouting.MustProxy { log.Tracef("Proxying to %v per domain routing rules (MustProxy)", addr) + op.Set("force_proxied", true) + op.Set("force_proxied_reason", "routingrule") return dialProxied(ctx, "whatever", addr) } if err := client.allowSendingToProxy(addr); err != nil { log.Debugf("%v, sending directly to %v", err, addr) + op.Set("force_direct", true) + op.Set("force_direct_reason", err.Error()) return dialDirect(ctx, "tcp", addr) } if client.proxyAll() { log.Tracef("Proxying to %v because proxyall is enabled", addr) + op.Set("force_proxied", true) + op.Set("force_proxied_reason", "proxyall") return dialProxied(ctx, "whatever", addr) } dialDirectForShortcut := func(ctx context.Context, network, addr string, ip net.IP) (net.Conn, error) { log.Debugf("Use shortcut (dial directly) for %v(%v)", addr, ip) + op.Set("shortcut_direct", true) + op.Set("shortcut_direct_ip", ip) + op.Set("shortcut_origin", addr) return dialDirect(ctx, "tcp", addr) } switch domainrouting.RuleFor(host) { case domainrouting.Direct: log.Tracef("Directly dialing %v per domain routing rules (Direct)", addr) + op.Set("force_direct", true) + op.Set("force_direct_reason", "routingrule") return dialDirect(ctx, "tcp", addr) case domainrouting.Proxy: log.Tracef("Proxying to %v per domain routing rules (Proxy)", addr) + op.Set("force_proxied", true) + op.Set("force_proxied_reason", "routingrule") return dialProxied(ctx, "whatever", addr) } @@ -529,6 +563,7 @@ func (client *Client) doDial(ctx context.Context, isCONNECT bool, addr string, var dialer func(ctx context.Context, network, addr string) (net.Conn, error) if client.useDetour() { + op.Set("detour", true) dialer = detour.Dialer(dialDirectForDetour, dialProxied) } else if !client.useShortcut() { dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { diff --git a/client/handler.go b/client/handler.go index 103a6e335..1a24e0284 100644 --- a/client/handler.go +++ b/client/handler.go @@ -5,7 +5,6 @@ import ( "errors" "net" "net/http" - "net/url" "strings" "time" @@ -104,16 +103,6 @@ func (client *Client) filter(cs *filters.ConnectionState, req *http.Request, nex return next(cs, req) } -// getBaseUrl returns the URL for the base domain of an ad without the full path, query string, -// etc. -func (client *Client) getBaseUrl(originalUrl string) string { - url, err := url.Parse(originalUrl) - if err != nil { - return originalUrl - } - return url.Scheme + "://" + url.Host -} - func (client *Client) isHTTPProxyPort(r *http.Request) bool { host, port, err := net.SplitHostPort(r.Host) if err != nil { @@ -161,16 +150,6 @@ func (client *Client) interceptProRequest(cs *filters.ConnectionState, r *http.R return filters.ShortCircuit(cs, r, resp) } -func (client *Client) easyblock(cs *filters.ConnectionState, req *http.Request) (*http.Response, *filters.ConnectionState, error) { - log.Debugf("Blocking %v on %v", req.URL, req.Host) - client.statsTracker.IncAdsBlocked() - resp := &http.Response{ - StatusCode: http.StatusForbidden, - Close: true, - } - return filters.ShortCircuit(cs, req, resp) -} - func (client *Client) redirectHTTPS(cs *filters.ConnectionState, req *http.Request, httpsURL string, op *ops.Op) (*http.Response, *filters.ConnectionState, error) { log.Debugf("httpseverywhere redirecting to %v", httpsURL) if op != nil { diff --git a/dialer/bandit.go b/dialer/bandit.go index b90d18eef..fe170760a 100644 --- a/dialer/bandit.go +++ b/dialer/bandit.go @@ -35,7 +35,10 @@ func NewBandit(opts *Options) (Dialer, error) { return nil, err } - b.Init(len(dialers)) + if err := b.Init(len(dialers)); err != nil { + log.Errorf("unable to initialize bandit: %v", err) + return nil, err + } dialer := &BanditDialer{ dialers: dialers, bandit: b, @@ -67,14 +70,18 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) ( if !failedUpstream { log.Errorf("Dialer %v failed in %v seconds: %v", d.Name(), time.Since(start).Seconds(), err) - bd.bandit.Update(chosenArm, 0) + if err := bd.bandit.Update(chosenArm, 0); err != nil { + log.Errorf("unable to update bandit: %v", err) + } } else { log.Debugf("Dialer %v failed upstream...", d.Name()) // This can happen, for example, if the upstream server is down, or // if the DNS resolves to localhost, for example. It is also possible // that the proxy is blacklisted by upstream sites for some reason, // so we have to choose some reasonable value. - bd.bandit.Update(chosenArm, 0.00005) + if err := bd.bandit.Update(chosenArm, 0.00005); err != nil { + log.Errorf("unable to update bandit: %v", err) + } } return nil, err } @@ -90,7 +97,9 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) ( time.AfterFunc(secondsForSample*time.Second, func() { speed := normalizeReceiveSpeed(dataRecv.Load()) //log.Debugf("Dialer %v received %v bytes in %v seconds, normalized speed: %v", d.Name(), dt.dataRecv, secondsForSample, speed) - bd.bandit.Update(chosenArm, speed) + if err := bd.bandit.Update(chosenArm, speed); err != nil { + log.Errorf("unable to update bandit: %v", err) + } }) bd.opts.OnSuccess(d) diff --git a/dialer/dialer.go b/dialer/dialer.go index cc4760ffe..593746602 100644 --- a/dialer/dialer.go +++ b/dialer/dialer.go @@ -31,6 +31,8 @@ func New(opts *Options) Dialer { }) } +// NoDialer returns a dialer that does nothing. This is useful during startup +// until a real dialer is available. func NoDialer() Dialer { return &noDialer{} } @@ -38,6 +40,9 @@ func NoDialer() Dialer { type noDialer struct{} func (d *noDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + // This ideally shouldn't be called, as it indicates we're attempting to send + // traffic through proxies before we actually have proxies. It's not a fatal + // error, but it's a sign that we should look into why we're here. // Print the goroutine stack to help debug why we're here log.Errorf("No dialer available -- should not be called, stack: %s", debug.Stack()) return nil, errors.New("no dialer available") diff --git a/dialer/fastconnect.go b/dialer/fastconnect.go index ffa95c0e4..055d7b389 100644 --- a/dialer/fastconnect.go +++ b/dialer/fastconnect.go @@ -16,28 +16,16 @@ type connectTimeProxyDialer struct { connectTime time.Duration } -type dialersByConnectTime []connectTimeProxyDialer - -func (d dialersByConnectTime) Len() int { - return len(d) -} - -func (d dialersByConnectTime) Less(i, j int) bool { - return d[i].connectTime < d[j].connectTime -} - -func (d dialersByConnectTime) Swap(i, j int) { - d[i], d[j] = d[j], d[i] +type connectedDialers struct { + dialers []connectTimeProxyDialer + sync.RWMutex } // fastConnectDialer stores the time it took to connect to each dialer and uses // that information to select the fastest dialer to use. type fastConnectDialer struct { - topDialer protectedDialer - connected dialersByConnectTime - connectedChan chan int - // Lock for the slice of dialers. - connectedLock sync.RWMutex + topDialer protectedDialer + connected connectedDialers next func(*Options, Dialer) Dialer opts *Options @@ -51,10 +39,12 @@ func newFastConnectDialer(opts *Options, next func(opts *Options, existing Diale opts.OnSuccess = func(ProxyDialer) {} } return &fastConnectDialer{ - connected: make(dialersByConnectTime, 0), - connectedChan: make(chan int), - opts: opts, - next: next, + connected: connectedDialers{ + dialers: make([]connectTimeProxyDialer, 0), + }, + opts: opts, + next: next, + topDialer: protectedDialer{}, } } @@ -70,7 +60,7 @@ func (fcd *fastConnectDialer) DialContext(ctx context.Context, network, addr str // the domain here. conn, failedUpstream, err := td.DialContext(ctx, network, addr) if err != nil { - hasSucceeding := len(fcd.connected) > 0 + hasSucceeding := len(fcd.connected.dialers) > 0 fcd.opts.OnError(err, hasSucceeding) // Error connecting to the proxy or to the destination if failedUpstream { @@ -88,18 +78,11 @@ func (fcd *fastConnectDialer) DialContext(ctx context.Context, network, addr str func (fcd *fastConnectDialer) onConnected(pd ProxyDialer, connectTime time.Duration) { log.Debugf("Connected to %v", pd.Name()) - fcd.connectedLock.Lock() - defer fcd.connectedLock.Unlock() - fcd.connected = append(fcd.connected, connectTimeProxyDialer{ - ProxyDialer: pd, - connectTime: connectTime, - }) - sort.Sort(fcd.connected) + newTopDialer := fcd.connected.onConnected(pd, connectTime) // Set top dialer if the fastest dialer changed. td := fcd.topDialer.get() - newTopDialer := fcd.connected[0].ProxyDialer if td != newTopDialer { log.Debugf("Setting new top dialer to %v", newTopDialer.Name()) fcd.topDialer.set(newTopDialer) @@ -117,7 +100,7 @@ func (fcd *fastConnectDialer) connectAll(dialers []ProxyDialer) { } log.Debugf("Dialing all dialers in parallel %#v", dialers) // Loop until we're connected - for len(fcd.connected) < 2 { + for len(fcd.connected.dialers) < 2 { fcd.parallelDial(dialers) // Add jitter to avoid thundering herd time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) @@ -129,7 +112,7 @@ func (fcd *fastConnectDialer) connectAll(dialers []ProxyDialer) { // If we've connected to more than one dialer after trying all of them, // switch to the next dialer that's optimized for multiple connections. nextOpts := fcd.opts.Clone() - nextOpts.Dialers = fcd.proxyDialers() + nextOpts.Dialers = fcd.connected.proxyDialers() fcd.next(nextOpts, fcd) } @@ -162,20 +145,36 @@ func (fcd *fastConnectDialer) parallelDial(dialers []ProxyDialer) { } // Accessor for a copy of the ProxyDialer slice -func (fcd *fastConnectDialer) proxyDialers() []ProxyDialer { - fcd.connectedLock.RLock() - defer fcd.connectedLock.RUnlock() +func (cd *connectedDialers) proxyDialers() []ProxyDialer { + cd.RLock() + defer cd.RUnlock() - dialers := make([]ProxyDialer, len(fcd.connected)) + dialers := make([]ProxyDialer, len(cd.dialers)) // Note that we manually copy here vs using copy because we need an array of // ProxyDialers, not a dialersByConnectTime. - for i, ctd := range fcd.connected { + for i, ctd := range cd.dialers { dialers[i] = ctd.ProxyDialer } return dialers } +// onConnected adds a connected dialer to the list of connected dialers and returns +// the fastest dialer. +func (cd *connectedDialers) onConnected(pd ProxyDialer, connectTime time.Duration) ProxyDialer { + cd.Lock() + defer cd.Unlock() + + cd.dialers = append(cd.dialers, connectTimeProxyDialer{ + ProxyDialer: pd, + connectTime: connectTime, + }) + sort.Slice(cd.dialers, func(i, j int) bool { + return cd.dialers[i].connectTime < cd.dialers[j].connectTime + }) + return cd.dialers[0].ProxyDialer +} + // protectedDialer protects a dialer.Dialer with a RWMutex. We can't use an atomic.Value here // because ProxyDialer is an interface. type protectedDialer struct { diff --git a/dialer/fastconnect_test.go b/dialer/fastconnect_test.go index d6bad539b..575d6255e 100644 --- a/dialer/fastconnect_test.go +++ b/dialer/fastconnect_test.go @@ -2,30 +2,76 @@ package dialer import ( - "sort" "testing" "time" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) -func TestConnectTimeProxyDialer(t *testing.T) { - //dialer := newMockProxyDialer("dialer1", false) - dialer := newTcpConnDialer() - ctd1 := connectTimeProxyDialer{ - ProxyDialer: dialer, connectTime: 1 * time.Second, - } - ctd2 := connectTimeProxyDialer{ - ProxyDialer: dialer, connectTime: 100 * time.Second, +func TestOnConnected(t *testing.T) { + mockDialer1 := new(mockProxyDialer) + mockDialer2 := new(mockProxyDialer) + mockDialer3 := new(mockProxyDialer) + + opts := &Options{ + OnError: func(err error, hasSucceeding bool) {}, + OnSuccess: func(pd ProxyDialer) {}, } - ctd3 := connectTimeProxyDialer{ - ProxyDialer: dialer, connectTime: 10 * time.Second, + + fcd := newFastConnectDialer(opts, nil) + + // Test adding the first dialer + fcd.onConnected(mockDialer1, 100*time.Millisecond) + assert.Equal(t, 1, len(fcd.connected.dialers)) + assert.Equal(t, mockDialer1, fcd.topDialer.get()) + + // Test adding a faster dialer + fcd.onConnected(mockDialer2, 50*time.Millisecond) + assert.Equal(t, 2, len(fcd.connected.dialers)) + assert.Equal(t, mockDialer2, fcd.topDialer.get()) + + // Test adding a slower dialer + fcd.onConnected(mockDialer1, 150*time.Millisecond) + assert.Equal(t, 3, len(fcd.connected.dialers)) + assert.Equal(t, mockDialer2, fcd.topDialer.get()) + + // Test adding a new fastest dialer + fcd.onConnected(mockDialer3, 10*time.Millisecond) + assert.Equal(t, 4, len(fcd.connected.dialers)) + assert.Equal(t, mockDialer3, fcd.topDialer.get()) +} +func TestConnectAll(t *testing.T) { + mockDialer1 := new(mockProxyDialer) + mockDialer2 := new(mockProxyDialer) + mockDialer3 := new(mockProxyDialer) + + opts := &Options{ + OnError: func(err error, hasSucceeding bool) {}, + OnSuccess: func(pd ProxyDialer) {}, } - dialers := dialersByConnectTime{ctd1, ctd2, ctd3} - sort.Sort(dialers) + fcd := newFastConnectDialer(opts, func(opts *Options, existing Dialer) Dialer { + return nil + }) + + dialers := []ProxyDialer{mockDialer1, mockDialer2, mockDialer3} + + // Test connecting with multiple dialers + fcd.connectAll(dialers) + + // Sleep for a bit to allow the goroutines to finish while checking for + // the connected dialers + tries := 0 + for len(fcd.connected.dialers) < 3 && tries < 100 { + time.Sleep(10 * time.Millisecond) + tries++ + } + assert.Equal(t, 3, len(fcd.connected.dialers)) + assert.NotNil(t, fcd.topDialer.get()) - // Make sure the lowest connect time is first - require.True(t, dialers[0].connectTime < dialers[1].connectTime, "Expected dialers to be ordered by connect time") - require.True(t, dialers[1].connectTime < dialers[2].connectTime, "Expected dialers to be ordered by connect time") + // Test with no dialers + fcd = newFastConnectDialer(opts, nil) + fcd.connectAll([]ProxyDialer{}) + assert.Equal(t, 0, len(fcd.connected.dialers)) + assert.Nil(t, fcd.topDialer.get()) }