diff --git a/py_ecc/__init__.py b/py_ecc/__init__.py index 982d917f..df9e3897 100644 --- a/py_ecc/__init__.py +++ b/py_ecc/__init__.py @@ -1,12 +1,16 @@ import sys -from .bls12_381_curve import ( +from .bls12_381_curve import ( # noqa: F401 + bls12_381, BLS12_381_Curve, + optimized_bls12_381, Optimized_BLS12_381_Curve, ) -from .bn128_curve import ( +from .bn128_curve import ( # noqa: F401 + bn128, BN128_Curve, + optimized_bn128, Optimized_BN128_Curve, ) @@ -26,11 +30,10 @@ from .secp256k1 import secp256k1 # noqa: F401 +from .validate_constants import validate_constants -sys.setrecursionlimit(max(100000, sys.getrecursionlimit())) -bn128 = BN128_Curve() -optimized_bn128 = Optimized_BN128_Curve() +sys.setrecursionlimit(max(100000, sys.getrecursionlimit())) -bls12_381 = BLS12_381_Curve() -optimized_bls12_381 = Optimized_BLS12_381_Curve() +# Check all the constants are valid, before using them +validate_constants() diff --git a/py_ecc/bls12_381_curve.py b/py_ecc/bls12_381_curve.py index 8bdb370a..e2fead54 100644 --- a/py_ecc/bls12_381_curve.py +++ b/py_ecc/bls12_381_curve.py @@ -162,3 +162,7 @@ def miller_loop(cls, return f ** ((cls.field_modulus ** 12 - 1) // cls.curve_order) else: return f + + +bls12_381 = BLS12_381_Curve() +optimized_bls12_381 = Optimized_BLS12_381_Curve() diff --git a/py_ecc/bn128_curve.py b/py_ecc/bn128_curve.py index 5bbcd6f7..75ae0f2f 100644 --- a/py_ecc/bn128_curve.py +++ b/py_ecc/bn128_curve.py @@ -174,3 +174,7 @@ def miller_loop(cls, return f ** ((cls.field_modulus ** 12 - 1) // cls.curve_order) else: return f + + +bn128 = BN128_Curve() +optimized_bn128 = Optimized_BN128_Curve() diff --git a/py_ecc/field_elements.py b/py_ecc/field_elements.py index 7a3a1c03..b450fbbf 100644 --- a/py_ecc/field_elements.py +++ b/py_ecc/field_elements.py @@ -34,8 +34,6 @@ def __init__(self, val: IntOrFQ, curve_name: str) -> None: """ self.curve_name = curve_name self.field_modulus = field_properties[curve_name]["field_modulus"] - # See, it's prime! - # assert pow(2, self.field_modulus, self.field_modulus) == 2 if isinstance(val, FQ): self.n = val.n diff --git a/py_ecc/optimized_field_elements.py b/py_ecc/optimized_field_elements.py index 1d384652..d45bd467 100644 --- a/py_ecc/optimized_field_elements.py +++ b/py_ecc/optimized_field_elements.py @@ -34,8 +34,6 @@ def __init__(self, val: IntOrFQ, curve_name: str) -> None: """ self.curve_name = curve_name self.field_modulus = field_properties[curve_name]["field_modulus"] - # See, it's prime! - # assert pow(2, self.field_modulus, self.field_modulus) == 2 if isinstance(val, FQ): self.n = val.n diff --git a/py_ecc/validate_constants.py b/py_ecc/validate_constants.py new file mode 100644 index 00000000..02656f10 --- /dev/null +++ b/py_ecc/validate_constants.py @@ -0,0 +1,135 @@ +from functools import lru_cache + +from py_ecc.bls12_381_curve import ( + bls12_381, + optimized_bls12_381, +) + +from py_ecc.bn128_curve import ( + bn128, + optimized_bn128, +) + +from py_ecc.curve_properties import ( + curve_properties, + optimized_curve_properties, +) + +from py_ecc.field_properties import ( + field_properties, +) + + +def validate_field_properties() -> None: + for curve_name in field_properties: + # Check if field_modulus is prime + field_modulus = field_properties[curve_name]["field_modulus"] + if pow(2, field_modulus, field_modulus) != 2: + raise ValueError( + "Field Modulus of the curve {} is not a prime".format(curve_name) + ) + + +def validate_curve_properties() -> None: + for curve_name in curve_properties: + # Check if curve_order is prime + curve_order = curve_properties[curve_name]["curve_order"] + if pow(2, curve_order, curve_order) != 2: + raise ValueError( + "Curve Order of the curve {} is not a prime".format(curve_name) + ) + + # Check consistency b/w field_modulus and curve_order + field_modulus = field_properties[curve_name]["field_modulus"] + if (field_modulus ** 12 - 1) % curve_order != 0: + raise ValueError( + "Inconsistent values among field_modulus and curve_order in the curve {}" + .format(curve_name) + ) + + # Check validity of pseudo_binary_encoding + pseudo_binary_encoding = curve_properties[curve_name]["pseudo_binary_encoding"] + ate_loop_count = curve_properties[curve_name]["ate_loop_count"] + if sum([e * 2**i for i, e in enumerate(pseudo_binary_encoding)]) != ate_loop_count: + raise ValueError( + "Inconsistent values among pseudo_binary_encoding and ate_loop_count" + ) + + +def validate_optimized_curve_properties() -> None: + for curve_name in optimized_curve_properties: + # Check if curve_order is prime + curve_order = optimized_curve_properties[curve_name]["curve_order"] + if pow(2, curve_order, curve_order) != 2: + raise ValueError( + "Curve Order of the optimized curve {} is not a prime".format(curve_name) + ) + + # Check consistency b/w field_modulus and curve_order + field_modulus = field_properties[curve_name]["field_modulus"] + if (field_modulus ** 12 - 1) % curve_order != 0: + raise ValueError( + "Inconsistent values among field_modulus and curve_order in the optimized curve {}" + .format(curve_name) + ) + + # Check validity of pseudo_binary_encoding + pseudo_binary_encoding = optimized_curve_properties[curve_name]["pseudo_binary_encoding"] + ate_loop_count = optimized_curve_properties[curve_name]["ate_loop_count"] + if sum([e * 2**i for i, e in enumerate(pseudo_binary_encoding)]) != ate_loop_count: + raise ValueError( + "Inconsistent values among pseudo_binary_encoding and ate_loop_count" + "in the optimized curve {}" + .format(curve_name) + ) + + +def validate_generators() -> None: + # Validate generators of normal curves + for curve_obj in (bn128, bls12_381): + if not curve_obj.is_on_curve(curve_obj.G1, curve_obj.b): + raise ValueError( + "G1 doesn't lie on the curve {} defined by b".format(curve_obj.curve_name) + ) + + if not curve_obj.is_on_curve(curve_obj.G2, curve_obj.b2): + raise ValueError( + "G2 doesn't lie on the curve {} defined by b2".format(curve_obj.curve_name) + ) + + if not curve_obj.is_on_curve(curve_obj.G12, curve_obj.b12): + raise ValueError( + "G12 doesn't lie on the curve {} defined by b12".format(curve_obj.curve_name) + ) + + # Validate generators of optimized curves + for optimized_curve_obj in (optimized_bn128, optimized_bls12_381): + if not optimized_curve_obj.is_on_curve(optimized_curve_obj.G1, optimized_curve_obj.b): + raise ValueError( + "G1 doesn't lie on the optimized curve {} defined by b" + .format(optimized_curve_obj.curve_name) + ) + + if not optimized_curve_obj.is_on_curve(optimized_curve_obj.G2, optimized_curve_obj.b2): + raise ValueError( + "G2 doesn't lie on the optimized curve {} defined by b2" + .format(optimized_curve_obj.curve_name) + ) + + if not optimized_curve_obj.is_on_curve(optimized_curve_obj.G12, optimized_curve_obj.b12): + raise ValueError( + "G12 doesn't lie on the optimized curve {} defined by b12" + .format(optimized_curve_obj.curve_name) + ) + + +@lru_cache(maxsize=1) +def validate_constants() -> None: + """ + This function validates the constants that are being used throughout + the whole codebase. It specifically verifies the curve and field properties + """ + validate_field_properties() + validate_curve_properties() + validate_optimized_curve_properties() + validate_generators()