diff --git a/examples/partition_example/partition_example.py b/examples/partition_example/partition_example.py index 5fc8e80c2499c..638bfc87088d1 100644 --- a/examples/partition_example/partition_example.py +++ b/examples/partition_example/partition_example.py @@ -33,7 +33,7 @@ def relativedelta(*args, **kwargs): metadata={"partition_expr": "LastModifiedDate"}, ) def salesforce_customers(context: AssetExecutionContext) -> pd.DataFrame: - start_date_str = context.asset_partition_key_for_output() + start_date_str = context.partition_info.key timezone = pytz.timezone("GMT") # Replace 'Your_Timezone' with the desired timezone start_obj = datetime.datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone) @@ -65,7 +65,7 @@ def realized_vol(context: AssetExecutionContext, orats_daily_prices: pd.DataFram The volatility is calculated using various methods such as close-to-close, Parkinson, Hodges-Tompkins, and Yang-Zhang. The function returns a DataFrame with the calculated volatilities. """ - trade_date = context.asset_partition_key_for_output() + trade_date = context.partition_info.key ticker_id = 1 df = all_realvols(orats_daily_prices, ticker_id, trade_date) @@ -80,7 +80,7 @@ def realized_vol(context: AssetExecutionContext, orats_daily_prices: pd.DataFram @asset(io_manager_def="parquet_io_manager", partitions_def=hourly_partitions) def my_custom_df(context: AssetExecutionContext) -> pd.DataFrame: - start, end = context.asset_partitions_time_window_for_output() + start, end = context.partition_info.time_window df = pd.DataFrame({"timestamp": pd.date_range(start, end, freq="5T")}) df["count"] = df["timestamp"].map(lambda a: random.randint(1, 1000)) @@ -93,7 +93,7 @@ def fetch_blog_posts_from_external_api(*args, **kwargs): @asset(partitions_def=HourlyPartitionsDefinition(start_date="2022-01-01-00:00")) def blog_posts(context: AssetExecutionContext) -> List[Dict]: - partition_datetime_str = context.asset_partition_key_for_output() + partition_datetime_str = context.partition_info.key hour = datetime.datetime.fromisoformat(partition_datetime_str) posts = fetch_blog_posts_from_external_api(hour_when_posted=hour) return posts @@ -106,7 +106,7 @@ def blog_posts(context: AssetExecutionContext) -> List[Dict]: key_prefix=["snowflake", "eldermark_proxy"], ) def resident(context: AssetExecutionContext) -> Output[pd.DataFrame]: - start, end = context.asset_partitions_time_window_for_output() + start, end = context.partition_info.time_window filter_str = f"LastMod_Stamp >= {start.timestamp()} AND LastMod_Stamp < {end.timestamp()}" records = context.resources.eldermark.fetch_obj(obj="Resident", filter=filter_str) diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index 792d061509c3b..a01cbb1d6bc34 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -1,6 +1,7 @@ from abc import ABC, ABCMeta, abstractmethod from contextlib import contextmanager from contextvars import ContextVar +from functools import cached_property from inspect import ( _empty as EmptyAnnotation, ) @@ -47,7 +48,10 @@ from dagster._core.definitions.partition import PartitionsDefinition from dagster._core.definitions.partition_key_range import PartitionKeyRange from dagster._core.definitions.step_launcher import StepLauncher -from dagster._core.definitions.time_window_partitions import TimeWindow +from dagster._core.definitions.time_window_partitions import ( + TimeWindow, + has_one_dimension_time_window_partitioning, +) from dagster._core.errors import ( DagsterInvalidDefinitionError, DagsterInvalidPropertyError, @@ -1375,6 +1379,17 @@ def _copy_docs_from_op_execution_context(obj): "run_config": "run.run_config", "run_tags": "run.tags", "get_op_execution_context": "op_execution_context", + "asset_partition_key_for_output": "partition_info.key", + "asset_partitions_time_window_for_output": "partition_info.time_window", + "asset_partition_key_range_for_output": "partition_info.key_range", + "asset_partition_key_range_for_input": "upstream_partition_info(asset_key).key_range", + "asset_partition_key_for_input": "upstream_partition_info(asset_key).key", + "asset_partitions_def_for_output": "partition_info.definition", + "asset_partitions_def_for_input": "upstream_partition_info(asset_key).definition", + "asset_partition_keys_for_output": "partition_info.keys", + "asset_partition_keys_for_input": "upstream_partition_info(asset_key).keys", + "asset_partitions_time_window_for_input": "upstream_partition_info(asset_key).time_window", + "has_partition_key": "is_partitioned_materialization", } ALTERNATE_EXPRESSIONS = { @@ -1401,6 +1416,47 @@ def _get_deprecation_kwargs(attr: str): return deprecation_kwargs +class PartitionInfo: # TODO - this is a bad name, figure out something else + def __init__( + self, + key: Optional[str], + keys: Sequence[str], + key_range: PartitionKeyRange, + time_window: Optional[TimeWindow], + definition: PartitionsDefinition, + ): + self._key = key + self._keys = keys + self._key_range = key_range + self._time_window = time_window + self._definition = definition + + @property + def key(self) -> str: + if self._key is None: + raise DagsterInvariantViolationError( + "Cannot access partition key for a partitioned run with a range of partitions." + " Call key_range instead." + ) + return self._key + + @property + def keys(self) -> Sequence[str]: + return self._keys + + @property + def key_range(self) -> PartitionKeyRange: + return self._key_range + + @property + def time_window(self) -> TimeWindow: + if self._time_window is None: + raise ValueError( + "Tried to get partition time window for an asset that is not time-partitioned." + ) + return self._time_window + + class AssetExecutionContext(OpExecutionContext): def __init__(self, op_execution_context: OpExecutionContext) -> None: self._op_execution_context = check.inst_param( @@ -1408,6 +1464,8 @@ def __init__(self, op_execution_context: OpExecutionContext) -> None: ) self._step_execution_context = self._op_execution_context._step_execution_context # noqa: SLF001 + self._cached_upstream_partition_infos: Dict[AssetKey, PartitionInfo] = {} + @staticmethod def get() -> "AssetExecutionContext": ctx = _current_asset_execution_context.get() @@ -1493,6 +1551,70 @@ def latest_materialization_for_upstream_asset( "in order to call latest_materialization_for_upstream_asset." ) + @public + @property + def is_partitioned_materialization(self) -> bool: + return self.op_execution_context.has_partition_key + + @cached_property + def partition_info(self) -> PartitionInfo: + """Returns a filled out PartitionInfo for the currently materializing asset.""" + partitions_def = self.assets_def.partitions_def + if self.is_partitioned_materialization and partitions_def: + key_range = self.op_execution_context.partition_key_range + return PartitionInfo( + key=self.op_execution_context.partition_key + if key_range.start == key_range.end + else None, + keys=self.op_execution_context.partition_keys, + key_range=key_range, + definition=partitions_def, + time_window=self.op_execution_context.partition_time_window + if has_one_dimension_time_window_partitioning(partitions_def) + else None, + ) + raise DagsterInvariantViolationError( + "Cannot access information about the asset's partition in a non-partitioned run." + ) + + def upstream_partition_info(self, key: CoercibleToAssetKey) -> PartitionInfo: + """Returns a filled out PartitionInfo of 'key' that the partition of the + currently materializing asset depends on. + """ + asset_key = AssetKey.from_coercible(key) + if asset_key in self._cached_upstream_partition_infos.keys(): + return self._cached_upstream_partition_infos[asset_key] + + partitions_def = self._step_execution_context.job_def.asset_layer.partitions_def_for_asset( + asset_key + ) + if self.is_partitioned_materialization and partitions_def: + key_range = self._step_execution_context.asset_partition_key_range_for_upstream( + asset_key + ) + info = PartitionInfo( + key=self._step_execution_context.asset_partition_key_for_upstream(asset_key) + if key_range.start == key_range.end + else None, + keys=list( + self._step_execution_context.asset_partitions_subset_for_upstream( + asset_key + ).get_partition_keys() + ), + key_range=key_range, + definition=partitions_def, + time_window=self._step_execution_context.asset_partitions_time_window_for_upstream( + asset_key + ) + if has_one_dimension_time_window_partitioning(partitions_def) + else None, + ) + self._cached_upstream_partition_infos[asset_key] = info + return info + raise DagsterInvariantViolationError( + "Cannot access information about the asset's partition in a non-partitioned run." + ) + ######## Deprecated methods @deprecated(**_get_deprecation_kwargs("dagster_run")) @@ -1534,6 +1656,110 @@ def get_tag(self, key: str) -> Optional[str]: @deprecated(**_get_deprecation_kwargs("get_op_execution_context")) def get_op_execution_context(self) -> "OpExecutionContext": return self.op_execution_context + + @deprecated(**_get_deprecation_kwargs("has_partition_key")) + @public + @property + @_copy_docs_from_op_execution_context + def has_partition_key(self) -> bool: + return self.op_execution_context.has_partition_key + + @deprecated(**_get_deprecation_kwargs("partition_key")) + @public + @property + @_copy_docs_from_op_execution_context + def partition_key(self) -> str: + return self.op_execution_context.partition_key + + @deprecated(**_get_deprecation_kwargs("partition_keys")) + @public + @property + @_copy_docs_from_op_execution_context + def partition_keys(self) -> Sequence[str]: + return self.op_execution_context.partition_keys + + @deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.") + @public + @property + @_copy_docs_from_op_execution_context + def asset_partition_key_range(self) -> PartitionKeyRange: + return self.op_execution_context.asset_partition_key_range + + @deprecated(**_get_deprecation_kwargs("partition_key_range")) + @public + @property + @_copy_docs_from_op_execution_context + def partition_key_range(self) -> PartitionKeyRange: + return self.op_execution_context.partition_key_range + + @deprecated(**_get_deprecation_kwargs("partition_time_window")) + @public + @property + @_copy_docs_from_op_execution_context + def partition_time_window(self) -> TimeWindow: + return self.op_execution_context.partition_time_window + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_for_output(self, output_name: str = "result") -> str: + return self.op_execution_context.asset_partition_key_for_output(output_name=output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow: + return self.op_execution_context.asset_partitions_time_window_for_output(output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_range_for_output( + self, output_name: str = "result" + ) -> PartitionKeyRange: + return self.op_execution_context.asset_partition_key_range_for_output(output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange: + return self.op_execution_context.asset_partition_key_range_for_input(input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_for_input(self, input_name: str) -> str: + return self.op_execution_context.asset_partition_key_for_input(input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition: + return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition: + return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]: + return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]: + return self.op_execution_context.asset_partition_keys_for_input(input_name=input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow: + return self.op_execution_context.asset_partitions_time_window_for_input(input_name) ########## pass-through to op context @@ -1652,97 +1878,6 @@ def step_launcher(self) -> Optional[StepLauncher]: def get_step_execution_context(self) -> StepExecutionContext: return self.op_execution_context.get_step_execution_context() - #### partition_related - - @public - @property - @_copy_docs_from_op_execution_context - def has_partition_key(self) -> bool: - return self.op_execution_context.has_partition_key - - @public - @property - @_copy_docs_from_op_execution_context - def partition_key(self) -> str: - return self.op_execution_context.partition_key - - @public - @property - @_copy_docs_from_op_execution_context - def partition_keys(self) -> Sequence[str]: - return self.op_execution_context.partition_keys - - @deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.") - @public - @property - @_copy_docs_from_op_execution_context - def asset_partition_key_range(self) -> PartitionKeyRange: - return self.op_execution_context.asset_partition_key_range - - @public - @property - @_copy_docs_from_op_execution_context - def partition_key_range(self) -> PartitionKeyRange: - return self.op_execution_context.partition_key_range - - @public - @property - @_copy_docs_from_op_execution_context - def partition_time_window(self) -> TimeWindow: - return self.op_execution_context.partition_time_window - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_for_output(self, output_name: str = "result") -> str: - return self.op_execution_context.asset_partition_key_for_output(output_name=output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow: - return self.op_execution_context.asset_partitions_time_window_for_output(output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_range_for_output( - self, output_name: str = "result" - ) -> PartitionKeyRange: - return self.op_execution_context.asset_partition_key_range_for_output(output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange: - return self.op_execution_context.asset_partition_key_range_for_input(input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_for_input(self, input_name: str) -> str: - return self.op_execution_context.asset_partition_key_for_input(input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition: - return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition: - return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]: - return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]: - return self.op_execution_context.asset_partition_keys_for_input(input_name=input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow: - return self.op_execution_context.asset_partitions_time_window_for_input(input_name) - #### Event log related @_copy_docs_from_op_execution_context diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py b/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py index 7c1a07f8916a7..81f9ecbc610e6 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py @@ -601,7 +601,7 @@ def upstream(): ) def downstream(context: AssetExecutionContext): upstream_key = datetime.strptime( - context.asset_partition_key_for_input("upstream"), "%Y-%m-%d" + context.upstream_partition_info("upstream").key, "%Y-%m-%d" ) current_partition_key = datetime.strptime(context.partition_key, "%Y-%m-%d") @@ -653,12 +653,8 @@ def multi_asset_1(): @multi_asset(specs=[asset_3, asset_4], partitions_def=partitions_def) def multi_asset_2(context: AssetExecutionContext): - asset_1_key = datetime.strptime( - context.asset_partition_key_for_input("asset_1"), "%Y-%m-%d" - ) - asset_2_key = datetime.strptime( - context.asset_partition_key_for_input("asset_2"), "%Y-%m-%d" - ) + asset_1_key = datetime.strptime(context.upstream_partition_info("asset_1").key, "%Y-%m-%d") + asset_2_key = datetime.strptime(context.upstream_partition_info("asset_2").key, "%Y-%m-%d") current_partition_key = datetime.strptime(context.partition_key, "%Y-%m-%d") @@ -760,7 +756,7 @@ def test_self_dependent_partition_mapping_with_asset_deps(): ) def self_dependent(context: AssetExecutionContext): upstream_key = datetime.strptime( - context.asset_partition_key_for_input("self_dependent"), "%Y-%m-%d" + context.upstream_partition_info("self_dependent").key, "%Y-%m-%d" ) current_partition_key = datetime.strptime(context.partition_key, "%Y-%m-%d") @@ -786,9 +782,7 @@ def self_dependent(context: AssetExecutionContext): @multi_asset(specs=[asset_1], partitions_def=partitions_def) def the_multi_asset(context: AssetExecutionContext): - asset_1_key = datetime.strptime( - context.asset_partition_key_for_input("asset_1"), "%Y-%m-%d" - ) + asset_1_key = datetime.strptime(context.upstream_partition_info("asset_1").key, "%Y-%m-%d") current_partition_key = datetime.strptime(context.partition_key, "%Y-%m-%d") @@ -810,7 +804,7 @@ def upstream(): deps=[AssetDep(upstream, partition_mapping=SpecificPartitionsPartitionMapping(["apple"]))], ) def downstream(context: AssetExecutionContext): - assert context.asset_partition_key_for_input("upstream") == "apple" + assert context.upstream_partition_info("upstream").key == "apple" assert context.partition_key == "orange" with instance_for_test() as instance: @@ -840,7 +834,7 @@ def asset_1_multi_asset(): @multi_asset(specs=[asset_2], partitions_def=partitions_def) def asset_2_multi_asset(context: AssetExecutionContext): - assert context.asset_partition_key_for_input("asset_1") == "apple" + assert context.upstream_partition_info("asset_1").key == "apple" assert context.partition_key == "orange" with instance_for_test() as instance: