Skip to content

Commit

Permalink
chore: refactor encryption so dal does not know about awscli
Browse files Browse the repository at this point in the history
  • Loading branch information
gak committed Aug 19, 2024
1 parent 2505efb commit 87ea00e
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 72 deletions.
3 changes: 2 additions & 1 deletion backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import (
"github.com/TBD54566975/ftl/frontend"
cf "github.com/TBD54566975/ftl/internal/configuration"
"github.com/TBD54566975/ftl/internal/cors"
"github.com/TBD54566975/ftl/internal/encryption"
ftlhttp "github.com/TBD54566975/ftl/internal/http"
"github.com/TBD54566975/ftl/internal/log"
ftlmaps "github.com/TBD54566975/ftl/internal/maps"
Expand Down Expand Up @@ -229,7 +230,7 @@ func New(ctx context.Context, conn *sql.DB, config Config, runnerScaling scaling
config.ControllerTimeout = time.Second * 5
}

db, err := dal.New(ctx, conn, optional.Ptr[string](config.KMSURI))
db, err := dal.New(ctx, conn, *encryption.NewBuilder().WithKMSURI(config.KMSURI))
if err != nil {
return nil, fmt.Errorf("failed to create DAL: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion backend/controller/cronjobs/cronjobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/model"
"github.com/TBD54566975/ftl/internal/slices"
Expand All @@ -37,7 +38,7 @@ func TestServiceWithMockDal(t *testing.T) {
attemptCountMap: map[string]int{},
}
conn := sqltest.OpenForTesting(ctx, t)
parentDAL, err := db.New(ctx, conn, optional.None[string]())
parentDAL, err := db.New(ctx, conn, *encryption.NewBuilder())
assert.NoError(t, err)

testServiceWithDal(ctx, t, mockDal, parentDAL, clk)
Expand Down
6 changes: 3 additions & 3 deletions backend/controller/dal/async_calls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ import (
"context"
"testing"

"github.com/alecthomas/types/optional"
"github.com/alecthomas/assert/v2"

"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/alecthomas/assert/v2"
)

func TestNoCallToAcquire(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, *encryption.NewBuilder())
assert.NoError(t, err)

_, err = dal.AcquireAsyncCall(ctx)
Expand Down
7 changes: 4 additions & 3 deletions backend/controller/dal/dal.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,17 @@ func WithReservation(ctx context.Context, reservation Reservation, fn func() err
return reservation.Commit(ctx)
}

func New(ctx context.Context, conn *stdsql.DB, kmsURL optional.Option[string]) (*DAL, error) {
func New(ctx context.Context, conn *stdsql.DB, encryptionBuilder encryption.Builder) (*DAL, error) {
d := &DAL{
db: sql.NewDB(conn),
DeploymentChanges: pubsub.New[DeploymentNotification](),
kmsURL: kmsURL,
}

if err := d.setupEncryptor(ctx); err != nil {
encryptor, err := encryptionBuilder.Build(ctx, d)
if err != nil {
return nil, fmt.Errorf("failed to setup encryptor: %w", err)
}
d.encryptor = encryptor

return d, nil
}
Expand Down
7 changes: 4 additions & 3 deletions backend/controller/dal/dal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
dalerrs "github.com/TBD54566975/ftl/backend/dal"
ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/model"
"github.com/TBD54566975/ftl/internal/sha256"
Expand All @@ -26,7 +27,7 @@ import (
func TestDAL(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, *encryption.NewBuilder())
assert.NoError(t, err)
assert.NotZero(t, dal)
var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100)
Expand Down Expand Up @@ -373,7 +374,7 @@ func TestDAL(t *testing.T) {
func TestCreateArtefactConflict(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, *encryption.NewBuilder())
assert.NoError(t, err)

idch := make(chan sha256.SHA256, 2)
Expand Down Expand Up @@ -450,7 +451,7 @@ func assertEventsEqual(t *testing.T, expected, actual []TimelineEvent) {
func TestDeleteOldEvents(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, *encryption.NewBuilder())
assert.NoError(t, err)

var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100)
Expand Down
57 changes: 18 additions & 39 deletions backend/controller/dal/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@ import (
"context"
"encoding/json"
"fmt"

"github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
)

type name struct {
}
"github.com/TBD54566975/ftl/internal/encryption"
)

func (d *DAL) encrypt(subkey encryption.Subkey, cleartext []byte) ([]byte, error) {
if d.encryptor == nil {
Expand Down Expand Up @@ -61,49 +58,31 @@ func (d *DAL) decryptJSON(subkey encryption.Subkey, encrypted []byte, v any) err
return nil
}

// setupEncryptor sets up the encryptor for the DAL.
// It will either create a key or load the existing one.
// If the KMS URL is not set, it will use a NoOpEncryptor which does not encrypt anything.
func (d *DAL) setupEncryptor(ctx context.Context) (err error) {
func (d *DAL) EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) ([]byte, error) {
logger := log.FromContext(ctx)
tx, err := d.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.CommitOrRollback(ctx, &err)

url, ok := d.kmsURL.Get()
if !ok {
logger.Infof("KMS URL not set, encryption not enabled")
d.encryptor = encryption.NewNoOpEncryptor()
return nil
}

encryptedKey, err := tx.db.GetOnlyEncryptionKey(ctx)
if err != nil {
if dal.IsNotFound(err) {
logger.Infof("No encryption key found, generating a new one")
encryptor, err := encryption.NewKMSEncryptorGenerateKey(url, nil)
if err != nil {
return fmt.Errorf("failed to create encryptor for generation: %w", err)
}
d.encryptor = encryptor

if err = tx.db.CreateOnlyEncryptionKey(ctx, encryptor.GetEncryptedKeyset()); err != nil {
return fmt.Errorf("failed to create only encryption key: %w", err)
}

return nil
if err != nil && dal.IsNotFound(err) {
logger.Infof("No encryption key found, generating a new one")
key, err := generateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
}
return fmt.Errorf("failed to get only encryption key: %w", err)
}

logger.Debugf("Encryption key found, using it")
encryptor, err := encryption.NewKMSEncryptorWithKMS(url, nil, encryptedKey)
if err != nil {
return fmt.Errorf("failed to create encryptor with encrypted key: %w", err)
if err = tx.db.CreateOnlyEncryptionKey(ctx, key); err != nil {
return nil, fmt.Errorf("failed to save the encryption key: %w", err)
}

return key, nil
} else if err != nil {
return nil, fmt.Errorf("failed to load the encryption key from the db: %w", err)
}
d.encryptor = encryptor

return nil
logger.Debugf("Encryption key found, using it")
return encryptedKey, nil
}
4 changes: 2 additions & 2 deletions backend/controller/dal/fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package dal

import (
"context"
"github.com/alecthomas/types/optional"
"testing"
"time"

Expand All @@ -12,13 +11,14 @@ import (
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
)

func TestSendFSMEvent(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, *encryption.NewBuilder())
assert.NoError(t, err)

_, err = dal.AcquireAsyncCall(ctx)
Expand Down
5 changes: 3 additions & 2 deletions backend/controller/dal/lease_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/TBD54566975/ftl/backend/controller/sql"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
)

Expand All @@ -36,7 +37,7 @@ func TestLease(t *testing.T) {
}
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, *encryption.NewBuilder())
assert.NoError(t, err)

// TTL is too short, expect an error
Expand Down Expand Up @@ -71,7 +72,7 @@ func TestExpireLeases(t *testing.T) {
}
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, *encryption.NewBuilder())
assert.NoError(t, err)

leasei, _, err := dal.AcquireLease(ctx, leases.SystemKey("test"), time.Second*5, optional.None[any]())
Expand Down
2 changes: 1 addition & 1 deletion backend/controller/sql/sqltest/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func OpenForTesting(ctx context.Context, t testing.TB) *sql.DB {
t.Helper()
// Acquire lock for this DB.
lockPath := filepath.Join(os.TempDir(), "ftl-db-test.lock")
release, err := flock.Acquire(ctx, lockPath, 20*time.Second)
release, err := flock.Acquire(ctx, lockPath, 30*time.Second)
assert.NoError(t, err)
t.Cleanup(func() { _ = release() }) //nolint:errcheck

Expand Down
7 changes: 5 additions & 2 deletions cmd/ftl-controller/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"context"
"database/sql"
"fmt"
"github.com/TBD54566975/ftl/internal/encryption"
"os"
"strconv"
"time"

"github.com/alecthomas/kong"
"github.com/alecthomas/types/optional"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"

Expand Down Expand Up @@ -55,7 +55,10 @@ func main() {
// The FTL controller currently only supports DB as a configuration provider/resolver.
conn, err := sql.Open("pgx", cli.ControllerConfig.DSN)
kctx.FatalIfErrorf(err)
dal, err := dal.New(ctx, conn, optional.Some[string](*cli.ControllerConfig.KMSURI))

encryptionBuilder := encryption.NewBuilder().WithKMSURI(cli.ControllerConfig.KMSURI)
kctx.FatalIfErrorf(err)
dal, err := dal.New(ctx, conn, *encryptionBuilder)
kctx.FatalIfErrorf(err)

configDal, err := cfdal.New(ctx, conn)
Expand Down
70 changes: 56 additions & 14 deletions internal/encryption/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package encryption

import (
"bytes"
"context"
"fmt"
"strings"

"github.com/alecthomas/types/optional"
awsv1kms "github.com/aws/aws-sdk-go/service/kms"
"github.com/tink-crypto/tink-go-awskms/integration/awskms"
"github.com/tink-crypto/tink-go/v2/aead"
Expand All @@ -20,11 +22,6 @@ type Subkey interface {
Salt() string
}

//const (
// TimelineSubkey Subkey = "timeline"
// AsyncSubkey Subkey = "async"
//)

type TimelineSubkey struct{}

func (t TimelineSubkey) Salt() string {
Expand All @@ -37,6 +34,58 @@ func (a AsyncSubkey) Salt() string {
return "async"
}

type KeyStoreProvider interface {
// EnsureKey asks a provider to check for an encrypted key.
// If not available, call the generateKey function to create a new key.
// The provider should handle transactions around checking and setting the key, to prevent race conditions.
EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) ([]byte, error)
}

// Builder constructs a DataEncryptor when used with a provider.
// Use a chain of With* methods to configure the builder.
type Builder struct {
kmsURI optional.Option[string]
awsV1Client *awsv1kms.KMS
}

func NewBuilder() *Builder {
return &Builder{
kmsURI: optional.None[string](),
}
}

// WithKMSURI sets the URI for the KMS key to use. Omitting this call or using nil will create a NoOpEncryptor.
func (b *Builder) WithKMSURI(kmsURI *string) *Builder {
b.kmsURI = optional.Ptr[string](kmsURI)
return b
}

func (b *Builder) WithAWSV1Client(kms *awsv1kms.KMS) *Builder {
b.awsV1Client = kms
return b
}

func (b *Builder) Build(ctx context.Context, provider KeyStoreProvider) (DataEncryptor, error) {
kmsURI, ok := b.kmsURI.Get()
if !ok {
return NewNoOpEncryptor(), nil
}

key, err := provider.EnsureKey(ctx, func() ([]byte, error) {
return newKey(kmsURI, nil)
})
if err != nil {
return nil, fmt.Errorf("failed to ensure key from provider: %w", err)
}

encryptor, err := NewKMSEncryptorWithKMS(kmsURI, nil, key)
if err != nil {
return nil, fmt.Errorf("failed to create KMS encryptor: %w", err)
}

return encryptor, nil
}

type DataEncryptor interface {
Encrypt(subkey Subkey, cleartext []byte) ([]byte, error)
Decrypt(subkey Subkey, encrypted []byte) ([]byte, error)
Expand Down Expand Up @@ -96,7 +145,7 @@ func newClientWithAEAD(uri string, kms *awsv1kms.KMS) (tink.AEAD, error) {
return kekAEAD, nil
}

func NewKMSEncryptorGenerateKey(uri string, v1client *awsv1kms.KMS) (*KMSEncryptor, error) {
func newKey(uri string, v1client *awsv1kms.KMS) ([]byte, error) {
kekAEAD, err := newClientWithAEAD(uri, v1client)
if err != nil {
return nil, fmt.Errorf("failed to create KMS client: %w", err)
Expand Down Expand Up @@ -125,14 +174,7 @@ func NewKMSEncryptorGenerateKey(uri string, v1client *awsv1kms.KMS) (*KMSEncrypt
if err != nil {
return nil, fmt.Errorf("failed to encrypt DEK: %w", err)
}
encryptedKeyset := buf.Bytes()

return &KMSEncryptor{
root: *handle,
kekAEAD: kekAEAD,
encryptedKeyset: encryptedKeyset,
cachedDerived: make(map[Subkey]tink.AEAD),
}, nil
return buf.Bytes(), nil
}

func NewKMSEncryptorWithKMS(uri string, v1client *awsv1kms.KMS, encryptedKeyset []byte) (*KMSEncryptor, error) {
Expand Down
5 changes: 4 additions & 1 deletion internal/encryption/encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ func TestNoOpEncryptor(t *testing.T) {
func TestKMSEncryptorFakeKMS(t *testing.T) {
uri := "fake-kms://CKbvh_ILElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEE6tD2yE5AWYOirhmkY-r3sYARABGKbvh_ILIAE"

encryptor, err := NewKMSEncryptorGenerateKey(uri, nil)
key, err := newKey(uri, nil)
assert.NoError(t, err)

encryptor, err := NewKMSEncryptorWithKMS(uri, nil, key)
assert.NoError(t, err)

encrypted, err := encryptor.Encrypt(TimelineSubkey{}, []byte("hunter2"))
Expand Down

0 comments on commit 87ea00e

Please sign in to comment.