From 7880a4c45f7977b416c7e370d051168703cef45b Mon Sep 17 00:00:00 2001 From: briantu Date: Fri, 18 Oct 2024 16:52:23 -0400 Subject: [PATCH 1/3] Pass loading context through _blocking_batch_load --- .../dagster/dagster/_core/loader.py | 30 +++++++++++++++---- .../storage/asset_check_execution_record.py | 5 ++-- .../dagster/_core/storage/dagster_run.py | 7 ++--- .../dagster/_core/storage/event_log/base.py | 19 +++++++----- .../_core/storage/partition_status_cache.py | 4 +-- .../storage_tests/utils/event_log_storage.py | 11 +++++-- .../utils_tests/test_dataloader.py | 4 +-- 7 files changed, 55 insertions(+), 25 deletions(-) diff --git a/python_modules/dagster/dagster/_core/loader.py b/python_modules/dagster/dagster/_core/loader.py index e0de5e20159fe..9853d3bb4dedc 100644 --- a/python_modules/dagster/dagster/_core/loader.py +++ b/python_modules/dagster/dagster/_core/loader.py @@ -58,6 +58,10 @@ def instance(self) -> "DagsterInstance": def loaders(self) -> Dict[Type, Tuple[DataLoader, BlockingDataLoader]]: raise NotImplementedError() + @staticmethod + def ephemeral(instance: "DagsterInstance") -> "LoadingContext": + return EphemeralLoadingContext(instance) + def get_loaders_for( self, ttype: Type["InstanceLoadableBy"] ) -> Tuple[DataLoader, BlockingDataLoader]: @@ -65,8 +69,8 @@ def get_loaders_for( if not issubclass(ttype, InstanceLoadableBy): check.failed(f"{ttype} is not Loadable") - batch_load_fn = partial(ttype._batch_load, instance=self.instance) # noqa - blocking_batch_load_fn = partial(ttype._blocking_batch_load, instance=self.instance) # noqa + batch_load_fn = partial(ttype._batch_load, context=self) # noqa + blocking_batch_load_fn = partial(ttype._blocking_batch_load, context=self) # noqa self.loaders[ttype] = ( DataLoader(batch_load_fn=batch_load_fn), @@ -80,6 +84,22 @@ def clear_loaders(self) -> None: del self.loaders[ttype] +class EphemeralLoadingContext(LoadingContext): + """Loading context that can be constructed for short-lived method resolution.""" + + def __init__(self, instance: "DagsterInstance"): + self._instance = instance + self._loaders = {} + + @property + def instance(self) -> "DagsterInstance": + return self._instance + + @property + def loaders(self) -> Dict[Type, Tuple[DataLoader, BlockingDataLoader]]: + return self._loaders + + # Expected there may be other "Loadable" base classes based on what is needed to load. @@ -88,14 +108,14 @@ class InstanceLoadableBy(ABC, Generic[TKey]): @classmethod async def _batch_load( - cls, keys: Iterable[TKey], instance: "DagsterInstance" + cls, keys: Iterable[TKey], context: "LoadingContext" ) -> Iterable[Optional[Self]]: - return cls._blocking_batch_load(keys, instance) + return cls._blocking_batch_load(keys, context) @classmethod @abstractmethod def _blocking_batch_load( - cls, keys: Iterable[TKey], instance: "DagsterInstance" + cls, keys: Iterable[TKey], context: "LoadingContext" ) -> Iterable[Optional[Self]]: # There is no good way of turning an async function into a sync one that # will allow us to execute that sync function inside of a broader async context. diff --git a/python_modules/dagster/dagster/_core/storage/asset_check_execution_record.py b/python_modules/dagster/dagster/_core/storage/asset_check_execution_record.py index 7d7e94496d20e..a76887308e9b3 100644 --- a/python_modules/dagster/dagster/_core/storage/asset_check_execution_record.py +++ b/python_modules/dagster/dagster/_core/storage/asset_check_execution_record.py @@ -5,7 +5,6 @@ from dagster._core.definitions.asset_check_evaluation import AssetCheckEvaluation from dagster._core.definitions.asset_key import AssetCheckKey from dagster._core.events.log import DagsterEventType, EventLogEntry -from dagster._core.instance import DagsterInstance from dagster._core.loader import InstanceLoadableBy, LoadingContext from dagster._core.storage.dagster_run import DagsterRunStatus, RunRecord from dagster._serdes.serdes import deserialize_value @@ -124,9 +123,9 @@ def from_db_row(cls, row, key: AssetCheckKey) -> "AssetCheckExecutionRecord": @classmethod def _blocking_batch_load( - cls, keys: Iterable[AssetCheckKey], instance: DagsterInstance + cls, keys: Iterable[AssetCheckKey], context: LoadingContext ) -> Iterable[Optional["AssetCheckExecutionRecord"]]: - records_by_key = instance.event_log_storage.get_latest_asset_check_execution_by_key( + records_by_key = context.instance.event_log_storage.get_latest_asset_check_execution_by_key( list(keys) ) return [records_by_key.get(key) for key in keys] diff --git a/python_modules/dagster/dagster/_core/storage/dagster_run.py b/python_modules/dagster/dagster/_core/storage/dagster_run.py index e292e6c5fda2b..fdad1a6da185d 100644 --- a/python_modules/dagster/dagster/_core/storage/dagster_run.py +++ b/python_modules/dagster/dagster/_core/storage/dagster_run.py @@ -20,7 +20,7 @@ from dagster._annotations import PublicAttr, experimental_param, public from dagster._core.definitions.asset_check_spec import AssetCheckKey from dagster._core.definitions.events import AssetKey -from dagster._core.loader import InstanceLoadableBy +from dagster._core.loader import InstanceLoadableBy, LoadingContext from dagster._core.origin import JobPythonOrigin from dagster._core.storage.tags import ( ASSET_EVALUATION_ID_TAG, @@ -41,7 +41,6 @@ if TYPE_CHECKING: from dagster._core.definitions.schedule_definition import ScheduleDefinition from dagster._core.definitions.sensor_definition import SensorDefinition - from dagster._core.instance import DagsterInstance from dagster._core.remote_representation.external import RemoteSchedule, RemoteSensor from dagster._core.remote_representation.origin import RemoteJobOrigin from dagster._core.scheduler.instigation import InstigatorState @@ -643,12 +642,12 @@ def __new__( @classmethod def _blocking_batch_load( - cls, keys: Iterable[str], instance: "DagsterInstance" + cls, keys: Iterable[str], context: LoadingContext ) -> Iterable[Optional["RunRecord"]]: result_map: Dict[str, Optional[RunRecord]] = {run_id: None for run_id in keys} # this should be replaced with an async DB call - records = instance.get_run_records(RunsFilter(run_ids=list(result_map.keys()))) + records = context.instance.get_run_records(RunsFilter(run_ids=list(result_map.keys()))) for record in records: result_map[record.dagster_run.run_id] = record diff --git a/python_modules/dagster/dagster/_core/storage/event_log/base.py b/python_modules/dagster/dagster/_core/storage/event_log/base.py index e8ccfeca8878e..2c823ef1b7358 100644 --- a/python_modules/dagster/dagster/_core/storage/event_log/base.py +++ b/python_modules/dagster/dagster/_core/storage/event_log/base.py @@ -32,8 +32,8 @@ build_run_stats_from_events, build_run_step_stats_from_events, ) -from dagster._core.instance import DagsterInstance, MayHaveInstanceWeakref, T_DagsterInstance -from dagster._core.loader import InstanceLoadableBy +from dagster._core.instance import MayHaveInstanceWeakref, T_DagsterInstance +from dagster._core.loader import InstanceLoadableBy, LoadingContext from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecord from dagster._core.storage.dagster_run import DagsterRunStatsSnapshot from dagster._core.storage.partition_status_cache import get_and_update_asset_status_cache_value @@ -138,11 +138,11 @@ class AssetRecord( @classmethod def _blocking_batch_load( - cls, keys: Iterable[AssetKey], instance: DagsterInstance + cls, keys: Iterable[AssetKey], context: LoadingContext ) -> Iterable[Optional["AssetRecord"]]: records_by_key = { record.asset_entry.asset_key: record - for record in instance.get_asset_records(list(keys)) + for record in context.instance.get_asset_records(list(keys)) } return [records_by_key.get(key) for key in keys] @@ -160,9 +160,11 @@ class AssetCheckSummaryRecord( ): @classmethod def _blocking_batch_load( - cls, keys: Iterable[AssetCheckKey], instance: DagsterInstance + cls, keys: Iterable[AssetCheckKey], context: LoadingContext ) -> Iterable[Optional["AssetCheckSummaryRecord"]]: - records_by_key = instance.event_log_storage.get_asset_check_summary_records(list(keys)) + records_by_key = context.instance.event_log_storage.get_asset_check_summary_records( + list(keys) + ) return [records_by_key[key] for key in keys] @@ -653,11 +655,14 @@ def default_run_scoped_event_tailer_offset(self) -> int: def get_asset_status_cache_values( self, partitions_defs_by_key: Mapping[AssetKey, Optional[PartitionsDefinition]], + context: LoadingContext, ) -> Sequence[Optional["AssetStatusCacheValue"]]: """Get the cached status information for each asset.""" values = [] for asset_key, partitions_def in partitions_defs_by_key.items(): values.append( - get_and_update_asset_status_cache_value(self._instance, asset_key, partitions_def) + get_and_update_asset_status_cache_value( + self._instance, asset_key, partitions_def, loading_context=context + ) ) return values diff --git a/python_modules/dagster/dagster/_core/storage/partition_status_cache.py b/python_modules/dagster/dagster/_core/storage/partition_status_cache.py index 6e2c194e8a824..6e653132cbd22 100644 --- a/python_modules/dagster/dagster/_core/storage/partition_status_cache.py +++ b/python_modules/dagster/dagster/_core/storage/partition_status_cache.py @@ -148,9 +148,9 @@ def from_db_string(db_string: str) -> Optional["AssetStatusCacheValue"]: @classmethod def _blocking_batch_load( - cls, keys: Iterable[Tuple[AssetKey, PartitionsDefinition]], instance: "DagsterInstance" + cls, keys: Iterable[Tuple[AssetKey, PartitionsDefinition]], context: LoadingContext ) -> Iterable[Optional["AssetStatusCacheValue"]]: - return instance.event_log_storage.get_asset_status_cache_values(dict(keys)) + return context.instance.event_log_storage.get_asset_status_cache_values(dict(keys), context) def deserialize_materialized_partition_subsets( self, partitions_def: PartitionsDefinition diff --git a/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py b/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py index 63902b4ccf886..ef3aa3ac81a96 100644 --- a/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py +++ b/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py @@ -83,6 +83,7 @@ from dagster._core.execution.plan.objects import StepFailureData, StepSuccessData from dagster._core.execution.stats import StepEventStatus from dagster._core.instance import RUNLESS_JOB_NAME, RUNLESS_RUN_ID +from dagster._core.loader import LoadingContext from dagster._core.remote_representation.external_data import PartitionsSnap from dagster._core.remote_representation.origin import ( InProcessCodeLocationOrigin, @@ -6023,7 +6024,9 @@ def test_get_updated_asset_status_cache_values( AssetKey("static"): StaticPartitionsDefinition(["a", "b", "c"]), } - assert storage.get_asset_status_cache_values(partition_defs_by_key) == [ + assert storage.get_asset_status_cache_values( + partition_defs_by_key, LoadingContext.ephemeral(instance) + ) == [ None, None, None, @@ -6038,6 +6041,10 @@ def test_get_updated_asset_status_cache_values( instance.report_runless_asset_event(AssetMaterialization(asset_key="static", partition="a")) partition_defs = list(partition_defs_by_key.values()) - for i, value in enumerate(storage.get_asset_status_cache_values(partition_defs_by_key)): + for i, value in enumerate( + storage.get_asset_status_cache_values( + partition_defs_by_key, LoadingContext.ephemeral(instance) + ), + ): assert value is not None assert len(value.deserialize_materialized_partition_subsets(partition_defs[i])) == 1 diff --git a/python_modules/dagster/dagster_tests/utils_tests/test_dataloader.py b/python_modules/dagster/dagster_tests/utils_tests/test_dataloader.py index 8fe56fdfbe836..b1e24716cb746 100644 --- a/python_modules/dagster/dagster_tests/utils_tests/test_dataloader.py +++ b/python_modules/dagster/dagster_tests/utils_tests/test_dataloader.py @@ -144,9 +144,9 @@ class LoadableThing( ): @classmethod def _blocking_batch_load( - cls, keys: Iterable[str], instance: mock.MagicMock + cls, keys: Iterable[str], context: mock.MagicMock ) -> List["LoadableThing"]: - instance.query(keys) + context.query(keys) return [LoadableThing(key, random.randint(0, 100000)) for key in keys] From da94d727bbbb68f887e727f6ea1db1199ba54d3c Mon Sep 17 00:00:00 2001 From: briantu Date: Mon, 21 Oct 2024 18:30:00 -0400 Subject: [PATCH 2/3] Use LoadingContextForTest --- python_modules/dagster/dagster/_core/loader.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/python_modules/dagster/dagster/_core/loader.py b/python_modules/dagster/dagster/_core/loader.py index 9853d3bb4dedc..adec5e3084d95 100644 --- a/python_modules/dagster/dagster/_core/loader.py +++ b/python_modules/dagster/dagster/_core/loader.py @@ -60,7 +60,7 @@ def loaders(self) -> Dict[Type, Tuple[DataLoader, BlockingDataLoader]]: @staticmethod def ephemeral(instance: "DagsterInstance") -> "LoadingContext": - return EphemeralLoadingContext(instance) + return LoadingContextForTest(instance) def get_loaders_for( self, ttype: Type["InstanceLoadableBy"] @@ -84,22 +84,6 @@ def clear_loaders(self) -> None: del self.loaders[ttype] -class EphemeralLoadingContext(LoadingContext): - """Loading context that can be constructed for short-lived method resolution.""" - - def __init__(self, instance: "DagsterInstance"): - self._instance = instance - self._loaders = {} - - @property - def instance(self) -> "DagsterInstance": - return self._instance - - @property - def loaders(self) -> Dict[Type, Tuple[DataLoader, BlockingDataLoader]]: - return self._loaders - - # Expected there may be other "Loadable" base classes based on what is needed to load. From 45d68da4a2397a6e3f3a05bd49d258b370555130 Mon Sep 17 00:00:00 2001 From: briantu Date: Tue, 22 Oct 2024 16:39:08 -0400 Subject: [PATCH 3/3] Use LoadingContextForTest --- python_modules/dagster/dagster/_core/loader.py | 4 ---- .../dagster_tests/storage_tests/utils/event_log_storage.py | 6 +++--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/python_modules/dagster/dagster/_core/loader.py b/python_modules/dagster/dagster/_core/loader.py index adec5e3084d95..5b78cc1b65d11 100644 --- a/python_modules/dagster/dagster/_core/loader.py +++ b/python_modules/dagster/dagster/_core/loader.py @@ -58,10 +58,6 @@ def instance(self) -> "DagsterInstance": def loaders(self) -> Dict[Type, Tuple[DataLoader, BlockingDataLoader]]: raise NotImplementedError() - @staticmethod - def ephemeral(instance: "DagsterInstance") -> "LoadingContext": - return LoadingContextForTest(instance) - def get_loaders_for( self, ttype: Type["InstanceLoadableBy"] ) -> Tuple[DataLoader, BlockingDataLoader]: diff --git a/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py b/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py index ef3aa3ac81a96..bef2efdb64ef7 100644 --- a/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py +++ b/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py @@ -83,7 +83,7 @@ from dagster._core.execution.plan.objects import StepFailureData, StepSuccessData from dagster._core.execution.stats import StepEventStatus from dagster._core.instance import RUNLESS_JOB_NAME, RUNLESS_RUN_ID -from dagster._core.loader import LoadingContext +from dagster._core.loader import LoadingContextForTest from dagster._core.remote_representation.external_data import PartitionsSnap from dagster._core.remote_representation.origin import ( InProcessCodeLocationOrigin, @@ -6025,7 +6025,7 @@ def test_get_updated_asset_status_cache_values( } assert storage.get_asset_status_cache_values( - partition_defs_by_key, LoadingContext.ephemeral(instance) + partition_defs_by_key, LoadingContextForTest(instance) ) == [ None, None, @@ -6043,7 +6043,7 @@ def test_get_updated_asset_status_cache_values( partition_defs = list(partition_defs_by_key.values()) for i, value in enumerate( storage.get_asset_status_cache_values( - partition_defs_by_key, LoadingContext.ephemeral(instance) + partition_defs_by_key, LoadingContextForTest(instance) ), ): assert value is not None