From 6f07a5b167e8d8a1c98aa9eaf2e4f59f5ed9c8f1 Mon Sep 17 00:00:00 2001 From: Bhargavasomu Date: Wed, 28 Nov 2018 22:56:59 +0530 Subject: [PATCH] Enable Type Hinting for optimized_bn128 module --- py_ecc/optimized_bn128/optimized_curve.py | 44 ++-- .../optimized_field_elements.py | 196 ++++++++++-------- py_ecc/optimized_bn128/optimized_pairing.py | 35 +++- py_ecc/typing.py | 38 +++- tox.ini | 1 + 5 files changed, 195 insertions(+), 119 deletions(-) diff --git a/py_ecc/optimized_bn128/optimized_curve.py b/py_ecc/optimized_bn128/optimized_curve.py index a9281bd7..0a3caa6f 100644 --- a/py_ecc/optimized_bn128/optimized_curve.py +++ b/py_ecc/optimized_bn128/optimized_curve.py @@ -1,10 +1,21 @@ from __future__ import absolute_import +from typing import ( + cast, +) + from .optimized_field_elements import ( - FQ2, - FQ12, field_modulus, FQ, + FQ2, + FQ12, + FQP, +) + +from py_ecc.typing import ( + Optimized_Field, + Optimized_Point2D, + Optimized_Point3D, ) @@ -43,12 +54,12 @@ # Check if a point is the point at infinity -def is_inf(pt): - return pt[-1] == pt[-1].__class__.zero() +def is_inf(pt: Optimized_Point3D[Optimized_Field]) -> bool: + return pt[-1] == (type(pt[-1]).zero()) # Check that a point is on the curve defined by y**2 == x**3 + b -def is_on_curve(pt, b): +def is_on_curve(pt: Optimized_Point3D[Optimized_Field], b: Optimized_Field) -> bool: if is_inf(pt): return True x, y, z = pt @@ -60,7 +71,7 @@ def is_on_curve(pt, b): # Elliptic curve doubling -def double(pt): +def double(pt: Optimized_Point3D[Optimized_Field]) -> Optimized_Point3D[Optimized_Field]: x, y, z = pt W = 3 * x * x S = y * z @@ -70,12 +81,13 @@ def double(pt): newx = 2 * H * S newy = W * (4 * B - H) - 8 * y * y * S_squared newz = 8 * S * S_squared - return newx, newy, newz + return (newx, newy, newz) # Elliptic curve addition -def add(p1, p2): - one, zero = p1[0].__class__.one(), p1[0].__class__.zero() +def add(p1: Optimized_Point3D[Optimized_Field], + p2: Optimized_Point3D[Optimized_Field]) -> Optimized_Point3D[Optimized_Field]: + one, zero = type(p1[0]).one(), type(p1[0]).zero() if p1[2] == zero or p2[2] == zero: return p1 if p2[2] == zero else p2 x1, y1, z1 = p1 @@ -102,9 +114,9 @@ def add(p1, p2): # Elliptic curve point multiplication -def multiply(pt, n): +def multiply(pt: Optimized_Point3D[Optimized_Field], n: int) -> Optimized_Point3D[Optimized_Field]: if n == 0: - return (pt[0].__class__.one(), pt[0].__class__.one(), pt[0].__class__.zero()) + return (type(pt[0]).one(), type(pt[0]).one(), type(pt[0]).zero()) elif n == 1: return pt elif not n % 2: @@ -113,13 +125,13 @@ def multiply(pt, n): return add(multiply(double(pt), int(n // 2)), pt) -def eq(p1, p2): +def eq(p1: Optimized_Point3D[Optimized_Field], p2: Optimized_Point3D[Optimized_Field]) -> bool: x1, y1, z1 = p1 x2, y2, z2 = p2 return x1 * z2 == x2 * z1 and y1 * z2 == y2 * z1 -def normalize(pt): +def normalize(pt: Optimized_Point3D[Optimized_Field]) -> Optimized_Point2D[Optimized_Field]: x, y, z = pt return (x / z, y / z) @@ -129,14 +141,14 @@ def normalize(pt): # Convert P => -P -def neg(pt): +def neg(pt: Optimized_Point3D[Optimized_Field]) -> Optimized_Point3D[Optimized_Field]: if pt is None: return None x, y, z = pt return (x, -y, z) -def twist(pt): +def twist(pt: Optimized_Point3D[FQP]) -> Optimized_Point3D[FQP]: if pt is None: return None _x, _y, _z = pt @@ -151,5 +163,5 @@ def twist(pt): # Check that the twist creates a point that is on the curve -G12 = twist(G2) +G12 = twist(cast(Optimized_Point3D[FQ2], G2)) assert is_on_curve(G12, b12) diff --git a/py_ecc/optimized_bn128/optimized_field_elements.py b/py_ecc/optimized_bn128/optimized_field_elements.py index 3fa2f225..5b03b794 100644 --- a/py_ecc/optimized_bn128/optimized_field_elements.py +++ b/py_ecc/optimized_bn128/optimized_field_elements.py @@ -1,22 +1,23 @@ from __future__ import absolute_import -import sys - +from typing import ( # noqa: F401 + cast, + List, + Sequence, + Tuple, + Union, +) field_modulus = 21888242871839275222246405745257275088696311157297823662689037894645226208583 -FQ12_modulus_coeffs = [82, 0, 0, 0, 0, 0, -18, 0, 0, 0, 0, 0] # Implied + [1] -FQ12_mc_tuples = [(i, c) for i, c in enumerate(FQ12_modulus_coeffs) if c] - -# python3 compatibility -if sys.version_info.major == 2: - int_types = (int, long) # noqa: F821 -else: - int_types = (int,) +FQ2_MODULUS_COEFFS = [1, 0] +FQ12_MODULUS_COEFFS = [82, 0, 0, 0, 0, 0, -18, 0, 0, 0, 0, 0] # Implied + [1] +FQ2_MC_TUPLES = [(0, 1)] +FQ12_MC_TUPLES = [(i, c) for i, c in enumerate(FQ12_MODULUS_COEFFS) if c] # Extended euclidean algorithm to find modular inverses for # integers -def prime_field_inv(a, n): +def prime_field_inv(a: int, n: int) -> int: if a == 0: return 0 lm, hm = 1, 0 @@ -28,55 +29,60 @@ def prime_field_inv(a, n): return lm % n +IntOrFQ = Union[int, "FQ"] + + # A class for field elements in FQ. Wrap a number in this class, # and it becomes a field element. class FQ(object): - def __init__(self, n): - if isinstance(n, self.__class__): - self.n = n.n + n = None # type: int + + def __init__(self, val: IntOrFQ) -> None: + if isinstance(val, FQ): + self.n = val.n else: - self.n = n % field_modulus - assert isinstance(self.n, int_types) + self.n = val % field_modulus + assert isinstance(self.n, int) - def __add__(self, other): + def __add__(self, other: IntOrFQ) -> "FQ": on = other.n if isinstance(other, FQ) else other return FQ((self.n + on) % field_modulus) - def __mul__(self, other): + def __mul__(self, other: IntOrFQ) -> "FQ": on = other.n if isinstance(other, FQ) else other return FQ((self.n * on) % field_modulus) - def __rmul__(self, other): + def __rmul__(self, other: IntOrFQ) -> "FQ": return self * other - def __radd__(self, other): + def __radd__(self, other: IntOrFQ) -> "FQ": return self + other - def __rsub__(self, other): + def __rsub__(self, other: IntOrFQ) -> "FQ": on = other.n if isinstance(other, FQ) else other return FQ((on - self.n) % field_modulus) - def __sub__(self, other): + def __sub__(self, other: IntOrFQ) -> "FQ": on = other.n if isinstance(other, FQ) else other return FQ((self.n - on) % field_modulus) - def __div__(self, other): + def __div__(self, other: IntOrFQ) -> "FQ": on = other.n if isinstance(other, FQ) else other - assert isinstance(on, int_types) + assert isinstance(on, int) return FQ(self.n * prime_field_inv(on, field_modulus) % field_modulus) - def __truediv__(self, other): + def __truediv__(self, other: IntOrFQ) -> "FQ": return self.__div__(other) - def __rdiv__(self, other): + def __rdiv__(self, other: IntOrFQ) -> "FQ": on = other.n if isinstance(other, FQ) else other - assert isinstance(on, int_types), on + assert isinstance(on, int), on return FQ(prime_field_inv(self.n, field_modulus) * on % field_modulus) - def __rtruediv__(self, other): + def __rtruediv__(self, other: IntOrFQ) -> "FQ": return self.__rdiv__(other) - def __pow__(self, other): + def __pow__(self, other: int) -> "FQ": if other == 0: return FQ(1) elif other == 1: @@ -86,45 +92,49 @@ def __pow__(self, other): else: return ((self * self) ** int(other // 2)) * self - def __eq__(self, other): + def __eq__(self, other: IntOrFQ) -> bool: # type:ignore # https://github.com/python/mypy/issues/2783 # noqa: E501 if isinstance(other, FQ): return self.n == other.n else: return self.n == other - def __ne__(self, other): + def __ne__(self, other: IntOrFQ) -> bool: # type:ignore # https://github.com/python/mypy/issues/2783 # noqa: E501 return not self == other - def __neg__(self): + def __neg__(self) -> "FQ": return FQ(-self.n) - def __repr__(self): + def __repr__(self) -> str: return repr(self.n) + def __int__(self) -> int: + return self.n + @classmethod - def one(cls): + def one(cls) -> "FQ": return cls(1) @classmethod - def zero(cls): + def zero(cls) -> "FQ": return cls(0) # Utility methods for polynomial math -def deg(p): +def deg(p: Sequence[IntOrFQ]) -> int: d = len(p) - 1 while p[d] == 0 and d: d -= 1 return d -def poly_rounded_div(a, b): +def poly_rounded_div(a: Sequence[IntOrFQ], + b: Sequence[IntOrFQ]) -> Sequence[IntOrFQ]: dega = deg(a) degb = deg(b) temp = [x for x in a] o = [0 for x in a] for i in range(dega - degb, -1, -1): - o[i] = (o[i] + temp[degb + i] * prime_field_inv(b[degb], field_modulus)) + o[i] = int(o[i] + temp[degb + i] * prime_field_inv(int(b[degb]), field_modulus)) for c in range(degb + 1): temp[c + i] = (temp[c + i] - o[c]) return [x % field_modulus for x in o[:deg(o) + 1]] @@ -132,7 +142,12 @@ def poly_rounded_div(a, b): # A class for elements in polynomial extension fields class FQP(object): - def __init__(self, coeffs, modulus_coeffs): + degree = 0 # type: int + mc_tuples = None # type: List[Tuple[int, int]] + + def __init__(self, + coeffs: Sequence[IntOrFQ], + modulus_coeffs: Sequence[IntOrFQ]=None) -> None: assert len(coeffs) == len(modulus_coeffs) self.coeffs = coeffs # The coefficients of the modulus, without the leading [1] @@ -140,58 +155,58 @@ def __init__(self, coeffs, modulus_coeffs): # The degree of the extension field self.degree = len(self.modulus_coeffs) - def __add__(self, other): - assert isinstance(other, self.__class__) - return self.__class__([ - (x + y) % field_modulus + def __add__(self, other: "FQP") -> "FQP": + assert isinstance(other, type(self)) + return type(self)([ + int(x + y) % field_modulus for x, y in zip(self.coeffs, other.coeffs) ]) - def __sub__(self, other): - assert isinstance(other, self.__class__) - return self.__class__([ - (x - y) % field_modulus + def __sub__(self, other: "FQP") -> "FQP": + assert isinstance(other, type(self)) + return type(self)([ + int(x - y) % field_modulus for x, y in zip(self.coeffs, other.coeffs) ]) - def __mul__(self, other): - if isinstance(other, int_types): - return self.__class__([c * other % field_modulus for c in self.coeffs]) + def __mul__(self, other: Union[int, "FQP"]) -> "FQP": + if isinstance(other, int): + return type(self)([int(c) * other % field_modulus for c in self.coeffs]) else: # assert isinstance(other, self.__class__) b = [0] * (self.degree * 2 - 1) inner_enumerate = list(enumerate(other.coeffs)) for i, eli in enumerate(self.coeffs): for j, elj in inner_enumerate: - b[i + j] += eli * elj + b[i + j] += int(eli * elj) # MID = len(self.coeffs) // 2 for exp in range(self.degree - 2, -1, -1): top = b.pop() for i, c in self.mc_tuples: b[exp + i] -= top * c - return self.__class__([x % field_modulus for x in b]) + return type(self)([x % field_modulus for x in b]) - def __rmul__(self, other): + def __rmul__(self, other: Union[int, "FQP"]) -> "FQP": return self * other - def __div__(self, other): - if isinstance(other, int_types): - return self.__class__([ - c * prime_field_inv(other, field_modulus) % field_modulus + def __div__(self, other: Union[int, "FQ", "FQP"]) -> "FQP": + if isinstance(other, int): + return type(self)([ + int(c) * prime_field_inv(other, field_modulus) % field_modulus for c in self.coeffs ]) else: - assert isinstance(other, self.__class__) + assert isinstance(other, type(self)) return self * other.inv() - def __truediv__(self, other): + def __truediv__(self, other: Union[int, "FQ", "FQP"]) -> "FQP": return self.__div__(other) - def __pow__(self, other): - o = self.__class__([1] + [0] * (self.degree - 1)) + def __pow__(self, other: int) -> "FQP": + o = type(self)([1] + [0] * (self.degree - 1)) t = self while other > 0: if other & 1: @@ -201,64 +216,71 @@ def __pow__(self, other): return o # Extended euclidean algorithm used to find the modular inverse - def inv(self): + def inv(self) -> "FQP": lm, hm = [1] + [0] * self.degree, [0] * (self.degree + 1) - low, high = self.coeffs + [0], self.modulus_coeffs + [1] + low, high = ( + # Ignore mypy yelling about the inner types for the lists being incompatible + cast(List[IntOrFQ], list(self.coeffs + [0])), # type: ignore + cast(List[IntOrFQ], list(self.modulus_coeffs + [1])), # type: ignore + ) + low, high = list(self.coeffs + [0]), self.modulus_coeffs + [1] # type: ignore while deg(low): - r = poly_rounded_div(high, low) + r = cast(List[IntOrFQ], poly_rounded_div(high, low)) r += [0] * (self.degree + 1 - len(r)) nm = [x for x in hm] new = [x for x in high] # assert len(lm) == len(hm) == len(low) == len(high) == len(nm) == len(new) == self.degree + 1 # noqa: E501 for i in range(self.degree + 1): for j in range(self.degree + 1 - i): - nm[i + j] -= lm[i] * r[j] + nm[i + j] -= lm[i] * int(r[j]) new[i + j] -= low[i] * r[j] nm = [x % field_modulus for x in nm] - new = [x % field_modulus for x in new] + new = [int(x) % field_modulus for x in new] lm, low, hm, high = nm, new, lm, low - return self.__class__(lm[:self.degree]) / low[0] + return type(self)(lm[:self.degree]) / low[0] - def __repr__(self): + def __repr__(self) -> str: return repr(self.coeffs) - def __eq__(self, other): - assert isinstance(other, self.__class__) + def __eq__(self, other: "FQP") -> bool: # type: ignore # https://github.com/python/mypy/issues/2783 # noqa: E501 + assert isinstance(other, type(self)) for c1, c2 in zip(self.coeffs, other.coeffs): if c1 != c2: return False return True - def __ne__(self, other): + def __ne__(self, other: "FQP") -> bool: # type: ignore # https://github.com/python/mypy/issues/2783 # noqa: E501 return not self == other - def __neg__(self): - return self.__class__([-c for c in self.coeffs]) + def __neg__(self) -> "FQP": + return type(self)([-c for c in self.coeffs]) @classmethod - def one(cls): + def one(cls) -> "FQP": return cls([1] + [0] * (cls.degree - 1)) @classmethod - def zero(cls): + def zero(cls) -> "FQP": return cls([0] * cls.degree) # The quadratic extension field class FQ2(FQP): - def __init__(self, coeffs): - self.coeffs = coeffs - self.modulus_coeffs = [1, 0] - self.mc_tuples = [(0, 1)] - self.degree = 2 - self.__class__.degree = 2 + degree = 2 + mc_tuples = FQ2_MC_TUPLES + + def __init__(self, coeffs: Sequence[IntOrFQ]) -> None: + super().__init__(coeffs, FQ2_MODULUS_COEFFS) + assert self.degree == 2 + assert self.mc_tuples == FQ2_MC_TUPLES # The 12th-degree extension field class FQ12(FQP): - def __init__(self, coeffs): - self.coeffs = coeffs - self.modulus_coeffs = FQ12_modulus_coeffs - self.mc_tuples = FQ12_mc_tuples - self.degree = 12 - self.__class__.degree = 12 + degree = 12 + mc_tuples = FQ12_MC_TUPLES + + def __init__(self, coeffs: Sequence[IntOrFQ]) -> None: + super().__init__(coeffs, FQ12_MODULUS_COEFFS) + assert self.degree == 12 + assert self.mc_tuples == FQ12_MC_TUPLES diff --git a/py_ecc/optimized_bn128/optimized_pairing.py b/py_ecc/optimized_bn128/optimized_pairing.py index bd846a53..8289d548 100644 --- a/py_ecc/optimized_bn128/optimized_pairing.py +++ b/py_ecc/optimized_bn128/optimized_pairing.py @@ -1,5 +1,13 @@ from __future__ import absolute_import +from py_ecc.typing import ( + Optimized_Field, + Optimized_FQPoint3D, + Optimized_FQ2Point3D, + Optimized_Point2D, + Optimized_Point3D, +) + from .optimized_curve import ( double, add, @@ -14,9 +22,10 @@ normalize, ) from .optimized_field_elements import ( - FQ12, field_modulus, FQ, + FQ12, + FQP, ) @@ -33,17 +42,19 @@ assert sum([e * 2**i for i, e in enumerate(pseudo_binary_encoding)]) == ate_loop_count -def normalize1(p): +def normalize1(p: Optimized_Point3D[Optimized_Field]) -> Optimized_Point3D[Optimized_Field]: x, y = normalize(p) - return x, y, x.__class__.one() + return x, y, type(x).one() # Create a function representing the line between P1 and P2, # and evaluate it at T. Returns a numerator and a denominator # to avoid unneeded divisions -def linefunc(P1, P2, T): - zero = P1[0].__class__.zero() +def linefunc(P1: Optimized_Point3D[Optimized_Field], + P2: Optimized_Point3D[Optimized_Field], + T: Optimized_Point3D[Optimized_Field]) -> Optimized_Point2D[Optimized_Field]: + zero = type(P1[0]).zero() x1, y1, z1 = P1 x2, y2, z2 = P2 xt, yt, zt = T @@ -66,7 +77,7 @@ def linefunc(P1, P2, T): return xt * z1 - x1 * zt, z1 * zt -def cast_point_to_fq12(pt): +def cast_point_to_fq12(pt: Optimized_Point3D[FQ]) -> Optimized_Point3D[FQ12]: if pt is None: return None x, y, z = pt @@ -95,10 +106,12 @@ def cast_point_to_fq12(pt): # Main miller loop -def miller_loop(Q, P, final_exponentiate=True): +def miller_loop(Q: Optimized_Point3D[FQP], + P: Optimized_Point3D[FQP], + final_exponentiate: bool=True) -> FQP: if Q is None or P is None: return FQ12.one() - R = Q + R = Q # type: Optimized_Point3D[FQP] f_num, f_den = FQ12.one(), FQ12.one() # for i in range(log_ate_loop_count, -1, -1): for v in pseudo_binary_encoding[63::-1]: @@ -135,13 +148,13 @@ def miller_loop(Q, P, final_exponentiate=True): # Pairing computation -def pairing(Q, P, final_exponentiate=True): +def pairing(Q: Optimized_FQ2Point3D, P: Optimized_FQPoint3D, final_exponentiate: bool=True) -> FQP: assert is_on_curve(Q, b2) assert is_on_curve(P, b) - if P[-1] == P[-1].__class__.zero() or Q[-1] == Q[-1].__class__.zero(): + if P[-1] == (type(P[-1]).zero()) or Q[-1] == (type(Q[-1]).zero()): return FQ12.one() return miller_loop(twist(Q), cast_point_to_fq12(P), final_exponentiate=final_exponentiate) -def final_exponentiate(p): +def final_exponentiate(p: Optimized_Field) -> Optimized_Field: return p ** ((field_modulus ** 12 - 1) // curve_order) diff --git a/py_ecc/typing.py b/py_ecc/typing.py index c96bffe2..f08a3670 100644 --- a/py_ecc/typing.py +++ b/py_ecc/typing.py @@ -1,3 +1,10 @@ +from typing import ( + Tuple, + TypeVar, + TYPE_CHECKING, + Union, +) + from py_ecc.bn128.bn128_field_elements import ( FQ, FQP, @@ -5,13 +12,16 @@ FQ12, ) -from typing import ( - Tuple, - TypeVar, - Union, -) +if TYPE_CHECKING: + from py_ecc.optimized_bn128.optimized_field_elements import ( # noqa: F401 + FQ as Optimized_FQ, + FQP as Optimized_FQP, + FQ2 as Optimized_FQ2, + FQ12 as Optimized_FQ12, + ) +# Types For bn128 module # These types are wrt FQ, FQ2, FQ12 FQPoint2D = Tuple[FQ, FQ] FQ2Point2D = Tuple[FQ2, FQ2] @@ -31,3 +41,21 @@ Point2D = Tuple[Field, Field] Point3D = Tuple[Field, Field, Field] GeneralPoint = Union[Point2D[Field], Point3D[Field]] + + +# Types For optimized_bn128_module +# These types are wrt FQ, FQ2, FQ12 +Optimized_FQPoint2D = Tuple["Optimized_FQ", "Optimized_FQ"] +Optimized_FQ2Point2D = Tuple["Optimized_FQ2", "Optimized_FQ2"] +Optimized_FQ12Point2D = Tuple["Optimized_FQ12", "Optimized_FQ12"] +Optimized_FQPPoint2D = Tuple["Optimized_FQP", "Optimized_FQP"] + +Optimized_FQPoint3D = Tuple["Optimized_FQ", "Optimized_FQ", "Optimized_FQ"] +Optimized_FQ2Point3D = Tuple["Optimized_FQ2", "Optimized_FQ2", "Optimized_FQ2"] +Optimized_FQ12Point3D = Tuple["Optimized_FQ12", "Optimized_FQ12", "Optimized_FQ12"] +Optimized_FQPPoint3D = Tuple["Optimized_FQP", "Optimized_FQP", "Optimized_FQP"] + +Optimized_Field = TypeVar('Optimized_Field', "Optimized_FQ", "Optimized_FQP") +Optimized_Point2D = Tuple[Optimized_Field, Optimized_Field] +Optimized_Point3D = Tuple[Optimized_Field, Optimized_Field, Optimized_Field] +Optimized_GeneralPoint = Union[Point2D[Optimized_Field], Point3D[Optimized_Field]] diff --git a/tox.ini b/tox.ini index c43d43cf..1b3f3474 100644 --- a/tox.ini +++ b/tox.ini @@ -25,4 +25,5 @@ extras=lint commands= flake8 {toxinidir}/py_ecc mypy --strict --follow-imports=silent --ignore-missing-imports --no-strict-optional -p py_ecc.bn128 + mypy --strict --follow-imports=silent --ignore-missing-imports --no-strict-optional -p py_ecc.optimized_bn128 mypy --strict --follow-imports=silent --ignore-missing-imports --no-strict-optional -p py_ecc.secp256k1