diff --git a/pkg/patch/cmd.go b/pkg/patch/cmd.go index f980f978b..442acc56f 100644 --- a/pkg/patch/cmd.go +++ b/pkg/patch/cmd.go @@ -25,6 +25,7 @@ type patchArgs struct { ignoreError bool format string output string + silent bool bkOpts buildkit.Opts } @@ -50,6 +51,7 @@ func NewPatchCmd() *cobra.Command { ua.scanner, ua.format, ua.output, + ua.silent, ua.ignoreError, bkopts) }, @@ -66,6 +68,7 @@ func NewPatchCmd() *cobra.Command { flags.DurationVar(&ua.timeout, "timeout", 5*time.Minute, "Timeout for the operation, defaults to '5m'") flags.StringVarP(&ua.scanner, "scanner", "s", "trivy", "Scanner used to generate the report, defaults to 'trivy'") flags.BoolVar(&ua.ignoreError, "ignore-errors", false, "Ignore errors and continue patching") + flags.BoolVar(&ua.silent, "silent", false, "silences the buildkit output while processing") flags.StringVarP(&ua.format, "format", "f", "openvex", "Output format, defaults to 'openvex'") flags.StringVarP(&ua.output, "output", "o", "", "Output file path") diff --git a/pkg/patch/cmd_test.go b/pkg/patch/cmd_test.go index eaa21c95a..b587861ef 100644 --- a/pkg/patch/cmd_test.go +++ b/pkg/patch/cmd_test.go @@ -4,14 +4,22 @@ import "testing" func TestNewPatchCmd(t *testing.T) { tests := []struct { - name string - args []string - expected string + name string + args []string + expected bool + errString string }{ { - name: "Missing image flag", - args: []string{"-r", "trivy.json", "-t", "3.7-alpine-patched"}, - expected: "required flag(s) \"image\" not set", + name: "Missing image flag", + args: []string{"-r", "trivy.json", "-t", "3.7-alpine-patched"}, + expected: true, + errString: "required flag(s) \"image\" not set", + }, + { + name: "Silent flag used", + args: []string{"-i", "alpine:3.14", "--debug"}, + expected: false, + errString: "", }, } @@ -24,8 +32,16 @@ func TestNewPatchCmd(t *testing.T) { // Run the command and capture the output err := cmd.Execute() - if err == nil || err.Error() != tt.expected { - t.Errorf("Unexpected error: %v, expected: %v", err, tt.expected) + if !tt.expected { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } else { + if err == nil { + t.Errorf("Expected error: %v, got %v", tt.expected, err) + } else if err != nil && err.Error() != tt.errString { + t.Errorf("Unexpected error: %v, expected: %v", err, tt.expected) + } } }) } diff --git a/pkg/patch/patch.go b/pkg/patch/patch.go index 38d839de4..9b25a4730 100644 --- a/pkg/patch/patch.go +++ b/pkg/patch/patch.go @@ -43,13 +43,13 @@ const ( ) // Patch command applies package updates to an OCI image given a vulnerability report. -func Patch(ctx context.Context, timeout time.Duration, image, reportFile, patchedTag, workingFolder, scanner, format, output string, ignoreError bool, bkOpts buildkit.Opts) error { +func Patch(ctx context.Context, timeout time.Duration, image, reportFile, patchedTag, workingFolder, scanner, format, output string, silent, ignoreError bool, bkOpts buildkit.Opts) error { timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() ch := make(chan error) go func() { - ch <- patchWithContext(timeoutCtx, ch, image, reportFile, patchedTag, workingFolder, scanner, format, output, ignoreError, bkOpts) + ch <- patchWithContext(timeoutCtx, ch, image, reportFile, patchedTag, workingFolder, scanner, format, output, silent, ignoreError, bkOpts) }() select { @@ -74,7 +74,7 @@ func removeIfNotDebug(workingFolder string) { } } -func patchWithContext(ctx context.Context, ch chan error, image, reportFile, patchedTag, workingFolder, scanner, format, output string, ignoreError bool, bkOpts buildkit.Opts) error { +func patchWithContext(ctx context.Context, ch chan error, image, reportFile, patchedTag, workingFolder, scanner, format, output string, silent, ignoreError bool, bkOpts buildkit.Opts) error { imageName, err := reference.ParseNormalizedNamed(image) if err != nil { return err @@ -275,21 +275,35 @@ func patchWithContext(ctx context.Context, ch chan error, image, reportFile, pat return err }) + if silent { + eg.Go(func() error { + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + case _, ok := <-buildChannel: + if !ok { + return nil + } + } + } + }) + } else { + eg.Go(func() error { + // not using shared context to not disrupt display but let us finish reporting errors + mode := progressui.AutoMode + if log.GetLevel() >= log.DebugLevel { + mode = progressui.PlainMode + } + display, err := progressui.NewDisplay(os.Stderr, mode) + if err != nil { + return err + } - eg.Go(func() error { - // not using shared context to not disrupt display but let us finish reporting errors - mode := progressui.AutoMode - if log.GetLevel() >= log.DebugLevel { - mode = progressui.PlainMode - } - display, err := progressui.NewDisplay(os.Stderr, mode) - if err != nil { + _, err = display.UpdateFrom(ctx, buildChannel) return err - } - - _, err = display.UpdateFrom(ctx, buildChannel) - return err - }) + }) + } eg.Go(func() error { if err := dockerLoad(ctx, pipeR); err != nil {