From fdecdbbd6ffaed1e12e7bfa986a527dcf7e24307 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Mon, 11 Nov 2024 12:01:17 +0000 Subject: [PATCH] handle code review comments --- .../authorizedkeys/authorized_keys.go | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/lib/secretsscanner/authorizedkeys/authorized_keys.go b/lib/secretsscanner/authorizedkeys/authorized_keys.go index a3f1e862435bc..43351dd0e43b8 100644 --- a/lib/secretsscanner/authorizedkeys/authorized_keys.go +++ b/lib/secretsscanner/authorizedkeys/authorized_keys.go @@ -222,9 +222,9 @@ func (w *Watcher) start(ctx context.Context) error { var requiresReportToExtendTTL bool for { - err := w.fetchAndReportAuthorizedKeys(ctx, fileWatcher, requiresReportToExtendTTL) + keysReported, err := w.fetchAndReportAuthorizedKeys(ctx, fileWatcher, requiresReportToExtendTTL) expirationTimerInterval := maxReSendInterval - if err != nil && !errors.Is(err, errKeysNotReported) { + if err != nil { w.logger.WarnContext(ctx, "Failed to report authorized keys", "error", err) expirationTimerInterval = maxInitialDelay } @@ -233,7 +233,7 @@ func (w *Watcher) start(ctx context.Context) error { requiresReportToExtendTTL = false // If the keys were reported, reset the expiration timer. - if !errors.Is(err, errKeysNotReported) { + if keysReported { resetTimer(expirationTimer, expirationTimerInterval) } @@ -299,20 +299,17 @@ func (w *Watcher) fetchAuthorizedKeys( return keys, nil } -// errKeysNotReported is returned when the keys are not reported. -var errKeysNotReported = errors.New("keys not reported") - // fetchAndReportAuthorizedKeys fetches the authorized keys from the system and reports them to the cluster. func (w *Watcher) fetchAndReportAuthorizedKeys( ctx context.Context, fileWatcher *fsnotify.Watcher, requiresReportToExtendTTL bool, -) (returnErr error) { +) (reported bool, returnErr error) { // fetchAuthorizedKeys fetches the authorized keys from the system. keys, err := w.fetchAuthorizedKeys(ctx, fileWatcher) if err != nil { - return trace.Wrap(err) + return false, trace.Wrap(err) } // for the given keys, sort the key names and return them. @@ -323,7 +320,7 @@ func (w *Watcher) fetchAndReportAuthorizedKeys( // If the cluster does not require a report to extend the TTL of the authorized keys, // and the key names are the same, there is no need to report the keys. if !requiresReportToExtendTTL && slices.Equal(w.keyNames, keyNames) { - return errKeysNotReported + return false, nil } // Report the authorized keys to the cluster. @@ -331,7 +328,7 @@ func (w *Watcher) fetchAndReportAuthorizedKeys( stream, err := w.client.AccessGraphSecretsScannerClient().ReportAuthorizedKeys(ctx) if err != nil { - return trace.Wrap(err) + return false, trace.Wrap(err) } defer func() { if closeErr := stream.CloseSend(); closeErr != nil && !errors.Is(closeErr, io.EOF) { @@ -359,16 +356,16 @@ func (w *Watcher) fetchAndReportAuthorizedKeys( Operation: accessgraphsecretsv1pb.OperationType_OPERATION_TYPE_ADD, }, ); err != nil { - return trace.Wrap(err) + return false, trace.Wrap(err) } } if err := stream.Send( &accessgraphsecretsv1pb.ReportAuthorizedKeysRequest{Operation: accessgraphsecretsv1pb.OperationType_OPERATION_TYPE_SYNC}, ); err != nil { - return trace.Wrap(err) + return false, trace.Wrap(err) } - return nil + return true, nil } func (w *Watcher) parseAuthorizedKeysFile(ctx context.Context, u user.User, authorizedKeysPath string) ([]*accessgraphsecretsv1pb.AuthorizedKey, error) {