diff --git a/pkg/patch/cmd.go b/pkg/patch/cmd.go index c8a3f62b..388e3a44 100644 --- a/pkg/patch/cmd.go +++ b/pkg/patch/cmd.go @@ -30,7 +30,7 @@ type patchArgs struct { ignoreError bool format string output string - pushDest string + push bool } func NewPatchCmd() *cobra.Command { @@ -50,7 +50,7 @@ func NewPatchCmd() *cobra.Command { ua.format, ua.output, ua.ignoreError, - ua.pushDest) + ua.push) }, } flags := patchCmd.Flags() @@ -63,7 +63,7 @@ func NewPatchCmd() *cobra.Command { flags.BoolVar(&ua.ignoreError, "ignore-errors", false, "Ignore errors and continue patching") flags.StringVarP(&ua.format, "format", "f", "openvex", "Output format, defaults to 'openvex'") flags.StringVarP(&ua.output, "output", "o", "", "Output file path") - flags.StringVarP(&ua.pushDest, "push", "p", "", "Push patched image to destination registry. Format: /:. Note: this takes precedence over tag flag if set. ") + flags.BoolVarP(&ua.push, "push", "p", false, "Push patched image to destination registry") if err := patchCmd.MarkFlagRequired("image"); err != nil { panic(err) diff --git a/pkg/patch/patch.go b/pkg/patch/patch.go index 7741f779..b2701357 100644 --- a/pkg/patch/patch.go +++ b/pkg/patch/patch.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "os" + "strings" "time" log "github.com/sirupsen/logrus" @@ -29,13 +30,13 @@ const ( ) // Patch command applies package updates to an OCI image given a vulnerability report. -func Patch(ctx context.Context, timeout time.Duration, buildkitAddr, image, reportFile, patchedTag, workingFolder, format, output string, ignoreError bool, pushDest string) error { +func Patch(ctx context.Context, timeout time.Duration, buildkitAddr, image, reportFile, patchedTag, workingFolder, format, output string, ignoreError, push bool) error { timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() ch := make(chan error) go func() { - ch <- patchWithContext(timeoutCtx, buildkitAddr, image, reportFile, patchedTag, workingFolder, format, output, ignoreError, pushDest) + ch <- patchWithContext(timeoutCtx, buildkitAddr, image, reportFile, patchedTag, workingFolder, format, output, ignoreError, push) }() select { @@ -60,31 +61,11 @@ func removeIfNotDebug(workingFolder string) { } } -func patchWithContext(ctx context.Context, buildkitAddr, image, reportFile, patchedTag, workingFolder, format, output string, ignoreError bool, pushDest string) error { - imageName, err := ref.ParseNamed(image) +func patchWithContext(ctx context.Context, buildkitAddr, image, reportFile, patchedTag, workingFolder, format, output string, ignoreError, push bool) error { + patchedImageName, err := patchedImageTarget(image, patchedTag) if err != nil { return err } - if ref.IsNameOnly(imageName) { - log.Warnf("Image name has no tag or digest, using latest as tag") - imageName = ref.TagNameOnly(imageName) - } - taggedName, ok := imageName.(ref.Tagged) - if !ok { - err := errors.New("unexpected: TagNameOnly did create Tagged ref") - log.Error(err) - return err - } - tag := taggedName.Tag() - if patchedTag == "" { - if tag == "" { - log.Warnf("No output tag specified for digest-referenced image, defaulting to `%s`", defaultPatchedTagSuffix) - patchedTag = defaultPatchedTagSuffix - } else { - patchedTag = fmt.Sprintf("%s-%s", tag, defaultPatchedTagSuffix) - } - } - patchedImageName := fmt.Sprintf("%s:%s", imageName.Name(), patchedTag) // Ensure working folder exists for call to InstallUpdates if workingFolder == "" { @@ -138,12 +119,12 @@ func patchWithContext(ctx context.Context, buildkitAddr, image, reportFile, patc return err } - if pushDest != "" { - if err = buildkit.SolveToRegistry(ctx, config.Client, patchedImageState, config.ConfigData, pushDest); err != nil { + if push { + if err = buildkit.SolveToRegistry(ctx, config.Client, patchedImageState, config.ConfigData, *patchedImageName); err != nil { return err } } else { - if err = buildkit.SolveToDocker(ctx, config.Client, patchedImageState, config.ConfigData, patchedImageName); err != nil { + if err = buildkit.SolveToDocker(ctx, config.Client, patchedImageState, config.ConfigData, *patchedImageName); err != nil { return err } } @@ -166,3 +147,38 @@ func patchWithContext(ctx context.Context, buildkitAddr, image, reportFile, patc } return nil } + +func patchedImageTarget(image, patchedTag string) (*string, error) { + imageName, err := ref.ParseNamed(image) + if err != nil { + return nil, err + } + if ref.IsNameOnly(imageName) { + log.Warnf("Image name has no tag or digest, using latest as tag") + imageName = ref.TagNameOnly(imageName) + } + taggedName, ok := imageName.(ref.Tagged) + if !ok { + err := errors.New("unexpected: TagNameOnly did create Tagged ref") + log.Error(err) + return nil, err + } + tag := taggedName.Tag() + var patchedImageName string + if patchedTag == "" { + if tag == "" { + log.Warnf("No output tag specified for digest-referenced image, defaulting to `%s`", defaultPatchedTagSuffix) + patchedTag = defaultPatchedTagSuffix + } else { + patchedTag = fmt.Sprintf("%s-%s", tag, defaultPatchedTagSuffix) + } + } + // this implies user has passed a destination image name, not just a tag + if strings.Contains(patchedTag, "/") { + patchedImageName = patchedTag + } else { + patchedImageName = fmt.Sprintf("%s:%s", imageName.Name(), patchedTag) + } + + return &patchedImageName, nil +} diff --git a/pkg/patch/patch_test.go b/pkg/patch/patch_test.go index abc0acda..7e745bf1 100644 --- a/pkg/patch/patch_test.go +++ b/pkg/patch/patch_test.go @@ -49,3 +49,49 @@ func TestRemoveIfNotDebug(t *testing.T) { os.RemoveAll(workingFolder) }) } + +func TestPatchedImageTarget(t *testing.T) { + tests := []struct { + name string + image string + patchedTag string + want string + wantErr bool + }{ + { + name: "tag passed is empty", + image: "docker.io/library/nginx:1.21.3", + patchedTag: "", + want: "docker.io/library/nginx:1.21.3-patched", + wantErr: false, + }, + + { + name: "tag passed with value", + image: "docker.io/library/nginx:1.21.3", + patchedTag: "custom", + want: "docker.io/library/nginx:custom", + wantErr: false, + }, + { + name: "tag passed and contains a slash(indicating registry)", + image: "docker.io/library/nginx:1.21.3", + patchedTag: "my.registry/nginx:1.21.3-patched", + want: "my.registry/nginx:1.21.3-patched", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := patchedImageTarget(tt.image, tt.patchedTag) + if (err != nil) != tt.wantErr { + t.Errorf("patchedImageTarget() error = %v, wantErr %v", err, tt.wantErr) + return + } + if *got != tt.want { + t.Errorf("patchedImageTarget() = %v, want %v", *got, tt.want) + } + }) + } +}