-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Dont add ControlMaster ssh entries on Windows Signed-off-by: Fabrice Normandin <[email protected]> * Remove the unused params of _add_ssh_entry Signed-off-by: Fabrice Normandin <[email protected]> * Rebase on Master Signed-off-by: Fabrice Normandin <[email protected]> * Fix isort issues Signed-off-by: Fabrice Normandin <[email protected]> * Fix the `mila init` command on windows Fixes #63 Signed-off-by: Fabrice Normandin <[email protected]> * Fix black formatting errors Signed-off-by: Fabrice Normandin <[email protected]> * Fix the equivalent of ssh-copy-id, now works! Signed-off-by: Fabrice Normandin <[email protected]> * Attempt to fix running `mila code` from WSL Signed-off-by: Fabrice <[email protected]> * Fix whitespace errors and run `code` directly Signed-off-by: Fabrice Normandin <[email protected]> * Set the remote.SSH settings in User Settings Json Signed-off-by: Fabrice Normandin <[email protected]> * Show that VsCode settings are found Signed-off-by: Fabrice Normandin <[email protected]> * Remove the VsCode settings thing for now Signed-off-by: Fabrice Normandin <[email protected]> * Use `here.run` to run the command Signed-off-by: Fabrice <[email protected]> * Skip the hostname checking on Windows machines Signed-off-by: Fabrice <[email protected]> * Fix black formatting issue Signed-off-by: Fabrice Normandin <[email protected]> * Fix missing condition in WSL check Signed-off-by: Fabrice Normandin <[email protected]> * Warn WSL users if `mila init` not done on Windows Signed-off-by: Fabrice Normandin <[email protected]> * Also setup Windows SSH config from WSL Signed-off-by: Fabrice Normandin <[email protected]> * Fix isort issues Signed-off-by: Fabrice Normandin <[email protected]> * Reuse existing fn to setup windows SSH file Signed-off-by: Fabrice Normandin <[email protected]> * Preserve ordering of SSH entries when copying Signed-off-by: Fabrice Normandin <[email protected]> * Copy both public and private key files Signed-off-by: Fabrice Normandin <[email protected]> * Add tests for the Windows SSH setup from WSL Signed-off-by: Fabrice Normandin <[email protected]> * Remove need for pytest-mock dependency Signed-off-by: Fabrice Normandin <[email protected]> * Fix a bug in the added tests Signed-off-by: Fabrice Normandin <[email protected]> * Remove accidentally-added .pre-commit-config.yaml Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]> Signed-off-by: Fabrice <[email protected]>
- Loading branch information
Showing
6 changed files
with
449 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,29 @@ | ||
from __future__ import annotations | ||
|
||
import difflib | ||
import shutil | ||
import subprocess | ||
import sys | ||
import warnings | ||
from logging import getLogger as get_logger | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import questionary as qn | ||
|
||
from .utils import SSHConfig, T, yn | ||
from .utils import SSHConfig, T, running_inside_WSL, yn | ||
|
||
WINDOWS_UNSUPPORTED_KEYS = ["ControlMaster", "ControlPath", "ControlPersist"] | ||
|
||
logger = get_logger(__name__) | ||
|
||
HOSTS = ["mila", "mila-cpu", "*.server.mila.quebec !*login.server.mila.quebec"] | ||
"""List of host entries that get added to the SSH configurtion by `mila init`.""" | ||
|
||
|
||
def setup_ssh_config( | ||
ssh_config_path: str | Path = "~/.ssh/config", | ||
): | ||
) -> SSHConfig: | ||
"""Interactively sets up some useful entries in the ~/.ssh/config file on the local machine. | ||
Exits if the User cancels any of the prompts or doesn't confirm the changes when asked. | ||
|
@@ -27,6 +37,9 @@ def setup_ssh_config( | |
directly to compute nodes. | ||
TODO: Also ask if we should add entries for the ComputeCanada/DRAC clusters. | ||
Returns: | ||
The resulting SSHConfig if the changes are approved. | ||
""" | ||
|
||
ssh_config_path = _setup_ssh_config_file(ssh_config_path) | ||
|
@@ -39,21 +52,28 @@ def setup_ssh_config( | |
# sure that the directory actually exists. | ||
control_path_dir.expanduser().mkdir(exist_ok=True, parents=True) | ||
|
||
if sys.platform == "win32": | ||
ssh_multiplexing_config = {} | ||
else: | ||
ssh_multiplexing_config = { | ||
# Tries to reuse an existing connection, but if it fails, it will create a new one. | ||
"ControlMaster": "auto", | ||
# This makes a file per connection, like [email protected]:2222 | ||
"ControlPath": str(control_path_dir / r"%r@%h:%p"), | ||
# persist for 10 minutes after the last connection ends. | ||
"ControlPersist": 600, | ||
} | ||
|
||
_add_ssh_entry( | ||
ssh_config, | ||
"mila", | ||
host="mila", | ||
HostName="login.server.mila.quebec", | ||
User=username, | ||
PreferredAuthentications="publickey,keyboard-interactive", | ||
Port=2222, | ||
ServerAliveInterval=120, | ||
ServerAliveCountMax=5, | ||
# Tries to reuse an existing connection, but if it fails, it will create a new one. | ||
ControlMaster="auto", | ||
# This makes a file per connection, like [email protected]:2222 | ||
ControlPath=str(control_path_dir / r"%r@%h:%p"), | ||
# persist for 10 minutes after the last connection ends. | ||
ControlPersist=600, | ||
**ssh_multiplexing_config, | ||
) | ||
|
||
_add_ssh_entry( | ||
|
@@ -97,12 +117,7 @@ def setup_ssh_config( | |
HostName="%h", | ||
User=username, | ||
ProxyJump="mila", | ||
# Tries to reuse an existing connection, but if it fails, it will create a new one. | ||
ControlMaster="auto", | ||
# This makes a file per connection, like [email protected]:2222 | ||
ControlPath=str(control_path_dir / r"%r@%h:%p"), | ||
# persist for 10 minutes after the last connection ends. | ||
ControlPersist=600, | ||
**ssh_multiplexing_config, | ||
) | ||
|
||
new_config = ssh_config.cfg.config() | ||
|
@@ -113,6 +128,76 @@ def setup_ssh_config( | |
else: | ||
ssh_config.save() | ||
print(f"Wrote {ssh_config_path}") | ||
return ssh_config | ||
|
||
|
||
def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig): | ||
"""Setup the Windows SSH configuration and public key from within WSL. | ||
This copies over the entries from the linux ssh configuration file, except for the | ||
values that aren't supported on Windows (e.g. "ControlMaster"). | ||
This also copies the public key file from the linux SSH directory over to the | ||
Windows SSH directory if it isn't already present. | ||
This makes it so the user doesn't need to install Python/Anaconda on the Windows | ||
side in order to use `mila code` from within WSL. | ||
""" | ||
assert running_inside_WSL() | ||
# NOTE: This also assumes that a public/private key pair has already been generated | ||
# at ~/.ssh/id_rsa.pub and ~/.ssh/id_rsa. | ||
windows_home = get_windows_home_path_in_wsl() | ||
windows_ssh_config_path = windows_home / ".ssh/config" | ||
windows_ssh_config_path = _setup_ssh_config_file(windows_ssh_config_path) | ||
|
||
windows_ssh_config = SSHConfig(windows_ssh_config_path) | ||
|
||
initial_windows_config_contents = windows_ssh_config.cfg.config() | ||
_copy_valid_ssh_entries_to_windows_ssh_config_file( | ||
linux_ssh_config, windows_ssh_config | ||
) | ||
new_windows_config_contents = windows_ssh_config.cfg.config() | ||
|
||
if ( | ||
new_windows_config_contents != initial_windows_config_contents | ||
and _confirm_changes(windows_ssh_config, initial_windows_config_contents) | ||
): | ||
# We made changes and they were accepted. | ||
windows_ssh_config.save() | ||
else: | ||
print(f"Did not change ssh config at path {windows_ssh_config.path}") | ||
return # also skip copying the SSH keys. | ||
|
||
# Copy the SSH key to the windows folder so that passwordless SSH also works on | ||
# Windows. | ||
# TODO: This will need to change if we support using a non-default location at some | ||
# point. | ||
linux_private_key_file = Path.home() / ".ssh/id_rsa" | ||
windows_private_key_file = windows_home / ".ssh/id_rsa" | ||
|
||
for linux_key_file, windows_key_file in [ | ||
(linux_private_key_file, windows_private_key_file), | ||
( | ||
linux_private_key_file.with_suffix(".pub"), | ||
windows_private_key_file.with_suffix(".pub"), | ||
), | ||
]: | ||
_copy_if_needed(linux_key_file, windows_key_file) | ||
|
||
|
||
def _copy_if_needed(linux_key_file: Path, windows_key_file: Path): | ||
if linux_key_file.exists() and not windows_key_file.exists(): | ||
print( | ||
f"Copying {linux_key_file} over to the Windows ssh folder at " | ||
f"{windows_key_file}." | ||
) | ||
shutil.copy2(src=linux_key_file, dst=windows_key_file) | ||
|
||
|
||
def get_windows_home_path_in_wsl() -> Path: | ||
assert running_inside_WSL() | ||
windows_username = subprocess.getoutput("powershell.exe '$env:UserName'").strip() | ||
return Path(f"/mnt/c/Users/{windows_username}") | ||
|
||
|
||
def _setup_ssh_config_file(config_file_path: str | Path) -> Path: | ||
|
@@ -147,7 +232,12 @@ def _setup_ssh_config_file(config_file_path: str | Path) -> Path: | |
|
||
|
||
def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool: | ||
print(T.bold("The following modifications will be made to your ~/.ssh/config:\n")) | ||
print( | ||
T.bold( | ||
f"The following modifications will be made to your SSH config file at " | ||
f"{ssh_config.path}:\n" | ||
) | ||
) | ||
diff_lines = list( | ||
difflib.unified_diff( | ||
(previous + "\n").splitlines(True), | ||
|
@@ -206,6 +296,7 @@ def _add_ssh_entry( | |
ssh_config: SSHConfig, | ||
host: str, | ||
Host: str | None = None, | ||
*, | ||
_space_before: bool = True, | ||
_space_after: bool = False, | ||
**entry, | ||
|
@@ -223,12 +314,42 @@ def _add_ssh_entry( | |
existing_entry = ssh_config.host(host) | ||
existing_entry.update(entry) | ||
ssh_config.cfg.set(host, **existing_entry) | ||
logger.debug(f"Updated {host} entry in ssh config.") | ||
logger.debug(f"Updated {host} entry in ssh config at path {ssh_config.path}.") | ||
else: | ||
ssh_config.add( | ||
host, | ||
_space_before=_space_before, | ||
_space_after=_space_after, | ||
**entry, | ||
) | ||
logger.debug(f"Adding new {host} entry in ssh config.") | ||
logger.debug( | ||
f"Adding new {host} entry in ssh config at path {ssh_config.path}." | ||
) | ||
|
||
|
||
def _copy_valid_ssh_entries_to_windows_ssh_config_file( | ||
linux_ssh_config: SSHConfig, windows_ssh_config: SSHConfig | ||
): | ||
unsupported_keys_lowercase = set(k.lower() for k in WINDOWS_UNSUPPORTED_KEYS) | ||
|
||
# NOTE: need to preserve the ordering of entries: | ||
for host in HOSTS + [ | ||
host for host in linux_ssh_config.hosts() if host not in HOSTS | ||
]: | ||
if host not in linux_ssh_config.hosts(): | ||
warnings.warn( | ||
RuntimeWarning( | ||
f"Weird, we expected to have a {host!r} entry in the SSH config..." | ||
) | ||
) | ||
continue | ||
linux_ssh_entry: dict[str, Any] = linux_ssh_config.host(host) | ||
_add_ssh_entry( | ||
windows_ssh_config, | ||
host, | ||
**{ | ||
key: value | ||
for key, value in linux_ssh_entry.items() | ||
if key.lower() not in unsupported_keys_lowercase | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.