diff --git a/jwcrypto/jwk.py b/jwcrypto/jwk.py index c273445..168fb5c 100644 --- a/jwcrypto/jwk.py +++ b/jwcrypto/jwk.py @@ -368,6 +368,18 @@ def generate_key(self, **params): gen(params) + @classmethod + def generate_similar(cls, key): + return cls.generate(**key.get_generate_params()) + + def get_generate_params(self): + params = {param: self.get(param) for param in ["kty", "crv", "use", "key_ops"] if param in self} + if self.get("kty") == "RSA": + params["size"] = self._get_public_key().key_size + elif self.get("kty") == "oct": + params["size"] = len(base64url_decode(self.k)) * 8 + return params + def _get_gen_size(self, params, default_size=None): size = default_size if 'size' in params: diff --git a/jwcrypto/tests.py b/jwcrypto/tests.py index 647235b..8569e9d 100644 --- a/jwcrypto/tests.py +++ b/jwcrypto/tests.py @@ -506,6 +506,58 @@ def test_jwk_from_json(self): y = jwk.JWK.from_json(k.export()) self.assertEqual(k.export(), y.export()) + def test_generate_similar(self): + KEY_PARAMS = [ + { + "kty": "oct", + "size": 192 + }, + { + "kty": "RSA", + "size": 3072 + }, + { + "kty": "EC", + "crv": "P-256" + }, + { + "kty": "EC", + "crv": "P-384", + }, + { + "kty": "OKP", + "crv": "Ed25519" + }, + { + "kty": "OKP", + "crv": "Ed448" + }, + { + "kty": "OKP", + "crv": "X25519" + }, + { + "kty": "OKP", + "crv": "X448" + }, + { + "kty": "RSA", + "size": 3072, + "use": "sig" + }, + { + "kty": "oct", + "size": 256, + "key_ops": "sign" + }, + ] + for params in KEY_PARAMS: + key1 = jwk.JWK.generate(**params) + key2 = jwk.JWK.generate_similar(key1) + for prop in ["kty", "crv", "use", "key_ops"]: + self.assertEqual(key1.get(prop), key2.get(prop)) + self.assertEqual(type(key1.get_op_key("sign")), type(key2.get_op_key("sign"))) + def test_jwkset(self): k = jwk.JWK(**RSAPrivateKey) ks = jwk.JWKSet()