diff --git a/python_modules/dagster/dagster/_core/definitions/asset_graph.py b/python_modules/dagster/dagster/_core/definitions/asset_graph.py index f4f64b68098d6..29689ee8ce1d0 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_graph.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_graph.py @@ -84,6 +84,16 @@ def tags(self) -> Mapping[str, str]: def kinds(self) -> AbstractSet[str]: return self._spec.kinds or set() + @property + def pools(self) -> Optional[Set[str]]: + if not self.assets_def.computation: + return None + return set( + op_def.pool + for op_def in self.assets_def.computation.node_def.iterate_op_defs() + if op_def.pool + ) + @property def owners(self) -> Sequence[str]: return self._spec.owners diff --git a/python_modules/dagster/dagster/_core/definitions/base_asset_graph.py b/python_modules/dagster/dagster/_core/definitions/base_asset_graph.py index 9c936179d747f..b4ef09f86ee89 100644 --- a/python_modules/dagster/dagster/_core/definitions/base_asset_graph.py +++ b/python_modules/dagster/dagster/_core/definitions/base_asset_graph.py @@ -142,6 +142,10 @@ def metadata(self) -> ArbitraryMetadataMapping: ... @abstractmethod def tags(self) -> Mapping[str, str]: ... + @property + @abstractmethod + def pools(self) -> Optional[Set[str]]: ... + @property @abstractmethod def owners(self) -> Sequence[str]: ... diff --git a/python_modules/dagster/dagster/_core/definitions/remote_asset_graph.py b/python_modules/dagster/dagster/_core/definitions/remote_asset_graph.py index d9a29e2f0aa64..15bec87b98c34 100644 --- a/python_modules/dagster/dagster/_core/definitions/remote_asset_graph.py +++ b/python_modules/dagster/dagster/_core/definitions/remote_asset_graph.py @@ -86,6 +86,10 @@ def group_name(self) -> str: def metadata(self) -> ArbitraryMetadataMapping: return self.resolve_to_singular_repo_scoped_node().asset_node_snap.metadata + @property + def pools(self) -> Optional[Set[str]]: + return self.resolve_to_singular_repo_scoped_node().pools + @property def tags(self) -> Mapping[str, str]: return self.resolve_to_singular_repo_scoped_node().asset_node_snap.tags or {} @@ -193,6 +197,10 @@ def auto_materialize_policy(self) -> Optional[AutoMaterializePolicy]: def auto_observe_interval_minutes(self) -> Optional[float]: return self.asset_node_snap.auto_observe_interval_minutes + @property + def pools(self) -> Optional[Set[str]]: + return self.asset_node_snap.pools + @whitelist_for_serdes @record @@ -274,6 +282,13 @@ def is_external(self) -> bool: def is_executable(self) -> bool: return any(node.asset_node.is_executable for node in self.repo_scoped_asset_infos) + @cached_property + def pools(self) -> Optional[Set[str]]: + pools = set() + for info in self.repo_scoped_asset_infos: + pools.update(info.asset_node.pools or set()) + return pools + @property def partition_mappings(self) -> Mapping[AssetKey, PartitionMapping]: if self.is_materializable: diff --git a/python_modules/dagster/dagster/_core/remote_representation/external_data.py b/python_modules/dagster/dagster/_core/remote_representation/external_data.py index 525926355318a..e21ee2b489225 100644 --- a/python_modules/dagster/dagster/_core/remote_representation/external_data.py +++ b/python_modules/dagster/dagster/_core/remote_representation/external_data.py @@ -1408,6 +1408,7 @@ class AssetNodeSnap(IHaveNew): parent_edges: Sequence[AssetParentEdgeSnap] child_edges: Sequence[AssetChildEdgeSnap] execution_type: AssetExecutionType + pools: Set[str] compute_kind: Optional[str] op_name: Optional[str] op_names: Sequence[str] @@ -1441,6 +1442,7 @@ def __new__( parent_edges: Sequence[AssetParentEdgeSnap], child_edges: Sequence[AssetChildEdgeSnap], execution_type: Optional[AssetExecutionType] = None, + pools: Optional[Set[str]] = None, compute_kind: Optional[str] = None, op_name: Optional[str] = None, op_names: Optional[Sequence[str]] = None, @@ -1516,6 +1518,7 @@ def __new__( parent_edges=parent_edges or [], child_edges=child_edges or [], compute_kind=compute_kind, + pools=pools or set(), op_name=op_name, op_names=op_names or [], code_version=code_version, @@ -1676,6 +1679,12 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode graph_name = ( root_node_handle.name if root_node_handle != output_handle.node_handle else None ) + op_defs = [ + cast(OpDefinition, job_def.graph.get_node(node_handle).definition) + for node_handle in node_handles + if isinstance(job_def.graph.get_node(node_handle).definition, OpDefinition) + ] + pools = {op_def.pool for op_def in op_defs if op_def.pool} op_names = sorted([str(handle) for handle in node_handles]) op_name = graph_name or next(iter(op_names), None) or node_def.name job_names = sorted([jd.name for jd in job_defs_by_asset_key[key]]) @@ -1693,6 +1702,7 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode else: graph_name = None + pools = set() op_names = [] op_name = None job_names = [] @@ -1731,6 +1741,7 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode ], execution_type=asset_node.execution_type, compute_kind=compute_kind, + pools=pools, op_name=op_name, op_names=op_names, code_version=asset_node.code_version, diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_external_asset_graph.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_external_asset_graph.py index 839cee84988ed..8ac8cb5fc142a 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_external_asset_graph.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_external_asset_graph.py @@ -8,6 +8,7 @@ from dagster import ( AssetIn, AssetKey, + AssetSpec, DagsterInstance, DailyPartitionsDefinition, Definitions, @@ -16,6 +17,9 @@ StaticPartitionMapping, StaticPartitionsDefinition, asset, + graph_asset, + multi_asset, + op, ) from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy from dagster._core.definitions.backfill_policy import BackfillPolicy @@ -364,3 +368,42 @@ def test_dup_node_detection(instance): _ = _make_context( instance, ["dup_observation_defs_a", "dup_observation_defs_b"] ).asset_graph + + +@asset(pool="foo") +def my_asset(): + pass + + +@op(pool="bar") +def my_op(): + pass + + +@graph_asset +def my_graph_asset(): + return my_op() + + +@multi_asset( + specs=[ + AssetSpec("multi_asset_1"), + AssetSpec("multi_asset_2"), + ], + pool="baz", +) +def my_multi_asset(): + pass + + +concurrency_assets = Definitions(assets=[my_asset, my_graph_asset, my_multi_asset]) + + +def test_pool_snap(instance) -> None: + context = _make_context(instance, ["concurrency_assets"]) + asset_graph = context.asset_graph + assert asset_graph + assert asset_graph.get(AssetKey("my_asset")).pools == {"foo"} + assert asset_graph.get(AssetKey("my_graph_asset")).pools == {"bar"} + assert asset_graph.get(AssetKey("multi_asset_1")).pools == {"baz"} + assert asset_graph.get(AssetKey("multi_asset_2")).pools == {"baz"}