Skip to content

Commit

Permalink
refactor TLS config builder
Browse files Browse the repository at this point in the history
This change refactors TLS configuration builder and adds support for
enabling mutual TLS authentication for Centrifugo servers.
  • Loading branch information
tie committed Nov 27, 2023
1 parent 7248d4c commit 7b88e8e
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 52 deletions.
220 changes: 180 additions & 40 deletions internal/tools/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,51 +23,191 @@ type ReadFileFunc func(name string) ([]byte, error)
// scoped under key prefix.
func MakeTLSConfig(v ConfigGetter, keyPrefix string, readFile ReadFileFunc) (*tls.Config, error) {
tlsConfig := &tls.Config{}
if v.GetString(keyPrefix+"tls_cert") != "" && v.GetString(keyPrefix+"tls_key") != "" {
certPEMBlock, err := readFile(v.GetString(keyPrefix + "tls_cert"))
if err != nil {
return nil, err
}
keyPEMBlock, err := readFile(v.GetString(keyPrefix + "tls_key"))
if err != nil {

loaders := []tlsConfigLoader{
chainTLSConfigLoaders(loadCertFromFile, loadCertFromPEM),
chainTLSConfigLoaders(loadRootCAFromFile, loadRootCAFromPEM),
chainTLSConfigLoaders(loadMutualTLSFromFile, loadMutualTLSFromPEM),
}
for _, loadConfig := range loaders {
if _, err := loadConfig(tlsConfig, v, keyPrefix, readFile); err != nil {
return nil, err
}
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return nil, fmt.Errorf("could not read the certificate/key: %s", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
} else if v.GetString(keyPrefix+"tls_cert_pem") != "" && v.GetString(keyPrefix+"tls_key_pem") != "" {
cert, err := tls.X509KeyPair([]byte(v.GetString(keyPrefix+"tls_cert_pem")), []byte(v.GetString(keyPrefix+"tls_key_pem")))
if err != nil {
return nil, fmt.Errorf("error creating X509 key pair: %s", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
if v.GetString(keyPrefix+"tls_root_ca") != "" {
caCert, err := readFile(v.GetString(keyPrefix + "tls_root_ca"))
if err != nil {
return nil, fmt.Errorf("can not read the CA certificate: %s", err)
}
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(caCert)
if !ok {
return nil, errors.New("can not parse CA certificate")
}
tlsConfig.RootCAs = caCertPool
} else if v.GetString(keyPrefix+"tls_root_ca_pem") != "" {
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM([]byte(v.GetString(keyPrefix + "tls_root_ca_pem")))
if !ok {
return nil, errors.New("can not parse CA certificate")

tlsConfig.ServerName = v.GetString(keyPrefix + "tls_server_name")
tlsConfig.InsecureSkipVerify = v.GetBool(keyPrefix + "tls_insecure_skip_verify")

return tlsConfig, nil
}

// tlsConfigLoader is a function that loads TLS from the given ConfigGetter.
// It returns false, nil if configuration does not exist, true, nil on success,
// or true, err ≠ nil if there was an error loading the configuration.
type tlsConfigLoader func(c *tls.Config, v ConfigGetter, keyPrefix string, readFile ReadFileFunc) (bool, error)

// chainTLSConfigLoaders returns tlsConfigLoader function that attempts to load
// TLS configuration until either a configuration is found or an error occurs.
func chainTLSConfigLoaders(loaders ...tlsConfigLoader) tlsConfigLoader {
return func(c *tls.Config, v ConfigGetter, keyPrefix string, readFile ReadFileFunc) (bool, error) {
for _, f := range loaders {
found, err := f(c, v, keyPrefix, readFile)
if found || err != nil {
return found, err
}
}
tlsConfig.RootCAs = caCertPool
return false, nil
}
}

// loadCertFromFile loads the TLS configuration with certificate from key pair
// files containing PEM-encoded TLS key and certificate.
func loadCertFromFile(tlsConfig *tls.Config, v ConfigGetter, keyPrefix string, readFile ReadFileFunc) (bool, error) {
certFileKeyName, keyFileKeyName := keyPrefix+"tls_cert", keyPrefix+"tls_key"

certFile, keyFile := v.GetString(certFileKeyName), v.GetString(keyFileKeyName)
if certFile == "" || keyFile == "" {
return false, nil
}
if v.GetString(keyPrefix+"tls_server_name") != "" {
tlsConfig.ServerName = v.GetString(keyPrefix + "tls_server_name")

certPEMBlock, err := readFile(certFile)
if err != nil {
return true, fmt.Errorf("read TLS certificate for %s: %w", certFileKeyName, err)
}
if v.GetBool(keyPrefix + "tls_insecure_skip_verify") {
tlsConfig.InsecureSkipVerify = true
keyPEMBlock, err := readFile(keyFile)
if err != nil {
return true, fmt.Errorf("read TLS key for %s: %w", keyFileKeyName, err)
}
return tlsConfig, nil

cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return true, fmt.Errorf("parse certificate/key pair for %s/%s: %w", certFileKeyName, keyFileKeyName, err)
}

tlsConfig.Certificates = []tls.Certificate{cert}

return true, nil
}

// loadCertFromPEM loads the TLS configuration with certificate from key pair
// strings containing PEM-encoded TLS key and certificate.
func loadCertFromPEM(tlsConfig *tls.Config, v ConfigGetter, keyPrefix string, _ ReadFileFunc) (bool, error) {
certPEMKeyName, keyPEMKeyName := keyPrefix+"tls_cert_pem", keyPrefix+"tls_key_pem"

certPEM, keyPEM := v.GetString(certPEMKeyName), v.GetString(keyPEMKeyName)
if certPEM == "" || keyPEM == "" {
return false, nil
}

cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return true, fmt.Errorf("parse certificate/key pair for %s/%s: %w", certPEMKeyName, keyPEMKeyName, err)
}

tlsConfig.Certificates = []tls.Certificate{cert}

return true, nil
}

// loadRootCAFromFile loads the TLS configuration with root CA bundle from file
// containing PEM-encoded certificates.
func loadRootCAFromFile(tlsConfig *tls.Config, v ConfigGetter, keyPrefix string, readFile ReadFileFunc) (bool, error) {
keyName := keyPrefix + "tls_root_ca"

rootCAFile := v.GetString(keyName)
if rootCAFile == "" {
return false, nil
}

caCert, err := readFile(rootCAFile)
if err != nil {
return true, fmt.Errorf("read the root CA certificate for %s: %w", keyName, err)
}

caCertPool, err := newCertPoolFromPEM(caCert)
if err != nil {
return true, fmt.Errorf("parse root CA certificate for %s: %w", keyName, err)
}

tlsConfig.RootCAs = caCertPool

return true, nil
}

// loadRootCAFromFile loads the TLS configuration with root CA bundle from
// string containing PEM-encoded certificates.
func loadRootCAFromPEM(tlsConfig *tls.Config, v ConfigGetter, keyPrefix string, _ ReadFileFunc) (bool, error) {
keyName := keyPrefix + "tls_root_ca_pem"

rootCAPEM := v.GetString(keyName)
if rootCAPEM == "" {
return false, nil
}

caCertPool, err := newCertPoolFromPEM([]byte(rootCAPEM))
if err != nil {
return true, fmt.Errorf("parse root CA certificate for %s: %w", keyName, err)
}

tlsConfig.RootCAs = caCertPool

return true, nil
}

// loadMutualTLSFromFile loads the TLS configuration for server-side mutual TLS
// authentication from file containing PEM-encoded certificates.
func loadMutualTLSFromFile(tlsConfig *tls.Config, v ConfigGetter, keyPrefix string, readFile ReadFileFunc) (bool, error) {
keyName := keyPrefix + "tls_client_ca"

clientCAFile := v.GetString(keyName)
if clientCAFile == "" {
return false, nil
}

caCert, err := readFile(clientCAFile)
if err != nil {
return true, fmt.Errorf("read the client CA certificate for %s: %w", keyName, err)
}

caCertPool, err := newCertPoolFromPEM(caCert)
if err != nil {
return true, fmt.Errorf("parse client CA certificate for %s: %w", keyName, err)
}

tlsConfig.ClientCAs = caCertPool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert

return true, nil
}

// loadMutualTLSFromFile loads the TLS configuration for server-side mutual TLS
// authentication from string containing PEM-encoded certificates.
func loadMutualTLSFromPEM(tlsConfig *tls.Config, v ConfigGetter, keyPrefix string, _ ReadFileFunc) (bool, error) {
keyName := keyPrefix + "tls_client_ca_pem"

clientCAPEM := v.GetString(keyName)
if clientCAPEM == "" {
return false, nil
}

caCertPool, err := newCertPoolFromPEM([]byte(clientCAPEM))
if err != nil {
return true, fmt.Errorf("parse client CA certificate for %s: %w", keyName, err)
}

tlsConfig.ClientCAs = caCertPool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert

return true, nil
}

// newCertPoolFromPEM returns certificate pool for the given PEM-encoded
// certificate bundle. Note that it currently ignores invalid blocks.
func newCertPoolFromPEM(pem []byte) (*x509.CertPool, error) {
certPool := x509.NewCertPool()
ok := certPool.AppendCertsFromPEM(pem)
if !ok {
return nil, errors.New("no valid certificates found")
}
return certPool, nil
}
74 changes: 62 additions & 12 deletions internal/tools/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package tools

import (
"crypto/tls"
"crypto/x509"
"errors"
"strconv"
"testing"
"testing/fstest"
Expand Down Expand Up @@ -198,6 +196,50 @@ func TestMakeTLSConfig(t *testing.T) {
"certs.pem": &fstest.MapFile{},
},
errOK: true,
}, {
name: "clientCAPEM",
config: testConfigGetter{
"tls_client_ca_pem": testCertPEM,
},
expect: tls.Config{
ClientCAs: testCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
},
}, {
name: "badClientCAPEM",
config: testConfigGetter{
"tls_client_ca_pem": "garbage",
},
errOK: true,
}, {
name: "clientCAFile",
config: testConfigGetter{
"tls_client_ca": "certs.pem",
},
fsys: fstest.MapFS{
"certs.pem": &fstest.MapFile{
Data: []byte(testCertPEM),
},
},
expect: tls.Config{
ClientCAs: testCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
},
}, {
name: "missingClientCAFile",
config: testConfigGetter{
"tls_client_ca": "certs.pem",
},
errOK: true,
}, {
name: "badClientCAFile",
config: testConfigGetter{
"tls_client_ca": "certs.pem",
},
fsys: fstest.MapFS{
"certs.pem": &fstest.MapFile{},
},
errOK: true,
}, {
name: "prefixedWithAllFields",
config: testConfigGetter{
Expand All @@ -207,6 +249,8 @@ func TestMakeTLSConfig(t *testing.T) {
"test_tls_key_pem": "garbage",
"test_tls_root_ca": "root.pem",
"test_tls_root_ca_pem": "garbage",
"test_tls_client_ca": "client.pem",
"test_tls_client_ca_pem": "garbage",
"test_tls_server_name": "example.com",
"test_tls_insecure_skip_verify": "true",
},
Expand All @@ -221,11 +265,16 @@ func TestMakeTLSConfig(t *testing.T) {
"root.pem": &fstest.MapFile{
Data: []byte(testCertPEM),
},
"client.pem": &fstest.MapFile{
Data: []byte(testCertPEM),
},
},
expect: tls.Config{
ServerName: "example.com",
InsecureSkipVerify: true,
RootCAs: testCertPool,
ClientCAs: testCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{testCertificate},
},
}, {
Expand All @@ -234,6 +283,7 @@ func TestMakeTLSConfig(t *testing.T) {
"test_tls_cert_pem": testCertPEM,
"test_tls_key_pem": testKeyPEM,
"test_tls_root_ca_pem": testCertPEM,
"test_tls_client_ca_pem": testCertPEM,
"test_tls_server_name": "example.com",
"test_tls_insecure_skip_verify": "true",
},
Expand All @@ -242,6 +292,8 @@ func TestMakeTLSConfig(t *testing.T) {
ServerName: "example.com",
InsecureSkipVerify: true,
RootCAs: testCertPool,
ClientCAs: testCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{testCertificate},
},
}}
Expand Down Expand Up @@ -286,15 +338,13 @@ func checkTLSConfig(t *testing.T, a, b *tls.Config) {
if !a.RootCAs.Equal(b.RootCAs) {
t.Error("expected tls.Config.RootCAs to be equal")
}
}

// newCertPoolFromPEM returns certificate pool for the given PEM-encoded
// certificate bundle. Note that it currently ignores invalid blocks.
func newCertPoolFromPEM(pem []byte) (*x509.CertPool, error) {
certPool := x509.NewCertPool()
ok := certPool.AppendCertsFromPEM(pem)
if !ok {
return nil, errors.New("no valid certificates found")
if !a.ClientCAs.Equal(b.ClientCAs) {
t.Error("expected tls.Config.ClientCAs to be equal")
}
if a.ClientAuth != b.ClientAuth {
t.Errorf(
"expected tls.Config.ClientAuth length to be %s, but got %s",
a.ClientAuth, b.ClientAuth,
)
}
return certPool, nil
}
Loading

0 comments on commit 7b88e8e

Please sign in to comment.