Skip to content

Commit

Permalink
fix test; enhance report
Browse files Browse the repository at this point in the history
  • Loading branch information
asikowitz committed Dec 21, 2024
1 parent b7c84b3 commit 9f15aa8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
12 changes: 7 additions & 5 deletions metadata-ingestion/src/datahub/ingestion/source/aws/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
platform_name,
support_status,
)
from datahub.ingestion.api.report import EntityFilterReport
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws import s3_util
Expand Down Expand Up @@ -219,6 +220,7 @@ def platform_validator(cls, v: str) -> str:
class GlueSourceReport(StaleEntityRemovalSourceReport):
tables_scanned = 0
filtered: List[str] = dataclass_field(default_factory=list)
databases = EntityFilterReport.field(type="database")

num_job_script_location_missing: int = 0
num_job_script_location_invalid: int = 0
Expand Down Expand Up @@ -684,15 +686,15 @@ def get_all_databases(self) -> Iterable[Mapping[str, Any]]:
pattern += "[?!TargetDatabase]"

for database in paginator_response.search(pattern):
if not self.source_config.database_pattern.allowed(database["Name"]):
continue
if (
if (not self.source_config.database_pattern.allowed(database["Name"])) or (
self.source_config.catalog_id
and database.get("CatalogId")
and database.get("CatalogId") != self.source_config.catalog_id
):
continue
yield database
self.report.databases.dropped(database["Name"])
else:
self.report.databases.processed(database["Name"])
yield database

def get_tables_from_database(self, database: Mapping[str, Any]) -> Iterable[Dict]:
logger.debug(f"Getting tables from database {database['Name']}")
Expand Down
15 changes: 9 additions & 6 deletions metadata-ingestion/tests/unit/glue/test_glue_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,26 +316,29 @@ def format_databases(databases):
return set(d["Name"] for d in databases)

all_catalogs_source: GlueSource = GlueSource(
config=GlueSourceConfig(), ctx=PipelineContext(run_id="glue-source-test")
config=GlueSourceConfig(aws_region="us-west-2"),
ctx=PipelineContext(run_id="glue-source-test"),
)
with Stubber(all_catalogs_source.glue_client) as glue_stubber:
glue_stubber.add_response("get_databases", get_databases_response, {})

expected = format_databases([flights_database, test_database, empty_database])
assert format_databases(all_catalogs_source.get_all_databases()) == expected
expected = [flights_database, test_database, empty_database]
actual = all_catalogs_source.get_all_databases()
assert format_databases(actual) == format_databases(expected)

catalog_id = "123412341234"
single_catalog_source = GlueSource(
config=GlueSourceConfig(catalog_id=catalog_id),
config=GlueSourceConfig(catalog_id=catalog_id, aws_region="us-west-2"),
ctx=PipelineContext(run_id="glue-source-test"),
)
with Stubber(single_catalog_source.glue_client) as glue_stubber:
glue_stubber.add_response(
"get_databases", get_databases_response, {"CatalogId": catalog_id}
)

expected = format_databases([flights_database, test_database])
assert format_databases(single_catalog_source.get_all_databases()) == expected
expected = [flights_database, test_database]
actual = single_catalog_source.get_all_databases()
assert format_databases(actual) == format_databases(expected)


@freeze_time(FROZEN_TIME)
Expand Down

0 comments on commit 9f15aa8

Please sign in to comment.