diff --git a/pkg/cfg/cfg.go b/pkg/cfg/cfg.go index 552ae2d..eafa36b 100644 --- a/pkg/cfg/cfg.go +++ b/pkg/cfg/cfg.go @@ -1,11 +1,16 @@ package cfg -import "os" +import ( + "fmt" + "net/http" + "os" +) const RepoURL = "https://api.github.com/repos/render-oss/cli" const InstallationInstructionsURL = "https://render.com/docs/cli#1-install" var Version = "dev" +var osInfo string func GetHost() string { if host := os.Getenv("RENDER_HOST"); host != "" { @@ -18,3 +23,15 @@ func GetHost() string { func GetAPIKey() string { return os.Getenv("RENDER_API_KEY") } + +func AddUserAgent(header http.Header) http.Header { + header.Add("user-agent", fmt.Sprintf("render-cli/%s (%s)", Version, getOSInfoOnce())) + return header +} + +func getOSInfoOnce() string { + if osInfo == "" { + osInfo = getOSInfo() + } + return osInfo +} diff --git a/pkg/client/useragent.go b/pkg/cfg/useragent.go similarity index 98% rename from pkg/client/useragent.go rename to pkg/cfg/useragent.go index e98a0b9..813dccc 100644 --- a/pkg/client/useragent.go +++ b/pkg/cfg/useragent.go @@ -1,4 +1,4 @@ -package client +package cfg import ( "fmt" diff --git a/pkg/client/client.go b/pkg/client/client.go index d1ca501..33a1664 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -7,38 +7,60 @@ import ( "fmt" "net/http" "reflect" + "time" "github.com/renderinc/cli/pkg/cfg" + "github.com/renderinc/cli/pkg/client/oauth" "github.com/renderinc/cli/pkg/config" ) var ErrUnauthorized = errors.New("unauthorized") var ErrForbidden = errors.New("forbidden") -var osInfo string - -func getOSInfoOnce() string { - if osInfo == "" { - osInfo = getOSInfo() - } - return osInfo -} - func NewDefaultClient() (*ClientWithResponses, error) { apiCfg, err := config.DefaultAPIConfig() if err != nil { return nil, err } + apiCfg = maybeRefreshAPIToken(apiCfg) return clientWithAuth(&http.Client{}, apiCfg) } -func AddUserAgent(header http.Header) http.Header { - header.Add("user-agent", fmt.Sprintf("render-cli/%s (%s)", cfg.Version, getOSInfoOnce())) - return header +func maybeRefreshAPIToken(apiCfg config.APIConfig) config.APIConfig { + expiresSoonThreshold := time.Now().Add(24*time.Hour).Unix() + + if apiCfg.ExpiresAt > 0 && apiCfg.ExpiresAt < expiresSoonThreshold && apiCfg.RefreshToken != "" { + updatedConfig, err := refreshAPIKey(apiCfg) + if err != nil { + // failed to refresh the token, clear the refresh token so we fall back + // to the standard login flow + apiCfg.RefreshToken = "" + _ = config.SetAPIConfig(apiCfg) + return apiCfg + } + + apiCfg = updatedConfig + } + return apiCfg +} + +func refreshAPIKey(apiCfg config.APIConfig) (config.APIConfig, error) { + token, err := oauth.NewClient(apiCfg.Host).RefreshToken( + context.Background(), + apiCfg.RefreshToken, + ) + if err != nil { + return config.APIConfig{}, err + } + + apiCfg.Key = token.AccessToken + apiCfg.RefreshToken = token.RefreshToken + apiCfg.ExpiresAt = time.Now().Add(time.Second * time.Duration(token.ExpiresIn)).Unix() + return apiCfg, config.SetAPIConfig(apiCfg) } func AddHeaders(header http.Header, token string) http.Header { - header = AddUserAgent(header) + header = cfg.AddUserAgent(header) header.Add("authorization", fmt.Sprintf("Bearer %s", token)) return header } diff --git a/pkg/client/devicegrant/devicegrant.go b/pkg/client/oauth/oauth.go similarity index 74% rename from pkg/client/devicegrant/devicegrant.go rename to pkg/client/oauth/oauth.go index bbd200b..29e0d56 100644 --- a/pkg/client/devicegrant/devicegrant.go +++ b/pkg/client/oauth/oauth.go @@ -1,4 +1,4 @@ -package devicegrant +package oauth import ( "bytes" @@ -10,7 +10,7 @@ import ( "net/http" "strings" - "github.com/renderinc/cli/pkg/client" + "github.com/renderinc/cli/pkg/cfg" ) const cliOauthClientID = "429024F5E608930E2A65EF92591A25CC" @@ -32,9 +32,10 @@ type GrantRequestBody struct { } type DeviceToken struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` } type TokenRequestBody struct { @@ -60,7 +61,7 @@ func NewClient(host string) *Client { } func (c *Client) Do(req *http.Request) (*http.Response, error) { - req.Header = client.AddUserAgent(req.Header) + req.Header = cfg.AddUserAgent(req.Header) return c.c.Do(req) } @@ -76,7 +77,7 @@ func (c *Client) CreateGrant(ctx context.Context) (*DeviceGrant, error) { return &grant, nil } -func (c *Client) GetDeviceToken(ctx context.Context, dg *DeviceGrant) (string, error) { +func (c *Client) GetDeviceTokenResponse(ctx context.Context, dg *DeviceGrant) (*DeviceToken, error) { body := &TokenRequestBody{ ClientID: cliOauthClientID, DeviceCode: dg.DeviceCode, GrantType: "urn:ietf:params:oauth:grant-type:device_code", @@ -86,13 +87,33 @@ func (c *Client) GetDeviceToken(ctx context.Context, dg *DeviceGrant) (string, e err := c.postFor(ctx, "/device-token", body, &token) if err != nil { if err.Error() == authorizationPendingAPIMsg { - return "", ErrAuthorizationPending + return nil, ErrAuthorizationPending } - return "", err + return nil, err + } + + return &token, nil +} + +type RefreshTokenRequestBody struct { + GrantType string `json:"grant_type"` + RefreshToken string `json:"refresh_token"` +} + +func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*DeviceToken, error) { + body := &RefreshTokenRequestBody{ + GrantType: "refresh_token", + RefreshToken: refreshToken, + } + + var token DeviceToken + err := c.postFor(ctx, "/token/refresh/", body, &token) + if err != nil { + return nil, err } - return token.AccessToken, nil + return &token, nil } func (c *Client) postFor(ctx context.Context, path string, body any, v any) error { diff --git a/pkg/client/devicegrant/devicegrant_test.go b/pkg/client/oauth/oauth_test.go similarity index 79% rename from pkg/client/devicegrant/devicegrant_test.go rename to pkg/client/oauth/oauth_test.go index 2a36fe0..c62c5c0 100644 --- a/pkg/client/devicegrant/devicegrant_test.go +++ b/pkg/client/oauth/oauth_test.go @@ -1,4 +1,4 @@ -package devicegrant_test +package oauth_test import ( "context" @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/renderinc/cli/pkg/client/devicegrant" + "github.com/renderinc/cli/pkg/client/oauth" ) func TestClient_CreateGrant(t *testing.T) { @@ -22,12 +22,12 @@ func TestClient_CreateGrant(t *testing.T) { require.NoError(t, err) })) - c := devicegrant.NewClient(s.URL) + c := oauth.NewClient(s.URL) dg, err := c.CreateGrant(context.Background()) require.NoError(t, err) - assert.Equal(t, &devicegrant.DeviceGrant{ + assert.Equal(t, &oauth.DeviceGrant{ DeviceCode: "some device code", UserCode: "some user code", VerificationUri: "some verification uri", @@ -39,7 +39,7 @@ func TestClient_CreateGrant(t *testing.T) { func TestClient_GetDeviceToken(t *testing.T) { t.Run("it gets the device token", func(t *testing.T) { - var gotBody devicegrant.TokenRequestBody + var gotBody oauth.TokenRequestBody s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "POST", r.Method) require.Equal(t, "/device-token", r.URL.Path) @@ -50,14 +50,14 @@ func TestClient_GetDeviceToken(t *testing.T) { require.NoError(t, err) })) - c := devicegrant.NewClient(s.URL) + c := oauth.NewClient(s.URL) - token, err := c.GetDeviceToken(context.Background(), &devicegrant.DeviceGrant{ + token, err := c.GetDeviceTokenResponse(context.Background(), &oauth.DeviceGrant{ DeviceCode: "some device code", }) require.NoError(t, err) - assert.Equal(t, "some device token", token) + assert.Equal(t, "some device token", token.AccessToken) assert.Equal(t, "some device code", gotBody.DeviceCode) assert.NotZero(t, gotBody.ClientID) @@ -70,12 +70,12 @@ func TestClient_GetDeviceToken(t *testing.T) { require.NoError(t, err) })) - c := devicegrant.NewClient(s.URL) + c := oauth.NewClient(s.URL) - _, err := c.GetDeviceToken(context.Background(), &devicegrant.DeviceGrant{ + _, err := c.GetDeviceTokenResponse(context.Background(), &oauth.DeviceGrant{ DeviceCode: "some device code", }) - require.ErrorIs(t, err, devicegrant.ErrAuthorizationPending) + require.ErrorIs(t, err, oauth.ErrAuthorizationPending) }) } diff --git a/pkg/config/config.go b/pkg/config/config.go index c8dbbd6..0df7dd8 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -35,7 +35,9 @@ type Config struct { type APIConfig struct { Key string `yaml:"key,omitempty"` + ExpiresAt int64 `yaml:"expires_at,omitempty"` Host string `json:"host,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` } func init() { @@ -182,14 +184,16 @@ func getAPIConfig() (APIConfig, error) { return cfg.APIConfig, nil } -func SetAPIConfig(host, apiKey string) error { +func SetAPIConfig(input APIConfig) error { cfg, err := Load() if err != nil { return err } - cfg.Host = host - cfg.Key = apiKey + cfg.Host = input.Host + cfg.Key = input.Key + cfg.ExpiresAt = input.ExpiresAt + cfg.RefreshToken = input.RefreshToken return cfg.Persist() } diff --git a/pkg/tui/views/login.go b/pkg/tui/views/login.go index 06e6ce8..b5282c7 100644 --- a/pkg/tui/views/login.go +++ b/pkg/tui/views/login.go @@ -10,11 +10,11 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/renderinc/cli/pkg/client/oauth" "github.com/spf13/cobra" "github.com/renderinc/cli/pkg/cfg" "github.com/renderinc/cli/pkg/client" - "github.com/renderinc/cli/pkg/client/devicegrant" "github.com/renderinc/cli/pkg/client/version" "github.com/renderinc/cli/pkg/command" "github.com/renderinc/cli/pkg/config" @@ -24,7 +24,7 @@ import ( ) func NonInteractiveLogin(cmd *cobra.Command) error { - dc := devicegrant.NewClient(cfg.GetHost()) + dc := oauth.NewClient(cfg.GetHost()) vc := version.NewClient(cfg.RepoURL) alreadyLoggedIn := isAlreadyLoggedIn(cmd.Context()) @@ -49,7 +49,7 @@ func NonInteractiveLogin(cmd *cobra.Command) error { return nil } -func login(cmd *cobra.Command, c *devicegrant.Client) error { +func login(cmd *cobra.Command, c *oauth.Client) error { dg, err := c.CreateGrant(cmd.Context()) if err != nil { return err @@ -72,20 +72,21 @@ func login(cmd *cobra.Command, c *devicegrant.Client) error { return err } - return config.SetAPIConfig(cfg.GetHost(), token) + apiCfg := configForToken(token) + return config.SetAPIConfig(apiCfg) } type LoginView struct { ctx context.Context - dc *devicegrant.Client + dc *oauth.Client vc *version.Client dashURL string } func NewLoginView(ctx context.Context) *LoginView { - dc := devicegrant.NewClient(cfg.GetHost()) + dc := oauth.NewClient(cfg.GetHost()) vc := version.NewClient(cfg.RepoURL) return &LoginView{ @@ -97,12 +98,12 @@ func NewLoginView(ctx context.Context) *LoginView { type loginStartedMsg struct { dashURL string - deviceGrant *devicegrant.DeviceGrant + deviceGrant *oauth.DeviceGrant } type loginCompleteMsg struct{} -func startLogin(ctx context.Context, dc *devicegrant.Client) tea.Cmd { +func startLogin(ctx context.Context, dc *oauth.Client) tea.Cmd { return func() tea.Msg { dg, err := dc.CreateGrant(ctx) if err != nil { @@ -126,14 +127,15 @@ func startLogin(ctx context.Context, dc *devicegrant.Client) tea.Cmd { } } -func pollForLogin(ctx context.Context, dc *devicegrant.Client, msg loginStartedMsg) tea.Cmd { +func pollForLogin(ctx context.Context, dc *oauth.Client, msg loginStartedMsg) tea.Cmd { return func() tea.Msg { token, err := pollForToken(ctx, dc, msg.deviceGrant) if err != nil { return tui.ErrorMsg{Err: err} } - err = config.SetAPIConfig(cfg.GetHost(), token) + apiCfg := configForToken(token) + err = config.SetAPIConfig(apiCfg) if err != nil { return tui.ErrorMsg{Err: err} } @@ -195,7 +197,7 @@ func isAlreadyLoggedIn(ctx context.Context) bool { return err == nil && resp.StatusCode == http.StatusOK } -func dashboardAuthURL(dg *devicegrant.DeviceGrant) (*url.URL, error) { +func dashboardAuthURL(dg *oauth.DeviceGrant) (*url.URL, error) { u, err := url.Parse(dg.VerificationUriComplete) if err != nil { return nil, err @@ -209,24 +211,33 @@ func dashboardAuthURL(dg *devicegrant.DeviceGrant) (*url.URL, error) { return u, nil } -func pollForToken(ctx context.Context, c *devicegrant.Client, dg *devicegrant.DeviceGrant) (string, error) { +func pollForToken(ctx context.Context, c *oauth.Client, dg *oauth.DeviceGrant) (*oauth.DeviceToken, error) { timeout := time.NewTimer(time.Duration(dg.ExpiresIn) * time.Second) interval := time.NewTicker(time.Duration(dg.Interval) * time.Second) for { select { case <-timeout.C: - return "", errors.New("timed out") + return nil, errors.New("timed out") case <-interval.C: - token, err := c.GetDeviceToken(ctx, dg) - if errors.Is(err, devicegrant.ErrAuthorizationPending) { + token, err := c.GetDeviceTokenResponse(ctx, dg) + if errors.Is(err, oauth.ErrAuthorizationPending) { continue } if err != nil { - return "", err + return nil, err } return token, nil } } } + +func configForToken(token *oauth.DeviceToken) config.APIConfig { + return config.APIConfig{ + Host: cfg.GetHost(), + Key: token.AccessToken, + ExpiresAt: time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix(), + RefreshToken: token.RefreshToken, + } +}