diff --git a/.docker/tests/conftest.py b/.docker/tests/conftest.py index e077d3916d..419559ba09 100644 --- a/.docker/tests/conftest.py +++ b/.docker/tests/conftest.py @@ -1,8 +1,7 @@ # -*- coding: utf-8 -*- -# pylint: disable=missing-docstring, redefined-outer-name import json -from pathlib import Path import time +from pathlib import Path import pytest @@ -13,13 +12,12 @@ def variant(request): @pytest.fixture(scope='session') -def docker_compose_file(pytestconfig, variant): # pylint: disable=unused-argument +def docker_compose_file(pytestconfig, variant): return f'docker-compose.{variant}.yml' @pytest.fixture(scope='session') def docker_compose(docker_services): - # pylint: disable=protected-access return docker_services._docker_compose @@ -31,7 +29,6 @@ def is_container_ready(docker_compose): @pytest.fixture(scope='session', autouse=True) def _docker_service_wait(docker_services): """Container startup wait.""" - time.sleep(30) @@ -42,7 +39,6 @@ def container_user(): @pytest.fixture def aiida_exec(docker_compose): - def execute(command, user=None, **kwargs): if user: command = f'exec -T --user={user} aiida {command}' diff --git a/.docker/tests/test_aiida.py b/.docker/tests/test_aiida.py index d0a073d733..9b03ce1d8e 100644 --- a/.docker/tests/test_aiida.py +++ b/.docker/tests/test_aiida.py @@ -1,9 +1,8 @@ # -*- coding: utf-8 -*- -# pylint: disable=missing-docstring import json -from packaging.version import parse import pytest +from packaging.version import parse def test_correct_python_version_installed(aiida_exec, python_version): diff --git a/.github/system_tests/test_daemon.py b/.github/system_tests/test_daemon.py index c3f8c14ecb..3bbcce353b 100644 --- a/.github/system_tests/test_daemon.py +++ b/.github/system_tests/test_daemon.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=no-name-in-module """Tests to run with a running daemon.""" import os import re @@ -40,7 +39,7 @@ from aiida.orm.nodes.caching import NodeCaching from aiida.plugins import CalculationFactory, WorkflowFactory from aiida.workflows.arithmetic.add_multiply import add, add_multiply -from tests.utils.memory import get_instances # pylint: disable=import-error +from tests.utils.memory import get_instances CODENAME_ADD = 'add@localhost' CODENAME_DOUBLER = 'doubler@localhost' @@ -56,10 +55,12 @@ def print_daemon_log(): print(f"Output of 'cat {daemon_log}':") try: - print(subprocess.check_output( - ['cat', f'{daemon_log}'], - stderr=subprocess.STDOUT, - )) + print( + subprocess.check_output( + ['cat', f'{daemon_log}'], + stderr=subprocess.STDOUT, + ) + ) except subprocess.CalledProcessError as exception: print(f'Note: the command failed, message: {exception}') @@ -81,10 +82,12 @@ def print_report(pk): """Print the process report for given pk.""" print(f"Output of 'verdi process report {pk}':") try: - print(subprocess.check_output( - ['verdi', 'process', 'report', f'{pk}'], - stderr=subprocess.STDOUT, - )) + print( + subprocess.check_output( + ['verdi', 'process', 'report', f'{pk}'], + stderr=subprocess.STDOUT, + ) + ) except subprocess.CalledProcessError as exception: print(f'Note: the command failed, message: {exception}') @@ -193,12 +196,9 @@ def validate_workchains(expected_results): def validate_cached(cached_calcs): - """ - Check that the calculations with created with caching are indeed cached. - """ + """Check that the calculations with created with caching are indeed cached.""" valid = True for calc in cached_calcs: - if not calc.is_finished_ok: print( 'Cached calculation<{}> not finished ok: process_state<{}> exit_status<{}>'.format( @@ -208,8 +208,9 @@ def validate_cached(cached_calcs): print_report(calc.pk) valid = False - if NodeCaching.CACHED_FROM_KEY not in calc.base.extras or calc.base.caching.get_hash( - ) != calc.base.extras.get('_aiida_hash'): + if NodeCaching.CACHED_FROM_KEY not in calc.base.extras or calc.base.caching.get_hash() != calc.base.extras.get( + '_aiida_hash' + ): print(f'Cached calculation<{calc.pk}> has invalid hash') print_report(calc.pk) valid = False @@ -270,9 +271,7 @@ def launch_workfunction(inputval): def launch_calculation(code, counter, inputval): - """ - Launch calculations to the daemon through the Process layer - """ + """Launch calculations to the daemon through the Process layer""" process, inputs, expected_result = create_calculation_process(code=code, inputval=inputval) calc = submit(process, **inputs) print(f'[{counter}] launched calculation {calc.uuid}, pk={calc.pk}') @@ -280,9 +279,7 @@ def launch_calculation(code, counter, inputval): def run_calculation(code, counter, inputval): - """ - Run a calculation through the Process layer. - """ + """Run a calculation through the Process layer.""" process, inputs, expected_result = create_calculation_process(code=code, inputval=inputval) _, calc = run.get_node(process, **inputs) print(f'[{counter}] ran calculation {calc.uuid}, pk={calc.pk}') @@ -290,28 +287,25 @@ def run_calculation(code, counter, inputval): def create_calculation_process(code, inputval): - """ - Create the process and inputs for a submitting / running a calculation. - """ - TemplatereplacerCalculation = CalculationFactory('core.templatereplacer') + """Create the process and inputs for a submitting / running a calculation.""" parameters = Dict({'value': inputval}) - template = Dict({ - # The following line adds a significant sleep time. - # I set it to 1 second to speed up tests - # I keep it to a non-zero value because I want - # To test the case when AiiDA finds some calcs - # in a queued state - # 'cmdline_params': ["{}".format(counter % 3)], # Sleep time - 'cmdline_params': ['1'], - 'input_file_template': '{value}', # File just contains the value to double - 'input_file_name': 'value_to_double.txt', - 'output_file_name': 'output.txt', - 'retrieve_temporary_files': ['triple_value.tmp'] - }) + template = Dict( + { + # The following line adds a significant sleep time. + # I set it to 1 second to speed up tests + # I keep it to a non-zero value because I want + # To test the case when AiiDA finds some calcs + # in a queued state + # 'cmdline_params': ["{}".format(counter % 3)], # Sleep time + 'cmdline_params': ['1'], + 'input_file_template': '{value}', # File just contains the value to double + 'input_file_name': 'value_to_double.txt', + 'output_file_name': 'output.txt', + 'retrieve_temporary_files': ['triple_value.tmp'], + } + ) options = { - 'resources': { - 'num_machines': 1 - }, + 'resources': {'num_machines': 1}, 'max_wallclock_seconds': 5 * 60, 'withmpi': False, 'parser_name': 'core.templatereplacer', @@ -325,15 +319,13 @@ def create_calculation_process(code, inputval): 'template': template, 'metadata': { 'options': options, - } + }, } - return TemplatereplacerCalculation, inputs, expected_result + return CalculationFactory('core.templatereplacer'), inputs, expected_result def run_arithmetic_add(): """Run the `ArithmeticAddCalculation`.""" - ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') - code = load_code(CODENAME_ADD) inputs = { 'x': Int(1), @@ -342,7 +334,7 @@ def run_arithmetic_add(): } # Normal inputs should run just fine - results, node = run.get_node(ArithmeticAddCalculation, **inputs) + results, node = run.get_node(CalculationFactory('core.arithmetic.add'), **inputs) assert node.is_finished_ok, node.exit_status assert results['sum'] == 3 @@ -378,7 +370,7 @@ def run_base_restart_workchain(): inputs['add']['y'] = Int(10) results, node = run.get_node(ArithmeticAddBaseWorkChain, **inputs) assert not node.is_finished_ok, node.process_state - assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_TOO_BIG.status, node.exit_status # pylint: disable=no-member + assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_TOO_BIG.status, node.exit_status assert len(node.called) == 1 # Check that overriding default handler enabled status works @@ -386,14 +378,12 @@ def run_base_restart_workchain(): inputs['handler_overrides'] = Dict({'disabled_handler': True}) results, node = run.get_node(ArithmeticAddBaseWorkChain, **inputs) assert not node.is_finished_ok, node.process_state - assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_ENABLED_DOOM.status, node.exit_status # pylint: disable=no-member + assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_ENABLED_DOOM.status, node.exit_status assert len(node.called) == 1 def run_multiply_add_workchain(): """Run the `MultiplyAddWorkChain`.""" - MultiplyAddWorkChain = WorkflowFactory('core.arithmetic.multiply_add') - code = load_code(CODENAME_ADD) inputs = { 'x': Int(1), @@ -403,7 +393,7 @@ def run_multiply_add_workchain(): } # Normal inputs should run just fine - results, node = run.get_node(MultiplyAddWorkChain, **inputs) + results, node = run.get_node(WorkflowFactory('core.arithmetic.multiply_add'), **inputs) assert node.is_finished_ok, node.exit_status assert len(node.called) == 2 assert 'result' in results @@ -429,7 +419,6 @@ def launch_all(): :returns: dictionary with expected results and pks of all launched calculations and workchains """ - # pylint: disable=too-many-locals,too-many-statements expected_results_process_functions = {} expected_results_calculations = {} expected_results_workchains = {} @@ -451,7 +440,6 @@ def launch_all(): print('Testing the stashing functionality') process, inputs, expected_result = create_calculation_process(code=code_doubler, inputval=1) with tempfile.TemporaryDirectory() as tmpdir: - # Delete the temporary directory to test that the stashing functionality will create it if necessary shutil.rmtree(tmpdir, ignore_errors=True) @@ -571,8 +559,10 @@ def relaunch_cached(results): results['calculations'][calc.pk] = expected_result if not ( - validate_calculations(results['calculations']) and validate_workchains(results['workchains']) and - validate_cached(cached_calcs) and validate_process_functions(results['process_functions']) + validate_calculations(results['calculations']) + and validate_workchains(results['workchains']) + and validate_cached(cached_calcs) + and validate_process_functions(results['process_functions']) ): print_daemon_log() print('') @@ -586,7 +576,6 @@ def relaunch_cached(results): def main(): """Launch a bunch of calculation jobs and workchains.""" - results = launch_all() print('Waiting for end of execution...') @@ -603,19 +592,23 @@ def main(): print('#' * 78) print("Output of 'verdi process list -a':") try: - print(subprocess.check_output( - ['verdi', 'process', 'list', '-a'], - stderr=subprocess.STDOUT, - )) + print( + subprocess.check_output( + ['verdi', 'process', 'list', '-a'], + stderr=subprocess.STDOUT, + ) + ) except subprocess.CalledProcessError as exception: print(f'Note: the command failed, message: {exception}') print("Output of 'verdi daemon status':") try: - print(subprocess.check_output( - ['verdi', 'daemon', 'status'], - stderr=subprocess.STDOUT, - )) + print( + subprocess.check_output( + ['verdi', 'daemon', 'status'], + stderr=subprocess.STDOUT, + ) + ) except subprocess.CalledProcessError as exception: print(f'Note: the command failed, message: {exception}') diff --git a/.github/system_tests/workchains.py b/.github/system_tests/workchains.py index af2ef91c4f..43eed8c961 100644 --- a/.github/system_tests/workchains.py +++ b/.github/system_tests/workchains.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name """Work chain implementations for testing purposes.""" from aiida.common import AttributeDict from aiida.engine import ( @@ -64,15 +63,15 @@ def setup(self): def sanity_check_not_too_big(self, node): """My puny brain cannot deal with numbers that I cannot count on my hand.""" if node.is_finished_ok and node.outputs.sum > 10: - return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG) # pylint: disable=no-member + return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG) @process_handler(priority=460, enabled=False) - def disabled_handler(self, node): # pylint: disable=unused-argument + def disabled_handler(self, node): """By default this is not enabled and so should never be called, irrespective of exit codes of sub process.""" - return ProcessHandlerReport(True, self.exit_codes.ERROR_ENABLED_DOOM) # pylint: disable=no-member + return ProcessHandlerReport(True, self.exit_codes.ERROR_ENABLED_DOOM) @process_handler(priority=450, exit_codes=ExitCode(1000, 'Unicorn encountered')) - def a_magic_unicorn_appeared(self, node): # pylint: disable=no-self-argument + def a_magic_unicorn_appeared(self, node): """As we all know unicorns do not exist so we should never have to deal with it.""" raise RuntimeError('this handler should never even have been called') @@ -85,9 +84,7 @@ def error_negative_sum(self, node): class NestedWorkChain(WorkChain): - """ - Nested workchain which creates a workflow where the nesting level is equal to its input. - """ + """Nested workchain which creates a workflow where the nesting level is equal to its input.""" @classmethod def define(cls, spec): @@ -216,9 +213,7 @@ def do_test(self): class CalcFunctionRunnerWorkChain(WorkChain): - """ - WorkChain which calls an InlineCalculation in its step. - """ + """WorkChain which calls an InlineCalculation in its step.""" @classmethod def define(cls, spec): @@ -234,9 +229,7 @@ def do_run(self): class WorkFunctionRunnerWorkChain(WorkChain): - """ - WorkChain which calls a workfunction in its step - """ + """WorkChain which calls a workfunction in its step""" @classmethod def define(cls, spec): diff --git a/.github/workflows/check_release_tag.py b/.github/workflows/check_release_tag.py index 47b45865c5..b6552849f6 100644 --- a/.github/workflows/check_release_tag.py +++ b/.github/workflows/check_release_tag.py @@ -14,8 +14,11 @@ def get_version_from_module(content: str) -> str: raise IOError(f'Unable to parse module: {exc}') try: return next( - ast.literal_eval(statement.value) for statement in module.body if isinstance(statement, ast.Assign) - for target in statement.targets if isinstance(target, ast.Name) and target.id == '__version__' + ast.literal_eval(statement.value) + for statement in module.body + if isinstance(statement, ast.Assign) + for target in statement.targets + if isinstance(target, ast.Name) and target.id == '__version__' ) except StopIteration: raise IOError('Unable to find __version__ in module') diff --git a/.molecule/default/files/polish/cli.py b/.molecule/default/files/polish/cli.py index 34f7ff0a5d..78de183426 100755 --- a/.molecule/default/files/polish/cli.py +++ b/.molecule/default/files/polish/cli.py @@ -25,7 +25,7 @@ @options.CODE( type=types.CodeParamType(entry_point='core.arithmetic.add'), required=False, - help='Code to perform the add operations with. Required if -C flag is specified' + help='Code to perform the add operations with. Required if -C flag is specified', ) @click.option( '-C', @@ -33,7 +33,7 @@ is_flag=True, default=False, show_default=True, - help='Use job calculations to perform all additions' + help='Use job calculations to perform all additions', ) @click.option( '-F', @@ -41,7 +41,7 @@ is_flag=True, default=False, show_default=True, - help='Use calcfunctions to perform all substractions' + help='Use calcfunctions to perform all substractions', ) @click.option( '-s', @@ -49,7 +49,7 @@ type=click.INT, default=5, show_default=True, - help='When submitting to the daemon, the number of seconds to sleep between polling the workchain process state' + help='When submitting to the daemon, the number of seconds to sleep between polling the workchain process state', ) @click.option( '-t', @@ -57,7 +57,7 @@ type=click.INT, default=60, show_default=True, - help='When submitting to the daemon, the number of seconds to wait for a workchain to finish before timing out' + help='When submitting to the daemon, the number of seconds to wait for a workchain to finish before timing out', ) @click.option( '-m', @@ -65,19 +65,18 @@ type=click.INT, default=1000000, show_default=True, - help='Specify an integer to modulo all intermediate and the final result to avoid integer overflow' + help='Specify an integer to modulo all intermediate and the final result to avoid integer overflow', ) @click.option( '-n', '--dry-run', is_flag=True, default=False, - help='Only evaluate the expression and generate the workchain but do not launch it' + help='Only evaluate the expression and generate the workchain but do not launch it', ) @decorators.with_dbenv() def launch(expression, code, use_calculations, use_calcfunctions, sleep, timeout, modulo, dry_run, daemon): - """ - Evaluate the expression in Reverse Polish Notation in both a normal way and by procedurally generating + """Evaluate the expression in Reverse Polish Notation in both a normal way and by procedurally generating a workchain that encodes the sequence of operators and gets the stack of operands as an input. Multiplications are modelled by a 'while_' construct and addition will be done performed by an addition or a subtraction, depending on the sign, branched by the 'if_' construct. Powers will be simulated by nested workchains. @@ -98,7 +97,6 @@ def launch(expression, code, use_calculations, use_calcfunctions, sleep, timeout If no expression is specified, a random one will be generated that adheres to these rules """ - # pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-branches from aiida.engine import run_get_node from aiida.orm import AbstractCode, Int, Str @@ -199,4 +197,4 @@ def run_via_daemon(workchains, inputs, sleep, timeout): if __name__ == '__main__': - launch() # pylint: disable=no-value-for-parameter + launch() diff --git a/.molecule/default/files/polish/lib/expression.py b/.molecule/default/files/polish/lib/expression.py index 1bf2123970..83eae2f98a 100644 --- a/.molecule/default/files/polish/lib/expression.py +++ b/.molecule/default/files/polish/lib/expression.py @@ -20,8 +20,7 @@ def generate(min_operator_count=3, max_operator_count=5, min_operand_value=-5, max_operand_value=5): - """ - Generate a random valid expression in Reverse Polish Notation. There are a few limitations: + """Generate a random valid expression in Reverse Polish Notation. There are a few limitations: * Only integers are supported * Only the addition, multiplication and power operators (+, * and ^, respectively) are supported @@ -58,8 +57,7 @@ def generate(min_operator_count=3, max_operator_count=5, min_operand_value=-5, m def validate(expression): - """ - Validate an expression in Reverse Polish Notation. In addition to normal rules, the following restrictions apply: + """Validate an expression in Reverse Polish Notation. In addition to normal rules, the following restrictions apply: * Only integers are supported * Only the addition, multiplication and power operators (+, * and ^, respectively) are supported @@ -69,7 +67,6 @@ def validate(expression): :param expression: the expression in Reverse Polish Notation :return: tuple(Bool, list) indicating whether expression is valid and if not a list of error messages """ - # pylint: disable=too-many-return-statements try: symbols = expression.split() except ValueError as exception: @@ -106,8 +103,7 @@ def validate(expression): def evaluate(expression, modulo=None): - """ - Evaluate an expression in Reverse Polish Notation. There are a few limitations: + """Evaluate an expression in Reverse Polish Notation. There are a few limitations: * Only integers are supported * Only the addition, multiplication and power operators (+, * and ^, respectively) are supported diff --git a/.molecule/default/files/polish/lib/workchain.py b/.molecule/default/files/polish/lib/workchain.py index a77e7f6b29..f73d08acf8 100644 --- a/.molecule/default/files/polish/lib/workchain.py +++ b/.molecule/default/files/polish/lib/workchain.py @@ -15,7 +15,7 @@ from pathlib import Path from string import Template -from .expression import OPERATORS # pylint: disable=relative-beyond-top-level +from .expression import OPERATORS INDENTATION_WIDTH = 4 @@ -71,8 +71,7 @@ def generate_outlines(expression): - """ - For a given expression in Reverse Polish Notation, generate the nested symbolic structure of the outlines. + """For a given expression in Reverse Polish Notation, generate the nested symbolic structure of the outlines. :param expression: a valid expression :return: a nested list structure of strings representing the structure of the outlines @@ -82,7 +81,6 @@ def generate_outlines(expression): outline = [['add']] for part in expression.split(): - if part not in OPERATORS: stack.appendleft(part) values.append(part) @@ -107,8 +105,7 @@ def generate_outlines(expression): def format_outlines(outlines, use_calculations=False, use_calcfunctions=False): - """ - Given the symbolic structure of the workchain outlines produced by ``generate_outlines``, format the actual + """Given the symbolic structure of the workchain outlines produced by ``generate_outlines``, format the actual string form of those workchain outlines :param outlines: the list of symbolic outline structures @@ -119,7 +116,6 @@ def format_outlines(outlines, use_calculations=False, use_calcfunctions=False): outline_strings = [] for sub_outline in outlines: - outline_string = '' for instruction in sub_outline: @@ -140,8 +136,7 @@ def format_outlines(outlines, use_calculations=False, use_calcfunctions=False): def format_block(instruction, level=0, use_calculations=False, use_calcfunctions=False): - """ - Format the instruction into its proper string form + """Format the instruction into its proper string form :param use_calculations: use CalcJobs for the add operations :param use_calcfunctions: use calcfunctions for the subtract operations @@ -176,8 +171,7 @@ def format_block(instruction, level=0, use_calculations=False, use_calcfunctions def format_indent(level=0, width=INDENTATION_WIDTH): - """ - Format the indentation for the given indentation level and indentation width + """Format the indentation for the given indentation level and indentation width :param level: the level of indentation :param width: the width in spaces of a single indentation @@ -187,8 +181,7 @@ def format_indent(level=0, width=INDENTATION_WIDTH): def write_workchain(outlines, directory=None) -> Path: - """ - Given a list of string formatted outlines, write the corresponding workchains to file + """Given a list of string formatted outlines, write the corresponding workchains to file :returns: file path """ @@ -219,10 +212,9 @@ def write_workchain(outlines, directory=None) -> Path: counter = len(outlines) - 1 for outline in outlines: - outline_string = '' for subline in outline.split('\n'): - outline_string += f'\t\t\t{subline}\n' # pylint: disable=consider-using-join + outline_string += f'\t\t\t{subline}\n' if counter == len(outlines) - 1: child_class = None diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 130fe378dc..bb16a675aa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ ci: autoupdate_schedule: monthly autofix_prs: true - skip: [mypy, pylint, dm-generate-all, dependencies, verdi-autodocs] + skip: [mypy, dm-generate-all, dependencies, verdi-autodocs] repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -21,12 +21,6 @@ repos: exclude: *exclude_pre_commit_hooks - id: check-yaml - -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - - repo: https://github.com/ikamensh/flynt/ rev: '1.0.1' hooks: @@ -36,18 +30,18 @@ repos: '--fail-on-change', ] -- repo: https://github.com/google/yapf - rev: v0.40.0 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: 'v0.1.3' hooks: - - id: yapf - name: yapf - types: [python] - exclude: | + - id: ruff-format + exclude: &exclude_ruff > (?x)^( - docs/.*| + docs/source/topics/processes/include/snippets/functions/parse_docstring_expose_ipython.py| + docs/source/topics/processes/include/snippets/functions/signature_plain_python_call_illegal.py| )$ - args: ['-i'] - additional_dependencies: ['toml'] + - id: ruff + exclude: *exclude_ruff + args: [--fix, --exit-non-zero-on-fix, --show-fixes] - repo: local @@ -194,17 +188,6 @@ repos: )$ - - id: pylint - name: pylint - entry: pylint - types: [python] - language: system - exclude: | - (?x)^( - docs/.*| - .docker/.*| - )$ - - id: dm-generate-all name: Update all requirements files entry: python ./utils/dependency_management.py generate-all diff --git a/aiida/__init__.py b/aiida/__init__.py index c7386ac933..9cadf4fa0d 100644 --- a/aiida/__init__.py +++ b/aiida/__init__.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -AiiDA is a flexible and scalable informatics' infrastructure to manage, +"""AiiDA is a flexible and scalable informatics' infrastructure to manage, preserve, and disseminate the simulations, data, and workflows of modern-day computational science. @@ -20,8 +19,8 @@ More information at http://www.aiida.net """ -from aiida.common.log import configure_logging -from aiida.manage.configuration import get_config_option, get_profile, load_profile, profile_context +from aiida.common.log import configure_logging # noqa: F401 +from aiida.manage.configuration import get_config_option, get_profile, load_profile, profile_context # noqa: F401 __copyright__ = ( 'Copyright (c), This file is part of the AiiDA platform. ' @@ -38,15 +37,15 @@ def get_strict_version(): - """ - Return a distutils StrictVersion instance with the current distribution version + """Return a distutils StrictVersion instance with the current distribution version :returns: StrictVersion instance with the current version :rtype: :class:`!distutils.version.StrictVersion` """ - from distutils.version import StrictVersion # pylint: disable=deprecated-module + from distutils.version import StrictVersion from aiida.common.warnings import warn_deprecation + warn_deprecation( 'This method is deprecated as the `distutils` package it uses will be removed in Python 3.12.', version=3 ) @@ -54,8 +53,7 @@ def get_strict_version(): def get_version() -> str: - """ - Return the current AiiDA distribution version + """Return the current AiiDA distribution version :returns: the current version """ @@ -63,8 +61,7 @@ def get_version() -> str: def _get_raw_file_header() -> str: - """ - Get the default header for source AiiDA source code files. + """Get the default header for source AiiDA source code files. Note: is not preceded by comment character. :return: default AiiDA source file header @@ -76,8 +73,7 @@ def _get_raw_file_header() -> str: def get_file_header(comment_char: str = '# ') -> str: - """ - Get the default header for source AiiDA source code files. + """Get the default header for source AiiDA source code files. .. note:: @@ -94,4 +90,5 @@ def get_file_header(comment_char: str = '# ') -> str: def load_ipython_extension(ipython): """Load the AiiDA IPython extension, using ``%load_ext aiida``.""" from .tools.ipython.ipython_magics import AiiDALoaderMagics + ipython.register_magics(AiiDALoaderMagics) diff --git a/aiida/__main__.py b/aiida/__main__.py index bf661ecdfe..f828d55752 100644 --- a/aiida/__main__.py +++ b/aiida/__main__.py @@ -12,4 +12,5 @@ if __name__ == '__main__': from aiida.cmdline.commands.cmd_verdi import verdi + sys.exit(verdi()) diff --git a/aiida/calculations/diff_tutorial/calculations.py b/aiida/calculations/diff_tutorial/calculations.py index 5e3887a90b..c6ea08590a 100644 --- a/aiida/calculations/diff_tutorial/calculations.py +++ b/aiida/calculations/diff_tutorial/calculations.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -""" -Calculations provided by aiida_diff tutorial plugin. +"""Calculations provided by aiida_diff tutorial plugin. Register calculations via the "aiida.calculations" entry point in the pyproject.toml file. """ @@ -10,8 +9,7 @@ class DiffCalculation(CalcJob): - """ - AiiDA calculation plugin wrapping the diff executable. + """AiiDA calculation plugin wrapping the diff executable. Simple AiiDA plugin wrapper for 'diffing' two files. """ @@ -19,7 +17,6 @@ class DiffCalculation(CalcJob): @classmethod def define(cls, spec): """Define inputs and outputs of the calculation.""" - # yapf: disable super(DiffCalculation, cls).define(spec) # new ports @@ -29,18 +26,17 @@ def define(cls, spec): spec.input('metadata.options.output_filename', valid_type=str, default='patch.diff') spec.inputs['metadata']['options']['resources'].default = { - 'num_machines': 1, - 'num_mpiprocs_per_machine': 1, - } + 'num_machines': 1, + 'num_mpiprocs_per_machine': 1, + } spec.inputs['metadata']['options']['parser_name'].default = 'diff-tutorial' - spec.exit_code(300, 'ERROR_MISSING_OUTPUT_FILES', - message='Calculation did not produce all expected output files.') - + spec.exit_code( + 300, 'ERROR_MISSING_OUTPUT_FILES', message='Calculation did not produce all expected output files.' + ) def prepare_for_submission(self, folder): - """ - Create input files. + """Create input files. :param folder: an `aiida.common.folders.Folder` where the plugin should temporarily place all files needed by the calculation. diff --git a/aiida/calculations/monitors/base.py b/aiida/calculations/monitors/base.py index 9b4f2fd55e..5dfd8ceee6 100644 --- a/aiida/calculations/monitors/base.py +++ b/aiida/calculations/monitors/base.py @@ -8,7 +8,7 @@ from aiida.transports import Transport -def always_kill(node: CalcJobNode, transport: Transport) -> str | None: # pylint: disable=unused-argument +def always_kill(node: CalcJobNode, transport: Transport) -> str | None: """Retrieve and inspect files in working directory of job to determine whether the job should be killed. This particular implementation is just for demonstration purposes and will kill the job as long as there is a diff --git a/aiida/calculations/templatereplacer.py b/aiida/calculations/templatereplacer.py index 0e3da70496..ba9a3f999c 100644 --- a/aiida/calculations/templatereplacer.py +++ b/aiida/calculations/templatereplacer.py @@ -17,8 +17,7 @@ class TemplatereplacerCalculation(CalcJob): - """ - Simple stub of a plugin that can be used to replace some text in a given template. + """Simple stub of a plugin that can be used to replace some text in a given template. Can be used for many different codes, or as a starting point to develop a new plugin. This simple plugin takes two node inputs, both of type Dict, with the labels @@ -68,7 +67,7 @@ def define(cls, spec): 'parameters', valid_type=orm.Dict, required=False, - help='Parameters used to replace placeholders in the template.' + help='Parameters used to replace placeholders in the template.', ) spec.input_namespace('files', valid_type=(orm.RemoteData, orm.SinglefileData), required=False, dynamic=True) @@ -79,37 +78,35 @@ def define(cls, spec): 301, 'ERROR_NO_TEMPORARY_RETRIEVED_FOLDER', invalidates_cache=True, - message='The temporary retrieved folder data node could not be accessed.' + message='The temporary retrieved folder data node could not be accessed.', ) spec.exit_code( 305, 'ERROR_NO_OUTPUT_FILE_NAME_DEFINED', invalidates_cache=True, - message='The `template` input node did not specify the key `output_file_name`.' + message='The `template` input node did not specify the key `output_file_name`.', ) spec.exit_code( 310, 'ERROR_READING_OUTPUT_FILE', invalidates_cache=True, - message='The output file could not be read from the retrieved folder.' + message='The output file could not be read from the retrieved folder.', ) spec.exit_code( 311, 'ERROR_READING_TEMPORARY_RETRIEVED_FILE', invalidates_cache=True, - message='A temporary retrieved file could not be read from the temporary retrieved folder.' + message='A temporary retrieved file could not be read from the temporary retrieved folder.', ) spec.exit_code( 320, 'ERROR_INVALID_OUTPUT', invalidates_cache=True, message='The output file contains invalid output.' ) def prepare_for_submission(self, folder): - """ - This is the routine to be called when you want to create the input files and related stuff with a plugin. + """This is the routine to be called when you want to create the input files and related stuff with a plugin. :param folder: a aiida.common.folders.Folder subclass where the plugin should put all its files. """ - # pylint: disable=too-many-locals,too-many-statements,too-many-branches from aiida.common.exceptions import ValidationError from aiida.common.utils import validate_list_of_string_tuples @@ -174,9 +171,8 @@ def prepare_for_submission(self, folder): input_content = input_file_template.format(**parameters) if input_file_name: folder.create_file_from_filelike(io.StringIO(input_content), input_file_name, 'w', encoding='utf8') - else: - if input_file_template: - self.logger.warning('No input file name passed, but a input file template is present') + elif input_file_template: + self.logger.warning('No input file name passed, but a input file template is present') cmdline_params = [i.format(**parameters) for i in cmdline_params_tmpl] diff --git a/aiida/calculations/transfer.py b/aiida/calculations/transfer.py index 45ded1e2f3..3bd1268887 100644 --- a/aiida/calculations/transfer.py +++ b/aiida/calculations/transfer.py @@ -18,7 +18,6 @@ def validate_instructions(instructions, _): """Check that the instructions dict contains the necessary keywords""" - instructions_dict = instructions.get_dict() retrieve_files = instructions_dict.get('retrieve_files', None) @@ -59,7 +58,6 @@ def validate_instructions(instructions, _): def validate_transfer_inputs(inputs, _): """Check that the instructions dict and the source nodes are consistent""" - source_nodes = inputs['source_nodes'] instructions = inputs['instructions'] computer = inputs['metadata']['computer'] @@ -117,7 +115,6 @@ def validate_transfer_inputs(inputs, _): def check_node_type(list_name, node_label, node_object, node_type): """Common utility function to check the type of a node""" - if node_object is None: return f' > node `{node_label}` requested on list `{list_name}` not found among inputs' @@ -217,34 +214,37 @@ def prepare_for_submission(self, folder): retrieve_paths = [] for source_label, source_relpath, target_relpath in local_files: - source_node = source_nodes[source_label] retrieve_paths.append(target_relpath) - calc_info.local_copy_list.append(( - source_node.uuid, - source_relpath, - target_relpath, - )) + calc_info.local_copy_list.append( + ( + source_node.uuid, + source_relpath, + target_relpath, + ) + ) for source_label, source_relpath, target_relpath in remote_files: - source_node = source_nodes[source_label] retrieve_paths.append(target_relpath) - calc_info.remote_copy_list.append(( - source_node.computer.uuid, - os.path.join(source_node.get_remote_path(), source_relpath), - target_relpath, - )) + calc_info.remote_copy_list.append( + ( + source_node.computer.uuid, + os.path.join(source_node.get_remote_path(), source_relpath), + target_relpath, + ) + ) for source_label, source_relpath, target_relpath in symlink_files: - source_node = source_nodes[source_label] retrieve_paths.append(target_relpath) - calc_info.remote_symlink_list.append(( - source_node.computer.uuid, - os.path.join(source_node.get_remote_path(), source_relpath), - target_relpath, - )) + calc_info.remote_symlink_list.append( + ( + source_node.computer.uuid, + os.path.join(source_node.get_remote_path(), source_relpath), + target_relpath, + ) + ) if retrieve_files: calc_info.retrieve_list = retrieve_paths diff --git a/aiida/cmdline/__init__.py b/aiida/cmdline/__init__.py index 4b3e166a76..4a08002a2b 100644 --- a/aiida/cmdline/__init__.py +++ b/aiida/cmdline/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .groups import * from .params import * @@ -61,4 +60,4 @@ 'with_dbenv', ) -# yapf: enable +# fmt: on diff --git a/aiida/cmdline/commands/__init__.py b/aiida/cmdline/commands/__init__.py index 8b99390a26..00a3dc45e1 100644 --- a/aiida/cmdline/commands/__init__.py +++ b/aiida/cmdline/commands/__init__.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# ruff: noqa: F401, E402 """Sub commands of the ``verdi`` command line interface. The commands need to be imported here for them to be registered with the top-level command group. diff --git a/aiida/cmdline/commands/cmd_archive.py b/aiida/cmdline/commands/cmd_archive.py index 1c933c0e64..c69ab22d4b 100644 --- a/aiida/cmdline/commands/cmd_archive.py +++ b/aiida/cmdline/commands/cmd_archive.py @@ -7,12 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-arguments,import-error,too-many-locals,broad-except """`verdi archive` command.""" -from enum import Enum import logging -from pathlib import Path import traceback +from enum import Enum +from pathlib import Path from typing import List, Tuple import click @@ -43,6 +42,7 @@ def archive_version(path): # note: this mirrors `cmd_storage:storage_version` # it is currently hardcoded to the `SqliteZipBackend`, but could be generalized in the future from aiida.storage.sqlite_zip.backend import SqliteZipBackend + storage_cls = SqliteZipBackend profile = storage_cls.create_profile(path) head_version = storage_cls.version_head() @@ -62,6 +62,7 @@ def archive_info(path, detailed): # note: this mirrors `cmd_storage:storage_info` # it is currently hardcoded to the `SqliteZipBackend`, but could be generalized in the future from aiida.storage.sqlite_zip.backend import SqliteZipBackend + try: storage = SqliteZipBackend(SqliteZipBackend.create_profile(path)) except (UnreachableStorage, CorruptStorage) as exc: @@ -87,7 +88,7 @@ def archive_info(path, detailed): 'Please call `verdi archive version` or `verdi archive info` instead.\n' ) @click.pass_context -def inspect(ctx, archive, version, meta_data, database): # pylint: disable=unused-argument +def inspect(ctx, archive, version, meta_data, database): """Inspect contents of an archive without importing it. .. deprecated:: v2.0.0, use `verdi archive version` or `verdi archive info` instead. @@ -113,19 +114,19 @@ def inspect(ctx, archive, version, meta_data, database): # pylint: disable=unus '--include-logs/--exclude-logs', default=True, show_default=True, - help='Include or exclude logs for node(s) in export.' + help='Include or exclude logs for node(s) in export.', ) @click.option( '--include-comments/--exclude-comments', default=True, show_default=True, - help='Include or exclude comments for node(s) in export. (Will also export extra users who commented).' + help='Include or exclude comments for node(s) in export. (Will also export extra users who commented).', ) @click.option( '--include-authinfos/--exclude-authinfos', default=False, show_default=True, - help='Include or exclude authentication information for computer(s) in export.' + help='Include or exclude authentication information for computer(s) in export.', ) @click.option('--compress', default=6, show_default=True, type=int, help='Level of compression to use (0-9).') @click.option( @@ -134,9 +135,25 @@ def inspect(ctx, archive, version, meta_data, database): # pylint: disable=unus @click.option('--test-run', is_flag=True, help='Determine entities to export, but do not create the archive.') @decorators.with_dbenv() def create( - output_file, all_entries, codes, computers, groups, nodes, force, input_calc_forward, input_work_forward, - create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs, - include_authinfos, compress, batch_size, test_run + output_file, + all_entries, + codes, + computers, + groups, + nodes, + force, + input_calc_forward, + input_work_forward, + create_backward, + return_backward, + call_calc_backward, + call_work_backward, + include_comments, + include_logs, + include_authinfos, + compress, + batch_size, + test_run, ): """Create an archive from all or part of a profiles's data. @@ -146,7 +163,6 @@ def create( their provenance, according to the rules outlined in the documentation. You can modify some of those rules using options of this command. """ - # pylint: disable=too-many-branches from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter from aiida.tools.archive.abstract import get_format from aiida.tools.archive.create import create_archive @@ -184,10 +200,10 @@ def create( 'overwrite': force, 'compression': compress, 'batch_size': batch_size, - 'test_run': test_run + 'test_run': test_run, } - if AIIDA_LOGGER.level <= logging.REPORT: # pylint: disable=no-member + if AIIDA_LOGGER.level <= logging.REPORT: set_progress_bar_tqdm(leave=AIIDA_LOGGER.level <= logging.INFO) else: set_progress_reporter(None) @@ -230,7 +246,7 @@ def migrate(input_file, output_file, force, in_place, version): 'no output file specified. Please add --in-place flag if you would like to migrate in place.' ) - if AIIDA_LOGGER.level <= logging.REPORT: # pylint: disable=no-member + if AIIDA_LOGGER.level <= logging.REPORT: set_progress_bar_tqdm(leave=AIIDA_LOGGER.level <= logging.INFO) else: set_progress_reporter(None) @@ -242,7 +258,7 @@ def migrate(input_file, output_file, force, in_place, version): try: archive_format.migrate(input_file, output_file, version, force=force, compression=6) - except Exception as error: # pylint: disable=broad-except + except Exception as error: if AIIDA_LOGGER.level <= logging.DEBUG: raise echo.echo_critical( @@ -255,7 +271,7 @@ def migrate(input_file, output_file, force, in_place, version): class ExtrasImportCode(Enum): """Exit codes for the verdi command line.""" - # pylint: disable=invalid-name + keep_existing = ('k', 'c', 'l') update_existing = ('k', 'c', 'u') mirror = ('n', 'c', 'u') @@ -270,18 +286,18 @@ class ExtrasImportCode(Enum): type=click.STRING, cls=options.MultipleValueOption, help='Discover all URL targets pointing to files with the .aiida extension for these HTTP addresses. ' - 'Automatically discovered archive URLs will be downloaded and added to ARCHIVES for importing.' + 'Automatically discovered archive URLs will be downloaded and added to ARCHIVES for importing.', ) @click.option( '--import-group/--no-import-group', default=True, show_default=True, - help='Add all imported nodes to the specified group, or an automatically created one' + help='Add all imported nodes to the specified group, or an automatically created one', ) @options.GROUP( type=GroupParamType(create_if_not_exist=True), help='Specify group to which all the import nodes will be added. If such a group does not exist, it will be' - ' created automatically.' + ' created automatically.', ) @click.option( '-e', @@ -293,16 +309,14 @@ class ExtrasImportCode(Enum): 'none: do not import any extras.' 'keep_existing: import all extras and keep original value of existing extras. ' 'update_existing: import all extras and overwrite value of existing extras. ' - 'mirror: import all extras and remove any existing extras that are not present in the archive. ' + 'mirror: import all extras and remove any existing extras that are not present in the archive. ', ) @click.option( '-n', '--extras-mode-new', type=click.Choice(EXTRAS_MODE_NEW), default='import', - help='Specify whether to import extras of new nodes: ' - 'import: import extras. ' - 'none: do not import extras.' + help='Specify whether to import extras of new nodes: ' 'import: import extras. ' 'none: do not import extras.', ) @click.option( '--comment-mode', @@ -311,19 +325,19 @@ class ExtrasImportCode(Enum): help='Specify the way to import Comments with identical UUIDs: ' 'leave: Leave the existing Comments in the database (default).' 'newest: Use only the newest Comments (based on mtime).' - 'overwrite: Replace existing Comments with those from the import file.' + 'overwrite: Replace existing Comments with those from the import file.', ) @click.option( '--include-authinfos/--exclude-authinfos', default=False, show_default=True, - help='Include or exclude authentication information for computer(s) in import.' + help='Include or exclude authentication information for computer(s) in import.', ) @click.option( '--migration/--no-migration', default=True, show_default=True, - help='Force migration of archive file archives, if needed.' + help='Force migration of archive file archives, if needed.', ) @click.option( '-b', '--batch-size', default=1000, type=int, help='Stream database rows in batches, to reduce memory usage.' @@ -332,17 +346,26 @@ class ExtrasImportCode(Enum): @decorators.with_dbenv() @click.pass_context def import_archive( - ctx, archives, webpages, extras_mode_existing, extras_mode_new, comment_mode, include_authinfos, migration, - batch_size, import_group, group, test_run + ctx, + archives, + webpages, + extras_mode_existing, + extras_mode_new, + comment_mode, + include_authinfos, + migration, + batch_size, + import_group, + group, + test_run, ): """Import archived data to a profile. The archive can be specified by its relative or absolute file path, or its HTTP URL. """ - # pylint: disable=unused-argument from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter - if AIIDA_LOGGER.level <= logging.REPORT: # pylint: disable=no-member + if AIIDA_LOGGER.level <= logging.REPORT: set_progress_bar_tqdm(leave=AIIDA_LOGGER.level <= logging.INFO) else: set_progress_reporter(None) @@ -378,7 +401,8 @@ def _echo_exception(msg: str, exception, warn_only: bool = False): """ from aiida.tools.archive.imports import IMPORT_LOGGER - message = f'{msg}: {exception.__class__.__name__}: {str(exception)}' + + message = f'{msg}: {exception.__class__.__name__}: {exception!s}' if warn_only: echo.echo_warning(message) else: @@ -441,7 +465,6 @@ def _import_archive_and_migrate( filepath = ctx.obj['config'].get_option('storage.sandbox') or None with SandboxFolder(filepath=filepath) as temp_folder: - archive_path = archive if web_based: @@ -460,7 +483,6 @@ def _import_archive_and_migrate( _import_archive(archive_path, archive_format=archive_format, **import_kwargs) except IncompatibleStorageSchema as exception: if try_migration: - echo.echo_report(f'incompatible version detected for {archive}, trying migration') try: new_path = temp_folder.get_abs_path('migrated_archive.aiida') diff --git a/aiida/cmdline/commands/cmd_calcjob.py b/aiida/cmdline/commands/cmd_calcjob.py index 14eaa2b387..fa87a26856 100644 --- a/aiida/cmdline/commands/cmd_calcjob.py +++ b/aiida/cmdline/commands/cmd_calcjob.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,too-many-locals """`verdi calcjob` commands.""" import os @@ -27,8 +26,7 @@ def verdi_calcjob(): @verdi_calcjob.command('gotocomputer') @arguments.CALCULATION('calcjob', type=CalculationParamType(sub_classes=('aiida.node:process.calculation.calcjob',))) def calcjob_gotocomputer(calcjob): - """ - Open a shell in the remote folder on the calcjob. + """Open a shell in the remote folder on the calcjob. This command opens a ssh connection to the folder on the remote computer on which the calcjob is being/has been executed. @@ -80,16 +78,15 @@ def calcjob_res(calcjob, fmt, keys): @click.argument('path', type=click.STRING, required=False) @decorators.with_dbenv() def calcjob_inputcat(calcjob, path): - """ - Show the contents of one of the calcjob input files. + """Show the contents of one of the calcjob input files. You can specify the relative PATH in the raw input folder of the CalcJob. If PATH is not specified, the default input file path will be used, if defined by the calcjob plugin class. """ import errno - from shutil import copyfileobj import sys + from shutil import copyfileobj # Get path from the given CalcJobNode if not defined by user if path is None: @@ -151,8 +148,7 @@ def calcjob_remotecat(calcjob, path): @click.argument('path', type=click.STRING, required=False) @decorators.with_dbenv() def calcjob_outputcat(calcjob, path): - """ - Show the contents of one of the calcjob retrieved outputs. + """Show the contents of one of the calcjob retrieved outputs. You can specify the relative PATH in the retrieved folder of the CalcJob. @@ -160,8 +156,8 @@ def calcjob_outputcat(calcjob, path): Content can only be shown after the daemon has retrieved the remote files. """ import errno - from shutil import copyfileobj import sys + from shutil import copyfileobj try: retrieved = calcjob.outputs.retrieved @@ -204,8 +200,7 @@ def calcjob_outputcat(calcjob, path): @click.argument('path', type=click.STRING, required=False) @click.option('-c', '--color', 'color', is_flag=True, default=False, help='color folders with a different color') def calcjob_inputls(calcjob, path, color): - """ - Show the list of the generated calcjob input files. + """Show the list of the generated calcjob input files. You can specify a relative PATH in the raw input folder of the CalcJob. @@ -225,8 +220,7 @@ def calcjob_inputls(calcjob, path, color): @click.argument('path', type=click.STRING, required=False) @click.option('-c', '--color', 'color', is_flag=True, default=False, help='color folders with a different color') def calcjob_outputls(calcjob, path, color): - """ - Show the list of the retrieved calcjob output files. + """Show the list of the retrieved calcjob output files. You can specify a relative PATH in the retrieved folder of the CalcJob. @@ -255,8 +249,7 @@ def calcjob_outputls(calcjob, path, color): @options.FORCE() @options.EXIT_STATUS() def calcjob_cleanworkdir(calcjobs, past_days, older_than, computers, force, exit_status): - """ - Clean all content of all output remote folders of calcjobs. + """Clean all content of all output remote folders of calcjobs. If no explicit calcjobs are specified as arguments, one or both of the -p and -o options has to be specified. If both are specified, a logical AND is done between the two, i.e. the calcjobs that will be cleaned have been @@ -266,11 +259,10 @@ def calcjob_cleanworkdir(calcjobs, past_days, older_than, computers, force, exit from aiida.orm.utils.remote import get_calcjob_remote_paths if calcjobs: - if (past_days is not None and older_than is not None): + if past_days is not None and older_than is not None: echo.echo_critical('specify either explicit calcjobs or use the filtering options') - else: - if (past_days is None and older_than is None): - echo.echo_critical('if no explicit calcjobs are specified, at least one filtering option is required') + elif past_days is None and older_than is None: + echo.echo_critical('if no explicit calcjobs are specified, at least one filtering option is required') calcjobs_pks = [calcjob.pk for calcjob in calcjobs] path_mapping = get_calcjob_remote_paths( @@ -293,14 +285,13 @@ def calcjob_cleanworkdir(calcjobs, past_days, older_than, computers, force, exit user = orm.User.collection.get_default() for computer_uuid, paths in path_mapping.items(): - counter = 0 computer = orm.load_computer(uuid=computer_uuid) transport = orm.AuthInfo.collection.get(dbcomputer_id=computer.pk, aiidauser_id=user.pk).get_transport() with transport: for remote_folder in paths: - remote_folder._clean(transport=transport) # pylint:disable=protected-access + remote_folder._clean(transport=transport) counter += 1 echo.echo_success(f'{counter} remote folders cleaned on {computer.label}') diff --git a/aiida/cmdline/commands/cmd_code.py b/aiida/cmdline/commands/cmd_code.py index eb4cc1ea7c..26f9e8d6d8 100644 --- a/aiida/cmdline/commands/cmd_code.py +++ b/aiida/cmdline/commands/cmd_code.py @@ -27,7 +27,7 @@ def verdi_code(): """Setup and manage codes.""" -def create_code(ctx: click.Context, cls, non_interactive: bool, **kwargs): # pylint: disable=unused-argument +def create_code(ctx: click.Context, cls, non_interactive: bool, **kwargs): """Create a new `Code` instance.""" try: instance = cls(**kwargs) @@ -47,15 +47,14 @@ def create_code(ctx: click.Context, cls, non_interactive: bool, **kwargs): # py cls=DynamicEntryPointCommandGroup, command=create_code, entry_point_group='aiida.data', - entry_point_name_filter=r'core\.code\..*' + entry_point_name_filter=r'core\.code\..*', ) def code_create(): """Create a new code.""" def get_default(key, ctx): - """ - Get the default argument using a user instance property + """Get the default argument using a user instance property :param value: The name of the property to use :param ctx: The click context (which will be used to get the user) :return: The default value, or None @@ -78,10 +77,10 @@ def get_on_computer(ctx): return not getattr(ctx.code_builder, 'is_local')() -# pylint: disable=unused-argument def set_code_builder(ctx, param, value): """Set the code spec for defaults of following options.""" from aiida.orm.utils.builders.code import CodeBuilder + ctx.code_builder = CodeBuilder.from_code(value) return value @@ -131,7 +130,7 @@ def setup_code(ctx, non_interactive, **kwargs): try: code.store() - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: echo.echo_critical(f'Unable to store the Code: {exception}') echo.echo_success(f'Code<{code.pk}> {code.full_label} created') @@ -339,7 +338,7 @@ def relabel(code, label): '-d', '--default-calc-job-plugin', type=types.PluginParamType(group='calculations', load=False), - help='Filter codes by their optional default calculation job plugin.' + help='Filter codes by their optional default calculation job plugin.', ) @options.ALL(help='Include hidden codes.') @options.ALL_USERS(help='Include codes from all users.') @@ -348,7 +347,6 @@ def relabel(code, label): @click.option('-o', '--show-owner', 'show_owner', is_flag=True, default=False, help='Show owners of codes.') @with_dbenv() def code_list(computer, default_calc_job_plugin, all_entries, all_users, raw, show_owner, project): - # pylint: disable=too-many-branches,too-many-locals """List the available codes.""" from aiida import orm from aiida.orm.utils.node import load_node_class @@ -388,14 +386,10 @@ def code_list(computer, default_calc_job_plugin, all_entries, all_users, raw, sh tag='computer', with_node='code', project=projections.get('computer', None), - filters=filters.get('computer', None) + filters=filters.get('computer', None), ) query.append( - orm.User, - tag='user', - with_node='code', - project=projections.get('user', None), - filters=filters.get('user', None) + orm.User, tag='user', with_node='code', project=projections.get('user', None), filters=filters.get('user', None) ) query.order_by({'code': {'id': 'asc'}}) diff --git a/aiida/cmdline/commands/cmd_computer.py b/aiida/cmdline/commands/cmd_computer.py index 27fba6bfc9..66303551cd 100644 --- a/aiida/cmdline/commands/cmd_computer.py +++ b/aiida/cmdline/commands/cmd_computer.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,too-many-statements,too-many-branches """`verdi computer` command.""" from copy import deepcopy from functools import partial @@ -30,10 +29,9 @@ def verdi_computer(): def get_computer_names(): - """ - Retrieve the list of computers in the DB. - """ + """Retrieve the list of computers in the DB.""" from aiida.orm.querybuilder import QueryBuilder + builder = QueryBuilder() builder.append(entity_type='computer', project=['label']) if builder.count() > 0: @@ -42,11 +40,11 @@ def get_computer_names(): return [] -def prompt_for_computer_configuration(computer): # pylint: disable=unused-argument +def prompt_for_computer_configuration(computer): pass -def _computer_test_get_jobs(transport, scheduler, authinfo, computer): # pylint: disable=unused-argument +def _computer_test_get_jobs(transport, scheduler, authinfo, computer): """Internal test to check if it is possible to check the queue state. :param transport: an open transport @@ -58,7 +56,7 @@ def _computer_test_get_jobs(transport, scheduler, authinfo, computer): # pylint return True, f'{len(found_jobs)} jobs found in the queue' -def _computer_test_no_unexpected_output(transport, scheduler, authinfo, computer): # pylint: disable=unused-argument +def _computer_test_no_unexpected_output(transport, scheduler, authinfo, computer): """Test that there is no unexpected output from the connection. This can happen if e.g. there is some spurious command in the @@ -93,7 +91,7 @@ def _computer_test_no_unexpected_output(transport, scheduler, authinfo, computer return True, None -def _computer_get_remote_username(transport, scheduler, authinfo, computer): # pylint: disable=unused-argument +def _computer_get_remote_username(transport, scheduler, authinfo, computer): """Internal test to check if it is possible to determine the username on the remote. :param transport: an open transport @@ -105,9 +103,8 @@ def _computer_get_remote_username(transport, scheduler, authinfo, computer): # return True, remote_user -def _computer_create_temp_file(transport, scheduler, authinfo, computer): # pylint: disable=unused-argument - """ - Internal test to check if it is possible to create a temporary file +def _computer_create_temp_file(transport, scheduler, authinfo, computer): + """Internal test to check if it is possible to create a temporary file and then delete it in the work directory :note: exceptions could be raised @@ -187,7 +184,7 @@ def time_use_login_shell(authinfo, auth_params, use_login_shell: bool, iteration return sum(timings) / iterations -def _computer_use_login_shell_performance(transport, scheduler, authinfo, computer): # pylint: disable=unused-argument +def _computer_use_login_shell_performance(transport, scheduler, authinfo, computer): """Execute a command over the transport with and without the ``use_login_shell`` option enabled. By default, AiiDA uses a login shell when connecting to a computer in order to operate in the same environment as a @@ -232,8 +229,7 @@ def _computer_use_login_shell_performance(transport, scheduler, authinfo, comput def get_parameter_default(parameter, ctx): - """ - Get the value for a specific parameter from the computer_builder or the default value of that option + """Get the value for a specific parameter from the computer_builder or the default value of that option :param parameter: parameter name :param ctx: click context of the command @@ -255,10 +251,10 @@ def get_parameter_default(parameter, ctx): return value -# pylint: disable=unused-argument def set_computer_builder(ctx, param, value): """Set the computer spec for defaults of following options.""" from aiida.orm.utils.builders.computer import ComputerBuilder + ctx.computer_builder = ComputerBuilder.from_computer(value) return value @@ -398,7 +394,8 @@ def computer_enable(computer, user): def computer_disable(computer, user): """Disable the computer for the given user. - Thi can be useful, for example, when a computer is under maintenance.""" + Thi can be useful, for example, when a computer is under maintenance. + """ from aiida.common.exceptions import NotExistent try: @@ -433,9 +430,9 @@ def computer_list(all_entries, raw): if not computers: echo.echo_report("No computers configured yet. Use 'verdi computer setup'") - sort = lambda computer: computer.label - highlight = lambda comp: comp.is_configured and comp.is_user_enabled(user) - hide = lambda comp: not (comp.is_configured and comp.is_user_enabled(user)) and not all_entries + sort = lambda computer: computer.label # noqa: E731 + highlight = lambda comp: comp.is_configured and comp.is_user_enabled(user) # noqa: E731 + hide = lambda comp: not (comp.is_configured and comp.is_user_enabled(user)) and not all_entries # noqa: E731 echo.echo_formatted_list(computers, ['label'], sort=sort, highlight=highlight, hide=hide) @@ -499,8 +496,7 @@ def computer_relabel(computer, label): @arguments.COMPUTER() @with_dbenv() def computer_test(user, print_traceback, computer): - """ - Test the connection to a computer. + """Test the connection to a computer. It tries to connect, to get the list of calculations on the queue and to perform other tests. @@ -551,20 +547,19 @@ def computer_test(user, print_traceback, computer): scheduler.set_transport(transport) for test, test_label in tests.items(): - echo.echo(f'* {test_label}... ', nl=False) num_tests += 1 try: success, message = test( transport=transport, scheduler=scheduler, authinfo=authinfo, computer=computer ) - except Exception as exception: # pylint:disable=broad-except + except Exception as exception: success = False - message = f'{exception.__class__.__name__}: {str(exception)}' + message = f'{exception.__class__.__name__}: {exception!s}' if print_traceback: message += '\n Full traceback:\n' - message += '\n'.join([f' {l}' for l in traceback.format_exc().splitlines()]) + message += '\n'.join([f' {line}' for line in traceback.format_exc().splitlines()]) else: message += '\n Use the `--print-traceback` option to see the full traceback.' @@ -575,25 +570,24 @@ def computer_test(user, print_traceback, computer): echo.echo(message) else: echo.echo('[Failed]', fg=echo.COLORS['error']) + elif message: + echo.echo('[OK]: ', fg=echo.COLORS['success'], nl=False) + echo.echo(message) else: - if message: - echo.echo('[OK]: ', fg=echo.COLORS['success'], nl=False) - echo.echo(message) - else: - echo.echo('[OK]', fg=echo.COLORS['success']) + echo.echo('[OK]', fg=echo.COLORS['success']) if num_failures: echo.echo_warning(f'{num_failures} out of {num_tests} tests failed') else: echo.echo_success(f'all {num_tests} tests succeeded') - except Exception: # pylint:disable=broad-except + except Exception: echo.echo('[FAILED]: ', fg=echo.COLORS['error'], nl=False) message = 'Error while trying to connect to the computer' if print_traceback: message += '\n Full traceback:\n' - message += '\n'.join([f' {l}' for l in traceback.format_exc().splitlines()]) + message += '\n'.join([f' {line}' for line in traceback.format_exc().splitlines()]) else: message += '\n Use the `--print-traceback` option to see the full traceback.' @@ -605,8 +599,7 @@ def computer_test(user, print_traceback, computer): @arguments.COMPUTER() @with_dbenv() def computer_delete(computer): - """ - Delete a computer. + """Delete a computer. Note that it is not possible to delete the computer if there are calculations that are using it. """ @@ -631,8 +624,9 @@ def list_commands(self, ctx): subcommands.extend(get_entry_point_names('aiida.transports')) return subcommands - def get_command(self, ctx, name): # pylint: disable=arguments-renamed + def get_command(self, ctx, name): from aiida.transports import cli as transport_cli + try: command = transport_cli.create_configure_cmd(name) except EntryPointError: @@ -661,7 +655,8 @@ def computer_config_show(computer, user, defaults, as_option_string): transport_cls = computer.get_transport_class() option_list = [ - param for param in transport_cli.create_configure_cmd(computer.transport_type).params + param + for param in transport_cli.create_configure_cmd(computer.transport_type).params if isinstance(param, click.core.Option) ] option_list = [option for option in option_list if option.name in transport_cls.get_valid_auth_params()] @@ -677,12 +672,13 @@ def computer_config_show(computer, user, defaults, as_option_string): t_opt = transport_cls.auth_options[option.name] if config.get(option.name) or config.get(option.name) is False: if t_opt.get('switch'): - option_value = option.opts[-1] if config.get( - option.name - ) else f"--no-{option.name.replace('_', '-')}" + option_value = ( + option.opts[-1] if config.get(option.name) else f"--no-{option.name.replace('_', '-')}" + ) elif t_opt.get('is_flag'): - is_default = config.get(option.name - ) == transport_cli.transport_option_default(option.name, computer) + is_default = config.get(option.name) == transport_cli.transport_option_default( + option.name, computer + ) option_value = option.opts[-1] if is_default else '' else: option_value = f'{option.opts[-1]}={option.type(config[option.name])}' diff --git a/aiida/cmdline/commands/cmd_config.py b/aiida/cmdline/commands/cmd_config.py index 95e0533d7c..0255d5d635 100644 --- a/aiida/cmdline/commands/cmd_config.py +++ b/aiida/cmdline/commands/cmd_config.py @@ -11,8 +11,8 @@ from __future__ import annotations import json -from pathlib import Path import textwrap +from pathlib import Path import click @@ -50,20 +50,24 @@ def verdi_config_list(ctx, prefix, description: bool): option_values = config.get_options(profile.name if profile else None) def _join(val): - """split arrays into multiple lines.""" + """Split arrays into multiple lines.""" if isinstance(val, list): return '\n'.join(str(v) for v in val) return val if description: - table = [[name, source, _join(value), '\n'.join(textwrap.wrap(c.description))] - for name, (c, source, value) in option_values.items() - if name.startswith(prefix)] + table = [ + [name, source, _join(value), '\n'.join(textwrap.wrap(c.description))] + for name, (c, source, value) in option_values.items() + if name.startswith(prefix) + ] headers = ['name', 'source', 'value', 'description'] else: - table = [[name, source, _join(value)] - for name, (c, source, value) in option_values.items() - if name.startswith(prefix)] + table = [ + [name, source, _join(value)] + for name, (c, source, value) in option_values.items() + if name.startswith(prefix) + ] headers = ['name', 'source', 'value'] # sort by name @@ -86,7 +90,7 @@ def verdi_config_show(ctx, option): 'values': { 'default': '' if option.default is None else option.default, 'global': config.options.get(option.name, ''), - } + }, } if not profile: diff --git a/aiida/cmdline/commands/cmd_daemon.py b/aiida/cmdline/commands/cmd_daemon.py index bc568d56ba..47801adbe8 100644 --- a/aiida/cmdline/commands/cmd_daemon.py +++ b/aiida/cmdline/commands/cmd_daemon.py @@ -21,7 +21,7 @@ from aiida.cmdline.utils import decorators, echo -def validate_daemon_workers(ctx, param, value): # pylint: disable=unused-argument,invalid-name +def validate_daemon_workers(ctx, param, value): """Validate the value for the number of daemon workers to start with default set by config.""" if value is None: value = ctx.obj.config.get_option('daemon.default_workers', ctx.obj.profile.name) @@ -251,7 +251,8 @@ def start_circus(foreground, number): .. note:: this should not be called directly from the commandline! """ from aiida.engine.daemon.client import get_daemon_client - get_daemon_client()._start_daemon(number_workers=number, foreground=foreground) # pylint: disable=protected-access + + get_daemon_client()._start_daemon(number_workers=number, foreground=foreground) @verdi_daemon.command('worker') @@ -259,4 +260,5 @@ def start_circus(foreground, number): def worker(): """Run a single daemon worker in the current interpreter.""" from aiida.engine.daemon.worker import start_daemon_worker + start_daemon_worker() diff --git a/aiida/cmdline/commands/cmd_data/cmd_bands.py b/aiida/cmdline/commands/cmd_data/cmd_bands.py index 8b916e843b..d04b7b1f64 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_bands.py +++ b/aiida/cmdline/commands/cmd_data/cmd_bands.py @@ -20,8 +20,16 @@ LIST_PROJECT_HEADERS = ['ID', 'Formula', 'Ctime', 'Label'] EXPORT_FORMATS = [ - 'agr', 'agr_batch', 'dat_blocks', 'dat_multicolumn', 'gnuplot', 'json', 'mpl_pdf', 'mpl_png', 'mpl_singlefile', - 'mpl_withjson' + 'agr', + 'agr_batch', + 'dat_blocks', + 'dat_multicolumn', + 'gnuplot', + 'json', + 'mpl_pdf', + 'mpl_png', + 'mpl_singlefile', + 'mpl_withjson', ] VISUALIZATION_FORMATS = ['xmgrace'] @@ -31,7 +39,6 @@ def bands(): """Manipulate BandsData objects (band structures).""" -# pylint: disable=too-many-arguments @bands.command('list') @decorators.with_dbenv() @list_options @@ -100,15 +107,13 @@ def bands_show(data, fmt): '--y-min-lim', type=click.FLOAT, default=None, - help='The minimum value for the y axis.' - ' Default: minimum of all bands' + help='The minimum value for the y axis.' ' Default: minimum of all bands', ) @click.option( '--y-max-lim', type=click.FLOAT, default=None, - help='The maximum value for the y axis.' - ' Default: maximum of all bands' + help='The maximum value for the y axis.' ' Default: maximum of all bands', ) @click.option( '-o', @@ -117,14 +122,14 @@ def bands_show(data, fmt): default=None, help='If present, store the output directly on a file ' 'with the given name. It is essential to use this option ' - 'if more than one file needs to be created.' + 'if more than one file needs to be created.', ) @options.FORCE(help='If passed, overwrite files without checking.') @click.option( '--prettify-format', default=None, type=click.Choice(Prettifier.get_prettifiers()), - help='The style of labels for the prettifier' + help='The style of labels for the prettifier', ) @decorators.with_dbenv() def bands_export(fmt, y_min_lim, y_max_lim, output, force, prettify_format, datum): diff --git a/aiida/cmdline/commands/cmd_data/cmd_cif.py b/aiida/cmdline/commands/cmd_data/cmd_cif.py index 7f5a588ba3..6ea57988ef 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_cif.py +++ b/aiida/cmdline/commands/cmd_data/cmd_cif.py @@ -93,7 +93,7 @@ def cif_content(data): try: echo.echo(node.get_content()) except IOError as exception: - echo.echo_warning(f'could not read the content for CifData<{node.pk}>: {str(exception)}') + echo.echo_warning(f'could not read the content for CifData<{node.pk}>: {exception!s}') @cif.command('export') @@ -122,6 +122,6 @@ def cif_import(filename): try: node, _ = CifData.get_or_create(filename) - echo.echo_success(f'imported {str(node)}') + echo.echo_success(f'imported {node!s}') except ValueError as err: echo.echo_critical(str(err)) diff --git a/aiida/cmdline/commands/cmd_data/cmd_export.py b/aiida/cmdline/commands/cmd_data/cmd_export.py index c4f5cb56f0..6ce72fb652 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_export.py +++ b/aiida/cmdline/commands/cmd_data/cmd_export.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -This module provides export functionality to all data types +"""This module provides export functionality to all data types """ import click @@ -22,7 +21,7 @@ 'reduce_symmetry', is_flag=True, default=None, - help='Do (default) or do not perform symmetry reduction.' + help='Do (default) or do not perform symmetry reduction.', ), click.option( '--parameter-data', @@ -34,29 +33,28 @@ ' Dict in the output, aforementioned' ' Dict is picked automatically. Instead, the' ' option is used in the case the calculation produces' - ' more than a single instance of Dict.' + ' more than a single instance of Dict.', ), click.option( '--dump-aiida-database/--no-dump-aiida-database', 'dump_aiida_database', is_flag=True, default=None, - help='Export (default) or do not export AiiDA database to the CIF file.' + help='Export (default) or do not export AiiDA database to the CIF file.', ), click.option( '--exclude-external-contents/--no-exclude-external-contents', 'exclude_external_contents', is_flag=True, default=None, - help='Do not (default) or do save the contents for external resources even if URIs are provided' + help='Do not (default) or do save the contents for external resources even if URIs are provided', ), click.option('--gzip/--no-gzip', is_flag=True, default=None, help='Do or do not (default) gzip large files.'), click.option( '--gzip-threshold', type=click.INT, default=None, - help='Specify the minimum size of exported file which should' - ' be gzipped.' + help='Specify the minimum size of exported file which should' ' be gzipped.', ), click.option( '-o', @@ -65,7 +63,7 @@ default=None, help='If present, store the output directly on a file ' 'with the given name. It is essential to use this option ' - 'if more than one file needs to be created.' + 'if more than one file needs to be created.', ), options.FORCE(help='Overwrite files without checking.'), ] @@ -79,8 +77,7 @@ def export_options(func): def data_export(node, output_fname, fileformat, other_args=None, overwrite=False): - """ - Depending on the parameters, either print the (single) output file on + """Depending on the parameters, either print the (single) output file on screen, or store the file(s) on disk. :param node: the Data node to print or store on disk @@ -97,7 +94,6 @@ def data_export(node, output_fname, fileformat, other_args=None, overwrite=False if other_args is None: other_args = {} try: - # pylint: disable=protected-access if output_fname: try: node.export(output_fname, fileformat=fileformat, overwrite=overwrite, **other_args) diff --git a/aiida/cmdline/commands/cmd_data/cmd_list.py b/aiida/cmdline/commands/cmd_data/cmd_list.py index c518fce366..5cd094a538 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_list.py +++ b/aiida/cmdline/commands/cmd_data/cmd_list.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -This module provides list functionality to all data types. +"""This module provides list functionality to all data types. """ from aiida.cmdline.params import options @@ -30,9 +29,7 @@ def list_options(func): def query(datatype, project, past_days, group_pks, all_users): - """ - Perform the query - """ + """Perform the query""" import datetime from aiida import orm @@ -79,11 +76,8 @@ def query(datatype, project, past_days, group_pks, all_users): return results -# pylint: disable=unused-argument,too-many-arguments def data_list(datatype, columns, elements, elements_only, formula_mode, past_days, groups, all_users): - """ - List stored objects - """ + """List stored objects""" columns_dict = { 'ID': 'id', 'Id': 'id', diff --git a/aiida/cmdline/commands/cmd_data/cmd_remote.py b/aiida/cmdline/commands/cmd_data/cmd_remote.py index 84e38ca37f..6dc8d3755b 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_remote.py +++ b/aiida/cmdline/commands/cmd_data/cmd_remote.py @@ -25,7 +25,8 @@ def remote(): Computers set up in AiiDA (e.g. where a CalcJob will run). This folder is called "remote" in the sense that it is on a Computer and not in the AiiDA repository. Note, however, that the "remote" computer - could also be "localhost".""" + could also be "localhost". + """ @remote.command('ls') @@ -35,18 +36,18 @@ def remote(): def remote_ls(ls_long, path, datum): """List content of a (sub)directory in a RemoteData object.""" import datetime + try: content = datum.listdir_withattributes(path=path) except (IOError, OSError) as err: - echo.echo_critical( - f'Unable to access the remote folder or file, check if it exists.\nOriginal error: {str(err)}' - ) + echo.echo_critical(f'Unable to access the remote folder or file, check if it exists.\nOriginal error: {err!s}') for metadata in content: if ls_long: mtime = datetime.datetime.fromtimestamp(metadata['attributes'].st_mtime) pre_line = '{} {:10} {} '.format( - stat.filemode(metadata['attributes'].st_mode), metadata['attributes'].st_size, - mtime.strftime('%d %b %Y %H:%M') + stat.filemode(metadata['attributes'].st_mode), + metadata['attributes'].st_size, + mtime.strftime('%d %b %Y %H:%M'), ) echo.echo(pre_line, nl=False) if metadata['isdir']: @@ -63,6 +64,7 @@ def remote_cat(datum, path): import os import sys import tempfile + try: with tempfile.NamedTemporaryFile(delete=False) as tmpf: tmpf.close() @@ -70,7 +72,7 @@ def remote_cat(datum, path): with open(tmpf.name, encoding='utf8') as fhandle: sys.stdout.write(fhandle.read()) except IOError as err: - echo.echo_critical(f'{err.errno}: {str(err)}') + echo.echo_critical(f'{err.errno}: {err!s}') try: os.remove(tmpf.name) diff --git a/aiida/cmdline/commands/cmd_data/cmd_show.py b/aiida/cmdline/commands/cmd_data/cmd_show.py index 26f87f2382..464923a72f 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_show.py +++ b/aiida/cmdline/commands/cmd_data/cmd_show.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -This allows to manage showfunctionality to all data types. +"""This allows to manage showfunctionality to all data types. """ import pathlib @@ -17,24 +16,20 @@ def has_executable(exec_name): - """ - :return: True if executable can be found in PATH, False otherwise. - """ + """:return: True if executable can be found in PATH, False otherwise.""" import shutil + return shutil.which(exec_name) is not None def _show_jmol(exec_name, trajectory_list, **_kwargs): - """ - Plugin for jmol - """ + """Plugin for jmol""" import subprocess import tempfile if not has_executable(exec_name): echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") - # pylint: disable=protected-access with tempfile.NamedTemporaryFile(mode='w+b') as handle: for trajectory in trajectory_list: handle.write(trajectory._exportcontent('cif')[0]) @@ -48,9 +43,7 @@ def _show_jmol(exec_name, trajectory_list, **_kwargs): def _show_xcrysden(exec_name, trajectory_list, **_kwargs): - """ - Plugin for xcrysden - """ + """Plugin for xcrysden""" import subprocess import tempfile @@ -61,9 +54,7 @@ def _show_xcrysden(exec_name, trajectory_list, **_kwargs): if not has_executable(exec_name): echo.echo_critical(f"No executable '{exec_name}' found.") - # pylint: disable=protected-access with tempfile.NamedTemporaryFile(mode='w+b', suffix='.xsf') as tmpf: - tmpf.write(obj._exportcontent('xsf')[0]) tmpf.flush() @@ -74,40 +65,31 @@ def _show_xcrysden(exec_name, trajectory_list, **_kwargs): echo.echo_error(f'the call to {exec_name} ended with an error.') -# pylint: disable=unused-argument def _show_mpl_pos(exec_name, trajectory_list, **kwargs): - """ - Produces a matplotlib plot of the trajectory - """ + """Produces a matplotlib plot of the trajectory""" for traj in trajectory_list: traj.show_mpl_pos(**kwargs) -# pylint: disable=unused-argument def _show_mpl_heatmap(exec_name, trajectory_list, **kwargs): - """ - Produces a matplotlib plot of the trajectory - """ + """Produces a matplotlib plot of the trajectory""" for traj in trajectory_list: traj.show_mpl_heatmap(**kwargs) -# pylint: disable=unused-argument def _show_ase(exec_name, structure_list): - """ - Plugin to show the structure with the ASE visualizer - """ + """Plugin to show the structure with the ASE visualizer""" try: from ase.visualize import view + for structure in structure_list: view(structure.get_ase()) - except ImportError: # pylint: disable=try-except-raise + except ImportError: raise def _show_vesta(exec_name, structure_list): - """ - Plugin for VESTA + """Plugin for VESTA This VESTA plugin was added by Yue-Wen FANG and Abel Carreras at Kyoto University in the group of Prof. Isao Tanaka's lab @@ -115,7 +97,6 @@ def _show_vesta(exec_name, structure_list): import subprocess import tempfile - # pylint: disable=protected-access with tempfile.NamedTemporaryFile(mode='w+b', suffix='.cif') as tmpf: for structure in structure_list: tmpf.write(structure._exportcontent('cif')[0]) @@ -134,9 +115,7 @@ def _show_vesta(exec_name, structure_list): def _show_vmd(exec_name, structure_list): - """ - Plugin for vmd - """ + """Plugin for vmd""" import subprocess import tempfile @@ -144,7 +123,6 @@ def _show_vmd(exec_name, structure_list): raise MultipleObjectsError('Visualization of multiple objects is not implemented') structure = structure_list[0] - # pylint: disable=protected-access with tempfile.NamedTemporaryFile(mode='w+b', suffix='.xsf') as tmpf: tmpf.write(structure._exportcontent('xsf')[0]) tmpf.flush() @@ -162,9 +140,7 @@ def _show_vmd(exec_name, structure_list): def _show_xmgrace(exec_name, list_bands): - """ - Plugin for showing the bands with the XMGrace plotting software. - """ + """Plugin for showing the bands with the XMGrace plotting software.""" import subprocess import sys import tempfile @@ -175,13 +151,12 @@ def _show_xmgrace(exec_name, list_bands): current_band_number = 0 with tempfile.TemporaryDirectory() as tmpdir: - dirpath = pathlib.Path(tmpdir) for iband, bnds in enumerate(list_bands): # extract number of bands nbnds = bnds.get_bands().shape[1] - text, _ = bnds._exportcontent( # pylint: disable=protected-access + text, _ = bnds._exportcontent( 'agr', setnumber_offset=current_band_number, color_number=(iband + 1 % MAX_NUM_AGR_COLORS) ) # write a tempfile diff --git a/aiida/cmdline/commands/cmd_data/cmd_singlefile.py b/aiida/cmdline/commands/cmd_data/cmd_singlefile.py index 3863510b74..a4dd615ba7 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_singlefile.py +++ b/aiida/cmdline/commands/cmd_data/cmd_singlefile.py @@ -27,4 +27,4 @@ def singlefile_content(datum): try: echo.echo(datum.get_content()) except (IOError, OSError) as exception: - echo.echo_critical(f'could not read the content for SinglefileData<{datum.pk}>: {str(exception)}') + echo.echo_critical(f'could not read the content for SinglefileData<{datum.pk}>: {exception!s}') diff --git a/aiida/cmdline/commands/cmd_data/cmd_structure.py b/aiida/cmdline/commands/cmd_data/cmd_structure.py index 628c338244..6b7f6d13f2 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_structure.py +++ b/aiida/cmdline/commands/cmd_data/cmd_structure.py @@ -24,8 +24,7 @@ def _store_structure(new_structure, dry_run): - """ - Store a structure and print a message (or don't store it if it's a dry_run) + """Store a structure and print a message (or don't store it if it's a dry_run) This is a utility function to avoid code duplication. @@ -46,7 +45,6 @@ def structure(): """Manipulate StructureData objects (crystal structures).""" -# pylint: disable=too-many-locals,too-many-branches @structure.command('list') @options.FORMULA_MODE() @options.WITH_ELEMENTS() @@ -60,8 +58,14 @@ def structure_list(elements, raw, formula_mode, past_days, groups, all_users): elements_only = False lst = data_list( - StructureData, ['Id', 'Label', 'Kinds', 'Sites'], elements, elements_only, formula_mode, past_days, groups, - all_users + StructureData, + ['Id', 'Label', 'Kinds', 'Sites'], + elements, + elements_only, + formula_mode, + past_days, + groups, + all_users, ) entry_list = [] @@ -164,14 +168,14 @@ def structure_import(): type=click.FLOAT, show_default=True, default=1.0, - help='The factor by which the cell accomodating the structure should be increased (angstrom).' + help='The factor by which the cell accomodating the structure should be increased (angstrom).', ) @click.option( '--vacuum-addition', type=click.FLOAT, show_default=True, default=10.0, - help='The distance to add to the unit cell after vacuum factor was applied to expand in each dimension (angstrom).' + help='The distance to add to the unit cell after vacuum factor was applied to expand in each dimension (angstrom).', ) @click.option( '--pbc', @@ -179,16 +183,14 @@ def structure_import(): nargs=3, show_default=True, default=[0, 0, 0], - help='Set periodic boundary conditions for each lattice direction, where 0 means periodic and 1 means periodic.' + help='Set periodic boundary conditions for each lattice direction, where 0 means periodic and 1 means periodic.', ) @click.option('--label', type=click.STRING, show_default=False, help='Set the structure node label (empty by default)') @options.GROUP() @options.DRY_RUN() @decorators.with_dbenv() def import_aiida_xyz(filename, vacuum_factor, vacuum_addition, pbc, label, group, dry_run): - """ - Import structure in XYZ format using AiiDA's internal importer - """ + """Import structure in XYZ format using AiiDA's internal importer""" from aiida.orm import StructureData with open(filename, encoding='utf8') as fobj: @@ -205,11 +207,8 @@ def import_aiida_xyz(filename, vacuum_factor, vacuum_addition, pbc, label, group raise click.BadParameter('values for pbc must be either 0 or 1', param_hint='pbc') try: - new_structure._parse_xyz(xyz_txt) # pylint: disable=protected-access - new_structure._adjust_default_cell( # pylint: disable=protected-access - vacuum_addition=vacuum_addition, - vacuum_factor=vacuum_factor, - pbc=pbc_bools) + new_structure._parse_xyz(xyz_txt) + new_structure._adjust_default_cell(vacuum_addition=vacuum_addition, vacuum_factor=vacuum_factor, pbc=pbc_bools) except (ValueError, TypeError) as err: echo.echo_critical(str(err)) @@ -230,9 +229,7 @@ def import_aiida_xyz(filename, vacuum_factor, vacuum_addition, pbc, label, group @options.DRY_RUN() @decorators.with_dbenv() def import_ase(filename, label, group, dry_run): - """ - Import structure with the ase library that supports a number of different formats - """ + """Import structure with the ase library that supports a number of different formats""" from aiida.orm import StructureData try: diff --git a/aiida/cmdline/commands/cmd_data/cmd_trajectory.py b/aiida/cmdline/commands/cmd_data/cmd_trajectory.py index 5a4cf3eb61..266cd32cd8 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_trajectory.py +++ b/aiida/cmdline/commands/cmd_data/cmd_trajectory.py @@ -74,13 +74,13 @@ def trajectory_list(raw, past_days, groups, all_users): '--sampling-stepsize', type=click.INT, default=None, - help='Sample positions in plot every sampling_stepsize timestep' + help='Sample positions in plot every sampling_stepsize timestep', ) @click.option( '--stepsize', type=click.INT, default=None, - help='The stepsize for the trajectory, set it higher to reduce number of points' + help='The stepsize for the trajectory, set it higher to reduce number of points', ) @click.option('--mintime', type=click.INT, default=None, help='The time to plot from') @click.option('--maxtime', type=click.INT, default=None, help='The time to plot to') diff --git a/aiida/cmdline/commands/cmd_data/cmd_upf.py b/aiida/cmdline/commands/cmd_data/cmd_upf.py index 1dd0b37c13..7799ceb327 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_upf.py +++ b/aiida/cmdline/commands/cmd_data/cmd_upf.py @@ -35,18 +35,18 @@ def upf(): '--stop-if-existing', is_flag=True, default=False, - help='Interrupt pseudos import if a pseudo was already present in the AiiDA database' + help='Interrupt pseudos import if a pseudo was already present in the AiiDA database', ) @decorators.with_dbenv() def upf_uploadfamily(folder, group_label, group_description, stop_if_existing): - """ - Create a new UPF family from a folder of UPF files. + """Create a new UPF family from a folder of UPF files. Returns the numbers of files found and the number of nodes uploaded. Call without parameters to get some help. """ from aiida.orm.nodes.data.upf import upload_upf_family + files_found, files_uploaded = upload_upf_family(folder, group_label, group_description, stop_if_existing) echo.echo_success(f'UPF files found: {files_found}. New files uploaded: {files_uploaded}') @@ -61,18 +61,16 @@ def upf_uploadfamily(folder, group_label, group_description, stop_if_existing): 'with_description', is_flag=True, default=False, - help='Show also the description for the UPF family' + help='Show also the description for the UPF family', ) @options.WITH_ELEMENTS() @decorators.with_dbenv() def upf_listfamilies(elements, with_description): - """ - List all UPF families that exist in the database. - """ + """List all UPF families that exist in the database.""" from aiida import orm from aiida.plugins import DataFactory - UpfData = DataFactory('core.upf') # pylint: disable=invalid-name + UpfData = DataFactory('core.upf') # noqa: N806 query = orm.QueryBuilder() query.append(UpfData, tag='upfdata') if elements is not None: @@ -107,8 +105,7 @@ def upf_listfamilies(elements, with_description): @arguments.GROUP() @decorators.with_dbenv() def upf_exportfamily(folder, group): - """ - Export a pseudopotential family into a folder. + """Export a pseudopotential family into a folder. Call without parameters to get some help. """ if group.is_empty: @@ -130,9 +127,7 @@ def upf_exportfamily(folder, group): @click.argument('filename', type=click.Path(exists=True, dir_okay=False, resolve_path=True)) @decorators.with_dbenv() def upf_import(filename): - """ - Import a UPF pseudopotential from a file. - """ + """Import a UPF pseudopotential from a file.""" from aiida.orm import UpfData node, _ = UpfData.get_or_create(filename) diff --git a/aiida/cmdline/commands/cmd_database.py b/aiida/cmdline/commands/cmd_database.py index 2653759f81..b738f20509 100644 --- a/aiida/cmdline/commands/cmd_database.py +++ b/aiida/cmdline/commands/cmd_database.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi database` commands.""" -# pylint: disable=unused-argument import click @@ -52,6 +51,7 @@ def database_migrate(ctx, force): .. deprecated:: v2.0.0 """ from aiida.cmdline.commands.cmd_storage import storage_migrate + ctx.forward(storage_migrate) @@ -69,7 +69,7 @@ def verdi_database_integrity(): '--table', default='db_dbnode', type=click.Choice(('db_dbcomment', 'db_dbcomputer', 'db_dbgroup', 'db_dbnode')), - help='The database table to operate on.' + help='The database table to operate on.', ) @click.option( '-a', '--apply-patch', is_flag=True, help='Actually apply the proposed changes instead of performing a dry run.' diff --git a/aiida/cmdline/commands/cmd_devel.py b/aiida/cmdline/commands/cmd_devel.py index d123030f15..e6e98268df 100644 --- a/aiida/cmdline/commands/cmd_devel.py +++ b/aiida/cmdline/commands/cmd_devel.py @@ -110,6 +110,7 @@ def devel_run_sql(sql): from sqlalchemy import text from aiida.storage.psql_dos.utils import create_sqlalchemy_engine + assert get_profile().storage_backend == 'core.psql_dos' with create_sqlalchemy_engine(get_profile().storage_config).connect() as connection: result = connection.execute(text(sql)).fetchall() @@ -125,6 +126,7 @@ def devel_run_sql(sql): def devel_play(): """Play the Aida triumphal march by Giuseppe Verdi.""" import webbrowser + webbrowser.open_new('http://upload.wikimedia.org/wikipedia/commons/3/32/Triumphal_March_from_Aida.ogg') @@ -155,7 +157,7 @@ def devel_launch_arithmetic_add(code, daemon, sleep): label='bash', computer=localhost, filepath_executable=which('bash'), - default_calc_job_plugin=default_calc_job_plugin + default_calc_job_plugin=default_calc_job_plugin, ).store() else: assert code.default_calc_job_plugin == default_calc_job_plugin @@ -203,8 +205,8 @@ def prepare_localhost(): scheduler_type='core.direct', workdir=tempfile.gettempdir(), ).store() - computer.configure(safe_interval=0.) - computer.set_minimum_job_poll_interval(0.) + computer.configure(safe_interval=0.0) + computer.set_minimum_job_poll_interval(0.0) if not computer.is_configured: computer.configure() diff --git a/aiida/cmdline/commands/cmd_group.py b/aiida/cmdline/commands/cmd_group.py index 42e25054cd..6b1643a683 100644 --- a/aiida/cmdline/commands/cmd_group.py +++ b/aiida/cmdline/commands/cmd_group.py @@ -52,7 +52,6 @@ def group_remove_nodes(group, nodes, clear, force): ) if not force: - if nodes: node_pks = [node.pk for node in nodes] @@ -138,8 +137,7 @@ def group_move_nodes(source_group, target_group, force, nodes, all_entries): if not force: click.confirm( - f'Are you sure you want to move {len(nodes)} nodes from {source_group} ' - f'to {target_group}?', abort=True + f'Are you sure you want to move {len(nodes)} nodes from {source_group} ' f'to {target_group}?', abort=True ) source_group.remove_nodes(nodes) @@ -222,7 +220,7 @@ def group_description(group, description): '--uuid', is_flag=True, default=False, - help='Show UUIDs together with PKs. Note: if the --raw option is also passed, PKs are not printed, but only UUIDs.' + help='Show UUIDs together with PKs. Note: if the --raw option is also passed, PKs are not printed, but only UUIDs.', ) @arguments.GROUP() @with_dbenv() @@ -277,12 +275,7 @@ def group_show(group, raw, limit, uuid): @options.ALL(help='Show groups of all types.') @options.TYPE_STRING() @click.option( - '-d', - '--with-description', - 'with_description', - is_flag=True, - default=False, - help='Show also the group description.' + '-d', '--with-description', 'with_description', is_flag=True, default=False, help='Show also the group description.' ) @click.option('-C', '--count', is_flag=True, default=False, help='Show also the number of nodes in the group.') @options.PAST_DAYS(help='Add a filter to show only groups created in the past N days.', default=None) @@ -291,32 +284,42 @@ def group_show(group, raw, limit, uuid): '--startswith', type=click.STRING, default=None, - help='Add a filter to show only groups for which the label begins with STRING.' + help='Add a filter to show only groups for which the label begins with STRING.', ) @click.option( '-e', '--endswith', type=click.STRING, default=None, - help='Add a filter to show only groups for which the label ends with STRING.' + help='Add a filter to show only groups for which the label ends with STRING.', ) @click.option( '-c', '--contains', type=click.STRING, default=None, - help='Add a filter to show only groups for which the label contains STRING.' + help='Add a filter to show only groups for which the label contains STRING.', ) @options.ORDER_BY(type=click.Choice(['id', 'label', 'ctime']), default='label') @options.ORDER_DIRECTION() @options.NODE(help='Show only the groups that contain the node.') @with_dbenv() def group_list( - all_users, user, all_entries, type_string, with_description, count, past_days, startswith, endswith, contains, - order_by, order_dir, node + all_users, + user, + all_entries, + type_string, + with_description, + count, + past_days, + startswith, + endswith, + contains, + order_by, + order_dir, + node, ): """Show a list of existing groups.""" - # pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements import datetime from tabulate import tabulate @@ -378,7 +381,7 @@ def group_list( 'type_string': lambda group: group.type_string, 'count': lambda group: group.count(), 'user': lambda group: group.user.email.strip(), - 'description': lambda group: group.description + 'description': lambda group: group.description, } table = [] @@ -428,7 +431,8 @@ def group_copy(source_group, destination_group): """Duplicate a group. More in detail, add all nodes from the source group to the destination group. - Note that the destination group may not exist.""" + Note that the destination group may not exist. + """ from aiida import orm dest_group, created = orm.Group.collection.get_or_create(label=destination_group) @@ -454,24 +458,18 @@ def verdi_group_path(): @click.option('-R', '--recursive', is_flag=True, default=False, help='Recursively list sub-paths encountered.') @click.option('-l', '--long', 'as_table', is_flag=True, default=False, help='List as a table, with sub-group count.') @click.option( - '-d', - '--with-description', - 'with_description', - is_flag=True, - default=False, - help='Show also the group description.' + '-d', '--with-description', 'with_description', is_flag=True, default=False, help='Show also the group description.' ) @click.option( '--no-virtual', 'no_virtual', is_flag=True, default=False, - help='Only show paths that fully correspond to an existing group.' + help='Only show paths that fully correspond to an existing group.', ) @click.option('--no-warn', is_flag=True, default=False, help='Do not issue a warning if any paths are invalid.') @with_dbenv() def group_path_ls(path, type_string, recursive, as_table, no_virtual, with_description, no_warn): - # pylint: disable=too-many-arguments,too-many-branches """Show a list of existing group paths.""" from aiida.plugins import GroupFactory from aiida.tools.groups.paths import GroupPath, InvalidPath @@ -488,6 +486,7 @@ def group_path_ls(path, type_string, recursive, as_table, no_virtual, with_descr if as_table or with_description: from tabulate import tabulate + headers = ['Path', 'Sub-Groups'] if with_description: headers.append('Description') @@ -497,7 +496,7 @@ def group_path_ls(path, type_string, recursive, as_table, no_virtual, with_descr continue row = [ child.path if child.is_virtual else click.style(child.path, bold=True), - len([c for c in child.walk() if not c.is_virtual]) + len([c for c in child.walk() if not c.is_virtual]), ] if with_description: row.append('-' if child.is_virtual else child.get_group().description) diff --git a/aiida/cmdline/commands/cmd_help.py b/aiida/cmdline/commands/cmd_help.py index 8b5412aaee..a3c36e415d 100644 --- a/aiida/cmdline/commands/cmd_help.py +++ b/aiida/cmdline/commands/cmd_help.py @@ -20,7 +20,6 @@ @click.argument('command', type=click.STRING, required=False) def verdi_help(ctx, command): """Show help for given command.""" - cmdctx = ctx.parent if command: diff --git a/aiida/cmdline/commands/cmd_node.py b/aiida/cmdline/commands/cmd_node.py index 728c27b185..7014ce1eb1 100644 --- a/aiida/cmdline/commands/cmd_node.py +++ b/aiida/cmdline/commands/cmd_node.py @@ -48,7 +48,7 @@ def repo_cat(node, relative_path): if not relative_path: if not isinstance(node, SinglefileData): - raise click.BadArgumentUsage('Missing argument \'RELATIVE_PATH\'.') + raise click.BadArgumentUsage("Missing argument 'RELATIVE_PATH'.") relative_path = node.filename @@ -102,9 +102,8 @@ def repo_dump(node, output_directory): except FileExistsError: echo.echo_critical(f'Invalid value for "OUTPUT_DIRECTORY": Path "{output_directory}" exists.') - def _copy_tree(key, output_dir): # pylint: disable=too-many-branches - """ - Recursively copy the content at the ``key`` path in the given node to + def _copy_tree(key, output_dir): + """Recursively copy the content at the ``key`` path in the given node to the ``output_dir``. """ for file in node.base.repository.list_objects(key): @@ -140,7 +139,6 @@ def node_label(nodes, label, raw, force): if label is None: for node in nodes: - if raw: table.append([node.label]) else: @@ -174,7 +172,6 @@ def node_description(nodes, description, force, raw): if description is None: for node in nodes: - if raw: table.append([node.description]) else: @@ -205,17 +202,15 @@ def node_show(nodes, print_groups): from aiida.cmdline.utils.common import get_node_info for node in nodes: - # pylint: disable=fixme # TODO: Add a check here on the node type, otherwise it might try to access # attributes such as code which are not necessarily there echo.echo(get_node_info(node)) if print_groups: - from aiida.orm import Node # pylint: disable=redefined-outer-name + from aiida.orm import Node from aiida.orm.groups import Group from aiida.orm.querybuilder import QueryBuilder - # pylint: disable=invalid-name qb = QueryBuilder() qb.append(Node, tag='node', filters={'id': {'==': node.pk}}) qb.append(Group, tag='groups', with_node='node', project=['id', 'label', 'type_string']) @@ -332,7 +327,7 @@ def _dry_run_callback(pks): '--entry-point', type=PluginParamType(group=('aiida.calculations', 'aiida.data', 'aiida.node', 'aiida.workflows'), load=True), default=None, - help='Only include nodes that are class or sub class of the class identified by this entry point.' + help='Only include nodes that are class or sub class of the class identified by this entry point.', ) @options.FORCE() @with_dbenv() @@ -377,7 +372,7 @@ def rehash(nodes, entry_point, force): echo.echo_critical('no matching nodes found') with click.progressbar(to_hash, label='Rehashing Nodes:') as iter_hash: - for node, in iter_hash: + for (node,) in iter_hash: node.base.caching.rehash() echo.echo_success(f'{num_nodes} nodes re-hashed.') @@ -399,57 +394,63 @@ def verdi_graph(): "'logic' includes only 'input_work' and 'return' links (logical provenance only)." ), default='all', - type=click.Choice(['all', 'data', 'logic']) + type=click.Choice(['all', 'data', 'logic']), ) @click.option( '--identifier', help='the type of identifier to use within the node text', default='uuid', - type=click.Choice(['pk', 'uuid', 'label']) + type=click.Choice(['pk', 'uuid', 'label']), ) @click.option( '-a', '--ancestor-depth', help='The maximum depth when recursing upwards, if not set it will recurse to the end.', - type=click.IntRange(min=0) + type=click.IntRange(min=0), ) @click.option( '-d', '--descendant-depth', help='The maximum depth when recursing through the descendants. If not set it will recurse to the end.', - type=click.IntRange(min=0) + type=click.IntRange(min=0), ) @click.option('-o', '--process-out', is_flag=True, help='Show outgoing links for all processes.') @click.option('-i', '--process-in', is_flag=True, help='Show incoming links for all processes.') @click.option( '-e', '--engine', - help="The graphviz engine, e.g. 'dot', 'circo', ... " - '(see http://www.graphviz.org/doc/info/output.html)', - default='dot' + help="The graphviz engine, e.g. 'dot', 'circo', ... " '(see http://www.graphviz.org/doc/info/output.html)', + default='dot', ) @click.option('-f', '--output-format', help="The output format used for rendering ('pdf', 'png', etc.).", default='pdf') @click.option( '-c', '--highlight-classes', - help= - "Only color nodes of specific class label (as displayed in the graph, e.g. 'StructureData', 'FolderData', etc.).", + help='Only color nodes of specific class label (as displayed in the graph e.g. StructureData, FolderData, etc.).', type=click.STRING, default=None, - multiple=True + multiple=True, ) @click.option('-s', '--show', is_flag=True, help='Open the rendered result with the default application.') @arguments.OUTPUT_FILE(required=False) @decorators.with_dbenv() def graph_generate( - root_node, link_types, identifier, ancestor_depth, descendant_depth, process_out, process_in, engine, output_format, - highlight_classes, show, output_file + root_node, + link_types, + identifier, + ancestor_depth, + descendant_depth, + process_out, + process_in, + engine, + output_format, + highlight_classes, + show, + output_file, ): - """ - Generate a graph from a ROOT_NODE (specified by pk or uuid). - """ - # pylint: disable=too-many-arguments + """Generate a graph from a ROOT_NODE (specified by pk or uuid).""" from aiida.tools.visualization import Graph + link_types = {'all': (), 'logic': ('input_work', 'return'), 'data': ('input_calc', 'create')}[link_types] echo.echo_info(f'Initiating graphviz engine: {engine}') diff --git a/aiida/cmdline/commands/cmd_process.py b/aiida/cmdline/commands/cmd_process.py index cf4278760a..feb089b70f 100644 --- a/aiida/cmdline/commands/cmd_process.py +++ b/aiida/cmdline/commands/cmd_process.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-arguments """`verdi process` command.""" import click @@ -32,6 +31,7 @@ def valid_projections(): This indirection is necessary to prevent loading the imported module which slows down tab-completion. """ from aiida.tools.query.calculation import CalculationQueryBuilder + return CalculationQueryBuilder.valid_projections @@ -41,6 +41,7 @@ def default_projections(): This indirection is necessary to prevent loading the imported module which slows down tab-completion. """ from aiida.tools.query.calculation import CalculationQueryBuilder + return CalculationQueryBuilder.default_projections @@ -50,7 +51,7 @@ def verdi_process(): @verdi_process.command('list') -@options.PROJECT(type=types.LazyChoice(valid_projections), default=lambda: default_projections()) # pylint: disable=unnecessary-lambda +@options.PROJECT(type=types.LazyChoice(valid_projections), default=lambda: default_projections()) @options.ORDER_BY() @options.ORDER_DIRECTION() @options.GROUP(help='Only include entries that are a member of this group.') @@ -66,14 +67,25 @@ def verdi_process(): @click.pass_context @decorators.with_dbenv() def process_list( - ctx, all_entries, group, process_state, process_label, paused, exit_status, failed, past_days, limit, project, raw, - order_by, order_dir + ctx, + all_entries, + group, + process_state, + process_label, + paused, + exit_status, + failed, + past_days, + limit, + project, + raw, + order_by, + order_dir, ): """Show a list of running or terminated processes. By default, only those that are still running are shown, but there are options to show also the finished ones. """ - # pylint: disable=too-many-locals from tabulate import tabulate from aiida.cmdline.commands.cmd_daemon import execute_client_command @@ -129,13 +141,11 @@ def process_list( # Second query to get active process count. Currently this is slow but will be fixed with issue #2770. It is # placed at the end of the command so that the user can Ctrl+C after getting the process table. slots_per_worker = ctx.obj.config.get_option('daemon.worker_process_slots', scope=ctx.obj.profile.name) - active_processes = QueryBuilder().append( - ProcessNode, filters={ - 'attributes.process_state': { - 'in': ('created', 'waiting', 'running') - } - } - ).count() + active_processes = ( + QueryBuilder() + .append(ProcessNode, filters={'attributes.process_state': {'in': ('created', 'waiting', 'running')}}) + .count() + ) available_slots = active_workers * slots_per_worker percent_load = active_processes / available_slots if percent_load > 0.9: # 90% @@ -162,7 +172,6 @@ def process_show(processes): def process_call_root(processes): """Show root process of the call stack for the given processes.""" for process in processes: - caller = process.caller if caller is None: @@ -188,7 +197,7 @@ def process_call_root(processes): '--levelname', type=click.Choice(list(LOG_LEVELS)), default='REPORT', - help='Filter the results by name of the log level.' + help='Filter the results by name of the log level.', ) @click.option( '-m', '--max-depth', 'max_depth', type=int, default=None, help='Limit the number of levels to be printed.' @@ -309,7 +318,7 @@ def process_watch(processes): from kiwipy import BroadcastFilter - def _print(communicator, body, sender, subject, correlation_id): # pylint: disable=unused-argument + def _print(communicator, body, sender, subject, correlation_id): """Format the incoming broadcast data into a message and echo it to stdout.""" if body is None: body = 'No message specified' @@ -323,7 +332,6 @@ def _print(communicator, body, sender, subject, correlation_id): # pylint: disa echo.echo_report('watching for broadcasted messages, press CTRL+C to stop...') for process in processes: - if process.is_terminated: echo.echo_error(f'Process<{process.pk}> is already terminated') continue diff --git a/aiida/cmdline/commands/cmd_profile.py b/aiida/cmdline/commands/cmd_profile.py index e004d9daad..d43a73a4a7 100644 --- a/aiida/cmdline/commands/cmd_profile.py +++ b/aiida/cmdline/commands/cmd_profile.py @@ -27,13 +27,8 @@ def verdi_profile(): def command_create_profile( - ctx: click.Context, - storage_cls, - non_interactive: bool, - profile: Profile, - set_as_default: bool = True, - **kwargs -): # pylint: disable=unused-argument + ctx: click.Context, storage_cls, non_interactive: bool, profile: Profile, set_as_default: bool = True, **kwargs +): """Create a new profile, initialise its storage and create a default user. :param ctx: The context of the CLI command. @@ -70,7 +65,7 @@ def command_create_profile( setup.SETUP_USER_FIRST_NAME(), setup.SETUP_USER_LAST_NAME(), setup.SETUP_USER_INSTITUTION(), - ] + ], ) def profile_setup(): """Set up a new profile.""" @@ -86,6 +81,7 @@ def profile_list(): # to be able to see the configuration directory, for instance for those who have set `AIIDA_PATH`. This way # they can at least verify that it is correctly set. from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER + echo.echo_report(f'configuration folder: {AIIDA_CONFIG_FOLDER}') echo.echo_critical(str(exception)) else: @@ -94,8 +90,8 @@ def profile_list(): if not config.profiles: echo.echo_warning('no profiles configured: run `verdi setup` to create one') else: - sort = lambda profile: profile.name # pylint: disable=unnecessary-lambda-assignment - highlight = lambda profile: profile.name == config.default_profile_name # pylint: disable=unnecessary-lambda-assignment + sort = lambda profile: profile.name # noqa: E731 + highlight = lambda profile: profile.name == config.default_profile_name # noqa: E731 echo.echo_formatted_list(config.profiles, ['name'], sort=sort, highlight=highlight) @@ -112,7 +108,6 @@ def _strip_private_keys(dct: dict): @arguments.PROFILE(default=defaults.get_default_profile) def profile_show(profile): """Show details for a profile.""" - if profile is None: echo.echo_critical('no profile to show') @@ -139,7 +134,7 @@ def profile_setdefault(profile): @click.option( '--delete-data/--keep-data', default=None, - help='Whether to delete the storage with all its data or not. This flag has to be explicitly specified' + help='Whether to delete the storage with all its data or not. This flag has to be explicitly specified', ) @arguments.PROFILES(required=True) def profile_delete(force, delete_data, profiles): diff --git a/aiida/cmdline/commands/cmd_rabbitmq.py b/aiida/cmdline/commands/cmd_rabbitmq.py index c10a201c68..57c9058731 100644 --- a/aiida/cmdline/commands/cmd_rabbitmq.py +++ b/aiida/cmdline/commands/cmd_rabbitmq.py @@ -92,6 +92,7 @@ def echo_response(response: 'requests.Response', exit_on_error: bool = True) -> :param exit_on_error: Boolean, if ``True``, call ``sys.exit`` with the status code of the response. """ import requests + try: response.raise_for_status() except requests.HTTPError: @@ -107,8 +108,8 @@ def echo_response(response: 'requests.Response', exit_on_error: bool = True) -> @click.pass_context def with_client(ctx, wrapped, _, args, kwargs): """Decorate a function injecting a :class:`aiida.manage.external.rmq.client.RabbitmqManagementClient`.""" - from aiida.manage.external.rmq.client import RabbitmqManagementClient + config = ctx.obj.profile.process_control_config client = RabbitmqManagementClient( username=config['broker_username'], @@ -146,7 +147,7 @@ def cmd_server_properties(manager): '--project', type=click.Choice(AVAILABLE_PROJECTORS), cls=options.MultipleValueOption, - default=('name', 'messages', 'state') + default=('name', 'messages', 'state'), ) @options.RAW() @click.option('-f', '--filter-name', type=str, help='Provide a regex pattern to filter queues based on their name. ') @@ -240,6 +241,7 @@ def cmd_tasks_analyze(ctx, fix): Use ``-v INFO`` to be more verbose and print more information. """ from .cmd_process import process_repair + ctx.invoke(process_repair, dry_run=not fix) diff --git a/aiida/cmdline/commands/cmd_restapi.py b/aiida/cmdline/commands/cmd_restapi.py index 799d6350ab..fc2b5ac6e9 100644 --- a/aiida/cmdline/commands/cmd_restapi.py +++ b/aiida/cmdline/commands/cmd_restapi.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -This allows to hook-up the AiiDA built-in RESTful API. +"""This allows to hook-up the AiiDA built-in RESTful API. Main advantage of doing this by means of a verdi command is that different profiles can be selected at hook-up (-p flag). """ @@ -28,14 +27,14 @@ '--config-dir', type=click.Path(exists=True), default=config.CLI_DEFAULTS['CONFIG_DIR'], - help='Path to the configuration directory' + help='Path to the configuration directory', ) @DEBUG(default=config.APP_CONFIG['DEBUG']) @click.option( '--wsgi-profile', is_flag=True, default=config.CLI_DEFAULTS['WSGI_PROFILE'], - help='Whether to enable WSGI profiler middleware for finding bottlenecks' + help='Whether to enable WSGI profiler middleware for finding bottlenecks', ) @click.option( '--posting/--no-posting', @@ -47,8 +46,7 @@ ) @click.pass_context def restapi(ctx, hostname, port, config_dir, debug, wsgi_profile, posting): - """ - Run the AiiDA REST API server. + """Run the AiiDA REST API server. Example Usage: diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index 6d1fc0bd4c..c9adfc208b 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -33,8 +33,8 @@ def update_environment(argv): yield finally: # Restore old parameters when exiting from the context manager - sys.argv = _argv # pylint - sys.path = _path # pylint + sys.argv = _argv + sys.path = _path def validate_entry_point_strings(_, __, value): @@ -44,12 +44,17 @@ def validate_entry_point_strings(_, __, value): try: autogroup.AutogroupManager.validate(value) except (TypeError, ValueError) as exc: - raise click.BadParameter(f'{str(exc)}: `{value}`') + raise click.BadParameter(f'{exc!s}: `{value}`') return value -@verdi.command('run', context_settings=dict(ignore_unknown_options=True,)) +@verdi.command( + 'run', + context_settings=dict( + ignore_unknown_options=True, + ), +) @click.argument('filepath', type=click.Path(exists=True, readable=True, dir_okay=False, path_type=pathlib.Path)) @click.argument('varargs', nargs=-1, type=click.UNPROCESSED) @click.option('--auto-group', is_flag=True, help='Enables the autogrouping') @@ -59,7 +64,7 @@ def validate_entry_point_strings(_, __, value): type=click.STRING, required=False, help='Specify the prefix of the label of the auto group (numbers might be automatically ' - 'appended to generate unique names per run).' + 'appended to generate unique names per run).', ) @click.option( '-e', @@ -68,7 +73,7 @@ def validate_entry_point_strings(_, __, value): cls=MultipleValueOption, default=None, help='Exclude these classes from auto grouping (use full entrypoint strings).', - callback=validate_entry_point_strings + callback=validate_entry_point_strings, ) @click.option( '-i', @@ -77,7 +82,7 @@ def validate_entry_point_strings(_, __, value): cls=MultipleValueOption, default=None, help='Include these classes from auto grouping (use full entrypoint strings or "all").', - callback=validate_entry_point_strings + callback=validate_entry_point_strings, ) @decorators.with_dbenv() def run(filepath, varargs, auto_group, auto_group_label_prefix, exclude, include): @@ -93,7 +98,7 @@ def run(filepath, varargs, auto_group, auto_group_label_prefix, exclude, include '__name__': '__main__', '__file__': filepath.name, '__doc__': None, - '__package__': None + '__package__': None, } # Dynamically load modules (the same of verdi shell) - but in globals_dict, not in the current environment @@ -112,8 +117,8 @@ def run(filepath, varargs, auto_group, auto_group_label_prefix, exclude, include with filepath.open('r', encoding='utf-8') as handle: with update_environment(argv=[str(filepath)] + list(varargs)): # Compile the script for execution and pass it to exec with the globals_dict - exec(compile(handle.read(), str(filepath), 'exec', dont_inherit=True), globals_dict) # pylint: disable=exec-used - except SystemExit: # pylint: disable=try-except-raise + exec(compile(handle.read(), str(filepath), 'exec', dont_inherit=True), globals_dict) + except SystemExit: # Script called ``sys.exit()``, re-raise the exception to have the error code properly returned at the end raise finally: diff --git a/aiida/cmdline/commands/cmd_setup.py b/aiida/cmdline/commands/cmd_setup.py index b715cf6fd5..879916eff6 100644 --- a/aiida/cmdline/commands/cmd_setup.py +++ b/aiida/cmdline/commands/cmd_setup.py @@ -43,15 +43,34 @@ @options.CONFIG_FILE() @click.pass_context def setup( - ctx, non_interactive, profile: Profile, email, first_name, last_name, institution, db_engine, db_backend, db_host, - db_port, db_name, db_username, db_password, broker_protocol, broker_username, broker_password, broker_host, - broker_port, broker_virtual_host, repository, test_profile, profile_uuid + ctx, + non_interactive, + profile: Profile, + email, + first_name, + last_name, + institution, + db_engine, + db_backend, + db_host, + db_port, + db_name, + db_username, + db_password, + broker_protocol, + broker_username, + broker_password, + broker_host, + broker_port, + broker_virtual_host, + repository, + test_profile, + profile_uuid, ): """Setup a new profile. This method assumes that an empty PSQL database has been created and that the database user has been created. """ - # pylint: disable=too-many-arguments,too-many-locals,unused-argument from aiida import orm # store default user settings so user does not have to re-enter them @@ -61,7 +80,8 @@ def setup( profile.uuid = profile_uuid profile.set_storage( - db_backend, { + db_backend, + { 'database_engine': db_engine, 'database_hostname': db_host, 'database_port': db_port, @@ -69,17 +89,18 @@ def setup( 'database_username': db_username, 'database_password': db_password, 'repository_uri': f'file://{repository}', - } + }, ) profile.set_process_controller( - 'rabbitmq', { + 'rabbitmq', + { 'broker_protocol': broker_protocol, 'broker_username': broker_username, 'broker_password': broker_password, 'broker_host': broker_host, 'broker_port': broker_port, 'broker_virtual_host': broker_virtual_host, - } + }, ) profile.is_test_profile = test_profile @@ -95,7 +116,7 @@ def setup( try: profile.storage_cls.initialise(profile) - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: echo.echo_critical( f'storage initialisation failed, probably because connection details are incorrect:\n{exception}' ) @@ -146,12 +167,33 @@ def setup( @options.CONFIG_FILE() @click.pass_context def quicksetup( - ctx, non_interactive, profile, email, first_name, last_name, institution, db_engine, db_backend, db_host, db_port, - db_name, db_username, db_password, su_db_name, su_db_username, su_db_password, broker_protocol, broker_username, - broker_password, broker_host, broker_port, broker_virtual_host, repository, test_profile + ctx, + non_interactive, + profile, + email, + first_name, + last_name, + institution, + db_engine, + db_backend, + db_host, + db_port, + db_name, + db_username, + db_password, + su_db_name, + su_db_username, + su_db_password, + broker_protocol, + broker_username, + broker_password, + broker_host, + broker_port, + broker_virtual_host, + repository, + test_profile, ): """Setup a new profile in a fully automated fashion.""" - # pylint: disable=too-many-arguments,too-many-locals from aiida.manage.external.postgres import Postgres, manual_setup_instructions # store default user settings so user does not have to re-enter them @@ -173,13 +215,12 @@ def quicksetup( db_username, db_name = postgres.create_dbuser_db_safe(dbname=db_name, dbuser=db_username, dbpass=db_password) except Exception as exception: echo.echo_error( - '\n'.join([ - 'Oops! quicksetup was unable to create the AiiDA database for you.', - 'See `verdi quicksetup -h` for how to specify non-standard parameters for the postgresql connection.\n' - 'Alternatively, create the AiiDA database yourself: ', - manual_setup_instructions(db_username=db_username, - db_name=db_name), '', 'and then use `verdi setup` instead', '' - ]) + f"""Oops! quicksetup was unable to create the AiiDA database for you. + See `verdi quicksetup -h` for how to specify non-standard parameters for the postgresql connection. + Alternatively, create the AiiDA database yourself:\n + {manual_setup_instructions(db_username=db_username, db_name=db_name)}\n + and then use `verdi setup` instead. + """ ) raise exception diff --git a/aiida/cmdline/commands/cmd_shell.py b/aiida/cmdline/commands/cmd_shell.py index 91f50d3338..06d5168980 100644 --- a/aiida/cmdline/commands/cmd_shell.py +++ b/aiida/cmdline/commands/cmd_shell.py @@ -24,17 +24,16 @@ @click.option( '--no-startup', is_flag=True, - help='When using plain Python, ignore the PYTHONSTARTUP environment variable and ~/.pythonrc.py script.' + help='When using plain Python, ignore the PYTHONSTARTUP environment variable and ~/.pythonrc.py script.', ) @click.option( '-i', '--interface', type=click.Choice(AVAILABLE_SHELLS.keys()), - help='Specify an interactive interpreter interface.' + help='Specify an interactive interpreter interface.', ) def shell(plain, no_startup, interface): """Start a python shell with preloaded AiiDA environment.""" - try: if plain: # Don't bother loading IPython, because the user wants plain Python. @@ -70,15 +69,13 @@ def shell(plain, no_startup, interface): for pythonrc in (os.environ.get('PYTHONSTARTUP'), '~/.pythonrc.py'): if not pythonrc: continue - pythonrc = os.path.expanduser(pythonrc) - if not os.path.isfile(pythonrc): + pythonrc_expanded = os.path.expanduser(pythonrc) + if not os.path.isfile(pythonrc_expanded): continue try: - with open(pythonrc, encoding='utf8') as handle: - exec(compile(handle.read(), pythonrc, 'exec'), imported_objects) # pylint: disable=exec-used + with open(pythonrc_expanded, encoding='utf8') as handle: + exec(compile(handle.read(), pythonrc_expanded, 'exec'), imported_objects) except NameError: pass - # The pylint disabler is necessary because the builtin code module - # clashes with the local commands.code module here. - code.interact(local=imported_objects) # pylint: disable=no-member + code.interact(local=imported_objects) diff --git a/aiida/cmdline/commands/cmd_status.py b/aiida/cmdline/commands/cmd_status.py index e4d3f68dee..afaec306e0 100644 --- a/aiida/cmdline/commands/cmd_status.py +++ b/aiida/cmdline/commands/cmd_status.py @@ -19,12 +19,13 @@ from aiida.common.exceptions import CorruptStorage, IncompatibleStorageSchema, UnreachableStorage from aiida.common.log import override_log_level -from ..utils.echo import ExitCode # pylint: disable=import-error,no-name-in-module +from ..utils.echo import ExitCode class ServiceStatus(enum.IntEnum): """Describe status of services for 'verdi status' command.""" - UP = 0 # pylint: disable=invalid-name + + UP = 0 ERROR = 1 WARNING = 2 DOWN = 3 @@ -55,7 +56,6 @@ class ServiceStatus(enum.IntEnum): @click.option('--no-rmq', is_flag=True, help='Do not check RabbitMQ status') def verdi_status(print_traceback, no_rmq): """Print status of AiiDA services.""" - # pylint: disable=broad-except,too-many-statements,too-many-branches,too-many-locals, from aiida import __version__ from aiida.common.utils import Capturing from aiida.engine.daemon.client import DaemonException, DaemonNotRunningException @@ -92,10 +92,10 @@ def verdi_status(print_traceback, no_rmq): storage_head_version = storage_cls.version_head() storage_backend = storage_cls(profile) except UnreachableStorage as exc: - message = 'Unable to connect to profile\'s storage.' + message = "Unable to connect to profile's storage." print_status(ServiceStatus.DOWN, 'storage', message, exception=exc, print_traceback=print_traceback) exit_code = ExitCode.CRITICAL - except IncompatibleStorageSchema as exc: + except IncompatibleStorageSchema: message = ( f'Storage schema version is incompatible with the code version {storage_head_version!r}. ' 'Run `verdi storage migrate` to solve this.' @@ -107,7 +107,7 @@ def verdi_status(print_traceback, no_rmq): print_status(ServiceStatus.DOWN, 'storage', message, exception=exc, print_traceback=print_traceback) exit_code = ExitCode.CRITICAL except Exception as exc: - message = 'Unable to instatiate profile\'s storage.' + message = "Unable to instatiate profile's storage." print_status(ServiceStatus.ERROR, 'storage', message, exception=exc, print_traceback=print_traceback) exit_code = ExitCode.CRITICAL else: @@ -171,4 +171,5 @@ def print_status(status, service, msg='', exception=None, print_traceback=False) if print_traceback: import traceback + traceback.print_exc() diff --git a/aiida/cmdline/commands/cmd_storage.py b/aiida/cmdline/commands/cmd_storage.py index cf32e9bf99..ac540a58cb 100644 --- a/aiida/cmdline/commands/cmd_storage.py +++ b/aiida/cmdline/commands/cmd_storage.py @@ -26,6 +26,7 @@ def verdi_storage(): def storage_version(): """Print the current version of the storage schema.""" from aiida import get_profile + profile = get_profile() head_version = profile.storage_cls.version_head() profile_version = profile.storage_cls.version_profile(profile) @@ -49,7 +50,6 @@ def storage_migrate(force): storage_cls = profile.storage_cls if not force: - echo.echo_warning('Migrating your storage might take a while and is not reversible.') echo.echo_warning('Before continuing, make sure you have completed the following steps:') echo.echo_warning('') @@ -107,7 +107,7 @@ def storage_info(detailed): @click.option( '--full', is_flag=True, - help='Perform all maintenance tasks, including the ones that should not be executed while the profile is in use.' + help='Perform all maintenance tasks, including the ones that should not be executed while the profile is in use.', ) @click.option( '--no-repack', is_flag=True, help='Disable the repacking of the storage when running a `full maintenance`.' @@ -116,8 +116,7 @@ def storage_info(detailed): @click.option( '--dry-run', is_flag=True, - help= - 'Run the maintenance in dry-run mode which will print actions that would be taken without actually executing them.' + help='Run the maintenance in dry-run mode to print actions that would be taken without actually executing them.', ) @click.option( '--compress', is_flag=True, default=False, help='Use compression if possible when carrying out maintenance tasks.' diff --git a/aiida/cmdline/commands/cmd_user.py b/aiida/cmdline/commands/cmd_user.py index 8f8dbeff3c..ffd03dca12 100644 --- a/aiida/cmdline/commands/cmd_user.py +++ b/aiida/cmdline/commands/cmd_user.py @@ -52,7 +52,7 @@ def user_list(): prompt='User email', help='Email address that serves as the user name and a way to identify data created by it.', type=types.UserParamType(create=True), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) @options_setup.SETUP_USER_FIRST_NAME(contextual_default=lambda ctx: ctx.params['user'].first_name) @options_setup.SETUP_USER_LAST_NAME(contextual_default=lambda ctx: ctx.params['user'].last_name) @@ -63,11 +63,11 @@ def user_list(): help='Set the user as the default user for the current profile.', is_flag=True, cls=options.interactive.InteractiveOption, - contextual_default=lambda ctx: ctx.params['user'].is_default + contextual_default=lambda ctx: ctx.params['user'].is_default, ) @click.pass_context @decorators.with_dbenv() -def user_configure(ctx, user, first_name, last_name, institution, set_default): # pylint: disable=too-many-arguments +def user_configure(ctx, user, first_name, last_name, institution, set_default): """Configure a new or existing user. An e-mail address is used as the user name. @@ -91,5 +91,6 @@ def user_configure(ctx, user, first_name, last_name, institution, set_default): def user_set_default(ctx, user): """Set a user as the default user for the profile.""" from aiida.manage import get_manager + get_manager().set_default_user_email(ctx.obj.profile, user.email) echo.echo_success(f'Set `{user.email}` as the default user for profile `{ctx.obj.profile.name}.`') diff --git a/aiida/cmdline/groups/__init__.py b/aiida/cmdline/groups/__init__.py index 3403f5c550..1ec2629a1e 100644 --- a/aiida/cmdline/groups/__init__.py +++ b/aiida/cmdline/groups/__init__.py @@ -3,8 +3,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .dynamic import * from .verdi import * @@ -14,4 +13,4 @@ 'VerdiCommandGroup', ) -# yapf: enable +# fmt: on diff --git a/aiida/cmdline/groups/dynamic.py b/aiida/cmdline/groups/dynamic.py index 475dbfb48f..38b1e61dc1 100644 --- a/aiida/cmdline/groups/dynamic.py +++ b/aiida/cmdline/groups/dynamic.py @@ -49,7 +49,7 @@ def __init__( entry_point_group: str, entry_point_name_filter: str = r'.*', shared_options: list[click.Option] | None = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self._command = command @@ -66,11 +66,14 @@ def list_commands(self, ctx: click.Context) -> list[str]: :param ctx: The :class:`click.Context`. """ commands = super().list_commands(ctx) - commands.extend([ - entry_point for entry_point in get_entry_point_names(self.entry_point_group) - if re.match(self.entry_point_name_filter, entry_point) and - getattr(self.factory(entry_point), 'cli_exposed', True) - ]) + commands.extend( + [ + entry_point + for entry_point in get_entry_point_names(self.entry_point_group) + if re.match(self.entry_point_name_filter, entry_point) + and getattr(self.factory(entry_point), 'cli_exposed', True) + ] + ) return sorted(commands) def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: @@ -138,9 +141,15 @@ def list_options(self, entry_point: str) -> list: # '`pydantic.BaseModel` that should be assigned to the `Config` class attribute.', # version=3 # ) + from aiida.common.warnings import warn_deprecation + + warn_deprecation( + 'Relying on `_get_cli_options` is deprecated. The options should be defined through a ' + '`pydantic.BaseModel` that should be assigned to the `Config` class attribute.', + version=3, + ) options_spec = self.factory(entry_point).get_cli_options() # type: ignore[union-attr] else: - options_spec = {} for key, field_info in cls.Configuration.model_fields.items(): diff --git a/aiida/cmdline/params/__init__.py b/aiida/cmdline/params/__init__.py index 128abf2797..0940d42249 100644 --- a/aiida/cmdline/params/__init__.py +++ b/aiida/cmdline/params/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .types import * @@ -44,4 +43,4 @@ 'WorkflowParamType', ) -# yapf: enable +# fmt: on diff --git a/aiida/cmdline/params/arguments/__init__.py b/aiida/cmdline/params/arguments/__init__.py index 0c891e6691..6ec4ff8cae 100644 --- a/aiida/cmdline/params/arguments/__init__.py +++ b/aiida/cmdline/params/arguments/__init__.py @@ -7,13 +7,10 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# yapf: disable -"""Module with pre-defined reusable commandline arguments that can be used as `click` decorators.""" # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .main import * from .overridable import * @@ -45,4 +42,4 @@ 'WORKFLOWS', ) -# yapf: enable +# fmt: on diff --git a/aiida/cmdline/params/arguments/main.py b/aiida/cmdline/params/arguments/main.py index 71bb8c2544..298f1128f8 100644 --- a/aiida/cmdline/params/arguments/main.py +++ b/aiida/cmdline/params/arguments/main.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# yapf: disable + """Module with pre-defined reusable commandline arguments that can be used as `click` decorators.""" import click @@ -16,9 +16,29 @@ from .overridable import OverridableArgument __all__ = ( - 'PROFILE', 'PROFILES', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'DATUM', 'DATA', - 'GROUP', 'GROUPS', 'NODE', 'NODES', 'PROCESS', 'PROCESSES', 'WORKFLOW', 'WORKFLOWS', 'INPUT_FILE', 'OUTPUT_FILE', - 'LABEL', 'USER', 'CONFIG_OPTION' + 'PROFILE', + 'PROFILES', + 'CALCULATION', + 'CALCULATIONS', + 'CODE', + 'CODES', + 'COMPUTER', + 'COMPUTERS', + 'DATUM', + 'DATA', + 'GROUP', + 'GROUPS', + 'NODE', + 'NODES', + 'PROCESS', + 'PROCESSES', + 'WORKFLOW', + 'WORKFLOWS', + 'INPUT_FILE', + 'OUTPUT_FILE', + 'LABEL', + 'USER', + 'CONFIG_OPTION', ) diff --git a/aiida/cmdline/params/arguments/overridable.py b/aiida/cmdline/params/arguments/overridable.py index 72ddff6ff7..da84d938e3 100644 --- a/aiida/cmdline/params/arguments/overridable.py +++ b/aiida/cmdline/params/arguments/overridable.py @@ -7,22 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -.. py:module::overridable - :synopsis: Convenience class which can be used to defined a set of commonly used arguments that - can be easily reused and which improves consistency across the command line interface -""" +"""Convenience class which can be used to defined a set of commonly used arguments that can be easily reused.""" import click __all__ = ('OverridableArgument',) class OverridableArgument: - """ - Wrapper around click.argument that increases reusability + """Wrapper around click.argument that increases reusability. - Once defined, the argument can be reused with a consistent name and sensible defaults while - other details can be customized on a per-command basis + Once defined, the argument can be reused with a consistent name and sensible defaults while other details can be + customized on a per-command basis. Example:: @@ -37,22 +32,20 @@ def print_code_pks(codes): click.echo([c.pk for c in codes]) Notice that the arguments, which are used to define the name of the argument and based on which - the function argument name is determined, can be overriden + the function argument name is determined, can be overridden. """ def __init__(self, *args, **kwargs): - """ - Store the default args and kwargs - """ + """Store the default args and kwargs""" self.args = args self.kwargs = kwargs def __call__(self, *args, **kwargs): - """ - Override the stored kwargs with the passed kwargs and return the argument, using the stored args - only if they are not provided. This allows the user to override the variable name, which is - useful if for example they want to allow multiple value with nargs=-1 and want to pluralize - the function argument for consistency + """Override the stored kwargs with the passed kwargs and return the argument. + + The stored args are used only if they are not provided. This allows the user to override the variable name, + which is useful if for example they want to allow multiple value with ``nargs=-1`` and want to pluralize the + function argument for consistency. """ kw_copy = self.kwargs.copy() kw_copy.update(kwargs) diff --git a/aiida/cmdline/params/options/__init__.py b/aiida/cmdline/params/options/__init__.py index e5dd70f91b..1c24205993 100644 --- a/aiida/cmdline/params/options/__init__.py +++ b/aiida/cmdline/params/options/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .callable import * from .config import * @@ -115,4 +114,4 @@ 'valid_process_states', ) -# yapf: enable +# fmt: on diff --git a/aiida/cmdline/params/options/callable.py b/aiida/cmdline/params/options/callable.py index 4fde0e96b6..5cb6cfd23f 100644 --- a/aiida/cmdline/params/options/callable.py +++ b/aiida/cmdline/params/options/callable.py @@ -7,10 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -.. py:module::callable - :synopsis: A monkey-patched subclass of click.Option that does not evaluate callable default during tab completion -""" +"""A monkey-patched subclass of click.Option that does not evaluate callable default during tab completion.""" import typing as t @@ -27,8 +24,7 @@ class CallableDefaultOption(click.Option): """ def get_default(self, ctx: click.Context, call: bool = True) -> t.Optional[t.Union[t.Any, t.Callable[[], t.Any]]]: - """provides the functionality of :meth:`click.Option.get_default`, - but ensures we do not evaluate callable defaults when in tab-completion context.""" + """Return default unless in tab-completion context.""" if ctx.resilient_parsing: return None return super().get_default(ctx=ctx, call=call) diff --git a/aiida/cmdline/params/options/commands/code.py b/aiida/cmdline/params/options/commands/code.py index 98f0f8b3d8..ce1341fe28 100644 --- a/aiida/cmdline/params/options/commands/code.py +++ b/aiida/cmdline/params/options/commands/code.py @@ -75,7 +75,7 @@ def validate_label_uniqueness(ctx, _, value): cls=InteractiveOption, prompt='Installed on target computer?', help='Whether the code is installed on the target computer, or should be copied to the target computer each time ' - 'from a local path.' + 'from a local path.', ) REMOTE_ABS_PATH = OverridableOption( @@ -85,7 +85,7 @@ def validate_label_uniqueness(ctx, _, value): prompt_fn=is_on_computer, type=types.AbsolutePathParamType(dir_okay=False), cls=InteractiveOption, - help='[if --on-computer]: Absolute path to the executable on the target computer.' + help='[if --on-computer]: Absolute path to the executable on the target computer.', ) FOLDER = OverridableOption( @@ -96,7 +96,7 @@ def validate_label_uniqueness(ctx, _, value): type=click.Path(file_okay=False, exists=True, readable=True), cls=InteractiveOption, help='[if --store-in-db]: Absolute path to directory containing the executable and all other files necessary for ' - 'running it (to be copied to target computer).' + 'running it (to be copied to target computer).', ) REL_PATH = OverridableOption( @@ -106,7 +106,7 @@ def validate_label_uniqueness(ctx, _, value): prompt_fn=is_not_on_computer, type=click.Path(dir_okay=False), cls=InteractiveOption, - help='[if --store-in-db]: Relative path of the executable inside the code-folder.' + help='[if --store-in-db]: Relative path of the executable inside the code-folder.', ) USE_DOUBLE_QUOTES = OverridableOption( @@ -115,7 +115,7 @@ def validate_label_uniqueness(ctx, _, value): cls=InteractiveOption, prompt='Escape CLI arguments in double quotes', help='Whether the executable and arguments of the code in the submission script should be escaped with single ' - 'or double quotes.' + 'or double quotes.', ) LABEL = options.LABEL.clone( @@ -123,20 +123,20 @@ def validate_label_uniqueness(ctx, _, value): callback=validate_label_uniqueness, cls=InteractiveOption, help="This label can be used to identify the code (using 'label@computerlabel'), as long as labels are unique per " - 'computer.' + 'computer.', ) DESCRIPTION = options.DESCRIPTION.clone( prompt='Description', cls=InteractiveOption, - help='A human-readable description of this code, ideally including version and compilation environment.' + help='A human-readable description of this code, ideally including version and compilation environment.', ) INPUT_PLUGIN = options.INPUT_PLUGIN.clone( required=False, prompt='Default calculation input plugin', cls=InteractiveOption, - help="Entry point name of the default calculation plugin (as listed in 'verdi plugin list aiida.calculations')." + help="Entry point name of the default calculation plugin (as listed in 'verdi plugin list aiida.calculations').", ) COMPUTER = options.COMPUTER.clone( @@ -144,7 +144,7 @@ def validate_label_uniqueness(ctx, _, value): cls=InteractiveOption, required_fn=is_on_computer, prompt_fn=is_on_computer, - help='Name of the computer, on which the code is installed.' + help='Name of the computer, on which the code is installed.', ) PREPEND_TEXT = OverridableOption( @@ -157,7 +157,7 @@ def validate_label_uniqueness(ctx, _, value): extension='.bash', header='PREPEND_TEXT: if there is any bash commands that should be prepended to the executable call in all ' 'submit scripts for this code, type that between the equal signs below and save the file.', - footer='All lines that start with `#=` will be ignored.' + footer='All lines that start with `#=` will be ignored.', ) APPEND_TEXT = OverridableOption( @@ -170,5 +170,5 @@ def validate_label_uniqueness(ctx, _, value): extension='.bash', header='APPEND_TEXT: if there is any bash commands that should be appended to the executable call in all ' 'submit scripts for this code, type that between the equal signs below and save the file.', - footer='All lines that start with `#=` will be ignored.' + footer='All lines that start with `#=` will be ignored.', ) diff --git a/aiida/cmdline/params/options/commands/computer.py b/aiida/cmdline/params/options/commands/computer.py index 7e1cfc0859..2203df1c41 100644 --- a/aiida/cmdline/params/options/commands/computer.py +++ b/aiida/cmdline/params/options/commands/computer.py @@ -16,9 +16,7 @@ def get_job_resource_cls(ctx): - """ - Return job resource cls from ctx. - """ + """Return job resource cls from ctx.""" from aiida.common.exceptions import ValidationError scheduler_ep = ctx.params['scheduler'] @@ -35,9 +33,8 @@ def get_job_resource_cls(ctx): return scheduler_cls.job_resource_class -def should_call_default_mpiprocs_per_machine(ctx): # pylint: disable=invalid-name - """ - Return whether the selected scheduler type accepts `default_mpiprocs_per_machine`. +def should_call_default_mpiprocs_per_machine(ctx): + """Return whether the selected scheduler type accepts `default_mpiprocs_per_machine`. :return: `True` if the scheduler type accepts `default_mpiprocs_per_machine`, `False` otherwise. If the scheduler class could not be loaded `False` is returned by default. @@ -51,9 +48,8 @@ def should_call_default_mpiprocs_per_machine(ctx): # pylint: disable=invalid-na return job_resource_cls.accepts_default_mpiprocs_per_machine() -def should_call_default_memory_per_machine(ctx): # pylint: disable=invalid-name - """ - Return whether the selected scheduler type accepts `default_memory_per_machine`. +def should_call_default_memory_per_machine(ctx): + """Return whether the selected scheduler type accepts `default_memory_per_machine`. :return: `True` if the scheduler type accepts `default_memory_per_machine`, `False` otherwise. If the scheduler class could not be loaded `False` is returned by default. @@ -71,7 +67,7 @@ def should_call_default_memory_per_machine(ctx): # pylint: disable=invalid-name prompt='Computer label', cls=InteractiveOption, required=True, - help='Unique, human-readable label for this computer.' + help='Unique, human-readable label for this computer.', ) HOSTNAME = options.HOSTNAME.clone( @@ -96,7 +92,7 @@ def should_call_default_memory_per_machine(ctx): # pylint: disable=invalid-name default='#!/bin/bash', cls=InteractiveOption, help='Specify the first line of the submission script for this computer (only the bash shell is supported).', - type=types.ShebangParamType() + type=types.ShebangParamType(), ) WORKDIR = OverridableOption( @@ -107,7 +103,7 @@ def should_call_default_memory_per_machine(ctx): # pylint: disable=invalid-name cls=InteractiveOption, help='The absolute path of the directory on the computer where AiiDA will ' 'run the calculations (often a "scratch" directory).' - 'The {username} string will be replaced by your username on the remote computer.' + 'The {username} string will be replaced by your username on the remote computer.', ) MPI_RUN_COMMAND = OverridableOption( @@ -118,7 +114,7 @@ def should_call_default_memory_per_machine(ctx): # pylint: disable=invalid-name cls=InteractiveOption, help='The mpirun command needed on the cluster to run parallel MPI programs. The {tot_num_mpiprocs} string will be ' 'replaced by the total number of cpus. See the scheduler docs for further scheduler-dependent template variables.', - type=types.MpirunCommandParamType() + type=types.MpirunCommandParamType(), ) MPI_PROCS_PER_MACHINE = OverridableOption( @@ -139,7 +135,7 @@ def should_call_default_memory_per_machine(ctx): # pylint: disable=invalid-name prompt_fn=should_call_default_memory_per_machine, required=False, type=click.INT, - help='The default amount of RAM (kB) that should be allocated per machine (node), if not otherwise specified.' + help='The default amount of RAM (kB) that should be allocated per machine (node), if not otherwise specified.', ) USE_DOUBLE_QUOTES = OverridableOption( @@ -148,7 +144,7 @@ def should_call_default_memory_per_machine(ctx): # pylint: disable=invalid-name cls=InteractiveOption, prompt='Escape CLI arguments in double quotes', help='Whether the command line arguments before and after the executable in the submission script should be ' - 'escaped with single or double quotes.' + 'escaped with single or double quotes.', ) PREPEND_TEXT = OverridableOption( @@ -161,7 +157,7 @@ def should_call_default_memory_per_machine(ctx): # pylint: disable=invalid-name extension='.bash', header='PREPEND_TEXT: if there is any bash commands that should be prepended to the executable call in all ' 'submit scripts for this computer, type that between the equal signs below and save the file.', - footer='All lines that start with `#=` will be ignored.' + footer='All lines that start with `#=` will be ignored.', ) APPEND_TEXT = OverridableOption( @@ -174,5 +170,5 @@ def should_call_default_memory_per_machine(ctx): # pylint: disable=invalid-name extension='.bash', header='APPEND_TEXT: if there is any bash commands that should be appended to the executable call in all ' 'submit scripts for this computer, type that between the equal signs below and save the file.', - footer='All lines that start with `#=` will be ignored.' + footer='All lines that start with `#=` will be ignored.', ) diff --git a/aiida/cmdline/params/options/commands/setup.py b/aiida/cmdline/params/options/commands/setup.py index b8cd6521ff..786aa4735a 100644 --- a/aiida/cmdline/params/options/commands/setup.py +++ b/aiida/cmdline/params/options/commands/setup.py @@ -18,7 +18,7 @@ from aiida.manage.external.postgres import DEFAULT_DBINFO from aiida.manage.external.rmq import BROKER_DEFAULTS -PASSWORD_UNCHANGED = '***' # noqa +PASSWORD_UNCHANGED = '***' def validate_profile_parameter(ctx): @@ -71,7 +71,7 @@ def get_repository_uri_default(ctx): return os.path.join(AIIDA_CONFIG_FOLDER, 'repository', ctx.params['profile'].name) -def get_quicksetup_repository_uri(ctx, param, value): # pylint: disable=unused-argument +def get_quicksetup_repository_uri(ctx, param, value): """Return the repository URI to be used as default in `verdi quicksetup` :param ctx: click context which should contain the contextual parameters @@ -80,7 +80,7 @@ def get_quicksetup_repository_uri(ctx, param, value): # pylint: disable=unused- return get_repository_uri_default(ctx) -def get_quicksetup_database_name(ctx, param, value): # pylint: disable=unused-argument +def get_quicksetup_database_name(ctx, param, value): """Determine the database name to be used as default for the Postgres connection in `verdi quicksetup` If a value is explicitly passed, that value is returned unchanged. @@ -106,7 +106,7 @@ def get_quicksetup_database_name(ctx, param, value): # pylint: disable=unused-a return database_name -def get_quicksetup_username(ctx, param, value): # pylint: disable=unused-argument +def get_quicksetup_username(ctx, param, value): """Determine the username to be used as default for the Postgres connection in `verdi quicksetup` If a value is explicitly passed, that value is returned. If there is no value, the name will be based on the @@ -127,7 +127,7 @@ def get_quicksetup_username(ctx, param, value): # pylint: disable=unused-argume return username -def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argument +def get_quicksetup_password(ctx, param, value): """Determine the password to be used as default for the Postgres connection in `verdi quicksetup` If a value is explicitly passed, that value is returned. If there is no value, the current username in the context @@ -165,7 +165,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume required=False, hidden=True, type=str, - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_PROFILE = options.OverridableOption( @@ -174,7 +174,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume help='The name of the new profile.', required=True, type=types.ProfileParamType(cannot_exist=True), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_PROFILE_SET_AS_DEFAULT = options.OverridableOption( @@ -183,35 +183,35 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume help='Whether to set the profile as the default.', is_flag=True, default=True, - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_USER_EMAIL = options.USER_EMAIL.clone( prompt='Email Address (for sharing data)', default=functools.partial(get_config_option, 'autofill.user.email'), required=True, - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_USER_FIRST_NAME = options.USER_FIRST_NAME.clone( prompt='First name', default=lambda: get_config_option('autofill.user.first_name') or 'John', required=True, - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_USER_LAST_NAME = options.USER_LAST_NAME.clone( prompt='Last name', default=lambda: get_config_option('autofill.user.last_name') or 'Doe', required=True, - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_USER_INSTITUTION = options.USER_INSTITUTION.clone( prompt='Institution', default=lambda: get_config_option('autofill.user.institution') or 'Unknown', required=True, - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) QUICKSETUP_DATABASE_ENGINE = options.DB_ENGINE @@ -226,7 +226,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume '--db-name', help='Name of the database to create.', type=types.NonEmptyStringParamType(), - callback=get_quicksetup_database_name + callback=get_quicksetup_database_name, ) QUICKSETUP_DATABASE_USERNAME = options.DB_USERNAME.clone( @@ -243,7 +243,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume '--su-db-name', help='Name of the template database to connect to as the database superuser.', type=click.STRING, - default=DEFAULT_DBINFO['database'] + default=DEFAULT_DBINFO['database'], ) QUICKSETUP_SUPERUSER_DATABASE_PASSWORD = options.OverridableOption( @@ -275,13 +275,13 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume contextual_default=functools.partial( get_profile_attribute_default, ('storage.config.database_engine', 'postgresql_psycopg2') ), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_DATABASE_BACKEND = QUICKSETUP_DATABASE_BACKEND.clone( prompt='Database backend', contextual_default=functools.partial(get_profile_attribute_default, ('storage_backend', 'core.psql_dos')), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_DATABASE_HOSTNAME = QUICKSETUP_DATABASE_HOSTNAME.clone( @@ -289,7 +289,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume contextual_default=functools.partial( get_profile_attribute_default, ('storage.config.database_hostname', 'localhost') ), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_DATABASE_PORT = QUICKSETUP_DATABASE_PORT.clone( @@ -297,28 +297,28 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume contextual_default=functools.partial( get_profile_attribute_default, ('storage.config.database_port', DEFAULT_DBINFO['port']) ), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_DATABASE_NAME = QUICKSETUP_DATABASE_NAME.clone( prompt='Database name', required=True, contextual_default=functools.partial(get_profile_attribute_default, ('storage.config.database_name', None)), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_DATABASE_USERNAME = QUICKSETUP_DATABASE_USERNAME.clone( prompt='Database username', required=True, contextual_default=functools.partial(get_profile_attribute_default, ('storage.config.database_username', None)), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_DATABASE_PASSWORD = QUICKSETUP_DATABASE_PASSWORD.clone( prompt='Database password', required=True, contextual_default=functools.partial(get_profile_attribute_default, ('storage.config.database_password', None)), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_BROKER_PROTOCOL = QUICKSETUP_BROKER_PROTOCOL.clone( @@ -327,7 +327,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume contextual_default=functools.partial( get_profile_attribute_default, ('process_control.config.broker_protocol', BROKER_DEFAULTS.protocol) ), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_BROKER_USERNAME = QUICKSETUP_BROKER_USERNAME.clone( @@ -336,7 +336,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume contextual_default=functools.partial( get_profile_attribute_default, ('process_control.config.broker_username', BROKER_DEFAULTS.username) ), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_BROKER_PASSWORD = QUICKSETUP_BROKER_PASSWORD.clone( @@ -345,7 +345,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume contextual_default=functools.partial( get_profile_attribute_default, ('process_control.config.broker_password', BROKER_DEFAULTS.password) ), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_BROKER_HOST = QUICKSETUP_BROKER_HOST.clone( @@ -354,7 +354,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume contextual_default=functools.partial( get_profile_attribute_default, ('process_control.config.broker_host', BROKER_DEFAULTS.host) ), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_BROKER_PORT = QUICKSETUP_BROKER_PORT.clone( @@ -363,7 +363,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume contextual_default=functools.partial( get_profile_attribute_default, ('process_control.config.broker_port', BROKER_DEFAULTS.port) ), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_BROKER_VIRTUAL_HOST = QUICKSETUP_BROKER_VIRTUAL_HOST.clone( @@ -372,7 +372,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume contextual_default=functools.partial( get_profile_attribute_default, ('process_control.config.broker_virtual_host', BROKER_DEFAULTS.virtual_host) ), - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_REPOSITORY_URI = QUICKSETUP_REPOSITORY_URI.clone( @@ -380,7 +380,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume required=True, callback=None, # Unset the `callback` to define the default, which is instead done by the `contextual_default` contextual_default=get_repository_uri_default, - cls=options.interactive.InteractiveOption + cls=options.interactive.InteractiveOption, ) SETUP_TEST_PROFILE = options.OverridableOption( diff --git a/aiida/cmdline/params/options/conditional.py b/aiida/cmdline/params/options/conditional.py index f865a84233..df62279f3a 100644 --- a/aiida/cmdline/params/options/conditional.py +++ b/aiida/cmdline/params/options/conditional.py @@ -46,7 +46,7 @@ def process_value(self, ctx, value): return value def is_required(self, ctx): - """runs the given check on the context to determine requiredness""" + """Runs the given check on the context to determine requiredness""" if self.required_fn: return self.required_fn(ctx) diff --git a/aiida/cmdline/params/options/config.py b/aiida/cmdline/params/options/config.py index d8db0465b4..60193ecd1f 100644 --- a/aiida/cmdline/params/options/config.py +++ b/aiida/cmdline/params/options/config.py @@ -7,9 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=cyclic-import -""" -.. py:module::config +""".. py:module::config :synopsis: Convenience class for configuration file option The functions :func:`configuration_callback` and :func:`configuration_option` were directly taken from the repository @@ -30,7 +28,7 @@ __all__ = ('ConfigFileOption',) -def yaml_config_file_provider(handle, cmd_name): # pylint: disable=unused-argument +def yaml_config_file_provider(handle, cmd_name): """Read yaml config file from file handle.""" import yaml @@ -130,7 +128,7 @@ def decorator(func): 'dir_okay': False, 'writable': False, 'readable': True, - 'resolve_path': False + 'resolve_path': False, } path_params = {k: attrs.pop(k, v) for k, v in path_default_params.items()} attrs['type'] = attrs.get('type', click.Path(**path_params)) diff --git a/aiida/cmdline/params/options/interactive.py b/aiida/cmdline/params/options/interactive.py index f2f5489b56..198567186b 100644 --- a/aiida/cmdline/params/options/interactive.py +++ b/aiida/cmdline/params/options/interactive.py @@ -7,11 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -.. module::interactive - :synopsis: Tools and an option class for interactive parameter entry with - additional features such as help lookup. -""" +"""Tools and an option class for interactive parameter entry with additional features such as help lookup.""" import typing as t import click @@ -22,9 +18,7 @@ class InteractiveOption(ConditionalOption): - """ - Prompts for input, intercepting certain keyword arguments to replace :mod:`click`'s prompting - behaviour with a more feature-rich one. + """Prompts for input, intercepting certain keyword arguments to provide more feature-rich behavior. .. note:: This class has a parameter ``required_fn`` that can be passed to its ``__init__`` (inherited from the superclass :py:class:`~aiida.cmdline.params.options.conditional.ConditionalOption`) and a @@ -46,6 +40,7 @@ class InteractiveOption(ConditionalOption): @click.option('label', prompt='Label', cls=InteractiveOption) def foo(label): click.echo(f'Labeling with label: {label}') + """ PROMPT_COLOR = echo.COLORS['warning'] @@ -53,7 +48,8 @@ def foo(label): CHARACTER_IGNORE_DEFAULT = '!' def __init__(self, param_decls=None, prompt_fn=None, contextual_default=None, **kwargs): - """ + """Construct a new instance. + :param param_decls: relayed to :class:`click.Option` :param prompt_fn: callable(ctx) -> Bool, returns True if the option should be prompted for in interactive mode. :param contextual_default: An optional callback function to get a default which is passed the click context. @@ -151,7 +147,7 @@ def get_help_message(self): return message def get_default(self, ctx: click.Context, call: bool = True) -> t.Optional[t.Union[t.Any, t.Callable[[], t.Any]]]: - """provides the functionality of :meth:`click.Option.get_default`""" + """Provides the functionality of :meth:`click.Option.get_default`""" if ctx.resilient_parsing: return None diff --git a/aiida/cmdline/params/options/main.py b/aiida/cmdline/params/options/main.py index ec431c3cb7..72085fd8ad 100644 --- a/aiida/cmdline/params/options/main.py +++ b/aiida/cmdline/params/options/main.py @@ -14,26 +14,102 @@ from aiida.manage.external.postgres import DEFAULT_DBINFO from aiida.manage.external.rmq import BROKER_DEFAULTS +from ...utils import defaults, echo from .. import types -from ...utils import defaults, echo # pylint: disable=no-name-in-module from .callable import CallableDefaultOption from .config import ConfigFileOption from .multivalue import MultipleValueOption from .overridable import OverridableOption __all__ = ( - 'ALL', 'ALL_STATES', 'ALL_USERS', 'APPEND_TEXT', 'ARCHIVE_FORMAT', 'BROKER_HOST', 'BROKER_PASSWORD', 'BROKER_PORT', - 'BROKER_PROTOCOL', 'BROKER_USERNAME', 'BROKER_VIRTUAL_HOST', 'CALCULATION', 'CALCULATIONS', 'CALC_JOB_STATE', - 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'CONFIG_FILE', 'DATA', 'DATUM', 'DB_BACKEND', 'DB_ENGINE', 'DB_HOST', - 'DB_NAME', 'DB_PASSWORD', 'DB_PORT', 'DB_USERNAME', 'DEBUG', 'DESCRIPTION', 'DICT_FORMAT', 'DICT_KEYS', 'DRY_RUN', - 'EXIT_STATUS', 'EXPORT_FORMAT', 'FAILED', 'FORCE', 'FORMULA_MODE', 'FREQUENCY', 'GROUP', 'GROUPS', 'GROUP_CLEAR', - 'HOSTNAME', 'IDENTIFIER', 'INPUT_FORMAT', 'INPUT_PLUGIN', 'LABEL', 'LIMIT', 'NODE', 'NODES', 'NON_INTERACTIVE', - 'OLDER_THAN', 'ORDER_BY', 'ORDER_DIRECTION', 'PAST_DAYS', 'PAUSED', 'PORT', 'PREPEND_TEXT', 'PRINT_TRACEBACK', - 'PROCESS_LABEL', 'PROCESS_STATE', 'PROFILE', 'PROFILE_ONLY_CONFIG', 'PROFILE_SET_DEFAULT', 'PROJECT', 'RAW', - 'REPOSITORY_PATH', 'SCHEDULER', 'SILENT', 'TIMEOUT', 'TRAJECTORY_INDEX', 'TRANSPORT', 'TRAVERSAL_RULE_HELP_STRING', - 'TYPE_STRING', 'USER', 'USER_EMAIL', 'USER_FIRST_NAME', 'USER_INSTITUTION', 'USER_LAST_NAME', 'VERBOSITY', - 'VISUALIZATION_FORMAT', 'WAIT', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'active_process_states', - 'graph_traversal_rules', 'valid_calc_job_states', 'valid_process_states' + 'ALL', + 'ALL_STATES', + 'ALL_USERS', + 'APPEND_TEXT', + 'ARCHIVE_FORMAT', + 'BROKER_HOST', + 'BROKER_PASSWORD', + 'BROKER_PORT', + 'BROKER_PROTOCOL', + 'BROKER_USERNAME', + 'BROKER_VIRTUAL_HOST', + 'CALCULATION', + 'CALCULATIONS', + 'CALC_JOB_STATE', + 'CODE', + 'CODES', + 'COMPUTER', + 'COMPUTERS', + 'CONFIG_FILE', + 'DATA', + 'DATUM', + 'DB_BACKEND', + 'DB_ENGINE', + 'DB_HOST', + 'DB_NAME', + 'DB_PASSWORD', + 'DB_PORT', + 'DB_USERNAME', + 'DEBUG', + 'DESCRIPTION', + 'DICT_FORMAT', + 'DICT_KEYS', + 'DRY_RUN', + 'EXIT_STATUS', + 'EXPORT_FORMAT', + 'FAILED', + 'FORCE', + 'FORMULA_MODE', + 'FREQUENCY', + 'GROUP', + 'GROUPS', + 'GROUP_CLEAR', + 'HOSTNAME', + 'IDENTIFIER', + 'INPUT_FORMAT', + 'INPUT_PLUGIN', + 'LABEL', + 'LIMIT', + 'NODE', + 'NODES', + 'NON_INTERACTIVE', + 'OLDER_THAN', + 'ORDER_BY', + 'ORDER_DIRECTION', + 'PAST_DAYS', + 'PAUSED', + 'PORT', + 'PREPEND_TEXT', + 'PRINT_TRACEBACK', + 'PROCESS_LABEL', + 'PROCESS_STATE', + 'PROFILE', + 'PROFILE_ONLY_CONFIG', + 'PROFILE_SET_DEFAULT', + 'PROJECT', + 'RAW', + 'REPOSITORY_PATH', + 'SCHEDULER', + 'SILENT', + 'TIMEOUT', + 'TRAJECTORY_INDEX', + 'TRANSPORT', + 'TRAVERSAL_RULE_HELP_STRING', + 'TYPE_STRING', + 'USER', + 'USER_EMAIL', + 'USER_FIRST_NAME', + 'USER_INSTITUTION', + 'USER_LAST_NAME', + 'VERBOSITY', + 'VISUALIZATION_FORMAT', + 'WAIT', + 'WITH_ELEMENTS', + 'WITH_ELEMENTS_EXCLUSIVE', + 'active_process_states', + 'graph_traversal_rules', + 'valid_calc_job_states', + 'valid_process_states', ) TRAVERSAL_RULE_HELP_STRING = { @@ -55,23 +131,26 @@ def valid_process_states(): """Return a list of valid values for the ProcessState enum.""" from plumpy import ProcessState + return tuple(state.value for state in ProcessState) def valid_calc_job_states(): """Return a list of valid values for the CalcState enum.""" from aiida.common.datastructures import CalcJobState + return tuple(state.value for state in CalcJobState) def active_process_states(): """Return a list of process states that are considered active.""" from plumpy import ProcessState - return ([ + + return [ ProcessState.CREATED.value, ProcessState.WAITING.value, ProcessState.RUNNING.value, - ]) + ] def graph_traversal_rules(rules): @@ -139,7 +218,7 @@ def set_log_level(_ctx, _param, value): type=click.Choice(tuple(map(str.lower, LOG_LEVELS.keys())), case_sensitive=False), callback=set_log_level, expose_value=False, # Ensures that the option is not actually passed to the command, because it doesn't need it - help='Set the verbosity of the output.' + help='Set the verbosity of the output.', ) PROFILE = OverridableOption( @@ -149,7 +228,7 @@ def set_log_level(_ctx, _param, value): type=types.ProfileParamType(), default=defaults.get_default_profile, cls=CallableDefaultOption, - help='Execute the command for this profile instead of the default profile.' + help='Execute the command for this profile instead of the default profile.', ) CALCULATION = OverridableOption( @@ -157,7 +236,7 @@ def set_log_level(_ctx, _param, value): '--calculation', 'calculation', type=types.CalculationParamType(), - help='A single calculation identified by its ID or UUID.' + help='A single calculation identified by its ID or UUID.', ) CALCULATIONS = OverridableOption( @@ -166,7 +245,7 @@ def set_log_level(_ctx, _param, value): 'calculations', type=types.CalculationParamType(), cls=MultipleValueOption, - help='One or multiple calculations identified by their ID or UUID.' + help='One or multiple calculations identified by their ID or UUID.', ) CODE = OverridableOption( @@ -179,7 +258,7 @@ def set_log_level(_ctx, _param, value): 'codes', type=types.CodeParamType(), cls=MultipleValueOption, - help='One or multiple codes identified by their ID, UUID or label.' + help='One or multiple codes identified by their ID, UUID or label.', ) COMPUTER = OverridableOption( @@ -187,7 +266,7 @@ def set_log_level(_ctx, _param, value): '--computer', 'computer', type=types.ComputerParamType(), - help='A single computer identified by its ID, UUID or label.' + help='A single computer identified by its ID, UUID or label.', ) COMPUTERS = OverridableOption( @@ -196,7 +275,7 @@ def set_log_level(_ctx, _param, value): 'computers', type=types.ComputerParamType(), cls=MultipleValueOption, - help='One or multiple computers identified by their ID, UUID or label.' + help='One or multiple computers identified by their ID, UUID or label.', ) DATUM = OverridableOption( @@ -209,7 +288,7 @@ def set_log_level(_ctx, _param, value): 'data', type=types.DataParamType(), cls=MultipleValueOption, - help='One or multiple data identified by their ID, UUID or label.' + help='One or multiple data identified by their ID, UUID or label.', ) GROUP = OverridableOption( @@ -222,7 +301,7 @@ def set_log_level(_ctx, _param, value): 'groups', type=types.GroupParamType(), cls=MultipleValueOption, - help='One or multiple groups identified by their ID, UUID or label.' + help='One or multiple groups identified by their ID, UUID or label.', ) NODE = OverridableOption( @@ -235,7 +314,7 @@ def set_log_level(_ctx, _param, value): 'nodes', type=types.NodeParamType(), cls=MultipleValueOption, - help='One or multiple nodes identified by their ID or UUID.' + help='One or multiple nodes identified by their ID or UUID.', ) FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.') @@ -256,7 +335,7 @@ def set_log_level(_ctx, _param, value): type=click.Choice(['zip', 'zip-uncompressed', 'tar.gz']), default='zip', show_default=True, - help='The format of the archive file.' + help='The format of the archive file.', ) NON_INTERACTIVE = OverridableOption( @@ -264,7 +343,7 @@ def set_log_level(_ctx, _param, value): '--non-interactive', is_flag=True, is_eager=True, - help='In non-interactive mode, the CLI never prompts but simply uses default values for options that define one.' + help='In non-interactive mode, the CLI never prompts but simply uses default values for options that define one.', ) DRY_RUN = OverridableOption('-n', '--dry-run', is_flag=True, help='Perform a dry run.') @@ -274,7 +353,7 @@ def set_log_level(_ctx, _param, value): 'email', type=types.EmailType(), help='Email address associated with the data you generate. The email address is exported along with the data, ' - 'when sharing it.' + 'when sharing it.', ) USER_FIRST_NAME = OverridableOption( @@ -292,7 +371,7 @@ def set_log_level(_ctx, _param, value): required=True, help='Engine to use to connect to the database.', default='postgresql_psycopg2', - type=click.Choice(['postgresql_psycopg2']) + type=click.Choice(['postgresql_psycopg2']), ) DB_BACKEND = OverridableOption( @@ -300,7 +379,7 @@ def set_log_level(_ctx, _param, value): required=True, type=click.Choice(['core.psql_dos']), default='core.psql_dos', - help='Database backend to use.' + help='Database backend to use.', ) DB_HOST = OverridableOption( @@ -308,7 +387,7 @@ def set_log_level(_ctx, _param, value): required=True, type=types.HostnameType(), help='Database server host. Leave empty for "peer" authentication.', - default='localhost' + default='localhost', ) DB_PORT = OverridableOption( @@ -337,7 +416,7 @@ def set_log_level(_ctx, _param, value): type=click.Choice(('amqp', 'amqps')), default=BROKER_DEFAULTS.protocol, show_default=True, - help='Protocol to use for the message broker.' + help='Protocol to use for the message broker.', ) BROKER_USERNAME = OverridableOption( @@ -345,7 +424,7 @@ def set_log_level(_ctx, _param, value): type=types.NonEmptyStringParamType(), default=BROKER_DEFAULTS.username, show_default=True, - help='Username to use for authentication with the message broker.' + help='Username to use for authentication with the message broker.', ) BROKER_PASSWORD = OverridableOption( @@ -362,7 +441,7 @@ def set_log_level(_ctx, _param, value): type=types.HostnameType(), default=BROKER_DEFAULTS.host, show_default=True, - help='Hostname for the message broker.' + help='Hostname for the message broker.', ) BROKER_PORT = OverridableOption( @@ -378,7 +457,7 @@ def set_log_level(_ctx, _param, value): type=click.types.StringParamType(), default=BROKER_DEFAULTS.virtual_host, show_default=True, - help='Name of the virtual host for the message broker without leading forward slash.' + help='Name of the virtual host for the message broker without leading forward slash.', ) REPOSITORY_PATH = OverridableOption( @@ -410,14 +489,14 @@ def set_log_level(_ctx, _param, value): metavar='DESCRIPTION', default='', required=False, - help='A detailed description.' + help='A detailed description.', ) INPUT_PLUGIN = OverridableOption( '-P', '--input-plugin', type=types.PluginParamType(group='calculations', load=False), - help='Calculation input plugin string.' + help='Calculation input plugin string.', ) CALC_JOB_STATE = OverridableOption( @@ -426,7 +505,7 @@ def set_log_level(_ctx, _param, value): 'calc_job_state', type=types.LazyChoice(valid_calc_job_states), cls=MultipleValueOption, - help='Only include entries with this calculation job state.' + help='Only include entries with this calculation job state.', ) PROCESS_STATE = OverridableOption( @@ -436,7 +515,7 @@ def set_log_level(_ctx, _param, value): type=types.LazyChoice(valid_process_states), cls=MultipleValueOption, default=active_process_states, - help='Only include entries with this process state.' + help='Only include entries with this process state.', ) PAUSED = OverridableOption('--paused', 'paused', is_flag=True, help='Only include entries that are paused.') @@ -447,7 +526,7 @@ def set_log_level(_ctx, _param, value): 'process_label', type=click.STRING, required=False, - help='Only include entries whose process label matches this filter.' + help='Only include entries whose process label matches this filter.', ) TYPE_STRING = OverridableOption( @@ -457,7 +536,7 @@ def set_log_level(_ctx, _param, value): type=click.STRING, required=False, help='Only include entries whose type string matches this filter. Can include `_` to match a single arbitrary ' - 'character or `%` to match any number of characters.' + 'character or `%` to match any number of characters.', ) EXIT_STATUS = OverridableOption( @@ -483,7 +562,7 @@ def set_log_level(_ctx, _param, value): type=click.Choice(['id', 'ctime']), default='ctime', show_default=True, - help='Order the entries by this attribute.' + help='Order the entries by this attribute.', ) ORDER_DIRECTION = OverridableOption( @@ -493,7 +572,7 @@ def set_log_level(_ctx, _param, value): type=click.Choice(['asc', 'desc']), default='asc', show_default=True, - help='List the entries in ascending or descending order' + help='List the entries in ascending or descending order', ) PAST_DAYS = OverridableOption( @@ -502,7 +581,7 @@ def set_log_level(_ctx, _param, value): 'past_days', type=click.INT, metavar='PAST_DAYS', - help='Only include entries created in the last PAST_DAYS number of days.' + help='Only include entries created in the last PAST_DAYS number of days.', ) OLDER_THAN = OverridableOption( @@ -511,7 +590,7 @@ def set_log_level(_ctx, _param, value): 'older_than', type=click.INT, metavar='OLDER_THAN', - help='Only include entries created before OLDER_THAN days ago.' + help='Only include entries created before OLDER_THAN days ago.', ) ALL = OverridableOption( @@ -520,7 +599,7 @@ def set_log_level(_ctx, _param, value): 'all_entries', is_flag=True, default=False, - help='Include all entries, disregarding all other filter options and flags.' + help='Include all entries, disregarding all other filter options and flags.', ) ALL_STATES = OverridableOption('-A', '--all-states', is_flag=True, help='Do not limit to items in running state.') @@ -539,7 +618,7 @@ def set_log_level(_ctx, _param, value): 'raw', is_flag=True, default=False, - help='Display only raw query results, without any headers or footers.' + help='Display only raw query results, without any headers or footers.', ) HOSTNAME = OverridableOption('-H', '--hostname', type=types.HostnameType(), help='Hostname.') @@ -549,7 +628,7 @@ def set_log_level(_ctx, _param, value): '--transport', type=types.PluginParamType(group='transports'), required=True, - help='A transport plugin (as listed in `verdi plugin list aiida.transports`).' + help='A transport plugin (as listed in `verdi plugin list aiida.transports`).', ) SCHEDULER = OverridableOption( @@ -557,7 +636,7 @@ def set_log_level(_ctx, _param, value): '--scheduler', type=types.PluginParamType(group='schedulers'), required=True, - help='A scheduler plugin (as listed in `verdi plugin list aiida.schedulers`).' + help='A scheduler plugin (as listed in `verdi plugin list aiida.schedulers`).', ) USER = OverridableOption('-u', '--user', 'user', type=types.UserParamType(), help='Email address of the user.') @@ -572,13 +651,13 @@ def set_log_level(_ctx, _param, value): type=click.FLOAT, default=5.0, show_default=True, - help='Time in seconds to wait for a response before timing out.' + help='Time in seconds to wait for a response before timing out.', ) WAIT = OverridableOption( '--wait/--no-wait', default=False, - help='Wait for the action to be completed otherwise return as soon as it is scheduled.' + help='Wait for the action to be completed otherwise return as soon as it is scheduled.', ) FORMULA_MODE = OverridableOption( @@ -586,7 +665,7 @@ def set_log_level(_ctx, _param, value): '--formula-mode', type=click.Choice(['hill', 'hill_compact', 'reduce', 'group', 'count', 'count_compact']), default='hill', - help='Mode for printing the chemical formula.' + help='Mode for printing the chemical formula.', ) TRAJECTORY_INDEX = OverridableOption( @@ -595,7 +674,7 @@ def set_log_level(_ctx, _param, value): 'trajectory_index', type=click.INT, default=None, - help='Specific step of the Trajectory to select.' + help='Specific step of the Trajectory to select.', ) WITH_ELEMENTS = OverridableOption( @@ -605,7 +684,7 @@ def set_log_level(_ctx, _param, value): type=click.STRING, cls=MultipleValueOption, default=None, - help='Only select objects containing these elements.' + help='Only select objects containing these elements.', ) WITH_ELEMENTS_EXCLUSIVE = OverridableOption( @@ -615,13 +694,13 @@ def set_log_level(_ctx, _param, value): type=click.STRING, cls=MultipleValueOption, default=None, - help='Only select objects containing only these and no other elements.' + help='Only select objects containing only these and no other elements.', ) CONFIG_FILE = ConfigFileOption( '--config', type=types.FileOrUrl(), - help='Load option values from configuration file in yaml format (local path or URL).' + help='Load option values from configuration file in yaml format (local path or URL).', ) IDENTIFIER = OverridableOption( @@ -630,7 +709,7 @@ def set_log_level(_ctx, _param, value): 'identifier', help='The type of identifier used for specifying each node.', default='pk', - type=click.Choice(['pk', 'uuid']) + type=click.Choice(['pk', 'uuid']), ) DICT_FORMAT = OverridableOption( @@ -638,8 +717,8 @@ def set_log_level(_ctx, _param, value): '--format', 'fmt', type=click.Choice(list(echo.VALID_DICT_FORMATS_MAPPING.keys())), - default=list(echo.VALID_DICT_FORMATS_MAPPING.keys())[0], - help='The format of the output data.' + default=next(iter(echo.VALID_DICT_FORMATS_MAPPING.keys())), + help='The format of the output data.', ) DICT_KEYS = OverridableOption( diff --git a/aiida/cmdline/params/options/multivalue.py b/aiida/cmdline/params/options/multivalue.py index e9f8968144..535e8d3521 100644 --- a/aiida/cmdline/params/options/multivalue.py +++ b/aiida/cmdline/params/options/multivalue.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module to define multi value options for click. +"""Module to define multi value options for click. """ import click @@ -38,8 +37,7 @@ def collect_usage_pieces(self, ctx): class MultipleValueOption(click.Option): - """ - An option that can handle multiple values with a single flag. For example:: + """An option that can handle multiple values with a single flag. For example:: @click.option('-n', '--nodes', cls=MultipleValueOption) @@ -62,22 +60,18 @@ def __init__(self, *args, **kwargs): self._eat_all_parser = None def add_to_parser(self, parser, ctx): - """ - Override built in click method that allows us to specify a custom parser + """Override built in click method that allows us to specify a custom parser to eat up parameters until the following flag or 'endopt' (i.e. --) """ - # pylint: disable=protected-access super().add_to_parser(parser, ctx) def parser_process(value, state): - """ - The actual function that parses the options + """The actual function that parses the options :param value: The value to parse :param state: The state of the parser """ - # pylint: disable=invalid-name - ENDOPTS = '--' + ENDOPTS = '--' # noqa: N806 done = False value = [value] diff --git a/aiida/cmdline/params/options/overridable.py b/aiida/cmdline/params/options/overridable.py index fae2ca0aff..b874565aa5 100644 --- a/aiida/cmdline/params/options/overridable.py +++ b/aiida/cmdline/params/options/overridable.py @@ -7,12 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=cyclic-import -""" -.. py:module::overridable - :synopsis: Convenience class which can be used to defined a set of commonly used options that - can be easily reused and which improves consistency across the command line interface -""" +"""Convenience class which can be used to defined a set of commonly used options that can be easily reused.""" import click @@ -20,8 +15,7 @@ class OverridableOption: - """ - Wrapper around click option that increases reusability + """Wrapper around click option that increases reusability Click options are reusable already but sometimes it can improve the user interface to for example customize a help message for an option on a per-command basis. Sometimes the option should be prompted for if it is not given @@ -43,11 +37,11 @@ def ls_or_create(folder): @FOLDER(help='An existing folder', type=click.Path(exists=True, file_okay=False, readable=True) def ls(folder) click.echo(os.listdir(folder)) + """ def __init__(self, *args, **kwargs): - """ - Store the default args and kwargs. + """Store the default args and kwargs. :param args: default arguments to be used for the click option :param kwargs: default keyword arguments to be used that can be overridden in the call @@ -56,8 +50,7 @@ def __init__(self, *args, **kwargs): self.kwargs = kwargs def __call__(self, **kwargs): - """ - Override the stored kwargs, (ignoring args as we do not allow option name changes) and return the option. + """Override the stored kwargs, (ignoring args as we do not allow option name changes) and return the option. :param kwargs: keyword arguments that will override those set in the construction :return: click option constructed with args and kwargs defined during construction and call of this instance @@ -67,8 +60,7 @@ def __call__(self, **kwargs): return click.option(*self.args, **kw_copy) def clone(self, **kwargs): - """ - Create a new instance of the OverridableOption by cloning it and updating the stored kwargs with those passed. + """Create a new instance of by cloning the current instance and updating the stored kwargs with those passed. This can be useful when an already predefined OverridableOption needs to be further specified and reused by a set of sub commands. Example:: @@ -84,6 +76,7 @@ def clone(self, **kwargs): :return: OverridableOption instance with stored keyword arguments updated """ import copy + clone = copy.deepcopy(self) clone.kwargs.update(kwargs) return clone diff --git a/aiida/cmdline/params/types/__init__.py b/aiida/cmdline/params/types/__init__.py index 4607b6dcbe..22cce35f44 100644 --- a/aiida/cmdline/params/types/__init__.py +++ b/aiida/cmdline/params/types/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .calculation import * from .choice import * @@ -60,4 +59,4 @@ 'WorkflowParamType', ) -# yapf: enable +# fmt: on diff --git a/aiida/cmdline/params/types/calculation.py b/aiida/cmdline/params/types/calculation.py index 2e4c0d0750..31cbeca722 100644 --- a/aiida/cmdline/params/types/calculation.py +++ b/aiida/cmdline/params/types/calculation.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module for the calculation parameter type +"""Module for the calculation parameter type """ from .identifier import IdentifierParamType @@ -17,19 +16,17 @@ class CalculationParamType(IdentifierParamType): - """ - The ParamType for identifying Calculation entities or its subclasses - """ + """The ParamType for identifying Calculation entities or its subclasses""" name = 'Calculation' @property def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed + """Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed to be used to load the entity for a given identifier :return: the orm entity loader class for this ParamType """ from aiida.orm.utils.loaders import CalculationEntityLoader + return CalculationEntityLoader diff --git a/aiida/cmdline/params/types/choice.py b/aiida/cmdline/params/types/choice.py index c3d1ead2a0..8e1a20b3f5 100644 --- a/aiida/cmdline/params/types/choice.py +++ b/aiida/cmdline/params/types/choice.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -A custom click type that defines a lazy choice +"""A custom click type that defines a lazy choice """ import click @@ -16,8 +15,7 @@ class LazyChoice(click.ParamType): - """ - This is a delegate of click's Choice ParamType that evaluates the set of choices + """This is a delegate of click's Choice ParamType that evaluates the set of choices lazily. This is useful if the choices set requires an import that is slow. Using the vanilla click.Choice will call this on import which will slow down verdi and its autocomplete. This type will generate the choices set lazily through the @@ -37,8 +35,7 @@ def __init__(self, get_choices): @property def _click_choice(self): - """ - Get the internal click Choice object that we delegate functionality to. + """Get the internal click Choice object that we delegate functionality to. Will construct it lazily if necessary. :return: The click Choice diff --git a/aiida/cmdline/params/types/code.py b/aiida/cmdline/params/types/code.py index 32367323ff..014c14d1c8 100644 --- a/aiida/cmdline/params/types/code.py +++ b/aiida/cmdline/params/types/code.py @@ -18,9 +18,7 @@ class CodeParamType(IdentifierParamType): - """ - The ParamType for identifying Code entities or its subclasses - """ + """The ParamType for identifying Code entities or its subclasses""" name = 'Code' @@ -35,24 +33,24 @@ def __init__(self, sub_classes=None, entry_point=None): @property def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed + """Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed to be used to load the entity for a given identifier :return: the orm entity loader class for this ParamType """ from aiida.orm.utils.loaders import CodeEntityLoader + return CodeEntityLoader @decorators.with_dbenv() - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument + def shell_complete(self, ctx, param, incomplete): """Return possible completions based on an incomplete value. :returns: list of tuples of valid entry points (matching incomplete) and a description """ return [ click.shell_completion.CompletionItem(option) - for option, in self.orm_class_loader.get_options(incomplete, project='label') + for (option,) in self.orm_class_loader.get_options(incomplete, project='label') ] def convert(self, value, param, ctx): diff --git a/aiida/cmdline/params/types/computer.py b/aiida/cmdline/params/types/computer.py index 97dcfdc2a0..a1fb2e9333 100644 --- a/aiida/cmdline/params/types/computer.py +++ b/aiida/cmdline/params/types/computer.py @@ -7,49 +7,45 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module for the custom click param type computer +"""Module for the custom click param type computer """ from click.shell_completion import CompletionItem from click.types import StringParamType -from ...utils import decorators # pylint: disable=no-name-in-module +from ...utils import decorators from .identifier import IdentifierParamType __all__ = ('ComputerParamType', 'ShebangParamType', 'MpirunCommandParamType') class ComputerParamType(IdentifierParamType): - """ - The ParamType for identifying Computer entities or its subclasses - """ + """The ParamType for identifying Computer entities or its subclasses""" name = 'Computer' @property def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed + """Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed to be used to load the entity for a given identifier :return: the orm entity loader class for this ParamType """ from aiida.orm.utils.loaders import ComputerEntityLoader + return ComputerEntityLoader @decorators.with_dbenv() - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument + def shell_complete(self, ctx, param, incomplete): """Return possible completions based on an incomplete value. :returns: list of tuples of valid entry points (matching incomplete) and a description """ - return [CompletionItem(option) for option, in self.orm_class_loader.get_options(incomplete, project='label')] + return [CompletionItem(option) for (option,) in self.orm_class_loader.get_options(incomplete, project='label')] class ShebangParamType(StringParamType): - """ - Custom click param type for shebang line - """ + """Custom click param type for shebang line""" + name = 'shebangline' def convert(self, value, param, ctx): @@ -65,8 +61,7 @@ def __repr__(self): class MpirunCommandParamType(StringParamType): - """ - Custom click param type for mpirun-command + """Custom click param type for mpirun-command .. note:: requires also a scheduler to be provided, and the scheduler must be called first! @@ -76,6 +71,7 @@ class MpirunCommandParamType(StringParamType): Return a list of arguments (by using 'value.strip().split(" ") on the input string) """ + name = 'mpiruncommandstring' def __repr__(self): diff --git a/aiida/cmdline/params/types/config.py b/aiida/cmdline/params/types/config.py index 195516554c..231525bb71 100644 --- a/aiida/cmdline/params/types/config.py +++ b/aiida/cmdline/params/types/config.py @@ -26,9 +26,8 @@ def convert(self, value, param, ctx): return get_option(value) - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument - """ - Return possible completions based on an incomplete value + def shell_complete(self, ctx, param, incomplete): + """Return possible completions based on an incomplete value :returns: list of tuples of valid entry points (matching incomplete) and a description """ diff --git a/aiida/cmdline/params/types/data.py b/aiida/cmdline/params/types/data.py index 742dec10eb..c74af032f3 100644 --- a/aiida/cmdline/params/types/data.py +++ b/aiida/cmdline/params/types/data.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module for the custom click param type for data +"""Module for the custom click param type for data """ from .identifier import IdentifierParamType @@ -16,19 +15,17 @@ class DataParamType(IdentifierParamType): - """ - The ParamType for identifying Data entities or its subclasses - """ + """The ParamType for identifying Data entities or its subclasses""" name = 'Data' @property def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed + """Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed to be used to load the entity for a given identifier :return: the orm entity loader class for this ParamType """ from aiida.orm.utils.loaders import DataEntityLoader + return DataEntityLoader diff --git a/aiida/cmdline/params/types/group.py b/aiida/cmdline/params/types/group.py index fe55c7694c..5f2f91db28 100644 --- a/aiida/cmdline/params/types/group.py +++ b/aiida/cmdline/params/types/group.py @@ -56,17 +56,18 @@ def orm_class_loader(self): :return: the orm entity loader class for this `ParamType` """ from aiida.orm.utils.loaders import GroupEntityLoader + return GroupEntityLoader @decorators.with_dbenv() - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument + def shell_complete(self, ctx, param, incomplete): """Return possible completions based on an incomplete value. :returns: list of tuples of valid entry points (matching incomplete) and a description """ return [ click.shell_completion.CompletionItem(option) - for option, in self.orm_class_loader.get_options(incomplete, project='label') + for (option,) in self.orm_class_loader.get_options(incomplete, project='label') ] @decorators.with_dbenv() diff --git a/aiida/cmdline/params/types/identifier.py b/aiida/cmdline/params/types/identifier.py index 17358d5763..065db0d1f2 100644 --- a/aiida/cmdline/params/types/identifier.py +++ b/aiida/cmdline/params/types/identifier.py @@ -7,14 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module for custom click param type identifier +"""Module for custom click param type identifier """ from __future__ import annotations +import typing as t from abc import ABC, abstractmethod from functools import cached_property -import typing as t import click @@ -30,8 +29,7 @@ class IdentifierParamType(click.ParamType, ABC): - """ - An extension of click.ParamType for a generic identifier parameter. In AiiDA, orm entities can often be + """An extension of click.ParamType for a generic identifier parameter. In AiiDA, orm entities can often be identified by either their ID, UUID or optionally some LABEL identifier. This parameter type implements the convert method, which attempts to convert a value passed to the command for a parameter with this type, to an orm entity. The actual loading of the entity is delegated to the orm class loader. Subclasses of this @@ -40,8 +38,7 @@ class IdentifierParamType(click.ParamType, ABC): """ def __init__(self, sub_classes: tuple[str, ...] | None = None): - """ - Construct the parameter type, optionally specifying a tuple of entry points that reference classes + """Construct the parameter type, optionally specifying a tuple of entry points that reference classes that should be a sub class of the base orm class of the orm class loader. The classes pointed to by these entry points will be passed to the OrmEntityLoader when converting an identifier and they will restrict the query set by demanding that the class of the corresponding entity matches these sub classes. @@ -82,8 +79,7 @@ def _entry_points(self) -> list[EntryPoint]: @abstractmethod @with_dbenv() # type: ignore[misc] def orm_class_loader(self) -> OrmEntityLoader: - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed + """Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed to be used to load the entity for a given identifier :return: the orm entity loader class for this ParamType @@ -91,8 +87,7 @@ def orm_class_loader(self) -> OrmEntityLoader: @with_dbenv() # type: ignore[misc] def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context) -> t.Any: - """ - Attempt to convert the given value to an instance of the orm class using the orm class loader. + """Attempt to convert the given value to an instance of the orm class using the orm class loader. :return: the loaded orm entity :raises click.BadParameter: if the value is ambiguous and leads to multiple entities @@ -116,7 +111,6 @@ def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Contex # sub classes of the orm class loader and then pass it as the sub_class parameter to the load_entity call. # We store the loaded entry points in an instance variable, such that the loading only has to be done once. if self._entry_points and self._sub_classes is None: - sub_classes = [] for entry_point in self._entry_points: diff --git a/aiida/cmdline/params/types/multiple.py b/aiida/cmdline/params/types/multiple.py index 60154d3cc1..9b38fa4205 100644 --- a/aiida/cmdline/params/types/multiple.py +++ b/aiida/cmdline/params/types/multiple.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module to define custom click param type for multiple values +"""Module to define custom click param type for multiple values """ import click @@ -16,9 +15,7 @@ class MultipleValueParamType(click.ParamType): - """ - An extension of click.ParamType that can parse multiple values for a given ParamType - """ + """An extension of click.ParamType that can parse multiple values for a given ParamType""" def __init__(self, param_type): """Construct a new instance.""" diff --git a/aiida/cmdline/params/types/node.py b/aiida/cmdline/params/types/node.py index 7642eb22d5..4c26eb359a 100644 --- a/aiida/cmdline/params/types/node.py +++ b/aiida/cmdline/params/types/node.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module to define the custom click param type for node +"""Module to define the custom click param type for node """ from .identifier import IdentifierParamType @@ -16,19 +15,17 @@ class NodeParamType(IdentifierParamType): - """ - The ParamType for identifying Node entities or its subclasses - """ + """The ParamType for identifying Node entities or its subclasses""" name = 'Node' @property def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed + """Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed to be used to load the entity for a given identifier :return: the orm entity loader class for this ParamType """ from aiida.orm.utils.loaders import NodeEntityLoader + return NodeEntityLoader diff --git a/aiida/cmdline/params/types/path.py b/aiida/cmdline/params/types/path.py index 2d1e03196b..4ce946473f 100644 --- a/aiida/cmdline/params/types/path.py +++ b/aiida/cmdline/params/types/path.py @@ -129,6 +129,6 @@ def get_url(self, url, param, ctx): import urllib.request try: - return urllib.request.urlopen(url, timeout=self.timeout_seconds) # pylint: disable=consider-using-with + return urllib.request.urlopen(url, timeout=self.timeout_seconds) except (urllib.error.URLError, urllib.error.HTTPError, socket.timeout): self.fail(f'{self.name} "{url}" could not be reached within {self.timeout_seconds} s.\n', param, ctx) diff --git a/aiida/cmdline/params/types/plugin.py b/aiida/cmdline/params/types/plugin.py index eecc6b0755..b411f12694 100644 --- a/aiida/cmdline/params/types/plugin.py +++ b/aiida/cmdline/params/types/plugin.py @@ -37,8 +37,7 @@ class PluginParamType(EntryPointType): - """ - AiiDA Plugin name parameter type. + """AiiDA Plugin name parameter type. :param group: string or tuple of strings, where each is a valid entry point group. Adding the `aiida.` prefix is optional. If it is not detected it will be prepended internally. @@ -53,6 +52,7 @@ class PluginParamType(EntryPointType): click.option(... type=PluginParamType(group=('calculations', 'data')) """ + name = 'plugin' _factory_mapping = { @@ -68,11 +68,9 @@ class PluginParamType(EntryPointType): } def __init__(self, group: str | tuple[str] | None = None, load: bool = False, *args, **kwargs): - """ - group should be either a string or a tuple of valid entry point groups. + """Group should be either a string or a tuple of valid entry point groups. If it is not specified we use the tuple of all recognized entry point groups. """ - # pylint: disable=keyword-arg-before-vararg self.load = load self._input_group = group @@ -81,7 +79,6 @@ def __init__(self, group: str | tuple[str] | None = None, load: bool = False, *a @functools.cached_property def groups(self) -> tuple[str, ...]: """Returns a tuple of valid groups for this instance""" - group = self._input_group valid_entry_point_groups = get_entry_point_groups() @@ -98,9 +95,8 @@ def groups(self) -> tuple[str, ...]: groups = [] for grp in unvalidated_groups: - if not grp.startswith(ENTRY_POINT_GROUP_PREFIX): - grp = ENTRY_POINT_GROUP_PREFIX + grp + grp = ENTRY_POINT_GROUP_PREFIX + grp # noqa: PLW2901 if grp not in valid_entry_point_groups: raise ValueError(f'entry point group {grp} is not recognized') @@ -119,15 +115,13 @@ def _entry_point_names(self) -> list[str]: @property def has_potential_ambiguity(self) -> bool: - """ - Returns whether the set of supported entry point groups can lead to ambiguity when only an entry point name + """Returns whether the set of supported entry point groups can lead to ambiguity when only an entry point name is specified. This will happen if one ore more groups share an entry point with a common name """ return len(self._entry_point_names) != len(set(self._entry_point_names)) def get_valid_arguments(self) -> list[str]: - """ - Return a list of all available plugin names for the groups configured for this PluginParamType instance. + """Return a list of all available plugin names for the groups configured for this PluginParamType instance. If the entry point names are not unique, because there are multiple groups that contain an entry point that has an identical name, we need to prefix the names with the full group name @@ -140,9 +134,7 @@ def get_valid_arguments(self) -> list[str]: return sorted(self._entry_point_names) def get_possibilities(self, incomplete: str = '') -> list[str]: - """ - Return a list of plugins starting with incomplete - """ + """Return a list of plugins starting with incomplete""" if incomplete == '': return self.get_valid_arguments() @@ -166,31 +158,27 @@ def get_possibilities(self, incomplete: str = '') -> list[str]: def shell_complete( self, ctx: click.Context | None, param: click.Parameter | None, incomplete: str - ) -> list[click.shell_completion.CompletionItem]: # pylint: disable=unused-argument - """ - Return possible completions based on an incomplete value + ) -> list[click.shell_completion.CompletionItem]: + """Return possible completions based on an incomplete value :returns: list of tuples of valid entry points (matching incomplete) and a description """ return [click.shell_completion.CompletionItem(p) for p in self.get_possibilities(incomplete=incomplete)] - def get_missing_message(self, param: click.Parameter) -> str: # pylint: disable=unused-argument + def get_missing_message(self, param: click.Parameter) -> str: return 'Possible arguments are:\n\n' + '\n'.join(self.get_valid_arguments()) def get_entry_point_from_string(self, entry_point_string: str) -> EntryPoint: - """ - Validate a given entry point string, which means that it should have a valid entry point string format + """Validate a given entry point string, which means that it should have a valid entry point string format and that the entry point unambiguously corresponds to an entry point in the groups configured for this instance of PluginParameterType. :returns: the entry point if valid :raises: ValueError if the entry point string is invalid """ - entry_point_format = get_entry_point_string_format(entry_point_string) if entry_point_format in (EntryPointFormat.FULL, EntryPointFormat.PARTIAL): - group, name = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) if entry_point_format == EntryPointFormat.PARTIAL: @@ -199,7 +187,6 @@ def get_entry_point_from_string(self, entry_point_string: str) -> EntryPoint: self.validate_entry_point_group(group) elif entry_point_format == EntryPointFormat.MINIMAL: - name = entry_point_string matching_groups = {group for group, entry_point in self._entry_points if entry_point.name == name} @@ -212,8 +199,9 @@ def get_entry_point_from_string(self, entry_point_string: str) -> EntryPoint: ) elif not matching_groups: raise ValueError( - "entry point '{}' is not valid for any of the allowed " - 'entry point groups: {}'.format(name, ' '.join(self.groups)) + "entry point '{}' is not valid for any of the allowed " 'entry point groups: {}'.format( + name, ' '.join(self.groups) + ) ) group = matching_groups.pop() @@ -230,10 +218,10 @@ def validate_entry_point_group(self, group: str) -> None: if group not in self.groups: raise ValueError(f'entry point group `{group}` is not supported by this parameter.') - def convert(self, value: t.Any, param: click.Parameter | None, - ctx: click.Context | None) -> t.Union[EntryPoint, t.Any]: - """ - Convert the string value to an entry point instance, if the value can be successfully parsed + def convert( + self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None + ) -> t.Union[EntryPoint, t.Any]: + """Convert the string value to an entry point instance, if the value can be successfully parsed into an actual entry point. Will raise click.BadParameter if validation fails. """ from importlib_metadata import EntryPoint diff --git a/aiida/cmdline/params/types/process.py b/aiida/cmdline/params/types/process.py index 0cbe5abf65..52319e7243 100644 --- a/aiida/cmdline/params/types/process.py +++ b/aiida/cmdline/params/types/process.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module for the process node parameter type +"""Module for the process node parameter type """ from .identifier import IdentifierParamType @@ -17,19 +16,17 @@ class ProcessParamType(IdentifierParamType): - """ - The ParamType for identifying ProcessNode entities or its subclasses - """ + """The ParamType for identifying ProcessNode entities or its subclasses""" name = 'Process' @property def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed + """Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed to be used to load the entity for a given identifier :return: the orm entity loader class for this ParamType """ from aiida.orm.utils.loaders import ProcessEntityLoader + return ProcessEntityLoader diff --git a/aiida/cmdline/params/types/profile.py b/aiida/cmdline/params/types/profile.py index 562bbcaefa..ecd06bb901 100644 --- a/aiida/cmdline/params/types/profile.py +++ b/aiida/cmdline/params/types/profile.py @@ -73,7 +73,7 @@ def convert(self, value, param, ctx): return profile - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument + def shell_complete(self, ctx, param, incomplete): """Return possible completions based on an incomplete value :returns: list of tuples of valid entry points (matching incomplete) and a description diff --git a/aiida/cmdline/params/types/strings.py b/aiida/cmdline/params/types/strings.py index a81b1bceab..47cd72535a 100644 --- a/aiida/cmdline/params/types/strings.py +++ b/aiida/cmdline/params/types/strings.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module for various text-based string validation. +"""Module for various text-based string validation. """ import re @@ -20,6 +19,7 @@ class NonEmptyStringParamType(StringParamType): """Parameter whose values have to be string and non-empty.""" + name = 'nonemptystring' def convert(self, value, param, ctx): @@ -45,6 +45,7 @@ class LabelStringType(NonEmptyStringParamType): [1] See https://docs.python.org/3/library/re.html """ + name = 'labelstring' ALPHABET = r'\w\.\-' @@ -61,8 +62,10 @@ def __repr__(self): return 'LABELSTRING' -HOSTNAME_REGEX = \ -r'^([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])(\.([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]{0,61}[a-zA-Z0-9]))*$' +HOSTNAME_REGEX = re.compile( + r'^([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])' + r'(\.([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]{0,61}[a-zA-Z0-9]))*$' +) class HostnameType(StringParamType): @@ -70,12 +73,13 @@ class HostnameType(StringParamType): Regex according to https://stackoverflow.com/a/3824105/1069467 """ + name = 'hostname' def convert(self, value, param, ctx): newval = super().convert(value, param, ctx) - if newval and not re.match(HOSTNAME_REGEX, newval): + if newval and not HOSTNAME_REGEX.match(newval): self.fail('Please enter a valid hostname.') return newval @@ -89,6 +93,7 @@ class EmailType(StringParamType): .. note:: For the moment, we do not require the domain suffix, i.e. 'aiida@localhost' is still valid. """ + name = 'email' def convert(self, value, param, ctx): @@ -108,6 +113,7 @@ class EntryPointType(NonEmptyStringParamType): See https://packaging.python.org/en/latest/specifications/entry-points/ """ + name = 'entrypoint' def convert(self, value, param, ctx): diff --git a/aiida/cmdline/params/types/user.py b/aiida/cmdline/params/types/user.py index 69e9c24c30..f4e5963c5f 100644 --- a/aiida/cmdline/params/types/user.py +++ b/aiida/cmdline/params/types/user.py @@ -16,15 +16,12 @@ class UserParamType(click.ParamType): - """ - The user parameter type for click. Can get or create a user. - """ + """The user parameter type for click. Can get or create a user.""" + name = 'user' def __init__(self, create=False): - """ - :param create: If the user does not exist, create a new instance (unstored). - """ + """:param create: If the user does not exist, create a new instance (unstored).""" self._create = create @with_dbenv() @@ -45,9 +42,8 @@ def convert(self, value, param, ctx): return results[0] @with_dbenv() - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument - """ - Return possible completions based on an incomplete value + def shell_complete(self, ctx, param, incomplete): + """Return possible completions based on an incomplete value :returns: list of tuples of valid entry points (matching incomplete) and a description """ diff --git a/aiida/cmdline/params/types/workflow.py b/aiida/cmdline/params/types/workflow.py index 7403ff99f7..f4cc3d0874 100644 --- a/aiida/cmdline/params/types/workflow.py +++ b/aiida/cmdline/params/types/workflow.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module for the workflow parameter type +"""Module for the workflow parameter type """ from .identifier import IdentifierParamType @@ -17,19 +16,17 @@ class WorkflowParamType(IdentifierParamType): - """ - The ParamType for identifying WorkflowNode entities or its subclasses - """ + """The ParamType for identifying WorkflowNode entities or its subclasses""" name = 'WorkflowNode' @property def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed + """Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed to be used to load the entity for a given identifier :return: the orm entity loader class for this ParamType """ from aiida.orm.utils.loaders import WorkflowEntityLoader + return WorkflowEntityLoader diff --git a/aiida/cmdline/utils/__init__.py b/aiida/cmdline/utils/__init__.py index 31fe49878c..c6e62c925c 100644 --- a/aiida/cmdline/utils/__init__.py +++ b/aiida/cmdline/utils/__init__.py @@ -12,8 +12,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .ascii_vis import * from .common import * @@ -36,4 +35,4 @@ 'with_dbenv', ) -# yapf: enable +# fmt: on diff --git a/aiida/cmdline/utils/ascii_vis.py b/aiida/cmdline/utils/ascii_vis.py index c8e5751776..8b75edce13 100644 --- a/aiida/cmdline/utils/ascii_vis.py +++ b/aiida/cmdline/utils/ascii_vis.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utility functions to draw ASCII diagrams to the command line.""" +from typing import Optional + __all__ = ('format_call_graph',) TREE_LAST_ENTRY = '\u2514\u2500\u2500 ' @@ -32,11 +34,12 @@ def calc_info(node, call_link_label: bool = False) -> str: if call_link_label and (caller := node.caller): from aiida.common.links import LinkType - call_link = [ + + call_link = next( triple.link_label for triple in caller.base.links.get_outgoing(link_type=(LinkType.CALL_CALC, LinkType.CALL_WORK)).all() if triple.node.pk == node.pk - ][0] + ) else: call_link = None @@ -54,7 +57,7 @@ def calc_info(node, call_link_label: bool = False) -> str: return string -def format_call_graph(calc_node, max_depth: int = None, call_link_label: bool = False, info_fn=calc_info): +def format_call_graph(calc_node, max_depth: Optional[int] = None, call_link_label: bool = False, info_fn=calc_info): """Print a tree like the POSIX tree command for the calculation call graph. :param calc_node: The calculation node @@ -67,7 +70,9 @@ def format_call_graph(calc_node, max_depth: int = None, call_link_label: bool = return format_tree_descending(call_tree) -def build_call_graph(calc_node, max_depth: int = None, call_link_label: bool = False, info_fn=calc_info) -> str: +def build_call_graph( + calc_node, max_depth: Optional[int] = None, call_link_label: bool = False, info_fn=calc_info +) -> str: """Build the call graph of a given node. :param calc_node: The calculation node @@ -97,7 +102,6 @@ def build_call_graph(calc_node, max_depth: int = None, call_link_label: bool = F def format_tree_descending(tree, prefix='', pos=-1): """Format a descending tree.""" - # pylint: disable=too-many-branches text = [] if isinstance(tree, tuple): diff --git a/aiida/cmdline/utils/common.py b/aiida/cmdline/utils/common.py index cea7f22621..d05c15ba31 100644 --- a/aiida/cmdline/utils/common.py +++ b/aiida/cmdline/utils/common.py @@ -27,6 +27,7 @@ def tabulate(table, **kwargs): """A dummy wrapper to hide the import cost of tabulate""" import tabulate as tb + return tb.tabulate(table, **kwargs) @@ -49,7 +50,7 @@ def get_env_with_venv_bin(): warn_deprecation( '`get_env_with_venv_bin` function is deprecated use `aiida.engine.daemon.client.DaemonClient.get_env` instead.', - version=3 + version=3, ) config = get_config() @@ -63,8 +64,7 @@ def get_env_with_venv_bin(): def format_local_time(timestamp, format_str='%Y-%m-%d %H:%M:%S'): - """ - Format a datetime object or UNIX timestamp in a human readable format + """Format a datetime object or UNIX timestamp in a human readable format :param timestamp: a datetime object or a float representing a UNIX timestamp :param format_str: optional string format to pass to strftime @@ -78,8 +78,7 @@ def format_local_time(timestamp, format_str='%Y-%m-%d %H:%M:%S'): def print_last_process_state_change(process_type=None): - """ - Print the last time that a process of the specified type has changed its state. + """Print the last time that a process of the specified type has changed its state. :param process_type: optional process type for which to get the latest state change timestamp. Valid process types are either 'calculation' or 'work'. @@ -208,10 +207,9 @@ def format_flat_links(links, headers): table = [] for link_triple in links: - table.append([ - link_triple.link_label, link_triple.node.pk, - link_triple.node.base.attributes.get('process_label', '') - ]) + table.append( + [link_triple.link_label, link_triple.node.pk, link_triple.node.base.attributes.get('process_label', '')] + ) result = f'\n{tabulate(table, headers=headers)}' @@ -256,8 +254,7 @@ def format_recursive(links, depth=0): def get_calcjob_report(calcjob): - """ - Return a multi line string representation of the log messages and output of a given calcjob + """Return a multi line string representation of the log messages and output of a given calcjob :param calcjob: the calcjob node :return: a string representation of the log messages and scheduler output @@ -310,8 +307,7 @@ def get_calcjob_report(calcjob): def get_process_function_report(node): - """ - Return a multi line string representation of the log messages and output of a given process function node + """Return a multi line string representation of the log messages and output of a given process function node :param node: the node :return: a string representation of the log messages @@ -327,13 +323,11 @@ def get_process_function_report(node): def get_workchain_report(node: 'WorkChainNode', levelname, indent_size=4, max_depth=None): - """ - Return a multi line string representation of the log messages and output of a given workchain + """Return a multi line string representation of the log messages and output of a given workchain :param node: the workchain node :return: a nested string representation of the log messages """ - # pylint: disable=too-many-locals import itertools from aiida import orm @@ -349,8 +343,7 @@ def get_report_messages(uuid, depth, levelname): return [(_, depth) for _ in entries] def get_subtree(uuid, level=0): - """ - Get a nested tree of work calculation nodes and their nesting level starting from this uuid. + """Get a nested tree of work calculation nodes and their nesting level starting from this uuid. The result is a list of uuid of these nodes. """ builder = orm.QueryBuilder(backend=node.backend) @@ -362,7 +355,7 @@ def get_subtree(uuid, level=0): # for now, CALL links are the only ones allowing calc-calc # (we here really want instead to follow CALL links) with_incoming='workcalculation', - tag='subworkchains' + tag='subworkchains', ) result = builder.all(flat=True) @@ -400,7 +393,7 @@ def get_subtree(uuid, level=0): time=entry.time, width_id=width_id, width_levelname=width_levelname, - indent=' ' * (depth * indent_size) + indent=' ' * (depth * indent_size), ) report.append(line) @@ -437,7 +430,6 @@ def build_entries(ports): result = [] for name, port in sorted(ports.items(), key=lambda x: (not x[1].required, x[0])): - if name.startswith('_'): continue diff --git a/aiida/cmdline/utils/decorators.py b/aiida/cmdline/utils/decorators.py index 3f70c7612d..74c2e18ba7 100644 --- a/aiida/cmdline/utils/decorators.py +++ b/aiida/cmdline/utils/decorators.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Various decorators useful for creating verdi commands, for example loading the dbenv lazily. +"""Various decorators useful for creating verdi commands, for example loading the dbenv lazily. Always avoids trying to load the dbenv twice. When it has to be loaded, a spinner ASCII widget is displayed. @@ -35,6 +34,7 @@ def with_manager(wrapped, _, args, kwargs): """Decorate a function injecting a :class:`kiwipy.rmq.communicator.RmqCommunicator`.""" from aiida.manage import get_manager + kwargs['manager'] = get_manager() return wrapped(*args, **kwargs) @@ -85,8 +85,7 @@ def wrapper(wrapped, _, args, kwargs): @contextmanager def dbenv(): - """ - Loads the dbenv for a specific region of code, does not unload afterwards + """Loads the dbenv for a specific region of code, does not unload afterwards Only use when it makes it possible to avoid loading the dbenv for certain code paths @@ -195,8 +194,7 @@ def wrapper(wrapped, _, args, kwargs): @decorator def check_circus_zmq_version(wrapped, _, args, kwargs): - """ - Function decorator to check for the right ZMQ version before trying to run circus. + """Function decorator to check for the right ZMQ version before trying to run circus. Example:: @@ -206,6 +204,7 @@ def do_circus_stuff(): pass """ import zmq + try: zmq_version = [int(part) for part in zmq.__version__.split('.')[:2]] if len(zmq_version) < 2: @@ -222,12 +221,12 @@ def do_circus_stuff(): def deprecated_command(message): """Function decorator that will mark a click command as deprecated when invoked. - Example:: + Example:: - @click.command() - @deprecated_command('This command has been deprecated in AiiDA v1.0, please use 'foo' instead.) - def mycommand(): - pass + @click.command() + @deprecated_command('This command has been deprecated in AiiDA v1.0, please use 'foo' instead.) + def mycommand(): + pass """ @decorator diff --git a/aiida/cmdline/utils/defaults.py b/aiida/cmdline/utils/defaults.py index 43f825e827..cdaf8ab5c7 100644 --- a/aiida/cmdline/utils/defaults.py +++ b/aiida/cmdline/utils/defaults.py @@ -14,7 +14,7 @@ from aiida.manage.configuration import get_config -def get_default_profile(): # pylint: disable=unused-argument +def get_default_profile(): """Try to get the name of the default profile. This utility function should only be used for defaults or callbacks in command line interface parameters. diff --git a/aiida/cmdline/utils/echo.py b/aiida/cmdline/utils/echo.py index 887fcba732..e3154911c1 100644 --- a/aiida/cmdline/utils/echo.py +++ b/aiida/cmdline/utils/echo.py @@ -20,13 +20,20 @@ CMDLINE_LOGGER = logging.getLogger('verdi') __all__ = ( - 'echo_report', 'echo_info', 'echo_success', 'echo_warning', 'echo_error', 'echo_critical', 'echo_tabulate', - 'echo_dictionary' + 'echo_report', + 'echo_info', + 'echo_success', + 'echo_warning', + 'echo_error', + 'echo_critical', + 'echo_tabulate', + 'echo_dictionary', ) class ExitCode(enum.IntEnum): """Exit codes for the verdi command line.""" + CRITICAL = 1 DEPRECATED = 80 UNKNOWN = 99 @@ -185,7 +192,6 @@ def echo_deprecated(message: str, bold: bool = False, nl: bool = True, err: bool :param err: whether to log to stderr. :param exit: whether to exit after printing the message """ - # pylint: disable=redefined-builtin prefix = click.style('Deprecated: ', fg=COLORS['deprecated'], bold=True) echo_warning(prefix + message, bold=bold, nl=nl, err=err, prefix=False) @@ -232,7 +238,7 @@ def default_jsondump(data): if isinstance(data, datetime.datetime): return timezone.localtime(data).strftime('%Y-%m-%dT%H:%M:%S.%f%z') - raise TypeError(f'{repr(data)} is not JSON serializable') + raise TypeError(f'{data!r} is not JSON serializable') return json.dumps(dictionary, indent=4, sort_keys=sort_keys, default=default_jsondump) @@ -266,6 +272,7 @@ def echo_tabulate(table, **kwargs): :param kwargs: Additional arguments passed to :meth:`tabulate.tabulate`. """ from tabulate import tabulate + echo(tabulate(table, **kwargs)) @@ -294,5 +301,4 @@ def is_stdout_redirected(): echo.echo_info("Found {} results".format(qb.count()), err=echo.is_stdout_redirected) echo.echo(tabulate.tabulate(qb.all())) """ - # pylint: disable=no-member return not sys.stdout.isatty() diff --git a/aiida/cmdline/utils/log.py b/aiida/cmdline/utils/log.py index 893004d15a..2a8bc1f957 100644 --- a/aiida/cmdline/utils/log.py +++ b/aiida/cmdline/utils/log.py @@ -36,7 +36,7 @@ def emit(self, record): try: msg = self.format(record) click.echo(msg, err=err, nl=nl) - except Exception: # pylint: disable=broad-except + except Exception: self.handleError(record) diff --git a/aiida/cmdline/utils/multi_line_input.py b/aiida/cmdline/utils/multi_line_input.py index def61c2720..fa2cbc90b1 100644 --- a/aiida/cmdline/utils/multi_line_input.py +++ b/aiida/cmdline/utils/multi_line_input.py @@ -24,6 +24,7 @@ def edit_multiline_template(template_name, comment_marker='#=', extension=None, ``click.edit`` returned ``None``. """ from aiida.cmdline.utils.templates import env + template = env.get_template(template_name) rendered = template.render(**kwargs) content = click.edit(rendered, extension=extension) @@ -36,10 +37,9 @@ def edit_multiline_template(template_name, comment_marker='#=', extension=None, def edit_comment(old_cmt=''): - """ - call up an editor to edit comments to nodes in the database - """ + """Call up an editor to edit comments to nodes in the database""" from aiida.cmdline.utils.templates import env + template = env.get_template('new_cmt.txt.tpl') content = template.render(old_comment=old_cmt) mlinput = click.edit(content, extension='.txt') diff --git a/aiida/cmdline/utils/pluginable.py b/aiida/cmdline/utils/pluginable.py index 8850f11c88..d47a7c1f7e 100644 --- a/aiida/cmdline/utils/pluginable.py +++ b/aiida/cmdline/utils/pluginable.py @@ -31,7 +31,7 @@ def list_commands(self, ctx): return subcommands - def get_command(self, ctx, name): # pylint: disable=arguments-renamed + def get_command(self, ctx, name): """Try to load a subcommand from entry points, else defer to super.""" command = None if not self._exclude_external_plugins: diff --git a/aiida/cmdline/utils/query/calculation.py b/aiida/cmdline/utils/query/calculation.py index 3cf8c9898f..24aef422b2 100644 --- a/aiida/cmdline/utils/query/calculation.py +++ b/aiida/cmdline/utils/query/calculation.py @@ -7,9 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=unused-import """A utility module with a factory of standard QueryBuilder instances for Calculation nodes.""" from aiida.common.warnings import warn_deprecation -from aiida.tools.query.calculation import CalculationQueryBuilder warn_deprecation('This module is deprecated, use `aiida.tools.query.calculation` instead.', version=3) diff --git a/aiida/cmdline/utils/query/formatting.py b/aiida/cmdline/utils/query/formatting.py index 67549a27a0..ca95ac50fc 100644 --- a/aiida/cmdline/utils/query/formatting.py +++ b/aiida/cmdline/utils/query/formatting.py @@ -7,9 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=unused-import """A utility module with simple functions to format variables into strings for cli outputs.""" from aiida.common.warnings import warn_deprecation -from aiida.tools.query.formatting import format_process_state, format_relative_time, format_sealed, format_state warn_deprecation('This module is deprecated, use `aiida.tools.query.formatting` instead.', version=3) diff --git a/aiida/cmdline/utils/query/mapping.py b/aiida/cmdline/utils/query/mapping.py index 7c70831f4b..67f867f21e 100644 --- a/aiida/cmdline/utils/query/mapping.py +++ b/aiida/cmdline/utils/query/mapping.py @@ -7,9 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=unused-import """A utility module with mapper objects that map database entities projections on attributes and labels.""" from aiida.common.warnings import warn_deprecation -from aiida.tools.query.mapping import CalculationProjectionMapper, ProjectionMapper warn_deprecation('This module is deprecated, use `aiida.tools.query.mapping` instead.', version=3) diff --git a/aiida/cmdline/utils/repository.py b/aiida/cmdline/utils/repository.py index d50f73f25a..2e5cea71d2 100644 --- a/aiida/cmdline/utils/repository.py +++ b/aiida/cmdline/utils/repository.py @@ -23,7 +23,5 @@ def list_repository_contents(node, path, color): for entry in node.base.repository.list_objects(path): bold = bool(entry.file_type == FileType.DIRECTORY) echo.echo( - entry.name, - bold=bold, - fg=echo.COLORS['report'] if color and entry.file_type == FileType.DIRECTORY else None + entry.name, bold=bold, fg=echo.COLORS['report'] if color and entry.file_type == FileType.DIRECTORY else None ) diff --git a/aiida/cmdline/utils/shell.py b/aiida/cmdline/utils/shell.py index 2101583aec..30839d86ae 100644 --- a/aiida/cmdline/utils/shell.py +++ b/aiida/cmdline/utils/shell.py @@ -57,7 +57,7 @@ def ipython(): def bpython(): """Start a bpython shell.""" - import bpython as bpy_shell # pylint: disable=import-error + import bpython as bpy_shell user_ns = get_start_namespace() if user_ns: @@ -71,7 +71,6 @@ def bpython(): def run_shell(interface=None): """Start the chosen external shell.""" - available_shells = [AVAILABLE_SHELLS[interface]] if interface else AVAILABLE_SHELLS.values() # Try the specified or the available shells one by one until you @@ -114,7 +113,7 @@ def get_start_namespace(): def _ipython_pre_011(): """Start IPython pre-0.11""" - from IPython.Shell import IPShell # pylint: disable=import-error,no-name-in-module + from IPython.Shell import IPShell user_ns = get_start_namespace() if user_ns: @@ -126,7 +125,7 @@ def _ipython_pre_011(): def _ipython_pre_100(): """Start IPython pre-1.0.0""" - from IPython.frontend.terminal.ipapp import TerminalIPythonApp # pylint: disable=import-error,no-name-in-module + from IPython.frontend.terminal.ipapp import TerminalIPythonApp app = TerminalIPythonApp.instance() app.initialize(argv=[]) @@ -138,7 +137,7 @@ def _ipython_pre_100(): def _ipython(): """Start IPython >= 1.0""" - from IPython import start_ipython # pylint: disable=import-error,no-name-in-module + from IPython import start_ipython user_ns = get_start_namespace() if user_ns: diff --git a/aiida/cmdline/utils/templates.py b/aiida/cmdline/utils/templates.py index 591a805eb3..58428c456f 100644 --- a/aiida/cmdline/utils/templates.py +++ b/aiida/cmdline/utils/templates.py @@ -10,5 +10,4 @@ """Templates for input/output of verdi commands.""" from jinja2 import Environment, PackageLoader -#pylint: disable=invalid-name env = Environment(loader=PackageLoader('aiida', 'cmdline/templates')) diff --git a/aiida/common/__init__.py b/aiida/common/__init__.py index 3c68731ff0..6db17def4d 100644 --- a/aiida/common/__init__.py +++ b/aiida/common/__init__.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Common data structures, utility classes and functions +"""Common data structures, utility classes and functions .. note:: Modules in this sub package have to run without a loaded database environment @@ -16,8 +15,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .datastructures import * from .exceptions import * @@ -91,4 +89,4 @@ 'validate_link_label', ) -# yapf: enable +# fmt: on diff --git a/aiida/common/constants.py b/aiida/common/constants.py index 7a80dade9e..18e86b513e 100644 --- a/aiida/common/constants.py +++ b/aiida/common/constants.py @@ -28,580 +28,120 @@ # Element table, from NIST (http://www.nist.gov/pml/data/index.cfm) # Retrieved in October 2014 for atomic numbers 1-103, and in May 2016 or atomic numbers 104-112, 114 and 116. # In addition, element X is added to support unknown elements. -elements = { # pylint: disable=invalid-name - 0: { - 'mass': 1.00000, - 'name': 'Unknown', - 'symbol': 'X' - }, - 1: { - 'mass': 1.00794, - 'name': 'Hydrogen', - 'symbol': 'H' - }, - 2: { - 'mass': 4.002602, - 'name': 'Helium', - 'symbol': 'He' - }, - 3: { - 'mass': 6.941, - 'name': 'Lithium', - 'symbol': 'Li' - }, - 4: { - 'mass': 9.012182, - 'name': 'Beryllium', - 'symbol': 'Be' - }, - 5: { - 'mass': 10.811, - 'name': 'Boron', - 'symbol': 'B' - }, - 6: { - 'mass': 12.0107, - 'name': 'Carbon', - 'symbol': 'C' - }, - 7: { - 'mass': 14.0067, - 'name': 'Nitrogen', - 'symbol': 'N' - }, - 8: { - 'mass': 15.9994, - 'name': 'Oxygen', - 'symbol': 'O' - }, - 9: { - 'mass': 18.9984032, - 'name': 'Fluorine', - 'symbol': 'F' - }, - 10: { - 'mass': 20.1797, - 'name': 'Neon', - 'symbol': 'Ne' - }, - 11: { - 'mass': 22.98977, - 'name': 'Sodium', - 'symbol': 'Na' - }, - 12: { - 'mass': 24.305, - 'name': 'Magnesium', - 'symbol': 'Mg' - }, - 13: { - 'mass': 26.981538, - 'name': 'Aluminium', - 'symbol': 'Al' - }, - 14: { - 'mass': 28.0855, - 'name': 'Silicon', - 'symbol': 'Si' - }, - 15: { - 'mass': 30.973761, - 'name': 'Phosphorus', - 'symbol': 'P' - }, - 16: { - 'mass': 32.065, - 'name': 'Sulfur', - 'symbol': 'S' - }, - 17: { - 'mass': 35.453, - 'name': 'Chlorine', - 'symbol': 'Cl' - }, - 18: { - 'mass': 39.948, - 'name': 'Argon', - 'symbol': 'Ar' - }, - 19: { - 'mass': 39.0983, - 'name': 'Potassium', - 'symbol': 'K' - }, - 20: { - 'mass': 40.078, - 'name': 'Calcium', - 'symbol': 'Ca' - }, - 21: { - 'mass': 44.955912, - 'name': 'Scandium', - 'symbol': 'Sc' - }, - 22: { - 'mass': 47.867, - 'name': 'Titanium', - 'symbol': 'Ti' - }, - 23: { - 'mass': 50.9415, - 'name': 'Vanadium', - 'symbol': 'V' - }, - 24: { - 'mass': 51.9961, - 'name': 'Chromium', - 'symbol': 'Cr' - }, - 25: { - 'mass': 54.938045, - 'name': 'Manganese', - 'symbol': 'Mn' - }, - 26: { - 'mass': 55.845, - 'name': 'Iron', - 'symbol': 'Fe' - }, - 27: { - 'mass': 58.933195, - 'name': 'Cobalt', - 'symbol': 'Co' - }, - 28: { - 'mass': 58.6934, - 'name': 'Nickel', - 'symbol': 'Ni' - }, - 29: { - 'mass': 63.546, - 'name': 'Copper', - 'symbol': 'Cu' - }, - 30: { - 'mass': 65.38, - 'name': 'Zinc', - 'symbol': 'Zn' - }, - 31: { - 'mass': 69.723, - 'name': 'Gallium', - 'symbol': 'Ga' - }, - 32: { - 'mass': 72.64, - 'name': 'Germanium', - 'symbol': 'Ge' - }, - 33: { - 'mass': 74.9216, - 'name': 'Arsenic', - 'symbol': 'As' - }, - 34: { - 'mass': 78.96, - 'name': 'Selenium', - 'symbol': 'Se' - }, - 35: { - 'mass': 79.904, - 'name': 'Bromine', - 'symbol': 'Br' - }, - 36: { - 'mass': 83.798, - 'name': 'Krypton', - 'symbol': 'Kr' - }, - 37: { - 'mass': 85.4678, - 'name': 'Rubidium', - 'symbol': 'Rb' - }, - 38: { - 'mass': 87.62, - 'name': 'Strontium', - 'symbol': 'Sr' - }, - 39: { - 'mass': 88.90585, - 'name': 'Yttrium', - 'symbol': 'Y' - }, - 40: { - 'mass': 91.224, - 'name': 'Zirconium', - 'symbol': 'Zr' - }, - 41: { - 'mass': 92.90638, - 'name': 'Niobium', - 'symbol': 'Nb' - }, - 42: { - 'mass': 95.96, - 'name': 'Molybdenum', - 'symbol': 'Mo' - }, - 43: { - 'mass': 98.0, - 'name': 'Technetium', - 'symbol': 'Tc' - }, - 44: { - 'mass': 101.07, - 'name': 'Ruthenium', - 'symbol': 'Ru' - }, - 45: { - 'mass': 102.9055, - 'name': 'Rhodium', - 'symbol': 'Rh' - }, - 46: { - 'mass': 106.42, - 'name': 'Palladium', - 'symbol': 'Pd' - }, - 47: { - 'mass': 107.8682, - 'name': 'Silver', - 'symbol': 'Ag' - }, - 48: { - 'mass': 112.411, - 'name': 'Cadmium', - 'symbol': 'Cd' - }, - 49: { - 'mass': 114.818, - 'name': 'Indium', - 'symbol': 'In' - }, - 50: { - 'mass': 118.71, - 'name': 'Tin', - 'symbol': 'Sn' - }, - 51: { - 'mass': 121.76, - 'name': 'Antimony', - 'symbol': 'Sb' - }, - 52: { - 'mass': 127.6, - 'name': 'Tellurium', - 'symbol': 'Te' - }, - 53: { - 'mass': 126.90447, - 'name': 'Iodine', - 'symbol': 'I' - }, - 54: { - 'mass': 131.293, - 'name': 'Xenon', - 'symbol': 'Xe' - }, - 55: { - 'mass': 132.9054519, - 'name': 'Caesium', - 'symbol': 'Cs' - }, - 56: { - 'mass': 137.327, - 'name': 'Barium', - 'symbol': 'Ba' - }, - 57: { - 'mass': 138.90547, - 'name': 'Lanthanum', - 'symbol': 'La' - }, - 58: { - 'mass': 140.116, - 'name': 'Cerium', - 'symbol': 'Ce' - }, - 59: { - 'mass': 140.90765, - 'name': 'Praseodymium', - 'symbol': 'Pr' - }, - 60: { - 'mass': 144.242, - 'name': 'Neodymium', - 'symbol': 'Nd' - }, - 61: { - 'mass': 145.0, - 'name': 'Promethium', - 'symbol': 'Pm' - }, - 62: { - 'mass': 150.36, - 'name': 'Samarium', - 'symbol': 'Sm' - }, - 63: { - 'mass': 151.964, - 'name': 'Europium', - 'symbol': 'Eu' - }, - 64: { - 'mass': 157.25, - 'name': 'Gadolinium', - 'symbol': 'Gd' - }, - 65: { - 'mass': 158.92535, - 'name': 'Terbium', - 'symbol': 'Tb' - }, - 66: { - 'mass': 162.5, - 'name': 'Dysprosium', - 'symbol': 'Dy' - }, - 67: { - 'mass': 164.93032, - 'name': 'Holmium', - 'symbol': 'Ho' - }, - 68: { - 'mass': 167.259, - 'name': 'Erbium', - 'symbol': 'Er' - }, - 69: { - 'mass': 168.93421, - 'name': 'Thulium', - 'symbol': 'Tm' - }, - 70: { - 'mass': 173.054, - 'name': 'Ytterbium', - 'symbol': 'Yb' - }, - 71: { - 'mass': 174.9668, - 'name': 'Lutetium', - 'symbol': 'Lu' - }, - 72: { - 'mass': 178.49, - 'name': 'Hafnium', - 'symbol': 'Hf' - }, - 73: { - 'mass': 180.94788, - 'name': 'Tantalum', - 'symbol': 'Ta' - }, - 74: { - 'mass': 183.84, - 'name': 'Tungsten', - 'symbol': 'W' - }, - 75: { - 'mass': 186.207, - 'name': 'Rhenium', - 'symbol': 'Re' - }, - 76: { - 'mass': 190.23, - 'name': 'Osmium', - 'symbol': 'Os' - }, - 77: { - 'mass': 192.217, - 'name': 'Iridium', - 'symbol': 'Ir' - }, - 78: { - 'mass': 195.084, - 'name': 'Platinum', - 'symbol': 'Pt' - }, - 79: { - 'mass': 196.966569, - 'name': 'Gold', - 'symbol': 'Au' - }, - 80: { - 'mass': 200.59, - 'name': 'Mercury', - 'symbol': 'Hg' - }, - 81: { - 'mass': 204.3833, - 'name': 'Thallium', - 'symbol': 'Tl' - }, - 82: { - 'mass': 207.2, - 'name': 'Lead', - 'symbol': 'Pb' - }, - 83: { - 'mass': 208.9804, - 'name': 'Bismuth', - 'symbol': 'Bi' - }, - 84: { - 'mass': 209.0, - 'name': 'Polonium', - 'symbol': 'Po' - }, - 85: { - 'mass': 210.0, - 'name': 'Astatine', - 'symbol': 'At' - }, - 86: { - 'mass': 222.0, - 'name': 'Radon', - 'symbol': 'Rn' - }, - 87: { - 'mass': 223.0, - 'name': 'Francium', - 'symbol': 'Fr' - }, - 88: { - 'mass': 226.0, - 'name': 'Radium', - 'symbol': 'Ra' - }, - 89: { - 'mass': 227.0, - 'name': 'Actinium', - 'symbol': 'Ac' - }, - 90: { - 'mass': 232.03806, - 'name': 'Thorium', - 'symbol': 'Th' - }, - 91: { - 'mass': 231.03588, - 'name': 'Protactinium', - 'symbol': 'Pa' - }, - 92: { - 'mass': 238.02891, - 'name': 'Uranium', - 'symbol': 'U' - }, - 93: { - 'mass': 237.0, - 'name': 'Neptunium', - 'symbol': 'Np' - }, - 94: { - 'mass': 244.0, - 'name': 'Plutonium', - 'symbol': 'Pu' - }, - 95: { - 'mass': 243.0, - 'name': 'Americium', - 'symbol': 'Am' - }, - 96: { - 'mass': 247.0, - 'name': 'Curium', - 'symbol': 'Cm' - }, - 97: { - 'mass': 247.0, - 'name': 'Berkelium', - 'symbol': 'Bk' - }, - 98: { - 'mass': 251.0, - 'name': 'Californium', - 'symbol': 'Cf' - }, - 99: { - 'mass': 252.0, - 'name': 'Einsteinium', - 'symbol': 'Es' - }, - 100: { - 'mass': 257.0, - 'name': 'Fermium', - 'symbol': 'Fm' - }, - 101: { - 'mass': 258.0, - 'name': 'Mendelevium', - 'symbol': 'Md' - }, - 102: { - 'mass': 259.0, - 'name': 'Nobelium', - 'symbol': 'No' - }, - 103: { - 'mass': 262.0, - 'name': 'Lawrencium', - 'symbol': 'Lr' - }, - 104: { - 'mass': 267.0, - 'name': 'Rutherfordium', - 'symbol': 'Rf' - }, - 105: { - 'mass': 268.0, - 'name': 'Dubnium', - 'symbol': 'Db' - }, - 106: { - 'mass': 271.0, - 'name': 'Seaborgium', - 'symbol': 'Sg' - }, - 107: { - 'mass': 272.0, - 'name': 'Bohrium', - 'symbol': 'Bh' - }, - 108: { - 'mass': 270.0, - 'name': 'Hassium', - 'symbol': 'Hs' - }, - 109: { - 'mass': 276.0, - 'name': 'Meitnerium', - 'symbol': 'Mt' - }, - 110: { - 'mass': 281.0, - 'name': 'Darmstadtium', - 'symbol': 'Ds' - }, - 111: { - 'mass': 280.0, - 'name': 'Roentgenium', - 'symbol': 'Rg' - }, - 112: { - 'mass': 285.0, - 'name': 'Copernicium', - 'symbol': 'Cn' - }, - 114: { - 'mass': 289.0, - 'name': 'Flerovium', - 'symbol': 'Fl' - }, - 116: { - 'mass': 293.0, - 'name': 'Livermorium', - 'symbol': 'Lv' - }, +elements = { + 0: {'mass': 1.00000, 'name': 'Unknown', 'symbol': 'X'}, + 1: {'mass': 1.00794, 'name': 'Hydrogen', 'symbol': 'H'}, + 2: {'mass': 4.002602, 'name': 'Helium', 'symbol': 'He'}, + 3: {'mass': 6.941, 'name': 'Lithium', 'symbol': 'Li'}, + 4: {'mass': 9.012182, 'name': 'Beryllium', 'symbol': 'Be'}, + 5: {'mass': 10.811, 'name': 'Boron', 'symbol': 'B'}, + 6: {'mass': 12.0107, 'name': 'Carbon', 'symbol': 'C'}, + 7: {'mass': 14.0067, 'name': 'Nitrogen', 'symbol': 'N'}, + 8: {'mass': 15.9994, 'name': 'Oxygen', 'symbol': 'O'}, + 9: {'mass': 18.9984032, 'name': 'Fluorine', 'symbol': 'F'}, + 10: {'mass': 20.1797, 'name': 'Neon', 'symbol': 'Ne'}, + 11: {'mass': 22.98977, 'name': 'Sodium', 'symbol': 'Na'}, + 12: {'mass': 24.305, 'name': 'Magnesium', 'symbol': 'Mg'}, + 13: {'mass': 26.981538, 'name': 'Aluminium', 'symbol': 'Al'}, + 14: {'mass': 28.0855, 'name': 'Silicon', 'symbol': 'Si'}, + 15: {'mass': 30.973761, 'name': 'Phosphorus', 'symbol': 'P'}, + 16: {'mass': 32.065, 'name': 'Sulfur', 'symbol': 'S'}, + 17: {'mass': 35.453, 'name': 'Chlorine', 'symbol': 'Cl'}, + 18: {'mass': 39.948, 'name': 'Argon', 'symbol': 'Ar'}, + 19: {'mass': 39.0983, 'name': 'Potassium', 'symbol': 'K'}, + 20: {'mass': 40.078, 'name': 'Calcium', 'symbol': 'Ca'}, + 21: {'mass': 44.955912, 'name': 'Scandium', 'symbol': 'Sc'}, + 22: {'mass': 47.867, 'name': 'Titanium', 'symbol': 'Ti'}, + 23: {'mass': 50.9415, 'name': 'Vanadium', 'symbol': 'V'}, + 24: {'mass': 51.9961, 'name': 'Chromium', 'symbol': 'Cr'}, + 25: {'mass': 54.938045, 'name': 'Manganese', 'symbol': 'Mn'}, + 26: {'mass': 55.845, 'name': 'Iron', 'symbol': 'Fe'}, + 27: {'mass': 58.933195, 'name': 'Cobalt', 'symbol': 'Co'}, + 28: {'mass': 58.6934, 'name': 'Nickel', 'symbol': 'Ni'}, + 29: {'mass': 63.546, 'name': 'Copper', 'symbol': 'Cu'}, + 30: {'mass': 65.38, 'name': 'Zinc', 'symbol': 'Zn'}, + 31: {'mass': 69.723, 'name': 'Gallium', 'symbol': 'Ga'}, + 32: {'mass': 72.64, 'name': 'Germanium', 'symbol': 'Ge'}, + 33: {'mass': 74.9216, 'name': 'Arsenic', 'symbol': 'As'}, + 34: {'mass': 78.96, 'name': 'Selenium', 'symbol': 'Se'}, + 35: {'mass': 79.904, 'name': 'Bromine', 'symbol': 'Br'}, + 36: {'mass': 83.798, 'name': 'Krypton', 'symbol': 'Kr'}, + 37: {'mass': 85.4678, 'name': 'Rubidium', 'symbol': 'Rb'}, + 38: {'mass': 87.62, 'name': 'Strontium', 'symbol': 'Sr'}, + 39: {'mass': 88.90585, 'name': 'Yttrium', 'symbol': 'Y'}, + 40: {'mass': 91.224, 'name': 'Zirconium', 'symbol': 'Zr'}, + 41: {'mass': 92.90638, 'name': 'Niobium', 'symbol': 'Nb'}, + 42: {'mass': 95.96, 'name': 'Molybdenum', 'symbol': 'Mo'}, + 43: {'mass': 98.0, 'name': 'Technetium', 'symbol': 'Tc'}, + 44: {'mass': 101.07, 'name': 'Ruthenium', 'symbol': 'Ru'}, + 45: {'mass': 102.9055, 'name': 'Rhodium', 'symbol': 'Rh'}, + 46: {'mass': 106.42, 'name': 'Palladium', 'symbol': 'Pd'}, + 47: {'mass': 107.8682, 'name': 'Silver', 'symbol': 'Ag'}, + 48: {'mass': 112.411, 'name': 'Cadmium', 'symbol': 'Cd'}, + 49: {'mass': 114.818, 'name': 'Indium', 'symbol': 'In'}, + 50: {'mass': 118.71, 'name': 'Tin', 'symbol': 'Sn'}, + 51: {'mass': 121.76, 'name': 'Antimony', 'symbol': 'Sb'}, + 52: {'mass': 127.6, 'name': 'Tellurium', 'symbol': 'Te'}, + 53: {'mass': 126.90447, 'name': 'Iodine', 'symbol': 'I'}, + 54: {'mass': 131.293, 'name': 'Xenon', 'symbol': 'Xe'}, + 55: {'mass': 132.9054519, 'name': 'Caesium', 'symbol': 'Cs'}, + 56: {'mass': 137.327, 'name': 'Barium', 'symbol': 'Ba'}, + 57: {'mass': 138.90547, 'name': 'Lanthanum', 'symbol': 'La'}, + 58: {'mass': 140.116, 'name': 'Cerium', 'symbol': 'Ce'}, + 59: {'mass': 140.90765, 'name': 'Praseodymium', 'symbol': 'Pr'}, + 60: {'mass': 144.242, 'name': 'Neodymium', 'symbol': 'Nd'}, + 61: {'mass': 145.0, 'name': 'Promethium', 'symbol': 'Pm'}, + 62: {'mass': 150.36, 'name': 'Samarium', 'symbol': 'Sm'}, + 63: {'mass': 151.964, 'name': 'Europium', 'symbol': 'Eu'}, + 64: {'mass': 157.25, 'name': 'Gadolinium', 'symbol': 'Gd'}, + 65: {'mass': 158.92535, 'name': 'Terbium', 'symbol': 'Tb'}, + 66: {'mass': 162.5, 'name': 'Dysprosium', 'symbol': 'Dy'}, + 67: {'mass': 164.93032, 'name': 'Holmium', 'symbol': 'Ho'}, + 68: {'mass': 167.259, 'name': 'Erbium', 'symbol': 'Er'}, + 69: {'mass': 168.93421, 'name': 'Thulium', 'symbol': 'Tm'}, + 70: {'mass': 173.054, 'name': 'Ytterbium', 'symbol': 'Yb'}, + 71: {'mass': 174.9668, 'name': 'Lutetium', 'symbol': 'Lu'}, + 72: {'mass': 178.49, 'name': 'Hafnium', 'symbol': 'Hf'}, + 73: {'mass': 180.94788, 'name': 'Tantalum', 'symbol': 'Ta'}, + 74: {'mass': 183.84, 'name': 'Tungsten', 'symbol': 'W'}, + 75: {'mass': 186.207, 'name': 'Rhenium', 'symbol': 'Re'}, + 76: {'mass': 190.23, 'name': 'Osmium', 'symbol': 'Os'}, + 77: {'mass': 192.217, 'name': 'Iridium', 'symbol': 'Ir'}, + 78: {'mass': 195.084, 'name': 'Platinum', 'symbol': 'Pt'}, + 79: {'mass': 196.966569, 'name': 'Gold', 'symbol': 'Au'}, + 80: {'mass': 200.59, 'name': 'Mercury', 'symbol': 'Hg'}, + 81: {'mass': 204.3833, 'name': 'Thallium', 'symbol': 'Tl'}, + 82: {'mass': 207.2, 'name': 'Lead', 'symbol': 'Pb'}, + 83: {'mass': 208.9804, 'name': 'Bismuth', 'symbol': 'Bi'}, + 84: {'mass': 209.0, 'name': 'Polonium', 'symbol': 'Po'}, + 85: {'mass': 210.0, 'name': 'Astatine', 'symbol': 'At'}, + 86: {'mass': 222.0, 'name': 'Radon', 'symbol': 'Rn'}, + 87: {'mass': 223.0, 'name': 'Francium', 'symbol': 'Fr'}, + 88: {'mass': 226.0, 'name': 'Radium', 'symbol': 'Ra'}, + 89: {'mass': 227.0, 'name': 'Actinium', 'symbol': 'Ac'}, + 90: {'mass': 232.03806, 'name': 'Thorium', 'symbol': 'Th'}, + 91: {'mass': 231.03588, 'name': 'Protactinium', 'symbol': 'Pa'}, + 92: {'mass': 238.02891, 'name': 'Uranium', 'symbol': 'U'}, + 93: {'mass': 237.0, 'name': 'Neptunium', 'symbol': 'Np'}, + 94: {'mass': 244.0, 'name': 'Plutonium', 'symbol': 'Pu'}, + 95: {'mass': 243.0, 'name': 'Americium', 'symbol': 'Am'}, + 96: {'mass': 247.0, 'name': 'Curium', 'symbol': 'Cm'}, + 97: {'mass': 247.0, 'name': 'Berkelium', 'symbol': 'Bk'}, + 98: {'mass': 251.0, 'name': 'Californium', 'symbol': 'Cf'}, + 99: {'mass': 252.0, 'name': 'Einsteinium', 'symbol': 'Es'}, + 100: {'mass': 257.0, 'name': 'Fermium', 'symbol': 'Fm'}, + 101: {'mass': 258.0, 'name': 'Mendelevium', 'symbol': 'Md'}, + 102: {'mass': 259.0, 'name': 'Nobelium', 'symbol': 'No'}, + 103: {'mass': 262.0, 'name': 'Lawrencium', 'symbol': 'Lr'}, + 104: {'mass': 267.0, 'name': 'Rutherfordium', 'symbol': 'Rf'}, + 105: {'mass': 268.0, 'name': 'Dubnium', 'symbol': 'Db'}, + 106: {'mass': 271.0, 'name': 'Seaborgium', 'symbol': 'Sg'}, + 107: {'mass': 272.0, 'name': 'Bohrium', 'symbol': 'Bh'}, + 108: {'mass': 270.0, 'name': 'Hassium', 'symbol': 'Hs'}, + 109: {'mass': 276.0, 'name': 'Meitnerium', 'symbol': 'Mt'}, + 110: {'mass': 281.0, 'name': 'Darmstadtium', 'symbol': 'Ds'}, + 111: {'mass': 280.0, 'name': 'Roentgenium', 'symbol': 'Rg'}, + 112: {'mass': 285.0, 'name': 'Copernicium', 'symbol': 'Cn'}, + 114: {'mass': 289.0, 'name': 'Flerovium', 'symbol': 'Fl'}, + 116: {'mass': 293.0, 'name': 'Livermorium', 'symbol': 'Lv'}, } diff --git a/aiida/common/datastructures.py b/aiida/common/datastructures.py index 731237ec5d..4822c16715 100644 --- a/aiida/common/datastructures.py +++ b/aiida/common/datastructures.py @@ -36,8 +36,7 @@ class CalcJobState(Enum): class CalcInfo(DefaultFieldsAttributeDict): - """ - This object will store the data returned by the calculation plugin and to be + """This object will store the data returned by the calculation plugin and to be passed to the ExecManager. In the following descriptions all paths have to be considered relative @@ -90,14 +89,31 @@ class CalcInfo(DefaultFieldsAttributeDict): """ _default_fields = ( - 'job_environment', 'email', 'email_on_started', 'email_on_terminated', 'uuid', 'prepend_text', 'append_text', - 'num_machines', 'num_mpiprocs_per_machine', 'priority', 'max_wallclock_seconds', 'max_memory_kb', 'rerunnable', - 'retrieve_list', 'retrieve_temporary_list', 'local_copy_list', 'remote_copy_list', 'remote_symlink_list', - 'provenance_exclude_list', 'codes_info', 'codes_run_mode', 'skip_submit' + 'job_environment', + 'email', + 'email_on_started', + 'email_on_terminated', + 'uuid', + 'prepend_text', + 'append_text', + 'num_machines', + 'num_mpiprocs_per_machine', + 'priority', + 'max_wallclock_seconds', + 'max_memory_kb', + 'rerunnable', + 'retrieve_list', + 'retrieve_temporary_list', + 'local_copy_list', + 'remote_copy_list', + 'remote_symlink_list', + 'provenance_exclude_list', + 'codes_info', + 'codes_run_mode', + 'skip_submit', ) if TYPE_CHECKING: - job_environment: None | dict[str, str] email: None | str email_on_started: bool @@ -123,8 +139,7 @@ class CalcInfo(DefaultFieldsAttributeDict): class CodeInfo(DefaultFieldsAttributeDict): - """ - This attribute-dictionary contains the information needed to execute a code. + """This attribute-dictionary contains the information needed to execute a code. Possible attributes are: * ``cmdline_params``: a list of strings, containing parameters to be written on @@ -166,6 +181,7 @@ class CodeInfo(DefaultFieldsAttributeDict): on the remote computer) * ``code_uuid``: the uuid of the code associated to the CodeInfo """ + _default_fields = ( 'cmdline_params', # as a list of strings 'stdin_name', @@ -173,11 +189,10 @@ class CodeInfo(DefaultFieldsAttributeDict): 'stderr_name', 'join_files', 'withmpi', - 'code_uuid' + 'code_uuid', ) if TYPE_CHECKING: - cmdline_params: None | list[str] stdin_name: None | str stdout_name: None | str diff --git a/aiida/common/escaping.py b/aiida/common/escaping.py index 170cf5b82c..bd35bc0968 100644 --- a/aiida/common/escaping.py +++ b/aiida/common/escaping.py @@ -13,8 +13,7 @@ def escape_for_bash(str_to_escape, use_double_quotes=False): - """ - This function takes any string and escapes it in a way that + """This function takes any string and escapes it in a way that bash will interpret it as a single string. Explanation: @@ -77,8 +76,7 @@ def escape_for_sql_like(string): def get_regex_pattern_from_sql(sql_pattern): - r""" - Convert a string providing a pattern to match in SQL + r"""Convert a string providing a pattern to match in SQL syntax into a string performing the same match as a regex. SQL LIKE syntax summary: @@ -100,8 +98,7 @@ def get_regex_pattern_from_sql(sql_pattern): """ def tokenizer(string, tokens_to_apply): - """ - Recursive function that tokenizes a string using the provided tokens + """Recursive function that tokenizes a string using the provided tokens :param string: the string to tokenize :param tokens_to_apply: the list of tokens still to process (in order: the first will be processed first) @@ -122,8 +119,11 @@ def tokenizer(string, tokens_to_apply): # with ALL tokens passed (there could be more occurrences of tokens_to_apply[0]) # Instead, for the first part, we know that we found the FIRST occurrence of tokens_to_apply[0] # so I pass the list without the first element - return tokenizer(first, tokens_to_apply=tokens_to_apply[1:] - ) + dict(SQL_TO_REGEX_TOKENS)[sep] + tokenizer(rest, tokens_to_apply=tokens_to_apply) + return ( + tokenizer(first, tokens_to_apply=tokens_to_apply[1:]) + + dict(SQL_TO_REGEX_TOKENS)[sep] + + tokenizer(rest, tokens_to_apply=tokens_to_apply) + ) # Here sep is empty: it means also rest is empty, and we just # return (recursively) the tokenizer on the first part, avoiding # infinite loops @@ -139,8 +139,7 @@ def tokenizer(string, tokens_to_apply): def sql_string_match(string, pattern): - """ - Check if the string matches the provided pattern, + """Check if the string matches the provided pattern, specified using SQL syntax. See documentation of :py:func:`~aiida.common.escaping.get_regex_pattern_from_sql` diff --git a/aiida/common/exceptions.py b/aiida/common/exceptions.py index eec8b94446..fc5e5a7e9e 100644 --- a/aiida/common/exceptions.py +++ b/aiida/common/exceptions.py @@ -10,21 +10,53 @@ """Module that define the exceptions that are thrown by AiiDA's internal code.""" __all__ = ( - 'AiidaException', 'NotExistent', 'NotExistentAttributeError', 'NotExistentKeyError', 'MultipleObjectsError', - 'RemoteOperationError', 'ContentNotExistent', 'FailedError', 'StoringNotAllowed', 'ModificationNotAllowed', - 'IntegrityError', 'UniquenessError', 'EntryPointError', 'MissingEntryPointError', 'MultipleEntryPointError', - 'LoadingEntryPointError', 'InvalidEntryPointTypeError', 'InvalidOperation', 'ParsingError', 'InternalError', - 'PluginInternalError', 'ValidationError', 'ConfigurationError', 'ProfileConfigurationError', - 'MissingConfigurationError', 'ConfigurationVersionError', 'IncompatibleStorageSchema', 'CorruptStorage', - 'DbContentError', 'InputValidationError', 'FeatureNotAvailable', 'FeatureDisabled', 'LicensingException', - 'TestsNotAllowedError', 'UnsupportedSpeciesError', 'TransportTaskException', 'OutputParsingError', 'HashingError', - 'StorageMigrationError', 'LockedProfileError', 'LockingProfileError', 'ClosedStorage' + 'AiidaException', + 'NotExistent', + 'NotExistentAttributeError', + 'NotExistentKeyError', + 'MultipleObjectsError', + 'RemoteOperationError', + 'ContentNotExistent', + 'FailedError', + 'StoringNotAllowed', + 'ModificationNotAllowed', + 'IntegrityError', + 'UniquenessError', + 'EntryPointError', + 'MissingEntryPointError', + 'MultipleEntryPointError', + 'LoadingEntryPointError', + 'InvalidEntryPointTypeError', + 'InvalidOperation', + 'ParsingError', + 'InternalError', + 'PluginInternalError', + 'ValidationError', + 'ConfigurationError', + 'ProfileConfigurationError', + 'MissingConfigurationError', + 'ConfigurationVersionError', + 'IncompatibleStorageSchema', + 'CorruptStorage', + 'DbContentError', + 'InputValidationError', + 'FeatureNotAvailable', + 'FeatureDisabled', + 'LicensingException', + 'TestsNotAllowedError', + 'UnsupportedSpeciesError', + 'TransportTaskException', + 'OutputParsingError', + 'HashingError', + 'StorageMigrationError', + 'LockedProfileError', + 'LockingProfileError', + 'ClosedStorage', ) -class AiidaException(Exception): - """ - Base class for all AiiDA exceptions. +class AiidaException(Exception): # noqa: N818 + """Base class for all AiiDA exceptions. Each module will have its own subclass, inherited from this (e.g. ExecManagerException, TransportException, ...) @@ -32,74 +64,58 @@ class AiidaException(Exception): class NotExistent(AiidaException): - """ - Raised when the required entity does not exist. - """ + """Raised when the required entity does not exist.""" class NotExistentAttributeError(AttributeError, NotExistent): - """ - Raised when the required entity does not exist, when fetched as an attribute. - """ + """Raised when the required entity does not exist, when fetched as an attribute.""" class NotExistentKeyError(KeyError, NotExistent): - """ - Raised when the required entity does not exist, when fetched as a dictionary key. - """ + """Raised when the required entity does not exist, when fetched as a dictionary key.""" class MultipleObjectsError(AiidaException): - """ - Raised when more than one entity is found in the DB, but only one was + """Raised when more than one entity is found in the DB, but only one was expected. """ class RemoteOperationError(AiidaException): - """ - Raised when an error in a remote operation occurs, as in a failed kill() + """Raised when an error in a remote operation occurs, as in a failed kill() of a scheduler job. """ class ContentNotExistent(NotExistent): - """ - Raised when trying to access an attribute, a key or a file in the result + """Raised when trying to access an attribute, a key or a file in the result nodes that is not present """ class FailedError(AiidaException): - """ - Raised when accessing a calculation that is in the FAILED status - """ + """Raised when accessing a calculation that is in the FAILED status""" class StoringNotAllowed(AiidaException): - """ - Raised when the user tries to store an unstorable node (e.g. a base Node class) - """ + """Raised when the user tries to store an unstorable node (e.g. a base Node class)""" class ModificationNotAllowed(AiidaException): - """ - Raised when the user tries to modify a field, object, property, ... that should not + """Raised when the user tries to modify a field, object, property, ... that should not be modified. """ class IntegrityError(AiidaException): - """ - Raised when there is an underlying data integrity error. This can be database related + """Raised when there is an underlying data integrity error. This can be database related or a general data integrity error. This can happen if, e.g., a foreign key check fails. See PEP 249 for details. """ class UniquenessError(AiidaException): - """ - Raised when the user tries to violate a uniqueness constraint (on the + """Raised when the user tries to violate a uniqueness constraint (on the DB, for instance). """ @@ -125,60 +141,46 @@ class InvalidEntryPointTypeError(EntryPointError): class InvalidOperation(AiidaException): - """ - The allowed operation is not valid (e.g., when trying to add a non-internal attribute + """The allowed operation is not valid (e.g., when trying to add a non-internal attribute before saving the entry), or deleting an entry that is protected (e.g., because it is referenced by foreign keys) """ class ParsingError(AiidaException): - """ - Generic error raised when there is a parsing error - """ + """Generic error raised when there is a parsing error""" class InternalError(AiidaException): - """ - Error raised when there is an internal error of AiiDA. - """ + """Error raised when there is an internal error of AiiDA.""" class PluginInternalError(InternalError): - """ - Error raised when there is an internal error which is due to a plugin + """Error raised when there is an internal error which is due to a plugin and not to the AiiDA infrastructure. """ class ValidationError(AiidaException): - """ - Error raised when there is an error during the validation phase + """Error raised when there is an error during the validation phase of a property. """ class ConfigurationError(AiidaException): - """ - Error raised when there is a configuration error in AiiDA. - """ + """Error raised when there is a configuration error in AiiDA.""" class ProfileConfigurationError(ConfigurationError): - """ - Configuration error raised when a wrong/inexistent profile is requested. - """ + """Configuration error raised when a wrong/inexistent profile is requested.""" class MissingConfigurationError(ConfigurationError): - """ - Configuration error raised when the configuration file is missing. - """ + """Configuration error raised when the configuration file is missing.""" class ConfigurationVersionError(ConfigurationError): - """ - Configuration error raised when the configuration file version is not + """Configuration error raised when the configuration file version is not compatible with the current version. """ @@ -187,11 +189,11 @@ class ClosedStorage(AiidaException): """Raised when trying to access data from a closed storage backend.""" -class UnreachableStorage(ConfigurationError): +class UnreachableStorage(ConfigurationError): # noqa: N818 """Raised when a connection to the storage backend fails.""" -class IncompatibleDatabaseSchema(ConfigurationError): +class IncompatibleDatabaseSchema(ConfigurationError): # noqa: N818 """Raised when the storage schema is incompatible with that of the code. Deprecated for ``IncompatibleStorageSchema`` @@ -202,7 +204,7 @@ class IncompatibleStorageSchema(IncompatibleDatabaseSchema): """Raised when the storage schema is incompatible with that of the code.""" -class CorruptStorage(ConfigurationError): +class CorruptStorage(ConfigurationError): # noqa: N818 """Raised when the storage is not found to be internally consistent on validation.""" @@ -218,84 +220,62 @@ class StorageMigrationError(DatabaseMigrationError): class DbContentError(AiidaException): - """ - Raised when the content of the DB is not valid. + """Raised when the content of the DB is not valid. This should never happen if the user does not play directly with the DB. """ class InputValidationError(ValidationError): - """ - The input data for a calculation did not validate (e.g., missing + """The input data for a calculation did not validate (e.g., missing required input data, wrong data, ...) """ class FeatureNotAvailable(AiidaException): - """ - Raised when a feature is requested from a plugin, that is not available. - """ + """Raised when a feature is requested from a plugin, that is not available.""" class FeatureDisabled(AiidaException): - """ - Raised when a feature is requested, but the user has chosen to disable + """Raised when a feature is requested, but the user has chosen to disable it (e.g., for submissions on disabled computers). """ class LicensingException(AiidaException): - """ - Raised when requirements for data licensing are not met. - """ + """Raised when requirements for data licensing are not met.""" class TestsNotAllowedError(AiidaException): - """ - Raised when tests are required to be run/loaded, but we are not in a testing environment. + """Raised when tests are required to be run/loaded, but we are not in a testing environment. This is to prevent data loss. """ class UnsupportedSpeciesError(ValueError): - """ - Raised when StructureData operations are fed species that are not supported by AiiDA such as Deuterium - """ + """Raised when StructureData operations are fed species that are not supported by AiiDA such as Deuterium""" class TransportTaskException(AiidaException): - """ - Raised when a TransportTask, an task to be completed by the engine that requires transport, fails - """ + """Raised when a TransportTask, an task to be completed by the engine that requires transport, fails""" class OutputParsingError(ParsingError): - """ - Can be raised by a Parser when it fails to parse the output generated by a `CalcJob` process. - """ + """Can be raised by a Parser when it fails to parse the output generated by a `CalcJob` process.""" class CircusCallError(AiidaException): - """ - Raised when an attempt to contact Circus returns an error in the response - """ + """Raised when an attempt to contact Circus returns an error in the response""" class HashingError(AiidaException): - """ - Raised when an attempt to hash an object fails via a known failure mode - """ + """Raised when an attempt to hash an object fails via a known failure mode""" class LockedProfileError(AiidaException): - """ - Raised if attempting to access a locked profile - """ + """Raised if attempting to access a locked profile""" class LockingProfileError(AiidaException): - """ - Raised if the profile can`t be locked - """ + """Raised if the profile can`t be locked""" diff --git a/aiida/common/extendeddicts.py b/aiida/common/extendeddicts.py index 9a3c3ee3b8..0f5192afb3 100644 --- a/aiida/common/extendeddicts.py +++ b/aiida/common/extendeddicts.py @@ -15,9 +15,8 @@ __all__ = ('AttributeDict', 'FixedFieldsAttributeDict', 'DefaultFieldsAttributeDict') -class AttributeDict(dict): # pylint: disable=too-many-instance-attributes - """ - This class internally stores values in a dictionary, but exposes +class AttributeDict(dict): + """This class internally stores values in a dictionary, but exposes the keys also as attributes, i.e. asking for attrdict.key will return the value of attrdict['key'] and so on. @@ -95,8 +94,7 @@ def __dir__(self): class FixedFieldsAttributeDict(AttributeDict): - """ - A dictionary with access to the keys as attributes, and with filtering + """A dictionary with access to the keys as attributes, and with filtering of valid attributes. This is only the base class, without valid attributes; use a derived class to do the actual work. @@ -105,6 +103,7 @@ class FixedFieldsAttributeDict(AttributeDict): class TestExample(FixedFieldsAttributeDict): _valid_fields = ('a','b','c') """ + _valid_fields = tuple() def __init__(self, init=None): @@ -118,18 +117,14 @@ def __init__(self, init=None): super().__init__(init) def __setitem__(self, item, value): - """ - Set a key as an attribute. - """ + """Set a key as an attribute.""" if item not in self._valid_fields: errmsg = f"'{item}' is not a valid key for object '{self.__class__.__name__}'" raise KeyError(errmsg) super().__setitem__(item, value) def __setattr__(self, attr, value): - """ - Overridden to allow direct access to fields with underscore. - """ + """Overridden to allow direct access to fields with underscore.""" if attr.startswith('_'): object.__setattr__(self, attr, value) else: @@ -137,9 +132,7 @@ def __setattr__(self, attr, value): @classmethod def get_valid_fields(cls): - """ - Return the list of valid fields. - """ + """Return the list of valid fields.""" return cls._valid_fields def __dir__(self): @@ -147,8 +140,7 @@ def __dir__(self): class DefaultFieldsAttributeDict(AttributeDict): - """ - A dictionary with access to the keys as attributes, and with an + """A dictionary with access to the keys as attributes, and with an internal value storing the 'default' keys to be distinguished from extra fields. @@ -199,13 +191,11 @@ class TestExample(DefaultFieldsAttributeDict): See if we want that setting a default field to None means deleting it. """ - # pylint: disable=invalid-name + _default_fields = tuple() def validate(self): - """ - Validate the keys, if any ``validate_*`` method is available. - """ + """Validate the keys, if any ``validate_*`` method is available.""" for key in self.get_default_fields(): # I get the attribute starting with validate_ and containing the name of the key # I set a dummy function if there is no validate_KEY function defined @@ -217,17 +207,14 @@ def validate(self): raise exceptions.ValidationError(f"Invalid value for key '{key}' [{exc.__class__.__name__}]: {exc}") def __setattr__(self, attr, value): - """ - Overridden to allow direct access to fields with underscore. - """ + """Overridden to allow direct access to fields with underscore.""" if attr.startswith('_'): object.__setattr__(self, attr, value) else: super().__setattr__(attr, value) def __getitem__(self, key): - """ - Return None instead of raising an exception if the key does not exist + """Return None instead of raising an exception if the key does not exist but is in the list of default fields. """ try: @@ -239,19 +226,13 @@ def __getitem__(self, key): @classmethod def get_default_fields(cls): - """ - Return the list of default fields, either defined in the instance or not. - """ + """Return the list of default fields, either defined in the instance or not.""" return list(cls._default_fields) def defaultkeys(self): - """ - Return the default keys defined in the instance. - """ + """Return the default keys defined in the instance.""" return [_ for _ in self.keys() if _ in self._default_fields] def extrakeys(self): - """ - Return the extra keys defined in the instance. - """ + """Return the extra keys defined in the instance.""" return [_ for _ in self.keys() if _ not in self._default_fields] diff --git a/aiida/common/files.py b/aiida/common/files.py index e3e8b3a0ae..05844a54a8 100644 --- a/aiida/common/files.py +++ b/aiida/common/files.py @@ -44,8 +44,7 @@ def md5_file(filepath, block_size_factor=128): def sha1_file(filename, block_size_factor=128): - """ - Open a file and return its sha1sum (hexdigested). + """Open a file and return its sha1sum (hexdigested). :param filename: the filename of the file for which we want the sha1sum :param block_size_factor: the file is read at chunks of size diff --git a/aiida/common/folders.py b/aiida/common/folders.py index de97b65647..dbf7809e8c 100644 --- a/aiida/common/folders.py +++ b/aiida/common/folders.py @@ -30,8 +30,7 @@ class Folder: - """ - A class to manage generic folders, avoiding to get out of + """A class to manage generic folders, avoiding to get out of specific given folder borders. .. todo:: @@ -66,9 +65,7 @@ def __init__(self, abspath, folder_limit=None): @property def mode_dir(self): - """ - Return the mode with which the folders should be created - """ + """Return the mode with which the folders should be created""" if GROUP_WRITABLE: return 0o770 @@ -76,17 +73,14 @@ def mode_dir(self): @property def mode_file(self): - """ - Return the mode with which the files should be created - """ + """Return the mode with which the files should be created""" if GROUP_WRITABLE: return 0o660 return 0o600 def get_subfolder(self, subfolder, create=False, reset_limit=False): - """ - Return a Folder object pointing to a subfolder. + """Return a Folder object pointing to a subfolder. :param subfolder: a string with the relative path of the subfolder, relative to the absolute path of this object. Note that @@ -163,7 +157,6 @@ def insert_path(self, src, dest_name=None, overwrite=True): :param overwrite: if ``False``, raises an error on existing destination; otherwise, delete it first. """ - # pylint: disable=too-many-branches if dest_name is None: filename = str(os.path.basename(src)) else: @@ -276,8 +269,7 @@ def get_abs_path(self, relpath, check_existence=False): @contextlib.contextmanager def open(self, name, mode='r', encoding='utf8', check_existence=False): - """ - Open a file in the current folder and return the corresponding file object. + """Open a file in the current folder and return the corresponding file object. :param check_existence: if False, just return the file path. Otherwise, also check if the file or directory actually exists. @@ -291,41 +283,32 @@ def open(self, name, mode='r', encoding='utf8', check_existence=False): @property def abspath(self): - """ - The absolute path of the folder. - """ + """The absolute path of the folder.""" return self._abspath @property def folder_limit(self): - """ - The folder limit that cannot be crossed when creating files and folders. - """ + """The folder limit that cannot be crossed when creating files and folders.""" return self._folder_limit def exists(self): - """ - Return True if the folder exists, False otherwise. - """ + """Return True if the folder exists, False otherwise.""" return os.path.exists(self.abspath) def isfile(self, relpath): - """ - Return True if 'relpath' exists inside the folder and is a file, + """Return True if 'relpath' exists inside the folder and is a file, False otherwise. """ return os.path.isfile(os.path.join(self.abspath, relpath)) def isdir(self, relpath): - """ - Return True if 'relpath' exists inside the folder and is a directory, + """Return True if 'relpath' exists inside the folder and is a directory, False otherwise. """ return os.path.isdir(os.path.join(self.abspath, relpath)) def erase(self, create_empty_folder=False): - """ - Erases the folder. Should be called only in very specific cases, + """Erases the folder. Should be called only in very specific cases, in general folder should not be erased! Doesn't complain if the folder does not exist. @@ -339,8 +322,7 @@ def erase(self, create_empty_folder=False): self.create() def create(self): - """ - Creates the folder, if it does not exist on the disk yet. + """Creates the folder, if it does not exist on the disk yet. It will also create top directories, if absent. @@ -387,7 +369,7 @@ def replace_with_folder(self, srcdir, move=False, overwrite=False): os.chmod(dirpath, self.mode_dir) for filename in filenames: # do not change permissions of symlinks (this would actually change permissions of the linked file/dir) - # TODO check whether this is a big speed loss # pylint: disable=fixme + # TODO check whether this is a big speed loss full_file_path = os.path.join(dirpath, filename) if not os.path.islink(full_file_path): os.chmod(full_file_path, self.mode_file) diff --git a/aiida/common/hashing.py b/aiida/common/hashing.py index 7cb06a3f81..423fba2ab5 100644 --- a/aiida/common/hashing.py +++ b/aiida/common/hashing.py @@ -8,18 +8,18 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Common password and hash generation functions.""" -from collections import OrderedDict, abc -from datetime import date, datetime, timezone -from decimal import Decimal -from functools import singledispatch import hashlib -from itertools import chain import numbers -from operator import itemgetter import secrets import string import typing import uuid +from collections import OrderedDict, abc +from datetime import date, datetime, timezone +from decimal import Decimal +from functools import singledispatch +from itertools import chain +from operator import itemgetter from aiida.common.constants import AIIDA_FLOAT_PRECISION from aiida.common.exceptions import HashingError @@ -75,8 +75,7 @@ def chunked_file_hash( def make_hash(object_to_hash, **kwargs): - """ - Makes a hash from a dictionary, list, tuple or set to any level, that contains + """Makes a hash from a dictionary, list, tuple or set to any level, that contains only other hashable or nonhashable types (including lists, tuples, sets, and dictionaries). @@ -95,7 +94,7 @@ def make_hash(object_to_hash, **kwargs): hashing iteratively. Uses python's sorted function to sort unsorted sets and dictionaries by sorting the hashed keys. """ - hashes = _make_hash(object_to_hash, **kwargs) # pylint: disable=assignment-from-no-return + hashes = _make_hash(object_to_hash, **kwargs) # use the Unlimited fanout hashing protocol outlined in # https://blake2.net/blake2_20130129.pdf @@ -112,8 +111,7 @@ def make_hash(object_to_hash, **kwargs): @singledispatch def _make_hash(object_to_hash, **_): - """ - Implementation of the ``make_hash`` function. The hash is created as a + """Implementation of the ``make_hash`` function. The hash is created as a 28 byte integer, and only later converted to a string. """ raise HashingError(f'Value of type {type(object_to_hash)} cannot be hashed') @@ -141,16 +139,22 @@ def _(val, **kwargs): @_make_hash.register(abc.Sequence) def _(sequence_obj, **kwargs): # unpack the list and use the elements - return [_single_digest('list(')] + list(chain.from_iterable(_make_hash(i, **kwargs) for i in sequence_obj) - ) + [_END_DIGEST] + return ( + [_single_digest('list(')] + + list(chain.from_iterable(_make_hash(i, **kwargs) for i in sequence_obj)) + + [_END_DIGEST] + ) @_make_hash.register(abc.Set) def _(set_obj, **kwargs): # turn the set objects into a list of hashes which are always sortable, # then return a flattened list of the hashes - return [_single_digest('set(')] + list(chain.from_iterable(sorted(_make_hash(i, **kwargs) for i in set_obj)) - ) + [_END_DIGEST] + return ( + [_single_digest('set(')] + + list(chain.from_iterable(sorted(_make_hash(i, **kwargs) for i in set_obj))) + + [_END_DIGEST] + ) @_make_hash.register(abc.Mapping) @@ -161,33 +165,39 @@ def hashed_key_mapping(): for key, value in mapping.items(): yield (_make_hash(key, **kwargs), value) - return [_single_digest('dict(')] + list( - chain.from_iterable( - (k_digest + _make_hash(val, **kwargs)) for k_digest, val in sorted(hashed_key_mapping(), key=itemgetter(0)) + return ( + [_single_digest('dict(')] + + list( + chain.from_iterable( + (k_digest + _make_hash(val, **kwargs)) + for k_digest, val in sorted(hashed_key_mapping(), key=itemgetter(0)) + ) ) - ) + [_END_DIGEST] + + [_END_DIGEST] + ) @_make_hash.register(OrderedDict) def _(mapping, **kwargs): - """ - Hashing of OrderedDicts + """Hashing of OrderedDicts :param odict_as_unordered: hash OrderedDicts as normal dicts (mostly for testing) """ - if kwargs.get('odict_as_unordered', False): return _make_hash.registry[abc.Mapping](mapping) - return ([_single_digest('odict(')] + list( - chain.from_iterable((_make_hash(key, **kwargs) + _make_hash(val, **kwargs)) for key, val in mapping.items()) - ) + [_END_DIGEST]) + return ( + [_single_digest('odict(')] + + list( + chain.from_iterable((_make_hash(key, **kwargs) + _make_hash(val, **kwargs)) for key, val in mapping.items()) + ) + + [_END_DIGEST] + ) @_make_hash.register(numbers.Real) def _(val, **kwargs): - """ - Before hashing a float, convert to a string (via rounding) and with a fixed number of digits after the comma. + """Before hashing a float, convert to a string (via rounding) and with a fixed number of digits after the comma. Note that the `_single_digest` requires a bytes object so we need to encode the utf-8 string first """ return [_single_digest('float', float_to_text(val, sig=AIIDA_FLOAT_PRECISION).encode('utf-8'))] @@ -195,8 +205,7 @@ def _(val, **kwargs): @_make_hash.register(Decimal) def _(val, **kwargs): - """ - While a decimal can be converted exactly to a string which captures all characteristics of the underlying + """While a decimal can be converted exactly to a string which captures all characteristics of the underlying implementation, we also need compatibility with "equal" representations as int or float. Hence we are checking for the exponent (which is negative if there is a fractional component, 0 otherwise) and get the same hash as for a corresponding float or int. @@ -208,21 +217,20 @@ def _(val, **kwargs): @_make_hash.register(numbers.Complex) def _(val, **kwargs): - """ - In case of a complex number, use the same encoding of two floats and join them with a special symbol (a ! here). - """ + """In case of a complex number, use the same encoding of two floats and join with a special symbol (a ! here).""" return [ _single_digest( - 'complex', '{}!{}'.format( + 'complex', + '{}!{}'.format( float_to_text(val.real, sig=AIIDA_FLOAT_PRECISION), float_to_text(val.imag, sig=AIIDA_FLOAT_PRECISION) - ).encode('utf-8') + ).encode('utf-8'), ) ] @_make_hash.register(numbers.Integral) def _(val, **kwargs): - """get the hash of the little-endian signed long long representation of the integer""" + """Get the hash of the little-endian signed long long representation of the integer""" return [_single_digest('int', f'{val}'.encode('utf-8'))] @@ -238,7 +246,7 @@ def _(val, **kwargs): @_make_hash.register(datetime) def _(val, **kwargs): - """hashes the little-endian rep of the float .""" + """Hashes the little-endian rep of the float .""" # see also https://stackoverflow.com/a/8778548 for an excellent elaboration if val.tzinfo is None or val.utcoffset() is None: val = val.replace(tzinfo=timezone.utc) @@ -260,24 +268,27 @@ def _(val, **kwargs): @_make_hash.register(DatetimePrecision) def _(datetime_precision, **kwargs): - """ Hashes for DatetimePrecision object - """ - return [_single_digest('dt_prec')] + list( - chain.from_iterable(_make_hash(i, **kwargs) for i in [datetime_precision.dtobj, datetime_precision.precision]) - ) + [_END_DIGEST] + """Hashes for DatetimePrecision object""" + return ( + [_single_digest('dt_prec')] + + list( + chain.from_iterable( + _make_hash(i, **kwargs) for i in [datetime_precision.dtobj, datetime_precision.precision] + ) + ) + + [_END_DIGEST] + ) @_make_hash.register(Folder) def _(folder, **kwargs): - """ - Hash the content of a Folder object. The name of the folder itself is actually ignored + """Hash the content of a Folder object. The name of the folder itself is actually ignored :param ignored_folder_content: list of filenames to be ignored for the hashing """ - ignored_folder_content = kwargs.get('ignored_folder_content', []) def folder_digests(subfolder): - """traverses the given folder and yields digests for the contained objects""" + """Traverses the given folder and yields digests for the contained objects""" for name, isfile in sorted(subfolder.get_content_list(only_paths=False), key=itemgetter(0)): if name in ignored_folder_content: continue @@ -296,14 +307,13 @@ def folder_digests(subfolder): def float_to_text(value, sig): - """ - Convert float to text string for computing hash. + """Convert float to text string for computing hash. Preseve up to N significant number given by sig. :param value: the float value to convert :param sig: choose how many digits after the comma should be output """ if value == 0: - value = 0. # Identify value of -0. and overwrite with 0. + value = 0.0 # Identify value of -0. and overwrite with 0. fmt = f'{{:.{sig}g}}' return fmt.format(value) diff --git a/aiida/common/lang.py b/aiida/common/lang.py index 7206fde896..84a17e2954 100644 --- a/aiida/common/lang.py +++ b/aiida/common/lang.py @@ -50,7 +50,7 @@ def type_check(what, of_type, msg=None, allow_none=False): def override_decorator(check=False) -> Callable[[MethodType], MethodType]: """Decorator to signal that a method from a base class is being overridden completely.""" - def wrap(func: MethodType) -> MethodType: # pylint: disable=missing-docstring + def wrap(func: MethodType) -> MethodType: import inspect if isinstance(func, property): @@ -63,7 +63,7 @@ def wrap(func: MethodType) -> MethodType: # pylint: disable=missing-docstring if check: @functools.wraps(func) - def wrapped_fn(self, *args, **kwargs): # pylint: disable=missing-docstring + def wrapped_fn(self, *args, **kwargs): try: getattr(super(), func.__name__) except AttributeError: @@ -78,15 +78,14 @@ def wrapped_fn(self, *args, **kwargs): # pylint: disable=missing-docstring return wrap -override = override_decorator(check=False) # pylint: disable=invalid-name +override = override_decorator(check=False) ReturnType = TypeVar('ReturnType') SelfType = TypeVar('SelfType') -class classproperty(Generic[ReturnType]): # pylint: disable=invalid-name - """ - A class that, when used as a decorator, works as if the +class classproperty(Generic[ReturnType]): # noqa: N801 + """A class that, when used as a decorator, works as if the two decorators @property and @classmethod where applied together (i.e., the object works as a property, both for the Class and for any of its instance; and is called with the class cls rather than with the diff --git a/aiida/common/links.py b/aiida/common/links.py index 7e8b1fcb7b..ca609824fe 100644 --- a/aiida/common/links.py +++ b/aiida/common/links.py @@ -58,7 +58,7 @@ class GraphTraversalRules(Enum): 'call_calc_forward': GraphTraversalRule(LinkType.CALL_CALC, 'forward', True, False), 'call_calc_backward': GraphTraversalRule(LinkType.CALL_CALC, 'backward', True, False), 'call_work_forward': GraphTraversalRule(LinkType.CALL_WORK, 'forward', True, False), - 'call_work_backward': GraphTraversalRule(LinkType.CALL_WORK, 'backward', True, False) + 'call_work_backward': GraphTraversalRule(LinkType.CALL_WORK, 'backward', True, False), } DELETE = { @@ -73,7 +73,7 @@ class GraphTraversalRules(Enum): 'call_calc_forward': GraphTraversalRule(LinkType.CALL_CALC, 'forward', True, True), 'call_calc_backward': GraphTraversalRule(LinkType.CALL_CALC, 'backward', False, True), 'call_work_forward': GraphTraversalRule(LinkType.CALL_WORK, 'forward', True, True), - 'call_work_backward': GraphTraversalRule(LinkType.CALL_WORK, 'backward', False, True) + 'call_work_backward': GraphTraversalRule(LinkType.CALL_WORK, 'backward', False, True), } EXPORT = { @@ -88,7 +88,7 @@ class GraphTraversalRules(Enum): 'call_calc_forward': GraphTraversalRule(LinkType.CALL_CALC, 'forward', False, True), 'call_calc_backward': GraphTraversalRule(LinkType.CALL_CALC, 'backward', True, True), 'call_work_forward': GraphTraversalRule(LinkType.CALL_WORK, 'forward', False, True), - 'call_work_backward': GraphTraversalRule(LinkType.CALL_WORK, 'backward', True, True) + 'call_work_backward': GraphTraversalRule(LinkType.CALL_WORK, 'backward', True, True), } diff --git a/aiida/common/log.py b/aiida/common/log.py index e993950423..7e4530fb4c 100644 --- a/aiida/common/log.py +++ b/aiida/common/log.py @@ -35,7 +35,6 @@ def report(self: logging.Logger, msg, *args, **kwargs): class AiidaLoggerType(logging.Logger): - def report(self, msg: str, *args, **kwargs) -> None: """Log a message at the ``REPORT`` level.""" @@ -76,26 +75,20 @@ def get_logging_config(): 'disable_existing_loggers': False, 'formatters': { 'verbose': { - 'format': '%(levelname)s %(asctime)s %(module)s %(process)d ' - '%(thread)d %(message)s', + 'format': '%(levelname)s %(asctime)s %(module)s %(process)d ' '%(thread)d %(message)s', }, 'halfverbose': { 'format': '%(asctime)s <%(process)d> %(name)s: [%(levelname)s] %(message)s', 'datefmt': '%m/%d/%Y %I:%M:%S %p', }, - 'cli': { - 'class': 'aiida.cmdline.utils.log.CliFormatter' - } + 'cli': {'class': 'aiida.cmdline.utils.log.CliFormatter'}, }, 'handlers': { 'console': { 'class': 'logging.StreamHandler', 'formatter': 'halfverbose', }, - 'cli': { - 'class': 'aiida.cmdline.utils.log.CliHandler', - 'formatter': 'cli' - }, + 'cli': {'class': 'aiida.cmdline.utils.log.CliHandler', 'formatter': 'cli'}, }, 'loggers': { 'aiida': { @@ -158,7 +151,7 @@ def evaluate_logging_configuration(dictionary): for key, value in dictionary.items(): if isinstance(value, collections.abc.Mapping): result[key] = evaluate_logging_configuration(value) - elif isinstance(value, types.LambdaType): # pylint: disable=no-member + elif isinstance(value, types.LambdaType): result[key] = value() else: result[key] = value @@ -167,8 +160,7 @@ def evaluate_logging_configuration(dictionary): def configure_logging(with_orm=False, daemon=False, daemon_log_file=None): - """ - Setup the logging by retrieving the LOGGING dictionary from aiida and passing it to + """Setup the logging by retrieving the LOGGING dictionary from aiida and passing it to the python module logging.config.dictConfig. If the logging needs to be setup for the daemon, set the argument 'daemon' to True and specify the path to the log file. This will cause a 'daemon_handler' to be added to all the configured loggers, that is a @@ -189,7 +181,6 @@ def configure_logging(with_orm=False, daemon=False, daemon_log_file=None): # Add the daemon file handler to all loggers if daemon=True if daemon is True: - # Daemon always needs to run with ORM enabled with_orm = True @@ -233,9 +224,10 @@ def configure_logging(with_orm=False, daemon=False, daemon_log_file=None): # Add the `DbLogHandler` if `with_orm` is `True` if with_orm: from aiida.manage.configuration import get_config_option + config['handlers']['db_logger'] = { 'level': get_config_option('logging.db_loglevel'), - 'class': 'aiida.orm.utils.log.DBLogHandler' + 'class': 'aiida.orm.utils.log.DBLogHandler', } config['loggers']['aiida']['handlers'].append('db_logger') diff --git a/aiida/common/progress_reporter.py b/aiida/common/progress_reporter.py index 9633946d58..1610fb7c72 100644 --- a/aiida/common/progress_reporter.py +++ b/aiida/common/progress_reporter.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=global-statement,unused-argument """Provide a singleton progress reporter implementation. The interface is inspired by `tqdm `, @@ -22,8 +21,12 @@ from typing import Any, Callable, Optional, Type __all__ = ( - 'get_progress_reporter', 'set_progress_reporter', 'set_progress_bar_tqdm', 'ProgressReporterAbstract', - 'TQDM_BAR_FORMAT', 'create_callback' + 'get_progress_reporter', + 'set_progress_reporter', + 'set_progress_bar_tqdm', + 'ProgressReporterAbstract', + 'TQDM_BAR_FORMAT', + 'create_callback', ) TQDM_BAR_FORMAT = '{desc:40.40}{percentage:6.1f}%|{bar}| {n_fmt}/{total_fmt}' @@ -65,7 +68,7 @@ def desc(self) -> Optional[str]: return self._desc @property - def n(self) -> int: # pylint: disable=invalid-name + def n(self) -> int: """Return the current iteration.""" # note using `n` as the attribute name is necessary for compatibility with tqdm return self._increment @@ -89,7 +92,7 @@ def set_description_str(self, text: Optional[str] = None, refresh: bool = True): """ self._desc = text - def update(self, n: int = 1): # pylint: disable=invalid-name + def update(self, n: int = 1): """Update the progress counter. :param n: Increment to add to the internal counter of iterations @@ -115,7 +118,7 @@ class ProgressReporterNull(ProgressReporterAbstract): """ -PROGRESS_REPORTER: Type[ProgressReporterAbstract] = ProgressReporterNull # pylint: disable=invalid-name +PROGRESS_REPORTER: Type[ProgressReporterAbstract] = ProgressReporterNull def get_progress_reporter() -> Type[ProgressReporterAbstract]: @@ -129,7 +132,7 @@ def get_progress_reporter() -> Type[ProgressReporterAbstract]: progress.update() """ - global PROGRESS_REPORTER # pylint: disable=global-variable-not-assigned + global PROGRESS_REPORTER # noqa: PLW0602 return PROGRESS_REPORTER @@ -152,7 +155,7 @@ def set_progress_reporter(reporter: Optional[Type[ProgressReporterAbstract]] = N progress.update() """ - global PROGRESS_REPORTER + global PROGRESS_REPORTER # noqa: PLW0603 if reporter is None: PROGRESS_REPORTER = ProgressReporterNull elif kwargs: @@ -173,6 +176,7 @@ def set_progress_bar_tqdm(bar_format: Optional[str] = TQDM_BAR_FORMAT, leave: Op """ from tqdm import tqdm + set_progress_reporter(tqdm, bar_format=bar_format, leave=leave, **kwargs) diff --git a/aiida/common/utils.py b/aiida/common/utils.py index 011ed06b7e..baf1103e91 100644 --- a/aiida/common/utils.py +++ b/aiida/common/utils.py @@ -8,13 +8,13 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Miscellaneous generic utility functions and classes.""" -from datetime import datetime import filecmp import inspect import io import os import re import sys +from datetime import datetime from typing import Any, Dict from uuid import UUID @@ -22,10 +22,9 @@ def get_new_uuid(): - """ - Return a new UUID (typically to be used for new nodes). - """ + """Return a new UUID (typically to be used for new nodes).""" import uuid + return str(uuid.uuid4()) @@ -43,8 +42,7 @@ def validate_uuid(given_uuid: str) -> bool: def validate_list_of_string_tuples(val, tuple_length): - """ - Check that: + """Check that: 1. ``val`` is a list or tuple 2. each element of the list: @@ -68,8 +66,9 @@ def validate_list_of_string_tuples(val, tuple_length): for element in val: if ( - not isinstance(element, (list, tuple)) or (len(element) != tuple_length) or - not all(isinstance(s, str) for s in element) + not isinstance(element, (list, tuple)) + or (len(element) != tuple_length) + or not all(isinstance(s, str) for s in element) ): raise ValidationError(err_msg) @@ -77,8 +76,7 @@ def validate_list_of_string_tuples(val, tuple_length): def get_unique_filename(filename, list_of_filenames): - """ - Return a unique filename that can be added to the list_of_filenames. + """Return a unique filename that can be added to the list_of_filenames. If filename is not in list_of_filenames, it simply returns the filename string itself. Otherwise, it appends a integer number to the filename @@ -106,9 +104,8 @@ def get_unique_filename(filename, list_of_filenames): return new_filename -def str_timedelta(dt, max_num_fields=3, short=False, negative_to_zero=False): # pylint: disable=invalid-name - """ - Given a dt in seconds, return it in a HH:MM:SS format. +def str_timedelta(dt, max_num_fields=3, short=False, negative_to_zero=False): + """Given a dt in seconds, return it in a HH:MM:SS format. :param dt: a TimeDelta object :param max_num_fields: maximum number of non-zero fields to show @@ -174,8 +171,7 @@ def str_timedelta(dt, max_num_fields=3, short=False, negative_to_zero=False): # def get_class_string(obj): - """ - Return the string identifying the class of the object (module + object name, + """Return the string identifying the class of the object (module + object name, joined by dots). It works both for classes and for class instances. @@ -187,8 +183,7 @@ def get_class_string(obj): def get_object_from_string(class_string): - """ - Given a string identifying an object (as returned by the get_class_string + """Given a string identifying an object (as returned by the get_class_string method) load and return the actual object. """ import importlib @@ -198,9 +193,8 @@ def get_object_from_string(class_string): return getattr(importlib.import_module(the_module), the_name) -def grouper(n, iterable): # pylint: disable=invalid-name - """ - Given an iterable, returns an iterable that returns tuples of groups of +def grouper(n, iterable): + """Given an iterable, returns an iterable that returns tuples of groups of elements from iterable of length n, except the last one that has the required length to exaust iterable (i.e., there is no filling applied). @@ -219,10 +213,10 @@ def grouper(n, iterable): # pylint: disable=invalid-name class ArrayCounter: - """ - A counter & a method that increments it and returns its value. + """A counter & a method that increments it and returns its value. It is used in various tests. """ + seq = None def __init__(self): @@ -234,8 +228,7 @@ def array_counter(self): def are_dir_trees_equal(dir1, dir2): - """ - Compare two directories recursively. Files in each directory are + """Compare two directories recursively. Files in each directory are assumed to be equal if their names and contents are equal. @param dir1: First directory path @@ -245,16 +238,16 @@ def are_dir_trees_equal(dir1, dir2): there were no errors while accessing the directories or files, False otherwise. """ - # Directory comparison dirs_cmp = filecmp.dircmp(dir1, dir2) if dirs_cmp.left_only or dirs_cmp.right_only or dirs_cmp.funny_files: return ( - False, 'Left directory: {}, right directory: {}, files only ' + False, + 'Left directory: {}, right directory: {}, files only ' 'in left directory: {}, files only in right directory: ' '{}, not comparable files: {}'.format( dir1, dir2, dirs_cmp.left_only, dirs_cmp.right_only, dirs_cmp.funny_files - ) + ), ) # If the directories contain the same files, compare the common files @@ -275,15 +268,13 @@ def are_dir_trees_equal(dir1, dir2): class Prettifier: - """ - Class to manage prettifiers (typically for labels of kpoints + """Class to manage prettifiers (typically for labels of kpoints in band plots) """ @classmethod def _prettify_label_pass(cls, label): - """ - No-op prettifier, simply returns the same label + """No-op prettifier, simply returns the same label :param label: a string to prettify """ @@ -291,29 +282,24 @@ def _prettify_label_pass(cls, label): @classmethod def _prettify_label_agr(cls, label): - """ - Prettifier for XMGrace + """Prettifier for XMGrace :param label: a string to prettify """ - label = ( - label - .replace('GAMMA', r'\xG\f{}') - .replace('DELTA', r'\xD\f{}') - .replace('LAMBDA', r'\xL\f{}') - .replace('SIGMA', r'\xS\f{}') - ) # yapf:disable + label.replace('GAMMA', r'\xG\f{}') + .replace('DELTA', r'\xD\f{}') + .replace('LAMBDA', r'\xL\f{}') + .replace('SIGMA', r'\xS\f{}') + ) return re.sub(r'_(.?)', r'\\s\1\\N', label) @classmethod def _prettify_label_agr_simple(cls, label): - """ - Prettifier for XMGrace (for old label names) + """Prettifier for XMGrace (for old label names) :param label: a string to prettify """ - if label == 'G': return r'\xG' @@ -321,33 +307,23 @@ def _prettify_label_agr_simple(cls, label): @classmethod def _prettify_label_gnuplot(cls, label): - """ - Prettifier for Gnuplot + """Prettifier for Gnuplot :note: uses unicode, returns unicode strings (potentially, if needed) :param label: a string to prettify """ - - label = ( - label - .replace('GAMMA', 'Γ') - .replace('DELTA', 'Δ') - .replace('LAMBDA', 'Λ') - .replace('SIGMA', 'Σ') - ) # yapf:disable + label = label.replace('GAMMA', 'Γ').replace('DELTA', 'Δ').replace('LAMBDA', 'Λ').replace('SIGMA', 'Σ') return re.sub(r'_(.?)', r'_{\1}', label) @classmethod def _prettify_label_gnuplot_simple(cls, label): - """ - Prettifier for Gnuplot (for old label names) + """Prettifier for Gnuplot (for old label names) :note: uses unicode, returns unicode strings (potentially, if needed) :param label: a string to prettify """ - if label == 'G': return 'Γ' @@ -355,19 +331,16 @@ def _prettify_label_gnuplot_simple(cls, label): @classmethod def _prettify_label_latex(cls, label): - """ - Prettifier for matplotlib, using LaTeX syntax + """Prettifier for matplotlib, using LaTeX syntax :param label: a string to prettify """ - label = ( - label - .replace('GAMMA', r'$\Gamma$') - .replace('DELTA', r'$\Delta$') - .replace('LAMBDA', r'$\Lambda$') - .replace('SIGMA', r'$\Sigma$') - ) # yapf:disable + label.replace('GAMMA', r'$\Gamma$') + .replace('DELTA', r'$\Delta$') + .replace('LAMBDA', r'$\Lambda$') + .replace('SIGMA', r'$\Sigma$') + ) label = re.sub(r'_(.?)', r'$_{\1}$', label) # label += r"$_{\vphantom{0}}$" @@ -376,8 +349,7 @@ def _prettify_label_latex(cls, label): @classmethod def _prettify_label_latex_simple(cls, label): - """ - Prettifier for matplotlib, using LaTeX syntax (for old label names) + """Prettifier for matplotlib, using LaTeX syntax (for old label names) :param label: a string to prettify """ @@ -387,9 +359,8 @@ def _prettify_label_latex_simple(cls, label): return re.sub(r'(\d+)', r'$_{\1}$', label) @classproperty - def prettifiers(cls) -> Dict[str, Any]: # pylint: disable=no-self-argument - """ - Property that returns a dictionary that for each string associates + def prettifiers(cls) -> Dict[str, Any]: # noqa: N805 + """Property that returns a dictionary that for each string associates the function to prettify a label :return: a dictionary where keys are strings and values are functions @@ -406,16 +377,14 @@ def prettifiers(cls) -> Dict[str, Any]: # pylint: disable=no-self-argument @classmethod def get_prettifiers(cls): - """ - Return a list of valid prettifier strings + """Return a list of valid prettifier strings :return: a list of strings """ return sorted(cls.prettifiers.keys()) - def __init__(self, format): # pylint: disable=redefined-builtin - """ - Create a class to pretttify strings of a given format + def __init__(self, format): + """Create a class to pretttify strings of a given format :param format: a string with the format to use to prettify. Valid formats are obtained from self.prettifiers @@ -424,13 +393,12 @@ def __init__(self, format): # pylint: disable=redefined-builtin format = 'pass' try: - self._prettifier_f = self.prettifiers[format] # pylint: disable=unsubscriptable-object + self._prettifier_f = self.prettifiers[format] except KeyError: raise ValueError(f"Unknown prettifier format {format}; valid formats: {', '.join(self.get_prettifiers())}") def prettify(self, label): - """ - Prettify a label using the format passed in the initializer + """Prettify a label using the format passed in the initializer :param label: the string to prettify :return: a prettified string @@ -438,9 +406,8 @@ def prettify(self, label): return self._prettifier_f(label) -def prettify_labels(labels, format=None): # pylint: disable=redefined-builtin - """ - Prettify label for typesetting in various formats +def prettify_labels(labels, format=None): + """Prettify label for typesetting in various formats :param labels: a list of length-2 tuples, in the format(position, label) :param format: a string with the format for the prettifier (e.g. 'agr', @@ -453,9 +420,8 @@ def prettify_labels(labels, format=None): # pylint: disable=redefined-builtin return [(pos, prettifier.prettify(label)) for pos, label in labels] -def join_labels(labels, join_symbol='|', threshold=1.e-6): - """ - Join labels with a joining symbol when they are very close +def join_labels(labels, join_symbol='|', threshold=1.0e-6): + """Join labels with a joining symbol when they are very close :param labels: a list of length-2 tuples, in the format(position, label) :param join_symbol: the string to use to join different paths. By default, a pipe @@ -481,8 +447,7 @@ def join_labels(labels, join_symbol='|', threshold=1.e-6): def strip_prefix(full_string, prefix): - """ - Strip the prefix from the given string and return it. If the prefix is not present + """Strip the prefix from the given string and return it. If the prefix is not present the original string will be returned unaltered :param full_string: the string from which to remove the prefix @@ -496,8 +461,7 @@ def strip_prefix(full_string, prefix): class Capturing: - """ - This class captures stdout and returns it + """This class captures stdout and returns it (as a list, split by lines). Note: if you raise a SystemExit, you have to catch it outside. @@ -516,8 +480,6 @@ class Capturing: lines, use obj.stderr_lines. If False, obj.stderr_lines is None. """ - # pylint: disable=attribute-defined-outside-init - def __init__(self, capture_stderr=False): """Construct a new instance.""" self.stdout_lines = [] @@ -558,8 +520,7 @@ def __iter__(self): class ErrorAccumulator: - """ - Allows to run a number of functions and collect all the errors they raise + """Allows to run a number of functions and collect all the errors they raise This allows to validate multiple things and tell the user about all the errors encountered at once. Works best if the individual functions do not depend on each other. @@ -592,8 +553,7 @@ def raise_errors(self, raise_cls): class DatetimePrecision: - """ - A simple class which stores a datetime object with its precision. No + """A simple class which stores a datetime object with its precision. No internal check is done (cause itis not possible). precision: 1 (only full date) @@ -603,8 +563,7 @@ class DatetimePrecision: """ def __init__(self, dtobj, precision): - """ Constructor to check valid datetime object and precision """ - + """Constructor to check valid datetime object and precision""" if not isinstance(dtobj, datetime): raise TypeError('dtobj argument has to be a datetime object') diff --git a/aiida/common/warnings.py b/aiida/common/warnings.py index fac19444ab..52ef848777 100644 --- a/aiida/common/warnings.py +++ b/aiida/common/warnings.py @@ -13,8 +13,7 @@ class AiidaDeprecationWarning(Warning): - """ - Class for AiiDA deprecations. + """Class for AiiDA deprecations. It does *not* inherit, on purpose, from `DeprecationWarning` as this would be filtered out by default. @@ -25,15 +24,11 @@ class AiidaDeprecationWarning(Warning): class AiidaEntryPointWarning(Warning): - """ - Class for warnings concerning AiiDA entry points. - """ + """Class for warnings concerning AiiDA entry points.""" class AiidaTestWarning(Warning): - """ - Class for warnings concerning the AiiDA testing infrastructure. - """ + """Class for warnings concerning the AiiDA testing infrastructure.""" def warn_deprecation(message: str, version: int, stacklevel=2) -> None: diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py index f9b4577579..1d9dc3e1ac 100644 --- a/aiida/engine/__init__.py +++ b/aiida/engine/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .daemon import * from .exceptions import * @@ -77,4 +76,4 @@ 'workfunction', ) -# yapf: enable +# fmt: on diff --git a/aiida/engine/daemon/__init__.py b/aiida/engine/daemon/__init__.py index 4012ec5d62..1ae6a5a38f 100644 --- a/aiida/engine/daemon/__init__.py +++ b/aiida/engine/daemon/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .client import * @@ -21,4 +20,4 @@ 'get_daemon_client', ) -# yapf: enable +# fmt: on diff --git a/aiida/engine/daemon/client.py b/aiida/engine/daemon/client.py index 7274bc2fb4..e0f2fc5831 100644 --- a/aiida/engine/daemon/client.py +++ b/aiida/engine/daemon/client.py @@ -80,7 +80,7 @@ def get_daemon_client(profile_name: str | None = None) -> 'DaemonClient': return DaemonClient(profile) -class DaemonClient: # pylint: disable=too-many-public-methods +class DaemonClient: """Client to interact with the daemon.""" _DAEMON_NAME = 'aiida-{name}' @@ -239,7 +239,6 @@ def get_circus_socket_directory(self) -> str: except (ValueError, IOError): raise RuntimeError('daemon is running so sockets file should have been there but could not read it') else: - # The SOCKET_DIRECTORY is already set, a temporary directory was already created and the same should be used if self._socket_directory is not None: return self._socket_directory @@ -519,7 +518,7 @@ def start_daemon( timeout = timeout or self._daemon_timeout try: - subprocess.check_output(command, env=env, stderr=subprocess.STDOUT) # pylint: disable=unexpected-keyword-arg + subprocess.check_output(command, env=env, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as exception: raise DaemonException('The daemon failed to start.') from exception @@ -647,7 +646,6 @@ def _await_condition(condition: t.Callable, exception: Exception, timeout: int = start_time = time.time() while not condition(): - time.sleep(interval) if time.time() - start_time > timeout: @@ -687,23 +685,25 @@ def _start_daemon(self, number_workers: int = 1, foreground: bool = False) -> No 'debug': False, 'statsd': True, 'pidfile': self.circus_pid_file, - 'watchers': [{ - 'cmd': ' '.join(self.cmd_start_daemon_worker), - 'name': self.daemon_name, - 'numprocesses': number_workers, - 'virtualenv': self.virtualenv, - 'copy_env': True, - 'stdout_stream': { - 'class': 'FileStream', - 'filename': self.daemon_log_file, - }, - 'stderr_stream': { - 'class': 'FileStream', - 'filename': self.daemon_log_file, - }, - 'env': self.get_env(), - }] - } # yapf: disable + 'watchers': [ + { + 'cmd': ' '.join(self.cmd_start_daemon_worker), + 'name': self.daemon_name, + 'numprocesses': number_workers, + 'virtualenv': self.virtualenv, + 'copy_env': True, + 'stdout_stream': { + 'class': 'FileStream', + 'filename': self.daemon_log_file, + }, + 'stderr_stream': { + 'class': 'FileStream', + 'filename': self.daemon_log_file, + }, + 'env': self.get_env(), + } + ], + } if not foreground: daemonize() @@ -724,10 +724,10 @@ def _start_daemon(self, number_workers: int = 1, foreground: bool = False) -> No future = arbiter.start() should_restart = False if check_future_exception_and_log(future) is None: - should_restart = arbiter._restarting # pylint: disable=protected-access + should_restart = arbiter._restarting except Exception as exception: # Emergency stop - arbiter.loop.run_sync(arbiter._emergency_stop) # pylint: disable=protected-access + arbiter.loop.run_sync(arbiter._emergency_stop) raise exception except KeyboardInterrupt: pass diff --git a/aiida/engine/daemon/execmanager.py b/aiida/engine/daemon/execmanager.py index 5ff8a432a8..5a3e605533 100644 --- a/aiida/engine/daemon/execmanager.py +++ b/aiida/engine/daemon/execmanager.py @@ -7,23 +7,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -This file contains the main routines to submit, check and retrieve calculation +"""This file contains the main routines to submit, check and retrieve calculation results. These are general and contain only the main logic; where appropriate, the routines make reference to the suitable plugins for all plugin-specific operations. """ from __future__ import annotations -from collections.abc import Mapping -from logging import LoggerAdapter import os import pathlib import shutil +from collections.abc import Mapping +from logging import LoggerAdapter from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union from typing import Mapping as MappingType -from typing import Optional, Tuple, Union -from typing import TYPE_CHECKING, Any, List from aiida.common import AIIDA_LOGGER, exceptions from aiida.common.datastructures import CalcInfo @@ -70,7 +68,7 @@ def upload_calculation( calc_info: CalcInfo, folder: SandboxFolder, inputs: Optional[MappingType[str, Any]] = None, - dry_run: bool = False + dry_run: bool = False, ) -> None: """Upload a `CalcJob` instance @@ -79,8 +77,6 @@ def upload_calculation( :param calc_info: the calculation info datastructure returned by `CalcJob.presubmit` :param folder: temporary local file system folder containing the inputs written by `CalcJob.prepare_for_submission` """ - # pylint: disable=too-many-locals,too-many-branches,too-many-statements - # If the calculation already has a `remote_folder`, simply return. The upload was apparently already completed # before, which can happen if the daemon is restarted and it shuts down after uploading but before getting the # chance to perform the state transition. Upon reloading this calculation, it will re-attempt the upload. @@ -216,7 +212,6 @@ def upload_calculation( if data_node is None: logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`') else: - # If no explicit source filename is defined, we assume the top-level directory filename_source = filename or '.' filename_target = target or '' @@ -246,7 +241,7 @@ def upload_calculation( logger.debug(f'[submission of calculation {node.pk}] copying file/folder {filename}...') transport.put(folder.get_abs_path(filename), filename) - for (remote_computer_uuid, remote_abs_path, dest_rel_path) in remote_copy_list: + for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_copy_list: if remote_computer_uuid == computer.uuid: logger.debug( f'[submission of calculation {node.pk}] copying {dest_rel_path} ' @@ -271,7 +266,7 @@ def upload_calculation( 'not implemented yet' ) - for (remote_computer_uuid, remote_abs_path, dest_rel_path) in remote_symlink_list: + for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_symlink_list: if remote_computer_uuid == computer.uuid: logger.debug( f'[submission of calculation {node.pk}] copying {dest_rel_path} remotely, ' @@ -292,7 +287,6 @@ def upload_calculation( f'It is not possible to create a symlink between two different machines for calculation {node.pk}' ) else: - if remote_copy_list: filepath = os.path.join(workdir, '_aiida_remote_copy_list.txt') with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment] @@ -339,12 +333,12 @@ def upload_calculation( dirname not in provenance_exclude_list for dirname in dirnames ): with open(filepath, 'rb') as handle: # type: ignore[assignment] - node.base.repository._repository.put_object_from_filelike(handle, relpath) # pylint: disable=protected-access + node.base.repository._repository.put_object_from_filelike(handle, relpath) # Since the node is already stored, we cannot use the normal repository interface since it will raise a # `ModificationNotAllowed` error. To bypass it, we go straight to the underlying repository instance to store the # files, however, this means we have to manually update the node's repository metadata. - node.base.repository._update_repository_metadata() # pylint: disable=protected-access + node.base.repository._update_repository_metadata() if not dry_run: # Make sure that attaching the `remote_folder` with a link is the last thing we do. This gives the biggest @@ -423,7 +417,6 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: target_basepath = pathlib.Path(stash_options['target_base']) / uuid[:2] / uuid[2:4] / uuid[4:] for source_filename in source_list: - if transport.has_magic(source_filename): copy_instructions = [] for globbed_filename in transport.glob(str(source_basepath / source_filename)): @@ -505,7 +498,7 @@ def retrieve_calculation(calculation: CalcJobNode, transport: Transport, retriev for filename in os.listdir(retrieved_temporary_folder): EXEC_LOGGER.debug( f"[retrieval of calc {calculation.pk}] Retrieved temporary file or folder '{filename}'", - extra=logger_extra + extra=logger_extra, ) # Store everything @@ -523,8 +516,7 @@ def retrieve_calculation(calculation: CalcJobNode, transport: Transport, retriev def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: - """ - Kill the calculation through the scheduler + """Kill the calculation through the scheduler :param calculation: the instance of CalcJobNode to kill. :param transport: an already opened transport to use to address the scheduler @@ -543,7 +535,6 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: result = scheduler.kill(job_id) if result is not True: - # Failed to kill because the job might have already been completed running_jobs = scheduler.get_jobs(jobs=[job_id], as_dict=True) job = running_jobs.get(job_id, None) @@ -558,11 +549,12 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: def retrieve_files_from_list( - calculation: CalcJobNode, transport: Transport, folder: str, retrieve_list: List[Union[str, Tuple[str, str, int], - list]] + calculation: CalcJobNode, + transport: Transport, + folder: str, + retrieve_list: List[Union[str, Tuple[str, str, int], list]], ) -> None: - """ - Retrieve all the files in the retrieve_list from the remote into the + """Retrieve all the files in the retrieve_list from the remote into the local folder instance through the transport. The entries in the retrieve_list can be of two types: @@ -584,7 +576,6 @@ def retrieve_files_from_list( :param folder: an absolute path to a folder that contains the files to copy. :param retrieve_list: the list of files to retrieve. """ - # pylint: disable=too-many-branches for item in retrieve_list: if isinstance(item, (list, tuple)): tmp_rname, tmp_lname, depth = item @@ -607,13 +598,12 @@ def retrieve_files_from_list( new_folder = os.path.join(folder, os.path.split(this_local_file)[0]) if not os.path.exists(new_folder): os.makedirs(new_folder) - else: # it is a string - if transport.has_magic(item): - remote_names = transport.glob(item) - local_names = [os.path.split(rem)[1] for rem in remote_names] - else: - remote_names = [item] - local_names = [os.path.split(item)[1]] + elif transport.has_magic(item): # it is a string + remote_names = transport.glob(item) + local_names = [os.path.split(rem)[1] for rem in remote_names] + else: + remote_names = [item] + local_names = [os.path.split(item)[1]] for rem, loc in zip(remote_names, local_names): transport.logger.debug(f"[retrieval of calc {calculation.pk}] Trying to retrieve remote item '{rem}'") diff --git a/aiida/engine/daemon/worker.py b/aiida/engine/daemon/worker.py index f90362dff9..78308015b9 100644 --- a/aiida/engine/daemon/worker.py +++ b/aiida/engine/daemon/worker.py @@ -55,7 +55,7 @@ def start_daemon_worker() -> None: sys.setrecursionlimit(rlimit) signals = (signal.SIGTERM, signal.SIGINT) - for s in signals: # pylint: disable=invalid-name + for s in signals: runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_worker(runner))) try: diff --git a/aiida/engine/launch.py b/aiida/engine/launch.py index 0fc72b8762..26c2ae430c 100644 --- a/aiida/engine/launch.py +++ b/aiida/engine/launch.py @@ -23,13 +23,13 @@ from .processes.functions import FunctionProcess from .processes.process import Process from .runners import ResultAndPk -from .utils import instantiate_process, is_process_scoped, prepare_inputs # pylint: disable=no-name-in-module +from .utils import instantiate_process, is_process_scoped, prepare_inputs __all__ = ('run', 'run_get_pk', 'run_get_node', 'submit', 'await_processes') -TYPE_RUN_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] # pylint: disable=invalid-name +TYPE_RUN_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] # run can also be process function, but it is not clear what type this should be -TYPE_SUBMIT_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] # pylint: disable=invalid-name +TYPE_SUBMIT_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] LOGGER = AIIDA_LOGGER.getChild('engine.launch') @@ -48,9 +48,9 @@ def run(process: TYPE_RUN_PROCESS, inputs: dict[str, t.Any] | None = None, **kwa return runner.run(process, inputs, **kwargs) -def run_get_node(process: TYPE_RUN_PROCESS, - inputs: dict[str, t.Any] | None = None, - **kwargs: t.Any) -> tuple[dict[str, t.Any], ProcessNode]: +def run_get_node( + process: TYPE_RUN_PROCESS, inputs: dict[str, t.Any] | None = None, **kwargs: t.Any +) -> tuple[dict[str, t.Any], ProcessNode]: """Run the process with the supplied inputs in a local runner that will block until the process is completed. :param process: the process class, instance, builder or function to run @@ -86,7 +86,7 @@ def submit( *, wait: bool = False, wait_interval: int = 5, - **kwargs: t.Any + **kwargs: t.Any, ) -> ProcessNode: """Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter. diff --git a/aiida/engine/persistence.py b/aiida/engine/persistence.py index 4a6168b589..6ef41cd282 100644 --- a/aiida/engine/persistence.py +++ b/aiida/engine/persistence.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=global-statement """Definition of AiiDA's process persister and the necessary object loaders.""" import importlib @@ -15,9 +14,9 @@ import traceback from typing import TYPE_CHECKING, Any, Hashable, Optional -from plumpy.exceptions import PersistenceError import plumpy.loaders import plumpy.persistence +from plumpy.exceptions import PersistenceError from aiida.orm.utils import serialize @@ -61,7 +60,7 @@ def get_object_loader() -> ObjectLoader: :return: The global object loader """ - global OBJECT_LOADER + global OBJECT_LOADER # noqa: PLW0603 if OBJECT_LOADER is None: OBJECT_LOADER = ObjectLoader() return OBJECT_LOADER @@ -140,7 +139,7 @@ def get_process_checkpoints(self, pid: Hashable): :return: list of PersistedCheckpoint tuples with element containing the process id and optional checkpoint tag. """ - def delete_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> None: # pylint: disable=unused-argument + def delete_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> None: """Delete a persisted process checkpoint, where no error will be raised if the checkpoint does not exist. :param pid: the process id of the :class:`plumpy.Process` diff --git a/aiida/engine/processes/__init__.py b/aiida/engine/processes/__init__.py index 20668be208..07c9d318fd 100644 --- a/aiida/engine/processes/__init__.py +++ b/aiida/engine/processes/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .builder import * from .calcjobs import * @@ -64,4 +63,4 @@ 'workfunction', ) -# yapf: enable +# fmt: on diff --git a/aiida/engine/processes/builder.py b/aiida/engine/processes/builder.py index 9362288d26..2bd5b68a61 100644 --- a/aiida/engine/processes/builder.py +++ b/aiida/engine/processes/builder.py @@ -8,8 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Convenience classes to help building the input dictionaries for Processes.""" -from collections.abc import Mapping, MutableMapping import json +from collections.abc import Mapping, MutableMapping from typing import TYPE_CHECKING, Any, Type from uuid import uuid4 @@ -28,7 +28,7 @@ class PrettyEncoder(json.JSONEncoder): """JSON encoder for returning a pretty representation of an AiiDA ``ProcessBuilder``.""" - def default(self, o): # pylint: disable=arguments-differ + def default(self, o): if isinstance(o, (ProcessBuilder, ProcessBuilderNamespace)): return dict(o) if isinstance(o, Dict): @@ -55,7 +55,6 @@ def __init__(self, port_namespace: PortNamespace) -> None: :param port_namespace: the inputs PortNamespace for which to construct the builder """ - # pylint: disable=super-init-not-called self._port_namespace = port_namespace self._valid_fields = [] self._data = {} @@ -67,7 +66,6 @@ def __init__(self, port_namespace: PortNamespace) -> None: # saved. If they are used directly in the body, it will try to capture the value from # its enclosing scope at the time of being called. for name, port in port_namespace.items(): - self._valid_fields.append(name) if isinstance(port, PortNamespace): @@ -77,7 +75,7 @@ def fgetter(self, name=name): return self._data.get(name) elif port.has_default(): - def fgetter(self, name=name, default=port.default): # type: ignore[misc] # pylint: disable=cell-var-from-loop + def fgetter(self, name=name, default=port.default): # type: ignore[misc] return self._data.get(name, default) else: @@ -89,7 +87,7 @@ def fsetter(self, value, name=name): fgetter.__doc__ = str(port) getter = property(fgetter) - getter.setter(fsetter) # pylint: disable=too-many-function-args + getter.setter(fsetter) dynamic_properties[name] = getter # The dynamic property can only be attached to a class and not an instance, however, we cannot attach it to @@ -204,13 +202,13 @@ def _update(self, *args, **kwds): if isinstance(value, Mapping): self[key].update(value) else: - self.__setattr__(key, value) # pylint: disable=unnecessary-dunder-call + self.__setattr__(key, value) for key, value in kwds.items(): if isinstance(value, Mapping): self[key].update(value) else: - self.__setattr__(key, value) # pylint: disable=unnecessary-dunder-call + self.__setattr__(key, value) def _inputs(self, prune: bool = False) -> dict: """Return the entire mapping of inputs specified for this builder. @@ -224,7 +222,7 @@ def _inputs(self, prune: bool = False) -> dict: return dict(self) -class ProcessBuilder(ProcessBuilderNamespace): # pylint: disable=too-many-ancestors +class ProcessBuilder(ProcessBuilderNamespace): """A process builder that helps setting up the inputs for creating a new process.""" def __init__(self, process_class: Type['Process']): @@ -241,7 +239,7 @@ def process_class(self) -> Type['Process']: """Return the process class for which this builder is constructed.""" return self._process_class - def _repr_pretty_(self, p, _) -> str: # pylint: disable=invalid-name + def _repr_pretty_(self, p, _) -> str: """Pretty representation for in the IPython console and notebooks.""" import yaml diff --git a/aiida/engine/processes/calcjobs/__init__.py b/aiida/engine/processes/calcjobs/__init__.py index 77686c9969..3f174ea7c4 100644 --- a/aiida/engine/processes/calcjobs/__init__.py +++ b/aiida/engine/processes/calcjobs/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .calcjob import * from .importer import * @@ -25,4 +24,4 @@ 'JobsList', ) -# yapf: enable +# fmt: on diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index c4fca3dd5e..0940d3654b 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines """Implementation of the CalcJob process.""" from __future__ import annotations @@ -39,7 +38,7 @@ __all__ = ('CalcJob',) -def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pylint: disable=too-many-return-statements +def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: """Validate the entire set of inputs passed to the `CalcJob` constructor. Reasons that will cause this validation to raise an `InputValidationError`: @@ -82,9 +81,7 @@ def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pyli if computer_from_code and computer_from_metadata and computer_from_code.uuid != computer_from_metadata.uuid: return ( 'Computer<{}> explicitly defined in `metadata.computer` is different from Computer<{}> which is the ' - 'computer of Code<{}> defined as the `code` input.'.format( - computer_from_metadata, computer_from_code, code - ) + 'computer of Code<{}> defined as the `code` input.'.format(computer_from_metadata, computer_from_code, code) ) try: @@ -124,9 +121,8 @@ def validate_stash_options(stash_options: Any, _: Any) -> Optional[str]: if not isinstance(target_base, str) or not os.path.isabs(target_base): return f'`metadata.options.stash.target_base` should be an absolute filepath, got: {target_base}' - if ( - not isinstance(source_list, (list, tuple)) or - any(not isinstance(src, str) or os.path.isabs(src) for src in source_list) + if not isinstance(source_list, (list, tuple)) or any( + not isinstance(src, str) or os.path.isabs(src) for src in source_list ): port = 'metadata.options.stash.source_list' return f'`{port}` should be a list or tuple of relative filepaths, got: {source_list}' @@ -212,7 +208,7 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] valid_type=orm.AbstractCode, required=False, help='The `Code` to use for this job. This input is required, unless the `remote_folder` input is ' - 'specified, which means an existing job is being imported and no code will actually be run.' + 'specified, which means an existing job is being imported and no code will actually be run.', ) spec.input_namespace( 'monitors', @@ -220,7 +216,7 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] required=False, validator=validate_monitors, help='Add monitoring functions that can inspect output files while the job is running and decide to ' - 'prematurely terminate the job.' + 'prematurely terminate the job.', ) spec.input( 'remote_folder', @@ -230,50 +226,50 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] 'inputs should be passed to the `CalcJob` as normal but instead of launching the actual job, the ' 'engine will recreate the input files and then proceed straight to the retrieve step where the files ' 'of this `RemoteData` will be retrieved as if it had been actually launched through AiiDA. If a ' - 'parser is defined in the inputs, the results are parsed and attached as output nodes as usual.' + 'parser is defined in the inputs, the results are parsed and attached as output nodes as usual.', ) spec.input( 'metadata.dry_run', valid_type=bool, default=False, - help='When set to `True` will prepare the calculation job for submission but not actually launch it.' + help='When set to `True` will prepare the calculation job for submission but not actually launch it.', ) spec.input( 'metadata.computer', valid_type=orm.Computer, required=False, - help='When using a "local" code, set the computer on which the calculation should be run.' + help='When using a "local" code, set the computer on which the calculation should be run.', ) spec.input_namespace(f'{spec.metadata_key}.{spec.options_key}', required=False) spec.input( 'metadata.options.input_filename', valid_type=str, required=False, - help='Filename to which the input for the code that is to be run is written.' + help='Filename to which the input for the code that is to be run is written.', ) spec.input( 'metadata.options.output_filename', valid_type=str, required=False, - help='Filename to which the content of stdout of the code that is to be run is written.' + help='Filename to which the content of stdout of the code that is to be run is written.', ) spec.input( 'metadata.options.submit_script_filename', valid_type=str, default='_aiidasubmit.sh', - help='Filename to which the job submission script is written.' + help='Filename to which the job submission script is written.', ) spec.input( 'metadata.options.scheduler_stdout', valid_type=str, default='_scheduler-stdout.txt', - help='Filename to which the content of stdout of the scheduler is written.' + help='Filename to which the content of stdout of the scheduler is written.', ) spec.input( 'metadata.options.scheduler_stderr', valid_type=str, default='_scheduler-stderr.txt', - help='Filename to which the content of stderr of the scheduler is written.' + help='Filename to which the content of stderr of the scheduler is written.', ) spec.input( 'metadata.options.resources', @@ -281,13 +277,13 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] required=True, help='Set the dictionary of resources to be used by the scheduler plugin, like the number of nodes, ' 'cpus etc. This dictionary is scheduler-plugin dependent. Look at the documentation of the ' - 'scheduler for more details.' + 'scheduler for more details.', ) spec.input( 'metadata.options.max_wallclock_seconds', valid_type=int, required=False, - help='Set the wallclock in seconds asked to the scheduler' + help='Set the wallclock in seconds asked to the scheduler', ) spec.input( 'metadata.options.custom_scheduler_commands', @@ -296,31 +292,31 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] help='Set a (possibly multiline) string with the commands that the user wants to manually set for the ' 'scheduler. The difference of this option with respect to the `prepend_text` is the position in ' 'the scheduler submission file where such text is inserted: with this option, the string is ' - 'inserted before any non-scheduler command' + 'inserted before any non-scheduler command', ) spec.input( 'metadata.options.queue_name', valid_type=str, required=False, - help='Set the name of the queue on the remote computer' + help='Set the name of the queue on the remote computer', ) spec.input( 'metadata.options.rerunnable', valid_type=bool, required=False, - help='Determines if the calculation can be requeued / rerun.' + help='Determines if the calculation can be requeued / rerun.', ) spec.input( 'metadata.options.account', valid_type=str, required=False, - help='Set the account to use in for the queue on the remote computer' + help='Set the account to use in for the queue on the remote computer', ) spec.input( 'metadata.options.qos', valid_type=str, required=False, - help='Set the quality of service to use in for the queue on the remote computer' + help='Set the quality of service to use in for the queue on the remote computer', ) spec.input( 'metadata.options.withmpi', @@ -355,16 +351,13 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] 'specified in ``environment_variables``.', ) spec.input( - 'metadata.options.priority', - valid_type=str, - required=False, - help='Set the priority of the job to be queued' + 'metadata.options.priority', valid_type=str, required=False, help='Set the priority of the job to be queued' ) spec.input( 'metadata.options.max_memory_kb', valid_type=int, required=False, - help='Set the maximum memory (in KiloBytes) to be asked to the scheduler' + help='Set the maximum memory (in KiloBytes) to be asked to the scheduler', ) spec.input( 'metadata.options.prepend_text', @@ -385,66 +378,66 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] valid_type=str, required=False, validator=validate_parser, - help='Set a string for the output parser. Can be None if no output plugin is available or needed' + help='Set a string for the output parser. Can be None if no output plugin is available or needed', ) spec.input( 'metadata.options.additional_retrieve_list', required=False, valid_type=(list, tuple), validator=validate_additional_retrieve_list, - help='List of relative file paths that should be retrieved in addition to what the plugin specifies.' + help='List of relative file paths that should be retrieved in addition to what the plugin specifies.', ) spec.input_namespace( 'metadata.options.stash', required=False, populate_defaults=False, validator=validate_stash_options, - help='Optional directives to stash files after the calculation job has completed.' + help='Optional directives to stash files after the calculation job has completed.', ) spec.input( 'metadata.options.stash.target_base', valid_type=str, required=False, help='The base location to where the files should be stashd. For example, for the `copy` stash mode, this ' - 'should be an absolute filepath on the remote computer.' + 'should be an absolute filepath on the remote computer.', ) spec.input( 'metadata.options.stash.source_list', valid_type=(tuple, list), required=False, - help='Sequence of relative filepaths representing files in the remote directory that should be stashed.' + help='Sequence of relative filepaths representing files in the remote directory that should be stashed.', ) spec.input( 'metadata.options.stash.stash_mode', valid_type=str, required=False, - help='Mode with which to perform the stashing, should be value of `aiida.common.datastructures.StashMode`.' + help='Mode with which to perform the stashing, should be value of `aiida.common.datastructures.StashMode`.', ) spec.output( 'remote_folder', valid_type=orm.RemoteData, - help='Input files necessary to run the process will be stored in this folder node.' + help='Input files necessary to run the process will be stored in this folder node.', ) spec.output( 'remote_stash', valid_type=orm.RemoteStashData, required=False, - help='Contents of the `stash.source_list` option are stored in this remote folder after job completion.' + help='Contents of the `stash.source_list` option are stored in this remote folder after job completion.', ) spec.output( cls.link_label_retrieved, valid_type=orm.FolderData, pass_to_parser=True, help='Files that are retrieved by the daemon will be stored in this node. By default the stdout and stderr ' - 'of the scheduler will be added, but one can add more by specifying them in `CalcInfo.retrieve_list`.' + 'of the scheduler will be added, but one can add more by specifying them in `CalcInfo.retrieve_list`.', ) spec.exit_code( 100, 'ERROR_NO_RETRIEVED_FOLDER', invalidates_cache=True, - message='The process did not have the required `retrieved` output.' + message='The process did not have the required `retrieved` output.', ) spec.exit_code( 110, 'ERROR_SCHEDULER_OUT_OF_MEMORY', invalidates_cache=True, message='The job ran out of memory.' @@ -461,13 +454,13 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] spec.exit_code(150, 'STOPPED_BY_MONITOR', invalidates_cache=True, message='{message}') @classproperty - def spec_options(cls): # pylint: disable=no-self-argument + def spec_options(cls): # noqa: N805 """Return the metadata options port namespace of the process specification of this process. :return: options dictionary :rtype: dict """ - return cls.spec_metadata['options'] # pylint: disable=unsubscriptable-object + return cls.spec_metadata['options'] @classmethod def get_importer(cls, entry_point_name: str | None = None) -> CalcJobImporter: @@ -556,7 +549,7 @@ def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wa # this case, the parser will not be called. The outputs will already have been added to the process node # though, so all that needs to be done here is just also assign them to the process instance. This such that # when the process returns its results, it returns the actual outputs and not an empty dictionary. - self._outputs = self.node.get_outgoing(link_type=LinkType.CREATE).nested() # pylint: disable=attribute-defined-outside-init + self._outputs = self.node.get_outgoing(link_type=LinkType.CREATE).nested() return self.node.exit_status # Launch the upload operation @@ -618,7 +611,7 @@ def _perform_dry_run(self): upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True) self.node.dry_run_info = { # type: ignore[attr-defined] 'folder': folder.abspath, - 'script_filename': self.node.get_option('submit_script_filename') + 'script_filename': self.node.get_option('submit_script_filename'), } def _perform_import(self): @@ -659,7 +652,7 @@ def parse( try: retrieved = self.node.outputs.retrieved except exceptions.NotExistent: - return self.exit_codes.ERROR_NO_RETRIEVED_FOLDER # pylint: disable=no-member + return self.exit_codes.ERROR_NO_RETRIEVED_FOLDER # Call the scheduler output parser exit_code_scheduler = self.parse_scheduler_output(retrieved) @@ -763,7 +756,7 @@ def parse_scheduler_output(self, retrieved: orm.Node) -> Optional[ExitCode]: except exceptions.FeatureNotAvailable: self.logger.info(f'`{scheduler.__class__.__name__}` does not implement scheduler output parsing') return None - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: self.logger.error(f'the `parse_output` method of the scheduler excepted: {exception}') return None @@ -793,7 +786,7 @@ def parse_retrieved_output(self, retrieved_temporary_folder: Optional[str] = Non self.out(link_label, node) except ValueError as exception: self.logger.error(f'invalid value {node} specified with label {link_label}: {exception}') - exit_code = self.exit_codes.ERROR_INVALID_OUTPUT # pylint: disable=no-member + exit_code = self.exit_codes.ERROR_INVALID_OUTPUT break if exit_code is not None and not isinstance(exit_code, ExitCode): @@ -810,7 +803,6 @@ def presubmit(self, folder: Folder) -> CalcInfo: :return calcinfo: the CalcInfo object containing the information needed by the daemon to handle operations. """ - # pylint: disable=too-many-locals,too-many-statements,too-many-branches from aiida.common.datastructures import CodeInfo, CodeRunMode from aiida.common.exceptions import InputValidationError, InvalidOperation, PluginInternalError, ValidationError from aiida.common.utils import validate_list_of_string_tuples @@ -856,10 +848,10 @@ def presubmit(self, folder: Folder) -> CalcInfo: # Set retrieve path, add also scheduler STDOUT and STDERR retrieve_list = calc_info.retrieve_list or [] - if (job_tmpl.sched_output_path is not None and job_tmpl.sched_output_path not in retrieve_list): + if job_tmpl.sched_output_path is not None and job_tmpl.sched_output_path not in retrieve_list: retrieve_list.append(job_tmpl.sched_output_path) if not job_tmpl.sched_join_files: - if (job_tmpl.sched_error_path is not None and job_tmpl.sched_error_path not in retrieve_list): + if job_tmpl.sched_error_path is not None and job_tmpl.sched_error_path not in retrieve_list: retrieve_list.append(job_tmpl.sched_error_path) retrieve_list.extend(self.node.get_option('additional_retrieve_list') or []) self.node.set_retrieve_list(retrieve_list) @@ -881,14 +873,18 @@ def presubmit(self, folder: Folder) -> CalcInfo: # - most importantly, skips the cases in which one of the methods # would return None, in which case the join method would raise # an exception - prepend_texts = [computer.get_prepend_text()] + \ - [code.prepend_text for code in codes] + \ - [calc_info.prepend_text, self.node.get_option('prepend_text')] + prepend_texts = ( + [computer.get_prepend_text()] + + [code.prepend_text for code in codes] + + [calc_info.prepend_text, self.node.get_option('prepend_text')] + ) job_tmpl.prepend_text = '\n\n'.join(prepend_text for prepend_text in prepend_texts if prepend_text) - append_texts = [self.node.get_option('append_text'), calc_info.append_text] + \ - [code.append_text for code in codes] + \ - [computer.get_append_text()] + append_texts = ( + [self.node.get_option('append_text'), calc_info.append_text] + + [code.append_text for code in codes] + + [computer.get_append_text()] + ) job_tmpl.append_text = '\n\n'.join(append_text for append_text in append_texts if append_text) # Set resources, also with get_default_mpiprocs_per_machine @@ -909,7 +905,6 @@ def presubmit(self, folder: Folder) -> CalcInfo: tmpl_codes_info = [] for code_info in calc_info.codes_info: - if not isinstance(code_info, CodeInfo): raise PluginInternalError('Invalid codes_info, must be a list of CodeInfo objects') @@ -1050,9 +1045,9 @@ def encoder(obj): f'[presubmission of calc {this_pk}] remote_copy_list format problem: {exception}' ) from exception - for (remote_computer_uuid, _, dest_rel_path) in remote_copy_list: + for remote_computer_uuid, _, dest_rel_path in remote_copy_list: try: - Computer.collection.get(uuid=remote_computer_uuid) # pylint: disable=unused-variable + Computer.collection.get(uuid=remote_computer_uuid) except exceptions.NotExistent as exception: raise PluginInternalError( '[presubmission of calc {}] ' @@ -1062,9 +1057,9 @@ def encoder(obj): ) from exception if os.path.isabs(dest_rel_path): raise PluginInternalError( - '[presubmission of calc {}] ' - 'The destination path of the remote copy ' - 'is absolute! ({})'.format(this_pk, dest_rel_path) + '[presubmission of calc {}] ' 'The destination path of the remote copy ' 'is absolute! ({})'.format( + this_pk, dest_rel_path + ) ) return calc_info diff --git a/aiida/engine/processes/calcjobs/manager.py b/aiida/engine/processes/calcjobs/manager.py index b04381e4d6..4b6aa0a323 100644 --- a/aiida/engine/processes/calcjobs/manager.py +++ b/aiida/engine/processes/calcjobs/manager.py @@ -201,10 +201,7 @@ async def updating(): @staticmethod def _has_job_state_changed(old: Optional['JobInfo'], new: Optional['JobInfo']) -> bool: - """Return whether the states `old` and `new` are different. - - - """ + """Return whether the states `old` and `new` are different.""" if old is None and new is None: return False @@ -225,13 +222,13 @@ def _get_next_update_delay(self) -> float: """ if self.last_updated is None: # Never updated, so do it straight away - return 0. + return 0.0 # Make sure to actually 'get' the minimum interval here, in case the user changed since last time minimum_interval = self.get_minimum_update_interval() elapsed = time.time() - self.last_updated - delay = max(minimum_interval - elapsed, 0.) + delay = max(minimum_interval - elapsed, 0.0) return delay diff --git a/aiida/engine/processes/calcjobs/monitors.py b/aiida/engine/processes/calcjobs/monitors.py index 9bc08b6a0f..d41d223c20 100644 --- a/aiida/engine/processes/calcjobs/monitors.py +++ b/aiida/engine/processes/calcjobs/monitors.py @@ -4,10 +4,10 @@ import collections import dataclasses -from datetime import datetime, timedelta import enum import inspect import typing as t +from datetime import datetime, timedelta from aiida.common.lang import type_check from aiida.common.log import AIIDA_LOGGER @@ -182,14 +182,14 @@ def process( :returns: ``None`` or a monitor result. """ for key, monitor in self.monitors.items(): - if monitor.disabled: LOGGER.debug(f'monitor`{key}` is disabled, skipping') continue if ( - monitor.minimum_poll_interval and monitor.call_timestamp and - datetime.now() - monitor.call_timestamp < timedelta(seconds=monitor.minimum_poll_interval) + monitor.minimum_poll_interval + and monitor.call_timestamp + and datetime.now() - monitor.call_timestamp < timedelta(seconds=monitor.minimum_poll_interval) ): LOGGER.debug(f'skipping monitor `{key}` because minimum poll interval has not expired yet.') continue diff --git a/aiida/engine/processes/calcjobs/tasks.py b/aiida/engine/processes/calcjobs/tasks.py index 5d7e14eaf2..149521b91b 100644 --- a/aiida/engine/processes/calcjobs/tasks.py +++ b/aiida/engine/processes/calcjobs/tasks.py @@ -48,10 +48,10 @@ RETRY_INTERVAL_OPTION = 'transport.task_retry_initial_interval' MAX_ATTEMPTS_OPTION = 'transport.task_maximum_attempts' -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) -class PreSubmitException(Exception): +class PreSubmitException(Exception): # noqa: N818 """Raise in the `do_upload` coroutine when an exception is raised in `CalcJob.presubmit`.""" @@ -89,7 +89,7 @@ async def do_upload(): # Any exception thrown in `presubmit` call is not transient so we circumvent the exponential backoff try: calc_info = process.presubmit(folder) - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: raise PreSubmitException('exception occurred in presubmit call') from exception else: execmanager.upload_calculation(node, transport, calc_info, folder) @@ -105,7 +105,7 @@ async def do_upload(): ) except PreSubmitException: raise - except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): raise except Exception as exception: logger.warning(f'uploading CalcJob<{node.pk}> failed') @@ -151,7 +151,7 @@ async def do_submit(): result = await exponential_backoff_retry( do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) - except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): raise except Exception as exception: logger.warning(f'submitting CalcJob<{node.pk}> failed') @@ -209,7 +209,7 @@ async def do_update(): job_done = await exponential_backoff_retry( do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) - except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): raise except Exception as exception: logger.warning(f'updating CalcJob<{node.pk}> failed') @@ -249,7 +249,6 @@ async def task_monitor_job( authinfo = node.get_authinfo() async def do_monitor(): - with transport_queue.request_transport(authinfo) as request: transport = await cancellable.with_interrupt(request) transport.chdir(node.get_remote_workdir()) @@ -261,7 +260,7 @@ async def do_monitor(): monitor_result = await exponential_backoff_retry( do_monitor, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) - except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): raise except Exception as exception: logger.warning(f'monitoring CalcJob<{node.pk}> failed') @@ -272,8 +271,10 @@ async def do_monitor(): async def task_retrieve_job( - node: CalcJobNode, transport_queue: TransportQueue, retrieved_temporary_folder: str, - cancellable: InterruptableFuture + node: CalcJobNode, + transport_queue: TransportQueue, + retrieved_temporary_folder: str, + cancellable: InterruptableFuture, ): """Transport task that will attempt to retrieve all files of a completed job calculation. @@ -328,7 +329,7 @@ async def do_retrieve(): result = await exponential_backoff_retry( do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) - except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): raise except Exception as exception: logger.warning(f'retrieving CalcJob<{node.pk}> failed') @@ -375,7 +376,7 @@ async def do_stash(): initial_interval, max_attempts, logger=node.logger, - ignore_exceptions=plumpy.process_states.Interruption + ignore_exceptions=plumpy.process_states.Interruption, ) except plumpy.process_states.Interruption: raise @@ -439,11 +440,9 @@ def __init__( process: 'CalcJob', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - data: Optional[Any] = None + data: Optional[Any] = None, ): - """ - :param process: The process this state belongs to - """ + """:param process: The process this state belongs to""" super().__init__(process, done_callback, msg, data) self._task: InterruptableFuture | None = None self._killing: plumpy.futures.Future | None = None @@ -473,9 +472,7 @@ def monitors(self) -> CalcJobMonitors | None: @property def process(self) -> 'CalcJob': - """ - :return: The process - """ + """:return: The process""" return self.state_machine # type: ignore[return-value] def load_instance_state(self, saved_state, load_context): @@ -483,9 +480,8 @@ def load_instance_state(self, saved_state, load_context): self._task = None self._killing = None - async def execute(self) -> plumpy.process_states.State: # type: ignore[override] # pylint: disable=invalid-overridden-method + async def execute(self) -> plumpy.process_states.State: # type: ignore[override] """Override the execute coroutine of the base `Waiting` state.""" - # pylint: disable=too-many-branches,too-many-statements,too-many-nested-blocks node = self.process.node transport_queue = self.process.runner.transport result: plumpy.process_states.State = self @@ -494,7 +490,6 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override node.set_process_status(process_status) try: - if self._command == UPLOAD_COMMAND: skip_submit = await self._launch_task(task_upload_job, self.process, transport_queue) if skip_submit: @@ -529,9 +524,7 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override if monitor_result and not monitor_result.retrieve: exit_code = self.process.exit_codes.STOPPED_BY_MONITOR.format(message=monitor_result.message) - return self.create_state( - ProcessState.RUNNING, self.process.terminate, exit_code - ) # type: ignore[return-value] + return self.create_state(ProcessState.RUNNING, self.process.terminate, exit_code) # type: ignore[return-value] result = self.stash(monitor_result=monitor_result) diff --git a/aiida/engine/processes/control.py b/aiida/engine/processes/control.py index 5341b378d8..c10848c3e0 100644 --- a/aiida/engine/processes/control.py +++ b/aiida/engine/processes/control.py @@ -93,7 +93,6 @@ def revive_processes(processes: list[ProcessNode], *, wait: bool = False) -> Non process_controller = get_manager().get_process_controller() for process in processes: - future = process_controller.continue_process(process.pk, nowait=not wait, no_reply=False) if future: @@ -110,11 +109,7 @@ def revive_processes(processes: list[ProcessNode], *, wait: bool = False) -> Non def play_processes( - processes: list[ProcessNode] | None = None, - *, - all_entries: bool = False, - timeout: float = 5.0, - wait: bool = False + processes: list[ProcessNode] | None = None, *, all_entries: bool = False, timeout: float = 5.0, wait: bool = False ) -> None: """Play (unpause) paused processes. @@ -149,7 +144,7 @@ def pause_processes( message: str = 'Paused through `aiida.engine.processes.control.pause_processes`', all_entries: bool = False, timeout: float = 5.0, - wait: bool = False + wait: bool = False, ) -> None: """Pause running processes. @@ -184,7 +179,7 @@ def kill_processes( message: str = 'Killed through `aiida.engine.processes.control.kill_processes`', all_entries: bool = False, timeout: float = 5.0, - wait: bool = False + wait: bool = False, ) -> None: """Kill running processes. @@ -218,9 +213,9 @@ def _perform_actions( action: t.Callable, infinitive: str, present: str, - timeout: float = None, + timeout: t.Optional[float] = None, wait: bool = False, - **kwargs: t.Any + **kwargs: t.Any, ) -> None: """Perform an action on a list of processes. @@ -237,7 +232,6 @@ def _perform_actions( futures = {} for process in processes: - if process.is_terminated: LOGGER.error(f'Process<{process.pk}> is already terminated.') continue @@ -257,7 +251,7 @@ def _resolve_futures( infinitive: str, present: str, wait: bool = False, - timeout: float = None + timeout: t.Optional[float] = None, ) -> None: """Process a mapping of futures representing an action on an active process. @@ -286,16 +280,15 @@ def handle_result(result): try: for future in concurrent.futures.as_completed(futures.keys(), timeout=timeout): - process = futures[future] try: # unwrap is need here since LoopCommunicator will also wrap a future - future = unwrap_kiwi_future(future) - result = future.result() + unwrapped = unwrap_kiwi_future(future) + result = unwrapped.result() except communications.TimeoutError: LOGGER.error(f'call to {infinitive} Process<{process.pk}> timed out') - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: LOGGER.error(f'failed to {infinitive} Process<{process.pk}>: {exception}') else: if isinstance(result, kiwipy.Future): @@ -314,7 +307,7 @@ def handle_result(result): try: result = future.result() - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: LOGGER.error(f'failed to {infinitive} Process<{process.pk}>: {exception}') else: handle_result(result) diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 9201150f16..f902557418 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -47,7 +47,7 @@ UnionType = types.UnionType except AttributeError: # This type is not available for Python 3.9 and older - UnionType = None # type: ignore[assignment,misc] # pylint: disable=invalid-name + UnionType = None # type: ignore[assignment,misc] try: from typing import ParamSpec @@ -59,7 +59,7 @@ get_annotations = inspect.get_annotations except AttributeError: # This is the backport for Python 3.9 and older - from get_annotations import get_annotations # type: ignore[no-redef] # pylint: disable=import-error + from get_annotations import get_annotations # type: ignore[no-redef] if TYPE_CHECKING: from .exit_code import ExitCode @@ -68,7 +68,7 @@ LOGGER = logging.getLogger(__name__) -FunctionType = t.TypeVar('FunctionType', bound=t.Callable[..., t.Any]) # pylint: disable=invalid-name +FunctionType = t.TypeVar('FunctionType', bound=t.Callable[..., t.Any]) def get_stack_size(size: int = 2) -> int: # type: ignore[return] @@ -83,9 +83,9 @@ def get_stack_size(size: int = 2) -> int: # type: ignore[return] :param size: Hint for the expected stack size. :returns: The stack size for caller's frame. """ - frame = sys._getframe(size) # pylint: disable=protected-access + frame = sys._getframe(size) try: - for size in itertools.count(size, 8): # pylint: disable=redefined-argument-from-local + for size in itertools.count(size, 8): frame = frame.f_back.f_back.f_back.f_back.f_back.f_back.f_back.f_back # type: ignore[assignment,union-attr] except AttributeError: while frame: # type: ignore[truthy-bool] @@ -126,8 +126,7 @@ def run_get_node(self, *args: P.args, **kwargs: P.kwargs) -> tuple[dict[str, t.A def calcfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, CalcFunctionNode]: - """ - A decorator to turn a standard python function into a calcfunction. + """A decorator to turn a standard python function into a calcfunction. Example usage: >>> from aiida.orm import Int @@ -153,8 +152,7 @@ def calcfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, def workfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, WorkFunctionNode]: - """ - A decorator to turn a standard python function into a workfunction. + """A decorator to turn a standard python function into a workfunction. Example usage: >>> from aiida.orm import Int @@ -180,15 +178,13 @@ def workfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]: - """ - The base function decorator to create a FunctionProcess out of a normal python function. + """The base function decorator to create a FunctionProcess out of a normal python function. :param node_class: the ORM class to be used as the Node record for the FunctionProcess """ def decorator(function: FunctionType) -> FunctionType: - """ - Turn the decorated function into a FunctionProcess. + """Turn the decorated function into a FunctionProcess. :param callable function: the actual decorated function that the FunctionProcess represents :return callable: The decorated function. @@ -196,8 +192,7 @@ def decorator(function: FunctionType) -> FunctionType: process_class = FunctionProcess.build(function, node_class=node_class) def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode']: - """ - Run the FunctionProcess with the supplied inputs in a local runner. + """Run the FunctionProcess with the supplied inputs in a local runner. :param args: input arguments to construct the FunctionProcess :param kwargs: input keyword arguments to construct the FunctionProcess @@ -213,7 +208,9 @@ def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode if frame_count > min(0.8 * stack_limit, stack_limit - 200): LOGGER.warning( 'Current stack contains %d frames which is close to the limit of %d. Increasing the limit by %d', - frame_count, stack_limit, frame_delta + frame_count, + stack_limit, + frame_delta, ) sys.setrecursionlimit(stack_limit + frame_delta) @@ -258,8 +255,8 @@ def kill_process(_num, _frame): store_provenance = inputs.get('metadata', {}).get('store_provenance', True) if not store_provenance: - process.node._storable = False # pylint: disable=protected-access - process.node._unstorable_message = 'cannot store node because it was run with `store_provenance=False`' # pylint: disable=protected-access + process.node._storable = False + process.node._unstorable_message = 'cannot store node because it was run with `store_provenance=False`' return result, process.node @@ -341,16 +338,14 @@ class FunctionProcess(Process): @staticmethod def _func(*_args, **_kwargs) -> dict: - """ - This is used internally to store the actual function that is being + """This is used internally to store the actual function that is being wrapped and will be replaced by the build method. """ return {} @staticmethod def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['FunctionProcess']: - """ - Build a Process from the given function. + """Build a Process from the given function. All function arguments will be assigned as process inputs. If keyword arguments are specified then these will also become inputs. @@ -362,10 +357,9 @@ def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['Func :return: A Process class that represents the function """ - # pylint: disable=too-many-statements if ( - not issubclass(node_class, ProcessNode) or # type: ignore[redundant-expr] - not issubclass(node_class, FunctionCalculationMixin) # type: ignore[unreachable] + not issubclass(node_class, ProcessNode) # type: ignore[redundant-expr] + or not issubclass(node_class, FunctionCalculationMixin) # type: ignore[unreachable] ): raise TypeError('the node_class should be a sub class of `ProcessNode` and `FunctionCalculationMixin`') @@ -377,7 +371,7 @@ def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['Func try: annotations = get_annotations(func, eval_str=True) - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: # Since we are running with ``eval_str=True`` to unstringize the annotations, the call can except if the # annotations are incorrect. In this case we simply want to log a warning and continue with type inference. LOGGER.warning(f'function `{func.__name__}` has invalid type hints: {exception}') @@ -385,7 +379,7 @@ def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['Func try: parsed_docstring = docstring_parser.parse(func.__doc__) - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: LOGGER.warning(f'function `{func.__name__}` has a docstring that could not be parsed: {exception}') param_help_string = {} namespace_help_string = None @@ -396,7 +390,6 @@ def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['Func namespace_help_string += f'\n\n{parsed_docstring.long_description}' for key, parameter in signature.parameters.items(): - if parameter.kind in [parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD, parameter.KEYWORD_ONLY]: args.append(key) @@ -406,14 +399,13 @@ def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['Func if parameter.kind is parameter.VAR_KEYWORD: var_keyword = key - def define(cls, spec): # pylint: disable=unused-argument + def define(cls, spec): """Define the spec dynamically""" from plumpy.ports import UNSPECIFIED super().define(spec) for parameter in signature.parameters.values(): - if parameter.kind in [parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD]: continue @@ -438,10 +430,14 @@ def define(cls, spec): # pylint: disable=unused-argument # done lazily using a lambda, just as any port defaults should not define node instances directly as is # also checked by the ``spec.input`` call. if ( - default is not None and default != UNSPECIFIED and not isinstance(default, Data) and - not callable(default) + default is not None + and default != UNSPECIFIED + and not isinstance(default, Data) + and not callable(default) ): - indirect_default = lambda value=default: to_aiida_type(value) # pylint: disable=unnecessary-lambda-assignment + + def indirect_default(value=default): + return to_aiida_type(value) else: indirect_default = default @@ -470,7 +466,9 @@ def define(cls, spec): # pylint: disable=unused-argument spec.outputs.valid_type = (Data, dict) return type( - func.__qualname__, (FunctionProcess,), { + func.__qualname__, + (FunctionProcess,), + { '__module__': func.__module__, '__name__': func.__name__, '__qualname__': func.__qualname__, @@ -479,12 +477,12 @@ def define(cls, spec): # pylint: disable=unused-argument '_func_args': args, '_var_positional': var_positional, '_var_keyword': var_keyword, - '_node_class': node_class - } + '_node_class': node_class, + }, ) @classmethod - def validate_inputs(cls, *args: t.Any, **kwargs: t.Any) -> None: # pylint: disable=unused-argument + def validate_inputs(cls, *args: t.Any, **kwargs: t.Any) -> None: """Validate the positional and keyword arguments passed in the function call. :raises TypeError: if more positional arguments are passed than the function defines @@ -545,8 +543,7 @@ def __init__(self, *args, **kwargs) -> None: @property def process_class(self) -> t.Callable[..., t.Any]: - """ - Return the class that represents this Process, for the FunctionProcess this is the function itself. + """Return the class that represents this Process, for the FunctionProcess this is the function itself. For a standard Process or sub class of Process, this is the class itself. However, for legacy reasons, the Process class is a wrapper around another class. This function returns that original class, i.e. the diff --git a/aiida/engine/processes/futures.py b/aiida/engine/processes/futures.py index 9bd8ebec20..a7cb6b799c 100644 --- a/aiida/engine/processes/futures.py +++ b/aiida/engine/processes/futures.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=cyclic-import """Futures that can poll or receive broadcasted messages while waiting for a task to be completed.""" import asyncio from typing import Optional, Union @@ -29,7 +28,7 @@ def __init__( pk: int, loop: Optional[asyncio.AbstractEventLoop] = None, poll_interval: Union[None, int, float] = None, - communicator: Optional[kiwipy.Communicator] = None + communicator: Optional[kiwipy.Communicator] = None, ): """Construct a future for a process node being finished. @@ -60,7 +59,7 @@ def __init__( # Try setting up a filtered broadcast subscriber if self._communicator is not None: - def _subscriber(*args, **kwargs): # pylint: disable=unused-argument + def _subscriber(*args, **kwargs): if not self.done(): self.set_result(node) diff --git a/aiida/engine/processes/ports.py b/aiida/engine/processes/ports.py index 4ffd0e5998..5c2f4864fb 100644 --- a/aiida/engine/processes/ports.py +++ b/aiida/engine/processes/ports.py @@ -8,10 +8,10 @@ # For further information please visit http://www.aiida.net # ########################################################################### """AiiDA specific implementation of plumpy Ports and PortNamespaces for the ProcessSpec.""" -from collections.abc import Mapping import re -from typing import Any, Callable, Dict, Optional, Sequence import warnings +from collections.abc import Mapping +from typing import Any, Callable, Dict, Optional, Sequence from plumpy import ports from plumpy.ports import breadcrumbs_to_port @@ -20,13 +20,18 @@ from aiida.orm import Data, Node __all__ = ( - 'PortNamespace', 'InputPort', 'OutputPort', 'CalcJobOutputPort', 'WithNonDb', 'WithSerialize', - 'PORT_NAMESPACE_SEPARATOR' + 'PortNamespace', + 'InputPort', + 'OutputPort', + 'CalcJobOutputPort', + 'WithNonDb', + 'WithSerialize', + 'PORT_NAMESPACE_SEPARATOR', ) -PORT_NAME_MAX_CONSECUTIVE_UNDERSCORES = 1 # pylint: disable=invalid-name +PORT_NAME_MAX_CONSECUTIVE_UNDERSCORES = 1 PORT_NAMESPACE_SEPARATOR = '__' # The character sequence to represent a nested port namespace in a flat link label -OutputPort = ports.OutputPort # pylint: disable=invalid-name +OutputPort = ports.OutputPort class WithNonDb: @@ -101,8 +106,7 @@ def is_metadata(self, is_metadata: bool) -> None: class WithSerialize: - """ - A mixin that adds support for a serialization function which is automatically applied on inputs + """A mixin that adds support for a serialization function which is automatically applied on inputs that are not AiiDA data types. """ @@ -124,8 +128,7 @@ def serialize(self, value: Any) -> 'Data': class InputPort(WithMetadata, WithSerialize, WithNonDb, ports.InputPort): - """ - Sub class of plumpy.InputPort which mixes in the WithSerialize and WithNonDb mixins to support automatic + """Sub class of plumpy.InputPort which mixes in the WithSerialize and WithNonDb mixins to support automatic value serialization to database storable types and support non database storable input types as well. The mixins have to go before the main port class in the superclass order to make sure they have the chance to @@ -134,15 +137,16 @@ class InputPort(WithMetadata, WithSerialize, WithNonDb, ports.InputPort): def __init__(self, *args, **kwargs) -> None: """Override the constructor to check the type of the default if set and warn if not immutable.""" - # pylint: disable=redefined-builtin,too-many-arguments if 'default' in kwargs: default = kwargs['default'] # If the default is specified and it is a node instance, raise a warning. This is to try and prevent that # people set node instances as defaults which can cause various problems. if default is not ports.UNSPECIFIED and isinstance(default, Node): - message = 'default of input port `{}` is a `Node` instance, which can lead to unexpected side effects.'\ + message = ( + 'default of input port `{}` is a `Node` instance, which can lead to unexpected side effects.' ' It is advised to use a lambda instead, e.g.: `default=lambda: orm.Int(5)`.'.format(args[0]) - warnings.warn(UserWarning(message)) # pylint: disable=no-member + ) + warnings.warn(UserWarning(message)) # If the port is not required and defines ``valid_type``, automatically add ``None`` as a valid type valid_type = kwargs.get('valid_type', ()) @@ -181,8 +185,7 @@ def pass_to_parser(self) -> bool: class PortNamespace(WithMetadata, WithNonDb, ports.PortNamespace): - """ - Sub class of plumpy.PortNamespace which implements the serialize method to support automatic recursive + """Sub class of plumpy.PortNamespace which implements the serialize method to support automatic recursive serialization of a given mapping onto the ports of the PortNamespace. """ @@ -200,9 +203,7 @@ def __setitem__(self, key: str, port: ports.Port) -> None: self.validate_port_name(key) - if hasattr( - port, 'is_metadata_explicitly_set' - ) and not port.is_metadata_explicitly_set: # type: ignore[attr-defined] + if hasattr(port, 'is_metadata_explicitly_set') and not port.is_metadata_explicitly_set: # type: ignore[attr-defined] port.is_metadata = self.is_metadata # type: ignore[attr-defined] if hasattr(port, 'non_db_explicitly_set') and not port.non_db_explicitly_set: # type: ignore[attr-defined] @@ -263,7 +264,6 @@ def serialize(self, mapping: Optional[Dict[str, Any]], breadcrumbs: Sequence[str for name, value in mapping.items(): if name in self: - port = self[name] if isinstance(port, PortNamespace): result[name] = port.serialize(value, breadcrumbs) diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py index 18aefd88db..c402a479a2 100644 --- a/aiida/engine/processes/process.py +++ b/aiida/engine/processes/process.py @@ -7,16 +7,15 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines """The AiiDA process class""" import asyncio import collections -from collections.abc import Mapping import copy import enum import inspect import logging import traceback +from collections.abc import Mapping from types import TracebackType from typing import ( TYPE_CHECKING, @@ -34,13 +33,13 @@ ) from uuid import UUID -from aio_pika.exceptions import ConnectionClosed -from kiwipy.communications import UnroutableError import plumpy.exceptions import plumpy.futures import plumpy.persistence -from plumpy.process_states import Finished, ProcessState import plumpy.processes +from aio_pika.exceptions import ConnectionClosed +from kiwipy.communications import UnroutableError +from plumpy.process_states import Finished, ProcessState from plumpy.utils import AttributesFrozendict from aiida import orm @@ -66,11 +65,9 @@ @plumpy.persistence.auto_persist('_parent_pid', '_enable_persistence') class Process(plumpy.processes.Process): - """ - This class represents an AiiDA process which can be executed and will + """This class represents an AiiDA process which can be executed and will have full provenance saved in the database. """ - # pylint: disable=too-many-public-methods _node_class = orm.ProcessNode _spec_class = ProcessSpec @@ -78,9 +75,8 @@ class Process(plumpy.processes.Process): SINGLE_OUTPUT_LINKNAME: str = 'result' class SaveKeys(enum.Enum): - """ - Keys used to identify things in the saved instance state bundle. - """ + """Keys used to identify things in the saved instance state bundle.""" + CALC_ID: str = 'calc_id' @classmethod @@ -100,13 +96,13 @@ def define(cls, spec: ProcessSpec) -> None: # type: ignore[override] f'{spec.metadata_key}.store_provenance', valid_type=bool, default=True, - help='If set to `False` provenance will not be stored in the database.' + help='If set to `False` provenance will not be stored in the database.', ) spec.input( f'{spec.metadata_key}.description', valid_type=str, required=False, - help='Description to set on the process node.' + help='Description to set on the process node.', ) spec.input( f'{spec.metadata_key}.label', valid_type=str, required=False, help='Label to set on the process node.' @@ -115,7 +111,7 @@ def define(cls, spec: ProcessSpec) -> None: # type: ignore[override] f'{spec.metadata_key}.call_link_label', valid_type=str, default='CALL', - help='The label to use for the `CALL` link if the process is called by another process.' + help='The label to use for the `CALL` link if the process is called by another process.', ) spec.inputs.valid_type = orm.Data spec.inputs.dynamic = False # Settings a ``valid_type`` automatically makes it dynamic, so we reset it again @@ -132,7 +128,7 @@ def define(cls, spec: ProcessSpec) -> None: # type: ignore[override] 11, 'ERROR_MISSING_OUTPUT', invalidates_cache=True, - message='The process did not register a required output.' + message='The process did not register a required output.', ) @classmethod @@ -141,8 +137,7 @@ def get_builder(cls) -> ProcessBuilder: @classmethod def get_or_create_db_record(cls) -> orm.ProcessNode: - """ - Create a process node that represents what happened in this process. + """Create a process node that represents what happened in this process. :return: A process node """ @@ -154,9 +149,9 @@ def __init__( logger: Optional[logging.Logger] = None, runner: Optional['Runner'] = None, parent_pid: Optional[int] = None, - enable_persistence: bool = True + enable_persistence: bool = True, ) -> None: - """ Process constructor. + """Process constructor. :param inputs: process inputs :param logger: aiida logger @@ -174,7 +169,7 @@ def __init__( inputs=self.spec().inputs.serialize(inputs), logger=logger, loop=self._runner.loop, - communicator=self._runner.communicator + communicator=self._runner.communicator, ) self._node: Optional[orm.ProcessNode] = None @@ -201,7 +196,7 @@ def get_exit_statuses(cls, exit_code_labels: Iterable[str]) -> List[int]: return [getattr(exit_codes, label).status for label in exit_code_labels] @classproperty - def exit_codes(cls) -> ExitCodesNamespace: # pylint: disable=no-self-argument + def exit_codes(cls) -> ExitCodesNamespace: # noqa: N805 """Return the namespace of exit codes defined for this WorkChain through its ProcessSpec. The namespace supports getitem and getattr operations with an ExitCode label to retrieve a specific code. @@ -213,7 +208,7 @@ def exit_codes(cls) -> ExitCodesNamespace: # pylint: disable=no-self-argument return cls.spec().exit_codes @classproperty - def spec_metadata(cls) -> PortNamespace: # pylint: disable=no-self-argument + def spec_metadata(cls) -> PortNamespace: # noqa: N805 """Return the metadata port namespace of the process specification of this process.""" return cls.spec().inputs['metadata'] # type: ignore[return-value] @@ -256,8 +251,7 @@ def metadata(self) -> AttributeDict: return AttributeDict() def _save_checkpoint(self) -> None: - """ - Save the current state in a chechpoint if persistence is enabled and the process state is not terminal + """Save the current state in a chechpoint if persistence is enabled and the process state is not terminal If the persistence call excepts with a PersistenceError, it will be caught and a warning will be logged. """ @@ -320,15 +314,14 @@ def load_instance_state( if self.SaveKeys.CALC_ID.value in saved_state: self._node = orm.load_node(saved_state[self.SaveKeys.CALC_ID.value]) # type: ignore[assignment] - self._pid = self.node.pk # pylint: disable=attribute-defined-outside-init + self._pid = self.node.pk else: - self._pid = self._create_and_setup_db_record() # pylint: disable=attribute-defined-outside-init + self._pid = self._create_and_setup_db_record() self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state') def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Future]: - """ - Kill the process and all the children calculations it called + """Kill the process and all the children calculations it called :param msg: message """ @@ -408,12 +401,11 @@ def on_create(self) -> None: current = Process.current() if isinstance(current, Process): self._parent_pid = current.pid # type: ignore[assignment] - self._pid = self._create_and_setup_db_record() # pylint: disable=attribute-defined-outside-init + self._pid = self._create_and_setup_db_record() @override def on_entered(self, from_state: Optional[plumpy.process_states.State]) -> None: """After entering a new state, save a checkpoint and update the latest process state change timestamp.""" - # pylint: disable=cyclic-import from aiida.engine.utils import set_process_state_change_timestamp # For reasons unknown, it is important to update the outputs first, before doing anything else, otherwise there @@ -424,7 +416,7 @@ def on_entered(self, from_state: Optional[plumpy.process_states.State]) -> None: # if the process is transitioning to the terminal excepted state. try: self.update_outputs() - except ValueError: # pylint: disable=try-except-raise + except ValueError: raise finally: self.node.set_process_state(self._state.LABEL) # type: ignore[arg-type] @@ -441,7 +433,7 @@ def on_terminated(self) -> None: try: assert self.runner.persister is not None self.runner.persister.delete_checkpoint(self.pid) - except Exception as error: # pylint: disable=broad-except + except Exception as error: self.logger.exception('Failed to delete checkpoint: %s', error) try: @@ -451,8 +443,7 @@ def on_terminated(self) -> None: @override def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: - """ - Log the exception by calling the report method with formatted stack trace from exception info object + """Log the exception by calling the report method with formatted stack trace from exception info object and store the exception string as a node attribute :param exc_info: the sys.exc_info() object (type, value, traceback) @@ -463,7 +454,7 @@ def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: @override def on_finish(self, result: Union[int, ExitCode, None], successful: bool) -> None: - """ Set the finish status on the process node. + """Set the finish status on the process node. :param result: result of the process :param successful: whether execution was successful @@ -473,7 +464,7 @@ def on_finish(self, result: Union[int, ExitCode, None], successful: bool) -> Non if result is None: if not successful: - result = self.exit_codes.ERROR_MISSING_OUTPUT # pylint: disable=no-member + result = self.exit_codes.ERROR_MISSING_OUTPUT else: result = ExitCode() @@ -489,8 +480,7 @@ def on_finish(self, result: Union[int, ExitCode, None], successful: bool) -> Non @override def on_paused(self, msg: Optional[str] = None) -> None: - """ - The Process was paused so set the paused attribute on the process node + """The Process was paused so set the paused attribute on the process node :param msg: message @@ -501,16 +491,13 @@ def on_paused(self, msg: Optional[str] = None) -> None: @override def on_playing(self) -> None: - """ - The Process was unpaused so remove the paused attribute on the process node - """ + """The Process was unpaused so remove the paused attribute on the process node""" super().on_playing() self.node.unpause() @override def on_output_emitting(self, output_port: str, value: Any) -> None: - """ - The process has emitted a value on the given output port. + """The process has emitted a value on the given output port. :param output_port: The output port name the value was emitted on :param value: The value emitted @@ -523,8 +510,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None: raise TypeError(f'Processes can only return `orm.Data` instances as output, got {value.__class__}') def set_status(self, status: Optional[str]) -> None: - """ - The status of the Process is about to be changed, so we reflect this is in node's attribute proxy. + """The status of the Process is about to be changed, so we reflect this is in node's attribute proxy. :param status: the status message @@ -547,8 +533,7 @@ def runner(self) -> 'Runner': return self._runner def get_parent_calc(self) -> Optional[orm.ProcessNode]: - """ - Get the parent process node + """Get the parent process node :return: the parent process node if there is one @@ -561,8 +546,7 @@ def get_parent_calc(self) -> Optional[orm.ProcessNode]: @classmethod def build_process_type(cls) -> str: - """ - The process type. + """The process type. :return: string of the process type @@ -597,8 +581,7 @@ def report(self, msg: str, *args, **kwargs) -> None: self.logger.log(LOG_LEVEL_REPORT, message, *args, **kwargs) def _create_and_setup_db_record(self) -> Union[int, UUID]: - """ - Create and setup the database record for this process + """Create and setup the database record for this process :return: the uuid or pk of the process @@ -609,7 +592,7 @@ def _create_and_setup_db_record(self) -> Union[int, UUID]: try: self.node.store_all() if self.node.is_finished_ok: - self._state = Finished(self, None, True) # pylint: disable=attribute-defined-outside-init + self._state = Finished(self, None, True) for entry in self.node.base.links.get_outgoing(link_type=LinkType.RETURN): if entry.link_label.endswith(f'_{entry.node.pk}'): continue @@ -634,8 +617,7 @@ def _create_and_setup_db_record(self) -> Union[int, UUID]: @override def encode_input_args(self, inputs: Dict[str, Any]) -> str: - """ - Encode input arguments such that they may be saved in a Bundle + """Encode input arguments such that they may be saved in a Bundle :param inputs: A mapping of the inputs as passed to the process :return: The encoded (serialized) inputs @@ -644,8 +626,7 @@ def encode_input_args(self, inputs: Dict[str, Any]) -> str: @override def decode_input_args(self, encoded: str) -> Dict[str, Any]: - """ - Decode saved input arguments as they came from the saved instance state Bundle + """Decode saved input arguments as they came from the saved instance state Bundle :param encoded: encoded (serialized) inputs :return: The decoded input args @@ -661,12 +642,12 @@ def update_outputs(self) -> None: return outputs_flat = self._flat_outputs() - outputs_stored = self.node.base.links.get_outgoing(link_type=(LinkType.CREATE, LinkType.RETURN) - ).all_link_labels() + outputs_stored = self.node.base.links.get_outgoing( + link_type=(LinkType.CREATE, LinkType.RETURN) + ).all_link_labels() outputs_new = set(outputs_flat.keys()) - set(outputs_stored) for link_label, output in outputs_flat.items(): - if link_label not in outputs_new: continue @@ -688,8 +669,7 @@ def _build_process_label(self) -> str: return self.__class__.__name__ def _setup_db_record(self) -> None: - """ - Create the database record for this process and the links with respect to its inputs + """Create the database record for this process and the links with respect to its inputs This function will set various attributes on the node that serve as a proxy for attributes of the Process. This is essential as otherwise this information could only be introspected through the Process itself, which @@ -710,7 +690,6 @@ def _setup_db_record(self) -> None: parent_calc = self.get_parent_calc() if parent_calc and self.metadata.store_provenance: - if isinstance(parent_calc, orm.CalculationNode): raise exceptions.InvalidOperation('calling processes from a calculation type process is forbidden.') @@ -764,7 +743,6 @@ def _setup_metadata(self, metadata: dict) -> None: def _setup_inputs(self) -> None: """Create the links between the input nodes and the ProcessNode that represents this process.""" for name, node in self._flat_inputs().items(): - # Certain processes allow to specify ports with `None` as acceptable values if node is None: continue @@ -821,8 +799,7 @@ def _filter_serializable_metadata( return result or None def _flat_inputs(self) -> Dict[str, Any]: - """ - Return a flattened version of the parsed inputs dictionary. + """Return a flattened version of the parsed inputs dictionary. The eventual keys will be a concatenation of the nested keys. Note that the `metadata` dictionary, if present, is not passed, as those are dealt with separately in `_setup_metadata`. @@ -834,8 +811,7 @@ def _flat_inputs(self) -> Dict[str, Any]: return dict(self._flatten_inputs(self.spec().inputs, inputs)) def _flat_outputs(self) -> Dict[str, Any]: - """ - Return a flattened version of the registered outputs dictionary. + """Return a flattened version of the registered outputs dictionary. The eventual keys will be a concatenation of the nested keys. @@ -848,10 +824,9 @@ def _flatten_inputs( port: Union[None, InputPort, PortNamespace], port_value: Any, parent_name: str = '', - separator: str = PORT_NAMESPACE_SEPARATOR + separator: str = PORT_NAMESPACE_SEPARATOR, ) -> List[Tuple[str, Any]]: - """ - Function that will recursively flatten the inputs dictionary, omitting inputs for ports that + """Function that will recursively flatten the inputs dictionary, omitting inputs for ports that are marked as being non database storable :param port: port against which to map the port value, can be InputPort or PortNamespace @@ -861,15 +836,14 @@ def _flatten_inputs( :return: flat list of inputs """ - if (port is None and - isinstance(port_value, - orm.Node)) or (isinstance(port, InputPort) and not (port.is_metadata or port.non_db)): + if (port is None and isinstance(port_value, orm.Node)) or ( + isinstance(port, InputPort) and not (port.is_metadata or port.non_db) + ): return [(parent_name, port_value)] if port is None and isinstance(port_value, Mapping) or isinstance(port, PortNamespace): items = [] for name, value in port_value.items(): - prefixed_key = parent_name + separator + name if parent_name else name try: @@ -893,10 +867,9 @@ def _flatten_outputs( port: Union[None, OutputPort, PortNamespace], port_value: Any, parent_name: str = '', - separator: str = PORT_NAMESPACE_SEPARATOR + separator: str = PORT_NAMESPACE_SEPARATOR, ) -> List[Tuple[str, Any]]: - """ - Function that will recursively flatten the outputs dictionary. + """Function that will recursively flatten the outputs dictionary. :param port: port against which to map the port value, can be OutputPort or PortNamespace :param port_value: value for the current port, can be a Mapping @@ -909,10 +882,9 @@ def _flatten_outputs( if port is None and isinstance(port_value, orm.Node) or isinstance(port, OutputPort): return [(parent_name, port_value)] - if (port is None and isinstance(port_value, Mapping) or isinstance(port, PortNamespace)): + if port is None and isinstance(port_value, Mapping) or isinstance(port, PortNamespace): items = [] for name, value in port_value.items(): - prefixed_key = parent_name + separator + name if parent_name else name try: @@ -930,10 +902,7 @@ def _flatten_outputs( return [] def exposed_inputs( - self, - process_class: Type['Process'], - namespace: Optional[str] = None, - agglomerate: bool = True + self, process_class: Type['Process'], namespace: Optional[str] = None, agglomerate: bool = True ) -> AttributeDict: """Gather a dictionary of the inputs that were exposed for a given Process class under an optional namespace. @@ -949,7 +918,6 @@ def exposed_inputs( namespace_list = self._get_namespace_list(namespace=namespace, agglomerate=agglomerate) for sub_namespace in namespace_list: - # The sub_namespace None indicates the base level sub_namespace if sub_namespace is None: inputs = self.inputs @@ -964,7 +932,7 @@ def exposed_inputs( raise ValueError(f'this process does not contain the "{sub_namespace}" input namespace') # Get the list of ports that were exposed for the given Process class in the current sub_namespace - exposed_inputs_list = self.spec()._exposed_inputs[sub_namespace][process_class] # pylint: disable=protected-access + exposed_inputs_list = self.spec()._exposed_inputs[sub_namespace][process_class] for name in port_namespace.ports.keys(): if inputs and name in inputs and name in exposed_inputs_list: @@ -977,7 +945,7 @@ def exposed_outputs( node: orm.ProcessNode, process_class: Type['Process'], namespace: Optional[str] = None, - agglomerate: bool = True + agglomerate: bool = True, ) -> AttributeDict: """Return the outputs which were exposed from the ``process_class`` and emitted by the specific ``node`` @@ -1004,9 +972,9 @@ def exposed_outputs( for port_namespace in self._get_namespace_list(namespace=namespace, agglomerate=agglomerate): # only the top-level key is stored in _exposed_outputs for top_name in top_namespace_map: - if namespace is not None and namespace not in self.spec()._exposed_outputs: # pylint: disable=protected-access + if namespace is not None and namespace not in self.spec()._exposed_outputs: raise KeyError(f'the namespace `{namespace}` is not an exposed namespace.') - if top_name in self.spec()._exposed_outputs[port_namespace][process_class]: # pylint: disable=protected-access + if top_name in self.spec()._exposed_outputs[port_namespace][process_class]: output_key_map[top_name] = port_namespace result = {} @@ -1063,9 +1031,8 @@ def is_valid_cache(cls, node: orm.ProcessNode) -> bool: return True -def get_query_string_from_process_type_string(process_type_string: str) -> str: # pylint: disable=invalid-name - """ - Take the process type string of a Node and create the queryable type string. +def get_query_string_from_process_type_string(process_type_string: str) -> str: + """Take the process type string of a Node and create the queryable type string. :param process_type_string: the process type string :type process_type_string: str diff --git a/aiida/engine/processes/process_spec.py b/aiida/engine/processes/process_spec.py index 75cc0af015..df01f58786 100644 --- a/aiida/engine/processes/process_spec.py +++ b/aiida/engine/processes/process_spec.py @@ -46,16 +46,14 @@ def options_key(self) -> str: @property def exit_codes(self) -> ExitCodesNamespace: - """ - Return the namespace of exit codes defined for this ProcessSpec + """Return the namespace of exit codes defined for this ProcessSpec :returns: ExitCodesNamespace of ExitCode named tuples """ return self._exit_codes def exit_code(self, status: int, label: str, message: str, invalidates_cache: bool = False) -> None: - """ - Add an exit code to the ProcessSpec + """Add an exit code to the ProcessSpec :param status: the exit status integer :param label: a label by which the exit code can be addressed diff --git a/aiida/engine/processes/workchains/__init__.py b/aiida/engine/processes/workchains/__init__.py index 56b6a94d2d..a429003d86 100644 --- a/aiida/engine/processes/workchains/__init__.py +++ b/aiida/engine/processes/workchains/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .awaitable import * from .context import * @@ -37,4 +36,4 @@ 'while_', ) -# yapf: enable +# fmt: on diff --git a/aiida/engine/processes/workchains/awaitable.py b/aiida/engine/processes/workchains/awaitable.py index 2c8e90dffb..33e7068068 100644 --- a/aiida/engine/processes/workchains/awaitable.py +++ b/aiida/engine/processes/workchains/awaitable.py @@ -24,18 +24,19 @@ class Awaitable(AttributesDict): class AwaitableTarget(Enum): """Enum that describes the class of the target a given awaitable.""" + PROCESS = 'process' class AwaitableAction(Enum): """Enum that describes the action to be taken for a given awaitable.""" + ASSIGN = 'assign' APPEND = 'append' def construct_awaitable(target: Union[Awaitable, ProcessNode]) -> Awaitable: - """ - Construct an instance of the Awaitable class that will contain the information + """Construct an instance of the Awaitable class that will contain the information related to the action to be taken with respect to the context once the awaitable object is completed. diff --git a/aiida/engine/processes/workchains/context.py b/aiida/engine/processes/workchains/context.py index 13092ad63e..a8f1e53675 100644 --- a/aiida/engine/processes/workchains/context.py +++ b/aiida/engine/processes/workchains/context.py @@ -20,8 +20,7 @@ def assign_(target: Union[Awaitable, ProcessNode]) -> Awaitable: - """ - Convenience function that will construct an Awaitable for a given class instance + """Convenience function that will construct an Awaitable for a given class instance with the context action set to ASSIGN. When the awaitable target is completed it will be assigned to the context for a key that is to be defined later @@ -36,8 +35,7 @@ def assign_(target: Union[Awaitable, ProcessNode]) -> Awaitable: def append_(target: Union[Awaitable, ProcessNode]) -> Awaitable: - """ - Convenience function that will construct an Awaitable for a given class instance + """Convenience function that will construct an Awaitable for a given class instance with the context action set to APPEND. When the awaitable target is completed it will be appended to a list in the context for a key that is to be defined later diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 1aa94a2397..fd202d6324 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -19,7 +19,7 @@ from aiida.common.warnings import warn_deprecation from .context import ToContext, append_ -from .utils import ProcessHandlerReport, process_handler # pylint: disable=no-name-in-module +from .utils import ProcessHandlerReport, process_handler from .workchain import WorkChain if TYPE_CHECKING: @@ -29,9 +29,7 @@ def validate_handler_overrides( - process_class: 'BaseRestartWorkChain', - handler_overrides: Optional[orm.Dict], - ctx: 'PortNamespace' # pylint: disable=unused-argument + process_class: 'BaseRestartWorkChain', handler_overrides: Optional[orm.Dict], ctx: 'PortNamespace' ) -> Optional[str]: """Validator for the ``handler_overrides`` input port of the ``BaseRestartWorkChain``. @@ -64,7 +62,7 @@ def validate_handler_overrides( warn_deprecation( 'Setting a boolean as value for `handler_overrides` is deprecated. Use ' "`{'handler_name': {'enabled': " + f'{overrides}' + '}` instead.', - version=3 + version=3, ) if isinstance(overrides, dict): @@ -133,7 +131,8 @@ def handle_problem(self, node): @property def process_class(self) -> Type['Process']: """Return the process class to run in the loop.""" - from ..process import Process # pylint: disable=cyclic-import + from ..process import Process + if self._process_class is None or not issubclass(self._process_class, Process): raise ValueError('no valid Process class defined for `_process_class` attribute') return self._process_class @@ -141,32 +140,45 @@ def process_class(self) -> Type['Process']: @classmethod def define(cls, spec: 'ProcessSpec') -> None: # type: ignore[override] """Define the process specification.""" - # yapf: disable super().define(spec) - spec.input('max_iterations', valid_type=orm.Int, default=lambda: orm.Int(5), - help='Maximum number of iterations the work chain will restart the process to finish successfully.') - spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False), - help='If `True`, work directories of all called calculation jobs will be cleaned at the end of execution.') - spec.input('handler_overrides', - valid_type=orm.Dict, required=False, validator=functools.partial(validate_handler_overrides, cls), + spec.input( + 'max_iterations', + valid_type=orm.Int, + default=lambda: orm.Int(5), + help='Maximum number of iterations the work chain will restart the process to finish successfully.', + ) + spec.input( + 'clean_workdir', + valid_type=orm.Bool, + default=lambda: orm.Bool(False), + help='If `True`, work directories of all called calculation jobs will be cleaned at the end of execution.', + ) + spec.input( + 'handler_overrides', + valid_type=orm.Dict, + required=False, + validator=functools.partial(validate_handler_overrides, cls), serializer=orm.to_aiida_type, help='Mapping where keys are process handler names and the values are a dictionary, where each dictionary ' - 'can define the ``enabled`` and ``priority`` key, which can be used to toggle the values set on ' - 'the original process handler declaration.') - spec.exit_code(301, 'ERROR_SUB_PROCESS_EXCEPTED', - message='The sub process excepted.') - spec.exit_code(302, 'ERROR_SUB_PROCESS_KILLED', - message='The sub process was killed.') - spec.exit_code(401, 'ERROR_MAXIMUM_ITERATIONS_EXCEEDED', - message='The maximum number of iterations was exceeded.') - spec.exit_code(402, 'ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE', - message='The process failed for an unknown reason, twice in a row.') - # yapf: enable + 'can define the ``enabled`` and ``priority`` key, which can be used to toggle the values set on ' + 'the original process handler declaration.', + ) + spec.exit_code(301, 'ERROR_SUB_PROCESS_EXCEPTED', message='The sub process excepted.') + spec.exit_code(302, 'ERROR_SUB_PROCESS_KILLED', message='The sub process was killed.') + spec.exit_code( + 401, 'ERROR_MAXIMUM_ITERATIONS_EXCEEDED', message='The maximum number of iterations was exceeded.' + ) + spec.exit_code( + 402, + 'ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE', + message='The process failed for an unknown reason, twice in a row.', + ) def setup(self) -> None: """Initialize context variables that are used during the logical flow of the `BaseRestartWorkChain`.""" - overrides = self.inputs.handler_overrides.get_dict() if (self.inputs and - 'handler_overrides' in self.inputs) else {} + overrides = ( + self.inputs.handler_overrides.get_dict() if (self.inputs and 'handler_overrides' in self.inputs) else {} + ) self.ctx.handler_overrides = overrides self.ctx.process_name = self.process_class.__name__ self.ctx.unhandled_failure = False @@ -208,7 +220,7 @@ def run_process(self) -> ToContext: return ToContext(children=append_(node)) - def inspect_process(self) -> Optional['ExitCode']: # pylint: disable=too-many-branches + def inspect_process(self) -> Optional['ExitCode']: """Analyse the results of the previous process and call the handlers when necessary. If the process is excepted or killed, the work chain will abort. Otherwise any attached handlers will be called @@ -232,16 +244,15 @@ def inspect_process(self) -> Optional['ExitCode']: # pylint: disable=too-many-b node = self.ctx.children[self.ctx.iteration - 1] if node.is_excepted: - return self.exit_codes.ERROR_SUB_PROCESS_EXCEPTED # pylint: disable=no-member + return self.exit_codes.ERROR_SUB_PROCESS_EXCEPTED if node.is_killed: - return self.exit_codes.ERROR_SUB_PROCESS_KILLED # pylint: disable=no-member + return self.exit_codes.ERROR_SUB_PROCESS_KILLED last_report = None # Sort the handlers with a priority defined, based on their priority in reverse order for _, handler in sorted(self.get_process_handlers_by_priority(), key=lambda e: e[0], reverse=True): - # Even though the ``handler`` is an instance method, the ``get_process_handlers_by_priority`` method returns # unbound methods so we have to pass in ``self`` manually. Also, always pass the ``node`` as an argument # because the ``process_handler`` decorator with which the handler is decorated relies on this behavior. @@ -266,7 +277,7 @@ def inspect_process(self) -> Optional['ExitCode']: # pylint: disable=too-many-b if self.ctx.unhandled_failure: template = '{}<{}> failed and error was not handled for the second consecutive time, aborting' self.report(template.format(*report_args)) - return self.exit_codes.ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE # pylint: disable=no-member + return self.exit_codes.ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE self.ctx.unhandled_failure = True self.report('{}<{}> failed and error was not handled, restarting once more'.format(*report_args)) @@ -318,7 +329,7 @@ def results(self) -> Optional['ExitCode']: f'reached the maximum number of iterations {max_iterations}: ' f'last ran {self.ctx.process_name}<{node.pk}>' ) - return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED # pylint: disable=no-member + return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED self.report(f'work chain completed after {self.ctx.iteration} iterations') self._attach_outputs(node) @@ -334,7 +345,6 @@ def _attach_outputs(self, node) -> Mapping[str, orm.Node]: existing_outputs = self.node.base.links.get_outgoing(link_type=LinkType.RETURN).all_link_labels() for name, port in self.spec().outputs.items(): - try: output = outputs[name] except KeyError: @@ -356,7 +366,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # try retrieving process class - self.process_class # pylint: disable=pointless-statement + self.process_class @classmethod def is_process_handler(cls, process_handler_name: Union[str, FunctionType]) -> bool: @@ -365,7 +375,6 @@ def is_process_handler(cls, process_handler_name: Union[str, FunctionType]) -> b :param process_handler_name: string name of the instance method :return: boolean, True if corresponds to process handler, False otherwise """ - # pylint: disable=comparison-with-callable if isinstance(process_handler_name, str): handler = getattr(cls, process_handler_name, {}) else: @@ -382,7 +391,6 @@ def get_process_handlers_by_priority(self) -> List[Tuple[int, FunctionType]]: handlers = [] for handler in self.get_process_handlers(): - overrides = self.ctx.handler_overrides.get(handler.__name__, {}) enabled = None @@ -414,7 +422,7 @@ def on_terminated(self): for called_descendant in self.node.called_descendants: if isinstance(called_descendant, orm.CalcJobNode): try: - called_descendant.outputs.remote_folder._clean() # pylint: disable=protected-access + called_descendant.outputs.remote_folder._clean() cleaned_calcs.append(str(called_descendant.pk)) except (IOError, OSError, KeyError): pass @@ -434,14 +442,12 @@ def _wrap_bare_dict_inputs(self, port_namespace: 'PortNamespace', inputs: Dict[s wrapped = {} for key, value in inputs.items(): - if key not in port_namespace: wrapped[key] = value continue port = port_namespace[key] - valid_types = port.valid_type \ - if isinstance(port.valid_type, (list, tuple)) else (port.valid_type,) # type: ignore[redundant-expr] + valid_types = port.valid_type if isinstance(port.valid_type, (list, tuple)) else (port.valid_type,) # type: ignore[redundant-expr] if isinstance(port, PortNamespace): wrapped[key] = self._wrap_bare_dict_inputs(port, value) diff --git a/aiida/engine/processes/workchains/utils.py b/aiida/engine/processes/workchains/utils.py index f52a886ee4..6ee279596b 100644 --- a/aiida/engine/processes/workchains/utils.py +++ b/aiida/engine/processes/workchains/utils.py @@ -10,7 +10,7 @@ """Utilities for `WorkChain` implementations.""" from functools import partial from inspect import getfullargspec -from types import FunctionType # pylint: disable=no-name-in-module +from types import FunctionType from typing import List, NamedTuple, Optional, Union from wrapt import decorator @@ -37,6 +37,7 @@ class ProcessHandlerReport(NamedTuple): which has status `0` meaning that the work chain step will be considered successful and the work chain will continue to the next step. """ + do_break: bool = False exit_code: ExitCode = ExitCode() @@ -46,7 +47,7 @@ def process_handler( *, priority: int = 0, exit_codes: Union[None, ExitCode, List[ExitCode]] = None, - enabled: bool = True + enabled: bool = True, ) -> FunctionType: """Decorator to register a :class:`~aiida.engine.BaseRestartWorkChain` instance method as a process handler. @@ -78,9 +79,7 @@ def process_handler( basis through the input `handler_overrides`. """ if wrapped is None: - return partial( - process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled - ) # type: ignore[return-value] + return partial(process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled) # type: ignore[return-value] if not isinstance(wrapped, FunctionType): raise TypeError('first argument can only be an instance method, use keywords for decorator arguments.') @@ -108,7 +107,6 @@ def process_handler( @decorator def wrapper(wrapped, instance, args, kwargs): - # When the handler will be called by the `BaseRestartWorkChain` it will pass the node as the only argument node = args[0] @@ -119,7 +117,7 @@ def wrapper(wrapped, instance, args, kwargs): # Append the name and return value of the current process handler to the `considered_handlers` extra. try: - considered_handlers = instance.node.base.extras.get(instance._considered_handlers_extra, []) # pylint: disable=protected-access + considered_handlers = instance.node.base.extras.get(instance._considered_handlers_extra, []) current_process = considered_handlers[-1] except IndexError: # The extra was never initialized, so we skip this functionality @@ -130,8 +128,8 @@ def wrapper(wrapped, instance, args, kwargs): if isinstance(serialized, ProcessHandlerReport): serialized = {'do_break': serialized.do_break, 'exit_status': serialized.exit_code.status} current_process.append((wrapped.__name__, serialized)) - instance.node.base.extras.set(instance._considered_handlers_extra, considered_handlers) # pylint: disable=protected-access + instance.node.base.extras.set(instance._considered_handlers_extra, considered_handlers) return result - return wrapper(wrapped) # pylint: disable=no-value-for-parameter + return wrapper(wrapped) diff --git a/aiida/engine/processes/workchains/workchain.py b/aiida/engine/processes/workchains/workchain.py index e17f816c37..c4d74fef8c 100644 --- a/aiida/engine/processes/workchains/workchain.py +++ b/aiida/engine/processes/workchains/workchain.py @@ -18,9 +18,8 @@ from plumpy.persistence import auto_persist from plumpy.process_states import Continue, Wait from plumpy.processes import ProcessStateMachineMeta -from plumpy.workchains import Stepper +from plumpy.workchains import Stepper, _PropagateReturn, if_, return_, while_ from plumpy.workchains import WorkChainSpec as PlumpyWorkChainSpec -from plumpy.workchains import _PropagateReturn, if_, return_, while_ from aiida.common import exceptions from aiida.common.extendeddicts import AttributeDict @@ -34,7 +33,7 @@ from .awaitable import Awaitable, AwaitableAction, AwaitableTarget, construct_awaitable if t.TYPE_CHECKING: - from aiida.engine.runners import Runner # pylint: disable=unused-import + from aiida.engine.runners import Runner __all__ = ('WorkChain', 'if_', 'while_', 'return_') @@ -62,7 +61,7 @@ def private_method(self): __SENTINEL = object() - def __new__(mcs, name, bases, namespace, **kwargs): + def __new__(mcs, name, bases, namespace, **kwargs): # noqa: N804 """Collect all methods that were marked as protected and raise if the subclass defines it. :raises RuntimeError: If the new class defines (i.e. overrides) a method that was decorated with ``final``. @@ -76,25 +75,25 @@ def __new__(mcs, name, bases, namespace, **kwargs): return super().__new__(mcs, name, bases, namespace, **kwargs) @classmethod - def __is_final(mcs, method) -> bool: # pylint: disable=unused-private-member + def __is_final(mcs, method) -> bool: # noqa: N804 """Return whether the method has been decorated by the ``final`` classmethod. :return: Boolean, ``True`` if the method is marked as final, ``False`` otherwise. """ try: - return method.__final is mcs.__SENTINEL # pylint: disable=protected-access + return method.__final is mcs.__SENTINEL except AttributeError: return False @classmethod - def final(mcs, method: MethodType) -> MethodType: + def final(mcs, method: MethodType) -> MethodType: # noqa: N804 """Decorate a method with this method to protect it from being overridden. Adds the ``__SENTINEL`` object as the ``__final`` private attribute to the given ``method`` and wraps it in the ``typing.final`` decorator. The latter indicates to typing systems that it cannot be overridden in subclasses. """ - method.__final = mcs.__SENTINEL # type: ignore[attr-defined] # pylint: disable=protected-access,unused-private-member + method.__final = mcs.__SENTINEL # type: ignore[attr-defined] return t.final(method) @@ -112,7 +111,7 @@ def __init__( inputs: dict | None = None, logger: logging.Logger | None = None, runner: 'Runner' | None = None, - enable_persistence: bool = True + enable_persistence: bool = True, ) -> None: """Construct a WorkChain instance. @@ -187,8 +186,7 @@ def on_run(self): self.node.set_stepper_state_info(str(self._stepper)) def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]: - """ - Returns a reference to a sub-dictionary of the context and the last key, + """Returns a reference to a sub-dictionary of the context and the last key, after resolving a potentially segmented key where required sub-dictionaries are created as needed. :param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary @@ -210,7 +208,7 @@ def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]: # (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself # * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable # would be an AttributeDict we can append things to it since the order of tasks is maintained. - if type(ctx) != AttributeDict: # pylint: disable=C0123 + if type(ctx) != AttributeDict: raise ValueError( f'Can not update the context for key `{key}`:' f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index+1])}`, expected AttributeDict' @@ -248,7 +246,6 @@ def _resolve_awaitable(self, awaitable: Awaitable, value: t.Any) -> None: :param awaitable: the awaitable to resolve """ - ctx, key = self._resolve_nested_context(awaitable.key) if awaitable.action == AwaitableAction.ASSIGN: @@ -363,7 +360,7 @@ def on_exiting(self) -> None: super().on_exiting() try: self._store_nodes(self.ctx) - except Exception: # pylint: disable=broad-except + except Exception: # An uncaught exception here will have bizarre and disastrous consequences self.logger.exception('exception in _store_nodes called in on_exiting') @@ -388,7 +385,7 @@ def _action_awaitables(self) -> None: callback = functools.partial(self.call_soon, self._on_awaitable_finished, awaitable) self.runner.call_on_process_finish(awaitable.pk, callback) else: - assert f"invalid awaitable target '{awaitable.target}'" + raise AssertionError(f"invalid awaitable target '{awaitable.target}'") def _on_awaitable_finished(self, awaitable: Awaitable) -> None: """Callback function, for when an awaitable process instance is completed. diff --git a/aiida/engine/runners.py b/aiida/engine/runners.py index 18260f7806..ba74126bcb 100644 --- a/aiida/engine/runners.py +++ b/aiida/engine/runners.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=global-statement """Runners that can run and submit processes.""" from __future__ import annotations @@ -16,8 +15,8 @@ import logging import signal import threading -from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple, Type, Union import uuid +from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple, Type, Union import kiwipy from plumpy.communications import wrap_communicator @@ -48,12 +47,12 @@ class ResultAndPk(NamedTuple): pk: int | None -TYPE_RUN_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name +TYPE_RUN_PROCESS = Union[Process, Type[Process], ProcessBuilder] # run can also be process function, but it is not clear what type this should be -TYPE_SUBMIT_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name +TYPE_SUBMIT_PROCESS = Union[Process, Type[Process], ProcessBuilder] -class Runner: # pylint: disable=too-many-public-methods +class Runner: """Class that can launch processes by running in the current interpreter or by submitting them to the daemon.""" _persister: Optional[Persister] = None @@ -67,7 +66,7 @@ def __init__( loop: Optional[asyncio.AbstractEventLoop] = None, communicator: Optional[kiwipy.Communicator] = None, rmq_submit: bool = False, - persister: Optional[Persister] = None + persister: Optional[Persister] = None, ): """Construct a new runner. @@ -78,8 +77,9 @@ def __init__( :param persister: the persister to use to persist processes """ - assert not (rmq_submit and persister is None), \ - 'Must supply a persister if you want to submit using communicator' + assert not ( + rmq_submit and persister is None + ), 'Must supply a persister if you want to submit using communicator' set_event_loop_policy() self._loop = loop if loop is not None else asyncio.get_event_loop() @@ -167,13 +167,14 @@ def close(self) -> None: self._closed = True def instantiate_process(self, process: TYPE_RUN_PROCESS, **inputs): - from .utils import instantiate_process # pylint: disable=no-name-in-module + from .utils import instantiate_process + return instantiate_process(self, process, **inputs) def submit(self, process: TYPE_SUBMIT_PROCESS, inputs: dict[str, Any] | None = None, **kwargs: Any): - """ - Submit the process with the supplied inputs to this runner immediately returning control to - the interpreter. The return value will be the calculation node of the submitted process + """Submit the process with the supplied inputs to this runner immediately returning control to the interpreter. + + The return value will be the calculation node of the submitted process :param process: the process class to submit :param inputs: the inputs to be passed to the process @@ -205,8 +206,7 @@ def submit(self, process: TYPE_SUBMIT_PROCESS, inputs: dict[str, Any] | None = N def schedule( self, process: TYPE_SUBMIT_PROCESS, inputs: dict[str, Any] | None = None, **kwargs: Any ) -> ProcessNode: - """ - Schedule a process to be executed by this runner + """Schedule a process to be executed by this runner. :param process: the process class to submit :param inputs: the inputs to be passed to the process @@ -220,12 +220,11 @@ def schedule( self.loop.create_task(process_inited.step_until_terminated()) return process_inited.node - def _run(self, - process: TYPE_RUN_PROCESS, - inputs: dict[str, Any] | None = None, - **kwargs: Any) -> Tuple[Dict[str, Any], ProcessNode]: - """ - Run the process with the supplied inputs in this runner that will block until the process is completed. + def _run( + self, process: TYPE_RUN_PROCESS, inputs: dict[str, Any] | None = None, **kwargs: Any + ) -> Tuple[Dict[str, Any], ProcessNode]: + """Run the process with the supplied inputs in this runner that will block until the process is completed. + The return value will be the results of the completed process :param process: the process class or process function to run @@ -265,8 +264,8 @@ def kill_process(_num, _frame): return process_inited.outputs, process_inited.node def run(self, process: TYPE_RUN_PROCESS, inputs: dict[str, Any] | None = None, **kwargs: Any) -> Dict[str, Any]: - """ - Run the process with the supplied inputs in this runner that will block until the process is completed. + """Run the process with the supplied inputs in this runner that will block until the process is completed. + The return value will be the results of the completed process :param process: the process class or process function to run @@ -279,8 +278,8 @@ def run(self, process: TYPE_RUN_PROCESS, inputs: dict[str, Any] | None = None, * def run_get_node( self, process: TYPE_RUN_PROCESS, inputs: dict[str, Any] | None = None, **kwargs: Any ) -> ResultAndNode: - """ - Run the process with the supplied inputs in this runner that will block until the process is completed. + """Run the process with the supplied inputs in this runner that will block until the process is completed. + The return value will be the results of the completed process :param process: the process class or process function to run @@ -291,8 +290,8 @@ def run_get_node( return ResultAndNode(result, node) def run_get_pk(self, process: TYPE_RUN_PROCESS, inputs: dict[str, Any] | None = None, **kwargs: Any) -> ResultAndPk: - """ - Run the process with the supplied inputs in this runner that will block until the process is completed. + """Run the process with the supplied inputs in this runner that will block until the process is completed. + The return value will be the results of the completed process :param process: the process class or process function to run @@ -316,7 +315,7 @@ def call_on_process_finish(self, pk: int, callback: Callable[[], Any]) -> None: subscriber_identifier = str(uuid.uuid4()) event = threading.Event() - def inline_callback(event, *args, **kwargs): # pylint: disable=unused-argument + def inline_callback(event, *args, **kwargs): """Callback to wrap the actual callback, that will always remove the subscriber that will be registered. As soon as the callback is called successfully once, the `event` instance is toggled, such that if this diff --git a/aiida/engine/transports.py b/aiida/engine/transports.py index df891e7c96..4c257fb132 100644 --- a/aiida/engine/transports.py +++ b/aiida/engine/transports.py @@ -24,7 +24,7 @@ class TransportRequest: - """ Information kept about request for a transport object """ + """Information kept about request for a transport object""" def __init__(self): super().__init__() @@ -33,8 +33,7 @@ def __init__(self): class TransportQueue: - """ - A queue to get transport objects from authinfo. This class allows clients + """A queue to get transport objects from authinfo. This class allows clients to register their interest in a transport object which will be provided at some point in the future. @@ -45,21 +44,18 @@ class TransportQueue: """ def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None): - """ - :param loop: An asyncio event, will use `asyncio.get_event_loop()` if not supplied - """ + """:param loop: An asyncio event, will use `asyncio.get_event_loop()` if not supplied""" self._loop = loop if loop is not None else asyncio.get_event_loop() self._transport_requests: Dict[Hashable, TransportRequest] = {} @property def loop(self) -> asyncio.AbstractEventLoop: - """ Get the loop being used by this transport queue """ + """Get the loop being used by this transport queue""" return self._loop @contextlib.contextmanager def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable['Transport']]: - """ - Request a transport from an authinfo. Because the client is not allowed to + """Request a transport from an authinfo. Because the client is not allowed to request a transport immediately they will instead be given back a future that can be awaited to get the transport:: @@ -83,13 +79,13 @@ async def transport_task(transport_queue, authinfo): safe_open_interval = transport.get_safe_open_interval() def do_open(): - """ Actually open the transport """ + """Actually open the transport""" if transport_request.count > 0: # The user still wants the transport so open it _LOGGER.debug('Transport request opening transport for %s', authinfo) try: transport.open() - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: _LOGGER.error('exception occurred while trying to open transport:\n %s', exception) transport_request.future.set_exception(exception) @@ -108,7 +104,7 @@ def do_open(): try: transport_request.count += 1 yield transport_request.future - except asyncio.CancelledError: # pylint: disable=try-except-raise + except asyncio.CancelledError: # note this is only required in python<=3.7, # where asyncio.CancelledError inherits from Exception _LOGGER.debug('Transport task cancelled') diff --git a/aiida/engine/utils.py b/aiida/engine/utils.py index bb40f1d908..42f0a55800 100644 --- a/aiida/engine/utils.py +++ b/aiida/engine/utils.py @@ -7,15 +7,14 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name """Utilities for the workflow engine.""" from __future__ import annotations import asyncio import contextlib -from datetime import datetime import inspect import logging +from datetime import datetime from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator, List, Optional, Tuple, Type, Union if TYPE_CHECKING: @@ -52,8 +51,7 @@ def prepare_inputs(inputs: dict[str, Any] | None = None, **kwargs: Any) -> dict[ def instantiate_process( runner: 'Runner', process: Union['Process', Type['Process'], 'ProcessBuilder'], **inputs ) -> 'Process': - """ - Return an instance of the process with the given inputs. The function can deal with various types + """Return an instance of the process with the given inputs. The function can deal with various types of the `process`: * Process instance: will simply return the instance @@ -75,7 +73,7 @@ def instantiate_process( if isinstance(process, ProcessBuilder): builder = process process_class = builder.process_class - inputs.update(**builder._inputs(prune=True)) # pylint: disable=protected-access + inputs.update(**builder._inputs(prune=True)) elif is_process_function(process): process_class = process.process_class # type: ignore[attr-defined] elif inspect.isclass(process) and issubclass(process, Process): @@ -96,8 +94,7 @@ def interrupt(self, reason: Exception) -> None: self.set_exception(reason) async def with_interrupt(self, coro: Awaitable[Any]) -> Any: - """ - return result of a coroutine which will be interrupted if this future is interrupted :: + """Return result of a coroutine which will be interrupted if this future is interrupted :: import asyncio loop = asyncio.get_event_loop() @@ -121,17 +118,14 @@ async def with_interrupt(self, coro: Awaitable[Any]) -> Any: def interruptable_task( - coro: Callable[[InterruptableFuture], Awaitable[Any]], - loop: Optional[asyncio.AbstractEventLoop] = None + coro: Callable[[InterruptableFuture], Awaitable[Any]], loop: Optional[asyncio.AbstractEventLoop] = None ) -> InterruptableFuture: - """ - Turn the given coroutine into an interruptable task by turning it into an InterruptableFuture and returning it. + """Turn the given coroutine into an interruptable task by turning it into an InterruptableFuture and returning it. :param coro: the coroutine that should be made interruptable with object of InterutableFuture as last paramenter :param loop: the event loop in which to run the coroutine, by default uses asyncio.get_event_loop() :return: an InterruptableFuture """ - loop = loop or asyncio.get_event_loop() future = InterruptableFuture() @@ -139,13 +133,15 @@ async def execute_coroutine(): """Coroutine that wraps the original coroutine and sets it result on the future only if not already set.""" try: result = await coro(future) - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: if not future.done(): future.set_exception(exception) else: LOGGER.warning( - 'Interruptable future set to %s before its coro %s is done. %s', future.result(), coro.__name__, - str(exception) + 'Interruptable future set to %s before its coro %s is done. %s', + future.result(), + coro.__name__, + str(exception), ) else: # If the future has not been set elsewhere, i.e. by the interrupt call, by the time that the coroutine @@ -159,8 +155,7 @@ async def execute_coroutine(): def ensure_coroutine(fct: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: - """ - Ensure that the given function ``fct`` is a coroutine + """Ensure that the given function ``fct`` is a coroutine If the passed function is not already a coroutine, it will be made to be a coroutine @@ -181,10 +176,9 @@ async def exponential_backoff_retry( initial_interval: Union[int, float] = 10.0, max_attempts: int = 5, logger: Optional[logging.Logger] = None, - ignore_exceptions: Union[None, Type[Exception], Tuple[Type[Exception], ...]] = None + ignore_exceptions: Union[None, Type[Exception], Tuple[Type[Exception], ...]] = None, ) -> Any: - """ - Coroutine to call a function, recalling it with an exponential backoff in the case of an exception + """Coroutine to call a function, recalling it with an exponential backoff in the case of an exception This coroutine will loop ``max_attempts`` times, calling the ``fct`` function, breaking immediately when the call finished without raising an exception, at which point the result will be returned. If an exception is caught, the @@ -208,8 +202,7 @@ async def exponential_backoff_retry( try: result = await coro() break # Finished successfully - except Exception as exception: # pylint: disable=broad-except - + except Exception as exception: # Re-raise exceptions that should be ignored if ignore_exceptions is not None and isinstance(exception, ignore_exceptions): raise @@ -247,13 +240,13 @@ def is_process_scoped() -> bool: :returns: True if the current scope is within a nested process, False otherwise """ from .processes.process import Process + return Process.current() is not None @contextlib.contextmanager def loop_scope(loop) -> Iterator[None]: - """ - Make an event loop current for the scope of the context + """Make an event loop current for the scope of the context :param loop: The event loop to make current for the duration of the scope """ @@ -267,15 +260,14 @@ def loop_scope(loop) -> Iterator[None]: def set_process_state_change_timestamp(process: 'Process') -> None: - """ - Set the global setting that reflects the last time a process changed state, for the process type + """Set the global setting that reflects the last time a process changed state, for the process type of the given process, to the current timestamp. The process type will be determined based on the class of the calculation node it has as its database container. :param process: the Process instance that changed its state """ from aiida.common import timezone - from aiida.manage import get_manager # pylint: disable=cyclic-import + from aiida.manage import get_manager from aiida.orm import CalculationNode, ProcessNode, WorkflowNode if isinstance(process.node, CalculationNode): @@ -297,8 +289,7 @@ def set_process_state_change_timestamp(process: 'Process') -> None: def get_process_state_change_timestamp(process_type: Optional[str] = None) -> Optional[datetime]: - """ - Get the global setting that reflects the last time a process of the given process type changed its state. + """Get the global setting that reflects the last time a process of the given process type changed its state. The returned value will be the corresponding timestamp or None if the setting does not exist. :param process_type: optional process type for which to get the latest state change timestamp. @@ -306,7 +297,7 @@ def get_process_state_change_timestamp(process_type: Optional[str] = None) -> Op known process types will be returned. :return: a timestamp or None """ - from aiida.manage import get_manager # pylint: disable=cyclic-import + from aiida.manage import get_manager valid_process_types = ['calculation', 'work'] diff --git a/aiida/manage/__init__.py b/aiida/manage/__init__.py index a745def690..cf7524f5ab 100644 --- a/aiida/manage/__init__.py +++ b/aiida/manage/__init__.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Managing an AiiDA instance: +"""Managing an AiiDA instance: * configuration file * profiles @@ -22,8 +21,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .caching import * from .configuration import * @@ -58,4 +56,4 @@ 'upgrade_config', ) -# yapf: enable +# fmt: on diff --git a/aiida/manage/caching.py b/aiida/manage/caching.py index be31d024a1..c728e2b811 100644 --- a/aiida/manage/caching.py +++ b/aiida/manage/caching.py @@ -10,11 +10,11 @@ """Definition of caching mechanism and configuration for calculations.""" from __future__ import annotations +import keyword +import re from collections import namedtuple from contextlib import contextmanager, suppress from enum import Enum -import keyword -import re from aiida.common import exceptions from aiida.common.lang import type_check @@ -66,7 +66,6 @@ def get_options(self, strict: bool = False): :param strict: When set to ``True``, the function will actually try to resolve the identifier by loading it and if it fails, an exception is raised. """ - if self._default_all == 'disable': return False, [], [] @@ -238,7 +237,6 @@ def _validate_identifier_pattern(*, identifier: str, strict: bool = False): :raises ValueError: If the identifier is an invalid identifier. :raises ValueError: If ``strict=True`` and the identifier cannot be successfully loaded. """ - # pylint: disable=too-many-branches import importlib from aiida.common.exceptions import EntryPointError @@ -260,8 +258,8 @@ def _validate_identifier_pattern(*, identifier: str, strict: bool = False): for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP ): raise ValueError( - common_error_msg + - f'Group name pattern `{group_pattern}` does not match any of the AiiDA entry point group names.' + common_error_msg + + f'Group name pattern `{group_pattern}` does not match any of the AiiDA entry point group names.' ) # If strict mode is enabled and the identifier is explicit, i.e., doesn't contain a wildcard, try to load it. @@ -291,8 +289,8 @@ def _validate_identifier_pattern(*, identifier: str, strict: bool = False): if '*' in identifier_part: if not identifier_part.replace('*', 'a').isidentifier(): raise ValueError( - common_error_msg + - f'Identifier part `{identifier_part}` can not match a fully qualified Python name.' + common_error_msg + + f'Identifier part `{identifier_part}` can not match a fully qualified Python name.' ) else: if not identifier_part.isidentifier(): diff --git a/aiida/manage/configuration/__init__.py b/aiida/manage/configuration/__init__.py index e4cb4cda37..9e24efd07a 100644 --- a/aiida/manage/configuration/__init__.py +++ b/aiida/manage/configuration/__init__.py @@ -7,12 +7,12 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# ruff: noqa: E402 """Modules related to the configuration of an AiiDA instance.""" # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .config import * from .migrations import * @@ -36,25 +36,30 @@ 'upgrade_config', ) -# yapf: enable +# fmt: on # END AUTO-GENERATED -# pylint: disable=global-statement,redefined-outer-name,wrong-import-order __all__ += ( - 'get_config', 'get_config_option', 'get_config_path', 'get_profile', 'load_profile', 'reset_config', 'CONFIG' + 'get_config', + 'get_config_option', + 'get_config_path', + 'get_profile', + 'load_profile', + 'reset_config', + 'CONFIG', ) -from contextlib import contextmanager import os -from typing import TYPE_CHECKING, Any, Optional import warnings +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Optional from aiida.common.warnings import AiidaDeprecationWarning if TYPE_CHECKING: - from aiida.manage.configuration import Config, Profile # pylint: disable=import-self + from aiida.manage.configuration import Config, Profile # global variables for aiida CONFIG: Optional['Config'] = None @@ -121,7 +126,7 @@ def _merge_deprecated_cache_yaml(config, filepath): 'cache_config.yml use is deprecated and support will be removed in `v3.0`. Merging into config.json and ' f'moving to: {cache_path_backup}', AiidaDeprecationWarning, - stacklevel=2 + stacklevel=2, ) with open(cache_path, 'r', encoding='utf8') as handle: @@ -130,8 +135,11 @@ def _merge_deprecated_cache_yaml(config, filepath): if profile_name not in config.profile_names: warnings.warn(f"Profile '{profile_name}' from cache_config.yml not in config.json, skipping", UserWarning) continue - for key, option_name in [('default', 'caching.default_enabled'), ('enabled', 'caching.enabled_for'), - ('disabled', 'caching.disabled_for')]: + for key, option_name in [ + ('default', 'caching.default_enabled'), + ('enabled', 'caching.enabled_for'), + ('disabled', 'caching.disabled_for'), + ]: if key in data: value = data[key] # in case of empty key @@ -154,6 +162,7 @@ def load_profile(profile: Optional[str] = None, allow_switch=False) -> 'Profile' if another profile has already been loaded and allow_switch is False """ from aiida.manage import get_manager + return get_manager().load_profile(profile, allow_switch) @@ -163,6 +172,7 @@ def get_profile() -> Optional['Profile']: :return: the globally loaded `Profile` instance or `None` """ from aiida.manage import get_manager + return get_manager().get_profile() @@ -176,6 +186,7 @@ def profile_context(profile: Optional[str] = None, allow_switch=False) -> 'Profi :return: a context manager for temporarily loading a profile """ from aiida.manage import get_manager + manager = get_manager() current_profile = manager.get_profile() manager.load_profile(profile, allow_switch) @@ -195,7 +206,7 @@ def create_profile( first_name: Optional[str] = None, last_name: Optional[str] = None, institution: Optional[str] = None, - **kwargs + **kwargs, ) -> Profile: """Create a new profile, initialise its storage and create a default user. @@ -233,7 +244,7 @@ def reset_config(): .. warning:: This is experimental functionality and should for now be used only internally. If the reset is unclean weird unknown side-effects may occur that end up corrupting or destroying data. """ - global CONFIG + global CONFIG # noqa: PLW0603 CONFIG = None @@ -255,7 +266,7 @@ def get_config(create=False): :rtype: :class:`~aiida.manage.configuration.config.Config` :raises aiida.common.ConfigurationError: if the configuration file could not be found, read or deserialized """ - global CONFIG + global CONFIG # noqa: PLW0603 if not CONFIG: CONFIG = load_config(create=create) @@ -264,10 +275,10 @@ def get_config(create=False): # If the user does not want to get AiiDA deprecation warnings, we disable them - this can be achieved with:: # verdi config warnings.showdeprecations False # Note that the AiidaDeprecationWarning does NOT inherit from DeprecationWarning - warnings.simplefilter('default', AiidaDeprecationWarning) # pylint: disable=no-member + warnings.simplefilter('default', AiidaDeprecationWarning) # This should default to 'once', i.e. once per different message else: - warnings.simplefilter('ignore', AiidaDeprecationWarning) # pylint: disable=no-member + warnings.simplefilter('ignore', AiidaDeprecationWarning) return CONFIG @@ -286,4 +297,5 @@ def get_config_option(option_name: str) -> Any: :raises `aiida.common.exceptions.ConfigurationError`: if the option is not found """ from aiida.manage import get_manager + return get_manager().get_option(option_name) diff --git a/aiida/manage/configuration/config.py b/aiida/manage/configuration/config.py index ff6c9cbcdf..e974c4c983 100644 --- a/aiida/manage/configuration/config.py +++ b/aiida/manage/configuration/config.py @@ -20,10 +20,10 @@ import io import json import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import uuid +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type -from pydantic import ( # pylint: disable=no-name-in-module +from pydantic import ( BaseModel, ConfigDict, Field, @@ -64,8 +64,7 @@ class ProfileOptionsSchema(BaseModel, defer_build=True): ) daemon__timeout: int = Field( 2, - description= - 'Used to set default timeout in the :class:`aiida.engine.daemon.client.DaemonClient` for calls to the daemon.' + description='Used to set default timeout in the `DaemonClient` for calls to the daemon.', ) daemon__worker_process_slots: int = Field( 200, description='Maximum number of concurrent process tasks that each daemon worker can handle.' @@ -74,11 +73,11 @@ class ProfileOptionsSchema(BaseModel, defer_build=True): db__batch_size: int = Field( 100000, description='Batch size for bulk CREATE operations in the database. Avoids hitting MaxAllocSize of PostgreSQL ' - '(1GB) when creating large numbers of database records in one go.' + '(1GB) when creating large numbers of database records in one go.', ) verdi__shell__auto_import: str = Field( ':', - description='Additional modules/functions/classes to be automatically loaded in `verdi shell`, split by `:`.' + description='Additional modules/functions/classes to be automatically loaded in `verdi shell`, split by `:`.', ) logging__aiida_loglevel: LogLevels = Field( 'REPORT', description='Minimum level to log to daemon log and the `DbLog` table for the `aiida` logger.' @@ -131,6 +130,7 @@ class ProfileOptionsSchema(BaseModel, defer_build=True): def validate_caching_identifier_pattern(cls, value: List[str]) -> List[str]: """Validate the caching identifier patterns.""" from aiida.manage.caching import _validate_identifier_pattern + for identifier in value: _validate_identifier_pattern(identifier=identifier, strict=True) return value @@ -138,6 +138,7 @@ def validate_caching_identifier_pattern(cls, value: List[str]) -> List[str]: class GlobalOptionsSchema(ProfileOptionsSchema, defer_build=True): """Schema for the global options of an AiiDA instance.""" + autofill__user__email: Optional[str] = Field( None, description='Default user email to use when creating new profiles.' ) @@ -155,7 +156,7 @@ class GlobalOptionsSchema(ProfileOptionsSchema, defer_build=True): ) warnings__development_version: bool = Field( True, - description='Whether to print a warning when a profile is loaded while a development version is installed.' + description='Whether to print a warning when a profile is loaded while a development version is installed.', ) @@ -175,8 +176,9 @@ class ProcessControlConfig(BaseModel, defer_build=True): broker_host: str = Field('127.0.0.1', description='Hostname of the message broker.') broker_port: int = Field(5432, description='Port of the message broker.') broker_virtual_host: str = Field('', description='Virtual host to use for the message broker.') - broker_parameters: dict[ - str, Any] = Field(default_factory=dict, description='Arguments to be encoded as query parameters.') + broker_parameters: dict[str, Any] = Field( + default_factory=dict, description='Arguments to be encoded as query parameters.' + ) class ProfileSchema(BaseModel, defer_build=True): @@ -203,7 +205,7 @@ class ConfigSchema(BaseModel, defer_build=True): default_profile: Optional[str] = None -class Config: # pylint: disable=too-many-public-methods +class Config: """Object that represents the configuration file of an AiiDA instance.""" KEY_VERSION = 'CONFIG_VERSION' @@ -279,7 +281,7 @@ def validate(config: dict, filepath: Optional[str] = None): try: ConfigSchema(**config) except ValidationError as exception: - raise ConfigurationError(f'invalid config schema: {filepath}: {str(exception)}') + raise ConfigurationError(f'invalid config schema: {filepath}: {exception!s}') def __init__(self, filepath: str, config: dict, validate: bool = True): """Instantiate a configuration object from a configuration dictionary and its filepath. @@ -340,6 +342,7 @@ def handle_invalid(self, message): :param message: a string message to echo with describing the infraction """ from aiida.cmdline.utils import echo + filepath_backup = self._backup(self.filepath) echo.echo_warning(message) echo.echo_warning(f'backup of the original config file written to: `{filepath_backup}`') @@ -385,7 +388,7 @@ def version_oldest_compatible(self, version_oldest_compatible): def version_settings(self): return { self.KEY_VERSION_CURRENT: self.version, - self.KEY_VERSION_OLDEST_COMPATIBLE: self.version_oldest_compatible + self.KEY_VERSION_OLDEST_COMPATIBLE: self.version_oldest_compatible, } @property @@ -482,7 +485,8 @@ def create_profile(self, name: str, storage_cls: Type['StorageBackend'], storage raise EntryPointError(f'`{storage_cls}` does not have a registered entry point.') profile = Profile( - name, { + name, + { 'storage': { 'backend': storage_entry_point.name, 'config': storage_config, @@ -495,17 +499,17 @@ def create_profile(self, name: str, storage_cls: Type['StorageBackend'], storage 'broker_password': 'guest', 'broker_host': '127.0.0.1', 'broker_port': 5672, - 'broker_virtual_host': '' - } + 'broker_virtual_host': '', + }, }, - } + }, ) LOGGER.report('Initialising the storage backend.') try: with contextlib.redirect_stdout(io.StringIO()): profile.storage_cls.initialise(profile) - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: raise StorageMigrationError( f'Storage backend initialisation failed, probably because the configuration is incorrect:\n{exception}' ) @@ -635,9 +639,8 @@ def set_option(self, option_name, option_value, scope=None, override=True): if not option.global_only and scope is not None: self.get_profile(scope).set_option(option.name, value, override=override) - else: - if option.name not in self.options or override: - self.options[option.name] = value + elif option.name not in self.options or override: + self.options[option.name] = value return value diff --git a/aiida/manage/configuration/migrations/__init__.py b/aiida/manage/configuration/migrations/__init__.py index 5eb7bf3bba..72fed50ab4 100644 --- a/aiida/manage/configuration/migrations/__init__.py +++ b/aiida/manage/configuration/migrations/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .migrations import * @@ -27,4 +26,4 @@ 'upgrade_config', ) -# yapf: enable +# fmt: on diff --git a/aiida/manage/configuration/migrations/migrations.py b/aiida/manage/configuration/migrations/migrations.py index 01ddfe7d4f..f9f1ba1b6c 100644 --- a/aiida/manage/configuration/migrations/migrations.py +++ b/aiida/manage/configuration/migrations/migrations.py @@ -14,8 +14,14 @@ from aiida.common.log import AIIDA_LOGGER __all__ = ( - 'CURRENT_CONFIG_VERSION', 'OLDEST_COMPATIBLE_CONFIG_VERSION', 'get_current_version', 'check_and_migrate_config', - 'config_needs_migrating', 'upgrade_config', 'downgrade_config', 'MIGRATIONS' + 'CURRENT_CONFIG_VERSION', + 'OLDEST_COMPATIBLE_CONFIG_VERSION', + 'get_current_version', + 'check_and_migrate_config', + 'config_needs_migrating', + 'upgrade_config', + 'downgrade_config', + 'MIGRATIONS', ) ConfigType = Dict[str, Any] @@ -55,6 +61,7 @@ def downgrade(self, config: ConfigType) -> None: class Initial(SingleMigration): """Base migration (no-op).""" + down_revision = 0 down_compatible = 0 up_revision = 1 @@ -75,6 +82,7 @@ class AddProfileUuid(SingleMigration): The profile uuid will be used as a general purpose identifier for the profile, in for example the RabbitMQ message queues and exchanges. """ + down_revision = 1 down_compatible = 0 up_revision = 2 @@ -82,6 +90,7 @@ class AddProfileUuid(SingleMigration): def upgrade(self, config: ConfigType) -> None: from uuid import uuid4 # we require this import here, to patch it in the tests + for profile in config.get('profiles', {}).values(): profile.setdefault('PROFILE_UUID', uuid4().hex) @@ -97,6 +106,7 @@ class SimplifyDefaultProfiles(SingleMigration): configuration no longer needs a value per process ('verdi', 'daemon'). We remove the dictionary 'default_profiles' and replace it with a simple value 'default_profile'. """ + down_revision = 2 down_compatible = 0 up_revision = 3 @@ -123,6 +133,7 @@ def downgrade(self, config: ConfigType) -> None: class AddMessageBroker(SingleMigration): """Add the configuration for the message broker, which was not configurable up to now.""" + down_revision = 3 down_compatible = 3 up_revision = 4 @@ -130,6 +141,7 @@ class AddMessageBroker(SingleMigration): def upgrade(self, config: ConfigType) -> None: from aiida.manage.external.rmq import BROKER_DEFAULTS + defaults = [ ('broker_protocol', BROKER_DEFAULTS.protocol), ('broker_username', BROKER_DEFAULTS.username), @@ -150,6 +162,7 @@ def downgrade(self, config: ConfigType) -> None: class SimplifyOptions(SingleMigration): """Remove unnecessary difference between file/internal representation of options""" + down_revision = 4 down_compatible = 3 up_revision = 5 @@ -205,6 +218,7 @@ class AbstractStorageAndProcess(SingleMigration): This allows for different storage backends to have different configuration. """ + down_revision = 5 down_compatible = 5 up_revision = 6 @@ -270,6 +284,7 @@ class MergeStorageBackendTypes(SingleMigration): The legacy name is stored under the `_v6_backend` key, to allow for downgrades. """ + down_revision = 6 down_compatible = 6 up_revision = 7 @@ -298,6 +313,7 @@ def downgrade(self, config: ConfigType) -> None: class AddTestProfileKey(SingleMigration): """Add the ``test_profile`` key.""" + down_revision = 7 down_compatible = 7 up_revision = 8 @@ -313,7 +329,6 @@ def downgrade(self, config: ConfigType) -> None: # Iterate over the fixed list of the profile names, since we are mutating the profiles dictionary. for profile_name in profile_names: - profile = profiles.pop(profile_name) profile_name_new = None test_profile = profile.pop('test_profile', False) # If absent, assume it is not a test profile @@ -331,14 +346,13 @@ def downgrade(self, config: ConfigType) -> None: ) if profile_name_new is not None: - if profile_name_new in profile_names: raise exceptions.ConfigurationError( f'cannot change `{profile_name}` to `{profile_name_new}` because it already exists.' ) CONFIG_LOGGER.warning(f'changing profile name from `{profile_name}` to `{profile_name_new}`.') - profile_name = profile_name_new + profile_name = profile_name_new # noqa: PLW2901 profile['test_profile'] = test_profile profiles[profile_name] = profile @@ -350,6 +364,7 @@ class AddPrefixToStorageBackendTypes(SingleMigration): At this point, it should only ever contain ``psql_dos`` which should therefore become ``core.psql_dos``. To cover for cases where people manually added a read only ``sqlite_zip`` profile, we also migrate that. """ + down_revision = 8 down_compatible = 8 up_revision = 9 @@ -406,9 +421,7 @@ def get_oldest_compatible_version(config): def upgrade_config( - config: ConfigType, - target: int = CURRENT_CONFIG_VERSION, - migrations: Iterable[Type[SingleMigration]] = MIGRATIONS + config: ConfigType, target: int = CURRENT_CONFIG_VERSION, migrations: Iterable[Type[SingleMigration]] = MIGRATIONS ) -> ConfigType: """Run the registered configuration migrations up to the target version. diff --git a/aiida/manage/configuration/options.py b/aiida/manage/configuration/options.py index d5925bf47d..eec455854f 100644 --- a/aiida/manage/configuration/options.py +++ b/aiida/manage/configuration/options.py @@ -49,6 +49,7 @@ def description(self) -> str: @property def global_only(self) -> bool: from .config import ProfileOptionsSchema + return self._name.replace('.', '__') not in ProfileOptionsSchema.model_fields def validate(self, value: Any) -> Any: @@ -87,12 +88,14 @@ def validate(self, value: Any) -> Any: def get_option_names() -> List[str]: """Return a list of available option names.""" from .config import GlobalOptionsSchema + return [key.replace('__', '.') for key in GlobalOptionsSchema.model_fields] def get_option(name: str) -> Option: """Return option.""" from .config import GlobalOptionsSchema + options = GlobalOptionsSchema.model_fields option_name = name.replace('.', '__') if option_name not in options: diff --git a/aiida/manage/configuration/profile.py b/aiida/manage/configuration/profile.py index 24d4ac640a..03b2dd470e 100644 --- a/aiida/manage/configuration/profile.py +++ b/aiida/manage/configuration/profile.py @@ -9,9 +9,9 @@ ########################################################################### """AiiDA profile related code""" import collections -from copy import deepcopy import os import pathlib +from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Type from aiida.common import exceptions @@ -24,7 +24,7 @@ __all__ = ('Profile',) -class Profile: # pylint: disable=too-many-public-methods +class Profile: """Class that models a profile as it is stored in the configuration file of an AiiDA instance.""" KEY_UUID = 'PROFILE_UUID' @@ -59,6 +59,7 @@ def __init__(self, name: str, config: Mapping[str, Any], validate=True): # Create a default UUID if not specified if self._attributes.get(self.KEY_UUID, None) is None: from uuid import uuid4 + self._attributes[self.KEY_UUID] = uuid4().hex def __repr__(self) -> str: @@ -114,6 +115,7 @@ def set_storage(self, name: str, config: Dict[str, Any]) -> None: def storage_cls(self) -> Type['StorageBackend']: """Return the storage backend class for this profile.""" from aiida.plugins import StorageFactory + return StorageFactory(self.storage_backend) @property @@ -263,10 +265,10 @@ def filepaths(self): 'controller': 'circus.c.sock', 'pubsub': 'circus.p.sock', 'stats': 'circus.s.sock', - } + }, }, 'daemon': { 'log': str(DAEMON_LOG_DIR / f'aiida-{self.name}.log'), 'pid': str(DAEMON_DIR / f'aiida-{self.name}.pid'), - } + }, } diff --git a/aiida/manage/configuration/settings.py b/aiida/manage/configuration/settings.py index 0f68c4339c..553b3fe372 100644 --- a/aiida/manage/configuration/settings.py +++ b/aiida/manage/configuration/settings.py @@ -56,7 +56,6 @@ def create_instance_directories() -> None: try: for path in list_of_paths: - if path is directory_base and not path.exists(): warnings.warn(f'Creating AiiDA configuration folder `{path}`.') @@ -82,21 +81,17 @@ def set_configuration_directory(aiida_config_folder: pathlib.Path | None = None) In principle then, a configuration folder should always be found or automatically created. """ - # pylint: disable = global-statement - global AIIDA_CONFIG_FOLDER - global DAEMON_DIR - global DAEMON_LOG_DIR - global ACCESS_CONTROL_DIR + global AIIDA_CONFIG_FOLDER # noqa: PLW0603 + global DAEMON_DIR # noqa: PLW0603 + global DAEMON_LOG_DIR # noqa: PLW0603 + global ACCESS_CONTROL_DIR # noqa: PLW0603 if aiida_config_folder is not None: - AIIDA_CONFIG_FOLDER = aiida_config_folder elif environment_variable := os.environ.get(DEFAULT_AIIDA_PATH_VARIABLE): - # Loop over all the paths in the `AIIDA_PATH` variable to see if any of them contain a configuration folder for base_dir_path in [path for path in environment_variable.split(':') if path]: - AIIDA_CONFIG_FOLDER = pathlib.Path(base_dir_path).expanduser() # Only add the base config directory name to the base path if it does not already do so diff --git a/aiida/manage/external/__init__.py b/aiida/manage/external/__init__.py index 7ec4dcbd01..8c6e8a7d55 100644 --- a/aiida/manage/external/__init__.py +++ b/aiida/manage/external/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .rmq import * @@ -26,4 +25,4 @@ 'get_task_exchange_name', ) -# yapf: enable +# fmt: on diff --git a/aiida/manage/external/postgres.py b/aiida/manage/external/postgres.py index cd2b471667..79e00cc165 100644 --- a/aiida/manage/external/postgres.py +++ b/aiida/manage/external/postgres.py @@ -17,7 +17,7 @@ """ from typing import TYPE_CHECKING -from pgsu import DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=no-name-in-module +from pgsu import DEFAULT_DSN as DEFAULT_DBINFO from pgsu import PGSU, PostgresConnectionMode if TYPE_CHECKING: @@ -28,7 +28,7 @@ _DROP_USER_COMMAND = 'DROP USER "{}"' _CREATE_DB_COMMAND = ( 'CREATE DATABASE "{}" OWNER "{}" ENCODING \'UTF8\' ' - 'LC_COLLATE=\'en_US.UTF-8\' LC_CTYPE=\'en_US.UTF-8\' ' + "LC_COLLATE='en_US.UTF-8' LC_CTYPE='en_US.UTF-8' " 'TEMPLATE=template0' ) _DROP_DB_COMMAND = 'DROP DATABASE "{}"' @@ -39,8 +39,7 @@ class Postgres(PGSU): - """ - Adds convenience functions to :py:class:`pgsu.PGSU`. + """Adds convenience functions to :py:class:`pgsu.PGSU`. Provides convenience functions for * creating/dropping users @@ -76,7 +75,7 @@ def from_profile(cls, profile: 'Profile', **kwargs): dbinfo.update( dict( host=profile.storage_config['database_hostname'] or DEFAULT_DBINFO['host'], - port=profile.storage_config['database_port'] or DEFAULT_DBINFO['port'] + port=profile.storage_config['database_port'] or DEFAULT_DBINFO['port'], ) ) @@ -85,8 +84,7 @@ def from_profile(cls, profile: 'Profile', **kwargs): ### DB user functions ### def dbuser_exists(self, dbuser): - """ - Find out if postgres user with name dbuser exists + """Find out if postgres user with name dbuser exists :param str dbuser: database user to check for :return: (bool) True if user exists, False otherwise @@ -94,8 +92,7 @@ def dbuser_exists(self, dbuser): return bool(self.execute(_USER_EXISTS_COMMAND.format(dbuser))) def create_dbuser(self, dbuser, dbpass, privileges=''): - """ - Create a database user in postgres + """Create a database user in postgres :param str dbuser: Name of the user to be created. :param str dbpass: Password the user should be given. @@ -108,8 +105,7 @@ def create_dbuser(self, dbuser, dbpass, privileges=''): self.execute(_GRANT_ROLE_COMMAND.format(dbuser)) def drop_dbuser(self, dbuser): - """ - Drop a database user in postgres + """Drop a database user in postgres :param str dbuser: Name of the user to be dropped. """ @@ -141,8 +137,9 @@ def can_user_authenticate(self, dbuser, dbpass): :param dbpass: the database password :return: True if the credentials are valid, False otherwise """ - from pgsu import _execute_psyco import psycopg2 + from pgsu import _execute_psyco + dsn = self.dsn.copy() dsn['user'] = dbuser dsn['password'] = dbpass @@ -157,8 +154,7 @@ def can_user_authenticate(self, dbuser, dbpass): ### DB functions ### def db_exists(self, dbname): - """ - Check wether a postgres database with dbname exists + """Check wether a postgres database with dbname exists :param str dbname: Name of the database to check for :return: (bool), True if database exists, False otherwise @@ -166,8 +162,7 @@ def db_exists(self, dbname): return bool(self.execute(_CHECK_DB_EXISTS_COMMAND.format(dbname))) def create_db(self, dbuser, dbname): - """ - Create a database in postgres + """Create a database in postgres :param str dbuser: Name of the user which should own the db. :param str dbname: Name of the database. @@ -175,8 +170,7 @@ def create_db(self, dbuser, dbname): self.execute(_CREATE_DB_COMMAND.format(dbname, dbuser)) def drop_db(self, dbname): - """ - Drop a database in postgres + """Drop a database in postgres :param str dbname: Name of the database. """ @@ -257,12 +251,14 @@ def dbinfo(self): def manual_setup_instructions(db_username, db_name): """Create a message with instructions for manually creating a database""" db_pass = '' - instructions = '\n'.join([ - 'Run the following commands as a UNIX user with access to PostgreSQL (Ubuntu: $ sudo su postgres):', - '', - '\t$ psql template1', - f' ==> {_CREATE_USER_COMMAND.format(db_username, db_pass, "")}', - f' ==> {_GRANT_ROLE_COMMAND.format(db_username)}', - f' ==> {_CREATE_DB_COMMAND.format(db_name, db_username)}', - ]) + instructions = '\n'.join( + [ + 'Run the following commands as a UNIX user with access to PostgreSQL (Ubuntu: $ sudo su postgres):', + '', + '\t$ psql template1', + f' ==> {_CREATE_USER_COMMAND.format(db_username, db_pass, "")}', + f' ==> {_GRANT_ROLE_COMMAND.format(db_username)}', + f' ==> {_CREATE_DB_COMMAND.format(db_name, db_username)}', + ] + ) return instructions diff --git a/aiida/manage/external/rmq/__init__.py b/aiida/manage/external/rmq/__init__.py index 1af41bff0e..6191b5c037 100644 --- a/aiida/manage/external/rmq/__init__.py +++ b/aiida/manage/external/rmq/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .client import * from .defaults import * @@ -28,4 +27,4 @@ 'get_task_exchange_name', ) -# yapf: enable +# fmt: on diff --git a/aiida/manage/external/rmq/client.py b/aiida/manage/external/rmq/client.py index 476c63111a..24ef5b3193 100644 --- a/aiida/manage/external/rmq/client.py +++ b/aiida/manage/external/rmq/client.py @@ -32,6 +32,7 @@ def __init__(self, username: str, password: str, hostname: str, virtual_host: st :param virtual_host: The virtual host. """ import requests + self._username = username self._password = password self._hostname = hostname @@ -73,6 +74,7 @@ def request( :raises `ManagementApiConnectionError`: If connection to the API cannot be made. """ import requests + url = self.format_url(url, url_params) try: return requests.request(method, url, auth=self._authentication, params=params or {}, timeout=5) diff --git a/aiida/manage/external/rmq/defaults.py b/aiida/manage/external/rmq/defaults.py index 16058d8a52..953fef8912 100644 --- a/aiida/manage/external/rmq/defaults.py +++ b/aiida/manage/external/rmq/defaults.py @@ -8,12 +8,14 @@ MESSAGE_EXCHANGE = 'messages' TASK_EXCHANGE = 'tasks' -BROKER_DEFAULTS = AttributeDict({ - 'protocol': 'amqp', - 'username': 'guest', - 'password': 'guest', - 'host': '127.0.0.1', - 'port': 5672, - 'virtual_host': '', - 'heartbeat': 600, -}) +BROKER_DEFAULTS = AttributeDict( + { + 'protocol': 'amqp', + 'username': 'guest', + 'password': 'guest', + 'host': '127.0.0.1', + 'port': 5672, + 'virtual_host': '', + 'heartbeat': 600, + } +) diff --git a/aiida/manage/external/rmq/launcher.py b/aiida/manage/external/rmq/launcher.py index aa89b016f1..a97c9e8517 100644 --- a/aiida/manage/external/rmq/launcher.py +++ b/aiida/manage/external/rmq/launcher.py @@ -68,15 +68,14 @@ async def _continue(self, communicator, pid, nowait, tag=None): return False if node.is_terminated: - LOGGER.info('not continuing process<%d> which is already terminated with state %s', pid, node.process_state) future = kiwipy.Future() if node.is_finished: - future.set_result({ - entry.link_label: entry.node for entry in node.base.links.get_outgoing(node_class=Data) - }) + future.set_result( + {entry.link_label: entry.node for entry in node.base.links.get_outgoing(node_class=Data)} + ) elif node.is_excepted: future.set_exception(PastException(node.exception)) elif node.is_killed: @@ -114,10 +113,11 @@ async def _continue(self, communicator, pid, nowait, tag=None): # server based. LOGGER.exception( 'A subscriber with the process id<%d> already exists, which most likely means this worker is already ' - 'working on it and this task was sent as a duplicate by mistake. Deleting the task now.', pid + 'working on it and this task was sent as a duplicate by mistake. Deleting the task now.', + pid, ) return False - except asyncio.CancelledError: # pylint: disable=try-except-raise + except asyncio.CancelledError: # note this is only required in python<=3.7, # where asyncio.CancelledError inherits from Exception raise diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index 3e33f822c3..1f918f02f0 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=cyclic-import """AiiDA manager for global settings""" import functools from typing import TYPE_CHECKING, Any, Optional, Union @@ -32,13 +31,13 @@ def get_manager() -> 'Manager': """Return the AiiDA global manager instance.""" - global MANAGER # pylint: disable=global-statement + global MANAGER # noqa: PLW0603 if MANAGER is None: MANAGER = Manager() return MANAGER -class Manager: # pylint: disable=too-many-public-methods +class Manager: """Manager singleton for globally loaded resources. AiiDA can have the following global resources loaded: @@ -85,6 +84,7 @@ def get_config(create=False) -> 'Config': """ from .configuration import get_config + return get_config(create=create) def get_profile(self) -> Optional['Profile']: @@ -235,6 +235,7 @@ def get_backend(self) -> 'StorageBackend': Deprecated: use `get_profile_storage` instead. """ from aiida.common.warnings import warn_deprecation + warn_deprecation('get_backend() is deprecated, use get_profile_storage() instead', version=3) return self.get_profile_storage() @@ -371,6 +372,7 @@ def get_process_controller(self) -> 'RemoteProcessThreadController': """ from plumpy.process_comms import RemoteProcessThreadController + if self._process_controller is None: self._process_controller = RemoteProcessThreadController(self.get_communicator()) @@ -451,7 +453,7 @@ def create_daemon_runner(self, loop: Optional['asyncio.AbstractEventLoop'] = Non loop=runner_loop, persister=self.get_persister(), load_context=LoadSaveContext(runner=runner), - loader=persistence.get_object_loader() + loader=persistence.get_object_loader(), ) assert runner.communicator is not None, 'communicator not set for runner' @@ -507,6 +509,7 @@ def is_rabbitmq_version_supported(communicator: 'RmqThreadCommunicator') -> bool :return: boolean whether the current RabbitMQ version is supported. """ from packaging.version import parse + version = get_rabbitmq_version(communicator) return parse('3.6.0') <= version < parse('3.8.15') @@ -517,4 +520,5 @@ def get_rabbitmq_version(communicator: 'RmqThreadCommunicator'): :return: :class:`packaging.version.Version` """ from packaging.version import parse + return parse(communicator.server_properties['version'].decode('utf-8')) diff --git a/aiida/manage/profile_access.py b/aiida/manage/profile_access.py index 305aae407d..b50c7338ae 100644 --- a/aiida/manage/profile_access.py +++ b/aiida/manage/profile_access.py @@ -10,8 +10,8 @@ """Module for the ProfileAccessManager that tracks process access to the profile.""" import contextlib import os -from pathlib import Path import typing +from pathlib import Path import psutil @@ -62,8 +62,7 @@ def request_access(self) -> None: :raises ~aiida.common.exceptions.LockedProfileError: if the profile is locked. """ error_message = ( - f'process {self.process.pid} cannot access profile `{self.profile.name}` ' - f'because it is being locked.' + f'process {self.process.pid} cannot access profile `{self.profile.name}` ' f'because it is being locked.' ) self._raise_if_locked(error_message) @@ -82,8 +81,7 @@ def request_access(self) -> None: # Check again in case a lock was created in the time between the first check and creating the # access record file. error_message = ( - f'profile `{self.profile.name}` was locked while process ' - f'{self.process.pid} was requesting access.' + f'profile `{self.profile.name}` was locked while process ' f'{self.process.pid} was requesting access.' ) self._raise_if_locked(error_message) @@ -104,16 +102,14 @@ def lock(self): :raises ~aiida.common.exceptions.LockedProfileError: if there currently already is a lock on the profile. """ error_message = ( - f'process {self.process.pid} cannot lock profile `{self.profile.name}` ' - f'because it is already locked.' + f'process {self.process.pid} cannot lock profile `{self.profile.name}` ' f'because it is already locked.' ) self._raise_if_locked(error_message) self._clear_stale_pid_files() error_message = ( - f'process {self.process.pid} cannot lock profile `{self.profile.name}` ' - f'because it is being accessed.' + f'process {self.process.pid} cannot lock profile `{self.profile.name}` ' f'because it is being accessed.' ) self._raise_if_active(error_message) diff --git a/aiida/manage/tests/__init__.py b/aiida/manage/tests/__init__.py index 8e913f7bef..2a23415c1e 100644 --- a/aiida/manage/tests/__init__.py +++ b/aiida/manage/tests/__init__.py @@ -7,6 +7,5 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Testing infrastructure for easy testing of AiiDA plugins. +"""Testing infrastructure for easy testing of AiiDA plugins. """ diff --git a/aiida/manage/tests/pytest_fixtures.py b/aiida/manage/tests/pytest_fixtures.py index 8f8a0efabe..8b95e5f2de 100644 --- a/aiida/manage/tests/pytest_fixtures.py +++ b/aiida/manage/tests/pytest_fixtures.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=redefined-outer-name,unused-argument """Collection of ``pytest`` fixtures that are intended for use in plugin packages. To use these fixtures, simply create a ``conftest.py`` in the tests folder and add the following line: @@ -32,10 +31,10 @@ import uuid import warnings -from importlib_metadata import EntryPoint, EntryPoints import plumpy import pytest import wrapt +from importlib_metadata import EntryPoint, EntryPoints from aiida import plugins from aiida.common.exceptions import NotExistent @@ -56,7 +55,7 @@ def recursive_merge(left: dict[t.Any, t.Any], right: dict[t.Any, t.Any]) -> None :param right: Dictionary to recurisvely merge on top of ``left`` dictionary. """ for key, value in right.items(): - if (key in left and isinstance(left[key], dict) and isinstance(value, dict)): + if key in left and isinstance(left[key], dict) and isinstance(value, dict): recursive_merge(left[key], value) else: left[key] = value @@ -73,9 +72,7 @@ def aiida_caplog(caplog): @pytest.fixture(scope='session') def postgres_cluster( - database_name: str | None = None, - database_username: str | None = None, - database_password: str | None = None + database_name: str | None = None, database_username: str | None = None, database_password: str | None = None ) -> t.Generator[dict[str, str], None, None]: """Create a temporary and isolated PostgreSQL cluster using ``pgtest`` and cleanup after the yield. @@ -204,7 +201,7 @@ def factory(custom_configuration: dict[str, t.Any] | None = None) -> dict[str, t 'config': { **postgres_cluster, 'repository_uri': f'file://{tmp_path_factory.mktemp("repository")}', - } + }, } } recursive_merge(configuration, custom_configuration or {}) @@ -227,7 +224,7 @@ def clear_profile(): daemon_client.stop_daemon(wait=True) manager = get_manager() - manager.get_profile_storage()._clear() # pylint: disable=protected-access + manager.get_profile_storage()._clear() manager.reset_communicator() manager.reset_runner() @@ -265,12 +262,12 @@ def factory(custom_configuration: dict[str, t.Any]) -> Profile: 'broker_host': '127.0.0.1', 'broker_port': 5672, 'broker_virtual_host': '', - } + }, }, 'options': { 'warnings.development_version': False, 'warnings.rabbitmq_version': False, - } + }, } recursive_merge(configuration, custom_configuration or {}) configuration['test_profile'] = True @@ -452,10 +449,7 @@ def get_code(entry_point, executable, computer=aiida_localhost, label=None, **kw builder = QueryBuilder().append(Computer, filters={'uuid': computer.uuid}, tag='computer') builder.append( - InstalledCode, filters={ - 'label': label, - 'attributes.input_plugin': entry_point - }, with_computer='computer' + InstalledCode, filters={'label': label, 'attributes.input_plugin': entry_point}, with_computer='computer' ) try: @@ -475,7 +469,7 @@ def get_code(entry_point, executable, computer=aiida_localhost, label=None, **kw default_calc_job_plugin=entry_point, computer=computer, filepath_executable=executable_path, - **kwargs + **kwargs, ) return code.store() @@ -530,11 +524,11 @@ def aiida_computer(tmp_path) -> t.Callable[[], Computer]: """Factory to return a :class:`aiida.orm.computers.Computer` instance.""" def factory( - label: str = None, + label: t.Optional[str] = None, minimum_job_poll_interval: int = 0, default_mpiprocs_per_machine: int = 1, - configuration_kwargs: dict[t.Any, t.Any] = None, - **kwargs + configuration_kwargs: t.Optional[dict[t.Any, t.Any]] = None, + **kwargs, ) -> Computer: """Return a :class:`aiida.orm.computers.Computer` instance. @@ -579,7 +573,7 @@ def factory( def aiida_computer_local(aiida_computer) -> t.Callable[[], Computer]: """Factory to return a :class:`aiida.orm.computers.Computer` instance with ``core.local`` transport.""" - def factory(label: str = None, configure: bool = True) -> Computer: + def factory(label: t.Optional[str] = None, configure: bool = True) -> Computer: """Return a :class:`aiida.orm.computers.Computer` instance representing localhost with ``core.local`` transport. The database is queried for an existing computer with the given label. If it exists, it is returned, otherwise a @@ -606,7 +600,7 @@ def factory(label: str = None, configure: bool = True) -> Computer: def aiida_computer_ssh(aiida_computer, ssh_key) -> t.Callable[[], Computer]: """Factory to return a :class:`aiida.orm.computers.Computer` instance with ``core.ssh`` transport.""" - def factory(label: str = None, configure: bool = True) -> Computer: + def factory(label: t.Optional[str] = None, configure: bool = True) -> Computer: """Return a :class:`aiida.orm.computers.Computer` instance representing localhost with ``core.ssh`` transport. The database is queried for an existing computer with the given label. If it exists, it is returned, otherwise a @@ -667,7 +661,7 @@ def daemon_client(aiida_profile): # Give an additional grace period by manually waiting for the daemon to be stopped. In certain unit test # scenarios, the built in wait time in ``daemon_client.stop_daemon`` is not sufficient and even though the # daemon is stopped, ``daemon_client.is_daemon_running`` will return false for a little bit longer. - daemon_client._await_condition( # pylint: disable=protected-access + daemon_client._await_condition( lambda: not daemon_client.is_daemon_running, DaemonTimeoutException('The daemon failed to stop.'), ) @@ -691,7 +685,7 @@ def stopped_daemon_client(daemon_client): # Give an additional grace period by manually waiting for the daemon to be stopped. In certain unit test # scenarios, the built in wait time in ``daemon_client.stop_daemon`` is not sufficient and even though the # daemon is stopped, ``daemon_client.is_daemon_running`` will return false for a little bit longer. - daemon_client._await_condition( # pylint: disable=protected-access + daemon_client._await_condition( lambda: not daemon_client.is_daemon_running, DaemonTimeoutException('The daemon failed to stop.'), ) @@ -707,7 +701,7 @@ def _factory( submittable: Process | ProcessBuilder | ProcessNode, state: plumpy.ProcessState = plumpy.ProcessState.FINISHED, timeout: int = 20, - **kwargs + **kwargs, ): """Submit a process and wait for it to achieve the given state. @@ -730,7 +724,6 @@ def _factory( start_time = time.time() while node.process_state is not state: - if node.is_excepted: raise RuntimeError(f'The process excepted: {node.exception}') @@ -804,7 +797,7 @@ def add( entry_point_string: str | None = None, *, name: str | None = None, - group: str | None = None + group: str | None = None, ) -> None: """Add an entry point. diff --git a/aiida/orm/__init__.py b/aiida/orm/__init__.py index e9a26461ed..15098eaf73 100644 --- a/aiida/orm/__init__.py +++ b/aiida/orm/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .authinfos import * from .comments import * @@ -117,4 +116,4 @@ 'validate_link', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/authinfos.py b/aiida/orm/authinfos.py index e8131ff652..bf61fc1cca 100644 --- a/aiida/orm/authinfos.py +++ b/aiida/orm/authinfos.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Type from aiida.common import exceptions -from aiida.common.lang import classproperty from aiida.manage import get_manager from aiida.plugins import TransportFactory @@ -19,7 +18,8 @@ if TYPE_CHECKING: from aiida.orm import Computer, User - from aiida.orm.implementation import BackendAuthInfo, StorageBackend + from aiida.orm.implementation import StorageBackend + from aiida.orm.implementation.authinfos import BackendAuthInfo # noqa: F401 from aiida.transports import Transport __all__ = ('AuthInfo',) @@ -83,7 +83,8 @@ def enabled(self, enabled: bool) -> None: @property def computer(self) -> 'Computer': """Return the computer associated with this instance.""" - from . import computers # pylint: disable=cyclic-import + from . import computers + return entities.from_backend_entity(computers.Computer, self._backend_entity.computer) @property diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index fa5b2d8838..c782cd0cf2 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -67,18 +67,21 @@ def disable(self) -> None: def get_exclude(self) -> list[str] | None: """Return the list of classes to exclude from autogrouping. - Returns ``None`` if no exclusion list has been set.""" + Returns ``None`` if no exclusion list has been set. + """ return self._exclude def get_include(self) -> list[str] | None: """Return the list of classes to include in the autogrouping. - Returns ``None`` if no inclusion list has been set.""" + Returns ``None`` if no inclusion list has been set. + """ return self._include def get_group_label_prefix(self) -> str: """Get the prefix of the label of the group. - If no group label prefix was set, it will set a default one by itself.""" + If no group label prefix was set, it will set a default one by itself. + """ return self._group_label_prefix @staticmethod @@ -215,19 +218,14 @@ def get_or_create_group(self) -> AutoGroup: queryb = QueryBuilder(self._backend).append( AutoGroup, filters={ - 'or': [{ - 'label': { - '==': label_prefix - } - }, { - 'label': { - 'like': f"{escape_for_sql_like(f'{label_prefix}_')}%" - } - }] + 'or': [ + {'label': {'==': label_prefix}}, + {'label': {'like': f"{escape_for_sql_like(f'{label_prefix}_')}%"}}, + ] }, - project='label' + project='label', ) - existing_group_labels = [res[0][len(label_prefix):] for res in queryb.all()] + existing_group_labels = [res[0][len(label_prefix) :] for res in queryb.all()] existing_group_ints = [] for label in existing_group_labels: if label == '': diff --git a/aiida/orm/comments.py b/aiida/orm/comments.py index 6ab090ada0..80ea473be8 100644 --- a/aiida/orm/comments.py +++ b/aiida/orm/comments.py @@ -11,14 +11,13 @@ from datetime import datetime from typing import TYPE_CHECKING, List, Optional, Type -from aiida.common.lang import classproperty from aiida.manage import get_manager from . import entities, users if TYPE_CHECKING: from aiida.orm import Node, User - from aiida.orm.implementation import BackendComment, StorageBackend + from aiida.orm.implementation import StorageBackend __all__ = ('Comment',) @@ -31,8 +30,7 @@ def _entity_base_cls() -> Type['Comment']: return Comment def delete(self, pk: int) -> None: - """ - Remove a Comment from the collection with the given id + """Remove a Comment from the collection with the given id :param pk: the id of the comment to delete @@ -42,16 +40,14 @@ def delete(self, pk: int) -> None: self._backend.comments.delete(pk) def delete_all(self) -> None: - """ - Delete all Comments from the Collection + """Delete all Comments from the Collection :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted """ self._backend.comments.delete_all() def delete_many(self, filters: dict) -> List[int]: - """ - Delete Comments from the Collection based on ``filters`` + """Delete Comments from the Collection based on ``filters`` :param filters: similar to QueryBuilder filter diff --git a/aiida/orm/computers.py b/aiida/orm/computers.py index 30f5e3c46f..df7dd6d245 100644 --- a/aiida/orm/computers.py +++ b/aiida/orm/computers.py @@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from aiida.common import exceptions -from aiida.common.lang import classproperty from aiida.manage import get_manager from aiida.plugins import SchedulerFactory, TransportFactory @@ -21,7 +20,7 @@ if TYPE_CHECKING: from aiida.orm import AuthInfo, User - from aiida.orm.implementation import BackendComputer, StorageBackend + from aiida.orm.implementation import StorageBackend from aiida.schedulers import Scheduler from aiida.transports import Transport @@ -36,8 +35,7 @@ def _entity_base_cls() -> Type['Computer']: return Computer def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple[bool, 'Computer']: - """ - Try to retrieve a Computer from the DB with the given arguments; + """Try to retrieve a Computer from the DB with the given arguments; create (and store) a new Computer if such a Computer was not present yet. :param label: computer label @@ -63,21 +61,18 @@ def delete(self, pk: int) -> None: class Computer(entities.Entity['BackendComputer', ComputerCollection]): - """ - Computer entity. - """ - # pylint: disable=too-many-public-methods + """Computer entity.""" _logger = logging.getLogger(__name__) - PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL = 'minimum_scheduler_poll_interval' # pylint: disable=invalid-name - PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL__DEFAULT = 10. # pylint: disable=invalid-name + PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL = 'minimum_scheduler_poll_interval' + PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL__DEFAULT = 10.0 PROPERTY_WORKDIR = 'workdir' PROPERTY_SHEBANG = 'shebang' _CLS_COLLECTION = ComputerCollection - def __init__( # pylint: disable=too-many-arguments + def __init__( self, label: Optional[str] = None, hostname: str = '', @@ -94,14 +89,14 @@ def __init__( # pylint: disable=too-many-arguments hostname=hostname, description=description, transport_type=transport_type, - scheduler_type=scheduler_type + scheduler_type=scheduler_type, ) super().__init__(model) if workdir is not None: self.set_workdir(workdir) def __repr__(self): - return f'<{self.__class__.__name__}: {str(self)}>' + return f'<{self.__class__.__name__}: {self!s}>' def __str__(self): return f'{self.label} ({self.hostname}), pk: {self.pk}' @@ -122,64 +117,50 @@ def logger(self) -> logging.Logger: @classmethod def _label_validator(cls, label: str) -> None: - """ - Validates the label. - """ + """Validates the label.""" if not label.strip(): raise exceptions.ValidationError('No label specified') @classmethod def _hostname_validator(cls, hostname: str) -> None: - """ - Validates the hostname. - """ + """Validates the hostname.""" if not (hostname or hostname.strip()): raise exceptions.ValidationError('No hostname specified') @classmethod def _description_validator(cls, description: str) -> None: - """ - Validates the description. - """ + """Validates the description.""" # The description is always valid @classmethod def _transport_type_validator(cls, transport_type: str) -> None: - """ - Validates the transport string. - """ + """Validates the transport string.""" from aiida.plugins.entry_point import get_entry_point_names + if transport_type not in get_entry_point_names('aiida.transports'): raise exceptions.ValidationError('The specified transport is not a valid one') @classmethod def _scheduler_type_validator(cls, scheduler_type: str) -> None: - """ - Validates the transport string. - """ + """Validates the transport string.""" from aiida.plugins.entry_point import get_entry_point_names + if scheduler_type not in get_entry_point_names('aiida.schedulers'): raise exceptions.ValidationError(f'The specified scheduler `{scheduler_type}` is not a valid one') @classmethod def _prepend_text_validator(cls, prepend_text: str) -> None: - """ - Validates the prepend text string. - """ + """Validates the prepend text string.""" # no validation done @classmethod def _append_text_validator(cls, append_text: str) -> None: - """ - Validates the append text string. - """ + """Validates the append text string.""" # no validation done @classmethod def _workdir_validator(cls, workdir: str) -> None: - """ - Validates the transport string. - """ + """Validates the transport string.""" if not workdir.strip(): raise exceptions.ValidationError('No workdir specified') @@ -194,8 +175,7 @@ def _workdir_validator(cls, workdir: str) -> None: raise exceptions.ValidationError('The workdir must be an absolute path') def _mpirun_command_validator(self, mpirun_cmd: Union[List[str], Tuple[str, ...]]) -> None: - """ - Validates the mpirun_command variable. MUST be called after properly + """Validates the mpirun_command variable. MUST be called after properly checking for a valid scheduler. """ if not isinstance(mpirun_cmd, (tuple, list)) or not all(isinstance(i, str) for i in mpirun_cmd): @@ -218,8 +198,7 @@ def _mpirun_command_validator(self, mpirun_cmd: Union[List[str], Tuple[str, ...] raise exceptions.ValidationError(f"Error in the string: '{exc}'") def validate(self) -> None: - """ - Check if the attributes and files retrieved from the DB are valid. + """Check if the attributes and files retrieved from the DB are valid. Raise a ValidationError if something is wrong. Must be able to work even before storing: therefore, use the get_attr and similar methods @@ -249,9 +228,7 @@ def validate(self) -> None: @classmethod def _default_mpiprocs_per_machine_validator(cls, def_cpus_per_machine: Optional[int]) -> None: - """ - Validates the default number of CPUs per machine (node) - """ + """Validates the default number of CPUs per machine (node)""" if def_cpus_per_machine is None: return @@ -273,14 +250,11 @@ def default_memory_per_machine_validator(cls, def_memory_per_machine: Optional[i ) def copy(self) -> 'Computer': - """ - Return a copy of the current object to work with, not stored yet. - """ + """Return a copy of the current object to work with, not stored yet.""" return entities.from_backend_entity(Computer, self._backend_entity.copy()) def store(self) -> 'Computer': - """ - Store the computer in the DB. + """Store the computer in the DB. Differently from Nodes, a computer can be re-stored if its properties are to be changed (e.g. a new mpirun command, etc.) @@ -385,8 +359,7 @@ def metadata(self, value: Dict[str, Any]) -> None: self._backend_entity.set_metadata(value) def delete_property(self, name: str, raise_exception: bool = True) -> None: - """ - Delete a property from this computer + """Delete a property from this computer :param name: the name of the property :param raise_exception: if True raise if the property does not exist, otherwise return None @@ -452,12 +425,12 @@ def set_use_double_quotes(self, val: bool) -> None: :param use_double_quotes: True if to escape with double quotes, False otherwise. """ from aiida.common.lang import type_check + type_check(val, bool) self.set_property('use_double_quotes', val) def get_mpirun_command(self) -> List[str]: - """ - Return the mpirun command. Must be a list of strings, that will be + """Return the mpirun command. Must be a list of strings, that will be then joined with spaces when submitting. I also provide a sensible default that may be ok in many cases. @@ -465,8 +438,7 @@ def get_mpirun_command(self) -> List[str]: return self.get_property('mpirun_command', ['mpirun', '-np', '{tot_num_mpiprocs}']) def set_mpirun_command(self, val: Union[List[str], Tuple[str, ...]]) -> None: - """ - Set the mpirun command. It must be a list of strings (you can use + """Set the mpirun command. It must be a list of strings (you can use string.split() if you have a single, space-separated string). """ if not isinstance(val, (tuple, list)) or not all(isinstance(i, str) for i in val): @@ -474,15 +446,13 @@ def set_mpirun_command(self, val: Union[List[str], Tuple[str, ...]]) -> None: self.set_property('mpirun_command', val) def get_default_mpiprocs_per_machine(self) -> Optional[int]: - """ - Return the default number of CPUs per machine (node) for this computer, + """Return the default number of CPUs per machine (node) for this computer, or None if it was not set. """ return self.get_property('default_mpiprocs_per_machine', None) def set_default_mpiprocs_per_machine(self, def_cpus_per_machine: Optional[int]) -> None: - """ - Set the default number of CPUs per machine (node) for this computer. + """Set the default number of CPUs per machine (node) for this computer. Accepts None if you do not want to set this value. """ if def_cpus_per_machine is None: @@ -492,15 +462,13 @@ def set_default_mpiprocs_per_machine(self, def_cpus_per_machine: Optional[int]) self.set_property('default_mpiprocs_per_machine', def_cpus_per_machine) def get_default_memory_per_machine(self) -> Optional[int]: - """ - Return the default amount of memory (kB) per machine (node) for this computer, + """Return the default amount of memory (kB) per machine (node) for this computer, or None if it was not set. """ return self.get_property('default_memory_per_machine', None) def set_default_memory_per_machine(self, def_memory_per_machine: Optional[int]) -> None: - """ - Set the default amount of memory (kB) per machine (node) for this computer. + """Set the default amount of memory (kB) per machine (node) for this computer. Accepts None if you do not want to set this value. """ self.default_memory_per_machine_validator(def_memory_per_machine) @@ -524,8 +492,7 @@ def get_minimum_job_poll_interval(self) -> float: return self.get_property(self.PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL, default) def set_minimum_job_poll_interval(self, interval: float) -> None: - """ - Set the minimum interval between subsequent requests to update the list + """Set the minimum interval between subsequent requests to update the list of jobs currently running on this computer. :param interval: The minimum interval in seconds @@ -533,8 +500,7 @@ def set_minimum_job_poll_interval(self, interval: float) -> None: self.set_property(self.PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL, interval) def get_workdir(self) -> str: - """ - Get the working directory for this computer + """Get the working directory for this computer :return: The currently configured working directory """ return self.get_property(self.PROPERTY_WORKDIR, '/scratch/{username}/aiida_run/') @@ -546,9 +512,7 @@ def get_shebang(self) -> str: return self.get_property(self.PROPERTY_SHEBANG, '#!/bin/bash') def set_shebang(self, val: str) -> None: - """ - :param str val: A valid shebang line - """ + """:param str val: A valid shebang line""" if not isinstance(val, str): raise ValueError(f'{val} is invalid. Input has to be a string') if not val.startswith('#!'): @@ -558,8 +522,7 @@ def set_shebang(self, val: str) -> None: self.metadata = metadata def get_authinfo(self, user: 'User') -> 'AuthInfo': - """ - Return the aiida.orm.authinfo.AuthInfo instance for the + """Return the aiida.orm.authinfo.AuthInfo instance for the given user on this computer, if the computer is configured for the given user. @@ -589,8 +552,7 @@ def is_configured(self) -> bool: return self.is_user_configured(users.User.get_collection(self.backend).get_default()) def is_user_configured(self, user: 'User') -> bool: - """ - Is the user configured on this computer? + """Is the user configured on this computer? :param user: the user to check :return: True if configured, False otherwise @@ -602,8 +564,7 @@ def is_user_configured(self, user: 'User') -> bool: return False def is_user_enabled(self, user: 'User') -> bool: - """ - Is the given user enabled to run on this computer? + """Is the given user enabled to run on this computer? :param user: the user to check :return: True if enabled, False otherwise @@ -616,8 +577,7 @@ def is_user_enabled(self, user: 'User') -> bool: return False def get_transport(self, user: Optional['User'] = None) -> 'Transport': - """ - Return a Transport class, configured with all correct parameters. + """Return a Transport class, configured with all correct parameters. The Transport is closed (meaning that if you want to run any operation with it, you have to open it first (i.e., e.g. for a SSH transport, you have to open a connection). To do this you can call ``transports.open()``, or simply @@ -634,7 +594,7 @@ def get_transport(self, user: Optional['User'] = None) -> 'Transport': parameters to the supercomputer, as configured with ``verdi computer configure`` for the user specified as a parameter ``user``. """ - from . import authinfos # pylint: disable=cyclic-import + from . import authinfos user = user or users.User.get_collection(self.backend).get_default() authinfo = authinfos.AuthInfo.get_collection(self.backend).get(dbcomputer=self, aiidauser=user) diff --git a/aiida/orm/convert.py b/aiida/orm/convert.py index e959fe3a08..1f7eb71272 100644 --- a/aiida/orm/convert.py +++ b/aiida/orm/convert.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=cyclic-import """Module for converting backend entities into frontend, ORM, entities""" from collections.abc import Iterator, Mapping, Sized from functools import singledispatch @@ -72,6 +71,7 @@ def _(backend_entity): @get_orm_entity.register(BackendGroup) def _(backend_entity): from .groups import load_group_class + group_class = load_group_class(backend_entity.type_string) return from_backend_entity(group_class, backend_entity) @@ -79,43 +79,48 @@ def _(backend_entity): @get_orm_entity.register(BackendComputer) def _(backend_entity): from . import computers + return from_backend_entity(computers.Computer, backend_entity) @get_orm_entity.register(BackendUser) def _(backend_entity): from . import users + return from_backend_entity(users.User, backend_entity) @get_orm_entity.register(BackendAuthInfo) def _(backend_entity): from . import authinfos + return from_backend_entity(authinfos.AuthInfo, backend_entity) @get_orm_entity.register(BackendLog) def _(backend_entity): from . import logs + return from_backend_entity(logs.Log, backend_entity) @get_orm_entity.register(BackendComment) def _(backend_entity): from . import comments + return from_backend_entity(comments.Comment, backend_entity) @get_orm_entity.register(BackendNode) def _(backend_entity): - from .utils.node import load_node_class # pylint: disable=import-error,no-name-in-module + from .utils.node import load_node_class + node_class = load_node_class(backend_entity.node_type) return from_backend_entity(node_class, backend_entity) class ConvertIterator(Iterator, Sized): - """ - Iterator that converts backend entities into frontend ORM entities as needed + """Iterator that converts backend entities into frontend ORM entities as needed See :func:`aiida.orm.Group.nodes` for an example. """ diff --git a/aiida/orm/entities.py b/aiida/orm/entities.py index fdf0639db2..6729b74aa8 100644 --- a/aiida/orm/entities.py +++ b/aiida/orm/entities.py @@ -35,6 +35,7 @@ class EntityTypes(Enum): """Enum for referring to ORM entities in a backend-agnostic manner.""" + AUTHINFO = 'authinfo' COMMENT = 'comment' COMPUTER = 'computer' @@ -62,16 +63,18 @@ def get_cached(cls, entity_class: Type[EntityType], backend: 'StorageBackend'): :param backend: the backend instance to get the collection for """ from aiida.orm.implementation import StorageBackend + type_check(backend, StorageBackend) return cls(entity_class, backend=backend) def __init__(self, entity_class: Type[EntityType], backend: Optional['StorageBackend'] = None) -> None: - """ Construct a new entity collection. + """Construct a new entity collection. :param entity_class: the entity type e.g. User, Computer, etc :param backend: the backend instance to get the collection for, or use the default """ from aiida.orm.implementation import StorageBackend + type_check(backend, StorageBackend, allow_none=True) assert issubclass(entity_class, self._entity_base_cls()) self._backend = backend or get_manager().get_profile_storage() @@ -98,7 +101,7 @@ def query( filters: Optional['FilterType'] = None, order_by: Optional['OrderByType'] = None, limit: Optional[int] = None, - offset: Optional[int] = None + offset: Optional[int] = None, ) -> 'QueryBuilder': """Get a query builder for the objects of this collection. @@ -131,7 +134,7 @@ def find( self, filters: Optional['FilterType'] = None, order_by: Optional['OrderByType'] = None, - limit: Optional[int] = None + limit: Optional[int] = None, ) -> List[EntityType]: """Find collection entries matching the filter criteria. @@ -149,7 +152,7 @@ def all(self) -> List[EntityType]: :return: A list of all entities """ - return cast(List[EntityType], self.query().all(flat=True)) # pylint: disable=no-member + return cast(List[EntityType], self.query().all(flat=True)) def count(self, filters: Optional['FilterType'] = None) -> int: """Count entities in this collection according to criteria. @@ -167,7 +170,7 @@ class Entity(abc.ABC, Generic[BackendEntityType, CollectionType]): _CLS_COLLECTION: Type[CollectionType] = Collection # type: ignore[assignment] @classproperty - def objects(cls: EntityType) -> CollectionType: # pylint: disable=no-self-argument + def objects(cls: EntityType) -> CollectionType: # noqa: N805 """Get a collection for objects of this type, with the default backend. .. deprecated:: This will be removed in v3, use ``collection`` instead. @@ -178,7 +181,7 @@ def objects(cls: EntityType) -> CollectionType: # pylint: disable=no-self-argum return cls.collection @classproperty - def collection(cls) -> CollectionType: # pylint: disable=no-self-argument + def collection(cls) -> CollectionType: # noqa: N805 """Get a collection for objects of this type, with the default backend. :return: an object that can be used to access entities of this type @@ -207,14 +210,12 @@ def get(cls, **kwargs): warn_deprecation( f'`{cls.__name__}.get` method is deprecated, use `{cls.__name__}.collection.get` instead.', version=3, - stacklevel=2 + stacklevel=2, ) - return cls.collection.get(**kwargs) # pylint: disable=no-member + return cls.collection.get(**kwargs) def __init__(self, backend_entity: BackendEntityType) -> None: - """ - :param backend_entity: the backend model supporting this entity - """ + """:param backend_entity: the backend model supporting this entity""" self._backend_entity = backend_entity call_with_super_check(self.initialize) @@ -230,7 +231,7 @@ def initialize(self) -> None: """ @property - def id(self) -> int | None: # pylint: disable=invalid-name + def id(self) -> int | None: """Return the id for this entity. This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance. @@ -284,6 +285,6 @@ def from_backend_entity(cls: Type[EntityType], backend_entity: BackendEntityType type_check(backend_entity, BackendEntity) entity = cls.__new__(cls) - entity._backend_entity = backend_entity # pylint: disable=protected-access + entity._backend_entity = backend_entity call_with_super_check(entity.initialize) return entity diff --git a/aiida/orm/extras.py b/aiida/orm/extras.py index 6b5a1ca37e..717b76bf00 100644 --- a/aiida/orm/extras.py +++ b/aiida/orm/extras.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines,too-many-arguments """Interface to the extras of a node instance.""" import copy from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, Union diff --git a/aiida/orm/groups.py b/aiida/orm/groups.py index 7f7c4daaff..3382600355 100644 --- a/aiida/orm/groups.py +++ b/aiida/orm/groups.py @@ -8,9 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### """AiiDA Group entites""" +import warnings from functools import cached_property from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence, Tuple, Type, TypeVar, Union, cast -import warnings from aiida.common import exceptions from aiida.common.lang import classproperty, type_check @@ -23,7 +23,8 @@ from importlib_metadata import EntryPoint from aiida.orm import Node, User - from aiida.orm.implementation import BackendGroup, StorageBackend + from aiida.orm.implementation import StorageBackend + from aiida.orm.implementation.groups import BackendGroup # noqa: F401 __all__ = ('Group', 'AutoGroup', 'ImportGroup', 'UpfFamily') @@ -45,7 +46,7 @@ def load_group_class(type_string: str) -> Type['Group']: group_class = load_entry_point('aiida.groups', type_string) except EntryPointError: message = f'could not load entry point `{type_string}`, falling back onto `Group` base class.' - warnings.warn(message) # pylint: disable=no-member + warnings.warn(message) group_class = Group return group_class @@ -59,8 +60,7 @@ def _entity_base_cls() -> Type['Group']: return Group def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple['Group', bool]: - """ - Try to retrieve a group from the DB with the given arguments; + """Try to retrieve a group from the DB with the given arguments; create (and store) a new group if such a group was not present yet. :param label: group label @@ -82,8 +82,7 @@ def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple['Group', return res[0], False def delete(self, pk: int) -> None: - """ - Delete a group + """Delete a group :param pk: the id of the group to delete """ @@ -116,10 +115,9 @@ def __init__( user: Optional['User'] = None, description: str = '', type_string: Optional[str] = None, - backend: Optional['StorageBackend'] = None + backend: Optional['StorageBackend'] = None, ): - """ - Create a new group. Either pass a dbgroup parameter, to reload + """Create a new group. Either pass a dbgroup parameter, to reload a group from the DB (and then, no further parameters are allowed), or pass the parameters for the Group creation. @@ -143,7 +141,7 @@ def __init__( super().__init__(model) @classproperty - def _type_string(cls) -> Optional[str]: + def _type_string(cls) -> Optional[str]: # noqa: N805 from aiida.plugins.entry_point import get_entry_point_from_class if hasattr(cls, '__type_string'): @@ -153,12 +151,12 @@ def _type_string(cls) -> Optional[str]: entry_point_group, entry_point = get_entry_point_from_class(mod, name) if entry_point_group is None or entry_point_group != 'aiida.groups': - cls.__type_string = None # type: ignore[misc] # pylint: disable=protected-access + cls.__type_string = None # type: ignore[misc] message = f'no registered entry point for `{mod}:{name}` so its instances will not be storable.' - warnings.warn(message) # pylint: disable=no-member + warnings.warn(message) else: assert entry_point is not None - cls.__type_string = entry_point.name # type: ignore[misc] # pylint: disable=protected-access + cls.__type_string = entry_point.name # type: ignore[misc] return cls.__type_string @cached_property @@ -183,7 +181,7 @@ def store(self: SelfType) -> SelfType: return super().store() @classproperty - def entry_point(cls) -> Optional['EntryPoint']: + def entry_point(cls) -> Optional['EntryPoint']: # noqa: N805 """Return the entry point associated this group type. :return: the associated entry point or ``None`` if it isn't known. @@ -204,15 +202,12 @@ def uuid(self) -> str: @property def label(self) -> str: - """ - :return: the label of the group as a string - """ + """:return: the label of the group as a string""" return self._backend_entity.label @label.setter def label(self, label: str) -> None: - """ - Attempt to change the label of the group instance. If the group is already stored + """Attempt to change the label of the group instance. If the group is already stored and the another group of the same type already exists with the desired label, a UniquenessError will be raised @@ -225,30 +220,22 @@ def label(self, label: str) -> None: @property def description(self) -> str: - """ - :return: the description of the group as a string - """ + """:return: the description of the group as a string""" return self._backend_entity.description or '' @description.setter def description(self, description: str) -> None: - """ - :param description: the description of the group as a string - """ + """:param description: the description of the group as a string""" self._backend_entity.description = description @property def type_string(self) -> str: - """ - :return: the string defining the type of the group - """ + """:return: the string defining the type of the group""" return self._backend_entity.type_string @property def user(self) -> 'User': - """ - :return: the user associated with this group - """ + """:return: the user associated with this group""" return entities.from_backend_entity(users.User, self._backend_entity.user) @user.setter @@ -269,8 +256,7 @@ def count(self) -> int: @property def nodes(self) -> convert.ConvertIterator: - """ - Return a generator/iterator that iterates over all nodes and returns + """Return a generator/iterator that iterates over all nodes and returns the respective AiiDA subclasses of Node, and also allows to ask for the number of nodes in the group using len(). """ @@ -335,9 +321,7 @@ def remove_nodes(self, nodes: Union['Node', Sequence['Node']]) -> None: self._backend_entity.remove_nodes([node.backend_entity for node in nodes]) def is_user_defined(self) -> bool: - """ - :return: True if the group is user defined, False otherwise - """ + """:return: True if the group is user defined, False otherwise""" return not self.type_string _deprecated_extra_methods = { @@ -355,8 +339,7 @@ def is_user_defined(self) -> bool: } def __getattr__(self, name: str) -> Any: - """ - This method is called when an extras is not found in the instance. + """This method is called when an extras is not found in the instance. It allows for the handling of deprecated mixin methods. """ diff --git a/aiida/orm/implementation/__init__.py b/aiida/orm/implementation/__init__.py index 0f02fcbf65..4c906a5408 100644 --- a/aiida/orm/implementation/__init__.py +++ b/aiida/orm/implementation/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .authinfos import * from .comments import * @@ -51,4 +50,4 @@ 'validate_attribute_extra_key', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/implementation/comments.py b/aiida/orm/implementation/comments.py index b44d1932d1..c2ae1965c1 100644 --- a/aiida/orm/implementation/comments.py +++ b/aiida/orm/implementation/comments.py @@ -76,10 +76,10 @@ class BackendCommentCollection(BackendCollection[BackendComment]): ENTITY_CLASS = BackendComment @abc.abstractmethod - def create( # type: ignore[override] # pylint: disable=arguments-differ - self, node: 'BackendNode', user: 'BackendUser', content: Optional[str] = None, **kwargs): - """ - Create a Comment for a given node and user + def create( # type: ignore[override] + self, node: 'BackendNode', user: 'BackendUser', content: Optional[str] = None, **kwargs + ): + """Create a Comment for a given node and user :param node: a Node instance :param user: a User instance @@ -89,8 +89,7 @@ def create( # type: ignore[override] # pylint: disable=arguments-differ @abc.abstractmethod def delete(self, comment_id: int) -> None: - """ - Remove a Comment from the collection with the given id + """Remove a Comment from the collection with the given id :param comment_id: the id of the comment to delete @@ -100,16 +99,14 @@ def delete(self, comment_id: int) -> None: @abc.abstractmethod def delete_all(self) -> None: - """ - Delete all Comment entries. + """Delete all Comment entries. :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted """ @abc.abstractmethod def delete_many(self, filters: dict) -> List[int]: - """ - Delete Comments based on ``filters`` + """Delete Comments based on ``filters`` :param filters: similar to QueryBuilder filter diff --git a/aiida/orm/implementation/computers.py b/aiida/orm/implementation/computers.py index 804ce24011..05467cba06 100644 --- a/aiida/orm/implementation/computers.py +++ b/aiida/orm/implementation/computers.py @@ -24,7 +24,6 @@ class BackendComputer(BackendEntity): It has an associated transport_type, which points to a plugin for connecting to the resource and passing data, and a scheduler_type, which points to a plugin for scheduling calculations. """ - # pylint: disable=too-many-public-methods _logger = logging.getLogger(__name__) @@ -58,8 +57,7 @@ def hostname(self) -> str: @abc.abstractmethod def set_hostname(self, val: str) -> None: - """ - Set the hostname of this computer + """Set the hostname of this computer :param val: The new hostname """ @@ -102,8 +100,7 @@ class BackendComputerCollection(BackendCollection[BackendComputer]): @abc.abstractmethod def delete(self, pk: int) -> None: - """ - Delete an entry with the given pk + """Delete an entry with the given pk :param pk: the pk of the entry to delete """ diff --git a/aiida/orm/implementation/entities.py b/aiida/orm/implementation/entities.py index 52320777b3..2363ea34be 100644 --- a/aiida/orm/implementation/entities.py +++ b/aiida/orm/implementation/entities.py @@ -18,13 +18,13 @@ __all__ = ('BackendEntity', 'BackendCollection', 'EntityType', 'BackendEntityExtrasMixin') -EntityType = TypeVar('EntityType', bound='BackendEntity') # pylint: disable=invalid-name +EntityType = TypeVar('EntityType', bound='BackendEntity') class BackendEntity(abc.ABC): """An first-class entity in the backend""" - def __init__(self, backend: 'StorageBackend', **kwargs: Any): # pylint: disable=unused-argument + def __init__(self, backend: 'StorageBackend', **kwargs: Any): self._backend = backend @property @@ -37,7 +37,7 @@ def backend(self) -> 'StorageBackend': @property @abc.abstractmethod - def id(self) -> int: # pylint: disable=invalid-name + def id(self) -> int: """Return the id for this entity. This is unique only amongst entities of this type for a particular backend. @@ -77,9 +77,7 @@ class BackendCollection(Generic[EntityType]): ENTITY_CLASS: ClassVar[Type[EntityType]] # type: ignore[misc] def __init__(self, backend: 'StorageBackend'): - """ - :param backend: the backend this collection belongs to - """ + """:param backend: the backend this collection belongs to""" assert issubclass(self.ENTITY_CLASS, BackendEntity), 'Must set the ENTRY_CLASS class variable to an entity type' self._backend = backend @@ -89,8 +87,7 @@ def backend(self) -> 'StorageBackend': return self._backend def create(self, **kwargs: Any) -> EntityType: - """ - Create new a entry and set the attributes to those specified in the keyword arguments + """Create new a entry and set the attributes to those specified in the keyword arguments :return: the newly created entry of type ENTITY_CLASS """ diff --git a/aiida/orm/implementation/groups.py b/aiida/orm/implementation/groups.py index 87b33a8679..0543801e9c 100644 --- a/aiida/orm/implementation/groups.py +++ b/aiida/orm/implementation/groups.py @@ -23,7 +23,7 @@ class NodeIterator(Protocol): """Protocol for iterating over nodes in a group""" - def __iter__(self) -> 'NodeIterator': # pylint: disable=non-iterator-returned + def __iter__(self) -> 'NodeIterator': """Return an iterator over the nodes in the group.""" def __next__(self) -> BackendNode: @@ -32,7 +32,7 @@ def __next__(self) -> BackendNode: def __getitem__(self, value: Union[int, slice]) -> Union[BackendNode, List[BackendNode]]: """Index node(s) from the group.""" - def __len__(self) -> int: # pylint: disable=invalid-length-returned + def __len__(self) -> int: """Return the number of nodes in the group.""" @@ -50,8 +50,7 @@ def label(self) -> str: @label.setter @abc.abstractmethod def label(self, name: str) -> None: - """ - Attempt to change the name of the group instance. If the group is already stored + """Attempt to change the name of the group instance. If the group is already stored and the another group of the same type already exists with the desired name, a UniquenessError will be raised @@ -92,8 +91,7 @@ def uuid(self) -> str: @property @abc.abstractmethod def nodes(self) -> NodeIterator: - """ - Return a generator/iterator that iterates over all nodes and returns + """Return a generator/iterator that iterates over all nodes and returns the respective AiiDA subclasses of Node, and also allows to ask for the number of nodes in the group using len(). """ @@ -109,7 +107,7 @@ def count(self) -> int: def clear(self) -> None: """Remove all the nodes from this group.""" - def add_nodes(self, nodes: Sequence[BackendNode], **kwargs): # pylint: disable=unused-argument + def add_nodes(self, nodes: Sequence[BackendNode], **kwargs): """Add a set of nodes to the group. :note: all the nodes *and* the group itself have to be stored. @@ -142,7 +140,7 @@ def remove_nodes(self, nodes: Sequence[BackendNode]) -> None: raise TypeError(f'nodes have to be of type {BackendNode}') def __repr__(self) -> str: - return f'<{self.__class__.__name__}: {str(self)}>' + return f'<{self.__class__.__name__}: {self!s}>' def __str__(self) -> str: if self.type_string: @@ -157,9 +155,8 @@ class BackendGroupCollection(BackendCollection[BackendGroup]): ENTITY_CLASS = BackendGroup @abc.abstractmethod - def delete(self, id: int) -> None: # pylint: disable=redefined-builtin, invalid-name - """ - Delete a group with the given id + def delete(self, id: int) -> None: + """Delete a group with the given id :param id: the id of the group to delete """ diff --git a/aiida/orm/implementation/logs.py b/aiida/orm/implementation/logs.py index 1cb3fec884..69ed315dae 100644 --- a/aiida/orm/implementation/logs.py +++ b/aiida/orm/implementation/logs.py @@ -66,8 +66,7 @@ class BackendLogCollection(BackendCollection[BackendLog]): @abc.abstractmethod def delete(self, log_id: int) -> None: - """ - Remove a Log entry from the collection with the given id + """Remove a Log entry from the collection with the given id :param log_id: id of the Log to delete @@ -77,16 +76,14 @@ def delete(self, log_id: int) -> None: @abc.abstractmethod def delete_all(self) -> None: - """ - Delete all Log entries. + """Delete all Log entries. :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted """ @abc.abstractmethod def delete_many(self, filters: dict) -> List[int]: - """ - Delete Logs based on ``filters`` + """Delete Logs based on ``filters`` :param filters: similar to QueryBuilder filter diff --git a/aiida/orm/implementation/nodes.py b/aiida/orm/implementation/nodes.py index 915672d158..3d0e6660fd 100644 --- a/aiida/orm/implementation/nodes.py +++ b/aiida/orm/implementation/nodes.py @@ -30,8 +30,6 @@ class BackendNode(BackendEntity, BackendEntityExtrasMixin, metaclass=abc.ABCMeta A node stores data input or output from a computation. """ - # pylint: disable=too-many-public-methods - @abc.abstractmethod def clone(self: BackendNodeType) -> BackendNodeType: """Return an unstored clone of ourselves. @@ -181,10 +179,8 @@ def add_incoming(self, source: 'BackendNode', link_type, link_label): """ @abc.abstractmethod - def store( # pylint: disable=arguments-differ - self: BackendNodeType, - links: Optional[Sequence['LinkTriple']] = None, - clean: bool = True + def store( + self: BackendNodeType, links: Optional[Sequence['LinkTriple']] = None, clean: bool = True ) -> BackendNodeType: """Store the node in the database. diff --git a/aiida/orm/implementation/querybuilder.py b/aiida/orm/implementation/querybuilder.py index 55e649aac3..4dc30f1dfc 100644 --- a/aiida/orm/implementation/querybuilder.py +++ b/aiida/orm/implementation/querybuilder.py @@ -29,8 +29,15 @@ EntityTypes.GROUP.value: {'with_node', 'with_user'}, EntityTypes.LOG.value: {'with_node'}, EntityTypes.NODE.value: { - 'with_comment', 'with_log', 'with_incoming', 'with_outgoing', 'with_descendants', 'with_ancestors', - 'with_computer', 'with_user', 'with_group' + 'with_comment', + 'with_log', + 'with_incoming', + 'with_outgoing', + 'with_descendants', + 'with_ancestors', + 'with_computer', + 'with_user', + 'with_group', }, EntityTypes.USER.value: {'with_authinfo', 'with_comment', 'with_group', 'with_node'}, EntityTypes.LINK.value: set(), @@ -81,10 +88,9 @@ class BackendQueryBuilder(abc.ABC): """Backend query builder interface""" def __init__(self, backend: 'StorageBackend'): - """ - :param backend: the backend - """ + """:param backend: the backend""" from .storage_backend import StorageBackend + type_check(backend, StorageBackend) self._backend = backend diff --git a/aiida/orm/implementation/storage_backend.py b/aiida/orm/implementation/storage_backend.py index 248f6392d9..3a4fdb28fb 100644 --- a/aiida/orm/implementation/storage_backend.py +++ b/aiida/orm/implementation/storage_backend.py @@ -32,10 +32,10 @@ __all__ = ('StorageBackend',) -TransactionType = TypeVar('TransactionType') # pylint: disable=invalid-name +TransactionType = TypeVar('TransactionType') -class StorageBackend(abc.ABC): # pylint: disable=too-many-public-methods +class StorageBackend(abc.ABC): """Abstraction for a backend to read/write persistent data for a profile's provenance graph. AiiDA splits data storage into two sources: @@ -103,6 +103,7 @@ def __init__(self, profile: 'Profile') -> None: :raises: :raises: :class:`aiida.common.exceptions.CorruptStorage` if the storage is internally inconsistent """ from aiida.orm.autogroup import AutogroupManager + self._profile = profile self._default_user: Optional['User'] = None self._autogroup = AutogroupManager(self) @@ -143,6 +144,7 @@ def _clear(self) -> None: .. warning:: This is a destructive operation, and should only be used for testing purposes. """ from aiida.orm.autogroup import AutogroupManager + self.reset_default_user() self._autogroup = AutogroupManager(self) @@ -208,8 +210,7 @@ def query(self) -> 'BackendQueryBuilder': @abc.abstractmethod def transaction(self) -> ContextManager[Any]: - """ - Get a context manager that can be used as a transaction context for a series of backend operations. + """Get a context manager that can be used as a transaction context for a series of backend operations. If there is an exception within the context then the changes will be rolled back and the state will be as before entering. Transactions can be nested. @@ -325,29 +326,33 @@ def get_orm_entities(self, detailed: bool = False) -> dict: query_user = QueryBuilder(self).append(User, project=['email']) data['Users'] = {'count': query_user.count()} if detailed: - data['Users']['emails'] = sorted({email for email, in query_user.iterall() if email is not None}) + data['Users']['emails'] = sorted({email for (email,) in query_user.iterall() if email is not None}) query_comp = QueryBuilder(self).append(Computer, project=['label']) data['Computers'] = {'count': query_comp.count()} if detailed: - data['Computers']['labels'] = sorted({comp for comp, in query_comp.iterall() if comp is not None}) + data['Computers']['labels'] = sorted({comp for (comp,) in query_comp.iterall() if comp is not None}) count = QueryBuilder(self).append(Node).count() data['Nodes'] = {'count': count} if detailed: - node_types = sorted({ - typ for typ, in QueryBuilder(self).append(Node, project=['node_type']).iterall() if typ is not None - }) + node_types = sorted( + {typ for (typ,) in QueryBuilder(self).append(Node, project=['node_type']).iterall() if typ is not None} + ) data['Nodes']['node_types'] = node_types - process_types = sorted({ - typ for typ, in QueryBuilder(self).append(Node, project=['process_type']).iterall() if typ is not None - }) + process_types = sorted( + { + typ + for (typ,) in QueryBuilder(self).append(Node, project=['process_type']).iterall() + if typ is not None + } + ) data['Nodes']['process_types'] = [p for p in process_types if p] query_group = QueryBuilder(self).append(Group, project=['type_string']) data['Groups'] = {'count': query_group.count()} if detailed: - data['Groups']['type_strings'] = sorted({typ for typ, in query_group.iterall() if typ is not None}) + data['Groups']['type_strings'] = sorted({typ for (typ,) in query_group.iterall() if typ is not None}) count = QueryBuilder(self).append(Comment).count() data['Comments'] = {'count': count} diff --git a/aiida/orm/implementation/users.py b/aiida/orm/implementation/users.py index c67d13e805..9d7e03f82f 100644 --- a/aiida/orm/implementation/users.py +++ b/aiida/orm/implementation/users.py @@ -24,8 +24,7 @@ class BackendUser(BackendEntity): @property @abc.abstractmethod def email(self) -> str: - """ - Get the email address of the user + """Get the email address of the user :return: the email address """ @@ -33,8 +32,7 @@ def email(self) -> str: @email.setter @abc.abstractmethod def email(self, val: str) -> None: - """ - Set the email address of the user + """Set the email address of the user :param val: the new email address """ @@ -42,8 +40,7 @@ def email(self, val: str) -> None: @property @abc.abstractmethod def first_name(self) -> str: - """ - Get the user's first name + """Get the user's first name :return: the first name """ @@ -51,8 +48,7 @@ def first_name(self) -> str: @first_name.setter @abc.abstractmethod def first_name(self, val: str) -> None: - """ - Set the user's first name + """Set the user's first name :param val: the new first name """ @@ -60,8 +56,7 @@ def first_name(self, val: str) -> None: @property @abc.abstractmethod def last_name(self) -> str: - """ - Get the user's last name + """Get the user's last name :return: the last name """ @@ -69,8 +64,7 @@ def last_name(self) -> str: @last_name.setter @abc.abstractmethod def last_name(self, val: str) -> None: - """ - Set the user's last name + """Set the user's last name :param val: the new last name """ @@ -78,8 +72,7 @@ def last_name(self, val: str) -> None: @property @abc.abstractmethod def institution(self) -> str: - """ - Get the user's institution + """Get the user's institution :return: the institution """ @@ -87,13 +80,11 @@ def institution(self) -> str: @institution.setter @abc.abstractmethod def institution(self, val: str) -> None: - """ - Set the user's institution + """Set the user's institution :param val: the new institution """ class BackendUserCollection(BackendCollection[BackendUser]): - ENTITY_CLASS = BackendUser diff --git a/aiida/orm/implementation/utils.py b/aiida/orm/implementation/utils.py index 76791336c2..7afabd8309 100644 --- a/aiida/orm/implementation/utils.py +++ b/aiida/orm/implementation/utils.py @@ -8,10 +8,10 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utility methods for backend non-specific implementations.""" -from collections.abc import Iterable, Mapping -from decimal import Decimal import math import numbers +from collections.abc import Iterable, Mapping +from decimal import Decimal from aiida.common import exceptions from aiida.common.constants import AIIDA_FLOAT_PRECISION @@ -38,8 +38,7 @@ def validate_attribute_extra_key(key): def clean_value(value): - """ - Get value from input and (recursively) replace, if needed, all occurrences + """Get value from input and (recursively) replace, if needed, all occurrences of BaseType AiiDA data nodes with their value, and List with a standard list. It also makes a deep copy of everything The purpose of this function is to convert data to a type which can be serialized and deserialized @@ -57,8 +56,7 @@ def clean_value(value): from aiida.orm import BaseType def clean_builtin(val): - """ - A function to clean build-in python values (`BaseType`). + """A function to clean build-in python values (`BaseType`). It mainly checks that we don't store NaN or Inf. """ @@ -103,7 +101,7 @@ def clean_builtin(val): # Check dictionary before iterables return {k: clean_value(v) for k, v in value.items()} - if (isinstance(value, Iterable) and not isinstance(value, str)): + if isinstance(value, Iterable) and not isinstance(value, str): # list, tuple, ... but not a string # This should also properly take care of dealing with the # basedatatypes.List object diff --git a/aiida/orm/logs.py b/aiida/orm/logs.py index 88cc6612cc..f398264a61 100644 --- a/aiida/orm/logs.py +++ b/aiida/orm/logs.py @@ -8,19 +8,19 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for orm logging abstract classes""" -from datetime import datetime import logging +from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type from aiida.common import timezone -from aiida.common.lang import classproperty from aiida.manage import get_manager from . import entities if TYPE_CHECKING: from aiida.orm import Node - from aiida.orm.implementation import BackendLog, StorageBackend + from aiida.orm.implementation import StorageBackend + from aiida.orm.implementation.logs import BackendLog # noqa: F401 from aiida.orm.querybuilder import FilterType, OrderByType __all__ = ('Log', 'OrderSpecifier', 'ASCENDING', 'DESCENDING') @@ -29,13 +29,12 @@ DESCENDING = 'desc' -def OrderSpecifier(field, direction): # pylint: disable=invalid-name +def OrderSpecifier(field, direction): # noqa: N802 return {field: direction} class LogCollection(entities.Collection['Log']): - """ - This class represents the collection of logs and can be used to create + """This class represents the collection of logs and can be used to create and retrieve logs. """ @@ -60,6 +59,7 @@ def create_entry_from_record(self, record: logging.LogRecord) -> Optional['Log'] # If an `exc_info` is present, the log message was an exception, so format the full traceback try: import traceback + exc_info = metadata.pop('exc_info') message = ''.join(traceback.format_exception(*exc_info)) except (TypeError, KeyError): @@ -77,7 +77,7 @@ def create_entry_from_record(self, record: logging.LogRecord) -> Optional['Log'] dbnode_id=dbnode_id, message=message, metadata=metadata, - backend=self.backend + backend=self.backend, ) def get_logs_for(self, entity: 'Node', order_by: Optional['OrderByType'] = None) -> List['Log']: @@ -91,7 +91,7 @@ def get_logs_for(self, entity: 'Node', order_by: Optional['OrderByType'] = None) from . import nodes if not isinstance(entity, nodes.Node): - raise Exception('Only node logs are stored') # pylint: disable=broad-exception-raised + raise Exception('Only node logs are stored') return self.find({'dbnode_id': entity.pk}, order_by=order_by) @@ -124,9 +124,7 @@ def delete_many(self, filters: 'FilterType') -> List[int]: class Log(entities.Entity['BackendLog', LogCollection]): - """ - An AiiDA Log entity. Corresponds to a logged message against a particular AiiDA node. - """ + """An AiiDA Log entity. Corresponds to a logged message against a particular AiiDA node.""" _CLS_COLLECTION = LogCollection @@ -138,8 +136,8 @@ def __init__( dbnode_id: int, message: str = '', metadata: Optional[Dict[str, Any]] = None, - backend: Optional['StorageBackend'] = None - ): # pylint: disable=too-many-arguments + backend: Optional['StorageBackend'] = None, + ): """Construct a new log :param time: time @@ -165,7 +163,7 @@ def __init__( levelname=levelname, dbnode_id=dbnode_id, message=message, - metadata=metadata + metadata=metadata, ) super().__init__(model) self.store() # Logs are immutable and automatically stored @@ -182,8 +180,7 @@ def uuid(self) -> str: @property def time(self) -> datetime: - """ - Get the time corresponding to the entry + """Get the time corresponding to the entry :return: The entry timestamp """ @@ -191,8 +188,7 @@ def time(self) -> datetime: @property def loggername(self) -> str: - """ - The name of the logger that created this entry + """The name of the logger that created this entry :return: The entry loggername """ @@ -200,8 +196,7 @@ def loggername(self) -> str: @property def levelname(self) -> str: - """ - The name of the log level + """The name of the log level :return: The entry log level name """ @@ -209,8 +204,7 @@ def levelname(self) -> str: @property def dbnode_id(self) -> int: - """ - Get the id of the object that created the log entry + """Get the id of the object that created the log entry :return: The id of the object that created the log entry """ @@ -218,8 +212,7 @@ def dbnode_id(self) -> int: @property def message(self) -> str: - """ - Get the message corresponding to the entry + """Get the message corresponding to the entry :return: The entry message """ @@ -227,8 +220,7 @@ def message(self) -> str: @property def metadata(self) -> Dict[str, Any]: - """ - Get the metadata corresponding to the entry + """Get the metadata corresponding to the entry :return: The entry metadata """ diff --git a/aiida/orm/nodes/__init__.py b/aiida/orm/nodes/__init__.py index 3af33b89cc..98eddd8eb5 100644 --- a/aiida/orm/nodes/__init__.py +++ b/aiida/orm/nodes/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .attributes import * from .data import * @@ -71,4 +70,4 @@ 'to_aiida_type', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/nodes/attributes.py b/aiida/orm/nodes/attributes.py index 0fbceac200..d11279d4ab 100644 --- a/aiida/orm/nodes/attributes.py +++ b/aiida/orm/nodes/attributes.py @@ -115,7 +115,7 @@ def set(self, key: str, value: Any) -> None: :raise aiida.common.ValidationError: if the key is invalid, i.e. contains periods :raise aiida.common.ModificationNotAllowed: if the entity is stored """ - self._node._check_mutability_attributes([key]) # pylint: disable=protected-access + self._node._check_mutability_attributes([key]) self._backend_node.set_attribute(key, value) def set_many(self, attributes: Dict[str, Any]) -> None: @@ -127,7 +127,7 @@ def set_many(self, attributes: Dict[str, Any]) -> None: :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods :raise aiida.common.ModificationNotAllowed: if the entity is stored """ - self._node._check_mutability_attributes(list(attributes)) # pylint: disable=protected-access + self._node._check_mutability_attributes(list(attributes)) self._backend_node.set_attribute_many(attributes) def reset(self, attributes: Dict[str, Any]) -> None: @@ -139,7 +139,7 @@ def reset(self, attributes: Dict[str, Any]) -> None: :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods :raise aiida.common.ModificationNotAllowed: if the entity is stored """ - self._node._check_mutability_attributes() # pylint: disable=protected-access + self._node._check_mutability_attributes() self._backend_node.reset_attributes(attributes) def delete(self, key: str) -> None: @@ -149,7 +149,7 @@ def delete(self, key: str) -> None: :raises AttributeError: if the attribute does not exist :raise aiida.common.ModificationNotAllowed: if the entity is stored """ - self._node._check_mutability_attributes([key]) # pylint: disable=protected-access + self._node._check_mutability_attributes([key]) self._backend_node.delete_attribute(key) def delete_many(self, keys: List[str]) -> None: @@ -159,12 +159,12 @@ def delete_many(self, keys: List[str]) -> None: :raises AttributeError: if at least one of the attribute does not exist :raise aiida.common.ModificationNotAllowed: if the entity is stored """ - self._node._check_mutability_attributes(keys) # pylint: disable=protected-access + self._node._check_mutability_attributes(keys) self._backend_node.delete_attribute_many(keys) def clear(self) -> None: """Delete all attributes.""" - self._node._check_mutability_attributes() # pylint: disable=protected-access + self._node._check_mutability_attributes() self._backend_node.clear_attributes() def items(self) -> Iterable[Tuple[str, Any]]: diff --git a/aiida/orm/nodes/caching.py b/aiida/orm/nodes/caching.py index 1fe342e5af..1e0cac1904 100644 --- a/aiida/orm/nodes/caching.py +++ b/aiida/orm/nodes/caching.py @@ -11,6 +11,9 @@ from ..querybuilder import QueryBuilder +if t.TYPE_CHECKING: + from .node import Node + class NodeCaching: """Interface to control caching of a node instance.""" @@ -35,8 +38,7 @@ def get_hash(self, ignore_errors: bool = True, **kwargs: t.Any) -> str | None: return self._get_hash(ignore_errors=ignore_errors, **kwargs) def _get_hash(self, ignore_errors: bool = True, **kwargs: t.Any) -> str | None: - """ - Return the hash for this node based on its attributes. + """Return the hash for this node based on its attributes. This will always work, even before storing. @@ -63,10 +65,10 @@ def _get_objects_to_hash(self) -> list[t.Any]: { key: val for key, val in self._node.base.attributes.items() - if key not in self._node._hash_ignored_attributes and key not in self._node._updatable_attributes # pylint: disable=unsupported-membership-test,protected-access + if key not in self._node._hash_ignored_attributes and key not in self._node._updatable_attributes }, self._node.base.repository.hash(), - self._node.computer.uuid if self._node.computer is not None else None + self._node.computer.uuid if self._node.computer is not None else None, ] return objects @@ -121,8 +123,7 @@ def get_all_same_nodes(self) -> list['Node']: return list(self._iter_all_same_nodes()) def _iter_all_same_nodes(self, allow_before_store=False) -> t.Iterator['Node']: - """ - Returns an iterator of all same nodes. + """Returns an iterator of all same nodes. Note: this should be only called on stored nodes, or internally from .store() since it first calls clean_value() on the attributes to normalise them. @@ -132,14 +133,16 @@ def _iter_all_same_nodes(self, allow_before_store=False) -> t.Iterator['Node']: node_hash = self._get_hash() - if not node_hash or not self._node._cachable: # pylint: disable=protected-access + if not node_hash or not self._node._cachable: return iter(()) builder = QueryBuilder(backend=self._node.backend) builder.append(self._node.__class__, filters={f'extras.{self._HASH_EXTRA_KEY}': node_hash}, subclassing=False) return ( - node for node, in builder.iterall() if node.base.caching.is_valid_cache # type: ignore[misc,union-attr] + node + for (node,) in builder.iterall() + if node.base.caching.is_valid_cache # type: ignore[misc,union-attr] ) @property diff --git a/aiida/orm/nodes/comments.py b/aiida/orm/nodes/comments.py index 8284bc9261..69aaf1c07e 100644 --- a/aiida/orm/nodes/comments.py +++ b/aiida/orm/nodes/comments.py @@ -7,6 +7,9 @@ from ..comments import Comment from ..users import User +if t.TYPE_CHECKING: + from .node import Node + class NodeComments: """Interface for comments of a node instance.""" @@ -40,10 +43,9 @@ def all(self) -> list[Comment]: :return: the list of comments, sorted by pk """ - return Comment.get_collection(self._node.backend - ).find(filters={'dbnode_id': self._node.pk}, order_by=[{ - 'id': 'asc' - }]) + return Comment.get_collection(self._node.backend).find( + filters={'dbnode_id': self._node.pk}, order_by=[{'id': 'asc'}] + ) def update(self, identifier: int, content: str) -> None: """Update the content of an existing comment. diff --git a/aiida/orm/nodes/data/__init__.py b/aiida/orm/nodes/data/__init__.py index 395de5f979..254ec884e3 100644 --- a/aiida/orm/nodes/data/__init__.py +++ b/aiida/orm/nodes/data/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .array import * from .base import * @@ -76,4 +75,4 @@ 'to_aiida_type', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/nodes/data/array/__init__.py b/aiida/orm/nodes/data/array/__init__.py index f12feedfbe..2d0ccb8c97 100644 --- a/aiida/orm/nodes/data/array/__init__.py +++ b/aiida/orm/nodes/data/array/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .array import * from .bands import * @@ -31,4 +30,4 @@ 'find_bandgap', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/nodes/data/array/array.py b/aiida/orm/nodes/data/array/array.py index cd7c0f5a0c..e4b1a07628 100644 --- a/aiida/orm/nodes/data/array/array.py +++ b/aiida/orm/nodes/data/array/array.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -AiiDA ORM data class storing (numpy) arrays +"""AiiDA ORM data class storing (numpy) arrays """ from __future__ import annotations @@ -28,21 +27,22 @@ def _(value): class ArrayData(Data): - """ - Store a set of arrays on disk (rather than on the database) in an efficient - way using numpy.save() (therefore, this class requires numpy to be - installed). + """Store a set of arrays on disk (rather than on the database) in an efficient way + + Arrays are stored using numpy and therefore this class requires numpy to be installed. Each array is stored within the Node folder as a different .npy file. :note: Before storing, no caching is done: if you perform a - :py:meth:`.get_array` call, the array will be re-read from disk. - If instead the ArrayData node has already been stored, - the array is cached in memory after the first read, and the cached array - is used thereafter. - If too much RAM memory is used, you can clear the - cache with the :py:meth:`.clear_internal_cache` method. + :py:meth:`.get_array` call, the array will be re-read from disk. + If instead the ArrayData node has already been stored, + the array is cached in memory after the first read, and the cached array + is used thereafter. + If too much RAM memory is used, you can clear the + cache with the :py:meth:`.clear_internal_cache` method. + """ + array_prefix = 'array|' default_array_name = 'default' @@ -75,8 +75,7 @@ def initialize(self): self._cached_arrays = {} def delete_array(self, name: str) -> None: - """ - Delete an array from the node. Can only be called before storing. + """Delete an array from the node. Can only be called before storing. :param name: The name of the array to delete from the node. """ @@ -93,8 +92,7 @@ def delete_array(self, name: str) -> None: pass def get_arraynames(self) -> list[str]: - """ - Return a list of all arrays stored in the node, listing the files (and + """Return a list of all arrays stored in the node, listing the files (and not relying on the properties). .. versionadded:: 0.7 @@ -103,22 +101,19 @@ def get_arraynames(self) -> list[str]: return self._arraynames_from_properties() def _arraynames_from_files(self) -> list[str]: - """ - Return a list of all arrays stored in the node, listing the files (and + """Return a list of all arrays stored in the node, listing the files (and not relying on the properties). """ return [i[:-4] for i in self.base.repository.list_object_names() if i.endswith('.npy')] def _arraynames_from_properties(self) -> list[str]: - """ - Return a list of all arrays stored in the node, listing the attributes + """Return a list of all arrays stored in the node, listing the attributes starting with the correct prefix. """ - return [i[len(self.array_prefix):] for i in self.base.attributes.keys() if i.startswith(self.array_prefix)] + return [i[len(self.array_prefix) :] for i in self.base.attributes.keys() if i.startswith(self.array_prefix)] def get_shape(self, name: str) -> tuple[int, ...]: - """ - Return the shape of an array (read from the value cached in the + """Return the shape of an array (read from the value cached in the properties for efficiency reasons). :param name: The name of the array. @@ -126,8 +121,7 @@ def get_shape(self, name: str) -> tuple[int, ...]: return tuple(self.base.attributes.get(f'{self.array_prefix}{name}')) def get_iterarrays(self) -> Iterator[tuple[str, ndarray]]: - """ - Iterator that returns tuples (name, array) for each array stored in the node. + """Iterator that returns tuples (name, array) for each array stored in the node. .. versionadded:: 1.0 Renamed from iterarrays @@ -136,8 +130,7 @@ def get_iterarrays(self) -> Iterator[tuple[str, ndarray]]: yield (name, self.get_array(name)) def get_array(self, name: str | None = None) -> ndarray: - """ - Return an array stored in the node + """Return an array stored in the node :param name: The name of the array to return. The name can be omitted in case the node contains only a single array, which will be returned in that case. If ``name`` is ``None`` and the node contains multiple arrays or @@ -166,7 +159,7 @@ def get_array_from_file(self, name: str) -> ndarray: # Open a handle in binary read mode as the arrays are written as binary files as well with self.base.repository.open(filename, mode='rb') as handle: - return numpy.load(handle, allow_pickle=False) # pylint: disable=unexpected-keyword-arg + return numpy.load(handle, allow_pickle=False) # Return with proper caching if the node is stored, otherwise always re-read from disk if not self.is_stored: @@ -178,8 +171,7 @@ def get_array_from_file(self, name: str) -> ndarray: return self._cached_arrays[name] def clear_internal_cache(self) -> None: - """ - Clear the internal memory cache where the arrays are stored after being + """Clear the internal memory cache where the arrays are stored after being read from disk (used in order to reduce at minimum the readings from disk). This function is useful if you want to keep the node in memory, but you @@ -188,8 +180,7 @@ def clear_internal_cache(self) -> None: self._cached_arrays = {} def set_array(self, name: str, array: ndarray) -> None: - """ - Store a new numpy array inside the node. Possibly overwrite the array + """Store a new numpy array inside the node. Possibly overwrite the array if it already existed. Internally, it stores a name.npy file in numpy format. @@ -227,8 +218,7 @@ def set_array(self, name: str, array: ndarray) -> None: self.base.attributes.set(f'{self.array_prefix}{name}', list(array.shape)) def _validate(self) -> bool: - """ - Check if the list of .npy files stored inside the node and the + """Check if the list of .npy files stored inside the node and the list of properties match. Just a name check, no check on the size since this would require to reload all arrays and this may take time and memory. @@ -251,14 +241,12 @@ def _get_array_entries(self) -> dict[str, Any]: the value is the numpy array transformed into a list. This is so that it can be transformed into a json object. """ - array_dict = {} for key, val in self.get_iterarrays(): - array_dict[key] = clean_array(val) return array_dict - def _prepare_json(self, main_file_name='', comments=True) -> tuple[bytes, dict]: # pylint: disable=unused-argument + def _prepare_json(self, main_file_name='', comments=True) -> tuple[bytes, dict]: """Dump the content of the arrays stored in this node into JSON format. :param comments: if True, includes comments (if it makes sense for the given format) @@ -277,8 +265,7 @@ def _prepare_json(self, main_file_name='', comments=True) -> tuple[bytes, dict]: def clean_array(array: ndarray) -> list: - """ - Replacing np.nan and np.inf/-np.inf for Nones. + """Replacing np.nan and np.inf/-np.inf for Nones. The function will also sanitize the array removing ``np.nan`` and ``np.inf`` for ``None`` of this way the resulting JSON is always valid. @@ -293,9 +280,10 @@ def clean_array(array: ndarray) -> list: import numpy as np output = np.reshape( - np.asarray([ - entry if not np.isnan(entry) and not np.isinf(entry) else None for entry in array.flatten().tolist() - ]), array.shape + np.asarray( + [entry if not np.isnan(entry) and not np.isinf(entry) else None for entry in array.flatten().tolist()] + ), + array.shape, ) return output.tolist() diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py index 6bf54eaafd..df480082eb 100644 --- a/aiida/orm/nodes/data/array/bands.py +++ b/aiida/orm/nodes/data/array/bands.py @@ -7,9 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines -""" -This module defines the classes related to band structures or dispersions +"""This module defines the classes related to band structures or dispersions in a Brillouin zone, and how to operate on them. """ import json @@ -44,8 +42,7 @@ def prepare_header_comment(uuid, plot_info, comment_char='#'): def find_bandgap(bandsdata, number_electrons=None, fermi_energy=None): - """ - Tries to guess whether the bandsdata represent an insulator. + """Tries to guess whether the bandsdata represent an insulator. This method is meant to be used only for electronic bands (not phonons) By default, it will try to use the occupations to guess the number of electrons and find the Fermi Energy, otherwise, it can be provided @@ -73,15 +70,11 @@ def find_bandgap(bandsdata, number_electrons=None, fermi_energy=None): equal to the lumo (e.g. in semi-metals). """ - # pylint: disable=too-many-return-statements,too-many-branches,too-many-statements,no-else-return - def nint(num): - """ - Stable rounding function - """ + """Stable rounding function""" if num > 0: - return int(num + .5) - return int(num - .5) + return int(num + 0.5) + return int(num - 0.5) if fermi_energy and number_electrons: raise ValueError('Specify either the number of electrons or the Fermi energy, but not both') @@ -100,7 +93,6 @@ def nint(num): # analysis on occupations: if fermi_energy is None: - num_kpoints = len(bands) if number_electrons is None: @@ -125,9 +117,11 @@ def nint(num): # sort the bands by energy, and reorder the occupations accordingly # since after joining the two spins, I might have unsorted stuff bands, occupations = [ - numpy.array(y) for y in zip( + numpy.array(y) + for y in zip( *[ - list(zip(*j)) for j in [ + list(zip(*j)) + for j in [ sorted(zip(i[0].tolist(), i[1].tolist()), key=lambda x: x[0]) for i in zip(bands, occupations) ] @@ -145,8 +139,7 @@ def nint(num): lumo = [_[0][_[1] + 1] for _ in zip(bands, homo_indexes)] except IndexError: raise ValueError( - 'To understand if it is a metal or insulator, ' - 'need more bands than n_band=number_electrons' + 'To understand if it is a metal or insulator, ' 'need more bands than n_band=number_electrons' ) else: @@ -163,8 +156,7 @@ def nint(num): lumo = [i[number_electrons // number_electrons_per_band] for i in bands] # take the n+1th level except IndexError: raise ValueError( - 'To understand if it is a metal or insulator, ' - 'need more bands than n_band=number_electrons' + 'To understand if it is a metal or insulator, ' 'need more bands than n_band=number_electrons' ) if number_electrons % 2 == 1 and len(stored_bands.shape) == 2: @@ -174,10 +166,10 @@ def nint(num): # if the nth band crosses the (n+1)th, it is an insulator gap = min(lumo) - max(homo) - if gap == 0.: - return False, 0. + if gap == 0.0: + return False, 0.0 - if gap < 0.: + if gap < 0.0: return False, None return True, gap @@ -198,32 +190,29 @@ def nint(num): raise ValueError("The Fermi energy is below all band energies, don't know what to do.") # one band is crossed by the fermi energy - if any(i[1] < fermi_energy and fermi_energy < i[0] for i in max_mins): # pylint: disable=chained-comparison + if any(i[1] < fermi_energy and fermi_energy < i[0] for i in max_mins): return False, None # case of semimetals, fermi energy at the crossing of two bands # this will only work if the dirac point is computed! - if (any(i[0] == fermi_energy for i in max_mins) and any(i[1] == fermi_energy for i in max_mins)): - return False, 0. + if any(i[0] == fermi_energy for i in max_mins) and any(i[1] == fermi_energy for i in max_mins): + return False, 0.0 # insulating case, take the max of the band maxima below the fermi energy homo = max(i[0] for i in max_mins if i[0] < fermi_energy) # take the min of the band minima above the fermi energy lumo = min(i[1] for i in max_mins if i[1] > fermi_energy) gap = lumo - homo - if gap <= 0.: + if gap <= 0.0: raise RuntimeError('Something wrong has been implemented. Revise the code!') return True, gap class BandsData(KpointsData): - """ - Class to handle bands data - """ + """Class to handle bands data""" def set_kpointsdata(self, kpointsdata): - """ - Load the kpoints from a kpoint object. + """Load the kpoints from a kpoint object. :param kpointsdata: an instance of KpointsData class """ if not isinstance(kpointsdata, KpointsData): @@ -251,14 +240,12 @@ def set_kpointsdata(self, kpointsdata): self.labels = [] def _validate_bands_occupations(self, bands, occupations=None, labels=None): - """ - Validate the list of bands and of occupations before storage. + """Validate the list of bands and of occupations before storage. Kpoints must be set in advance. Bands and occupations must be convertible into arrays of Nkpoints x Nbands floats or Nspins x Nkpoints x Nbands; Nkpoints must correspond to the number of kpoints. """ - # pylint: disable=too-many-branches try: kpoints = self.get_kpoints() except AttributeError: @@ -311,8 +298,9 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None): the_labels = [str(_) for _ in labels] else: raise ValidationError( - 'Band labels have an unrecognized type ({})' - 'but should be a string or a list of strings'.format(labels.__class__) + 'Band labels have an unrecognized type ({})' 'but should be a string or a list of strings'.format( + labels.__class__ + ) ) if len(the_bands.shape) == 2 and len(the_labels) != 1: @@ -325,8 +313,7 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None): return the_bands, the_occupations, the_labels def set_bands(self, bands, units=None, occupations=None, labels=None): - """ - Set an array of band energies of dimension (nkpoints x nbands). + """Set an array of band energies of dimension (nkpoints x nbands). Kpoints must be set in advance. Can contain floats or None. :param bands: a list of nkpoints lists of nbands bands, or a 2D array of shape (nkpoints x nbands), with band energies for each kpoint @@ -349,32 +336,25 @@ def set_bands(self, bands, units=None, occupations=None, labels=None): @property def array_labels(self): - """ - Get the labels associated with the band arrays - """ + """Get the labels associated with the band arrays""" return self.base.attributes.get('array_labels', None) @property def units(self): - """ - Units in which the data in bands were stored. A string - """ + """Units in which the data in bands were stored. A string""" # return copy.deepcopy(self._pbc) return self.base.attributes.get('units') @units.setter def units(self, value): - """ - Set the value of pbc, i.e. a tuple of three booleans, indicating if the + """Set the value of pbc, i.e. a tuple of three booleans, indicating if the cell is periodic in the 1,2,3 crystal direction """ the_str = str(value) self.base.attributes.set('units', the_str) def _set_pbc(self, value): - """ - validate the pbc, then store them - """ + """Validate the pbc, then store them""" from aiida.common.exceptions import ModificationNotAllowed from aiida.orm.nodes.data.structure import get_valid_pbc @@ -386,8 +366,7 @@ def _set_pbc(self, value): self.base.attributes.set('pbc3', the_pbc[2]) def get_bands(self, also_occupations=False, also_labels=False): - """ - Returns an array (nkpoints x num_bands or nspins x nkpoints x num_bands) + """Returns an array (nkpoints x num_bands or nspins x nkpoints x num_bands) of energies. :param also_occupations: if True, returns also the occupations array. Default = False @@ -414,9 +393,8 @@ def get_bands(self, also_occupations=False, also_labels=False): return to_return - def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, get_segments=False, y_origin=0.): - """ - Get data to plot a band structure + def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, get_segments=False, y_origin=0.0): + """Get data to plot a band structure :param cartesian: if True, distances (for the x-axis) are computed in cartesian coordinates, otherwise they are computed in reciprocal @@ -438,7 +416,6 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, depending on the type of spin; the length is always equalt to the total number of bands per kpoint). """ - # pylint: disable=too-many-locals,too-many-branches,too-many-statements # load the x and y's of the graph stored_bands = self.get_bands() if len(stored_bands.shape) == 2: @@ -475,8 +452,9 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, # as a result, where there are discontinuities in the path, # I have two consecutive points with the same x coordinate distances = [ - numpy.linalg.norm(kpoints[i] - - kpoints[i - 1]) if not (i in labels_indices and i - 1 in labels_indices) else 0. + numpy.linalg.norm(kpoints[i] - kpoints[i - 1]) + if not (i in labels_indices and i - 1 in labels_indices) + else 0.0 for i in range(1, len(kpoints)) ] x = [float(sum(distances[:i])) for i in range(len(distances) + 1)] @@ -520,8 +498,8 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, 'length': position_to - position_from, 'from': label_from, 'to': label_to, - 'values': bands[position_from:position_to + 1, :].transpose().tolist(), - 'x': x[position_from:position_to + 1], + 'values': bands[position_from : position_to + 1, :].transpose().tolist(), + 'x': x[position_from : position_to + 1], 'two_band_types': two_band_types, } plot_info['paths'].append(path_dict) @@ -542,8 +520,7 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, return plot_info def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=None): - """ - Prepare two files, data and batch, to be plot with xmgrace as: + """Prepare two files, data and batch, to be plot with xmgrace as: xmgrace -batch file.dat :param main_file_name: if the user asks to write the main content on a @@ -555,7 +532,6 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N :param prettify_format: if None, use the default prettify format. Otherwise specify a string with the prettifier to use. """ - # pylint: disable=too-many-locals import os dat_filename = os.path.splitext(main_file_name)[0] + '_data.dat' @@ -623,9 +599,8 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N return batch_data.encode('utf-8'), extra_files - def _prepare_dat_multicolumn(self, main_file_name='', comments=True): # pylint: disable=unused-argument - """ - Write an N x M matrix. First column is the distance between kpoints, + def _prepare_dat_multicolumn(self, main_file_name='', comments=True): + """Write an N x M matrix. First column is the distance between kpoints, The other columns are the bands. Header contains number of kpoints and the number of bands (commented). @@ -647,9 +622,8 @@ def _prepare_dat_multicolumn(self, main_file_name='', comments=True): # pylint: return ('\n'.join(return_text) + '\n').encode('utf-8'), {} - def _prepare_dat_blocks(self, main_file_name='', comments=True): # pylint: disable=unused-argument - """ - Format suitable for gnuplot using blocks. + def _prepare_dat_blocks(self, main_file_name='', comments=True): + """Format suitable for gnuplot using blocks. Columns with x and y (path and band energy). Several blocks, separated by two empty lines, one per energy band. @@ -683,12 +657,11 @@ def _matplotlib_get_dict( legend2=None, y_max_lim=None, y_min_lim=None, - y_origin=0., + y_origin=0.0, prettify_format=None, - **kwargs - ): # pylint: disable=unused-argument - """ - Prepare the data to send to the python-matplotlib plotting script. + **kwargs, + ): + """Prepare the data to send to the python-matplotlib plotting script. :param comments: if True, print comments (if it makes sense for the given format) @@ -713,8 +686,6 @@ def _matplotlib_get_dict( :param kwargs: additional customization variables; only a subset is accepted, see internal variable 'valid_additional_keywords """ - # pylint: disable=too-many-arguments,too-many-locals - # Only these keywords are accepted in kwargs, and then set into the json valid_additional_keywords = [ 'bands_color', # Color of band lines @@ -759,7 +730,7 @@ def _matplotlib_get_dict( prettify_format=prettify_format, join_symbol=join_symbol, get_segments=True, - y_origin=y_origin + y_origin=y_origin, ) all_data = {} @@ -806,8 +777,7 @@ def _matplotlib_get_dict( return all_data def _prepare_mpl_singlefile(self, *args, **kwargs): - """ - Prepare a python script using matplotlib to plot the bands + """Prepare a python script using matplotlib to plot the bands For the possible parameters, see documentation of :py:meth:`~aiida.orm.nodes.data.array.bands.BandsData._matplotlib_get_dict` @@ -823,9 +793,8 @@ def _prepare_mpl_singlefile(self, *args, **kwargs): return string.encode('utf-8'), {} - def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg - """ - Prepare a python script using matplotlib to plot the bands, with the JSON + def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): + """Prepare a python script using matplotlib to plot the bands, with the JSON returned as an independent file. For the possible parameters, see documentation of @@ -837,7 +806,7 @@ def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): # pylint: json_fname = os.path.splitext(main_file_name)[0] + '_data.json' # Escape double_quotes - json_fname = json_fname.replace('"', '\"') + json_fname = json_fname.replace('"', '"') ext_files = {json_fname: json.dumps(all_data, indent=2).encode('utf-8')} @@ -850,9 +819,8 @@ def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): # pylint: return string.encode('utf-8'), ext_files - def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg,unused-argument - """ - Prepare a python script using matplotlib to plot the bands, with the JSON + def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): + """Prepare a python script using matplotlib to plot the bands, with the JSON returned as an independent file. For the possible parameters, see documentation of @@ -875,7 +843,7 @@ def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): # pylint: disab os.close(handle) os.remove(filename) - escaped_fname = filename.replace('"', '\"') + escaped_fname = filename.replace('"', '"') s_footer = MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE.substitute(fname=escaped_fname, format='pdf') @@ -898,9 +866,8 @@ def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): # pylint: disab return imgdata, {} - def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg,unused-argument - """ - Prepare a python script using matplotlib to plot the bands, with the JSON + def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): + """Prepare a python script using matplotlib to plot the bands, with the JSON returned as an independent file. For the possible parameters, see documentation of @@ -923,7 +890,7 @@ def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): # pylint: disab os.close(handle) os.remove(filename) - escaped_fname = filename.replace('"', '\"') + escaped_fname = filename.replace('"', '"') s_footer = MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI.substitute(fname=escaped_fname, format='png', dpi=300) @@ -948,9 +915,7 @@ def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): # pylint: disab @staticmethod def _get_mpl_body_template(paths): - """ - :param paths: paths of k-points - """ + """:param paths: paths of k-points""" if len(paths) == 1: s_body = MATPLOTLIB_BODY_TEMPLATE.substitute(plot_code=SINGLE_KP) else: @@ -958,14 +923,13 @@ def _get_mpl_body_template(paths): return s_body def show_mpl(self, **kwargs): - """ - Call a show() command for the band structure using matplotlib. + """Call a show() command for the band structure using matplotlib. This uses internally the 'mpl_singlefile' format, with empty main_file_name. Other kwargs are passed to self._exportcontent. """ - exec(*self._exportcontent(fileformat='mpl_singlefile', main_file_name='', **kwargs)) # pylint: disable=exec-used + exec(*self._exportcontent(fileformat='mpl_singlefile', main_file_name='', **kwargs)) def _prepare_gnuplot( self, @@ -975,10 +939,9 @@ def _prepare_gnuplot( prettify_format=None, y_max_lim=None, y_min_lim=None, - y_origin=0. + y_origin=0.0, ): - """ - Prepare an gnuplot script to plot the bands, with the .dat file + """Prepare an gnuplot script to plot the bands, with the .dat file returned as an independent file. :param main_file_name: if the user asks to write the main content on a @@ -991,7 +954,6 @@ def _prepare_gnuplot( :param prettify_format: if None, use the default prettify format. Otherwise specify a string with the prettifier to use. """ - # pylint: disable=too-many-arguments,too-many-locals import os main_file_name = main_file_name or 'band.dat' @@ -1053,21 +1015,21 @@ def _prepare_gnuplot( script.append(f'set xtics ({xtics_string})') script.append('unset key') script.append(f'set yrange [{y_min_lim}:{y_max_lim}]') - script.append(f"set ylabel \"Dispersion ({self.units})\"") + script.append(f'set ylabel "Dispersion ({self.units})"') if title: - script.append('set title "{}"'.format(title.replace('"', '\"'))) + script.append('set title "{}"'.format(title.replace('"', '"'))) # Plot, escaping filename if len(x) > 1: script.append(f'set xrange [{x_min_lim}:{x_max_lim}]') script.append('set grid xtics lt 1 lc rgb "#888888"') - script.append('plot "{}" with l lc rgb "#000000"'.format(os.path.basename(dat_filename).replace('"', '\"'))) + script.append('plot "{}" with l lc rgb "#000000"'.format(os.path.basename(dat_filename).replace('"', '"'))) else: script.append('set xrange [-1.0:1.0]') script.append( 'plot "{}" using ($1-0.25):($2):(0.5):(0) with vectors nohead lc rgb "#000000"'.format( - os.path.basename(dat_filename).replace('"', '\"') + os.path.basename(dat_filename).replace('"', '"') ) ) @@ -1087,11 +1049,10 @@ def _prepare_agr( title='', y_max_lim=None, y_min_lim=None, - y_origin=0., - prettify_format=None + y_origin=0.0, + prettify_format=None, ): - """ - Prepare an xmgrace agr file. + """Prepare an xmgrace agr file. :param comments: if True, print comments (if it makes sense for the given format) @@ -1117,7 +1078,6 @@ def _prepare_agr( :param prettify_format: if None, use the default prettify format. Otherwise specify a string with the prettifier to use. """ - # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unused-argument if prettify_format is None: # Default. Specified like this to allow caller functions to pass 'None' prettify_format = 'agr_seekpath' @@ -1147,7 +1107,7 @@ def _prepare_agr( y_min_lim = the_bands.min() x_min_lim = min(x) # this isn't a numpy array, but a list x_max_lim = max(x) - ytick_spacing = 10**int(math.log10((y_max_lim - y_min_lim))) + ytick_spacing = 10 ** int(math.log10((y_max_lim - y_min_lim))) # prepare xticks labels sx1 = '' @@ -1182,7 +1142,7 @@ def _prepare_agr( set_number=i + setnumber_offset, linewidth=width, color_number=linecolor, - legend=legend if i == 0 else '' + legend=legend if i == 0 else '', ) units = self.units @@ -1223,9 +1183,8 @@ def _get_band_segments(self, cartesian): return out_dict - def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=unused-argument - """ - Prepare a json file in a format compatible with the AiiDA band visualizer + def _prepare_json(self, main_file_name='', comments=True): + """Prepare a json file in a format compatible with the AiiDA band visualizer :param comments: if True, print comments (if it makes sense for the given format) @@ -1372,10 +1331,12 @@ def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=un """ ) -AGR_XTICKS_TEMPLATE = Template(""" +AGR_XTICKS_TEMPLATE = Template( + """ @ xaxis tick spec $num_labels $single_xtick_templates - """) + """ +) AGR_SINGLE_XTICK_TEMPLATE = Template( """ @@ -1599,11 +1560,13 @@ def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=un """ ) -AGR_SINGLESET_TEMPLATE = Template(""" +AGR_SINGLESET_TEMPLATE = Template( + """ @target G0.S$set_number @type xy $xydata - """) + """ +) MATPLOTLIB_HEADER_AGG_TEMPLATE = Template( """# -*- coding: utf-8 -*- @@ -1650,8 +1613,10 @@ def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=un """ ) -MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE = Template('''all_data_str = r"""$all_data_json""" -''') +MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE = Template( + '''all_data_str = r"""$all_data_json""" +''' +) MATPLOTLIB_IMPORT_DATA_FROMFILE_TEMPLATE = Template( """with open("$json_fname", encoding='utf8') as f: @@ -1793,8 +1758,6 @@ def get_bands_and_parents_structure(args, backend=None): A list of sublists, each latter containing (in order): pk as string, formula as string, creation date, bandsdata-label """ - # pylint: disable=too-many-locals,too-many-branches - import datetime from aiida import orm @@ -1840,7 +1803,7 @@ def get_bands_and_parents_structure(args, backend=None): tag='sdata', with_descendants='bdata', # We don't care about the creator of StructureData - project=['id', 'attributes.kinds', 'attributes.sites', 'ctime'] + project=['id', 'attributes.kinds', 'attributes.sites', 'ctime'], ) q_build.order_by({orm.StructureData: {'ctime': 'desc'}}) @@ -1854,7 +1817,6 @@ def get_bands_and_parents_structure(args, backend=None): already_visited_bdata = set() for [bid, blabel, bdate] in bands_list_data: - # We process only one StructureData per BandsData. # We want to process the closest StructureData to # every BandsData. @@ -1870,11 +1832,10 @@ def get_bands_and_parents_structure(args, backend=None): if strct is not None: akinds, asites = strct formula = _extract_formula(akinds, asites, args) + elif args.element is not None or args.element_only is not None: + formula = None else: - if args.element is not None or args.element_only is not None: - formula = None - else: - formula = '<>' + formula = '<>' if formula is None: continue @@ -1884,8 +1845,7 @@ def get_bands_and_parents_structure(args, backend=None): def _extract_formula(akinds, asites, args): - """ - Extract formula from the structure object. + """Extract formula from the structure object. :param akinds: list of kinds, e.g. [{'mass': 55.845, 'name': 'Fe', 'symbols': ['Fe'], 'weights': [1.0]}, {'mass': 15.9994, 'name': 'O', 'symbols': ['O'], 'weights': [1.0]}] diff --git a/aiida/orm/nodes/data/array/kpoints.py b/aiida/orm/nodes/data/array/kpoints.py index 1f7432d8f2..168bf0b917 100644 --- a/aiida/orm/nodes/data/array/kpoints.py +++ b/aiida/orm/nodes/data/array/kpoints.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module of the KpointsData class, defining the AiiDA data type for storing +"""Module of the KpointsData class, defining the AiiDA data type for storing lists and meshes of k-points (i.e., points in the reciprocal space of a periodic crystal structure). """ @@ -23,8 +22,7 @@ class KpointsData(ArrayData): - """ - Class to handle array of kpoints in the Brillouin zone. Provide methods to + """Class to handle array of kpoints in the Brillouin zone. Provide methods to generate either user-defined k-points or path of k-points along symmetry lines. Internally, all k-points are defined in terms of crystal (fractional) @@ -37,8 +35,7 @@ class KpointsData(ArrayData): """ def get_description(self): - """ - Returns a string with infos retrieved from kpoints node's properties. + """Returns a string with infos retrieved from kpoints node's properties. :param node: :return: retstr """ @@ -55,23 +52,20 @@ def get_description(self): @property def cell(self): - """ - The crystal unit cell. Rows are the crystal vectors in Angstroms. + """The crystal unit cell. Rows are the crystal vectors in Angstroms. :return: a 3x3 numpy.array """ return numpy.array(self.base.attributes.get('cell')) @cell.setter def cell(self, value): - """ - Set the crystal unit cell + """Set the crystal unit cell :param value: a 3x3 list/tuple/array of numbers (units = Angstroms). """ self._set_cell(value) def _set_cell(self, value): - """ - Validate if 'value' is a allowed crystal unit cell + """Validate if 'value' is a allowed crystal unit cell :param value: something compatible with a 3x3 tuple of floats """ from aiida.common.exceptions import ModificationNotAllowed @@ -86,8 +80,7 @@ def _set_cell(self, value): @property def pbc(self): - """ - The periodic boundary conditions along the vectors a1,a2,a3. + """The periodic boundary conditions along the vectors a1,a2,a3. :return: a tuple of three booleans, each one tells if there are periodic boundary conditions for the i-th real-space direction (i=1,2,3) @@ -97,16 +90,13 @@ def pbc(self): @pbc.setter def pbc(self, value): - """ - Set the value of pbc, i.e. a tuple of three booleans, indicating if the + """Set the value of pbc, i.e. a tuple of three booleans, indicating if the cell is periodic in the 1,2,3 crystal direction """ self._set_pbc(value) def _set_pbc(self, value): - """ - validate the pbc, then store them - """ + """Validate the pbc, then store them""" from aiida.common.exceptions import ModificationNotAllowed from aiida.orm.nodes.data.structure import get_valid_pbc @@ -119,8 +109,7 @@ def _set_pbc(self, value): @property def labels(self): - """ - Labels associated with the list of kpoints. + """Labels associated with the list of kpoints. List of tuples with kpoint index and kpoint name: ``[(0,'G'),(13,'M'),...]`` """ label_numbers = self.base.attributes.get('label_numbers', None) @@ -134,9 +123,7 @@ def labels(self, value): self._set_labels(value) def _set_labels(self, value): - """ - set label names. Must pass in input a list like: ``[[0,'X'],[34,'L'],... ]`` - """ + """Set label names. Must pass in input a list like: ``[[0,'X'],[34,'L'],... ]``""" # check if kpoints were set try: self.get_kpoints() @@ -159,8 +146,7 @@ def _set_labels(self, value): self.base.attributes.set('labels', labels) def _change_reference(self, kpoints, to_cartesian=True): - """ - Change reference system, from cartesian to crystal coordinates (units of b1,b2,b3) or viceversa. + """Change reference system, from cartesian to crystal coordinates (units of b1,b2,b3) or viceversa. :param kpoints: a list of (3) point coordinates :return kpoints: a list of (3) point coordinates in the new reference """ @@ -184,8 +170,7 @@ def _change_reference(self, kpoints, to_cartesian=True): return numpy.transpose(numpy.dot(matrix, numpy.transpose(kpoints))) def set_cell_from_structure(self, structuredata): - """ - Set a cell to be used for symmetry analysis from an AiiDA structure. + """Set a cell to be used for symmetry analysis from an AiiDA structure. Inherits both the cell and the pbc's. To set manually a cell, use "set_cell" @@ -195,15 +180,15 @@ def set_cell_from_structure(self, structuredata): if not isinstance(structuredata, StructureData): raise ValueError( - 'An instance of StructureData should be passed to ' - 'the KpointsData, found instead {}'.format(structuredata.__class__) + 'An instance of StructureData should be passed to ' 'the KpointsData, found instead {}'.format( + structuredata.__class__ + ) ) cell = structuredata.cell self.set_cell(cell, structuredata.pbc) def set_cell(self, cell, pbc=None): - """ - Set a cell to be used for symmetry analysis. + """Set a cell to be used for symmetry analysis. To set a cell from an AiiDA structure, use "set_cell_from_structure". :param cell: 3x3 matrix of cell vectors. Orientation: each row @@ -218,19 +203,17 @@ def set_cell(self, cell, pbc=None): @property def reciprocal_cell(self): - """ - Compute reciprocal cell from the internally set cell. + """Compute reciprocal cell from the internally set cell. :returns: reciprocal cell in units of 1/Angstrom with cell vectors stored as rows. Use e.g. reciprocal_cell[0] to access the first reciprocal cell vector. """ the_cell = numpy.array(self.cell) - reciprocal_cell = 2. * numpy.pi * numpy.linalg.inv(the_cell).transpose() + reciprocal_cell = 2.0 * numpy.pi * numpy.linalg.inv(the_cell).transpose() return reciprocal_cell def set_kpoints_mesh(self, mesh, offset=None): - """ - Set KpointsData to represent a uniformily spaced mesh of kpoints in the + """Set KpointsData to represent a uniformily spaced mesh of kpoints in the Brillouin zone. This excludes the possibility of set/get kpoints :param mesh: a list of three integers, representing the size of the @@ -251,7 +234,7 @@ def set_kpoints_mesh(self, mesh, offset=None): except (IndexError, ValueError, TypeError): raise ValueError('The kpoint mesh must be a list of three integers') if offset is None: - offset = [0., 0., 0.] + offset = [0.0, 0.0, 0.0] try: the_offset = [float(i) for i in offset] if len(the_offset) != 3: @@ -271,8 +254,7 @@ def set_kpoints_mesh(self, mesh, offset=None): self.base.attributes.set('offset', the_offset) def get_kpoints_mesh(self, print_list=False): - """ - Get the mesh of kpoints. + """Get the mesh of kpoints. :param print_list: default=False. If True, prints the mesh of kpoints as a list @@ -288,7 +270,7 @@ def get_kpoints_mesh(self, print_list=False): if not print_list: return mesh, offset - kpoints = numpy.mgrid[0:mesh[0], 0:mesh[1], 0:mesh[2]] + kpoints = numpy.mgrid[0 : mesh[0], 0 : mesh[1], 0 : mesh[2]] kpoints = kpoints.reshape(3, -1).T offset_kpoints = kpoints + numpy.array(offset) offset_kpoints[:, 0] /= mesh[0] @@ -297,8 +279,7 @@ def get_kpoints_mesh(self, print_list=False): return offset_kpoints def set_kpoints_mesh_from_density(self, distance, offset=None, force_parity=False): - """ - Set a kpoints mesh using a kpoints density, expressed as the maximum + """Set a kpoints mesh using a kpoints density, expressed as the maximum distance between adjacent points along a reciprocal axis :param distance: distance (in 1/Angstrom) between adjacent @@ -316,7 +297,7 @@ def set_kpoints_mesh_from_density(self, distance, offset=None, force_parity=Fals :note: the number of kpoints along non-periodic axes is always 1. """ if offset is None: - offset = [0., 0., 0.] + offset = [0.0, 0.0, 0.0] try: rec_cell = self.reciprocal_cell @@ -335,8 +316,7 @@ def set_kpoints_mesh_from_density(self, distance, offset=None, force_parity=Fals @property def _dimension(self): - """ - Dimensionality of the structure, found from its pbc (i.e. 1 if it's a 1D + """Dimensionality of the structure, found from its pbc (i.e. 1 if it's a 1D structure, 2 if its 2D, 3 if it's 3D ...). :return dimensionality: 0, 1, 2 or 3 :note: will return 3 if pbc has not been set beforehand @@ -347,8 +327,7 @@ def _dimension(self): return 3 def _validate_kpoints_weights(self, kpoints, weights): - """ - Validate the list of kpoints and of weights before storage. + """Validate the list of kpoints and of weights before storage. Kpoints and weights must be convertible respectively to an array of N x dimension and N floats """ @@ -356,10 +335,10 @@ def _validate_kpoints_weights(self, kpoints, weights): # I cannot just use `if not kpoints` because it's a numpy array and # `not` of a numpy array does not work - if len(kpoints) == 0: # pylint: disable=len-as-condition + if len(kpoints) == 0: if self._dimension == 0: # replace empty list by Gamma point - kpoints = numpy.array([[0., 0., 0.]]) + kpoints = numpy.array([[0.0, 0.0, 0.0]]) else: raise ValueError( 'empty kpoints list is valid only in zero dimension' @@ -394,26 +373,27 @@ def _validate_kpoints_weights(self, kpoints, weights): return kpoints, weights def set_kpoints(self, kpoints, cartesian=False, labels=None, weights=None, fill_values=0): - """ - Set the list of kpoints. If a mesh has already been stored, raise a - ModificationNotAllowed + """Set the list of kpoints. If a mesh has already been stored, raise a ModificationNotAllowed :param kpoints: a list of kpoints, each kpoint being a list of one, two or three coordinates, depending on self.pbc: if structure is 1D (only one True in self.pbc) one allows singletons or scalars for each k-point, if it's 2D it can be a length-2 list, and in all cases it can be a length-3 list. - Examples: - * [[0.,0.,0.],[0.1,0.1,0.1],...] for 1D, 2D or 3D - * [[0.,0.],[0.1,0.1,],...] for 1D or 2D - * [[0.],[0.1],...] for 1D - * [0., 0.1, ...] for 1D (list of scalars) + Examples + -------- + + * [[0.,0.,0.],[0.1,0.1,0.1],...] for 1D, 2D or 3D + * [[0.,0.],[0.1,0.1,],...] for 1D or 2D + * [[0.],[0.1],...] for 1D + * [0., 0.1, ...] for 1D (list of scalars) For 0D (all pbc are False), the list can be any of the above or empty - then only Gamma point is set. The value of k for the non-periodic dimension(s) is set by fill_values + :param cartesian: if True, the coordinates given in input are treated as in cartesian units. If False, the coordinates are crystal, i.e. in units of b1,b2,b3. Default = False @@ -478,8 +458,7 @@ def set_kpoints(self, kpoints, cartesian=False, labels=None, weights=None, fill_ self.labels = labels def get_kpoints(self, also_weights=False, cartesian=False): - """ - Return the list of kpoints + """Return the list of kpoints :param also_weights: if True, returns also the list of weights. Default = False diff --git a/aiida/orm/nodes/data/array/projection.py b/aiida/orm/nodes/data/array/projection.py index 7cbcdb0c04..5bb23024e9 100644 --- a/aiida/orm/nodes/data/array/projection.py +++ b/aiida/orm/nodes/data/array/projection.py @@ -23,8 +23,7 @@ class ProjectionData(OrbitalData, ArrayData): - """ - A class to handle arrays of projected wavefunction data. That is projections + """A class to handle arrays of projected wavefunction data. That is projections of a orbitals, usually an atomic-hydrogen orbital, onto a given bloch wavefunction, the bloch wavefunction being indexed by s, n, and k. E.g. the elements are the projections described as @@ -32,8 +31,7 @@ class ProjectionData(OrbitalData, ArrayData): """ def _check_projections_bands(self, projection_array): - """ - Checks to make sure that a reference bandsdata is already set, and that + """Checks to make sure that a reference bandsdata is already set, and that projection_array is of the same shape of the bands data :param projwfc_arrays: nk x nb x nwfc array, to be @@ -53,8 +51,7 @@ def _check_projections_bands(self, projection_array): raise AttributeError('These arrays are not the same shape as the bands') def set_reference_bandsdata(self, value): - """ - Sets a reference bandsdata, creates a uuid link between this data + """Sets a reference bandsdata, creates a uuid link between this data object and a bandsdata object, must be set before any projection arrays :param value: a BandsData instance, a uuid or a pk @@ -74,7 +71,7 @@ def set_reference_bandsdata(self, value): try: bands = load_node(uuid=uuid) uuid = bands.uuid - except Exception: # pylint: disable=bare-except + except Exception: raise exceptions.NotExistent( 'The value passed to set_reference_bandsdata was not associated to any bandsdata' ) @@ -82,8 +79,7 @@ def set_reference_bandsdata(self, value): self.base.attributes.set('reference_bandsdata_uuid', uuid) def get_reference_bandsdata(self): - """ - Returns the reference BandsData, using the set uuid via + """Returns the reference BandsData, using the set uuid via set_reference_bandsdata :return: a BandsData instance @@ -91,6 +87,7 @@ def get_reference_bandsdata(self): :raise exceptions.NotExistent: if the bandsdata uuid did not retrieve bandsdata """ from aiida.orm import load_node + try: uuid = self.base.attributes.get('reference_bandsdata_uuid') except AttributeError: @@ -102,8 +99,7 @@ def get_reference_bandsdata(self): return bands def _find_orbitals_and_indices(self, **kwargs): - """ - Finds all the orbitals and their indicies associated with kwargs + """Finds all the orbitals and their indicies associated with kwargs essential for retrieving the other indexed array parameters :param kwargs: kwargs that can call orbitals as in get_orbitals() @@ -119,8 +115,7 @@ def _find_orbitals_and_indices(self, **kwargs): return retrieve_indices, all_orbitals def get_pdos(self, **kwargs): - """ - Retrieves all the pdos arrays corresponding to the input kwargs + """Retrieves all the pdos arrays corresponding to the input kwargs :param kwargs: inputs describing the orbitals associated with the pdos arrays @@ -129,15 +124,18 @@ def get_pdos(self, **kwargs): """ retrieve_indices, all_orbitals = self._find_orbitals_and_indices(**kwargs) - out_list = [( - all_orbitals[i], self.get_array(f'pdos_{self._from_index_to_arrayname(i)}'), - self.get_array(f'energy_{self._from_index_to_arrayname(i)}') - ) for i in retrieve_indices] + out_list = [ + ( + all_orbitals[i], + self.get_array(f'pdos_{self._from_index_to_arrayname(i)}'), + self.get_array(f'energy_{self._from_index_to_arrayname(i)}'), + ) + for i in retrieve_indices + ] return out_list def get_projections(self, **kwargs): - """ - Retrieves all the pdos arrays corresponding to the input kwargs + """Retrieves all the pdos arrays corresponding to the input kwargs :param kwargs: inputs describing the orbitals associated with the pdos arrays @@ -153,9 +151,7 @@ def get_projections(self, **kwargs): @staticmethod def _from_index_to_arrayname(index): - """ - Used internally to determine the array names. - """ + """Used internally to determine the array names.""" return f'array_{index}' def set_projectiondata( @@ -165,10 +161,9 @@ def set_projectiondata( list_of_energy=None, list_of_pdos=None, tags=None, - bands_check=True + bands_check=True, ): - """ - Stores the projwfc_array using the projwfc_label, after validating both. + """Stores the projwfc_array using the projwfc_label, after validating both. :param list_of_orbitals: list of orbitals, of class orbital data. They should be the ones up on which the @@ -196,11 +191,8 @@ def set_projectiondata( cannot be called """ - # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - def single_to_list(item): - """ - Checks if the item is a list or tuple, and converts it to a list + """Checks if the item is a list or tuple, and converts it to a list if it is not already a list or tuple :param item: an object which may or may not be a list or tuple @@ -213,8 +205,7 @@ def single_to_list(item): return [item] def array_list_checker(array_list, array_name, orb_length): - """ - Does basic checks over everything in the array_list. Makes sure that + """Does basic checks over everything in the array_list. Makes sure that all the arrays are np.ndarray floats, that the length is same as required_length, raises exception using array_name if there is a failure @@ -287,13 +278,10 @@ def array_list_checker(array_list, array_name, orb_length): raise exceptions.ValidationError('Tags must set a list of strings') self.base.attributes.set('tags', tags) - def set_orbitals(self, **kwargs): # pylint: disable=arguments-differ - """ - This method is inherited from OrbitalData, but is blocked here. + def set_orbitals(self, **kwargs): + """This method is inherited from OrbitalData, but is blocked here. If used will raise a NotImplementedError """ raise NotImplementedError( - 'You cannot set orbitals using this class!' - ' This class is for setting orbitals and ' - ' projections only!' + 'You cannot set orbitals using this class!' ' This class is for setting orbitals and ' ' projections only!' ) diff --git a/aiida/orm/nodes/data/array/trajectory.py b/aiida/orm/nodes/data/array/trajectory.py index 3bc4726609..9b820979ca 100644 --- a/aiida/orm/nodes/data/array/trajectory.py +++ b/aiida/orm/nodes/data/array/trajectory.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -AiiDA class to deal with crystal structure trajectories. +"""AiiDA class to deal with crystal structure trajectories. """ import collections.abc @@ -18,8 +17,7 @@ class TrajectoryData(ArrayData): - """ - Stores a trajectory (a sequence of crystal structures with timestamps, and + """Stores a trajectory (a sequence of crystal structures with timestamps, and possibly with velocities). """ @@ -28,9 +26,8 @@ def __init__(self, structurelist=None, **kwargs): if structurelist is not None: self.set_structurelist(structurelist) - def _internal_validate(self, stepids, cells, symbols, positions, times, velocities): # pylint: disable=too-many-arguments,too-many-locals,too-many-branches - """ - Internal function to validate the type and shape of the arrays. See + def _internal_validate(self, stepids, cells, symbols, positions, times, velocities): + """Internal function to validate the type and shape of the arrays. See the documentation of py:meth:`.set_trajectory` for a description of the valid shape and type of the parameters. """ @@ -66,8 +63,7 @@ def _internal_validate(self, stepids, cells, symbols, positions, times, velociti numatoms = len(symbols) if positions.shape != (numsteps, numatoms, 3): raise ValueError( - 'TrajectoryData.positions must have shape (s,n,3), ' - 'with s=number of steps and n=number of symbols' + 'TrajectoryData.positions must have shape (s,n,3), ' 'with s=number of steps and n=number of symbols' ) if times is not None: if times.shape != (numsteps,): @@ -80,9 +76,8 @@ def _internal_validate(self, stepids, cells, symbols, positions, times, velociti 'with s=number of steps and n=number of symbols' ) - def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=None, velocities=None): # pylint: disable=too-many-arguments - r""" - Store the whole trajectory, after checking that types and dimensions + def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=None, velocities=None): + r"""Store the whole trajectory, after checking that types and dimensions are correct. Parameters ``stepids``, ``cells`` and ``velocities`` are optional @@ -132,7 +127,6 @@ def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=Non .. todo :: Choose suitable units for velocities """ - import numpy self._internal_validate(stepids, cells, symbols, positions, times, velocities) @@ -169,8 +163,7 @@ def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=Non pass def set_structurelist(self, structurelist): - """ - Create trajectory from the list of + """Create trajectory from the list of :py:class:`aiida.orm.nodes.data.structure.StructureData` instances. :param structurelist: a list of @@ -192,8 +185,7 @@ def set_structurelist(self, structurelist): self.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions) def _validate(self): - """ - Verify that the required arrays are present and that their type and + """Verify that the required arrays are present and that their type and dimension are correct. """ # check dimensions, types @@ -201,8 +193,12 @@ def _validate(self): try: self._internal_validate( - self.get_stepids(), self.get_cells(), self.symbols, self.get_positions(), self.get_times(), - self.get_velocities() + self.get_stepids(), + self.get_cells(), + self.symbols, + self.get_positions(), + self.get_times(), + self.get_velocities(), ) # Should catch TypeErrors, ValueErrors, and KeyErrors for missing arrays except Exception as exception: @@ -212,9 +208,7 @@ def _validate(self): @property def numsteps(self): - """ - Return the number of stored steps, or zero if nothing has been stored yet. - """ + """Return the number of stored steps, or zero if nothing has been stored yet.""" try: return self.get_shape('steps')[0] except (AttributeError, KeyError, IndexError): @@ -222,17 +216,14 @@ def numsteps(self): @property def numsites(self): - """ - Return the number of stored sites, or zero if nothing has been stored yet. - """ + """Return the number of stored sites, or zero if nothing has been stored yet.""" try: return len(self.symbols) except (AttributeError, KeyError, IndexError): return 0 def get_stepids(self): - """ - Return the array of steps, if it has already been set. + """Return the array of steps, if it has already been set. .. versionadded:: 0.7 Renamed from get_steps @@ -242,8 +233,7 @@ def get_stepids(self): return self.get_array('steps') def get_times(self): - """ - Return the array of times (in ps), if it has already been set. + """Return the array of times (in ps), if it has already been set. :raises KeyError: if the trajectory has not been set yet. """ @@ -253,8 +243,7 @@ def get_times(self): return None def get_cells(self): - """ - Return the array of cells, if it has already been set. + """Return the array of cells, if it has already been set. :raises KeyError: if the trajectory has not been set yet. """ @@ -265,24 +254,21 @@ def get_cells(self): @property def symbols(self): - """ - Return the array of symbols, if it has already been set. + """Return the array of symbols, if it has already been set. :raises KeyError: if the trajectory has not been set yet. """ return self.base.attributes.get('symbols') def get_positions(self): - """ - Return the array of positions, if it has already been set. + """Return the array of positions, if it has already been set. :raises KeyError: if the trajectory has not been set yet. """ return self.get_array('positions') def get_velocities(self): - """ - Return the array of velocities, if it has already been set. + """Return the array of velocities, if it has already been set. .. note :: This function (differently from all other ``get_*`` functions, will not raise an exception if the velocities are not @@ -295,8 +281,7 @@ def get_velocities(self): return None def get_index_from_stepid(self, stepid): - """ - Given a value for the stepid (i.e., a value among those of the ``steps`` + """Given a value for the stepid (i.e., a value among those of the ``steps`` array), return the array index of that stepid, that can be used in other methods such as :py:meth:`.get_step_data` or :py:meth:`.get_step_structure`. @@ -318,8 +303,7 @@ def get_index_from_stepid(self, stepid): raise ValueError(f'{stepid} not among the stepids') def get_step_data(self, index): - """ - Return a tuple with all information concerning the stepid with given + """Return a tuple with all information concerning the stepid with given index (0 is the first step, 1 the second step and so on). If you know only the step value, use the :py:meth:`.get_index_from_stepid` method to get the corresponding index. @@ -356,8 +340,7 @@ def get_step_data(self, index): return (self.get_stepids()[index], time, cell, self.symbols, self.get_positions()[index, :, :], vel) def get_step_structure(self, index, custom_kinds=None): - """ - Return an AiiDA :py:class:`aiida.orm.nodes.data.structure.StructureData` node + """Return an AiiDA :py:class:`aiida.orm.nodes.data.structure.StructureData` node (not stored yet!) with the coordinates of the given step, identified by its index. If you know only the step value, use the :py:meth:`.get_index_from_stepid` method to get the corresponding index. @@ -390,8 +373,7 @@ def get_step_structure(self, index, custom_kinds=None): for k in custom_kinds: if not isinstance(k, Kind): raise TypeError( - 'Each element of the custom_kinds list must ' - 'be a aiida.orm.nodes.data.structure.Kind object' + 'Each element of the custom_kinds list must ' 'be a aiida.orm.nodes.data.structure.Kind object' ) kind_names.append(k.name) if len(kind_names) != len(set(kind_names)): @@ -417,11 +399,10 @@ def get_step_structure(self, index, custom_kinds=None): return struc - def _prepare_xsf(self, index=None, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given trajectory to a string of format XSF (for XCrySDen). - """ + def _prepare_xsf(self, index=None, main_file_name=''): + """Write the given trajectory to a string of format XSF (for XCrySDen).""" from aiida.common.constants import elements + _atomic_numbers = {data['symbol']: num for num, data in elements.items()} indices = list(range(self.numsteps)) @@ -455,10 +436,8 @@ def _prepare_xsf(self, index=None, main_file_name=''): # pylint: disable=unused raise return return_string.encode('utf-8'), {} - def _prepare_cif(self, trajectory_index=None, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given trajectory to a string of format CIF. - """ + def _prepare_cif(self, trajectory_index=None, main_file_name=''): + """Write the given trajectory to a string of format CIF.""" from aiida.common.utils import Capturing from aiida.orm.nodes.data.cif import ase_loops, cif_from_ase, pycifrw_from_cif @@ -474,8 +453,7 @@ def _prepare_cif(self, trajectory_index=None, main_file_name=''): # pylint: dis return cif.encode('utf-8'), {} def get_structure(self, store=False, **kwargs): - """ - Creates :py:class:`aiida.orm.nodes.data.structure.StructureData`. + """Creates :py:class:`aiida.orm.nodes.data.structure.StructureData`. .. versionadded:: 1.0 Renamed from _get_aiida_structure @@ -504,12 +482,11 @@ def get_structure(self, store=False, **kwargs): param = Dict(kwargs) - ret_dict = _get_aiida_structure_inline(trajectory=self, parameters=param, metadata={'store_provenance': store}) # pylint: disable=unexpected-keyword-arg + ret_dict = _get_aiida_structure_inline(trajectory=self, parameters=param, metadata={'store_provenance': store}) return ret_dict['structure'] def get_cif(self, index=None, **kwargs): - """ - Creates :py:class:`aiida.orm.nodes.data.cif.CifData` + """Creates :py:class:`aiida.orm.nodes.data.cif.CifData` .. versionadded:: 1.0 Renamed from _get_cif @@ -519,8 +496,7 @@ def get_cif(self, index=None, **kwargs): return cif def _parse_xyz_pos(self, inputstring): - """ - Load positions from a XYZ file. + """Load positions from a XYZ file. .. note:: The steps and symbols must be set manually before calling this import function as a consistency measure. Even though the symbols @@ -540,7 +516,6 @@ def _parse_xyz_pos(self, inputstring): t.set_array('symbols', array([site.kind for site in s.sites])) t.importfile('some-calc/AIIDA-PROJECT-pos-1.xyz', 'xyz_pos') """ - from numpy import array from aiida.common.exceptions import ValidationError @@ -568,14 +543,12 @@ def _parse_xyz_pos(self, inputstring): self.set_array('positions', positions) def _parse_xyz_vel(self, inputstring): - """ - Load velocities from a XYZ file. + """Load velocities from a XYZ file. .. note:: The steps and symbols must be set manually before calling this import function as a consistency measure. See also comment for :py:meth:`._parse_xyz_pos` """ - from numpy import array from aiida.common.exceptions import ValidationError @@ -602,9 +575,8 @@ def _parse_xyz_vel(self, inputstring): self.set_array('velocities', velocities) - def show_mpl_pos(self, **kwargs): # pylint: disable=too-many-locals - """ - Shows the positions as a function of time, separate for XYZ coordinates + def show_mpl_pos(self, **kwargs): + """Shows the positions as a function of time, separate for XYZ coordinates :param int stepsize: The stepsize for the trajectory, set higher than 1 to reduce number of points @@ -691,12 +663,11 @@ def show_mpl_pos(self, **kwargs): # pylint: disable=too-many-locals maxtime, ) - def show_mpl_heatmap(self, **kwargs): # pylint: disable=invalid-name,too-many-arguments,too-many-locals,too-many-statements,too-many-branches - """ - Show a heatmap of the trajectory with matplotlib. - """ + def show_mpl_heatmap(self, **kwargs): + """Show a heatmap of the trajectory with matplotlib.""" import numpy as np from scipy import stats + try: from mayavi import mlab except ImportError: @@ -707,19 +678,17 @@ def show_mpl_heatmap(self, **kwargs): # pylint: disable=invalid-name,too-many-a from ase.data import atomic_numbers from ase.data.colors import jmol_colors - # pylint: disable=invalid-name - def collapse_into_unit_cell(point, cell): - """ - Applies linear transformation to coordinate system based on crystal + """Applies linear transformation to coordinate system based on crystal lattice, vectors. The inverse of that inverse transformation matrix with the point given results in the point being given as a multiples of lattice vectors Than take the integer of the rows to find how many times you have to shift - the point back""" - invcell = np.matrix(cell).T.I # pylint: disable=no-member + the point back + """ + invcell = np.matrix(cell).T.I # point in crystal coordinates points_in_crystal = np.dot(invcell, point).tolist()[0] - #point collapsed into unit cell + # point collapsed into unit cell points_in_unit_cell = [i % 1 for i in points_in_crystal] return np.dot(cell.T, points_in_unit_cell).tolist() @@ -792,16 +761,16 @@ def collapse_into_unit_cell(point, cell): xmin, ymin, zmin = _x.min(), _y.min(), _z.min() xmax, ymax, zmax = _x.max(), _y.max(), _z.max() - _xi, _yi, _zi = np.mgrid[xmin:xmax:60j, ymin:ymax:30j, zmin:zmax:30j] # pylint: disable=invalid-slice-index + _xi, _yi, _zi = np.mgrid[xmin:xmax:60j, ymin:ymax:30j, zmin:zmax:30j] coords = np.vstack([item.ravel() for item in [_xi, _yi, _zi]]) density = kde(coords).reshape(_xi.shape) # Plot scatter with mayavi - #~ figure = mlab.figure('DensityPlot') + # ~ figure = mlab.figure('DensityPlot') grid = mlab.pipeline.scalar_field(_xi, _yi, _zi, density) - #~ min = density.min() + # ~ min = density.min() maxdens = density.max() - #~ mlab.pipeline.volume(grid, vmin=min, vmax=min + .5*(max-min)) + # ~ mlab.pipeline.volume(grid, vmin=min, vmax=min + .5*(max-min)) surf = mlab.pipeline.iso_surface(grid, opacity=0.5, colormap='cool', contours=(maxdens * contours).tolist()) lut = surf.module_manager.scalar_lut_manager.lut.table.to_array() @@ -823,27 +792,27 @@ def collapse_into_unit_cell(point, cell): color=tuple(jmol_colors[atomic_numbers[ele]].tolist()), scale_mode='none', scale_factor=0.3, - opacity=0.3 + opacity=0.3, ) mlab.view(azimuth=155, elevation=70, distance='auto') mlab.show() -def plot_positions_XYZ( # pylint: disable=too-many-arguments,too-many-locals,invalid-name - times, - positions, - indices_to_show, - color_list, - label, - positions_unit='A', - times_unit='ps', - dont_block=False, - mintime=None, - maxtime=None, - n_labels=10): - """ - Plot with matplotlib the positions of the coordinates of the atoms +def plot_positions_XYZ( # noqa: N802 + times, + positions, + indices_to_show, + color_list, + label, + positions_unit='A', + times_unit='ps', + dont_block=False, + mintime=None, + maxtime=None, + n_labels=10, +): + """Plot with matplotlib the positions of the coordinates of the atoms over time for a trajectory :param times: array of times @@ -858,9 +827,9 @@ def plot_positions_XYZ( # pylint: disable=too-many-arguments,too-many-locals,in :param maxtime: if specified, cut the time axis at the specified max value :param n_labels: how many labels (t, coord) to put """ + import numpy as np from matplotlib import pyplot as plt from matplotlib.gridspec import GridSpec - import numpy as np tlim = [times[0], times[-1]] index_range = [0, len(times) - 1] diff --git a/aiida/orm/nodes/data/array/xy.py b/aiida/orm/nodes/data/array/xy.py index 05907b3dd0..923c22c544 100644 --- a/aiida/orm/nodes/data/array/xy.py +++ b/aiida/orm/nodes/data/array/xy.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -This module defines the classes related to Xy data. That is data that contains +"""This module defines the classes related to Xy data. That is data that contains collections of y-arrays bound to a single x-array, and the methods to operate on them. """ @@ -29,8 +28,7 @@ def check_convert_single_to_tuple(item: Any | Sequence[Any]) -> Sequence[Any]: - """ - Checks if the item is a list or tuple, and converts it to a list if it is + """Checks if the item is a list or tuple, and converts it to a list if it is not already a list or tuple :param item: an object which may or may not be a list or tuple @@ -44,8 +42,7 @@ def check_convert_single_to_tuple(item: Any | Sequence[Any]) -> Sequence[Any]: class XyData(ArrayData): - """ - A subclass designed to handle arrays that have an "XY" relationship to + """A subclass designed to handle arrays that have an "XY" relationship to each other. That is there is one array, the X array, and there are several Y arrays, which can be considered functions of X. """ @@ -59,7 +56,7 @@ def __init__( x_units: str | None = None, y_names: str | list[str] | None = None, y_units: str | list[str] | None = None, - **kwargs + **kwargs, ): """Construct a new instance, optionally setting the x and y arrays. @@ -80,8 +77,7 @@ def __init__( @staticmethod def _arrayandname_validator(array: 'ndarray', name: str, units: str) -> None: - """ - Validates that the array is an numpy.ndarray and that the name is + """Validates that the array is an numpy.ndarray and that the name is of type str. Raises TypeError or ValueError if this not the case. """ if not isinstance(name, str): @@ -97,8 +93,7 @@ def _arrayandname_validator(array: 'ndarray', name: str, units: str) -> None: raise TypeError('The units must always be a str.') def set_x(self, x_array: 'ndarray', x_name: str, x_units: str) -> None: - """ - Sets the array and the name for the x values. + """Sets the array and the name for the x values. :param x_array: A numpy.ndarray, containing only floats :param x_name: a string for the x array name @@ -112,8 +107,7 @@ def set_x(self, x_array: 'ndarray', x_name: str, x_units: str) -> None: def set_y( self, y_arrays: 'ndarray' | Sequence['ndarray'], y_names: str | Sequence[str], y_units: str | Sequence[str] ) -> None: - """ - Set array(s) for the y part of the dataset. Also checks if the + """Set array(s) for the y part of the dataset. Also checks if the x_array has already been set, and that, the shape of the y_arrays agree with the x_array. :param y_arrays: A list of y_arrays, numpy.ndarray @@ -148,8 +142,7 @@ def set_y( self.base.attributes.set('y_units', y_units) def get_x(self) -> tuple[str, 'ndarray', str]: - """ - Tries to retrieve the x array and x name raises a NotExistent + """Tries to retrieve the x array and x name raises a NotExistent exception if no x array has been set yet. :return x_name: the name set for the x_array :return x_array: the x array set earlier @@ -164,8 +157,7 @@ def get_x(self) -> tuple[str, 'ndarray', str]: return x_name, x_array, x_units def get_y(self) -> list[tuple[str, 'ndarray', str]]: - """ - Tries to retrieve the y arrays and the y names, raises a + """Tries to retrieve the y arrays and the y names, raises a NotExistent exception if they have not been set yet, or cannot be retrieved :return y_names: list of strings naming the y_arrays diff --git a/aiida/orm/nodes/data/base.py b/aiida/orm/nodes/data/base.py index f95cacaa2e..34677839fb 100644 --- a/aiida/orm/nodes/data/base.py +++ b/aiida/orm/nodes/data/base.py @@ -32,7 +32,7 @@ def __init__(self, value=None, **kwargs): super().__init__(**kwargs) - self.value = value or self._type() # pylint: disable=no-member + self.value = value or self._type() @property def value(self): @@ -40,7 +40,7 @@ def value(self): @value.setter def value(self, value): - self.base.attributes.set('value', self._type(value)) # pylint: disable=no-member + self.base.attributes.set('value', self._type(value)) def __str__(self): return f'{super().__str__()} value: {self.value}' diff --git a/aiida/orm/nodes/data/cif.py b/aiida/orm/nodes/data/cif.py index 460bc2d9cb..05aad657d9 100644 --- a/aiida/orm/nodes/data/cif.py +++ b/aiida/orm/nodes/data/cif.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,too-many-locals,too-many-statements """Tools for handling Crystallographic Information Files (CIF)""" import re @@ -37,21 +36,17 @@ def has_pycifrw(): - """ - :return: True if the PyCifRW module can be imported, False otherwise. - """ - # pylint: disable=unused-variable,unused-import + """:return: True if the PyCifRW module can be imported, False otherwise.""" try: - import CifFile - from CifFile import CifBlock + import CifFile # noqa: F401 + from CifFile import CifBlock # noqa: F401 except ImportError: return False return True def cif_from_ase(ase, full_occupancies=False, add_fake_biso=False): - """ - Construct a CIF datablock from the ASE structure. The code is taken + """Construct a CIF datablock from the ASE structure. The code is taken from https://wiki.fysik.dtu.dk/ase/ase/io/formatoptions.html#ase.io.cif.write_cif, as the original ASE code contains a bug in printing the @@ -74,9 +69,9 @@ def cif_from_ase(ase, full_occupancies=False, add_fake_biso=False): a = norm(cell[0]) b = norm(cell[1]) c = norm(cell[2]) - alpha = arccos(dot(cell[1], cell[2]) / (b * c)) * 180. / pi - beta = arccos(dot(cell[0], cell[2]) / (a * c)) * 180. / pi - gamma = arccos(dot(cell[0], cell[1]) / (a * b)) * 180. / pi + alpha = arccos(dot(cell[1], cell[2]) / (b * c)) * 180.0 / pi + beta = arccos(dot(cell[0], cell[2]) / (a * c)) * 180.0 / pi + gamma = arccos(dot(cell[0], cell[1]) / (a * b)) * 180.0 / pi datablock['_cell_length_a'] = str(a) datablock['_cell_length_b'] = str(b) @@ -126,10 +121,8 @@ def cif_from_ase(ase, full_occupancies=False, add_fake_biso=False): return datablocks -# pylint: disable=too-many-branches def pycifrw_from_cif(datablocks, loops=None, names=None): - """ - Constructs PyCifRW's CifFile from an array of CIF datablocks. + """Constructs PyCifRW's CifFile from an array of CIF datablocks. :param datablocks: an array of CIF datablocks :param loops: optional dict of lists of CIF tag loops. @@ -140,12 +133,12 @@ def pycifrw_from_cif(datablocks, loops=None, names=None): import CifFile from CifFile import CifBlock except ImportError as exc: - raise ImportError(f'{str(exc)}. You need to install the PyCifRW package.') + raise ImportError(f'{exc!s}. You need to install the PyCifRW package.') if loops is None: loops = {} - cif = CifFile.CifFile() # pylint: disable=no-member + cif = CifFile.CifFile() try: cif.set_grammar('1.1') except AttributeError: @@ -185,7 +178,7 @@ def pycifrw_from_cif(datablocks, loops=None, names=None): if row_size is not None and row_size > 0: datablock.CreateLoop(datanames=tags_seen) for tag in sorted(values.keys()): - if not tag in tags_in_loops: + if tag not in tags_in_loops: datablock.AddItem(tag, values[tag]) # create automatically a loop for non-scalar values if isinstance(values[tag], (tuple, list)) and tag not in loops.keys(): @@ -194,8 +187,7 @@ def pycifrw_from_cif(datablocks, loops=None, names=None): def parse_formula(formula): - """ - Parses the Hill formulae. Does not need spaces as separators. + """Parses the Hill formulae. Does not need spaces as separators. Works also for partial occupancies and for chemical groups enclosed in round/square/curly brackets. Elements are counted and a dictionary is returned. e.g. 'C[NH2]3NO3' --> {'C': 1, 'N': 4, 'H': 6, 'O': 3} @@ -238,18 +230,16 @@ def chemcount_str_to_number(string): return contents -# pylint: disable=abstract-method,too-many-public-methods # Note: Method 'query' is abstract in class 'Node' but is not overridden class CifData(SinglefileData): - """ - Wrapper for Crystallographic Interchange File (CIF) + """Wrapper for Crystallographic Interchange File (CIF) .. note:: the file (physical) is held as the authoritative source of information, so all conversions are done through the physical file: when setting ``ase`` or ``values``, a physical CIF file is generated first, the values are updated from the physical CIF file. """ - # pylint: disable=abstract-method, too-many-public-methods + _SET_INCOMPATIBILITIES = [('ase', 'file'), ('ase', 'values'), ('file', 'values')] _SCAN_TYPES = ('standard', 'flex') _SCAN_TYPE_DEFAULT = 'standard' @@ -270,9 +260,6 @@ def __init__(self, ase=None, file=None, filename=None, values=None, scan_type=No :param scan_type: scan type string for parsing with PyCIFRW ('standard' or 'flex'). See CifFile.ReadCif :param parse_policy: 'eager' (parse CIF file on set_file) or 'lazy' (defer parsing until needed) """ - - # pylint: disable=too-many-arguments, redefined-builtin - args = { 'ase': ase, 'file': file, @@ -298,8 +285,7 @@ def __init__(self, ase=None, file=None, filename=None, values=None, scan_type=No @staticmethod def read_cif(fileobj, index=-1, **kwargs): - """ - A wrapper method that simulates the behavior of the old + """A wrapper method that simulates the behavior of the old function ase.io.cif.read_cif by using the new generic ase.io.read function. @@ -327,21 +313,20 @@ def read_cif(fileobj, index=-1, **kwargs): @classmethod def from_md5(cls, md5, backend=None): - """ - Return a list of all CIF files that match a given MD5 hash. + """Return a list of all CIF files that match a given MD5 hash. .. note:: the hash has to be stored in a ``_md5`` attribute, otherwise the CIF file will not be found. """ from aiida.orm.querybuilder import QueryBuilder + builder = QueryBuilder(backend=backend) builder.append(cls, filters={'attributes.md5': {'==': md5}}) return builder.all(flat=True) @classmethod def get_or_create(cls, filename, use_first=False, store_cif=True): - """ - Pass the same parameter of the init; if a file with the same md5 + """Pass the same parameter of the init; if a file with the same md5 is found, that CifData is returned. :param filename: an absolute filename on disk @@ -375,18 +360,16 @@ def get_or_create(cls, filename, use_first=False, store_cif=True): return (cifs[0], False) raise ValueError( - 'More than one copy of a CIF file ' - 'with the same MD5 has been found in ' - 'the DB. pks={}'.format(','.join([str(i.pk) for i in cifs])) + 'More than one copy of a CIF file ' 'with the same MD5 has been found in ' 'the DB. pks={}'.format( + ','.join([str(i.pk) for i in cifs]) + ) ) return cifs[0], False - # pylint: disable=attribute-defined-outside-init @property def ase(self): - """ - ASE object, representing the CIF. + """ASE object, representing the CIF. .. note:: requires ASE module. """ @@ -395,8 +378,7 @@ def ase(self): return self._ase def get_ase(self, **kwargs): - """ - Returns ASE object, representing the CIF. This function differs + """Returns ASE object, representing the CIF. This function differs from the property ``ase`` by the possibility to pass the keyworded arguments (kwargs) to ase.io.cif.read_cif(). @@ -408,12 +390,12 @@ def get_ase(self, **kwargs): return CifData.read_cif(handle, **kwargs) def set_ase(self, aseatoms): - """ - Set the contents of the CifData starting from an ASE atoms object + """Set the contents of the CifData starting from an ASE atoms object :param aseatoms: the ASE atoms object """ import tempfile + cif = cif_from_ase(aseatoms) with tempfile.NamedTemporaryFile(mode='w+') as tmpf: with Capturing(): @@ -427,25 +409,23 @@ def ase(self, aseatoms): @property def values(self): - """ - PyCifRW structure, representing the CIF datablocks. + """PyCifRW structure, representing the CIF datablocks. .. note:: requires PyCifRW module. """ if self._values is None: import CifFile - from CifFile import CifBlock # pylint: disable=no-name-in-module + from CifFile import CifBlock with self.open() as handle: - c = CifFile.ReadCif(handle, scantype=self.base.attributes.get('scan_type', CifData._SCAN_TYPE_DEFAULT)) # pylint: disable=no-member + c = CifFile.ReadCif(handle, scantype=self.base.attributes.get('scan_type', CifData._SCAN_TYPE_DEFAULT)) for k, v in c.items(): c.dictionary[k] = CifBlock(v) self._values = c return self._values def set_values(self, values): - """ - Set internal representation to `values`. + """Set internal representation to `values`. Warning: This also writes a new CIF file. @@ -454,6 +434,7 @@ def set_values(self, values): .. note:: requires PyCifRW module. """ import tempfile + with tempfile.NamedTemporaryFile(mode='w+') as tmpf: with Capturing(): tmpf.write(values.WriteOut()) @@ -468,8 +449,7 @@ def values(self, values): self.set_values(values) def parse(self, scan_type=None): - """ - Parses CIF file and sets attributes. + """Parses CIF file and sets attributes. :param scan_type: See set_scan_type """ @@ -480,18 +460,15 @@ def parse(self, scan_type=None): self.base.attributes.set('formulae', self.get_formulae()) self.base.attributes.set('spacegroup_numbers', self.get_spacegroup_numbers()) - def store(self, *args, **kwargs): # pylint: disable=signature-differs - """ - Store the node. - """ + def store(self, *args, **kwargs): + """Store the node.""" if not self.is_stored: self.base.attributes.set('md5', self.generate_md5()) return super().store(*args, **kwargs) def set_file(self, file, filename=None): - """ - Set the file. + """Set the file. If the source is set and the MD5 checksum of new file is different from the source, the source has to be deleted. @@ -500,12 +477,13 @@ def set_file(self, file, filename=None): Hint: Pass io.BytesIO(b"my string") to construct the file directly from a string. :param filename: specify filename to use (defaults to name of provided file). """ - # pylint: disable=redefined-builtin super().set_file(file, filename=filename) md5sum = self.generate_md5() - if isinstance(self.source, dict) and \ - self.source.get('source_md5', None) is not None and \ - self.source['source_md5'] != md5sum: + if ( + isinstance(self.source, dict) + and self.source.get('source_md5', None) is not None + and self.source['source_md5'] != md5sum + ): self.source = {} self.base.attributes.set('md5', md5sum) @@ -515,8 +493,7 @@ def set_file(self, file, filename=None): self.base.attributes.set('spacegroup_numbers', None) def set_scan_type(self, scan_type): - """ - Set the scan_type for PyCifRW. + """Set the scan_type for PyCifRW. The 'flex' scan_type of PyCifRW is faster for large CIF files but does not yet support the CIF2 format as of 02/2018. @@ -530,8 +507,7 @@ def set_scan_type(self, scan_type): raise ValueError(f'Got unknown scan_type {scan_type}') def set_parse_policy(self, parse_policy): - """ - Set the parse policy. + """Set the parse policy. :param parse_policy: Either 'eager' (parse CIF file on set_file) or 'lazy' (defer parsing until needed) @@ -542,8 +518,7 @@ def set_parse_policy(self, parse_policy): raise ValueError(f'Got unknown parse_policy {parse_policy}') def get_formulae(self, mode='sum', custom_tags=None): - """ - Return chemical formulae specified in CIF file. + """Return chemical formulae specified in CIF file. Note: This does not compute the formula, it only reads it from the appropriate tag. Use refine_inline to compute formulae. @@ -568,9 +543,7 @@ def get_formulae(self, mode='sum', custom_tags=None): return formulae def get_spacegroup_numbers(self): - """ - Get the spacegroup international number. - """ + """Get the spacegroup international number.""" # note: If spacegroup_numbers are not None, they could be returned # directly (but the function is very cheap anyhow). spg_tags = ['_space_group.it_number', '_space_group_it_number', '_symmetry_int_tables_number'] @@ -589,8 +562,7 @@ def get_spacegroup_numbers(self): @property def has_partial_occupancies(self): - """ - Return if the cif data contains partial occupancies + """Return if the cif data contains partial occupancies A partial occupancy is defined as site with an occupancy that differs from unity, within a precision of 1E-6 @@ -619,8 +591,7 @@ def has_partial_occupancies(self): @property def has_attached_hydrogens(self): - """ - Check if there are hydrogens without coordinates, specified as attached + """Check if there are hydrogens without coordinates, specified as attached to the atoms of the structure. :returns: True if there are attached hydrogens, False otherwise. @@ -636,8 +607,7 @@ def has_attached_hydrogens(self): @property def has_undefined_atomic_sites(self): - """ - Return whether the cif data contains any undefined atomic sites. + """Return whether the cif data contains any undefined atomic sites. An undefined atomic site is defined as a site where at least one of the fractional coordinates specified in the `_atom_site_fract_*` tags, cannot be successfully interpreted as a float. If the cif data contains any site that @@ -658,7 +628,6 @@ def has_undefined_atomic_sites(self): for tag in [tag_x, tag_y, tag_z]: if tag in self.values[datablock].keys(): for position in self.values[datablock][tag]: - # The CifData contains at least one `_atom_site_fract_*` tag has_tags = True @@ -674,8 +643,7 @@ def has_undefined_atomic_sites(self): @property def has_atomic_sites(self): - """ - Returns whether there are any atomic sites defined in the cif data. That + """Returns whether there are any atomic sites defined in the cif data. That is to say, it will check all the values for the `_atom_site_fract_*` tags and if they are all equal to `?` that means there are no relevant atomic sites defined and the function will return False. In all other cases the @@ -697,8 +665,7 @@ def has_atomic_sites(self): @property def has_unknown_species(self): - """ - Returns whether the cif contains atomic species that are not recognized by AiiDA. + """Returns whether the cif contains atomic species that are not recognized by AiiDA. The known species are taken from the elements dictionary in `aiida.common.constants`, with the exception of the "unknown" placeholder element with symbol 'X', as this could not be used to construct a real structure. @@ -713,7 +680,6 @@ def has_unknown_species(self): known_species = [element['symbol'] for element in elements.values() if element['symbol'] != 'X'] for formula in self.get_formulae(): - if formula is None: return None @@ -724,9 +690,7 @@ def has_unknown_species(self): return False def generate_md5(self): - """ - Computes and returns MD5 hash of the CIF file. - """ + """Computes and returns MD5 hash of the CIF file.""" from aiida.common.files import md5_from_filelike # Open in binary mode which is required for generating the md5 checksum @@ -734,8 +698,7 @@ def generate_md5(self): return md5_from_filelike(handle) def get_structure(self, converter='pymatgen', store=False, **kwargs): - """ - Creates :py:class:`aiida.orm.nodes.data.structure.StructureData`. + """Creates :py:class:`aiida.orm.nodes.data.structure.StructureData`. .. versionadded:: 1.0 Renamed from _get_aiida_structure @@ -765,7 +728,7 @@ def get_structure(self, converter='pymatgen', store=False, **kwargs): return result['structure'] - def _prepare_cif(self, **kwargs): # pylint: disable=unused-argument + def _prepare_cif(self, **kwargs): """Return CIF string of CifData object. If parsed values are present, a CIF string is created and written to file. If no parsed values are present, the @@ -775,25 +738,21 @@ def _prepare_cif(self, **kwargs): # pylint: disable=unused-argument return handle.read(), {} def _get_object_ase(self): - """ - Converts CifData to ase.Atoms + """Converts CifData to ase.Atoms :return: an ase.Atoms object """ return self.ase def _get_object_pycifrw(self): - """ - Converts CifData to PyCIFRW.CifFile + """Converts CifData to PyCIFRW.CifFile :return: a PyCIFRW.CifFile object """ return self.values def _validate(self): - """ - Validates MD5 hash of CIF file. - """ + """Validates MD5 hash of CIF file.""" from aiida.common.exceptions import ValidationError super()._validate() diff --git a/aiida/orm/nodes/data/code/__init__.py b/aiida/orm/nodes/data/code/__init__.py index 9e85d34d06..de734ea3bd 100644 --- a/aiida/orm/nodes/data/code/__init__.py +++ b/aiida/orm/nodes/data/code/__init__.py @@ -3,8 +3,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .abstract import * from .containerized import * @@ -20,4 +19,4 @@ 'PortableCode', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/nodes/data/code/abstract.py b/aiida/orm/nodes/data/code/abstract.py index 5a2f4ec0e2..e32596d539 100644 --- a/aiida/orm/nodes/data/code/abstract.py +++ b/aiida/orm/nodes/data/code/abstract.py @@ -50,7 +50,7 @@ def __init__( with_mpi: bool | None = None, is_hidden: bool = False, wrap_cmdline_params: bool = False, - **kwargs + **kwargs, ): """Construct a new instance. @@ -347,13 +347,13 @@ def _get_cli_options(cls) -> dict: 'short_name': '-D', 'type': click.STRING, 'prompt': 'Description', - 'help': 'Human-readable description of this code ideally including version and compilation environment.' + 'help': 'Human-readable description, ideally including version and compilation environment.', }, 'default_calc_job_plugin': { 'short_name': '-P', 'type': click.STRING, 'prompt': 'Default `CalcJob` plugin', - 'help': 'Entry point name of the default plugin (as listed in `verdi plugin list aiida.calculations`).' + 'help': 'Entry point name of the default plugin (as listed in `verdi plugin list aiida.calculations`).', }, 'use_double_quotes': { 'is_flag': True, @@ -381,7 +381,7 @@ def _get_cli_options(cls) -> dict: 'extension': '.bash', 'header': 'PREPEND_TEXT: if there is any bash commands that should be prepended to the executable call ' 'in all submit scripts for this code, type that between the equal signs below and save the file.', - 'footer': 'All lines that start with `#=`: will be ignored.' + 'footer': 'All lines that start with `#=`: will be ignored.', }, 'append_text': { 'cls': TemplateInteractiveOption, @@ -392,6 +392,6 @@ def _get_cli_options(cls) -> dict: 'extension': '.bash', 'header': 'APPEND_TEXT: if there is any bash commands that should be appended to the executable call ' 'in all submit scripts for this code, type that between the equal signs below and save the file.', - 'footer': 'All lines that start with `#=`: will be ignored.' + 'footer': 'All lines that start with `#=`: will be ignored.', }, } diff --git a/aiida/orm/nodes/data/code/containerized.py b/aiida/orm/nodes/data/code/containerized.py index 99230434b4..c7ab67aec4 100644 --- a/aiida/orm/nodes/data/code/containerized.py +++ b/aiida/orm/nodes/data/code/containerized.py @@ -25,6 +25,7 @@ class ContainerizedCode(InstalledCode): """Data plugin representing an executable code in container on a remote computer.""" + _KEY_ATTRIBUTE_ENGINE_COMMAND: str = 'engine_command' _KEY_ATTRIBUTE_IMAGE_NAME: str = 'image_name' @@ -97,7 +98,8 @@ def get_prepend_cmdline_params( ) -> list[str]: """Return the list of prepend cmdline params for mpi seeting - :return: list of prepend cmdline parameters.""" + :return: list of prepend cmdline parameters. + """ engine_cmdline = self.engine_command.format(image_name=self.image_name) engine_cmdline_params = engine_cmdline.split() @@ -130,7 +132,7 @@ def _get_cli_options(cls) -> dict: 'help': 'Whether all command line parameters to be passed to the engine command should be wrapped in ' 'a double quotes to form a single argument. This should be set to `True` for Docker.', 'prompt': 'Wrap command line parameters', - } + }, } options.update(**super()._get_cli_options()) diff --git a/aiida/orm/nodes/data/code/installed.py b/aiida/orm/nodes/data/code/installed.py index 9379c15eef..a4fb1a718c 100644 --- a/aiida/orm/nodes/data/code/installed.py +++ b/aiida/orm/nodes/data/code/installed.py @@ -49,7 +49,7 @@ def _validate(self): :raises :class:`aiida.common.exceptions.ValidationError`: If the state of the node is invalid. """ - super(Code, self)._validate() # Change to ``super()._validate()`` once deprecated ``Code`` class is removed. # pylint: disable=bad-super-call + super(Code, self)._validate() # Change to ``super()._validate()`` once deprecated ``Code`` class is removed. if not self.computer: # type: ignore[truthy-bool] raise exceptions.ValidationError('The `computer` is undefined.') @@ -78,7 +78,7 @@ def validate_filepath_executable(self): with override_log_level(): # Temporarily suppress noisy logging with self.computer.get_transport() as transport: file_exists = transport.isfile(str(self.filepath_executable)) - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: raise exceptions.ValidationError( 'Could not connect to the configured computer to determine whether the specified executable exists.' ) from exception @@ -196,7 +196,7 @@ def _get_cli_options(cls) -> dict: 'type': click.Path(exists=False), 'prompt': 'Absolute filepath executable', 'help': 'Absolute filepath of the executable on the remote computer.', - } + }, } options.update(**super()._get_cli_options()) diff --git a/aiida/orm/nodes/data/code/legacy.py b/aiida/orm/nodes/data/code/legacy.py index ba4d4ae915..026fa1df84 100644 --- a/aiida/orm/nodes/data/code/legacy.py +++ b/aiida/orm/nodes/data/code/legacy.py @@ -22,8 +22,7 @@ class Code(AbstractCode): - """ - A code entity. + """A code entity. It can either be 'local', or 'remote'. * Local code: it is a collection of files/dirs (added using the add_path() method), where one \ @@ -38,8 +37,6 @@ class Code(AbstractCode): for the code to be run). """ - # pylint: disable=too-many-public-methods - def __init__(self, remote_computer_exec=None, local_executable=None, input_plugin_name=None, files=None, **kwargs): super().__init__(**kwargs) @@ -48,7 +45,7 @@ def __init__(self, remote_computer_exec=None, local_executable=None, input_plugi '`aiida.orm.nodes.data.code.installed.InstalledCode` or `aiida.orm.nodes.data.code.portable.PortableCode` ' 'for a "remote" or "local" code, respectively. If you are using this class to compare type, e.g. in ' '`isinstance`, use `aiida.orm.nodes.data.code.abstract.AbstractCode`.', - version=3 + version=3, ) if remote_computer_exec and local_executable: @@ -100,15 +97,12 @@ def get_executable(self) -> pathlib.PurePosixPath: return pathlib.PurePosixPath(exec_path) def hide(self): - """ - Hide the code (prevents from showing it in the verdi code list) - """ + """Hide the code (prevents from showing it in the verdi code list)""" warn_deprecation('`Code.hide` property is deprecated, use the `Code.is_hidden` property instead.', version=3) self.is_hidden = True def reveal(self): - """ - Reveal the code (allows to show it in the verdi code list) + """Reveal the code (allows to show it in the verdi code list) By default, it is revealed """ warn_deprecation('`Code.reveal` property is deprecated, use the `Code.is_hidden` property instead.', version=3) @@ -116,22 +110,18 @@ def reveal(self): @property def hidden(self): - """ - Determines whether the Code is hidden or not - """ + """Determines whether the Code is hidden or not""" warn_deprecation('`Code.hidden` property is deprecated, use the `Code.is_hidden` property instead.', version=3) return self.is_hidden def set_files(self, files): - """ - Given a list of filenames (or a single filename string), + """Given a list of filenames (or a single filename string), add it to the path (all at level zero, i.e. without folders). Therefore, be careful for files with the same name! :todo: decide whether to check if the Code must be a local executable to be able to call this function. """ - if isinstance(files, str): files = [files] @@ -150,7 +140,7 @@ def get_computer_label(self): """Get label of this code's computer.""" warn_deprecation( '`Code.get_computer_label` method is deprecated, use the `InstalledCode.computer.label` property instead.', - version=3 + version=3, ) return 'repository' if self.computer is None else self.computer.label @@ -171,7 +161,7 @@ def relabel(self, new_label): if self.computer is not None: suffix = f'@{self.computer.label}' if new_label.endswith(suffix): - new_label = new_label[:-len(suffix)] + new_label = new_label[: -len(suffix)] self.label = new_label @@ -187,8 +177,7 @@ def get_description(self): @classmethod def get_code_helper(cls, label, machinename=None, backend=None): - """ - :param label: the code label identifying the code to load + """:param label: the code label identifying the code to load :param machinename: the machine name where code is setup :raise aiida.common.NotExistent: if no code identified by the given string is found @@ -213,7 +202,7 @@ def get_code_helper(cls, label, machinename=None, backend=None): codes = query.all(flat=True) retstr = f"There are multiple codes with label '{label}', having IDs: " retstr += f"{', '.join(sorted([str(c.pk) for c in codes]))}.\n" # type: ignore[union-attr] - retstr += ('Relabel them (using their ID), or refer to them with their ID.') + retstr += 'Relabel them (using their ID), or refer to them with their ID.' raise MultipleObjectsError(retstr) else: result = query.first() @@ -224,8 +213,7 @@ def get_code_helper(cls, label, machinename=None, backend=None): @classmethod def get(cls, pk=None, label=None, machinename=None): - """ - Get a Computer object with given identifier string, that can either be + """Get a Computer object with given identifier string, that can either be the numeric ID (pk), or the label (and computername) (if unique). :param pk: the numeric ID (pk) for code @@ -236,7 +224,6 @@ def get(cls, pk=None, label=None, machinename=None): :raise aiida.common.MultipleObjectsError: if the string cannot identify uniquely a code :raise ValueError: if neither a pk nor a label was passed in """ - # pylint: disable=arguments-differ from aiida.orm.utils import load_code warn_deprecation('`Code.get` classmethod is deprecated, use `aiida.orm.load_code` instead.', version=3) @@ -260,8 +247,7 @@ def get(cls, pk=None, label=None, machinename=None): @classmethod def get_from_string(cls, code_string): - """ - Get a Computer object with given identifier string in the format + """Get a Computer object with given identifier string in the format label@machinename. See the note below for details on the string detection algorithm. @@ -298,8 +284,7 @@ def get_from_string(cls, code_string): @classmethod def list_for_plugin(cls, plugin, labels=True, backend=None): - """ - Return a list of valid code strings for a given plugin. + """Return a list of valid code strings for a given plugin. :param plugin: The string of the plugin. :param labels: if True, return a list of code names, otherwise @@ -329,8 +314,7 @@ def _validate(self): if self.is_local(): if not self.get_local_executable(): raise exceptions.ValidationError( - 'You have to set which file is the local executable ' - 'using the set_exec_filename() method' + 'You have to set which file is the local executable ' 'using the set_exec_filename() method' ) if self.get_local_executable() not in self.base.repository.list_object_names(): raise exceptions.ValidationError( @@ -357,7 +341,7 @@ def validate_remote_exec_path(self): warn_deprecation( '`Code.validate_remote_exec_path` method is deprecated, use the ' '`InstalledCode.validate_filepath_executable` property instead.', - version=3 + version=3, ) filepath = self.get_remote_exec_path() @@ -368,7 +352,7 @@ def validate_remote_exec_path(self): with override_log_level(): # Temporarily suppress noisy logging with self.computer.get_transport() as transport: file_exists = transport.isfile(filepath) - except Exception: # pylint: disable=broad-except + except Exception: raise exceptions.ValidationError( 'Could not connect to the configured computer to determine whether the specified executable exists.' ) @@ -379,8 +363,7 @@ def validate_remote_exec_path(self): ) def set_prepend_text(self, code): - """ - Pass a string of code that will be put in the scheduler script before the + """Pass a string of code that will be put in the scheduler script before the execution of the code. """ warn_deprecation( @@ -389,8 +372,7 @@ def set_prepend_text(self, code): self.prepend_text = code def get_prepend_text(self): - """ - Return the code that will be put in the scheduler script before the + """Return the code that will be put in the scheduler script before the execution, or an empty string if no pre-exec code was defined. """ warn_deprecation( @@ -399,24 +381,22 @@ def get_prepend_text(self): return self.prepend_text def set_input_plugin_name(self, input_plugin): - """ - Set the name of the default input plugin, to be used for the automatic + """Set the name of the default input plugin, to be used for the automatic generation of a new calculation. """ warn_deprecation( '`Code.set_input_plugin_name` method is deprecated, use the `default_calc_job_plugin` property instead.', - version=3 + version=3, ) self.default_calc_job_plugin = input_plugin def get_input_plugin_name(self): - """ - Return the name of the default input plugin (or None if no input plugin + """Return the name of the default input plugin (or None if no input plugin was set. """ warn_deprecation( '`Code.get_input_plugin_name` method is deprecated, use the `default_calc_job_plugin` property instead.', - version=3 + version=3, ) return self.default_calc_job_plugin @@ -427,7 +407,7 @@ def set_use_double_quotes(self, use_double_quotes: bool): """ warn_deprecation( '`Code.set_use_double_quotes` method is deprecated, use the `use_double_quotes` property instead.', - version=3 + version=3, ) self.use_double_quotes = use_double_quotes @@ -438,13 +418,12 @@ def get_use_double_quotes(self) -> bool: """ warn_deprecation( '`Code.get_use_double_quotes` method is deprecated, use the `use_double_quotes` property instead.', - version=3 + version=3, ) return self.use_double_quotes def set_append_text(self, code): - """ - Pass a string of code that will be put in the scheduler script after the + """Pass a string of code that will be put in the scheduler script after the execution of the code. """ warn_deprecation( @@ -453,17 +432,14 @@ def set_append_text(self, code): self.append_text = code def get_append_text(self): - """ - Return the postexec_code, or an empty string if no post-exec code was defined. - """ + """Return the postexec_code, or an empty string if no post-exec code was defined.""" warn_deprecation( '`Code.get_append_text` method is deprecated, use the `append_text` property instead.', version=3 ) return self.append_text def set_local_executable(self, exec_name): - """ - Set the filename of the local executable. + """Set the filename of the local executable. Implicitly set the code as local. """ warn_deprecation('`Code.set_local_executable` method is deprecated, use `PortableCode`.', version=3) @@ -475,13 +451,12 @@ def get_local_executable(self): """Return the local executable.""" warn_deprecation( '`Code.get_local_executable` method is deprecated, use `PortableCode.filepath_executable` instead.', - version=3 + version=3, ) return self.filepath_executable def set_remote_computer_exec(self, remote_computer_exec): - """ - Set the code as remote, and pass the computer on which it resides + """Set the code as remote, and pass the computer on which it resides and the absolute path on that computer. :param remote_computer_exec: a tuple (computer, remote_exec_path), where computer is a aiida.orm.Computer and @@ -492,7 +467,7 @@ def set_remote_computer_exec(self, remote_computer_exec): warn_deprecation('`Code.set_remote_computer_exec` method is deprecated, use `InstalledCode`.', version=3) - if (not isinstance(remote_computer_exec, (list, tuple)) or len(remote_computer_exec) != 2): + if not isinstance(remote_computer_exec, (list, tuple)) or len(remote_computer_exec) != 2: raise ValueError( 'remote_computer_exec must be a list or tuple of length 2, with machine and executable name' ) @@ -512,7 +487,7 @@ def get_remote_exec_path(self): """Return the ``remote_exec_path`` attribute.""" warn_deprecation( '`Code.get_remote_exec_path` method is deprecated, use `InstalledCode.filepath_executable` instead.', - version=3 + version=3, ) if self.is_local(): raise ValueError('The code is local') @@ -529,8 +504,7 @@ def get_remote_computer(self): return self.computer def _set_local(self): - """ - Set the code as a 'local' code, meaning that all the files belonging to the code + """Set the code as a 'local' code, meaning that all the files belonging to the code will be copied to the cluster, and the file set with set_exec_filename will be run. @@ -544,8 +518,7 @@ def _set_local(self): pass def _set_remote(self): - """ - Set the code as a 'remote' code, meaning that the code itself has no files attached, + """Set the code as a 'remote' code, meaning that the code itself has no files attached, but only a location on a remote computer (with an absolute path of the executable on the remote computer). @@ -558,8 +531,7 @@ def _set_remote(self): pass def is_local(self): - """ - Return True if the code is 'local', False if it is 'remote' (see also documentation + """Return True if the code is 'local', False if it is 'remote' (see also documentation of the set_local and set_remote functions). """ warn_deprecation( @@ -568,8 +540,7 @@ def is_local(self): return self.base.attributes.get('is_local', None) def can_run_on(self, computer): - """ - Return True if this code can run on the given computer, False otherwise. + """Return True if this code can run on the given computer, False otherwise. Local codes can run on any machine; remote codes can run only on the machine on which they reside. @@ -588,8 +559,7 @@ def can_run_on(self, computer): return computer.pk == self.get_remote_computer().pk def get_execname(self): - """ - Return the executable string to be put in the script. + """Return the executable string to be put in the script. For local codes, it is ./LOCAL_EXECUTABLE_NAME For remote codes, it is the absolute path to the executable. """ diff --git a/aiida/orm/nodes/data/code/portable.py b/aiida/orm/nodes/data/code/portable.py index 6fac593818..426fac16a6 100644 --- a/aiida/orm/nodes/data/code/portable.py +++ b/aiida/orm/nodes/data/code/portable.py @@ -62,7 +62,7 @@ def _validate(self): :raises :class:`aiida.common.exceptions.ValidationError`: If the state of the node is invalid. """ - super(Code, self)._validate() # Change to ``super()._validate()`` once deprecated ``Code`` class is removed. # pylint: disable=bad-super-call + super(Code, self)._validate() # Change to ``super()._validate()`` once deprecated ``Code`` class is removed. try: filepath_executable = self.filepath_executable @@ -162,7 +162,7 @@ def _get_cli_options(cls) -> dict: 'type': click.Path(exists=True, file_okay=False, dir_okay=True, path_type=pathlib.Path), 'prompt': 'Code directory', 'help': 'Filepath to directory containing code files.', - } + }, } options.update(**super()._get_cli_options()) diff --git a/aiida/orm/nodes/data/data.py b/aiida/orm/nodes/data/data.py index 564139691f..988f866750 100644 --- a/aiida/orm/nodes/data/data.py +++ b/aiida/orm/nodes/data/data.py @@ -21,8 +21,7 @@ class Data(Node): - """ - The base class for all Data nodes. + """The base class for all Data nodes. AiiDA Data classes are subclasses of Node and must support multiple inheritance. @@ -30,6 +29,7 @@ class Data(Node): Calculation plugins are responsible for converting raw output data from simulation codes to Data nodes. Nodes are responsible for validating their content (see _validate method). """ + _source_attributes = ['db_name', 'db_uri', 'uri', 'id', 'version', 'extras', 'source_md5', 'description', 'license'] # Replace this with a dictionary in each subclass that, given a file @@ -56,8 +56,7 @@ def __copy__(self): raise exceptions.InvalidOperation('copying a Data node is not supported, use copy.deepcopy') def __deepcopy__(self, memo): - """ - Create a clone of the Data node by piping through to the clone method and return the result. + """Create a clone of the Data node by piping through to the clone method and return the result. :returns: an unstored clone of this Data node """ @@ -73,14 +72,13 @@ def clone(self): backend_clone = self.backend_entity.clone() clone = from_backend_entity(self.__class__, backend_clone) clone.base.attributes.reset(copy.deepcopy(self.base.attributes.all)) - clone.base.repository._clone(self.base.repository) # pylint: disable=protected-access + clone.base.repository._clone(self.base.repository) return clone @property def source(self): - """ - Gets the dictionary describing the source of Data object. Possible fields: + """Gets the dictionary describing the source of Data object. Possible fields: * **db_name**: name of the source database. * **db_uri**: URI of the source database. @@ -100,8 +98,7 @@ def source(self): @source.setter def source(self, source): - """ - Sets the dictionary describing the source of Data object. + """Sets the dictionary describing the source of Data object. :raise KeyError: if dictionary contains unknown field. :raise ValueError: if supplied source description is not a dictionary. @@ -115,9 +112,7 @@ def source(self, source): self.base.attributes.set('source', source) def set_source(self, source): - """ - Sets the dictionary describing the source of Data object. - """ + """Sets the dictionary describing the source of Data object.""" self.source = source @property @@ -135,8 +130,7 @@ def creator(self): @override def _exportcontent(self, fileformat, main_file_name='', **kwargs): - """ - Converts a Data node to one (or multiple) files. + """Converts a Data node to one (or multiple) files. Note: Export plugins should return utf8-encoded **bytes**, which can be directly dumped to file. @@ -161,15 +155,15 @@ def _exportcontent(self, fileformat, main_file_name='', **kwargs): except KeyError: if exporters.keys(): raise ValueError( - 'The format {} is not implemented for {}. ' - 'Currently implemented are: {}.'.format( + 'The format {} is not implemented for {}. ' 'Currently implemented are: {}.'.format( fileformat, self.__class__.__name__, ','.join(exporters.keys()) ) ) else: raise ValueError( - 'The format {} is not implemented for {}. ' - 'No formats are implemented yet.'.format(fileformat, self.__class__.__name__) + 'The format {} is not implemented for {}. ' 'No formats are implemented yet.'.format( + fileformat, self.__class__.__name__ + ) ) string, dictionary = func(main_file_name=main_file_name, **kwargs) @@ -179,8 +173,7 @@ def _exportcontent(self, fileformat, main_file_name='', **kwargs): @override def export(self, path, fileformat=None, overwrite=False, **kwargs): - """ - Save a Data object to a file. + """Save a Data object to a file. :param fname: string with file name. Can be an absolute or relative path. :param fileformat: kind of format to use for the export. If not present, @@ -201,7 +194,7 @@ def export(self, path, fileformat=None, overwrite=False, **kwargs): if fileformat is None: extension = os.path.splitext(path)[1] if extension.startswith(os.path.extsep): - extension = extension[len(os.path.extsep):] + extension = extension[len(os.path.extsep) :] if not extension: raise ValueError('Cannot recognized the fileformat from the extension') @@ -233,8 +226,7 @@ def export(self, path, fileformat=None, overwrite=False, **kwargs): return retlist def _get_exporters(self): - """ - Get all implemented export formats. + """Get all implemented export formats. The convention is to find all _prepare_... methods. Returns a dictionary of method_name: method_function """ @@ -247,21 +239,19 @@ def _get_exporters(self): @classmethod def get_export_formats(cls): - """ - Get the list of valid export format strings + """Get the list of valid export format strings :return: a list of valid formats """ exporter_prefix = '_prepare_' method_names = dir(cls) # get list of class methods names valid_format_names = [ - i[len(exporter_prefix):] for i in method_names if i.startswith(exporter_prefix) + i[len(exporter_prefix) :] for i in method_names if i.startswith(exporter_prefix) ] # filter them return sorted(valid_format_names) def importstring(self, inputstring, fileformat, **kwargs): - """ - Converts a Data object to other text format. + """Converts a Data object to other text format. :param fileformat: a string (the extension) to describe the file format. :returns: a string with the structure description. @@ -273,23 +263,22 @@ def importstring(self, inputstring, fileformat, **kwargs): except KeyError: if importers.keys(): raise ValueError( - 'The format {} is not implemented for {}. ' - 'Currently implemented are: {}.'.format( + 'The format {} is not implemented for {}. ' 'Currently implemented are: {}.'.format( fileformat, self.__class__.__name__, ','.join(importers.keys()) ) ) else: raise ValueError( - 'The format {} is not implemented for {}. ' - 'No formats are implemented yet.'.format(fileformat, self.__class__.__name__) + 'The format {} is not implemented for {}. ' 'No formats are implemented yet.'.format( + fileformat, self.__class__.__name__ + ) ) # func is bound to self by getattr in _get_importers() func(inputstring, **kwargs) def importfile(self, fname, fileformat=None): - """ - Populate a Data object from a file. + """Populate a Data object from a file. :param fname: string with file name. Can be an absolute or relative path. :param fileformat: kind of format to use for the export. If not present, @@ -301,8 +290,7 @@ def importfile(self, fname, fileformat=None): self.importstring(fhandle.read(), fileformat) def _get_importers(self): - """ - Get all implemented import formats. + """Get all implemented import formats. The convention is to find all _parse_... methods. Returns a list of strings. """ @@ -310,18 +298,15 @@ def _get_importers(self): # _parse_"" with the name of the new format importer_prefix = '_parse_' method_names = dir(self) # get list of class methods names - valid_format_names = [i[len(importer_prefix):] for i in method_names if i.startswith(importer_prefix)] + valid_format_names = [i[len(importer_prefix) :] for i in method_names if i.startswith(importer_prefix)] valid_formats = {k: getattr(self, importer_prefix + k) for k in valid_format_names} return valid_formats def convert(self, object_format=None, *args): - """ - Convert the AiiDA StructureData into another python object + """Convert the AiiDA StructureData into another python object :param object_format: Specify the output format """ - # pylint: disable=keyword-arg-before-vararg - if object_format is None: raise ValueError('object_format must be provided') @@ -335,22 +320,21 @@ def convert(self, object_format=None, *args): except KeyError: if converters.keys(): raise ValueError( - 'The format {} is not implemented for {}. ' - 'Currently implemented are: {}.'.format( + 'The format {} is not implemented for {}. ' 'Currently implemented are: {}.'.format( object_format, self.__class__.__name__, ','.join(converters.keys()) ) ) else: raise ValueError( - 'The format {} is not implemented for {}. ' - 'No formats are implemented yet.'.format(object_format, self.__class__.__name__) + 'The format {} is not implemented for {}. ' 'No formats are implemented yet.'.format( + object_format, self.__class__.__name__ + ) ) return func(*args) def _get_converters(self): - """ - Get all implemented converter formats. + """Get all implemented converter formats. The convention is to find all _get_object_... methods. Returns a list of strings. """ @@ -358,6 +342,6 @@ def _get_converters(self): # _prepare_"" with the name of the new format exporter_prefix = '_get_object_' method_names = dir(self) # get list of class methods names - valid_format_names = [i[len(exporter_prefix):] for i in method_names if i.startswith(exporter_prefix)] + valid_format_names = [i[len(exporter_prefix) :] for i in method_names if i.startswith(exporter_prefix)] valid_formats = {k: getattr(self, exporter_prefix + k) for k in valid_format_names} return valid_formats diff --git a/aiida/orm/nodes/data/dict.py b/aiida/orm/nodes/data/dict.py index 833c3ca4a0..38e8e7b8f1 100644 --- a/aiida/orm/nodes/data/dict.py +++ b/aiida/orm/nodes/data/dict.py @@ -82,7 +82,7 @@ def __contains__(self, key: str) -> bool: """Return whether the node contains a key.""" return key in self.base.attributes - def get(self, key: str, default: t.Any | None = None, /): # type: ignore[override] # pylint: disable=arguments-differ + def get(self, key: str, default: t.Any | None = None, /): # type: ignore[override] """Return the value for key if key is in the dictionary, else default. :param key: The key whose value to return. @@ -102,7 +102,7 @@ def set_dict(self, dictionary): # Clear existing attributes and set the new dictionary self.base.attributes.clear() self.update_dict(dictionary) - except exceptions.ModificationNotAllowed: # pylint: disable=try-except-raise + except exceptions.ModificationNotAllowed: # I reraise here to avoid to go in the generic 'except' below that would raise the same exception again raise except Exception: @@ -150,6 +150,7 @@ def dict(self): :return: an instance of the `AttributeResultManager`. """ from aiida.orm.utils.managers import AttributeManager + return AttributeManager(self) diff --git a/aiida/orm/nodes/data/enum.py b/aiida/orm/nodes/data/enum.py index 8c74c043f8..a0e9b3b2de 100644 --- a/aiida/orm/nodes/data/enum.py +++ b/aiida/orm/nodes/data/enum.py @@ -15,8 +15,8 @@ class Color(Enum): members (or enum members) and are functionally constants. The enum members have names and values: the name of ``Color.RED`` is ``RED`` and the value of ``Color.RED`` is ``1``. """ -from enum import Enum import typing as t +from enum import Enum from plumpy.loaders import get_object_loader @@ -57,7 +57,7 @@ def __init__(self, member: Enum, *args, **kwargs): data = { self.KEY_NAME: member.name, self.KEY_VALUE: member.value, - self.KEY_IDENTIFIER: get_object_loader().identify_object(member.__class__) + self.KEY_IDENTIFIER: get_object_loader().identify_object(member.__class__), } self.base.attributes.set_many(data) diff --git a/aiida/orm/nodes/data/jsonable.py b/aiida/orm/nodes/data/jsonable.py index d16670a4e3..796c5e8d5a 100644 --- a/aiida/orm/nodes/data/jsonable.py +++ b/aiida/orm/nodes/data/jsonable.py @@ -10,7 +10,6 @@ class JsonSerializableProtocol(typing.Protocol): - def as_dict(self) -> typing.MutableMapping[typing.Any, typing.Any]: ... diff --git a/aiida/orm/nodes/data/list.py b/aiida/orm/nodes/data/list.py index ef6a8e98ed..c0a022b1fe 100644 --- a/aiida/orm/nodes/data/list.py +++ b/aiida/orm/nodes/data/list.py @@ -62,13 +62,13 @@ def append(self, value): if not self._using_list_reference(): self.set_list(data) - def extend(self, value): # pylint: disable=arguments-renamed + def extend(self, value): data = self.get_list() data.extend(value) if not self._using_list_reference(): self.set_list(data) - def insert(self, i, value): # pylint: disable=arguments-renamed + def insert(self, i, value): data = self.get_list() data.insert(i, value) if not self._using_list_reference(): @@ -81,7 +81,7 @@ def remove(self, value): self.set_list(data) return item - def pop(self, **kwargs): # pylint: disable=arguments-differ + def pop(self, **kwargs): """Remove and return item at index (default last).""" data = self.get_list() item = data.pop(**kwargs) @@ -89,7 +89,7 @@ def pop(self, **kwargs): # pylint: disable=arguments-differ self.set_list(data) return item - def index(self, value): # pylint: disable=arguments-differ + def index(self, value): """Return first index of value..""" return self.get_list().index(value) @@ -130,8 +130,7 @@ def set_list(self, data): self.base.attributes.set(self._LIST_KEY, data.copy()) def _using_list_reference(self): - """ - This function tells the class if we are using a list reference. This + """This function tells the class if we are using a list reference. This means that calls to self.get_list return a reference rather than a copy of the underlying list and therefore self.set_list need not be called. This knwoledge is essential to make sure this class is performant. diff --git a/aiida/orm/nodes/data/orbital.py b/aiida/orm/nodes/data/orbital.py index 32f1640cac..5294aed108 100644 --- a/aiida/orm/nodes/data/orbital.py +++ b/aiida/orm/nodes/data/orbital.py @@ -19,21 +19,18 @@ class OrbitalData(Data): - """ - Used for storing collections of orbitals, as well as + """Used for storing collections of orbitals, as well as providing methods for accessing them internally. """ def clear_orbitals(self): - """ - Remove all orbitals that were added to the class + """Remove all orbitals that were added to the class Cannot work if OrbitalData has been already stored """ self.base.attributes.set('orbital_dicts', []) def get_orbitals(self, **kwargs): - """ - Returns all orbitals by default. If a site is provided, returns + """Returns all orbitals by default. If a site is provided, returns all orbitals cooresponding to the location of that site, additional arguments may be provided, which act as filters on the retrieved orbitals. @@ -42,7 +39,6 @@ def get_orbitals(self, **kwargs): :kwargs: attributes than can filter the set of returned orbitals :return list_of_outputs: a list of orbitals """ - orbital_dicts = copy.deepcopy(self.base.attributes.get('orbital_dicts', None)) if orbital_dicts is None: raise AttributeError('Orbitals must be set before being retrieved') @@ -66,8 +62,7 @@ def get_orbitals(self, **kwargs): return list_of_outputs def set_orbitals(self, orbitals): - """ - Sets the orbitals into the database. Uses the orbital's inherent + """Sets the orbitals into the database. Uses the orbital's inherent set_orbital_dict method to generate a orbital dict string. :param orbital: an orbital or list of orbitals to be set diff --git a/aiida/orm/nodes/data/remote/__init__.py b/aiida/orm/nodes/data/remote/__init__.py index ae1b5dbc4f..f1746f1ab4 100644 --- a/aiida/orm/nodes/data/remote/__init__.py +++ b/aiida/orm/nodes/data/remote/__init__.py @@ -3,8 +3,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .base import * from .stash import * @@ -15,4 +14,4 @@ 'RemoteStashFolderData', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/nodes/data/remote/base.py b/aiida/orm/nodes/data/remote/base.py index 48a9c03985..ca89604bb2 100644 --- a/aiida/orm/nodes/data/remote/base.py +++ b/aiida/orm/nodes/data/remote/base.py @@ -18,8 +18,7 @@ class RemoteData(Data): - """ - Store a link to a file or folder on a remote machine. + """Store a link to a file or folder on a remote machine. Remember to pass a computer! """ @@ -44,9 +43,7 @@ def is_cleaned(self): @property def is_empty(self): - """ - Check if remote folder is empty - """ + """Check if remote folder is empty""" if self.is_cleaned: return True @@ -63,8 +60,7 @@ def is_empty(self): return not transport.listdir() def getfile(self, relpath, destpath): - """ - Connects to the remote folder and retrieves the content of a file. + """Connects to the remote folder and retrieves the content of a file. :param relpath: The relative path of the file on the remote to retrieve. :param destpath: The absolute path of where to store the file on the local machine. @@ -79,15 +75,13 @@ def getfile(self, relpath, destpath): if exception.errno == 2: # file does not exist raise IOError( 'The required remote file {} on {} does not exist or has been deleted.'.format( - full_path, - self.computer.label # pylint: disable=no-member + full_path, self.computer.label ) ) from exception raise def listdir(self, relpath='.'): - """ - Connects to the remote folder and lists the directory content. + """Connects to the remote folder and lists the directory content. :param relpath: If 'relpath' is specified, lists the content of the given subfolder. :return: a flat list of file/directory names (as strings). @@ -101,8 +95,8 @@ def listdir(self, relpath='.'): except IOError as exception: if exception.errno in (2, 20): # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. - format(full_path, self.computer.label) # pylint: disable=no-member + f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a ' + 'directory or has been deleted.' ) exc.errno = exception.errno raise exc from exception @@ -114,8 +108,8 @@ def listdir(self, relpath='.'): except IOError as exception: if exception.errno in (2, 20): # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. - format(full_path, self.computer.label) # pylint: disable=no-member + f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a ' + 'directory or has been deleted.' ) exc.errno = exception.errno raise exc from exception @@ -123,8 +117,7 @@ def listdir(self, relpath='.'): raise def listdir_withattributes(self, path='.'): - """ - Connects to the remote folder and lists the directory content. + """Connects to the remote folder and lists the directory content. :param relpath: If 'relpath' is specified, lists the content of the given subfolder. :return: a list of dictionaries, where the documentation is in :py:class:Transport.listdir_withattributes. @@ -138,8 +131,8 @@ def listdir_withattributes(self, path='.'): except IOError as exception: if exception.errno in (2, 20): # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. - format(full_path, self.computer.label) # pylint: disable=no-member + f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a ' + 'directory or has been deleted.' ) exc.errno = exception.errno raise exc from exception @@ -151,8 +144,8 @@ def listdir_withattributes(self, path='.'): except IOError as exception: if exception.errno in (2, 20): # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. - format(full_path, self.computer.label) # pylint: disable=no-member + f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a ' + 'directory or has been deleted.' ) exc.errno = exception.errno raise exc from exception @@ -175,7 +168,7 @@ def _clean(self, transport=None): remote_dir = self.get_remote_path() if transport is None: - with self.get_authinfo().get_transport() as transport: # pylint: disable=redefined-argument-from-local + with self.get_authinfo().get_transport() as transport: clean_remote(transport, remote_dir) else: if transport.hostname != self.computer.hostname: diff --git a/aiida/orm/nodes/data/remote/stash/__init__.py b/aiida/orm/nodes/data/remote/stash/__init__.py index e06481e842..a6f25b367c 100644 --- a/aiida/orm/nodes/data/remote/stash/__init__.py +++ b/aiida/orm/nodes/data/remote/stash/__init__.py @@ -3,8 +3,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .base import * from .folder import * @@ -14,4 +13,4 @@ 'RemoteStashFolderData', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/nodes/data/singlefile.py b/aiida/orm/nodes/data/singlefile.py index ed9d9079df..c0f4cd4f9d 100644 --- a/aiida/orm/nodes/data/singlefile.py +++ b/aiida/orm/nodes/data/singlefile.py @@ -48,7 +48,6 @@ def __init__( Hint: Pass io.BytesIO(b"my string") to construct the SinglefileData directly from a string. :param filename: specify filename to use (defaults to name of provided file). """ - # pylint: disable=redefined-builtin super().__init__(**kwargs) if file is not None: @@ -85,9 +84,9 @@ def open(self, path: None = None, mode: t.Literal['rb'] = ...) -> t.Iterator[t.B ... @contextlib.contextmanager - def open(self, - path: FilePath | None = None, - mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]: + def open( + self, path: FilePath | None = None, mode: t.Literal['r', 'rb'] = 'r' + ) -> t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]: """Return an open file handle to the content of this data node. :param path: the relative path of the object within the repository. @@ -128,8 +127,6 @@ def set_file(self, file: str | pathlib.Path | t.IO, filename: str | pathlib.Path Hint: Pass io.BytesIO(b"my string") to construct the file directly from a string. :param filename: specify filename to use (defaults to name of provided file). """ - # pylint: disable=redefined-builtin - if isinstance(file, (str, pathlib.Path)): is_filelike = False diff --git a/aiida/orm/nodes/data/structure.py b/aiida/orm/nodes/data/structure.py index 20800a5497..6884d29a5e 100644 --- a/aiida/orm/nodes/data/structure.py +++ b/aiida/orm/nodes/data/structure.py @@ -7,9 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines -""" -This module defines the classes for structures and all related +"""This module defines the classes for structures and all related functions to operate on them. """ import copy @@ -26,9 +24,9 @@ # Threshold used to check if the mass of two different Site objects is the same. -_MASS_THRESHOLD = 1.e-3 +_MASS_THRESHOLD = 1.0e-3 # Threshold to check if the sum is one or not -_SUM_THRESHOLD = 1.e-6 +_SUM_THRESHOLD = 1.0e-6 # Default cell _DEFAULT_CELL = ((0, 0, 0), (0, 0, 0), (0, 0, 0)) @@ -38,8 +36,7 @@ def _get_valid_cell(inputcell): - """ - Return the cell in a valid format from a generic input. + """Return the cell in a valid format from a generic input. :raise ValueError: whenever the format is not valid. """ @@ -56,8 +53,7 @@ def _get_valid_cell(inputcell): def get_valid_pbc(inputpbc): - """ - Return a list of three booleans for the periodic boundary conditions, + """Return a list of three booleans for the periodic boundary conditions, in a valid format from a generic input. :raise ValueError: if the format is not valid. @@ -87,31 +83,25 @@ def get_valid_pbc(inputpbc): def has_ase(): - """ - :return: True if the ase module can be imported, False otherwise. - """ + """:return: True if the ase module can be imported, False otherwise.""" try: - import ase # pylint: disable=unused-import + import ase # noqa: F401 except ImportError: return False return True def has_pymatgen(): - """ - :return: True if the pymatgen module can be imported, False otherwise. - """ + """:return: True if the pymatgen module can be imported, False otherwise.""" try: - import pymatgen # pylint: disable=unused-import + import pymatgen # noqa: F401 except ImportError: return False return True def get_pymatgen_version(): - """ - :return: string with pymatgen version, None if can not import. - """ + """:return: string with pymatgen version, None if can not import.""" if not has_pymatgen(): return None try: @@ -123,30 +113,27 @@ def get_pymatgen_version(): def has_spglib(): - """ - :return: True if the spglib module can be imported, False otherwise. - """ + """:return: True if the spglib module can be imported, False otherwise.""" try: - import spglib # pylint: disable=unused-import + import spglib # noqa: F401 except ImportError: return False return True def calc_cell_volume(cell): - """ - Compute the three-dimensional cell volume in Angstrom^3. + """Compute the three-dimensional cell volume in Angstrom^3. :param cell: the cell vectors; the must be a 3x3 list of lists of floats :returns: the cell volume. """ import numpy as np + return np.abs(np.dot(cell[0], np.cross(cell[1], cell[2]))) def _create_symbols_tuple(symbols): - """ - Returns a tuple with the symbols provided. If a string is provided, + """Returns a tuple with the symbols provided. If a string is provided, this is converted to a tuple with one single element. """ if isinstance(symbols, str): @@ -157,15 +144,14 @@ def _create_symbols_tuple(symbols): def _create_weights_tuple(weights): - """ - Returns a tuple with the weights provided. If a number is provided, + """Returns a tuple with the weights provided. If a number is provided, this is converted to a tuple with one single element. If None is provided, this is converted to the tuple (1.,) """ import numbers if weights is None: - weights_tuple = (1.,) + weights_tuple = (1.0,) elif isinstance(weights, numbers.Number): weights_tuple = (weights,) else: @@ -174,8 +160,7 @@ def _create_weights_tuple(weights): def create_automatic_kind_name(symbols, weights): - """ - Create a string obtained with the symbols appended one + """Create a string obtained with the symbols appended one after the other, without spaces, in alphabetical order; if the site has a vacancy, a X is appended at the end too. """ @@ -188,8 +173,7 @@ def create_automatic_kind_name(symbols, weights): def validate_weights_tuple(weights_tuple, threshold): - """ - Validates the weight of the atomic kinds. + """Validates the weight of the atomic kinds. :raise: ValueError if the weights_tuple is not valid. @@ -202,13 +186,12 @@ def validate_weights_tuple(weights_tuple, threshold): Each element of the list must be >= 0, and the sum must be <= 1. """ w_sum = sum(weights_tuple) - if (any(i < 0. for i in weights_tuple) or (w_sum - 1. > threshold)): + if any(i < 0.0 for i in weights_tuple) or (w_sum - 1.0 > threshold): raise ValueError('The weight list is not valid (each element must be positive, and the sum must be <= 1).') def is_valid_symbol(symbol): - """ - Validates the chemical symbol name. + """Validates the chemical symbol name. :return: True if the symbol is a valid chemical symbol (with correct capitalization), or the dummy X, False otherwise. @@ -220,8 +203,7 @@ def is_valid_symbol(symbol): def validate_symbols_tuple(symbols_tuple): - """ - Used to validate whether the chemical species are valid. + """Used to validate whether the chemical species are valid. :param symbols_tuple: a tuple (or list) with the chemical symbols name. :raises: UnsupportedSpeciesError if any symbol in the tuple is not a valid chemical @@ -240,8 +222,7 @@ def validate_symbols_tuple(symbols_tuple): def is_ase_atoms(ase_atoms): - """ - Check if the ase_atoms parameter is actually a ase.Atoms object. + """Check if the ase_atoms parameter is actually a ase.Atoms object. :param ase_atoms: an object, expected to be an ase.Atoms. :return: a boolean. @@ -249,16 +230,16 @@ def is_ase_atoms(ase_atoms): Requires the ability to import ase, by doing 'import ase'. """ import ase + return isinstance(ase_atoms, ase.Atoms) def group_symbols(_list): - """ - Group a list of symbols to a list containing the number of consecutive + """Group a list of symbols to a list containing the number of consecutive identical symbols, and the symbol itself. - Examples: - + Examples + -------- * ``['Ba','Ti','O','O','O','Ba']`` will return ``[[1,'Ba'],[1,'Ti'],[3,'O'],[1,'Ba']]`` @@ -268,7 +249,6 @@ def group_symbols(_list): :param _list: a list of elements representing a chemical formula :return: a list of length-2 lists of the form [ multiplicity , element ] """ - the_list = copy.deepcopy(_list) the_list.reverse() grouped_list = [[1, the_list.pop()]] @@ -284,9 +264,10 @@ def group_symbols(_list): def get_formula_from_symbol_list(_list, separator=''): - """ - Return a string with the formula obtained from the list of symbols. - Examples: + """Return a string with the formula obtained from the list of symbols. + + Examples + -------- * ``[[1,'Ba'],[1,'Ti'],[3,'O']]`` will return ``'BaTiO3'`` * ``[[2, [ [1, 'Ba'], [1, 'Ti'] ] ]]`` will return ``'(BaTi)2'`` @@ -296,7 +277,6 @@ def get_formula_from_symbol_list(_list, separator=''): :return: a string """ - list_str = [] for elem in _list: if elem[0] == 1: @@ -315,8 +295,7 @@ def get_formula_from_symbol_list(_list, separator=''): def get_formula_group(symbol_list, separator=''): - """ - Return a string with the chemical formula from a list of chemical symbols. + """Return a string with the chemical formula from a list of chemical symbols. The formula is written in a compact" way, i.e. trying to group as much as possible parts of the formula. @@ -334,8 +313,7 @@ def get_formula_group(symbol_list, separator=''): """ def group_together(_list, group_size, offset): - """ - :param _list: a list + """:param _list: a list :param group_size: size of the groups :param offset: beginning grouping after offset elements :return : a list of lists made of groups of size group_size @@ -345,7 +323,6 @@ def group_together(_list, group_size, offset): ``group_together(['O','Ba','Ti','Ba','Ti'],2,1) = ['O',['Ba','Ti'],['Ba','Ti']]`` """ - the_list = copy.deepcopy(_list) the_list.reverse() grouped_list = [] @@ -362,8 +339,7 @@ def group_together(_list, group_size, offset): return grouped_list def cleanout_symbol_list(_list): - """ - :param _list: a list of groups of symbols and multiplicities + """:param _list: a list of groups of symbols and multiplicities :return : a list where all groups with multiplicity 1 have been reduced to minimum example: ``[[1,[[1,'Ba']]]]`` will return ``[[1,'Ba']]`` @@ -378,8 +354,7 @@ def cleanout_symbol_list(_list): return the_list def group_together_symbols(_list, group_size): - """ - Successive application of group_together, group_symbols and + """Successive application of group_together, group_symbols and cleanout_symbol_list, in order to group a symbol list, scanning all possible offsets, for a given group size :param _list: the symbol list (see function group_symbols) @@ -403,8 +378,7 @@ def group_together_symbols(_list, group_size): return the_symbol_list, has_grouped def group_all_together_symbols(_list): - """ - Successive application of the function group_together_symbols, to group + """Successive application of the function group_together_symbols, to group a symbol list, scanning all possible offsets and group sizes :param _list: the symbol list (see function group_symbols) :return: the new grouped symbol list @@ -437,8 +411,7 @@ def group_all_together_symbols(_list): def get_formula(symbol_list, mode='hill', separator=''): - """ - Return a string with the chemical formula. + """Return a string with the chemical formula. :param symbol_list: a list of symbols, e.g. ``['H','H','O']`` :param mode: a string to specify how to generate the formula, can @@ -449,7 +422,7 @@ def get_formula(symbol_list, mode='hill', separator=''): first if one or several C atom(s) is (are) present, e.g. ``['C','H','H','H','O','C','H','H','H']`` will return ``'C2H6O'`` ``['S','O','O','H','O','H','O']`` will return ``'H2O4S'`` - From E. A. Hill, J. Am. Chem. Soc., 22 (8), pp 478–494 (1900) + From E. A. Hill, J. Am. Chem. Soc., 22 (8), pp 478-494 (1900) * 'hill_compact': same as hill but the number of atoms for each species is divided by the greatest common divisor of all of them, e.g. @@ -485,7 +458,6 @@ def get_formula(symbol_list, mode='hill', separator=''): initial order in which the atoms were appended by the user is used to group and/or order the symbols in the formula """ - if mode == 'group': return get_formula_group(symbol_list, separator=separator) @@ -511,6 +483,7 @@ def get_formula(symbol_list, mode='hill', separator=''): if mode in ['hill_compact', 'count_compact']: from math import gcd + the_gcd = functools.reduce(gcd, [e[0] for e in the_symbol_list]) the_symbol_list = [[e[0] // the_gcd, e[1]] for e in the_symbol_list] @@ -518,8 +491,7 @@ def get_formula(symbol_list, mode='hill', separator=''): def get_symbols_string(symbols, weights): - """ - Return a string that tries to match as good as possible the symbols + """Return a string that tries to match as good as possible the symbols and weights. If there is only one symbol (no alloy) with 100% occupancy, just returns the symbol name. Otherwise, groups the full string in curly brackets, and try to write also the composition @@ -533,7 +505,7 @@ def get_symbols_string(symbols, weights): .. note:: Note the difference with respect to the symbols and the symbol properties! """ - if len(symbols) == 1 and weights[0] == 1.: + if len(symbols) == 1 and weights[0] == 1.0: return symbols[0] pieces = [] @@ -545,19 +517,17 @@ def get_symbols_string(symbols, weights): def has_vacancies(weights): - """ - Returns True if the sum of the weights is less than one. + """Returns True if the sum of the weights is less than one. It uses the internal variable _SUM_THRESHOLD as a threshold. :param weights: the weights :return: a boolean """ w_sum = sum(weights) - return not 1. - w_sum < _SUM_THRESHOLD + return not 1.0 - w_sum < _SUM_THRESHOLD def symop_ortho_from_fract(cell): - """ - Creates a matrix for conversion from orthogonal to fractional + """Creates a matrix for conversion from orthogonal to fractional coordinates. Taken from @@ -566,7 +536,6 @@ def symop_ortho_from_fract(cell): :param cell: array of cell parameters (three lengths and three angles) """ - # pylint: disable=invalid-name import math import numpy @@ -576,13 +545,17 @@ def symop_ortho_from_fract(cell): ca, cb, cg = [math.cos(x) for x in [alpha, beta, gamma]] sg = math.sin(gamma) - return numpy.array([[a, b * cg, c * cb], [0, b * sg, c * (ca - cb * cg) / sg], - [0, 0, c * math.sqrt(sg * sg - ca * ca - cb * cb + 2 * ca * cb * cg) / sg]]) + return numpy.array( + [ + [a, b * cg, c * cb], + [0, b * sg, c * (ca - cb * cg) / sg], + [0, 0, c * math.sqrt(sg * sg - ca * ca - cb * cb + 2 * ca * cb * cg) / sg], + ] + ) def symop_fract_from_ortho(cell): - """ - Creates a matrix for conversion from fractional to orthogonal + """Creates a matrix for conversion from fractional to orthogonal coordinates. Taken from @@ -591,7 +564,6 @@ def symop_fract_from_ortho(cell): :param cell: array of cell parameters (three lengths and three angles) """ - # pylint: disable=invalid-name import math import numpy @@ -601,18 +573,19 @@ def symop_fract_from_ortho(cell): ca, cb, cg = [math.cos(x) for x in [alpha, beta, gamma]] sg = math.sin(gamma) ctg = cg / sg - D = math.sqrt(sg * sg - cb * cb - ca * ca + 2 * ca * cb * cg) + D = math.sqrt(sg * sg - cb * cb - ca * ca + 2 * ca * cb * cg) # noqa: N806 - return numpy.array([ - [1.0 / a, -(1.0 / a) * ctg, (ca * cg - cb) / (a * D)], - [0, 1.0 / (b * sg), -(ca - cb * cg) / (b * D * sg)], - [0, 0, sg / (c * D)], - ]) + return numpy.array( + [ + [1.0 / a, -(1.0 / a) * ctg, (ca * cg - cb) / (a * D)], + [0, 1.0 / (b * sg), -(ca - cb * cg) / (b * D * sg)], + [0, 0, sg / (c * D)], + ] + ) def ase_refine_cell(aseatoms, **kwargs): - """ - Detect the symmetry of the structure, remove symmetric atoms and + """Detect the symmetry of the structure, remove symmetric atoms and refine unit cell. :param aseatoms: an ase.atoms.Atoms instance @@ -622,6 +595,7 @@ def ase_refine_cell(aseatoms, **kwargs): """ from ase.atoms import Atoms from spglib import get_symmetry_dataset, refine_cell + cell, positions, numbers = refine_cell(aseatoms, **kwargs) refined_atoms = Atoms(numbers, scaled_positions=positions, cell=cell, pbc=True) @@ -642,36 +616,35 @@ def ase_refine_cell(aseatoms, **kwargs): 'hall': sym_dataset['hall'], 'tables': sym_dataset['number'], 'rotations': sym_dataset['rotations'], - 'translations': sym_dataset['translations'] + 'translations': sym_dataset['translations'], } def atom_kinds_to_html(atom_kind): - """ - - Construct in html format + """Construct in html format an alloy with 0.5 Ge, 0.4 Si and 0.1 vacancy is represented as Ge0.5 + Si0.4 + vacancy0.1 Args: + ----- atom_kind: a string with the name of the atomic kind, as printed by kind.get_symbols_string(), e.g. Ba0.80Ca0.10X0.10 Returns: + -------- html code for rendered formula """ - # Parse the formula (TODO can be made more robust though never fails if # it takes strings generated with kind.get_symbols_string()) import re + matched_elements = re.findall(r'([A-Z][a-z]*)([0-1][.[0-9]*]?)?', atom_kind) # Compose the html string html_formula_pieces = [] for element in matched_elements: - # replace element X by 'vacancy' species = element[0] if element[0] != 'X' else 'vacancy' weight = element[1] if element[1] != '' else None @@ -687,34 +660,35 @@ def atom_kinds_to_html(atom_kind): class StructureData(Data): - """ - This class contains the information about a given structure, i.e. a - collection of sites together with a cell, the - boundary conditions (whether they are periodic or not) and other - related useful information. - """ + """Data class that represents an atomic structure. - # pylint: disable=too-many-public-methods + The data is organized as a collection of sites together with a cell, the boundary conditions (whether they are + periodic or not) and other related useful information. + """ - _set_incompatibilities = [('ase', 'cell'), ('ase', 'pbc'), ('ase', 'pymatgen'), ('ase', 'pymatgen_molecule'), - ('ase', 'pymatgen_structure'), ('cell', 'pymatgen'), ('cell', 'pymatgen_molecule'), - ('cell', 'pymatgen_structure'), ('pbc', 'pymatgen'), ('pbc', 'pymatgen_molecule'), - ('pbc', 'pymatgen_structure'), ('pymatgen', 'pymatgen_molecule'), - ('pymatgen', 'pymatgen_structure'), ('pymatgen_molecule', 'pymatgen_structure')] + _set_incompatibilities = [ + ('ase', 'cell'), + ('ase', 'pbc'), + ('ase', 'pymatgen'), + ('ase', 'pymatgen_molecule'), + ('ase', 'pymatgen_structure'), + ('cell', 'pymatgen'), + ('cell', 'pymatgen_molecule'), + ('cell', 'pymatgen_structure'), + ('pbc', 'pymatgen'), + ('pbc', 'pymatgen_molecule'), + ('pbc', 'pymatgen_structure'), + ('pymatgen', 'pymatgen_molecule'), + ('pymatgen', 'pymatgen_structure'), + ('pymatgen_molecule', 'pymatgen_structure'), + ] _dimensionality_label = {0: '', 1: 'length', 2: 'surface', 3: 'volume'} _internal_kind_tags = None def __init__( - self, - cell=None, - pbc=None, - ase=None, - pymatgen=None, - pymatgen_structure=None, - pymatgen_molecule=None, - **kwargs - ): # pylint: disable=too-many-arguments + self, cell=None, pbc=None, ase=None, pymatgen=None, pymatgen_structure=None, pymatgen_molecule=None, **kwargs + ): args = { 'cell': cell, 'pbc': pbc, @@ -731,7 +705,6 @@ def __init__( super().__init__(**kwargs) if any(ext is not None for ext in [ase, pymatgen, pymatgen_structure, pymatgen_molecule]): - if ase is not None: self.set_ase(ase) @@ -754,8 +727,7 @@ def __init__( self.set_pbc(pbc) def get_dimensionality(self): - """ - Return the dimensionality of the structure and its length/surface/volume. + """Return the dimensionality of the structure and its length/surface/volume. Zero-dimensional structures are assigned "volume" 0. @@ -765,9 +737,7 @@ def get_dimensionality(self): return _get_dimensionality(self.pbc, self.cell) def set_ase(self, aseatoms): - """ - Load the structure from a ASE object - """ + """Load the structure from a ASE object""" if is_ase_atoms(aseatoms): # Read the ase structure self.cell = aseatoms.cell @@ -779,8 +749,7 @@ def set_ase(self, aseatoms): raise TypeError('The value is not an ase.Atoms object') def set_pymatgen(self, obj, **kwargs): - """ - Load the structure from a pymatgen object. + """Load the structure from a pymatgen object. .. note:: Requires the pymatgen module (version >= 3.0.13, usage of earlier versions may cause errors). @@ -793,8 +762,7 @@ def set_pymatgen(self, obj, **kwargs): func(obj, **kwargs) def set_pymatgen_molecule(self, mol, margin=5): - """ - Load the structure from a pymatgen Molecule object. + """Load the structure from a pymatgen Molecule object. :param margin: the margin to be added in all directions of the bounding box of the molecule. @@ -805,14 +773,13 @@ def set_pymatgen_molecule(self, mol, margin=5): box = [ max(x.coords.tolist()[0] for x in mol.sites) - min(x.coords.tolist()[0] for x in mol.sites) + 2 * margin, max(x.coords.tolist()[1] for x in mol.sites) - min(x.coords.tolist()[1] for x in mol.sites) + 2 * margin, - max(x.coords.tolist()[2] for x in mol.sites) - min(x.coords.tolist()[2] for x in mol.sites) + 2 * margin + max(x.coords.tolist()[2] for x in mol.sites) - min(x.coords.tolist()[2] for x in mol.sites) + 2 * margin, ] self.set_pymatgen_structure(mol.get_boxed_structure(*box)) self.pbc = [False, False, False] def set_pymatgen_structure(self, struct): - """ - Load the structure from a pymatgen Structure object. + """Load the structure from a pymatgen Structure object. .. note:: periodic boundary conditions are set to True in all three directions. @@ -823,8 +790,7 @@ def set_pymatgen_structure(self, struct): """ def build_kind_name(species_and_occu): - """ - Build a kind name from a pymatgen Composition, including an additional ordinal if spin is included, + """Build a kind name from a pymatgen Composition, including an additional ordinal if spin is included, e.g. it returns '1' for an atom with spin < 0 and '2' for an atom with spin > 0, otherwise (no spin) it returns None @@ -845,13 +811,12 @@ def build_kind_name(species_and_occu): else: has_spin = any(specie.as_dict().get('properties', {}).get('spin', 0) != 0 for specie in species) - has_partial_occupancies = (len(occupations) != 1 or occupations[0] != 1.0) + has_partial_occupancies = len(occupations) != 1 or occupations[0] != 1.0 if has_partial_occupancies and has_spin: raise ValueError('Cannot set partial occupancies and spins at the same time') if has_spin: - symbols = [specie.symbol for specie in species] kind_name = create_automatic_kind_name(symbols, occupations) @@ -876,7 +841,6 @@ def build_kind_name(species_and_occu): self.clear_kinds() for site in struct.sites: - species_and_occu = site.species if 'kind_name' in site.properties: @@ -887,7 +851,7 @@ def build_kind_name(species_and_occu): inputs = { 'symbols': [x.symbol for x in species_and_occu.keys()], 'weights': list(species_and_occu.values()), - 'position': site.coords.tolist() + 'position': site.coords.tolist(), } if kind_name is not None: @@ -896,10 +860,7 @@ def build_kind_name(species_and_occu): self.append_atom(**inputs) def _validate(self): - """ - Performs some standard validation tests. - """ - + """Performs some standard validation tests.""" from aiida.common.exceptions import ValidationError super()._validate() @@ -939,16 +900,14 @@ def _validate(self): if site.kind_name not in [k.name for k in kinds]: raise ValidationError(f'A site has kind {site.kind_name}, but no specie with that name exists') - kinds_without_sites = (set(k.name for k in kinds) - set(s.kind_name for s in sites)) + kinds_without_sites = set(k.name for k in kinds) - set(s.kind_name for s in sites) if kinds_without_sites: raise ValidationError( f'The following kinds are defined, but there are no sites with that kind: {list(kinds_without_sites)}' ) - def _prepare_xsf(self, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given structure to a string of format XSF (for XCrySDen). - """ + def _prepare_xsf(self, main_file_name=''): + """Write the given structure to a string of format XSF (for XCrySDen).""" if self.is_alloy or self.has_vacancies: raise NotImplementedError('XSF for alloys or systems with vacancies not implemented.') @@ -967,20 +926,15 @@ def _prepare_xsf(self, main_file_name=''): # pylint: disable=unused-argument return_string += '%18.10f %18.10f %18.10f\n' % tuple(site.position) return return_string.encode('utf-8'), {} - def _prepare_cif(self, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given structure to a string of format CIF. - """ + def _prepare_cif(self, main_file_name=''): + """Write the given structure to a string of format CIF.""" from aiida.orm import CifData cif = CifData(ase=self.get_ase()) - return cif._prepare_cif() # pylint: disable=protected-access + return cif._prepare_cif() - def _prepare_chemdoodle(self, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given structure to a string of format required by ChemDoodle. - """ - # pylint: disable=too-many-locals,invalid-name + def _prepare_chemdoodle(self, main_file_name=''): + """Write the given structure to a string of format required by ChemDoodle.""" from itertools import product import numpy as np @@ -1006,23 +960,24 @@ def _prepare_chemdoodle(self, main_file_name=''): # pylint: disable=unused-argu atoms_json = [] # Manual recenter of the structure - center = (lattice_vectors[0] + lattice_vectors[1] + lattice_vectors[2]) / 2. + center = (lattice_vectors[0] + lattice_vectors[1] + lattice_vectors[2]) / 2.0 for ix, iy, iz in product(grid1, grid2, grid3): for base_site in base_sites: - shift = (ix * lattice_vectors[0] + iy * lattice_vectors[1] + \ - iz * lattice_vectors[2] - center).tolist() + shift = (ix * lattice_vectors[0] + iy * lattice_vectors[1] + iz * lattice_vectors[2] - center).tolist() kind_name = base_site['kind_name'] kind_string = self.get_kind(kind_name).get_symbols_string() - atoms_json.append({ - 'l': kind_string, - 'x': base_site['position'][0] + shift[0], - 'y': base_site['position'][1] + shift[1], - 'z': base_site['position'][2] + shift[2], - 'atomic_elements_html': atom_kinds_to_html(kind_string) - }) + atoms_json.append( + { + 'l': kind_string, + 'x': base_site['position'][0] + shift[0], + 'y': base_site['position'][1] + shift[1], + 'z': base_site['position'][2] + shift[2], + 'atomic_elements_html': atom_kinds_to_html(kind_string), + } + ) cell_json = { 't': 'UnitCell', @@ -1041,10 +996,8 @@ def _prepare_chemdoodle(self, main_file_name=''): # pylint: disable=unused-argu return json.dumps(return_dict).encode('utf-8'), {} - def _prepare_xyz(self, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given structure to a string of format XYZ. - """ + def _prepare_xyz(self, main_file_name=''): + """Write the given structure to a string of format XYZ.""" if self.is_alloy or self.has_vacancies: raise NotImplementedError('XYZ for alloys or systems with vacancies not implemented.') @@ -1054,8 +1007,18 @@ def _prepare_xyz(self, main_file_name=''): # pylint: disable=unused-argument return_list = [f'{len(sites)}'] return_list.append( 'Lattice="{} {} {} {} {} {} {} {} {}" pbc="{} {} {}"'.format( - cell[0][0], cell[0][1], cell[0][2], cell[1][0], cell[1][1], cell[1][2], cell[2][0], cell[2][1], - cell[2][2], self.pbc[0], self.pbc[1], self.pbc[2] + cell[0][0], + cell[0][1], + cell[0][2], + cell[1][0], + cell[1][1], + cell[1][2], + cell[2][0], + cell[2][1], + cell[2][2], + self.pbc[0], + self.pbc[1], + self.pbc[2], ) ) for site in sites: @@ -1071,9 +1034,7 @@ def _prepare_xyz(self, main_file_name=''): # pylint: disable=unused-argument return return_string.encode('utf-8'), {} def _parse_xyz(self, inputstring): - """ - Read the structure from a string of format XYZ. - """ + """Read the structure from a string of format XYZ.""" from aiida.tools.data.structure import xyz_parser_iterator # idiom to get to the last block @@ -1091,17 +1052,13 @@ def _parse_xyz(self, inputstring): self.append_atom(symbols=sym, position=position) def _adjust_default_cell(self, vacuum_factor=1.0, vacuum_addition=10.0, pbc=(False, False, False)): - """ - If the structure was imported from an xyz file, it lacks a cell. + """If the structure was imported from an xyz file, it lacks a cell. This method will adjust the cell """ - # pylint: disable=invalid-name import numpy as np def get_extremas_from_positions(positions): - """ - returns the minimum and maximum value for each dimension in the positions given - """ + """Returns the minimum and maximum value for each dimension in the positions given""" return list(zip(*[(min(values), max(values)) for values in zip(*positions)])) # Calculating the minimal cell: @@ -1130,8 +1087,7 @@ def get_extremas_from_positions(positions): return self def get_description(self): - """ - Returns a string with infos retrieved from StructureData node's properties + """Returns a string with infos retrieved from StructureData node's properties :param self: the StructureData node :return: retsrt: the description string @@ -1139,8 +1095,7 @@ def get_description(self): return self.get_formula(mode='hill_compact') def get_symbols_set(self): - """ - Return a set containing the names of all elements involved in + """Return a set containing the names of all elements involved in this structure (i.e., for it joins the list of symbols for each kind k in the structure). @@ -1149,8 +1104,7 @@ def get_symbols_set(self): return set(itertools.chain.from_iterable(kind.symbols for kind in self.kinds)) def get_formula(self, mode='hill', separator=''): - """ - Return a string with the chemical formula. + """Return a string with the chemical formula. :param mode: a string to specify how to generate the formula, can assume one of the following values: @@ -1160,7 +1114,7 @@ def get_formula(self, mode='hill', separator=''): first if one or several C atom(s) is (are) present, e.g. ``['C','H','H','H','O','C','H','H','H']`` will return ``'C2H6O'`` ``['S','O','O','H','O','H','O']`` will return ``'H2O4S'`` - From E. A. Hill, J. Am. Chem. Soc., 22 (8), pp 478–494 (1900) + From E. A. Hill, J. Am. Chem. Soc., 22 (8), pp 478-494 (1900) * 'hill_compact': same as hill but the number of atoms for each species is divided by the greatest common divisor of all of them, e.g. @@ -1196,14 +1150,12 @@ def get_formula(self, mode='hill', separator=''): initial order in which the atoms were appended by the user is used to group and/or order the symbols in the formula """ - symbol_list = [self.get_kind(s.kind_name).get_symbols_string() for s in self.sites] return get_formula(symbol_list, mode=mode, separator=separator) def get_site_kindnames(self): - """ - Return a list with length equal to the number of sites of this structure, + """Return a list with length equal to the number of sites of this structure, where each element of the list is the kind name of the corresponding site. .. note:: This is NOT necessarily a list of chemical symbols! Use @@ -1215,8 +1167,7 @@ def get_site_kindnames(self): return [this_site.kind_name for this_site in self.sites] def get_composition(self, mode='full'): - """ - Returns the chemical composition of this structure as a dictionary, + """Returns the chemical composition of this structure as a dictionary, where each key is the kind symbol (e.g. H, Li, Ba), and each value is the number of occurences of that element in this structure. @@ -1230,6 +1181,7 @@ def get_composition(self, mode='full'): :returns: a dictionary with the composition """ import numpy as np + symbols_list = [self.get_kind(s.kind_name).get_symbols_string() for s in self.sites] symbols_set = set(symbols_list) @@ -1247,8 +1199,7 @@ def get_composition(self, mode='full'): raise ValueError(f'mode `{mode}` is invalid, choose from `full`, `reduced` or `fractional`.') def get_ase(self): - """ - Get the ASE object. + """Get the ASE object. Requires to be able to import ase. :return: an ASE object corresponding to this @@ -1261,8 +1212,7 @@ def get_ase(self): return self._get_object_ase() def get_pymatgen(self, **kwargs): - """ - Get pymatgen object. Returns Structure for structures with + """Get pymatgen object. Returns Structure for structures with periodic boundary conditions (in three dimensions) and Molecule otherwise. :param add_spin: True to add the spins to the pymatgen structure. @@ -1280,8 +1230,7 @@ def get_pymatgen(self, **kwargs): return self._get_object_pymatgen(**kwargs) def get_pymatgen_structure(self, **kwargs): - """ - Get the pymatgen Structure object. + """Get the pymatgen Structure object. :param add_spin: True to add the spins to the pymatgen structure. Default is False (no spin added). @@ -1303,8 +1252,7 @@ def get_pymatgen_structure(self, **kwargs): return self._get_object_pymatgen_structure(**kwargs) def get_pymatgen_molecule(self): - """ - Get the pymatgen Molecule object. + """Get the pymatgen Molecule object. .. note:: Requires the pymatgen module (version >= 3.0.13, usage of earlier versions may cause errors). @@ -1316,8 +1264,7 @@ def get_pymatgen_molecule(self): return self._get_object_pymatgen_molecule() def append_kind(self, kind): - """ - Append a kind to the + """Append a kind to the :py:class:`StructureData `. It makes a copy of the kind. @@ -1339,11 +1286,10 @@ def append_kind(self, kind): if self._internal_kind_tags is None: self._internal_kind_tags = {} - self._internal_kind_tags[len(self.base.attributes.get('kinds')) - 1] = kind._internal_tag # pylint: disable=protected-access + self._internal_kind_tags[len(self.base.attributes.get('kinds')) - 1] = kind._internal_tag def append_site(self, site): - """ - Append a site to the + """Append a site to the :py:class:`StructureData `. It makes a copy of the site. @@ -1365,8 +1311,7 @@ def append_site(self, site): self.base.attributes.all.setdefault('sites', []).append(new_site.get_raw()) def append_atom(self, **kwargs): - """ - Append an atom to the Structure, taking care of creating the + """Append an atom to the Structure, taking care of creating the corresponding kind. :param ase: the ase Atom object from which we want to create a new atom @@ -1396,14 +1341,11 @@ def append_atom(self, **kwargs): .. note :: checks of equality of species are done using the :py:meth:`~aiida.orm.nodes.data.structure.Kind.compare_with` method. """ - # pylint: disable=too-many-branches aseatom = kwargs.pop('ase', None) if aseatom is not None: if kwargs: raise ValueError( - "If you pass 'ase' as a parameter to " - 'append_atom, you cannot pass any further' - 'parameter' + "If you pass 'ase' as a parameter to " 'append_atom, you cannot pass any further' 'parameter' ) position = aseatom.position kind = Kind(ase=aseatom) @@ -1423,7 +1365,7 @@ def append_atom(self, **kwargs): exists_already = False for idx, existing_kind in enumerate(_kinds): try: - existing_kind._internal_tag = self._internal_kind_tags[idx] # pylint: disable=protected-access + existing_kind._internal_tag = self._internal_kind_tags[idx] except KeyError: # self._internal_kind_tags does not contain any info for # the kind in position idx: I don't have to add anything @@ -1470,8 +1412,7 @@ def append_atom(self, **kwargs): self.append_site(site) def clear_kinds(self): - """ - Removes all kinds for the StructureData object. + """Removes all kinds for the StructureData object. .. note:: Also clear all sites! """ @@ -1485,9 +1426,7 @@ def clear_kinds(self): self.clear_sites() def clear_sites(self): - """ - Removes all sites for the StructureData object. - """ + """Removes all sites for the StructureData object.""" from aiida.common.exceptions import ModificationNotAllowed if self.is_stored: @@ -1497,9 +1436,7 @@ def clear_sites(self): @property def sites(self): - """ - Returns a list of sites. - """ + """Returns a list of sites.""" try: raw_sites = self.base.attributes.get('sites') except AttributeError: @@ -1508,9 +1445,7 @@ def sites(self): @property def kinds(self): - """ - Returns a list of kinds. - """ + """Returns a list of kinds.""" try: raw_kinds = self.base.attributes.get('kinds') except AttributeError: @@ -1518,8 +1453,7 @@ def kinds(self): return [Kind(raw=i) for i in raw_kinds] def get_kind(self, kind_name): - """ - Return the kind object associated with the given kind name. + """Return the kind object associated with the given kind name. :param kind_name: String, the name of the kind you want to get @@ -1533,7 +1467,7 @@ def get_kind(self, kind_name): try: kinds_dict = self._kinds_cache except AttributeError: - self._kinds_cache = {_.name: _ for _ in self.kinds} # pylint: disable=attribute-defined-outside-init + self._kinds_cache = {_.name: _ for _ in self.kinds} kinds_dict = self._kinds_cache else: kinds_dict = {_.name: _ for _ in self.kinds} @@ -1545,8 +1479,7 @@ def get_kind(self, kind_name): raise ValueError(f"Kind name '{kind_name}' unknown") def get_kind_names(self): - """ - Return a list of kind names (in the same order of the ``self.kinds`` + """Return a list of kind names (in the same order of the ``self.kinds`` property, but return the names rather than Kind objects) .. note:: This is NOT necessarily a list of chemical symbols! Use @@ -1558,8 +1491,7 @@ def get_kind_names(self): @property def cell(self): - """ - Returns the cell shape. + """Returns the cell shape. :return: a 3x3 list of lists. """ @@ -1581,8 +1513,7 @@ def set_cell(self, value): self.base.attributes.set('cell', the_cell) def reset_cell(self, new_cell): - """ - Reset the cell of a structure not yet stored to a new value. + """Reset the cell of a structure not yet stored to a new value. :param new_cell: list specifying the cell vectors @@ -1597,8 +1528,7 @@ def reset_cell(self, new_cell): self.base.attributes.set('cell', new_cell) def reset_sites_positions(self, new_positions, conserve_particle=True): - """ - Replace all the Site positions attached to the Structure + """Replace all the Site positions attached to the Structure :param new_positions: list of (3D) positions for every sites. @@ -1620,7 +1550,6 @@ def reset_sites_positions(self, new_positions, conserve_particle=True): if not conserve_particle: raise NotImplementedError else: - # test consistency of th enew input n_sites = len(self.sites) if n_sites != len(new_positions) and conserve_particle: @@ -1648,8 +1577,7 @@ def reset_sites_positions(self, new_positions, conserve_particle=True): @property def pbc(self): - """ - Get the periodic boundary conditions. + """Get the periodic boundary conditions. :return: a tuple of three booleans, each one tells if there are periodic boundary conditions for the i-th real-space direction (i=1,2,3) @@ -1677,9 +1605,7 @@ def set_pbc(self, value): @property def cell_lengths(self): - """ - Get the lengths of cell lattice vectors in angstroms. - """ + """Get the lengths of cell lattice vectors in angstroms.""" import numpy cell = self.cell @@ -1698,15 +1624,14 @@ def set_cell_lengths(self, value): @property def cell_angles(self): - """ - Get the angles between the cell lattice vectors in degrees. - """ + """Get the angles between the cell lattice vectors in degrees.""" import numpy cell = self.cell lengths = self.cell_lengths return [ - float(numpy.arccos(x) / numpy.pi * 180) for x in [ + float(numpy.arccos(x) / numpy.pi * 180) + for x in [ numpy.vdot(cell[1], cell[2]) / lengths[1] / lengths[2], numpy.vdot(cell[0], cell[2]) / lengths[0] / lengths[2], numpy.vdot(cell[0], cell[1]) / lengths[0] / lengths[1], @@ -1737,8 +1662,7 @@ def has_vacancies(self): return any(kind.has_vacancies for kind in self.kinds) def get_cell_volume(self): - """ - Returns the three-dimensional cell volume in Angstrom^3. + """Returns the three-dimensional cell volume in Angstrom^3. Use the `get_dimensionality` method in order to get the area/length of lower-dimensional cells. @@ -1747,8 +1671,7 @@ def get_cell_volume(self): return calc_cell_volume(self.cell) def get_cif(self, converter='ase', store=False, **kwargs): - """ - Creates :py:class:`aiida.orm.nodes.data.cif.CifData`. + """Creates :py:class:`aiida.orm.nodes.data.cif.CifData`. .. versionadded:: 1.0 Renamed from _get_cif @@ -1771,12 +1694,11 @@ def get_cif(self, converter='ase', store=False, **kwargs): return ret_dict['cif'] def _get_object_phonopyatoms(self): - """ - Converts StructureData to PhonopyAtoms + """Converts StructureData to PhonopyAtoms :return: a PhonopyAtoms object """ - from phonopy.structure.atoms import PhonopyAtoms # pylint: disable=import-error,no-name-in-module + from phonopy.structure.atoms import PhonopyAtoms atoms = PhonopyAtoms(symbols=[_.kind_name for _ in self.sites]) # Phonopy internally uses scaled positions, so you must store cell first! @@ -1786,8 +1708,7 @@ def _get_object_phonopyatoms(self): return atoms def _get_object_ase(self): - """ - Converts + """Converts :py:class:`StructureData ` to ase.Atoms @@ -1803,8 +1724,7 @@ def _get_object_ase(self): return asecell def _get_object_pymatgen(self, **kwargs): - """ - Converts + """Converts :py:class:`StructureData ` to pymatgen object @@ -1820,8 +1740,7 @@ def _get_object_pymatgen(self, **kwargs): return self._get_object_pymatgen_molecule(**kwargs) def _get_object_pymatgen_structure(self, **kwargs): - """ - Converts + """Converts :py:class:`StructureData ` to pymatgen Structure object :param add_spin: True to add the spins to the pymatgen structure. @@ -1851,23 +1770,24 @@ def _get_object_pymatgen_structure(self, **kwargs): species = [] additional_kwargs = {} - if (kwargs.pop('add_spin', False) and any(n.endswith('1') or n.endswith('2') for n in self.get_kind_names())): + if kwargs.pop('add_spin', False) and any(n.endswith('1') or n.endswith('2') for n in self.get_kind_names()): # case when spins are defined -> no partial occupancy allowed from pymatgen.core.periodic_table import Specie + oxidation_state = 0 # now I always set the oxidation_state to zero for site in self.sites: kind = self.get_kind(site.kind_name) - if len(kind.symbols) != 1 or (len(kind.weights) != 1 or sum(kind.weights) < 1.): + if len(kind.symbols) != 1 or (len(kind.weights) != 1 or sum(kind.weights) < 1.0): raise ValueError('Cannot set partial occupancies and spins at the same time') spin = -1 if kind.name.endswith('1') else 1 if kind.name.endswith('2') else 0 try: - specie = Specie(kind.symbols[0], oxidation_state, properties={'spin': spin}) # pylint: disable=unexpected-keyword-arg + specie = Specie(kind.symbols[0], oxidation_state, properties={'spin': spin}) except TypeError: # As of v2023.9.2, the ``properties`` argument is removed and the ``spin`` argument should be used. # See: https://github.com/materialsproject/pymatgen/commit/118c245d6082fe0b13e19d348fc1db9c0d512019 # The ``spin`` argument was introduced in v2023.6.28. # See: https://github.com/materialsproject/pymatgen/commit/9f2b3939af45d5129e0778d371d814811924aeb6 - specie = Specie(kind.symbols[0], oxidation_state, spin=spin) # pylint: disable=unexpected-keyword-arg + specie = Specie(kind.symbols[0], oxidation_state, spin=spin) species.append(specie) else: # case when no spin are defined @@ -1875,8 +1795,8 @@ def _get_object_pymatgen_structure(self, **kwargs): kind = self.get_kind(site.kind_name) species.append(dict(zip(kind.symbols, kind.weights))) if any( - create_automatic_kind_name(self.get_kind(name).symbols, - self.get_kind(name).weights) != name for name in self.get_site_kindnames() + create_automatic_kind_name(self.get_kind(name).symbols, self.get_kind(name).weights) != name + for name in self.get_site_kindnames() ): # add "kind_name" as a properties to each site, whenever # the kind_name cannot be automatically obtained from the symbols @@ -1889,8 +1809,7 @@ def _get_object_pymatgen_structure(self, **kwargs): return Structure(self.cell, species, positions, coords_are_cartesian=True, **additional_kwargs) def _get_object_pymatgen_molecule(self, **kwargs): - """ - Converts + """Converts :py:class:`StructureData ` to pymatgen Molecule object @@ -1916,15 +1835,13 @@ def _get_object_pymatgen_molecule(self, **kwargs): class Kind: - """ - This class contains the information about the species (kinds) of the system. + """This class contains the information about the species (kinds) of the system. It can be a single atom, or an alloy, or even contain vacancies. """ def __init__(self, **kwargs): - """ - Create a site. + """Create a site. One can either pass: :param raw: the raw python dictionary that will be converted to a @@ -1946,7 +1863,6 @@ def __init__(self, **kwargs): :param name: a string that uniquely identifies the kind, and that is used to identify the sites. """ - # pylint: disable=too-many-branches,too-many-statements # Internal variables self._mass = None self._symbols = None @@ -2005,7 +1921,8 @@ def __init__(self, **kwargs): try: import numpy - self.set_symbols_and_weights([aseatom.symbol], [1.]) + + self.set_symbols_and_weights([aseatom.symbol], [1.0]) # ASE sets mass to numpy.nan for unstable species if not numpy.isnan(aseatom.mass): self.mass = aseatom.mass @@ -2043,8 +1960,7 @@ def __init__(self, **kwargs): raise ValueError(f'Unrecognized parameters passed to Kind constructor: {kwargs.keys()}') def get_raw(self): - """ - Return the raw version of the site, mapped to a suitable dictionary. + """Return the raw version of the site, mapped to a suitable dictionary. This is the format that is actually used to store each kind of the structure in the DB. @@ -2058,8 +1974,7 @@ def get_raw(self): } def reset_mass(self): - """ - Reset the mass to the automatic calculated value. + """Reset the mass to the automatic calculated value. The mass can be set manually; by default, if not provided, it is the mass of the constituent atoms, weighted with their @@ -2084,8 +1999,7 @@ def reset_mass(self): @property def name(self): - """ - Return the name of this kind. + """Return the name of this kind. The name of a kind is used to identify the species of a site. :return: a string @@ -2094,14 +2008,11 @@ def name(self): @name.setter def name(self, value): - """ - Set the name of this site (a string). - """ + """Set the name of this site (a string).""" self._name = str(value) def set_automatic_kind_name(self, tag=None): - """ - Set the type to a string obtained with the symbols appended one + """Set the type to a string obtained with the symbols appended one after the other, without spaces, in alphabetical order; if the site has a vacancy, a X is appended at the end too. """ @@ -2112,8 +2023,7 @@ def set_automatic_kind_name(self, tag=None): self.name = f'{name_string}{tag}' def compare_with(self, other_kind): - """ - Compare with another Kind object to check if they are different. + """Compare with another Kind object to check if they are different. .. note:: This does NOT check the 'type' attribute. Instead, it compares (with reasonable thresholds, where applicable): the mass, and the list @@ -2147,12 +2057,8 @@ def compare_with(self, other_kind): if abs(self.mass - other_kind.mass) > _MASS_THRESHOLD: return (False, f'Masses are different ({self.mass} vs. {other_kind.mass})') - if self._internal_tag != other_kind._internal_tag: # pylint: disable=protected-access - return ( - False, - 'Internal tags are different ({} vs. {})' - ''.format(self._internal_tag, other_kind._internal_tag) # pylint: disable=protected-access - ) + if self._internal_tag != other_kind._internal_tag: + return (False, f'Internal tags are different ({self._internal_tag} vs. {other_kind._internal_tag})') # If we got here, the two Site objects are similar enough # to be considered of the same kind @@ -2160,8 +2066,7 @@ def compare_with(self, other_kind): @property def mass(self): - """ - The mass of this species kind. + """The mass of this species kind. :return: a float """ @@ -2176,16 +2081,14 @@ def mass(self, value): @property def weights(self): - """ - Weights for this species kind. Refer also to + """Weights for this species kind. Refer also to :func:validate_symbols_tuple for the validation rules on the weights. """ return copy.deepcopy(self._weights) @weights.setter def weights(self, value): - """ - If value is a number, a single weight is used. Otherwise, a list or + """If value is a number, a single weight is used. Otherwise, a list or tuple of numbers is expected. None is also accepted, corresponding to the list [1.]. """ @@ -2193,16 +2096,14 @@ def weights(self, value): if len(weights_tuple) != len(self._symbols): raise ValueError( - 'Cannot change the number of weights. Use the ' - 'set_symbols_and_weights function instead.' + 'Cannot change the number of weights. Use the ' 'set_symbols_and_weights function instead.' ) validate_weights_tuple(weights_tuple, _SUM_THRESHOLD) self._weights = weights_tuple def get_symbols_string(self): - """ - Return a string that tries to match as good as possible the symbols + """Return a string that tries to match as good as possible the symbols of this kind. If there is only one symbol (no alloy) with 100% occupancy, just returns the symbol name. Otherwise, groups the full string in curly brackets, and try to write also the composition @@ -2219,8 +2120,7 @@ def get_symbols_string(self): @property def symbol(self): - """ - If the kind has only one symbol, return it; otherwise, raise a + """If the kind has only one symbol, return it; otherwise, raise a ValueError. """ if len(self._symbols) == 1: @@ -2230,8 +2130,7 @@ def symbol(self): @property def symbols(self): - """ - List of symbols for this site. If the site is a single atom, + """List of symbols for this site. If the site is a single atom, pass a list of one element only, or simply the string for that atom. For alloys, a list of elements. @@ -2242,8 +2141,7 @@ def symbols(self): @symbols.setter def symbols(self, value): - """ - If value is a string, a single symbol is used. Otherwise, a list or + """If value is a string, a single symbol is used. Otherwise, a list or tuple of strings is expected. I set a copy of the list, so to avoid that the content changes @@ -2253,16 +2151,14 @@ def symbols(self, value): if len(symbols_tuple) != len(self._weights): raise ValueError( - 'Cannot change the number of symbols. Use the ' - 'set_symbols_and_weights function instead.' + 'Cannot change the number of symbols. Use the ' 'set_symbols_and_weights function instead.' ) validate_symbols_tuple(symbols_tuple) self._symbols = symbols_tuple def set_symbols_and_weights(self, symbols, weights): - """ - Set the chemical symbols and the weights for the site. + """Set the chemical symbols and the weights for the site. .. note:: Note that the kind name remains unchanged. """ @@ -2294,7 +2190,7 @@ def has_vacancies(self): return has_vacancies(self._weights) def __repr__(self): - return f'<{self.__class__.__name__}: {str(self)}>' + return f'<{self.__class__.__name__}: {self!s}>' def __str__(self): symbol = self.get_symbols_string() @@ -2302,15 +2198,13 @@ def __str__(self): class Site: - """ - This class contains the information about a given site of the system. + """This class contains the information about a given site of the system. It can be a single atom, or an alloy, or even contain vacancies. """ def __init__(self, **kwargs): - """ - Create a site. + """Create a site. :param kind_name: a string that identifies the kind (species) of this site. This has to be found in the list of kinds of the StructureData @@ -2351,8 +2245,7 @@ def __init__(self, **kwargs): raise ValueError(f'Unrecognized parameters: {kwargs.keys}') def get_raw(self): - """ - Return the raw version of the site, mapped to a suitable dictionary. + """Return the raw version of the site, mapped to a suitable dictionary. This is the format that is actually used to store each site of the structure in the DB. @@ -2364,15 +2257,13 @@ def get_raw(self): } def get_ase(self, kinds): - """ - Return a ase.Atom object for this site. + """Return a ase.Atom object for this site. :param kinds: the list of kinds from the StructureData object. .. note:: If any site is an alloy or has vacancies, a ValueError is raised (from the site.get_ase() routine). """ - # pylint: disable=too-many-branches from collections import defaultdict import ase @@ -2430,13 +2321,12 @@ def get_ase(self, kinds): raise ValueError('Cannot convert to ASE if the kind represents an alloy or it has vacancies.') aseatom = ase.Atom(position=self.position, symbol=str(kind.symbols[0]), mass=kind.mass) if tag is not None: - aseatom.tag = tag # pylint: disable=assigning-non-slot + aseatom.tag = tag return aseatom @property def kind_name(self): - """ - Return the kind name of this site (a string). + """Return the kind name of this site (a string). The type of a site is used to decide whether two sites are identical (same mass, symbols, weights, ...) or not. @@ -2445,23 +2335,19 @@ def kind_name(self): @kind_name.setter def kind_name(self, value): - """ - Set the type of this site (a string). - """ + """Set the type of this site (a string).""" self._kind_name = str(value) @property def position(self): - """ - Return the position of this site in absolute coordinates, + """Return the position of this site in absolute coordinates, in angstrom. """ return copy.deepcopy(self._position) @position.setter def position(self, value): - """ - Set the position of this site in absolute coordinates, + """Set the position of this site in absolute coordinates, in angstrom. """ try: @@ -2474,22 +2360,20 @@ def position(self, value): self._position = internal_pos def __repr__(self): - return f'<{self.__class__.__name__}: {str(self)}>' + return f'<{self.__class__.__name__}: {self!s}>' def __str__(self): return f"kind name '{self.kind_name}' @ {self.position[0]},{self.position[1]},{self.position[2]}" def _get_dimensionality(pbc, cell): - """ - Return the dimensionality of the structure and its length/surface/volume. + """Return the dimensionality of the structure and its length/surface/volume. Zero-dimensional structures are assigned "volume" 0. :return: returns a dictionary with keys "dim" (dimensionality integer), "label" (dimensionality label) and "value" (numerical length/surface/volume). """ - import numpy as np retdict = {} @@ -2500,7 +2384,7 @@ def _get_dimensionality(pbc, cell): dim = len(pbc[pbc]) retdict['dim'] = dim - retdict['label'] = StructureData._dimensionality_label[dim] # pylint: disable=protected-access + retdict['label'] = StructureData._dimensionality_label[dim] if dim not in (0, 1, 2, 3): raise ValueError(f'Dimensionality {dim} must be one of 0, 1, 2, 3') @@ -2520,9 +2404,7 @@ def _get_dimensionality(pbc, cell): def _validate_dimensionality(pbc, cell): - """ - Check whether the given pbc and cell vectors are consistent. - """ + """Check whether the given pbc and cell vectors are consistent.""" dim = _get_dimensionality(pbc, cell) # 0-d structures put no constraints on the cell diff --git a/aiida/orm/nodes/data/upf.py b/aiida/orm/nodes/data/upf.py index f937af8021..378a579215 100644 --- a/aiida/orm/nodes/data/upf.py +++ b/aiida/orm/nodes/data/upf.py @@ -24,20 +24,26 @@ def emit_deprecation(): warn_deprecation( 'The `aiida.orm.nodes.data.upf` module is deprecated. For details how to replace it, please see ' 'https://aiida-pseudo.readthedocs.io/en/latest/howto.html#migrate-from-legacy-upfdata-from-aiida-core.', - version=3 + version=3, ) -REGEX_UPF_VERSION = re.compile(r""" +REGEX_UPF_VERSION = re.compile( + r""" \s*.*)"> - """, re.VERBOSE) + """, + re.VERBOSE, +) -REGEX_ELEMENT_V1 = re.compile(r""" +REGEX_ELEMENT_V1 = re.compile( + r""" (?P[a-zA-Z]{1,2}) \s+ Element - """, re.VERBOSE) + """, + re.VERBOSE, +) REGEX_ELEMENT_V2 = re.compile( r""" @@ -45,7 +51,8 @@ def emit_deprecation(): element\s*=\s*(?P['"])\s* (?P[a-zA-Z]{1,2})\s* (?P=quote_symbol) - """, re.VERBOSE + """, + re.VERBOSE, ) @@ -93,7 +100,6 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T :param stop_if_existing: if True, check for the md5 of the files and, if the file already exists in the DB, raises a MultipleObjectsError. If False, simply adds the existing UPFData node to the group. """ - # pylint: disable=too-many-locals,too-many-branches,too-many-statements import os from aiida import orm @@ -196,15 +202,13 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T def parse_upf(fname, check_filename=True, encoding='utf-8'): - """ - Try to get relevant information from the UPF. For the moment, only the + """Try to get relevant information from the UPF. For the moment, only the element name. Note that even UPF v.2 cannot be parsed with the XML minidom! (e.g. due to the & characters in the human-readable section). If check_filename is True, raise a ParsingError exception if the filename does not start with the element name. """ - # pylint: disable=too-many-branches import os from aiida.common import AIIDA_LOGGER @@ -260,9 +264,9 @@ def parse_upf(fname, check_filename=True, encoding='utf-8'): if check_filename: if not os.path.basename(fname).lower().startswith(element.lower()): raise ParsingError( - 'Filename {0} was recognized for element ' - '{1}, but the filename does not start ' - 'with {1}'.format(fname, element) + 'Filename {0} was recognized for element ' '{1}, but the filename does not start ' 'with {1}'.format( + fname, element + ) ) parsed_data['element'] = element @@ -316,7 +320,7 @@ def __init__(self, *args, **kwargs): emit_deprecation() super().__init__(*args, **kwargs) - def store(self, *args, **kwargs): # pylint: disable=signature-differs + def store(self, *args, **kwargs): """Store the node, reparsing the file so that the md5 and the element are correctly reset.""" from aiida.common.exceptions import ParsingError from aiida.common.files import md5_from_filelike @@ -356,6 +360,7 @@ def from_md5(cls, md5, backend=None): :return: list of existing `UpfData` nodes that have the same md5 hash """ from aiida.orm.querybuilder import QueryBuilder + builder = QueryBuilder(backend=backend) builder.append(cls, filters={'attributes.md5': {'==': md5}}) return builder.all(flat=True) @@ -367,7 +372,6 @@ def set_file(self, file, filename=None): Hint: Pass io.BytesIO(b"my string") to construct the file directly from a string. :param filename: specify filename to use (defaults to name of provided file). """ - # pylint: disable=redefined-builtin from aiida.common.exceptions import ParsingError from aiida.common.files import md5_file, md5_from_filelike @@ -454,10 +458,7 @@ def _validate(self): raise ValidationError(f"Attribute 'md5' says '{attr_md5}' but '{md5}' was parsed instead.") def _prepare_upf(self, main_file_name=''): - """ - Return UPF content. - """ - # pylint: disable=unused-argument + """Return UPF content.""" return_string = self.get_content() return return_string.encode('utf-8'), {} @@ -502,11 +503,8 @@ def get_upf_groups(cls, filter_elements=None, user=None, backend=None): return builder.all(flat=True) - # pylint: disable=unused-argument def _prepare_json(self, main_file_name=''): - """ - Returns UPF PP in json format. - """ + """Returns UPF PP in json format.""" with self.open() as file_handle: upf_json = upf_to_json(file_handle.read(), fname=self.filename) return json.dumps(upf_json).encode('utf-8'), {} diff --git a/aiida/orm/nodes/links.py b/aiida/orm/nodes/links.py index 46e157eadf..49ceee793a 100644 --- a/aiida/orm/nodes/links.py +++ b/aiida/orm/nodes/links.py @@ -14,7 +14,7 @@ from ..utils.links import LinkManager, LinkTriple if t.TYPE_CHECKING: - from .node import Node # pylint: disable=unused-import + from .node import Node class NodeLinks: @@ -79,19 +79,21 @@ def validate_incoming(self, source: 'Node', link_type: LinkType, link_label: str """ from aiida.orm.utils.links import validate_link - from .node import Node # pylint: disable=redefined-outer-name + from .node import Node validate_link(source, self._node, link_type, link_label, backend=self._node.backend) # Check if the proposed link would introduce a cycle in the graph following ancestor/descendant rules if link_type in [LinkType.CREATE, LinkType.INPUT_CALC, LinkType.INPUT_WORK]: - builder = QueryBuilder(backend=self._node.backend).append( - Node, filters={'id': self._node.pk}, tag='parent').append( - Node, filters={'id': source.pk}, tag='child', with_ancestors='parent') # yapf:disable + builder = ( + QueryBuilder(backend=self._node.backend) + .append(Node, filters={'id': self._node.pk}, tag='parent') + .append(Node, filters={'id': source.pk}, tag='child', with_ancestors='parent') + ) if builder.count() > 0: raise ValueError('the link you are attempting to create would generate a cycle in the graph') - def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str) -> None: # pylint: disable=unused-argument + def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str) -> None: """Validate adding a link of the given type from ourself to a given node. The validity of the triple (source, link, target) should be validated in the `validate_incoming` call. @@ -104,7 +106,8 @@ def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str :raise TypeError: if `target` is not a Node instance or `link_type` is not a `LinkType` enum :raise ValueError: if the proposed link is invalid """ - from .node import Node # pylint: disable=redefined-outer-name + from .node import Node + type_check(link_type, LinkType, f'link_type should be a LinkType enum but got: {type(link_type)}') type_check(target, Node, f'target should be a `Node` instance but got: {type(target)}') @@ -114,7 +117,7 @@ def get_stored_link_triples( link_type: t.Union[LinkType, t.Sequence[LinkType]] = (), link_label_filter: t.Optional[str] = None, link_direction: str = 'incoming', - only_uuid: bool = False + only_uuid: bool = False, ) -> list[LinkTriple]: """Return the list of stored link triples directly incoming to or outgoing of this node. @@ -127,7 +130,7 @@ def get_stored_link_triples( :param link_direction: `incoming` or `outgoing` to get the incoming or outgoing links, respectively. :param only_uuid: project only the node UUID instead of the instance onto the `NodeTriple.node` entries """ - from .node import Node # pylint: disable=redefined-outer-name + from .node import Node if not isinstance(link_type, (tuple, list)): link_type = cast(t.Sequence[LinkType], (link_type,)) @@ -155,7 +158,7 @@ def get_stored_link_triples( with_incoming='main', project=node_project, edge_project=['type', 'label'], - edge_filters=edge_filters + edge_filters=edge_filters, ) else: builder.append( @@ -163,7 +166,7 @@ def get_stored_link_triples( with_outgoing='main', project=node_project, edge_project=['type', 'label'], - edge_filters=edge_filters + edge_filters=edge_filters, ) return [LinkTriple(entry[0], LinkType(entry[1]), entry[2]) for entry in builder.all()] @@ -173,7 +176,7 @@ def get_incoming( node_class: Optional[t.Type['Node']] = None, link_type: t.Union[LinkType, t.Sequence[LinkType]] = (), link_label_filter: t.Optional[str] = None, - only_uuid: bool = False + only_uuid: bool = False, ) -> LinkManager: """Return a list of link triples that are (directly) incoming into this node. @@ -197,9 +200,8 @@ def get_incoming( # Get all cached link triples for link_triple in self.incoming_cache: - if only_uuid: - link_triple = LinkTriple( + link_triple = LinkTriple( # noqa: PLW2901 link_triple.node.uuid, # type: ignore[arg-type] link_triple.link_type, link_triple.link_label, @@ -224,7 +226,7 @@ def get_outgoing( node_class: Optional[t.Type['Node']] = None, link_type: t.Union[LinkType, t.Sequence[LinkType]] = (), link_label_filter: t.Optional[str] = None, - only_uuid: bool = False + only_uuid: bool = False, ) -> LinkManager: """Return a list of link triples that are (directly) outgoing of this node. diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index f7f951c987..fa00a319f6 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-arguments """Package for node ORM classes.""" import datetime from functools import cached_property @@ -20,7 +19,7 @@ from aiida.common.links import LinkType from aiida.common.warnings import warn_deprecation from aiida.manage import get_manager -from aiida.orm.utils.node import ( # pylint: disable=unused-import +from aiida.orm.utils.node import ( AbstractNodeMeta, get_query_type_from_type_string, get_type_string_from_class, @@ -41,11 +40,12 @@ if TYPE_CHECKING: from importlib_metadata import EntryPoint - from ..implementation import BackendNode, StorageBackend + from ..implementation import StorageBackend + from ..implementation.nodes import BackendNode # noqa: F401 __all__ = ('Node',) -NodeType = TypeVar('NodeType', bound='Node') # pylint: disable=invalid-name +NodeType = TypeVar('NodeType', bound='Node') class NodeCollection(EntityCollection[NodeType], Generic[NodeType]): @@ -73,10 +73,9 @@ def delete(self, pk: int) -> None: self._backend.nodes.delete(pk) - def iter_repo_keys(self, - filters: Optional[dict] = None, - subclassing: bool = True, - batch_size: int = 100) -> Iterator[str]: + def iter_repo_keys( + self, filters: Optional[dict] = None, subclassing: bool = True, batch_size: int = 100 + ) -> Iterator[str]: """Iterate over all repository object keys for this ``Node`` class .. note:: keys will not be deduplicated, wrap in a ``set`` to achieve this @@ -86,9 +85,10 @@ def iter_repo_keys(self, :param batch_size: The number of nodes to fetch data for at once """ from aiida.repository import Repository + query = QueryBuilder(backend=self.backend) query.append(self.entity_type, subclassing=subclassing, filters=filters, project=['repository_metadata']) - for metadata, in query.iterall(batch_size=batch_size): + for (metadata,) in query.iterall(batch_size=batch_size): for key in Repository.flatten(metadata).values(): if key is not None: yield key @@ -109,7 +109,7 @@ def repository(self) -> 'NodeRepository': @cached_property def caching(self) -> 'NodeCaching': """Return an interface to interact with the caching of this node.""" - return self._node._CLS_NODE_CACHING(self._node) # pylint: disable=protected-access + return self._node._CLS_NODE_CACHING(self._node) @cached_property def comments(self) -> 'NodeComments': @@ -129,12 +129,11 @@ def extras(self) -> 'EntityExtras': @cached_property def links(self) -> 'NodeLinks': """Return an interface to interact with the links of this node.""" - return self._node._CLS_NODE_LINKS(self._node) # pylint: disable=protected-access + return self._node._CLS_NODE_LINKS(self._node) class Node(Entity['BackendNode', NodeCollection], metaclass=AbstractNodeMeta): - """ - Base class for all nodes in AiiDA. + """Base class for all nodes in AiiDA. Stores attributes starting with an underscore. @@ -147,7 +146,6 @@ class Node(Entity['BackendNode', NodeCollection], metaclass=AbstractNodeMeta): In the plugin, also set the _plugin_type_string, to be set in the DB in the 'type' field. """ - # pylint: disable=too-many-public-methods _CLS_COLLECTION = NodeCollection _CLS_NODE_LINKS = NodeLinks @@ -157,14 +155,14 @@ class Node(Entity['BackendNode', NodeCollection], metaclass=AbstractNodeMeta): __query_type_string: ClassVar[str] @classproperty - def _plugin_type_string(cls) -> str: + def _plugin_type_string(cls) -> str: # noqa: N805 """Return the plugin type string of this node class.""" if not hasattr(cls, '__plugin_type_string'): cls.__plugin_type_string = get_type_string_from_class(cls.__module__, cls.__name__) # type: ignore[misc] return cls.__plugin_type_string @classproperty - def _query_type_string(cls) -> str: + def _query_type_string(cls) -> str: # noqa: N805 """Return the query type string of this node class.""" if not hasattr(cls, '__query_type_string'): cls.__query_type_string = get_query_type_from_type_string(cls._plugin_type_string) # type: ignore[misc] @@ -192,7 +190,7 @@ def __init__( backend: Optional['StorageBackend'] = None, user: Optional[User] = None, computer: Optional[Computer] = None, - **kwargs: Any + **kwargs: Any, ) -> None: backend = backend or get_manager().get_profile_storage() @@ -215,7 +213,7 @@ def base(self) -> NodeBase: """Return the node base namespace.""" return NodeBase(self) - def _check_mutability_attributes(self, keys: Optional[List[str]] = None) -> None: # pylint: disable=unused-argument + def _check_mutability_attributes(self, keys: Optional[List[str]] = None) -> None: """Check if the entity is mutable and raise an exception if not. This is called from `NodeAttributes` methods that modify the attributes. @@ -236,7 +234,7 @@ def __hash__(self) -> int: return UUID(self.uuid).int def __repr__(self) -> str: - return f'<{self.__class__.__name__}: {str(self)}>' + return f'<{self.__class__.__name__}: {self!s}>' def __str__(self) -> str: if not self.is_stored: @@ -283,19 +281,18 @@ def _validate_storability(self) -> None: ) @classproperty - def class_node_type(cls) -> str: + def class_node_type(cls) -> str: # noqa: N805 """Returns the node type of this node (sub) class.""" - # pylint: disable=no-self-argument,no-member return cls._plugin_type_string @classproperty - def entry_point(cls) -> Optional['EntryPoint']: + def entry_point(cls) -> Optional['EntryPoint']: # noqa: N805 """Return the entry point associated this node class. :return: the associated entry point or ``None`` if it isn't known. """ - # pylint: disable=no-self-argument from aiida.plugins.entry_point import get_entry_point_from_class + return get_entry_point_from_class(cls.__module__, cls.__name__)[1] @property @@ -434,7 +431,7 @@ def store_all(self) -> 'Node': # For each node of a cached incoming link, check that all its incoming links are stored for link_triple in self.base.links.incoming_cache: - link_triple.node._verify_are_parents_stored() # pylint: disable=protected-access + link_triple.node._verify_are_parents_stored() for link_triple in self.base.links.incoming_cache: if not link_triple.node.is_stored: @@ -442,7 +439,7 @@ def store_all(self) -> 'Node': return self.store() - def store(self) -> 'Node': # pylint: disable=arguments-differ + def store(self) -> 'Node': """Store the node in the database while saving its attributes and repository directory. After being called attributes cannot be changed anymore! Instead, extras can be changed only AFTER calling @@ -454,7 +451,6 @@ def store(self) -> 'Node': # pylint: disable=arguments-differ from aiida.manage.caching import get_use_cache if not self.is_stored: - # Call `_validate_storability` directly and not in `_validate` in case sub class forgets to call the super. self._validate_storability() self._validate() @@ -470,7 +466,7 @@ def store(self) -> 'Node': # pylint: disable=arguments-differ self._backend_entity.clean_values() # Retrieve the cached node. - same_node = self.base.caching._get_same_node() if use_cache else None # pylint: disable=protected-access + same_node = self.base.caching._get_same_node() if use_cache else None if same_node is not None: self._store_from_cache(same_node) @@ -488,7 +484,7 @@ def _store(self, clean: bool = True) -> 'Node': :param clean: boolean, if True, will clean the attributes and extras before attempting to store """ - self.base.repository._store() # pylint: disable=protected-access + self.base.repository._store() links = self.base.links.incoming_cache self._backend_entity.store(links, clean=clean) @@ -521,6 +517,7 @@ def _store_from_cache(self, cache_node: 'Node') -> None: """ from aiida.orm.utils.mixins import Sealable + assert self.node_type == cache_node.node_type # Make sure the node doesn't have any RETURN links @@ -531,7 +528,7 @@ def _store_from_cache(self, cache_node: 'Node') -> None: self.description = cache_node.description # Make sure to reinitialize the repository instance of the clone to that of the source node. - self.base.repository._copy(cache_node.base.repository) # pylint: disable=protected-access + self.base.repository._copy(cache_node.base.repository) for key, value in cache_node.base.attributes.all.items(): if key != Sealable.SEALED_KEY: @@ -567,7 +564,7 @@ def is_valid_cache(self) -> bool: warn_deprecation( f'`{kls}.is_valid_cache` is deprecated, use `{kls}.base.caching.is_valid_cache` instead.', version=3, - stacklevel=2 + stacklevel=2, ) return self.base.caching.is_valid_cache @@ -584,7 +581,7 @@ def is_valid_cache(self, valid: bool) -> None: warn_deprecation( f'`{kls}.is_valid_cache` is deprecated, use `{kls}.base.caching.is_valid_cache` instead.', version=3, - stacklevel=2 + stacklevel=2, ) self.base.caching.is_valid_cache = valid @@ -663,7 +660,7 @@ def is_valid_cache(self, valid: bool) -> None: } @classproperty - def Collection(cls): # pylint: disable=invalid-name + def Collection(cls): # noqa: N802, N805 """Return the collection type for this class. This used to be a class argument with the value ``NodeCollection``. The argument is deprecated and this property @@ -693,7 +690,7 @@ def __getattr__(self, name: str) -> Any: warn_deprecation( f'`{kls}.{name}` is deprecated, use `{kls}.base.attributes.{new_name}` instead.', version=3, - stacklevel=3 + stacklevel=3, ) return getattr(self.base.attributes, new_name) @@ -703,7 +700,7 @@ def __getattr__(self, name: str) -> Any: warn_deprecation( f'`{kls}.{name}` is deprecated, use `{kls}.base.repository.{new_name}` instead.', version=3, - stacklevel=3 + stacklevel=3, ) return getattr(self.base.repository, new_name) diff --git a/aiida/orm/nodes/process/__init__.py b/aiida/orm/nodes/process/__init__.py index 283b14e9b0..e12d4295b4 100644 --- a/aiida/orm/nodes/process/__init__.py +++ b/aiida/orm/nodes/process/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .calculation import * from .process import * @@ -28,4 +27,4 @@ 'WorkflowNode', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/nodes/process/calculation/__init__.py b/aiida/orm/nodes/process/calculation/__init__.py index 21af4e576e..f7712d6cbb 100644 --- a/aiida/orm/nodes/process/calculation/__init__.py +++ b/aiida/orm/nodes/process/calculation/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .calcfunction import * from .calcjob import * @@ -24,4 +23,4 @@ 'CalculationNode', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/nodes/process/calculation/calcjob.py b/aiida/orm/nodes/process/calculation/calcjob.py index ba94a23683..b37e034862 100644 --- a/aiida/orm/nodes/process/calculation/calcjob.py +++ b/aiida/orm/nodes/process/calculation/calcjob.py @@ -20,7 +20,6 @@ from .calculation import CalculationNode if TYPE_CHECKING: - from aiida.engine.processes.builder import ProcessBuilder from aiida.orm import FolderData from aiida.orm.authinfos import AuthInfo from aiida.orm.utils.calcjob import CalcJobResultManager @@ -45,19 +44,20 @@ def _get_objects_to_hash(self) -> List[Any]: has started and the input files have been written to the repository. """ from importlib import import_module + objects = [ import_module(self._node.__module__.split('.', 1)[0]).__version__, { key: val for key, val in self._node.base.attributes.items() - if key not in self._node._hash_ignored_attributes and key not in self._node._updatable_attributes # pylint: disable=unsupported-membership-test,protected-access + if key not in self._node._hash_ignored_attributes and key not in self._node._updatable_attributes }, - self._node.computer.uuid if self._node.computer is not None else None, # pylint: disable=no-member + self._node.computer.uuid if self._node.computer is not None else None, { entry.link_label: entry.node.base.caching.get_hash() for entry in self._node.base.links.get_incoming(link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK)) if entry.link_label not in self._hash_ignored_inputs - } + }, ] return objects @@ -65,7 +65,6 @@ def _get_objects_to_hash(self) -> List[Any]: class CalcJobNode(CalculationNode): """ORM class for all nodes representing the execution of a CalcJob.""" - # pylint: disable=too-many-public-methods _CLS_NODE_CACHING = CalcJobNodeCaching CALC_JOB_STATE_KEY = 'state' @@ -114,7 +113,7 @@ def tools(self) -> 'CalculationTools': return self._tools @classproperty - def _updatable_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-argument + def _updatable_attributes(cls) -> Tuple[str, ...]: # noqa: N805 return super()._updatable_attributes + ( cls.CALC_JOB_STATE_KEY, cls.IMMIGRATED_KEY, @@ -129,7 +128,7 @@ def _updatable_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-ar ) @classproperty - def _hash_ignored_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-argument + def _hash_ignored_attributes(cls) -> Tuple[str, ...]: # noqa: N805 return super()._hash_ignored_attributes + ( 'queue_name', 'account', @@ -145,8 +144,7 @@ def is_imported(self) -> bool: return self.base.attributes.get(self.IMMIGRATED_KEY, None) is True def get_option(self, name: str) -> Optional[Any]: - """ - Retun the value of an option that was set for this CalcJobNode + """Retun the value of an option that was set for this CalcJobNode :param name: the option name :return: the option value or None @@ -155,8 +153,7 @@ def get_option(self, name: str) -> Optional[Any]: return self.base.attributes.get(name, None) def set_option(self, name: str, value: Any) -> None: - """ - Set an option to the given value + """Set an option to the given value :param name: the option name :param value: the value to set @@ -166,8 +163,7 @@ def set_option(self, name: str, value: Any) -> None: self.base.attributes.set(name, value) def get_options(self) -> Dict[str, Any]: - """ - Return the dictionary of options set for this CalcJobNode + """Return the dictionary of options set for this CalcJobNode :return: dictionary of the options and their values """ @@ -180,8 +176,7 @@ def get_options(self) -> Dict[str, Any]: return options def set_options(self, options: Dict[str, Any]) -> None: - """ - Set the options for this CalcJobNode + """Set the options for this CalcJobNode :param options: dictionary of option and their values to set """ @@ -249,7 +244,6 @@ def _validate_retrieval_directive(directives: Sequence[Union[str, Tuple[str, str raise TypeError('file retrieval directives has to be a list or tuple') for directive in directives: - # A string as a directive is valid, so we continue if isinstance(directive, str): continue @@ -414,7 +408,7 @@ def get_authinfo(self) -> 'AuthInfo': if computer is None: raise exceptions.NotExistent('No computer has been set for this calculation') - return computer.get_authinfo(self.user) # pylint: disable=no-member + return computer.get_authinfo(self.user) def get_transport(self) -> 'Transport': """Return the transport for this calculation. @@ -449,16 +443,19 @@ def get_retrieved_node(self) -> Optional['FolderData']: :return: the retrieved FolderData node or None if not found """ from aiida.orm import FolderData + try: - return self.base.links.get_outgoing(node_class=FolderData, - link_label_filter=self.link_label_retrieved).one().node + return ( + self.base.links.get_outgoing(node_class=FolderData, link_label_filter=self.link_label_retrieved) + .one() + .node + ) except ValueError: return None @property def res(self) -> 'CalcJobResultManager': - """ - To be used to get direct access to the parsed parameters. + """To be used to get direct access to the parsed parameters. :return: an instance of the CalcJobResultManager. @@ -467,6 +464,7 @@ def res(self) -> 'CalcJobResultManager': The command `calc.res.energy` will return such a list. """ from aiida.orm.utils.calcjob import CalcJobResultManager + return CalcJobResultManager(self) def get_scheduler_stdout(self) -> Optional[AnyStr]: diff --git a/aiida/orm/nodes/process/process.py b/aiida/orm/nodes/process/process.py index fa7a59ea9b..ec1c20a699 100644 --- a/aiida/orm/nodes/process/process.py +++ b/aiida/orm/nodes/process/process.py @@ -68,15 +68,15 @@ def is_valid_cache(self, valid: bool) -> None: super(ProcessNodeCaching, self.__class__).is_valid_cache.fset(self, valid) def _get_objects_to_hash(self) -> List[Any]: - """ - Return a list of objects which should be included in the hash. - """ - res = super()._get_objects_to_hash() # pylint: disable=protected-access - res.append({ - entry.link_label: entry.node.base.caching.get_hash() - for entry in self._node.base.links.get_incoming(link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK)) - if entry.link_label not in self._hash_ignored_inputs - }) + """Return a list of objects which should be included in the hash.""" + res = super()._get_objects_to_hash() + res.append( + { + entry.link_label: entry.node.base.caching.get_hash() + for entry in self._node.base.links.get_incoming(link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK)) + if entry.link_label not in self._hash_ignored_inputs + } + ) return res @@ -121,8 +121,7 @@ def validate_outgoing(self, target, link_type, link_label): class ProcessNode(Sealable, Node): - """ - Base class for all nodes representing the execution of a process + """Base class for all nodes representing the execution of a process This class and its subclasses serve as proxies in the database, for actual `Process` instances being run. The `Process` instance in memory will leverage an instance of this class (the exact sub class depends on the sub class @@ -130,7 +129,6 @@ class ProcessNode(Sealable, Node): inspect the state of the `Process` during its execution as well as a permanent record of its execution in the provenance graph, after the execution has terminated. """ - # pylint: disable=too-many-public-methods,abstract-method _CLS_NODE_LINKS = ProcessNodeLinks _CLS_NODE_CACHING = ProcessNodeCaching @@ -155,8 +153,7 @@ def __str__(self) -> str: return f'{base}' @classproperty - def _updatable_attributes(cls) -> Tuple[str, ...]: - # pylint: disable=no-self-argument + def _updatable_attributes(cls) -> Tuple[str, ...]: # noqa: N805 return super()._updatable_attributes + ( cls.PROCESS_PAUSED_KEY, cls.CHECKPOINT_KEY, @@ -178,12 +175,12 @@ def get_metadata_inputs(self) -> Optional[Dict[str, Any]]: @property def logger(self): - """ - Get the logger of the Calculation object, so that it also logs to the DB. + """Get the logger of the Calculation object, so that it also logs to the DB. :return: LoggerAdapter object, that works like a logger, but also has the 'extra' embedded """ from aiida.orm.utils.log import create_logger_adapter + return create_logger_adapter(self._logger, self) def get_builder_restart(self) -> 'ProcessBuilder': @@ -196,8 +193,8 @@ def get_builder_restart(self) -> 'ProcessBuilder': :return: `~aiida.engine.processes.builder.ProcessBuilder` instance """ builder = self.process_class.get_builder() - builder._update(self.base.links.get_incoming(link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK)).nested()) # pylint: disable=protected-access - builder._merge(self.get_metadata_inputs() or {}) # pylint: disable=protected-access + builder._update(self.base.links.get_incoming(link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK)).nested()) + builder._merge(self.get_metadata_inputs() or {}) return builder @@ -245,8 +242,7 @@ def str_rsplit_iter(string, sep='.'): return process_class def set_process_type(self, process_type_string: str) -> None: - """ - Set the process type string. + """Set the process type string. :param process_type: the process type string identifying the class using this process node as storage. """ @@ -254,16 +250,14 @@ def set_process_type(self, process_type_string: str) -> None: @property def process_label(self) -> Optional[str]: - """ - Return the process label + """Return the process label :returns: the process label """ return self.base.attributes.get(self.PROCESS_LABEL_KEY, None) def set_process_label(self, label: str) -> None: - """ - Set the process label + """Set the process label :param label: process label string """ @@ -271,8 +265,7 @@ def set_process_label(self, label: str) -> None: @property def process_state(self) -> Optional[ProcessState]: - """ - Return the process state + """Return the process state :returns: the process state instance of ProcessState enum """ @@ -284,8 +277,7 @@ def process_state(self) -> Optional[ProcessState]: return ProcessState(state) def set_process_state(self, state: Union[str, ProcessState, None]): - """ - Set the process state + """Set the process state :param state: value or instance of ProcessState enum """ @@ -295,8 +287,7 @@ def set_process_state(self, state: Union[str, ProcessState, None]): @property def process_status(self) -> Optional[str]: - """ - Return the process status + """Return the process status The process status is a generic status message e.g. the reason it might be paused or when it is being killed @@ -305,8 +296,7 @@ def process_status(self) -> Optional[str]: return self.base.attributes.get(self.PROCESS_STATUS_KEY, None) def set_process_status(self, status: Optional[str]) -> None: - """ - Set the process status + """Set the process status The process status is a generic status message e.g. the reason it might be paused or when it is being killed. If status is None, the corresponding attribute will be deleted. @@ -327,8 +317,7 @@ def set_process_status(self, status: Optional[str]) -> None: @property def is_terminated(self) -> bool: - """ - Return whether the process has terminated + """Return whether the process has terminated Terminated means that the process has reached any terminal state. @@ -339,8 +328,7 @@ def is_terminated(self) -> bool: @property def is_excepted(self) -> bool: - """ - Return whether the process has excepted + """Return whether the process has excepted Excepted means that during execution of the process, an exception was raised that was not caught. @@ -351,8 +339,7 @@ def is_excepted(self) -> bool: @property def is_killed(self) -> bool: - """ - Return whether the process was killed + """Return whether the process was killed Killed means the process was killed directly by the user or by the calling process being killed. @@ -363,8 +350,7 @@ def is_killed(self) -> bool: @property def is_finished(self) -> bool: - """ - Return whether the process has finished + """Return whether the process has finished Finished means that the process reached a terminal state nominally. Note that this does not necessarily mean successfully, but there were no exceptions and it was not killed. @@ -376,8 +362,7 @@ def is_finished(self) -> bool: @property def is_finished_ok(self) -> bool: - """ - Return whether the process has finished successfully + """Return whether the process has finished successfully Finished successfully means that it terminated nominally and had a zero exit status. @@ -388,8 +373,7 @@ def is_finished_ok(self) -> bool: @property def is_failed(self) -> bool: - """ - Return whether the process has failed + """Return whether the process has failed Failed means that the process terminated nominally but it had a non-zero exit status. @@ -418,16 +402,14 @@ def exit_code(self) -> Optional['ExitCode']: @property def exit_status(self) -> Optional[int]: - """ - Return the exit status of the process + """Return the exit status of the process :returns: the exit status, an integer exit code or None """ return self.base.attributes.get(self.EXIT_STATUS_KEY, None) def set_exit_status(self, status: Union[None, enum.Enum, int]) -> None: - """ - Set the exit status of the process + """Set the exit status of the process :param state: an integer exit code or None, which will be interpreted as zero """ @@ -444,16 +426,14 @@ def set_exit_status(self, status: Union[None, enum.Enum, int]) -> None: @property def exit_message(self) -> Optional[str]: - """ - Return the exit message of the process + """Return the exit message of the process :returns: the exit message """ return self.base.attributes.get(self.EXIT_MESSAGE_KEY, None) def set_exit_message(self, message: Optional[str]) -> None: - """ - Set the exit message of the process, if None nothing will be done + """Set the exit message of the process, if None nothing will be done :param message: a string message """ @@ -467,8 +447,7 @@ def set_exit_message(self, message: Optional[str]) -> None: @property def exception(self) -> Optional[str]: - """ - Return the exception of the process or None if the process is not excepted. + """Return the exception of the process or None if the process is not excepted. If the process is marked as excepted yet there is no exception attribute, an empty string will be returned. @@ -480,8 +459,7 @@ def exception(self) -> Optional[str]: return None def set_exception(self, exception: str) -> None: - """ - Set the exception of the process + """Set the exception of the process :param exception: the exception message """ @@ -492,25 +470,21 @@ def set_exception(self, exception: str) -> None: @property def checkpoint(self) -> Optional[str]: - """ - Return the checkpoint bundle set for the process + """Return the checkpoint bundle set for the process :returns: checkpoint bundle if it exists, None otherwise """ return self.base.attributes.get(self.CHECKPOINT_KEY, None) def set_checkpoint(self, checkpoint: str) -> None: - """ - Set the checkpoint bundle set for the process + """Set the checkpoint bundle set for the process :param state: string representation of the stepper state info """ return self.base.attributes.set(self.CHECKPOINT_KEY, checkpoint) def delete_checkpoint(self) -> None: - """ - Delete the checkpoint bundle set for the process - """ + """Delete the checkpoint bundle set for the process""" try: self.base.attributes.delete(self.CHECKPOINT_KEY) except AttributeError: @@ -518,16 +492,14 @@ def delete_checkpoint(self) -> None: @property def paused(self) -> bool: - """ - Return whether the process is paused + """Return whether the process is paused :returns: True if the Calculation is marked as paused, False otherwise """ return self.base.attributes.get(self.PROCESS_PAUSED_KEY, False) def pause(self) -> None: - """ - Mark the process as paused by setting the corresponding attribute. + """Mark the process as paused by setting the corresponding attribute. This serves only to reflect that the corresponding Process is paused and so this method should not be called by anyone but the Process instance itself. @@ -535,8 +507,7 @@ def pause(self) -> None: return self.base.attributes.set(self.PROCESS_PAUSED_KEY, True) def unpause(self) -> None: - """ - Mark the process as unpaused by removing the corresponding attribute. + """Mark the process as unpaused by removing the corresponding attribute. This serves only to reflect that the corresponding Process is unpaused and so this method should not be called by anyone but the Process instance itself. @@ -548,8 +519,7 @@ def unpause(self) -> None: @property def called(self) -> List['ProcessNode']: - """ - Return a list of nodes that the process called + """Return a list of nodes that the process called :returns: list of process nodes called by this process """ @@ -557,8 +527,7 @@ def called(self) -> List['ProcessNode']: @property def called_descendants(self) -> List['ProcessNode']: - """ - Return a list of all nodes that have been called downstream of this process + """Return a list of all nodes that have been called downstream of this process This will recursively find all the called processes for this process and its children. """ @@ -572,8 +541,7 @@ def called_descendants(self) -> List['ProcessNode']: @property def caller(self) -> Optional['ProcessNode']: - """ - Return the process node that called this process node, or None if it does not have a caller + """Return the process node that called this process node, or None if it does not have a caller :returns: process node that called this process node instance or None """ diff --git a/aiida/orm/nodes/process/workflow/__init__.py b/aiida/orm/nodes/process/workflow/__init__.py index f4125a4f8f..c5d13a6f0b 100644 --- a/aiida/orm/nodes/process/workflow/__init__.py +++ b/aiida/orm/nodes/process/workflow/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .workchain import * from .workflow import * @@ -24,4 +23,4 @@ 'WorkflowNode', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/nodes/process/workflow/workchain.py b/aiida/orm/nodes/process/workflow/workchain.py index eba864a25c..ba7a350e77 100644 --- a/aiida/orm/nodes/process/workflow/workchain.py +++ b/aiida/orm/nodes/process/workflow/workchain.py @@ -23,22 +23,19 @@ class WorkChainNode(WorkflowNode): STEPPER_STATE_INFO_KEY = 'stepper_state_info' @classproperty - def _updatable_attributes(cls) -> Tuple[str, ...]: # type: ignore[override] - # pylint: disable=no-self-argument + def _updatable_attributes(cls) -> Tuple[str, ...]: # type: ignore[override] # noqa: N805 return super()._updatable_attributes + (cls.STEPPER_STATE_INFO_KEY,) @property def stepper_state_info(self) -> Optional[str]: - """ - Return the stepper state info + """Return the stepper state info :returns: string representation of the stepper state info """ return self.base.attributes.get(self.STEPPER_STATE_INFO_KEY, None) def set_stepper_state_info(self, stepper_state_info: str) -> None: - """ - Set the stepper state info + """Set the stepper state info :param state: string representation of the stepper state info """ diff --git a/aiida/orm/nodes/repository.py b/aiida/orm/nodes/repository.py index e6d1b53c8e..9f0f035de2 100644 --- a/aiida/orm/nodes/repository.py +++ b/aiida/orm/nodes/repository.py @@ -117,7 +117,7 @@ def _copy(self, repo: 'NodeRepository') -> None: :param repo: the repository to clone. """ - self._repository = copy.copy(repo._repository) # pylint: disable=protected-access + self._repository = copy.copy(repo._repository) def _clone(self, repo: 'NodeRepository') -> None: """Clone the repository from another instance. @@ -126,7 +126,7 @@ def _clone(self, repo: 'NodeRepository') -> None: :param repo: the repository to clone. """ - self._repository.clone(repo._repository) # pylint: disable=protected-access + self._repository.clone(repo._repository) def serialize(self) -> dict: """Serialize the metadata of the repository content into a JSON-serializable format. @@ -209,7 +209,6 @@ def as_path(self, path: FilePath | None = None) -> t.Iterator[pathlib.Path]: obj = self.get_object(path) with tempfile.TemporaryDirectory() as tmp_path: - dirpath = pathlib.Path(tmp_path) if obj.is_dir(): @@ -283,7 +282,7 @@ def put_object_from_filelike(self, handle: io.BufferedReader, path: str): if isinstance(handle, io.StringIO): # type: ignore[unreachable] handle = io.BytesIO(handle.read().encode('utf-8')) # type: ignore[unreachable] - if isinstance(handle, tempfile._TemporaryFileWrapper): # type: ignore[unreachable] # pylint: disable=protected-access + if isinstance(handle, tempfile._TemporaryFileWrapper): # type: ignore[unreachable] if 'b' in handle.file.mode: # type: ignore[unreachable] handle = io.BytesIO(handle.read()) else: diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index 7379443a52..d3f1e2e4a5 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -7,9 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines -""" -The QueryBuilder: A class that allows you to query the AiiDA database, independent from backend. +"""The QueryBuilder: A class that allows you to query the AiiDA database, independent from backend. Note that the backend implementation is enforced and handled with a composition model! :func:`QueryBuilder` is the frontend class that the user can use. It inherits from *object* and contains backend-specific functionality. Backend specific functionality is provided by the implementation classes. @@ -21,6 +19,7 @@ """ from __future__ import annotations +import warnings from copy import deepcopy from inspect import isclass as inspect_isclass from typing import ( @@ -40,7 +39,6 @@ cast, overload, ) -import warnings from aiida.common.log import AIIDA_LOGGER from aiida.common.warnings import warn_deprecation @@ -57,16 +55,15 @@ from . import authinfos, comments, computers, convert, entities, groups, logs, nodes, users if TYPE_CHECKING: - # pylint: disable=ungrouped-imports from aiida.engine import Process from aiida.orm.implementation import StorageBackend __all__ = ('QueryBuilder',) # re-usable type annotations -EntityClsType = Type[Union[entities.Entity, 'Process']] # pylint: disable=invalid-name -ProjectType = Union[str, dict, Sequence[Union[str, dict]]] # pylint: disable=invalid-name -FilterType = Dict[str, Any] # pylint: disable=invalid-name +EntityClsType = Type[Union[entities.Entity, 'Process']] +ProjectType = Union[str, dict, Sequence[Union[str, dict]]] +FilterType = Dict[str, Any] OrderByType = Union[dict, List[dict], Tuple[dict, ...]] LOGGER = AIIDA_LOGGER.getChild('querybuilder') @@ -74,13 +71,13 @@ class Classifier(NamedTuple): """A classifier for an entity.""" + ormclass_type_string: str process_type_string: Optional[str] = None class QueryBuilder: - """ - The class to query the AiiDA database. + """The class to query the AiiDA database. Usage:: @@ -93,8 +90,6 @@ class QueryBuilder: """ - # pylint: disable=too-many-instance-attributes,too-many-public-methods - # This tag defines how edges are tagged (labeled) by the QueryBuilder default # namely tag of first entity + _EDGE_TAG_DELIM + tag of second entity _EDGE_TAG_DELIM = '--' @@ -113,8 +108,7 @@ def __init__( order_by: Optional[OrderByType] = None, distinct: bool = False, ) -> None: - """ - Instantiates a QueryBuilder instance. + """Instantiates a QueryBuilder instance. Which backend is used decided here based on backend-settings (taken from the user profile). This cannot be overridden so far by the user. @@ -228,7 +222,7 @@ def as_dict(self, copy: bool = True) -> QueryDictType: @property def queryhelp(self) -> 'QueryDictType': - """"Legacy name for ``as_dict`` method.""" + """ "Legacy name for ``as_dict`` method.""" warn_deprecation('`QueryBuilder.queryhelp` is deprecated, use `QueryBuilder.as_dict()` instead', version=3) return self.as_dict() @@ -267,8 +261,7 @@ def get_used_tags(self, vertices: bool = True, edges: bool = True) -> List[str]: return given_tags def _get_unique_tag(self, classifiers: List[Classifier]) -> str: - """ - Using the function get_tag_from_type, I get a tag. + """Using the function get_tag_from_type, I get a tag. I increment an index that is appended to that tag until I have an unused tag. This function is called in :func:`QueryBuilder.append` when no tag is given. @@ -303,11 +296,10 @@ def append( outerjoin: bool = False, joining_keyword: Optional[str] = None, joining_value: Optional[Any] = None, - orm_base: Optional[str] = None, # pylint: disable=unused-argument - **kwargs: Any + orm_base: Optional[str] = None, + **kwargs: Any, ) -> 'QueryBuilder': - """ - Any iterative procedure to build the path for a graph query + """Any iterative procedure to build the path for a graph query needs to invoke this method to append to the path. :param cls: @@ -368,7 +360,6 @@ def append( :return: self """ - # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements # INPUT CHECKS ########################## # This function can be called by users, so I am checking the input now. # First of all, let's make sure the specified the class or the type (not both) @@ -459,7 +450,7 @@ def append( raise exception # JOINING ##################################### - # pylint: disable=too-many-nested-blocks + try: # Get the functions that are implemented: spec_to_function_map = set(EntityRelationships[ormclass.value]) @@ -521,9 +512,8 @@ def append( if edge_tag is None: edge_destination_tag = self._tags.get(joining_value) edge_tag = edge_destination_tag + self._EDGE_TAG_DELIM + tag - else: - if edge_tag in self._tags: - raise ValueError(f'The tag {edge_tag} is already in use') + elif edge_tag in self._tags: + raise ValueError(f'The tag {edge_tag} is already in use') LOGGER.debug('edge_tag chosen: %s', edge_tag) # edge tags do not have an ormclass @@ -579,8 +569,7 @@ def append( return self def order_by(self, order_by: OrderByType) -> 'QueryBuilder': - """ - Set the entity to order by + """Set the entity to order by :param order_by: This is a list of items, where each item is a dictionary specifies @@ -612,7 +601,6 @@ def order_by(self, order_by: OrderByType) -> 'QueryBuilder': qb.append(Node, tag='node') qb.order_by({'node':[{'id':'desc'}]}) """ - # pylint: disable=too-many-nested-blocks,too-many-branches self._order_by = [] allowed_keys = ('cast', 'order') possible_orders = ('asc', 'desc') @@ -629,12 +617,12 @@ def order_by(self, order_by: OrderByType) -> 'QueryBuilder': _order_spec: dict = {} for tagspec, items_to_order_by in order_spec.items(): if not isinstance(items_to_order_by, (tuple, list)): - items_to_order_by = [items_to_order_by] + items_to_order_by = [items_to_order_by] # noqa: PLW2901 tag = self._tags.get(tagspec) _order_spec[tag] = [] for item_to_order_by in items_to_order_by: if isinstance(item_to_order_by, str): - item_to_order_by = {item_to_order_by: {}} + item_to_order_by = {item_to_order_by: {}} # noqa: PLW2901 elif isinstance(item_to_order_by, dict): pass else: @@ -651,9 +639,9 @@ def order_by(self, order_by: OrderByType) -> 'QueryBuilder': this_order_spec = orderspec else: raise TypeError( - 'I was expecting a string or a dictionary\n' - 'You provided {} {}\n' - ''.format(type(orderspec), orderspec) + 'I was expecting a string or a dictionary\n' 'You provided {} {}\n' ''.format( + type(orderspec), orderspec + ) ) for key in this_order_spec: if key not in allowed_keys: @@ -679,8 +667,7 @@ def order_by(self, order_by: OrderByType) -> 'QueryBuilder': return self def add_filter(self, tagspec: Union[str, EntityClsType], filter_spec: FilterType) -> 'QueryBuilder': - """ - Adding a filter to my filters. + """Adding a filter to my filters. :param tagspec: A tag string or an ORM class which maps to an existing tag :param filter_spec: The specifications for the filter, has to be a dictionary @@ -720,8 +707,7 @@ def _process_filters(filters: FilterType) -> Dict[str, Any]: return processed_filters def _add_node_type_filter(self, tagspec: str, classifiers: List[Classifier], subclassing: bool): - """ - Add a filter based on node type. + """Add a filter based on node type. :param tagspec: The tag, which has to exist already as a key in self._filters :param classifiers: a dictionary with classifiers @@ -738,8 +724,7 @@ def _add_node_type_filter(self, tagspec: str, classifiers: List[Classifier], sub self.add_filter(tagspec, {'node_type': entity_type_filter}) def _add_process_type_filter(self, tagspec: str, classifiers: List[Classifier], subclassing: bool) -> None: - """ - Add a filter based on process type. + """Add a filter based on process type. :param tagspec: The tag, which has to exist already as a key in self._filters :param classifiers: a dictionary with classifiers @@ -757,14 +742,12 @@ def _add_process_type_filter(self, tagspec: str, classifiers: List[Classifier], if len(process_type_filter['or']) > 0: self.add_filter(tagspec, {'process_type': process_type_filter}) - else: - if classifiers[0].process_type_string is not None: - process_type_filter = _get_process_type_filter(classifiers[0], subclassing) - self.add_filter(tagspec, {'process_type': process_type_filter}) + elif classifiers[0].process_type_string is not None: + process_type_filter = _get_process_type_filter(classifiers[0], subclassing) + self.add_filter(tagspec, {'process_type': process_type_filter}) def _add_group_type_filter(self, tagspec: str, classifiers: List[Classifier], subclassing: bool) -> None: - """ - Add a filter based on group type. + """Add a filter based on group type. :param tagspec: The tag, which has to exist already as a key in self._filters :param classifiers: a dictionary with classifiers @@ -854,8 +837,7 @@ def add_projection(self, tag_spec: Union[str, EntityClsType], projection_spec: P self._projections[tag] = _projections def set_debug(self, debug: bool) -> 'QueryBuilder': - """ - Run in debug mode. This does not affect functionality, but prints intermediate stages + """Run in debug mode. This does not affect functionality, but prints intermediate stages when creating a query on screen. :param debug: Turn debug on or off @@ -879,8 +861,7 @@ def debug(self, msg: str, *objects: Any) -> None: print(f'DEBUG: {msg}' % objects) def limit(self, limit: Optional[int]) -> 'QueryBuilder': - """ - Set the limit (nr of rows to return) + """Set the limit (nr of rows to return) :param limit: integers of number of rows of rows to return """ @@ -890,8 +871,7 @@ def limit(self, limit: Optional[int]) -> 'QueryBuilder': return self def offset(self, offset: Optional[int]) -> 'QueryBuilder': - """ - Set the offset. If offset is set, that many rows are skipped before returning. + """Set the offset. If offset is set, that many rows are skipped before returning. *offset* = 0 is the same as omitting setting the offset. If both offset and limit appear, then *offset* rows are skipped before starting to count the *limit* rows @@ -905,8 +885,7 @@ def offset(self, offset: Optional[int]) -> 'QueryBuilder': return self def distinct(self, value: bool = True) -> 'QueryBuilder': - """ - Asks for distinct rows, which is the same as asking the backend to remove + """Asks for distinct rows, which is the same as asking the backend to remove duplicates. Does not execute the query! @@ -928,48 +907,48 @@ def distinct(self, value: bool = True) -> 'QueryBuilder': return self def inputs(self, **kwargs: Any) -> 'QueryBuilder': - """ - Join to inputs of previous vertice in path. + """Join to inputs of previous vertice in path. :returns: self """ from aiida.orm import Node + join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) self.append(cls=cls, with_outgoing=join_to, **kwargs) return self def outputs(self, **kwargs: Any) -> 'QueryBuilder': - """ - Join to outputs of previous vertice in path. + """Join to outputs of previous vertice in path. :returns: self """ from aiida.orm import Node + join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) self.append(cls=cls, with_incoming=join_to, **kwargs) return self def children(self, **kwargs: Any) -> 'QueryBuilder': - """ - Join to children/descendants of previous vertice in path. + """Join to children/descendants of previous vertice in path. :returns: self """ from aiida.orm import Node + join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) self.append(cls=cls, with_ancestors=join_to, **kwargs) return self def parents(self, **kwargs: Any) -> 'QueryBuilder': - """ - Join to parents/ancestors of previous vertice in path. + """Join to parents/ancestors of previous vertice in path. :returns: self """ from aiida.orm import Node + join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) self.append(cls=cls, with_descendants=join_to, **kwargs) @@ -1043,16 +1022,14 @@ def first(self, flat: bool = False) -> Optional[list[Any] | Any]: return result def count(self) -> int: - """ - Counts the number of rows returned by the backend. + """Counts the number of rows returned by the backend. :returns: the number of rows as an integer """ return self._impl.count(self.as_dict()) def iterall(self, batch_size: Optional[int] = 100) -> Iterable[List[Any]]: - """ - Same as :meth:`.all`, but returns a generator. + """Same as :meth:`.all`, but returns a generator. Be aware that this is only safe if no commit will take place during this transaction. You might also want to read the SQLAlchemy documentation on https://docs.sqlalchemy.org/en/14/orm/query.html#sqlalchemy.orm.Query.yield_per @@ -1071,8 +1048,7 @@ def iterall(self, batch_size: Optional[int] = 100) -> Iterable[List[Any]]: yield item def iterdict(self, batch_size: Optional[int] = 100) -> Iterable[Dict[str, Dict[str, Any]]]: - """ - Same as :meth:`.dict`, but returns a generator. + """Same as :meth:`.dict`, but returns a generator. Be aware that this is only safe if no commit will take place during this transaction. You might also want to read the SQLAlchemy documentation on https://docs.sqlalchemy.org/en/14/orm/query.html#sqlalchemy.orm.Query.yield_per @@ -1117,6 +1093,7 @@ def one(self) -> List[Any]: :raises: NotExistent if no result was found """ from aiida.common.exceptions import MultipleObjectsError, NotExistent + limit = self._limit self.limit(2) try: @@ -1130,8 +1107,7 @@ def one(self) -> List[Any]: return res[0] def dict(self, batch_size: Optional[int] = None) -> List[Dict[str, Dict[str, Any]]]: - """ - Executes the full query with the order of the rows as returned by the backend. + """Executes the full query with the order of the rows as returned by the backend. the order inside each row is given by the order of the vertices in the path and the order of the projections for each vertice in the path. @@ -1220,8 +1196,7 @@ def _get_ormclass( def _get_ormclass_from_cls(cls: EntityClsType) -> Tuple[EntityTypes, Classifier]: - """ - Return the correct classifiers for the QueryBuilder from an ORM class. + """Return the correct classifiers for the QueryBuilder from an ORM class. :param cls: an AiiDA ORM class or backend ORM class. :param query: an instance of the appropriate QueryBuilder backend. @@ -1230,7 +1205,6 @@ def _get_ormclass_from_cls(cls: EntityClsType) -> Tuple[EntityTypes, Classifier] Note: the ormclass_type_string is currently hardcoded for group, computer etc. One could instead use something like aiida.orm.utils.node.get_type_string_from_class(cls.__module__, cls.__name__) """ - # pylint: disable=protected-access,too-many-branches,too-many-statements # Note: Unable to move this import to the top of the module for some reason from aiida.engine import Process from aiida.orm.utils.node import is_valid_node_type_string @@ -1312,8 +1286,7 @@ def _get_ormclass_from_str(type_string: str) -> Tuple[EntityTypes, Classifier]: def _get_node_type_filter(classifiers: Classifier, subclassing: bool) -> dict: - """ - Return filter dictionaries given a set of classifiers. + """Return filter dictionaries given a set of classifiers. :param classifiers: a dictionary with classifiers (note: does *not* support lists) :param subclassing: if True, allow for subclasses of the ormclass @@ -1322,6 +1295,7 @@ def _get_node_type_filter(classifiers: Classifier, subclassing: bool) -> dict: """ from aiida.common.escaping import escape_for_sql_like from aiida.orm.utils.node import get_query_type_from_type_string + value = classifiers.ormclass_type_string if not subclassing: @@ -1335,8 +1309,7 @@ def _get_node_type_filter(classifiers: Classifier, subclassing: bool) -> dict: def _get_process_type_filter(classifiers: Classifier, subclassing: bool) -> dict: - """ - Return filter dictionaries given a set of classifiers. + """Return filter dictionaries given a set of classifiers. :param classifiers: a dictionary with classifiers (note: does *not* support lists) :param subclassing: if True, allow for subclasses of the process type @@ -1356,47 +1329,39 @@ def _get_process_type_filter(classifiers: Classifier, subclassing: bool) -> dict if not subclassing: filters = {'==': value} + elif ':' in value: + # if value is an entry point, do usual subclassing + + # Note: the process_type_string stored in the database does *not* end in a dot. + # In order to avoid that querying for class 'Begin' will also find class 'BeginEnd', + # we need to search separately for equality and 'like'. + filters = { + 'or': [ + {'==': value}, + {'like': escape_for_sql_like(get_query_string_from_process_type_string(value))}, + ] + } + elif value.startswith('aiida.engine'): + # For core process types, a filter is not is needed since each process type has a corresponding + # ormclass type that already specifies everything. + # Note: This solution is fragile and will break as soon as there is not an exact one-to-one correspondence + # between process classes and node classes + + # Note: Improve this when issue https://github.com/aiidateam/aiida-core/issues/2475 is addressed + filters = {'like': '%'} else: - if ':' in value: - # if value is an entry point, do usual subclassing - - # Note: the process_type_string stored in the database does *not* end in a dot. - # In order to avoid that querying for class 'Begin' will also find class 'BeginEnd', - # we need to search separately for equality and 'like'. - filters = { - 'or': [ - { - '==': value - }, - { - 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) - }, - ] - } - elif value.startswith('aiida.engine'): - # For core process types, a filter is not is needed since each process type has a corresponding - # ormclass type that already specifies everything. - # Note: This solution is fragile and will break as soon as there is not an exact one-to-one correspondence - # between process classes and node classes - - # Note: Improve this when issue https://github.com/aiidateam/aiida-core/issues/2475 is addressed - filters = {'like': '%'} - else: - warnings.warn( - "Process type '{value}' does not correspond to a registered entry. " - 'This risks queries to fail once the location of the process class changes. ' - "Add an entry point for '{value}' to remove this warning.".format(value=value), AiidaEntryPointWarning - ) - filters = { - 'or': [ - { - '==': value - }, - { - 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) - }, - ] - } + warnings.warn( + "Process type '{value}' does not correspond to a registered entry. " + 'This risks queries to fail once the location of the process class changes. ' + "Add an entry point for '{value}' to remove this warning.".format(value=value), + AiidaEntryPointWarning, + ) + filters = { + 'or': [ + {'==': value}, + {'like': escape_for_sql_like(get_query_string_from_process_type_string(value))}, + ] + } return filters @@ -1434,7 +1399,7 @@ def add( self, tag: str, etype: Union[None, EntityTypes] = None, - klasses: Union[None, EntityClsType, Sequence[EntityClsType]] = None + klasses: Union[None, EntityClsType, Sequence[EntityClsType]] = None, ) -> None: """Add a tag.""" self._tag_to_type[tag] = etype @@ -1464,7 +1429,7 @@ def get(self, tag_or_cls: Union[str, EntityClsType]) -> str: f'The object used as a tag ({tag_or_cls}) has multiple values associated with it: ' f'{self._cls_to_tag_map[tag_or_cls]}' ) - return list(self._cls_to_tag_map[tag_or_cls])[0] + return next(iter(self._cls_to_tag_map[tag_or_cls])) raise ValueError(f'The given object ({tag_or_cls}) has no tags associated with it.') @@ -1478,7 +1443,7 @@ def _get_group_type_filter(classifiers: Classifier, subclassing: bool) -> dict: """ from aiida.common.escaping import escape_for_sql_like - value = classifiers.ormclass_type_string[len(GROUP_ENTITY_TYPE_PREFIX):] + value = classifiers.ormclass_type_string[len(GROUP_ENTITY_TYPE_PREFIX) :] if not subclassing: filters = {'==': value} diff --git a/aiida/orm/users.py b/aiida/orm/users.py index fab3517d58..cce5d76de6 100644 --- a/aiida/orm/users.py +++ b/aiida/orm/users.py @@ -16,7 +16,8 @@ from . import entities if TYPE_CHECKING: - from aiida.orm.implementation import BackendUser, StorageBackend + from aiida.orm.implementation import StorageBackend + from aiida.orm.implementation.users import BackendUser # noqa: F401 __all__ = ('User',) @@ -57,10 +58,9 @@ def __init__( first_name: str = '', last_name: str = '', institution: str = '', - backend: Optional['StorageBackend'] = None + backend: Optional['StorageBackend'] = None, ): """Create a new `User`.""" - # pylint: disable=too-many-arguments backend = backend or get_manager().get_profile_storage() email = self.normalize_email(email) backend_entity = backend.users.create( @@ -125,8 +125,7 @@ def institution(self, institution: str) -> None: self._backend_entity.institution = institution def get_full_name(self) -> str: - """ - Return the user full name + """Return the user full name :return: the user full name """ @@ -142,8 +141,7 @@ def get_full_name(self) -> str: return full_name def get_short_name(self) -> str: - """ - Return the user short name (typically, this returns the email) + """Return the user short name (typically, this returns the email) :return: The short name """ @@ -151,7 +149,5 @@ def get_short_name(self) -> str: @property def uuid(self) -> None: - """ - For now users do not have UUIDs so always return None - """ + """For now users do not have UUIDs so always return None""" return None diff --git a/aiida/orm/utils/__init__.py b/aiida/orm/utils/__init__.py index 16e7b146c1..10e23252c9 100644 --- a/aiida/orm/utils/__init__.py +++ b/aiida/orm/utils/__init__.py @@ -11,8 +11,7 @@ # AUTO-GENERATED -# yapf: disable -# pylint: disable=wildcard-import +# fmt: off from .calcjob import * from .links import * @@ -46,4 +45,4 @@ 'validate_link', ) -# yapf: enable +# fmt: on diff --git a/aiida/orm/utils/builders/code.py b/aiida/orm/utils/builders/code.py index 0026146ded..2275aa8b0b 100644 --- a/aiida/orm/utils/builders/code.py +++ b/aiida/orm/utils/builders/code.py @@ -56,12 +56,12 @@ def new(self): if self._get_and_count('code_type', used) == self.CodeType.STORE_AND_UPLOAD: code = PortableCode( filepath_executable=self._get_and_count('code_rel_path', used), - filepath_files=pathlib.Path(self._get_and_count('code_folder', used)) + filepath_files=pathlib.Path(self._get_and_count('code_folder', used)), ) else: code = InstalledCode( computer=self._get_and_count('computer', used), - filepath_executable=self._get_and_count('remote_abs_path', used) + filepath_executable=self._get_and_count('remote_abs_path', used), ) code.label = self._get_and_count('label', used) @@ -128,16 +128,14 @@ def __getattr__(self, key): return None def _get(self, key): - """ - Return a spec, or None if not defined + """Return a spec, or None if not defined :param key: name of a code spec """ return self._code_spec.get(key) def _get_and_count(self, key, used): - """ - Return a spec, or raise if not defined. + """Return a spec, or raise if not defined. Moreover, add the key to the 'used' dict. :param key: name of a code spec @@ -199,11 +197,10 @@ def validate_installed(self): raise self.CodeValidationError(f'{messages}') class CodeValidationError(ValueError): - """ - A CodeBuilder instance may raise this + """A CodeBuilder instance may raise this - * when asked to instanciate a code with missing or invalid code attributes - * when asked for a code attibute that has not been set yet + * when asked to instanciate a code with missing or invalid code attributes + * when asked for a code attibute that has not been set yet """ def __init__(self, msg): diff --git a/aiida/orm/utils/builders/computer.py b/aiida/orm/utils/builders/computer.py index 4918425498..cc0b1de715 100644 --- a/aiida/orm/utils/builders/computer.py +++ b/aiida/orm/utils/builders/computer.py @@ -12,14 +12,15 @@ from aiida.common.utils import ErrorAccumulator -class ComputerBuilder: # pylint: disable=too-many-instance-attributes +class ComputerBuilder: """Build a computer with validation of attribute combinations""" @staticmethod def from_computer(computer): """Create ComputerBuilder from existing computer instance. - See also :py:func:`~ComputerBuilder.get_computer_spec`""" + See also :py:func:`~ComputerBuilder.get_computer_spec` + """ spec = ComputerBuilder.get_computer_spec(computer) return ComputerBuilder(**spec) @@ -31,7 +32,9 @@ def get_computer_spec(computer): spec = ComputerBuilder.get_computer_spec(old_computer) builder = ComputerBuilder(**spec) - new_computer = builder.new()""" + new_computer = builder.new() + + """ spec = {} spec['label'] = computer.label spec['description'] = computer.description @@ -94,13 +97,11 @@ def new(self): mpiprocs_per_machine = int(mpiprocs_per_machine) except ValueError: raise self.ComputerValidationError( - 'Invalid value provided for mpiprocs_per_machine, ' - 'must be a valid integer' + 'Invalid value provided for mpiprocs_per_machine, ' 'must be a valid integer' ) if mpiprocs_per_machine <= 0: raise self.ComputerValidationError( - 'Invalid value provided for mpiprocs_per_machine, ' - 'must be positive' + 'Invalid value provided for mpiprocs_per_machine, ' 'must be positive' ) computer.set_default_mpiprocs_per_machine(mpiprocs_per_machine) @@ -120,7 +121,7 @@ def new(self): mpirun_command_internal = self._get_and_count('mpirun_command', used).strip().split(' ') if mpirun_command_internal == ['']: mpirun_command_internal = [] - computer._mpirun_command_validator(mpirun_command_internal) # pylint: disable=protected-access + computer._mpirun_command_validator(mpirun_command_internal) computer.set_mpirun_command(mpirun_command_internal) # Complain if there are keys that are passed but not used @@ -141,22 +142,21 @@ def __getattr__(self, key): return None def _get(self, key): - """ - Return a spec, or None if not defined + """Return a spec, or None if not defined - :param key: name of a computer spec""" + :param key: name of a computer spec + """ return self._computer_spec.get(key) def _get_and_count(self, key, used): - """ - Return a spec, or raise if not defined. + """Return a spec, or raise if not defined. Moreover, add the key to the 'used' dict. :param key: name of a computer spec :param used: should be a set of keys that you want to track. ``key`` will be added to this set if the value exists in the spec and can be retrieved. """ - retval = self.__getattr__(key) # pylint: disable=unnecessary-dunder-call + retval = self.__getattr__(key) # I first get a retval, so if I get an exception, I don't add it to the 'used' set used.add(key) return retval @@ -176,11 +176,11 @@ def _set_computer_attr(self, key, value): self.validate() class ComputerValidationError(Exception): - """ - A ComputerBuilder instance may raise this + """A ComputerBuilder instance may raise this. - * when asked to instanciate a code with missing or invalid computer attributes - * when asked for a computer attibute that has not been set yet.""" + * when asked to instanciate a code with missing or invalid computer attributes + * when asked for a computer attibute that has not been set yet. + """ def __init__(self, msg): super().__init__() diff --git a/aiida/orm/utils/calcjob.py b/aiida/orm/utils/calcjob.py index 5fc58150a6..e3c2b16876 100644 --- a/aiida/orm/utils/calcjob.py +++ b/aiida/orm/utils/calcjob.py @@ -15,8 +15,7 @@ class CalcJobResultManager: - """ - Utility class to easily access the contents of the 'default output' node of a `CalcJobNode`. + """Utility class to easily access the contents of the 'default output' node of a `CalcJobNode`. A `CalcJob` process can mark one of its outputs as the 'default output'. The default output node will always be returned by the `CalcJob` and will always be a `Dict` node. diff --git a/aiida/orm/utils/links.py b/aiida/orm/utils/links.py index 2106a1d57c..aba8a5abaf 100644 --- a/aiida/orm/utils/links.py +++ b/aiida/orm/utils/links.py @@ -42,11 +42,7 @@ class LinkQuadruple(NamedTuple): def link_triple_exists( - source: 'Node', - target: 'Node', - link_type: 'LinkType', - link_label: str, - backend: Optional['StorageBackend'] = None + source: 'Node', target: 'Node', link_type: 'LinkType', link_label: str, backend: Optional['StorageBackend'] = None ) -> bool: """Return whether a link with the given type and label exists between the given source and target node. @@ -79,14 +75,9 @@ def link_triple_exists( def validate_link( - source: 'Node', - target: 'Node', - link_type: 'LinkType', - link_label: str, - backend: Optional['StorageBackend'] = None + source: 'Node', target: 'Node', link_type: 'LinkType', link_label: str, backend: Optional['StorageBackend'] = None ) -> None: - """ - Validate adding a link of the given type and label from a given node to ourself. + """Validate adding a link of the given type and label from a given node to ourself. This function will first validate the class types of the inputs and will subsequently validate whether a link of the specified type is allowed at all between the nodes types of the source and target. @@ -146,7 +137,6 @@ def validate_link( :raise TypeError: if `source` or `target` is not a Node instance, or `link_type` is not a `LinkType` enum :raise ValueError: if the proposed link is invalid """ - # yapf: disable from aiida.common.links import LinkType, validate_link_label from aiida.orm import CalculationNode, Data, Node, WorkflowNode @@ -199,33 +189,42 @@ def validate_link( raise ValueError(f'node<{source.uuid}> already has an outgoing {link_type} link') # If the outdegree is `unique_pair`, then the link labels for outgoing links of this type should be unique - elif outdegree == 'unique_pair' and source.base.links.get_outgoing( - link_type=link_type, only_uuid=True, link_label_filter=link_label).all(): + elif ( + outdegree == 'unique_pair' + and source.base.links.get_outgoing(link_type=link_type, only_uuid=True, link_label_filter=link_label).all() + ): raise ValueError(f'node<{source.uuid}> already has an outgoing {link_type} link with label "{link_label}"') # If the outdegree is `unique_triple`, then the link triples of link type, link label and target should be unique elif outdegree == 'unique_triple' and duplicate_link_triple: - raise ValueError('node<{}> already has an outgoing {} link with label "{}" from node<{}>'.format( - source.uuid, link_type, link_label, target.uuid)) + raise ValueError( + 'node<{}> already has an outgoing {} link with label "{}" from node<{}>'.format( + source.uuid, link_type, link_label, target.uuid + ) + ) # If the indegree is `unique` there cannot already be any other incoming links of that type if indegree == 'unique' and target.base.links.get_incoming(link_type=link_type, only_uuid=True).all(): raise ValueError(f'node<{target.uuid}> already has an incoming {link_type} link') # If the indegree is `unique_pair`, then the link labels for incoming links of this type should be unique - elif indegree == 'unique_pair' and target.base.links.get_incoming( - link_type=link_type, link_label_filter=link_label, only_uuid=True).all(): + elif ( + indegree == 'unique_pair' + and target.base.links.get_incoming(link_type=link_type, link_label_filter=link_label, only_uuid=True).all() + ): raise ValueError(f'node<{target.uuid}> already has an incoming {link_type} link with label "{link_label}"') # If the indegree is `unique_triple`, then the link triples of link type, link label and source should be unique elif indegree == 'unique_triple' and duplicate_link_triple: - raise ValueError('node<{}> already has an incoming {} link with label "{}" from node<{}>'.format( - target.uuid, link_type, link_label, source.uuid)) + raise ValueError( + 'node<{}> already has an incoming {} link with label "{}" from node<{}>'.format( + target.uuid, link_type, link_label, source.uuid + ) + ) class LinkManager: - """ - Class to convert a list of LinkTriple tuples into an iterator. + """Class to convert a list of LinkTriple tuples into an iterator. It defines convenience methods to retrieve certain subsets of LinkTriple while checking for consistency. For example:: @@ -331,9 +330,7 @@ def get_node_by_label(self, label: str) -> 'Node': if matching_entry is None: matching_entry = entry.node else: - raise exceptions.MultipleObjectsError( - f'more than one neighbor with the label {label} found' - ) + raise exceptions.MultipleObjectsError(f'more than one neighbor with the label {label} found') if matching_entry is None: raise exceptions.NotExistent(f'no neighbor with the label {label} found') @@ -355,7 +352,6 @@ def nested(self, sort=True): nested: dict = {} for entry in self.link_triples: - current_namespace = nested breadcrumbs = entry.link_label.split(PORT_NAMESPACE_SEPARATOR) diff --git a/aiida/orm/utils/loaders.py b/aiida/orm/utils/loaders.py index b3981373da..4f9bf97d1a 100644 --- a/aiida/orm/utils/loaders.py +++ b/aiida/orm/utils/loaders.py @@ -20,17 +20,25 @@ from aiida.orm import Code, Computer, Group, Node __all__ = ( - 'load_code', 'load_computer', 'load_group', 'load_node', 'load_entity', 'get_loader', 'OrmEntityLoader', - 'CalculationEntityLoader', 'CodeEntityLoader', 'ComputerEntityLoader', 'GroupEntityLoader', 'NodeEntityLoader' + 'load_code', + 'load_computer', + 'load_group', + 'load_node', + 'load_entity', + 'get_loader', + 'OrmEntityLoader', + 'CalculationEntityLoader', + 'CodeEntityLoader', + 'ComputerEntityLoader', + 'GroupEntityLoader', + 'NodeEntityLoader', ) def load_entity( entity_loader=None, identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True ): - # pylint: disable=too-many-arguments - """ - Load an entity instance by one of its identifiers: pk, uuid or label + """Load an entity instance by one of its identifiers: pk, uuid or label If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to automatically infer the type. @@ -59,7 +67,6 @@ def load_entity( raise ValueError("only one of parameters 'identifier', pk', 'uuid' or 'label' has to be specified") if pk is not None: - if not isinstance(pk, int): raise TypeError('a pk has to be an integer') @@ -67,7 +74,6 @@ def load_entity( identifier_type = IdentifierType.ID elif uuid is not None: - if not isinstance(uuid, str): raise TypeError('uuid has to be a string type') @@ -75,7 +81,6 @@ def load_entity( identifier_type = IdentifierType.UUID elif label is not None: - if not isinstance(label, str): raise TypeError('label has to be a string type') @@ -91,8 +96,7 @@ def load_entity( def load_code(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True) -> 'Code': - """ - Load a Code instance by one of its identifiers: pk, uuid or label + """Load a Code instance by one of its identifiers: pk, uuid or label If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to automatically infer the type. @@ -117,15 +121,14 @@ def load_code(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, uuid=uuid, label=label, sub_classes=sub_classes, - query_with_dashes=query_with_dashes + query_with_dashes=query_with_dashes, ) def load_computer( identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True ) -> 'Computer': - """ - Load a Computer instance by one of its identifiers: pk, uuid or label + """Load a Computer instance by one of its identifiers: pk, uuid or label If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to automatically infer the type. @@ -150,13 +153,12 @@ def load_computer( uuid=uuid, label=label, sub_classes=sub_classes, - query_with_dashes=query_with_dashes + query_with_dashes=query_with_dashes, ) def load_group(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True) -> 'Group': - """ - Load a Group instance by one of its identifiers: pk, uuid or label + """Load a Group instance by one of its identifiers: pk, uuid or label If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to automatically infer the type. @@ -181,13 +183,12 @@ def load_group(identifier=None, pk=None, uuid=None, label=None, sub_classes=None uuid=uuid, label=label, sub_classes=sub_classes, - query_with_dashes=query_with_dashes + query_with_dashes=query_with_dashes, ) def load_node(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True) -> 'Node': - """ - Load a node by one of its identifiers: pk or uuid. If the type of the identifier is unknown + """Load a node by one of its identifiers: pk or uuid. If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to infer the type :param identifier: pk (integer) or uuid (string) @@ -210,7 +211,7 @@ def load_node(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, uuid=uuid, label=label, sub_classes=sub_classes, - query_with_dashes=query_with_dashes + query_with_dashes=query_with_dashes, ) @@ -239,12 +240,10 @@ def get_loader(orm_class): class IdentifierType(Enum): - """ - The enumeration that defines the three types of identifier that can be used to identify an orm entity. + """The enumeration that defines the three types of identifier that can be used to identify an orm entity. The ID is always an integer, the UUID a base 16 encoded integer with optional dashes and the LABEL can be any string based label or name, the format of which will vary per orm class """ - # pylint: disable=invalid-name ID = 'ID' UUID = 'UUID' @@ -258,8 +257,7 @@ class OrmEntityLoader: @classproperty def orm_base_class(self): - """ - Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity + """Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity may further narrow the query set by defining a more specific set of orm classes, as long as each of those is a strict sub class of the orm base class. @@ -270,8 +268,7 @@ def orm_base_class(self): @classmethod @abstractmethod def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', project='*'): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, interpreting the identifier as a LABEL like identifier :param identifier: the LABEL identifier @@ -286,8 +283,7 @@ def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', @classmethod def _get_query_builder_id_identifier(cls, identifier, classes): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, interpreting the identifier as an ID like identifier :param identifier: the ID identifier @@ -302,8 +298,7 @@ def _get_query_builder_id_identifier(cls, identifier, classes): @classmethod def _get_query_builder_uuid_identifier(cls, identifier, classes, query_with_dashes): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, interpreting the identifier as a UUID like identifier :param identifier: the UUID identifier @@ -336,8 +331,7 @@ def _get_query_builder_uuid_identifier(cls, identifier, classes, query_with_dash def get_query_builder( cls, identifier, identifier_type=None, sub_classes=None, query_with_dashes=True, operator='==', project='*' ): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, inferring the identifier type if it is not defined. :param identifier: the identifier @@ -348,7 +342,6 @@ def get_query_builder( :param project: the property or properties to project for entities matching the query :returns: the query builder instance and a dictionary of used query parameters """ - # pylint: disable=too-many-arguments classes = cls.get_query_classes(sub_classes) if identifier_type is None: @@ -385,8 +378,7 @@ def get_options(cls, incomplete, project='*'): @classmethod def load_entity(cls, identifier, identifier_type=None, sub_classes=None, query_with_dashes=True): - """ - Load an entity that uniquely corresponds to the provided identifier of the identifier type. + """Load an entity that uniquely corresponds to the provided identifier of the identifier type. :param identifier: the identifier :param identifier_type: the type of the identifier @@ -416,8 +408,7 @@ def load_entity(cls, identifier, identifier_type=None, sub_classes=None, query_w @classmethod def get_query_classes(cls, sub_classes=None): - """ - Get the tuple of classes to be used for the entity query. If sub_classes is defined, each class will be + """Get the tuple of classes to be used for the entity query. If sub_classes is defined, each class will be validated by verifying that it is a sub class of the loader's orm base class. Validate a tuple of classes if a user passes in a specific one when attempting to load an entity. Each class should be a sub class of the entity loader's orm base class @@ -441,9 +432,7 @@ def get_query_classes(cls, sub_classes=None): @classmethod def infer_identifier_type(cls, value): - """ - This method will attempt to automatically distinguish which identifier type is implied for the given value, if - the value itself has no type from which it can be inferred. + """Attempt to automatically distinguish which identifier type is implied for the given value. The strategy is to first attempt to convert the value to an integer. If successful, it is assumed that the value represents an ID. If that fails, we attempt to interpret the value as a base 16 encoded integer, after having @@ -475,18 +464,15 @@ def infer_identifier_type(cls, value): # If the final character of the value is the special marker, we enforce LABEL interpretation if value[-1] == cls.label_ambiguity_breaker: - identifier = value.rstrip(cls.label_ambiguity_breaker) identifier_type = IdentifierType.LABEL else: - # If the value can be cast into an integer, interpret it as an ID try: identifier = int(value) identifier_type = IdentifierType.ID except ValueError: - # If the value is a valid base sixteen encoded integer, after dashes are removed, interpret it as a UUID try: int(value.replace('-', ''), 16) @@ -506,20 +492,19 @@ class ProcessEntityLoader(OrmEntityLoader): @classproperty def orm_base_class(self): - """ - Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity + """Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity may further narrow the query set by defining a more specific set of orm classes, as long as each of those is a strict sub class of the orm base class. :returns: the orm base class """ from aiida.orm import ProcessNode + return ProcessNode @classmethod def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', project='*'): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, interpreting the identifier as a LABEL like identifier :param identifier: the LABEL identifier @@ -546,20 +531,19 @@ class CalculationEntityLoader(OrmEntityLoader): @classproperty def orm_base_class(self): - """ - Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity + """Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity may further narrow the query set by defining a more specific set of orm classes, as long as each of those is a strict sub class of the orm base class. :returns: the orm base class """ from aiida.orm import CalculationNode + return CalculationNode @classmethod def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', project='*'): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, interpreting the identifier as a LABEL like identifier :param identifier: the LABEL identifier @@ -586,20 +570,19 @@ class WorkflowEntityLoader(OrmEntityLoader): @classproperty def orm_base_class(self): - """ - Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity + """Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity may further narrow the query set by defining a more specific set of orm classes, as long as each of those is a strict sub class of the orm base class. :returns: the orm base class """ from aiida.orm import WorkflowNode + return WorkflowNode @classmethod def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', project='*'): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, interpreting the identifier as a LABEL like identifier :param identifier: the LABEL identifier @@ -626,20 +609,19 @@ class CodeEntityLoader(OrmEntityLoader): @classproperty def orm_base_class(self): - """ - Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity + """Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity may further narrow the query set by defining a more specific set of orm classes, as long as each of those is a strict sub class of the orm base class. :returns: the orm base class """ from aiida.orm import Code + return Code @classmethod def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', project='*'): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, interpreting the identifier as a LABEL like identifier :param identifier: the LABEL identifier @@ -675,20 +657,19 @@ class ComputerEntityLoader(OrmEntityLoader): @classproperty def orm_base_class(self): - """ - Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity + """Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity may further narrow the query set by defining a more specific set of orm classes, as long as each of those is a strict sub class of the orm base class. :returns: the orm base class """ from aiida.orm import Computer + return Computer @classmethod def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', project='*'): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, interpreting the identifier as a LABEL like identifier :param identifier: the LABEL identifier @@ -715,20 +696,19 @@ class DataEntityLoader(OrmEntityLoader): @classproperty def orm_base_class(self): - """ - Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity + """Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity may further narrow the query set by defining a more specific set of orm classes, as long as each of those is a strict sub class of the orm base class. :returns: the orm base class """ from aiida.orm import Data + return Data @classmethod def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', project='*'): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, interpreting the identifier as a LABEL like identifier :param identifier: the LABEL identifier @@ -755,20 +735,19 @@ class GroupEntityLoader(OrmEntityLoader): @classproperty def orm_base_class(self): - """ - Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity + """Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity may further narrow the query set by defining a more specific set of orm classes, as long as each of those is a strict sub class of the orm base class. :returns: the orm base class """ from aiida.orm.groups import Group + return Group @classmethod def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', project='*'): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, interpreting the identifier as a LABEL like identifier :param identifier: the LABEL identifier @@ -795,20 +774,19 @@ class NodeEntityLoader(OrmEntityLoader): @classproperty def orm_base_class(self): - """ - Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity + """Return the orm base class to which loaded entities should be mapped. Actual queries to load an entity may further narrow the query set by defining a more specific set of orm classes, as long as each of those is a strict sub class of the orm base class. :returns: the orm base class """ from aiida.orm import Node + return Node @classmethod def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', project='*'): - """ - Return the query builder instance that attempts to map the identifier onto an entity of the orm class, + """Return the query builder instance that attempts to map the identifier onto an entity of the orm class, defined for this loader class, interpreting the identifier as a LABEL like identifier :param identifier: the LABEL identifier diff --git a/aiida/orm/utils/log.py b/aiida/orm/utils/log.py index bc279bec7a..34f685faee 100644 --- a/aiida/orm/utils/log.py +++ b/aiida/orm/utils/log.py @@ -31,10 +31,11 @@ def emit(self, record): # The backend should be set. We silently absorb this error pass - except Exception: # pylint: disable=broad-except + except Exception: # To avoid loops with the error handler, I just print. # Hopefully, though, this should not happen! import traceback + traceback.print_exc() raise diff --git a/aiida/orm/utils/managers.py b/aiida/orm/utils/managers.py index 3e04631751..8ddd672781 100644 --- a/aiida/orm/utils/managers.py +++ b/aiida/orm/utils/managers.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Contain utility classes for "managers", i.e., classes that allow +"""Contain utility classes for "managers", i.e., classes that allow to access members of other classes via TAB-completable attributes (e.g. the class underlying `calculation.inputs` to allow to do `calculation.inputs.