From 6dc9d5f4177c04661d8a5791f42403e470b8a050 Mon Sep 17 00:00:00 2001 From: Jiaqiang Huang <96465211+River2000i@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:32:28 +0800 Subject: [PATCH] fix test --- dm/config/security/security.go | 3 --- dm/config/security_test.go | 3 --- dm/loader/lightning.go | 19 ++++++++------ dm/loader/lightning_test.go | 46 +++++++++++++++++----------------- 4 files changed, 34 insertions(+), 37 deletions(-) diff --git a/dm/config/security/security.go b/dm/config/security/security.go index a3670902877..4ec521d23eb 100644 --- a/dm/config/security/security.go +++ b/dm/config/security/security.go @@ -85,9 +85,6 @@ func (s *Security) ClearSSLBytesData() { s.SSLCABytes = s.SSLCABytes[:0] s.SSLKeyBytes = s.SSLKeyBytes[:0] s.SSLCertBytes = s.SSLCertBytes[:0] - s.SSLCA = "" - s.SSLCert = "" - s.SSLKey = "" } // Clone returns a deep copy of Security. diff --git a/dm/config/security_test.go b/dm/config/security_test.go index c713229d6c0..40e4c833c9a 100644 --- a/dm/config/security_test.go +++ b/dm/config/security_test.go @@ -106,9 +106,6 @@ func (c *testTLSConfig) TestLoadAndClearContent() { c.Require().Len(s.SSLCABytes, 0) c.Require().Len(s.SSLCertBytes, 0) c.Require().Len(s.SSLKeyBytes, 0) - c.Require().Equal(s.SSLCA, "") - c.Require().Equal(s.SSLCert, "") - c.Require().Equal(s.SSLKey, "") s.SSLCABase64 = "MTIz" err = s.LoadTLSContent() diff --git a/dm/loader/lightning.go b/dm/loader/lightning.go index ab4dd26cd7a..61e5e437c00 100644 --- a/dm/loader/lightning.go +++ b/dm/loader/lightning.go @@ -106,10 +106,10 @@ func NewLightning(cfg *config.SubTaskConfig, cli *clientv3.Client, workerName st // MakeGlobalConfig converts subtask config to lightning global config. func MakeGlobalConfig(cfg *config.SubTaskConfig) *lcfg.GlobalConfig { lightningCfg := lcfg.NewGlobalConfig() - if cfg.LoaderConfig.Security != nil { - lightningCfg.Security.CAPath = cfg.LoaderConfig.Security.SSLCA - lightningCfg.Security.CertPath = cfg.LoaderConfig.Security.SSLCert - lightningCfg.Security.KeyPath = cfg.LoaderConfig.Security.SSLKey + if cfg.To.Security != nil { + lightningCfg.Security.CABytes = cfg.To.Security.SSLCABytes + lightningCfg.Security.CertBytes = cfg.To.Security.SSLCertBytes + lightningCfg.Security.KeyBytes = cfg.To.Security.SSLKeyBytes } lightningCfg.TiDB.Host = cfg.To.Host lightningCfg.TiDB.Psw = cfg.To.Password @@ -330,10 +330,13 @@ func GetLightningConfig(globalCfg *lcfg.GlobalConfig, subtaskCfg *config.SubTask return nil, err } cfg.TiDB.Security = &globalCfg.Security - if subtaskCfg.To.Security != nil { - cfg.TiDB.Security.CAPath = subtaskCfg.To.Security.SSLCA - cfg.TiDB.Security.CertPath = subtaskCfg.To.Security.SSLCert - cfg.TiDB.Security.KeyPath = subtaskCfg.To.Security.SSLKey + if subtaskCfg.LoaderConfig.Security != nil { + cfg.Security.CABytes = cfg.Security.CABytes[:0] + cfg.Security.CertBytes = cfg.Security.CertBytes[:0] + cfg.Security.KeyBytes = cfg.Security.KeyBytes[:0] + cfg.Security.CAPath = subtaskCfg.LoaderConfig.Security.SSLCA + cfg.Security.CertPath = subtaskCfg.LoaderConfig.Security.SSLCert + cfg.Security.KeyPath = subtaskCfg.LoaderConfig.Security.SSLKey } // TableConcurrency is adjusted to the value of RegionConcurrency // when using TiDB backend. diff --git a/dm/loader/lightning_test.go b/dm/loader/lightning_test.go index 85142b2e4e3..83f8e99e8bf 100644 --- a/dm/loader/lightning_test.go +++ b/dm/loader/lightning_test.go @@ -127,21 +127,34 @@ func TestGetLightiningConfig(t *testing.T) { conf, err = GetLightningConfig( &lcfg.GlobalConfig{Security: lcfg.Security{CAPath: caPath, CertPath: certPath, KeyPath: keyPath}}, &config.SubTaskConfig{ - LoaderConfig: config.LoaderConfig{Security: &security.Security{SSLCA: caPath, SSLCert: certPath, SSLKey: keyPath}}, - To: dbconfig.DBConfig{Security: &security.Security{SSLCA: caPath2, SSLCert: certPath2, SSLKey: keyPath2}}, + To: dbconfig.DBConfig{Security: &security.Security{SSLCA: caPath, SSLCert: certPath, SSLKey: keyPath}}, + LoaderConfig: config.LoaderConfig{Security: &security.Security{SSLCA: caPath2, SSLCert: certPath2, SSLKey: keyPath2}}, }) require.NoError(t, err) - require.Equal(t, conf.Security.CAPath, caPath) - require.Equal(t, conf.Security.CertPath, certPath) - require.Equal(t, conf.Security.KeyPath, keyPath) - require.Equal(t, conf.TiDB.Security.CAPath, caPath2) - require.Equal(t, conf.TiDB.Security.CertPath, certPath2) - require.Equal(t, conf.TiDB.Security.KeyPath, keyPath2) + require.Equal(t, conf.Security.CAPath, caPath2) + require.Equal(t, conf.Security.CertPath, certPath2) + require.Equal(t, conf.Security.KeyPath, keyPath2) + require.Equal(t, conf.TiDB.Security.CAPath, caPath) + require.Equal(t, conf.TiDB.Security.CertPath, certPath) + require.Equal(t, conf.TiDB.Security.KeyPath, keyPath) conf, err = GetLightningConfig( - &lcfg.GlobalConfig{Security: lcfg.Security{CAPath: caPath, CertPath: certPath, KeyPath: keyPath}}, + &lcfg.GlobalConfig{}, &config.SubTaskConfig{ - LoaderConfig: config.LoaderConfig{Security: &security.Security{SSLCA: caPath, SSLCert: certPath, SSLKey: keyPath}}, To: dbconfig.DBConfig{}, + LoaderConfig: config.LoaderConfig{Security: &security.Security{SSLCA: caPath2, SSLCert: certPath2, SSLKey: keyPath2}}, + }) + require.NoError(t, err) + require.Equal(t, conf.Security.CAPath, caPath2) + require.Equal(t, conf.Security.CertPath, certPath2) + require.Equal(t, conf.Security.KeyPath, keyPath2) + require.Equal(t, conf.TiDB.Security.CAPath, "") + require.Equal(t, conf.TiDB.Security.CertPath, "") + require.Equal(t, conf.TiDB.Security.KeyPath, "") + conf, err = GetLightningConfig( + &lcfg.GlobalConfig{Security: lcfg.Security{CAPath: caPath, CertPath: certPath, KeyPath: keyPath}}, + &config.SubTaskConfig{ + To: dbconfig.DBConfig{Security: &security.Security{SSLCA: caPath, SSLCert: certPath, SSLKey: keyPath}}, + LoaderConfig: config.LoaderConfig{}, }) require.NoError(t, err) require.Equal(t, conf.Security.CAPath, caPath) @@ -150,19 +163,6 @@ func TestGetLightiningConfig(t *testing.T) { require.Equal(t, conf.TiDB.Security.CAPath, caPath) require.Equal(t, conf.TiDB.Security.CertPath, certPath) require.Equal(t, conf.TiDB.Security.KeyPath, keyPath) - conf, err = GetLightningConfig( - &lcfg.GlobalConfig{}, - &config.SubTaskConfig{ - LoaderConfig: config.LoaderConfig{}, - To: dbconfig.DBConfig{Security: &security.Security{SSLCA: caPath2, SSLCert: certPath2, SSLKey: keyPath2}}, - }) - require.NoError(t, err) - require.Equal(t, conf.Security.CAPath, "") - require.Equal(t, conf.Security.CertPath, "") - require.Equal(t, conf.Security.KeyPath, "") - require.Equal(t, conf.TiDB.Security.CAPath, caPath2) - require.Equal(t, conf.TiDB.Security.CertPath, certPath2) - require.Equal(t, conf.TiDB.Security.KeyPath, keyPath2) // invalid security file path _, err = GetLightningConfig( &lcfg.GlobalConfig{Security: lcfg.Security{CAPath: "caPath"}},