Skip to content

Commit

Permalink
chore(refactor): extract encryption dal for sharing with other packages
Browse files Browse the repository at this point in the history
  • Loading branch information
wesbillman committed Sep 13, 2024
1 parent cf20910 commit a67bf63
Show file tree
Hide file tree
Showing 22 changed files with 871 additions and 230 deletions.
10 changes: 6 additions & 4 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/TBD54566975/ftl/backend/controller/console"
"github.com/TBD54566975/ftl/backend/controller/cronjobs"
"github.com/TBD54566975/ftl/backend/controller/dal"
"github.com/TBD54566975/ftl/backend/controller/encryption"
"github.com/TBD54566975/ftl/backend/controller/ingress"
"github.com/TBD54566975/ftl/backend/controller/leases"
leasesdal "github.com/TBD54566975/ftl/backend/controller/leases/dal"
Expand All @@ -54,7 +55,7 @@ import (
frontend "github.com/TBD54566975/ftl/frontend/console"
cf "github.com/TBD54566975/ftl/internal/configuration/manager"
"github.com/TBD54566975/ftl/internal/cors"
"github.com/TBD54566975/ftl/internal/encryption"
ftlencryption "github.com/TBD54566975/ftl/internal/encryption"
ftlhttp "github.com/TBD54566975/ftl/internal/http"
"github.com/TBD54566975/ftl/internal/log"
ftlmaps "github.com/TBD54566975/ftl/internal/maps"
Expand Down Expand Up @@ -231,12 +232,13 @@ func New(ctx context.Context, conn *sql.DB, config Config, devel bool) (*Service
config.ControllerTimeout = time.Second * 5
}

ldb := leasesdal.New(conn)
db, err := dal.New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Ptr(config.KMSURI)))
encryptionSrv, err := encryption.New(ctx, conn, ftlencryption.NewBuilder().WithKMSURI(optional.Ptr(config.KMSURI)))
if err != nil {
return nil, fmt.Errorf("failed to create DAL: %w", err)
return nil, fmt.Errorf("failed to create encryption dal: %w", err)
}

db := dal.New(ctx, conn, encryptionSrv)
ldb := leasesdal.New(conn)
svc := &Service{
tasks: scheduledtask.New(ctx, key, ldb),
dal: db,
Expand Down
8 changes: 6 additions & 2 deletions backend/controller/cronjobs/cronjobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ import (

"github.com/TBD54566975/ftl/backend/controller/cronjobs/dal"
parentdal "github.com/TBD54566975/ftl/backend/controller/dal"
"github.com/TBD54566975/ftl/backend/controller/encryption"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/cron"
"github.com/TBD54566975/ftl/internal/encryption"
ftlencryption "github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/model"
)
Expand All @@ -34,8 +35,11 @@ func TestNewCronJobsForModule(t *testing.T) {
key := model.NewControllerKey("localhost", strconv.Itoa(8080+1))
conn := sqltest.OpenForTesting(ctx, t)
dal := dal.New(conn)
parentDAL, err := parentdal.New(ctx, conn, encryption.NewBuilder())

encryption, err := encryption.New(ctx, conn, ftlencryption.NewBuilder())
assert.NoError(t, err)

parentDAL := parentdal.New(ctx, conn, encryption)
moduleName := "initial"
jobsToCreate := newCronJobs(t, moduleName, "* * * * * *", clk, 2) // every minute

Expand Down
10 changes: 5 additions & 5 deletions backend/controller/dal/async_calls.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, leaseCtx c
return nil, ctx, fmt.Errorf("failed to parse origin key %q: %w", row.Origin, err)
}

decryptedRequest, err := d.decrypt(&row.Request)
decryptedRequest, err := d.encryption.Decrypt(&row.Request)
if err != nil {
return nil, ctx, fmt.Errorf("failed to decrypt async call request: %w", err)
}

lease, leaseCtx := d.leasedal.NewLease(ctx, row.LeaseKey, row.LeaseIdempotencyKey, ttl)
lease, leaseCtx := d.leaseDAL.NewLease(ctx, row.LeaseKey, row.LeaseIdempotencyKey, ttl)
return &AsyncCall{
ID: row.AsyncCallID,
Verb: row.Verb,
Expand Down Expand Up @@ -192,7 +192,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
switch result := result.(type) {
case either.Left[[]byte, string]: // Successful response.
var encryptedResult encryption.EncryptedAsyncColumn
err := tx.encrypt(result.Get(), &encryptedResult)
err := tx.encryption.Encrypt(result.Get(), &encryptedResult)
if err != nil {
return false, fmt.Errorf("failed to encrypt async call result: %w", err)
}
Expand Down Expand Up @@ -261,7 +261,7 @@ func (d *DAL) LoadAsyncCall(ctx context.Context, id int64) (*AsyncCall, error) {
if err != nil {
return nil, fmt.Errorf("failed to parse origin key %q: %w", row.Origin, err)
}
request, err := d.decrypt(&row.Request)
request, err := d.encryption.Decrypt(&row.Request)
if err != nil {
return nil, fmt.Errorf("failed to decrypt async call request: %w", err)
}
Expand All @@ -284,7 +284,7 @@ func (d *DAL) GetZombieAsyncCalls(ctx context.Context, limit int) ([]*AsyncCall,
if err != nil {
return nil, fmt.Errorf("failed to parse origin key %q: %w", row.Origin, err)
}
decryptedRequest, err := d.decrypt(&row.Request)
decryptedRequest, err := d.encryption.Decrypt(&row.Request)
if err != nil {
return nil, fmt.Errorf("failed to decrypt async call request: %w", err)
}
Expand Down
7 changes: 5 additions & 2 deletions backend/controller/dal/async_calls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@ import (

"github.com/alecthomas/assert/v2"

"github.com/TBD54566975/ftl/backend/controller/encryption"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
ftlencryption "github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/model"
)

func TestNoCallToAcquire(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, encryption.NewBuilder())
encryption, err := encryption.New(ctx, conn, ftlencryption.NewBuilder())
assert.NoError(t, err)

dal := New(ctx, conn, encryption)

_, _, err = dal.AcquireAsyncCall(ctx)
assert.IsError(t, err, libdal.ErrNotFound)
assert.EqualError(t, err, "no pending async calls: not found")
Expand Down
46 changes: 20 additions & 26 deletions backend/controller/dal/dal.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ import (
"google.golang.org/protobuf/proto"

dalsql "github.com/TBD54566975/ftl/backend/controller/dal/internal/sql"
"github.com/TBD54566975/ftl/backend/controller/encryption"
leasedal "github.com/TBD54566975/ftl/backend/controller/leases/dal"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltypes"
"github.com/TBD54566975/ftl/backend/libdal"
ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
ftlencryption "github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/maps"
"github.com/TBD54566975/ftl/internal/model"
Expand Down Expand Up @@ -201,40 +202,33 @@ func WithReservation(ctx context.Context, reservation Reservation, fn func() err
return reservation.Commit(ctx)
}

func New(ctx context.Context, conn libdal.Connection, encryptionBuilder encryption.Builder) (*DAL, error) {
func New(ctx context.Context, conn libdal.Connection, encryption *encryption.Service) *DAL {
var d *DAL
d = &DAL{
leasedal: leasedal.New(conn),
db: dalsql.New(conn),
leaseDAL: leasedal.New(conn),
db: dalsql.New(conn),
encryption: encryption,
Handle: libdal.New(conn, func(h *libdal.Handle[DAL]) *DAL {
return &DAL{
Handle: h,
db: dalsql.New(h.Connection),
leasedal: leasedal.New(h.Connection),
encryptor: d.encryptor,
leaseDAL: leasedal.New(h.Connection),
encryption: d.encryption,
DeploymentChanges: d.DeploymentChanges,
}
}),
DeploymentChanges: pubsub.New[DeploymentNotification](),
}
encryptor, err := encryptionBuilder.Build(ctx, d)
if err != nil {
return nil, fmt.Errorf("build encryptor: %w", err)
}
if err := d.verifyEncryptor(ctx, encryptor); err != nil {
return nil, fmt.Errorf("verify encryptor: %w", err)
}
d.encryptor = encryptor
return d, nil

return d
}

type DAL struct {
*libdal.Handle[DAL]
db dalsql.Querier

leasedal *leasedal.DAL

encryptor encryption.DataEncryptor
leaseDAL *leasedal.DAL
encryption *encryption.Service

// DeploymentChanges is a Topic that receives changes to the deployments table.
DeploymentChanges *pubsub.Topic[DeploymentNotification]
Expand Down Expand Up @@ -611,8 +605,8 @@ func (d *DAL) SetDeploymentReplicas(ctx context.Context, key model.DeploymentKey
return libdal.TranslatePGError(err)
}
}
var payload encryption.EncryptedTimelineColumn
err = d.encryptJSON(map[string]interface{}{
var payload ftlencryption.EncryptedTimelineColumn
err = d.encryption.EncryptJSON(map[string]interface{}{
"prev_min_replicas": deployment.MinReplicas,
"min_replicas": minReplicas,
}, &payload)
Expand Down Expand Up @@ -685,8 +679,8 @@ func (d *DAL) ReplaceDeployment(ctx context.Context, newDeploymentKey model.Depl
}
}

var payload encryption.EncryptedTimelineColumn
err = d.encryptJSON(map[string]any{
var payload ftlencryption.EncryptedTimelineColumn
err = d.encryption.EncryptJSON(map[string]any{
"min_replicas": int32(minReplicas),
"replaced": replacedDeploymentKey,
}, &payload)
Expand Down Expand Up @@ -898,8 +892,8 @@ func (d *DAL) InsertLogEvent(ctx context.Context, log *LogEvent) error {
"error": log.Error,
"stack": log.Stack,
}
var encryptedPayload encryption.EncryptedTimelineColumn
err := d.encryptJSON(payload, &encryptedPayload)
var encryptedPayload ftlencryption.EncryptedTimelineColumn
err := d.encryption.EncryptJSON(payload, &encryptedPayload)
if err != nil {
return fmt.Errorf("failed to encrypt log payload: %w", err)
}
Expand Down Expand Up @@ -979,8 +973,8 @@ func (d *DAL) InsertCallEvent(ctx context.Context, call *CallEvent) error {
if pr, ok := call.ParentRequestKey.Get(); ok {
parentRequestKey = optional.Some(pr.String())
}
var payload encryption.EncryptedTimelineColumn
err := d.encryptJSON(map[string]any{
var payload ftlencryption.EncryptedTimelineColumn
err := d.encryption.EncryptJSON(map[string]any{
"duration_ms": call.Duration.Milliseconds(),
"request": call.Request,
"response": call.Response,
Expand Down
93 changes: 11 additions & 82 deletions backend/controller/dal/dal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ import (
"github.com/alecthomas/types/optional"
"golang.org/x/sync/errgroup"

"github.com/TBD54566975/ftl/backend/controller/encryption"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
"github.com/TBD54566975/ftl/backend/libdal"
ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
ftlencryption "github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/model"
"github.com/TBD54566975/ftl/internal/sha256"
Expand All @@ -27,9 +28,11 @@ import (
func TestDAL(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, encryption.NewBuilder())
encryption, err := encryption.New(ctx, conn, ftlencryption.NewBuilder())
assert.NoError(t, err)

dal := New(ctx, conn, encryption)

var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100)
var testSHA = sha256.Sum(testContent)

Expand Down Expand Up @@ -291,9 +294,11 @@ func TestDAL(t *testing.T) {
func TestCreateArtefactConflict(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, encryption.NewBuilder())
encryption, err := encryption.New(ctx, conn, ftlencryption.NewBuilder())
assert.NoError(t, err)

dal := New(ctx, conn, encryption)

idch := make(chan sha256.SHA256, 2)

wg := sync.WaitGroup{}
Expand Down Expand Up @@ -368,9 +373,11 @@ func assertEventsEqual(t *testing.T, expected, actual []TimelineEvent) {
func TestDeleteOldEvents(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, encryption.NewBuilder())
encryption, err := encryption.New(ctx, conn, ftlencryption.NewBuilder())
assert.NoError(t, err)

dal := New(ctx, conn, encryption)

var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100)
var testSha sha256.SHA256

Expand Down Expand Up @@ -459,81 +466,3 @@ func TestDeleteOldEvents(t *testing.T) {
assert.Equal(t, int64(0), count)
})
}

func TestVerifyEncryption(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
uri := "fake-kms://CK6YwYkBElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEJy4TIQgfCuwxA3ZZgChp_wYARABGK6YwYkBIAE"

t.Run("DeleteVerificationColumns", func(t *testing.T) {
dal, err := New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Some(uri)))
assert.NoError(t, err)

// check that there are columns set in encryption_keys
row, err := dal.db.GetOnlyEncryptionKey(ctx)
assert.NoError(t, err)
assert.NotZero(t, row.VerifyTimeline.Ok())
assert.NotZero(t, row.VerifyAsync.Ok())

// delete the columns to see if they are recreated
err = dal.db.UpdateEncryptionVerification(ctx, optional.None[encryption.EncryptedTimelineColumn](), optional.None[encryption.EncryptedAsyncColumn]())
assert.NoError(t, err)

dal, err = New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Some(uri)))
assert.NoError(t, err)

row, err = dal.db.GetOnlyEncryptionKey(ctx)
assert.NoError(t, err)
assert.NotZero(t, row.VerifyTimeline.Ok())
assert.NotZero(t, row.VerifyAsync.Ok())
})

t.Run("DifferentKey", func(t *testing.T) {
_, err := New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Some(uri)))
assert.NoError(t, err)

differentKey := "fake-kms://CJP7ksIKElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEJWT3z-xdW23HO7hc9vF3YoYARABGJP7ksIKIAE"
_, err = New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Some(differentKey)))
assert.Error(t, err)
assert.Contains(t, err.Error(), "decryption failed")
})

t.Run("SameKeyButWrongTimelineVerification", func(t *testing.T) {
dal, err := New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Some(uri)))
assert.NoError(t, err)

err = dal.db.UpdateEncryptionVerification(ctx, optional.Some[encryption.EncryptedTimelineColumn]([]byte("123")), optional.None[encryption.EncryptedAsyncColumn]())
assert.NoError(t, err)
_, err = New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Some(uri)))
assert.Error(t, err)
assert.Contains(t, err.Error(), "verification sanity")
assert.Contains(t, err.Error(), "verify timeline")

err = dal.db.UpdateEncryptionVerification(ctx, optional.None[encryption.EncryptedTimelineColumn](), optional.Some[encryption.EncryptedAsyncColumn]([]byte("123")))
assert.NoError(t, err)
_, err = New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Some(uri)))
assert.Error(t, err)
assert.Contains(t, err.Error(), "verification sanity")
assert.Contains(t, err.Error(), "verify async")
})

t.Run("SameKeyButEncryptWrongPlainText", func(t *testing.T) {
result, err := conn.Exec("DELETE FROM encryption_keys")
assert.NoError(t, err)
affected, err := result.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(1), affected)
dal, err := New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Some(uri)))
assert.NoError(t, err)

encrypted := encryption.EncryptedColumn[encryption.TimelineSubKey]{}
err = dal.encrypt([]byte("123"), &encrypted)
assert.NoError(t, err)

err = dal.db.UpdateEncryptionVerification(ctx, optional.Some(encrypted), optional.None[encryption.EncryptedAsyncColumn]())
assert.NoError(t, err)
_, err = New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Some(uri)))
assert.Error(t, err)
assert.Contains(t, err.Error(), "string does not match")
})
}
Loading

0 comments on commit a67bf63

Please sign in to comment.