Skip to content

Commit

Permalink
Merge pull request #29 from Bhargavasomu/type_hinting
Browse files Browse the repository at this point in the history
Enable Type Hinting for optimized_bn128 module
  • Loading branch information
pipermerriam authored Dec 6, 2018
2 parents bacb225 + 6f07a5b commit 212b193
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 119 deletions.
44 changes: 28 additions & 16 deletions py_ecc/optimized_bn128/optimized_curve.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from __future__ import absolute_import

from typing import (
cast,
)

from .optimized_field_elements import (
FQ2,
FQ12,
field_modulus,
FQ,
FQ2,
FQ12,
FQP,
)

from py_ecc.typing import (
Optimized_Field,
Optimized_Point2D,
Optimized_Point3D,
)


Expand Down Expand Up @@ -43,12 +54,12 @@


# Check if a point is the point at infinity
def is_inf(pt):
return pt[-1] == pt[-1].__class__.zero()
def is_inf(pt: Optimized_Point3D[Optimized_Field]) -> bool:
return pt[-1] == (type(pt[-1]).zero())


# Check that a point is on the curve defined by y**2 == x**3 + b
def is_on_curve(pt, b):
def is_on_curve(pt: Optimized_Point3D[Optimized_Field], b: Optimized_Field) -> bool:
if is_inf(pt):
return True
x, y, z = pt
Expand All @@ -60,7 +71,7 @@ def is_on_curve(pt, b):


# Elliptic curve doubling
def double(pt):
def double(pt: Optimized_Point3D[Optimized_Field]) -> Optimized_Point3D[Optimized_Field]:
x, y, z = pt
W = 3 * x * x
S = y * z
Expand All @@ -70,12 +81,13 @@ def double(pt):
newx = 2 * H * S
newy = W * (4 * B - H) - 8 * y * y * S_squared
newz = 8 * S * S_squared
return newx, newy, newz
return (newx, newy, newz)


# Elliptic curve addition
def add(p1, p2):
one, zero = p1[0].__class__.one(), p1[0].__class__.zero()
def add(p1: Optimized_Point3D[Optimized_Field],
p2: Optimized_Point3D[Optimized_Field]) -> Optimized_Point3D[Optimized_Field]:
one, zero = type(p1[0]).one(), type(p1[0]).zero()
if p1[2] == zero or p2[2] == zero:
return p1 if p2[2] == zero else p2
x1, y1, z1 = p1
Expand All @@ -102,9 +114,9 @@ def add(p1, p2):


# Elliptic curve point multiplication
def multiply(pt, n):
def multiply(pt: Optimized_Point3D[Optimized_Field], n: int) -> Optimized_Point3D[Optimized_Field]:
if n == 0:
return (pt[0].__class__.one(), pt[0].__class__.one(), pt[0].__class__.zero())
return (type(pt[0]).one(), type(pt[0]).one(), type(pt[0]).zero())
elif n == 1:
return pt
elif not n % 2:
Expand All @@ -113,13 +125,13 @@ def multiply(pt, n):
return add(multiply(double(pt), int(n // 2)), pt)


def eq(p1, p2):
def eq(p1: Optimized_Point3D[Optimized_Field], p2: Optimized_Point3D[Optimized_Field]) -> bool:
x1, y1, z1 = p1
x2, y2, z2 = p2
return x1 * z2 == x2 * z1 and y1 * z2 == y2 * z1


def normalize(pt):
def normalize(pt: Optimized_Point3D[Optimized_Field]) -> Optimized_Point2D[Optimized_Field]:
x, y, z = pt
return (x / z, y / z)

Expand All @@ -129,14 +141,14 @@ def normalize(pt):


# Convert P => -P
def neg(pt):
def neg(pt: Optimized_Point3D[Optimized_Field]) -> Optimized_Point3D[Optimized_Field]:
if pt is None:
return None
x, y, z = pt
return (x, -y, z)


def twist(pt):
def twist(pt: Optimized_Point3D[FQP]) -> Optimized_Point3D[FQP]:
if pt is None:
return None
_x, _y, _z = pt
Expand All @@ -151,5 +163,5 @@ def twist(pt):


# Check that the twist creates a point that is on the curve
G12 = twist(G2)
G12 = twist(cast(Optimized_Point3D[FQ2], G2))
assert is_on_curve(G12, b12)
Loading

0 comments on commit 212b193

Please sign in to comment.