From f4fcf01a4755f50f3e2692829546ff9993a697ef Mon Sep 17 00:00:00 2001 From: Laurent Demailly Date: Mon, 6 Nov 2023 16:19:34 -0800 Subject: [PATCH] detect NUL in string and support []byte with base64 for binary --- env.go | 42 +++++++++++++++++++++++++++++++++--------- env_test.go | 32 +++++++++++++++++++++++++++----- 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/env.go b/env.go index 60aab0a..cc6b096 100644 --- a/env.go +++ b/env.go @@ -20,6 +20,7 @@ package struct2env import ( + "encoding/base64" "fmt" "os" "reflect" @@ -97,12 +98,16 @@ type KeyValue struct { } // Escape characters such as the result string can be embedded as a single argument in a shell fragment -// e.g for ENV_VAR= such as is safe (no $(cmd...) no ` etc`). -func ShellQuote(input string) string { +// e.g for ENV_VAR= such as is safe (no $(cmd...) no ` etc`). Will error out if NUL is found +// in the input (use []byte for that and it'll get base64 encoded/decoded). +func ShellQuote(input string) (string, error) { + if strings.ContainsRune(input, 0) { + return "", fmt.Errorf("String value %q should not contain NUL", input) + } // To emit a single quote in a single quote enclosed string you have to close the current ' then emit a quote (\'), // then reopen the single quote sequence to finish. Note that when the string ends with a quote there is an unnecessary // trailing ''. - return "'" + strings.ReplaceAll(input, "'", `'\''`) + "'" + return "'" + strings.ReplaceAll(input, "'", `'\''`) + "'", nil } func (kv KeyValue) String() string { @@ -129,14 +134,16 @@ func ToShellWithPrefix(prefix string, kvl []KeyValue) string { return sb.String() } -func SerializeValue(value interface{}) string { +func SerializeValue(value interface{}) (string, error) { switch v := value.(type) { case bool: res := "false" if v { res = "true" } - return res + return res, nil + case []byte: + return ShellQuote(base64.StdEncoding.EncodeToString(v)) case string: return ShellQuote(v) default: @@ -186,24 +193,33 @@ func structToEnvVars(envVars []KeyValue, allErrors []error, prefix string, s int } fieldValue := v.Field(i) stringValue := "" + var err error switch fieldValue.Kind() { //nolint: exhaustive // we have default: for the other cases case reflect.Ptr: if !fieldValue.IsNil() { fieldValue = fieldValue.Elem() - stringValue = SerializeValue(fieldValue.Interface()) + stringValue, err = SerializeValue(fieldValue.Interface()) } case reflect.Map, reflect.Array, reflect.Chan, reflect.Slice: - // log.LogVf("Skipping field %s of type %v, not supported", fieldType.Name, fieldType.Type) - continue + // From that list of other types, only support []byte + if fieldValue.Type().Elem().Kind() == reflect.Uint8 { + stringValue, err = SerializeValue(fieldValue.Interface()) + } else { + // log.LogVf("Skipping field %s of type %v, not supported", fieldType.Name, fieldType.Type) + continue + } case reflect.Struct: // Recurse with prefix envVars, allErrors = structToEnvVars(envVars, allErrors, tag+"_", fieldValue.Interface()) continue default: value := fieldValue.Interface() - stringValue = SerializeValue(value) + stringValue, err = SerializeValue(value) } envVars = append(envVars, KeyValue{Key: prefix + tag, QuotedValue: stringValue}) + if err != nil { + allErrors = append(allErrors, err) + } } return envVars, allErrors } @@ -309,6 +325,14 @@ func setFromEnv(allErrors []error, prefix string, s interface{}) []error { if err == nil { fieldValue.SetBool(ev) } + case reflect.Slice: + if fieldValue.Type().Elem().Kind() != reflect.Uint8 { + err = fmt.Errorf("unsupported slice of %v to set from %s=%q", fieldValue.Type().Elem().Kind(), envName, envVal) + } else { + var data []byte + data, err = base64.StdEncoding.DecodeString(envVal) + fieldValue.SetBytes(data) + } default: err = fmt.Errorf("unsupported type %v to set from %s=%q", kind, envName, envVal) } diff --git a/env_test.go b/env_test.go index 4e53418..fb92d0d 100644 --- a/env_test.go +++ b/env_test.go @@ -118,12 +118,13 @@ type FooConfig struct { Embedded HiddenEmbedded `env:"-"` RecurseHere Embedded + SomeBinary []byte } func TestStructToEnvVars(t *testing.T) { intV := 199 foo := FooConfig{ - Foo: "a\nfoo with $X, `backticks`, \" quotes and \\ and ' in middle and end '", + Foo: "a newline:\nfoo with $X, `backticks`, \" quotes and \\ and ' in middle and end '", Bar: "42str", Blah: 42, ABool: true, @@ -135,6 +136,7 @@ func TestStructToEnvVars(t *testing.T) { InnerA: "rec a", InnerB: "rec b", }, + SomeBinary: []byte{0, 1, 2}, } foo.InnerA = "inner a" foo.InnerB = "inner b" @@ -149,12 +151,12 @@ func TestStructToEnvVars(t *testing.T) { if len(errors) != 0 { t.Errorf("expected no error, got %v", errors) } - if len(envVars) != 11 { - t.Errorf("expected 11 env vars, got %d: %+v", len(envVars), envVars) + if len(envVars) != 12 { + t.Errorf("expected 12 env vars, got %d: %+v", len(envVars), envVars) } str := ToShellWithPrefix("TST_", envVars) //nolint:lll - expected := `TST_FOO='a + expected := `TST_FOO='a newline: foo with $X, ` + "`backticks`" + `, " quotes and \ and '\'' in middle and end '\''' TST_BAR='42str' TST_A_SPECIAL_BLAH='42' @@ -166,11 +168,27 @@ TST_INNER_A='inner a' TST_INNER_B='inner b' TST_RECURSE_HERE_INNER_A='rec a' TST_RECURSE_HERE_INNER_B='rec b' -export TST_FOO TST_BAR TST_A_SPECIAL_BLAH TST_A_BOOL TST_HTTP_SERVER TST_INT_POINTER TST_FLOAT_POINTER TST_INNER_A TST_INNER_B TST_RECURSE_HERE_INNER_A TST_RECURSE_HERE_INNER_B +TST_SOME_BINARY='AAEC' +export TST_FOO TST_BAR TST_A_SPECIAL_BLAH TST_A_BOOL TST_HTTP_SERVER TST_INT_POINTER TST_FLOAT_POINTER TST_INNER_A TST_INNER_B TST_RECURSE_HERE_INNER_A TST_RECURSE_HERE_INNER_B TST_SOME_BINARY ` if str != expected { t.Errorf("\n---expected:---\n%s\n---got:---\n%s", expected, str) } + // NUL in string + type Cfg struct { + Foo string + } + cfg := Cfg{Foo: "ABC\x00DEF"} + envVars, errors = StructToEnvVars(&cfg) + if len(errors) != 1 { + t.Errorf("Should have had error with embedded NUL") + } + if envVars[0].Key != "FOO" { + t.Errorf("Expecting key to be present %v", envVars) + } + if envVars[0].QuotedValue != "" { + t.Errorf("Expecting value to be empty %v", envVars) + } } func TestSetFromEnv(t *testing.T) { @@ -186,6 +204,7 @@ func TestSetFromEnv(t *testing.T) { {"TST2_A_BOOL", "1"}, {"TST2_FLOAT_POINTER", "5.75"}, {"TST2_INT_POINTER", "73"}, + {"TST2_SOME_BINARY", "QUJDAERFRg=="}, } for _, e := range envs { os.Setenv(e.k, e.v) @@ -199,4 +218,7 @@ func TestSetFromEnv(t *testing.T) { foo.IntPointer == nil || *foo.IntPointer != 73 { t.Errorf("Mismatch in object values, got: %+v", foo) } + if string(foo.SomeBinary) != "ABC\x00DEF" { + t.Errorf("Base64 decoding not working for []byte field: %q", string(foo.SomeBinary)) + } }