Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add memory cert reloader #517

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/scorecards.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright 2024 Yahoo Inc.
# Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms.

name: Scorecards supply-chain security
on:
# Only the default branch is supported.
Expand Down
2 changes: 2 additions & 0 deletions license_comment
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Copyright 2024 Yahoo Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this mistakenly checked in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the license file for https://github.com/google/addlicense. I'll add a comment in the readme.

Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms.
167 changes: 167 additions & 0 deletions utils/cert_reload.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright 2024 Yahoo Inc.
// Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms.

package utils
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we call this certreload?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'll change it in the next PR.


import (
"bytes"
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"log"
"sync"
"time"
)

const defaultMemPollInterval = 60 * time.Minute

// MemCertReloader reloads the (key, cert) pair by invoking the callback functions
// getter.
type MemCertReloader struct {
mu sync.RWMutex
getter func() ([]byte, []byte, error)
cert *tls.Certificate

logger func(fmt string, args ...interface{})
once sync.Once
stop chan struct{}
pollInterval time.Duration
}

// GetCertificate returns the latest known certificate and can be assigned to the
// GetCertificate member of the TLS config. For http.server use.
func (w *MemCertReloader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return w.GetLatestCertificate()
}

// GetClientCertificate returns the latest known certificate and can be assigned to the
// GetClientCertificate member of the TLS config. For http.client use.
func (w *MemCertReloader) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return w.GetLatestCertificate()
}

// GetLatestCertificate returns the latest known certificate.
func (w *MemCertReloader) GetLatestCertificate() (*tls.Certificate, error) {
w.mu.RLock()
c := w.cert
w.mu.RUnlock()
return c, nil
}

// Close stops the background refresh.
func (w *MemCertReloader) Close() error {
w.once.Do(func() {
close(w.stop)
})
return nil
}

// Reload reloads the certificate into the memory cache when the certificate is updated and valid.
func (w *MemCertReloader) Reload() error {
cb, kb, err := w.getter()
if err != nil {
return fmt.Errorf("failed to get the certificate and private key, %v", err)
}

if err := ValidateCertExpiry(cb, time.Now()); err != nil {
return fmt.Errorf("failed to validate certicate, %v", err)
}

cert, err := tls.X509KeyPair(cb, kb)
if err != nil {
return fmt.Errorf("failed to parse the certificate and private key, %v", err)
}

if w.cert != nil {
if subtle.ConstantTimeCompare(cert.Certificate[0], w.cert.Certificate[0]) == 1 {
return nil
}
}

w.mu.Lock()
w.cert = &cert
w.mu.Unlock()
w.logger("certs reloaded at %v", time.Now())
return nil
}

func (w *MemCertReloader) pollRefresh() {
poll := time.NewTicker(w.pollInterval)
defer poll.Stop()
for {
select {
case <-poll.C:
if err := w.Reload(); err != nil {
w.logger("cert reload error: %v\n", err)
}
case <-w.stop:
return
}
}
}

// CertReloadConfig contains the config for cert reload.
type CertReloadConfig struct {
// CertKeyGetter gets the certificate and the private key.
CertKeyGetter func() ([]byte, []byte, error)
Logger func(fmt string, args ...interface{})
PollInterval time.Duration
}

// NewCertReloader returns a MemCertReloader that reloads the (key, cert) pair whenever
// the cert file changes on the filesystem.
func NewCertReloader(config CertReloadConfig) (*MemCertReloader, error) {
if config.Logger == nil {
config.Logger = log.Printf
}
if config.PollInterval == 0 {
config.PollInterval = defaultMemPollInterval
}

var getter func() (cert []byte, key []byte, _ error)

if config.CertKeyGetter == nil {
return nil, errors.New("no getter function found in the config")
}

if config.CertKeyGetter != nil {
getter = config.CertKeyGetter
}

r := &MemCertReloader{
getter: getter,
logger: config.Logger,
pollInterval: config.PollInterval,
stop: make(chan struct{}, 10),
}
// load once to ensure cert is good.
if err := r.Reload(); err != nil {
return nil, err
}
go r.pollRefresh()
return r, nil
}

// ValidateCertExpiry validates the certificate expiry.
func ValidateCertExpiry(certPEM []byte, now time.Time) error {
if len(bytes.TrimSpace(certPEM)) == 0 {
return errors.New("certificate is empty")
}
for {
der, rest := pem.Decode(certPEM)
cp, err := x509.ParseCertificate(der.Bytes)
if err != nil {
return err
}
if now.Before(cp.NotBefore) || now.After(cp.NotAfter) {
return fmt.Errorf("invalid certificate, NotBefore: %v, NotAfter: %v, Now: %v", cp.NotBefore, cp.NotAfter, now)
}
if len(bytes.TrimSpace(rest)) == 0 {
return nil
}
certPEM = rest
}
}
155 changes: 155 additions & 0 deletions utils/cert_reload_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright 2024 Yahoo Inc.
// Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms.

package utils

import (
"crypto/tls"
"fmt"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func TestMemCertReloader_Reload(t *testing.T) {
t.Parallel()
type expect struct {
cert *tls.Certificate
wantErr assert.ErrorAssertionFunc
}

tests := []struct {
name string
setup func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect)
certPath string
keyPath string
wantCert *tls.Certificate
wantErr assert.ErrorAssertionFunc
}{
{
name: "happy path",
certPath: "testdata/client.crt",
keyPath: "testdata/client.key",
setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
t.Fatal(err)
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
t.Fatal(err)
}

reloader, err := NewCertReloader(
CertReloadConfig{
CertKeyGetter: func() ([]byte, []byte, error) {
return certPEM, keyPEM, nil
},
},
)
if err != nil {
t.Fatal(err)
}
wantCrt, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatal(err)
}
want := &expect{
cert: &wantCrt,
wantErr: assert.NoError,
}
return reloader, want
},
},
{
name: "getter error",
certPath: "testdata/invalid.crt",
keyPath: "testdata/invalid.key",
setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) {
reloader := &MemCertReloader{
getter: func() ([]byte, []byte, error) {
return nil, nil, fmt.Errorf("get error")
},
}
want := &expect{
wantErr: assert.Error,
}
return reloader, want
},
},
{
name: "unchanged cert",
certPath: "testdata/client.crt",
keyPath: "testdata/client.key",
setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
t.Fatal(err)
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
t.Fatal(err)
}

reloader, err := NewCertReloader(
CertReloadConfig{
CertKeyGetter: func() ([]byte, []byte, error) {
return certPEM, keyPEM, nil
},
},
)
if err != nil {
t.Fatal(err)
}
wantCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatal(err)
}
reloader.cert = &wantCert
want := &expect{
cert: &wantCert,
wantErr: assert.NoError,
}
return reloader, want
},
},
{
name: "invalid key pair",
certPath: "testdata/ca.crt",
keyPath: "testdata/client.key",
setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
t.Fatal(err)
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
t.Fatal(err)
}
reloader := &MemCertReloader{
getter: func() ([]byte, []byte, error) {
return certPEM, keyPEM, nil
},
}
if err != nil {
t.Fatal(err)
}
want := &expect{
wantErr: assert.Error,
}
return reloader, want
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reloader, want := tt.setup(t, tt.certPath, tt.keyPath)
gotErr := reloader.Reload()
if !want.wantErr(t, gotErr, "unexpected error") {
return
}
assert.Equal(t, reloader.cert, want.cert, "unexpected result")
})
}
}
11 changes: 11 additions & 0 deletions utils/generate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright 2022 Yahoo Inc.
// Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms.

package utils

//go:generate certstrap init --passphrase "" --common-name "ca" --years 80
//go:generate certstrap request-cert --passphrase "" --common-name client
//go:generate certstrap sign client --passphrase "" --CA ca --years 80
//go:generate mkdir -p ./testdata
//go:generate mv -f ./out/ca.crt ./out/client.crt ./out/client.key ./testdata
//go:generate rm -rf ./out
28 changes: 28 additions & 0 deletions utils/testdata/ca.crt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
-----BEGIN CERTIFICATE-----
MIIE3DCCAsSgAwIBAgIBATANBgkqhkiG9w0BAQsFADANMQswCQYDVQQDEwJjYTAg
Fw0yNDExMDcxNzQ4NTFaGA8yMTA2MDUwNzE3NTg1MVowDTELMAkGA1UEAxMCY2Ew
ggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDA3OVLRgACNl6b1IDiAq0c
pL5FncGJC/w5/01LUkgy+9rAk0lwwjnZiXf3aMOC2Q6267uB9BtpaoTLR5h8GRM7
25GNCps2x8gbo8GaBqN8UfH/Cn+yGg652tZI4ikD5HE1rIGYAnhll3esEV+zCiFr
Nuh+RFyabcprWr0FQ/N6ysrrMdFQNo17WEIp0L3nevznLU1d7uc7h6z2lKU2DBrT
ghFKvwSO724YHhQsvCZOtNcIPsYcwTEHiugLEhZrcYQ2OjgiygCmg71OiPgHoATa
lrUGnv7tibyjvQ/XIZqRu3iL3GJAJJV3S6owHl8eSur0u8RW3mHneEZKqZqF4fXY
isSzmO4SQDJibtiWboQZP2NmkEUR7ar7Y9z6SgyFDie5GH5kvP7g3eHcIs/6sz+s
yM0DY5FO9YsNBuQbfxfEtXSbQ8Y2ZC+0NWgifYCG0DAmaoyRSZjNjBMBm3naNJrQ
V6TPtsANJxhp4b8nTa9W1Fh04w8yH6ROTPYb9LWyYuWuoV6BCE6rj5mmLWRqoT2n
FXxiqAiftg3qrco+ZCh7KZ8ht4+dlxeDC3ki2jpCx33GZZSEQhXX3+ZcizVK24d6
BOwcbn6NKda+xW/7xaK9dLkYHeQBSxXE+U1X7RmsVfEQr/B7vigwWWzefOG+ECOE
fq6Qq9RaPcTvuFdXuVi8tQIDAQABo0UwQzAOBgNVHQ8BAf8EBAMCAQYwEgYDVR0T
AQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU/wf4Qp+YW2ZZOdUXkbzT6gluGNkwDQYJ
KoZIhvcNAQELBQADggIBAAg/nnhO2YaIrLs89BZAtTwe0UCJVaSx9kt5wTRjsAwV
A149MuGstrootq31mt4k90a7t3X63tUzvAOpDV+/JZAei3ASz2KN13C4BKXToaBW
IcMTNXYwZiL2dF1uqkMxPVo0w+NRyQtVwdYDdqRCqHKY7TFz7wJPjDrYetm5OwfQ
JritO2h5mqxr+ubg2mWECxhHT9D4N6w0dhUCJX/zbH7QF8mvuEarlQB3ct+Ew868
DoWbvWD3pDRqD8Fjt1CraXm1FWhR84uPLga8XpOQ+NvJWnQFXWLFaqAtBs3a8eaj
nmptRl0Iue/esTSQBRpeqzu73dzCtkeFrR2Nst1Ycpmbl7cFat1m8tvBwAM7Pbku
0Bom1qT6daOZrvbIDXKAlaBAseT6o892PswWUjRSC7ZqhrUHMQTq1oJs4lxbtQTQ
pnOuVwQLWOv+vlaoCufnysP65zxHAvzMt25L7/yyTmp4f+eixn7YReQg+4px9DeE
2loWjq4YTbEcXCzgJ8HR4uhppHKZXJhB/vx7Qg386zgtRtpa/QeGAtfeUFyCHlmu
P35g9wKaonBtN9DQyFN9sBJ9ugLGZ/YeXwCkPzT3OoylNRM+rr2h1E80fiF76jOz
Gx5Uv7/b5ie5/917MPaJfdk99AZ/VOb5m4HBNqB4O8CmgTdbYJjDJAuK8YBz5rxO
-----END CERTIFICATE-----
24 changes: 24 additions & 0 deletions utils/testdata/client.crt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
-----BEGIN CERTIFICATE-----
MIIEGzCCAgOgAwIBAgIQK2zwL4S3fDViQnDt0QMyqzANBgkqhkiG9w0BAQsFADAN
MQswCQYDVQQDEwJjYTAgFw0yNDExMDcxNzQ4NTFaGA8yMTA0MTEwNzE3NTg1MVow
ETEPMA0GA1UEAxMGY2xpZW50MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC
AQEAxKRuswos/E15VoODmFJnDKkmgwB4pfdOb1GBi8r/+stRWldkX72DkOYQVJ/T
qshJocMelEItDI/HwP3kQ3tHTjVklp8ekrqxkFxh3lpxAKTMKzwbPFsLpdBm+jB5
E5Rz7DA9d9mWVwxzrwOmRM9FWwiRM8NqjELSvigyf+Q3ZuZsgFIxHpO5vQ9h8wHr
J4Xx4MJAAltNe8wD5GGFoZ3S+gJaEOqilPl5RXFu6jUbj6tmFTEKowCoWMeBi9ZQ
EHavkTbOEdSqVSe4mkvM3Hznn//wgnx1Wxd8BYEpeAYYOAqTg5A0hdNOrv67smGc
sJpb6TNNEqzcJpvdHwA38FyUlQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMCA7gwHQYD
VR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBSn7pDoI+w7gybp
E6Zt+78Sl0uR1DAfBgNVHSMEGDAWgBT/B/hCn5hbZlk51ReRvNPqCW4Y2TANBgkq
hkiG9w0BAQsFAAOCAgEARndBgdqO5f4M9vTcbhBfU1qN2CB7TdKqMr1jOsfHMHaS
hu00sw2Y0MKPSsSkV3FndmJUy5sjYHbxEAmjsUVnsaW/1hXCgvnHMl0JbOVdtkiW
qrIMKudTKo49hEk3jl32dmgR0EWj8PRF8blgl7j3SdmixYIpuoJ6zxccET5SxvSc
c2Srl3QP16pBc+OnaHZoEhiUOLRogP+Gn+4daH1iPpTIC5TGvpz9aK2iexoH1wJa
sj6JykZpsfT7pCv5wl2JhNtAKSjEAhRz2gv6Md3lps/0PjG3/cEKxWKdZWtZg0Jk
5iSbAVzi2E7xcfNM3Gmp4f7xiAWN70HH0c/HlcJ/jjlQ9/pt8BBpEItlpW2oGQKc
EtkGvoBWrfPq6WRHhFSAamIL3aHCsqXa2y9CtQ6Wk4eXBIMJ1uU+zfCTEm1DMFJh
JHNENPc8eYwEPloAQDQbwkTRKDKP+FjyhRiMi071X0oVw5byIzRZegQM5We9XV2k
PGGxyqdYxQ2Xv/DHKQpiEnmDjQ5j8xHUPJaT81Do3x8L1oBXsUvPIAQZfV7tgIKk
xl89dyPLYCjAQc9bnAp+YiqS2CRTQDMVmPe+8SpRFLgGjc7YTwTKMdm+lZWj169/
56KGiU0neQozv92gKPygazzqZ/W/BBTXxsRtSwOHzdQ0SpONeyIdNKqyh+xuc6M=
-----END CERTIFICATE-----
Loading
Loading