diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 239594951..0802b960a 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -18,6 +18,7 @@ from src.modules.csm.types import ReportData, Shares from src.modules.submodules.consensus import ConsensusModule from src.modules.submodules.oracle_module import BaseModule, ModuleExecuteDelay +from src.modules.submodules.types import ZERO_HASH from src.providers.execution.contracts.cs_fee_oracle import CSFeeOracleContract from src.providers.execution.exceptions import InconsistentData from src.providers.ipfs import CID @@ -95,33 +96,42 @@ def execute_module(self, last_finalized_blockstamp: BlockStamp) -> ModuleExecute def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple: self.validate_state(blockstamp) - distributed, shares, log = self.calculate_distribution(blockstamp) - if not distributed: - logger.info({"msg": "No shares distributed in the current frame"}) - - # Load the previous tree if any. prev_root = self.w3.csm.get_csm_tree_root(blockstamp) prev_cid = self.w3.csm.get_csm_tree_cid(blockstamp) - if prev_cid: + if bool(prev_cid) != (prev_root != ZERO_HASH): + raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}") + + distributed, shares, log = self.calculate_distribution(blockstamp) + + log_cid = self.publish_log(log) + + if not distributed: + logger.info({"msg": "No shares distributed in the current frame"}) + return ReportData( + self.report_contract.get_consensus_version(blockstamp.block_hash), + blockstamp.ref_slot, + tree_root=prev_root, + tree_cid=prev_cid or "", + log_cid=log_cid, + distributed=0, + ).as_tuple() + + if prev_cid and prev_root != ZERO_HASH: # Update cumulative amount of shares for all operators. for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root): shares[no_id] += acc_shares else: - logger.info({"msg": "No previous CID available"}) + logger.info({"msg": "No previous distribution. Nothing to accumulate"}) tree = self.make_tree(shares) - tree_cid: CID | None = None - - log_cid = self.publish_log(log) - if tree: - tree_cid = self.publish_tree(tree) + tree_cid = self.publish_tree(tree) return ReportData( self.report_contract.get_consensus_version(blockstamp.block_hash), blockstamp.ref_slot, - tree_root=tree.root if tree else prev_root, - tree_cid=tree_cid or prev_cid or "", + tree_root=tree.root, + tree_cid=tree_cid, log_cid=log_cid, distributed=distributed, ).as_tuple() @@ -169,7 +179,9 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: report_blockstamp = self.get_blockstamp_for_report(blockstamp) if report_blockstamp and report_blockstamp.ref_epoch != r_epoch: logger.warning( - {"msg": f"Frame has been changed, but the change is not yet observed on finalized epoch {finalized_epoch}"} + { + "msg": f"Frame has been changed, but the change is not yet observed on finalized epoch {finalized_epoch}" + } ) return False @@ -296,9 +308,9 @@ def stuck_operators(self, blockstamp: ReferenceBlockStamp) -> set[NodeOperatorId ) return stuck - def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree | None: + def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree: if not shares: - return None + raise ValueError("No shares to build a tree") # XXX: We put a stone here to make sure, that even with only 1 node operator in the tree, it's still possible to # claim rewards. The CSModule contract skips pulling rewards if the proof's length is zero, which is the case diff --git a/src/modules/csm/types.py b/src/modules/csm/types.py index 94f549eb7..541a0fc5f 100644 --- a/src/modules/csm/types.py +++ b/src/modules/csm/types.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import Literal, TypeAlias +from typing import TypeAlias, Literal from hexbytes import HexBytes diff --git a/src/web3py/extensions/csm.py b/src/web3py/extensions/csm.py index 2dd8bd760..65206cda7 100644 --- a/src/web3py/extensions/csm.py +++ b/src/web3py/extensions/csm.py @@ -48,7 +48,7 @@ def get_csm_tree_root(self, blockstamp: BlockStamp) -> HexBytes: def get_csm_tree_cid(self, blockstamp: BlockStamp) -> CID | None: result = self.fee_distributor.tree_cid(blockstamp.block_hash) - if not result: + if result == "": return None return CIDv0(result) if is_cid_v0(result) else CIDv1(result) diff --git a/tests/modules/csm/test_csm_module.py b/tests/modules/csm/test_csm_module.py index f078168c3..7deccd026 100644 --- a/tests/modules/csm/test_csm_module.py +++ b/tests/modules/csm/test_csm_module.py @@ -1,14 +1,19 @@ import logging +from collections import defaultdict from dataclasses import dataclass -from typing import NoReturn +from typing import NoReturn, Iterable, Literal, Type from unittest.mock import Mock, patch, PropertyMock import pytest +from hexbytes import HexBytes from src.constants import UINT64_MAX from src.modules.csm.csm import CSOracle from src.modules.csm.state import AttestationsAccumulator, State -from src.modules.submodules.types import CurrentFrame +from src.modules.csm.tree import Tree +from src.modules.submodules.oracle_module import ModuleExecuteDelay +from src.modules.submodules.types import CurrentFrame, ZERO_HASH +from src.providers.ipfs import CIDv0, CID from src.types import EpochNumber, NodeOperatorId, SlotNumber, StakingModuleId, ValidatorIndex from src.web3py.extensions.csm import CSM from tests.factory.blockstamp import ReferenceBlockStampFactory @@ -210,7 +215,7 @@ class FrameTestParam: last_processing_ref_slot: int current_ref_slot: int finalized_slot: int - expected_frame: tuple[int, int] + expected_frame: tuple[int, int] | Type[ValueError] @pytest.mark.parametrize( @@ -223,10 +228,9 @@ class FrameTestParam: last_processing_ref_slot=0, current_ref_slot=0, finalized_slot=0, - expected_frame=(0, 0), + expected_frame=ValueError, ), id="initial_epoch_not_set", - marks=pytest.mark.xfail(raises=ValueError), ), pytest.param( FrameTestParam( @@ -324,10 +328,15 @@ def test_current_frame_range(module: CSOracle, csm: CSM, mock_chain_config: NoRe ) ) module.get_initial_ref_slot = Mock(return_value=param.initial_ref_slot) - bs = ReferenceBlockStampFactory.build(slot_number=param.finalized_slot) - l_epoch, r_epoch = module.current_frame_range(bs) - assert (l_epoch, r_epoch) == param.expected_frame + if param.expected_frame is ValueError: + with pytest.raises(ValueError): + module.current_frame_range(ReferenceBlockStampFactory.build(slot_number=param.finalized_slot)) + else: + bs = ReferenceBlockStampFactory.build(slot_number=param.finalized_slot) + + l_epoch, r_epoch = module.current_frame_range(bs) + assert (l_epoch, r_epoch) == param.expected_frame @pytest.fixture() @@ -493,3 +502,273 @@ def test_collect_data_fulfilled_state( # assert that it is not early return from function msg = list(filter(lambda log: "All epochs are already processed. Nothing to collect" in log, caplog.messages)) assert len(msg) == 0, "Unexpected message found in logs" + + +@dataclass(frozen=True) +class BuildReportTestParam: + prev_tree_root: HexBytes + prev_tree_cid: CID | None + prev_acc_shares: Iterable[tuple[NodeOperatorId, int]] + curr_distribution: Mock + curr_tree_root: HexBytes + curr_tree_cid: CID | Literal[""] + curr_log_cid: CID + expected_make_tree_call_args: tuple | None + expected_func_result: tuple + + +@pytest.mark.parametrize( + "param", + [ + pytest.param( + BuildReportTestParam( + prev_tree_root=HexBytes(ZERO_HASH), + prev_tree_cid=None, + prev_acc_shares=[], + curr_distribution=Mock( + return_value=( + # distributed + 0, + # shares + defaultdict(int), + # log + Mock(), + ) + ), + curr_tree_root=HexBytes(ZERO_HASH), + curr_tree_cid="", + curr_log_cid=CID("QmLOG"), + expected_make_tree_call_args=None, + expected_func_result=(1, 100500, HexBytes(ZERO_HASH), "", CID("QmLOG"), 0), + ), + id="empty_prev_report_and_no_new_distribution", + ), + pytest.param( + BuildReportTestParam( + prev_tree_root=HexBytes(ZERO_HASH), + prev_tree_cid=None, + prev_acc_shares=[], + curr_distribution=Mock( + return_value=( + # distributed + 6, + # shares + defaultdict(int, {NodeOperatorId(0): 1, NodeOperatorId(1): 2, NodeOperatorId(2): 3}), + # log + Mock(), + ) + ), + curr_tree_root=HexBytes("NEW_TREE_ROOT".encode()), + curr_tree_cid=CID("QmNEW_TREE"), + curr_log_cid=CID("QmLOG"), + expected_make_tree_call_args=(({NodeOperatorId(0): 1, NodeOperatorId(1): 2, NodeOperatorId(2): 3},),), + expected_func_result=( + 1, + 100500, + HexBytes("NEW_TREE_ROOT".encode()), + CID("QmNEW_TREE"), + CID("QmLOG"), + 6, + ), + ), + id="empty_prev_report_and_new_distribution", + ), + pytest.param( + BuildReportTestParam( + prev_tree_root=HexBytes("OLD_TREE_ROOT".encode()), + prev_tree_cid=CID("QmOLD_TREE"), + prev_acc_shares=[(NodeOperatorId(0), 100), (NodeOperatorId(1), 200), (NodeOperatorId(2), 300)], + curr_distribution=Mock( + return_value=( + # distributed + 6, + # shares + defaultdict(int, {NodeOperatorId(0): 1, NodeOperatorId(1): 2, NodeOperatorId(3): 3}), + # log + Mock(), + ) + ), + curr_tree_root=HexBytes("NEW_TREE_ROOT".encode()), + curr_tree_cid=CID("QmNEW_TREE"), + curr_log_cid=CID("QmLOG"), + expected_make_tree_call_args=( + ({NodeOperatorId(0): 101, NodeOperatorId(1): 202, NodeOperatorId(2): 300, NodeOperatorId(3): 3},), + ), + expected_func_result=( + 1, + 100500, + HexBytes("NEW_TREE_ROOT".encode()), + CID("QmNEW_TREE"), + CID("QmLOG"), + 6, + ), + ), + id="non_empty_prev_report_and_new_distribution", + ), + pytest.param( + BuildReportTestParam( + prev_tree_root=HexBytes("OLD_TREE_ROOT".encode()), + prev_tree_cid=CID("QmOLD_TREE"), + prev_acc_shares=[(NodeOperatorId(0), 100), (NodeOperatorId(1), 200), (NodeOperatorId(2), 300)], + curr_distribution=Mock( + return_value=( + # distributed + 0, + # shares + defaultdict(int), + # log + Mock(), + ) + ), + curr_tree_root=HexBytes(32), + curr_tree_cid="", + curr_log_cid=CID("QmLOG"), + expected_make_tree_call_args=None, + expected_func_result=( + 1, + 100500, + HexBytes("OLD_TREE_ROOT".encode()), + CID("QmOLD_TREE"), + CID("QmLOG"), + 0, + ), + ), + id="non_empty_prev_report_and_no_new_distribution", + ), + ], +) +def test_build_report(csm: CSM, module: CSOracle, param: BuildReportTestParam): + module.validate_state = Mock() + module.report_contract.get_consensus_version = Mock(return_value=1) + # mock previous report + module.w3.csm.get_csm_tree_root = Mock(return_value=param.prev_tree_root) + module.w3.csm.get_csm_tree_cid = Mock(return_value=param.prev_tree_cid) + module.get_accumulated_shares = Mock(return_value=param.prev_acc_shares) + # mock current frame + module.calculate_distribution = param.curr_distribution + module.make_tree = Mock(return_value=Mock(root=param.curr_tree_root)) + module.publish_tree = Mock(return_value=param.curr_tree_cid) + module.publish_log = Mock(return_value=param.curr_log_cid) + + report = module.build_report(blockstamp=Mock(ref_slot=100500)) + + assert module.make_tree.call_args == param.expected_make_tree_call_args + assert report == param.expected_func_result + + +def test_execute_module_not_collected(module: CSOracle): + module.collect_data = Mock(return_value=False) + + execute_delay = module.execute_module( + last_finalized_blockstamp=Mock(slot_number=100500), + ) + assert execute_delay is ModuleExecuteDelay.NEXT_FINALIZED_EPOCH + + +def test_execute_module_no_report_blockstamp(module: CSOracle): + module.collect_data = Mock(return_value=True) + module.get_blockstamp_for_report = Mock(return_value=None) + + execute_delay = module.execute_module( + last_finalized_blockstamp=Mock(slot_number=100500), + ) + assert execute_delay is ModuleExecuteDelay.NEXT_FINALIZED_EPOCH + + +def test_execute_module_processed(module: CSOracle): + module.collect_data = Mock(return_value=True) + module.get_blockstamp_for_report = Mock(return_value=Mock(slot_number=100500)) + module.process_report = Mock() + + execute_delay = module.execute_module( + last_finalized_blockstamp=Mock(slot_number=100500), + ) + assert execute_delay is ModuleExecuteDelay.NEXT_SLOT + + +@pytest.fixture() +def tree(): + return Tree.new( + [ + (NodeOperatorId(0), 0), + (NodeOperatorId(1), 1), + (NodeOperatorId(2), 42), + (NodeOperatorId(UINT64_MAX), 0), + ] + ) + + +def test_get_accumulated_shares(module: CSOracle, tree: Tree): + encoded_tree = tree.encode() + module.w3.ipfs = Mock(fetch=Mock(return_value=encoded_tree)) + + for i, leaf in enumerate(module.get_accumulated_shares(cid=CIDv0("0x100500"), root=tree.root)): + assert tuple(leaf) == tree.tree.values[i]["value"] + + +def test_get_accumulated_shares_unexpected_root(module: CSOracle, tree: Tree): + encoded_tree = tree.encode() + module.w3.ipfs = Mock(fetch=Mock(return_value=encoded_tree)) + + with pytest.raises(ValueError): + next(module.get_accumulated_shares(cid=CIDv0("0x100500"), root=HexBytes("0x100500"))) + + +@dataclass(frozen=True) +class MakeTreeTestParam: + shares: dict[NodeOperatorId, int] + expected_tree_values: tuple | Type[ValueError] + + +@pytest.mark.parametrize( + "param", + [ + pytest.param(MakeTreeTestParam(shares={}, expected_tree_values=ValueError), id="empty"), + pytest.param( + MakeTreeTestParam( + shares={NodeOperatorId(0): 1, NodeOperatorId(1): 2, NodeOperatorId(2): 3}, + expected_tree_values=( + {'treeIndex': 4, 'value': (0, 1)}, + {'treeIndex': 2, 'value': (1, 2)}, + {'treeIndex': 3, 'value': (2, 3)}, + ), + ), + id="normal_tree", + ), + pytest.param( + MakeTreeTestParam( + shares={NodeOperatorId(0): 1}, + expected_tree_values=( + {'treeIndex': 2, 'value': (0, 1)}, + {'treeIndex': 1, 'value': (18446744073709551615, 0)}, + ), + ), + id="put_stone", + ), + pytest.param( + MakeTreeTestParam( + shares={ + NodeOperatorId(0): 1, + NodeOperatorId(1): 2, + NodeOperatorId(2): 3, + NodeOperatorId(18446744073709551615): 0, + }, + expected_tree_values=( + {'treeIndex': 4, 'value': (0, 1)}, + {'treeIndex': 2, 'value': (1, 2)}, + {'treeIndex': 3, 'value': (2, 3)}, + ), + ), + id="remove_stone", + ), + ], +) +def test_make_tree(module: CSOracle, param: MakeTreeTestParam): + module.w3.csm.module.MAX_OPERATORS_COUNT = UINT64_MAX + + if param.expected_tree_values is ValueError: + with pytest.raises(ValueError): + module.make_tree(param.shares) + else: + tree = module.make_tree(param.shares) + assert tree.tree.values == param.expected_tree_values