diff --git a/copy_test.go b/copy_test.go index 02421524..4fd04772 100644 --- a/copy_test.go +++ b/copy_test.go @@ -1779,7 +1779,9 @@ func TestCopyGraph_WithOptions(t *testing.T) { t.Run("MountFrom error", func(t *testing.T) { root = descs[6] dst := &countingStorage{storage: cas.NewMemory()} - opts = oras.CopyGraphOptions{} + opts = oras.CopyGraphOptions{ + Concurrency: 1, + } var numMountFrom atomic.Int64 e := errors.New("mountFrom error") opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { @@ -1790,7 +1792,8 @@ func TestCopyGraph_WithOptions(t *testing.T) { t.Fatalf("CopyGraph() error = %v, wantErr %v", err, e) } - if got, expected := dst.numExists.Load(), int64(7); got != expected { + // with a very low probability, dst.numExists may be 3 + if got, expected := dst.numExists.Load(), int64(4); got != expected { t.Errorf("count(Exists()) = %d, want %d", got, expected) } if got, expected := dst.numFetch.Load(), int64(0); got != expected { @@ -1799,7 +1802,7 @@ func TestCopyGraph_WithOptions(t *testing.T) { if got, expected := dst.numPush.Load(), int64(0); got != expected { t.Errorf("count(Push()) = %d, want %d", got, expected) } - if got, expected := numMountFrom.Load(), int64(4); got != expected { + if got, expected := numMountFrom.Load(), int64(1); got != expected { t.Errorf("count(MountFrom()) = %d, want %d", got, expected) } }) @@ -1828,7 +1831,9 @@ func TestCopyGraph_WithOptions(t *testing.T) { } return nil } - opts = oras.CopyGraphOptions{} + opts = oras.CopyGraphOptions{ + Concurrency: 1, + } var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64 opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { numPreCopy.Add(1) @@ -1851,7 +1856,7 @@ func TestCopyGraph_WithOptions(t *testing.T) { t.Fatalf("CopyGraph() error = %v, wantErr %v", err, e) } - if got, expected := dst.numExists.Load(), int64(7); got != expected { + if got, expected := dst.numExists.Load(), int64(4); got != expected { t.Errorf("count(Exists()) = %d, want %d", got, expected) } if got, expected := dst.numFetch.Load(), int64(0); got != expected { @@ -1860,13 +1865,13 @@ func TestCopyGraph_WithOptions(t *testing.T) { if got, expected := dst.numPush.Load(), int64(0); got != expected { t.Errorf("count(Push()) = %d, want %d", got, expected) } - if got, expected := numMount.Load(), int64(4); got != expected { + if got, expected := numMount.Load(), int64(1); got != expected { t.Errorf("count(Mount()) = %d, want %d", got, expected) } - if got, expected := numOnMounted.Load(), int64(4); got != expected { + if got, expected := numOnMounted.Load(), int64(1); got != expected { t.Errorf("count(OnMounted()) = %d, want %d", got, expected) } - if got, expected := numMountFrom.Load(), int64(4); got != expected { + if got, expected := numMountFrom.Load(), int64(1); got != expected { t.Errorf("count(MountFrom()) = %d, want %d", got, expected) } if got, expected := numPreCopy.Load(), int64(0); got != expected { diff --git a/go.mod b/go.mod index 85b83d90..bd267939 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,5 @@ go 1.21 require ( github.com/opencontainers/go-digest v1.0.0 github.com/opencontainers/image-spec v1.1.0 - golang.org/x/sync v0.6.0 + golang.org/x/sync v0.7.0 ) diff --git a/go.sum b/go.sum index 9b89e8ae..eec227b2 100644 --- a/go.sum +++ b/go.sum @@ -2,5 +2,5 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= diff --git a/internal/syncutil/limit.go b/internal/syncutil/limit.go index 2a05d4ea..3b28b8ed 100644 --- a/internal/syncutil/limit.go +++ b/internal/syncutil/limit.go @@ -17,6 +17,7 @@ package syncutil import ( "context" + "sync/atomic" "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" @@ -68,15 +69,24 @@ type GoFunc[T any] func(ctx context.Context, region *LimitedRegion, t T) error // Go concurrently invokes fn on items. func Go[T any](ctx context.Context, limiter *semaphore.Weighted, fn GoFunc[T], items ...T) error { eg, egCtx := errgroup.WithContext(ctx) + var egErr atomic.Value for _, item := range items { - region := LimitRegion(ctx, limiter) + region := LimitRegion(egCtx, limiter) if err := region.Start(); err != nil { + if egErr, ok := egErr.Load().(error); ok && egErr != nil { + return egErr + } return err } eg.Go(func(t T) func() error { return func() error { defer region.End() - return fn(egCtx, region, t) + // cancel the gorountine before the next goroutine is created + err := fn(egCtx, region, t) + if err != nil { + egErr.CompareAndSwap(nil, err) + } + return err } }(item)) }