diff --git a/assert_test.go b/assert_test.go new file mode 100644 index 000000000..0394b4b08 --- /dev/null +++ b/assert_test.go @@ -0,0 +1,94 @@ +// Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. + +package gosnowflake + +import ( + "fmt" + "reflect" + "strings" + "testing" +) + +func assertNilF(t *testing.T, actual any, descriptions ...string) { + fatalOnNonEmpty(t, validateNil(actual, descriptions...)) +} + +func assertNotNilF(t *testing.T, actual any, descriptions ...string) { + fatalOnNonEmpty(t, validateNotNil(actual, descriptions...)) +} + +func assertEqualE(t *testing.T, actual any, expected any, descriptions ...string) { + errorOnNonEmpty(t, validateEqual(actual, expected, descriptions...)) +} + +func assertStringContainsE(t *testing.T, actual string, expectedToContain string, descriptions ...string) { + errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...)) +} + +func assertHasPrefixE(t *testing.T, actual string, expectedPrefix string, descriptions ...string) { + errorOnNonEmpty(t, validateHasPrefix(actual, expectedPrefix, descriptions...)) +} + +func fatalOnNonEmpty(t *testing.T, errMsg string) { + if errMsg != "" { + t.Fatal(errMsg) + } +} + +func errorOnNonEmpty(t *testing.T, errMsg string) { + if errMsg != "" { + t.Error(errMsg) + } +} + +func validateNil(actual any, descriptions ...string) string { + if isNil(actual) { + return "" + } + desc := joinDescriptions(descriptions...) + return fmt.Sprintf("expected \"%s\" to be nil but was not. %s", actual, desc) +} + +func validateNotNil(actual any, descriptions ...string) string { + if !isNil(actual) { + return "" + } + desc := joinDescriptions(descriptions...) + return fmt.Sprintf("expected to be not nil but was not. %s", desc) +} + +func validateEqual(actual any, expected any, descriptions ...string) string { + if expected == actual { + return "" + } + desc := joinDescriptions(descriptions...) + return fmt.Sprintf("expected \"%s\" to be equal to \"%s\" but was not. %s", actual, expected, desc) +} + +func validateStringContains(actual string, expectedToContain string, descriptions ...string) string { + if strings.Contains(actual, expectedToContain) { + return "" + } + desc := joinDescriptions(descriptions...) + return fmt.Sprintf("expected \"%s\" to contain \"%s\" but did not. %s", actual, expectedToContain, desc) +} + +func validateHasPrefix(actual string, expectedPrefix string, descriptions ...string) string { + if strings.HasPrefix(actual, expectedPrefix) { + return "" + } + desc := joinDescriptions(descriptions...) + return fmt.Sprintf("expected \"%s\" to start with \"%s\" but did not. %s", actual, expectedPrefix, desc) +} + +func joinDescriptions(descriptions ...string) string { + return strings.Join(descriptions, " ") +} + +func isNil(value any) bool { + if value == nil { + return true + } + val := reflect.ValueOf(value) + return val.Kind() == reflect.Pointer && val.IsNil() +} diff --git a/client_configuration_test.go b/client_configuration_test.go index b7ecdd2f0..8da3be3ec 100644 --- a/client_configuration_test.go +++ b/client_configuration_test.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "path" - "strings" "testing" ) @@ -59,15 +58,9 @@ func TestParseConfiguration(t *testing.T) { config, err := parseClientConfiguration(fileName) - if err != nil { - t.Fatalf("Error should be nil but was %s", err) - } - if config.Common.LogLevel != tc.expectedLogLevel { - t.Errorf("Log level should be %s but was %s", tc.expectedLogLevel, config.Common.LogLevel) - } - if config.Common.LogPath != tc.expectedLogPath { - t.Errorf("Log path should be %s but was %s", tc.expectedLogPath, config.Common.LogPath) - } + assertNilF(t, err, "parse client configuration error") + assertEqualE(t, config.Common.LogLevel, tc.expectedLogLevel, "log level") + assertEqualE(t, config.Common.LogPath, tc.expectedLogPath, "log path") }) } } @@ -86,12 +79,8 @@ func TestParseAllLogLevels(t *testing.T) { config, err := parseClientConfiguration(fileName) - if err != nil { - t.Fatalf("Error should be nil but was: %s", err) - } - if config.Common.LogLevel != logLevel { - t.Errorf("Log level should be %s but was %s", logLevel, config.Common.LogLevel) - } + assertNilF(t, err, "parse client config error") + assertEqualE(t, config.Common.LogLevel, logLevel, "log level") }) } } @@ -150,17 +139,11 @@ func TestParseConfigurationFails(t *testing.T) { _, err := parseClientConfiguration(fileName) - if err == nil { - t.Fatal("Error should not be nil but was nil") - } + assertNotNilF(t, err, "parse client configuration error") errMessage := fmt.Sprint(err) expectedPrefix := "parsing client config failed" - if !strings.HasPrefix(errMessage, expectedPrefix) { - t.Errorf("Error message: \"%s\" should start with prefix: \"%s\"", errMessage, expectedPrefix) - } - if !strings.Contains(errMessage, tc.expectedErrorMessageToContain) { - t.Errorf("Error message: \"%s\" should contain given phrase: \"%s\"", errMessage, tc.expectedErrorMessageToContain) - } + assertHasPrefixE(t, errMessage, expectedPrefix, "error message") + assertStringContainsE(t, errMessage, tc.expectedErrorMessageToContain, "error message") }) } } @@ -168,8 +151,6 @@ func TestParseConfigurationFails(t *testing.T) { func createFile(t *testing.T, fileName string, fileContents string, directory string) string { fullFileName := path.Join(directory, fileName) err := os.WriteFile(fullFileName, []byte(fileContents), 0644) - if err != nil { - t.Fatal("Could not create file") - } + assertNilF(t, err, "create file error") return fullFileName }