Skip to content

Commit

Permalink
fix: cancel goroutine before the next one is created
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaoxuan Wang <[email protected]>
  • Loading branch information
wangxiaoxuan273 committed Apr 7, 2024
1 parent 9b6f321 commit 3b52a20
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 13 deletions.
21 changes: 13 additions & 8 deletions copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1779,7 +1779,9 @@ func TestCopyGraph_WithOptions(t *testing.T) {
t.Run("MountFrom error", func(t *testing.T) {
root = descs[6]
dst := &countingStorage{storage: cas.NewMemory()}
opts = oras.CopyGraphOptions{}
opts = oras.CopyGraphOptions{
Concurrency: 1,
}
var numMountFrom atomic.Int64
e := errors.New("mountFrom error")
opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) {
Expand All @@ -1790,7 +1792,8 @@ func TestCopyGraph_WithOptions(t *testing.T) {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, e)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
// with a very low probability, dst.numExists may be 3
if got, expected := dst.numExists.Load(), int64(4); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
Expand All @@ -1799,7 +1802,7 @@ func TestCopyGraph_WithOptions(t *testing.T) {
if got, expected := dst.numPush.Load(), int64(0); got != expected {
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numMountFrom.Load(), int64(4); got != expected {
if got, expected := numMountFrom.Load(), int64(1); got != expected {
t.Errorf("count(MountFrom()) = %d, want %d", got, expected)
}
})
Expand Down Expand Up @@ -1828,7 +1831,9 @@ func TestCopyGraph_WithOptions(t *testing.T) {
}
return nil
}
opts = oras.CopyGraphOptions{}
opts = oras.CopyGraphOptions{
Concurrency: 1,
}
var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPreCopy.Add(1)
Expand All @@ -1851,7 +1856,7 @@ func TestCopyGraph_WithOptions(t *testing.T) {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, e)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
if got, expected := dst.numExists.Load(), int64(4); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
Expand All @@ -1860,13 +1865,13 @@ func TestCopyGraph_WithOptions(t *testing.T) {
if got, expected := dst.numPush.Load(), int64(0); got != expected {
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numMount.Load(), int64(4); got != expected {
if got, expected := numMount.Load(), int64(1); got != expected {
t.Errorf("count(Mount()) = %d, want %d", got, expected)
}
if got, expected := numOnMounted.Load(), int64(4); got != expected {
if got, expected := numOnMounted.Load(), int64(1); got != expected {
t.Errorf("count(OnMounted()) = %d, want %d", got, expected)
}
if got, expected := numMountFrom.Load(), int64(4); got != expected {
if got, expected := numMountFrom.Load(), int64(1); got != expected {
t.Errorf("count(MountFrom()) = %d, want %d", got, expected)
}
if got, expected := numPreCopy.Load(), int64(0); got != expected {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ go 1.21
require (
github.com/opencontainers/go-digest v1.0.0
github.com/opencontainers/image-spec v1.1.0
golang.org/x/sync v0.6.0
golang.org/x/sync v0.7.0
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
14 changes: 12 additions & 2 deletions internal/syncutil/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package syncutil

import (
"context"
"sync/atomic"

"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
Expand Down Expand Up @@ -68,15 +69,24 @@ type GoFunc[T any] func(ctx context.Context, region *LimitedRegion, t T) error
// Go concurrently invokes fn on items.
func Go[T any](ctx context.Context, limiter *semaphore.Weighted, fn GoFunc[T], items ...T) error {
eg, egCtx := errgroup.WithContext(ctx)
var egErr atomic.Value
for _, item := range items {
region := LimitRegion(ctx, limiter)
region := LimitRegion(egCtx, limiter)
if err := region.Start(); err != nil {
if egErr, ok := egErr.Load().(error); ok && egErr != nil {
return egErr
}
return err
}
eg.Go(func(t T) func() error {
return func() error {
defer region.End()
return fn(egCtx, region, t)
// cancel the gorountine before the next goroutine is created
err := fn(egCtx, region, t)
if err != nil {
egErr.CompareAndSwap(nil, err)
}
return err
}
}(item))
}
Expand Down

0 comments on commit 3b52a20

Please sign in to comment.