Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add coincurve instead of secp256k1 #46

Merged
merged 7 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion bolt11/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" bolt11 CLI """
"""bolt11 CLI"""

import json
import sys
Expand Down Expand Up @@ -52,10 +52,12 @@ def decode(bolt11, ignore_exceptions, strict):
@click.argument("private_key", type=str, default=None, required=False)
@click.argument("ignore_exceptions", type=bool, default=True)
@click.argument("strict", type=bool, default=False)
@click.argument("keep_payee", type=bool, default=False)
def encode(
json_string,
ignore_exceptions: bool = True,
strict: bool = False,
keep_payee: bool = False,
private_key: Optional[str] = None,
):
"""
Expand Down Expand Up @@ -92,6 +94,7 @@ def encode(
private_key,
ignore_exceptions=ignore_exceptions,
strict=strict,
keep_payee=keep_payee,
)
click.echo(encoded)
except Bolt11Exception as exc:
Expand Down
16 changes: 11 additions & 5 deletions bolt11/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ def decode(
timestamp = data_part.read(35).uint

tags = Tags()
payee = None

while data_part.pos != data_part.len:
tag, tagdata, data_part = _pull_tagged(data_part)
data_length = int(len(tagdata or []) / 5)

# MUST skip over unknown fields, OR an f field with unknown version, OR p, h,
# s or n fields that do NOT have data_lengths of 52, 52, 52 or 53, respectively.

if (
tag == TagChar.payment_hash.value
and data_length == 52
Expand Down Expand Up @@ -93,9 +95,10 @@ def decode(
and data_length == 53
and not tags.has(TagChar.payee)
):
payee = trim_to_bytes(tagdata).hex()
tags.add(
TagChar.payee,
trim_to_bytes(tagdata).hex(),
payee,
)
elif (
tag == TagChar.description.value
Expand Down Expand Up @@ -133,19 +136,22 @@ def decode(
elif tag == TagChar.route_hint.value:
tags.add(TagChar.route_hint, RouteHint.from_bitstring(tagdata))

else:
# skip unknown fields
pass

signature = Signature(
signature_data=signature_data,
signing_data=hrp.encode() + data_part.tobytes(),
signing_data=data_part.tobytes(),
hrp=hrp,
)

# A reader MUST check that the `signature` is valid (see the `n` tagged field
# specified below). A reader MUST use the `n` field to validate the signature
# instead of performing signature recovery if a valid `n` field is provided.
payee = tags.get(TagChar.payee)
if payee:
# TODO: research why no test runs this?
try:
signature.verify(payee.data)
signature.verify(payee)
except Exception as exc:
raise Bolt11SignatureVerifyException() from exc
else:
Expand Down
11 changes: 6 additions & 5 deletions bolt11/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def encode(
private_key: Optional[str] = None,
ignore_exceptions: bool = False,
strict: bool = False,
keep_payee: bool = False,
) -> str:
try:
if invoice.description_hash:
Expand All @@ -75,10 +76,8 @@ def encode(
tags += _tagged_bytes(tag.bech32, bytes.fromhex(tag.data))
elif tag.char == TagChar.metadata:
tags += _tagged_bytes(tag.bech32, bytes.fromhex(tag.data))
# TODO: why uncommented?
# payee is not needed, needs more research
# elif tag.char == TagChar.payee:
# tags += _tagged_bytes(tag.bech32, bytes.fromhex(tag.data))
elif tag.char == TagChar.payee and keep_payee:
tags += _tagged_bytes(tag.bech32, bytes.fromhex(tag.data))
elif tag.char == TagChar.features:
tags += _tagged_bytes(tag.bech32, tag.data.data)
elif tag.char == TagChar.fallback:
Expand All @@ -94,7 +93,9 @@ def encode(
data_part = timestamp + tags

if private_key:
invoice.signature = Signature.from_private_key(private_key, hrp, data_part)
invoice.signature = Signature.from_private_key(
hrp=hrp, private_key=private_key, signing_data=data_part.tobytes()
)

if not invoice.signature:
raise Bolt11NoSignatureException()
Expand Down
60 changes: 38 additions & 22 deletions bolt11/models/signature.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,60 @@
from dataclasses import dataclass
from hashlib import sha256
from typing import Optional

from bitstring import Bits
from ecdsa import SECP256k1, VerifyingKey
from ecdsa.util import sigdecode_string
from secp256k1 import PrivateKey
from coincurve import PrivateKey, PublicKey, verify_signature
from coincurve.ecdsa import cdata_to_der, deserialize_recoverable, recoverable_convert


def message(hrp: str, signing_data: bytes) -> bytes:
return bytearray([ord(c) for c in hrp]) + signing_data


@dataclass
class Signature:
"""An invoice signature."""

hrp: str
signing_data: bytes
signature_data: Optional[bytes] = None
signature_data: bytes

@classmethod
def from_signature_data(
cls, hrp: str, signature_data: bytes, signing_data: bytes
) -> "Signature":
return cls(hrp=hrp, signature_data=signature_data, signing_data=signing_data)

@classmethod
def from_private_key(
cls, private_key: str, hrp: str, signing_data: Bits
cls, hrp: str, private_key: str, signing_data: bytes
) -> "Signature":
key = PrivateKey(bytes.fromhex(private_key))
sig = key.ecdsa_sign_recoverable(
bytearray([ord(c) for c in hrp]) + signing_data.tobytes()
)
sig, recid = key.ecdsa_recoverable_serialize(sig)
signature_data = bytes(sig) + bytes([recid])
return cls(signing_data=signing_data.tobytes(), signature_data=signature_data)
key = PrivateKey.from_hex(private_key)
signature_data = key.sign_recoverable(message(hrp, signing_data))
return cls(hrp=hrp, signing_data=signing_data, signature_data=signature_data)

def verify(self, payee: str) -> bool:
key = VerifyingKey.from_string(bytes.fromhex(payee), curve=SECP256k1)
return key.verify(
self.sig, self.signing_data, sha256, sigdecode=sigdecode_string
)
if not self.signature_data:
raise ValueError("No signature data")
if not self.signing_data:
raise ValueError("No signing data")
sig = deserialize_recoverable(self.signature_data)
sig = recoverable_convert(sig)
sig = cdata_to_der(sig)
if not verify_signature(
sig, message(self.hrp, self.signing_data), bytes.fromhex(payee)
):
raise ValueError("Invalid signature")
return True

def recover_public_key(self) -> str:
keys = VerifyingKey.from_public_key_recovery(
self.sig, self.signing_data, SECP256k1, sha256
if not self.signature_data:
raise ValueError("No signature data")
if not self.signing_data:
raise ValueError("No signing data")

key = PublicKey.from_signature_and_message(
self.signature_data, message(self.hrp, self.signing_data)
)
key = keys[self.recovery_flag]
return key.to_string("compressed").hex()
return key.format(compressed=True).hex()

@property
def r(self) -> str:
Expand Down
Loading
Loading