Skip to content

Commit

Permalink
hotfix/test(query_clients): Test vpool.query.all_pools fn. Account fo…
Browse files Browse the repository at this point in the history
…r deserialize edge case #102. (#103)

* fix(query_clients): account for missing fields on the pb_msg in deserialize

* fix(query_clients): account for missing fields on the pb_msg in deserialize

* test(vpool): test_query_vpool_base_asset_price
  • Loading branch information
Unique-Divine authored Aug 21, 2022
1 parent 3e6e8b6 commit e5d3a0c
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 29 deletions.
87 changes: 80 additions & 7 deletions nibiru/query_clients/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from typing import Dict, List

from google.protobuf import message as protobuf_message
from google.protobuf.json_format import MessageToDict

from nibiru.utils import from_sdk_dec
from nibiru.utils import from_sdk_dec, from_sdk_int

PROTOBUF_MSG_BASE_ATTRS: List[str] = (
dir(protobuf_message.Message)
+ ['Extensions', 'FindInitializationErrors', '_CheckCalledFromGeneratedFile']
+ ['_extensions_by_name', '_extensions_by_number']
)
"""PROTOBUF_MSG_BASE_ATTRS (List[str]): The default attributes and methods of
an instance of the 'protobuf.message.Message' class.
"""


def camel_to_snake(s):
return ''.join(['_' + c.lower() if c.isupper() else c for c in s]).lstrip('_')
def camel_to_snake(camel: str):
return ''.join(
['_' + char.lower() if char.isupper() else char for char in camel]
).lstrip('_')


def t_dict(d):
Expand All @@ -16,16 +30,75 @@ def t_dict(d):
}


def deserialize(proto_message: object) -> dict:
def deserialize(pb_msg: protobuf_message.Message) -> dict:
"""Deserializes a proto message into a dictionary.
- sdk.Dec values are converted to floats.
- sdk.Int values are converted to ints.
- Missing fields become blank strings.
Args:
pb_msg (protobuf.message.Message)
Returns:
dict: 'pb_msg' as a JSON-able dictionary.
"""
if not isinstance(pb_msg, protobuf_message.Message):
raise TypeError(f"expted protobuf Message for 'pb_msg', not {type(pb_msg)}")
custom_dtypes: Dict[str, bytes] = {
str(field[1]): field[0].GetOptions().__getstate__().get("serialized", None)
for field in pb_msg.ListFields()
}
serialized_output = {}
expected_fields: List[str] = [
attr for attr in dir(pb_msg) if attr not in PROTOBUF_MSG_BASE_ATTRS
]

for _, attr in enumerate(expected_fields):

attr_search = pb_msg.__getattribute__(attr)
custom_dtype = custom_dtypes.get(str(attr_search))

if custom_dtype is not None:

if "sdk/types.Dec" in str(custom_dtype):
serialized_output[str(attr)] = from_sdk_dec(
pb_msg.__getattribute__(attr)
)
elif "sdk/types.Int" in str(custom_dtype):
serialized_output[str(attr)] = from_sdk_int(
pb_msg.__getattribute__(attr)
)
else:
try:
val = pb_msg.__getattribute__(attr)
if hasattr(val, '__len__') and not isinstance(val, str):
updated_vals = []
for v in val:
updated_vals.append(deserialize(v))
serialized_output[str(attr)] = updated_vals
else:
serialized_output[str(attr)] = deserialize(val)
except:
serialized_output[str(attr)] = pb_msg.__getattribute__(attr)
elif (custom_dtype is None) and (attr_search == ''):
serialized_output[str(attr)] = ""
else:
serialized_output[str(attr)] = deserialize(pb_msg.__getattribute__(attr))

return serialized_output


def deserialize_exp(proto_message: protobuf_message.Message) -> dict:
"""
Take a proto message and convert it into a dictionnary.
sdk.Dec values are converted to be consistent with txs.
Args:
proto_message (object): The proto message
proto_message (protobuf.message.Message)
Returns:
dict: The dictionary
dict
"""
output = MessageToDict(proto_message)

Expand All @@ -38,7 +111,7 @@ def deserialize(proto_message: object) -> dict:
if field.message_type is not None:
# This is another proto object
try:
output[field.camelcase_name] = deserialize(
output[field.camelcase_name] = deserialize_exp(
proto_message.__getattribute__(field.camelcase_name)
)
except AttributeError:
Expand Down
24 changes: 19 additions & 5 deletions nibiru/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,25 @@


class Sdk:
"""
The Sdk class creates an interface to sign and send transactions or execute queries from a node.
"""The Sdk class creates an interface to sign and send transactions or execute
queries from a node.
It is associated to:
- a wallet, which can be either created or recovered from an existing mnemonic.
- a network, defining the node to connect to
- optionally a configuration defining how to behave and the gas configuration for each transaction
- a wallet, which can be either created or recovered from an existing mnemonic.
- a network, defining the node to connect to
- optionally a configuration defining how to behave and the gas configuration
for each transaction
Each method starting with `with_` will replace the existing Sdk object with a new version having the defined
behavior.
Attributes:
priv_key
query
tx
network
tx_config
Example ::
Expand All @@ -38,6 +47,11 @@ class Sdk:
)
"""

query: GrpcClient
network: Network
tx: BaseTxClient
tx_config: TxConfig

def __init__(self, _error_do_not_use_init_directly=None) -> None:
"""Unsupported, please use from_mnemonic to initialize."""
if not _error_do_not_use_init_directly:
Expand Down
4 changes: 2 additions & 2 deletions nibiru/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def to_sdk_int(i: float) -> str:
return str(int(i))


def from_sdk_int(int_str: str) -> float:
return float(int_str)
def from_sdk_int(int_str: str) -> int:
return int(int_str)


def toPbTimestamp(dt: datetime):
Expand Down
6 changes: 5 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the nibiru package"""
import logging
import sys
from typing import Iterable

import shutup

Expand All @@ -25,7 +26,10 @@ def init_test_logger() -> logging.Logger:
"""Simple logger to use throughout the test suite."""


def dict_keys_must_match(dict_: dict, keys: list[str]):
def dict_keys_must_match(dict_: dict, keys: Iterable[str]):
keys = list(keys)
if not isinstance(dict_, dict):
raise TypeError(f"'dict' must be a dicitonary, not {type(dict_)}")
assert len(dict_.keys()) == len(keys)
for key in dict_.keys():
assert key in keys
Expand Down
9 changes: 0 additions & 9 deletions tests/chain_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,3 @@ def test_query_perp_params(val_node: Sdk):
"twapLookbackWindow",
]
assert all([(param_name in params) for param_name in perp_param_names])


def test_query_vpool_reserve_assets(val_node: Sdk):
expected_pairs: List[str] = ["ubtc:unusd", "ueth:unusd"]
for pair in expected_pairs:
query_resp: dict = val_node.query.vpool.reserve_assets(pair)
assert isinstance(query_resp, dict)
assert query_resp["base_asset_reserve"] > 0
assert query_resp["quote_asset_reserve"] > 0
10 changes: 5 additions & 5 deletions tests/perp_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# perp_test.py
import pytest
from grpc._channel import _InactiveRpcError
from pytest import approx, raises

import nibiru
import nibiru.msg
Expand All @@ -25,7 +25,7 @@ def test_open_close_position(val_node: nibiru.Sdk, agent: nibiru.Sdk):
)

# Exception must be raised when requesting not existing position
with raises(_InactiveRpcError, match="no position found"):
with pytest.raises(_InactiveRpcError, match="no position found"):
agent.query.perp.trader_position(trader=agent.address, token_pair=pair)

# Transaction open_position must succeed
Expand Down Expand Up @@ -57,12 +57,12 @@ def test_open_close_position(val_node: nibiru.Sdk, agent: nibiru.Sdk):
],
)
# Margin ratio must be ~10%
assert position_res["margin_ratio_mark"] == approx(0.1, PRECISION)
assert position_res["margin_ratio_mark"] == pytest.approx(0.1, PRECISION)

position = position_res["position"]
assert position["margin"] == 10.0
assert position["open_notional"] == 100.0
assert position["size"] == approx(0.005, PRECISION)
assert position["size"] == pytest.approx(0.005, PRECISION)

# Transaction add_margin must succeed
tx_output = agent.tx.execute_msgs(
Expand Down Expand Up @@ -103,5 +103,5 @@ def test_open_close_position(val_node: nibiru.Sdk, agent: nibiru.Sdk):
transaction_must_succeed(tx_output)

# Exception must be raised when querying closed position
with raises(_InactiveRpcError, match="no position found"):
with pytest.raises(_InactiveRpcError, match="no position found"):
agent.query.perp.trader_position(trader=agent.address, token_pair=pair)
66 changes: 66 additions & 0 deletions tests/vpool_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pprint
from typing import Dict, List

import nibiru
import tests
from nibiru import common


def test_query_vpool_reserve_assets(val_node: nibiru.Sdk):
expected_pairs: List[str] = ["ubtc:unusd", "ueth:unusd"]
for pair in expected_pairs:
query_resp: dict = val_node.query.vpool.reserve_assets(pair)
assert isinstance(query_resp, dict)
assert query_resp["base_asset_reserve"] > 0
assert query_resp["quote_asset_reserve"] > 0


def test_query_vpool_all_pools(agent: nibiru.Sdk):
"""Tests deserialization and expected attributes for the
'nibid query vpool all-pools' command.
"""

query_resp: Dict[str, List[dict]] = agent.query.vpool.all_pools()
tests.dict_keys_must_match(query_resp, keys=["pools", "prices"])

all_vpools: List[dict] = query_resp["pools"]
vpool_fields: List[str] = [
"base_asset_reserve",
"fluctuation_limit_ratio",
"maintenance_margin_ratio",
"max_leverage",
"max_oracle_spread_ratio",
"pair",
"quote_asset_reserve",
"trade_limit_ratio",
]
tests.dict_keys_must_match(all_vpools[0], keys=vpool_fields)

all_vpool_prices = query_resp["prices"]
price_fields: List[str] = [
"block_number",
"index_price",
"mark_price",
"swap_invariant",
"twap_mark",
"pair",
]
tests.dict_keys_must_match(all_vpool_prices[0], keys=price_fields)

vpool_prices = all_vpool_prices[0]
assert isinstance(vpool_prices["block_number"], int), "block_number"
assert isinstance(vpool_prices["index_price"], float), "index_price"
assert isinstance(vpool_prices["mark_price"], float), "mark_price"
assert isinstance(vpool_prices["swap_invariant"], int), "swap_invariant"
assert isinstance(vpool_prices["twap_mark"], float), "twap_mark"
assert isinstance(vpool_prices["pair"], str), "pair"
tests.LOGGER.info(f"vpool_prices: {pprint.pformat(vpool_prices, indent=3)}")


def test_query_vpool_base_asset_price(agent: nibiru.Sdk):
query_resp: Dict[str, List[dict]] = agent.query.vpool.base_asset_price(
pair="ueth:unusd", direction=common.Direction.ADD, base_asset_amount="15"
)
tests.dict_keys_must_match(query_resp, keys=["price_in_quote_denom"])
assert isinstance(query_resp["price_in_quote_denom"], float)
assert query_resp["price_in_quote_denom"] > 0

0 comments on commit e5d3a0c

Please sign in to comment.