Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert github auth code to use x/oauth2 #48598

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 3 additions & 43 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ import (
"sync"
"time"

"github.com/coreos/go-oidc/oauth2"
"github.com/google/uuid"
liblicense "github.com/gravitational/license"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -494,7 +493,6 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
Authority: cfg.Authority,
AuthServiceName: cfg.AuthServiceName,
ServerID: cfg.HostUUID,
githubClients: make(map[string]*githubClient),
cancelFunc: cancelFunc,
closeCtx: closeCtx,
emitter: cfg.Emitter,
Expand Down Expand Up @@ -886,10 +884,9 @@ type ReadOnlyCache = readonly.Cache
// - same for users and their sessions
// - checks public keys to see if they're signed by it (can be trusted or not)
type Server struct {
lock sync.RWMutex
githubClients map[string]*githubClient
clock clockwork.Clock
bk backend.Backend
lock sync.RWMutex
clock clockwork.Clock
bk backend.Backend

closeCtx context.Context
cancelFunc context.CancelFunc
Expand Down Expand Up @@ -7523,43 +7520,6 @@ func (k *authKeepAliver) Close() error {
return nil
}

// githubClient is internal structure that stores Github OAuth 2client and its config
type githubClient struct {
client *oauth2.Client
config oauth2.Config
}

// oauth2ConfigsEqual returns true if the provided OAuth2 configs are equal
func oauth2ConfigsEqual(a, b oauth2.Config) bool {
if a.Credentials.ID != b.Credentials.ID {
return false
}
if a.Credentials.Secret != b.Credentials.Secret {
return false
}
if a.RedirectURL != b.RedirectURL {
return false
}
if len(a.Scope) != len(b.Scope) {
return false
}
for i := range a.Scope {
if a.Scope[i] != b.Scope[i] {
return false
}
}
if a.AuthURL != b.AuthURL {
return false
}
if a.TokenURL != b.TokenURL {
return false
}
if a.AuthMethod != b.AuthMethod {
return false
}
return true
}

// DefaultDNSNamesForRole returns default DNS names for the specified role.
func DefaultDNSNamesForRole(role types.SystemRole) []string {
if (types.SystemRoles{role}).IncludeAny(
Expand Down
99 changes: 36 additions & 63 deletions lib/auth/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ import (
"strings"
"time"

"github.com/coreos/go-oidc/oauth2"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
Expand Down Expand Up @@ -143,7 +143,7 @@ func (g *GithubConverter) UpdateGithubConnector(ctx context.Context, connector t

// CreateGithubAuthRequest creates a new request for Github OAuth2 flow
func (a *Server) CreateGithubAuthRequest(ctx context.Context, req types.GithubAuthRequest) (*types.GithubAuthRequest, error) {
connector, client, err := a.getGithubConnectorAndClient(ctx, req)
connector, err := a.getGithubConnector(ctx, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -163,7 +163,10 @@ func (a *Server) CreateGithubAuthRequest(ctx context.Context, req types.GithubAu
if err != nil {
return nil, trace.Wrap(err)
}
req.RedirectURL = client.AuthCodeURL(req.StateToken, "", "")

config := newGithubOAuth2Config(connector)

req.RedirectURL = config.AuthCodeURL(req.StateToken)
log.WithFields(logrus.Fields{teleport.ComponentKey: "github"}).Debugf(
"Redirect URL: %v.", req.RedirectURL)
req.SetExpiry(a.GetClock().Now().UTC().Add(defaults.GithubAuthRequestTTL))
Expand Down Expand Up @@ -487,86 +490,51 @@ func validateGithubAuthCallbackHelper(ctx context.Context, m githubManager, diag
return auth, nil
}

func (a *Server) getGithubConnectorAndClient(ctx context.Context, request types.GithubAuthRequest) (types.GithubConnector, *oauth2.Client, error) {
func (a *Server) getGithubConnector(ctx context.Context, request types.GithubAuthRequest) (types.GithubConnector, error) {
if request.SSOTestFlow {
if request.ConnectorSpec == nil {
return nil, nil, trace.BadParameter("ConnectorSpec cannot be nil when SSOTestFlow is true")
return nil, trace.BadParameter("ConnectorSpec cannot be nil when SSOTestFlow is true")
}

if request.ConnectorID == "" {
return nil, nil, trace.BadParameter("ConnectorID cannot be empty")
return nil, trace.BadParameter("ConnectorID cannot be empty")
}

// stateless test flow
connector, err := services.NewGithubConnector(request.ConnectorID, *request.ConnectorSpec)
if err != nil {
return nil, nil, trace.Wrap(err)
}

// construct client directly.
config := newGithubOAuth2Config(connector)
client, err := oauth2.NewClient(http.DefaultClient, config)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, trace.Wrap(err)
}

return connector, client, nil
return connector, nil
}

// regular execution flow
connector, err := a.GetGithubConnector(ctx, request.ConnectorID, true)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, trace.Wrap(err)
}
connector, err = services.InitGithubConnector(connector)
if err != nil {
return nil, nil, trace.Wrap(err)
}

client, err := a.getGithubOAuth2Client(connector)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, trace.Wrap(err)
}

return connector, client, nil
return connector, nil
}

func newGithubOAuth2Config(connector types.GithubConnector) oauth2.Config {
return oauth2.Config{
Credentials: oauth2.ClientCredentials{
ID: connector.GetClientID(),
Secret: connector.GetClientSecret(),
ClientID: connector.GetClientID(),
ClientSecret: connector.GetClientSecret(),
RedirectURL: connector.GetRedirectURL(),
Scopes: GithubScopes,
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubAuthPath),
TokenURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubTokenPath),
rosstimothy marked this conversation as resolved.
Show resolved Hide resolved
},
RedirectURL: connector.GetRedirectURL(),
Scope: GithubScopes,
AuthURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubAuthPath),
TokenURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubTokenPath),
}
}

func (a *Server) getGithubOAuth2Client(connector types.GithubConnector) (*oauth2.Client, error) {
config := newGithubOAuth2Config(connector)

a.lock.Lock()
defer a.lock.Unlock()

cachedClient, ok := a.githubClients[connector.GetName()]
if ok && oauth2ConfigsEqual(cachedClient.config, config) {
return cachedClient.client, nil
}

delete(a.githubClients, connector.GetName())
client, err := oauth2.NewClient(http.DefaultClient, config)
if err != nil {
return nil, trace.Wrap(err)
}
a.githubClients[connector.GetName()] = &githubClient{
client: client,
config: config,
}
return client, nil
}

// ValidateGithubAuthCallback validates Github auth callback redirect
func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*authclient.GithubAuthResponse, error) {
logger := log.WithFields(logrus.Fields{teleport.ComponentKey: "github"})
Expand All @@ -584,19 +552,19 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODia

// optional parameter: error_description
errDesc := q.Get("error_description")
oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, errParam, q)
oauthErr := trace.OAuth2("invalid_request", errParam, q)
return nil, trace.WithUserMessage(oauthErr, "GitHub returned error: %v [%v]", errDesc, errParam)
}

code := q.Get("code")
if code == "" {
oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, "code query param must be set", q)
oauthErr := trace.OAuth2("invalid_request", "code query param must be set", q)
return nil, trace.WithUserMessage(oauthErr, "Invalid parameters received from GitHub.")
}

stateToken := q.Get("state")
if stateToken == "" {
oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, "missing state query param", q)
oauthErr := trace.OAuth2("invalid_request", "missing state query param", q)
return nil, trace.WithUserMessage(oauthErr, "Invalid parameters received from GitHub.")
}
diagCtx.RequestID = stateToken
Expand All @@ -607,15 +575,15 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODia
}
diagCtx.Info.TestFlow = req.SSOTestFlow

connector, client, err := a.getGithubConnectorAndClient(ctx, *req)
connector, err := a.getGithubConnector(ctx, *req)
if err != nil {
return nil, trace.Wrap(err, "Failed to get GitHub connector and client.")
}
diagCtx.Info.GithubTeamsToLogins = connector.GetTeamsToLogins()
diagCtx.Info.GithubTeamsToRoles = connector.GetTeamsToRoles()
logger.Debugf("Connector %q teams to logins: %v, roles: %v", connector.GetName(), connector.GetTeamsToLogins(), connector.GetTeamsToRoles())

userResp, teamsResp, err := a.getGithubUserAndTeams(ctx, connector, code, client, diagCtx, logger)
userResp, teamsResp, err := a.getGithubUserAndTeams(ctx, connector, code, diagCtx, logger)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -752,7 +720,6 @@ func (a *Server) getGithubUserAndTeams(
ctx context.Context,
connector types.GithubConnector,
code string,
client *oauth2.Client,
diagCtx *SSODiagContext,
logger *logrus.Entry,
) (*GithubUserResponse, []GithubTeamResponse, error) {
Expand All @@ -762,20 +729,26 @@ func (a *Server) getGithubUserAndTeams(
return a.GithubUserAndTeamsOverride()
}

config := newGithubOAuth2Config(connector)

// exchange the authorization code received by the callback for an access token
token, err := client.RequestToken(oauth2.GrantTypeAuthCode, code)
token, err := config.Exchange(ctx, code)
if err != nil {
return nil, nil, trace.Wrap(err, "Requesting GitHub OAuth2 token failed.")
}

scope, ok := token.Extra("scope").(string)
if !ok {
return nil, nil, trace.BadParameter("missing or invalid scope found in GitHub OAuth2 token")
}
diagCtx.Info.GithubTokenInfo = &types.GithubTokenInfo{
TokenType: token.TokenType,
Expires: int64(token.Expires),
Scope: token.Scope,
Expires: token.ExpiresIn,
Scope: scope,
}

logger.Debugf("Obtained OAuth2 token: Type=%v Expires=%v Scope=%v.",
token.TokenType, token.Expires, token.Scope)
token.TokenType, token.ExpiresIn, scope)

// Get the Github organizations the user is a member of so we don't
// make unnecessary API requests
Expand Down
Loading