Skip to content

Commit

Permalink
Convert AESGCM AEAD to Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Oct 28, 2023
1 parent 3b39f65 commit 7134cc7
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 65 deletions.
3 changes: 1 addition & 2 deletions src/cryptography/hazmat/backends/openssl/aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
from cryptography.hazmat.backends.openssl.backend import Backend
from cryptography.hazmat.primitives.ciphers.aead import (
AESCCM,
AESGCM,
ChaCha20Poly1305,
)

_AEADTypes = typing.Union[AESCCM, AESGCM, ChaCha20Poly1305]
_AEADTypes = typing.Union[AESCCM, ChaCha20Poly1305]


def _is_evp_aead_supported_cipher(
Expand Down
17 changes: 17 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,23 @@
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

class AESGCM:
def __init__(self, key: bytes) -> None: ...
@staticmethod
def generate_key(key_size: int) -> bytes: ...
def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes: ...
def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes: ...

class AESSIV:
def __init__(self, key: bytes) -> None: ...
@staticmethod
Expand Down
64 changes: 1 addition & 63 deletions src/cryptography/hazmat/primitives/ciphers/aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"AESSIV",
]

AESGCM = rust_openssl.aead.AESGCM
AESSIV = rust_openssl.aead.AESSIV
AESOCB3 = rust_openssl.aead.AESOCB3

Expand Down Expand Up @@ -180,66 +181,3 @@ def _check_params(
utils._check_byteslike("associated_data", associated_data)
if not 7 <= len(nonce) <= 13:
raise ValueError("Nonce must be between 7 and 13 bytes")


class AESGCM:
_MAX_SIZE = 2**31 - 1

def __init__(self, key: bytes):
utils._check_byteslike("key", key)
if len(key) not in (16, 24, 32):
raise ValueError("AESGCM key must be 128, 192, or 256 bits.")

self._key = key

@classmethod
def generate_key(cls, bit_length: int) -> bytes:
if not isinstance(bit_length, int):
raise TypeError("bit_length must be an integer")

if bit_length not in (128, 192, 256):
raise ValueError("bit_length must be 128, 192, or 256")

return os.urandom(bit_length // 8)

def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes:
if associated_data is None:
associated_data = b""

if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE:
# This is OverflowError to match what cffi would raise
raise OverflowError(
"Data or associated data too long. Max 2**31 - 1 bytes"
)

self._check_params(nonce, data, associated_data)
return aead._encrypt(backend, self, nonce, data, [associated_data], 16)

def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes:
if associated_data is None:
associated_data = b""

self._check_params(nonce, data, associated_data)
return aead._decrypt(backend, self, nonce, data, [associated_data], 16)

def _check_params(
self,
nonce: bytes,
data: bytes,
associated_data: bytes,
) -> None:
utils._check_byteslike("nonce", nonce)
utils._check_byteslike("data", data)
utils._check_byteslike("associated_data", associated_data)
if len(nonce) < 8 or len(nonce) > 128:
raise ValueError("Nonce must be between 8 and 128 bytes")
87 changes: 87 additions & 0 deletions src/rust/src/backend/aead.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,92 @@ impl EvpCipherAead {
}
}

#[pyo3::prelude::pyclass(
frozen,
module = "cryptography.hazmat.bindings._rust.openssl.aead",
name = "AESGCM"
)]
struct AesGcm {
ctx: EvpCipherAead,
}

#[pyo3::prelude::pymethods]
impl AesGcm {
#[new]
fn new(py: pyo3::Python<'_>, key: pyo3::Py<pyo3::PyAny>) -> CryptographyResult<AesGcm> {
let key_buf = key.extract::<CffiBuf<'_>>(py)?;
let cipher = match key_buf.as_bytes().len() {
16 => openssl::cipher::Cipher::aes_128_gcm(),
24 => openssl::cipher::Cipher::aes_192_gcm(),
32 => openssl::cipher::Cipher::aes_256_gcm(),
_ => {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err(
"AESGCM key must be 128, 192, or 256 bits.",
),
))
}
};

Ok(AesGcm {
ctx: EvpCipherAead::new(cipher, key_buf.as_bytes(), 16, false)?,
})
}

#[staticmethod]
fn generate_key(py: pyo3::Python<'_>, bit_length: usize) -> CryptographyResult<&pyo3::PyAny> {
if bit_length != 128 && bit_length != 192 && bit_length != 256 {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("bit_length must be 128, 192, or 256"),
));
}

Ok(py
.import(pyo3::intern!(py, "os"))?
.call_method1(pyo3::intern!(py, "urandom"), (bit_length / 8,))?)
}

fn encrypt<'p>(
&self,
py: pyo3::Python<'p>,
nonce: CffiBuf<'_>,
data: CffiBuf<'_>,
associated_data: Option<CffiBuf<'_>>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
let nonce_bytes = nonce.as_bytes();
let aad = associated_data.map(Aad::Single);

if nonce_bytes.len() < 8 || nonce_bytes.len() > 128 {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("Nonce must be between 8 and 128 bytes"),
));
}

self.ctx
.encrypt(py, data.as_bytes(), aad, Some(nonce_bytes))
}

fn decrypt<'p>(
&self,
py: pyo3::Python<'p>,
nonce: CffiBuf<'_>,
data: CffiBuf<'_>,
associated_data: Option<CffiBuf<'_>>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
let nonce_bytes = nonce.as_bytes();
let aad = associated_data.map(Aad::Single);

if nonce_bytes.len() < 8 || nonce_bytes.len() > 128 {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("Nonce must be between 8 and 128 bytes"),
));
}

self.ctx
.decrypt(py, data.as_bytes(), aad, Some(nonce_bytes))
}
}

#[pyo3::prelude::pyclass(
frozen,
module = "cryptography.hazmat.bindings._rust.openssl.aead",
Expand Down Expand Up @@ -413,6 +499,7 @@ impl AesOcb3 {
pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {
let m = pyo3::prelude::PyModule::new(py, "aead")?;

m.add_class::<AesGcm>()?;
m.add_class::<AesSiv>()?;
m.add_class::<AesOcb3>()?;

Expand Down

0 comments on commit 7134cc7

Please sign in to comment.