From d9bb83977aaf7603df69998c4923bd00ab4c20b6 Mon Sep 17 00:00:00 2001 From: chaosinthecrd Date: Tue, 5 Dec 2023 17:43:16 +0000 Subject: [PATCH] adding ability to hot-reload certs when files modified Signed-off-by: chaosinthecrd --- internal/provider/provider.go | 196 ++++++++++++++++++++++++++++++++++ main.go | 85 ++------------- pkg/utils/watcher.go | 102 ++++++++++++++++++ 3 files changed, 305 insertions(+), 78 deletions(-) create mode 100644 internal/provider/provider.go create mode 100644 pkg/utils/watcher.go diff --git a/internal/provider/provider.go b/internal/provider/provider.go new file mode 100644 index 0000000..13ff376 --- /dev/null +++ b/internal/provider/provider.go @@ -0,0 +1,196 @@ +package provider + +import ( + "context" + "crypto" + "crypto/tls" + "crypto/x509" + "errors" + "flag" + "fmt" + "github.com/go-logr/logr" + "github.com/testifysec/archivista-data-provider/pkg/handler" + "github.com/testifysec/archivista-data-provider/pkg/utils" + "github.com/testifysec/archivista-data-provider/pkg/utils/certs" + "github.com/testifysec/go-witness/archivista" + "k8s.io/klog/v2" + "net/http" + "os" + "time" +) + +const ( + // TODO: Fix timeout handling. + timeout = 300 * time.Second + defaultPort = 8090 + apiVersion = "externaldata.gatekeeper.sh/v1alpha1" + defaultCertFile = "/etc/ssl/certs/server.crt" + defaultKeyFile = "/etc/ssl/certs/server.key" + defaultCaCertFile = "/usr/local/tls/client-ca/ca.crt" +) + +var ( + certFile string + keyFile string + clientCAFile string + port int +) + +// Provider is used for running the gatekeeper provider. Provider will verify images when sent requests to do so by Gatekeeper. +type Provider struct { + // log is the Controller logger. + log logr.Logger + + ctx context.Context + + tls TLS + + watchers map[string]*utils.Watcher +} + +type TLS struct { + certificate *tls.Certificate + key *crypto.PrivateKey + clientCAs *x509.CertPool +} + +// New constructs a new Provider instance. +func New() (*Provider, error) { + p := &Provider{ + ctx: context.Background(), + watchers: map[string]*utils.Watcher{}, + } + + klog.InitFlags(nil) + flag.StringVar(&certFile, "tls-cert-file", "", "path to the file containing the TLS certificate for the provider") + flag.StringVar(&keyFile, "tls-key-file", "", "path to the file containing the TLS private key for the provider") + flag.StringVar(&clientCAFile, "client-ca-file", defaultCaCertFile, "path to client CA certificate") + flag.IntVar(&port, "port", defaultPort, "Port for the server to listen on") + flag.Parse() + + if certFile == "" || keyFile == "" { + return nil, fmt.Errorf("tls certificate and key path is required for the provider") + } + + f, err := os.ReadFile(certFile) + if err != nil { + return nil, errors.Join(err, fmt.Errorf("reading client certificate file from path %s", certFile)) + } + + p.tls.certificate, err = certs.ParseCert(f) + if err != nil { + return nil, errors.Join(err, fmt.Errorf("faild to parse certificate")) + } + + w, err := utils.NewWatcher(p.ctx, certFile, func() error { + f, err := os.ReadFile(certFile) + if err != nil { + return err + } + + p.tls.certificate, err = certs.ParseCert(f) + if err != nil { + return errors.Join(err, fmt.Errorf("faild to parse certificate")) + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to create watcher: %w", err) + } + + p.watchers["cert"] = w + + f, err = os.ReadFile(keyFile) + if err != nil { + return nil, errors.Join(err, fmt.Errorf("reading client private key from path %s", keyFile)) + } + + p.tls.key, err = certs.ParseKey(f) + if err != nil { + return nil, errors.Join(err, fmt.Errorf("faild to parse certificate key")) + } + + w, err = utils.NewWatcher(p.ctx, keyFile, func() error { + f, err := os.ReadFile(keyFile) + if err != nil { + return err + } + + p.tls.key, err = certs.ParseKey(f) + if err != nil { + return errors.Join(err, fmt.Errorf("faild to parse certificate key")) + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to create watcher: %w", err) + } + + p.watchers["key"] = w + + // For now having the CA cert field populated is going to be optional. There might be other ways that people are mounting in the CA cert. + if clientCAFile != "" { + cacert, err := os.ReadFile(clientCAFile) + if err != nil { + return nil, errors.Join(err, fmt.Errorf("reading gatekeeper CA certificate file from path %s", clientCAFile)) + } + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(cacert) { + return nil, fmt.Errorf("Failed to add Gatekeeper CA Certificate to pool") + } + + p.tls.clientCAs = certPool + + } else { + // TODO: Logically, I don't really see why we shouldn't be trying to hot-reload the system cert pool. + // I am not sure at the moment how this would work however. + rootCAs, _ := x509.SystemCertPool() + if err != nil { + return nil, errors.Join(err, fmt.Errorf("failed to get system certificate pool")) + } + if rootCAs == nil { + rootCAs = x509.NewCertPool() + } + + p.tls.clientCAs = rootCAs + } + + return p, nil +} + +func (p *Provider) Start() error { + fmt.Println("starting server...") + + ac := archivista.New("https://archivista.testifysec.io") + vh := handler.NewValidateHandler(ac) + + mux := http.NewServeMux() + mux.HandleFunc("/", vh.Handler) + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS13, + ClientCAs: p.tls.clientCAs, + ClientAuth: tls.RequireAndVerifyClientCert, + GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + var certificate tls.Certificate + certificate.Certificate = p.tls.certificate.Certificate + certificate.PrivateKey = *p.tls.key + return &certificate, nil + }, + } + + server := &http.Server{ + Addr: fmt.Sprintf(":%d", port), + Handler: mux, + TLSConfig: tlsConfig, + ReadHeaderTimeout: time.Duration(5) * time.Second, + } + + if err := server.ListenAndServeTLS("", ""); err != nil { + return err + } + + return nil +} diff --git a/main.go b/main.go index e3b5b03..21f748a 100644 --- a/main.go +++ b/main.go @@ -1,91 +1,20 @@ package main import ( - "crypto/tls" - "crypto/x509" - "flag" - "fmt" - "net/http" "os" - "path/filepath" - "time" - - "github.com/testifysec/archivista-data-provider/pkg/handler" - "github.com/testifysec/archivista-data-provider/pkg/manager" - "github.com/testifysec/go-witness/archivista" + "github.com/testifysec/archivista-data-provider/internal/provider" "k8s.io/klog/v2" ) -const ( - // TODO: Fix timeout handling. - timeout = 300 * time.Second - defaultPort = 8090 - - certName = "tls.crt" - keyName = "tls.key" -) - -var ( - certDir string - clientCAFile string - port int -) - -func init() { - klog.InitFlags(nil) - flag.StringVar(&certDir, "cert-dir", "", "path to directory containing TLS certificates") - flag.StringVar(&clientCAFile, "client-ca-file", "", "path to client CA certificate") - flag.IntVar(&port, "port", defaultPort, "Port for the server to listen on") - flag.Parse() -} - func main() { - ac := archivista.New("https://archivista.testifysec.io") - vh := handler.NewValidateHandler(ac) - - mux := http.NewServeMux() - mux.HandleFunc("/", vh.Handler) - - server := &http.Server{ - Addr: fmt.Sprintf(":%d", port), - Handler: mux, - ReadHeaderTimeout: time.Duration(5) * time.Second, - } - - config := &tls.Config{ - MinVersion: tls.VersionTLS13, - } - if clientCAFile != "" { - klog.InfoS("loading Gatekeeper's CA certificate", "clientCAFile", clientCAFile) - caCert, err := os.ReadFile(clientCAFile) - if err != nil { - klog.ErrorS(err, "unable to load Gatekeeper's CA certificate", "clientCAFile", clientCAFile) - os.Exit(1) - } - - clientCAs := x509.NewCertPool() - clientCAs.AppendCertsFromPEM(caCert) - - config.ClientCAs = clientCAs - config.ClientAuth = tls.RequireAndVerifyClientCert - server.TLSConfig = config + c, err := provider.New() + if err != nil { + klog.ErrorS(err, "unable to initialize archivista data provider server") + os.Exit(1) } - if certDir != "" { - certFile := filepath.Join(certDir, certName) - keyFile := filepath.Join(certDir, keyName) - - klog.Info("start archivista controller manager") - go manager.StartManager() + klog.Info("starting archivista data provider...") - klog.InfoS("starting archivista data provider server", "port", port, "certFile", certFile, "keyFile", keyFile) - if err := server.ListenAndServeTLS(certFile, keyFile); err != nil { - klog.ErrorS(err, "unable to start archivista data provider server") - os.Exit(1) - } - } else { - klog.Error("TLS certificates are not provided, the server will not be started") - os.Exit(1) - } + c.Start() } diff --git a/pkg/utils/watcher.go b/pkg/utils/watcher.go new file mode 100644 index 0000000..79c3f51 --- /dev/null +++ b/pkg/utils/watcher.go @@ -0,0 +1,102 @@ +package utils + +import ( + "context" + "log" + "time" + + "github.com/fsnotify/fsnotify" +) + +// Watcher is an opinionated fsnotify.Watcher that is designed to +// watch Kubernetes config maps and perform actions on change. +type Watcher struct { + actions []func() error + notify chan struct{} +} + +// Notify manually runs all actions in a watcher +func (w *Watcher) Notify() { + go func(w *Watcher) { + w.notify <- struct{}{} + }(w) +} + +func NewWatcher(ctx context.Context, filePath string, actions ...func() error) (*Watcher, error) { + w := &Watcher{ + actions: actions, + notify: make(chan struct{}), + } + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + err = watcher.Add(filePath) + if err != nil { + return nil, err + } + + go func(w *Watcher, watcher *fsnotify.Watcher) { + // only perform actions every 5 seconds at most + t := time.NewTicker(5 * time.Second) + defer t.Stop() + // we only really care about the last fsnotify event, as we are going to attempt to perform actions + // every 5 seconds. If there were 2 writes in that time we don't really mind. + var lastEvent *fsnotify.Event + runAll := func(actions ...func() error) []error { + var allErrors []error + for _, a := range w.actions { + if err := a(); err != nil { + allErrors = append(allErrors, err) + } + } + return allErrors + } + for { + select { + case <-w.notify: + allErrors := runAll(w.actions...) + for _, e := range allErrors { + log.Printf("error while reloading config (%s)", e.Error()) + } + case <-t.C: + if lastEvent == nil { + continue + } + allErrors := runAll(w.actions...) + for _, e := range allErrors { + log.Printf("error while reloading file %s (%s)", lastEvent.Name, e.Error()) + } + // if no errors, clear the last event. + if len(allErrors) == 0 { + lastEvent = nil + } + case event, ok := <-watcher.Events: + if !ok { + return + } + // When a config map is updated, behind the scenes Kubernetes creates + // a new directory with the new contents, then replaces the symlink to point + // to the new config map, then deletes the old one. In this case, we get a delete + // event rather than a write event. + if event.Op == fsnotify.Remove { + // Only error here would be attempting to remove a non-existent watch + _ = watcher.Remove(event.Name) + err := watcher.Add(event.Name) + if err != nil { + log.Fatalf("file %s change detected, but could not re-watch the file: %s", event.Name, err.Error()) + } + lastEvent = &event + } + // However, people might not be using a config map after all + if event.Op == fsnotify.Write || event.Op == fsnotify.Create { + lastEvent = &event + } + case <-ctx.Done(): + watcher.Close() + return + } + } + }(w, watcher) + return w, nil +}