Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip Async/tasks #2927

Draft
wants to merge 27 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3e1d5a2
eod
wild-endeavor Nov 2, 2024
b45c4b7
notes
wild-endeavor Nov 4, 2024
999cea7
changes
wild-endeavor Nov 6, 2024
8be77c5
need to verify tests
wild-endeavor Nov 6, 2024
11f3242
quick lint pass and async
wild-endeavor Nov 6, 2024
8ceec0a
more tests
wild-endeavor Nov 6, 2024
528b063
add some assertions even though they're not correct
wild-endeavor Nov 6, 2024
2b6f698
nested eager in real execution calls the backend
wild-endeavor Nov 7, 2024
3ad46d1
comment
wild-endeavor Nov 7, 2024
38581a3
note
wild-endeavor Nov 7, 2024
469c53f
Merge remote-tracking branch 'origin/master' into async/tasks
wild-endeavor Nov 11, 2024
adbf189
comments, pre-worker queu
wild-endeavor Nov 13, 2024
b45cc87
replace queue
wild-endeavor Nov 14, 2024
52555eb
add turning back to native values
wild-endeavor Nov 14, 2024
9530595
remote
wild-endeavor Nov 14, 2024
809e6b6
remote
wild-endeavor Nov 14, 2024
88537e2
remote
wild-endeavor Nov 14, 2024
27e7ee2
return
wild-endeavor Nov 14, 2024
a0fc558
remove older comments
wild-endeavor Nov 14, 2024
4dfd857
Async/tasks cleanup (#2937)
wild-endeavor Nov 19, 2024
e0c6b7b
merge in consistent exec ids and signals
wild-endeavor Nov 21, 2024
f9b45d6
add cb to watch function, add exception handler, add try/catch around…
wild-endeavor Nov 22, 2024
d4fad13
fmt
wild-endeavor Nov 22, 2024
7c62e20
fix one test and skip rest for now
wild-endeavor Nov 22, 2024
8c80bb6
wrong marker
wild-endeavor Nov 22, 2024
bb04b64
remove old eager tests
wild-endeavor Nov 22, 2024
65df34a
Merge branch 'master' into async/tasks
wild-endeavor Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import contextlib
import datetime
import inspect
import os
import pathlib
import signal
Expand Down Expand Up @@ -177,10 +176,6 @@ def _dispatch_execute(
# Step2
# Invoke task - dispatch_execute
outputs = task_def.dispatch_execute(ctx, idl_input_literals)
if inspect.iscoroutine(outputs):
# Handle eager-mode (async) tasks
logger.info("Output is a coroutine")
outputs = _get_working_loop().run_until_complete(outputs)

# Step3a
if isinstance(outputs, VoidPromise):
Expand Down
27 changes: 0 additions & 27 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import asyncio
import collections
import datetime
import inspect
import warnings
from abc import abstractmethod
from base64 import b64encode
Expand Down Expand Up @@ -340,9 +339,6 @@ def local_execute(
# if one is changed and not the other.
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)

if inspect.iscoroutine(outputs_literal_map):
return outputs_literal_map

outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
Expand Down Expand Up @@ -759,29 +755,6 @@ def dispatch_execute(
raise
raise FlyteUserRuntimeException(e) from e

if inspect.iscoroutine(native_outputs):
# If native outputs is a coroutine, then this is an eager workflow.
if exec_ctx.execution_state:
if exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION:
# Just return task outputs as a coroutine if the eager workflow is being executed locally,
# outside of a workflow. This preserves the expectation that the eager workflow is an async
# function.
return native_outputs
elif exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
# If executed inside of a workflow being executed locally, then run the coroutine to get the
# actual results.
return asyncio.run(
self._async_execute(
native_inputs,
native_outputs,
ctx,
exec_ctx,
new_user_params,
)
)

return self._async_execute(native_inputs, native_outputs, ctx, exec_ctx, new_user_params)

# Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
# bubbled up to be handled at the callee layer.
native_outputs = self.post_execute(new_user_params, native_outputs)
Expand Down
11 changes: 11 additions & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,14 @@

# Set this environment variable to true to force the task to return non-zero exit code on failure.
FLYTE_FAIL_ON_ERROR = "FLYTE_FAIL_ON_ERROR"

# Executions launched by the current eager task will be tagged with this key:current_eager_exec_name
EAGER_TAG_KEY = "eager-exec"

# Executions launched by the current eager task will be tagged with this key:root_eager_exec_name, only relevant
# for nested eager tasks. This is how you identify the root execution.
EAGER_TAG_ROOT_KEY = "eager-root-exec"

# The environment variable that will be set to the root eager task execution name. This is how you pass down the
# root eager execution.
EAGER_ROOT_ENV_NAME = "_F_EE_ROOT"
27 changes: 26 additions & 1 deletion flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
from flytekit.models.core import identifier as _identifier

if typing.TYPE_CHECKING:
from flytekit import Deck
from flytekit.clients import friendly as friendly_client # noqa
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.core.worker_queue import Controller
from flytekit.deck.deck import Deck

# TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin

Expand Down Expand Up @@ -526,6 +528,10 @@ class Mode(Enum):
# This is the mode that is used to indicate a dynamic task
DYNAMIC_TASK_EXECUTION = 4

EAGER_EXECUTION = 5

EAGER_LOCAL_EXECUTION = 6

mode: Optional[ExecutionState.Mode]
working_dir: Union[os.PathLike, str]
engine_dir: Optional[Union[os.PathLike, str]]
Expand Down Expand Up @@ -586,6 +592,7 @@ def is_local_execution(self) -> bool:
return (
self.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION
or self.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
or self.mode == ExecutionState.Mode.EAGER_LOCAL_EXECUTION
)


Expand Down Expand Up @@ -663,6 +670,7 @@ class FlyteContext(object):
in_a_condition: bool = False
origin_stackframe: Optional[traceback.FrameSummary] = None
output_metadata_tracker: Optional[OutputMetadataTracker] = None
worker_queue: Optional[Controller] = None

@property
def user_space_params(self) -> Optional[ExecutionParameters]:
Expand All @@ -689,6 +697,7 @@ def new_builder(self) -> Builder:
execution_state=self.execution_state,
in_a_condition=self.in_a_condition,
output_metadata_tracker=self.output_metadata_tracker,
worker_queue=self.worker_queue,
)

def enter_conditional_section(self) -> Builder:
Expand All @@ -713,6 +722,12 @@ def with_serialization_settings(self, ss: SerializationSettings) -> Builder:
def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> Builder:
return self.new_builder().with_output_metadata_tracker(t)

def with_worker_queue(self, wq: Controller) -> Builder:
return self.new_builder().with_worker_queue(wq)

def with_client(self, c: SynchronousFlyteClient) -> Builder:
return self.new_builder().with_client(c)

def new_compilation_state(self, prefix: str = "") -> CompilationState:
"""
Creates and returns a default compilation state. For most of the code this should be the entrypoint
Expand Down Expand Up @@ -774,6 +789,7 @@ class Builder(object):
serialization_settings: Optional[SerializationSettings] = None
in_a_condition: bool = False
output_metadata_tracker: Optional[OutputMetadataTracker] = None
worker_queue: Optional[Controller] = None

def build(self) -> FlyteContext:
return FlyteContext(
Expand All @@ -785,6 +801,7 @@ def build(self) -> FlyteContext:
serialization_settings=self.serialization_settings,
in_a_condition=self.in_a_condition,
output_metadata_tracker=self.output_metadata_tracker,
worker_queue=self.worker_queue,
)

def enter_conditional_section(self) -> FlyteContext.Builder:
Expand Down Expand Up @@ -833,6 +850,14 @@ def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> FlyteContext
self.output_metadata_tracker = t
return self

def with_worker_queue(self, wq: Controller) -> FlyteContext.Builder:
self.worker_queue = wq
return self

def with_client(self, c: SynchronousFlyteClient) -> FlyteContext.Builder:
self.flyte_client = c
return self

def new_compilation_state(self, prefix: str = "") -> CompilationState:
"""
Creates and returns a default compilation state. For most of the code this should be the entrypoint
Expand Down
4 changes: 0 additions & 4 deletions flytekit/core/options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing
from dataclasses import dataclass
from typing import Callable, Optional

from flytekit.models import common as common_models
from flytekit.models import security
Expand Down Expand Up @@ -35,9 +34,6 @@ class Options(object):
notifications: typing.Optional[typing.List[common_models.Notification]] = None
disable_notifications: typing.Optional[bool] = None
overwrite_cache: typing.Optional[bool] = None
file_uploader: Optional[Callable] = (
None # This is used by the translator to upload task files, like pickled code etc
)

@classmethod
def default_from(
Expand Down
64 changes: 57 additions & 7 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import asyncio
import collections
import datetime
import inspect
import typing
from collections.abc import Iterable
from copy import deepcopy
from enum import Enum
from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast

from google.protobuf import struct_pb2 as _struct
from typing_extensions import Annotated, Protocol, get_args, get_origin
Expand Down Expand Up @@ -109,6 +109,27 @@ def my_wf(in1: int, in2: int) -> int:
translate_inputs_to_literals = loop_manager.synced(_translate_inputs_to_literals)


async def _translate_inputs_to_native(
ctx: FlyteContext,
incoming_values: Dict[str, Any],
flyte_interface_types: Dict[str, _interface_models.Variable],
) -> Dict[str, _literals_models.Literal]:
if incoming_values is None:
raise AssertionError("Incoming values cannot be None, must be a dict")

result = {} # So as to not overwrite the input_kwargs
for k, v in incoming_values.items():
if k not in flyte_interface_types:
raise AssertionError(f"Received unexpected keyword argument {k}")
v = await resolve_attr_path_recursively(v)
result[k] = v

return result


translate_inputs_to_native = loop_manager.synced(_translate_inputs_to_native)


async def resolve_attr_path_recursively(v: Any) -> Any:
"""
This function resolves the attribute path in a nested structure recursively.
Expand Down Expand Up @@ -1386,9 +1407,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
def local_execution_mode(self) -> ExecutionState.Mode: ...


def flyte_entity_call_handler(
# change this to async?
async def async_flyte_entity_call_handler(
entity: SupportsNodeCreation, *args, **kwargs
) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, Coroutine, None]:
) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]:
"""
This function is the call handler for tasks, workflows, and launch plans (which redirects to the underlying
workflow). The logic is the same for all three, but we did not want to create base class, hence this separate
Expand Down Expand Up @@ -1421,6 +1443,19 @@ def flyte_entity_call_handler(
kwargs[input_name] = arg

ctx = FlyteContextManager.current_context()
# todo: add condition here to let sync/async tasks be called during eager execution
# conditions to take care of
if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.EAGER_EXECUTION:
# for both nested eager, async, and sync tasks, submit to the informer.
if not ctx.worker_queue:
raise AssertionError("Worker queue missing, must be set when trying to execute tasks in an eager workflow")
loop = asyncio.get_running_loop()
fut = ctx.worker_queue.add(loop, entity, input_kwargs=kwargs)
result = await fut

return result
# if this is eager local execution, the proceed with normal local execution below

if ctx.execution_state and (
ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION
or ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION
Expand All @@ -1430,7 +1465,10 @@ def flyte_entity_call_handler(
)
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
return create_and_link_node(ctx, entity=entity, **kwargs)

# This handles the case for when we're already in a local execution state
if ctx.execution_state and ctx.execution_state.is_local_execution():
original_mode = ctx.execution_state.mode
mode = cast(LocallyExecutable, entity).local_execution_mode()
omt = OutputMetadataTracker()
with FlyteContextManager.with_context(
Expand All @@ -1448,7 +1486,16 @@ def flyte_entity_call_handler(
return create_task_output(vals, entity.python_interface)
else:
return None
return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs)
if original_mode == ExecutionState.Mode.EAGER_LOCAL_EXECUTION:
local_execute_results = cast(LocallyExecutable, entity).local_execute(ctx, **kwargs)
if mode == ExecutionState.Mode.EAGER_LOCAL_EXECUTION:
return local_execute_results

return create_native_named_tuple(ctx, local_execute_results, entity.python_interface)
else:
return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs)

# This condition kicks off a new local execution.
else:
mode = cast(LocallyExecutable, entity).local_execution_mode()
omt = OutputMetadataTracker()
Expand All @@ -1465,10 +1512,10 @@ def flyte_entity_call_handler(
else:
raise ValueError(f"Received an output when workflow local execution expected None. Received: {result}")

if inspect.iscoroutine(result):
if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.DYNAMIC_TASK_EXECUTION:
return result

if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.DYNAMIC_TASK_EXECUTION:
if mode == ExecutionState.Mode.EAGER_LOCAL_EXECUTION:
return result

if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or (
Expand All @@ -1481,3 +1528,6 @@ def flyte_entity_call_handler(
f"Result {result}. "
f"Python interface: {entity.python_interface}"
)


flyte_entity_call_handler = loop_manager.synced(async_flyte_entity_call_handler)
Loading
Loading