diff --git a/config/config.go b/config/config.go index 32482577c74..5ac34ad9d27 100644 --- a/config/config.go +++ b/config/config.go @@ -12,7 +12,6 @@ import ( "reflect" "slices" "strings" - "sync" "time" "github.com/pkg/errors" @@ -950,63 +949,41 @@ func (config *AuthHandlerConfig) Validate(path string) error { return nil } -// TLSConfig stores the TLS config for the robot. -type TLSConfig struct { - *tls.Config - certMu sync.Mutex - tlsCert *tls.Certificate -} - -// NewTLSConfig creates a new tls config. -func NewTLSConfig(cfg *Config) *TLSConfig { - tlsCfg := &TLSConfig{} - var tlsConfig *tls.Config - if cfg.Cloud != nil && cfg.Cloud.TLSCertificate != "" { - tlsConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { - // always return same cert - tlsCfg.certMu.Lock() - defer tlsCfg.certMu.Unlock() - return tlsCfg.tlsCert, nil - }, - GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { - // always return same cert - tlsCfg.certMu.Lock() - defer tlsCfg.certMu.Unlock() - return tlsCfg.tlsCert, nil - }, - } - } - tlsCfg.Config = tlsConfig - return tlsCfg -} - -// UpdateCert updates the TLS certificate to be returned. -func (t *TLSConfig) UpdateCert(cfg *Config) error { +// CreateTLSWithCert creates a tls.Config with the TLS certificate to be returned. +func CreateTLSWithCert(cfg *Config) (*tls.Config, error) { cert, err := tls.X509KeyPair([]byte(cfg.Cloud.TLSCertificate), []byte(cfg.Cloud.TLSPrivateKey)) if err != nil { - return err - } - t.certMu.Lock() - t.tlsCert = &cert - t.certMu.Unlock() - return nil + return nil, err + } + return &tls.Config{ + MinVersion: tls.VersionTLS12, + GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + // always return same cert + return &cert, nil + }, + GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { + // always return same cert + return &cert, nil + }, + }, nil } // ProcessConfig processes robot configs. -func ProcessConfig(in *Config, tlsCfg *TLSConfig) (*Config, error) { +func ProcessConfig(in *Config) (*Config, error) { out := *in var selfCreds *rpc.Credentials if in.Cloud != nil { + // We expect a cloud config from app to always contain a non-empty `TLSCertificate` field. + // We do this empty string check just to cope with unexpected input, such as cached configs + // that are hand altered to have their `TLSCertificate` removed. if in.Cloud.TLSCertificate != "" { - if err := tlsCfg.UpdateCert(in); err != nil { + tlsConfig, err := CreateTLSWithCert(in) + if err != nil { return nil, err } + out.Network.TLSConfig = tlsConfig } - selfCreds = &rpc.Credentials{rutils.CredentialsTypeRobotSecret, in.Cloud.Secret} - out.Network.TLSConfig = tlsCfg.Config // override } out.Remotes = make([]Remote, len(in.Remotes)) diff --git a/config/config_test.go b/config/config_test.go index 9fa44c319f9..7853953bb7b 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -727,29 +727,8 @@ func TestCopyOnlyPublicFields(t *testing.T) { }) } -func TestNewTLSConfig(t *testing.T) { - for _, tc := range []struct { - TestName string - Config *config.Config - HasTLSConfig bool - }{ - {TestName: "no cloud", Config: &config.Config{}, HasTLSConfig: false}, - {TestName: "cloud but no cert", Config: &config.Config{Cloud: &config.Cloud{TLSCertificate: ""}}, HasTLSConfig: false}, - {TestName: "cloud and cert", Config: &config.Config{Cloud: &config.Cloud{TLSCertificate: "abc"}}, HasTLSConfig: true}, - } { - t.Run(tc.TestName, func(t *testing.T) { - observed := config.NewTLSConfig(tc.Config) - if tc.HasTLSConfig { - test.That(t, observed.MinVersion, test.ShouldEqual, tls.VersionTLS12) - } else { - test.That(t, observed, test.ShouldResemble, &config.TLSConfig{}) - } - }) - } -} - -func TestUpdateCert(t *testing.T) { - t.Run("cert update", func(t *testing.T) { +func TestCreateTLSWithCert(t *testing.T) { + t.Run("create TLS cert", func(t *testing.T) { cfg := &config.Config{ Cloud: &config.Cloud{ TLSCertificate: `-----BEGIN CERTIFICATE----- @@ -775,8 +754,7 @@ ph2C/7IgjA== cert, err := tls.X509KeyPair([]byte(cfg.Cloud.TLSCertificate), []byte(cfg.Cloud.TLSPrivateKey)) test.That(t, err, test.ShouldBeNil) - tlsCfg := config.NewTLSConfig(cfg) - err = tlsCfg.UpdateCert(cfg) + tlsCfg, err := config.CreateTLSWithCert(cfg) test.That(t, err, test.ShouldBeNil) observed, err := tlsCfg.GetCertificate(&tls.ClientHelloInfo{}) @@ -785,8 +763,7 @@ ph2C/7IgjA== }) t.Run("cert error", func(t *testing.T) { cfg := &config.Config{Cloud: &config.Cloud{TLSCertificate: "abcd", TLSPrivateKey: "abcd"}} - tlsCfg := &config.TLSConfig{} - err := tlsCfg.UpdateCert(cfg) + _, err := config.CreateTLSWithCert(cfg) test.That(t, err, test.ShouldBeError, errors.New("tls: failed to find any PEM data in certificate input")) }) } @@ -862,15 +839,14 @@ ph2C/7IgjA== expectedRemoteDiffManagerNoCloud := remoteDiffManager expectedRemoteDiffManagerNoCloud.Auth = expectedRemoteAuthNoCloud - tlsCfg := &config.TLSConfig{} - err := tlsCfg.UpdateCert(cloudWTLSCfg) + tlsCfg, err := config.CreateTLSWithCert(cloudWTLSCfg) test.That(t, err, test.ShouldBeNil) expectedCloudWTLSCfg := &config.Config{Cloud: cloudWTLS, Remotes: []config.Remote{}} - expectedCloudWTLSCfg.Network.TLSConfig = tlsCfg.Config + expectedCloudWTLSCfg.Network.TLSConfig = tlsCfg expectedRemotesCloudWTLSCfg := &config.Config{Cloud: cloudWTLS, Remotes: []config.Remote{expectedRemoteCloud, remoteDiffManager}} - expectedRemotesCloudWTLSCfg.Network.TLSConfig = tlsCfg.Config + expectedRemotesCloudWTLSCfg.Network.TLSConfig = tlsCfg for _, tc := range []struct { TestName string @@ -893,15 +869,26 @@ ph2C/7IgjA== {TestName: "remotes cloud and cert", Config: remotesCloudWTLSCfg, Expected: expectedRemotesCloudWTLSCfg}, } { t.Run(tc.TestName, func(t *testing.T) { - observed, err := config.ProcessConfig(tc.Config, &config.TLSConfig{}) + observed, err := config.ProcessConfig(tc.Config) test.That(t, err, test.ShouldBeNil) + // TLSConfig holds funcs, which do not resemble each other so check separately and nil them out after. + if tc.Expected.Network.TLSConfig != nil { + obsCert, err := observed.Network.TLSConfig.GetCertificate(nil) + test.That(t, err, test.ShouldBeNil) + expCert, err := tc.Expected.Network.TLSConfig.GetCertificate(nil) + test.That(t, err, test.ShouldBeNil) + + test.That(t, obsCert, test.ShouldResemble, expCert) + tc.Expected.Network.TLSConfig = nil + observed.Network.TLSConfig = nil + } test.That(t, observed, test.ShouldResemble, tc.Expected) }) } t.Run("cert error", func(t *testing.T) { cfg := &config.Config{Cloud: &config.Cloud{TLSCertificate: "abcd", TLSPrivateKey: "abcd"}} - _, err := config.ProcessConfig(cfg, &config.TLSConfig{}) + _, err := config.ProcessConfig(cfg) test.That(t, err, test.ShouldBeError, errors.New("tls: failed to find any PEM data in certificate input")) }) } diff --git a/robot/impl/local_robot.go b/robot/impl/local_robot.go index 44371447bb6..3edde0c8c67 100644 --- a/robot/impl/local_robot.go +++ b/robot/impl/local_robot.go @@ -1050,8 +1050,7 @@ func RobotFromConfigPath(ctx context.Context, cfgPath string, logger logging.Log // RobotFromConfig is a helper to process a config and then create a robot based on it. func RobotFromConfig(ctx context.Context, cfg *config.Config, logger logging.Logger, opts ...Option) (robot.LocalRobot, error) { - tlsConfig := config.NewTLSConfig(cfg) - processedCfg, err := config.ProcessConfig(cfg, tlsConfig) + processedCfg, err := config.ProcessConfig(cfg) if err != nil { return nil, err } diff --git a/web/server/entrypoint.go b/web/server/entrypoint.go index 0edb638c357..90a6e79e2f2 100644 --- a/web/server/entrypoint.go +++ b/web/server/entrypoint.go @@ -303,8 +303,7 @@ func (s *robotServer) serveWeb(ctx context.Context, cfg *config.Config) (err err ctx = rpc.ContextWithDialer(ctx, rpcDialer) processConfig := func(in *config.Config) (*config.Config, error) { - tlsCfg := config.NewTLSConfig(cfg) - out, err := config.ProcessConfig(in, tlsCfg) + out, err := config.ProcessConfig(in) if err != nil { return nil, err }