Skip to content

Commit

Permalink
Make activate script that runs on compute node
Browse files Browse the repository at this point in the history
Without this we rely on some external mechanism to make sure we have
the required environment on the compute cluster, for example the user
sourcing their vitrualenvironment in their bashrc/cshrc file.
  • Loading branch information
oyvindeide committed Nov 27, 2024
1 parent 1d62a28 commit 17e775f
Show file tree
Hide file tree
Showing 15 changed files with 133 additions and 15 deletions.
9 changes: 9 additions & 0 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
ForwardModelStepKeys,
HistorySource,
HookRuntime,
QueueSystemWithGeneric,
init_forward_model_schema,
init_site_config_schema,
init_user_config_schema,
Expand Down Expand Up @@ -260,6 +261,7 @@ class ErtConfig:
DEFAULT_RUNPATH_FILE: ClassVar[str] = ".ert_runpath_list"
PREINSTALLED_FORWARD_MODEL_STEPS: ClassVar[Dict[str, ForwardModelStep]] = {}
ENV_PR_FM_STEP: ClassVar[Dict[str, Dict[str, Any]]] = {}
ACTIVATE_SCRIPT: Optional[str] = None

substitutions: Substitutions = field(default_factory=Substitutions)
ensemble_config: EnsembleConfig = field(default_factory=EnsembleConfig)
Expand Down Expand Up @@ -347,6 +349,7 @@ class ErtConfigWithPlugins(ErtConfig):
Dict[str, ForwardModelStepPlugin]
] = preinstalled_fm_steps
ENV_PR_FM_STEP: ClassVar[Dict[str, Dict[str, Any]]] = env_pr_fm_step
ACTIVATE_SCRIPT = ErtPluginManager().activate_script()

assert issubclass(ErtConfigWithPlugins, ErtConfig)
return ErtConfigWithPlugins
Expand Down Expand Up @@ -675,6 +678,12 @@ def _merge_user_and_site_config(
user_config_dict[keyword] = value + original_entries
elif keyword not in user_config_dict:
user_config_dict[keyword] = value
if cls.ACTIVATE_SCRIPT:
if "QUEUE_OPTION" not in user_config_dict:
user_config_dict["QUEUE_OPTION"] = []
user_config_dict["QUEUE_OPTION"].append(
[QueueSystemWithGeneric.GENERIC, "ACTIVATE_SCRIPT", cls.ACTIVATE_SCRIPT]
)
return user_config_dict

@classmethod
Expand Down
10 changes: 9 additions & 1 deletion src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import os
import re
import shutil
from abc import abstractmethod
Expand All @@ -26,12 +27,20 @@
NonEmptyString = Annotated[str, pydantic.StringConstraints(min_length=1)]


def activate_script() -> str:
venv = os.environ.get("VIRTUAL_ENV")
if not venv:
return ""
return f"source {venv}/bin/activate"


@pydantic.dataclasses.dataclass(config={"extra": "forbid", "validate_assignment": True})
class QueueOptions:
name: str
max_running: pydantic.NonNegativeInt = 0
submit_sleep: pydantic.NonNegativeFloat = 0.0
project_code: Optional[str] = None
activate_script: str = field(default_factory=activate_script)

@staticmethod
def create_queue_options(
Expand Down Expand Up @@ -292,7 +301,6 @@ def from_dict(cls, config_dict: ConfigDict) -> QueueConfig:
_grouped_queue_options = _group_queue_options_by_queue_system(
_raw_queue_options
)

_log_duplicated_queue_options(_raw_queue_options)
_raise_for_defaulted_invalid_options(_raw_queue_options)

Expand Down
1 change: 1 addition & 0 deletions src/ert/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> Any:
"flow_config_path",
"help_links",
"site_config_lines",
"activate_script",
]
and res is not None
):
Expand Down
2 changes: 2 additions & 0 deletions src/ert/plugins/hook_specifications/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .activate_script import activate_script
from .ecl_config import (
ecl100_config_path,
ecl300_config_path,
Expand All @@ -21,6 +22,7 @@
from .site_config import site_config_lines

__all__ = [
"activate_script",
"add_log_handle_to_root",
"add_span_processor",
"ecl100_config_path",
Expand Down
19 changes: 19 additions & 0 deletions src/ert/plugins/hook_specifications/activate_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from ert.plugins.plugin_manager import hook_specification


@hook_specification
def activate_script() -> str: # type: ignore
"""
Allows the plugin to provide a script that will be run when
the driver submits to the cluster. The script will run in
bash.
Example:
import ert
@ert.plugin(name="my_plugin")
def activate_script():
return "source /private/venv/my_env/bin/activate
:return: Activate script
"""
11 changes: 11 additions & 0 deletions src/ert/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ def _site_config_lines(self) -> List[str]:
]
return list(chain.from_iterable(reversed(plugin_site_config_lines)))

def activate_script(self) -> str:
plugin_responses = self.hook.activate_script()
if not plugin_responses:
return ""
if len(plugin_responses) > 1:
raise ValueError(
f"Only one activate script is allowed, got {[plugin.plugin_metadata.plugin_name for plugin in plugin_responses]}"
)
else:
return plugin_responses[0].data

def get_installable_workflow_jobs(self) -> Dict[str, str]:
config_workflow_jobs = self._get_config_workflow_jobs()
hooked_workflow_jobs = self.get_ertscript_workflows().get_workflows()
Expand Down
8 changes: 6 additions & 2 deletions src/ert/scheduler/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
"""Bash and other shells add an offset of 128 to the signal value when a process exited due to a signal"""


def create_submit_script(runpath: Path, executable: str, args: tuple[str, ...]) -> str:
def create_submit_script(
runpath: Path, executable: str, args: tuple[str, ...], activate_script: str
) -> str:
return (
"#!/usr/bin/env bash\n"
f"cd {shlex.quote(str(runpath))}\n"
f"{activate_script}\n"
f"exec -a {shlex.quote(executable)} {executable} {shlex.join(args)}\n"
)

Expand All @@ -28,9 +31,10 @@ class FailedSubmit(RuntimeError):
class Driver(ABC):
"""Adapter for the HPC cluster."""

def __init__(self, **kwargs: Dict[str, str]) -> None:
def __init__(self, activate_script: str = "") -> None:
self._event_queue: Optional[asyncio.Queue[Event]] = None
self._job_error_message_by_iens: Dict[int, str] = {}
self.activate_script = activate_script

@property
def event_queue(self) -> asyncio.Queue[Event]:
Expand Down
1 change: 1 addition & 0 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def submit(
runpath: Optional[Path] = None,
num_cpu: Optional[int] = 1,
realization_memory: Optional[int] = 0,
activate_script: str = "",
) -> None:
self._tasks[iens] = asyncio.create_task(self._run(iens, executable, *args))
with suppress(KeyError):
Expand Down
5 changes: 3 additions & 2 deletions src/ert/scheduler/lsf_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,9 @@ def __init__(
bjobs_cmd: Optional[str] = None,
bkill_cmd: Optional[str] = None,
bhist_cmd: Optional[str] = None,
activate_script: str = "",
) -> None:
super().__init__()
super().__init__(activate_script)
self._queue_name = queue_name
self._project_code = project_code
self._resource_requirement = resource_requirement
Expand Down Expand Up @@ -309,7 +310,7 @@ async def submit(

arg_queue_name = ["-q", self._queue_name] if self._queue_name else []
arg_project_code = ["-P", self._project_code] if self._project_code else []
script = create_submit_script(runpath, executable, args)
script = create_submit_script(runpath, executable, args, self.activate_script)
script_path: Optional[Path] = None
try:
with NamedTemporaryFile(
Expand Down
5 changes: 3 additions & 2 deletions src/ert/scheduler/openpbs_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ def __init__(
qsub_cmd: Optional[str] = None,
qstat_cmd: Optional[str] = None,
qdel_cmd: Optional[str] = None,
activate_script: str = "",
) -> None:
super().__init__()
super().__init__(activate_script)

self._queue_name = queue_name
self._project_code = project_code
Expand Down Expand Up @@ -241,7 +242,7 @@ async def submit(
[] if self._keep_qsub_output else ["-o", "/dev/null", "-e", "/dev/null"]
)

script = create_submit_script(runpath, executable, args)
script = create_submit_script(runpath, executable, args, self.activate_script)
name_prefix = self._job_prefix or ""
qsub_with_args: List[str] = [
str(self._qsub_cmd),
Expand Down
5 changes: 3 additions & 2 deletions src/ert/scheduler/slurm_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
max_runtime: Optional[float] = None,
squeue_timeout: float = 2,
project_code: Optional[str] = None,
activate_script: str = "",
) -> None:
"""
The arguments "memory" and "realization_memory" are currently both
Expand All @@ -90,7 +91,7 @@ def __init__(
zero "realization memory" is the default and means no intended
memory allocation.
"""
super().__init__()
super().__init__(activate_script)
self._submit_locks: dict[int, asyncio.Lock] = {}
self._iens2jobid: dict[int, str] = {}
self._jobs: dict[str, JobData] = {}
Expand Down Expand Up @@ -181,7 +182,7 @@ async def submit(
if runpath is None:
runpath = Path.cwd()

script = create_submit_script(runpath, executable, args)
script = create_submit_script(runpath, executable, args, self.activate_script)
script_path: Optional[Path] = None
try:
with NamedTemporaryFile(
Expand Down
7 changes: 5 additions & 2 deletions src/everest/detached/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
SlurmQueueOptions,
TorqueQueueOptions,
)
from ert.plugins import ErtPluginManager
from ert.scheduler import create_driver
from ert.scheduler.driver import Driver, FailedSubmit
from ert.scheduler.event import StartedEvent
Expand Down Expand Up @@ -279,22 +280,24 @@ def get_server_queue_options(
simulator: Optional[SimulatorConfig],
server: Optional[ServerConfig],
) -> QueueOptions:
activate_script = ErtPluginManager().activate_script()
queue_system = _find_res_queue_system(simulator, server)
ever_queue_config = server if server is not None else simulator

if queue_system == QueueSystem.LSF:
queue = LsfQueueOptions(
activate_script=activate_script,
lsf_queue=ever_queue_config.name,
lsf_resource=ever_queue_config.options,
)
elif queue_system == QueueSystem.SLURM:
queue = SlurmQueueOptions(
activate_script=activate_script,
exclude_host=ever_queue_config.exclude_host,
include_host=ever_queue_config.include_host,
partition=ever_queue_config.name,
)
elif queue_system == QueueSystem.TORQUE:
queue = TorqueQueueOptions()
queue = TorqueQueueOptions(activate_script=activate_script)
elif queue_system == QueueSystem.LOCAL:
queue = LocalQueueOptions()
else:
Expand Down
13 changes: 13 additions & 0 deletions tests/ert/unit_tests/config/test_queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
QueueOptions,
SlurmQueueOptions,
TorqueQueueOptions,
)
Expand Down Expand Up @@ -509,3 +510,15 @@ def test_driver_initialization_from_defaults(queue_system):
LocalDriver(**LocalQueueOptions().driver_options)
if queue_system == QueueSystem.SLURM:
SlurmDriver(**SlurmQueueOptions().driver_options)


@pytest.mark.parametrize(
"venv, expected", [("my_env", "source my_env/bin/activate"), (None, "")]
)
def test_default_activate_script_generation(expected, monkeypatch, venv):
if venv:
monkeypatch.setenv("VIRTUAL_ENV", venv)
else:
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
options = QueueOptions(name="local")
assert options.activate_script == expected
48 changes: 45 additions & 3 deletions tests/ert/unit_tests/plugins/test_plugin_manager.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
import logging
import tempfile
from unittest.mock import Mock
from functools import partial
from unittest.mock import Mock, patch

import pytest
from opentelemetry.sdk.trace import TracerProvider

import ert.plugins.hook_implementations
from ert import plugin
from ert.plugins import ErtPluginManager
from ert.config import ErtConfig
from ert.plugins import ErtPluginManager, plugin
from tests.ert.unit_tests.plugins import dummy_plugins
from tests.ert.unit_tests.plugins.dummy_plugins import (
DummyFMStep,
Expand Down Expand Up @@ -279,3 +280,44 @@ def test_that_forward_model_step_is_registered(tmpdir):
with tmpdir.as_cwd():
pm = ErtPluginManager(plugins=[dummy_plugins])
assert pm.forward_model_steps == [DummyFMStep]


class ActivatePlugin:
@plugin(name="first")
def activate_script(self):
return "source something"


class AnotherActivatePlugin:
@plugin(name="second")
def activate_script(self):
return "Something"


class EmptyActivatePlugin:
@plugin(name="empty")
def activate_script(self):
return None


@pytest.mark.parametrize(
"plugins", [[ActivatePlugin()], [ActivatePlugin(), EmptyActivatePlugin()]]
)
def test_activate_script_hook(plugins):
pm = ErtPluginManager(plugins=plugins)
assert pm.activate_script() == "source something"


def test_multiple_activate_script_hook():
pm = ErtPluginManager(plugins=[ActivatePlugin(), AnotherActivatePlugin()])
with pytest.raises(ValueError, match="one activate script is allowed"):
pm.activate_script()


def test_activate_script_plugin_integration():
patched = partial(
ert.config.ert_config.ErtPluginManager, plugins=[ActivatePlugin()]
)
with patch("ert.config.ert_config.ErtPluginManager", patched):
config = ErtConfig.with_plugins().from_file_contents("NUM_REALIZATIONS 1\n")
assert config.queue_config.queue_options.activate_script == "source something"
4 changes: 3 additions & 1 deletion tests/everest/test_res_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,9 @@ def test_snake_everest_to_ert_torque(copy_test_data_to_tmp):
qc = ert_config.queue_config
qo = qc.queue_options
assert qc.queue_system == "TORQUE"
assert {k: v for k, v in qo.driver_options.items() if v is not None} == {
driver_options = qo.driver_options
driver_options.pop("activate_script")
assert {k: v for k, v in driver_options.items() if v is not None} == {
"project_code": "snake_oil_pc",
"qsub_cmd": "qsub",
"qstat_cmd": "qstat",
Expand Down

0 comments on commit 17e775f

Please sign in to comment.