Skip to content

Commit

Permalink
Make compute asset/check subset functions async (dagster-io#25264)
Browse files Browse the repository at this point in the history
## Summary & Motivation
We now want to make the compute asset and asset check subset functions on `AssetGraphView` async so we can use `AssetStatusCacheValue.gen()` instead of `AssetStatusCacheValue.blocking_get()`. Same for `AssetCheckExecutionRecord`.

## How I Tested These Changes
Existing tests should pass
  • Loading branch information
briantu authored and Grzyblon committed Oct 26, 2024
1 parent 365309e commit bc6dab3
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ async def resolve_executionForLatestMaterialization(
record = await AssetCheckExecutionRecord.gen(graphene_info.context, self._asset_check.key)
return (
GrapheneAssetCheckExecution(record)
if record and record.targets_latest_materialization(graphene_info.context)
if record and await record.targets_latest_materialization(graphene_info.context)
else None
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import (
TYPE_CHECKING,
AbstractSet,
Awaitable,
Callable,
Dict,
Literal,
Expand Down Expand Up @@ -356,86 +357,90 @@ def compute_latest_time_window_subset(
else:
check.failed(f"Unsupported partitions_def: {partitions_def}")

def compute_subset_with_status(
async def compute_subset_with_status(
self, key: AssetCheckKey, status: Optional["AssetCheckExecutionResolvedStatus"]
):
"""Returns the subset of an asset check that matches a given status."""
from dagster._core.storage.event_log.base import AssetCheckSummaryRecord

summary = AssetCheckSummaryRecord.blocking_get(self, key)
"""Returns the subset of an asset check that matches a given status."""
summary = await AssetCheckSummaryRecord.gen(self, key)
latest_record = summary.last_check_execution_record if summary else None
resolved_status = (
latest_record.resolve_status(self)
if latest_record and latest_record.targets_latest_materialization(self)
if latest_record and await latest_record.targets_latest_materialization(self)
else None
)
if resolved_status == status:
return self.get_full_subset(key=key)
else:
return self.get_empty_subset(key=key)

def _compute_run_in_progress_check_subset(
async def _compute_run_in_progress_check_subset(
self, key: AssetCheckKey
) -> EntitySubset[AssetCheckKey]:
from dagster._core.storage.asset_check_execution_record import (
AssetCheckExecutionResolvedStatus,
)

return self.compute_subset_with_status(key, AssetCheckExecutionResolvedStatus.IN_PROGRESS)
return await self.compute_subset_with_status(
key, AssetCheckExecutionResolvedStatus.IN_PROGRESS
)

def _compute_execution_failed_check_subset(
async def _compute_execution_failed_check_subset(
self, key: AssetCheckKey
) -> EntitySubset[AssetCheckKey]:
from dagster._core.storage.asset_check_execution_record import (
AssetCheckExecutionResolvedStatus,
)

return self.compute_subset_with_status(
return await self.compute_subset_with_status(
key, AssetCheckExecutionResolvedStatus.EXECUTION_FAILED
)

def _compute_missing_check_subset(self, key: AssetCheckKey) -> EntitySubset[AssetCheckKey]:
return self.compute_subset_with_status(key, None)
async def _compute_missing_check_subset(
self, key: AssetCheckKey
) -> EntitySubset[AssetCheckKey]:
return await self.compute_subset_with_status(key, None)

def _compute_run_in_progress_asset_subset(self, key: AssetKey) -> EntitySubset[AssetKey]:
async def _compute_run_in_progress_asset_subset(self, key: AssetKey) -> EntitySubset[AssetKey]:
from dagster._core.storage.partition_status_cache import AssetStatusCacheValue

partitions_def = self._get_partitions_def(key)
if partitions_def:
cache_value = AssetStatusCacheValue.blocking_get(self, (key, partitions_def))
cache_value = await AssetStatusCacheValue.gen(self, (key, partitions_def))
return (
cache_value.get_in_progress_subset(self, key, partitions_def)
if cache_value
else self.get_empty_subset(key=key)
)
else:
value = self._queryer.get_in_progress_asset_subset(asset_key=key).value
value = self._queryer.get_in_progress_asset_subset(asset_key=key).value
return EntitySubset(self, key=key, value=_ValidatedEntitySubsetValue(value))

def _compute_backfill_in_progress_asset_subset(self, key: AssetKey) -> EntitySubset[AssetKey]:
async def _compute_backfill_in_progress_asset_subset(
self, key: AssetKey
) -> EntitySubset[AssetKey]:
value = (
self._queryer.get_active_backfill_in_progress_asset_graph_subset()
.get_asset_subset(asset_key=key, asset_graph=self.asset_graph)
.value
)
return EntitySubset(self, key=key, value=_ValidatedEntitySubsetValue(value))

def _compute_execution_failed_asset_subset(self, key: AssetKey) -> EntitySubset[AssetKey]:
async def _compute_execution_failed_asset_subset(self, key: AssetKey) -> EntitySubset[AssetKey]:
from dagster._core.storage.partition_status_cache import AssetStatusCacheValue

partitions_def = self._get_partitions_def(key)
if partitions_def:
cache_value = AssetStatusCacheValue.blocking_get(self, (key, partitions_def))
cache_value = await AssetStatusCacheValue.gen(self, (key, partitions_def))
return (
cache_value.get_failed_subset(self, key, partitions_def)
if cache_value
else self.get_empty_subset(key=key)
)
else:
value = self._queryer.get_failed_asset_subset(asset_key=key).value
value = self._queryer.get_failed_asset_subset(asset_key=key).value
return EntitySubset(self, key=key, value=_ValidatedEntitySubsetValue(value))

def _compute_missing_asset_subset(
async def _compute_missing_asset_subset(
self, key: AssetKey, from_subset: EntitySubset
) -> EntitySubset[AssetKey]:
"""Returns a subset which is the subset of the input subset that has never been materialized
Expand All @@ -451,7 +456,7 @@ def _compute_missing_asset_subset(
# cheap call which takes advantage of the partition status cache
partitions_def = self._get_partitions_def(key)
if partitions_def:
cache_value = AssetStatusCacheValue.blocking_get(self, (key, partitions_def))
cache_value = await AssetStatusCacheValue.gen(self, (key, partitions_def))
materialized_subset = (
cache_value.get_materialized_subset(self, key, partitions_def)
if cache_value
Expand All @@ -475,51 +480,56 @@ def _compute_missing_asset_subset(
)

@cached_method
def compute_run_in_progress_subset(self, *, key: EntityKey) -> EntitySubset:
return _dispatch(
async def compute_run_in_progress_subset(self, *, key: EntityKey) -> EntitySubset:
return await _dispatch(
key=key,
check_method=self._compute_run_in_progress_check_subset,
asset_method=self._compute_run_in_progress_asset_subset,
)

@cached_method
def compute_backfill_in_progress_subset(self, *, key: EntityKey) -> EntitySubset:
return _dispatch(
async def compute_backfill_in_progress_subset(self, *, key: EntityKey) -> EntitySubset:
async def get_empty_subset(key: EntityKey) -> EntitySubset:
return self.get_empty_subset(key=key)

return await _dispatch(
key=key,
# asset checks cannot currently be backfilled
check_method=lambda k: self.get_empty_subset(key=k),
check_method=get_empty_subset,
asset_method=self._compute_backfill_in_progress_asset_subset,
)

@cached_method
def compute_execution_failed_subset(self, *, key: EntityKey) -> EntitySubset:
return _dispatch(
async def compute_execution_failed_subset(self, *, key: EntityKey) -> EntitySubset:
return await _dispatch(
key=key,
check_method=self._compute_execution_failed_check_subset,
asset_method=self._compute_execution_failed_asset_subset,
)

@cached_method
def compute_missing_subset(self, *, key: EntityKey, from_subset: EntitySubset) -> EntitySubset:
return _dispatch(
async def compute_missing_subset(
self, *, key: EntityKey, from_subset: EntitySubset
) -> EntitySubset:
return await _dispatch(
key=key,
check_method=self._compute_missing_check_subset,
asset_method=functools.partial(
self._compute_missing_asset_subset, from_subset=from_subset
),
)

def _expensively_filter_entity_subset(
self, subset: EntitySubset, filter_fn: Callable[[Optional[str]], bool]
async def _expensively_filter_entity_subset(
self, subset: EntitySubset, filter_fn: Callable[[Optional[str]], Awaitable[bool]]
) -> EntitySubset:
if subset.is_partitioned:
return subset.compute_intersection_with_partition_keys(
{pk for pk in subset.expensively_compute_partition_keys() if filter_fn(pk)}
{pk for pk in subset.expensively_compute_partition_keys() if await filter_fn(pk)}
)
else:
return (
subset
if not subset.is_empty and filter_fn(None)
if not subset.is_empty and await filter_fn(None)
else self.get_empty_subset(key=subset.key)
)

Expand All @@ -528,48 +538,49 @@ def _run_record_targets_entity(self, run_record: "RunRecord", target_key: Entity
check_selection = run_record.dagster_run.asset_check_selection or set()
return target_key in (asset_selection | check_selection)

def _compute_latest_check_run_executed_with_target(
async def _compute_latest_check_run_executed_with_target(
self, partition_key: Optional[str], query_key: AssetCheckKey, target_key: EntityKey
) -> bool:
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecord
from dagster._core.storage.dagster_run import RunRecord
from dagster._core.storage.event_log.base import AssetCheckSummaryRecord

check.invariant(partition_key is None, "Partitioned checks not supported")
check_record = AssetCheckExecutionRecord.blocking_get(self, query_key)
summary = await AssetCheckSummaryRecord.gen(self, query_key)
check_record = summary.last_check_execution_record if summary else None
if check_record and check_record.event:
run_record = RunRecord.blocking_get(self, check_record.event.run_id)
run_record = await RunRecord.gen(self, check_record.event.run_id)
return bool(run_record) and self._run_record_targets_entity(run_record, target_key)
else:
return False

def _compute_latest_asset_run_executed_with_target(
async def _compute_latest_asset_run_executed_with_target(
self, partition_key: Optional[str], query_key: AssetKey, target_key: EntityKey
) -> bool:
from dagster._core.storage.dagster_run import RunRecord
from dagster._core.storage.event_log.base import AssetRecord

asset_record = AssetRecord.blocking_get(self, query_key)
asset_record = await AssetRecord.gen(self, query_key)
if (
asset_record
and asset_record.asset_entry.last_materialization
and asset_record.asset_entry.last_materialization.asset_materialization
and asset_record.asset_entry.last_materialization.asset_materialization.partition
== partition_key
):
run_record = RunRecord.blocking_get(
run_record = await RunRecord.gen(
self, asset_record.asset_entry.last_materialization.run_id
)
return bool(run_record) and self._run_record_targets_entity(run_record, target_key)
else:
return False

def compute_latest_run_executed_with_subset(
async def compute_latest_run_executed_with_subset(
self, from_subset: EntitySubset, target: EntityKey
) -> EntitySubset:
"""Computes the subset of from_subset for which the latest run also targeted
the provided target EntityKey.
"""
return _dispatch(
return await _dispatch(
key=from_subset.key,
check_method=lambda k: self._expensively_filter_entity_subset(
from_subset,
Expand All @@ -589,22 +600,22 @@ def compute_latest_run_executed_with_subset(
),
)

def _compute_updated_since_cursor_subset(
async def _compute_updated_since_cursor_subset(
self, key: AssetKey, cursor: Optional[int]
) -> EntitySubset[AssetKey]:
value = self._queryer.get_asset_subset_updated_after_cursor(
asset_key=key, after_cursor=cursor
).value
return EntitySubset(self, key=key, value=_ValidatedEntitySubsetValue(value))

def _compute_updated_since_time_subset(
async def _compute_updated_since_time_subset(
self, key: AssetCheckKey, time: datetime
) -> EntitySubset[AssetCheckKey]:
from dagster._core.events import DagsterEventType
from dagster._core.storage.event_log.base import AssetCheckSummaryRecord

# intentionally left unimplemented for AssetKey, as this is a less performant query
summary = AssetCheckSummaryRecord.blocking_get(self, key)
summary = await AssetCheckSummaryRecord.gen(self, key)
record = summary.last_check_execution_record if summary else None
if (
record is None
Expand All @@ -617,10 +628,10 @@ def _compute_updated_since_time_subset(
return self.get_full_subset(key=key)

@cached_method
def compute_updated_since_temporal_context_subset(
async def compute_updated_since_temporal_context_subset(
self, *, key: EntityKey, temporal_context: TemporalContext
) -> EntitySubset:
return _dispatch(
return await _dispatch(
key=key,
check_method=functools.partial(
self._compute_updated_since_time_subset, time=temporal_context.effective_dt
Expand Down Expand Up @@ -705,14 +716,14 @@ def _build_multi_partition_subset(
O_Dispatch = TypeVar("O_Dispatch")


def _dispatch(
async def _dispatch(
*,
key: EntityKey,
check_method: Callable[[AssetCheckKey], O_Dispatch],
asset_method: Callable[[AssetKey], O_Dispatch],
check_method: Callable[[AssetCheckKey], Awaitable[O_Dispatch]],
asset_method: Callable[[AssetKey], Awaitable[O_Dispatch]],
) -> O_Dispatch:
"""Applies a method for either a check or an asset."""
if isinstance(key, AssetCheckKey):
return check_method(key)
return await check_method(key)
else:
return asset_method(key)
return await asset_method(key)
Loading

0 comments on commit bc6dab3

Please sign in to comment.