Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support OAEP encryption schema #126

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rsa/pkcs1.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"SHA-512": b"\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40",
}

HASH_METHODS: typing.Dict[str, typing.Callable[[], HashType]] = {
HASH_METHODS: typing.Dict[str, typing.Callable[..., HashType]] = {
"MD5": hashlib.md5,
"SHA-1": hashlib.sha1,
"SHA-224": hashlib.sha224,
Expand Down
207 changes: 201 additions & 6 deletions rsa/pkcs1_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,24 @@
"""Functions for PKCS#1 version 2 encryption and signing

This module implements certain functionality from PKCS#1 version 2. Main
documentation is RFC 2437: https://tools.ietf.org/html/rfc2437
documentation is RFC 8017: https://tools.ietf.org/html/rfc8017
"""

from rsa import (
common,
pkcs1,
transform,
)
import os
from hmac import compare_digest

from . import common, transform, core, key, pkcs1
from ._compat import xor_bytes


def _constant_time_select(v: int, t: int, f: int) -> int:
"""Return t if v else f.

v must be 0 or 1. (False and True are allowed)
t and f are integer between 0 and 255.
"""
v -= 1
return (~v & t) | (v & f)


def mgf1(seed: bytes, length: int, hasher: str = "SHA-1") -> bytes:
Expand Down Expand Up @@ -81,8 +91,193 @@ def mgf1(seed: bytes, length: int, hasher: str = "SHA-1") -> bytes:
return output[:length]


def _OAEP_encode(
message: bytes, keylength: int, label, hash_method: str, mgf1_hash_method: str
) -> bytes:
try:
hasher = pkcs1.HASH_METHODS[hash_method](label)
except KeyError:
raise ValueError(
"Invalid `hash_method` specified. Please select one of: {hash_list}".format(
hash_list=", ".join(sorted(pkcs1.HASH_METHODS.keys()))
)
)
hash_length = hasher.digest_size
max_message_length = keylength - 2 * hash_length - 2
message_length = len(message)
if message_length > max_message_length:
raise OverflowError(
"message is too long; at most %s bytes, given %s bytes"
% (max_message_length, len(message))
)

lhash = hasher.digest()
ps = bytearray(keylength - message_length - 2 * hash_length - 2)
db = (
hasher.digest()
+ b"\0" * (keylength - message_length - 2 * hash_length - 2)
+ b"\x01"
+ message
)

seed = os.urandom(hash_length)
db_mask = mgf1(seed, keylength - hash_length - 1, mgf1_hash_method)
masked_db = xor_bytes(db, db_mask)

seed_mask = mgf1(masked_db, hash_length, mgf1_hash_method)
masked_seed = xor_bytes(seed, seed_mask)

em = b"\x00" + masked_seed + masked_db
return em


def encrypt_OAEP(
message: bytes,
pub_key: key.PublicKey,
label: bytes = b"",
hash_method: str = "SHA-1",
mgf1_hash_method: str = None,
) -> bytes:
"""Encrypts the given message using PKCS#1 v2 RSA-OEAP.

:param message: the message to encrypt.
:param pub_key: the public key to encrypt with.
:param label: optional RSA-OAEP label.
:param hash_method: hash function to be used. 'SHA-1' (default),
'SHA-256', 'SHA-384', and 'SHA-512' can be used.
:param mgf1_hash_method: hash function to be used by MGF1 function.
If it is None (default), *hash_method* is used.
"""
# NOTE: Some hash method other than listed in the docstring can be used
# for hash_method. But the RFC 8017 recommends only them.
if mgf1_hash_method is None:
mgf1_hash_method = hash_method
keylength = common.byte_size(pub_key.n)

em = _OAEP_encode(message, keylength, label, hash_method, mgf1_hash_method)

m = transform.bytes2int(em)
encrypted = core.encrypt_int(m, pub_key.e, pub_key.n)
c = transform.int2bytes(encrypted, keylength)

return c


def decrypt_OAEP(
crypto: bytes,
priv_key: key.PrivateKey,
label: bytes = b"",
hash_method: str = "SHA-1",
mgf1_hash_method: str = None,
) -> bytes:
"""Decrypts the givem crypto using PKCS#1 v2 RSA-OAEP.

:param crypto: the crypto text as returned by :py:func:`rsa.encrypt`
:param priv_key: the private key to decrypt with.
:param label: optional RSA-OAEP label.
:param hash_method: hash function to be used. 'SHA-1' (default),
'SHA-256', 'SHA-384', and 'SHA-512' can be used.
:param mgf1_hash_method: hash function to be used by MGF1 function.
If it is None (default), *hash_method* is used.

:raise rsa.pkcs1.DecryptionError: when the decryption fails. No details are given as
to why the code thinks the decryption fails, as this would leak
information about the private key.

>>> import rsa
>>> (pub_key, priv_key) = rsa.newkeys(512)

It works with binary data:

>>> crypto = encrypt_OAEP(b'hello', pub_key)
>>> decrypt_OAEP(crypto, priv_key)
b'hello'

You can pass optional label data too:

>>> crypto = encrypt_OAEP(b'hello', pub_key, label=b'world')
>>> decrypt_OAEP(crypto, priv_key, label=b'world')
b'hello'

Altering the encrypted information will cause a
:py:class:`rsa.pkcs1.DecryptionError`.

>>> crypto = encrypt_OAEP(b'hello', pub_key)
>>> crypto = crypto[0:5] + bytes([(ord(crypto[5:6])+1)%256]) + crypto[6:] # change a byte
>>> decrypt_OAEP(crypto, priv_key)
Traceback (most recent call last):
...
rsa.pkcs1.DecryptionError: Decryption failed

Changing label will also cause the error.

>>> crypto = encrypt_OAEP(b'hello', pub_key, label=b'world')
>>> decrypt_OAEP(crypto, priv_key, label=b'universe')
Traceback (most recent call last):
...
rsa.pkcs1.DecryptionError: Decryption failed
"""
if mgf1_hash_method is None:
mgf1_hash_method = hash_method

# todo: Step 1: length checking
k = common.byte_size(priv_key.n)
if k != len(crypto):
raise pkcs1.DecryptionError("Decryption failed")

# Step 2: RSA Decryption
c = transform.bytes2int(crypto)
m = priv_key.blinded_decrypt(c)
em = transform.int2bytes(m, k)

# Step 3: EME-OAEP decoding
try:
hasher = pkcs1.HASH_METHODS[hash_method](label)
except KeyError:
raise ValueError(
"Invalid `hash_method` specified. Please select one of: {hash_list}".format(
hash_list=", ".join(sorted(pkcs1.HASH_METHODS.keys()))
)
)
hash_length = hasher.digest_size
lhash = hasher.digest()
Y = em[0:1]
masked_seed = em[1 : 1 + hash_length]
masked_db = em[1 + hash_length :]

seed_mask = mgf1(masked_db, hash_length, mgf1_hash_method)
seed = xor_bytes(masked_seed, seed_mask)

db_mask = mgf1(seed, k - hash_length - 1, mgf1_hash_method)
db = xor_bytes(masked_db, db_mask)

lhash_ = db[:hash_length]
rest = db[hash_length:]

# NOTE: Take care about timing attack. See note in the RFC.
hash_is_good = compare_digest(lhash, lhash_)

index = invalid = 0
looking_one = 1

for i, c in enumerate(rest):
iszero = c == 0
isone = c == 1

index = _constant_time_select(looking_one & isone, i, index)
looking_one = _constant_time_select(isone, 0, looking_one)
invalid = _constant_time_select(looking_one & ~iszero, 1, invalid)

if invalid | looking_one | (not hash_is_good):
raise pkcs1.DecryptionError("Decryption failed")

return rest[index + 1 :]


__all__ = [
"mgf1",
"encrypt_OAEP",
"decrypt_OAEP",
]

if __name__ == "__main__":
Expand Down
43 changes: 43 additions & 0 deletions tests/test_pkcs1_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
http://www.itomorrowmag.com/emc-plus/rsa-labs/standards-initiatives/pkcs-rsa-cryptography-standard.htm
"""

import struct
import unittest

import rsa
from rsa import pkcs1_v2
from rsa._compat import byte
from rsa.pkcs1 import DecryptionError


class MGFTest(unittest.TestCase):
Expand Down Expand Up @@ -77,3 +81,42 @@ def test_invalid_hasher(self):
def test_invalid_length(self):
with self.assertRaises(OverflowError):
pkcs1_v2.mgf1(b"\x06\xe1\xde\xb2", length=2 ** 50)


class BinaryTest(unittest.TestCase):
def setUp(self):
(self.pub, self.priv) = rsa.newkeys(512)

def test_enc_dec(self):
message = struct.pack(">IIII", 0, 0, 0, 1)
print("\tMessage: %r" % message)

encrypted = pkcs1_v2.encrypt_OAEP(message, self.pub)
print("\tEncrypted: %r" % encrypted)

decrypted = pkcs1_v2.decrypt_OAEP(encrypted, self.priv)
print("\tDecrypted: %r" % decrypted)

self.assertEqual(message, decrypted)

def test_decoding_failure(self):
message = struct.pack(">IIII", 0, 0, 0, 1)
encrypted = pkcs1_v2.encrypt_OAEP(message, self.pub)

# Alter the encrypted stream
a = encrypted[5]
altered_a = (a + 1) % 256
encrypted = encrypted[:5] + byte(altered_a) + encrypted[6:]

self.assertRaises(DecryptionError, pkcs1_v2.decrypt_OAEP, encrypted, self.priv)

def test_randomness(self):
"""Encrypting the same message twice should result in different
cryptos.
"""

message = struct.pack(">IIII", 0, 0, 0, 1)
encrypted1 = pkcs1_v2.encrypt_OAEP(message, self.pub)
encrypted2 = pkcs1_v2.encrypt_OAEP(message, self.pub)

self.assertNotEqual(encrypted1, encrypted2)