Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass loading context through _blocking_batch_load #25377

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python_modules/dagster/dagster/_core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def get_loaders_for(
if not issubclass(ttype, InstanceLoadableBy):
check.failed(f"{ttype} is not Loadable")

batch_load_fn = partial(ttype._batch_load, instance=self.instance) # noqa
blocking_batch_load_fn = partial(ttype._blocking_batch_load, instance=self.instance) # noqa
Comment on lines 65 to -69
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when I first set this up I imagined we might have other _X_LoadableBy which would bind _X_ instead of the instance for fetching, so in this case ContextLoadableBy

i'm not sure that having different flavors is necessary in the end here - but we might want to drop the Instance part of the naming for LoadableBy interface if we just want to always pass the context

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah my thinking was that in a world where there are multiple things that we want to load with (i.e. multiple Xs), there would likely conceivably be singular objects that would want more than one of those Xs, and so at that point we'd probably just do this change anyway.

Agree LoadableBy would be a better name (maybe that could be a followup to this PR?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow up or in this PR is fine with me

batch_load_fn = partial(ttype._batch_load, context=self) # noqa
blocking_batch_load_fn = partial(ttype._blocking_batch_load, context=self) # noqa

self.loaders[ttype] = (
DataLoader(batch_load_fn=batch_load_fn),
Expand All @@ -88,14 +88,14 @@ class InstanceLoadableBy(ABC, Generic[TKey]):

@classmethod
async def _batch_load(
cls, keys: Iterable[TKey], instance: "DagsterInstance"
cls, keys: Iterable[TKey], context: "LoadingContext"
) -> Iterable[Optional[Self]]:
return cls._blocking_batch_load(keys, instance)
return cls._blocking_batch_load(keys, context)

@classmethod
@abstractmethod
def _blocking_batch_load(
cls, keys: Iterable[TKey], instance: "DagsterInstance"
cls, keys: Iterable[TKey], context: "LoadingContext"
) -> Iterable[Optional[Self]]:
# There is no good way of turning an async function into a sync one that
# will allow us to execute that sync function inside of a broader async context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dagster._core.definitions.asset_check_evaluation import AssetCheckEvaluation
from dagster._core.definitions.asset_key import AssetCheckKey
from dagster._core.events.log import DagsterEventType, EventLogEntry
from dagster._core.instance import DagsterInstance
from dagster._core.loader import InstanceLoadableBy, LoadingContext
from dagster._core.storage.dagster_run import DagsterRunStatus, RunRecord
from dagster._serdes.serdes import deserialize_value
Expand Down Expand Up @@ -124,9 +123,9 @@ def from_db_row(cls, row, key: AssetCheckKey) -> "AssetCheckExecutionRecord":

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[AssetCheckKey], instance: DagsterInstance
cls, keys: Iterable[AssetCheckKey], context: LoadingContext
) -> Iterable[Optional["AssetCheckExecutionRecord"]]:
records_by_key = instance.event_log_storage.get_latest_asset_check_execution_by_key(
records_by_key = context.instance.event_log_storage.get_latest_asset_check_execution_by_key(
list(keys)
)
return [records_by_key.get(key) for key in keys]
Expand Down
7 changes: 3 additions & 4 deletions python_modules/dagster/dagster/_core/storage/dagster_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dagster._annotations import PublicAttr, experimental_param, public
from dagster._core.definitions.asset_check_spec import AssetCheckKey
from dagster._core.definitions.events import AssetKey
from dagster._core.loader import InstanceLoadableBy
from dagster._core.loader import InstanceLoadableBy, LoadingContext
from dagster._core.origin import JobPythonOrigin
from dagster._core.storage.tags import (
ASSET_EVALUATION_ID_TAG,
Expand All @@ -41,7 +41,6 @@
if TYPE_CHECKING:
from dagster._core.definitions.schedule_definition import ScheduleDefinition
from dagster._core.definitions.sensor_definition import SensorDefinition
from dagster._core.instance import DagsterInstance
from dagster._core.remote_representation.external import RemoteSchedule, RemoteSensor
from dagster._core.remote_representation.origin import RemoteJobOrigin
from dagster._core.scheduler.instigation import InstigatorState
Expand Down Expand Up @@ -643,12 +642,12 @@ def __new__(

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[str], instance: "DagsterInstance"
cls, keys: Iterable[str], context: LoadingContext
) -> Iterable[Optional["RunRecord"]]:
result_map: Dict[str, Optional[RunRecord]] = {run_id: None for run_id in keys}

# this should be replaced with an async DB call
records = instance.get_run_records(RunsFilter(run_ids=list(result_map.keys())))
records = context.instance.get_run_records(RunsFilter(run_ids=list(result_map.keys())))

for record in records:
result_map[record.dagster_run.run_id] = record
Expand Down
19 changes: 12 additions & 7 deletions python_modules/dagster/dagster/_core/storage/event_log/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
build_run_stats_from_events,
build_run_step_stats_from_events,
)
from dagster._core.instance import DagsterInstance, MayHaveInstanceWeakref, T_DagsterInstance
from dagster._core.loader import InstanceLoadableBy
from dagster._core.instance import MayHaveInstanceWeakref, T_DagsterInstance
from dagster._core.loader import InstanceLoadableBy, LoadingContext
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecord
from dagster._core.storage.dagster_run import DagsterRunStatsSnapshot
from dagster._core.storage.partition_status_cache import get_and_update_asset_status_cache_value
Expand Down Expand Up @@ -138,11 +138,11 @@ class AssetRecord(

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[AssetKey], instance: DagsterInstance
cls, keys: Iterable[AssetKey], context: LoadingContext
) -> Iterable[Optional["AssetRecord"]]:
records_by_key = {
record.asset_entry.asset_key: record
for record in instance.get_asset_records(list(keys))
for record in context.instance.get_asset_records(list(keys))
}
return [records_by_key.get(key) for key in keys]

Expand All @@ -160,9 +160,11 @@ class AssetCheckSummaryRecord(
):
@classmethod
def _blocking_batch_load(
cls, keys: Iterable[AssetCheckKey], instance: DagsterInstance
cls, keys: Iterable[AssetCheckKey], context: LoadingContext
) -> Iterable[Optional["AssetCheckSummaryRecord"]]:
records_by_key = instance.event_log_storage.get_asset_check_summary_records(list(keys))
records_by_key = context.instance.event_log_storage.get_asset_check_summary_records(
list(keys)
)
return [records_by_key[key] for key in keys]


Expand Down Expand Up @@ -653,11 +655,14 @@ def default_run_scoped_event_tailer_offset(self) -> int:
def get_asset_status_cache_values(
self,
partitions_defs_by_key: Mapping[AssetKey, Optional[PartitionsDefinition]],
context: LoadingContext,
) -> Sequence[Optional["AssetStatusCacheValue"]]:
"""Get the cached status information for each asset."""
values = []
for asset_key, partitions_def in partitions_defs_by_key.items():
values.append(
get_and_update_asset_status_cache_value(self._instance, asset_key, partitions_def)
get_and_update_asset_status_cache_value(
self._instance, asset_key, partitions_def, loading_context=context
)
)
return values
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ def from_db_string(db_string: str) -> Optional["AssetStatusCacheValue"]:

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[Tuple[AssetKey, PartitionsDefinition]], instance: "DagsterInstance"
cls, keys: Iterable[Tuple[AssetKey, PartitionsDefinition]], context: LoadingContext
) -> Iterable[Optional["AssetStatusCacheValue"]]:
return instance.event_log_storage.get_asset_status_cache_values(dict(keys))
return context.instance.event_log_storage.get_asset_status_cache_values(dict(keys), context)

def deserialize_materialized_partition_subsets(
self, partitions_def: PartitionsDefinition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from dagster._core.execution.plan.objects import StepFailureData, StepSuccessData
from dagster._core.execution.stats import StepEventStatus
from dagster._core.instance import RUNLESS_JOB_NAME, RUNLESS_RUN_ID
from dagster._core.loader import LoadingContextForTest
from dagster._core.remote_representation.external_data import PartitionsSnap
from dagster._core.remote_representation.origin import (
InProcessCodeLocationOrigin,
Expand Down Expand Up @@ -6023,7 +6024,9 @@ def test_get_updated_asset_status_cache_values(
AssetKey("static"): StaticPartitionsDefinition(["a", "b", "c"]),
}

assert storage.get_asset_status_cache_values(partition_defs_by_key) == [
assert storage.get_asset_status_cache_values(
partition_defs_by_key, LoadingContextForTest(instance)
) == [
None,
None,
None,
Expand All @@ -6038,6 +6041,10 @@ def test_get_updated_asset_status_cache_values(
instance.report_runless_asset_event(AssetMaterialization(asset_key="static", partition="a"))

partition_defs = list(partition_defs_by_key.values())
for i, value in enumerate(storage.get_asset_status_cache_values(partition_defs_by_key)):
for i, value in enumerate(
storage.get_asset_status_cache_values(
partition_defs_by_key, LoadingContextForTest(instance)
),
):
assert value is not None
assert len(value.deserialize_materialized_partition_subsets(partition_defs[i])) == 1
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ class LoadableThing(
):
@classmethod
def _blocking_batch_load(
cls, keys: Iterable[str], instance: mock.MagicMock
cls, keys: Iterable[str], context: mock.MagicMock
) -> List["LoadableThing"]:
instance.query(keys)
context.query(keys)
return [LoadableThing(key, random.randint(0, 100000)) for key in keys]


Expand Down