diff --git a/enricher/epss/epss.go b/enricher/epss/epss.go index c790cedd4..ad5546f46 100644 --- a/enricher/epss/epss.go +++ b/enricher/epss/epss.go @@ -1,17 +1,12 @@ package epss import ( - "bufio" "compress/gzip" "context" + "encoding/csv" "encoding/json" + "errors" "fmt" - "github.com/google/uuid" - "github.com/pkg/errors" - "github.com/quay/claircore" - "github.com/quay/claircore/libvuln/driver" - "github.com/quay/claircore/pkg/tmp" - "github.com/quay/zlog" "io" "net/http" "net/url" @@ -21,15 +16,26 @@ import ( "strconv" "strings" "time" + + "github.com/quay/claircore" + "github.com/quay/claircore/libvuln/driver" + "github.com/quay/claircore/pkg/tmp" + "github.com/quay/zlog" ) var ( _ driver.Enricher = (*Enricher)(nil) _ driver.EnrichmentUpdater = (*Enricher)(nil) - - defaultFeed *url.URL ) +type EPSSItem struct { + ModelVersion string `json:"modelVersion"` + Date string `json:"date"` + CVE string `json:"cve"` + EPSS float64 `json:"epss"` + Percentile float64 `json:"percentile"` +} + // This is a slightly more relaxed version of the validation pattern in the NVD // JSON schema: https://csrc.nist.gov/schema/nvd/feed/1.1/CVE_JSON_4.0_min_1.1.schema // @@ -41,9 +47,9 @@ const ( // Type is the type of data returned from the Enricher's Enrich method. Type = `message/vnd.clair.map.vulnerability; enricher=clair.epss schema=https://csrc.nist.gov/schema/nvd/feed/1.1/cvss-v3.x.json` - // DefaultFeeds is the default place to look for EPSS feeds. + // DefaultFeed is the default place to look for EPSS feeds. // epss_scores-YYYY-MM-DD.csv.gz needs to be specified to get all data - DefaultFeeds = `https://epss.cyentia.com/` + DefaultFeed = `https://epss.cyentia.com/` // epssName is the name of the enricher epssName = `clair.epss` @@ -51,7 +57,6 @@ const ( func init() { var err error - defaultFeed, err = url.Parse(DefaultFeeds) if err != nil { panic(err) } @@ -76,43 +81,43 @@ func (e *Enricher) Configure(ctx context.Context, f driver.ConfigUnmarshaler, c ctx = zlog.ContextWithValues(ctx, "component", "enricher/epss/Enricher/Configure") var cfg Config e.c = c + e.feedPath = currentFeedURL() if f == nil { - zlog.Warn(ctx).Msg("No configuration provided; proceeding with default settings") - e.sourceURL() + zlog.Debug(ctx).Msg("No configuration provided; proceeding with default settings") return nil } if err := f(&cfg); err != nil { return err } - if cfg.FeedRoot != nil { // validate the URL format if _, err := url.Parse(*cfg.FeedRoot); err != nil { return fmt.Errorf("invalid URL format for FeedRoot: %w", err) } - // Check for a .gz file + // only .gz file is supported if strings.HasSuffix(*cfg.FeedRoot, ".gz") { + //overwrite feedPath is cfg provides another feed path e.feedPath = *cfg.FeedRoot } else { - e.sourceURL() // Fallback to the default source URL if not a .gz file + return fmt.Errorf("invalid feed root: expected a '.gz' file, but got '%q'", *cfg.FeedRoot) } - } else { - e.sourceURL() } return nil } -func (e *Enricher) FetchEnrichment(ctx context.Context, _ driver.Fingerprint) (io.ReadCloser, driver.Fingerprint, error) { +// FetchEnrichment implements driver.EnrichmentUpdater. +func (e *Enricher) FetchEnrichment(ctx context.Context, prevFingerprint driver.Fingerprint) (io.ReadCloser, driver.Fingerprint, error) { ctx = zlog.ContextWithValues(ctx, "component", "enricher/epss/Enricher/FetchEnrichment") - newUUID := uuid.New() - hint := driver.Fingerprint(newUUID.String()) - zlog.Info(ctx).Str("hint", string(hint)).Msg("starting fetch") - out, err := tmp.NewFile("", "enricher.epss.*.json") + if e.feedPath == "" || !strings.HasSuffix(e.feedPath, ".gz") { + return nil, "", fmt.Errorf("invalid feed path: %q must be non-empty and end with '.gz'", e.feedPath) + } + + out, err := tmp.NewFile("", "epss.") if err != nil { - return nil, hint, err + return nil, "", err } var success bool defer func() { @@ -123,103 +128,146 @@ func (e *Enricher) FetchEnrichment(ctx context.Context, _ driver.Fingerprint) (i } }() - if e.feedPath == "" || !strings.HasSuffix(e.feedPath, ".gz") { - e.sourceURL() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, e.feedPath, nil) + if err != nil { + return nil, "", fmt.Errorf("unable to create request for %s: %w", e.feedPath, err) } - resp, err := http.Get(e.feedPath) + resp, err := e.c.Do(req) if err != nil { - return nil, "", fmt.Errorf("failed to fetch file from %s: %w", e.feedPath, err) + return nil, "", fmt.Errorf("unable to fetch file from %s: %w", e.feedPath, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, "", fmt.Errorf("failed to fetch file: received status %d", resp.StatusCode) + return nil, "", fmt.Errorf("unable to fetch file: received status %d", resp.StatusCode) + } + + etag := resp.Header.Get("ETag") + if etag == "" { + return nil, "", fmt.Errorf("ETag not found in response headers") + } + + newFingerprint := driver.Fingerprint(etag) + + if prevFingerprint == newFingerprint { + zlog.Info(ctx).Str("fingerprint", string(newFingerprint)).Msg("file unchanged; skipping processing") + return nil, prevFingerprint, nil } gzipReader, err := gzip.NewReader(resp.Body) if err != nil { - return nil, "", fmt.Errorf("failed to decompress file: %w", err) + return nil, "", fmt.Errorf("unable to decompress file: %w", err) } defer gzipReader.Close() - scanner := bufio.NewScanner(gzipReader) - var headers []string + csvReader := csv.NewReader(gzipReader) + csvReader.FieldsPerRecord = -1 // Allow variable-length fields + + // assume metadata is always in the first line + record, err := csvReader.Read() + if err != nil { + return nil, "", fmt.Errorf("failed to read metadata line: %w", err) + } + + var modelVersion, date string + for _, field := range record { + field = strings.TrimSpace(field) + if strings.HasPrefix(field, "#") { + field = strings.TrimPrefix(field, "#") + } + kv := strings.SplitN(field, ":", 2) + if len(kv) == 2 { + switch strings.TrimSpace(kv[0]) { + case "model_version": + modelVersion = strings.TrimSpace(kv[1]) + case "score_date": + date = strings.TrimSpace(kv[1]) + } + } + } + + if modelVersion == "" || date == "" { + return nil, "", fmt.Errorf("missing metadata fields in record: %v", record) + } + + csvReader.Comment = '#' // Ignore subsequent comment lines + + record, err = csvReader.Read() + if err != nil { + return nil, "", fmt.Errorf("unable to read header line: %w", err) + } + if len(record) < 3 || record[0] != "cve" || record[1] != "epss" || record[2] != "percentile" { + return nil, "", fmt.Errorf("unexpected CSV headers: %v", record) + } + headers := record + enc := json.NewEncoder(out) totalCVEs := 0 - var modelVersion, date string - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue + for { + record, err = csvReader.Read() + if errors.Is(err, io.EOF) { + break } - // assume metadata is always available at first comment of the file - if strings.HasPrefix(line, "#") && date == "" && modelVersion == "" { - modelVersion, date = parseMetadata(line) - zlog.Info(ctx). - Str("modelVersion", modelVersion). - Str("scoreDate", date). - Msg("parsed metadata") - continue - } - if headers == nil { - headers = strings.Split(line, ",") - continue + if err != nil { + return nil, "", fmt.Errorf("unable to read line in CSV: %w", err) } - record := strings.Split(line, ",") if len(record) != len(headers) { - zlog.Warn(ctx).Str("line", line).Msg("skipping line with mismatched fields") + zlog.Warn(ctx).Str("record", fmt.Sprintf("%v", record)).Msg("skipping record with mismatched fields") continue } r, err := newItemFeed(record, headers, modelVersion, date) if err != nil { - return nil, "", err + zlog.Warn(ctx).Str("record", fmt.Sprintf("%v", record)).Msg("skipping invalid record") + continue } if err = enc.Encode(&r); err != nil { - return nil, "", fmt.Errorf("failed to write JSON line to file: %w", err) + return nil, "", fmt.Errorf("unable to write JSON line to file: %w", err) } totalCVEs++ } - if err := scanner.Err(); err != nil { - return nil, "", fmt.Errorf("error reading file: %w", err) - } - zlog.Info(ctx).Int("totalCVEs", totalCVEs).Msg("processed CVEs") if _, err := out.Seek(0, io.SeekStart); err != nil { - return nil, hint, fmt.Errorf("unable to reset file pointer: %w", err) + return nil, newFingerprint, fmt.Errorf("unable to reset file pointer: %w", err) } success = true - return out, hint, nil + return out, newFingerprint, nil } // ParseEnrichment implements driver.EnrichmentUpdater. func (e *Enricher) ParseEnrichment(ctx context.Context, rc io.ReadCloser) ([]driver.EnrichmentRecord, error) { ctx = zlog.ContextWithValues(ctx, "component", "enricher/epss/Enricher/ParseEnrichment") - // Our Fetch method actually has all the smarts w/r/t to constructing the - // records, so this is just decoding in a loop. + defer func() { _ = rc.Close() }() - var err error + dec := json.NewDecoder(rc) - ret := make([]driver.EnrichmentRecord, 0, 250_000) // Wild guess at initial capacity. - // This is going to allocate like mad, hold onto your butts. - for err == nil { - ret = append(ret, driver.EnrichmentRecord{}) - err = dec.Decode(&ret[len(ret)-1]) + ret := make([]driver.EnrichmentRecord, 0, 250_000) + var err error + + for { + var record driver.EnrichmentRecord + if err = dec.Decode(&record); err != nil { + break + } + ret = append(ret, record) } + zlog.Debug(ctx). - Int("count", len(ret)-1). + Int("count", len(ret)). Msg("decoded enrichments") + if !errors.Is(err, io.EOF) { - return nil, err + return nil, fmt.Errorf("error decoding enrichment records: %w", err) } + return ret, nil } @@ -227,18 +275,18 @@ func (*Enricher) Name() string { return epssName } -func (e *Enricher) sourceURL() { +func currentFeedURL() string { currentDate := time.Now() formattedDate := currentDate.Format("2006-01-02") filePath := fmt.Sprintf("epss_scores-%s.csv.gz", formattedDate) - feedURL, err := url.Parse(DefaultFeeds) + feedURL, err := url.Parse(DefaultFeed) if err != nil { panic(fmt.Errorf("invalid default feed URL: %w", err)) } feedURL.Path = path.Join(feedURL.Path, filePath) - e.feedPath = feedURL.String() + return feedURL.String() } func (e *Enricher) Enrich(ctx context.Context, g driver.EnrichmentGetter, r *claircore.VulnerabilityReport) (string, []json.RawMessage, error) { @@ -284,7 +332,6 @@ func (e *Enricher) Enrich(ctx context.Context, g driver.EnrichmentGetter, r *cla sort.Strings(ts) cveKey := strings.Join(ts, "_") - zlog.Debug(ctx).Str("cve_key", cveKey).Strs("cve", ts).Msg("generated CVE cache key") rec, ok := erCache[cveKey] if !ok { @@ -324,57 +371,42 @@ func (e *Enricher) Enrich(ctx context.Context, g driver.EnrichmentGetter, r *cla } func newItemFeed(record []string, headers []string, modelVersion string, scoreDate string) (driver.EnrichmentRecord, error) { - item := make(map[string]interface{}) // Use interface{} to allow mixed types + if len(record) != len(headers) { + return driver.EnrichmentRecord{}, fmt.Errorf("record and headers length mismatch") + } + + var item EPSSItem for i, value := range record { - // epss details are numeric values - if f, err := strconv.ParseFloat(value, 64); err == nil { - item[headers[i]] = f - } else { - item[headers[i]] = value + switch headers[i] { + case "cve": + item.CVE = value + case "epss": + if f, err := strconv.ParseFloat(value, 64); err == nil { + item.EPSS = f + } else { + return driver.EnrichmentRecord{}, fmt.Errorf("invalid float for epss: %w", err) + } + case "percentile": + if f, err := strconv.ParseFloat(value, 64); err == nil { + item.Percentile = f + } else { + return driver.EnrichmentRecord{}, fmt.Errorf("invalid float for percentile: %w", err) + } } } - if modelVersion != "" { - item["modelVersion"] = modelVersion - } - if scoreDate != "" { - item["date"] = scoreDate - } + item.ModelVersion = modelVersion + item.Date = scoreDate enrichment, err := json.Marshal(item) if err != nil { - return driver.EnrichmentRecord{}, fmt.Errorf("failed to encode enrichment: %w", err) + return driver.EnrichmentRecord{}, fmt.Errorf("unable to encode enrichment: %w", err) } r := driver.EnrichmentRecord{ - Tags: []string{item["cve"].(string)}, // Ensure the "cve" field is a string + Tags: []string{item.CVE}, // CVE field should be set Enrichment: enrichment, } return r, nil } - -func parseMetadata(line string) (modelVersion string, scoreDate string) { - // Set default values - modelVersion = "N/A" - scoreDate = "0001-01-01" - - trimmedLine := strings.TrimPrefix(line, "#") - parts := strings.Split(trimmedLine, ",") - for _, part := range parts { - keyValue := strings.SplitN(part, ":", 2) - if len(keyValue) == 2 { - key := strings.TrimSpace(keyValue[0]) - value := strings.TrimSpace(keyValue[1]) - - switch key { - case "score_date": - scoreDate = value - case "model_version": - modelVersion = value - } - } - } - - return modelVersion, scoreDate -} diff --git a/enricher/epss/epss_test.go b/enricher/epss/epss_test.go index 6d4a0c57a..3ee4059b7 100644 --- a/enricher/epss/epss_test.go +++ b/enricher/epss/epss_test.go @@ -5,10 +5,6 @@ import ( "context" "encoding/json" "errors" - "github.com/google/go-cmp/cmp" - "github.com/quay/claircore" - "github.com/quay/claircore/libvuln/driver" - "github.com/quay/zlog" "io" "log" "net/http" @@ -17,6 +13,12 @@ import ( "path" "path/filepath" "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/quay/claircore" + "github.com/quay/claircore/libvuln/driver" + "github.com/quay/zlog" ) func TestConfigure(t *testing.T) { @@ -32,7 +34,7 @@ func TestConfigure(t *testing.T) { }, }, { - Name: "OK", // URL without .gz will be replaced with default URL + Name: "Not OK", // URL without .gz is invalid Config: func(i interface{}) error { cfg := i.(*Config) s := "http://example.com/" @@ -40,8 +42,8 @@ func TestConfigure(t *testing.T) { return nil }, Check: func(t *testing.T, err error) { - if err != nil { - t.Errorf("unexpected error with .gz URL: %v", err) + if err == nil { + t.Errorf("expected invalid URL error, but got none: %v", err) } }, }, @@ -161,28 +163,36 @@ func noopConfig(_ interface{}) error { return nil } func mockServer(t *testing.T) *httptest.Server { const root = `testdata/` + + // Define a static ETag for testing purposes + const etagValue = `"test-etag-12345"` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch path.Ext(r.URL.Path) { case ".gz": // only gz feed is supported + w.Header().Set("ETag", etagValue) + f, err := os.Open(filepath.Join(root, "data.csv")) if err != nil { t.Errorf("open failed: %v", err) w.WriteHeader(http.StatusInternalServerError) - break + return } defer f.Close() + gz := gzip.NewWriter(w) defer gz.Close() if _, err := io.Copy(gz, f); err != nil { t.Errorf("write error: %v", err) w.WriteHeader(http.StatusInternalServerError) - break + return } default: t.Errorf("unknown request path: %q", r.URL.Path) w.WriteHeader(http.StatusBadRequest) } })) + t.Cleanup(srv.Close) return srv } @@ -255,7 +265,9 @@ func (tc parseTestcase) Run(ctx context.Context, srv *httptest.Server) func(*tes if err := e.Configure(ctx, f, srv.Client()); err != nil { t.Errorf("unexpected error: %v", err) } - rc, _, err := e.FetchEnrichment(ctx, "") + + hint := driver.Fingerprint("test-e-tag-54321") + rc, _, err := e.FetchEnrichment(ctx, hint) if err != nil { t.Errorf("unexpected error: %v", err) } diff --git a/go.mod b/go.mod index e6ce7480e..345028d43 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,6 @@ require ( github.com/knqyf263/go-deb-version v0.0.0-20190517075300-09fca494f03d github.com/knqyf263/go-rpm-version v0.0.0-20170716094938-74609b86c936 github.com/package-url/packageurl-go v0.1.3 - github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.20.5 github.com/quay/claircore/toolkit v1.2.4 github.com/quay/claircore/updater/driver v1.0.0