Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

copy id_rsa wsl to windows #135

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,16 +519,17 @@ def init():

ssh_config = setup_ssh_config()

success = setup_passwordless_ssh_access(ssh_config=ssh_config)
if not success:
exit()

# if we're running on WSL, we actually just copy the id_rsa + id_rsa.pub and the
# ~/.ssh/config to the Windows ssh directory (taking care to remove the
# ControlMaster-related entries) so that the user doesn't need to install Python on
# the Windows side.
if running_inside_WSL():
setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config)

success = setup_passwordless_ssh_access(ssh_config=ssh_config)
if not success:
exit()
setup_keys_on_login_node()
setup_vscode_settings()
print_welcome_message()
Expand Down
17 changes: 13 additions & 4 deletions milatools/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,13 @@ def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig):
windows_ssh_config.save()
else:
print(f"Did not change ssh config at path {windows_ssh_config.path}")
return # also skip copying the SSH keys.

# if running inside WSL, copy the keys to the Windows folder.
# Copy the SSH key to the windows folder so that passwordless SSH also works on
# Windows.
# TODO: This will need to change if we support using a non-default location at some
# point.
assert running_inside_WSL()
# TODO: Get the key to copy from the SSH config instead of assuming ~/.ssh/id_rsa.
windows_home = get_windows_home_path_in_wsl()
linux_private_key_file = Path.home() / ".ssh/id_rsa"
windows_private_key_file = windows_home / ".ssh/id_rsa"

Expand Down Expand Up @@ -254,6 +255,7 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool:
# TODO: This uses the public key set in the SSH config file, which may (or may not)
# be the random id*.pub file that was just checked for above.
success = setup_passwordless_ssh_access_to_cluster("mila")

if not success:
return False
setup_keys_on_login_node("mila")
Expand Down Expand Up @@ -427,14 +429,21 @@ def print_welcome_message():


def _copy_if_needed(linux_key_file: Path, windows_key_file: Path):
if linux_key_file.exists() and not windows_key_file.exists():
assert linux_key_file.exists()
if not windows_key_file.exists():
print(
f"Copying {linux_key_file} over to the Windows ssh folder at "
f"{windows_key_file}."
)
shutil.copy2(src=linux_key_file, dst=windows_key_file)
return

print(
f"{windows_key_file} already exists. Not overwriting it with contents of {linux_key_file}."
)


@functools.lru_cache
def get_windows_home_path_in_wsl() -> Path:
assert running_inside_WSL()
windows_username = subprocess.getoutput("powershell.exe '$env:UserName'").strip()
Expand Down
23 changes: 0 additions & 23 deletions tests/cli/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,10 @@

import milatools.cli.code
import milatools.cli.utils
from milatools.cli.utils import running_inside_WSL
from milatools.utils.compute_node import ComputeNode
from milatools.utils.local_v2 import LocalV2


@pytest.fixture
def pretend_to_be_in_WSL(
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
):
# By default, pretend to be in WSL. Indirect parametrization can be used to
# overwrite this value for a given test (as is done below).
in_wsl = getattr(request, "param", True)

_mock_running_inside_WSL = Mock(spec=running_inside_WSL, return_value=in_wsl)
monkeypatch.setattr(
milatools.cli.utils,
running_inside_WSL.__name__, # type: ignore
_mock_running_inside_WSL,
)
monkeypatch.setattr(
milatools.cli.code,
running_inside_WSL.__name__, # type: ignore
_mock_running_inside_WSL,
)
return in_wsl


@pytest.mark.parametrize("pretend_to_be_in_WSL", [True, False], indirect=True)
@pytest.mark.asyncio
async def test_code_from_WSL(
Expand Down
112 changes: 77 additions & 35 deletions tests/cli/test_init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,28 +832,17 @@ def linux_ssh_config(

@pytest.mark.parametrize("accept_changes", [True, False], ids=["accept", "reject"])
def test_setup_windows_ssh_config_from_wsl(
tmp_path: Path,
pretend_to_be_in_WSL, # here even if `windows_home` already uses it (more explicit)
windows_home: Path,
linux_ssh_config: SSHConfig,
input_pipe: PipeInput,
file_regression: FileRegressionFixture,
monkeypatch: pytest.MonkeyPatch,
fake_linux_ssh_keypair: tuple[Path, Path], # add this fixture so the keys exist.
accept_changes: bool,
):
initial_contents = linux_ssh_config.cfg.config()
windows_home = tmp_path / "fake_windows_home"
windows_home.mkdir(exist_ok=False)
windows_ssh_config_path = windows_home / ".ssh" / "config"

monkeypatch.setattr(
init_command,
running_inside_WSL.__name__,
Mock(spec=running_inside_WSL, return_value=True),
)
monkeypatch.setattr(
init_command,
get_windows_home_path_in_wsl.__name__,
Mock(spec=get_windows_home_path_in_wsl, return_value=windows_home),
)
user_inputs: list[str] = []
if not windows_ssh_config_path.exists():
# We accept creating the Windows SSH config file for now.
Expand Down Expand Up @@ -900,6 +889,43 @@ def test_setup_windows_ssh_config_from_wsl(
file_regression.check(expected_text, extension=".md")


@pytest.fixture
def windows_ssh_config(
linux_ssh_config: SSHConfig,
windows_home: Path,
input_pipe: PipeInput,
monkeypatch: pytest.MonkeyPatch,
) -> SSHConfig:
"""Returns the Windows ssh config as it would be when we create it from WSL."""
windows_ssh_config_path = windows_home / ".ssh" / "config"
monkeypatch.setattr(
init_command,
running_inside_WSL.__name__, # type: ignore
Mock(spec=running_inside_WSL, return_value=True),
)
monkeypatch.setattr(
init_command,
get_windows_home_path_in_wsl.__name__, # type: ignore
Mock(spec=get_windows_home_path_in_wsl, return_value=windows_home),
)
user_inputs: list[str] = []
if not windows_ssh_config_path.exists():
# We accept creating the Windows SSH config file for now.
user_inputs.append("y")
user_inputs.append("y") # accept changes.

for prompt in user_inputs:
input_pipe.send_text(prompt)

setup_windows_ssh_config_from_wsl(linux_ssh_config=linux_ssh_config)

assert windows_ssh_config_path.exists()
assert windows_ssh_config_path.stat().st_mode & 0o777 == 0o600
assert windows_ssh_config_path.parent.stat().st_mode & 0o777 == 0o700

return SSHConfig(windows_ssh_config_path)


@xfails_on_windows(
raises=AssertionError, reason="TODO: buggy test: getting assert None is not None."
)
Expand Down Expand Up @@ -992,52 +1018,68 @@ def test_setup_vscode_settings(
file_regression.check(expected_text, extension=".md")


def test_setup_windows_ssh_config_from_wsl_copies_keys(
tmp_path: Path,
linux_ssh_config: SSHConfig,
input_pipe: PipeInput,
monkeypatch: pytest.MonkeyPatch,
):
@pytest.fixture
def linux_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""Creates a fake home directory where we will make a fake SSH directory for
tests."""
linux_home = tmp_path / "fake_linux_home"
linux_home.mkdir(exist_ok=False)
windows_home = tmp_path / "fake_windows_home"
windows_home.mkdir(exist_ok=False)

monkeypatch.setattr(Path, "home", Mock(spec=Path.home, return_value=linux_home))
return linux_home

monkeypatch.setattr(
init_command,
running_inside_WSL.__name__,
Mock(spec=running_inside_WSL, return_value=True),
)
monkeypatch.setattr(
init_command,
get_windows_home_path_in_wsl.__name__,
Mock(spec=get_windows_home_path_in_wsl, return_value=windows_home),
)

@pytest.fixture
def fake_linux_ssh_keypair(linux_home: Path):
"""Creates a fake ssh key pair in some mock ssh directory.

Used in tests related to mila init and WSL.
"""

fake_linux_ssh_dir = linux_home / ".ssh"
fake_linux_ssh_dir.mkdir(mode=0o700)

private_key_text = "THIS IS A PRIVATE KEY"
linux_private_key_path = fake_linux_ssh_dir / "id_rsa"
linux_private_key_path.write_text(private_key_text)
linux_private_key_path.chmod(mode=0o600)

public_key_text = "THIS IS A PUBLIC KEY"
linux_public_key_path = linux_private_key_path.with_suffix(".pub")
linux_public_key_path.write_text(public_key_text)
linux_public_key_path.chmod(mode=0o600)

return linux_public_key_path, linux_private_key_path


def test_setup_windows_ssh_config_from_wsl_copies_keys(
linux_ssh_config: SSHConfig,
input_pipe: PipeInput,
windows_home: Path,
linux_home: Path,
fake_linux_ssh_keypair: tuple[Path, Path],
):
linux_public_key_path, linux_private_key_path = fake_linux_ssh_keypair

input_pipe.send_text("y") # accept creating the Windows config file
input_pipe.send_text("y") # accept the changes

setup_windows_ssh_config_from_wsl(linux_ssh_config=linux_ssh_config)

windows_private_key_path = windows_home / ".ssh" / "id_rsa"
windows_private_key_path = windows_home / linux_private_key_path.relative_to(
linux_home
)
windows_public_key_path = windows_private_key_path.with_suffix(".pub")

# TODO: Check that the copied key has the correct permissions (and content) on **WINDOWS**.
assert windows_private_key_path.exists()
assert windows_private_key_path.read_text() == private_key_text
assert windows_private_key_path.stat().st_mode & 0o777 == 0o600
assert windows_public_key_path.exists()
assert windows_public_key_path.read_text() == public_key_text
assert windows_public_key_path.stat().st_mode & 0o777 == 0o600
# todo: Might have to manually add the weird CRLF line endings to the public/private
# key file?
assert windows_private_key_path.read_text() == linux_private_key_path.read_text()
assert windows_public_key_path.read_text() == linux_public_key_path.read_text()


BACKUP_SSH_DIR = Path.home() / ".ssh_backup"
Expand Down
46 changes: 44 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@

import milatools.cli.code
import milatools.cli.commands
import milatools.cli.init_command
import milatools.cli.utils
import milatools.utils.compute_node
import milatools.utils.disk_quota
import milatools.utils.local_v2
import milatools.utils.parallel_progress
import milatools.utils.remote_v2
from milatools.cli import console
from milatools.cli.init_command import setup_ssh_config
from milatools.cli.utils import SSH_CONFIG_FILE
from milatools.cli.init_command import get_windows_home_path_in_wsl, setup_ssh_config
from milatools.cli.utils import SSH_CONFIG_FILE, running_inside_WSL
from milatools.utils.compute_node import get_queued_milatools_job_ids
from milatools.utils.remote_v1 import RemoteV1
from milatools.utils.remote_v2 import (
Expand Down Expand Up @@ -436,3 +438,43 @@ def _mock_text(message: str, *args, **kwargs):
setup_ssh_config(ssh_config_path)
assert ssh_config_path.exists()
return ssh_config_path


@pytest.fixture
def pretend_to_be_in_WSL(
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
):
# By default, pretend to be in WSL. Indirect parametrization can be used to
# overwrite this value for a given test (as is done below).
in_wsl = getattr(request, "param", True)
_mock_running_inside_WSL = Mock(spec=running_inside_WSL, return_value=in_wsl)
monkeypatch.setattr(
milatools.cli.utils, # defined here
running_inside_WSL.__name__, # type: ignore
_mock_running_inside_WSL,
)
# Unfortunately we have to also patch this everywhere we import it in other modules.
for place_that_imports_it in [
milatools.cli.code,
milatools.cli.init_command,
milatools.cli.commands,
]:
monkeypatch.setattr(
place_that_imports_it,
running_inside_WSL.__name__, # type: ignore
_mock_running_inside_WSL,
)

return in_wsl


@pytest.fixture
def windows_home(pretend_to_be_in_WSL, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
windows_home = tmp_path / "fake_windows_home"
windows_home.mkdir(exist_ok=False)
monkeypatch.setattr(
milatools.cli.init_command,
get_windows_home_path_in_wsl.__name__, # type: ignore
Mock(spec=get_windows_home_path_in_wsl, return_value=windows_home),
)
return windows_home
Loading