Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

case: add post processors to get_storage() #7

Merged
merged 8 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions boaconstructor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import signal
import re
import inspect
from typing import Optional, TypeVar, Type, cast, Sequence
from typing import Optional, TypeVar, Type, Sequence
from neo3.core import types, cryptography
from neo3.wallet import account
from neo3.api.wrappers import GenericContract, NEP17Contract, ChainFacade
Expand All @@ -18,6 +18,7 @@
from neo3.contracts import nef, manifest
from dataclasses import dataclass
from boaconstructor.node import NeoGoNode, Node
from boaconstructor.storage import PostProcessor

__version__ = "0.1.3"

Expand Down Expand Up @@ -245,6 +246,8 @@ async def get_storage(
*,
target_contract: Optional[types.UInt160] = None,
remove_prefix: bool = False,
key_post_processor: Optional[PostProcessor] = None,
values_post_processor: Optional[PostProcessor] = None,
) -> dict[bytes, bytes]:
"""
Gets the entries in the storage of the contract specified by `contract_hash`
Expand All @@ -253,6 +256,8 @@ async def get_storage(
prefix: prefix to filter the entries in the storage. Return the entire storage if not set.
target_contract: gets the storage of a different contract than the one under test. e.g. NeoToken
remove_prefix: whether the prefix should be removed from the output keys. False by default.
key_post_processor: a function to post process the storage key before placing it in the dictionary.
values_post_processor: a function to post process the storage value before placing it in the dictionary.
"""
if target_contract is None:
contract = GenericContract(cls.contract_hash)
Expand All @@ -266,9 +271,12 @@ async def get_storage(
async with noderpc.NeoRpcClient(cls.node.facade.rpc_host) as rpc_client:
async for k, v in rpc_client.find_states(contract.hash, prefix):
if remove_prefix:
results[k.removeprefix(prefix)] = v
else:
results[k] = v
k = k.removeprefix(prefix)
if key_post_processor is not None:
k = key_post_processor(k)
if values_post_processor is not None:
v = values_post_processor(v)
results[k] = v

return results

Expand Down
68 changes: 68 additions & 0 deletions boaconstructor/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
Common post processor functions for the get_storage() method.
"""

from neo3.core import types, cryptography
from neo3.wallet import utils as walletutils
from neo3.wallet.types import NeoAddress
from typing_extensions import Protocol
from typing import Any


class PostProcessor(Protocol):
def __call__(self, data: bytes, *args: Any) -> Any:
...


def as_uint160(data: bytes, *_: Any) -> types.UInt160:
"""
Convert the data to a UInt160

Args:
data: a serialized UInt160
"""
return types.UInt160(data)


def as_uint256(data: bytes, *_: Any) -> types.UInt256:
"""
Convert the data to a UInt256

Args:
data: a serialized UInt256
"""
return types.UInt256(data)


def as_int(data: bytes, *_: Any) -> int:
"""
Convert the data to an integer
"""
return int(types.BigInteger(data))


def as_str(data: bytes, *_: Any) -> str:
"""
Convert the data to a UTF-8 encoded string
"""
return data.decode()


def as_address(data: bytes, *_: Any) -> NeoAddress:
"""
Convert the data to a Neo address string

Args:
data: a serialized UInt160
"""
return walletutils.script_hash_to_address(types.UInt160(data))


def as_public_key(data: bytes, *_: Any) -> cryptography.ECPoint:
"""
Convert the data to a public key

Args:
data: a serialized compressed public key
"""
return cryptography.ECPoint.deserialize_from_bytes(data)
9 changes: 5 additions & 4 deletions examples/amm/test_amm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class AmmContractTest(SmartContractTestCase):
zgas_contract: GenericContract
zneo_contract_hash: types.UInt160
zneo_contract: GenericContract
contract: GenericContract

@classmethod
def setUpClass(cls) -> None:
Expand All @@ -35,7 +36,7 @@ def setUpClass(cls) -> None:

@classmethod
async def asyncSetupClass(cls) -> None:
cls.genesis = cls.node.wallet.account_get_by_label("committee")
cls.genesis = cls.node.wallet.account_get_by_label("committee") # type: ignore

await cls.transfer(GAS, cls.genesis.script_hash, cls.owner.script_hash, 100)
await cls.transfer(GAS, cls.genesis.script_hash, cls.user.script_hash, 100)
Expand Down Expand Up @@ -338,7 +339,7 @@ async def test_07_add_liquidity(self):
self.assertEqual(1, len(transfer_events))
self.assertEqual(3, len(transfer_events[0].state.as_list()))
sender, receiver, amount = transfer_events[0].state.as_list()
self.assertEqual(None, sender.as_none())
self.assertIsNone(sender.as_none())
self.assertEqual(self.user.script_hash, receiver.as_uint160())
self.assertEqual(liquidity, amount.as_int())

Expand Down Expand Up @@ -434,7 +435,7 @@ async def test_07_add_liquidity(self):
self.assertEqual(1, len(transfer_events))
self.assertEqual(3, len(transfer_events[0].state.as_list()))
sender, receiver, amount = transfer_events[0].state.as_list()
self.assertEqual(None, sender.as_none())
self.assertIsNone(sender.as_none())
self.assertEqual(self.user.script_hash, receiver.as_uint160())
self.assertEqual(liquidity, amount.as_int())

Expand Down Expand Up @@ -638,7 +639,7 @@ async def test_08_remove_liquidity(self):
self.assertEqual(3, len(transfer_events[0].state.as_list()))
sender, receiver, amount = transfer_events[0].state.as_list()
self.assertEqual(self.user.script_hash, sender.as_uint160())
self.assertEqual(None, receiver.as_none())
self.assertIsNone(receiver.as_none())
self.assertEqual(liquidity, amount.as_int())

self.assertEqual(1, len(sync_events))
Expand Down
37 changes: 25 additions & 12 deletions examples/nep17/test_nep17.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
AssertException,
SmartContractTestCase,
Nep17TransferEvent,
storage as _storage,
)
from neo3.api.wrappers import NEP17Contract
from neo3.wallet import account
from neo3.core import types
from neo3.contracts.contract import CONTRACT_HASHES
from typing import cast

NEO = CONTRACT_HASHES.NEO_TOKEN
GAS = CONTRACT_HASHES.GAS_TOKEN
Expand All @@ -19,6 +21,7 @@ class Nep17ContractTest(SmartContractTestCase):
user1: account.Account
user2: account.Account
balance_prefix: bytes = b"b"
contract: NEP17Contract

@classmethod
def setUpClass(cls) -> None:
Expand All @@ -34,7 +37,7 @@ def setUpClass(cls) -> None:

@classmethod
async def asyncSetupClass(cls) -> None:
cls.genesis = cls.node.wallet.account_get_by_label("committee")
cls.genesis = cls.node.wallet.account_get_by_label("committee") # type: ignore
cls.contract_hash = await cls.deploy("./resources/nep17.nef", cls.genesis)
cls.contract = NEP17Contract(cls.contract_hash)
await cls.transfer(GAS, cls.genesis.script_hash, cls.user1.script_hash, 100)
Expand Down Expand Up @@ -70,14 +73,19 @@ async def test_02_balance_of(self):
signing_account=self.user1,
)
self.assertTrue(success)

storage = await self.get_storage(prefix=self.balance_prefix, remove_prefix=True)

storage = cast(
dict[types.UInt160, bytes],
await self.get_storage(
prefix=self.balance_prefix,
remove_prefix=True,
key_post_processor=_storage.as_uint160,
),
)
result, _ = await self.call(
"balanceOf", [self.user1.script_hash], return_type=int
)
self.assertEqual(expected, result)
balance_key = self.user1.script_hash.to_array()
balance_key = self.user1.script_hash
self.assertIn(balance_key, storage)
self.assertEqual(types.BigInteger(expected).to_array(), storage[balance_key])

Expand All @@ -86,7 +94,7 @@ async def test_02_balance_of(self):
unknown_account = types.UInt160(b"\x01" * 20)
result, _ = await self.call("balanceOf", [unknown_account], return_type=int)
self.assertEqual(expected, result)
balance_key = unknown_account.to_array()
balance_key = unknown_account
self.assertNotIn(balance_key, storage)

# now test invalid account
Expand Down Expand Up @@ -116,7 +124,14 @@ async def test_03_transfer_success(self):
self.assertEqual(1, len(notifications))
self.assertEqual("Transfer", notifications[0].event_name)

storage = await self.get_storage(prefix=self.balance_prefix, remove_prefix=True)
storage = cast(
dict[types.UInt160, bytes],
await self.get_storage(
prefix=self.balance_prefix,
remove_prefix=True,
key_post_processor=_storage.as_uint160,
),
)

# test we emitted the correct transfer event
event = Nep17TransferEvent.from_notification(notifications[0])
Expand All @@ -131,12 +146,10 @@ async def test_03_transfer_success(self):
self.assertEqual(user1_balance, user2_balance)

# test storage updates
user1_balance_key = self.user1.script_hash.to_array()
user2_balance_key = self.user2.script_hash.to_array()
self.assertNotIn(user1_balance_key, storage)
self.assertIn(user2_balance_key, storage)
self.assertNotIn(self.user1.script_hash, storage)
self.assertIn(self.user2.script_hash, storage)
self.assertEqual(
types.BigInteger(user1_balance).to_array(), storage[user2_balance_key]
types.BigInteger(user1_balance).to_array(), storage[self.user2.script_hash]
)

async def test_onnep17(self):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ tag = "v0.102.0"

[tool.mypy]
check_untyped_defs = true
disable_error_code = "func-returns-value"

[tool.bumpversion]
current_version = "0.1.3"
Expand Down