Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CSM] feat: state data as tuples #557

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class FrameCheckpoint:
@dataclass
class ValidatorDuty:
index: ValidatorIndex
included: bool
is_included: bool


class FrameCheckpointsIterator:
Expand Down Expand Up @@ -200,7 +200,7 @@ def _check_duty(
for validator_duty in committee:
self.state.inc(
validator_duty.index,
included=validator_duty.included,
validator_duty.is_included,
)
if duty_epoch not in self.state.unprocessed_epochs:
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
Expand All @@ -222,7 +222,7 @@ def _prepare_committees(self, epoch: EpochNumber) -> Committees:
validators = []
# Order of insertion is used to track the positions in the committees.
for validator in committee.validators:
validators.append(ValidatorDuty(index=ValidatorIndex(int(validator)), included=False))
validators.append(ValidatorDuty(index=ValidatorIndex(int(validator)), is_included=False))
committees[(committee.slot, committee.index)] = validators
return committees

Expand All @@ -233,7 +233,7 @@ def process_attestations(attestations: Iterable[BlockAttestation], committees: C
committee = committees.get(committee_id, [])
att_bits = _to_bits(attestation.aggregation_bits)
for index_in_committee, validator_duty in enumerate(committee):
validator_duty.included = validator_duty.included or _is_attested(att_bits, index_in_committee)
validator_duty.is_included = validator_duty.is_included or _is_attested(att_bits, index_in_committee)


def _is_attested(bits: Sequence[bool], index: int) -> bool:
Expand Down
21 changes: 13 additions & 8 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
)
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached
from src.modules.csm.log import FramePerfLog
from src.modules.csm.state import State
from src.modules.csm.log import FramePerfLog, AttestationsAccumulatorLog
from src.modules.csm.state import State, perf, Assigned, Included, AttestationsAccumulator
from src.modules.csm.tree import Tree
from src.modules.csm.types import ReportData, Shares
from src.modules.submodules.consensus import ConsensusModule
Expand Down Expand Up @@ -228,7 +228,7 @@ def calculate_distribution(
) -> tuple[int, defaultdict[NodeOperatorId, int], FramePerfLog]:
"""Computes distribution of fee shares at the given timestamp"""

network_avg_perf = self.state.get_network_aggr().perf
network_avg_perf = perf(self.state.get_network_aggr())
threshold = network_avg_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
operators_to_validators = self.module_validators_by_node_operators(blockstamp)

Expand All @@ -243,9 +243,14 @@ def calculate_distribution(
continue

for v in validators:
aggr = self.state.data.get(ValidatorIndex(int(v.index)))
aggr = (
self.state.data[ValidatorIndex(int(v.index))] or
AttestationsAccumulator((Assigned(0), Included(0)))
)

if aggr is None:
assigned, included = aggr

if not assigned:
# It's possible that the validator is not assigned to any duty, hence it's performance
# is not presented in the aggregates (e.g. exited, pending for activation etc).
continue
Expand All @@ -256,12 +261,12 @@ def calculate_distribution(
log.operators[no_id].validators[v.index].slashed = True
continue

if aggr.perf > threshold:
if perf(aggr) > threshold:
# Count of assigned attestations used as a metrics of time
# the validator was active in the current frame.
distribution[no_id] += aggr.assigned
distribution[no_id] += assigned

log.operators[no_id].validators[v.index].perf = aggr
log.operators[no_id].validators[v.index].perf = AttestationsAccumulatorLog(assigned, included)

# Calculate share of each CSM node operator.
shares = defaultdict[NodeOperatorId, int](int)
Expand Down
10 changes: 8 additions & 2 deletions src/modules/csm/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@
from collections import defaultdict
from dataclasses import asdict, dataclass, field

from src.modules.csm.state import AttestationsAccumulator
from src.modules.csm.types import Shares
from src.types import EpochNumber, NodeOperatorId, ReferenceBlockStamp


class LogJSONEncoder(json.JSONEncoder): ...


@dataclass
class AttestationsAccumulatorLog:
assigned: int = 0
included: int = 0


@dataclass
class ValidatorFrameSummary:
perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
# TODO: Should be renamed. Perf means different things in different contexts
perf: AttestationsAccumulatorLog = field(default_factory=AttestationsAccumulatorLog)
slashed: bool = False


Expand Down
71 changes: 33 additions & 38 deletions src/modules/csm/state.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import logging
import os
import pickle
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Self
from typing import Self, NewType

from src.types import EpochNumber, ValidatorIndex
from src.utils.range import sequence
Expand All @@ -17,20 +15,14 @@ class InvalidState(ValueError):
"""State has data considered as invalid for a report"""


@dataclass
class AttestationsAccumulator:
"""Accumulator of attestations duties observed for a validator"""
Assigned = NewType("Assigned", int)
Included = NewType("Included", int)
AttestationsAccumulator = NewType('AttestationsAccumulator', tuple[Assigned, Included])

assigned: int = 0
included: int = 0

@property
def perf(self) -> float:
return self.included / self.assigned if self.assigned else 0

def add_duty(self, included: bool) -> None:
self.assigned += 1
self.included += 1 if included else 0
def perf(acc: AttestationsAccumulator) -> float:
assigned, included = acc
return included / assigned if assigned else 0


class State:
Expand All @@ -43,15 +35,15 @@ class State:

The state can be migrated to be used for another frame's report by calling the `migrate` method.
"""
# validator_index -> (assigned, included)
data: list[AttestationsAccumulator | None]

data: defaultdict[ValidatorIndex, AttestationsAccumulator]

_epochs_to_process: tuple[EpochNumber, ...]
_epochs_to_process: set[EpochNumber]
_processed_epochs: set[EpochNumber]

def __init__(self, data: dict[ValidatorIndex, AttestationsAccumulator] | None = None) -> None:
self.data = defaultdict(AttestationsAccumulator, data or {})
self._epochs_to_process = tuple()
def __init__(self, data: list[AttestationsAccumulator | None] | None = None) -> None:
self.data = data or []
self._epochs_to_process = set()
self._processed_epochs = set()

EXTENSION = ".pkl"
Expand Down Expand Up @@ -88,13 +80,16 @@ def buffer(self) -> Path:
return self.file().with_suffix(".buf")

def clear(self) -> None:
self.data = defaultdict(AttestationsAccumulator)
self._epochs_to_process = tuple()
self.data = []
self._epochs_to_process.clear()
self._processed_epochs.clear()
assert self.is_empty

def inc(self, key: ValidatorIndex, included: bool) -> None:
self.data[key].add_duty(included)
def inc(self, key: ValidatorIndex, is_included: bool) -> None:
if key >= len(self.data):
self.data += [None] * (key - len(self.data) + 1)
assigned, included = self.data[key] or (Assigned(0), Included(0))
self.data[key] = AttestationsAccumulator((Assigned(assigned + 1), Included(included + 1 if is_included else included)))

def add_processed_epoch(self, epoch: EpochNumber) -> None:
self._processed_epochs.add(epoch)
Expand All @@ -110,7 +105,7 @@ def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber):
self.clear()
break

self._epochs_to_process = tuple(sequence(l_epoch, r_epoch))
self._epochs_to_process = set(sequence(l_epoch, r_epoch))
self.commit()

def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
Expand All @@ -133,7 +128,7 @@ def is_empty(self) -> bool:
def unprocessed_epochs(self) -> set[EpochNumber]:
if not self._epochs_to_process:
raise ValueError("Epochs to process are not set")
diff = set(self._epochs_to_process) - self._processed_epochs
diff = self._epochs_to_process - self._processed_epochs
return diff

@property
Expand All @@ -149,15 +144,15 @@ def frame(self) -> tuple[EpochNumber, EpochNumber]:
def get_network_aggr(self) -> AttestationsAccumulator:
"""Return `AttestationsAccumulator` over duties of all the network validators"""

included = assigned = 0
for validator, acc in self.data.items():
if acc.included > acc.assigned:
raise ValueError(f"Invalid accumulator: {validator=}, {acc=}")
included += acc.included
assigned += acc.assigned
aggr = AttestationsAccumulator(
included=included,
assigned=assigned,
)
logger.info({"msg": "Network attestations aggregate computed", "value": repr(aggr), "avg_perf": aggr.perf})
net_included = net_assigned = 0
for validator_index, acc in enumerate(self.data):
if acc is None:
continue
assigned, included = acc
if included > assigned:
raise ValueError(f"Invalid accumulator: {validator_index=}, {acc=}")
net_included += included
net_assigned += assigned
aggr = AttestationsAccumulator((Assigned(net_assigned), Included(net_included)))
logger.info({"msg": "Network attestations aggregate computed", "value": repr(aggr), "avg_perf": perf(aggr)})
return aggr
8 changes: 4 additions & 4 deletions tests/modules/csm/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_checkpoints_processor_prepare_committees(mock_get_attestation_committee
assert int(committee_index) == committee_from_raw.index
assert len(validators) == 32
for validator in validators:
assert validator.included is False
assert validator.is_included is False


def test_checkpoints_processor_process_attestations(mock_get_attestation_committees, consensus_client, converter):
Expand Down Expand Up @@ -265,9 +265,9 @@ def test_checkpoints_processor_process_attestations(mock_get_attestation_committ
for validator in validators:
# only the first attestation is accounted
if index == 0:
assert validator.included is True
assert validator.is_included is True
else:
assert validator.included is False
assert validator.is_included is False


def test_checkpoints_processor_process_attestations_undefined_committee(
Expand All @@ -290,7 +290,7 @@ def test_checkpoints_processor_process_attestations_undefined_committee(
process_attestations([attestation], committees)
for validators in committees.values():
for v in validators:
assert v.included is False
assert v.is_included is False


@pytest.fixture()
Expand Down
47 changes: 30 additions & 17 deletions tests/modules/csm/test_csm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from src.constants import UINT64_MAX
from src.modules.csm.csm import CSOracle
from src.modules.csm.state import AttestationsAccumulator, State
from src.modules.csm.state import AttestationsAccumulator, State, perf
from src.modules.csm.tree import Tree
from src.modules.submodules.oracle_module import ModuleExecuteDelay
from src.modules.submodules.types import CurrentFrame, ZERO_HASH
Expand Down Expand Up @@ -118,21 +118,34 @@ def test_calculate_distribution(module: CSOracle, csm: CSM):
)

module.state = State(
{
ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame
ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000),
ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000),
ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000),
ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000),
ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming
ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming
ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000),
ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming
# ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state
ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000),
ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000),
ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000),
}
[
# ValidatorIndex(0):
AttestationsAccumulator((200, 200)), # short on frame
# ValidatorIndex(1):
AttestationsAccumulator((1000, 1000)),
# ValidatorIndex(2):
AttestationsAccumulator((1000, 1000)),
# ValidatorIndex(3):
AttestationsAccumulator((1000, 999)),
# ValidatorIndex(4):
AttestationsAccumulator((1000, 900)),
# ValidatorIndex(5):
AttestationsAccumulator((1000, 500)), # underperforming
# ValidatorIndex(6):
AttestationsAccumulator((1000, 0)), # underperforming
# ValidatorIndex(7):
AttestationsAccumulator((1000, 900)),
# ValidatorIndex(8):
AttestationsAccumulator((1000, 500)), # underperforming
# ValidatorIndex(9):
None, # missing in state
# ValidatorIndex(10):
AttestationsAccumulator((1000, 1000)),
# ValidatorIndex(11):
AttestationsAccumulator((1000, 1000)),
# ValidatorIndex(12):
AttestationsAccumulator((1000, 1000)),
]
)
module.state.migrate(EpochNumber(100), EpochNumber(500))

Expand Down Expand Up @@ -177,7 +190,7 @@ def test_calculate_distribution(module: CSOracle, csm: CSM):
assert log.operators[NodeOperatorId(6)].distributed == 2380

assert log.frame == (100, 500)
assert log.threshold == module.state.get_network_aggr().perf - 0.05
assert log.threshold == perf(module.state.get_network_aggr()) - 0.05


# Static functions you were dreaming of for so long.
Expand Down
4 changes: 2 additions & 2 deletions tests/modules/csm/test_log.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import pytest

from src.modules.csm.log import FramePerfLog
from src.modules.csm.log import FramePerfLog, AttestationsAccumulatorLog
from src.modules.csm.state import AttestationsAccumulator
from src.types import EpochNumber, NodeOperatorId, ReferenceBlockStamp
from tests.factory.blockstamp import ReferenceBlockStampFactory
Expand Down Expand Up @@ -29,7 +29,7 @@ def test_fields_access(log: FramePerfLog):

def test_log_encode(log: FramePerfLog):
# Fill in dynamic fields to make sure we have data in it to be encoded.
log.operators[NodeOperatorId(42)].validators["41337"].perf = AttestationsAccumulator(220, 119)
log.operators[NodeOperatorId(42)].validators["41337"].perf = AttestationsAccumulatorLog(220, 119)
log.operators[NodeOperatorId(42)].distributed = 17
log.operators[NodeOperatorId(0)].distributed = 0

Expand Down
Loading
Loading