Skip to content

Commit

Permalink
feat(sdk): support urns in other urn constructors (#12311)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Jan 10, 2025
1 parent 208447d commit 5f63f3f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
38 changes: 25 additions & 13 deletions metadata-ingestion/scripts/avro_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def write_urn_classes(key_aspects: List[dict], urn_dir: Path) -> None:
code = """
# This file contains classes corresponding to entity URNs.
from typing import ClassVar, List, Optional, Type, TYPE_CHECKING
from typing import ClassVar, List, Optional, Type, TYPE_CHECKING, Union
import functools
from deprecated.sphinx import deprecated as _sphinx_deprecated
Expand Down Expand Up @@ -547,10 +547,31 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
assert fields[0]["type"] == ["null", "string"]
fields[0]["type"] = "string"

field_urn_type_classes = {}
for field in fields:
# Figure out if urn types are valid for each field.
field_urn_type_class = None
if field_name(field) == "platform":
field_urn_type_class = "DataPlatformUrn"
elif field.get("Urn"):
if len(field.get("entityTypes", [])) == 1:
field_entity_type = field["entityTypes"][0]
field_urn_type_class = f"{capitalize_entity_name(field_entity_type)}Urn"
else:
field_urn_type_class = "Urn"

field_urn_type_classes[field_name(field)] = field_urn_type_class

_init_arg_parts: List[str] = []
for field in fields:
field_urn_type_class = field_urn_type_classes[field_name(field)]

default = '"PROD"' if field_name(field) == "env" else None
_arg_part = f"{field_name(field)}: {field_type(field)}"

type_hint = field_type(field)
if field_urn_type_class:
type_hint = f'Union["{field_urn_type_class}", str]'
_arg_part = f"{field_name(field)}: {type_hint}"
if default:
_arg_part += f" = {default}"
_init_arg_parts.append(_arg_part)
Expand Down Expand Up @@ -579,16 +600,7 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
init_validation += f'if not {field_name(field)}:\n raise InvalidUrnError("{class_name} {field_name(field)} cannot be empty")\n'

# Generalized mechanism for validating embedded urns.
field_urn_type_class = None
if field_name(field) == "platform":
field_urn_type_class = "DataPlatformUrn"
elif field.get("Urn"):
if len(field.get("entityTypes", [])) == 1:
field_entity_type = field["entityTypes"][0]
field_urn_type_class = f"{capitalize_entity_name(field_entity_type)}Urn"
else:
field_urn_type_class = "Urn"

field_urn_type_class = field_urn_type_classes[field_name(field)]
if field_urn_type_class:
init_validation += f"{field_name(field)} = str({field_name(field)})\n"
init_validation += (
Expand All @@ -608,7 +620,7 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
init_coercion += " platform_name = DataPlatformUrn.from_string(platform_name).platform_name\n"

if field_name(field) == "platform":
init_coercion += "platform = DataPlatformUrn(platform).urn()\n"
init_coercion += "platform = platform.urn() if isinstance(platform, DataPlatformUrn) else DataPlatformUrn(platform).urn()\n"
elif field_urn_type_class is None:
# For all non-urns, run the value through the UrnEncoder.
init_coercion += (
Expand Down
22 changes: 21 additions & 1 deletion metadata-ingestion/tests/unit/urns/test_urn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

import pytest

from datahub.metadata.urns import CorpUserUrn, DatasetUrn, Urn
from datahub.metadata.urns import (
CorpUserUrn,
DataPlatformUrn,
DatasetUrn,
SchemaFieldUrn,
Urn,
)
from datahub.utilities.urns.error import InvalidUrnError

pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning")
Expand Down Expand Up @@ -60,6 +66,20 @@ def test_urn_coercion() -> None:
assert urn == Urn.from_string(urn.urn())


def test_urns_in_init() -> None:
platform = DataPlatformUrn("abc")
assert platform.urn() == "urn:li:dataPlatform:abc"

dataset_urn = DatasetUrn(platform, "def", "PROD")
assert dataset_urn.urn() == "urn:li:dataset:(urn:li:dataPlatform:abc,def,PROD)"

schema_field = SchemaFieldUrn(dataset_urn, "foo")
assert (
schema_field.urn()
== "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:abc,def,PROD),foo)"
)


def test_urn_type_dispatch_1() -> None:
urn = Urn.from_string("urn:li:dataset:(urn:li:dataPlatform:abc,def,PROD)")
assert isinstance(urn, DatasetUrn)
Expand Down

0 comments on commit 5f63f3f

Please sign in to comment.