Skip to content

Commit

Permalink
Fix doubling points on the x axis bug, add unittest (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
lc6chang authored Feb 12, 2022
1 parent 55ee2e8 commit eb8b8c1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
21 changes: 5 additions & 16 deletions ecc/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,25 +100,17 @@ def add_point(self, P: Point, Q: Point) -> Point:
elif Q.is_at_infinity():
return P

if P == Q:
return self._double_point(P)
if P == -Q:
return self.INF
if P == Q:
return self._double_point(P)

return self._add_point(P, Q)

@abstractmethod
def _add_point(self, P: Point, Q: Point) -> Point:
pass

def double_point(self, P: Point) -> Point:
if not self.is_on_curve(P):
raise ValueError("The point is not on the curve.")
if P.is_at_infinity():
return self.INF

return self._double_point(P)

@abstractmethod
def _double_point(self, P: Point) -> Point:
pass
Expand All @@ -134,17 +126,14 @@ def mul_point(self, d: int, P: Point) -> Point:
if d == 0:
return self.INF

res = None
res = self.INF
is_negative_scalar = d < 0
d = -d if is_negative_scalar else d
tmp = P
while d:
if d & 0x1 == 1:
if res:
res = self.add_point(res, tmp)
else:
res = tmp
tmp = self.double_point(tmp)
res = self.add_point(res, tmp)
tmp = self.add_point(tmp, tmp)
d >>= 1
if is_negative_scalar:
return -res
Expand Down
8 changes: 7 additions & 1 deletion tests/test_curve.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

from ecc.curve import (
P256, secp256k1, Curve25519, M383, E222, E382
P256, secp256k1, Curve25519, M383, E222, E382, Point
)

CURVES = [P256, secp256k1, Curve25519, M383, E222, E382]
Expand All @@ -24,3 +24,9 @@ def test_operator(self):
self.assertEqual(curve.INF + curve.INF, curve.INF)
self.assertEqual(0 * P, curve.INF)
self.assertEqual(1000 * curve.INF, curve.INF)

def test_double_points_y_equals_to_0(self):
P = Point(x=0, y=0, curve=Curve25519)
self.assertEqual(P + P, Curve25519.INF)
self.assertEqual(2 * P, Curve25519.INF)
self.assertEqual(-2 * P, Curve25519.INF)

0 comments on commit eb8b8c1

Please sign in to comment.