From fc981807994fa3c34916e7bb1e5236cd8d3db14d Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 31 Jul 2024 14:25:08 -0400 Subject: [PATCH] ruff auto-fixes --- doc/source/conf.py | 6 +-- src/qtoolkit/core/base.py | 4 +- src/qtoolkit/core/data_objects.py | 19 ++++----- src/qtoolkit/core/exceptions.py | 4 +- src/qtoolkit/host/base.py | 7 +++- src/qtoolkit/host/local.py | 2 +- src/qtoolkit/host/remote.py | 20 ++++----- src/qtoolkit/io/base.py | 15 ++++--- src/qtoolkit/io/pbs.py | 28 ++++++------- src/qtoolkit/io/shell.py | 20 ++++----- src/qtoolkit/io/slurm.py | 40 ++++++++---------- src/qtoolkit/manager.py | 29 ++++++++----- src/qtoolkit/utils.py | 2 +- tests/__init__.py | 2 - tests/conftest.py | 2 +- tests/core/test_data_objects.py | 70 +++++++++++++++---------------- tests/io/test_base.py | 10 ++--- tests/io/test_slurm.py | 2 +- 18 files changed, 138 insertions(+), 144 deletions(-) diff --git a/doc/source/conf.py b/doc/source/conf.py index 2a210ab..ea437d1 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -1,4 +1,4 @@ -# +# # noqa: INP001 # Configuration file for the Sphinx documentation builder. # # This file does only contain a selection of the most common options. For a @@ -23,7 +23,7 @@ # -- Project information ----------------------------------------------------- project = "QToolKit" -copyright = "2023, Matgenix SRL" +copyright = "2023, Matgenix SRL" # noqa: A001 author = "Guido Petretto, David Waroquiers" @@ -109,7 +109,7 @@ }, "collapse_navigation": True, "announcement": ( - "

" "QToolKit is still in beta phase. The API may change at any time." "

" + "

QToolKit is still in beta phase. The API may change at any time.

" ), # "navbar_end": ["theme-switcher", "navbar-icon-links"], # "navbar_end": ["theme-switcher", "version-switcher", "navbar-icon-links"], diff --git a/src/qtoolkit/core/base.py b/src/qtoolkit/core/base.py index d978188..3e67b4a 100644 --- a/src/qtoolkit/core/base.py +++ b/src/qtoolkit/core/base.py @@ -25,8 +25,8 @@ def _validate_monty(cls, __input_value): """ try: super()._validate_monty(__input_value) - except ValueError as e: + except ValueError as exc: try: return cls(__input_value) except Exception: - raise e + raise exc # noqa: B904 diff --git a/src/qtoolkit/core/data_objects.py b/src/qtoolkit/core/data_objects.py index 9700f31..2c59a0d 100644 --- a/src/qtoolkit/core/data_objects.py +++ b/src/qtoolkit/core/data_objects.py @@ -2,11 +2,14 @@ import abc from dataclasses import dataclass, fields -from pathlib import Path +from typing import TYPE_CHECKING from qtoolkit.core.base import QTKEnum, QTKObject from qtoolkit.core.exceptions import UnsupportedResourcesError +if TYPE_CHECKING: + from pathlib import Path + class SubmissionStatus(QTKEnum): SUCCESSFUL = "SUCCESSFUL" @@ -190,7 +193,7 @@ class QResources(QTKObject): def __post_init__(self): if self.process_placement is None: if self.processes and not self.processes_per_node and not self.nodes: - self.process_placement = ProcessPlacement.NO_CONSTRAINTS # type: ignore # due to QTKEnum + self.process_placement = ProcessPlacement.NO_CONSTRAINTS elif self.nodes and self.processes_per_node and not self.processes: self.process_placement = ProcessPlacement.EVENLY_DISTRIBUTED elif not self._check_no_values(): @@ -203,18 +206,12 @@ def __post_init__(self): self.scheduler_kwargs = self.scheduler_kwargs or {} def _check_no_values(self) -> bool: - """ - Check if all the attributes are None or empty. - """ - for f in fields(self): - if self.__getattribute__(f.name): - return False - - return True + """Check if all the attributes are None or empty.""" + return all(not self.__getattribute__(f.name) for f in fields(self)) def check_empty(self) -> bool: """ - Check if the QResouces is empty and its content is coherent. + Check if the QResources is empty and its content is coherent. Raises an error if process_placement is None, but some attributes are set. """ if self.process_placement is not None: diff --git a/src/qtoolkit/core/exceptions.py b/src/qtoolkit/core/exceptions.py index 2c6d829..74b0104 100644 --- a/src/qtoolkit/core/exceptions.py +++ b/src/qtoolkit/core/exceptions.py @@ -1,7 +1,5 @@ class QTKException(Exception): - """ - Base class for all the exceptions generated by qtoolkit. - """ + """Base class for all the exceptions generated by qtoolkit.""" class CommandFailedError(QTKException): diff --git a/src/qtoolkit/host/base.py b/src/qtoolkit/host/base.py index fb29816..1f2760a 100644 --- a/src/qtoolkit/host/base.py +++ b/src/qtoolkit/host/base.py @@ -2,10 +2,13 @@ import abc from dataclasses import dataclass -from pathlib import Path +from typing import TYPE_CHECKING from qtoolkit.core.base import QTKObject +if TYPE_CHECKING: + from pathlib import Path + @dataclass class HostConfig(QTKObject): @@ -30,7 +33,7 @@ def execute( # stdout=None, # stderr=None, ): - """Execute the given command on the host + """Execute the given command on the host. Parameters ---------- diff --git a/src/qtoolkit/host/local.py b/src/qtoolkit/host/local.py index cdc1db2..4128818 100644 --- a/src/qtoolkit/host/local.py +++ b/src/qtoolkit/host/local.py @@ -11,7 +11,7 @@ class LocalHost(BaseHost): # def __init__(self, config): # self.config = config def execute(self, command: str | list[str], workdir: str | Path | None = None): - """Execute the given command on the host + """Execute the given command on the host. Note that the command is executed with shell=True, so commands can be exposed to command injection. Consider whether to escape part of diff --git a/src/qtoolkit/host/remote.py b/src/qtoolkit/host/remote.py index 9fc6080..24d87da 100644 --- a/src/qtoolkit/host/remote.py +++ b/src/qtoolkit/host/remote.py @@ -2,12 +2,15 @@ import io from dataclasses import dataclass, field -from pathlib import Path +from typing import TYPE_CHECKING import fabric from qtoolkit.host.base import BaseHost, HostConfig +if TYPE_CHECKING: + from pathlib import Path + # from fabric import Connection, Config @@ -128,7 +131,7 @@ class RemoteConfig(HostConfig): class RemoteHost(BaseHost): """ Execute commands on a remote host. - For some commands assumes the remote can run unix + For some commands assumes the remote can run unix. """ def __init__(self, config: RemoteConfig): @@ -143,7 +146,7 @@ def connection(self): return self._connection def execute(self, command: str | list[str], workdir: str | Path | None = None): - """Execute the given command on the host + """Execute the given command on the host. Parameters ---------- @@ -161,7 +164,6 @@ def execute(self, command: str | list[str], workdir: str | Path | None = None): exit_code : int Exit code of the command. """ - if isinstance(command, (list, tuple)): command = " ".join(command) @@ -169,10 +171,7 @@ def execute(self, command: str | list[str], workdir: str | Path | None = None): # connection from outside (not through a config) and we want to keep it alive ? # TODO: check if this works: - if not workdir: - workdir = "." - else: - workdir = str(workdir) + workdir = str(workdir) if workdir else "." with self.connection.cd(workdir): out = self.connection.run(command, hide=True, warn=True) @@ -185,10 +184,11 @@ def mkdir(self, directory, recursive: bool = True, exist_ok: bool = True) -> boo command += "-p " command += str(directory) try: - stdout, stderr, returncode = self.execute(command) - return returncode == 0 + _stdout, _stderr, returncode = self.execute(command) except Exception: return False + else: + return returncode == 0 def write_text_file(self, filepath, content): """Write content to a file on the host.""" diff --git a/src/qtoolkit/io/base.py b/src/qtoolkit/io/base.py index b935949..c044a9c 100644 --- a/src/qtoolkit/io/base.py +++ b/src/qtoolkit/io/base.py @@ -3,13 +3,16 @@ import abc import shlex from dataclasses import fields -from pathlib import Path from string import Template +from typing import TYPE_CHECKING from qtoolkit.core.base import QTKObject from qtoolkit.core.data_objects import CancelResult, QJob, QResources, SubmissionResult from qtoolkit.core.exceptions import UnsupportedResourcesError +if TYPE_CHECKING: + from pathlib import Path + class QTemplate(Template): delimiter = "$$" @@ -74,9 +77,8 @@ def generate_header(self, options: dict | QResources | None) -> str: options = options or {} - if isinstance(options, QResources): - if not options.check_empty(): - options = self.check_convert_qresources(options) + if isinstance(options, QResources) and not options.check_empty(): + options = self.check_convert_qresources(options) template = QTemplate(self.header_template) @@ -89,10 +91,7 @@ def generate_header(self, options: dict | QResources | None) -> str: unclean_header = template.safe_substitute(options) # Remove lines with leftover $$. - clean_header = [] - for line in unclean_header.split("\n"): - if "$$" not in line: - clean_header.append(line) + clean_header = [line for line in unclean_header.split("\n") if "$$" not in line] return "\n".join(clean_header) diff --git a/src/qtoolkit/io/pbs.py b/src/qtoolkit/io/pbs.py index 22f6fce..49fa3cb 100644 --- a/src/qtoolkit/io/pbs.py +++ b/src/qtoolkit/io/pbs.py @@ -2,6 +2,7 @@ import re from datetime import timedelta +from typing import ClassVar from qtoolkit.core.data_objects import ( CancelResult, @@ -156,9 +157,7 @@ def parse_cancel_output(self, exit_code, stdout, stderr) -> CancelResult: ) def _get_job_cmd(self, job_id: str): - cmd = f"qstat -f {job_id}" - - return cmd + return f"qstat -f {job_id}" def parse_job_output(self, exit_code, stdout, stderr) -> QJob | None: out = self.parse_jobs_list_output(exit_code, stdout, stderr) @@ -224,7 +223,7 @@ def parse_jobs_list_output(self, exit_code, stdout, stderr) -> list[QJob]: jobs_list = [] for chunk in jobs_chunks: - chunk = chunk.strip() + chunk = chunk.strip() # noqa: PLW2901 if not chunk: continue @@ -243,9 +242,9 @@ def parse_jobs_list_output(self, exit_code, stdout, stderr) -> list[QJob]: try: pbs_job_state = PBSState(job_state_string) - except ValueError: + except ValueError as exc: msg = f"Unknown job state {job_state_string} for job id {qjob.job_id}" - raise OutputParsingError(msg) + raise OutputParsingError(msg) from exc qjob.sub_state = pbs_job_state qjob.state = pbs_job_state.qstate @@ -299,7 +298,6 @@ def _convert_str_to_time(time_str: str | None): Convert a string in the format used by PBS DD:HH:MM:SS to a number of seconds. It may contain only H:M:S, only M:S or only S. """ - if not time_str: return None @@ -312,8 +310,8 @@ def _convert_str_to_time(time_str: str | None): for i, v in enumerate(reversed(time_split)): time[i] = int(v) - except ValueError: - raise OutputParsingError() + except ValueError as exc: + raise OutputParsingError from exc return time[3] * 86400 + time[2] * 3600 + time[1] * 60 + time[0] @@ -335,14 +333,14 @@ def _convert_memory_str(memory: str | None) -> int | None: raise OutputParsingError(f"Unknown units {units}") try: v = int(memory) - except ValueError: - raise OutputParsingError + except ValueError as exc: + raise OutputParsingError from exc return v * (1024 ** power_labels[units]) # helper attribute to match the values defined in QResources and # the dictionary that should be passed to the template - _qresources_mapping = { + _qresources_mapping: ClassVar = { "queue_name": "queue", "job_name": "job_name", "account": "account", @@ -353,22 +351,20 @@ def _convert_memory_str(memory: str | None) -> int | None: } @staticmethod - def _convert_time_to_str(time: int | float | timedelta) -> str: + def _convert_time_to_str(time: int | float | timedelta) -> str: # noqa: PYI041 if not isinstance(time, timedelta): time = timedelta(seconds=time) hours, remainder = divmod(int(time.total_seconds()), 3600) minutes, seconds = divmod(remainder, 60) - time_str = f"{hours}:{minutes}:{seconds}" - return time_str + return f"{hours}:{minutes}:{seconds}" def _convert_qresources(self, resources: QResources) -> dict: """ Converts a QResources instance to a dict that will be used to fill in the header of the submission script. """ - header_dict = {} for qr_field, pbs_field in self._qresources_mapping.items(): val = getattr(resources, qr_field) diff --git a/src/qtoolkit/io/shell.py b/src/qtoolkit/io/shell.py index 2b52fdc..9aafbc3 100644 --- a/src/qtoolkit/io/shell.py +++ b/src/qtoolkit/io/shell.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pathlib import Path +from typing import TYPE_CHECKING from qtoolkit.core.data_objects import ( CancelResult, @@ -19,6 +19,9 @@ ) from qtoolkit.io.base import BaseSchedulerIO +if TYPE_CHECKING: + from pathlib import Path + # States in from ps command, extracted from man ps. # D uninterruptible sleep (usually IO) # R running or runnable (on run queue) @@ -152,9 +155,7 @@ def parse_cancel_output(self, exit_code, stdout, stderr) -> CancelResult: ) def _get_job_cmd(self, job_id: str): - cmd = self._get_jobs_list_cmd(job_ids=[job_id]) - - return cmd + return self._get_jobs_list_cmd(job_ids=[job_id]) def parse_job_output(self, exit_code, stdout, stderr) -> QJob | None: """Parse the output of the ps command and return the corresponding QJob object. @@ -239,9 +240,9 @@ def parse_jobs_list_output(self, exit_code, stdout, stderr) -> list[QJob]: try: shell_job_state = ShellState(data[3][0]) - except ValueError: + except ValueError as exc: msg = f"Unknown job state {data[3]} for job id {qjob.job_id}" - raise OutputParsingError(msg) + raise OutputParsingError(msg) from exc qjob.sub_state = shell_job_state qjob.state = shell_job_state.qstate @@ -276,7 +277,6 @@ def _convert_str_to_time(time_str: str | None) -> int | None: Convert a string in the format used in etime [[DD-]hh:]mm:ss to a number of seconds. """ - if not time_str: return None @@ -295,9 +295,9 @@ def _convert_str_to_time(time_str: str | None) -> int | None: elif len(time_split) == 2: minutes, seconds = (int(v) for v in time_split) else: - raise OutputParsingError() + raise OutputParsingError - except ValueError: - raise OutputParsingError() + except ValueError as exc: + raise OutputParsingError from exc return days * 86400 + hours * 3600 + minutes * 60 + seconds diff --git a/src/qtoolkit/io/slurm.py b/src/qtoolkit/io/slurm.py index e555092..48abfaf 100644 --- a/src/qtoolkit/io/slurm.py +++ b/src/qtoolkit/io/slurm.py @@ -2,6 +2,7 @@ import re from datetime import timedelta +from typing import ClassVar from qtoolkit.core.data_objects import ( CancelResult, @@ -72,7 +73,7 @@ # SI SIGNALING Job is being signaled. # # SE SPECIAL_EXIT The job was requeued in a special state. This state -# can be set by users, typically in Epi‐ +# can be set by users, typically in Epi- # logSlurmctld, if the job has terminated with a particular # exit value. # @@ -180,7 +181,7 @@ class SlurmIO(BaseSchedulerIO): "scancel -v" # The -v is needed as the default is to report nothing ) - squeue_fields = [ + squeue_fields: ClassVar = [ ("%i", "job_id"), # job or job step id ("%t", "state_raw"), # job state in compact form ("%r", "annotation"), # reason for the job being in its current state @@ -212,11 +213,11 @@ def parse_submit_output(self, exit_code, stdout, stderr) -> SubmissionResult: stderr=stderr, status=SubmissionStatus("FAILED"), ) - _SLURM_SUBMITTED_REGEXP = re.compile( + _slurm_submitted_regexp = re.compile( r"(.*:\s*)?([Gg]ranted job allocation|" r"[Ss]ubmitted batch job)\s+(?P\d+)" ) - match = _SLURM_SUBMITTED_REGEXP.match(stdout.strip()) + match = _slurm_submitted_regexp.match(stdout.strip()) job_id = match.group("jobid") if match else None status = ( SubmissionStatus("SUCCESSFUL") @@ -251,10 +252,10 @@ def parse_cancel_output(self, exit_code, stdout, stderr) -> CancelResult: stderr=stderr, status=CancelStatus("FAILED"), ) - _SLURM_CANCELLED_REGEXP = re.compile( + _slurm_cancelled_regexp = re.compile( r"(.*:\s*)?(Terminating job)\s+(?P\d+)" ) - match = _SLURM_CANCELLED_REGEXP.match(stderr.strip()) + match = _slurm_cancelled_regexp.match(stderr.strip()) job_id = match.group("jobid") if match else None status = ( CancelStatus("SUCCESSFUL") if job_id else CancelStatus("JOB_ID_UNKNOWN") @@ -432,9 +433,9 @@ def parse_jobs_list_output(self, exit_code, stdout, stderr) -> list[QJob]: try: slurm_job_state = SlurmState(job_state_string) - except ValueError: + except ValueError as exc: msg = f"Unknown job state {job_state_string} for job id {qjob.job_id}" - raise OutputParsingError(msg) + raise OutputParsingError(msg) from exc qjob.sub_state = slurm_job_state qjob.state = slurm_job_state.qstate @@ -482,10 +483,7 @@ def parse_jobs_list_output(self, exit_code, stdout, stderr) -> list[QJob]: @staticmethod def _convert_str_to_time(time_str: str | None) -> int | None: - """ - Convert a string in the format used by SLURM DD-HH:MM:SS to a number of seconds. - """ - + """Convert a string in the format used by SLURM DD-HH:MM:SS to a number of seconds.""" if not time_str: return None @@ -509,10 +507,10 @@ def _convert_str_to_time(time_str: str | None) -> int | None: elif len(time_split) == 1: minutes = int(time_split[0]) else: - raise OutputParsingError() + raise OutputParsingError - except ValueError: - raise OutputParsingError() + except ValueError as exc: + raise OutputParsingError from exc return days * 86400 + hours * 3600 + minutes * 60 + seconds @@ -532,14 +530,14 @@ def _convert_memory_str(memory: str | None) -> int | None: memory = memory[:-1] try: v = int(memory) - except ValueError: - raise OutputParsingError + except ValueError as exc: + raise OutputParsingError from exc power_labels = {"K": 0, "M": 1, "G": 2, "T": 3} return v * (1024 ** power_labels[units]) @staticmethod - def _convert_time_to_str(time: int | float | timedelta) -> str: + def _convert_time_to_str(time: int | float | timedelta) -> str: # noqa: PYI041 if not isinstance(time, timedelta): time = timedelta(seconds=time) @@ -547,12 +545,11 @@ def _convert_time_to_str(time: int | float | timedelta) -> str: hours, remainder = divmod(time.seconds, 3600) minutes, seconds = divmod(remainder, 60) - time_str = f"{days}-{hours}:{minutes}:{seconds}" - return time_str + return f"{days}-{hours}:{minutes}:{seconds}" # helper attribute to match the values defined in QResources and # the dictionary that should be passed to the template - _qresources_mapping = { + _qresources_mapping: ClassVar = { "queue_name": "partition", "job_name": "job_name", "memory_per_thread": "mem-per-cpu", @@ -568,7 +565,6 @@ def _convert_qresources(self, resources: QResources) -> dict: Converts a Qresources instance to a dict that will be used to fill in the header of the submission script. """ - header_dict = {} for qr_field, slurm_field in self._qresources_mapping.items(): val = getattr(resources, qr_field) diff --git a/src/qtoolkit/manager.py b/src/qtoolkit/manager.py index ed50b77..dcc212f 100644 --- a/src/qtoolkit/manager.py +++ b/src/qtoolkit/manager.py @@ -1,12 +1,20 @@ from __future__ import annotations from pathlib import Path +from typing import TYPE_CHECKING from qtoolkit.core.base import QTKObject -from qtoolkit.core.data_objects import CancelResult, QJob, QResources, SubmissionResult -from qtoolkit.host.base import BaseHost from qtoolkit.host.local import LocalHost -from qtoolkit.io.base import BaseSchedulerIO + +if TYPE_CHECKING: + from qtoolkit.core.data_objects import ( + CancelResult, + QJob, + QResources, + SubmissionResult, + ) + from qtoolkit.host.base import BaseHost + from qtoolkit.io.base import BaseSchedulerIO class QueueManager(QTKObject): @@ -69,12 +77,12 @@ def get_environment_setup(self, env_config) -> str: if env_config: env_setup = [] if "modules" in env_config: - env_setup.append("module purge") - for mod in env_config["modules"]: - env_setup.append(f"module load {mod}") + env_setup += [f"module load {mod}" for mod in env_config["modules"]] if "source_files" in env_config: - for source_file in env_config["source_files"]: - env_setup.append(f"source {source_file}") + env_setup += [ + f"source {source_file}" + for source_file in env_config["source_files"] + ] if "conda_environment" in env_config: env_setup.append(f'conda activate {env_config["conda_environment"]}') if "environ" in env_config: @@ -102,10 +110,9 @@ def get_pre_run(self, pre_run) -> str: def get_run_commands(self, commands) -> str: if isinstance(commands, str): return commands - elif isinstance(commands, list): + if isinstance(commands, list): return "\n".join(commands) - else: - raise ValueError("commands should be a str or a list of str.") + raise ValueError("commands should be a str or a list of str.") def get_post_run(self, post_run) -> str: pass diff --git a/src/qtoolkit/utils.py b/src/qtoolkit/utils.py index aefa8b1..a6b3466 100644 --- a/src/qtoolkit/utils.py +++ b/src/qtoolkit/utils.py @@ -10,7 +10,7 @@ def cd(path: str | Path): """ A Fabric-inspired cd context that temporarily changes directory for performing some tasks, and returns to the original working directory - afterwards. e.g., + afterwards. e.g. with cd("/my/path/"): do_something() diff --git a/tests/__init__.py b/tests/__init__.py index c806bfd..4639f3e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,5 +5,3 @@ module_dir = Path(__file__).resolve().parent test_dir = module_dir / "test_data" TEST_DIR = test_dir.resolve() - -__all__ = [TEST_DIR] diff --git a/tests/conftest.py b/tests/conftest.py index abb98cc..bc5e8d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,7 +49,7 @@ def clean_dir(debug_mode): @pytest.fixture() def tmp_dir(): - """Same as clean_dir but is fresh for every test""" + """Same as clean_dir but is fresh for every test.""" import os import shutil import tempfile diff --git a/tests/core/test_data_objects.py b/tests/core/test_data_objects.py index e873328..cea64df 100644 --- a/tests/core/test_data_objects.py +++ b/tests/core/test_data_objects.py @@ -378,37 +378,37 @@ def test_get_processes_distribution(self): ) proc_distr = qr.get_processes_distribution() assert proc_distr == [3, 3, 1] + qr = QResources( + process_placement=ProcessPlacement.SCATTERED, nodes=3, processes=4 + ) with pytest.raises( UnsupportedResourcesError, match=r"ProcessPlacement.SCATTERED is incompatible " r"with different values of nodes and processes", ): - qr = QResources( - process_placement=ProcessPlacement.SCATTERED, nodes=3, processes=4 - ) qr.get_processes_distribution() qr = QResources( process_placement=ProcessPlacement.SCATTERED, nodes=None, processes=None ) proc_distr = qr.get_processes_distribution() assert proc_distr == [1, 1, 1] + qr = QResources( + process_placement=ProcessPlacement.SCATTERED, + nodes=4, + processes=4, + processes_per_node=2, + ) with pytest.raises( UnsupportedResourcesError, match=r"ProcessPlacement.SCATTERED is incompatible " r"with 2 processes_per_node", ): - qr = QResources( - process_placement=ProcessPlacement.SCATTERED, - nodes=4, - processes=4, - processes_per_node=2, - ) qr.get_processes_distribution() + qr = QResources(process_placement=ProcessPlacement.SAME_NODE, nodes=4) with pytest.raises( UnsupportedResourcesError, - match=r"ProcessPlacement.SAME_NODE is incompatible " r"with 4 nodes", + match=r"ProcessPlacement.SAME_NODE is incompatible with 4 nodes", ): - qr = QResources(process_placement=ProcessPlacement.SAME_NODE, nodes=4) qr.get_processes_distribution() qr = QResources( process_placement=ProcessPlacement.SAME_NODE, @@ -426,17 +426,17 @@ def test_get_processes_distribution(self): ) proc_distr = qr.get_processes_distribution() assert proc_distr == [1, 6, 6] + qr = QResources( + process_placement=ProcessPlacement.SAME_NODE, + nodes=1, + processes=2, + processes_per_node=6, + ) with pytest.raises( UnsupportedResourcesError, match=r"ProcessPlacement.SAME_NODE is incompatible with " r"different values of nodes and processes", ): - qr = QResources( - process_placement=ProcessPlacement.SAME_NODE, - nodes=1, - processes=2, - processes_per_node=6, - ) qr.get_processes_distribution() qr = QResources( process_placement=ProcessPlacement.SAME_NODE, @@ -486,41 +486,41 @@ def test_get_processes_distribution(self): ) proc_distr = qr.get_processes_distribution() assert proc_distr == [4, None, 1] + qr = QResources( + process_placement=ProcessPlacement.EVENLY_DISTRIBUTED, + nodes=1, + processes=2, + processes_per_node=6, + ) with pytest.raises( UnsupportedResourcesError, match=r"ProcessPlacement.EVENLY_DISTRIBUTED " r"is incompatible with processes attribute", ): - qr = QResources( - process_placement=ProcessPlacement.EVENLY_DISTRIBUTED, - nodes=1, - processes=2, - processes_per_node=6, - ) qr.get_processes_distribution() + qr = QResources( + process_placement=ProcessPlacement.NO_CONSTRAINTS, + nodes=1, + processes=2, + processes_per_node=None, + ) with pytest.raises( UnsupportedResourcesError, match=r"ProcessPlacement.NO_CONSTRAINTS is incompatible " r"with processes_per_node and nodes attribute", ): - qr = QResources( - process_placement=ProcessPlacement.NO_CONSTRAINTS, - nodes=1, - processes=2, - processes_per_node=None, - ) qr.get_processes_distribution() + qr = QResources( + process_placement=ProcessPlacement.NO_CONSTRAINTS, + nodes=None, + processes=2, + processes_per_node=2, + ) with pytest.raises( UnsupportedResourcesError, match=r"ProcessPlacement.NO_CONSTRAINTS is incompatible " r"with processes_per_node and nodes attribute", ): - qr = QResources( - process_placement=ProcessPlacement.NO_CONSTRAINTS, - nodes=None, - processes=2, - processes_per_node=2, - ) qr.get_processes_distribution() qr = QResources( process_placement=ProcessPlacement.NO_CONSTRAINTS, diff --git a/tests/io/test_base.py b/tests/io/test_base.py index 09aa6a5..368672f 100644 --- a/tests/io/test_base.py +++ b/tests/io/test_base.py @@ -144,15 +144,15 @@ def test_generate_header(self, scheduler): #SPECCMD --nodes=4""" ) + res = QResources( + nodes=4, + processes_per_node=16, + scheduler_kwargs={"tata": "tata", "titi": "titi"}, + ) with pytest.raises( ValueError, match=r"The following keys are not present in the template: tata, titi", ): - res = QResources( - nodes=4, - processes_per_node=16, - scheduler_kwargs={"tata": "tata", "titi": "titi"}, - ) scheduler.generate_header(res) def test_generate_ids_list(self, scheduler): diff --git a/tests/io/test_slurm.py b/tests/io/test_slurm.py index f3ffe00..4fd77b3 100644 --- a/tests/io/test_slurm.py +++ b/tests/io/test_slurm.py @@ -23,7 +23,7 @@ def slurm_io(): class TestSlurmState: - @pytest.mark.parametrize("slurm_state", [s for s in SlurmState]) + @pytest.mark.parametrize("slurm_state", list(SlurmState)) def test_qstate(self, slurm_state): assert isinstance(slurm_state.qstate, QState) assert SlurmState("CA") == SlurmState.CANCELLED