Skip to content

Commit

Permalink
prevent data races around setting nodekey in extension (#2039)
Browse files Browse the repository at this point in the history
  • Loading branch information
zackattack01 authored Jan 9, 2025
1 parent 30b4511 commit 9f071da
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions pkg/osquery/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,12 @@ func (e *Extension) Enroll(ctx context.Context) (string, bool, error) {
}

func (e *Extension) enrolled() bool {
return e.NodeKey != ""
// grab a reference to the existing nodekey to prevent data races with any re-enrollments
e.enrollMutex.Lock()
nodeKey := e.NodeKey
e.enrollMutex.Unlock()

return nodeKey != ""
}

// RequireReenroll clears the existing node key information, ensuring that the
Expand Down Expand Up @@ -548,7 +553,12 @@ var reenrollmentInvalidErr = errors.New("enrollment invalid, reenrollment invali

// Helper to allow for a single attempt at re-enrollment
func (e *Extension) generateConfigsWithReenroll(ctx context.Context, reenroll bool) (string, error) {
config, invalid, err := e.serviceClient.RequestConfig(ctx, e.NodeKey)
// grab a reference to the existing nodekey to prevent data races with any re-enrollments
e.enrollMutex.Lock()
nodeKey := e.NodeKey
e.enrollMutex.Unlock()

config, invalid, err := e.serviceClient.RequestConfig(ctx, nodeKey)
switch {
case errors.Is(err, service.ErrDeviceDisabled{}):
uninstall.Uninstall(ctx, e.knapsack, true)
Expand Down Expand Up @@ -796,7 +806,12 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error {

// Helper to allow for a single attempt at re-enrollment
func (e *Extension) writeLogsWithReenroll(ctx context.Context, typ logger.LogType, logs []string, reenroll bool) error {
_, _, invalid, err := e.serviceClient.PublishLogs(ctx, e.NodeKey, typ, logs)
// grab a reference to the existing nodekey to prevent data races with any re-enrollments
e.enrollMutex.Lock()
nodeKey := e.NodeKey
e.enrollMutex.Unlock()

_, _, invalid, err := e.serviceClient.PublishLogs(ctx, nodeKey, typ, logs)

if errors.Is(err, service.ErrDeviceDisabled{}) {
uninstall.Uninstall(ctx, e.knapsack, true)
Expand Down Expand Up @@ -912,8 +927,13 @@ func (e *Extension) getQueriesWithReenroll(ctx context.Context, reenroll bool) (
ctx, span := traces.StartSpan(ctx)
defer span.End()

// grab a reference to the existing nodekey to prevent data races with any re-enrollments
e.enrollMutex.Lock()
nodeKey := e.NodeKey
e.enrollMutex.Unlock()

// Note that we set invalid two ways -- in the return, and via isNodeinvaliderr
queries, invalid, err := e.serviceClient.RequestQueries(ctx, e.NodeKey)
queries, invalid, err := e.serviceClient.RequestQueries(ctx, nodeKey)

switch {
case errors.Is(err, service.ErrDeviceDisabled{}):
Expand Down Expand Up @@ -971,7 +991,12 @@ func (e *Extension) writeResultsWithReenroll(ctx context.Context, results []dist
ctx, span := traces.StartSpan(ctx)
defer span.End()

_, _, invalid, err := e.serviceClient.PublishResults(ctx, e.NodeKey, results)
// grab a reference to the existing nodekey to prevent data races with any re-enrollments
e.enrollMutex.Lock()
nodeKey := e.NodeKey
e.enrollMutex.Unlock()

_, _, invalid, err := e.serviceClient.PublishResults(ctx, nodeKey, results)
switch {
case errors.Is(err, service.ErrDeviceDisabled{}):
uninstall.Uninstall(ctx, e.knapsack, true)
Expand Down

0 comments on commit 9f071da

Please sign in to comment.