Skip to content

Commit

Permalink
Async runner improvements (#2056)
Browse files Browse the repository at this point in the history
* Add an async sigint handler

* Improve async runner interface

* Fix error handling in *handle_timeout

* Fix async_read_from_file_when_ready

* Simplify async_kill_processes_and_descendants

* Group kill_process_and_descendants into single command

* Call sigint manager only for live processes

* Guard against empty pid list in sigint handler

* Cleanup comments

* Reimplement CommandManager.kill in terms of kill_processes_and_descendants

* Use fifo for attribute file

* Fix `temporary_fifo` type annotation

* Fix `temporary_fifo` type annotation on py37 and py38

* Fix rewriting the current run metadata

* Wrap execs in `async_kill_processes_and_descendants` in try/except
  • Loading branch information
filipcacky authored Dec 2, 2024
1 parent 8ccbbbe commit 286f9ac
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 105 deletions.
12 changes: 4 additions & 8 deletions metaflow/plugins/argo/argo_workflows_deployer_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metaflow.plugins.argo.argo_workflows import ArgoWorkflows
from metaflow.runner.deployer import Deployer, DeployedFlow, TriggeredRun

from metaflow.runner.utils import get_lower_level_group, handle_timeout
from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo


def generate_fake_flow_file_contents(
Expand Down Expand Up @@ -341,18 +341,14 @@ def trigger(self, **kwargs) -> ArgoWorkflowsTriggeredRun:
Exception
If there is an error during the trigger process.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)

with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
self.deployer.api,
self.deployer.top_level_kwargs,
self.deployer.TYPE,
self.deployer.deployer_kwargs,
).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).trigger(deployer_attribute_file=attribute_file_path, **kwargs)

pid = self.deployer.spm.run_command(
[sys.executable, *command],
Expand All @@ -363,7 +359,7 @@ def trigger(self, **kwargs) -> ArgoWorkflowsTriggeredRun:

command_obj = self.deployer.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, self.deployer.file_read_timeout
attribute_file_fd, command_obj, self.deployer.file_read_timeout
)

if command_obj.process.returncode == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from metaflow.plugins.aws.step_functions.step_functions import StepFunctions
from metaflow.runner.deployer import DeployedFlow, TriggeredRun

from metaflow.runner.utils import get_lower_level_group, handle_timeout
from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo


class StepFunctionsTriggeredRun(TriggeredRun):
Expand Down Expand Up @@ -196,18 +196,14 @@ def trigger(self, **kwargs) -> StepFunctionsTriggeredRun:
Exception
If there is an error during the trigger process.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)

with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
self.deployer.api,
self.deployer.top_level_kwargs,
self.deployer.TYPE,
self.deployer.deployer_kwargs,
).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).trigger(deployer_attribute_file=attribute_file_path, **kwargs)

pid = self.deployer.spm.run_command(
[sys.executable, *command],
Expand All @@ -218,7 +214,7 @@ def trigger(self, **kwargs) -> StepFunctionsTriggeredRun:

command_obj = self.deployer.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, self.deployer.file_read_timeout
attribute_file_fd, command_obj, self.deployer.file_read_timeout
)

if command_obj.process.returncode == 0:
Expand Down
12 changes: 4 additions & 8 deletions metaflow/runner/deployer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import json
import os
import sys
import tempfile

from typing import Any, ClassVar, Dict, Optional, TYPE_CHECKING, Type

from .subprocess_manager import SubprocessManager
from .utils import get_lower_level_group, handle_timeout
from .utils import get_lower_level_group, handle_timeout, temporary_fifo

if TYPE_CHECKING:
import metaflow.runner.deployer
Expand Down Expand Up @@ -121,14 +120,11 @@ def create(self, **kwargs) -> "metaflow.runner.deployer.DeployedFlow":
def _create(
self, create_class: Type["metaflow.runner.deployer.DeployedFlow"], **kwargs
) -> "metaflow.runner.deployer.DeployedFlow":
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
self.api, self.top_level_kwargs, self.TYPE, self.deployer_kwargs
).create(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).create(deployer_attribute_file=attribute_file_path, **kwargs)

pid = self.spm.run_command(
[sys.executable, *command],
Expand All @@ -139,7 +135,7 @@ def _create(

command_obj = self.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
attribute_file_fd, command_obj, self.file_read_timeout
)
content = json.loads(content)
self.name = content.get("name")
Expand Down
62 changes: 33 additions & 29 deletions metaflow/runner/metaflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import os
import sys
import json
import tempfile

from typing import Dict, Iterator, Optional, Tuple

from metaflow import Run

from .utils import handle_timeout
from .utils import (
temporary_fifo,
handle_timeout,
async_handle_timeout,
)
from .subprocess_manager import CommandManager, SubprocessManager


Expand Down Expand Up @@ -267,9 +270,22 @@ def __enter__(self) -> "Runner":
async def __aenter__(self) -> "Runner":
return self

def __get_executing_run(self, tfp_runner_attribute, command_obj):
content = handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
def __get_executing_run(self, attribute_file_fd, command_obj):
content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout)
content = json.loads(content)
pathspec = "%s/%s" % (content.get("flow_name"), content.get("run_id"))

# Set the correct metadata from the runner_attribute file corresponding to this run.
metadata_for_flow = content.get("metadata")

run_object = Run(
pathspec, _namespace_check=False, _current_metadata=metadata_for_flow
)
return ExecutingRun(self, command_obj, run_object)

async def __async_get_executing_run(self, attribute_file_fd, command_obj):
content = await async_handle_timeout(
attribute_file_fd, command_obj, self.file_read_timeout
)
content = json.loads(content)
pathspec = "%s/%s" % (content.get("flow_name"), content.get("run_id"))
Expand Down Expand Up @@ -298,12 +314,9 @@ def run(self, **kwargs) -> ExecutingRun:
ExecutingRun
ExecutingRun containing the results of the run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).run(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = self.spm.run_command(
Expand All @@ -314,7 +327,7 @@ def run(self, **kwargs) -> ExecutingRun:
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return self.__get_executing_run(attribute_file_fd, command_obj)

def resume(self, **kwargs):
"""
Expand All @@ -332,12 +345,9 @@ def resume(self, **kwargs):
ExecutingRun
ExecutingRun containing the results of the resumed run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).resume(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = self.spm.run_command(
Expand All @@ -348,7 +358,7 @@ def resume(self, **kwargs):
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return self.__get_executing_run(attribute_file_fd, command_obj)

async def async_run(self, **kwargs) -> ExecutingRun:
"""
Expand All @@ -368,12 +378,9 @@ async def async_run(self, **kwargs) -> ExecutingRun:
ExecutingRun
ExecutingRun representing the run that was started.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).run(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = await self.spm.async_run_command(
Expand All @@ -383,7 +390,7 @@ async def async_run(self, **kwargs) -> ExecutingRun:
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return await self.__async_get_executing_run(attribute_file_fd, command_obj)

async def async_resume(self, **kwargs):
"""
Expand All @@ -403,12 +410,9 @@ async def async_resume(self, **kwargs):
ExecutingRun
ExecutingRun representing the resumed run that was started.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).resume(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = await self.spm.async_run_command(
Expand All @@ -418,7 +422,7 @@ async def async_resume(self, **kwargs):
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return await self.__async_get_executing_run(attribute_file_fd, command_obj)

def __exit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()
Expand Down
67 changes: 58 additions & 9 deletions metaflow/runner/subprocess_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,61 @@
import threading
from typing import Callable, Dict, Iterator, List, Optional, Tuple

from .utils import check_process_exited

def kill_process_and_descendants(pid, termination_timeout):

def kill_processes_and_descendants(pids: List[str], termination_timeout: float):
# TODO: there's a race condition that new descendants might
# spawn b/w the invocations of 'pkill' and 'kill'.
# Needs to be fixed in future.
try:
subprocess.check_call(["pkill", "-TERM", "-P", str(pid)])
subprocess.check_call(["kill", "-TERM", str(pid)])
subprocess.check_call(["pkill", "-TERM", "-P", *pids])
subprocess.check_call(["kill", "-TERM", *pids])
except subprocess.CalledProcessError:
pass

time.sleep(termination_timeout)

try:
subprocess.check_call(["pkill", "-KILL", "-P", str(pid)])
subprocess.check_call(["kill", "-KILL", str(pid)])
subprocess.check_call(["pkill", "-KILL", "-P", *pids])
subprocess.check_call(["kill", "-KILL", *pids])
except subprocess.CalledProcessError:
pass


async def async_kill_processes_and_descendants(
pids: List[str], termination_timeout: float
):
# TODO: there's a race condition that new descendants might
# spawn b/w the invocations of 'pkill' and 'kill'.
# Needs to be fixed in future.
try:
sub_term = await asyncio.create_subprocess_exec("pkill", "-TERM", "-P", *pids)
await sub_term.wait()
except Exception:
pass

try:
main_term = await asyncio.create_subprocess_exec("kill", "-TERM", *pids)
await main_term.wait()
except Exception:
pass

await asyncio.sleep(termination_timeout)

try:
sub_kill = await asyncio.create_subprocess_exec("pkill", "-KILL", "-P", *pids)
await sub_kill.wait()
except Exception:
pass

try:
main_kill = await asyncio.create_subprocess_exec("kill", "-KILL", *pids)
await main_kill.wait()
except Exception:
pass


class LogReadTimeoutError(Exception):
"""Exception raised when reading logs times out."""

Expand All @@ -46,14 +81,28 @@ def __init__(self):
loop = asyncio.get_running_loop()
loop.add_signal_handler(
signal.SIGINT,
lambda: self._handle_sigint(signum=signal.SIGINT, frame=None),
lambda: asyncio.create_task(self._async_handle_sigint()),
)
except RuntimeError:
signal.signal(signal.SIGINT, self._handle_sigint)

async def _async_handle_sigint(self):
pids = [
str(command.process.pid)
for command in self.commands.values()
if command.process and not check_process_exited(command)
]
if pids:
await async_kill_processes_and_descendants(pids, termination_timeout=2)

def _handle_sigint(self, signum, frame):
for each_command in self.commands.values():
each_command.kill(termination_timeout=2)
pids = [
str(command.process.pid)
for command in self.commands.values()
if command.process and not check_process_exited(command)
]
if pids:
kill_processes_and_descendants(pids, termination_timeout=2)

async def __aenter__(self) -> "SubprocessManager":
return self
Expand Down Expand Up @@ -472,7 +521,7 @@ def kill(self, termination_timeout: float = 2):
"""

if self.process is not None:
kill_process_and_descendants(self.process.pid, termination_timeout)
kill_processes_and_descendants([str(self.process.pid)], termination_timeout)
else:
print("No process to kill.")

Expand Down
Loading

0 comments on commit 286f9ac

Please sign in to comment.