From bcc86e66adb40a5f3f571756b67dc7e71af16ad9 Mon Sep 17 00:00:00 2001 From: Mars Lan Date: Mon, 14 Oct 2024 09:37:21 -0700 Subject: [PATCH] Refactor Unity Catalog to fetch catalog/schema/table metadata from System tables --- metaphor/unity_catalog/extractor.py | 405 ++++++----- metaphor/unity_catalog/models.py | 95 ++- metaphor/unity_catalog/queries.py | 633 ++++++++++++++++++ metaphor/unity_catalog/utils.py | 232 ++----- pyproject.toml | 2 +- .../describe_history.sql | 1 + .../show_table_properties.sql | 1 + .../list_catalog_tags.sql | 1 + .../test_list_catalogs/list_catalogs.sql | 8 + .../list_column_lineage.sql | 18 + .../test_list_query_logs/list_query_log.sql | 0 .../list_schema_tags.sql | 1 + .../test_list_schemas/list_schemas.sql | 9 + .../list_table_lineage.sql | 14 + .../test_list_tables/list_tables.sql | 109 +++ .../list_volume_files.sql | 1 + .../test_list_volumes/list_volumes.sql | 16 + tests/unity_catalog/test_models.py | 11 - tests/unity_catalog/test_queries.py | 556 +++++++++++++++ tests/unity_catalog/test_utils.py | 101 +-- 20 files changed, 1748 insertions(+), 466 deletions(-) create mode 100644 metaphor/unity_catalog/queries.py create mode 100644 tests/unity_catalog/snapshots/test_queries/test_get_last_refreshed_time/describe_history.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_get_table_properties/show_table_properties.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_catalog_tags/list_catalog_tags.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_catalogs/list_catalogs.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_column_lineage/list_column_lineage.sql rename tests/unity_catalog/snapshots/{test_utils => test_queries}/test_list_query_logs/list_query_log.sql (100%) create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_schema_tags/list_schema_tags.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_schemas/list_schemas.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_table_lineage/list_table_lineage.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_tables/list_tables.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_volume_files/list_volume_files.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_volumes/list_volumes.sql create mode 100644 tests/unity_catalog/test_queries.py diff --git a/metaphor/unity_catalog/extractor.py b/metaphor/unity_catalog/extractor.py index 92a54a28..fccd2296 100644 --- a/metaphor/unity_catalog/extractor.py +++ b/metaphor/unity_catalog/extractor.py @@ -1,12 +1,10 @@ -import json import re import urllib.parse -from collections import defaultdict -from datetime import datetime -from typing import Collection, Dict, Generator, Iterator, List, Optional, Tuple +from typing import Collection, Dict, Generator, Iterator, List, Optional -from databricks.sdk.service.catalog import TableInfo, TableType, VolumeInfo +from databricks.sdk.service.catalog import TableType from databricks.sdk.service.iam import ServicePrincipal +from pydantic import BaseModel from metaphor.common.base_extractor import BaseExtractor from metaphor.common.entity_id import ( @@ -18,7 +16,6 @@ from metaphor.common.filter import DatasetFilter from metaphor.common.logger import get_logger, json_dump_to_debug_file from metaphor.common.models import to_dataset_statistics -from metaphor.common.utils import to_utc_datetime_from_timestamp from metaphor.models.crawler_run_metadata import Platform from metaphor.models.metadata_change_event import ( AssetPlatform, @@ -40,6 +37,7 @@ SQLSchema, SystemContact, SystemContacts, + SystemDescription, SystemTag, SystemTags, SystemTagSource, @@ -52,18 +50,35 @@ VolumeFile, ) from metaphor.unity_catalog.config import UnityCatalogRunConfig -from metaphor.unity_catalog.models import extract_schema_field_from_column_info -from metaphor.unity_catalog.utils import ( +from metaphor.unity_catalog.models import ( + CatalogInfo, + SchemaInfo, + TableInfo, + VolumeFileInfo, + VolumeInfo, +) +from metaphor.unity_catalog.queries import ( ColumnLineageMap, TableLineageMap, + list_catalog_tags, + list_catalogs, + list_column_lineage, + list_schema_tags, + list_schemas, + list_table_lineage, + list_tables, + list_volume_files, + list_volumes, +) +from metaphor.unity_catalog.utils import ( batch_get_last_refreshed_time, + batch_get_table_properties, + build_schema_field_from_column_info, create_api, create_connection, create_connection_pool, get_query_logs, - list_column_lineage, list_service_principals, - list_table_lineage, ) logger = get_logger() @@ -114,22 +129,9 @@ URL_TABLE_RE = re.compile(r"{table}") -CatalogSystemTagsTuple = Tuple[List[SystemTag], Dict[str, List[SystemTag]]] -""" -(catalog system tags, schema name -> schema system tags) -""" - - -CatalogSystemTags = Dict[str, CatalogSystemTagsTuple] -""" -catalog name -> (catalog tags, schema name -> schema tags) -""" - - -def to_utc_from_timestamp_ms(timestamp_ms: Optional[int]): - if timestamp_ms is not None: - return to_utc_datetime_from_timestamp(timestamp_ms / 1000) - return None +class CatalogSystemTags(BaseModel): + catalog_tags: List[SystemTag] = [] + schema_name_to_tags: Dict[str, List[SystemTag]] = {} class UnityCatalogExtractor(BaseExtractor): @@ -168,6 +170,7 @@ def __init__(self, config: UnityCatalogRunConfig): self._service_principals: Dict[str, ServicePrincipal] = {} self._last_refresh_time_queue: List[str] = [] + self._table_properties_queue: List[str] = [] # Map fullname or volume path to a dataset self._datasets: Dict[str, Dataset] = {} @@ -182,38 +185,37 @@ async def extract(self) -> Collection[ENTITY_TYPES]: self._service_principals = list_service_principals(self._api) logger.info(f"Found service principals: {self._service_principals}") - catalogs = [ - catalog - for catalog in self._get_catalogs() - if self._filter.include_database(catalog) - ] + catalogs = list_catalogs(self._connection) + for catalog_info in catalogs: + catalog = catalog_info.catalog_name + if not self._filter.include_database(catalog): + logger.info(f"Ignore catalog {catalog} due to filter config") + continue - logger.info(f"Found catalogs: {catalogs}") + self._init_catalog(catalog_info) - for catalog in catalogs: - schemas = self._get_schemas(catalog) - for schema in schemas: + for schema_info in list_schemas(self._connection, catalog): + schema = schema_info.schema_name if not self._filter.include_schema(catalog, schema): logger.info( - f"Ignore schema: {catalog}.{schema} due to filter config" + f"Ignore schema {catalog}.{schema} due to filter config" ) continue + self._init_schema(schema_info) + table_lineage = list_table_lineage(self._connection, catalog, schema) column_lineage = list_column_lineage(self._connection, catalog, schema) - for volume in self._get_volume_infos(catalog, schema): - assert volume.full_name - self._volumes[volume.full_name] = volume - self._init_volume(volume) - self._extract_volume_files(volume) + for volume_info in list_volumes(self._connection, catalog, schema): + self._volumes[volume_info.full_name] = volume_info + self._init_volume(volume_info) + self._extract_volume_files(volume_info) - for table_info in self._get_table_infos(catalog, schema): - table_name = f"{catalog}.{schema}.{table_info.name}" - if table_info.name is None: - logger.error(f"Ignoring table without name: {table_info}") - continue - if not self._filter.include_table(catalog, schema, table_info.name): + for table_info in list_tables(self._connection, catalog, schema): + table = table_info.table_name + table_name = f"{catalog}.{schema}.{table}" + if not self._filter.include_table(catalog, schema, table): logger.info(f"Ignore table: {table_name} due to filter config") continue @@ -314,11 +316,10 @@ def _get_source_url( return url def _init_dataset(self, table_info: TableInfo) -> Dataset: - assert table_info.catalog_name and table_info.schema_name and table_info.name - table_name = table_info.name + table_name = table_info.table_name schema_name = table_info.schema_name database = table_info.catalog_name - table_type = table_info.table_type + table_type = table_info.type normalized_name = dataset_normalized_name(database, schema_name, table_name) @@ -332,14 +333,14 @@ def _init_dataset(self, table_info: TableInfo) -> Dataset: ) if table_type is None: - raise ValueError(f"Invalid table {table_info.name}, no table_type found") + raise ValueError( + f"Invalid table {table_info.table_name}, no table_type found" + ) - fields = [] - if table_info.columns is not None: - fields = [ - extract_schema_field_from_column_info(column_info) - for column_info in table_info.columns - ] + fields = [ + build_schema_field_from_column_info(column_info) + for column_info in table_info.columns + ] dataset.schema = DatasetSchema( schema_type=SchemaType.SQL, @@ -349,54 +350,45 @@ def _init_dataset(self, table_info: TableInfo) -> Dataset: materialization=TABLE_TYPE_MATERIALIZATION_TYPE_MAP.get( table_type, MaterializationType.TABLE ), - table_schema=( - table_info.view_definition if table_info.view_definition else None - ), + table_schema=table_info.view_definition, ), ) - if table_info.table_type in TABLE_TYPE_WITH_HISTORY: + # Queue tables with history for batch query later + if table_type in TABLE_TYPE_WITH_HISTORY: self._last_refresh_time_queue.append(normalized_name) main_url = self._get_table_source_url(database, schema_name, table_name) dataset.source_info = SourceInfo( main_url=main_url, - created_at_source=to_utc_from_timestamp_ms( - timestamp_ms=table_info.created_at - ), + created_at_source=table_info.created_at, + created_by=table_info.created_by, + last_updated=table_info.updated_at, + updated_by=table_info.updated_by, ) dataset.unity_catalog = UnityCatalog( dataset_type=UnityCatalogDatasetType.UNITY_CATALOG_TABLE, table_info=UnityCatalogTableInfo( type=TABLE_TYPE_MAP.get(table_type, UnityCatalogTableType.UNKNOWN), - data_source_format=( - table_info.data_source_format.value - if table_info.data_source_format is not None - else None - ), + data_source_format=table_info.data_source_format, storage_location=table_info.storage_location, owner=table_info.owner, - properties=( - [ - KeyValuePair(key=k, value=json.dumps(v)) - for k, v in table_info.properties.items() - ] - if table_info.properties is not None - else [] - ), ), ) - if table_info.owner is not None: - dataset.system_contacts = SystemContacts( - contacts=[ - SystemContact( - email=self._get_owner_display_name(table_info.owner), - system_contact_source=AssetPlatform.UNITY_CATALOG, - ) - ] - ) + # Queue non-view tables for batch query later + if table_info.view_definition is None: + self._table_properties_queue.append(table_name) + + dataset.system_contacts = SystemContacts( + contacts=[ + SystemContact( + email=self._get_owner_display_name(table_info.owner), + system_contact_source=AssetPlatform.UNITY_CATALOG, + ) + ] + ) dataset.system_tags = SystemTags(tags=[]) @@ -499,9 +491,10 @@ def _init_hierarchy( self, catalog: str, schema: Optional[str] = None, + owner: Optional[str] = None, + comment: Optional[str] = None, ) -> Hierarchy: path = [part.lower() for part in [catalog, schema] if part] - return self._hierarchies.setdefault( ".".join(path), Hierarchy( @@ -511,43 +504,94 @@ def _init_hierarchy( ), ) - def _extract_hierarchies(self, catalog_system_tags: CatalogSystemTags) -> None: - for catalog, (catalog_tags, schema_name_to_tag) in catalog_system_tags.items(): + def _init_catalog( + self, + catalog_info: CatalogInfo, + ) -> Hierarchy: + hierarchy = self._init_hierarchy(catalog_info.catalog_name) + + hierarchy.contacts = [ + SystemContact( + email=self._get_owner_display_name(catalog_info.owner), + system_contact_source=AssetPlatform.UNITY_CATALOG, + ) + ] + + if catalog_info.comment: + hierarchy.system_descriptions = [ + SystemDescription( + description=catalog_info.comment, + platform=AssetPlatform.UNITY_CATALOG, + ) + ] + + return hierarchy + + def _init_schema( + self, + schema_info: SchemaInfo, + ) -> Hierarchy: + hierarchy = self._init_hierarchy( + schema_info.catalog_name, schema_info.schema_name + ) + + hierarchy.contacts = [ + SystemContact( + email=self._get_owner_display_name(schema_info.owner), + system_contact_source=AssetPlatform.UNITY_CATALOG, + ) + ] + + if schema_info.comment: + hierarchy.system_descriptions = [ + SystemDescription( + description=schema_info.comment, + platform=AssetPlatform.UNITY_CATALOG, + ) + ] + + return hierarchy + + def _extract_hierarchies( + self, catalog_to_tags: Dict[str, CatalogSystemTags] + ) -> None: + for catalog, catalog_system_tags in catalog_to_tags.items(): + catalog_tags = catalog_system_tags.catalog_tags + schema_name_to_tags = catalog_system_tags.schema_name_to_tags if catalog_tags: hierarchy = self._init_hierarchy(catalog) hierarchy.system_tags = SystemTags(tags=catalog_tags) - for schema, schema_tags in schema_name_to_tag.items(): + for schema, schema_tags in schema_name_to_tags.items(): if schema_tags: hierarchy = self._init_hierarchy(catalog, schema) hierarchy.system_tags = SystemTags(tags=schema_tags) - def _fetch_catalog_system_tags(self, catalog: str) -> CatalogSystemTagsTuple: + def _fetch_catalog_system_tags(self, catalog: str) -> CatalogSystemTags: logger.info(f"Fetching tags for catalog {catalog}") - with self._connection.cursor() as cursor: - catalog_tags = [] - schema_tags: Dict[str, List[SystemTag]] = defaultdict(list) - catalog_tags_query = f"SELECT tag_name, tag_value FROM {catalog}.information_schema.catalog_tags" - cursor.execute(catalog_tags_query) - for tag_name, tag_value in cursor.fetchall(): - tag = SystemTag( - key=tag_name, - value=tag_value, + catalog_tags = [ + SystemTag( + key=tag.key, + value=tag.value, + system_tag_source=SystemTagSource.UNITY_CATALOG, + ) + for tag in list_catalog_tags(self._connection, catalog) + ] + + schema_name_to_tags = {} + for schema_name, tags in list_schema_tags(self._connection, catalog).items(): + schema_name_to_tags[schema_name] = [ + SystemTag( + key=tag.key, + value=tag.value, system_tag_source=SystemTagSource.UNITY_CATALOG, ) - catalog_tags.append(tag) + for tag in tags + ] - schema_tags_query = f"SELECT schema_name, tag_name, tag_value FROM {catalog}.information_schema.schema_tags" - cursor.execute(schema_tags_query) - for schema_name, tag_name, tag_value in cursor.fetchall(): - if self._filter.include_schema(catalog, schema_name): - tag = SystemTag( - key=tag_name, - value=tag_value, - system_tag_source=SystemTagSource.UNITY_CATALOG, - ) - schema_tags[schema_name].append(tag) - return catalog_tags, schema_tags + return CatalogSystemTags( + catalog_tags=catalog_tags, schema_name_to_tags=schema_name_to_tags + ) def _assign_dataset_system_tags( self, catalog: str, catalog_system_tags: CatalogSystemTags @@ -562,8 +606,10 @@ def _assign_dataset_system_tags( if dataset is not None: assert dataset.system_tags dataset.system_tags.tags = ( - catalog_system_tags[catalog][0] - + catalog_system_tags[catalog][1][schema.name] + catalog_system_tags.catalog_tags + + catalog_system_tags.schema_name_to_tags.get( + schema.name, [] + ) ) def _extract_object_tags( @@ -662,9 +708,8 @@ def _extract_column_tags(self, catalog: str) -> None: tag = f"{tag_name}={tag_value}" if tag_value else tag_name - assert ( - dataset.schema is not None - ) # Can't be None, we initialized it at `init_dataset` + # Can't be None, we initialized it at `init_dataset` + assert dataset.schema is not None if dataset.schema.fields: field = next( ( @@ -679,74 +724,56 @@ def _extract_column_tags(self, catalog: str) -> None: field.tags = [] field.tags.append(tag) - def _fetch_tags(self, catalogs: List[str]): - catalog_system_tags: CatalogSystemTags = {} + def _fetch_tags(self, catalogs: List[CatalogInfo]): + catalog_to_tags: Dict[str, CatalogSystemTags] = {} - for catalog in catalogs: + for catalog_info in catalogs: + catalog = catalog_info.catalog_name if self._filter.include_database(catalog): - catalog_system_tags[catalog] = self._fetch_catalog_system_tags(catalog) - self._extract_hierarchies(catalog_system_tags) - self._assign_dataset_system_tags(catalog, catalog_system_tags) + catalog_to_tags[catalog] = self._fetch_catalog_system_tags(catalog) + self._extract_hierarchies(catalog_to_tags) + self._assign_dataset_system_tags(catalog, catalog_to_tags[catalog]) self._extract_table_tags(catalog) self._extract_volume_tags(catalog) self._extract_column_tags(catalog) def _extract_volume_files(self, volume: VolumeInfo): + catalog_name = volume.catalog_name + schema_name = volume.schema_name + volume_name = volume.volume_name + volume_dataset = self._datasets.get( - dataset_normalized_name( - volume.catalog_name, volume.schema_name, volume.name - ) + dataset_normalized_name(catalog_name, schema_name, volume_name) ) - if not volume_dataset: + if volume_dataset is None: return - with self._connection.cursor() as cursor: - query = f"LIST '/Volumes/{volume.catalog_name}/{volume.schema_name}/{volume.name}'" + volume_entity_id = str( + to_dataset_entity_id_from_logical_id(volume_dataset.logical_id) + ) - cursor.execute(query) - for path, name, size, modification_time in cursor.fetchall(): - last_updated = to_utc_from_timestamp_ms(timestamp_ms=modification_time) - volume_file = self._init_volume_file( - path, - size, - last_updated, + for volume_file_info in list_volume_files(self._connection, volume): + volume_file = self._init_volume_file(volume_file_info, volume_entity_id) + assert volume_dataset.unity_catalog.volume_info.volume_files is not None + + volume_dataset.unity_catalog.volume_info.volume_files.append( + VolumeFile( + modification_time=volume_file_info.last_updated, + name=volume_file_info.name, + path=volume_file_info.path, + size=volume_file_info.size, entity_id=str( - to_dataset_entity_id_from_logical_id(volume_dataset.logical_id) + to_dataset_entity_id_from_logical_id(volume_file.logical_id) ), ) - - if volume_dataset and volume_file: - assert ( - volume_dataset.unity_catalog.volume_info.volume_files - is not None - ) - volume_dataset.unity_catalog.volume_info.volume_files.append( - VolumeFile( - modification_time=last_updated, - name=name, - path=path, - size=float(size), - entity_id=str( - to_dataset_entity_id_from_logical_id( - volume_file.logical_id - ) - ), - ) - ) + ) def _init_volume(self, volume: VolumeInfo): - assert ( - volume.volume_type - and volume.schema_name - and volume.catalog_name - and volume.name - ) - schema_name = volume.schema_name catalog_name = volume.catalog_name - name = volume.name - full_name = dataset_normalized_name(catalog_name, schema_name, name) + volume_name = volume.volume_name + full_name = dataset_normalized_name(catalog_name, schema_name, volume_name) dataset = Dataset() dataset.logical_id = DatasetLogicalID( @@ -755,15 +782,15 @@ def _init_volume(self, volume: VolumeInfo): ) dataset.structure = DatasetStructure( - database=catalog_name, schema=schema_name, table=volume.name + database=catalog_name, schema=schema_name, table=volume_name ) - main_url = self._get_volume_source_url(catalog_name, schema_name, name) + main_url = self._get_volume_source_url(catalog_name, schema_name, volume_name) dataset.source_info = SourceInfo( main_url=main_url, - last_updated=to_utc_from_timestamp_ms(timestamp_ms=volume.updated_at), - created_at_source=to_utc_from_timestamp_ms(timestamp_ms=volume.created_at), + created_at_source=volume.created_at, created_by=volume.created_by, + last_updated=volume.updated_at, updated_by=volume.updated_by, ) @@ -784,7 +811,7 @@ def _init_volume(self, volume: VolumeInfo): dataset.unity_catalog = UnityCatalog( dataset_type=UnityCatalogDatasetType.UNITY_CATALOG_VOLUME, volume_info=UnityCatalogVolumeInfo( - type=UnityCatalogVolumeType[volume.volume_type.value], + type=UnityCatalogVolumeType[volume.volume_type], volume_files=[], storage_location=volume.storage_location, ), @@ -799,33 +826,30 @@ def _init_volume(self, volume: VolumeInfo): def _init_volume_file( self, - path: str, - size: int, - last_updated: Optional[datetime], - entity_id: str, + volume_file_info: VolumeFileInfo, + volumn_entity_id: str, ) -> Optional[Dataset]: dataset = Dataset() dataset.logical_id = DatasetLogicalID( # We use path as ID for file - name=path, + name=volume_file_info.path, platform=DataPlatform.UNITY_CATALOG_VOLUME_FILE, ) - if last_updated: - dataset.source_info = SourceInfo(last_updated=last_updated) + dataset.source_info = SourceInfo(last_updated=volume_file_info.last_updated) dataset.unity_catalog = UnityCatalog( dataset_type=UnityCatalogDatasetType.UNITY_CATALOG_VOLUME_FILE, - volume_entity_id=entity_id, + volume_entity_id=volumn_entity_id, ) dataset.statistics = to_dataset_statistics( - size_bytes=size, + size_bytes=volume_file_info.size, ) dataset.entity_upstream = EntityUpstream(source_entities=[]) - self._datasets[path] = dataset + self._datasets[volume_file_info.path] = dataset return dataset @@ -849,6 +873,27 @@ def _populate_last_refreshed_time(self): last_updated=last_refreshed_time, ) + def _populate_table_properties(self): + connection_pool = create_connection_pool( + self._token, self._hostname, self._http_path, self._max_concurrency + ) + + result_map = batch_get_table_properties( + connection_pool, + self._table_properties_queue, + ) + + for name, properties in result_map.items(): + dataset = self._datasets.get(name) + assert ( + dataset is not None + and dataset.unity_catalog is not None + and dataset.unity_catalog.table_info is not None + ) + dataset.unity_catalog.table_info.properties = [ + KeyValuePair(key=k, value=v) for k, v in properties.items() + ] + def _get_owner_display_name(self, user_id: str) -> str: # Unity Catalog returns service principal's application_id and must be # manually map back to display_name diff --git a/metaphor/unity_catalog/models.py b/metaphor/unity_catalog/models.py index 38a721e0..e45dec67 100644 --- a/metaphor/unity_catalog/models.py +++ b/metaphor/unity_catalog/models.py @@ -1,29 +1,8 @@ -from typing import Dict, List +from datetime import datetime +from typing import Dict, List, Literal, Optional -from databricks.sdk.service.catalog import ColumnInfo from pydantic import BaseModel -from metaphor.common.fieldpath import build_schema_field -from metaphor.common.logger import get_logger -from metaphor.models.metadata_change_event import SchemaField - -logger = get_logger() - - -def extract_schema_field_from_column_info(column: ColumnInfo) -> SchemaField: - if column.name is None or column.type_name is None: - raise ValueError(f"Invalid column {column.name}, no type_name found") - - field = build_schema_field( - column.name, column.type_name.value.lower(), column.comment - ) - field.precision = ( - float(column.type_precision) - if column.type_precision is not None - else float("nan") - ) - return field - class TableLineage(BaseModel): upstream_tables: List[str] = [] @@ -36,3 +15,73 @@ class Column(BaseModel): class ColumnLineage(BaseModel): upstream_columns: Dict[str, List[Column]] = {} + + +class CatalogInfo(BaseModel): + catalog_name: str + owner: str + comment: Optional[str] = None + + +class SchemaInfo(BaseModel): + catalog_name: str + schema_name: str + owner: str + comment: Optional[str] = None + + +class ColumnInfo(BaseModel): + column_name: str + data_type: str + data_precision: Optional[int] + is_nullable: bool + comment: Optional[str] = None + + +class TableInfo(BaseModel): + catalog_name: str + schema_name: str + table_name: str + type: str + owner: str + comment: Optional[str] = None + created_at: datetime + created_by: str + updated_at: datetime + updated_by: str + view_definition: Optional[str] = None + storage_location: Optional[str] = None + data_source_format: str + columns: List[ColumnInfo] = [] + + +class VolumeInfo(BaseModel): + catalog_name: str + schema_name: str + volume_name: str + volume_type: Literal["MANAGED", "EXTERNAL"] + full_name: str + owner: str + comment: Optional[str] = None + created_at: datetime + created_by: str + updated_at: datetime + updated_by: str + storage_location: str + + +class VolumeFileInfo(BaseModel): + last_updated: datetime + name: str + path: str + size: float + + +class CatalogTag(BaseModel): + key: str + value: str + + +class SchemaTag(BaseModel): + key: str + value: str diff --git a/metaphor/unity_catalog/queries.py b/metaphor/unity_catalog/queries.py new file mode 100644 index 00000000..82a0aad3 --- /dev/null +++ b/metaphor/unity_catalog/queries.py @@ -0,0 +1,633 @@ +from datetime import datetime +from typing import Collection, Dict, List, Optional, Tuple + +from databricks.sql.client import Connection + +from metaphor.common.logger import get_logger, json_dump_to_debug_file +from metaphor.common.utils import start_of_day +from metaphor.unity_catalog.models import ( + CatalogInfo, + CatalogTag, + Column, + ColumnInfo, + ColumnLineage, + SchemaInfo, + SchemaTag, + TableInfo, + TableLineage, + VolumeFileInfo, + VolumeInfo, +) + +logger = get_logger() + + +TableLineageMap = Dict[str, TableLineage] +"""Map a table's full name to its table lineage""" + +ColumnLineageMap = Dict[str, ColumnLineage] +"""Map a table's full name to its column lineage""" + +IGNORED_HISTORY_OPERATIONS = { + "ADD CONSTRAINT", + "CHANGE COLUMN", + "LIQUID TAGGING", + "OPTIMIZE", + "SET TBLPROPERTIES", +} +"""These are the operations that do not modify actual data.""" + + +def list_catalogs(connection: Connection) -> List[CatalogInfo]: + """ + Fetch catalogs from system.access.information_schema.catalogs + See https://docs.databricks.com/en/sql/language-manual/sql-ref-information-schema.html for more details + """ + catalogs: List[CatalogInfo] = [] + + with connection.cursor() as cursor: + query = """ + SELECT + catalog_name, + catalog_owner, + comment + FROM system.information_schema.catalogs + WHERE catalog_name <> 'SYSTEM' + """ + + try: + cursor.execute(query) + except Exception as error: + logger.exception(f"Failed to list catalogs: {error}") + return [] + + for row in cursor.fetchall(): + catalogs.append( + CatalogInfo( + catalog_name=row["catalog_name"], + owner=row["catalog_owner"], + comment=row["comment"], + ) + ) + + logger.info(f"Found {len(catalogs)} catalogs") + json_dump_to_debug_file(catalogs, "list_catalogs.json") + return catalogs + + +def list_schemas(connection: Connection, catalog: str) -> List[SchemaInfo]: + """ + Fetch schemas for a specific catalog from system.access.information_schema.schemata + See https://docs.databricks.com/en/sql/language-manual/sql-ref-information-schema.html for more details + """ + schemas: List[SchemaInfo] = [] + + with connection.cursor() as cursor: + query = f""" + SELECT + catalog_name, + schema_name, + schema_owner, + comment + FROM system.information_schema.schemata + WHERE catalog_name = '{catalog}' AND schema_name <> 'INFORMATION_SCHEMA' + """ + + try: + cursor.execute(query) + except Exception as error: + logger.exception(f"Failed to list schemas for {catalog}: {error}") + return [] + + for row in cursor.fetchall(): + schemas.append( + SchemaInfo( + catalog_name=row["catalog_name"], + schema_name=row["schema_name"], + owner=row["schema_owner"], + comment=row["comment"], + ) + ) + + logger.info(f"Found {len(schemas)} schemas from {catalog}") + json_dump_to_debug_file(schemas, f"list_schemas_{catalog}.json") + return schemas + + +def list_tables(connection: Connection, catalog: str, schema: str) -> List[TableInfo]: + """ + Fetch schemas for a specific catalog from system.access.information_schema.schemata + See https://docs.databricks.com/en/sql/language-manual/sql-ref-information-schema.html for more details + """ + table_map: Dict[str, TableInfo] = {} + + with connection.cursor() as cursor: + query = f""" + WITH + t AS ( + SELECT + table_catalog, + table_schema, + table_name, + table_type, + table_owner, + comment, + data_source_format, + storage_path, + created, + created_by, + last_altered, + last_altered_by + FROM system.information_schema.tables + WHERE + table_catalog = '{catalog}' AND + table_schema <> 'INFORMATION_SCHEMA' AND + table_schema IS NOT NULL AND + table_name IS NOT NULL + ), + + v AS ( + SELECT + table_catalog, + table_schema, + table_name, + view_definition + FROM system.information_schema.views + WHERE + table_catalog = '{catalog}' AND + table_schema <> 'INFORMATION_SCHEMA' AND + table_schema IS NOT NULL AND + table_name IS NOT NULL + ), + + tv AS ( + SELECT + t.table_catalog, + t.table_schema, + t.table_name, + t.table_type, + t.table_owner, + t.comment, + t.data_source_format, + t.storage_path, + t.created, + t.created_by, + t.last_altered, + t.last_altered_by, + v.view_definition + FROM t + LEFT JOIN v + ON + t.table_catalog = v.table_catalog AND + t.table_schema = v.table_schema AND + t.table_name = v.table_name + ), + + c AS ( + SELECT + table_catalog, + table_schema, + table_name, + column_name, + full_data_type as data_type, + CASE + WHEN numeric_precision IS NOT NULL THEN numeric_precision + WHEN datetime_precision IS NOT NULL THEN datetime_precision + ELSE NULL + END AS data_precision, + is_nullable, + comment + FROM system.information_schema.columns + WHERE + table_catalog = '{catalog}' AND + table_schema <> 'INFORMATION_SCHEMA' AND + table_schema IS NOT NULL AND + table_name IS NOT NULL + ) + + SELECT + tv.table_catalog as catalog_name, + tv.table_schema as schema_name, + tv.table_name as table_name, + tv.table_type as table_type, + tv.table_owner as owner, + tv.comment as table_comment, + tv.data_source_format as data_source_format, + tv.storage_path as storage_path, + tv.created as created_at, + tv.created_by as created_by, + tv.last_altered as updated_at, + tv.last_altered_by as updated_by, + tv.view_definition as view_definition, + c.column_name as column_name, + c.data_type as data_type, + c.data_precision as data_precision, + c.is_nullable as is_nullable, + c.comment as column_comment + FROM tv + LEFT JOIN c + ON + tv.table_catalog = c.table_catalog AND + tv.table_schema = c.table_schema AND + tv.table_name = c.table_name + """ + + try: + cursor.execute(query) + except Exception as error: + logger.exception(f"Failed to list tables for {catalog}.{schema}: {error}") + return [] + + for row in cursor.fetchall(): + full_name = ( + f"{row['catalog_name']}.{row['schema_name']}.{row['table_name']}" + ) + table = table_map.setdefault( + full_name, + TableInfo( + catalog_name=row["catalog_name"], + schema_name=row["schema_name"], + table_name=row["table_name"], + type=row["table_type"], + owner=row["owner"], + comment=row["table_comment"], + created_at=row["created_at"], + created_by=row["created_by"], + updated_at=row["updated_at"], + updated_by=row["updated_by"], + data_source_format=row["data_source_format"], + view_definition=row["view_definition"], + storage_location=row["storage_path"], + columns=[], + ), + ) + + table.columns.append( + ColumnInfo( + column_name=row["column_name"], + data_type=row["data_type"], + data_precision=row["data_precision"], + is_nullable=row["is_nullable"] == "YES", + comment=row["column_comment"], + ) + ) + + tables = list(table_map.values()) + logger.info(f"Found {len(tables)} tables from {catalog}") + json_dump_to_debug_file(tables, f"list_tables_{catalog}_{schema}.json") + return tables + + +def list_volumes(connection: Connection, catalog: str, schema: str) -> List[VolumeInfo]: + """ + Fetch volumes for a specific catalog from system.access.information_schema.volumes + See https://docs.databricks.com/en/sql/language-manual/sql-ref-information-schema.html for more details + """ + volumes: List[VolumeInfo] = [] + + with connection.cursor() as cursor: + query = f""" + SELECT + volume_catalog, + volume_schema, + volume_name, + volume_type, + volume_owner, + comment, + created, + created_by, + last_altered, + last_altered_by, + storage_location + FROM system.information_schema.volumes + WHERE volume_catalog = '{catalog}' AND volume_schema = '{schema}' + """ + + try: + cursor.execute(query) + except Exception as error: + logger.exception(f"Failed to list volumes for {catalog}.{schema}: {error}") + return [] + + for row in cursor.fetchall(): + volumes.append( + VolumeInfo( + catalog_name=row["volume_catalog"], + schema_name=row["volume_schema"], + volume_name=row["volume_name"], + full_name=f"{row['volume_catalog']}.{row['volume_schema']}.{row['volume_name']}".lower(), + volume_type=row["volume_type"], + owner=row["volume_owner"], + comment=row["comment"], + created_at=row["created"], + created_by=row["created_by"], + updated_at=row["last_altered"], + updated_by=row["last_altered_by"], + storage_location=row["storage_location"], + ) + ) + + logger.info(f"Found {len(volumes)} volumes from {catalog}.{schema}") + json_dump_to_debug_file(volumes, f"list_volumes_{catalog}_{schema}.json") + return volumes + + +def list_volume_files( + connection: Connection, volume_info: VolumeInfo +) -> List[VolumeFileInfo]: + """ + List files in a volume + See https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-aux-list.html + """ + + catalog_name = volume_info.catalog_name + schema_name = volume_info.schema_name + volume_name = volume_info.volume_name + + volume_files: List[VolumeFileInfo] = [] + + with connection.cursor() as cursor: + query = f"LIST '/Volumes/{catalog_name}/{schema_name}/{volume_name}'" + + try: + cursor.execute(query) + except Exception as error: + logger.exception( + f"Failed to list files in {volume_info.full_name}: {error}" + ) + return [] + + for row in cursor.fetchall(): + volume_files.append( + VolumeFileInfo( + last_updated=row["modification_time"], + name=row["name"], + path=row["path"], + size=float(row["size"]), + ) + ) + + logger.info(f"Found {len(volume_files)} files in {volume_info.full_name}") + json_dump_to_debug_file( + volume_files, + f"list_volume_files_{catalog_name}_{schema_name}_{volume_name}.json", + ) + return volume_files + + +def list_table_lineage( + connection: Connection, catalog: str, schema: str, lookback_days=7 +) -> TableLineageMap: + """ + Fetch table lineage for a specific schema from system.access.table_lineage table + See https://docs.databricks.com/en/admin/system-tables/lineage.html for more details + """ + + table_lineage: Dict[str, TableLineage] = {} + + with connection.cursor() as cursor: + query = f""" + SELECT + source_table_full_name, + target_table_full_name + FROM system.access.table_lineage + WHERE + target_table_catalog = '{catalog}' AND + target_table_schema = '{schema}' AND + source_table_full_name IS NOT NULL AND + event_time > date_sub(now(), {lookback_days}) + GROUP BY + source_table_full_name, + target_table_full_name + """ + + try: + cursor.execute(query) + except Exception as error: + logger.exception( + f"Failed to list table lineage for {catalog}.{schema}: {error}" + ) + return {} + + for source_table, target_table in cursor.fetchall(): + lineage = table_lineage.setdefault(target_table.lower(), TableLineage()) + lineage.upstream_tables.append(source_table.lower()) + + logger.info( + f"Fetched table lineage for {len(table_lineage)} tables in {catalog}.{schema}" + ) + json_dump_to_debug_file(table_lineage, f"table_lineage_{catalog}_{schema}.json") + return table_lineage + + +def list_column_lineage( + connection: Connection, catalog: str, schema: str, lookback_days=7 +) -> ColumnLineageMap: + """ + Fetch column lineage for a specific schema from system.access.table_lineage table + See https://docs.databricks.com/en/admin/system-tables/lineage.html for more details + """ + column_lineage: Dict[str, ColumnLineage] = {} + + with connection.cursor() as cursor: + query = f""" + SELECT + source_table_full_name, + source_column_name, + target_table_full_name, + target_column_name + FROM system.access.column_lineage + WHERE + target_table_catalog = '{catalog}' AND + target_table_schema = '{schema}' AND + source_table_full_name IS NOT NULL AND + event_time > date_sub(now(), {lookback_days}) + GROUP BY + source_table_full_name, + source_column_name, + target_table_full_name, + target_column_name + """ + + try: + cursor.execute(query) + except Exception as error: + logger.exception( + f"Failed to list column lineage for {catalog}.{schema}: {error}" + ) + return {} + + cursor.execute(query) + + for ( + source_table, + source_column, + target_table, + target_column, + ) in cursor.fetchall(): + lineage = column_lineage.setdefault(target_table.lower(), ColumnLineage()) + columns = lineage.upstream_columns.setdefault(target_column.lower(), []) + columns.append( + Column( + table_name=source_table.lower(), column_name=source_column.lower() + ) + ) + + logger.info( + f"Fetched column lineage for {len(column_lineage)} tables in {catalog}.{schema}" + ) + json_dump_to_debug_file(column_lineage, f"column_lineage_{catalog}_{schema}.json") + return column_lineage + + +def list_query_logs( + connection: Connection, lookback_days: int, excluded_usernames: Collection[str] +): + """ + Fetch query logs from system.query.history table + See https://docs.databricks.com/en/admin/system-tables/query-history.html + """ + start = start_of_day(lookback_days) + end = start_of_day() + + user_condition = ",".join([f"'{user}'" for user in excluded_usernames]) + user_filter = f"Q.executed_by IN ({user_condition}) AND" if user_condition else "" + + with connection.cursor() as cursor: + query = f""" + SELECT + statement_id as query_id, + executed_by as email, + start_time, + int(total_task_duration_ms/1000) as duration, + read_rows as rows_read, + produced_rows as rows_written, + read_bytes as bytes_read, + written_bytes as bytes_written, + statement_type as query_type, + statement_text as query_text + FROM system.query.history + WHERE + {user_filter} + execution_status = 'FINISHED' AND + start_time >= ? AND + start_time < ? + """ + + try: + cursor.execute(query, [start, end]) + except Exception as error: + logger.exception(f"Failed to list query logs: {error}") + return [] + + return cursor.fetchall() + + +def get_last_refreshed_time( + connection: Connection, + table_full_name: str, + limit: int, +) -> Optional[Tuple[str, datetime]]: + """ + Retrieve the last refresh time for a table + See https://docs.databricks.com/en/delta/history.html + """ + + with connection.cursor() as cursor: + try: + cursor.execute(f"DESCRIBE HISTORY {table_full_name} LIMIT {limit}") + except Exception as error: + logger.exception(f"Failed to get history for {table_full_name}: {error}") + return None + + for history in cursor.fetchall(): + operation = history["operation"] + if operation not in IGNORED_HISTORY_OPERATIONS: + logger.info( + f"Fetched last refresh time for {table_full_name} ({operation})" + ) + return (table_full_name, history["timestamp"]) + + return None + + +def get_table_properties( + connection: Connection, + table_full_name: str, +) -> Dict[str, str]: + """ + Retrieve the properties for a table + See https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-aux-show-tblproperties.html + """ + properties: Dict[str, str] = {} + with connection.cursor() as cursor: + try: + cursor.execute(f"SHOW TBLPROPERTIES {table_full_name}") + except Exception as error: + logger.exception( + f"Failed to show table properties for {table_full_name}: {error}" + ) + return properties + + for row in cursor.fetchall(): + properties[row["key"]] = row["value"] + + return properties + + +def list_catalog_tags(connection: Connection, catalog: str) -> List[CatalogTag]: + """ + List tags for a catalog + See https://docs.databricks.com/en/sql/language-manual/information-schema/catalog_tags.html + """ + + catalog_tags: List[CatalogTag] = [] + + with connection.cursor() as cursor: + try: + cursor.execute( + f"SELECT tag_name, tag_value FROM {catalog}.information_schema.catalog_tags" + ) + except Exception as error: + logger.exception(f"Failed to fetch catalog tags for {catalog}: {error}") + return [] + + for row in cursor.fetchall(): + catalog_tags.append( + CatalogTag( + key=row["tag_name"], + value=row["tag_value"], + ) + ) + + return catalog_tags + + +def list_schema_tags( + connection: Connection, catalog: str +) -> Dict[str, List[SchemaTag]]: + """ + List tags for all schemas in a catalog + See https://docs.databricks.com/en/sql/language-manual/information-schema/schema_tags.html + """ + + schema_tags: Dict[str, List[SchemaTag]] = {} + + with connection.cursor() as cursor: + try: + cursor.execute( + f"SELECT schema_name, tag_name, tag_value FROM {catalog}.information_schema.schema_tags" + ) + except Exception as error: + logger.exception(f"Failed to fetch schema tags for {catalog}: {error}") + return {} + + for row in cursor.fetchall(): + schema_tags[row["schema_name"]].append( + SchemaTag( + key=row["tag_name"], + value=row["tag_value"], + ) + ) + + return schema_tags diff --git a/metaphor/unity_catalog/utils.py b/metaphor/unity_catalog/utils.py index c711b4a4..33343895 100644 --- a/metaphor/unity_catalog/utils.py +++ b/metaphor/unity_catalog/utils.py @@ -1,7 +1,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from queue import Queue -from typing import Collection, Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set from databricks import sql from databricks.sdk import WorkspaceClient @@ -10,186 +10,25 @@ from databricks.sql.types import Row from metaphor.common.entity_id import dataset_normalized_name, to_dataset_entity_id +from metaphor.common.fieldpath import build_schema_field from metaphor.common.logger import get_logger, json_dump_to_debug_file from metaphor.common.sql.process_query.config import ProcessQueryConfig from metaphor.common.sql.query_log import PartialQueryLog, process_and_init_query_log from metaphor.common.sql.table_level_lineage.table_level_lineage import ( extract_table_level_lineage, ) -from metaphor.common.utils import is_email, safe_float, start_of_day +from metaphor.common.utils import is_email, safe_float +from metaphor.datahub.gql_parser import SchemaField from metaphor.models.metadata_change_event import DataPlatform, Dataset, QueriedDataset -from metaphor.unity_catalog.models import Column, ColumnLineage, TableLineage +from metaphor.unity_catalog.models import ColumnInfo +from metaphor.unity_catalog.queries import ( + get_last_refreshed_time, + get_table_properties, + list_query_logs, +) logger = get_logger() -# Map a table's full name to its table lineage -TableLineageMap = Dict[str, TableLineage] - -# Map a table's full name to its column lineage -ColumnLineageMap = Dict[str, ColumnLineage] - -IGNORED_HISTORY_OPERATIONS = { - "ADD CONSTRAINT", - "CHANGE COLUMN", - "LIQUID TAGGING", - "OPTIMIZE", - "SET TBLPROPERTIES", -} -"""These are the operations that do not modify actual data.""" - - -def list_table_lineage( - connection: Connection, catalog: str, schema: str, lookback_days=7 -) -> TableLineageMap: - """ - Fetch table lineage for a specific schema from system.access.table_lineage table - See https://docs.databricks.com/en/admin/system-tables/lineage.html for more details - """ - - with connection.cursor() as cursor: - query = f""" - SELECT - source_table_full_name, - target_table_full_name - FROM system.access.table_lineage - WHERE - target_table_catalog = '{catalog}' AND - target_table_schema = '{schema}' AND - source_table_full_name IS NOT NULL AND - event_time > date_sub(now(), {lookback_days}) - GROUP BY - source_table_full_name, - target_table_full_name - """ - cursor.execute(query) - - table_lineage: Dict[str, TableLineage] = {} - for source_table, target_table in cursor.fetchall(): - lineage = table_lineage.setdefault(target_table.lower(), TableLineage()) - lineage.upstream_tables.append(source_table.lower()) - - logger.info( - f"Fetched table lineage for {len(table_lineage)} tables in {catalog}.{schema}" - ) - json_dump_to_debug_file(table_lineage, f"table_lineage_{catalog}_{schema}.json") - return table_lineage - - -def list_column_lineage( - connection: Connection, catalog: str, schema: str, lookback_days=7 -) -> ColumnLineageMap: - """ - Fetch column lineage for a specific schema from system.access.table_lineage table - See https://docs.databricks.com/en/admin/system-tables/lineage.html for more details - """ - - with connection.cursor() as cursor: - query = f""" - SELECT - source_table_full_name, - source_column_name, - target_table_full_name, - target_column_name - FROM system.access.column_lineage - WHERE - target_table_catalog = '{catalog}' AND - target_table_schema = '{schema}' AND - source_table_full_name IS NOT NULL AND - event_time > date_sub(now(), {lookback_days}) - GROUP BY - source_table_full_name, - source_column_name, - target_table_full_name, - target_column_name - """ - cursor.execute(query) - - column_lineage: Dict[str, ColumnLineage] = {} - for ( - source_table, - source_column, - target_table, - target_column, - ) in cursor.fetchall(): - lineage = column_lineage.setdefault(target_table.lower(), ColumnLineage()) - columns = lineage.upstream_columns.setdefault(target_column.lower(), []) - columns.append( - Column( - table_name=source_table.lower(), column_name=source_column.lower() - ) - ) - - logger.info( - f"Fetched column lineage for {len(column_lineage)} tables in {catalog}.{schema}" - ) - json_dump_to_debug_file(column_lineage, f"column_lineage_{catalog}_{schema}.json") - return column_lineage - - -def list_query_logs( - connection: Connection, lookback_days: int, excluded_usernames: Collection[str] -): - """ - Fetch query logs from system.query.history table - See https://docs.databricks.com/en/admin/system-tables/query-history.html - """ - start = start_of_day(lookback_days) - end = start_of_day() - - user_condition = ",".join([f"'{user}'" for user in excluded_usernames]) - user_filter = f"Q.executed_by IN ({user_condition}) AND" if user_condition else "" - - with connection.cursor() as cursor: - query = f""" - SELECT - statement_id as query_id, - executed_by as email, - start_time, - int(total_task_duration_ms/1000) as duration, - read_rows as rows_read, - produced_rows as rows_written, - read_bytes as bytes_read, - written_bytes as bytes_written, - statement_type as query_type, - statement_text as query_text - FROM system.query.history - WHERE - {user_filter} - execution_status = 'FINISHED' AND - start_time >= ? AND - start_time < ? - """ - cursor.execute(query, [start, end]) - return cursor.fetchall() - - -def get_last_refreshed_time( - connection: Connection, - table_full_name: str, - limit: int, -) -> Optional[Tuple[str, datetime]]: - """ - Retrieve the last refresh time for a table - See https://docs.databricks.com/en/delta/history.html - """ - - with connection.cursor() as cursor: - try: - cursor.execute(f"DESCRIBE HISTORY {table_full_name} LIMIT {limit}") - except Exception as error: - logger.exception(f"Failed to get history for {table_full_name}: {error}") - return None - - for history in cursor.fetchall(): - operation = history["operation"] - if operation not in IGNORED_HISTORY_OPERATIONS: - logger.info( - f"Fetched last refresh time for {table_full_name} ({operation})" - ) - return (table_full_name, history["timestamp"]) - - return None - def batch_get_last_refreshed_time( connection_pool: Queue, @@ -231,6 +70,43 @@ def get_last_refreshed_time_helper(table_full_name: str): return result_map +def batch_get_table_properties( + connection_pool: Queue, + table_full_names: List[str], +): + result_map: Dict[str, Dict[str, str]] = {} + + with ThreadPoolExecutor(max_workers=connection_pool.maxsize) as executor: + + def get_table_properties_helper(table_full_name: str): + connection = connection_pool.get() + result = get_table_properties(connection, table_full_name) + connection_pool.put(connection) + return result + + futures = { + executor.submit( + get_table_properties_helper, + table_full_name, + ): table_full_name + for table_full_name in table_full_names + } + + for future in as_completed(futures): + try: + result = future.result() + if result is None: + continue + table_full_name, properties = result + result_map[table_full_name] = properties + except Exception: + logger.exception( + f"Not able to get table properties for {futures[future]}" + ) + + return result_map + + SPECIAL_CHARACTERS = "&*{}[],=-()+;'\"`" """ The special characters mentioned in Databricks documentation are: @@ -360,7 +236,7 @@ def find_qualified_dataset(dataset: QueriedDataset, datasets: Dict[str, Dataset] return None -def to_query_log_with_tll( +def _to_query_log_with_tll( row: Row, service_principals: Dict[str, ServicePrincipal], datasets: Dict[str, Dataset], @@ -436,7 +312,7 @@ def get_query_logs( count = 0 logger.info(f"{len(rows)} queries to fetch") for row in rows: - res = to_query_log_with_tll( + res = _to_query_log_with_tll( row, service_principals, datasets, process_query_config ) if res is not None: @@ -444,3 +320,13 @@ def get_query_logs( if count % 1000 == 0: logger.info(f"Fetched {count} queries") yield res + + +def build_schema_field_from_column_info(column: ColumnInfo) -> SchemaField: + return build_schema_field( + column_name=column.column_name, + field_type=column.data_type, + description=column.comment, + nullable=column.is_nullable, + precision=safe_float(column.data_precision), + ) diff --git a/pyproject.toml b/pyproject.toml index 5595112d..c6e2a210 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "metaphor-connectors" -version = "0.14.136" +version = "0.14.137" license = "Apache-2.0" description = "A collection of Python-based 'connectors' that extract metadata from various sources to ingest into the Metaphor app." authors = ["Metaphor "] diff --git a/tests/unity_catalog/snapshots/test_queries/test_get_last_refreshed_time/describe_history.sql b/tests/unity_catalog/snapshots/test_queries/test_get_last_refreshed_time/describe_history.sql new file mode 100644 index 00000000..aa560ccc --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_get_last_refreshed_time/describe_history.sql @@ -0,0 +1 @@ +DESCRIBE HISTORY db.schema.table LIMIT 50 \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_get_table_properties/show_table_properties.sql b/tests/unity_catalog/snapshots/test_queries/test_get_table_properties/show_table_properties.sql new file mode 100644 index 00000000..901d1f1d --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_get_table_properties/show_table_properties.sql @@ -0,0 +1 @@ +SHOW TBLPROPERTIES db.schema.table \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_catalog_tags/list_catalog_tags.sql b/tests/unity_catalog/snapshots/test_queries/test_list_catalog_tags/list_catalog_tags.sql new file mode 100644 index 00000000..df8270ba --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_catalog_tags/list_catalog_tags.sql @@ -0,0 +1 @@ +SELECT tag_name, tag_value FROM catalog1.information_schema.catalog_tags \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_catalogs/list_catalogs.sql b/tests/unity_catalog/snapshots/test_queries/test_list_catalogs/list_catalogs.sql new file mode 100644 index 00000000..8d7046e5 --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_catalogs/list_catalogs.sql @@ -0,0 +1,8 @@ + + SELECT + catalog_name, + catalog_owner, + comment + FROM system.information_schema.catalogs + WHERE catalog_name <> 'SYSTEM' + \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_column_lineage/list_column_lineage.sql b/tests/unity_catalog/snapshots/test_queries/test_list_column_lineage/list_column_lineage.sql new file mode 100644 index 00000000..ca32e2a8 --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_column_lineage/list_column_lineage.sql @@ -0,0 +1,18 @@ + + SELECT + source_table_full_name, + source_column_name, + target_table_full_name, + target_column_name + FROM system.access.column_lineage + WHERE + target_table_catalog = 'catalog' AND + target_table_schema = 'schema' AND + source_table_full_name IS NOT NULL AND + event_time > date_sub(now(), 7) + GROUP BY + source_table_full_name, + source_column_name, + target_table_full_name, + target_column_name + \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_utils/test_list_query_logs/list_query_log.sql b/tests/unity_catalog/snapshots/test_queries/test_list_query_logs/list_query_log.sql similarity index 100% rename from tests/unity_catalog/snapshots/test_utils/test_list_query_logs/list_query_log.sql rename to tests/unity_catalog/snapshots/test_queries/test_list_query_logs/list_query_log.sql diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_schema_tags/list_schema_tags.sql b/tests/unity_catalog/snapshots/test_queries/test_list_schema_tags/list_schema_tags.sql new file mode 100644 index 00000000..b3c30963 --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_schema_tags/list_schema_tags.sql @@ -0,0 +1 @@ +SELECT schema_name, tag_name, tag_value FROM catalog1.information_schema.schema_tags \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_schemas/list_schemas.sql b/tests/unity_catalog/snapshots/test_queries/test_list_schemas/list_schemas.sql new file mode 100644 index 00000000..1b554e81 --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_schemas/list_schemas.sql @@ -0,0 +1,9 @@ + + SELECT + catalog_name, + schema_name, + schema_owner, + comment + FROM system.information_schema.schemata + WHERE catalog_name = 'catalog1' AND schema_name <> 'INFORMATION_SCHEMA' + \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_table_lineage/list_table_lineage.sql b/tests/unity_catalog/snapshots/test_queries/test_list_table_lineage/list_table_lineage.sql new file mode 100644 index 00000000..55e4eaa1 --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_table_lineage/list_table_lineage.sql @@ -0,0 +1,14 @@ + + SELECT + source_table_full_name, + target_table_full_name + FROM system.access.table_lineage + WHERE + target_table_catalog = 'c' AND + target_table_schema = 's' AND + source_table_full_name IS NOT NULL AND + event_time > date_sub(now(), 7) + GROUP BY + source_table_full_name, + target_table_full_name + \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_tables/list_tables.sql b/tests/unity_catalog/snapshots/test_queries/test_list_tables/list_tables.sql new file mode 100644 index 00000000..cc97839c --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_tables/list_tables.sql @@ -0,0 +1,109 @@ + + WITH + t AS ( + SELECT + table_catalog, + table_schema, + table_name, + table_type, + table_owner, + comment, + data_source_format, + storage_path, + created, + created_by, + last_altered, + last_altered_by + FROM system.information_schema.tables + WHERE + table_catalog = 'catalog1' AND + table_schema <> 'INFORMATION_SCHEMA' AND + table_schema IS NOT NULL AND + table_name IS NOT NULL + ), + + v AS ( + SELECT + table_catalog, + table_schema, + table_name, + view_definition + FROM system.information_schema.views + WHERE + table_catalog = 'catalog1' AND + table_schema <> 'INFORMATION_SCHEMA' AND + table_schema IS NOT NULL AND + table_name IS NOT NULL + ), + + tv AS ( + SELECT + t.table_catalog, + t.table_schema, + t.table_name, + t.table_type, + t.table_owner, + t.comment, + t.data_source_format, + t.storage_path, + t.created, + t.created_by, + t.last_altered, + t.last_altered_by, + v.view_definition + FROM t + LEFT JOIN v + ON + t.table_catalog = v.table_catalog AND + t.table_schema = v.table_schema AND + t.table_name = v.table_name + ), + + c AS ( + SELECT + table_catalog, + table_schema, + table_name, + column_name, + full_data_type as data_type, + CASE + WHEN numeric_precision IS NOT NULL THEN numeric_precision + WHEN datetime_precision IS NOT NULL THEN datetime_precision + ELSE NULL + END AS data_precision, + is_nullable, + comment + FROM system.information_schema.columns + WHERE + table_catalog = 'catalog1' AND + table_schema <> 'INFORMATION_SCHEMA' AND + table_schema IS NOT NULL AND + table_name IS NOT NULL + ) + + SELECT + tv.table_catalog as catalog_name, + tv.table_schema as schema_name, + tv.table_name as table_name, + tv.table_type as table_type, + tv.table_owner as owner, + tv.comment as table_comment, + tv.data_source_format as data_source_format, + tv.storage_path as storage_path, + tv.created as created_at, + tv.created_by as created_by, + tv.last_altered as updated_at, + tv.last_altered_by as updated_by, + tv.view_definition as view_definition, + c.column_name as column_name, + c.data_type as data_type, + c.data_precision as data_precision, + c.is_nullable as is_nullable, + c.comment as column_comment + FROM tv + LEFT JOIN c + ON + tv.table_catalog = c.table_catalog AND + tv.table_schema = c.table_schema AND + tv.table_name = c.table_name + \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_volume_files/list_volume_files.sql b/tests/unity_catalog/snapshots/test_queries/test_list_volume_files/list_volume_files.sql new file mode 100644 index 00000000..5467b79c --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_volume_files/list_volume_files.sql @@ -0,0 +1 @@ +LIST '/Volumes/catalog1/schema1/volume1' \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_volumes/list_volumes.sql b/tests/unity_catalog/snapshots/test_queries/test_list_volumes/list_volumes.sql new file mode 100644 index 00000000..e8263dfa --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_volumes/list_volumes.sql @@ -0,0 +1,16 @@ + + SELECT + volume_catalog, + volume_schema, + volume_name, + volume_type, + volume_owner, + comment, + created, + created_by, + last_altered, + last_altered_by, + storage_location + FROM system.information_schema.volumes + WHERE volume_catalog = 'catalog1' AND volume_schema = 'schema1' + \ No newline at end of file diff --git a/tests/unity_catalog/test_models.py b/tests/unity_catalog/test_models.py index 7db7417e..e69de29b 100644 --- a/tests/unity_catalog/test_models.py +++ b/tests/unity_catalog/test_models.py @@ -1,11 +0,0 @@ -import pytest -from databricks.sdk.service.catalog import ColumnInfo - -from metaphor.unity_catalog.models import extract_schema_field_from_column_info - - -def test_parse_schema_field_from_invalid_column_info() -> None: - with pytest.raises(ValueError): - extract_schema_field_from_column_info( - ColumnInfo(comment="does not have a type") - ) diff --git a/tests/unity_catalog/test_queries.py b/tests/unity_catalog/test_queries.py new file mode 100644 index 00000000..75b75427 --- /dev/null +++ b/tests/unity_catalog/test_queries.py @@ -0,0 +1,556 @@ +from datetime import datetime, timezone +from unittest.mock import MagicMock + +from freezegun import freeze_time +from pytest_snapshot.plugin import Snapshot + +from metaphor.unity_catalog.models import ( + CatalogInfo, + CatalogTag, + Column, + ColumnInfo, + ColumnLineage, + SchemaInfo, + SchemaTag, + TableInfo, + TableLineage, + VolumeFileInfo, + VolumeInfo, +) +from metaphor.unity_catalog.queries import ( + get_last_refreshed_time, + get_table_properties, + list_catalog_tags, + list_catalogs, + list_column_lineage, + list_query_logs, + list_schema_tags, + list_schemas, + list_table_lineage, + list_tables, + list_volume_files, + list_volumes, +) +from tests.unity_catalog.mocks import mock_sql_connection + + +def test_list_catalogs( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "catalog_name": "catalog1", + "catalog_owner": "owner1", + "comment": "comment1", + }, + { + "catalog_name": "catalog2", + "catalog_owner": "owner2", + "comment": "comment2", + }, + ] + ], + None, + mock_cursor, + ) + + catalogs = list_catalogs(mock_connection) + + assert catalogs == [ + CatalogInfo(catalog_name="catalog1", owner="owner1", comment="comment1"), + CatalogInfo(catalog_name="catalog2", owner="owner2", comment="comment2"), + ] + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_catalogs.sql") + + +def test_list_schemas( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "catalog_name": "catalog1", + "schema_name": "schema1", + "schema_owner": "owner1", + "comment": "comment1", + }, + { + "catalog_name": "catalog1", + "schema_name": "schema2", + "schema_owner": "owner2", + "comment": "comment2", + }, + ] + ], + None, + mock_cursor, + ) + + schemas = list_schemas(mock_connection, "catalog1") + + assert schemas == [ + SchemaInfo( + catalog_name="catalog1", + schema_name="schema1", + owner="owner1", + comment="comment1", + ), + SchemaInfo( + catalog_name="catalog1", + schema_name="schema2", + owner="owner2", + comment="comment2", + ), + ] + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_schemas.sql") + + +def test_list_tables( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "catalog_name": "catalog1", + "schema_name": "schema1", + "table_name": "table1", + "table_type": "TABLE", + "owner": "owner1", + "table_comment": "table_comment1", + "data_source_format": "PARQUET", + "storage_path": "location1", + "created_at": datetime(2000, 1, 1), + "created_by": "user1", + "updated_at": datetime(2001, 1, 1), + "updated_by": "user2", + "view_definition": "definition", + "column_name": "column1", + "data_type": "data_type1", + "data_precision": 10, + "is_nullable": "YES", + "column_comment": "column_comment1", + }, + { + "catalog_name": "catalog1", + "schema_name": "schema1", + "table_name": "table1", + "table_type": "TABLE", + "owner": "owner1", + "table_comment": "table_comment1", + "data_source_format": "PARQUET", + "storage_path": "location1", + "created_at": datetime(2000, 1, 1), + "created_by": "user1", + "updated_at": datetime(2001, 1, 1), + "updated_by": "user2", + "view_definition": "definition", + "column_name": "column2", + "data_type": "data_type2", + "data_precision": 20, + "is_nullable": "NO", + "column_comment": "column_comment2", + }, + ], + ], + None, + mock_cursor, + ) + + tables = list_tables(mock_connection, "catalog1", "schema1") + + assert tables == [ + TableInfo( + catalog_name="catalog1", + schema_name="schema1", + table_name="table1", + type="TABLE", + owner="owner1", + comment="table_comment1", + data_source_format="PARQUET", + storage_location="location1", + created_at=datetime(2000, 1, 1), + created_by="user1", + updated_at=datetime(2001, 1, 1), + updated_by="user2", + view_definition="definition", + columns=[ + ColumnInfo( + column_name="column1", + data_type="data_type1", + data_precision=10, + is_nullable=True, + comment="column_comment1", + ), + ColumnInfo( + column_name="column2", + data_type="data_type2", + data_precision=20, + is_nullable=False, + comment="column_comment2", + ), + ], + ), + ] + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_tables.sql") + + +def test_list_volumes( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "volume_catalog": "catalog1", + "volume_schema": "schema1", + "volume_name": "volume1", + "volume_type": "MANAGED", + "volume_owner": "owner1", + "comment": "comment1", + "created": datetime(2000, 1, 1), + "created_by": "user1", + "last_altered": datetime(2001, 1, 1), + "last_altered_by": "user2", + "last_altered_by": "user2", + "storage_location": "location1", + }, + ] + ], + None, + mock_cursor, + ) + + volumes = list_volumes(mock_connection, "catalog1", "schema1") + + assert volumes == [ + VolumeInfo( + catalog_name="catalog1", + schema_name="schema1", + volume_name="volume1", + full_name="catalog1.schema1.volume1", + volume_type="MANAGED", + owner="owner1", + comment="comment1", + created_at=datetime(2000, 1, 1), + created_by="user1", + updated_at=datetime(2001, 1, 1), + updated_by="user2", + storage_location="location1", + ), + ] + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_volumes.sql") + + +def test_list_volume_files( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "path": "path1", + "name": "name1", + "size": 100, + "modification_time": datetime(2000, 1, 1), + }, + ] + ], + None, + mock_cursor, + ) + + volume_files = list_volume_files( + mock_connection, + VolumeInfo( + catalog_name="catalog1", + schema_name="schema1", + volume_name="volume1", + full_name="catalog1.schema1.volume1", + volume_type="MANAGED", + owner="owner1", + created_at=datetime(2000, 1, 1), + created_by="user1", + updated_at=datetime(2001, 1, 1), + updated_by="user2", + storage_location="location1", + ), + ) + + assert volume_files == [ + VolumeFileInfo( + path="path1", + name="name1", + size=100, + last_updated=datetime(2000, 1, 1), + ), + ] + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_volume_files.sql") + + +def test_list_table_lineage( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + ("c.s.t1", "c.s.t3"), + ("c.s.t2", "c.s.t3"), + ("c.s.t4", "c.s.t2"), + ] + ], + None, + mock_cursor, + ) + + table_lineage = list_table_lineage(mock_connection, "c", "s") + + assert table_lineage == { + "c.s.t3": TableLineage(upstream_tables=["c.s.t1", "c.s.t2"]), + "c.s.t2": TableLineage(upstream_tables=["c.s.t4"]), + } + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_table_lineage.sql") + + +def test_list_column_lineage( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + ("c.s.t1", "c1", "c.s.t3", "ca"), + ("c.s.t1", "c2", "c.s.t3", "ca"), + ("c.s.t1", "c3", "c.s.t3", "cb"), + ("c.s.t2", "c4", "c.s.t3", "ca"), + ("c.s.t3", "c5", "c.s.t2", "cc"), + ] + ], + None, + mock_cursor, + ) + + column_lineage = list_column_lineage(mock_connection, "catalog", "schema") + + assert column_lineage == { + "c.s.t3": ColumnLineage( + upstream_columns={ + "ca": [ + Column(column_name="c1", table_name="c.s.t1"), + Column(column_name="c2", table_name="c.s.t1"), + Column(column_name="c4", table_name="c.s.t2"), + ], + "cb": [Column(column_name="c3", table_name="c.s.t1")], + } + ), + "c.s.t2": ColumnLineage( + upstream_columns={ + "cc": [Column(column_name="c5", table_name="c.s.t3")], + } + ), + } + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_column_lineage.sql") + + +@freeze_time("2000-01-02") +def test_list_query_logs( + test_root_dir: str, + snapshot: Snapshot, +): + + mock_cursor = MagicMock() + mock_connection = mock_sql_connection(None, None, mock_cursor) + + list_query_logs(mock_connection, 1, ["user1", "user2"]) + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_query_log.sql") + assert args[1] == [ + datetime(2000, 1, 1, tzinfo=timezone.utc), + datetime(2000, 1, 2, tzinfo=timezone.utc), + ] + + +def test_get_last_refreshed_time( + test_root_dir: str, + snapshot: Snapshot, +): + + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "operation": "SET TBLPROPERTIES", + "timestamp": datetime(2020, 1, 4), + }, + { + "operation": "ADD CONSTRAINT", + "timestamp": datetime(2020, 1, 3), + }, + { + "operation": "CHANGE COLUMN", + "timestamp": datetime(2020, 1, 2), + }, + { + "operation": "WRITE", + "timestamp": datetime(2020, 1, 1), + }, + ] + ], + None, + mock_cursor, + ) + + result = get_last_refreshed_time(mock_connection, "db.schema.table", 50) + + assert result == ("db.schema.table", datetime(2020, 1, 1)) + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "describe_history.sql") + + +def test_get_table_properties( + test_root_dir: str, + snapshot: Snapshot, +): + + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "key": "key1", + "value": "value1", + }, + { + "key": "key2", + "value": "value2", + }, + ] + ], + None, + mock_cursor, + ) + + result = get_table_properties(mock_connection, "db.schema.table") + + assert result == {"key1": "value1", "key2": "value2"} + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "show_table_properties.sql") + + +def test_list_catalog_tags( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "catalog_name": "catalog1", + "tag_name": "key1", + "tag_value": "value1", + }, + { + "catalog_name": "catalog1", + "tag_name": "key2", + "tag_value": "value2", + }, + ] + ], + None, + mock_cursor, + ) + + result = list_catalog_tags(mock_connection, "catalog1") + + assert result == [ + CatalogTag(key="key1", value="value1"), + CatalogTag(key="key2", value="value2"), + ] + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_catalog_tags.sql") + + +def test_list_schema_tags( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "catalog_name": "catalog1", + "schema_name": "schema1", + "tag_name": "key1", + "tag_value": "value1", + }, + { + "catalog_name": "catalog1", + "schema_name": "schema1", + "tag_name": "key2", + "tag_value": "value2", + }, + ] + ], + None, + mock_cursor, + ) + + result = list_schema_tags(mock_connection, "catalog1") + + assert result == { + "schema1": [ + SchemaTag(key="key1", value="value1"), + SchemaTag(key="key2", value="value2"), + ] + } + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_schema_tags.sql") diff --git a/tests/unity_catalog/test_utils.py b/tests/unity_catalog/test_utils.py index 1136ade1..62f67bfc 100644 --- a/tests/unity_catalog/test_utils.py +++ b/tests/unity_catalog/test_utils.py @@ -1,9 +1,8 @@ -from datetime import datetime, timezone +from datetime import datetime from typing import Optional, Tuple from unittest.mock import MagicMock from databricks.sdk.service.iam import ServicePrincipal -from freezegun import freeze_time from pytest_snapshot.plugin import Snapshot from metaphor.common.entity_id import dataset_normalized_name, to_dataset_entity_id @@ -15,18 +14,17 @@ DatasetStructure, QueriedDataset, QueryLogs, + SchemaField, ) -from metaphor.unity_catalog.models import Column, ColumnLineage, TableLineage +from metaphor.unity_catalog.models import ColumnInfo from metaphor.unity_catalog.utils import ( batch_get_last_refreshed_time, + build_schema_field_from_column_info, escape_special_characters, find_qualified_dataset, get_last_refreshed_time, get_query_logs, - list_column_lineage, - list_query_logs, list_service_principals, - list_table_lineage, ) from tests.test_utils import load_json from tests.unity_catalog.mocks import mock_connection_pool, mock_sql_connection @@ -179,78 +177,6 @@ def make_dataset(parts: Tuple[str, str, str]) -> Dataset: assert not found -def test_list_table_lineage(): - mock_connection = mock_sql_connection( - [ - [ - ("c.s.t1", "c.s.t3"), - ("c.s.t2", "c.s.t3"), - ("c.s.t4", "c.s.t2"), - ] - ] - ) - - table_lineage = list_table_lineage(mock_connection, "c", "s") - - assert table_lineage == { - "c.s.t3": TableLineage(upstream_tables=["c.s.t1", "c.s.t2"]), - "c.s.t2": TableLineage(upstream_tables=["c.s.t4"]), - } - - -def test_list_column_lineage(): - mock_connection = mock_sql_connection( - [ - [ - ("c.s.t1", "c1", "c.s.t3", "ca"), - ("c.s.t1", "c2", "c.s.t3", "ca"), - ("c.s.t1", "c3", "c.s.t3", "cb"), - ("c.s.t2", "c4", "c.s.t3", "ca"), - ("c.s.t3", "c5", "c.s.t2", "cc"), - ] - ] - ) - - column_lineage = list_column_lineage(mock_connection, "catalog", "schema") - - assert column_lineage == { - "c.s.t3": ColumnLineage( - upstream_columns={ - "ca": [ - Column(column_name="c1", table_name="c.s.t1"), - Column(column_name="c2", table_name="c.s.t1"), - Column(column_name="c4", table_name="c.s.t2"), - ], - "cb": [Column(column_name="c3", table_name="c.s.t1")], - } - ), - "c.s.t2": ColumnLineage( - upstream_columns={ - "cc": [Column(column_name="c5", table_name="c.s.t3")], - } - ), - } - - -@freeze_time("2000-01-02") -def test_list_query_logs( - test_root_dir: str, - snapshot: Snapshot, -): - - mock_cursor = MagicMock() - mock_connection = mock_sql_connection(None, None, mock_cursor) - - list_query_logs(mock_connection, 1, ["user1", "user2"]) - - args = mock_cursor.execute.call_args_list[0].args - snapshot.assert_match(args[0], "list_query_log.sql") - assert args[1] == [ - datetime(2000, 1, 1, tzinfo=timezone.utc), - datetime(2000, 1, 2, tzinfo=timezone.utc), - ] - - def test_get_last_refreshed_time( test_root_dir: str, snapshot: Snapshot, @@ -332,3 +258,22 @@ def test_list_service_principals(): def test_escape_special_characters(): assert escape_special_characters("this.is.a_table") == "this.is.a_table" assert escape_special_characters("this.is.also-a-table") == "`this.is.also-a-table`" + + +def test_parse_schema_field_from_column_info(): + assert build_schema_field_from_column_info( + ColumnInfo( + column_name="name", + data_type="type", + data_precision=3, + is_nullable=True, + comment="comment", + ) + ) == SchemaField( + field_name="name", + field_path="name", + native_type="type", + description="comment", + nullable=True, + precision=3.0, + )