From 82669f5489184bbe73bb0672fe32892e9bb4ecc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Thu, 13 Aug 2020 19:10:15 -0700 Subject: [PATCH] validate script argument count --- runtime/runtime.go | 29 +++++++++++------- runtime/runtime_test.go | 65 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 10 deletions(-) diff --git a/runtime/runtime.go b/runtime/runtime.go index bd7e502bfb..a610b7b76f 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -180,7 +180,7 @@ func (r *interpreterRuntime) ExecuteScript( checker, functions, nil, - scriptExecutionFunction(epSignature.Parameters, len(arguments), arguments, runtimeInterface), + scriptExecutionFunction(epSignature.Parameters, arguments, runtimeInterface), ) if err != nil { return nil, newError(err) @@ -196,12 +196,11 @@ func (r *interpreterRuntime) ExecuteScript( return exportValue(value), nil } -func scriptExecutionFunction(parameters []*sema.Parameter, argumentCount int, arguments [][]byte, runtimeInterface Interface) func(inter *interpreter.Interpreter) (interpreter.Value, error) { +func scriptExecutionFunction(parameters []*sema.Parameter, arguments [][]byte, runtimeInterface Interface) func(inter *interpreter.Interpreter) (interpreter.Value, error) { return func(inter *interpreter.Interpreter) (interpreter.Value, error) { values, err := validateArgumentParams( inter, runtimeInterface, - argumentCount, arguments, parameters) if err != nil { @@ -370,7 +369,6 @@ func (r *interpreterRuntime) ExecuteTransaction( nil, r.transactionExecutionFunction( transactionType.Parameters, - argumentCount, arguments, runtimeInterface, authorizerValues, @@ -406,7 +404,6 @@ func wrapPanic(f func()) { func (r *interpreterRuntime) transactionExecutionFunction( parameters []*sema.Parameter, - argumentCount int, arguments [][]byte, runtimeInterface Interface, authorizerValues []interpreter.Value, @@ -415,9 +412,9 @@ func (r *interpreterRuntime) transactionExecutionFunction( values, err := validateArgumentParams( inter, runtimeInterface, - argumentCount, arguments, - parameters) + parameters, + ) if err != nil { return nil, err } @@ -430,11 +427,23 @@ func (r *interpreterRuntime) transactionExecutionFunction( func validateArgumentParams( inter *interpreter.Interpreter, runtimeInterface Interface, - argumentCount int, arguments [][]byte, - parameters []*sema.Parameter) ([]interpreter.Value, error) { + parameters []*sema.Parameter, +) ( + []interpreter.Value, + error, +) { + argumentCount := len(arguments) + parameterCount := len(parameters) + + if argumentCount != parameterCount { + return nil, InvalidEntryPointParameterCountError{ + Expected: parameterCount, + Actual: argumentCount, + } + } - argumentValues := make([]interpreter.Value, argumentCount) + argumentValues := make([]interpreter.Value, len(arguments)) // Decode arguments against parameter types for i, parameter := range parameters { diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index 04eb586bcc..ae9afad9a7 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -4698,3 +4698,68 @@ func TestRuntimeTransaction_UpdateAccountCodeUnsafeNotInitializing(t *testing.T) _, err = runtime.ExecuteScript(script2, nil, runtimeInterface, nextTransactionLocation()) require.NoError(t, err) } + +func TestRuntime(t *testing.T) { + + t.Parallel() + + runtime := NewInterpreterRuntime() + + runtimeInterface := &testRuntimeInterface{ + decodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return jsoncdc.Decode(b) + }, + } + + script := []byte(` + pub fun main(num: Int) {} + `) + + type testCase struct { + name string + arguments [][]byte + valid bool + } + + test := func(tc testCase) { + t.Run(tc.name, func(t *testing.T) { + + t.Parallel() + + _, err := runtime.ExecuteScript(script, tc.arguments, runtimeInterface, ScriptLocation{0x1}) + + if tc.valid { + require.NoError(t, err) + } else { + require.Error(t, err) + require.IsType(t, Error{}, err) + assert.IsType(t, InvalidEntryPointParameterCountError{}, err.(Error).Unwrap()) + } + }) + } + + for _, testCase := range []testCase{ + { + name: "too few arguments", + arguments: [][]byte{}, + valid: false, + }, + { + name: "correct number of arguments", + arguments: [][]byte{ + jsoncdc.MustEncode(cadence.NewInt(1)), + }, + valid: true, + }, + { + name: "too many arguments", + arguments: [][]byte{ + jsoncdc.MustEncode(cadence.NewInt(1)), + jsoncdc.MustEncode(cadence.NewInt(2)), + }, + valid: false, + }, + } { + test(testCase) + } +}