Skip to content

Commit

Permalink
add tests for tools.MakeTLSConfig
Browse files Browse the repository at this point in the history
This change is a preparation for refactoring MakeTLSConfig and adds
tests to ensure that we keep the current behavior. To make testing
easier, we pass os.ReadFile as an argument and use fstest.MapFS.ReadFile
in tests. In particular, this also means avoiding tls.LoadX509KeyPair.
  • Loading branch information
tie committed Nov 26, 2023
1 parent 99526d1 commit 7248d4c
Show file tree
Hide file tree
Showing 3 changed files with 325 additions and 9 deletions.
24 changes: 20 additions & 4 deletions internal/tools/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,34 @@ import (
"crypto/x509"
"errors"
"fmt"
"os"
)

type ConfigGetter interface {
GetBool(name string) bool
GetString(name string) string
}

func MakeTLSConfig(v ConfigGetter, keyPrefix string) (*tls.Config, error) {
// ReadFileFunc is an abstraction for os.ReadFile but also io/fs.ReadFile
// wrapped with an io/fs.FS instance.
//
// Note that os.DirFS has slightly different semantics compared to the native
// filesystem APIs, see https://go.dev/issue/44279
type ReadFileFunc func(name string) ([]byte, error)

// MakeTLSConfig constructs a tls.Config instance using the given configuration
// scoped under key prefix.
func MakeTLSConfig(v ConfigGetter, keyPrefix string, readFile ReadFileFunc) (*tls.Config, error) {
tlsConfig := &tls.Config{}
if v.GetString(keyPrefix+"tls_cert") != "" && v.GetString(keyPrefix+"tls_key") != "" {
cert, err := tls.LoadX509KeyPair(v.GetString(keyPrefix+"tls_cert"), v.GetString(keyPrefix+"tls_key"))
certPEMBlock, err := readFile(v.GetString(keyPrefix + "tls_cert"))
if err != nil {
return nil, err
}
keyPEMBlock, err := readFile(v.GetString(keyPrefix + "tls_key"))
if err != nil {
return nil, err
}
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return nil, fmt.Errorf("could not read the certificate/key: %s", err)
}
Expand All @@ -29,7 +45,7 @@ func MakeTLSConfig(v ConfigGetter, keyPrefix string) (*tls.Config, error) {
tlsConfig.Certificates = []tls.Certificate{cert}
}
if v.GetString(keyPrefix+"tls_root_ca") != "" {
caCert, err := os.ReadFile(v.GetString(keyPrefix + "tls_root_ca"))
caCert, err := readFile(v.GetString(keyPrefix + "tls_root_ca"))
if err != nil {
return nil, fmt.Errorf("can not read the CA certificate: %s", err)
}
Expand Down
300 changes: 300 additions & 0 deletions internal/tools/tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
package tools

import (
"crypto/tls"
"crypto/x509"
"errors"
"strconv"
"testing"
"testing/fstest"
)

const testCertPEM = `-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`

const testKeyPEM = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`

var testCertificate, _ = tls.X509KeyPair([]byte(testCertPEM), []byte(testKeyPEM))

var testCertPool, _ = newCertPoolFromPEM([]byte(testCertPEM))

type testConfigGetter map[string]string

func (c testConfigGetter) GetBool(name string) bool {
v, _ := strconv.ParseBool(c[name])
return v
}

func (c testConfigGetter) GetString(name string) string {
return c[name]
}

func TestMakeTLSConfig(t *testing.T) {
testCases := []struct {
name string
config testConfigGetter
prefix string
fsys fstest.MapFS
expect tls.Config
errOK bool
}{{
name: "empty",
}, {
name: "serverName",
config: testConfigGetter{
"tls_server_name": "example.com",
},
expect: tls.Config{
ServerName: "example.com",
},
}, {
name: "insecureSkipVerify",
config: testConfigGetter{
"tls_insecure_skip_verify": "true",
},
expect: tls.Config{
InsecureSkipVerify: true,
},
}, {
name: "certPEM",
config: testConfigGetter{
"tls_key_pem": testKeyPEM,
"tls_cert_pem": testCertPEM,
},
expect: tls.Config{
Certificates: []tls.Certificate{testCertificate},
},
}, {
name: "certPEMWithoutKey",
config: testConfigGetter{
"tls_cert_pem": testCertPEM,
},
}, {
name: "keyPEMWithoutCert",
config: testConfigGetter{
"tls_key_pem": testKeyPEM,
},
}, {
name: "badKeyPairPEM",
config: testConfigGetter{
"tls_key_pem": "garbage",
"tls_cert_pem": "garbage",
},
errOK: true,
}, {
name: "certFile",
config: testConfigGetter{
"tls_key": "tls/centrifugo.key",
"tls_cert": "tls/centrifugo.cert",
},
fsys: fstest.MapFS{
"tls/centrifugo.key": &fstest.MapFile{
Data: []byte(testKeyPEM),
},
"tls/centrifugo.cert": &fstest.MapFile{
Data: []byte(testCertPEM),
},
},
expect: tls.Config{
Certificates: []tls.Certificate{testCertificate},
},
}, {
name: "certFileWithoutKey",
config: testConfigGetter{
"tls_cert": "tls/centrifugo.cert",
},
}, {
name: "keyFileWithoutCert",
config: testConfigGetter{
"tls_key": "tls/centrifugo.key",
},
}, {
name: "missingCertFile",
config: testConfigGetter{
"tls_key": "tls/centrifugo.key",
"tls_cert": "tls/centrifugo.cert",
},
fsys: fstest.MapFS{
"tls/centrifugo.key": &fstest.MapFile{
Data: []byte(testKeyPEM),
},
},
errOK: true,
}, {
name: "missingKeyFile",
config: testConfigGetter{
"tls_key": "tls/centrifugo.key",
"tls_cert": "tls/centrifugo.cert",
},
fsys: fstest.MapFS{
"tls/centrifugo.cert": &fstest.MapFile{
Data: []byte(testCertPEM),
},
},
errOK: true,
}, {
name: "badKeyPairFile",
config: testConfigGetter{
"tls_key": "tls/centrifugo.key",
"tls_cert": "tls/centrifugo.cert",
},
fsys: fstest.MapFS{
"tls/centrifugo.key": &fstest.MapFile{},
"tls/centrifugo.cert": &fstest.MapFile{},
},
errOK: true,
}, {
name: "rootCAPEM",
config: testConfigGetter{
"tls_root_ca_pem": testCertPEM,
},
expect: tls.Config{
RootCAs: testCertPool,
},
}, {
name: "badRootCAPEM",
config: testConfigGetter{
"tls_root_ca_pem": "garbage",
},
errOK: true,
}, {
name: "rootCAFile",
config: testConfigGetter{
"tls_root_ca": "certs.pem",
},
fsys: fstest.MapFS{
"certs.pem": &fstest.MapFile{
Data: []byte(testCertPEM),
},
},
expect: tls.Config{
RootCAs: testCertPool,
},
}, {
name: "missingRootCAFile",
config: testConfigGetter{
"tls_root_ca": "certs.pem",
},
errOK: true,
}, {
name: "badRootCAFile",
config: testConfigGetter{
"tls_root_ca": "certs.pem",
},
fsys: fstest.MapFS{
"certs.pem": &fstest.MapFile{},
},
errOK: true,
}, {
name: "prefixedWithAllFields",
config: testConfigGetter{
"test_tls_cert": "tls/centrifugo.cert",
"test_tls_key": "tls/centrifugo.key",
"test_tls_cert_pem": "garbage",
"test_tls_key_pem": "garbage",
"test_tls_root_ca": "root.pem",
"test_tls_root_ca_pem": "garbage",
"test_tls_server_name": "example.com",
"test_tls_insecure_skip_verify": "true",
},
prefix: "test_",
fsys: fstest.MapFS{
"tls/centrifugo.key": &fstest.MapFile{
Data: []byte(testKeyPEM),
},
"tls/centrifugo.cert": &fstest.MapFile{
Data: []byte(testCertPEM),
},
"root.pem": &fstest.MapFile{
Data: []byte(testCertPEM),
},
},
expect: tls.Config{
ServerName: "example.com",
InsecureSkipVerify: true,
RootCAs: testCertPool,
Certificates: []tls.Certificate{testCertificate},
},
}, {
name: "prefixedWithAllPEMFields",
config: testConfigGetter{
"test_tls_cert_pem": testCertPEM,
"test_tls_key_pem": testKeyPEM,
"test_tls_root_ca_pem": testCertPEM,
"test_tls_server_name": "example.com",
"test_tls_insecure_skip_verify": "true",
},
prefix: "test_",
expect: tls.Config{
ServerName: "example.com",
InsecureSkipVerify: true,
RootCAs: testCertPool,
Certificates: []tls.Certificate{testCertificate},
},
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
config, err := MakeTLSConfig(tc.config, tc.prefix, tc.fsys.ReadFile)
switch {
case tc.errOK:
if err == nil {
t.Fatal("expected error")
}
case err != nil:
t.Fatal(err)
default:
checkTLSConfig(t, &tc.expect, config)
}
})
}
}

func checkTLSConfig(t *testing.T, a, b *tls.Config) {
if a.ServerName != b.ServerName {
t.Errorf(
"expected tls.Config.ServerName to be %q, but got %q",
a.ServerName, b.ServerName,
)
}
if a.InsecureSkipVerify != b.InsecureSkipVerify {
t.Errorf(
"expected tls.Config.InsecureSkipVerify to be %t, but got %t",
a.InsecureSkipVerify, b.InsecureSkipVerify,
)
}
if len(a.Certificates) != len(b.Certificates) {
// TODO: check that tls.Certificate instances are equal.
t.Errorf(
"expected tls.Config.Certificates length to be %d, but got %d",
len(a.Certificates), len(b.Certificates),
)
}
if !a.RootCAs.Equal(b.RootCAs) {
t.Error("expected tls.Config.RootCAs to be equal")
}
}

// newCertPoolFromPEM returns certificate pool for the given PEM-encoded
// certificate bundle. Note that it currently ignores invalid blocks.
func newCertPoolFromPEM(pem []byte) (*x509.CertPool, error) {
certPool := x509.NewCertPool()
ok := certPool.AppendCertsFromPEM(pem)
if !ok {
return nil, errors.New("no valid certificates found")
}
return certPool, nil
}
10 changes: 5 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -1318,18 +1318,18 @@ func getTLSConfig() (*tls.Config, error) {

} else if tlsEnabled {
// Autocert disabled - just try to use provided SSL cert and key files.
return tools.MakeTLSConfig(viper.GetViper(), "")
return tools.MakeTLSConfig(viper.GetViper(), "", os.ReadFile)
}

return nil, nil
}

func tlsConfigForGRPC() (*tls.Config, error) {
return tools.MakeTLSConfig(viper.GetViper(), "grpc_api_")
return tools.MakeTLSConfig(viper.GetViper(), "grpc_api_", os.ReadFile)
}

func tlsConfigForUniGRPC() (*tls.Config, error) {
return tools.MakeTLSConfig(viper.GetViper(), "uni_grpc_")
return tools.MakeTLSConfig(viper.GetViper(), "uni_grpc_", os.ReadFile)
}

type httpErrorLogWriter struct {
Expand Down Expand Up @@ -2435,7 +2435,7 @@ func addRedisShardCommonSettings(shardConf *centrifuge.RedisShardConfig) {
shardConf.ClientName = viper.GetString("redis_client_name")

if viper.GetBool("redis_tls") {
tlsConfig, err := tools.MakeTLSConfig(viper.GetViper(), "redis_")
tlsConfig, err := tools.MakeTLSConfig(viper.GetViper(), "redis_", os.ReadFile)
if err != nil {
log.Fatal().Msgf("error creating Redis TLS config: %v", err)
}
Expand Down Expand Up @@ -2498,7 +2498,7 @@ func getRedisShardConfigs() ([]centrifuge.RedisShardConfig, string, error) {
}
conf.SentinelClientName = viper.GetString("redis_sentinel_client_name")
if viper.GetBool("redis_sentinel_tls") {
tlsConfig, err := tools.MakeTLSConfig(viper.GetViper(), "redis_sentinel_")
tlsConfig, err := tools.MakeTLSConfig(viper.GetViper(), "redis_sentinel_", os.ReadFile)
if err != nil {
log.Fatal().Msgf("error creating Redis Sentinel TLS config: %v", err)
}
Expand Down

0 comments on commit 7248d4c

Please sign in to comment.