diff --git a/revocation/crl/bundle.go b/revocation/crl/bundle.go new file mode 100644 index 00000000..63b7e0f4 --- /dev/null +++ b/revocation/crl/bundle.go @@ -0,0 +1,28 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crl + +import "crypto/x509" + +// Bundle is a collection of CRLs, including base and delta CRLs +type Bundle struct { + // BaseCRL is the parsed base CRL + BaseCRL *x509.RevocationList + + // DeltaCRL is the parsed delta CRL + // + // TODO: support delta CRL https://github.com/notaryproject/notation-core-go/issues/228 + // It will always be nil until we support delta CRL + DeltaCRL *x509.RevocationList +} diff --git a/revocation/crl/cache.go b/revocation/crl/cache.go new file mode 100644 index 00000000..0410dfed --- /dev/null +++ b/revocation/crl/cache.go @@ -0,0 +1,32 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crl + +import "context" + +// Cache is an interface that specifies methods used for caching +type Cache interface { + // Get retrieves the CRL bundle with the given url + // + // url is the key to retrieve the CRL bundle + // + // if the key does not exist or the content is expired, return ErrCacheMiss. + Get(ctx context.Context, url string) (*Bundle, error) + + // Set stores the CRL bundle with the given url + // + // url is the key to store the CRL bundle + // bundle is the CRL collections to store + Set(ctx context.Context, url string, bundle *Bundle) error +} diff --git a/revocation/internal/crl/errors.go b/revocation/crl/errors.go similarity index 76% rename from revocation/internal/crl/errors.go rename to revocation/crl/errors.go index 37866551..a1978910 100644 --- a/revocation/internal/crl/errors.go +++ b/revocation/crl/errors.go @@ -15,8 +15,5 @@ package crl import "errors" -var ( - // ErrDeltaCRLNotSupported is returned when the CRL contains a delta CRL but - // the delta CRL is not supported. - ErrDeltaCRLNotSupported = errors.New("delta CRL is not supported") -) +// ErrCacheMiss is returned when a cache miss occurs. +var ErrCacheMiss = errors.New("cache miss") diff --git a/revocation/crl/fetcher.go b/revocation/crl/fetcher.go new file mode 100644 index 00000000..cef1a7b5 --- /dev/null +++ b/revocation/crl/fetcher.go @@ -0,0 +1,167 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package crl provides Fetcher interface with its implementation, and the +// Cache interface. +package crl + +import ( + "context" + "crypto/x509" + "encoding/asn1" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +// oidFreshestCRL is the object identifier for the distribution point +// for the delta CRL. (See RFC 5280, Section 5.2.6) +var oidFreshestCRL = asn1.ObjectIdentifier{2, 5, 29, 46} + +// maxCRLSize is the maximum size of CRL in bytes +// +// The 32 MiB limit is based on investigation that even the largest CRLs +// are less than 16 MiB. The limit is set to 32 MiB to prevent +const maxCRLSize = 32 * 1024 * 1024 // 32 MiB + +// Fetcher is an interface that specifies methods used for fetching CRL +// from the given URL +type Fetcher interface { + // Fetch retrieves the CRL from the given URL. + Fetch(ctx context.Context, url string) (*Bundle, error) +} + +// HTTPFetcher is a Fetcher implementation that fetches CRL from the given URL +type HTTPFetcher struct { + // Cache stores fetched CRLs and reuses them until the CRL reaches the + // NextUpdate time. + // If Cache is nil, no cache is used. + Cache Cache + + // DiscardCacheError specifies whether to discard any error on cache. + // + // ErrCacheMiss is not considered as an failure and will not be returned as + // an error if DiscardCacheError is false. + DiscardCacheError bool + + httpClient *http.Client +} + +// NewHTTPFetcher creates a new HTTPFetcher with the given HTTP client +func NewHTTPFetcher(httpClient *http.Client) (*HTTPFetcher, error) { + if httpClient == nil { + return nil, errors.New("httpClient cannot be nil") + } + + return &HTTPFetcher{ + httpClient: httpClient, + }, nil +} + +// Fetch retrieves the CRL from the given URL +// +// If cache is not nil, try to get the CRL from the cache first. On failure +// (e.g. cache miss), it will download the CRL from the URL and store it to the +// cache. +func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*Bundle, error) { + if url == "" { + return nil, errors.New("CRL URL cannot be empty") + } + + if f.Cache != nil { + bundle, err := f.Cache.Get(ctx, url) + if err == nil { + // check expiry + nextUpdate := bundle.BaseCRL.NextUpdate + if !nextUpdate.IsZero() && !time.Now().After(nextUpdate) { + return bundle, nil + } + } else if !errors.Is(err, ErrCacheMiss) && !f.DiscardCacheError { + return nil, fmt.Errorf("failed to retrieve CRL from cache: %w", err) + } + } + + bundle, err := f.fetch(ctx, url) + if err != nil { + return nil, fmt.Errorf("failed to retrieve CRL: %w", err) + } + + if f.Cache != nil { + err = f.Cache.Set(ctx, url, bundle) + if err != nil && !f.DiscardCacheError { + return nil, fmt.Errorf("failed to store CRL to cache: %w", err) + } + } + + return bundle, nil +} + +// fetch downloads the CRL from the given URL. +func (f *HTTPFetcher) fetch(ctx context.Context, url string) (*Bundle, error) { + // fetch base CRL + base, err := fetchCRL(ctx, url, f.httpClient) + if err != nil { + return nil, err + } + + // check delta CRL + // TODO: support delta CRL https://github.com/notaryproject/notation-core-go/issues/228 + for _, ext := range base.Extensions { + if ext.Id.Equal(oidFreshestCRL) { + return nil, errors.New("delta CRL is not supported") + } + } + + return &Bundle{ + BaseCRL: base, + }, nil +} + +func fetchCRL(ctx context.Context, crlURL string, client *http.Client) (*x509.RevocationList, error) { + // validate URL + parsedURL, err := url.Parse(crlURL) + if err != nil { + return nil, fmt.Errorf("invalid CRL URL: %w", err) + } + if parsedURL.Scheme != "http" { + return nil, fmt.Errorf("unsupported scheme: %s. Only supports CRL URL in HTTP protocol", parsedURL.Scheme) + } + + // download CRL + req, err := http.NewRequestWithContext(ctx, http.MethodGet, crlURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create CRL request: %w", err) + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("failed to download with status code: %d", resp.StatusCode) + } + // read with size limit + data, err := io.ReadAll(io.LimitReader(resp.Body, maxCRLSize)) + if err != nil { + return nil, fmt.Errorf("failed to read CRL response: %w", err) + } + if len(data) == maxCRLSize { + return nil, fmt.Errorf("CRL size exceeds the limit: %d", maxCRLSize) + } + + // parse CRL + return x509.ParseRevocationList(data) +} diff --git a/revocation/crl/fetcher_test.go b/revocation/crl/fetcher_test.go new file mode 100644 index 00000000..9b22f97e --- /dev/null +++ b/revocation/crl/fetcher_test.go @@ -0,0 +1,448 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crl + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "strings" + "sync" + "testing" + "time" + + "github.com/notaryproject/notation-core-go/testhelper" +) + +func TestNewHTTPFetcher(t *testing.T) { + t.Run("httpClient is nil", func(t *testing.T) { + _, err := NewHTTPFetcher(nil) + if err.Error() != "httpClient cannot be nil" { + t.Errorf("NewHTTPFetcher() error = %v, want %v", err, "httpClient cannot be nil") + } + }) +} + +func TestFetch(t *testing.T) { + // prepare crl + certChain := testhelper.GetRevokableRSAChainWithRevocations(2, false, true) + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + Number: big.NewInt(1), + NextUpdate: time.Now().Add(1 * time.Hour), + }, certChain[1].Cert, certChain[1].PrivateKey) + if err != nil { + t.Fatalf("failed to create base CRL: %v", err) + } + baseCRL, err := x509.ParseRevocationList(crlBytes) + if err != nil { + t.Fatalf("failed to parse base CRL: %v", err) + } + const exampleURL = "http://example.com" + const uncachedURL = "http://uncached.com" + + bundle := &Bundle{ + BaseCRL: baseCRL, + } + + t.Run("url is empty", func(t *testing.T) { + c := &memoryCache{} + httpClient := &http.Client{} + f, err := NewHTTPFetcher(httpClient) + if err != nil { + t.Errorf("NewHTTPFetcher() error = %v, want nil", err) + } + f.Cache = c + _, err = f.Fetch(context.Background(), "") + if err.Error() != "CRL URL cannot be empty" { + t.Fatalf("Fetcher.Fetch() error = %v, want CRL URL cannot be empty", err) + } + }) + + t.Run("fetch without cache", func(t *testing.T) { + httpClient := &http.Client{ + Transport: expectedRoundTripperMock{Body: baseCRL.Raw}, + } + f, err := NewHTTPFetcher(httpClient) + if err != nil { + t.Errorf("NewHTTPFetcher() error = %v, want nil", err) + } + bundle, err := f.Fetch(context.Background(), exampleURL) + if err != nil { + t.Errorf("Fetcher.Fetch() error = %v, want nil", err) + } + if !bytes.Equal(bundle.BaseCRL.Raw, baseCRL.Raw) { + t.Errorf("Fetcher.Fetch() base.Raw = %v, want %v", bundle.BaseCRL.Raw, baseCRL.Raw) + } + }) + + t.Run("cache hit", func(t *testing.T) { + // set the cache + c := &memoryCache{} + if err := c.Set(context.Background(), exampleURL, bundle); err != nil { + t.Errorf("Cache.Set() error = %v, want nil", err) + } + + httpClient := &http.Client{} + f, err := NewHTTPFetcher(httpClient) + if err != nil { + t.Errorf("NewHTTPFetcher() error = %v, want nil", err) + } + f.Cache = c + bundle, err := f.Fetch(context.Background(), exampleURL) + if err != nil { + t.Errorf("Fetcher.Fetch() error = %v, want nil", err) + } + if !bytes.Equal(bundle.BaseCRL.Raw, baseCRL.Raw) { + t.Errorf("Fetcher.Fetch() base.Raw = %v, want %v", bundle.BaseCRL.Raw, baseCRL.Raw) + } + }) + + t.Run("cache miss and download failed error", func(t *testing.T) { + c := &memoryCache{} + httpClient := &http.Client{ + Transport: errorRoundTripperMock{}, + } + f, err := NewHTTPFetcher(httpClient) + f.Cache = c + if err != nil { + t.Errorf("NewHTTPFetcher() error = %v, want nil", err) + } + _, err = f.Fetch(context.Background(), uncachedURL) + if err == nil { + t.Errorf("Fetcher.Fetch() error = nil, want not nil") + } + }) + + t.Run("cache miss", func(t *testing.T) { + c := &memoryCache{} + httpClient := &http.Client{ + Transport: expectedRoundTripperMock{Body: baseCRL.Raw}, + } + f, err := NewHTTPFetcher(httpClient) + if err != nil { + t.Errorf("NewHTTPFetcher() error = %v, want nil", err) + } + f.Cache = c + f.DiscardCacheError = false + bundle, err := f.Fetch(context.Background(), uncachedURL) + if err != nil { + t.Errorf("Fetcher.Fetch() error = %v, want nil", err) + } + if !bytes.Equal(bundle.BaseCRL.Raw, baseCRL.Raw) { + t.Errorf("Fetcher.Fetch() base.Raw = %v, want %v", bundle.BaseCRL.Raw, baseCRL.Raw) + } + }) + + t.Run("cache expired", func(t *testing.T) { + c := &memoryCache{} + // prepare an expired CRL + certChain := testhelper.GetRevokableRSAChainWithRevocations(2, false, true) + expiredCRLBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + Number: big.NewInt(1), + NextUpdate: time.Now().Add(-1 * time.Hour), + }, certChain[1].Cert, certChain[1].PrivateKey) + if err != nil { + t.Fatalf("failed to create base CRL: %v", err) + } + expiredCRL, err := x509.ParseRevocationList(expiredCRLBytes) + if err != nil { + t.Fatalf("failed to parse base CRL: %v", err) + } + // store the expired CRL + const expiredCRLURL = "http://example.com/expired" + bundle := &Bundle{ + BaseCRL: expiredCRL, + } + if err := c.Set(context.Background(), expiredCRLURL, bundle); err != nil { + t.Errorf("Cache.Set() error = %v, want nil", err) + } + + // fetch the expired CRL + httpClient := &http.Client{ + Transport: expectedRoundTripperMock{Body: baseCRL.Raw}, + } + f, err := NewHTTPFetcher(httpClient) + if err != nil { + t.Errorf("NewHTTPFetcher() error = %v, want nil", err) + } + f.Cache = c + f.DiscardCacheError = true + bundle, err = f.Fetch(context.Background(), expiredCRLURL) + if err != nil { + t.Errorf("Fetcher.Fetch() error = %v, want nil", err) + } + // should re-download the CRL + if !bytes.Equal(bundle.BaseCRL.Raw, baseCRL.Raw) { + t.Errorf("Fetcher.Fetch() base.Raw = %v, want %v", bundle.BaseCRL.Raw, baseCRL.Raw) + } + }) + + t.Run("delta CRL is not supported", func(t *testing.T) { + c := &memoryCache{} + // prepare a CRL with refresh CRL extension + certChain := testhelper.GetRevokableRSAChainWithRevocations(2, false, true) + expiredCRLBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + Number: big.NewInt(1), + NextUpdate: time.Now().Add(-1 * time.Hour), + ExtraExtensions: []pkix.Extension{ + { + Id: oidFreshestCRL, + Value: []byte{0x01, 0x02, 0x03}, + }, + }, + }, certChain[1].Cert, certChain[1].PrivateKey) + if err != nil { + t.Fatalf("failed to create base CRL: %v", err) + } + + httpClient := &http.Client{ + Transport: expectedRoundTripperMock{Body: expiredCRLBytes}, + } + f, err := NewHTTPFetcher(httpClient) + if err != nil { + t.Errorf("NewHTTPFetcher() error = %v, want nil", err) + } + f.Cache = c + f.DiscardCacheError = true + _, err = f.Fetch(context.Background(), uncachedURL) + if !strings.Contains(err.Error(), "delta CRL is not supported") { + t.Errorf("Fetcher.Fetch() error = %v, want delta CRL is not supported", err) + } + }) + + t.Run("Set cache error", func(t *testing.T) { + c := &errorCache{ + GetError: ErrCacheMiss, + SetError: errors.New("cache error"), + } + httpClient := &http.Client{ + Transport: expectedRoundTripperMock{Body: baseCRL.Raw}, + } + f, err := NewHTTPFetcher(httpClient) + if err != nil { + t.Errorf("NewHTTPFetcher() error = %v, want nil", err) + } + f.Cache = c + f.DiscardCacheError = true + bundle, err = f.Fetch(context.Background(), exampleURL) + if err != nil { + t.Errorf("Fetcher.Fetch() error = %v, want nil", err) + } + if !bytes.Equal(bundle.BaseCRL.Raw, baseCRL.Raw) { + t.Errorf("Fetcher.Fetch() base.Raw = %v, want %v", bundle.BaseCRL.Raw, baseCRL.Raw) + } + }) + + t.Run("Get error without discard", func(t *testing.T) { + c := &errorCache{ + GetError: errors.New("cache error"), + } + httpClient := &http.Client{ + Transport: expectedRoundTripperMock{Body: baseCRL.Raw}, + } + f, err := NewHTTPFetcher(httpClient) + if err != nil { + t.Errorf("NewHTTPFetcher() error = %v, want nil", err) + } + f.Cache = c + f.DiscardCacheError = false + _, err = f.Fetch(context.Background(), exampleURL) + if !strings.HasPrefix(err.Error(), "failed to retrieve CRL from cache:") { + t.Errorf("Fetcher.Fetch() error = %v, want failed to retrieve CRL from cache:", err) + } + }) + + t.Run("Set error without discard", func(t *testing.T) { + c := &errorCache{ + GetError: ErrCacheMiss, + SetError: errors.New("cache error"), + } + httpClient := &http.Client{ + Transport: expectedRoundTripperMock{Body: baseCRL.Raw}, + } + f, err := NewHTTPFetcher(httpClient) + if err != nil { + t.Errorf("NewHTTPFetcher() error = %v, want nil", err) + } + f.Cache = c + f.DiscardCacheError = false + _, err = f.Fetch(context.Background(), exampleURL) + if !strings.HasPrefix(err.Error(), "failed to store CRL to cache:") { + t.Errorf("Fetcher.Fetch() error = %v, want failed to store CRL to cache:", err) + } + }) +} + +func TestDownload(t *testing.T) { + t.Run("parse url error", func(t *testing.T) { + _, err := fetchCRL(context.Background(), ":", http.DefaultClient) + if err == nil { + t.Fatal("expected error") + } + }) + t.Run("https download", func(t *testing.T) { + _, err := fetchCRL(context.Background(), "https://example.com", http.DefaultClient) + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("http.NewRequestWithContext error", func(t *testing.T) { + var ctx context.Context = nil + _, err := fetchCRL(ctx, "http://example.com", &http.Client{}) + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("client.Do error", func(t *testing.T) { + _, err := fetchCRL(context.Background(), "http://example.com", &http.Client{ + Transport: errorRoundTripperMock{}, + }) + + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("status code is not 2xx", func(t *testing.T) { + _, err := fetchCRL(context.Background(), "http://example.com", &http.Client{ + Transport: serverErrorRoundTripperMock{}, + }) + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("readAll error", func(t *testing.T) { + _, err := fetchCRL(context.Background(), "http://example.com", &http.Client{ + Transport: readFailedRoundTripperMock{}, + }) + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("exceed the size limit", func(t *testing.T) { + _, err := fetchCRL(context.Background(), "http://example.com", &http.Client{ + Transport: expectedRoundTripperMock{Body: make([]byte, maxCRLSize+1)}, + }) + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("invalid crl", func(t *testing.T) { + _, err := fetchCRL(context.Background(), "http://example.com", &http.Client{ + Transport: expectedRoundTripperMock{Body: []byte("invalid crl")}, + }) + if err == nil { + t.Fatal("expected error") + } + }) +} + +type errorRoundTripperMock struct{} + +func (rt errorRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("error") +} + +type serverErrorRoundTripperMock struct{} + +func (rt serverErrorRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + Request: req, + StatusCode: http.StatusInternalServerError, + }, nil +} + +type readFailedRoundTripperMock struct{} + +func (rt readFailedRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: errorReaderMock{}, + }, nil +} + +type errorReaderMock struct{} + +func (r errorReaderMock) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("error") +} + +func (r errorReaderMock) Close() error { + return nil +} + +type expectedRoundTripperMock struct { + Body []byte +} + +func (rt expectedRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + Request: req, + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBuffer(rt.Body)), + }, nil +} + +// memoryCache is an in-memory cache that stores CRL bundles for testing. +type memoryCache struct { + store sync.Map +} + +// Get retrieves the CRL from the memory store. +// +// - if the key does not exist, return ErrNotFound +// - if the CRL is expired, return ErrCacheMiss +func (c *memoryCache) Get(ctx context.Context, url string) (*Bundle, error) { + value, ok := c.store.Load(url) + if !ok { + return nil, ErrCacheMiss + } + bundle, ok := value.(*Bundle) + if !ok { + return nil, fmt.Errorf("invalid type: %T", value) + } + + return bundle, nil +} + +// Set stores the CRL in the memory store. +func (c *memoryCache) Set(ctx context.Context, url string, bundle *Bundle) error { + c.store.Store(url, bundle) + return nil +} + +type errorCache struct { + GetError error + SetError error +} + +func (c *errorCache) Get(ctx context.Context, url string) (*Bundle, error) { + return nil, c.GetError +} + +func (c *errorCache) Set(ctx context.Context, url string, bundle *Bundle) error { + return c.SetError +} diff --git a/revocation/internal/crl/crl.go b/revocation/internal/crl/crl.go index 48a5930a..50f2c085 100644 --- a/revocation/internal/crl/crl.go +++ b/revocation/internal/crl/crl.go @@ -21,11 +21,9 @@ import ( "encoding/asn1" "errors" "fmt" - "io" - "net/http" - "net/url" "time" + "github.com/notaryproject/notation-core-go/revocation/crl" "github.com/notaryproject/notation-core-go/revocation/result" ) @@ -43,18 +41,13 @@ var ( oidInvalidityDate = asn1.ObjectIdentifier{2, 5, 29, 24} ) -// maxCRLSize is the maximum size of CRL in bytes -// -// CRL examples: https://chasersystems.com/blog/an-analysis-of-certificate-revocation-list-sizes/ -const maxCRLSize = 32 * 1024 * 1024 // 32 MiB - -// CertCheckStatusOptions specifies values that are needed to check CRL +// CertCheckStatusOptions specifies values that are needed to check CRL. type CertCheckStatusOptions struct { - // HTTPClient is the HTTP client used to download CRL - HTTPClient *http.Client + // Fetcher is used to fetch the CRL from the CRL distribution points. + Fetcher crl.Fetcher // SigningTime is used to compare with the invalidity date during revocation - // check + // check. SigningTime time.Time } @@ -73,12 +66,31 @@ func CertCheckStatus(ctx context.Context, cert, issuer *x509.Certificate, opts C Result: result.ResultNonRevokable, ServerResults: []*result.ServerResult{{ RevocationMethod: result.RevocationMethodCRL, + Error: errors.New("CRL is not supported"), Result: result.ResultNonRevokable, }}, RevocationMethod: result.RevocationMethodCRL, } } + if opts.Fetcher == nil { + return &result.CertRevocationResult{ + Result: result.ResultUnknown, + ServerResults: []*result.ServerResult{{ + RevocationMethod: result.RevocationMethodCRL, + Error: errors.New("CRL fetcher cannot be nil"), + Result: result.ResultUnknown, + }}, + RevocationMethod: result.RevocationMethodCRL, + } + } + + var ( + serverResults = make([]*result.ServerResult, 0, len(cert.CRLDistributionPoints)) + lastErr error + crlURL string + ) + // The CRLDistributionPoints contains the URIs of all the CRL distribution // points. Since it does not distinguish the reason field, it needs to check // all the URIs to avoid missing any partial CRLs. @@ -86,28 +98,24 @@ func CertCheckStatus(ctx context.Context, cert, issuer *x509.Certificate, opts C // For the majority of the certificates, there is only one CRL distribution // point with one CRL URI, which will be cached, so checking all the URIs is // not a performance issue. - var ( - serverResults = make([]*result.ServerResult, 0, len(cert.CRLDistributionPoints)) - lastErr error - crlURL string - ) for _, crlURL = range cert.CRLDistributionPoints { - baseCRL, err := download(ctx, crlURL, opts.HTTPClient) + bundle, err := opts.Fetcher.Fetch(ctx, crlURL) if err != nil { lastErr = fmt.Errorf("failed to download CRL from %s: %w", crlURL, err) break } - if err = validate(baseCRL, issuer); err != nil { + if err = validate(bundle.BaseCRL, issuer); err != nil { lastErr = fmt.Errorf("failed to validate CRL from %s: %w", crlURL, err) break } - crlResult, err := checkRevocation(cert, baseCRL, opts.SigningTime, crlURL) + crlResult, err := checkRevocation(cert, bundle.BaseCRL, opts.SigningTime, crlURL) if err != nil { lastErr = fmt.Errorf("failed to check revocation status from %s: %w", crlURL, err) break } + if crlResult.Result == result.ResultRevoked { return &result.CertRevocationResult{ Result: result.ResultRevoked, @@ -152,15 +160,18 @@ func validate(crl *x509.RevocationList, issuer *x509.Certificate) error { } // check validity + if crl.NextUpdate.IsZero() { + return errors.New("CRL NextUpdate is not set") + } now := time.Now() - if !crl.NextUpdate.IsZero() && now.After(crl.NextUpdate) { + if now.After(crl.NextUpdate) { return fmt.Errorf("expired CRL. Current time %v is after CRL NextUpdate %v", now, crl.NextUpdate) } for _, ext := range crl.Extensions { switch { case ext.Id.Equal(oidFreshestCRL): - return ErrDeltaCRLNotSupported + return errors.New("delta CRL is not supported") case ext.Id.Equal(oidIssuingDistributionPoint): // IssuingDistributionPoint is a critical extension that identifies // the scope of the CRL. Since we will check all the CRL @@ -247,42 +258,3 @@ func parseEntryExtensions(entry x509.RevocationListEntry) (entryExtensions, erro return extensions, nil } - -func download(ctx context.Context, crlURL string, client *http.Client) (*x509.RevocationList, error) { - // validate URL - parsedURL, err := url.Parse(crlURL) - if err != nil { - return nil, fmt.Errorf("invalid CRL URL: %w", err) - } - if parsedURL.Scheme != "http" { - return nil, fmt.Errorf("unsupported CRL endpoint: %s. Only urls with HTTP scheme is supported", crlURL) - } - - // download CRL - req, err := http.NewRequestWithContext(ctx, http.MethodGet, crlURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to create CRL request %q: %w", crlURL, err) - } - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("request failed for %q: %w", crlURL, err) - } - defer resp.Body.Close() - - // check response - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("%s %q: failed to download with status code: %d", resp.Request.Method, resp.Request.URL, resp.StatusCode) - } - - // read with size limit - limitedReader := io.LimitReader(resp.Body, maxCRLSize) - data, err := io.ReadAll(limitedReader) - if err != nil { - return nil, fmt.Errorf("failed to read CRL response from %q: %w", resp.Request.URL, err) - } - if len(data) == maxCRLSize { - return nil, fmt.Errorf("%s %q: CRL size reached the %d MiB size limit", resp.Request.Method, resp.Request.URL, maxCRLSize/1024/1024) - } - - return x509.ParseRevocationList(data) -} diff --git a/revocation/internal/crl/crl_test.go b/revocation/internal/crl/crl_test.go index 2129fb79..5555374f 100644 --- a/revocation/internal/crl/crl_test.go +++ b/revocation/internal/crl/crl_test.go @@ -20,49 +20,78 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/asn1" - "errors" "fmt" "io" "math/big" "net/http" + "strings" + "sync" "testing" "time" + crlutils "github.com/notaryproject/notation-core-go/revocation/crl" "github.com/notaryproject/notation-core-go/revocation/result" "github.com/notaryproject/notation-core-go/testhelper" ) func TestCertCheckStatus(t *testing.T) { - t.Run("certtificate does not have CRLDistributionPoints", func(t *testing.T) { + t.Run("certificate does not have CRLDistributionPoints", func(t *testing.T) { cert := &x509.Certificate{} r := CertCheckStatus(context.Background(), cert, &x509.Certificate{}, CertCheckStatusOptions{}) - if r.Result != result.ResultNonRevokable { - t.Fatalf("expected NonRevokable, got %s", r.Result) + if r.ServerResults[0].Error.Error() != "CRL is not supported" { + t.Fatalf("expected CRL is not supported, got %v", r.ServerResults[0].Error) + } + }) + + t.Run("fetcher is nil", func(t *testing.T) { + cert := &x509.Certificate{ + CRLDistributionPoints: []string{"http://example.com"}, + } + r := CertCheckStatus(context.Background(), cert, &x509.Certificate{}, CertCheckStatusOptions{}) + if r.ServerResults[0].Error.Error() != "CRL fetcher cannot be nil" { + t.Fatalf("expected CRL fetcher cannot be nil, got %v", r.ServerResults[0].Error) } }) t.Run("download error", func(t *testing.T) { + memoryCache := &memoryCache{} + cert := &x509.Certificate{ CRLDistributionPoints: []string{"http://example.com"}, } + fetcher, err := crlutils.NewHTTPFetcher( + &http.Client{Transport: errorRoundTripperMock{}}, + ) + if err != nil { + t.Fatal(err) + } + fetcher.Cache = memoryCache + r := CertCheckStatus(context.Background(), cert, &x509.Certificate{}, CertCheckStatusOptions{ - HTTPClient: &http.Client{ - Transport: errorRoundTripperMock{}, - }, + Fetcher: fetcher, }) + if r.ServerResults[0].Error == nil { t.Fatal("expected error") } }) t.Run("CRL validate failed", func(t *testing.T) { + memoryCache := &memoryCache{} + cert := &x509.Certificate{ CRLDistributionPoints: []string{"http://example.com"}, } + fetcher, err := crlutils.NewHTTPFetcher( + &http.Client{Transport: expiredCRLRoundTripperMock{}}, + ) + if err != nil { + t.Fatal(err) + } + fetcher.Cache = memoryCache + r := CertCheckStatus(context.Background(), cert, &x509.Certificate{}, CertCheckStatusOptions{ - HTTPClient: &http.Client{ - Transport: expiredCRLRoundTripperMock{}, - }, + Fetcher: fetcher, }) if r.ServerResults[0].Error == nil { t.Fatal("expected error") @@ -75,6 +104,8 @@ func TestCertCheckStatus(t *testing.T) { issuerKey := chain[1].PrivateKey t.Run("revoked", func(t *testing.T) { + memoryCache := &memoryCache{} + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ NextUpdate: time.Now().Add(time.Hour), Number: big.NewInt(20240720), @@ -89,10 +120,16 @@ func TestCertCheckStatus(t *testing.T) { t.Fatal(err) } + fetcher, err := crlutils.NewHTTPFetcher( + &http.Client{Transport: expectedRoundTripperMock{Body: crlBytes}}, + ) + if err != nil { + t.Fatal(err) + } + fetcher.Cache = memoryCache + fetcher.DiscardCacheError = true r := CertCheckStatus(context.Background(), chain[0].Cert, issuerCert, CertCheckStatusOptions{ - HTTPClient: &http.Client{ - Transport: expectedRoundTripperMock{Body: crlBytes}, - }, + Fetcher: fetcher, }) if r.Result != result.ResultRevoked { t.Fatalf("expected revoked, got %s", r.Result) @@ -100,6 +137,8 @@ func TestCertCheckStatus(t *testing.T) { }) t.Run("unknown critical extension", func(t *testing.T) { + memoryCache := &memoryCache{} + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ NextUpdate: time.Now().Add(time.Hour), Number: big.NewInt(20240720), @@ -120,10 +159,17 @@ func TestCertCheckStatus(t *testing.T) { t.Fatal(err) } + fetcher, err := crlutils.NewHTTPFetcher( + &http.Client{Transport: expectedRoundTripperMock{Body: crlBytes}}, + ) + if err != nil { + t.Fatal(err) + } + + fetcher.Cache = memoryCache + fetcher.DiscardCacheError = true r := CertCheckStatus(context.Background(), chain[0].Cert, issuerCert, CertCheckStatusOptions{ - HTTPClient: &http.Client{ - Transport: expectedRoundTripperMock{Body: crlBytes}, - }, + Fetcher: fetcher, }) if r.ServerResults[0].Error == nil { t.Fatal("expected error") @@ -131,6 +177,8 @@ func TestCertCheckStatus(t *testing.T) { }) t.Run("Not revoked", func(t *testing.T) { + memoryCache := &memoryCache{} + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ NextUpdate: time.Now().Add(time.Hour), Number: big.NewInt(20240720), @@ -139,10 +187,16 @@ func TestCertCheckStatus(t *testing.T) { t.Fatal(err) } + fetcher, err := crlutils.NewHTTPFetcher( + &http.Client{Transport: expectedRoundTripperMock{Body: crlBytes}}, + ) + if err != nil { + t.Fatal(err) + } + fetcher.Cache = memoryCache + fetcher.DiscardCacheError = true r := CertCheckStatus(context.Background(), chain[0].Cert, issuerCert, CertCheckStatusOptions{ - HTTPClient: &http.Client{ - Transport: expectedRoundTripperMock{Body: crlBytes}, - }, + Fetcher: fetcher, }) if r.Result != result.ResultOK { t.Fatalf("expected OK, got %s", r.Result) @@ -150,6 +204,8 @@ func TestCertCheckStatus(t *testing.T) { }) t.Run("CRL with delta CRL is not checked", func(t *testing.T) { + memoryCache := &memoryCache{} + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ NextUpdate: time.Now().Add(time.Hour), Number: big.NewInt(20240720), @@ -163,14 +219,113 @@ func TestCertCheckStatus(t *testing.T) { if err != nil { t.Fatal(err) } + fetcher, err := crlutils.NewHTTPFetcher( + &http.Client{Transport: expectedRoundTripperMock{Body: crlBytes}}, + ) + if err != nil { + t.Fatal(err) + } + fetcher.Cache = memoryCache + fetcher.DiscardCacheError = true + r := CertCheckStatus(context.Background(), chain[0].Cert, issuerCert, CertCheckStatusOptions{ + Fetcher: fetcher, + }) + if !strings.Contains(r.ServerResults[0].Error.Error(), "delta CRL is not supported") { + t.Fatalf("unexpected error, got %v, expected %v", r.ServerResults[0].Error, "delta CRL is not supported") + } + }) + + memoryCache := &memoryCache{} + // create a stale CRL + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + NextUpdate: time.Now().Add(-time.Hour), + Number: big.NewInt(20240720), + }, issuerCert, issuerKey) + if err != nil { + t.Fatal(err) + } + base, err := x509.ParseRevocationList(crlBytes) + if err != nil { + t.Fatal(err) + } + bundle := &crlutils.Bundle{ + BaseCRL: base, + } + + chain[0].Cert.CRLDistributionPoints = []string{"http://example.com"} + + t.Run("invalid stale CRL cache, and re-download failed", func(t *testing.T) { + // save to cache + if err := memoryCache.Set(context.Background(), "http://example.com", bundle); err != nil { + t.Fatal(err) + } + + fetcher, err := crlutils.NewHTTPFetcher( + &http.Client{Transport: errorRoundTripperMock{}}, + ) + if err != nil { + t.Fatal(err) + } + fetcher.Cache = memoryCache + fetcher.DiscardCacheError = true r := CertCheckStatus(context.Background(), chain[0].Cert, issuerCert, CertCheckStatusOptions{ - HTTPClient: &http.Client{ - Transport: expectedRoundTripperMock{Body: crlBytes}, - }, + Fetcher: fetcher, }) - if !errors.Is(r.ServerResults[0].Error, ErrDeltaCRLNotSupported) { - t.Fatal("expected ErrDeltaCRLNotChecked") + if !strings.HasPrefix(r.ServerResults[0].Error.Error(), "failed to download CRL from") { + t.Fatalf("unexpected error, got %v", r.ServerResults[0].Error) + } + }) + + t.Run("invalid stale CRL cache, re-download and still validate failed", func(t *testing.T) { + // save to cache + if err := memoryCache.Set(context.Background(), "http://example.com", bundle); err != nil { + t.Fatal(err) + } + + fetcher, err := crlutils.NewHTTPFetcher( + &http.Client{Transport: expectedRoundTripperMock{Body: crlBytes}}, + ) + if err != nil { + t.Fatal(err) + } + fetcher.Cache = memoryCache + fetcher.DiscardCacheError = true + r := CertCheckStatus(context.Background(), chain[0].Cert, issuerCert, CertCheckStatusOptions{ + Fetcher: fetcher, + }) + if !strings.HasPrefix(r.ServerResults[0].Error.Error(), "failed to validate CRL from") { + t.Fatalf("unexpected error, got %v", r.ServerResults[0].Error) + } + }) + + t.Run("invalid stale CRL cache, re-download and validate seccessfully", func(t *testing.T) { + // save to cache + if err := memoryCache.Set(context.Background(), "http://example.com", bundle); err != nil { + t.Fatal(err) + } + + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + NextUpdate: time.Now().Add(time.Hour), + Number: big.NewInt(20240720), + }, issuerCert, issuerKey) + if err != nil { + t.Fatal(err) + } + + fetcher, err := crlutils.NewHTTPFetcher( + &http.Client{Transport: expectedRoundTripperMock{Body: crlBytes}}, + ) + if err != nil { + t.Fatal(err) + } + fetcher.Cache = memoryCache + fetcher.DiscardCacheError = true + r := CertCheckStatus(context.Background(), chain[0].Cert, issuerCert, CertCheckStatusOptions{ + Fetcher: fetcher, + }) + if r.Result != result.ResultOK { + t.Fatalf("expected OK, got %s", r.Result) } }) } @@ -268,6 +423,35 @@ func TestValidate(t *testing.T) { t.Fatal(err) } }) + + t.Run("delta CRL is not supported", func(t *testing.T) { + chain := testhelper.GetRevokableRSAChainWithRevocations(1, false, true) + issuerCert := chain[0].Cert + issuerKey := chain[0].PrivateKey + + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + NextUpdate: time.Now().Add(time.Hour), + Number: big.NewInt(20240720), + ExtraExtensions: []pkix.Extension{ + { + Id: oidFreshestCRL, + Critical: false, + }, + }, + }, issuerCert, issuerKey) + if err != nil { + t.Fatal(err) + } + + crl, err := x509.ParseRevocationList(crlBytes) + if err != nil { + t.Fatal(err) + } + + if err := validate(crl, issuerCert); err.Error() != "delta CRL is not supported" { + t.Fatalf("got %v, expected delta CRL is not supported", err) + } + }) } func TestCheckRevocation(t *testing.T) { @@ -535,66 +719,6 @@ func marshalGeneralizedTimeToBytes(t time.Time) ([]byte, error) { return asn1.Marshal(t) } -func TestDownload(t *testing.T) { - t.Run("parse url error", func(t *testing.T) { - _, err := download(context.Background(), ":", http.DefaultClient) - if err == nil { - t.Fatal("expected error") - } - }) - t.Run("https download", func(t *testing.T) { - _, err := download(context.Background(), "https://example.com", http.DefaultClient) - if err == nil { - t.Fatal("expected error") - } - }) - - t.Run("http.NewRequestWithContext error", func(t *testing.T) { - var ctx context.Context = nil - _, err := download(ctx, "http://example.com", &http.Client{}) - if err == nil { - t.Fatal("expected error") - } - }) - - t.Run("client.Do error", func(t *testing.T) { - _, err := download(context.Background(), "http://example.com", &http.Client{ - Transport: errorRoundTripperMock{}, - }) - - if err == nil { - t.Fatal("expected error") - } - }) - - t.Run("status code is not 2xx", func(t *testing.T) { - _, err := download(context.Background(), "http://example.com", &http.Client{ - Transport: serverErrorRoundTripperMock{}, - }) - if err == nil { - t.Fatal("expected error") - } - }) - - t.Run("readAll error", func(t *testing.T) { - _, err := download(context.Background(), "http://example.com", &http.Client{ - Transport: readFailedRoundTripperMock{}, - }) - if err == nil { - t.Fatal("expected error") - } - }) - - t.Run("exceed the size limit", func(t *testing.T) { - _, err := download(context.Background(), "http://example.com", &http.Client{ - Transport: expectedRoundTripperMock{Body: make([]byte, maxCRLSize+1)}, - }) - if err == nil { - t.Fatal("expected error") - } - }) -} - func TestSupported(t *testing.T) { t.Run("supported", func(t *testing.T) { cert := &x509.Certificate{ @@ -619,28 +743,6 @@ func (rt errorRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, er return nil, fmt.Errorf("error") } -type serverErrorRoundTripperMock struct{} - -func (rt serverErrorRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { - return &http.Response{ - Request: req, - StatusCode: http.StatusInternalServerError, - }, nil -} - -type readFailedRoundTripperMock struct{} - -func (rt readFailedRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Body: errorReaderMock{}, - Request: &http.Request{ - Method: http.MethodGet, - URL: req.URL, - }, - }, nil -} - type expiredCRLRoundTripperMock struct{} func (rt expiredCRLRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { @@ -662,16 +764,6 @@ func (rt expiredCRLRoundTripperMock) RoundTrip(req *http.Request) (*http.Respons }, nil } -type errorReaderMock struct{} - -func (r errorReaderMock) Read(p []byte) (n int, err error) { - return 0, fmt.Errorf("error") -} - -func (r errorReaderMock) Close() error { - return nil -} - type expectedRoundTripperMock struct { Body []byte } @@ -683,3 +775,31 @@ func (rt expectedRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, Body: io.NopCloser(bytes.NewBuffer(rt.Body)), }, nil } + +// memoryCache is an in-memory cache that stores CRL bundles for testing. +type memoryCache struct { + store sync.Map +} + +// Get retrieves the CRL from the memory store. +// +// - if the key does not exist, return ErrNotFound +// - if the CRL is expired, return ErrCacheMiss +func (c *memoryCache) Get(ctx context.Context, url string) (*crlutils.Bundle, error) { + value, ok := c.store.Load(url) + if !ok { + return nil, crlutils.ErrCacheMiss + } + bundle, ok := value.(*crlutils.Bundle) + if !ok { + return nil, fmt.Errorf("invalid type: %T", value) + } + + return bundle, nil +} + +// Set stores the CRL in the memory store. +func (c *memoryCache) Set(ctx context.Context, url string, bundle *crlutils.Bundle) error { + c.store.Store(url, bundle) + return nil +} diff --git a/revocation/ocsp/error.go b/revocation/ocsp/errors.go similarity index 100% rename from revocation/ocsp/error.go rename to revocation/ocsp/errors.go diff --git a/revocation/revocation.go b/revocation/revocation.go index a20c63c1..f249d915 100644 --- a/revocation/revocation.go +++ b/revocation/revocation.go @@ -24,6 +24,7 @@ import ( "sync" "time" + crlutil "github.com/notaryproject/notation-core-go/revocation/crl" "github.com/notaryproject/notation-core-go/revocation/internal/crl" "github.com/notaryproject/notation-core-go/revocation/internal/ocsp" "github.com/notaryproject/notation-core-go/revocation/internal/x509util" @@ -69,7 +70,7 @@ type Validator interface { // revocation is an internal struct used for revocation checking type revocation struct { ocspHTTPClient *http.Client - crlHTTPClient *http.Client + crlFetcher crlutil.Fetcher certChainPurpose purpose.Purpose } @@ -81,9 +82,14 @@ func New(httpClient *http.Client) (Revocation, error) { if httpClient == nil { return nil, errors.New("invalid input: a non-nil httpClient must be specified") } + fetcher, err := crlutil.NewHTTPFetcher(httpClient) + if err != nil { + return nil, err + } + return &revocation{ ocspHTTPClient: httpClient, - crlHTTPClient: httpClient, + crlFetcher: fetcher, certChainPurpose: purpose.CodeSigning, }, nil } @@ -95,10 +101,10 @@ type Options struct { // OPTIONAL. OCSPHTTPClient *http.Client - // CRLHTTPClient is the HTTP client for CRL request. If not provided, - // a default *http.Client with timeout of 5 seconds will be used. - // OPTIONAL. - CRLHTTPClient *http.Client + // CRLFetcher is a fetcher for CRL with cache. If not provided, a default + // fetcher with an HTTP client and a timeout of 5 seconds will be used + // without cache. + CRLFetcher crlutil.Fetcher // CertChainPurpose is the purpose of the certificate chain. Supported // values are CodeSigning and Timestamping. Default value is CodeSigning. @@ -112,8 +118,13 @@ func NewWithOptions(opts Options) (Validator, error) { opts.OCSPHTTPClient = &http.Client{Timeout: 2 * time.Second} } - if opts.CRLHTTPClient == nil { - opts.CRLHTTPClient = &http.Client{Timeout: 5 * time.Second} + fetcher := opts.CRLFetcher + if fetcher == nil { + newFetcher, err := crlutil.NewHTTPFetcher(&http.Client{Timeout: 5 * time.Second}) + if err != nil { + return nil, err + } + fetcher = newFetcher } switch opts.CertChainPurpose { @@ -124,7 +135,7 @@ func NewWithOptions(opts Options) (Validator, error) { return &revocation{ ocspHTTPClient: opts.OCSPHTTPClient, - crlHTTPClient: opts.CRLHTTPClient, + crlFetcher: fetcher, certChainPurpose: opts.CertChainPurpose, }, nil } @@ -170,8 +181,9 @@ func (r *revocation) ValidateContext(ctx context.Context, validateContextOpts Va HTTPClient: r.ocspHTTPClient, SigningTime: validateContextOpts.AuthenticSigningTime, } + crlOpts := crl.CertCheckStatusOptions{ - HTTPClient: r.crlHTTPClient, + Fetcher: r.crlFetcher, SigningTime: validateContextOpts.AuthenticSigningTime, } diff --git a/revocation/revocation_test.go b/revocation/revocation_test.go index 2ac8b4c9..00e9597d 100644 --- a/revocation/revocation_test.go +++ b/revocation/revocation_test.go @@ -28,6 +28,7 @@ import ( "testing" "time" + "github.com/notaryproject/notation-core-go/revocation/crl" revocationocsp "github.com/notaryproject/notation-core-go/revocation/internal/ocsp" "github.com/notaryproject/notation-core-go/revocation/purpose" "github.com/notaryproject/notation-core-go/revocation/result" @@ -1035,15 +1036,20 @@ func TestCRL(t *testing.T) { t.Run("CRL check valid", func(t *testing.T) { chain := testhelper.GetRevokableRSAChainWithRevocations(3, false, true) - revocationClient, err := NewWithOptions(Options{ - CRLHTTPClient: &http.Client{ - Timeout: 5 * time.Second, - Transport: &crlRoundTripper{ - CertChain: chain, - Revoked: false, - }, + fetcher, err := crl.NewHTTPFetcher(&http.Client{ + Timeout: 5 * time.Second, + Transport: &crlRoundTripper{ + CertChain: chain, + Revoked: false, }, + }) + if err != nil { + t.Errorf("Expected successful creation of fetcher, but received error: %v", err) + } + + revocationClient, err := NewWithOptions(Options{ OCSPHTTPClient: &http.Client{}, + CRLFetcher: fetcher, CertChainPurpose: purpose.CodeSigning, }) if err != nil { @@ -1084,15 +1090,20 @@ func TestCRL(t *testing.T) { t.Run("CRL check with revoked status", func(t *testing.T) { chain := testhelper.GetRevokableRSAChainWithRevocations(3, false, true) - revocationClient, err := NewWithOptions(Options{ - CRLHTTPClient: &http.Client{ - Timeout: 5 * time.Second, - Transport: &crlRoundTripper{ - CertChain: chain, - Revoked: true, - }, + fetcher, err := crl.NewHTTPFetcher(&http.Client{ + Timeout: 5 * time.Second, + Transport: &crlRoundTripper{ + CertChain: chain, + Revoked: true, }, + }) + if err != nil { + t.Errorf("Expected successful creation of fetcher, but received error: %v", err) + } + + revocationClient, err := NewWithOptions(Options{ OCSPHTTPClient: &http.Client{}, + CRLFetcher: fetcher, CertChainPurpose: purpose.CodeSigning, }) if err != nil { @@ -1140,17 +1151,21 @@ func TestCRL(t *testing.T) { t.Run("OCSP fallback to CRL", func(t *testing.T) { chain := testhelper.GetRevokableRSAChainWithRevocations(3, true, true) + fetcher, err := crl.NewHTTPFetcher(&http.Client{ + Timeout: 5 * time.Second, + Transport: &crlRoundTripper{ + CertChain: chain, + Revoked: true, + FailOCSP: true, + }, + }) + if err != nil { + t.Errorf("Expected successful creation of fetcher, but received error: %v", err) + } revocationClient, err := NewWithOptions(Options{ - CRLHTTPClient: &http.Client{ - Timeout: 5 * time.Second, - Transport: &crlRoundTripper{ - CertChain: chain, - Revoked: true, - FailOCSP: true, - }, - }, OCSPHTTPClient: &http.Client{}, + CRLFetcher: fetcher, CertChainPurpose: purpose.CodeSigning, }) if err != nil { @@ -1218,9 +1233,14 @@ func TestPanicHandling(t *testing.T) { Transport: panicTransport{}, } + fetcher, err := crl.NewHTTPFetcher(client) + if err != nil { + t.Errorf("Expected successful creation of fetcher, but received error: %v", err) + } + r, err := NewWithOptions(Options{ OCSPHTTPClient: client, - CRLHTTPClient: client, + CRLFetcher: fetcher, CertChainPurpose: purpose.CodeSigning, }) if err != nil { @@ -1245,9 +1265,14 @@ func TestPanicHandling(t *testing.T) { Transport: panicTransport{}, } + fetcher, err := crl.NewHTTPFetcher(client) + if err != nil { + t.Errorf("Expected successful creation of fetcher, but received error: %v", err) + } + r, err := NewWithOptions(Options{ OCSPHTTPClient: client, - CRLHTTPClient: client, + CRLFetcher: fetcher, CertChainPurpose: purpose.CodeSigning, }) if err != nil {