From 61631a9e1fff6e4e5d7e11dcafc4a80f4c9733c2 Mon Sep 17 00:00:00 2001 From: Arvid Bjurklint Date: Fri, 16 Aug 2024 01:13:43 +0200 Subject: [PATCH] Add switching between mounts --- dev-env.sh | 1 - internal/vault/client.go | 64 +++++++++++++++++++++++++++++------ internal/vault/client_test.go | 6 ++-- main.go | 45 +++++++++++++++++++----- 4 files changed, 92 insertions(+), 24 deletions(-) diff --git a/dev-env.sh b/dev-env.sh index 32564a0..15d6622 100644 --- a/dev-env.sh +++ b/dev-env.sh @@ -1,3 +1,2 @@ export VAULT_ADDR=http://127.0.0.1:8200 export VAULT_TOKEN=dev-only-token -export VAULT_MOUNT=secret diff --git a/internal/vault/client.go b/internal/vault/client.go index 74d5725..a3905ba 100644 --- a/internal/vault/client.go +++ b/internal/vault/client.go @@ -6,6 +6,7 @@ import ( "io" "log/slog" "net/http" + "slices" "strings" "sync" ) @@ -13,7 +14,6 @@ import ( type Client struct { Addr string Token string - Mount string } type dirEnt struct { @@ -21,29 +21,35 @@ type dirEnt struct { Name string } -func GetKeys(client Client) []string { +var cachedKeys = make(map[string][]string) + +func GetKeys(client Client, mount string) []string { + if keys, found := cachedKeys[mount]; found { + return keys + } entrypoint := dirEnt{ IsDir: true, Name: "/", } recv := make(chan string) go func() { - recurse(recv, client, entrypoint) + recurse(recv, client, mount, entrypoint) close(recv) }() keys := []string{} for key := range recv { keys = append(keys, key) } + cachedKeys[mount] = keys return keys } -func recurse(recv chan string, client Client, entry dirEnt) { +func recurse(recv chan string, client Client, mount string, entry dirEnt) { if !entry.IsDir { recv <- entry.Name return } - relativeEntries, err := client.listDir(entry.Name) + relativeEntries, err := client.listDir(mount, entry.Name) if err != nil { slog.Error("Failed to list directory", "directory", entry.Name, "err", err.Error()) return @@ -60,14 +66,14 @@ func recurse(recv chan string, client Client, entry dirEnt) { wg.Add(1) go func(entry dirEnt) { defer wg.Done() - recurse(recv, client, e) + recurse(recv, client, mount, e) }(e) } wg.Wait() } -func (c Client) listDir(name string) ([]dirEnt, error) { - url := fmt.Sprintf("%s/v1/%s/metadata%s?list=true", c.Addr, c.Mount, name) +func (c Client) listDir(mount string, name string) ([]dirEnt, error) { + url := fmt.Sprintf("%s/v1/%s/metadata%s?list=true", c.Addr, mount, name) request, err := http.NewRequest("GET", url, nil) if err != nil { return []dirEnt{}, fmt.Errorf("Failed to create request: %s", err) @@ -117,11 +123,11 @@ type Secret struct { var cachedSecrets = make(map[string]Secret) -func (c Client) GetSecret(name string) Secret { +func (c Client) GetSecret(mount, name string) Secret { if secret, found := cachedSecrets[name]; found { return secret } - url := fmt.Sprintf("%s/v1/%s/data%s", c.Addr, c.Mount, name) + url := fmt.Sprintf("%s/v1/%s/data%s", c.Addr, mount, name) request, err := http.NewRequest("GET", url, nil) if err != nil { panic(fmt.Errorf("Failed to create request: %s", err)) @@ -151,3 +157,41 @@ func (c Client) GetSecret(name string) Secret { cachedSecrets[name] = secret return secret } + +type MountResponse struct { + Data struct { + Secret map[string]Mount + } +} + +type Mount struct { + Type string +} + +func (c Client) GetMounts() []string { + url := fmt.Sprintf("%s/v1/sys/internal/ui/mounts", c.Addr) + request, err := http.NewRequest("GET", url, nil) + if err != nil { + panic(fmt.Errorf("Failed to create request: %s", err)) + } + request.Header.Set("X-Vault-Token", c.Token) + request.Header.Set("Accept", "application/json") + response, err := http.DefaultClient.Do(request) + if err != nil { + panic(fmt.Errorf("Failed to perform request: %s", err)) + } + defer response.Body.Close() + body, err := io.ReadAll(response.Body) + var mounts MountResponse + if err := json.Unmarshal(body, &mounts); err != nil { + panic(fmt.Errorf("failed to unmarshal response body %s: %s", string(body), err)) + } + mountNames := []string{} + for k, v := range mounts.Data.Secret { + if v.Type == "kv" { + mountNames = append(mountNames, strings.TrimSuffix(k, "/")) + } + } + slices.Sort(mountNames) + return mountNames +} diff --git a/internal/vault/client_test.go b/internal/vault/client_test.go index 2f5fd4c..cf12559 100644 --- a/internal/vault/client_test.go +++ b/internal/vault/client_test.go @@ -36,9 +36,8 @@ func TestGetKeys(t *testing.T) { vaultClient := Client{ Addr: vaultAddr, Token: token, - Mount: "secret", } - keys := GetKeys(vaultClient) + keys := GetKeys(vaultClient, "secret") if len(keys) != len(secrets) { t.Fatalf("Expected %d keys, got %d", len(secrets), len(keys)) } @@ -69,9 +68,8 @@ func TestGetSecret(t *testing.T) { vaultClient := Client{ Addr: vaultAddr, Token: token, - Mount: "secret", } - secret := vaultClient.GetSecret("/bar/baz") + secret := vaultClient.GetSecret("secret", "/bar/baz") if err != nil { t.Fatalf("Got unexpected error: %s", err) } diff --git a/main.go b/main.go index cb37ef8..36e6cbd 100644 --- a/main.go +++ b/main.go @@ -42,6 +42,8 @@ type Ui struct { Height int Result []byte Vault vault.Client + Mounts []string + CurrentMount int } const ( @@ -50,7 +52,7 @@ const ( var ( STYLE_KEY = tcell.StyleDefault.Foreground(tcell.ColorBlue) - STYLE_STRING = tcell.StyleDefault.Foreground(tcell.ColorGreen) + STYLE_STRING = tcell.StyleDefault.Foreground(tcell.ColorPink) STYLE_NULL = tcell.StyleDefault.Foreground(tcell.ColorGray) STYLE_DEFAULT = tcell.StyleDefault ) @@ -60,8 +62,9 @@ func main() { vaultClient := vault.Client{ Addr: mustGetEnv("VAULT_ADDR"), Token: mustGetEnv("VAULT_TOKEN"), - Mount: mustGetEnv("VAULT_MOUNT"), } + mounts := []string{} + mounts = vaultClient.GetMounts() if len(os.Getenv("DEBUG")) > 0 { logFile, err := os.Create("./log") if err != nil { @@ -81,8 +84,10 @@ func main() { screen.EnablePaste() screen.Clear() state := Ui{ - Screen: screen, - Vault: vaultClient, + Screen: screen, + Vault: vaultClient, + Mounts: mounts, + CurrentMount: 0, } quit := func() { // You have to catch panics in a defer, clean up, and @@ -101,7 +106,7 @@ func main() { drawPrompt(state) drawLoadingScreen(state) screen.Show() - state.Keys = vault.GetKeys(vaultClient) + state.Keys = vault.GetKeys(vaultClient, state.Mounts[state.CurrentMount]) newKeysView(&state) for { ev := screen.PollEvent() @@ -115,7 +120,6 @@ func main() { state.ViewStart = 0 } case *tcell.EventKey: - // TODO: Add ability to switch between key-value vault mounts switch ev.Key() { case tcell.KeyEscape, tcell.KeyCtrlC: return @@ -136,6 +140,20 @@ func main() { case tcell.KeyCtrlU: state.Prompt = "" newKeysView(&state) + case tcell.KeyCtrlO: + state.CurrentMount = (state.CurrentMount + 1) % len(state.Mounts) + state.Keys = vault.GetKeys(state.Vault, state.Mounts[state.CurrentMount]) + state.Prompt = "" + newKeysView(&state) + case tcell.KeyCtrlI: + if state.CurrentMount == 0 { + state.CurrentMount = len(state.Mounts) - 1 + } else { + state.CurrentMount-- + } + state.Keys = vault.GetKeys(state.Vault, state.Mounts[state.CurrentMount]) + state.Prompt = "" + newKeysView(&state) case tcell.KeyRune: state.Prompt += string(ev.Rune()) newKeysView(&state) @@ -250,8 +268,17 @@ func drawData(s tcell.Screen, x int, y *int, name string, data map[string]interf } func drawStats(s Ui) { - nKeys := len(s.Keys) - drawLine(s.Screen, 2, s.Height-2, tcell.StyleDefault.Foreground(tcell.ColorYellow), fmt.Sprintf("%d", nKeys)) + nKeysStr := fmt.Sprint(len(s.Keys)) + drawLine(s.Screen, 2, s.Height-2, tcell.StyleDefault.Foreground(tcell.ColorYellow), nKeysStr) + mountsStr := "" + for i, m := range s.Mounts { + if i == s.CurrentMount { + mountsStr = fmt.Sprintf("%s [%s]", mountsStr, m) + } else { + mountsStr = fmt.Sprintf("%s %s ", mountsStr, m) + } + } + drawLine(s.Screen, 4, s.Height-2, tcell.StyleDefault.Foreground(tcell.ColorYellow), mountsStr) } func drawPrompt(s Ui) { @@ -298,7 +325,7 @@ func newKeysView(s *Ui) { func setSecret(s *Ui) { if len(s.FilteredKeys) > 0 { - s.Secret = s.Vault.GetSecret(s.FilteredKeys[s.ViewStart+s.Cursor]) + s.Secret = s.Vault.GetSecret(s.Mounts[s.CurrentMount], s.FilteredKeys[s.ViewStart+s.Cursor]) } else { s.Secret = vault.Secret{} }