Skip to content

Commit

Permalink
Split TestComputeNode::test_connect in two cases
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed May 8, 2024
1 parent 8aa42d0 commit 9b6bc09
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions tests/utils/test_compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,30 +221,44 @@ async def test_run_gets_executed_in_job_step(
assert job_step_b == job_step_a + 1

@pytest.mark.asyncio
async def test_connect(self, runner: ComputeNode):
async def test_connect_with_job_id(self, runner: ComputeNode):
login_node = runner.login_node
job_id = runner.job_id
node_hostname = runner.hostname
# Connect with the job id:
compute_node_with_jobid = await ComputeNode.connect(
login_node, job_id_or_node_name=job_id
)
assert compute_node_with_jobid.salloc_subprocess is None
# Note: __eq__ from dataclass ignores some fields in this case (those with
# compare=False), in this case the salloc_subprocess is different, but the
# equality check succeeds
assert compute_node_with_jobid == runner

@pytest.mark.asyncio
async def test_connect_with_node_name(self, runner: ComputeNode):
"""Test connecting to a compute node using the node name.
This will probably eventually be deprecated in favour of reconnecting with the
job ID instead, since it's unambiguous in the case of multiple jobs running on
the same node.
"""
# Need to connect with the node name, not the full node hostname.
# For the `mila` cluster, we don't currently have a `cn-?????` entry in the ssh
# config (although we could!)
# Therefore, we need to connect to the node with the full hostname. However
# squeue expects the node name, so we have to truncate it manually for now.
login_node = runner.login_node
node_hostname = runner.hostname
if node_hostname.endswith(".server.mila.quebec"):
node_name = removesuffix(node_hostname, ".server.mila.quebec")
else:
node_name = node_hostname

# Connect with the node name:
compute_node_with_node_name = await ComputeNode.connect(
login_node, job_id_or_node_name=node_name
)
assert compute_node_with_jobid.salloc_subprocess is None
assert compute_node_with_node_name.salloc_subprocess is None
assert compute_node_with_node_name == runner

@pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"])
Expand Down

0 comments on commit 9b6bc09

Please sign in to comment.