Skip to content

Commit

Permalink
Rename Local -> LocalV1 and Remote -> RemoteV1 (#121)
Browse files Browse the repository at this point in the history
* Rename `local.py`->`local_v1.py`

Signed-off-by: Fabrice Normandin <[email protected]>

* Rename `Local`->`LocalV1`

Signed-off-by: Fabrice Normandin <[email protected]>

* Move `cli/local_v1.py` under `utils`

Signed-off-by: Fabrice Normandin <[email protected]>

* Move tests for LocalV1 from tests/cli->tests/utils

Signed-off-by: Fabrice Normandin <[email protected]>

* Rename `remote.py` to `remote_v1.py`

Signed-off-by: Fabrice Normandin <[email protected]>

* Rename `Remote` -> `RemoteV1`

Signed-off-by: Fabrice Normandin <[email protected]>

* Move `cli/remote_v1.py` under `utils`

Signed-off-by: Fabrice Normandin <[email protected]>

* Move RemoteV1 tests from tests/cli->tests/utils

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix regression files after Remote->RemoteV1 rename

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice authored May 1, 2024
1 parent 320be40 commit 400bc30
Show file tree
Hide file tree
Showing 32 changed files with 138 additions and 132 deletions.
40 changes: 22 additions & 18 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
)

from ..__version__ import __version__
from ..utils.local_v1 import LocalV1
from ..utils.remote_v1 import RemoteV1, SlurmRemote
from .init_command import (
print_welcome_message,
setup_keys_on_login_node,
Expand All @@ -46,9 +48,7 @@
setup_vscode_settings,
setup_windows_ssh_config_from_wsl,
)
from .local import Local
from .profile import ensure_program, setup_profile
from .remote import Remote, SlurmRemote
from .utils import (
CLUSTERS,
Cluster,
Expand Down Expand Up @@ -518,7 +518,7 @@ def forward(
pass

local_proc, _ = _forward(
local=Local(),
local=LocalV1(),
node=f"{node}.server.mila.quebec",
to_forward=remote_port,
page=page,
Expand Down Expand Up @@ -553,8 +553,8 @@ def code(
node: Node to connect to
alloc: Extra options to pass to slurm
"""
here = Local()
remote = Remote(cluster)
here = LocalV1()
remote = RemoteV1(cluster)

if cluster != "mila" and job is None and node is None:
if not any("--account" in flag for flag in alloc):
Expand Down Expand Up @@ -603,7 +603,7 @@ def code(
copy_vscode_extensions_process.start()
else:
sync_vscode_extensions(
Local(),
LocalV1(),
[cluster],
)

Expand Down Expand Up @@ -697,10 +697,10 @@ def code(
def connect(identifier: str, port: int | None):
"""Reconnect to a persistent server."""

remote = Remote("mila")
remote = RemoteV1("mila")
info = _get_server_info(remote, identifier)
local_proc, _ = _forward(
local=Local(),
local=LocalV1(),
node=f"{info['node_name']}.server.mila.quebec",
to_forward=info["to_forward"],
options={"token": info.get("token", None)},
Expand All @@ -718,7 +718,7 @@ def connect(identifier: str, port: int | None):

def kill(identifier: str | None, all: bool = False):
"""Kill a persistent server."""
remote = Remote("mila")
remote = RemoteV1("mila")

if all:
for identifier in remote.get_lines("ls .milatools/control", hide=True):
Expand All @@ -740,7 +740,7 @@ def kill(identifier: str | None, all: bool = False):

def serve_list(purge: bool):
"""List active servers."""
remote = Remote("mila")
remote = RemoteV1("mila")

to_purge = []

Expand Down Expand Up @@ -899,7 +899,7 @@ def aim(logdir: str, **kwargs: Unpack[StandardServerArgs]):


def _get_server_info(
remote: Remote, identifier: str, hide: bool = False
remote: RemoteV1, identifier: str, hide: bool = False
) -> dict[str, str]:
text = remote.get_output(f"cat .milatools/control/{identifier}", hide=hide)
info = dict(line.split(" = ") for line in text.split("\n") if line)
Expand Down Expand Up @@ -993,7 +993,7 @@ def _standard_server(
elif persist:
name = program

remote = Remote("mila")
remote = RemoteV1("mila")

path = path or "~"
if path == "~" or path.startswith("~/"):
Expand Down Expand Up @@ -1107,7 +1107,7 @@ def _standard_server(
options = {}

local_proc, local_port = _forward(
local=Local(),
local=LocalV1(),
node=get_fully_qualified_hostname_of_compute_node(node_name, cluster="mila"),
to_forward=to_forward,
options=options,
Expand Down Expand Up @@ -1181,7 +1181,7 @@ def _parse_lfs_quota_output(
return (used_gb, max_gb), (used_files, max_files)


def check_disk_quota(remote: Remote | RemoteV2) -> None:
def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None:
cluster = remote.hostname

# NOTE: This is what the output of the command looks like on the Mila cluster:
Expand Down Expand Up @@ -1262,7 +1262,7 @@ def get_colour(used: float, max: float) -> str:


def _find_allocation(
remote: Remote,
remote: RemoteV1,
node: str | None,
job: str | None,
alloc: list[str],
Expand All @@ -1274,11 +1274,15 @@ def _find_allocation(

if node is not None:
node_name = get_fully_qualified_hostname_of_compute_node(node, cluster=cluster)
return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster))
return RemoteV1(
node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster)
)

elif job is not None:
node_name = remote.get_output(f"squeue --jobs {job} -ho %N")
return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster))
return RemoteV1(
node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster)
)

else:
alloc = ["-J", job_name, *alloc]
Expand All @@ -1290,7 +1294,7 @@ def _find_allocation(


def _forward(
local: Local,
local: LocalV1,
node: str,
to_forward: int | str,
port: int | None,
Expand Down
14 changes: 7 additions & 7 deletions milatools/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

from milatools.utils.remote_v2 import SSH_CONFIG_FILE

from ..utils.local_v1 import LocalV1, check_passwordless, display
from ..utils.remote_v1 import RemoteV1
from ..utils.vscode_utils import (
get_expected_vscode_settings_json_path,
vscode_installed,
)
from .local import Local, check_passwordless, display
from .remote import Remote
from .utils import SSHConfig, T, running_inside_WSL, yn

logger = get_logger(__name__)
Expand Down Expand Up @@ -238,7 +238,7 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool:
"""
print("Checking passwordless authentication")

here = Local()
here = LocalV1()
sshdir = Path.home() / ".ssh"

# Check if there is a public key file in ~/.ssh
Expand Down Expand Up @@ -294,7 +294,7 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
Returns whether the operation completed successfully or not.
"""
here = Local()
here = LocalV1()
# Check that it is possible to connect without using a password.
print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.")
# TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of
Expand Down Expand Up @@ -371,7 +371,7 @@ def setup_keys_on_login_node(cluster: str = "mila"):
"This is required for `mila code` to work properly."
)
# todo: avoid re-creating the `Remote` here, since it goes through 2FA each time!
remote = Remote(cluster)
remote = RemoteV1(cluster)
try:
pubkeys = remote.get_lines("ls -t ~/.ssh/id*.pub")
print("# OK")
Expand Down Expand Up @@ -443,7 +443,7 @@ def get_windows_home_path_in_wsl() -> Path:

def create_ssh_keypair(
ssh_private_key_path: Path,
local: Local | None = None,
local: LocalV1 | None = None,
passphrase: str | None = "",
) -> None:
"""Creates a public/private key pair at the given path using ssh-keygen.
Expand All @@ -452,7 +452,7 @@ def create_ssh_keypair(
Otherwise, if passphrase is an empty string, no passphrase will be used (default).
If a string is passed, it is passed to ssh-keygen and used as the passphrase.
"""
local = local or Local()
local = local or LocalV1()
command = [
"ssh-keygen",
"-f",
Expand Down
18 changes: 9 additions & 9 deletions milatools/cli/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .utils import askpath, yn

if typing.TYPE_CHECKING:
from milatools.cli.remote import Remote
from milatools.utils.remote_v1 import RemoteV1

style = qn.Style(
[
Expand All @@ -38,7 +38,7 @@ def _ask_name(message: str, default: str = "") -> str:
qn.print(f"Invalid name: {name}", style="bold red")


def setup_profile(remote: Remote, path: str) -> str:
def setup_profile(remote: RemoteV1, path: str) -> str:
profile = select_preferred(remote, path)
preferred = profile is not None
if not preferred:
Expand All @@ -58,7 +58,7 @@ def setup_profile(remote: Remote, path: str) -> str:
return profile


def select_preferred(remote: Remote, path: str) -> str | None:
def select_preferred(remote: RemoteV1, path: str) -> str | None:
preferred = f"{path}/.milatools-profile"
qn.print(f"Checking for preferred profile in {preferred}")

Expand All @@ -71,7 +71,7 @@ def select_preferred(remote: Remote, path: str) -> str | None:
return preferred


def select_profile(remote: Remote) -> str | None:
def select_profile(remote: RemoteV1) -> str | None:
profdir = "~/.milatools/profiles"

qn.print(f"Fetching profiles in {profdir}")
Expand Down Expand Up @@ -109,7 +109,7 @@ def select_profile(remote: Remote) -> str | None:
return profile


def create_profile(remote: Remote, path: str = "~"):
def create_profile(remote: RemoteV1, path: str = "~"):
modules = select_modules(remote)

mload = f"module load {' '.join(modules)}"
Expand Down Expand Up @@ -139,7 +139,7 @@ def create_profile(remote: Remote, path: str = "~"):
return prof_file


def select_modules(remote: Remote):
def select_modules(remote: RemoteV1):
choices = [
Choice(
title="miniconda/3",
Expand Down Expand Up @@ -202,7 +202,7 @@ def _env_basename(pth: str) -> str | None:
return base


def select_conda_environment(remote: Remote, loader: str = "module load miniconda/3"):
def select_conda_environment(remote: RemoteV1, loader: str = "module load miniconda/3"):
qn.print("Fetching the list of conda environments...")
envstr = remote.get_output("conda env list --json", hide=True)
envlist: list[str] = json.loads(envstr)["envs"]
Expand Down Expand Up @@ -254,7 +254,7 @@ def select_conda_environment(remote: Remote, loader: str = "module load minicond
return env


def select_virtual_environment(remote: Remote, path):
def select_virtual_environment(remote: RemoteV1, path):
envstr = remote.get_output(
(
f"ls -d {path}/venv {path}/.venv {path}/virtualenv ~/virtualenvs/* "
Expand Down Expand Up @@ -293,7 +293,7 @@ def select_virtual_environment(remote: Remote, path):
return env


def ensure_program(remote: Remote, program: str, installers: dict[str, str]):
def ensure_program(remote: RemoteV1, program: str, installers: dict[str, str]):
to_test = [program, *installers.keys()]
progs = [
Path(p).name
Expand Down
6 changes: 3 additions & 3 deletions milatools/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing_extensions import ParamSpec, TypeGuard

if typing.TYPE_CHECKING:
from milatools.cli.remote import Remote
from milatools.utils.remote_v1 import RemoteV1

control_file_var = contextvars.ContextVar("control_file", default="/dev/null")

Expand Down Expand Up @@ -96,7 +96,7 @@ def randname():


@contextmanager
def with_control_file(remote: Remote, name=None):
def with_control_file(remote: RemoteV1, name=None):
name = name or randname()
pth = f".milatools/control/{name}"
remote.run("mkdir -p ~/.milatools/control", hide=True)
Expand Down Expand Up @@ -172,7 +172,7 @@ def yn(prompt: str, default: bool = True) -> bool:
return qn.confirm(prompt, default=default).unsafe_ask()


def askpath(prompt: str, remote: Remote) -> str:
def askpath(prompt: str, remote: RemoteV1) -> str:
while True:
pth = qn.text(prompt).unsafe_ask()
try:
Expand Down
4 changes: 2 additions & 2 deletions milatools/cli/local.py → milatools/utils/local_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

from milatools.utils.remote_v2 import SSH_CONFIG_FILE, is_already_logged_in

from .utils import CommandNotFoundError, T, cluster_to_connect_kwargs
from ..cli.utils import CommandNotFoundError, T, cluster_to_connect_kwargs

logger = get_logger(__name__)


class Local:
class LocalV1:
def display(self, args: list[str] | tuple[str, ...]) -> None:
display(args)

Expand Down
8 changes: 4 additions & 4 deletions milatools/cli/remote.py → milatools/utils/remote_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from fabric import Connection
from typing_extensions import Self, TypedDict, deprecated

from .utils import (
from ..cli.utils import (
SSHConnectionError,
T,
cluster_to_connect_kwargs,
Expand Down Expand Up @@ -109,7 +109,7 @@ def get_first_node_name(node_names_out: str) -> str:
return base + inside_brackets.split("-")[0]


class Remote:
class RemoteV1:
def __init__(
self,
hostname: str,
Expand Down Expand Up @@ -444,7 +444,7 @@ def extract_script(
return self.extract(shlex.join([dest, *args]), pattern=pattern, **kwargs)


class SlurmRemote(Remote):
class SlurmRemote(RemoteV1):
def __init__(
self,
connection: fabric.Connection,
Expand Down Expand Up @@ -533,7 +533,7 @@ def ensure_allocation(
"jobid": results["jobid"],
}, login_node_runner
else:
remote = Remote(hostname=self.hostname, connection=self.connection)
remote = RemoteV1(hostname=self.hostname, connection=self.connection)
command = shlex.join(["salloc", *self.alloc])
# We need to cd to $SCRATCH before we can run `salloc` on some clusters.
command = f"cd $SCRATCH && {command}"
Expand Down
2 changes: 1 addition & 1 deletion milatools/utils/remote_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from paramiko import SSHConfig

from milatools.cli import console
from milatools.cli.remote import Hide
from milatools.cli.utils import DRAC_CLUSTERS, MilatoolsUserError
from milatools.utils.remote_v1 import Hide

logger = get_logger(__name__)

Expand Down
Loading

0 comments on commit 400bc30

Please sign in to comment.