Skip to content

Commit

Permalink
Pass loading context through _blocking_batch_load
Browse files Browse the repository at this point in the history
  • Loading branch information
briantu committed Oct 21, 2024
1 parent 4a1a5a3 commit 2e65c08
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 25 deletions.
30 changes: 25 additions & 5 deletions python_modules/dagster/dagster/_core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,19 @@ def instance(self) -> "DagsterInstance":
def loaders(self) -> Dict[Type, Tuple[DataLoader, BlockingDataLoader]]:
raise NotImplementedError()

@staticmethod
def ephemeral(instance: "DagsterInstance") -> "LoadingContext":
return EphemeralLoadingContext(instance)

def get_loaders_for(
self, ttype: Type["InstanceLoadableBy"]
) -> Tuple[DataLoader, BlockingDataLoader]:
if ttype not in self.loaders:
if not issubclass(ttype, InstanceLoadableBy):
check.failed(f"{ttype} is not Loadable")

batch_load_fn = partial(ttype._batch_load, instance=self.instance) # noqa
blocking_batch_load_fn = partial(ttype._blocking_batch_load, instance=self.instance) # noqa
batch_load_fn = partial(ttype._batch_load, context=self) # noqa
blocking_batch_load_fn = partial(ttype._blocking_batch_load, context=self) # noqa

self.loaders[ttype] = (
DataLoader(batch_load_fn=batch_load_fn),
Expand All @@ -80,6 +84,22 @@ def clear_loaders(self) -> None:
del self.loaders[ttype]


class EphemeralLoadingContext(LoadingContext):
"""Loading context that can be constructed for short-lived method resolution."""

def __init__(self, instance: "DagsterInstance"):
self._instance = instance
self._loaders = {}

@property
def instance(self) -> "DagsterInstance":
return self._instance

@property
def loaders(self) -> Dict[Type, Tuple[DataLoader, BlockingDataLoader]]:
return self._loaders


# Expected there may be other "Loadable" base classes based on what is needed to load.


Expand All @@ -88,14 +108,14 @@ class InstanceLoadableBy(ABC, Generic[TKey]):

@classmethod
async def _batch_load(
cls, keys: Iterable[TKey], instance: "DagsterInstance"
cls, keys: Iterable[TKey], context: "LoadingContext"
) -> Iterable[Optional[Self]]:
return cls._blocking_batch_load(keys, instance)
return cls._blocking_batch_load(keys, context)

@classmethod
@abstractmethod
def _blocking_batch_load(
cls, keys: Iterable[TKey], instance: "DagsterInstance"
cls, keys: Iterable[TKey], context: "LoadingContext"
) -> Iterable[Optional[Self]]:
# There is no good way of turning an async function into a sync one that
# will allow us to execute that sync function inside of a broader async context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dagster._core.definitions.asset_check_evaluation import AssetCheckEvaluation
from dagster._core.definitions.asset_key import AssetCheckKey
from dagster._core.events.log import DagsterEventType, EventLogEntry
from dagster._core.instance import DagsterInstance
from dagster._core.loader import InstanceLoadableBy, LoadingContext
from dagster._core.storage.dagster_run import DagsterRunStatus, RunRecord
from dagster._serdes.serdes import deserialize_value
Expand Down Expand Up @@ -124,9 +123,9 @@ def from_db_row(cls, row, key: AssetCheckKey) -> "AssetCheckExecutionRecord":

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[AssetCheckKey], instance: DagsterInstance
cls, keys: Iterable[AssetCheckKey], context: LoadingContext
) -> Iterable[Optional["AssetCheckExecutionRecord"]]:
records_by_key = instance.event_log_storage.get_latest_asset_check_execution_by_key(
records_by_key = context.instance.event_log_storage.get_latest_asset_check_execution_by_key(
list(keys)
)
return [records_by_key.get(key) for key in keys]
Expand Down
7 changes: 3 additions & 4 deletions python_modules/dagster/dagster/_core/storage/dagster_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dagster._annotations import PublicAttr, experimental_param, public
from dagster._core.definitions.asset_check_spec import AssetCheckKey
from dagster._core.definitions.events import AssetKey
from dagster._core.loader import InstanceLoadableBy
from dagster._core.loader import InstanceLoadableBy, LoadingContext
from dagster._core.origin import JobPythonOrigin
from dagster._core.storage.tags import (
ASSET_EVALUATION_ID_TAG,
Expand All @@ -41,7 +41,6 @@
if TYPE_CHECKING:
from dagster._core.definitions.schedule_definition import ScheduleDefinition
from dagster._core.definitions.sensor_definition import SensorDefinition
from dagster._core.instance import DagsterInstance
from dagster._core.remote_representation.external import RemoteSchedule, RemoteSensor
from dagster._core.remote_representation.origin import RemoteJobOrigin
from dagster._core.scheduler.instigation import InstigatorState
Expand Down Expand Up @@ -643,12 +642,12 @@ def __new__(

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[str], instance: "DagsterInstance"
cls, keys: Iterable[str], context: LoadingContext
) -> Iterable[Optional["RunRecord"]]:
result_map: Dict[str, Optional[RunRecord]] = {run_id: None for run_id in keys}

# this should be replaced with an async DB call
records = instance.get_run_records(RunsFilter(run_ids=list(result_map.keys())))
records = context.instance.get_run_records(RunsFilter(run_ids=list(result_map.keys())))

for record in records:
result_map[record.dagster_run.run_id] = record
Expand Down
19 changes: 12 additions & 7 deletions python_modules/dagster/dagster/_core/storage/event_log/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
build_run_stats_from_events,
build_run_step_stats_from_events,
)
from dagster._core.instance import DagsterInstance, MayHaveInstanceWeakref, T_DagsterInstance
from dagster._core.loader import InstanceLoadableBy
from dagster._core.instance import MayHaveInstanceWeakref, T_DagsterInstance
from dagster._core.loader import InstanceLoadableBy, LoadingContext
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecord
from dagster._core.storage.dagster_run import DagsterRunStatsSnapshot
from dagster._core.storage.partition_status_cache import get_and_update_asset_status_cache_value
Expand Down Expand Up @@ -138,11 +138,11 @@ class AssetRecord(

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[AssetKey], instance: DagsterInstance
cls, keys: Iterable[AssetKey], context: LoadingContext
) -> Iterable[Optional["AssetRecord"]]:
records_by_key = {
record.asset_entry.asset_key: record
for record in instance.get_asset_records(list(keys))
for record in context.instance.get_asset_records(list(keys))
}
return [records_by_key.get(key) for key in keys]

Expand All @@ -160,9 +160,11 @@ class AssetCheckSummaryRecord(
):
@classmethod
def _blocking_batch_load(
cls, keys: Iterable[AssetCheckKey], instance: DagsterInstance
cls, keys: Iterable[AssetCheckKey], context: LoadingContext
) -> Iterable[Optional["AssetCheckSummaryRecord"]]:
records_by_key = instance.event_log_storage.get_asset_check_summary_records(list(keys))
records_by_key = context.instance.event_log_storage.get_asset_check_summary_records(
list(keys)
)
return [records_by_key[key] for key in keys]


Expand Down Expand Up @@ -653,11 +655,14 @@ def default_run_scoped_event_tailer_offset(self) -> int:
def get_asset_status_cache_values(
self,
partitions_defs_by_key: Mapping[AssetKey, Optional[PartitionsDefinition]],
context: LoadingContext,
) -> Sequence[Optional["AssetStatusCacheValue"]]:
"""Get the cached status information for each asset."""
values = []
for asset_key, partitions_def in partitions_defs_by_key.items():
values.append(
get_and_update_asset_status_cache_value(self._instance, asset_key, partitions_def)
get_and_update_asset_status_cache_value(
self._instance, asset_key, partitions_def, loading_context=context
)
)
return values
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ def from_db_string(db_string: str) -> Optional["AssetStatusCacheValue"]:

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[Tuple[AssetKey, PartitionsDefinition]], instance: "DagsterInstance"
cls, keys: Iterable[Tuple[AssetKey, PartitionsDefinition]], context: LoadingContext
) -> Iterable[Optional["AssetStatusCacheValue"]]:
return instance.event_log_storage.get_asset_status_cache_values(dict(keys))
return context.instance.event_log_storage.get_asset_status_cache_values(dict(keys), context)

def deserialize_materialized_partition_subsets(
self, partitions_def: PartitionsDefinition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from dagster._core.execution.plan.objects import StepFailureData, StepSuccessData
from dagster._core.execution.stats import StepEventStatus
from dagster._core.instance import RUNLESS_JOB_NAME, RUNLESS_RUN_ID
from dagster._core.loader import LoadingContext
from dagster._core.remote_representation.external_data import PartitionsSnap
from dagster._core.remote_representation.origin import (
InProcessCodeLocationOrigin,
Expand Down Expand Up @@ -6023,7 +6024,9 @@ def test_get_updated_asset_status_cache_values(
AssetKey("static"): StaticPartitionsDefinition(["a", "b", "c"]),
}

assert storage.get_asset_status_cache_values(partition_defs_by_key) == [
assert storage.get_asset_status_cache_values(
partition_defs_by_key, LoadingContext.ephemeral(instance)
) == [
None,
None,
None,
Expand All @@ -6038,6 +6041,10 @@ def test_get_updated_asset_status_cache_values(
instance.report_runless_asset_event(AssetMaterialization(asset_key="static", partition="a"))

partition_defs = list(partition_defs_by_key.values())
for i, value in enumerate(storage.get_asset_status_cache_values(partition_defs_by_key)):
for i, value in enumerate(
storage.get_asset_status_cache_values(
partition_defs_by_key, LoadingContext.ephemeral(instance)
),
):
assert value is not None
assert len(value.deserialize_materialized_partition_subsets(partition_defs[i])) == 1
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ class LoadableThing(
):
@classmethod
def _blocking_batch_load(
cls, keys: Iterable[str], instance: mock.MagicMock
cls, keys: Iterable[str], context: mock.MagicMock
) -> List["LoadableThing"]:
instance.query(keys)
context.query(keys)
return [LoadableThing(key, random.randint(0, 100000)) for key in keys]


Expand Down

0 comments on commit 2e65c08

Please sign in to comment.