Skip to content

Commit

Permalink
Remove check_passwordless and fix test_init.py
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed May 24, 2024
1 parent cccb609 commit 9d10a17
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 149 deletions.
26 changes: 15 additions & 11 deletions milatools/cli/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from milatools.cli import console
from milatools.cli.utils import SSHConfig as SSHConfigWriter
from milatools.cli.utils import T, running_inside_WSL, yn
from milatools.utils.local_v1 import check_passwordless, display
from milatools.utils.local_v1 import display
from milatools.utils.local_v2 import LocalV2
from milatools.utils.remote_v1 import RemoteV1
from milatools.utils.remote_v2 import control_socket_is_running, get_controlpath_for
from milatools.utils.vscode_utils import (
get_expected_vscode_settings_json_path,
vscode_installed,
Expand Down Expand Up @@ -345,7 +346,6 @@ def setup_passwordless_ssh_access_to_cluster(
# their SSH config proposed by the first part of `mila init`.
# - Instead of making the code complicated with lots of corner cases, just raise an
# error if the SSH config doesn't match what we expect to see after `mila init`.
raise NotImplementedError()
if ssh_private_key_path is None:
# TODO: What to do if there isn't a private key set in the SSH config, but there
# is already a private key in the SSH dir? (it would be used by ssh).
Expand Down Expand Up @@ -399,17 +399,21 @@ def setup_passwordless_ssh_access_to_cluster(
subprocess.run(command, check=True, text=False, stdin=f)
else:
here.run(
"ssh-copy-id",
"-i",
str(ssh_private_key_path),
"-o",
"StrictHostKeyChecking=no",
cluster,
check=True,
(
"ssh-copy-id",
"-i",
str(ssh_private_key_path),
"-o",
"StrictHostKeyChecking=no",
cluster,
),
)

# double-check that this worked.
if not check_passwordless(cluster):
if not control_socket_is_running(
cluster,
control_path=get_controlpath_for(cluster, ssh_config_path=ssh_config_path),
):
print(f"'ssh-copy-id {cluster}' appears to have failed!")
return False
return True
Expand Down Expand Up @@ -699,7 +703,7 @@ def _get_drac_username(ssh_config: SSHConfigReader) -> str | None:
if len(users_from_drac_config_entries) == 1:
return users_from_drac_config_entries.pop().strip()

username: str | None
username: str | None = None
# Note: If there are none, or more than one, then we'll ask the user for their
# username, just to be sure.
if yn("Do you also have an account on the ComputeCanada/DRAC clusters?"):
Expand Down
51 changes: 1 addition & 50 deletions milatools/utils/local_v1.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from __future__ import annotations

import shlex
import socket
import subprocess
import sys
from logging import getLogger as get_logger
from subprocess import CompletedProcess
from typing import IO, Any

import fabric
import paramiko.ssh_exception
from typing_extensions import deprecated

from milatools.cli.utils import CommandNotFoundError, T, cluster_to_connect_kwargs
from milatools.utils.remote_v2 import SSH_CONFIG_FILE, is_already_logged_in
from milatools.cli.utils import CommandNotFoundError, T

logger = get_logger(__name__)

Expand Down Expand Up @@ -69,54 +64,10 @@ def popen(
cmd, stdout=stdout, stderr=stderr, universal_newlines=True
)

def check_passwordless(self, host: str):
return check_passwordless(host)


def display(split_command: list[str] | tuple[str, ...] | str) -> None:
if isinstance(split_command, str):
command = split_command
else:
command = shlex.join(split_command)
print(T.bold_green("(local) $ ", command))


def check_passwordless(host: str) -> bool:
if (
sys.platform != "win32"
and SSH_CONFIG_FILE.exists()
and is_already_logged_in(host, ssh_config_path=SSH_CONFIG_FILE)
):
return True

try:
connect_kwargs_for_host = {"allow_agent": False}
if host in cluster_to_connect_kwargs:
connect_kwargs_for_host.update(cluster_to_connect_kwargs[host])
with fabric.Connection(
host,
connect_kwargs=connect_kwargs_for_host,
) as connection:
results: fabric.runners.Result = connection.run(
"echo OK",
in_stream=False,
echo=True,
echo_format=T.bold_cyan(f"({host})" + " $ {command}"),
)

except (
paramiko.ssh_exception.SSHException,
paramiko.ssh_exception.NoValidConnectionsError,
socket.gaierror,
# BUG: Also getting ValueError("q must be exactlu 160, 224, or 256 bits long")
# with older versions of paramiko.
# ValueError,
) as err:
logger.debug(f"Unable to connect to {host} without a password: {err}")
return False

if "OK" in results.stdout:
return True
logger.error("Unexpected output from SSH command, output didn't contain 'OK'!")
logger.error(f"stdout: {results.stdout}, stderr: {results.stderr}")
return False
32 changes: 17 additions & 15 deletions tests/cli/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import textwrap
from functools import partial
from logging import getLogger as get_logger
from pathlib import Path, PurePosixPath
from pathlib import Path, PosixPath, PurePosixPath
from unittest.mock import Mock

import invoke
Expand Down Expand Up @@ -42,7 +42,7 @@
SSHConfig,
running_inside_WSL,
)
from milatools.utils.local_v1 import LocalV1, check_passwordless
from milatools.utils.local_v1 import LocalV1
from milatools.utils.remote_v1 import RemoteV1
from milatools.utils.remote_v2 import (
SSH_CACHE_DIR,
Expand Down Expand Up @@ -542,8 +542,8 @@ def test_with_existing_entries(
User Bob
"""
),
["bob\r"],
"bob",
[], # user input doesn't matter (won't get asked).
"george",
id="two_matching_entries",
),
pytest.param(
Expand Down Expand Up @@ -579,7 +579,7 @@ def test_get_username(
ssh_config_path = tmp_path / "config"
with open(ssh_config_path, "w") as f:
f.write(contents)
ssh_config = SSHConfig(ssh_config_path)
ssh_config = paramiko.SSHConfig.from_path(str(ssh_config_path))
if not prompt_inputs:
input_pipe.close()
for prompt_input in prompt_inputs:
Expand Down Expand Up @@ -639,8 +639,8 @@ def test_get_username(
User Bob
"""
),
["y", "bob\r"],
"bob",
[], # will not get asked for input.
"george",
id="two_matching_entries",
),
pytest.param(
Expand Down Expand Up @@ -675,7 +675,7 @@ def test_get_drac_username(
ssh_config_path = tmp_path / "config"
with open(ssh_config_path, "w") as f:
f.write(contents)
ssh_config = SSHConfig(ssh_config_path)
ssh_config = paramiko.SSHConfig.from_path(str(ssh_config_path))
if not prompt_inputs:
input_pipe.close()
for prompt_input in prompt_inputs:
Expand Down Expand Up @@ -863,7 +863,9 @@ def test_setup_windows_ssh_config_from_wsl(
for prompt in user_inputs:
input_pipe.send_text(prompt)

setup_windows_ssh_config_from_wsl(linux_ssh_config=linux_ssh_config)
setup_windows_ssh_config_from_wsl(
linux_ssh_config_path=PosixPath(linux_ssh_config.path)
)

assert windows_ssh_config_path.exists()
assert windows_ssh_config_path.stat().st_mode & 0o777 == 0o600
Expand Down Expand Up @@ -1029,7 +1031,9 @@ def test_setup_windows_ssh_config_from_wsl_copies_keys(
input_pipe.send_text("y") # accept creating the Windows config file
input_pipe.send_text("y") # accept the changes

setup_windows_ssh_config_from_wsl(linux_ssh_config=linux_ssh_config)
setup_windows_ssh_config_from_wsl(
linux_ssh_config_path=PosixPath(linux_ssh_config.path)
)

windows_private_key_path = windows_home / ".ssh" / "id_rsa"
windows_public_key_path = windows_private_key_path.with_suffix(".pub")
Expand Down Expand Up @@ -1283,8 +1287,6 @@ def test_setup_passwordless_ssh_access_to_cluster(
assert ssh_public_key_path.exists()

def have_passwordless_ssh_access_to(cluster: str) -> bool:
if sys.platform == "win32":
return check_passwordless(cluster)
return is_already_logged_in(cluster, ssh_config_path=SSH_CONFIG_FILE)

def _exists(file: PurePosixPath | Path):
Expand Down Expand Up @@ -1340,7 +1342,7 @@ def temporarily_disable_ssh_access_to_cluster():
# subprocess.check_call(
# ["ssh", "-O", "exit", f"-oControlPath={control_path}", cluster]
# )
assert not check_passwordless(cluster)
assert not is_already_logged_in(cluster)
else:
assert passwordless_ssh_was_previously_setup
if _exists(authorized_keys_file):
Expand All @@ -1356,7 +1358,7 @@ def temporarily_disable_ssh_access_to_cluster():
hide=False,
)
# todo: might not work well for the DRAC clusters!
assert not check_passwordless(cluster)
assert not is_already_logged_in(cluster)

def reenable_ssh_access_to_cluster():
if cluster == "localhost":
Expand All @@ -1376,7 +1378,7 @@ def reenable_ssh_access_to_cluster():
ssh_dir.mkdir(exist_ok=True, mode=0o700)
assert not passwordless_ssh_was_previously_setup
shutil.copy(backup_authorized_keys_file, authorized_keys_file)
assert check_passwordless(cluster)
assert is_already_logged_in(cluster)
else:
logger.info(
f"Restoring the original {authorized_keys_file} from backup at "
Expand Down
74 changes: 1 addition & 73 deletions tests/utils/test_local_v1.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
from __future__ import annotations

import sys
from subprocess import PIPE

import pytest
from pytest_regressions.file_regression import FileRegressionFixture

from milatools.cli.utils import SSH_CONFIG_FILE
from milatools.utils.local_v1 import CommandNotFoundError, LocalV1, check_passwordless
from milatools.utils.remote_v2 import is_already_logged_in
from milatools.utils.local_v1 import CommandNotFoundError, LocalV1

from ..cli.common import (
in_github_CI,
in_self_hosted_github_CI,
output_tester,
passwordless_ssh_connection_to_localhost_is_setup,
requires_no_s_flag,
skip_if_on_github_cloud_CI,
xfails_on_windows,
)

Expand Down Expand Up @@ -129,70 +124,3 @@ def test_popen(
"and raises a ValueError."
),
)


# @PARAMIKO_SSH_BANNER_BUG
# @paramiko_openssh_key_parsing_issue
@pytest.mark.xfail(
reason="TODO: `check_passwordless` is incredibly flaky and needs to be reworked."
)
@pytest.mark.parametrize(
("hostname", "expected"),
[
pytest.param(
"localhost",
passwordless_ssh_connection_to_localhost_is_setup,
),
("blablabob@localhost", False),
pytest.param(
"mila",
True if (in_self_hosted_github_CI or not in_github_CI) else False,
),
pytest.param(
"bobobobobobo@mila",
False,
marks=[
paramiko_openssh_key_parsing_issue,
skip_if_on_github_cloud_CI,
],
),
# For the clusters with 2FA, we expect `check_passwordless` to return True if
# we've already setup the shared SSH connection.
pytest.param(
"blablabob@narval",
False,
marks=[
skip_if_on_github_cloud_CI,
paramiko_openssh_key_parsing_issue,
],
),
*(
# note: can't properly test for the False case because of the 2FA
# prompt!
pytest.param(
drac_cluster,
True,
marks=pytest.mark.skipif(
sys.platform == "win32"
or not is_already_logged_in(
drac_cluster, ssh_config_path=SSH_CONFIG_FILE
),
reason="Should give True when we're already logged in.",
),
)
for drac_cluster in ["narval", "beluga", "cedar", "graham"]
),
pytest.param(
"niagara",
False,
marks=[
skip_if_on_github_cloud_CI,
paramiko_openssh_key_parsing_issue,
],
), # SSH access to niagara isn't enabled by default.
],
)
def test_check_passwordless(hostname: str, expected: bool):
# TODO: Maybe also test how `check_passwordless` behaves when using a key with a
# passphrase.
assert check_passwordless(hostname) == expected

0 comments on commit 9d10a17

Please sign in to comment.