From 8b6014514af93ce3049ca7f65ebf91abcabf64fc Mon Sep 17 00:00:00 2001 From: Qiao Han Date: Sat, 7 Dec 2024 16:35:32 +0800 Subject: [PATCH] chore: skip loading project ref if flag is set --- cmd/root.go | 1 + internal/projects/list/list.go | 5 ++--- internal/services/services.go | 2 +- internal/start/start.go | 2 +- internal/utils/flags/db_url.go | 16 +++++-------- internal/utils/flags/project_ref.go | 35 +++++++++++++---------------- 6 files changed, 27 insertions(+), 34 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 35540a897..80b4c592c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -109,6 +109,7 @@ var ( } } } + // TODO: always load config.toml if err := flags.ParseDatabaseConfig(cmd.Flags(), fsys); err != nil { return err } diff --git a/internal/projects/list/list.go b/internal/projects/list/list.go index d211f71e9..ef356dfcf 100644 --- a/internal/projects/list/list.go +++ b/internal/projects/list/list.go @@ -29,8 +29,7 @@ func Run(ctx context.Context, fsys afero.Fs) error { return errors.New("Unexpected error retrieving projects: " + string(resp.Body)) } - projectRef, err := flags.LoadProjectRef(fsys) - if err != nil && err != utils.ErrNotLinked { + if err := flags.LoadProjectRef(fsys); err != nil && err != utils.ErrNotLinked { fmt.Fprintln(os.Stderr, err) } @@ -38,7 +37,7 @@ func Run(ctx context.Context, fsys afero.Fs) error { for _, project := range *resp.JSON200 { projects = append(projects, linkedProject{ V1ProjectWithDatabaseResponse: project, - Linked: project.Id == projectRef, + Linked: project.Id == flags.ProjectRef, }) } diff --git a/internal/services/services.go b/internal/services/services.go index b9bfae6e6..ebc08efca 100644 --- a/internal/services/services.go +++ b/internal/services/services.go @@ -16,7 +16,7 @@ import ( ) func Run(ctx context.Context, fsys afero.Fs) error { - if _, err := flags.LoadProjectRef(fsys); err != nil && !errors.Is(err, utils.ErrNotLinked) { + if err := flags.LoadProjectRef(fsys); err != nil && !errors.Is(err, utils.ErrNotLinked) { fmt.Fprintln(os.Stderr, err) } if err := utils.Config.Load("", utils.NewRootFS(fsys)); err != nil && !errors.Is(err, os.ErrNotExist) { diff --git a/internal/start/start.go b/internal/start/start.go index 5877b47e4..b9e89f4c4 100644 --- a/internal/start/start.go +++ b/internal/start/start.go @@ -37,7 +37,7 @@ import ( func Run(ctx context.Context, fsys afero.Fs, excludedContainers []string, ignoreHealthCheck bool) error { // Sanity checks. { - _, _ = flags.LoadProjectRef(fsys) + _ = flags.LoadProjectRef(fsys) if err := utils.LoadConfigFS(fsys); err != nil { return err } diff --git a/internal/utils/flags/db_url.go b/internal/utils/flags/db_url.go index d089f2d13..013453b86 100644 --- a/internal/utils/flags/db_url.go +++ b/internal/utils/flags/db_url.go @@ -50,10 +50,8 @@ func ParseDatabaseConfig(flagSet *pflag.FlagSet, fsys afero.Fs) error { // Update connection config switch connType { case direct: - if err := utils.Config.Load("", utils.NewRootFS(fsys)); err != nil { - if !errors.Is(err, os.ErrNotExist) { - return err - } + if err := utils.Config.Load("", utils.NewRootFS(fsys)); err != nil && !errors.Is(err, os.ErrNotExist) { + return err } if flag := flagSet.Lookup("db-url"); flag != nil { config, err := pgconn.ParseConfig(flag.Value.String()) @@ -76,25 +74,23 @@ func ParseDatabaseConfig(flagSet *pflag.FlagSet, fsys afero.Fs) error { if err := utils.LoadConfigFS(fsys); err != nil { return err } - projectRef, err := LoadProjectRef(fsys) - if err != nil { + if err := LoadProjectRef(fsys); err != nil { return err } - DbConfig = NewDbConfigWithPassword(projectRef) + DbConfig = NewDbConfigWithPassword(ProjectRef) case proxy: token, err := utils.LoadAccessTokenFS(fsys) if err != nil { return err } - projectRef, err := LoadProjectRef(fsys) - if err != nil { + if err := LoadProjectRef(fsys); err != nil { return err } DbConfig.Host = utils.GetSupabaseAPIHost() DbConfig.Port = 443 DbConfig.User = "postgres" DbConfig.Password = token - DbConfig.Database = projectRef + DbConfig.Database = ProjectRef } return nil } diff --git a/internal/utils/flags/project_ref.go b/internal/utils/flags/project_ref.go index 437bbb668..73df72353 100644 --- a/internal/utils/flags/project_ref.go +++ b/internal/utils/flags/project_ref.go @@ -16,12 +16,7 @@ import ( var ProjectRef string func ParseProjectRef(ctx context.Context, fsys afero.Fs) error { - // Flag takes highest precedence - if len(ProjectRef) > 0 { - return utils.AssertProjectRefIsValid(ProjectRef) - } - // Followed by linked ref file - if _, err := LoadProjectRef(fsys); !errors.Is(err, utils.ErrNotLinked) { + if err := LoadProjectRef(fsys); !errors.Is(err, utils.ErrNotLinked) { return err } // Prompt as the last resort @@ -55,20 +50,22 @@ func PromptProjectRef(ctx context.Context, title string) error { return nil } -func LoadProjectRef(fsys afero.Fs) (string, error) { +func LoadProjectRef(fsys afero.Fs) error { + // Flag takes highest precedence + if len(ProjectRef) > 0 { + return utils.AssertProjectRefIsValid(ProjectRef) + } // Env var takes precedence over ref file - ProjectRef = viper.GetString("PROJECT_ID") - if len(ProjectRef) == 0 { - projectRefBytes, err := afero.ReadFile(fsys, utils.ProjectRefPath) - if errors.Is(err, os.ErrNotExist) { - return "", errors.New(utils.ErrNotLinked) - } else if err != nil { - return "", errors.Errorf("failed to load project ref: %w", err) - } - ProjectRef = string(bytes.TrimSpace(projectRefBytes)) + if ProjectRef = viper.GetString("PROJECT_ID"); len(ProjectRef) > 0 { + return utils.AssertProjectRefIsValid(ProjectRef) } - if err := utils.AssertProjectRefIsValid(ProjectRef); err != nil { - return "", err + // Load from local file last + projectRefBytes, err := afero.ReadFile(fsys, utils.ProjectRefPath) + if errors.Is(err, os.ErrNotExist) { + return errors.New(utils.ErrNotLinked) + } else if err != nil { + return errors.Errorf("failed to load project ref: %w", err) } - return ProjectRef, nil + ProjectRef = string(bytes.TrimSpace(projectRefBytes)) + return utils.AssertProjectRefIsValid(ProjectRef) }