Skip to content

Commit

Permalink
Bug fixes and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
TimothyClaeys committed Apr 13, 2021
1 parent 7a61ab0 commit 9b4576f
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 13 deletions.
12 changes: 7 additions & 5 deletions cose/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@


class CoseAlgorithm(_CoseAttribute, ABC):
""" Base class for all COSE algorithms. """

_registered_algorithms = {}

@classmethod
Expand All @@ -45,7 +47,7 @@ class _HashAlg(CoseAlgorithm, ABC):
#: Set in derived class to hash constructor
hash_cls = None
#: Set in derived class to optional trucation size in byte count
truc_size: Optional[int] = None
trunc_size: Optional[int] = None

@classmethod
def get_hash_func(cls) -> HashAlgorithm:
Expand All @@ -57,8 +59,8 @@ def compute_hash(cls, data: bytes) -> bytes:
h.update(data)
digest = h.finalize()

if cls.truc_size:
digest = digest[:cls.truc_size]
if cls.trunc_size:
digest = digest[:cls.trunc_size]

return digest

Expand Down Expand Up @@ -633,7 +635,7 @@ class Sha512Trunc256(_HashAlg):
identifier = -17
fullname = "SHA-512/256"
hash_cls = SHA512
truc_size = 32
trunc_size = 32


@CoseAlgorithm.register_attribute()
Expand All @@ -648,7 +650,7 @@ class Sha256Trunc64(_HashAlg):
identifier = -15
fullname = "SHA-256/64"
hash_cls = SHA256
truc_size = 8
trunc_size = 8


@CoseAlgorithm.register_attribute()
Expand Down
6 changes: 3 additions & 3 deletions cose/keys/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ def crv(self) -> Optional[Type['CoseCurve']]:
raise CoseInvalidKey("EC2 COSE key must have the EC2KpCurve attribute")

@crv.setter
def crv(self, crv: Type['CoseCurve']):
def crv(self, crv: Union[Type['CoseCurve'], int, str]):
if crv not in [P256, P384, P521] \
or crv not in [P256.identifier, P384.identifier, P521.identifier] \
or crv not in [P256.fullname, P384.fullname, P521.fullname]:
and crv not in [P256.identifier, P384.identifier, P521.identifier] \
and crv not in [P256.fullname, P384.fullname, P521.fullname]:
raise CoseIllegalCurve("Invalid COSE curve attribute")
else:
self.store[EC2KpCurve] = CoseCurve.from_id(crv)
Expand Down
16 changes: 16 additions & 0 deletions cose/keys/keyparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from cose.utils import _CoseAttribute


#########################################
# Base Key Parameters
#########################################

class KeyParam(_CoseAttribute, ABC):
_registered_algorithms = {}

Expand Down Expand Up @@ -44,6 +48,10 @@ class KpBaseIV(KeyParam):
fullname = 'BASE_IV'


#########################################
# EC2 Key Parameters
#########################################

class EC2KeyParam(_CoseAttribute, ABC):
_registered_algorithms = {}
_registered_algorithms.update(KeyParam.get_registered_classes())
Expand Down Expand Up @@ -77,6 +85,10 @@ class EC2KpD(EC2KeyParam):
fullname = "D"


#########################################
# OKP Key Parameters
#########################################

class OKPKeyParam(_CoseAttribute, ABC):
_registered_algorithms = {}
_registered_algorithms.update(KeyParam.get_registered_classes())
Expand Down Expand Up @@ -104,6 +116,10 @@ class OKPKpX(OKPKeyParam):
fullname = "X"


#########################################
# Symmetric Key Parameters
#########################################

class SymmetricKeyParam(_CoseAttribute, ABC):
_registered_algorithms = {}
_registered_algorithms.update(KeyParam.get_registered_classes())
Expand Down
17 changes: 13 additions & 4 deletions cose/keys/okp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def _key_transform(key: Union[Type['OKPKeyParam'], Type['KeyParam'], str, int]):

def __init__(self, crv: Union[Type['CoseCurve'], str, int], x: bytes = b'', d: bytes = b'',
optional_params: Optional[dict] = None):
"""
Create an COSE OKP key.
:param crv: An OKP elliptic curve.
:param x: Public value of the OKP key.
:param d: Private value of the OKP key.
:param optional_params: A dictionary with optional key parameters.
"""

transformed_dict = {}

if len(x) == 0 and len(d) == 0:
Expand Down Expand Up @@ -110,10 +119,10 @@ def crv(self) -> Optional[Type['CoseCurve']]:
raise CoseInvalidKey("OKP COSE key must have the OKP KpCurve attribute")

@crv.setter
def crv(self, crv: Type['CoseCurve']):
def crv(self, crv: Union[Type['CoseCurve'], int, str]):
if crv not in [X25519, X448, Ed25519, Ed448] \
or crv not in [X25519.identifier, X448.identifier, Ed25519.identifier, Ed448.identifier] \
or crv not in [X25519.fullname, X448.fullname, Ed25519.fullname, Ed448.identifier]:
and crv not in [X25519.identifier, X448.identifier, Ed25519.identifier, Ed448.identifier] \
and crv not in [X25519.fullname, X448.fullname, Ed25519.fullname, Ed448.identifier]:
raise CoseIllegalCurve("Invalid COSE curve attribute")
else:
self.store[OKPKpCurve] = CoseCurve.from_id(crv)
Expand Down Expand Up @@ -183,7 +192,7 @@ def generate_key(curve: Union[Type['CoseCurve'], str, int], optional_params: dic
d=private_key.private_bytes(encoding, private_format, encryption),
optional_params=optional_params)

def __delitem__(self, key):
def __delitem__(self, key: Union['KeyParam', str, int]):
if self._key_transform(key) != KpKty and self._key_transform(key) != OKPKpCurve:
if self._key_transform(key) == OKPKpD and OKPKpX not in self.store:
pass
Expand Down
23 changes: 22 additions & 1 deletion tests/test_okp_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from cose.algorithms import EdDSA
from cose.curves import Ed448, Ed25519, X448, X25519
from cose.exceptions import CoseInvalidKey, CoseIllegalKeyType
from cose.exceptions import CoseInvalidKey, CoseIllegalKeyType, CoseIllegalCurve
from cose.keys import OKPKey, CoseKey
from cose.keys.keyops import SignOp
from cose.keys.keyparam import KpKty, OKPKpCurve, OKPKpX, OKPKpD, KpAlg, KpKeyOps
Expand Down Expand Up @@ -156,3 +156,24 @@ def test_unknown_key_attributes():
key = CoseKey.decode(unhexlify(key))

assert "subject name" in key


def test_key_set_curve():
key = 'a401012006215820898ff79a02067a16ea1eccb90fa52246f5aa4dd6ec076bba0259d904b7ec8b0c2358208f781a095372f85b6d' \
'9f6109ae422611734d7dbfa0069a2df2935bb2e053bf35'
key = CoseKey.decode(unhexlify(key))

assert key.crv == Ed25519

key.crv = X25519

assert key.crv == X25519

with pytest.raises(CoseIllegalCurve) as excinfo:
key.crv = 3

assert "Invalid COSE curve attribute" in str(excinfo.value)

key.crv = X448.identifier

assert key.crv == X448

0 comments on commit 9b4576f

Please sign in to comment.