diff --git a/.firecrest-demo-config.json b/.firecrest-demo-config.json index 271dd6b..5f1bd36 100644 --- a/.firecrest-demo-config.json +++ b/.firecrest-demo-config.json @@ -7,5 +7,6 @@ "temp_directory": "", "small_file_size_mb": 5.0, "workdir": "", - "api_version": "1.16.0" + "api_version": "1.16.0", + "builder_metadata_options_custom_scheduler_commands": [] } diff --git a/aiida_firecrest/scheduler.py b/aiida_firecrest/scheduler.py index 553a9a9..edc5baa 100644 --- a/aiida_firecrest/scheduler.py +++ b/aiida_firecrest/scheduler.py @@ -221,14 +221,17 @@ def get_jobs( try: for page_iter in itertools.count(): results += transport._client.poll_active( - transport._machine, jobs, page_number=page_iter + transport._machine, + jobs, + page_number=page_iter, + page_size=self._DEFAULT_PAGE_SIZE, ) if len(results) < self._DEFAULT_PAGE_SIZE * (page_iter + 1): break except FirecrestException as exc: - # firecrest returns error if the job is completed # TODO: check what type of error is returned and handle it properly - if "Invalid job id specified" not in str(exc): + if "Invalid job id" not in str(exc): + # firecrest returns error if the job is completed, while aiida expect a silent return raise SchedulerError(str(exc)) from exc job_list = [] for raw_result in results: diff --git a/aiida_firecrest/transport.py b/aiida_firecrest/transport.py index 6a2780b..c3e5bec 100644 --- a/aiida_firecrest/transport.py +++ b/aiida_firecrest/transport.py @@ -401,6 +401,8 @@ def __init__( # aiida-core/src/aiida/orm/utils/remote:clean_remote() self._is_open = True + self.checksum_check = False + def __str__(self) -> str: """Return the name of the plugin.""" return self.__class__.__name__ @@ -723,7 +725,8 @@ def getfile( down_obj = self._client.external_download(self._machine, str(remote)) down_obj.finish_download(local) - self._validate_checksum(local, remote) + if self.checksum_check: + self._validate_checksum(local, remote) def _validate_checksum( self, localpath: str | Path, remotepath: str | FcPath @@ -965,7 +968,8 @@ def putfile( ) up_obj.finish_upload() - self._validate_checksum(localpath, str(remote)) + if self.checksum_check: + self._validate_checksum(localpath, str(remote)) def payoff(self, path: str | FcPath | Path) -> bool: """ diff --git a/tests/conftest.py b/tests/conftest.py index 7d13b7f..df53355 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field import hashlib import itertools import json @@ -8,18 +8,20 @@ from pathlib import Path import shutil import stat -from typing import Any, Callable +from typing import Any, Callable, ClassVar +from unittest.mock import MagicMock from urllib.parse import urlparse from aiida import orm import firecrest -import firecrest.path import pytest import requests -class Values: - _DEFAULT_PAGE_SIZE: int = 25 +class Slurm: + """Save the submitted job ids for testing purposes.""" + + all_jobs: ClassVar[list] = [] @pytest.fixture @@ -75,7 +77,12 @@ def submit( raise DeprecationWarning("local_file is not supported") if script_remote_path and not Path(script_remote_path).exists(): - raise FileNotFoundError(f"File {script_remote_path} does not exist") + # Firecrest raises FirecrestException instead of FileNotFoundError + mock_response = MagicMock() + mock_response.status_code = 999 # I don't really know + mock_response.json.return_value = {"error": "Mock error message"} + raise firecrest.FirecrestException(mock_response) + job_id = next(self.job_id_generator) # Filter out lines starting with '#SBATCH' @@ -98,9 +105,18 @@ def submit( os.chdir(Path(script_remote_path).parent) os.system(command) + Slurm.all_jobs.append(job_id) + return {"jobid": job_id} - def poll_active(self, machine: str, jobs: list[str], page_number: int = 0): + def cancel(self, machine: str, job_id: str): + job_id = int(job_id) + if job_id in Slurm.all_jobs: + Slurm.all_jobs.remove(job_id) + + def poll_active( + self, machine: str, jobs: list[str], page_number: int = 0, page_size: int = 25 + ): response = [] # 12 satets are defined in firecrest states = [ @@ -118,6 +134,11 @@ def poll_active(self, machine: str, jobs: list[str], page_number: int = 0): "COMPLETING", ] for i in range(len(jobs)): + if int(jobs[i]) not in Slurm.all_jobs: + mock_response = MagicMock() + mock_response.status_code = 999 # I don't really know + mock_response.json.return_value = {"error": "Invalid job id"} + raise firecrest.FirecrestException([mock_response]) response.append( { "job_data_err": "", @@ -139,11 +160,7 @@ def poll_active(self, machine: str, jobs: list[str], page_number: int = 0): } ) - return response[ - page_number - * Values._DEFAULT_PAGE_SIZE : (page_number + 1) - * Values._DEFAULT_PAGE_SIZE - ] + return response[page_number * page_size : (page_number + 1) * page_size] def whoami(self, machine: str): assert machine == "MACHINE_NAME" @@ -373,7 +390,22 @@ def __init__(self, *args, **kwargs): @dataclass class ComputerFirecrestConfig: - """Configuration of a computer using FirecREST as transport plugin.""" + """Configuration of a computer using FirecREST as transport plugin. + + :param url: The URL of the FirecREST server. + :param token_uri: The URI to receive tokens. + :param client_id: The client ID for the client credentials. + :param client_secret: The client secret for the client credentials. + :param compute_resource: The name of the compute resource. This is the name of the machine. + :param temp_directory: A temporary directory on the machine for transient zip files. + :param workdir: The aiida working directory on the machine. + :param api_version: The version of the FirecREST API. + :param builder_metadata_options_custom_scheduler_commands: A list of custom + scheduler commands when submitting a job, for example + ["#SBATCH --account=mr32", + "#SBATCH --constraint=mc", + "#SBATCH --mem=10K"]. + :param small_file_size_mb: The maximum file size for direct upload & download.""" url: str token_uri: str @@ -384,6 +416,9 @@ class ComputerFirecrestConfig: workdir: str api_version: str small_file_size_mb: float = 1.0 + builder_metadata_options_custom_scheduler_commands: list[str] = field( + default_factory=list + ) class RequestTelemetry: @@ -514,4 +549,5 @@ def firecrest_config( small_file_size_mb=1.0, temp_directory=str(_temp_directory), api_version="2", + builder_metadata_options_custom_scheduler_commands=[], ) diff --git a/tests/test_calculation.py b/tests/test_calculation.py index a71e924..fc62159 100644 --- a/tests/test_calculation.py +++ b/tests/test_calculation.py @@ -20,8 +20,9 @@ def _no_retries(): manage.get_config().set_option(MAX_ATTEMPTS_OPTION, max_attempts) +@pytest.mark.timeout(180) @pytest.mark.usefixtures("aiida_profile_clean", "no_retries") -def test_calculation_basic(firecrest_computer: orm.Computer): +def test_calculation_basic(firecrest_computer: orm.Computer, firecrest_config): """Test running a simple `arithmetic.add` calculation.""" code = orm.InstalledCode( label="test_code", @@ -35,6 +36,10 @@ def test_calculation_basic(firecrest_computer: orm.Computer): builder = code.get_builder() builder.x = orm.Int(1) builder.y = orm.Int(2) + custom_scheduler_commands = "\n".join( + firecrest_config.builder_metadata_options_custom_scheduler_commands + ) + builder.metadata.options.custom_scheduler_commands = custom_scheduler_commands _, node = engine.run_get_node(builder) assert node.is_finished_ok diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 5939e60..f5e6327 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,47 +1,125 @@ from pathlib import Path +import textwrap +from time import sleep from aiida import orm +from aiida.schedulers import SchedulerError from aiida.schedulers.datastructures import CodeRunMode, JobTemplate import pytest from aiida_firecrest.scheduler import FirecrestScheduler -from conftest import Values @pytest.mark.usefixtures("aiida_profile_clean") -def test_submit_job(firecrest_computer: orm.Computer, tmp_path: Path): +def test_submit_job(firecrest_computer: orm.Computer, firecrest_config, tmpdir: Path): + """Test submitting a job to the scheduler. + Note: this test relies on a functional transport.put() method.""" + transport = firecrest_computer.get_transport() scheduler = FirecrestScheduler() scheduler.set_transport(transport) - with pytest.raises(FileNotFoundError): - scheduler.submit_job(transport.getcwd(), "unknown.sh") + # raise error if file not found + with pytest.raises(SchedulerError): + scheduler.submit_job(firecrest_config.workdir, "unknown.sh") + + custom_scheduler_commands = "\n ".join( + firecrest_config.builder_metadata_options_custom_scheduler_commands + ) + + shell_script = f""" + #!/bin/bash + #SBATCH --no-requeue + #SBATCH --job-name="aiida-1928" + #SBATCH --get-user-env + #SBATCH --output=_scheduler-stdout.txt + #SBATCH --error=_scheduler-stderr.txt + #SBATCH --nodes=1 + #SBATCH --ntasks-per-node=1 + {custom_scheduler_commands} + + echo 'hello world' + """ + + dedented_script = textwrap.dedent(shell_script).strip() + Path(tmpdir / "job.sh").write_text(dedented_script) + remote_ = transport._cwd.joinpath(firecrest_config.workdir, "job.sh") + transport.put(tmpdir / "job.sh", remote_) - _script = Path(tmp_path / "job.sh") - _script.write_text("#!/bin/bash\n\necho 'hello world'") + job_id = scheduler.submit_job(firecrest_config.workdir, "job.sh") - job_id = scheduler.submit_job(transport.getcwd(), _script) - # this is how aiida expects the job_id to be returned assert isinstance(job_id, str) +@pytest.mark.timeout(180) @pytest.mark.usefixtures("aiida_profile_clean") -def test_get_jobs(firecrest_computer: orm.Computer): +def test_get_and_kill_jobs( + firecrest_computer: orm.Computer, firecrest_config, tmpdir: Path +): + """Test getting and killing jobs from the scheduler. + We test the two together for performance reasons, as this test might run against + a real server and we don't want to leave parasitic jobs behind. + also less billing for the user. + Note: this test relies on a functional transport.put() method. + """ + import time + transport = firecrest_computer.get_transport() scheduler = FirecrestScheduler() scheduler.set_transport(transport) - # test pagaination - scheduler._DEFAULT_PAGE_SIZE = 2 - Values._DEFAULT_PAGE_SIZE = 2 + # verify that no error is raised in the case of an invalid job id 000 + scheduler.get_jobs(["000"]) + + custom_scheduler_commands = "\n ".join( + firecrest_config.builder_metadata_options_custom_scheduler_commands + ) + shell_script = f""" + #!/bin/bash + #SBATCH --no-requeue + #SBATCH --job-name="aiida-1929" + #SBATCH --get-user-env + #SBATCH --output=_scheduler-stdout.txt + #SBATCH --error=_scheduler-stderr.txt + #SBATCH --nodes=1 + #SBATCH --ntasks-per-node=1 + {custom_scheduler_commands} + + sleep 180 + """ - joblist = ["111", "222", "333", "444", "555"] + joblist = [] + dedented_script = textwrap.dedent(shell_script).strip() + Path(tmpdir / "job.sh").write_text(dedented_script) + remote_ = transport._cwd.joinpath(firecrest_config.workdir, "job.sh") + transport.put(tmpdir / "job.sh", remote_) + + for _ in range(5): + joblist.append(scheduler.submit_job(firecrest_config.workdir, "job.sh")) + + # test pagaination is working + scheduler._DEFAULT_PAGE_SIZE = 2 result = scheduler.get_jobs(joblist) assert len(result) == 5 for i in range(5): - assert result[i].job_id == str(joblist[i]) + assert result[i].job_id in joblist # TODO: one could check states as well + # test kill jobs + for jobid in joblist: + scheduler.kill_job(jobid) + + # sometimes it takes time for the server to actually kill the jobs + timeout_kill = 5 # seconds + start_time = time.time() + while time.time() - start_time < timeout_kill: + result = scheduler.get_jobs(joblist) + if not len(result): + break + sleep(0.5) + + assert not len(result) + def test_write_script_full(): # to avoid false positive (overwriting on existing file), @@ -68,9 +146,9 @@ def test_write_script_full(): #SBATCH --mem=1 test_command """ - expectaion_flat = "\n".join(line.strip() for line in expectaion.splitlines()).strip( - "\n" - ) + expectation_flat = "\n".join( + line.strip() for line in expectaion.splitlines() + ).strip("\n") scheduler = FirecrestScheduler() template = JobTemplate( { @@ -98,7 +176,7 @@ def test_write_script_full(): } ) try: - assert scheduler.get_submit_script(template).rstrip() == expectaion_flat + assert scheduler.get_submit_script(template).rstrip() == expectation_flat except AssertionError: print(scheduler.get_submit_script(template).rstrip()) print(expectaion) @@ -116,9 +194,9 @@ def test_write_script_minimal(): #SBATCH --ntasks-per-node=1 """ - expectaion_flat = "\n".join(line.strip() for line in expectaion.splitlines()).strip( - "\n" - ) + expectation_flat = "\n".join( + line.strip() for line in expectaion.splitlines() + ).strip("\n") scheduler = FirecrestScheduler() template = JobTemplate( { @@ -131,7 +209,7 @@ def test_write_script_minimal(): ) try: - assert scheduler.get_submit_script(template).rstrip() == expectaion_flat + assert scheduler.get_submit_script(template).rstrip() == expectation_flat except AssertionError: print(scheduler.get_submit_script(template).rstrip()) print(expectaion) diff --git a/tests/test_transport.py b/tests/test_transport.py index f9d970d..edd9f30 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -150,6 +150,27 @@ def test_putfile_getfile(firecrest_computer: orm.Computer, tmpdir: Path): assert not Path(_local_download / "remote_link").is_symlink() assert Path(_local_download / "remote_link").read_text() == "file1" + # test the self.checksum_check + with patch.object( + transport, "_validate_checksum", autospec=True + ) as mock_validate_checksum: + transport.checksum_check = True + transport.putfile(_local / "file1", _remote / "file1_checksum") + transport.getfile( + _remote / "file1_checksum", _local_download / "file1_checksum" + ) + assert mock_validate_checksum.call_count == 2 + + with patch.object( + transport, "_validate_checksum", autospec=True + ) as mock_validate_checksum: + transport.checksum_check = False + transport.putfile(_local / "file1", _remote / "file1_checksum2") + transport.getfile( + _remote / "file1_checksum2", _local_download / "file1_checksum2" + ) + assert mock_validate_checksum.call_count == 0 + @pytest.mark.usefixtures("aiida_profile_clean") def test_remove(firecrest_computer: orm.Computer, tmpdir: Path):