Skip to content

Commit

Permalink
add show_table parameter and argument
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Apr 30, 2024
1 parent 412fd5d commit 8801272
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
1 change: 1 addition & 0 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def mila():
formatter_class=SortingHelpFormatter,
)
run_parser.add_argument("--ssh_config_path", type=Path, default=SSH_CONFIG_FILE)
run_parser.add_argument("--show-table", action="store_true", default=False)
run_parser.add_argument(
"command", type=str, nargs=argparse.REMAINDER, help="The command to run."
)
Expand Down
30 changes: 25 additions & 5 deletions milatools/cli/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import shlex
import subprocess
import sys
from pathlib import Path

Expand All @@ -16,7 +17,9 @@


async def run_command(
command: str | list[str], ssh_config_path: Path = SSH_CONFIG_FILE
command: str | list[str],
ssh_config_path: Path = SSH_CONFIG_FILE,
show_table: bool = False,
):
command = shlex.join(command) if isinstance(command, list) else command
if command.startswith("'") and command.endswith("'"):
Expand Down Expand Up @@ -44,18 +47,36 @@ async def _is_slurm_cluster(remote: RemoteV2) -> bool:

results = await asyncio.gather(
*(
login_node.run_async(command=command, warn=True, display=True, hide=False)
login_node.run_async(command=command, warn=True, display=True, hide=True)
for login_node in cluster_login_nodes
)
)
if show_table:
_print_with_table(command, cluster_login_nodes, results)
else:
_print_with_prefix(command, cluster_login_nodes, results)
return results


def _print_with_prefix(
command: str,
cluster_login_nodes: list[RemoteV2],
results: list[subprocess.CompletedProcess[str]],
):
for remote, result in zip(cluster_login_nodes, results):
for line in result.stdout.splitlines():
print(f"({remote.hostname}) {line}")
console.print(f"[bold]({remote.hostname})[/bold] {line}", markup=True)
for line in result.stderr.splitlines():
print(f"({remote.hostname}) {line}", file=sys.stderr)

return results
# return results


def _print_with_table(
command: str,
cluster_login_nodes: list[RemoteV2],
results: list[subprocess.CompletedProcess[str]],
):
table = rich.table.Table(title=command)
table.add_column("Cluster")

Expand Down Expand Up @@ -88,7 +109,6 @@ async def _is_slurm_cluster(remote: RemoteV2) -> bool:
# table.add_column(remote.hostname, no_wrap=True)
# task = group.create_task(remote.run_async(command))
# task.add_done_callback(lambda _: table.add_row())
return results


if __name__ == "main":
Expand Down

0 comments on commit 8801272

Please sign in to comment.