Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tests for csm build report #516

Merged
merged 11 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from src.modules.csm.types import ReportData, Shares
from src.modules.submodules.consensus import ConsensusModule
from src.modules.submodules.oracle_module import BaseModule, ModuleExecuteDelay
from src.modules.submodules.types import ZERO_HASH
from src.providers.execution.contracts.cs_fee_oracle import CSFeeOracleContract
from src.providers.execution.exceptions import InconsistentData
from src.providers.ipfs import CID
Expand Down Expand Up @@ -98,7 +99,7 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
prev_root = self.w3.csm.get_csm_tree_root(blockstamp)
prev_cid = self.w3.csm.get_csm_tree_cid(blockstamp)

if bool(prev_root) != bool(prev_cid):
if bool(prev_root) != (prev_root is not ZERO_HASH):
vgorkavenko marked this conversation as resolved.
Show resolved Hide resolved
raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}")

distributed, shares, log = self.calculate_distribution(blockstamp)
Expand All @@ -110,13 +111,13 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
return ReportData(
self.report_contract.get_consensus_version(blockstamp.block_hash),
blockstamp.ref_slot,
tree_root=prev_root or HexBytes(32),
tree_root=prev_root,
tree_cid=prev_cid or "",
log_cid=log_cid,
distributed=0,
).as_tuple()

if prev_cid and prev_root:
if prev_cid and prev_root is not ZERO_HASH:
madlabman marked this conversation as resolved.
Show resolved Hide resolved
# Update cumulative amount of shares for all operators.
for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root):
shares[no_id] += acc_shares
Expand Down
7 changes: 2 additions & 5 deletions src/web3py/extensions/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@ def get_csm_last_processing_ref_slot(self, blockstamp: BlockStamp) -> SlotNumber
FRAME_PREV_REPORT_REF_SLOT.labels("csm_oracle").set(result)
return result

def get_csm_tree_root(self, blockstamp: BlockStamp) -> HexBytes | None:
result = self.fee_distributor.tree_root(blockstamp.block_hash)
if result == HexBytes(32):
return None
return result
def get_csm_tree_root(self, blockstamp: BlockStamp) -> HexBytes:
return self.fee_distributor.tree_root(blockstamp.block_hash)

def get_csm_tree_cid(self, blockstamp: BlockStamp) -> CID | None:
result = self.fee_distributor.tree_cid(blockstamp.block_hash)
Expand Down
12 changes: 6 additions & 6 deletions tests/modules/csm/test_csm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from src.modules.csm.state import AttestationsAccumulator, State
from src.modules.csm.tree import Tree
from src.modules.submodules.oracle_module import ModuleExecuteDelay
from src.modules.submodules.types import CurrentFrame
from src.modules.submodules.types import CurrentFrame, ZERO_HASH
from src.providers.ipfs import CIDv0, CID
from src.types import EpochNumber, NodeOperatorId, SlotNumber, StakingModuleId, ValidatorIndex
from src.web3py.extensions.csm import CSM
Expand Down Expand Up @@ -506,7 +506,7 @@ def test_collect_data_fulfilled_state(

@dataclass(frozen=True)
class BuildReportTestParam:
prev_tree_root: HexBytes | None
prev_tree_root: HexBytes
prev_tree_cid: CID | None
prev_acc_shares: Iterable[tuple[NodeOperatorId, int]]
curr_distribution: Mock
Expand All @@ -522,7 +522,7 @@ class BuildReportTestParam:
[
pytest.param(
BuildReportTestParam(
prev_tree_root=None,
prev_tree_root=HexBytes(ZERO_HASH),
prev_tree_cid=None,
prev_acc_shares=[],
curr_distribution=Mock(
Expand All @@ -535,17 +535,17 @@ class BuildReportTestParam:
Mock(),
)
),
curr_tree_root=HexBytes(32),
curr_tree_root=HexBytes(ZERO_HASH),
curr_tree_cid="",
curr_log_cid=CID("QmLOG"),
expected_make_tree_call_args=None,
expected_func_result=(1, 100500, HexBytes(32), "", CID("QmLOG"), 0),
expected_func_result=(1, 100500, HexBytes(ZERO_HASH), "", CID("QmLOG"), 0),
),
id="empty_prev_report_and_no_new_distribution",
),
pytest.param(
BuildReportTestParam(
prev_tree_root=None,
prev_tree_root=HexBytes(ZERO_HASH),
prev_tree_cid=None,
prev_acc_shares=[],
curr_distribution=Mock(
Expand Down
Loading