diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ef98ff0ad..935e8ee68 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -8,7 +8,7 @@ on: jobs: lint: name: Lint - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v3 @@ -19,7 +19,7 @@ jobs: unit-test: name: Unit tests - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v3 @@ -30,7 +30,7 @@ jobs: integration-test-lxd-charm: name: Integration tests for the charm (lxd) - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v3 @@ -49,7 +49,7 @@ jobs: integration-test-lxd-tls: name: Integration tests for TLS (lxd) - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index fdb4296f8..9d2129cb2 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -6,19 +6,26 @@ on: - main jobs: -# lib-check: -# name: Check libraries -# runs-on: ubuntu-latest -# steps: -# - name: Checkout -# uses: actions/checkout@v3 + lib-check: + name: Check libraries + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Check libs + uses: canonical/charming-actions/check-libraries@2.1.1 + with: + credentials: "${{ secrets.CHARMHUB_TOKEN }}" # FIXME: current token will expire in 2023-07-04 + github-token: "${{ secrets.GITHUB_TOKEN }}" ci-tests: uses: ./.github/workflows/ci.yaml release-to-charmhub: name: Release to CharmHub needs: - # - lib-check + - lib-check - ci-tests runs-on: ubuntu-22.04 steps: diff --git a/lib/charms/opensearch/v0/opensearch_tls.py b/lib/charms/opensearch/v0/opensearch_tls.py index 1a7d6eebd..e8064014a 100644 --- a/lib/charms/opensearch/v0/opensearch_tls.py +++ b/lib/charms/opensearch/v0/opensearch_tls.py @@ -17,7 +17,7 @@ import logging import re import socket -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple from charms.opensearch.v0.constants_tls import TLS_RELATION, CertType from charms.opensearch.v0.helper_databag import Scope @@ -199,14 +199,14 @@ def _request_certificate( if self.charm.model.get_relation(TLS_RELATION): self.certs.request_certificate_creation(certificate_signing_request=csr) - def _get_sans(self, cert_type: CertType) -> dict: + def _get_sans(self, cert_type: CertType) -> Dict[str, List[str]]: """Create a list of OID/IP/DNS names for an OpenSearch unit. Returns: A list representing the hostnames of the OpenSearch unit. or None if admin cert_type, because that cert is not tied to a specific host. """ - sans = {"sans_oid": "1.2.3.4.5.5"} # required for node discovery + sans = {"sans_oid": ["1.2.3.4.5.5"]} # required for node discovery if cert_type == CertType.APP_ADMIN: return sans diff --git a/lib/charms/tls_certificates_interface/v1/tls_certificates.py b/lib/charms/tls_certificates_interface/v1/tls_certificates.py index e98d7866c..1eda19bf1 100644 --- a/lib/charms/tls_certificates_interface/v1/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v1/tls_certificates.py @@ -149,7 +149,7 @@ def __init__(self, *args): self.certificates.on.certificate_available, self._on_certificate_available ) self.framework.observe( - self.on.certificates.on.certificate_expiring, self._on_certificate_expiring + self.certificates.on.certificate_expiring, self._on_certificate_expiring ) def _on_install(self, event) -> None: @@ -226,9 +226,11 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None: from typing import Dict, List, Optional from cryptography import x509 +from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.serialization import pkcs12 +from cryptography.x509.extensions import Extension, ExtensionNotFound from jsonschema import exceptions, validate # type: ignore[import] from ops.charm import CharmBase, CharmEvents, RelationChangedEvent, UpdateStatusEvent from ops.framework import EventBase, EventSource, Handle, Object @@ -241,7 +243,7 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 6 +LIBPATCH = 10 REQUIRER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", @@ -371,13 +373,13 @@ def restore(self, snapshot: dict): class CertificateExpiringEvent(EventBase): """Charm Event triggered when a TLS certificate is almost expired.""" - def __init__(self, handle, certificate: str, expiry: datetime): + def __init__(self, handle, certificate: str, expiry: str): """CertificateExpiringEvent. Args: handle (Handle): Juju framework handle certificate (str): TLS Certificate - expiry (datetime): Datetime object reprensenting the time at which the certificate + expiry (str): Datetime string reprensenting the time at which the certificate won't be valid anymore. """ super().__init__(handle) @@ -480,7 +482,7 @@ def _load_relation_data(raw_relation_data: dict) -> dict: for key in raw_relation_data: try: certificate_data[key] = json.loads(raw_relation_data[key]) - except json.decoder.JSONDecodeError: + except (json.decoder.JSONDecodeError, TypeError): certificate_data[key] = raw_relation_data[key] return certificate_data @@ -549,7 +551,7 @@ def generate_certificate( ca_key: bytes, ca_key_password: Optional[bytes] = None, validity: int = 365, - alt_names: list = None, + alt_names: List[str] = None, ) -> bytes: """Generates a TLS certificate based on a CSR. @@ -559,7 +561,7 @@ def generate_certificate( ca_key (bytes): CA private key ca_key_password: CA private key password validity (int): Certificate validity (in days) - alt_names: Certificate Subject alternative names + alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR Returns: bytes: Certificate @@ -578,13 +580,36 @@ def generate_certificate( .not_valid_before(datetime.utcnow()) .not_valid_after(datetime.utcnow() + timedelta(days=validity)) ) + + extensions_list = csr_object.extensions + san_ext: Optional[x509.Extension] = None if alt_names: - names = [x509.DNSName(n) for n in alt_names] + full_sans_dns = alt_names.copy() + try: + loaded_san_ext = csr_object.extensions.get_extension_for_class( + x509.SubjectAlternativeName + ) + full_sans_dns.extend(loaded_san_ext.value.get_values_for_type(x509.DNSName)) + except ExtensionNotFound: + pass + finally: + san_ext = Extension( + ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + False, + x509.SubjectAlternativeName([x509.DNSName(name) for name in full_sans_dns]), + ) + if not extensions_list: + extensions_list = x509.Extensions([san_ext]) + + for extension in extensions_list: + if extension.value.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME and san_ext: + extension = san_ext + certificate_builder = certificate_builder.add_extension( - x509.SubjectAlternativeName(names), - critical=False, + extension.value, + critical=extension.critical, ) - certificate_builder._version = x509.Version.v1 + certificate_builder._version = x509.Version.v3 cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] return cert.public_bytes(serialization.Encoding.PEM) @@ -658,7 +683,8 @@ def generate_csr( email_address: str = None, country_name: str = None, private_key_password: Optional[bytes] = None, - sans_oid: Optional[str] = None, + sans: Optional[List[str]] = None, + sans_oid: Optional[List[str]] = None, sans_ip: Optional[List[str]] = None, sans_dns: Optional[List[str]] = None, additional_critical_extensions: Optional[List] = None, @@ -675,9 +701,11 @@ def generate_csr( email_address (str): Email address. country_name (str): Country Name. private_key_password (bytes): Private key password - sans_dns (list): List of DNS subject alternative names + sans (list): Use sans_dns - this will be deprecated in a future release + List of DNS subject alternative names (keeping it for now for backward compatibility) + sans_oid (list): List of registered ID SANs + sans_dns (list): List of DNS subject alternative names (similar to the arg: sans) sans_ip (list): List of IP subject alternative names - sans_oid (str): Additional OID additional_critical_extensions (list): List if critical additional extension objects. Object must be a x509 ExtensionType. @@ -699,19 +727,22 @@ def generate_csr( subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) - _sans = [] + _sans: List[x509.GeneralName] = [] if sans_oid: - _sans.append(x509.RegisteredID(x509.ObjectIdentifier(sans_oid))) + _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) if sans_ip: _sans.extend([x509.IPAddress(IPv4Address(san)) for san in sans_ip]) + if sans: + _sans.extend([x509.DNSName(san) for san in sans]) if sans_dns: _sans.extend([x509.DNSName(san) for san in sans_dns]) if _sans: - csr = csr.add_extension(x509.SubjectAlternativeName(_sans), critical=False) + csr = csr.add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False) if additional_critical_extensions: for extension in additional_critical_extensions: csr = csr.add_extension(extension, critical=True) + signed_certificate = csr.sign(signing_key, hashes.SHA256()) # type: ignore[arg-type] return signed_certificate.public_bytes(serialization.Encoding.PEM) @@ -744,29 +775,18 @@ def __init__(self, charm: CharmBase, relationship_name: str): self.charm = charm self.relationship_name = relationship_name - @property - def _provider_certificates(self) -> List[Dict]: - """Returns list of provider CSR's from relation data.""" - relation = self.model.get_relation(self.relationship_name) - if not relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") - provider_relation_data = _load_relation_data(relation.data[self.model.app]) - return provider_relation_data.get("certificates", []) - - def _requirer_csrs(self, unit) -> List[Dict[str, str]]: - """Returns list of requirer CSR's from relation data.""" - relation = self.model.get_relation(self.relationship_name) - if not relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") - requirer_relation_data = _load_relation_data(relation.data[unit]) - return requirer_relation_data.get("certificate_signing_requests", []) - def _add_certificate( - self, certificate: str, certificate_signing_request: str, ca: str, chain: List[str] + self, + relation_id: int, + certificate: str, + certificate_signing_request: str, + ca: str, + chain: List[str], ) -> None: """Adds certificate to relation data. Args: + relation_id (int): Relation id certificate (str): Certificate certificate_signing_request (str): Certificate Signing Request ca (str): CA Certificate @@ -775,7 +795,9 @@ def _add_certificate( Returns: None """ - relation = self.model.get_relation(self.relationship_name) + relation = self.model.get_relation( + relation_name=self.relationship_name, relation_id=relation_id + ) if not relation: raise RuntimeError( f"Relation {self.relationship_name} does not exist - " @@ -787,7 +809,9 @@ def _add_certificate( "ca": ca, "chain": chain, } - certificates = copy.deepcopy(self._provider_certificates) + provider_relation_data = _load_relation_data(relation.data[self.charm.app]) + provider_certificates = provider_relation_data.get("certificates", []) + certificates = copy.deepcopy(provider_certificates) if new_certificate in certificates: logger.info("Certificate already in relation data - Doing nothing") return @@ -818,7 +842,9 @@ def _remove_certificate( raise RuntimeError( f"Relation {self.relationship_name} with relation id {relation_id} does not exist" ) - certificates = copy.deepcopy(self._provider_certificates) + provider_relation_data = _load_relation_data(relation.data[self.charm.app]) + provider_certificates = provider_relation_data.get("certificates", []) + certificates = copy.deepcopy(provider_certificates) for certificate_dict in certificates: if certificate and certificate_dict["certificate"] == certificate: certificates.remove(certificate_dict) @@ -845,6 +871,14 @@ def _relation_data_is_valid(certificates_data: dict) -> bool: except exceptions.ValidationError: return False + def revoke_all_certificates(self) -> None: + """Revokes all certificates of this provider. + + This method is meant to be used when the Root CA has changed. + """ + for relation in self.model.relations[self.relationship_name]: + relation.data[self.model.app]["certificates"] = json.dumps([]) + def set_relation_certificate( self, certificate: str, @@ -875,6 +909,7 @@ def set_relation_certificate( relation_id=relation_id, ) self._add_certificate( + relation_id=relation_id, certificate=certificate.strip(), certificate_signing_request=certificate_signing_request.strip(), ca=ca.strip(), @@ -911,19 +946,23 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ + assert event.unit is not None requirer_relation_data = _load_relation_data(event.relation.data[event.unit]) + provider_relation_data = _load_relation_data(event.relation.data[self.charm.app]) if not self._relation_data_is_valid(requirer_relation_data): logger.warning( f"Relation data did not pass JSON Schema validation: {requirer_relation_data}" ) return + provider_certificates = provider_relation_data.get("certificates", []) + requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) provider_csrs = [ certificate_creation_request["certificate_signing_request"] - for certificate_creation_request in self._provider_certificates + for certificate_creation_request in provider_certificates ] requirer_unit_csrs = [ certificate_creation_request["certificate_signing_request"] - for certificate_creation_request in self._requirer_csrs(event.unit) + for certificate_creation_request in requirer_csrs ] for certificate_signing_request in requirer_unit_csrs: if certificate_signing_request not in provider_csrs: @@ -950,12 +989,14 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None ) if not certificates_relation: raise RuntimeError(f"Relation {self.relationship_name} does not exist") + provider_relation_data = _load_relation_data(certificates_relation.data[self.charm.app]) list_of_csrs: List[str] = [] for unit in certificates_relation.units: - list_of_csrs.extend( - csr["certificate_signing_request"] for csr in self._requirer_csrs(unit) - ) - for certificate in self._provider_certificates: + requirer_relation_data = _load_relation_data(certificates_relation.data[unit]) + requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) + list_of_csrs.extend(csr["certificate_signing_request"] for csr in requirer_csrs) + provider_certificates = provider_relation_data.get("certificates", []) + for certificate in provider_certificates: if certificate["certificate_signing_request"] not in list_of_csrs: self.on.certificate_revocation_request.emit( certificate=certificate["certificate"], @@ -1009,7 +1050,9 @@ def _provider_certificates(self) -> List[Dict[str, str]]: relation = self.model.get_relation(self.relationship_name) if not relation: raise RuntimeError(f"Relation {self.relationship_name} does not exist") - provider_relation_data = _load_relation_data(relation.data[relation.app]) # type: ignore[index] # noqa: E501 + if not relation.app: + raise RuntimeError(f"Remote app for relation {self.relationship_name} does not exist") + provider_relation_data = _load_relation_data(relation.data[relation.app]) return provider_relation_data.get("certificates", []) def _add_requirer_csr(self, csr: str) -> None: @@ -1148,11 +1191,14 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: if not relation: logger.warning(f"No relation: {self.relationship_name}") return - provider_relation_data = _load_relation_data(relation.data[relation.app]) # type: ignore[index] # noqa: E501 + if not relation.app: + logger.warning(f"No remote app in relation: {self.relationship_name}") + return + provider_relation_data = _load_relation_data(relation.data[relation.app]) if not self._relation_data_is_valid(provider_relation_data): logger.warning( f"Provider relation data did not pass JSON Schema validation: " - f"{event.relation.data[event.app]}" + f"{event.relation.data[relation.app]}" ) return requirer_csrs = [ @@ -1185,15 +1231,23 @@ def _on_update_status(self, event: UpdateStatusEvent) -> None: if not relation: logger.warning(f"No relation: {self.relationship_name}") return - provider_relation_data = _load_relation_data(relation.data[relation.app]) # type: ignore[index] # noqa: E501 + if not relation.app: + logger.warning(f"No remote app in relation: {self.relationship_name}") + return + provider_relation_data = _load_relation_data(relation.data[relation.app]) if not self._relation_data_is_valid(provider_relation_data): logger.warning( - f"Provider relation data did not pass JSON Schema validation: {relation.data[relation.app]}" # type: ignore[index] # noqa: W505 + f"Provider relation data did not pass JSON Schema validation: " + f"{relation.data[relation.app]}" ) return for certificate_dict in self._provider_certificates: certificate = certificate_dict["certificate"] - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError: + logger.warning("Could not load certificate.") + continue time_difference = certificate_object.not_valid_after - datetime.utcnow() if time_difference.total_seconds() < 0: logger.warning("Certificate is expired") @@ -1203,5 +1257,5 @@ def _on_update_status(self, event: UpdateStatusEvent) -> None: if time_difference.total_seconds() < (self.expiry_notification_time * 60 * 60): logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( - certificate=certificate, expiry=certificate_object.not_valid_after + certificate=certificate, expiry=certificate_object.not_valid_after.isoformat() ) diff --git a/tests/unit/test_opensearch_tls.py b/tests/unit/test_opensearch_tls.py index 85807452e..9aadf1148 100644 --- a/tests/unit/test_opensearch_tls.py +++ b/tests/unit/test_opensearch_tls.py @@ -38,14 +38,14 @@ def test_get_sans(self): """Test the SANs returned depending on the cert type.""" self.assertDictEqual( self.charm.tls._get_sans(CertType.APP_ADMIN), - {"sans_oid": "1.2.3.4.5.5"}, + {"sans_oid": ["1.2.3.4.5.5"]}, ) for cert_type in [CertType.UNIT_HTTP, CertType.UNIT_TRANSPORT]: self.assertDictEqual( self.charm.tls._get_sans(cert_type), { - "sans_oid": "1.2.3.4.5.5", + "sans_oid": ["1.2.3.4.5.5"], "sans_ip": ["1.1.1.1"], "sans_dns": [self.charm.unit_name, "nebula"], },