diff --git a/CHANGELOG.md b/CHANGELOG.md index 61796c5..bff3b5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Changed + +- Config.Load now is concurrent-safe (#567). + ## [1.3.1] - 2024-09-09 ### Added diff --git a/config.go b/config.go index 92fd8cd..8494f44 100644 --- a/config.go +++ b/config.go @@ -34,15 +34,9 @@ type Config struct { onStatus func(loader Loader, changed bool, err error) converter *convert.Converter - // Loaded configuration. - values atomic.Pointer[map[string]any] - providers []*provider - providersMutex sync.RWMutex - - // For watching changes. - onChanges map[string][]func(*Config) - onChangesMutex sync.RWMutex - watched atomic.Bool + providers providers + onChanges onChanges + watched atomic.Bool } // New creates a new Config with the given Option(s). @@ -71,7 +65,7 @@ func New(opts ...Option) *Config { // Load loads configuration from the given loader. // Each loader takes precedence over the loaders before it. // -// This method can be called multiple times but it is not concurrency-safe. +// This method is concurrent-safe. func (c *Config) Load(loader Loader) error { if loader == nil { return nil @@ -84,19 +78,8 @@ func (c *Config) Load(loader Loader) error { return fmt.Errorf("load configuration: %w", err) } c.transformKeys(values) - prd := provider{ - loader: loader, - } - prd.values.Store(&values) - c.providersMutex.Lock() - c.providers = append(c.providers, &prd) - c.providersMutex.Unlock() - - // Merge loaded values into values map. - if c.values.Load() == nil { - c.values.Store(&map[string]any{}) - } - maps.Merge(*c.values.Load(), *prd.values.Load()) + c.providers.append(loader, values) + c.providers.sync() if _, ok := loader.(Watcher); !ok { return nil @@ -141,15 +124,16 @@ func (c *Config) Unmarshal(path string, target any) error { } c.nocopy.Check() - if c.values.Load() == nil { - return nil // To support zero Config + value := c.providers.value() + if value == nil { // To support zero Config + return nil } converter := c.converter if converter == nil { // To support zero Config converter = defaultConverter } - if err := converter.Convert(c.sub(*c.values.Load(), path), target); err != nil { + if err := converter.Convert(c.sub(value, path), target); err != nil { return fmt.Errorf("decode: %w", err) } @@ -195,12 +179,13 @@ func (c *Config) Explain(path string) string { } c.nocopy.Check() - if c.values.Load() == nil { // To support zero Config + value := c.providers.value() + if value == nil { // To support zero Config return path + " has no configuration.\n\n" } explanation := &strings.Builder{} - c.explain(explanation, path, c.sub(*c.values.Load(), path)) + c.explain(explanation, path, c.sub(value, path)) return explanation.String() } @@ -224,11 +209,11 @@ func (c *Config) explain(explanation *strings.Builder, path string, value any) { value any } var loaders []loaderValue - for _, provider := range c.providers { + c.providers.traverse(func(provider *provider) { if v := c.sub(*provider.values.Load(), path); v != nil { loaders = append(loaders, loaderValue{provider.loader, v}) } - } + }) slices.Reverse(loaders) if len(loaders) == 0 { @@ -256,9 +241,54 @@ func (c *Config) explain(explanation *strings.Builder, path string, value any) { explanation.WriteString("\n") } -type provider struct { - loader Loader - values atomic.Pointer[map[string]any] +type ( + providers struct { + providers []*provider + values atomic.Pointer[map[string]any] + mutex sync.RWMutex + } + provider struct { + loader Loader + values atomic.Pointer[map[string]any] + } +) + +func (p *providers) append(loader Loader, values map[string]any) { + p.mutex.Lock() + defer p.mutex.Unlock() + + provider := &provider{loader: loader} + provider.values.Store(&values) + p.providers = append(p.providers, provider) +} + +func (p *providers) sync() { + p.mutex.Lock() + defer p.mutex.Unlock() + + values := make(map[string]any) + for _, w := range p.providers { + maps.Merge(values, *w.values.Load()) + } + p.values.Store(&values) +} + +func (p *providers) traverse(action func(*provider)) { + p.mutex.RLock() + defer p.mutex.RUnlock() + + for _, provider := range p.providers { + action(provider) + } +} + +func (p *providers) value() map[string]any { + val := p.values.Load() + if val == nil { + return nil + } + + return *val } //nolint:gochecknoglobals diff --git a/default.go b/default.go index 4011cb6..c1802b6 100644 --- a/default.go +++ b/default.go @@ -44,7 +44,7 @@ func Unmarshal(path string, target any) error { // The register function must be non-blocking and usually completes instantly. // If it requires a long time to complete, it should be executed in a separate goroutine. // -// This method is concurrency-safe. +// This method is concurrent-safe. func OnChange(onChange func(), paths ...string) { defaultConfig.Load().OnChange(func(*Config) { onChange() }, paths...) } diff --git a/provider.go b/provider.go index da8ca02..862fc0f 100644 --- a/provider.go +++ b/provider.go @@ -41,9 +41,10 @@ func (c *Config) Exists(path []string) bool { } c.nocopy.Check() - if c.values.Load() == nil { + value := c.providers.value() + if value == nil { return false // To support zero Config } - return c.sub(*c.values.Load(), strings.Join(path, c.delim())) != nil + return c.sub(value, strings.Join(path, c.delim())) != nil } diff --git a/watch.go b/watch.go index e3d57af..ea89187 100644 --- a/watch.go +++ b/watch.go @@ -9,11 +9,8 @@ import ( "fmt" "log/slog" "reflect" - "slices" "sync" "time" - - "github.com/nil-go/konf/internal/maps" ) // Watch watches and updates configuration when it changes. @@ -24,39 +21,31 @@ import ( func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocognit c.nocopy.Check() - if hasWatcher := slices.ContainsFunc(c.providers, func(provider *provider) bool { - _, ok := provider.loader.(Watcher) - - return ok - }); !hasWatcher { - return nil - } - if watched := c.watched.Swap(true); watched { - c.log(ctx, slog.LevelWarn, "Config has been watched, call Watch again has no effects.") + c.log(ctx, slog.LevelWarn, "Config has been watched, call Watch more than once has no effects.") return nil } ctx, cancel := context.WithCancelCause(ctx) defer cancel(nil) - var waitGroup sync.WaitGroup // Start a goroutine to update the configuration while it has changes from watchers. onChangesChannel := make(chan []func(*Config), 1) defer close(onChangesChannel) + + var waitGroup sync.WaitGroup waitGroup.Add(1) go func() { defer waitGroup.Done() for { select { + case <-ctx.Done(): + return + case onChanges := <-onChangesChannel: - values := make(map[string]any) - for _, w := range c.providers { - maps.Merge(values, *w.values.Load()) - } - c.values.Store(&values) + c.providers.sync() c.log(ctx, slog.LevelDebug, "Configuration has been updated with change.") if len(onChanges) > 0 { @@ -79,25 +68,19 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog if errors.Is(tctx.Err(), context.DeadlineExceeded) { c.log( ctx, slog.LevelWarn, - "Configuration has not been fully applied to onChanges due to timeout."+ + "Configuration has not been fully applied to onChanges in one minute."+ " Please check if the onChanges is blocking or takes too long to complete.", ) } } }() } - - case <-ctx.Done(): - return } } }() // Start a watching goroutine for each watcher registered. - c.providersMutex.RLock() - for i := range c.providers { - provider := c.providers[i] // Use pointer for later modification. - + c.providers.traverse(func(provider *provider) { if watcher, ok := provider.loader.(Watcher); ok { waitGroup.Add(1) go func(ctx context.Context) { @@ -106,24 +89,11 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog onChange := func(values map[string]any) { c.transformKeys(values) oldValues := *provider.values.Swap(&values) - - // Find the onChanges should be triggered. - onChanges := func() []func(*Config) { - c.onChangesMutex.RLock() - defer c.onChangesMutex.RUnlock() - - var callbacks []func(*Config) - for path, onChanges := range c.onChanges { - oldVal := c.sub(oldValues, path) - newVal := c.sub(values, path) - if !reflect.DeepEqual(oldVal, newVal) { - callbacks = append(callbacks, onChanges...) - } - } - - return callbacks - } - onChangesChannel <- onChanges() + onChangesChannel <- c.onChanges.get( + func(path string) bool { + return !reflect.DeepEqual(c.sub(oldValues, path), c.sub(values, path)) + }, + ) c.log(ctx, slog.LevelInfo, "Configuration has been changed.", @@ -137,8 +107,7 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog } }(ctx) } - } - c.providersMutex.RUnlock() + }) waitGroup.Wait() if err := context.Cause(ctx); err != nil && !errors.Is(err, ctx.Err()) { @@ -156,27 +125,52 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog // The register function must be non-blocking and usually completes instantly. // If it requires a long time to complete, it should be executed in a separate goroutine. // -// This method is concurrency-safe. +// This method is concurrent-safe. func (c *Config) OnChange(onChange func(*Config), paths ...string) { if onChange == nil { return // Do nothing is onchange is nil. } c.nocopy.Check() + if !c.caseSensitive { + for i := range paths { + paths[i] = defaultKeyMap(paths[i]) + } + } + c.onChanges.register(onChange, paths) +} + +type onChanges struct { + subscribers map[string][]func(*Config) + mutex sync.RWMutex +} + +func (o *onChanges) register(onChange func(*Config), paths []string) { + o.mutex.Lock() + defer o.mutex.Unlock() + if len(paths) == 0 { paths = []string{""} } - c.onChangesMutex.Lock() - defer c.onChangesMutex.Unlock() - - if c.onChanges == nil { // To support zero Config - c.onChanges = make(map[string][]func(*Config)) + if o.subscribers == nil { + o.subscribers = make(map[string][]func(*Config)) } for _, path := range paths { - if !c.caseSensitive { - path = defaultKeyMap(path) + o.subscribers[path] = append(o.subscribers[path], onChange) + } +} + +func (o *onChanges) get(filter func(string) bool) []func(*Config) { + o.mutex.RLock() + defer o.mutex.RUnlock() + + var callbacks []func(*Config) + for path, subscriber := range o.subscribers { + if filter(path) { + callbacks = append(callbacks, subscriber...) } - c.onChanges[path] = append(c.onChanges[path], onChange) } + + return callbacks } diff --git a/watch_test.go b/watch_test.go index 38598ae..35521be 100644 --- a/watch_test.go +++ b/watch_test.go @@ -16,7 +16,6 @@ import ( "github.com/nil-go/konf" "github.com/nil-go/konf/internal/assert" - "github.com/nil-go/konf/provider/env" ) func TestOnChange_nil(*testing.T) { @@ -124,19 +123,11 @@ func TestConfig_Watch_onchange_block(t *testing.T) { <-ctx.Done() time.Sleep(10 * time.Millisecond) // Wait for log to be written expected := `level=INFO msg="Configuration has been changed." loader=stringWatcher -level=WARN msg="Configuration has not been fully applied to onChanges due to timeout. Please check if the onChanges is blocking or takes too long to complete." +level=WARN msg="Configuration has not been fully applied to onChanges in one minute. Please check if the onChanges is blocking or takes too long to complete." ` assert.Equal(t, expected, buf.String()) } -func TestConfig_Watch_without_loader(t *testing.T) { - t.Parallel() - - var config konf.Config - assert.NoError(t, config.Load(env.New())) - assert.NoError(t, config.Watch(context.Background())) -} - func TestConfig_Watch_twice(t *testing.T) { t.Parallel() @@ -157,7 +148,7 @@ func TestConfig_Watch_twice(t *testing.T) { time.Sleep(100 * time.Millisecond) // Wait for watch to start assert.NoError(t, config.Watch(ctx)) - expected := "level=WARN msg=\"Config has been watched, call Watch again has no effects.\"\n" + expected := "level=WARN msg=\"Config has been watched, call Watch more than once has no effects.\"\n" assert.Equal(t, expected, buf.String()) }