Skip to content

Commit

Permalink
thread pool info to remote representation
Browse files Browse the repository at this point in the history
  • Loading branch information
prha committed Jan 2, 2025
1 parent a36d157 commit ec8a4d6
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python_modules/dagster/dagster/_core/definitions/asset_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ def tags(self) -> Mapping[str, str]:
def kinds(self) -> AbstractSet[str]:
return self._spec.kinds or set()

@property
def pools(self) -> Optional[Set[str]]:
if not self.assets_def.computation:
return None
return set(
op_def.pool
for op_def in self.assets_def.computation.node_def.iterate_op_defs()
if op_def.pool
)

@property
def owners(self) -> Sequence[str]:
return self._spec.owners
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def metadata(self) -> ArbitraryMetadataMapping: ...
@abstractmethod
def tags(self) -> Mapping[str, str]: ...

@property
@abstractmethod
def pools(self) -> Optional[Set[str]]: ...

@property
@abstractmethod
def owners(self) -> Sequence[str]: ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def group_name(self) -> str:
def metadata(self) -> ArbitraryMetadataMapping:
return self.resolve_to_singular_repo_scoped_node().asset_node_snap.metadata

@property
def pools(self) -> Optional[Set[str]]:
return self.resolve_to_singular_repo_scoped_node().pools

@property
def tags(self) -> Mapping[str, str]:
return self.resolve_to_singular_repo_scoped_node().asset_node_snap.tags or {}
Expand Down Expand Up @@ -193,6 +197,10 @@ def auto_materialize_policy(self) -> Optional[AutoMaterializePolicy]:
def auto_observe_interval_minutes(self) -> Optional[float]:
return self.asset_node_snap.auto_observe_interval_minutes

@property
def pools(self) -> Optional[Set[str]]:
return self.asset_node_snap.pools


@whitelist_for_serdes
@record
Expand Down Expand Up @@ -274,6 +282,13 @@ def is_external(self) -> bool:
def is_executable(self) -> bool:
return any(node.asset_node.is_executable for node in self.repo_scoped_asset_infos)

@cached_property
def pools(self) -> Optional[Set[str]]:
pools = set()
for info in self.repo_scoped_asset_infos:
pools.update(info.asset_node.pools or set())
return pools

@property
def partition_mappings(self) -> Mapping[AssetKey, PartitionMapping]:
if self.is_materializable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,7 @@ class AssetNodeSnap(IHaveNew):
parent_edges: Sequence[AssetParentEdgeSnap]
child_edges: Sequence[AssetChildEdgeSnap]
execution_type: AssetExecutionType
pools: Set[str]
compute_kind: Optional[str]
op_name: Optional[str]
op_names: Sequence[str]
Expand Down Expand Up @@ -1441,6 +1442,7 @@ def __new__(
parent_edges: Sequence[AssetParentEdgeSnap],
child_edges: Sequence[AssetChildEdgeSnap],
execution_type: Optional[AssetExecutionType] = None,
pools: Optional[Set[str]] = None,
compute_kind: Optional[str] = None,
op_name: Optional[str] = None,
op_names: Optional[Sequence[str]] = None,
Expand Down Expand Up @@ -1516,6 +1518,7 @@ def __new__(
parent_edges=parent_edges or [],
child_edges=child_edges or [],
compute_kind=compute_kind,
pools=pools,
op_name=op_name,
op_names=op_names or [],
code_version=code_version,
Expand Down Expand Up @@ -1676,6 +1679,12 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode
graph_name = (
root_node_handle.name if root_node_handle != output_handle.node_handle else None
)
op_defs = [
cast(OpDefinition, job_def.graph.get_node(node_handle).definition)
for node_handle in node_handles
if isinstance(job_def.graph.get_node(node_handle).definition, OpDefinition)
]
pools = {op_def.pool for op_def in op_defs if op_def.pool}
op_names = sorted([str(handle) for handle in node_handles])
op_name = graph_name or next(iter(op_names), None) or node_def.name
job_names = sorted([jd.name for jd in job_defs_by_asset_key[key]])
Expand All @@ -1693,6 +1702,7 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode

else:
graph_name = None
pools = set()
op_names = []
op_name = None
job_names = []
Expand Down Expand Up @@ -1731,6 +1741,7 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode
],
execution_type=asset_node.execution_type,
compute_kind=compute_kind,
pools=pools,
op_name=op_name,
op_names=op_names,
code_version=asset_node.code_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dagster import (
AssetIn,
AssetKey,
AssetSpec,
DagsterInstance,
DailyPartitionsDefinition,
Definitions,
Expand All @@ -16,6 +17,9 @@
StaticPartitionMapping,
StaticPartitionsDefinition,
asset,
graph_asset,
multi_asset,
op,
)
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.backfill_policy import BackfillPolicy
Expand Down Expand Up @@ -364,3 +368,42 @@ def test_dup_node_detection(instance):
_ = _make_context(
instance, ["dup_observation_defs_a", "dup_observation_defs_b"]
).asset_graph


@asset(pool="foo")
def my_asset():
pass


@op(pool="bar")
def my_op():
pass


@graph_asset
def my_graph_asset():
return my_op()


@multi_asset(
specs=[
AssetSpec("multi_asset_1"),
AssetSpec("multi_asset_2"),
],
pool="baz",
)
def my_multi_asset():
pass


concurrency_assets = Definitions(assets=[my_asset, my_graph_asset, my_multi_asset])


def test_pool_snap(instance) -> None:
context = _make_context(instance, ["concurrency_assets"])
asset_graph = context.asset_graph
assert asset_graph
assert asset_graph.get(AssetKey("my_asset")).pools == {"foo"}
assert asset_graph.get(AssetKey("my_graph_asset")).pools == {"bar"}
assert asset_graph.get(AssetKey("multi_asset_1")).pools == {"baz"}
assert asset_graph.get(AssetKey("multi_asset_2")).pools == {"baz"}

0 comments on commit ec8a4d6

Please sign in to comment.