diff --git a/experimental/plugins/macro/macro.go b/experimental/plugins/macro/macro.go index ec756f0d..fabd039d 100644 --- a/experimental/plugins/macro/macro.go +++ b/experimental/plugins/macro/macro.go @@ -99,38 +99,38 @@ func (m *macro) compile(input string) error { return fmt.Errorf("empty macro") } - currentToken := strings.Builder{} m.original = input + var currentToken strings.Builder isMacro := false + for i := 0; i < l; i++ { c := input[i] - if c == '%' && (i <= l && input[i+1] == '{') { - // we have a macro + + if c == '%' && i+1 < l && input[i+1] == '{' { if currentToken.Len() > 0 { - // we add the text token m.tokens = append(m.tokens, macroToken{ text: currentToken.String(), variable: variables.Unknown, key: "", }) + currentToken.Reset() } - currentToken.Reset() isMacro = true - i++ + i++ // Skip '{' continue } if isMacro { if c == '}' { - // we close a macro isMacro = false - // TODO(jcchavezs): key should only be empty in single collections + if input[i-1] == '.' { + return fmt.Errorf("empty variable name") + } varName, key, _ := strings.Cut(currentToken.String(), ".") v, err := variables.Parse(varName) if err != nil { return fmt.Errorf("unknown variable %q", varName) } - // we add the variable token m.tokens = append(m.tokens, macroToken{ text: currentToken.String(), variable: v, @@ -140,8 +140,7 @@ func (m *macro) compile(input string) error { continue } - if !(c == '.' || c == '_' || c == '-' || (c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) { - currentToken.WriteByte(c) + if !isValidMacroChar(c) { return fmt.Errorf("malformed variable starting with %q", "%{"+currentToken.String()) } @@ -152,10 +151,10 @@ func (m *macro) compile(input string) error { } continue } - // we have a normal character + currentToken.WriteByte(c) } - // if there is something left + if currentToken.Len() > 0 { m.tokens = append(m.tokens, macroToken{ text: currentToken.String(), @@ -166,6 +165,10 @@ func (m *macro) compile(input string) error { return nil } +func isValidMacroChar(c byte) bool { + return c == '.' || c == '_' || c == '-' || (c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') +} + // String returns the original string func (m *macro) String() string { return m.original diff --git a/experimental/plugins/macro/macro_test.go b/experimental/plugins/macro/macro_test.go index 94e9579b..7ea662df 100644 --- a/experimental/plugins/macro/macro_test.go +++ b/experimental/plugins/macro/macro_test.go @@ -36,7 +36,39 @@ func TestCompile(t *testing.T) { } }) - t.Run("malformed macro", func(t *testing.T) { + t.Run("single percent sign", func(t *testing.T) { + m := ¯o{} + err := m.compile("%") + if err != nil { + t.Errorf("single percent sign should not error") + } + }) + + t.Run("empty braces", func(t *testing.T) { + m := ¯o{} + err := m.compile("%{}") + if err == nil { + t.Errorf("expected error for empty braces") + } + }) + + t.Run("missing key", func(t *testing.T) { + m := ¯o{} + err := m.compile("%{tx.}") + if err == nil { + t.Errorf("expected error for missing key") + } + }) + + t.Run("missing collection", func(t *testing.T) { + m := ¯o{} + err := m.compile("%{.key}") + if err == nil { + t.Errorf("expected error for missing collection") + } + }) + + t.Run("malformed macros", func(t *testing.T) { for _, test := range []string{"%{tx.count", "%{{tx.count}", "%{{tx.{count}", "something %{tx.count"} { t.Run(test, func(t *testing.T) { m := ¯o{}