From 9f15aa83823a45903c91847ae9dcae704cbf9650 Mon Sep 17 00:00:00 2001 From: Andrew Sikowitz Date: Sat, 21 Dec 2024 06:03:30 -0800 Subject: [PATCH] fix test; enhance report --- .../src/datahub/ingestion/source/aws/glue.py | 12 +++++++----- .../tests/unit/glue/test_glue_source.py | 15 +++++++++------ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py b/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py index e67c56f814275a..cdcdd0221934bc 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py @@ -52,6 +52,7 @@ platform_name, support_status, ) +from datahub.ingestion.api.report import EntityFilterReport from datahub.ingestion.api.source import MetadataWorkUnitProcessor from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.aws import s3_util @@ -219,6 +220,7 @@ def platform_validator(cls, v: str) -> str: class GlueSourceReport(StaleEntityRemovalSourceReport): tables_scanned = 0 filtered: List[str] = dataclass_field(default_factory=list) + databases = EntityFilterReport.field(type="database") num_job_script_location_missing: int = 0 num_job_script_location_invalid: int = 0 @@ -684,15 +686,15 @@ def get_all_databases(self) -> Iterable[Mapping[str, Any]]: pattern += "[?!TargetDatabase]" for database in paginator_response.search(pattern): - if not self.source_config.database_pattern.allowed(database["Name"]): - continue - if ( + if (not self.source_config.database_pattern.allowed(database["Name"])) or ( self.source_config.catalog_id and database.get("CatalogId") and database.get("CatalogId") != self.source_config.catalog_id ): - continue - yield database + self.report.databases.dropped(database["Name"]) + else: + self.report.databases.processed(database["Name"]) + yield database def get_tables_from_database(self, database: Mapping[str, Any]) -> Iterable[Dict]: logger.debug(f"Getting tables from database {database['Name']}") diff --git a/metadata-ingestion/tests/unit/glue/test_glue_source.py b/metadata-ingestion/tests/unit/glue/test_glue_source.py index 29f4175a1331b0..aff169aaf60f36 100644 --- a/metadata-ingestion/tests/unit/glue/test_glue_source.py +++ b/metadata-ingestion/tests/unit/glue/test_glue_source.py @@ -316,17 +316,19 @@ def format_databases(databases): return set(d["Name"] for d in databases) all_catalogs_source: GlueSource = GlueSource( - config=GlueSourceConfig(), ctx=PipelineContext(run_id="glue-source-test") + config=GlueSourceConfig(aws_region="us-west-2"), + ctx=PipelineContext(run_id="glue-source-test"), ) with Stubber(all_catalogs_source.glue_client) as glue_stubber: glue_stubber.add_response("get_databases", get_databases_response, {}) - expected = format_databases([flights_database, test_database, empty_database]) - assert format_databases(all_catalogs_source.get_all_databases()) == expected + expected = [flights_database, test_database, empty_database] + actual = all_catalogs_source.get_all_databases() + assert format_databases(actual) == format_databases(expected) catalog_id = "123412341234" single_catalog_source = GlueSource( - config=GlueSourceConfig(catalog_id=catalog_id), + config=GlueSourceConfig(catalog_id=catalog_id, aws_region="us-west-2"), ctx=PipelineContext(run_id="glue-source-test"), ) with Stubber(single_catalog_source.glue_client) as glue_stubber: @@ -334,8 +336,9 @@ def format_databases(databases): "get_databases", get_databases_response, {"CatalogId": catalog_id} ) - expected = format_databases([flights_database, test_database]) - assert format_databases(single_catalog_source.get_all_databases()) == expected + expected = [flights_database, test_database] + actual = single_catalog_source.get_all_databases() + assert format_databases(actual) == format_databases(expected) @freeze_time(FROZEN_TIME)