Skip to content

Commit

Permalink
detect NUL in string and support []byte with base64 for binary (#8)
Browse files Browse the repository at this point in the history
* detect NUL in string and support []byte with base64 for binary

* add support for Duration and Time (in export, still to do the reverse). Make the SetFromEnv pluggable

* switch how to hook different lookup to simple function instead of interface, also fix bug found when using the mock in test (recurse was using the os version)

* implement reverse env to struct for duration in floating point seconds and date/time in rfc3339

* split setting of leaf value into function

* update readme and package header
  • Loading branch information
ldemailly authored Nov 7, 2023
1 parent 752d976 commit 95f1f47
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 71 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,10 @@ Using
kv, errs := struct2env.StructToEnvVars(foo)
txt := struct2env.ToShellWithPrefix("TST_", kv)
```

Type conversions:

- Most primitive type to their string representation, single quote (') escaped.
- []byte are encoded as base64
- time.Time are formatted as RFC3339
- time.Duration are in (floating point) seconds.
174 changes: 128 additions & 46 deletions env.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
// Package env provides conversion from structure to and from environment variables.
//
// It supports converting struct fields to environment variables using field tags,
// handling different data types, and transforming strings between different case
// conventions, which is useful for generating or parsing environment variables,
// JSON tags, or command line flags.
// Supports converting struct fields to environment variables using field tags,
// handling most data types. Provides functions to serialize structs into slices
// of key-value pairs where the keys are derived from struct field names transformed
// to upper snake case by default, or specified explicitly via struct field tags.
//
// The package also defines several case conversion functions that aid in manipulating
// strings to fit conventional casing for various programming and configuration contexts.
// Additionally, it provides functions to serialize structs into slices of key-value pairs
// where the keys are derived from struct field names transformed to upper snake case by default,
// or specified explicitly via struct field tags.
//
// It also includes functionality to deserialize environment variables back into
// Includes functionality to deserialize environment variables back into
// struct fields, handling pointers and nested structs appropriately, as well as providing
// shell-compatible output for environment variable definitions.
//
// Incidentally the package also defines several case conversion functions that aid in manipulating
// which is useful for generating or parsing environment variables,
// JSON tags, or command line flags style of naming (camelCase, UPPER_SNAKE_CASE, lower-kebab-case ...)
//
// The package leverages reflection to dynamically handle arbitrary struct types,
// and has 0 dependencies.
package struct2env

import (
"encoding/base64"
"fmt"
"os"
"reflect"
"strconv"
"strings"
"time"
"unicode"
)

Expand Down Expand Up @@ -97,12 +97,16 @@ type KeyValue struct {
}

// Escape characters such as the result string can be embedded as a single argument in a shell fragment
// e.g for ENV_VAR=<value> such as <value> is safe (no $(cmd...) no ` etc`).
func ShellQuote(input string) string {
// e.g for ENV_VAR=<value> such as <value> is safe (no $(cmd...) no ` etc`). Will error out if NUL is found
// in the input (use []byte for that and it'll get base64 encoded/decoded).
func ShellQuote(input string) (string, error) {
if strings.ContainsRune(input, 0) {
return "", fmt.Errorf("String value %q should not contain NUL", input)
}
// To emit a single quote in a single quote enclosed string you have to close the current ' then emit a quote (\'),
// then reopen the single quote sequence to finish. Note that when the string ends with a quote there is an unnecessary
// trailing ''.
return "'" + strings.ReplaceAll(input, "'", `'\''`) + "'"
return "'" + strings.ReplaceAll(input, "'", `'\''`) + "'", nil
}

func (kv KeyValue) String() string {
Expand All @@ -129,16 +133,20 @@ func ToShellWithPrefix(prefix string, kvl []KeyValue) string {
return sb.String()
}

func SerializeValue(value interface{}) string {
func SerializeValue(value interface{}) (string, error) {
switch v := value.(type) {
case bool:
res := "false"
if v {
res = "true"
}
return res
return res, nil
case []byte:
return ShellQuote(base64.StdEncoding.EncodeToString(v))
case string:
return ShellQuote(v)
case time.Duration:
return fmt.Sprintf("%g", v.Seconds()), nil
default:
return ShellQuote(fmt.Sprint(value))
}
Expand All @@ -151,6 +159,7 @@ func SerializeValue(value interface{}) string {
// If the field is exportable and the tag is missing we'll use the field name
// converted to UPPER_SNAKE_CASE (using CamelCaseToUpperSnakeCase()) as the
// environment variable name.
// []byte are encoded as base64, time.Time are formatted as RFC3339, time.Duration are in (floating point) seconds.
func StructToEnvVars(s interface{}) ([]KeyValue, []error) {
var allErrors []error
var allKeyValVals []KeyValue
Expand Down Expand Up @@ -186,24 +195,49 @@ func structToEnvVars(envVars []KeyValue, allErrors []error, prefix string, s int
}
fieldValue := v.Field(i)
stringValue := ""
var err error

if fieldValue.Type() == reflect.TypeOf(time.Time{}) { // other wise we hit the "struct" case below
timeField := fieldValue.Interface().(time.Time)
stringValue, err = SerializeValue(timeField.Format(time.RFC3339))
if err != nil {
allErrors = append(allErrors, err)
} else {
envVars = append(envVars, KeyValue{Key: prefix + tag, QuotedValue: stringValue})
}
continue // Continue to the next field
}

switch fieldValue.Kind() { //nolint: exhaustive // we have default: for the other cases
case reflect.Ptr:
if !fieldValue.IsNil() {
fieldValue = fieldValue.Elem()
stringValue = SerializeValue(fieldValue.Interface())
stringValue, err = SerializeValue(fieldValue.Interface())
}
case reflect.Map, reflect.Array, reflect.Chan, reflect.Slice:
// log.LogVf("Skipping field %s of type %v, not supported", fieldType.Name, fieldType.Type)
continue
// From that list of other types, only support []byte
if fieldValue.Type().Elem().Kind() == reflect.Uint8 {
stringValue, err = SerializeValue(fieldValue.Interface())
} else {
// log.LogVf("Skipping field %s of type %v, not supported", fieldType.Name, fieldType.Type)
continue
}
case reflect.Struct:
// Recurse with prefix
envVars, allErrors = structToEnvVars(envVars, allErrors, tag+"_", fieldValue.Interface())
continue
default:
value := fieldValue.Interface()
stringValue = SerializeValue(value)
if !fieldValue.CanInterface() {
err = fmt.Errorf("can't interface %s", fieldType.Name)
} else {
value := fieldValue.Interface()
stringValue, err = SerializeValue(value)
}
}
envVars = append(envVars, KeyValue{Key: prefix + tag, QuotedValue: stringValue})
if err != nil {
allErrors = append(allErrors, err)
}
}
return envVars, allErrors
}
Expand All @@ -217,8 +251,8 @@ func setPointer(fieldValue reflect.Value) reflect.Value {
return fieldValue.Elem()
}

func checkEnv(envName, fieldName string, fieldValue reflect.Value) (*string, error) {
val, found := os.LookupEnv(envName)
func checkEnv(envLookup EnvLookup, envName, fieldName string, fieldValue reflect.Value) (*string, error) {
val, found := envLookup(envName)
if !found {
// log.LogVf("%q not set for %s", envName, fieldName)
return nil, nil //nolint:nilnil
Expand All @@ -231,11 +265,19 @@ func checkEnv(envName, fieldName string, fieldValue reflect.Value) (*string, err
return &val, nil
}

type EnvLookup func(key string) (string, bool)

// Reverse of StructToEnvVars, assumes the same encoding. Using the current os environment variables as source.
func SetFromEnv(prefix string, s interface{}) []error {
return setFromEnv(nil, prefix, s)
return SetFrom(os.LookupEnv, prefix, s)
}

func setFromEnv(allErrors []error, prefix string, s interface{}) []error {
// Reverse of StructToEnvVars, assumes the same encoding. Using passed it lookup object that can lookup values by keys.
func SetFrom(envLookup EnvLookup, prefix string, s interface{}) []error {
return setFromEnv(nil, envLookup, prefix, s)
}

func setFromEnv(allErrors []error, envLookup EnvLookup, prefix string, s interface{}) []error {
// TODO: this is quite similar in structure to structToEnvVars() - can it be refactored with
// passing setter vs getter function and share the same iteration (yet a little bit of copy is the go way too)
v := reflect.ValueOf(s)
Expand Down Expand Up @@ -263,17 +305,18 @@ func setFromEnv(allErrors []error, prefix string, s interface{}) []error {

kind := fieldValue.Kind()

if kind == reflect.Struct {
// Handle time.Time separately a bit below after we get the value
if kind == reflect.Struct && fieldType.Type != reflect.TypeOf(time.Time{}) {
// Recurse with prefix
if fieldValue.CanAddr() { // Check if we can get the address
SetFromEnv(envName+"_", fieldValue.Addr().Interface())
allErrors = setFromEnv(allErrors, envLookup, envName+"_", fieldValue.Addr().Interface())
} else {
err := fmt.Errorf("cannot take the address of %s to recurse", fieldType.Name)
allErrors = append(allErrors, err)
}
continue
}
val, err := checkEnv(envName, fieldType.Name, fieldValue)
val, err := checkEnv(envLookup, envName, fieldType.Name, fieldValue)
if err != nil {
allErrors = append(allErrors, err)
continue
Expand All @@ -288,33 +331,72 @@ func setFromEnv(allErrors []error, prefix string, s interface{}) []error {
kind = fieldValue.Type().Elem().Kind()
fieldValue = setPointer(fieldValue)
}
switch kind { //nolint: exhaustive // we have default: for the other cases
case reflect.String:
fieldValue.SetString(envVal)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var ev int64
ev, err = strconv.ParseInt(envVal, 10, fieldValue.Type().Bits())
if fieldType.Type == reflect.TypeOf(time.Time{}) {
var timeField time.Time
timeField, err = time.Parse(time.RFC3339, envVal)
if err == nil {
fieldValue.SetInt(ev)
fieldValue.Set(reflect.ValueOf(timeField))
} else {
allErrors = append(allErrors, err)
}
case reflect.Float32, reflect.Float64:
continue
}
allErrors = setValue(allErrors, fieldType, fieldValue, kind, envName, envVal)
}
return allErrors
}

func setValue(
allErrors []error,
fieldType reflect.StructField,
fieldValue reflect.Value,
kind reflect.Kind,
envName, envVal string,
) []error {
var err error
switch kind { //nolint: exhaustive // we have default: for the other cases
case reflect.String:
fieldValue.SetString(envVal)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// if it's a duration, parse it as a float seconds
if fieldType.Type == reflect.TypeOf(time.Duration(0)) {
var ev float64
ev, err = strconv.ParseFloat(envVal, fieldValue.Type().Bits())
ev, err = strconv.ParseFloat(envVal, 64)
if err == nil {
fieldValue.SetFloat(ev)
fieldValue.SetInt(int64(ev * float64(1*time.Second)))
}
case reflect.Bool:
var ev bool
ev, err = strconv.ParseBool(envVal)
} else {
var ev int64
ev, err = strconv.ParseInt(envVal, 10, fieldValue.Type().Bits())
if err == nil {
fieldValue.SetBool(ev)
fieldValue.SetInt(ev)
}
default:
err = fmt.Errorf("unsupported type %v to set from %s=%q", kind, envName, envVal)
}
if err != nil {
allErrors = append(allErrors, err)
case reflect.Float32, reflect.Float64:
var ev float64
ev, err = strconv.ParseFloat(envVal, fieldValue.Type().Bits())
if err == nil {
fieldValue.SetFloat(ev)
}
case reflect.Bool:
var ev bool
ev, err = strconv.ParseBool(envVal)
if err == nil {
fieldValue.SetBool(ev)
}
case reflect.Slice:
if fieldValue.Type().Elem().Kind() != reflect.Uint8 {
err = fmt.Errorf("unsupported slice of %v to set from %s=%q", fieldValue.Type().Elem().Kind(), envName, envVal)
} else {
var data []byte
data, err = base64.StdEncoding.DecodeString(envVal)
fieldValue.SetBytes(data)
}
default:
err = fmt.Errorf("unsupported type %v to set from %s=%q", kind, envName, envVal)
}
if err != nil {
allErrors = append(allErrors, err)
}
return allErrors
}
Loading

0 comments on commit 95f1f47

Please sign in to comment.