diff --git a/python_modules/dagster/dagster/_core/loader.py b/python_modules/dagster/dagster/_core/loader.py index e0de5e20159fe..5b78cc1b65d11 100644 --- a/python_modules/dagster/dagster/_core/loader.py +++ b/python_modules/dagster/dagster/_core/loader.py @@ -65,8 +65,8 @@ def get_loaders_for( if not issubclass(ttype, InstanceLoadableBy): check.failed(f"{ttype} is not Loadable") - batch_load_fn = partial(ttype._batch_load, instance=self.instance) # noqa - blocking_batch_load_fn = partial(ttype._blocking_batch_load, instance=self.instance) # noqa + batch_load_fn = partial(ttype._batch_load, context=self) # noqa + blocking_batch_load_fn = partial(ttype._blocking_batch_load, context=self) # noqa self.loaders[ttype] = ( DataLoader(batch_load_fn=batch_load_fn), @@ -88,14 +88,14 @@ class InstanceLoadableBy(ABC, Generic[TKey]): @classmethod async def _batch_load( - cls, keys: Iterable[TKey], instance: "DagsterInstance" + cls, keys: Iterable[TKey], context: "LoadingContext" ) -> Iterable[Optional[Self]]: - return cls._blocking_batch_load(keys, instance) + return cls._blocking_batch_load(keys, context) @classmethod @abstractmethod def _blocking_batch_load( - cls, keys: Iterable[TKey], instance: "DagsterInstance" + cls, keys: Iterable[TKey], context: "LoadingContext" ) -> Iterable[Optional[Self]]: # There is no good way of turning an async function into a sync one that # will allow us to execute that sync function inside of a broader async context. diff --git a/python_modules/dagster/dagster/_core/storage/asset_check_execution_record.py b/python_modules/dagster/dagster/_core/storage/asset_check_execution_record.py index 7d7e94496d20e..a76887308e9b3 100644 --- a/python_modules/dagster/dagster/_core/storage/asset_check_execution_record.py +++ b/python_modules/dagster/dagster/_core/storage/asset_check_execution_record.py @@ -5,7 +5,6 @@ from dagster._core.definitions.asset_check_evaluation import AssetCheckEvaluation from dagster._core.definitions.asset_key import AssetCheckKey from dagster._core.events.log import DagsterEventType, EventLogEntry -from dagster._core.instance import DagsterInstance from dagster._core.loader import InstanceLoadableBy, LoadingContext from dagster._core.storage.dagster_run import DagsterRunStatus, RunRecord from dagster._serdes.serdes import deserialize_value @@ -124,9 +123,9 @@ def from_db_row(cls, row, key: AssetCheckKey) -> "AssetCheckExecutionRecord": @classmethod def _blocking_batch_load( - cls, keys: Iterable[AssetCheckKey], instance: DagsterInstance + cls, keys: Iterable[AssetCheckKey], context: LoadingContext ) -> Iterable[Optional["AssetCheckExecutionRecord"]]: - records_by_key = instance.event_log_storage.get_latest_asset_check_execution_by_key( + records_by_key = context.instance.event_log_storage.get_latest_asset_check_execution_by_key( list(keys) ) return [records_by_key.get(key) for key in keys] diff --git a/python_modules/dagster/dagster/_core/storage/dagster_run.py b/python_modules/dagster/dagster/_core/storage/dagster_run.py index e292e6c5fda2b..fdad1a6da185d 100644 --- a/python_modules/dagster/dagster/_core/storage/dagster_run.py +++ b/python_modules/dagster/dagster/_core/storage/dagster_run.py @@ -20,7 +20,7 @@ from dagster._annotations import PublicAttr, experimental_param, public from dagster._core.definitions.asset_check_spec import AssetCheckKey from dagster._core.definitions.events import AssetKey -from dagster._core.loader import InstanceLoadableBy +from dagster._core.loader import InstanceLoadableBy, LoadingContext from dagster._core.origin import JobPythonOrigin from dagster._core.storage.tags import ( ASSET_EVALUATION_ID_TAG, @@ -41,7 +41,6 @@ if TYPE_CHECKING: from dagster._core.definitions.schedule_definition import ScheduleDefinition from dagster._core.definitions.sensor_definition import SensorDefinition - from dagster._core.instance import DagsterInstance from dagster._core.remote_representation.external import RemoteSchedule, RemoteSensor from dagster._core.remote_representation.origin import RemoteJobOrigin from dagster._core.scheduler.instigation import InstigatorState @@ -643,12 +642,12 @@ def __new__( @classmethod def _blocking_batch_load( - cls, keys: Iterable[str], instance: "DagsterInstance" + cls, keys: Iterable[str], context: LoadingContext ) -> Iterable[Optional["RunRecord"]]: result_map: Dict[str, Optional[RunRecord]] = {run_id: None for run_id in keys} # this should be replaced with an async DB call - records = instance.get_run_records(RunsFilter(run_ids=list(result_map.keys()))) + records = context.instance.get_run_records(RunsFilter(run_ids=list(result_map.keys()))) for record in records: result_map[record.dagster_run.run_id] = record diff --git a/python_modules/dagster/dagster/_core/storage/event_log/base.py b/python_modules/dagster/dagster/_core/storage/event_log/base.py index e8ccfeca8878e..2c823ef1b7358 100644 --- a/python_modules/dagster/dagster/_core/storage/event_log/base.py +++ b/python_modules/dagster/dagster/_core/storage/event_log/base.py @@ -32,8 +32,8 @@ build_run_stats_from_events, build_run_step_stats_from_events, ) -from dagster._core.instance import DagsterInstance, MayHaveInstanceWeakref, T_DagsterInstance -from dagster._core.loader import InstanceLoadableBy +from dagster._core.instance import MayHaveInstanceWeakref, T_DagsterInstance +from dagster._core.loader import InstanceLoadableBy, LoadingContext from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecord from dagster._core.storage.dagster_run import DagsterRunStatsSnapshot from dagster._core.storage.partition_status_cache import get_and_update_asset_status_cache_value @@ -138,11 +138,11 @@ class AssetRecord( @classmethod def _blocking_batch_load( - cls, keys: Iterable[AssetKey], instance: DagsterInstance + cls, keys: Iterable[AssetKey], context: LoadingContext ) -> Iterable[Optional["AssetRecord"]]: records_by_key = { record.asset_entry.asset_key: record - for record in instance.get_asset_records(list(keys)) + for record in context.instance.get_asset_records(list(keys)) } return [records_by_key.get(key) for key in keys] @@ -160,9 +160,11 @@ class AssetCheckSummaryRecord( ): @classmethod def _blocking_batch_load( - cls, keys: Iterable[AssetCheckKey], instance: DagsterInstance + cls, keys: Iterable[AssetCheckKey], context: LoadingContext ) -> Iterable[Optional["AssetCheckSummaryRecord"]]: - records_by_key = instance.event_log_storage.get_asset_check_summary_records(list(keys)) + records_by_key = context.instance.event_log_storage.get_asset_check_summary_records( + list(keys) + ) return [records_by_key[key] for key in keys] @@ -653,11 +655,14 @@ def default_run_scoped_event_tailer_offset(self) -> int: def get_asset_status_cache_values( self, partitions_defs_by_key: Mapping[AssetKey, Optional[PartitionsDefinition]], + context: LoadingContext, ) -> Sequence[Optional["AssetStatusCacheValue"]]: """Get the cached status information for each asset.""" values = [] for asset_key, partitions_def in partitions_defs_by_key.items(): values.append( - get_and_update_asset_status_cache_value(self._instance, asset_key, partitions_def) + get_and_update_asset_status_cache_value( + self._instance, asset_key, partitions_def, loading_context=context + ) ) return values diff --git a/python_modules/dagster/dagster/_core/storage/partition_status_cache.py b/python_modules/dagster/dagster/_core/storage/partition_status_cache.py index 6e2c194e8a824..6e653132cbd22 100644 --- a/python_modules/dagster/dagster/_core/storage/partition_status_cache.py +++ b/python_modules/dagster/dagster/_core/storage/partition_status_cache.py @@ -148,9 +148,9 @@ def from_db_string(db_string: str) -> Optional["AssetStatusCacheValue"]: @classmethod def _blocking_batch_load( - cls, keys: Iterable[Tuple[AssetKey, PartitionsDefinition]], instance: "DagsterInstance" + cls, keys: Iterable[Tuple[AssetKey, PartitionsDefinition]], context: LoadingContext ) -> Iterable[Optional["AssetStatusCacheValue"]]: - return instance.event_log_storage.get_asset_status_cache_values(dict(keys)) + return context.instance.event_log_storage.get_asset_status_cache_values(dict(keys), context) def deserialize_materialized_partition_subsets( self, partitions_def: PartitionsDefinition diff --git a/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py b/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py index 63902b4ccf886..bef2efdb64ef7 100644 --- a/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py +++ b/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py @@ -83,6 +83,7 @@ from dagster._core.execution.plan.objects import StepFailureData, StepSuccessData from dagster._core.execution.stats import StepEventStatus from dagster._core.instance import RUNLESS_JOB_NAME, RUNLESS_RUN_ID +from dagster._core.loader import LoadingContextForTest from dagster._core.remote_representation.external_data import PartitionsSnap from dagster._core.remote_representation.origin import ( InProcessCodeLocationOrigin, @@ -6023,7 +6024,9 @@ def test_get_updated_asset_status_cache_values( AssetKey("static"): StaticPartitionsDefinition(["a", "b", "c"]), } - assert storage.get_asset_status_cache_values(partition_defs_by_key) == [ + assert storage.get_asset_status_cache_values( + partition_defs_by_key, LoadingContextForTest(instance) + ) == [ None, None, None, @@ -6038,6 +6041,10 @@ def test_get_updated_asset_status_cache_values( instance.report_runless_asset_event(AssetMaterialization(asset_key="static", partition="a")) partition_defs = list(partition_defs_by_key.values()) - for i, value in enumerate(storage.get_asset_status_cache_values(partition_defs_by_key)): + for i, value in enumerate( + storage.get_asset_status_cache_values( + partition_defs_by_key, LoadingContextForTest(instance) + ), + ): assert value is not None assert len(value.deserialize_materialized_partition_subsets(partition_defs[i])) == 1 diff --git a/python_modules/dagster/dagster_tests/utils_tests/test_dataloader.py b/python_modules/dagster/dagster_tests/utils_tests/test_dataloader.py index 8fe56fdfbe836..b1e24716cb746 100644 --- a/python_modules/dagster/dagster_tests/utils_tests/test_dataloader.py +++ b/python_modules/dagster/dagster_tests/utils_tests/test_dataloader.py @@ -144,9 +144,9 @@ class LoadableThing( ): @classmethod def _blocking_batch_load( - cls, keys: Iterable[str], instance: mock.MagicMock + cls, keys: Iterable[str], context: mock.MagicMock ) -> List["LoadableThing"]: - instance.query(keys) + context.query(keys) return [LoadableThing(key, random.randint(0, 100000)) for key in keys]