Skip to content

Commit

Permalink
feat: add support for TOTP multi-factor authentication (#220)
Browse files Browse the repository at this point in the history
Co-authored-by: Tanner Henhawke <[email protected]>
Co-authored-by: Jeffrey Lo <[email protected]>
Co-authored-by: Lance Ivy <[email protected]>
  • Loading branch information
4 people authored Nov 28, 2023
1 parent 436a2b9 commit bbc6311
Show file tree
Hide file tree
Showing 55 changed files with 1,648 additions and 171 deletions.
8 changes: 7 additions & 1 deletion app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type App struct {
AccountStore data.AccountStore
RefreshTokenStore data.RefreshTokenStore
KeyStore data.KeyStore
TOTPCache data.TOTPCache
Actives data.Actives
Reporter ops.ErrorReporter
OauthProviders map[string]oauth.Provider
Expand Down Expand Up @@ -69,10 +70,12 @@ func NewApp(cfg *Config, logger logrus.FieldLogger) (*App, error) {
return nil, errors.Wrap(err, "NewBlobStore")
}

encryptedBlobStore := data.NewEncryptedBlobStore(blobStore, cfg.DBEncryptionKey)

keyStore := data.NewRotatingKeyStore()
if cfg.IdentitySigningKey == nil {
m := data.NewKeyStoreRotater(
data.NewEncryptedBlobStore(blobStore, cfg.DBEncryptionKey),
encryptedBlobStore,
cfg.AccessTokenTTL,
logger,
)
Expand All @@ -84,6 +87,8 @@ func NewApp(cfg *Config, logger logrus.FieldLogger) (*App, error) {
keyStore.Rotate(cfg.IdentitySigningKey)
}

totpCache := data.NewTOTPCache(encryptedBlobStore)

var actives data.Actives
if redis != nil {
actives = dataRedis.NewActives(
Expand Down Expand Up @@ -121,6 +126,7 @@ func NewApp(cfg *Config, logger logrus.FieldLogger) (*App, error) {
AccountStore: accountStore,
RefreshTokenStore: tokenStore,
KeyStore: keyStore,
TOTPCache: totpCache,
Actives: actives,
Reporter: errorReporter,
OauthProviders: oauthProviders,
Expand Down
2 changes: 2 additions & 0 deletions app/data/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type AccountStore interface {
SetPassword(id int, p []byte) (bool, error)
UpdateUsername(id int, u string) (bool, error)
SetLastLogin(id int) (bool, error)
SetTOTPSecret(id int, secret []byte) (bool, error)
DeleteTOTPSecret(id int) (bool, error)
}

func NewAccountStore(db sqlx.Ext) (AccountStore, error) {
Expand Down
6 changes: 6 additions & 0 deletions app/data/blob_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ type BlobStore interface {

// WriteNX will write the blob into the store only if the name does not exist.
WriteNX(name string, blob []byte) (bool, error)

// Write will write the blob into the store
Write(name string, blob []byte) (bool, error)

// Delete will remove the blob from the store
Delete(name string) error
}

func NewBlobStore(interval time.Duration, redis *redis.Client, db *sqlx.DB, reporter ops.ErrorReporter) (BlobStore, error) {
Expand Down
12 changes: 12 additions & 0 deletions app/data/encrypted_blob_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,15 @@ func (bs *EncryptedBlobStore) WriteNX(name string, blob []byte) (bool, error) {
}
return bs.store.WriteNX(name, encryptedBlob)
}

func (bs *EncryptedBlobStore) Write(name string, blob []byte) (bool, error) {
encryptedBlob, err := compat.Encrypt(blob, bs.encryptionKey)
if err != nil {
return false, err
}
return bs.store.Write(name, encryptedBlob)
}

func (bs *EncryptedBlobStore) Delete(name string) error {
return bs.store.Delete(name)
}
52 changes: 50 additions & 2 deletions app/data/mock/account_store.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mock

import (
"database/sql"
"fmt"
"strings"
"time"
Expand All @@ -23,15 +24,29 @@ type accountStore struct {
idByUsername map[string]int
oauthAccountsByID map[int][]*models.OauthAccount
idByOauthID map[string]int
errorOnID int
}

func NewAccountStore() *accountStore {
return &accountStore{
func WithSetTOTPFailureID(id int) func(s *accountStore) {
return func(s *accountStore) {
s.errorOnID = id
}
}

func NewAccountStore(opts ...func(*accountStore)) *accountStore {
s := &accountStore{
accountsByID: make(map[int]*models.Account),
oauthAccountsByID: make(map[int][]*models.OauthAccount),
idByUsername: make(map[string]int),
idByOauthID: make(map[string]int),
errorOnID: -1,
}

for _, o := range opts {
o(s)
}

return s
}

func (s *accountStore) Find(id int) (*models.Account, error) {
Expand Down Expand Up @@ -204,6 +219,39 @@ func (s *accountStore) SetLastLogin(id int) (bool, error) {
return true, nil
}

func (s *accountStore) SetTOTPSecret(id int, secret []byte) (bool, error) {
account := s.accountsByID[id]
if account == nil {
return false, nil
}

// this is weird, but we can return "unaffected" if the secret already exists
// to approximate the failure mode for testing.
if account.TOTPSecret.Valid {
return false, nil
}

if account.ID == s.errorOnID {
return false, fmt.Errorf("rejecting for bad ID: %d", account.ID)
}

account.TOTPSecret = sql.NullString{String: string(secret), Valid: true}
return true, nil
}

func (s *accountStore) DeleteTOTPSecret(id int) (bool, error) {
account := s.accountsByID[id]
if account == nil {
return false, nil
}
deleted := false
if account.TOTPSecret.Valid {
account.TOTPSecret = sql.NullString{}
deleted = true
}
return deleted, nil
}

// i think this works? i want to avoid accidentally giving callers the ability
// to reach into the memory map and modify things or see changes without relying
// on the store api.
Expand Down
19 changes: 17 additions & 2 deletions app/data/mock/blob_store.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package mock

import "time"
import "sync"
import (
"sync"
"time"
)

type BlobStore struct {
blobs map[string][]byte
Expand All @@ -10,6 +12,11 @@ type BlobStore struct {
LockTime time.Duration
}

func (bs *BlobStore) Delete(name string) error {
delete(bs.blobs, name)
return nil
}

var placeholder = "mock-blob-store"

func NewBlobStore(ttl time.Duration, lockTime time.Duration) *BlobStore {
Expand Down Expand Up @@ -39,3 +46,11 @@ func (bs *BlobStore) WriteNX(name string, blob []byte) (bool, error) {
bs.blobs[name] = blob
return true, nil
}

func (bs *BlobStore) Write(name string, blob []byte) (bool, error) {
bs.mutex.Lock()
defer bs.mutex.Unlock()

bs.blobs[name] = blob
return true, nil
}
44 changes: 44 additions & 0 deletions app/data/mock/totp_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package mock

import (
"fmt"
)

type TOTP struct {
store map[int][]byte
errorOnID int
}

func NewTOTPCache(errorOnID int) *TOTP {
return &TOTP{
errorOnID: errorOnID,
store: make(map[int][]byte),
}
}

func (m TOTP) CacheTOTPSecret(accountID int, secret []byte) error {
if accountID == m.errorOnID {
return fmt.Errorf("error forced by ID: %d", accountID)
}
m.store[accountID] = secret
return nil
}

func (m TOTP) LoadTOTPSecret(accountID int) ([]byte, error) {
if accountID == m.errorOnID {
return nil, fmt.Errorf("error forced by ID: %d", accountID)
}
r, ok := m.store[accountID]
if !ok {
return nil, nil
}
return r, nil
}

func (m TOTP) RemoveTOTPSecret(accountID int) error {
if accountID == m.errorOnID {
return fmt.Errorf("error forced by ID: %d", accountID)
}
delete(m.store, accountID)
return nil
}
10 changes: 10 additions & 0 deletions app/data/mysql/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ func (db *AccountStore) SetLastLogin(id int) (bool, error) {
return ok(result, err)
}

func (db *AccountStore) SetTOTPSecret(id int, secret []byte) (bool, error) {
result, err := db.Exec("UPDATE accounts SET totp_secret = ? WHERE id = ?", secret, id)
return ok(result, err)
}

func (db *AccountStore) DeleteTOTPSecret(id int) (bool, error) {
result, err := db.Exec("UPDATE accounts SET totp_secret = NULL WHERE id = ?", id)
return ok(result, err)
}

func ok(result sql.Result, err error) (bool, error) {
if err != nil {
return false, err
Expand Down
19 changes: 17 additions & 2 deletions app/data/mysql/migrations.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package mysql

import "github.com/jmoiron/sqlx"
import "github.com/go-sql-driver/mysql"
import (
"github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
)

// MigrateDB is committed to doing the work necessary to converge the database
// in a safe, production-grade fashion. This will mean conditional logic as it
Expand All @@ -12,6 +14,7 @@ func MigrateDB(db *sqlx.DB) error {
createAccounts,
createOauthAccounts,
createAccountLastLoginAtField,
createAccountTOTPFields,
}
for _, m := range migrations {
if err := m(db); err != nil {
Expand Down Expand Up @@ -69,3 +72,15 @@ func createAccountLastLoginAtField(db *sqlx.DB) error {
}
return err
}

func createAccountTOTPFields(db *sqlx.DB) error {
_, err := db.Exec(`
ALTER TABLE accounts ADD totp_secret VARCHAR(255) DEFAULT NULL
`)
if mysqlError, ok := err.(*mysql.MySQLError); ok {
if mysqlError.Number == 1060 { // 1060 = Duplicate column name
err = nil
}
}
return err
}
10 changes: 10 additions & 0 deletions app/data/postgres/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ func (db *AccountStore) SetLastLogin(id int) (bool, error) {
return ok(result, err)
}

func (db *AccountStore) SetTOTPSecret(id int, secret []byte) (bool, error) {
result, err := db.Exec("UPDATE accounts SET totp_secret = $1 WHERE id = $2", secret, id)
return ok(result, err)
}

func (db *AccountStore) DeleteTOTPSecret(id int) (bool, error) {
result, err := db.Exec("UPDATE accounts SET totp_secret = NULL WHERE id = $1", id)
return ok(result, err)
}

func ok(result sql.Result, err error) (bool, error) {
if err != nil {
return false, err
Expand Down
9 changes: 9 additions & 0 deletions app/data/postgres/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ func MigrateDB(db *sqlx.DB) error {
createOauthAccounts,
createAccountLastLoginAtField,
caseInsensitiveUsername,
createAccountTOTPFields,
}
for _, m := range migrations {
if err := m(db); err != nil {
Expand All @@ -20,6 +21,7 @@ func MigrateDB(db *sqlx.DB) error {
}
return nil
}

func migrateAccounts(db *sqlx.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS accounts (
Expand Down Expand Up @@ -68,3 +70,10 @@ func caseInsensitiveUsername(db *sqlx.DB) error {
`)
return err
}

func createAccountTOTPFields(db *sqlx.DB) error {
_, err := db.Exec(`
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS totp_secret TEXT DEFAULT NULL
`)
return err
}
12 changes: 12 additions & 0 deletions app/data/redis/blob_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,15 @@ func (s *BlobStore) Read(name string) ([]byte, error) {
func (s *BlobStore) WriteNX(name string, blob []byte) (bool, error) {
return s.Client.SetNX(context.TODO(), name, blob, s.TTL).Result()
}

func (s *BlobStore) Write(name string, blob []byte) (bool, error) {
res, err := s.Client.Set(context.TODO(), name, blob, s.TTL).Result()
if res != "OK" {
return false, err
}
return true, nil
}

func (s *BlobStore) Delete(name string) error {
return s.Client.Del(context.TODO(), name).Err()
}
10 changes: 10 additions & 0 deletions app/data/sqlite3/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ func (db *AccountStore) SetLastLogin(id int) (bool, error) {
return ok(result, err)
}

func (db *AccountStore) SetTOTPSecret(id int, secret []byte) (bool, error) {
result, err := db.Exec("UPDATE accounts SET totp_secret = ? WHERE id = ?", secret, id)
return ok(result, err)
}

func (db *AccountStore) DeleteTOTPSecret(id int) (bool, error) {
result, err := db.Exec("UPDATE accounts SET totp_secret = NULL WHERE id = ?", id)
return ok(result, err)
}

func ok(result sql.Result, err error) (bool, error) {
if err != nil {
return false, err
Expand Down
14 changes: 14 additions & 0 deletions app/data/sqlite3/blob_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,17 @@ func (s *BlobStore) WriteNX(name string, blob []byte) (bool, error) {
}
return true, nil
}

func (s *BlobStore) Write(name string, blob []byte) (bool, error) {
expiresAt := time.Now().Add(s.TTL)
_, err := s.DB.Exec("INSERT or REPLACE INTO blobs (name, blob, expires_at) VALUES (?, ?, ?)", name, blob, expiresAt, blob, expiresAt, name)
if err != nil {
return false, err
}
return true, nil
}

func (s *BlobStore) Delete(name string) error {
_, err := s.DB.Exec("DELETE FROM blobs WHERE name = ?", name)
return err
}
Loading

0 comments on commit bbc6311

Please sign in to comment.