Skip to content

Commit

Permalink
Use fifo for attribute file
Browse files Browse the repository at this point in the history
  • Loading branch information
filipcacky committed Oct 22, 2024
1 parent fcb652d commit e850e68
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 116 deletions.
10 changes: 4 additions & 6 deletions metaflow/plugins/argo/argo_workflows_deployer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
import tempfile
from typing import Optional, ClassVar

from metaflow.plugins.argo.argo_workflows import ArgoWorkflows
Expand All @@ -9,6 +8,7 @@
TriggeredRun,
get_lower_level_group,
handle_timeout,
temporary_fifo,
)


Expand Down Expand Up @@ -207,16 +207,14 @@ def trigger(instance: DeployedFlow, **kwargs):
Exception
If there is an error during the trigger process.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)

with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
instance.deployer.api,
instance.deployer.top_level_kwargs,
instance.deployer.TYPE,
instance.deployer.deployer_kwargs,
).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).trigger(deployer_attribute_file=attribute_file_path, **kwargs)

pid = instance.deployer.spm.run_command(
[sys.executable, *command],
Expand All @@ -227,7 +225,7 @@ def trigger(instance: DeployedFlow, **kwargs):

command_obj = instance.deployer.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, instance.deployer.file_read_timeout
attribute_file_fd, command_obj, instance.deployer.file_read_timeout
)

if command_obj.process.returncode == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TriggeredRun,
get_lower_level_group,
handle_timeout,
temporary_fifo,
)


Expand Down Expand Up @@ -174,16 +175,14 @@ def trigger(instance: DeployedFlow, **kwargs):
Exception
If there is an error during the trigger process.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)

with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
instance.deployer.api,
instance.deployer.top_level_kwargs,
instance.deployer.TYPE,
instance.deployer.deployer_kwargs,
).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).trigger(deployer_attribute_file=attribute_file_path, **kwargs)

pid = instance.deployer.spm.run_command(
[sys.executable, *command],
Expand All @@ -194,7 +193,7 @@ def trigger(instance: DeployedFlow, **kwargs):

command_obj = instance.deployer.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, instance.deployer.file_read_timeout
attribute_file_fd, command_obj, instance.deployer.file_read_timeout
)

if command_obj.process.returncode == 0:
Expand Down
12 changes: 4 additions & 8 deletions metaflow/runner/deployer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import time
import importlib
import functools
import tempfile

from typing import Optional, Dict, ClassVar

from metaflow.exception import MetaflowNotFound
from metaflow.runner.subprocess_manager import SubprocessManager
from metaflow.runner.utils import handle_timeout
from metaflow.runner.utils import handle_timeout, temporary_fifo


def get_lower_level_group(
Expand Down Expand Up @@ -332,14 +331,11 @@ def create(self, **kwargs) -> DeployedFlow:
Exception
If there is an error during deployment.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
self.api, self.top_level_kwargs, self.TYPE, self.deployer_kwargs
).create(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).create(deployer_attribute_file=attribute_file_path, **kwargs)

pid = self.spm.run_command(
[sys.executable, *command],
Expand All @@ -350,7 +346,7 @@ def create(self, **kwargs) -> DeployedFlow:

command_obj = self.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
attribute_file_fd, command_obj, self.file_read_timeout
)
content = json.loads(content)
self.name = content.get("name")
Expand Down
57 changes: 21 additions & 36 deletions metaflow/runner/metaflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import os
import sys
import json
import tempfile

from typing import Dict, Iterator, Optional, Tuple

from metaflow import Run, metadata

from .utils import handle_timeout, async_handle_timeout
from .utils import (
temporary_fifo,
handle_timeout,
async_handle_timeout,
)
from .subprocess_manager import CommandManager, SubprocessManager


Expand Down Expand Up @@ -266,10 +269,8 @@ def __enter__(self) -> "Runner":
async def __aenter__(self) -> "Runner":
return self

def __get_executing_run(self, tfp_runner_attribute, command_obj):
content = handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
)
def __get_executing_run(self, attribute_file_fd, command_obj):
content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout)
content = json.loads(content)
pathspec = "%s/%s" % (content.get("flow_name"), content.get("run_id"))

Expand All @@ -280,9 +281,9 @@ def __get_executing_run(self, tfp_runner_attribute, command_obj):
run_object = Run(pathspec, _namespace_check=False)
return ExecutingRun(self, command_obj, run_object)

async def __async_get_executing_run(self, tfp_runner_attribute, command_obj):
async def __async_get_executing_run(self, attribute_file_fd, command_obj):
content = await async_handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
attribute_file_fd, command_obj, self.file_read_timeout
)
content = json.loads(content)
pathspec = "%s/%s" % (content.get("flow_name"), content.get("run_id"))
Expand Down Expand Up @@ -310,12 +311,9 @@ def run(self, **kwargs) -> ExecutingRun:
ExecutingRun
ExecutingRun containing the results of the run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).run(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = self.spm.run_command(
Expand All @@ -326,7 +324,7 @@ def run(self, **kwargs) -> ExecutingRun:
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return self.__get_executing_run(attribute_file_fd, command_obj)

def resume(self, **kwargs):
"""
Expand All @@ -344,12 +342,9 @@ def resume(self, **kwargs):
ExecutingRun
ExecutingRun containing the results of the resumed run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).resume(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = self.spm.run_command(
Expand All @@ -360,7 +355,7 @@ def resume(self, **kwargs):
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return self.__get_executing_run(attribute_file_fd, command_obj)

async def async_run(self, **kwargs) -> ExecutingRun:
"""
Expand All @@ -380,12 +375,9 @@ async def async_run(self, **kwargs) -> ExecutingRun:
ExecutingRun
ExecutingRun representing the run that was started.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).run(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = await self.spm.async_run_command(
Expand All @@ -395,9 +387,7 @@ async def async_run(self, **kwargs) -> ExecutingRun:
)
command_obj = self.spm.get(pid)

return await self.__async_get_executing_run(
tfp_runner_attribute, command_obj
)
return await self.__async_get_executing_run(attribute_file_fd, command_obj)

async def async_resume(self, **kwargs):
"""
Expand All @@ -417,12 +407,9 @@ async def async_resume(self, **kwargs):
ExecutingRun
ExecutingRun representing the resumed run that was started.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).resume(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = await self.spm.async_run_command(
Expand All @@ -432,9 +419,7 @@ async def async_resume(self, **kwargs):
)
command_obj = self.spm.get(pid)

return await self.__async_get_executing_run(
tfp_runner_attribute, command_obj
)
return await self.__async_get_executing_run(attribute_file_fd, command_obj)

def __exit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()
Expand Down
Loading

0 comments on commit e850e68

Please sign in to comment.