diff --git a/milatools/cli/init.py b/milatools/cli/init.py index caa81cab..8829a672 100644 --- a/milatools/cli/init.py +++ b/milatools/cli/init.py @@ -19,9 +19,10 @@ from milatools.cli import console from milatools.cli.utils import SSHConfig as SSHConfigWriter from milatools.cli.utils import T, running_inside_WSL, yn -from milatools.utils.local_v1 import check_passwordless, display +from milatools.utils.local_v1 import display from milatools.utils.local_v2 import LocalV2 from milatools.utils.remote_v1 import RemoteV1 +from milatools.utils.remote_v2 import control_socket_is_running, get_controlpath_for from milatools.utils.vscode_utils import ( get_expected_vscode_settings_json_path, vscode_installed, @@ -345,7 +346,6 @@ def setup_passwordless_ssh_access_to_cluster( # their SSH config proposed by the first part of `mila init`. # - Instead of making the code complicated with lots of corner cases, just raise an # error if the SSH config doesn't match what we expect to see after `mila init`. - raise NotImplementedError() if ssh_private_key_path is None: # TODO: What to do if there isn't a private key set in the SSH config, but there # is already a private key in the SSH dir? (it would be used by ssh). @@ -399,17 +399,21 @@ def setup_passwordless_ssh_access_to_cluster( subprocess.run(command, check=True, text=False, stdin=f) else: here.run( - "ssh-copy-id", - "-i", - str(ssh_private_key_path), - "-o", - "StrictHostKeyChecking=no", - cluster, - check=True, + ( + "ssh-copy-id", + "-i", + str(ssh_private_key_path), + "-o", + "StrictHostKeyChecking=no", + cluster, + ), ) # double-check that this worked. - if not check_passwordless(cluster): + if not control_socket_is_running( + cluster, + control_path=get_controlpath_for(cluster, ssh_config_path=ssh_config_path), + ): print(f"'ssh-copy-id {cluster}' appears to have failed!") return False return True @@ -699,7 +703,7 @@ def _get_drac_username(ssh_config: SSHConfigReader) -> str | None: if len(users_from_drac_config_entries) == 1: return users_from_drac_config_entries.pop().strip() - username: str | None + username: str | None = None # Note: If there are none, or more than one, then we'll ask the user for their # username, just to be sure. if yn("Do you also have an account on the ComputeCanada/DRAC clusters?"): diff --git a/milatools/utils/local_v1.py b/milatools/utils/local_v1.py index 6f5520bc..8f465059 100644 --- a/milatools/utils/local_v1.py +++ b/milatools/utils/local_v1.py @@ -1,19 +1,14 @@ from __future__ import annotations import shlex -import socket import subprocess -import sys from logging import getLogger as get_logger from subprocess import CompletedProcess from typing import IO, Any -import fabric -import paramiko.ssh_exception from typing_extensions import deprecated -from milatools.cli.utils import CommandNotFoundError, T, cluster_to_connect_kwargs -from milatools.utils.remote_v2 import SSH_CONFIG_FILE, is_already_logged_in +from milatools.cli.utils import CommandNotFoundError, T logger = get_logger(__name__) @@ -69,9 +64,6 @@ def popen( cmd, stdout=stdout, stderr=stderr, universal_newlines=True ) - def check_passwordless(self, host: str): - return check_passwordless(host) - def display(split_command: list[str] | tuple[str, ...] | str) -> None: if isinstance(split_command, str): @@ -79,44 +71,3 @@ def display(split_command: list[str] | tuple[str, ...] | str) -> None: else: command = shlex.join(split_command) print(T.bold_green("(local) $ ", command)) - - -def check_passwordless(host: str) -> bool: - if ( - sys.platform != "win32" - and SSH_CONFIG_FILE.exists() - and is_already_logged_in(host, ssh_config_path=SSH_CONFIG_FILE) - ): - return True - - try: - connect_kwargs_for_host = {"allow_agent": False} - if host in cluster_to_connect_kwargs: - connect_kwargs_for_host.update(cluster_to_connect_kwargs[host]) - with fabric.Connection( - host, - connect_kwargs=connect_kwargs_for_host, - ) as connection: - results: fabric.runners.Result = connection.run( - "echo OK", - in_stream=False, - echo=True, - echo_format=T.bold_cyan(f"({host})" + " $ {command}"), - ) - - except ( - paramiko.ssh_exception.SSHException, - paramiko.ssh_exception.NoValidConnectionsError, - socket.gaierror, - # BUG: Also getting ValueError("q must be exactlu 160, 224, or 256 bits long") - # with older versions of paramiko. - # ValueError, - ) as err: - logger.debug(f"Unable to connect to {host} without a password: {err}") - return False - - if "OK" in results.stdout: - return True - logger.error("Unexpected output from SSH command, output didn't contain 'OK'!") - logger.error(f"stdout: {results.stdout}, stderr: {results.stderr}") - return False diff --git a/tests/cli/test_init.py b/tests/cli/test_init.py index adbc531a..2cf3e112 100644 --- a/tests/cli/test_init.py +++ b/tests/cli/test_init.py @@ -11,7 +11,7 @@ import textwrap from functools import partial from logging import getLogger as get_logger -from pathlib import Path, PurePosixPath +from pathlib import Path, PosixPath, PurePosixPath from unittest.mock import Mock import invoke @@ -42,7 +42,7 @@ SSHConfig, running_inside_WSL, ) -from milatools.utils.local_v1 import LocalV1, check_passwordless +from milatools.utils.local_v1 import LocalV1 from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import ( SSH_CACHE_DIR, @@ -542,8 +542,8 @@ def test_with_existing_entries( User Bob """ ), - ["bob\r"], - "bob", + [], # user input doesn't matter (won't get asked). + "george", id="two_matching_entries", ), pytest.param( @@ -579,7 +579,7 @@ def test_get_username( ssh_config_path = tmp_path / "config" with open(ssh_config_path, "w") as f: f.write(contents) - ssh_config = SSHConfig(ssh_config_path) + ssh_config = paramiko.SSHConfig.from_path(str(ssh_config_path)) if not prompt_inputs: input_pipe.close() for prompt_input in prompt_inputs: @@ -639,8 +639,8 @@ def test_get_username( User Bob """ ), - ["y", "bob\r"], - "bob", + [], # will not get asked for input. + "george", id="two_matching_entries", ), pytest.param( @@ -675,7 +675,7 @@ def test_get_drac_username( ssh_config_path = tmp_path / "config" with open(ssh_config_path, "w") as f: f.write(contents) - ssh_config = SSHConfig(ssh_config_path) + ssh_config = paramiko.SSHConfig.from_path(str(ssh_config_path)) if not prompt_inputs: input_pipe.close() for prompt_input in prompt_inputs: @@ -863,7 +863,9 @@ def test_setup_windows_ssh_config_from_wsl( for prompt in user_inputs: input_pipe.send_text(prompt) - setup_windows_ssh_config_from_wsl(linux_ssh_config=linux_ssh_config) + setup_windows_ssh_config_from_wsl( + linux_ssh_config_path=PosixPath(linux_ssh_config.path) + ) assert windows_ssh_config_path.exists() assert windows_ssh_config_path.stat().st_mode & 0o777 == 0o600 @@ -1029,7 +1031,9 @@ def test_setup_windows_ssh_config_from_wsl_copies_keys( input_pipe.send_text("y") # accept creating the Windows config file input_pipe.send_text("y") # accept the changes - setup_windows_ssh_config_from_wsl(linux_ssh_config=linux_ssh_config) + setup_windows_ssh_config_from_wsl( + linux_ssh_config_path=PosixPath(linux_ssh_config.path) + ) windows_private_key_path = windows_home / ".ssh" / "id_rsa" windows_public_key_path = windows_private_key_path.with_suffix(".pub") @@ -1283,8 +1287,6 @@ def test_setup_passwordless_ssh_access_to_cluster( assert ssh_public_key_path.exists() def have_passwordless_ssh_access_to(cluster: str) -> bool: - if sys.platform == "win32": - return check_passwordless(cluster) return is_already_logged_in(cluster, ssh_config_path=SSH_CONFIG_FILE) def _exists(file: PurePosixPath | Path): @@ -1340,7 +1342,7 @@ def temporarily_disable_ssh_access_to_cluster(): # subprocess.check_call( # ["ssh", "-O", "exit", f"-oControlPath={control_path}", cluster] # ) - assert not check_passwordless(cluster) + assert not is_already_logged_in(cluster) else: assert passwordless_ssh_was_previously_setup if _exists(authorized_keys_file): @@ -1356,7 +1358,7 @@ def temporarily_disable_ssh_access_to_cluster(): hide=False, ) # todo: might not work well for the DRAC clusters! - assert not check_passwordless(cluster) + assert not is_already_logged_in(cluster) def reenable_ssh_access_to_cluster(): if cluster == "localhost": @@ -1376,7 +1378,7 @@ def reenable_ssh_access_to_cluster(): ssh_dir.mkdir(exist_ok=True, mode=0o700) assert not passwordless_ssh_was_previously_setup shutil.copy(backup_authorized_keys_file, authorized_keys_file) - assert check_passwordless(cluster) + assert is_already_logged_in(cluster) else: logger.info( f"Restoring the original {authorized_keys_file} from backup at " diff --git a/tests/utils/test_local_v1.py b/tests/utils/test_local_v1.py index 6d04ae50..fa1f1f3d 100644 --- a/tests/utils/test_local_v1.py +++ b/tests/utils/test_local_v1.py @@ -1,22 +1,17 @@ from __future__ import annotations -import sys from subprocess import PIPE import pytest from pytest_regressions.file_regression import FileRegressionFixture -from milatools.cli.utils import SSH_CONFIG_FILE -from milatools.utils.local_v1 import CommandNotFoundError, LocalV1, check_passwordless -from milatools.utils.remote_v2 import is_already_logged_in +from milatools.utils.local_v1 import CommandNotFoundError, LocalV1 from ..cli.common import ( in_github_CI, in_self_hosted_github_CI, output_tester, - passwordless_ssh_connection_to_localhost_is_setup, requires_no_s_flag, - skip_if_on_github_cloud_CI, xfails_on_windows, ) @@ -129,70 +124,3 @@ def test_popen( "and raises a ValueError." ), ) - - -# @PARAMIKO_SSH_BANNER_BUG -# @paramiko_openssh_key_parsing_issue -@pytest.mark.xfail( - reason="TODO: `check_passwordless` is incredibly flaky and needs to be reworked." -) -@pytest.mark.parametrize( - ("hostname", "expected"), - [ - pytest.param( - "localhost", - passwordless_ssh_connection_to_localhost_is_setup, - ), - ("blablabob@localhost", False), - pytest.param( - "mila", - True if (in_self_hosted_github_CI or not in_github_CI) else False, - ), - pytest.param( - "bobobobobobo@mila", - False, - marks=[ - paramiko_openssh_key_parsing_issue, - skip_if_on_github_cloud_CI, - ], - ), - # For the clusters with 2FA, we expect `check_passwordless` to return True if - # we've already setup the shared SSH connection. - pytest.param( - "blablabob@narval", - False, - marks=[ - skip_if_on_github_cloud_CI, - paramiko_openssh_key_parsing_issue, - ], - ), - *( - # note: can't properly test for the False case because of the 2FA - # prompt! - pytest.param( - drac_cluster, - True, - marks=pytest.mark.skipif( - sys.platform == "win32" - or not is_already_logged_in( - drac_cluster, ssh_config_path=SSH_CONFIG_FILE - ), - reason="Should give True when we're already logged in.", - ), - ) - for drac_cluster in ["narval", "beluga", "cedar", "graham"] - ), - pytest.param( - "niagara", - False, - marks=[ - skip_if_on_github_cloud_CI, - paramiko_openssh_key_parsing_issue, - ], - ), # SSH access to niagara isn't enabled by default. - ], -) -def test_check_passwordless(hostname: str, expected: bool): - # TODO: Maybe also test how `check_passwordless` behaves when using a key with a - # passphrase. - assert check_passwordless(hostname) == expected