Skip to content

Commit

Permalink
Improve readability
Browse files Browse the repository at this point in the history
  • Loading branch information
sosthene-nitrokey committed Oct 25, 2024
1 parent 66191e0 commit 1e5c7e8
Showing 1 changed file with 35 additions and 34 deletions.
69 changes: 35 additions & 34 deletions pynitrokey/cli/nk3/piv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import click
import cryptography
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding
from cryptography.hazmat.primitives.asymmetric import ec, rsa
Expand All @@ -21,21 +22,21 @@ class RsaPivSigner(rsa.RSAPrivateKey):
_device: PivApp
_key_reference: int
_public_key: rsa.RSAPublicKey
_key_size: int

def __init__(
self, device: PivApp, key_reference: int, public_key: rsa.RSAPublicKey
):
self._device = device
self._key_reference = key_reference
self._public_key = public_key
self._key_size = public_key.key_size

def public_key(self) -> rsa.RSAPublicKey:
return self._public_key

# for some reason the type checking thinks this should be a constant
# it fails at runtime if it's not a function though
def key_size(self) -> int: # type: ignore
return self._key_size
return self._public_key.key_size

def sign(
self,
Expand Down Expand Up @@ -84,12 +85,16 @@ def exchange(
def public_key(self) -> ec.EllipticCurvePublicKey:
return self._public_key

# for some reason the type checking thinks this should be a constant
# it fails at runtime if it's not a function though
def curve(self) -> ec.EllipticCurve: # type: ignore
return self._public_key.curve

def private_numbers(self) -> ec.EllipticCurvePrivateNumbers:
raise NotImplementedError()

# for some reason the type checking thinks this should be a constant
# it fails at runtime if it's not a function though
def key_size(self) -> int: # type: ignore
return self._public_key.key_size

Expand Down Expand Up @@ -182,7 +187,7 @@ def info() -> None:
if not printed_head:
local_print("Keys:")
printed_head = True
parsed_cert = cryptography.x509.load_der_x509_certificate(cert)
parsed_cert = x509.load_der_x509_certificate(cert)
local_print(f" {key}")
local_print(
f" algorithm: {parsed_cert.signature_algorithm_oid._name}"
Expand Down Expand Up @@ -427,10 +432,8 @@ def generate_key(
algo = algo.lower()
if algo == "rsa2048":
algo_id = b"\x07"
signature_algorithm = "sha256_rsa"
elif algo == "nistp256":
algo_id = b"\x11"
signature_algorithm = "sha256_ecdsa"
else:
local_critical("Unimplemented algorithm", support_hint=False)

Expand Down Expand Up @@ -476,30 +479,26 @@ def generate_key(
else:
local_critical("Unimplemented algorithm")

certificate_builder = cryptography.x509.CertificateBuilder()
csr_builder = cryptography.x509.CertificateSigningRequestBuilder()
certificate_builder = x509.CertificateBuilder()
csr_builder = x509.CertificateSigningRequestBuilder()

if domain_component is None:
domain_component = []

if subject_name is None:
crypto_rdns = cryptography.x509.Name([])
crypto_rdns = x509.Name([])
else:
crypto_rdns = cryptography.x509.Name(
crypto_rdns = x509.Name(
[
cryptography.x509.RelativeDistinguishedName(
x509.RelativeDistinguishedName(
[
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.DOMAIN_COMPONENT, subject
)
x509.NameAttribute(x509.NameOID.DOMAIN_COMPONENT, subject)
for subject in domain_component
]
),
cryptography.x509.RelativeDistinguishedName(
x509.RelativeDistinguishedName(
[
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COMMON_NAME, subject
)
x509.NameAttribute(x509.NameOID.COMMON_NAME, subject)
for subject in subject_name
]
),
Expand All @@ -511,7 +510,7 @@ def generate_key(
.issuer_name(crypto_rdns)
.not_valid_before(datetime.datetime(2000, 1, 1, 0, 0))
.not_valid_after(datetime.datetime(2099, 1, 1, 0, 0))
.serial_number(cryptography.x509.random_serial_number())
.serial_number(x509.random_serial_number())
)
csr_builder = csr_builder.subject_name(crypto_rdns)

Expand Down Expand Up @@ -544,10 +543,10 @@ def generate_key(
)
)

crypto_extensions: Sequence[Tuple[cryptography.x509.ExtensionType, bool]] = [
(cryptography.x509.BasicConstraints(ca=False, path_length=None), True),
crypto_extensions: Sequence[Tuple[x509.ExtensionType, bool]] = [
(x509.BasicConstraints(ca=False, path_length=None), True),
(
cryptography.x509.KeyUsage(
x509.KeyUsage(
digital_signature=True,
content_commitment=True,
key_encipherment=False,
Expand All @@ -561,17 +560,17 @@ def generate_key(
True,
),
(
cryptography.x509.ExtendedKeyUsage(
x509.ExtendedKeyUsage(
[
cryptography.x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH,
cryptography.x509.oid.ExtendedKeyUsageOID.SMARTCARD_LOGON,
x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH,
x509.oid.ExtendedKeyUsageOID.SMARTCARD_LOGON,
]
),
False,
),
(
cryptography.x509.UnrecognizedExtension(
oid=cryptography.x509.oid.ObjectIdentifier("1.2.840.113549.1.9.15"),
x509.UnrecognizedExtension(
oid=x509.oid.ObjectIdentifier("1.2.840.113549.1.9.15"),
value=smime_extension,
),
False,
Expand All @@ -583,10 +582,12 @@ def generate_key(
csr_builder = csr_builder.add_extension(ext, critical)

if subject_alt_name_upn is not None:
crypto_sujbect_alt_name = cryptography.x509.SubjectAlternativeName(
crypto_sujbect_alt_name = x509.SubjectAlternativeName(
[
cryptography.x509.OtherName(
cryptography.x509.ObjectIdentifier("1.3.6.1.4.1.311.20.2.3"),
x509.OtherName(
x509.ObjectIdentifier("1.3.6.1.4.1.311.20.2.3"),
# bytes, because it's different from bytearray, and tlv because
# it expects already DER encoded ASN1
bytes(Tlv.build([(0x0C, subject_alt_name_upn.encode("utf-8"))])),
)
]
Expand Down Expand Up @@ -700,9 +701,9 @@ def write_certificate(admin_key: str, format: str, key: str, path: str) -> None:
format = format.upper()
if format == "DER":
cert_serialized = cert_bytes
cert = cryptography.x509.load_der_x509_certificate(cert_bytes)
cert = x509.load_der_x509_certificate(cert_bytes)
elif format == "PEM":
cert = cryptography.x509.load_pem_x509_certificate(cert_bytes)
cert = x509.load_pem_x509_certificate(cert_bytes)
cert_serialized = cert.public_bytes(Encoding.DER)

payload = Tlv.build(
Expand Down Expand Up @@ -774,9 +775,9 @@ def read_certificate(format: str, key: str, path: str) -> None:
format = format.upper()
if format == "DER":
cert_serialized = value
cryptography.x509.load_der_x509_certificate(value)
x509.load_der_x509_certificate(value)
elif format == "PEM":
cert = cryptography.x509.load_der_x509_certificate(value)
cert = x509.load_der_x509_certificate(value)
cert_serialized = cert.public_bytes(Encoding.PEM)

with click.open_file(path, mode="wb") as f:
Expand Down

0 comments on commit 1e5c7e8

Please sign in to comment.