diff --git a/pkg/config/config.go b/pkg/config/config.go index 9351c8779..3fd7f53e0 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -37,7 +37,7 @@ var ( // `cmd.PersistentFlags().AddFlagSet(connection.Flags)` Flags *pflag.FlagSet = CreateFlagSet() - // Directory is $HOME/config/registry + // Directory is $HOME/.config/registry Directory string ErrCannotDeleteActive = fmt.Errorf("cannot delete active configuration") ErrReservedConfigName = fmt.Errorf("%q is reserved", ActivePointerFilename) @@ -100,10 +100,6 @@ func ValidateName(name string) error { if name == ActivePointerFilename { return ErrReservedConfigName } - - if dir, _ := filepath.Split(name); dir != "" { - return fmt.Errorf("%q must not include a path", name) - } return nil } @@ -199,9 +195,10 @@ func ReadValid(name string) (c Configuration, err error) { } // Read loads a Configuration from yaml file matching `name`. If name -// contains a path, the file will be read from that path, otherwise -// the path is assumed as: ~/.config/registry. Does a simple read from the -// file: does not bind to env vars or flags, resolve, or validate. +// contains a path or refers to a local file, the file will be read +// using name, otherwise the path is assumed as: ~/.config/registry. +// Does a simple read from the file: does not bind to env vars or flags, +// resolve, or validate. // See also: ReadValid() func Read(name string) (c Configuration, err error) { if err = ValidateName(name); err != nil { @@ -210,7 +207,11 @@ func Read(name string) (c Configuration, err error) { dir, file := filepath.Split(name) if dir == "" { - name = filepath.Join(Directory, file) + // If name refers to a local file, preferentially read the local file. + // Otherwise assume name refers to a file in the config directory. + if info, err := os.Stat(file); errors.Is(err, os.ErrNotExist) || info.IsDir() { + name = filepath.Join(Directory, file) + } } var r io.Reader if r, err = os.Open(name); err != nil { diff --git a/pkg/config/configuration_test.go b/pkg/config/configuration_test.go index 32804876c..a20f163e6 100644 --- a/pkg/config/configuration_test.go +++ b/pkg/config/configuration_test.go @@ -15,6 +15,7 @@ package config_test import ( + "log" "os" "path" "path/filepath" @@ -24,6 +25,7 @@ import ( "github.com/apigee/registry/pkg/config" "github.com/apigee/registry/pkg/config/test" "github.com/google/go-cmp/cmp" + "gopkg.in/yaml.v2" ) func TestMissingDirectory(t *testing.T) { @@ -476,3 +478,40 @@ func TestResolve(t *testing.T) { t.Errorf("want: %s, got: %s", "hello", c.Registry.Token) } } + +func TestReadExternalFile(t *testing.T) { + c := config.Configuration{ + Registry: config.Registry{Address: "example.com:443"}, + } + bytes, err := yaml.Marshal(c) + if err != nil { + t.Fatal(err) + } + f, err := os.CreateTemp(".", "tmpfile-") + if err != nil { + log.Fatal(err) + } + if _, err = f.Write(bytes); err != nil { + log.Fatal(err) + } + if err = f.Close(); err != nil { + log.Fatal(err) + } + defer os.Remove(f.Name()) + // Verify that we can read a file using its full path name. + c2, err := config.Read(f.Name()) + if err != nil { + t.Fatal(err) + } + if c2.Registry.Address != c.Registry.Address { + t.Errorf("want: %s, got: %s", c2.Registry.Address, c.Registry.Address) + } + // Verify that we can read a local file using its base name. + c3, err := config.Read(filepath.Base(f.Name())) + if err != nil { + t.Fatal(err) + } + if c3.Registry.Address != c.Registry.Address { + t.Errorf("want: %s, got: %s", c3.Registry.Address, c.Registry.Address) + } +}