Skip to content
This repository has been archived by the owner on May 4, 2023. It is now read-only.

Commit

Permalink
This commit fixes the execute variable in macros (#47)
Browse files Browse the repository at this point in the history
This allows macros during the initial compile phase not to do anything
but during a run phase to execute remote calls.
  • Loading branch information
DomBlack authored Aug 19, 2021
1 parent e63373e commit 96c2ee6
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 53 deletions.
4 changes: 4 additions & 0 deletions compiler/GlobalContext.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ func (g *GlobalContext) GetMacro(name string) (compilerInterface.FunctionDef, er
newEC := ec.PushState()
// Note we copy any varaibles defined within the macro's own file in to the context being executed here too
macro.ec.CopyVariablesInto(newEC)

// We keep the caller and execute context however as these will change from when the macro was registered to when
// it is called
newEC.SetVariable("caller", ec.GetVariable("caller"))
newEC.SetVariable("execute", ec.GetVariable("execute"))

return macro.function(newEC, caller, args)
}, nil
Expand Down
74 changes: 65 additions & 9 deletions compiler/builtInFunctions.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"os"
"strings"

"ddbt/bigquery"
"ddbt/compiler/dbtUtils"
"ddbt/compilerInterface"
"ddbt/fs"
"ddbt/utils"
Expand Down Expand Up @@ -163,6 +165,15 @@ var builtInFunctions = map[string]compilerInterface.FunctionDef{
},

// Jinja2 Filter functions
"upper": func(ec compilerInterface.ExecutionContext, caller compilerInterface.AST, args compilerInterface.Arguments) (*compilerInterface.Value, error) {
values, err := requiredArgs(ec, caller, args, "upper", compilerInterface.StringVal)
if err != nil {
return nil, err
}

return compilerInterface.NewString(strings.ToUpper(values[0].AsStringValue())), nil
},

"lower": func(ec compilerInterface.ExecutionContext, caller compilerInterface.AST, args compilerInterface.Arguments) (*compilerInterface.Value, error) {
values, err := requiredArgs(ec, caller, args, "lower", compilerInterface.StringVal)
if err != nil {
Expand Down Expand Up @@ -226,15 +237,60 @@ var adapterFunctions = map[string]compilerInterface.FunctionDef{
"dispatch": noopMethod(),
"get_missing_columns": noopMethod(),
"expand_target_column_types": noopMethod(),
"get_relation": noopMethod(),
"get_columns_in_relation": noopMethod(),
"create_schema": noopMethod(),
"drop_schema": noopMethod(),
"drop_relation": noopMethod(),
"rename_relation": noopMethod(),
"get_columns_in_table": noopMethod(),
"already_exists": noopMethod(),
"adapter_macro": noopMethod(),
"get_relation": func(ec compilerInterface.ExecutionContext, caller compilerInterface.AST, args compilerInterface.Arguments) (*compilerInterface.Value, error) {
arguments, err := dbtUtils.GetArgs(args, dbtUtils.Param("database"), dbtUtils.Param("schema"), dbtUtils.Param("identifier"))
if err != nil {
return nil, ec.ErrorAt(caller, fmt.Sprintf("%s", err))
}

projectID := arguments[0].AsStringValue()
dataSet := arguments[1].AsStringValue()
table := arguments[2].AsStringValue()

return compilerInterface.NewString(
"`" + projectID + "`.`" + dataSet + "`.`" + table + "`",
), nil
},
"get_columns_in_relation": func(ec compilerInterface.ExecutionContext, caller compilerInterface.AST, args compilerInterface.Arguments) (*compilerInterface.Value, error) {
values, err := requiredArgs(ec, caller, args, "adapter.get_columns_in_relation", compilerInterface.StringVal)
if err != nil {
return nil, err
}

if isOnlyCompilingSQL(ec) {
return ec.MarkAsDynamicSQL()
}

returnColumns := make([]*compilerInterface.Value, 0)

target, err := ec.GetTarget()
if err != nil {
return nil, ec.ErrorAt(caller, fmt.Sprintf("Unable to get the columns in relation: %s", err.Error()))
}

columns, err := bigquery.GetColumnsFromTable(values[0].AsStringValue(), target)
if err != nil {
return nil, ec.ErrorAt(caller, fmt.Sprintf("Unable to get the columns in relation: %s", err.Error()))
}

for _, column := range columns {
columnMap := compilerInterface.NewMap(map[string]*compilerInterface.Value{
"name": compilerInterface.NewString(column.Name),
"column": compilerInterface.NewString(column.Name),
"data_type": compilerInterface.NewString(string(column.Type)),
})
returnColumns = append(returnColumns, columnMap)
}

return compilerInterface.NewList(returnColumns), nil
},
"create_schema": noopMethod(),
"drop_schema": noopMethod(),
"drop_relation": noopMethod(),
"rename_relation": noopMethod(),
"get_columns_in_table": noopMethod(),
"already_exists": noopMethod(),
"adapter_macro": noopMethod(),

// Note listed on their site
"check_schema_exists": noopMethod(),
Expand Down
25 changes: 13 additions & 12 deletions compiler/dbtUtils/queryMacros.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package dbtUtils

import (
"context"
"ddbt/bigquery"
"ddbt/compilerInterface"
"fmt"
"strconv"
"strings"

"ddbt/bigquery"
"ddbt/compilerInterface"
)

// GetColumnValues is a fallback GetColumnValuesWithContext
Expand All @@ -16,11 +17,11 @@ func GetColumnValues(ec compilerInterface.ExecutionContext, caller compilerInter
}

func GetColumnValuesWithContext(ctx context.Context, ec compilerInterface.ExecutionContext, caller compilerInterface.AST, arguments compilerInterface.Arguments) (*compilerInterface.Value, error) {
if isOnlyCompilingSQL(ec) {
if IsOnlyCompilingSQL(ec) {
return ec.MarkAsDynamicSQL()
}

args, err := getArgs(arguments, param("table"), param("column"), param("max_records"))
args, err := GetArgs(arguments, Param("table"), Param("column"), Param("max_records"))
if err != nil {
return nil, ec.ErrorAt(caller, fmt.Sprintf("%s", err))
}
Expand Down Expand Up @@ -66,17 +67,17 @@ func GetColumnValuesWithContext(ctx context.Context, ec compilerInterface.Execut
}

func Unpivot(ec compilerInterface.ExecutionContext, caller compilerInterface.AST, arguments compilerInterface.Arguments) (*compilerInterface.Value, error) {
if isOnlyCompilingSQL(ec) {
if IsOnlyCompilingSQL(ec) {
return ec.MarkAsDynamicSQL()
}

args, err := getArgs(arguments,
paramWithDefault("table", compilerInterface.NewString("")),
paramWithDefault("cast_to", compilerInterface.NewString("varchar")),
paramWithDefault("exclude", compilerInterface.NewList(make([]*compilerInterface.Value, 0))),
paramWithDefault("remove", compilerInterface.NewList(make([]*compilerInterface.Value, 0))),
paramWithDefault("field_name", compilerInterface.NewString("field_name")),
paramWithDefault("value_name", compilerInterface.NewString("value_name")),
args, err := GetArgs(arguments,
ParamWithDefault("table", compilerInterface.NewString("")),
ParamWithDefault("cast_to", compilerInterface.NewString("varchar")),
ParamWithDefault("exclude", compilerInterface.NewList(make([]*compilerInterface.Value, 0))),
ParamWithDefault("remove", compilerInterface.NewList(make([]*compilerInterface.Value, 0))),
ParamWithDefault("field_name", compilerInterface.NewString("field_name")),
ParamWithDefault("value_name", compilerInterface.NewString("value_name")),
)
if err != nil {
return nil, ec.ErrorAt(caller, fmt.Sprintf("%s", err))
Expand Down
26 changes: 13 additions & 13 deletions compiler/dbtUtils/replacements.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func UnionAllTables(ec compilerInterface.ExecutionContext, caller compilerInterface.AST, arguments compilerInterface.Arguments) (*compilerInterface.Value, error) {
args, err := getArgs(arguments, param("tables"), param("column_names"))
args, err := GetArgs(arguments, Param("tables"), Param("column_names"))
if err != nil {
return nil, ec.ErrorAt(caller, fmt.Sprintf("%s", err))
}
Expand Down Expand Up @@ -45,7 +45,7 @@ func UnionAllTables(ec compilerInterface.ExecutionContext, caller compilerInterf
}

func GroupBy(ec compilerInterface.ExecutionContext, caller compilerInterface.AST, arguments compilerInterface.Arguments) (*compilerInterface.Value, error) {
args, err := getArgs(arguments, paramWithDefault("n", compilerInterface.NewNumber(0)))
args, err := GetArgs(arguments, ParamWithDefault("n", compilerInterface.NewNumber(0)))
if err != nil {
return nil, ec.ErrorAt(caller, fmt.Sprintf("%s", err))
}
Expand All @@ -68,17 +68,17 @@ func GroupBy(ec compilerInterface.ExecutionContext, caller compilerInterface.AST
}

func Pivot(ec compilerInterface.ExecutionContext, caller compilerInterface.AST, arguments compilerInterface.Arguments) (*compilerInterface.Value, error) {
args, err := getArgs(arguments,
paramWithDefault("column", compilerInterface.NewString("")),
paramWithDefault("values", compilerInterface.NewList(make([]*compilerInterface.Value, 0))),
paramWithDefault("alias", compilerInterface.NewBoolean(true)),
paramWithDefault("agg", compilerInterface.NewString("sum")),
paramWithDefault("cmp", compilerInterface.NewString("=")),
paramWithDefault("prefix", compilerInterface.NewString("")),
paramWithDefault("suffix", compilerInterface.NewString("")),
param("then_value"),
param("else_value"),
paramWithDefault("quote_identifiers", compilerInterface.NewBoolean(true)),
args, err := GetArgs(arguments,
ParamWithDefault("column", compilerInterface.NewString("")),
ParamWithDefault("values", compilerInterface.NewList(make([]*compilerInterface.Value, 0))),
ParamWithDefault("alias", compilerInterface.NewBoolean(true)),
ParamWithDefault("agg", compilerInterface.NewString("sum")),
ParamWithDefault("cmp", compilerInterface.NewString("=")),
ParamWithDefault("prefix", compilerInterface.NewString("")),
ParamWithDefault("suffix", compilerInterface.NewString("")),
Param("then_value"),
Param("else_value"),
ParamWithDefault("quote_identifiers", compilerInterface.NewBoolean(true)),
)
if err != nil {
return nil, ec.ErrorAt(caller, fmt.Sprintf("%s", err))
Expand Down
13 changes: 7 additions & 6 deletions compiler/dbtUtils/utils.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
package dbtUtils

import (
"ddbt/compilerInterface"
"fmt"

"ddbt/compilerInterface"
)

func param(name string) compilerInterface.Argument {
return paramWithDefault(name, nil)
func Param(name string) compilerInterface.Argument {
return ParamWithDefault(name, nil)
}

func paramWithDefault(name string, value *compilerInterface.Value) compilerInterface.Argument {
func ParamWithDefault(name string, value *compilerInterface.Value) compilerInterface.Argument {
return compilerInterface.Argument{
Name: name,
Value: value,
}
}

func getArgs(arguments compilerInterface.Arguments, params ...compilerInterface.Argument) ([]*compilerInterface.Value, error) {
func GetArgs(arguments compilerInterface.Arguments, params ...compilerInterface.Argument) ([]*compilerInterface.Value, error) {
args := make([]*compilerInterface.Value, len(params))

// quick lookup map
Expand Down Expand Up @@ -62,7 +63,7 @@ func getArgs(arguments compilerInterface.Arguments, params ...compilerInterface.
return args, nil
}

func isOnlyCompilingSQL(ec compilerInterface.ExecutionContext) bool {
func IsOnlyCompilingSQL(ec compilerInterface.ExecutionContext) bool {
value := ec.GetVariable("execute")

if value.Type() == compilerInterface.BooleanValue {
Expand Down
40 changes: 29 additions & 11 deletions compilerInterface/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"strconv"
"strings"

"ddbt/jinja/lexer"
)
Expand Down Expand Up @@ -127,26 +128,43 @@ func (v *Value) Type() ValueType {
}
}

func (v *Value) Properties() map[string]*Value {
func (v *Value) Properties(isForFunctionCall bool) map[string]*Value {
switch v.Type() {
case MapVal:
return v.MapValue

case ListVal:
return map[string]*Value{
"items": NewFunction(func(_ ExecutionContext, _ AST, _ Arguments) (*Value, error) { return v, nil }),
"extend": NewFunction(func(_ ExecutionContext, _ AST, args Arguments) (*Value, error) {
for _, arg := range args[0].Value.ListValue {
if arg != v {
v.ListValue = append(v.ListValue, arg)
}
extendFunc := NewFunction(func(_ ExecutionContext, _ AST, args Arguments) (*Value, error) {
for _, arg := range args[0].Value.ListValue {
if arg != v {
v.ListValue = append(v.ListValue, arg)
}
return v, nil
}),
}
return v, nil
})

return map[string]*Value{
"items": NewFunction(func(_ ExecutionContext, _ AST, _ Arguments) (*Value, error) { return v, nil }),
"extend": extendFunc,
"append": extendFunc,
}

case ReturnVal:
return v.ReturnValue.Properties()
return v.ReturnValue.Properties(isForFunctionCall)

case StringVal:
if !isForFunctionCall {
return nil
}

return map[string]*Value{
"upper": NewFunction(func(_ ExecutionContext, _ AST, _ Arguments) (*Value, error) {
return NewString(strings.ToUpper(v.StringValue)), nil
}),
"lower": NewFunction(func(_ ExecutionContext, _ AST, _ Arguments) (*Value, error) {
return NewString(strings.ToUpper(v.StringValue)), nil
}),
}

default:
return nil
Expand Down
2 changes: 1 addition & 1 deletion jinja/ast/Variable.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (v *Variable) resolvePropertyLookup(ec compilerInterface.ExecutionContext,
return nil, err
}

data := value.Properties()
data := value.Properties(isForFunctionCall)
if data == nil {
return nil, ec.ErrorAt(v, fmt.Sprintf("unable reference by property key in a %s", value.Type()))
}
Expand Down
2 changes: 1 addition & 1 deletion utils/version.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
package utils

const DdbtVersion = "0.6.3"
const DdbtVersion = "0.6.4"

0 comments on commit 96c2ee6

Please sign in to comment.