Skip to content

Commit

Permalink
Make compute asset subset functions async
Browse files Browse the repository at this point in the history
  • Loading branch information
briantu committed Oct 14, 2024
1 parent 2c11e72 commit 3bf9beb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def compute_latest_time_window_subset(
else:
check.failed(f"Unsupported partitions_def: {partitions_def}")

def compute_missing_subset(
async def compute_missing_subset(
self, asset_key: "AssetKey", from_subset: EntitySubset[AssetKey]
) -> EntitySubset[AssetKey]:
"""Returns a subset which is the subset of the input subset that has never been materialized
Expand All @@ -359,7 +359,7 @@ def compute_missing_subset(
# cheap call which takes advantage of the partition status cache
partitions_def = self._get_partitions_def(asset_key)
if partitions_def:
cache_value = AssetStatusCacheValue.blocking_get(self, (asset_key, partitions_def))
cache_value = await AssetStatusCacheValue.gen(self, (asset_key, partitions_def))
return (
cache_value.get_materialized_subset(self, asset_key, partitions_def)
if cache_value
Expand All @@ -383,13 +383,15 @@ def compute_missing_subset(
)

@cached_method
def compute_in_progress_asset_subset(self, *, asset_key: AssetKey) -> EntitySubset[AssetKey]:
async def compute_in_progress_asset_subset(
self, *, asset_key: AssetKey
) -> EntitySubset[AssetKey]:
from dagster._core.storage.partition_status_cache import AssetStatusCacheValue

# part of in progress run
partitions_def = self._get_partitions_def(asset_key)
if partitions_def:
cache_value = AssetStatusCacheValue.blocking_get(self, (asset_key, partitions_def))
cache_value = await AssetStatusCacheValue.gen(self, (asset_key, partitions_def))
return (
cache_value.get_in_progress_subset(self, asset_key, partitions_def)
if cache_value
Expand All @@ -400,12 +402,12 @@ def compute_in_progress_asset_subset(self, *, asset_key: AssetKey) -> EntitySubs
return EntitySubset(self, key=asset_key, value=_ValidatedEntitySubsetValue(value))

@cached_method
def compute_failed_asset_subset(self, *, asset_key: "AssetKey") -> EntitySubset[AssetKey]:
async def compute_failed_asset_subset(self, *, asset_key: "AssetKey") -> EntitySubset[AssetKey]:
from dagster._core.storage.partition_status_cache import AssetStatusCacheValue

partitions_def = self._get_partitions_def(asset_key)
if partitions_def:
cache_value = AssetStatusCacheValue.blocking_get(self, (asset_key, partitions_def))
cache_value = await AssetStatusCacheValue.gen(self, (asset_key, partitions_def))
return (
cache_value.get_failed_subset(self, asset_key, partitions_def)
if cache_value
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import inspect
from abc import abstractmethod
from typing import Optional

Expand Down Expand Up @@ -50,10 +51,14 @@ def compute_subset(
self, context: AutomationContext[T_EntityKey]
) -> EntitySubset[T_EntityKey]: ...

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
# don't compute anything if there are no candidates
if context.candidate_subset.is_empty:
true_subset = context.get_empty_subset()
elif inspect.iscoroutinefunction(self.compute_subset):
true_subset = await self.compute_subset(context)
else:
true_subset = self.compute_subset(context)

Expand All @@ -71,8 +76,8 @@ def description(self) -> str:
def name(self) -> str:
return "missing"

def compute_subset(self, context: AutomationContext) -> EntitySubset[AssetKey]:
return context.asset_graph_view.compute_missing_subset(
async def compute_subset(self, context: AutomationContext) -> EntitySubset[AssetKey]:
return await context.asset_graph_view.compute_missing_subset(
context.key, from_subset=context.candidate_subset
)

Expand All @@ -88,8 +93,10 @@ def description(self) -> str:
def name(self) -> str:
return "in_progress"

def compute_subset(self, context: AutomationContext) -> EntitySubset[AssetKey]:
return context.asset_graph_view.compute_in_progress_asset_subset(asset_key=context.key)
async def compute_subset(self, context: AutomationContext) -> EntitySubset[AssetKey]:
return await context.asset_graph_view.compute_in_progress_asset_subset(
asset_key=context.key
)


@whitelist_for_serdes
Expand All @@ -103,8 +110,8 @@ def description(self) -> str:
def name(self) -> str:
return "failed"

def compute_subset(self, context: AutomationContext) -> EntitySubset[AssetKey]:
return context.asset_graph_view.compute_failed_asset_subset(asset_key=context.key)
async def compute_subset(self, context: AutomationContext) -> EntitySubset[AssetKey]:
return await context.asset_graph_view.compute_failed_asset_subset(asset_key=context.key)


@whitelist_for_serdes
Expand Down

0 comments on commit 3bf9beb

Please sign in to comment.