From 734c13f57d23cb2b6a3a20bd028f138db99dd1fe Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 15 Apr 2024 14:48:59 -0400 Subject: [PATCH] Simplify `is_already_logged_in` function Signed-off-by: Fabrice Normandin --- milatools/utils/remote_v2.py | 19 +++++++------------ tests/conftest.py | 2 +- tests/utils/test_remote_v2.py | 25 +++++++++++++------------ 3 files changed, 21 insertions(+), 25 deletions(-) diff --git a/milatools/utils/remote_v2.py b/milatools/utils/remote_v2.py index 34d4e8fa..7dd2078b 100644 --- a/milatools/utils/remote_v2.py +++ b/milatools/utils/remote_v2.py @@ -312,32 +312,26 @@ def option_dict_to_flags(options: dict[str, str]) -> list[str]: ] -def is_already_logged_in(cluster: str, also_run_command_to_check: bool = False) -> bool: +def is_already_logged_in(cluster: str) -> bool: """Checks whether we are already logged in to the given cluster. More specifically, this checks whether a reusable SSH control master is setup at the controlpath for the given cluster. NOTE: This function is not supported on Windows. - - Parameters - ---------- - cluster: Hostname of the cluster to connect to. - also_run_command_to_check: Whether we should also run a command over SSH to make - 100% sure that we are logged in. In most cases this isn't necessary so we can - skip it, since it can take a few seconds. """ - if not SSH_CONFIG_FILE.exists(): + ssh_config_path = SSH_CONFIG_FILE + if not ssh_config_path.exists(): return False - control_path = get_controlpath_for(cluster, ssh_config_path=SSH_CONFIG_FILE) + control_path = get_controlpath_for(cluster, ssh_config_path=ssh_config_path) if not control_path.exists(): logger.debug(f"ControlPath at {control_path} doesn't exist. Not logged in.") return False if not control_socket_is_running(cluster, control_path): return False - if not also_run_command_to_check: - return True + return True + # if also_run_command_to_check: return RemoteV2(cluster, control_path=control_path).get_output("echo OK") == "OK" @@ -363,6 +357,7 @@ def get_controlpath_for( ssh_config_values: dict[str, str] = {} if ssh_config_path.exists(): + # note: This also does the substitutions in the vars, e.g. %p -> port, etc. ssh_config_values = SSHConfig.from_path(str(ssh_config_path)).lookup(cluster) if control_path := ssh_config_values.get("controlpath"): diff --git a/tests/conftest.py b/tests/conftest.py index 7f4b302f..4e8d6b5f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -349,7 +349,7 @@ def already_logged_in( ) RemoteV2(cluster) - assert is_already_logged_in(cluster, also_run_command_to_check=True) + assert is_already_logged_in(cluster) yield True # TODO: Should we remove the connection after running the tests? return diff --git a/tests/utils/test_remote_v2.py b/tests/utils/test_remote_v2.py index 5f911cc2..7f5124bc 100644 --- a/tests/utils/test_remote_v2.py +++ b/tests/utils/test_remote_v2.py @@ -7,7 +7,7 @@ import milatools.utils.remote_v2 from milatools.cli.init_command import DRAC_CLUSTERS -from milatools.cli.utils import SSH_CONFIG_FILE +from milatools.cli.utils import SSH_CACHE_DIR, SSH_CONFIG_FILE from milatools.utils.local_v2 import run_async from milatools.utils.remote_v2 import ( RemoteV2, @@ -106,18 +106,19 @@ async def test_init_with_none_controlpath( # NOTE: The timeout here is a part of the test: if we are already connected, running the # command should be fast, and if we aren't connected, this should be able to tell fast # (in other words, it shouldn't wait for 2FA input or similar). -@pytest.mark.timeout(1, func_only=True) -@pytest.mark.parametrize("also_run_command_to_check", [False, True]) -def test_is_already_logged_in( - cluster: str, already_logged_in: bool, also_run_command_to_check: bool -): - assert ( - is_already_logged_in( - cluster, also_run_command_to_check=also_run_command_to_check +@pytest.mark.timeout(5, func_only=True) +@pytest.mark.asyncio +async def test_is_already_logged_in(cluster: str, already_logged_in: bool): + if is_already_logged_in(cluster): + remote = await RemoteV2.connect(cluster) + assert remote.control_path.exists() + assert (await remote.get_output_async("echo OK")) == "OK" + else: + # Can't really check all that much here. + control_path = get_controlpath_for( + cluster, ssh_config_path=SSH_CONFIG_FILE, ssh_cache_dir=SSH_CACHE_DIR ) - == already_logged_in - == get_controlpath_for(cluster).exists() - ) + assert not control_path.exists() def test_controlsocket_is_running(cluster: str, already_logged_in: bool):