diff --git a/flashlight.go b/flashlight.go index d0fcc9f7e..4747f3aaf 100644 --- a/flashlight.go +++ b/flashlight.go @@ -95,8 +95,6 @@ type Flashlight struct { errorHandler func(HandledErrorType, error) mxProxyListeners sync.RWMutex proxyListeners []func(map[string]*commonconfig.ProxyConfig, config.Source) - - configService *services.ConfigService } // clientCallbacks are callbacks the client is configured with @@ -130,7 +128,7 @@ func (f *Flashlight) EnabledFeatures() map[string]bool { } global := f.global f.mxGlobal.RUnlock() - country := f.configService.Country() + country := services.GetCountry() for feature := range global.FeaturesEnabled { if f.calcFeature(global, country, "0.0.1", feature) { featuresEnabled[feature] = true @@ -188,7 +186,7 @@ func (f *Flashlight) FeatureEnabled(feature, applicationVersion string) bool { f.mxGlobal.RLock() global := f.global f.mxGlobal.RUnlock() - return f.calcFeature(global, f.configService.Country(), applicationVersion, feature) + return f.calcFeature(global, services.GetCountry(), applicationVersion, feature) } func (f *Flashlight) calcFeature(global *config.Global, country, applicationVersion, feature string) bool { @@ -333,16 +331,7 @@ func New( proxyListeners: make([]func(map[string]*commonconfig.ProxyConfig, config.Source), 0), } - fops.InitGlobalContext( - appName, appVersion, revisionDate, deviceID, isPro, - func() string { - if f.configService != nil { - return f.configService.Country() - } - - return "" - }, - ) + fops.InitGlobalContext(appName, appVersion, revisionDate, deviceID, isPro, services.GetCountry) f.addProxyListener(func(proxies map[string]*commonconfig.ProxyConfig, src config.Source) { log.Debug("Applying proxy config with proxies") @@ -467,50 +456,53 @@ func (f *Flashlight) StartBackgroundServices() (func(), error) { stopMonitor := goroutines.Monitor(time.Minute, 800, 5) stopGlobalConfigFetch := f.startGlobalConfigFetch() - bypassStopFn := services.StartBypassService(f.addProxyListener, f.configDir, f.userConfig) + stopBypass := services.StartBypassService(f.addProxyListener, f.configDir, f.userConfig) - onConfig := func(conf *services.ClientConfig) { - country := f.configService.Country() - if nc := conf.GetCountry(); nc != country { + onConfig := func(old, new *services.ClientConfig) { + var country string + if old != nil { + country = old.GetCountry() + } + + if nc := new.GetCountry(); nc != country { // Update the country if it has changed log.Debugf("Setting detour country to %v", nc) detour.SetCountry(nc) } - proxyMap := f.convertNewProxyConfToOld(conf.GetProxy().GetProxies()) + proxyMap := f.convertNewProxyConfToOld(new.GetProxy().GetProxies()) f.notifyProxyListeners(proxyMap, config.Fetched) } + configOpts := &services.ConfigOptions{ - SaveDir: f.configDir, - Obfuscate: true, - OriginURL: "", - UserConfig: f.userConfig, - Sticky: false, - Rt: proxied.ParallelPreferChained(), - OnConfig: onConfig, + SaveDir: f.configDir, + Obfuscate: true, + OriginURL: "", + UserConfig: f.userConfig, + Sticky: false, + RoundTripper: proxied.ParallelPreferChained(), + OnConfig: onConfig, } setConfigFlagOpts(configOpts, f.flagsAsMap) - configService, err := services.StartConfigService(configOpts) + stopConfigService, err := services.StartConfigService(configOpts) if err != nil { return func() { stopMonitor() - bypassStopFn() + stopBypass() stopGlobalConfigFetch() }, err } - f.configService = configService - // TODO: update all code to use the new config format for geolookup geolookup.EnablePersistence(filepath.Join(f.configDir, "latestgeoinfo.json")) geolookup.Refresh() return func() { stopMonitor() - bypassStopFn() + stopBypass() stopGlobalConfigFetch() - configService.Stop() + stopConfigService() }, err } diff --git a/services/bypass.go b/services/bypass.go index 7b856ba7f..b7e8b4d46 100644 --- a/services/bypass.go +++ b/services/bypass.go @@ -71,7 +71,7 @@ func StartBypassService( listen func(func(map[string]*commonconfig.ProxyConfig, config.Source)), configDir string, userConfig common.UserConfig, -) func() { +) StopFn { b := &bypassService{ infos: make(map[string]*commonconfig.ProxyConfig), proxies: make([]*proxy, 0), diff --git a/services/config.go b/services/config.go index 0867522a1..3fa2b3634 100644 --- a/services/config.go +++ b/services/config.go @@ -28,8 +28,8 @@ import ( ) const ( - defaultSaveDir = "" - defaultFilename = "proxies.conf" + defaultConfigSaveDir = "" + defaultConfigFilename = "proxies.conf" defaultConfigPollInterval = 3 * time.Minute defaultConfigPollJitter = 2 * time.Minute @@ -70,10 +70,9 @@ type ConfigOptions struct { // update it with remote data. Sticky bool - // Rt provides the RoundTripper the fetcher should use, which allows us to - // dictate whether the fetcher will use dual fetching (from fronted and - // chained URLs) or not. - Rt http.RoundTripper + // RoundTripper provides the http.RoundTripper the fetcher should use, which allows us to + // dictate whether the fetcher will use dual fetching (from fronted and chained URLs) or not. + RoundTripper http.RoundTripper // PollInterval specifies how frequently to poll for new config. PollInterval time.Duration @@ -81,25 +80,39 @@ type ConfigOptions struct { PollJitter time.Duration // OnConfig is a callback that is called when a new config is received. - OnConfig func(conf *ClientConfig) + OnConfig func(old, new *ClientConfig) } -type ConfigService struct { +type configService struct { opts *ConfigOptions clientInfo *ClientInfo clientConfig atomic.Value lastFetched time.Time - done chan struct{} - once sync.Once - logger golog.Logger + done chan struct{} + running bool + logger golog.Logger } -// StartConfigService starts a new config service with the given options. It will return an error -// if opts.OriginURL, opts.Rt, opts.Fetcher, or opts.OnConfig are nil. -func StartConfigService(opts *ConfigOptions) (*ConfigService, error) { +var ( + // initialize variable so we don't have to lock mutex and check if it's nil every time someone + // calls GetClientConfig + _configService = &configService{clientConfig: atomic.Value{}} + configServiceMu sync.Mutex +) + +// StartConfigService starts a new config service with the given options and returns a func to stop +// it. It will return an error if opts.OriginURL, opts.Rt, opts.Fetcher, or opts.OnConfig are nil. +func StartConfigService(opts *ConfigOptions) (StopFn, error) { + configServiceMu.Lock() + defer configServiceMu.Unlock() + + if _configService != nil && _configService.running { + return _configService.Stop, nil + } + switch { - case opts.Rt == nil: + case opts.RoundTripper == nil: return nil, errors.New("RoundTripper is required") case opts.OnConfig == nil: return nil, errors.New("OnConfig is required") @@ -108,8 +121,8 @@ func StartConfigService(opts *ConfigOptions) (*ConfigService, error) { } if opts.SaveDir == "" { - opts.SaveDir = defaultSaveDir - opts.filePath = filepath.Join(opts.SaveDir, defaultFilename) + opts.SaveDir = defaultConfigSaveDir + opts.filePath = filepath.Join(opts.SaveDir, defaultConfigFilename) } if opts.PollInterval <= 0 { @@ -130,119 +143,125 @@ func StartConfigService(opts *ConfigOptions) (*ConfigService, error) { detour.ForceWhitelist(u.Host) userId := strconv.Itoa(int(opts.UserConfig.GetUserID())) - ch := &ConfigService{ - opts: opts, - clientInfo: &ClientInfo{ - FlashlightVersion: common.LibraryVersion, - ClientVersion: common.CompileTimeApplicationVersion, - UserId: userId, - ProToken: opts.UserConfig.GetToken(), - }, - clientConfig: atomic.Value{}, - done: make(chan struct{}), - once: sync.Once{}, - logger: logger, + _configService.opts = opts + _configService.clientInfo = &ClientInfo{ + FlashlightVersion: common.LibraryVersion, + ClientVersion: common.CompileTimeApplicationVersion, + UserId: userId, + ProToken: opts.UserConfig.GetToken(), } + _configService.done = make(chan struct{}) + _configService.logger = logger - if err := ch.init(); err != nil { + if err := _configService.init(); err != nil { return nil, err } - ch.logger.Debug("Starting config service") + _configService.logger.Debug("Starting config service") + _configService.running = true if opts.Sticky { - return ch, nil + return _configService.Stop, nil } fn := func() int64 { - sleep, _ := ch.fetchConfig() + sleep, _ := _configService.fetchConfig() return sleep } - go callRandomly(fn, ch.opts.PollInterval, ch.opts.PollJitter, ch.done, ch.logger) - return ch, nil + go callRandomly(fn, opts.PollInterval, opts.PollJitter, _configService.done, _configService.logger) + + return _configService.Stop, nil } -func (ch *ConfigService) init() error { - ch.logger.Debug("Initializing config service") - conf, err := readExistingClientConfig(ch.opts.filePath, ch.opts.Obfuscate) +func (cs *configService) init() error { + cs.logger.Debug("Initializing config service") + conf, err := readExistingClientConfig(cs.opts.filePath, cs.opts.Obfuscate) if conf == nil { if err != nil { - ch.logger.Errorf("could not read existing config: %v", err) + cs.logger.Errorf("could not read existing config: %v", err) } - ch.clientConfig.Store(&ClientConfig{}) + cs.clientConfig.Store(&ClientConfig{}) return err } - ch.logger.Debugf("loaded saved config at %v", ch.opts.filePath) + cs.logger.Debugf("loaded saved config at %v", cs.opts.filePath) + + cs.clientInfo.Country = conf.Country + cs.clientConfig.Store(conf) + cs.opts.OnConfig(nil, conf) - ch.clientInfo.Country = conf.Country - ch.clientConfig.Store(conf) - ch.opts.OnConfig(conf) return nil } -func (ch *ConfigService) updateClientInfo(conf *ClientConfig) { - ch.clientInfo.ProToken = conf.ProToken - ch.clientInfo.Country = conf.Country - ch.clientInfo.Ip = conf.Ip +func (cs *configService) updateClientInfo(conf *ClientConfig) { + cs.clientInfo.ProToken = conf.ProToken + cs.clientInfo.Country = conf.Country + cs.clientInfo.Ip = conf.Ip } -func (ch *ConfigService) Stop() { - ch.once.Do(func() { - close(ch.done) - }) +func (cs *configService) Stop() { + configServiceMu.Lock() + defer configServiceMu.Unlock() + + if !cs.running { + return + } + + close(cs.done) + cs.running = false } // fetchConfig fetches the current config from the server and updates the client's config if a change // has occurred. It returns the extra sleep time received from the server response and any error that // occurred. -func (ch *ConfigService) fetchConfig() (int64, error) { +func (cs *configService) fetchConfig() (int64, error) { op := ops.Begin("Fetching config") defer op.End() - newConf, sleep, err := ch.fetch() + newConf, sleep, err := cs.fetch() if err != nil { return 0, op.FailIf(err) } - ch.lastFetched = time.Now() + cs.lastFetched = time.Now() - ch.logger.Debug("Received config") - if !configIsNew(ch.clientInfo, newConf) { + cs.logger.Debug("Received config") + curConf := GetClientConfig() + if curConf != nil && !configIsNew(curConf, newConf) { op.Set("config_changed", false) - ch.logger.Debug("Config is unchanged") + cs.logger.Debug("Config is unchanged") return sleep, nil } op.Set("config_changed", true) - err = saveClientConfig(ch.opts.filePath, newConf, ch.opts.Obfuscate) + err = saveClientConfig(cs.opts.filePath, newConf, cs.opts.Obfuscate) if err != nil { - ch.logger.Error(err) + cs.logger.Error(err) } else { - ch.logger.Debugf("Wrote config to %v", ch.opts.filePath) + cs.logger.Debugf("Wrote config to %v", cs.opts.filePath) } - ch.updateClientInfo(newConf) - ch.clientConfig.Store(newConf) - ch.opts.OnConfig(newConf) + cs.updateClientInfo(newConf) + old := cs.clientConfig.Swap(newConf) + cs.opts.OnConfig(old.(*ClientConfig), newConf) return sleep, nil } -func (ch *ConfigService) fetch() (*ClientConfig, int64, error) { - confReq := ch.newRequest() +func (cs *configService) fetch() (*ClientConfig, int64, error) { + confReq := cs.newRequest() buf, err := proto.Marshal(confReq) if err != nil { return nil, 0, fmt.Errorf("unable to marshal config request: %w", err) } resp, sleep, err := post( - ch.opts.OriginURL, + cs.opts.OriginURL, bytes.NewReader(buf), - ch.opts.Rt, - ch.opts.UserConfig, - ch.logger, + cs.opts.RoundTripper, + cs.opts.UserConfig, + cs.logger, ) if err != nil { return nil, 0, fmt.Errorf("config request failed: %w", err) @@ -265,18 +284,23 @@ func (ch *ConfigService) fetch() (*ClientConfig, int64, error) { return newConf, sleep, err } -func (ch *ConfigService) newRequest() *ConfigRequest { - proxies := ch.Proxies() +func (cs *configService) newRequest() *ConfigRequest { + conf := GetClientConfig() + proxies := []*ProxyConnectConfig{} + if conf != nil { // not the first request + proxies = conf.GetProxy().GetProxies() + } + ids := make([]string, len(proxies)) for _, proxy := range proxies { ids = append(ids, proxy.GetTrack()) } confReq := &ConfigRequest{ - ClientInfo: ch.clientInfo, + ClientInfo: cs.clientInfo, Proxy: &ConfigProxies{ Ids: ids, - LastRequest: timestamppb.New(ch.lastFetched), + LastRequest: timestamppb.New(cs.lastFetched), }, } @@ -311,7 +335,7 @@ func readExistingClientConfig(filePath string, obfuscate bool) (*ClientConfig, e return conf, err } -// saveClientConfig writes conf to a file at the specified path, filePath, obfuscating them if +// saveClientConfig writes conf to a file at the specified path, filePath, obfuscating it if // obfuscate is true. If the file already exists, it will be overwritten. func saveClientConfig(filePath string, conf *ClientConfig, obfuscate bool) error { in, err := proto.Marshal(conf) @@ -339,25 +363,25 @@ func saveClientConfig(filePath string, conf *ClientConfig, obfuscate bool) error // configIsNew returns true if country, proToken, or ip in currInfo differ from new or if new has // proxy configs. -func configIsNew(currInfo *ClientInfo, new *ClientConfig) bool { - return currInfo.GetCountry() != new.GetCountry() || - currInfo.GetProToken() != new.GetProToken() || - currInfo.GetIp() != new.GetIp() || +func configIsNew(cur, new *ClientConfig) bool { + return cur.GetCountry() != new.GetCountry() || + cur.GetProToken() != new.GetProToken() || len(new.GetProxy().GetProxies()) > 0 } -func (ch *ConfigService) Country() string { - return ch.clientConfig.Load().(*ClientConfig).GetCountry() +// GetClientConfig returns the current client config. +func GetClientConfig() *ClientConfig { + // We don't need to lock the mutex here because we know that the configService var is not nil + return _configService.clientConfig.Load().(*ClientConfig) } -func (ch *ConfigService) Ip() string { - return ch.clientConfig.Load().(*ClientConfig).GetIp() -} - -func (ch *ConfigService) ProToken() string { - return ch.clientConfig.Load().(*ClientConfig).GetProToken() -} +// GetCountry returns the country from the current client config. If there is no config, it returns +// the default country. +func GetCountry() string { + conf := GetClientConfig() + if conf == nil { // no config yet + return "" + } -func (ch *ConfigService) Proxies() []*ProxyConnectConfig { - return ch.clientConfig.Load().(*ClientConfig).GetProxy().GetProxies() + return conf.GetCountry() } diff --git a/services/service.go b/services/service.go index 0d5ed97a4..fd5720b96 100644 --- a/services/service.go +++ b/services/service.go @@ -8,6 +8,8 @@ import ( "github.com/getlantern/golog" ) +type StopFn func() + // callRandomly continuously calls fn randomly between interval-jitter and interval+jitter, with // the initial call being made immediately. fn can return a positive value to extend the wait time. func callRandomly(