Skip to content

Commit

Permalink
case: add post processors to get_storage() (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
ixje authored Dec 1, 2023
1 parent 1c68b6f commit a4293c6
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 20 deletions.
16 changes: 12 additions & 4 deletions boaconstructor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import signal
import re
import inspect
from typing import Optional, TypeVar, Type, cast, Sequence
from typing import Optional, TypeVar, Type, Sequence
from neo3.core import types, cryptography
from neo3.wallet import account
from neo3.api.wrappers import GenericContract, NEP17Contract, ChainFacade
Expand All @@ -18,6 +18,7 @@
from neo3.contracts import nef, manifest
from dataclasses import dataclass
from boaconstructor.node import NeoGoNode, Node
from boaconstructor.storage import PostProcessor

__version__ = "0.1.3"

Expand Down Expand Up @@ -245,6 +246,8 @@ async def get_storage(
*,
target_contract: Optional[types.UInt160] = None,
remove_prefix: bool = False,
key_post_processor: Optional[PostProcessor] = None,
values_post_processor: Optional[PostProcessor] = None,
) -> dict[bytes, bytes]:
"""
Gets the entries in the storage of the contract specified by `contract_hash`
Expand All @@ -253,6 +256,8 @@ async def get_storage(
prefix: prefix to filter the entries in the storage. Return the entire storage if not set.
target_contract: gets the storage of a different contract than the one under test. e.g. NeoToken
remove_prefix: whether the prefix should be removed from the output keys. False by default.
key_post_processor: a function to post process the storage key before placing it in the dictionary.
values_post_processor: a function to post process the storage value before placing it in the dictionary.
"""
if target_contract is None:
contract = GenericContract(cls.contract_hash)
Expand All @@ -266,9 +271,12 @@ async def get_storage(
async with noderpc.NeoRpcClient(cls.node.facade.rpc_host) as rpc_client:
async for k, v in rpc_client.find_states(contract.hash, prefix):
if remove_prefix:
results[k.removeprefix(prefix)] = v
else:
results[k] = v
k = k.removeprefix(prefix)
if key_post_processor is not None:
k = key_post_processor(k)
if values_post_processor is not None:
v = values_post_processor(v)
results[k] = v

return results

Expand Down
79 changes: 79 additions & 0 deletions boaconstructor/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Common post processor functions for the get_storage() method.
"""

from neo3.core import types, cryptography
from neo3.wallet import utils as walletutils
from neo3.wallet.types import NeoAddress
from neo3.api.helpers import stdlib
from typing_extensions import Protocol
from typing import Any


class PostProcessor(Protocol):
def __call__(self, data: bytes, *args: Any) -> Any:
...


def as_uint160(data: bytes, *_: Any) -> types.UInt160:
"""
Convert the data to a UInt160
Args:
data: a serialized UInt160
"""
return types.UInt160(data)


def as_uint256(data: bytes, *_: Any) -> types.UInt256:
"""
Convert the data to a UInt256
Args:
data: a serialized UInt256
"""
return types.UInt256(data)


def as_int(data: bytes, *_: Any) -> int:
"""
Convert the data to an integer
"""
return int(types.BigInteger(data))


def as_str(data: bytes, *_: Any) -> str:
"""
Convert the data to a UTF-8 encoded string
"""
return data.decode()


def as_address(data: bytes, *_: Any) -> NeoAddress:
"""
Convert the data to a Neo address string
Args:
data: a serialized UInt160
"""
return walletutils.script_hash_to_address(types.UInt160(data))


def as_public_key(data: bytes, *_: Any) -> cryptography.ECPoint:
"""
Convert the data to a public key
Args:
data: a serialized compressed public key
"""
return cryptography.ECPoint.deserialize_from_bytes(data)


def stdlib_deserialize(data: bytes, *_: Any) -> Any:
"""
Deserialize the data using the Binary Deserialize logic of the StdLib native contract
Args:
data: data that has been serialized using the StdLib native contract binary serialize method
"""
return stdlib.binary_deserialize(data)
9 changes: 5 additions & 4 deletions examples/amm/test_amm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class AmmContractTest(SmartContractTestCase):
zgas_contract: GenericContract
zneo_contract_hash: types.UInt160
zneo_contract: GenericContract
contract: GenericContract

@classmethod
def setUpClass(cls) -> None:
Expand All @@ -35,7 +36,7 @@ def setUpClass(cls) -> None:

@classmethod
async def asyncSetupClass(cls) -> None:
cls.genesis = cls.node.wallet.account_get_by_label("committee")
cls.genesis = cls.node.wallet.account_get_by_label("committee") # type: ignore

await cls.transfer(GAS, cls.genesis.script_hash, cls.owner.script_hash, 100)
await cls.transfer(GAS, cls.genesis.script_hash, cls.user.script_hash, 100)
Expand Down Expand Up @@ -338,7 +339,7 @@ async def test_07_add_liquidity(self):
self.assertEqual(1, len(transfer_events))
self.assertEqual(3, len(transfer_events[0].state.as_list()))
sender, receiver, amount = transfer_events[0].state.as_list()
self.assertEqual(None, sender.as_none())
self.assertIsNone(sender.as_none())
self.assertEqual(self.user.script_hash, receiver.as_uint160())
self.assertEqual(liquidity, amount.as_int())

Expand Down Expand Up @@ -434,7 +435,7 @@ async def test_07_add_liquidity(self):
self.assertEqual(1, len(transfer_events))
self.assertEqual(3, len(transfer_events[0].state.as_list()))
sender, receiver, amount = transfer_events[0].state.as_list()
self.assertEqual(None, sender.as_none())
self.assertIsNone(sender.as_none())
self.assertEqual(self.user.script_hash, receiver.as_uint160())
self.assertEqual(liquidity, amount.as_int())

Expand Down Expand Up @@ -638,7 +639,7 @@ async def test_08_remove_liquidity(self):
self.assertEqual(3, len(transfer_events[0].state.as_list()))
sender, receiver, amount = transfer_events[0].state.as_list()
self.assertEqual(self.user.script_hash, sender.as_uint160())
self.assertEqual(None, receiver.as_none())
self.assertIsNone(receiver.as_none())
self.assertEqual(liquidity, amount.as_int())

self.assertEqual(1, len(sync_events))
Expand Down
37 changes: 25 additions & 12 deletions examples/nep17/test_nep17.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
AssertException,
SmartContractTestCase,
Nep17TransferEvent,
storage as _storage,
)
from neo3.api.wrappers import NEP17Contract
from neo3.wallet import account
from neo3.core import types
from neo3.contracts.contract import CONTRACT_HASHES
from typing import cast

NEO = CONTRACT_HASHES.NEO_TOKEN
GAS = CONTRACT_HASHES.GAS_TOKEN
Expand All @@ -19,6 +21,7 @@ class Nep17ContractTest(SmartContractTestCase):
user1: account.Account
user2: account.Account
balance_prefix: bytes = b"b"
contract: NEP17Contract

@classmethod
def setUpClass(cls) -> None:
Expand All @@ -34,7 +37,7 @@ def setUpClass(cls) -> None:

@classmethod
async def asyncSetupClass(cls) -> None:
cls.genesis = cls.node.wallet.account_get_by_label("committee")
cls.genesis = cls.node.wallet.account_get_by_label("committee") # type: ignore
cls.contract_hash = await cls.deploy("./resources/nep17.nef", cls.genesis)
cls.contract = NEP17Contract(cls.contract_hash)
await cls.transfer(GAS, cls.genesis.script_hash, cls.user1.script_hash, 100)
Expand Down Expand Up @@ -70,14 +73,19 @@ async def test_02_balance_of(self):
signing_account=self.user1,
)
self.assertTrue(success)

storage = await self.get_storage(prefix=self.balance_prefix, remove_prefix=True)

storage = cast(
dict[types.UInt160, bytes],
await self.get_storage(
prefix=self.balance_prefix,
remove_prefix=True,
key_post_processor=_storage.as_uint160,
),
)
result, _ = await self.call(
"balanceOf", [self.user1.script_hash], return_type=int
)
self.assertEqual(expected, result)
balance_key = self.user1.script_hash.to_array()
balance_key = self.user1.script_hash
self.assertIn(balance_key, storage)
self.assertEqual(types.BigInteger(expected).to_array(), storage[balance_key])

Expand All @@ -86,7 +94,7 @@ async def test_02_balance_of(self):
unknown_account = types.UInt160(b"\x01" * 20)
result, _ = await self.call("balanceOf", [unknown_account], return_type=int)
self.assertEqual(expected, result)
balance_key = unknown_account.to_array()
balance_key = unknown_account
self.assertNotIn(balance_key, storage)

# now test invalid account
Expand Down Expand Up @@ -116,7 +124,14 @@ async def test_03_transfer_success(self):
self.assertEqual(1, len(notifications))
self.assertEqual("Transfer", notifications[0].event_name)

storage = await self.get_storage(prefix=self.balance_prefix, remove_prefix=True)
storage = cast(
dict[types.UInt160, bytes],
await self.get_storage(
prefix=self.balance_prefix,
remove_prefix=True,
key_post_processor=_storage.as_uint160,
),
)

# test we emitted the correct transfer event
event = Nep17TransferEvent.from_notification(notifications[0])
Expand All @@ -131,12 +146,10 @@ async def test_03_transfer_success(self):
self.assertEqual(user1_balance, user2_balance)

# test storage updates
user1_balance_key = self.user1.script_hash.to_array()
user2_balance_key = self.user2.script_hash.to_array()
self.assertNotIn(user1_balance_key, storage)
self.assertIn(user2_balance_key, storage)
self.assertNotIn(self.user1.script_hash, storage)
self.assertIn(self.user2.script_hash, storage)
self.assertEqual(
types.BigInteger(user1_balance).to_array(), storage[user2_balance_key]
types.BigInteger(user1_balance).to_array(), storage[self.user2.script_hash]
)

async def test_onnep17(self):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ tag = "v0.104.0"

[tool.mypy]
check_untyped_defs = true
disable_error_code = "func-returns-value"

[tool.bumpversion]
current_version = "0.1.3"
Expand Down

0 comments on commit a4293c6

Please sign in to comment.