Skip to content

Commit

Permalink
Use refresh tokens to renew access tokens (#150)
Browse files Browse the repository at this point in the history
* Rename devicegrant to oauth

* Store refresh token when logging in

* Automatically refresh API token when it is expiring soon

* PR feedback
  • Loading branch information
jakemalachowski authored Dec 3, 2024
1 parent eec9917 commit 1473812
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 55 deletions.
19 changes: 18 additions & 1 deletion pkg/cfg/cfg.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package cfg

import "os"
import (
"fmt"
"net/http"
"os"
)

const RepoURL = "https://api.github.com/repos/render-oss/cli"
const InstallationInstructionsURL = "https://render.com/docs/cli#1-install"

var Version = "dev"
var osInfo string

func GetHost() string {
if host := os.Getenv("RENDER_HOST"); host != "" {
Expand All @@ -18,3 +23,15 @@ func GetHost() string {
func GetAPIKey() string {
return os.Getenv("RENDER_API_KEY")
}

func AddUserAgent(header http.Header) http.Header {
header.Add("user-agent", fmt.Sprintf("render-cli/%s (%s)", Version, getOSInfoOnce()))
return header
}

func getOSInfoOnce() string {
if osInfo == "" {
osInfo = getOSInfo()
}
return osInfo
}
2 changes: 1 addition & 1 deletion pkg/client/useragent.go → pkg/cfg/useragent.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package client
package cfg

import (
"fmt"
Expand Down
48 changes: 35 additions & 13 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,60 @@ import (
"fmt"
"net/http"
"reflect"
"time"

"github.com/renderinc/cli/pkg/cfg"
"github.com/renderinc/cli/pkg/client/oauth"
"github.com/renderinc/cli/pkg/config"
)

var ErrUnauthorized = errors.New("unauthorized")
var ErrForbidden = errors.New("forbidden")

var osInfo string

func getOSInfoOnce() string {
if osInfo == "" {
osInfo = getOSInfo()
}
return osInfo
}

func NewDefaultClient() (*ClientWithResponses, error) {
apiCfg, err := config.DefaultAPIConfig()
if err != nil {
return nil, err
}
apiCfg = maybeRefreshAPIToken(apiCfg)
return clientWithAuth(&http.Client{}, apiCfg)
}

func AddUserAgent(header http.Header) http.Header {
header.Add("user-agent", fmt.Sprintf("render-cli/%s (%s)", cfg.Version, getOSInfoOnce()))
return header
func maybeRefreshAPIToken(apiCfg config.APIConfig) config.APIConfig {
expiresSoonThreshold := time.Now().Add(24*time.Hour).Unix()

if apiCfg.ExpiresAt > 0 && apiCfg.ExpiresAt < expiresSoonThreshold && apiCfg.RefreshToken != "" {
updatedConfig, err := refreshAPIKey(apiCfg)
if err != nil {
// failed to refresh the token, clear the refresh token so we fall back
// to the standard login flow
apiCfg.RefreshToken = ""
_ = config.SetAPIConfig(apiCfg)
return apiCfg
}

apiCfg = updatedConfig
}
return apiCfg
}

func refreshAPIKey(apiCfg config.APIConfig) (config.APIConfig, error) {
token, err := oauth.NewClient(apiCfg.Host).RefreshToken(
context.Background(),
apiCfg.RefreshToken,
)
if err != nil {
return config.APIConfig{}, err
}

apiCfg.Key = token.AccessToken
apiCfg.RefreshToken = token.RefreshToken
apiCfg.ExpiresAt = time.Now().Add(time.Second * time.Duration(token.ExpiresIn)).Unix()
return apiCfg, config.SetAPIConfig(apiCfg)
}

func AddHeaders(header http.Header, token string) http.Header {
header = AddUserAgent(header)
header = cfg.AddUserAgent(header)
header.Add("authorization", fmt.Sprintf("Bearer %s", token))
return header
}
Expand Down
41 changes: 31 additions & 10 deletions pkg/client/devicegrant/devicegrant.go → pkg/client/oauth/oauth.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package devicegrant
package oauth

import (
"bytes"
Expand All @@ -10,7 +10,7 @@ import (
"net/http"
"strings"

"github.com/renderinc/cli/pkg/client"
"github.com/renderinc/cli/pkg/cfg"
)

const cliOauthClientID = "429024F5E608930E2A65EF92591A25CC"
Expand All @@ -32,9 +32,10 @@ type GrantRequestBody struct {
}

type DeviceToken struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
}

type TokenRequestBody struct {
Expand All @@ -60,7 +61,7 @@ func NewClient(host string) *Client {
}

func (c *Client) Do(req *http.Request) (*http.Response, error) {
req.Header = client.AddUserAgent(req.Header)
req.Header = cfg.AddUserAgent(req.Header)
return c.c.Do(req)
}

Expand All @@ -76,7 +77,7 @@ func (c *Client) CreateGrant(ctx context.Context) (*DeviceGrant, error) {
return &grant, nil
}

func (c *Client) GetDeviceToken(ctx context.Context, dg *DeviceGrant) (string, error) {
func (c *Client) GetDeviceTokenResponse(ctx context.Context, dg *DeviceGrant) (*DeviceToken, error) {
body := &TokenRequestBody{
ClientID: cliOauthClientID, DeviceCode: dg.DeviceCode,
GrantType: "urn:ietf:params:oauth:grant-type:device_code",
Expand All @@ -86,13 +87,33 @@ func (c *Client) GetDeviceToken(ctx context.Context, dg *DeviceGrant) (string, e
err := c.postFor(ctx, "/device-token", body, &token)
if err != nil {
if err.Error() == authorizationPendingAPIMsg {
return "", ErrAuthorizationPending
return nil, ErrAuthorizationPending
}

return "", err
return nil, err
}

return &token, nil
}

type RefreshTokenRequestBody struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
}

func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*DeviceToken, error) {
body := &RefreshTokenRequestBody{
GrantType: "refresh_token",
RefreshToken: refreshToken,
}

var token DeviceToken
err := c.postFor(ctx, "/token/refresh/", body, &token)
if err != nil {
return nil, err
}

return token.AccessToken, nil
return &token, nil
}

func (c *Client) postFor(ctx context.Context, path string, body any, v any) error {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package devicegrant_test
package oauth_test

import (
"context"
Expand All @@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/renderinc/cli/pkg/client/devicegrant"
"github.com/renderinc/cli/pkg/client/oauth"
)

func TestClient_CreateGrant(t *testing.T) {
Expand All @@ -22,12 +22,12 @@ func TestClient_CreateGrant(t *testing.T) {
require.NoError(t, err)
}))

c := devicegrant.NewClient(s.URL)
c := oauth.NewClient(s.URL)

dg, err := c.CreateGrant(context.Background())
require.NoError(t, err)

assert.Equal(t, &devicegrant.DeviceGrant{
assert.Equal(t, &oauth.DeviceGrant{
DeviceCode: "some device code",
UserCode: "some user code",
VerificationUri: "some verification uri",
Expand All @@ -39,7 +39,7 @@ func TestClient_CreateGrant(t *testing.T) {

func TestClient_GetDeviceToken(t *testing.T) {
t.Run("it gets the device token", func(t *testing.T) {
var gotBody devicegrant.TokenRequestBody
var gotBody oauth.TokenRequestBody
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "POST", r.Method)
require.Equal(t, "/device-token", r.URL.Path)
Expand All @@ -50,14 +50,14 @@ func TestClient_GetDeviceToken(t *testing.T) {
require.NoError(t, err)
}))

c := devicegrant.NewClient(s.URL)
c := oauth.NewClient(s.URL)

token, err := c.GetDeviceToken(context.Background(), &devicegrant.DeviceGrant{
token, err := c.GetDeviceTokenResponse(context.Background(), &oauth.DeviceGrant{
DeviceCode: "some device code",
})
require.NoError(t, err)

assert.Equal(t, "some device token", token)
assert.Equal(t, "some device token", token.AccessToken)

assert.Equal(t, "some device code", gotBody.DeviceCode)
assert.NotZero(t, gotBody.ClientID)
Expand All @@ -70,12 +70,12 @@ func TestClient_GetDeviceToken(t *testing.T) {
require.NoError(t, err)
}))

c := devicegrant.NewClient(s.URL)
c := oauth.NewClient(s.URL)

_, err := c.GetDeviceToken(context.Background(), &devicegrant.DeviceGrant{
_, err := c.GetDeviceTokenResponse(context.Background(), &oauth.DeviceGrant{
DeviceCode: "some device code",
})
require.ErrorIs(t, err, devicegrant.ErrAuthorizationPending)
require.ErrorIs(t, err, oauth.ErrAuthorizationPending)
})
}

Expand Down
10 changes: 7 additions & 3 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ type Config struct {

type APIConfig struct {
Key string `yaml:"key,omitempty"`
ExpiresAt int64 `yaml:"expires_at,omitempty"`
Host string `json:"host,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
}

func init() {
Expand Down Expand Up @@ -182,14 +184,16 @@ func getAPIConfig() (APIConfig, error) {
return cfg.APIConfig, nil
}

func SetAPIConfig(host, apiKey string) error {
func SetAPIConfig(input APIConfig) error {
cfg, err := Load()
if err != nil {
return err
}

cfg.Host = host
cfg.Key = apiKey
cfg.Host = input.Host
cfg.Key = input.Key
cfg.ExpiresAt = input.ExpiresAt
cfg.RefreshToken = input.RefreshToken
return cfg.Persist()
}

Expand Down
Loading

0 comments on commit 1473812

Please sign in to comment.