diff --git a/tests/utils/test_compute_node.py b/tests/utils/test_compute_node.py index bb16b84d..c9259324 100644 --- a/tests/utils/test_compute_node.py +++ b/tests/utils/test_compute_node.py @@ -221,30 +221,44 @@ async def test_run_gets_executed_in_job_step( assert job_step_b == job_step_a + 1 @pytest.mark.asyncio - async def test_connect(self, runner: ComputeNode): + async def test_connect_with_job_id(self, runner: ComputeNode): login_node = runner.login_node job_id = runner.job_id - node_hostname = runner.hostname + # Connect with the job id: compute_node_with_jobid = await ComputeNode.connect( login_node, job_id_or_node_name=job_id ) assert compute_node_with_jobid.salloc_subprocess is None + # Note: __eq__ from dataclass ignores some fields in this case (those with + # compare=False), in this case the salloc_subprocess is different, but the + # equality check succeeds assert compute_node_with_jobid == runner + @pytest.mark.asyncio + async def test_connect_with_node_name(self, runner: ComputeNode): + """Test connecting to a compute node using the node name. + + This will probably eventually be deprecated in favour of reconnecting with the + job ID instead, since it's unambiguous in the case of multiple jobs running on + the same node. + """ # Need to connect with the node name, not the full node hostname. # For the `mila` cluster, we don't currently have a `cn-?????` entry in the ssh # config (although we could!) # Therefore, we need to connect to the node with the full hostname. However # squeue expects the node name, so we have to truncate it manually for now. + login_node = runner.login_node + node_hostname = runner.hostname if node_hostname.endswith(".server.mila.quebec"): node_name = removesuffix(node_hostname, ".server.mila.quebec") else: node_name = node_hostname + # Connect with the node name: compute_node_with_node_name = await ComputeNode.connect( login_node, job_id_or_node_name=node_name ) - assert compute_node_with_jobid.salloc_subprocess is None + assert compute_node_with_node_name.salloc_subprocess is None assert compute_node_with_node_name == runner @pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"])