Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fast-verification-of-multiple-bls-signatures #67

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py_ecc/bls/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
sign,
verify,
verify_multiple,
verify_multiple_multiple,
)
89 changes: 69 additions & 20 deletions py_ecc/bls/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from operator import (
itemgetter,
)
from secrets import (
randbelow,
)
from typing import (
Iterator,
Sequence,
Tuple,
)

from cytoolz.itertoolz import (
groupby,
)
from eth_typing import (
BLSPubkey,
BLSSignature,
Expand All @@ -24,6 +35,10 @@
neg,
pairing,
)

from .typing import (
G1Uncompressed,
)
from .utils import (
G1_to_pubkey,
G2_to_signature,
Expand Down Expand Up @@ -80,32 +95,66 @@ def aggregate_pubkeys(pubkeys: Sequence[BLSPubkey]) -> BLSPubkey:
return G1_to_pubkey(o)


def _group_key_by_msg(pubkeys: Sequence[BLSPubkey],
message_hashes: Sequence[Hash32])-> Iterator[Tuple[G1Uncompressed, Hash32]]:
if len(pubkeys) != len(message_hashes):
raise ValidationError(
"len(pubkeys) (%s) should be equal to len(message_hashes) (%s)" % (
len(pubkeys), len(message_hashes)
)
)
groups_dict = groupby(itemgetter(1), enumerate(message_hashes))
for message_hash, group in groups_dict.items():
agg_key = Z1
for i, _ in group:
agg_key = add(agg_key, pubkey_to_G1(pubkeys[i]))
yield agg_key, message_hash


def verify_multiple(pubkeys: Sequence[BLSPubkey],
message_hashes: Sequence[Hash32],
signature: BLSSignature,
domain: int) -> bool:
len_msgs = len(message_hashes)

if len(pubkeys) != len_msgs:
o = FQ12.one()
for pubkey, message_hash in _group_key_by_msg(pubkeys, message_hashes):
o *= pairing(
hash_to_G2(message_hash, domain),
pubkey,
final_exponentiate=False,
)
o *= pairing(signature_to_G2(signature), neg(G1), final_exponentiate=False)
final_exponentiation = final_exponentiate(o)
return final_exponentiation == FQ12.one()


def verify_multiple_multiple(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming ideas: optimized_verify_multiple, fast_verify_multiple...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's faster why not just replace the existing implementation instead of maintaining two?

Also, if there's an existing one and we are keeping both around, seems prudent to test them against each other to ensure they have equal behavior (sorry if I missed the test that does this)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new function is an optimized version of doing multiple times of verify_multiple. The later verifies a single attestation and the former verifies many attestations in a block.

Copy link
Contributor Author

@ChihChengLiang ChihChengLiang May 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do desperately want to get rid of verify_multiple_multiple and rename it properly. Besides the name of the function, the function parameter pubkeys_and_messages: Sequence[Tuple[Sequence[BLSPubkey], Sequence[Hash32]]] is also confusing. Should we create some objects like

class SignedMessage:
    signature: BLSSignature
    public_keys: Sequence[BLSPubkey]
    message_hashes: Sequence[Hash32]
    def __init__(self, signature, public_keys, message_hashes):
        self.signature = signature
        if len(public_keys) != len(message_hashes)
            raise ValidationError()
        self.public_keys = public_keys
        self.message_hashes = message_hashes

Then write the function parameters in this fashion?

def verify_multiple(signed_message: SignedMessage): 
def verify_multiple_multiple(signed_messages: Sequence[SignedMessage]): 

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChihChengLiang

I do desperately want to get rid of verify_multiple_multiple and rename it properly.

how about batch_verify ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. We could also do "verify_multiple_aggregate_signatures" to match milagro's function name.

signatures: Sequence[BLSSignature],
pubkeys_and_messages: Sequence[Tuple[Sequence[BLSPubkey], Sequence[Hash32]]],
domain: int)-> bool:
"""
This is the optimized version of len(signatures) rounds of verify_multiple
"""
if len(signatures) != len(pubkeys_and_messages):
raise ValidationError(
"len(pubkeys) (%s) should be equal to len(message_hashes) (%s)" % (
len(pubkeys), len_msgs
"len(signatures) (%s) should be equal to len(pubkeys_and_messages) (%s)" % (
len(signatures), len(pubkeys_and_messages)
)
)

try:
o = FQ12([1] + [0] * 11)
for m_pubs in set(message_hashes):
# aggregate the pubs
group_pub = Z1
for i in range(len_msgs):
if message_hashes[i] == m_pubs:
group_pub = add(group_pub, pubkey_to_G1(pubkeys[i]))

o *= pairing(hash_to_G2(m_pubs, domain), group_pub, final_exponentiate=False)
o *= pairing(signature_to_G2(signature), neg(G1), final_exponentiate=False)

final_exponentiation = final_exponentiate(o)
return final_exponentiation == FQ12.one()
except (ValidationError, ValueError, AssertionError):
return False
random_ints = (1,) + tuple(2**randbelow(64) for _ in signatures[:-1])
o = FQ12.one()
for r_i, (pubkeys, message_hashes) in zip(random_ints, pubkeys_and_messages):
for pubkey, message_hash in _group_key_by_msg(pubkeys, message_hashes):
o *= pairing(
multiply(hash_to_G2(message_hash, domain), r_i),
pubkey,
final_exponentiate=False,
)
agg_sig = Z2
for r_i, sig in zip(random_ints, signatures):
agg_sig = add(agg_sig, multiply(signature_to_G2(sig), r_i))
o *= pairing(agg_sig, neg(G1), final_exponentiate=False)

final_exponentiation = final_exponentiate(o)
return final_exponentiation == FQ12.one()
68 changes: 68 additions & 0 deletions scripts/benchmark_multi_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from random import sample
from py_ecc.bls import (
aggregate_pubkeys,
aggregate_signatures,
sign,
privtopub,
verify_multiple,
verify_multiple_multiple,
)
import time

domain = 0
validator_indices = tuple(range(1000))
privkeys = tuple(2**i for i in validator_indices)
pubkeys = [privtopub(k) for k in privkeys]

MAX_ATTESTATIONS = 128
TARGET_COMMITTEE_SIZE = 128


class Attestation:
def __init__(self, msg_1, msg_2):
msg_1_validators = sample(validator_indices, TARGET_COMMITTEE_SIZE//2)
msg_2_validators = sample(validator_indices, TARGET_COMMITTEE_SIZE//2)
self.agg_pubkeys = [
aggregate_pubkeys([pubkeys[i] for i in msg_1_validators]),
aggregate_pubkeys([pubkeys[i] for i in msg_2_validators]),
]
self.msgs = [msg_1, msg_2]
msg_1_sigs = [sign(msg_1, privkeys[i], domain) for i in msg_1_validators]
msg_2_sigs = [sign(msg_2, privkeys[i], domain) for i in msg_2_validators]
self.sig = aggregate_signatures([
aggregate_signatures(msg_1_sigs),
aggregate_signatures(msg_2_sigs),
])


att = Attestation(b'\x12' * 32, b'\x34' * 32)
atts = (att,) * MAX_ATTESTATIONS


def profile_verify_multiple():
t = time.time()
for att in atts:
assert verify_multiple(
pubkeys=att.agg_pubkeys,
message_hashes=att.msgs,
signature=att.sig,
domain=domain,
)
print(time.time() - t)


def profile_verify_multiple_multiple():
t = time.time()
assert verify_multiple_multiple(
signatures=[att.sig for att in atts],
pubkeys_and_messages=[[att.agg_pubkeys, att.msgs] for att in atts],
domain=domain,
)
print(time.time() - t)


if __name__ == '__main__':
print("profile_verify_multiple")
profile_verify_multiple()
print("profile_verify_multiple_multiple")
profile_verify_multiple_multiple()
31 changes: 31 additions & 0 deletions tests/test_bls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
sign,
verify,
verify_multiple,
verify_multiple_multiple,
)
from py_ecc.bls.hash import (
hash_eth2,
Expand Down Expand Up @@ -44,6 +45,7 @@
normalize,
field_modulus as q,
)
from random import sample


@pytest.mark.parametrize(
Expand Down Expand Up @@ -238,3 +240,32 @@ def test_multi_aggregation(msg_1, msg_2, privkeys_1, privkeys_2):
signature=aggsig,
domain=domain,
)


def test_multi_multi():
domain = 0
validator_indices = tuple(range(10))
privkeys = tuple(2**i for i in validator_indices)
pubkeys = [privtopub(k) for k in privkeys]

class Attestation:
def __init__(self):
msg_1_validators = (1, 2, 3, 4)
msg_2_validators = (4, 5, 6, 7)
self.agg_pubkeys = [
aggregate_pubkeys([pubkeys[i] for i in msg_1_validators]),
aggregate_pubkeys([pubkeys[i] for i in msg_2_validators]),
]
self.msgs = (b'\x12' * 32, b'\x34' * 32)
msg_1_sigs = [sign(self.msgs[0], privkeys[i], domain) for i in msg_1_validators]
msg_2_sigs = [sign(self.msgs[1], privkeys[i], domain) for i in msg_2_validators]
self.sig = aggregate_signatures([
aggregate_signatures(msg_1_sigs),
aggregate_signatures(msg_2_sigs),
])
atts = (Attestation(),) * 3
assert verify_multiple_multiple(
signatures=[att.sig for att in atts],
pubkeys_and_messages=[[att.agg_pubkeys, att.msgs] for att in atts],
domain=domain,
)