diff --git a/py_ecc/bls/api.py b/py_ecc/bls/api.py index 461a4657..d8d5568a 100644 --- a/py_ecc/bls/api.py +++ b/py_ecc/bls/api.py @@ -1,6 +1,4 @@ -from operator import ( - itemgetter, -) + from secrets import ( randbelow, ) @@ -9,10 +7,6 @@ Sequence, Tuple, ) - -from cytoolz.itertoolz import ( - groupby, -) from eth_typing import ( BLSPubkey, BLSSignature, @@ -92,19 +86,15 @@ def aggregate_pubkeys(pubkeys: Sequence[BLSPubkey]) -> BLSPubkey: return G1_to_pubkey(o) -def _group_by_messages( - message_hashes: Sequence[Hash32], - pubkeys: Sequence[BLSPubkey]) -> Iterator[Tuple[Hash32, Tuple[BLSPubkey, ...]]]: +def _zip(pubkeys: Sequence[BLSPubkey], + message_hashes: Sequence[Hash32])-> Iterator[Tuple[BLSPubkey, Hash32]]: if len(pubkeys) != len(message_hashes): raise ValidationError( "len(pubkeys) (%s) should be equal to len(message_hashes) (%s)" % ( len(pubkeys), len(message_hashes) ) ) - groups = groupby(itemgetter(0), zip(message_hashes, pubkeys)).items() - for message_hash, group in groups: - group_pubkeys = tuple(BLSPubkey(pubkey) for _, pubkey in group) - yield Hash32(message_hash), group_pubkeys + return zip(pubkeys, message_hashes) def verify_multiple(pubkeys: Sequence[BLSPubkey], @@ -113,11 +103,12 @@ def verify_multiple(pubkeys: Sequence[BLSPubkey], domain: int) -> bool: o = FQ12.one() - for message_hash, group_pubkeys in _group_by_messages(message_hashes, pubkeys): - agg_pub = Z1 - for key in group_pubkeys: - agg_pub = add(agg_pub, pubkey_to_G1(key)) - o *= pairing(hash_to_G2(message_hash, domain), agg_pub, final_exponentiate=False) + for pubkey, message_hash in _zip(pubkeys, message_hashes): + o *= pairing( + hash_to_G2(message_hash, domain), + pubkey_to_G1(pubkey), + final_exponentiate=False, + ) o *= pairing(signature_to_G2(signature), neg(G1), final_exponentiate=False) final_exponentiation = final_exponentiate(o) return final_exponentiation == FQ12.one() @@ -139,15 +130,11 @@ def verify_multiple_multiple( random_ints = (1,) + tuple(2**randbelow(64) for _ in signatures[:-1]) o = FQ12.one() - for r_i, pm in zip(random_ints, pubkeys_and_messages): - pubkeys, message_hashes = pm - for message_hash, group_pubkeys in _group_by_messages(message_hashes, pubkeys): - agg_pub = Z1 - for key in group_pubkeys: - agg_pub = add(agg_pub, pubkey_to_G1(key)) + for r_i, (pubkeys, message_hashes) in zip(random_ints, pubkeys_and_messages): + for pubkey, message_hash in _zip(pubkeys, message_hashes): o *= pairing( multiply(hash_to_G2(message_hash, domain), r_i), - agg_pub, + pubkey_to_G1(pubkey), final_exponentiate=False, ) agg_sig = Z2