Skip to content

Commit

Permalink
fix: prevent implicit conversions to/from EncryptedColumn
Browse files Browse the repository at this point in the history
Make EncryptedColumn a struct so we can't accidentally implicitly convert
`[]byte` to/from it. As a side-effect, discovered a few columns in the DB that
weren't the correct types.
  • Loading branch information
alecthomas committed Sep 13, 2024
1 parent 8e72149 commit 02ed469
Show file tree
Hide file tree
Showing 15 changed files with 90 additions and 56 deletions.
2 changes: 1 addition & 1 deletion backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func New(ctx context.Context, conn *sql.DB, config Config, devel bool) (*Service
svc.routes.Store(map[string][]dal.Route{})
svc.schema.Store(&schema.Schema{})

cronSvc := cronjobs.New(ctx, key, svc.config.Advertise.Host, conn)
cronSvc := cronjobs.New(ctx, key, svc.config.Advertise.Host, encryptionSrv, conn)
svc.cronJobs = cronSvc

pubSub := pubsub.New(ctx, db, svc.tasks, svc)
Expand Down
17 changes: 13 additions & 4 deletions backend/controller/cronjobs/cronjobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import (

"github.com/TBD54566975/ftl/backend/controller/cronjobs/dal"
parentdal "github.com/TBD54566975/ftl/backend/controller/dal"
encryptionsvc "github.com/TBD54566975/ftl/backend/controller/encryption"
schemapb "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/schema"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/cron"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/model"
)
Expand All @@ -21,18 +23,20 @@ type Service struct {
key model.ControllerKey
requestSource string
dal dal.DAL
encryption *encryptionsvc.Service
clock clock.Clock
}

func New(ctx context.Context, key model.ControllerKey, requestSource string, conn *sql.DB) *Service {
return NewForTesting(ctx, key, requestSource, *dal.New(conn), clock.New())
func New(ctx context.Context, key model.ControllerKey, requestSource string, encryption *encryptionsvc.Service, conn *sql.DB) *Service {
return NewForTesting(ctx, key, requestSource, encryption, *dal.New(conn), clock.New())
}

func NewForTesting(ctx context.Context, key model.ControllerKey, requestSource string, dal dal.DAL, clock clock.Clock) *Service {
func NewForTesting(ctx context.Context, key model.ControllerKey, requestSource string, encryption *encryptionsvc.Service, dal dal.DAL, clock clock.Clock) *Service {
svc := &Service{
key: key,
requestSource: requestSource,
dal: dal,
encryption: encryption,
clock: clock,
}
return svc
Expand Down Expand Up @@ -174,11 +178,16 @@ func (s *Service) scheduleCronJob(ctx context.Context, tx *dal.DAL, job model.Cr

logger.Tracef("Scheduling cron job %q async_call execution at %s", job.Key, nextAttemptForJob)
origin := &parentdal.AsyncOriginCron{CronJobKey: job.Key}
var request encryption.EncryptedColumn[encryption.AsyncSubKey]
err = s.encryption.Encrypt([]byte(`{}`), &request)
if err != nil {
return fmt.Errorf("failed to encrypt request for job %q: %w", job.Key, err)
}
id, err := tx.CreateAsyncCall(ctx, dal.CreateAsyncCallParams{
ScheduledAt: nextAttemptForJob,
Verb: schema.RefKey{Module: job.Verb.Module, Name: job.Verb.Name},
Origin: origin.String(),
Request: []byte(`{}`),
Request: request,
})
if err != nil {
return fmt.Errorf("failed to create async call for job %q: %w", job.Key, err)
Expand Down
23 changes: 11 additions & 12 deletions backend/controller/cronjobs/cronjobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
)

func TestNewCronJobsForModule(t *testing.T) {
t.Parallel()
ctx := log.ContextWithNewDefaultLogger(context.Background())
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
Expand Down Expand Up @@ -52,7 +51,7 @@ func TestNewCronJobsForModule(t *testing.T) {

// Progress so that start_time is valid
clk.Add(time.Second)
cjs := NewForTesting(ctx, key, "test.com", *dal, clk)
cjs := NewForTesting(ctx, key, "test.com", encryption, *dal, clk)
// All jobs need to be scheduled
expectUnscheduledJobs(t, dal, clk, 2)
unscheduledJobs, err := dal.GetUnscheduledCronJobs(ctx, clk.Now())
Expand All @@ -70,8 +69,8 @@ func TestNewCronJobsForModule(t *testing.T) {
for _, job := range jobsToCreate {
j, err := dal.GetCronJobByKey(ctx, job.Key)
assert.NoError(t, err)
assert.Equal(t, job.StartTime, j.StartTime)
assert.Equal(t, j.NextExecution, clk.Now().Add(time.Second))
assert.Equal(t, j.StartTime, job.StartTime)
assert.Equal(t, clk.Now().Add(time.Second), j.NextExecution)

p, err := dal.IsCronJobPending(ctx, job.Key, job.StartTime)
assert.NoError(t, err)
Expand All @@ -82,10 +81,10 @@ func TestNewCronJobsForModule(t *testing.T) {
for i, job := range jobsToCreate {
call, _, err := parentDAL.AcquireAsyncCall(ctx)
assert.NoError(t, err)
assert.Equal(t, call.Verb, job.Verb.ToRefKey())
assert.Equal(t, call.Origin.String(), fmt.Sprintf("cron:%s", job.Key))
assert.Equal(t, call.Request, []byte("{}"))
assert.Equal(t, call.QueueDepth, int64(len(jobsToCreate)-i)) // widdling down queue
assert.Equal(t, job.Verb.ToRefKey(), call.Verb)
assert.Equal(t, fmt.Sprintf("cron:%s", job.Key), call.Origin.String())
assert.Equal(t, []byte("{}"), call.Request)
assert.Equal(t, int64(len(jobsToCreate)-i), call.QueueDepth) // widdling down queue

p, err := dal.IsCronJobPending(ctx, job.Key, job.StartTime)
assert.NoError(t, err)
Expand Down Expand Up @@ -116,10 +115,10 @@ func TestNewCronJobsForModule(t *testing.T) {
for i, job := range jobsToCreate {
call, _, err := parentDAL.AcquireAsyncCall(ctx)
assert.NoError(t, err)
assert.Equal(t, call.Verb, job.Verb.ToRefKey())
assert.Equal(t, call.Origin.String(), fmt.Sprintf("cron:%s", job.Key))
assert.Equal(t, call.Request, []byte("{}"))
assert.Equal(t, call.QueueDepth, int64(len(jobsToCreate)-i)) // widdling down queue
assert.Equal(t, job.Verb.ToRefKey(), call.Verb)
assert.Equal(t, fmt.Sprintf("cron:%s", job.Key), call.Origin.String())
assert.Equal(t, []byte("{}"), call.Request)
assert.Equal(t, int64(len(jobsToCreate)-i), call.QueueDepth) // widdling down queue

assert.Equal(t, call.ScheduledAt, clk.Now())

Expand Down
4 changes: 2 additions & 2 deletions backend/controller/cronjobs/dal/internal/sql/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions backend/controller/dal/fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
func (d *DAL) StartFSMTransition(ctx context.Context, fsm schema.RefKey, instanceKey string, destinationState schema.RefKey, request []byte, encrypted bool, retryParams schema.RetryParams) (err error) {
var encryptedRequest encryption.EncryptedAsyncColumn
if encrypted {
encryptedRequest = encryption.EncryptedAsyncColumn(request)
encryptedRequest.Set(request)
} else {
err = d.encryption.Encrypt(request, &encryptedRequest)
if err != nil {
Expand Down Expand Up @@ -139,9 +139,16 @@ func (d *DAL) PopNextFSMEvent(ctx context.Context, fsm schema.RefKey, instanceKe
}
return optional.None[NextFSMEvent](), err
}

var decryptedRequest json.RawMessage
err = d.encryption.DecryptJSON(&next.Request, &decryptedRequest)
if err != nil {
return optional.None[NextFSMEvent](), fmt.Errorf("failed to decrypt FSM request: %w", err)
}

return optional.Some(NextFSMEvent{
DestinationState: next.NextState,
Request: next.Request,
Request: decryptedRequest,
RequestType: next.RequestType,
}), nil
}
Expand Down
4 changes: 2 additions & 2 deletions backend/controller/dal/internal/sql/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions backend/controller/dal/internal/sql/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion backend/controller/dal/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ func (d *DAL) ProgressSubscriptions(ctx context.Context, eventConsumptionDelay t
observability.PubSub.PropagationFailed(ctx, "GetNextEventForSubscription", subscription.Topic.Payload, nextCursor.Caller, subscriptionRef(subscription), optional.None[schema.RefKey]())
return 0, fmt.Errorf("failed to get next cursor: %w", libdal.TranslatePGError(err))
}
payload, ok := nextCursor.Payload.Get()
if !ok {
observability.PubSub.PropagationFailed(ctx, "GetNextEventForSubscription-->Payload.Get", subscription.Topic.Payload, nextCursor.Caller, subscriptionRef(subscription), optional.None[schema.RefKey]())
return 0, fmt.Errorf("could not find payload to progress subscription: %w", libdal.TranslatePGError(err))
}
nextCursorKey, ok := nextCursor.Event.Get()
if !ok {
observability.PubSub.PropagationFailed(ctx, "GetNextEventForSubscription-->Event.Get", subscription.Topic.Payload, nextCursor.Caller, subscriptionRef(subscription), optional.None[schema.RefKey]())
Expand Down Expand Up @@ -131,7 +136,7 @@ func (d *DAL) ProgressSubscriptions(ctx context.Context, eventConsumptionDelay t
ScheduledAt: time.Now(),
Verb: subscriber.Sink,
Origin: origin.String(),
Request: nextCursor.Payload, // already encrypted
Request: payload, // already encrypted
RemainingAttempts: subscriber.RetryAttempts,
Backoff: subscriber.Backoff,
MaxBackoff: subscriber.MaxBackoff,
Expand Down
25 changes: 15 additions & 10 deletions backend/controller/encryption/dal/dal.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,18 @@ func (d *DAL) VerifyEncryptor(ctx context.Context, encryptor encryption.DataEncr
if err != nil {
return fmt.Errorf("failed to verify timeline subkey: %w", err)
}
if newTimeline != nil {
if newTimeline.Ok() {
needsUpdate = true
row.VerifyTimeline = optional.Some(newTimeline)
row.VerifyTimeline = newTimeline
}

newAsync, err := verifySubkey(encryptor, row.VerifyAsync)
if err != nil {
return fmt.Errorf("failed to verify async subkey: %w", err)
}
if newAsync != nil {
if newAsync.Ok() {
needsUpdate = true
row.VerifyAsync = optional.Some(newAsync)
row.VerifyAsync = newAsync
}

if !needsUpdate {
Expand All @@ -115,25 +115,30 @@ func (d *DAL) VerifyEncryptor(ctx context.Context, encryptor encryption.DataEncr

// verifySubkey checks if the subkey is set and if not, sets it to a verification string.
// returns (nil, nil) if verified and not changed
func verifySubkey[SK encryption.SubKey](encryptor encryption.DataEncryptor, encrypted optional.Option[encryption.EncryptedColumn[SK]]) (encryption.EncryptedColumn[SK], error) {
func verifySubkey[SK encryption.SubKey](
encryptor encryption.DataEncryptor,
encrypted optional.Option[encryption.EncryptedColumn[SK]],
) (optional.Option[encryption.EncryptedColumn[SK]], error) {
type EC = encryption.EncryptedColumn[SK]

verifyField, ok := encrypted.Get()
if !ok {
err := encryptor.Encrypt([]byte(verification), &verifyField)
if err != nil {
return nil, fmt.Errorf("failed to encrypt verification sanity string: %w", err)
return optional.None[EC](), fmt.Errorf("failed to encrypt verification sanity string: %w", err)
}
return verifyField, nil
return optional.Some(verifyField), nil
}

decrypted, err := encryptor.Decrypt(&verifyField)
if err != nil {
return nil, fmt.Errorf("failed to decrypt verification sanity string: %w", err)
return optional.None[EC](), fmt.Errorf("failed to decrypt verification sanity string: %w", err)
}

if string(decrypted) != verification {
return nil, fmt.Errorf("decrypted verification string does not match expected value")
return optional.None[EC](), fmt.Errorf("decrypted verification string does not match expected value")
}

// verified, no need to update
return nil, nil
return optional.None[EC](), nil
}
4 changes: 2 additions & 2 deletions backend/controller/leases/dal/internal/sql/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- migrate:up

ALTER TABLE fsm_next_event
ALTER COLUMN request TYPE encrypted_async;

-- migrate:down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- migrate:up

ALTER TABLE topic_events
ALTER COLUMN payload TYPE encrypted_async;

-- migrate:down
4 changes: 2 additions & 2 deletions internal/configuration/dal/internal/sql/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 8 additions & 11 deletions internal/encryption/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,25 @@ var _ Encrypted = &EncryptedColumn[TimelineSubKey]{}
// EncryptedColumn is a type that represents an encrypted column.
//
// It can be used by sqlc to map to/from a bytea column in the database.
type EncryptedColumn[SK SubKey] []byte
type EncryptedColumn[SK SubKey] struct{ data []byte }

var _ driver.Valuer = &EncryptedColumn[TimelineSubKey]{}
var _ sql.Scanner = &EncryptedColumn[TimelineSubKey]{}

func (e *EncryptedColumn[SK]) SubKey() string { var sk SK; return sk.SubKey() }
func (e *EncryptedColumn[SK]) Bytes() []byte { return *e }
func (e *EncryptedColumn[SK]) Set(b []byte) { *e = b }
func (e *EncryptedColumn[SK]) Value() (driver.Value, error) {
return []byte(*e), nil
func (e *EncryptedColumn[SK]) SubKey() string { var sk SK; return sk.SubKey() }
func (e *EncryptedColumn[SK]) Bytes() []byte { return e.data }
func (e *EncryptedColumn[SK]) Set(b []byte) { e.data = b }
func (e EncryptedColumn[SK]) Value() (driver.Value, error) { return e.data, nil }
func (e *EncryptedColumn[SK]) GoString() string {
return fmt.Sprintf("EncryptedColumn[%s](%d bytes)", e.SubKey(), len(e.data))
}

func (e *EncryptedColumn[SK]) Scan(src interface{}) error {
if src == nil {
*e = nil
return nil
}
b, ok := src.([]byte)
if !ok {
return fmt.Errorf("expected []byte, got %T", src)
}
*e = b
e.data = b
return nil
}

Expand Down
Loading

0 comments on commit 02ed469

Please sign in to comment.