From ee4ab8f5dcbdabab38d5338e3fd7329649996f4d Mon Sep 17 00:00:00 2001 From: Simon Ferquel Date: Thu, 8 Mar 2018 16:07:10 +0100 Subject: [PATCH] Allow users to create TLS config from arbitrary sources of PEM data Fix #52 Signed-off-by: Simon Ferquel --- tlsconfig/config.go | 94 +++++++++++++++++++++++------ tlsconfig/config_test.go | 126 ++++++++++++++++++++++++++------------- 2 files changed, 160 insertions(+), 60 deletions(-) diff --git a/tlsconfig/config.go b/tlsconfig/config.go index 0ef3fdcb4..855cb9abe 100644 --- a/tlsconfig/config.go +++ b/tlsconfig/config.go @@ -16,15 +16,47 @@ import ( "github.com/pkg/errors" ) +// PEMSource represents a provider of some PEM Block +type PEMSource interface { + Data() ([]byte, error) + Name() string +} + +// PEMFile is a PEMSource backed by a file +type PEMFile string + +// Data returns PEM data from the file +func (f PEMFile) Data() ([]byte, error) { + return ioutil.ReadFile(string(f)) +} + +// Name returns the name of the file +func (f PEMFile) Name() string { + return string(f) +} + +// PEMInMemory is a PENSource from memory +type PEMInMemory []byte + +// Data returns PEM data from memory +func (m PEMInMemory) Data() ([]byte, error) { + return []byte(m), nil +} + +// Name returns +func (m PEMInMemory) Name() string { + return "" +} + // Options represents the information needed to create client and server TLS configurations. type Options struct { - CAFile string + CA PEMSource - // If either CertFile or KeyFile is empty, Client() will not load them + // If either Cert or Key is nil, Client() will not load them // preventing the client from authenticating to the server. // However, Server() requires them and will error out if they are empty. - CertFile string - KeyFile string + Cert PEMSource + Key PEMSource // client-only option InsecureSkipVerify bool @@ -94,7 +126,10 @@ func ClientDefault(ops ...func(*tls.Config)) *tls.Config { } // certPool returns an X.509 certificate pool from `caFile`, the certificate file. -func certPool(caFile string, exclusivePool bool) (*x509.CertPool, error) { +func certPool(ca PEMSource, exclusivePool bool) (*x509.CertPool, error) { + if ca == nil { + return nil, fmt.Errorf("no CA PEM data") + } // If we should verify the server, we need to load a trusted ca var ( certPool *x509.CertPool @@ -108,12 +143,12 @@ func certPool(caFile string, exclusivePool bool) (*x509.CertPool, error) { return nil, fmt.Errorf("failed to read system certificates: %v", err) } } - pem, err := ioutil.ReadFile(caFile) + pem, err := ca.Data() if err != nil { - return nil, fmt.Errorf("could not read CA certificate %q: %v", caFile, err) + return nil, fmt.Errorf("could not read CA certificate %q: %v", ca.Name(), err) } if !certPool.AppendCertsFromPEM(pem) { - return nil, fmt.Errorf("failed to append certificates from PEM file: %q", caFile) + return nil, fmt.Errorf("failed to append certificates from PEM file: %q", ca.Name()) } return certPool, nil } @@ -172,18 +207,24 @@ func getPrivateKey(keyBytes []byte, passphrase string) ([]byte, error) { // if the key is encrypted, the Passphrase in 'options' will be used to // decrypt it. func getCert(options Options) ([]tls.Certificate, error) { - if options.CertFile == "" && options.KeyFile == "" { + if options.Cert == nil && options.Key == nil { return nil, nil } + if options.Cert == nil { + return nil, errors.New("cert is missing") + } + if options.Key == nil { + return nil, errors.New("key is missing") + } errMessage := "Could not load X509 key pair" - cert, err := ioutil.ReadFile(options.CertFile) + cert, err := options.Cert.Data() if err != nil { return nil, errors.Wrap(err, errMessage) } - prKeyBytes, err := ioutil.ReadFile(options.KeyFile) + prKeyBytes, err := options.Key.Data() if err != nil { return nil, errors.Wrap(err, errMessage) } @@ -205,8 +246,8 @@ func getCert(options Options) ([]tls.Certificate, error) { func Client(options Options) (*tls.Config, error) { tlsConfig := ClientDefault() tlsConfig.InsecureSkipVerify = options.InsecureSkipVerify - if !options.InsecureSkipVerify && options.CAFile != "" { - CAs, err := certPool(options.CAFile, options.ExclusiveRootPools) + if !options.InsecureSkipVerify && options.CA != nil { + CAs, err := certPool(options.CA, options.ExclusiveRootPools) if err != nil { return nil, err } @@ -230,16 +271,33 @@ func Client(options Options) (*tls.Config, error) { func Server(options Options) (*tls.Config, error) { tlsConfig := ServerDefault() tlsConfig.ClientAuth = options.ClientAuth - tlsCert, err := tls.LoadX509KeyPair(options.CertFile, options.KeyFile) + if options.Cert == nil { + return nil, errors.New("cert is missing") + } + if options.Key == nil { + return nil, errors.New("key is missing") + } + cert, err := options.Cert.Data() if err != nil { if os.IsNotExist(err) { - return nil, fmt.Errorf("Could not load X509 key pair (cert: %q, key: %q): %v", options.CertFile, options.KeyFile, err) + return nil, fmt.Errorf("Could not load X509 key pair (cert: %q, key: %q): %v", options.Cert.Name(), options.Key.Name(), err) } - return nil, fmt.Errorf("Error reading X509 key pair (cert: %q, key: %q): %v. Make sure the key is not encrypted.", options.CertFile, options.KeyFile, err) + return nil, fmt.Errorf("could not read cert data: %s", err) + } + key, err := options.Key.Data() + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("Could not load X509 key pair (cert: %q, key: %q): %v", options.Cert.Name(), options.Key.Name(), err) + } + return nil, fmt.Errorf("could not read key data: %s", err) + } + tlsCert, err := tls.X509KeyPair(cert, key) + if err != nil { + return nil, fmt.Errorf("Error reading X509 key pair (cert: %q, key: %q): %v. Make sure the key is not encrypted.", options.Cert.Name(), options.Key.Name(), err) } tlsConfig.Certificates = []tls.Certificate{tlsCert} - if options.ClientAuth >= tls.VerifyClientCertIfGiven && options.CAFile != "" { - CAs, err := certPool(options.CAFile, options.ExclusiveRootPools) + if options.ClientAuth >= tls.VerifyClientCertIfGiven && options.CA != nil { + CAs, err := certPool(options.CA, options.ExclusiveRootPools) if err != nil { return nil, err } diff --git a/tlsconfig/config_test.go b/tlsconfig/config_test.go index 345cbe778..4b9bdf56c 100644 --- a/tlsconfig/config_test.go +++ b/tlsconfig/config_test.go @@ -85,9 +85,9 @@ func TestConfigServerTLSFailsIfUnableToLoadCerts(t *testing.T) { files[i] = badFile result, err := Server(Options{ - CertFile: files[0], - KeyFile: files[1], - CAFile: files[2], + Cert: PEMFile(files[0]), + Key: PEMFile(files[1]), + CA: PEMFile(files[2]), ClientAuth: tls.VerifyClientCertIfGiven, }) if err == nil || result != nil { @@ -108,8 +108,8 @@ func TestConfigServerTLSServerCertsOnly(t *testing.T) { } tlsConfig, err := Server(Options{ - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), }) if err != nil || tlsConfig == nil { t.Fatal("Unable to configure server TLS", err) @@ -145,10 +145,10 @@ func TestConfigServerTLSClientCANotSetIfClientAuthTooLow(t *testing.T) { ca := getMultiCert() tlsConfig, err := Server(Options{ - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), ClientAuth: tls.RequestClientCert, - CAFile: ca, + CA: PEMFile(ca), }) if err != nil || tlsConfig == nil { @@ -173,10 +173,10 @@ func TestConfigServerTLSClientCASet(t *testing.T) { ca := getMultiCert() tlsConfig, err := Server(Options{ - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), ClientAuth: tls.VerifyClientCertIfGiven, - CAFile: ca, + CA: PEMFile(ca), }) if err != nil || tlsConfig == nil { @@ -226,10 +226,10 @@ func TestConfigServerExclusiveRootPools(t *testing.T) { // ExclusiveRootPools not set, so should be able to verify both system-signed certs // and custom CA-signed certs tlsConfig, err := Server(Options{ - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), ClientAuth: tls.VerifyClientCertIfGiven, - CAFile: ca, + CA: PEMFile(ca), }) if err != nil || tlsConfig == nil { @@ -245,10 +245,10 @@ func TestConfigServerExclusiveRootPools(t *testing.T) { // ExclusiveRootPools set and custom CA provided, so system certs should not be verifiable // and custom CA-signed certs should be verifiable tlsConfig, err = Server(Options{ - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), ClientAuth: tls.VerifyClientCertIfGiven, - CAFile: ca, + CA: PEMFile(ca), ExclusiveRootPools: true, }) @@ -268,8 +268,8 @@ func TestConfigServerExclusiveRootPools(t *testing.T) { // No CA file provided, system cert should be verifiable only tlsConfig, err = Server(Options{ - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), }) if err != nil || tlsConfig == nil { @@ -336,8 +336,8 @@ func TestConfigServerTLSMinVersionIsSetBasedOnOptions(t *testing.T) { for _, v := range versions { tlsConfig, err := Server(Options{ MinVersion: v, - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), }) if err != nil || tlsConfig == nil { @@ -357,8 +357,8 @@ func TestConfigServerTLSMinVersionNotSetIfMinVersionIsTooLow(t *testing.T) { _, err := Server(Options{ MinVersion: tls.VersionSSL30, - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), }) if err == nil { @@ -373,8 +373,8 @@ func TestConfigServerTLSMinVersionNotSetIfMinVersionIsInvalid(t *testing.T) { _, err := Server(Options{ MinVersion: 1, - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), }) if err == nil { @@ -387,7 +387,7 @@ func TestConfigServerTLSMinVersionNotSetIfMinVersionIsInvalid(t *testing.T) { func TestConfigClientTLSNoVerify(t *testing.T) { ca := getMultiCert() - tlsConfig, err := Client(Options{CAFile: ca, InsecureSkipVerify: true}) + tlsConfig, err := Client(Options{CA: PEMFile(ca), InsecureSkipVerify: true}) if err != nil || tlsConfig == nil { t.Fatal("Unable to configure client TLS", err) @@ -438,7 +438,7 @@ func TestConfigClientTLSNoRoot(t *testing.T) { func TestConfigClientTLSRootCAFileWithOneCert(t *testing.T) { ca := getMultiCert() - tlsConfig, err := Client(Options{CAFile: ca}) + tlsConfig, err := Client(Options{CA: PEMFile(ca)}) if err != nil || tlsConfig == nil { t.Fatal("Unable to configure client TLS", err) @@ -458,7 +458,7 @@ func TestConfigClientTLSRootCAFileWithOneCert(t *testing.T) { // An error is returned if a root CA is provided but the file doesn't exist. func TestConfigClientTLSNonexistentRootCAFile(t *testing.T) { - tlsConfig, err := Client(Options{CAFile: "nonexistent"}) + tlsConfig, err := Client(Options{CA: PEMFile("nonexistent")}) if err == nil || tlsConfig != nil { t.Fatal("Should not have been able to configure client TLS", err) @@ -482,7 +482,7 @@ func TestConfigClientTLSClientCertOrKeyInvalid(t *testing.T) { files := []string{cert, key} files[i] = invalid - tlsConfig, err := Client(Options{CertFile: files[0], KeyFile: files[1]}) + tlsConfig, err := Client(Options{Cert: PEMFile(files[0]), Key: PEMFile(files[1])}) if err == nil || tlsConfig != nil { t.Fatal("Should not have been able to configure client TLS", err) } @@ -500,7 +500,49 @@ func TestConfigClientTLSValidClientCertAndKey(t *testing.T) { t.Fatal("Unable to load the generated cert and key") } - tlsConfig, err := Client(Options{CertFile: cert, KeyFile: key}) + tlsConfig, err := Client(Options{Cert: PEMFile(cert), Key: PEMFile(key)}) + + if err != nil || tlsConfig == nil { + t.Fatal("Unable to configure client TLS", err) + } + + if len(tlsConfig.Certificates) != 1 { + t.Fatal("Unexpected client certificates") + } + if len(tlsConfig.Certificates[0].Certificate) != len(keypair.Certificate) { + t.Fatal("Unexpected client certificates") + } + for i, cert := range tlsConfig.Certificates[0].Certificate { + if !bytes.Equal(cert, keypair.Certificate[i]) { + t.Fatal("Unexpected client certificates") + } + } + + if tlsConfig.RootCAs != nil { + t.Fatal("Root CAs should not have been set", err) + } +} + +// The certificate is set if the client cert and client key are provided and +// valid. +func TestConfigClientTLSValidClientCertAndKeyFromMemory(t *testing.T) { + key, cert := getCertAndKey() + + keypair, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + t.Fatal("Unable to load the generated cert and key") + } + + certData, err := ioutil.ReadFile(cert) + if err != nil { + t.Fatal("Unable to load the cert") + } + keyData, err := ioutil.ReadFile(key) + if err != nil { + t.Fatal("Unable to load the key") + } + + tlsConfig, err := Client(Options{Cert: PEMInMemory(certData), Key: PEMInMemory(keyData)}) if err != nil || tlsConfig == nil { t.Fatal("Unable to configure client TLS", err) @@ -529,8 +571,8 @@ func TestConfigClientTLSValidClientCertAndEncryptedKey(t *testing.T) { key, cert := getCertAndEncryptedKey() tlsConfig, err := Client(Options{ - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), Passphrase: "FooBar123", }) @@ -549,8 +591,8 @@ func TestConfigClientTLSNotSetWithInvalidPassphrase(t *testing.T) { key, cert := getCertAndEncryptedKey() tlsConfig, err := Client(Options{ - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), Passphrase: "InvalidPassphrase", }) @@ -584,7 +626,7 @@ func TestConfigClientExclusiveRootPools(t *testing.T) { // ExclusiveRootPools not set, so should be able to verify both system-signed certs // and custom CA-signed certs - tlsConfig, err := Client(Options{CAFile: ca}) + tlsConfig, err := Client(Options{CA: PEMFile(ca)}) if err != nil || tlsConfig == nil { t.Fatal("Unable to configure client TLS", err) @@ -599,7 +641,7 @@ func TestConfigClientExclusiveRootPools(t *testing.T) { // ExclusiveRootPools set and custom CA provided, so system certs should not be verifiable // and custom CA-signed certs should be verifiable tlsConfig, err = Client(Options{ - CAFile: ca, + CA: PEMFile(ca), ExclusiveRootPools: true, }) @@ -642,8 +684,8 @@ func TestConfigClientTLSMinVersionIsSetBasedOnOptions(t *testing.T) { tlsConfig, err := Client(Options{ MinVersion: tls.VersionTLS12, - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), }) if err != nil || tlsConfig == nil { @@ -662,8 +704,8 @@ func TestConfigClientTLSMinVersionNotSetIfMinVersionIsTooLow(t *testing.T) { _, err := Client(Options{ MinVersion: tls.VersionTLS11, - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), }) if err == nil { @@ -678,8 +720,8 @@ func TestConfigClientTLSMinVersionNotSetIfMinVersionIsInvalid(t *testing.T) { _, err := Client(Options{ MinVersion: 1, - CertFile: cert, - KeyFile: key, + Cert: PEMFile(cert), + Key: PEMFile(key), }) if err == nil {