Skip to content

Commit

Permalink
Add validation of constants at the time of import
Browse files Browse the repository at this point in the history
  • Loading branch information
Bhargavasomu committed Dec 16, 2018
1 parent e502901 commit 00e65a8
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 11 deletions.
17 changes: 10 additions & 7 deletions py_ecc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import sys

from .bls12_381_curve import (
from .bls12_381_curve import ( # noqa: F401
bls12_381,
BLS12_381_Curve,
optimized_bls12_381,
Optimized_BLS12_381_Curve,
)

from .bn128_curve import (
from .bn128_curve import ( # noqa: F401
bn128,
BN128_Curve,
optimized_bn128,
Optimized_BN128_Curve,
)

Expand All @@ -26,11 +30,10 @@

from .secp256k1 import secp256k1 # noqa: F401

from .validate_constants import validate_constants

sys.setrecursionlimit(max(100000, sys.getrecursionlimit()))

bn128 = BN128_Curve()
optimized_bn128 = Optimized_BN128_Curve()
sys.setrecursionlimit(max(100000, sys.getrecursionlimit()))

bls12_381 = BLS12_381_Curve()
optimized_bls12_381 = Optimized_BLS12_381_Curve()
# Check all the constants are valid, before using them
validate_constants()
4 changes: 4 additions & 0 deletions py_ecc/bls12_381_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,7 @@ def miller_loop(cls,
return f ** ((cls.field_modulus ** 12 - 1) // cls.curve_order)
else:
return f


bls12_381 = BLS12_381_Curve()
optimized_bls12_381 = Optimized_BLS12_381_Curve()
4 changes: 4 additions & 0 deletions py_ecc/bn128_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,7 @@ def miller_loop(cls,
return f ** ((cls.field_modulus ** 12 - 1) // cls.curve_order)
else:
return f


bn128 = BN128_Curve()
optimized_bn128 = Optimized_BN128_Curve()
2 changes: 0 additions & 2 deletions py_ecc/field_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def __init__(self, val: IntOrFQ, curve_name: str) -> None:
"""
self.curve_name = curve_name
self.field_modulus = field_properties[curve_name]["field_modulus"]
# See, it's prime!
# assert pow(2, self.field_modulus, self.field_modulus) == 2

if isinstance(val, FQ):
self.n = val.n
Expand Down
2 changes: 0 additions & 2 deletions py_ecc/optimized_field_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def __init__(self, val: IntOrFQ, curve_name: str) -> None:
"""
self.curve_name = curve_name
self.field_modulus = field_properties[curve_name]["field_modulus"]
# See, it's prime!
# assert pow(2, self.field_modulus, self.field_modulus) == 2

if isinstance(val, FQ):
self.n = val.n
Expand Down
135 changes: 135 additions & 0 deletions py_ecc/validate_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from functools import lru_cache

from py_ecc.bls12_381_curve import (
bls12_381,
optimized_bls12_381,
)

from py_ecc.bn128_curve import (
bn128,
optimized_bn128,
)

from py_ecc.curve_properties import (
curve_properties,
optimized_curve_properties,
)

from py_ecc.field_properties import (
field_properties,
)


def validate_field_properties() -> None:
for curve_name in field_properties:
# Check if field_modulus is prime
field_modulus = field_properties[curve_name]["field_modulus"]
if pow(2, field_modulus, field_modulus) != 2:
raise ValueError(
"Field Modulus of the curve {} is not a prime".format(curve_name)
)


def validate_curve_properties() -> None:
for curve_name in curve_properties:
# Check if curve_order is prime
curve_order = curve_properties[curve_name]["curve_order"]
if pow(2, curve_order, curve_order) != 2:
raise ValueError(
"Curve Order of the curve {} is not a prime".format(curve_name)
)

# Check consistency b/w field_modulus and curve_order
field_modulus = field_properties[curve_name]["field_modulus"]
if (field_modulus ** 12 - 1) % curve_order != 0:
raise ValueError(
"Inconsistent values among field_modulus and curve_order in the curve {}"
.format(curve_name)
)

# Check validity of pseudo_binary_encoding
pseudo_binary_encoding = curve_properties[curve_name]["pseudo_binary_encoding"]
ate_loop_count = curve_properties[curve_name]["ate_loop_count"]
if sum([e * 2**i for i, e in enumerate(pseudo_binary_encoding)]) != ate_loop_count:
raise ValueError(
"Inconsistent values among pseudo_binary_encoding and ate_loop_count"
)


def validate_optimized_curve_properties() -> None:
for curve_name in optimized_curve_properties:
# Check if curve_order is prime
curve_order = optimized_curve_properties[curve_name]["curve_order"]
if pow(2, curve_order, curve_order) != 2:
raise ValueError(
"Curve Order of the optimized curve {} is not a prime".format(curve_name)
)

# Check consistency b/w field_modulus and curve_order
field_modulus = field_properties[curve_name]["field_modulus"]
if (field_modulus ** 12 - 1) % curve_order != 0:
raise ValueError(
"Inconsistent values among field_modulus and curve_order in the optimized curve {}"
.format(curve_name)
)

# Check validity of pseudo_binary_encoding
pseudo_binary_encoding = optimized_curve_properties[curve_name]["pseudo_binary_encoding"]
ate_loop_count = optimized_curve_properties[curve_name]["ate_loop_count"]
if sum([e * 2**i for i, e in enumerate(pseudo_binary_encoding)]) != ate_loop_count:
raise ValueError(
"Inconsistent values among pseudo_binary_encoding and ate_loop_count"
"in the optimized curve {}"
.format(curve_name)
)


def validate_generators() -> None:
# Validate generators of normal curves
for curve_obj in (bn128, bls12_381):
if not curve_obj.is_on_curve(curve_obj.G1, curve_obj.b):
raise ValueError(
"G1 doesn't lie on the curve {} defined by b".format(curve_obj.curve_name)
)

if not curve_obj.is_on_curve(curve_obj.G2, curve_obj.b2):
raise ValueError(
"G2 doesn't lie on the curve {} defined by b2".format(curve_obj.curve_name)
)

if not curve_obj.is_on_curve(curve_obj.G12, curve_obj.b12):
raise ValueError(
"G12 doesn't lie on the curve {} defined by b12".format(curve_obj.curve_name)
)

# Validate generators of optimized curves
for optimized_curve_obj in (optimized_bn128, optimized_bls12_381):
if not optimized_curve_obj.is_on_curve(optimized_curve_obj.G1, optimized_curve_obj.b):
raise ValueError(
"G1 doesn't lie on the optimized curve {} defined by b"
.format(optimized_curve_obj.curve_name)
)

if not optimized_curve_obj.is_on_curve(optimized_curve_obj.G2, optimized_curve_obj.b2):
raise ValueError(
"G2 doesn't lie on the optimized curve {} defined by b2"
.format(optimized_curve_obj.curve_name)
)

if not optimized_curve_obj.is_on_curve(optimized_curve_obj.G12, optimized_curve_obj.b12):
raise ValueError(
"G12 doesn't lie on the optimized curve {} defined by b12"
.format(optimized_curve_obj.curve_name)
)


@lru_cache(maxsize=1)
def validate_constants() -> None:
"""
This function validates the constants that are being used throughout
the whole codebase. It specifically verifies the curve and field properties
"""
validate_field_properties()
validate_curve_properties()
validate_optimized_curve_properties()
validate_generators()

0 comments on commit 00e65a8

Please sign in to comment.