-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the mila init
and mila code
commands on windows & WSL
#65
Merged
Merged
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
a5ecdcc
Dont add ControlMaster ssh entries on Windows
lebrice b5f6ce9
Remove the unused params of _add_ssh_entry
lebrice e43a4fb
Rebase on Master
lebrice 05b2b9a
Fix isort issues
lebrice c6978af
Fix the `mila init` command on windows
lebrice fb30bb8
Fix black formatting errors
lebrice 126a46b
Fix the equivalent of ssh-copy-id, now works!
lebrice 557c9ac
Attempt to fix running `mila code` from WSL
lebrice ddd7ad0
Fix whitespace errors and run `code` directly
lebrice 7f0b0c4
Set the remote.SSH settings in User Settings Json
lebrice 440d1e3
Show that VsCode settings are found
lebrice 1838323
Remove the VsCode settings thing for now
lebrice 0b92c8f
Use `here.run` to run the command
lebrice a0ad53f
Skip the hostname checking on Windows machines
lebrice e881705
Fix black formatting issue
lebrice 477997d
Fix missing condition in WSL check
lebrice 9b5fd37
Warn WSL users if `mila init` not done on Windows
lebrice cbe4d74
Also setup Windows SSH config from WSL
lebrice 611b2b6
Fix isort issues
lebrice f687847
Reuse existing fn to setup windows SSH file
lebrice 62bd443
Preserve ordering of SSH entries when copying
lebrice a2f4b86
Copy both public and private key files
lebrice 90b85b8
Add tests for the Windows SSH setup from WSL
lebrice dd87787
Remove need for pytest-mock dependency
lebrice 531b89a
Fix a bug in the added tests
lebrice f073c9e
Remove accidentally-added .pre-commit-config.yaml
lebrice File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,30 @@ | ||
from __future__ import annotations | ||
|
||
import difflib | ||
import os | ||
import shutil | ||
import subprocess | ||
import sys | ||
import warnings | ||
from logging import getLogger as get_logger | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import questionary as qn | ||
|
||
from .utils import SSHConfig, T, yn | ||
from .utils import SSHConfig, T, running_inside_WSL, yn | ||
|
||
WINDOWS_UNSUPPORTED_KEYS = ["ControlMaster", "ControlPath", "ControlPersist"] | ||
|
||
logger = get_logger(__name__) | ||
|
||
HOSTS = ["mila", "mila-cpu", "*.server.mila.quebec !*login.server.mila.quebec"] | ||
"""List of host entries that get added to the SSH configurtion by `mila init`.""" | ||
|
||
|
||
def setup_ssh_config( | ||
ssh_config_path: str | Path = "~/.ssh/config", | ||
): | ||
) -> SSHConfig: | ||
"""Interactively sets up some useful entries in the ~/.ssh/config file on the local machine. | ||
|
||
Exits if the User cancels any of the prompts or doesn't confirm the changes when asked. | ||
|
@@ -27,6 +38,9 @@ def setup_ssh_config( | |
directly to compute nodes. | ||
|
||
TODO: Also ask if we should add entries for the ComputeCanada/DRAC clusters. | ||
|
||
Returns: | ||
The resulting SSHConfig if the changes are approved. | ||
""" | ||
|
||
ssh_config_path = _setup_ssh_config_file(ssh_config_path) | ||
|
@@ -39,21 +53,28 @@ def setup_ssh_config( | |
# sure that the directory actually exists. | ||
control_path_dir.expanduser().mkdir(exist_ok=True, parents=True) | ||
|
||
if sys.platform == "win32": | ||
ssh_multiplexing_config = {} | ||
else: | ||
ssh_multiplexing_config = { | ||
# Tries to reuse an existing connection, but if it fails, it will create a new one. | ||
"ControlMaster": "auto", | ||
# This makes a file per connection, like [email protected]:2222 | ||
"ControlPath": str(control_path_dir / r"%r@%h:%p"), | ||
# persist for 10 minutes after the last connection ends. | ||
"ControlPersist": 600, | ||
} | ||
|
||
_add_ssh_entry( | ||
ssh_config, | ||
"mila", | ||
host="mila", | ||
HostName="login.server.mila.quebec", | ||
User=username, | ||
PreferredAuthentications="publickey,keyboard-interactive", | ||
Port=2222, | ||
ServerAliveInterval=120, | ||
ServerAliveCountMax=5, | ||
# Tries to reuse an existing connection, but if it fails, it will create a new one. | ||
ControlMaster="auto", | ||
# This makes a file per connection, like [email protected]:2222 | ||
ControlPath=str(control_path_dir / r"%r@%h:%p"), | ||
# persist for 10 minutes after the last connection ends. | ||
ControlPersist=600, | ||
**ssh_multiplexing_config, | ||
) | ||
|
||
_add_ssh_entry( | ||
|
@@ -97,12 +118,7 @@ def setup_ssh_config( | |
HostName="%h", | ||
User=username, | ||
ProxyJump="mila", | ||
# Tries to reuse an existing connection, but if it fails, it will create a new one. | ||
ControlMaster="auto", | ||
# This makes a file per connection, like [email protected]:2222 | ||
ControlPath=str(control_path_dir / r"%r@%h:%p"), | ||
# persist for 10 minutes after the last connection ends. | ||
ControlPersist=600, | ||
**ssh_multiplexing_config, | ||
) | ||
|
||
new_config = ssh_config.cfg.config() | ||
|
@@ -113,6 +129,62 @@ def setup_ssh_config( | |
else: | ||
ssh_config.save() | ||
print(f"Wrote {ssh_config_path}") | ||
return ssh_config | ||
|
||
|
||
def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig): | ||
"""Setup the Windows SSH configuration and public key from within WSL. | ||
|
||
This copies over the entries from the linux ssh configuration file, except for the | ||
values that aren't supported on Windows (e.g. "ControlMaster"). | ||
|
||
This also copies the public key file from the linux SSH directory over to the | ||
Windows SSH directory if it isn't already present. | ||
|
||
This makes it so the user doesn't need to install Python/Anaconda on the Windows | ||
side in order to use `mila code` from within WSL. | ||
""" | ||
assert running_inside_WSL() | ||
# NOTE: This also assumes that a public/private key pair has already been generated | ||
# at ~/.ssh/id_rsa.pub and ~/.ssh/id_rsa. | ||
windows_home = get_windows_home_path_in_wsl() | ||
windows_ssh_config_path = windows_home / ".ssh/config" | ||
windows_ssh_config_path = _setup_ssh_config_file(windows_ssh_config_path) | ||
|
||
_copy_valid_ssh_entries_to_windows_ssh_config_file( | ||
linux_ssh_config, windows_ssh_config_path | ||
) | ||
|
||
# copy the public_key_to_windows_ssh_folder | ||
# TODO: This might be different if the user selects a non-default location during | ||
# `ssh-keygen`. | ||
|
||
linux_private_key_file = Path.home() / ".ssh/id_rsa" | ||
windows_private_key_file = windows_home / ".ssh/id_rsa" | ||
|
||
for linux_key_file, windows_key_file in [ | ||
(linux_private_key_file, windows_private_key_file), | ||
( | ||
linux_private_key_file.with_suffix(".pub"), | ||
windows_private_key_file.with_suffix(".pub"), | ||
), | ||
]: | ||
_copy_if_needed(linux_key_file, windows_key_file) | ||
|
||
|
||
def _copy_if_needed(linux_key_file: Path, windows_key_file: Path): | ||
if linux_key_file.exists() and not windows_key_file.exists(): | ||
print( | ||
f"Copying {linux_key_file} over to the Windows ssh folder at " | ||
f"{windows_key_file}." | ||
) | ||
shutil.copy2(src=linux_key_file, dst=windows_key_file) | ||
|
||
|
||
def get_windows_home_path_in_wsl() -> Path: | ||
assert running_inside_WSL() | ||
windows_username = subprocess.getoutput("powershell.exe '$env:UserName'").strip() | ||
return Path(f"/mnt/c/Users/{windows_username}") | ||
|
||
|
||
def _setup_ssh_config_file(config_file_path: str | Path) -> Path: | ||
|
@@ -147,7 +219,12 @@ def _setup_ssh_config_file(config_file_path: str | Path) -> Path: | |
|
||
|
||
def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool: | ||
print(T.bold("The following modifications will be made to your ~/.ssh/config:\n")) | ||
print( | ||
T.bold( | ||
f"The following modifications will be made to your SSH config file at " | ||
f"{ssh_config.path}:\n" | ||
) | ||
) | ||
diff_lines = list( | ||
difflib.unified_diff( | ||
(previous + "\n").splitlines(True), | ||
|
@@ -206,6 +283,7 @@ def _add_ssh_entry( | |
ssh_config: SSHConfig, | ||
host: str, | ||
Host: str | None = None, | ||
*, | ||
_space_before: bool = True, | ||
_space_after: bool = False, | ||
**entry, | ||
|
@@ -223,12 +301,56 @@ def _add_ssh_entry( | |
existing_entry = ssh_config.host(host) | ||
existing_entry.update(entry) | ||
ssh_config.cfg.set(host, **existing_entry) | ||
logger.debug(f"Updated {host} entry in ssh config.") | ||
logger.debug(f"Updated {host} entry in ssh config at path {ssh_config.path}.") | ||
else: | ||
ssh_config.add( | ||
host, | ||
_space_before=_space_before, | ||
_space_after=_space_after, | ||
**entry, | ||
) | ||
logger.debug(f"Adding new {host} entry in ssh config.") | ||
logger.debug( | ||
f"Adding new {host} entry in ssh config at path {ssh_config.path}." | ||
) | ||
|
||
|
||
def _copy_valid_ssh_entries_to_windows_ssh_config_file( | ||
linux_ssh_config: SSHConfig, windows_ssh_config_path: Path | ||
): | ||
windows_ssh_config = SSHConfig(windows_ssh_config_path) | ||
initial_windows_config_contents = windows_ssh_config.cfg.config() | ||
|
||
unsupported_keys_lowercase = set(k.lower() for k in WINDOWS_UNSUPPORTED_KEYS) | ||
|
||
# NOTE: need to preserve the ordering of entries: | ||
for host in HOSTS + [ | ||
host for host in linux_ssh_config.hosts() if host not in HOSTS | ||
]: | ||
if host not in linux_ssh_config.hosts(): | ||
warnings.warn( | ||
RuntimeWarning( | ||
f"Weird, we expected to have a {host!r} entry in the SSH config..." | ||
) | ||
) | ||
continue | ||
linux_ssh_entry: dict[str, Any] = linux_ssh_config.host(host) | ||
_add_ssh_entry( | ||
windows_ssh_config, | ||
host, | ||
**{ | ||
key: value | ||
for key, value in linux_ssh_entry.items() | ||
if key.lower() not in unsupported_keys_lowercase | ||
}, | ||
) | ||
|
||
new_config_contents = windows_ssh_config.cfg.config() | ||
if new_config_contents == initial_windows_config_contents: | ||
print(f"Did not change ssh config at path {windows_ssh_config.path}") | ||
return | ||
if not _confirm_changes( | ||
windows_ssh_config, previous=initial_windows_config_contents | ||
): | ||
exit() | ||
# We made changes and they were accepted. | ||
windows_ssh_config.save() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the
*
is doing? Is it to grab all extra args? Should an exception be thrown instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's making it so
_space_before
and_space_after
cannot be passed as positional arguments. This helps remove a typing error when calls to_add_ssh_entry
are made with different **kwargs (Windows vs non-Windows)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for reference:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah get it thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we, should we assert that nothing is in this
*
?Only thinking out loud here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly! That's what using a
*
without a name does