diff --git a/milatools/cli/code.py b/milatools/cli/code.py index a2cc75f8..20268b4f 100644 --- a/milatools/cli/code.py +++ b/milatools/cli/code.py @@ -164,6 +164,8 @@ async def code( ) # todo: use the mila or the local machine as the reference for vscode # extensions? + # TODO: If the remote is a cluster that doesn't yet have `vscode-server`, we + # could launch vscode at the same time (or before) syncing the vscode extensions? sync_vscode_extensions_task = sync_vscode_extensions( LocalV2(), [login_node], diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index d08bad45..935fe979 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -34,9 +34,7 @@ from milatools.cli.common import find_allocation from milatools.utils.local_v1 import LocalV1 from milatools.utils.remote_v1 import RemoteV1 -from milatools.utils.vscode_utils import ( - sync_vscode_extensions_with_hostnames, -) +from milatools.utils.vscode_utils import sync_vscode_extensions from ..__version__ import __version__ from .code import add_mila_code_arguments @@ -226,7 +224,7 @@ def mila(): "extensions locally. Defaults to all the available SLURM clusters." ), ) - sync_vscode_parser.set_defaults(function=sync_vscode_extensions_with_hostnames) + sync_vscode_parser.set_defaults(function=sync_vscode_extensions) # ----- mila serve ------ diff --git a/milatools/utils/vscode_utils.py b/milatools/utils/vscode_utils.py index c4c30832..23790844 100644 --- a/milatools/utils/vscode_utils.py +++ b/milatools/utils/vscode_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import functools import os import shutil @@ -11,7 +12,6 @@ from typing import Literal, Sequence from milatools.cli.utils import ( - CLUSTERS, CommandNotFoundError, batched, stripped_lines_of, @@ -23,7 +23,6 @@ TaskID, async_progress_bar, ) -from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import RemoteV2 logger = get_logger(__name__) @@ -77,113 +76,106 @@ def vscode_installed() -> bool: return True -async def sync_vscode_extensions_with_hostnames( - source: str, - destinations: list[str], -): - destinations = list(destinations) - if source in destinations: - if source == "mila" and destinations == CLUSTERS: - logger.info("Assuming you want to sync from mila to all DRAC/CC clusters.") - else: - logger.warning( - f"{source=} is also in the destinations to sync to. " f"Removing it." - ) - destinations.remove(source) - - if len(set(destinations)) != len(destinations): - raise ValueError(f"{destinations=} contains duplicate hostnames!") - - source_obj = LocalV2() if source == "localhost" else await RemoteV2.connect(source) - return await sync_vscode_extensions(source_obj, destinations) - - async def sync_vscode_extensions( source: str | LocalV2 | RemoteV2, destinations: Sequence[str | LocalV2 | RemoteV2], -): +) -> dict[str, ProgressDict]: """Syncs vscode extensions between `source` all all the clusters in `dest`. This spawns a thread for each cluster in `dest` and displays a parallel progress bar for the syncing of vscode extensions to each cluster. """ - if isinstance(source, LocalV2): - source_hostname = "localhost" - source_extensions = await get_local_vscode_extensions() - elif isinstance(source, RemoteV2): - source_hostname = source.hostname - code_server_executable = await find_code_server_executable( - source, remote_vscode_server_dir="~/.vscode-server" - ) - if not code_server_executable: - raise RuntimeError( - f"The vscode-server executable was not found on {source.hostname}." - ) - source_extensions = await get_remote_vscode_extensions( - source, code_server_executable - ) - else: - assert isinstance(source, str) - source_hostname = source - source = RemoteV2(source) + if isinstance(source, str): + if source == "localhost": + source = LocalV2() + else: + source = await RemoteV2.connect(source) + + destinations = _remove_source_from_destinations(source, destinations) + + if not destinations: + logger.info("No destinations to sync extensions to!") + return {} + + source_hostname, source_extensions = await get_vscode_extensions(source) task_hostnames: list[str] = [] tasks: list[AsyncTaskFn[ProgressDict]] = [] task_descriptions: list[str] = [] - for dest_remote in destinations: - dest_hostname: str - - if dest_remote == "localhost": - dest_hostname = dest_remote # type: ignore - dest_remote = LocalV2() # pickleable - elif isinstance(dest_remote, LocalV2): - dest_hostname = "localhost" - dest_remote = dest_remote # again, pickleable - elif isinstance(dest_remote, RemoteV2): - dest_hostname = dest_remote.hostname - dest_remote = dest_remote # pickleable - elif isinstance(dest_remote, RemoteV1): - # We unfortunately can't pass this kind of object to another process or - # thread because it uses `fabric.Connection` which don't appear to be - # pickleable. This means we will have to re-connect in the subprocess. - dest_hostname = dest_remote.hostname - dest_remote = None - else: - assert isinstance(dest_remote, str) - # The dest_remote is a hostname. Try to connect to it with a reusable SSH - # control socket so we can get the 2FA prompts out of the way in advance. - # NOTE: We could fallback to using the `Remote` class with paramiko inside - # the subprocess if this doesn't work, but it would suck because it messes - # up the UI, and you need to press 1 in the terminal to get the 2FA prompt, - # which screws up the progress bars. - dest_hostname = dest_remote - dest_remote = RemoteV2(hostname=dest_hostname) - + # Connect to the remotes in parallel. + dest_runners_and_hostnames = await asyncio.gather( + *(_get_runner_and_hostname(dest) for dest in destinations) + ) + for dest_runner, dest_hostname in dest_runners_and_hostnames: task_hostnames.append(dest_hostname) tasks.append( functools.partial( install_vscode_extensions_task_function, dest_hostname=dest_hostname, source_extensions=source_extensions, - remote=dest_remote, + remote=dest_runner, source_name=source_hostname, ) ) task_descriptions.append(f"{source_hostname} -> {dest_hostname}") - results: dict[str, ProgressDict] = {} + return { + hostname: result + for hostname, result in zip( + task_hostnames, + await async_progress_bar( + async_task_fns=tasks, + task_descriptions=task_descriptions, + overall_progress_task_description="[green]Syncing vscode extensions:", + ), + ) + } - for hostname, result in zip( - task_hostnames, - await async_progress_bar( - async_task_fns=tasks, - task_descriptions=task_descriptions, - overall_progress_task_description="[green]Syncing vscode extensions:", - ), - ): - results[hostname] = result - return results + +def _remove_source_from_destinations( + source: LocalV2 | RemoteV2, destinations: Sequence[str | LocalV2 | RemoteV2] +): + dest_hostnames = [ + dest if isinstance(dest, str) else dest.hostname for dest in destinations + ] + if source.hostname in dest_hostnames: + logger.debug(f"{source.hostname!r} is also in the destinations, removing it.") + destinations = list(destinations) + destinations.pop(dest_hostnames.index(source.hostname)) + + if len(set(dest_hostnames)) != len(dest_hostnames): + raise ValueError(f"{dest_hostnames=} contains duplicate hostnames!") + return destinations + + +async def _get_runner_and_hostname( + dest_remote: str | LocalV2 | RemoteV2, +) -> tuple[LocalV2 | RemoteV2, str]: + if isinstance(dest_remote, str): + if dest_remote == "localhost": + return LocalV2(), dest_remote + dest_hostname = dest_remote + dest_remote = await RemoteV2.connect(dest_hostname) + return dest_remote, dest_hostname + return dest_remote, dest_remote.hostname + + +async def get_vscode_extensions( + source: LocalV2 | RemoteV2, +) -> tuple[str, dict[str, str]]: + if isinstance(source, LocalV2): + code_server_executable = get_local_vscode_executable_path(code_command=None) + else: + code_server_executable = await find_code_server_executable( + source, remote_vscode_server_dir="~/.vscode-server" + ) + if not code_server_executable: + raise RuntimeError( + f"The vscode-server executable was not found on {source.hostname}." + ) + source_extensions = await _get_vscode_extensions(source, code_server_executable) + return source.hostname, source_extensions async def install_vscode_extensions_task_function( @@ -191,7 +183,7 @@ async def install_vscode_extensions_task_function( task_id: TaskID, dest_hostname: str | Literal["localhost"], source_extensions: dict[str, str], - remote: RemoteV2 | LocalV2 | None, + remote: RemoteV2 | LocalV2, source_name: str, verbose: bool = False, ) -> ProgressDict: @@ -208,8 +200,8 @@ async def install_vscode_extensions_task_function( def _update_progress( progress: int, status: str, total: int = len(source_extensions) ): - # Show progress to the parent process by setting an item in the task progress - # dict. + """Shows progress to the parent process by setting an item in the task progress + dict.""" progress_dict: ProgressDict = { "progress": progress, "total": total, @@ -218,7 +210,7 @@ def _update_progress( task_progress_dict[task_id] = progress_dict return progress_dict - if remote is None: + if not remote: if dest_hostname == "localhost": remote = LocalV2() else: @@ -228,7 +220,9 @@ def _update_progress( if isinstance(remote, LocalV2): assert dest_hostname == "localhost" code_server_executable = get_local_vscode_executable_path() - extensions_on_dest = await get_local_vscode_extensions() + extensions_on_dest = await _get_vscode_extensions( + remote, code_server_executable + ) else: dest_hostname = remote.hostname remote_vscode_server_dir = "~/.vscode-server" @@ -250,7 +244,7 @@ def _update_progress( status="code-server executable not found!", ) _update_progress(0, status="fetching installed extensions...") - extensions_on_dest = await get_remote_vscode_extensions( + extensions_on_dest = await _get_vscode_extensions( remote, code_server_executable ) @@ -317,35 +311,32 @@ async def install_vscode_extension( return result -async def get_local_vscode_extensions( +async def _get_local_vscode_extensions( code_command: str | None = None, ) -> dict[str, str]: - output = await LocalV2.get_output_async( - ( - get_local_vscode_executable_path(code_command=code_command), - "--list-extensions", - "--show-versions", + return await _get_vscode_extensions( + LocalV2(), + code_server_executable=get_local_vscode_executable_path( + code_command=code_command ), ) - return parse_vscode_extensions_versions(stripped_lines_of(output)) -async def get_remote_vscode_extensions( - remote: RemoteV2, - remote_code_server_executable: str, +async def _get_vscode_extensions( + remote: RemoteV2 | LocalV2, + code_server_executable: str, ) -> dict[str, str]: """Returns the list of isntalled extensions and the path to the code-server executable.""" - remote_extensions = parse_vscode_extensions_versions( + return parse_vscode_extensions_versions( stripped_lines_of( await remote.get_output_async( - f"{remote_code_server_executable} --list-extensions --show-versions", + f"{code_server_executable} --list-extensions --show-versions", display=False, hide=True, ) ) ) - return remote_extensions def extensions_to_install( diff --git a/tests/utils/test_vscode_utils.py b/tests/utils/test_vscode_utils.py index 6ac7f419..a479a4e9 100644 --- a/tests/utils/test_vscode_utils.py +++ b/tests/utils/test_vscode_utils.py @@ -11,22 +11,20 @@ import pytest_asyncio from milatools.cli.utils import MilatoolsUserError, running_inside_WSL -from milatools.utils.local_v2 import LocalV2 from milatools.utils.parallel_progress import ProgressDict from milatools.utils.remote_v1 import RemoteV1 -from milatools.utils.remote_v2 import RemoteV2, UnsupportedPlatformError +from milatools.utils.remote_v2 import RemoteV2 from milatools.utils.vscode_utils import ( + _get_local_vscode_extensions, + _get_vscode_extensions, extensions_to_install, find_code_server_executable, get_code_command, get_expected_vscode_settings_json_path, get_local_vscode_executable_path, - get_local_vscode_extensions, - get_remote_vscode_extensions, install_vscode_extension, install_vscode_extensions_task_function, sync_vscode_extensions, - sync_vscode_extensions_with_hostnames, vscode_installed, ) @@ -35,7 +33,6 @@ in_self_hosted_github_CI, requires_ssh_to_localhost, skip_if_on_github_cloud_CI, - xfails_on_windows, ) from .test_remote_v2 import uses_remote_v2 @@ -117,30 +114,26 @@ def mock_find_code_server_executable(monkeypatch: pytest.MonkeyPatch): return mock_find_code_server_executable -@xfails_on_windows(raises=UnsupportedPlatformError, reason="Uses RemoteV2", strict=True) +@uses_remote_v2 @requires_vscode @requires_ssh_to_localhost @pytest.mark.asyncio -async def test_sync_vscode_extensions_in_parallel_with_hostnames( +async def test_sync_vscode_extensions( mock_find_code_server_executable: Mock, ): - await sync_vscode_extensions_with_hostnames( + user = getpass.getuser() + results = await sync_vscode_extensions( "localhost", - # Make the destination slightly different so it actually gets wrapped as a - # `Remote(v2)` object. - destinations=[f"{getpass.getuser()}@localhost"], + # Make the destination slightly different to avoid the duplicate hostname + # detection that happens in `sync_vscode_extensions`. + destinations=[f"{user}@localhost"], ) + assert results == { + f"{user}@localhost": {"info": "Done.", "progress": 0, "total": 0} + } mock_find_code_server_executable.assert_called() -@requires_vscode -@requires_ssh_to_localhost -@pytest.mark.asyncio -async def test_sync_vscode_extensions_in_parallel(): - results = await sync_vscode_extensions(LocalV2(), destinations=[LocalV2()]) - assert results == {"localhost": {"info": "Done.", "progress": 0, "total": 0}} - - @pytest_asyncio.fixture async def vscode_extensions( request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch @@ -150,7 +143,7 @@ async def vscode_extensions( Here we pretend like some local vscode extensions are missing by patching the function that returns the local extensions to return only part of its actual result. """ - all_extensions = await get_local_vscode_extensions() + all_extensions = await _get_local_vscode_extensions() installed_extensions = all_extensions.copy() num_missing_extensions = 3 @@ -164,12 +157,12 @@ async def vscode_extensions( # `localhost` is the source, so it has all the extensions # the "remote" (just to localhost during tests) is missing some extensions mock_remote_extensions = AsyncMock( - spec=get_remote_vscode_extensions, + spec=_get_vscode_extensions, return_value=(installed_extensions, str(get_local_vscode_executable_path())), ) monkeypatch.setattr( milatools.utils.vscode_utils, - get_remote_vscode_extensions.__name__, + _get_vscode_extensions.__name__, mock_remote_extensions, ) @@ -265,7 +258,7 @@ async def test_install_vscode_extension(missing_extensions: dict[str, str]): @requires_vscode @pytest.mark.asyncio async def test_get_local_vscode_extensions(): - local_extensions = await get_local_vscode_extensions() + local_extensions = await _get_local_vscode_extensions() assert local_extensions and all( isinstance(ext, str) and isinstance(version, str) for ext, version in local_extensions.items() @@ -284,10 +277,10 @@ async def test_get_remote_vscode_extensions(mock_find_code_server_executable): local_vscode_executable = get_local_vscode_executable_path() assert local_vscode_executable is not None - fake_remote_extensions = await get_remote_vscode_extensions( - fake_remote, remote_code_server_executable=local_vscode_executable + fake_remote_extensions = await _get_vscode_extensions( + fake_remote, code_server_executable=local_vscode_executable ) - assert fake_remote_extensions == await get_local_vscode_extensions() + assert fake_remote_extensions == await _get_local_vscode_extensions() @requires_vscode