From d72ca94665efd68923138cc961ccf364cd725c2e Mon Sep 17 00:00:00 2001 From: Jeroen Dekkers Date: Wed, 8 May 2024 00:19:40 +0200 Subject: [PATCH 1/9] Enable more mypy checks --- boefjes/boefjes/__main__.py | 2 +- boefjes/boefjes/api.py | 8 ++--- boefjes/boefjes/app.py | 6 ++-- boefjes/boefjes/katalogus/api.py | 2 +- .../katalogus/dependencies/encryption.py | 4 +-- .../boefjes/katalogus/dependencies/plugins.py | 16 +++++----- boefjes/boefjes/katalogus/models.py | 2 +- .../katalogus/routers/organisations.py | 8 ++--- boefjes/boefjes/katalogus/routers/plugins.py | 6 ++-- .../boefjes/katalogus/routers/repositories.py | 8 ++--- boefjes/boefjes/katalogus/routers/settings.py | 8 +++-- boefjes/boefjes/migrations/env.py | 2 +- ...fafdaf_json_settings_for_settings_table.py | 4 +-- boefjes/boefjes/plugins/kat_crt_sh/main.py | 8 ++++- .../plugins/kat_cve_2023_35078/normalize.py | 18 +++++------ boefjes/boefjes/plugins/kat_dnssec/main.py | 2 +- .../plugins/kat_manual/csv/normalize.py | 10 +++---- boefjes/boefjes/plugins/kat_masscan/main.py | 2 +- boefjes/boefjes/plugins/kat_nmap_tcp/main.py | 2 +- .../kat_security_txt_downloader/main.py | 3 +- .../boefjes/plugins/kat_snyk/check_version.py | 8 +++-- .../plugins/kat_webpage_analysis/main.py | 9 +++--- .../plugins/kat_webpage_capture/main.py | 4 +-- boefjes/boefjes/plugins/models.py | 2 +- boefjes/boefjes/runtime_interfaces.py | 2 +- boefjes/boefjes/seed.py | 2 +- boefjes/boefjes/sql/organisation_storage.py | 2 +- boefjes/boefjes/sql/repository_storage.py | 2 +- boefjes/boefjes/sql/setting_storage.py | 2 +- bytes/bytes/api/metrics.py | 2 +- bytes/bytes/api/root.py | 2 +- bytes/bytes/api/router.py | 2 +- bytes/bytes/database/migrations/env.py | 2 +- bytes/bytes/database/sql_meta_repository.py | 2 +- bytes/bytes/rabbitmq.py | 2 +- cveapi/cveapi.py | 4 +-- mula/scheduler/__init__.py | 4 +-- mula/scheduler/connectors/connector.py | 3 +- .../connectors/listeners/__init__.py | 2 ++ .../connectors/listeners/listeners.py | 2 +- .../scheduler/connectors/services/__init__.py | 2 ++ .../connectors/services/katalogus.py | 2 +- mula/scheduler/context/__init__.py | 2 ++ mula/scheduler/context/context.py | 3 ++ mula/scheduler/models/__init__.py | 30 +++++++++++++++++-- mula/scheduler/models/tasks.py | 2 +- mula/scheduler/queues/__init__.py | 10 +++++++ mula/scheduler/rankers/__init__.py | 2 ++ mula/scheduler/schedulers/__init__.py | 2 ++ mula/scheduler/schedulers/boefje.py | 2 +- mula/scheduler/server/__init__.py | 2 ++ mula/scheduler/storage/__init__.py | 8 +++++ mula/scheduler/storage/filters/__init__.py | 2 ++ mula/scheduler/storage/filters/functions.py | 3 +- mula/scheduler/storage/pq_store.py | 2 +- mula/scheduler/storage/storage.py | 3 +- mula/scheduler/utils/__init__.py | 2 ++ octopoes/bits/check_csp_header/bit.py | 2 +- .../bits/check_csp_header/check_csp_header.py | 2 +- octopoes/bits/check_cve_2021_41773/bit.py | 2 +- .../check_cve_2021_41773.py | 2 +- octopoes/bits/check_hsts_header/bit.py | 2 +- .../check_hsts_header/check_hsts_header.py | 2 +- .../cipher_classification.py | 5 ++-- .../missing_certificate.py | 2 +- octopoes/bits/nxdomain_flag/bit.py | 2 +- octopoes/bits/nxdomain_flag/nxdomain_flag.py | 2 +- octopoes/bits/nxdomain_header_flag/bit.py | 2 +- .../nxdomain_header_flag.py | 2 +- octopoes/bits/oois_in_headers/bit.py | 2 +- .../bits/oois_in_headers/oois_in_headers.py | 4 +-- octopoes/bits/retire_js/retire_js.py | 8 ++--- octopoes/bits/runner.py | 10 ++----- octopoes/bits/spf_discovery/spf_discovery.py | 2 +- octopoes/octopoes/api/router.py | 16 +++++----- octopoes/octopoes/connector/__init__.py | 6 +--- octopoes/octopoes/connector/octopoes.py | 10 +++---- octopoes/octopoes/core/app.py | 2 +- octopoes/octopoes/core/service.py | 6 ++-- octopoes/octopoes/models/__init__.py | 14 ++++----- octopoes/octopoes/models/ooi/dns/records.py | 2 +- octopoes/octopoes/models/ooi/findings.py | 2 +- octopoes/octopoes/models/ooi/network.py | 2 +- octopoes/octopoes/models/origin.py | 4 ++- octopoes/octopoes/models/path.py | 10 +++---- octopoes/octopoes/models/persistence.py | 4 +-- octopoes/octopoes/models/tree.py | 2 +- octopoes/octopoes/models/types.py | 6 ++-- .../octopoes/repositories/ooi_repository.py | 20 ++++++------- .../origin_parameter_repository.py | 2 +- .../repositories/scan_profile_repository.py | 2 +- octopoes/octopoes/tasks/tasks.py | 5 ++-- octopoes/octopoes/xtdb/client.py | 10 +++---- octopoes/octopoes/xtdb/query.py | 8 ++--- octopoes/octopoes/xtdb/query_builder.py | 3 +- .../octopoes/xtdb/related_field_generator.py | 8 ++--- pyproject.toml | 15 ++++++---- rocky/account/forms/__init__.py | 16 ++++++++++ rocky/account/forms/organization.py | 14 ++++----- rocky/account/mixins.py | 8 ++--- rocky/account/models.py | 2 +- rocky/account/views/account.py | 2 +- rocky/crisis_room/views.py | 3 +- rocky/katalogus/client.py | 22 +++++++------- rocky/katalogus/forms/plugin_settings.py | 4 ++- rocky/katalogus/views/mixins.py | 5 ++-- rocky/katalogus/views/plugin_detail.py | 8 ++--- .../katalogus/views/plugin_enable_disable.py | 2 +- rocky/onboarding/view_helpers.py | 4 +-- rocky/onboarding/views.py | 10 +++---- rocky/poetry.lock | 17 ++++++++++- rocky/pyproject.toml | 1 + rocky/reports/forms.py | 6 ++-- .../aggregate_organisation_report/report.py | 10 ++++--- rocky/reports/report_types/definitions.py | 4 ++- .../multi_organization_report/report.py | 2 +- .../report_types/name_server_report/report.py | 6 ++-- .../vulnerability_report/report.py | 2 +- .../report_types/web_system_report/report.py | 6 ++-- rocky/reports/views/aggregate_report.py | 4 +-- rocky/reports/views/base.py | 12 ++++---- rocky/reports/views/generate_report.py | 4 +-- rocky/reports/views/multi_report.py | 4 +-- rocky/requirements-dev.txt | 3 ++ rocky/requirements.txt | 3 ++ rocky/rocky/bytes_client.py | 6 ++-- rocky/rocky/exceptions.py | 7 +++-- rocky/rocky/keiko.py | 6 ++-- rocky/rocky/middleware/onboarding.py | 8 ++--- rocky/rocky/paginator.py | 14 +++++---- rocky/rocky/scheduler.py | 8 ++--- rocky/rocky/views/finding_add.py | 3 +- rocky/rocky/views/finding_list.py | 8 ++--- rocky/rocky/views/health.py | 5 ++-- rocky/rocky/views/mixins.py | 9 +++--- rocky/rocky/views/ooi_detail.py | 12 ++++---- .../rocky/views/ooi_detail_related_object.py | 5 ++-- rocky/rocky/views/ooi_list.py | 13 ++++---- rocky/rocky/views/ooi_report.py | 2 +- rocky/rocky/views/ooi_tree.py | 8 ++--- rocky/rocky/views/ooi_view.py | 7 +++-- rocky/rocky/views/organization_member_add.py | 11 +++---- rocky/rocky/views/organization_member_list.py | 2 +- rocky/rocky/views/organization_settings.py | 3 +- rocky/rocky/views/privacy_statement.py | 2 +- rocky/rocky/views/scan_profile.py | 2 +- rocky/rocky/views/tasks.py | 9 +++--- rocky/rocky/views/upload_csv.py | 8 ++--- rocky/tools/add_ooi_information.py | 2 +- rocky/tools/forms/base.py | 8 ++--- rocky/tools/forms/finding_type.py | 7 +++-- rocky/tools/forms/ooi.py | 6 ++-- rocky/tools/forms/ooi_form.py | 6 ++-- .../management/commands/export_migrations.py | 3 +- .../management/commands/generate_report.py | 6 ++-- .../management/commands/setup_test_users.py | 6 ++-- rocky/tools/models.py | 8 ++--- rocky/tools/ooi_helpers.py | 11 +++---- rocky/tools/templatetags/ooi_extra.py | 14 ++++----- rocky/tools/view_helpers.py | 27 ++++++++--------- 160 files changed, 527 insertions(+), 391 deletions(-) diff --git a/boefjes/boefjes/__main__.py b/boefjes/boefjes/__main__.py index e9621de3e0e..b0c264ea339 100644 --- a/boefjes/boefjes/__main__.py +++ b/boefjes/boefjes/__main__.py @@ -21,7 +21,7 @@ help="Log level", default="INFO", ) -def cli(worker_type: str, log_level: str): +def cli(worker_type: str, log_level: str) -> None: logger.setLevel(log_level) logger.info("Starting runtime for %s", worker_type) diff --git a/boefjes/boefjes/api.py b/boefjes/boefjes/api.py index c26f64e7979..99b2a90a93e 100644 --- a/boefjes/boefjes/api.py +++ b/boefjes/boefjes/api.py @@ -30,7 +30,7 @@ def __init__(self, config: Config): self.server = Server(config=config) self.config = config - def stop(self): + def stop(self) -> None: self.terminate() def run(self, *args, **kwargs): @@ -89,7 +89,7 @@ async def boefje_input( task_id: UUID, scheduler_client: SchedulerAPIClient = Depends(get_scheduler_client), local_repository: LocalPluginRepository = Depends(get_local_repository), -): +) -> BoefjeInput: task = get_task(task_id, scheduler_client) if task.status is not TaskStatus.RUNNING: @@ -108,7 +108,7 @@ async def boefje_output( scheduler_client: SchedulerAPIClient = Depends(get_scheduler_client), bytes_client: BytesAPIClient = Depends(get_bytes_client), local_repository: LocalPluginRepository = Depends(get_local_repository), -): +) -> Response: task = get_task(task_id, scheduler_client) if task.status is not TaskStatus.RUNNING: @@ -126,7 +126,7 @@ async def boefje_output( for file in boefje_output.files: raw = base64.b64decode(file.content) # when supported, also save file.name to Bytes - bytes_client.save_raw(task_id, raw, mime_types.union(file.tags)) + bytes_client.save_raw(task_id, raw, mime_types.union(file.tags) if file.tags else mime_types) if boefje_output.status == StatusEnum.COMPLETED: scheduler_client.patch_task(task_id, TaskStatus.COMPLETED) diff --git a/boefjes/boefjes/app.py b/boefjes/boefjes/app.py index 6734395a6cd..c9672a4f278 100644 --- a/boefjes/boefjes/app.py +++ b/boefjes/boefjes/app.py @@ -77,7 +77,7 @@ def run(self, queue_type: WorkerManager.Queue) -> None: raise - def _fill_queue(self, task_queue: Queue, queue_type: WorkerManager.Queue): + def _fill_queue(self, task_queue: Queue, queue_type: WorkerManager.Queue) -> None: if task_queue.qsize() > self.settings.pool_size: time.sleep(self.settings.worker_heartbeat) return @@ -186,7 +186,7 @@ def _cleanup_pending_worker_task(self, worker: mp.Process) -> None: def _worker_args(self) -> tuple: return self.task_queue, self.item_handler, self.scheduler_client, self.handling_tasks - def exit(self, queue_type: WorkerManager.Queue, signum: int | None = None): + def exit(self, queue_type: WorkerManager.Queue, signum: int | None = None) -> None: try: if signum: logger.info("Received %s, exiting", signal.Signals(signum).name) @@ -235,7 +235,7 @@ def _start_working( handler: Handler, scheduler_client: SchedulerClientInterface, handling_tasks: dict[int, str], -): +) -> None: logger.info("Started listening for tasks from worker[pid=%s]", os.getpid()) while True: diff --git a/boefjes/boefjes/katalogus/api.py b/boefjes/boefjes/katalogus/api.py index 3da6cb85668..f25301aef82 100644 --- a/boefjes/boefjes/katalogus/api.py +++ b/boefjes/boefjes/katalogus/api.py @@ -45,7 +45,7 @@ @app.exception_handler(StorageError) -def entity_not_found_handler(request: Request, exc: StorageError): +def entity_not_found_handler(request: Request, exc: StorageError) -> JSONResponse: logger.exception("some error", exc_info=exc) return JSONResponse( diff --git a/boefjes/boefjes/katalogus/dependencies/encryption.py b/boefjes/boefjes/katalogus/dependencies/encryption.py index 45e001a9fb1..43c56376b16 100644 --- a/boefjes/boefjes/katalogus/dependencies/encryption.py +++ b/boefjes/boefjes/katalogus/dependencies/encryption.py @@ -34,8 +34,8 @@ def __init__(self, private_key: str, public_key: str): def encode(self, contents: str) -> str: encrypted_contents = self.box.encrypt(contents.encode()) - encrypted_contents = base64.b64encode(encrypted_contents) - return encrypted_contents.decode() + base64_encrypted_contents = base64.b64encode(encrypted_contents) + return base64_encrypted_contents.decode() def decode(self, contents: str) -> str: encrypted_binary = base64.b64decode(contents) diff --git a/boefjes/boefjes/katalogus/dependencies/plugins.py b/boefjes/boefjes/katalogus/dependencies/plugins.py index 974b23cfde2..9155caccfc3 100644 --- a/boefjes/boefjes/katalogus/dependencies/plugins.py +++ b/boefjes/boefjes/katalogus/dependencies/plugins.py @@ -87,10 +87,10 @@ def by_plugin_ids(self, plugin_ids: list[str], organisation_id: str) -> list[Plu return found_plugins - def get_all_settings(self, organisation_id: str, plugin_id: str): + def get_all_settings(self, organisation_id: str, plugin_id: str) -> dict[str, str]: return self.settings_storage.get_all(organisation_id, plugin_id) - def clone_settings_to_organisation(self, from_organisation: str, to_organisation: str): + def clone_settings_to_organisation(self, from_organisation: str, to_organisation: str) -> None: # One requirement is that we also do not keep previously enabled boefjes enabled of they are not copied. for repository_id, plugins in self.plugin_enabled_store.get_all_enabled(to_organisation).items(): for plugin_id in plugins: @@ -104,12 +104,12 @@ def clone_settings_to_organisation(self, from_organisation: str, to_organisation for plugin_id in plugins: self.update_by_id(repository_id, plugin_id, to_organisation, enabled=True) - def upsert_settings(self, values: dict, organisation_id: str, plugin_id: str): + def upsert_settings(self, values: dict, organisation_id: str, plugin_id: str) -> None: self._assert_settings_match_schema(values, organisation_id, plugin_id) - return self.settings_storage.upsert(values, organisation_id, plugin_id) + self.settings_storage.upsert(values, organisation_id, plugin_id) - def delete_settings(self, organisation_id: str, plugin_id: str): + def delete_settings(self, organisation_id: str, plugin_id: str) -> None: self.settings_storage.delete(organisation_id, plugin_id) try: @@ -155,7 +155,7 @@ def repository_plugin(self, repository_id: str, plugin_id: str, organisation_id: return plugin - def update_by_id(self, repository_id: str, plugin_id: str, organisation_id: str, enabled: bool): + def update_by_id(self, repository_id: str, plugin_id: str, organisation_id: str, enabled: bool) -> None: if enabled: all_settings = self.settings_storage.get_all(organisation_id, plugin_id) self._assert_settings_match_schema(all_settings, organisation_id, plugin_id) @@ -186,7 +186,7 @@ def _plugins_for_repos( return plugins - def _assert_settings_match_schema(self, all_settings: dict, organisation_id: str, plugin_id: str): + def _assert_settings_match_schema(self, all_settings: dict, organisation_id: str, plugin_id: str) -> None: schema = self.schema(plugin_id) if schema: # No schema means that there is nothing to assert @@ -207,7 +207,7 @@ def _namespaced_id(repository_id: str, plugin_id: str) -> str: def get_plugin_service(organisation_id: str) -> Iterator[PluginService]: - def closure(session: Session): + def closure(session: Session) -> PluginService: return PluginService( create_plugin_enabled_storage(session), create_repository_storage(session), diff --git a/boefjes/boefjes/katalogus/models.py b/boefjes/boefjes/katalogus/models.py index 00f3c341925..7e9ac119165 100644 --- a/boefjes/boefjes/katalogus/models.py +++ b/boefjes/boefjes/katalogus/models.py @@ -30,7 +30,7 @@ class Plugin(BaseModel): related: list[str] | None = None enabled: bool = False - def __str__(self): + def __str__(self) -> str: return f"{self.id}:{self.version}" diff --git a/boefjes/boefjes/katalogus/routers/organisations.py b/boefjes/boefjes/katalogus/routers/organisations.py index fcb68e0f99f..077893066f5 100644 --- a/boefjes/boefjes/katalogus/routers/organisations.py +++ b/boefjes/boefjes/katalogus/routers/organisations.py @@ -26,7 +26,7 @@ def check_organisation_exists( @router.get("", response_model=dict[str, Organisation]) def list_organisations( storage: OrganisationStorage = Depends(get_organisations_store), -): +) -> dict[str, Organisation]: return storage.get_all() @@ -34,7 +34,7 @@ def list_organisations( def get_organisation( organisation_id: str, storage: OrganisationStorage = Depends(get_organisations_store), -): +) -> Organisation: try: return storage.get_by_id(organisation_id) except (KeyError, ObjectNotFoundException): @@ -45,7 +45,7 @@ def get_organisation( def add_organisation( organisation: Organisation, storage: OrganisationStorage = Depends(get_organisations_store), -): +) -> None: with storage as store: store.create(organisation) @@ -54,6 +54,6 @@ def add_organisation( def remove_organisation( organisation_id: str, storage: OrganisationStorage = Depends(get_organisations_store), -): +) -> None: with storage as store: store.delete_by_id(organisation_id) diff --git a/boefjes/boefjes/katalogus/routers/plugins.py b/boefjes/boefjes/katalogus/routers/plugins.py index be3b9cabd83..9ea11e2a9da 100644 --- a/boefjes/boefjes/katalogus/routers/plugins.py +++ b/boefjes/boefjes/katalogus/routers/plugins.py @@ -95,7 +95,7 @@ def list_repository_plugins( repository_id: str, organisation_id: str, plugin_service: PluginService = Depends(get_plugin_service), -): +) -> dict[str, PluginType]: with plugin_service as p: return p.repository_plugins(repository_id, organisation_id) @@ -123,7 +123,7 @@ def update_plugin_state( organisation_id: str, enabled: bool = Body(False, embed=True), plugin_service: PluginService = Depends(get_plugin_service), -): +) -> None: try: with plugin_service as p: p.update_by_id(repository_id, plugin_id, organisation_id, enabled) @@ -181,6 +181,6 @@ def clone_organisation_settings( organisation_id: str, to_organisation_id: str, storage: PluginService = Depends(get_plugin_service), -): +) -> None: with storage as store: store.clone_settings_to_organisation(organisation_id, to_organisation_id) diff --git a/boefjes/boefjes/katalogus/routers/repositories.py b/boefjes/boefjes/katalogus/routers/repositories.py index b00d08d16ea..af2edb79578 100644 --- a/boefjes/boefjes/katalogus/routers/repositories.py +++ b/boefjes/boefjes/katalogus/routers/repositories.py @@ -13,12 +13,12 @@ @router.get("", response_model=dict[str, Repository], response_model_exclude={0: {0: False}}) -def list_repositories(storage: RepositoryStorage = Depends(get_repository_store)): +def list_repositories(storage: RepositoryStorage = Depends(get_repository_store)) -> dict[str, Repository]: return storage.get_all() @router.get("/{repository_id}", response_model=Repository) -def get_repository(repository_id: str, storage: RepositoryStorage = Depends(get_repository_store)): +def get_repository(repository_id: str, storage: RepositoryStorage = Depends(get_repository_store)) -> Repository: try: return storage.get_by_id(repository_id) except KeyError: @@ -29,7 +29,7 @@ def get_repository(repository_id: str, storage: RepositoryStorage = Depends(get_ def add_repository( repository: Repository, storage: RepositoryStorage = Depends(get_repository_store), -): +) -> None: with storage as store: store.create(repository) @@ -38,7 +38,7 @@ def add_repository( def remove_repository( repository_id: str, storage: RepositoryStorage = Depends(get_repository_store), -): +) -> None: if repository_id == RESERVED_LOCAL_ID: raise HTTPException(status.HTTP_403_FORBIDDEN, "LOCAL repository cannot be deleted") with storage as store: diff --git a/boefjes/boefjes/katalogus/routers/settings.py b/boefjes/boefjes/katalogus/routers/settings.py index 9c1bebe8d04..2055f72b9bd 100644 --- a/boefjes/boefjes/katalogus/routers/settings.py +++ b/boefjes/boefjes/katalogus/routers/settings.py @@ -15,7 +15,7 @@ def list_settings( organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service), -): +) -> dict[str, str]: with plugin_service as p: return p.get_all_settings(organisation_id, plugin_id) @@ -26,12 +26,14 @@ def upsert_settings( plugin_id: str, values: dict, plugin_service: PluginService = Depends(get_plugin_service), -): +) -> None: with plugin_service as p: p.upsert_settings(values, organisation_id, plugin_id) @router.delete("") -def remove_settings(organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service)): +def remove_settings( + organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service) +) -> None: with plugin_service as p: p.delete_settings(organisation_id, plugin_id) diff --git a/boefjes/boefjes/migrations/env.py b/boefjes/boefjes/migrations/env.py index cf7a0a4c2dd..2cb712f59f3 100644 --- a/boefjes/boefjes/migrations/env.py +++ b/boefjes/boefjes/migrations/env.py @@ -4,7 +4,7 @@ from sqlalchemy import engine_from_config, pool from boefjes.config import settings -from boefjes.sql.db_models import SQL_BASE +from boefjes.sql.db import SQL_BASE # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py b/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py index 7351d140501..5c512fe4ddf 100644 --- a/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py +++ b/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py @@ -42,7 +42,7 @@ def upgrade() -> None: # ### end Alembic commands ### -def upgrade_encrypted_settings(conn: Connection): +def upgrade_encrypted_settings(conn: Connection) -> None: encrypter = create_encrypter() with conn.begin(): @@ -90,7 +90,7 @@ def downgrade() -> None: # ### end Alembic commands ### -def downgrade_encrypted_settings(conn: Connection): +def downgrade_encrypted_settings(conn: Connection) -> None: encrypter = create_encrypter() with conn.begin(): diff --git a/boefjes/boefjes/plugins/kat_crt_sh/main.py b/boefjes/boefjes/plugins/kat_crt_sh/main.py index cb2d891bae0..318a45efaf1 100644 --- a/boefjes/boefjes/plugins/kat_crt_sh/main.py +++ b/boefjes/boefjes/plugins/kat_crt_sh/main.py @@ -31,7 +31,13 @@ ) -def request_certs(search_string, search_type="Identity", match="=", deduplicate=True, json_output=True) -> str: +def request_certs( + search_string: str, + search_type: str = "Identity", + match: str = "=", + deduplicate: bool = True, + json_output: bool = True, +) -> str: """Queries the public service CRT.sh for certificate information the searchtype can be specified and defaults to Identity. the type of sql matching can be specified and defaults to "=" diff --git a/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py b/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py index 92670e81f45..d93a87dc9c9 100644 --- a/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py +++ b/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py @@ -4,12 +4,12 @@ from octopoes.models import Reference from octopoes.models.ooi.findings import CVEFindingType, Finding from octopoes.models.ooi.software import Software, SoftwareInstance -from packaging import version +from packaging.version import Version, parse VULNERABLE_RANGES: list[tuple[str, str]] = [("0", "11.8.1.1"), ("11.9.0.0", "11.9.1.1"), ("11.10.0.0", "11.10.0.2")] -def extract_js_version(html_content: str) -> version.Version | bool: +def extract_js_version(html_content: str) -> Version | bool: telltale = "/mifs/scripts/auth.js?" telltale_position = html_content.find(telltale) if telltale_position == -1: @@ -20,10 +20,10 @@ def extract_js_version(html_content: str) -> version.Version | bool: version_string = html_content[telltale_position + len(telltale) : version_end] if not version_string: return False - return version.parse(" ".join(strip_vsp_and_build(version_string))) + return parse(" ".join(strip_vsp_and_build(version_string))) -def extract_css_version(html_content: str) -> version.Version | bool: +def extract_css_version(html_content: str) -> Version | bool: telltale = "/mifs/css/windowsAllAuth.css?" telltale_position = html_content.find(telltale) if telltale_position == -1: @@ -34,7 +34,7 @@ def extract_css_version(html_content: str) -> version.Version | bool: version_string = html_content[telltale_position + len(telltale) : version_end] if not version_string: return False - return version.parse(" ".join(strip_vsp_and_build(version_string))) + return parse(" ".join(strip_vsp_and_build(version_string))) def strip_vsp_and_build(url: str) -> Iterable[str]: @@ -47,9 +47,7 @@ def strip_vsp_and_build(url: str) -> Iterable[str]: yield part -def is_vulnerable_version( - vulnerable_ranges: list[tuple[version.Version, version.Version]], detected_version: version.Version -) -> bool: +def is_vulnerable_version(vulnerable_ranges: list[tuple[Version, Version]], detected_version: Version) -> bool: return any(start <= detected_version < end for start, end in vulnerable_ranges) @@ -70,11 +68,11 @@ def run(input_ooi: dict, raw: bytes) -> Iterable[NormalizerOutput]: yield software_instance if js_detected_version: vulnerable = is_vulnerable_version( - [(version.parse(start), version.parse(end)) for start, end in VULNERABLE_RANGES], js_detected_version + [(parse(start), parse(end)) for start, end in VULNERABLE_RANGES], js_detected_version ) else: # The CSS version only included the first two parts of the version number so we don't know the patch level - vulnerable = css_detected_version < version.parse("11.8") + vulnerable = css_detected_version < parse("11.8") if vulnerable: finding_type = CVEFindingType(id="CVE-2023-35078") finding = Finding( diff --git a/boefjes/boefjes/plugins/kat_dnssec/main.py b/boefjes/boefjes/plugins/kat_dnssec/main.py index fd94d371ba0..7d01b4a1b0b 100644 --- a/boefjes/boefjes/plugins/kat_dnssec/main.py +++ b/boefjes/boefjes/plugins/kat_dnssec/main.py @@ -4,7 +4,7 @@ import subprocess -def run(boefje_meta: dict): +def run(boefje_meta: dict) -> list[tuple[set, bytes | str]]: input_ = boefje_meta["arguments"]["input"] domain = input_["name"] diff --git a/boefjes/boefjes/plugins/kat_manual/csv/normalize.py b/boefjes/boefjes/plugins/kat_manual/csv/normalize.py index a25bf73f0d1..262ba102cbe 100644 --- a/boefjes/boefjes/plugins/kat_manual/csv/normalize.py +++ b/boefjes/boefjes/plugins/kat_manual/csv/normalize.py @@ -7,7 +7,7 @@ from pydantic import ValidationError from boefjes.job_models import NormalizerDeclaration, NormalizerOutput -from octopoes.models import Reference +from octopoes.models import OOI, Reference from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.network import IPAddressV4, IPAddressV6, Network from octopoes.models.ooi.web import URL @@ -30,7 +30,7 @@ def run(input_ooi: dict, raw: bytes) -> Iterable[NormalizerOutput]: yield from process_csv(raw, reference_cache) -def process_csv(csv_raw_data, reference_cache) -> Iterable[NormalizerOutput]: +def process_csv(csv_raw_data: bytes, reference_cache: dict) -> Iterable[NormalizerOutput]: csv_data = io.StringIO(csv_raw_data.decode("UTF-8")) object_type = get_object_type(csv_data) @@ -74,7 +74,7 @@ def get_object_type(csv_data: io.StringIO) -> str: def get_ooi_from_csv( - ooi_type_name: str, values: dict[str, str], reference_cache + ooi_type_name: str, values: dict[str, str], reference_cache: dict ) -> tuple[OOIType, list[NormalizerDeclaration]]: skip_properties = ("object_type", "scan_profile", "primary_key") @@ -85,7 +85,7 @@ def get_ooi_from_csv( if field not in skip_properties ] - kwargs = {} + kwargs: dict[str, Reference | str | None] = {} extra_declarations: list[NormalizerDeclaration] = [] for field, is_reference, required in ooi_fields: @@ -109,7 +109,7 @@ def get_ooi_from_csv( return ooi_type(**kwargs), extra_declarations -def get_or_create_reference(ooi_type_name: str, value: str | None, reference_cache): +def get_or_create_reference(ooi_type_name: str, value: str | None, reference_cache: dict) -> OOI: ooi_type_name = next(filter(lambda x: x.casefold() == ooi_type_name.casefold(), OOI_TYPES.keys())) # get from cache diff --git a/boefjes/boefjes/plugins/kat_masscan/main.py b/boefjes/boefjes/plugins/kat_masscan/main.py index 5ffe991aa17..63044ca1d72 100644 --- a/boefjes/boefjes/plugins/kat_masscan/main.py +++ b/boefjes/boefjes/plugins/kat_masscan/main.py @@ -10,7 +10,7 @@ FILE_PATH = "/tmp/output.json" # noqa: S108 -def run_masscan(target_ip) -> bytes: +def run_masscan(target_ip: str) -> bytes: """Run Masscan in Docker.""" client = docker.from_env() diff --git a/boefjes/boefjes/plugins/kat_nmap_tcp/main.py b/boefjes/boefjes/plugins/kat_nmap_tcp/main.py index c46d686134f..2c787aea076 100644 --- a/boefjes/boefjes/plugins/kat_nmap_tcp/main.py +++ b/boefjes/boefjes/plugins/kat_nmap_tcp/main.py @@ -5,7 +5,7 @@ TOP_PORTS_DEFAULT = 250 -def run(boefje_meta: dict): +def run(boefje_meta: dict) -> list[tuple[set, bytes | str]]: top_ports_key = "TOP_PORTS" if boefje_meta["boefje"]["id"] == "nmap-udp": top_ports_key = "TOP_PORTS_UDP" diff --git a/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py b/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py index 0ca93f35932..4dc1b6d806f 100644 --- a/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py +++ b/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py @@ -5,6 +5,7 @@ import requests from forcediphttpsadapter.adapters import ForcedIPHTTPSAdapter from requests import Session +from requests.models import Response from boefjes.job_models import BoefjeMeta @@ -53,7 +54,7 @@ def run(boefje_meta: BoefjeMeta) -> list[tuple[set, bytes | str]]: return [(set(), json.dumps(results))] -def do_request(hostname: str, session: Session, uri: str, useragent: str): +def do_request(hostname: str, session: Session, uri: str, useragent: str) -> Response: response = session.get( uri, headers={"Host": hostname, "User-Agent": useragent}, diff --git a/boefjes/boefjes/plugins/kat_snyk/check_version.py b/boefjes/boefjes/plugins/kat_snyk/check_version.py index 5625ba57f6b..146b2b02706 100644 --- a/boefjes/boefjes/plugins/kat_snyk/check_version.py +++ b/boefjes/boefjes/plugins/kat_snyk/check_version.py @@ -70,7 +70,7 @@ def check_version(version1: str, version2: str) -> VersionCheck: return check_version(version1_split[1], version2_split[1]) -def check_version_agains_versionlist(my_version: str, all_versions: list[str]): +def check_version_agains_versionlist(my_version: str, all_versions: list[str]) -> tuple[bool, list[str] | None]: lowerbound = all_versions.pop(0).strip() upperbound = None @@ -164,10 +164,12 @@ def check_version_agains_versionlist(my_version: str, all_versions: list[str]): return True, all_versions -def check_version_in(version: str, versions: str): +def check_version_in(version: str, versions: str) -> bool: if not version: return False - all_versions = versions.split(",") # Example: https://snyk.io/vuln/composer%3Awoocommerce%2Fwoocommerce-blocks + all_versions: list[str] | None = versions.split( + "," + ) # Example: https://snyk.io/vuln/composer%3Awoocommerce%2Fwoocommerce-blocks in_range = False while not in_range and all_versions: in_range, all_versions = check_version_agains_versionlist(version, all_versions) diff --git a/boefjes/boefjes/plugins/kat_webpage_analysis/main.py b/boefjes/boefjes/plugins/kat_webpage_analysis/main.py index e8accd74d7b..a3ff205bdcd 100644 --- a/boefjes/boefjes/plugins/kat_webpage_analysis/main.py +++ b/boefjes/boefjes/plugins/kat_webpage_analysis/main.py @@ -7,6 +7,7 @@ import requests from forcediphttpsadapter.adapters import ForcedIPHTTPSAdapter from requests import Session +from requests.models import Response from boefjes.job_models import BoefjeMeta @@ -62,9 +63,9 @@ def run(boefje_meta: BoefjeMeta) -> list[tuple[set, bytes | str]]: body_mimetypes.add(content_type) # Pick up the content type for the body from the server and split away encodings to make normalization easier - content_type = content_type.split(";") - if content_type[0] in ALLOWED_CONTENT_TYPES: - body_mimetypes.add(content_type[0]) + content_type_splitted = content_type.split(";") + if content_type_splitted[0] in ALLOWED_CONTENT_TYPES: + body_mimetypes.add(content_type_splitted[0]) # in case of a full response object, we hexdump to avoid issues with binary data or different encoding response_dump = json.dumps(create_response_object(response)) @@ -95,7 +96,7 @@ def create_response_object(response: requests.Response) -> dict: } -def do_request(hostname: str, session: Session, uri: str, useragent: str): +def do_request(hostname: str, session: Session, uri: str, useragent: str) -> Response: response = session.get( uri, headers={"Host": hostname, "User-Agent": useragent}, diff --git a/boefjes/boefjes/plugins/kat_webpage_capture/main.py b/boefjes/boefjes/plugins/kat_webpage_capture/main.py index 6ee7e8dd44e..afecb4bbc89 100644 --- a/boefjes/boefjes/plugins/kat_webpage_capture/main.py +++ b/boefjes/boefjes/plugins/kat_webpage_capture/main.py @@ -10,11 +10,11 @@ class WebpageCaptureException(Exception): """Exception raised when webpage capture fails.""" - def __init__(self, message, container_log=None): + def __init__(self, message: str, container_log: str): self.message = message self.container_log = container_log - def __str__(self): + def __str__(self) -> str: return str(self.message) + "\n\nContainer log:\n" + self.container_log diff --git a/boefjes/boefjes/plugins/models.py b/boefjes/boefjes/plugins/models.py index d02c255fd39..2c218909bae 100644 --- a/boefjes/boefjes/plugins/models.py +++ b/boefjes/boefjes/plugins/models.py @@ -89,7 +89,7 @@ def get_runnable_hash(path: Path) -> str: return folder_hash.hexdigest() -def _default_mime_types(boefje: Boefje): +def _default_mime_types(boefje: Boefje) -> set: mime_types = {f"boefje/{boefje.id}"} if boefje.version is not None: diff --git a/boefjes/boefjes/runtime_interfaces.py b/boefjes/boefjes/runtime_interfaces.py index 0a8375bdb86..70bacfb5554 100644 --- a/boefjes/boefjes/runtime_interfaces.py +++ b/boefjes/boefjes/runtime_interfaces.py @@ -4,7 +4,7 @@ class Handler: - def handle(self, item: BoefjeMeta | NormalizerMeta): + def handle(self, item: BoefjeMeta | NormalizerMeta) -> None: raise NotImplementedError() diff --git a/boefjes/boefjes/seed.py b/boefjes/boefjes/seed.py index ede27ccfbef..034088bcf26 100644 --- a/boefjes/boefjes/seed.py +++ b/boefjes/boefjes/seed.py @@ -6,7 +6,7 @@ from boefjes.sql.db_models import RepositoryInDB -def main(): +def main() -> None: session = sessionmaker(bind=get_engine())() try: diff --git a/boefjes/boefjes/sql/organisation_storage.py b/boefjes/boefjes/sql/organisation_storage.py index d772b890e3e..7f492b40a2e 100644 --- a/boefjes/boefjes/sql/organisation_storage.py +++ b/boefjes/boefjes/sql/organisation_storage.py @@ -89,5 +89,5 @@ def to_organisation(organisation_in_db: OrganisationInDB) -> Organisation: ) -def create_organisation_storage(session) -> SQLOrganisationStorage: +def create_organisation_storage(session: Session) -> SQLOrganisationStorage: return SQLOrganisationStorage(session, settings) diff --git a/boefjes/boefjes/sql/repository_storage.py b/boefjes/boefjes/sql/repository_storage.py index 9ba65769630..978ccb5fac9 100644 --- a/boefjes/boefjes/sql/repository_storage.py +++ b/boefjes/boefjes/sql/repository_storage.py @@ -67,5 +67,5 @@ def to_repository(repository_in_db: RepositoryInDB) -> Repository: ) -def create_repository_storage(session) -> SQLRepositoryStorage: +def create_repository_storage(session: Session) -> SQLRepositoryStorage: return SQLRepositoryStorage(session, settings) diff --git a/boefjes/boefjes/sql/setting_storage.py b/boefjes/boefjes/sql/setting_storage.py index f7c8d182fdc..17102e85395 100644 --- a/boefjes/boefjes/sql/setting_storage.py +++ b/boefjes/boefjes/sql/setting_storage.py @@ -67,7 +67,7 @@ def _db_instance_by_id(self, organisation_id: str, plugin_id: str) -> SettingsIn return instance -def create_setting_storage(session) -> SettingsStorage: +def create_setting_storage(session: Session) -> SettingsStorage: encrypter = create_encrypter() return SQLSettingsStorage(session, encrypter) diff --git a/bytes/bytes/api/metrics.py b/bytes/bytes/api/metrics.py index 48f937b4719..29c2d913f53 100644 --- a/bytes/bytes/api/metrics.py +++ b/bytes/bytes/api/metrics.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -def ignore_arguments_key(meta_repository: MetaDataRepository): +def ignore_arguments_key(meta_repository: MetaDataRepository) -> str: return "" diff --git a/bytes/bytes/api/root.py b/bytes/bytes/api/root.py index d26468ee2e1..3b94b8e44e1 100644 --- a/bytes/bytes/api/root.py +++ b/bytes/bytes/api/root.py @@ -52,7 +52,7 @@ def health() -> ServiceHealth: @router.get("/metrics", dependencies=[Depends(authenticate_token)]) -def metrics(meta_repository: MetaDataRepository = Depends(create_meta_data_repository)): +def metrics(meta_repository: MetaDataRepository = Depends(create_meta_data_repository)) -> Response: collector_registry = get_registry(meta_repository) data = prometheus_client.generate_latest(collector_registry) diff --git a/bytes/bytes/api/router.py b/bytes/bytes/api/router.py index d6961c91dcd..0b091785d97 100644 --- a/bytes/bytes/api/router.py +++ b/bytes/bytes/api/router.py @@ -257,7 +257,7 @@ def get_raw_count_per_mime_type( return cached_counts_per_mime_type(meta_repository, query_filter) -def ignore_arguments_key(meta_repository: MetaDataRepository, query_filter: RawDataFilter): +def ignore_arguments_key(meta_repository: MetaDataRepository, query_filter: RawDataFilter) -> str: """Helper to not cache based on the stateful meta_repository, but only use the query parameters as a key.""" return query_filter.json() diff --git a/bytes/bytes/database/migrations/env.py b/bytes/bytes/database/migrations/env.py index 6c731d7d72f..0e267ea735e 100644 --- a/bytes/bytes/database/migrations/env.py +++ b/bytes/bytes/database/migrations/env.py @@ -6,7 +6,7 @@ # this is the Alembic Config object, which provides # access to the values within the .ini file in use. from bytes.config import get_settings -from bytes.database.db_models import SQL_BASE +from bytes.database.db import SQL_BASE config = context.config diff --git a/bytes/bytes/database/sql_meta_repository.py b/bytes/bytes/database/sql_meta_repository.py index 4faac6e4cab..3e7004250fb 100644 --- a/bytes/bytes/database/sql_meta_repository.py +++ b/bytes/bytes/database/sql_meta_repository.py @@ -229,7 +229,7 @@ def create_meta_data_repository() -> Iterator[MetaDataRepository]: class ObjectNotFoundException(Exception): - def __init__(self, cls: type[SQL_BASE], **kwargs): + def __init__(self, cls: type[SQL_BASE], **kwargs: str): super().__init__(f"The object of type {cls} was not found for query parameters {kwargs}") diff --git a/bytes/bytes/rabbitmq.py b/bytes/bytes/rabbitmq.py index 94413d27f3b..9dfa4ce024a 100644 --- a/bytes/bytes/rabbitmq.py +++ b/bytes/bytes/rabbitmq.py @@ -41,7 +41,7 @@ def publish(self, event: Event) -> None: logger.info("Published event [event_id=%s] to queue %s", event.event_id, queue_name) - def _check_connection(self): + def _check_connection(self) -> None: if self.connection.is_closed: self.connection = pika.BlockingConnection(pika.URLParameters(self.queue_uri)) self.channel = self.connection.channel() diff --git a/cveapi/cveapi.py b/cveapi/cveapi.py index 92c5722a07f..a65d83dacd4 100644 --- a/cveapi/cveapi.py +++ b/cveapi/cveapi.py @@ -13,7 +13,7 @@ logger = logging.getLogger("cveapi") -def download_files(directory, last_update, update_timestamp): +def download_files(directory: pathlib.Path, last_update: datetime | None, update_timestamp: datetime) -> None: index = 0 client = httpx.Client() error_count = 0 @@ -66,7 +66,7 @@ def download_files(directory, last_update, update_timestamp): logger.info("Downloaded new information of %s CVEs", response_json["totalResults"]) -def run(): +def run() -> None: loglevel = os.getenv("CVEAPI_LOGLEVEL", "INFO") numeric_level = getattr(logging, loglevel.upper(), None) if not isinstance(numeric_level, int): diff --git a/mula/scheduler/__init__.py b/mula/scheduler/__init__.py index 211a13feb2f..3bd881035da 100644 --- a/mula/scheduler/__init__.py +++ b/mula/scheduler/__init__.py @@ -1,4 +1,4 @@ from .app import App -from .version import version +from .version import __version__ -__version__ = version +__all__ = ["App", "__version__"] diff --git a/mula/scheduler/connectors/connector.py b/mula/scheduler/connectors/connector.py index 29f3cbc4351..d7832079e9f 100644 --- a/mula/scheduler/connectors/connector.py +++ b/mula/scheduler/connectors/connector.py @@ -1,6 +1,7 @@ import socket import time from collections.abc import Callable +from typing import Any import httpx import structlog @@ -47,7 +48,7 @@ def is_host_healthy(self, host: str, health_endpoint: str) -> bool: self.logger.warning("Exception: %s", exc) return False - def retry(self, func: Callable, *args, **kwargs) -> bool: + def retry(self, func: Callable, *args: Any, **kwargs: Any) -> bool: """Retry a function until it returns True. Args: diff --git a/mula/scheduler/connectors/listeners/__init__.py b/mula/scheduler/connectors/listeners/__init__.py index ba7a3fab3ef..85194b47cb5 100644 --- a/mula/scheduler/connectors/listeners/__init__.py +++ b/mula/scheduler/connectors/listeners/__init__.py @@ -1,3 +1,5 @@ from .listeners import Listener, RabbitMQ from .raw_data import RawData from .scan_profile import ScanProfileMutation + +__all__ = ["Listener", "RabbitMQ", "RawData", "ScanProfileMutation"] diff --git a/mula/scheduler/connectors/listeners/listeners.py b/mula/scheduler/connectors/listeners/listeners.py index 6fde9c54e08..891604f3d39 100644 --- a/mula/scheduler/connectors/listeners/listeners.py +++ b/mula/scheduler/connectors/listeners/listeners.py @@ -164,7 +164,7 @@ def callback( # Submit the message to the thread pool executor self.executor.submit(self.dispatch, channel, method.delivery_tag, body) - def dispatch(self, channel, delivery_tag, body: bytes) -> None: + def dispatch(self, channel: pika.channel.Channel, delivery_tag: int, body: bytes) -> None: # Check if we still have a connection if not self.connection: self.logger.debug("No connection available, cannot dispatch message!") diff --git a/mula/scheduler/connectors/services/__init__.py b/mula/scheduler/connectors/services/__init__.py index 2082f2c95e6..b969ff08778 100644 --- a/mula/scheduler/connectors/services/__init__.py +++ b/mula/scheduler/connectors/services/__init__.py @@ -3,3 +3,5 @@ from .octopoes import Octopoes from .rocky import Rocky from .services import HTTPService + +__all__ = ["Bytes", "Katalogus", "Octopoes", "Rocky", "HTTPService"] diff --git a/mula/scheduler/connectors/services/katalogus.py b/mula/scheduler/connectors/services/katalogus.py index 7f9062bdd61..356fa08c6cb 100644 --- a/mula/scheduler/connectors/services/katalogus.py +++ b/mula/scheduler/connectors/services/katalogus.py @@ -137,7 +137,7 @@ def get_boefje(self, boefje_id: str) -> Boefje: return Boefje(**response.json()) @exception_handler - def get_organisation(self, organisation_id) -> Organisation: + def get_organisation(self, organisation_id: str) -> Organisation: url = f"{self.host}/v1/organisations/{organisation_id}" response = self.get(url) return Organisation(**response.json()) diff --git a/mula/scheduler/context/__init__.py b/mula/scheduler/context/__init__.py index 61627b2c729..ae686bb6e0c 100644 --- a/mula/scheduler/context/__init__.py +++ b/mula/scheduler/context/__init__.py @@ -1 +1,3 @@ from .context import AppContext + +__all__ = ["AppContext"] diff --git a/mula/scheduler/context/context.py b/mula/scheduler/context/context.py index 4369132a51a..db5a9a7083e 100644 --- a/mula/scheduler/context/context.py +++ b/mula/scheduler/context/context.py @@ -34,6 +34,9 @@ class AppContext: the schedulers. """ + metrics_qsize: Gauge + metrics_task_status_counts: Gauge + def __init__(self) -> None: """Initializer of the AppContext class.""" self.config: settings.Settings = settings.Settings() diff --git a/mula/scheduler/models/__init__.py b/mula/scheduler/models/__init__.py index 894fb0de555..ce0fe7f762e 100644 --- a/mula/scheduler/models/__init__.py +++ b/mula/scheduler/models/__init__.py @@ -1,12 +1,38 @@ from .base import Base from .boefje import Boefje, BoefjeMeta -from .events import RawData, RawDataReceivedEvent +from .events import RawDataReceivedEvent from .health import ServiceHealth from .normalizer import Normalizer -from .ooi import OOI, MutationOperationType, ScanProfile, ScanProfileMutation +from .ooi import OOI, MutationOperationType, ScanProfileMutation from .organisation import Organisation from .plugin import Plugin from .queue import PrioritizedItem, PrioritizedItemDB, Queue +from .raw_data import RawData from .request import PrioritizedItemRequest from .scheduler import Scheduler from .tasks import BoefjeTask, NormalizerTask, Task, TaskDB, TaskStatus + +__all__ = [ + "Base", + "Boefje", + "BoefjeMeta", + "RawData", + "RawDataReceivedEvent", + "ServiceHealth", + "Normalizer", + "OOI", + "MutationOperationType", + "ScanProfileMutation", + "Organisation", + "Plugin", + "PrioritizedItem", + "PrioritizedItemDB", + "Queue", + "PrioritizedItemRequest", + "Scheduler", + "BoefjeTask", + "NormalizerTask", + "Task", + "TaskDB", + "TaskStatus", +] diff --git a/mula/scheduler/models/tasks.py b/mula/scheduler/models/tasks.py index 52c229ecd39..4bc1bf925c6 100644 --- a/mula/scheduler/models/tasks.py +++ b/mula/scheduler/models/tasks.py @@ -61,7 +61,7 @@ class Task(BaseModel): modified_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - def __repr__(self): + def __repr__(self) -> str: return f"Task(id={self.id}, scheduler_id={self.scheduler_id}, type={self.type}, status={self.status})" diff --git a/mula/scheduler/queues/__init__.py b/mula/scheduler/queues/__init__.py index 50d51f48471..446d418ffd2 100644 --- a/mula/scheduler/queues/__init__.py +++ b/mula/scheduler/queues/__init__.py @@ -2,3 +2,13 @@ from .errors import InvalidPrioritizedItemError, NotAllowedError, QueueEmptyError, QueueFullError from .normalizer import NormalizerPriorityQueue from .pq import PriorityQueue + +__all__ = [ + "BoefjePriorityQueue", + "InvalidPrioritizedItemError", + "NotAllowedError", + "QueueEmptyError", + "QueueFullError", + "NormalizerPriorityQueue", + "PriorityQueue", +] diff --git a/mula/scheduler/rankers/__init__.py b/mula/scheduler/rankers/__init__.py index 49a565268ec..4e1c36adbf9 100644 --- a/mula/scheduler/rankers/__init__.py +++ b/mula/scheduler/rankers/__init__.py @@ -1,3 +1,5 @@ from .boefje import BoefjeRanker from .normalizer import NormalizerRanker from .ranker import Ranker + +__all__ = ["BoefjeRanker", "NormalizerRanker", "Ranker"] diff --git a/mula/scheduler/schedulers/__init__.py b/mula/scheduler/schedulers/__init__.py index 5614508b532..bebfa613c2b 100644 --- a/mula/scheduler/schedulers/__init__.py +++ b/mula/scheduler/schedulers/__init__.py @@ -1,3 +1,5 @@ from .boefje import BoefjeScheduler from .normalizer import NormalizerScheduler from .scheduler import Scheduler + +__all__ = ["BoefjeScheduler", "NormalizerScheduler", "Scheduler"] diff --git a/mula/scheduler/schedulers/boefje.py b/mula/scheduler/schedulers/boefje.py index 516ed777f26..d8f23b8d359 100644 --- a/mula/scheduler/schedulers/boefje.py +++ b/mula/scheduler/schedulers/boefje.py @@ -782,7 +782,7 @@ def has_grace_period_passed(self, task: BoefjeTask) -> bool: return True - def get_boefjes_for_ooi(self, ooi) -> list[Plugin]: + def get_boefjes_for_ooi(self, ooi: OOI) -> list[Plugin]: """Get available all boefjes (enabled and disabled) for an ooi. Args: diff --git a/mula/scheduler/server/__init__.py b/mula/scheduler/server/__init__.py index b7f2cf59516..09ed39ca17f 100644 --- a/mula/scheduler/server/__init__.py +++ b/mula/scheduler/server/__init__.py @@ -1 +1,3 @@ from .server import Server + +__all__ = ["Server"] diff --git a/mula/scheduler/storage/__init__.py b/mula/scheduler/storage/__init__.py index 4f6d7155872..93a01287117 100644 --- a/mula/scheduler/storage/__init__.py +++ b/mula/scheduler/storage/__init__.py @@ -2,3 +2,11 @@ from .pq_store import PriorityQueueStore from .storage import DBConn, retry from .task_store import TaskStore + +__all__ = [ + "apply_filter", + "PriorityQueueStore", + "DBConn", + "retry", + "TaskStore", +] diff --git a/mula/scheduler/storage/filters/__init__.py b/mula/scheduler/storage/filters/__init__.py index ddf32f56ef3..eb44f2d36e1 100644 --- a/mula/scheduler/storage/filters/__init__.py +++ b/mula/scheduler/storage/filters/__init__.py @@ -1,3 +1,5 @@ from .casting import cast_expression from .filters import Filter, FilterRequest from .functions import apply_filter + +__all__ = ["cast_expression", "Filter", "FilterRequest", "apply_filter"] diff --git a/mula/scheduler/storage/filters/functions.py b/mula/scheduler/storage/filters/functions.py index 9158643e6a6..5d343ffb423 100644 --- a/mula/scheduler/storage/filters/functions.py +++ b/mula/scheduler/storage/filters/functions.py @@ -1,4 +1,5 @@ import sqlalchemy +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm.query import Query from sqlalchemy.sql.elements import BinaryExpression @@ -9,7 +10,7 @@ from .operators import FILTER_OPERATORS -def apply_filter(entity, query: Query, filter_request: FilterRequest) -> Query: +def apply_filter(entity: DeclarativeBase, query: Query, filter_request: FilterRequest) -> Query: """Apply the filter criteria to a SQLAlchemy query. Args: diff --git a/mula/scheduler/storage/pq_store.py b/mula/scheduler/storage/pq_store.py index 3cb300102f2..ec9b826433b 100644 --- a/mula/scheduler/storage/pq_store.py +++ b/mula/scheduler/storage/pq_store.py @@ -75,7 +75,7 @@ def remove(self, scheduler_id: str, item_id: UUID) -> None: ) @retry() - def get(self, scheduler_id, item_id: UUID) -> models.PrioritizedItem | None: + def get(self, scheduler_id: str, item_id: UUID) -> models.PrioritizedItem | None: with self.dbconn.session.begin() as session: item_orm = ( session.query(models.PrioritizedItemDB) diff --git a/mula/scheduler/storage/storage.py b/mula/scheduler/storage/storage.py index 405090b72d3..719c5299785 100644 --- a/mula/scheduler/storage/storage.py +++ b/mula/scheduler/storage/storage.py @@ -1,6 +1,7 @@ import json import logging import time +from collections.abc import Callable from functools import partial, wraps import sqlalchemy @@ -39,7 +40,7 @@ def __init__(self, dsn: str) -> None: ) -def retry(max_retries: int = 3, retry_delay: float = 5.0): +def retry(max_retries: int = 3, retry_delay: float = 5.0) -> Callable: def decorator(func): @wraps(func) def wrapper(*args, **kwargs): diff --git a/mula/scheduler/utils/__init__.py b/mula/scheduler/utils/__init__.py index 47047ddb189..cb8b2af51d1 100644 --- a/mula/scheduler/utils/__init__.py +++ b/mula/scheduler/utils/__init__.py @@ -2,3 +2,5 @@ from .dict_utils import ExpiredError, ExpiringDict, deep_get from .functions import remove_trailing_slash from .thread import ThreadRunner + +__all__ = ["GUID", "ExpiredError", "ExpiringDict", "deep_get", "remove_trailing_slash", "ThreadRunner"] diff --git a/octopoes/bits/check_csp_header/bit.py b/octopoes/bits/check_csp_header/bit.py index 72bc8be7bc4..1304327fd72 100644 --- a/octopoes/bits/check_csp_header/bit.py +++ b/octopoes/bits/check_csp_header/bit.py @@ -1,5 +1,5 @@ from bits.definitions import BitDefinition -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader BIT = BitDefinition( id="check-csp-header", diff --git a/octopoes/bits/check_csp_header/check_csp_header.py b/octopoes/bits/check_csp_header/check_csp_header.py index c621a652e74..46c76a3387f 100644 --- a/octopoes/bits/check_csp_header/check_csp_header.py +++ b/octopoes/bits/check_csp_header/check_csp_header.py @@ -5,7 +5,7 @@ from octopoes.models import OOI, Reference from octopoes.models.ooi.findings import Finding, KATFindingType -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader NON_DECIMAL_FILTER = re.compile(r"[^\d.]+") diff --git a/octopoes/bits/check_cve_2021_41773/bit.py b/octopoes/bits/check_cve_2021_41773/bit.py index 367183e5f8d..3e32a458c9c 100644 --- a/octopoes/bits/check_cve_2021_41773/bit.py +++ b/octopoes/bits/check_cve_2021_41773/bit.py @@ -1,5 +1,5 @@ from bits.definitions import BitDefinition -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader BIT = BitDefinition( id="check_cve_2021_41773", diff --git a/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py b/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py index c55e786b7fc..89ac0aed4b4 100644 --- a/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py +++ b/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py @@ -3,7 +3,7 @@ from octopoes.models import OOI from octopoes.models.ooi.findings import CVEFindingType, Finding -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: diff --git a/octopoes/bits/check_hsts_header/bit.py b/octopoes/bits/check_hsts_header/bit.py index 5e8436084f6..6b98f2d86a9 100644 --- a/octopoes/bits/check_hsts_header/bit.py +++ b/octopoes/bits/check_hsts_header/bit.py @@ -1,5 +1,5 @@ from bits.definitions import BitDefinition -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader BIT = BitDefinition( id="check-hsts-header", diff --git a/octopoes/bits/check_hsts_header/check_hsts_header.py b/octopoes/bits/check_hsts_header/check_hsts_header.py index d7e2dcf0c12..87d3090a8f8 100644 --- a/octopoes/bits/check_hsts_header/check_hsts_header.py +++ b/octopoes/bits/check_hsts_header/check_hsts_header.py @@ -4,7 +4,7 @@ from octopoes.models import OOI, Reference from octopoes.models.ooi.findings import Finding, KATFindingType -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: diff --git a/octopoes/bits/cipher_classification/cipher_classification.py b/octopoes/bits/cipher_classification/cipher_classification.py index 8a4594613cf..0235fd8fb77 100644 --- a/octopoes/bits/cipher_classification/cipher_classification.py +++ b/octopoes/bits/cipher_classification/cipher_classification.py @@ -1,6 +1,7 @@ import csv from collections.abc import Iterator from pathlib import Path +from typing import Any from octopoes.models import OOI from octopoes.models.ooi.findings import Finding, KATFindingType @@ -13,7 +14,7 @@ } -def get_severity_and_reasons(cipher_suite) -> list[tuple[str, str]]: +def get_severity_and_reasons(cipher_suite: str) -> list[tuple[str, str]]: with Path.open(Path(__file__).parent / "list-ciphers-openssl-with-finding-type.csv", newline="") as csvfile: reader = csv.DictReader(csvfile) data = [{k.strip(): v.strip() for k, v in row.items() if k} for row in reader] @@ -76,7 +77,7 @@ def get_highest_severity_and_all_reasons(cipher_suites: dict) -> tuple[str, str] return highest_severity, all_reasons_str -def run(input_ooi: TLSCipher, additional_oois, config) -> Iterator[OOI]: +def run(input_ooi: TLSCipher, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: # Get the highest severity and all reasons for the cipher suite highest_severity, all_reasons = get_highest_severity_and_all_reasons(input_ooi.suites) diff --git a/octopoes/bits/missing_certificate/missing_certificate.py b/octopoes/bits/missing_certificate/missing_certificate.py index 04721d51dc6..ae08a2a214e 100644 --- a/octopoes/bits/missing_certificate/missing_certificate.py +++ b/octopoes/bits/missing_certificate/missing_certificate.py @@ -6,7 +6,7 @@ from octopoes.models.ooi.web import Website -def run(input_ooi: Website, additional_oois, config: dict[str, Any]) -> Iterator[OOI]: +def run(input_ooi: Website, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: if input_ooi.ip_service.tokenized.service.name.lower() != "https": return diff --git a/octopoes/bits/nxdomain_flag/bit.py b/octopoes/bits/nxdomain_flag/bit.py index 13464c56483..c094f5903bc 100644 --- a/octopoes/bits/nxdomain_flag/bit.py +++ b/octopoes/bits/nxdomain_flag/bit.py @@ -1,6 +1,6 @@ from bits.definitions import BitDefinition, BitParameterDefinition +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname -from octopoes.models.types import NXDOMAIN BIT = BitDefinition( id="nxdomain-flag", diff --git a/octopoes/bits/nxdomain_flag/nxdomain_flag.py b/octopoes/bits/nxdomain_flag/nxdomain_flag.py index b42685d9210..a92c6cc85cf 100644 --- a/octopoes/bits/nxdomain_flag/nxdomain_flag.py +++ b/octopoes/bits/nxdomain_flag/nxdomain_flag.py @@ -2,9 +2,9 @@ from typing import Any from octopoes.models import OOI +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.findings import Finding, KATFindingType -from octopoes.models.types import NXDOMAIN def run(input_ooi: Hostname, additional_oois: list[NXDOMAIN], config: dict[str, Any]) -> Iterator[OOI]: diff --git a/octopoes/bits/nxdomain_header_flag/bit.py b/octopoes/bits/nxdomain_header_flag/bit.py index 3d4883adec1..296ac1a3580 100644 --- a/octopoes/bits/nxdomain_header_flag/bit.py +++ b/octopoes/bits/nxdomain_header_flag/bit.py @@ -1,7 +1,7 @@ from bits.definitions import BitDefinition, BitParameterDefinition +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.web import HTTPHeaderHostname -from octopoes.models.types import NXDOMAIN BIT = BitDefinition( id="nxdomain-header-flag", diff --git a/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py b/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py index f55c6b54561..7a4f78a004e 100644 --- a/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py +++ b/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py @@ -2,10 +2,10 @@ from typing import Any from octopoes.models import OOI +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.findings import Finding, KATFindingType from octopoes.models.ooi.web import HTTPHeaderHostname -from octopoes.models.types import NXDOMAIN def run( diff --git a/octopoes/bits/oois_in_headers/bit.py b/octopoes/bits/oois_in_headers/bit.py index dbdd9b7981e..515554738f3 100644 --- a/octopoes/bits/oois_in_headers/bit.py +++ b/octopoes/bits/oois_in_headers/bit.py @@ -1,5 +1,5 @@ from bits.definitions import BitDefinition -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader BIT = BitDefinition( id="oois-in-headers", diff --git a/octopoes/bits/oois_in_headers/oois_in_headers.py b/octopoes/bits/oois_in_headers/oois_in_headers.py index a6e67d15b08..6e48863952c 100644 --- a/octopoes/bits/oois_in_headers/oois_in_headers.py +++ b/octopoes/bits/oois_in_headers/oois_in_headers.py @@ -7,8 +7,8 @@ from octopoes.models import OOI from octopoes.models.ooi.dns.zone import Hostname -from octopoes.models.ooi.web import HTTPHeaderHostname -from octopoes.models.types import URL, HTTPHeader, HTTPHeaderURL, Network +from octopoes.models.ooi.network import Network +from octopoes.models.ooi.web import URL, HTTPHeader, HTTPHeaderHostname, HTTPHeaderURL def is_url(input_str): diff --git a/octopoes/bits/retire_js/retire_js.py b/octopoes/bits/retire_js/retire_js.py index 5b57382ffc4..b79a06ec6fc 100644 --- a/octopoes/bits/retire_js/retire_js.py +++ b/octopoes/bits/retire_js/retire_js.py @@ -7,7 +7,7 @@ from octopoes.models import OOI from octopoes.models.ooi.findings import CVEFindingType, Finding, RetireJSFindingType from octopoes.models.ooi.software import Software, SoftwareInstance -from packaging import version +from packaging.version import parse def run(input_ooi: Software, additional_oois: list[SoftwareInstance], config: dict[str, Any]) -> Iterator[OOI]: @@ -40,7 +40,7 @@ def run(input_ooi: Software, additional_oois: list[SoftwareInstance], config: di ) -def _check_vulnerabilities(name, package_version: str, known_vulnerabilities: dict) -> dict[str, list[str]]: +def _check_vulnerabilities(name: str, package_version: str, known_vulnerabilities: dict) -> dict[str, list[str]]: vulnerabilities: dict[str, list[str]] = {"CVE": [], "RetireJS": []} processed_name = _process_name(name) found_brands = [brand for brand in known_vulnerabilities if processed_name == _process_name(brand)] @@ -70,10 +70,10 @@ def _hash_identifiers(identifiers: dict[str, str | list[str]]) -> str: def _check_versions(package_version: str, known_vulnerability: dict) -> bool: - below = version.parse(package_version) < version.parse(known_vulnerability["below"]) + below = parse(package_version) < parse(known_vulnerability["below"]) # Some packages are only vulnerable below a version and not above above = ( - version.parse(package_version) >= version.parse(known_vulnerability["atOrAbove"]) + parse(package_version) >= parse(known_vulnerability["atOrAbove"]) if "atOrAbove" in known_vulnerability else True ) diff --git a/octopoes/bits/runner.py b/octopoes/bits/runner.py index da851076c34..931297ae5df 100644 --- a/octopoes/bits/runner.py +++ b/octopoes/bits/runner.py @@ -1,7 +1,7 @@ from collections.abc import Iterator from importlib import import_module from inspect import isfunction, signature -from typing import Any, Protocol +from typing import Any from bits.definitions import BitDefinition from octopoes.models import OOI @@ -11,15 +11,11 @@ class ModuleException(Exception): """General error for modules""" -class Runnable(Protocol): - def run(self, *args, **kwargs) -> Any: ... - - class BitRunner: def __init__(self, bit_definition: BitDefinition): self.module = bit_definition.module - def run(self, *args, **kwargs) -> list[OOI]: + def run(self, *args: Any, **kwargs: Any) -> list[OOI]: module = import_module(self.module) if not hasattr(module, "run") or not isfunction(module.run): @@ -31,7 +27,7 @@ def run(self, *args, **kwargs) -> list[OOI]: ) return list(module.run(*args, **kwargs)) - def __str__(self): + def __str__(self) -> str: return f"BitRunner {self.module}" diff --git a/octopoes/bits/spf_discovery/spf_discovery.py b/octopoes/bits/spf_discovery/spf_discovery.py index 36a10f16003..a094cc28cfb 100644 --- a/octopoes/bits/spf_discovery/spf_discovery.py +++ b/octopoes/bits/spf_discovery/spf_discovery.py @@ -10,7 +10,7 @@ from octopoes.models.ooi.network import IPAddressV4, IPAddressV6, Network -def run(input_ooi: DNSTXTRecord, additional_oois, config: dict[str, Any]) -> Iterator[OOI]: +def run(input_ooi: DNSTXTRecord, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: if input_ooi.value.startswith("v=spf1"): spf_value = input_ooi.value.replace("%(d)", input_ooi.hostname.tokenized.name) parsed = parse(spf_value) diff --git a/octopoes/octopoes/api/router.py b/octopoes/octopoes/api/router.py index eb1afbd4e41..3dd6fe0c55d 100644 --- a/octopoes/octopoes/api/router.py +++ b/octopoes/octopoes/api/router.py @@ -122,7 +122,7 @@ def list_objects( scan_profile_type: set[ScanProfileType] = Query(DEFAULT_SCAN_PROFILE_TYPE_FILTER), offset: int = 0, limit: int = 20, -): +) -> Paginated[OOI]: return octopoes.list_ooi(types, valid_time, offset, limit, scan_level, scan_profile_type) @@ -134,7 +134,7 @@ def query( valid_time: datetime = Depends(extract_valid_time), offset: int = DEFAULT_OFFSET, limit: int = DEFAULT_LIMIT, -): +) -> list[OOI | tuple]: object_path = ObjectPath.parse(path) xtdb_query = XTDBQuery.from_path(object_path).offset(offset).limit(limit) @@ -150,7 +150,7 @@ def query_many( sources: list[str] = Query(), octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), -): +) -> list[OOI | tuple]: """ How does this work and why do we do this? @@ -195,7 +195,7 @@ def load_objects_bulk( octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), references: set[Reference] = Depends(extract_references), -): +) -> dict[str, OOI]: return octopoes.ooi_repository.load_bulk(references, valid_time) @@ -204,7 +204,7 @@ def get_object( octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), reference: Reference = Depends(extract_reference), -): +) -> OOI: return octopoes.get_ooi(reference, valid_time) @@ -236,7 +236,7 @@ def list_random_objects( valid_time: datetime = Depends(extract_valid_time), amount: int = 1, scan_level: set[ScanLevel] = Query(DEFAULT_SCAN_LEVEL_FILTER), -): +) -> list[OOI]: return octopoes.list_random_ooi(valid_time, amount, scan_level) @@ -426,8 +426,8 @@ def get_scan_profile_inheritance( def list_findings( exclude_muted: bool = True, only_muted: bool = False, - offset=DEFAULT_OFFSET, - limit=DEFAULT_LIMIT, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), severities: set[RiskLevelSeverity] = Query(DEFAULT_SEVERITY_FILTER), diff --git a/octopoes/octopoes/connector/__init__.py b/octopoes/octopoes/connector/__init__.py index a5ca0237d37..5f56b2ee59e 100644 --- a/octopoes/octopoes/connector/__init__.py +++ b/octopoes/octopoes/connector/__init__.py @@ -1,12 +1,8 @@ -# Keep for backwards compatibility -from octopoes.models.exception import ObjectNotFoundException - - class ConnectorException(Exception): def __init__(self, value: str): self.value = value - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/octopoes/octopoes/connector/octopoes.py b/octopoes/octopoes/connector/octopoes.py index ec8ff1393be..2ec7dcbbc9a 100644 --- a/octopoes/octopoes/connector/octopoes.py +++ b/octopoes/octopoes/connector/octopoes.py @@ -1,5 +1,5 @@ import json -from collections.abc import Sequence, Set +from collections.abc import Iterable, Sequence, Set from datetime import datetime from uuid import UUID @@ -173,7 +173,7 @@ def save_affirmation(self, affirmation: Affirmation) -> None: content=affirmation.model_dump_json(), ) - def save_scan_profile(self, scan_profile: ScanProfile, valid_time: datetime): + def save_scan_profile(self, scan_profile: ScanProfile, valid_time: datetime) -> None: params = {"valid_time": str(valid_time)} self.session.put( f"/{self.client}/scan_profiles", @@ -221,7 +221,7 @@ def count_findings_by_severity(self, valid_time: datetime) -> dict[str, int]: def list_findings( self, - severities: set[RiskLevelSeverity], + severities: Iterable[RiskLevelSeverity], valid_time: datetime, exclude_muted: bool = True, only_muted: bool = False, @@ -239,9 +239,9 @@ def list_findings( res = self.session.get(f"/{self.client}/findings", params=params) return TypeAdapter(Paginated[Finding]).validate_json(res.content) - def load_objects_bulk(self, references: set[Reference], valid_time): + def load_objects_bulk(self, references: set[Reference], valid_time: datetime) -> dict[Reference, OOIType]: params = { - "valid_time": valid_time, + "valid_time": str(valid_time), } res = self.session.post( f"/{self.client}/objects/load_bulk", params=params, json=[str(ref) for ref in references] diff --git a/octopoes/octopoes/core/app.py b/octopoes/octopoes/core/app.py index 514742a09e9..9dd7e1ff323 100644 --- a/octopoes/octopoes/core/app.py +++ b/octopoes/octopoes/core/app.py @@ -21,7 +21,7 @@ def get_xtdb_client(base_uri: str, client: str) -> XTDBHTTPClient: return XTDBHTTPClient(f"{base_uri}/_xtdb", client) -def close_rabbit_channel(queue_uri: str): +def close_rabbit_channel(queue_uri: str) -> None: rabbit_channel = get_rabbit_channel(queue_uri) try: diff --git a/octopoes/octopoes/core/service.py b/octopoes/octopoes/core/service.py index c0494f36d9d..efb4e88f29a 100644 --- a/octopoes/octopoes/core/service.py +++ b/octopoes/octopoes/core/service.py @@ -136,7 +136,7 @@ def get_ooi_tree( valid_time: datetime, search_types: set[type[OOI]] | None = None, depth: int = 1, - ): + ) -> ReferenceTree: tree = self.ooi_repository.get_tree(reference, valid_time, search_types, depth) self._populate_scan_profiles(tree.store.values(), valid_time) return tree @@ -207,7 +207,7 @@ def _run_inference(self, origin: Origin, valid_time: datetime) -> None: self.save_origin(origin, resulting_oois, valid_time) @staticmethod - def check_path_level(path_level: int | None, current_level: int): + def check_path_level(path_level: int | None, current_level: int) -> bool: return path_level is not None and path_level >= current_level def recalculate_scan_profiles(self, valid_time: datetime) -> None: @@ -331,7 +331,7 @@ def recalculate_scan_profiles(self, valid_time: datetime) -> None: ) logger.info("Recalculated scan profiles") - def process_event(self, event: DBEvent): + def process_event(self, event: DBEvent) -> None: # handle event event_handler_name = f"_on_{event.operation_type.value}_{event.entity_type}" handler: Callable[[DBEvent], None] | None = getattr(self, event_handler_name) diff --git a/octopoes/octopoes/models/__init__.py b/octopoes/octopoes/models/__init__.py index 02c970eb36e..516315ba6f6 100644 --- a/octopoes/octopoes/models/__init__.py +++ b/octopoes/octopoes/models/__init__.py @@ -43,12 +43,12 @@ def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHa return core_schema.with_info_after_validator_function(cls.validate, core_schema.str_schema()) @classmethod - def validate(cls, v, info: ValidationInfo): + def validate(cls, v: str, info: ValidationInfo) -> Any: if not isinstance(v, str): raise TypeError("string required") return cls(str(v)) - def __repr__(self): + def __repr__(self) -> str: return f"Reference({super().__repr__()})" @classmethod @@ -122,7 +122,7 @@ class OOI(BaseModel): def model_post_init(self, __context: Any) -> None: # noqa: F841 self.primary_key = self.primary_key or f"{self.get_object_type()}|{self.natural_key}" - def __str__(self): + def __str__(self) -> str: return self.primary_key @classmethod @@ -189,11 +189,11 @@ def get_reverse_relation_name(cls, attr: str) -> str: return cls._reverse_relation_names.get(attr, f"{cls.get_object_type()}_{attr}") @classmethod - def get_tokenized_primary_key(cls, natural_key: str): + def get_tokenized_primary_key(cls, natural_key: str) -> PrimaryKeyToken: token_tree = build_token_tree(cls) natural_key_parts = natural_key.split("|") - def hydrate(node) -> dict | str: + def hydrate(node: dict[str, dict | str]) -> dict | str: for key, value in node.items(): if isinstance(value, dict): node[key] = hydrate(value) @@ -228,10 +228,10 @@ def format_id_short(id_: str) -> str: class PrimaryKeyToken(RootModel): root: dict[str, str | PrimaryKeyToken] - def __getattr__(self, item) -> Any: + def __getattr__(self, item: str) -> Any: return self.root[item] - def __getitem__(self, item) -> Any: + def __getitem__(self, item: str) -> Any: return self.root[item] diff --git a/octopoes/octopoes/models/ooi/dns/records.py b/octopoes/octopoes/models/ooi/dns/records.py index dc510f46000..e204907cef9 100644 --- a/octopoes/octopoes/models/ooi/dns/records.py +++ b/octopoes/octopoes/models/ooi/dns/records.py @@ -166,7 +166,7 @@ class CAATAGS(Enum): ISSUEVMC = "issuevmc" ISSUEMAIL = "issuemail" - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/octopoes/octopoes/models/ooi/findings.py b/octopoes/octopoes/models/ooi/findings.py index a8ca1cd6456..ed148770d9d 100644 --- a/octopoes/octopoes/models/ooi/findings.py +++ b/octopoes/octopoes/models/ooi/findings.py @@ -27,7 +27,7 @@ class RiskLevelSeverity(Enum): def __gt__(self, other: "RiskLevelSeverity") -> bool: return severity_order.index(self.value) > severity_order.index(other.value) - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/octopoes/octopoes/models/ooi/network.py b/octopoes/octopoes/models/ooi/network.py index c36a8da02ae..e661f635cce 100644 --- a/octopoes/octopoes/models/ooi/network.py +++ b/octopoes/octopoes/models/ooi/network.py @@ -90,7 +90,7 @@ class IPPort(OOI): _information_value = ["protocol", "port"] @classmethod - def format_reference_human_readable(cls, reference: Reference): + def format_reference_human_readable(cls, reference: Reference) -> str: tokenized = reference.tokenized return f"{tokenized.address.address}:{tokenized.port}/{tokenized.protocol}" diff --git a/octopoes/octopoes/models/origin.py b/octopoes/octopoes/models/origin.py index 8b87d332810..3d956ad0570 100644 --- a/octopoes/octopoes/models/origin.py +++ b/octopoes/octopoes/models/origin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum from uuid import UUID @@ -20,7 +22,7 @@ class Origin(BaseModel): result: list[Reference] = Field(default_factory=list) task_id: UUID | None = None - def __sub__(self, other) -> set[Reference]: + def __sub__(self, other: Origin) -> set[Reference]: if isinstance(other, Origin): return set(self.result) - set(other.result) else: diff --git a/octopoes/octopoes/models/path.py b/octopoes/octopoes/models/path.py index 545e4c32185..8bb94417ccd 100644 --- a/octopoes/octopoes/models/path.py +++ b/octopoes/octopoes/models/path.py @@ -48,7 +48,7 @@ def parse_step(cls, step: str) -> tuple[Direction, str, type[OOI] | None]: raise ValueError(f"Could not parse step: {step}") @classmethod - def calculate_step(cls, source_type: type[OOI], step: str): + def calculate_step(cls, source_type: type[OOI], step: str) -> Segment: direction, property_name, explicit_target_type = cls.parse_step(step) if explicit_target_type: @@ -96,7 +96,7 @@ def __eq__(self, other: object) -> bool: and self.property_name == other.property_name ) - def __str__(self): + def __str__(self) -> str: if self.direction == Direction.INCOMING: if self.target_type is None: raise ValueError("Direction cannot be incoming if target type is None") @@ -105,7 +105,7 @@ def __str__(self): else: return f"{self.property_name}" - def __repr__(self): + def __repr__(self) -> str: return str(self) @@ -114,7 +114,7 @@ def __init__(self, segments: list[Segment]): self.segments = segments @classmethod - def parse(cls, path: str): + def parse(cls, path: str) -> Path: start_type, step, *rest = path.split(".") segments = [Segment.calculate_step(type_by_name(start_type), step)] @@ -146,7 +146,7 @@ def __lt__(self, other): def __hash__(self): return hash(str(self)) - def __repr__(self): + def __repr__(self) -> str: return str(self) diff --git a/octopoes/octopoes/models/persistence.py b/octopoes/octopoes/models/persistence.py index d10ff1bf03a..fbc420d3adf 100644 --- a/octopoes/octopoes/models/persistence.py +++ b/octopoes/octopoes/models/persistence.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import Any from pydantic import Field from pydantic.fields import FieldInfo @@ -11,7 +11,7 @@ def ReferenceField( *, max_issue_scan_level: int | None = None, max_inherit_scan_level: int | None = None, - **kwargs, + **kwargs: Any, ) -> FieldInfo: if not isinstance(object_type, str): object_type = object_type.get_object_type() diff --git a/octopoes/octopoes/models/tree.py b/octopoes/octopoes/models/tree.py index e6efdc3a3e2..06d7d92a70a 100644 --- a/octopoes/octopoes/models/tree.py +++ b/octopoes/octopoes/models/tree.py @@ -12,7 +12,7 @@ class ReferenceNode(BaseModel): reference: Reference children: dict[str, list[ReferenceNode]] - def filter_children(self, filter_fn: Callable[[ReferenceNode], bool]): + def filter_children(self, filter_fn: Callable[[ReferenceNode], bool]) -> bool: """ Mutable filter function to evict any children from the tree that do not adhere to the provided callback """ diff --git a/octopoes/octopoes/models/types.py b/octopoes/octopoes/models/types.py index 0eb9c16d4dd..62957d29559 100644 --- a/octopoes/octopoes/models/types.py +++ b/octopoes/octopoes/models/types.py @@ -2,6 +2,8 @@ from collections.abc import Iterator +from pydantic.fields import FieldInfo + from octopoes.models import OOI, Reference from octopoes.models.exception import TypeNotFound from octopoes.models.ooi.certificate import ( @@ -206,14 +208,14 @@ def to_concrete(object_types: set[type[OOI]]) -> set[type[OOI]]: return concrete_types -def type_by_name(type_name: str): +def type_by_name(type_name: str) -> type[OOI]: try: return next(t for t in ALL_TYPES if t.__name__ == type_name) except StopIteration: raise TypeNotFound -def related_object_type(field) -> type[OOI]: +def related_object_type(field: FieldInfo) -> type[OOI]: object_type: str | type[OOI] = field.json_schema_extra["object_type"] if isinstance(object_type, str): return type_by_name(object_type) diff --git a/octopoes/octopoes/repositories/ooi_repository.py b/octopoes/octopoes/repositories/ooi_repository.py index f0a15c190ba..9039003fcad 100644 --- a/octopoes/octopoes/repositories/ooi_repository.py +++ b/octopoes/octopoes/repositories/ooi_repository.py @@ -125,12 +125,12 @@ def count_findings_by_severity(self, valid_time: datetime) -> Counter: def list_findings( self, - severities, - valid_time, - exclude_muted, - only_muted, - offset, - limit, + severities: set[RiskLevelSeverity], + valid_time: datetime, + exclude_muted: bool = False, + only_muted: bool = False, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, ) -> Paginated[Finding]: raise NotImplementedError @@ -656,10 +656,10 @@ def list_findings( self, severities: set[RiskLevelSeverity], valid_time: datetime, - exclude_muted=False, - only_muted=False, - offset=DEFAULT_OFFSET, - limit=DEFAULT_LIMIT, + exclude_muted: bool = False, + only_muted: bool = False, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, ) -> Paginated[Finding]: # clause to find risk_severity concrete_finding_types = to_concrete({FindingType}) diff --git a/octopoes/octopoes/repositories/origin_parameter_repository.py b/octopoes/octopoes/repositories/origin_parameter_repository.py index 76c601cb09b..e6b2d5b640a 100644 --- a/octopoes/octopoes/repositories/origin_parameter_repository.py +++ b/octopoes/octopoes/repositories/origin_parameter_repository.py @@ -80,7 +80,7 @@ def list_by_origin(self, origin_id: set[str], valid_time: datetime) -> list[Orig results = self.session.client.query(query, valid_time=valid_time) return [self.deserialize(r[0]) for r in results] - def list_by_reference(self, reference: Reference, valid_time: datetime): + def list_by_reference(self, reference: Reference, valid_time: datetime) -> list[OriginParameter]: query = generate_pull_query( FieldSet.ALL_FIELDS, { diff --git a/octopoes/octopoes/repositories/scan_profile_repository.py b/octopoes/octopoes/repositories/scan_profile_repository.py index 939db4338cf..321ce303662 100644 --- a/octopoes/octopoes/repositories/scan_profile_repository.py +++ b/octopoes/octopoes/repositories/scan_profile_repository.py @@ -53,7 +53,7 @@ def commit(self): self.session.commit() @classmethod - def format_id(cls, ooi_reference: Reference): + def format_id(cls, ooi_reference: Reference) -> str: return f"{cls.object_type}|{ooi_reference}" @classmethod diff --git a/octopoes/octopoes/tasks/tasks.py b/octopoes/octopoes/tasks/tasks.py index 5df58487440..8920dca657f 100644 --- a/octopoes/octopoes/tasks/tasks.py +++ b/octopoes/octopoes/tasks/tasks.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from logging import config, getLogger from pathlib import Path +from typing import Any import yaml from celery.signals import worker_process_init, worker_process_shutdown @@ -44,7 +45,7 @@ def init_worker(**kwargs): @app.task(queue=QUEUE_NAME_OCTOPOES) -def handle_event(event: dict): +def handle_event(event: dict) -> None: try: parsed_event: DBEvent = TypeAdapter(DBEventType).validate_python(event) @@ -75,7 +76,7 @@ def schedule_scan_profile_recalculations(): @app.task(queue=QUEUE_NAME_OCTOPOES) -def recalculate_scan_profiles(org: str, *args, **kwargs): +def recalculate_scan_profiles(org: str, *args: Any, **kwargs: Any) -> None: session = XTDBSession(get_xtdb_client(str(settings.xtdb_uri), org)) octopoes = bootstrap_octopoes(settings, org, session) diff --git a/octopoes/octopoes/xtdb/client.py b/octopoes/octopoes/xtdb/client.py index 529239c3064..dc2872fb1da 100644 --- a/octopoes/octopoes/xtdb/client.py +++ b/octopoes/octopoes/xtdb/client.py @@ -51,7 +51,7 @@ def _get_xtdb_http_session(base_url: str) -> httpx.Client: class XTDBHTTPClient: - def __init__(self, base_url, client: str): + def __init__(self, base_url: str, client: str): self._client = client self._session = _get_xtdb_http_session(base_url) @@ -168,7 +168,7 @@ def export_transactions(self) -> Any: self._verify_response(res) return res.json() - def sync(self, timeout: int | None = None): + def sync(self, timeout: int | None = None) -> Any: params = {} if timeout is not None: @@ -193,10 +193,10 @@ def __enter__(self): def __exit__(self, _exc_type: type[Exception], _exc_value: str, _exc_traceback: str) -> None: self.commit() - def add(self, operation: Operation): + def add(self, operation: Operation) -> None: self._operations.append(operation) - def put(self, document: str | dict[str, Any], valid_time: datetime): + def put(self, document: str | dict[str, Any], valid_time: datetime) -> None: self.add((OperationType.PUT, document, valid_time)) def commit(self) -> None: @@ -214,5 +214,5 @@ def commit(self) -> None: logger.info("Called %s callbacks after committing XTDBSession", len(self.post_commit_callbacks)) self.post_commit_callbacks = [] - def listen_post_commit(self, callback: Callable[[], None]): + def listen_post_commit(self, callback: Callable[[], None]) -> None: self.post_commit_callbacks.append(callback) diff --git a/octopoes/octopoes/xtdb/query.py b/octopoes/octopoes/xtdb/query.py index 3f263046d72..067958688eb 100644 --- a/octopoes/octopoes/xtdb/query.py +++ b/octopoes/octopoes/xtdb/query.py @@ -75,7 +75,7 @@ class Query: _limit: int | None = None _offset: int | None = None - def where(self, ooi_type: Ref, **kwargs) -> "Query": + def where(self, ooi_type: Ref, **kwargs: Ref | str | set[str]) -> "Query": for field_name, value in kwargs.items(): self._where_field_is(ooi_type, field_name, value) @@ -321,7 +321,7 @@ def _assert_type(self, ref: Ref, ooi_type: type[OOI]) -> str: def _to_object_type_statement(self, ref: Ref, other_type: type[OOI]) -> str: return f'[ {self._get_object_alias(ref)} :object_type "{other_type.get_object_type()}" ]' - def _compile_where_clauses(self, *, separator=" ") -> str: + def _compile_where_clauses(self, *, separator: str = " ") -> str: """Sorted and deduplicated where clauses, since they are both idempotent and commutative""" return separator + separator.join(sorted(set(self._where_clauses))) @@ -329,7 +329,7 @@ def _compile_where_clauses(self, *, separator=" ") -> str: def _compile_find_clauses(self) -> str: return " ".join(self._find_clauses) - def _compile(self, *, separator=" ") -> str: + def _compile(self, *, separator: str = " ") -> str: result_ooi_type = self.result_type.type if isinstance(self.result_type, Aliased) else self.result_type self._where_clauses.append(self._assert_type(self.result_type, result_ooi_type)) @@ -361,7 +361,7 @@ def _get_object_alias(self, object_type: Ref) -> str: def __str__(self) -> str: return self._compile() - def __eq__(self, other: object): + def __eq__(self, other: object) -> bool: if not isinstance(other, Query): return NotImplemented diff --git a/octopoes/octopoes/xtdb/query_builder.py b/octopoes/octopoes/xtdb/query_builder.py index 172cfc745c2..54335f5c8d8 100644 --- a/octopoes/octopoes/xtdb/query_builder.py +++ b/octopoes/octopoes/xtdb/query_builder.py @@ -2,7 +2,8 @@ from collections.abc import Iterable, Mapping from typing import Any -from octopoes.xtdb.related_field_generator import FieldSet, RelatedFieldNode +from octopoes.xtdb import FieldSet +from octopoes.xtdb.related_field_generator import RelatedFieldNode def join_csv(values: Iterable[Any]) -> str: diff --git a/octopoes/octopoes/xtdb/related_field_generator.py b/octopoes/octopoes/xtdb/related_field_generator.py index ac48a8f0e88..7ba43a5fd3c 100644 --- a/octopoes/octopoes/xtdb/related_field_generator.py +++ b/octopoes/octopoes/xtdb/related_field_generator.py @@ -69,7 +69,7 @@ def construct_incoming_relations(self): self.path + (foreign_key,), ) - def build_tree(self, depth: int): + def build_tree(self, depth: int) -> None: if depth > 0: self.construct_outgoing_relations() for child_node in self.relations_out.values(): @@ -79,7 +79,7 @@ def build_tree(self, depth: int): for child_node in self.relations_in.values(): child_node.build_tree(depth - 1) - def generate_field(self, field_set: FieldSet, pk_prefix: str): + def generate_field(self, field_set: FieldSet, pk_prefix: str) -> str: queried_fields = pk_prefix if field_set is FieldSet.ONLY_ID else "*" """ Output dicts in XTDB Query Language @@ -123,10 +123,10 @@ def search_nodes(self, search_object_types=set[str]): # Match self return not self.object_types.isdisjoint(search_object_types) - def __repr__(self): + def __repr__(self) -> str: return f"QueryNode[{self}]" - def __str__(self): + def __str__(self) -> str: return ",".join(self.object_types) def __eq__(self, other): diff --git a/pyproject.toml b/pyproject.toml index 8f4ab8bc977..a071805366b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,20 +2,23 @@ python_version = "3.10" plugins = ["pydantic.mypy"] strict = true -follow_imports = "skip" -warn_unused_ignores = false # This gives false positives in pre-commit as long as we don't enable follow imports disallow_subclassing_any = false disallow_untyped_decorators = false # Needed for FastAPI decorators disallow_any_generics = false disallow_untyped_calls = false -disallow_incomplete_defs = false disallow_untyped_defs = false -no_implicit_reexport = false warn_return_any = false [[tool.mypy.overrides]] -module = ["httpx.*"] -follow_imports = "normal" +# Following pydantic imports currently gives 2000 errors +module = ["pydantic.*"] +follow_imports = "skip" + +[[tool.mypy.overrides]] +module = ["bytes.*", "cveapi.*"] +disallow_any_generics = true +disallow_untyped_defs = true +disallow_untyped_calls = true [tool.setuptools_scm] write_to = "_version.py" diff --git a/rocky/account/forms/__init__.py b/rocky/account/forms/__init__.py index 448a075559d..7da11d44645 100644 --- a/rocky/account/forms/__init__.py +++ b/rocky/account/forms/__init__.py @@ -11,3 +11,19 @@ from account.forms.login import LoginForm from account.forms.password_reset import PasswordResetForm from account.forms.token import TwoFactorBackupTokenForm, TwoFactorSetupTokenForm, TwoFactorVerifyTokenForm + +__all__ = [ + "AccountTypeSelectForm", + "IndemnificationAddForm", + "MemberRegistrationForm", + "OnboardingOrganizationUpdateForm", + "OrganizationForm", + "OrganizationMemberEditForm", + "OrganizationUpdateForm", + "SetPasswordForm", + "LoginForm", + "PasswordResetForm", + "TwoFactorBackupTokenForm", + "TwoFactorSetupTokenForm", + "TwoFactorVerifyTokenForm", +] diff --git a/rocky/account/forms/organization.py b/rocky/account/forms/organization.py index c582fa7755e..3ea28c47995 100644 --- a/rocky/account/forms/organization.py +++ b/rocky/account/forms/organization.py @@ -30,11 +30,11 @@ def populate_dropdown_list(self, user): organizations.append([organization.code, organization.name]) if organizations: - props = { - "required": True, - "label": _("Organizations"), - "help_text": _("The organization from which to clone settings."), - "error_messages": self.error_messages, - } - self.fields["organization"] = forms.ChoiceField(**props) + self.fields["organization"] = forms.ChoiceField( + required=True, + label=_("Organizations"), + help_text=_("The organization from which to clone settings."), + error_messages=self.error_messages, + ) + self.fields["organization"].choices = [BLANK_CHOICE] + organizations diff --git a/rocky/account/mixins.py b/rocky/account/mixins.py index 951c6707628..e2de135bec4 100644 --- a/rocky/account/mixins.py +++ b/rocky/account/mixins.py @@ -6,7 +6,7 @@ from django.core.exceptions import PermissionDenied from django.http import Http404 from django.utils.translation import gettext_lazy as _ -from django.views import View +from django.views import ContextMixin, View from tools.models import Indemnification, Organization, OrganizationMember from octopoes.connector.octopoes import OctopoesAPIConnector @@ -25,7 +25,7 @@ class OrganizationPermLookupDict: def __init__(self, organization_member, app_label): self.organization_member, self.app_label = organization_member, app_label - def __repr__(self): + def __repr__(self) -> str: return str(self.organization_member.get_all_permissions) def __getitem__(self, perm_name): @@ -44,7 +44,7 @@ class OrganizationPermWrapper: def __init__(self, organization_member): self.organization_member = organization_member - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__qualname__}({self.organization_member!r})" def __getitem__(self, app_label): @@ -65,7 +65,7 @@ def __contains__(self, perm_name): return self[app_label][perm_name] -class OrganizationView(View): +class OrganizationView(ContextMixin, View): def setup(self, request, *args, **kwargs): super().setup(request, *args, **kwargs) diff --git a/rocky/account/models.py b/rocky/account/models.py index 7d082739f7d..740d3f6d189 100644 --- a/rocky/account/models.py +++ b/rocky/account/models.py @@ -129,7 +129,7 @@ class Meta: models.UniqueConstraint("user", Lower("name"), name="unique name"), ] - def __str__(self): + def __str__(self) -> str: return f"{self.name} ({self.user})" def generate_new_token(self) -> str: diff --git a/rocky/account/views/account.py b/rocky/account/views/account.py index 49c5c298565..2e082adf170 100644 --- a/rocky/account/views/account.py +++ b/rocky/account/views/account.py @@ -23,7 +23,7 @@ def post(self, request, *args, **kwargs): # Mypy doesn't have the information to understand this return self.get(request, *args, **kwargs) # type: ignore[attr-defined] - def handle_page_action(self, action: str): + def handle_page_action(self, action: str) -> None: if action == PageActions.ACCEPT_CLEARANCE.value: self.organization_member.acknowledged_clearance_level = self.organization_member.trusted_clearance_level elif action == PageActions.WITHDRAW_ACCEPTANCE.value: diff --git a/rocky/crisis_room/views.py b/rocky/crisis_room/views.py index 6e16e8ba0ef..46bfec3f68c 100644 --- a/rocky/crisis_room/views.py +++ b/rocky/crisis_room/views.py @@ -14,8 +14,7 @@ from octopoes.connector import ConnectorException from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models.ooi.findings import RiskLevelSeverity -from rocky.views.mixins import ObservedAtMixin -from rocky.views.ooi_view import ConnectorFormMixin +from rocky.views.mixins import ConnectorFormMixin, ObservedAtMixin logger = logging.getLogger(__name__) diff --git a/rocky/katalogus/client.py b/rocky/katalogus/client.py index 84425683c91..5bfed248aa0 100644 --- a/rocky/katalogus/client.py +++ b/rocky/katalogus/client.py @@ -8,6 +8,7 @@ from jsonschema.validators import Draft202012Validator from pydantic import BaseModel, Field, field_serializer from tools.enums import SCAN_LEVEL +from tools.models import OrganizationMember from octopoes.models import OOI from octopoes.models.exception import TypeNotFound @@ -34,7 +35,7 @@ class Plugin(BaseModel): # """Pydantic does not stringify the OOI classes, but then templates can't render them""" # # todo: use field_serializer instead - def can_scan(self, member) -> bool: + def can_scan(self, member: OrganizationMember) -> bool: return member.has_perm("tools.can_scan_organization") @@ -47,10 +48,10 @@ class Boefje(Plugin): # use a custom field_serializer for `consumes` @field_serializer("consumes") - def serialize_consumes(self, consumes: set[type[OOI]]): + def serialize_consumes(self, consumes: set[type[OOI]]) -> set[str]: return {ooi_class.get_ooi_type() for ooi_class in consumes} - def can_scan(self, member) -> bool: + def can_scan(self, member: OrganizationMember) -> bool: return super().can_scan(member) and member.acknowledged_clearance_level >= self.scan_level.value @@ -60,7 +61,7 @@ class Normalizer(Plugin): # use a custom field_serializer for `produces` @field_serializer("produces") - def serialize_produces(self, produces: set[type[OOI]]): + def serialize_produces(self, produces: set[type[OOI]]) -> set[str]: return {ooi_class.get_ooi_type() for ooi_class in produces} @@ -91,11 +92,11 @@ def organization_exists(self) -> bool: return response.status_code != 404 - def create_organization(self, name: str): + def create_organization(self, name: str) -> None: response = self.session.post("/v1/organisations/", json={"id": self.organization, "name": name}) response.raise_for_status() - def delete_organization(self): + def delete_organization(self) -> None: response = self.session.delete(self.organization_uri) response.raise_for_status() @@ -112,7 +113,7 @@ def get_plugin(self, plugin_id: str) -> Plugin: response.raise_for_status() return parse_plugin(response.json()) - def get_plugin_schema(self, plugin_id) -> dict | None: + def get_plugin_schema(self, plugin_id: str) -> dict | None: response = self.session.get(f"{self.organization_uri}/plugins/{plugin_id}/schema.json") response.raise_for_status() @@ -138,17 +139,14 @@ def upsert_plugin_settings(self, plugin_id: str, values: dict) -> None: response = self.session.put(f"{self.organization_uri}/{plugin_id}/settings", json=values) response.raise_for_status() - def delete_plugin_settings(self, plugin_id: str): + def delete_plugin_settings(self, plugin_id: str) -> None: response = self.session.delete(f"{self.organization_uri}/{plugin_id}/settings") response.raise_for_status() - return response - def clone_all_configuration_to_organization(self, to_organization: str): + def clone_all_configuration_to_organization(self, to_organization: str) -> None: response = self.session.post(f"{self.organization_uri}/settings/clone/{to_organization}") response.raise_for_status() - return response - def health(self) -> ServiceHealth: response = self.session.get("/health") response.raise_for_status() diff --git a/rocky/katalogus/forms/plugin_settings.py b/rocky/katalogus/forms/plugin_settings.py index d09ac808066..5040e9156a5 100644 --- a/rocky/katalogus/forms/plugin_settings.py +++ b/rocky/katalogus/forms/plugin_settings.py @@ -1,3 +1,5 @@ +from typing import Any + from django import forms from django.utils.translation import gettext_lazy as _ from jsonschema.validators import Draft202012Validator @@ -13,7 +15,7 @@ class PluginSchemaForm(forms.Form): "required": _("This field is required."), } - def __init__(self, plugin_schema: dict, values: dict, *args, **kwargs): + def __init__(self, plugin_schema: dict, values: dict, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.plugin_schema = plugin_schema self.values = values diff --git a/rocky/katalogus/views/mixins.py b/rocky/katalogus/views/mixins.py index fe4146bdf5f..40be3b5cd3f 100644 --- a/rocky/katalogus/views/mixins.py +++ b/rocky/katalogus/views/mixins.py @@ -1,8 +1,9 @@ from logging import getLogger +from typing import Any from account.mixins import OrganizationView from django.contrib import messages -from django.http import Http404 +from django.http import Http404, HttpRequest from django.shortcuts import redirect from django.urls import reverse from django.utils.translation import gettext_lazy as _ @@ -24,7 +25,7 @@ class SinglePluginView(OrganizationView): katalogus_client: KATalogusClientV1 plugin: KATalogusBoefje | KATalogusNormalizer - def setup(self, request, *args, plugin_id: str, **kwargs): + def setup(self, request: HttpRequest, *args: Any, plugin_id: str, **kwargs: Any) -> None: """ Prepare organization info and KAT-alogus API client. """ diff --git a/rocky/katalogus/views/plugin_detail.py b/rocky/katalogus/views/plugin_detail.py index b1e524f2bd7..4774321cc47 100644 --- a/rocky/katalogus/views/plugin_detail.py +++ b/rocky/katalogus/views/plugin_detail.py @@ -47,13 +47,13 @@ def get_task_history(self) -> Page: input_ooi = self.request.GET.get("task_history_search") status = self.request.GET.get("task_history_status") - if self.request.GET.get("task_history_from"): - min_created_at = datetime.strptime(self.request.GET.get("task_history_from"), "%Y-%m-%d") + if task_history_from := self.request.GET.get("task_history_from"): + min_created_at = datetime.strptime(task_history_from, "%Y-%m-%d") else: min_created_at = None - if self.request.GET.get("task_history_to"): - max_created_at = datetime.strptime(self.request.GET.get("task_history_to"), "%Y-%m-%d") + if task_history_to := self.request.GET.get("task_history_to"): + max_created_at = datetime.strptime(task_history_to, "%Y-%m-%d") else: max_created_at = None diff --git a/rocky/katalogus/views/plugin_enable_disable.py b/rocky/katalogus/views/plugin_enable_disable.py index 8717c6d540b..f5519b85b5b 100644 --- a/rocky/katalogus/views/plugin_enable_disable.py +++ b/rocky/katalogus/views/plugin_enable_disable.py @@ -13,7 +13,7 @@ class PluginEnableDisableView(SinglePluginView): - def check_required_settings(self, settings: dict): + def check_required_settings(self, settings: dict) -> bool: if self.plugin_schema is None or "required" not in self.plugin_schema: return True diff --git a/rocky/onboarding/view_helpers.py b/rocky/onboarding/view_helpers.py index 246fa890ed7..3d01fc412d5 100644 --- a/rocky/onboarding/view_helpers.py +++ b/rocky/onboarding/view_helpers.py @@ -2,7 +2,7 @@ from django.utils.translation import gettext_lazy as _ from reports.views.base import get_selection from tools.models import Organization -from tools.view_helpers import BreadcrumbsMixin, StepsMixin +from tools.view_helpers import Breadcrumb, BreadcrumbsMixin, StepsMixin ONBOARDING_PERMISSIONS = ( "tools.can_scan_organization", @@ -89,7 +89,7 @@ def build_steps(self): class OnboardingBreadcrumbsMixin(BreadcrumbsMixin): organization: Organization - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: return [ { "url": reverse_lazy("step_introduction", kwargs={"organization_code": self.organization.code}), diff --git a/rocky/onboarding/views.py b/rocky/onboarding/views.py index 2aa5f103532..c56a89284eb 100644 --- a/rocky/onboarding/views.py +++ b/rocky/onboarding/views.py @@ -39,7 +39,7 @@ from rocky.exceptions import RockyError from rocky.messaging import clearance_level_warning_dns_report from rocky.views.indemnification_add import IndemnificationAddView -from rocky.views.ooi_view import SingleOOITreeMixin +from rocky.views.mixins import SingleOOITreeMixin User = get_user_model() @@ -119,6 +119,7 @@ class OnboardingSetupScanOOIInfoView( class OnboardingSetupScanOOIAddView( OrganizationPermissionRequiredMixin, IntroductionStepsMixin, + OnboardingBreadcrumbsMixin, SingleOOITreeMixin, FormView, ): @@ -421,11 +422,8 @@ class OnboardingOrganizationSetupView( permission_required = "tools.add_organization" def get(self, request, *args, **kwargs): - members = OrganizationMember.objects.filter(user=self.request.user) - if members: - return redirect( - reverse("step_organization_update", kwargs={"organization_code": members.first().organization.code}) - ) + if member := OrganizationMember.objects.filter(user=self.request.user).first(): + return redirect(reverse("step_organization_update", kwargs={"organization_code": member.organization.code})) return super().get(request, *args, **kwargs) def post(self, request, *args, **kwargs): diff --git a/rocky/poetry.lock b/rocky/poetry.lock index 0b83c59152f..b80fc98d623 100644 --- a/rocky/poetry.lock +++ b/rocky/poetry.lock @@ -638,6 +638,21 @@ url = "https://github.com/jazzband/django-rest-knox" reference = "dd7b062147bc4b9718e22d5acd6cf1301a1036b9" resolved_reference = "dd7b062147bc4b9718e22d5acd6cf1301a1036b9" +[[package]] +name = "django-stubs-ext" +version = "5.0.0" +description = "Monkey-patching and extensions for django-stubs" +optional = false +python-versions = ">=3.8" +files = [ + {file = "django_stubs_ext-5.0.0-py3-none-any.whl", hash = "sha256:8e1334fdf0c8bff87e25d593b33d4247487338aaed943037826244ff788b56a8"}, + {file = "django_stubs_ext-5.0.0.tar.gz", hash = "sha256:5bacfbb498a206d5938454222b843d81da79ea8b6fcd1a59003f529e775bc115"}, +] + +[package.dependencies] +django = "*" +typing-extensions = "*" + [[package]] name = "django-tagulous" version = "1.3.3" @@ -3383,4 +3398,4 @@ test = ["pytest"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "a7a5b4190bcdb94605c963d061082eb405d38c0f037e23bdb0dbf98424a43068" +content-hash = "d5cfc36d60d1c0f4d328cb0571887b8b6722619aff5d08cc3e1e70e0536e95f4" diff --git a/rocky/pyproject.toml b/rocky/pyproject.toml index 761919d60a1..abe3fb24ae4 100644 --- a/rocky/pyproject.toml +++ b/rocky/pyproject.toml @@ -49,6 +49,7 @@ opentelemetry-instrumentation-wsgi = "^0.45b0" opentelemetry-proto = "^1.24.0" opentelemetry-semantic-conventions = "^0.45b0" opentelemetry-util-http = "^0.45b0" +django-stubs-ext = "^5.0.0" [tool.poetry.group.dev.dependencies] diff --git a/rocky/reports/forms.py b/rocky/reports/forms.py index ed34728cf30..c3cfb5bda6a 100644 --- a/rocky/reports/forms.py +++ b/rocky/reports/forms.py @@ -1,3 +1,5 @@ +from typing import Any + from django import forms from django.utils.translation import gettext_lazy as _ from tools.forms.base import BaseRockyForm @@ -12,7 +14,7 @@ class OOITypeMultiCheckboxForReportForm(BaseRockyForm): widget=forms.CheckboxSelectMultiple, ) - def __init__(self, ooi_types: list[str], *args, **kwargs): + def __init__(self, ooi_types: list[str], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.fields["ooi_type"].choices = ((ooi_type, ooi_type) for ooi_type in ooi_types) @@ -24,7 +26,7 @@ class ReportTypeMultiselectForm(BaseRockyForm): widget=forms.CheckboxSelectMultiple, ) - def __init__(self, report_types: set[Report], *args, **kwargs): + def __init__(self, report_types: set[Report], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) report_types_choices = ((report_type.id, report_type.name) for report_type in report_types) self.fields["report_type"].choices = report_types_choices diff --git a/rocky/reports/report_types/aggregate_organisation_report/report.py b/rocky/reports/report_types/aggregate_organisation_report/report.py index ea938ecf724..4afe198909a 100644 --- a/rocky/reports/report_types/aggregate_organisation_report/report.py +++ b/rocky/reports/report_types/aggregate_organisation_report/report.py @@ -25,7 +25,7 @@ class AggregateOrganisationReport(AggregateReport): id = "aggregate-organisation-report" - name = "Aggregate Organisation Report" + name = _("Aggregate Organisation Report") description = "Aggregate Organisation Report" reports = { "required": [SystemReport], @@ -42,7 +42,7 @@ class AggregateOrganisationReport(AggregateReport): } template_path = "aggregate_organisation_report/report.html" - def post_process_data(self, data: dict[str, Any], valid_time) -> dict[str, Any]: + def post_process_data(self, data: dict[str, Any], valid_time: datetime) -> dict[str, Any]: systems: dict[str, dict[str, Any]] = {"services": {}} services = {} open_ports = {} @@ -179,7 +179,7 @@ def post_process_data(self, data: dict[str, Any], valid_time) -> dict[str, Any]: basic_security["system_specific"][SystemType.WEB] = [ report for ip in web_report_data for report in web_report_data[ip] ] - basic_security["system_specific"][SystemType.DNS] = [ + basic_security["syst_specific"][SystemType.DNS] = [ report for ip in dns_report_data for report in dns_report_data[ip] ] @@ -412,7 +412,9 @@ def is_mail_compliant(result): "config_oois": config_oois, } - def collect_system_specific_data(self, data, services, system_type: str, report_id: str) -> dict[str, Any]: + def collect_system_specific_data( + self, data: dict, services: dict, system_type: str, report_id: str + ) -> dict[str, Any]: """Given a system, return a list of report data from the right sub-reports based on the related report_id""" report_data: dict[str, Any] = {} diff --git a/rocky/reports/report_types/definitions.py b/rocky/reports/report_types/definitions.py index 62acac98b23..aaaeac822bf 100644 --- a/rocky/reports/report_types/definitions.py +++ b/rocky/reports/report_types/definitions.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import Any, TypedDict, TypeVar +from django_stubs_ext import StrPromise + from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models import OOI, Reference from octopoes.models.ooi.dns.zone import Hostname @@ -21,7 +23,7 @@ class ReportPlugins(TypedDict): class BaseReport: id: str - name: str + name: StrPromise description: str template_path: str = "report.html" label_style = "1-light" # default/fallback color diff --git a/rocky/reports/report_types/multi_organization_report/report.py b/rocky/reports/report_types/multi_organization_report/report.py index 3f236eba780..4f31f7a826b 100644 --- a/rocky/reports/report_types/multi_organization_report/report.py +++ b/rocky/reports/report_types/multi_organization_report/report.py @@ -257,7 +257,7 @@ def collect_report_data( connector: OctopoesAPIConnector, input_ooi_references: list[str], observed_at: datetime, -): +) -> dict: report_data = {} for ooi in [x for x in input_ooi_references if Reference.from_str(x).class_type == ReportData]: report_data[ooi] = connector.get(Reference.from_str(ooi), observed_at).dict() diff --git a/rocky/reports/report_types/name_server_report/report.py b/rocky/reports/report_types/name_server_report/report.py index 1b5544da74f..0b55154b919 100644 --- a/rocky/reports/report_types/name_server_report/report.py +++ b/rocky/reports/report_types/name_server_report/report.py @@ -40,13 +40,13 @@ def has_dnssec(self): def has_valid_dnssec(self): return sum([check.has_valid_dnssec for check in self.checks]) - def __bool__(self): + def __bool__(self) -> bool: return all(bool(check) for check in self.checks) - def __len__(self): + def __len__(self) -> int: return len(self.checks) - def __add__(self, other: "NameServerChecks"): + def __add__(self, other: "NameServerChecks") -> "NameServerChecks": return NameServerChecks(checks=self.checks + other.checks) diff --git a/rocky/reports/report_types/vulnerability_report/report.py b/rocky/reports/report_types/vulnerability_report/report.py index 8b45bd37d84..73380d21311 100644 --- a/rocky/reports/report_types/vulnerability_report/report.py +++ b/rocky/reports/report_types/vulnerability_report/report.py @@ -142,7 +142,7 @@ def collect_data(self, input_oois: Iterable[str], valid_time: datetime) -> dict[ return result - def get_findings(self, input_oois: Iterable[str], valid_time: datetime): + def get_findings(self, input_oois: Iterable[str], valid_time: datetime) -> dict: ips_by_input_ooi = self.to_ips(input_oois, valid_time) all_ips = list({ip for key, ips in ips_by_input_ooi.items() for ip in ips}) diff --git a/rocky/reports/report_types/web_system_report/report.py b/rocky/reports/report_types/web_system_report/report.py index 55711e45199..dc0bedf1900 100644 --- a/rocky/reports/report_types/web_system_report/report.py +++ b/rocky/reports/report_types/web_system_report/report.py @@ -80,13 +80,13 @@ def certificates_not_expired(self): def certificates_not_expiring_soon(self): return sum([check.certificates_not_expiring_soon for check in self.checks]) - def __bool__(self): + def __bool__(self) -> bool: return all(bool(check) for check in self.checks) - def __len__(self): + def __len__(self) -> int: return len(self.checks) - def __add__(self, other: "WebChecks"): + def __add__(self, other: "WebChecks") -> "WebChecks": return WebChecks(checks=self.checks + other.checks) diff --git a/rocky/reports/views/aggregate_report.py b/rocky/reports/views/aggregate_report.py index 789f2d34471..4e454adabdd 100644 --- a/rocky/reports/views/aggregate_report.py +++ b/rocky/reports/views/aggregate_report.py @@ -10,7 +10,7 @@ from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView from django_weasyprint import WeasyTemplateResponseMixin -from tools.view_helpers import url_with_querystring +from tools.view_helpers import Breadcrumb, url_with_querystring from reports.report_types.aggregate_organisation_report.report import AggregateOrganisationReport, aggregate_reports from reports.report_types.definitions import Report @@ -29,7 +29,7 @@ class BreadcrumbsAggregateReportView(ReportBreadcrumbs): - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() kwargs = self.get_kwargs() selection = get_selection(self.request) diff --git a/rocky/reports/views/base.py b/rocky/reports/views/base.py index c93064e28e4..b41c0b136cc 100644 --- a/rocky/reports/views/base.py +++ b/rocky/reports/views/base.py @@ -1,4 +1,4 @@ -from collections.abc import Iterable, Sequence +from collections.abc import Iterable, Mapping, Sequence from logging import getLogger from operator import attrgetter from typing import Any, Literal, cast @@ -13,7 +13,7 @@ from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView from katalogus.client import KATalogusError, Plugin, get_katalogus -from tools.view_helpers import BreadcrumbsMixin +from tools.view_helpers import Breadcrumb, BreadcrumbsMixin from octopoes.models import OOI from reports.forms import OOITypeMultiCheckboxForReportForm @@ -28,7 +28,7 @@ } -def get_selection(request: HttpRequest, pre_selection: dict[str, str | Sequence[str]] | None = None) -> str: +def get_selection(request: HttpRequest, pre_selection: Mapping[str, str | Sequence[str]] | None = None) -> str: if pre_selection is not None: return "?" + urlencode(pre_selection, True) return "?" + urlencode(request.GET, True) @@ -50,19 +50,17 @@ def get_kwargs(self): def is_valid_breadcrumbs(self): return self.breadcrumbs_step < len(self.breadcrumbs) - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: kwargs = self.get_kwargs() selection = get_selection(self.request) - breadcrumbs = [ + return [ { "url": reverse("reports", kwargs=kwargs) + selection, "text": _("Reports"), }, ] - return breadcrumbs - def get_breadcrumbs(self): if self.is_valid_breadcrumbs(): return self.breadcrumbs[: self.breadcrumbs_step] diff --git a/rocky/reports/views/generate_report.py b/rocky/reports/views/generate_report.py index b4e5b14763e..89c40658d91 100644 --- a/rocky/reports/views/generate_report.py +++ b/rocky/reports/views/generate_report.py @@ -9,7 +9,7 @@ from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView from django_weasyprint import WeasyTemplateResponseMixin -from tools.view_helpers import url_with_querystring +from tools.view_helpers import Breadcrumb, url_with_querystring from octopoes.models import Reference from octopoes.models.exception import ObjectNotFoundException, TypeNotFound @@ -28,7 +28,7 @@ class BreadcrumbsGenerateReportView(ReportBreadcrumbs): - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() kwargs = self.get_kwargs() selection = get_selection(self.request) diff --git a/rocky/reports/views/multi_report.py b/rocky/reports/views/multi_report.py index 0425c0185c3..68ba789dc9d 100644 --- a/rocky/reports/views/multi_report.py +++ b/rocky/reports/views/multi_report.py @@ -8,7 +8,7 @@ from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView from django_weasyprint import WeasyTemplateResponseMixin -from tools.view_helpers import url_with_querystring +from tools.view_helpers import Breadcrumb, url_with_querystring from reports.report_types.multi_organization_report.report import MultiOrganizationReport, collect_report_data from reports.views.base import ( @@ -24,7 +24,7 @@ class BreadcrumbsMultiReportView(ReportBreadcrumbs): - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() kwargs = self.get_kwargs() selection = get_selection(self.request) diff --git a/rocky/requirements-dev.txt b/rocky/requirements-dev.txt index 27f8e373fae..ffe3be4add4 100644 --- a/rocky/requirements-dev.txt +++ b/rocky/requirements-dev.txt @@ -275,6 +275,9 @@ django-phonenumber-field==7.3.0 ; python_version >= "3.10" and python_version < --hash=sha256:bc6eaa49d1f9d870944f5280258db511e3a1ba5e2fbbed255488dceacae45d06 \ --hash=sha256:f9cdb3de085f99c249328293a3b93d4e5fa440c0c8e3b99eb0d0f54748629797 django-rest-knox @ git+https://github.com/jazzband/django-rest-knox@dd7b062147bc4b9718e22d5acd6cf1301a1036b9 ; python_version >= "3.10" and python_version < "4.0" +django-stubs-ext==5.0.0 ; python_version >= "3.10" and python_version < "4.0" \ + --hash=sha256:5bacfbb498a206d5938454222b843d81da79ea8b6fcd1a59003f529e775bc115 \ + --hash=sha256:8e1334fdf0c8bff87e25d593b33d4247487338aaed943037826244ff788b56a8 django-tagulous==1.3.3 ; python_version >= "3.10" and python_version < "4.0" \ --hash=sha256:ad3bb85f4cce83a47e4c0257143229cb92a294defa02fe661823b0442b35d478 \ --hash=sha256:d445590ae1b5cb9b8c5a425f97bf5f01148a33419c19edeb721ebd9fdd6792fe diff --git a/rocky/requirements.txt b/rocky/requirements.txt index cf1f3177b99..0617113b28d 100644 --- a/rocky/requirements.txt +++ b/rocky/requirements.txt @@ -216,6 +216,9 @@ django-phonenumber-field==7.3.0 ; python_version >= "3.10" and python_version < --hash=sha256:bc6eaa49d1f9d870944f5280258db511e3a1ba5e2fbbed255488dceacae45d06 \ --hash=sha256:f9cdb3de085f99c249328293a3b93d4e5fa440c0c8e3b99eb0d0f54748629797 django-rest-knox @ git+https://github.com/jazzband/django-rest-knox@dd7b062147bc4b9718e22d5acd6cf1301a1036b9 ; python_version >= "3.10" and python_version < "4.0" +django-stubs-ext==5.0.0 ; python_version >= "3.10" and python_version < "4.0" \ + --hash=sha256:5bacfbb498a206d5938454222b843d81da79ea8b6fcd1a59003f529e775bc115 \ + --hash=sha256:8e1334fdf0c8bff87e25d593b33d4247487338aaed943037826244ff788b56a8 django-tagulous==1.3.3 ; python_version >= "3.10" and python_version < "4.0" \ --hash=sha256:ad3bb85f4cce83a47e4c0257143229cb92a294defa02fe661823b0442b35d478 \ --hash=sha256:d445590ae1b5cb9b8c5a425f97bf5f01148a33419c19edeb721ebd9fdd6792fe diff --git a/rocky/rocky/bytes_client.py b/rocky/rocky/bytes_client.py index cb2c2964677..f82d85fccef 100644 --- a/rocky/rocky/bytes_client.py +++ b/rocky/rocky/bytes_client.py @@ -30,14 +30,14 @@ def health(self) -> ServiceHealth: return ServiceHealth.parse_obj(response.json()) @staticmethod - def raw_from_declarations(declarations: list[Declaration]): + def raw_from_declarations(declarations: list[Declaration]) -> bytes: json_string = f"[{','.join([declaration.json() for declaration in declarations])}]" return json_string.encode("utf-8") def add_manual_proof( self, normalizer_id: uuid.UUID, raw: bytes, manual_mime_types: Set[str] = frozenset({"manual/ooi"}) - ): + ) -> None: """Per convention for a generic normalizer, we add a raw list of declarations, not a single declaration""" self.login() @@ -70,7 +70,7 @@ def add_manual_proof( ), ) - def upload_raw(self, raw: bytes, manual_mime_types: set[str], input_ooi: str | None = None): + def upload_raw(self, raw: bytes, manual_mime_types: set[str], input_ooi: str | None = None) -> None: self.login() boefje_meta = BoefjeMeta( diff --git a/rocky/rocky/exceptions.py b/rocky/rocky/exceptions.py index ccd3e720012..8f76f5c052c 100644 --- a/rocky/rocky/exceptions.py +++ b/rocky/rocky/exceptions.py @@ -1,3 +1,6 @@ +from typing import Any + + class RockyError(Exception): pass @@ -21,13 +24,13 @@ class TrustedClearanceLevelTooLowException(ClearanceLevelTooLowException): class ServiceException(RockyError): """Base exception representing an issue with an (external) service""" - def __init__(self, service_name: str, *args): + def __init__(self, service_name: str, *args: Any): super().__init__(*args) self.service_name = service_name class OctopoesException(ServiceException): - def __init__(self, *args): + def __init__(self, *args: Any): super().__init__("Octopoes", *args) diff --git a/rocky/rocky/keiko.py b/rocky/rocky/keiko.py index 944cc4ddd1d..b202ae89387 100644 --- a/rocky/rocky/keiko.py +++ b/rocky/rocky/keiko.py @@ -196,7 +196,7 @@ def get_organization_finding_report( return self.get_report(valid_time, "Organisatie", organization_name, store, filters) @classmethod - def ooi_report_file_name(cls, valid_time: datetime, organization_code: str, ooi_id: str): + def ooi_report_file_name(cls, valid_time: datetime, organization_code: str, ooi_id: str) -> str: report_file_name = "_".join( [ "bevindingenrapport", @@ -214,7 +214,7 @@ def ooi_report_file_name(cls, valid_time: datetime, organization_code: str, ooi_ return report_file_name @classmethod - def organization_report_file_name(cls, organization_code: str): + def organization_report_file_name(cls, organization_code: str) -> str: file_name = "_".join( [ "bevindingenrapport_nl", @@ -226,7 +226,7 @@ def organization_report_file_name(cls, organization_code: str): return f"{file_name}.pdf" -def _ooi_field_as_string(findings_grouped: dict, store: dict): +def _ooi_field_as_string(findings_grouped: dict, store: dict) -> dict: new_findings_grouped = {} for finding_type, finding_group in findings_grouped.items(): diff --git a/rocky/rocky/middleware/onboarding.py b/rocky/rocky/middleware/onboarding.py index fb96f152fa7..3f14fcd71a0 100644 --- a/rocky/rocky/middleware/onboarding.py +++ b/rocky/rocky/middleware/onboarding.py @@ -31,12 +31,12 @@ def middleware(request): if request.user.is_superuser: return redirect(reverse("step_introduction_registration")) - member = OrganizationMember.objects.filter(user=request.user) - # Members with these permissions can run a full DNS-report onboarding. - if member.exists() and member.first().has_perms(ONBOARDING_PERMISSIONS): + if (member := OrganizationMember.objects.filter(user=request.user).first()) and member.has_perms( + ONBOARDING_PERMISSIONS + ): return redirect( - reverse("step_introduction", kwargs={"organization_code": member.first().organization.code}) + reverse("step_introduction", kwargs={"organization_code": member.organization.code}) ) return response diff --git a/rocky/rocky/paginator.py b/rocky/rocky/paginator.py index d78b18c5097..2b74578fe83 100644 --- a/rocky/rocky/paginator.py +++ b/rocky/rocky/paginator.py @@ -1,3 +1,5 @@ +from typing import Any + from django.core.paginator import EmptyPage, Page, PageNotAnInteger, Paginator from django.utils.translation import gettext_lazy as _ @@ -7,25 +9,25 @@ def __init__( self, *args, **kwargs, - ) -> None: + ): super().__init__(*args, **kwargs) if self.orphans != 0: raise ValueError("Setting orphans is not supported") - def validate_number(self, number) -> int: + def validate_number(self, number: Any) -> int: """Validate the given 1-based page number.""" try: if isinstance(number, float) and not number.is_integer(): raise ValueError - number = int(number) + parsed_number = int(number) except (TypeError, ValueError): raise PageNotAnInteger(_("That page number is not an integer")) - if number < 1: + if parsed_number < 1: raise EmptyPage(_("That page number is less than 1")) - return number + return parsed_number - def page(self, number) -> Page: + def page(self, number: Any) -> Page: """Return a Page object per page number.""" number = self.validate_number(number) bottom = (number - 1) * self.per_page diff --git a/rocky/rocky/scheduler.py b/rocky/rocky/scheduler.py index f9ac40fb633..83e46cfd55f 100644 --- a/rocky/rocky/scheduler.py +++ b/rocky/rocky/scheduler.py @@ -126,7 +126,7 @@ class LazyTaskList: def __init__( self, scheduler_client: SchedulerClient, - **kwargs, + **kwargs: Any, ): self.scheduler_client = scheduler_client self.kwargs = kwargs @@ -144,7 +144,7 @@ def count(self) -> int: def __len__(self): return self.count - def __getitem__(self, key) -> list[Task]: + def __getitem__(self, key: slice | int) -> list[Task]: if isinstance(key, slice): offset = key.start or 0 limit = key.stop - offset @@ -167,7 +167,7 @@ def __getitem__(self, key) -> list[Task]: class SchedulerError(Exception): message = _("Connectivity issues with Mula.") - def __str__(self): + def __str__(self) -> str: return str(self.message) @@ -193,7 +193,7 @@ def __init__(self, base_uri: str): def list_tasks( self, - **kwargs, + **kwargs: Any, ) -> PaginatedTasksResponse: kwargs = {k: v for k, v in kwargs.items() if v is not None} # filter Nones from kwargs res = self._client.get("/tasks", params=kwargs) diff --git a/rocky/rocky/views/finding_add.py b/rocky/rocky/views/finding_add.py index e165fd50bb3..a4d439cc04b 100644 --- a/rocky/rocky/views/finding_add.py +++ b/rocky/rocky/views/finding_add.py @@ -1,6 +1,7 @@ from datetime import datetime, timezone from uuid import uuid4 +from django.forms import Form from django.shortcuts import redirect from django.urls.base import reverse from django.utils.translation import gettext_lazy as _ @@ -78,7 +79,7 @@ def get_form_kwargs(self): return kwargs - def get_form(self, form_class=None) -> FindingAddForm: + def get_form(self, form_class: type[Form] | None = None) -> FindingAddForm: if form_class is None: form_class = self.get_form_class() diff --git a/rocky/rocky/views/finding_list.py b/rocky/rocky/views/finding_list.py index 042fd67a250..ff52486d012 100644 --- a/rocky/rocky/views/finding_list.py +++ b/rocky/rocky/views/finding_list.py @@ -7,7 +7,7 @@ from django.views.generic import ListView from tools.forms.base import ObservedAtForm from tools.forms.findings import FindingSeverityMultiSelectForm, MutedFindingSelectionForm -from tools.view_helpers import BreadcrumbsMixin +from tools.view_helpers import Breadcrumb, BreadcrumbsMixin from octopoes.models.ooi.findings import RiskLevelSeverity from rocky.views.mixins import ConnectorFormMixin, FindingList, OctopoesView, SeveritiesMixin @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -def sort_by_severity_desc(findings) -> list[dict[str, Any]]: +def sort_by_severity_desc(findings: Iterable) -> list[dict[str, Any]]: # Sorting is stable (when multiple records have the same key, their original # order is preserved) so if we first sort by finding id the findings with # the same risk score will be sorted by finding id @@ -83,7 +83,7 @@ class FindingListView(BreadcrumbsMixin, FindingListFilter): template_name = "findings/finding_list.html" paginate_by = 20 - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: return [ { "url": reverse_lazy("finding_list", kwargs={"organization_code": self.organization.code}), @@ -96,7 +96,7 @@ class Top10FindingListView(FindingListView): template_name = "findings/finding_list.html" paginate_by = 10 - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: return [ { "url": reverse_lazy("organization_crisis_room", kwargs={"organization_code": self.organization.code}), diff --git a/rocky/rocky/views/health.py b/rocky/rocky/views/health.py index 222df344813..2994b30fe83 100644 --- a/rocky/rocky/views/health.py +++ b/rocky/rocky/views/health.py @@ -1,7 +1,8 @@ import logging +from typing import Any from account.mixins import OrganizationView -from django.http import JsonResponse +from django.http import HttpRequest, JsonResponse from django.urls.base import reverse from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView, View @@ -19,7 +20,7 @@ class Health(OrganizationView, View): - def get(self, request, *args, **kwargs) -> JsonResponse: + def get(self, request: HttpRequest, *args: Any, **kwargs: Any) -> JsonResponse: octopoes_connector = self.octopoes_api_connector rocky_health = get_rocky_health(octopoes_connector) return JsonResponse(rocky_health.model_dump()) diff --git a/rocky/rocky/views/mixins.py b/rocky/rocky/views/mixins.py index 17f80b0a353..4502e6f7ac8 100644 --- a/rocky/rocky/views/mixins.py +++ b/rocky/rocky/views/mixins.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime, timezone from functools import cached_property @@ -17,9 +18,9 @@ from tools.ooi_helpers import get_knowledge_base_data_for_ooi_store from tools.view_helpers import convert_date_to_datetime, get_ooi_url -from octopoes.connector import ObjectNotFoundException from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models import OOI, Reference, ScanLevel, ScanProfileType +from octopoes.models.exception import ObjectNotFoundException from octopoes.models.explanation import InheritanceSection from octopoes.models.ooi.findings import Finding, FindingType, RiskLevelSeverity from octopoes.models.origin import Origin, OriginType @@ -122,7 +123,7 @@ def get_origins( logger.error(e) return [], [], [] - def handle_connector_exception(self, exception: Exception): + def handle_connector_exception(self, exception: Exception) -> None: if isinstance(exception, ObjectNotFoundException): raise Http404("OOI not found") @@ -198,7 +199,7 @@ def __init__( self, octopoes_connector: OctopoesAPIConnector, valid_time: datetime, - severities: set[RiskLevelSeverity], + severities: Iterable[RiskLevelSeverity], exclude_muted: bool = True, only_muted: bool = False, ): @@ -305,7 +306,7 @@ def get_breadcrumb_list(self): }, ] - def get_ooi_properties(self, ooi: OOI): + def get_ooi_properties(self, ooi: OOI) -> dict: class_relations = get_relations(ooi.__class__) props = {field_name: value for field_name, value in ooi if field_name not in class_relations} diff --git a/rocky/rocky/views/ooi_detail.py b/rocky/rocky/views/ooi_detail.py index 3052b3e1637..61f7adfadeb 100644 --- a/rocky/rocky/views/ooi_detail.py +++ b/rocky/rocky/views/ooi_detail.py @@ -5,7 +5,7 @@ from django.contrib import messages from django.core.paginator import Page, Paginator -from django.http import Http404 +from django.http import Http404, HttpResponse from django.shortcuts import redirect from django.utils.translation import gettext_lazy as _ from httpx import HTTPError @@ -58,7 +58,7 @@ def post(self, request, *args, **kwargs): action = self.request.POST.get("action") return self.handle_page_action(action) - def handle_page_action(self, action: str) -> bool: + def handle_page_action(self, action: str) -> HttpResponse: try: if action == PageActions.CHANGE_CLEARANCE_LEVEL.value: clearance_level = int(self.request.POST.get("level")) @@ -127,13 +127,13 @@ def get_task_history(self) -> Page: status = self.request.GET.get("task_history_status") - if self.request.GET.get("task_history_from"): - min_created_at = datetime.strptime(self.request.GET.get("task_history_from"), "%Y-%m-%d") + if task_history_from := self.request.GET.get("task_history_from"): + min_created_at = datetime.strptime(task_history_from, "%Y-%m-%d") else: min_created_at = None - if self.request.GET.get("task_history_to"): - max_created_at = datetime.strptime(self.request.GET.get("task_history_to"), "%Y-%m-%d") + if task_history_to := self.request.GET.get("task_history_to"): + max_created_at = datetime.strptime(task_history_to, "%Y-%m-%d") else: max_created_at = None diff --git a/rocky/rocky/views/ooi_detail_related_object.py b/rocky/rocky/views/ooi_detail_related_object.py index e34eeabc36b..ed594add330 100644 --- a/rocky/rocky/views/ooi_detail_related_object.py +++ b/rocky/rocky/views/ooi_detail_related_object.py @@ -1,4 +1,5 @@ from collections import Counter +from typing import Any from django.shortcuts import redirect from django.urls import reverse @@ -10,7 +11,7 @@ from octopoes.models import OOI from octopoes.models.ooi.findings import Finding, FindingType, RiskLevelSeverity from octopoes.models.types import OOI_TYPES, get_relations, to_concrete -from rocky.views.ooi_view import SingleOOITreeMixin +from rocky.views.mixins import SingleOOITreeMixin class OOIRelatedObjectManager(SingleOOITreeMixin): @@ -78,7 +79,7 @@ def get(self, request, *args, **kwargs): return super().get(request, *args, **kwargs) - def split_ooi_type_choice(self, ooi_type_choice) -> dict[str, str]: + def split_ooi_type_choice(self, ooi_type_choice: str) -> dict[str, Any]: ooi_type = ooi_type_choice.split("|", 1) return { diff --git a/rocky/rocky/views/ooi_list.py b/rocky/rocky/views/ooi_list.py index 0c27d65fc7e..9b6c475e53c 100644 --- a/rocky/rocky/views/ooi_list.py +++ b/rocky/rocky/views/ooi_list.py @@ -2,6 +2,7 @@ import json from datetime import datetime, timezone from enum import Enum +from typing import Any from django.contrib import messages from django.http import Http404, HttpRequest, HttpResponse @@ -55,14 +56,14 @@ def get_context_data(self, **kwargs): return context - def get(self, request: HttpRequest, *args, status=200, **kwargs) -> HttpResponse: + def get(self, request: HttpRequest, *args: Any, status: int = 200, **kwargs: Any) -> HttpResponse: """Override the response status in case submitting a form returns an error message""" response = super().get(request, *args, **kwargs) response.status_code = status return response - def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: """Perform bulk action on selected oois.""" selected_oois = request.POST.getlist("ooi") if not selected_oois: @@ -87,7 +88,7 @@ def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: return self.get(request, status=404, *args, **kwargs) def _set_scan_profiles( - self, selected_oois: list[Reference], level: CUSTOM_SCAN_LEVEL, request: HttpRequest, *args, **kwargs + self, selected_oois: list[Reference], level: CUSTOM_SCAN_LEVEL, request: HttpRequest, *args: Any, **kwargs: Any ) -> HttpResponse: try: self.raise_clearance_levels(selected_oois, level.value) @@ -153,7 +154,7 @@ def _set_scan_profiles( return self.get(request, *args, **kwargs) def _set_oois_to_inherit( - self, selected_oois: list[Reference], request: HttpRequest, *args, **kwargs + self, selected_oois: list[Reference], request: HttpRequest, *args: Any, **kwargs: Any ) -> HttpResponse: scan_profiles = [EmptyScanProfile(reference=Reference.from_str(ooi)) for ooi in selected_oois] @@ -181,7 +182,9 @@ def _set_oois_to_inherit( ) return self.get(request, *args, **kwargs) - def _delete_oois(self, selected_oois: list[Reference], request: HttpRequest, *args, **kwargs) -> HttpResponse: + def _delete_oois( + self, selected_oois: list[Reference], request: HttpRequest, *args: Any, **kwargs: Any + ) -> HttpResponse: connector = self.octopoes_api_connector valid_time = datetime.now(timezone.utc) diff --git a/rocky/rocky/views/ooi_report.py b/rocky/rocky/views/ooi_report.py index 9496d17bc76..4725e873300 100644 --- a/rocky/rocky/views/ooi_report.py +++ b/rocky/rocky/views/ooi_report.py @@ -174,7 +174,7 @@ def get_ooi_type_filter(cls): return [ooi.get_ooi_type() for ooi in cls.allowed_ooi_types] @classmethod - def get_boefjes(cls, organization: Organization): + def get_boefjes(cls, organization: Organization) -> list: cls.boefjes = [] katalogus_boefjes = get_katalogus(organization.code).get_boefjes() diff --git a/rocky/rocky/views/ooi_tree.py b/rocky/rocky/views/ooi_tree.py index 032abe8aa44..a00d918dd6c 100644 --- a/rocky/rocky/views/ooi_tree.py +++ b/rocky/rocky/views/ooi_tree.py @@ -13,7 +13,7 @@ class OOITreeView(BaseOOIDetailView): def get_tree_dict(self): return create_object_tree_item_from_ref(self.tree.root, self.tree.store) - def get_filtered_tree(self, tree_dict): + def get_filtered_tree(self, tree_dict: dict) -> dict: filtered_types = self.request.GET.getlist("ooi_type", []) return filter_ooi_tree(tree_dict, filtered_types) @@ -60,7 +60,7 @@ def get_last_breadcrumb(self): class OOIGraphView(OOITreeView): template_name = "graph-d3.html" - def get_filtered_tree(self, tree_dict): + def get_filtered_tree(self, tree_dict: dict) -> dict: filtered_tree = super().get_filtered_tree(tree_dict) return hydrate_tree(filtered_tree, self.organization.code) @@ -71,11 +71,11 @@ def get_last_breadcrumb(self): } -def hydrate_tree(tree, organization_code: str): +def hydrate_tree(tree: dict, organization_code: str) -> dict: return hydrate_branch(tree, organization_code) -def hydrate_branch(branch, organization_code: str): +def hydrate_branch(branch: dict, organization_code: str) -> dict: branch["name"] = branch["tree_meta"]["location"] + "-" + branch["ooi_type"] branch["overlay_data"] = {"Type": branch["ooi_type"]} if branch["ooi_type"] == "Finding": diff --git a/rocky/rocky/views/ooi_view.py b/rocky/rocky/views/ooi_view.py index d48205d9d91..7d26929bc50 100644 --- a/rocky/rocky/views/ooi_view.py +++ b/rocky/rocky/views/ooi_view.py @@ -2,7 +2,8 @@ from time import sleep from typing import Any -from django import forms, http +from django import http +from django.forms import Form from django.shortcuts import redirect from django.urls import reverse from django.utils.translation import gettext_lazy as _ @@ -138,12 +139,12 @@ def build_breadcrumbs(self) -> list[Breadcrumb]: class BaseOOIFormView(SingleOOIMixin, FormView): ooi_class: type[OOI] - form_class: forms.Form = OOIForm + form_class: type[BaseRockyForm] = OOIForm def get_ooi_class(self): return self.ooi.__class__ if hasattr(self, "ooi") else None - def get_form(self, form_class=None) -> BaseRockyForm: + def get_form(self, form_class: type[Form] | None = None) -> BaseRockyForm: form = super().get_form(form_class) # Disable natural key attributes diff --git a/rocky/rocky/views/organization_member_add.py b/rocky/rocky/views/organization_member_add.py index f4af5b320ef..c1f724890a6 100644 --- a/rocky/rocky/views/organization_member_add.py +++ b/rocky/rocky/views/organization_member_add.py @@ -10,6 +10,7 @@ from django.contrib.auth.models import Group from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.db import transaction +from django.forms import Form from django.http import FileResponse, HttpRequest, HttpResponse from django.shortcuts import redirect from django.urls import reverse_lazy @@ -19,7 +20,7 @@ from onboarding.view_helpers import DNS_REPORT_LEAST_CLEARANCE_LEVEL from tools.forms.upload_csv import UploadCSVForm from tools.models import GROUP_ADMIN, GROUP_CLIENT, GROUP_REDTEAM, OrganizationMember -from tools.view_helpers import OrganizationMemberBreadcrumbsMixin +from tools.view_helpers import Breadcrumb, OrganizationMemberBreadcrumbsMixin from rocky.messaging import clearance_level_warning_dns_report @@ -65,7 +66,7 @@ def get(self, request: HttpRequest, *args: str, **kwargs: Any) -> HttpResponse: ) ) - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() breadcrumbs.append( { @@ -110,7 +111,7 @@ def add_success_notification(self): def get_success_url(self, **kwargs): return reverse_lazy("organization_member_list", kwargs={"organization_code": self.organization.code}) - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() breadcrumbs.extend( [ @@ -159,7 +160,7 @@ def form_valid(self, form): self.process_csv(form) return super().form_valid(form) - def process_csv(self, form) -> None: + def process_csv(self, form: Form) -> None: csv_raw_data = form.cleaned_data["csv_file"].read() csv_data = io.StringIO(csv_raw_data.decode("UTF-8")) @@ -215,7 +216,7 @@ def process_csv(self, form) -> None: def save_models( self, name: str, email: str, account_type: str, trusted_clearance: int, acknowledged_clearance: int - ): + ) -> None: user, user_created = User.objects.get_or_create(email=email, defaults={"full_name": name}) member_kwargs = { diff --git a/rocky/rocky/views/organization_member_list.py b/rocky/rocky/views/organization_member_list.py index 3a53e04b07b..f4077fde35b 100644 --- a/rocky/rocky/views/organization_member_list.py +++ b/rocky/rocky/views/organization_member_list.py @@ -66,7 +66,7 @@ def post(self, request, *args, **kwargs): self.handle_page_action(request.POST.get("action")) return redirect(reverse("organization_member_list", kwargs={"organization_code": self.organization.code})) - def handle_page_action(self, action: str): + def handle_page_action(self, action: str) -> None: member_id = self.request.POST.get("member_id") organizationmember = self.model.objects.get(id=member_id) try: diff --git a/rocky/rocky/views/organization_settings.py b/rocky/rocky/views/organization_settings.py index 838c07f500a..9ec4a7f168e 100644 --- a/rocky/rocky/views/organization_settings.py +++ b/rocky/rocky/views/organization_settings.py @@ -1,5 +1,6 @@ from datetime import datetime from enum import Enum +from typing import Any from account.mixins import OrganizationPermissionRequiredMixin, OrganizationView from django.contrib import messages @@ -20,7 +21,7 @@ class OrganizationSettingsView( template_name = "organizations/organization_settings.html" permission_required = "tools.view_organization" - def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: """Perform actions based on action type""" action = request.POST.get("action") if not self.request.user.has_perm("tools.can_recalculate_bits"): diff --git a/rocky/rocky/views/privacy_statement.py b/rocky/rocky/views/privacy_statement.py index c1d0e8ee21e..ec4a516b802 100644 --- a/rocky/rocky/views/privacy_statement.py +++ b/rocky/rocky/views/privacy_statement.py @@ -1,4 +1,4 @@ -from django.shortcuts import reverse +from django.urls import reverse from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView diff --git a/rocky/rocky/views/scan_profile.py b/rocky/rocky/views/scan_profile.py index d82620b7cbe..0928ec28d1a 100644 --- a/rocky/rocky/views/scan_profile.py +++ b/rocky/rocky/views/scan_profile.py @@ -17,7 +17,7 @@ class ScanProfileDetailView(OOIDetailView, FormView): template_name = "scan_profiles/scan_profile_detail.html" form_class = SetClearanceLevelForm - def get_context_data(self, **kwargs) -> dict[str, Any]: + def get_context_data(self, **kwargs: Any) -> dict[str, Any]: context = super().get_context_data(**kwargs) context["mandatory_fields"] = get_mandatory_fields(self.request) context["user"] = self.organization_member diff --git a/rocky/rocky/views/tasks.py b/rocky/rocky/views/tasks.py index 897f89eec5f..8b4b65768dd 100644 --- a/rocky/rocky/views/tasks.py +++ b/rocky/rocky/views/tasks.py @@ -37,6 +37,7 @@ def get(self, request, *args, **kwargs): class TaskListView(OrganizationView, ListView): paginate_by = 20 paginator_class = RockyPaginator + plugin_type: str def get_queryset(self): scheduler_id = self.plugin_type + "-" + self.organization.code @@ -46,13 +47,13 @@ def get_queryset(self): input_ooi = self.request.GET.get("scan_history_search") if self.request.GET.get("scan_history_search") else None - if self.request.GET.get("scan_history_from"): - min_created_at = datetime.strptime(self.request.GET.get("scan_history_from"), "%Y-%m-%d") + if scan_history_from := self.request.GET.get("scan_history_from"): + min_created_at = datetime.strptime(scan_history_from, "%Y-%m-%d") else: min_created_at = None - if self.request.GET.get("scan_history_to"): - max_created_at = datetime.strptime(self.request.GET.get("scan_history_to"), "%Y-%m-%d") + if scan_history_to := self.request.GET.get("scan_history_to"): + max_created_at = datetime.strptime(scan_history_to, "%Y-%m-%d") else: max_created_at = None diff --git a/rocky/rocky/views/upload_csv.py b/rocky/rocky/views/upload_csv.py index f2056b6aa3d..b1527fba5f3 100644 --- a/rocky/rocky/views/upload_csv.py +++ b/rocky/rocky/views/upload_csv.py @@ -16,7 +16,7 @@ from tools.forms.upload_oois import UploadOOICSVForm from octopoes.api.models import Declaration -from octopoes.models import Reference +from octopoes.models import OOI, Reference from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.network import IPAddressV4, IPAddressV6, Network from octopoes.models.ooi.web import URL @@ -79,7 +79,7 @@ def get_context_data(self, **kwargs): context["criteria"] = CSV_CRITERIA return context - def get_or_create_reference(self, ooi_type_name: str, value: str | None): + def get_or_create_reference(self, ooi_type_name: str, value: str | None) -> OOI: ooi_type_name = next(filter(lambda x: x.casefold() == ooi_type_name.casefold(), self.ooi_types.keys())) # get from cache @@ -100,7 +100,7 @@ def get_or_create_reference(self, ooi_type_name: str, value: str | None): return ooi - def get_ooi_from_csv(self, ooi_type_name: str, values: dict[str, str]): + def get_ooi_from_csv(self, ooi_type_name: str, values: dict[str, str]) -> tuple[OOI, int | None]: key = "clearance" level = int(values[key]) if key in values and values[key] in CLEARANCE_VALUES else None ooi_type = self.ooi_types[ooi_type_name]["type"] @@ -110,7 +110,7 @@ def get_ooi_from_csv(self, ooi_type_name: str, values: dict[str, str]): if field not in self.skip_properties ] - kwargs = {} + kwargs: dict[str, Any] = {} for field, is_reference, required in ooi_fields: if is_reference and required: try: diff --git a/rocky/tools/add_ooi_information.py b/rocky/tools/add_ooi_information.py index 755f4704c25..00dff377538 100644 --- a/rocky/tools/add_ooi_information.py +++ b/rocky/tools/add_ooi_information.py @@ -60,7 +60,7 @@ def iana_service_table(search_query: str) -> list[_Service]: return services -def service_info(value) -> tuple[str, str]: +def service_info(value: str) -> tuple[str, str]: """Provides information about IP Services such as common assigned ports for certain protocols and descriptions""" services = iana_service_table(value) source = "https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml" diff --git a/rocky/tools/forms/base.py b/rocky/tools/forms/base.py index 7bf73024c8f..0491ea4bbbd 100644 --- a/rocky/tools/forms/base.py +++ b/rocky/tools/forms/base.py @@ -98,17 +98,17 @@ class CheckboxGroup(forms.CheckboxSelectMultiple): def __init__( self, required_options: list[str] | None = None, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self.required_options = required_options or [] - def get_context(self, name, value, attrs) -> dict[str, Any]: + def get_context(self, name: str, value: Any, attrs: dict) -> dict[str, Any] | None: context = super().get_context(name, value, attrs) return context - def create_option(self, *arg, **kwargs) -> dict[str, Any]: + def create_option(self, *arg: Any, **kwargs: Any) -> dict[str, Any]: option = super().create_option(*arg, **kwargs) option["wrap_label"] = self.wrap_label option["attrs"]["checked"] = self.is_required_option(option["value"]) diff --git a/rocky/tools/forms/finding_type.py b/rocky/tools/forms/finding_type.py index ee06c19c281..29311df7bde 100644 --- a/rocky/tools/forms/finding_type.py +++ b/rocky/tools/forms/finding_type.py @@ -1,12 +1,13 @@ from datetime import datetime, timezone +from typing import Any from django import forms from django.core.exceptions import ValidationError from django.utils.translation import gettext_lazy as _ -from octopoes.connector import ObjectNotFoundException from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models import Reference +from octopoes.models.exception import ObjectNotFoundException from tools.forms.base import BaseRockyForm, DataListInput, DateTimeInput from tools.forms.settings import ( FINDING_DATETIME_HELP_TEXT, @@ -141,8 +142,8 @@ def __init__( self, connector: OctopoesAPIConnector, ooi_list: list[tuple[str, str]], - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): self.octopoes_connector = connector super().__init__(*args, **kwargs) diff --git a/rocky/tools/forms/ooi.py b/rocky/tools/forms/ooi.py index 849862f46af..d925be98039 100644 --- a/rocky/tools/forms/ooi.py +++ b/rocky/tools/forms/ooi.py @@ -25,7 +25,7 @@ class OoiTreeSettingsForm(OOIReportSettingsForm): required=False, ) - def __init__(self, ooi_types: list[str], *args, **kwargs): + def __init__(self, ooi_types: list[str], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.set_ooi_types(ooi_types) @@ -55,8 +55,8 @@ def __init__( oois: list[OOI], organization_code: str, mandatory_fields: list | None = None, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) self.fields["ooi"].widget.attrs["organization_code"] = organization_code diff --git a/rocky/tools/forms/ooi_form.py b/rocky/tools/forms/ooi_form.py index 372a26164ec..05f5834e502 100644 --- a/rocky/tools/forms/ooi_form.py +++ b/rocky/tools/forms/ooi_form.py @@ -2,7 +2,7 @@ from enum import Enum from inspect import isclass from ipaddress import IPv4Address, IPv6Address -from typing import Literal, Union, get_args, get_origin +from typing import Any, Literal, Union, get_args, get_origin from django import forms from django.utils.translation import gettext_lazy as _ @@ -13,13 +13,13 @@ from octopoes.models import OOI from octopoes.models.ooi.question import Question from octopoes.models.types import get_collapsed_types, get_relations +from tools.enums import SCAN_LEVEL from tools.forms.base import BaseRockyForm, CheckboxGroup from tools.forms.settings import CLEARANCE_TYPE_CHOICES -from tools.models import SCAN_LEVEL class OOIForm(BaseRockyForm): - def __init__(self, ooi_class: type[OOI], connector: OctopoesAPIConnector, *args, **kwargs): + def __init__(self, ooi_class: type[OOI], connector: OctopoesAPIConnector, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.ooi_class = ooi_class self.api_connector = connector diff --git a/rocky/tools/management/commands/export_migrations.py b/rocky/tools/management/commands/export_migrations.py index 0fc4711db67..eb8980e8bb5 100644 --- a/rocky/tools/management/commands/export_migrations.py +++ b/rocky/tools/management/commands/export_migrations.py @@ -1,5 +1,6 @@ from logging import getLogger from pathlib import Path +from typing import Any from django.core.management import BaseCommand, CommandParser from django.db import DEFAULT_DB_ALIAS, connections @@ -22,7 +23,7 @@ def add_arguments(self, parser: CommandParser) -> None: help="Output folder", ) - def handle(self, **options) -> None: + def handle(self, **options: Any) -> None: # Get the database we're operating from connection = connections[DEFAULT_DB_ALIAS] diff --git a/rocky/tools/management/commands/generate_report.py b/rocky/tools/management/commands/generate_report.py index 8782885979b..42c927bed5b 100644 --- a/rocky/tools/management/commands/generate_report.py +++ b/rocky/tools/management/commands/generate_report.py @@ -75,7 +75,9 @@ def handle(self, *args, **options): self.stdout.buffer.write(report.read()) @staticmethod - def get_findings_metadata(organization, valid_time, severities) -> list[dict[str, Any]]: + def get_findings_metadata( + organization: Organization, valid_time: datetime, severities: list[RiskLevelSeverity] + ) -> list[dict[str, Any]]: findings = FindingList( OctopoesAPIConnector(settings.OCTOPOES_API, organization.code), valid_time, @@ -85,7 +87,7 @@ def get_findings_metadata(organization, valid_time, severities) -> list[dict[str return generate_findings_metadata(findings, severities) @staticmethod - def get_organization(**options) -> Organization | None: + def get_organization(**options: str) -> Organization | None: if options["code"] and options["id"]: return None diff --git a/rocky/tools/management/commands/setup_test_users.py b/rocky/tools/management/commands/setup_test_users.py index cdb7973c5a1..134ba825e1d 100644 --- a/rocky/tools/management/commands/setup_test_users.py +++ b/rocky/tools/management/commands/setup_test_users.py @@ -19,7 +19,7 @@ def handle(self, **options): add_test_user("e2e-client", password, GROUP_CLIENT) -def add_superuser(email: str, password: str): +def add_superuser(email: str, password: str) -> None: user_kwargs: dict[str, str | bool] = { "email": email, "password": password, @@ -31,7 +31,7 @@ def add_superuser(email: str, password: str): add_user(user_kwargs) -def add_test_user(email: str, password: str, group_name: str | None = None): +def add_test_user(email: str, password: str, group_name: str | None = None) -> None: user_kwargs: dict[str, str | bool] = { "email": email, "password": password, @@ -41,7 +41,7 @@ def add_test_user(email: str, password: str, group_name: str | None = None): add_user(user_kwargs, group_name) -def add_user(user_kwargs: dict[str, str | bool], group_name: str | None = None): +def add_user(user_kwargs: dict[str, str | bool], group_name: str | None = None) -> None: """ Creates a test user with the given user_kwargs. User is optionally added to group group_name. diff --git a/rocky/tools/models.py b/rocky/tools/models.py index 5e66ac7d340..b2dddc8cb65 100644 --- a/rocky/tools/models.py +++ b/rocky/tools/models.py @@ -20,7 +20,7 @@ from octopoes.api.models import Declaration from octopoes.connector.octopoes import OctopoesAPIConnector -from octopoes.models.ooi.web import Network +from octopoes.models.ooi.network import Network from rocky.exceptions import OctopoesDownException, OctopoesException, OctopoesUnhealthyException from tools.add_ooi_information import SEPARATOR, get_info from tools.enums import SCAN_LEVEL @@ -91,7 +91,7 @@ class Organization(models.Model): ) tags = tagulous.models.TagField(to=OrganizationTag, blank=True) - def __str__(self): + def __str__(self) -> str: return str(self.name) class Meta: @@ -258,7 +258,7 @@ def has_perms(self, perm_list: Iterable[str]) -> bool: class Meta: unique_together = ["user", "organization"] - def __str__(self): + def __str__(self) -> str: return str(self.user) @@ -304,5 +304,5 @@ def get_internet_description(self): self.data[key] = value self.save() - def __str__(self): + def __str__(self) -> str: return self.id diff --git a/rocky/tools/ooi_helpers.py b/rocky/tools/ooi_helpers.py index 7d2c6238254..b72ed60b8dd 100644 --- a/rocky/tools/ooi_helpers.py +++ b/rocky/tools/ooi_helpers.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from datetime import datetime from enum import Enum from typing import Any @@ -47,7 +48,7 @@ def format_display(data: dict, ignore: list | None = None) -> dict[str, str]: return {format_attr_name(k): format_value(v) for k, v in data.items() if k not in ignore} -def get_knowledge_base_data_for_ooi_store(ooi_store) -> dict[str, dict]: +def get_knowledge_base_data_for_ooi_store(ooi_store: dict) -> dict[str, dict]: knowledge_base = {} for ooi in ooi_store.values(): @@ -130,9 +131,9 @@ def create_object_tree_item_from_ref( reference_node: ReferenceNode, ooi_store: dict[str, OOI], knowledge_base: dict[str, dict] | None = None, - depth=0, - position=1, - location="loc", + depth: int = 0, + position: int = 1, + location: str = "loc", ) -> dict: depth = sum([depth, 1]) location = location + "-" + str(position) @@ -181,7 +182,7 @@ def get_ooi_types_from_tree(ooi, include_self=True): return sorted(types) -def filter_ooi_tree(ooi_node: dict, show_types=[], hide_types=[]) -> dict: +def filter_ooi_tree(ooi_node: dict, show_types: Sequence = [], hide_types: Sequence = []) -> dict: if not show_types and not hide_types: return ooi_node diff --git a/rocky/tools/templatetags/ooi_extra.py b/rocky/tools/templatetags/ooi_extra.py index 3757e9efa3f..63ee811f8fe 100644 --- a/rocky/tools/templatetags/ooi_extra.py +++ b/rocky/tools/templatetags/ooi_extra.py @@ -12,7 +12,7 @@ @register.filter -def get_encoded_dict(data_dict: dict): +def get_encoded_dict(data_dict: dict) -> str: return parse.urlencode(data_dict) @@ -37,27 +37,27 @@ def get_scan_levels() -> list[str]: @register.filter -def ooi_types_to_strings(ooi_types: set[type[OOI]]): +def ooi_types_to_strings(ooi_types: set[type[OOI]]) -> list["str"]: return [ooi_type.get_ooi_type() for ooi_type in ooi_types] @register.filter() -def get_type(x: Any): +def get_type(x: Any) -> Any: return type(x) @register.simple_tag() -def ooi_url(routename: str, ooi_id: str, organization_code: str, **kwargs) -> str: +def ooi_url(routename: str, ooi_id: str, organization_code: str, **kwargs: str) -> str: return get_ooi_url(routename, ooi_id, organization_code, **kwargs) @register.filter() -def is_finding(ooi: OOI): +def is_finding(ooi: OOI) -> bool: return isinstance(ooi, Finding) @register.filter() -def is_finding_type(ooi: OOI): +def is_finding_type(ooi: OOI) -> bool: return isinstance(ooi, FindingType) @@ -79,5 +79,5 @@ def index(indexable, i): @register.filter -def pretty_json(obj: dict): +def pretty_json(obj: dict) -> str: return json.dumps(obj, default=str, indent=4) diff --git a/rocky/tools/view_helpers.py b/rocky/tools/view_helpers.py index 741c59c868a..06f241ce45d 100644 --- a/rocky/tools/view_helpers.py +++ b/rocky/tools/view_helpers.py @@ -1,12 +1,13 @@ import uuid from datetime import date, datetime, timezone -from typing import TypedDict +from typing import Any, TypedDict from urllib.parse import urlencode, urlparse, urlunparse from django.contrib import messages from django.http import HttpRequest from django.urls.base import reverse, reverse_lazy from django.utils.translation import gettext_lazy as _ +from django_stubs_ext import StrPromise from octopoes.models.types import OOI_TYPES from rocky.scheduler import PrioritizedItem, SchedulerError, client @@ -18,7 +19,7 @@ def convert_date_to_datetime(d: date) -> datetime: return datetime.combine(d, datetime.max.time(), tzinfo=timezone.utc) -def get_mandatory_fields(request, params: list[str] | None = None): +def get_mandatory_fields(request: HttpRequest, params: list[str] | None = None) -> list: mandatory_fields = [] if not params: @@ -38,7 +39,7 @@ def generate_job_id(): return str(uuid.uuid4()) -def url_with_querystring(path, doseq=False, **kwargs) -> str: +def url_with_querystring(path: str, doseq: bool = False, /, **kwargs: Any) -> str: parsed_route = urlparse(path) return str( @@ -55,7 +56,7 @@ def url_with_querystring(path, doseq=False, **kwargs) -> str: ) -def get_ooi_url(routename: str, ooi_id: str, organization_code: str, **kwargs) -> str: +def get_ooi_url(routename: str, ooi_id: str, organization_code: str, **kwargs: Any) -> str: if ooi_id: kwargs["ooi_id"] = ooi_id @@ -68,7 +69,7 @@ def get_ooi_url(routename: str, ooi_id: str, organization_code: str, **kwargs) - return url_with_querystring(reverse(routename, kwargs={"organization_code": organization_code}), **kwargs) -def existing_ooi_type(ooi_type: str): +def existing_ooi_type(ooi_type: str) -> bool: if not ooi_type: return False @@ -76,7 +77,7 @@ def existing_ooi_type(ooi_type: str): class Breadcrumb(TypedDict): - text: str + text: StrPromise url: str @@ -130,35 +131,31 @@ class OrganizationBreadcrumbsMixin(BreadcrumbsMixin): class OrganizationDetailBreadcrumbsMixin(BreadcrumbsMixin): organization: Organization - def build_breadcrumbs(self): - breadcrumbs = [ + def build_breadcrumbs(self) -> list[Breadcrumb]: + return [ { "url": reverse("organization_settings", kwargs={"organization_code": self.organization.code}), "text": _("Settings"), }, ] - return breadcrumbs - class OrganizationMemberBreadcrumbsMixin(BreadcrumbsMixin): organization: Organization - def build_breadcrumbs(self): - breadcrumbs = [ + def build_breadcrumbs(self) -> list[Breadcrumb]: + return [ { "url": reverse("organization_member_list", kwargs={"organization_code": self.organization.code}), "text": _("Members"), }, ] - return breadcrumbs - class ObjectsBreadcrumbsMixin(BreadcrumbsMixin): organization: Organization - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: return [ { "url": reverse_lazy("ooi_list", kwargs={"organization_code": self.organization.code}), From b7d3a86b312ffb594d4e057d30293f539ca7b012 Mon Sep 17 00:00:00 2001 From: Jeroen Dekkers Date: Tue, 14 May 2024 17:41:19 +0200 Subject: [PATCH 2/9] More typing --- rocky/account/admin.py | 1 - rocky/account/mixins.py | 2 +- rocky/katalogus/client.py | 10 +++++++--- rocky/katalogus/views/katalogus_settings.py | 2 +- rocky/katalogus/views/plugin_detail.py | 2 ++ rocky/reports/report_types/definitions.py | 2 +- rocky/rocky/views/ooi_detail.py | 16 ++++++++++++---- rocky/rocky/views/ooi_list.py | 17 ++++++++--------- rocky/rocky/views/ooi_report.py | 3 +-- rocky/rocky/views/ooi_view.py | 2 +- rocky/rocky/views/organization_list.py | 4 ++-- rocky/rocky/views/organization_member_add.py | 2 +- rocky/rocky/views/organization_member_list.py | 4 ++-- rocky/rocky/views/scans.py | 3 ++- rocky/tools/admin.py | 10 ++++++---- rocky/tools/forms/base.py | 2 +- rocky/tools/forms/ooi_form.py | 11 ++++++++--- rocky/tools/forms/settings.py | 5 ++--- rocky/tools/models.py | 1 + rocky/tools/view_helpers.py | 4 ++-- 20 files changed, 61 insertions(+), 42 deletions(-) diff --git a/rocky/account/admin.py b/rocky/account/admin.py index 564513931ac..3666a9726b0 100644 --- a/rocky/account/admin.py +++ b/rocky/account/admin.py @@ -10,7 +10,6 @@ @admin.register(User) class KATUserAdmin(UserAdmin): - model = User list_display = ( "email", "is_staff", diff --git a/rocky/account/mixins.py b/rocky/account/mixins.py index e2de135bec4..4949925c771 100644 --- a/rocky/account/mixins.py +++ b/rocky/account/mixins.py @@ -6,7 +6,7 @@ from django.core.exceptions import PermissionDenied from django.http import Http404 from django.utils.translation import gettext_lazy as _ -from django.views import ContextMixin, View +from django.views.generic.base import ContextMixin, View from tools.models import Indemnification, Organization, OrganizationMember from octopoes.connector.octopoes import OctopoesAPIConnector diff --git a/rocky/katalogus/client.py b/rocky/katalogus/client.py index 5bfed248aa0..6280d476b22 100644 --- a/rocky/katalogus/client.py +++ b/rocky/katalogus/client.py @@ -1,5 +1,6 @@ from io import BytesIO from logging import getLogger +from typing import TYPE_CHECKING import httpx from django.conf import settings @@ -8,7 +9,10 @@ from jsonschema.validators import Draft202012Validator from pydantic import BaseModel, Field, field_serializer from tools.enums import SCAN_LEVEL -from tools.models import OrganizationMember + +if TYPE_CHECKING: + # This prevents circurlar import + from tools.models import OrganizationMember from octopoes.models import OOI from octopoes.models.exception import TypeNotFound @@ -35,7 +39,7 @@ class Plugin(BaseModel): # """Pydantic does not stringify the OOI classes, but then templates can't render them""" # # todo: use field_serializer instead - def can_scan(self, member: OrganizationMember) -> bool: + def can_scan(self, member: "OrganizationMember") -> bool: return member.has_perm("tools.can_scan_organization") @@ -51,7 +55,7 @@ class Boefje(Plugin): def serialize_consumes(self, consumes: set[type[OOI]]) -> set[str]: return {ooi_class.get_ooi_type() for ooi_class in consumes} - def can_scan(self, member: OrganizationMember) -> bool: + def can_scan(self, member: "OrganizationMember") -> bool: return super().can_scan(member) and member.acknowledged_clearance_level >= self.scan_level.value diff --git a/rocky/katalogus/views/katalogus_settings.py b/rocky/katalogus/views/katalogus_settings.py index 14edf7a5719..933232624d9 100644 --- a/rocky/katalogus/views/katalogus_settings.py +++ b/rocky/katalogus/views/katalogus_settings.py @@ -87,7 +87,7 @@ def get_settings(self): messages.add_message( self.request, messages.ERROR, - _("Failed getting settings for boefje {}").format(self.plugin.id), + _("Failed getting settings for boefje {}").format(boefje.id), ) continue diff --git a/rocky/katalogus/views/plugin_detail.py b/rocky/katalogus/views/plugin_detail.py index 4774321cc47..ef5e838a700 100644 --- a/rocky/katalogus/views/plugin_detail.py +++ b/rocky/katalogus/views/plugin_detail.py @@ -83,6 +83,8 @@ def post(self, request, *args, **kwargs): def handle_page_action(self, action: str) -> None: if action == PageActions.RESCHEDULE_TASK.value: task_id = self.request.POST.get("task_id") + if not task_id: + raise ValueError("Missing task_id value") reschedule_task(self.request, self.organization.code, task_id) def get_context_data(self, **kwargs): diff --git a/rocky/reports/report_types/definitions.py b/rocky/reports/report_types/definitions.py index aaaeac822bf..941ca455ea5 100644 --- a/rocky/reports/report_types/definitions.py +++ b/rocky/reports/report_types/definitions.py @@ -24,7 +24,7 @@ class ReportPlugins(TypedDict): class BaseReport: id: str name: StrPromise - description: str + description: StrPromise template_path: str = "report.html" label_style = "1-light" # default/fallback color diff --git a/rocky/rocky/views/ooi_detail.py b/rocky/rocky/views/ooi_detail.py index 61f7adfadeb..30785ce6a7b 100644 --- a/rocky/rocky/views/ooi_detail.py +++ b/rocky/rocky/views/ooi_detail.py @@ -50,19 +50,21 @@ def post(self, request, *args, **kwargs): ) return self.get(request, status_code=403, *args, **kwargs) - if "action" not in self.request.POST: + action = self.request.POST.get("action") + if not action: return self.get(request, status_code=404, *args, **kwargs) self.ooi = self.get_ooi() - action = self.request.POST.get("action") return self.handle_page_action(action) def handle_page_action(self, action: str) -> HttpResponse: try: if action == PageActions.CHANGE_CLEARANCE_LEVEL.value: - clearance_level = int(self.request.POST.get("level")) - if not self.can_raise_clearance_level(self.ooi, clearance_level): + clearance_level = self.request.POST.get("level") + if clearance_level is None: + raise ValueError("Missing boefje_id parameter") + if not self.can_raise_clearance_level(self.ooi, int(clearance_level)): return redirect("account_detail", organization_code=self.organization.code) return self.get(self.request, *self.args, **self.kwargs) @@ -72,7 +74,11 @@ def handle_page_action(self, action: str) -> HttpResponse: if action == PageActions.START_SCAN.value: boefje_id = self.request.POST.get("boefje_id") + if boefje_id is None: + raise ValueError("Missing boefje_id parameter") ooi_id = self.request.GET.get("ooi_id") + if ooi_id is None: + raise ValueError("Missing boefje_id parameter") boefje = get_katalogus(self.organization.code).get_plugin(boefje_id) ooi = self.get_single_ooi(pk=ooi_id) @@ -85,6 +91,8 @@ def handle_page_action(self, action: str) -> HttpResponse: return self.get(self.request, status_code=500, *self.args, **self.kwargs) schema_answer = self.request.POST.get("schema") + if schema_answer is None: + raise ValueError("Missing schema parameter") parsed_schema_answer = json.loads(schema_answer) validator = Draft202012Validator(json.loads(self.ooi.json_schema)) diff --git a/rocky/rocky/views/ooi_list.py b/rocky/rocky/views/ooi_list.py index 9b6c475e53c..372c954c29a 100644 --- a/rocky/rocky/views/ooi_list.py +++ b/rocky/rocky/views/ooi_list.py @@ -8,7 +8,8 @@ from django.http import Http404, HttpRequest, HttpResponse from django.shortcuts import redirect from django.urls import reverse, reverse_lazy -from django.utils.translation import gettext_lazy as _ +from django.utils.translation import gettext as _ +from django.utils.translation import gettext_lazy from httpx import HTTPError from tools.enums import CUSTOM_SCAN_LEVEL from tools.forms.ooi import SelectOOIForm @@ -34,7 +35,7 @@ class PageActions(Enum): class OOIListView(BaseOOIListView, OctopoesView): - breadcrumbs = [{"url": reverse_lazy("ooi_list"), "text": _("Objects")}] + breadcrumbs = [{"url": reverse_lazy("ooi_list"), "text": gettext_lazy("Objects")}] template_name = "oois/ooi_list.html" def get_context_data(self, **kwargs): @@ -88,10 +89,10 @@ def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: return self.get(request, status=404, *args, **kwargs) def _set_scan_profiles( - self, selected_oois: list[Reference], level: CUSTOM_SCAN_LEVEL, request: HttpRequest, *args: Any, **kwargs: Any + self, selected_oois: list[str], level: CUSTOM_SCAN_LEVEL, request: HttpRequest, *args: Any, **kwargs: Any ) -> HttpResponse: try: - self.raise_clearance_levels(selected_oois, level.value) + self.raise_clearance_levels([Reference.from_str(ooi) for ooi in selected_oois], level.value) except IndemnificationNotPresentException: messages.add_message( self.request, @@ -154,7 +155,7 @@ def _set_scan_profiles( return self.get(request, *args, **kwargs) def _set_oois_to_inherit( - self, selected_oois: list[Reference], request: HttpRequest, *args: Any, **kwargs: Any + self, selected_oois: list[str], request: HttpRequest, *args: Any, **kwargs: Any ) -> HttpResponse: scan_profiles = [EmptyScanProfile(reference=Reference.from_str(ooi)) for ooi in selected_oois] @@ -182,14 +183,12 @@ def _set_oois_to_inherit( ) return self.get(request, *args, **kwargs) - def _delete_oois( - self, selected_oois: list[Reference], request: HttpRequest, *args: Any, **kwargs: Any - ) -> HttpResponse: + def _delete_oois(self, selected_oois: list[str], request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: connector = self.octopoes_api_connector valid_time = datetime.now(timezone.utc) try: - connector.delete_many(selected_oois, valid_time) + connector.delete_many([Reference.from_ooi(ooi) for ooi in selected_oois], valid_time) except (HTTPError, RemoteException, ConnectionError): messages.add_message(request, messages.ERROR, _("An error occurred while deleting oois.")) return self.get(request, status=500, *args, **kwargs) diff --git a/rocky/rocky/views/ooi_report.py b/rocky/rocky/views/ooi_report.py index 4725e873300..af44aaf29a7 100644 --- a/rocky/rocky/views/ooi_report.py +++ b/rocky/rocky/views/ooi_report.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from typing import Any from account.mixins import OrganizationView @@ -6,7 +6,6 @@ from django.http import FileResponse, HttpRequest, HttpResponse from django.shortcuts import redirect from django.urls import reverse -from django.utils import timezone from django.utils.translation import gettext_lazy as _ from katalogus.client import get_katalogus from tools.forms.ooi import OOIReportSettingsForm diff --git a/rocky/rocky/views/ooi_view.py b/rocky/rocky/views/ooi_view.py index 7d26929bc50..842db849585 100644 --- a/rocky/rocky/views/ooi_view.py +++ b/rocky/rocky/views/ooi_view.py @@ -6,7 +6,7 @@ from django.forms import Form from django.shortcuts import redirect from django.urls import reverse -from django.utils.translation import gettext_lazy as _ +from django.utils.translation import gettext as _ from django.views.generic import ListView, TemplateView from django.views.generic.edit import FormView from pydantic import ValidationError diff --git a/rocky/rocky/views/organization_list.py b/rocky/rocky/views/organization_list.py index 6bf12f1956a..f0f7451490e 100644 --- a/rocky/rocky/views/organization_list.py +++ b/rocky/rocky/views/organization_list.py @@ -1,5 +1,5 @@ from account.models import KATUser -from django.db.models import Count +from django.db.models import Count, QuerySet from django.views.generic import ListView from tools.models import Organization from tools.view_helpers import OrganizationBreadcrumbsMixin @@ -11,7 +11,7 @@ class OrganizationListView( ): template_name = "organizations/organization_list.html" - def get_queryset(self) -> list[Organization]: + def get_queryset(self) -> QuerySet[Organization]: user: KATUser = self.request.user return ( Organization.objects.annotate(member_count=Count("members")) diff --git a/rocky/rocky/views/organization_member_add.py b/rocky/rocky/views/organization_member_add.py index c1f724890a6..775abe85230 100644 --- a/rocky/rocky/views/organization_member_add.py +++ b/rocky/rocky/views/organization_member_add.py @@ -179,7 +179,7 @@ def process_csv(self, form: Form) -> None: ) except KeyError: messages.add_message(self.request, messages.ERROR, _("The csv file is missing required columns")) - return redirect("organization_member_upload", self.organization.code) + return try: with transaction.atomic(): diff --git a/rocky/rocky/views/organization_member_list.py b/rocky/rocky/views/organization_member_list.py index f4077fde35b..d592b600e5b 100644 --- a/rocky/rocky/views/organization_member_list.py +++ b/rocky/rocky/views/organization_member_list.py @@ -38,7 +38,7 @@ def get_queryset(self): queryset = self.model.objects.filter(organization=self.organization) if "client_status" in self.request.GET: status_filter = self.request.GET.getlist("client_status", []) - queryset = [member for member in queryset if member.status in status_filter] + queryset = queryset.filter(status__in=status_filter) if "blocked_status" in self.request.GET: blocked_filter = self.request.GET.getlist("blocked_status", []) @@ -51,7 +51,7 @@ def get_queryset(self): if filter_option == "unblocked": blocked_filter_bools.append(False) - queryset = [member for member in queryset if member.blocked in blocked_filter_bools] + queryset = queryset.filter(blocked__in=blocked_filter_bools) return queryset def setup(self, request, *args, **kwargs): diff --git a/rocky/rocky/views/scans.py b/rocky/rocky/views/scans.py index 103fcf360af..b8036096db6 100644 --- a/rocky/rocky/views/scans.py +++ b/rocky/rocky/views/scans.py @@ -1,6 +1,7 @@ from logging import getLogger from account.mixins import OrganizationView +from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView from katalogus.client import get_katalogus from tools.view_helpers import Breadcrumb, ObjectsBreadcrumbsMixin @@ -15,7 +16,7 @@ def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() breadcrumbs.append( - {"url": "", "text": "Scans"}, + {"url": "", "text": _("Scans")}, ) return breadcrumbs diff --git a/rocky/tools/admin.py b/rocky/tools/admin.py index b3e9731bbb6..2fe34181233 100644 --- a/rocky/tools/admin.py +++ b/rocky/tools/admin.py @@ -5,7 +5,7 @@ from django.contrib import admin, messages from django.db.models import JSONField from django.forms import widgets -from django.http import HttpResponseRedirect +from django.http import HttpRequest, HttpResponseRedirect from rocky.exceptions import RockyError from tools.models import Indemnification, OOIInformation, Organization, OrganizationMember, OrganizationTag @@ -34,14 +34,16 @@ class OOIInformationAdmin(admin.ModelAdmin): formfield_overrides = {JSONField: {"widget": JSONInfoWidget}} # if pk is not readonly, it will create a new record upon editing - def get_readonly_fields(self, request, obj=None): + def get_readonly_fields( + self, request: HttpRequest, obj: OOIInformation | None = None + ) -> list[str] | tuple[str, ...]: if obj is not None: # editing an existing object if not obj.value: - return self.readonly_fields + ( + return tuple(self.readonly_fields) + ( "id", "consult_api", ) - return self.readonly_fields + ("id",) + return tuple(self.readonly_fields) + ("id",) return self.readonly_fields diff --git a/rocky/tools/forms/base.py b/rocky/tools/forms/base.py index 0491ea4bbbd..2a76f96336c 100644 --- a/rocky/tools/forms/base.py +++ b/rocky/tools/forms/base.py @@ -104,7 +104,7 @@ def __init__( super().__init__(*args, **kwargs) self.required_options = required_options or [] - def get_context(self, name: str, value: Any, attrs: dict) -> dict[str, Any] | None: + def get_context(self, name: str, value: Any, attrs: dict[str, Any] | None) -> dict[str, Any]: context = super().get_context(name, value, attrs) return context diff --git a/rocky/tools/forms/ooi_form.py b/rocky/tools/forms/ooi_form.py index 05f5834e502..ee7b12a0493 100644 --- a/rocky/tools/forms/ooi_form.py +++ b/rocky/tools/forms/ooi_form.py @@ -2,7 +2,7 @@ from enum import Enum from inspect import isclass from ipaddress import IPv4Address, IPv6Address -from typing import Any, Literal, Union, get_args, get_origin +from typing import Any, Literal, TypedDict, Union, get_args, get_origin from django import forms from django.utils.translation import gettext_lazy as _ @@ -39,7 +39,7 @@ def generate_form_fields( self, hidden_ooi_fields: dict[str, str] | None = None, ) -> dict[str, forms.fields.Field]: - fields = {} + fields: dict[str, forms.fields.Field] = {} for name, field in self.ooi_class.model_fields.items(): annotation = field.annotation default_attrs = default_field_options(name, field) @@ -156,7 +156,12 @@ def generate_url_field(field: FieldInfo) -> forms.fields.Field: return field -def default_field_options(name: str, field_info: FieldInfo) -> dict[str, str | bool]: +class DefaultFieldOptions(TypedDict): + label: str + required: bool + + +def default_field_options(name: str, field_info: FieldInfo) -> DefaultFieldOptions: return { "label": name, "required": field_info.is_required(), diff --git a/rocky/tools/forms/settings.py b/rocky/tools/forms/settings.py index 4774f022942..e9d5b684a3f 100644 --- a/rocky/tools/forms/settings.py +++ b/rocky/tools/forms/settings.py @@ -1,10 +1,9 @@ -from typing import Any - from django.utils.translation import gettext_lazy as _ +from django_stubs_ext import StrPromise from tools.enums import SCAN_LEVEL -Choice = tuple[Any, str] +Choice = tuple[str, StrPromise] Choices = list[Choice] ChoicesGroup = tuple[str, Choices] ChoicesGroups = list[ChoicesGroup] diff --git a/rocky/tools/models.py b/rocky/tools/models.py index b2dddc8cb65..4148ca3bc50 100644 --- a/rocky/tools/models.py +++ b/rocky/tools/models.py @@ -79,6 +79,7 @@ def css_class(self): class Organization(models.Model): + id: int name = models.CharField(max_length=126, unique=True, help_text=_("The name of the organisation")) code = LowerCaseSlugField( max_length=ORGANIZATION_CODE_LENGTH, diff --git a/rocky/tools/view_helpers.py b/rocky/tools/view_helpers.py index 06f241ce45d..4c215e05019 100644 --- a/rocky/tools/view_helpers.py +++ b/rocky/tools/view_helpers.py @@ -7,7 +7,7 @@ from django.http import HttpRequest from django.urls.base import reverse, reverse_lazy from django.utils.translation import gettext_lazy as _ -from django_stubs_ext import StrPromise +from django_stubs_ext import StrOrPromise from octopoes.models.types import OOI_TYPES from rocky.scheduler import PrioritizedItem, SchedulerError, client @@ -77,7 +77,7 @@ def existing_ooi_type(ooi_type: str) -> bool: class Breadcrumb(TypedDict): - text: StrPromise + text: StrOrPromise url: str From c3ff3b5fd6e8e986a65c02da00e31f2e7e776fe9 Mon Sep 17 00:00:00 2001 From: Jeroen Dekkers Date: Tue, 3 Dec 2024 22:44:20 +0100 Subject: [PATCH 3/9] Fix promise --- rocky/reports/report_types/definitions.py | 6 +++--- rocky/tools/forms/settings.py | 4 ++-- rocky/tools/view_helpers.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/rocky/reports/report_types/definitions.py b/rocky/reports/report_types/definitions.py index ccffb34f38d..05262d3f3c0 100644 --- a/rocky/reports/report_types/definitions.py +++ b/rocky/reports/report_types/definitions.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Any, TypedDict, TypeVar -from django_stubs_ext import StrPromise +from django.utils.functional import Promise from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models import OOI, Reference @@ -39,8 +39,8 @@ def report_plugins_union(report_types: list[type["BaseReport"]]) -> ReportPlugin class BaseReport: id: str - name: StrPromise - description: StrPromise + name: Promise + description: Promise template_path: str = "report.html" plugins: ReportPlugins input_ooi_types: set[type[OOI]] diff --git a/rocky/tools/forms/settings.py b/rocky/tools/forms/settings.py index 52704ab73b2..88af5cf2f38 100644 --- a/rocky/tools/forms/settings.py +++ b/rocky/tools/forms/settings.py @@ -1,10 +1,10 @@ +from django.utils.functional import Promise from django.utils.safestring import mark_safe from django.utils.translation import gettext_lazy as _ -from django_stubs_ext import StrPromise from tools.enums import SCAN_LEVEL -Choice = tuple[str, StrPromise] +Choice = tuple[str, Promise] Choices = list[Choice] ChoicesGroup = tuple[str, Choices] ChoicesGroups = list[ChoicesGroup] diff --git a/rocky/tools/view_helpers.py b/rocky/tools/view_helpers.py index 03d32ef05a4..8bc300818dc 100644 --- a/rocky/tools/view_helpers.py +++ b/rocky/tools/view_helpers.py @@ -6,8 +6,8 @@ from django.http import HttpRequest from django.http.response import HttpResponseRedirectBase from django.urls.base import reverse, reverse_lazy +from django.utils.functional import Promise from django.utils.translation import gettext_lazy as _ -from django_stubs_ext import StrOrPromise from octopoes.models.types import OOI_TYPES from tools.models import Organization @@ -76,7 +76,7 @@ def existing_ooi_type(ooi_type: str) -> bool: class Breadcrumb(TypedDict): - text: StrOrPromise + text: str | Promise url: str From 16cc41156dcf4b54092f9058ae3583a8b7c9f361 Mon Sep 17 00:00:00 2001 From: Jeroen Dekkers Date: Tue, 3 Dec 2024 22:44:36 +0100 Subject: [PATCH 4/9] Fix typo --- .../report_types/aggregate_organisation_report/report.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rocky/reports/report_types/aggregate_organisation_report/report.py b/rocky/reports/report_types/aggregate_organisation_report/report.py index 0d0d3dfa197..de7f1fbc61f 100644 --- a/rocky/reports/report_types/aggregate_organisation_report/report.py +++ b/rocky/reports/report_types/aggregate_organisation_report/report.py @@ -179,7 +179,7 @@ def post_process_data(self, data: dict[str, Any], valid_time: datetime, organiza basic_security["system_specific"][SystemType.WEB] = [ report for ip in web_report_data for report in web_report_data[ip] ] - basic_security["syst_specific"][SystemType.DNS] = [ + basic_security["system_specific"][SystemType.DNS] = [ report for ip in dns_report_data for report in dns_report_data[ip] ] From deaa2778d954feac4693ad18d08ba2b0851e9742 Mon Sep 17 00:00:00 2001 From: Jeroen Dekkers Date: Tue, 3 Dec 2024 22:57:47 +0100 Subject: [PATCH 5/9] Make lang --- .../report_overview/report_history_table.html | 10 ++++---- rocky/rocky/locale/django.pot | 23 +++++++++++-------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/rocky/reports/templates/report_overview/report_history_table.html b/rocky/reports/templates/report_overview/report_history_table.html index 9452ad6cb24..cee41f0f203 100644 --- a/rocky/reports/templates/report_overview/report_history_table.html +++ b/rocky/reports/templates/report_overview/report_history_table.html @@ -154,11 +154,11 @@ {% translate "Subreports details:" %}
{% translate "Report types" %}

- {% blocktranslate count counter=report.total_children_reports %} - This report consist of {{counter}} subreport with the following report type and object. - {% plural %} - This report consist of {{counter}} subreports with the following report types and objects. - {% endblocktranslate %} + {% blocktranslate trimmed count counter=report.total_children_reports %} + This report consist of {{ counter }} subreport with the following report type and object. + {% plural %} + This report consist of {{ counter }} subreports with the following report types and objects. + {% endblocktranslate %}

diff --git a/rocky/rocky/locale/django.pot b/rocky/rocky/locale/django.pot index 962ed1d10cf..744f46adbde 100644 --- a/rocky/rocky/locale/django.pot +++ b/rocky/rocky/locale/django.pot @@ -9,7 +9,7 @@ msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-11-25 09:27+0000\n" +"POT-Creation-Date: 2024-12-03 21:56+0000\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -2772,6 +2772,10 @@ msgstr "" msgid "No CVEs have been found." msgstr "" +#: reports/report_types/aggregate_organisation_report/report.py +msgid "Aggregate Organisation Report" +msgstr "" + #: reports/report_types/aggregate_organisation_report/report_design.html msgid "Observed at:" msgstr "" @@ -3809,6 +3813,7 @@ msgid "Add reference date" msgstr "" #: reports/templates/partials/report_names_form.html +#: reports/templates/report_overview/modal_partials/rename_modal.html #: rocky/views/scan_profile.py msgid "Reset" msgstr "" @@ -4261,15 +4266,11 @@ msgstr "" #: reports/templates/report_overview/report_history_table.html #, python-format msgid "" -"\n" -" This report consist of %(counter)s " -"subreport with the following report type and object.\n" -" " +"This report consist of %(counter)s subreport with the following report type " +"and object." msgid_plural "" -"\n" -" This report consist of %(counter)s " -"subreports with the following report types and objects.\n" -" " +"This report consist of %(counter)s subreports with the following report " +"types and objects." msgstr[0] "" msgstr[1] "" @@ -7350,6 +7351,10 @@ msgstr "" msgid "Can not reset scan level. Scan level of {ooi_name} not declared" msgstr "" +#: rocky/views/scans.py +msgid "Scans" +msgstr "" + #: rocky/views/scheduler.py msgid "Your report has been scheduled." msgstr "" From cae54dca479e33acdc243ddb0301242a97107f13 Mon Sep 17 00:00:00 2001 From: Jeroen Dekkers Date: Tue, 3 Dec 2024 23:04:23 +0100 Subject: [PATCH 6/9] Use future annotations --- octopoes/octopoes/models/ooi/findings.py | 4 +++- octopoes/octopoes/xtdb/query.py | 20 ++++++++++--------- .../report_types/name_server_report/report.py | 4 +++- .../report_types/web_system_report/report.py | 4 +++- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/octopoes/octopoes/models/ooi/findings.py b/octopoes/octopoes/models/ooi/findings.py index c60c3803c3a..af82e2de756 100644 --- a/octopoes/octopoes/models/ooi/findings.py +++ b/octopoes/octopoes/models/ooi/findings.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum from functools import total_ordering from typing import Annotated, Literal @@ -24,7 +26,7 @@ class RiskLevelSeverity(Enum): # unknown = the third party has been contacted, but third party has not determined the risk level (yet) UNKNOWN = "unknown" - def __gt__(self, other: "RiskLevelSeverity") -> bool: + def __gt__(self, other: RiskLevelSeverity) -> bool: return severity_order.index(self.value) > severity_order.index(other.value) def __str__(self) -> str: diff --git a/octopoes/octopoes/xtdb/query.py b/octopoes/octopoes/xtdb/query.py index 8af202e6555..147c58174a0 100644 --- a/octopoes/octopoes/xtdb/query.py +++ b/octopoes/octopoes/xtdb/query.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass, field from uuid import UUID, uuid4 @@ -76,13 +78,13 @@ class Query: _offset: int | None = None _order_by: tuple[Aliased, bool] | None = None - def where(self, ooi_type: Ref, **kwargs: Ref | str | set[str] | bool) -> "Query": + def where(self, ooi_type: Ref, **kwargs: Ref | str | set[str] | bool) -> Query: for field_name, value in kwargs.items(): self._where_field_is(ooi_type, field_name, value) return self - def where_in(self, ooi_type: Ref, **kwargs: list[str]) -> "Query": + def where_in(self, ooi_type: Ref, **kwargs: list[str]) -> Query: """Allows for filtering on multiple values for a specific field.""" for field_name, values in kwargs.items(): @@ -94,7 +96,7 @@ def format(self) -> str: return self._compile(separator="\n ") @classmethod - def from_path(cls, path: Path) -> "Query": + def from_path(cls, path: Path) -> Query: """ Create a query from a Path. @@ -147,14 +149,14 @@ def from_path(cls, path: Path) -> "Query": return query - def pull(self, ooi_type: Ref, *, fields: str = "[*]") -> "Query": + def pull(self, ooi_type: Ref, *, fields: str = "[*]") -> Query: """By default, we pull the target type. But when using find, count, etc., you have to pull explicitly.""" self._find_clauses.append(f"(pull {self._get_object_alias(ooi_type)} {fields})") return self - def find(self, item: Ref, *, index: int | None = None) -> "Query": + def find(self, item: Ref, *, index: int | None = None) -> Query: """Add a find clause, so we can select specific fields in a query to be returned as well.""" if index is None: @@ -164,22 +166,22 @@ def find(self, item: Ref, *, index: int | None = None) -> "Query": return self - def count(self, ooi_type: Ref) -> "Query": + def count(self, ooi_type: Ref) -> Query: self._find_clauses.append(f"(count {self._get_object_alias(ooi_type)})") return self - def limit(self, limit: int) -> "Query": + def limit(self, limit: int) -> Query: self._limit = limit return self - def offset(self, offset: int) -> "Query": + def offset(self, offset: int) -> Query: self._offset = offset return self - def order_by(self, ref: Aliased, ascending: bool = True) -> "Query": + def order_by(self, ref: Aliased, ascending: bool = True) -> Query: self._order_by = (ref, ascending) return self diff --git a/rocky/reports/report_types/name_server_report/report.py b/rocky/reports/report_types/name_server_report/report.py index f62851e2dd8..cdec60e352c 100644 --- a/rocky/reports/report_types/name_server_report/report.py +++ b/rocky/reports/report_types/name_server_report/report.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Iterable from dataclasses import dataclass, field from datetime import datetime @@ -43,7 +45,7 @@ def __bool__(self) -> bool: def __len__(self) -> int: return len(self.checks) - def __add__(self, other: "NameServerChecks") -> "NameServerChecks": + def __add__(self, other: NameServerChecks) -> NameServerChecks: return NameServerChecks(checks=self.checks + other.checks) diff --git a/rocky/reports/report_types/web_system_report/report.py b/rocky/reports/report_types/web_system_report/report.py index 6be0bed0db0..47f62c760c9 100644 --- a/rocky/reports/report_types/web_system_report/report.py +++ b/rocky/reports/report_types/web_system_report/report.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Iterable from dataclasses import dataclass, field from datetime import datetime @@ -83,7 +85,7 @@ def __bool__(self) -> bool: def __len__(self) -> int: return len(self.checks) - def __add__(self, other: "WebChecks") -> "WebChecks": + def __add__(self, other: WebChecks) -> WebChecks: return WebChecks(checks=self.checks + other.checks) From a26b91d0fa61ddc691d95cea384e3878739f40f6 Mon Sep 17 00:00:00 2001 From: Jeroen Dekkers Date: Tue, 3 Dec 2024 23:10:22 +0100 Subject: [PATCH 7/9] No need for gettext_lazy --- rocky/rocky/views/scans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rocky/rocky/views/scans.py b/rocky/rocky/views/scans.py index 82b1c0585ff..0ce17c061dc 100644 --- a/rocky/rocky/views/scans.py +++ b/rocky/rocky/views/scans.py @@ -1,5 +1,5 @@ from account.mixins import OrganizationView -from django.utils.translation import gettext_lazy as _ +from django.utils.translation import gettext as _ from django.views.generic import TemplateView from tools.view_helpers import Breadcrumb, ObjectsBreadcrumbsMixin From 61773b4458f719bb4290ce3b34be37cfb15bb422 Mon Sep 17 00:00:00 2001 From: Jeroen Dekkers Date: Wed, 4 Dec 2024 00:04:49 +0100 Subject: [PATCH 8/9] Fix missing import --- mula/scheduler/models/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mula/scheduler/models/__init__.py b/mula/scheduler/models/__init__.py index 8fc580db6ca..a5390ad6ede 100644 --- a/mula/scheduler/models/__init__.py +++ b/mula/scheduler/models/__init__.py @@ -1,9 +1,9 @@ from .base import Base from .boefje import Boefje, BoefjeMeta -from .events import RawDataReceivedEvent +from .events import RawData, RawDataReceivedEvent from .health import ServiceHealth from .normalizer import Normalizer -from .ooi import OOI, MutationOperationType, ScanProfileMutation +from .ooi import OOI, MutationOperationType, ScanProfile, ScanProfileMutation from .organisation import Organisation from .plugin import Plugin from .queue import Queue From 0bd87c932c4fa806bc06cde3ec4bcfcb3e04cb1c Mon Sep 17 00:00:00 2001 From: Jeroen Dekkers Date: Wed, 4 Dec 2024 00:31:08 +0100 Subject: [PATCH 9/9] Remove return typing from octopoes routers --- octopoes/octopoes/api/router.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/octopoes/octopoes/api/router.py b/octopoes/octopoes/api/router.py index f1759d5c368..5816d2a2d88 100644 --- a/octopoes/octopoes/api/router.py +++ b/octopoes/octopoes/api/router.py @@ -114,7 +114,7 @@ def list_objects( search_string: str | None = None, order_by: Literal["scan_level", "object_type"] = "object_type", asc_desc: Literal["asc", "desc"] = "asc", -) -> Paginated[OOI]: +): return octopoes.list_ooi( types, valid_time, offset, limit, scan_level, scan_profile_type, search_string, order_by, asc_desc ) @@ -128,7 +128,7 @@ def query( valid_time: datetime = Depends(extract_valid_time), offset: int = DEFAULT_OFFSET, limit: int = DEFAULT_LIMIT, -) -> list[OOI | tuple]: +): object_path = ObjectPath.parse(path) xtdb_query = XTDBQuery.from_path(object_path).offset(offset).limit(limit) @@ -144,7 +144,7 @@ def query_many( sources: list[str] = Query(), octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), -) -> list[OOI | tuple]: +): """ How does this work and why do we do this? @@ -189,7 +189,7 @@ def load_objects_bulk( octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), references: set[Reference] = Depends(extract_references), -) -> dict[str, OOI]: +): return octopoes.ooi_repository.load_bulk(references, valid_time) @@ -198,7 +198,7 @@ def get_object( octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), reference: Reference = Depends(extract_reference), -) -> OOI: +): return octopoes.get_ooi(reference, valid_time) @@ -230,7 +230,7 @@ def list_random_objects( valid_time: datetime = Depends(extract_valid_time), amount: int = 1, scan_level: set[ScanLevel] = Query(DEFAULT_SCAN_LEVEL_FILTER), -) -> list[OOI]: +): return octopoes.list_random_ooi(valid_time, amount, scan_level) @@ -459,8 +459,8 @@ def list_findings( @router.get("/reports", tags=["Reports"]) def list_reports( - offset=DEFAULT_OFFSET, - limit=DEFAULT_LIMIT, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), ) -> Paginated[tuple[Report, list[Report | None]]]: