diff --git a/boaconstructor/__init__.py b/boaconstructor/__init__.py index 3eda817..d262935 100644 --- a/boaconstructor/__init__.py +++ b/boaconstructor/__init__.py @@ -4,7 +4,7 @@ import signal import re import inspect -from typing import Optional, TypeVar, Type, cast, Sequence +from typing import Optional, TypeVar, Type, Sequence from neo3.core import types, cryptography from neo3.wallet import account from neo3.api.wrappers import GenericContract, NEP17Contract, ChainFacade @@ -18,6 +18,7 @@ from neo3.contracts import nef, manifest from dataclasses import dataclass from boaconstructor.node import NeoGoNode, Node +from boaconstructor.storage import PostProcessor __version__ = "0.1.3" @@ -245,6 +246,8 @@ async def get_storage( *, target_contract: Optional[types.UInt160] = None, remove_prefix: bool = False, + key_post_processor: Optional[PostProcessor] = None, + values_post_processor: Optional[PostProcessor] = None, ) -> dict[bytes, bytes]: """ Gets the entries in the storage of the contract specified by `contract_hash` @@ -253,6 +256,8 @@ async def get_storage( prefix: prefix to filter the entries in the storage. Return the entire storage if not set. target_contract: gets the storage of a different contract than the one under test. e.g. NeoToken remove_prefix: whether the prefix should be removed from the output keys. False by default. + key_post_processor: a function to post process the storage key before placing it in the dictionary. + values_post_processor: a function to post process the storage value before placing it in the dictionary. """ if target_contract is None: contract = GenericContract(cls.contract_hash) @@ -266,9 +271,12 @@ async def get_storage( async with noderpc.NeoRpcClient(cls.node.facade.rpc_host) as rpc_client: async for k, v in rpc_client.find_states(contract.hash, prefix): if remove_prefix: - results[k.removeprefix(prefix)] = v - else: - results[k] = v + k = k.removeprefix(prefix) + if key_post_processor is not None: + k = key_post_processor(k) + if values_post_processor is not None: + v = values_post_processor(v) + results[k] = v return results diff --git a/boaconstructor/storage.py b/boaconstructor/storage.py new file mode 100644 index 0000000..a016a96 --- /dev/null +++ b/boaconstructor/storage.py @@ -0,0 +1,79 @@ +""" +Common post processor functions for the get_storage() method. +""" + +from neo3.core import types, cryptography +from neo3.wallet import utils as walletutils +from neo3.wallet.types import NeoAddress +from neo3.api.helpers import stdlib +from typing_extensions import Protocol +from typing import Any + + +class PostProcessor(Protocol): + def __call__(self, data: bytes, *args: Any) -> Any: + ... + + +def as_uint160(data: bytes, *_: Any) -> types.UInt160: + """ + Convert the data to a UInt160 + + Args: + data: a serialized UInt160 + """ + return types.UInt160(data) + + +def as_uint256(data: bytes, *_: Any) -> types.UInt256: + """ + Convert the data to a UInt256 + + Args: + data: a serialized UInt256 + """ + return types.UInt256(data) + + +def as_int(data: bytes, *_: Any) -> int: + """ + Convert the data to an integer + """ + return int(types.BigInteger(data)) + + +def as_str(data: bytes, *_: Any) -> str: + """ + Convert the data to a UTF-8 encoded string + """ + return data.decode() + + +def as_address(data: bytes, *_: Any) -> NeoAddress: + """ + Convert the data to a Neo address string + + Args: + data: a serialized UInt160 + """ + return walletutils.script_hash_to_address(types.UInt160(data)) + + +def as_public_key(data: bytes, *_: Any) -> cryptography.ECPoint: + """ + Convert the data to a public key + + Args: + data: a serialized compressed public key + """ + return cryptography.ECPoint.deserialize_from_bytes(data) + + +def stdlib_deserialize(data: bytes, *_: Any) -> Any: + """ + Deserialize the data using the Binary Deserialize logic of the StdLib native contract + + Args: + data: data that has been serialized using the StdLib native contract binary serialize method + """ + return stdlib.binary_deserialize(data) diff --git a/examples/amm/test_amm.py b/examples/amm/test_amm.py index a060dde..70b3e3d 100644 --- a/examples/amm/test_amm.py +++ b/examples/amm/test_amm.py @@ -20,6 +20,7 @@ class AmmContractTest(SmartContractTestCase): zgas_contract: GenericContract zneo_contract_hash: types.UInt160 zneo_contract: GenericContract + contract: GenericContract @classmethod def setUpClass(cls) -> None: @@ -35,7 +36,7 @@ def setUpClass(cls) -> None: @classmethod async def asyncSetupClass(cls) -> None: - cls.genesis = cls.node.wallet.account_get_by_label("committee") + cls.genesis = cls.node.wallet.account_get_by_label("committee") # type: ignore await cls.transfer(GAS, cls.genesis.script_hash, cls.owner.script_hash, 100) await cls.transfer(GAS, cls.genesis.script_hash, cls.user.script_hash, 100) @@ -338,7 +339,7 @@ async def test_07_add_liquidity(self): self.assertEqual(1, len(transfer_events)) self.assertEqual(3, len(transfer_events[0].state.as_list())) sender, receiver, amount = transfer_events[0].state.as_list() - self.assertEqual(None, sender.as_none()) + self.assertIsNone(sender.as_none()) self.assertEqual(self.user.script_hash, receiver.as_uint160()) self.assertEqual(liquidity, amount.as_int()) @@ -434,7 +435,7 @@ async def test_07_add_liquidity(self): self.assertEqual(1, len(transfer_events)) self.assertEqual(3, len(transfer_events[0].state.as_list())) sender, receiver, amount = transfer_events[0].state.as_list() - self.assertEqual(None, sender.as_none()) + self.assertIsNone(sender.as_none()) self.assertEqual(self.user.script_hash, receiver.as_uint160()) self.assertEqual(liquidity, amount.as_int()) @@ -638,7 +639,7 @@ async def test_08_remove_liquidity(self): self.assertEqual(3, len(transfer_events[0].state.as_list())) sender, receiver, amount = transfer_events[0].state.as_list() self.assertEqual(self.user.script_hash, sender.as_uint160()) - self.assertEqual(None, receiver.as_none()) + self.assertIsNone(receiver.as_none()) self.assertEqual(liquidity, amount.as_int()) self.assertEqual(1, len(sync_events)) diff --git a/examples/nep17/test_nep17.py b/examples/nep17/test_nep17.py index 8b898e5..d4325e3 100644 --- a/examples/nep17/test_nep17.py +++ b/examples/nep17/test_nep17.py @@ -4,11 +4,13 @@ AssertException, SmartContractTestCase, Nep17TransferEvent, + storage as _storage, ) from neo3.api.wrappers import NEP17Contract from neo3.wallet import account from neo3.core import types from neo3.contracts.contract import CONTRACT_HASHES +from typing import cast NEO = CONTRACT_HASHES.NEO_TOKEN GAS = CONTRACT_HASHES.GAS_TOKEN @@ -19,6 +21,7 @@ class Nep17ContractTest(SmartContractTestCase): user1: account.Account user2: account.Account balance_prefix: bytes = b"b" + contract: NEP17Contract @classmethod def setUpClass(cls) -> None: @@ -34,7 +37,7 @@ def setUpClass(cls) -> None: @classmethod async def asyncSetupClass(cls) -> None: - cls.genesis = cls.node.wallet.account_get_by_label("committee") + cls.genesis = cls.node.wallet.account_get_by_label("committee") # type: ignore cls.contract_hash = await cls.deploy("./resources/nep17.nef", cls.genesis) cls.contract = NEP17Contract(cls.contract_hash) await cls.transfer(GAS, cls.genesis.script_hash, cls.user1.script_hash, 100) @@ -70,14 +73,19 @@ async def test_02_balance_of(self): signing_account=self.user1, ) self.assertTrue(success) - - storage = await self.get_storage(prefix=self.balance_prefix, remove_prefix=True) - + storage = cast( + dict[types.UInt160, bytes], + await self.get_storage( + prefix=self.balance_prefix, + remove_prefix=True, + key_post_processor=_storage.as_uint160, + ), + ) result, _ = await self.call( "balanceOf", [self.user1.script_hash], return_type=int ) self.assertEqual(expected, result) - balance_key = self.user1.script_hash.to_array() + balance_key = self.user1.script_hash self.assertIn(balance_key, storage) self.assertEqual(types.BigInteger(expected).to_array(), storage[balance_key]) @@ -86,7 +94,7 @@ async def test_02_balance_of(self): unknown_account = types.UInt160(b"\x01" * 20) result, _ = await self.call("balanceOf", [unknown_account], return_type=int) self.assertEqual(expected, result) - balance_key = unknown_account.to_array() + balance_key = unknown_account self.assertNotIn(balance_key, storage) # now test invalid account @@ -116,7 +124,14 @@ async def test_03_transfer_success(self): self.assertEqual(1, len(notifications)) self.assertEqual("Transfer", notifications[0].event_name) - storage = await self.get_storage(prefix=self.balance_prefix, remove_prefix=True) + storage = cast( + dict[types.UInt160, bytes], + await self.get_storage( + prefix=self.balance_prefix, + remove_prefix=True, + key_post_processor=_storage.as_uint160, + ), + ) # test we emitted the correct transfer event event = Nep17TransferEvent.from_notification(notifications[0]) @@ -131,12 +146,10 @@ async def test_03_transfer_success(self): self.assertEqual(user1_balance, user2_balance) # test storage updates - user1_balance_key = self.user1.script_hash.to_array() - user2_balance_key = self.user2.script_hash.to_array() - self.assertNotIn(user1_balance_key, storage) - self.assertIn(user2_balance_key, storage) + self.assertNotIn(self.user1.script_hash, storage) + self.assertIn(self.user2.script_hash, storage) self.assertEqual( - types.BigInteger(user1_balance).to_array(), storage[user2_balance_key] + types.BigInteger(user1_balance).to_array(), storage[self.user2.script_hash] ) async def test_onnep17(self): diff --git a/pyproject.toml b/pyproject.toml index edfb5a7..376e098 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ tag = "v0.104.0" [tool.mypy] check_untyped_defs = true +disable_error_code = "func-returns-value" [tool.bumpversion] current_version = "0.1.3"