Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Certificate Hot Reloading #24

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package provider

import (
"context"
"crypto"
"crypto/tls"
"crypto/x509"
"errors"
"flag"
"fmt"
"github.com/go-logr/logr"
"github.com/testifysec/archivista-data-provider/pkg/handler"
"github.com/testifysec/archivista-data-provider/pkg/utils"
"github.com/testifysec/archivista-data-provider/pkg/utils/certs"

Check failure on line 14 in internal/provider/provider.go

View workflow job for this annotation

GitHub Actions / Analyze (go)

no required module provides package github.com/testifysec/archivista-data-provider/pkg/utils/certs; to add it:
"github.com/testifysec/go-witness/archivista"
"k8s.io/klog/v2"
"net/http"
"os"
"time"
)

const (
// TODO: Fix timeout handling.
timeout = 300 * time.Second
defaultPort = 8090
apiVersion = "externaldata.gatekeeper.sh/v1alpha1"
defaultCertFile = "/etc/ssl/certs/server.crt"
defaultKeyFile = "/etc/ssl/certs/server.key"
defaultCaCertFile = "/usr/local/tls/client-ca/ca.crt"
)

var (
certFile string
keyFile string
clientCAFile string
port int
)

// Provider is used for running the gatekeeper provider. Provider will verify images when sent requests to do so by Gatekeeper.
type Provider struct {
// log is the Controller logger.
log logr.Logger

ctx context.Context

tls TLS

watchers map[string]*utils.Watcher
}

type TLS struct {
certificate *tls.Certificate
key *crypto.PrivateKey
clientCAs *x509.CertPool
}

// New constructs a new Provider instance.
func New() (*Provider, error) {
p := &Provider{
ctx: context.Background(),
watchers: map[string]*utils.Watcher{},
}

klog.InitFlags(nil)
flag.StringVar(&certFile, "tls-cert-file", "", "path to the file containing the TLS certificate for the provider")
flag.StringVar(&keyFile, "tls-key-file", "", "path to the file containing the TLS private key for the provider")
flag.StringVar(&clientCAFile, "client-ca-file", defaultCaCertFile, "path to client CA certificate")
flag.IntVar(&port, "port", defaultPort, "Port for the server to listen on")
flag.Parse()

if certFile == "" || keyFile == "" {
return nil, fmt.Errorf("tls certificate and key path is required for the provider")
}

f, err := os.ReadFile(certFile)
if err != nil {
return nil, errors.Join(err, fmt.Errorf("reading client certificate file from path %s", certFile))
}

p.tls.certificate, err = certs.ParseCert(f)
if err != nil {
return nil, errors.Join(err, fmt.Errorf("faild to parse certificate"))
}

w, err := utils.NewWatcher(p.ctx, certFile, func() error {
f, err := os.ReadFile(certFile)
if err != nil {
return err
}

p.tls.certificate, err = certs.ParseCert(f)
if err != nil {
return errors.Join(err, fmt.Errorf("faild to parse certificate"))
}

return nil
})
if err != nil {
return nil, fmt.Errorf("failed to create watcher: %w", err)
}

p.watchers["cert"] = w

f, err = os.ReadFile(keyFile)
if err != nil {
return nil, errors.Join(err, fmt.Errorf("reading client private key from path %s", keyFile))
}

p.tls.key, err = certs.ParseKey(f)
if err != nil {
return nil, errors.Join(err, fmt.Errorf("faild to parse certificate key"))
}

w, err = utils.NewWatcher(p.ctx, keyFile, func() error {
f, err := os.ReadFile(keyFile)
if err != nil {
return err
}

p.tls.key, err = certs.ParseKey(f)
if err != nil {
return errors.Join(err, fmt.Errorf("faild to parse certificate key"))
}

return nil
})
if err != nil {
return nil, fmt.Errorf("failed to create watcher: %w", err)
}

p.watchers["key"] = w

// For now having the CA cert field populated is going to be optional. There might be other ways that people are mounting in the CA cert.
if clientCAFile != "" {
cacert, err := os.ReadFile(clientCAFile)
if err != nil {
return nil, errors.Join(err, fmt.Errorf("reading gatekeeper CA certificate file from path %s", clientCAFile))
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(cacert) {
return nil, fmt.Errorf("Failed to add Gatekeeper CA Certificate to pool")
}

p.tls.clientCAs = certPool

} else {
// TODO: Logically, I don't really see why we shouldn't be trying to hot-reload the system cert pool.
// I am not sure at the moment how this would work however.
rootCAs, _ := x509.SystemCertPool()
if err != nil {
return nil, errors.Join(err, fmt.Errorf("failed to get system certificate pool"))
}
if rootCAs == nil {
rootCAs = x509.NewCertPool()
}

p.tls.clientCAs = rootCAs
}

return p, nil
}

func (p *Provider) Start() error {
fmt.Println("starting server...")

ac := archivista.New("https://archivista.testifysec.io")
vh := handler.NewValidateHandler(ac)

mux := http.NewServeMux()
mux.HandleFunc("/", vh.Handler)

tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
ClientCAs: p.tls.clientCAs,
ClientAuth: tls.RequireAndVerifyClientCert,
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
var certificate tls.Certificate
certificate.Certificate = p.tls.certificate.Certificate
certificate.PrivateKey = *p.tls.key
return &certificate, nil
},
}

server := &http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: mux,
TLSConfig: tlsConfig,
ReadHeaderTimeout: time.Duration(5) * time.Second,
}

if err := server.ListenAndServeTLS("", ""); err != nil {
return err
}

return nil
}
85 changes: 7 additions & 78 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,91 +1,20 @@
package main

import (
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"net/http"
"os"
"path/filepath"
"time"

"github.com/testifysec/archivista-data-provider/pkg/handler"
"github.com/testifysec/archivista-data-provider/pkg/manager"
"github.com/testifysec/go-witness/archivista"

"github.com/testifysec/archivista-data-provider/internal/provider"
"k8s.io/klog/v2"
)

const (
// TODO: Fix timeout handling.
timeout = 300 * time.Second
defaultPort = 8090

certName = "tls.crt"
keyName = "tls.key"
)

var (
certDir string
clientCAFile string
port int
)

func init() {
klog.InitFlags(nil)
flag.StringVar(&certDir, "cert-dir", "", "path to directory containing TLS certificates")
flag.StringVar(&clientCAFile, "client-ca-file", "", "path to client CA certificate")
flag.IntVar(&port, "port", defaultPort, "Port for the server to listen on")
flag.Parse()
}

func main() {
ac := archivista.New("https://archivista.testifysec.io")
vh := handler.NewValidateHandler(ac)

mux := http.NewServeMux()
mux.HandleFunc("/", vh.Handler)

server := &http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: mux,
ReadHeaderTimeout: time.Duration(5) * time.Second,
}

config := &tls.Config{
MinVersion: tls.VersionTLS13,
}
if clientCAFile != "" {
klog.InfoS("loading Gatekeeper's CA certificate", "clientCAFile", clientCAFile)
caCert, err := os.ReadFile(clientCAFile)
if err != nil {
klog.ErrorS(err, "unable to load Gatekeeper's CA certificate", "clientCAFile", clientCAFile)
os.Exit(1)
}

clientCAs := x509.NewCertPool()
clientCAs.AppendCertsFromPEM(caCert)

config.ClientCAs = clientCAs
config.ClientAuth = tls.RequireAndVerifyClientCert
server.TLSConfig = config
c, err := provider.New()
if err != nil {
klog.ErrorS(err, "unable to initialize archivista data provider server")
os.Exit(1)
}

if certDir != "" {
certFile := filepath.Join(certDir, certName)
keyFile := filepath.Join(certDir, keyName)

klog.Info("start archivista controller manager")
go manager.StartManager()
klog.Info("starting archivista data provider...")

klog.InfoS("starting archivista data provider server", "port", port, "certFile", certFile, "keyFile", keyFile)
if err := server.ListenAndServeTLS(certFile, keyFile); err != nil {
klog.ErrorS(err, "unable to start archivista data provider server")
os.Exit(1)
}
} else {
klog.Error("TLS certificates are not provided, the server will not be started")
os.Exit(1)
}
c.Start()
}
102 changes: 102 additions & 0 deletions pkg/utils/watcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package utils

import (
"context"
"log"
"time"

"github.com/fsnotify/fsnotify"
)

// Watcher is an opinionated fsnotify.Watcher that is designed to
// watch Kubernetes config maps and perform actions on change.
type Watcher struct {
actions []func() error
notify chan struct{}
}

// Notify manually runs all actions in a watcher
func (w *Watcher) Notify() {
go func(w *Watcher) {
w.notify <- struct{}{}
}(w)
}

func NewWatcher(ctx context.Context, filePath string, actions ...func() error) (*Watcher, error) {
w := &Watcher{
actions: actions,
notify: make(chan struct{}),
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
err = watcher.Add(filePath)
if err != nil {
return nil, err
}

go func(w *Watcher, watcher *fsnotify.Watcher) {
// only perform actions every 5 seconds at most
t := time.NewTicker(5 * time.Second)
defer t.Stop()
// we only really care about the last fsnotify event, as we are going to attempt to perform actions
// every 5 seconds. If there were 2 writes in that time we don't really mind.
var lastEvent *fsnotify.Event
runAll := func(actions ...func() error) []error {
var allErrors []error
for _, a := range w.actions {
if err := a(); err != nil {
allErrors = append(allErrors, err)
}
}
return allErrors
}
for {
select {
case <-w.notify:
allErrors := runAll(w.actions...)
for _, e := range allErrors {
log.Printf("error while reloading config (%s)", e.Error())
}
case <-t.C:
if lastEvent == nil {
continue
}
allErrors := runAll(w.actions...)
for _, e := range allErrors {
log.Printf("error while reloading file %s (%s)", lastEvent.Name, e.Error())
}
// if no errors, clear the last event.
if len(allErrors) == 0 {
lastEvent = nil
}
case event, ok := <-watcher.Events:
if !ok {
return
}
// When a config map is updated, behind the scenes Kubernetes creates
// a new directory with the new contents, then replaces the symlink to point
// to the new config map, then deletes the old one. In this case, we get a delete
// event rather than a write event.
if event.Op == fsnotify.Remove {
// Only error here would be attempting to remove a non-existent watch
_ = watcher.Remove(event.Name)
err := watcher.Add(event.Name)
if err != nil {
log.Fatalf("file %s change detected, but could not re-watch the file: %s", event.Name, err.Error())
}
lastEvent = &event
}
// However, people might not be using a config map after all
if event.Op == fsnotify.Write || event.Op == fsnotify.Create {
lastEvent = &event
}
case <-ctx.Done():
watcher.Close()
return
}
}
}(w, watcher)
return w, nil
}
Loading