Skip to content

Commit

Permalink
fixes #11453 -- include localKeyID when serializaing a key with a cert (
Browse files Browse the repository at this point in the history
  • Loading branch information
alex authored Aug 23, 2024
1 parent 041ef8b commit b5a312f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 22 deletions.
4 changes: 4 additions & 0 deletions src/rust/cryptography-x509/src/pkcs12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub const SHROUDED_KEY_BAG_OID: asn1::ObjectIdentifier =
asn1::oid!(1, 2, 840, 113549, 1, 12, 10, 1, 2);
pub const X509_CERTIFICATE_OID: asn1::ObjectIdentifier = asn1::oid!(1, 2, 840, 113549, 1, 9, 22, 1);
pub const FRIENDLY_NAME_OID: asn1::ObjectIdentifier = asn1::oid!(1, 2, 840, 113549, 1, 9, 20);
pub const LOCAL_KEY_ID_OID: asn1::ObjectIdentifier = asn1::oid!(1, 2, 840, 113549, 1, 9, 21);

#[derive(asn1::Asn1Write)]
pub struct Pfx<'a> {
Expand Down Expand Up @@ -46,6 +47,9 @@ pub struct Attribute<'a> {
pub enum AttributeSet<'a> {
#[defined_by(FRIENDLY_NAME_OID)]
FriendlyName(asn1::SetOfWriter<'a, Utf8StoredBMPString<'a>, [Utf8StoredBMPString<'a>; 1]>),

#[defined_by(LOCAL_KEY_ID_OID)]
LocalKeyId(asn1::SetOfWriter<'a, &'a [u8], [&'a [u8]; 1]>),
}

#[derive(asn1::Asn1DefinedByWrite)]
Expand Down
60 changes: 41 additions & 19 deletions src/rust/src/pkcs12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,38 +338,51 @@ fn pkcs12_kdf(
Ok(result)
}

fn friendly_name_attributes(
friendly_name: Option<&[u8]>,
fn pkcs12_attributes<'a>(
friendly_name: Option<&'a [u8]>,
local_key_id: Option<&'a [u8]>,
) -> CryptographyResult<
Option<
asn1::SetOfWriter<
'_,
cryptography_x509::pkcs12::Attribute<'_>,
Vec<cryptography_x509::pkcs12::Attribute<'_>>,
'a,
cryptography_x509::pkcs12::Attribute<'a>,
Vec<cryptography_x509::pkcs12::Attribute<'a>>,
>,
>,
> {
let mut attrs = vec![];
if let Some(name) = friendly_name {
let name_str = std::str::from_utf8(name).map_err(|_| {
pyo3::exceptions::PyValueError::new_err("friendly_name must be valid UTF-8")
})?;

Ok(Some(asn1::SetOfWriter::new(vec![
cryptography_x509::pkcs12::Attribute {
_attr_id: asn1::DefinedByMarker::marker(),
attr_values: cryptography_x509::pkcs12::AttributeSet::FriendlyName(
asn1::SetOfWriter::new([Utf8StoredBMPString::new(name_str)]),
),
},
])))
} else {
attrs.push(cryptography_x509::pkcs12::Attribute {
_attr_id: asn1::DefinedByMarker::marker(),
attr_values: cryptography_x509::pkcs12::AttributeSet::FriendlyName(
asn1::SetOfWriter::new([Utf8StoredBMPString::new(name_str)]),
),
});
}
if let Some(key_id) = local_key_id {
attrs.push(cryptography_x509::pkcs12::Attribute {
_attr_id: asn1::DefinedByMarker::marker(),
attr_values: cryptography_x509::pkcs12::AttributeSet::LocalKeyId(
asn1::SetOfWriter::new([key_id]),
),
});
}

if attrs.is_empty() {
Ok(None)
} else {
Ok(Some(asn1::SetOfWriter::new(attrs)))
}
}

fn cert_to_bag<'a>(
cert: &'a Certificate,
friendly_name: Option<&'a [u8]>,
local_key_id: Option<&'a [u8]>,
) -> CryptographyResult<cryptography_x509::pkcs12::SafeBag<'a>> {
Ok(cryptography_x509::pkcs12::SafeBag {
_bag_id: asn1::DefinedByMarker::marker(),
Expand All @@ -381,7 +394,7 @@ fn cert_to_bag<'a>(
)),
},
)),
attributes: friendly_name_attributes(friendly_name)?,
attributes: pkcs12_attributes(friendly_name, local_key_id)?,
})
}

Expand Down Expand Up @@ -499,6 +512,7 @@ fn serialize_key_and_certificates<'p>(
key_ciphertext,
);
let mut ca_certs = vec![];
let mut key_id = None;
if cert.is_some() || cas.is_some() {
let mut cert_bags = vec![];

Expand All @@ -515,9 +529,14 @@ fn serialize_key_and_certificates<'p>(
),
));
}
key_id = Some(cert.fingerprint(py, &types::SHA1.get(py)?.call0()?)?);
}

cert_bags.push(cert_to_bag(cert, name)?);
cert_bags.push(cert_to_bag(
cert,
name,
key_id.as_ref().map(|v| v.as_bytes()),
)?);
}

if let Some(cas) = cas {
Expand All @@ -527,10 +546,13 @@ fn serialize_key_and_certificates<'p>(

for cert in &ca_certs {
let bag = match cert {
CertificateOrPKCS12Certificate::Certificate(c) => cert_to_bag(c.get(), None)?,
CertificateOrPKCS12Certificate::Certificate(c) => {
cert_to_bag(c.get(), None, None)?
}
CertificateOrPKCS12Certificate::PKCS12Certificate(c) => cert_to_bag(
c.get().certificate.get(),
c.get().friendly_name.as_ref().map(|v| v.as_bytes(py)),
None,
)?,
};
cert_bags.push(bag);
Expand Down Expand Up @@ -627,7 +649,7 @@ fn serialize_key_and_certificates<'p>(
},
),
),
attributes: friendly_name_attributes(name)?,
attributes: pkcs12_attributes(name, key_id.as_ref().map(|v| v.as_bytes()))?,
}
} else {
let pkcs8_tlv = asn1::parse_single(&pkcs8_bytes)?;
Expand All @@ -637,7 +659,7 @@ fn serialize_key_and_certificates<'p>(
bag_value: asn1::Explicit::new(cryptography_x509::pkcs12::BagValue::KeyBag(
pkcs8_tlv,
)),
attributes: friendly_name_attributes(name)?,
attributes: pkcs12_attributes(name, key_id.as_ref().map(|v| v.as_bytes()))?,
}
};

Expand Down
6 changes: 3 additions & 3 deletions src/rust/src/x509/certificate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,16 @@ impl Certificate {
)
}

fn fingerprint<'p>(
pub(crate) fn fingerprint<'p>(
&self,
py: pyo3::Python<'p>,
algorithm: &pyo3::Bound<'p, pyo3::PyAny>,
) -> CryptographyResult<pyo3::Bound<'p, pyo3::PyAny>> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let serialized = asn1::write_single(&self.raw.borrow_dependent())?;

let mut h = hashes::Hash::new(py, algorithm, None)?;
h.update_bytes(&serialized)?;
Ok(h.finalize(py)?.into_any())
h.finalize(py)
}

fn public_bytes<'p>(
Expand Down
24 changes: 24 additions & 0 deletions tests/hazmat/primitives/test_pkcs12.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,30 @@ def test_set_mac_key_certificate_mismatch(self, backend):
b"name", key, cacert, [], encryption
)

@pytest.mark.parametrize(
"encryption_algorithm",
[
serialization.NoEncryption(),
serialization.BestAvailableEncryption(b"password"),
],
)
def test_generate_localkeyid(self, backend, encryption_algorithm):
cert, key = _load_ca(backend)

p12 = serialize_key_and_certificates(
None, key, cert, None, encryption_algorithm
)
# Dirty, but does the trick. Should be there:
# * 2x if unencrypted (once for the key and once for the cert)
# * 1x if encrypted (the cert one is encrypted, but the key one is
# plaintext)
count = (
2
if isinstance(encryption_algorithm, serialization.NoEncryption)
else 1
)
assert p12.count(cert.fingerprint(hashes.SHA1())) == count


@pytest.mark.skip_fips(
reason="PKCS12 unsupported in FIPS mode. So much bad crypto in it."
Expand Down

0 comments on commit b5a312f

Please sign in to comment.