diff --git a/pkg/mdformatter/linktransformer/link.go b/pkg/mdformatter/linktransformer/link.go index 2f6044f..b59cafb 100644 --- a/pkg/mdformatter/linktransformer/link.go +++ b/pkg/mdformatter/linktransformer/link.go @@ -125,7 +125,7 @@ type validator struct { c *colly.Collector futureMu sync.Mutex - destFutures map[futureKey]*futureResult + destFutures sync.Map } type futureKey struct { @@ -156,14 +156,13 @@ func NewValidator(ctx context.Context, logger log.Logger, linksValidateConfig [] localLinks: map[string]*[]string{}, remoteLinks: map[string]error{}, c: colly.NewCollector(colly.Async(), colly.StdlibContext(ctx)), - destFutures: map[futureKey]*futureResult{}, } // Set very soft limits. // E.g github has 50-5000 https://docs.github.com/en/free-pro-team@latest/rest/reference/rate-limit limit depending // on api (only search is below 100). if err := v.c.Limit(&colly.LimitRule{ DomainGlob: "*", - Parallelism: 100, + Parallelism: 5, }); err != nil { return nil, err } @@ -234,6 +233,11 @@ func MustNewValidator(logger log.Logger, linksValidateConfig []byte, anchorDir s } func (v *validator) TransformDestination(ctx mdformatter.SourceContext, destination []byte) (_ []byte, err error) { + select { + case <-ctx.Context.Done(): + return nil, ctx.Err() + default: + } v.visit(ctx.Filepath, string(destination), ctx.LineNumbers) return destination, nil } @@ -242,7 +246,14 @@ func (v *validator) Close(ctx mdformatter.SourceContext) error { v.c.Wait() var keys []futureKey - for k := range v.destFutures { + // Read map from sync.Map. + destFuturesMap := map[futureKey]*futureResult{} + v.destFutures.Range(func(key, value interface{}) bool { + destFuturesMap[key.(futureKey)] = value.(*futureResult) + return true + }) + + for k := range destFuturesMap { if k.filepath != ctx.Filepath { continue } @@ -263,7 +274,7 @@ func (v *validator) Close(ctx mdformatter.SourceContext) error { } for _, k := range keys { - f := v.destFutures[k] + f := destFuturesMap[k] if err := f.resultFn(); err != nil { if f.cases == 1 { merr.Add(errors.Wrapf(err, "%v:%v", path, k.lineNumbers)) @@ -279,11 +290,16 @@ func (v *validator) visit(filepath string, dest string, lineNumbers string) { v.futureMu.Lock() defer v.futureMu.Unlock() k := futureKey{filepath: filepath, dest: dest, lineNumbers: lineNumbers} - if _, ok := v.destFutures[k]; ok { - v.destFutures[k].cases++ + // If key present, delete and increment cases. + if prevResult, loaded := v.destFutures.LoadAndDelete(k); loaded { + newResult := prevResult.(*futureResult) + newResult.cases++ + v.destFutures.Store(k, newResult) return } - v.destFutures[k] = &futureResult{cases: 1, resultFn: func() error { return nil }} + + // Key not present, no store. + v.destFutures.Store(k, &futureResult{cases: 1, resultFn: func() error { return nil }}) matches := remoteLinkPrefixRe.FindAllStringIndex(dest, 1) if matches == nil { // Relative or absolute path. Check if exists. @@ -291,7 +307,10 @@ func (v *validator) visit(filepath string, dest string, lineNumbers string) { // Local link. Check if exists. if err := v.localLinks.Lookup(newDest); err != nil { - v.destFutures[k].resultFn = func() error { return errors.Wrapf(err, "link %v, normalized to", dest) } + prevResult, _ := v.destFutures.LoadAndDelete(k) + newResult := prevResult.(*futureResult) + newResult.resultFn = func() error { return errors.Wrapf(err, "link %v, normalized to", dest) } + v.destFutures.Store(k, newResult) } return } diff --git a/pkg/mdformatter/linktransformer/validator.go b/pkg/mdformatter/linktransformer/validator.go index 1000813..bc75120 100644 --- a/pkg/mdformatter/linktransformer/validator.go +++ b/pkg/mdformatter/linktransformer/validator.go @@ -34,7 +34,11 @@ func (v GitHubValidator) IsValid(k futureKey, r *validator) (bool, error) { // RoundTripValidator.IsValid returns true if url is checked by colly. func (v RoundTripValidator) IsValid(k futureKey, r *validator) (bool, error) { // Result will be in future. - r.destFutures[k].resultFn = func() error { return r.remoteLinks[k.dest] } + prevResult, _ := r.destFutures.LoadAndDelete(k) + newResult := prevResult.(*futureResult) + newResult.resultFn = func() error { return r.remoteLinks[k.dest] } + r.destFutures.Store(k, newResult) + r.rMu.RLock() if _, ok := r.remoteLinks[k.dest]; ok { r.rMu.RUnlock() diff --git a/pkg/mdformatter/mdformatter.go b/pkg/mdformatter/mdformatter.go index 71793f3..50dafd8 100644 --- a/pkg/mdformatter/mdformatter.go +++ b/pkg/mdformatter/mdformatter.go @@ -10,6 +10,7 @@ import ( "io/ioutil" "os" "sort" + "sync" "time" "github.com/Kunde21/markdownfmt/v2/markdown" @@ -217,7 +218,7 @@ func newSpinner(suffix string) (*yacspin.Spinner, error) { // Format formats given markdown files in-place. IsFormatted `With...` function to see what modifiers you can add. func Format(ctx context.Context, logger log.Logger, files []string, opts ...Option) error { - spin, err := newSpinner(" Formatting: ") + spin, err := newSpinner(" Formatting... ") if err != nil { return err } @@ -228,7 +229,7 @@ func Format(ctx context.Context, logger log.Logger, files []string, opts ...Opti // If diff is empty it means all files are formatted. func IsFormatted(ctx context.Context, logger log.Logger, files []string, opts ...Option) (diffs Diffs, err error) { d := Diffs{} - spin, err := newSpinner(" Checking: ") + spin, err := newSpinner(" Checking... ") if err != nil { return nil, err } @@ -240,57 +241,73 @@ func IsFormatted(ctx context.Context, logger log.Logger, files []string, opts .. func format(ctx context.Context, logger log.Logger, files []string, diffs *Diffs, spin *yacspin.Spinner, opts ...Option) error { f := New(ctx, opts...) - b := bytes.Buffer{} - // TODO(bwplotka): Add concurrency (collector will need to redone). + errorChannel := make(chan error) + var wg sync.WaitGroup errs := merrors.New() if spin != nil { errs.Add(spin.Start()) } + + wg.Add(len(files)) + + go func() { + wg.Wait() + close(errorChannel) + }() + for _, fn := range files { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - if spin != nil { - spin.Message(fn + "...") - } - errs.Add(func() error { + go func(fn string) { + defer wg.Done() + b := bytes.Buffer{} + file, err := os.OpenFile(fn, os.O_RDWR, 0) if err != nil { - return errors.Wrapf(err, "open %v", fn) + errorChannel <- errors.Wrapf(err, "open %v", fn) + return } defer logerrcapture.ExhaustClose(logger, file, "close file %v", fn) b.Reset() if err := f.Format(file, &b); err != nil { - return err + errorChannel <- err + return } if diffs != nil { if _, err := file.Seek(0, 0); err != nil { - return err + errorChannel <- err + return } in, err := ioutil.ReadAll(file) if err != nil { - return errors.Wrapf(err, "read all %v", fn) + errorChannel <- errors.Wrapf(err, "read all %v", fn) + return } if !bytes.Equal(in, b.Bytes()) { *diffs = append(*diffs, gitdiff.CompareBytes(in, fn, b.Bytes(), fn+" (formatted)")) } - return nil + return } n, err := file.WriteAt(b.Bytes(), 0) if err != nil { - return errors.Wrapf(err, "write %v", fn) + errorChannel <- errors.Wrapf(err, "write %v", fn) + return + } + if err := file.Truncate(int64(n)); err != nil { + errorChannel <- err + return } - return file.Truncate(int64(n)) - }()) + }(fn) } + + for err := range errorChannel { + errs.Add(err) + } + if spin != nil { errs.Add(spin.Stop()) }