Skip to content

Commit

Permalink
deprecate asset_partition_*_for_output on AssetExecutionContext (#19436)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria authored Jan 31, 2024
1 parent 2c70f4d commit b2b732b
Show file tree
Hide file tree
Showing 21 changed files with 208 additions and 178 deletions.
64 changes: 37 additions & 27 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,11 @@ def _copy_docs_from_op_execution_context(obj):
"run_config": "run.run_config",
"run_tags": "run.tags",
"get_op_execution_context": "op_execution_context",
"asset_partition_key_for_output": "partition_key",
"asset_partitions_time_window_for_output": "partition_time_window",
"asset_partition_key_range_for_output": "partition_key_range",
"asset_partitions_def_for_output": "assets_def.partitions_def",
"asset_partition_keys_for_output": "partition_keys",
}

ALTERNATE_EXPRESSIONS = {
Expand Down Expand Up @@ -1501,6 +1506,38 @@ def get_tag(self, key: str) -> Optional[str]:
def get_op_execution_context(self) -> "OpExecutionContext":
return self.op_execution_context

@deprecated(**_get_deprecation_kwargs("asset_partition_key_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_output(self, output_name: str = "result") -> str:
return self.op_execution_context.asset_partition_key_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_output(output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_output(
self, output_name: str = "result"
) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_output(output_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name)

########## pass-through to op context

#### op related
Expand Down Expand Up @@ -1657,23 +1694,6 @@ def partition_key_range(self) -> PartitionKeyRange:
def partition_time_window(self) -> TimeWindow:
return self.op_execution_context.partition_time_window

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_output(self, output_name: str = "result") -> str:
return self.op_execution_context.asset_partition_key_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_output(output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_output(
self, output_name: str = "result"
) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_output(output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange:
Expand All @@ -1684,21 +1704,11 @@ def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRa
def asset_partition_key_for_input(self, input_name: str) -> str:
return self.op_execution_context.asset_partition_key_for_input(input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,10 @@ def asset_partitions_time_window_for_output(self, output_name: str = "result") -
Union[MultiPartitionsDefinition, TimeWindowPartitionsDefinition], partitions_def
).time_window_for_partition_key(self.partition_key)

@property
def partition_time_window(self) -> TimeWindow:
return self.asset_partitions_time_window_for_output()

def add_output_metadata(
self,
metadata: Mapping[str, Any],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def load_input(self, context):
assert context.asset_partitions_def == upstream_partitions_def

@asset(partitions_def=upstream_partitions_def)
def upstream_asset(context):
assert context.asset_partition_key_for_output() == "2"
def upstream_asset(context: AssetExecutionContext):
assert context.partition_key == "2"

@asset(
partitions_def=downstream_partitions_def,
ins={"upstream_asset": AssetIn(partition_mapping=TrailingWindowPartitionMapping())},
)
def downstream_asset(context, upstream_asset):
assert context.asset_partition_key_for_output() == "2"
def downstream_asset(context: AssetExecutionContext, upstream_asset):
assert context.partition_key == "2"
assert upstream_asset is None
assert context.asset_partitions_def_for_input("upstream_asset") == upstream_partitions_def

Expand Down Expand Up @@ -341,9 +341,8 @@ def test_partition_keys_in_range():
]

@asset(partitions_def=DailyPartitionsDefinition(start_date="2022-09-11"))
def upstream(context):
assert context.asset_partition_keys_for_output("result") == ["2022-09-11"]
assert context.asset_partition_keys_for_output() == ["2022-09-11"]
def upstream(context: AssetExecutionContext):
assert context.partition_keys == ["2022-09-11"]

@asset(partitions_def=WeeklyPartitionsDefinition(start_date="2022-09-11"))
def downstream(context, upstream):
Expand Down Expand Up @@ -383,8 +382,8 @@ def test_dependency_resolution_partition_mapping():
partitions_def=DailyPartitionsDefinition(start_date="2020-01-01"),
key_prefix=["staging"],
)
def upstream(context):
partition_date_str = context.asset_partition_key_for_output()
def upstream(context: AssetExecutionContext):
partition_date_str = context.partition_key
return partition_date_str

@asset(
Expand Down Expand Up @@ -441,11 +440,8 @@ def upstream(context):
return 1

@asset(partitions_def=composite)
def downstream(context, upstream):
assert (
context.asset_partition_keys_for_input("upstream")
== context.asset_partition_keys_for_output()
)
def downstream(context: AssetExecutionContext, upstream):
assert context.asset_partition_keys_for_input("upstream") == context.partition_keys
return 1

asset_graph = AssetGraph.from_assets([upstream, downstream])
Expand All @@ -471,16 +467,14 @@ def test_multipartitions_def_partition_mapping_infer_single_dim_to_multi():
)

@asset(partitions_def=abc_def)
def upstream(context):
assert context.asset_partition_keys_for_output("result") == ["a"]
def upstream(context: AssetExecutionContext):
assert context.partition_keys == ["a"]
return 1

@asset(partitions_def=composite)
def downstream(context, upstream):
def downstream(context: AssetExecutionContext, upstream):
assert context.asset_partition_keys_for_input("upstream") == ["a"]
assert context.asset_partition_keys_for_output("result") == [
MultiPartitionKey({"abc": "a", "123": "1"})
]
assert context.partition_keys == [MultiPartitionKey({"abc": "a", "123": "1"})]
return 1

asset_graph = AssetGraph.from_assets([upstream, downstream])
Expand Down Expand Up @@ -533,9 +527,9 @@ def upstream(context):
return 1

@asset(partitions_def=abc_def)
def downstream(context, upstream):
def downstream(context: AssetExecutionContext, upstream):
assert set(context.asset_partition_keys_for_input("upstream")) == a_multipartition_keys
assert context.asset_partition_keys_for_output("result") == ["a"]
assert context.partition_keys == ["a"]
return 1

asset_graph = AssetGraph.from_assets([upstream, downstream])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from dagster import (
AssetExecutionContext,
AssetKey,
AssetOut,
AssetsDefinition,
Expand Down Expand Up @@ -278,8 +279,8 @@ def the_asset(context):

def test_materialize_partition_key():
@asset(partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"))
def the_asset(context):
assert context.asset_partition_key_for_output() == "2022-02-02"
def the_asset(context: AssetExecutionContext):
assert context.partition_key == "2022-02-02"

with instance_for_test() as instance:
result = materialize([the_asset], partition_key="2022-02-02", instance=instance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def multi_asset_with_internal_deps(thing):

def test_materialize_to_memory_partition_key():
@asset(partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"))
def the_asset(context):
assert context.asset_partition_key_for_output() == "2022-02-02"
def the_asset(context: AssetExecutionContext):
assert context.partition_key == "2022-02-02"

result = materialize_to_memory([the_asset], partition_key="2022-02-02")
assert result.success
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pendulum
import pytest
from dagster import (
AssetExecutionContext,
AssetMaterialization,
AssetOut,
AssetsDefinition,
Expand Down Expand Up @@ -155,8 +156,8 @@ def load_input(self, context):
assert False, "shouldn't get here"

@asset(partitions_def=partitions_def)
def my_asset(context):
assert context.asset_partitions_def_for_output() == partitions_def
def my_asset(context: AssetExecutionContext):
assert context.assets_def.partitions_def == partitions_def

my_job = build_assets_job(
"my_job",
Expand Down Expand Up @@ -212,24 +213,24 @@ def test_access_partition_keys_from_context_direct_invocation():
partitions_def = StaticPartitionsDefinition(["a"])

@asset(partitions_def=partitions_def)
def partitioned_asset(context):
assert context.asset_partition_key_for_output() == "a"
def partitioned_asset(context: AssetExecutionContext):
assert context.partition_key == "a"

context = build_asset_context(partition_key="a")

# check unbound context
assert context.asset_partition_key_for_output() == "a"
assert context.partition_key == "a"

# check bound context
partitioned_asset(context)

# check failure for non-partitioned asset
@asset
def non_partitioned_asset(context):
def non_partitioned_asset(context: AssetExecutionContext):
with pytest.raises(
CheckError, match="Tried to access partition_key for a non-partitioned run"
):
context.asset_partition_key_for_output()
_ = context.partition_key

context = build_asset_context()
non_partitioned_asset(context)
Expand Down Expand Up @@ -257,8 +258,8 @@ def load_input(self, context):
assert context.asset_partition_key_range == PartitionKeyRange("a", "c")

@asset(partitions_def=upstream_partitions_def)
def upstream_asset(context):
assert context.asset_partition_key_for_output() == "b"
def upstream_asset(context: AssetExecutionContext):
assert context.partition_key == "b"

@asset
def downstream_asset(upstream_asset):
Expand Down Expand Up @@ -606,7 +607,6 @@ def test_partition_range_single_run():
@asset(partitions_def=partitions_def)
def upstream_asset(context) -> None:
key_range = PartitionKeyRange(start="2020-01-01", end="2020-01-03")
assert context.asset_partition_key_range_for_output() == key_range
assert context.partition_key_range == key_range
assert context.partition_time_window == TimeWindow(
partitions_def.time_window_for_partition_key(key_range.start).start,
Expand All @@ -615,11 +615,11 @@ def upstream_asset(context) -> None:
assert context.partition_keys == partitions_def.get_partition_keys_in_range(key_range)

@asset(partitions_def=partitions_def, deps=["upstream_asset"])
def downstream_asset(context) -> None:
def downstream_asset(context: AssetExecutionContext) -> None:
assert context.asset_partition_key_range_for_input("upstream_asset") == PartitionKeyRange(
start="2020-01-01", end="2020-01-03"
)
assert context.asset_partition_key_range_for_output() == PartitionKeyRange(
assert context.partition_key_range == PartitionKeyRange(
start="2020-01-01", end="2020-01-03"
)

Expand Down Expand Up @@ -653,17 +653,15 @@ def test_multipartition_range_single_run():
)

@asset(partitions_def=partitions_def)
def multipartitioned_asset(context) -> None:
key_range = context.asset_partition_key_range_for_output()
def multipartitioned_asset(context: AssetExecutionContext) -> None:
key_range = context.partition_key_range

assert isinstance(key_range.start, MultiPartitionKey)
assert isinstance(key_range.end, MultiPartitionKey)
assert key_range.start == MultiPartitionKey({"date": "2020-01-01", "abc": "a"})
assert key_range.end == MultiPartitionKey({"date": "2020-01-03", "abc": "a"})

assert all(
isinstance(key, MultiPartitionKey) for key in context.asset_partition_keys_for_output()
)
assert all(isinstance(key, MultiPartitionKey) for key in context.partition_keys)

the_job = define_asset_job("job").resolve(
asset_graph=AssetGraph.from_assets([multipartitioned_asset])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,11 @@ def test_deprecation_warnings():
"asset_key_for_input",
"asset_key_for_output",
"asset_partition_key_for_input",
"asset_partition_key_for_output",
"asset_partition_key_range",
"asset_partition_key_range_for_input",
"asset_partition_key_range_for_output",
"asset_partition_keys_for_input",
"asset_partition_keys_for_output",
"asset_partitions_def_for_input",
"asset_partitions_def_for_output",
"asset_partitions_time_window_for_input",
"asset_partitions_time_window_for_output",
"assets_def",
"get_output_metadata",
"has_asset_checks_def",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1259,8 +1259,8 @@ def test_partitions_time_window_asset_invocation():
@asset(
partitions_def=partitions_def,
)
def partitioned_asset(context):
start, end = context.asset_partitions_time_window_for_output()
def partitioned_asset(context: AssetExecutionContext):
start, end = context.partition_time_window
assert start == pendulum.instance(datetime(2023, 2, 2), tz=partitions_def.timezone)
assert end == pendulum.instance(datetime(2023, 2, 3), tz=partitions_def.timezone)

Expand All @@ -1279,18 +1279,21 @@ def test_multipartitioned_time_window_asset_invocation():
)

@asset(partitions_def=partitions_def)
def my_asset(context):
def my_asset(context: AssetExecutionContext):
time_partition = get_time_partitions_def(partitions_def)
if time_partition is None:
assert False, "partitions def does not have a time component"
time_window = TimeWindow(
start=pendulum.instance(
datetime(year=2020, month=1, day=1),
tz=get_time_partitions_def(partitions_def).timezone,
tz=time_partition.timezone,
),
end=pendulum.instance(
datetime(year=2020, month=1, day=2),
tz=get_time_partitions_def(partitions_def).timezone,
tz=time_partition.timezone,
),
)
assert context.asset_partitions_time_window_for_output() == time_window
assert context.partition_time_window == time_window
return 1

context = build_asset_context(
Expand All @@ -1306,12 +1309,12 @@ def my_asset(context):
)

@asset(partitions_def=partitions_def)
def static_multipartitioned_asset(context):
def static_multipartitioned_asset(context: AssetExecutionContext):
with pytest.raises(
DagsterInvariantViolationError,
match="with a single time dimension",
):
context.asset_partitions_time_window_for_output()
_ = context.partition_time_window

context = build_asset_context(
partition_key="a|a",
Expand Down
Loading

0 comments on commit b2b732b

Please sign in to comment.