From 955ff891485a072ea2e9eced6b9ec4b2431eda76 Mon Sep 17 00:00:00 2001 From: Mars Lan Date: Mon, 28 Oct 2024 18:55:10 -0700 Subject: [PATCH] Refactor Unity Catalog to fetch catalog/schema/table metadata from System tables (#1022) --- metaphor/unity_catalog/extractor.py | 685 +++++++---------- metaphor/unity_catalog/models.py | 95 ++- metaphor/unity_catalog/profile/extractor.py | 10 +- metaphor/unity_catalog/queries.py | 715 ++++++++++++++++++ metaphor/unity_catalog/utils.py | 243 ++---- poetry.lock | 33 +- pyproject.toml | 6 +- tests/unity_catalog/expected.json | 90 +-- .../external_shallow_clone.json | 17 +- .../test_init_dataset/shallow_clone.json | 17 +- .../test_init_dataset/table.json | 24 +- .../describe_history.sql | 1 + .../show_table_properties.sql | 1 + .../test_list_catalogs/list_catalogs.sql | 29 + .../list_column_lineage.sql | 18 + .../test_list_query_logs/list_query_log.sql | 0 .../test_list_schemas/list_schemas.sql | 32 + .../list_table_lineage.sql | 14 + .../test_list_tables/list_tables.sql | 157 ++++ .../list_volume_files.sql | 1 + .../test_list_volumes/list_volumes.sql | 50 ++ tests/unity_catalog/test_extractor.py | 361 ++++----- tests/unity_catalog/test_models.py | 11 - tests/unity_catalog/test_queries.py | 551 ++++++++++++++ tests/unity_catalog/test_utils.py | 128 ++-- 25 files changed, 2334 insertions(+), 955 deletions(-) create mode 100644 metaphor/unity_catalog/queries.py create mode 100644 tests/unity_catalog/snapshots/test_queries/test_get_last_refreshed_time/describe_history.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_get_table_properties/show_table_properties.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_catalogs/list_catalogs.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_column_lineage/list_column_lineage.sql rename tests/unity_catalog/snapshots/{test_utils => test_queries}/test_list_query_logs/list_query_log.sql (100%) create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_schemas/list_schemas.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_table_lineage/list_table_lineage.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_tables/list_tables.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_volume_files/list_volume_files.sql create mode 100644 tests/unity_catalog/snapshots/test_queries/test_list_volumes/list_volumes.sql create mode 100644 tests/unity_catalog/test_queries.py diff --git a/metaphor/unity_catalog/extractor.py b/metaphor/unity_catalog/extractor.py index 92a54a28..d01e3c5e 100644 --- a/metaphor/unity_catalog/extractor.py +++ b/metaphor/unity_catalog/extractor.py @@ -1,12 +1,9 @@ -import json import re import urllib.parse -from collections import defaultdict -from datetime import datetime -from typing import Collection, Dict, Generator, Iterator, List, Optional, Tuple +from typing import Collection, Dict, Iterator, List, Optional -from databricks.sdk.service.catalog import TableInfo, TableType, VolumeInfo from databricks.sdk.service.iam import ServicePrincipal +from pydantic import BaseModel from metaphor.common.base_extractor import BaseExtractor from metaphor.common.entity_id import ( @@ -15,10 +12,10 @@ to_dataset_entity_id_from_logical_id, ) from metaphor.common.event_util import ENTITY_TYPES -from metaphor.common.filter import DatasetFilter -from metaphor.common.logger import get_logger, json_dump_to_debug_file +from metaphor.common.fieldpath import build_schema_field +from metaphor.common.logger import get_logger from metaphor.common.models import to_dataset_statistics -from metaphor.common.utils import to_utc_datetime_from_timestamp +from metaphor.common.utils import safe_float from metaphor.models.crawler_run_metadata import Platform from metaphor.models.metadata_change_event import ( AssetPlatform, @@ -34,12 +31,14 @@ KeyValuePair, MaterializationType, QueryLog, + SchemaField, SchemaType, SourceField, SourceInfo, SQLSchema, SystemContact, SystemContacts, + SystemDescription, SystemTag, SystemTags, SystemTagSource, @@ -52,59 +51,68 @@ VolumeFile, ) from metaphor.unity_catalog.config import UnityCatalogRunConfig -from metaphor.unity_catalog.models import extract_schema_field_from_column_info -from metaphor.unity_catalog.utils import ( +from metaphor.unity_catalog.models import ( + CatalogInfo, + ColumnInfo, + SchemaInfo, + TableInfo, + Tag, + VolumeFileInfo, + VolumeInfo, +) +from metaphor.unity_catalog.queries import ( ColumnLineageMap, TableLineageMap, + list_catalogs, + list_column_lineage, + list_schemas, + list_table_lineage, + list_tables, + list_volume_files, + list_volumes, +) +from metaphor.unity_catalog.utils import ( batch_get_last_refreshed_time, + batch_get_table_properties, create_api, create_connection, create_connection_pool, get_query_logs, - list_column_lineage, list_service_principals, - list_table_lineage, ) logger = get_logger() -# Filter out "system" database & all "information_schema" schemas -DEFAULT_FILTER: DatasetFilter = DatasetFilter( - excludes={ - "system": None, - "*": {"information_schema": None}, - } -) TABLE_TYPE_MATERIALIZATION_TYPE_MAP = { - TableType.EXTERNAL: MaterializationType.EXTERNAL, - TableType.EXTERNAL_SHALLOW_CLONE: MaterializationType.EXTERNAL, - TableType.FOREIGN: MaterializationType.EXTERNAL, - TableType.MANAGED: MaterializationType.TABLE, - TableType.MANAGED_SHALLOW_CLONE: MaterializationType.TABLE, - TableType.MATERIALIZED_VIEW: MaterializationType.MATERIALIZED_VIEW, - TableType.STREAMING_TABLE: MaterializationType.STREAM, - TableType.VIEW: MaterializationType.VIEW, + "EXTERNAL": MaterializationType.EXTERNAL, + "EXTERNAL_SHALLOW_CLONE": MaterializationType.EXTERNAL, + "FOREIGN": MaterializationType.EXTERNAL, + "MANAGED": MaterializationType.TABLE, + "MANAGED_SHALLOW_CLONE": MaterializationType.TABLE, + "MATERIALIZED_VIEW": MaterializationType.MATERIALIZED_VIEW, + "STREAMING_TABLE": MaterializationType.STREAM, + "VIEW": MaterializationType.VIEW, } TABLE_TYPE_WITH_HISTORY = set( [ - TableType.EXTERNAL, - TableType.EXTERNAL_SHALLOW_CLONE, - TableType.MANAGED, - TableType.MANAGED_SHALLOW_CLONE, + "EXTERNAL", + "EXTERNAL_SHALLOW_CLONE", + "MANAGED", + "MANAGED_SHALLOW_CLONE", ] ) TABLE_TYPE_MAP = { - TableType.EXTERNAL: UnityCatalogTableType.EXTERNAL, - TableType.EXTERNAL_SHALLOW_CLONE: UnityCatalogTableType.EXTERNAL_SHALLOW_CLONE, - TableType.FOREIGN: UnityCatalogTableType.FOREIGN, - TableType.MANAGED: UnityCatalogTableType.MANAGED, - TableType.MANAGED_SHALLOW_CLONE: UnityCatalogTableType.MANAGED_SHALLOW_CLONE, - TableType.MATERIALIZED_VIEW: UnityCatalogTableType.MATERIALIZED_VIEW, - TableType.STREAMING_TABLE: UnityCatalogTableType.STREAMING_TABLE, - TableType.VIEW: UnityCatalogTableType.VIEW, + "EXTERNAL": UnityCatalogTableType.EXTERNAL, + "EXTERNAL_SHALLOW_CLONE": UnityCatalogTableType.EXTERNAL_SHALLOW_CLONE, + "FOREIGN": UnityCatalogTableType.FOREIGN, + "MANAGED": UnityCatalogTableType.MANAGED, + "MANAGED_SHALLOW_CLONE": UnityCatalogTableType.MANAGED_SHALLOW_CLONE, + "MATERIALIZED_VIEW": UnityCatalogTableType.MATERIALIZED_VIEW, + "STREAMING_TABLE": UnityCatalogTableType.STREAMING_TABLE, + "VIEW": UnityCatalogTableType.VIEW, } # For variable substitution in source URLs @@ -114,22 +122,9 @@ URL_TABLE_RE = re.compile(r"{table}") -CatalogSystemTagsTuple = Tuple[List[SystemTag], Dict[str, List[SystemTag]]] -""" -(catalog system tags, schema name -> schema system tags) -""" - - -CatalogSystemTags = Dict[str, CatalogSystemTagsTuple] -""" -catalog name -> (catalog tags, schema name -> schema tags) -""" - - -def to_utc_from_timestamp_ms(timestamp_ms: Optional[int]): - if timestamp_ms is not None: - return to_utc_datetime_from_timestamp(timestamp_ms / 1000) - return None +class CatalogSystemTags(BaseModel): + catalog_tags: List[SystemTag] = [] + schema_name_to_tags: Dict[str, List[SystemTag]] = {} class UnityCatalogExtractor(BaseExtractor): @@ -168,10 +163,11 @@ def __init__(self, config: UnityCatalogRunConfig): self._service_principals: Dict[str, ServicePrincipal] = {} self._last_refresh_time_queue: List[str] = [] + self._table_properties_queue: List[str] = [] # Map fullname or volume path to a dataset self._datasets: Dict[str, Dataset] = {} - self._filter = config.filter.normalize().merge(DEFAULT_FILTER) + self._filter = config.filter.normalize() self._query_log_config = config.query_log self._hierarchies: Dict[str, Hierarchy] = {} self._volumes: Dict[str, VolumeInfo] = {} @@ -182,46 +178,47 @@ async def extract(self) -> Collection[ENTITY_TYPES]: self._service_principals = list_service_principals(self._api) logger.info(f"Found service principals: {self._service_principals}") - catalogs = [ - catalog - for catalog in self._get_catalogs() - if self._filter.include_database(catalog) - ] + catalogs = list_catalogs(self._connection) + for catalog_info in catalogs: + catalog = catalog_info.catalog_name + if not self._filter.include_database(catalog): + logger.info(f"Ignore catalog {catalog} due to filter config") + continue - logger.info(f"Found catalogs: {catalogs}") + self._init_catalog(catalog_info) - for catalog in catalogs: - schemas = self._get_schemas(catalog) - for schema in schemas: + for schema_info in list_schemas(self._connection, catalog): + schema = schema_info.schema_name if not self._filter.include_schema(catalog, schema): logger.info( - f"Ignore schema: {catalog}.{schema} due to filter config" + f"Ignore schema {catalog}.{schema} due to filter config" ) continue + self._init_schema(schema_info) + table_lineage = list_table_lineage(self._connection, catalog, schema) column_lineage = list_column_lineage(self._connection, catalog, schema) - for volume in self._get_volume_infos(catalog, schema): - assert volume.full_name - self._volumes[volume.full_name] = volume - self._init_volume(volume) - self._extract_volume_files(volume) + for volume_info in list_volumes(self._connection, catalog, schema): + self._volumes[volume_info.full_name] = volume_info + self._init_volume(volume_info) + self._extract_volume_files(volume_info) - for table_info in self._get_table_infos(catalog, schema): - table_name = f"{catalog}.{schema}.{table_info.name}" - if table_info.name is None: - logger.error(f"Ignoring table without name: {table_info}") - continue - if not self._filter.include_table(catalog, schema, table_info.name): + for table_info in list_tables(self._connection, catalog, schema): + table = table_info.table_name + table_name = f"{catalog}.{schema}.{table}" + if not self._filter.include_table(catalog, schema, table): logger.info(f"Ignore table: {table_name} due to filter config") continue dataset = self._init_dataset(table_info) self._populate_lineage(dataset, table_lineage, column_lineage) - self._fetch_tags(catalogs) + self._propagate_tags() + # Batch query table properties and last refreshed time + self._populate_table_properties() self._populate_last_refreshed_time() entities: List[ENTITY_TYPES] = [] @@ -229,68 +226,6 @@ async def extract(self) -> Collection[ENTITY_TYPES]: entities.extend(self._hierarchies.values()) return entities - def _get_catalogs(self) -> List[str]: - catalogs = list(self._api.catalogs.list()) - json_dump_to_debug_file(catalogs, "list-catalogs.json") - - catalog_names = [] - for catalog in catalogs: - if catalog.name is None: - continue - - catalog_names.append(catalog.name) - if not catalog.owner: - continue - - hierarchy = self._init_hierarchy(catalog.name) - hierarchy.system_contacts = SystemContacts( - contacts=[ - SystemContact( - email=self._get_owner_display_name(catalog.owner), - system_contact_source=AssetPlatform.UNITY_CATALOG, - ) - ] - ) - return catalog_names - - def _get_schemas(self, catalog: str) -> List[str]: - schemas = list(self._api.schemas.list(catalog)) - json_dump_to_debug_file(schemas, f"list-schemas-{catalog}.json") - - schema_names = [] - for schema in schemas: - if schema.name: - schema_names.append(schema.name) - if not schema.owner: - continue - - hierarchy = self._init_hierarchy(catalog, schema.name) - hierarchy.system_contacts = SystemContacts( - contacts=[ - SystemContact( - email=self._get_owner_display_name(schema.owner), - system_contact_source=AssetPlatform.UNITY_CATALOG, - ) - ] - ) - return schema_names - - def _get_table_infos( - self, catalog: str, schema: str - ) -> Generator[TableInfo, None, None]: - tables = list(self._api.tables.list(catalog, schema)) - json_dump_to_debug_file(tables, f"list-tables-{catalog}-{schema}.json") - for table in tables: - yield table - - def _get_volume_infos( - self, catalog: str, schema: str - ) -> Generator[VolumeInfo, None, None]: - volumes = list(self._api.volumes.list(catalog, schema)) - json_dump_to_debug_file(volumes, f"list-volumes-{catalog}-{schema}.json") - for volume in volumes: - yield volume - def _get_table_source_url( self, database: str, schema_name: str, table_name: str ) -> str: @@ -314,11 +249,10 @@ def _get_source_url( return url def _init_dataset(self, table_info: TableInfo) -> Dataset: - assert table_info.catalog_name and table_info.schema_name and table_info.name - table_name = table_info.name + table_name = table_info.table_name schema_name = table_info.schema_name database = table_info.catalog_name - table_type = table_info.table_type + table_type = table_info.type normalized_name = dataset_normalized_name(database, schema_name, table_name) @@ -331,15 +265,7 @@ def _init_dataset(self, table_info: TableInfo) -> Dataset: database=database, schema=schema_name, table=table_name ) - if table_type is None: - raise ValueError(f"Invalid table {table_info.name}, no table_type found") - - fields = [] - if table_info.columns is not None: - fields = [ - extract_schema_field_from_column_info(column_info) - for column_info in table_info.columns - ] + fields = [self._init_column(column) for column in table_info.columns] dataset.schema = DatasetSchema( schema_type=SchemaType.SQL, @@ -347,63 +273,83 @@ def _init_dataset(self, table_info: TableInfo) -> Dataset: fields=fields, sql_schema=SQLSchema( materialization=TABLE_TYPE_MATERIALIZATION_TYPE_MAP.get( - table_type, MaterializationType.TABLE - ), - table_schema=( - table_info.view_definition if table_info.view_definition else None + table_type, + MaterializationType.TABLE, ), + table_schema=table_info.view_definition, ), ) - if table_info.table_type in TABLE_TYPE_WITH_HISTORY: + # Queue tables with history for batch query later + if table_type in TABLE_TYPE_WITH_HISTORY: self._last_refresh_time_queue.append(normalized_name) main_url = self._get_table_source_url(database, schema_name, table_name) dataset.source_info = SourceInfo( main_url=main_url, - created_at_source=to_utc_from_timestamp_ms( - timestamp_ms=table_info.created_at - ), + created_at_source=table_info.created_at, + created_by=table_info.created_by, + last_updated=table_info.updated_at, + updated_by=table_info.updated_by, ) dataset.unity_catalog = UnityCatalog( dataset_type=UnityCatalogDatasetType.UNITY_CATALOG_TABLE, table_info=UnityCatalogTableInfo( - type=TABLE_TYPE_MAP.get(table_type, UnityCatalogTableType.UNKNOWN), - data_source_format=( - table_info.data_source_format.value - if table_info.data_source_format is not None - else None + type=TABLE_TYPE_MAP.get( + table_type, + UnityCatalogTableType.UNKNOWN, ), + data_source_format=table_info.data_source_format, storage_location=table_info.storage_location, owner=table_info.owner, - properties=( - [ - KeyValuePair(key=k, value=json.dumps(v)) - for k, v in table_info.properties.items() - ] - if table_info.properties is not None - else [] - ), ), ) - if table_info.owner is not None: - dataset.system_contacts = SystemContacts( - contacts=[ - SystemContact( - email=self._get_owner_display_name(table_info.owner), - system_contact_source=AssetPlatform.UNITY_CATALOG, - ) - ] - ) + # Queue non-view tables for batch query later + if table_info.view_definition is None: + self._table_properties_queue.append(normalized_name) + + dataset.system_contacts = SystemContacts( + contacts=[ + SystemContact( + email=self._get_owner_display_name(table_info.owner), + system_contact_source=AssetPlatform.UNITY_CATALOG, + ) + ] + ) - dataset.system_tags = SystemTags(tags=[]) + dataset.system_tags = SystemTags( + tags=[ + SystemTag( + key=tag.key, + value=tag.value, + system_tag_source=SystemTagSource.UNITY_CATALOG, + ) + for tag in table_info.tags + ] + ) self._datasets[normalized_name] = dataset return dataset + def _init_column(self, column_info: ColumnInfo) -> SchemaField: + field = build_schema_field( + column_name=column_info.column_name, + field_type=column_info.data_type, + description=column_info.comment, + nullable=column_info.is_nullable, + precision=safe_float(column_info.data_precision), + ) + + field.tags = [ + f"{tag.key}={tag.value}" if tag.key else tag.value + for tag in column_info.tags + ] + + return field + def _get_location_url(self, location_name: str): url = f"https://{self._hostname}/explore/location/{location_name}/browse" return url @@ -499,10 +445,12 @@ def _init_hierarchy( self, catalog: str, schema: Optional[str] = None, + owner: Optional[str] = None, + comment: Optional[str] = None, + tags: Optional[List[Tag]] = None, ) -> Hierarchy: path = [part.lower() for part in [catalog, schema] if part] - - return self._hierarchies.setdefault( + hierarchy = self._hierarchies.setdefault( ".".join(path), Hierarchy( logical_id=HierarchyLogicalID( @@ -511,242 +459,96 @@ def _init_hierarchy( ), ) - def _extract_hierarchies(self, catalog_system_tags: CatalogSystemTags) -> None: - for catalog, (catalog_tags, schema_name_to_tag) in catalog_system_tags.items(): - if catalog_tags: - hierarchy = self._init_hierarchy(catalog) - hierarchy.system_tags = SystemTags(tags=catalog_tags) - for schema, schema_tags in schema_name_to_tag.items(): - if schema_tags: - hierarchy = self._init_hierarchy(catalog, schema) - hierarchy.system_tags = SystemTags(tags=schema_tags) - - def _fetch_catalog_system_tags(self, catalog: str) -> CatalogSystemTagsTuple: - logger.info(f"Fetching tags for catalog {catalog}") - - with self._connection.cursor() as cursor: - catalog_tags = [] - schema_tags: Dict[str, List[SystemTag]] = defaultdict(list) - catalog_tags_query = f"SELECT tag_name, tag_value FROM {catalog}.information_schema.catalog_tags" - cursor.execute(catalog_tags_query) - for tag_name, tag_value in cursor.fetchall(): - tag = SystemTag( - key=tag_name, - value=tag_value, - system_tag_source=SystemTagSource.UNITY_CATALOG, - ) - catalog_tags.append(tag) - - schema_tags_query = f"SELECT schema_name, tag_name, tag_value FROM {catalog}.information_schema.schema_tags" - cursor.execute(schema_tags_query) - for schema_name, tag_name, tag_value in cursor.fetchall(): - if self._filter.include_schema(catalog, schema_name): - tag = SystemTag( - key=tag_name, - value=tag_value, - system_tag_source=SystemTagSource.UNITY_CATALOG, - ) - schema_tags[schema_name].append(tag) - return catalog_tags, schema_tags - - def _assign_dataset_system_tags( - self, catalog: str, catalog_system_tags: CatalogSystemTags - ) -> None: - for schema in self._api.schemas.list(catalog): - if schema.name: - for table in self._api.tables.list(catalog, schema.name): - normalized_dataset_name = dataset_normalized_name( - catalog, schema.name, table.name + if owner is not None: + hierarchy.system_contacts = SystemContacts( + contacts=[ + SystemContact( + email=self._get_owner_display_name(owner), + system_contact_source=AssetPlatform.UNITY_CATALOG, ) - dataset = self._datasets.get(normalized_dataset_name) - if dataset is not None: - assert dataset.system_tags - dataset.system_tags.tags = ( - catalog_system_tags[catalog][0] - + catalog_system_tags[catalog][1][schema.name] - ) - - def _extract_object_tags( - self, catalog, columns: List[str], tag_schema_name: str - ) -> None: - with self._connection.cursor() as cursor: - query = f"SELECT {', '.join(columns)} FROM {catalog}.information_schema.{tag_schema_name}" - - cursor.execute(query) - for ( - catalog_name, - schema_name, - dataset_name, - tag_name, - tag_value, - ) in cursor.fetchall(): - normalized_dataset_name = dataset_normalized_name( - catalog_name, schema_name, dataset_name - ) - dataset = self._datasets.get(normalized_dataset_name) - - if dataset is None: - logger.warning(f"Cannot find {normalized_dataset_name} dataset") - continue + ] + ) - assert dataset.system_tags and dataset.system_tags.tags is not None + if comment is not None: + hierarchy.system_description = SystemDescription( + description=comment, + platform=AssetPlatform.UNITY_CATALOG, + ) - if tag_value: - tag = SystemTag( - key=tag_name, - system_tag_source=SystemTagSource.UNITY_CATALOG, - value=tag_value, - ) - else: - tag = SystemTag( - key=None, + if tags is not None: + hierarchy.system_tags = SystemTags( + tags=[ + SystemTag( + key=tag.key, + value=tag.value, system_tag_source=SystemTagSource.UNITY_CATALOG, - value=tag_name, ) - dataset.system_tags.tags.append(tag) - - def _extract_table_tags(self, catalog: str) -> None: - self._extract_object_tags( - catalog, - columns=[ - "catalog_name", - "schema_name", - "table_name", - "tag_name", - "tag_value", - ], - tag_schema_name="table_tags", - ) + for tag in tags + ] + ) - def _extract_volume_tags(self, catalog: str) -> None: - self._extract_object_tags( - catalog, - columns=[ - "catalog_name", - "schema_name", - "volume_name", - "tag_name", - "tag_value", - ], - tag_schema_name="volume_tags", - ) + return hierarchy - def _extract_column_tags(self, catalog: str) -> None: - with self._connection.cursor() as cursor: - columns = [ - "catalog_name", - "schema_name", - "table_name", - "column_name", - "tag_name", - "tag_value", - ] - query = f"SELECT {', '.join(columns)} FROM {catalog}.information_schema.column_tags" - - cursor.execute(query) - for ( - catalog_name, - schema_name, - table_name, - column_name, - tag_name, - tag_value, - ) in cursor.fetchall(): - normalized_dataset_name = dataset_normalized_name( - catalog_name, schema_name, table_name - ) - dataset = self._datasets.get(normalized_dataset_name) - if dataset is None: - logger.warning(f"Cannot find {normalized_dataset_name} table") - continue + def _init_catalog( + self, + catalog_info: CatalogInfo, + ) -> Hierarchy: + return self._init_hierarchy( + catalog=catalog_info.catalog_name, + owner=catalog_info.owner, + comment=catalog_info.comment, + tags=catalog_info.tags, + ) - tag = f"{tag_name}={tag_value}" if tag_value else tag_name - - assert ( - dataset.schema is not None - ) # Can't be None, we initialized it at `init_dataset` - if dataset.schema.fields: - field = next( - ( - f - for f in dataset.schema.fields - if f.field_name == column_name - ), - None, - ) - if field is not None: - if not field.tags: - field.tags = [] - field.tags.append(tag) - - def _fetch_tags(self, catalogs: List[str]): - catalog_system_tags: CatalogSystemTags = {} - - for catalog in catalogs: - if self._filter.include_database(catalog): - catalog_system_tags[catalog] = self._fetch_catalog_system_tags(catalog) - self._extract_hierarchies(catalog_system_tags) - self._assign_dataset_system_tags(catalog, catalog_system_tags) - self._extract_table_tags(catalog) - self._extract_volume_tags(catalog) - self._extract_column_tags(catalog) + def _init_schema( + self, + schema_info: SchemaInfo, + ) -> Hierarchy: + return self._init_hierarchy( + catalog=schema_info.catalog_name, + schema=schema_info.schema_name, + owner=schema_info.owner, + comment=schema_info.comment, + tags=schema_info.tags, + ) def _extract_volume_files(self, volume: VolumeInfo): + catalog_name = volume.catalog_name + schema_name = volume.schema_name + volume_name = volume.volume_name + volume_dataset = self._datasets.get( - dataset_normalized_name( - volume.catalog_name, volume.schema_name, volume.name - ) + dataset_normalized_name(catalog_name, schema_name, volume_name) ) - if not volume_dataset: + if volume_dataset is None: return - with self._connection.cursor() as cursor: - query = f"LIST '/Volumes/{volume.catalog_name}/{volume.schema_name}/{volume.name}'" + volume_entity_id = str( + to_dataset_entity_id_from_logical_id(volume_dataset.logical_id) + ) + + for volume_file_info in list_volume_files(self._connection, volume): + volume_file = self._init_volume_file(volume_file_info, volume_entity_id) + assert volume_dataset.unity_catalog.volume_info.volume_files is not None - cursor.execute(query) - for path, name, size, modification_time in cursor.fetchall(): - last_updated = to_utc_from_timestamp_ms(timestamp_ms=modification_time) - volume_file = self._init_volume_file( - path, - size, - last_updated, + volume_dataset.unity_catalog.volume_info.volume_files.append( + VolumeFile( + modification_time=volume_file_info.last_updated, + name=volume_file_info.name, + path=volume_file_info.path, + size=volume_file_info.size, entity_id=str( - to_dataset_entity_id_from_logical_id(volume_dataset.logical_id) + to_dataset_entity_id_from_logical_id(volume_file.logical_id) ), ) - - if volume_dataset and volume_file: - assert ( - volume_dataset.unity_catalog.volume_info.volume_files - is not None - ) - volume_dataset.unity_catalog.volume_info.volume_files.append( - VolumeFile( - modification_time=last_updated, - name=name, - path=path, - size=float(size), - entity_id=str( - to_dataset_entity_id_from_logical_id( - volume_file.logical_id - ) - ), - ) - ) + ) def _init_volume(self, volume: VolumeInfo): - assert ( - volume.volume_type - and volume.schema_name - and volume.catalog_name - and volume.name - ) - schema_name = volume.schema_name catalog_name = volume.catalog_name - name = volume.name - full_name = dataset_normalized_name(catalog_name, schema_name, name) + volume_name = volume.volume_name + full_name = dataset_normalized_name(catalog_name, schema_name, volume_name) dataset = Dataset() dataset.logical_id = DatasetLogicalID( @@ -755,15 +557,15 @@ def _init_volume(self, volume: VolumeInfo): ) dataset.structure = DatasetStructure( - database=catalog_name, schema=schema_name, table=volume.name + database=catalog_name, schema=schema_name, table=volume_name ) - main_url = self._get_volume_source_url(catalog_name, schema_name, name) + main_url = self._get_volume_source_url(catalog_name, schema_name, volume_name) dataset.source_info = SourceInfo( main_url=main_url, - last_updated=to_utc_from_timestamp_ms(timestamp_ms=volume.updated_at), - created_at_source=to_utc_from_timestamp_ms(timestamp_ms=volume.created_at), + created_at_source=volume.created_at, created_by=volume.created_by, + last_updated=volume.updated_at, updated_by=volume.updated_by, ) @@ -784,14 +586,23 @@ def _init_volume(self, volume: VolumeInfo): dataset.unity_catalog = UnityCatalog( dataset_type=UnityCatalogDatasetType.UNITY_CATALOG_VOLUME, volume_info=UnityCatalogVolumeInfo( - type=UnityCatalogVolumeType[volume.volume_type.value], + type=UnityCatalogVolumeType[volume.volume_type], volume_files=[], storage_location=volume.storage_location, ), ) dataset.entity_upstream = EntityUpstream(source_entities=[]) - dataset.system_tags = SystemTags(tags=[]) + dataset.system_tags = SystemTags( + tags=[ + SystemTag( + key=tag.key, + value=tag.value, + system_tag_source=SystemTagSource.UNITY_CATALOG, + ) + for tag in volume.tags + ] + ) self._datasets[full_name] = dataset @@ -799,36 +610,57 @@ def _init_volume(self, volume: VolumeInfo): def _init_volume_file( self, - path: str, - size: int, - last_updated: Optional[datetime], - entity_id: str, - ) -> Optional[Dataset]: + volume_file_info: VolumeFileInfo, + volume_entity_id: str, + ) -> Dataset: dataset = Dataset() dataset.logical_id = DatasetLogicalID( # We use path as ID for file - name=path, + name=volume_file_info.path, platform=DataPlatform.UNITY_CATALOG_VOLUME_FILE, ) - if last_updated: - dataset.source_info = SourceInfo(last_updated=last_updated) + dataset.source_info = SourceInfo(last_updated=volume_file_info.last_updated) dataset.unity_catalog = UnityCatalog( dataset_type=UnityCatalogDatasetType.UNITY_CATALOG_VOLUME_FILE, - volume_entity_id=entity_id, + volume_entity_id=volume_entity_id, ) dataset.statistics = to_dataset_statistics( - size_bytes=size, + size_bytes=volume_file_info.size, ) dataset.entity_upstream = EntityUpstream(source_entities=[]) - self._datasets[path] = dataset + self._datasets[volume_file_info.path] = dataset return dataset + def _propagate_tags(self): + """Propagate tags from catalogs and schemas to tables & volumes""" + + for dataset in self._datasets.values(): + tags = [] + + if dataset.structure is None: + continue + + catalog_name = dataset.structure.database.lower() + catalog = self._hierarchies.get(catalog_name) + if catalog is not None and catalog.system_tags is not None: + tags.extend(catalog.system_tags.tags) + + schema_name = dataset.structure.schema.lower() + schema = self._hierarchies.get(f"{catalog_name}.{schema_name}") + if schema is not None and schema.system_tags is not None: + tags.extend(schema.system_tags.tags) + + if dataset.system_tags is not None: + tags.extend(dataset.system_tags.tags) + + dataset.system_tags = SystemTags(tags=tags) + def _populate_last_refreshed_time(self): connection_pool = create_connection_pool( self._token, self._hostname, self._http_path, self._max_concurrency @@ -849,6 +681,29 @@ def _populate_last_refreshed_time(self): last_updated=last_refreshed_time, ) + def _populate_table_properties(self): + connection_pool = create_connection_pool( + self._token, self._hostname, self._http_path, self._max_concurrency + ) + + result_map = batch_get_table_properties( + connection_pool, + self._table_properties_queue, + ) + + for name, properties in result_map.items(): + dataset = self._datasets.get(name) + if ( + dataset is None + or dataset.unity_catalog is None + or dataset.unity_catalog.table_info is None + ): + continue + + dataset.unity_catalog.table_info.properties = [ + KeyValuePair(key=k, value=v) for k, v in properties.items() + ] + def _get_owner_display_name(self, user_id: str) -> str: # Unity Catalog returns service principal's application_id and must be # manually map back to display_name diff --git a/metaphor/unity_catalog/models.py b/metaphor/unity_catalog/models.py index 38a721e0..443c6a98 100644 --- a/metaphor/unity_catalog/models.py +++ b/metaphor/unity_catalog/models.py @@ -1,29 +1,8 @@ -from typing import Dict, List +from datetime import datetime +from typing import Dict, List, Literal, Optional -from databricks.sdk.service.catalog import ColumnInfo from pydantic import BaseModel -from metaphor.common.fieldpath import build_schema_field -from metaphor.common.logger import get_logger -from metaphor.models.metadata_change_event import SchemaField - -logger = get_logger() - - -def extract_schema_field_from_column_info(column: ColumnInfo) -> SchemaField: - if column.name is None or column.type_name is None: - raise ValueError(f"Invalid column {column.name}, no type_name found") - - field = build_schema_field( - column.name, column.type_name.value.lower(), column.comment - ) - field.precision = ( - float(column.type_precision) - if column.type_precision is not None - else float("nan") - ) - return field - class TableLineage(BaseModel): upstream_tables: List[str] = [] @@ -36,3 +15,73 @@ class Column(BaseModel): class ColumnLineage(BaseModel): upstream_columns: Dict[str, List[Column]] = {} + + +class Tag(BaseModel): + key: str + value: str + + +class CatalogInfo(BaseModel): + catalog_name: str + owner: str + comment: Optional[str] = None + tags: List[Tag] + + +class SchemaInfo(BaseModel): + catalog_name: str + schema_name: str + owner: str + comment: Optional[str] = None + tags: List[Tag] + + +class ColumnInfo(BaseModel): + column_name: str + data_type: str + data_precision: Optional[int] + is_nullable: bool + comment: Optional[str] = None + tags: List[Tag] + + +class TableInfo(BaseModel): + catalog_name: str + schema_name: str + table_name: str + type: str + owner: str + comment: Optional[str] = None + created_at: datetime + created_by: str + updated_at: datetime + updated_by: str + view_definition: Optional[str] = None + storage_location: Optional[str] = None + data_source_format: str + tags: List[Tag] = [] + columns: List[ColumnInfo] = [] + + +class VolumeInfo(BaseModel): + catalog_name: str + schema_name: str + volume_name: str + volume_type: Literal["MANAGED", "EXTERNAL"] + full_name: str + owner: str + comment: Optional[str] = None + created_at: datetime + created_by: str + updated_at: datetime + updated_by: str + storage_location: str + tags: List[Tag] + + +class VolumeFileInfo(BaseModel): + last_updated: datetime + name: str + path: str + size: float diff --git a/metaphor/unity_catalog/profile/extractor.py b/metaphor/unity_catalog/profile/extractor.py index 32909037..62cb056a 100644 --- a/metaphor/unity_catalog/profile/extractor.py +++ b/metaphor/unity_catalog/profile/extractor.py @@ -16,6 +16,7 @@ from metaphor.common.entity_id import normalize_full_dataset_name from metaphor.common.event_util import ENTITY_TYPES from metaphor.common.fieldpath import build_field_statistics +from metaphor.common.filter import DatasetFilter from metaphor.common.logger import get_logger from metaphor.common.utils import safe_float from metaphor.models.crawler_run_metadata import Platform @@ -26,7 +27,6 @@ DatasetLogicalID, DatasetStatistics, ) -from metaphor.unity_catalog.extractor import DEFAULT_FILTER from metaphor.unity_catalog.profile.config import UnityCatalogProfileRunConfig from metaphor.unity_catalog.utils import ( create_api, @@ -36,6 +36,14 @@ logger = get_logger() +# Filter out "system" database & all "information_schema" schemas +DEFAULT_FILTER: DatasetFilter = DatasetFilter( + excludes={ + "system": None, + "*": {"information_schema": None}, + } +) + NON_MODIFICATION_OPERATIONS = { "SET TBLPROPERTIES", "ADD CONSTRAINT", diff --git a/metaphor/unity_catalog/queries.py b/metaphor/unity_catalog/queries.py new file mode 100644 index 00000000..3f9deba9 --- /dev/null +++ b/metaphor/unity_catalog/queries.py @@ -0,0 +1,715 @@ +from datetime import datetime +from typing import Collection, Dict, List, Optional, Tuple + +from databricks.sql.client import Connection + +from metaphor.common.logger import get_logger, json_dump_to_debug_file +from metaphor.common.utils import start_of_day +from metaphor.unity_catalog.models import ( + CatalogInfo, + Column, + ColumnInfo, + ColumnLineage, + SchemaInfo, + TableInfo, + TableLineage, + Tag, + VolumeFileInfo, + VolumeInfo, +) + +logger = get_logger() + + +TableLineageMap = Dict[str, TableLineage] +"""Map a table's full name to its table lineage""" + +ColumnLineageMap = Dict[str, ColumnLineage] +"""Map a table's full name to its column lineage""" + +IGNORED_HISTORY_OPERATIONS = { + "ADD CONSTRAINT", + "CHANGE COLUMN", + "LIQUID TAGGING", + "OPTIMIZE", + "SET TBLPROPERTIES", +} +"""These are the operations that do not modify actual data.""" + + +def to_tags(tags: Optional[List[Dict]]) -> List[Tag]: + return [Tag(key=tag["tag_name"], value=tag["tag_value"]) for tag in tags or []] + + +def list_catalogs(connection: Connection) -> List[CatalogInfo]: + """ + Fetch catalogs from system.access.information_schema + See + - https://docs.databricks.com/en/sql/language-manual/information-schema/catalogs.html + - https://docs.databricks.com/en/sql/language-manual/information-schema/catalog_tags.html + """ + catalogs: List[CatalogInfo] = [] + + with connection.cursor() as cursor: + query = """ + WITH c AS ( + SELECT + catalog_name, + catalog_owner, + comment + FROM system.information_schema.catalogs + WHERE catalog_name <> 'system' + ), + + t AS ( + SELECT + catalog_name, + struct(tag_name, tag_value) as tag + FROM system.information_schema.catalog_tags + WHERE catalog_name <> 'system' + ) + + SELECT + c.catalog_name AS catalog_name, + first(c.catalog_owner) AS catalog_owner, + first(c.comment) AS comment, + collect_list(t.tag) AS tags + FROM c + LEFT JOIN t + ON c.catalog_name = t.catalog_name + GROUP BY c.catalog_name + ORDER by c.catalog_name + """ + + try: + cursor.execute(query) + except Exception as error: + logger.exception(f"Failed to list catalogs: {error}") + return [] + + for row in cursor.fetchall(): + catalogs.append( + CatalogInfo( + catalog_name=row["catalog_name"], + owner=row["catalog_owner"], + comment=row["comment"], + tags=to_tags(row["tags"]), + ) + ) + + logger.info(f"Found {len(catalogs)} catalogs") + json_dump_to_debug_file(catalogs, "list_catalogs.json") + return catalogs + + +def list_schemas(connection: Connection, catalog: str) -> List[SchemaInfo]: + """ + Fetch schemas for a specific catalog from system.access.information_schema + See + - https://docs.databricks.com/en/sql/language-manual/information-schema/schemata.html + - https://docs.databricks.com/en/sql/language-manual/information-schema/schema_tags.html + """ + schemas: List[SchemaInfo] = [] + + with connection.cursor() as cursor: + query = """ + WITH s AS ( + SELECT + catalog_name, + schema_name, + schema_owner, + comment + FROM system.information_schema.schemata + WHERE catalog_name = %(catalog)s AND schema_name <> 'information_schema' + ), + + t AS ( + SELECT + catalog_name, + schema_name, + struct(tag_name, tag_value) as tag + FROM system.information_schema.schema_tags + WHERE catalog_name = %(catalog)s AND schema_name <> 'information_schema' + ) + + SELECT + first(s.catalog_name) AS catalog_name, + s.schema_name AS schema_name, + first(s.schema_owner) AS schema_owner, + first(s.comment) AS comment, + collect_list(t.tag) AS tags + FROM s + LEFT JOIN t + ON s.catalog_name = t.catalog_name AND s.schema_name = t.schema_name + GROUP BY s.schema_name + ORDER by s.schema_name + """ + + try: + cursor.execute(query, {"catalog": catalog}) + except Exception as error: + logger.exception(f"Failed to list schemas for {catalog}: {error}") + return [] + + for row in cursor.fetchall(): + schemas.append( + SchemaInfo( + catalog_name=row["catalog_name"], + schema_name=row["schema_name"], + owner=row["schema_owner"], + comment=row["comment"], + tags=to_tags(row["tags"]), + ) + ) + + logger.info(f"Found {len(schemas)} schemas from {catalog}") + json_dump_to_debug_file(schemas, f"list_schemas_{catalog}.json") + return schemas + + +def list_tables(connection: Connection, catalog: str, schema: str) -> List[TableInfo]: + """ + Fetch tables for a specific schema from system.access.information_schema + See + - https://docs.databricks.com/en/sql/language-manual/information-schema/tables.html + - https://docs.databricks.com/en/sql/language-manual/information-schema/views.html + - https://docs.databricks.com/en/sql/language-manual/information-schema/table_tags.html + - https://docs.databricks.com/en/sql/language-manual/information-schema/columns.html + - https://docs.databricks.com/en/sql/language-manual/information-schema/column_tags.html + """ + + tables: List[TableInfo] = [] + with connection.cursor() as cursor: + query = """ + WITH + t AS ( + SELECT + table_catalog, + table_schema, + table_name, + table_type, + table_owner, + comment, + data_source_format, + storage_path, + created, + created_by, + last_altered, + last_altered_by + FROM system.information_schema.tables + WHERE + table_catalog = %(catalog)s AND + table_schema = %(schema)s + ), + + tt AS ( + SELECT + catalog_name AS table_catalog, + schema_name AS table_schema, + table_name AS table_name, + collect_list(struct(tag_name, tag_value)) as tags + FROM system.information_schema.table_tags + WHERE + catalog_name = %(catalog)s AND + schema_name = %(schema)s + GROUP BY catalog_name, schema_name, table_name + ), + + v AS ( + SELECT + table_catalog, + table_schema, + table_name, + view_definition + FROM system.information_schema.views + WHERE + table_catalog = %(catalog)s AND + table_schema = %(schema)s + ), + + tf AS ( + SELECT + t.table_catalog, + t.table_schema, + t.table_name, + t.table_type, + t.table_owner, + t.comment, + t.data_source_format, + t.storage_path, + t.created, + t.created_by, + t.last_altered, + t.last_altered_by, + v.view_definition, + tt.tags + FROM t + LEFT JOIN v + ON + t.table_catalog = v.table_catalog AND + t.table_schema = v.table_schema AND + t.table_name = v.table_name + LEFT JOIN tt + ON + t.table_catalog = tt.table_catalog AND + t.table_schema = tt.table_schema AND + t.table_name = tt.table_name + ), + + c AS ( + SELECT + table_catalog, + table_schema, + table_name, + column_name, + data_type, + CASE + WHEN numeric_precision IS NOT NULL THEN numeric_precision + WHEN datetime_precision IS NOT NULL THEN datetime_precision + ELSE NULL + END AS data_precision, + is_nullable, + comment + FROM system.information_schema.columns + WHERE + table_catalog = %(catalog)s AND + table_schema = %(schema)s + ), + + ct AS ( + SELECT + catalog_name AS table_catalog, + schema_name AS table_schema, + table_name, + column_name, + collect_list(struct(tag_name, tag_value)) as tags + FROM system.information_schema.column_tags + WHERE + catalog_name = %(catalog)s AND + schema_name = %(schema)s + GROUP BY catalog_name, schema_name, table_name, column_name + ), + + cf AS ( + SELECT + c.table_catalog, + c.table_schema, + c.table_name, + collect_list(struct( + c.column_name, + c.data_type, + c.data_precision, + c.is_nullable, + c.comment, + ct.tags + )) as columns + FROM c + LEFT JOIN ct + ON + c.table_catalog = ct.table_catalog AND + c.table_schema = ct.table_schema AND + c.table_name = ct.table_name AND + c.column_name = ct.column_name + GROUP BY c.table_catalog, c.table_schema, c.table_name + ) + + SELECT + tf.table_catalog AS catalog_name, + tf.table_schema AS schema_name, + tf.table_name AS table_name, + tf.table_type AS table_type, + tf.table_owner AS owner, + tf.comment AS table_comment, + tf.data_source_format AS data_source_format, + tf.storage_path AS storage_path, + tf.created AS created_at, + tf.created_by AS created_by, + tf.last_altered as updated_at, + tf.last_altered_by AS updated_by, + tf.view_definition AS view_definition, + tf.tags AS tags, + cf.columns AS columns + FROM tf + LEFT JOIN cf + ON + tf.table_catalog = cf.table_catalog AND + tf.table_schema = cf.table_schema AND + tf.table_name = cf.table_name + ORDER by tf.table_catalog, tf.table_schema, tf.table_name + """ + + try: + cursor.execute(query, {"catalog": catalog, "schema": schema}) + except Exception as error: + logger.exception(f"Failed to list tables for {catalog}.{schema}: {error}") + return [] + + for row in cursor.fetchall(): + columns = [ + ColumnInfo( + column_name=column["column_name"], + data_type=column["data_type"], + data_precision=column["data_precision"], + is_nullable=column["is_nullable"] == "YES", + comment=column["comment"], + tags=to_tags(column["tags"]), + ) + for column in row["columns"] + ] + + tables.append( + TableInfo( + catalog_name=row["catalog_name"], + schema_name=row["schema_name"], + table_name=row["table_name"], + type=row["table_type"], + owner=row["owner"], + comment=row["table_comment"], + created_at=row["created_at"], + created_by=row["created_by"], + updated_at=row["updated_at"], + updated_by=row["updated_by"], + data_source_format=row["data_source_format"], + view_definition=row["view_definition"], + storage_location=row["storage_path"], + columns=columns, + ), + ) + + logger.info(f"Found {len(tables)} tables from {catalog}") + json_dump_to_debug_file(tables, f"list_tables_{catalog}_{schema}.json") + return tables + + +def list_volumes(connection: Connection, catalog: str, schema: str) -> List[VolumeInfo]: + """ + Fetch volumes for a specific catalog from system.access.information_schema + See + - https://docs.databricks.com/en/sql/language-manual/information-schema/volumes.html + - https://docs.databricks.com/en/sql/language-manual/information-schema/volume_tags.html + """ + volumes: List[VolumeInfo] = [] + + with connection.cursor() as cursor: + query = """ + WITH v AS ( + SELECT + volume_catalog, + volume_schema, + volume_name, + volume_type, + volume_owner, + comment, + created, + created_by, + last_altered, + last_altered_by, + storage_location + FROM system.information_schema.volumes + WHERE volume_catalog = %(catalog)s AND volume_schema = %(schema)s + ), + + t AS ( + SELECT + catalog_name, + schema_name, + volume_name, + struct(tag_name, tag_value) as tag + FROM system.information_schema.volume_tags + WHERE catalog_name = %(catalog)s AND schema_name = %(schema)s + ) + + SELECT + first(v.volume_catalog) AS volume_catalog, + first(v.volume_schema) AS volume_schema, + v.volume_name AS volume_name, + first(v.volume_type) AS volume_type, + first(v.volume_owner) AS volume_owner, + first(v.comment) AS comment, + first(v.created) AS created, + first(v.created_by) AS created_by, + first(v.last_altered) AS last_altered, + first(v.last_altered_by) AS last_altered_by, + first(v.storage_location) AS storage_location, + collect_list(t.tag) AS tags + FROM v + LEFT JOIN t + ON + v.volume_catalog = t.catalog_name AND + v.volume_schema = t.schema_name AND + v.volume_name = t.volume_name + GROUP BY v.volume_name + ORDER BY v.volume_name + """ + + try: + cursor.execute(query, {"catalog": catalog, "schema": schema}) + except Exception as error: + logger.exception(f"Failed to list volumes for {catalog}.{schema}: {error}") + return [] + + for row in cursor.fetchall(): + volumes.append( + VolumeInfo( + catalog_name=row["volume_catalog"], + schema_name=row["volume_schema"], + volume_name=row["volume_name"], + full_name=f"{row['volume_catalog']}.{row['volume_schema']}.{row['volume_name']}".lower(), + volume_type=row["volume_type"], + owner=row["volume_owner"], + comment=row["comment"], + created_at=row["created"], + created_by=row["created_by"], + updated_at=row["last_altered"], + updated_by=row["last_altered_by"], + storage_location=row["storage_location"], + tags=to_tags(row["tags"]), + ) + ) + + logger.info(f"Found {len(volumes)} volumes from {catalog}.{schema}") + json_dump_to_debug_file(volumes, f"list_volumes_{catalog}_{schema}.json") + return volumes + + +def list_volume_files( + connection: Connection, volume_info: VolumeInfo +) -> List[VolumeFileInfo]: + """ + List files in a volume + See https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-aux-list.html + """ + + catalog_name = volume_info.catalog_name + schema_name = volume_info.schema_name + volume_name = volume_info.volume_name + + volume_files: List[VolumeFileInfo] = [] + + with connection.cursor() as cursor: + query = f"LIST '/Volumes/{catalog_name}/{schema_name}/{volume_name}'" + + try: + cursor.execute(query) + except Exception as error: + logger.exception( + f"Failed to list files in {volume_info.full_name}: {error}" + ) + return [] + + for row in cursor.fetchall(): + volume_files.append( + VolumeFileInfo( + last_updated=row["modification_time"], + name=row["name"], + path=row["path"], + size=float(row["size"]), + ) + ) + + logger.info(f"Found {len(volume_files)} files in {volume_info.full_name}") + json_dump_to_debug_file( + volume_files, + f"list_volume_files_{catalog_name}_{schema_name}_{volume_name}.json", + ) + return volume_files + + +def list_table_lineage( + connection: Connection, catalog: str, schema: str, lookback_days=7 +) -> TableLineageMap: + """ + Fetch table lineage for a specific schema from system.access.table_lineage table + See https://docs.databricks.com/en/admin/system-tables/lineage.html for more details + """ + + table_lineage: Dict[str, TableLineage] = {} + + with connection.cursor() as cursor: + query = f""" + SELECT + source_table_full_name, + target_table_full_name + FROM system.access.table_lineage + WHERE + target_table_catalog = '{catalog}' AND + target_table_schema = '{schema}' AND + source_table_full_name IS NOT NULL AND + event_time > date_sub(now(), {lookback_days}) + GROUP BY + source_table_full_name, + target_table_full_name + """ + + try: + cursor.execute(query) + except Exception as error: + logger.exception( + f"Failed to list table lineage for {catalog}.{schema}: {error}" + ) + return {} + + for source_table, target_table in cursor.fetchall(): + lineage = table_lineage.setdefault(target_table.lower(), TableLineage()) + lineage.upstream_tables.append(source_table.lower()) + + logger.info( + f"Fetched table lineage for {len(table_lineage)} tables in {catalog}.{schema}" + ) + json_dump_to_debug_file(table_lineage, f"table_lineage_{catalog}_{schema}.json") + return table_lineage + + +def list_column_lineage( + connection: Connection, catalog: str, schema: str, lookback_days=7 +) -> ColumnLineageMap: + """ + Fetch column lineage for a specific schema from system.access.table_lineage table + See https://docs.databricks.com/en/admin/system-tables/lineage.html for more details + """ + column_lineage: Dict[str, ColumnLineage] = {} + + with connection.cursor() as cursor: + query = f""" + SELECT + source_table_full_name, + source_column_name, + target_table_full_name, + target_column_name + FROM system.access.column_lineage + WHERE + target_table_catalog = '{catalog}' AND + target_table_schema = '{schema}' AND + source_table_full_name IS NOT NULL AND + event_time > date_sub(now(), {lookback_days}) + GROUP BY + source_table_full_name, + source_column_name, + target_table_full_name, + target_column_name + """ + + try: + cursor.execute(query) + except Exception as error: + logger.exception( + f"Failed to list column lineage for {catalog}.{schema}: {error}" + ) + return {} + + cursor.execute(query) + + for ( + source_table, + source_column, + target_table, + target_column, + ) in cursor.fetchall(): + lineage = column_lineage.setdefault(target_table.lower(), ColumnLineage()) + columns = lineage.upstream_columns.setdefault(target_column.lower(), []) + columns.append( + Column( + table_name=source_table.lower(), column_name=source_column.lower() + ) + ) + + logger.info( + f"Fetched column lineage for {len(column_lineage)} tables in {catalog}.{schema}" + ) + json_dump_to_debug_file(column_lineage, f"column_lineage_{catalog}_{schema}.json") + return column_lineage + + +def list_query_logs( + connection: Connection, lookback_days: int, excluded_usernames: Collection[str] +): + """ + Fetch query logs from system.query.history table + See https://docs.databricks.com/en/admin/system-tables/query-history.html + """ + start = start_of_day(lookback_days) + end = start_of_day() + + user_condition = ",".join([f"'{user}'" for user in excluded_usernames]) + user_filter = f"Q.executed_by IN ({user_condition}) AND" if user_condition else "" + + with connection.cursor() as cursor: + query = f""" + SELECT + statement_id as query_id, + executed_by as email, + start_time, + int(total_task_duration_ms/1000) as duration, + read_rows as rows_read, + produced_rows as rows_written, + read_bytes as bytes_read, + written_bytes as bytes_written, + statement_type as query_type, + statement_text as query_text + FROM system.query.history + WHERE + {user_filter} + execution_status = 'FINISHED' AND + start_time >= ? AND + start_time < ? + """ + + try: + cursor.execute(query, [start, end]) + except Exception as error: + logger.exception(f"Failed to list query logs: {error}") + return [] + + return cursor.fetchall() + + +def get_last_refreshed_time( + connection: Connection, + table_full_name: str, + limit: int, +) -> Optional[Tuple[str, datetime]]: + """ + Retrieve the last refresh time for a table + See https://docs.databricks.com/en/delta/history.html + """ + + with connection.cursor() as cursor: + try: + cursor.execute(f"DESCRIBE HISTORY {table_full_name} LIMIT {limit}") + except Exception as error: + logger.exception(f"Failed to get history for {table_full_name}: {error}") + return None + + for history in cursor.fetchall(): + operation = history["operation"] + if operation not in IGNORED_HISTORY_OPERATIONS: + logger.info( + f"Fetched last refresh time for {table_full_name} ({operation})" + ) + return (table_full_name, history["timestamp"]) + + return None + + +def get_table_properties( + connection: Connection, + table_full_name: str, +) -> Optional[Tuple[str, Dict[str, str]]]: + """ + Retrieve the properties for a table + See https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-aux-show-tblproperties.html + """ + properties: Dict[str, str] = {} + with connection.cursor() as cursor: + try: + cursor.execute(f"SHOW TBLPROPERTIES {table_full_name}") + except Exception as error: + logger.exception( + f"Failed to show table properties for {table_full_name}: {error}" + ) + return None + + for row in cursor.fetchall(): + properties[row["key"]] = row["value"] + + return (table_full_name, properties) diff --git a/metaphor/unity_catalog/utils.py b/metaphor/unity_catalog/utils.py index c711b4a4..bdf1906b 100644 --- a/metaphor/unity_catalog/utils.py +++ b/metaphor/unity_catalog/utils.py @@ -1,7 +1,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from queue import Queue -from typing import Collection, Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set from databricks import sql from databricks.sdk import WorkspaceClient @@ -16,180 +16,24 @@ from metaphor.common.sql.table_level_lineage.table_level_lineage import ( extract_table_level_lineage, ) -from metaphor.common.utils import is_email, safe_float, start_of_day -from metaphor.models.metadata_change_event import DataPlatform, Dataset, QueriedDataset -from metaphor.unity_catalog.models import Column, ColumnLineage, TableLineage +from metaphor.common.utils import is_email, safe_float +from metaphor.models.metadata_change_event import ( + DataPlatform, + Dataset, + QueriedDataset, + SystemTag, + SystemTags, + SystemTagSource, +) +from metaphor.unity_catalog.models import Tag +from metaphor.unity_catalog.queries import ( + get_last_refreshed_time, + get_table_properties, + list_query_logs, +) logger = get_logger() -# Map a table's full name to its table lineage -TableLineageMap = Dict[str, TableLineage] - -# Map a table's full name to its column lineage -ColumnLineageMap = Dict[str, ColumnLineage] - -IGNORED_HISTORY_OPERATIONS = { - "ADD CONSTRAINT", - "CHANGE COLUMN", - "LIQUID TAGGING", - "OPTIMIZE", - "SET TBLPROPERTIES", -} -"""These are the operations that do not modify actual data.""" - - -def list_table_lineage( - connection: Connection, catalog: str, schema: str, lookback_days=7 -) -> TableLineageMap: - """ - Fetch table lineage for a specific schema from system.access.table_lineage table - See https://docs.databricks.com/en/admin/system-tables/lineage.html for more details - """ - - with connection.cursor() as cursor: - query = f""" - SELECT - source_table_full_name, - target_table_full_name - FROM system.access.table_lineage - WHERE - target_table_catalog = '{catalog}' AND - target_table_schema = '{schema}' AND - source_table_full_name IS NOT NULL AND - event_time > date_sub(now(), {lookback_days}) - GROUP BY - source_table_full_name, - target_table_full_name - """ - cursor.execute(query) - - table_lineage: Dict[str, TableLineage] = {} - for source_table, target_table in cursor.fetchall(): - lineage = table_lineage.setdefault(target_table.lower(), TableLineage()) - lineage.upstream_tables.append(source_table.lower()) - - logger.info( - f"Fetched table lineage for {len(table_lineage)} tables in {catalog}.{schema}" - ) - json_dump_to_debug_file(table_lineage, f"table_lineage_{catalog}_{schema}.json") - return table_lineage - - -def list_column_lineage( - connection: Connection, catalog: str, schema: str, lookback_days=7 -) -> ColumnLineageMap: - """ - Fetch column lineage for a specific schema from system.access.table_lineage table - See https://docs.databricks.com/en/admin/system-tables/lineage.html for more details - """ - - with connection.cursor() as cursor: - query = f""" - SELECT - source_table_full_name, - source_column_name, - target_table_full_name, - target_column_name - FROM system.access.column_lineage - WHERE - target_table_catalog = '{catalog}' AND - target_table_schema = '{schema}' AND - source_table_full_name IS NOT NULL AND - event_time > date_sub(now(), {lookback_days}) - GROUP BY - source_table_full_name, - source_column_name, - target_table_full_name, - target_column_name - """ - cursor.execute(query) - - column_lineage: Dict[str, ColumnLineage] = {} - for ( - source_table, - source_column, - target_table, - target_column, - ) in cursor.fetchall(): - lineage = column_lineage.setdefault(target_table.lower(), ColumnLineage()) - columns = lineage.upstream_columns.setdefault(target_column.lower(), []) - columns.append( - Column( - table_name=source_table.lower(), column_name=source_column.lower() - ) - ) - - logger.info( - f"Fetched column lineage for {len(column_lineage)} tables in {catalog}.{schema}" - ) - json_dump_to_debug_file(column_lineage, f"column_lineage_{catalog}_{schema}.json") - return column_lineage - - -def list_query_logs( - connection: Connection, lookback_days: int, excluded_usernames: Collection[str] -): - """ - Fetch query logs from system.query.history table - See https://docs.databricks.com/en/admin/system-tables/query-history.html - """ - start = start_of_day(lookback_days) - end = start_of_day() - - user_condition = ",".join([f"'{user}'" for user in excluded_usernames]) - user_filter = f"Q.executed_by IN ({user_condition}) AND" if user_condition else "" - - with connection.cursor() as cursor: - query = f""" - SELECT - statement_id as query_id, - executed_by as email, - start_time, - int(total_task_duration_ms/1000) as duration, - read_rows as rows_read, - produced_rows as rows_written, - read_bytes as bytes_read, - written_bytes as bytes_written, - statement_type as query_type, - statement_text as query_text - FROM system.query.history - WHERE - {user_filter} - execution_status = 'FINISHED' AND - start_time >= ? AND - start_time < ? - """ - cursor.execute(query, [start, end]) - return cursor.fetchall() - - -def get_last_refreshed_time( - connection: Connection, - table_full_name: str, - limit: int, -) -> Optional[Tuple[str, datetime]]: - """ - Retrieve the last refresh time for a table - See https://docs.databricks.com/en/delta/history.html - """ - - with connection.cursor() as cursor: - try: - cursor.execute(f"DESCRIBE HISTORY {table_full_name} LIMIT {limit}") - except Exception as error: - logger.exception(f"Failed to get history for {table_full_name}: {error}") - return None - - for history in cursor.fetchall(): - operation = history["operation"] - if operation not in IGNORED_HISTORY_OPERATIONS: - logger.info( - f"Fetched last refresh time for {table_full_name} ({operation})" - ) - return (table_full_name, history["timestamp"]) - - return None - def batch_get_last_refreshed_time( connection_pool: Queue, @@ -231,6 +75,44 @@ def get_last_refreshed_time_helper(table_full_name: str): return result_map +def batch_get_table_properties( + connection_pool: Queue, + table_full_names: List[str], +) -> Dict[str, Dict[str, str]]: + result_map: Dict[str, Dict[str, str]] = {} + + with ThreadPoolExecutor(max_workers=connection_pool.maxsize) as executor: + + def get_table_properties_helper(table_full_name: str): + connection = connection_pool.get() + result = get_table_properties(connection, table_full_name) + connection_pool.put(connection) + return result + + futures = { + executor.submit( + get_table_properties_helper, + table_full_name, + ): table_full_name + for table_full_name in table_full_names + } + + for future in as_completed(futures): + try: + result = future.result() + if result is None: + continue + + table_full_name, properties = result + result_map[table_full_name] = properties + except Exception: + logger.exception( + f"Not able to get table properties for {futures[future]}" + ) + + return result_map + + SPECIAL_CHARACTERS = "&*{}[],=-()+;'\"`" """ The special characters mentioned in Databricks documentation are: @@ -360,7 +242,7 @@ def find_qualified_dataset(dataset: QueriedDataset, datasets: Dict[str, Dataset] return None -def to_query_log_with_tll( +def _to_query_log_with_tll( row: Row, service_principals: Dict[str, ServicePrincipal], datasets: Dict[str, Dataset], @@ -436,7 +318,7 @@ def get_query_logs( count = 0 logger.info(f"{len(rows)} queries to fetch") for row in rows: - res = to_query_log_with_tll( + res = _to_query_log_with_tll( row, service_principals, datasets, process_query_config ) if res is not None: @@ -444,3 +326,16 @@ def get_query_logs( if count % 1000 == 0: logger.info(f"Fetched {count} queries") yield res + + +def to_system_tags(tags: List[Tag]) -> SystemTags: + return SystemTags( + tags=[ + SystemTag( + key=tag.key if tag.value else None, + value=tag.value if tag.value else tag.key, + system_tag_source=SystemTagSource.UNITY_CATALOG, + ) + for tag in tags + ], + ) diff --git a/poetry.lock b/poetry.lock index 4652634d..834065c4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -646,8 +646,8 @@ files = [ jmespath = ">=0.7.1,<2.0.0" python-dateutil = ">=2.1,<3.0.0" urllib3 = [ - {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, ] [package.extras] @@ -1139,13 +1139,13 @@ test = ["flake8", "isort", "pytest"] [[package]] name = "databricks-sdk" -version = "0.29.0" +version = "0.36.0" description = "Databricks SDK for Python (Beta)" optional = true python-versions = ">=3.7" files = [ - {file = "databricks-sdk-0.29.0.tar.gz", hash = "sha256:23016df608bb025548582d378f94af2ea312c0d77250ac14aa57d1f863efe88c"}, - {file = "databricks_sdk-0.29.0-py3-none-any.whl", hash = "sha256:3e08578f4128f759a6a9bba2c836ec32a4cff37fb594530209ab92f2534985bd"}, + {file = "databricks_sdk-0.36.0-py3-none-any.whl", hash = "sha256:e6105a2752c7980de35f7c7e3c4d63389c0763c9ef7bf7e2813e464acef907e9"}, + {file = "databricks_sdk-0.36.0.tar.gz", hash = "sha256:d8c46348cbd3e0b56991a6b7a59d7a6e0437947f6387bef832e6fe092e2dd427"}, ] [package.dependencies] @@ -1153,25 +1153,26 @@ google-auth = ">=2.0,<3.0" requests = ">=2.28.1,<3" [package.extras] -dev = ["autoflake", "databricks-connect", "ipython", "ipywidgets", "isort", "pycodestyle", "pyfakefs", "pytest", "pytest-cov", "pytest-mock", "pytest-rerunfailures", "pytest-xdist", "requests-mock", "wheel", "yapf"] +dev = ["autoflake", "databricks-connect", "httpx", "ipython", "ipywidgets", "isort", "langchain-openai", "openai", "pycodestyle", "pyfakefs", "pytest", "pytest-cov", "pytest-mock", "pytest-rerunfailures", "pytest-xdist", "requests-mock", "wheel", "yapf"] notebook = ["ipython (>=8,<9)", "ipywidgets (>=8,<9)"] +openai = ["httpx", "langchain-openai", "openai"] [[package]] name = "databricks-sql-connector" -version = "3.4.0" +version = "3.6.0" description = "Databricks SQL Connector for Python" optional = true python-versions = "<4.0.0,>=3.8.0" files = [ - {file = "databricks_sql_connector-3.4.0-py3-none-any.whl", hash = "sha256:7ba2efa4149529dee418ec467bacff1cb34c321a43e597d41fd020e569cbba3f"}, - {file = "databricks_sql_connector-3.4.0.tar.gz", hash = "sha256:5def7762a398e025db6a5740649f3ea856f07dc04a87cb7818af335f4157c030"}, + {file = "databricks_sql_connector-3.6.0-py3-none-any.whl", hash = "sha256:126b1b0ec8403c2ca7d84f1c617ef482f4288f428608b51122186dabc69bbd0f"}, + {file = "databricks_sql_connector-3.6.0.tar.gz", hash = "sha256:4302828afa17b9f993dc63143aa35c76e30d2b46bdb4bcc47452d4c44bb412d8"}, ] [package.dependencies] lz4 = ">=4.0.2,<5.0.0" numpy = [ - {version = ">=1.23.4,<2.0.0", markers = "python_version >= \"3.11\""}, {version = ">=1.16.6,<2.0.0", markers = "python_version >= \"3.8\" and python_version < \"3.11\""}, + {version = ">=1.23.4,<2.0.0", markers = "python_version >= \"3.11\""}, ] oauthlib = ">=3.1.0,<4.0.0" openpyxl = ">=3.0.10,<4.0.0" @@ -1221,9 +1222,9 @@ isort = ">=4.3.21,<6.0" jinja2 = ">=2.10.1,<4.0" packaging = "*" pydantic = [ - {version = ">=1.10.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.11\" and python_version < \"4.0\""}, {version = ">=1.5.1,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version < \"3.10\""}, {version = ">=1.9.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.10.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.11\" and python_version < \"4.0\""}, ] pyyaml = ">=6.0.1" toml = {version = ">=0.10.0,<1.0.0", markers = "python_version < \"3.11\""} @@ -1691,12 +1692,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" @@ -1832,8 +1833,8 @@ google-cloud-core = ">=2.0.0,<3.0.0dev" grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev" opentelemetry-api = ">=1.9.0" proto-plus = [ - {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, + {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, ] protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" @@ -3956,8 +3957,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -6859,4 +6860,4 @@ unity-catalog = ["databricks-sdk", "databricks-sql-connector", "sqlglot"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "5cf4c7306334f39d8bdfde366c70b984695d832220a7ab4f1d5a524d402a499e" +content-hash = "97ff6a1aa416689438658a9db3b0e577082d54cf4126e8998fbb381eef3a1d96" diff --git a/pyproject.toml b/pyproject.toml index 8f15b488..e5c586f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "metaphor-connectors" -version = "0.14.138" +version = "0.14.139" license = "Apache-2.0" description = "A collection of Python-based 'connectors' that extract metadata from various sources to ingest into the Metaphor app." authors = ["Metaphor "] @@ -25,8 +25,8 @@ boto3 = "^1.35.19" botocore = "^1.35.19" canonicaljson = "^2.0.0" confluent-kafka = { version = "^2.3.0", optional = true } -databricks-sdk = { version = "^0.29.0", optional = true } -databricks-sql-connector = { version = "^3.3.0", optional = true } +databricks-sdk = { version = "^0.36.0", optional = true } +databricks-sql-connector = { version = "^3.5.0", optional = true } fastavro = { version = "^1.9.2", optional = true } GitPython = "^3.1.37" google-cloud-bigquery = { version = "^3.25.0", optional = true } diff --git a/tests/unity_catalog/expected.json b/tests/unity_catalog/expected.json index 1da07bb7..0d8168dd 100644 --- a/tests/unity_catalog/expected.json +++ b/tests/unity_catalog/expected.json @@ -9,9 +9,9 @@ }, "schema": {}, "sourceInfo": { - "createdAtSource": "2024-04-25T08:39:38.658000+00:00", + "createdAtSource": "2020-01-01T00:00:00+00:00", "createdBy": "foo@bar.com", - "lastUpdated": "2024-04-25T08:39:38.658000+00:00", + "lastUpdated": "2020-01-01T00:00:00+00:00", "mainUrl": "https://dummy.host/explore/data/volumes/catalog2/schema/volume", "updatedBy": "foo@bar.com" }, @@ -45,14 +45,14 @@ "volumeFiles": [ { "entityId": "DATASET~C109145C4035631CA68E19687464C80A", - "modification_time": "2024-05-09T16:49:14+00:00", + "modification_time": "2020-01-01T00:00:00+00:00", "name": "input.csv", "path": "/Volumes/catalog2/schema/volume/input.csv", "size": 100000.0 }, { "entityId": "DATASET~3AC713F58F40836DEA91AF59F2AE4D7A", - "modification_time": "2024-05-09T18:12:34+00:00", + "modification_time": "2020-01-01T00:00:00+00:00", "name": "output.csv", "path": "/Volumes/catalog2/schema/volume/output.csv", "size": 200000.0 @@ -70,7 +70,7 @@ "platform": "UNITY_CATALOG_VOLUME_FILE" }, "sourceInfo": { - "lastUpdated": "2024-05-09T16:49:14+00:00" + "lastUpdated": "2020-01-01T00:00:00+00:00" }, "statistics": { "dataSizeBytes": 100000.0 @@ -89,7 +89,7 @@ "platform": "UNITY_CATALOG_VOLUME_FILE" }, "sourceInfo": { - "lastUpdated": "2024-05-09T18:12:34+00:00" + "lastUpdated": "2020-01-01T00:00:00+00:00" }, "statistics": { "dataSizeBytes": 200000.0 @@ -128,7 +128,8 @@ "description": "some description", "fieldName": "col1", "fieldPath": "col1", - "nativeType": "int", + "nativeType": "INT", + "nullable": true, "precision": 32.0, "tags": [ "col_tag=col_value", @@ -142,7 +143,7 @@ } }, "sourceInfo": { - "createdAtSource": "1970-01-01T00:00:00+00:00", + "createdAtSource": "2020-01-01T00:00:00+00:00", "lastUpdated": "2020-01-01T00:00:00+00:00", "mainUrl": "https://dummy.host/explore/data/catalog/schema/table" }, @@ -187,8 +188,9 @@ "value": "value" }, { + "key": "tag2", "systemTagSource": "UNITY_CATALOG", - "value": "tag2" + "value": "value2" } ] }, @@ -197,14 +199,14 @@ "tableInfo": { "dataSourceFormat": "CSV", "owner": "user1@foo.com", + "storageLocation": "s3://path", + "type": "MANAGED", "properties": [ { "key": "delta.lastCommitTimestamp", - "value": "\"1664444422000\"" + "value": "1664444422000" } - ], - "storageLocation": "s3://path", - "type": "MANAGED" + ] } } }, @@ -219,8 +221,10 @@ { "fieldName": "col1", "fieldPath": "col1", - "nativeType": "int", - "precision": 32.0 + "nativeType": "INT", + "nullable": true, + "precision": 32.0, + "tags": [] } ], "schemaType": "SQL", @@ -230,7 +234,10 @@ } }, "sourceInfo": { - "createdAtSource": "1970-01-01T00:00:00+00:00", + "createdAtSource": "2020-01-01T00:00:00+00:00", + "createdBy": "foo@bar.com", + "lastUpdated": "2020-01-01T00:00:00+00:00", + "updatedBy": "foo@bar.com", "mainUrl": "https://dummy.host/explore/data/catalog/schema/view" }, "structure": { @@ -274,69 +281,66 @@ "datasetType": "UNITY_CATALOG_TABLE", "tableInfo": { "owner": "user2@foo.com", + "dataSourceFormat": "CSV", + "type": "VIEW", "properties": [ { "key": "view.catalogAndNamespace.numParts", - "value": "\"2\"" + "value": "2" }, { "key": "view.sqlConfig.spark.sql.hive.convertCTAS", - "value": "\"true\"" + "value": "true" }, { "key": "view.query.out.col.0", - "value": "\"key\"" + "value": "key" }, { "key": "view.sqlConfig.spark.sql.parquet.compression.codec", - "value": "\"snappy\"" + "value": "snappy" }, { "key": "view.query.out.numCols", - "value": "\"3\"" + "value": "3" }, { "key": "view.referredTempViewNames", - "value": "\"[]\"" + "value": "[]" }, { "key": "view.query.out.col.1", - "value": "\"values\"" + "value": "values" }, { "key": "view.sqlConfig.spark.sql.streaming.stopTimeout", - "value": "\"15s\"" + "value": "15s" }, { "key": "view.catalogAndNamespace.part.0", - "value": "\"catalog\"" + "value": "catalog" }, { "key": "view.sqlConfig.spark.sql.sources.commitProtocolClass", - "value": "\"com.databricks.sql.transaction.directory.DirectoryAtomicCommitProtocol\"" + "value": "com.databricks.sql.transaction.directory.DirectoryAtomicCommitProtocol" }, { "key": "view.sqlConfig.spark.sql.sources.default", - "value": "\"delta\"" + "value": "delta" }, { "key": "view.sqlConfig.spark.sql.legacy.createHiveTableByDefault", - "value": "\"false\"" + "value": "false" }, { "key": "view.query.out.col.2", - "value": "\"nested_values\"" - }, - { - "key": "view.referredTempFunctionsNames", - "value": "\"[]\"" + "value": "nested_values" }, { "key": "view.catalogAndNamespace.part.1", - "value": "\"default\"" + "value": "default" } - ], - "type": "VIEW" + ] } } }, @@ -352,8 +356,10 @@ "description": "some description", "fieldName": "col1", "fieldPath": "col1", - "nativeType": "int", - "precision": 32.0 + "nativeType": "INT", + "nullable": true, + "precision": 32.0, + "tags": [] } ], "schemaType": "SQL", @@ -362,7 +368,7 @@ } }, "sourceInfo": { - "createdAtSource": "1970-01-01T00:00:00+00:00", + "createdAtSource": "2020-01-01T00:00:00+00:00", "lastUpdated": "2020-01-01T00:00:00+00:00", "mainUrl": "https://dummy.host/explore/data/catalog2/schema/table2" }, @@ -387,12 +393,6 @@ "tableInfo": { "dataSourceFormat": "DELTA", "owner": "sp1", - "properties": [ - { - "key": "delta.lastCommitTimestamp", - "value": "\"1664444422000\"" - } - ], "storageLocation": "s3://path", "type": "MANAGED" } diff --git a/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/external_shallow_clone.json b/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/external_shallow_clone.json index 53f58ec7..1a77710d 100644 --- a/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/external_shallow_clone.json +++ b/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/external_shallow_clone.json @@ -11,20 +11,33 @@ } }, "sourceInfo": { - "mainUrl": "http://foo.bar/catalog/schema/table" + "createdAtSource": "2020-01-01T00:00:00+00:00", + "createdBy": "bar", + "lastUpdated": "2020-01-01T00:00:00+00:00", + "mainUrl": "http://foo.bar/catalog/schema/table", + "updatedBy": "baz" }, "structure": { "database": "catalog", "schema": "schema", "table": "table" }, + "systemContacts": { + "contacts": [ + { + "email": "foo", + "systemContactSource": "UNITY_CATALOG" + } + ] + }, "systemTags": { "tags": [] }, "unityCatalog": { "datasetType": "UNITY_CATALOG_TABLE", "tableInfo": { - "properties": [], + "dataSourceFormat": "csv", + "owner": "foo", "type": "EXTERNAL_SHALLOW_CLONE" } } diff --git a/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/shallow_clone.json b/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/shallow_clone.json index 2fa59d34..b0022f91 100644 --- a/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/shallow_clone.json +++ b/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/shallow_clone.json @@ -11,20 +11,33 @@ } }, "sourceInfo": { - "mainUrl": "http://foo.bar/catalog/schema/table" + "createdAtSource": "2020-01-01T00:00:00+00:00", + "createdBy": "bar", + "lastUpdated": "2020-01-01T00:00:00+00:00", + "mainUrl": "http://foo.bar/catalog/schema/table", + "updatedBy": "baz" }, "structure": { "database": "catalog", "schema": "schema", "table": "table" }, + "systemContacts": { + "contacts": [ + { + "email": "foo", + "systemContactSource": "UNITY_CATALOG" + } + ] + }, "systemTags": { "tags": [] }, "unityCatalog": { "datasetType": "UNITY_CATALOG_TABLE", "tableInfo": { - "properties": [], + "dataSourceFormat": "csv", + "owner": "foo", "type": "MANAGED_SHALLOW_CLONE" } } diff --git a/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/table.json b/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/table.json index 4f292c55..f2427119 100644 --- a/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/table.json +++ b/tests/unity_catalog/snapshots/test_extractor/test_init_dataset/table.json @@ -10,8 +10,13 @@ "description": "some description", "fieldName": "col1", "fieldPath": "col1", - "nativeType": "int", - "precision": 32.0 + "nativeType": "INT", + "nullable": true, + "precision": 32.0, + "tags": [ + "col1_tag_key_1=col1_tag_value_1", + "col1_tag_key_2=col1_tag_value_2" + ] } ], "schemaType": "SQL", @@ -20,8 +25,11 @@ } }, "sourceInfo": { - "createdAtSource": "1970-01-01T00:00:00+00:00", - "mainUrl": "http://foo.bar/catalog/schema/table" + "createdAtSource": "2020-01-01T00:00:00+00:00", + "createdBy": "foo@bar.com", + "lastUpdated": "2020-01-01T00:00:00+00:00", + "mainUrl": "http://foo.bar/catalog/schema/table", + "updatedBy": "foo@bar.com" }, "structure": { "database": "catalog", @@ -42,14 +50,8 @@ "unityCatalog": { "datasetType": "UNITY_CATALOG_TABLE", "tableInfo": { - "dataSourceFormat": "CSV", + "dataSourceFormat": "csv", "owner": "foo@bar.com", - "properties": [ - { - "key": "delta.lastCommitTimestamp", - "value": "\"1664444422000\"" - } - ], "storageLocation": "s3://path", "type": "MANAGED" } diff --git a/tests/unity_catalog/snapshots/test_queries/test_get_last_refreshed_time/describe_history.sql b/tests/unity_catalog/snapshots/test_queries/test_get_last_refreshed_time/describe_history.sql new file mode 100644 index 00000000..aa560ccc --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_get_last_refreshed_time/describe_history.sql @@ -0,0 +1 @@ +DESCRIBE HISTORY db.schema.table LIMIT 50 \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_get_table_properties/show_table_properties.sql b/tests/unity_catalog/snapshots/test_queries/test_get_table_properties/show_table_properties.sql new file mode 100644 index 00000000..901d1f1d --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_get_table_properties/show_table_properties.sql @@ -0,0 +1 @@ +SHOW TBLPROPERTIES db.schema.table \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_catalogs/list_catalogs.sql b/tests/unity_catalog/snapshots/test_queries/test_list_catalogs/list_catalogs.sql new file mode 100644 index 00000000..e4102e0a --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_catalogs/list_catalogs.sql @@ -0,0 +1,29 @@ + + WITH c AS ( + SELECT + catalog_name, + catalog_owner, + comment + FROM system.information_schema.catalogs + WHERE catalog_name <> 'system' + ), + + t AS ( + SELECT + catalog_name, + struct(tag_name, tag_value) as tag + FROM system.information_schema.catalog_tags + WHERE catalog_name <> 'system' + ) + + SELECT + c.catalog_name AS catalog_name, + first(c.catalog_owner) AS catalog_owner, + first(c.comment) AS comment, + collect_list(t.tag) AS tags + FROM c + LEFT JOIN t + ON c.catalog_name = t.catalog_name + GROUP BY c.catalog_name + ORDER by c.catalog_name + \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_column_lineage/list_column_lineage.sql b/tests/unity_catalog/snapshots/test_queries/test_list_column_lineage/list_column_lineage.sql new file mode 100644 index 00000000..ca32e2a8 --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_column_lineage/list_column_lineage.sql @@ -0,0 +1,18 @@ + + SELECT + source_table_full_name, + source_column_name, + target_table_full_name, + target_column_name + FROM system.access.column_lineage + WHERE + target_table_catalog = 'catalog' AND + target_table_schema = 'schema' AND + source_table_full_name IS NOT NULL AND + event_time > date_sub(now(), 7) + GROUP BY + source_table_full_name, + source_column_name, + target_table_full_name, + target_column_name + \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_utils/test_list_query_logs/list_query_log.sql b/tests/unity_catalog/snapshots/test_queries/test_list_query_logs/list_query_log.sql similarity index 100% rename from tests/unity_catalog/snapshots/test_utils/test_list_query_logs/list_query_log.sql rename to tests/unity_catalog/snapshots/test_queries/test_list_query_logs/list_query_log.sql diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_schemas/list_schemas.sql b/tests/unity_catalog/snapshots/test_queries/test_list_schemas/list_schemas.sql new file mode 100644 index 00000000..6a72bc1c --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_schemas/list_schemas.sql @@ -0,0 +1,32 @@ + + WITH s AS ( + SELECT + catalog_name, + schema_name, + schema_owner, + comment + FROM system.information_schema.schemata + WHERE catalog_name = %(catalog)s AND schema_name <> 'information_schema' + ), + + t AS ( + SELECT + catalog_name, + schema_name, + struct(tag_name, tag_value) as tag + FROM system.information_schema.schema_tags + WHERE catalog_name = %(catalog)s AND schema_name <> 'information_schema' + ) + + SELECT + first(s.catalog_name) AS catalog_name, + s.schema_name AS schema_name, + first(s.schema_owner) AS schema_owner, + first(s.comment) AS comment, + collect_list(t.tag) AS tags + FROM s + LEFT JOIN t + ON s.catalog_name = t.catalog_name AND s.schema_name = t.schema_name + GROUP BY s.schema_name + ORDER by s.schema_name + \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_table_lineage/list_table_lineage.sql b/tests/unity_catalog/snapshots/test_queries/test_list_table_lineage/list_table_lineage.sql new file mode 100644 index 00000000..55e4eaa1 --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_table_lineage/list_table_lineage.sql @@ -0,0 +1,14 @@ + + SELECT + source_table_full_name, + target_table_full_name + FROM system.access.table_lineage + WHERE + target_table_catalog = 'c' AND + target_table_schema = 's' AND + source_table_full_name IS NOT NULL AND + event_time > date_sub(now(), 7) + GROUP BY + source_table_full_name, + target_table_full_name + \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_tables/list_tables.sql b/tests/unity_catalog/snapshots/test_queries/test_list_tables/list_tables.sql new file mode 100644 index 00000000..eff7dbc1 --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_tables/list_tables.sql @@ -0,0 +1,157 @@ + + WITH + t AS ( + SELECT + table_catalog, + table_schema, + table_name, + table_type, + table_owner, + comment, + data_source_format, + storage_path, + created, + created_by, + last_altered, + last_altered_by + FROM system.information_schema.tables + WHERE + table_catalog = %(catalog)s AND + table_schema = %(schema)s + ), + + tt AS ( + SELECT + catalog_name AS table_catalog, + schema_name AS table_schema, + table_name AS table_name, + collect_list(struct(tag_name, tag_value)) as tags + FROM system.information_schema.table_tags + WHERE + catalog_name = %(catalog)s AND + schema_name = %(schema)s + GROUP BY catalog_name, schema_name, table_name + ), + + v AS ( + SELECT + table_catalog, + table_schema, + table_name, + view_definition + FROM system.information_schema.views + WHERE + table_catalog = %(catalog)s AND + table_schema = %(schema)s + ), + + tf AS ( + SELECT + t.table_catalog, + t.table_schema, + t.table_name, + t.table_type, + t.table_owner, + t.comment, + t.data_source_format, + t.storage_path, + t.created, + t.created_by, + t.last_altered, + t.last_altered_by, + v.view_definition, + tt.tags + FROM t + LEFT JOIN v + ON + t.table_catalog = v.table_catalog AND + t.table_schema = v.table_schema AND + t.table_name = v.table_name + LEFT JOIN tt + ON + t.table_catalog = tt.table_catalog AND + t.table_schema = tt.table_schema AND + t.table_name = tt.table_name + ), + + c AS ( + SELECT + table_catalog, + table_schema, + table_name, + column_name, + data_type, + CASE + WHEN numeric_precision IS NOT NULL THEN numeric_precision + WHEN datetime_precision IS NOT NULL THEN datetime_precision + ELSE NULL + END AS data_precision, + is_nullable, + comment + FROM system.information_schema.columns + WHERE + table_catalog = %(catalog)s AND + table_schema = %(schema)s + ), + + ct AS ( + SELECT + catalog_name AS table_catalog, + schema_name AS table_schema, + table_name, + column_name, + collect_list(struct(tag_name, tag_value)) as tags + FROM system.information_schema.column_tags + WHERE + catalog_name = %(catalog)s AND + schema_name = %(schema)s + GROUP BY catalog_name, schema_name, table_name, column_name + ), + + cf AS ( + SELECT + c.table_catalog, + c.table_schema, + c.table_name, + collect_list(struct( + c.column_name, + c.data_type, + c.data_precision, + c.is_nullable, + c.comment, + ct.tags + )) as columns + FROM c + LEFT JOIN ct + ON + c.table_catalog = ct.table_catalog AND + c.table_schema = ct.table_schema AND + c.table_name = ct.table_name AND + c.column_name = ct.column_name + GROUP BY c.table_catalog, c.table_schema, c.table_name + ) + + SELECT + tf.table_catalog AS catalog_name, + tf.table_schema AS schema_name, + tf.table_name AS table_name, + tf.table_type AS table_type, + tf.table_owner AS owner, + tf.comment AS table_comment, + tf.data_source_format AS data_source_format, + tf.storage_path AS storage_path, + tf.created AS created_at, + tf.created_by AS created_by, + tf.last_altered as updated_at, + tf.last_altered_by AS updated_by, + tf.view_definition AS view_definition, + tf.tags AS tags, + cf.columns AS columns + FROM tf + LEFT JOIN cf + ON + tf.table_catalog = cf.table_catalog AND + tf.table_schema = cf.table_schema AND + tf.table_name = cf.table_name + ORDER by tf.table_catalog, tf.table_schema, tf.table_name + \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_volume_files/list_volume_files.sql b/tests/unity_catalog/snapshots/test_queries/test_list_volume_files/list_volume_files.sql new file mode 100644 index 00000000..5467b79c --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_volume_files/list_volume_files.sql @@ -0,0 +1 @@ +LIST '/Volumes/catalog1/schema1/volume1' \ No newline at end of file diff --git a/tests/unity_catalog/snapshots/test_queries/test_list_volumes/list_volumes.sql b/tests/unity_catalog/snapshots/test_queries/test_list_volumes/list_volumes.sql new file mode 100644 index 00000000..742e4b3c --- /dev/null +++ b/tests/unity_catalog/snapshots/test_queries/test_list_volumes/list_volumes.sql @@ -0,0 +1,50 @@ + + WITH v AS ( + SELECT + volume_catalog, + volume_schema, + volume_name, + volume_type, + volume_owner, + comment, + created, + created_by, + last_altered, + last_altered_by, + storage_location + FROM system.information_schema.volumes + WHERE volume_catalog = %(catalog)s AND volume_schema = %(schema)s + ), + + t AS ( + SELECT + catalog_name, + schema_name, + volume_name, + struct(tag_name, tag_value) as tag + FROM system.information_schema.volume_tags + WHERE catalog_name = %(catalog)s AND schema_name = %(schema)s + ) + + SELECT + first(v.volume_catalog) AS volume_catalog, + first(v.volume_schema) AS volume_schema, + v.volume_name AS volume_name, + first(v.volume_type) AS volume_type, + first(v.volume_owner) AS volume_owner, + first(v.comment) AS comment, + first(v.created) AS created, + first(v.created_by) AS created_by, + first(v.last_altered) AS last_altered, + first(v.last_altered_by) AS last_altered_by, + first(v.storage_location) AS storage_location, + collect_list(t.tag) AS tags + FROM v + LEFT JOIN t + ON + v.volume_catalog = t.catalog_name AND + v.volume_schema = t.schema_name AND + v.volume_name = t.volume_name + GROUP BY v.volume_name + ORDER BY v.volume_name + \ No newline at end of file diff --git a/tests/unity_catalog/test_extractor.py b/tests/unity_catalog/test_extractor.py index 0b93cfbb..32a2d7bf 100644 --- a/tests/unity_catalog/test_extractor.py +++ b/tests/unity_catalog/test_extractor.py @@ -1,17 +1,7 @@ from datetime import datetime, timezone -from queue import Queue from unittest.mock import MagicMock, patch import pytest -from databricks.sdk.service.catalog import ( - CatalogInfo, - ColumnInfo, - ColumnTypeName, - DataSourceFormat, - SchemaInfo, -) -from databricks.sdk.service.catalog import TableInfo as Table -from databricks.sdk.service.catalog import TableType, VolumeInfo, VolumeType from databricks.sdk.service.iam import ServicePrincipal from pytest_snapshot.plugin import Snapshot @@ -20,9 +10,19 @@ from metaphor.models.metadata_change_event import DataPlatform, QueryLog from metaphor.unity_catalog.config import UnityCatalogRunConfig from metaphor.unity_catalog.extractor import UnityCatalogExtractor -from metaphor.unity_catalog.models import Column, ColumnLineage, TableLineage +from metaphor.unity_catalog.models import ( + CatalogInfo, + Column, + ColumnInfo, + ColumnLineage, + SchemaInfo, + TableInfo, + TableLineage, + Tag, + VolumeFileInfo, + VolumeInfo, +) from tests.test_utils import load_json, serialize_event, wrap_query_log_stream_to_event -from tests.unity_catalog.mocks import mock_sql_connection def dummy_config(): @@ -34,20 +34,35 @@ def dummy_config(): ) +mock_time = datetime(2020, 1, 1, tzinfo=timezone.utc) + + @patch("metaphor.unity_catalog.extractor.batch_get_last_refreshed_time") @patch("metaphor.unity_catalog.extractor.create_connection_pool") @patch("metaphor.unity_catalog.extractor.create_connection") @patch("metaphor.unity_catalog.extractor.create_api") +@patch("metaphor.unity_catalog.extractor.list_catalogs") +@patch("metaphor.unity_catalog.extractor.list_schemas") +@patch("metaphor.unity_catalog.extractor.list_tables") +@patch("metaphor.unity_catalog.extractor.list_volumes") +@patch("metaphor.unity_catalog.extractor.list_volume_files") @patch("metaphor.unity_catalog.extractor.list_table_lineage") @patch("metaphor.unity_catalog.extractor.list_column_lineage") +@patch("metaphor.unity_catalog.extractor.batch_get_table_properties") @patch("metaphor.unity_catalog.extractor.get_query_logs") @patch("metaphor.unity_catalog.extractor.list_service_principals") @pytest.mark.asyncio async def test_extractor( mock_list_service_principals: MagicMock, mock_get_query_logs: MagicMock, + mock_batch_get_table_properties: MagicMock, mock_list_column_lineage: MagicMock, mock_list_table_lineage: MagicMock, + mock_list_volume_files: MagicMock, + mock_list_volumes: MagicMock, + mock_list_tables: MagicMock, + mock_list_schemas: MagicMock, + mock_list_catalogs: MagicMock, mock_create_api: MagicMock, mock_create_connection: MagicMock, mock_create_connection_pool: MagicMock, @@ -58,134 +73,153 @@ async def test_extractor( "sp1": ServicePrincipal(display_name="service principal 1") } - def mock_list_catalogs(): - return [CatalogInfo(name="catalog", owner="sp1")] - - def mock_list_schemas(catalog): - return [SchemaInfo(name="schema", owner="test@foo.bar")] + mock_list_catalogs.side_effect = [ + [ + CatalogInfo( + catalog_name="catalog", + owner="sp1", + tags=[ + Tag(key="catalog_tag_key_1", value="catalog_tag_value_1"), + Tag(key="catalog_tag_key_2", value="catalog_tag_value_2"), + ], + ) + ] + ] + mock_list_schemas.side_effect = [ + [ + SchemaInfo( + catalog_name="catalog", + schema_name="schema", + owner="test@foo.bar", + tags=[ + Tag(key="schema_tag_key_1", value="schema_tag_value_1"), + Tag(key="schema_tag_key_2", value="schema_tag_value_2"), + ], + ) + ] + ] - def mock_list_tables(catalog, schema): - return [ - Table( - name="table", + mock_list_tables.side_effect = [ + [ + TableInfo( + table_name="table", catalog_name="catalog", schema_name="schema", - table_type=TableType.MANAGED, - data_source_format=DataSourceFormat.CSV, + type="MANAGED", + data_source_format="CSV", columns=[ ColumnInfo( - name="col1", - type_name=ColumnTypeName.INT, - type_precision=32, - nullable=True, + column_name="col1", + data_type="INT", + data_precision=32, + is_nullable=True, comment="some description", + tags=[ + Tag(key="col_tag", value="col_value"), + Tag(key="col_tag2", value="tag_value_2"), + ], ) ], storage_location="s3://path", owner="user1@foo.com", comment="example", - updated_at=0, + updated_at=mock_time, updated_by="foo@bar.com", - properties={ - "delta.lastCommitTimestamp": "1664444422000", - }, - created_at=0, + created_at=mock_time, + created_by="foo@bar.com", + tags=[ + Tag(key="tag", value="value"), + Tag(key="tag2", value="value2"), + ], ), - Table( - name="view", + TableInfo( + table_name="view", catalog_name="catalog", schema_name="schema", - table_type=TableType.VIEW, + type="VIEW", + data_source_format="CSV", columns=[ ColumnInfo( - name="col1", - type_name=ColumnTypeName.INT, - type_precision=32, - nullable=True, + column_name="col1", + data_type="INT", + data_precision=32, + is_nullable=True, + tags=[], ) ], view_definition="SELECT ...", owner="user2@foo.com", comment="example", - updated_at=0, + updated_at=mock_time, updated_by="foo@bar.com", - properties={ - "view.catalogAndNamespace.numParts": "2", - "view.sqlConfig.spark.sql.hive.convertCTAS": "true", - "view.query.out.col.0": "key", - "view.sqlConfig.spark.sql.parquet.compression.codec": "snappy", - "view.query.out.numCols": "3", - "view.referredTempViewNames": "[]", - "view.query.out.col.1": "values", - "view.sqlConfig.spark.sql.streaming.stopTimeout": "15s", - "view.catalogAndNamespace.part.0": "catalog", - "view.sqlConfig.spark.sql.sources.commitProtocolClass": "com.databricks.sql.transaction.directory.DirectoryAtomicCommitProtocol", - "view.sqlConfig.spark.sql.sources.default": "delta", - "view.sqlConfig.spark.sql.legacy.createHiveTableByDefault": "false", - "view.query.out.col.2": "nested_values", - "view.referredTempFunctionsNames": "[]", - "view.catalogAndNamespace.part.1": "default", - }, - created_at=0, + created_at=mock_time, + created_by="foo@bar.com", ), - Table( - name="table2", + TableInfo( + table_name="table2", catalog_name="catalog2", schema_name="schema", - table_type=TableType.MANAGED, - data_source_format=DataSourceFormat.DELTA, + type="MANAGED", + data_source_format="DELTA", columns=[ ColumnInfo( - name="col1", - type_name=ColumnTypeName.INT, - type_precision=32, - nullable=True, + column_name="col1", + data_type="INT", + data_precision=32, + is_nullable=True, comment="some description", + tags=[], ) ], storage_location="s3://path", owner="sp1", comment="example", - updated_at=0, + updated_at=mock_time, updated_by="foo@bar.com", - properties={ - "delta.lastCommitTimestamp": "1664444422000", - }, - created_at=0, + created_at=mock_time, + created_by="foo@bar.com", ), ] + ] - def mock_list_volumes(catalog, schema): - return [ + mock_list_volumes.side_effect = [ + [ VolumeInfo( - access_point=None, catalog_name="catalog2", - comment=None, - created_at=1714034378658, - created_by="foo@bar.com", - encryption_details=None, + schema_name="schema", + volume_name="volume", + volume_type="EXTERNAL", full_name="catalog2.schema.volume", - metastore_id="ashjkdhaskd", - name="volume", owner="foo@bar.com", - schema_name="schema", - storage_location="s3://path", - updated_at=1714034378658, + created_at=mock_time, + created_by="foo@bar.com", + updated_at=mock_time, updated_by="foo@bar.com", - volume_id="volume-id", - volume_type=VolumeType.EXTERNAL, + storage_location="s3://path", + tags=[ + Tag(key="tag", value="value"), + ], ) ] + ] + + mock_list_volume_files.side_effect = [ + [ + VolumeFileInfo( + last_updated=mock_time, + name="input.csv", + path="/Volumes/catalog2/schema/volume/input.csv", + size=100000, + ), + VolumeFileInfo( + last_updated=mock_time, + name="output.csv", + path="/Volumes/catalog2/schema/volume/output.csv", + size=200000, + ), + ] + ] - mock_client = MagicMock() - mock_client.catalogs = MagicMock() - mock_client.catalogs.list = mock_list_catalogs - mock_client.schemas = MagicMock() - mock_client.schemas.list = mock_list_schemas - mock_client.tables = MagicMock() - mock_client.tables.list = mock_list_tables - mock_client.volumes = MagicMock() - mock_client.volumes.list = mock_list_volumes mock_list_table_lineage.side_effect = [ { "catalog.schema.table": TableLineage( @@ -213,7 +247,7 @@ def mock_list_volumes(catalog, schema): QueryLog( query_id="foo", email="foo@bar.com", - start_time=datetime(2020, 1, 1, tzinfo=timezone.utc), + start_time=mock_time, duration=1234.0, rows_read=9487.0, rows_written=5566.0, @@ -228,53 +262,34 @@ def mock_list_volumes(catalog, schema): ) ] - mock_batch_get_last_refreshed_time.return_value = { - "catalog.schema.table": datetime(2020, 1, 1, tzinfo=timezone.utc), - "catalog2.schema.table2": datetime(2020, 1, 1, tzinfo=timezone.utc), + mock_batch_get_table_properties.return_value = { + "catalog.schema.table": { + "delta.lastCommitTimestamp": "1664444422000", + }, + "catalog.schema.view": { + "view.catalogAndNamespace.numParts": "2", + "view.sqlConfig.spark.sql.hive.convertCTAS": "true", + "view.query.out.col.0": "key", + "view.sqlConfig.spark.sql.parquet.compression.codec": "snappy", + "view.query.out.numCols": "3", + "view.referredTempViewNames": "[]", + "view.query.out.col.1": "values", + "view.sqlConfig.spark.sql.streaming.stopTimeout": "15s", + "view.catalogAndNamespace.part.0": "catalog", + "view.sqlConfig.spark.sql.sources.commitProtocolClass": "com.databricks.sql.transaction.directory.DirectoryAtomicCommitProtocol", + "view.sqlConfig.spark.sql.sources.default": "delta", + "view.sqlConfig.spark.sql.legacy.createHiveTableByDefault": "false", + "view.query.out.col.2": "nested_values", + "view.catalogAndNamespace.part.1": "default", + }, } - mock_create_api.return_value = mock_client - - results = [ - [ - ( - "/Volumes/catalog2/schema/volume/input.csv", - "input.csv", - "100000", - 1715273354000, - ), - ( - "/Volumes/catalog2/schema/volume/output.csv", - "output.csv", - "200000", - 1715278354000, - ), - ], - [ - ("catalog_tag_key_1", "catalog_tag_value_1"), - ("catalog_tag_key_2", "catalog_tag_value_2"), - ], - [ - ("schema", "schema_tag_key_1", "schema_tag_value_1"), - ("schema", "schema_tag_key_2", "schema_tag_value_2"), - ], - [ - ("catalog", "schema", "table", "tag", "value"), - ("catalog", "schema", "table", "tag2", ""), - ("does", "not", "exist", "also", "doesn't exist"), - ], - [ - ("catalog2", "schema", "volume", "tag", "value"), - ], - [ - ("catalog", "schema", "table", "col1", "col_tag", "col_value"), - ("catalog", "schema", "table", "col1", "col_tag2", "tag_value_2"), - ("does", "not", "exist", "also", "doesn't", "exist"), - ], - ] + mock_batch_get_last_refreshed_time.return_value = { + "catalog.schema.table": mock_time, + "catalog2.schema.table2": mock_time, + } - mock_create_connection.return_value = mock_sql_connection(results) - mock_create_connection_pool.return_value = Queue(1) + mock_create_api.return_value = MagicMock() extractor = UnityCatalogExtractor(dummy_config()) events = [EventUtil.trim_event(e) for e in await extractor.extract()] @@ -322,25 +337,6 @@ def test_source_url( ) -@patch("metaphor.unity_catalog.extractor.create_connection") -@patch("metaphor.unity_catalog.extractor.create_api") -def test_init_invalid_dataset( - mock_create_api: MagicMock, - mock_create_connection: MagicMock, - test_root_dir: str, -) -> None: - mock_create_api.return_value = None - mock_create_connection.return_value = None - - extractor = UnityCatalogExtractor.from_config_file( - f"{test_root_dir}/unity_catalog/config.yml" - ) - with pytest.raises(ValueError): - extractor._init_dataset( - Table(catalog_name="catalog", schema_name="schema", name="table") - ) - - @patch("metaphor.unity_catalog.extractor.create_connection") @patch("metaphor.unity_catalog.extractor.create_api") def test_init_dataset( @@ -359,30 +355,35 @@ def test_init_dataset( snapshot.assert_match( serialize_event( extractor._init_dataset( - Table( - name="table", + TableInfo( + table_name="table", catalog_name="catalog", schema_name="schema", - table_type=TableType.MANAGED, - data_source_format=DataSourceFormat.CSV, + type="MANAGED", + data_source_format="csv", columns=[ ColumnInfo( - name="col1", - type_name=ColumnTypeName.INT, - type_precision=32, - nullable=True, + column_name="col1", + data_type="INT", + data_precision=32, + is_nullable=True, comment="some description", + tags=[ + Tag(key="col1_tag_key_1", value="col1_tag_value_1"), + Tag(key="col1_tag_key_2", value="col1_tag_value_2"), + ], ) ], storage_location="s3://path", owner="foo@bar.com", comment="example", - updated_at=0, + updated_at=mock_time, updated_by="foo@bar.com", properties={ "delta.lastCommitTimestamp": "1664444422000", }, - created_at=0, + created_at=mock_time, + created_by="foo@bar.com", ), ) ), @@ -392,11 +393,17 @@ def test_init_dataset( snapshot.assert_match( serialize_event( extractor._init_dataset( - Table( - name="table", + TableInfo( + table_name="table", catalog_name="catalog", schema_name="schema", - table_type=TableType.MANAGED_SHALLOW_CLONE, + type="MANAGED_SHALLOW_CLONE", + owner="foo", + created_at=mock_time, + created_by="bar", + updated_at=mock_time, + updated_by="baz", + data_source_format="csv", ), ) ), @@ -406,11 +413,17 @@ def test_init_dataset( snapshot.assert_match( serialize_event( extractor._init_dataset( - Table( - name="table", + TableInfo( + table_name="table", catalog_name="catalog", schema_name="schema", - table_type=TableType.EXTERNAL_SHALLOW_CLONE, + type="EXTERNAL_SHALLOW_CLONE", + owner="foo", + created_at=mock_time, + created_by="bar", + updated_at=mock_time, + updated_by="baz", + data_source_format="csv", ), ) ), diff --git a/tests/unity_catalog/test_models.py b/tests/unity_catalog/test_models.py index 7db7417e..e69de29b 100644 --- a/tests/unity_catalog/test_models.py +++ b/tests/unity_catalog/test_models.py @@ -1,11 +0,0 @@ -import pytest -from databricks.sdk.service.catalog import ColumnInfo - -from metaphor.unity_catalog.models import extract_schema_field_from_column_info - - -def test_parse_schema_field_from_invalid_column_info() -> None: - with pytest.raises(ValueError): - extract_schema_field_from_column_info( - ColumnInfo(comment="does not have a type") - ) diff --git a/tests/unity_catalog/test_queries.py b/tests/unity_catalog/test_queries.py new file mode 100644 index 00000000..0e7ea99a --- /dev/null +++ b/tests/unity_catalog/test_queries.py @@ -0,0 +1,551 @@ +from datetime import datetime, timezone +from unittest.mock import MagicMock + +from freezegun import freeze_time +from pytest_snapshot.plugin import Snapshot + +from metaphor.unity_catalog.models import ( + CatalogInfo, + Column, + ColumnInfo, + ColumnLineage, + SchemaInfo, + TableInfo, + TableLineage, + Tag, + VolumeFileInfo, + VolumeInfo, +) +from metaphor.unity_catalog.queries import ( + get_last_refreshed_time, + get_table_properties, + list_catalogs, + list_column_lineage, + list_query_logs, + list_schemas, + list_table_lineage, + list_tables, + list_volume_files, + list_volumes, +) +from tests.unity_catalog.mocks import mock_sql_connection + + +def test_list_catalogs( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "catalog_name": "catalog1", + "catalog_owner": "owner1", + "comment": "comment1", + "tags": [ + {"tag_name": "tag1", "tag_value": "value1"}, + {"tag_name": "tag2", "tag_value": "value2"}, + ], + }, + { + "catalog_name": "catalog2", + "catalog_owner": "owner2", + "comment": "comment2", + "tags": [], + }, + ] + ], + None, + mock_cursor, + ) + + catalogs = list_catalogs(mock_connection) + + assert catalogs == [ + CatalogInfo( + catalog_name="catalog1", + owner="owner1", + comment="comment1", + tags=[ + Tag(key="tag1", value="value1"), + Tag(key="tag2", value="value2"), + ], + ), + CatalogInfo( + catalog_name="catalog2", owner="owner2", comment="comment2", tags=[] + ), + ] + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_catalogs.sql") + + # Exception handling + mock_connection = mock_sql_connection([], Exception("some error")) + catalogs = list_catalogs(mock_connection) + assert catalogs == [] + + +def test_list_schemas( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "catalog_name": "catalog1", + "schema_name": "schema1", + "schema_owner": "owner1", + "comment": "comment1", + "tags": [ + {"tag_name": "tag1", "tag_value": "value1"}, + {"tag_name": "tag2", "tag_value": "value2"}, + ], + }, + { + "catalog_name": "catalog1", + "schema_name": "schema2", + "schema_owner": "owner2", + "comment": "comment2", + "tags": [], + }, + ] + ], + None, + mock_cursor, + ) + + schemas = list_schemas(mock_connection, "catalog1") + + assert schemas == [ + SchemaInfo( + catalog_name="catalog1", + schema_name="schema1", + owner="owner1", + comment="comment1", + tags=[ + Tag(key="tag1", value="value1"), + Tag(key="tag2", value="value2"), + ], + ), + SchemaInfo( + catalog_name="catalog1", + schema_name="schema2", + owner="owner2", + comment="comment2", + tags=[], + ), + ] + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_schemas.sql") + + # Exception handling + mock_connection = mock_sql_connection([], Exception("some error")) + schemas = list_schemas(mock_connection, "catalog1") + assert schemas == [] + + +def test_list_tables( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "catalog_name": "catalog1", + "schema_name": "schema1", + "table_name": "table1", + "table_type": "TABLE", + "owner": "owner1", + "table_comment": "table_comment1", + "data_source_format": "PARQUET", + "storage_path": "location1", + "created_at": datetime(2000, 1, 1), + "created_by": "user1", + "updated_at": datetime(2001, 1, 1), + "updated_by": "user2", + "view_definition": "definition", + "columns": [ + { + "column_name": "column1", + "data_type": "data_type1", + "data_precision": 10, + "is_nullable": "YES", + "comment": "column_comment1", + "tags": None, + }, + { + "column_name": "column2", + "data_type": "data_type2", + "data_precision": 20, + "is_nullable": "NO", + "comment": "column_comment2", + "tags": None, + }, + ], + }, + ], + ], + None, + mock_cursor, + ) + + tables = list_tables(mock_connection, "catalog1", "schema1") + + assert tables == [ + TableInfo( + catalog_name="catalog1", + schema_name="schema1", + table_name="table1", + type="TABLE", + owner="owner1", + comment="table_comment1", + data_source_format="PARQUET", + storage_location="location1", + created_at=datetime(2000, 1, 1), + created_by="user1", + updated_at=datetime(2001, 1, 1), + updated_by="user2", + view_definition="definition", + columns=[ + ColumnInfo( + column_name="column1", + data_type="data_type1", + data_precision=10, + is_nullable=True, + comment="column_comment1", + tags=[], + ), + ColumnInfo( + column_name="column2", + data_type="data_type2", + data_precision=20, + is_nullable=False, + comment="column_comment2", + tags=[], + ), + ], + ), + ] + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_tables.sql") + + # Exception handling + mock_connection = mock_sql_connection([], Exception("some error")) + tables = list_tables(mock_connection, "catalog1", "schema1") + assert tables == [] + + +def test_list_volumes( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "volume_catalog": "catalog1", + "volume_schema": "schema1", + "volume_name": "volume1", + "volume_type": "MANAGED", + "volume_owner": "owner1", + "comment": "comment1", + "created": datetime(2000, 1, 1), + "created_by": "user1", + "last_altered": datetime(2001, 1, 1), + "last_altered_by": "user2", + "last_altered_by": "user2", + "storage_location": "location1", + "tags": [ + {"tag_name": "tag1", "tag_value": "value1"}, + {"tag_name": "tag2", "tag_value": "value2"}, + ], + }, + ] + ], + None, + mock_cursor, + ) + + volumes = list_volumes(mock_connection, "catalog1", "schema1") + + assert volumes == [ + VolumeInfo( + catalog_name="catalog1", + schema_name="schema1", + volume_name="volume1", + full_name="catalog1.schema1.volume1", + volume_type="MANAGED", + owner="owner1", + comment="comment1", + created_at=datetime(2000, 1, 1), + created_by="user1", + updated_at=datetime(2001, 1, 1), + updated_by="user2", + storage_location="location1", + tags=[ + Tag(key="tag1", value="value1"), + Tag(key="tag2", value="value2"), + ], + ), + ] + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_volumes.sql") + + # Exception handling + mock_connection = mock_sql_connection([], Exception("some error")) + volumes = list_volumes(mock_connection, "catalog1", "schema1") + assert volumes == [] + + +def test_list_volume_files( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "path": "path1", + "name": "name1", + "size": 100, + "modification_time": datetime(2000, 1, 1), + }, + ] + ], + None, + mock_cursor, + ) + + volume_files = list_volume_files( + mock_connection, + VolumeInfo( + catalog_name="catalog1", + schema_name="schema1", + volume_name="volume1", + full_name="catalog1.schema1.volume1", + volume_type="MANAGED", + owner="owner1", + created_at=datetime(2000, 1, 1), + created_by="user1", + updated_at=datetime(2001, 1, 1), + updated_by="user2", + storage_location="location1", + tags=[], + ), + ) + + assert volume_files == [ + VolumeFileInfo( + path="path1", + name="name1", + size=100, + last_updated=datetime(2000, 1, 1), + ), + ] + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_volume_files.sql") + + +def test_list_table_lineage( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + ("c.s.t1", "c.s.t3"), + ("c.s.t2", "c.s.t3"), + ("c.s.t4", "c.s.t2"), + ] + ], + None, + mock_cursor, + ) + + table_lineage = list_table_lineage(mock_connection, "c", "s") + + assert table_lineage == { + "c.s.t3": TableLineage(upstream_tables=["c.s.t1", "c.s.t2"]), + "c.s.t2": TableLineage(upstream_tables=["c.s.t4"]), + } + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_table_lineage.sql") + + # Exception handling + mock_connection = mock_sql_connection([], Exception("some error")) + table_lineage = list_table_lineage(mock_connection, "c", "s") + assert table_lineage == {} + + +def test_list_column_lineage( + test_root_dir: str, + snapshot: Snapshot, +): + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + ("c.s.t1", "c1", "c.s.t3", "ca"), + ("c.s.t1", "c2", "c.s.t3", "ca"), + ("c.s.t1", "c3", "c.s.t3", "cb"), + ("c.s.t2", "c4", "c.s.t3", "ca"), + ("c.s.t3", "c5", "c.s.t2", "cc"), + ] + ], + None, + mock_cursor, + ) + + column_lineage = list_column_lineage(mock_connection, "catalog", "schema") + + assert column_lineage == { + "c.s.t3": ColumnLineage( + upstream_columns={ + "ca": [ + Column(column_name="c1", table_name="c.s.t1"), + Column(column_name="c2", table_name="c.s.t1"), + Column(column_name="c4", table_name="c.s.t2"), + ], + "cb": [Column(column_name="c3", table_name="c.s.t1")], + } + ), + "c.s.t2": ColumnLineage( + upstream_columns={ + "cc": [Column(column_name="c5", table_name="c.s.t3")], + } + ), + } + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_column_lineage.sql") + + # Exception handling + mock_connection = mock_sql_connection([], Exception("some error")) + column_lineage = list_column_lineage(mock_connection, "c", "s") + assert column_lineage == {} + + +@freeze_time("2000-01-02") +def test_list_query_logs( + test_root_dir: str, + snapshot: Snapshot, +): + + mock_cursor = MagicMock() + mock_connection = mock_sql_connection(None, None, mock_cursor) + + list_query_logs(mock_connection, 1, ["user1", "user2"]) + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "list_query_log.sql") + assert args[1] == [ + datetime(2000, 1, 1, tzinfo=timezone.utc), + datetime(2000, 1, 2, tzinfo=timezone.utc), + ] + + # Exception handling + mock_connection = mock_sql_connection([], Exception("some error")) + query_logs = list_query_logs(mock_connection, 1, ["user1", "user2"]) + assert query_logs == [] + + +def test_get_last_refreshed_time( + test_root_dir: str, + snapshot: Snapshot, +): + + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "operation": "SET TBLPROPERTIES", + "timestamp": datetime(2020, 1, 4), + }, + { + "operation": "ADD CONSTRAINT", + "timestamp": datetime(2020, 1, 3), + }, + { + "operation": "CHANGE COLUMN", + "timestamp": datetime(2020, 1, 2), + }, + { + "operation": "WRITE", + "timestamp": datetime(2020, 1, 1), + }, + ] + ], + None, + mock_cursor, + ) + + result = get_last_refreshed_time(mock_connection, "db.schema.table", 50) + + assert result == ("db.schema.table", datetime(2020, 1, 1)) + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "describe_history.sql") + + # Exception handling + mock_connection = mock_sql_connection([], Exception("some error")) + result = get_last_refreshed_time(mock_connection, "db.schema.table", 50) + assert result is None + + +def test_get_table_properties( + test_root_dir: str, + snapshot: Snapshot, +): + + mock_cursor = MagicMock() + + mock_connection = mock_sql_connection( + [ + [ + { + "key": "key1", + "value": "value1", + }, + { + "key": "key2", + "value": "value2", + }, + ] + ], + None, + mock_cursor, + ) + + result = get_table_properties(mock_connection, "db.schema.table") + + assert result == ("db.schema.table", {"key1": "value1", "key2": "value2"}) + + args = mock_cursor.execute.call_args_list[0].args + snapshot.assert_match(args[0], "show_table_properties.sql") + + # Exception handling + mock_connection = mock_sql_connection([], Exception("some error")) + result = get_table_properties(mock_connection, "db.schema.table") + assert result is None diff --git a/tests/unity_catalog/test_utils.py b/tests/unity_catalog/test_utils.py index 1136ade1..6ca52d45 100644 --- a/tests/unity_catalog/test_utils.py +++ b/tests/unity_catalog/test_utils.py @@ -1,9 +1,8 @@ -from datetime import datetime, timezone +from datetime import datetime from typing import Optional, Tuple from unittest.mock import MagicMock from databricks.sdk.service.iam import ServicePrincipal -from freezegun import freeze_time from pytest_snapshot.plugin import Snapshot from metaphor.common.entity_id import dataset_normalized_name, to_dataset_entity_id @@ -15,18 +14,20 @@ DatasetStructure, QueriedDataset, QueryLogs, + SystemTag, + SystemTags, + SystemTagSource, ) -from metaphor.unity_catalog.models import Column, ColumnLineage, TableLineage +from metaphor.unity_catalog.models import Tag from metaphor.unity_catalog.utils import ( batch_get_last_refreshed_time, + batch_get_table_properties, escape_special_characters, find_qualified_dataset, get_last_refreshed_time, get_query_logs, - list_column_lineage, - list_query_logs, list_service_principals, - list_table_lineage, + to_system_tags, ) from tests.test_utils import load_json from tests.unity_catalog.mocks import mock_connection_pool, mock_sql_connection @@ -179,78 +180,6 @@ def make_dataset(parts: Tuple[str, str, str]) -> Dataset: assert not found -def test_list_table_lineage(): - mock_connection = mock_sql_connection( - [ - [ - ("c.s.t1", "c.s.t3"), - ("c.s.t2", "c.s.t3"), - ("c.s.t4", "c.s.t2"), - ] - ] - ) - - table_lineage = list_table_lineage(mock_connection, "c", "s") - - assert table_lineage == { - "c.s.t3": TableLineage(upstream_tables=["c.s.t1", "c.s.t2"]), - "c.s.t2": TableLineage(upstream_tables=["c.s.t4"]), - } - - -def test_list_column_lineage(): - mock_connection = mock_sql_connection( - [ - [ - ("c.s.t1", "c1", "c.s.t3", "ca"), - ("c.s.t1", "c2", "c.s.t3", "ca"), - ("c.s.t1", "c3", "c.s.t3", "cb"), - ("c.s.t2", "c4", "c.s.t3", "ca"), - ("c.s.t3", "c5", "c.s.t2", "cc"), - ] - ] - ) - - column_lineage = list_column_lineage(mock_connection, "catalog", "schema") - - assert column_lineage == { - "c.s.t3": ColumnLineage( - upstream_columns={ - "ca": [ - Column(column_name="c1", table_name="c.s.t1"), - Column(column_name="c2", table_name="c.s.t1"), - Column(column_name="c4", table_name="c.s.t2"), - ], - "cb": [Column(column_name="c3", table_name="c.s.t1")], - } - ), - "c.s.t2": ColumnLineage( - upstream_columns={ - "cc": [Column(column_name="c5", table_name="c.s.t3")], - } - ), - } - - -@freeze_time("2000-01-02") -def test_list_query_logs( - test_root_dir: str, - snapshot: Snapshot, -): - - mock_cursor = MagicMock() - mock_connection = mock_sql_connection(None, None, mock_cursor) - - list_query_logs(mock_connection, 1, ["user1", "user2"]) - - args = mock_cursor.execute.call_args_list[0].args - snapshot.assert_match(args[0], "list_query_log.sql") - assert args[1] == [ - datetime(2000, 1, 1, tzinfo=timezone.utc), - datetime(2000, 1, 2, tzinfo=timezone.utc), - ] - - def test_get_last_refreshed_time( test_root_dir: str, snapshot: Snapshot, @@ -315,6 +244,30 @@ def test_batch_get_last_refreshed_time(): assert result_map == {"a.b.c": datetime(2020, 1, 1), "d.e.f": datetime(2020, 1, 1)} +def test_batch_get_table_properties(): + + connection_pool = mock_connection_pool( + [ + [ + { + "key": "prop1", + "value": "value1", + }, + ], + [ + { + "key": "prop2", + "value": "value2", + }, + ], + ], + ) + + result_map = batch_get_table_properties(connection_pool, ["a.b.c", "d.e.f"]) + + assert result_map == {"a.b.c": {"prop1": "value1"}, "d.e.f": {"prop2": "value2"}} + + def test_list_service_principals(): sp1 = ServicePrincipal(application_id="sp1", display_name="SP1") @@ -332,3 +285,22 @@ def test_list_service_principals(): def test_escape_special_characters(): assert escape_special_characters("this.is.a_table") == "this.is.a_table" assert escape_special_characters("this.is.also-a-table") == "`this.is.also-a-table`" + + +def test_to_system_tags(): + assert to_system_tags( + [Tag(key="tag", value="value"), Tag(key="tag2", value="")] + ) == SystemTags( + tags=[ + SystemTag( + key="tag", + value="value", + system_tag_source=SystemTagSource.UNITY_CATALOG, + ), + SystemTag( + key=None, + value="tag2", + system_tag_source=SystemTagSource.UNITY_CATALOG, + ), + ] + )