Skip to content

Commit

Permalink
feat(ingest): add output schema inference for sql parser (#8989)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Oct 12, 2023
1 parent 245c5c0 commit 84bba4d
Show file tree
Hide file tree
Showing 24 changed files with 604 additions and 148 deletions.
119 changes: 107 additions & 12 deletions metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import logging
import pathlib
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import pydantic.dataclasses
import sqlglot
import sqlglot.errors
import sqlglot.lineage
import sqlglot.optimizer.annotate_types
import sqlglot.optimizer.qualify
import sqlglot.optimizer.qualify_columns
from pydantic import BaseModel
Expand All @@ -23,7 +24,17 @@
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.bigquery_v2.bigquery_audit import BigqueryTableIdentifier
from datahub.metadata.schema_classes import OperationTypeClass, SchemaMetadataClass
from datahub.metadata.schema_classes import (
ArrayTypeClass,
BooleanTypeClass,
DateTypeClass,
NumberTypeClass,
OperationTypeClass,
SchemaFieldDataTypeClass,
SchemaMetadataClass,
StringTypeClass,
TimeTypeClass,
)
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedDict
from datahub.utilities.urns.dataset_urn import DatasetUrn

Expand Down Expand Up @@ -90,8 +101,18 @@ def get_query_type_of_sql(expression: sqlglot.exp.Expression) -> QueryType:
return QueryType.UNKNOWN


class _ParserBaseModel(
BaseModel,
arbitrary_types_allowed=True,
json_encoders={
SchemaFieldDataTypeClass: lambda v: v.to_obj(),
},
):
pass


@functools.total_ordering
class _FrozenModel(BaseModel, frozen=True):
class _FrozenModel(_ParserBaseModel, frozen=True):
def __lt__(self, other: "_FrozenModel") -> bool:
for field in self.__fields__:
self_v = getattr(self, field)
Expand Down Expand Up @@ -146,37 +167,50 @@ class _ColumnRef(_FrozenModel):
column: str


class ColumnRef(BaseModel):
class ColumnRef(_ParserBaseModel):
table: Urn
column: str


class _DownstreamColumnRef(BaseModel):
class _DownstreamColumnRef(_ParserBaseModel):
table: Optional[_TableName]
column: str
column_type: Optional[sqlglot.exp.DataType]


class DownstreamColumnRef(BaseModel):
class DownstreamColumnRef(_ParserBaseModel):
table: Optional[Urn]
column: str
column_type: Optional[SchemaFieldDataTypeClass]
native_column_type: Optional[str]

@pydantic.validator("column_type", pre=True)
def _load_column_type(
cls, v: Optional[Union[dict, SchemaFieldDataTypeClass]]
) -> Optional[SchemaFieldDataTypeClass]:
if v is None:
return None
if isinstance(v, SchemaFieldDataTypeClass):
return v
return SchemaFieldDataTypeClass.from_obj(v)


class _ColumnLineageInfo(BaseModel):
class _ColumnLineageInfo(_ParserBaseModel):
downstream: _DownstreamColumnRef
upstreams: List[_ColumnRef]

logic: Optional[str]


class ColumnLineageInfo(BaseModel):
class ColumnLineageInfo(_ParserBaseModel):
downstream: DownstreamColumnRef
upstreams: List[ColumnRef]

# Logic for this column, as a SQL expression.
logic: Optional[str] = pydantic.Field(default=None, exclude=True)


class SqlParsingDebugInfo(BaseModel, arbitrary_types_allowed=True):
class SqlParsingDebugInfo(_ParserBaseModel):
confidence: float = 0.0

tables_discovered: int = 0
Expand All @@ -190,7 +224,7 @@ def error(self) -> Optional[Exception]:
return self.table_error or self.column_error


class SqlParsingResult(BaseModel):
class SqlParsingResult(_ParserBaseModel):
query_type: QueryType = QueryType.UNKNOWN

in_tables: List[Urn]
Expand Down Expand Up @@ -541,6 +575,15 @@ def _schema_aware_fuzzy_column_resolve(
) from e
logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect))

# Try to figure out the types of the output columns.
try:
statement = sqlglot.optimizer.annotate_types.annotate_types(
statement, schema=sqlglot_db_schema
)
except sqlglot.errors.OptimizeError as e:
# This is not a fatal error, so we can continue.
logger.debug("sqlglot failed to annotate types: %s", e)

column_lineage = []

try:
Expand All @@ -553,7 +596,6 @@ def _schema_aware_fuzzy_column_resolve(
logger.debug("output columns: %s", [col[0] for col in output_columns])
output_col: str
for output_col, original_col_expression in output_columns:
# print(f"output column: {output_col}")
if output_col == "*":
# If schema information is available, the * will be expanded to the actual columns.
# Otherwise, we can't process it.
Expand Down Expand Up @@ -613,12 +655,19 @@ def _schema_aware_fuzzy_column_resolve(

output_col = _schema_aware_fuzzy_column_resolve(output_table, output_col)

# Guess the output column type.
output_col_type = None
if original_col_expression.type:
output_col_type = original_col_expression.type

if not direct_col_upstreams:
logger.debug(f' "{output_col}" has no upstreams')
column_lineage.append(
_ColumnLineageInfo(
downstream=_DownstreamColumnRef(
table=output_table, column=output_col
table=output_table,
column=output_col,
column_type=output_col_type,
),
upstreams=sorted(direct_col_upstreams),
# logic=column_logic.sql(pretty=True, dialect=dialect),
Expand Down Expand Up @@ -673,6 +722,42 @@ def _try_extract_select(
return statement


def _translate_sqlglot_type(
sqlglot_type: sqlglot.exp.DataType.Type,
) -> Optional[SchemaFieldDataTypeClass]:
TypeClass: Any
if sqlglot_type in sqlglot.exp.DataType.TEXT_TYPES:
TypeClass = StringTypeClass
elif sqlglot_type in sqlglot.exp.DataType.NUMERIC_TYPES or sqlglot_type in {
sqlglot.exp.DataType.Type.DECIMAL,
}:
TypeClass = NumberTypeClass
elif sqlglot_type in {
sqlglot.exp.DataType.Type.BOOLEAN,
sqlglot.exp.DataType.Type.BIT,
}:
TypeClass = BooleanTypeClass
elif sqlglot_type in {
sqlglot.exp.DataType.Type.DATE,
}:
TypeClass = DateTypeClass
elif sqlglot_type in sqlglot.exp.DataType.TEMPORAL_TYPES:
TypeClass = TimeTypeClass
elif sqlglot_type in {
sqlglot.exp.DataType.Type.ARRAY,
}:
TypeClass = ArrayTypeClass
elif sqlglot_type in {
sqlglot.exp.DataType.Type.UNKNOWN,
}:
return None
else:
logger.debug("Unknown sqlglot type: %s", sqlglot_type)
return None

return SchemaFieldDataTypeClass(type=TypeClass())


def _translate_internal_column_lineage(
table_name_urn_mapping: Dict[_TableName, str],
raw_column_lineage: _ColumnLineageInfo,
Expand All @@ -684,6 +769,16 @@ def _translate_internal_column_lineage(
downstream=DownstreamColumnRef(
table=downstream_urn,
column=raw_column_lineage.downstream.column,
column_type=_translate_sqlglot_type(
raw_column_lineage.downstream.column_type.this
)
if raw_column_lineage.downstream.column_type
else None,
native_column_type=raw_column_lineage.downstream.column_type.sql()
if raw_column_lineage.downstream.column_type
and raw_column_lineage.downstream.column_type.this
!= sqlglot.exp.DataType.Type.UNKNOWN
else None,
),
upstreams=[
ColumnRef(
Expand Down
93 changes: 21 additions & 72 deletions metadata-ingestion/tests/integration/powerbi/test_m_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)
from datahub.ingestion.source.powerbi.m_query import parser, resolver, tree_function
from datahub.ingestion.source.powerbi.m_query.resolver import DataPlatformTable, Lineage
from datahub.utilities.sqlglot_lineage import ColumnLineageInfo, DownstreamColumnRef

pytestmark = pytest.mark.integration_batch_2

Expand Down Expand Up @@ -742,75 +741,25 @@ def test_sqlglot_parser():
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_deployment.operations_analytics.transformed_prod.v_sme_unit_targets,PROD)"
)

assert lineage[0].column_lineage == [
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="client_director"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="tier"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column='upper("manager")'),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="team_type"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="date_target"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="monthid"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="target_team"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="seller_email"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="agent_key"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="sme_quota"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="revenue_quota"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="service_quota"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="bl_target"),
upstreams=[],
logic=None,
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(table=None, column="software_quota"),
upstreams=[],
logic=None,
),
# TODO: None of these columns have upstreams?
# That doesn't seem right - we probably need to add fake schemas for the two tables above.
cols = [
"client_director",
"tier",
'upper("manager")',
"team_type",
"date_target",
"monthid",
"target_team",
"seller_email",
"agent_key",
"sme_quota",
"revenue_quota",
"service_quota",
"bl_target",
"software_quota",
]
for i, column in enumerate(cols):
assert lineage[0].column_lineage[i].downstream.table is None
assert lineage[0].column_lineage[i].downstream.column == column
assert lineage[0].column_lineage[i].upstreams == []
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col5"
"column": "col5",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand All @@ -24,7 +30,13 @@
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col1"
"column": "col1",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand All @@ -36,7 +48,13 @@
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col2"
"column": "col2",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand All @@ -48,7 +66,13 @@
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col3"
"column": "col3",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand Down
Loading

0 comments on commit 84bba4d

Please sign in to comment.