Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async runner improvements #2056

Merged
merged 17 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions metaflow/plugins/argo/argo_workflows_deployer_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metaflow.plugins.argo.argo_workflows import ArgoWorkflows
from metaflow.runner.deployer import Deployer, DeployedFlow, TriggeredRun

from metaflow.runner.utils import get_lower_level_group, handle_timeout
from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo


def generate_fake_flow_file_contents(
Expand Down Expand Up @@ -341,18 +341,14 @@ def trigger(self, **kwargs) -> ArgoWorkflowsTriggeredRun:
Exception
If there is an error during the trigger process.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)

with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
self.deployer.api,
self.deployer.top_level_kwargs,
self.deployer.TYPE,
self.deployer.deployer_kwargs,
).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).trigger(deployer_attribute_file=attribute_file_path, **kwargs)

pid = self.deployer.spm.run_command(
[sys.executable, *command],
Expand All @@ -363,7 +359,7 @@ def trigger(self, **kwargs) -> ArgoWorkflowsTriggeredRun:

command_obj = self.deployer.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, self.deployer.file_read_timeout
attribute_file_fd, command_obj, self.deployer.file_read_timeout
)

if command_obj.process.returncode == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from metaflow.plugins.aws.step_functions.step_functions import StepFunctions
from metaflow.runner.deployer import DeployedFlow, TriggeredRun

from metaflow.runner.utils import get_lower_level_group, handle_timeout
from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo


class StepFunctionsTriggeredRun(TriggeredRun):
Expand Down Expand Up @@ -196,18 +196,14 @@ def trigger(self, **kwargs) -> StepFunctionsTriggeredRun:
Exception
If there is an error during the trigger process.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)

with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
self.deployer.api,
self.deployer.top_level_kwargs,
self.deployer.TYPE,
self.deployer.deployer_kwargs,
).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).trigger(deployer_attribute_file=attribute_file_path, **kwargs)

pid = self.deployer.spm.run_command(
[sys.executable, *command],
Expand All @@ -218,7 +214,7 @@ def trigger(self, **kwargs) -> StepFunctionsTriggeredRun:

command_obj = self.deployer.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, self.deployer.file_read_timeout
attribute_file_fd, command_obj, self.deployer.file_read_timeout
)

if command_obj.process.returncode == 0:
Expand Down
12 changes: 4 additions & 8 deletions metaflow/runner/deployer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import json
import os
import sys
import tempfile

from typing import Any, ClassVar, Dict, Optional, TYPE_CHECKING, Type

from .subprocess_manager import SubprocessManager
from .utils import get_lower_level_group, handle_timeout
from .utils import get_lower_level_group, handle_timeout, temporary_fifo

if TYPE_CHECKING:
import metaflow.runner.deployer
Expand Down Expand Up @@ -121,14 +120,11 @@ def create(self, **kwargs) -> "metaflow.runner.deployer.DeployedFlow":
def _create(
self, create_class: Type["metaflow.runner.deployer.DeployedFlow"], **kwargs
) -> "metaflow.runner.deployer.DeployedFlow":
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
self.api, self.top_level_kwargs, self.TYPE, self.deployer_kwargs
).create(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).create(deployer_attribute_file=attribute_file_path, **kwargs)

pid = self.spm.run_command(
[sys.executable, *command],
Expand All @@ -139,7 +135,7 @@ def _create(

command_obj = self.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
attribute_file_fd, command_obj, self.file_read_timeout
)
content = json.loads(content)
self.name = content.get("name")
Expand Down
63 changes: 33 additions & 30 deletions metaflow/runner/metaflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import os
import sys
import json
import tempfile

from typing import Dict, Iterator, Optional, Tuple

from metaflow import Run

from .utils import handle_timeout
from .utils import (
temporary_fifo,
handle_timeout,
async_handle_timeout,
)
from .subprocess_manager import CommandManager, SubprocessManager


Expand Down Expand Up @@ -267,10 +270,8 @@ def __enter__(self) -> "Runner":
async def __aenter__(self) -> "Runner":
return self

def __get_executing_run(self, tfp_runner_attribute, command_obj):
content = handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
)
def __get_executing_run(self, attribute_file_fd, command_obj):
content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout)
content = json.loads(content)
pathspec = "%s/%s" % (content.get("flow_name"), content.get("run_id"))

Expand All @@ -282,6 +283,20 @@ def __get_executing_run(self, tfp_runner_attribute, command_obj):
)
return ExecutingRun(self, command_obj, run_object)

async def __async_get_executing_run(self, attribute_file_fd, command_obj):
content = await async_handle_timeout(
attribute_file_fd, command_obj, self.file_read_timeout
)
content = json.loads(content)
pathspec = "%s/%s" % (content.get("flow_name"), content.get("run_id"))

# Set the correct metadata from the runner_attribute file corresponding to this run.
metadata_for_flow = content.get("metadata")
metadata(metadata_for_flow)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are no longer setting the metadata explicitly.. i.e metadata(metadata_for_flow) needs to be removed..

refer to the implementation of __get_executing_run and you will see the following..

# Set the correct metadata from the runner_attribute file corresponding to this run.
metadata_for_flow = content.get("metadata")

run_object = Run(
    pathspec, _namespace_check=False, _current_metadata=metadata_for_flow
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed! Must've missed it while rebasing.


run_object = Run(pathspec, _namespace_check=False)
madhur-ob marked this conversation as resolved.
Show resolved Hide resolved
return ExecutingRun(self, command_obj, run_object)

def run(self, **kwargs) -> ExecutingRun:
"""
Blocking execution of the run. This method will wait until
Expand All @@ -298,12 +313,9 @@ def run(self, **kwargs) -> ExecutingRun:
ExecutingRun
ExecutingRun containing the results of the run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).run(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = self.spm.run_command(
Expand All @@ -314,7 +326,7 @@ def run(self, **kwargs) -> ExecutingRun:
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return self.__get_executing_run(attribute_file_fd, command_obj)

def resume(self, **kwargs):
"""
Expand All @@ -332,12 +344,9 @@ def resume(self, **kwargs):
ExecutingRun
ExecutingRun containing the results of the resumed run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).resume(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = self.spm.run_command(
Expand All @@ -348,7 +357,7 @@ def resume(self, **kwargs):
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return self.__get_executing_run(attribute_file_fd, command_obj)

async def async_run(self, **kwargs) -> ExecutingRun:
"""
Expand All @@ -368,12 +377,9 @@ async def async_run(self, **kwargs) -> ExecutingRun:
ExecutingRun
ExecutingRun representing the run that was started.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).run(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = await self.spm.async_run_command(
Expand All @@ -383,7 +389,7 @@ async def async_run(self, **kwargs) -> ExecutingRun:
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return await self.__async_get_executing_run(attribute_file_fd, command_obj)

async def async_resume(self, **kwargs):
"""
Expand All @@ -403,12 +409,9 @@ async def async_resume(self, **kwargs):
ExecutingRun
ExecutingRun representing the resumed run that was started.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).resume(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = await self.spm.async_run_command(
Expand All @@ -418,7 +421,7 @@ async def async_resume(self, **kwargs):
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return await self.__async_get_executing_run(attribute_file_fd, command_obj)

def __exit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()
Expand Down
55 changes: 46 additions & 9 deletions metaflow/runner/subprocess_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,49 @@
import threading
from typing import Callable, Dict, Iterator, List, Optional, Tuple

from .utils import check_process_exited

def kill_process_and_descendants(pid, termination_timeout):

def kill_processes_and_descendants(pids: List[str], termination_timeout: float):
# TODO: there's a race condition that new descendants might
# spawn b/w the invocations of 'pkill' and 'kill'.
# Needs to be fixed in future.
try:
subprocess.check_call(["pkill", "-TERM", "-P", str(pid)])
subprocess.check_call(["kill", "-TERM", str(pid)])
subprocess.check_call(["pkill", "-TERM", "-P", *pids])
subprocess.check_call(["kill", "-TERM", *pids])
except subprocess.CalledProcessError:
pass

time.sleep(termination_timeout)

try:
subprocess.check_call(["pkill", "-KILL", "-P", str(pid)])
subprocess.check_call(["kill", "-KILL", str(pid)])
subprocess.check_call(["pkill", "-KILL", "-P", *pids])
subprocess.check_call(["kill", "-KILL", *pids])
except subprocess.CalledProcessError:
pass


async def async_kill_processes_and_descendants(
pids: List[str], termination_timeout: float
):
# TODO: there's a race condition that new descendants might
# spawn b/w the invocations of 'pkill' and 'kill'.
# Needs to be fixed in future.
sub_term = await asyncio.create_subprocess_exec("pkill", "-TERM", "-P", *pids)
await sub_term.wait()

main_term = await asyncio.create_subprocess_exec("kill", "-TERM", *pids)
await main_term.wait()

await asyncio.sleep(termination_timeout)

sub_kill = await asyncio.create_subprocess_exec("pkill", "-KILL", "-P", *pids)
await sub_kill.wait()

main_kill = await asyncio.create_subprocess_exec("kill", "-KILL", *pids)
await main_kill.wait()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can encapsulate the SIGTERM part and the SIGKILL part inside try/except each?
just like we did inside the sync version i.e. kill_processes_and_descendants

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was checking asyncio docs as well as the implementation while writing this part, didn't see any info about or any exceptions being raised, but i may have completely missed them, not familiar with the impl. Probably a good idea to wrap the calls regardless.


class LogReadTimeoutError(Exception):
"""Exception raised when reading logs times out."""

Expand All @@ -46,14 +69,28 @@ def __init__(self):
loop = asyncio.get_running_loop()
loop.add_signal_handler(
signal.SIGINT,
lambda: self._handle_sigint(signum=signal.SIGINT, frame=None),
lambda: asyncio.create_task(self._async_handle_sigint()),
)
except RuntimeError:
signal.signal(signal.SIGINT, self._handle_sigint)

async def _async_handle_sigint(self):
pids = [
str(command.process.pid)
for command in self.commands.values()
if command.process and not check_process_exited(command)
]
if pids:
await async_kill_processes_and_descendants(pids, termination_timeout=2)

def _handle_sigint(self, signum, frame):
for each_command in self.commands.values():
each_command.kill(termination_timeout=2)
pids = [
str(command.process.pid)
for command in self.commands.values()
if command.process and not check_process_exited(command)
]
if pids:
kill_processes_and_descendants(pids, termination_timeout=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nitpick: earlier, termination_timeout was 2 seconds per PID
maybe we should pass 2*len(pids) here? since this will handle all of them collectively..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly for the async version, if we choose to do this..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be necessary anymore. All the pids are enumerated before the call and are sent out simultaneously by kill/pkill, so they all get the same amount of time as before.


async def __aenter__(self) -> "SubprocessManager":
return self
Expand Down Expand Up @@ -472,7 +509,7 @@ def kill(self, termination_timeout: float = 2):
"""

if self.process is not None:
kill_process_and_descendants(self.process.pid, termination_timeout)
kill_processes_and_descendants([str(self.process.pid)], termination_timeout)
else:
print("No process to kill.")

Expand Down
Loading
Loading