diff --git a/cmd/fetch/fetch_policy.go b/cmd/fetch/fetch_policy.go index 9b8bd6a2c..999545b94 100644 --- a/cmd/fetch/fetch_policy.go +++ b/cmd/fetch/fetch_policy.go @@ -115,8 +115,10 @@ func fetchPolicyCmd() *cobra.Command { sources = append(sources, &source.PolicyUrl{Url: url, Kind: source.DataKind}) } + ctx := cmd.Context() for _, s := range sources { - _, err := s.GetPolicy(cmd.Context(), destDir, true) + c, _, err := s.GetPolicy(ctx, destDir, true) + ctx = c if err != nil { return err } diff --git a/cmd/inspect/inspect_policy.go b/cmd/inspect/inspect_policy.go index 70b53bb2c..443eeb380 100644 --- a/cmd/inspect/inspect_policy.go +++ b/cmd/inspect/inspect_policy.go @@ -121,7 +121,8 @@ func inspectPolicyCmd() *cobra.Command { s := &source.PolicyUrl{Url: url, Kind: source.PolicyKind} // Download - policyDir, err := s.GetPolicy(ctx, destDir, false) + c, policyDir, err := s.GetPolicy(ctx, destDir, false) + ctx = c if err != nil { return err } diff --git a/cmd/inspect/inspect_policy_data.go b/cmd/inspect/inspect_policy_data.go index f51df593d..d8a7c4cca 100644 --- a/cmd/inspect/inspect_policy_data.go +++ b/cmd/inspect/inspect_policy_data.go @@ -91,7 +91,8 @@ func inspectPolicyDataCmd() *cobra.Command { s := &source.PolicyUrl{Url: url, Kind: source.PolicyKind} // Download - policyDir, err := s.GetPolicy(ctx, destDir, false) + c, policyDir, err := s.GetPolicy(ctx, destDir, false) + ctx = c if err != nil { return err } diff --git a/cmd/validate/common_test.go b/cmd/validate/common_test.go index 340904fb8..f8681b3e6 100644 --- a/cmd/validate/common_test.go +++ b/cmd/validate/common_test.go @@ -41,10 +41,10 @@ type mockEvaluator struct { mock.Mock } -func (e *mockEvaluator) Evaluate(ctx context.Context, target evaluator.EvaluationTarget) ([]evaluator.Outcome, evaluator.Data, error) { +func (e *mockEvaluator) Evaluate(ctx context.Context, target evaluator.EvaluationTarget) (context.Context, []evaluator.Outcome, evaluator.Data, error) { args := e.Called(ctx, target.Inputs) - return args.Get(0).([]evaluator.Outcome), args.Get(1).(evaluator.Data), args.Error(2) + return ctx, args.Get(0).([]evaluator.Outcome), args.Get(1).(evaluator.Data), args.Error(2) } func (e *mockEvaluator) Destroy() { diff --git a/cmd/validate/image.go b/cmd/validate/image.go index 68cc2b117..87172bc39 100644 --- a/cmd/validate/image.go +++ b/cmd/validate/image.go @@ -209,7 +209,8 @@ func validateImageCmd(validate imageValidationFunc) *cobra.Command { data.spec = s } - policyConfiguration, err := validate_utils.GetPolicyConfig(ctx, data.policyConfiguration) + c, policyConfiguration, err := validate_utils.GetPolicyConfig(ctx, data.policyConfiguration) + ctx = c if err != nil { allErrors = errors.Join(allErrors, err) return @@ -232,9 +233,10 @@ func validateImageCmd(validate imageValidationFunc) *cobra.Command { // We're not currently using the policyCache returned from PreProcessPolicy, but we could // use it to cache the policy for future use. - if p, _, err := policy.PreProcessPolicy(ctx, policyOptions); err != nil { + if c, p, _, err := policy.PreProcessPolicy(ctx, policyOptions); err != nil { allErrors = errors.Join(allErrors, err) } else { + ctx = c // inject extra variables into rule data per source if len(data.extraRuleData) > 0 { policySpec := p.Spec() @@ -262,7 +264,8 @@ func validateImageCmd(validate imageValidationFunc) *cobra.Command { if len(parts) < 2 { log.Errorf("Incorrect syntax for --extra-rule-data") } - extraRuleDataPolicyConfig, err := validate_utils.GetPolicyConfig(ctx, parts[1]) + c, extraRuleDataPolicyConfig, err := validate_utils.GetPolicyConfig(ctx, parts[1]) + ctx = c if err != nil { log.Errorf("Unable to load data from extraRuleData: %s", err.Error()) } diff --git a/cmd/validate/image_integration_test.go b/cmd/validate/image_integration_test.go index c530904e8..1e6e554d9 100644 --- a/cmd/validate/image_integration_test.go +++ b/cmd/validate/image_integration_test.go @@ -71,7 +71,7 @@ func TestEvaluatorLifecycle(t *testing.T) { evaluators[i].On("Destroy").NotBefore(expectations...) } - newConftestEvaluator = func(_ context.Context, s []source.PolicySource, _ evaluator.ConfigProvider, _ v1alpha1.Source) (evaluator.Evaluator, error) { + newConftestEvaluator = func(ctx context.Context, s []source.PolicySource, _ evaluator.ConfigProvider, _ v1alpha1.Source) (evaluator.Evaluator, error) { // We are splitting this url to get to the index of the evaluator. idx, err := strconv.Atoi(strings.Split(strings.Split(s[0].PolicyUrl(), "@")[0], "::")[1]) require.NoError(t, err) @@ -84,7 +84,7 @@ func TestEvaluatorLifecycle(t *testing.T) { validate := func(_ context.Context, component app.SnapshotComponent, _ *app.SnapshotSpec, _ policy.Policy, evaluators []evaluator.Evaluator, _ bool) (*output.Output, error) { for _, e := range evaluators { - _, _, err := e.Evaluate(ctx, evaluator.EvaluationTarget{Inputs: []string{}}) + _, _, _, err := e.Evaluate(ctx, evaluator.EvaluationTarget{Inputs: []string{}}) require.NoError(t, err) } diff --git a/cmd/validate/input.go b/cmd/validate/input.go index e9550c6ff..df4440023 100644 --- a/cmd/validate/input.go +++ b/cmd/validate/input.go @@ -88,7 +88,8 @@ func validateInputCmd(validate InputValidationFunc) *cobra.Command { PreRunE: func(cmd *cobra.Command, args []string) (allErrors error) { ctx := cmd.Context() - policyConfiguration, err := validate_utils.GetPolicyConfig(ctx, data.policyConfiguration) + c, policyConfiguration, err := validate_utils.GetPolicyConfig(ctx, data.policyConfiguration) + cmd.SetContext(c) if err != nil { allErrors = errors.Join(allErrors, err) return diff --git a/cmd/validate/policy.go b/cmd/validate/policy.go index d15d42957..e8a1752db 100644 --- a/cmd/validate/policy.go +++ b/cmd/validate/policy.go @@ -53,7 +53,8 @@ func ValidatePolicyCmd(validate policyValidationFunc) *cobra.Command { PreRunE: func(cmd *cobra.Command, args []string) (allErrors error) { ctx := cmd.Context() - policyConfiguration, err := validate_utils.GetPolicyConfig(ctx, data.policyConfiguration) + c, policyConfiguration, err := validate_utils.GetPolicyConfig(ctx, data.policyConfiguration) + cmd.SetContext(c) if err != nil { allErrors = errors.Join(allErrors, err) return diff --git a/internal/evaluator/conftest_evaluator.go b/internal/evaluator/conftest_evaluator.go index 6b5c26ef8..cee63c23e 100644 --- a/internal/evaluator/conftest_evaluator.go +++ b/internal/evaluator/conftest_evaluator.go @@ -364,7 +364,7 @@ func (r *policyRules) collect(a *ast.AnnotationsRef) error { return nil } -func (c conftestEvaluator) Evaluate(ctx context.Context, target EvaluationTarget) ([]Outcome, Data, error) { +func (c conftestEvaluator) Evaluate(ctx context.Context, target EvaluationTarget) (context.Context, []Outcome, Data, error) { var results []Outcome if trace.IsEnabled() { @@ -379,11 +379,11 @@ func (c conftestEvaluator) Evaluate(ctx context.Context, target EvaluationTarget rules := policyRules{} // Download all sources for _, s := range c.policySources { - dir, err := s.GetPolicy(ctx, c.workDir, false) + ctx, dir, err := s.GetPolicy(ctx, c.workDir, false) if err != nil { log.Debugf("Unable to download source from %s!", s.PolicyUrl()) // TODO do we want to download other policies instead of erroring out? - return nil, nil, err + return ctx, nil, nil, err } annotations := []*ast.AnnotationsRef{} fs := utils.FS(ctx) @@ -396,7 +396,7 @@ func (c conftestEvaluator) Evaluate(ctx context.Context, target EvaluationTarget // Let's try to give some more robust messaging to the user. policyURL, err := url.Parse(s.PolicyUrl()) if err != nil { - return nil, nil, errMsg + return ctx, nil, nil, errMsg } // Do we have a prefix at the end of the URL path? // If not, this means we aren't trying to access a specific file. @@ -410,7 +410,7 @@ func (c conftestEvaluator) Evaluate(ctx context.Context, target EvaluationTarget } } } - return nil, nil, errMsg + return ctx, nil, nil, errMsg } } @@ -419,7 +419,7 @@ func (c conftestEvaluator) Evaluate(ctx context.Context, target EvaluationTarget continue } if err := rules.collect(a); err != nil { - return nil, nil, err + return ctx, nil, nil, err } } } @@ -453,7 +453,7 @@ func (c conftestEvaluator) Evaluate(ctx context.Context, target EvaluationTarget runResults, data, err := r.Run(ctx, target.Inputs) if err != nil { // TODO do we want to evaluate further policies instead of erroring out? - return nil, nil, err + return ctx, nil, nil, err } effectiveTime := c.policy.EffectiveTime() @@ -535,10 +535,10 @@ func (c conftestEvaluator) Evaluate(ctx context.Context, target EvaluationTarget // ran due to input error, etc. if totalRules == 0 { log.Error("no successes, warnings, or failures, check input") - return nil, nil, fmt.Errorf("no successes, warnings, or failures, check input") + return ctx, nil, nil, fmt.Errorf("no successes, warnings, or failures, check input") } - return results, data, nil + return ctx, results, data, nil } func toRules(results []output.Result) []Result { diff --git a/internal/evaluator/conftest_evaluator_test.go b/internal/evaluator/conftest_evaluator_test.go index 6202d92b7..0ac2f6dfc 100644 --- a/internal/evaluator/conftest_evaluator_test.go +++ b/internal/evaluator/conftest_evaluator_test.go @@ -67,8 +67,8 @@ func withTestRunner(ctx context.Context, clnt testRunner) context.Context { type testPolicySource struct{} -func (t testPolicySource) GetPolicy(ctx context.Context, dest string, showMsg bool) (string, error) { - return "/policy", nil +func (t testPolicySource) GetPolicy(ctx context.Context, dest string, showMsg bool) (context.Context, string, error) { + return ctx, "/policy", nil } func (t testPolicySource) PolicyUrl() string { @@ -302,7 +302,7 @@ func TestConftestEvaluatorEvaluateSeverity(t *testing.T) { }, pol, ecc.Source{}) assert.NoError(t, err) - actualResults, data, err := evaluator.Evaluate(ctx, inputs) + _, actualResults, data, err := evaluator.Evaluate(ctx, inputs) assert.NoError(t, err) assert.Equal(t, expectedResults, actualResults) assert.Equal(t, expectedData, data) @@ -413,7 +413,7 @@ func TestConftestEvaluatorEvaluateNoSuccessWarningsOrFailures(t *testing.T) { }, p, ecc.Source{Config: tt.sourceConfig}) assert.NoError(t, err) - actualResults, data, err := evaluator.Evaluate(ctx, inputs) + _, actualResults, data, err := evaluator.Evaluate(ctx, inputs) assert.ErrorContains(t, err, "no successes, warnings, or failures, check input") assert.Nil(t, actualResults) assert.Nil(t, data) @@ -1305,7 +1305,7 @@ func TestConftestEvaluatorIncludeExclude(t *testing.T) { }, p, ecc.Source{}) assert.NoError(t, err) - got, data, err := evaluator.Evaluate(ctx, inputs) + _, got, data, err := evaluator.Evaluate(ctx, inputs) assert.NoError(t, err) assert.Equal(t, tt.want, got) assert.Equal(t, Data(nil), data) @@ -1825,7 +1825,7 @@ func TestConftestEvaluatorEvaluate(t *testing.T) { }, config, ecc.Source{}) require.NoError(t, err) - results, data, err := evaluator.Evaluate(ctx, EvaluationTarget{Inputs: []string{path.Join(dir, "inputs")}}) + _, results, data, err := evaluator.Evaluate(ctx, EvaluationTarget{Inputs: []string{path.Join(dir, "inputs")}}) require.NoError(t, err) // sort the slice by code for test stability @@ -1888,7 +1888,7 @@ func TestUnconformingRule(t *testing.T) { }, p, ecc.Source{}) require.NoError(t, err) - _, _, err = evaluator.Evaluate(ctx, EvaluationTarget{Inputs: []string{path.Join(dir, "inputs")}}) + _, _, _, err = evaluator.Evaluate(ctx, EvaluationTarget{Inputs: []string{path.Join(dir, "inputs")}}) assert.EqualError(t, err, `the rule "deny = true { true }" returns an unsupported value, at no_msg.rego:3`) } diff --git a/internal/evaluator/evaluator.go b/internal/evaluator/evaluator.go index 3b03c7ac6..bd888f995 100644 --- a/internal/evaluator/evaluator.go +++ b/internal/evaluator/evaluator.go @@ -26,7 +26,7 @@ type EvaluationTarget struct { } type Evaluator interface { - Evaluate(ctx context.Context, target EvaluationTarget) ([]Outcome, Data, error) + Evaluate(ctx context.Context, target EvaluationTarget) (context.Context, []Outcome, Data, error) // Destroy performs any cleanup needed Destroy() diff --git a/internal/image/validate.go b/internal/image/validate.go index 9240afae9..8517de844 100644 --- a/internal/image/validate.go +++ b/internal/image/validate.go @@ -122,7 +122,8 @@ func ValidateImage(ctx context.Context, comp app.SnapshotComponent, snap *app.Sn } else { target.Target = digest } - results, data, err := e.Evaluate(ctx, target) + c, results, data, err := e.Evaluate(ctx, target) + ctx = c log.Debug("\n\nRunning conftest policy check\n\n") if err != nil { diff --git a/internal/image/validate_test.go b/internal/image/validate_test.go index 057da2741..89de32cf4 100644 --- a/internal/image/validate_test.go +++ b/internal/image/validate_test.go @@ -289,10 +289,10 @@ type mockEvaluator struct { mock.Mock } -func (e *mockEvaluator) Evaluate(ctx context.Context, target evaluator.EvaluationTarget) ([]evaluator.Outcome, evaluator.Data, error) { +func (e *mockEvaluator) Evaluate(ctx context.Context, target evaluator.EvaluationTarget) (context.Context, []evaluator.Outcome, evaluator.Data, error) { args := e.Called(ctx, target.Inputs) - return args.Get(0).([]evaluator.Outcome), args.Get(1).(evaluator.Data), args.Error(2) + return ctx, args.Get(0).([]evaluator.Outcome), args.Get(1).(evaluator.Data), args.Error(2) } func (e *mockEvaluator) Destroy() { diff --git a/internal/input/validate.go b/internal/input/validate.go index 00e8ac04d..965b21cae 100644 --- a/internal/input/validate.go +++ b/internal/input/validate.go @@ -55,7 +55,8 @@ func ValidateInput(ctx context.Context, fpath string, policy policy.Policy, deta var allResults []evaluator.Outcome for _, e := range p.Evaluators { - results, _, err := e.Evaluate(ctx, evaluator.EvaluationTarget{Inputs: inputFiles}) + c, results, _, err := e.Evaluate(ctx, evaluator.EvaluationTarget{Inputs: inputFiles}) + ctx = c if err != nil { return nil, fmt.Errorf("evaluating policy: %w", err) } diff --git a/internal/input/validate_test.go b/internal/input/validate_test.go index c876ab0cb..7c88dc5a1 100644 --- a/internal/input/validate_test.go +++ b/internal/input/validate_test.go @@ -40,8 +40,8 @@ type ( badMockEvaluator struct{} ) -func (e mockEvaluator) Evaluate(ctx context.Context, target evaluator.EvaluationTarget) ([]evaluator.Outcome, evaluator.Data, error) { - return []evaluator.Outcome{}, nil, nil +func (e mockEvaluator) Evaluate(ctx context.Context, target evaluator.EvaluationTarget) (context.Context, []evaluator.Outcome, evaluator.Data, error) { + return ctx, []evaluator.Outcome{}, nil, nil } func (e mockEvaluator) Destroy() { @@ -51,8 +51,8 @@ func (e mockEvaluator) CapabilitiesPath() string { return "" } -func (b badMockEvaluator) Evaluate(ctx context.Context, target evaluator.EvaluationTarget) ([]evaluator.Outcome, evaluator.Data, error) { - return nil, nil, errors.New("Evaluator error") +func (b badMockEvaluator) Evaluate(ctx context.Context, target evaluator.EvaluationTarget) (context.Context, []evaluator.Outcome, evaluator.Data, error) { + return ctx, nil, nil, errors.New("Evaluator error") } func (e badMockEvaluator) Destroy() { diff --git a/internal/policy/policy.go b/internal/policy/policy.go index c6d432cd0..07fc5eb6b 100644 --- a/internal/policy/policy.go +++ b/internal/policy/policy.go @@ -567,17 +567,17 @@ func validatePolicyConfig(policyConfig string) error { // PreProcessPolicy fetches policy sources and returns a policy object with // pinned SHA/image digest URL where applicable, along with a policy cache object. -func PreProcessPolicy(ctx context.Context, policyOptions Options) (Policy, *cache.PolicyCache, error) { +func PreProcessPolicy(ctx context.Context, policyOptions Options) (context.Context, Policy, *cache.PolicyCache, error) { var policyCache *cache.PolicyCache pinnedPolicyUrls := map[string][]string{} policyCache, err := cache.NewPolicyCache(ctx) if err != nil { - return nil, nil, err + return ctx, nil, nil, err } p, err := NewPolicy(ctx, policyOptions) if err != nil { - return nil, nil, err + return ctx, nil, nil, err } sources := p.Spec().Sources @@ -589,7 +589,7 @@ func PreProcessPolicy(ctx context.Context, policyOptions Options) (Policy, *cach dir, err := utils.CreateWorkDir(fs) if err != nil { log.Debug("Failed to create work dir!") - return nil, nil, err + return ctx, nil, nil, err } for _, policySource := range policySources { @@ -597,10 +597,10 @@ func PreProcessPolicy(ctx context.Context, policyOptions Options) (Policy, *cach continue } - destDir, err := policySource.GetPolicy(ctx, dir, false) + ctx, destDir, err := policySource.GetPolicy(ctx, dir, false) if err != nil { log.Debugf("Unable to download source from %s!", policySource.PolicyUrl()) - return nil, nil, err + return ctx, nil, nil, err } log.Debugf("Downloaded policy source from %s to %s\n", policySource.PolicyUrl(), destDir) @@ -626,7 +626,7 @@ func PreProcessPolicy(ctx context.Context, policyOptions Options) (Policy, *cach } } - return p, policyCache, err + return ctx, p, policyCache, err } func urls(s []source.PolicySource, kind source.PolicyType) []string { diff --git a/internal/policy/source/git_config.go b/internal/policy/source/git_config.go index d615a2f50..02d76754c 100644 --- a/internal/policy/source/git_config.go +++ b/internal/policy/source/git_config.go @@ -54,16 +54,16 @@ func SourceIsHttp(src string) bool { return err == nil && strings.HasPrefix(normalizedUrl, "http") } -func GoGetterDownload(ctx context.Context, tmpDir, src string) (string, error) { +func GoGetterDownload(ctx context.Context, tmpDir, src string) (context.Context, string, error) { // Download the config from a url c := PolicyUrl{ Url: src, Kind: ConfigKind, } - configDir, err := c.GetPolicy(ctx, tmpDir, false) + ctx, configDir, err := c.GetPolicy(ctx, tmpDir, false) if err != nil { log.Debugf("Failed to download policy config from %s", c.Url) - return "", err + return ctx, "", err } log.Debugf("Downloaded policy config from %s to %s", c.Url, configDir) @@ -71,8 +71,8 @@ func GoGetterDownload(ctx context.Context, tmpDir, src string) (string, error) { configFile, err := choosePolicyFile(ctx, configDir) if err != nil { // A more useful error message: - return "", fmt.Errorf("no suitable config file found at %s", c.Url) + return ctx, "", fmt.Errorf("no suitable config file found at %s", c.Url) } log.Debugf("Chose file %s to use for the policy config", configFile) - return configFile, nil + return ctx, configFile, nil } diff --git a/internal/policy/source/source.go b/internal/policy/source/source.go index e49a8fccc..a0cdc32b7 100644 --- a/internal/policy/source/source.go +++ b/internal/policy/source/source.go @@ -54,6 +54,7 @@ const ( DataKind PolicyType = "data" ConfigKind PolicyType = "config" InlineDataKind PolicyType = "inline-data" + downloadCacheKey key = 1 ) type downloaderFunc interface { @@ -63,7 +64,7 @@ type downloaderFunc interface { // PolicySource in an interface representing the location a policy source. // Must implement the GetPolicy() method. type PolicySource interface { - GetPolicy(ctx context.Context, dest string, showMsg bool) (string, error) + GetPolicy(ctx context.Context, dest string, showMsg bool) (context.Context, string, error) PolicyUrl() string Subdir() string Type() PolicyType @@ -75,19 +76,23 @@ type PolicyUrl struct { Kind PolicyType } -// downloadCache is a concurrent map used to cache downloaded files. -var downloadCache sync.Map - type cacheContent struct { sourceUrl string metadata metadata.Metadata err error } -func getPolicyThroughCache(ctx context.Context, s PolicySource, workDir string, dl func(string, string) (metadata.Metadata, error)) (string, metadata.Metadata, error) { +func getPolicyThroughCache(ctx context.Context, s PolicySource, workDir string, dl func(string, string) (metadata.Metadata, error)) (context.Context, string, metadata.Metadata, error) { sourceUrl := s.PolicyUrl() dest := uniqueDestination(workDir, s.Subdir(), sourceUrl) + // downloadCache is a concurrent map used to cache downloaded files. + downloadCache, ok := ctx.Value(downloadCacheKey).(*sync.Map) + if !ok { + downloadCache = &sync.Map{} + ctx = context.WithValue(ctx, downloadCacheKey, downloadCache) + } + // Load or store the downloaded policy file from the given source URL. // If the file is already in the download cache, it is loaded from there. // Otherwise, it is downloaded from the source URL and stored in the cache. @@ -102,12 +107,12 @@ func getPolicyThroughCache(ctx context.Context, s PolicySource, workDir string, d, c := dfn.(func() (string, cacheContent))() if c.err != nil { - return "", c.metadata, c.err + return ctx, "", c.metadata, c.err } fs := utils.FS(ctx) if _, err := fs.Stat(dest); err == nil { - return dest, c.metadata, nil + return ctx, dest, c.metadata, nil } // If the destination directory is different from the source directory, we @@ -115,32 +120,32 @@ func getPolicyThroughCache(ctx context.Context, s PolicySource, workDir string, if filepath.Dir(dest) != filepath.Dir(d) { base := filepath.Dir(dest) if err := fs.MkdirAll(base, 0755); err != nil { - return "", nil, err + return ctx, "", nil, err } if symlinkableFS, ok := fs.(afero.Symlinker); ok { log.Debugf("Symlinking %s to %s", d, dest) if err := symlinkableFS.SymlinkIfPossible(d, dest); err != nil { - return "", nil, err + return ctx, "", nil, err } logMetadata(c.metadata) - return dest, c.metadata, nil + return ctx, dest, c.metadata, nil } else { log.Debugf("Filesystem does not support symlinking: %q, re-downloading instead", fs.Name()) m, err := dl(sourceUrl, dest) logMetadata(m) - return dest, m, err + return ctx, dest, m, err } } if c.metadata != nil { logMetadata(c.metadata) } - return d, c.metadata, c.err + return ctx, d, c.metadata, c.err } // GetPolicies clones the repository for a given PolicyUrl -func (p *PolicyUrl) GetPolicy(ctx context.Context, workDir string, showMsg bool) (string, error) { +func (p *PolicyUrl) GetPolicy(ctx context.Context, workDir string, showMsg bool) (context.Context, string, error) { if trace.IsEnabled() { region := trace.StartRegion(ctx, "ec:get-policy") defer region.End() @@ -155,18 +160,18 @@ func (p *PolicyUrl) GetPolicy(ctx context.Context, workDir string, showMsg bool) return downloader.Download(ctx, dest, source, showMsg) } - dest, metadata, err := getPolicyThroughCache(ctx, p, workDir, dl) + ctx, dest, metadata, err := getPolicyThroughCache(ctx, p, workDir, dl) if err != nil { - return "", err + return ctx, "", err } p.Url, err = metadata.GetPinnedURL(p.Url) log.Debug("Pinned URL: ", p.Url) if err != nil { - return "", err + return ctx, "", err } - return dest, err + return ctx, dest, err } func (p *PolicyUrl) PolicyUrl() string { @@ -215,7 +220,7 @@ func InlineData(source []byte) PolicySource { return inlineData{source} } -func (s inlineData) GetPolicy(ctx context.Context, workDir string, showMsg bool) (string, error) { +func (s inlineData) GetPolicy(ctx context.Context, workDir string, showMsg bool) (context.Context, string, error) { dl := func(source string, dest string) (metadata.Metadata, error) { fs := utils.FS(ctx) @@ -233,8 +238,8 @@ func (s inlineData) GetPolicy(ctx context.Context, workDir string, showMsg bool) return m, afero.WriteFile(fs, f, s.source, 0400) } - dest, _, err := getPolicyThroughCache(ctx, s, workDir, dl) - return dest, err + ctx, dest, _, err := getPolicyThroughCache(ctx, s, workDir, dl) + return ctx, dest, err } func (s inlineData) PolicyUrl() string { diff --git a/internal/policy/source/source_test.go b/internal/policy/source/source_test.go index 87bc36127..bd81560ea 100644 --- a/internal/policy/source/source_test.go +++ b/internal/policy/source/source_test.go @@ -98,7 +98,7 @@ func TestGetPolicy(t *testing.T) { return matched }), tt.sourceUrl, false).Return(tt.metadata, tt.err) - _, err := p.GetPolicy(usingDownloader(context.TODO(), &dl), "/tmp/ec-work-1234", false) + _, _, err := p.GetPolicy(usingDownloader(context.TODO(), &dl), "/tmp/ec-work-1234", false) if tt.err == nil { assert.NoError(t, err, "GetPolicies returned an error") } else { @@ -121,7 +121,7 @@ func TestInlineDataSource(t *testing.T) { ctx := utils.WithFS(context.Background(), fs) - dest, err := s.GetPolicy(ctx, temp, false) + _, dest, err := s.GetPolicy(ctx, temp, false) require.NoError(t, err) file := path.Join(dest, "rule_data.json") @@ -186,9 +186,9 @@ type mockPolicySource struct { *mock.Mock } -func (m mockPolicySource) GetPolicy(ctx context.Context, dest string, msgs bool) (string, error) { +func (m mockPolicySource) GetPolicy(ctx context.Context, dest string, msgs bool) (context.Context, string, error) { args := m.Called(ctx, dest, msgs) - return args.String(0), args.Error(1) + return ctx, args.String(0), args.Error(1) } func (m mockPolicySource) PolicyUrl() string { @@ -208,10 +208,6 @@ func (m mockPolicySource) Type() PolicyType { func TestGetPolicyThroughCache(t *testing.T) { test := func(t *testing.T, fs afero.Fs, expectedDownloads int) { - t.Cleanup(func() { - downloadCache = sync.Map{} - }) - ctx := utils.WithFS(context.Background(), fs) invocations := 0 @@ -229,10 +225,10 @@ func TestGetPolicyThroughCache(t *testing.T) { source.On("PolicyUrl").Return("policy-url") source.On("Subdir").Return("subdir") - s1, _, err := getPolicyThroughCache(ctx, source, "/workdir1", dl) + ctx, s1, _, err := getPolicyThroughCache(ctx, source, "/workdir1", dl) require.NoError(t, err) - s2, _, err := getPolicyThroughCache(ctx, source, "/workdir2", dl) + _, s2, _, err := getPolicyThroughCache(ctx, source, "/workdir2", dl) require.NoError(t, err) assert.NotEqual(t, s1, s2) @@ -273,9 +269,6 @@ func TestGetPolicyThroughCache(t *testing.T) { // symbolic links pointing to the same policy download within the same workdir // causing Rego compile issue func TestDownloadCacheWorkdirMismatch(t *testing.T) { - t.Cleanup(func() { - downloadCache = sync.Map{} - }) tmp := t.TempDir() source := &mockPolicySource{&mock.Mock{}} @@ -285,6 +278,8 @@ func TestDownloadCacheWorkdirMismatch(t *testing.T) { // same URL downloaded to workdir1 precachedDest := uniqueDestination(tmp, "subdir", source.PolicyUrl()) require.NoError(t, os.MkdirAll(precachedDest, 0755)) + + downloadCache := sync.Map{} downloadCache.Store("policy-url", func() (string, cacheContent) { return precachedDest, cacheContent{} }) @@ -292,13 +287,15 @@ func TestDownloadCacheWorkdirMismatch(t *testing.T) { // when working in workdir2 workdir2 := filepath.Join(tmp, "workdir2") + ctx := context.WithValue(context.Background(), downloadCacheKey, &downloadCache) + // first invocation symlinks back to workdir1 - destination1, _, err := getPolicyThroughCache(context.Background(), source, workdir2, func(s1, s2 string) (metadata.Metadata, error) { return nil, nil }) + ctx, destination1, _, err := getPolicyThroughCache(ctx, source, workdir2, func(s1, s2 string) (metadata.Metadata, error) { return nil, nil }) require.NoError(t, err) // second invocation should not create a second symlink and duplicate the // source files within workdir2 - destination2, _, err := getPolicyThroughCache(context.Background(), source, workdir2, func(s1, s2 string) (metadata.Metadata, error) { return nil, nil }) + _, destination2, _, err := getPolicyThroughCache(ctx, source, workdir2, func(s1, s2 string) (metadata.Metadata, error) { return nil, nil }) require.NoError(t, err) assert.Equal(t, destination1, destination2) diff --git a/internal/validate/helpers.go b/internal/validate/helpers.go index d2516ddba..89da572b3 100644 --- a/internal/validate/helpers.go +++ b/internal/validate/helpers.go @@ -28,7 +28,7 @@ import ( ) // Determine policyConfig -func GetPolicyConfig(ctx context.Context, policyConfiguration string) (string, error) { +func GetPolicyConfig(ctx context.Context, policyConfiguration string) (context.Context, string, error) { // If policyConfiguration is not detected as a file and is detected as a git URL, // or if policyConfiguration is an https URL try to download a config file from // the provided source. If successful we read its contents and return it. @@ -41,27 +41,29 @@ func GetPolicyConfig(ctx context.Context, policyConfiguration string) (string, e fs := utils.FS(ctx) tmpDir, err := utils.CreateWorkDir(fs) if err != nil { - return "", err + return ctx, "", err } defer utils.CleanupWorkDir(fs, tmpDir) // Git download and find a suitable config file - configFile, err := source.GoGetterDownload(ctx, tmpDir, policyConfiguration) + ctx, configFile, err := source.GoGetterDownload(ctx, tmpDir, policyConfiguration) if err != nil { - return "", err + return ctx, "", err } log.Debugf("Loading %s as policy configuration", configFile) - return ReadFile(ctx, configFile) + content, err := ReadFile(ctx, configFile) + return ctx, content, err } else if source.SourceIsFile(policyConfiguration) && utils.HasJsonOrYamlExt(policyConfiguration) { // If policyConfiguration is detected as a file and it has a json or yaml extension, // we read its contents and return it. log.Debugf("Loading %s as policy configuration", policyConfiguration) - return ReadFile(ctx, policyConfiguration) + content, err := ReadFile(ctx, policyConfiguration) + return ctx, content, err } // If policyConfiguration is not a file path, git url, or https url, // we assume it's a string and return it as is. - return policyConfiguration, nil + return ctx, policyConfiguration, nil } // Read file from the workspace and return its contents.