Skip to content

Commit

Permalink
RSDK-7921 - Cleanup processConfig in entrypoint (viamrobotics#4450)
Browse files Browse the repository at this point in the history
  • Loading branch information
cheukt authored Oct 15, 2024
1 parent 5cd29ab commit 00eba29
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 82 deletions.
67 changes: 22 additions & 45 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"reflect"
"slices"
"strings"
"sync"
"time"

"github.com/pkg/errors"
Expand Down Expand Up @@ -950,63 +949,41 @@ func (config *AuthHandlerConfig) Validate(path string) error {
return nil
}

// TLSConfig stores the TLS config for the robot.
type TLSConfig struct {
*tls.Config
certMu sync.Mutex
tlsCert *tls.Certificate
}

// NewTLSConfig creates a new tls config.
func NewTLSConfig(cfg *Config) *TLSConfig {
tlsCfg := &TLSConfig{}
var tlsConfig *tls.Config
if cfg.Cloud != nil && cfg.Cloud.TLSCertificate != "" {
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
// always return same cert
tlsCfg.certMu.Lock()
defer tlsCfg.certMu.Unlock()
return tlsCfg.tlsCert, nil
},
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
// always return same cert
tlsCfg.certMu.Lock()
defer tlsCfg.certMu.Unlock()
return tlsCfg.tlsCert, nil
},
}
}
tlsCfg.Config = tlsConfig
return tlsCfg
}

// UpdateCert updates the TLS certificate to be returned.
func (t *TLSConfig) UpdateCert(cfg *Config) error {
// CreateTLSWithCert creates a tls.Config with the TLS certificate to be returned.
func CreateTLSWithCert(cfg *Config) (*tls.Config, error) {
cert, err := tls.X509KeyPair([]byte(cfg.Cloud.TLSCertificate), []byte(cfg.Cloud.TLSPrivateKey))
if err != nil {
return err
}
t.certMu.Lock()
t.tlsCert = &cert
t.certMu.Unlock()
return nil
return nil, err
}
return &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
// always return same cert
return &cert, nil
},
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
// always return same cert
return &cert, nil
},
}, nil
}

// ProcessConfig processes robot configs.
func ProcessConfig(in *Config, tlsCfg *TLSConfig) (*Config, error) {
func ProcessConfig(in *Config) (*Config, error) {
out := *in
var selfCreds *rpc.Credentials
if in.Cloud != nil {
// We expect a cloud config from app to always contain a non-empty `TLSCertificate` field.
// We do this empty string check just to cope with unexpected input, such as cached configs
// that are hand altered to have their `TLSCertificate` removed.
if in.Cloud.TLSCertificate != "" {
if err := tlsCfg.UpdateCert(in); err != nil {
tlsConfig, err := CreateTLSWithCert(in)
if err != nil {
return nil, err
}
out.Network.TLSConfig = tlsConfig
}

selfCreds = &rpc.Credentials{rutils.CredentialsTypeRobotSecret, in.Cloud.Secret}
out.Network.TLSConfig = tlsCfg.Config // override
}

out.Remotes = make([]Remote, len(in.Remotes))
Expand Down
53 changes: 20 additions & 33 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,29 +727,8 @@ func TestCopyOnlyPublicFields(t *testing.T) {
})
}

func TestNewTLSConfig(t *testing.T) {
for _, tc := range []struct {
TestName string
Config *config.Config
HasTLSConfig bool
}{
{TestName: "no cloud", Config: &config.Config{}, HasTLSConfig: false},
{TestName: "cloud but no cert", Config: &config.Config{Cloud: &config.Cloud{TLSCertificate: ""}}, HasTLSConfig: false},
{TestName: "cloud and cert", Config: &config.Config{Cloud: &config.Cloud{TLSCertificate: "abc"}}, HasTLSConfig: true},
} {
t.Run(tc.TestName, func(t *testing.T) {
observed := config.NewTLSConfig(tc.Config)
if tc.HasTLSConfig {
test.That(t, observed.MinVersion, test.ShouldEqual, tls.VersionTLS12)
} else {
test.That(t, observed, test.ShouldResemble, &config.TLSConfig{})
}
})
}
}

func TestUpdateCert(t *testing.T) {
t.Run("cert update", func(t *testing.T) {
func TestCreateTLSWithCert(t *testing.T) {
t.Run("create TLS cert", func(t *testing.T) {
cfg := &config.Config{
Cloud: &config.Cloud{
TLSCertificate: `-----BEGIN CERTIFICATE-----
Expand All @@ -775,8 +754,7 @@ ph2C/7IgjA==
cert, err := tls.X509KeyPair([]byte(cfg.Cloud.TLSCertificate), []byte(cfg.Cloud.TLSPrivateKey))
test.That(t, err, test.ShouldBeNil)

tlsCfg := config.NewTLSConfig(cfg)
err = tlsCfg.UpdateCert(cfg)
tlsCfg, err := config.CreateTLSWithCert(cfg)
test.That(t, err, test.ShouldBeNil)

observed, err := tlsCfg.GetCertificate(&tls.ClientHelloInfo{})
Expand All @@ -785,8 +763,7 @@ ph2C/7IgjA==
})
t.Run("cert error", func(t *testing.T) {
cfg := &config.Config{Cloud: &config.Cloud{TLSCertificate: "abcd", TLSPrivateKey: "abcd"}}
tlsCfg := &config.TLSConfig{}
err := tlsCfg.UpdateCert(cfg)
_, err := config.CreateTLSWithCert(cfg)
test.That(t, err, test.ShouldBeError, errors.New("tls: failed to find any PEM data in certificate input"))
})
}
Expand Down Expand Up @@ -862,15 +839,14 @@ ph2C/7IgjA==
expectedRemoteDiffManagerNoCloud := remoteDiffManager
expectedRemoteDiffManagerNoCloud.Auth = expectedRemoteAuthNoCloud

tlsCfg := &config.TLSConfig{}
err := tlsCfg.UpdateCert(cloudWTLSCfg)
tlsCfg, err := config.CreateTLSWithCert(cloudWTLSCfg)
test.That(t, err, test.ShouldBeNil)

expectedCloudWTLSCfg := &config.Config{Cloud: cloudWTLS, Remotes: []config.Remote{}}
expectedCloudWTLSCfg.Network.TLSConfig = tlsCfg.Config
expectedCloudWTLSCfg.Network.TLSConfig = tlsCfg

expectedRemotesCloudWTLSCfg := &config.Config{Cloud: cloudWTLS, Remotes: []config.Remote{expectedRemoteCloud, remoteDiffManager}}
expectedRemotesCloudWTLSCfg.Network.TLSConfig = tlsCfg.Config
expectedRemotesCloudWTLSCfg.Network.TLSConfig = tlsCfg

for _, tc := range []struct {
TestName string
Expand All @@ -893,15 +869,26 @@ ph2C/7IgjA==
{TestName: "remotes cloud and cert", Config: remotesCloudWTLSCfg, Expected: expectedRemotesCloudWTLSCfg},
} {
t.Run(tc.TestName, func(t *testing.T) {
observed, err := config.ProcessConfig(tc.Config, &config.TLSConfig{})
observed, err := config.ProcessConfig(tc.Config)
test.That(t, err, test.ShouldBeNil)
// TLSConfig holds funcs, which do not resemble each other so check separately and nil them out after.
if tc.Expected.Network.TLSConfig != nil {
obsCert, err := observed.Network.TLSConfig.GetCertificate(nil)
test.That(t, err, test.ShouldBeNil)
expCert, err := tc.Expected.Network.TLSConfig.GetCertificate(nil)
test.That(t, err, test.ShouldBeNil)

test.That(t, obsCert, test.ShouldResemble, expCert)
tc.Expected.Network.TLSConfig = nil
observed.Network.TLSConfig = nil
}
test.That(t, observed, test.ShouldResemble, tc.Expected)
})
}

t.Run("cert error", func(t *testing.T) {
cfg := &config.Config{Cloud: &config.Cloud{TLSCertificate: "abcd", TLSPrivateKey: "abcd"}}
_, err := config.ProcessConfig(cfg, &config.TLSConfig{})
_, err := config.ProcessConfig(cfg)
test.That(t, err, test.ShouldBeError, errors.New("tls: failed to find any PEM data in certificate input"))
})
}
Expand Down
3 changes: 1 addition & 2 deletions robot/impl/local_robot.go
Original file line number Diff line number Diff line change
Expand Up @@ -1050,8 +1050,7 @@ func RobotFromConfigPath(ctx context.Context, cfgPath string, logger logging.Log

// RobotFromConfig is a helper to process a config and then create a robot based on it.
func RobotFromConfig(ctx context.Context, cfg *config.Config, logger logging.Logger, opts ...Option) (robot.LocalRobot, error) {
tlsConfig := config.NewTLSConfig(cfg)
processedCfg, err := config.ProcessConfig(cfg, tlsConfig)
processedCfg, err := config.ProcessConfig(cfg)
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions web/server/entrypoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ func (s *robotServer) serveWeb(ctx context.Context, cfg *config.Config) (err err
ctx = rpc.ContextWithDialer(ctx, rpcDialer)

processConfig := func(in *config.Config) (*config.Config, error) {
tlsCfg := config.NewTLSConfig(cfg)
out, err := config.ProcessConfig(in, tlsCfg)
out, err := config.ProcessConfig(in)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 00eba29

Please sign in to comment.