Skip to content

Commit

Permalink
fix(ingest/sql-common): sql_common to use SqlParsingAggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
sagar-salvi-apptware committed Dec 24, 2024
1 parent eceb799 commit 4631f81
Showing 1 changed file with 36 additions and 100 deletions.
136 changes: 36 additions & 100 deletions metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Dict,
Iterable,
List,
MutableMapping,
Optional,
Set,
Tuple,
Expand All @@ -36,7 +35,6 @@
make_tag_urn,
)
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.sql_parsing_builder import SqlParsingBuilder
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import capability
from datahub.ingestion.api.incremental_lineage_helper import auto_incremental_lineage
Expand Down Expand Up @@ -79,7 +77,6 @@
StatefulIngestionSourceBase,
)
from datahub.metadata.com.linkedin.pegasus2avro.common import StatusClass
from datahub.metadata.com.linkedin.pegasus2avro.dataset import UpstreamLineage
from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import DatasetSnapshot
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
Expand All @@ -106,17 +103,11 @@
GlobalTagsClass,
SubTypesClass,
TagAssociationClass,
UpstreamClass,
ViewPropertiesClass,
)
from datahub.sql_parsing.schema_resolver import SchemaResolver
from datahub.sql_parsing.sqlglot_lineage import (
SqlParsingResult,
sqlglot_lineage,
view_definition_lineage_helper,
)
from datahub.sql_parsing.sql_parsing_aggregator import SqlParsingAggregator
from datahub.telemetry import telemetry
from datahub.utilities.file_backed_collections import FileBackedDict
from datahub.utilities.registries.domain_registry import DomainRegistry
from datahub.utilities.sqlalchemy_type_converter import (
get_native_data_type_for_sqlalchemy_type,
Expand Down Expand Up @@ -353,11 +344,16 @@ def __init__(self, config: SQLCommonConfig, ctx: PipelineContext, platform: str)
env=self.config.env,
)
self.discovered_datasets: Set[str] = set()
self._view_definition_cache: MutableMapping[str, str]
if self.config.use_file_backed_cache:
self._view_definition_cache = FileBackedDict[str]()
else:
self._view_definition_cache = {}
self.aggregator = SqlParsingAggregator(
platform=self.platform,
platform_instance=self.config.platform_instance,
env=self.config.env,
schema_resolver=self.schema_resolver,
graph=self.ctx.graph,
generate_lineage=self.include_lineage,
generate_usage_statistics=False,
generate_operations=False,
)

@classmethod
def test_connection(cls, config_dict: dict) -> TestConnectionReport:
Expand Down Expand Up @@ -572,36 +568,10 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit
profile_requests, profiler, platform=self.platform
)

if self.config.include_view_lineage:
yield from self.get_view_lineage()

def get_view_lineage(self) -> Iterable[MetadataWorkUnit]:
builder = SqlParsingBuilder(
generate_lineage=True,
generate_usage_statistics=False,
generate_operations=False,
)
for dataset_name in self._view_definition_cache.keys():
# TODO: Ensure that the lineage generated from the view definition
# matches the dataset_name.
view_definition = self._view_definition_cache[dataset_name]
result = self._run_sql_parser(
dataset_name,
view_definition,
self.schema_resolver,
)
if result and result.out_tables:
# This does not yield any workunits but we use
# yield here to execute this method
yield from builder.process_sql_parsing_result(
result=result,
query=view_definition,
is_view_ddl=True,
include_column_lineage=self.config.include_view_column_lineage,
)
else:
self.views_failed_parsing.add(dataset_name)
yield from builder.gen_workunits()
# Generate workunit for aggregated SQL parsing results
for mcp in self.aggregator.gen_metadata():
self.report.report_workunit(mcp.as_workunit())
yield mcp.as_workunit()

def get_identifier(
self, *, schema: str, entity: str, inspector: Inspector, **kwargs: Any
Expand Down Expand Up @@ -760,16 +730,6 @@ def _process_table(
)
dataset_snapshot.aspects.append(dataset_properties)

if self.config.include_table_location_lineage and location_urn:
external_upstream_table = UpstreamClass(
dataset=location_urn,
type=DatasetLineageTypeClass.COPY,
)
yield MetadataChangeProposalWrapper(
entityUrn=dataset_snapshot.urn,
aspect=UpstreamLineage(upstreams=[external_upstream_table]),
).as_workunit()

extra_tags = self.get_extra_tags(inspector, schema, table)
pk_constraints: dict = inspector.get_pk_constraint(table, schema)
partitions: Optional[List[str]] = self.get_partitions(inspector, schema, table)
Expand All @@ -795,7 +755,7 @@ def _process_table(

dataset_snapshot.aspects.append(schema_metadata)
if self._save_schema_to_resolver():
self.schema_resolver.add_schema_metadata(dataset_urn, schema_metadata)
self.aggregator.register_schema(dataset_urn, schema_metadata)
self.discovered_datasets.add(dataset_name)
db_name = self.get_db_name(inspector)

Expand All @@ -815,6 +775,13 @@ def _process_table(
),
)

if self.config.include_table_location_lineage and location_urn:
self.aggregator.add_known_lineage_mapping(
upstream_urn=location_urn,
downstream_urn=dataset_snapshot.urn,
lineage_type=DatasetLineageTypeClass.COPY,
)

if self.config.domain:
assert self.domain_registry
yield from get_domain_wu(
Expand Down Expand Up @@ -1108,22 +1075,28 @@ def _process_view(
canonical_schema=schema_fields,
)
if self._save_schema_to_resolver():
self.schema_resolver.add_schema_metadata(dataset_urn, schema_metadata)
self.aggregator.register_schema(dataset_urn, schema_metadata)
self.discovered_datasets.add(dataset_name)

description, properties, _ = self.get_table_properties(inspector, schema, view)
properties["is_view"] = "True"

view_definition = self._get_view_definition(inspector, schema, view)
properties["view_definition"] = view_definition
db_name = self.get_db_name(inspector)
if view_definition and self.config.include_view_lineage:
self._view_definition_cache[dataset_name] = view_definition
self.aggregator.add_view_definition(
view_urn=dataset_urn,
view_definition=view_definition,
default_db=db_name,
default_schema=schema,
)

dataset_snapshot = DatasetSnapshot(
urn=dataset_urn,
aspects=[StatusClass(removed=False)],
)
db_name = self.get_db_name(inspector)

yield from self.add_table_to_schema_container(
dataset_urn=dataset_urn,
db_name=db_name,
Expand Down Expand Up @@ -1169,49 +1142,12 @@ def _save_schema_to_resolver(self):
hasattr(self.config, "include_lineage") and self.config.include_lineage
)

def _run_sql_parser(
self, view_identifier: str, query: str, schema_resolver: SchemaResolver
) -> Optional[SqlParsingResult]:
try:
database, schema = self.get_db_schema(view_identifier)
except ValueError:
logger.warning(f"Invalid view identifier: {view_identifier}")
return None
raw_lineage = sqlglot_lineage(
query,
schema_resolver=schema_resolver,
default_db=database,
default_schema=schema,
)
view_urn = make_dataset_urn_with_platform_instance(
self.platform,
view_identifier,
self.config.platform_instance,
self.config.env,
@property
def include_lineage(self):
return self.config.include_view_lineage or (
hasattr(self.config, "include_lineage") and self.config.include_lineage
)

if raw_lineage.debug_info.table_error:
logger.debug(
f"Failed to parse lineage for view {view_identifier}: "
f"{raw_lineage.debug_info.table_error}"
)
self.report.num_view_definitions_failed_parsing += 1
self.report.view_definitions_parsing_failures.append(
f"Table-level sql parsing error for view {view_identifier}: {raw_lineage.debug_info.table_error}"
)
return None

elif raw_lineage.debug_info.column_error:
self.report.num_view_definitions_failed_column_parsing += 1
self.report.view_definitions_parsing_failures.append(
f"Column-level sql parsing error for view {view_identifier}: {raw_lineage.debug_info.column_error}"
)
else:
self.report.num_view_definitions_parsed += 1
if raw_lineage.out_tables != [view_urn]:
self.report.num_view_definitions_view_urn_mismatch += 1
return view_definition_lineage_helper(raw_lineage, view_urn)

def get_db_schema(self, dataset_identifier: str) -> Tuple[Optional[str], str]:
database, schema, _view = dataset_identifier.split(".", 2)
return database, schema
Expand Down

0 comments on commit 4631f81

Please sign in to comment.