Skip to content

Commit

Permalink
use CobraCmd context in the upgrade flow
Browse files Browse the repository at this point in the history
  • Loading branch information
neogopher committed Dec 4, 2024
1 parent b881c92 commit ed3c336
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 11 deletions.
10 changes: 7 additions & 3 deletions cmd/vclusterctl/cmd/upgrade.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cmd

import (
"context"

"github.com/loft-sh/log"
"github.com/loft-sh/vcluster/pkg/upgrade"
"github.com/pkg/errors"
Expand Down Expand Up @@ -28,16 +30,18 @@ func NewUpgradeCmd() *cobra.Command {
Upgrades the vcluster CLI to the newest version
#######################################################`,
Args: cobra.NoArgs,
RunE: cmd.Run,
RunE: func(cobraCmd *cobra.Command, _ []string) error {
return cmd.Run(cobraCmd.Context())
},
}

upgradeCmd.Flags().StringVar(&cmd.Version, "version", "", "The version to update vcluster to. Defaults to the latest stable version available")
return upgradeCmd
}

// Run executes the command logic
func (cmd *UpgradeCmd) Run(*cobra.Command, []string) error {
err := upgrade.Upgrade(cmd.Version, cmd.log)
func (cmd *UpgradeCmd) Run(ctx context.Context) error {
err := upgrade.Upgrade(ctx, cmd.Version, cmd.log)
if err != nil {
return errors.Errorf("Couldn't upgrade: %v", err)
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/upgrade/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ import (
"golang.org/x/oauth2"
)

func fetchReleaseByTag(owner, repo, tag string) (*github.RepositoryRelease, error) {
func fetchReleaseByTag(ctx context.Context, owner, repo, tag string) (*github.RepositoryRelease, error) {
var (
ctx = context.Background()
token string
hc *http.Client

Expand Down
9 changes: 5 additions & 4 deletions pkg/upgrade/upgrade.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package upgrade

import (
"context"
"fmt"
"os"
"regexp"
Expand Down Expand Up @@ -116,15 +117,15 @@ func NewerVersionAvailable() string {
}

// Upgrade downloads the latest release from github and replaces vcluster if a new version is found
func Upgrade(flagVersion string, log log.Logger) error {
func Upgrade(ctx context.Context, flagVersion string, log log.Logger) error {
updater, err := selfupdate.NewUpdater(selfupdate.Config{
Filters: []string{"vcluster"},
})
if err != nil {
return fmt.Errorf("failed to initialize updater: %w", err)
}
if flagVersion != "" {
release, found, err := DetectVersion(githubSlug, flagVersion)
release, found, err := DetectVersion(ctx, githubSlug, flagVersion)
if err != nil {
return errors.Wrap(err, "find version")
} else if !found {
Expand Down Expand Up @@ -188,7 +189,7 @@ func Upgrade(flagVersion string, log log.Logger) error {
return nil
}

func DetectVersion(slug string, version string) (*selfupdate.Release, bool, error) {
func DetectVersion(ctx context.Context, slug string, version string) (*selfupdate.Release, bool, error) {
var (
release *selfupdate.Release
found bool
Expand All @@ -200,7 +201,7 @@ func DetectVersion(slug string, version string) (*selfupdate.Release, bool, erro
return nil, false, fmt.Errorf("invalid slug format. It should be 'owner/name': %s", slug)
}

githubRelease, err := fetchReleaseByTag(repo[0], repo[1], version)
githubRelease, err := fetchReleaseByTag(ctx, repo[0], repo[1], version)
if err != nil {
return nil, false, fmt.Errorf("repository or release not found: %w", err)
}
Expand Down
6 changes: 4 additions & 2 deletions pkg/upgrade/upgrade_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package upgrade

import (
"context"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -75,7 +76,8 @@ func TestUpgrade(t *testing.T) {
defer func() { version = versionBackup }()

// Newest version already reached
err = Upgrade("", log.GetInstance())
ctx := context.Background()
err = Upgrade(ctx, "", log.GetInstance())
assert.Equal(t, false, err != nil, "Upgrade returned error if newest version already reached")
err = logFile.Close()
if err != nil {
Expand All @@ -91,6 +93,6 @@ func TestUpgrade(t *testing.T) {
githubSlugBackup := githubSlug
githubSlug = ""
defer func() { githubSlug = githubSlugBackup }()
err = Upgrade("", log.GetInstance())
err = Upgrade(ctx, "", log.GetInstance())
assert.Equal(t, true, err != nil, "No error returned if DetectLatest returns one.")
}

0 comments on commit ed3c336

Please sign in to comment.