From 15727f19dbae05f87ef49ba85032ef98b95eeebd Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 21 Jan 2019 21:04:16 +0900 Subject: [PATCH 1/5] Implement RSA-OAEP encryption schema --- rsa/pkcs1_v2.py | 195 ++++++++++++++++++++++++++++++++++++++++- tests/test_pkcs1_v2.py | 46 ++++++++++ 2 files changed, 238 insertions(+), 3 deletions(-) diff --git a/rsa/pkcs1_v2.py b/rsa/pkcs1_v2.py index 5f9c7dd..123fa74 100644 --- a/rsa/pkcs1_v2.py +++ b/rsa/pkcs1_v2.py @@ -17,17 +17,41 @@ """Functions for PKCS#1 version 2 encryption and signing This module implements certain functionality from PKCS#1 version 2. Main -documentation is RFC 2437: https://tools.ietf.org/html/rfc2437 +documentation is RFC 8017: https://tools.ietf.org/html/rfc8017 """ -from rsa._compat import range +import os + +from rsa._compat import range, xor_bytes, PY2 from rsa import ( common, + core, pkcs1, transform, ) +try: + from hmac import compare_digest # Available from 2.7.7+ and 3.3+ +except ImportError: + # https://www.reddit.com/r/Python/comments/49hwq0/constant_time_comparison_in_python/d0ry5qx/ + def compare_digest(a, b): + result = True + for x, y in zip(a, b): + result &= (x == y) + return result + + +def _constant_time_select(v, t, f): + """Return t if v else f. + + v must be 0 or 1. (False and True are allowed) + t and f are integer between 0 and 255. + """ + v -= 1 + return (~v & t) | (v & f) + + def mgf1(seed, length, hasher='SHA-1'): """ MGF1 is a Mask Generation Function based on a hash function. @@ -84,8 +108,173 @@ def mgf1(seed, length, hasher='SHA-1'): return output[:length] +def _OAEP_encode(message, keylength, label, hash_method, mgf1_hash_method): + try: + hasher = pkcs1.HASH_METHODS[hash_method](label) + except KeyError: + raise ValueError( + 'Invalid `hash_method` specified. Please select one of: {hash_list}'.format( + hash_list=', '.join(sorted(pkcs1.HASH_METHODS.keys())) + ) + ) + hash_length = hasher.digest_size + max_message_length = keylength - 2*hash_length - 2 + message_length = len(message) + if message_length > max_message_length: + raise OverflowError("message is too long; at most %s bytes, given %s bytes" % + (max_message_length, len(message))) + + lhash = hasher.digest() + ps = bytearray(keylength - message_length - 2*hash_length -2) + db = hasher.digest() + b'\0' * (keylength - message_length - 2*hash_length - 2) \ + + b'\x01' + message + + seed = os.urandom(hash_length) + db_mask = mgf1(seed, keylength - hash_length - 1, mgf1_hash_method) + masked_db = xor_bytes(db, db_mask) + + seed_mask = mgf1(masked_db, hash_length, mgf1_hash_method) + masked_seed = xor_bytes(seed, seed_mask) + + em = b'\x00' + masked_seed + masked_db + return em + + +def OAEP_encrypt(message, pub_key, label=b'', hash_method="SHA-1", + mgf1_hash_method=None): + """Encrypts the given message using PKCS#1 v2 RSA-OEAP. + + :param bytes message: the message to encrypt. + :param rsa.PublicKey pub_key: the public key to encrypt with. + :param bytes label: optional RSA-OAEP label. + :param str hash_method: hash function to be used. 'SHA-1' (default), + 'SHA-256', 'SHA-384', and 'SHA-512' can be used. + """ + # NOTE: Some hash method other than listed in the docstring can be used + # for hash_method. But the RFC 8017 recommends only them. + if mgf1_hash_method is None: + mgf1_hash_method = hash_method + keylength = common.byte_size(pub_key.n) + + em = _OAEP_encode(message, keylength, label, hash_method, mgf1_hash_method) + + m = transform.bytes2int(em) + encrypted = core.encrypt_int(m, pub_key.e, pub_key.n) + c = transform.int2bytes(encrypted, keylength) + + return c + + +def OAEP_decrypt(crypto, priv_key, label=b'', hash_method="SHA-1", + mgf1_hash_method=None): + """Decrypts the givem crypto using PKCS#1 v2 RSA-OAEP. + + :param bytes crypto: the crypto text as returned by :py:func:`rsa.encrypt` + :param rsa.PrivateKey priv_key: the private key to decrypt with. + :param bytes label: optional RSA-OAEP label. + :param str hash_method: hash function to be used. 'SHA-1' (default), + 'SHA-256', 'SHA-384', and 'SHA-512' can be used. + :param str mgf1_hash_method: hash function to be used by MGF1 function. + If it is None (default), *hash_method* is used. + + :raise rsa.pkcs1.DecryptionError: when the decryption fails. No details are given as + to why the code thinks the decryption fails, as this would leak + information about the private key. + + >>> import rsa + >>> (pub_key, priv_key) = rsa.newkeys(512) + + It works with binary data: + + >>> crypto = OAEP_encrypt(b'hello', pub_key) + >>> OAEP_decrypt(crypto, priv_key) + b'hello' + + You can pass optional label data too: + + >>> crypto = OAEP_encrypt(b'hello', pub_key, label=b'world') + >>> OAEP_decrypt(crypto, priv_key, label=b'world') + b'hello' + + Altering the encrypted information will cause a + :py:class:`rsa.pkcs1.DecryptionError`. + + >>> crypto = OAEP_encrypt(b'hello', pub_key) + >>> crypto = crypto[0:5] + bytes([(ord(crypto[5:6])+1)%256]) + crypto[6:] # change a byte + >>> OAEP_decrypt(crypto, priv_key) + Traceback (most recent call last): + ... + rsa.pkcs1.DecryptionError: Decryption failed + + Changing label will also cause the error. + + >>> crypto = OAEP_encrypt(b'hello', pub_key, label=b'world') + >>> OAEP_decrypt(crypto, priv_key, label=b'universe') + Traceback (most recent call last): + ... + rsa.pkcs1.DecryptionError: Decryption failed + """ + if mgf1_hash_method is None: + mgf1_hash_method = hash_method + + # todo: Step 1: length checking + k = common.byte_size(priv_key.n) + if k != len(crypto): + raise pkcs1.DecryptionError('Decryption failed') + + # Step 2: RSA Decryption + c = transform.bytes2int(crypto) + m = priv_key.blinded_decrypt(c) + em = transform.int2bytes(m, k) + + # Step 3: EME-OAEP decoding + try: + hasher = pkcs1.HASH_METHODS[hash_method](label) + except KeyError: + raise ValueError( + 'Invalid `hash_method` specified. Please select one of: {hash_list}'.format( + hash_list=', '.join(sorted(pkcs1.HASH_METHODS.keys())) + ) + ) + hash_length = hasher.digest_size + lhash = hasher.digest() + Y = em[0:1] + masked_seed = em[1:1+hash_length] + masked_db = em[1+hash_length:] + + seed_mask = mgf1(masked_db, hash_length, mgf1_hash_method) + seed = xor_bytes(masked_seed, seed_mask) + + db_mask = mgf1(seed, k-hash_length-1, mgf1_hash_method) + db = xor_bytes(masked_db, db_mask) + + lhash_ = db[:hash_length] + rest = db[hash_length:] + + # NOTE: Take care about timing attack. See note in the RFC. + hash_is_good = compare_digest(lhash, lhash_) + + index = invalid = 0 + looking_one = 1 + + if PY2: + rest = bytearray(rest) + for i, c in enumerate(rest): + iszero = c == 0 + isone = c == 1 + + index = _constant_time_select(looking_one & isone, i, index) + looking_one = _constant_time_select(isone, 0, looking_one) + invalid = _constant_time_select(looking_one & ~iszero, 1, invalid) + + if invalid | looking_one | (not hash_is_good): + raise pkcs1.DecryptionError('Decryption failed') + + return rest[index+1:] + + __all__ = [ - 'mgf1', + 'mgf1', 'OAEP_encrypt', 'OAEP_decrypt', ] if __name__ == '__main__': diff --git a/tests/test_pkcs1_v2.py b/tests/test_pkcs1_v2.py index 1d8f001..0c3d1e1 100644 --- a/tests/test_pkcs1_v2.py +++ b/tests/test_pkcs1_v2.py @@ -20,9 +20,13 @@ http://www.itomorrowmag.com/emc-plus/rsa-labs/standards-initiatives/pkcs-rsa-cryptography-standard.htm """ +import struct import unittest +import rsa from rsa import pkcs1_v2 +from rsa._compat import byte, is_bytes +from rsa.pkcs1 import DecryptionError class MGFTest(unittest.TestCase): @@ -81,3 +85,45 @@ def test_invalid_hasher(self): def test_invalid_length(self): with self.assertRaises(OverflowError): pkcs1_v2.mgf1(b'\x06\xe1\xde\xb2', length=2**50) + + +class BinaryTest(unittest.TestCase): + def setUp(self): + (self.pub, self.priv) = rsa.newkeys(512) + + def test_enc_dec(self): + message = struct.pack('>IIII', 0, 0, 0, 1) + print("\tMessage: %r" % message) + + encrypted = pkcs1_v2.OAEP_encrypt(message, self.pub) + print("\tEncrypted: %r" % encrypted) + + decrypted = pkcs1_v2.OAEP_decrypt(encrypted, self.priv) + print("\tDecrypted: %r" % decrypted) + + self.assertEqual(message, decrypted) + + def test_decoding_failure(self): + message = struct.pack('>IIII', 0, 0, 0, 1) + encrypted = pkcs1_v2.OAEP_encrypt(message, self.pub) + + # Alter the encrypted stream + a = encrypted[5] + if is_bytes(a): + a = ord(a) + altered_a = (a + 1) % 256 + encrypted = encrypted[:5] + byte(altered_a) + encrypted[6:] + + self.assertRaises(DecryptionError, pkcs1_v2.OAEP_decrypt, + encrypted, self.priv) + + def test_randomness(self): + """Encrypting the same message twice should result in different + cryptos. + """ + + message = struct.pack('>IIII', 0, 0, 0, 1) + encrypted1 = pkcs1_v2.OAEP_encrypt(message, self.pub) + encrypted2 = pkcs1_v2.OAEP_encrypt(message, self.pub) + + self.assertNotEqual(encrypted1, encrypted2) From 9177447c65ebe95e0a807f3f107ad58489f582ec Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 10 Jun 2021 18:06:04 +0900 Subject: [PATCH 2/5] fixup --- rsa/pkcs1_v2.py | 4 ++-- tests/test_pkcs1_v2.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/rsa/pkcs1_v2.py b/rsa/pkcs1_v2.py index 83e4db7..c7dc6de 100644 --- a/rsa/pkcs1_v2.py +++ b/rsa/pkcs1_v2.py @@ -18,6 +18,7 @@ documentation is RFC 8017: https://tools.ietf.org/html/rfc8017 """ +import os from hmac import compare_digest from rsa import ( common, @@ -25,6 +26,7 @@ pkcs1, transform, ) +from rsa._compat import xor_bytes def _constant_time_select(v, t, f): @@ -242,8 +244,6 @@ def OAEP_decrypt(crypto, priv_key, label=b'', hash_method="SHA-1", index = invalid = 0 looking_one = 1 - if PY2: - rest = bytearray(rest) for i, c in enumerate(rest): iszero = c == 0 isone = c == 1 diff --git a/tests/test_pkcs1_v2.py b/tests/test_pkcs1_v2.py index 3d1c415..d7524fd 100644 --- a/tests/test_pkcs1_v2.py +++ b/tests/test_pkcs1_v2.py @@ -23,7 +23,7 @@ import rsa from rsa import pkcs1_v2 -from rsa._compat import byte, is_bytes +from rsa._compat import byte from rsa.pkcs1 import DecryptionError @@ -105,8 +105,6 @@ def test_decoding_failure(self): # Alter the encrypted stream a = encrypted[5] - if is_bytes(a): - a = ord(a) altered_a = (a + 1) % 256 encrypted = encrypted[:5] + byte(altered_a) + encrypted[6:] From ddaf57566aa499123deb2b352f6790c5ed68d813 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 10 Jun 2021 18:17:42 +0900 Subject: [PATCH 3/5] rename OAEP_encrypt to encrypt_OAEP --- rsa/pkcs1_v2.py | 22 +++++++++++----------- tests/test_pkcs1_v2.py | 12 ++++++------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/rsa/pkcs1_v2.py b/rsa/pkcs1_v2.py index c7dc6de..acf4869 100644 --- a/rsa/pkcs1_v2.py +++ b/rsa/pkcs1_v2.py @@ -127,7 +127,7 @@ def _OAEP_encode(message, keylength, label, hash_method, mgf1_hash_method): return em -def OAEP_encrypt(message, pub_key, label=b'', hash_method="SHA-1", +def encrypt_OAEP(message, pub_key, label=b'', hash_method="SHA-1", mgf1_hash_method=None): """Encrypts the given message using PKCS#1 v2 RSA-OEAP. @@ -152,7 +152,7 @@ def OAEP_encrypt(message, pub_key, label=b'', hash_method="SHA-1", return c -def OAEP_decrypt(crypto, priv_key, label=b'', hash_method="SHA-1", +def decrypt_OAEP(crypto, priv_key, label=b'', hash_method="SHA-1", mgf1_hash_method=None): """Decrypts the givem crypto using PKCS#1 v2 RSA-OAEP. @@ -173,30 +173,30 @@ def OAEP_decrypt(crypto, priv_key, label=b'', hash_method="SHA-1", It works with binary data: - >>> crypto = OAEP_encrypt(b'hello', pub_key) - >>> OAEP_decrypt(crypto, priv_key) + >>> crypto = encrypt_OAEP(b'hello', pub_key) + >>> decrypt_OAEP(crypto, priv_key) b'hello' You can pass optional label data too: - >>> crypto = OAEP_encrypt(b'hello', pub_key, label=b'world') - >>> OAEP_decrypt(crypto, priv_key, label=b'world') + >>> crypto = encrypt_OAEP(b'hello', pub_key, label=b'world') + >>> decrypt_OAEP(crypto, priv_key, label=b'world') b'hello' Altering the encrypted information will cause a :py:class:`rsa.pkcs1.DecryptionError`. - >>> crypto = OAEP_encrypt(b'hello', pub_key) + >>> crypto = encrypt_OAEP(b'hello', pub_key) >>> crypto = crypto[0:5] + bytes([(ord(crypto[5:6])+1)%256]) + crypto[6:] # change a byte - >>> OAEP_decrypt(crypto, priv_key) + >>> decrypt_OAEP(crypto, priv_key) Traceback (most recent call last): ... rsa.pkcs1.DecryptionError: Decryption failed Changing label will also cause the error. - >>> crypto = OAEP_encrypt(b'hello', pub_key, label=b'world') - >>> OAEP_decrypt(crypto, priv_key, label=b'universe') + >>> crypto = encrypt_OAEP(b'hello', pub_key, label=b'world') + >>> decrypt_OAEP(crypto, priv_key, label=b'universe') Traceback (most recent call last): ... rsa.pkcs1.DecryptionError: Decryption failed @@ -259,7 +259,7 @@ def OAEP_decrypt(crypto, priv_key, label=b'', hash_method="SHA-1", __all__ = [ - 'mgf1', 'OAEP_encrypt', 'OAEP_decrypt', + 'mgf1', 'encrypt_OAEP', 'decrypt_OAEP', ] if __name__ == "__main__": diff --git a/tests/test_pkcs1_v2.py b/tests/test_pkcs1_v2.py index d7524fd..d605157 100644 --- a/tests/test_pkcs1_v2.py +++ b/tests/test_pkcs1_v2.py @@ -91,24 +91,24 @@ def test_enc_dec(self): message = struct.pack('>IIII', 0, 0, 0, 1) print("\tMessage: %r" % message) - encrypted = pkcs1_v2.OAEP_encrypt(message, self.pub) + encrypted = pkcs1_v2.encrypt_OAEP(message, self.pub) print("\tEncrypted: %r" % encrypted) - decrypted = pkcs1_v2.OAEP_decrypt(encrypted, self.priv) + decrypted = pkcs1_v2.decrypt_OAEP(encrypted, self.priv) print("\tDecrypted: %r" % decrypted) self.assertEqual(message, decrypted) def test_decoding_failure(self): message = struct.pack('>IIII', 0, 0, 0, 1) - encrypted = pkcs1_v2.OAEP_encrypt(message, self.pub) + encrypted = pkcs1_v2.encrypt_OAEP(message, self.pub) # Alter the encrypted stream a = encrypted[5] altered_a = (a + 1) % 256 encrypted = encrypted[:5] + byte(altered_a) + encrypted[6:] - self.assertRaises(DecryptionError, pkcs1_v2.OAEP_decrypt, + self.assertRaises(DecryptionError, pkcs1_v2.decrypt_OAEP, encrypted, self.priv) def test_randomness(self): @@ -117,7 +117,7 @@ def test_randomness(self): """ message = struct.pack('>IIII', 0, 0, 0, 1) - encrypted1 = pkcs1_v2.OAEP_encrypt(message, self.pub) - encrypted2 = pkcs1_v2.OAEP_encrypt(message, self.pub) + encrypted1 = pkcs1_v2.encrypt_OAEP(message, self.pub) + encrypted2 = pkcs1_v2.encrypt_OAEP(message, self.pub) self.assertNotEqual(encrypted1, encrypted2) From 75070dcc9cfab2aa13dc82158afcedc0619cf095 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 10 Jun 2021 18:19:32 +0900 Subject: [PATCH 4/5] black --- rsa/pkcs1_v2.py | 50 +++++++++++++++++++++++------------------- tests/test_pkcs1_v2.py | 11 +++++----- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/rsa/pkcs1_v2.py b/rsa/pkcs1_v2.py index acf4869..a4aa534 100644 --- a/rsa/pkcs1_v2.py +++ b/rsa/pkcs1_v2.py @@ -100,21 +100,27 @@ def _OAEP_encode(message, keylength, label, hash_method, mgf1_hash_method): hasher = pkcs1.HASH_METHODS[hash_method](label) except KeyError: raise ValueError( - 'Invalid `hash_method` specified. Please select one of: {hash_list}'.format( - hash_list=', '.join(sorted(pkcs1.HASH_METHODS.keys())) + "Invalid `hash_method` specified. Please select one of: {hash_list}".format( + hash_list=", ".join(sorted(pkcs1.HASH_METHODS.keys())) ) ) hash_length = hasher.digest_size - max_message_length = keylength - 2*hash_length - 2 + max_message_length = keylength - 2 * hash_length - 2 message_length = len(message) if message_length > max_message_length: - raise OverflowError("message is too long; at most %s bytes, given %s bytes" % - (max_message_length, len(message))) + raise OverflowError( + "message is too long; at most %s bytes, given %s bytes" + % (max_message_length, len(message)) + ) lhash = hasher.digest() - ps = bytearray(keylength - message_length - 2*hash_length -2) - db = hasher.digest() + b'\0' * (keylength - message_length - 2*hash_length - 2) \ - + b'\x01' + message + ps = bytearray(keylength - message_length - 2 * hash_length - 2) + db = ( + hasher.digest() + + b"\0" * (keylength - message_length - 2 * hash_length - 2) + + b"\x01" + + message + ) seed = os.urandom(hash_length) db_mask = mgf1(seed, keylength - hash_length - 1, mgf1_hash_method) @@ -123,12 +129,11 @@ def _OAEP_encode(message, keylength, label, hash_method, mgf1_hash_method): seed_mask = mgf1(masked_db, hash_length, mgf1_hash_method) masked_seed = xor_bytes(seed, seed_mask) - em = b'\x00' + masked_seed + masked_db + em = b"\x00" + masked_seed + masked_db return em -def encrypt_OAEP(message, pub_key, label=b'', hash_method="SHA-1", - mgf1_hash_method=None): +def encrypt_OAEP(message, pub_key, label=b"", hash_method="SHA-1", mgf1_hash_method=None): """Encrypts the given message using PKCS#1 v2 RSA-OEAP. :param bytes message: the message to encrypt. @@ -152,8 +157,7 @@ def encrypt_OAEP(message, pub_key, label=b'', hash_method="SHA-1", return c -def decrypt_OAEP(crypto, priv_key, label=b'', hash_method="SHA-1", - mgf1_hash_method=None): +def decrypt_OAEP(crypto, priv_key, label=b"", hash_method="SHA-1", mgf1_hash_method=None): """Decrypts the givem crypto using PKCS#1 v2 RSA-OAEP. :param bytes crypto: the crypto text as returned by :py:func:`rsa.encrypt` @@ -207,7 +211,7 @@ def decrypt_OAEP(crypto, priv_key, label=b'', hash_method="SHA-1", # todo: Step 1: length checking k = common.byte_size(priv_key.n) if k != len(crypto): - raise pkcs1.DecryptionError('Decryption failed') + raise pkcs1.DecryptionError("Decryption failed") # Step 2: RSA Decryption c = transform.bytes2int(crypto) @@ -219,20 +223,20 @@ def decrypt_OAEP(crypto, priv_key, label=b'', hash_method="SHA-1", hasher = pkcs1.HASH_METHODS[hash_method](label) except KeyError: raise ValueError( - 'Invalid `hash_method` specified. Please select one of: {hash_list}'.format( - hash_list=', '.join(sorted(pkcs1.HASH_METHODS.keys())) + "Invalid `hash_method` specified. Please select one of: {hash_list}".format( + hash_list=", ".join(sorted(pkcs1.HASH_METHODS.keys())) ) ) hash_length = hasher.digest_size lhash = hasher.digest() Y = em[0:1] - masked_seed = em[1:1+hash_length] - masked_db = em[1+hash_length:] + masked_seed = em[1 : 1 + hash_length] + masked_db = em[1 + hash_length :] seed_mask = mgf1(masked_db, hash_length, mgf1_hash_method) seed = xor_bytes(masked_seed, seed_mask) - db_mask = mgf1(seed, k-hash_length-1, mgf1_hash_method) + db_mask = mgf1(seed, k - hash_length - 1, mgf1_hash_method) db = xor_bytes(masked_db, db_mask) lhash_ = db[:hash_length] @@ -253,13 +257,15 @@ def decrypt_OAEP(crypto, priv_key, label=b'', hash_method="SHA-1", invalid = _constant_time_select(looking_one & ~iszero, 1, invalid) if invalid | looking_one | (not hash_is_good): - raise pkcs1.DecryptionError('Decryption failed') + raise pkcs1.DecryptionError("Decryption failed") - return rest[index+1:] + return rest[index + 1 :] __all__ = [ - 'mgf1', 'encrypt_OAEP', 'decrypt_OAEP', + "mgf1", + "encrypt_OAEP", + "decrypt_OAEP", ] if __name__ == "__main__": diff --git a/tests/test_pkcs1_v2.py b/tests/test_pkcs1_v2.py index d605157..eee1cbc 100644 --- a/tests/test_pkcs1_v2.py +++ b/tests/test_pkcs1_v2.py @@ -80,7 +80,7 @@ def test_invalid_hasher(self): def test_invalid_length(self): with self.assertRaises(OverflowError): - pkcs1_v2.mgf1(b"\x06\xe1\xde\xb2", length=2**50) + pkcs1_v2.mgf1(b"\x06\xe1\xde\xb2", length=2 ** 50) class BinaryTest(unittest.TestCase): @@ -88,7 +88,7 @@ def setUp(self): (self.pub, self.priv) = rsa.newkeys(512) def test_enc_dec(self): - message = struct.pack('>IIII', 0, 0, 0, 1) + message = struct.pack(">IIII", 0, 0, 0, 1) print("\tMessage: %r" % message) encrypted = pkcs1_v2.encrypt_OAEP(message, self.pub) @@ -100,7 +100,7 @@ def test_enc_dec(self): self.assertEqual(message, decrypted) def test_decoding_failure(self): - message = struct.pack('>IIII', 0, 0, 0, 1) + message = struct.pack(">IIII", 0, 0, 0, 1) encrypted = pkcs1_v2.encrypt_OAEP(message, self.pub) # Alter the encrypted stream @@ -108,15 +108,14 @@ def test_decoding_failure(self): altered_a = (a + 1) % 256 encrypted = encrypted[:5] + byte(altered_a) + encrypted[6:] - self.assertRaises(DecryptionError, pkcs1_v2.decrypt_OAEP, - encrypted, self.priv) + self.assertRaises(DecryptionError, pkcs1_v2.decrypt_OAEP, encrypted, self.priv) def test_randomness(self): """Encrypting the same message twice should result in different cryptos. """ - message = struct.pack('>IIII', 0, 0, 0, 1) + message = struct.pack(">IIII", 0, 0, 0, 1) encrypted1 = pkcs1_v2.encrypt_OAEP(message, self.pub) encrypted2 = pkcs1_v2.encrypt_OAEP(message, self.pub) From fd273d04bc84f2d8f6a3758a9992ebcc914e7066 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 11 Jun 2021 12:35:11 +0900 Subject: [PATCH 5/5] format --- rsa/pkcs1.py | 2 +- rsa/pkcs1_v2.py | 52 ++++++++++++++++++++++++++++++------------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/rsa/pkcs1.py b/rsa/pkcs1.py index 5992c7f..6ef477d 100644 --- a/rsa/pkcs1.py +++ b/rsa/pkcs1.py @@ -49,7 +49,7 @@ "SHA-512": b"\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40", } -HASH_METHODS: typing.Dict[str, typing.Callable[[], HashType]] = { +HASH_METHODS: typing.Dict[str, typing.Callable[..., HashType]] = { "MD5": hashlib.md5, "SHA-1": hashlib.sha1, "SHA-224": hashlib.sha224, diff --git a/rsa/pkcs1_v2.py b/rsa/pkcs1_v2.py index a4aa534..c956155 100644 --- a/rsa/pkcs1_v2.py +++ b/rsa/pkcs1_v2.py @@ -20,16 +20,12 @@ import os from hmac import compare_digest -from rsa import ( - common, - core, - pkcs1, - transform, -) -from rsa._compat import xor_bytes +from . import common, transform, core, key, pkcs1 +from ._compat import xor_bytes -def _constant_time_select(v, t, f): + +def _constant_time_select(v: int, t: int, f: int) -> int: """Return t if v else f. v must be 0 or 1. (False and True are allowed) @@ -95,7 +91,9 @@ def mgf1(seed: bytes, length: int, hasher: str = "SHA-1") -> bytes: return output[:length] -def _OAEP_encode(message, keylength, label, hash_method, mgf1_hash_method): +def _OAEP_encode( + message: bytes, keylength: int, label, hash_method: str, mgf1_hash_method: str +) -> bytes: try: hasher = pkcs1.HASH_METHODS[hash_method](label) except KeyError: @@ -133,14 +131,22 @@ def _OAEP_encode(message, keylength, label, hash_method, mgf1_hash_method): return em -def encrypt_OAEP(message, pub_key, label=b"", hash_method="SHA-1", mgf1_hash_method=None): +def encrypt_OAEP( + message: bytes, + pub_key: key.PublicKey, + label: bytes = b"", + hash_method: str = "SHA-1", + mgf1_hash_method: str = None, +) -> bytes: """Encrypts the given message using PKCS#1 v2 RSA-OEAP. - :param bytes message: the message to encrypt. - :param rsa.PublicKey pub_key: the public key to encrypt with. - :param bytes label: optional RSA-OAEP label. - :param str hash_method: hash function to be used. 'SHA-1' (default), + :param message: the message to encrypt. + :param pub_key: the public key to encrypt with. + :param label: optional RSA-OAEP label. + :param hash_method: hash function to be used. 'SHA-1' (default), 'SHA-256', 'SHA-384', and 'SHA-512' can be used. + :param mgf1_hash_method: hash function to be used by MGF1 function. + If it is None (default), *hash_method* is used. """ # NOTE: Some hash method other than listed in the docstring can be used # for hash_method. But the RFC 8017 recommends only them. @@ -157,15 +163,21 @@ def encrypt_OAEP(message, pub_key, label=b"", hash_method="SHA-1", mgf1_hash_met return c -def decrypt_OAEP(crypto, priv_key, label=b"", hash_method="SHA-1", mgf1_hash_method=None): +def decrypt_OAEP( + crypto: bytes, + priv_key: key.PrivateKey, + label: bytes = b"", + hash_method: str = "SHA-1", + mgf1_hash_method: str = None, +) -> bytes: """Decrypts the givem crypto using PKCS#1 v2 RSA-OAEP. - :param bytes crypto: the crypto text as returned by :py:func:`rsa.encrypt` - :param rsa.PrivateKey priv_key: the private key to decrypt with. - :param bytes label: optional RSA-OAEP label. - :param str hash_method: hash function to be used. 'SHA-1' (default), + :param crypto: the crypto text as returned by :py:func:`rsa.encrypt` + :param priv_key: the private key to decrypt with. + :param label: optional RSA-OAEP label. + :param hash_method: hash function to be used. 'SHA-1' (default), 'SHA-256', 'SHA-384', and 'SHA-512' can be used. - :param str mgf1_hash_method: hash function to be used by MGF1 function. + :param mgf1_hash_method: hash function to be used by MGF1 function. If it is None (default), *hash_method* is used. :raise rsa.pkcs1.DecryptionError: when the decryption fails. No details are given as