Skip to content

Commit

Permalink
Use fifo for attribute file
Browse files Browse the repository at this point in the history
  • Loading branch information
filipcacky committed Oct 22, 2024
1 parent 1286f0f commit e48db92
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 115 deletions.
10 changes: 4 additions & 6 deletions metaflow/plugins/argo/argo_workflows_deployer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
import tempfile
from typing import Optional, ClassVar

from metaflow.plugins.argo.argo_workflows import ArgoWorkflows
Expand All @@ -9,6 +8,7 @@
TriggeredRun,
get_lower_level_group,
handle_timeout,
temporary_fifo,
)


Expand Down Expand Up @@ -207,16 +207,14 @@ def trigger(instance: DeployedFlow, **kwargs):
Exception
If there is an error during the trigger process.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)

with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
instance.deployer.api,
instance.deployer.top_level_kwargs,
instance.deployer.TYPE,
instance.deployer.deployer_kwargs,
).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).trigger(deployer_attribute_file=attribute_file_path, **kwargs)

pid = instance.deployer.spm.run_command(
[sys.executable, *command],
Expand All @@ -227,7 +225,7 @@ def trigger(instance: DeployedFlow, **kwargs):

command_obj = instance.deployer.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, instance.deployer.file_read_timeout
attribute_file_fd, command_obj, instance.deployer.file_read_timeout
)

if command_obj.process.returncode == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TriggeredRun,
get_lower_level_group,
handle_timeout,
temporary_fifo,
)


Expand Down Expand Up @@ -174,16 +175,14 @@ def trigger(instance: DeployedFlow, **kwargs):
Exception
If there is an error during the trigger process.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)

with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
instance.deployer.api,
instance.deployer.top_level_kwargs,
instance.deployer.TYPE,
instance.deployer.deployer_kwargs,
).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).trigger(deployer_attribute_file=attribute_file_path, **kwargs)

pid = instance.deployer.spm.run_command(
[sys.executable, *command],
Expand All @@ -194,7 +193,7 @@ def trigger(instance: DeployedFlow, **kwargs):

command_obj = instance.deployer.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, instance.deployer.file_read_timeout
attribute_file_fd, command_obj, instance.deployer.file_read_timeout
)

if command_obj.process.returncode == 0:
Expand Down
12 changes: 4 additions & 8 deletions metaflow/runner/deployer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import time
import importlib
import functools
import tempfile

from typing import Optional, Dict, ClassVar

from metaflow.exception import MetaflowNotFound
from metaflow.runner.subprocess_manager import SubprocessManager
from metaflow.runner.utils import handle_timeout
from metaflow.runner.utils import handle_timeout, temporary_fifo


def get_lower_level_group(
Expand Down Expand Up @@ -332,14 +331,11 @@ def create(self, **kwargs) -> DeployedFlow:
Exception
If there is an error during deployment.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
# every subclass needs to have `self.deployer_kwargs`
command = get_lower_level_group(
self.api, self.top_level_kwargs, self.TYPE, self.deployer_kwargs
).create(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
).create(deployer_attribute_file=attribute_file_path, **kwargs)

pid = self.spm.run_command(
[sys.executable, *command],
Expand All @@ -350,7 +346,7 @@ def create(self, **kwargs) -> DeployedFlow:

command_obj = self.spm.get(pid)
content = handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
attribute_file_fd, command_obj, self.file_read_timeout
)
content = json.loads(content)
self.name = content.get("name")
Expand Down
58 changes: 22 additions & 36 deletions metaflow/runner/metaflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
import os
import sys
import json
import tempfile

from typing import Dict, Iterator, Optional, Tuple

from metaflow import Run, metadata

from .utils import handle_timeout, async_handle_timeout, clear_and_set_os_environ
from .utils import (
clear_and_set_os_environ,
temporary_fifo,
handle_timeout,
async_handle_timeout,
)
from .subprocess_manager import CommandManager, SubprocessManager


Expand Down Expand Up @@ -282,10 +286,8 @@ def __restore_env_and_metadata(self, content):
metadata_for_flow = content.get("metadata")
metadata(metadata_for_flow)

def __get_executing_run(self, tfp_runner_attribute, command_obj):
content = handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
)
def __get_executing_run(self, attribute_file_fd, command_obj):
content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout)
content = json.loads(content)

self.__restore_env_and_metadata(content)
Expand All @@ -294,9 +296,9 @@ def __get_executing_run(self, tfp_runner_attribute, command_obj):
run_object = Run(pathspec, _namespace_check=False)
return ExecutingRun(self, command_obj, run_object)

async def __async_get_executing_run(self, tfp_runner_attribute, command_obj):
async def __async_get_executing_run(self, attribute_file_fd, command_obj):
content = await async_handle_timeout(
tfp_runner_attribute, command_obj, self.file_read_timeout
attribute_file_fd, command_obj, self.file_read_timeout
)
content = json.loads(content)

Expand All @@ -322,12 +324,9 @@ def run(self, **kwargs) -> ExecutingRun:
ExecutingRun
ExecutingRun containing the results of the run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).run(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = self.spm.run_command(
Expand All @@ -338,7 +337,7 @@ def run(self, **kwargs) -> ExecutingRun:
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return self.__get_executing_run(attribute_file_fd, command_obj)

def resume(self, **kwargs):
"""
Expand All @@ -356,12 +355,9 @@ def resume(self, **kwargs):
ExecutingRun
ExecutingRun containing the results of the resumed run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).resume(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = self.spm.run_command(
Expand All @@ -372,7 +368,7 @@ def resume(self, **kwargs):
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_runner_attribute, command_obj)
return self.__get_executing_run(attribute_file_fd, command_obj)

async def async_run(self, **kwargs) -> ExecutingRun:
"""
Expand All @@ -392,12 +388,9 @@ async def async_run(self, **kwargs) -> ExecutingRun:
ExecutingRun
ExecutingRun representing the run that was started.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).run(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = await self.spm.async_run_command(
Expand All @@ -407,9 +400,7 @@ async def async_run(self, **kwargs) -> ExecutingRun:
)
command_obj = self.spm.get(pid)

return await self.__async_get_executing_run(
tfp_runner_attribute, command_obj
)
return await self.__async_get_executing_run(attribute_file_fd, command_obj)

async def async_resume(self, **kwargs):
"""
Expand All @@ -429,12 +420,9 @@ async def async_resume(self, **kwargs):
ExecutingRun
ExecutingRun representing the resumed run that was started.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
with temporary_fifo() as (attribute_file_path, attribute_file_fd):
command = self.api(**self.top_level_kwargs).resume(
runner_attribute_file=tfp_runner_attribute.name, **kwargs
runner_attribute_file=attribute_file_path, **kwargs
)

pid = await self.spm.async_run_command(
Expand All @@ -444,9 +432,7 @@ async def async_resume(self, **kwargs):
)
command_obj = self.spm.get(pid)

return await self.__async_get_executing_run(
tfp_runner_attribute, command_obj
)
return await self.__async_get_executing_run(attribute_file_fd, command_obj)

def __exit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()
Expand Down
Loading

0 comments on commit e48db92

Please sign in to comment.