Skip to content

Commit

Permalink
(only) add type hints
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Nov 13, 2023
1 parent 2fbcfa0 commit 59be64a
Show file tree
Hide file tree
Showing 6 changed files with 506 additions and 206 deletions.
51 changes: 27 additions & 24 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import argparse
import functools
import operator
import os
import re
Expand All @@ -16,7 +15,6 @@
import time
import traceback
import typing
import warnings
import webbrowser
from argparse import ArgumentParser, _HelpAction
from contextlib import ExitStack
Expand All @@ -26,7 +24,7 @@
from urllib.parse import urlencode

import questionary as qn
from invoke import UnexpectedExit
from invoke.exceptions import UnexpectedExit
from typing_extensions import TypedDict

from ..version import version as mversion
Expand Down Expand Up @@ -61,7 +59,8 @@ def main():
".server.mila.quebec"
):
exit(
"ERROR: 'mila ...' should be run on your local machine and not on the Mila cluster"
"ERROR: 'mila ...' should be run on your local machine and not on "
"the Mila cluster"
)

try:
Expand All @@ -70,7 +69,8 @@ def main():
# These are user errors and should not be reported
print("ERROR:", exc, file=sys.stderr)
except SSHConnectionError as err:
# These are errors coming from paramiko's failure to connect to the host
# These are errors coming from paramiko's failure to connect to the
# host
print("ERROR:", f"{err}", file=sys.stderr)
except Exception:
print(T.red(traceback.format_exc()), file=sys.stderr)
Expand Down Expand Up @@ -633,6 +633,7 @@ def code(
print("To reconnect to this node:")
print(T.bold(f" mila code {path} --node {node_name}"))
print("To kill this allocation:")
assert "jobid" in data
print(T.bold(f" ssh mila scancel {data['jobid']}"))


Expand Down Expand Up @@ -715,25 +716,25 @@ def serve_list(purge: bool):

class StandardServerArgs(TypedDict):
alloc: Sequence[str]
"""Extra options to pass to slurm"""
"""Extra options to pass to slurm."""

job: str | None
"""Job ID to connect to"""
"""Job ID to connect to."""

name: str | None
"""Name of the persistent server"""
"""Name of the persistent server."""

node: str | None
"""Node to connect to"""
"""Node to connect to."""

persist: bool
"""Whether the server should persist or not"""
"""Whether the server should persist or not."""

port: int | None
"""Port to open on the local machine"""
"""Port to open on the local machine."""

profile: str | None
"""Name of the profile to use"""
"""Name of the profile to use."""


def lab(path: str | None, **kwargs: Unpack[StandardServerArgs]):
Expand Down Expand Up @@ -849,7 +850,7 @@ def _get_server_info(


class SortingHelpFormatter(argparse.HelpFormatter):
"""Taken and adapted from https://stackoverflow.com/a/12269143/6388696"""
"""Taken and adapted from https://stackoverflow.com/a/12269143/6388696."""

def add_arguments(self, actions):
actions = sorted(actions, key=operator.attrgetter("option_strings"))
Expand Down Expand Up @@ -914,8 +915,8 @@ def _standard_server(
path: str | None,
*,
program: str,
installers,
command,
installers: dict[str, str],
command: str,
profile: str | None,
persist: bool,
port: int | None,
Expand Down Expand Up @@ -1166,7 +1167,7 @@ def check_disk_quota(remote: Remote) -> None:


def _find_allocation(
remote,
remote: Remote,
node: str | None,
job: str | None,
alloc: Sequence[str],
Expand Down Expand Up @@ -1219,13 +1220,15 @@ def _forward(
args = [f"localhost:{port}:{to_forward}", node]

proc = local.popen(
"ssh",
"-o",
"UserKnownHostsFile=/dev/null",
"-o",
"StrictHostKeyChecking=no",
"-nNL",
*args,
[
"ssh",
"-o",
"UserKnownHostsFile=/dev/null",
"-o",
"StrictHostKeyChecking=no",
"-nNL",
*args,
]
)

url = f"http://localhost:{port}"
Expand All @@ -1245,7 +1248,7 @@ def _forward(
time.sleep(period)
try:
# This feels stupid, there's probably a better way
local.silent_get("nc", "-z", "localhost", str(port))
local.silent_get(["nc", "-z", "localhost", str(port)])
except subprocess.CalledProcessError:
continue
except Exception:
Expand Down
82 changes: 45 additions & 37 deletions milatools/cli/local.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,73 @@
from __future__ import annotations

import shlex
import subprocess
from subprocess import CompletedProcess
from typing import IO, Any

from typing_extensions import deprecated

from .utils import CommandNotFoundError, T, shjoin


class Local:
def display(self, args):
print(T.bold_green(f"(local) $ ", shjoin(args)))
def display(self, args: list[str] | tuple[str, ...]) -> None:
print(T.bold_green("(local) $ ", shjoin(args)))

def silent_get(self, *args, **kwargs):
return subprocess.check_output(
args,
universal_newlines=True,
**kwargs,
)
def silent_get(self, *cmd: str) -> str:
return subprocess.check_output(cmd, universal_newlines=True)

def get(self, *args, **kwargs):
self.display(args)
return subprocess.check_output(
args,
universal_newlines=True,
**kwargs,
)
@deprecated("This isn't used and will probably be removed. Don't start using it.")
def get(self, *cmd: str) -> str:
self.display(cmd)
return subprocess.check_output(cmd, universal_newlines=True)

def run(self, *args, **kwargs):
self.display(args)
def run(
self,
*cmd: str,
stdout: int | IO[Any] | None = None,
stderr: int | IO[Any] | None = None,
capture_output: bool = False,
) -> CompletedProcess[str]:
self.display(cmd)
try:
return subprocess.run(
args,
cmd,
stdout=stdout,
stderr=stderr,
capture_output=capture_output,
universal_newlines=True,
**kwargs,
)
except FileNotFoundError as e:
if e.filename == args[0]:
raise CommandNotFoundError(e.filename)
else:
raise
if e.filename == cmd[0]:
raise CommandNotFoundError(e.filename) from e
raise

def popen(self, *args, **kwargs):
self.display(args)
def popen(
self,
*cmd: str,
stdout: int | IO[Any] | None = None,
stderr: int | IO[Any] | None = None,
) -> subprocess.Popen:
self.display(cmd)
return subprocess.Popen(
args,
universal_newlines=True,
**kwargs,
cmd, stdout=stdout, stderr=stderr, universal_newlines=True
)

def check_passwordless(self, host):
def check_passwordless(self, host: str) -> bool:
results = self.run(
"ssh",
"-oPreferredAuthentications=publickey",
host,
"echo OK",
*shlex.split(f"ssh -oPreferredAuthentications=publickey {host} 'echo OK'"),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if results.returncode != 0:
if "Permission denied" in results.stderr:
return False
else:
print(results.stdout)
print(results.stderr)
exit(f"Failed to connect to {host}, could not understand error")
print(results.stdout)
print(results.stderr)
exit(f"Failed to connect to {host}, could not understand error")
# TODO: Perhaps we could actually check the output of the command here!
# elif "OK" in results.stdout:
else:
print("# OK")
return True
Loading

0 comments on commit 59be64a

Please sign in to comment.