diff --git a/cmd/keymasterd/2fa_totp_test.go b/cmd/keymasterd/2fa_totp_test.go index 025b8c3..5541db7 100644 --- a/cmd/keymasterd/2fa_totp_test.go +++ b/cmd/keymasterd/2fa_totp_test.go @@ -127,6 +127,7 @@ func TestGenerateNewTOTPSuccess(t *testing.T) { if valid { t.Fatal("should NOT have been valid") } + state.dbDone <- struct{}{} } @@ -213,6 +214,7 @@ func TestVerifyTOTPHandlerSuccess(t *testing.T) { if err != nil { t.Fatal(err) } + state.dbDone <- struct{}{} } func TestAuthTOTPHandlerSuccess(t *testing.T) { @@ -259,6 +261,7 @@ func TestAuthTOTPHandlerSuccess(t *testing.T) { if err != nil { t.Fatal(err) } + state.dbDone <- struct{}{} } func TestTOTPTokenManagerHandlerUpdateSuccess(t *testing.T) { @@ -327,4 +330,5 @@ func TestTOTPTokenManagerHandlerUpdateSuccess(t *testing.T) { if profile.TOTPAuthData[0].Name != newName { t.Fatal("update not successul") } + state.dbDone <- struct{}{} } diff --git a/cmd/keymasterd/2fa_webauthn_test.go b/cmd/keymasterd/2fa_webauthn_test.go index 8db7dff..69e04e6 100644 --- a/cmd/keymasterd/2fa_webauthn_test.go +++ b/cmd/keymasterd/2fa_webauthn_test.go @@ -88,10 +88,10 @@ func TestWebAuthnRegistrationBegin(t *testing.T) { */ /* - Example post for finalization: - { - "{\"id\":\"_N2M7t9Qe2rwS4asNZ15I4Thd-nkXow6_lyDT6CURM3gD1sAq0FyMnf8NDOARMWMjjNgPfeHpPWP0Q8nkx-v7pNRuR0IwRHkvZeZxaV3Ql3HFigByVOhuB3OCq2em8Ve\",\"rawId\":\"_N2M7t9Qe2rwS4asNZ15I4Thd-nkXow6_lyDT6CURM3gD1sAq0FyMnf8NDOARMWMjjNgPfeHpPWP0Q8nkx-v7pNRuR0IwRHkvZeZxaV3Ql3HFigByVOhuB3OCq2em8Ve\",\"type\":\"public-key\",\"response\":{\"attestationObject\":\"o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YVjkSZYN5YgOjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NBAAADlwAAAAAAAAAAAAAAAAAAAAAAYPzdjO7fUHtq8EuGrDWdeSOE4Xfp5F6MOv5cg0-glETN4A9bAKtBcjJ3_DQzgETFjI4zYD33h6T1j9EPJ5Mfr-6TUbkdCMER5L2XmcWld0JdxxYoAclTobgdzgqtnpvFXqUBAgMmIAEhWCBwm_S46LuncSKubWLGS7236xBQyY-Ptg0dTKpOmddRMCJYIG02ZJischNpyUqMXRdiJfBW2kDmG3TROzKzHHBHmLlp\",\"clientDataJSON\":\"eyJjaGFsbGVuZ2UiOiJlTW1Ca0gxQ05KZzFsbGRQb3ZXQUN6R0pMZUpYRHZndmViUXIycDRxdWNVIiwib3JpZ2luIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6MzM0NDMiLCJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIn0\"}}": "" - } + Example post for finalization: + { + "{\"id\":\"_N2M7t9Qe2rwS4asNZ15I4Thd-nkXow6_lyDT6CURM3gD1sAq0FyMnf8NDOARMWMjjNgPfeHpPWP0Q8nkx-v7pNRuR0IwRHkvZeZxaV3Ql3HFigByVOhuB3OCq2em8Ve\",\"rawId\":\"_N2M7t9Qe2rwS4asNZ15I4Thd-nkXow6_lyDT6CURM3gD1sAq0FyMnf8NDOARMWMjjNgPfeHpPWP0Q8nkx-v7pNRuR0IwRHkvZeZxaV3Ql3HFigByVOhuB3OCq2em8Ve\",\"type\":\"public-key\",\"response\":{\"attestationObject\":\"o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YVjkSZYN5YgOjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NBAAADlwAAAAAAAAAAAAAAAAAAAAAAYPzdjO7fUHtq8EuGrDWdeSOE4Xfp5F6MOv5cg0-glETN4A9bAKtBcjJ3_DQzgETFjI4zYD33h6T1j9EPJ5Mfr-6TUbkdCMER5L2XmcWld0JdxxYoAclTobgdzgqtnpvFXqUBAgMmIAEhWCBwm_S46LuncSKubWLGS7236xBQyY-Ptg0dTKpOmddRMCJYIG02ZJischNpyUqMXRdiJfBW2kDmG3TROzKzHHBHmLlp\",\"clientDataJSON\":\"eyJjaGFsbGVuZ2UiOiJlTW1Ca0gxQ05KZzFsbGRQb3ZXQUN6R0pMZUpYRHZndmViUXIycDRxdWNVIiwib3JpZ2luIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6MzM0NDMiLCJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIn0\"}}": "" + } */ - + state.dbDone <- struct{}{} } diff --git a/cmd/keymasterd/app.go b/cmd/keymasterd/app.go index 8d883c9..c6c9354 100644 --- a/cmd/keymasterd/app.go +++ b/cmd/keymasterd/app.go @@ -209,6 +209,7 @@ type RuntimeState struct { db *sql.DB dbType string cacheDB *sql.DB + dbDone chan struct{} remoteDBQueryTimeout time.Duration htmlTemplate *htmltemplate.Template passwordChecker pwauth.PasswordAuthenticator @@ -1944,7 +1945,7 @@ func main() { ClientAuth: tls.VerifyClientCertIfGiven, GetCertificate: runtimeState.certManager.GetCertificate, MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, + CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256, tls.X25519}, PreferServerCipherSuites: true, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, @@ -2013,7 +2014,7 @@ func main() { ClientAuth: tls.VerifyClientCertIfGiven, GetCertificate: runtimeState.certManager.GetCertificate, MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, + CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256, tls.X25519}, PreferServerCipherSuites: true, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, diff --git a/cmd/keymasterd/main_test.go b/cmd/keymasterd/main_test.go index 44652bd..11a8086 100644 --- a/cmd/keymasterd/main_test.go +++ b/cmd/keymasterd/main_test.go @@ -622,6 +622,7 @@ func TestLoginAPIBasicAuth(t *testing.T) { t.Fatal(err) } } + state.dbDone <- struct{}{} } func TestLoginAPIFormAuth(t *testing.T) { @@ -707,6 +708,7 @@ func TestLoginAPIFormAuth(t *testing.T) { t.Fatal(err) } } + state.dbDone <- struct{}{} } func TestProfileHandlerTemplate(t *testing.T) { @@ -752,6 +754,8 @@ func TestProfileHandlerTemplate(t *testing.T) { t.Fatal(err) } //TODO: verify HTML output + + state.dbDone <- struct{}{} } func TestU2fTokenManagerHandlerUpdateSuccess(t *testing.T) { @@ -821,10 +825,12 @@ func TestU2fTokenManagerHandlerUpdateSuccess(t *testing.T) { if profile.U2fAuthData[0].Name != newName { t.Fatal("update not successul") } + + state.dbDone <- struct{}{} } func TestU2fTokenManagerHandlerDeleteNotAdmin(t *testing.T) { - var state RuntimeState + state := RuntimeState{logger: testlogger.New(t)} //load signer signer, err := getSignerFromPEMBytes([]byte(testSignerPrivateKey)) if err != nil { @@ -888,6 +894,7 @@ func TestU2fTokenManagerHandlerDeleteNotAdmin(t *testing.T) { if len(profile.U2fAuthData) != 2 { t.Fatal("delete should not have succeeded") } + state.dbDone <- struct{}{} } func TestU2fTokenManagerHandlerDeleteSuccess(t *testing.T) { @@ -955,4 +962,5 @@ func TestU2fTokenManagerHandlerDeleteSuccess(t *testing.T) { if len(profile.U2fAuthData) != 1 { t.Fatal("update not successul") } + state.dbDone <- struct{}{} } diff --git a/cmd/keymasterd/storage.go b/cmd/keymasterd/storage.go index 191a488..69dc2c0 100644 --- a/cmd/keymasterd/storage.go +++ b/cmd/keymasterd/storage.go @@ -13,6 +13,7 @@ import ( "github.com/Cloud-Foundations/golib/pkg/awsutil/metadata" "github.com/Cloud-Foundations/golib/pkg/awsutil/secretsmgr" + "github.com/Cloud-Foundations/golib/pkg/log" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) @@ -82,38 +83,39 @@ func (state *RuntimeState) expandStorageUrl() error { } func initDB(state *RuntimeState) (err error) { - logger.Debugf(3, "Top of initDB") + state.logger.Debugf(3, "Top of initDB") state.Config.ProfileStorage.setSyncLimits() //open/create cache DB first cacheDBFilename := filepath.Join(state.Config.Base.DataDirectory, cachedDBFilename) state.cacheDB, err = initFileDBSQLite(cacheDBFilename, state.cacheDB) if err != nil { - logger.Printf("Failure on creation of cacheDB") + state.logger.Printf("Failure on creation of cacheDB") return err } - logger.Debugf(3, "storage=%s", state.Config.ProfileStorage.StorageUrl) + state.logger.Debugf(3, "storage=%s", state.Config.ProfileStorage.StorageUrl) storageURL := state.Config.ProfileStorage.StorageUrl if storageURL == "" { storageURL = "sqlite:" } splitString := strings.SplitN(storageURL, ":", 2) if len(splitString) < 1 { - logger.Printf("invalid string") + state.logger.Printf("invalid string") err := errors.New("Bad storage url string") return err } state.remoteDBQueryTimeout = time.Second * 2 - go state.BackgroundDBCopy(state.Config.ProfileStorage.SyncDelay) + state.dbDone = make(chan struct{}) + go state.BackgroundDBCopy(state.Config.ProfileStorage.SyncDelay, state.dbDone, state.logger) switch splitString[0] { case "sqlite": - logger.Printf("doing sqlite") + state.logger.Printf("doing sqlite") return initDBSQlite(state) case "postgresql": - logger.Printf("doing postgres") + state.logger.Printf("doing postgres") return initDBPostgres(state) default: - logger.Printf("invalid storage url string") + state.logger.Printf("invalid storage url string") err := errors.New("Bad storage url string") return err } @@ -130,13 +132,13 @@ func initDBPostgres(state *RuntimeState) (err error) { sqlStmt := `create table if not exists user_profile (id serial not null primary key, username text unique, profile_data bytea);` _, err = state.db.Exec(sqlStmt) if err != nil { - logger.Printf("init postgres err: %s: %q\n", err, sqlStmt) + state.logger.Printf("init postgres err: %s: %q\n", err, sqlStmt) return err } sqlStmt = `create table if not exists expiring_signed_user_data(id serial not null primary key, username text not null, jws_data text not null, type integer not null, expiration_epoch integer not null, update_epoch integer not null, UNIQUE(username,type));` _, err = state.db.Exec(sqlStmt) if err != nil { - logger.Printf("init postgres err: %s: %q\n", err, sqlStmt) + state.logger.Printf("init postgres err: %s: %q\n", err, sqlStmt) return err } } @@ -207,8 +209,16 @@ func initFileDBSQLite(dbFilename string, currentDB *sql.DB) (*sql.DB, error) { return currentDB, nil } -func (state *RuntimeState) BackgroundDBCopy(initialSleep time.Duration) { - time.Sleep(initialSleep) +func (state *RuntimeState) BackgroundDBCopy(initialSleep time.Duration, done chan struct{}, logger log.DebugLogger) { + select { + case <-done: + //fmt.Println("Received:", v) + logger.Debugf(0, "Cancelled before first copy") + return + case <-time.After(initialSleep): + logger.Debugf(1, "BackgroundDBCopy, initial sleep done") + + } for { logger.Debugf(0, "starting db copy") err := copyDBIntoSQLite(state.db, state.cacheDB, "sqlite") @@ -219,7 +229,17 @@ func (state *RuntimeState) BackgroundDBCopy(initialSleep time.Duration) { } cleanupDBData(state.db) cleanupDBData(state.cacheDB) - time.Sleep(state.Config.ProfileStorage.SyncInterval) + + select { + case <-done: + //fmt.Println("Received:", v) + logger.Debugf(0, "Cancelled after copy") + return + case <-time.After(state.Config.ProfileStorage.SyncInterval): + logger.Debugf(1, "BackgroundDBCopy, sleep complete sleep done") + + } + //time.Sleep(state.Config.ProfileStorage.SyncInterval) } } diff --git a/cmd/keymasterd/storage_test.go b/cmd/keymasterd/storage_test.go index 7b16b54..185775b 100644 --- a/cmd/keymasterd/storage_test.go +++ b/cmd/keymasterd/storage_test.go @@ -61,6 +61,7 @@ func TestDBCopy(t *testing.T) { if err != nil { t.Fatal(err) } + state.dbDone <- struct{}{} } func TestFetchFromCache(t *testing.T) { @@ -106,4 +107,5 @@ func TestFetchFromCache(t *testing.T) { if ok { t.Fatal("This should have failed for invalid user") } + state.dbDone <- struct{}{} }