From 532ca5f26c42a5dda0e1abd2b643df1f0c4c72c8 Mon Sep 17 00:00:00 2001 From: JamieDeMaria Date: Thu, 18 Jan 2024 12:55:10 -0500 Subject: [PATCH] partition api 4 --- .../partition_example/partition_example.py | 10 +- .../_core/execution/context/compute.py | 288 +++++++++++------- .../test_asset_partition_mappings.py | 24 +- 3 files changed, 191 insertions(+), 131 deletions(-) diff --git a/examples/partition_example/partition_example.py b/examples/partition_example/partition_example.py index 5fc8e80c2499c..2b1b2f4f43ea8 100644 --- a/examples/partition_example/partition_example.py +++ b/examples/partition_example/partition_example.py @@ -33,7 +33,7 @@ def relativedelta(*args, **kwargs): metadata={"partition_expr": "LastModifiedDate"}, ) def salesforce_customers(context: AssetExecutionContext) -> pd.DataFrame: - start_date_str = context.asset_partition_key_for_output() + start_date_str = context.partition_key timezone = pytz.timezone("GMT") # Replace 'Your_Timezone' with the desired timezone start_obj = datetime.datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone) @@ -65,7 +65,7 @@ def realized_vol(context: AssetExecutionContext, orats_daily_prices: pd.DataFram The volatility is calculated using various methods such as close-to-close, Parkinson, Hodges-Tompkins, and Yang-Zhang. The function returns a DataFrame with the calculated volatilities. """ - trade_date = context.asset_partition_key_for_output() + trade_date = context.partition_key ticker_id = 1 df = all_realvols(orats_daily_prices, ticker_id, trade_date) @@ -80,7 +80,7 @@ def realized_vol(context: AssetExecutionContext, orats_daily_prices: pd.DataFram @asset(io_manager_def="parquet_io_manager", partitions_def=hourly_partitions) def my_custom_df(context: AssetExecutionContext) -> pd.DataFrame: - start, end = context.asset_partitions_time_window_for_output() + start, end = context.partition_time_window df = pd.DataFrame({"timestamp": pd.date_range(start, end, freq="5T")}) df["count"] = df["timestamp"].map(lambda a: random.randint(1, 1000)) @@ -93,7 +93,7 @@ def fetch_blog_posts_from_external_api(*args, **kwargs): @asset(partitions_def=HourlyPartitionsDefinition(start_date="2022-01-01-00:00")) def blog_posts(context: AssetExecutionContext) -> List[Dict]: - partition_datetime_str = context.asset_partition_key_for_output() + partition_datetime_str = context.partition_key hour = datetime.datetime.fromisoformat(partition_datetime_str) posts = fetch_blog_posts_from_external_api(hour_when_posted=hour) return posts @@ -106,7 +106,7 @@ def blog_posts(context: AssetExecutionContext) -> List[Dict]: key_prefix=["snowflake", "eldermark_proxy"], ) def resident(context: AssetExecutionContext) -> Output[pd.DataFrame]: - start, end = context.asset_partitions_time_window_for_output() + start, end = context.partition_time_window filter_str = f"LastMod_Stamp >= {start.timestamp()} AND LastMod_Stamp < {end.timestamp()}" records = context.resources.eldermark.fetch_obj(obj="Resident", filter=filter_str) diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index 792d061509c3b..aaf7a5029cc98 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -1375,6 +1375,17 @@ def _copy_docs_from_op_execution_context(obj): "run_config": "run.run_config", "run_tags": "run.tags", "get_op_execution_context": "op_execution_context", + "asset_partition_key_for_output": "partition_key", + "asset_partitions_time_window_for_output": "partition_time_window", + "asset_partition_key_range_for_output": "partition_key_range", + "asset_partition_key_range_for_input": "dep_context(asset_key).partition_key_range", + "asset_partition_key_for_input": "dep_context(asset_key).partition_key", + "asset_partitions_def_for_output": "assets_def.partitions_def", + "asset_partitions_def_for_input": "dep_context(asset_key).partitions_def", + "asset_partition_keys_for_output": "partition_keys", + "asset_partition_keys_for_input": "dep_context(asset_key).partition_keys", + "asset_partitions_time_window_for_input": "dep_context(asset_key).partition_time_window", + "has_partition_key": "is_partitioned_materialization", } ALTERNATE_EXPRESSIONS = { @@ -1401,6 +1412,61 @@ def _get_deprecation_kwargs(attr: str): return deprecation_kwargs +class DepContext: + def __init__( + self, asset_key: CoercibleToAssetKey, step_execution_context: StepExecutionContext + ): + self._key = AssetKey.from_coercible(asset_key) + self._step_execution_context = step_execution_context + + @property + def key(self) -> AssetKey: + return self._key + + @property + def latest_materialization(self) -> Optional[AssetMaterialization]: + materialization_events = self._step_execution_context.upstream_asset_materialization_events + if self.key in materialization_events.keys(): + return materialization_events.get(self.key) + + raise DagsterInvariantViolationError( + f"Cannot fetch AssetMaterialization for asset {self.key}. {self.key} must be an upstream dependency" + "in order to call latest_materialization." + ) + + @property + def partition_key(self) -> str: + return self._step_execution_context.asset_partition_key_for_upstream_asset(self.key) + + @property + def partition_keys(self) -> Sequence[str]: + return list( + self._step_execution_context.asset_partitions_subset_for_upstream_asset( + self.key + ).get_partition_keys() + ) + + @property + def partition_time_window(self) -> TimeWindow: + return self._step_execution_context.asset_partitions_time_window_for_upstream_asset( + self.key + ) + + @property + def partition_key_range(self) -> PartitionKeyRange: + return self._step_execution_context.asset_partition_key_range_for_upstream_asset(self.key) + + @property + def partitions_def(self) -> PartitionsDefinition: + result = self._step_execution_context.job_def.asset_layer.partitions_def_for_asset(self.key) + if result is None: + raise DagsterInvariantViolationError( + f"Attempting to access partitions def for asset {self.key}, but it is not partitioned" + ) + + return result + + class AssetExecutionContext(OpExecutionContext): def __init__(self, op_execution_context: OpExecutionContext) -> None: self._op_execution_context = check.inst_param( @@ -1471,30 +1537,55 @@ def job_def(self) -> JobDefinition: return self.op_execution_context.job_def @public - def latest_materialization_for_upstream_asset( - self, key: CoercibleToAssetKey - ) -> Optional[AssetMaterialization]: - """Get the most recent AssetMaterialization event for the key. The key must be an upstream - asset for the currently materializing asset. Information like metadata and tags can be found - on the AssetMaterialization. If the key is not an upstream asset of the currently - materializing asset, an error will be raised. If no AssetMaterialization exists for key, None - will be returned. + def dep_context(self, asset_key: CoercibleToAssetKey): + if ( + self.job_def.asset_layer.input_for_asset_key( + self.op_execution_context.node_handle, AssetKey.from_coercible(asset_key) + ) + is not None + ): + return DepContext( + asset_key=asset_key, step_execution_context=self._step_execution_context + ) + else: + raise DagsterInvariantViolationError( + f"Cannot access DepContext for asset {asset_key} since it is not an upstream dependency of {self.asset_key}." + ) - Returns: Optional[AssetMaterialization] - """ - materialization_events = ( - self.op_execution_context._step_execution_context.upstream_asset_materialization_events # noqa: SLF001 - ) - if AssetKey.from_coercible(key) in materialization_events.keys(): - return materialization_events.get(AssetKey.from_coercible(key)) + @public + @property + def is_partitioned_materialization(self) -> bool: + return self.op_execution_context.has_partition_key - raise DagsterInvariantViolationError( - f"Cannot fetch AssetMaterialization for asset {key}. {key} must be an upstream dependency" - "in order to call latest_materialization_for_upstream_asset." - ) + @public + @property + def partition_key(self) -> str: + return self.op_execution_context.partition_key + + @public + @property + def partition_keys(self) -> Sequence[str]: + return self.op_execution_context.partition_keys + + @public + @property + def partition_key_range(self) -> PartitionKeyRange: + return self.op_execution_context.partition_key_range + + @public + @property + def partition_time_window(self) -> TimeWindow: + return self.op_execution_context.partition_time_window ######## Deprecated methods + @deprecated(**_get_deprecation_kwargs("has_partition_key")) + @public + @property + @_copy_docs_from_op_execution_context + def has_partition_key(self) -> bool: + return self.op_execution_context.has_partition_key + @deprecated(**_get_deprecation_kwargs("dagster_run")) @property @_copy_docs_from_op_execution_context @@ -1534,6 +1625,74 @@ def get_tag(self, key: str) -> Optional[str]: @deprecated(**_get_deprecation_kwargs("get_op_execution_context")) def get_op_execution_context(self) -> "OpExecutionContext": return self.op_execution_context + @deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.") + @public + @property + @_copy_docs_from_op_execution_context + def asset_partition_key_range(self) -> PartitionKeyRange: + return self.op_execution_context.asset_partition_key_range + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_for_output(self, output_name: str = "result") -> str: + return self.op_execution_context.asset_partition_key_for_output(output_name=output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow: + return self.op_execution_context.asset_partitions_time_window_for_output(output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_range_for_output( + self, output_name: str = "result" + ) -> PartitionKeyRange: + return self.op_execution_context.asset_partition_key_range_for_output(output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange: + return self.op_execution_context.asset_partition_key_range_for_input(input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_for_input(self, input_name: str) -> str: + return self.op_execution_context.asset_partition_key_for_input(input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition: + return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition: + return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]: + return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]: + return self.op_execution_context.asset_partition_keys_for_input(input_name=input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow: + return self.op_execution_context.asset_partitions_time_window_for_input(input_name) ########## pass-through to op context @@ -1652,97 +1811,6 @@ def step_launcher(self) -> Optional[StepLauncher]: def get_step_execution_context(self) -> StepExecutionContext: return self.op_execution_context.get_step_execution_context() - #### partition_related - - @public - @property - @_copy_docs_from_op_execution_context - def has_partition_key(self) -> bool: - return self.op_execution_context.has_partition_key - - @public - @property - @_copy_docs_from_op_execution_context - def partition_key(self) -> str: - return self.op_execution_context.partition_key - - @public - @property - @_copy_docs_from_op_execution_context - def partition_keys(self) -> Sequence[str]: - return self.op_execution_context.partition_keys - - @deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.") - @public - @property - @_copy_docs_from_op_execution_context - def asset_partition_key_range(self) -> PartitionKeyRange: - return self.op_execution_context.asset_partition_key_range - - @public - @property - @_copy_docs_from_op_execution_context - def partition_key_range(self) -> PartitionKeyRange: - return self.op_execution_context.partition_key_range - - @public - @property - @_copy_docs_from_op_execution_context - def partition_time_window(self) -> TimeWindow: - return self.op_execution_context.partition_time_window - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_for_output(self, output_name: str = "result") -> str: - return self.op_execution_context.asset_partition_key_for_output(output_name=output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow: - return self.op_execution_context.asset_partitions_time_window_for_output(output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_range_for_output( - self, output_name: str = "result" - ) -> PartitionKeyRange: - return self.op_execution_context.asset_partition_key_range_for_output(output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange: - return self.op_execution_context.asset_partition_key_range_for_input(input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_for_input(self, input_name: str) -> str: - return self.op_execution_context.asset_partition_key_for_input(input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition: - return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition: - return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]: - return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]: - return self.op_execution_context.asset_partition_keys_for_input(input_name=input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow: - return self.op_execution_context.asset_partitions_time_window_for_input(input_name) - #### Event log related @_copy_docs_from_op_execution_context diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py b/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py index 7c1a07f8916a7..020d0db118d7a 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py @@ -472,15 +472,13 @@ def test_multipartitions_def_partition_mapping_infer_single_dim_to_multi(): @asset(partitions_def=abc_def) def upstream(context): - assert context.asset_partition_keys_for_output("result") == ["a"] + assert context.partition_keys == ["a"] return 1 @asset(partitions_def=composite) def downstream(context, upstream): - assert context.asset_partition_keys_for_input("upstream") == ["a"] - assert context.asset_partition_keys_for_output("result") == [ - MultiPartitionKey({"abc": "a", "123": "1"}) - ] + assert context.dep_context("upstream").partition_keys == ["a"] + assert context.partition_keys == [MultiPartitionKey({"abc": "a", "123": "1"})] return 1 asset_graph = AssetGraph.from_assets([upstream, downstream]) @@ -534,8 +532,8 @@ def upstream(context): @asset(partitions_def=abc_def) def downstream(context, upstream): - assert set(context.asset_partition_keys_for_input("upstream")) == a_multipartition_keys - assert context.asset_partition_keys_for_output("result") == ["a"] + assert set(context.dep_context("upstream").partition_keys) == a_multipartition_keys + assert context.partition_keys == ["a"] return 1 asset_graph = AssetGraph.from_assets([upstream, downstream]) @@ -600,9 +598,7 @@ def upstream(): ], ) def downstream(context: AssetExecutionContext): - upstream_key = datetime.strptime( - context.asset_partition_key_for_input("upstream"), "%Y-%m-%d" - ) + upstream_key = datetime.strptime(context.dep_context("upstream").partition_key, "%Y-%m-%d") current_partition_key = datetime.strptime(context.partition_key, "%Y-%m-%d") @@ -653,12 +649,8 @@ def multi_asset_1(): @multi_asset(specs=[asset_3, asset_4], partitions_def=partitions_def) def multi_asset_2(context: AssetExecutionContext): - asset_1_key = datetime.strptime( - context.asset_partition_key_for_input("asset_1"), "%Y-%m-%d" - ) - asset_2_key = datetime.strptime( - context.asset_partition_key_for_input("asset_2"), "%Y-%m-%d" - ) + asset_1_key = datetime.strptime(context.dep_context("asset_1").partition_key, "%Y-%m-%d") + asset_2_key = datetime.strptime(context.dep_context("asset_2").partition_key, "%Y-%m-%d") current_partition_key = datetime.strptime(context.partition_key, "%Y-%m-%d")