diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml index 0fc268d..367a8b0 100644 --- a/.github/workflows/scorecards.yml +++ b/.github/workflows/scorecards.yml @@ -1,3 +1,6 @@ +# Copyright 2024 Yahoo Inc. +# Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms. + name: Scorecards supply-chain security on: # Only the default branch is supported. diff --git a/license_comment b/license_comment new file mode 100644 index 0000000..76ca316 --- /dev/null +++ b/license_comment @@ -0,0 +1,2 @@ +Copyright 2024 Yahoo Inc. +Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms. diff --git a/utils/cert_reload.go b/utils/cert_reload.go new file mode 100644 index 0000000..941b706 --- /dev/null +++ b/utils/cert_reload.go @@ -0,0 +1,167 @@ +// Copyright 2024 Yahoo Inc. +// Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms. + +package utils + +import ( + "bytes" + "crypto/subtle" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "log" + "sync" + "time" +) + +const defaultMemPollInterval = 60 * time.Minute + +// MemCertReloader reloads the (key, cert) pair by invoking the callback functions +// getter. +type MemCertReloader struct { + mu sync.RWMutex + getter func() ([]byte, []byte, error) + cert *tls.Certificate + + logger func(fmt string, args ...interface{}) + once sync.Once + stop chan struct{} + pollInterval time.Duration +} + +// GetCertificate returns the latest known certificate and can be assigned to the +// GetCertificate member of the TLS config. For http.server use. +func (w *MemCertReloader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return w.GetLatestCertificate() +} + +// GetClientCertificate returns the latest known certificate and can be assigned to the +// GetClientCertificate member of the TLS config. For http.client use. +func (w *MemCertReloader) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return w.GetLatestCertificate() +} + +// GetLatestCertificate returns the latest known certificate. +func (w *MemCertReloader) GetLatestCertificate() (*tls.Certificate, error) { + w.mu.RLock() + c := w.cert + w.mu.RUnlock() + return c, nil +} + +// Close stops the background refresh. +func (w *MemCertReloader) Close() error { + w.once.Do(func() { + close(w.stop) + }) + return nil +} + +// Reload reloads the certificate into the memory cache when the certificate is updated and valid. +func (w *MemCertReloader) Reload() error { + cb, kb, err := w.getter() + if err != nil { + return fmt.Errorf("failed to get the certificate and private key, %v", err) + } + + if err := ValidateCertExpiry(cb, time.Now()); err != nil { + return fmt.Errorf("failed to validate certicate, %v", err) + } + + cert, err := tls.X509KeyPair(cb, kb) + if err != nil { + return fmt.Errorf("failed to parse the certificate and private key, %v", err) + } + + if w.cert != nil { + if subtle.ConstantTimeCompare(cert.Certificate[0], w.cert.Certificate[0]) == 1 { + return nil + } + } + + w.mu.Lock() + w.cert = &cert + w.mu.Unlock() + w.logger("certs reloaded at %v", time.Now()) + return nil +} + +func (w *MemCertReloader) pollRefresh() { + poll := time.NewTicker(w.pollInterval) + defer poll.Stop() + for { + select { + case <-poll.C: + if err := w.Reload(); err != nil { + w.logger("cert reload error: %v\n", err) + } + case <-w.stop: + return + } + } +} + +// CertReloadConfig contains the config for cert reload. +type CertReloadConfig struct { + // CertKeyGetter gets the certificate and the private key. + CertKeyGetter func() ([]byte, []byte, error) + Logger func(fmt string, args ...interface{}) + PollInterval time.Duration +} + +// NewCertReloader returns a MemCertReloader that reloads the (key, cert) pair whenever +// the cert file changes on the filesystem. +func NewCertReloader(config CertReloadConfig) (*MemCertReloader, error) { + if config.Logger == nil { + config.Logger = log.Printf + } + if config.PollInterval == 0 { + config.PollInterval = defaultMemPollInterval + } + + var getter func() (cert []byte, key []byte, _ error) + + if config.CertKeyGetter == nil { + return nil, errors.New("no getter function found in the config") + } + + if config.CertKeyGetter != nil { + getter = config.CertKeyGetter + } + + r := &MemCertReloader{ + getter: getter, + logger: config.Logger, + pollInterval: config.PollInterval, + stop: make(chan struct{}, 10), + } + // load once to ensure cert is good. + if err := r.Reload(); err != nil { + return nil, err + } + go r.pollRefresh() + return r, nil +} + +// ValidateCertExpiry validates the certificate expiry. +func ValidateCertExpiry(certPEM []byte, now time.Time) error { + if len(bytes.TrimSpace(certPEM)) == 0 { + return errors.New("certificate is empty") + } + for { + der, rest := pem.Decode(certPEM) + cp, err := x509.ParseCertificate(der.Bytes) + if err != nil { + return err + } + if now.Before(cp.NotBefore) || now.After(cp.NotAfter) { + return fmt.Errorf("invalid certificate, NotBefore: %v, NotAfter: %v, Now: %v", cp.NotBefore, cp.NotAfter, now) + } + if len(bytes.TrimSpace(rest)) == 0 { + return nil + } + certPEM = rest + } +} diff --git a/utils/cert_reload_test.go b/utils/cert_reload_test.go new file mode 100644 index 0000000..be441b6 --- /dev/null +++ b/utils/cert_reload_test.go @@ -0,0 +1,155 @@ +// Copyright 2024 Yahoo Inc. +// Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms. + +package utils + +import ( + "crypto/tls" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMemCertReloader_Reload(t *testing.T) { + t.Parallel() + type expect struct { + cert *tls.Certificate + wantErr assert.ErrorAssertionFunc + } + + tests := []struct { + name string + setup func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) + certPath string + keyPath string + wantCert *tls.Certificate + wantErr assert.ErrorAssertionFunc + }{ + { + name: "happy path", + certPath: "testdata/client.crt", + keyPath: "testdata/client.key", + setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + t.Fatal(err) + } + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + t.Fatal(err) + } + + reloader, err := NewCertReloader( + CertReloadConfig{ + CertKeyGetter: func() ([]byte, []byte, error) { + return certPEM, keyPEM, nil + }, + }, + ) + if err != nil { + t.Fatal(err) + } + wantCrt, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatal(err) + } + want := &expect{ + cert: &wantCrt, + wantErr: assert.NoError, + } + return reloader, want + }, + }, + { + name: "getter error", + certPath: "testdata/invalid.crt", + keyPath: "testdata/invalid.key", + setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) { + reloader := &MemCertReloader{ + getter: func() ([]byte, []byte, error) { + return nil, nil, fmt.Errorf("get error") + }, + } + want := &expect{ + wantErr: assert.Error, + } + return reloader, want + }, + }, + { + name: "unchanged cert", + certPath: "testdata/client.crt", + keyPath: "testdata/client.key", + setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + t.Fatal(err) + } + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + t.Fatal(err) + } + + reloader, err := NewCertReloader( + CertReloadConfig{ + CertKeyGetter: func() ([]byte, []byte, error) { + return certPEM, keyPEM, nil + }, + }, + ) + if err != nil { + t.Fatal(err) + } + wantCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatal(err) + } + reloader.cert = &wantCert + want := &expect{ + cert: &wantCert, + wantErr: assert.NoError, + } + return reloader, want + }, + }, + { + name: "invalid key pair", + certPath: "testdata/ca.crt", + keyPath: "testdata/client.key", + setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + t.Fatal(err) + } + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + t.Fatal(err) + } + reloader := &MemCertReloader{ + getter: func() ([]byte, []byte, error) { + return certPEM, keyPEM, nil + }, + } + if err != nil { + t.Fatal(err) + } + want := &expect{ + wantErr: assert.Error, + } + return reloader, want + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reloader, want := tt.setup(t, tt.certPath, tt.keyPath) + gotErr := reloader.Reload() + if !want.wantErr(t, gotErr, "unexpected error") { + return + } + assert.Equal(t, reloader.cert, want.cert, "unexpected result") + }) + } +} diff --git a/utils/generate_test.go b/utils/generate_test.go new file mode 100644 index 0000000..fc557a7 --- /dev/null +++ b/utils/generate_test.go @@ -0,0 +1,11 @@ +// Copyright 2022 Yahoo Inc. +// Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms. + +package utils + +//go:generate certstrap init --passphrase "" --common-name "ca" --years 80 +//go:generate certstrap request-cert --passphrase "" --common-name client +//go:generate certstrap sign client --passphrase "" --CA ca --years 80 +//go:generate mkdir -p ./testdata +//go:generate mv -f ./out/ca.crt ./out/client.crt ./out/client.key ./testdata +//go:generate rm -rf ./out diff --git a/utils/testdata/ca.crt b/utils/testdata/ca.crt new file mode 100644 index 0000000..f64992d --- /dev/null +++ b/utils/testdata/ca.crt @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIE3DCCAsSgAwIBAgIBATANBgkqhkiG9w0BAQsFADANMQswCQYDVQQDEwJjYTAg +Fw0yNDExMDcxNzQ4NTFaGA8yMTA2MDUwNzE3NTg1MVowDTELMAkGA1UEAxMCY2Ew +ggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDA3OVLRgACNl6b1IDiAq0c +pL5FncGJC/w5/01LUkgy+9rAk0lwwjnZiXf3aMOC2Q6267uB9BtpaoTLR5h8GRM7 +25GNCps2x8gbo8GaBqN8UfH/Cn+yGg652tZI4ikD5HE1rIGYAnhll3esEV+zCiFr +Nuh+RFyabcprWr0FQ/N6ysrrMdFQNo17WEIp0L3nevznLU1d7uc7h6z2lKU2DBrT +ghFKvwSO724YHhQsvCZOtNcIPsYcwTEHiugLEhZrcYQ2OjgiygCmg71OiPgHoATa +lrUGnv7tibyjvQ/XIZqRu3iL3GJAJJV3S6owHl8eSur0u8RW3mHneEZKqZqF4fXY +isSzmO4SQDJibtiWboQZP2NmkEUR7ar7Y9z6SgyFDie5GH5kvP7g3eHcIs/6sz+s +yM0DY5FO9YsNBuQbfxfEtXSbQ8Y2ZC+0NWgifYCG0DAmaoyRSZjNjBMBm3naNJrQ +V6TPtsANJxhp4b8nTa9W1Fh04w8yH6ROTPYb9LWyYuWuoV6BCE6rj5mmLWRqoT2n +FXxiqAiftg3qrco+ZCh7KZ8ht4+dlxeDC3ki2jpCx33GZZSEQhXX3+ZcizVK24d6 +BOwcbn6NKda+xW/7xaK9dLkYHeQBSxXE+U1X7RmsVfEQr/B7vigwWWzefOG+ECOE +fq6Qq9RaPcTvuFdXuVi8tQIDAQABo0UwQzAOBgNVHQ8BAf8EBAMCAQYwEgYDVR0T +AQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU/wf4Qp+YW2ZZOdUXkbzT6gluGNkwDQYJ +KoZIhvcNAQELBQADggIBAAg/nnhO2YaIrLs89BZAtTwe0UCJVaSx9kt5wTRjsAwV +A149MuGstrootq31mt4k90a7t3X63tUzvAOpDV+/JZAei3ASz2KN13C4BKXToaBW +IcMTNXYwZiL2dF1uqkMxPVo0w+NRyQtVwdYDdqRCqHKY7TFz7wJPjDrYetm5OwfQ +JritO2h5mqxr+ubg2mWECxhHT9D4N6w0dhUCJX/zbH7QF8mvuEarlQB3ct+Ew868 +DoWbvWD3pDRqD8Fjt1CraXm1FWhR84uPLga8XpOQ+NvJWnQFXWLFaqAtBs3a8eaj +nmptRl0Iue/esTSQBRpeqzu73dzCtkeFrR2Nst1Ycpmbl7cFat1m8tvBwAM7Pbku +0Bom1qT6daOZrvbIDXKAlaBAseT6o892PswWUjRSC7ZqhrUHMQTq1oJs4lxbtQTQ +pnOuVwQLWOv+vlaoCufnysP65zxHAvzMt25L7/yyTmp4f+eixn7YReQg+4px9DeE +2loWjq4YTbEcXCzgJ8HR4uhppHKZXJhB/vx7Qg386zgtRtpa/QeGAtfeUFyCHlmu +P35g9wKaonBtN9DQyFN9sBJ9ugLGZ/YeXwCkPzT3OoylNRM+rr2h1E80fiF76jOz +Gx5Uv7/b5ie5/917MPaJfdk99AZ/VOb5m4HBNqB4O8CmgTdbYJjDJAuK8YBz5rxO +-----END CERTIFICATE----- diff --git a/utils/testdata/client.crt b/utils/testdata/client.crt new file mode 100644 index 0000000..8f35f0f --- /dev/null +++ b/utils/testdata/client.crt @@ -0,0 +1,24 @@ +-----BEGIN CERTIFICATE----- +MIIEGzCCAgOgAwIBAgIQK2zwL4S3fDViQnDt0QMyqzANBgkqhkiG9w0BAQsFADAN +MQswCQYDVQQDEwJjYTAgFw0yNDExMDcxNzQ4NTFaGA8yMTA0MTEwNzE3NTg1MVow +ETEPMA0GA1UEAxMGY2xpZW50MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAxKRuswos/E15VoODmFJnDKkmgwB4pfdOb1GBi8r/+stRWldkX72DkOYQVJ/T +qshJocMelEItDI/HwP3kQ3tHTjVklp8ekrqxkFxh3lpxAKTMKzwbPFsLpdBm+jB5 +E5Rz7DA9d9mWVwxzrwOmRM9FWwiRM8NqjELSvigyf+Q3ZuZsgFIxHpO5vQ9h8wHr +J4Xx4MJAAltNe8wD5GGFoZ3S+gJaEOqilPl5RXFu6jUbj6tmFTEKowCoWMeBi9ZQ +EHavkTbOEdSqVSe4mkvM3Hznn//wgnx1Wxd8BYEpeAYYOAqTg5A0hdNOrv67smGc +sJpb6TNNEqzcJpvdHwA38FyUlQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMCA7gwHQYD +VR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBSn7pDoI+w7gybp +E6Zt+78Sl0uR1DAfBgNVHSMEGDAWgBT/B/hCn5hbZlk51ReRvNPqCW4Y2TANBgkq +hkiG9w0BAQsFAAOCAgEARndBgdqO5f4M9vTcbhBfU1qN2CB7TdKqMr1jOsfHMHaS +hu00sw2Y0MKPSsSkV3FndmJUy5sjYHbxEAmjsUVnsaW/1hXCgvnHMl0JbOVdtkiW +qrIMKudTKo49hEk3jl32dmgR0EWj8PRF8blgl7j3SdmixYIpuoJ6zxccET5SxvSc +c2Srl3QP16pBc+OnaHZoEhiUOLRogP+Gn+4daH1iPpTIC5TGvpz9aK2iexoH1wJa +sj6JykZpsfT7pCv5wl2JhNtAKSjEAhRz2gv6Md3lps/0PjG3/cEKxWKdZWtZg0Jk +5iSbAVzi2E7xcfNM3Gmp4f7xiAWN70HH0c/HlcJ/jjlQ9/pt8BBpEItlpW2oGQKc +EtkGvoBWrfPq6WRHhFSAamIL3aHCsqXa2y9CtQ6Wk4eXBIMJ1uU+zfCTEm1DMFJh +JHNENPc8eYwEPloAQDQbwkTRKDKP+FjyhRiMi071X0oVw5byIzRZegQM5We9XV2k +PGGxyqdYxQ2Xv/DHKQpiEnmDjQ5j8xHUPJaT81Do3x8L1oBXsUvPIAQZfV7tgIKk +xl89dyPLYCjAQc9bnAp+YiqS2CRTQDMVmPe+8SpRFLgGjc7YTwTKMdm+lZWj169/ +56KGiU0neQozv92gKPygazzqZ/W/BBTXxsRtSwOHzdQ0SpONeyIdNKqyh+xuc6M= +-----END CERTIFICATE----- diff --git a/utils/testdata/client.key b/utils/testdata/client.key new file mode 100644 index 0000000..37804c2 --- /dev/null +++ b/utils/testdata/client.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAxKRuswos/E15VoODmFJnDKkmgwB4pfdOb1GBi8r/+stRWldk +X72DkOYQVJ/TqshJocMelEItDI/HwP3kQ3tHTjVklp8ekrqxkFxh3lpxAKTMKzwb +PFsLpdBm+jB5E5Rz7DA9d9mWVwxzrwOmRM9FWwiRM8NqjELSvigyf+Q3ZuZsgFIx +HpO5vQ9h8wHrJ4Xx4MJAAltNe8wD5GGFoZ3S+gJaEOqilPl5RXFu6jUbj6tmFTEK +owCoWMeBi9ZQEHavkTbOEdSqVSe4mkvM3Hznn//wgnx1Wxd8BYEpeAYYOAqTg5A0 +hdNOrv67smGcsJpb6TNNEqzcJpvdHwA38FyUlQIDAQABAoIBAHEdcXxWNyhvXIQS +pOlDRXn392pWJBC8YQcR6KZEgNmAiWyhZuDYAE2iufQj+Dt+eTSvK/D23DkkoDB5 +HAxhZtQrTJCEOa8H68pyCe5BMk2/fp7ENZqTePDKH+J9bbiApepQmZtOs/eg9w7O +158vZ+ME25neKHVEnzu5ncsJuYDbu5EFYViL9cftOgKR7tuiDkOx1s6lwvPygKHK +yyv7n7CicnsbSBHyOluqWuCn+jZ7xBDahsPAdI/vxcNuAN/qOQx8taUg+XPQjnP/ +0nM6nYOVx/EPA1HBJzsnXrAkoVeblI18lTF5GXb0QK6dtEJI0xhjaxduprm4B/J7 +NPmsW8UCgYEA5Qs+QJwvPxn4i5BZRvGJsW9JWvHZbLcdCYUKUGMxXcQRetHASZYQ +gFPiV8l12VWVQXgoO/A5c42S1ZwH++16C/vU3Px6pXB37+93zuJnZh0teFPxkrR5 +RJcx+P6y/Cqy8YFFubWN3KxrsUVeYg4Wz6RIS2tvGdCzShwxJm3+YG8CgYEA28j5 ++7NnusMDDMsV4SHwo/RxQ/ePvYpcyzrYTFHmjIA1iAYsW0Fm4dst0VH1tPVRYzZI +BiRX2fBzON/6hR/1JHOIbh0U+0r0GIV21BCUt28e2ouRU7yfjevLzifa8+5zx2eN +ScOzJ7vYV3Iw9T+L8EX8x/RNlkaECo56PleA1TsCgYA0CtDzR0mo8dK6i2rtprYd +neayBl/bxuOPJS6Jw3AVGRbLrFsfnTxUnDrCraDcaAjI3m9t5xB2xAVICfL6eCQS +Ev9z0t5fNuXZm7TCSkkqN5j8TT6HkgA36I7PP7gVefI8052vK6R3LqotllBywTbH +qVFP6bJN0FDclvlH/RgeewKBgQCIyIyj1GTDO+jjBmVohqnwMApp5WGk8b6MkOPa +o8IbQROPw1/Jr2trNvBN7HdBlsd/OmIayHWQYnAjPmn4fgogFHMdLKZJOr2toSpy +EpuridGm6+OXPLYEKnLdq7o9w/J0cILjHJOcL/EVgzDrARCDidsnSmkbFGnK9B8q +O2UnzwKBgQCmHC3J809GUCJ5noNtuVA/qOpInfZvuDc1yo1CUvA5e+J+ZLjEu7Uh +jnWhlxdnQDW1xnSkOmAylLyT/yHBa7+u7qU4OyodOi5VhjUIqBLUtlRNIM+JWQws +8vERY3qu8gcPXf61rnNYwVZ81daJAXcMe2/XT7QrbDzhO7swXYSyzg== +-----END RSA PRIVATE KEY-----