Skip to content

Commit

Permalink
start fixing test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Sep 20, 2023
1 parent 02bed98 commit 7c8f4eb
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1268,7 +1268,7 @@ def asset_check_spec(self) -> AssetCheckSpec:

OP_EXECUTION_CONTEXT_ONLY_METHODS = set(
[
"describe_op",
"describe_op", # TODO - used by internals
"file_manager",
"has_assets_def",
"get_mapping_key",
Expand All @@ -1281,12 +1281,12 @@ def asset_check_spec(self) -> AssetCheckSpec:
"op_handle",
"step_launcher",
"has_events", # TODO - used by internals
"consume_events",
"consume_events", # TODO - used by internals
]
)


PARTITION_KEY_RANGE_AS_ALT = "use partition_key_range or partition_key_range_for_asset instead"
PARTITION_KEY_RANGE_AS_ALT = "use partition_key_range or partition_key_range_for_asset_key instead"
INPUT_OUTPUT_ALT = "not use input or output names and instead use asset keys directly"
OUTPUT_METADATA_ALT = "return MaterializeResult from the asset instead"

Expand All @@ -1313,7 +1313,7 @@ def asset_check_spec(self) -> AssetCheckSpec:

ALTERNATE_AVAILABLE_METHODS = {
"has_tag": "use dagster_run.has_tag instead",
"get_tag": "use dagster_run.get_tag instead",
"get_tag": "use dagster_run.tags.get instead",
"run_tags": "use dagster_run.tags instead",
"set_data_version": "use MaterializeResult instead",
"run": "use dagster_run instead.",
Expand Down Expand Up @@ -1406,7 +1406,7 @@ def partition_key(self) -> str:
@public
@property
def partition_key_range(self) -> PartitionKeyRange:
return self._op_execution_context.asset_partition_key_range
return self._op_execution_context.partition_key_range

@property
def partition_time_window(self) -> TimeWindow:
Expand Down Expand Up @@ -1759,6 +1759,7 @@ def get_step_execution_context(self) -> StepExecutionContext:
def has_events(self) -> bool:
return self.op_execution_context.has_events()


def build_execution_context(
step_context: StepExecutionContext,
) -> Union[OpExecutionContext, AssetExecutionContext]:
Expand Down
11 changes: 7 additions & 4 deletions python_modules/dagster/dagster/_core/execution/plan/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import build_execution_context
from dagster._core.execution.context.compute import AssetExecutionContext, build_execution_context
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.system_config.objects import ResolvedRunConfig
from dagster._utils import iterate_with_context
Expand Down Expand Up @@ -168,6 +168,9 @@ def _yield_compute_results(
user_event_generator = gen_from_async_gen(user_event_generator)

op_label = step_context.describe_op()
op_execution_context = (
context.op_execution_context if isinstance(context, AssetExecutionContext) else context
)

for event in iterate_with_context(
lambda: op_execution_error_boundary(
Expand All @@ -180,12 +183,12 @@ def _yield_compute_results(
),
user_event_generator,
):
if context.has_events():
if op_execution_context.has_events():
yield from context.consume_events()
yield _validate_event(event, step_context)

if context.has_events():
yield from context.consume_events()
if op_execution_context.has_events():
yield from op_execution_context.consume_events()


def execute_core_compute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from dagster._utils import is_named_tuple_instance
from dagster._utils.warnings import disable_dagster_warnings

from ..context.compute import OpExecutionContext
from ..context.compute import AssetExecutionContext, OpExecutionContext


class NoAnnotationSentinel:
Expand Down Expand Up @@ -244,6 +244,8 @@ def _check_output_object_name(
def validate_and_coerce_op_result_to_iterator(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> Iterator[Any]:
if isinstance(context, AssetExecutionContext):
context = context.op_execution_context
if inspect.isgenerator(result):
# this happens when a user explicitly returns a generator in the op
for event in result:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def the_asset():
def test_materialize_config():
@asset(config_schema={"foo_str": str})
def the_asset_reqs_config(context):
assert context.op_config["foo_str"] == "foo"
assert context.op_execution_context.op_config["foo_str"] == "foo"

with instance_for_test() as instance:
assert materialize(
Expand Down Expand Up @@ -268,7 +268,7 @@ def multi_asset_with_internal_deps(thing):
def test_materialize_tags():
@asset
def the_asset(context):
assert context.get_tag("key1") == "value1"
assert context.dagster_run.tags.get("key1") == "value1"

with instance_for_test() as instance:
result = materialize([the_asset], instance=instance, tags={"key1": "value1"})
Expand All @@ -279,7 +279,7 @@ def the_asset(context):
def test_materialize_partition_key():
@asset(partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"))
def the_asset(context):
assert context.asset_partition_key_for_output() == "2022-02-02"
assert context.partition_key == "2022-02-02"

with instance_for_test() as instance:
result = materialize([the_asset], partition_key="2022-02-02", instance=instance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def load_input(self, context):

@asset(partitions_def=partitions_def)
def my_asset(context):
assert context.asset_partitions_def_for_output() == partitions_def
# TODO - no partitions_def property on AssetExecutionContext
assert context.op_execution_context.asset_partitions_def_for_output() == partitions_def

my_job = build_assets_job(
"my_job",
Expand Down Expand Up @@ -320,9 +321,10 @@ def upstream_asset():

@asset(partitions_def=partitions_def)
def downstream_asset(context, upstream_asset):
assert context.asset_partitions_time_window_for_input("upstream_asset") == TimeWindow(
pendulum.parse("2021-06-06"), pendulum.parse("2021-06-07")
)
# TODO - getting partition time windows is nasty rn - need to solve
assert context.op_execution_context.asset_partitions_time_window_for_input(
"upstream_asset"
) == TimeWindow(pendulum.parse("2021-06-06"), pendulum.parse("2021-06-07"))
assert upstream_asset is None

assert materialize(
Expand Down Expand Up @@ -533,7 +535,7 @@ def test_job_config_with_asset_partitions():

@asset(config_schema={"a": int}, partitions_def=daily_partitions_def)
def asset1(context):
assert context.op_config["a"] == 5
assert context.op_execution_context.op_config["a"] == 5
assert context.partition_key == "2020-01-01"

the_job = define_asset_job(
Expand All @@ -555,7 +557,7 @@ def test_job_partitioned_config_with_asset_partitions():

@asset(config_schema={"day_of_month": int}, partitions_def=daily_partitions_def)
def asset1(context):
assert context.op_config["day_of_month"] == 1
assert context.op_execution_context.op_config["day_of_month"] == 1
assert context.partition_key == "2020-01-01"

@daily_partitioned_config(start_date="2020-01-01")
Expand Down Expand Up @@ -593,6 +595,9 @@ def myconfig(start, _end):
)


@pytest.mark.skip(
"partition_key_range_for_asset_key not implemented in this PR, will implement in upstack"
) # TODO - remove
def test_partition_range_single_run():
partitions_def = DailyPartitionsDefinition(start_date="2020-01-01")

Expand All @@ -604,10 +609,10 @@ def upstream_asset(context: AssetExecutionContext) -> None:

@asset(partitions_def=partitions_def, deps=["upstream_asset"])
def downstream_asset(context) -> None:
assert context.asset_partition_key_range_for_input("upstream_asset") == PartitionKeyRange(
assert context.partition_key_range_for_asset_key("upstream_asset") == PartitionKeyRange(
start="2020-01-01", end="2020-01-03"
)
assert context.asset_partition_key_range_for_output() == PartitionKeyRange(
assert context.partition_key_range == PartitionKeyRange(
start="2020-01-01", end="2020-01-03"
)

Expand Down Expand Up @@ -643,15 +648,12 @@ def test_multipartition_range_single_run():
@asset(partitions_def=partitions_def)
def multipartitioned_asset(context: AssetExecutionContext) -> None:
key_range = context.partition_key_range

assert isinstance(key_range.start, MultiPartitionKey)
assert isinstance(key_range.end, MultiPartitionKey)
assert key_range.start == MultiPartitionKey({"date": "2020-01-01", "abc": "a"})
assert key_range.end == MultiPartitionKey({"date": "2020-01-03", "abc": "a"})

assert all(
isinstance(key, MultiPartitionKey)
for key in partitions_def.get_partitions_keys_in_range(context.partition_key_range)
for key in partitions_def.get_partition_keys_in_range(context.partition_key_range)
)

the_job = define_asset_job("job").resolve(
Expand Down

0 comments on commit 7c8f4eb

Please sign in to comment.