diff --git a/docs/openapi.yaml b/docs/openapi.yaml index 1c822f7..a6e396c 100644 --- a/docs/openapi.yaml +++ b/docs/openapi.yaml @@ -397,7 +397,7 @@ components: event: type: string description: Event type to subscribe to. - enum: [Enrollment, Authenticate, TokenUpdate, CheckOut] + enum: [Enrollment, Authenticate, TokenUpdate, CheckOut, Idle, IdleNotStartedSince] workflow: type: string description: Name of NanoCMD workflow. @@ -405,6 +405,9 @@ components: context: type: string description: Workflow-dependent context. + event_context: + type: string + description: Event-dependent context. JSONError: type: object properties: diff --git a/docs/operations-guide.md b/docs/operations-guide.md index 01ef433..5232948 100644 --- a/docs/operations-guide.md +++ b/docs/operations-guide.md @@ -153,7 +153,8 @@ Configures Event Subscriptions. Event Subscriptions start workflows for MDM even { "event": "Enrollment", "workflow": "io.micromdm.wf.example.v1", - "context": "string" + "context": "string", + "event_context": "string" } ``` @@ -164,8 +165,11 @@ The JSON keys are as follows: * `TokenUpdate`: when an enrollment sends a TokenUpdate MDM check-in message. * `Enrollment`: when an enrollment enrolls; i.e. the first TokenUpdate message. * `CheckOut`: when a device sends a CheckOut MDM check-in message. + * `Idle`: when an enrollment sends an Idle command response. + * `IdleNotStartedSince`: when an enrollment sends an Idle message and the associated workflow has not been started in the given number of seconds. The seconds are provided in the `event_context` string. * `workflow`: the name of the workflow. * `context`: optional context to give to the workflow when it starts. +* `event_context`: optional context to give to the event. #### FileVault profile template endpoint diff --git a/engine/engine.go b/engine/engine.go index ee8cfa5..c6fb8a7 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "strconv" "sync" "time" @@ -188,6 +189,9 @@ func (e *Engine) StartWorkflow(ctx context.Context, name string, context []byte, if err = w.Start(ctx, ss); err != nil { return instanceID, fmt.Errorf("staring workflow: %w", err) } + if err = e.storage.RecordWorkflowStarted(ctx, startID, name, time.Now()); err != nil { + return instanceID, fmt.Errorf("recording workflow status: %w", err) + } logger.Debug( logkeys.InstanceID, instanceID, logkeys.Message, "starting workflow", @@ -288,6 +292,114 @@ func logAndError(err error, logger log.Logger, msg string) error { return fmt.Errorf("%s: %w", msg, err) } +// MDMIdleEvent is called when an MDM Report Results has an "Idle" status. +// MDMIdleEvent will dispatch workflow "Idle" events (for workflows that are +// configured for it) and will also start workflows for the "IdleNotStartedSince" +// event subscription type. +// Note: any other event subscription type starting workflows is not supported. +func (e *Engine) MDMIdleEvent(ctx context.Context, id string, raw []byte, mdmContext *workflow.MDMContext, eventAt time.Time) error { + logger := ctxlog.Logger(ctx, e.logger).With(logkeys.EnrollmentID, id) + + // dispatch the events to (only) the workflow events + event := &workflow.Event{EventFlag: workflow.EventIdle} + if err := e.dispatchEvents(ctx, id, event, mdmContext, false, true); err != nil { + logger.Info( + logkeys.Message, "idle event: dispatch workflow events", + logkeys.Event, event.EventFlag, + logkeys.Error, err, + ) + } + + if e.eventStorage == nil { + return nil + } + + subs, err := e.eventStorage.RetrieveEventSubscriptionsByEvent(ctx, workflow.EventIdleNotStartedSince) + if err != nil { + logger.Info( + logkeys.Message, "retrieving event subscriptions", + logkeys.Event, workflow.EventIdleNotStartedSince, + logkeys.Error, err, + ) + } + + if len(subs) < 1 { + return nil + } + + var wg sync.WaitGroup + event = &workflow.Event{EventFlag: workflow.EventIdleNotStartedSince} + for _, sub := range subs { + wg.Add(1) + go func(es *storage.EventSubscription) { + defer wg.Done() + + if es == nil { + return + } + + subLogger := logger.With( + logkeys.Event, workflow.EventIdleNotStartedSince, + logkeys.WorkflowName, es.Workflow, + ) + + // get the last time this workflow started for this + // enrollment ID for the workflow was subscribed to. + started, err := e.storage.RetrieveWorkflowStarted(ctx, id, es.Workflow) + if err != nil { + subLogger.Info( + logkeys.Message, "retrieving workflow status", + logkeys.Error, err, + ) + return + } + + // make sure we have a valid event context (time between runs) + if es.EventContext == "" { + subLogger.Info( + logkeys.Error, "event context is empty", + ) + return + } + sinceSeconds, err := strconv.Atoi(es.EventContext) + if err != nil { + subLogger.Info( + logkeys.Message, "converting event context to integer", + logkeys.Error, err, + ) + return + } else if sinceSeconds < 1 { + subLogger.Info( + logkeys.Error, "event context less than 1 second", + ) + return + } + + // check if we've run this workflow "recently" + if !eventAt.After(started.Add(time.Second * time.Duration(sinceSeconds))) { + // TODO: hide behind an "extra" debug flag? + // subLogger.Debug(logkeys.Message, "workflow not due yet") + return + } + + if instanceID, err := e.StartWorkflow(ctx, es.Workflow, []byte(es.Context), []string{id}, event, mdmContext); err != nil { + subLogger.Info( + logkeys.Message, "start workflow", + logkeys.InstanceID, instanceID, + logkeys.Error, err, + ) + } else { + subLogger.Debug( + logkeys.Message, "started workflow", + logkeys.InstanceID, instanceID, + ) + } + }(sub) + } + wg.Wait() + return nil +} + // MDMCommandResponseEvent receives MDM command responses. func (e *Engine) MDMCommandResponseEvent(ctx context.Context, id string, uuid string, raw []byte, mdmContext *workflow.MDMContext) error { logger := ctxlog.Logger(ctx, e.logger).With( @@ -374,14 +486,15 @@ func (e *Engine) MDMCommandResponseEvent(ctx context.Context, id string, uuid st // dispatchEvents dispatches MDM check-in events. // this includes event subscriptions (user configured) and workflow -// configurations. -func (e *Engine) dispatchEvents(ctx context.Context, id string, ev *workflow.Event, mdmCtx *workflow.MDMContext) error { +// configs. The bool subEV as true indicates to run subscription event +// workflows and wfEV indicates to run workflow-configured events. +func (e *Engine) dispatchEvents(ctx context.Context, id string, ev *workflow.Event, mdmCtx *workflow.MDMContext, subEV, wfEV bool) error { logger := ctxlog.Logger(ctx, e.logger).With( - "event", ev.EventFlag, + logkeys.Event, ev.EventFlag, logkeys.EnrollmentID, id, ) var wg sync.WaitGroup - if e.eventStorage != nil { + if subEV && e.eventStorage != nil { subs, err := e.eventStorage.RetrieveEventSubscriptionsByEvent(ctx, ev.EventFlag) if err != nil { logger.Info( @@ -411,23 +524,25 @@ func (e *Engine) dispatchEvents(ctx context.Context, id string, ev *workflow.Eve } } } - for _, w := range e.eventWorkflows(ev.EventFlag) { - wg.Add(1) - go func(w workflow.Workflow) { - defer wg.Done() - if err := w.Event(ctx, ev, id, mdmCtx); err != nil { - logger.Info( - logkeys.Message, "workflow event", - logkeys.WorkflowName, w.Name(), - logkeys.Error, err, - ) - } else { - logger.Debug( - logkeys.Message, "workflow event", - logkeys.WorkflowName, w.Name(), - ) - } - }(w) + if wfEV { + for _, w := range e.eventWorkflows(ev.EventFlag) { + wg.Add(1) + go func(w workflow.Workflow) { + defer wg.Done() + if err := w.Event(ctx, ev, id, mdmCtx); err != nil { + logger.Info( + logkeys.Message, "workflow event", + logkeys.WorkflowName, w.Name(), + logkeys.Error, err, + ) + } else { + logger.Debug( + logkeys.Message, "workflow event", + logkeys.WorkflowName, w.Name(), + ) + } + }(w) + } } wg.Wait() return nil @@ -485,12 +600,16 @@ func (e *Engine) MDMCheckinEvent(ctx context.Context, id string, checkin interfa if err := e.storage.CancelSteps(ctx, id, ""); err != nil { return logAndError(err, logger, "checkin event: cancel steps") } + // also clear out any workflow status for an id + if err := e.storage.ClearWorkflowStatus(ctx, id); err != nil { + return logAndError(err, logger, "checkin event: clearing workflow status") + } } for _, event := range events { - if err := e.dispatchEvents(ctx, id, event, mdmContext); err != nil { + if err := e.dispatchEvents(ctx, id, event, mdmContext, true, true); err != nil { logger.Info( logkeys.Message, "checkin event: dispatch events", - "event", event.EventFlag, + logkeys.Event, event.EventFlag, logkeys.Error, err, ) } diff --git a/engine/storage/diskv/diskv.go b/engine/storage/diskv/diskv.go index d0ac672..8043525 100644 --- a/engine/storage/diskv/diskv.go +++ b/engine/storage/diskv/diskv.go @@ -34,5 +34,10 @@ func New(path string) *Diskv { CacheSizeMax: 1024 * 1024, })), uuid.NewUUID(), + kvdiskv.NewBucket(diskv.New(diskv.Options{ + BasePath: filepath.Join(path, "engine", "wfstatus"), + Transform: flatTransform, + CacheSizeMax: 1024 * 1024, + })), )} } diff --git a/engine/storage/inmem/inmem.go b/engine/storage/inmem/inmem.go index af74429..b172088 100644 --- a/engine/storage/inmem/inmem.go +++ b/engine/storage/inmem/inmem.go @@ -18,5 +18,6 @@ func New() *InMem { kvmap.NewBucket(), kvmap.NewBucket(), uuid.NewUUID(), + kvmap.NewBucket(), )} } diff --git a/engine/storage/kv/event.go b/engine/storage/kv/event.go index 8d0f6c4..b151ad8 100644 --- a/engine/storage/kv/event.go +++ b/engine/storage/kv/event.go @@ -13,9 +13,10 @@ import ( ) const ( - keySfxEventFlag = ".flag" // contains a strconv integer - keySfxEventWorkflow = ".name" - keySfxEventContext = ".ctx" + keySfxEventFlag = ".flag" // contains a strconv integer + keySfxEventWorkflow = ".name" + keySfxEventContext = ".ctx" + keySfxEventEventContext = ".evctx" ) type kvEventSubscription struct { @@ -37,6 +38,9 @@ func (es *kvEventSubscription) set(ctx context.Context, b kv.Bucket, name string if len(es.Context) > 0 { esMap[name+keySfxEventContext] = []byte(es.Context) } + if len(es.EventContext) > 0 { + esMap[name+keySfxEventEventContext] = []byte(es.EventContext) + } return kv.SetMap(ctx, b, esMap) } @@ -62,13 +66,21 @@ func (es *kvEventSubscription) get(ctx context.Context, b kv.Bucket, name string es.Event = workflow.EventFlag(eventFlag).String() if ok, err := b.Has(ctx, name+keySfxEventContext); err != nil { return fmt.Errorf("checking event context: %w", err) - } else if !ok { - return nil + } else if ok { + if ctxBytes, err := b.Get(ctx, name+keySfxEventContext); err != nil { + return fmt.Errorf("getting event context: %w", err) + } else { + es.Context = string(ctxBytes) + } } - if ctxBytes, err := b.Get(ctx, name+keySfxEventContext); err != nil { - return fmt.Errorf("getting event context: %w", err) - } else { - es.Context = string(ctxBytes) + if ok, err := b.Has(ctx, name+keySfxEventEventContext); err != nil { + return fmt.Errorf("checking event event_context: %w", err) + } else if ok { + if evCtxBytes, err := b.Get(ctx, name+keySfxEventEventContext); err != nil { + return fmt.Errorf("getting event event_context: %w", err) + } else { + es.EventContext = string(evCtxBytes) + } } return nil } @@ -145,5 +157,6 @@ func (s *KV) DeleteEventSubscription(ctx context.Context, name string) error { name + keySfxEventFlag, name + keySfxEventWorkflow, name + keySfxEventContext, + name + keySfxEventEventContext, }) } diff --git a/engine/storage/kv/kv.go b/engine/storage/kv/kv.go index 9daf0b9..76dc6e8 100644 --- a/engine/storage/kv/kv.go +++ b/engine/storage/kv/kv.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "strings" "sync" "time" @@ -15,20 +16,22 @@ import ( // KV is a workflow engine storage backend using a key-value interface. type KV struct { - mu sync.RWMutex - stepStore kv.TraversingBucket - idCmdStore kv.TraversingBucket - eventStore kv.TraversingBucket - ider uuid.IDer + mu sync.RWMutex + stepStore kv.TraversingBucket + idCmdStore kv.TraversingBucket + eventStore kv.TraversingBucket + ider uuid.IDer + statusStore kv.TraversingBucket } // New creates a new key-value workflow engine storage backend. -func New(stepStore kv.TraversingBucket, idCmdStore kv.TraversingBucket, eventStore kv.TraversingBucket, ider uuid.IDer) *KV { +func New(stepStore kv.TraversingBucket, idCmdStore kv.TraversingBucket, eventStore kv.TraversingBucket, ider uuid.IDer, statusStore kv.TraversingBucket) *KV { return &KV{ - stepStore: stepStore, - idCmdStore: idCmdStore, - eventStore: eventStore, - ider: ider, + stepStore: stepStore, + idCmdStore: idCmdStore, + eventStore: eventStore, + ider: ider, + statusStore: statusStore, } } @@ -281,3 +284,51 @@ func (s *KV) CancelSteps(ctx context.Context, id, workflowName string) error { } return nil } + +func workflowStatusKey(id, workflowName string) string { + return id + "." + workflowName +} + +// RetrieveWorkflowStarted returns the last time a workflow was started for id. +func (s *KV) RetrieveWorkflowStarted(ctx context.Context, id, workflowName string) (time.Time, error) { + var started time.Time + if found, err := s.statusStore.Has(ctx, workflowStatusKey(id, workflowName)); err != nil { + return started, fmt.Errorf("status not found for id=%s workflow=%s: %w", id, workflowName, err) + } else if !found { + return started, nil + } + b, err := s.statusStore.Get(ctx, workflowStatusKey(id, workflowName)) + if err != nil { + return started, fmt.Errorf("getting workflow status: %w", err) + } + if err = started.UnmarshalText(b); err != nil { + err = fmt.Errorf("unmarshaling workflow status: %w", err) + } + return started, err +} + +// RecordWorkflowStarted stores the started time for workflowName for ids. +func (s *KV) RecordWorkflowStarted(ctx context.Context, ids []string, workflowName string, started time.Time) error { + b, err := started.MarshalText() + if err != nil { + return fmt.Errorf("marshaling workflow status: %w", err) + } + for _, id := range ids { + if err = s.statusStore.Set(ctx, workflowStatusKey(id, workflowName), b); err != nil { + return fmt.Errorf("setting workflow status for id=%s workflow=%s: %w", id, workflowName, err) + } + } + return nil +} + +// ClearWorkflowStatus removes all workflow start times for id. +func (s *KV) ClearWorkflowStatus(ctx context.Context, id string) error { + var toDelete []string + for k := range s.statusStore.Keys(nil) { + // very inefficient! this could be a large table + if strings.HasPrefix(k, id+".") { + toDelete = append(toDelete, k) + } + } + return kv.DeleteSlice(ctx, s.statusStore, toDelete) +} diff --git a/engine/storage/mysql/event.go b/engine/storage/mysql/event.go index 0f5a90b..db1d85b 100644 --- a/engine/storage/mysql/event.go +++ b/engine/storage/mysql/event.go @@ -18,9 +18,10 @@ func (s *MySQLStorage) RetrieveEventSubscriptions(ctx context.Context, names []s retEvents := make(map[string]*storage.EventSubscription) for _, event := range events { retEvents[event.EventName] = &storage.EventSubscription{ - Event: event.EventType, - Workflow: event.WorkflowName, - Context: event.Context.String, + Event: event.EventType, + Workflow: event.WorkflowName, + Context: event.Context.String, + EventContext: event.EventContext.String, } } return retEvents, nil @@ -36,9 +37,10 @@ func (s *MySQLStorage) RetrieveEventSubscriptionsByEvent(ctx context.Context, f var retEvents []*storage.EventSubscription for _, event := range events { retEvents = append(retEvents, &storage.EventSubscription{ - Event: event.EventType, - Workflow: event.WorkflowName, - Context: event.Context.String, + Event: event.EventType, + Workflow: event.WorkflowName, + Context: event.Context.String, + EventContext: event.EventContext.String, }) } return retEvents, nil @@ -51,17 +53,19 @@ func (s *MySQLStorage) StoreEventSubscription(ctx context.Context, name string, ctx, ` INSERT INTO wf_events - (event_name, event_type, workflow_name, context) + (event_name, event_type, workflow_name, event_context, context) VALUES - (?, ?, ?, ?) AS new + (?, ?, ?, ?, ?) AS new ON DUPLICATE KEY UPDATE workflow_name = new.workflow_name, event_type = new.event_type, + event_context = new.event_context, context = new.context;`, name, es.Event, es.Workflow, + sqlNullString(es.EventContext), sqlNullString(es.Context), ) return err diff --git a/engine/storage/mysql/mysql.go b/engine/storage/mysql/mysql.go index 8212aca..373f9fc 100644 --- a/engine/storage/mysql/mysql.go +++ b/engine/storage/mysql/mysql.go @@ -11,6 +11,8 @@ import ( "github.com/micromdm/nanocmd/engine/storage/mysql/sqlc" ) +const mySQLTimestampFormat = "2006-01-02 15:04:05" + // MySQLStorage implements a storage.AllStorage using MySQL. type MySQLStorage struct { db *sql.DB diff --git a/engine/storage/mysql/mysql_test.go b/engine/storage/mysql/mysql_test.go index ac2b7c2..cf45304 100644 --- a/engine/storage/mysql/mysql_test.go +++ b/engine/storage/mysql/mysql_test.go @@ -21,5 +21,15 @@ func TestMySQLStorage(t *testing.T) { t.Fatal(err) } + // to test using an existing DB/DSN: + // + // DELETE FROM id_commands; + // DELETE FROM steps; + // DELETE FROM wf_events; + // + // this clears out some left-over workflow starts that are + // intentionally left incomplete but are re-used when another + // test is completed + test.TestEngineStorage(t, func() storage.AllStorage { return s }) } diff --git a/engine/storage/mysql/query.sql b/engine/storage/mysql/query.sql index b3d6036..250231c 100644 --- a/engine/storage/mysql/query.sql +++ b/engine/storage/mysql/query.sql @@ -153,3 +153,17 @@ WHERE c.completed = 0 AND s.workflow_name = ?; +-- name: GetWorkflowLastStarted :one +SELECT + last_created_at +FROM + wf_status +WHERE + enrollment_id = ? AND + workflow_name = ?; + +-- name: ClearWorkflowStatus :exec +DELETE FROM + wf_status +WHERE + enrollment_id = ?; diff --git a/engine/storage/mysql/query_event.sql b/engine/storage/mysql/query_event.sql index ff161e8..c94fe3f 100644 --- a/engine/storage/mysql/query_event.sql +++ b/engine/storage/mysql/query_event.sql @@ -2,6 +2,7 @@ SELECT event_name, context, + event_context, workflow_name, event_type FROM @@ -12,6 +13,7 @@ WHERE -- name: GetEventsByType :many SELECT context, + event_context, workflow_name, event_type FROM diff --git a/engine/storage/mysql/schema.00001.sql b/engine/storage/mysql/schema.00001.sql new file mode 100644 index 0000000..b421243 --- /dev/null +++ b/engine/storage/mysql/schema.00001.sql @@ -0,0 +1,15 @@ +ALTER TABLE wf_events ADD COLUMN event_context MEDIUMTEXT NULL; +CREATE TABLE wf_status ( + enrollment_id VARCHAR(255) NOT NULL, + workflow_name VARCHAR(255) NOT NULL, + + last_created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + + INDEX (enrollment_id), + INDEX (workflow_name), + + PRIMARY KEY (enrollment_id, workflow_name) +); diff --git a/engine/storage/mysql/schema.sql b/engine/storage/mysql/schema.sql index 96213c4..e11171c 100644 --- a/engine/storage/mysql/schema.sql +++ b/engine/storage/mysql/schema.sql @@ -68,6 +68,7 @@ CREATE TABLE wf_events ( event_name VARCHAR(255) NOT NULL, context MEDIUMTEXT NULL, + event_context MEDIUMTEXT NULL, workflow_name VARCHAR(255) NOT NULL, event_type VARCHAR(63) NOT NULL, @@ -77,4 +78,19 @@ CREATE TABLE wf_events ( INDEX (event_type), PRIMARY KEY (event_name) -); \ No newline at end of file +); + +CREATE TABLE wf_status ( + enrollment_id VARCHAR(255) NOT NULL, + workflow_name VARCHAR(255) NOT NULL, + + last_created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + + INDEX (enrollment_id), + INDEX (workflow_name), + + PRIMARY KEY (enrollment_id, workflow_name) +); diff --git a/engine/storage/mysql/sqlc.yaml b/engine/storage/mysql/sqlc.yaml index ac70b4f..15c51aa 100644 --- a/engine/storage/mysql/sqlc.yaml +++ b/engine/storage/mysql/sqlc.yaml @@ -23,3 +23,6 @@ sql: go_type: type: "byte" slice: true + - column: "wf_status.last_created_at" + go_type: + type: "string" diff --git a/engine/storage/mysql/sqlc/db.go b/engine/storage/mysql/sqlc/db.go index 6a77d41..c5852e0 100644 --- a/engine/storage/mysql/sqlc/db.go +++ b/engine/storage/mysql/sqlc/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.21.0 +// sqlc v1.26.0 package sqlc diff --git a/engine/storage/mysql/sqlc/models.go b/engine/storage/mysql/sqlc/models.go index af8de69..28d3646 100644 --- a/engine/storage/mysql/sqlc/models.go +++ b/engine/storage/mysql/sqlc/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.21.0 +// sqlc v1.26.0 package sqlc @@ -45,8 +45,17 @@ type StepCommand struct { type WfEvent struct { EventName string Context sql.NullString + EventContext sql.NullString WorkflowName string EventType string CreatedAt sql.NullTime UpdatedAt sql.NullTime } + +type WfStatus struct { + EnrollmentID string + WorkflowName string + LastCreatedAt string + CreatedAt sql.NullTime + UpdatedAt sql.NullTime +} diff --git a/engine/storage/mysql/sqlc/query.sql.go b/engine/storage/mysql/sqlc/query.sql.go index cd7e2aa..128ac63 100644 --- a/engine/storage/mysql/sqlc/query.sql.go +++ b/engine/storage/mysql/sqlc/query.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.21.0 +// sqlc v1.26.0 // source: query.sql package sqlc @@ -11,6 +11,18 @@ import ( "strings" ) +const clearWorkflowStatus = `-- name: ClearWorkflowStatus :exec +DELETE FROM + wf_status +WHERE + enrollment_id = ? +` + +func (q *Queries) ClearWorkflowStatus(ctx context.Context, enrollmentID string) error { + _, err := q.db.ExecContext(ctx, clearWorkflowStatus, enrollmentID) + return err +} + const countOutstandingIDWorkflowStepCommands = `-- name: CountOutstandingIDWorkflowStepCommands :one SELECT COUNT(*), @@ -362,6 +374,28 @@ func (q *Queries) GetStepByID(ctx context.Context, id int64) (GetStepByIDRow, er return i, err } +const getWorkflowLastStarted = `-- name: GetWorkflowLastStarted :one +SELECT + last_created_at +FROM + wf_status +WHERE + enrollment_id = ? AND + workflow_name = ? +` + +type GetWorkflowLastStartedParams struct { + EnrollmentID string + WorkflowName string +} + +func (q *Queries) GetWorkflowLastStarted(ctx context.Context, arg GetWorkflowLastStartedParams) (string, error) { + row := q.db.QueryRowContext(ctx, getWorkflowLastStarted, arg.EnrollmentID, arg.WorkflowName) + var last_created_at string + err := row.Scan(&last_created_at) + return last_created_at, err +} + const removeIDCommandsByStepID = `-- name: RemoveIDCommandsByStepID :exec DELETE FROM id_commands diff --git a/engine/storage/mysql/sqlc/query_event.sql.go b/engine/storage/mysql/sqlc/query_event.sql.go index 37ed716..a444ef0 100644 --- a/engine/storage/mysql/sqlc/query_event.sql.go +++ b/engine/storage/mysql/sqlc/query_event.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.21.0 +// sqlc v1.26.0 // source: query_event.sql package sqlc @@ -15,6 +15,7 @@ const getEventsByNames = `-- name: GetEventsByNames :many SELECT event_name, context, + event_context, workflow_name, event_type FROM @@ -26,6 +27,7 @@ WHERE type GetEventsByNamesRow struct { EventName string Context sql.NullString + EventContext sql.NullString WorkflowName string EventType string } @@ -52,6 +54,7 @@ func (q *Queries) GetEventsByNames(ctx context.Context, names []string) ([]GetEv if err := rows.Scan( &i.EventName, &i.Context, + &i.EventContext, &i.WorkflowName, &i.EventType, ); err != nil { @@ -71,6 +74,7 @@ func (q *Queries) GetEventsByNames(ctx context.Context, names []string) ([]GetEv const getEventsByType = `-- name: GetEventsByType :many SELECT context, + event_context, workflow_name, event_type FROM @@ -81,6 +85,7 @@ WHERE type GetEventsByTypeRow struct { Context sql.NullString + EventContext sql.NullString WorkflowName string EventType string } @@ -94,7 +99,12 @@ func (q *Queries) GetEventsByType(ctx context.Context, eventType string) ([]GetE var items []GetEventsByTypeRow for rows.Next() { var i GetEventsByTypeRow - if err := rows.Scan(&i.Context, &i.WorkflowName, &i.EventType); err != nil { + if err := rows.Scan( + &i.Context, + &i.EventContext, + &i.WorkflowName, + &i.EventType, + ); err != nil { return nil, err } items = append(items, i) diff --git a/engine/storage/mysql/sqlc/query_worker.sql.go b/engine/storage/mysql/sqlc/query_worker.sql.go index 785e25f..3def03e 100644 --- a/engine/storage/mysql/sqlc/query_worker.sql.go +++ b/engine/storage/mysql/sqlc/query_worker.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.21.0 +// sqlc v1.26.0 // source: query_worker.sql package sqlc diff --git a/engine/storage/mysql/storage.go b/engine/storage/mysql/storage.go index 0ce8271..5e6f00b 100644 --- a/engine/storage/mysql/storage.go +++ b/engine/storage/mysql/storage.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "strings" "time" "github.com/micromdm/nanocmd/engine/storage" @@ -241,3 +242,53 @@ func (s *MySQLStorage) CancelSteps(ctx context.Context, id, workflowName string) return nil }) } + +// RetrieveWorkflowStarted returns the last time a workflow was started for id. +func (s *MySQLStorage) RetrieveWorkflowStarted(ctx context.Context, id, workflowName string) (time.Time, error) { + ret, err := s.q.GetWorkflowLastStarted(ctx, sqlc.GetWorkflowLastStartedParams{EnrollmentID: id, WorkflowName: workflowName}) + if err != nil { + return time.Time{}, err + } + parsedTime, err := time.Parse(mySQLTimestampFormat, ret) + if err != nil { + return time.Time{}, fmt.Errorf("parsing time: %w", err) + } + return parsedTime, err +} + +// RecordWorkflowStarted stores the started time for workflowName for ids. +func (s *MySQLStorage) RecordWorkflowStarted(ctx context.Context, ids []string, workflowName string, started time.Time) error { + if len(ids) < 1 { + return errors.New("no id(s) provided") + } + const numFields = 3 + const subst = ", (?, ?, ?)" + fmt.Println(len(ids), len(ids)-1) + parms := make([]interface{}, len(ids)*numFields) + startedFormat := started.Format(mySQLTimestampFormat) + for i, id := range ids { + // these must match the SQL query, below + parms[i*numFields] = id + parms[i*numFields+1] = workflowName + parms[i*numFields+2] = startedFormat + } + val := subst[2:] + strings.Repeat(subst, len(ids)-1) + _, err := s.db.ExecContext( + ctx, + ` +INSERT INTO wf_status + (enrollment_id, workflow_name, last_created_at) +VALUES + `+val+` AS new +ON DUPLICATE KEY +UPDATE + last_created_at = new.last_created_at;`, + parms..., + ) + return err +} + +// ClearWorkflowStatus removes all workflow start times for id. +func (s *MySQLStorage) ClearWorkflowStatus(ctx context.Context, id string) error { + return s.q.ClearWorkflowStatus(ctx, id) +} diff --git a/engine/storage/storage.go b/engine/storage/storage.go index 642c264..32c3c4e 100644 --- a/engine/storage/storage.go +++ b/engine/storage/storage.go @@ -141,6 +141,17 @@ type StepResult struct { Commands []StepCommandResult } +type WorkflowStatusStorage interface { + // RetrieveWorkflowStarted returns the last time a workflow was started for id. + RetrieveWorkflowStarted(ctx context.Context, id, workflowName string) (time.Time, error) + + // RecordWorkflowStarted stores the started time for workflowName for ids. + RecordWorkflowStarted(ctx context.Context, ids []string, workflowName string, started time.Time) error + + // ClearWorkflowStatus removes all workflow start times for id. + ClearWorkflowStatus(ctx context.Context, id string) error +} + // Storage is the primary interface for workflow engine backend storage implementations. type Storage interface { // RetrieveCommandRequestType retrieves a command request type given id and uuid. @@ -172,6 +183,8 @@ type Storage interface { // workflow steps for the id. "NotUntil" (future) workflows steps // should also be canceled. CancelSteps(ctx context.Context, id, workflowName string) error + + WorkflowStatusStorage } // WorkerStorage is used by the workflow engine worker for async (scheduled) actions. @@ -207,9 +220,10 @@ type AllStorage interface { // EventSubscription is a user-configured subscription for starting workflows with optional context. type EventSubscription struct { - Event string `json:"event"` - Workflow string `json:"workflow"` - Context string `json:"context,omitempty"` + Event string `json:"event"` + Workflow string `json:"workflow"` + Context string `json:"context,omitempty"` + EventContext string `json:"event_context,omitempty"` } var ( diff --git a/engine/storage/test/event.go b/engine/storage/test/event.go index 881c8ee..25bccb8 100644 --- a/engine/storage/test/event.go +++ b/engine/storage/test/event.go @@ -12,9 +12,10 @@ func TestEventStorage(t *testing.T, store storage.EventSubscriptionStorage) { ctx := context.Background() evTest := &storage.EventSubscription{ - Event: "Enrollment", - Workflow: "wf", - Context: "ctx", + Event: "Enrollment", + Workflow: "wf", + Context: "ctx", + EventContext: "evCtx", } testEventData := func(t *testing.T, es *storage.EventSubscription) { @@ -38,6 +39,10 @@ func TestEventStorage(t *testing.T, store storage.EventSubscriptionStorage) { if have, want := es.Context, evTest.Context; have != want { t.Errorf("[context] have: %v, want: %v", have, want) } + + if have, want := es.EventContext, evTest.EventContext; have != want { + t.Errorf("[context] have: %v, want: %v", have, want) + } } t.Run("testdata", func(t *testing.T) { diff --git a/engine/storage/test/test.go b/engine/storage/test/test.go index 2234dbe..c7ab3cc 100644 --- a/engine/storage/test/test.go +++ b/engine/storage/test/test.go @@ -267,13 +267,40 @@ func mainTest(t *testing.T, s storage.AllStorage) { // t.Fatalf("invalid test data: step enqueueing with config: %v", err) // } - err := s.StoreStep(ctx, tStep.step, time.Now()) + // some backends may truncate the time and drop TZ + // so let's truncate ourselves and eliminate the TZ. + // since this value is used to compare the retrived value + // we'll stick with that. + storedAt := time.Now().UTC().Truncate(time.Second) + + err := s.StoreStep(ctx, tStep.step, storedAt) if tStep.shouldError && err == nil { t.Fatalf("StoreStep: expected error; step=%v", tStep.step) } else if !tStep.shouldError && err != nil { t.Fatalf("StoreStep: expected no error; step=%v err=%v", tStep.step, err) } + if err != nil && tStep.step != nil { + if len(tStep.step.IDs) > 0 { + err = s.RecordWorkflowStarted(ctx, tStep.step.IDs, tStep.step.WorkflowName, storedAt) + if err != nil { + t.Errorf("RecordWorkflowStarted: error for step=%s: %v", tStep.step.WorkflowName, err) + } + } + + for _, id := range tStep.step.IDs { + ts, err := s.RetrieveWorkflowStarted(ctx, id, tStep.step.WorkflowName) + if err != nil { + t.Fatalf("RetrieveWorkflowStarted: error for id=%s, step=%s err=%v", id, tStep.step.WorkflowName, err) + } + if ts.IsZero() { + t.Errorf("RetrieveWorkflowStarted: nil timestamp for id=%s, step=%s err=%v", id, tStep.step.WorkflowName, err) + } else if ts != storedAt { + t.Errorf("RetrieveWorkflowStarted: timestamp mismatch for id=%s, step=%s expected=%v got=%v", id, tStep.step.WorkflowName, storedAt, ts) + } + } + } + for _, tRespStep := range tStep.respSteps { t.Run("cmd-resp-"+tRespStep.testName, func(t *testing.T) { reqType, _, err := s.RetrieveCommandRequestType(ctx, tRespStep.id, tRespStep.resp.CommandUUID) diff --git a/log/logkeys/logkeys.go b/log/logkeys/logkeys.go index 6ee10e7..19420aa 100644 --- a/log/logkeys/logkeys.go +++ b/log/logkeys/logkeys.go @@ -22,4 +22,6 @@ const ( // a context-dependent numerical count/length of something GenericCount = "count" + + Event = "event" ) diff --git a/mdm/foss/dump.go b/mdm/foss/dump.go index 4ca42b1..f0ddfc7 100644 --- a/mdm/foss/dump.go +++ b/mdm/foss/dump.go @@ -3,6 +3,7 @@ package foss import ( "context" "io" + "time" "github.com/micromdm/nanocmd/workflow" ) @@ -23,6 +24,12 @@ func (d *MDMEventDumper) MDMCommandResponseEvent(ctx context.Context, id string, return d.next.MDMCommandResponseEvent(ctx, id, uuid, raw, mdmContext) } +// MDMIdleEvent is called when an MDM Report Results has an "Idle" status. +func (d *MDMEventDumper) MDMIdleEvent(ctx context.Context, id string, raw []byte, mdmContext *workflow.MDMContext, eventAt time.Time) error { + d.output.Write(append(raw, '\n')) + return d.next.MDMIdleEvent(ctx, id, raw, mdmContext, eventAt) +} + // MDMCheckinEvent processes the next eventer. func (d *MDMEventDumper) MDMCheckinEvent(ctx context.Context, id string, checkin interface{}, mdmContext *workflow.MDMContext) error { return d.next.MDMCheckinEvent(ctx, id, checkin, mdmContext) diff --git a/mdm/foss/process.go b/mdm/foss/process.go index c5f151c..27bdd61 100644 --- a/mdm/foss/process.go +++ b/mdm/foss/process.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/groob/plist" "github.com/micromdm/nanocmd/mdm" @@ -13,6 +14,9 @@ import ( type MDMCommandResponseEventer interface { MDMCommandResponseEvent(ctx context.Context, id string, uuid string, raw []byte, mdmContext *workflow.MDMContext) error + + // MDMIdleEvent is called when an MDM Report Results has an "Idle" status. + MDMIdleEvent(ctx context.Context, id string, raw []byte, mdmContext *workflow.MDMContext, eventAt time.Time) error } type MDMCheckinEventer interface { @@ -39,11 +43,11 @@ func processAcknowledgeEvent(ctx context.Context, e *AcknowledgeEvent, ev MDMCom if e == nil { return errors.New("empty acknowledge event") } - if e.Status == "Idle" || e.CommandUUID == "" { - return nil - } id, mdmContext := idAndContext(e.UDID, e.EnrollmentID, e.Params) - return ev.MDMCommandResponseEvent(ctx, id, e.CommandUUID, e.RawPayload, mdmContext) + if e.Status != "Idle" { + return ev.MDMCommandResponseEvent(ctx, id, e.CommandUUID, e.RawPayload, mdmContext) + } + return ev.MDMIdleEvent(ctx, id, e.RawPayload, mdmContext, time.Now()) } func processCheckinEvent(ctx context.Context, topic string, e *CheckinEvent, ev MDMCheckinEventer) error { diff --git a/mdm/foss/webhook_test.go b/mdm/foss/webhook_test.go index c3902ca..5315bb5 100644 --- a/mdm/foss/webhook_test.go +++ b/mdm/foss/webhook_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "os" "testing" + "time" "github.com/micromdm/nanocmd/log" "github.com/micromdm/nanocmd/mdm" @@ -36,6 +37,16 @@ func (r *eventRecorder) MDMCommandResponseEvent(ctx context.Context, id string, return nil } +func (r *eventRecorder) MDMIdleEvent(ctx context.Context, id string, raw []byte, mdmContext *workflow.MDMContext, _ time.Time) error { + r.events = append(r.events, event{ + resp: true, + id: id, + raw: raw, + ctx: mdmContext, + }) + return nil +} + func (r *eventRecorder) MDMCheckinEvent(ctx context.Context, id string, checkin interface{}, mdmContext *workflow.MDMContext) error { r.events = append(r.events, event{ resp: false, diff --git a/utils/kv/kv.go b/utils/kv/kv.go index 5dd857a..41055f3 100644 --- a/utils/kv/kv.go +++ b/utils/kv/kv.go @@ -43,3 +43,14 @@ func GetMap(ctx context.Context, b Bucket, keys []string) (map[string][]byte, er } return ret, nil } + +// DeleteSlice deletes s keys from b. +func DeleteSlice(ctx context.Context, b Bucket, s []string) error { + var err error + for _, i := range s { + if err = b.Delete(ctx, i); err != nil { + return fmt.Errorf("deleting %s: %w", i, err) + } + } + return nil +} diff --git a/workflow/event.go b/workflow/event.go index 3a34e43..b8531b1 100644 --- a/workflow/event.go +++ b/workflow/event.go @@ -22,6 +22,8 @@ const ( // continually arrive. EventEnrollment EventCheckOut + EventIdle + EventIdleNotStartedSince maxEventFlag ) @@ -41,6 +43,10 @@ func (e EventFlag) String() string { return "Enrollment" case EventCheckOut: return "CheckOut" + case EventIdle: + return "Idle" + case EventIdleNotStartedSince: + return "IdleNotStartedSince" default: return fmt.Sprintf("unknown event type: %d", e) } @@ -58,6 +64,10 @@ func EventFlagForString(s string) EventFlag { return EventEnrollment case "CheckOut": return EventCheckOut + case "Idle": + return EventIdle + case "IdleNotStartedSince": + return EventIdleNotStartedSince default: return 0 }