Skip to content

Commit

Permalink
chore: fix how OCI images are pulled
Browse files Browse the repository at this point in the history
  • Loading branch information
banjoh committed Sep 21, 2023
1 parent e34742f commit 7c9f0c7
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 19 deletions.
78 changes: 59 additions & 19 deletions pkg/oci/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"path/filepath"
"strings"

ocispec "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"github.com/replicatedhq/troubleshoot/internal/util"
"github.com/replicatedhq/troubleshoot/pkg/version"
"k8s.io/klog/v2"
"oras.land/oras-go/pkg/auth"
dockerauth "oras.land/oras-go/pkg/auth/docker"
"oras.land/oras-go/pkg/content"
Expand All @@ -27,14 +29,39 @@ var (
)

func PullPreflightFromOCI(uri string) ([]byte, error) {
return pullFromOCI(uri, "replicated.preflight.spec", "replicated-preflight")
return pullFromOCI(context.Background(), uri, "replicated.preflight.spec", "replicated-preflight")
}

func PullSupportBundleFromOCI(uri string) ([]byte, error) {
return pullFromOCI(uri, "replicated.supportbundle.spec", "replicated-supportbundle")
return pullFromOCI(context.Background(), uri, "replicated.supportbundle.spec", "replicated-supportbundle")
}

func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error) {
func PullSpecsFromOCI(ctx context.Context, uri string) ([]string, error) {
rawSpecs := []string{}

// First try to pull the preflight spec
rawPreflight, err := pullFromOCI(ctx, uri, "replicated.preflight.spec", "replicated-preflight")
if err != nil {
// Ignore "not found" error and continue fetching the support bundle spec
if !errors.Is(err, ErrNoRelease) {
return nil, err
}
} else {
rawSpecs = append(rawSpecs, string(rawPreflight))
}

// Then try to pull the support bundle spec
rawSupportBundle, err := pullFromOCI(ctx, uri, "replicated.supportbundle.spec", "replicated-supportbundle")
// If we had found a preflight spec, do not return an error
if err != nil && len(rawSpecs) == 0 {
return nil, err
}
rawSpecs = append(rawSpecs, string(rawSupportBundle))

return rawSpecs, nil
}

func pullFromOCI(ctx context.Context, uri string, mediaType string, imageName string) ([]byte, error) {
// helm credentials
helmCredentialsFile := filepath.Join(util.HomeDir(), HelmCredentialsFileBasename)
dockerauthClient, err := dockerauth.NewClientWithDockerFallback(helmCredentialsFile)
Expand All @@ -52,6 +79,7 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error)
return nil, errors.Wrap(err, "failed to create resolver")
}

// TODO: How do we handle "not found" cases?
memoryStore := content.NewMemory()
allowedMediaTypes := []string{
mediaType,
Expand All @@ -60,24 +88,13 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error)
var descriptors, layers []ocispec.Descriptor
registryStore := content.Registry{Resolver: resolver}

// remove the oci://
uri = strings.TrimPrefix(uri, "oci://")

uriParts := strings.Split(uri, ":")
uri = fmt.Sprintf("%s/%s", uriParts[0], imageName)

if len(uriParts) > 1 {
uri = fmt.Sprintf("%s:%s", uri, uriParts[1])
} else {
uri = fmt.Sprintf("%s:latest", uri)
}

parsedRef, err := registry.ParseReference(uri)
parsedRef, err := toRegistryRef(uri)
if err != nil {
return nil, errors.Wrap(err, "failed to parse reference")
return nil, err
}
klog.V(1).Infof("Pulling OCI image from %q", parsedRef.String())

manifest, err := oras.Copy(context.TODO(), registryStore, parsedRef.String(), memoryStore, "",
manifest, err := oras.Copy(ctx, registryStore, parsedRef.String(), memoryStore, "",
oras.WithPullEmptyNameAllowed(),
oras.WithAllowedMediaTypes(allowedMediaTypes),
oras.WithLayerDescriptors(func(l []ocispec.Descriptor) {
Expand All @@ -94,7 +111,7 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error)
descriptors = append(descriptors, manifest)
descriptors = append(descriptors, layers...)

// expect 1 descriptor
// expect 2 descriptors
if len(descriptors) != 2 {
return nil, fmt.Errorf("expected 2 descriptor, got %d", len(descriptors))
}
Expand All @@ -120,3 +137,26 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error)

return matchingSpec, nil
}

func toRegistryRef(raw string) (registry.Reference, error) {
u, err := url.Parse(raw)
if err != nil {
return registry.Reference{}, err
}

// Always check the scheme. If more schemes need to be supported
// we need to compare u.Scheme against a list of supported schemes.
// url.Parse(raw) will not return an error is a scheme is not present.
if u.Scheme != "oci" {
return registry.Reference{}, fmt.Errorf("%q is an invalid OCI registry scheme", u.Scheme)
}

parts := strings.Split(u.EscapedPath(), ":")
tag := "latest"
if len(parts) > 1 {
tag = parts[1]
}
// remove the oci://
uri := fmt.Sprintf("%s%s:%s", u.Host, parts[0], tag)
return registry.ParseReference(uri)
}
56 changes: 56 additions & 0 deletions pkg/oci/pull_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package oci

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_toRegistryRef(t *testing.T) {
tests := []struct {
name string
uri string
want string
wantErr bool
}{
{
name: "valid uri",
uri: "oci://localhost/replicated-preflight",
want: "localhost/replicated-preflight:latest",
},
{
name: "valid uri with port",
uri: "oci://localhost:5000/replicated-preflight",
want: "localhost:5000/replicated-preflight:latest",
},
{
name: "valid uri with tag",
uri: "oci://localhost:5000/replicated-preflight:v4",
want: "localhost:5000/replicated-preflight:v4",
},
{
name: "invalid uri - missing scheme",
uri: "localhost:5000/replicated-preflight:v4",
wantErr: true,
},
{
name: "invalid uri - wrong scheme",
uri: "https://localhost:5000/replicated-preflight:v4",
wantErr: true,
},
{
name: "empty uri",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := toRegistryRef(tt.uri)
require.Equalf(t, (err != nil), tt.wantErr, "toRegistryRef() error = %v, wantErr %v", err, tt.wantErr)

gotStr := got.String()
assert.Equalf(t, tt.want, gotStr, "toRegistryRef() = %v, want %v", gotStr, tt.want)
})
}
}

0 comments on commit 7c9f0c7

Please sign in to comment.