diff --git a/engine/storage/mysql/query.sql b/engine/storage/mysql/query.sql index 88fb22c..8b67168 100644 --- a/engine/storage/mysql/query.sql +++ b/engine/storage/mysql/query.sql @@ -128,6 +128,16 @@ WHERE step_id = ? AND completed != 0; +-- name: LockIDCommandsByStepID :exec +SELECT + command_uuid +FROM + id_commands +WHERE + enrollment_id = ? AND + step_id = ? +FOR UPDATE; + -- name: RemoveIDCommandsByStepID :exec DELETE FROM id_commands diff --git a/engine/storage/mysql/sqlc/query.sql.go b/engine/storage/mysql/sqlc/query.sql.go index 8b608dd..d1fae99 100644 --- a/engine/storage/mysql/sqlc/query.sql.go +++ b/engine/storage/mysql/sqlc/query.sql.go @@ -396,6 +396,27 @@ func (q *Queries) GetWorkflowLastStarted(ctx context.Context, arg GetWorkflowLas return last_created_unix, err } +const lockIDCommandsByStepID = `-- name: LockIDCommandsByStepID :exec +SELECT + command_uuid +FROM + id_commands +WHERE + enrollment_id = ? AND + step_id = ? +FOR UPDATE +` + +type LockIDCommandsByStepIDParams struct { + EnrollmentID string + StepID int64 +} + +func (q *Queries) LockIDCommandsByStepID(ctx context.Context, arg LockIDCommandsByStepIDParams) error { + _, err := q.db.ExecContext(ctx, lockIDCommandsByStepID, arg.EnrollmentID, arg.StepID) + return err +} + const removeIDCommandsByStepID = `-- name: RemoveIDCommandsByStepID :exec DELETE FROM id_commands diff --git a/engine/storage/mysql/storage.go b/engine/storage/mysql/storage.go index 7cd9e99..ef954dc 100644 --- a/engine/storage/mysql/storage.go +++ b/engine/storage/mysql/storage.go @@ -97,7 +97,13 @@ func (s *MySQLStorage) StoreCommandResponseAndRetrieveCompletedStep(ctx context. Commands: []storage.StepCommandResult{*sc}, } - // TODO: select ... for update on id commands? + err = qtx.LockIDCommandsByStepID(ctx, sqlc.LockIDCommandsByStepIDParams{ + EnrollmentID: id, + StepID: cmdCt.StepID, + }) + if err != nil { + return fmt.Errorf("lock commands by step by id (%d): %w", cmdCt.StepID, err) + } cmdR, err := qtx.GetIDCommandsByStepID(ctx, sqlc.GetIDCommandsByStepIDParams{ EnrollmentID: id,