diff --git a/milabench/cli/cloud.py b/milabench/cli/cloud.py index 8d95a47d1..14a6eea1b 100644 --- a/milabench/cli/cloud.py +++ b/milabench/cli/cloud.py @@ -1,8 +1,8 @@ from copy import deepcopy import os -import socket import subprocess import sys +import warnings from coleo import Option, tooled from omegaconf import OmegaConf @@ -22,7 +22,7 @@ def _flatten_cli_args(**kwargs): return sum( ( - (f"--{str(k).replace('_', '-')}", *([str(v)] if str(v) else [])) + (f"--{str(k).replace('_', '-')}", *([str(v)] if v is not None else [])) for k, v in kwargs.items() ), () ) @@ -39,7 +39,7 @@ def manage_cloud(pack, run_on, action="setup"): "hostname":(lambda v: ("ip",v)), "username":(lambda v: ("user",v)), "ssh_key_file":(lambda v: ("key",v)), - "env":(lambda v: ("env",[".", v, ";", "conda", "activate", "milabench", "&&"])), + # "env":(lambda v: ("env",[".", v, ";", "conda", "activate", "milabench", "&&"])), } plan_params = deepcopy(pack.config["system"]["cloud_profiles"][run_on]) run_on, *profile = run_on.split("__") @@ -58,8 +58,9 @@ def manage_cloud(pack, run_on, action="setup"): plan_params["state_prefix"] = plan_params.get("state_prefix", default_state_prefix) plan_params["state_id"] = plan_params.get("state_id", default_state_id) plan_params["cluster_size"] = max(len(pack.config["system"]["nodes"]), i + 1) + plan_params["keep_alive"] = None - import milabench.cli.covalent as cv + import milabench.scripts.covalent as cv subprocess.run( [ @@ -106,12 +107,16 @@ def manage_cloud(pack, run_on, action="setup"): continue try: k, v = line_str.split("::>") - k, v = key_map[k](v) - if k == "ip" and n[k] != "1.1.1.1": - i, n = next(nodes) - n[k] = v except ValueError: - pass + continue + try: + k, v = key_map[k](v) + except KeyError: + warnings.warn(f"Ignoring invalid key received: {k}:{v}") + continue + if k == "ip" and n[k] != "1.1.1.1": + i, n = next(nodes) + n[k] = v _, stderr = p.communicate() stderr = stderr.decode("utf-8").strip() @@ -159,7 +164,6 @@ def _teardown(): if all: overrides = { "*": OmegaConf.to_object(OmegaConf.from_dotlist([ - f"system.cloud_profiles.{run_on}.state_prefix='*'", f"system.cloud_profiles.{run_on}.state_id='*'", ])) } diff --git a/milabench/common.py b/milabench/common.py index 8d37f59f0..f3bd698dc 100644 --- a/milabench/common.py +++ b/milabench/common.py @@ -335,7 +335,7 @@ def _push_reports(reports_repo, runs): "partial": "yellow", "failure": "red", } - import milabench.cli.badges as badges + import milabench.scripts.badges as badges _repo = git.repo.base.Repo(ROOT_FOLDER) try: diff --git a/milabench/cli/badges/__main__.py b/milabench/scripts/badges/__main__.py similarity index 100% rename from milabench/cli/badges/__main__.py rename to milabench/scripts/badges/__main__.py diff --git a/milabench/cli/badges/requirements.txt b/milabench/scripts/badges/requirements.txt similarity index 100% rename from milabench/cli/badges/requirements.txt rename to milabench/scripts/badges/requirements.txt diff --git a/milabench/cli/covalent/__main__.py b/milabench/scripts/covalent/__main__.py similarity index 67% rename from milabench/cli/covalent/__main__.py rename to milabench/scripts/covalent/__main__.py index eb602ee27..995cc856f 100644 --- a/milabench/cli/covalent/__main__.py +++ b/milabench/scripts/covalent/__main__.py @@ -1,5 +1,4 @@ import argparse -import asyncio import os import pathlib import subprocess @@ -17,17 +16,12 @@ def serve(*argv): def _get_executor_kwargs(args): return { **{k:v for k,v in vars(args).items() if k not in ("setup", "teardown")}, - **{"action":k for k,v in vars(args).items() if k in ("setup", "teardown") and v}, } def executor(executor_cls, args, *argv): import covalent as ct - executor:ct.executor.BaseExecutor = executor_cls( - **_get_executor_kwargs(args), - ) - def _popen(cmd, *args, _env=None, **kwargs): _env = _env if _env is not None else {} @@ -89,65 +83,43 @@ def _popen(cmd, *args, _env=None, **kwargs): ) return p.returncode, stdout, stderr - @ct.lattice - def lattice(argv=(), deps_bash = None): - return ct.electron( - _popen, - executor=executor, - deps_bash=deps_bash, - )( - argv, - ) - + executor:ct.executor.BaseExecutor = executor_cls( + **_get_executor_kwargs(args), + ) return_code = 0 try: - dispatch_id = None - result = None - deps_bash = None + if args.setup: + dispatch_id = ct.dispatch( + ct.lattice(executor.get_connection_attributes), disable_run=False + )() - if not argv and args.setup: - deps_bash = ct.DepsBash([]) - # Make sure pip is installed - argv = ["python3", "-m", "pip", "freeze"] + result = ct.get_result(dispatch_id=dispatch_id, wait=True).result - if argv: - dispatch_id = ct.dispatch(lattice, disable_run=False)(argv, deps_bash=deps_bash) - result = ct.get_result(dispatch_id=dispatch_id, wait=True) - return_code, _, _ = result.result if result.result is not None else (1, "", "") - - if return_code == 0 and args.setup: - _executor:ct.executor.BaseExecutor = executor_cls( - **{ - **_get_executor_kwargs(args), - **{"action": "teardown"}, - } - ) - asyncio.run(_executor.setup({})) + assert result and result[0] - assert _executor.hostnames - for hostname in _executor.hostnames: + all_connection_attributes, _ = result + for hostname, connection_attributes in all_connection_attributes.items(): print(f"hostname::>{hostname}") - print(f"username::>{_executor.username}") - print(f"ssh_key_file::>{_executor.ssh_key_file}") + for attribute,value in connection_attributes.items(): + if attribute == "hostname": + continue + print(f"{attribute}::>{value}") + + if argv: + dispatch_id = ct.dispatch( + ct.lattice( + lambda:ct.electron(_popen, executor=executor)(argv) + ), + disable_run=False + )() + + result = ct.get_result(dispatch_id=dispatch_id, wait=True).result + + return_code, _, _ = result if result is not None else (1, "", "") finally: - result = ct.get_result(dispatch_id=dispatch_id, wait=False) if dispatch_id else None - results_dir = result.results_dir if result else "" if args.teardown: - try: - _executor:ct.executor.BaseExecutor = executor_cls( - **{ - **_get_executor_kwargs(args), - **{"action": "teardown"}, - } - ) - asyncio.run(_executor.setup({})) - asyncio.run( - _executor.teardown( - {"dispatch_id": dispatch_id, "node_id": 0, "results_dir": results_dir} - ) - ) - except FileNotFoundError: - pass + result = executor.stop_cloud_instance().result + assert result is not None return return_code @@ -176,9 +148,14 @@ def main(argv=None): subparser.add_argument(f"--setup", action="store_true") subparser.add_argument(f"--teardown", action="store_true") for param, default in config.items(): - if param == "action": + if param.startswith("_"): continue - subparser.add_argument(f"--{param.replace('_', '-')}", default=default) + add_argument_kwargs = {} + if isinstance(default, bool): + add_argument_kwargs["action"] = "store_false" if default else "store_true" + else: + add_argument_kwargs["default"] = default + subparser.add_argument(f"--{param.replace('_', '-')}", **add_argument_kwargs) try: cv_argv, argv = argv[:argv.index("--")], argv[argv.index("--")+1:] diff --git a/milabench/cli/covalent/requirements.txt b/milabench/scripts/covalent/requirements.txt similarity index 100% rename from milabench/cli/covalent/requirements.txt rename to milabench/scripts/covalent/requirements.txt diff --git a/milabench/cli/utils.py b/milabench/scripts/utils.py similarity index 100% rename from milabench/cli/utils.py rename to milabench/scripts/utils.py