Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make Config.Load concurrent-safe #567

Merged
merged 3 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Changed

- Config.Load now is concurrent-safe (#567).

## [1.3.1] - 2024-09-09

### Added
Expand Down
96 changes: 63 additions & 33 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,9 @@ type Config struct {
onStatus func(loader Loader, changed bool, err error)
converter *convert.Converter

// Loaded configuration.
values atomic.Pointer[map[string]any]
providers []*provider
providersMutex sync.RWMutex

// For watching changes.
onChanges map[string][]func(*Config)
onChangesMutex sync.RWMutex
watched atomic.Bool
providers providers
onChanges onChanges
watched atomic.Bool
}

// New creates a new Config with the given Option(s).
Expand Down Expand Up @@ -71,7 +65,7 @@ func New(opts ...Option) *Config {
// Load loads configuration from the given loader.
// Each loader takes precedence over the loaders before it.
//
// This method can be called multiple times but it is not concurrency-safe.
// This method is concurrent-safe.
func (c *Config) Load(loader Loader) error {
if loader == nil {
return nil
Expand All @@ -84,19 +78,8 @@ func (c *Config) Load(loader Loader) error {
return fmt.Errorf("load configuration: %w", err)
}
c.transformKeys(values)
prd := provider{
loader: loader,
}
prd.values.Store(&values)
c.providersMutex.Lock()
c.providers = append(c.providers, &prd)
c.providersMutex.Unlock()

// Merge loaded values into values map.
if c.values.Load() == nil {
c.values.Store(&map[string]any{})
}
maps.Merge(*c.values.Load(), *prd.values.Load())
c.providers.append(loader, values)
c.providers.sync()

if _, ok := loader.(Watcher); !ok {
return nil
Expand Down Expand Up @@ -141,15 +124,16 @@ func (c *Config) Unmarshal(path string, target any) error {
}
c.nocopy.Check()

if c.values.Load() == nil {
return nil // To support zero Config
value := c.providers.value()
if value == nil { // To support zero Config
return nil
}

converter := c.converter
if converter == nil { // To support zero Config
converter = defaultConverter
}
if err := converter.Convert(c.sub(*c.values.Load(), path), target); err != nil {
if err := converter.Convert(c.sub(value, path), target); err != nil {
return fmt.Errorf("decode: %w", err)
}

Expand Down Expand Up @@ -195,12 +179,13 @@ func (c *Config) Explain(path string) string {
}
c.nocopy.Check()

if c.values.Load() == nil { // To support zero Config
value := c.providers.value()
if value == nil { // To support zero Config
return path + " has no configuration.\n\n"
}

explanation := &strings.Builder{}
c.explain(explanation, path, c.sub(*c.values.Load(), path))
c.explain(explanation, path, c.sub(value, path))

return explanation.String()
}
Expand All @@ -224,11 +209,11 @@ func (c *Config) explain(explanation *strings.Builder, path string, value any) {
value any
}
var loaders []loaderValue
for _, provider := range c.providers {
c.providers.traverse(func(provider *provider) {
if v := c.sub(*provider.values.Load(), path); v != nil {
loaders = append(loaders, loaderValue{provider.loader, v})
}
}
})
slices.Reverse(loaders)

if len(loaders) == 0 {
Expand Down Expand Up @@ -256,9 +241,54 @@ func (c *Config) explain(explanation *strings.Builder, path string, value any) {
explanation.WriteString("\n")
}

type provider struct {
loader Loader
values atomic.Pointer[map[string]any]
type (
providers struct {
providers []*provider
values atomic.Pointer[map[string]any]
mutex sync.RWMutex
}
provider struct {
loader Loader
values atomic.Pointer[map[string]any]
}
)

func (p *providers) append(loader Loader, values map[string]any) {
p.mutex.Lock()
defer p.mutex.Unlock()

provider := &provider{loader: loader}
provider.values.Store(&values)
p.providers = append(p.providers, provider)
}

func (p *providers) sync() {
p.mutex.Lock()
defer p.mutex.Unlock()

values := make(map[string]any)
for _, w := range p.providers {
maps.Merge(values, *w.values.Load())
}
p.values.Store(&values)
}

func (p *providers) traverse(action func(*provider)) {
p.mutex.RLock()
defer p.mutex.RUnlock()

for _, provider := range p.providers {
action(provider)
}
}

func (p *providers) value() map[string]any {
val := p.values.Load()
if val == nil {
return nil
}

return *val
}

//nolint:gochecknoglobals
Expand Down
2 changes: 1 addition & 1 deletion default.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func Unmarshal(path string, target any) error {
// The register function must be non-blocking and usually completes instantly.
// If it requires a long time to complete, it should be executed in a separate goroutine.
//
// This method is concurrency-safe.
// This method is concurrent-safe.
func OnChange(onChange func(), paths ...string) {
defaultConfig.Load().OnChange(func(*Config) { onChange() }, paths...)
}
Expand Down
5 changes: 3 additions & 2 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ func (c *Config) Exists(path []string) bool {
}
c.nocopy.Check()

if c.values.Load() == nil {
value := c.providers.value()
if value == nil {
return false // To support zero Config
}

return c.sub(*c.values.Load(), strings.Join(path, c.delim())) != nil
return c.sub(value, strings.Join(path, c.delim())) != nil
}
104 changes: 49 additions & 55 deletions watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@ import (
"fmt"
"log/slog"
"reflect"
"slices"
"sync"
"time"

"github.com/nil-go/konf/internal/maps"
)

// Watch watches and updates configuration when it changes.
Expand All @@ -24,39 +21,31 @@ import (
func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocognit
c.nocopy.Check()

if hasWatcher := slices.ContainsFunc(c.providers, func(provider *provider) bool {
_, ok := provider.loader.(Watcher)

return ok
}); !hasWatcher {
return nil
}

if watched := c.watched.Swap(true); watched {
c.log(ctx, slog.LevelWarn, "Config has been watched, call Watch again has no effects.")
c.log(ctx, slog.LevelWarn, "Config has been watched, call Watch more than once has no effects.")

return nil
}

ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
var waitGroup sync.WaitGroup

// Start a goroutine to update the configuration while it has changes from watchers.
onChangesChannel := make(chan []func(*Config), 1)
defer close(onChangesChannel)

var waitGroup sync.WaitGroup
waitGroup.Add(1)
go func() {
defer waitGroup.Done()

for {
select {
case <-ctx.Done():
return

case onChanges := <-onChangesChannel:
values := make(map[string]any)
for _, w := range c.providers {
maps.Merge(values, *w.values.Load())
}
c.values.Store(&values)
c.providers.sync()
c.log(ctx, slog.LevelDebug, "Configuration has been updated with change.")

if len(onChanges) > 0 {
Expand All @@ -79,25 +68,19 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog
if errors.Is(tctx.Err(), context.DeadlineExceeded) {
c.log(
ctx, slog.LevelWarn,
"Configuration has not been fully applied to onChanges due to timeout."+
"Configuration has not been fully applied to onChanges in one minute."+
" Please check if the onChanges is blocking or takes too long to complete.",
)
}
}
}()
}

case <-ctx.Done():
return
}
}
}()

// Start a watching goroutine for each watcher registered.
c.providersMutex.RLock()
for i := range c.providers {
provider := c.providers[i] // Use pointer for later modification.

c.providers.traverse(func(provider *provider) {
if watcher, ok := provider.loader.(Watcher); ok {
waitGroup.Add(1)
go func(ctx context.Context) {
Expand All @@ -106,24 +89,11 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog
onChange := func(values map[string]any) {
c.transformKeys(values)
oldValues := *provider.values.Swap(&values)

// Find the onChanges should be triggered.
onChanges := func() []func(*Config) {
c.onChangesMutex.RLock()
defer c.onChangesMutex.RUnlock()

var callbacks []func(*Config)
for path, onChanges := range c.onChanges {
oldVal := c.sub(oldValues, path)
newVal := c.sub(values, path)
if !reflect.DeepEqual(oldVal, newVal) {
callbacks = append(callbacks, onChanges...)
}
}

return callbacks
}
onChangesChannel <- onChanges()
onChangesChannel <- c.onChanges.get(
func(path string) bool {
return !reflect.DeepEqual(c.sub(oldValues, path), c.sub(values, path))
},
)

c.log(ctx, slog.LevelInfo,
"Configuration has been changed.",
Expand All @@ -137,8 +107,7 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog
}
}(ctx)
}
}
c.providersMutex.RUnlock()
})
waitGroup.Wait()

if err := context.Cause(ctx); err != nil && !errors.Is(err, ctx.Err()) {
Expand All @@ -156,27 +125,52 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog
// The register function must be non-blocking and usually completes instantly.
// If it requires a long time to complete, it should be executed in a separate goroutine.
//
// This method is concurrency-safe.
// This method is concurrent-safe.
func (c *Config) OnChange(onChange func(*Config), paths ...string) {
if onChange == nil {
return // Do nothing is onchange is nil.
}
c.nocopy.Check()

if !c.caseSensitive {
for i := range paths {
paths[i] = defaultKeyMap(paths[i])
}
}
c.onChanges.register(onChange, paths)
}

type onChanges struct {
subscribers map[string][]func(*Config)
mutex sync.RWMutex
}

func (o *onChanges) register(onChange func(*Config), paths []string) {
o.mutex.Lock()
defer o.mutex.Unlock()

if len(paths) == 0 {
paths = []string{""}
}

c.onChangesMutex.Lock()
defer c.onChangesMutex.Unlock()

if c.onChanges == nil { // To support zero Config
c.onChanges = make(map[string][]func(*Config))
if o.subscribers == nil {
o.subscribers = make(map[string][]func(*Config))
}
for _, path := range paths {
if !c.caseSensitive {
path = defaultKeyMap(path)
o.subscribers[path] = append(o.subscribers[path], onChange)
}
}

func (o *onChanges) get(filter func(string) bool) []func(*Config) {
o.mutex.RLock()
defer o.mutex.RUnlock()

var callbacks []func(*Config)
for path, subscriber := range o.subscribers {
if filter(path) {
callbacks = append(callbacks, subscriber...)
}
c.onChanges[path] = append(c.onChanges[path], onChange)
}

return callbacks
}
Loading