Skip to content

Commit

Permalink
make Config.Load concurrent-safe (#567)
Browse files Browse the repository at this point in the history
So Load and Watch can be called on different goroutines.
  • Loading branch information
ktong authored Nov 20, 2024
1 parent a2af1b4 commit 3138d90
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 102 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Changed

- Config.Load now is concurrent-safe (#567).

## [1.3.1] - 2024-09-09

### Added
Expand Down
96 changes: 63 additions & 33 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,9 @@ type Config struct {
onStatus func(loader Loader, changed bool, err error)
converter *convert.Converter

// Loaded configuration.
values atomic.Pointer[map[string]any]
providers []*provider
providersMutex sync.RWMutex

// For watching changes.
onChanges map[string][]func(*Config)
onChangesMutex sync.RWMutex
watched atomic.Bool
providers providers
onChanges onChanges
watched atomic.Bool
}

// New creates a new Config with the given Option(s).
Expand Down Expand Up @@ -71,7 +65,7 @@ func New(opts ...Option) *Config {
// Load loads configuration from the given loader.
// Each loader takes precedence over the loaders before it.
//
// This method can be called multiple times but it is not concurrency-safe.
// This method is concurrent-safe.
func (c *Config) Load(loader Loader) error {
if loader == nil {
return nil
Expand All @@ -84,19 +78,8 @@ func (c *Config) Load(loader Loader) error {
return fmt.Errorf("load configuration: %w", err)
}
c.transformKeys(values)
prd := provider{
loader: loader,
}
prd.values.Store(&values)
c.providersMutex.Lock()
c.providers = append(c.providers, &prd)
c.providersMutex.Unlock()

// Merge loaded values into values map.
if c.values.Load() == nil {
c.values.Store(&map[string]any{})
}
maps.Merge(*c.values.Load(), *prd.values.Load())
c.providers.append(loader, values)
c.providers.sync()

if _, ok := loader.(Watcher); !ok {
return nil
Expand Down Expand Up @@ -141,15 +124,16 @@ func (c *Config) Unmarshal(path string, target any) error {
}
c.nocopy.Check()

if c.values.Load() == nil {
return nil // To support zero Config
value := c.providers.value()
if value == nil { // To support zero Config
return nil
}

converter := c.converter
if converter == nil { // To support zero Config
converter = defaultConverter
}
if err := converter.Convert(c.sub(*c.values.Load(), path), target); err != nil {
if err := converter.Convert(c.sub(value, path), target); err != nil {
return fmt.Errorf("decode: %w", err)
}

Expand Down Expand Up @@ -195,12 +179,13 @@ func (c *Config) Explain(path string) string {
}
c.nocopy.Check()

if c.values.Load() == nil { // To support zero Config
value := c.providers.value()
if value == nil { // To support zero Config
return path + " has no configuration.\n\n"
}

explanation := &strings.Builder{}
c.explain(explanation, path, c.sub(*c.values.Load(), path))
c.explain(explanation, path, c.sub(value, path))

return explanation.String()
}
Expand All @@ -224,11 +209,11 @@ func (c *Config) explain(explanation *strings.Builder, path string, value any) {
value any
}
var loaders []loaderValue
for _, provider := range c.providers {
c.providers.traverse(func(provider *provider) {
if v := c.sub(*provider.values.Load(), path); v != nil {
loaders = append(loaders, loaderValue{provider.loader, v})
}
}
})
slices.Reverse(loaders)

if len(loaders) == 0 {
Expand Down Expand Up @@ -256,9 +241,54 @@ func (c *Config) explain(explanation *strings.Builder, path string, value any) {
explanation.WriteString("\n")
}

type provider struct {
loader Loader
values atomic.Pointer[map[string]any]
type (
providers struct {
providers []*provider
values atomic.Pointer[map[string]any]
mutex sync.RWMutex
}
provider struct {
loader Loader
values atomic.Pointer[map[string]any]
}
)

func (p *providers) append(loader Loader, values map[string]any) {
p.mutex.Lock()
defer p.mutex.Unlock()

provider := &provider{loader: loader}
provider.values.Store(&values)
p.providers = append(p.providers, provider)
}

func (p *providers) sync() {
p.mutex.Lock()
defer p.mutex.Unlock()

values := make(map[string]any)
for _, w := range p.providers {
maps.Merge(values, *w.values.Load())
}
p.values.Store(&values)
}

func (p *providers) traverse(action func(*provider)) {
p.mutex.RLock()
defer p.mutex.RUnlock()

for _, provider := range p.providers {
action(provider)
}
}

func (p *providers) value() map[string]any {
val := p.values.Load()
if val == nil {
return nil
}

return *val
}

//nolint:gochecknoglobals
Expand Down
2 changes: 1 addition & 1 deletion default.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func Unmarshal(path string, target any) error {
// The register function must be non-blocking and usually completes instantly.
// If it requires a long time to complete, it should be executed in a separate goroutine.
//
// This method is concurrency-safe.
// This method is concurrent-safe.
func OnChange(onChange func(), paths ...string) {
defaultConfig.Load().OnChange(func(*Config) { onChange() }, paths...)
}
Expand Down
5 changes: 3 additions & 2 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ func (c *Config) Exists(path []string) bool {
}
c.nocopy.Check()

if c.values.Load() == nil {
value := c.providers.value()
if value == nil {
return false // To support zero Config
}

return c.sub(*c.values.Load(), strings.Join(path, c.delim())) != nil
return c.sub(value, strings.Join(path, c.delim())) != nil
}
104 changes: 49 additions & 55 deletions watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@ import (
"fmt"
"log/slog"
"reflect"
"slices"
"sync"
"time"

"github.com/nil-go/konf/internal/maps"
)

// Watch watches and updates configuration when it changes.
Expand All @@ -24,39 +21,31 @@ import (
func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocognit
c.nocopy.Check()

if hasWatcher := slices.ContainsFunc(c.providers, func(provider *provider) bool {
_, ok := provider.loader.(Watcher)

return ok
}); !hasWatcher {
return nil
}

if watched := c.watched.Swap(true); watched {
c.log(ctx, slog.LevelWarn, "Config has been watched, call Watch again has no effects.")
c.log(ctx, slog.LevelWarn, "Config has been watched, call Watch more than once has no effects.")

return nil
}

ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
var waitGroup sync.WaitGroup

// Start a goroutine to update the configuration while it has changes from watchers.
onChangesChannel := make(chan []func(*Config), 1)
defer close(onChangesChannel)

var waitGroup sync.WaitGroup
waitGroup.Add(1)
go func() {
defer waitGroup.Done()

for {
select {
case <-ctx.Done():
return

case onChanges := <-onChangesChannel:
values := make(map[string]any)
for _, w := range c.providers {
maps.Merge(values, *w.values.Load())
}
c.values.Store(&values)
c.providers.sync()
c.log(ctx, slog.LevelDebug, "Configuration has been updated with change.")

if len(onChanges) > 0 {
Expand All @@ -79,25 +68,19 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog
if errors.Is(tctx.Err(), context.DeadlineExceeded) {
c.log(
ctx, slog.LevelWarn,
"Configuration has not been fully applied to onChanges due to timeout."+
"Configuration has not been fully applied to onChanges in one minute."+
" Please check if the onChanges is blocking or takes too long to complete.",
)
}
}
}()
}

case <-ctx.Done():
return
}
}
}()

// Start a watching goroutine for each watcher registered.
c.providersMutex.RLock()
for i := range c.providers {
provider := c.providers[i] // Use pointer for later modification.

c.providers.traverse(func(provider *provider) {
if watcher, ok := provider.loader.(Watcher); ok {
waitGroup.Add(1)
go func(ctx context.Context) {
Expand All @@ -106,24 +89,11 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog
onChange := func(values map[string]any) {
c.transformKeys(values)
oldValues := *provider.values.Swap(&values)

// Find the onChanges should be triggered.
onChanges := func() []func(*Config) {
c.onChangesMutex.RLock()
defer c.onChangesMutex.RUnlock()

var callbacks []func(*Config)
for path, onChanges := range c.onChanges {
oldVal := c.sub(oldValues, path)
newVal := c.sub(values, path)
if !reflect.DeepEqual(oldVal, newVal) {
callbacks = append(callbacks, onChanges...)
}
}

return callbacks
}
onChangesChannel <- onChanges()
onChangesChannel <- c.onChanges.get(
func(path string) bool {
return !reflect.DeepEqual(c.sub(oldValues, path), c.sub(values, path))
},
)

c.log(ctx, slog.LevelInfo,
"Configuration has been changed.",
Expand All @@ -137,8 +107,7 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog
}
}(ctx)
}
}
c.providersMutex.RUnlock()
})
waitGroup.Wait()

if err := context.Cause(ctx); err != nil && !errors.Is(err, ctx.Err()) {
Expand All @@ -156,27 +125,52 @@ func (c *Config) Watch(ctx context.Context) error { //nolint:cyclop,funlen,gocog
// The register function must be non-blocking and usually completes instantly.
// If it requires a long time to complete, it should be executed in a separate goroutine.
//
// This method is concurrency-safe.
// This method is concurrent-safe.
func (c *Config) OnChange(onChange func(*Config), paths ...string) {
if onChange == nil {
return // Do nothing is onchange is nil.
}
c.nocopy.Check()

if !c.caseSensitive {
for i := range paths {
paths[i] = defaultKeyMap(paths[i])
}
}
c.onChanges.register(onChange, paths)
}

type onChanges struct {
subscribers map[string][]func(*Config)
mutex sync.RWMutex
}

func (o *onChanges) register(onChange func(*Config), paths []string) {
o.mutex.Lock()
defer o.mutex.Unlock()

if len(paths) == 0 {
paths = []string{""}
}

c.onChangesMutex.Lock()
defer c.onChangesMutex.Unlock()

if c.onChanges == nil { // To support zero Config
c.onChanges = make(map[string][]func(*Config))
if o.subscribers == nil {
o.subscribers = make(map[string][]func(*Config))
}
for _, path := range paths {
if !c.caseSensitive {
path = defaultKeyMap(path)
o.subscribers[path] = append(o.subscribers[path], onChange)
}
}

func (o *onChanges) get(filter func(string) bool) []func(*Config) {
o.mutex.RLock()
defer o.mutex.RUnlock()

var callbacks []func(*Config)
for path, subscriber := range o.subscribers {
if filter(path) {
callbacks = append(callbacks, subscriber...)
}
c.onChanges[path] = append(c.onChanges[path], onChange)
}

return callbacks
}
Loading

0 comments on commit 3138d90

Please sign in to comment.