Skip to content

Commit

Permalink
Convert github auth code to use x/oauth2
Browse files Browse the repository at this point in the history
Replaces the dependency on go-oidc/oauth2 with x/oauth2 in the
github connector auth flows. One result of the switch is that it
permitted a lot of simplification - the github client cache used
by the auth server was removed entirely.
  • Loading branch information
rosstimothy committed Nov 18, 2024
1 parent 79a1680 commit fe74173
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 106 deletions.
46 changes: 3 additions & 43 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ import (
"sync"
"time"

"github.com/coreos/go-oidc/oauth2"
"github.com/google/uuid"
liblicense "github.com/gravitational/license"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -494,7 +493,6 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
Authority: cfg.Authority,
AuthServiceName: cfg.AuthServiceName,
ServerID: cfg.HostUUID,
githubClients: make(map[string]*githubClient),
cancelFunc: cancelFunc,
closeCtx: closeCtx,
emitter: cfg.Emitter,
Expand Down Expand Up @@ -886,10 +884,9 @@ type ReadOnlyCache = readonly.Cache
// - same for users and their sessions
// - checks public keys to see if they're signed by it (can be trusted or not)
type Server struct {
lock sync.RWMutex
githubClients map[string]*githubClient
clock clockwork.Clock
bk backend.Backend
lock sync.RWMutex
clock clockwork.Clock
bk backend.Backend

closeCtx context.Context
cancelFunc context.CancelFunc
Expand Down Expand Up @@ -7523,43 +7520,6 @@ func (k *authKeepAliver) Close() error {
return nil
}

// githubClient is internal structure that stores Github OAuth 2client and its config
type githubClient struct {
client *oauth2.Client
config oauth2.Config
}

// oauth2ConfigsEqual returns true if the provided OAuth2 configs are equal
func oauth2ConfigsEqual(a, b oauth2.Config) bool {
if a.Credentials.ID != b.Credentials.ID {
return false
}
if a.Credentials.Secret != b.Credentials.Secret {
return false
}
if a.RedirectURL != b.RedirectURL {
return false
}
if len(a.Scope) != len(b.Scope) {
return false
}
for i := range a.Scope {
if a.Scope[i] != b.Scope[i] {
return false
}
}
if a.AuthURL != b.AuthURL {
return false
}
if a.TokenURL != b.TokenURL {
return false
}
if a.AuthMethod != b.AuthMethod {
return false
}
return true
}

// DefaultDNSNamesForRole returns default DNS names for the specified role.
func DefaultDNSNamesForRole(role types.SystemRole) []string {
if (types.SystemRoles{role}).IncludeAny(
Expand Down
99 changes: 36 additions & 63 deletions lib/auth/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ import (
"strings"
"time"

"github.com/coreos/go-oidc/oauth2"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
Expand Down Expand Up @@ -143,7 +143,7 @@ func (g *GithubConverter) UpdateGithubConnector(ctx context.Context, connector t

// CreateGithubAuthRequest creates a new request for Github OAuth2 flow
func (a *Server) CreateGithubAuthRequest(ctx context.Context, req types.GithubAuthRequest) (*types.GithubAuthRequest, error) {
connector, client, err := a.getGithubConnectorAndClient(ctx, req)
connector, err := a.getGithubConnector(ctx, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -163,7 +163,10 @@ func (a *Server) CreateGithubAuthRequest(ctx context.Context, req types.GithubAu
if err != nil {
return nil, trace.Wrap(err)
}
req.RedirectURL = client.AuthCodeURL(req.StateToken, "", "")

config := newGithubOAuth2Config(connector)

req.RedirectURL = config.AuthCodeURL(req.StateToken)
log.WithFields(logrus.Fields{teleport.ComponentKey: "github"}).Debugf(
"Redirect URL: %v.", req.RedirectURL)
req.SetExpiry(a.GetClock().Now().UTC().Add(defaults.GithubAuthRequestTTL))
Expand Down Expand Up @@ -487,86 +490,51 @@ func validateGithubAuthCallbackHelper(ctx context.Context, m githubManager, diag
return auth, nil
}

func (a *Server) getGithubConnectorAndClient(ctx context.Context, request types.GithubAuthRequest) (types.GithubConnector, *oauth2.Client, error) {
func (a *Server) getGithubConnector(ctx context.Context, request types.GithubAuthRequest) (types.GithubConnector, error) {
if request.SSOTestFlow {
if request.ConnectorSpec == nil {
return nil, nil, trace.BadParameter("ConnectorSpec cannot be nil when SSOTestFlow is true")
return nil, trace.BadParameter("ConnectorSpec cannot be nil when SSOTestFlow is true")
}

if request.ConnectorID == "" {
return nil, nil, trace.BadParameter("ConnectorID cannot be empty")
return nil, trace.BadParameter("ConnectorID cannot be empty")
}

// stateless test flow
connector, err := services.NewGithubConnector(request.ConnectorID, *request.ConnectorSpec)
if err != nil {
return nil, nil, trace.Wrap(err)
}

// construct client directly.
config := newGithubOAuth2Config(connector)
client, err := oauth2.NewClient(http.DefaultClient, config)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, trace.Wrap(err)
}

return connector, client, nil
return connector, nil
}

// regular execution flow
connector, err := a.GetGithubConnector(ctx, request.ConnectorID, true)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, trace.Wrap(err)
}
connector, err = services.InitGithubConnector(connector)
if err != nil {
return nil, nil, trace.Wrap(err)
}

client, err := a.getGithubOAuth2Client(connector)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, trace.Wrap(err)
}

return connector, client, nil
return connector, nil
}

func newGithubOAuth2Config(connector types.GithubConnector) oauth2.Config {
return oauth2.Config{
Credentials: oauth2.ClientCredentials{
ID: connector.GetClientID(),
Secret: connector.GetClientSecret(),
ClientID: connector.GetClientID(),
ClientSecret: connector.GetClientSecret(),
RedirectURL: connector.GetRedirectURL(),
Scopes: GithubScopes,
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubAuthPath),
TokenURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubTokenPath),
},
RedirectURL: connector.GetRedirectURL(),
Scope: GithubScopes,
AuthURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubAuthPath),
TokenURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubTokenPath),
}
}

func (a *Server) getGithubOAuth2Client(connector types.GithubConnector) (*oauth2.Client, error) {
config := newGithubOAuth2Config(connector)

a.lock.Lock()
defer a.lock.Unlock()

cachedClient, ok := a.githubClients[connector.GetName()]
if ok && oauth2ConfigsEqual(cachedClient.config, config) {
return cachedClient.client, nil
}

delete(a.githubClients, connector.GetName())
client, err := oauth2.NewClient(http.DefaultClient, config)
if err != nil {
return nil, trace.Wrap(err)
}
a.githubClients[connector.GetName()] = &githubClient{
client: client,
config: config,
}
return client, nil
}

// ValidateGithubAuthCallback validates Github auth callback redirect
func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*authclient.GithubAuthResponse, error) {
logger := log.WithFields(logrus.Fields{teleport.ComponentKey: "github"})
Expand All @@ -584,19 +552,19 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODia

// optional parameter: error_description
errDesc := q.Get("error_description")
oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, errParam, q)
oauthErr := trace.OAuth2("invalid_request", errParam, q)
return nil, trace.WithUserMessage(oauthErr, "GitHub returned error: %v [%v]", errDesc, errParam)
}

code := q.Get("code")
if code == "" {
oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, "code query param must be set", q)
oauthErr := trace.OAuth2("invalid_request", "code query param must be set", q)
return nil, trace.WithUserMessage(oauthErr, "Invalid parameters received from GitHub.")
}

stateToken := q.Get("state")
if stateToken == "" {
oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, "missing state query param", q)
oauthErr := trace.OAuth2("invalid_request", "missing state query param", q)
return nil, trace.WithUserMessage(oauthErr, "Invalid parameters received from GitHub.")
}
diagCtx.RequestID = stateToken
Expand All @@ -607,15 +575,15 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODia
}
diagCtx.Info.TestFlow = req.SSOTestFlow

connector, client, err := a.getGithubConnectorAndClient(ctx, *req)
connector, err := a.getGithubConnector(ctx, *req)
if err != nil {
return nil, trace.Wrap(err, "Failed to get GitHub connector and client.")
}
diagCtx.Info.GithubTeamsToLogins = connector.GetTeamsToLogins()
diagCtx.Info.GithubTeamsToRoles = connector.GetTeamsToRoles()
logger.Debugf("Connector %q teams to logins: %v, roles: %v", connector.GetName(), connector.GetTeamsToLogins(), connector.GetTeamsToRoles())

userResp, teamsResp, err := a.getGithubUserAndTeams(ctx, connector, code, client, diagCtx, logger)
userResp, teamsResp, err := a.getGithubUserAndTeams(ctx, connector, code, diagCtx, logger)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -752,7 +720,6 @@ func (a *Server) getGithubUserAndTeams(
ctx context.Context,
connector types.GithubConnector,
code string,
client *oauth2.Client,
diagCtx *SSODiagContext,
logger *logrus.Entry,
) (*GithubUserResponse, []GithubTeamResponse, error) {
Expand All @@ -762,20 +729,26 @@ func (a *Server) getGithubUserAndTeams(
return a.GithubUserAndTeamsOverride()
}

config := newGithubOAuth2Config(connector)

// exchange the authorization code received by the callback for an access token
token, err := client.RequestToken(oauth2.GrantTypeAuthCode, code)
token, err := config.Exchange(ctx, code)
if err != nil {
return nil, nil, trace.Wrap(err, "Requesting GitHub OAuth2 token failed.")
}

scope, ok := token.Extra("scope").(string)
if !ok {
return nil, nil, trace.BadParameter("missing or invalid scope found in GitHub OAuth2 token")
}
diagCtx.Info.GithubTokenInfo = &types.GithubTokenInfo{
TokenType: token.TokenType,
Expires: int64(token.Expires),
Scope: token.Scope,
Expires: token.ExpiresIn,
Scope: scope,
}

logger.Debugf("Obtained OAuth2 token: Type=%v Expires=%v Scope=%v.",
token.TokenType, token.Expires, token.Scope)
token.TokenType, token.ExpiresIn, scope)

// Get the Github organizations the user is a member of so we don't
// make unnecessary API requests
Expand Down

0 comments on commit fe74173

Please sign in to comment.