From a0094835f081433ffc12454e94cd477d495d9bb8 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Thu, 7 Nov 2024 12:14:09 -0500 Subject: [PATCH] Convert github auth code to use x/oauth2 Replaces the dependency on go-oidc/oauth2 with x/oauth2 in the github connector auth flows. One result of the switch is that it permitted a lot of simplification - the github client cache used by the auth server was removed entirely. --- lib/auth/auth.go | 46 ++-------------------- lib/auth/github.go | 96 ++++++++++++++++------------------------------ 2 files changed, 36 insertions(+), 106 deletions(-) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index b17b25a6205d6..b9d91381cd51c 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -49,7 +49,6 @@ import ( "sync" "time" - "github.com/coreos/go-oidc/oauth2" "github.com/google/uuid" liblicense "github.com/gravitational/license" "github.com/gravitational/trace" @@ -494,7 +493,6 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { Authority: cfg.Authority, AuthServiceName: cfg.AuthServiceName, ServerID: cfg.HostUUID, - githubClients: make(map[string]*githubClient), cancelFunc: cancelFunc, closeCtx: closeCtx, emitter: cfg.Emitter, @@ -886,10 +884,9 @@ type ReadOnlyCache = readonly.Cache // - same for users and their sessions // - checks public keys to see if they're signed by it (can be trusted or not) type Server struct { - lock sync.RWMutex - githubClients map[string]*githubClient - clock clockwork.Clock - bk backend.Backend + lock sync.RWMutex + clock clockwork.Clock + bk backend.Backend closeCtx context.Context cancelFunc context.CancelFunc @@ -7523,43 +7520,6 @@ func (k *authKeepAliver) Close() error { return nil } -// githubClient is internal structure that stores Github OAuth 2client and its config -type githubClient struct { - client *oauth2.Client - config oauth2.Config -} - -// oauth2ConfigsEqual returns true if the provided OAuth2 configs are equal -func oauth2ConfigsEqual(a, b oauth2.Config) bool { - if a.Credentials.ID != b.Credentials.ID { - return false - } - if a.Credentials.Secret != b.Credentials.Secret { - return false - } - if a.RedirectURL != b.RedirectURL { - return false - } - if len(a.Scope) != len(b.Scope) { - return false - } - for i := range a.Scope { - if a.Scope[i] != b.Scope[i] { - return false - } - } - if a.AuthURL != b.AuthURL { - return false - } - if a.TokenURL != b.TokenURL { - return false - } - if a.AuthMethod != b.AuthMethod { - return false - } - return true -} - // DefaultDNSNamesForRole returns default DNS names for the specified role. func DefaultDNSNamesForRole(role types.SystemRole) []string { if (types.SystemRoles{role}).IncludeAny( diff --git a/lib/auth/github.go b/lib/auth/github.go index d49c789f81349..aa4770e0a9add 100644 --- a/lib/auth/github.go +++ b/lib/auth/github.go @@ -32,9 +32,9 @@ import ( "strings" "time" - "github.com/coreos/go-oidc/oauth2" "github.com/gravitational/trace" "github.com/sirupsen/logrus" + "golang.org/x/oauth2" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/constants" @@ -143,7 +143,7 @@ func (g *GithubConverter) UpdateGithubConnector(ctx context.Context, connector t // CreateGithubAuthRequest creates a new request for Github OAuth2 flow func (a *Server) CreateGithubAuthRequest(ctx context.Context, req types.GithubAuthRequest) (*types.GithubAuthRequest, error) { - connector, client, err := a.getGithubConnectorAndClient(ctx, req) + connector, err := a.getGithubConnector(ctx, req) if err != nil { return nil, trace.Wrap(err) } @@ -163,7 +163,10 @@ func (a *Server) CreateGithubAuthRequest(ctx context.Context, req types.GithubAu if err != nil { return nil, trace.Wrap(err) } - req.RedirectURL = client.AuthCodeURL(req.StateToken, "", "") + + config := newGithubOAuth2Config(connector) + + req.RedirectURL = config.AuthCodeURL(req.StateToken) log.WithFields(logrus.Fields{teleport.ComponentKey: "github"}).Debugf( "Redirect URL: %v.", req.RedirectURL) req.SetExpiry(a.GetClock().Now().UTC().Add(defaults.GithubAuthRequestTTL)) @@ -487,84 +490,49 @@ func validateGithubAuthCallbackHelper(ctx context.Context, m githubManager, diag return auth, nil } -func (a *Server) getGithubConnectorAndClient(ctx context.Context, request types.GithubAuthRequest) (types.GithubConnector, *oauth2.Client, error) { +func (a *Server) getGithubConnector(ctx context.Context, request types.GithubAuthRequest) (types.GithubConnector, error) { if request.SSOTestFlow { if request.ConnectorSpec == nil { - return nil, nil, trace.BadParameter("ConnectorSpec cannot be nil when SSOTestFlow is true") + return nil, trace.BadParameter("ConnectorSpec cannot be nil when SSOTestFlow is true") } if request.ConnectorID == "" { - return nil, nil, trace.BadParameter("ConnectorID cannot be empty") + return nil, trace.BadParameter("ConnectorID cannot be empty") } // stateless test flow connector, err := services.NewGithubConnector(request.ConnectorID, *request.ConnectorSpec) if err != nil { - return nil, nil, trace.Wrap(err) - } - - // construct client directly. - config := newGithubOAuth2Config(connector) - client, err := oauth2.NewClient(http.DefaultClient, config) - if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } - return connector, client, nil + return connector, nil } // regular execution flow connector, err := a.GetGithubConnector(ctx, request.ConnectorID, true) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } connector, err = services.InitGithubConnector(connector) if err != nil { - return nil, nil, trace.Wrap(err) - } - - client, err := a.getGithubOAuth2Client(connector) - if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } - return connector, client, nil + return connector, nil } func newGithubOAuth2Config(connector types.GithubConnector) oauth2.Config { return oauth2.Config{ - Credentials: oauth2.ClientCredentials{ - ID: connector.GetClientID(), - Secret: connector.GetClientSecret(), + ClientID: connector.GetClientID(), + ClientSecret: connector.GetClientSecret(), + RedirectURL: connector.GetRedirectURL(), + Scopes: GithubScopes, + Endpoint: oauth2.Endpoint{ + AuthURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubAuthPath), + TokenURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubTokenPath), }, - RedirectURL: connector.GetRedirectURL(), - Scope: GithubScopes, - AuthURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubAuthPath), - TokenURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubTokenPath), - } -} - -func (a *Server) getGithubOAuth2Client(connector types.GithubConnector) (*oauth2.Client, error) { - config := newGithubOAuth2Config(connector) - - a.lock.Lock() - defer a.lock.Unlock() - - cachedClient, ok := a.githubClients[connector.GetName()] - if ok && oauth2ConfigsEqual(cachedClient.config, config) { - return cachedClient.client, nil } - - delete(a.githubClients, connector.GetName()) - client, err := oauth2.NewClient(http.DefaultClient, config) - if err != nil { - return nil, trace.Wrap(err) - } - a.githubClients[connector.GetName()] = &githubClient{ - client: client, - config: config, - } - return client, nil } // ValidateGithubAuthCallback validates Github auth callback redirect @@ -584,19 +552,19 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODia // optional parameter: error_description errDesc := q.Get("error_description") - oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, errParam, q) + oauthErr := trace.OAuth2("invalid_request", errParam, q) return nil, trace.WithUserMessage(oauthErr, "GitHub returned error: %v [%v]", errDesc, errParam) } code := q.Get("code") if code == "" { - oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, "code query param must be set", q) + oauthErr := trace.OAuth2("invalid_request", "code query param must be set", q) return nil, trace.WithUserMessage(oauthErr, "Invalid parameters received from GitHub.") } stateToken := q.Get("state") if stateToken == "" { - oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, "missing state query param", q) + oauthErr := trace.OAuth2("invalid_request", "missing state query param", q) return nil, trace.WithUserMessage(oauthErr, "Invalid parameters received from GitHub.") } diagCtx.RequestID = stateToken @@ -607,7 +575,7 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODia } diagCtx.Info.TestFlow = req.SSOTestFlow - connector, client, err := a.getGithubConnectorAndClient(ctx, *req) + connector, err := a.getGithubConnector(ctx, *req) if err != nil { return nil, trace.Wrap(err, "Failed to get GitHub connector and client.") } @@ -615,7 +583,7 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODia diagCtx.Info.GithubTeamsToRoles = connector.GetTeamsToRoles() logger.Debugf("Connector %q teams to logins: %v, roles: %v", connector.GetName(), connector.GetTeamsToLogins(), connector.GetTeamsToRoles()) - userResp, teamsResp, err := a.getGithubUserAndTeams(ctx, connector, code, client, diagCtx, logger) + userResp, teamsResp, err := a.getGithubUserAndTeams(ctx, connector, code, diagCtx, logger) if err != nil { return nil, trace.Wrap(err) } @@ -752,7 +720,6 @@ func (a *Server) getGithubUserAndTeams( ctx context.Context, connector types.GithubConnector, code string, - client *oauth2.Client, diagCtx *SSODiagContext, logger *logrus.Entry, ) (*GithubUserResponse, []GithubTeamResponse, error) { @@ -762,20 +729,23 @@ func (a *Server) getGithubUserAndTeams( return a.GithubUserAndTeamsOverride() } + config := newGithubOAuth2Config(connector) + // exchange the authorization code received by the callback for an access token - token, err := client.RequestToken(oauth2.GrantTypeAuthCode, code) + token, err := config.Exchange(ctx, code) if err != nil { return nil, nil, trace.Wrap(err, "Requesting GitHub OAuth2 token failed.") } + scope := token.Extra("scope").(string) diagCtx.Info.GithubTokenInfo = &types.GithubTokenInfo{ TokenType: token.TokenType, - Expires: int64(token.Expires), - Scope: token.Scope, + Expires: token.ExpiresIn, + Scope: scope, } logger.Debugf("Obtained OAuth2 token: Type=%v Expires=%v Scope=%v.", - token.TokenType, token.Expires, token.Scope) + token.TokenType, token.ExpiresIn, scope) // Get the Github organizations the user is a member of so we don't // make unnecessary API requests