diff --git a/py_ecc/bls/__init__.py b/py_ecc/bls/__init__.py index f52c3de1..0710cedb 100644 --- a/py_ecc/bls/__init__.py +++ b/py_ecc/bls/__init__.py @@ -5,4 +5,5 @@ sign, verify, verify_multiple, + verify_multiple_multiple, ) diff --git a/py_ecc/bls/api.py b/py_ecc/bls/api.py index 6a07d1d8..00b14efe 100644 --- a/py_ecc/bls/api.py +++ b/py_ecc/bls/api.py @@ -1,7 +1,18 @@ +from operator import ( + itemgetter, +) +from secrets import ( + randbelow, +) from typing import ( + Iterator, Sequence, + Tuple, ) +from cytoolz.itertoolz import ( + groupby, +) from eth_typing import ( BLSPubkey, BLSSignature, @@ -24,6 +35,10 @@ neg, pairing, ) + +from .typing import ( + G1Uncompressed, +) from .utils import ( G1_to_pubkey, G2_to_signature, @@ -80,32 +95,66 @@ def aggregate_pubkeys(pubkeys: Sequence[BLSPubkey]) -> BLSPubkey: return G1_to_pubkey(o) +def _group_key_by_msg(pubkeys: Sequence[BLSPubkey], + message_hashes: Sequence[Hash32])-> Iterator[Tuple[G1Uncompressed, Hash32]]: + if len(pubkeys) != len(message_hashes): + raise ValidationError( + "len(pubkeys) (%s) should be equal to len(message_hashes) (%s)" % ( + len(pubkeys), len(message_hashes) + ) + ) + groups_dict = groupby(itemgetter(1), enumerate(message_hashes)) + for message_hash, group in groups_dict.items(): + agg_key = Z1 + for i, _ in group: + agg_key = add(agg_key, pubkey_to_G1(pubkeys[i])) + yield agg_key, message_hash + + def verify_multiple(pubkeys: Sequence[BLSPubkey], message_hashes: Sequence[Hash32], signature: BLSSignature, domain: int) -> bool: - len_msgs = len(message_hashes) - if len(pubkeys) != len_msgs: + o = FQ12.one() + for pubkey, message_hash in _group_key_by_msg(pubkeys, message_hashes): + o *= pairing( + hash_to_G2(message_hash, domain), + pubkey, + final_exponentiate=False, + ) + o *= pairing(signature_to_G2(signature), neg(G1), final_exponentiate=False) + final_exponentiation = final_exponentiate(o) + return final_exponentiation == FQ12.one() + + +def verify_multiple_multiple( + signatures: Sequence[BLSSignature], + pubkeys_and_messages: Sequence[Tuple[Sequence[BLSPubkey], Sequence[Hash32]]], + domain: int)-> bool: + """ + This is the optimized version of len(signatures) rounds of verify_multiple + """ + if len(signatures) != len(pubkeys_and_messages): raise ValidationError( - "len(pubkeys) (%s) should be equal to len(message_hashes) (%s)" % ( - len(pubkeys), len_msgs + "len(signatures) (%s) should be equal to len(pubkeys_and_messages) (%s)" % ( + len(signatures), len(pubkeys_and_messages) ) ) - try: - o = FQ12([1] + [0] * 11) - for m_pubs in set(message_hashes): - # aggregate the pubs - group_pub = Z1 - for i in range(len_msgs): - if message_hashes[i] == m_pubs: - group_pub = add(group_pub, pubkey_to_G1(pubkeys[i])) - - o *= pairing(hash_to_G2(m_pubs, domain), group_pub, final_exponentiate=False) - o *= pairing(signature_to_G2(signature), neg(G1), final_exponentiate=False) - - final_exponentiation = final_exponentiate(o) - return final_exponentiation == FQ12.one() - except (ValidationError, ValueError, AssertionError): - return False + random_ints = (1,) + tuple(2**randbelow(64) for _ in signatures[:-1]) + o = FQ12.one() + for r_i, (pubkeys, message_hashes) in zip(random_ints, pubkeys_and_messages): + for pubkey, message_hash in _group_key_by_msg(pubkeys, message_hashes): + o *= pairing( + multiply(hash_to_G2(message_hash, domain), r_i), + pubkey, + final_exponentiate=False, + ) + agg_sig = Z2 + for r_i, sig in zip(random_ints, signatures): + agg_sig = add(agg_sig, multiply(signature_to_G2(sig), r_i)) + o *= pairing(agg_sig, neg(G1), final_exponentiate=False) + + final_exponentiation = final_exponentiate(o) + return final_exponentiation == FQ12.one() diff --git a/scripts/benchmark_multi_multi.py b/scripts/benchmark_multi_multi.py new file mode 100644 index 00000000..90a48ca1 --- /dev/null +++ b/scripts/benchmark_multi_multi.py @@ -0,0 +1,68 @@ +from random import sample +from py_ecc.bls import ( + aggregate_pubkeys, + aggregate_signatures, + sign, + privtopub, + verify_multiple, + verify_multiple_multiple, +) +import time + +domain = 0 +validator_indices = tuple(range(1000)) +privkeys = tuple(2**i for i in validator_indices) +pubkeys = [privtopub(k) for k in privkeys] + +MAX_ATTESTATIONS = 128 +TARGET_COMMITTEE_SIZE = 128 + + +class Attestation: + def __init__(self, msg_1, msg_2): + msg_1_validators = sample(validator_indices, TARGET_COMMITTEE_SIZE//2) + msg_2_validators = sample(validator_indices, TARGET_COMMITTEE_SIZE//2) + self.agg_pubkeys = [ + aggregate_pubkeys([pubkeys[i] for i in msg_1_validators]), + aggregate_pubkeys([pubkeys[i] for i in msg_2_validators]), + ] + self.msgs = [msg_1, msg_2] + msg_1_sigs = [sign(msg_1, privkeys[i], domain) for i in msg_1_validators] + msg_2_sigs = [sign(msg_2, privkeys[i], domain) for i in msg_2_validators] + self.sig = aggregate_signatures([ + aggregate_signatures(msg_1_sigs), + aggregate_signatures(msg_2_sigs), + ]) + + +att = Attestation(b'\x12' * 32, b'\x34' * 32) +atts = (att,) * MAX_ATTESTATIONS + + +def profile_verify_multiple(): + t = time.time() + for att in atts: + assert verify_multiple( + pubkeys=att.agg_pubkeys, + message_hashes=att.msgs, + signature=att.sig, + domain=domain, + ) + print(time.time() - t) + + +def profile_verify_multiple_multiple(): + t = time.time() + assert verify_multiple_multiple( + signatures=[att.sig for att in atts], + pubkeys_and_messages=[[att.agg_pubkeys, att.msgs] for att in atts], + domain=domain, + ) + print(time.time() - t) + + +if __name__ == '__main__': + print("profile_verify_multiple") + profile_verify_multiple() + print("profile_verify_multiple_multiple") + profile_verify_multiple_multiple() diff --git a/tests/test_bls.py b/tests/test_bls.py index 94173736..7cec7380 100644 --- a/tests/test_bls.py +++ b/tests/test_bls.py @@ -10,6 +10,7 @@ sign, verify, verify_multiple, + verify_multiple_multiple, ) from py_ecc.bls.hash import ( hash_eth2, @@ -44,6 +45,7 @@ normalize, field_modulus as q, ) +from random import sample @pytest.mark.parametrize( @@ -238,3 +240,32 @@ def test_multi_aggregation(msg_1, msg_2, privkeys_1, privkeys_2): signature=aggsig, domain=domain, ) + + +def test_multi_multi(): + domain = 0 + validator_indices = tuple(range(10)) + privkeys = tuple(2**i for i in validator_indices) + pubkeys = [privtopub(k) for k in privkeys] + + class Attestation: + def __init__(self): + msg_1_validators = (1, 2, 3, 4) + msg_2_validators = (4, 5, 6, 7) + self.agg_pubkeys = [ + aggregate_pubkeys([pubkeys[i] for i in msg_1_validators]), + aggregate_pubkeys([pubkeys[i] for i in msg_2_validators]), + ] + self.msgs = (b'\x12' * 32, b'\x34' * 32) + msg_1_sigs = [sign(self.msgs[0], privkeys[i], domain) for i in msg_1_validators] + msg_2_sigs = [sign(self.msgs[1], privkeys[i], domain) for i in msg_2_validators] + self.sig = aggregate_signatures([ + aggregate_signatures(msg_1_sigs), + aggregate_signatures(msg_2_sigs), + ]) + atts = (Attestation(),) * 3 + assert verify_multiple_multiple( + signatures=[att.sig for att in atts], + pubkeys_and_messages=[[att.agg_pubkeys, att.msgs] for att in atts], + domain=domain, + )