From 587f9993ddd6bb73edec0c6d257b7fe1825d833e Mon Sep 17 00:00:00 2001 From: Serhii Dimchenko Date: Tue, 14 Jan 2025 16:57:21 +0100 Subject: [PATCH] Fixed typing --- .../src/datahub/ingestion/source/aws/glue.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py b/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py index e2a3f25639d18e..e14e875ab6a28a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field as dataclass_field from functools import lru_cache from typing import ( + TYPE_CHECKING, Any, DefaultDict, Dict, @@ -114,6 +115,12 @@ from datahub.utilities.delta import delta_type_to_hive_type from datahub.utilities.hive_schema_to_avro import get_schema_fields_for_hive_column +if TYPE_CHECKING: + from mypy_boto3_glue.type_defs import ( + DatabasePaginatorTypeDef, + TablePaginatorTypeDef, + ) + logger = logging.getLogger(__name__) DEFAULT_PLATFORM = "glue" @@ -156,8 +163,8 @@ class GlueSourceConfig( default=None, description="The aws account id where the target glue catalog lives. If None, datahub will ingest glue in aws caller's account.", ) - catalog_name: str = Field( - default="awsdatacatalog", description="The aws athena catalog name" + catalog_name: Optional[str] = Field( + default=None, description="The aws athena catalog name" ) ignore_resource_links: Optional[bool] = Field( default=False, @@ -715,7 +722,7 @@ def get_datajob_wu(self, node: Dict[str, Any], job_name: str) -> MetadataWorkUni return MetadataWorkUnit(id=f'{job_name}-{node["Id"]}', mce=mce) - def get_all_databases(self) -> Iterable[Mapping[str, Any]]: + def get_all_databases(self) -> Iterable[DatabasePaginatorTypeDef]: logger.debug("Getting all databases") # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue/paginator/GetDatabases.html paginator = self.glue_client.get_paginator("get_databases") @@ -743,7 +750,9 @@ def get_all_databases(self) -> Iterable[Mapping[str, Any]]: self.report.databases.processed(database["Name"]) yield database - def get_tables_from_database(self, database: Mapping[str, Any]) -> Iterable[Dict]: + def get_tables_from_database( + self, database: DatabasePaginatorTypeDef + ) -> Iterable[TablePaginatorTypeDef]: logger.debug(f"Getting tables from database {database['Name']}") # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue/paginator/GetTables.html paginator = self.glue_client.get_paginator("get_tables") @@ -770,7 +779,7 @@ def get_tables_from_database(self, database: Mapping[str, Any]) -> Iterable[Dict def get_all_databases_and_tables( self, - ) -> Tuple[List[Mapping[str, Any]], List[Dict]]: + ) -> Tuple[List[DatabasePaginatorTypeDef], List[TablePaginatorTypeDef]]: all_databases = [*self.get_all_databases()] all_tables = [ tables @@ -1038,7 +1047,7 @@ def gen_database_key(self, database: str) -> DatabaseKey: ) def gen_database_containers( - self, database: Mapping[str, Any] + self, database: DatabasePaginatorTypeDef ) -> Iterable[MetadataWorkUnit]: domain_urn = self._gen_domain_urn(database["Name"]) database_container_key = self.gen_database_key(database["Name"]) @@ -1113,7 +1122,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if self.extract_transforms: yield from self._transform_extraction() - def _gen_table_wu(self, table: Dict) -> Iterable[MetadataWorkUnit]: + def _gen_table_wu(self, table: TablePaginatorTypeDef) -> Iterable[MetadataWorkUnit]: database_name = table["DatabaseName"] table_name = table["Name"] full_table_name = (