Skip to content

Commit

Permalink
Simplify the vscode utils for syncing extensions
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed May 16, 2024
1 parent 4d9a5a9 commit ee33330
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 134 deletions.
2 changes: 2 additions & 0 deletions milatools/cli/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ async def code(
)
# todo: use the mila or the local machine as the reference for vscode
# extensions?
# TODO: If the remote is a cluster that doesn't yet have `vscode-server`, we
# could launch vscode at the same time (or before) syncing the vscode extensions?
sync_vscode_extensions_task = sync_vscode_extensions(
LocalV2(),
[login_node],
Expand Down
6 changes: 2 additions & 4 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
from milatools.cli.common import find_allocation
from milatools.utils.local_v1 import LocalV1
from milatools.utils.remote_v1 import RemoteV1
from milatools.utils.vscode_utils import (
sync_vscode_extensions_with_hostnames,
)
from milatools.utils.vscode_utils import sync_vscode_extensions

from ..__version__ import __version__
from .code import add_mila_code_arguments
Expand Down Expand Up @@ -226,7 +224,7 @@ def mila():
"extensions locally. Defaults to all the available SLURM clusters."
),
)
sync_vscode_parser.set_defaults(function=sync_vscode_extensions_with_hostnames)
sync_vscode_parser.set_defaults(function=sync_vscode_extensions)

# ----- mila serve ------

Expand Down
197 changes: 94 additions & 103 deletions milatools/utils/vscode_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import functools
import os
import shutil
Expand All @@ -11,7 +12,6 @@
from typing import Literal, Sequence

from milatools.cli.utils import (
CLUSTERS,
CommandNotFoundError,
batched,
stripped_lines_of,
Expand All @@ -23,7 +23,6 @@
TaskID,
async_progress_bar,
)
from milatools.utils.remote_v1 import RemoteV1
from milatools.utils.remote_v2 import RemoteV2

logger = get_logger(__name__)
Expand Down Expand Up @@ -77,121 +76,114 @@ def vscode_installed() -> bool:
return True


async def sync_vscode_extensions_with_hostnames(
source: str,
destinations: list[str],
):
destinations = list(destinations)
if source in destinations:
if source == "mila" and destinations == CLUSTERS:
logger.info("Assuming you want to sync from mila to all DRAC/CC clusters.")
else:
logger.warning(
f"{source=} is also in the destinations to sync to. " f"Removing it."
)
destinations.remove(source)

if len(set(destinations)) != len(destinations):
raise ValueError(f"{destinations=} contains duplicate hostnames!")

source_obj = LocalV2() if source == "localhost" else await RemoteV2.connect(source)
return await sync_vscode_extensions(source_obj, destinations)


async def sync_vscode_extensions(
source: str | LocalV2 | RemoteV2,
destinations: Sequence[str | LocalV2 | RemoteV2],
):
) -> dict[str, ProgressDict]:
"""Syncs vscode extensions between `source` all all the clusters in `dest`.
This spawns a thread for each cluster in `dest` and displays a parallel progress bar
for the syncing of vscode extensions to each cluster.
"""
if isinstance(source, LocalV2):
source_hostname = "localhost"
source_extensions = await get_local_vscode_extensions()
elif isinstance(source, RemoteV2):
source_hostname = source.hostname
code_server_executable = await find_code_server_executable(
source, remote_vscode_server_dir="~/.vscode-server"
)
if not code_server_executable:
raise RuntimeError(
f"The vscode-server executable was not found on {source.hostname}."
)
source_extensions = await get_remote_vscode_extensions(
source, code_server_executable
)
else:
assert isinstance(source, str)
source_hostname = source
source = RemoteV2(source)
if isinstance(source, str):
if source == "localhost":
source = LocalV2()
else:
source = await RemoteV2.connect(source)

destinations = _remove_source_from_destinations(source, destinations)

if not destinations:
logger.info("No destinations to sync extensions to!")
return {}

source_hostname, source_extensions = await get_vscode_extensions(source)

task_hostnames: list[str] = []
tasks: list[AsyncTaskFn[ProgressDict]] = []
task_descriptions: list[str] = []

for dest_remote in destinations:
dest_hostname: str

if dest_remote == "localhost":
dest_hostname = dest_remote # type: ignore
dest_remote = LocalV2() # pickleable
elif isinstance(dest_remote, LocalV2):
dest_hostname = "localhost"
dest_remote = dest_remote # again, pickleable
elif isinstance(dest_remote, RemoteV2):
dest_hostname = dest_remote.hostname
dest_remote = dest_remote # pickleable
elif isinstance(dest_remote, RemoteV1):
# We unfortunately can't pass this kind of object to another process or
# thread because it uses `fabric.Connection` which don't appear to be
# pickleable. This means we will have to re-connect in the subprocess.
dest_hostname = dest_remote.hostname
dest_remote = None
else:
assert isinstance(dest_remote, str)
# The dest_remote is a hostname. Try to connect to it with a reusable SSH
# control socket so we can get the 2FA prompts out of the way in advance.
# NOTE: We could fallback to using the `Remote` class with paramiko inside
# the subprocess if this doesn't work, but it would suck because it messes
# up the UI, and you need to press 1 in the terminal to get the 2FA prompt,
# which screws up the progress bars.
dest_hostname = dest_remote
dest_remote = RemoteV2(hostname=dest_hostname)

# Connect to the remotes in parallel.
dest_runners_and_hostnames = await asyncio.gather(
*(_get_runner_and_hostname(dest) for dest in destinations)
)
for dest_runner, dest_hostname in dest_runners_and_hostnames:
task_hostnames.append(dest_hostname)
tasks.append(
functools.partial(
install_vscode_extensions_task_function,
dest_hostname=dest_hostname,
source_extensions=source_extensions,
remote=dest_remote,
remote=dest_runner,
source_name=source_hostname,
)
)
task_descriptions.append(f"{source_hostname} -> {dest_hostname}")

results: dict[str, ProgressDict] = {}
return {
hostname: result
for hostname, result in zip(
task_hostnames,
await async_progress_bar(
async_task_fns=tasks,
task_descriptions=task_descriptions,
overall_progress_task_description="[green]Syncing vscode extensions:",
),
)
}

for hostname, result in zip(
task_hostnames,
await async_progress_bar(
async_task_fns=tasks,
task_descriptions=task_descriptions,
overall_progress_task_description="[green]Syncing vscode extensions:",
),
):
results[hostname] = result
return results

def _remove_source_from_destinations(
source: LocalV2 | RemoteV2, destinations: Sequence[str | LocalV2 | RemoteV2]
):
dest_hostnames = [
dest if isinstance(dest, str) else dest.hostname for dest in destinations
]
if source.hostname in dest_hostnames:
logger.debug(f"{source.hostname!r} is also in the destinations, removing it.")
destinations = list(destinations)
destinations.pop(dest_hostnames.index(source.hostname))

if len(set(dest_hostnames)) != len(dest_hostnames):
raise ValueError(f"{dest_hostnames=} contains duplicate hostnames!")
return destinations


async def _get_runner_and_hostname(
dest_remote: str | LocalV2 | RemoteV2,
) -> tuple[LocalV2 | RemoteV2, str]:
if isinstance(dest_remote, str):
if dest_remote == "localhost":
return LocalV2(), dest_remote
dest_hostname = dest_remote
dest_remote = await RemoteV2.connect(dest_hostname)
return dest_remote, dest_hostname
return dest_remote, dest_remote.hostname


async def get_vscode_extensions(
source: LocalV2 | RemoteV2,
) -> tuple[str, dict[str, str]]:
if isinstance(source, LocalV2):
code_server_executable = get_local_vscode_executable_path(code_command=None)
else:
code_server_executable = await find_code_server_executable(
source, remote_vscode_server_dir="~/.vscode-server"
)
if not code_server_executable:
raise RuntimeError(
f"The vscode-server executable was not found on {source.hostname}."
)
source_extensions = await _get_vscode_extensions(source, code_server_executable)
return source.hostname, source_extensions


async def install_vscode_extensions_task_function(
task_progress_dict: dict[TaskID, ProgressDict],
task_id: TaskID,
dest_hostname: str | Literal["localhost"],
source_extensions: dict[str, str],
remote: RemoteV2 | LocalV2 | None,
remote: RemoteV2 | LocalV2,
source_name: str,
verbose: bool = False,
) -> ProgressDict:
Expand All @@ -208,8 +200,8 @@ async def install_vscode_extensions_task_function(
def _update_progress(
progress: int, status: str, total: int = len(source_extensions)
):
# Show progress to the parent process by setting an item in the task progress
# dict.
"""Shows progress to the parent process by setting an item in the task progress
dict."""
progress_dict: ProgressDict = {
"progress": progress,
"total": total,
Expand All @@ -218,7 +210,7 @@ def _update_progress(
task_progress_dict[task_id] = progress_dict
return progress_dict

if remote is None:
if not remote:
if dest_hostname == "localhost":
remote = LocalV2()
else:
Expand All @@ -228,7 +220,9 @@ def _update_progress(
if isinstance(remote, LocalV2):
assert dest_hostname == "localhost"
code_server_executable = get_local_vscode_executable_path()
extensions_on_dest = await get_local_vscode_extensions()
extensions_on_dest = await _get_vscode_extensions(
remote, code_server_executable
)
else:
dest_hostname = remote.hostname
remote_vscode_server_dir = "~/.vscode-server"
Expand All @@ -250,7 +244,7 @@ def _update_progress(
status="code-server executable not found!",
)
_update_progress(0, status="fetching installed extensions...")
extensions_on_dest = await get_remote_vscode_extensions(
extensions_on_dest = await _get_vscode_extensions(
remote, code_server_executable
)

Expand Down Expand Up @@ -317,35 +311,32 @@ async def install_vscode_extension(
return result


async def get_local_vscode_extensions(
async def _get_local_vscode_extensions(
code_command: str | None = None,
) -> dict[str, str]:
output = await LocalV2.get_output_async(
(
get_local_vscode_executable_path(code_command=code_command),
"--list-extensions",
"--show-versions",
return await _get_vscode_extensions(
LocalV2(),
code_server_executable=get_local_vscode_executable_path(
code_command=code_command
),
)
return parse_vscode_extensions_versions(stripped_lines_of(output))


async def get_remote_vscode_extensions(
remote: RemoteV2,
remote_code_server_executable: str,
async def _get_vscode_extensions(
remote: RemoteV2 | LocalV2,
code_server_executable: str,
) -> dict[str, str]:
"""Returns the list of isntalled extensions and the path to the code-server
executable."""
remote_extensions = parse_vscode_extensions_versions(
return parse_vscode_extensions_versions(
stripped_lines_of(
await remote.get_output_async(
f"{remote_code_server_executable} --list-extensions --show-versions",
f"{code_server_executable} --list-extensions --show-versions",
display=False,
hide=True,
)
)
)
return remote_extensions


def extensions_to_install(
Expand Down
Loading

0 comments on commit ee33330

Please sign in to comment.