diff --git a/tool/teleport/common/wait.go b/tool/teleport/common/wait.go index 3522c37dc935c..e4d8df8553923 100644 --- a/tool/teleport/common/wait.go +++ b/tool/teleport/common/wait.go @@ -21,6 +21,7 @@ package common import ( "context" "errors" + "fmt" "log/slog" "net" "os" @@ -29,7 +30,6 @@ import ( "time" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/utils" @@ -92,6 +92,8 @@ func waitNoResolve(ctx context.Context, domain string, period, timeout time.Dura if timeout == 0 { return trace.BadParameter("no timeout provided") } + log := slog.With("domain", domain) + log.InfoContext(ctx, "waiting until the domain stops resolving to ensure that every auth server running the previous major version has been updated/terminated") var err error ctx, cancel := context.WithTimeout(ctx, timeout) @@ -124,44 +126,60 @@ func waitNoResolve(ctx context.Context, domain string, period, timeout time.Dura return trace.Wrap(err) case <-periodic.Next(): - exit, err = checkDomainNoResolve(domain) + exit, err = checkDomainNoResolve(ctx, domain, log) if err != nil { return trace.Wrap(err) } } } - log.Info("no endpoints found, exiting with success code") + log.InfoContext(ctx, "no endpoints found, exiting with success code") return nil } -func checkDomainNoResolve(domainName string) (exit bool, err error) { - endpoints, err := countEndpoints(domainName) +func checkDomainNoResolve(ctx context.Context, domainName string, log *slog.Logger) (exit bool, err error) { + endpoints, err := resolveEndpoints(domainName) if err != nil { var dnsErr *net.DNSError if !errors.As(trace.Unwrap(err), &dnsErr) { - log.Errorf("unexpected error when resolving domain %s : %s", domainName, err) + log.ErrorContext(ctx, "unexpected error when resolving domain", "error", err) return false, trace.Wrap(err) } - if dnsErr.Temporary() { - log.Warnf("temporary error when resolving domain %s : %s", domainName, err) - return false, nil - } + if dnsErr.IsNotFound { - log.Infof("domain %s not found", domainName) + log.InfoContext(ctx, "domain not found") return true, nil } - log.Errorf("error when resolving domain %s : %s", domainName, err) + + // Creating a new logger because the linter doesn't want both key/value and slog.Attr in the same log write. + log := log.With(slog.Group("dns_error", + "name", dnsErr.Name, + "server", dnsErr.Server, + "is_timeout", dnsErr.IsTimeout, + "is_temporary", dnsErr.IsTemporary, + "is_not_found", dnsErr.IsNotFound, + // Logging the error type can help understanding where the error comes from + "wrapped_error_type", fmt.Sprintf("%T", dnsErr.Unwrap()), + )) + if dnsErr.Temporary() { + log.WarnContext(ctx, "temporary error when resolving domain", "error", err) + return false, nil + } + log.ErrorContext(ctx, "error when resolving domain", "error", err) return false, nil } - log.Infof("%d endpoints found when resolving domain %s", endpoints, domainName) - return endpoints == 0, nil + if len(endpoints) == 0 { + log.InfoContext(ctx, "domain found and resolution returned no endpoints") + return true, nil + } + log.InfoContext(ctx, "endpoints found when resolving domain", "endpoints", endpoints) + return false, nil } -func countEndpoints(serviceName string) (int, error) { +func resolveEndpoints(serviceName string) ([]net.IP, error) { ips, err := net.LookupIP(serviceName) if err != nil { - return 0, trace.Wrap(err) + return nil, trace.Wrap(err) } - return len(ips), nil + return ips, nil }