diff --git a/internal/tools/tls.go b/internal/tools/tls.go index 269f9cfe42..d59fa15846 100644 --- a/internal/tools/tls.go +++ b/internal/tools/tls.go @@ -5,7 +5,6 @@ import ( "crypto/x509" "errors" "fmt" - "os" ) type ConfigGetter interface { @@ -13,10 +12,27 @@ type ConfigGetter interface { GetString(name string) string } -func MakeTLSConfig(v ConfigGetter, keyPrefix string) (*tls.Config, error) { +// ReadFileFunc is an abstraction for os.ReadFile but also io/fs.ReadFile +// wrapped with an io/fs.FS instance. +// +// Note that os.DirFS has slightly different semantics compared to the native +// filesystem APIs, see https://go.dev/issue/44279 +type ReadFileFunc func(name string) ([]byte, error) + +// MakeTLSConfig constructs a tls.Config instance using the given configuration +// scoped under key prefix. +func MakeTLSConfig(v ConfigGetter, keyPrefix string, readFile ReadFileFunc) (*tls.Config, error) { tlsConfig := &tls.Config{} if v.GetString(keyPrefix+"tls_cert") != "" && v.GetString(keyPrefix+"tls_key") != "" { - cert, err := tls.LoadX509KeyPair(v.GetString(keyPrefix+"tls_cert"), v.GetString(keyPrefix+"tls_key")) + certPEMBlock, err := readFile(v.GetString(keyPrefix + "tls_cert")) + if err != nil { + return nil, err + } + keyPEMBlock, err := readFile(v.GetString(keyPrefix + "tls_key")) + if err != nil { + return nil, err + } + cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) if err != nil { return nil, fmt.Errorf("could not read the certificate/key: %s", err) } @@ -29,7 +45,7 @@ func MakeTLSConfig(v ConfigGetter, keyPrefix string) (*tls.Config, error) { tlsConfig.Certificates = []tls.Certificate{cert} } if v.GetString(keyPrefix+"tls_root_ca") != "" { - caCert, err := os.ReadFile(v.GetString(keyPrefix + "tls_root_ca")) + caCert, err := readFile(v.GetString(keyPrefix + "tls_root_ca")) if err != nil { return nil, fmt.Errorf("can not read the CA certificate: %s", err) } diff --git a/internal/tools/tls_test.go b/internal/tools/tls_test.go new file mode 100644 index 0000000000..34dae9da17 --- /dev/null +++ b/internal/tools/tls_test.go @@ -0,0 +1,300 @@ +package tools + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "strconv" + "testing" + "testing/fstest" +) + +const testCertPEM = `-----BEGIN CERTIFICATE----- +MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw +DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow +EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d +7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B +5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr +BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 +NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l +Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc +6MF9+Yw1Yy0t +-----END CERTIFICATE-----` + +const testKeyPEM = `-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 +AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q +EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== +-----END EC PRIVATE KEY-----` + +var testCertificate, _ = tls.X509KeyPair([]byte(testCertPEM), []byte(testKeyPEM)) + +var testCertPool, _ = newCertPoolFromPEM([]byte(testCertPEM)) + +type testConfigGetter map[string]string + +func (c testConfigGetter) GetBool(name string) bool { + v, _ := strconv.ParseBool(c[name]) + return v +} + +func (c testConfigGetter) GetString(name string) string { + return c[name] +} + +func TestMakeTLSConfig(t *testing.T) { + testCases := []struct { + name string + config testConfigGetter + prefix string + fsys fstest.MapFS + expect tls.Config + errOK bool + }{{ + name: "empty", + }, { + name: "serverName", + config: testConfigGetter{ + "tls_server_name": "example.com", + }, + expect: tls.Config{ + ServerName: "example.com", + }, + }, { + name: "insecureSkipVerify", + config: testConfigGetter{ + "tls_insecure_skip_verify": "true", + }, + expect: tls.Config{ + InsecureSkipVerify: true, + }, + }, { + name: "certPEM", + config: testConfigGetter{ + "tls_key_pem": testKeyPEM, + "tls_cert_pem": testCertPEM, + }, + expect: tls.Config{ + Certificates: []tls.Certificate{testCertificate}, + }, + }, { + name: "certPEMWithoutKey", + config: testConfigGetter{ + "tls_cert_pem": testCertPEM, + }, + }, { + name: "keyPEMWithoutCert", + config: testConfigGetter{ + "tls_key_pem": testKeyPEM, + }, + }, { + name: "badKeyPairPEM", + config: testConfigGetter{ + "tls_key_pem": "garbage", + "tls_cert_pem": "garbage", + }, + errOK: true, + }, { + name: "certFile", + config: testConfigGetter{ + "tls_key": "tls/centrifugo.key", + "tls_cert": "tls/centrifugo.cert", + }, + fsys: fstest.MapFS{ + "tls/centrifugo.key": &fstest.MapFile{ + Data: []byte(testKeyPEM), + }, + "tls/centrifugo.cert": &fstest.MapFile{ + Data: []byte(testCertPEM), + }, + }, + expect: tls.Config{ + Certificates: []tls.Certificate{testCertificate}, + }, + }, { + name: "certFileWithoutKey", + config: testConfigGetter{ + "tls_cert": "tls/centrifugo.cert", + }, + }, { + name: "keyFileWithoutCert", + config: testConfigGetter{ + "tls_key": "tls/centrifugo.key", + }, + }, { + name: "missingCertFile", + config: testConfigGetter{ + "tls_key": "tls/centrifugo.key", + "tls_cert": "tls/centrifugo.cert", + }, + fsys: fstest.MapFS{ + "tls/centrifugo.key": &fstest.MapFile{ + Data: []byte(testKeyPEM), + }, + }, + errOK: true, + }, { + name: "missingKeyFile", + config: testConfigGetter{ + "tls_key": "tls/centrifugo.key", + "tls_cert": "tls/centrifugo.cert", + }, + fsys: fstest.MapFS{ + "tls/centrifugo.cert": &fstest.MapFile{ + Data: []byte(testCertPEM), + }, + }, + errOK: true, + }, { + name: "badKeyPairFile", + config: testConfigGetter{ + "tls_key": "tls/centrifugo.key", + "tls_cert": "tls/centrifugo.cert", + }, + fsys: fstest.MapFS{ + "tls/centrifugo.key": &fstest.MapFile{}, + "tls/centrifugo.cert": &fstest.MapFile{}, + }, + errOK: true, + }, { + name: "rootCAPEM", + config: testConfigGetter{ + "tls_root_ca_pem": testCertPEM, + }, + expect: tls.Config{ + RootCAs: testCertPool, + }, + }, { + name: "badRootCAPEM", + config: testConfigGetter{ + "tls_root_ca_pem": "garbage", + }, + errOK: true, + }, { + name: "rootCAFile", + config: testConfigGetter{ + "tls_root_ca": "certs.pem", + }, + fsys: fstest.MapFS{ + "certs.pem": &fstest.MapFile{ + Data: []byte(testCertPEM), + }, + }, + expect: tls.Config{ + RootCAs: testCertPool, + }, + }, { + name: "missingRootCAFile", + config: testConfigGetter{ + "tls_root_ca": "certs.pem", + }, + errOK: true, + }, { + name: "badRootCAFile", + config: testConfigGetter{ + "tls_root_ca": "certs.pem", + }, + fsys: fstest.MapFS{ + "certs.pem": &fstest.MapFile{}, + }, + errOK: true, + }, { + name: "prefixedWithAllFields", + config: testConfigGetter{ + "test_tls_cert": "tls/centrifugo.cert", + "test_tls_key": "tls/centrifugo.key", + "test_tls_cert_pem": "garbage", + "test_tls_key_pem": "garbage", + "test_tls_root_ca": "root.pem", + "test_tls_root_ca_pem": "garbage", + "test_tls_server_name": "example.com", + "test_tls_insecure_skip_verify": "true", + }, + prefix: "test_", + fsys: fstest.MapFS{ + "tls/centrifugo.key": &fstest.MapFile{ + Data: []byte(testKeyPEM), + }, + "tls/centrifugo.cert": &fstest.MapFile{ + Data: []byte(testCertPEM), + }, + "root.pem": &fstest.MapFile{ + Data: []byte(testCertPEM), + }, + }, + expect: tls.Config{ + ServerName: "example.com", + InsecureSkipVerify: true, + RootCAs: testCertPool, + Certificates: []tls.Certificate{testCertificate}, + }, + }, { + name: "prefixedWithAllPEMFields", + config: testConfigGetter{ + "test_tls_cert_pem": testCertPEM, + "test_tls_key_pem": testKeyPEM, + "test_tls_root_ca_pem": testCertPEM, + "test_tls_server_name": "example.com", + "test_tls_insecure_skip_verify": "true", + }, + prefix: "test_", + expect: tls.Config{ + ServerName: "example.com", + InsecureSkipVerify: true, + RootCAs: testCertPool, + Certificates: []tls.Certificate{testCertificate}, + }, + }} + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + config, err := MakeTLSConfig(tc.config, tc.prefix, tc.fsys.ReadFile) + switch { + case tc.errOK: + if err == nil { + t.Fatal("expected error") + } + case err != nil: + t.Fatal(err) + default: + checkTLSConfig(t, &tc.expect, config) + } + }) + } +} + +func checkTLSConfig(t *testing.T, a, b *tls.Config) { + if a.ServerName != b.ServerName { + t.Errorf( + "expected tls.Config.ServerName to be %q, but got %q", + a.ServerName, b.ServerName, + ) + } + if a.InsecureSkipVerify != b.InsecureSkipVerify { + t.Errorf( + "expected tls.Config.InsecureSkipVerify to be %t, but got %t", + a.InsecureSkipVerify, b.InsecureSkipVerify, + ) + } + if len(a.Certificates) != len(b.Certificates) { + // TODO: check that tls.Certificate instances are equal. + t.Errorf( + "expected tls.Config.Certificates length to be %d, but got %d", + len(a.Certificates), len(b.Certificates), + ) + } + if !a.RootCAs.Equal(b.RootCAs) { + t.Error("expected tls.Config.RootCAs to be equal") + } +} + +// newCertPoolFromPEM returns certificate pool for the given PEM-encoded +// certificate bundle. Note that it currently ignores invalid blocks. +func newCertPoolFromPEM(pem []byte) (*x509.CertPool, error) { + certPool := x509.NewCertPool() + ok := certPool.AppendCertsFromPEM(pem) + if !ok { + return nil, errors.New("no valid certificates found") + } + return certPool, nil +} diff --git a/main.go b/main.go index 661bafa89e..9036c187c2 100644 --- a/main.go +++ b/main.go @@ -1318,18 +1318,18 @@ func getTLSConfig() (*tls.Config, error) { } else if tlsEnabled { // Autocert disabled - just try to use provided SSL cert and key files. - return tools.MakeTLSConfig(viper.GetViper(), "") + return tools.MakeTLSConfig(viper.GetViper(), "", os.ReadFile) } return nil, nil } func tlsConfigForGRPC() (*tls.Config, error) { - return tools.MakeTLSConfig(viper.GetViper(), "grpc_api_") + return tools.MakeTLSConfig(viper.GetViper(), "grpc_api_", os.ReadFile) } func tlsConfigForUniGRPC() (*tls.Config, error) { - return tools.MakeTLSConfig(viper.GetViper(), "uni_grpc_") + return tools.MakeTLSConfig(viper.GetViper(), "uni_grpc_", os.ReadFile) } type httpErrorLogWriter struct { @@ -2435,7 +2435,7 @@ func addRedisShardCommonSettings(shardConf *centrifuge.RedisShardConfig) { shardConf.ClientName = viper.GetString("redis_client_name") if viper.GetBool("redis_tls") { - tlsConfig, err := tools.MakeTLSConfig(viper.GetViper(), "redis_") + tlsConfig, err := tools.MakeTLSConfig(viper.GetViper(), "redis_", os.ReadFile) if err != nil { log.Fatal().Msgf("error creating Redis TLS config: %v", err) } @@ -2498,7 +2498,7 @@ func getRedisShardConfigs() ([]centrifuge.RedisShardConfig, string, error) { } conf.SentinelClientName = viper.GetString("redis_sentinel_client_name") if viper.GetBool("redis_sentinel_tls") { - tlsConfig, err := tools.MakeTLSConfig(viper.GetViper(), "redis_sentinel_") + tlsConfig, err := tools.MakeTLSConfig(viper.GetViper(), "redis_sentinel_", os.ReadFile) if err != nil { log.Fatal().Msgf("error creating Redis Sentinel TLS config: %v", err) }