Skip to content

Commit

Permalink
Convert src/backend/dh.rs to new pyo3 APIs (#10714)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex authored Apr 4, 2024
1 parent f284aee commit 98e6fd4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 29 deletions.
59 changes: 31 additions & 28 deletions src/rust/src/backend/dh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::asn1::encode_der_data;
use crate::backend::utils;
use crate::error::{CryptographyError, CryptographyResult};
use crate::{types, x509};
use pyo3::prelude::PyAnyMethods;
use pyo3::prelude::{PyAnyMethods, PyModuleMethods};

const MIN_MODULUS_SIZE: u32 = 512;

Expand All @@ -31,7 +31,7 @@ struct DHParameters {
fn generate_parameters(
generator: u32,
key_size: u32,
backend: Option<&pyo3::PyAny>,
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
) -> CryptographyResult<DHParameters> {
let _ = backend;

Expand Down Expand Up @@ -89,7 +89,7 @@ fn pkey_from_dh<T: openssl::pkey::HasParams>(
#[pyo3::prelude::pyfunction]
fn from_der_parameters(
data: &[u8],
backend: Option<&pyo3::PyAny>,
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
) -> CryptographyResult<DHParameters> {
let _ = backend;
let asn1_params = asn1::parse_single::<common::DHParams<'_>>(data)?;
Expand All @@ -109,7 +109,7 @@ fn from_der_parameters(
#[pyo3::prelude::pyfunction]
fn from_pem_parameters(
data: &[u8],
backend: Option<&pyo3::PyAny>,
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
) -> CryptographyResult<DHParameters> {
let _ = backend;
let parsed = x509::find_in_pem(
Expand Down Expand Up @@ -156,13 +156,14 @@ impl DHPrivateKey {
&self,
py: pyo3::Python<'p>,
peer_public_key: &DHPublicKey,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let mut deriver = openssl::derive::Deriver::new(&self.pkey)?;
deriver
.set_peer(&peer_public_key.pkey)
.map_err(|_| pyo3::exceptions::PyValueError::new_err("Error computing shared key."))?;

Ok(pyo3::types::PyBytes::new_with(py, deriver.len()?, |b| {
let len = deriver.len()?;
Ok(pyo3::types::PyBytes::new_bound_with(py, len, |b| {
let n = deriver.derive(b).unwrap();

let pad = b.len() - n;
Expand Down Expand Up @@ -341,8 +342,8 @@ impl DHParameters {
fn parameter_bytes<'p>(
&self,
py: pyo3::Python<'p>,
encoding: &'p pyo3::PyAny,
format: &pyo3::PyAny,
encoding: pyo3::Bound<'p, pyo3::PyAny>,
format: pyo3::Bound<'p, pyo3::PyAny>,
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
if !format.is(types::PARAMETER_FORMAT_PKCS3.get(py)?) {
return Err(CryptographyError::from(
Expand All @@ -368,7 +369,7 @@ impl DHParameters {
} else {
"X9.42 DH PARAMETERS"
};
encode_der_data(py, tag.to_string(), data, encoding)
encode_der_data(py, tag.to_string(), data, encoding.into_gil_ref())
}
}

Expand Down Expand Up @@ -412,7 +413,7 @@ impl DHPrivateNumbers {
fn private_key(
&self,
py: pyo3::Python<'_>,
backend: Option<&pyo3::PyAny>,
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
) -> CryptographyResult<DHPrivateKey> {
let _ = backend;

Expand All @@ -439,11 +440,11 @@ impl DHPrivateNumbers {
py: pyo3::Python<'_>,
other: pyo3::PyRef<'_, Self>,
) -> CryptographyResult<bool> {
Ok(self.x.as_ref(py).eq(other.x.as_ref(py))?
Ok(self.x.bind(py).eq(other.x.bind(py))?
&& self
.public_numbers
.as_ref(py)
.eq(other.public_numbers.as_ref(py))?)
.bind(py)
.eq(other.public_numbers.bind(py))?)
}
}

Expand All @@ -464,7 +465,7 @@ impl DHPublicNumbers {
fn public_key(
&self,
py: pyo3::Python<'_>,
backend: Option<&pyo3::PyAny>,
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
) -> CryptographyResult<DHPublicKey> {
let _ = backend;

Expand All @@ -482,11 +483,11 @@ impl DHPublicNumbers {
py: pyo3::Python<'_>,
other: pyo3::PyRef<'_, Self>,
) -> CryptographyResult<bool> {
Ok(self.y.as_ref(py).eq(other.y.as_ref(py))?
Ok(self.y.bind(py).eq(other.y.bind(py))?
&& self
.parameter_numbers
.as_ref(py)
.eq(other.parameter_numbers.as_ref(py))?)
.bind(py)
.eq(other.parameter_numbers.bind(py))?)
}
}

Expand All @@ -499,13 +500,13 @@ impl DHParameterNumbers {
g: pyo3::Py<pyo3::types::PyLong>,
q: Option<pyo3::Py<pyo3::types::PyLong>>,
) -> CryptographyResult<DHParameterNumbers> {
if g.as_ref(py).lt(2)? {
if g.bind(py).lt(2)? {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("DH generator must be 2 or greater"),
));
}

if p.as_ref(py)
if p.bind(py)
.call_method0("bit_length")?
.lt(MIN_MODULUS_SIZE)?
{
Expand All @@ -522,7 +523,7 @@ impl DHParameterNumbers {
fn parameters(
&self,
py: pyo3::Python<'_>,
backend: Option<&pyo3::PyAny>,
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
) -> CryptographyResult<DHParameters> {
let _ = backend;

Expand All @@ -536,21 +537,23 @@ impl DHParameterNumbers {
other: pyo3::PyRef<'_, Self>,
) -> CryptographyResult<bool> {
let q_equal = match (self.q.as_ref(), other.q.as_ref()) {
(Some(self_q), Some(other_q)) => self_q.as_ref(py).eq(other_q.as_ref(py))?,
(Some(self_q), Some(other_q)) => self_q.bind(py).eq(other_q.bind(py))?,
(None, None) => true,
_ => false,
};
Ok(self.p.as_ref(py).eq(other.p.as_ref(py))?
&& self.g.as_ref(py).eq(other.g.as_ref(py))?
Ok(self.p.bind(py).eq(other.p.bind(py))?
&& self.g.bind(py).eq(other.g.bind(py))?
&& q_equal)
}
}

pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {
let m = pyo3::prelude::PyModule::new(py, "dh")?;
m.add_function(pyo3::wrap_pyfunction!(generate_parameters, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(from_der_parameters, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(from_pem_parameters, m)?)?;
pub(crate) fn create_module(
py: pyo3::Python<'_>,
) -> pyo3::PyResult<pyo3::Bound<'_, pyo3::prelude::PyModule>> {
let m = pyo3::prelude::PyModule::new_bound(py, "dh")?;
m.add_function(pyo3::wrap_pyfunction_bound!(generate_parameters, &m)?)?;
m.add_function(pyo3::wrap_pyfunction_bound!(from_der_parameters, &m)?)?;
m.add_function(pyo3::wrap_pyfunction_bound!(from_pem_parameters, &m)?)?;

m.add_class::<DHPrivateKey>()?;
m.add_class::<DHPublicKey>()?;
Expand Down
2 changes: 1 addition & 1 deletion src/rust/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub(crate) fn add_to_module(module: &pyo3::prelude::PyModule) -> pyo3::PyResult<
module.add_submodule(aead::create_module(module.py())?.into_gil_ref())?;
module.add_submodule(ciphers::create_module(module.py())?.into_gil_ref())?;
module.add_submodule(cmac::create_module(module.py())?)?;
module.add_submodule(dh::create_module(module.py())?)?;
module.add_submodule(dh::create_module(module.py())?.into_gil_ref())?;
module.add_submodule(dsa::create_module(module.py())?)?;
module.add_submodule(ec::create_module(module.py())?)?;
module.add_submodule(keys::create_module(module.py())?)?;
Expand Down

0 comments on commit 98e6fd4

Please sign in to comment.