Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecate asset_partition_*_for_output on AssetExecutionContext #19436

Merged
merged 6 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,11 @@ def _copy_docs_from_op_execution_context(obj):
"run_config": "run.run_config",
"run_tags": "run.tags",
"get_op_execution_context": "op_execution_context",
"asset_partition_key_for_output": "partition_key",
"asset_partitions_time_window_for_output": "partition_time_window",
"asset_partition_key_range_for_output": "partition_key_range",
"asset_partitions_def_for_output": "assets_def.partitions_def",
"asset_partition_keys_for_output": "partition_keys",
}

ALTERNATE_EXPRESSIONS = {
Expand Down Expand Up @@ -1501,6 +1506,38 @@ def get_tag(self, key: str) -> Optional[str]:
def get_op_execution_context(self) -> "OpExecutionContext":
return self.op_execution_context

@deprecated(**_get_deprecation_kwargs("asset_partition_key_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_output(self, output_name: str = "result") -> str:
return self.op_execution_context.asset_partition_key_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_output(output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_output(
self, output_name: str = "result"
) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_output(output_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name)

########## pass-through to op context

#### op related
Expand Down Expand Up @@ -1657,23 +1694,6 @@ def partition_key_range(self) -> PartitionKeyRange:
def partition_time_window(self) -> TimeWindow:
return self.op_execution_context.partition_time_window

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_output(self, output_name: str = "result") -> str:
return self.op_execution_context.asset_partition_key_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_output(output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_output(
self, output_name: str = "result"
) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_output(output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange:
Expand All @@ -1684,21 +1704,11 @@ def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRa
def asset_partition_key_for_input(self, input_name: str) -> str:
return self.op_execution_context.asset_partition_key_for_input(input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,10 @@ def asset_partitions_time_window_for_output(self, output_name: str = "result") -
Union[MultiPartitionsDefinition, TimeWindowPartitionsDefinition], partitions_def
).time_window_for_partition_key(self.partition_key)

@property
def partition_time_window(self) -> TimeWindow:
return self.asset_partitions_time_window_for_output()

def add_output_metadata(
self,
metadata: Mapping[str, Any],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def load_input(self, context):
assert context.asset_partitions_def == upstream_partitions_def

@asset(partitions_def=upstream_partitions_def)
def upstream_asset(context):
assert context.asset_partition_key_for_output() == "2"
def upstream_asset(context: AssetExecutionContext):
assert context.partition_key == "2"

@asset(
partitions_def=downstream_partitions_def,
ins={"upstream_asset": AssetIn(partition_mapping=TrailingWindowPartitionMapping())},
)
def downstream_asset(context, upstream_asset):
assert context.asset_partition_key_for_output() == "2"
def downstream_asset(context: AssetExecutionContext, upstream_asset):
assert context.partition_key == "2"
assert upstream_asset is None
assert context.asset_partitions_def_for_input("upstream_asset") == upstream_partitions_def

Expand Down Expand Up @@ -341,9 +341,8 @@ def test_partition_keys_in_range():
]

@asset(partitions_def=DailyPartitionsDefinition(start_date="2022-09-11"))
def upstream(context):
assert context.asset_partition_keys_for_output("result") == ["2022-09-11"]
assert context.asset_partition_keys_for_output() == ["2022-09-11"]
def upstream(context: AssetExecutionContext):
assert context.partition_keys == ["2022-09-11"]

@asset(partitions_def=WeeklyPartitionsDefinition(start_date="2022-09-11"))
def downstream(context, upstream):
Expand Down Expand Up @@ -383,8 +382,8 @@ def test_dependency_resolution_partition_mapping():
partitions_def=DailyPartitionsDefinition(start_date="2020-01-01"),
key_prefix=["staging"],
)
def upstream(context):
partition_date_str = context.asset_partition_key_for_output()
def upstream(context: AssetExecutionContext):
partition_date_str = context.partition_key
return partition_date_str

@asset(
Expand Down Expand Up @@ -441,11 +440,8 @@ def upstream(context):
return 1

@asset(partitions_def=composite)
def downstream(context, upstream):
assert (
context.asset_partition_keys_for_input("upstream")
== context.asset_partition_keys_for_output()
)
def downstream(context: AssetExecutionContext, upstream):
assert context.asset_partition_keys_for_input("upstream") == context.partition_keys
return 1

asset_graph = AssetGraph.from_assets([upstream, downstream])
Expand All @@ -471,16 +467,14 @@ def test_multipartitions_def_partition_mapping_infer_single_dim_to_multi():
)

@asset(partitions_def=abc_def)
def upstream(context):
assert context.asset_partition_keys_for_output("result") == ["a"]
def upstream(context: AssetExecutionContext):
assert context.partition_keys == ["a"]
return 1

@asset(partitions_def=composite)
def downstream(context, upstream):
def downstream(context: AssetExecutionContext, upstream):
assert context.asset_partition_keys_for_input("upstream") == ["a"]
assert context.asset_partition_keys_for_output("result") == [
MultiPartitionKey({"abc": "a", "123": "1"})
]
assert context.partition_keys == [MultiPartitionKey({"abc": "a", "123": "1"})]
return 1

asset_graph = AssetGraph.from_assets([upstream, downstream])
Expand Down Expand Up @@ -533,9 +527,9 @@ def upstream(context):
return 1

@asset(partitions_def=abc_def)
def downstream(context, upstream):
def downstream(context: AssetExecutionContext, upstream):
assert set(context.asset_partition_keys_for_input("upstream")) == a_multipartition_keys
assert context.asset_partition_keys_for_output("result") == ["a"]
assert context.partition_keys == ["a"]
return 1

asset_graph = AssetGraph.from_assets([upstream, downstream])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from dagster import (
AssetExecutionContext,
AssetKey,
AssetOut,
AssetsDefinition,
Expand Down Expand Up @@ -278,8 +279,8 @@ def the_asset(context):

def test_materialize_partition_key():
@asset(partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"))
def the_asset(context):
assert context.asset_partition_key_for_output() == "2022-02-02"
def the_asset(context: AssetExecutionContext):
assert context.partition_key == "2022-02-02"

with instance_for_test() as instance:
result = materialize([the_asset], partition_key="2022-02-02", instance=instance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def multi_asset_with_internal_deps(thing):

def test_materialize_to_memory_partition_key():
@asset(partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"))
def the_asset(context):
assert context.asset_partition_key_for_output() == "2022-02-02"
def the_asset(context: AssetExecutionContext):
assert context.partition_key == "2022-02-02"

result = materialize_to_memory([the_asset], partition_key="2022-02-02")
assert result.success
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pendulum
import pytest
from dagster import (
AssetExecutionContext,
AssetMaterialization,
AssetOut,
AssetsDefinition,
Expand Down Expand Up @@ -155,8 +156,8 @@ def load_input(self, context):
assert False, "shouldn't get here"

@asset(partitions_def=partitions_def)
def my_asset(context):
assert context.asset_partitions_def_for_output() == partitions_def
def my_asset(context: AssetExecutionContext):
assert context.assets_def.partitions_def == partitions_def

my_job = build_assets_job(
"my_job",
Expand Down Expand Up @@ -212,24 +213,24 @@ def test_access_partition_keys_from_context_direct_invocation():
partitions_def = StaticPartitionsDefinition(["a"])

@asset(partitions_def=partitions_def)
def partitioned_asset(context):
assert context.asset_partition_key_for_output() == "a"
def partitioned_asset(context: AssetExecutionContext):
assert context.partition_key == "a"

context = build_asset_context(partition_key="a")

# check unbound context
assert context.asset_partition_key_for_output() == "a"
assert context.partition_key == "a"

# check bound context
partitioned_asset(context)

# check failure for non-partitioned asset
@asset
def non_partitioned_asset(context):
def non_partitioned_asset(context: AssetExecutionContext):
with pytest.raises(
CheckError, match="Tried to access partition_key for a non-partitioned run"
):
context.asset_partition_key_for_output()
_ = context.partition_key

context = build_asset_context()
non_partitioned_asset(context)
Expand Down Expand Up @@ -257,8 +258,8 @@ def load_input(self, context):
assert context.asset_partition_key_range == PartitionKeyRange("a", "c")

@asset(partitions_def=upstream_partitions_def)
def upstream_asset(context):
assert context.asset_partition_key_for_output() == "b"
def upstream_asset(context: AssetExecutionContext):
assert context.partition_key == "b"

@asset
def downstream_asset(upstream_asset):
Expand Down Expand Up @@ -606,7 +607,6 @@ def test_partition_range_single_run():
@asset(partitions_def=partitions_def)
def upstream_asset(context) -> None:
key_range = PartitionKeyRange(start="2020-01-01", end="2020-01-03")
assert context.asset_partition_key_range_for_output() == key_range
assert context.partition_key_range == key_range
assert context.partition_time_window == TimeWindow(
partitions_def.time_window_for_partition_key(key_range.start).start,
Expand All @@ -615,11 +615,11 @@ def upstream_asset(context) -> None:
assert context.partition_keys == partitions_def.get_partition_keys_in_range(key_range)

@asset(partitions_def=partitions_def, deps=["upstream_asset"])
def downstream_asset(context) -> None:
def downstream_asset(context: AssetExecutionContext) -> None:
assert context.asset_partition_key_range_for_input("upstream_asset") == PartitionKeyRange(
start="2020-01-01", end="2020-01-03"
)
assert context.asset_partition_key_range_for_output() == PartitionKeyRange(
assert context.partition_key_range == PartitionKeyRange(
start="2020-01-01", end="2020-01-03"
)

Expand Down Expand Up @@ -653,17 +653,15 @@ def test_multipartition_range_single_run():
)

@asset(partitions_def=partitions_def)
def multipartitioned_asset(context) -> None:
key_range = context.asset_partition_key_range_for_output()
def multipartitioned_asset(context: AssetExecutionContext) -> None:
key_range = context.partition_key_range

assert isinstance(key_range.start, MultiPartitionKey)
assert isinstance(key_range.end, MultiPartitionKey)
assert key_range.start == MultiPartitionKey({"date": "2020-01-01", "abc": "a"})
assert key_range.end == MultiPartitionKey({"date": "2020-01-03", "abc": "a"})

assert all(
isinstance(key, MultiPartitionKey) for key in context.asset_partition_keys_for_output()
)
assert all(isinstance(key, MultiPartitionKey) for key in context.partition_keys)

the_job = define_asset_job("job").resolve(
asset_graph=AssetGraph.from_assets([multipartitioned_asset])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,11 @@ def test_deprecation_warnings():
"asset_key_for_input",
"asset_key_for_output",
"asset_partition_key_for_input",
"asset_partition_key_for_output",
"asset_partition_key_range",
"asset_partition_key_range_for_input",
"asset_partition_key_range_for_output",
"asset_partition_keys_for_input",
"asset_partition_keys_for_output",
"asset_partitions_def_for_input",
"asset_partitions_def_for_output",
"asset_partitions_time_window_for_input",
"asset_partitions_time_window_for_output",
"assets_def",
"get_output_metadata",
"has_asset_checks_def",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1259,8 +1259,8 @@ def test_partitions_time_window_asset_invocation():
@asset(
partitions_def=partitions_def,
)
def partitioned_asset(context):
start, end = context.asset_partitions_time_window_for_output()
def partitioned_asset(context: AssetExecutionContext):
start, end = context.partition_time_window
assert start == pendulum.instance(datetime(2023, 2, 2), tz=partitions_def.timezone)
assert end == pendulum.instance(datetime(2023, 2, 3), tz=partitions_def.timezone)

Expand All @@ -1279,18 +1279,21 @@ def test_multipartitioned_time_window_asset_invocation():
)

@asset(partitions_def=partitions_def)
def my_asset(context):
def my_asset(context: AssetExecutionContext):
time_partition = get_time_partitions_def(partitions_def)
if time_partition is None:
assert False, "partitions def does not have a time component"
time_window = TimeWindow(
start=pendulum.instance(
datetime(year=2020, month=1, day=1),
tz=get_time_partitions_def(partitions_def).timezone,
tz=time_partition.timezone,
),
end=pendulum.instance(
datetime(year=2020, month=1, day=2),
tz=get_time_partitions_def(partitions_def).timezone,
tz=time_partition.timezone,
),
)
assert context.asset_partitions_time_window_for_output() == time_window
assert context.partition_time_window == time_window
return 1

context = build_asset_context(
Expand All @@ -1306,12 +1309,12 @@ def my_asset(context):
)

@asset(partitions_def=partitions_def)
def static_multipartitioned_asset(context):
def static_multipartitioned_asset(context: AssetExecutionContext):
with pytest.raises(
DagsterInvariantViolationError,
match="with a single time dimension",
):
context.asset_partitions_time_window_for_output()
_ = context.partition_time_window

context = build_asset_context(
partition_key="a|a",
Expand Down
Loading