diff --git a/internal/cli/main.go b/internal/cli/main.go index e1d21934d..29330057a 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -67,10 +67,20 @@ func printUsageAndExit() { func dbMakeConnectionString(driver, user, password, address, name, ssl string) string { return fmt.Sprintf("%s://%s:%s@%s/%s?sslmode=%s", - driver, url.QueryEscape(user), url.QueryEscape(password), address, name, ssl, + driver, EscapeIfNeeded(user), EscapeIfNeeded(password), address, name, ssl, ) } +func EscapeIfNeeded(str string) string { + unescapedStr, err := url.QueryUnescape(str) + if err != nil || unescapedStr == str { + // If the str is already unescaped or an error occurred, escape it + return url.QueryEscape(str) + } + // If the str was successfully unescaped and is different from the original, return the original + return str +} + // Main function of a cli application. It is public for backwards compatibility with `cli` package func Main(version string) { help := viper.GetBool("help") diff --git a/internal/cli/main_test.go b/internal/cli/main_test.go new file mode 100644 index 000000000..ab3c192db --- /dev/null +++ b/internal/cli/main_test.go @@ -0,0 +1,48 @@ +package cli + +import ( + "testing" +) + +func TestEscapeIfNeeded(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "AlreadyEscaped", + input: "hello%20world", + expected: "hello%20world", + }, + { + name: "Unescaped", + input: "hello world", + expected: "hello+world", + }, + { + name: "PartiallyEscaped", + input: "hello%20world!", + expected: "hello%20world!", + }, + { + name: "EmptyString", + input: "", + expected: "", + }, + { + name: "SpecialCharacters", + input: "hello@world.com", + expected: "hello%40world.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := EscapeIfNeeded(tt.input) + if result != tt.expected { + t.Errorf("EscapeIfNeeded(%q) actual = %q; want = %q", tt.input, result, tt.expected) + } + }) + } +}