Skip to content
This repository has been archived by the owner on Dec 1, 2023. It is now read-only.

Commit

Permalink
Sanitize call/invoke params (#162)
Browse files Browse the repository at this point in the history
* sanitize call/invoke params

* add docstrings

* fix tests

* fix linter

* add tests for common (#164)

Co-authored-by: Andrew Fleming <[email protected]>
  • Loading branch information
martriay and andrew-fleming authored Aug 11, 2022
1 parent 651f59d commit 5856ba5
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 22 deletions.
17 changes: 16 additions & 1 deletion src/nile/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run_command(

if arguments:
command.append("--inputs")
command.extend([argument for argument in arguments])
command.extend(prepare_params(arguments))

if network == "mainnet":
os.environ["STARKNET_NETWORK"] = "alpha-mainnet"
Expand All @@ -77,3 +77,18 @@ def parse_information(x):
# address is 64, tx_hash is 64 chars long
address, tx_hash = re.findall("0x[\\da-f]{1,64}", str(x))
return address, tx_hash


def stringify(x):
"""Recursively convert list elements to strings."""
if isinstance(x, list):
return [stringify(y) for y in x]
else:
return str(x)


def prepare_params(params):
"""Sanitize call, invoke, and deploy parameters."""
if params is None:
params = []
return stringify(params)
12 changes: 6 additions & 6 deletions src/nile/core/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def send(self, to, method, calldata, max_fee, nonce=None):

if nonce is None:
nonce = int(
call_or_invoke(self.address, "call", "get_nonce", [], self.network)
call_or_invoke(self.address, "call", "get_nonce", [], self.network)[0]
)

if max_fee is None:
Expand All @@ -80,11 +80,11 @@ def send(self, to, method, calldata, max_fee, nonce=None):
)

params = []
params.append(str(len(call_array)))
params.extend([str(elem) for sublist in call_array for elem in sublist])
params.append(str(len(calldata)))
params.extend([str(param) for param in calldata])
params.append(str(nonce))
params.append(len(call_array))
params.extend(*call_array)
params.append(len(calldata))
params.extend(calldata)
params.append(nonce)

return call_or_invoke(
contract=self.address,
Expand Down
6 changes: 4 additions & 2 deletions src/nile/core/call_or_invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import subprocess

from nile import deployments
from nile.common import GATEWAYS
from nile.common import GATEWAYS, prepare_params


def call_or_invoke(
Expand Down Expand Up @@ -32,6 +32,8 @@ def call_or_invoke(
gateway_prefix = "feeder_gateway" if type == "call" else "gateway"
command.append(f"--{gateway_prefix}_url={GATEWAYS.get(network)}")

params = prepare_params(params)

if len(params) > 0:
command.append("--inputs")
command.extend(params)
Expand All @@ -47,7 +49,7 @@ def call_or_invoke(
command.append("--no_wallet")

try:
return subprocess.check_output(command).strip().decode("utf-8")
return subprocess.check_output(command).strip().decode("utf-8").split()
except subprocess.CalledProcessError:
p = subprocess.Popen(command, stderr=subprocess.PIPE)
_, error = p.communicate()
Expand Down
6 changes: 0 additions & 6 deletions src/nile/nre.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,14 @@ def declare(self, contract, alias=None, overriding_path=None):

def deploy(self, contract, arguments=None, alias=None, overriding_path=None):
"""Deploy a smart contract."""
if arguments is None:
arguments = []
return deploy(contract, arguments, self.network, alias, overriding_path)

def call(self, contract, method, params=None):
"""Call a view function in a smart contract."""
if params is None:
params = []
return call_or_invoke(contract, "call", method, params, self.network)

def invoke(self, contract, method, params=None):
"""Invoke a mutable function in a smart contract."""
if params is None:
params = []
return call_or_invoke(contract, "invoke", method, params, self.network)

def get_deployment(self, identifier):
Expand Down
8 changes: 4 additions & 4 deletions tests/commands/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_send_nonce_call(mock_call):
"callarray, calldata",
# The following callarray and calldata args tests the Account's list comprehensions
# ensuring they're set to strings and passed correctly
[([[111]], []), ([[111, 222]], [333, 444, 555])],
[([["111"]], []), ([["111", "222"]], ["333", "444", "555"])],
)
def test_send_sign_transaction_and_execute(callarray, calldata):
account = Account(KEY, NETWORK)
Expand Down Expand Up @@ -128,11 +128,11 @@ def test_send_sign_transaction_and_execute(callarray, calldata):
method="__execute__",
network=NETWORK,
params=[
str(len(callarray)),
len(callarray),
*(str(elem) for sublist in callarray for elem in sublist),
str(len(calldata)),
len(calldata),
*(str(param) for param in calldata),
str(nonce),
nonce,
],
signature=[str(sig_r), str(sig_s)],
type="invoke",
Expand Down
33 changes: 30 additions & 3 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Tests for deploy command."""
"""Tests for common library."""
from unittest.mock import patch

import pytest

from nile.common import BUILD_DIRECTORY, run_command
from nile.common import BUILD_DIRECTORY, prepare_params, run_command, stringify

CONTRACT = "contract"
OPERATION = "invoke"
NETWORK = "goerli"
ARGS = [1, 2, 3]
ARGS = ["1", "2", "3"]
LIST1 = [1, 2, 3]
LIST2 = [1, 2, 3, [4, 5, 6]]
LIST3 = [1, 2, 3, [4, 5, 6, [7, 8, 9]]]


@pytest.mark.parametrize("operation", ["invoke", "call"])
Expand All @@ -30,3 +33,27 @@ def test_run_command(mock_subprocess, operation):
"--no_wallet",
]
)


@pytest.mark.parametrize(
"args, expected",
[
([], []),
([LIST1], [["1", "2", "3"]]),
([LIST2], [["1", "2", "3", ["4", "5", "6"]]]),
([LIST3], [["1", "2", "3", ["4", "5", "6", ["7", "8", "9"]]]]),
],
)
def test_stringify(args, expected):
assert stringify(args) == expected


@pytest.mark.parametrize(
"args, expected",
[
([], []),
([LIST1], [["1", "2", "3"]]),
],
)
def test_prepare_params(args, expected):
assert prepare_params(args) == expected

0 comments on commit 5856ba5

Please sign in to comment.