From 58dc64af8acb0e2776447458a3cc4ca480128fe2 Mon Sep 17 00:00:00 2001
From: Satya Ortiz-Gagne <satya.ortiz-gagne@mila.quebec>
Date: Thu, 8 Aug 2024 06:54:49 +0200
Subject: [PATCH] Fix cloud multi-nodes

* Copy ssh key to allow connections from master to workers
* Use local ip for manager's ip such that workers can find it and connect to it
* Fix incompatibility between pandas and numpy 2.0.0
---
 .github/workflows/cloud-ci.yml         |  7 +++-
 config/cloud-multinodes-system.yaml    |  1 +
 config/examples/test.yaml              | 12 +++----
 milabench/commands/__init__.py         | 13 +++++--
 milabench/remote.py                    |  5 ---
 milabench/scripts/covalent/__main__.py | 49 +++++++++++++++++++++++---
 milabench/system.py                    |  4 +++
 poetry.lock                            |  4 +--
 pyproject.toml                         |  4 ++-
 9 files changed, 76 insertions(+), 23 deletions(-)

diff --git a/.github/workflows/cloud-ci.yml b/.github/workflows/cloud-ci.yml
index b1c6e7fc6..7226fc026 100644
--- a/.github/workflows/cloud-ci.yml
+++ b/.github/workflows/cloud-ci.yml
@@ -118,6 +118,11 @@ jobs:
         env:
           GITHUB_TOKEN: ${{ github.token }}
 
+      - name: DEBUG state file
+        if: always()
+        run: |
+          cat /tmp/milabench/covalent_venv/lib/python*/site-packages/covalent_azure_plugin/infra/*.tfstate
+
       - name: teardown cloud
         if: always()
         run: |
@@ -130,7 +135,7 @@ jobs:
             --run-on ${{ matrix.run_on }} \
             --all
 
-      - name: debug logs
+      - name: DEBUG logs
         if: always()
         run: |
           cat ~/.cache/covalent/covalent_ui.log
diff --git a/config/cloud-multinodes-system.yaml b/config/cloud-multinodes-system.yaml
index e5dc14f2b..e3d3e37c1 100644
--- a/config/cloud-multinodes-system.yaml
+++ b/config/cloud-multinodes-system.yaml
@@ -5,6 +5,7 @@ system:
     - name: manager
       # Use 1.1.1.1 as an ip placeholder
       ip: 1.1.1.1
+      port: 5000
       # Use this node as the master node or not
       main: true
       # User to use in remote milabench operations
diff --git a/config/examples/test.yaml b/config/examples/test.yaml
index 6e155a0bf..4f74ac33b 100644
--- a/config/examples/test.yaml
+++ b/config/examples/test.yaml
@@ -7,18 +7,18 @@ _defaults:
 
 test:
   inherits: _defaults
-  group: test_remote
-  install_group: test_remote
-  definition: ../../benchmarks/_template
+  group: simple
+  install_group: test
+  definition: ../../benchmarks/_templates/simple
   plan:
     method: njobs
     n: 1
 
 testing:
   inherits: _defaults
-  definition: ../../benchmarks/_template
-  group: test_remote_2
-  install_group: test_remote_2
+  definition: ../../benchmarks/_templates/stdout
+  group: stdout
+  install_group: test
   plan:
     method: njobs
     n: 1
diff --git a/milabench/commands/__init__.py b/milabench/commands/__init__.py
index 0de00f756..6c166132b 100644
--- a/milabench/commands/__init__.py
+++ b/milabench/commands/__init__.py
@@ -939,6 +939,13 @@ def _get_main_and_workers(self):
     def _argv(self, **_) -> List:
         manager, nodes = self._get_main_and_workers()
 
+        # Find local ip such that workers can connect to the port
+        for manager_ip in manager["ipaddrlist"]:
+            if ":" in manager_ip or manager_ip == "127.0.0.1":
+                continue
+            if all(str.isnumeric(n) for n in manager_ip.split(".")):
+                break
+
         num_machines = max(1, len(nodes) + 1)
 
         # Cant do that maybe this run is constrained
@@ -976,9 +983,9 @@ def _argv(self, **_) -> List:
             f"--machine_rank={self.rank}",
             f"--num_machines={num_machines}",
             *deepspeed_argv,
-            f"--gradient_accumulation_steps={self.pack.config.get('gradient_accumulation_steps', 1)}",
-            f"--num_cpu_threads_per_process={cpu_per_process}",
-            f"--main_process_ip={manager['ip']}",
+            f"--gradient_accumulation_steps={self.pack.config['gradient_accumulation_steps']}",
+            f"--num_cpu_threads_per_process={self.pack.config['argv']['--cpus_per_gpu']}",
+            f"--main_process_ip={manager_ip}",
             f"--main_process_port={manager['port']}",
             f"--num_processes={nproc}",
             *self.accelerate_argv,
diff --git a/milabench/remote.py b/milabench/remote.py
index b657f98c5..78e8ad736 100644
--- a/milabench/remote.py
+++ b/milabench/remote.py
@@ -2,16 +2,11 @@
 import os
 import sys
 
-import yaml
-
-from milabench.fs import XPath
-
 from . import ROOT_FOLDER
 from .commands import (
     CmdCommand,
     Command,
     ListCommand,
-    SCPCommand,
     SequenceCommand,
     SSHCommand,
     VoidCommand,
diff --git a/milabench/scripts/covalent/__main__.py b/milabench/scripts/covalent/__main__.py
index 995cc856f..ec8a14f58 100644
--- a/milabench/scripts/covalent/__main__.py
+++ b/milabench/scripts/covalent/__main__.py
@@ -98,6 +98,7 @@ def _popen(cmd, *args, _env=None, **kwargs):
             assert result and result[0]
 
             all_connection_attributes, _ = result
+            ssh_key_file:str = None
             for hostname, connection_attributes in all_connection_attributes.items():
                 print(f"hostname::>{hostname}")
                 for attribute,value in connection_attributes.items():
@@ -105,17 +106,55 @@ def _popen(cmd, *args, _env=None, **kwargs):
                         continue
                     print(f"{attribute}::>{value}")
 
+                ssh_key_file = (
+                    ssh_key_file or connection_attributes["ssh_key_file"]
+                )
+
+            if len(all_connection_attributes) >= 1:
+                fn = pathlib.Path(ssh_key_file)
+                dispatch_id = ct.dispatch(
+                    ct.lattice(executor.cp_to_remote), disable_run=False
+                )(f".ssh/{fn.name.split('.')[0]}", str(fn))
+
+                result = ct.get_result(dispatch_id=dispatch_id, wait=True).result
+
         if argv:
             dispatch_id = ct.dispatch(
-                ct.lattice(
-                    lambda:ct.electron(_popen, executor=executor)(argv)
-                ),
-                disable_run=False
+                ct.lattice(executor.list_running_instances), disable_run=False
             )()
 
             result = ct.get_result(dispatch_id=dispatch_id, wait=True).result
 
-            return_code, _, _ = result if result is not None else (1, "", "")
+            assert result
+
+            dispatch_ids = set()
+            for connection_attributes in result.get(
+                (executor.state_prefix, executor.state_id),
+                {"env": None}
+            ).values():
+                kwargs = {
+                    **_get_executor_kwargs(args),
+                    **connection_attributes
+                }
+                del kwargs["env"]
+
+                _executor:ct.executor.BaseExecutor = executor_cls(**kwargs)
+
+                dispatch_ids.add(
+                    ct.dispatch(
+                        ct.lattice(
+                            lambda:ct.electron(_popen, executor=_executor)(argv)
+                        ),
+                        disable_run=False
+                    )()
+                )
+
+            for dispatch_id in dispatch_ids:
+                result = ct.get_result(dispatch_id=dispatch_id, wait=True).result
+
+                _return_code, _, _ = result if result is not None else (1, "", "")
+                return_code = return_code or _return_code
+
     finally:
         if args.teardown:
             result = executor.stop_cloud_instance().result
diff --git a/milabench/system.py b/milabench/system.py
index 8b9711514..8d137d642 100644
--- a/milabench/system.py
+++ b/milabench/system.py
@@ -258,6 +258,10 @@ def _resolve_ip(ip):
     if not offline:
         # Resolve the IP
         try:
+            # Workaround error with `gethostbyaddr` on azure DNS (like
+            # `inmako.eastus2.cloudapp.azure.com`). A proper fix might be a
+            # correct network config in terraform.
+            # socket.herror: [Errno 1] Unknown host
             hostname, aliaslist, ipaddrlist = socket.gethostbyname_ex(ip)
             lazy_raise = None
         
diff --git a/poetry.lock b/poetry.lock
index ec0f16753..b910db129 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
 
 [[package]]
 name = "alabaster"
@@ -2190,4 +2190,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.10,<4.0"
-content-hash = "59901f6d97314b2a67cac2cf9c4300cb5bde2feba01b0198b20c8ac477adae05"
+content-hash = "e8817803c68c0acc023e37a954027d5870b08d0e29cf46e8dd673df7e9d6994d"
diff --git a/pyproject.toml b/pyproject.toml
index 6a1693bf6..e7f784793 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,7 +26,9 @@ blessed = "^1.19.1"
 pathspec = "^0.9.0"
 cp-template = "^0.3.0"
 pandas = ">=1.4.2"
-numpy = ">=1.23.0,<2.0.0"
+# Work around for compatibility issue between numpy 2.0.0 and pandas
+# https://github.com/numpy/numpy/issues/26710
+numpy = "^1.23.0"
 pynvml = "^11.4.1"
 tqdm = "^4.64.1"
 pip-tools = "^7.4.1"