diff --git a/metaphor/common/sql/dialect.py b/metaphor/common/sql/dialect.py index f2946ff9..b72e37c9 100644 --- a/metaphor/common/sql/dialect.py +++ b/metaphor/common/sql/dialect.py @@ -9,6 +9,7 @@ DataPlatform.REDSHIFT: "redshift", DataPlatform.SNOWFLAKE: "snowflake", DataPlatform.UNITY_CATALOG: "databricks", + DataPlatform.ORACLE: "oracle", } """ Mapping from data platforms supported by Metaphor to dialects recognized by SQLGlot. diff --git a/metaphor/common/sql/table_level_lineage/table.py b/metaphor/common/sql/table_level_lineage/table.py index 126943d3..8d0ac978 100644 --- a/metaphor/common/sql/table_level_lineage/table.py +++ b/metaphor/common/sql/table_level_lineage/table.py @@ -22,17 +22,21 @@ def from_sqlglot_table(cls, table: exp.Table): ) def to_queried_dataset( - self, platform: DataPlatform, account: Optional[str], database: Optional[str] + self, + platform: DataPlatform, + account: Optional[str], + default_database: Optional[str] = None, + default_schema: Optional[str] = None, ): + schema = self.schema or default_schema + database = self.db or default_database return QueriedDataset( - database=self.db, - schema=self.schema, + database=database, + schema=schema, table=self.table, id=str( to_dataset_entity_id( - dataset_normalized_name( - self.db or database, self.schema, self.table - ), + dataset_normalized_name(database, schema, self.table), platform, account, ) diff --git a/metaphor/common/sql/table_level_lineage/table_level_lineage.py b/metaphor/common/sql/table_level_lineage/table_level_lineage.py index 3a2ccb4c..b4c9f07d 100644 --- a/metaphor/common/sql/table_level_lineage/table_level_lineage.py +++ b/metaphor/common/sql/table_level_lineage/table_level_lineage.py @@ -134,7 +134,8 @@ def extract_table_level_lineage( account: Optional[str], statement_type: Optional[str] = None, query_id: Optional[str] = None, - database: Optional[str] = None, + default_database: Optional[str] = None, + default_schema: Optional[str] = None, ) -> Result: if statement_type and statement_type.upper() not in _VALID_STATEMENT_TYPES: @@ -153,11 +154,15 @@ def extract_table_level_lineage( try: return Result( targets=[ - target.to_queried_dataset(platform, account, database) + target.to_queried_dataset( + platform, account, default_database, default_schema + ) for target in _find_targets(expression) ], sources=[ - source.to_queried_dataset(platform, account, database) + source.to_queried_dataset( + platform, account, default_database, default_schema + ) for source in _find_sources(expression) ], ) diff --git a/metaphor/informatica/extractor.py b/metaphor/informatica/extractor.py index 0a8cfcbc..320cb242 100644 --- a/metaphor/informatica/extractor.py +++ b/metaphor/informatica/extractor.py @@ -186,7 +186,7 @@ def trans_source(source: MappingParameter) -> List[str]: sql=source.customQuery, platform=get_platform(connection), account=get_account(connection), - database=connection.database, + default_database=connection.database, ) return [dataset.id for dataset in result.sources] diff --git a/metaphor/oracle/extractor.py b/metaphor/oracle/extractor.py index 7a0f6634..5193f3d4 100644 --- a/metaphor/oracle/extractor.py +++ b/metaphor/oracle/extractor.py @@ -1,11 +1,14 @@ from typing import Collection, Iterator, List -from sqlalchemy import Inspector, text +from sqlalchemy import Connection, Inspector, text from metaphor.common.entity_id import dataset_normalized_name from metaphor.common.event_util import ENTITY_TYPES from metaphor.common.logger import get_logger -from metaphor.common.utils import md5_digest, start_of_day, to_utc_time +from metaphor.common.sql.table_level_lineage.table_level_lineage import ( + extract_table_level_lineage, +) +from metaphor.common.utils import md5_digest, safe_float, start_of_day, to_utc_time from metaphor.database.extractor import GenericDatabaseExtractor from metaphor.models.crawler_run_metadata import Platform from metaphor.models.metadata_change_event import ( @@ -175,6 +178,56 @@ def _extract_ddl(self, inspector: Inspector): assert dataset.schema and dataset.schema.sql_schema dataset.schema.sql_schema.table_schema = ddl + def _inner_fetch_query_logs( + self, sql: str, connection: Connection + ) -> List[QueryLog]: + cursor = connection.execute(text(sql)) + + rows = cursor.fetchall() + logs: List[QueryLog] = [] + for ( + user, + query, + start, + duration, + query_id, + read_bytes, + write_bytes, + row_count, + ) in rows: + schema = user.lower() if user else None + database = self._database if self._database else None + + ttl = extract_table_level_lineage( + query, + platform=DataPlatform.ORACLE, + account=self._alternative_host or self._config.host, + query_id=query_id, + default_database=database, + default_schema=schema, + ) + + logs.append( + QueryLog( + id=f"{DataPlatform.ORACLE.name}:{query_id}", + query_id=query_id, + platform=DataPlatform.ORACLE, + default_database=database, + default_schema=schema, + user_id=user, + sql=query, + sql_hash=md5_digest(query.encode("utf-8")), + duration=float(duration), + start_time=to_utc_time(start), + bytes_read=safe_float(read_bytes), + bytes_written=safe_float(write_bytes), + sources=ttl.sources, + rows_read=safe_float(row_count), + targets=ttl.targets, + ) + ) + return logs + def _extract_query_logs(self, inspector: Inspector, excluded_users: List[str]): start_time = start_of_day( daysAgo=self._query_logs_config.lookback_days @@ -183,33 +236,34 @@ def _extract_query_logs(self, inspector: Inspector, excluded_users: List[str]): users = [f"'{user.upper()}'" for user in excluded_users] with inspector.engine.connect() as connection: + offset = 0 result_limit = 1000 filters = f"""AND PARSING_SCHEMA_NAME not in ({','.join(users)})""" - sql = f""" - SELECT - PARSING_SCHEMA_NAME, - SQL_FULLTEXT AS query_text, - TO_TIMESTAMP(FIRST_LOAD_TIME, 'yy-MM-dd/HH24:MI:SS') AS start_time, - ELAPSED_TIME / 1000 AS duration, - SQL_ID - FROM gv$sql - WHERE OBJECT_STATUS = 'VALID' - {filters} - AND TO_TIMESTAMP(FIRST_LOAD_TIME, 'yy-MM-dd/HH24:MI:SS') >= TO_TIMESTAMP('{start_time}', 'yy-MM-dd HH24:MI:SS') - ORDER BY FIRST_LOAD_TIME DESC - OFFSET 0 ROWS FETCH NEXT {result_limit} ROWS ONLY - """ - - cursor = connection.execute(text(sql)) - for user, query, start, duration, query_id in cursor: - yield QueryLog( - id=f"{DataPlatform.ORACLE.name}:{query_id}", - query_id=query_id, - platform=DataPlatform.ORACLE, - user_id=user, - sql=query, - sql_hash=md5_digest(query.encode("utf-8")), - duration=float(duration), - start_time=to_utc_time(start), - ) + while True: + sql = f""" + SELECT + PARSING_SCHEMA_NAME, + SQL_FULLTEXT AS query_text, + TO_TIMESTAMP(FIRST_LOAD_TIME, 'yy-MM-dd/HH24:MI:SS') AS start_time, + ELAPSED_TIME / 1000 AS duration, + SQL_ID, + PHYSICAL_READ_BYTES, + PHYSICAL_WRITE_BYTES, + ROWS_PROCESSED + FROM gv$sql + WHERE OBJECT_STATUS = 'VALID' + {filters} + AND TO_TIMESTAMP(FIRST_LOAD_TIME, 'yy-MM-dd/HH24:MI:SS') >= TO_TIMESTAMP('{start_time}', 'yy-MM-dd HH24:MI:SS') + ORDER BY FIRST_LOAD_TIME DESC + OFFSET {offset} ROWS FETCH NEXT {offset + result_limit} ROWS ONLY + """ + logs = self._inner_fetch_query_logs(sql, connection) + for log in logs: + yield log + + logger.info(f"Fetched {len(logs)} query logs") + + if len(logs) < result_limit: + break + offset += result_limit diff --git a/tests/common/sql/table_level_lineage/test_select_dialect_specific.py b/tests/common/sql/table_level_lineage/test_select_dialect_specific.py index aa6b6b15..f6f98d01 100644 --- a/tests/common/sql/table_level_lineage/test_select_dialect_specific.py +++ b/tests/common/sql/table_level_lineage/test_select_dialect_specific.py @@ -102,3 +102,20 @@ def test_select_from_unnest_with_ordinality(platform: DataPlatform): CROSS JOIN UNNEST(numbers) WITH ORDINALITY AS t (n, a); """ assert_table_lineage_equal(sql, None, None, platform=platform) + + +@pytest.mark.parametrize("platform", [DataPlatform.ORACLE]) +def test_oracle_select_statement(platform: DataPlatform): + assert_table_lineage_equal( + "/* **** 60 */\nwith ss as (\n select\n i_item_id,sum(ss_ext_sales_price) total_sales\n from\n \tstore_sales,\n \tdate_dim,\n customer_address,\n item\n where\n i_item_id in (select\n i_item_id\nfrom\n item\nwhere i_category in ('Children'))\n and ss_item_sk = i_item_sk\n and ss_sold_date_sk = d_date_sk\n and d_year = 2000\n and d_moy = 8\n and ss_addr_sk = ca_address_sk\n and ca_gmt_offset = -7\n group by i_item_id),\n cs as (\n select\n i_item_id,sum(cs_ext_sales_price) total_sales\n from\n \tcatalog_sales,\n \tdate_dim,\n customer_address,\n item\n where\n i_item_id in (select\n i_item_id\nfrom\n item\nwhere i_category in ('Children'))\n and cs_item_sk = i_item_sk\n and cs_sold_date_sk = d_date_sk\n and d_year = 2000\n and d_moy = 8\n and cs_bill_addr_sk = ca_address_sk\n and ca_gmt_offset = -7\n group by i_item_id),\n ws as (\n select\n i_item_id,sum(ws_ext_sales_price) total_sales\n from\n \tweb_sales,\n \tdate_dim,\n customer_address,\n item\n where\n i_item_id in (select\n i_item_id\nfrom\n item\nwhere i_category in ('Children'))\n and ws_item_sk = i_item_sk\n and ws_sold_date_sk = d_date_sk\n and d_year = 2000\n and d_moy = 8\n and ws_bill_addr_sk = ca_address_sk\n and ca_gmt_offset = -7\n group by i_item_id)\n select * from ( select\n i_item_id\n,sum(total_sales) total_sales\n from (select * from ss\n union all\n select * from cs\n union all\n select * from ws) tmp1\n group by i_item_id\n order by i_item_id\n ,total_sales\n ) where rownum <= 100", + { + "catalog_sales", + "customer_address", + "date_dim", + "item", + "store_sales", + "web_sales", + }, + set(), + platform=platform, + ) diff --git a/tests/oracle/expected_query_logs.json b/tests/oracle/expected_query_logs.json index b50d848f..97a00dd6 100644 --- a/tests/oracle/expected_query_logs.json +++ b/tests/oracle/expected_query_logs.json @@ -4,9 +4,23 @@ "duration": 10.0, "platform": "ORACLE", "queryId": "sql-id", - "sql": "SELECT...", "startTime": "2024-07-30T15:31:33+00:00", - "sqlHash": "191df6d782898cbb739c413fa5868422", - "userId": "DEV" + "userId": "DEV", + "bytesRead": 1024.0, + "bytesWritten": 0.0, + "defaultDatabase": "db", + "defaultSchema": "dev", + "rowsRead": 20.0, + "sources": [ + { + "database": "db", + "id": "DATASET~FB036B29A2F624861C86A3D4237DF75F", + "schema": "dev", + "table": "TABLE_5566" + } + ], + "sql": "SELECT x, y FROM TABLE_5566", + "sqlHash": "a5bfd703435ab791044cc28af8999abf", + "targets": [] } ] diff --git a/tests/oracle/test_extractor.py b/tests/oracle/test_extractor.py index 7400000f..26e80902 100644 --- a/tests/oracle/test_extractor.py +++ b/tests/oracle/test_extractor.py @@ -119,18 +119,18 @@ def mock_connection(): [("mview",)], # extract_mviews_names, prod [("TABLE1", "DEV", "DDL ......")], # extract_ddl [("SYS",)], # get_system_users - ] - ) - - cursor.__iter__.return_value = iter( - [ - ( - "DEV", - "SELECT...", - datetime.fromtimestamp(1722353493, tz=timezone.utc), - 10.0, - "sql-id", - ) + [ + ( + "DEV", + "SELECT x, y FROM TABLE_5566", + datetime.fromtimestamp(1722353493, tz=timezone.utc), + 10.0, + "sql-id", + 1024.0, + 0.0, + 20, + ) + ], ] )