diff --git a/metadata-ingestion/src/datahub/ingestion/graph/client.py b/metadata-ingestion/src/datahub/ingestion/graph/client.py index d91165ac9777c..5c24b06dde999 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/client.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/client.py @@ -787,9 +787,11 @@ def get_aspect_counts(self, aspect: str, urn_like: Optional[str] = None) -> int: def execute_graphql(self, query: str, variables: Optional[Dict] = None) -> Dict: url = f"{self.config.server}/api/graphql" + body: Dict = { "query": query, } + if variables: body["variables"] = variables @@ -1065,6 +1067,28 @@ def parse_sql_lineage( default_schema=default_schema, ) + def create_tag(self, tag_name: str) -> str: + graph_query: str = """ + mutation($tag_detail: CreateTagInput!) { + createTag(input: $tag_detail) + } + """ + + variables = { + "tag_detail": { + "name": tag_name, + "id": tag_name, + }, + } + + res = self.execute_graphql( + query=graph_query, + variables=variables, + ) + + # return urn + return res["createTag"] + def close(self) -> None: self._make_schema_resolver.cache_clear() super().close() diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py index 5a276ad899c48..72a8c226e491e 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py @@ -1,14 +1,24 @@ +import logging from typing import Callable, List, Optional, cast +import datahub.emitter.mce_builder as builder from datahub.configuration.common import ( KeyValuePattern, TransformerSemanticsConfigModel, ) from datahub.configuration.import_resolver import pydantic_resolve_key from datahub.emitter.mce_builder import Aspect +from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.transformer.dataset_transformer import DatasetTagsTransformer -from datahub.metadata.schema_classes import GlobalTagsClass, TagAssociationClass +from datahub.metadata.schema_classes import ( + GlobalTagsClass, + TagAssociationClass, + TagKeyClass, +) +from datahub.utilities.urns.tag_urn import TagUrn + +logger = logging.getLogger(__name__) class AddDatasetTagsConfig(TransformerSemanticsConfigModel): @@ -22,11 +32,13 @@ class AddDatasetTags(DatasetTagsTransformer): ctx: PipelineContext config: AddDatasetTagsConfig + processed_tags: List[TagAssociationClass] def __init__(self, config: AddDatasetTagsConfig, ctx: PipelineContext): super().__init__() self.ctx = ctx self.config = config + self.processed_tags = [] @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetTags": @@ -45,11 +57,38 @@ def transform_aspect( tags_to_add = self.config.get_tags_to_add(entity_urn) if tags_to_add is not None: out_global_tags_aspect.tags.extend(tags_to_add) + self.processed_tags.extend( + tags_to_add + ) # Keep track of tags added so that we can create them in handle_end_of_stream return self.get_result_semantics( self.config, self.ctx.graph, entity_urn, out_global_tags_aspect ) + def handle_end_of_stream(self) -> List[MetadataChangeProposalWrapper]: + + mcps: List[MetadataChangeProposalWrapper] = [] + + logger.debug("Generating tags") + + for tag_association in self.processed_tags: + ids: List[str] = TagUrn.create_from_string( + tag_association.tag + ).get_entity_id() + + assert len(ids) == 1, "Invalid Tag Urn" + + tag_name: str = ids[0] + + mcps.append( + MetadataChangeProposalWrapper( + entityUrn=builder.make_tag_urn(tag=tag_name), + aspect=TagKeyClass(name=tag_name), + ) + ) + + return mcps + class SimpleDatasetTagConfig(TransformerSemanticsConfigModel): tag_urns: List[str] @@ -82,6 +121,7 @@ class PatternAddDatasetTags(AddDatasetTags): """Transformer that adds a specified set of tags to each dataset.""" def __init__(self, config: PatternDatasetTagsConfig, ctx: PipelineContext): + config.tag_pattern.all tag_pattern = config.tag_pattern generic_config = AddDatasetTagsConfig( get_tags_to_add=lambda _: [ diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/base_transformer.py b/metadata-ingestion/src/datahub/ingestion/transformer/base_transformer.py index e0d6ae720c9a1..8b6f42dcfba4b 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/base_transformer.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/base_transformer.py @@ -17,13 +17,30 @@ log = logging.getLogger(__name__) -class LegacyMCETransformer(Transformer, metaclass=ABCMeta): +def _update_work_unit_id( + envelope: RecordEnvelope, urn: str, aspect_name: str +) -> Dict[Any, Any]: + structured_urn = Urn.create_from_string(urn) + simple_name = "-".join(structured_urn.get_entity_id()) + record_metadata = envelope.metadata.copy() + record_metadata.update({"workunit_id": f"txform-{simple_name}-{aspect_name}"}) + return record_metadata + + +class HandleEndOfStreamTransformer: + def handle_end_of_stream(self) -> List[MetadataChangeProposalWrapper]: + return [] + + +class LegacyMCETransformer( + Transformer, HandleEndOfStreamTransformer, metaclass=ABCMeta +): @abstractmethod def transform_one(self, mce: MetadataChangeEventClass) -> MetadataChangeEventClass: pass -class SingleAspectTransformer(metaclass=ABCMeta): +class SingleAspectTransformer(HandleEndOfStreamTransformer, metaclass=ABCMeta): @abstractmethod def aspect_name(self) -> str: """Implement this method to specify a single aspect that the transformer is interested in subscribing to. No default provided.""" @@ -180,6 +197,32 @@ def _transform_or_record_mcpw( self._record_mcp(envelope.record) return envelope if envelope.record.aspect is not None else None + def _handle_end_of_stream( + self, envelope: RecordEnvelope + ) -> Iterable[RecordEnvelope]: + + if not isinstance(self, SingleAspectTransformer) and not isinstance( + self, LegacyMCETransformer + ): + return + + mcps: List[MetadataChangeProposalWrapper] = self.handle_end_of_stream() + + for mcp in mcps: + if mcp.aspect is None or mcp.entityUrn is None: # to silent the lint error + continue + + record_metadata = _update_work_unit_id( + envelope=envelope, + aspect_name=mcp.aspect.get_aspect_name(), # type: ignore + urn=mcp.entityUrn, + ) + + yield RecordEnvelope( + record=mcp, + metadata=record_metadata, + ) + def transform( self, record_envelopes: Iterable[RecordEnvelope] ) -> Iterable[RecordEnvelope]: @@ -216,17 +259,10 @@ def transform( else None, ) if transformed_aspect: - # for end of stream records, we modify the workunit-id structured_urn = Urn.create_from_string(urn) - simple_name = "-".join(structured_urn.get_entity_id()) - record_metadata = envelope.metadata.copy() - record_metadata.update( - { - "workunit_id": f"txform-{simple_name}-{self.aspect_name()}" - } - ) - yield RecordEnvelope( - record=MetadataChangeProposalWrapper( + + mcp: MetadataChangeProposalWrapper = ( + MetadataChangeProposalWrapper( entityUrn=urn, entityType=structured_urn.get_type(), systemMetadata=last_seen_mcp.systemMetadata @@ -234,8 +270,21 @@ def transform( else last_seen_mce_system_metadata, aspectName=self.aspect_name(), aspect=transformed_aspect, - ), + ) + ) + + record_metadata = _update_work_unit_id( + envelope=envelope, + aspect_name=mcp.aspect.get_aspect_name(), # type: ignore + urn=mcp.entityUrn, + ) + + yield RecordEnvelope( + record=mcp, metadata=record_metadata, ) + self._mark_processed(urn) + yield from self._handle_end_of_stream(envelope=envelope) + yield envelope diff --git a/metadata-ingestion/tests/unit/test_transform_dataset.py b/metadata-ingestion/tests/unit/test_transform_dataset.py index 8014df2f5c519..546549dcf37a4 100644 --- a/metadata-ingestion/tests/unit/test_transform_dataset.py +++ b/metadata-ingestion/tests/unit/test_transform_dataset.py @@ -813,13 +813,25 @@ def test_simple_dataset_tags_transformation(mock_time): ] ) ) - assert len(outputs) == 3 + + assert len(outputs) == 5 # Check that tags were added. tags_aspect = outputs[1].record.aspect + assert tags_aspect.tags[0].tag == builder.make_tag_urn("NeedsDocumentation") assert tags_aspect assert len(tags_aspect.tags) == 2 - assert tags_aspect.tags[0].tag == builder.make_tag_urn("NeedsDocumentation") + + # Check new tag entity should be there + assert outputs[2].record.aspectName == "tagKey" + assert outputs[2].record.aspect.name == "NeedsDocumentation" + assert outputs[2].record.entityUrn == builder.make_tag_urn("NeedsDocumentation") + + assert outputs[3].record.aspectName == "tagKey" + assert outputs[3].record.aspect.name == "Legacy" + assert outputs[3].record.entityUrn == builder.make_tag_urn("Legacy") + + assert isinstance(outputs[4].record, EndOfStream) def dummy_tag_resolver_method(dataset_snapshot): @@ -853,7 +865,7 @@ def test_pattern_dataset_tags_transformation(mock_time): ) ) - assert len(outputs) == 3 + assert len(outputs) == 5 tags_aspect = outputs[1].record.aspect assert tags_aspect assert len(tags_aspect.tags) == 2 @@ -1363,7 +1375,7 @@ def test_mcp_add_tags_missing(mock_time): ] input_stream.append(RecordEnvelope(record=EndOfStream(), metadata={})) outputs = list(transformer.transform(input_stream)) - assert len(outputs) == 3 + assert len(outputs) == 5 assert outputs[0].record == dataset_mcp # Check that tags were added, this will be the second result tags_aspect = outputs[1].record.aspect @@ -1395,13 +1407,23 @@ def test_mcp_add_tags_existing(mock_time): ] input_stream.append(RecordEnvelope(record=EndOfStream(), metadata={})) outputs = list(transformer.transform(input_stream)) - assert len(outputs) == 2 + + assert len(outputs) == 4 + # Check that tags were added, this will be the second result tags_aspect = outputs[0].record.aspect assert tags_aspect assert len(tags_aspect.tags) == 3 assert tags_aspect.tags[0].tag == builder.make_tag_urn("Test") assert tags_aspect.tags[1].tag == builder.make_tag_urn("NeedsDocumentation") + assert tags_aspect.tags[2].tag == builder.make_tag_urn("Legacy") + + # Check tag entities got added + assert outputs[1].record.entityType == "tag" + assert outputs[1].record.entityUrn == builder.make_tag_urn("NeedsDocumentation") + assert outputs[2].record.entityType == "tag" + assert outputs[2].record.entityUrn == builder.make_tag_urn("Legacy") + assert isinstance(outputs[-1].record, EndOfStream)