diff --git a/hcl2template/utils.go b/hcl2template/utils.go index a542a2467ac..ce9486ed676 100644 --- a/hcl2template/utils.go +++ b/hcl2template/utils.go @@ -11,6 +11,7 @@ import ( "github.com/gobwas/glob" "github.com/hashicorp/hcl/v2" + "github.com/hashicorp/hcl/v2/hclsyntax" "github.com/hashicorp/packer/hcl2template/repl" hcl2shim "github.com/hashicorp/packer/hcl2template/shim" "github.com/zclconf/go-cty/cty" @@ -187,22 +188,47 @@ func ConvertPluginConfigValueToHCLValue(v interface{}) (cty.Value, error) { return buildValue, nil } +// GetVarsByType walks through a hcl body, and gathers all the Traversals that +// have a root type matching one of the specified top-level labels. +// +// This will only work on finite, expanded, HCL bodies. func GetVarsByType(block *hcl.Block, topLevelLabels ...string) []hcl.Traversal { - attributes, _ := block.Body.JustAttributes() - - var vars []hcl.Traversal - - for _, attr := range attributes { - for _, variable := range attr.Expr.Variables() { - rootLabel := variable.RootName() - for _, label := range topLevelLabels { - if label == rootLabel { - vars = append(vars, variable) - break - } + var travs []hcl.Traversal + + switch body := block.Body.(type) { + case *hclsyntax.Body: + travs = getVarsByTypeForHCLSyntaxBody(body) + default: + attrs, _ := body.JustAttributes() + for _, attr := range attrs { + travs = append(travs, attr.Expr.Variables()...) + } + } + + var rets []hcl.Traversal + for _, t := range travs { + varRootname := t.RootName() + for _, lbl := range topLevelLabels { + if varRootname == lbl { + rets = append(rets, t) + break } } } - return vars + return rets +} + +func getVarsByTypeForHCLSyntaxBody(body *hclsyntax.Body) []hcl.Traversal { + var rets []hcl.Traversal + + for _, attr := range body.Attributes { + rets = append(rets, attr.Expr.Variables()...) + } + + for _, block := range body.Blocks { + rets = append(rets, getVarsByTypeForHCLSyntaxBody(block.Body)...) + } + + return rets }