From e2a686237ac4efba14b56defd42c48236fd8b6d2 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 22 Feb 2022 02:56:47 +0100 Subject: [PATCH 1/3] =?UTF-8?q?=F0=9F=A7=AA=20TESTS:=20Entirely=20remove?= =?UTF-8?q?=20`AiidaTestCase`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/system_tests/test_plugin_testcase.py | 117 -- aiida/manage/tests/__init__.py | 3 - aiida/manage/tests/main.py | 6 +- aiida/manage/tests/unittest_classes.py | 100 - aiida/storage/testbase.py | 175 -- tests/cmdline/commands/test_calcjob.py | 201 +- tests/cmdline/commands/test_computer.py | 133 +- tests/cmdline/commands/test_data.py | 564 +++--- tests/cmdline/commands/test_group.py | 306 ++- tests/cmdline/commands/test_help.py | 25 +- tests/cmdline/commands/test_node.py | 346 ++-- tests/cmdline/commands/test_process.py | 251 ++- tests/cmdline/commands/test_profile.py | 125 +- tests/cmdline/commands/test_run.py | 112 +- tests/cmdline/commands/test_setup.py | 92 +- tests/cmdline/commands/test_user.py | 36 +- .../params/options/test_conditional.py | 4 +- .../cmdline/params/types/test_calculation.py | 47 +- tests/cmdline/params/types/test_data.py | 39 +- tests/cmdline/params/types/test_identifier.py | 30 +- tests/cmdline/params/types/test_node.py | 40 +- tests/cmdline/params/types/test_path.py | 14 +- tests/cmdline/params/types/test_plugin.py | 131 +- tests/common/test_hashing.py | 176 +- tests/common/test_links.py | 33 +- tests/conftest.py | 2 +- tests/engine/daemon/test_client.py | 42 +- .../processes/calcjobs/test_calc_job.py | 96 +- .../engine/processes/workchains/test_utils.py | 14 +- tests/engine/test_calcfunctions.py | 75 +- tests/engine/test_class_loader.py | 24 +- tests/engine/test_futures.py | 9 +- tests/engine/test_launch.py | 119 +- tests/engine/test_manager.py | 49 +- tests/engine/test_persistence.py | 33 +- tests/engine/test_ports.py | 46 +- tests/engine/test_process.py | 142 +- tests/engine/test_process_function.py | 193 +- tests/engine/test_process_spec.py | 55 +- tests/engine/test_rmq.py | 58 +- tests/engine/test_run.py | 7 +- tests/engine/test_transport.py | 36 +- tests/engine/test_utils.py | 77 +- tests/engine/test_work_chain.py | 385 ++-- tests/engine/test_workfunctions.py | 35 +- tests/manage/configuration/test_options.py | 32 +- tests/manage/configuration/test_profile.py | 49 +- tests/orm/implementation/test_comments.py | 172 +- tests/orm/implementation/test_logs.py | 146 +- tests/orm/implementation/test_nodes.py | 361 ++-- tests/orm/implementation/test_utils.py | 14 +- tests/orm/nodes/data/test_kpoints.py | 31 +- tests/orm/nodes/data/test_orbital.py | 30 +- tests/orm/nodes/data/test_trajectory.py | 31 +- tests/orm/nodes/data/test_upf.py | 96 +- tests/orm/nodes/test_calcjob.py | 24 +- tests/orm/test_authinfos.py | 23 +- tests/orm/test_comments.py | 75 +- tests/orm/test_computers.py | 38 +- tests/orm/test_entities.py | 20 +- tests/orm/test_groups.py | 186 +- tests/orm/test_logs.py | 102 +- tests/orm/test_mixins.py | 11 +- tests/orm/utils/test_loaders.py | 108 +- tests/orm/utils/test_managers.py | 4 +- tests/orm/utils/test_node.py | 20 +- tests/parsers/test_parser.py | 35 +- tests/plugins/test_utils.py | 21 +- tests/restapi/test_routes.py | 450 ++--- tests/storage/psql_dos/test_nodes.py | 44 +- tests/storage/psql_dos/test_query.py | 59 +- tests/storage/psql_dos/test_schema.py | 63 +- tests/test_calculation_node.py | 61 +- tests/test_dataclasses.py | 1712 ++++++++--------- tests/test_dbimporters.py | 213 +- tests/test_generic.py | 132 +- tests/test_nodes.py | 573 +++--- tests/tools/data/orbital/test_orbitals.py | 75 +- tests/tools/dbimporters/test_icsd.py | 29 +- .../dbimporters/test_materialsproject.py | 5 +- tests/tools/graph/test_age.py | 111 +- tests/tools/graph/test_graph_traversers.py | 52 +- tests/tools/visualization/test_graph.py | 73 +- 83 files changed, 4473 insertions(+), 5411 deletions(-) delete mode 100644 .github/system_tests/test_plugin_testcase.py delete mode 100644 aiida/manage/tests/unittest_classes.py delete mode 100644 aiida/storage/testbase.py diff --git a/.github/system_tests/test_plugin_testcase.py b/.github/system_tests/test_plugin_testcase.py deleted file mode 100644 index afc841da75..0000000000 --- a/.github/system_tests/test_plugin_testcase.py +++ /dev/null @@ -1,117 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Test the plugin test case - -This must be in a standalone script because it would clash with other tests, -Since the dbenv gets loaded on the temporary profile. -""" - -import shutil -import sys -import tempfile -import unittest - -from aiida.manage.tests.unittest_classes import PluginTestCase, TestRunner - - -class PluginTestCase1(PluginTestCase): - """ - Test the PluginTestCase from utils.fixtures - """ - - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.data = self.get_data() - self.data_pk = self.data.pk - self.computer = self.get_computer(temp_dir=self.temp_dir) - - def tearDown(self): - super().tearDown() - shutil.rmtree(self.temp_dir) - - @staticmethod - def get_data(): - """ - Return some Dict - """ - from aiida.plugins import DataFactory - data = DataFactory('core.dict')(dict={'data': 'test'}) - data.store() - return data - - @classmethod - def get_computer(cls, temp_dir): - """ - Create and store a new computer, and return it - """ - from aiida import orm - - computer = orm.Computer( - label='localhost', - hostname='localhost', - description='my computer', - transport_type='core.local', - scheduler_type='core.direct', - workdir=temp_dir, - backend=cls.backend - ).store() - return computer - - def test_data_loaded(self): - """ - Check that the data node is indeed in the DB when calling load_node - """ - from aiida import orm - self.assertEqual(orm.load_node(self.data_pk).uuid, self.data.uuid) - - def test_computer_loaded(self): - """ - Check that the computer is indeed in the DB when calling load_node - - Note: Important to have at least two test functions in order to verify things - work after resetting the DB. - """ - from aiida import orm - self.assertEqual(orm.Computer.objects.get(label='localhost').uuid, self.computer.uuid) - - def test_tear_down(self): - """ - Check that after tearing down, the previously stored nodes - are not there anymore. - """ - from aiida.orm import load_node - super().tearDown() # reset DB - with self.assertRaises(Exception): - load_node(self.data_pk) - - -class PluginTestCase2(PluginTestCase): - """ - Second PluginTestCase. - """ - - def test_dummy(self): - """ - Dummy test for 2nd PluginTestCase class. - - Just making sure that setup/teardown is safe for - multiple testcase classes (this was broken in #1425). - """ - super().tearDown() - - -if __name__ == '__main__': - MODULE = sys.modules[__name__] - SUITE = unittest.defaultTestLoader.loadTestsFromModule(MODULE) - RESULT = TestRunner().run(SUITE) - - EXIT_CODE = int(not RESULT.wasSuccessful()) - sys.exit(EXIT_CODE) diff --git a/aiida/manage/tests/__init__.py b/aiida/manage/tests/__init__.py index 813f38628c..f1efdf4603 100644 --- a/aiida/manage/tests/__init__.py +++ b/aiida/manage/tests/__init__.py @@ -17,15 +17,12 @@ # pylint: disable=wildcard-import from .main import * -from .unittest_classes import * __all__ = ( - 'PluginTestCase', 'ProfileManager', 'TemporaryProfileManager', 'TestManager', 'TestManagerError', - 'TestRunner', 'get_test_backend_name', 'get_test_profile_name', 'get_user_dict', diff --git a/aiida/manage/tests/main.py b/aiida/manage/tests/main.py index 3d8857e427..450c8a0450 100644 --- a/aiida/manage/tests/main.py +++ b/aiida/manage/tests/main.py @@ -147,14 +147,16 @@ def __init__(self, profile_name): :param profile_name: Name of the profile to be loaded """ from aiida import load_profile - from aiida.storage.testbase import check_if_tests_can_run self._profile = None try: self._profile = load_profile(profile_name) except Exception: raise TestManagerError(f'Unable to load test profile `{profile_name}`.') - check_if_tests_can_run() + if self._profile is None: + raise TestManagerError(f'Unable to load test profile `{profile_name}`.') + if not self._profile.is_test_profile: + raise TestManagerError(f'Profile `{profile_name}` is not a valid test profile.') @staticmethod def clear_profile(): diff --git a/aiida/manage/tests/unittest_classes.py b/aiida/manage/tests/unittest_classes.py deleted file mode 100644 index 55f58aedcd..0000000000 --- a/aiida/manage/tests/unittest_classes.py +++ /dev/null @@ -1,100 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Test classes and test runners for testing AiiDA plugins with unittest. -""" -import unittest -import warnings - -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.manage import get_manager - -from .main import _GLOBAL_TEST_MANAGER, get_test_backend_name, get_test_profile_name, test_manager - -__all__ = ('PluginTestCase', 'TestRunner') - -warnings.warn( # pylint: disable=no-member - 'This module has been deprecated and will be removed soon. Please use the `pytest` fixtures instead.\n' - 'See https://github.com/aiidateam/aiida-core/wiki/AiiDA-2.0-plugin-migration-guide#unit-tests', - AiidaDeprecationWarning -) - - -class PluginTestCase(unittest.TestCase): - """ - Set up a complete temporary AiiDA environment for plugin tests. - - Note: This test class needs to be run through the :py:class:`~aiida.manage.tests.unittest_classes.TestRunner` - and will **not** work simply with `python -m unittest discover`. - - Usage example:: - - MyTestCase(aiida.manage.tests.unittest_classes.PluginTestCase): - - def setUp(self): - # load my tests data - - # optionally extend setUpClass / tearDownClass / tearDown if needed - - def test_my_plugin(self): - # execute tests - """ - # Filled in during setUpClass - backend = None # type :class:`aiida.orm.implementation.Backend` - - @classmethod - def setUpClass(cls): - cls.test_manager = _GLOBAL_TEST_MANAGER - if not cls.test_manager.has_profile_open(): - raise ValueError( - 'Fixture mananger has no open profile.' + - 'Please use aiida.manage.tests.unittest_classes.TestRunner to run these tests.' - ) - - cls.backend = get_manager().get_profile_storage() - - def tearDown(self): - manager = get_manager() - if manager.profile_storage_loaded: - manager.get_profile_storage()._clear(recreate_user=True) # pylint: disable=protected-access - - -class TestRunner(unittest.runner.TextTestRunner): - """ - Testrunner for unit tests using the fixture manager. - - Usage example:: - - import unittest - from aiida.manage.tests.unittest_classes import TestRunner - - tests = unittest.defaultTestLoader.discover('.') - TestRunner().run(tests) - - """ - - # pylint: disable=arguments-differ - def run(self, suite, backend=None, profile_name=None): - """ - Run tests using fixture manager for specified backend. - - :param suite: A suite of tests, as returned e.g. by :py:meth:`unittest.TestLoader.discover` - :param backend: name of database backend to be used. - :param profile_name: name of test profile to be used or None (will use temporary profile) - """ - warnings.warn( # pylint: disable=no-member - 'Please use "pytest" for testing AiiDA plugins. Support for "unittest" will be removed soon', - AiidaDeprecationWarning - ) - - with test_manager( - backend=backend or get_test_backend_name(), profile_name=profile_name or get_test_profile_name() - ): - return super().run(suite) diff --git a/aiida/storage/testbase.py b/aiida/storage/testbase.py deleted file mode 100644 index 30a093f658..0000000000 --- a/aiida/storage/testbase.py +++ /dev/null @@ -1,175 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Basic test classes.""" -import traceback -from typing import Optional -import unittest - -from aiida import orm -from aiida.common.exceptions import TestsNotAllowedError -from aiida.common.lang import classproperty -from aiida.manage import configuration, get_manager -from aiida.orm.implementation import StorageBackend - -TEST_KEYWORD = 'test_' - - -def check_if_tests_can_run(): - """Verify that the currently loaded profile is a test profile, otherwise raise `TestsNotAllowedError`.""" - profile = configuration.get_profile() - if not profile: - raise TestsNotAllowedError('No profile is loaded.') - if not profile.is_test_profile: - raise TestsNotAllowedError(f'currently loaded profile {profile.name} is not a valid test profile') - - -class AiidaTestCase(unittest.TestCase): - """This is the base class for AiiDA tests, independent of the backend.""" - _class_was_setup = False - backend: Optional[StorageBackend] = None - - @classmethod - def setUpClass(cls): - """Set up test class.""" - # Note: this will raise an exception, that will be seen as a test - # failure. To be safe, you should do the same check also in the tearDownClass - # to avoid that it is run - check_if_tests_can_run() - - # Force the loading of the backend which will load the required database environment - cls._class_was_setup = True - cls.clean_db() - cls.backend = get_manager().get_profile_storage() - - @classmethod - def tearDownClass(cls): - """Tear down test class, by clearing all backend storage.""" - # Double check for double security to avoid to run the tearDown - # if this is not a test profile - - check_if_tests_can_run() - cls.clean_db() - - def tearDown(self): - manager = get_manager() - # this should really call reset profile, but that also resets the storage backend - # and causes issues for some existing tests that set class level entities - # manager.reset_profile() - # pylint: disable=protected-access - if manager._communicator is not None: - manager._communicator.close() - if manager._runner is not None: - manager._runner.stop() - manager._communicator = None - manager._runner = None - manager._daemon_client = None - manager._process_controller = None - manager._persister = None - - ### storage methods - - @classmethod - def clean_db(cls): - """Clean up database and reset caches. - - Resets AiiDA manager cache, which could otherwise be left in an inconsistent state when cleaning the database. - """ - from aiida.common.exceptions import InvalidOperation - - # Note: this will raise an exception, that will be seen as a test - # failure. To be safe, you should do the same check also in the tearDownClass - # to avoid that it is run - check_if_tests_can_run() - - if not cls._class_was_setup: - raise InvalidOperation('You cannot call clean_db before running the setUpClass') - - manager = get_manager() - manager.get_profile_storage()._clear(recreate_user=True) # pylint: disable=protected-access - manager.reset_profile() - - @classmethod - def refurbish_db(cls): - """Clean up database and repopulate with initial data.""" - cls.clean_db() - - @classproperty - def computer(cls) -> orm.Computer: # pylint: disable=no-self-argument - """Get the default computer for this test - - :return: the test computer - """ - created, computer = orm.Computer.objects.get_or_create( - label='localhost', - hostname='localhost', - transport_type='core.local', - scheduler_type='core.direct', - workdir='/tmp/aiida', - ) - if created: - computer.store() - return computer - - @classproperty - def user(cls) -> orm.User: # pylint: disable=no-self-argument - return get_default_user() - - @classproperty - def user_email(cls) -> str: # pylint: disable=no-self-argument - return cls.user.email # pylint: disable=no-member - - ### Usability methods - - def assertClickSuccess(self, cli_result): # pylint: disable=invalid-name - self.assertEqual(cli_result.exit_code, 0, cli_result.output) - self.assertClickResultNoException(cli_result) - - def assertClickResultNoException(self, cli_result): # pylint: disable=invalid-name - self.assertIsNone(cli_result.exception, ''.join(traceback.format_exception(*cli_result.exc_info))) - - -class AiidaPostgresTestCase(AiidaTestCase): - """Setup postgres tests.""" - - @classmethod - def setUpClass(cls, *args, **kwargs): - """Setup the PGTest postgres test cluster.""" - from pgtest.pgtest import PGTest - cls.pg_test = PGTest() - super().setUpClass(*args, **kwargs) - - @classmethod - def tearDownClass(cls, *args, **kwargs): - """Close the PGTest postgres test cluster.""" - super().tearDownClass(*args, **kwargs) - cls.pg_test.close() - - -def get_default_user(**kwargs): - """Creates and stores the default user in the database. - - Default user email is taken from current profile. - No-op if user already exists. - The same is done in `verdi setup`. - - :param kwargs: Additional information to use for new user, i.e. 'first_name', 'last_name' or 'institution'. - :returns: the :py:class:`~aiida.orm.User` - """ - email = configuration.get_profile().default_user_email - - if kwargs.pop('email', None): - raise ValueError('Do not specify the user email (must coincide with default user email of profile).') - - # Create the AiiDA user if it does not yet exist - created, user = orm.User.objects.get_or_create(email=email, **kwargs) - if created: - user.store() - - return user diff --git a/tests/cmdline/commands/test_calcjob.py b/tests/cmdline/commands/test_calcjob.py index 9c75795392..7107d447ed 100644 --- a/tests/cmdline/commands/test_calcjob.py +++ b/tests/cmdline/commands/test_calcjob.py @@ -12,14 +12,16 @@ import io from click.testing import CliRunner +import pytest from aiida import orm from aiida.cmdline.commands import cmd_calcjob as command from aiida.common.datastructures import CalcJobState +from aiida.common.links import LinkType +from aiida.engine import ProcessState from aiida.orm.nodes.data.remote.base import RemoteData from aiida.plugins import CalculationFactory from aiida.plugins.entry_point import get_entry_point_string_from_class -from aiida.storage.testbase import AiidaTestCase from tests.utils.archives import import_test_archive @@ -27,31 +29,19 @@ def get_result_lines(result): return [e for e in result.output.split('\n') if e] -class TestVerdiCalculation(AiidaTestCase): +class TestVerdiCalculation: """Tests for `verdi calcjob`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - from aiida.common.links import LinkType - from aiida.engine import ProcessState - - cls.computer = orm.Computer( - label='comp', - hostname='localhost', - transport_type='core.local', - scheduler_type='core.direct', - workdir='/tmp/aiida' - ).store() - - cls.code = orm.Code(remote_computer_exec=(cls.computer, '/bin/true')).store() - cls.group = orm.Group(label='test_group').store() - cls.node = orm.Data().store() - cls.calcs = [] - - user = orm.User.objects.get_default() - authinfo = orm.AuthInfo(computer=cls.computer, user=user) - authinfo.store() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + + self.computer = aiida_localhost + self.code = orm.Code(remote_computer_exec=(self.computer, '/bin/true')).store() + self.group = orm.Group(label='test_group').store() + self.node = orm.Data().store() + self.calcs = [] process_class = CalculationFactory('core.templatereplacer') process_type = get_entry_point_string_from_class(process_class.__module__, process_class.__name__) @@ -59,7 +49,7 @@ def setUpClass(cls, *args, **kwargs): # Create 5 CalcJobNodes (one for each CalculationState) for calculation_state in CalcJobState: - calc = orm.CalcJobNode(computer=cls.computer, process_type=process_type) + calc = orm.CalcJobNode(computer=self.computer, process_type=process_type) calc.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) calc.set_remote_workdir('/tmp/aiida/work') remote = RemoteData(remote_path='/tmp/aiida/work') @@ -69,40 +59,40 @@ def setUpClass(cls, *args, **kwargs): remote.store() calc.set_process_state(ProcessState.RUNNING) - cls.calcs.append(calc) + self.calcs.append(calc) if calculation_state == CalcJobState.PARSING: - cls.KEY_ONE = 'key_one' - cls.KEY_TWO = 'key_two' - cls.VAL_ONE = 'val_one' - cls.VAL_TWO = 'val_two' + self.KEY_ONE = 'key_one' + self.KEY_TWO = 'key_two' + self.VAL_ONE = 'val_one' + self.VAL_TWO = 'val_two' output_parameters = orm.Dict(dict={ - cls.KEY_ONE: cls.VAL_ONE, - cls.KEY_TWO: cls.VAL_TWO, + self.KEY_ONE: self.VAL_ONE, + self.KEY_TWO: self.VAL_TWO, }).store() output_parameters.add_incoming(calc, LinkType.CREATE, 'output_parameters') # Create shortcut for easy dereferencing - cls.result_job = calc + self.result_job = calc # Add a single calc to a group - cls.group.add_nodes([calc]) + self.group.add_nodes([calc]) # Create a single failed CalcJobNode - cls.EXIT_STATUS = 100 - calc = orm.CalcJobNode(computer=cls.computer) + self.EXIT_STATUS = 100 + calc = orm.CalcJobNode(computer=self.computer) calc.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) calc.store() - calc.set_exit_status(cls.EXIT_STATUS) + calc.set_exit_status(self.EXIT_STATUS) calc.set_process_state(ProcessState.FINISHED) calc.set_remote_workdir('/tmp/aiida/work') remote = RemoteData(remote_path='/tmp/aiida/work') remote.computer = calc.computer remote.add_incoming(calc, LinkType.CREATE, link_label='remote_folder') remote.store() - cls.calcs.append(calc) + self.calcs.append(calc) # Load the fixture containing a single ArithmeticAddCalculation node import_test_archive('calcjob/arithmetic.add.aiida') @@ -110,101 +100,98 @@ def setUpClass(cls, *args, **kwargs): # Get the imported ArithmeticAddCalculation node ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') calculations = orm.QueryBuilder().append(ArithmeticAddCalculation).all()[0] - cls.arithmetic_job = calculations[0] - print(cls.arithmetic_job.repository_metadata) + self.arithmetic_job = calculations[0] - def setUp(self): - super().setUp() self.cli_runner = CliRunner() def test_calcjob_res(self): """Test verdi calcjob res""" options = [str(self.result_job.uuid)] result = self.cli_runner.invoke(command.calcjob_res, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.KEY_ONE, result.output) - self.assertIn(self.VAL_ONE, result.output) - self.assertIn(self.KEY_TWO, result.output) - self.assertIn(self.VAL_TWO, result.output) + assert result.exception is None, result.output + assert self.KEY_ONE in result.output + assert self.VAL_ONE in result.output + assert self.KEY_TWO in result.output + assert self.VAL_TWO in result.output for flag in ['-k', '--keys']: options = [flag, self.KEY_ONE, '--', str(self.result_job.uuid)] result = self.cli_runner.invoke(command.calcjob_res, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.KEY_ONE, result.output) - self.assertIn(self.VAL_ONE, result.output) - self.assertNotIn(self.KEY_TWO, result.output) - self.assertNotIn(self.VAL_TWO, result.output) + assert result.exception is None, result.output + assert self.KEY_ONE in result.output + assert self.VAL_ONE in result.output + assert self.KEY_TWO not in result.output + assert self.VAL_TWO not in result.output def test_calcjob_inputls(self): """Test verdi calcjob inputls""" options = [] result = self.cli_runner.invoke(command.calcjob_inputls, options) - self.assertIsNotNone(result.exception) + assert result.exception is not None options = [self.arithmetic_job.uuid] result = self.cli_runner.invoke(command.calcjob_inputls, options) - self.assertIsNone(result.exception, result.output) + assert result.exception is None, result.output # There is also an additional fourth file added by hand to test retrieval of binary content # see comments in test_calcjob_inputcat - self.assertEqual(len(get_result_lines(result)), 4) - self.assertIn('.aiida', get_result_lines(result)) - self.assertIn('aiida.in', get_result_lines(result)) - self.assertIn('_aiidasubmit.sh', get_result_lines(result)) - self.assertIn('in_gzipped_data', get_result_lines(result)) + assert len(get_result_lines(result)) == 4 + assert '.aiida' in get_result_lines(result) + assert 'aiida.in' in get_result_lines(result) + assert '_aiidasubmit.sh' in get_result_lines(result) + assert 'in_gzipped_data' in get_result_lines(result) options = [self.arithmetic_job.uuid, '.aiida'] result = self.cli_runner.invoke(command.calcjob_inputls, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 2) - self.assertIn('calcinfo.json', get_result_lines(result)) - self.assertIn('job_tmpl.json', get_result_lines(result)) + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 2 + assert 'calcinfo.json' in get_result_lines(result) + assert 'job_tmpl.json' in get_result_lines(result) options = [self.arithmetic_job.uuid, 'non-existing-folder'] result = self.cli_runner.invoke(command.calcjob_inputls, options) - self.assertIsNotNone(result.exception) - self.assertIn('does not exist for the given node', result.output) + assert result.exception is not None + assert 'does not exist for the given node' in result.output def test_calcjob_outputls(self): """Test verdi calcjob outputls""" options = [] result = self.cli_runner.invoke(command.calcjob_outputls, options) - self.assertIsNotNone(result.exception, msg=result.output) + assert result.exception is not None, result.output options = [self.arithmetic_job.uuid] result = self.cli_runner.invoke(command.calcjob_outputls, options) - self.assertIsNone(result.exception, result.output) + assert result.exception is None, result.output # There is also an additional fourth file added by hand to test retrieval of binary content # see comments in test_calcjob_outputcat - self.assertEqual(len(get_result_lines(result)), 4) - self.assertIn('_scheduler-stderr.txt', get_result_lines(result)) - self.assertIn('_scheduler-stdout.txt', get_result_lines(result)) - self.assertIn('aiida.out', get_result_lines(result)) - self.assertIn('gzipped_data', get_result_lines(result)) + assert len(get_result_lines(result)) == 4 + assert '_scheduler-stderr.txt' in get_result_lines(result) + assert '_scheduler-stdout.txt' in get_result_lines(result) + assert 'aiida.out' in get_result_lines(result) + assert 'gzipped_data' in get_result_lines(result) options = [self.arithmetic_job.uuid, 'non-existing-folder'] result = self.cli_runner.invoke(command.calcjob_inputls, options) - self.assertIsNotNone(result.exception) - self.assertIn('does not exist for the given node', result.output) + assert result.exception is not None + assert 'does not exist for the given node' in result.output def test_calcjob_inputcat(self): """Test verdi calcjob inputcat""" options = [] result = self.cli_runner.invoke(command.calcjob_inputcat, options) - self.assertIsNotNone(result.exception, msg=result.output) + assert result.exception is not None, result.output options = [self.arithmetic_job.uuid] result = self.cli_runner.invoke(command.calcjob_inputcat, options) - self.assertIsNone(result.exception, msg=result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '2 3') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '2 3' options = [self.arithmetic_job.uuid, 'aiida.in'] result = self.cli_runner.invoke(command.calcjob_inputcat, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '2 3') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '2 3' # Test cat binary files self.arithmetic_job._repository.put_object_from_filelike(io.BytesIO(b'COMPRESS'), 'aiida.in') @@ -223,19 +210,19 @@ def test_calcjob_outputcat(self): options = [] result = self.cli_runner.invoke(command.calcjob_outputcat, options) - self.assertIsNotNone(result.exception) + assert result.exception is not None options = [self.arithmetic_job.uuid] result = self.cli_runner.invoke(command.calcjob_outputcat, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '5') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '5' options = [self.arithmetic_job.uuid, 'aiida.out'] result = self.cli_runner.invoke(command.calcjob_outputcat, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '5') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '5' # Test cat binary files retrieved = self.arithmetic_job.outputs.retrieved @@ -256,22 +243,22 @@ def test_calcjob_cleanworkdir(self): # Specifying no filtering options and no explicit calcjobs should exit with non-zero status options = [] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) - self.assertIsNotNone(result.exception) + assert result.exception is not None # Without the force flag it should fail options = [str(self.result_job.uuid)] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) - self.assertIsNotNone(result.exception) + assert result.exception is not None # With force flag we should find one calcjob options = ['-f', str(self.result_job.uuid)] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) - self.assertIsNone(result.exception, result.output) + assert result.exception is None, result.output # Do it again should fail as the calcjob has been cleaned options = ['-f', str(self.result_job.uuid)] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) - self.assertIsNotNone(result.exception, result.output) + assert result.exception is not None, result.output # The flag should have been set assert self.result_job.outputs.remote_folder.get_extra('cleaned') is True @@ -282,29 +269,29 @@ def test_calcjob_cleanworkdir(self): options = [flag_p, '5', flag_o, '1', '-f'] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) # This should fail - the node was just created in the test - self.assertIsNotNone(result.exception) + assert result.exception is not None options = [flag_p, '5', flag_o, '0', '-f'] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) # This should pass fine - self.assertIsNone(result.exception) + assert result.exception is None self.result_job.outputs.remote_folder.delete_extra('cleaned') options = [flag_p, '0', flag_o, '0', '-f'] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) # This should not pass - self.assertIsNotNone(result.exception) + assert result.exception is not None # Should fail because the exit code is not 999 - using the failed job for testing options = [str(self.calcs[-1].uuid), '-E', '999', '-f'] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) - self.assertIsNotNone(result.exception) + assert result.exception is not None # Should be fine because the exit code is 100 self.calcs[-1].outputs.remote_folder.set_extra('cleaned', False) options = [str(self.calcs[-1].uuid), '-E', '100', '-f'] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) - self.assertIsNone(result.exception) + assert result.exception is None def test_calcjob_inoutputcat_old(self): """Test most recent process class / plug-in can be successfully used to find filenames""" @@ -322,23 +309,19 @@ def test_calcjob_inoutputcat_old(self): break assert add_job # Make sure add_job does not specify options 'input_filename' and 'output_filename' - self.assertIsNone( - add_job.get_option('input_filename'), msg=f"'input_filename' should not be an option for {add_job}" - ) - self.assertIsNone( - add_job.get_option('output_filename'), msg=f"'output_filename' should not be an option for {add_job}" - ) + assert add_job.get_option('input_filename') is None, f"'input_filename' should not be an option for {add_job}" + assert add_job.get_option('output_filename') is None, f"'output_filename' should not be an option for {add_job}" # Run `verdi calcjob inputcat add_job` options = [add_job.uuid] result = self.cli_runner.invoke(command.calcjob_inputcat, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '2 3') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '2 3' # Run `verdi calcjob outputcat add_job` options = [add_job.uuid] result = self.cli_runner.invoke(command.calcjob_outputcat, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '5') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '5' diff --git a/tests/cmdline/commands/test_computer.py b/tests/cmdline/commands/test_computer.py index c64b72c5d2..8e32ec517e 100644 --- a/tests/cmdline/commands/test_computer.py +++ b/tests/cmdline/commands/test_computer.py @@ -13,7 +13,6 @@ import os import tempfile -from click.testing import CliRunner import pytest from aiida import orm @@ -27,7 +26,6 @@ computer_show, computer_test, ) -from aiida.storage.testbase import AiidaTestCase def generate_setup_options_dict(replace_args=None, non_interactive=True): @@ -338,13 +336,15 @@ def test_noninteractive_from_config(run_cli_command): assert isinstance(orm.Computer.objects.get(label=label), orm.Computer) -class TestVerdiComputerConfigure(AiidaTestCase): +class TestVerdiComputerConfigure: """Test the ``verdi computer configure`` command.""" - def setUp(self): - """Prepare computer builder with common properties.""" + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init from aiida.orm.utils.builders.computer import ComputerBuilder - self.cli_runner = CliRunner() + self.cli_runner = run_cli_command self.user = orm.User.objects.get_default() self.comp_builder = ComputerBuilder(label='test_comp_setup') self.comp_builder.hostname = 'localhost' @@ -360,7 +360,7 @@ def setUp(self): def test_top_help(self): """Test help option of verdi computer configure.""" - result = self.cli_runner.invoke(computer_configure, ['--help'], catch_exceptions=False) + result = self.cli_runner(computer_configure, ['--help'], catch_exceptions=False) assert 'core.ssh' in result.output assert 'core.local' in result.output @@ -387,7 +387,7 @@ def test_local_ni_empty(self): comp.store() options = ['core.local', comp.label, '--non-interactive', '--safe-interval', '0'] - result = self.cli_runner.invoke(computer_configure, options, catch_exceptions=False) + result = self.cli_runner(computer_configure, options, catch_exceptions=False) assert comp.is_user_configured(self.user), result.output self.comp_builder.label = 'test_local_ni_empty_mismatch' @@ -396,7 +396,7 @@ def test_local_ni_empty(self): comp_mismatch.store() options = ['core.local', comp_mismatch.label, '--non-interactive'] - result = self.cli_runner.invoke(computer_configure, options, catch_exceptions=False) + result = self.cli_runner(computer_configure, options, catch_exceptions=False) assert result.exception is not None assert 'core.ssh' in result.output assert 'core.local' in result.output @@ -410,8 +410,8 @@ def test_local_interactive(self): invalid = 'n' valid = '1.0' - result = self.cli_runner.invoke( - computer_configure, ['core.local', comp.label], input=f'{invalid}\n{valid}\n', catch_exceptions=False + result = self.cli_runner( + computer_configure, ['core.local', comp.label], user_input=f'{invalid}\n{valid}\n', catch_exceptions=False ) assert comp.is_user_configured(self.user), result.output @@ -448,8 +448,8 @@ def test_ssh_interactive(self): key_filename=key_filename ) - result = self.cli_runner.invoke( - computer_configure, ['core.ssh', comp.label], input=command_input, catch_exceptions=False + result = self.cli_runner( + computer_configure, ['core.ssh', comp.label], user_input=command_input, catch_exceptions=False ) assert comp.is_user_configured(self.user), result.output new_auth_params = comp.get_authinfo(self.user).get_auth_params() @@ -476,9 +476,8 @@ def test_local_from_config(self): handle.flush() options = ['core.local', computer.label, '--config', os.path.realpath(handle.name)] - result = self.cli_runner.invoke(computer_configure, options) + self.cli_runner(computer_configure, options) - self.assertClickResultNoException(result) assert computer.get_configuration()['safe_interval'] == interval def test_ssh_ni_empty(self): @@ -496,7 +495,7 @@ def test_ssh_ni_empty(self): comp.store() options = ['core.ssh', comp.label, '--non-interactive', '--safe-interval', '1'] - result = self.cli_runner.invoke(computer_configure, options, catch_exceptions=False) + result = self.cli_runner(computer_configure, options, catch_exceptions=False) assert comp.is_user_configured(self.user), result.output self.comp_builder.label = 'test_ssh_ni_empty_mismatch' @@ -505,7 +504,7 @@ def test_ssh_ni_empty(self): comp_mismatch.store() options = ['core.ssh', comp_mismatch.label, '--non-interactive'] - result = self.cli_runner.invoke(computer_configure, options, catch_exceptions=False) + result = self.cli_runner(computer_configure, options, catch_exceptions=False) assert result.exception is not None assert 'core.local' in result.output assert 'core.ssh' in result.output @@ -519,7 +518,7 @@ def test_ssh_ni_username(self): username = 'TEST' options = ['core.ssh', comp.label, '--non-interactive', f'--username={username}', '--safe-interval', '1'] - result = self.cli_runner.invoke(computer_configure, options, catch_exceptions=False) + result = self.cli_runner(computer_configure, options, catch_exceptions=False) auth_info = orm.AuthInfo.objects.get(dbcomputer_id=comp.id, aiidauser_id=self.user.id) assert comp.is_user_configured(self.user), result.output assert auth_info.get_auth_params()['username'] == username @@ -531,65 +530,55 @@ def test_show(self): comp = self.comp_builder.new() comp.store() - result = self.cli_runner.invoke(computer_configure, ['show', comp.label], catch_exceptions=False) + result = self.cli_runner(computer_configure, ['show', comp.label], catch_exceptions=False) - result = self.cli_runner.invoke(computer_configure, ['show', comp.label, '--defaults'], catch_exceptions=False) + result = self.cli_runner(computer_configure, ['show', comp.label, '--defaults'], catch_exceptions=False) assert '* username' in result.output - result = self.cli_runner.invoke( + result = self.cli_runner( computer_configure, ['show', comp.label, '--defaults', '--as-option-string'], catch_exceptions=False ) assert '--username=' in result.output config_cmd = ['core.ssh', comp.label, '--non-interactive'] config_cmd.extend(result.output.replace("'", '').split(' ')) - result_config = self.cli_runner.invoke(computer_configure, config_cmd, catch_exceptions=False) + result_config = self.cli_runner(computer_configure, config_cmd, catch_exceptions=False) assert comp.is_user_configured(self.user), result_config.output - result_cur = self.cli_runner.invoke( + result_cur = self.cli_runner( computer_configure, ['show', comp.label, '--as-option-string'], catch_exceptions=False ) assert '--username=' in result.output assert result_cur.output == result.output -class TestVerdiComputerCommands(AiidaTestCase): +class TestVerdiComputerCommands: """Testing verdi computer commands. Testing everything besides `computer setup`. """ - @classmethod - def setUpClass(cls, *args, **kwargs): - """Create a new computer> I create a new one because I want to configure it and I don't want to - interfere with other tests""" - super().setUpClass(*args, **kwargs) - cls.computer_name = 'comp_cli_test_computer' - cls.comp = orm.Computer( - label=cls.computer_name, + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer_name = 'comp_cli_test_computer' + self.comp = orm.Computer( + label=self.computer_name, hostname='localhost', transport_type='core.local', scheduler_type='core.direct', workdir='/tmp/aiida' ) - cls.comp.set_default_mpiprocs_per_machine(1) - cls.comp.set_default_memory_per_machine(1000000) - cls.comp.set_prepend_text('text to prepend') - cls.comp.set_append_text('text to append') - cls.comp.store() - - def setUp(self): - """ - Prepare the computer and user - """ - self.user = orm.User.objects.get_default() - - # I need to configure the computer here; being 'core.local', - # there should not be any options asked here + self.comp.set_default_mpiprocs_per_machine(1) + self.comp.set_default_memory_per_machine(1000000) + self.comp.set_prepend_text('text to prepend') + self.comp.set_append_text('text to append') + self.comp.store() self.comp.configure() - + self.user = orm.User.objects.get_default() assert self.comp.is_user_configured(self.user), 'There was a problem configuring the test computer' - self.cli_runner = CliRunner() + self.cli_runner = run_cli_command def test_computer_test(self): """ @@ -598,31 +587,23 @@ def test_computer_test(self): It should work as it is a local connection """ # Testing the wrong computer will fail - result = self.cli_runner.invoke(computer_test, ['non-existent-computer']) - # An exception should arise - assert result.exception is not None + self.cli_runner(computer_test, ['non-existent-computer'], raises=True) # Testing the right computer should pass locally - result = self.cli_runner.invoke(computer_test, ['comp_cli_test_computer']) - # No exceptions should arise - assert result.exception is None, result.output + self.cli_runner(computer_test, ['comp_cli_test_computer']) def test_computer_list(self): """ Test if 'verdi computer list' command works """ # Check the vanilla command works - result = self.cli_runner.invoke(computer_list, []) - # No exceptions should arise - assert result.exception is None, result.output + result = self.cli_runner(computer_list, []) # Something should be printed to stdout assert result.output is not None # Check all options run for opt in ['-r', '--raw', '-a', '--all']: - result = self.cli_runner.invoke(computer_list, [opt]) - # No exceptions should arise - assert result.exception is None, result.output + result = self.cli_runner(computer_list, [opt]) # Something should be printed to stdout assert result.output is not None @@ -631,17 +612,12 @@ def test_computer_show(self): Test if 'verdi computer show' command works """ # See if we can display info about the test computer. - result = self.cli_runner.invoke(computer_show, ['comp_cli_test_computer']) - - # No exceptions should arise - self.assertClickResultNoException(result) + result = self.cli_runner(computer_show, ['comp_cli_test_computer']) # Something should be printed to stdout assert result.output is not None # See if a non-existent computer will raise an error. - result = self.cli_runner.invoke(computer_show, 'non_existent_computer_name') - # Exceptions should arise - assert result.exception is not None + result = self.cli_runner(computer_show, 'non_existent_computer_name', raises=True) def test_computer_relabel(self): """ @@ -651,23 +627,19 @@ def test_computer_relabel(self): # See if the command complains about not getting an invalid computer options = ['not_existent_computer_label'] - result = self.cli_runner.invoke(computer_relabel, options) - assert result.exception is not None + self.cli_runner(computer_relabel, options, raises=True) # See if the command complains about not getting both labels options = ['comp_cli_test_computer'] - result = self.cli_runner.invoke(computer_relabel, options) - assert result.exception is not None + self.cli_runner(computer_relabel, options, raises=True) # The new label must be different to the old one options = ['comp_cli_test_computer', 'comp_cli_test_computer'] - result = self.cli_runner.invoke(computer_relabel, options) - assert result.exception is not None + self.cli_runner(computer_relabel, options, raises=True) # Change a computer label successully. options = ['comp_cli_test_computer', 'relabeled_test_computer'] - result = self.cli_runner.invoke(computer_relabel, options) - assert result.exception is None, result.output + self.cli_runner(computer_relabel, options) # Check that the label really was changed # The old label should not be available @@ -678,8 +650,7 @@ def test_computer_relabel(self): # Now change the label back options = ['relabeled_test_computer', 'comp_cli_test_computer'] - result = self.cli_runner.invoke(computer_relabel, options) - assert result.exception is None, result.output + self.cli_runner(computer_relabel, options) # Check that the label really was changed # The old label should not be available @@ -705,15 +676,11 @@ def test_computer_delete(self): # See if the command complains about not getting an invalid computer options = ['non_existent_computer_name'] - result = self.cli_runner.invoke(computer_delete, options) - # Exception should be raised - assert result.exception is not None + self.cli_runner(computer_delete, options, raises=True) # Delete a computer name successully. options = ['computer_for_test_delete'] - result = self.cli_runner.invoke(computer_delete, options) - # Exception should be not be raised - self.assertClickResultNoException(result) + self.cli_runner(computer_delete, options) # Check that the computer really was deleted with pytest.raises(NotExistent): orm.Computer.objects.get(label='computer_for_test_delete') diff --git a/tests/cmdline/commands/test_data.py b/tests/cmdline/commands/test_data.py index 053bdb442e..f9b4cc4871 100644 --- a/tests/cmdline/commands/test_data.py +++ b/tests/cmdline/commands/test_data.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=no-member, too-many-lines +# pylint: disable=no-member,too-many-lines,no-self-use """Test data-related verdi commands.""" import asyncio @@ -16,9 +16,7 @@ import shutil import subprocess as sp import tempfile -import unittest -from click.testing import CliRunner import numpy as np import pytest @@ -38,7 +36,6 @@ from aiida.engine import calcfunction from aiida.orm import ArrayData, BandsData, CifData, Dict, Group, KpointsData, RemoteData, StructureData, TrajectoryData from aiida.orm.nodes.data.cif import has_pycifrw -from aiida.storage.testbase import AiidaTestCase from tests.static import STATIC_DIR @@ -51,7 +48,7 @@ class DummyVerdiDataExportable: NON_EMPTY_GROUP_ID_STR = 'non_empty_group_id' NON_EMPTY_GROUP_NAME_STR = 'non_empty_group' - @unittest.skipUnless(has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def data_export_test(self, datatype, ids, supported_formats): """This method tests that the data listing works as expected with all possible flags and arguments for different datatypes.""" @@ -68,14 +65,14 @@ def data_export_test(self, datatype, ids, supported_formats): # Check that the simple command works as expected options = [str(ids[self.NODE_ID_STR])] - res = self.cli_runner.invoke(export_cmd, options, catch_exceptions=False) - self.assertEqual(res.exit_code, 0, 'The command did not finish correctly') + res = self.cli_runner(export_cmd, options, catch_exceptions=False) + assert res.exit_code == 0, 'The command did not finish correctly' for flag in ['-F', '--format']: for frmt in supported_formats: options = [flag, frmt, str(ids[self.NODE_ID_STR])] - res = self.cli_runner.invoke(export_cmd, options, catch_exceptions=False) - self.assertEqual(res.exit_code, 0, f'The command did not finish correctly. Output:\n{res.output}') + res = self.cli_runner(export_cmd, options, catch_exceptions=False) + assert res.exit_code == 0, f'The command did not finish correctly. Output:\n{res.output}' # Check that the output to file flags work correctly: # -o, --output @@ -85,19 +82,19 @@ def data_export_test(self, datatype, ids, supported_formats): tmpd = tempfile.mkdtemp() filepath = os.path.join(tmpd, 'output_file.txt') options = [flag, filepath, str(ids[self.NODE_ID_STR])] - res = self.cli_runner.invoke(export_cmd, options, catch_exceptions=False) - self.assertEqual(res.exit_code, 0, f'The command should finish correctly.Output:\n{res.output}') + res = self.cli_runner(export_cmd, options, catch_exceptions=False) + assert res.exit_code == 0, f'The command should finish correctly.Output:\n{res.output}' # Try to export it again. It should fail because the # file exists - res = self.cli_runner.invoke(export_cmd, options, catch_exceptions=False) - self.assertNotEqual(res.exit_code, 0, 'The command should fail because the file already exists') + res = self.cli_runner(export_cmd, options, catch_exceptions=False) + assert res.exit_code != 0, 'The command should fail because the file already exists' # Now we force the export of the file and it should overwrite # existing files options = [flag, filepath, '-f', str(ids[self.NODE_ID_STR])] - res = self.cli_runner.invoke(export_cmd, options, catch_exceptions=False) - self.assertEqual(res.exit_code, 0, f'The command should finish correctly.Output: {res.output}') + res = self.cli_runner(export_cmd, options, catch_exceptions=False) + assert res.exit_code == 0, f'The command should finish correctly.Output: {res.output}' finally: shutil.rmtree(tmpd) @@ -138,22 +135,20 @@ def data_listing_test(self, datatype, search_string, ids): search_string_bytes = search_string.encode('utf-8') # Check that the normal listing works as expected - res = self.cli_runner.invoke(listing_cmd, [], catch_exceptions=False) - self.assertIn(search_string_bytes, res.stdout_bytes, f'The string {search_string} was not found in the listing') + res = self.cli_runner(listing_cmd, [], catch_exceptions=False) + assert search_string_bytes in res.stdout_bytes, f'The string {search_string} was not found in the listing' # Check that the past days filter works as expected past_days_flags = ['-p', '--past-days'] for flag in past_days_flags: options = [flag, '1'] - res = self.cli_runner.invoke(listing_cmd, options, catch_exceptions=False) - self.assertIn( - search_string_bytes, res.stdout_bytes, f'The string {search_string} was not found in the listing' - ) + res = self.cli_runner(listing_cmd, options, catch_exceptions=False) + assert search_string_bytes in res.stdout_bytes, f'The string {search_string} was not found in the listing' options = [flag, '0'] - res = self.cli_runner.invoke(listing_cmd, options, catch_exceptions=False) - self.assertNotIn( - search_string_bytes, res.stdout_bytes, f'A not expected string {search_string} was found in the listing' + res = self.cli_runner(listing_cmd, options, catch_exceptions=False) + assert search_string_bytes not in res.stdout_bytes, ( + f'A not expected string {search_string} was found in the listing' ) # Check that the group filter works as expected @@ -163,43 +158,35 @@ def data_listing_test(self, datatype, search_string, ids): # Non empty group for non_empty in [self.NON_EMPTY_GROUP_NAME_STR, str(ids[self.NON_EMPTY_GROUP_ID_STR])]: options = [flag, non_empty] - res = self.cli_runner.invoke(listing_cmd, options, catch_exceptions=False) - self.assertIn(search_string_bytes, res.stdout_bytes, 'The string {} was not found in the listing') + res = self.cli_runner(listing_cmd, options, catch_exceptions=False) + assert search_string_bytes in res.stdout_bytes, 'The string {} was not found in the listing' # Empty group for empty in [self.EMPTY_GROUP_NAME_STR, str(ids[self.EMPTY_GROUP_ID_STR])]: options = [flag, empty] - res = self.cli_runner.invoke(listing_cmd, options, catch_exceptions=False) - self.assertNotIn( - search_string_bytes, res.stdout_bytes, 'A not expected string {} was found in the listing' - ) + res = self.cli_runner(listing_cmd, options, catch_exceptions=False) + assert search_string_bytes not in res.stdout_bytes, 'A not expected string {} was found in the listing' # Group combination for non_empty in [self.NON_EMPTY_GROUP_NAME_STR, str(ids[self.NON_EMPTY_GROUP_ID_STR])]: for empty in [self.EMPTY_GROUP_NAME_STR, str(ids[self.EMPTY_GROUP_ID_STR])]: options = [flag, non_empty, empty] - res = self.cli_runner.invoke(listing_cmd, options, catch_exceptions=False) - self.assertIn(search_string_bytes, res.stdout_bytes, 'The string {} was not found in the listing') + res = self.cli_runner(listing_cmd, options, catch_exceptions=False) + assert search_string_bytes in res.stdout_bytes, 'The string {} was not found in the listing' # Check raw flag raw_flags = ['-r', '--raw'] for flag in raw_flags: options = [flag] - res = self.cli_runner.invoke(listing_cmd, options, catch_exceptions=False) + res = self.cli_runner(listing_cmd, options, catch_exceptions=False) for header in headers_mapping[datatype]: - self.assertNotIn(header.encode('utf-8'), res.stdout_bytes) + assert header.encode('utf-8') not in res.stdout_bytes -class TestVerdiData(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestVerdiData: """Testing reachability of the verdi data subcommands.""" - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - super().setUpClass() - - def setUp(self): - pass - def test_reachable(self): """Testing reachability of the following commands: verdi data array @@ -216,37 +203,46 @@ def test_reachable(self): ] for sub_cmd in subcommands: output = sp.check_output(['verdi', 'data', sub_cmd, '--help']) - self.assertIn(b'Usage:', output, f'Sub-command verdi data {sub_cmd} --help failed.') + assert b'Usage:' in output, f'Sub-command verdi data {sub_cmd} --help failed.' -class TestVerdiDataArray(AiidaTestCase): +class TestVerdiDataArray: """Testing verdi data array.""" - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - super().setUpClass() - - def setUp(self): + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init self.arr = ArrayData() self.arr.set_array('test_array', np.array([0, 1, 3])) self.arr.store() - - self.cli_runner = CliRunner() + self.cli_runner = run_cli_command def test_arrayshowhelp(self): output = sp.check_output(['verdi', 'data', 'core.array', 'show', '--help']) - self.assertIn(b'Usage:', output, 'Sub-command verdi data array show --help failed.') + assert b'Usage:' in output, 'Sub-command verdi data array show --help failed.' def test_arrayshow(self): options = [str(self.arr.id)] - res = self.cli_runner.invoke(cmd_array.array_show, options, catch_exceptions=False) - self.assertEqual(res.exit_code, 0, 'The command did not finish correctly') + res = self.cli_runner(cmd_array.array_show, options, catch_exceptions=False) + assert res.exit_code == 0, 'The command did not finish correctly' @pytest.mark.requires_rmq -class TestVerdiDataBands(AiidaTestCase, DummyVerdiDataListable): +class TestVerdiDataBands(DummyVerdiDataListable): """Testing verdi data bands.""" + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.ids = self.create_structure_bands() + self.cli_runner = run_cli_command + yield + self.loop.close() + @staticmethod def create_structure_bands(): """Create bands structure object.""" @@ -311,30 +307,13 @@ def connect_structure_bands(strct): # pylint: disable=unused-argument DummyVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id } - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - super().setUpClass() - - # create a new event loop since the privious one is closed by other test case - cls.loop = asyncio.new_event_loop() - asyncio.set_event_loop(cls.loop) - cls.ids = cls.create_structure_bands() - - @classmethod - def tearDownClass(cls): # pylint: disable=arguments-differ - cls.loop.close() - super().tearDownClass() - - def setUp(self): - self.cli_runner = CliRunner() - def test_bandsshowhelp(self): output = sp.check_output(['verdi', 'data', 'core.bands', 'show', '--help']) - self.assertIn(b'Usage:', output, 'Sub-command verdi data bands show --help failed.') + assert b'Usage:' in output, 'Sub-command verdi data bands show --help failed.' def test_bandlistshelp(self): output = sp.check_output(['verdi', 'data', 'core.bands', 'list', '--help']) - self.assertIn(b'Usage:', output, 'Sub-command verdi data bands show --help failed.') + assert b'Usage:' in output, 'Sub-command verdi data bands show --help failed.' def test_bandslist(self): self.data_listing_test(BandsData, 'FeO', self.ids) @@ -342,19 +321,19 @@ def test_bandslist(self): def test_bandslist_with_elements(self): options = ['-e', 'Fe'] - res = self.cli_runner.invoke(cmd_bands.bands_list, options, catch_exceptions=False) - self.assertIn(b'FeO', res.stdout_bytes, 'The string "FeO" was not found in the listing') - self.assertNotIn(b'<>', res.stdout_bytes, 'The string "<>" should not in the listing') + res = self.cli_runner(cmd_bands.bands_list, options, catch_exceptions=False) + assert b'FeO' in res.stdout_bytes, 'The string "FeO" was not found in the listing' + assert b'<>' not in res.stdout_bytes, 'The string "<>" should not in the listing' def test_bandexporthelp(self): output = sp.check_output(['verdi', 'data', 'core.bands', 'export', '--help']) - self.assertIn(b'Usage:', output, 'Sub-command verdi data bands export --help failed.') + assert b'Usage:' in output, 'Sub-command verdi data bands export --help failed.' def test_bandsexport(self): options = [str(self.ids[DummyVerdiDataListable.NODE_ID_STR])] - res = self.cli_runner.invoke(cmd_bands.bands_export, options, catch_exceptions=False) - self.assertEqual(res.exit_code, 0, 'The command did not finish correctly') - self.assertIn(b'[1.0, 3.0]', res.stdout_bytes, 'The string [1.0, 3.0] was not found in the bands export') + res = self.cli_runner(cmd_bands.bands_export, options, catch_exceptions=False) + assert res.exit_code == 0, 'The command did not finish correctly' + assert b'[1.0, 3.0]' in res.stdout_bytes, 'The string [1.0, 3.0] was not found in the bands export' def test_bandsexport_single_kp(self): """ @@ -370,116 +349,110 @@ def test_bandsexport_single_kp(self): # matplotlib options = [str(bands.id), '--format', 'mpl_singlefile'] - res = self.cli_runner.invoke(cmd_bands.bands_export, options, catch_exceptions=False) - self.assertIn(b'p.scatter', res.stdout_bytes, 'The string p.scatter was not found in the bands mpl export') + res = self.cli_runner(cmd_bands.bands_export, options, catch_exceptions=False) + assert b'p.scatter' in res.stdout_bytes, 'The string p.scatter was not found in the bands mpl export' # gnuplot - with self.cli_runner.isolated_filesystem(): + from click.testing import CliRunner + with CliRunner().isolated_filesystem(): options = [str(bands.id), '--format', 'gnuplot', '-o', 'bands.gnu'] - self.cli_runner.invoke(cmd_bands.bands_export, options, catch_exceptions=False) + self.cli_runner(cmd_bands.bands_export, options, catch_exceptions=False) with open('bands.gnu', 'r', encoding='utf8') as gnu_file: res = gnu_file.read() - self.assertIn('vectors nohead', res, 'The string "vectors nohead" was not found in the gnuplot script') + assert 'vectors nohead' in res, 'The string "vectors nohead" was not found in the gnuplot script' -class TestVerdiDataDict(AiidaTestCase): +class TestVerdiDataDict: """Testing verdi data dict.""" - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - super().setUpClass() - - def setUp(self): + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init self.dct = Dict() self.dct.set_dict({'a': 1, 'b': 2}) self.dct.store() - - self.cli_runner = CliRunner() + self.cli_runner = run_cli_command def test_dictshowhelp(self): output = sp.check_output(['verdi', 'data', 'core.dict', 'show', '--help']) - self.assertIn(b'Usage:', output, 'Sub-command verdi data dict show --help failed.') + assert b'Usage:' in output, 'Sub-command verdi data dict show --help failed.' def test_dictshow(self): """Test verdi data dict show.""" options = [str(self.dct.id)] - res = self.cli_runner.invoke(cmd_dict.dictionary_show, options, catch_exceptions=False) - self.assertEqual(res.exit_code, 0, 'The command verdi data dict show did not finish correctly') - self.assertIn( - b'"a": 1', res.stdout_bytes, 'The string "a": 1 was not found in the output' + res = self.cli_runner(cmd_dict.dictionary_show, options, catch_exceptions=False) + assert res.exit_code == 0, 'The command verdi data dict show did not finish correctly' + assert b'"a": 1' in res.stdout_bytes, 'The string "a": 1 was not found in the output' \ ' of verdi data dict show' - ) -class TestVerdiDataRemote(AiidaTestCase): +class TestVerdiDataRemote: """Testing verdi data remote.""" - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - super().setUpClass() - user = orm.User.objects.get_default() - orm.AuthInfo(cls.computer, user).store() - - def setUp(self): - comp = self.computer + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost, run_cli_command, tmp_path): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + tmp_path.joinpath('file.txt').write_text('test string', encoding='utf8') self.rmt = RemoteData() - path = tempfile.mkdtemp() - self.rmt.set_remote_path(path) - with open(os.path.join(path, 'file.txt'), 'w', encoding='utf8') as fhandle: - fhandle.write('test string') - self.rmt.computer = comp + self.rmt.set_remote_path(str(tmp_path.absolute())) + self.rmt.computer = aiida_localhost self.rmt.store() - - self.cli_runner = CliRunner() + self.cli_runner = run_cli_command def test_remoteshowhelp(self): output = sp.check_output(['verdi', 'data', 'core.remote', 'show', '--help']) - self.assertIn(b'Usage:', output, 'Sub-command verdi data remote show --help failed.') + assert b'Usage:' in output, 'Sub-command verdi data remote show --help failed.' def test_remoteshow(self): """Test verdi data remote show.""" options = [str(self.rmt.id)] - res = self.cli_runner.invoke(cmd_remote.remote_show, options, catch_exceptions=False) - self.assertEqual(res.exit_code, 0, 'The command verdi data remote show did not finish correctly') - self.assertIn( - b'Remote computer name:', res.stdout_bytes, 'The string "Remote computer name:" was not found in the' - ' output of verdi data remote show' + res = self.cli_runner(cmd_remote.remote_show, options, catch_exceptions=False) + assert res.exit_code == 0, 'The command verdi data remote show did not finish correctly' + assert b'Remote computer name:' in res.stdout_bytes, ( + 'The string "Remote computer name:" was not found in the output of verdi data remote show' ) - self.assertIn( - b'Remote folder full path:', res.stdout_bytes, 'The string "Remote folder full path:" was not found in the' - ' output of verdi data remote show' + assert b'Remote folder full path:' in res.stdout_bytes, ( + 'The string "Remote folder full path:" was not found in the output of verdi data remote show' ) def test_remotelshelp(self): output = sp.check_output(['verdi', 'data', 'core.remote', 'ls', '--help']) - self.assertIn(b'Usage:', output, 'Sub-command verdi data remote ls --help failed.') + assert b'Usage:' in output, 'Sub-command verdi data remote ls --help failed.' def test_remotels(self): options = ['--long', str(self.rmt.id)] - res = self.cli_runner.invoke(cmd_remote.remote_ls, options, catch_exceptions=False) - self.assertEqual(res.exit_code, 0, 'The command verdi data remote ls did not finish correctly') - self.assertIn( - b'file.txt', res.stdout_bytes, 'The file "file.txt" was not found in the output' + res = self.cli_runner(cmd_remote.remote_ls, options, catch_exceptions=False) + assert res.exit_code == 0, 'The command verdi data remote ls did not finish correctly' + assert b'file.txt' in res.stdout_bytes, 'The file "file.txt" was not found in the output' \ ' of verdi data remote ls' - ) def test_remotecathelp(self): output = sp.check_output(['verdi', 'data', 'core.remote', 'cat', '--help']) - self.assertIn(b'Usage:', output, 'Sub-command verdi data remote cat --help failed.') + assert b'Usage:' in output, 'Sub-command verdi data remote cat --help failed.' def test_remotecat(self): options = [str(self.rmt.id), 'file.txt'] - res = self.cli_runner.invoke(cmd_remote.remote_cat, options, catch_exceptions=False) - self.assertEqual(res.exit_code, 0, 'The command verdi data remote cat did not finish correctly') - self.assertIn( - b'test string', res.stdout_bytes, 'The string "test string" was not found in the output' + res = self.cli_runner(cmd_remote.remote_cat, options, catch_exceptions=False) + assert res.exit_code == 0, 'The command verdi data remote cat did not finish correctly' + assert b'test string' in res.stdout_bytes, 'The string "test string" was not found in the output' \ ' of verdi data remote cat file.txt' - ) -class TestVerdiDataTrajectory(AiidaTestCase, DummyVerdiDataListable, DummyVerdiDataExportable): +class TestVerdiDataTrajectory(DummyVerdiDataListable, DummyVerdiDataExportable): """Test verdi data trajectory.""" + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.comp = aiida_localhost + self.this_folder = os.path.dirname(__file__) + self.this_file = os.path.basename(__file__) + self.ids = self.create_trajectory_data() + self.cli_runner = run_cli_command + @staticmethod def create_trajectory_data(): """Create TrajectoryData object with two arrays.""" @@ -541,46 +514,36 @@ def create_trajectory_data(): DummyVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id } - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - super().setUpClass() - orm.Computer( - label='comp', - hostname='localhost', - transport_type='core.local', - scheduler_type='core.direct', - workdir='/tmp/aiida' - ).store() - cls.ids = cls.create_trajectory_data() - - def setUp(self): - self.comp = self.computer - self.runner = CliRunner() - self.this_folder = os.path.dirname(__file__) - self.this_file = os.path.basename(__file__) - - self.cli_runner = CliRunner() - def test_showhelp(self): - res = self.runner.invoke(cmd_trajectory.trajectory_show, ['--help']) - self.assertIn( - b'Usage:', res.stdout_bytes, 'The string "Usage: " was not found in the output' + res = self.cli_runner(cmd_trajectory.trajectory_show, ['--help']) + assert b'Usage:' in res.stdout_bytes, 'The string "Usage: " was not found in the output' \ ' of verdi data trajecotry show --help' - ) def test_list(self): self.data_listing_test(TrajectoryData, str(self.ids[DummyVerdiDataListable.NODE_ID_STR]), self.ids) - @unittest.skipUnless(has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_export(self): new_supported_formats = list(cmd_trajectory.EXPORT_FORMATS) self.data_export_test(TrajectoryData, self.ids, new_supported_formats) -class TestVerdiDataStructure(AiidaTestCase, DummyVerdiDataListable, DummyVerdiDataExportable): +class TestVerdiDataStructure(DummyVerdiDataListable, DummyVerdiDataExportable): """Test verdi data structure.""" from aiida.orm.nodes.data.structure import has_ase + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.comp = aiida_localhost + self.this_folder = os.path.dirname(__file__) + self.this_file = os.path.basename(__file__) + self.ids = self.create_structure_data() + for group_label in ['xyz structure group', 'ase structure group']: + Group(label=group_label).store() + self.cli_runner = run_cli_command + @staticmethod def create_structure_data(): """Create StructureData object.""" @@ -626,49 +589,20 @@ def create_structure_data(): DummyVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id } - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - super().setUpClass() - orm.Computer( - label='comp', - hostname='localhost', - transport_type='core.local', - scheduler_type='core.direct', - workdir='/tmp/aiida' - ).store() - cls.ids = cls.create_structure_data() - - for group_label in ['xyz structure group', 'ase structure group']: - Group(label=group_label).store() - - def setUp(self): - self.comp = self.computer - self.runner = CliRunner() - self.this_folder = os.path.dirname(__file__) - self.this_file = os.path.basename(__file__) - - self.cli_runner = CliRunner() - def test_importhelp(self): - res = self.runner.invoke(cmd_structure.structure_import, ['--help']) - self.assertIn( - b'Usage:', res.stdout_bytes, 'The string "Usage: " was not found in the output' + res = self.cli_runner(cmd_structure.structure_import, ['--help']) + assert b'Usage:' in res.stdout_bytes, 'The string "Usage: " was not found in the output' \ ' of verdi data structure import --help' - ) def test_importhelp_ase(self): - res = self.runner.invoke(cmd_structure.import_ase, ['--help']) - self.assertIn( - b'Usage:', res.stdout_bytes, 'The string "Usage: " was not found in the output' + res = self.cli_runner(cmd_structure.import_ase, ['--help']) + assert b'Usage:' in res.stdout_bytes, 'The string "Usage: " was not found in the output' \ ' of verdi data structure import ase --help' - ) def test_importhelp_aiida_xyz(self): - res = self.runner.invoke(cmd_structure.import_aiida_xyz, ['--help']) - self.assertIn( - b'Usage:', res.stdout_bytes, 'The string "Usage: " was not found in the output' + res = self.cli_runner(cmd_structure.import_aiida_xyz, ['--help']) + assert b'Usage:' in res.stdout_bytes, 'The string "Usage: " was not found in the output' \ ' of verdi data structure import aiida-xyz --help' - ) def test_import_aiida_xyz(self): """Test import xyz file.""" @@ -692,16 +626,12 @@ def test_import_aiida_xyz(self): '1', '1', ] - res = self.cli_runner.invoke(cmd_structure.import_aiida_xyz, options, catch_exceptions=False) - self.assertIn( - b'Successfully imported', res.stdout_bytes, - 'The string "Successfully imported" was not found in the output' + res = self.cli_runner(cmd_structure.import_aiida_xyz, options, catch_exceptions=False) + assert b'Successfully imported' in res.stdout_bytes, \ + 'The string "Successfully imported" was not found in the output' \ ' of verdi data structure import.' - ) - self.assertIn( - b'PK', res.stdout_bytes, 'The string "PK" was not found in the output' + assert b'PK' in res.stdout_bytes, 'The string "PK" was not found in the output' \ ' of verdi data structure import.' - ) def test_import_aiida_xyz_2(self): """Test import xyz file.""" @@ -718,16 +648,12 @@ def test_import_aiida_xyz_2(self): fhandle.name, '-n' # dry-run ] - res = self.cli_runner.invoke(cmd_structure.import_aiida_xyz, options, catch_exceptions=False) - self.assertIn( - b'Successfully imported', res.stdout_bytes, - 'The string "Successfully imported" was not found in the output' + res = self.cli_runner(cmd_structure.import_aiida_xyz, options, catch_exceptions=False) + assert b'Successfully imported' in res.stdout_bytes, \ + 'The string "Successfully imported" was not found in the output' \ ' of verdi data structure import.' - ) - self.assertIn( - b'dry-run', res.stdout_bytes, 'The string "dry-run" was not found in the output' + assert b'dry-run' in res.stdout_bytes, 'The string "dry-run" was not found in the output' \ ' of verdi data structure import.' - ) def test_import_aiida_xyz_w_group_label(self): """Test import xyz file including setting label and group.""" @@ -756,22 +682,17 @@ def test_import_aiida_xyz_w_group_label(self): '--group', group_label, ] - res = self.cli_runner.invoke(cmd_structure.import_aiida_xyz, options, catch_exceptions=False) - self.assertIn( - b'Successfully imported', res.stdout_bytes, - 'The string "Successfully imported" was not found in the output' + res = self.cli_runner(cmd_structure.import_aiida_xyz, options, catch_exceptions=False) + assert b'Successfully imported' in res.stdout_bytes, \ + 'The string "Successfully imported" was not found in the output' \ ' of verdi data structure import.' - ) - self.assertIn( - b'PK', res.stdout_bytes, 'The string "PK" was not found in the output' + assert b'PK' in res.stdout_bytes, 'The string "PK" was not found in the output' \ ' of verdi data structure import.' - ) - res = self.cli_runner.invoke(cmd_group.group_show, [group_label]) - self.assertClickResultNoException(res) + res = self.cli_runner(cmd_group.group_show, [group_label]) for grpline in [group_label, 'StructureData']: - self.assertIn(grpline, res.output) + assert grpline in res.output - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') def test_import_ase(self): """Trying to import an xsf file through ase.""" xsfcontent = '''CRYSTAL @@ -790,18 +711,14 @@ def test_import_ase(self): options = [ fhandle.name, ] - res = self.cli_runner.invoke(cmd_structure.import_ase, options, catch_exceptions=False) - self.assertIn( - b'Successfully imported', res.stdout_bytes, - 'The string "Successfully imported" was not found in the output' + res = self.cli_runner(cmd_structure.import_ase, options, catch_exceptions=False) + assert b'Successfully imported' in res.stdout_bytes, \ + 'The string "Successfully imported" was not found in the output' \ ' of verdi data structure import.' - ) - self.assertIn( - b'PK', res.stdout_bytes, 'The string "PK" was not found in the output' + assert b'PK' in res.stdout_bytes, 'The string "PK" was not found in the output' \ ' of verdi data structure import.' - ) - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') def test_import_ase_w_group_label(self): """Trying to import an xsf file through ase including setting label and group.""" xsfcontent = '''CRYSTAL @@ -819,20 +736,15 @@ def test_import_ase_w_group_label(self): fhandle.write(xsfcontent) fhandle.flush() options = [fhandle.name, '--label', 'another  structure', '--group', group_label] - res = self.cli_runner.invoke(cmd_structure.import_ase, options, catch_exceptions=False) - self.assertIn( - b'Successfully imported', res.stdout_bytes, - 'The string "Successfully imported" was not found in the output' + res = self.cli_runner(cmd_structure.import_ase, options, catch_exceptions=False) + assert b'Successfully imported' in res.stdout_bytes, \ + 'The string "Successfully imported" was not found in the output' \ ' of verdi data structure import.' - ) - self.assertIn( - b'PK', res.stdout_bytes, 'The string "PK" was not found in the output' + assert b'PK' in res.stdout_bytes, 'The string "PK" was not found in the output' \ ' of verdi data structure import.' - ) - res = self.cli_runner.invoke(cmd_group.group_show, [group_label]) - self.assertClickResultNoException(res) + res = self.cli_runner(cmd_group.group_show, [group_label]) for grpline in [group_label, 'StructureData']: - self.assertIn(grpline, res.output) + assert grpline in res.output def test_list(self): self.data_listing_test(StructureData, 'BaO3Ti', self.ids) @@ -841,8 +753,8 @@ def test_export(self): self.data_export_test(StructureData, self.ids, cmd_structure.EXPORT_FORMATS) -@unittest.skipUnless(has_pycifrw(), 'Unable to import PyCifRW') -class TestVerdiDataCif(AiidaTestCase, DummyVerdiDataListable, DummyVerdiDataExportable): +@pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') +class TestVerdiDataCif(DummyVerdiDataListable, DummyVerdiDataExportable): """Test verdi data cif.""" valid_sample_cif_str = ''' data_test @@ -864,12 +776,21 @@ class TestVerdiDataCif(AiidaTestCase, DummyVerdiDataListable, DummyVerdiDataExpo H 0.75 0.75 0.75 0 ''' - @classmethod - def create_cif_data(cls): + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.comp = aiida_localhost + self.this_folder = os.path.dirname(__file__) + self.this_file = os.path.basename(__file__) + self.ids = self.create_cif_data() + self.cli_runner = run_cli_command + + def create_cif_data(self): """Create CifData object.""" with tempfile.NamedTemporaryFile(mode='w+') as fhandle: filename = fhandle.name - fhandle.write(cls.valid_sample_cif_str) + fhandle.write(self.valid_sample_cif_str) fhandle.flush() a_cif = CifData(file=filename, source={'version': '1234', 'db_name': 'COD', 'id': '0000001'}) a_cif.store() @@ -881,7 +802,7 @@ def create_cif_data(cls): g_e = Group(label='empty_group') g_e.store() - cls.cif = a_cif + self.cif = a_cif # pylint: disable=attribute-defined-outside-init return { DummyVerdiDataListable.NODE_ID_STR: a_cif.id, @@ -889,29 +810,6 @@ def create_cif_data(cls): DummyVerdiDataListable.EMPTY_GROUP_ID_STR: g_e.id } - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - """Setup class to test CifData.""" - super().setUpClass() - orm.Computer( - label='comp', - hostname='localhost', - transport_type='core.local', - scheduler_type='core.direct', - workdir='/tmp/aiida' - ).store() - - cls.ids = cls.create_cif_data() - - def setUp(self): - super().setUp() - self.comp = self.computer - self.runner = CliRunner() - self.this_folder = os.path.dirname(__file__) - self.this_file = os.path.basename(__file__) - - self.cli_runner = CliRunner() - def test_list(self): """ This method tests that the Cif listing works as expected with all @@ -921,19 +819,15 @@ def test_list(self): def test_showhelp(self): options = ['--help'] - res = self.cli_runner.invoke(cmd_cif.cif_show, options, catch_exceptions=False) - self.assertIn( - b'Usage:', res.stdout_bytes, 'The string "Usage: " was not found in the output' + res = self.cli_runner(cmd_cif.cif_show, options, catch_exceptions=False) + assert b'Usage:' in res.stdout_bytes, 'The string "Usage: " was not found in the output' \ ' of verdi data show help' - ) def test_importhelp(self): options = ['--help'] - res = self.cli_runner.invoke(cmd_cif.cif_import, options, catch_exceptions=False) - self.assertIn( - b'Usage:', res.stdout_bytes, 'The string "Usage: " was not found in the output' + res = self.cli_runner(cmd_cif.cif_import, options, catch_exceptions=False) + assert b'Usage:' in res.stdout_bytes, 'The string "Usage: " was not found in the output' \ ' of verdi data import help' - ) def test_import(self): """Test verdi data cif import.""" @@ -941,19 +835,17 @@ def test_import(self): fhandle.write(self.valid_sample_cif_str) fhandle.flush() options = [fhandle.name] - res = self.cli_runner.invoke(cmd_cif.cif_import, options, catch_exceptions=False) - self.assertIn( - b'imported uuid', res.stdout_bytes, 'The string "imported uuid" was not found in the output' + res = self.cli_runner(cmd_cif.cif_import, options, catch_exceptions=False) + assert b'imported uuid' in res.stdout_bytes, 'The string "imported uuid" was not found in the output' \ ' of verdi data import.' - ) def test_content(self): """Test that `verdi data cif content` returns the content of the file.""" options = [str(self.cif.uuid)] - result = self.cli_runner.invoke(cmd_cif.cif_content, options, catch_exceptions=False) + result = self.cli_runner(cmd_cif.cif_content, options, catch_exceptions=False) for line in result.output.split('\n'): - self.assertIn(line, self.valid_sample_cif_str) + assert line in self.valid_sample_cif_str def test_export(self): """This method checks if the Cif export works as expected with all @@ -961,24 +853,20 @@ def test_export(self): self.data_export_test(CifData, self.ids, cmd_cif.EXPORT_FORMATS) -class TestVerdiDataSinglefile(AiidaTestCase, DummyVerdiDataListable, DummyVerdiDataExportable): +class TestVerdiDataSinglefile(DummyVerdiDataListable, DummyVerdiDataExportable): """Test verdi data singlefile.""" sample_str = ''' data_test ''' - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - super().setUpClass() - - def setUp(self): - super().setUp() - self.comp = self.computer - self.runner = CliRunner() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.comp = aiida_localhost self.this_folder = os.path.dirname(__file__) self.this_file = os.path.basename(__file__) - - self.cli_runner = CliRunner() + self.cli_runner = run_cli_command def test_content(self): """Test that `verdi data singlefile content` returns the content of the file.""" @@ -986,44 +874,40 @@ def test_content(self): singlefile = orm.SinglefileData(file=io.BytesIO(content.encode('utf8'))).store() options = [str(singlefile.uuid)] - result = self.cli_runner.invoke(cmd_singlefile.singlefile_content, options, catch_exceptions=False) + result = self.cli_runner(cmd_singlefile.singlefile_content, options, catch_exceptions=False) for line in result.output.split('\n'): - self.assertIn(line, content) + assert line in content -class TestVerdiDataUpf(AiidaTestCase): +class TestVerdiDataUpf: """Testing verdi data upf.""" - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - super().setUpClass() - - def setUp(self): + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init self.filepath_pseudos = os.path.join(STATIC_DIR, 'pseudos') - self.cli_runner = CliRunner() + self.cli_runner = run_cli_command def upload_family(self): options = [self.filepath_pseudos, 'test_group', 'test description'] - res = self.cli_runner.invoke(cmd_upf.upf_uploadfamily, options, catch_exceptions=False) - self.assertIn( - b'UPF files found: 4', res.stdout_bytes, 'The string "UPF files found: 4" was not found in the' + res = self.cli_runner(cmd_upf.upf_uploadfamily, options, catch_exceptions=False) + assert b'UPF files found: 4' in res.stdout_bytes, 'The string "UPF files found: 4" was not found in the' \ ' output of verdi data upf uploadfamily' - ) def test_uploadfamilyhelp(self): output = sp.check_output(['verdi', 'data', 'core.upf', 'uploadfamily', '--help']) - self.assertIn(b'Usage:', output, f'Sub-command verdi data upf uploadfamily --help failed: {output}') + assert b'Usage:' in output, f'Sub-command verdi data upf uploadfamily --help failed: {output}' def test_uploadfamily(self): self.upload_family() options = [self.filepath_pseudos, 'test_group', 'test description', '--stop-if-existing'] - with self.assertRaises(ValueError): - self.cli_runner.invoke(cmd_upf.upf_uploadfamily, options, catch_exceptions=False) + self.cli_runner(cmd_upf.upf_uploadfamily, options, raises=True) def test_exportfamilyhelp(self): output = sp.check_output(['verdi', 'data', 'core.upf', 'exportfamily', '--help']) - self.assertIn(b'Usage:', output, 'Sub-command verdi data upf exportfamily --help failed.') + assert b'Usage:' in output, 'Sub-command verdi data upf exportfamily --help failed.' def test_exportfamily(self): """Test verdi data upf exportfamily.""" @@ -1031,60 +915,46 @@ def test_exportfamily(self): path = tempfile.mkdtemp() options = [path, 'test_group'] - res = self.cli_runner.invoke(cmd_upf.upf_exportfamily, options, catch_exceptions=False) - self.assertClickResultNoException(res) + self.cli_runner(cmd_upf.upf_exportfamily, options, catch_exceptions=False) output = sp.check_output(['ls', path]) - self.assertIn( - b'Ba.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF', output, + assert b'Ba.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF' in output, \ f'Sub-command verdi data upf exportfamily --help failed: {output}' - ) - self.assertIn( - b'O.pbesol-n-rrkjus_psl.0.1-tested-pslib030.UPF', output, + assert b'O.pbesol-n-rrkjus_psl.0.1-tested-pslib030.UPF' in output, \ 'Sub-command verdi data upf exportfamily --help failed.' - ) - self.assertIn( - b'Ti.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF', output, + assert b'Ti.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF' in output, \ 'Sub-command verdi data upf exportfamily --help failed.' - ) - self.assertIn(b'C_pbe_v1.2.uspp.F.UPF', output, 'Sub-command verdi data upf exportfamily --help failed.') + assert b'C_pbe_v1.2.uspp.F.UPF' in output, 'Sub-command verdi data upf exportfamily --help failed.' def test_listfamilieshelp(self): output = sp.check_output(['verdi', 'data', 'core.upf', 'listfamilies', '--help']) - self.assertIn(b'Usage:', output, 'Sub-command verdi data upf listfamilies --help failed.') + assert b'Usage:' in output, 'Sub-command verdi data upf listfamilies --help failed.' def test_listfamilies(self): """Test verdi data upf listfamilies""" self.upload_family() options = ['-d', '-e', 'Ba'] - res = self.cli_runner.invoke(cmd_upf.upf_listfamilies, options, catch_exceptions=False) + res = self.cli_runner(cmd_upf.upf_listfamilies, options, catch_exceptions=False) - self.assertIn( - b'test_group', res.stdout_bytes, 'The string "test_group" was not found in the' + assert b'test_group' in res.stdout_bytes, 'The string "test_group" was not found in the' \ ' output of verdi data upf listfamilies: {}'.format(res.output) - ) - self.assertIn( - b'test description', res.stdout_bytes, 'The string "test_group" was not found in the' + assert b'test description' in res.stdout_bytes, 'The string "test_group" was not found in the' \ ' output of verdi data upf listfamilies' - ) options = ['-d', '-e', 'Fe'] - res = self.cli_runner.invoke(cmd_upf.upf_listfamilies, options, catch_exceptions=False) - self.assertIn( - b'No valid UPF pseudopotential', res.stdout_bytes, 'The string "No valid UPF pseudopotential" was not' - ' found in the output of verdi data upf listfamilies' + res = self.cli_runner(cmd_upf.upf_listfamilies, options, catch_exceptions=False) + assert b'No valid UPF pseudopotential' in res.stdout_bytes, ( + 'The string "No valid UPF pseudopotential" was not found in the output of verdi data upf listfamilies' ) def test_importhelp(self): output = sp.check_output(['verdi', 'data', 'core.upf', 'import', '--help']) - self.assertIn(b'Usage:', output, 'Sub-command verdi data upf listfamilies --help failed.') + assert b'Usage:' in output, 'Sub-command verdi data upf listfamilies --help failed.' def test_import(self): options = [os.path.join(self.filepath_pseudos, 'Ti.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF')] - res = self.cli_runner.invoke(cmd_upf.upf_import, options, catch_exceptions=False) + res = self.cli_runner(cmd_upf.upf_import, options, catch_exceptions=False) - self.assertIn( - b'Imported', res.stdout_bytes, 'The string "Imported" was not' + assert b'Imported' in res.stdout_bytes, 'The string "Imported" was not' \ ' found in the output of verdi data import: {}'.format(res.output) - ) diff --git a/tests/cmdline/commands/test_group.py b/tests/cmdline/commands/test_group.py index f92b503665..997e9d0392 100644 --- a/tests/cmdline/commands/test_group.py +++ b/tests/cmdline/commands/test_group.py @@ -8,125 +8,103 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `verdi group` command.""" +import pytest + from aiida import orm from aiida.cmdline.commands import cmd_group from aiida.cmdline.utils.echo import ExitCode from aiida.common import exceptions -from aiida.storage.testbase import AiidaTestCase -class TestVerdiGroup(AiidaTestCase): +class TestVerdiGroup: """Tests for the `verdi group` command.""" - def setUp(self): - """Create runner object to run tests.""" - from click.testing import CliRunner - self.cli_runner = CliRunner() - + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init for group in ['dummygroup1', 'dummygroup2', 'dummygroup3', 'dummygroup4']: orm.Group(label=group).store() - - def tearDown(self): - """Delete all created group objects.""" - for group in orm.Group.objects.all(): - orm.Group.objects.delete(group.pk) + self.cli_runner = run_cli_command def test_help(self): """Tests help text for all group sub commands.""" options = ['--help'] # verdi group list - result = self.cli_runner.invoke(cmd_group.group_list, options) - self.assertIsNone(result.exception, result.output) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_group.group_list, options) + assert 'Usage' in result.output # verdi group create - result = self.cli_runner.invoke(cmd_group.group_create, options) - self.assertIsNone(result.exception, result.output) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_group.group_create, options) + assert 'Usage' in result.output # verdi group delete - result = self.cli_runner.invoke(cmd_group.group_delete, options) - self.assertIsNone(result.exception, result.output) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_group.group_delete, options) + assert 'Usage' in result.output # verdi group relabel - result = self.cli_runner.invoke(cmd_group.group_relabel, options) - self.assertIsNone(result.exception, result.output) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_group.group_relabel, options) + assert 'Usage' in result.output # verdi group description - result = self.cli_runner.invoke(cmd_group.group_description, options) - self.assertIsNone(result.exception, result.output) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_group.group_description, options) + assert 'Usage' in result.output # verdi group addnodes - result = self.cli_runner.invoke(cmd_group.group_add_nodes, options) - self.assertIsNone(result.exception, result.output) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_group.group_add_nodes, options) + assert 'Usage' in result.output # verdi group removenodes - result = self.cli_runner.invoke(cmd_group.group_remove_nodes, options) - self.assertIsNone(result.exception, result.output) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_group.group_remove_nodes, options) + assert 'Usage' in result.output # verdi group show - result = self.cli_runner.invoke(cmd_group.group_show, options) - self.assertIsNone(result.exception, result.output) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_group.group_show, options) + assert 'Usage' in result.output # verdi group copy - result = self.cli_runner.invoke(cmd_group.group_copy, options) - self.assertIsNone(result.exception, result.output) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_group.group_copy, options) + assert 'Usage' in result.output def test_create(self): """Test `verdi group create` command.""" - result = self.cli_runner.invoke(cmd_group.group_create, ['dummygroup5']) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_group.group_create, ['dummygroup5']) # check if newly added group in present in list - result = self.cli_runner.invoke(cmd_group.group_list) - self.assertClickResultNoException(result) - - self.assertIn('dummygroup5', result.output) + result = self.cli_runner(cmd_group.group_list) + assert 'dummygroup5' in result.output def test_list(self): """Test `verdi group list` command.""" - result = self.cli_runner.invoke(cmd_group.group_list) - self.assertClickResultNoException(result) - + result = self.cli_runner(cmd_group.group_list) for grp in ['dummygroup1', 'dummygroup2']: - self.assertIn(grp, result.output) + assert grp in result.output def test_list_order(self): """Test `verdi group list` command with ordering options.""" orm.Group(label='agroup').store() options = [] - result = self.cli_runner.invoke(cmd_group.group_list, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_group.group_list, options) group_ordering = [l.split()[1] for l in result.output.split('\n')[3:] if l] - self.assertEqual(['dummygroup1', 'dummygroup2', 'dummygroup3', 'dummygroup4', 'agroup'], group_ordering) + assert ['dummygroup1', 'dummygroup2', 'dummygroup3', 'dummygroup4', 'agroup'] == group_ordering options = ['--order-by', 'label'] - result = self.cli_runner.invoke(cmd_group.group_list, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_group.group_list, options) group_ordering = [l.split()[1] for l in result.output.split('\n')[3:] if l] - self.assertEqual(['agroup', 'dummygroup1', 'dummygroup2', 'dummygroup3', 'dummygroup4'], group_ordering) + assert ['agroup', 'dummygroup1', 'dummygroup2', 'dummygroup3', 'dummygroup4'] == group_ordering options = ['--order-by', 'id', '--order-direction', 'desc'] - result = self.cli_runner.invoke(cmd_group.group_list, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_group.group_list, options) + group_ordering = [l.split()[1] for l in result.output.split('\n')[3:] if l] - self.assertEqual(['agroup', 'dummygroup4', 'dummygroup3', 'dummygroup2', 'dummygroup1'], group_ordering) + assert ['agroup', 'dummygroup4', 'dummygroup3', 'dummygroup2', 'dummygroup1'] == group_ordering def test_copy(self): """Test `verdi group copy` command.""" - result = self.cli_runner.invoke(cmd_group.group_copy, ['dummygroup1', 'dummygroup2']) - self.assertClickResultNoException(result) - - self.assertIn('Success', result.output) + result = self.cli_runner(cmd_group.group_copy, ['dummygroup1', 'dummygroup2']) + assert 'Success' in result.output def test_delete(self): """Test `verdi group delete` command.""" @@ -135,17 +113,14 @@ def test_delete(self): orm.Group(label='group_test_delete_03').store() # dry run - result = self.cli_runner.invoke(cmd_group.group_delete, ['--dry-run', 'group_test_delete_01']) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_group.group_delete, ['--dry-run', 'group_test_delete_01']) orm.load_group(label='group_test_delete_01') - result = self.cli_runner.invoke(cmd_group.group_delete, ['--force', 'group_test_delete_01']) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_group.group_delete, ['--force', 'group_test_delete_01']) # Verify that removed group is not present in list - result = self.cli_runner.invoke(cmd_group.group_list) - self.assertClickResultNoException(result) - self.assertNotIn('group_test_delete_01', result.output) + result = self.cli_runner(cmd_group.group_list) + assert 'group_test_delete_01' not in result.output node_01 = orm.CalculationNode().store() node_02 = orm.CalculationNode().store() @@ -154,11 +129,11 @@ def test_delete(self): # Add some nodes and then use `verdi group delete` to delete a group that contains nodes group = orm.load_group(label='group_test_delete_02') group.add_nodes([node_01, node_02]) - self.assertEqual(group.count(), 2) + assert group.count() == 2 - result = self.cli_runner.invoke(cmd_group.group_delete, ['--force', 'group_test_delete_02']) + result = self.cli_runner(cmd_group.group_delete, ['--force', 'group_test_delete_02']) - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): orm.load_group(label='group_test_delete_02') # check nodes still exist @@ -168,25 +143,22 @@ def test_delete(self): # delete the group and the nodes it contains group = orm.load_group(label='group_test_delete_03') group.add_nodes([node_01, node_02]) - result = self.cli_runner.invoke(cmd_group.group_delete, ['--force', '--delete-nodes', 'group_test_delete_03']) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_group.group_delete, ['--force', '--delete-nodes', 'group_test_delete_03']) # check group and nodes no longer exist - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): orm.load_group(label='group_test_delete_03') for pk in node_pks: - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): orm.load_node(pk) def test_show(self): """Test `verdi group show` command.""" - result = self.cli_runner.invoke(cmd_group.group_show, ['dummygroup1']) - self.assertClickResultNoException(result) - + result = self.cli_runner(cmd_group.group_show, ['dummygroup1']) for grpline in [ 'Group label', 'dummygroup1', 'Group type_string', 'core', 'Group description', '' ]: - self.assertIn(grpline, result.output) + assert grpline in result.output def test_show_limit(self): """Test `--limit` option of the `verdi group show` command.""" @@ -196,55 +168,48 @@ def test_show_limit(self): group.add_nodes(nodes) # Default should include all nodes in the output - result = self.cli_runner.invoke(cmd_group.group_show, [label]) - self.assertClickResultNoException(result) - + result = self.cli_runner(cmd_group.group_show, [label]) for node in nodes: - self.assertIn(str(node.pk), result.output) + assert str(node.pk) in result.output # Repeat test with `limit=1`, use also the `--raw` option to only display nodes - result = self.cli_runner.invoke(cmd_group.group_show, [label, '--limit', '1', '--raw']) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_group.group_show, [label, '--limit', '1', '--raw']) # The current `verdi group show` does not support ordering so we cannot rely on that for now to test if only # one of the nodes is shown - self.assertEqual(len(result.output.strip().split('\n')), 1) - self.assertTrue(str(nodes[0].pk) in result.output or str(nodes[1].pk) in result.output) + assert len(result.output.strip().split('\n')) == 1 + assert str(nodes[0].pk) in result.output or str(nodes[1].pk) in result.output # Repeat test with `limit=1` but without the `--raw` flag as it has a different code path that is affected - result = self.cli_runner.invoke(cmd_group.group_show, [label, '--limit', '1']) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_group.group_show, [label, '--limit', '1']) # Check that one, and only one pk appears in the output - self.assertTrue(str(nodes[0].pk) in result.output or str(nodes[1].pk) in result.output) - self.assertTrue(not (str(nodes[0].pk) in result.output and str(nodes[1].pk) in result.output)) + assert str(nodes[0].pk) in result.output or str(nodes[1].pk) in result.output + assert not (str(nodes[0].pk) in result.output and str(nodes[1].pk) in result.output) def test_description(self): """Test `verdi group description` command.""" description = 'It is a new description' group = orm.load_group(label='dummygroup2') - self.assertNotEqual(group.description, description) + assert group.description != description # Change the description of the group - result = self.cli_runner.invoke(cmd_group.group_description, [group.label, description]) - self.assertClickResultNoException(result) - self.assertEqual(group.description, description) + result = self.cli_runner(cmd_group.group_description, [group.label, description]) + assert group.description == description # When no description argument is passed the command should just echo the current description - result = self.cli_runner.invoke(cmd_group.group_description, [group.label]) - self.assertClickResultNoException(result) - self.assertIn(description, result.output) + result = self.cli_runner(cmd_group.group_description, [group.label]) + assert description in result.output def test_relabel(self): """Test `verdi group relabel` command.""" - result = self.cli_runner.invoke(cmd_group.group_relabel, ['dummygroup4', 'relabeled_group']) - self.assertIsNone(result.exception, result.output) + result = self.cli_runner(cmd_group.group_relabel, ['dummygroup4', 'relabeled_group']) # check if group list command shows changed group name - result = self.cli_runner.invoke(cmd_group.group_list) - self.assertClickResultNoException(result) - self.assertNotIn('dummygroup4', result.output) - self.assertIn('relabeled_group', result.output) + result = self.cli_runner(cmd_group.group_list) + + assert 'dummygroup4' not in result.output + assert 'relabeled_group' in result.output def test_add_remove_nodes(self): """Test `verdi group remove-nodes` command.""" @@ -252,62 +217,66 @@ def test_add_remove_nodes(self): node_02 = orm.CalculationNode().store() node_03 = orm.CalculationNode().store() - result = self.cli_runner.invoke(cmd_group.group_add_nodes, ['--force', '--group=dummygroup1', node_01.uuid]) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_group.group_add_nodes, ['--force', '--group=dummygroup1', node_01.uuid]) # Check if node is added in group using group show command - result = self.cli_runner.invoke(cmd_group.group_show, ['dummygroup1']) - self.assertClickResultNoException(result) - self.assertIn('CalculationNode', result.output) - self.assertIn(str(node_01.pk), result.output) + result = self.cli_runner(cmd_group.group_show, ['dummygroup1']) + + assert 'CalculationNode' in result.output + assert str(node_01.pk) in result.output # Remove same node - result = self.cli_runner.invoke(cmd_group.group_remove_nodes, ['--force', '--group=dummygroup1', node_01.uuid]) - self.assertIsNone(result.exception, result.output) + result = self.cli_runner(cmd_group.group_remove_nodes, ['--force', '--group=dummygroup1', node_01.uuid]) # Check that the node is no longer in the group - result = self.cli_runner.invoke(cmd_group.group_show, ['-r', 'dummygroup1']) - self.assertClickResultNoException(result) - self.assertNotIn('CalculationNode', result.output) - self.assertNotIn(str(node_01.pk), result.output) + result = self.cli_runner(cmd_group.group_show, ['-r', 'dummygroup1']) + + assert 'CalculationNode' not in result.output + assert str(node_01.pk) not in result.output # Add all three nodes and then use `verdi group remove-nodes --clear` to remove them all group = orm.load_group(label='dummygroup1') group.add_nodes([node_01, node_02, node_03]) - self.assertEqual(group.count(), 3) + assert group.count() == 3 - result = self.cli_runner.invoke(cmd_group.group_remove_nodes, ['--force', '--clear', '--group=dummygroup1']) - self.assertClickResultNoException(result) - self.assertEqual(group.count(), 0) + result = self.cli_runner(cmd_group.group_remove_nodes, ['--force', '--clear', '--group=dummygroup1']) + + assert group.count() == 0 # Try to remove node that isn't in the group - result = self.cli_runner.invoke(cmd_group.group_remove_nodes, ['--group=dummygroup1', node_01.uuid]) - self.assertEqual(result.exit_code, ExitCode.CRITICAL) + result = self.cli_runner(cmd_group.group_remove_nodes, ['--group=dummygroup1', node_01.uuid], raises=True) + assert result.exit_code == ExitCode.CRITICAL # Try to remove no nodes nor clear the group - result = self.cli_runner.invoke(cmd_group.group_remove_nodes, ['--group=dummygroup1']) - self.assertEqual(result.exit_code, ExitCode.CRITICAL) + result = self.cli_runner(cmd_group.group_remove_nodes, ['--group=dummygroup1'], raises=True) + assert result.exit_code == ExitCode.CRITICAL # Try to remove both nodes and clear the group - result = self.cli_runner.invoke(cmd_group.group_remove_nodes, ['--group=dummygroup1', '--clear', node_01.uuid]) - self.assertEqual(result.exit_code, ExitCode.CRITICAL) + result = self.cli_runner( + cmd_group.group_remove_nodes, ['--group=dummygroup1', '--clear', node_01.uuid], raises=True + ) + assert result.exit_code == ExitCode.CRITICAL # Add a node with confirmation - result = self.cli_runner.invoke(cmd_group.group_add_nodes, ['--group=dummygroup1', node_01.uuid], input='y') - self.assertEqual(group.count(), 1) + result = self.cli_runner(cmd_group.group_add_nodes, ['--group=dummygroup1', node_01.uuid], user_input='y') + assert group.count() == 1 # Try to remove two nodes, one that isn't in the group, but abort - result = self.cli_runner.invoke( - cmd_group.group_remove_nodes, ['--group=dummygroup1', node_01.uuid, node_02.uuid], input='N' + result = self.cli_runner( + cmd_group.group_remove_nodes, ['--group=dummygroup1', node_01.uuid, node_02.uuid], + user_input='N', + raises=True ) - self.assertIn('Warning', result.output) - self.assertEqual(group.count(), 1) + assert 'Warning' in result.output + assert group.count() == 1 # Try to clear all nodes from the group, but abort - result = self.cli_runner.invoke(cmd_group.group_remove_nodes, ['--group=dummygroup1', '--clear'], input='N') - self.assertIn('Are you sure you want to remove ALL', result.output) - self.assertIn('Aborted', result.output) - self.assertEqual(group.count(), 1) + result = self.cli_runner( + cmd_group.group_remove_nodes, ['--group=dummygroup1', '--clear'], user_input='N', raises=True + ) + assert 'Are you sure you want to remove ALL' in result.output + assert 'Aborted' in result.output + assert group.count() == 1 def test_move_nodes(self): """Test `verdi group move-nodes` command.""" @@ -321,28 +290,31 @@ def test_move_nodes(self): group1.add_nodes([node_01, node_02]) # Moving the nodes to the same group - result = self.cli_runner.invoke( - cmd_group.group_move_nodes, ['-s', 'dummygroup1', '-t', 'dummygroup1', node_01.uuid, node_02.uuid] + result = self.cli_runner( + cmd_group.group_move_nodes, ['-s', 'dummygroup1', '-t', 'dummygroup1', node_01.uuid, node_02.uuid], + raises=True ) - self.assertIn('Source and target group are the same:', result.output) + assert 'Source and target group are the same:' in result.output # Not specifying NODES or `--all` - result = self.cli_runner.invoke(cmd_group.group_move_nodes, ['-s', 'dummygroup1', '-t', 'dummygroup2']) - self.assertIn('Neither NODES or the `-a, --all` option was specified.', result.output) + result = self.cli_runner(cmd_group.group_move_nodes, ['-s', 'dummygroup1', '-t', 'dummygroup2'], raises=True) + assert 'Neither NODES or the `-a, --all` option was specified.' in result.output # Moving the nodes from the empty group - result = self.cli_runner.invoke( - cmd_group.group_move_nodes, ['-s', 'dummygroup2', '-t', 'dummygroup1', node_01.uuid, node_02.uuid] + result = self.cli_runner( + cmd_group.group_move_nodes, ['-s', 'dummygroup2', '-t', 'dummygroup1', node_01.uuid, node_02.uuid], + raises=True ) - self.assertIn('None of the specified nodes are in', result.output) + assert 'None of the specified nodes are in' in result.output # Move two nodes to the second dummy group, but specify a missing uuid - result = self.cli_runner.invoke( - cmd_group.group_move_nodes, ['-s', 'dummygroup1', '-t', 'dummygroup2', node_01.uuid, node_03.uuid] + result = self.cli_runner( + cmd_group.group_move_nodes, ['-s', 'dummygroup1', '-t', 'dummygroup2', node_01.uuid, node_03.uuid], + raises=True ) - self.assertIn(f'1 nodes with PK {{{node_03.pk}}} are not in', result.output) + assert f'1 nodes with PK {{{node_03.pk}}} are not in' in result.output # Check that the node that is present is actually moved - result = self.cli_runner.invoke( + result = self.cli_runner( cmd_group.group_move_nodes, ['-f', '-s', 'dummygroup1', '-t', 'dummygroup2', node_01.uuid, node_03.uuid], ) @@ -351,28 +323,26 @@ def test_move_nodes(self): # Add the first node back to the first group, and try to move it from the second one group1.add_nodes(node_01) - result = self.cli_runner.invoke( - cmd_group.group_move_nodes, ['-s', 'dummygroup2', '-t', 'dummygroup1', node_01.uuid] + result = self.cli_runner( + cmd_group.group_move_nodes, ['-s', 'dummygroup2', '-t', 'dummygroup1', node_01.uuid], raises=True ) - self.assertIn(f'1 nodes with PK {{{node_01.pk}}} are already', result.output) + assert f'1 nodes with PK {{{node_01.pk}}} are already' in result.output # Check that it is still removed from the second group - result = self.cli_runner.invoke( + result = self.cli_runner( cmd_group.group_move_nodes, ['-f', '-s', 'dummygroup2', '-t', 'dummygroup1', node_01.uuid], ) assert node_01 not in group2.nodes # Force move the two nodes to the second dummy group - result = self.cli_runner.invoke( + result = self.cli_runner( cmd_group.group_move_nodes, ['-f', '-s', 'dummygroup1', '-t', 'dummygroup2', node_01.uuid, node_02.uuid] ) assert node_01 in group2.nodes assert node_02 in group2.nodes # Force move all nodes back to the first dummy group - result = self.cli_runner.invoke( - cmd_group.group_move_nodes, ['-f', '-s', 'dummygroup2', '-t', 'dummygroup1', '--all'] - ) + result = self.cli_runner(cmd_group.group_move_nodes, ['-f', '-s', 'dummygroup2', '-t', 'dummygroup1', '--all']) assert node_01 not in group2.nodes assert node_02 not in group2.nodes assert node_01 in group1.nodes @@ -392,28 +362,22 @@ def test_copy_existing_group(self): # Copy using `verdi group copy` - making sure all is successful options = [source_label, dest_label] - result = self.cli_runner.invoke(cmd_group.group_copy, options) - self.assertClickResultNoException(result) - self.assertIn( - f'Success: Nodes copied from {source_group} to {source_group.__class__.__name__}<{dest_label}>.', + result = self.cli_runner(cmd_group.group_copy, options) + assert f'Success: Nodes copied from {source_group} to {source_group.__class__.__name__}<{dest_label}>.' in \ result.output, result.exception - ) # Check destination group exists with source group's nodes dest_group = orm.load_group(label=dest_label) - self.assertEqual(dest_group.count(), 2) + assert dest_group.count() == 2 nodes_dest_group = {str(node.uuid) for node in dest_group.nodes} - self.assertSetEqual(nodes_source_group, nodes_dest_group) + assert nodes_source_group == nodes_dest_group # Copy again, making sure an abort error is raised, since no user input can be made and default is abort - result = self.cli_runner.invoke(cmd_group.group_copy, options) - self.assertIsNotNone(result.exception, result.output) - self.assertIn( - f'Warning: Destination {dest_group} already exists and is not empty.', result.output, result.exception - ) + result = self.cli_runner(cmd_group.group_copy, options, raises=True) + assert f'Warning: Destination {dest_group} already exists and is not empty.' in result.output, result.exception # Check destination group is unchanged dest_group = orm.load_group(label=dest_label) - self.assertEqual(dest_group.count(), 2) + assert dest_group.count() == 2 nodes_dest_group = {str(node.uuid) for node in dest_group.nodes} - self.assertSetEqual(nodes_source_group, nodes_dest_group) + assert nodes_source_group == nodes_dest_group diff --git a/tests/cmdline/commands/test_help.py b/tests/cmdline/commands/test_help.py index 83c632cd0f..c34d1fe79e 100644 --- a/tests/cmdline/commands/test_help.py +++ b/tests/cmdline/commands/test_help.py @@ -8,20 +8,19 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `verdi help`.""" -from click.testing import CliRunner import pytest from aiida.cmdline.commands import cmd_verdi -from aiida.storage.testbase import AiidaTestCase -@pytest.mark.usefixtures('config_with_profile') -class TestVerdiHelpCommand(AiidaTestCase): +class TestVerdiHelpCommand: """Tests for `verdi help`.""" - def setUp(self): - super().setUp() - self.cli_runner = CliRunner() + @pytest.fixture(autouse=True) + def init_profile(self, config_with_profile, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.cli_runner = run_cli_command def test_without_arg(self): """ @@ -30,12 +29,12 @@ def test_without_arg(self): """ # don't invoke the cmd directly to make sure ctx.parent is properly populated # as it would be when called as a cli - result_help = self.cli_runner.invoke(cmd_verdi.verdi, ['help'], catch_exceptions=False) - result_verdi = self.cli_runner.invoke(cmd_verdi.verdi, [], catch_exceptions=False) - self.assertEqual(result_help.output, result_verdi.output) + result_help = self.cli_runner(cmd_verdi.verdi, ['help'], catch_exceptions=False) + result_verdi = self.cli_runner(cmd_verdi.verdi, [], catch_exceptions=False) + assert result_help.output == result_verdi.output def test_cmd_help(self): """Ensure we get the same help for `verdi user --help` and `verdi help user`""" - result_help = self.cli_runner.invoke(cmd_verdi.verdi, ['help', 'user'], catch_exceptions=False) - result_user = self.cli_runner.invoke(cmd_verdi.verdi, ['user', '--help'], catch_exceptions=False) - self.assertEqual(result_help.output, result_user.output) + result_help = self.cli_runner(cmd_verdi.verdi, ['help', 'user'], catch_exceptions=False) + result_user = self.cli_runner(cmd_verdi.verdi, ['user', '--help'], catch_exceptions=False) + assert result_help.output == result_user.output diff --git a/tests/cmdline/commands/test_node.py b/tests/cmdline/commands/test_node.py index 28c11280c3..b09bccf46e 100644 --- a/tests/cmdline/commands/test_node.py +++ b/tests/cmdline/commands/test_node.py @@ -15,47 +15,42 @@ import pathlib import tempfile -from click.testing import CliRunner import pytest from aiida import orm from aiida.cmdline.commands import cmd_node -from aiida.storage.testbase import AiidaTestCase def get_result_lines(result): return [e for e in result.output.split('\n') if e] -class TestVerdiNode(AiidaTestCase): +class TestVerdiNode: """Tests for `verdi node`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init,invalid-name node = orm.Data() - cls.ATTR_KEY_ONE = 'a' - cls.ATTR_VAL_ONE = '1' - cls.ATTR_KEY_TWO = 'b' - cls.ATTR_VAL_TWO = 'test' + self.ATTR_KEY_ONE = 'a' + self.ATTR_VAL_ONE = '1' + self.ATTR_KEY_TWO = 'b' + self.ATTR_VAL_TWO = 'test' - node.set_attribute_many({cls.ATTR_KEY_ONE: cls.ATTR_VAL_ONE, cls.ATTR_KEY_TWO: cls.ATTR_VAL_TWO}) + node.set_attribute_many({self.ATTR_KEY_ONE: self.ATTR_VAL_ONE, self.ATTR_KEY_TWO: self.ATTR_VAL_TWO}) - cls.EXTRA_KEY_ONE = 'x' - cls.EXTRA_VAL_ONE = '2' - cls.EXTRA_KEY_TWO = 'y' - cls.EXTRA_VAL_TWO = 'other' + self.EXTRA_KEY_ONE = 'x' + self.EXTRA_VAL_ONE = '2' + self.EXTRA_KEY_TWO = 'y' + self.EXTRA_VAL_TWO = 'other' - node.set_extra_many({cls.EXTRA_KEY_ONE: cls.EXTRA_VAL_ONE, cls.EXTRA_KEY_TWO: cls.EXTRA_VAL_TWO}) + node.set_extra_many({self.EXTRA_KEY_ONE: self.EXTRA_VAL_ONE, self.EXTRA_KEY_TWO: self.EXTRA_VAL_TWO}) node.store() - - cls.node = node - - def setUp(self): - self.cli_runner = CliRunner() + self.node = node + self.cli_runner = run_cli_command @classmethod def get_unstored_folder_node(cls): @@ -77,17 +72,15 @@ def test_node_show(self): node = orm.Data().store() node.label = 'SOMELABEL' options = [str(node.pk)] - result = self.cli_runner.invoke(cmd_node.node_show, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_node.node_show, options) # Let's check some content in the output. At least the UUID and the label should be in there - self.assertIn(node.label, result.output) - self.assertIn(node.uuid, result.output) + assert node.label in result.output + assert node.uuid in result.output # Let's now test the '--print-groups' option options.append('--print-groups') - result = self.cli_runner.invoke(cmd_node.node_show, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_node.node_show, options) # I don't check the list of groups - it might be in an autogroup # Let's create a group and put the node in there @@ -95,96 +88,85 @@ def test_node_show(self): group = orm.Group(group_name).store() group.add_nodes(node) - result = self.cli_runner.invoke(cmd_node.node_show, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_node.node_show, options) + # Now the group should be in there - self.assertIn(group_name, result.output) + assert group_name in result.output def test_node_attributes(self): """Test verdi node attributes""" options = [str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.attributes, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.ATTR_KEY_ONE, result.output) - self.assertIn(self.ATTR_VAL_ONE, result.output) - self.assertIn(self.ATTR_KEY_TWO, result.output) - self.assertIn(self.ATTR_VAL_TWO, result.output) + result = self.cli_runner(cmd_node.attributes, options) + assert self.ATTR_KEY_ONE in result.output + assert self.ATTR_VAL_ONE in result.output + assert self.ATTR_KEY_TWO in result.output + assert self.ATTR_VAL_TWO in result.output for flag in ['-k', '--keys']: options = [flag, self.ATTR_KEY_ONE, '--', str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.attributes, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.ATTR_KEY_ONE, result.output) - self.assertIn(self.ATTR_VAL_ONE, result.output) - self.assertNotIn(self.ATTR_KEY_TWO, result.output) - self.assertNotIn(self.ATTR_VAL_TWO, result.output) + result = self.cli_runner(cmd_node.attributes, options) + assert self.ATTR_KEY_ONE in result.output + assert self.ATTR_VAL_ONE in result.output + assert self.ATTR_KEY_TWO not in result.output + assert self.ATTR_VAL_TWO not in result.output for flag in ['-r', '--raw']: options = [flag, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.attributes, options) - self.assertIsNone(result.exception, result.output) + self.cli_runner(cmd_node.attributes, options) for flag in ['-f', '--format']: for fmt in ['json+date', 'yaml', 'yaml_expanded']: options = [flag, fmt, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.attributes, options) - self.assertIsNone(result.exception, result.output) + self.cli_runner(cmd_node.attributes, options) for flag in ['-i', '--identifier']: for fmt in ['pk', 'uuid']: options = [flag, fmt, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.attributes, options) - self.assertIsNone(result.exception, result.output) + self.cli_runner(cmd_node.attributes, options) def test_node_extras(self): """Test verdi node extras""" options = [str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.extras, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.EXTRA_KEY_ONE, result.output) - self.assertIn(self.EXTRA_VAL_ONE, result.output) - self.assertIn(self.EXTRA_KEY_TWO, result.output) - self.assertIn(self.EXTRA_VAL_TWO, result.output) + result = self.cli_runner(cmd_node.extras, options) + assert self.EXTRA_KEY_ONE in result.output + assert self.EXTRA_VAL_ONE in result.output + assert self.EXTRA_KEY_TWO in result.output + assert self.EXTRA_VAL_TWO in result.output for flag in ['-k', '--keys']: options = [flag, self.EXTRA_KEY_ONE, '--', str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.extras, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.EXTRA_KEY_ONE, result.output) - self.assertIn(self.EXTRA_VAL_ONE, result.output) - self.assertNotIn(self.EXTRA_KEY_TWO, result.output) - self.assertNotIn(self.EXTRA_VAL_TWO, result.output) + result = self.cli_runner(cmd_node.extras, options) + assert self.EXTRA_KEY_ONE in result.output + assert self.EXTRA_VAL_ONE in result.output + assert self.EXTRA_KEY_TWO not in result.output + assert self.EXTRA_VAL_TWO not in result.output for flag in ['-r', '--raw']: options = [flag, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.extras, options) - self.assertIsNone(result.exception, result.output) + result = self.cli_runner(cmd_node.extras, options) for flag in ['-f', '--format']: for fmt in ['json+date', 'yaml', 'yaml_expanded']: options = [flag, fmt, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.extras, options) - self.assertIsNone(result.exception, result.output) + self.cli_runner(cmd_node.extras, options) for flag in ['-i', '--identifier']: for fmt in ['pk', 'uuid']: options = [flag, fmt, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.extras, options) - self.assertIsNone(result.exception, result.output) + self.cli_runner(cmd_node.extras, options) def test_node_repo_ls(self): """Test 'verdi node repo ls' command.""" folder_node = self.get_unstored_folder_node().store() options = [str(folder_node.pk), 'some/nested/folder'] - result = self.cli_runner.invoke(cmd_node.repo_ls, options, catch_exceptions=False) - self.assertClickResultNoException(result) - self.assertIn('filename.txt', result.output) + result = self.cli_runner(cmd_node.repo_ls, options, catch_exceptions=False) + + assert 'filename.txt' in result.output options = [str(folder_node.pk), 'some/non-existing-folder'] - result = self.cli_runner.invoke(cmd_node.repo_ls, options, catch_exceptions=False) - self.assertIsNotNone(result.exception) - self.assertIn('does not exist for the given node', result.output) + result = self.cli_runner(cmd_node.repo_ls, options, catch_exceptions=False, raises=True) + assert 'does not exist for the given node' in result.output def test_node_repo_cat(self): """Test 'verdi node repo cat' command.""" @@ -195,7 +177,7 @@ def test_node_repo_cat(self): folder_node.store() options = [str(folder_node.pk), 'filename.txt.gz'] - result = self.cli_runner.invoke(cmd_node.repo_cat, options) + result = self.cli_runner(cmd_node.repo_cat, options) assert gzip.decompress(result.stdout_bytes) == b'COMPRESS' def test_node_repo_dump(self): @@ -205,16 +187,16 @@ def test_node_repo_dump(self): with tempfile.TemporaryDirectory() as tmp_dir: out_path = pathlib.Path(tmp_dir) / 'out_dir' options = [str(folder_node.uuid), str(out_path)] - res = self.cli_runner.invoke(cmd_node.repo_dump, options, catch_exceptions=False) - self.assertFalse(res.stdout) + res = self.cli_runner(cmd_node.repo_dump, options, catch_exceptions=False) + assert not res.stdout for file_key, content in [(self.key_file1, self.content_file1), (self.key_file2, self.content_file2)]: curr_path = out_path for key_part in file_key.split('/'): curr_path /= key_part - self.assertTrue(curr_path.exists()) + assert curr_path.exists() with curr_path.open('r') as res_file: - self.assertEqual(res_file.read(), content) + assert res_file.read() == content def test_node_repo_dump_to_nested_folder(self): """Test 'verdi node repo dump' command, with an output folder whose parent does not exist.""" @@ -223,16 +205,16 @@ def test_node_repo_dump_to_nested_folder(self): with tempfile.TemporaryDirectory() as tmp_dir: out_path = pathlib.Path(tmp_dir) / 'out_dir' / 'nested' / 'path' options = [str(folder_node.uuid), str(out_path)] - res = self.cli_runner.invoke(cmd_node.repo_dump, options, catch_exceptions=False) - self.assertFalse(res.stdout) + res = self.cli_runner(cmd_node.repo_dump, options, catch_exceptions=False) + assert not res.stdout for file_key, content in [(self.key_file1, self.content_file1), (self.key_file2, self.content_file2)]: curr_path = out_path for key_part in file_key.split('/'): curr_path /= key_part - self.assertTrue(curr_path.exists()) + assert curr_path.exists() with curr_path.open('r') as res_file: - self.assertEqual(res_file.read(), content) + assert res_file.read() == content def test_node_repo_existing_out_dir(self): """Test 'verdi node repo dump' command, check that an existing output directory is not overwritten.""" @@ -247,9 +229,9 @@ def test_node_repo_existing_out_dir(self): with some_file.open('w') as file_handle: file_handle.write(some_file_content) options = [str(folder_node.uuid), str(out_path)] - res = self.cli_runner.invoke(cmd_node.repo_dump, options, catch_exceptions=False) - self.assertIn('exists', res.stdout) - self.assertIn('Critical:', res.stdout) + res = self.cli_runner(cmd_node.repo_dump, options, catch_exceptions=False) + assert 'exists' in res.stdout + assert 'Critical:' in res.stdout # Make sure the directory content is still there with some_file.open('r') as file_handle: @@ -272,29 +254,26 @@ def delete_temporary_file(filepath): pass -class TestVerdiGraph(AiidaTestCase): +class TestVerdiGraph: """Tests for the ``verdi node graph`` command.""" - @classmethod - def setUpClass(cls): - super().setUpClass() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command, tmp_path): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init from aiida.orm import Data - cls.node = Data().store() + self.node = Data().store() + self.cli_runner = run_cli_command # some of the export tests write in the current directory, # make sure it is writeable and we don't pollute the current one - cls.old_cwd = os.getcwd() - cls.cwd = tempfile.mkdtemp(__name__) - os.chdir(cls.cwd) - - @classmethod - def tearDownClass(cls): - os.chdir(cls.old_cwd) - os.rmdir(cls.cwd) - - def setUp(self): - self.cli_runner = CliRunner() + self.old_cwd = os.getcwd() + self.cwd = str(tmp_path.absolute()) + os.chdir(self.cwd) + yield + os.chdir(self.old_cwd) + os.rmdir(self.cwd) def test_generate_graph(self): """ @@ -306,9 +285,8 @@ def test_generate_graph(self): filename = f'{root_node}.dot.pdf' options = [root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + self.cli_runner(cmd_node.graph_generate, options) + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -325,9 +303,8 @@ def test_catch_bad_pk(self): options = [root_node] filename = f'{root_node}.dot.pdf' try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNotNone(result.exception) - self.assertFalse(os.path.isfile(filename)) + self.cli_runner(cmd_node.graph_generate, options, raises=True) + assert not os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -337,16 +314,15 @@ def test_catch_bad_pk(self): root_node = 123456789 try: node = load_node(pk=root_node) - self.assertIsNone(node) + assert node is None except NotExistent: pass # Make sure verdi graph rejects this non-existant pk try: filename = f'{str(root_node)}.dot.pdf' options = [str(root_node)] - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNotNone(result.exception) - self.assertFalse(os.path.isfile(filename)) + self.cli_runner(cmd_node.graph_generate, options, raises=True) + assert not os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -363,9 +339,8 @@ def test_check_recursion_flags(self): for opt in ['-a', '--ancestor-depth', '-d', '--descendant-depth']: options = [opt, None, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + self.cli_runner(cmd_node.graph_generate, options) + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -374,9 +349,8 @@ def test_check_recursion_flags(self): for value in ['0', '1']: options = [opt, value, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + self.cli_runner(cmd_node.graph_generate, options) + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -385,9 +359,8 @@ def test_check_recursion_flags(self): for badvalue in ['xyz', '3.14', '-5']: options = [flag, badvalue, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNotNone(result.exception) - self.assertFalse(os.path.isfile(filename)) + self.cli_runner(cmd_node.graph_generate, options, raises=True) + assert not os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -401,9 +374,8 @@ def test_check_io_flags(self): for flag in ['-i', '--process-in', '-o', '--process-out']: options = [flag, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + self.cli_runner(cmd_node.graph_generate, options) + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -423,9 +395,8 @@ def test_output_format(self): filename = f'{root_node}.dot.{fileformat}' options = [option, fileformat, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + self.cli_runner(cmd_node.graph_generate, options) + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -439,9 +410,8 @@ def test_node_id_label_format(self): for id_label_type in ['uuid', 'pk', 'label']: options = ['--identifier', id_label_type, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + self.cli_runner(cmd_node.graph_generate, options) + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -449,153 +419,149 @@ def test_node_id_label_format(self): COMMENT = 'Well I never...' -class TestVerdiUserCommand(AiidaTestCase): +class TestVerdiUserCommand: """Tests for the ``verdi node comment`` command.""" - def setUp(self): - self.cli_runner = CliRunner() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init,invalid-name + self.cli_runner = run_cli_command self.node = orm.Data().store() def test_comment_show_simple(self): """Test simply calling the show command (without data to show).""" - result = self.cli_runner.invoke(cmd_node.comment_show, [], catch_exceptions=False) - self.assertEqual(result.output, '') - self.assertEqual(result.exit_code, 0) + result = self.cli_runner(cmd_node.comment_show, [], catch_exceptions=False) + assert result.output == '' + assert result.exit_code == 0 def test_comment_show(self): """Test showing an existing comment.""" self.node.add_comment(COMMENT) options = [str(self.node.pk)] - result = self.cli_runner.invoke(cmd_node.comment_show, options, catch_exceptions=False) - self.assertNotEqual(result.output.find(COMMENT), -1) - self.assertEqual(result.exit_code, 0) + result = self.cli_runner(cmd_node.comment_show, options, catch_exceptions=False) + assert result.output.find(COMMENT) != -1 + assert result.exit_code == 0 def test_comment_add(self): """Test adding a comment.""" options = ['-N', str(self.node.pk), '--', f'{COMMENT}'] - result = self.cli_runner.invoke(cmd_node.comment_add, options, catch_exceptions=False) - self.assertEqual(result.exit_code, 0) + result = self.cli_runner(cmd_node.comment_add, options, catch_exceptions=False) + assert result.exit_code == 0 comment = self.node.get_comments() - self.assertEqual(len(comment), 1) - self.assertEqual(comment[0].content, COMMENT) + assert len(comment) == 1 + assert comment[0].content == COMMENT def test_comment_remove(self): """Test removing a comment.""" comment = self.node.add_comment(COMMENT) - self.assertEqual(len(self.node.get_comments()), 1) + assert len(self.node.get_comments()) == 1 options = [str(comment.pk), '--force'] - result = self.cli_runner.invoke(cmd_node.comment_remove, options, catch_exceptions=False) - self.assertEqual(result.exit_code, 0, result.output) - self.assertEqual(len(self.node.get_comments()), 0) + result = self.cli_runner(cmd_node.comment_remove, options, catch_exceptions=False) + assert result.exit_code == 0, result.output + assert len(self.node.get_comments()) == 0 -class TestVerdiRehash(AiidaTestCase): +class TestVerdiRehash: """Tests for the ``verdi node rehash`` command.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init,invalid-name from aiida.orm import Bool, Data, Float, Int + self.cli_runner = run_cli_command - cls.node_base = Data().store() - cls.node_bool_true = Bool(True).store() - cls.node_bool_false = Bool(False).store() - cls.node_float = Float(1.0).store() - cls.node_int = Int(1).store() - - def setUp(self): - self.cli_runner = CliRunner() + self.node_base = Data().store() + self.node_bool_true = Bool(True).store() + self.node_bool_false = Bool(False).store() + self.node_float = Float(1.0).store() + self.node_int = Int(1).store() def test_rehash_interactive_yes(self): """Passing no options and answering 'Y' to the command will rehash all 5 nodes.""" expected_node_count = 5 options = [] # no option, will ask in the prompt - result = self.cli_runner.invoke(cmd_node.rehash, options, input='y') - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options, user_input='y') + assert f'{expected_node_count} nodes' in result.output def test_rehash_interactive_no(self): """Passing no options and answering 'N' to the command will abort the command.""" options = [] # no option, will ask in the prompt - result = self.cli_runner.invoke(cmd_node.rehash, options, input='n') - self.assertIsInstance(result.exception, SystemExit) - self.assertIn('ExitCode.CRITICAL', str(result.exception)) + result = self.cli_runner(cmd_node.rehash, options, user_input='n', raises=True) + assert isinstance(result.exception, SystemExit) + assert 'ExitCode.CRITICAL' in str(result.exception) def test_rehash(self): """Passing no options to the command will rehash all 5 nodes.""" expected_node_count = 5 options = ['-f'] # force, so no questions are asked - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + assert f'{expected_node_count} nodes' in result.output def test_rehash_bool(self): """Limiting the queryset by defining an entry point, in this case bool, should limit nodes to 2.""" expected_node_count = 2 options = ['-f', '-e', 'aiida.data:core.bool'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + + assert f'{expected_node_count} nodes' in result.output def test_rehash_float(self): """Limiting the queryset by defining an entry point, in this case float, should limit nodes to 1.""" expected_node_count = 1 options = ['-f', '-e', 'aiida.data:core.float'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + + assert f'{expected_node_count} nodes' in result.output def test_rehash_int(self): """Limiting the queryset by defining an entry point, in this case int, should limit nodes to 1.""" expected_node_count = 1 options = ['-f', '-e', 'aiida.data:core.int'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + + assert f'{expected_node_count} nodes' in result.output def test_rehash_explicit_pk(self): """Limiting the queryset by defining explicit identifiers, should limit nodes to 2 in this example.""" expected_node_count = 2 options = ['-f', str(self.node_bool_true.pk), str(self.node_float.uuid)] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + + assert f'{expected_node_count} nodes' in result.output def test_rehash_explicit_pk_and_entry_point(self): """Limiting the queryset by defining explicit identifiers and entry point, should limit nodes to 1.""" expected_node_count = 1 options = ['-f', '-e', 'aiida.data:core.bool', str(self.node_bool_true.pk), str(self.node_float.uuid)] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + + assert f'{expected_node_count} nodes' in result.output def test_rehash_entry_point_no_matches(self): """Limiting the queryset by defining explicit entry point, with no nodes should exit with non-zero status.""" options = ['-f', '-e', 'aiida.data:core.structure'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertIsNotNone(result.exception) + self.cli_runner(cmd_node.rehash, options, raises=True) def test_rehash_invalid_entry_point(self): """Passing an invalid entry point should exit with non-zero status.""" # Incorrect entry point group options = ['-f', '-e', 'data:core.structure'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertIsNotNone(result.exception) + self.cli_runner(cmd_node.rehash, options, raises=True) # Non-existent entry point name options = ['-f', '-e', 'aiida.data:inexistant'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertIsNotNone(result.exception) + self.cli_runner(cmd_node.rehash, options, raises=True) # Incorrect syntax, no colon to join entry point group and name options = ['-f', '-e', 'aiida.data.structure'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertIsNotNone(result.exception) + self.cli_runner(cmd_node.rehash, options, raises=True) @pytest.mark.parametrize( diff --git a/tests/cmdline/commands/test_process.py b/tests/cmdline/commands/test_process.py index d3e228aca6..868cc02fe6 100644 --- a/tests/cmdline/commands/test_process.py +++ b/tests/cmdline/commands/test_process.py @@ -12,7 +12,6 @@ from concurrent.futures import Future import time -from click.testing import CliRunner import kiwipy import plumpy import pytest @@ -22,7 +21,6 @@ from aiida.common.links import LinkType from aiida.common.log import LOG_LEVEL_REPORT from aiida.orm import CalcJobNode, WorkChainNode, WorkflowNode, WorkFunctionNode -from aiida.storage.testbase import AiidaTestCase from tests.utils import processes as test_processes @@ -30,19 +28,20 @@ def get_result_lines(result): return [e for e in result.output.split('\n') if e] -class TestVerdiProcess(AiidaTestCase): +class TestVerdiProcess: """Tests for `verdi process`.""" TEST_TIMEOUT = 5. - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init from aiida.engine import ProcessState from aiida.orm.groups import Group - cls.calcs = [] - cls.process_label = 'SomeDummyWorkFunctionNode' + self.calcs = [] + self.process_label = 'SomeDummyWorkFunctionNode' # Create 6 WorkFunctionNodes and WorkChainNodes (one for each ProcessState) for state in ProcessState: @@ -55,10 +54,10 @@ def setUpClass(cls, *args, **kwargs): calc.set_exit_status(0) # Give a `process_label` to the `WorkFunctionNodes` so the `--process-label` option can be tested - calc.set_attribute('process_label', cls.process_label) + calc.set_attribute('process_label', self.process_label) calc.store() - cls.calcs.append(calc) + self.calcs.append(calc) calc = WorkChainNode() calc.set_process_state(state) @@ -72,118 +71,112 @@ def setUpClass(cls, *args, **kwargs): calc.pause() calc.store() - cls.calcs.append(calc) + self.calcs.append(calc) - cls.group = Group('some_group').store() - cls.group.add_nodes(cls.calcs[0]) - - def setUp(self): - super().setUp() - self.cli_runner = CliRunner() + self.group = Group('some_group').store() + self.group.add_nodes(self.calcs[0]) + self.cli_runner = run_cli_command def test_list_non_raw(self): """Test the list command as the user would run it (e.g. without -r).""" - result = self.cli_runner.invoke(cmd_process.process_list) - self.assertIsNone(result.exception, result.output) - self.assertIn('Total results:', result.output) - self.assertIn('last time an entry changed state', result.output) + result = self.cli_runner(cmd_process.process_list) + + assert 'Total results:' in result.output + assert 'last time an entry changed state' in result.output def test_list(self): """Test the list command.""" # pylint: disable=too-many-branches # Default behavior should yield all active states (CREATED, RUNNING and WAITING) so six in total - result = self.cli_runner.invoke(cmd_process.process_list, ['-r']) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 6) + result = self.cli_runner(cmd_process.process_list, ['-r']) + + assert len(get_result_lines(result)) == 6 # Ordering shouldn't change the number of results, for flag in ['-O', '--order-by']: for flag_value in ['id', 'ctime']: - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', flag, flag_value]) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 6) + result = self.cli_runner(cmd_process.process_list, ['-r', flag, flag_value]) + + assert len(get_result_lines(result)) == 6 # but the orders should be inverse for flag in ['-D', '--order-direction']: flag_value = 'asc' - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', '-O', 'id', flag, flag_value]) - self.assertIsNone(result.exception, result.output) + result = self.cli_runner(cmd_process.process_list, ['-r', '-O', 'id', flag, flag_value]) + result_num_asc = [line.split()[0] for line in get_result_lines(result)] - self.assertEqual(len(result_num_asc), 6) + assert len(result_num_asc) == 6 flag_value = 'desc' - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', '-O', 'id', flag, flag_value]) - self.assertIsNone(result.exception, result.output) + result = self.cli_runner(cmd_process.process_list, ['-r', '-O', 'id', flag, flag_value]) + result_num_desc = [line.split()[0] for line in get_result_lines(result)] - self.assertEqual(len(result_num_desc), 6) + assert len(result_num_desc) == 6 - self.assertEqual(result_num_asc, list(reversed(result_num_desc))) + assert result_num_asc == list(reversed(result_num_desc)) # Adding the all option should return all entries regardless of process state for flag in ['-a', '--all']: - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', flag]) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 12) + result = self.cli_runner(cmd_process.process_list, ['-r', flag]) + + assert len(get_result_lines(result)) == 12 # Passing the limit option should limit the results for flag in ['-l', '--limit']: - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', flag, '6']) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 6) + result = self.cli_runner(cmd_process.process_list, ['-r', flag, '6']) + assert len(get_result_lines(result)) == 6 # Filtering for a specific process state for flag in ['-S', '--process-state']: for flag_value in ['created', 'running', 'waiting', 'killed', 'excepted', 'finished']: - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', flag, flag_value]) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 2) + result = self.cli_runner(cmd_process.process_list, ['-r', flag, flag_value]) + assert len(get_result_lines(result)) == 2 # Filtering for exit status should only get us one for flag in ['-E', '--exit-status']: for exit_status in ['0', '1']: - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', flag, exit_status]) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) + result = self.cli_runner(cmd_process.process_list, ['-r', flag, exit_status]) + assert len(get_result_lines(result)) == 1 # Passing the failed flag as a shortcut for FINISHED + non-zero exit status for flag in ['-X', '--failed']: - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', flag]) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) + result = self.cli_runner(cmd_process.process_list, ['-r', flag]) + + assert len(get_result_lines(result)) == 1 # Projecting on pk should allow us to verify all the pks for flag in ['-P', '--project']: - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', flag, 'pk']) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 6) + result = self.cli_runner(cmd_process.process_list, ['-r', flag, 'pk']) + + assert len(get_result_lines(result)) == 6 for line in get_result_lines(result): - self.assertIn(line.strip(), [str(calc.pk) for calc in self.calcs]) + assert line.strip() in [str(calc.pk) for calc in self.calcs] # The group option should limit the query set to nodes in the group for flag in ['-G', '--group']: - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', '-P', 'pk', flag, str(self.group.pk)]) - self.assertClickResultNoException(result) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], str(self.calcs[0].pk)) + result = self.cli_runner(cmd_process.process_list, ['-r', '-P', 'pk', flag, str(self.group.pk)]) + + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == str(self.calcs[0].pk) # The process label should limit the query set to nodes with the given `process_label` attribute for flag in ['-L', '--process-label']: for process_label in [self.process_label, self.process_label.replace('Dummy', '%')]: - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', flag, process_label]) - self.assertClickResultNoException(result) - self.assertEqual(len(get_result_lines(result)), 3) # Should only match the active `WorkFunctionNodes` + result = self.cli_runner(cmd_process.process_list, ['-r', flag, process_label]) + + assert len(get_result_lines(result)) == 3 # Should only match the active `WorkFunctionNodes` for line in get_result_lines(result): - self.assertIn(self.process_label, line.strip()) + assert self.process_label in line.strip() # There should be exactly one paused for flag in ['--paused']: - result = self.cli_runner.invoke(cmd_process.process_list, ['-r', flag]) - self.assertClickResultNoException(result) - self.assertEqual(len(get_result_lines(result)), 1) + result = self.cli_runner(cmd_process.process_list, ['-r', flag]) + + assert len(get_result_lines(result)) == 1 def test_process_show(self): """Test verdi process show""" @@ -210,25 +203,25 @@ def test_process_show(self): # Running without identifiers should not except and not print anything options = [] - result = self.cli_runner.invoke(cmd_process.process_show, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 0) + result = self.cli_runner(cmd_process.process_show, options) + + assert len(get_result_lines(result)) == 0 # Giving a single identifier should print a non empty string message options = [str(workchain_one.pk)] - result = self.cli_runner.invoke(cmd_process.process_show, options) + result = self.cli_runner(cmd_process.process_show, options) lines = get_result_lines(result) - self.assertClickResultNoException(result) - self.assertTrue(len(lines) > 0) - self.assertIn('workchain_one_caller', result.output) - self.assertIn('process_label_one', lines[-2]) - self.assertIn('process_label_two', lines[-1]) + + assert len(lines) > 0 + assert 'workchain_one_caller' in result.output + assert 'process_label_one' in lines[-2] + assert 'process_label_two' in lines[-1] # Giving multiple identifiers should print a non empty string message options = [str(node.pk) for node in workchains] - result = self.cli_runner.invoke(cmd_process.process_show, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(len(get_result_lines(result)) > 0) + result = self.cli_runner(cmd_process.process_show, options) + + assert len(get_result_lines(result)) > 0 def test_process_report(self): """Test verdi process report""" @@ -236,21 +229,21 @@ def test_process_report(self): # Running without identifiers should not except and not print anything options = [] - result = self.cli_runner.invoke(cmd_process.process_report, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 0) + result = self.cli_runner(cmd_process.process_report, options) + + assert len(get_result_lines(result)) == 0 # Giving a single identifier should print a non empty string message options = [str(node.pk)] - result = self.cli_runner.invoke(cmd_process.process_report, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(len(get_result_lines(result)) > 0) + result = self.cli_runner(cmd_process.process_report, options) + + assert len(get_result_lines(result)) > 0 # Giving multiple identifiers should print a non empty string message options = [str(calc.pk) for calc in [node]] - result = self.cli_runner.invoke(cmd_process.process_report, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(len(get_result_lines(result)) > 0) + result = self.cli_runner(cmd_process.process_report, options) + + assert len(get_result_lines(result)) > 0 def test_report(self): """Test the report command.""" @@ -267,33 +260,31 @@ def test_report(self): parent.logger.log(LOG_LEVEL_REPORT, 'parent_message') child.logger.log(LOG_LEVEL_REPORT, 'child_message') - result = self.cli_runner.invoke(cmd_process.process_report, [str(grandparent.pk)]) - self.assertClickResultNoException(result) - self.assertEqual(len(get_result_lines(result)), 3) + result = self.cli_runner(cmd_process.process_report, [str(grandparent.pk)]) - result = self.cli_runner.invoke(cmd_process.process_report, [str(parent.pk)]) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 2) + assert len(get_result_lines(result)) == 3 - result = self.cli_runner.invoke(cmd_process.process_report, [str(child.pk)]) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) + result = self.cli_runner(cmd_process.process_report, [str(parent.pk)]) + + assert len(get_result_lines(result)) == 2 + + result = self.cli_runner(cmd_process.process_report, [str(child.pk)]) + + assert len(get_result_lines(result)) == 1 # Max depth should limit nesting level for flag in ['-m', '--max-depth']: for flag_value in [1, 2]: - result = self.cli_runner.invoke( - cmd_process.process_report, [str(grandparent.pk), flag, str(flag_value)] - ) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), flag_value) + result = self.cli_runner(cmd_process.process_report, [str(grandparent.pk), flag, str(flag_value)]) + + assert len(get_result_lines(result)) == flag_value # Filtering for other level name such as WARNING should not have any hits and only print the no log message for flag in ['-l', '--levelname']: - result = self.cli_runner.invoke(cmd_process.process_report, [str(grandparent.pk), flag, 'WARNING']) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1, get_result_lines(result)) - self.assertEqual(get_result_lines(result)[0], 'No log messages recorded for this entry') + result = self.cli_runner(cmd_process.process_report, [str(grandparent.pk), flag, 'WARNING']) + + assert len(get_result_lines(result)) == 1, get_result_lines(result) + assert get_result_lines(result)[0] == 'No log messages recorded for this entry' @pytest.mark.usefixtures('aiida_profile_clean') @@ -336,55 +327,51 @@ def test_list_worker_slot_warning(run_cli_command, monkeypatch): assert any(warning_phrase in line for line in result.output_lines) -class TestVerdiProcessCallRoot(AiidaTestCase): +class TestVerdiProcessCallRoot: """Tests for `verdi process call-root`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.node_root = WorkflowNode() - cls.node_middle = WorkflowNode() - cls.node_terminal = WorkflowNode() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.node_root = WorkflowNode() + self.node_middle = WorkflowNode() + self.node_terminal = WorkflowNode() - cls.node_root.store() + self.node_root.store() - cls.node_middle.add_incoming(cls.node_root, link_type=LinkType.CALL_WORK, link_label='call_middle') - cls.node_middle.store() + self.node_middle.add_incoming(self.node_root, link_type=LinkType.CALL_WORK, link_label='call_middle') + self.node_middle.store() - cls.node_terminal.add_incoming(cls.node_middle, link_type=LinkType.CALL_WORK, link_label='call_terminal') - cls.node_terminal.store() + self.node_terminal.add_incoming(self.node_middle, link_type=LinkType.CALL_WORK, link_label='call_terminal') + self.node_terminal.store() - def setUp(self): - super().setUp() - self.cli_runner = CliRunner() + self.cli_runner = run_cli_command def test_no_caller(self): """Test `verdi process call-root` when passing single process without caller.""" options = [str(self.node_root.pk)] - result = self.cli_runner.invoke(cmd_process.process_call_root, options) - self.assertClickResultNoException(result) - self.assertTrue(len(get_result_lines(result)) == 1) - self.assertIn('No callers found', get_result_lines(result)[0]) + result = self.cli_runner(cmd_process.process_call_root, options) + assert len(get_result_lines(result)) == 1 + assert 'No callers found' in get_result_lines(result)[0] def test_single_caller(self): """Test `verdi process call-root` when passing single process with call root.""" # Both the middle and terminal node should have the `root` node as call root. for node in [self.node_middle, self.node_terminal]: options = [str(node.pk)] - result = self.cli_runner.invoke(cmd_process.process_call_root, options) - self.assertClickResultNoException(result) - self.assertTrue(len(get_result_lines(result)) == 1) - self.assertIn(str(self.node_root.pk), get_result_lines(result)[0]) + result = self.cli_runner(cmd_process.process_call_root, options) + assert len(get_result_lines(result)) == 1 + assert str(self.node_root.pk) in get_result_lines(result)[0] def test_multiple_processes(self): """Test `verdi process call-root` when passing multiple processes.""" options = [str(self.node_root.pk), str(self.node_middle.pk), str(self.node_terminal.pk)] - result = self.cli_runner.invoke(cmd_process.process_call_root, options) - self.assertClickResultNoException(result) - self.assertTrue(len(get_result_lines(result)) == 3) - self.assertIn('No callers found', get_result_lines(result)[0]) - self.assertIn(str(self.node_root.pk), get_result_lines(result)[1]) - self.assertIn(str(self.node_root.pk), get_result_lines(result)[2]) + result = self.cli_runner(cmd_process.process_call_root, options) + assert len(get_result_lines(result)) == 3 + assert 'No callers found' in get_result_lines(result)[0] + assert str(self.node_root.pk) in get_result_lines(result)[1] + assert str(self.node_root.pk) in get_result_lines(result)[2] @pytest.mark.skip(reason='fails to complete randomly (see issue #4731)') diff --git a/tests/cmdline/commands/test_profile.py b/tests/cmdline/commands/test_profile.py index 0bbcd4ac7e..13af5ebb8f 100644 --- a/tests/cmdline/commands/test_profile.py +++ b/tests/cmdline/commands/test_profile.py @@ -8,25 +8,33 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `verdi profile`.""" - -from click.testing import CliRunner +from pgtest.pgtest import PGTest import pytest from aiida.cmdline.commands import cmd_profile, cmd_verdi -from aiida.manage import configuration, get_manager -from aiida.storage.testbase import AiidaPostgresTestCase +from aiida.manage import configuration from tests.utils.configuration import create_mock_profile -@pytest.mark.usefixtures('config_with_profile') -class TestVerdiProfileSetup(AiidaPostgresTestCase): +@pytest.fixture(scope='class') +def pg_test_cluster(): + """Create a standalone Postgres cluster, for setup tests.""" + pg_test = PGTest() + yield pg_test + pg_test.close() + + +class TestVerdiProfileSetup: """Tests for `verdi profile`.""" - def setUp(self): - """Create a CLI runner to invoke the CLI commands.""" - super().setUp() - self.cli_runner = CliRunner() - self.config = None + @pytest.fixture(autouse=True) + def init_profile(self, pg_test_cluster, empty_config, run_cli_command): # pylint: disable=redefined-outer-name,unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.storage_backend_name = 'psql_dos' + self.pg_test = pg_test_cluster + self.cli_runner = run_cli_command + self.config = configuration.get_config() self.profile_list = [] def mock_profiles(self, **kwargs): @@ -35,7 +43,7 @@ def mock_profiles(self, **kwargs): Note: this cannot be done in the `setUp` or `setUpClass` methods, because the temporary configuration instance is not generated until the test function is entered, which calls the `config_with_profile` test fixture. """ - self.config = configuration.get_config() + # pylint: disable=attribute-defined-outside-init self.profile_list = ['mock_profile1', 'mock_profile2', 'mock_profile3', 'mock_profile4'] for profile_name in self.profile_list: @@ -50,45 +58,36 @@ def test_help(self): options = ['--help'] - result = self.cli_runner.invoke(cmd_profile.profile_list, options) - self.assertClickSuccess(result) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_profile.profile_list, options) + assert 'Usage' in result.output - result = self.cli_runner.invoke(cmd_profile.profile_setdefault, options) - self.assertClickSuccess(result) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_profile.profile_setdefault, options) + assert 'Usage' in result.output - result = self.cli_runner.invoke(cmd_profile.profile_delete, options) - self.assertClickSuccess(result) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_profile.profile_delete, options) + assert 'Usage' in result.output - result = self.cli_runner.invoke(cmd_profile.profile_show, options) - self.assertClickSuccess(result) - self.assertIn('Usage', result.output) + result = self.cli_runner(cmd_profile.profile_show, options) + assert 'Usage' in result.output def test_list(self): """Test the `verdi profile list` command.""" self.mock_profiles() - result = self.cli_runner.invoke(cmd_profile.profile_list) - self.assertClickSuccess(result) - self.assertIn(f'Report: configuration folder: {self.config.dirpath}', result.output) - self.assertIn(f'* {self.profile_list[0]}', result.output) - self.assertIn(self.profile_list[1], result.output) + result = self.cli_runner(cmd_profile.profile_list) + assert f'Report: configuration folder: {self.config.dirpath}' in result.output + assert f'* {self.profile_list[0]}' in result.output + assert self.profile_list[1] in result.output def test_setdefault(self): """Test the `verdi profile setdefault` command.""" self.mock_profiles() - result = self.cli_runner.invoke(cmd_profile.profile_setdefault, [self.profile_list[1]]) - self.assertClickSuccess(result) + self.cli_runner(cmd_profile.profile_setdefault, [self.profile_list[1]]) + result = self.cli_runner(cmd_profile.profile_list) - result = self.cli_runner.invoke(cmd_profile.profile_list) - - self.assertClickSuccess(result) - self.assertIn(f'Report: configuration folder: {self.config.dirpath}', result.output) - self.assertIn(f'* {self.profile_list[1]}', result.output) - self.assertClickSuccess(result) + assert f'Report: configuration folder: {self.config.dirpath}' in result.output + assert f'* {self.profile_list[1]}' in result.output def test_show(self): """Test the `verdi profile show` command.""" @@ -98,29 +97,25 @@ def test_show(self): profile_name = self.profile_list[0] profile = config.get_profile(profile_name) - result = self.cli_runner.invoke(cmd_profile.profile_show, [profile_name]) - self.assertClickSuccess(result) + result = self.cli_runner(cmd_profile.profile_show, [profile_name]) for key, value in profile.dictionary.items(): if isinstance(value, str): - self.assertIn(key, result.output) - self.assertIn(value, result.output) + assert key in result.output + assert value in result.output def test_show_with_profile_option(self): """Test the `verdi profile show` command in combination with `-p/--profile.""" - get_manager().unload_profile() self.mock_profiles() profile_name_non_default = self.profile_list[1] # Specifying the non-default profile as argument should override the default - result = self.cli_runner.invoke(cmd_profile.profile_show, [profile_name_non_default]) - self.assertClickSuccess(result) - self.assertTrue(profile_name_non_default in result.output) + result = self.cli_runner(cmd_profile.profile_show, [profile_name_non_default]) + assert profile_name_non_default in result.output # Specifying `-p/--profile` should not override the argument default (which should be the default profile) - result = self.cli_runner.invoke(cmd_verdi.verdi, ['-p', profile_name_non_default, 'profile', 'show']) - self.assertClickSuccess(result) - self.assertTrue(profile_name_non_default not in result.output) + result = self.cli_runner(cmd_verdi.verdi, ['-p', profile_name_non_default, 'profile', 'show']) + assert profile_name_non_default not in result.output def test_delete_partial(self): """Test the `verdi profile delete` command. @@ -130,38 +125,24 @@ def test_delete_partial(self): """ self.mock_profiles() - result = self.cli_runner.invoke(cmd_profile.profile_delete, ['--force', '--skip-db', self.profile_list[1]]) - self.assertClickSuccess(result) - - result = self.cli_runner.invoke(cmd_profile.profile_list) - self.assertClickSuccess(result) - self.assertNotIn(self.profile_list[1], result.output) + self.cli_runner(cmd_profile.profile_delete, ['--force', '--skip-db', self.profile_list[1]]) + result = self.cli_runner(cmd_profile.profile_list) + assert self.profile_list[1] not in result.output def test_delete(self): """Test for verdi profile delete command.""" from aiida.cmdline.commands.cmd_profile import profile_delete, profile_list - get_manager().unload_profile() - kwargs = {'database_port': self.pg_test.dsn['port']} self.mock_profiles(**kwargs) # Delete single profile - result = self.cli_runner.invoke(profile_delete, ['--force', self.profile_list[1]]) - self.assertIsNone(result.exception, result.output) - - result = self.cli_runner.invoke(profile_list) - self.assertIsNone(result.exception, result.output) - - self.assertNotIn(self.profile_list[1], result.output) - self.assertIsNone(result.exception, result.output) + self.cli_runner(profile_delete, ['--force', self.profile_list[1]]) + result = self.cli_runner(profile_list) + assert self.profile_list[1] not in result.output # Delete multiple profiles - result = self.cli_runner.invoke(profile_delete, ['--force', self.profile_list[2], self.profile_list[3]]) - self.assertIsNone(result.exception, result.output) - - result = self.cli_runner.invoke(profile_list) - self.assertIsNone(result.exception, result.output) - self.assertNotIn(self.profile_list[2], result.output) - self.assertNotIn(self.profile_list[3], result.output) - self.assertIsNone(result.exception, result.output) + self.cli_runner(profile_delete, ['--force', self.profile_list[2], self.profile_list[3]]) + result = self.cli_runner(profile_list) + assert self.profile_list[2] not in result.output + assert self.profile_list[3] not in result.output diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index dd0672b0bc..6fd599cca1 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -11,20 +11,20 @@ import tempfile import textwrap -from click.testing import CliRunner import pytest from aiida.cmdline.commands import cmd_run from aiida.common.log import override_log_level -from aiida.storage.testbase import AiidaTestCase -class TestVerdiRun(AiidaTestCase): +class TestVerdiRun: """Tests for `verdi run`.""" - def setUp(self): - super().setUp() - self.cli_runner = CliRunner() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.cli_runner = run_cli_command @pytest.mark.requires_rmq def test_run_workfunction(self): @@ -59,26 +59,26 @@ def wf(): fhandle.flush() options = [fhandle.name] - result = self.cli_runner.invoke(cmd_run.run, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_run.run, options) # Try to load the function calculation node from the printed pk in the output pk = int(result.output.splitlines()[-1]) node = load_node(pk) # Verify that the node has the correct function name and content - self.assertTrue(isinstance(node, WorkFunctionNode)) - self.assertEqual(node.function_name, 'wf') - self.assertEqual(node.get_function_source_code(), script_content) + assert isinstance(node, WorkFunctionNode) + assert node.function_name == 'wf' + assert node.get_function_source_code() == script_content -class TestAutoGroups(AiidaTestCase): +class TestAutoGroups: """Test the autogroup functionality.""" - def setUp(self): - """Setup the CLI runner to run command line commands.""" - super().setUp() - self.cli_runner = CliRunner() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.cli_runner = run_cli_command def test_autogroup(self): """Check if the autogroup is properly generated.""" @@ -97,8 +97,7 @@ def test_autogroup(self): fhandle.flush() options = ['--auto-group', fhandle.name] - result = self.cli_runner.invoke(cmd_run.run, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_run.run, options) pk = int(result.output) _ = load_node(pk) # Check if the node can be loaded @@ -106,9 +105,7 @@ def test_autogroup(self): queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() - self.assertEqual( - len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' - ) + assert len(all_auto_groups) == 1, 'There should be only one autogroup associated with the node just created' def test_autogroup_custom_label(self): """Check if the autogroup is properly generated with the label specified.""" @@ -128,8 +125,7 @@ def test_autogroup_custom_label(self): fhandle.flush() options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] - result = self.cli_runner.invoke(cmd_run.run, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_run.run, options) pk = int(result.output) _ = load_node(pk) # Check if the node can be loaded @@ -137,10 +133,8 @@ def test_autogroup_custom_label(self): queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() - self.assertEqual( - len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' - ) - self.assertEqual(all_auto_groups[0][0].label, autogroup_label) + assert len(all_auto_groups) == 1, 'There should be only one autogroup associated with the node just created' + assert all_auto_groups[0][0].label == autogroup_label def test_no_autogroup(self): """Check if the autogroup is not generated if ``verdi run`` is asked not to.""" @@ -159,8 +153,7 @@ def test_no_autogroup(self): fhandle.flush() options = [fhandle.name] # Not storing an autogroup by default - result = self.cli_runner.invoke(cmd_run.run, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_run.run, options) pk = int(result.output) _ = load_node(pk) # Check if the node can be loaded @@ -168,7 +161,7 @@ def test_no_autogroup(self): queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() - self.assertEqual(len(all_auto_groups), 0, 'There should be no autogroup generated') + assert len(all_auto_groups) == 0, 'There should be no autogroup generated' @pytest.mark.requires_rmq def test_autogroup_filter_class(self): # pylint: disable=too-many-locals @@ -270,8 +263,7 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals options = ['--auto-group'] + flags + ['--', fhandle.name, str(idx)] with override_log_level(): - result = self.cli_runner.invoke(cmd_run.run, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_run.run, options) pk1_str, pk2_str, pk3_str, pk4_str, pk5_str, pk6_str = result.output.split() pk1 = int(pk1_str) @@ -311,36 +303,24 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_calcarithmetic = queryb.all() - self.assertEqual( - len(all_auto_groups_kptdata), 1 if kptdata_in_autogroup else 0, - 'Wrong number of nodes in autogroup associated with the KpointsData node ' + assert len(all_auto_groups_kptdata) == (1 if kptdata_in_autogroup else 0), \ + 'Wrong number of nodes in autogroup associated with the KpointsData node ' \ "just created with flags '{}'".format(' '.join(flags)) - ) - self.assertEqual( - len(all_auto_groups_arraydata), 1 if arraydata_in_autogroup else 0, - 'Wrong number of nodes in autogroup associated with the ArrayData node ' + assert len(all_auto_groups_arraydata) == (1 if arraydata_in_autogroup else 0), \ + 'Wrong number of nodes in autogroup associated with the ArrayData node ' \ "just created with flags '{}'".format(' '.join(flags)) - ) - self.assertEqual( - len(all_auto_groups_int), 1 if int_in_autogroup else 0, - 'Wrong number of nodes in autogroup associated with the Int node ' + assert len(all_auto_groups_int) == (1 if int_in_autogroup else 0), \ + 'Wrong number of nodes in autogroup associated with the Int node ' \ "just created with flags '{}'".format(' '.join(flags)) - ) - self.assertEqual( - len(all_auto_groups_calc), 1 if calc_in_autogroup else 0, - 'Wrong number of nodes in autogroup associated with the CalculationNode ' + assert len(all_auto_groups_calc) == (1 if calc_in_autogroup else 0), \ + 'Wrong number of nodes in autogroup associated with the CalculationNode ' \ "just created with flags '{}'".format(' '.join(flags)) - ) - self.assertEqual( - len(all_auto_groups_wf), 1 if wf_in_autogroup else 0, - 'Wrong number of nodes in autogroup associated with the WorkflowNode ' + assert len(all_auto_groups_wf) == (1 if wf_in_autogroup else 0), \ + 'Wrong number of nodes in autogroup associated with the WorkflowNode ' \ "just created with flags '{}'".format(' '.join(flags)) - ) - self.assertEqual( - len(all_auto_groups_calcarithmetic), 1 if calcarithmetic_in_autogroup else 0, - 'Wrong number of nodes in autogroup associated with the ArithmeticAdd CalcJobNode ' + assert len(all_auto_groups_calcarithmetic) == (1 if calcarithmetic_in_autogroup else 0), \ + 'Wrong number of nodes in autogroup associated with the ArithmeticAdd CalcJobNode ' \ "just created with flags '{}'".format(' '.join(flags)) - ) def test_autogroup_clashing_label(self): """Check if the autogroup label is properly (re)generated when it clashes with an existing group name.""" @@ -361,31 +341,27 @@ def test_autogroup_clashing_label(self): # First run options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] - result = self.cli_runner.invoke(cmd_run.run, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_run.run, options) pk = int(result.output) _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() - self.assertEqual( - len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' - ) - self.assertEqual(all_auto_groups[0][0].label, autogroup_label) + assert len(all_auto_groups) == 1, 'There should be only one autogroup associated with the node just created' + assert all_auto_groups[0][0].label == autogroup_label # A few more runs with the same label - it should not crash but append something to the group name for _ in range(10): options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] - result = self.cli_runner.invoke(cmd_run.run, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_run.run, options) pk = int(result.output) _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() - self.assertEqual( - len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' - ) - self.assertTrue(all_auto_groups[0][0].label.startswith(autogroup_label)) + assert len( + all_auto_groups + ) == 1, 'There should be only one autogroup associated with the node just created' + assert all_auto_groups[0][0].label.startswith(autogroup_label) diff --git a/tests/cmdline/commands/test_setup.py b/tests/cmdline/commands/test_setup.py index 76507a8fbc..3e5d900134 100644 --- a/tests/cmdline/commands/test_setup.py +++ b/tests/cmdline/commands/test_setup.py @@ -10,27 +10,34 @@ """Tests for `verdi profile`.""" import os import tempfile -import traceback -from click.testing import CliRunner +from pgtest.pgtest import PGTest import pytest from aiida import orm from aiida.cmdline.commands import cmd_setup -from aiida.manage import configuration, get_manager +from aiida.manage import configuration from aiida.manage.external.postgres import Postgres -from aiida.storage.testbase import AiidaPostgresTestCase -@pytest.mark.usefixtures('config_with_profile') -class TestVerdiSetup(AiidaPostgresTestCase): +@pytest.fixture(scope='class') +def pg_test_cluster(): + """Create a standalone Postgres cluster, for setup tests.""" + pg_test = PGTest() + yield pg_test + pg_test.close() + + +class TestVerdiSetup: """Tests for `verdi setup` and `verdi quicksetup`.""" - def setUp(self): - """Create a CLI runner to invoke the CLI commands.""" - super().setUp() - self.backend = configuration.get_profile().storage_backend - self.cli_runner = CliRunner() + @pytest.fixture(autouse=True) + def init_profile(self, pg_test_cluster, empty_config, run_cli_command): # pylint: disable=redefined-outer-name,unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.storage_backend_name = 'psql_dos' + self.pg_test = pg_test_cluster + self.cli_runner = run_cli_command def test_help(self): """Check that the `--help` option is eager, is not overruled and will properly display the help message. @@ -38,21 +45,11 @@ def test_help(self): If this test hangs, most likely the `--help` eagerness is overruled by another option that has started the prompt cycle, which by waiting for input, will block the test from continuing. """ - self.cli_runner.invoke(cmd_setup.setup, ['--help'], catch_exceptions=False) - self.cli_runner.invoke(cmd_setup.quicksetup, ['--help'], catch_exceptions=False) + self.cli_runner(cmd_setup.setup, ['--help'], catch_exceptions=False) + self.cli_runner(cmd_setup.quicksetup, ['--help'], catch_exceptions=False) def test_quicksetup(self): """Test `verdi quicksetup`.""" - config = configuration.get_config() - get_manager().unload_profile() - profile_name = 'testing' - user_email = 'some@email.com' - user_first_name = 'John' - user_last_name = 'Smith' - user_institution = 'ECMA' - - config = configuration.get_config() - profile_name = 'testing' user_email = 'some@email.com' user_first_name = 'John' @@ -62,34 +59,31 @@ def test_quicksetup(self): options = [ '--non-interactive', '--profile', profile_name, '--email', user_email, '--first-name', user_first_name, '--last-name', user_last_name, '--institution', user_institution, '--db-port', self.pg_test.dsn['port'], - '--db-backend', self.backend + '--db-backend', self.storage_backend_name ] - result = self.cli_runner.invoke(cmd_setup.quicksetup, options) - self.assertClickResultNoException(result) - self.assertClickSuccess(result) + self.cli_runner(cmd_setup.quicksetup, options) config = configuration.get_config() - self.assertIn(profile_name, config.profile_names) + assert profile_name in config.profile_names profile = config.get_profile(profile_name) profile.default_user_email = user_email # Verify that the backend type of the created profile matches that of the profile for the current test session - self.assertEqual(self.backend, profile.storage_backend) + assert self.storage_backend_name == profile.storage_backend user = orm.User.objects.get(email=user_email) - self.assertEqual(user.first_name, user_first_name) - self.assertEqual(user.last_name, user_last_name) - self.assertEqual(user.institution, user_institution) + assert user.first_name == user_first_name + assert user.last_name == user_last_name + assert user.institution == user_institution # Check that the repository UUID was stored in the database backend = profile.storage_cls(profile) - self.assertEqual(backend.get_global_variable('repository|uuid'), backend.get_repository().uuid) + assert backend.get_global_variable('repository|uuid') == backend.get_repository().uuid def test_quicksetup_from_config_file(self): """Test `verdi quicksetup` from configuration file.""" - get_manager().unload_profile() with tempfile.NamedTemporaryFile('w') as handle: handle.write( f"""--- @@ -97,17 +91,15 @@ def test_quicksetup_from_config_file(self): first_name: Leopold last_name: Talirz institution: EPFL -db_backend: {self.backend} +db_backend: {self.storage_backend_name} +db_port: {self.pg_test.dsn['port']} email: 123@234.de""" ) handle.flush() - result = self.cli_runner.invoke(cmd_setup.quicksetup, ['--config', os.path.realpath(handle.name)]) - self.assertClickResultNoException(result) + self.cli_runner(cmd_setup.quicksetup, ['--config', os.path.realpath(handle.name)]) def test_quicksetup_wrong_port(self): """Test `verdi quicksetup` exits if port is wrong.""" - get_manager().unload_profile() - profile_name = 'testing' user_email = 'some@email.com' user_first_name = 'John' @@ -120,8 +112,7 @@ def test_quicksetup_wrong_port(self): self.pg_test.dsn['port'] + 100 ] - result = self.cli_runner.invoke(cmd_setup.quicksetup, options) - self.assertIsNotNone(result.exception, ''.join(traceback.format_exception(*result.exc_info))) + self.cli_runner(cmd_setup.quicksetup, options, raises=True) def test_setup(self): """Test `verdi setup` (non-interactive).""" @@ -132,7 +123,6 @@ def test_setup(self): db_pass = 'aiida_test_setup' postgres.create_dbuser(db_user, db_pass) postgres.create_db(db_user, db_name) - get_manager().unload_profile() profile_name = 'testing' user_email = 'some@email.com' @@ -146,27 +136,25 @@ def test_setup(self): options = [ '--non-interactive', '--email', user_email, '--first-name', user_first_name, '--last-name', user_last_name, '--institution', user_institution, '--db-name', db_name, '--db-username', db_user, '--db-password', db_pass, - '--db-port', self.pg_test.dsn['port'], '--db-backend', self.backend, '--profile', profile_name + '--db-port', self.pg_test.dsn['port'], '--db-backend', self.storage_backend_name, '--profile', profile_name ] - result = self.cli_runner.invoke(cmd_setup.setup, options) - self.assertClickResultNoException(result) - self.assertClickSuccess(result) + self.cli_runner(cmd_setup.setup, options) config = configuration.get_config() - self.assertIn(profile_name, config.profile_names) + assert profile_name in config.profile_names profile = config.get_profile(profile_name) profile.default_user_email = user_email # Verify that the backend type of the created profile matches that of the profile for the current test session - self.assertEqual(self.backend, profile.storage_backend) + assert self.storage_backend_name == profile.storage_backend user = orm.User.objects.get(email=user_email) - self.assertEqual(user.first_name, user_first_name) - self.assertEqual(user.last_name, user_last_name) - self.assertEqual(user.institution, user_institution) + assert user.first_name == user_first_name + assert user.last_name == user_last_name + assert user.institution == user_institution # Check that the repository UUID was stored in the database backend = profile.storage_cls(profile) - self.assertEqual(backend.get_global_variable('repository|uuid'), backend.get_repository().uuid) + assert backend.get_global_variable('repository|uuid') == backend.get_repository().uuid diff --git a/tests/cmdline/commands/test_user.py b/tests/cmdline/commands/test_user.py index f8d4160c5f..cb4b29aaeb 100644 --- a/tests/cmdline/commands/test_user.py +++ b/tests/cmdline/commands/test_user.py @@ -8,12 +8,10 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `verdi user`.""" - -from click.testing import CliRunner +import pytest from aiida import orm from aiida.cmdline.commands import cmd_user -from aiida.storage.testbase import AiidaTestCase USER_1 = { # pylint: disable=invalid-name 'email': 'testuser1@localhost', @@ -29,11 +27,14 @@ } -class TestVerdiUserCommand(AiidaTestCase): +class TestVerdiUserCommand: """Test verdi user.""" - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.cli_runner = run_cli_command created, user = orm.User.objects.get_or_create(email=USER_1['email']) for key, value in USER_1.items(): @@ -41,14 +42,13 @@ def setUp(self): setattr(user, key, value) if created: orm.User(**USER_1).store() - self.cli_runner = CliRunner() def test_user_list(self): """Test `verdi user list`.""" from aiida.cmdline.commands.cmd_user import user_list as list_user - result = self.cli_runner.invoke(list_user, [], catch_exceptions=False) - self.assertTrue(USER_1['email'] in result.output) + result = self.cli_runner(list_user, [], catch_exceptions=False) + assert USER_1['email'] in result.output def test_user_create(self): """Create a new user with `verdi user configure`.""" @@ -63,14 +63,14 @@ def test_user_create(self): USER_2['institution'], ] - result = self.cli_runner.invoke(cmd_user.user_configure, cli_options, catch_exceptions=False) - self.assertTrue(USER_2['email'] in result.output) - self.assertTrue('created' in result.output) - self.assertTrue('updated' not in result.output) + result = self.cli_runner(cmd_user.user_configure, cli_options, catch_exceptions=False) + assert USER_2['email'] in result.output + assert 'created' in result.output + assert 'updated' not in result.output user_obj = orm.User.objects.get(email=USER_2['email']) for key, val in USER_2.items(): - self.assertEqual(val, getattr(user_obj, key)) + assert val == getattr(user_obj, key) def test_user_update(self): """Reconfigure an existing user with `verdi user configure`.""" @@ -87,10 +87,10 @@ def test_user_update(self): USER_2['institution'], ] - result = self.cli_runner.invoke(cmd_user.user_configure, cli_options, catch_exceptions=False) - self.assertTrue(email in result.output) - self.assertTrue('updated' in result.output) - self.assertTrue('created' not in result.output) + result = self.cli_runner(cmd_user.user_configure, cli_options, catch_exceptions=False) + assert email in result.output + assert 'updated' in result.output + assert 'created' not in result.output # Check it's all been changed to user2's attributes except the email for key, _ in USER_2.items(): diff --git a/tests/cmdline/params/options/test_conditional.py b/tests/cmdline/params/options/test_conditional.py index 830835a2b2..33b761f848 100644 --- a/tests/cmdline/params/options/test_conditional.py +++ b/tests/cmdline/params/options/test_conditional.py @@ -40,8 +40,8 @@ def command_multi_non_eager(a_or_b, opt_a, opt_b): """Return a command that has two scenarios. * flag a_or_b (--a/--b) - * opt-a required if a_or_b == True - * opt-b required if a_or_b == False + * opt-a required if a_or_b is True + * opt-b required if a_or_b is False """ # pylint: disable=unused-argument click.echo(f'{opt_a} / {opt_b}') diff --git a/tests/cmdline/params/types/test_calculation.py b/tests/cmdline/params/types/test_calculation.py index fcacd3052f..9077d16804 100644 --- a/tests/cmdline/params/types/test_calculation.py +++ b/tests/cmdline/params/types/test_calculation.py @@ -8,38 +8,37 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `CalculationParamType`.""" +import pytest from aiida.cmdline.params.types import CalculationParamType from aiida.orm import CalcFunctionNode, CalcJobNode, CalculationNode, WorkChainNode, WorkFunctionNode from aiida.orm.utils.loaders import OrmEntityLoader -from aiida.storage.testbase import AiidaTestCase -class TestCalculationParamType(AiidaTestCase): +class TestCalculationParamType: """Tests for the `CalculationParamType`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument """ Create some code to test the CalculationParamType parameter type for the command line infrastructure We create an initial code with a random name and then on purpose create two code with a name that matches exactly the ID and UUID, respectively, of the first one. This allows us to test the rules implemented to solve ambiguities that arise when determing the identifier type """ - super().setUpClass(*args, **kwargs) + # pylint: disable=attribute-defined-outside-init + self.param = CalculationParamType() + self.entity_01 = CalculationNode().store() + self.entity_02 = CalculationNode().store() + self.entity_03 = CalculationNode().store() + self.entity_04 = WorkFunctionNode() + self.entity_05 = CalcFunctionNode() + self.entity_06 = CalcJobNode() + self.entity_07 = WorkChainNode() - cls.param = CalculationParamType() - cls.entity_01 = CalculationNode().store() - cls.entity_02 = CalculationNode().store() - cls.entity_03 = CalculationNode().store() - cls.entity_04 = WorkFunctionNode() - cls.entity_05 = CalcFunctionNode() - cls.entity_06 = CalcJobNode() - cls.entity_07 = WorkChainNode() - - cls.entity_01.label = 'calculation_01' - cls.entity_02.label = str(cls.entity_01.pk) - cls.entity_03.label = str(cls.entity_01.uuid) + self.entity_01.label = 'calculation_01' + self.entity_02.label = str(self.entity_01.pk) + self.entity_03.label = str(self.entity_01.uuid) def test_get_by_id(self): """ @@ -47,7 +46,7 @@ def test_get_by_id(self): """ identifier = f'{self.entity_01.pk}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_uuid(self): """ @@ -55,7 +54,7 @@ def test_get_by_uuid(self): """ identifier = f'{self.entity_01.uuid}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_label(self): """ @@ -63,7 +62,7 @@ def test_get_by_label(self): """ identifier = f'{self.entity_01.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_ambiguous_label_pk(self): """ @@ -74,11 +73,11 @@ def test_ambiguous_label_pk(self): """ identifier = f'{self.entity_02.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_02.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_02.uuid) + assert result.uuid == self.entity_02.uuid def test_ambiguous_label_uuid(self): """ @@ -89,8 +88,8 @@ def test_ambiguous_label_uuid(self): """ identifier = f'{self.entity_03.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_03.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_03.uuid) + assert result.uuid == self.entity_03.uuid diff --git a/tests/cmdline/params/types/test_data.py b/tests/cmdline/params/types/test_data.py index 30dd277a44..4c2f1d8733 100644 --- a/tests/cmdline/params/types/test_data.py +++ b/tests/cmdline/params/types/test_data.py @@ -8,34 +8,33 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `DataParamType`.""" +import pytest from aiida.cmdline.params.types import DataParamType from aiida.orm import Data from aiida.orm.utils.loaders import OrmEntityLoader -from aiida.storage.testbase import AiidaTestCase -class TestDataParamType(AiidaTestCase): +class TestDataParamType: """Tests for the `DataParamType`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument """ Create some code to test the DataParamType parameter type for the command line infrastructure We create an initial code with a random name and then on purpose create two code with a name that matches exactly the ID and UUID, respectively, of the first one. This allows us to test the rules implemented to solve ambiguities that arise when determing the identifier type """ - super().setUpClass(*args, **kwargs) + # pylint: disable=attribute-defined-outside-init + self.param = DataParamType() + self.entity_01 = Data().store() + self.entity_02 = Data().store() + self.entity_03 = Data().store() - cls.param = DataParamType() - cls.entity_01 = Data().store() - cls.entity_02 = Data().store() - cls.entity_03 = Data().store() - - cls.entity_01.label = 'data_01' - cls.entity_02.label = str(cls.entity_01.pk) - cls.entity_03.label = str(cls.entity_01.uuid) + self.entity_01.label = 'data_01' + self.entity_02.label = str(self.entity_01.pk) + self.entity_03.label = str(self.entity_01.uuid) def test_get_by_id(self): """ @@ -43,7 +42,7 @@ def test_get_by_id(self): """ identifier = f'{self.entity_01.pk}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_uuid(self): """ @@ -51,7 +50,7 @@ def test_get_by_uuid(self): """ identifier = f'{self.entity_01.uuid}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_label(self): """ @@ -59,7 +58,7 @@ def test_get_by_label(self): """ identifier = f'{self.entity_01.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_ambiguous_label_pk(self): """ @@ -70,11 +69,11 @@ def test_ambiguous_label_pk(self): """ identifier = f'{self.entity_02.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_02.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_02.uuid) + assert result.uuid == self.entity_02.uuid def test_ambiguous_label_uuid(self): """ @@ -85,8 +84,8 @@ def test_ambiguous_label_uuid(self): """ identifier = f'{self.entity_03.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_03.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_03.uuid) + assert result.uuid == self.entity_03.uuid diff --git a/tests/cmdline/params/types/test_identifier.py b/tests/cmdline/params/types/test_identifier.py index c9cb1c6f64..365c8be49b 100644 --- a/tests/cmdline/params/types/test_identifier.py +++ b/tests/cmdline/params/types/test_identifier.py @@ -8,42 +8,44 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `IdentifierParamType`.""" - import click +import pytest from aiida.cmdline.params.types import IdentifierParamType, NodeParamType from aiida.orm import Bool, Float, Int -from aiida.storage.testbase import AiidaTestCase -class TestIdentifierParamType(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestIdentifierParamType: """Tests for the `IdentifierParamType`.""" + # pylint: disable=no-self-use + def test_base_class(self): """ The base class is abstract and should not be constructable """ - with self.assertRaises(TypeError): + with pytest.raises(TypeError): IdentifierParamType() # pylint: disable=abstract-class-instantiated def test_identifier_sub_invalid_type(self): """ The sub_classes keyword argument should expect a tuple """ - with self.assertRaises(TypeError): + with pytest.raises(TypeError): NodeParamType(sub_classes='aiida.data:core.structure') - with self.assertRaises(TypeError): + with pytest.raises(TypeError): NodeParamType(sub_classes=(None,)) def test_identifier_sub_invalid_entry_point(self): """ The sub_classes keyword argument should expect a tuple of valid entry point strings """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): NodeParamType(sub_classes=('aiida.data.structure',)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): NodeParamType(sub_classes=('aiida.data:not_existent',)) def test_identifier_sub_classes(self): @@ -58,14 +60,14 @@ def test_identifier_sub_classes(self): param_type_scoped = NodeParamType(sub_classes=('aiida.data:core.bool', 'aiida.data:core.float')) # For the base NodeParamType all node types should be matched - self.assertEqual(param_type_normal.convert(str(node_bool.pk), None, None).uuid, node_bool.uuid) - self.assertEqual(param_type_normal.convert(str(node_float.pk), None, None).uuid, node_float.uuid) - self.assertEqual(param_type_normal.convert(str(node_int.pk), None, None).uuid, node_int.uuid) + assert param_type_normal.convert(str(node_bool.pk), None, None).uuid == node_bool.uuid + assert param_type_normal.convert(str(node_float.pk), None, None).uuid == node_float.uuid + assert param_type_normal.convert(str(node_int.pk), None, None).uuid == node_int.uuid # The scoped NodeParamType should only match Bool and Float - self.assertEqual(param_type_scoped.convert(str(node_bool.pk), None, None).uuid, node_bool.uuid) - self.assertEqual(param_type_scoped.convert(str(node_float.pk), None, None).uuid, node_float.uuid) + assert param_type_scoped.convert(str(node_bool.pk), None, None).uuid == node_bool.uuid + assert param_type_scoped.convert(str(node_float.pk), None, None).uuid == node_float.uuid # The Int should not be found and raise - with self.assertRaises(click.BadParameter): + with pytest.raises(click.BadParameter): param_type_scoped.convert(str(node_int.pk), None, None) diff --git a/tests/cmdline/params/types/test_node.py b/tests/cmdline/params/types/test_node.py index 9b9ac85ce9..abec73ecb4 100644 --- a/tests/cmdline/params/types/test_node.py +++ b/tests/cmdline/params/types/test_node.py @@ -8,33 +8,33 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `NodeParamType`.""" +import pytest + from aiida.cmdline.params.types import NodeParamType from aiida.orm import Data from aiida.orm.utils.loaders import OrmEntityLoader -from aiida.storage.testbase import AiidaTestCase -class TestNodeParamType(AiidaTestCase): +class TestNodeParamType: """Tests for the `NodeParamType`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument """ Create some code to test the NodeParamType parameter type for the command line infrastructure We create an initial code with a random name and then on purpose create two code with a name that matches exactly the ID and UUID, respectively, of the first one. This allows us to test the rules implemented to solve ambiguities that arise when determing the identifier type """ - super().setUpClass(*args, **kwargs) - - cls.param = NodeParamType() - cls.entity_01 = Data().store() - cls.entity_02 = Data().store() - cls.entity_03 = Data().store() + # pylint: disable=attribute-defined-outside-init + self.param = NodeParamType() + self.entity_01 = Data().store() + self.entity_02 = Data().store() + self.entity_03 = Data().store() - cls.entity_01.label = 'data_01' - cls.entity_02.label = str(cls.entity_01.pk) - cls.entity_03.label = str(cls.entity_01.uuid) + self.entity_01.label = 'data_01' + self.entity_02.label = str(self.entity_01.pk) + self.entity_03.label = str(self.entity_01.uuid) def test_get_by_id(self): """ @@ -42,7 +42,7 @@ def test_get_by_id(self): """ identifier = f'{self.entity_01.pk}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_uuid(self): """ @@ -50,7 +50,7 @@ def test_get_by_uuid(self): """ identifier = f'{self.entity_01.uuid}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_label(self): """ @@ -58,7 +58,7 @@ def test_get_by_label(self): """ identifier = f'{self.entity_01.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_ambiguous_label_pk(self): """ @@ -69,11 +69,11 @@ def test_ambiguous_label_pk(self): """ identifier = f'{self.entity_02.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_02.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_02.uuid) + assert result.uuid == self.entity_02.uuid def test_ambiguous_label_uuid(self): """ @@ -84,8 +84,8 @@ def test_ambiguous_label_uuid(self): """ identifier = f'{self.entity_03.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_03.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_03.uuid) + assert result.uuid == self.entity_03.uuid diff --git a/tests/cmdline/params/types/test_path.py b/tests/cmdline/params/types/test_path.py index 5acd65ca27..a8702431e7 100644 --- a/tests/cmdline/params/types/test_path.py +++ b/tests/cmdline/params/types/test_path.py @@ -7,12 +7,14 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for Path types""" +import pytest + from aiida.cmdline.params.types.path import PathOrUrl, check_timeout_seconds -from aiida.storage.testbase import AiidaTestCase -class TestPath(AiidaTestCase): +class TestPath: """Tests for `PathOrUrl` and `FileOrUrl`""" def test_default_timeout(self): @@ -21,7 +23,7 @@ def test_default_timeout(self): import_path = PathOrUrl() - self.assertEqual(import_path.timeout_seconds, URL_TIMEOUT_SECONDS) + assert import_path.timeout_seconds == URL_TIMEOUT_SECONDS def test_timeout_checks(self): """Test that timeout check handles different values. @@ -34,12 +36,12 @@ def test_timeout_checks(self): valid_values = [42, '42'] for value in valid_values: - self.assertEqual(check_timeout_seconds(value), int(value)) + assert check_timeout_seconds(value) == int(value) for invalid in [None, 'test']: - with self.assertRaises(TypeError): + with pytest.raises(TypeError): check_timeout_seconds(invalid) for invalid in [-5, 65]: - with self.assertRaises(ValueError): + with pytest.raises(ValueError): check_timeout_seconds(invalid) diff --git a/tests/cmdline/params/types/test_plugin.py b/tests/cmdline/params/types/test_plugin.py index 760f58f0f5..1aad413580 100644 --- a/tests/cmdline/params/types/test_plugin.py +++ b/tests/cmdline/params/types/test_plugin.py @@ -8,47 +8,48 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `PluginParamType`.""" - import click +import pytest from aiida.cmdline.params.types.plugin import PluginParamType from aiida.plugins.entry_point import get_entry_point_from_string -from aiida.storage.testbase import AiidaTestCase -class TestPluginParamType(AiidaTestCase): +class TestPluginParamType: """Tests for the `PluginParamType`.""" + # pylint: disable=no-self-use + def test_group_definition(self): """ Test the various accepted syntaxes of defining supported entry point groups. Both single values as well as tuples should be allowed. The `aiida.` prefix should also be optional. """ param = PluginParamType(group='calculations') - self.assertIn('aiida.calculations', param.groups) - self.assertTrue(len(param.groups), 1) + assert 'aiida.calculations' in param.groups + assert len(param.groups) == 1 param = PluginParamType(group='aiida.calculations') - self.assertIn('aiida.calculations', param.groups) - self.assertTrue(len(param.groups), 1) + assert 'aiida.calculations' in param.groups + assert len(param.groups) == 1 param = PluginParamType(group=('calculations',)) - self.assertIn('aiida.calculations', param.groups) - self.assertTrue(len(param.groups), 1) + assert 'aiida.calculations' in param.groups + assert len(param.groups) == 1 param = PluginParamType(group=('aiida.calculations',)) - self.assertIn('aiida.calculations', param.groups) - self.assertTrue(len(param.groups), 1) + assert 'aiida.calculations' in param.groups + assert len(param.groups) == 1 param = PluginParamType(group=('aiida.calculations', 'aiida.data')) - self.assertIn('aiida.calculations', param.groups) - self.assertIn('aiida.data', param.groups) - self.assertTrue(len(param.groups), 2) + assert 'aiida.calculations' in param.groups + assert 'aiida.data' in param.groups + assert len(param.groups) == 2 param = PluginParamType(group=('aiida.calculations', 'data')) - self.assertIn('aiida.calculations', param.groups) - self.assertIn('aiida.data', param.groups) - self.assertTrue(len(param.groups), 2) + assert 'aiida.calculations' in param.groups + assert 'aiida.data' in param.groups + assert len(param.groups) == 2 def test_get_entry_point_from_string(self): """ @@ -59,39 +60,39 @@ def test_get_entry_point_from_string(self): entry_point = get_entry_point_from_string('aiida.transports:core.ssh') # Invalid entry point strings - with self.assertRaises(ValueError): + with pytest.raises(ValueError): param.get_entry_point_from_string('aiida.transport:core.ssh') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): param.get_entry_point_from_string('aiid.transports:core.ssh') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): param.get_entry_point_from_string('aiida..transports:core.ssh') # Unsupported entry points for all formats - with self.assertRaises(ValueError): + with pytest.raises(ValueError): param.get_entry_point_from_string('aiida.data:core.structure') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): param.get_entry_point_from_string('data:structure') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): param.get_entry_point_from_string('structure') # Non-existent entry points for all formats - with self.assertRaises(ValueError): + with pytest.raises(ValueError): param.get_entry_point_from_string('aiida.transports:not_existent') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): param.get_entry_point_from_string('transports:not_existent') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): param.get_entry_point_from_string('not_existent') # Valid entry point strings - self.assertEqual(param.get_entry_point_from_string('aiida.transports:core.ssh').name, entry_point.name) - self.assertEqual(param.get_entry_point_from_string('transports:core.ssh').name, entry_point.name) - self.assertEqual(param.get_entry_point_from_string('core.ssh').name, entry_point.name) + assert param.get_entry_point_from_string('aiida.transports:core.ssh').name == entry_point.name + assert param.get_entry_point_from_string('transports:core.ssh').name == entry_point.name + assert param.get_entry_point_from_string('core.ssh').name == entry_point.name def test_get_entry_point_from_ambiguous(self): """ @@ -102,14 +103,12 @@ def test_get_entry_point_from_ambiguous(self): entry_point = get_entry_point_from_string('aiida.calculations:core.arithmetic.add') # Both groups contain entry point `arithmetic.add` so passing only name is ambiguous and should raise - with self.assertRaises(ValueError): + with pytest.raises(ValueError): param.get_entry_point_from_string('core.arithmetic.add') # Passing PARTIAL or FULL should allow entry point to be returned - self.assertEqual( - param.get_entry_point_from_string('aiida.calculations:core.arithmetic.add').name, entry_point.name - ) - self.assertEqual(param.get_entry_point_from_string('calculations:core.arithmetic.add').name, entry_point.name) + assert param.get_entry_point_from_string('aiida.calculations:core.arithmetic.add').name == entry_point.name + assert param.get_entry_point_from_string('calculations:core.arithmetic.add').name == entry_point.name def test_convert(self): """ @@ -118,24 +117,24 @@ def test_convert(self): param = PluginParamType(group=('transports', 'data')) entry_point = param.convert('aiida.transports:core.ssh', None, None) - self.assertEqual(entry_point.name, 'core.ssh') + assert entry_point.name == 'core.ssh' entry_point = param.convert('transports:core.ssh', None, None) - self.assertEqual(entry_point.name, 'core.ssh') + assert entry_point.name == 'core.ssh' entry_point = param.convert('core.ssh', None, None) - self.assertEqual(entry_point.name, 'core.ssh') + assert entry_point.name == 'core.ssh' entry_point = param.convert('aiida.data:core.structure', None, None) - self.assertEqual(entry_point.name, 'core.structure') + assert entry_point.name == 'core.structure' entry_point = param.convert('data:core.structure', None, None) - self.assertEqual(entry_point.name, 'core.structure') + assert entry_point.name == 'core.structure' entry_point = param.convert('core.structure', None, None) - self.assertEqual(entry_point.name, 'core.structure') + assert entry_point.name == 'core.structure' - with self.assertRaises(click.BadParameter): + with pytest.raises(click.BadParameter): param.convert('not_existent', None, None) def test_convert_load(self): @@ -147,24 +146,24 @@ def test_convert_load(self): entry_point_structure = get_entry_point_from_string('aiida.data:core.structure') entry_point = param.convert('aiida.transports:core.ssh', None, None) - self.assertTrue(entry_point, entry_point_ssh) + assert entry_point, entry_point_ssh entry_point = param.convert('transports:core.ssh', None, None) - self.assertTrue(entry_point, entry_point_ssh) + assert entry_point, entry_point_ssh entry_point = param.convert('core.ssh', None, None) - self.assertTrue(entry_point, entry_point_ssh) + assert entry_point, entry_point_ssh entry_point = param.convert('aiida.data:core.structure', None, None) - self.assertTrue(entry_point, entry_point_structure) + assert entry_point, entry_point_structure entry_point = param.convert('data:core.structure', None, None) - self.assertTrue(entry_point, entry_point_structure) + assert entry_point, entry_point_structure entry_point = param.convert('core.structure', None, None) - self.assertTrue(entry_point, entry_point_structure) + assert entry_point, entry_point_structure - with self.assertRaises(click.BadParameter): + with pytest.raises(click.BadParameter): param.convert('not_existent', None, None) def test_complete_single_group(self): @@ -179,22 +178,22 @@ def test_complete_single_group(self): entry_point_full = 'aiida.transports:core.ssh' options = [item[0] for item in param.complete(None, 'core.ss')] - self.assertIn(entry_point_minimal, options) + assert entry_point_minimal in options options = [item[0] for item in param.complete(None, 'core.ssh')] - self.assertIn(entry_point_minimal, options) + assert entry_point_minimal in options options = [item[0] for item in param.complete(None, 'transports:core.ss')] - self.assertIn(entry_point_partial, options) + assert entry_point_partial in options options = [item[0] for item in param.complete(None, 'transports:core.ssh')] - self.assertIn(entry_point_partial, options) + assert entry_point_partial in options options = [item[0] for item in param.complete(None, 'aiida.transports:core.ss')] - self.assertIn(entry_point_full, options) + assert entry_point_full in options options = [item[0] for item in param.complete(None, 'aiida.transports:core.ssh')] - self.assertIn(entry_point_full, options) + assert entry_point_full in options def test_complete_amibguity(self): """ @@ -207,30 +206,30 @@ def test_complete_amibguity(self): entry_point_full_parsers = 'aiida.parsers:core.arithmetic.add' options = [item[0] for item in param.complete(None, 'aiida.calculations:core.arith')] - self.assertIn(entry_point_full_calculations, options) + assert entry_point_full_calculations in options options = [item[0] for item in param.complete(None, 'aiida.calculations:core.arithmetic.add')] - self.assertIn(entry_point_full_calculations, options) + assert entry_point_full_calculations in options options = [item[0] for item in param.complete(None, 'aiida.parsers:core.arith')] - self.assertIn(entry_point_full_parsers, options) + assert entry_point_full_parsers in options options = [item[0] for item in param.complete(None, 'aiida.parsers:core.arithmetic.add')] - self.assertIn(entry_point_full_parsers, options) + assert entry_point_full_parsers in options # PARTIAL or MINIMAL string formats will not be autocompleted options = [item[0] for item in param.complete(None, 'parsers:core.arith')] - self.assertNotIn(entry_point_full_calculations, options) - self.assertNotIn(entry_point_full_parsers, options) + assert entry_point_full_calculations not in options + assert entry_point_full_parsers not in options options = [item[0] for item in param.complete(None, 'parsers:core.arithmetic.add')] - self.assertNotIn(entry_point_full_calculations, options) - self.assertNotIn(entry_point_full_parsers, options) + assert entry_point_full_calculations not in options + assert entry_point_full_parsers not in options options = [item[0] for item in param.complete(None, 'core.arith')] - self.assertNotIn(entry_point_full_calculations, options) - self.assertNotIn(entry_point_full_parsers, options) + assert entry_point_full_calculations not in options + assert entry_point_full_parsers not in options options = [item[0] for item in param.complete(None, 'core.arithmetic.add')] - self.assertNotIn(entry_point_full_calculations, options) - self.assertNotIn(entry_point_full_parsers, options) + assert entry_point_full_calculations not in options + assert entry_point_full_parsers not in options diff --git a/tests/common/test_hashing.py b/tests/common/test_hashing.py index 7a71fd8929..d5dd97c597 100644 --- a/tests/common/test_hashing.py +++ b/tests/common/test_hashing.py @@ -7,10 +7,10 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=missing-docstring,no-self-use """ Unittests for aiida.common.hashing:make_hash with hardcoded hash values """ - import collections from datetime import datetime from decimal import Decimal @@ -19,103 +19,82 @@ import uuid import numpy as np +import pytest import pytz -try: - import unittest2 as unittest -except ImportError: - import unittest - from aiida.common.exceptions import HashingError from aiida.common.folders import SandboxFolder from aiida.common.hashing import chunked_file_hash, float_to_text, make_hash from aiida.common.utils import DatetimePrecision from aiida.orm import Dict -from aiida.storage.testbase import AiidaTestCase -class FloatToTextTest(unittest.TestCase): +class TestFloatToTextTest: """ Tests for the float_to_text methods """ def test_subnormal(self): - self.assertEqual(float_to_text(-0.00, sig=2), '0') # 0 is always printed as '0' - self.assertEqual(float_to_text(3.555, sig=2), '3.6') - self.assertEqual(float_to_text(3.555, sig=3), '3.56') - self.assertEqual(float_to_text(3.141592653589793238462643383279502884197, sig=14), '3.1415926535898') - - -class MakeHashTest(unittest.TestCase): + assert float_to_text(-0.00, sig=2) == '0' # 0 is always printed as '0' + assert float_to_text(3.555, sig=2) == '3.6' + assert float_to_text(3.555, sig=3) == '3.56' + assert float_to_text(3.141592653589793238462643383279502884197, sig=14) == '3.1415926535898' + + +@pytest.mark.parametrize( + 'value,digest', [ + ('something in ASCII', '06e87857590c91280d25e02f05637cd2381002bd1425dff3e36ca860bbb26a29'), + (42, '9468692328de958d7a8039e8a2eb05cd6888b7911bbc3794d0dfebd8df3482cd'), + (3.141, 'b3302aad550413e14fe44d5ead10b3aeda9884055fca77f9368c48517916d4be'), + (complex(1, 2), '287c6bb18d4fb00fd5f3a6fb6931a85cd8ae4b1f43be4707a76964fbc322872e'), + (True, '31ad5fa163a0c478d966c7f7568f3248f0c58d294372b2e8f7cb0560d8c8b12f'), + (None, '1729486cc7e56a6383542b1ec73125ccb26093651a5da05e04657ac416a74b8f'), + ] +) +def test_builtin_types(value, digest): + assert make_hash(value) == digest + + +class TestMakeHashTest: """ Tests for the make_hash function. """ - # pylint: disable=missing-docstring - - def test_builtin_types(self): - test_data = { - 'something in ASCII': '06e87857590c91280d25e02f05637cd2381002bd1425dff3e36ca860bbb26a29', - 42: '9468692328de958d7a8039e8a2eb05cd6888b7911bbc3794d0dfebd8df3482cd', - 3.141: 'b3302aad550413e14fe44d5ead10b3aeda9884055fca77f9368c48517916d4be', - complex(1, 2): '287c6bb18d4fb00fd5f3a6fb6931a85cd8ae4b1f43be4707a76964fbc322872e', - True: '31ad5fa163a0c478d966c7f7568f3248f0c58d294372b2e8f7cb0560d8c8b12f', - None: '1729486cc7e56a6383542b1ec73125ccb26093651a5da05e04657ac416a74b8f', - } - - for val, digest in test_data.items(): - with self.subTest(val=val): - self.assertEqual(make_hash(val), digest) - def test_unicode_string(self): - self.assertEqual( - make_hash('something still in ASCII'), 'd55e492596cf214d877e165cdc3394f27e82e011838474f5ba5b9824074b9e91' - ) + assert make_hash( + 'something still in ASCII' + ) == 'd55e492596cf214d877e165cdc3394f27e82e011838474f5ba5b9824074b9e91' - self.assertEqual( - make_hash('öpis mit Umluut wie ä, ö, ü und emene ß'), + assert make_hash('öpis mit Umluut wie ä, ö, ü und emene ß') == \ 'c404bf9a62cba3518de5c2bae8c67010aff6e4051cce565fa247a7f1d71f1fc7' - ) def test_collection_with_ordered_sets(self): - self.assertEqual(make_hash((1, 2, 3)), 'b6b13d50e3bee7e58371af2b303f629edf32d1be2f7717c9d14193b4b8b23e04') - self.assertEqual(make_hash([1, 2, 3]), 'b6b13d50e3bee7e58371af2b303f629edf32d1be2f7717c9d14193b4b8b23e04') + assert make_hash((1, 2, 3)) == 'b6b13d50e3bee7e58371af2b303f629edf32d1be2f7717c9d14193b4b8b23e04' + assert make_hash([1, 2, 3]) == 'b6b13d50e3bee7e58371af2b303f629edf32d1be2f7717c9d14193b4b8b23e04' for perm in itertools.permutations([1, 2, 3]): - with self.subTest(orig=[1, 2, 3], perm=perm): - self.assertNotEqual(make_hash(perm), make_hash({1, 2, 3})) + assert make_hash(perm) != make_hash({1, 2, 3}) def test_collisions_with_nested_objs(self): - self.assertNotEqual(make_hash([[1, 2], 3]), make_hash([[1, 2, 3]])) - self.assertNotEqual(make_hash({1, 2}), make_hash({1: 2})) + assert make_hash([[1, 2], 3]) != make_hash([[1, 2, 3]]) + assert make_hash({1, 2}) != make_hash({1: 2}) def test_collection_with_unordered_sets(self): - self.assertEqual(make_hash({1, 2, 3}), 'a11cff8e62b57e1aefb7de908bd50096816b66796eb7e11ad78edeaf2629f89c') - self.assertEqual(make_hash({1, 2, 3}), make_hash({2, 1, 3})) + assert make_hash({1, 2, 3}) == 'a11cff8e62b57e1aefb7de908bd50096816b66796eb7e11ad78edeaf2629f89c' + assert make_hash({1, 2, 3}) == make_hash({2, 1, 3}) def test_collection_with_dicts(self): - self.assertEqual( - make_hash({ - 'a': 'b', - 'c': 'd' - }), '656ef313d44684c44977b0c75f48f27a43686c63ae44c8778ea0fe05f629b3b9' - ) + assert make_hash({'a': 'b', 'c': 'd'}) == '656ef313d44684c44977b0c75f48f27a43686c63ae44c8778ea0fe05f629b3b9' # order changes in dictionaries should give the same hashes - self.assertEqual( - make_hash(collections.OrderedDict([('c', 'd'), ('a', 'b')]), odict_as_unordered=True), + assert make_hash(collections.OrderedDict([('c', 'd'), ('a', 'b')]), odict_as_unordered=True) == \ make_hash(collections.OrderedDict([('a', 'b'), ('c', 'd')]), odict_as_unordered=True) - ) def test_collection_with_odicts(self): # ordered dicts should always give a different hash (because they are a different type), unless told otherwise: - self.assertNotEqual( - make_hash(collections.OrderedDict([('a', 'b'), ('c', 'd')])), make_hash(dict([('a', 'b'), ('c', 'd')])) - ) - self.assertEqual( - make_hash(collections.OrderedDict([('a', 'b'), ('c', 'd')]), odict_as_unordered=True), + assert make_hash(collections.OrderedDict([('a', 'b'), ('c', 'd')])) != make_hash(dict([('a', 'b'), ('c', 'd')])) + assert make_hash(collections.OrderedDict([('a', 'b'), ('c', 'd')]), odict_as_unordered=True) == \ make_hash(dict([('a', 'b'), ('c', 'd')])) - ) def test_nested_collections(self): obj_a = collections.OrderedDict([ @@ -137,70 +116,60 @@ def test_nested_collections(self): '1': 'hello', }), ('3', 4), (3, 4)]) - self.assertEqual( - make_hash(obj_a, odict_as_unordered=True), + assert make_hash(obj_a, odict_as_unordered=True) == \ 'e27bf6081c23afcb3db0ee3a24a64c73171c062c7f227fecc7f17189996add44' - ) - self.assertEqual(make_hash(obj_a, odict_as_unordered=True), make_hash(obj_b, odict_as_unordered=True)) + assert make_hash(obj_a, odict_as_unordered=True) == make_hash(obj_b, odict_as_unordered=True) def test_bytes(self): - self.assertEqual(make_hash(b'foo'), '459062c44082269b2d07f78c1b6e8c98b93448606bfb1cc1f48284cdfcea74e3') + assert make_hash(b'foo') == '459062c44082269b2d07f78c1b6e8c98b93448606bfb1cc1f48284cdfcea74e3' def test_uuid(self): some_uuid = uuid.UUID('62c42d58-56e8-4ade-9d5e-18de3a7baacd') - self.assertEqual(make_hash(some_uuid), '3df6ae6dd5930e4cf8b22de123e5ac4f004f63ab396dff6225e656acc42dcf6f') - self.assertNotEqual(make_hash(some_uuid), make_hash(str(some_uuid))) + assert make_hash(some_uuid) == '3df6ae6dd5930e4cf8b22de123e5ac4f004f63ab396dff6225e656acc42dcf6f' + assert make_hash(some_uuid) != make_hash(str(some_uuid)) def test_datetime(self): # test for timezone-naive datetime: - self.assertEqual( - make_hash(datetime(2018, 8, 18, 8, 18)), '714138f1114daa5fdc74c3483260742952b71b568d634c6093bb838afad76646' - ) - self.assertEqual( - make_hash(datetime.utcfromtimestamp(0)), 'b4d97d9d486937775bcc25a5cba073f048348c3cd93d4460174a4f72a6feb285' - ) + assert make_hash( + datetime(2018, 8, 18, 8, 18) + ) == '714138f1114daa5fdc74c3483260742952b71b568d634c6093bb838afad76646' + assert make_hash( + datetime.utcfromtimestamp(0) + ) == 'b4d97d9d486937775bcc25a5cba073f048348c3cd93d4460174a4f72a6feb285' # test with timezone-aware datetime: - self.assertEqual( - make_hash(datetime(2018, 8, 18, 8, 18).replace(tzinfo=pytz.timezone('US/Eastern'))), + assert make_hash(datetime(2018, 8, 18, 8, 18).replace(tzinfo=pytz.timezone('US/Eastern'))) == \ '194478834b3b8bd0518cf6ca6fefacc13bea15f9c0b8f5d585a0adf2ebbd562f' - ) - self.assertEqual( - make_hash(datetime(2018, 8, 18, 8, 18).replace(tzinfo=pytz.timezone('Europe/Amsterdam'))), + assert make_hash(datetime(2018, 8, 18, 8, 18).replace(tzinfo=pytz.timezone('Europe/Amsterdam'))) == \ 'be7c7c7faaff07d796db4cbef4d3d07ed29fdfd4a38c9aded00a4c2da2b89b9c' - ) def test_datetime_precision_hashing(self): dt_prec = DatetimePrecision(datetime(2018, 8, 18, 8, 18), 10) - self.assertEqual(make_hash(dt_prec), '837ab70b3b7bd04c1718834a0394a2230d81242c442e4aa088abeab15622df37') + assert make_hash(dt_prec) == '837ab70b3b7bd04c1718834a0394a2230d81242c442e4aa088abeab15622df37' dt_prec_utc = DatetimePrecision(datetime.utcfromtimestamp(0), 0) - self.assertEqual(make_hash(dt_prec_utc), '8c756ee99eaf9655bb00166839b9d40aa44eac97684b28f6e3c07d4331ae644e') + assert make_hash(dt_prec_utc) == '8c756ee99eaf9655bb00166839b9d40aa44eac97684b28f6e3c07d4331ae644e' def test_numpy_types(self): - self.assertEqual( - make_hash(np.float64(3.141)), 'b3302aad550413e14fe44d5ead10b3aeda9884055fca77f9368c48517916d4be' - ) # pylint: disable=no-member - self.assertEqual(make_hash(np.int64(42)), '9468692328de958d7a8039e8a2eb05cd6888b7911bbc3794d0dfebd8df3482cd') # pylint: disable=no-member + assert make_hash(np.float64(3.141)) == 'b3302aad550413e14fe44d5ead10b3aeda9884055fca77f9368c48517916d4be' # pylint: disable=no-member + assert make_hash(np.int64(42)) == '9468692328de958d7a8039e8a2eb05cd6888b7911bbc3794d0dfebd8df3482cd' # pylint: disable=no-member def test_decimal(self): - self.assertEqual( - make_hash(Decimal('3.141')), 'b3302aad550413e14fe44d5ead10b3aeda9884055fca77f9368c48517916d4be' - ) # pylint: disable=no-member + assert make_hash(Decimal('3.141')) == 'b3302aad550413e14fe44d5ead10b3aeda9884055fca77f9368c48517916d4be' # pylint: disable=no-member # make sure we get the same hashes as for corresponding float or int - self.assertEqual(make_hash(Decimal('3.141')), make_hash(3.141)) # pylint: disable=no-member + assert make_hash(Decimal('3.141')) == make_hash(3.141) # pylint: disable=no-member - self.assertEqual(make_hash(Decimal('3.')), make_hash(3)) # pylint: disable=no-member + assert make_hash(Decimal('3.')) == make_hash(3) # pylint: disable=no-member - self.assertEqual(make_hash(Decimal('3141')), make_hash(3141)) # pylint: disable=no-member + assert make_hash(Decimal('3141')) == make_hash(3141) # pylint: disable=no-member def test_unhashable_type(self): class MadeupClass: pass - with self.assertRaises(HashingError): + with pytest.raises(HashingError): make_hash(MadeupClass()) def test_folder(self): @@ -212,35 +181,36 @@ def test_folder(self): fhandle.close() folder_hash = make_hash(folder) - self.assertEqual(folder_hash, '47d9cdb2247e75eca492035f60f09fdd0daf87bbba40bb658d2d7e84f21f26c5') + assert folder_hash == '47d9cdb2247e75eca492035f60f09fdd0daf87bbba40bb658d2d7e84f21f26c5' nested_obj = ['1.0.0a2', {'array|a': [1001]}, folder, None] - self.assertEqual(make_hash(nested_obj), 'd3e7ff24708bc60b75a01571454ac0a664fa94ff2145848b584fb9ecc7e4fcbe') + assert make_hash(nested_obj) == 'd3e7ff24708bc60b75a01571454ac0a664fa94ff2145848b584fb9ecc7e4fcbe' with folder.open('file3.npy', 'wb') as fhandle: np.save(fhandle, np.arange(10)) # after adding a file, the folder hash should have changed - self.assertNotEqual(make_hash(folder), folder_hash) + assert make_hash(folder) != folder_hash # ... unless we explicitly tell it to ignore the new file - self.assertEqual(make_hash(folder, ignored_folder_content='file3.npy'), folder_hash) + assert make_hash(folder, ignored_folder_content='file3.npy') == folder_hash subfolder = folder.get_subfolder('some_subdir', create=True) with subfolder.open('file4.npy', 'wb') as fhandle: np.save(fhandle, np.arange(5)) - self.assertNotEqual(make_hash(folder), folder_hash) - self.assertEqual(make_hash(folder, ignored_folder_content=['file3.npy', 'some_subdir']), folder_hash) + assert make_hash(folder) != folder_hash + assert make_hash(folder, ignored_folder_content=['file3.npy', 'some_subdir']) == folder_hash -class CheckDBRoundTrip(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestCheckDBRoundTrip: """ - Check that the hash does not change after a roundtrip via the DB. Note that this class must inherit from - AiiDATestCase since it's working with the DB. + Check that the hash does not change after a roundtrip via the DB. """ - def test_attribute_storing(self): + @staticmethod + def test_attribute_storing(): """ I test that when storing different types of data as attributes (using a dict), the hash is the same before and after storing. @@ -277,7 +247,7 @@ def test_attribute_storing(self): first_hash = node.get_extra('_aiida_hash') recomputed_hash = node.get_hash() - self.assertEqual(first_hash, recomputed_hash) + assert first_hash == recomputed_hash def test_chunked_file_hash(tmp_path): diff --git a/tests/common/test_links.py b/tests/common/test_links.py index 6880270140..d58aff0496 100644 --- a/tests/common/test_links.py +++ b/tests/common/test_links.py @@ -8,27 +8,24 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the links utilities.""" +import pytest from aiida.common.links import validate_link_label -from aiida.storage.testbase import AiidaTestCase -class TestValidateLinkLabel(AiidaTestCase): - """Tests for `validate_link_label` function.""" +def test_validate_link_label(): + """Test that illegal link labels will raise a `ValueError`.""" - def test_validate_link_label(self): - """Test that illegal link labels will raise a `ValueError`.""" + illegal_link_labels = [ + '_leading_underscore', + 'trailing_underscore_', + 'non_numeric_%', + 'including.period', + 'disallowed👻unicodecharacters', + 'white space', + 'das-hes', + ] - illegal_link_labels = [ - '_leading_underscore', - 'trailing_underscore_', - 'non_numeric_%', - 'including.period', - 'disallowed👻unicodecharacters', - 'white space', - 'das-hes', - ] - - for link_label in illegal_link_labels: - with self.assertRaises(ValueError): - validate_link_label(link_label) + for link_label in illegal_link_labels: + with pytest.raises(ValueError): + validate_link_label(link_label) diff --git a/tests/conftest.py b/tests/conftest.py index ea08471917..314552fe2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -447,7 +447,7 @@ def _run_cli_command( :param user_input: string with data to be provided at the prompt. Can include newline characters to simulate responses to multiple prompts. :param raises: whether the command is expected to raise an exception. - :param catch_exceptions: if True and ``raise == False``, will assert that the exception is ``None`` and the exit + :param catch_exceptions: if True and ``raise is False``, will assert that the exception is ``None`` and the exit code of the result of the invoked command equals zero. :param kwargs: keyword arguments that will be psased to the command invocation. :return: test result. diff --git a/tests/engine/daemon/test_client.py b/tests/engine/daemon/test_client.py index 8056ed4394..8cfda7802b 100644 --- a/tests/engine/daemon/test_client.py +++ b/tests/engine/daemon/test_client.py @@ -7,35 +7,33 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Unit tests for the `DaemonClient` class.""" - +import pytest import zmq from aiida.engine.daemon.client import get_daemon_client -from aiida.storage.testbase import AiidaTestCase - -class TestDaemonClient(AiidaTestCase): - """Unit tests for the `DaemonClient` class.""" - def test_ipc_socket_file_length_limit(self): - """ - The maximum length of socket filepaths is often limited by the operating system. - For MacOS it is limited to 103 bytes, versus 107 bytes on Unix. This limit is - exposed by the Zmq library which is used by Circus library that is used to - daemonize the daemon runners. This test verifies that the three endpoints used - for the Circus client have a filepath that does not exceed that path limit. +@pytest.mark.usefixtures('aiida_profile_clean') +def test_ipc_socket_file_length_limit(): + """ + The maximum length of socket filepaths is often limited by the operating system. + For MacOS it is limited to 103 bytes, versus 107 bytes on Unix. This limit is + exposed by the Zmq library which is used by Circus library that is used to + daemonize the daemon runners. This test verifies that the three endpoints used + for the Circus client have a filepath that does not exceed that path limit. - See issue #1317 and pull request #1403 for the discussion - """ - # pylint: disable=no-member + See issue #1317 and pull request #1403 for the discussion + """ + # pylint: disable=no-member - daemon_client = get_daemon_client() + daemon_client = get_daemon_client() - controller_endpoint = daemon_client.get_controller_endpoint() - pubsub_endpoint = daemon_client.get_pubsub_endpoint() - stats_endpoint = daemon_client.get_stats_endpoint() + controller_endpoint = daemon_client.get_controller_endpoint() + pubsub_endpoint = daemon_client.get_pubsub_endpoint() + stats_endpoint = daemon_client.get_stats_endpoint() - self.assertTrue(len(controller_endpoint) <= zmq.IPC_PATH_MAX_LEN) - self.assertTrue(len(pubsub_endpoint) <= zmq.IPC_PATH_MAX_LEN) - self.assertTrue(len(stats_endpoint) <= zmq.IPC_PATH_MAX_LEN) + assert len(controller_endpoint) <= zmq.IPC_PATH_MAX_LEN + assert len(pubsub_endpoint) <= zmq.IPC_PATH_MAX_LEN + assert len(stats_endpoint) <= zmq.IPC_PATH_MAX_LEN diff --git a/tests/engine/processes/calcjobs/test_calc_job.py b/tests/engine/processes/calcjobs/test_calc_job.py index 811287e550..da818baecc 100644 --- a/tests/engine/processes/calcjobs/test_calc_job.py +++ b/tests/engine/processes/calcjobs/test_calc_job.py @@ -25,7 +25,6 @@ from aiida.engine.processes.calcjobs.calcjob import validate_stash_options from aiida.engine.processes.ports import PortNamespace from aiida.plugins import CalculationFactory -from aiida.storage.testbase import AiidaTestCase ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') # pylint: disable=invalid-name @@ -168,16 +167,20 @@ def test_multi_codes_run_withmpi(aiida_local_code_factory, file_regression, calc @pytest.mark.requires_rmq -class TestCalcJob(AiidaTestCase): +class TestCalcJob: """Test for the `CalcJob` process sub class.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer.configure() # pylint: disable=no-member - cls.remote_code = orm.Code(remote_computer_exec=(cls.computer, '/bin/bash')).store() - cls.local_code = orm.Code(local_executable='bash', files=['/bin/bash']).store() - cls.inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'metadata': {'options': {}}} + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + assert Process.current() is None + self.computer = aiida_localhost + self.remote_code = orm.Code(remote_computer_exec=(self.computer, '/bin/bash')).store() + self.local_code = orm.Code(local_executable='bash', files=['/bin/bash']).store() + self.inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'metadata': {'options': {}}} + yield + assert Process.current() is None def instantiate_process(self, state=CalcJobState.PARSING): """Instantiate a process with default inputs and return the `Process` instance.""" @@ -196,29 +199,21 @@ def instantiate_process(self, state=CalcJobState.PARSING): return process - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) - def test_run_base_class(self): """Verify that it is impossible to run, submit or instantiate a base `CalcJob` class.""" - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): CalcJob() - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run(CalcJob) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run_get_node(CalcJob) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run_get_pk(CalcJob) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.submit(CalcJob) def test_define_not_calling_super(self): @@ -234,13 +229,13 @@ def define(cls, spec): def prepare_for_submission(self, folder): pass - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): launch.run(IncompleteDefineCalcJob) def test_spec_options_property(self): """`CalcJob.spec_options` should return the options port namespace of its spec.""" - self.assertIsInstance(CalcJob.spec_options, PortNamespace) - self.assertEqual(CalcJob.spec_options, CalcJob.spec().inputs['metadata']['options']) + assert isinstance(CalcJob.spec_options, PortNamespace) + assert CalcJob.spec_options == CalcJob.spec().inputs['metadata']['options'] def test_invalid_options_type(self): """Verify that passing an invalid type to `metadata.options` raises a `TypeError`.""" @@ -256,7 +251,7 @@ def prepare_for_submission(self, folder): pass # The `metadata.options` input expects a plain dict and not a node `Dict` - with self.assertRaises(TypeError): + with pytest.raises(TypeError): launch.run(SimpleCalcJob, code=self.remote_code, metadata={'options': orm.Dict(dict={'a': 1})}) def test_remote_code_set_computer_implicit(self): @@ -269,8 +264,8 @@ def test_remote_code_set_computer_implicit(self): inputs['code'] = self.remote_code process = ArithmeticAddCalculation(inputs=inputs) - self.assertTrue(process.node.is_stored) - self.assertEqual(process.node.computer.uuid, self.remote_code.computer.uuid) + assert process.node.is_stored + assert process.node.computer.uuid == self.remote_code.computer.uuid def test_remote_code_unstored_computer(self): """Test launching a `CalcJob` with an unstored computer which should raise.""" @@ -278,7 +273,7 @@ def test_remote_code_unstored_computer(self): inputs['code'] = self.remote_code inputs['metadata']['computer'] = orm.Computer('different', 'localhost', 'desc', 'core.local', 'core.direct') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ArithmeticAddCalculation(inputs=inputs) def test_remote_code_set_computer_explicit(self): @@ -291,7 +286,7 @@ def test_remote_code_set_computer_explicit(self): inputs['code'] = self.remote_code # Setting explicitly a computer that is not the same as that of the `code` should raise - with self.assertRaises(ValueError): + with pytest.raises(ValueError): inputs['metadata']['computer'] = orm.Computer( 'different', 'localhost', 'desc', 'core.local', 'core.direct' ).store() @@ -300,8 +295,8 @@ def test_remote_code_set_computer_explicit(self): # Setting the same computer as that of the `code` effectively accomplishes nothing but should be fine inputs['metadata']['computer'] = self.remote_code.computer process = ArithmeticAddCalculation(inputs=inputs) - self.assertTrue(process.node.is_stored) - self.assertEqual(process.node.computer.uuid, self.remote_code.computer.uuid) + assert process.node.is_stored + assert process.node.computer.uuid == self.remote_code.computer.uuid def test_local_code_set_computer(self): """Test launching a `CalcJob` with a local code *with* explicitly defining a computer, which should work.""" @@ -310,15 +305,15 @@ def test_local_code_set_computer(self): inputs['metadata']['computer'] = self.computer process = ArithmeticAddCalculation(inputs=inputs) - self.assertTrue(process.node.is_stored) - self.assertEqual(process.node.computer.uuid, self.computer.uuid) # pylint: disable=no-member + assert process.node.is_stored + assert process.node.computer.uuid == self.computer.uuid # pylint: disable=no-member def test_local_code_no_computer(self): """Test launching a `CalcJob` with a local code *without* explicitly defining a computer, which should raise.""" inputs = deepcopy(self.inputs) inputs['code'] = self.local_code - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ArithmeticAddCalculation(inputs=inputs) def test_invalid_parser_name(self): @@ -327,7 +322,7 @@ def test_invalid_parser_name(self): inputs['code'] = self.remote_code inputs['metadata']['options']['parser_name'] = 'invalid_parser' - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ArithmeticAddCalculation(inputs=inputs) def test_invalid_resources(self): @@ -336,7 +331,7 @@ def test_invalid_resources(self): inputs['code'] = self.remote_code inputs['metadata']['options']['resources'] = {'num_machines': 'invalid_type'} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ArithmeticAddCalculation(inputs=inputs) def test_par_env_resources_computer(self): @@ -369,11 +364,9 @@ def test_exception_presubmit(self): """ from aiida.engine.processes.calcjobs.tasks import PreSubmitException - with self.assertRaises(PreSubmitException) as context: + with pytest.raises(PreSubmitException, match='exception occurred in presubmit call'): launch.run(ArithmeticAddCalculation, code=self.remote_code, **self.inputs) - self.assertIn('exception occurred in presubmit call', str(context.exception)) - @pytest.mark.usefixtures('chdir_tmp_path') def test_run_local_code(self): """Run a dry-run with local code.""" @@ -388,7 +381,7 @@ def test_run_local_code(self): # Since the repository will only contain files on the top-level due to `Code.set_files` we only check those for filename in self.local_code.list_object_names(): - self.assertTrue(filename in uploaded_files) + assert filename in uploaded_files @pytest.mark.usefixtures('chdir_tmp_path') def test_rerunnable(self): @@ -446,13 +439,13 @@ def test_provenance_exclude_list(self): # written to the node's repository so we can check it contains the expected contents. _, node = launch.run_get_node(FileCalcJob, **inputs) - self.assertIn('folder', node.dry_run_info) + assert 'folder' in node.dry_run_info # Verify that the folder (representing the node's repository) indeed do not contain the input files. Note, # however, that the directory hierarchy should be there, albeit empty - self.assertIn('base', node.list_object_names()) - self.assertEqual(sorted(['b']), sorted(node.list_object_names(os.path.join('base')))) - self.assertEqual(['two'], node.list_object_names(os.path.join('base', 'b'))) + assert 'base' in node.list_object_names() + assert sorted(['b']) == sorted(node.list_object_names(os.path.join('base'))) + assert ['two'] == node.list_object_names(os.path.join('base', 'b')) def test_parse_no_retrieved_folder(self): """Test the `CalcJob.parse` method when there is no retrieved folder.""" @@ -784,14 +777,15 @@ def test_validate_stash_options(stash_options, expected): assert expected in validate_stash_options(stash_options, None) -class TestImport(AiidaTestCase): +class TestImport: """Test the functionality to import existing calculations completed outside of AiiDA.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer.configure() # pylint: disable=no-member - cls.inputs = { + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost + self.inputs = { 'x': orm.Int(1), 'y': orm.Int(2), 'metadata': { diff --git a/tests/engine/processes/workchains/test_utils.py b/tests/engine/processes/workchains/test_utils.py index d0171c5ad5..63c2d12e7a 100644 --- a/tests/engine/processes/workchains/test_utils.py +++ b/tests/engine/processes/workchains/test_utils.py @@ -16,18 +16,18 @@ from aiida.engine.processes.workchains.utils import ProcessHandlerReport, process_handler from aiida.orm import ProcessNode from aiida.plugins import CalculationFactory -from aiida.storage.testbase import AiidaTestCase ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') @pytest.mark.requires_rmq -class TestRegisterProcessHandler(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestRegisterProcessHandler: """Tests for the `process_handler` decorator.""" def test_priority_keyword_only(self): """The `priority` should be keyword only.""" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): class SomeWorkChain(BaseRestartWorkChain): @@ -43,7 +43,7 @@ def _(self, node): def test_priority_type(self): """The `priority` should be an integer.""" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): class SomeWorkChain(BaseRestartWorkChain): @@ -109,7 +109,7 @@ def handler_04(self, node): def test_exit_codes_keyword_only(self): """The `exit_codes` should be keyword only.""" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): class SomeWorkChain(BaseRestartWorkChain): @@ -132,7 +132,7 @@ def test_exit_codes_type(self): [400], ] - with self.assertRaises(TypeError): + with pytest.raises(TypeError): for incorrect_type in incorrect_types: class SomeWorkChain(BaseRestartWorkChain): @@ -188,7 +188,7 @@ def _(self, node): def test_enabled_keyword_only(self): """The `enabled` should be keyword only.""" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): class SomeWorkChain(BaseRestartWorkChain): diff --git a/tests/engine/test_calcfunctions.py b/tests/engine/test_calcfunctions.py index d3d029297d..d8ec5c2bef 100644 --- a/tests/engine/test_calcfunctions.py +++ b/tests/engine/test_calcfunctions.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the calcfunction decorator and CalcFunctionNode.""" import pytest @@ -15,7 +16,6 @@ from aiida.engine import Process, calcfunction from aiida.manage.caching import enable_caching from aiida.orm import CalcFunctionNode, Int -from aiida.storage.testbase import AiidaTestCase # Global required for one of the caching tests to keep track of the number of times the calculation function is executed EXECUTION_COUNTER = 0 @@ -39,66 +39,65 @@ def execution_counter_calcfunction(data): @pytest.mark.requires_rmq -class TestCalcFunction(AiidaTestCase): +class TestCalcFunction: """Tests for calcfunctions. .. note: tests common to all process functions should go in `tests.engine.test_process_function.py` """ - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + assert Process.current() is None self.default_int = Int(256) - self.test_calcfunction = add_calcfunction - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + yield + assert Process.current() is None def test_calcfunction_node_type(self): """Verify that a calcfunction gets a CalcFunctionNode as node instance.""" _, node = self.test_calcfunction.run_get_node(self.default_int) - self.assertIsInstance(node, CalcFunctionNode) + assert isinstance(node, CalcFunctionNode) def test_calcfunction_links(self): """Verify that a calcfunction can only get CREATE links and no RETURN links.""" _, node = self.test_calcfunction.run_get_node(self.default_int) - self.assertEqual(len(node.get_outgoing(link_type=LinkType.CREATE).all()), 1) - self.assertEqual(len(node.get_outgoing(link_type=LinkType.RETURN).all()), 0) + assert len(node.get_outgoing(link_type=LinkType.CREATE).all()) == 1 + assert len(node.get_outgoing(link_type=LinkType.RETURN).all()) == 0 def test_calcfunction_return_stored(self): """Verify that a calcfunction will raise when a stored node is returned.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): return_stored_calcfunction.run_get_node() def test_calcfunction_default_linkname(self): """Verify that a calcfunction that returns a single Data node gets a default link label.""" _, node = self.test_calcfunction.run_get_node(self.default_int) - self.assertEqual(node.outputs.result, self.default_int.value + 1) - self.assertEqual(getattr(node.outputs, Process.SINGLE_OUTPUT_LINKNAME), self.default_int.value + 1) - self.assertEqual(node.outputs[Process.SINGLE_OUTPUT_LINKNAME], self.default_int.value + 1) + assert node.outputs.result == self.default_int.value + 1 + assert getattr(node.outputs, Process.SINGLE_OUTPUT_LINKNAME) == self.default_int.value + 1 + assert node.outputs[Process.SINGLE_OUTPUT_LINKNAME] == self.default_int.value + 1 def test_calcfunction_caching(self): """Verify that a calcfunction can be cached.""" - self.assertEqual(EXECUTION_COUNTER, 0) + assert EXECUTION_COUNTER == 0 _, original = execution_counter_calcfunction.run_get_node(Int(5)) - self.assertEqual(EXECUTION_COUNTER, 1) + assert EXECUTION_COUNTER == 1 # Caching a CalcFunctionNode should be possible with enable_caching(identifier='*.execution_counter_calcfunction'): input_node = Int(5) result, cached = execution_counter_calcfunction.run_get_node(input_node) - self.assertEqual(EXECUTION_COUNTER, 1) # Calculation function body should not have been executed - self.assertTrue(result.is_stored) - self.assertTrue(cached.is_created_from_cache) - self.assertIn(cached.get_cache_source(), original.uuid) - self.assertEqual(cached.get_incoming().one().node.uuid, input_node.uuid) + assert EXECUTION_COUNTER == 1 # Calculation function body should not have been executed + assert result.is_stored + assert cached.is_created_from_cache + assert cached.get_cache_source() in original.uuid + assert cached.get_incoming().one().node.uuid == input_node.uuid def test_calcfunction_caching_change_code(self): """Verify that changing the source codde of a calcfunction invalidates any existing cached nodes.""" @@ -113,21 +112,21 @@ def add_calcfunction(data): # pylint: disable=redefined-outer-name with enable_caching(identifier='*.add_calcfunction'): result_cached, cached = add_calcfunction.run_get_node(self.default_int) - self.assertNotEqual(result_original, result_cached) - self.assertFalse(cached.is_created_from_cache) + assert result_original != result_cached + assert not cached.is_created_from_cache # Test that the locally-created calcfunction can be cached in principle result2_cached, cached2 = add_calcfunction.run_get_node(self.default_int) - self.assertNotEqual(result_original, result2_cached) - self.assertTrue(cached2.is_created_from_cache) + assert result_original != result2_cached + assert cached2.is_created_from_cache def test_calcfunction_do_not_store_provenance(self): """Run the function without storing the provenance.""" data = Int(1) result, node = self.test_calcfunction.run_get_node(data, metadata={'store_provenance': False}) # pylint: disable=unexpected-keyword-arg - self.assertFalse(result.is_stored) - self.assertFalse(data.is_stored) - self.assertFalse(node.is_stored) - self.assertEqual(result, data + 1) + assert not result.is_stored + assert not data.is_stored + assert not node.is_stored + assert result == data + 1 def test_calculation_cannot_call(self): """Verify that calling another process from within a calcfunction raises as it is forbidden.""" @@ -136,11 +135,11 @@ def test_calculation_cannot_call(self): def test_calcfunction_caller(data): self.test_calcfunction(data) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): test_calcfunction_caller(self.default_int) def test_calculation_call_store_provenance_false(self): # pylint: disable=invalid-name - """Verify that a `calcfunction` can call another calcfunction as long as `store_provenance` is False.""" + """Verify that a `calcfunction` can call another calcfunction as long as `store_provenance` is False.""" @calcfunction def test_calcfunction_caller(data): @@ -148,9 +147,9 @@ def test_calcfunction_caller(data): result, node = test_calcfunction_caller.run_get_node(self.default_int) - self.assertTrue(isinstance(result, Int)) - self.assertTrue(isinstance(node, CalcFunctionNode)) + assert isinstance(result, Int) + assert isinstance(node, CalcFunctionNode) # The node of the outermost `calcfunction` should have a single `CREATE` link and no `CALL_CALC` links - self.assertEqual(len(node.get_outgoing(link_type=LinkType.CREATE).all()), 1) - self.assertEqual(len(node.get_outgoing(link_type=LinkType.CALL_CALC).all()), 0) + assert len(node.get_outgoing(link_type=LinkType.CREATE).all()) == 1 + assert len(node.get_outgoing(link_type=LinkType.CALL_CALC).all()) == 0 diff --git a/tests/engine/test_class_loader.py b/tests/engine/test_class_loader.py index e04e836bd4..e3ce02f7d9 100644 --- a/tests/engine/test_class_loader.py +++ b/tests/engine/test_class_loader.py @@ -7,23 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """A module to test class loader factories.""" +import pytest + import aiida from aiida.engine import Process from aiida.plugins import CalculationFactory -from aiida.storage.testbase import AiidaTestCase -class TestCalcJob(AiidaTestCase): +class TestCalcJob: """Test CalcJob.""" - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + assert Process.current() is None + yield + assert Process.current() is None def test_class_loader(self): """Test that CalculationFactory works.""" @@ -33,5 +35,5 @@ def test_class_loader(self): class_name = loader.identify_object(process) loaded_class = loader.load_object(class_name) - self.assertEqual(process.__name__, loaded_class.__name__) - self.assertEqual(class_name, loader.identify_object(loaded_class)) + assert process.__name__ == loaded_class.__name__ + assert class_name == loader.identify_object(loaded_class) diff --git a/tests/engine/test_futures.py b/tests/engine/test_futures.py index 33fea6c2ba..8411b5664b 100644 --- a/tests/engine/test_futures.py +++ b/tests/engine/test_futures.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Module to test process futures.""" import asyncio @@ -14,12 +15,12 @@ from aiida.engine import processes, run from aiida.manage import get_manager -from aiida.storage.testbase import AiidaTestCase from tests.utils import processes as test_processes @pytest.mark.requires_rmq -class TestWf(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestWf: """Test process futures.""" TIMEOUT = 5.0 # seconds @@ -37,7 +38,7 @@ def test_calculation_future_broadcasts(self): run(process) calc_node = runner.run_until_complete(asyncio.wait_for(future, self.TIMEOUT)) - self.assertEqual(process.node.pk, calc_node.pk) + assert process.node.pk == calc_node.pk def test_calculation_future_polling(self): """Test calculation future polling.""" @@ -51,4 +52,4 @@ def test_calculation_future_polling(self): runner.run(process) calc_node = runner.run_until_complete(asyncio.wait_for(future, self.TIMEOUT)) - self.assertEqual(process.node.pk, calc_node.pk) + assert process.node.pk == calc_node.pk diff --git a/tests/engine/test_launch.py b/tests/engine/test_launch.py index 345585e28b..c187ac31c8 100644 --- a/tests/engine/test_launch.py +++ b/tests/engine/test_launch.py @@ -7,13 +7,16 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Module to test processess launch.""" +import os +import shutil + import pytest from aiida import orm from aiida.common import exceptions from aiida.engine import CalcJob, Process, WorkChain, calcfunction, launch -from aiida.storage.testbase import AiidaTestCase @calcfunction @@ -65,53 +68,52 @@ def add(self): @pytest.mark.requires_rmq -class TestLaunchers(AiidaTestCase): +class TestLaunchers: """Class to test process launchers.""" - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + assert Process.current() is None self.term_a = orm.Int(1) self.term_b = orm.Int(2) self.result = 3 - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + assert Process.current() is None def test_calcfunction_run(self): """Test calcfunction run.""" result = launch.run(add, term_a=self.term_a, term_b=self.term_b) - self.assertEqual(result, self.result) + assert result == self.result def test_calcfunction_run_get_node(self): """Test calcfunction run by run_get_node.""" result, node = launch.run_get_node(add, term_a=self.term_a, term_b=self.term_b) - self.assertEqual(result, self.result) - self.assertTrue(isinstance(node, orm.CalcFunctionNode)) + assert result == self.result + assert isinstance(node, orm.CalcFunctionNode) def test_calcfunction_run_get_pk(self): """Test calcfunction run by run_get_pk.""" result, pk = launch.run_get_pk(add, term_a=self.term_a, term_b=self.term_b) - self.assertEqual(result, self.result) - self.assertTrue(isinstance(pk, int)) + assert result == self.result + assert isinstance(pk, int) def test_workchain_run(self): """Test workchain run.""" result = launch.run(AddWorkChain, term_a=self.term_a, term_b=self.term_b) - self.assertEqual(result['result'], self.result) + assert result['result'] == self.result def test_workchain_run_get_node(self): """Test workchain run by run_get_node.""" result, node = launch.run_get_node(AddWorkChain, term_a=self.term_a, term_b=self.term_b) - self.assertEqual(result['result'], self.result) - self.assertTrue(isinstance(node, orm.WorkChainNode)) + assert result['result'] == self.result + assert isinstance(node, orm.WorkChainNode) def test_workchain_run_get_pk(self): """Test workchain run by run_get_pk.""" result, pk = launch.run_get_pk(AddWorkChain, term_a=self.term_a, term_b=self.term_b) - self.assertEqual(result['result'], self.result) - self.assertTrue(isinstance(pk, int)) + assert result['result'] == self.result + assert isinstance(pk, int) def test_workchain_builder_run(self): """Test workchain builder run.""" @@ -119,7 +121,7 @@ def test_workchain_builder_run(self): builder.term_a = self.term_a builder.term_b = self.term_b result = launch.run(builder) - self.assertEqual(result['result'], self.result) + assert result['result'] == self.result def test_workchain_builder_run_get_node(self): """Test workchain builder that run by run_get_node.""" @@ -127,8 +129,8 @@ def test_workchain_builder_run_get_node(self): builder.term_a = self.term_a builder.term_b = self.term_b result, node = launch.run_get_node(builder) - self.assertEqual(result['result'], self.result) - self.assertTrue(isinstance(node, orm.WorkChainNode)) + assert result['result'] == self.result + assert isinstance(node, orm.WorkChainNode) def test_workchain_builder_run_get_pk(self): """Test workchain builder that run by run_get_pk.""" @@ -136,32 +138,28 @@ def test_workchain_builder_run_get_pk(self): builder.term_a = self.term_a builder.term_b = self.term_b result, pk = launch.run_get_pk(builder) - self.assertEqual(result['result'], self.result) - self.assertTrue(isinstance(pk, int)) + assert result['result'] == self.result + assert isinstance(pk, int) def test_submit_store_provenance_false(self): """Verify that submitting with `store_provenance=False` raises.""" - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.submit(AddWorkChain, term_a=self.term_a, term_b=self.term_b, metadata={'store_provenance': False}) @pytest.mark.requires_rmq -class TestLaunchersDryRun(AiidaTestCase): +class TestLaunchersDryRun: """Test the launchers when performing a dry-run.""" - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - import os - import shutil - + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init from aiida.common.folders import CALC_JOB_DRY_RUN_BASE_PATH - - super().tearDown() - self.assertIsNone(Process.current()) - + assert Process.current() is None + self.computer = aiida_localhost + yield + assert Process.current() is None # Make sure to clean the test directory that will be generated by the dry-run filepath = os.path.join(os.getcwd(), CALC_JOB_DRY_RUN_BASE_PATH) try: @@ -194,21 +192,21 @@ def test_launchers_dry_run(self): } result = launch.run(ArithmeticAddCalculation, **inputs) - self.assertEqual(result, {}) + assert result == {} result, pk = launch.run_get_pk(ArithmeticAddCalculation, **inputs) - self.assertEqual(result, {}) - self.assertIsInstance(pk, int) + assert result == {} + assert isinstance(pk, int) result, node = launch.run_get_node(ArithmeticAddCalculation, **inputs) - self.assertEqual(result, {}) - self.assertIsInstance(node, orm.CalcJobNode) - self.assertIsInstance(node.dry_run_info, dict) - self.assertIn('folder', node.dry_run_info) - self.assertIn('script_filename', node.dry_run_info) + assert result == {} + assert isinstance(node, orm.CalcJobNode) + assert isinstance(node.dry_run_info, dict) + assert 'folder' in node.dry_run_info + assert 'script_filename' in node.dry_run_info node = launch.submit(ArithmeticAddCalculation, **inputs) - self.assertIsInstance(node, orm.CalcJobNode) + assert isinstance(node, orm.CalcJobNode) def test_launchers_dry_run_no_provenance(self): """Test the launchers in `dry_run` mode with `store_provenance=False`.""" @@ -236,23 +234,23 @@ def test_launchers_dry_run_no_provenance(self): } result = launch.run(ArithmeticAddCalculation, **inputs) - self.assertEqual(result, {}) + assert result == {} result, pk = launch.run_get_pk(ArithmeticAddCalculation, **inputs) - self.assertEqual(result, {}) - self.assertIsNone(pk) + assert result == {} + assert pk is None result, node = launch.run_get_node(ArithmeticAddCalculation, **inputs) - self.assertEqual(result, {}) - self.assertIsInstance(node, orm.CalcJobNode) - self.assertFalse(node.is_stored) - self.assertIsInstance(node.dry_run_info, dict) - self.assertIn('folder', node.dry_run_info) - self.assertIn('script_filename', node.dry_run_info) + assert result == {} + assert isinstance(node, orm.CalcJobNode) + assert not node.is_stored + assert isinstance(node.dry_run_info, dict) + assert 'folder' in node.dry_run_info + assert 'script_filename' in node.dry_run_info node = launch.submit(ArithmeticAddCalculation, **inputs) - self.assertIsInstance(node, orm.CalcJobNode) - self.assertFalse(node.is_stored) + assert isinstance(node, orm.CalcJobNode) + assert not node.is_stored def test_calcjob_dry_run_no_provenance(self): """Test that dry run with `store_provenance=False` still works for unstored inputs. @@ -262,7 +260,6 @@ def test_calcjob_dry_run_no_provenance(self): which is not the case in the `store_provenance=False` mode with unstored nodes. Note that it also explicitly tests nested namespaces as that is a non-trivial case. """ - import os import tempfile code = orm.Code(input_plugin_name='core.arithmetic.add', remote_computer_exec=[self.computer, @@ -295,6 +292,6 @@ def test_calcjob_dry_run_no_provenance(self): } _, node = launch.run_get_node(FileCalcJob, **inputs) - self.assertIn('folder', node.dry_run_info) + assert 'folder' in node.dry_run_info for filename in ['path', 'file_one', 'file_two']: - self.assertIn(filename, os.listdir(node.dry_run_info['folder'])) + assert filename in os.listdir(node.dry_run_info['folder']) diff --git a/tests/engine/test_manager.py b/tests/engine/test_manager.py index 89958a2b14..2781fac2e5 100644 --- a/tests/engine/test_manager.py +++ b/tests/engine/test_manager.py @@ -7,71 +7,70 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the classes in `aiida.engine.processes.calcjobs.manager`.""" - import asyncio import time +import pytest + from aiida.engine.processes.calcjobs.manager import JobManager, JobsList from aiida.engine.transports import TransportQueue -from aiida.orm import AuthInfo, User -from aiida.storage.testbase import AiidaTestCase +from aiida.orm import User -class TestJobManager(AiidaTestCase): +class TestJobManager: """Test the `aiida.engine.processes.calcjobs.manager.JobManager` class.""" - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init self.loop = asyncio.get_event_loop() self.transport_queue = TransportQueue(self.loop) self.user = User.objects.get_default() - self.auth_info = AuthInfo(self.computer, self.user).store() + self.computer = aiida_localhost + self.auth_info = self.computer.get_authinfo(self.user) self.manager = JobManager(self.transport_queue) - def tearDown(self): - super().tearDown() - AuthInfo.objects.delete(self.auth_info.pk) - def test_get_jobs_list(self): """Test the `JobManager.get_jobs_list` method.""" jobs_list = self.manager.get_jobs_list(self.auth_info) - self.assertIsInstance(jobs_list, JobsList) + assert isinstance(jobs_list, JobsList) # Calling the method again, should return the exact same instance of `JobsList` - self.assertEqual(self.manager.get_jobs_list(self.auth_info), jobs_list) + assert self.manager.get_jobs_list(self.auth_info) == jobs_list def test_request_job_info_update(self): """Test the `JobManager.request_job_info_update` method.""" with self.manager.request_job_info_update(self.auth_info, job_id=1) as request: - self.assertIsInstance(request, asyncio.Future) + assert isinstance(request, asyncio.Future) -class TestJobsList(AiidaTestCase): +class TestJobsList: """Test the `aiida.engine.processes.calcjobs.manager.JobsList` class.""" - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init self.loop = asyncio.get_event_loop() self.transport_queue = TransportQueue(self.loop) self.user = User.objects.get_default() - self.auth_info = AuthInfo(self.computer, self.user).store() + self.computer = aiida_localhost + self.auth_info = self.computer.get_authinfo(self.user) self.jobs_list = JobsList(self.auth_info, self.transport_queue) - def tearDown(self): - super().tearDown() - AuthInfo.objects.delete(self.auth_info.pk) - def test_get_minimum_update_interval(self): """Test the `JobsList.get_minimum_update_interval` method.""" minimum_poll_interval = self.auth_info.computer.get_minimum_job_poll_interval() - self.assertEqual(self.jobs_list.get_minimum_update_interval(), minimum_poll_interval) + assert self.jobs_list.get_minimum_update_interval() == minimum_poll_interval def test_last_updated(self): """Test the `JobsList.last_updated` method.""" jobs_list = JobsList(self.auth_info, self.transport_queue) - self.assertEqual(jobs_list.last_updated, None) + assert jobs_list.last_updated is None last_updated = time.time() jobs_list = JobsList(self.auth_info, self.transport_queue, last_updated=last_updated) - self.assertEqual(jobs_list.last_updated, last_updated) + assert jobs_list.last_updated == last_updated diff --git a/tests/engine/test_persistence.py b/tests/engine/test_persistence.py index e54703ccdc..5a0d320f51 100644 --- a/tests/engine/test_persistence.py +++ b/tests/engine/test_persistence.py @@ -7,27 +7,26 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Test persisting via the AiiDAPersister.""" import plumpy import pytest from aiida.engine import Process, run from aiida.engine.persistence import AiiDAPersister -from aiida.storage.testbase import AiidaTestCase from tests.utils.processes import DummyProcess @pytest.mark.requires_rmq -class TestProcess(AiidaTestCase): +class TestProcess: """Test the basic saving and loading of process states.""" - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + assert Process.current() is None + yield + assert Process.current() is None def test_save_load(self): """Test load saved state.""" @@ -38,16 +37,18 @@ def test_save_load(self): loaded_process = saved_state.unbundle() run(loaded_process) - self.assertEqual(loaded_process.state, plumpy.ProcessState.FINISHED) + assert loaded_process.state == plumpy.ProcessState.FINISHED @pytest.mark.requires_rmq -class TestAiiDAPersister(AiidaTestCase): +class TestAiiDAPersister: """Test AiiDAPersister.""" maxDiff = 1024 - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init self.persister = AiiDAPersister() def test_save_load_checkpoint(self): @@ -56,14 +57,14 @@ def test_save_load_checkpoint(self): bundle_saved = self.persister.save_checkpoint(process) bundle_loaded = self.persister.load_checkpoint(process.node.pk) - self.assertDictEqual(bundle_saved, bundle_loaded) + assert bundle_saved == bundle_loaded def test_delete_checkpoint(self): """Test checkpoint deletion.""" process = DummyProcess() self.persister.save_checkpoint(process) - self.assertTrue(isinstance(process.node.checkpoint, str)) + assert isinstance(process.node.checkpoint, str) self.persister.delete_checkpoint(process.pid) - self.assertEqual(process.node.checkpoint, None) + assert process.node.checkpoint is None diff --git a/tests/engine/test_ports.py b/tests/engine/test_ports.py index e5f26cf5b2..9243fc50ce 100644 --- a/tests/engine/test_ports.py +++ b/tests/engine/test_ports.py @@ -7,14 +7,16 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for process spec ports.""" +import pytest + from aiida.engine.processes.ports import InputPort, PortNamespace from aiida.orm import Dict, Int -from aiida.storage.testbase import AiidaTestCase -class TestInputPort(AiidaTestCase): +class TestInputPort: """Tests for the `InputPort` class.""" def test_with_non_db(self): @@ -22,26 +24,26 @@ def test_with_non_db(self): # When not specifying, it should get the default value and `non_db_explicitly_set` should be `False` port = InputPort('port') - self.assertEqual(port.non_db, False) - self.assertEqual(port.non_db_explicitly_set, False) + assert port.non_db is False + assert port.non_db_explicitly_set is False # Using the setter to change the value should toggle both properties port.non_db = True - self.assertEqual(port.non_db, True) - self.assertEqual(port.non_db_explicitly_set, True) + assert port.non_db is True + assert port.non_db_explicitly_set is True # Explicitly setting to `False` upon construction port = InputPort('port', non_db=False) - self.assertEqual(port.non_db, False) - self.assertEqual(port.non_db_explicitly_set, True) + assert port.non_db is False + assert port.non_db_explicitly_set is True # Explicitly setting to `True` upon construction port = InputPort('port', non_db=True) - self.assertEqual(port.non_db, True) - self.assertEqual(port.non_db_explicitly_set, True) + assert port.non_db is True + assert port.non_db_explicitly_set is True -class TestPortNamespace(AiidaTestCase): +class TestPortNamespace: """Tests for the `PortNamespace` class.""" def test_with_non_db(self): @@ -52,16 +54,16 @@ def test_with_non_db(self): # When explicitly set upon port construction, value should not be inherited even when different port = InputPort('storable', non_db=False) port_namespace['storable'] = port - self.assertEqual(port.non_db, False) + assert port.non_db is False port = InputPort('not_storable', non_db=True) port_namespace['not_storable'] = port - self.assertEqual(port.non_db, True) + assert port.non_db is True # If not explicitly defined, it should inherit from parent namespace port = InputPort('not_storable') port_namespace['not_storable'] = port - self.assertEqual(port.non_db, namespace_non_db) + assert port.non_db == namespace_non_db def test_validate_port_name(self): """This test will ensure that illegal port names will raise a `ValueError` when trying to add it.""" @@ -81,7 +83,7 @@ def test_validate_port_name(self): ] for port_name in illegal_port_names: - with self.assertRaises(ValueError): + with pytest.raises(ValueError): port_namespace[port_name] = port def test_serialize_type_check(self): @@ -91,7 +93,7 @@ def test_serialize_type_check(self): port_namespace = PortNamespace(base_namespace) port_namespace.create_port_namespace(nested_namespace) - with self.assertRaisesRegex(TypeError, f'.*{base_namespace}.*{nested_namespace}.*'): + with pytest.raises(TypeError, match=f'.*{base_namespace}.*{nested_namespace}.*'): port_namespace.serialize({'some': {'nested': {'namespace': Dict()}}}) def test_lambda_default(self): @@ -103,21 +105,21 @@ def test_lambda_default(self): # However, pre processing the namespace, which shall evaluate the default followed by validation will fail inputs = port_namespace.pre_process({}) - self.assertIsNotNone(port_namespace.validate(inputs)) + assert port_namespace.validate(inputs) is not None # Passing an explicit value for the port will forego the default and validation on returned inputs should pass inputs = port_namespace.pre_process({'port': Int(5)}) - self.assertIsNone(port_namespace.validate(inputs)) + assert port_namespace.validate(inputs) is None # Redefining the port, this time with a correct default port_namespace['port'] = InputPort('port', valid_type=Int, default=lambda: Int(5)) # Pre processing the namespace shall evaluate the default and return the int node inputs = port_namespace.pre_process({}) - self.assertIsInstance(inputs['port'], Int) - self.assertEqual(inputs['port'].value, 5) + assert isinstance(inputs['port'], Int) + assert inputs['port'].value == 5 # Passing an explicit value for the port will forego the default inputs = port_namespace.pre_process({'port': Int(3)}) - self.assertIsInstance(inputs['port'], Int) - self.assertEqual(inputs['port'].value, 3) + assert isinstance(inputs['port'], Int) + assert inputs['port'].value == 3 diff --git a/tests/engine/test_process.py b/tests/engine/test_process.py index bdf9144986..a095e27d32 100644 --- a/tests/engine/test_process.py +++ b/tests/engine/test_process.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=no-member,too-many-public-methods +# pylint: disable=no-member,too-many-public-methods,no-self-use """Module to test AiiDA processes.""" import threading @@ -21,7 +21,6 @@ from aiida.engine.processes.ports import PortNamespace from aiida.manage.caching import enable_caching from aiida.plugins import CalculationFactory -from aiida.storage.testbase import AiidaTestCase from tests.utils import processes as test_processes @@ -37,16 +36,16 @@ def define(cls, spec): @pytest.mark.requires_rmq -class TestProcessNamespace(AiidaTestCase): +class TestProcessNamespace: """Test process namespace""" - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + assert Process.current() is None + yield + assert Process.current() is None def test_namespaced_process(self): """Test that inputs in nested namespaces are properly validated and the link labels @@ -54,21 +53,21 @@ def test_namespaced_process(self): proc = NameSpacedProcess(inputs={'some': {'name': {'space': {'a': orm.Int(5)}}}}) # Test that the namespaced inputs are AttributesFrozenDicts - self.assertIsInstance(proc.inputs, AttributesFrozendict) - self.assertIsInstance(proc.inputs.some, AttributesFrozendict) - self.assertIsInstance(proc.inputs.some.name, AttributesFrozendict) - self.assertIsInstance(proc.inputs.some.name.space, AttributesFrozendict) + assert isinstance(proc.inputs, AttributesFrozendict) + assert isinstance(proc.inputs.some, AttributesFrozendict) + assert isinstance(proc.inputs.some.name, AttributesFrozendict) + assert isinstance(proc.inputs.some.name.space, AttributesFrozendict) # Test that the input node is in the inputs of the process input_node = proc.inputs.some.name.space.a - self.assertTrue(isinstance(input_node, orm.Int)) - self.assertEqual(input_node.value, 5) + assert isinstance(input_node, orm.Int) + assert input_node.value == 5 # Check that the link of the process node has the correct link name - self.assertTrue('some__name__space__a' in proc.node.get_incoming().all_link_labels()) - self.assertEqual(proc.node.get_incoming().get_node_by_label('some__name__space__a'), 5) - self.assertEqual(proc.node.inputs.some.name.space.a, 5) - self.assertEqual(proc.node.inputs['some']['name']['space']['a'], 5) + assert 'some__name__space__a' in proc.node.get_incoming().all_link_labels() + assert proc.node.get_incoming().get_node_by_label('some__name__space__a') == 5 + assert proc.node.inputs.some.name.space.a == 5 + assert proc.node.inputs['some']['name']['space']['a'] == 5 class ProcessStackTest(Process): @@ -95,29 +94,30 @@ def on_stop(self): @pytest.mark.requires_rmq -class TestProcess(AiidaTestCase): +class TestProcess: """Test AiiDA process.""" - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + assert Process.current() is None + self.computer = aiida_localhost + yield + assert Process.current() is None @staticmethod def test_process_stack(): run(ProcessStackTest) def test_inputs(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): run(test_processes.BadOutput) def test_spec_metadata_property(self): """`Process.spec_metadata` should return the metadata port namespace of its spec.""" - self.assertIsInstance(Process.spec_metadata, PortNamespace) - self.assertEqual(Process.spec_metadata, Process.spec().inputs['metadata']) + assert isinstance(Process.spec_metadata, PortNamespace) + assert Process.spec_metadata == Process.spec().inputs['metadata'] def test_input_link_creation(self): """Test input link creation.""" @@ -128,12 +128,12 @@ def test_input_link_creation(self): process = test_processes.DummyProcess(inputs) for entry in process.node.get_incoming().all(): - self.assertTrue(entry.link_label in inputs) - self.assertEqual(entry.link_label, entry.node.value) + assert entry.link_label in inputs + assert entry.link_label == entry.node.value dummy_inputs.remove(entry.link_label) # Make sure there are no other inputs - self.assertFalse(dummy_inputs) + assert not dummy_inputs @staticmethod def test_none_input(): @@ -145,34 +145,34 @@ def test_input_after_stored(self): from aiida.common import LinkType process = test_processes.DummyProcess() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): process.node.add_incoming(orm.Int(1), link_type=LinkType.INPUT_WORK, link_label='illegal_link') def test_seal(self): _, p_k = run_get_pk(test_processes.DummyProcess) - self.assertTrue(orm.load_node(pk=p_k).is_sealed) + assert orm.load_node(pk=p_k).is_sealed def test_description(self): """Testing setting a process description.""" dummy_process = test_processes.DummyProcess(inputs={'metadata': {'description': "Rockin' process"}}) - self.assertEqual(dummy_process.node.description, "Rockin' process") + assert dummy_process.node.description == "Rockin' process" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): test_processes.DummyProcess(inputs={'metadata': {'description': 5}}) def test_label(self): """Test setting a label.""" dummy_process = test_processes.DummyProcess(inputs={'metadata': {'label': 'My label'}}) - self.assertEqual(dummy_process.node.label, 'My label') + assert dummy_process.node.label == 'My label' - with self.assertRaises(ValueError): + with pytest.raises(ValueError): test_processes.DummyProcess(inputs={'label': 5}) def test_work_calc_finish(self): process = test_processes.DummyProcess() - self.assertFalse(process.node.is_finished_ok) + assert not process.node.is_finished_ok run(process) - self.assertTrue(process.node.is_finished_ok) + assert process.node.is_finished_ok @staticmethod def test_save_instance_state(): @@ -188,16 +188,16 @@ def test_exit_codes(self): ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') # pylint: disable=invalid-name exit_codes = ArithmeticAddCalculation.exit_codes - self.assertIsInstance(exit_codes, ExitCodesNamespace) + assert isinstance(exit_codes, ExitCodesNamespace) for _, value in exit_codes.items(): - self.assertIsInstance(value, ExitCode) + assert isinstance(value, ExitCode) exit_statuses = ArithmeticAddCalculation.get_exit_statuses(['ERROR_NO_RETRIEVED_FOLDER']) - self.assertIsInstance(exit_statuses, list) + assert isinstance(exit_statuses, list) for entry in exit_statuses: - self.assertIsInstance(entry, int) + assert isinstance(entry, int) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): ArithmeticAddCalculation.get_exit_statuses(['NON_EXISTING_EXIT_CODE_LABEL']) def test_exit_codes_invalidate_cache(self): @@ -209,14 +209,14 @@ def test_exit_codes_invalidate_cache(self): with enable_caching(): _, node1 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(False)) _, node2 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(False)) - self.assertEqual(node1.get_extra('_aiida_hash'), node2.get_extra('_aiida_hash')) - self.assertIn('_aiida_cached_from', node2.extras) + assert node1.get_extra('_aiida_hash') == node2.get_extra('_aiida_hash') + assert '_aiida_cached_from' in node2.extras with enable_caching(): _, node3 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(True)) _, node4 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(True)) - self.assertEqual(node3.get_extra('_aiida_hash'), node4.get_extra('_aiida_hash')) - self.assertNotIn('_aiida_cached_from', node4.extras) + assert node3.get_extra('_aiida_hash') == node4.get_extra('_aiida_hash') + assert '_aiida_cached_from' not in node4.extras def test_valid_cache_hook(self): """ @@ -227,14 +227,14 @@ def test_valid_cache_hook(self): with enable_caching(): _, node1 = run_get_node(test_processes.IsValidCacheHook) _, node2 = run_get_node(test_processes.IsValidCacheHook) - self.assertEqual(node1.get_extra('_aiida_hash'), node2.get_extra('_aiida_hash')) - self.assertIn('_aiida_cached_from', node2.extras) + assert node1.get_extra('_aiida_hash') == node2.get_extra('_aiida_hash') + assert '_aiida_cached_from' in node2.extras with enable_caching(): _, node3 = run_get_node(test_processes.IsValidCacheHook, not_valid_cache=orm.Bool(True)) _, node4 = run_get_node(test_processes.IsValidCacheHook, not_valid_cache=orm.Bool(True)) - self.assertEqual(node3.get_extra('_aiida_hash'), node4.get_extra('_aiida_hash')) - self.assertNotIn('_aiida_cached_from', node4.extras) + assert node3.get_extra('_aiida_hash') == node4.get_extra('_aiida_hash') + assert '_aiida_cached_from' not in node4.extras def test_process_type_with_entry_point(self): """For a process with a registered entry point, the process_type will be its formatted entry point string.""" @@ -268,11 +268,11 @@ def test_process_type_with_entry_point(self): process = process_class(inputs=inputs) expected_process_type = f'aiida.calculations:{entry_point}' - self.assertEqual(process.node.process_type, expected_process_type) + assert process.node.process_type == expected_process_type # Verify that process_class on the calculation node returns the original entry point class recovered_process = process.node.process_class - self.assertEqual(recovered_process, process_class) + assert recovered_process == process_class def test_process_type_without_entry_point(self): """ @@ -281,11 +281,11 @@ def test_process_type_without_entry_point(self): """ process = test_processes.DummyProcess() expected_process_type = f'{process.__class__.__module__}.{process.__class__.__name__}' - self.assertEqual(process.node.process_type, expected_process_type) + assert process.node.process_type == expected_process_type # Verify that process_class on the calculation node returns the original entry point class recovered_process = process.node.process_class - self.assertEqual(recovered_process, process.__class__) + assert recovered_process == process.__class__ def test_output_dictionary(self): """Verify that a dictionary can be passed as an output for a namespace.""" @@ -306,9 +306,9 @@ def run(self): results, node = run_get_node(TestProcess1, namespace={'alpha': orm.Int(1), 'beta': orm.Int(2)}) - self.assertTrue(node.is_finished_ok) - self.assertEqual(results['namespace']['alpha'], orm.Int(1)) - self.assertEqual(results['namespace']['beta'], orm.Int(2)) + assert node.is_finished_ok + assert results['namespace']['alpha'] == orm.Int(1) + assert results['namespace']['beta'] == orm.Int(2) def test_output_validation_error(self): """Test that a process is marked as failed if its output namespace validation fails.""" @@ -334,16 +334,16 @@ def run(self): # For default inputs, no outputs will be attached, causing the validation to fail at the end so an internal # exit status will be set, which is a negative integer - self.assertTrue(node.is_finished) - self.assertFalse(node.is_finished_ok) - self.assertEqual(node.exit_status, TestProcess1.exit_codes.ERROR_MISSING_OUTPUT.status) - self.assertEqual(node.exit_message, TestProcess1.exit_codes.ERROR_MISSING_OUTPUT.message) + assert node.is_finished + assert not node.is_finished_ok + assert node.exit_status == TestProcess1.exit_codes.ERROR_MISSING_OUTPUT.status + assert node.exit_message == TestProcess1.exit_codes.ERROR_MISSING_OUTPUT.message # When settings `add_outputs` to True, the outputs should be added and validation should pass _, node = run_get_node(TestProcess1, add_outputs=orm.Bool(True)) - self.assertTrue(node.is_finished) - self.assertTrue(node.is_finished_ok) - self.assertEqual(node.exit_status, 0) + assert node.is_finished + assert node.is_finished_ok + assert node.exit_status == 0 def test_exposed_outputs(self): """Test the ``Process.exposed_outputs`` method.""" @@ -392,7 +392,7 @@ def define(cls, spec): }, 'output': node_output, }) - self.assertEqual(exposed_outputs, expected) + assert exposed_outputs == expected def test_exposed_outputs_non_existing_namespace(self): """Test the ``Process.exposed_outputs`` method for non-existing namespace.""" @@ -434,5 +434,5 @@ def define(cls, spec): process = instantiate_process(runner, ParentProcess, input=orm.Int(1)) # If the ``namespace`` does not exist, for example because it is slightly misspelled, a ``KeyError`` is raised - with self.assertRaises(KeyError): + with pytest.raises(KeyError): process.exposed_outputs(node_child, ChildProcess, namespace='cildh') diff --git a/tests/engine/test_process_function.py b/tests/engine/test_process_function.py index c118b965f7..ca25bc1216 100644 --- a/tests/engine/test_process_function.py +++ b/tests/engine/test_process_function.py @@ -7,13 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the process_function decorator.""" import pytest from aiida import orm from aiida.engine import ExitCode, Process, calcfunction, run, run_get_node, submit, workfunction from aiida.orm.nodes.data.bool import get_true_node -from aiida.storage.testbase import AiidaTestCase from aiida.workflows.arithmetic.add_multiply import add_multiply DEFAULT_INT = 256 @@ -24,7 +24,7 @@ @pytest.mark.requires_rmq -class TestProcessFunction(AiidaTestCase): +class TestProcessFunction: """ Note that here we use `@workfunctions` and `@calculations`, the concrete versions of the `@process_function` decorator, even though we are testing only the shared functionality @@ -38,9 +38,11 @@ class TestProcessFunction(AiidaTestCase): # pylint: disable=too-many-public-methods,too-many-instance-attributes - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + assert Process.current() is None @workfunction def function_return_input(data): @@ -117,16 +119,15 @@ def function_out_unstored(): self.function_excepts = function_excepts self.function_out_unstored = function_out_unstored - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + yield + assert Process.current() is None def test_properties(self): """Test that the `is_process_function` and `node_class` attributes are set.""" - self.assertEqual(self.function_return_input.is_process_function, True) - self.assertEqual(self.function_return_input.node_class, orm.WorkFunctionNode) - self.assertEqual(self.function_return_true.is_process_function, True) - self.assertEqual(self.function_return_true.node_class, orm.CalcFunctionNode) + assert self.function_return_input.is_process_function is True + assert self.function_return_input.node_class == orm.WorkFunctionNode + assert self.function_return_true.is_process_function is True + assert self.function_return_true.node_class == orm.CalcFunctionNode def test_plugin_version(self): """Test the version attributes of a process function.""" @@ -137,32 +138,32 @@ def test_plugin_version(self): # Since the "plugin" i.e. the process function is defined in `aiida-core` the `version.plugin` is the same as # the version of `aiida-core` itself version_info = node.get_attribute('version') - self.assertEqual(version_info['core'], version_core) - self.assertEqual(version_info['plugin'], version_core) + assert version_info['core'] == version_core + assert version_info['plugin'] == version_core def test_process_state(self): """Test the process state for a process function.""" _, node = self.function_args_with_default.run_get_node() - self.assertEqual(node.is_terminated, True) - self.assertEqual(node.is_excepted, False) - self.assertEqual(node.is_killed, False) - self.assertEqual(node.is_finished, True) - self.assertEqual(node.is_finished_ok, True) - self.assertEqual(node.is_failed, False) + assert node.is_terminated is True + assert node.is_excepted is False + assert node.is_killed is False + assert node.is_finished is True + assert node.is_finished_ok is True + assert node.is_failed is False def test_process_type(self): """Test that the process type correctly contains the module and name of original decorated function.""" _, node = self.function_defaults.run_get_node() process_type = f'{self.function_defaults.__module__}.{self.function_defaults.__name__}' - self.assertEqual(node.process_type, process_type) + assert node.process_type == process_type def test_exit_status(self): """A FINISHED process function has to have an exit status of 0""" _, node = self.function_args_with_default.run_get_node() - self.assertEqual(node.exit_status, 0) - self.assertEqual(node.is_finished_ok, True) - self.assertEqual(node.is_failed, False) + assert node.exit_status == 0 + assert node.is_finished_ok is True + assert node.is_failed is False def test_source_code_attributes(self): """Verify function properties are properly introspected and stored in the nodes attributes and repository.""" @@ -178,18 +179,18 @@ def test_process_function(data): function_source_code = node.get_function_source_code().split('\n') # Verify that the function name is correct and the first source code linenumber is stored - self.assertEqual(node.function_name, function_name) - self.assertIsInstance(node.function_starting_line_number, int) + assert node.function_name == function_name + assert isinstance(node.function_starting_line_number, int) # Check that first line number is correct. Note that the first line should correspond # to the `@workfunction` directive, but since the list is zero-indexed we actually get the # following line, which should correspond to the function name i.e. `def test_process_function(data)` function_name_from_source = function_source_code[node.function_starting_line_number] - self.assertTrue(node.function_name in function_name_from_source) + assert node.function_name in function_name_from_source def test_function_varargs(self): """Variadic arguments are not supported and should raise.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): @workfunction def function_varargs(*args): # pylint: disable=unused-variable @@ -199,24 +200,24 @@ def test_function_args(self): """Simple process function that defines a single positional argument.""" arg = 1 - with self.assertRaises(ValueError): + with pytest.raises(ValueError): result = self.function_args() # pylint: disable=no-value-for-parameter result = self.function_args(data_a=orm.Int(arg)) - self.assertTrue(isinstance(result, orm.Int)) - self.assertEqual(result, arg) + assert isinstance(result, orm.Int) + assert result == arg def test_function_args_with_default(self): """Simple process function that defines a single argument with a default.""" arg = 1 result = self.function_args_with_default() - self.assertTrue(isinstance(result, orm.Int)) - self.assertEqual(result, orm.Int(DEFAULT_INT)) + assert isinstance(result, orm.Int) + assert result == orm.Int(DEFAULT_INT) result = self.function_args_with_default(data_a=orm.Int(arg)) - self.assertTrue(isinstance(result, orm.Int)) - self.assertEqual(result, arg) + assert isinstance(result, orm.Int) + assert result == arg def test_function_with_none_default(self): """Simple process function that defines a keyword with `None` as default value.""" @@ -225,32 +226,32 @@ def test_function_with_none_default(self): int_c = orm.Int(3) result = self.function_with_none_default(int_a, int_b) - self.assertTrue(isinstance(result, orm.Int)) - self.assertEqual(result, orm.Int(3)) + assert isinstance(result, orm.Int) + assert result == orm.Int(3) result = self.function_with_none_default(int_a, int_b, int_c) - self.assertTrue(isinstance(result, orm.Int)) - self.assertEqual(result, orm.Int(6)) + assert isinstance(result, orm.Int) + assert result == orm.Int(6) def test_function_kwargs(self): """Simple process function that defines keyword arguments.""" kwargs = {'data_a': orm.Int(DEFAULT_INT)} result, node = self.function_kwargs.run_get_node() - self.assertTrue(isinstance(result, dict)) - self.assertEqual(len(node.get_incoming().all()), 0) - self.assertEqual(result, {}) + assert isinstance(result, dict) + assert len(node.get_incoming().all()) == 0 + assert result == {} result, node = self.function_kwargs.run_get_node(**kwargs) - self.assertTrue(isinstance(result, dict)) - self.assertEqual(len(node.get_incoming().all()), 1) - self.assertEqual(result, kwargs) + assert isinstance(result, dict) + assert len(node.get_incoming().all()) == 1 + assert result == kwargs # Calling with any number of positional arguments should raise - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.function_kwargs.run_get_node(orm.Int(1)) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.function_kwargs.run_get_node(orm.Int(1), b=orm.Int(2)) def test_function_args_and_kwargs(self): @@ -260,18 +261,18 @@ def test_function_args_and_kwargs(self): kwargs = {'data_b': orm.Int(arg)} result = self.function_args_and_kwargs(*args) - self.assertTrue(isinstance(result, dict)) - self.assertEqual(result, {'data_a': args[0]}) + assert isinstance(result, dict) + assert result == {'data_a': args[0]} result = self.function_args_and_kwargs(*args, **kwargs) - self.assertTrue(isinstance(result, dict)) - self.assertEqual(result, {'data_a': args[0], 'data_b': kwargs['data_b']}) + assert isinstance(result, dict) + assert result == {'data_a': args[0], 'data_b': kwargs['data_b']} # Calling with more positional arguments than defined in the signature should raise - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.function_kwargs.run_get_node(orm.Int(1), orm.Int(2)) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.function_kwargs.run_get_node(orm.Int(1), orm.Int(2), b=orm.Int(2)) def test_function_args_and_kwargs_default(self): @@ -281,18 +282,18 @@ def test_function_args_and_kwargs_default(self): args_input_explicit = (orm.Int(DEFAULT_INT), orm.Int(arg)) result = self.function_args_and_default(*args_input_default) - self.assertTrue(isinstance(result, dict)) - self.assertEqual(result, {'data_a': args_input_default[0], 'data_b': orm.Int(DEFAULT_INT)}) + assert isinstance(result, dict) + assert result == {'data_a': args_input_default[0], 'data_b': orm.Int(DEFAULT_INT)} result = self.function_args_and_default(*args_input_explicit) - self.assertTrue(isinstance(result, dict)) - self.assertEqual(result, {'data_a': args_input_explicit[0], 'data_b': args_input_explicit[1]}) + assert isinstance(result, dict) + assert result == {'data_a': args_input_explicit[0], 'data_b': args_input_explicit[1]} def test_function_args_passing_kwargs(self): """Cannot pass kwargs if the function does not explicitly define it accepts kwargs.""" arg = 1 - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.function_args(data_a=orm.Int(arg), data_b=orm.Int(arg)) # pylint: disable=unexpected-keyword-arg def test_function_set_label_description(self): @@ -300,58 +301,58 @@ def test_function_set_label_description(self): metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION} _, node = self.function_args.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) - self.assertEqual(node.label, CUSTOM_LABEL) - self.assertEqual(node.description, CUSTOM_DESCRIPTION) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION _, node = self.function_args_with_default.run_get_node(metadata=metadata) - self.assertEqual(node.label, CUSTOM_LABEL) - self.assertEqual(node.description, CUSTOM_DESCRIPTION) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION _, node = self.function_kwargs.run_get_node(metadata=metadata) - self.assertEqual(node.label, CUSTOM_LABEL) - self.assertEqual(node.description, CUSTOM_DESCRIPTION) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION _, node = self.function_args_and_kwargs.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) - self.assertEqual(node.label, CUSTOM_LABEL) - self.assertEqual(node.description, CUSTOM_DESCRIPTION) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION _, node = self.function_args_and_default.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) - self.assertEqual(node.label, CUSTOM_LABEL) - self.assertEqual(node.description, CUSTOM_DESCRIPTION) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION def test_function_defaults(self): """Verify that a process function can define a default label and description but can be overriden.""" metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION} _, node = self.function_defaults.run_get_node(data_a=orm.Int(DEFAULT_INT)) - self.assertEqual(node.label, DEFAULT_LABEL) - self.assertEqual(node.description, DEFAULT_DESCRIPTION) + assert node.label == DEFAULT_LABEL + assert node.description == DEFAULT_DESCRIPTION _, node = self.function_defaults.run_get_node(metadata=metadata) - self.assertEqual(node.label, CUSTOM_LABEL) - self.assertEqual(node.description, CUSTOM_DESCRIPTION) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION def test_function_default_label(self): """Verify unless specified label is taken from function name.""" metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION} _, node = self.function_default_label.run_get_node() - self.assertEqual(node.label, 'function_default_label') - self.assertEqual(node.description, '') + assert node.label == 'function_default_label' + assert node.description == '' _, node = self.function_default_label.run_get_node(metadata=metadata) - self.assertEqual(node.label, CUSTOM_LABEL) - self.assertEqual(node.description, CUSTOM_DESCRIPTION) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION def test_launchers(self): """Verify that the various launchers are working.""" result = run(self.function_return_true) - self.assertTrue(result) + assert result result, node = run_get_node(self.function_return_true) - self.assertTrue(result) - self.assertEqual(result, get_true_node()) - self.assertTrue(isinstance(node, orm.CalcFunctionNode)) + assert result + assert result == get_true_node() + assert isinstance(node, orm.CalcFunctionNode) # Process function can be submitted and will be run by a daemon worker as long as the function is importable # Note that the actual running is not tested here but is done so in `.github/system_tests/test_daemon.py`. @@ -368,23 +369,23 @@ def test_return_exit_code(self): message = orm.Str(exit_message) _, node = self.function_exit_code.run_get_node(exit_status=orm.Int(exit_status), exit_message=message) - self.assertTrue(node.is_finished) - self.assertFalse(node.is_finished_ok) - self.assertEqual(node.exit_status, exit_status) - self.assertEqual(node.exit_message, exit_message) + assert node.is_finished + assert not node.is_finished_ok + assert node.exit_status == exit_status + assert node.exit_message == exit_message def test_normal_exception(self): """If a process, for example a FunctionProcess, excepts, the exception should be stored in the node.""" exception = 'This process function excepted' - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): _, node = self.function_excepts.run_get_node(exception=orm.Str(exception)) - self.assertTrue(node.is_excepted) - self.assertEqual(node.exception, exception) + assert node.is_excepted + assert node.exception == exception def test_function_out_unstored(self): """A workfunction that returns an unstored node should raise as it indicates users tried to create data.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.function_out_unstored() def test_simple_workflow(self): @@ -404,21 +405,21 @@ def add_mul_wf(data_a, data_b, data_c): result, node = add_mul_wf.run_get_node(orm.Int(3), orm.Int(4), orm.Int(5)) - self.assertEqual(result, (3 + 4) * 5) - self.assertIsInstance(node, orm.WorkFunctionNode) + assert result == (3 + 4) * 5 + assert isinstance(node, orm.WorkFunctionNode) def test_hashes(self): """Test that the hashes generated for identical process functions with identical inputs are the same.""" _, node1 = self.function_return_input.run_get_node(data=orm.Int(2)) _, node2 = self.function_return_input.run_get_node(data=orm.Int(2)) - self.assertEqual(node1.get_hash(), node1.get_extra('_aiida_hash')) - self.assertEqual(node2.get_hash(), node2.get_extra('_aiida_hash')) - self.assertEqual(node1.get_hash(), node2.get_hash()) + assert node1.get_hash() == node1.get_extra('_aiida_hash') + assert node2.get_hash() == node2.get_extra('_aiida_hash') + assert node1.get_hash() == node2.get_hash() def test_hashes_different(self): """Test that the hashes generated for identical process functions with different inputs are the different.""" _, node1 = self.function_return_input.run_get_node(data=orm.Int(2)) _, node2 = self.function_return_input.run_get_node(data=orm.Int(3)) - self.assertEqual(node1.get_hash(), node1.get_extra('_aiida_hash')) - self.assertEqual(node2.get_hash(), node2.get_extra('_aiida_hash')) - self.assertNotEqual(node1.get_hash(), node2.get_hash()) + assert node1.get_hash() == node1.get_extra('_aiida_hash') + assert node2.get_hash() == node2.get_extra('_aiida_hash') + assert node1.get_hash() != node2.get_hash() diff --git a/tests/engine/test_process_spec.py b/tests/engine/test_process_spec.py index 345e0d6f11..3349af4e49 100644 --- a/tests/engine/test_process_spec.py +++ b/tests/engine/test_process_spec.py @@ -7,44 +7,45 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the `ProcessSpec` class.""" +import pytest from aiida.engine import Process from aiida.orm import Data, Node -from aiida.storage.testbase import AiidaTestCase -class TestProcessSpec(AiidaTestCase): +class TestProcessSpec: """Tests for the `ProcessSpec` class.""" - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + assert Process.current() is None self.spec = Process.spec() self.spec.inputs.valid_type = Data self.spec.outputs.valid_type = Data - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + yield + assert Process.current() is None def test_dynamic_input(self): """Test a process spec with dynamic input enabled.""" node = Node() data = Data() - self.assertIsNotNone(self.spec.inputs.validate({'key': 'foo'})) - self.assertIsNotNone(self.spec.inputs.validate({'key': 5})) - self.assertIsNotNone(self.spec.inputs.validate({'key': node})) - self.assertIsNone(self.spec.inputs.validate({'key': data})) + assert self.spec.inputs.validate({'key': 'foo'}) is not None + assert self.spec.inputs.validate({'key': 5}) is not None + assert self.spec.inputs.validate({'key': node}) is not None + assert self.spec.inputs.validate({'key': data}) is None def test_dynamic_output(self): """Test a process spec with dynamic output enabled.""" node = Node() data = Data() - self.assertIsNotNone(self.spec.outputs.validate({'key': 'foo'})) - self.assertIsNotNone(self.spec.outputs.validate({'key': 5})) - self.assertIsNotNone(self.spec.outputs.validate({'key': node})) - self.assertIsNone(self.spec.outputs.validate({'key': data})) + assert self.spec.outputs.validate({'key': 'foo'}) is not None + assert self.spec.outputs.validate({'key': 5}) is not None + assert self.spec.outputs.validate({'key': node}) is not None + assert self.spec.outputs.validate({'key': data}) is None def test_exit_code(self): """Test the definition of error codes through the ProcessSpec.""" @@ -54,14 +55,14 @@ def test_exit_code(self): self.spec.exit_code(status, label, message) - self.assertEqual(self.spec.exit_codes.SOME_EXIT_CODE.status, status) - self.assertEqual(self.spec.exit_codes.SOME_EXIT_CODE.message, message) + assert self.spec.exit_codes.SOME_EXIT_CODE.status == status + assert self.spec.exit_codes.SOME_EXIT_CODE.message == message - self.assertEqual(self.spec.exit_codes['SOME_EXIT_CODE'].status, status) - self.assertEqual(self.spec.exit_codes['SOME_EXIT_CODE'].message, message) + assert self.spec.exit_codes['SOME_EXIT_CODE'].status == status + assert self.spec.exit_codes['SOME_EXIT_CODE'].message == message - self.assertEqual(self.spec.exit_codes[label].status, status) - self.assertEqual(self.spec.exit_codes[label].message, message) + assert self.spec.exit_codes[label].status == status + assert self.spec.exit_codes[label].message == message def test_exit_code_invalid(self): """Test type validation for registering new error codes.""" @@ -69,14 +70,14 @@ def test_exit_code_invalid(self): label = 'SOME_EXIT_CODE' message = 'I am a teapot' - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.spec.exit_code(status, 256, message) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.spec.exit_code('string', label, message) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.spec.exit_code(-256, label, message) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.spec.exit_code(status, label, 8) diff --git a/tests/engine/test_rmq.py b/tests/engine/test_rmq.py index 5b49308ee0..8d854ffa14 100644 --- a/tests/engine/test_rmq.py +++ b/tests/engine/test_rmq.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Module to test RabbitMQ.""" import asyncio @@ -16,18 +17,19 @@ from aiida.engine import ProcessState from aiida.manage import get_manager from aiida.orm import Int -from aiida.storage.testbase import AiidaTestCase from tests.utils import processes as test_processes @pytest.mark.requires_rmq -class TestProcessControl(AiidaTestCase): +class TestProcessControl: """Test AiiDA's RabbitMQ functionalities.""" TIMEOUT = 2. - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init # The coroutine defined in testcase should run in runner's loop # and process need submit by runner.submit rather than `submit` import from @@ -35,10 +37,6 @@ def setUp(self): manager = get_manager() self.runner = manager.get_runner() - def tearDown(self): - self.runner.close() - super().tearDown() - def test_submit_simple(self): """"Launch the process.""" @@ -46,8 +44,8 @@ async def do_submit(): calc_node = self.runner.submit(test_processes.DummyProcess) await self.wait_for_process(calc_node) - self.assertTrue(calc_node.is_finished_ok) - self.assertEqual(calc_node.process_state.value, plumpy.ProcessState.FINISHED.value) + assert calc_node.is_finished_ok + assert calc_node.process_state.value == plumpy.ProcessState.FINISHED.value self.runner.loop.run_until_complete(do_submit()) @@ -60,13 +58,13 @@ async def do_launch(): calc_node = self.runner.submit(test_processes.AddProcess, a=term_a, b=term_b) await self.wait_for_process(calc_node) - self.assertTrue(calc_node.is_finished_ok) - self.assertEqual(calc_node.process_state.value, plumpy.ProcessState.FINISHED.value) + assert calc_node.is_finished_ok + assert calc_node.process_state.value == plumpy.ProcessState.FINISHED.value self.runner.loop.run_until_complete(do_launch()) def test_submit_bad_input(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.runner.submit(test_processes.AddProcess, a=Int(5)) def test_exception_process(self): @@ -76,8 +74,8 @@ async def do_exception(): calc_node = self.runner.submit(test_processes.ExceptionProcess) await self.wait_for_process(calc_node) - self.assertFalse(calc_node.is_finished_ok) - self.assertEqual(calc_node.process_state.value, plumpy.ProcessState.EXCEPTED.value) + assert not calc_node.is_finished_ok + assert calc_node.process_state.value == plumpy.ProcessState.EXCEPTED.value self.runner.loop.run_until_complete(do_exception()) @@ -91,19 +89,19 @@ async def do_pause(): while calc_node.process_state != ProcessState.WAITING: await asyncio.sleep(0.1) - self.assertFalse(calc_node.paused) + assert not calc_node.paused pause_future = controller.pause_process(calc_node.pk) future = await with_timeout(asyncio.wrap_future(pause_future)) result = await self.wait_future(asyncio.wrap_future(future)) - self.assertTrue(result) - self.assertTrue(calc_node.paused) + assert result + assert calc_node.paused kill_message = 'Sorry, you have to go mate' kill_future = controller.kill_process(calc_node.pk, msg=kill_message) future = await with_timeout(asyncio.wrap_future(kill_future)) result = await self.wait_future(asyncio.wrap_future(future)) - self.assertTrue(result) + assert result self.runner.loop.run_until_complete(do_pause()) @@ -114,7 +112,7 @@ def test_pause_play(self): async def do_pause_play(): calc_node = self.runner.submit(test_processes.WaitProcess) - self.assertFalse(calc_node.paused) + assert not calc_node.paused while calc_node.process_state != ProcessState.WAITING: await asyncio.sleep(0.1) @@ -122,22 +120,22 @@ async def do_pause_play(): pause_future = controller.pause_process(calc_node.pk, msg=pause_message) future = await with_timeout(asyncio.wrap_future(pause_future)) result = await self.wait_future(asyncio.wrap_future(future)) - self.assertTrue(calc_node.paused) - self.assertEqual(calc_node.process_status, pause_message) + assert calc_node.paused + assert calc_node.process_status == pause_message play_future = controller.play_process(calc_node.pk) future = await with_timeout(asyncio.wrap_future(play_future)) result = await self.wait_future(asyncio.wrap_future(future)) - self.assertTrue(result) - self.assertFalse(calc_node.paused) - self.assertEqual(calc_node.process_status, None) + assert result + assert not calc_node.paused + assert calc_node.process_status is None kill_message = 'Sorry, you have to go mate' kill_future = controller.kill_process(calc_node.pk, msg=kill_message) future = await with_timeout(asyncio.wrap_future(kill_future)) result = await self.wait_future(asyncio.wrap_future(future)) - self.assertTrue(result) + assert result self.runner.loop.run_until_complete(do_pause_play()) @@ -148,7 +146,7 @@ def test_kill(self): async def do_kill(): calc_node = self.runner.submit(test_processes.WaitProcess) - self.assertFalse(calc_node.is_killed) + assert not calc_node.is_killed while calc_node.process_state != ProcessState.WAITING: await asyncio.sleep(0.1) @@ -156,11 +154,11 @@ async def do_kill(): kill_future = controller.kill_process(calc_node.pk, msg=kill_message) future = await with_timeout(asyncio.wrap_future(kill_future)) result = await self.wait_future(asyncio.wrap_future(future)) - self.assertTrue(result) + assert result await self.wait_for_process(calc_node) - self.assertTrue(calc_node.is_killed) - self.assertEqual(calc_node.process_status, kill_message) + assert calc_node.is_killed + assert calc_node.process_status == kill_message self.runner.loop.run_until_complete(do_kill()) diff --git a/tests/engine/test_run.py b/tests/engine/test_run.py index a150f21b28..221f049b6c 100644 --- a/tests/engine/test_run.py +++ b/tests/engine/test_run.py @@ -7,17 +7,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the `run` functions.""" import pytest from aiida.engine import run, run_get_node from aiida.orm import Int, ProcessNode, Str -from aiida.storage.testbase import AiidaTestCase from tests.utils.processes import DummyProcess @pytest.mark.requires_rmq -class TestRun(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestRun: """Tests for the `run` functions.""" @staticmethod @@ -30,4 +31,4 @@ def test_run_get_node(self): """Test the `run_get_node` function.""" inputs = {'a': Int(2), 'b': Str('test')} result, node = run_get_node(DummyProcess, **inputs) # pylint: disable=unused-variable - self.assertIsInstance(node, ProcessNode) + assert isinstance(node, ProcessNode) diff --git a/tests/engine/test_transport.py b/tests/engine/test_transport.py index 7bd0e519e1..f9354d2615 100644 --- a/tests/engine/test_transport.py +++ b/tests/engine/test_transport.py @@ -7,25 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Module to test transport.""" import asyncio +import pytest + from aiida import orm from aiida.engine.transports import TransportQueue -from aiida.storage.testbase import AiidaTestCase -class TestTransportQueue(AiidaTestCase): +class TestTransportQueue: """Tests for the transport queue.""" - def setUp(self, *args, **kwargs): # pylint: disable=arguments-differ - """ Set up a simple authinfo and for later use """ - super().setUp(*args, **kwargs) - self.authinfo = orm.AuthInfo(computer=self.computer, user=orm.User.objects.get_default()).store() - - def tearDown(self, *args, **kwargs): # pylint: disable=arguments-differ - orm.AuthInfo.objects.delete(self.authinfo.id) - super().tearDown(*args, **kwargs) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost + self.authinfo = self.computer.get_authinfo(orm.User.objects.get_default()) def test_simple_request(self): """ Test a simple transport request """ @@ -36,8 +36,8 @@ async def test(): trans = None with queue.request_transport(self.authinfo) as request: trans = await request - self.assertTrue(trans.is_open) - self.assertFalse(trans.is_open) + assert trans.is_open + assert not trans.is_open loop.run_until_complete(test()) @@ -49,11 +49,11 @@ def test_get_transport_nested(self): async def nested(queue, authinfo): with queue.request_transport(authinfo) as request1: trans1 = await request1 - self.assertTrue(trans1.is_open) + assert trans1.is_open with queue.request_transport(authinfo) as request2: trans2 = await request2 - self.assertIs(trans1, trans2) - self.assertTrue(trans2.is_open) + assert trans1 is trans2 + assert trans2.is_open loop.run_until_complete(nested(transport_queue, self.authinfo)) @@ -79,7 +79,7 @@ async def test(): return trans.is_open retval = loop.run_until_complete(test()) - self.assertTrue(retval) + assert retval def test_open_fail(self): """Test that if opening fails.""" @@ -98,7 +98,7 @@ def broken_open(trans): # Let's put in a broken open method original = self.authinfo.get_transport().__class__.open self.authinfo.get_transport().__class__.open = broken_open - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): loop.run_until_complete(test()) finally: self.authinfo.get_transport().__class__.open = original @@ -126,7 +126,7 @@ async def test(iteration): time_current = time.time() time_elapsed = time_current - time_start time_minimum = trans.get_safe_open_interval() * (iteration + 1) - self.assertTrue(time_elapsed > time_minimum, 'transport safe interval was violated') + assert time_elapsed > time_minimum, 'transport safe interval was violated' for iteration in range(5): loop.run_until_complete(test(iteration)) diff --git a/tests/engine/test_utils.py b/tests/engine/test_utils.py index 8018ed0ef0..06ed7c69fa 100644 --- a/tests/engine/test_utils.py +++ b/tests/engine/test_utils.py @@ -16,21 +16,20 @@ from aiida import orm from aiida.engine import calcfunction, workfunction from aiida.engine.utils import InterruptableFuture, exponential_backoff_retry, interruptable_task, is_process_function -from aiida.storage.testbase import AiidaTestCase ITERATION = 0 MAX_ITERATIONS = 3 -class TestExponentialBackoffRetry(AiidaTestCase): +class TestExponentialBackoffRetry: """Tests for the exponential backoff retry coroutine.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - """Set up a simple authinfo and for later use.""" - super().setUpClass(*args, **kwargs) - cls.authinfo = orm.AuthInfo(computer=cls.computer, user=orm.User.objects.get_default()) - cls.authinfo.store() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost + self.authinfo = self.computer.get_authinfo(orm.User.objects.get_default()) @staticmethod def test_exp_backoff_success(): @@ -63,39 +62,31 @@ def coro(): raise RuntimeError max_attempts = MAX_ITERATIONS - 1 - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): loop.run_until_complete(exponential_backoff_retry(coro, initial_interval=0.1, max_attempts=max_attempts)) -class TestUtils(AiidaTestCase): - """ Tests for engine utils.""" +def test_is_process_function(): + """Test the `is_process_function` utility.""" - def test_is_process_function(self): - """Test the `is_process_function` utility.""" - - def normal_function(): - pass - - @calcfunction - def calc_function(): - pass - - @workfunction - def work_function(): - pass - - self.assertEqual(is_process_function(normal_function), False) - self.assertEqual(is_process_function(calc_function), True) - self.assertEqual(is_process_function(work_function), True) + def normal_function(): + pass - def test_is_process_scoped(self): + @calcfunction + def calc_function(): pass - def test_loop_scope(self): + @workfunction + def work_function(): pass + assert is_process_function(normal_function) is False + assert is_process_function(calc_function) is True + assert is_process_function(work_function) is True + -class TestInterruptable(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestInterruptable: """ Tests for InterruptableFuture and interruptable_task.""" def test_normal_future(self): @@ -109,8 +100,8 @@ async def task(): fut.set_result('I am done') loop.run_until_complete(interruptable.with_interrupt(task())) - self.assertFalse(interruptable.done()) - self.assertEqual(fut.result(), 'I am done') + assert not interruptable.done() + assert fut.result() == 'I am done' def test_interrupt(self): """Test interrupt future being interrupted""" @@ -121,11 +112,11 @@ def test_interrupt(self): try: loop.run_until_complete(interruptable.with_interrupt(asyncio.sleep(10.))) except RuntimeError as err: - self.assertEqual(str(err), 'STOP') + assert str(err) == 'STOP' else: - self.fail('ExpectedException not raised') + pytest.fail('ExpectedException not raised') - self.assertTrue(interruptable.done()) + assert interruptable.done() def test_inside_interrupted(self): """Test interrupt future being interrupted from inside of coroutine""" @@ -142,12 +133,12 @@ async def task(): try: loop.run_until_complete(interruptable.with_interrupt(task())) except RuntimeError as err: - self.assertEqual(str(err), 'STOP') + assert str(err) == 'STOP' else: - self.fail('ExpectedException not raised') + pytest.fail('ExpectedException not raised') - self.assertTrue(interruptable.done()) - self.assertEqual(fut.result(), 'I got set.') + assert interruptable.done() + assert fut.result() == 'I got set.' def test_interruptable_future_set(self): """Test interrupt future being set before coroutine is done""" @@ -162,11 +153,11 @@ async def task(): try: loop.run_until_complete(interruptable.with_interrupt(asyncio.sleep(20.))) except RuntimeError as err: - self.assertEqual(str(err), "This interruptible future had it's result set unexpectedly to 'NOT ME!!!'") + assert str(err) == "This interruptible future had it's result set unexpectedly to 'NOT ME!!!'" else: - self.fail('ExpectedException not raised') + pytest.fail('ExpectedException not raised') - self.assertTrue(interruptable.done()) + assert interruptable.done() @pytest.mark.requires_rmq diff --git a/tests/engine/test_work_chain.py b/tests/engine/test_work_chain.py index 65ec106910..67318a18b2 100644 --- a/tests/engine/test_work_chain.py +++ b/tests/engine/test_work_chain.py @@ -23,7 +23,6 @@ from aiida.engine.persistence import ObjectLoader from aiida.manage import get_manager from aiida.orm import Bool, Float, Int, Str, load_node -from aiida.storage.testbase import AiidaTestCase def run_until_paused(proc): @@ -187,7 +186,8 @@ def success(self): @pytest.mark.requires_rmq -class TestExitStatus(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestExitStatus: """ This class should test the various ways that one can exit from the outline flow of a WorkChain, other than it running it all the way through. Currently this can be done directly in the outline by calling the `return_` @@ -196,53 +196,49 @@ class TestExitStatus(AiidaTestCase): def test_failing_workchain_through_integer(self): _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(False)) - self.assertEqual(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) - self.assertEqual(node.exit_message, None) - self.assertEqual(node.is_finished, True) - self.assertEqual(node.is_finished_ok, False) - self.assertEqual(node.is_failed, True) - self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) + assert node.exit_status == PotentialFailureWorkChain.EXIT_STATUS + assert node.exit_message is None + assert node.is_finished is True + assert node.is_finished_ok is False + assert node.is_failed is True + assert PotentialFailureWorkChain.OUTPUT_LABEL not in node.get_outgoing().all_link_labels() def test_failing_workchain_through_exit_code(self): _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(False), through_exit_code=Bool(True)) - self.assertEqual(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) - self.assertEqual(node.exit_message, PotentialFailureWorkChain.EXIT_MESSAGE) - self.assertEqual(node.is_finished, True) - self.assertEqual(node.is_finished_ok, False) - self.assertEqual(node.is_failed, True) - self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) + assert node.exit_status == PotentialFailureWorkChain.EXIT_STATUS + assert node.exit_message == PotentialFailureWorkChain.EXIT_MESSAGE + assert node.is_finished is True + assert node.is_finished_ok is False + assert node.is_failed is True + assert PotentialFailureWorkChain.OUTPUT_LABEL not in node.get_outgoing().all_link_labels() def test_successful_workchain_through_integer(self): _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True)) - self.assertEqual(node.exit_status, 0) - self.assertEqual(node.is_finished, True) - self.assertEqual(node.is_finished_ok, True) - self.assertEqual(node.is_failed, False) - self.assertIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) - self.assertEqual( - node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL), + assert node.exit_status == 0 + assert node.is_finished is True + assert node.is_finished_ok is True + assert node.is_failed is False + assert PotentialFailureWorkChain.OUTPUT_LABEL in node.get_outgoing().all_link_labels() + assert node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL) == \ PotentialFailureWorkChain.OUTPUT_VALUE - ) def test_successful_workchain_through_exit_code(self): _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True), through_exit_code=Bool(True)) - self.assertEqual(node.exit_status, 0) - self.assertEqual(node.is_finished, True) - self.assertEqual(node.is_finished_ok, True) - self.assertEqual(node.is_failed, False) - self.assertIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) - self.assertEqual( - node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL), + assert node.exit_status == 0 + assert node.is_finished is True + assert node.is_finished_ok is True + assert node.is_failed is False + assert PotentialFailureWorkChain.OUTPUT_LABEL in node.get_outgoing().all_link_labels() + assert node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL) == \ PotentialFailureWorkChain.OUTPUT_VALUE - ) def test_return_out_of_outline(self): _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True), through_return=Bool(True)) - self.assertEqual(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) - self.assertEqual(node.is_finished, True) - self.assertEqual(node.is_finished_ok, False) - self.assertEqual(node.is_failed, True) - self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) + assert node.exit_status == PotentialFailureWorkChain.EXIT_STATUS + assert node.is_finished is True + assert node.is_finished_ok is False + assert node.is_failed is True + assert PotentialFailureWorkChain.OUTPUT_LABEL not in node.get_outgoing().all_link_labels() class IfTest(WorkChain): @@ -269,55 +265,55 @@ def step2(self): @pytest.mark.requires_rmq -class TestContext(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestContext: def test_attributes(self): wc = IfTest() wc.ctx.new_attr = 5 - self.assertEqual(wc.ctx.new_attr, 5) + assert wc.ctx.new_attr == 5 del wc.ctx.new_attr - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): wc.ctx.new_attr # pylint: disable=pointless-statement def test_dict(self): wc = IfTest() wc.ctx['new_attr'] = 5 - self.assertEqual(wc.ctx['new_attr'], 5) + assert wc.ctx['new_attr'] == 5 del wc.ctx['new_attr'] - with self.assertRaises(KeyError): + with pytest.raises(KeyError): wc.ctx['new_attr'] # pylint: disable=pointless-statement @pytest.mark.requires_rmq -class TestWorkchain(AiidaTestCase): +class TestWorkchain: # pylint: disable=too-many-public-methods - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + assert Process.current() is None + yield + assert Process.current() is None def test_run_base_class(self): """Verify that it is impossible to run, submit or instantiate a base `WorkChain` class.""" - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): WorkChain() - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run(WorkChain) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run.get_node(WorkChain) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run.get_pk(WorkChain) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.submit(WorkChain) def test_run(self): @@ -331,21 +327,21 @@ def test_run(self): # Check the steps that should have been run for step, finished in Wf.finished_steps.items(): if step not in ['step3', 'step4', 'is_b']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' # Try the elif(..) part finished_steps = launch.run(Wf, value=B, n=three) # Check the steps that should have been run for step, finished in finished_steps.items(): if step not in ['is_a', 'step2', 'step4']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' # Try the else... part finished_steps = launch.run(Wf, value=C, n=three) # Check the steps that should have been run for step, finished in finished_steps.items(): if step not in ['is_a', 'step2', 'is_b', 'step3']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' def test_incorrect_outline(self): @@ -357,7 +353,7 @@ def define(cls, spec): # Try defining an invalid outline spec.outline(5) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): IncorrectOutline.spec() def test_define_not_calling_super(self): @@ -369,7 +365,7 @@ class IncompleteDefineWorkChain(WorkChain): def define(cls, spec): pass - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): launch.run(IncompleteDefineWorkChain) def test_out_unstored(self): @@ -389,7 +385,7 @@ def define(cls, spec): def illegal(self): self.out('not_allowed', orm.Int(2)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): launch.run(IllegalWorkChain) def test_same_input_node(self): @@ -415,8 +411,6 @@ def test_context(self): A = Str('a').store() B = Str('b').store() - test_case = self - class ReturnA(WorkChain): @classmethod @@ -450,15 +444,15 @@ def s1(self): return ToContext(r1=self.submit(ReturnA), r2=self.submit(ReturnB)) def s2(self): - test_case.assertEqual(self.ctx.r1.outputs.res, A) - test_case.assertEqual(self.ctx.r2.outputs.res, B) + assert self.ctx.r1.outputs.res == A + assert self.ctx.r2.outputs.res == B # Try overwriting r1 return ToContext(r1=self.submit(ReturnB)) def s3(self): - test_case.assertEqual(self.ctx.r1.outputs.res, B) - test_case.assertEqual(self.ctx.r2.outputs.res, B) + assert self.ctx.r1.outputs.res == B + assert self.ctx.r2.outputs.res == B run_and_check_success(OverrideContextWorkChain) @@ -481,7 +475,7 @@ def read_context(self): run_and_check_success(TestWorkChain) def test_str(self): - self.assertIsInstance(str(Wf.spec()), str) + assert isinstance(str(Wf.spec()), str) def test_malformed_outline(self): """ @@ -491,11 +485,11 @@ def test_malformed_outline(self): spec = WorkChainSpec() - with self.assertRaises(TypeError): + with pytest.raises(TypeError): spec.outline(5) # Test a function with wrong number of args - with self.assertRaises(TypeError): + with pytest.raises(TypeError): spec.outline(lambda x, y: None) def test_checkpointing(self): @@ -509,21 +503,21 @@ def test_checkpointing(self): # Check the steps that should have been run for step, finished in finished_steps.items(): if step not in ['step3', 'step4', 'is_b']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' # Try the elif(..) part finished_steps = self._run_with_checkpoints(Wf, inputs={'value': B, 'n': three}) # Check the steps that should have been run for step, finished in finished_steps.items(): if step not in ['is_a', 'step2', 'step4']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' # Try the else... part finished_steps = self._run_with_checkpoints(Wf, inputs={'value': C, 'n': three}) # Check the steps that should have been run for step, finished in finished_steps.items(): if step not in ['is_a', 'step2', 'is_b', 'step3']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' def test_return(self): @@ -579,11 +573,11 @@ class SubWorkChain(WorkChain): # Verify that the `CALL` link of the calculation function is there with the correct label link_triple = process.node.get_outgoing(link_type=LinkType.CALL_CALC, link_label_filter=label_calcfunction).one() - self.assertIsInstance(link_triple.node, orm.CalcFunctionNode) + assert isinstance(link_triple.node, orm.CalcFunctionNode) # Verify that the `CALL` link of the work chain is there with the correct label link_triple = process.node.get_outgoing(link_type=LinkType.CALL_WORK, link_label_filter=label_workchain).one() - self.assertIsInstance(link_triple.node, orm.WorkChainNode) + assert isinstance(link_triple.node, orm.WorkChainNode) def test_tocontext_submit_workchain_no_daemon(self): @@ -691,8 +685,8 @@ async def run_async(workchain): # run the original workchain until paused await run_until_paused(workchain) - self.assertTrue(workchain.ctx.s1) - self.assertFalse(workchain.ctx.s2) + assert workchain.ctx.s1 + assert not workchain.ctx.s2 # Now bundle the workchain bundle = plumpy.Bundle(workchain) @@ -701,19 +695,19 @@ async def run_async(workchain): # Load from saved state workchain2 = bundle.unbundle() - self.assertTrue(workchain2.ctx.s1) - self.assertFalse(workchain2.ctx.s2) + assert workchain2.ctx.s1 + assert not workchain2.ctx.s2 # check bundling again creates the same saved state bundle2 = plumpy.Bundle(workchain2) - self.assertDictEqual(bundle, bundle2) + assert bundle == bundle2 # run the loaded workchain to completion runner.schedule(workchain2) workchain2.play() await workchain2.future() - self.assertTrue(workchain2.ctx.s1) - self.assertTrue(workchain2.ctx.s2) + assert workchain2.ctx.s1 + assert workchain2.ctx.s2 # ensure the original paused workchain future is finalised # to avoid warnings @@ -750,8 +744,6 @@ def check(self): def test_to_context(self): val = Int(5).store() - test_case = self - class SimpleWc(WorkChain): @classmethod @@ -775,16 +767,14 @@ def begin(self): return ToContext(result_b=self.submit(SimpleWc)) def result(self): - test_case.assertEqual(self.ctx.result_a.outputs.result, val) - test_case.assertEqual(self.ctx.result_b.outputs.result, val) + assert self.ctx.result_a.outputs.result == val + assert self.ctx.result_b.outputs.result == val run_and_check_success(Workchain) def test_nested_to_context(self): val = Int(5).store() - test_case = self - class SimpleWc(WorkChain): @classmethod @@ -808,8 +798,8 @@ def begin(self): return ToContext(**{'sub1.sub2.result_b': self.submit(SimpleWc)}) def result(self): - test_case.assertEqual(self.ctx.sub1.sub2.result_a.outputs.result, val) - test_case.assertEqual(self.ctx.sub1.sub2.result_b.outputs.result, val) + assert self.ctx.sub1.sub2.result_a.outputs.result == val + assert self.ctx.sub1.sub2.result_b.outputs.result == val run_and_check_success(Workchain) @@ -817,8 +807,6 @@ def test_nested_to_context_with_append(self): val1 = Int(5).store() val2 = Int(6).store() - test_case = self - class SimpleWc1(WorkChain): @classmethod @@ -853,8 +841,8 @@ def begin(self): return ToContext(**{'sub1.workchains': append_(self.submit(SimpleWc2))}) def result(self): - test_case.assertEqual(self.ctx.sub1.workchains[0].outputs.result, val1) - test_case.assertEqual(self.ctx.sub1.workchains[1].outputs.result, val2) + assert self.ctx.sub1.workchains[0].outputs.result == val1 + assert self.ctx.sub1.workchains[1].outputs.result == val2 run_and_check_success(Workchain) @@ -1014,21 +1002,21 @@ def run(self): wc = ExitCodeWorkChain() # The exit code can be gotten by calling it with the status or label, as well as using attribute dereferencing - self.assertEqual(wc.exit_codes(status).status, status) # pylint: disable=too-many-function-args - self.assertEqual(wc.exit_codes(label).status, status) # pylint: disable=too-many-function-args - self.assertEqual(wc.exit_codes.SOME_EXIT_CODE.status, status) # pylint: disable=no-member + assert wc.exit_codes(status).status == status # pylint: disable=too-many-function-args + assert wc.exit_codes(label).status == status # pylint: disable=too-many-function-args + assert wc.exit_codes.SOME_EXIT_CODE.status == status # pylint: disable=no-member - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): wc.exit_codes.NON_EXISTENT_ERROR # pylint: disable=no-member,pointless-statement - self.assertEqual(ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.status, status) # pylint: disable=no-member - self.assertEqual(ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.message, message) # pylint: disable=no-member + assert ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.status == status # pylint: disable=no-member + assert ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.message == message # pylint: disable=no-member - self.assertEqual(ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].status, status) # pylint: disable=unsubscriptable-object - self.assertEqual(ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].message, message) # pylint: disable=unsubscriptable-object + assert ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].status == status # pylint: disable=unsubscriptable-object + assert ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].message == message # pylint: disable=unsubscriptable-object - self.assertEqual(ExitCodeWorkChain.exit_codes[label].status, status) # pylint: disable=unsubscriptable-object - self.assertEqual(ExitCodeWorkChain.exit_codes[label].message, message) # pylint: disable=unsubscriptable-object + assert ExitCodeWorkChain.exit_codes[label].status == status # pylint: disable=unsubscriptable-object + assert ExitCodeWorkChain.exit_codes[label].message == message # pylint: disable=unsubscriptable-object @staticmethod def _run_with_checkpoints(wf_class, inputs=None): @@ -1039,18 +1027,17 @@ def _run_with_checkpoints(wf_class, inputs=None): @pytest.mark.requires_rmq -class TestWorkChainAbort(AiidaTestCase): +class TestWorkChainAbort: """ Test the functionality to abort a workchain """ - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + assert Process.current() is None + yield + assert Process.current() is None class AbortableWorkChain(WorkChain): @@ -1079,15 +1066,15 @@ async def run_async(): process.play() with Capturing(): - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): await process.future() runner.schedule(process) runner.loop.run_until_complete(run_async()) - self.assertEqual(process.node.is_finished_ok, False) - self.assertEqual(process.node.is_excepted, True) - self.assertEqual(process.node.is_killed, False) + assert process.node.is_finished_ok is False + assert process.node.is_excepted is True + assert process.node.is_killed is False def test_simple_kill_through_process(self): """ @@ -1101,34 +1088,33 @@ def test_simple_kill_through_process(self): async def run_async(): await run_until_paused(process) - self.assertTrue(process.paused) + assert process.paused process.kill() - with self.assertRaises(plumpy.ClosedError): + with pytest.raises(plumpy.ClosedError): launch.run(process) runner.schedule(process) runner.loop.run_until_complete(run_async()) - self.assertEqual(process.node.is_finished_ok, False) - self.assertEqual(process.node.is_excepted, False) - self.assertEqual(process.node.is_killed, True) + assert process.node.is_finished_ok is False + assert process.node.is_excepted is False + assert process.node.is_killed is True @pytest.mark.requires_rmq -class TestWorkChainAbortChildren(AiidaTestCase): +class TestWorkChainAbortChildren: """ Test the functionality to abort a workchain and verify that children are also aborted appropriately """ - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + assert Process.current() is None + yield + assert Process.current() is None class SubWorkChain(WorkChain): @@ -1170,12 +1156,12 @@ def test_simple_run(self): process = TestWorkChainAbortChildren.MainWorkChain() with Capturing(): - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): launch.run(process) - self.assertEqual(process.node.is_finished_ok, False) - self.assertEqual(process.node.is_excepted, True) - self.assertEqual(process.node.is_killed, False) + assert process.node.is_finished_ok is False + assert process.node.is_excepted is True + assert process.node.is_killed is False def test_simple_kill_through_process(self): """ @@ -1192,41 +1178,39 @@ async def run_async(): if asyncio.isfuture(result): await result - with self.assertRaises(plumpy.KilledError): + with pytest.raises(plumpy.KilledError): await process.future() runner.schedule(process) runner.loop.run_until_complete(run_async()) child = process.node.get_outgoing(link_type=LinkType.CALL_WORK).first().node - self.assertEqual(child.is_finished_ok, False) - self.assertEqual(child.is_excepted, False) - self.assertEqual(child.is_killed, True) + assert child.is_finished_ok is False + assert child.is_excepted is False + assert child.is_killed is True - self.assertEqual(process.node.is_finished_ok, False) - self.assertEqual(process.node.is_excepted, False) - self.assertEqual(process.node.is_killed, True) + assert process.node.is_finished_ok is False + assert process.node.is_excepted is False + assert process.node.is_killed is True @pytest.mark.requires_rmq -class TestImmutableInputWorkchain(AiidaTestCase): +class TestImmutableInputWorkchain: """ Test that inputs cannot be modified """ - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + assert Process.current() is None + yield + assert Process.current() is None def test_immutable_input(self): """ Check that from within the WorkChain self.inputs returns an AttributesFrozendict which should be immutable """ - test_class = self class FrozenDictWorkChain(WorkChain): @@ -1242,19 +1226,19 @@ def define(cls, spec): def step_one(self): # Attempt to manipulate the inputs dictionary which since it is a AttributesFrozendict should raise - with test_class.assertRaises(TypeError): + with pytest.raises(TypeError): self.inputs['a'] = Int(3) - with test_class.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.inputs.pop('b') - with test_class.assertRaises(TypeError): + with pytest.raises(TypeError): self.inputs['c'] = Int(4) def step_two(self): # Verify that original inputs are still there with same value and no inputs were added - test_class.assertIn('a', self.inputs) - test_class.assertIn('b', self.inputs) - test_class.assertNotIn('c', self.inputs) - test_class.assertEqual(self.inputs['a'].value, 1) + assert 'a' in self.inputs + assert 'b' in self.inputs + assert 'c' not in self.inputs + assert self.inputs['a'].value == 1 run_and_check_success(FrozenDictWorkChain, a=Int(1), b=Int(2)) @@ -1262,7 +1246,6 @@ def test_immutable_input_groups(self): """ Check that namespaced inputs also return AttributeFrozendicts and are hence immutable """ - test_class = self class ImmutableGroups(WorkChain): @@ -1277,19 +1260,19 @@ def define(cls, spec): def step_one(self): # Attempt to manipulate the namespaced inputs dictionary which should raise - with test_class.assertRaises(TypeError): + with pytest.raises(TypeError): self.inputs.subspace['one'] = Int(3) - with test_class.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.inputs.subspace.pop('two') - with test_class.assertRaises(TypeError): + with pytest.raises(TypeError): self.inputs.subspace['four'] = Int(4) def step_two(self): # Verify that original inputs are still there with same value and no inputs were added - test_class.assertIn('one', self.inputs.subspace) - test_class.assertIn('two', self.inputs.subspace) - test_class.assertNotIn('four', self.inputs.subspace) - test_class.assertEqual(self.inputs.subspace['one'].value, 1) + assert 'one' in self.inputs.subspace + assert 'two' in self.inputs.subspace + assert 'four' not in self.inputs.subspace + assert self.inputs.subspace['one'].value == 1 run_and_check_success(ImmutableGroups, subspace={'one': Int(1), 'two': Int(2)}) @@ -1315,18 +1298,17 @@ def do_test(self): @pytest.mark.requires_rmq -class TestSerializeWorkChain(AiidaTestCase): +class TestSerializeWorkChain: """ Test workchains with serialized input / output. """ - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + assert Process.current() is None + yield + assert Process.current() is None @staticmethod def test_serialize(): @@ -1442,7 +1424,8 @@ def do_run(self): @pytest.mark.requires_rmq -class TestWorkChainExpose(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestWorkChainExpose: """ Test the expose inputs / outputs functionality """ @@ -1462,21 +1445,19 @@ def test_expose(self): } }, ) - self.assertEqual( - res, { - 'a': Float(2.2), - 'sub_1': { - 'b': Float(2.3), - 'c': Bool(True) - }, - 'sub_2': { - 'b': Float(1.2), - 'sub_3': { - 'c': Bool(False) - } + assert res == { + 'a': Float(2.2), + 'sub_1': { + 'b': Float(2.3), + 'c': Bool(True) + }, + 'sub_2': { + 'b': Float(1.2), + 'sub_3': { + 'c': Bool(False) } } - ) + } def test_nested_expose(self): res = launch.run( @@ -1497,25 +1478,23 @@ def test_nested_expose(self): ) ) ) - self.assertEqual( - res, { + assert res == { + 'sub': { 'sub': { - 'sub': { - 'a': Float(2.2), - 'sub_1': { - 'b': Float(2.3), - 'c': Bool(True) - }, - 'sub_2': { - 'b': Float(1.2), - 'sub_3': { - 'c': Bool(False) - } + 'a': Float(2.2), + 'sub_1': { + 'b': Float(2.3), + 'c': Bool(True) + }, + 'sub_2': { + 'b': Float(1.2), + 'sub_3': { + 'c': Bool(False) } } } } - ) + } @pytest.mark.filterwarnings('ignore::UserWarning') def test_issue_1741_expose_inputs(self): @@ -1553,7 +1532,8 @@ def step1(self): @pytest.mark.requires_rmq -class TestWorkChainMisc(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestWorkChainMisc: class PointlessWorkChain(WorkChain): @@ -1585,12 +1565,13 @@ def test_run_pointless_workchain(): def test_global_submit_raises(self): """Using top-level submit should raise.""" - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run(TestWorkChainMisc.IllegalSubmitWorkChain) @pytest.mark.requires_rmq -class TestDefaultUniqueness(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestDefaultUniqueness: """Test that default inputs of exposed nodes will get unique UUIDS.""" class Parent(WorkChain): @@ -1638,4 +1619,4 @@ def test_unique_default_inputs(self): # Trying to load one of the inputs through the UUID should fail, # as both `child_one.a` and `child_two.a` should have the same UUID. node = load_node(uuid=node.get_incoming().get_node_by_label('child_one__a').uuid) - self.assertEqual(len(uuids), len(nodes), f'Only {len(uuids)} unique UUIDS for {len(nodes)} input nodes') + assert len(uuids) == len(nodes), f'Only {len(uuids)} unique UUIDS for {len(nodes)} input nodes' diff --git a/tests/engine/test_workfunctions.py b/tests/engine/test_workfunctions.py index e47ecae994..1255f229d0 100644 --- a/tests/engine/test_workfunctions.py +++ b/tests/engine/test_workfunctions.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the workfunction decorator and WorkFunctionNode.""" import pytest @@ -14,19 +15,20 @@ from aiida.engine import Process, calcfunction, workfunction from aiida.manage.caching import enable_caching from aiida.orm import CalcFunctionNode, Int, WorkFunctionNode -from aiida.storage.testbase import AiidaTestCase @pytest.mark.requires_rmq -class TestWorkFunction(AiidaTestCase): +class TestWorkFunction: """Tests for workfunctions. .. note: tests common to all process functions should go in `tests.engine.test_process_function.py` """ - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + assert Process.current() is None self.default_int = Int(256) @workfunction @@ -35,21 +37,20 @@ def test_workfunction(data): self.test_workfunction = test_workfunction - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + yield + assert Process.current() is None def test_workfunction_node_type(self): """Verify that a workfunction gets a WorkFunctionNode as node instance.""" _, node = self.test_workfunction.run_get_node(self.default_int) - self.assertIsInstance(node, WorkFunctionNode) + assert isinstance(node, WorkFunctionNode) def test_workfunction_links(self): """Verify that a workfunction can only get RETURN links and no CREATE links.""" _, node = self.test_workfunction.run_get_node(self.default_int) - self.assertEqual(len(node.get_outgoing(link_type=LinkType.RETURN).all()), 1) - self.assertEqual(len(node.get_outgoing(link_type=LinkType.CREATE).all()), 0) + assert len(node.get_outgoing(link_type=LinkType.RETURN).all()) == 1 + assert len(node.get_outgoing(link_type=LinkType.CREATE).all()) == 0 def test_workfunction_return_unstored(self): """Verify that a workfunction will raise when an unstored node is returned.""" @@ -58,16 +59,16 @@ def test_workfunction_return_unstored(self): def test_workfunction(): return Int(2) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): test_workfunction.run_get_node() def test_workfunction_default_linkname(self): """Verify that a workfunction that returns a single Data node gets a default link label.""" _, node = self.test_workfunction.run_get_node(self.default_int) - self.assertEqual(node.outputs.result, self.default_int) - self.assertEqual(getattr(node.outputs, Process.SINGLE_OUTPUT_LINKNAME), self.default_int) - self.assertEqual(node.outputs[Process.SINGLE_OUTPUT_LINKNAME], self.default_int) + assert node.outputs.result == self.default_int + assert getattr(node.outputs, Process.SINGLE_OUTPUT_LINKNAME) == self.default_int + assert node.outputs[Process.SINGLE_OUTPUT_LINKNAME] == self.default_int def test_workfunction_caching(self): """Verify that a workfunction cannot be cached.""" @@ -76,7 +77,7 @@ def test_workfunction_caching(self): # Caching should always be disabled for a WorkFunctionNode with enable_caching(): _, cached = self.test_workfunction.run_get_node(self.default_int) - self.assertFalse(cached.is_created_from_cache) + assert not cached.is_created_from_cache def test_call_link_label(self): """Verify that setting a `call_link_label` on a `calcfunction` called from a `workfunction` works.""" @@ -95,4 +96,4 @@ def caller(): # Verify that the `CALL` link of the calculation function is there with the correct label link_triple = node.get_outgoing(link_type=LinkType.CALL_CALC, link_label_filter=link_label).one() - self.assertIsInstance(link_triple.node, CalcFunctionNode) + assert isinstance(link_triple.node, CalcFunctionNode) diff --git a/tests/manage/configuration/test_options.py b/tests/manage/configuration/test_options.py index f9d783c250..213689f7a0 100644 --- a/tests/manage/configuration/test_options.py +++ b/tests/manage/configuration/test_options.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the configuration options.""" import pytest @@ -14,46 +15,45 @@ from aiida.common.exceptions import ConfigurationError from aiida.manage.configuration import ConfigValidationError, get_config, get_config_option from aiida.manage.configuration.options import Option, get_option, get_option_names, parse_option -from aiida.storage.testbase import AiidaTestCase -class TestConfigurationOptions(AiidaTestCase): +@pytest.mark.usefixtures('config_with_profile') +class TestConfigurationOptions: """Tests for the Options class.""" def test_get_option_names(self): """Test `get_option_names` function.""" - self.assertIsInstance(get_option_names(), list) - self.assertEqual(len(get_option_names()), 27) + assert isinstance(get_option_names(), list) + assert len(get_option_names()) == 27 def test_get_option(self): """Test `get_option` function.""" - with self.assertRaises(ConfigurationError): + with pytest.raises(ConfigurationError): get_option('no_existing_option') option_name = get_option_names()[0] option = get_option(option_name) - self.assertIsInstance(option, Option) - self.assertEqual(option.name, option_name) + assert isinstance(option, Option) + assert option.name == option_name def test_parse_option(self): """Test `parse_option` function.""" - with self.assertRaises(ConfigValidationError): + with pytest.raises(ConfigValidationError): parse_option('logging.aiida_loglevel', 1) - with self.assertRaises(ConfigValidationError): + with pytest.raises(ConfigValidationError): parse_option('logging.aiida_loglevel', 'INVALID_LOG_LEVEL') def test_options(self): """Test that all defined options can be converted into Option namedtuples.""" for option_name in get_option_names(): option = get_option(option_name) - self.assertEqual(option.name, option_name) - self.assertIsInstance(option.description, str) + assert option.name == option_name + assert isinstance(option.description, str) option.valid_type # pylint: disable=pointless-statement option.default # pylint: disable=pointless-statement - @pytest.mark.usefixtures('config_with_profile') def test_get_config_option_default(self): """Tests that `get_option` return option default if not specified globally or for current profile.""" option_name = 'logging.aiida_loglevel' @@ -61,9 +61,8 @@ def test_get_config_option_default(self): # If we haven't set the option explicitly, `get_config_option` should return the option default option_value = get_config_option(option_name) - self.assertEqual(option_value, option.default) + assert option_value == option.default - @pytest.mark.usefixtures('config_with_profile') def test_get_config_option_profile_specific(self): """Tests that `get_option` correctly gets a configuration option if specified for the current profile.""" config = get_config() @@ -75,9 +74,8 @@ def test_get_config_option_profile_specific(self): # Setting a specific value for the current profile which should then be returned by `get_config_option` config.set_option(option_name, option_value_profile, scope=profile.name) option_value = get_config_option(option_name) - self.assertEqual(option_value, option_value_profile) + assert option_value == option_value_profile - @pytest.mark.usefixtures('config_with_profile') def test_get_config_option_global(self): """Tests that `get_option` correctly agglomerates upwards and so retrieves globally set config options.""" config = get_config() @@ -88,4 +86,4 @@ def test_get_config_option_global(self): # Setting a specific value globally which should then be returned by `get_config_option` due to agglomeration config.set_option(option_name, option_value_global) option_value = get_config_option(option_name) - self.assertEqual(option_value, option_value_global) + assert option_value == option_value_global diff --git a/tests/manage/configuration/test_profile.py b/tests/manage/configuration/test_profile.py index cd2bc13c21..d2752edbd4 100644 --- a/tests/manage/configuration/test_profile.py +++ b/tests/manage/configuration/test_profile.py @@ -7,36 +7,37 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the Profile class.""" - import os import uuid +import pytest + from aiida.manage.configuration import Profile -from aiida.storage.testbase import AiidaTestCase from tests.utils.configuration import create_mock_profile -class TestProfile(AiidaTestCase): +class TestProfile: """Tests for the Profile class.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - """Setup a mock profile.""" - super().setUpClass(*args, **kwargs) - cls.profile_name = 'test_profile' - cls.profile_dictionary = { + @pytest.fixture(autouse=True) + def init_profile(self): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.profile_name = 'test_profile' + self.profile_dictionary = { 'default_user_email': 'dummy@localhost', 'storage': { 'backend': 'psql_dos', 'config': { 'database_engine': 'postgresql_psycopg2', - 'database_name': cls.profile_name, + 'database_name': self.profile_name, 'database_port': '5432', 'database_hostname': 'localhost', 'database_username': 'user', 'database_password': 'pass', - 'repository_uri': f"file:///{os.path.join('/some/path', f'repository_{cls.profile_name}')}", + 'repository_uri': f"file:///{os.path.join('/some/path', f'repository_{self.profile_name}')}", } }, 'process_control': { @@ -51,25 +52,25 @@ def setUpClass(cls, *args, **kwargs): } } } - cls.profile = Profile(cls.profile_name, cls.profile_dictionary) + self.profile = Profile(self.profile_name, self.profile_dictionary) def test_base_properties(self): """Test the basic properties of a Profile instance.""" - self.assertEqual(self.profile.name, self.profile_name) + assert self.profile.name == self.profile_name - self.assertEqual(self.profile.storage_backend, 'psql_dos') - self.assertEqual(self.profile.storage_config, self.profile_dictionary['storage']['config']) - self.assertEqual(self.profile.process_control_backend, 'rabbitmq') - self.assertEqual(self.profile.process_control_config, self.profile_dictionary['process_control']['config']) + assert self.profile.storage_backend == 'psql_dos' + assert self.profile.storage_config == self.profile_dictionary['storage']['config'] + assert self.profile.process_control_backend == 'rabbitmq' + assert self.profile.process_control_config == self.profile_dictionary['process_control']['config'] # Verify that the uuid property returns a valid UUID by attempting to construct an UUID instance from it uuid.UUID(self.profile.uuid) # Check that the default user email field is not None - self.assertIsNotNone(self.profile.default_user_email) + assert self.profile.default_user_email is not None # The RabbitMQ prefix should contain the profile UUID - self.assertIn(self.profile.uuid, self.profile.rmq_prefix) + assert self.profile.uuid in self.profile.rmq_prefix def test_is_test_profile(self): """Test that a profile whose name starts with `test_` is marked as a test profile.""" @@ -77,10 +78,10 @@ def test_is_test_profile(self): profile = create_mock_profile(name=profile_name) # The one constructed in the setUpClass should be a test profile - self.assertTrue(self.profile.is_test_profile) + assert self.profile.is_test_profile # The profile created here should *not* be a test profile - self.assertFalse(profile.is_test_profile) + assert not profile.is_test_profile def test_set_option(self): """Test the `set_option` method.""" @@ -90,12 +91,12 @@ def test_set_option(self): # Setting an option if it does not exist should work self.profile.set_option(option_key, option_value_one) - self.assertEqual(self.profile.get_option(option_key), option_value_one) + assert self.profile.get_option(option_key) == option_value_one # Setting it again will override it by default self.profile.set_option(option_key, option_value_two) - self.assertEqual(self.profile.get_option(option_key), option_value_two) + assert self.profile.get_option(option_key) == option_value_two # If we set override to False, it should not override, big surprise self.profile.set_option(option_key, option_value_one, override=False) - self.assertEqual(self.profile.get_option(option_key), option_value_two) + assert self.profile.get_option(option_key) == option_value_two diff --git a/tests/orm/implementation/test_comments.py b/tests/orm/implementation/test_comments.py index a380806014..30e7ac826e 100644 --- a/tests/orm/implementation/test_comments.py +++ b/tests/orm/implementation/test_comments.py @@ -7,29 +7,29 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Unit tests for the BackendComment and BackendCommentCollection classes.""" from datetime import datetime from uuid import UUID +import pytest import pytz from aiida import orm from aiida.common import exceptions, timezone -from aiida.storage.testbase import AiidaTestCase -class TestBackendComment(AiidaTestCase): +class TestBackendComment: """Test BackendComment.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer = cls.computer.backend_entity # Unwrap the `Computer` instance to `BackendComputer` - cls.user = cls.backend.users.create(email='tester@localhost').store() - - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost, backend): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.backend = backend + self.computer = aiida_localhost.backend_entity # Unwrap the `Computer` instance to `BackendComputer` + self.user = backend.users.create(email='tester@localhost').store() self.node = self.backend.nodes.create( node_type='', user=self.user, computer=self.computer, label='label', description='description' ).store() @@ -51,46 +51,46 @@ def test_creation(self): comment = self.backend.comments.create(node=self.node, user=self.user, content=self.comment_content) # Before storing - self.assertIsNone(comment.id) - self.assertIsNone(comment.pk) - self.assertTrue(isinstance(comment.uuid, str)) - self.assertTrue(comment.node, self.node) - self.assertTrue(isinstance(comment.ctime, datetime)) - self.assertIsNone(comment.mtime) - self.assertTrue(comment.user, self.user) - self.assertEqual(comment.content, self.comment_content) + assert comment.id is None + assert comment.pk is None + assert isinstance(comment.uuid, str) + assert comment.node, self.node + assert isinstance(comment.ctime, datetime) + assert comment.mtime is None + assert comment.user, self.user + assert comment.content == self.comment_content # Store the comment.ctime before the store as a reference now = timezone.now() comment_ctime_before_store = comment.ctime - self.assertTrue(now > comment.ctime, f'{comment.ctime} is not smaller than now {now}') + assert now > comment.ctime, f'{comment.ctime} is not smaller than now {now}' comment.store() comment_ctime = comment.ctime comment_mtime = comment.mtime # The comment.ctime should have been unchanged, but the comment.mtime should have changed - self.assertEqual(comment.ctime, comment_ctime_before_store) - self.assertIsNotNone(comment.mtime) - self.assertTrue(now < comment.mtime, f'{comment.mtime} is not larger than now {now}') + assert comment.ctime == comment_ctime_before_store + assert comment.mtime is not None + assert now < comment.mtime, f'{comment.mtime} is not larger than now {now}' # After storing - self.assertTrue(isinstance(comment.id, int)) - self.assertTrue(isinstance(comment.pk, int)) - self.assertTrue(isinstance(comment.uuid, str)) - self.assertTrue(comment.node, self.node) - self.assertTrue(isinstance(comment.ctime, datetime)) - self.assertTrue(isinstance(comment.mtime, datetime)) - self.assertTrue(comment.user, self.user) - self.assertEqual(comment.content, self.comment_content) + assert isinstance(comment.id, int) + assert isinstance(comment.pk, int) + assert isinstance(comment.uuid, str) + assert comment.node, self.node + assert isinstance(comment.ctime, datetime) + assert isinstance(comment.mtime, datetime) + assert comment.user, self.user + assert comment.content == self.comment_content # Try to construct a UUID from the UUID value to prove that it has a valid UUID UUID(comment.uuid) # Change a column, which should trigger the save, update the mtime but leave the ctime untouched comment.set_content('test') - self.assertEqual(comment.ctime, comment_ctime) - self.assertTrue(comment.mtime > comment_mtime) + assert comment.ctime == comment_ctime + assert comment.mtime > comment_mtime def test_creation_with_time(self): """ @@ -105,14 +105,14 @@ def test_creation_with_time(self): ) # Check that the ctime and mtime are the given ones - self.assertEqual(comment.ctime, ctime) - self.assertEqual(comment.mtime, mtime) + assert comment.ctime == ctime + assert comment.mtime == mtime comment.store() # Check that the given values remain even after storing - self.assertEqual(comment.ctime, ctime) - self.assertEqual(comment.mtime, mtime) + assert comment.ctime == ctime + assert comment.mtime == mtime def test_delete(self): """Test `delete` method""" @@ -124,39 +124,37 @@ def test_delete(self): builder = orm.QueryBuilder().append(orm.Comment, project='uuid') no_of_comments = builder.count() found_comments_uuid = [_[0] for _ in builder.all()] - self.assertIn(comment_uuid, found_comments_uuid) + assert comment_uuid in found_comments_uuid # Delete Comment, making sure it was deleted self.backend.comments.delete(comment.id) builder = orm.QueryBuilder().append(orm.Comment, project='uuid') - self.assertEqual(builder.count(), no_of_comments - 1) + assert builder.count() == no_of_comments - 1 found_comments_uuid = [_[0] for _ in builder.all()] - self.assertNotIn(comment_uuid, found_comments_uuid) + assert comment_uuid not in found_comments_uuid def test_delete_all(self): """Test `delete_all` method""" self.create_comment().store() - self.assertGreater(len(orm.Comment.objects.all()), 0, msg='There should be Comments in the database') + assert len(orm.Comment.objects.all()) > 0, 'There should be Comments in the database' self.backend.comments.delete_all() - self.assertEqual(len(orm.Comment.objects.all()), 0, msg='All Comments should have been deleted') + assert len(orm.Comment.objects.all()) == 0, 'All Comments should have been deleted' def test_delete_many_no_filters(self): """Test `delete_many` method with empty filters""" self.create_comment().store() count = len(orm.Comment.objects.all()) - self.assertGreater(count, 0) + assert count > 0 # Pass empty filter to delete_many, making sure ValidationError is raised - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): self.backend.comments.delete_many({}) - self.assertEqual( - len(orm.Comment.objects.all()), - count, - msg='No Comments should have been deleted. There should still be {} Comment(s), ' + assert len(orm.Comment.objects.all()) == \ + count, \ + 'No Comments should have been deleted. There should still be {} Comment(s), ' \ 'however {} Comment(s) was/were found.'.format(count, len(orm.Comment.objects.all())) - ) def test_delete_many_ids(self): """Test `delete_many` method filtering on both `id` and `uuid`""" @@ -170,13 +168,9 @@ def test_delete_many_ids(self): # Make sure they exist count_comments_found = orm.QueryBuilder().append(orm.Comment, filters={'uuid': {'in': comment_uuids}}).count() - self.assertEqual( - count_comments_found, - len(comment_uuids), - msg='There should be {} Comments, instead {} Comment(s) was/were found'.format( - len(comment_uuids), count_comments_found - ) - ) + assert count_comments_found == \ + len(comment_uuids), \ + f'There should be {len(comment_uuids)} Comments, instead {count_comments_found} Comment(s) was/were found' # Delete last two comments (comment2, comment3) filters = {'or': [{'id': comment2.id}, {'uuid': str(comment3.uuid)}]} @@ -185,7 +179,7 @@ def test_delete_many_ids(self): # Check they were deleted builder = orm.QueryBuilder().append(orm.Comment, filters={'uuid': {'in': comment_uuids}}, project='uuid').all() found_comments_uuid = [_[0] for _ in builder] - self.assertEqual([comment_uuids[0]], found_comments_uuid) + assert [comment_uuids[0]] == found_comments_uuid def test_delete_many_dbnode_id(self): """Test `delete_many` method filtering on `dbnode_id`""" @@ -203,13 +197,9 @@ def test_delete_many_dbnode_id(self): # Make sure they exist count_comments_found = orm.QueryBuilder().append(orm.Comment, filters={'uuid': {'in': comment_uuids}}).count() - self.assertEqual( - count_comments_found, - len(comment_uuids), - msg='There should be {} Comments, instead {} Comment(s) was/were found'.format( - len(comment_uuids), count_comments_found - ) - ) + assert count_comments_found == \ + len(comment_uuids), \ + f'There should be {len(comment_uuids)} Comments, instead {count_comments_found} Comment(s) was/were found' # Delete comments for self.node filters = {'dbnode_id': self.node.id} @@ -218,7 +208,7 @@ def test_delete_many_dbnode_id(self): # Check they were deleted builder = orm.QueryBuilder().append(orm.Comment, filters={'uuid': {'in': comment_uuids}}, project='uuid').all() found_comments_uuid = [_[0] for _ in builder] - self.assertEqual([comment_uuids[0]], found_comments_uuid) + assert [comment_uuids[0]] == found_comments_uuid # pylint: disable=too-many-locals def test_delete_many_ctime_mtime(self): @@ -246,16 +236,16 @@ def test_delete_many_ctime_mtime(self): # Make sure they exist with the correct times builder = orm.QueryBuilder().append(orm.Comment, project=['ctime', 'mtime', 'uuid']) - self.assertGreater(builder.count(), 0) + assert builder.count() > 0 for comment in builder.all(): found_comments_ctime.append(comment[0]) found_comments_mtime.append(comment[1]) found_comments_uuid.append(comment[2]) for time, uuid in zip(comment_times, comment_uuids): - self.assertIn(time, found_comments_ctime) - self.assertIn(uuid, found_comments_uuid) + assert time in found_comments_ctime + assert uuid in found_comments_uuid if time != two_days_ago: - self.assertIn(time, found_comments_mtime) + assert time in found_comments_mtime # Delete comments that are created more than 1 hour ago, # unless they have been modified within 5 hours @@ -266,13 +256,13 @@ def test_delete_many_ctime_mtime(self): # Check only the most stale comment (comment3) was deleted builder = orm.QueryBuilder().append(orm.Comment, project='uuid') - self.assertGreater(builder.count(), 1) # There should still be at least 2 + assert builder.count() > 1 # There should still be at least 2 found_comments_uuid = [_[0] for _ in builder.all()] - self.assertNotIn(comment_uuids[2], found_comments_uuid) + assert comment_uuids[2] not in found_comments_uuid # Make sure the other comments were not deleted for comment_uuid in comment_uuids[:-1]: - self.assertIn(comment_uuid, found_comments_uuid) + assert comment_uuid in found_comments_uuid def test_delete_many_user_id(self): """Test `delete_many` method filtering on `user_id`""" @@ -288,10 +278,10 @@ def test_delete_many_user_id(self): # Make sure they exist builder = orm.QueryBuilder().append(orm.Comment, project='uuid') - self.assertGreater(builder.count(), 0) + assert builder.count() > 0 found_comments_uuid = [_[0] for _ in builder.all()] for comment_uuid in comment_uuids: - self.assertIn(comment_uuid, found_comments_uuid) + assert comment_uuid in found_comments_uuid # Delete last comments for `self.user` filters = {'user_id': self.user.id} @@ -300,12 +290,12 @@ def test_delete_many_user_id(self): # Check they were deleted builder = orm.QueryBuilder().append(orm.Comment, project='uuid') found_comments_uuid = [_[0] for _ in builder.all()] - self.assertGreater(builder.count(), 0) + assert builder.count() > 0 for comment_uuid in comment_uuids[1:]: - self.assertNotIn(comment_uuid, found_comments_uuid) + assert comment_uuid not in found_comments_uuid # Make sure the first comment (comment1) was not deleted - self.assertIn(comment_uuids[0], found_comments_uuid) + assert comment_uuids[0] in found_comments_uuid def test_deleting_non_existent_entities(self): """Test deleting non-existent Comments for different cases""" @@ -326,27 +316,25 @@ def test_deleting_non_existent_entities(self): # Try to delete non-existing Comment - using delete_many # delete_many should return an empty list deleted_entities = self.backend.comments.delete_many(filters={'id': id_}) - self.assertEqual( - deleted_entities, [], msg=f'No entities should have been deleted, since Comment id {id_} does not exist' - ) + assert deleted_entities == [], f'No entities should have been deleted, since Comment id {id_} does not exist' # Try to delete non-existing Comment - using delete # NotExistent should be raised, since no entities are found - with self.assertRaises(exceptions.NotExistent) as exc: + with pytest.raises(exceptions.NotExistent) as exc: self.backend.comments.delete(comment_id=id_) - self.assertIn(f"Comment with id '{id_}' not found", str(exc.exception)) + assert f"Comment with id '{id_}' not found" in str(exc) # Try to delete existing and non-existing Comment - using delete_many # delete_many should return a list that *only* includes the existing Comment filters = {'id': {'in': [id_, comment_id]}} deleted_entities = self.backend.comments.delete_many(filters=filters) - self.assertEqual([comment_id], - deleted_entities, - msg=f'Only Comment id {comment_id} should be returned from delete_many') + assert [comment_id] == \ + deleted_entities, \ + f'Only Comment id {comment_id} should be returned from delete_many' # Make sure the existing Comment was deleted builder = orm.QueryBuilder().append(orm.Comment, filters={'uuid': comment_uuid}) - self.assertEqual(builder.count(), 0) + assert builder.count() == 0 # Get a non-existent Node valid_node_found = True @@ -363,12 +351,10 @@ def test_deleting_non_existent_entities(self): filters = {'dbnode_id': id_} self.backend.comments.delete_many(filters=filters) comment_count_after = orm.QueryBuilder().append(orm.Comment).count() - self.assertEqual( - comment_count_after, - comment_count_before, - msg='The number of comments changed after performing `delete_many`, ' + assert comment_count_after == \ + comment_count_before, \ + 'The number of comments changed after performing `delete_many`, ' \ "while filtering for a non-existing 'dbnode_id'" - ) def test_delete_many_same_twice(self): """Test no exception is raised when entity is filtered by both `id` and `uuid`""" @@ -383,14 +369,14 @@ def test_delete_many_same_twice(self): # Make sure comment is removed builder = orm.QueryBuilder().append(orm.Comment, filters={'uuid': comment_uuid}) - self.assertEqual(builder.count(), 0) + assert builder.count() == 0 def test_delete_wrong_type(self): """Test TypeError is raised when `filters` is wrong type""" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.backend.comments.delete(comment_id=None) def test_delete_many_wrong_type(self): """Test TypeError is raised when `filters` is wrong type""" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.backend.comments.delete_many(filters=None) diff --git a/tests/orm/implementation/test_logs.py b/tests/orm/implementation/test_logs.py index 7f7bcff730..8e8d6fc114 100644 --- a/tests/orm/implementation/test_logs.py +++ b/tests/orm/implementation/test_logs.py @@ -7,31 +7,31 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Unit tests for the BackendLog and BackendLogCollection classes.""" from datetime import datetime import logging from uuid import UUID +import pytest import pytz from aiida import orm from aiida.common import exceptions, timezone from aiida.common.log import LOG_LEVEL_REPORT -from aiida.storage.testbase import AiidaTestCase -class TestBackendLog(AiidaTestCase): +class TestBackendLog: """Test BackendLog.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer = cls.computer.backend_entity # Unwrap the `Computer` instance to `BackendComputer` - cls.user = cls.backend.users.create(email='tester@localhost').store() - - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost, backend): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.backend = backend + self.computer = aiida_localhost.backend_entity # Unwrap the `Computer` instance to `BackendComputer` + self.user = self.backend.users.create(email='tester@localhost').store() self.node = self.backend.nodes.create( node_type='', user=self.user, computer=self.computer, label='label', description='description' ).store() @@ -56,34 +56,34 @@ def test_creation(self): log = self.create_log() # Before storing - self.assertIsNone(log.id) - self.assertIsNone(log.pk) - self.assertTrue(isinstance(log.uuid, str)) - self.assertTrue(isinstance(log.time, datetime)) - self.assertEqual(log.loggername, 'loggername') - self.assertTrue(isinstance(log.levelname, str)) - self.assertTrue(isinstance(log.dbnode_id, int)) - self.assertEqual(log.message, self.log_message) - self.assertEqual(log.metadata, {'content': 'test'}) + assert log.id is None + assert log.pk is None + assert isinstance(log.uuid, str) + assert isinstance(log.time, datetime) + assert log.loggername == 'loggername' + assert isinstance(log.levelname, str) + assert isinstance(log.dbnode_id, int) + assert log.message == self.log_message + assert log.metadata == {'content': 'test'} log.store() # After storing - self.assertTrue(isinstance(log.id, int)) - self.assertTrue(isinstance(log.pk, int)) - self.assertTrue(isinstance(log.uuid, str)) - self.assertTrue(isinstance(log.time, datetime)) - self.assertEqual(log.loggername, 'loggername') - self.assertTrue(isinstance(log.levelname, str)) - self.assertTrue(isinstance(log.dbnode_id, int)) - self.assertEqual(log.message, self.log_message) - self.assertEqual(log.metadata, {'content': 'test'}) + assert isinstance(log.id, int) + assert isinstance(log.pk, int) + assert isinstance(log.uuid, str) + assert isinstance(log.time, datetime) + assert log.loggername == 'loggername' + assert isinstance(log.levelname, str) + assert isinstance(log.dbnode_id, int) + assert log.message == self.log_message + assert log.metadata == {'content': 'test'} # Try to construct a UUID from the UUID value to prove that it has a valid UUID UUID(log.uuid) # Raise AttributeError when trying to change column - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): log.message = 'change message' def test_creation_with_static_time(self): @@ -96,15 +96,15 @@ def test_creation_with_static_time(self): log = self.create_log(time=time) # Check that the time is the given one - self.assertEqual(log.time, time) + assert log.time == time # Store - self.assertFalse(log.is_stored) + assert not log.is_stored log.store() - self.assertTrue(log.is_stored) + assert log.is_stored # Check that the given value remains even after storing - self.assertEqual(log.time, time) + assert log.time == time def test_delete(self): """Test `delete` method""" @@ -116,39 +116,37 @@ def test_delete(self): builder = orm.QueryBuilder().append(orm.Log, project='uuid') no_of_logs = builder.count() found_logs_uuid = [_[0] for _ in builder.all()] - self.assertIn(log_uuid, found_logs_uuid) + assert log_uuid in found_logs_uuid # Delete Log, making sure it was deleted self.backend.logs.delete(log.id) builder = orm.QueryBuilder().append(orm.Log, project='uuid') - self.assertEqual(builder.count(), no_of_logs - 1) + assert builder.count() == no_of_logs - 1 found_logs_uuid = [_[0] for _ in builder.all()] - self.assertNotIn(log_uuid, found_logs_uuid) + assert log_uuid not in found_logs_uuid def test_delete_all(self): """Test `delete_all` method""" self.create_log().store() - self.assertGreater(len(orm.Log.objects.all()), 0, msg='There should be Logs in the database') + assert len(orm.Log.objects.all()) > 0, 'There should be Logs in the database' self.backend.logs.delete_all() - self.assertEqual(len(orm.Log.objects.all()), 0, msg='All Logs should have been deleted') + assert len(orm.Log.objects.all()) == 0, 'All Logs should have been deleted' def test_delete_many_no_filters(self): """Test `delete_many` method with empty filters""" self.create_log().store() count = len(orm.Log.objects.all()) - self.assertGreater(count, 0) + assert count > 0 # Pass empty filter to delete_many, making sure ValidationError is raised - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): self.backend.logs.delete_many({}) - self.assertEqual( - len(orm.Log.objects.all()), - count, - msg='No Logs should have been deleted. There should still be {} Log(s), ' + assert len(orm.Log.objects.all()) == \ + count, \ + 'No Logs should have been deleted. There should still be {} Log(s), ' \ 'however {} Log(s) was/were found.'.format(count, len(orm.Log.objects.all())) - ) def test_delete_many_ids(self): """Test `delete_many` method filtering on both `id` and `uuid`""" @@ -163,11 +161,9 @@ def test_delete_many_ids(self): # Make sure they exist count_logs_found = orm.QueryBuilder().append(orm.Log, filters={'uuid': {'in': log_uuids}}).count() - self.assertEqual( - count_logs_found, - len(log_uuids), - msg=f'There should be {len(log_uuids)} Logs, instead {count_logs_found} Log(s) was/were found' - ) + assert count_logs_found == \ + len(log_uuids), \ + f'There should be {len(log_uuids)} Logs, instead {count_logs_found} Log(s) was/were found' # Delete last two logs (log2, log3) filters = {'or': [{'id': log2.id}, {'uuid': str(log3.uuid)}]} @@ -176,7 +172,7 @@ def test_delete_many_ids(self): # Check they were deleted builder = orm.QueryBuilder().append(orm.Log, filters={'uuid': {'in': log_uuids}}, project='uuid').all() found_logs_uuid = [_[0] for _ in builder] - self.assertEqual([log_uuids[0]], found_logs_uuid) + assert [log_uuids[0]] == found_logs_uuid def test_delete_many_dbnode_id(self): """Test `delete_many` method filtering on `dbnode_id`""" @@ -194,11 +190,9 @@ def test_delete_many_dbnode_id(self): # Make sure they exist count_logs_found = orm.QueryBuilder().append(orm.Log, filters={'uuid': {'in': log_uuids}}).count() - self.assertEqual( - count_logs_found, - len(log_uuids), - msg=f'There should be {len(log_uuids)} Logs, instead {count_logs_found} Log(s) was/were found' - ) + assert count_logs_found == \ + len(log_uuids), \ + f'There should be {len(log_uuids)} Logs, instead {count_logs_found} Log(s) was/were found' # Delete logs for self.node filters = {'dbnode_id': self.node.id} @@ -207,7 +201,7 @@ def test_delete_many_dbnode_id(self): # Check they were deleted builder = orm.QueryBuilder().append(orm.Log, filters={'uuid': {'in': log_uuids}}, project='uuid').all() found_logs_uuid = [_[0] for _ in builder] - self.assertEqual([log_uuids[0]], found_logs_uuid) + assert [log_uuids[0]] == found_logs_uuid def test_delete_many_time(self): """Test `delete_many` method filtering on `time`""" @@ -233,14 +227,14 @@ def test_delete_many_time(self): # Make sure they exist with the correct times builder = orm.QueryBuilder().append(orm.Log, project=['time', 'uuid']) - self.assertGreater(builder.count(), 0) + assert builder.count() > 0 for log in builder.all(): found_logs_time.append(log[0]) found_logs_uuid.append(log[1]) for log_time in log_times: - self.assertIn(log_time, found_logs_time) + assert log_time in found_logs_time for log_uuid in log_uuids: - self.assertIn(log_uuid, found_logs_uuid) + assert log_uuid in found_logs_uuid # Delete logs that are older than 1 hour turning_point = now - timedelta(seconds=60 * 60) @@ -249,13 +243,13 @@ def test_delete_many_time(self): # Check they were deleted builder = orm.QueryBuilder().append(orm.Log, project='uuid') - self.assertGreater(builder.count(), 0) # There should still be at least 1 + assert builder.count() > 0 # There should still be at least 1 found_logs_uuid = [_[0] for _ in builder.all()] for log_uuid in log_uuids[1:]: - self.assertNotIn(log_uuid, found_logs_uuid) + assert log_uuid not in found_logs_uuid # Make sure the newest log (log1) was not deleted - self.assertIn(log_uuids[0], found_logs_uuid) + assert log_uuids[0] in found_logs_uuid def test_deleting_non_existent_entities(self): """Test deleting non-existent Logs for different cases""" @@ -277,25 +271,23 @@ def test_deleting_non_existent_entities(self): # Try to delete non-existing Log - using delete_many # delete_many should return an empty list deleted_entities = self.backend.logs.delete_many(filters={'id': id_}) - self.assertEqual( - deleted_entities, [], msg=f'No entities should have been deleted, since Log id {id_} does not exist' - ) + assert deleted_entities == [], f'No entities should have been deleted, since Log id {id_} does not exist' # Try to delete non-existing Log - using delete # NotExistent should be raised, since no entities are found - with self.assertRaises(exceptions.NotExistent) as exc: + with pytest.raises(exceptions.NotExistent) as exc: self.backend.logs.delete(log_id=id_) - self.assertIn(f"Log with id '{id_}' not found", str(exc.exception)) + assert f"Log with id '{id_}' not found" in str(exc) # Try to delete existing and non-existing Log - using delete_many # delete_many should return a list that *only* includes the existing Logs filters = {'id': {'in': [id_, log_id]}} deleted_entities = self.backend.logs.delete_many(filters=filters) - self.assertEqual([log_id], deleted_entities, msg=f'Only Log id {log_id} should be returned from delete_many') + assert [log_id] == deleted_entities, f'Only Log id {log_id} should be returned from delete_many' # Make sure the existing Log was deleted builder = orm.QueryBuilder().append(orm.Log, filters={'uuid': log_uuid}) - self.assertEqual(builder.count(), 0) + assert builder.count() == 0 # Get a non-existent Node valid_node_found = True @@ -312,12 +304,10 @@ def test_deleting_non_existent_entities(self): filters = {'dbnode_id': id_} self.backend.logs.delete_many(filters=filters) log_count_after = orm.QueryBuilder().append(orm.Log).count() - self.assertEqual( - log_count_after, - log_count_before, - msg='The number of logs changed after performing `delete_many`, ' + assert log_count_after == \ + log_count_before, \ + 'The number of logs changed after performing `delete_many`, ' \ "while filtering for a non-existing 'dbnode_id'" - ) def test_delete_many_same_twice(self): """Test no exception is raised when entity is filtered by both `id` and `uuid`""" @@ -332,14 +322,14 @@ def test_delete_many_same_twice(self): # Make sure log is removed builder = orm.QueryBuilder().append(orm.Log, filters={'uuid': log_uuid}) - self.assertEqual(builder.count(), 0) + assert builder.count() == 0 def test_delete_wrong_type(self): """Test TypeError is raised when `filters` is wrong type""" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.backend.logs.delete(log_id=None) def test_delete_many_wrong_type(self): """Test TypeError is raised when `filters` is wrong type""" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.backend.logs.delete_many(filters=None) diff --git a/tests/orm/implementation/test_nodes.py b/tests/orm/implementation/test_nodes.py index 22faff9bf8..a2249c2c50 100644 --- a/tests/orm/implementation/test_nodes.py +++ b/tests/orm/implementation/test_nodes.py @@ -7,34 +7,33 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-public-methods +# pylint: disable=too-many-public-methods,no-self-use """Unit tests for the BackendNode and BackendNodeCollection classes.""" from collections import OrderedDict from datetime import datetime from uuid import UUID +import pytest import pytz from aiida.common import exceptions, timezone -from aiida.storage.testbase import AiidaTestCase -class TestBackendNode(AiidaTestCase): +class TestBackendNode: """Test BackendNode.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer = cls.computer.backend_entity # Unwrap the `Computer` instance to `BackendComputer` - cls.user = cls.backend.users.create(email='tester@localhost').store() - - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost, backend): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.backend = backend + self.computer = aiida_localhost.backend_entity # Unwrap the `Computer` instance to `BackendComputer` + self.user = backend.users.create(email='tester@localhost').store() self.node_type = '' self.node_label = 'label' self.node_description = 'description' - self.node = self.backend.nodes.create( + self.node = backend.nodes.create( node_type=self.node_type, user=self.user, computer=self.computer, @@ -52,54 +51,54 @@ def test_creation(self): ) # Before storing - self.assertIsNone(node.id) - self.assertIsNone(node.pk) - self.assertTrue(isinstance(node.uuid, str)) - self.assertTrue(isinstance(node.ctime, datetime)) - self.assertIsNone(node.mtime) - self.assertIsNone(node.process_type) - self.assertEqual(node.attributes, {}) - self.assertEqual(node.extras, {}) - self.assertEqual(node.repository_metadata, {}) - self.assertEqual(node.node_type, self.node_type) - self.assertEqual(node.label, self.node_label) - self.assertEqual(node.description, self.node_description) + assert node.id is None + assert node.pk is None + assert isinstance(node.uuid, str) + assert isinstance(node.ctime, datetime) + assert node.mtime is None + assert node.process_type is None + assert node.attributes == {} + assert node.extras == {} + assert node.repository_metadata == {} + assert node.node_type == self.node_type + assert node.label == self.node_label + assert node.description == self.node_description # Store the node.ctime before the store as a reference now = timezone.now() node_ctime_before_store = node.ctime - self.assertTrue(now > node.ctime, f'{node.ctime} is not smaller than now {now}') + assert now > node.ctime, f'{node.ctime} is not smaller than now {now}' node.store() node_ctime = node.ctime node_mtime = node.mtime # The node.ctime should have been unchanged, but the node.mtime should have changed - self.assertEqual(node.ctime, node_ctime_before_store) - self.assertIsNotNone(node.mtime) - self.assertTrue(now < node.mtime, f'{node.mtime} is not larger than now {now}') + assert node.ctime == node_ctime_before_store + assert node.mtime is not None + assert now < node.mtime, f'{node.mtime} is not larger than now {now}' # After storing - self.assertTrue(isinstance(node.id, int)) - self.assertTrue(isinstance(node.pk, int)) - self.assertTrue(isinstance(node.uuid, str)) - self.assertTrue(isinstance(node.ctime, datetime)) - self.assertTrue(isinstance(node.mtime, datetime)) - self.assertIsNone(node.process_type) - self.assertEqual(node.attributes, {}) - self.assertEqual(node.extras, {}) - self.assertEqual(node.repository_metadata, {}) - self.assertEqual(node.node_type, self.node_type) - self.assertEqual(node.label, self.node_label) - self.assertEqual(node.description, self.node_description) + assert isinstance(node.id, int) + assert isinstance(node.pk, int) + assert isinstance(node.uuid, str) + assert isinstance(node.ctime, datetime) + assert isinstance(node.mtime, datetime) + assert node.process_type is None + assert node.attributes == {} + assert node.extras == {} + assert node.repository_metadata == {} + assert node.node_type == self.node_type + assert node.label == self.node_label + assert node.description == self.node_description # Try to construct a UUID from the UUID value to prove that it has a valid UUID UUID(node.uuid) # Change a column, which should trigger the save, update the mtime but leave the ctime untouched node.label = 'test' - self.assertEqual(node.ctime, node_ctime) - self.assertTrue(node.mtime > node_mtime) + assert node.ctime == node_ctime + assert node.mtime > node_mtime def test_creation_with_time(self): """ @@ -119,14 +118,14 @@ def test_creation_with_time(self): ) # Check that the ctime and mtime are the given ones - self.assertEqual(node.ctime, ctime) - self.assertEqual(node.mtime, mtime) + assert node.ctime == ctime + assert node.mtime == mtime node.store() # Check that the given values remain even after storing - self.assertEqual(node.ctime, ctime) - self.assertEqual(node.mtime, mtime) + assert node.ctime == ctime + assert node.mtime == mtime def test_mtime(self): """Test the `mtime` is automatically updated when a database field is updated.""" @@ -134,7 +133,7 @@ def test_mtime(self): node_mtime = node.mtime node.label = 'changed label' - self.assertTrue(node.mtime > node_mtime) + assert node.mtime > node_mtime def test_clone(self): """Test the `clone` method.""" @@ -142,14 +141,14 @@ def test_clone(self): clone = node.clone() # Check that the clone is unstored, i.e. has *no* id, has a different UUID, but all other props are the same - self.assertIsNone(clone.id) - self.assertNotEqual(clone.uuid, node.uuid) - self.assertEqual(clone.label, node.label) - self.assertEqual(clone.description, node.description) - self.assertEqual(clone.user.id, node.user.id) - self.assertEqual(clone.computer.id, node.computer.id) - self.assertEqual(clone.attributes, node.attributes) - self.assertEqual(clone.extras, node.extras) + assert clone.id is None + assert clone.uuid != node.uuid + assert clone.label == node.label + assert clone.description == node.description + assert clone.user.id == node.user.id + assert clone.computer.id == node.computer.id + assert clone.attributes == node.attributes + assert clone.extras == node.extras def test_property_setters(self): """Test the property setters of a BackendNode.""" @@ -159,22 +158,22 @@ def test_property_setters(self): self.node.label = label self.node.description = description - self.assertEqual(self.node.label, label) - self.assertEqual(self.node.description, description) + assert self.node.label == label + assert self.node.description == description def test_computer_methods(self): """Test the computer methods of a BackendNode.""" new_computer = self.backend.computers.create(label='localhost2', hostname='localhost').store() - self.assertEqual(self.node.computer.id, self.computer.id) + assert self.node.computer.id == self.computer.id self.node.computer = new_computer - self.assertEqual(self.node.computer.id, new_computer.id) + assert self.node.computer.id == new_computer.id def test_user_methods(self): """Test the user methods of a BackendNode.""" new_user = self.backend.users.create(email='newuser@localhost').store() - self.assertEqual(self.node.user.id, self.user.id) + assert self.node.user.id == self.user.id self.node.user = new_user - self.assertEqual(self.node.user.id, new_user.id) + assert self.node.user.id == new_user.id def test_get_set_attribute(self): """Test the `get_attribute` and `set_attribute` method of a BackendNode.""" @@ -185,45 +184,45 @@ def test_get_set_attribute(self): attribute_2_value = '2' attribute_3_value = '3' - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.node.get_attribute(attribute_1_name) - self.assertFalse(self.node.is_stored) + assert not self.node.is_stored self.node.set_attribute(attribute_1_name, attribute_1_value) # Check that the attribute is set, but the node is not stored - self.assertFalse(self.node.is_stored) - self.assertEqual(self.node.get_attribute(attribute_1_name), attribute_1_value) + assert not self.node.is_stored + assert self.node.get_attribute(attribute_1_name) == attribute_1_value self.node.store() # Check that the attribute is set, and the node is stored - self.assertTrue(self.node.is_stored) - self.assertEqual(self.node.get_attribute(attribute_1_name), attribute_1_value) + assert self.node.is_stored + assert self.node.get_attribute(attribute_1_name) == attribute_1_value self.node.set_attribute(attribute_2_name, attribute_2_value) - self.assertEqual(self.node.get_attribute(attribute_1_name), attribute_1_value) - self.assertEqual(self.node.get_attribute(attribute_2_name), attribute_2_value) + assert self.node.get_attribute(attribute_1_name) == attribute_1_value + assert self.node.get_attribute(attribute_2_name) == attribute_2_value reloaded = self.backend.nodes.get(self.node.pk) - self.assertEqual(self.node.get_attribute(attribute_1_name), attribute_1_value) - self.assertEqual(self.node.get_attribute(attribute_2_name), attribute_2_value) + assert self.node.get_attribute(attribute_1_name) == attribute_1_value + assert self.node.get_attribute(attribute_2_name) == attribute_2_value reloaded.set_attribute(attribute_3_name, attribute_3_value) - self.assertEqual(reloaded.get_attribute(attribute_1_name), attribute_1_value) - self.assertEqual(reloaded.get_attribute(attribute_2_name), attribute_2_value) - self.assertEqual(reloaded.get_attribute(attribute_3_name), attribute_3_value) + assert reloaded.get_attribute(attribute_1_name) == attribute_1_value + assert reloaded.get_attribute(attribute_2_name) == attribute_2_value + assert reloaded.get_attribute(attribute_3_name) == attribute_3_value # Check deletion of a single reloaded.delete_attribute(attribute_1_name) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): reloaded.get_attribute(attribute_1_name) - self.assertEqual(reloaded.get_attribute(attribute_2_name), attribute_2_value) - self.assertEqual(reloaded.get_attribute(attribute_3_name), attribute_3_value) + assert reloaded.get_attribute(attribute_2_name) == attribute_2_value + assert reloaded.get_attribute(attribute_3_name) == attribute_3_value - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.node.get_attribute(attribute_1_name) def test_get_set_extras(self): @@ -235,86 +234,86 @@ def test_get_set_extras(self): extra_2_value = '2' extra_3_value = '3' - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.node.get_extra(extra_1_name) - self.assertFalse(self.node.is_stored) + assert not self.node.is_stored self.node.set_extra(extra_1_name, extra_1_value) # Check that the extra is set, but the node is not stored - self.assertFalse(self.node.is_stored) - self.assertEqual(self.node.get_extra(extra_1_name), extra_1_value) + assert not self.node.is_stored + assert self.node.get_extra(extra_1_name) == extra_1_value self.node.store() # Check that the extra is set, and the node is stored - self.assertTrue(self.node.is_stored) - self.assertEqual(self.node.get_extra(extra_1_name), extra_1_value) + assert self.node.is_stored + assert self.node.get_extra(extra_1_name) == extra_1_value self.node.set_extra(extra_2_name, extra_2_value) - self.assertEqual(self.node.get_extra(extra_1_name), extra_1_value) - self.assertEqual(self.node.get_extra(extra_2_name), extra_2_value) + assert self.node.get_extra(extra_1_name) == extra_1_value + assert self.node.get_extra(extra_2_name) == extra_2_value reloaded = self.backend.nodes.get(self.node.pk) - self.assertEqual(self.node.get_extra(extra_1_name), extra_1_value) - self.assertEqual(self.node.get_extra(extra_2_name), extra_2_value) + assert self.node.get_extra(extra_1_name) == extra_1_value + assert self.node.get_extra(extra_2_name) == extra_2_value reloaded.set_extra(extra_3_name, extra_3_value) - self.assertEqual(reloaded.get_extra(extra_1_name), extra_1_value) - self.assertEqual(reloaded.get_extra(extra_2_name), extra_2_value) - self.assertEqual(reloaded.get_extra(extra_3_name), extra_3_value) + assert reloaded.get_extra(extra_1_name) == extra_1_value + assert reloaded.get_extra(extra_2_name) == extra_2_value + assert reloaded.get_extra(extra_3_name) == extra_3_value # Check deletion of a single reloaded.delete_extra(extra_1_name) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): reloaded.get_extra(extra_1_name) - self.assertEqual(reloaded.get_extra(extra_2_name), extra_2_value) - self.assertEqual(reloaded.get_extra(extra_3_name), extra_3_value) + assert reloaded.get_extra(extra_2_name) == extra_2_value + assert reloaded.get_extra(extra_3_name) == extra_3_value - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.node.get_extra(extra_1_name) def test_attributes(self): """Test the `BackendNode.attributes` property.""" node = self.create_node() - self.assertEqual(node.attributes, {}) + assert node.attributes == {} node.set_attribute('attribute', 'value') - self.assertEqual(node.attributes, {'attribute': 'value'}) + assert node.attributes == {'attribute': 'value'} node.store() - self.assertEqual(node.attributes, {'attribute': 'value'}) + assert node.attributes == {'attribute': 'value'} def test_get_attribute(self): """Test the `BackendNode.get_attribute` method.""" node = self.create_node() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.get_attribute('attribute') node.set_attribute('attribute', 'value') - self.assertEqual(node.get_attribute('attribute'), 'value') + assert node.get_attribute('attribute') == 'value' node.store() - self.assertEqual(node.get_attribute('attribute'), 'value') + assert node.get_attribute('attribute') == 'value' def test_get_attribute_many(self): """Test the `BackendNode.get_attribute_many` method.""" node = self.create_node() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.get_attribute_many(['attribute']) node.set_attribute_many({'attribute': 'value', 'another': 'case'}) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.get_attribute_many(['attribute', 'unexisting']) - self.assertEqual(node.get_attribute_many(['attribute', 'another']), ['value', 'case']) + assert node.get_attribute_many(['attribute', 'another']) == ['value', 'case'] node.store() - self.assertEqual(node.get_attribute_many(['attribute', 'another']), ['value', 'case']) + assert node.get_attribute_many(['attribute', 'another']) == ['value', 'case'] def test_set_attribute(self): """Test the `BackendNode.set_attribute` method.""" @@ -325,16 +324,16 @@ def test_set_attribute(self): node.set_attribute('attribute_valid', 'value') # Calling store should cause the values to be cleaned which should raise - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): node.store() # Replace the original invalid with a valid value node.set_attribute('attribute_invalid', 'actually valid') node.store() - self.assertEqual(node.get_attribute_many(['attribute_invalid', 'attribute_valid']), ['actually valid', 'value']) + assert node.get_attribute_many(['attribute_invalid', 'attribute_valid']) == ['actually valid', 'value'] # Raises immediately when setting it if already stored - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): node.set_attribute('attribute', object()) def test_set_attribute_many(self): @@ -346,23 +345,23 @@ def test_set_attribute_many(self): node.set_attribute_many({'attribute_invalid': object(), 'attribute_valid': 'value'}) # Calling store should cause the values to be cleaned which should raise - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): node.store() # Replace the original invalid with a valid value node.set_attribute_many({'attribute_invalid': 'actually valid'}) node.store() - self.assertEqual(node.get_attribute_many(['attribute_invalid', 'attribute_valid']), ['actually valid', 'value']) + assert node.get_attribute_many(['attribute_invalid', 'attribute_valid']) == ['actually valid', 'value'] attributes = OrderedDict() attributes['another_attribute'] = 'value' attributes['attribute_invalid'] = object() # Raises immediately when setting it if already stored - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): node.set_attribute_many(attributes) - self.assertTrue('another_attribute' not in node.attributes) + assert 'another_attribute' not in node.attributes attributes = {'attribute_one': 1, 'attribute_two': 2} # Calling `set_attribute_many` on a stored node @@ -370,7 +369,7 @@ def test_set_attribute_many(self): node.store() node.set_attribute_many(attributes) - self.assertEqual(node.attributes, attributes) + assert node.attributes == attributes def test_reset_attributes(self): """Test the `BackendNode.reset_attributes` method.""" @@ -380,79 +379,79 @@ def test_reset_attributes(self): # Reset attributes on an unstored node node.set_attribute_many(attributes_before) - self.assertEqual(node.attributes, attributes_before) + assert node.attributes == attributes_before node.reset_attributes(attributes_after) - self.assertEqual(node.attributes, attributes_after) + assert node.attributes == attributes_after # Reset attributes on stored node node = self.create_node() node.store() node.set_attribute_many(attributes_before) - self.assertEqual(node.attributes, attributes_before) + assert node.attributes == attributes_before node.reset_attributes(attributes_after) - self.assertEqual(node.attributes, attributes_after) + assert node.attributes == attributes_after def test_delete_attribute(self): """Test the `BackendNode.delete_attribute` method.""" node = self.create_node() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_attribute('notexisting') node.set_attribute('attribute', 'value') node.delete_attribute('attribute') - self.assertEqual(node.attributes, {}) + assert node.attributes == {} # Now for a stored node node = self.create_node().store() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_attribute('notexisting') node.set_attribute('attribute', 'value') node.delete_attribute('attribute') - self.assertEqual(node.attributes, {}) + assert node.attributes == {} def test_delete_attribute_many(self): """Test the `BackendNode.delete_attribute_many` method.""" node = self.create_node() attributes = {'attribute_one': 1, 'attribute_two': 2} - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_attribute_many(['notexisting', 'some']) node.set_attribute_many(attributes) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_attribute_many(['attribute_one', 'notexisting']) # Because one key failed during delete, none of the attributes should have been deleted - self.assertTrue('attribute_one' in node.attributes) + assert 'attribute_one' in node.attributes # Now delete the keys that actually should exist node.delete_attribute_many(attributes.keys()) - self.assertEqual(node.attributes, {}) + assert node.attributes == {} # Now for a stored node node = self.create_node().store() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_attribute_many(['notexisting', 'some']) node.set_attribute_many(attributes) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_attribute_many(['attribute_one', 'notexisting']) # Because one key failed during delete, none of the attributes should have been deleted - self.assertTrue('attribute_one' in node.attributes) + assert 'attribute_one' in node.attributes # Now delete the keys that actually should exist node.delete_attribute_many(attributes.keys()) - self.assertEqual(node.attributes, {}) + assert node.attributes == {} def test_clear_attributes(self): """Test the `BackendNode.clear_attributes` method.""" @@ -460,17 +459,17 @@ def test_clear_attributes(self): attributes = {'attribute_one': 1, 'attribute_two': 2} node.set_attribute_many(attributes) - self.assertEqual(node.attributes, attributes) + assert node.attributes == attributes node.clear_attributes() - self.assertEqual(node.attributes, {}) + assert node.attributes == {} # Now for a stored node node = self.create_node().store() node.set_attribute_many(attributes) - self.assertEqual(node.attributes, attributes) + assert node.attributes == attributes node.clear_attributes() - self.assertEqual(node.attributes, {}) + assert node.attributes == {} def test_attribute_items(self): """Test the `BackendNode.attribute_items` generator.""" @@ -478,14 +477,14 @@ def test_attribute_items(self): attributes = {'attribute_one': 1, 'attribute_two': 2} node.set_attribute_many(attributes) - self.assertEqual(attributes, dict(node.attributes_items())) + assert attributes == dict(node.attributes_items()) # Repeat for a stored node node = self.create_node().store() attributes = {'attribute_one': 1, 'attribute_two': 2} node.set_attribute_many(attributes) - self.assertEqual(attributes, dict(node.attributes_items())) + assert attributes == dict(node.attributes_items()) def test_attribute_keys(self): """Test the `BackendNode.attribute_keys` generator.""" @@ -493,14 +492,14 @@ def test_attribute_keys(self): attributes = {'attribute_one': 1, 'attribute_two': 2} node.set_attribute_many(attributes) - self.assertEqual(set(attributes), set(node.attributes_keys())) + assert set(attributes) == set(node.attributes_keys()) # Repeat for a stored node node = self.create_node().store() attributes = {'attribute_one': 1, 'attribute_two': 2} node.set_attribute_many(attributes) - self.assertEqual(set(attributes), set(node.attributes_keys())) + assert set(attributes) == set(node.attributes_keys()) def test_attribute_flush_specifically(self): """Test that changing `attributes` only flushes that property and does not affect others like extras. @@ -522,47 +521,47 @@ def test_attribute_flush_specifically(self): # Reload the node yet again and verify that the `extra_three` extra is still there rereloaded = self.backend.nodes.get(node.pk) - self.assertIn('extra_three', rereloaded.extras.keys()) + assert 'extra_three' in rereloaded.extras.keys() def test_extras(self): """Test the `BackendNode.extras` property.""" node = self.create_node() - self.assertEqual(node.extras, {}) + assert node.extras == {} node.set_extra('extra', 'value') - self.assertEqual(node.extras, {'extra': 'value'}) + assert node.extras == {'extra': 'value'} node.store() - self.assertEqual(node.extras, {'extra': 'value'}) + assert node.extras == {'extra': 'value'} def test_get_extra(self): """Test the `BackendNode.get_extra` method.""" node = self.create_node() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.get_extra('extra') node.set_extra('extra', 'value') - self.assertEqual(node.get_extra('extra'), 'value') + assert node.get_extra('extra') == 'value' node.store() - self.assertEqual(node.get_extra('extra'), 'value') + assert node.get_extra('extra') == 'value' def test_get_extra_many(self): """Test the `BackendNode.get_extra_many` method.""" node = self.create_node() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.get_extra_many(['extra']) node.set_extra_many({'extra': 'value', 'another': 'case'}) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.get_extra_many(['extra', 'unexisting']) - self.assertEqual(node.get_extra_many(['extra', 'another']), ['value', 'case']) + assert node.get_extra_many(['extra', 'another']) == ['value', 'case'] node.store() - self.assertEqual(node.get_extra_many(['extra', 'another']), ['value', 'case']) + assert node.get_extra_many(['extra', 'another']) == ['value', 'case'] def test_set_extra(self): """Test the `BackendNode.set_extra` method.""" @@ -573,16 +572,16 @@ def test_set_extra(self): node.set_extra('extra_valid', 'value') # Calling store should cause the values to be cleaned which should raise - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): node.store() # Replace the original invalid with a valid value node.set_extra('extra_invalid', 'actually valid') node.store() - self.assertEqual(node.get_extra_many(['extra_invalid', 'extra_valid']), ['actually valid', 'value']) + assert node.get_extra_many(['extra_invalid', 'extra_valid']) == ['actually valid', 'value'] # Raises immediately when setting it if already stored - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): node.set_extra('extra', object()) def test_set_extra_many(self): @@ -594,23 +593,23 @@ def test_set_extra_many(self): node.set_extra_many({'extra_invalid': object(), 'extra_valid': 'value'}) # Calling store should cause the values to be cleaned which should raise - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): node.store() # Replace the original invalid with a valid value node.set_extra_many({'extra_invalid': 'actually valid'}) node.store() - self.assertEqual(node.get_extra_many(['extra_invalid', 'extra_valid']), ['actually valid', 'value']) + assert node.get_extra_many(['extra_invalid', 'extra_valid']) == ['actually valid', 'value'] extras = OrderedDict() extras['another_extra'] = 'value' extras['extra_invalid'] = object() # Raises immediately when setting it if already stored - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): node.set_extra_many(extras) - self.assertTrue('another_extra' not in node.extras) + assert 'another_extra' not in node.extras extras = {'extra_one': 1, 'extra_two': 2} # Calling `set_extra_many` on a stored node @@ -618,7 +617,7 @@ def test_set_extra_many(self): node.store() node.set_extra_many(extras) - self.assertEqual(node.extras, extras) + assert node.extras == extras def test_reset_extras(self): """Test the `BackendNode.reset_extras` method.""" @@ -628,79 +627,79 @@ def test_reset_extras(self): # Reset extras on an unstored node node.set_extra_many(extras_before) - self.assertEqual(node.extras, extras_before) + assert node.extras == extras_before node.reset_extras(extras_after) - self.assertEqual(node.extras, extras_after) + assert node.extras == extras_after # Reset extras on stored node node = self.create_node() node.store() node.set_extra_many(extras_before) - self.assertEqual(node.extras, extras_before) + assert node.extras == extras_before node.reset_extras(extras_after) - self.assertEqual(node.extras, extras_after) + assert node.extras == extras_after def test_delete_extra(self): """Test the `BackendNode.delete_extra` method.""" node = self.create_node() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_extra('notexisting') node.set_extra('extra', 'value') node.delete_extra('extra') - self.assertEqual(node.extras, {}) + assert node.extras == {} # Now for a stored node node = self.create_node().store() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_extra('notexisting') node.set_extra('extra', 'value') node.delete_extra('extra') - self.assertEqual(node.extras, {}) + assert node.extras == {} def test_delete_extra_many(self): """Test the `BackendNode.delete_extra_many` method.""" node = self.create_node() extras = {'extra_one': 1, 'extra_two': 2} - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_extra_many(['notexisting', 'some']) node.set_extra_many(extras) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_extra_many(['extra_one', 'notexisting']) # Because one key failed during delete, none of the extras should have been deleted - self.assertTrue('extra_one' in node.extras) + assert 'extra_one' in node.extras # Now delete the keys that actually should exist node.delete_extra_many(extras.keys()) - self.assertEqual(node.extras, {}) + assert node.extras == {} # Now for a stored node node = self.create_node().store() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_extra_many(['notexisting', 'some']) node.set_extra_many(extras) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_extra_many(['extra_one', 'notexisting']) # Because one key failed during delete, none of the extras should have been deleted - self.assertTrue('extra_one' in node.extras) + assert 'extra_one' in node.extras # Now delete the keys that actually should exist node.delete_extra_many(extras.keys()) - self.assertEqual(node.extras, {}) + assert node.extras == {} def test_clear_extras(self): """Test the `BackendNode.clear_extras` method.""" @@ -708,17 +707,17 @@ def test_clear_extras(self): extras = {'extra_one': 1, 'extra_two': 2} node.set_extra_many(extras) - self.assertEqual(node.extras, extras) + assert node.extras == extras node.clear_extras() - self.assertEqual(node.extras, {}) + assert node.extras == {} # Now for a stored node node = self.create_node().store() node.set_extra_many(extras) - self.assertEqual(node.extras, extras) + assert node.extras == extras node.clear_extras() - self.assertEqual(node.extras, {}) + assert node.extras == {} def test_extra_items(self): """Test the `BackendNode.extra_items` generator.""" @@ -726,14 +725,14 @@ def test_extra_items(self): extras = {'extra_one': 1, 'extra_two': 2} node.set_extra_many(extras) - self.assertEqual(extras, dict(node.extras_items())) + assert extras == dict(node.extras_items()) # Repeat for a stored node node = self.create_node().store() extras = {'extra_one': 1, 'extra_two': 2} node.set_extra_many(extras) - self.assertEqual(extras, dict(node.extras_items())) + assert extras == dict(node.extras_items()) def test_extra_keys(self): """Test the `BackendNode.extra_keys` generator.""" @@ -741,14 +740,14 @@ def test_extra_keys(self): extras = {'extra_one': 1, 'extra_two': 2} node.set_extra_many(extras) - self.assertEqual(set(extras), set(node.extras_keys())) + assert set(extras) == set(node.extras_keys()) # Repeat for a stored node node = self.create_node().store() extras = {'extra_one': 1, 'extra_two': 2} node.set_extra_many(extras) - self.assertEqual(set(extras), set(node.extras_keys())) + assert set(extras) == set(node.extras_keys()) def test_extra_flush_specifically(self): """Test that changing `extras` only flushes that property and does not affect others like attributes. @@ -770,4 +769,4 @@ def test_extra_flush_specifically(self): # Reload the node yet again and verify that the `attribute_three` attribute is still there rereloaded = self.backend.nodes.get(node.pk) - self.assertIn('attribute_three', rereloaded.attributes.keys()) + assert 'attribute_three' in rereloaded.attributes.keys() diff --git a/tests/orm/implementation/test_utils.py b/tests/orm/implementation/test_utils.py index f7e7d60925..a47e85f13b 100644 --- a/tests/orm/implementation/test_utils.py +++ b/tests/orm/implementation/test_utils.py @@ -7,15 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Unit tests for the backend non-specific utility methods.""" import math +import pytest + from aiida.common import exceptions from aiida.orm.implementation.utils import FIELD_SEPARATOR, clean_value, validate_attribute_extra_key -from aiida.storage.testbase import AiidaTestCase -class TestOrmImplementationUtils(AiidaTestCase): +class TestOrmImplementationUtils: """Test the utility methods in aiida.orm.implementation.utils""" def test_invalid_attribute_extra_key(self): @@ -23,10 +25,10 @@ def test_invalid_attribute_extra_key(self): non_string_key = 5 field_separator_key = f'invalid{FIELD_SEPARATOR}key' - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): validate_attribute_extra_key(non_string_key) - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): validate_attribute_extra_key(field_separator_key) def test_invalid_value(self): @@ -34,8 +36,8 @@ def test_invalid_value(self): nan_value = math.nan inf_value = math.inf - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): clean_value(nan_value) - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): clean_value(inf_value) diff --git a/tests/orm/nodes/data/test_kpoints.py b/tests/orm/nodes/data/test_kpoints.py index 1a8005d2ae..3cf123f4c1 100644 --- a/tests/orm/nodes/data/test_kpoints.py +++ b/tests/orm/nodes/data/test_kpoints.py @@ -7,20 +7,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the `KpointsData` class.""" - import numpy as np +import pytest from aiida.orm import KpointsData, StructureData, load_node -from aiida.storage.testbase import AiidaTestCase -class TestKpoints(AiidaTestCase): +class TestKpoints: """Test for the `Kpointsdata` class.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init alat = 5.430 # angstrom cell = [[ @@ -32,14 +33,14 @@ def setUpClass(cls, *args, **kwargs): 0.5 * alat, 0.5 * alat, ], [0.5 * alat, 0., 0.5 * alat]] - cls.alat = alat + self.alat = alat structure = StructureData(cell=cell) structure.append_atom(position=(0.000 * alat, 0.000 * alat, 0.000 * alat), symbols=['Si']) structure.append_atom(position=(0.250 * alat, 0.250 * alat, 0.250 * alat), symbols=['Si']) - cls.structure = structure + self.structure = structure # Define the expected reciprocal cell val = 2. * np.pi / alat - cls.expected_reciprocal_cell = np.array([[val, val, -val], [-val, val, val], [val, -val, val]]) + self.expected_reciprocal_cell = np.array([[val, val, -val], [-val, val, val], [val, -val, val]]) def test_reciprocal_cell(self): """ @@ -50,12 +51,12 @@ def test_reciprocal_cell(self): kpt = KpointsData() kpt.set_cell_from_structure(self.structure) - self.assertEqual(np.abs(kpt.reciprocal_cell - self.expected_reciprocal_cell).sum(), 0.) + assert np.abs(kpt.reciprocal_cell - self.expected_reciprocal_cell).sum() == 0. # Check also after storing kpt.store() kpt2 = load_node(kpt.pk) - self.assertEqual(np.abs(kpt2.reciprocal_cell - self.expected_reciprocal_cell).sum(), 0.) + assert np.abs(kpt2.reciprocal_cell - self.expected_reciprocal_cell).sum() == 0. def test_get_kpoints(self): """Test the `get_kpoints` method.""" @@ -73,11 +74,11 @@ def test_get_kpoints(self): ] kpt.set_kpoints(kpoints) - self.assertEqual(np.abs(kpt.get_kpoints() - np.array(kpoints)).sum(), 0.) - self.assertEqual(np.abs(kpt.get_kpoints(cartesian=True) - np.array(cartesian_kpoints)).sum(), 0.) + assert np.abs(kpt.get_kpoints() - np.array(kpoints)).sum() == 0. + assert np.abs(kpt.get_kpoints(cartesian=True) - np.array(cartesian_kpoints)).sum() == 0. # Check also after storing kpt.store() kpt2 = load_node(kpt.pk) - self.assertEqual(np.abs(kpt2.get_kpoints() - np.array(kpoints)).sum(), 0.) - self.assertEqual(np.abs(kpt2.get_kpoints(cartesian=True) - np.array(cartesian_kpoints)).sum(), 0.) + assert np.abs(kpt2.get_kpoints() - np.array(kpoints)).sum() == 0. + assert np.abs(kpt2.get_kpoints(cartesian=True) - np.array(cartesian_kpoints)).sum() == 0. diff --git a/tests/orm/nodes/data/test_orbital.py b/tests/orm/nodes/data/test_orbital.py index 09d9c57701..8239b5f206 100644 --- a/tests/orm/nodes/data/test_orbital.py +++ b/tests/orm/nodes/data/test_orbital.py @@ -7,24 +7,26 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the `OrbitalData` class.""" - import copy +import pytest + from aiida.common import ValidationError from aiida.orm import OrbitalData from aiida.plugins import OrbitalFactory -from aiida.storage.testbase import AiidaTestCase -class TestOrbitalData(AiidaTestCase): +class TestOrbitalData: """Test for the `OrbitalData` class.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init - cls.my_real_hydrogen_dict = { + self.my_real_hydrogen_dict = { 'angular_momentum': -3, 'diffusivity': None, 'kind_name': 'As', @@ -44,25 +46,27 @@ def test_real_hydrogen(self): orbitaldata = OrbitalData() #Check that there is a failure if get_orbital is called for setting orbitals - self.assertRaises(AttributeError, orbitaldata.get_orbitals) + with pytest.raises(AttributeError): + orbitaldata.get_orbitals() #Check that only one orbital has been assiigned orbitaldata.set_orbitals(orbitals=orbital) - self.assertEqual(len(orbitaldata.get_orbitals()), 1) + assert len(orbitaldata.get_orbitals()) == 1 #Check the orbital dict has been assigned correctly retrieved_real_hydrogen_dict = orbitaldata.get_orbitals()[0].get_orbital_dict() - self.assertEqual(retrieved_real_hydrogen_dict.pop('_orbital_type'), 'core.realhydrogen') - self.assertDictEqual(retrieved_real_hydrogen_dict, self.my_real_hydrogen_dict) + assert retrieved_real_hydrogen_dict.pop('_orbital_type') == 'core.realhydrogen' + assert retrieved_real_hydrogen_dict == self.my_real_hydrogen_dict #Check that a corrupted OribtalData fails on get_orbitals corrupted_orbitaldata = copy.deepcopy(orbitaldata) del corrupted_orbitaldata.get_attribute('orbital_dicts')[0]['_orbital_type'] - self.assertRaises(ValidationError, corrupted_orbitaldata.get_orbitals) + with pytest.raises(ValidationError): + corrupted_orbitaldata.get_orbitals() #Check that clear_orbitals empties the data of orbitals orbitaldata.clear_orbitals() - self.assertEqual(len(orbitaldata.get_orbitals()), 0) + assert len(orbitaldata.get_orbitals()) == 0 # verdi -p test_aiida devel tests db.orm.data.orbital diff --git a/tests/orm/nodes/data/test_trajectory.py b/tests/orm/nodes/data/test_trajectory.py index aaf6abe284..d73835126e 100644 --- a/tests/orm/nodes/data/test_trajectory.py +++ b/tests/orm/nodes/data/test_trajectory.py @@ -1,24 +1,25 @@ # -*- coding: utf-8 -*- +# pylint: disable=no-self-use """Tests for the `TrajectoryData` class.""" - import numpy as np +import pytest from aiida.orm import TrajectoryData, load_node -from aiida.storage.testbase import AiidaTestCase -class TestTrajectory(AiidaTestCase): +class TestTrajectory: """Test for the `TrajectoryData` class.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init n_atoms = 5 n_steps = 30 - cls.symbols = [chr(_) for _ in range(ord('A'), ord('A') + n_atoms)] - cls.positions = np.array(np.arange(n_steps * n_atoms * 3).reshape(n_steps, n_atoms, 3), dtype=float) + self.symbols = [chr(_) for _ in range(ord('A'), ord('A') + n_atoms)] + self.positions = np.array(np.arange(n_steps * n_atoms * 3).reshape(n_steps, n_atoms, 3), dtype=float) def test_get_attribute_tryexcept_default(self): """ @@ -36,7 +37,7 @@ def test_get_attribute_tryexcept_default(self): positions_unit = 'A' except KeyError: times_unit = 'FAILED_tryexc' - self.assertEqual(positions_unit, 'A') + assert positions_unit == 'A' try: times_unit = tjd.get_attribute('units|times') @@ -44,7 +45,7 @@ def test_get_attribute_tryexcept_default(self): times_unit = 'ps' except KeyError: times_unit = 'FAILED_tryexc' - self.assertEqual(times_unit, 'ps') + assert times_unit == 'ps' positions = 1 try: @@ -55,21 +56,21 @@ def test_get_attribute_tryexcept_default(self): pass except KeyError: positions = 'FAILED_tryexc' - self.assertEqual(positions, 1) + assert positions == 1 def test_units(self): """Test the setting of units attributes.""" tjd = TrajectoryData() tjd.set_attribute('units|positions', 'some_random_pos_unit') - self.assertEqual(tjd.get_attribute('units|positions'), 'some_random_pos_unit') + assert tjd.get_attribute('units|positions') == 'some_random_pos_unit' tjd.set_attribute('units|times', 'some_random_time_unit') - self.assertEqual(tjd.get_attribute('units|times'), 'some_random_time_unit') + assert tjd.get_attribute('units|times') == 'some_random_time_unit' # Test after storing tjd.set_trajectory(self.symbols, self.positions) tjd.store() tjd2 = load_node(tjd.pk) - self.assertEqual(tjd2.get_attribute('units|positions'), 'some_random_pos_unit') - self.assertEqual(tjd2.get_attribute('units|times'), 'some_random_time_unit') + assert tjd2.get_attribute('units|positions') == 'some_random_pos_unit' + assert tjd2.get_attribute('units|times') == 'some_random_time_unit' diff --git a/tests/orm/nodes/data/test_upf.py b/tests/orm/nodes/data/test_upf.py index 4922aa1bcf..7c0a4e19de 100644 --- a/tests/orm/nodes/data/test_upf.py +++ b/tests/orm/nodes/data/test_upf.py @@ -7,23 +7,20 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=unspecified-encoding +# pylint: disable=unspecified-encoding,no-self-use """ This module contains tests for UpfData and UpfData related functions. """ -import errno import json import os -import shutil -import tempfile import numpy from numpy import array, isclose +import pytest from aiida import orm from aiida.common.exceptions import ParsingError from aiida.orm.nodes.data.upf import parse_upf -from aiida.storage.testbase import AiidaTestCase from tests.static import STATIC_DIR @@ -77,48 +74,31 @@ def compare(dd1, dd2): return all(compare(dd1, dd2)) -class TestUpfParser(AiidaTestCase): +class TestUpfParser: """Tests UPF version / element_name parser function.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, tmp_path): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init filepath_base = os.path.abspath(os.path.join(STATIC_DIR, 'pseudos')) - cls.filepath_barium = os.path.join(filepath_base, 'Ba.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF') - cls.filepath_oxygen = os.path.join(filepath_base, 'O.pbesol-n-rrkjus_psl.0.1-tested-pslib030.UPF') - cls.filepath_carbon = os.path.join(filepath_base, 'C_pbe_v1.2.uspp.F.UPF') - cls.pseudo_barium = orm.UpfData(file=cls.filepath_barium).store() - cls.pseudo_oxygen = orm.UpfData(file=cls.filepath_oxygen).store() - cls.pseudo_carbon = orm.UpfData(file=cls.filepath_carbon).store() - - def setUp(self): - """Setup a temporary directory to store UPF files.""" - self.temp_dir = tempfile.mkdtemp() - - def tearDown(self): - """Delete all groups and destroy the temporary directory created.""" - for group in orm.UpfFamily.objects.find(): - orm.UpfFamily.objects.delete(group.pk) - - try: - shutil.rmtree(self.temp_dir) - except OSError as exception: - if exception.errno == errno.ENOENT: - pass - elif exception.errno == errno.ENOTDIR: - os.remove(self.temp_dir) - else: - raise IOError(exception) + self.filepath_barium = os.path.join(filepath_base, 'Ba.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF') + self.filepath_oxygen = os.path.join(filepath_base, 'O.pbesol-n-rrkjus_psl.0.1-tested-pslib030.UPF') + self.filepath_carbon = os.path.join(filepath_base, 'C_pbe_v1.2.uspp.F.UPF') + self.pseudo_barium = orm.UpfData(file=self.filepath_barium).store() + self.pseudo_oxygen = orm.UpfData(file=self.filepath_oxygen).store() + self.pseudo_carbon = orm.UpfData(file=self.filepath_carbon).store() + self.temp_dir = str(tmp_path) def test_constructor(self): """Tests for the constructor of `UpfData`.""" filename = 'C.some_custom_filename.upf' upf = orm.UpfData(file=self.filepath_carbon, filename=filename) - self.assertEqual(upf.filename, filename) + assert upf.filename == filename # Store and check that the filename is unchanged upf.store() - self.assertEqual(upf.filename, filename) + assert upf.filename == filename def test_get_upf_family_names(self): """Test the `UpfData.get_upf_family_names` method.""" @@ -128,8 +108,8 @@ def test_get_upf_family_names(self): family.add_nodes([self.pseudo_barium]) family.store() - self.assertEqual({group.label for group in orm.UpfFamily.objects.all()}, {label}) - self.assertEqual(self.pseudo_barium.get_upf_family_names(), [label]) + assert {group.label for group in orm.UpfFamily.objects.all()} == {label} + assert self.pseudo_barium.get_upf_family_names() == [label] def test_get_upf_groups(self): """Test the `UpfData.get_upf_groups` class method.""" @@ -138,36 +118,36 @@ def test_get_upf_groups(self): user = orm.User(email='alternate@localhost').store() - self.assertEqual(orm.UpfFamily.objects.all(), []) + assert orm.UpfFamily.objects.all() == [] # Create group with default user and add `Ba` pseudo family_01, _ = orm.UpfFamily.objects.get_or_create(label=label_01) family_01.add_nodes([self.pseudo_barium]) family_01.store() - self.assertEqual({group.label for group in orm.UpfData.get_upf_groups()}, {label_01}) + assert {group.label for group in orm.UpfData.get_upf_groups()} == {label_01} # Create group with different user and add `O` pseudo family_02, _ = orm.UpfFamily.objects.get_or_create(label=label_02, user=user) family_02.add_nodes([self.pseudo_oxygen]) family_02.store() - self.assertEqual({group.label for group in orm.UpfData.get_upf_groups()}, {label_01, label_02}) + assert {group.label for group in orm.UpfData.get_upf_groups()} == {label_01, label_02} # Filter on a given user - self.assertEqual({group.label for group in orm.UpfData.get_upf_groups(user=user.email)}, {label_02}) + assert {group.label for group in orm.UpfData.get_upf_groups(user=user.email)} == {label_02} # Filter on a given element groups = {group.label for group in orm.UpfData.get_upf_groups(filter_elements='O')} - self.assertEqual(groups, {label_02}) + assert groups == {label_02} # Filter on a given element and user groups = {group.label for group in orm.UpfData.get_upf_groups(filter_elements='O', user=user.email)} - self.assertEqual(groups, {label_02}) + assert groups == {label_02} # Filter on element and user that should not match anything groups = {group.label for group in orm.UpfData.get_upf_groups(filter_elements='Ba', user=user.email)} - self.assertEqual(groups, set([])) + assert groups == set([]) def test_upf_version_one(self): """Check if parsing for regular UPF file (version 1) succeeds.""" @@ -189,8 +169,8 @@ def test_upf_version_one(self): # try to parse version / element name from UPF file contents parsed_data = parse_upf(path_to_upf, check_filename=True) # check that parsed data matches the expected one - self.assertEqual(parsed_data['version'], '1') - self.assertEqual(parsed_data['element'], 'O') + assert parsed_data['version'] == '1' + assert parsed_data['element'] == 'O' def test_upf_version_two(self): """Check if parsing for regular UPF file (version 2) succeeds.""" @@ -212,8 +192,8 @@ def test_upf_version_two(self): # try to parse version / element name from UPF file contents parsed_data = parse_upf(path_to_upf, check_filename=True) # check that parsed data matches the expected one - self.assertEqual(parsed_data['version'], '2.0.1') - self.assertEqual(parsed_data['element'], 'Al') + assert parsed_data['version'] == '2.0.1' + assert parsed_data['element'] == 'Al' def test_additional_header_line(self): """Regression #2228: check if parsing succeeds if additional header line is present.""" @@ -237,8 +217,8 @@ def test_additional_header_line(self): # try to parse version / element name from UPF file contents parsed_data = parse_upf(path_to_upf, check_filename=True) # check that parsed data matches the expected one - self.assertEqual(parsed_data['version'], '2.0.1') - self.assertEqual(parsed_data['element'], 'Pt') + assert parsed_data['version'] == '2.0.1' + assert parsed_data['element'] == 'Pt' def test_check_filename(self): """Test built-in check for if file name matches element""" @@ -258,7 +238,7 @@ def test_check_filename(self): with open(path_to_upf, 'w') as upf_file: upf_file.write(upf_contents) # Check if parser raises the desired ParsingError - with self.assertRaises(ParsingError): + with pytest.raises(ParsingError): _ = parse_upf(path_to_upf, check_filename=True) def test_missing_element_upf_v2(self): @@ -279,7 +259,7 @@ def test_missing_element_upf_v2(self): with open(path_to_upf, 'w') as upf_file: upf_file.write(upf_contents) # Check if parser raises the desired ParsingError - with self.assertRaises(ParsingError): + with pytest.raises(ParsingError): _ = parse_upf(path_to_upf, check_filename=True) def test_invalid_element_upf_v2(self): @@ -300,7 +280,7 @@ def test_invalid_element_upf_v2(self): with open(path_to_upf, 'w') as upf_file: upf_file.write(upf_contents) # Check if parser raises the desired ParsingError - with self.assertRaises(ParsingError): + with pytest.raises(ParsingError): _ = parse_upf(path_to_upf, check_filename=True) def test_missing_element_upf_v1(self): @@ -321,7 +301,7 @@ def test_missing_element_upf_v1(self): with open(path_to_upf, 'w') as upf_file: upf_file.write(upf_contents) # Check if parser raises the desired ParsingError - with self.assertRaises(ParsingError): + with pytest.raises(ParsingError): _ = parse_upf(path_to_upf, check_filename=True) def test_upf1_to_json_carbon(self): @@ -335,7 +315,7 @@ def test_upf1_to_json_carbon(self): # remove path information pp_dict['pseudo_potential']['header']['original_upf_file'] = '' reference_dict['pseudo_potential']['header']['original_upf_file'] = '' - self.assertTrue(compare_dicts(pp_dict, reference_dict)) + assert compare_dicts(pp_dict, reference_dict) def test_upf2_to_json_barium(self): """Test UPF check Bariium UPF1 pp conversion""" @@ -348,7 +328,7 @@ def test_upf2_to_json_barium(self): # remove path information pp_dict['pseudo_potential']['header']['original_upf_file'] = '' reference_dict['pseudo_potential']['header']['original_upf_file'] = '' - self.assertTrue(compare_dicts(pp_dict, reference_dict)) + assert compare_dicts(pp_dict, reference_dict) def test_invalid_element_upf_v1(self): """Test parsers exception on invalid element name in UPF v1.""" @@ -368,5 +348,5 @@ def test_invalid_element_upf_v1(self): with open(path_to_upf, 'w') as upf_file: upf_file.write(upf_contents) # Check if parser raises the desired ParsingError - with self.assertRaises(ParsingError): + with pytest.raises(ParsingError): _ = parse_upf(path_to_upf, check_filename=True) diff --git a/tests/orm/nodes/test_calcjob.py b/tests/orm/nodes/test_calcjob.py index a963c50021..89b0a32386 100644 --- a/tests/orm/nodes/test_calcjob.py +++ b/tests/orm/nodes/test_calcjob.py @@ -7,31 +7,39 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the `CalcJobNode` node sub class.""" import io +import pytest + from aiida.common import CalcJobState, LinkType from aiida.orm import CalcJobNode, FolderData -from aiida.storage.testbase import AiidaTestCase -class TestCalcJobNode(AiidaTestCase): +class TestCalcJobNode: """Tests for the `CalcJobNode` node sub class.""" + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost + def test_get_set_state(self): """Test the `get_state` and `set_state` method.""" node = CalcJobNode(computer=self.computer,) - self.assertEqual(node.get_state(), None) + assert node.get_state() is None - with self.assertRaises(ValueError): + with pytest.raises(ValueError): node.set_state('INVALID') node.set_state(CalcJobState.UPLOADING) - self.assertEqual(node.get_state(), CalcJobState.UPLOADING) + assert node.get_state() == CalcJobState.UPLOADING # Setting an illegal calculation job state, the `get_state` should not fail but return `None` node.set_attribute(node.CALC_JOB_STATE_KEY, 'INVALID') - self.assertEqual(node.get_state(), None) + assert node.get_state() is None def test_get_scheduler_stdout(self): """Verify that the repository sandbox folder is cleaned after the node instance is garbage collected.""" @@ -57,7 +65,7 @@ def test_get_scheduler_stdout(self): # It should return `None` if no scheduler output is there (file not there, or option not set), # while it should return the content if both are set - self.assertEqual(node.get_scheduler_stdout(), stdout if with_file and with_option else None) + assert node.get_scheduler_stdout() == (stdout if with_file and with_option else None) def test_get_scheduler_stderr(self): """Verify that the repository sandbox folder is cleaned after the node instance is garbage collected.""" @@ -83,4 +91,4 @@ def test_get_scheduler_stderr(self): # It should return `None` if no scheduler output is there (file not there, or option not set), # while it should return the content if both are set - self.assertEqual(node.get_scheduler_stderr(), stderr if with_file and with_option else None) + assert node.get_scheduler_stderr() == (stderr if with_file and with_option else None) diff --git a/tests/orm/test_authinfos.py b/tests/orm/test_authinfos.py index d4a4127f4d..e28e341313 100644 --- a/tests/orm/test_authinfos.py +++ b/tests/orm/test_authinfos.py @@ -7,21 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Unit tests for the AuthInfo ORM class.""" +import pytest from aiida.common import exceptions from aiida.orm import authinfos -from aiida.storage.testbase import AiidaTestCase -class TestAuthinfo(AiidaTestCase): +class TestAuthinfo: """Unit tests for the AuthInfo ORM class.""" - def setUp(self): - super().setUp() - for auth_info in authinfos.AuthInfo.objects.all(): - authinfos.AuthInfo.objects.delete(auth_info.pk) - + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost self.auth_info = self.computer.configure() # pylint: disable=no-member def test_set_auth_params(self): @@ -29,15 +30,15 @@ def test_set_auth_params(self): auth_params = {'safe_interval': 100} self.auth_info.set_auth_params(auth_params) - self.assertEqual(self.auth_info.get_auth_params(), auth_params) + assert self.auth_info.get_auth_params() == auth_params def test_delete(self): """Test deleting a single AuthInfo.""" pk = self.auth_info.pk - self.assertEqual(len(authinfos.AuthInfo.objects.all()), 1) + assert len(authinfos.AuthInfo.objects.all()) == 1 authinfos.AuthInfo.objects.delete(pk) - self.assertEqual(len(authinfos.AuthInfo.objects.all()), 0) + assert len(authinfos.AuthInfo.objects.all()) == 0 - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): authinfos.AuthInfo.objects.delete(pk) diff --git a/tests/orm/test_comments.py b/tests/orm/test_comments.py index 6cd3c77764..3783cda649 100644 --- a/tests/orm/test_comments.py +++ b/tests/orm/test_comments.py @@ -7,53 +7,52 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Unit tests for the Comment ORM class.""" +import pytest from aiida import orm from aiida.common import exceptions from aiida.orm.comments import Comment -from aiida.storage.testbase import AiidaTestCase -class TestComment(AiidaTestCase): +class TestComment: """Unit tests for the Comment ORM class.""" - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init self.node = orm.Data().store() self.user = orm.User.objects.get_default() self.content = 'Sometimes when I am freestyling, I lose confidence' self.comment = Comment(self.node, self.user, self.content).store() - def tearDown(self): - super().tearDown() - Comment.objects.delete_all() - def test_comment_content(self): """Test getting and setting content of a Comment.""" content = 'Be more constructive with your feedback' self.comment.set_content(content) - self.assertEqual(self.comment.content, content) + assert self.comment.content == content def test_comment_mtime(self): """Test getting and setting mtime of a Comment.""" mtime = self.comment.mtime self.comment.set_content('Changing an attribute should automatically change the mtime') - self.assertEqual(self.comment.content, 'Changing an attribute should automatically change the mtime') - self.assertNotEqual(self.comment.mtime, mtime) + assert self.comment.content == 'Changing an attribute should automatically change the mtime' + assert self.comment.mtime != mtime def test_comment_node(self): """Test getting the node of a Comment.""" - self.assertEqual(self.comment.node.uuid, self.node.uuid) + assert self.comment.node.uuid == self.node.uuid def test_comment_user(self): """Test getting the user of a Comment.""" - self.assertEqual(self.comment.user.uuid, self.user.uuid) + assert self.comment.user.uuid == self.user.uuid def test_comment_collection_get(self): """Test retrieving a Comment through the collection.""" comment = Comment.objects.get(id=self.comment.pk) - self.assertEqual(self.comment.uuid, comment.uuid) + assert self.comment.uuid == comment.uuid def test_comment_collection_delete(self): """Test deleting a Comment through the collection.""" @@ -62,10 +61,10 @@ def test_comment_collection_delete(self): Comment.objects.delete(comment.pk) - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): Comment.objects.delete(comment_pk) - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): Comment.objects.get(id=comment_pk) def test_comment_collection_delete_all(self): @@ -75,15 +74,15 @@ def test_comment_collection_delete_all(self): comment_pk = comment.pk # Assert the comments exist - self.assertEqual(len(Comment.objects.all()), 3) + assert len(Comment.objects.all()) == 3 # Delete all Comments Comment.objects.delete_all() - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): Comment.objects.delete(comment_pk) - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): Comment.objects.get(id=comment_pk) def test_comment_collection_delete_many(self): @@ -93,7 +92,7 @@ def test_comment_collection_delete_many(self): comment_ids = [_.id for _ in [comment_one, comment_two]] # Assert the Comments exist - self.assertEqual(len(Comment.objects.all()), 3) + assert len(Comment.objects.all()) == 3 # Delete new Comments using filter filters = {'id': {'in': comment_ids}} @@ -101,14 +100,14 @@ def test_comment_collection_delete_many(self): # Make sure only the setUp Comment is left builder = orm.QueryBuilder().append(Comment, project='id') - self.assertEqual(builder.count(), 1) - self.assertEqual(builder.all()[0][0], self.comment.id) + assert builder.count() == 1 + assert builder.all()[0][0] == self.comment.id for comment_pk in comment_ids: - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): Comment.objects.delete(comment_pk) - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): Comment.objects.get(id=comment_pk) def test_comment_querybuilder(self): @@ -137,9 +136,9 @@ def test_comment_querybuilder(self): builder.append(orm.Node, with_comment='comment', project=['uuid']) nodes = builder.all() - self.assertEqual(len(nodes), 1) + assert len(nodes) == 1 for node in nodes: - self.assertIn(str(node[0]), [node_one.uuid]) + assert str(node[0]) in [node_one.uuid] # Retrieve a comment by joining on a specific node builder = orm.QueryBuilder() @@ -147,9 +146,9 @@ def test_comment_querybuilder(self): builder.append(Comment, with_node='node', project=['uuid']) comments = builder.all() - self.assertEqual(len(comments), 2) + assert len(comments) == 2 for comment in comments: - self.assertIn(str(comment[0]), [comment_two.uuid, comment_three.uuid]) + assert str(comment[0]) in [comment_two.uuid, comment_three.uuid] # Retrieve a user by joining on a specific comment builder = orm.QueryBuilder() @@ -157,9 +156,9 @@ def test_comment_querybuilder(self): builder.append(orm.User, with_comment='comment', project=['email']) users = builder.all() - self.assertEqual(len(users), 1) + assert len(users) == 1 for user in users: - self.assertEqual(str(user[0]), user_two.email) + assert str(user[0]) == user_two.email # Retrieve a comment by joining on a specific user builder = orm.QueryBuilder() @@ -167,12 +166,10 @@ def test_comment_querybuilder(self): builder.append(Comment, with_user='user', project=['uuid']) comments = builder.all() - self.assertEqual(len(comments), 5) + assert len(comments) == 5 for comment in comments: - self.assertIn( - str(comment[0]), + assert str(comment[0]) in \ [self.comment.uuid, comment_one.uuid, comment_two.uuid, comment_three.uuid, comment_five.uuid] - ) # Retrieve users from comments of a single node by joining specific node builder = orm.QueryBuilder() @@ -181,19 +178,19 @@ def test_comment_querybuilder(self): builder.append(orm.User, with_comment='comments', project=['email']) comments_and_users = builder.all() - self.assertEqual(len(comments_and_users), 2) + assert len(comments_and_users) == 2 for entry in comments_and_users: - self.assertEqual(len(entry), 2) + assert len(entry) == 2 comment_uuid = str(entry[0]) user_email = str(entry[1]) - self.assertIn(comment_uuid, [comment_five.uuid, comment_six.uuid]) - self.assertIn(user_email, [user_one.email, user_two.email]) + assert comment_uuid in [comment_five.uuid, comment_six.uuid] + assert user_email in [user_one.email, user_two.email] def test_objects_get(self): """Test getting a comment from the collection""" node = orm.Data().store() comment = node.add_comment('Check out the comment on _this_ one') gotten_comment = Comment.objects.get(id=comment.id) - self.assertIsInstance(gotten_comment, Comment) + assert isinstance(gotten_comment, Comment) diff --git a/tests/orm/test_computers.py b/tests/orm/test_computers.py index fa43e17450..2642fc98cc 100644 --- a/tests/orm/test_computers.py +++ b/tests/orm/test_computers.py @@ -7,14 +7,16 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the `Computer` ORM class.""" +import pytest from aiida import orm from aiida.common import exceptions -from aiida.storage.testbase import AiidaTestCase -class TestComputer(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestComputer: """Tests for the `Computer` ORM class.""" def test_get_transport(self): @@ -40,9 +42,9 @@ def test_get_transport(self): # It's on localhost, so I see files that I create with transport: with tempfile.NamedTemporaryFile() as handle: - self.assertEqual(transport.isfile(handle.name), True) + assert transport.isfile(handle.name) is True # Here the file should have been deleted - self.assertEqual(transport.isfile(handle.name), False) + assert transport.isfile(handle.name) is False def test_delete(self): """Test the deletion of a `Computer` instance.""" @@ -57,19 +59,21 @@ def test_delete(self): comp_pk = new_comp.pk check_computer = orm.Computer.objects.get(id=comp_pk) - self.assertEqual(comp_pk, check_computer.pk) + assert comp_pk == check_computer.pk orm.Computer.objects.delete(comp_pk) - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): orm.Computer.objects.get(id=comp_pk) -class TestComputerConfigure(AiidaTestCase): +class TestComputerConfigure: """Tests for the configuring of instance of the `Computer` ORM class.""" - def setUp(self): + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument """Prepare current user and computer builder with common properties.""" + # pylint: disable=attribute-defined-outside-init from aiida.orm.utils.builders.computer import ComputerBuilder self.comp_builder = ComputerBuilder(label='test', description='computer', hostname='localhost') @@ -91,7 +95,7 @@ def test_configure_local(self): comp.store() comp.configure() - self.assertTrue(comp.is_user_configured(self.user)) + assert comp.is_user_configured(self.user) def test_configure_ssh(self): """Configure a computer for ssh transport and check it is configured.""" @@ -101,7 +105,7 @@ def test_configure_ssh(self): comp.store() comp.configure(username='radames', port='22') - self.assertTrue(comp.is_user_configured(self.user)) + assert comp.is_user_configured(self.user) def test_configure_ssh_invalid(self): """Try to configure computer with invalid auth params and check it fails.""" @@ -110,7 +114,7 @@ def test_configure_ssh_invalid(self): comp = self.comp_builder.new() comp.store() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): comp.configure(username='radames', invalid_auth_param='TEST') def test_non_configure_error(self): @@ -120,11 +124,11 @@ def test_non_configure_error(self): comp = self.comp_builder.new() comp.store() - with self.assertRaises(exceptions.NotExistent) as exc: + with pytest.raises(exceptions.NotExistent) as exc: comp.get_authinfo(self.user) - self.assertIn(str(comp.id), str(exc.exception)) - self.assertIn(comp.label, str(exc.exception)) - self.assertIn(self.user.get_short_name(), str(exc.exception)) - self.assertIn(str(self.user.id), str(exc.exception)) - self.assertIn('verdi computer configure', str(exc.exception)) + assert str(comp.id) in str(exc) + assert comp.label in str(exc) + assert self.user.get_short_name() in str(exc) + assert str(self.user.id) in str(exc) + assert 'verdi computer configure' in str(exc) diff --git a/tests/orm/test_entities.py b/tests/orm/test_entities.py index f35e193793..a530963f6d 100644 --- a/tests/orm/test_entities.py +++ b/tests/orm/test_entities.py @@ -7,33 +7,33 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Test for general backend entities""" +import pytest from aiida import orm -from aiida.storage.testbase import AiidaTestCase -class TestBackendEntitiesAndCollections(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestBackendEntitiesAndCollections: """Test backend entities and their collections""" def test_collections_cache(self): """Make sure that we're not recreating collections each time .objects is called""" # Check directly user_collection = orm.User.objects - self.assertIs(user_collection, orm.User.objects) + assert user_collection is orm.User.objects # Now check passing an explicit backend backend = user_collection.backend - self.assertIs(user_collection, user_collection(backend)) + assert user_collection is user_collection(backend) def test_collections_count(self): """Make sure count() works for collections""" user_collection_count = orm.User.objects.count() number_of_users = orm.QueryBuilder().append(orm.User).count() - self.assertGreater(number_of_users, 0, msg='There should be more than 0 Users in the DB') - self.assertEqual( - user_collection_count, - number_of_users, - msg='{} User(s) was/were found using Collections\' count() method, ' + assert number_of_users > 0, 'There should be more than 0 Users in the DB' + assert user_collection_count == \ + number_of_users, \ + '{} User(s) was/were found using Collections\' count() method, ' \ 'but {} User(s) was/were found using QueryBuilder directly'.format(user_collection_count, number_of_users) - ) diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index ff65d3d852..d32604bdb2 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -7,22 +7,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Test for the Group ORM class.""" import pytest from aiida import orm from aiida.common import exceptions -from aiida.storage.testbase import AiidaTestCase -class TestGroups(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestGroups: """Test backend entities and their collections""" - def setUp(self): - """Remove all existing Groups.""" - for group in orm.Group.objects.all(): - orm.Group.objects.delete(group.id) - def test_count(self): """Test the `count` method.""" node_00 = orm.Data().store() @@ -32,7 +28,7 @@ def test_count(self): group = orm.Group(label='label', description='description').store() group.add_nodes(nodes) - self.assertEqual(group.count(), len(nodes)) + assert group.count() == len(nodes) def test_creation(self): """Test the creation of Groups.""" @@ -41,25 +37,25 @@ def test_creation(self): group = orm.Group(label='testgroup') - with self.assertRaises(exceptions.ModificationNotAllowed): + with pytest.raises(exceptions.ModificationNotAllowed): # group unstored group.add_nodes(node) - with self.assertRaises(exceptions.ModificationNotAllowed): + with pytest.raises(exceptions.ModificationNotAllowed): # group unstored group.add_nodes(stored_node) group.store() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # node unstored group.add_nodes(node) group.add_nodes(stored_node) nodes = list(group.nodes) - self.assertEqual(len(nodes), 1) - self.assertEqual(nodes[0].pk, stored_node.pk) + assert len(nodes) == 1 + assert nodes[0].pk == stored_node.pk def test_node_iterator(self): """Test the indexing and slicing functionality of the node iterator.""" @@ -74,15 +70,15 @@ def test_node_iterator(self): # Indexing node_indexed = group.nodes[0] - self.assertTrue(isinstance(node_indexed, orm.Data)) - self.assertIn(node_indexed.uuid, [node.uuid for node in nodes]) + assert isinstance(node_indexed, orm.Data) + assert node_indexed.uuid in [node.uuid for node in nodes] # Slicing nodes_sliced = group.nodes[1:3] - self.assertTrue(isinstance(nodes_sliced, list)) - self.assertEqual(len(nodes_sliced), 2) - self.assertTrue(all(isinstance(node, orm.Data) for node in nodes_sliced)) - self.assertTrue(all(node.uuid in set(node.uuid for node in nodes) for node in nodes_sliced)) + assert isinstance(nodes_sliced, list) + assert len(nodes_sliced) == 2 + assert all(isinstance(node, orm.Data) for node in nodes_sliced) + assert all(node.uuid in set(node.uuid for node in nodes) for node in nodes_sliced) def test_description(self): """Test the update of the description both for stored and unstored groups.""" @@ -94,10 +90,10 @@ def test_description(self): group_02 = orm.Group(label='testgroupdescription2', description='group_02') # Preliminary checks - self.assertTrue(group_01.is_stored) - self.assertFalse(group_02.is_stored) - self.assertEqual(group_01.description, 'group_01') - self.assertEqual(group_02.description, 'group_02') + assert group_01.is_stored + assert not group_02.is_stored + assert group_01.description == 'group_01' + assert group_02.description == 'group_02' # Change group_01.description = 'new1' @@ -105,15 +101,15 @@ def test_description(self): # Test that the groups remained in their proper stored state and that # the description was updated - self.assertTrue(group_01.is_stored) - self.assertFalse(group_02.is_stored) - self.assertEqual(group_01.description, 'new1') - self.assertEqual(group_02.description, 'new2') + assert group_01.is_stored + assert not group_02.is_stored + assert group_01.description == 'new1' + assert group_02.description == 'new2' # Store group_02 and check that the description is OK group_02.store() - self.assertTrue(group_02.is_stored) - self.assertEqual(group_02.description, 'new2') + assert group_02.is_stored + assert group_02.description == 'new2' def test_add_nodes(self): """Test different ways of adding nodes.""" @@ -129,11 +125,11 @@ def test_add_nodes(self): group.add_nodes([node_02, node_03]) # Check - self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) # Try to add a node that is already present: there should be no problem group.add_nodes(node_01) - self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) def test_remove_nodes(self): """Test node removal.""" @@ -146,22 +142,22 @@ def test_remove_nodes(self): # Add initial nodes group.add_nodes(nodes) - self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) # Remove a node that is not in the group: nothing should happen group.remove_nodes(node_04) - self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) # Remove one orm.Node nodes.remove(node_03) group.remove_nodes(node_03) - self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) # Remove a list of Nodes and check nodes.remove(node_01) nodes.remove(node_02) group.remove_nodes([node_01, node_02]) - self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) def test_clear(self): """Test the `clear` method to remove all nodes.""" @@ -173,23 +169,23 @@ def test_clear(self): # Add initial nodes group.add_nodes(nodes) - self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) group.clear() - self.assertEqual(list(group.nodes), []) + assert list(group.nodes) == [] def test_name_desc(self): """Test Group description.""" group = orm.Group(label='testgroup2', description='some desc') - self.assertEqual(group.label, 'testgroup2') - self.assertEqual(group.description, 'some desc') - self.assertTrue(group.is_user_defined) + assert group.label == 'testgroup2' + assert group.description == 'some desc' + assert group.is_user_defined group.store() # Same checks after storing - self.assertEqual(group.label, 'testgroup2') - self.assertTrue(group.is_user_defined) - self.assertEqual(group.description, 'some desc') + assert group.label == 'testgroup2' + assert group.is_user_defined + assert group.description == 'some desc' # To avoid to find it in further tests orm.Group.objects.delete(group.pk) @@ -200,14 +196,14 @@ def test_delete(self): group = orm.Group(label='testgroup3', description='some other desc').store() group_copy = orm.Group.get(label='testgroup3') - self.assertEqual(group.uuid, group_copy.uuid) + assert group.uuid == group_copy.uuid group.add_nodes(node) - self.assertEqual(len(group.nodes), 1) + assert len(group.nodes) == 1 orm.Group.objects.delete(group.pk) - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): # The group does not exist anymore orm.Group.get(label='testgroup3') @@ -219,18 +215,18 @@ def test_rename(self): group = orm.Group(label=label_original, description='I will be renamed') # Check name changes work before storing - self.assertEqual(group.label, label_original) + assert group.label == label_original group.label = label_changed - self.assertEqual(group.label, label_changed) + assert group.label == label_changed # Revert the name to its original and store it group.label = label_original group.store() # Check name changes work after storing - self.assertEqual(group.label, label_original) + assert group.label == label_original group.label = label_changed - self.assertEqual(group.label, label_changed) + assert group.label == label_changed def test_rename_existing(self): """Test that renaming to an already existing name is not permitted.""" @@ -243,7 +239,7 @@ def test_rename_existing(self): group_b = orm.Group(label=label_group_a, description='They will try to rename me') # Storing for duplicate group name should trigger UniquenessError - with self.assertRaises(exceptions.IntegrityError): + with pytest.raises(exceptions.IntegrityError): group_b.store() # Reverting to unique name before storing @@ -251,7 +247,7 @@ def test_rename_existing(self): group_b.store() # After storing name change to existing should raise - with self.assertRaises(exceptions.IntegrityError): + with pytest.raises(exceptions.IntegrityError): group_b.label = label_group_a def test_group_uuid_hashing_for_querybuidler(self): @@ -272,18 +268,14 @@ def test_group_uuid_hashing_for_querybuidler(self): builder.all() # And that the results are correct - self.assertEqual(builder.count(), 1) - self.assertEqual(builder.first()[0], group.id) + assert builder.count() == 1 + assert builder.first()[0] == group.id -class TestGroupsSubclasses(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestGroupsSubclasses: """Test rules around creating `Group` subclasses.""" - def setUp(self): - """Remove all existing Groups.""" - for group in orm.Group.objects.all(): - orm.Group.objects.delete(group.id) - @staticmethod def test_creation_registered(): """Test rules around creating registered `Group` subclasses.""" @@ -405,13 +397,13 @@ def test_query_with_group(): assert loaded.pk == group.pk -class TestGroupExtras(AiidaTestCase): +class TestGroupExtras: """Test the property and methods of group extras.""" - def setUp(self): - super().setUp() - for group in orm.Group.objects.all(): - orm.Group.objects.delete(group.id) + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init self.group = orm.Group('test_extras') def test_extras(self): @@ -420,10 +412,10 @@ def test_extras(self): self.group.set_extra('key', original_extra) group_extras = self.group.extras - self.assertEqual(group_extras['key'], original_extra) + assert group_extras['key'] == original_extra group_extras['key']['nested']['a'] = 2 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 # Now store the group and verify that `extras` then returns a deep copy self.group.store() @@ -431,7 +423,7 @@ def test_extras(self): # We change the returned group extras but the original extra should remain unchanged group_extras['key']['nested']['a'] = 3 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 def test_get_extra(self): """Test the `Group.get_extra` method.""" @@ -439,14 +431,14 @@ def test_get_extra(self): self.group.set_extra('key', original_extra) group_extra = self.group.get_extra('key') - self.assertEqual(group_extra, original_extra) + assert group_extra == original_extra group_extra['nested']['a'] = 2 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 default = 'default' - self.assertEqual(self.group.get_extra('not_existing', default=default), default) - with self.assertRaises(AttributeError): + assert self.group.get_extra('not_existing', default=default) == default + with pytest.raises(AttributeError): self.group.get_extra('not_existing') # Now store the group and verify that `get_extra` then returns a deep copy @@ -455,11 +447,11 @@ def test_get_extra(self): # We change the returned group extras but the original extra should remain unchanged group_extra['nested']['a'] = 3 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 default = 'default' - self.assertEqual(self.group.get_extra('not_existing', default=default), default) - with self.assertRaises(AttributeError): + assert self.group.get_extra('not_existing', default=default) == default + with pytest.raises(AttributeError): self.group.get_extra('not_existing') def test_get_extra_many(self): @@ -468,10 +460,10 @@ def test_get_extra_many(self): self.group.set_extra('key', original_extra) group_extra = self.group.get_extra_many(['key'])[0] - self.assertEqual(group_extra, original_extra) + assert group_extra == original_extra group_extra['nested']['a'] = 2 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 # Now store the group and verify that `get_extra` then returns a deep copy self.group.store() @@ -479,29 +471,29 @@ def test_get_extra_many(self): # We change the returned group extras but the original extra should remain unchanged group_extra['nested']['a'] = 3 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 def test_set_extra(self): """Test the `Group.set_extra` method.""" - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): self.group.set_extra('illegal.key', 'value') self.group.set_extra('valid_key', 'value') self.group.store() self.group.set_extra('valid_key', 'changed') - self.assertEqual(orm.load_group(self.group.pk).get_extra('valid_key'), 'changed') + assert orm.load_group(self.group.pk).get_extra('valid_key') == 'changed' def test_set_extra_many(self): """Test the `Group.set_extra` method.""" - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): self.group.set_extra_many({'illegal.key': 'value', 'valid_key': 'value'}) self.group.set_extra_many({'valid_key': 'value'}) self.group.store() self.group.set_extra_many({'valid_key': 'changed'}) - self.assertEqual(orm.load_group(self.group.pk).get_extra('valid_key'), 'changed') + assert orm.load_group(self.group.pk).get_extra('valid_key') == 'changed' def test_reset_extra(self): """Test the `Group.reset_extra` method.""" @@ -510,25 +502,25 @@ def test_reset_extra(self): extras_illegal = {'extra.illegal': 'value', 'extra_four': 'value'} self.group.set_extra_many(extras_before) - self.assertEqual(self.group.extras, extras_before) + assert self.group.extras == extras_before self.group.reset_extras(extras_after) - self.assertEqual(self.group.extras, extras_after) + assert self.group.extras == extras_after - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): self.group.reset_extras(extras_illegal) self.group.store() self.group.reset_extras(extras_after) - self.assertEqual(orm.load_group(self.group.pk).extras, extras_after) + assert orm.load_group(self.group.pk).extras == extras_after def test_delete_extra(self): """Test the `Group.delete_extra` method.""" self.group.set_extra('valid_key', 'value') - self.assertEqual(self.group.get_extra('valid_key'), 'value') + assert self.group.get_extra('valid_key') == 'value' self.group.delete_extra('valid_key') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.group.delete_extra('valid_key') # Repeat with stored group @@ -536,7 +528,7 @@ def test_delete_extra(self): self.group.store() self.group.delete_extra('valid_key') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): orm.load_group(self.group.pk).get_extra('valid_key') def test_delete_extra_many(self): @@ -546,39 +538,39 @@ def test_delete_extra_many(self): invalid_keys = ['extra_one', 'invalid_key'] self.group.set_extra_many(extras_valid) - self.assertEqual(self.group.extras, extras_valid) + assert self.group.extras == extras_valid - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.group.delete_extra_many(invalid_keys) self.group.store() self.group.delete_extra_many(valid_keys) - self.assertEqual(orm.load_group(self.group.pk).extras, {}) + assert orm.load_group(self.group.pk).extras == {} def test_clear_extras(self): """Test the `Group.clear_extras` method.""" extras = {'extra_one': 'value', 'extra_two': 'value'} self.group.set_extra_many(extras) - self.assertEqual(self.group.extras, extras) + assert self.group.extras == extras self.group.clear_extras() - self.assertEqual(self.group.extras, {}) + assert self.group.extras == {} # Repeat for stored group self.group.store() self.group.clear_extras() - self.assertEqual(orm.load_group(self.group.pk).extras, {}) + assert orm.load_group(self.group.pk).extras == {} def test_extras_items(self): """Test the `Group.extras_items` generator.""" extras = {'extra_one': 'value', 'extra_two': 'value'} self.group.set_extra_many(extras) - self.assertEqual(dict(self.group.extras_items()), extras) + assert dict(self.group.extras_items()) == extras def test_extras_keys(self): """Test the `Group.extras_keys` generator.""" extras = {'extra_one': 'value', 'extra_two': 'value'} self.group.set_extra_many(extras) - self.assertEqual(set(self.group.extras_keys()), set(extras)) + assert set(self.group.extras_keys()) == set(extras) diff --git a/tests/orm/test_logs.py b/tests/orm/test_logs.py index e65b380265..66f342ab61 100644 --- a/tests/orm/test_logs.py +++ b/tests/orm/test_logs.py @@ -7,23 +7,26 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """ORM Log tests""" - import logging +import pytest + from aiida import orm from aiida.common import exceptions from aiida.common.log import LOG_LEVEL_REPORT from aiida.common.timezone import now from aiida.orm import Log -from aiida.storage.testbase import AiidaTestCase -class TestBackendLog(AiidaTestCase): +class TestBackendLog: """Test the Log entity""" - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init self.log_record = { 'time': now(), 'loggername': 'loggername', @@ -35,13 +38,6 @@ def setUp(self): }, } - def tearDown(self): - """ - Delete all the created log entries - """ - super().tearDown() - Log.objects.delete_all() - def create_log(self): node = orm.CalculationNode().store() record = self.log_record @@ -54,13 +50,13 @@ def test_create_log_message(self): """ entry, node = self.create_log() - self.assertEqual(entry.time, self.log_record['time']) - self.assertEqual(entry.loggername, self.log_record['loggername']) - self.assertEqual(entry.levelname, self.log_record['levelname']) - self.assertEqual(entry.message, self.log_record['message']) - self.assertEqual(entry.metadata, self.log_record['metadata']) - self.assertEqual(entry.dbnode_id, self.log_record['dbnode_id']) - self.assertEqual(entry.dbnode_id, node.id) + assert entry.time == self.log_record['time'] + assert entry.loggername == self.log_record['loggername'] + assert entry.levelname == self.log_record['levelname'] + assert entry.message == self.log_record['message'] + assert entry.metadata == self.log_record['metadata'] + assert entry.dbnode_id == self.log_record['dbnode_id'] + assert entry.dbnode_id == node.id def test_create_log_unserializable_metadata(self): """Test that unserializable data will be removed before reaching the database causing an error.""" @@ -85,21 +81,21 @@ def unbound_method(argument): except ValueError: node.logger.exception('caught an exception') - self.assertEqual(len(Log.objects.all()), 3) + assert len(Log.objects.all()) == 3 def test_log_delete_single(self): """Test that a single log entry can be deleted through the collection.""" entry, _ = self.create_log() log_id = entry.id - self.assertEqual(len(Log.objects.all()), 1) + assert len(Log.objects.all()) == 1 # Deleting the entry Log.objects.delete(log_id) - self.assertEqual(len(Log.objects.all()), 0) + assert len(Log.objects.all()) == 0 # Deleting a non-existing entry should raise - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): Log.objects.delete(log_id) def test_log_collection_delete_all(self): @@ -109,18 +105,18 @@ def test_log_collection_delete_all(self): self.create_log() log_id = Log.objects.find(limit=1)[0].id - self.assertEqual(len(Log.objects.all()), count) + assert len(Log.objects.all()) == count # Delete all Log.objects.delete_all() # Checks - self.assertEqual(len(Log.objects.all()), 0) + assert len(Log.objects.all()) == 0 - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): Log.objects.delete(log_id) - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): Log.objects.get(id=log_id) def test_log_collection_delete_many(self): @@ -133,7 +129,7 @@ def test_log_collection_delete_many(self): special_log, _ = self.create_log() # Assert the Logs exist - self.assertEqual(len(Log.objects.all()), count + 1) + assert len(Log.objects.all()) == count + 1 # Delete new Logs using filter filters = {'id': {'in': log_ids}} @@ -141,14 +137,14 @@ def test_log_collection_delete_many(self): # Make sure only the special_log Log is left builder = orm.QueryBuilder().append(Log, project='id') - self.assertEqual(builder.count(), 1) - self.assertEqual(builder.all()[0][0], special_log.id) + assert builder.count() == 1 + assert builder.all()[0][0] == special_log.id for log_id in log_ids: - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): Log.objects.delete(log_id) - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): Log.objects.get(id=log_id) def test_objects_find(self): @@ -160,8 +156,8 @@ def test_objects_find(self): Log(**record) entries = Log.objects.all() - self.assertEqual(10, len(entries)) - self.assertIsInstance(entries[0], Log) + assert len(entries) == 10 + assert isinstance(entries[0], Log) def test_find_orderby(self): """ @@ -177,11 +173,11 @@ def test_find_orderby(self): order_by = [OrderSpecifier('dbnode_id', ASCENDING)] res_entries = Log.objects.find(order_by=order_by) - self.assertEqual(res_entries[0].dbnode_id, node_ids[0]) + assert res_entries[0].dbnode_id == node_ids[0] order_by = [OrderSpecifier('dbnode_id', DESCENDING)] res_entries = Log.objects.find(order_by=order_by) - self.assertEqual(res_entries[0].dbnode_id, node_ids[-1]) + assert res_entries[0].dbnode_id == node_ids[-1] def test_find_limit(self): """ @@ -193,7 +189,7 @@ def test_find_limit(self): self.log_record['dbnode_id'] = node.id Log(**self.log_record) entries = Log.objects.find(limit=limit) - self.assertEqual(len(entries), limit) + assert len(entries) == limit def test_find_filter(self): """ @@ -209,8 +205,8 @@ def test_find_filter(self): node_id_of_choice = node_ids.pop(randint(0, 9)) entries = Log.objects.find(filters={'dbnode_id': node_id_of_choice}) - self.assertEqual(len(entries), 1) - self.assertEqual(entries[0].dbnode_id, node_id_of_choice) + assert len(entries) == 1 + assert entries[0].dbnode_id == node_id_of_choice def test_db_log_handler(self): """ @@ -228,15 +224,15 @@ def test_db_log_handler(self): logs = Log.objects.find() - self.assertEqual(len(logs), 0) + assert len(logs) == 0 # After storing the node, logs above log level should be stored node.store() node.logger.critical(message) logs = Log.objects.find() - self.assertEqual(len(logs), 1) - self.assertEqual(logs[0].message, message) + assert len(logs) == 1 + assert logs[0].message == message # Launching a second log message ensuring that both messages are correctly stored message2 = f'{message} - Second message' @@ -245,9 +241,9 @@ def test_db_log_handler(self): order_by = [OrderSpecifier('time', ASCENDING)] logs = Log.objects.find(order_by=order_by) - self.assertEqual(len(logs), 2) - self.assertEqual(logs[0].message, message) - self.assertEqual(logs[1].message, message2) + assert len(logs) == 2 + assert logs[0].message == message + assert logs[1].message == message2 def test_log_querybuilder(self): """ Test querying for logs by joining on nodes in the QueryBuilder """ @@ -264,9 +260,9 @@ def test_log_querybuilder(self): builder.append(orm.CalculationNode, with_log='log', project=['uuid']) nodes = builder.all() - self.assertEqual(len(nodes), 1) + assert len(nodes) == 1 for node in nodes: - self.assertIn(str(node[0]), [calc.uuid]) + assert str(node[0]) in [calc.uuid] # Retrieve all logs for a specific node by joining on a said node builder = QueryBuilder() @@ -274,9 +270,9 @@ def test_log_querybuilder(self): builder.append(Log, with_node='calc', project=['uuid']) logs = builder.all() - self.assertEqual(len(logs), 3) + assert len(logs) == 3 for log in logs: - self.assertIn(str(log[0]), [str(log_1.uuid), str(log_2.uuid), str(log_3.uuid)]) + assert str(log[0]) in [str(log_1.uuid), str(log_2.uuid), str(log_3.uuid)] def test_raise_wrong_metadata_type_error(self): """ @@ -302,7 +298,7 @@ def test_raise_wrong_metadata_type_error(self): json_metadata_format = json.loads(json.dumps(correct_metadata_format)) # Check an error is raised when creating a Log with wrong metadata - with self.assertRaises(TypeError): + with pytest.raises(TypeError): Log( now(), 'loggername', @@ -323,7 +319,7 @@ def test_raise_wrong_metadata_type_error(self): ) # Check metadata is correctly created - self.assertEqual(correct_metadata_log.metadata, correct_metadata_format) + assert correct_metadata_log.metadata == correct_metadata_format # Create Log with json metadata, making sure TypeError is NOT raised json_metadata_log = Log( @@ -336,7 +332,7 @@ def test_raise_wrong_metadata_type_error(self): ) # Check metadata is correctly created - self.assertEqual(json_metadata_log.metadata, json_metadata_format) + assert json_metadata_log.metadata == json_metadata_format # Check no error is raised if no metadata is given no_metadata_log = Log( @@ -349,4 +345,4 @@ def test_raise_wrong_metadata_type_error(self): ) # Check metadata is an empty dict for no_metadata_log - self.assertEqual(no_metadata_log.metadata, {}) + assert no_metadata_log.metadata == {} diff --git a/tests/orm/test_mixins.py b/tests/orm/test_mixins.py index d176b59b67..dd822be348 100644 --- a/tests/orm/test_mixins.py +++ b/tests/orm/test_mixins.py @@ -7,16 +7,19 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the ORM mixin classes.""" +import pytest + from aiida.common import exceptions from aiida.common.links import LinkType from aiida.orm import CalculationNode, Int from aiida.orm.utils.mixins import Sealable -from aiida.storage.testbase import AiidaTestCase -class TestSealable(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestSealable: """Tests for the `Sealable` mixin class.""" @staticmethod @@ -35,7 +38,7 @@ def test_validate_incoming_sealed(self): node = CalculationNode().store() node.seal() - with self.assertRaises(exceptions.ModificationNotAllowed): + with pytest.raises(exceptions.ModificationNotAllowed): node.validate_incoming(data, link_type=LinkType.INPUT_CALC, link_label='input') def test_validate_outgoing_sealed(self): @@ -44,5 +47,5 @@ def test_validate_outgoing_sealed(self): node = CalculationNode().store() node.seal() - with self.assertRaises(exceptions.ModificationNotAllowed): + with pytest.raises(exceptions.ModificationNotAllowed): node.validate_outgoing(data, link_type=LinkType.CREATE, link_label='create') diff --git a/tests/orm/utils/test_loaders.py b/tests/orm/utils/test_loaders.py index 26f867d72d..3171926190 100644 --- a/tests/orm/utils/test_loaders.py +++ b/tests/orm/utils/test_loaders.py @@ -7,39 +7,47 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Module to test orm utilities to load nodes, codes etc.""" +import pytest + from aiida.common.exceptions import NotExistent from aiida.orm import Data, Group, Node from aiida.orm.utils import load_code, load_computer, load_entity, load_group, load_node from aiida.orm.utils.loaders import NodeEntityLoader -from aiida.storage.testbase import AiidaTestCase -class TestOrmUtils(AiidaTestCase): +class TestOrmUtils: """Test orm utils.""" + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost + def test_load_entity(self): """Test the functionality of load_entity which is the base function for the other loader functions.""" entity_loader = NodeEntityLoader - with self.assertRaises(TypeError): + with pytest.raises(TypeError): load_entity(entity_loader=None) # No identifier keyword arguments specified - with self.assertRaises(ValueError): + with pytest.raises(ValueError): load_entity(entity_loader) # More than one identifier keyword arguments specified - with self.assertRaises(ValueError): + with pytest.raises(ValueError): load_entity(entity_loader, identifier='a', pk=1) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): load_entity(entity_loader, pk='1') - with self.assertRaises(TypeError): + with pytest.raises(TypeError): load_entity(entity_loader, uuid=1) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): load_entity(entity_loader, label=1) def test_load_code(self): @@ -54,45 +62,45 @@ def test_load_code(self): # Load through full label loaded_code = load_code(code.full_label) - self.assertEqual(loaded_code.uuid, code.uuid) + assert loaded_code.uuid == code.uuid # Load through label loaded_code = load_code(code.label) - self.assertEqual(loaded_code.uuid, code.uuid) + assert loaded_code.uuid == code.uuid # Load through uuid loaded_code = load_code(code.uuid) - self.assertEqual(loaded_code.uuid, code.uuid) + assert loaded_code.uuid == code.uuid # Load through pk loaded_code = load_code(code.pk) - self.assertEqual(loaded_code.uuid, code.uuid) + assert loaded_code.uuid == code.uuid # Load through full label explicitly loaded_code = load_code(label=code.full_label) - self.assertEqual(loaded_code.uuid, code.uuid) + assert loaded_code.uuid == code.uuid # Load through label explicitly loaded_code = load_code(label=code.label) - self.assertEqual(loaded_code.uuid, code.uuid) + assert loaded_code.uuid == code.uuid # Load through uuid explicitly loaded_code = load_code(uuid=code.uuid) - self.assertEqual(loaded_code.uuid, code.uuid) + assert loaded_code.uuid == code.uuid # Load through pk explicitly loaded_code = load_code(pk=code.pk) - self.assertEqual(loaded_code.uuid, code.uuid) + assert loaded_code.uuid == code.uuid # Load through partial uuid without a dash loaded_code = load_code(uuid=code.uuid[:8]) - self.assertEqual(loaded_code.uuid, code.uuid) + assert loaded_code.uuid == code.uuid # Load through partial uuid including a dash loaded_code = load_code(uuid=code.uuid[:10]) - self.assertEqual(loaded_code.uuid, code.uuid) + assert loaded_code.uuid == code.uuid - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): load_code('non-existent-uuid') def test_load_computer(self): @@ -102,37 +110,37 @@ def test_load_computer(self): # Load through label loaded_computer = load_computer(computer.label) - self.assertEqual(loaded_computer.uuid, computer.uuid) + assert loaded_computer.uuid == computer.uuid # Load through uuid loaded_computer = load_computer(computer.uuid) - self.assertEqual(loaded_computer.uuid, computer.uuid) + assert loaded_computer.uuid == computer.uuid # Load through pk loaded_computer = load_computer(computer.pk) - self.assertEqual(loaded_computer.uuid, computer.uuid) + assert loaded_computer.uuid == computer.uuid # Load through label explicitly loaded_computer = load_computer(label=computer.label) - self.assertEqual(loaded_computer.uuid, computer.uuid) + assert loaded_computer.uuid == computer.uuid # Load through uuid explicitly loaded_computer = load_computer(uuid=computer.uuid) - self.assertEqual(loaded_computer.uuid, computer.uuid) + assert loaded_computer.uuid == computer.uuid # Load through pk explicitly loaded_computer = load_computer(pk=computer.pk) - self.assertEqual(loaded_computer.uuid, computer.uuid) + assert loaded_computer.uuid == computer.uuid # Load through partial uuid without a dash loaded_computer = load_computer(uuid=computer.uuid[:8]) - self.assertEqual(loaded_computer.uuid, computer.uuid) + assert loaded_computer.uuid == computer.uuid # Load through partial uuid including a dash loaded_computer = load_computer(uuid=computer.uuid[:10]) - self.assertEqual(loaded_computer.uuid, computer.uuid) + assert loaded_computer.uuid == computer.uuid - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): load_computer('non-existent-uuid') def test_load_group(self): @@ -142,37 +150,37 @@ def test_load_group(self): # Load through label loaded_group = load_group(group.label) - self.assertEqual(loaded_group.uuid, group.uuid) + assert loaded_group.uuid == group.uuid # Load through uuid loaded_group = load_group(group.uuid) - self.assertEqual(loaded_group.uuid, group.uuid) + assert loaded_group.uuid == group.uuid # Load through pk loaded_group = load_group(group.pk) - self.assertEqual(loaded_group.uuid, group.uuid) + assert loaded_group.uuid == group.uuid # Load through label explicitly loaded_group = load_group(label=group.label) - self.assertEqual(loaded_group.uuid, group.uuid) + assert loaded_group.uuid == group.uuid # Load through uuid explicitly loaded_group = load_group(uuid=group.uuid) - self.assertEqual(loaded_group.uuid, group.uuid) + assert loaded_group.uuid == group.uuid # Load through pk explicitly loaded_group = load_group(pk=group.pk) - self.assertEqual(loaded_group.uuid, group.uuid) + assert loaded_group.uuid == group.uuid # Load through partial uuid without a dash loaded_group = load_group(uuid=group.uuid[:8]) - self.assertEqual(loaded_group.uuid, group.uuid) + assert loaded_group.uuid == group.uuid # Load through partial uuid including a dash loaded_group = load_group(uuid=group.uuid[:10]) - self.assertEqual(loaded_group.uuid, group.uuid) + assert loaded_group.uuid == group.uuid - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): load_group('non-existent-uuid') def test_load_node(self): @@ -181,33 +189,33 @@ def test_load_node(self): # Load through uuid loaded_node = load_node(node.uuid) - self.assertIsInstance(loaded_node, Node) - self.assertEqual(loaded_node.uuid, node.uuid) + assert isinstance(loaded_node, Node) + assert loaded_node.uuid == node.uuid # Load through pk loaded_node = load_node(node.pk) - self.assertIsInstance(loaded_node, Node) - self.assertEqual(loaded_node.uuid, node.uuid) + assert isinstance(loaded_node, Node) + assert loaded_node.uuid == node.uuid # Load through uuid explicitly loaded_node = load_node(uuid=node.uuid) - self.assertIsInstance(loaded_node, Node) - self.assertEqual(loaded_node.uuid, node.uuid) + assert isinstance(loaded_node, Node) + assert loaded_node.uuid == node.uuid # Load through pk explicitly loaded_node = load_node(pk=node.pk) - self.assertIsInstance(loaded_node, Node) - self.assertEqual(loaded_node.uuid, node.uuid) + assert isinstance(loaded_node, Node) + assert loaded_node.uuid == node.uuid # Load through partial uuid without a dash loaded_node = load_node(uuid=node.uuid[:8]) - self.assertIsInstance(loaded_node, Node) - self.assertEqual(loaded_node.uuid, node.uuid) + assert isinstance(loaded_node, Node) + assert loaded_node.uuid == node.uuid # Load through partial uuid including a dashs loaded_node = load_node(uuid=node.uuid[:10]) - self.assertIsInstance(loaded_node, Node) - self.assertEqual(loaded_node.uuid, node.uuid) + assert isinstance(loaded_node, Node) + assert loaded_node.uuid == node.uuid - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): load_group('non-existent-uuid') diff --git a/tests/orm/utils/test_managers.py b/tests/orm/utils/test_managers.py index 64f3f50d9a..3e72975d59 100644 --- a/tests/orm/utils/test_managers.py +++ b/tests/orm/utils/test_managers.py @@ -30,9 +30,9 @@ def test_dot_dict_manager(aiida_profile_clean): assert set(dict_node.dict) == set(dict_content) for key, val in dict_content.items(): - # dict_node.dict.a == True, ... + # dict_node.dict.a is True, ... assert getattr(dict_node.dict, key) == val - # dict_node.dict['a'] == True, ... + # dict_node.dict['a'] is True, ... assert dict_node.dict[key] == val # I check the attribute fetching directly diff --git a/tests/orm/utils/test_node.py b/tests/orm/utils/test_node.py index 1b8214e631..4cf8caaa25 100644 --- a/tests/orm/utils/test_node.py +++ b/tests/orm/utils/test_node.py @@ -12,18 +12,14 @@ from aiida.orm import Data from aiida.orm.utils.node import load_node_class -from aiida.storage.testbase import AiidaTestCase -class TestLoadNodeClass(AiidaTestCase): - """Tests for the node plugin type generator and loaders.""" +def test_load_node_class_fallback(): + """Verify that `load_node_class` will fall back to `Data` class if entry point cannot be loaded.""" + loaded_class = load_node_class('data.core.some.non.existing.plugin.') + assert loaded_class == Data - def test_load_node_class_fallback(self): - """Verify that `load_node_class` will fall back to `Data` class if entry point cannot be loaded.""" - loaded_class = load_node_class('data.core.some.non.existing.plugin.') - self.assertEqual(loaded_class, Data) - - # For really unresolvable type strings, we fall back onto the `Data` class - with pytest.warns(UserWarning): - loaded_class = load_node_class('__main__.SubData.') - self.assertEqual(loaded_class, Data) + # For really unresolvable type strings, we fall back onto the `Data` class + with pytest.warns(UserWarning): + loaded_class = load_node_class('__main__.SubData.') + assert loaded_class == Data diff --git a/tests/parsers/test_parser.py b/tests/parsers/test_parser.py index 04bd7154f8..d42ebb2231 100644 --- a/tests/parsers/test_parser.py +++ b/tests/parsers/test_parser.py @@ -7,8 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Test for the `Parser` base class.""" - import io import pytest @@ -19,7 +19,6 @@ from aiida.parsers import Parser from aiida.parsers.plugins.arithmetic.add import SimpleArithmeticAddParser # for demonstration purposes only from aiida.plugins import CalculationFactory, ParserFactory -from aiida.storage.testbase import AiidaTestCase ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') # pylint: disable=invalid-name ArithmeticAddParser = ParserFactory('core.arithmetic.add') # pylint: disable=invalid-name @@ -39,12 +38,18 @@ def prepare_for_submission(self): # pylint: disable=arguments-differ pass -class TestParser(AiidaTestCase): +class TestParser: """Test backend entities and their collections""" + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost + def test_abstract_parse_method(self): """Verify that trying to instantiate base class will raise `TypeError` because of abstract `parse` method.""" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): Parser() # pylint: disable=abstract-class-instantiated,no-value-for-parameter def test_parser_retrieved(self): @@ -58,14 +63,14 @@ def test_parser_retrieved(self): retrieved.add_incoming(node, link_type=LinkType.CREATE, link_label='retrieved') parser = ArithmeticAddParser(node) - self.assertEqual(parser.node.uuid, node.uuid) - self.assertEqual(parser.retrieved.uuid, retrieved.uuid) + assert parser.node.uuid == node.uuid + assert parser.retrieved.uuid == retrieved.uuid def test_parser_exit_codes(self): """Ensure that exit codes from the `CalcJob` can be retrieved through the parser instance.""" node = orm.CalcJobNode(computer=self.computer, process_type=ArithmeticAddCalculation.build_process_type()) parser = ArithmeticAddParser(node) - self.assertEqual(parser.exit_codes, ArithmeticAddCalculation.spec().exit_codes) + assert parser.exit_codes == ArithmeticAddCalculation.spec().exit_codes def test_parser_get_outputs_for_parsing(self): """Make sure that the `get_output_for_parsing` method returns the correct output nodes.""" @@ -83,10 +88,10 @@ def test_parser_get_outputs_for_parsing(self): parser = ArithmeticAddParser(node) outputs_for_parsing = parser.get_outputs_for_parsing() - self.assertIn('retrieved', outputs_for_parsing) - self.assertEqual(outputs_for_parsing['retrieved'].uuid, retrieved.uuid) - self.assertIn('output', outputs_for_parsing) - self.assertEqual(outputs_for_parsing['output'].uuid, output.uuid) + assert 'retrieved' in outputs_for_parsing + assert outputs_for_parsing['retrieved'].uuid == retrieved.uuid + assert 'output' in outputs_for_parsing + assert outputs_for_parsing['output'].uuid == output.uuid @pytest.mark.requires_rmq def test_parse_from_node(self): @@ -113,10 +118,10 @@ def test_parse_from_node(self): for cls in [ArithmeticAddParser, SimpleArithmeticAddParser]: result, calcfunction = cls.parse_from_node(node) - self.assertIsInstance(result['sum'], orm.Int) - self.assertEqual(result['sum'].value, summed) - self.assertIsInstance(calcfunction, orm.CalcFunctionNode) - self.assertEqual(calcfunction.exit_status, 0) + assert isinstance(result['sum'], orm.Int) + assert result['sum'].value == summed + assert isinstance(calcfunction, orm.CalcFunctionNode) + assert calcfunction.exit_status == 0 # Verify that the `retrieved_temporary_folder` keyword can be passed, there is no validation though result, calcfunction = ArithmeticAddParser.parse_from_node(node, retrieved_temporary_folder='/some/path') diff --git a/tests/plugins/test_utils.py b/tests/plugins/test_utils.py index 8b259c1988..f06529d91b 100644 --- a/tests/plugins/test_utils.py +++ b/tests/plugins/test_utils.py @@ -12,14 +12,13 @@ from aiida.engine import WorkChain, calcfunction from aiida.plugins import CalculationFactory from aiida.plugins.utils import PluginVersionProvider -from aiida.storage.testbase import AiidaTestCase -class TestPluginVersionProvider(AiidaTestCase): +class TestPluginVersionProvider: """Tests for the :py:class:`~aiida.plugins.utils.PluginVersionProvider` utility class.""" - def setUp(self): - super().setUp() + def setup_method(self): + # pylint: disable=attribute-defined-outside-init self.provider = PluginVersionProvider() @staticmethod @@ -58,7 +57,7 @@ class DummyCalcJob(): dynamic_plugin = self.create_dynamic_plugin_module(DummyCalcJob, version_plugin, add_module_to_sys=False) expected_version = {'version': {'core': version_core}} - self.assertEqual(self.provider.get_version_info(dynamic_plugin), expected_version) + assert self.provider.get_version_info(dynamic_plugin) == expected_version def test_external_module_no_version_attribute(self): """Test that mapper does not except even if external module does not define `__version__` attribute.""" @@ -70,7 +69,7 @@ class DummyCalcJob(): dynamic_plugin = self.create_dynamic_plugin_module(DummyCalcJob, version_plugin, add_version=False) expected_version = {'version': {'core': version_core}} - self.assertEqual(self.provider.get_version_info(dynamic_plugin), expected_version) + assert self.provider.get_version_info(dynamic_plugin) == expected_version def test_external_module_class(self): """Test the mapper works for a class from an external module.""" @@ -82,7 +81,7 @@ class DummyCalcJob(): dynamic_plugin = self.create_dynamic_plugin_module(DummyCalcJob, version_plugin) expected_version = {'version': {'core': version_core, 'plugin': version_plugin}} - self.assertEqual(self.provider.get_version_info(dynamic_plugin), expected_version) + assert self.provider.get_version_info(dynamic_plugin) == expected_version def test_external_module_function(self): """Test the mapper works for a function from an external module.""" @@ -95,7 +94,7 @@ def test_calcfunction(): dynamic_plugin = self.create_dynamic_plugin_module(test_calcfunction, version_plugin) expected_version = {'version': {'core': version_core, 'plugin': version_plugin}} - self.assertEqual(self.provider.get_version_info(dynamic_plugin), expected_version) + assert self.provider.get_version_info(dynamic_plugin) == expected_version def test_calcfunction(self): """Test the mapper for a `calcfunction`.""" @@ -105,14 +104,14 @@ def test_calcfunction(): return expected_version = {'version': {'core': version_core, 'plugin': version_core}} - self.assertEqual(self.provider.get_version_info(test_calcfunction), expected_version) + assert self.provider.get_version_info(test_calcfunction) == expected_version def test_calc_job(self): """Test the mapper for a `CalcJob`.""" AddArithmeticCalculation = CalculationFactory('core.arithmetic.add') # pylint: disable=invalid-name expected_version = {'version': {'core': version_core, 'plugin': version_core}} - self.assertEqual(self.provider.get_version_info(AddArithmeticCalculation), expected_version) + assert self.provider.get_version_info(AddArithmeticCalculation) == expected_version def test_work_chain(self): """Test the mapper for a `WorkChain`.""" @@ -121,4 +120,4 @@ class SomeWorkChain(WorkChain): """Need to create a dummy class since there is no built-in work chain with entry point in `aiida-core`.""" expected_version = {'version': {'core': version_core, 'plugin': version_core}} - self.assertEqual(self.provider.get_version_info(SomeWorkChain), expected_version) + assert self.provider.get_version_info(SomeWorkChain) == expected_version diff --git a/tests/restapi/test_routes.py b/tests/restapi/test_routes.py index 1a71fc30b0..2d8252622d 100644 --- a/tests/restapi/test_routes.py +++ b/tests/restapi/test_routes.py @@ -7,21 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines,too-many-public-methods """Unittests for REST API.""" from datetime import date import io from flask_cors.core import ACL_ORIGIN +import pytest from aiida import orm from aiida.common import json from aiida.common.links import LinkType +from aiida.manage import get_manager from aiida.restapi.run_api import configure_api -from aiida.storage.testbase import AiidaTestCase -class RESTApiTestCase(AiidaTestCase): +class TestRestApi: """ Setup of the tests for the AiiDA RESTful-api """ @@ -30,16 +31,16 @@ class RESTApiTestCase(AiidaTestCase): _PERPAGE_DEFAULT = 20 _LIMIT_DEFAULT = 400 - @classmethod - def setUpClass(cls): # pylint: disable=too-many-locals, too-many-statements - """ - Add objects to the database for different requests/filters/orderings etc. - """ - super().setUpClass() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init,disable=too-many-locals,too-many-statements api = configure_api(catch_internal_server=True) - cls.app = api.app - cls.app.config['TESTING'] = True + self.app = api.app + self.app.config['TESTING'] = True + + self.user = orm.User.objects.get_default() # create test inputs cell = ((2., 0., 0.), (0., 2., 0.), (0., 0., 2.)) @@ -64,10 +65,11 @@ def setUpClass(cls): # pylint: disable=too-many-locals, too-many-statements resources = {'num_machines': 1, 'num_mpiprocs_per_machine': 1} - calcfunc = orm.CalcFunctionNode(computer=cls.computer) + self.computer = aiida_localhost + calcfunc = orm.CalcFunctionNode(computer=self.computer) calcfunc.store() - calc = orm.CalcJobNode(computer=cls.computer) + calc = orm.CalcJobNode(computer=self.computer) calc.set_option('resources', resources) calc.set_attribute('attr1', 'OK') calc.set_attribute('attr2', 'OK') @@ -106,7 +108,7 @@ def setUpClass(cls): # pylint: disable=too-many-locals, too-many-statements kpoint.add_incoming(calc, link_type=LinkType.CREATE, link_label='create') - calc1 = orm.CalcJobNode(computer=cls.computer) + calc1 = orm.CalcJobNode(computer=self.computer) calc1.set_option('resources', resources) calc1.store() @@ -137,15 +139,14 @@ def setUpClass(cls): # pylint: disable=too-many-locals, too-many-statements computer.store() # Prepare typical REST responses - cls.process_dummy_data() + self.process_dummy_data() - @classmethod - def tearDownClass(cls): - # we need to reset the default user here, - # because the REST API's close_thread_connection decorator wil have closed its session, - # meaning the `PsqlDosBackend._clear` method will fail - orm.User.objects.reset() - super().tearDownClass() + yield + + # because the `close_thread_connection` decorator, currently, directly closes the SQLA session, + # the default user will be detached from the session, and the `_clean` method will fail. + # So, we need to reattach the default user to the session. + get_manager().get_profile_storage().get_session().add(self.user.backend_entity.bare_model) def get_dummy_data(self): return self._dummy_data @@ -153,8 +154,7 @@ def get_dummy_data(self): def get_url_prefix(self): return self._url_prefix - @classmethod - def process_dummy_data(cls): + def process_dummy_data(self): # pylint: disable=fixme """ This functions prepare atomic chunks of typical responses from the @@ -178,7 +178,7 @@ def process_dummy_data(cls): for comp in computers: if comp['uuid'] is not None: comp['uuid'] = str(comp['uuid']) - cls._dummy_data['computers'] = computers + self._dummy_data['computers'] = computers calculation_projections = ['id', 'uuid', 'user_id', 'node_type'] calculations = orm.QueryBuilder().append(orm.CalculationNode, tag='calc', @@ -194,7 +194,7 @@ def process_dummy_data(cls): for calc in calculations: if calc['uuid'] is not None: calc['uuid'] = str(calc['uuid']) - cls._dummy_data['calculations'] = calculations + self._dummy_data['calculations'] = calculations data_projections = ['id', 'uuid', 'user_id', 'node_type'] data_types = { @@ -217,7 +217,7 @@ def process_dummy_data(cls): if datum['uuid'] is not None: datum['uuid'] = str(datum['uuid']) - cls._dummy_data[label] = data + self._dummy_data[label] = data def split_path(self, url): # pylint: disable=no-self-use @@ -250,13 +250,13 @@ def compare_extra_response_data(self, node_type, url, response, uuid=None): """ path, query_string = self.split_path(url) - self.assertEqual(response['method'], 'GET') - self.assertEqual(response['resource_type'], node_type) - self.assertEqual(response['path'], path) - self.assertEqual(response['id'], uuid) - self.assertEqual(response['query_string'], query_string) - self.assertEqual(response['url'], f'http://localhost{url}') - self.assertEqual(response['url_root'], 'http://localhost/') + assert response['method'] == 'GET' + assert response['resource_type'] == node_type + assert response['path'] == path + assert response['id'] == uuid + assert response['query_string'] == query_string + assert response['url'] == f'http://localhost{url}' + assert response['url_root'] == 'http://localhost/' # node details and list with limit, offset, page, perpage def process_test( @@ -305,7 +305,7 @@ def process_test( response = json.loads(rv_response.data) if expected_errormsg: - self.assertEqual(response['message'], expected_errormsg) + assert response['message'] == expected_errormsg else: if full_list: expected_data = self._dummy_data[result_node_type] @@ -321,17 +321,10 @@ def process_test( expected_node_uuids = [node['uuid'] for node in expected_data] result_node_uuids = [node['uuid'] for node in response['data'][result_name]] - self.assertEqual(expected_node_uuids, result_node_uuids) + assert expected_node_uuids == result_node_uuids self.compare_extra_response_data(entity_type, url, response, uuid) - -class RESTApiTestSuite(RESTApiTestCase): - # pylint: disable=too-many-public-methods - """ - Define unittests for rest api - """ - ############### generic endpoints ######################## def test_server(self): @@ -345,8 +338,8 @@ def test_server(self): response = client.get(url) data = json.loads(response.data)['data'] - self.assertEqual(__version__, data['AiiDA_version']) - self.assertEqual(self.get_url_prefix(), data['API_prefix']) + assert __version__ == data['AiiDA_version'] + assert self.get_url_prefix() == data['API_prefix'] def test_base_url(self): """ @@ -356,8 +349,8 @@ def test_base_url(self): data_base = json.loads(client.get(self.get_url_prefix() + '/').data)['data'] data_server = json.loads(client.get(self.get_url_prefix() + '/server/endpoints').data)['data'] - self.assertTrue(len(data_base['available_endpoints']) > 0) - self.assertDictEqual(data_base, data_server) + assert len(data_base['available_endpoints']) > 0 + assert data_base == data_server def test_cors_headers(self): """ @@ -368,7 +361,7 @@ def test_cors_headers(self): with self.app.test_client() as client: response = client.get(url) headers = response.headers - self.assertEqual(headers.get(ACL_ORIGIN), '*') + assert headers.get(ACL_ORIGIN) == '*' ############### computers endpoint ######################## @@ -377,15 +370,13 @@ def test_computers_details(self): Requests the details of single computer """ node_uuid = self.get_dummy_data()['computers'][1]['uuid'] - RESTApiTestCase.process_test( - self, 'computers', f'/computers/{str(node_uuid)}', expected_list_ids=[1], uuid=node_uuid - ) + self.process_test('computers', f'/computers/{str(node_uuid)}', expected_list_ids=[1], uuid=node_uuid) def test_computers_list(self): """ Get the full list of computers from database """ - RESTApiTestCase.process_test(self, 'computers', '/computers?orderby=+id', full_list=True) + self.process_test('computers', '/computers?orderby=+id', full_list=True) def test_computers_list_limit_offset(self): """ @@ -394,9 +385,7 @@ def test_computers_list_limit_offset(self): It should return the no of rows specified in limit from database starting from the no. specified in offset """ - RESTApiTestCase.process_test( - self, 'computers', '/computers?limit=2&offset=2&orderby=+id', expected_range=[2, 4] - ) + self.process_test('computers', '/computers?limit=2&offset=2&orderby=+id', expected_range=[2, 4]) def test_computers_list_limit_only(self): """ @@ -405,7 +394,7 @@ def test_computers_list_limit_only(self): It should return the no of rows specified in limit from database. """ - RESTApiTestCase.process_test(self, 'computers', '/computers?limit=2&orderby=+id', expected_range=[None, 2]) + self.process_test('computers', '/computers?limit=2&orderby=+id', expected_range=[None, 2]) def test_computers_list_offset_only(self): """ @@ -414,7 +403,7 @@ def test_computers_list_offset_only(self): It should return all the rows from database starting from the no. specified in offset """ - RESTApiTestCase.process_test(self, 'computers', '/computers?offset=2&orderby=+id', expected_range=[2, None]) + self.process_test('computers', '/computers?offset=2&orderby=+id', expected_range=[2, None]) def test_computers_list_limit_offset_perpage(self): """ @@ -422,8 +411,8 @@ def test_computers_list_limit_offset_perpage(self): would return the error message. """ expected_error = 'perpage key is incompatible with limit and offset' - RESTApiTestCase.process_test( - self, 'computers', '/computers?offset=2&limit=1&perpage=2&orderby=+id', expected_errormsg=expected_error + self.process_test( + 'computers', '/computers?offset=2&limit=1&perpage=2&orderby=+id', expected_errormsg=expected_error ) def test_computers_list_page_limit_offset(self): @@ -433,8 +422,8 @@ def test_computers_list_page_limit_offset(self): """ expected_error = 'requesting a specific page is incompatible with ' \ 'limit and offset' - RESTApiTestCase.process_test( - self, 'computers', '/computers/page/2?offset=2&limit=1&orderby=+id', expected_errormsg=expected_error + self.process_test( + 'computers', '/computers/page/2?offset=2&limit=1&orderby=+id', expected_errormsg=expected_error ) def test_complist_pagelimitoffset_perpage(self): @@ -443,11 +432,8 @@ def test_complist_pagelimitoffset_perpage(self): would return the error message. """ expected_error = 'perpage key is incompatible with limit and offset' - RESTApiTestCase.process_test( - self, - 'computers', - '/computers/page/2?offset=2&limit=1&perpage=2&orderby=+id', - expected_errormsg=expected_error + self.process_test( + 'computers', '/computers/page/2?offset=2&limit=1&perpage=2&orderby=+id', expected_errormsg=expected_error ) def test_computers_list_page_default(self): @@ -459,16 +445,14 @@ def test_computers_list_page_default(self): "/page" acts as "/page/1?perpage=default_value" """ - RESTApiTestCase.process_test(self, 'computers', '/computers/page?orderby=+id', full_list=True) + self.process_test('computers', '/computers/page?orderby=+id', full_list=True) def test_computers_list_page_perpage(self): """ no.of pages = total no. of computers in database / perpage Using this formula it returns the no. of rows for requested page """ - RESTApiTestCase.process_test( - self, 'computers', '/computers/page/1?perpage=2&orderby=+id', expected_range=[None, 2] - ) + self.process_test('computers', '/computers/page/1?perpage=2&orderby=+id', expected_range=[None, 2]) def test_computers_list_page_perpage_exceed(self): """ @@ -479,9 +463,7 @@ def test_computers_list_page_perpage_exceed(self): """ expected_error = 'Non existent page requested. The page range is [1 : ' \ '3]' - RESTApiTestCase.process_test( - self, 'computers', '/computers/page/4?perpage=2&orderby=+id', expected_errormsg=expected_error - ) + self.process_test('computers', '/computers/page/4?perpage=2&orderby=+id', expected_errormsg=expected_error) ############### list filters ######################## def test_computers_filter_id1(self): @@ -491,7 +473,7 @@ def test_computers_filter_id1(self): """ node_pk = self.get_dummy_data()['computers'][1]['id'] - RESTApiTestCase.process_test(self, 'computers', f'/computers?id={str(node_pk)}', expected_list_ids=[1]) + self.process_test('computers', f'/computers?id={str(node_pk)}', expected_list_ids=[1]) def test_computers_filter_id2(self): """ @@ -499,9 +481,7 @@ def test_computers_filter_id2(self): list (e.g. id > 2) """ node_pk = self.get_dummy_data()['computers'][1]['id'] - RESTApiTestCase.process_test( - self, 'computers', f'/computers?id>{str(node_pk)}&orderby=+id', expected_range=[2, None] - ) + self.process_test('computers', f'/computers?id>{str(node_pk)}&orderby=+id', expected_range=[2, None]) def test_computers_filter_pk(self): """ @@ -509,21 +489,21 @@ def test_computers_filter_pk(self): list (e.g. id=1) """ node_pk = self.get_dummy_data()['computers'][1]['id'] - RESTApiTestCase.process_test(self, 'computers', f'/computers?pk={str(node_pk)}', expected_list_ids=[1]) + self.process_test('computers', f'/computers?pk={str(node_pk)}', expected_list_ids=[1]) def test_computers_filter_name(self): """ Add filter for the label of computer and get the filtered computer list """ - RESTApiTestCase.process_test(self, 'computers', '/computers?label="test1"', expected_list_ids=[1]) + self.process_test('computers', '/computers?label="test1"', expected_list_ids=[1]) def test_computers_filter_hostname(self): """ Add filter for the hostname of computer and get the filtered computer list """ - RESTApiTestCase.process_test(self, 'computers', '/computers?hostname="test1.epfl.ch"', expected_list_ids=[1]) + self.process_test('computers', '/computers?hostname="test1.epfl.ch"', expected_list_ids=[1]) def test_computers_filter_transport_type(self): """ @@ -531,11 +511,8 @@ def test_computers_filter_transport_type(self): computer list """ - RESTApiTestCase.process_test( - self, - 'computers', - '/computers?transport_type="core.local"&label="test3"&orderby=+id', - expected_list_ids=[3] + self.process_test( + 'computers', '/computers?transport_type="core.local"&label="test3"&orderby=+id', expected_list_ids=[3] ) ############### list orderby ######################## @@ -544,21 +521,21 @@ def test_computers_orderby_id_asc(self): Returns the computers list ordered by "id" in ascending order """ - RESTApiTestCase.process_test(self, 'computers', '/computers?orderby=id', full_list=True) + self.process_test('computers', '/computers?orderby=id', full_list=True) def test_computers_orderby_id_asc_sign(self): """ Returns the computers list ordered by "+id" in ascending order """ - RESTApiTestCase.process_test(self, 'computers', '/computers?orderby=+id', full_list=True) + self.process_test('computers', '/computers?orderby=+id', full_list=True) def test_computers_orderby_id_desc(self): """ Returns the computers list ordered by "id" in descending order """ - RESTApiTestCase.process_test(self, 'computers', '/computers?orderby=-id', expected_list_ids=[4, 3, 2, 1, 0]) + self.process_test('computers', '/computers?orderby=-id', expected_list_ids=[4, 3, 2, 1, 0]) def test_computers_orderby_label_asc(self): """ @@ -566,9 +543,7 @@ def test_computers_orderby_label_asc(self): order """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, 'computers', f'/computers?pk>{str(node_pk)}&orderby=label', expected_list_ids=[1, 2, 3, 4] - ) + self.process_test('computers', f'/computers?pk>{str(node_pk)}&orderby=label', expected_list_ids=[1, 2, 3, 4]) def test_computers_orderby_label_asc_sign(self): """ @@ -576,9 +551,7 @@ def test_computers_orderby_label_asc_sign(self): order """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, 'computers', f'/computers?pk>{str(node_pk)}&orderby=+label', expected_list_ids=[1, 2, 3, 4] - ) + self.process_test('computers', f'/computers?pk>{str(node_pk)}&orderby=+label', expected_list_ids=[1, 2, 3, 4]) def test_computers_orderby_label_desc(self): """ @@ -586,9 +559,7 @@ def test_computers_orderby_label_desc(self): order """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, 'computers', f'/computers?pk>{str(node_pk)}&orderby=-label', expected_list_ids=[4, 3, 2, 1] - ) + self.process_test('computers', f'/computers?pk>{str(node_pk)}&orderby=-label', expected_list_ids=[4, 3, 2, 1]) def test_computers_orderby_scheduler_type_asc(self): """ @@ -596,8 +567,7 @@ def test_computers_orderby_scheduler_type_asc(self): order """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, + self.process_test( 'computers', f"/computers?transport_type=\"core.ssh\"&pk>{str(node_pk)}&orderby=scheduler_type", expected_list_ids=[1, 4, 2] @@ -609,8 +579,7 @@ def test_comp_orderby_scheduler_ascsign(self): order """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, + self.process_test( 'computers', f"/computers?transport_type=\"core.ssh\"&pk>{str(node_pk)}&orderby=+scheduler_type", expected_list_ids=[1, 4, 2] @@ -622,8 +591,7 @@ def test_computers_orderby_schedulertype_desc(self): order """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, + self.process_test( 'computers', f"/computers?pk>{str(node_pk)}&transport_type=\"core.ssh\"&orderby=-scheduler_type", expected_list_ids=[2, 4, 1] @@ -637,11 +605,8 @@ def test_computers_orderby_mixed1(self): by "id" """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, - 'computers', - f'/computers?pk>{str(node_pk)}&orderby=transport_type,id', - expected_list_ids=[3, 1, 2, 4] + self.process_test( + 'computers', f'/computers?pk>{str(node_pk)}&orderby=transport_type,id', expected_list_ids=[3, 1, 2, 4] ) def test_computers_orderby_mixed2(self): @@ -651,11 +616,8 @@ def test_computers_orderby_mixed2(self): by "name" """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, - 'computers', - f'/computers?pk>{str(node_pk)}&orderby=-scheduler_type,label', - expected_list_ids=[2, 3, 4, 1] + self.process_test( + 'computers', f'/computers?pk>{str(node_pk)}&orderby=-scheduler_type,label', expected_list_ids=[2, 3, 4, 1] ) def test_computers_orderby_mixed3(self): @@ -680,7 +642,7 @@ def test_computers_orderby_mixed3(self): test1 test4 - RESTApiTestCase.process_test(self, "computers", + self.process_test("computers", "/computers?orderby=+scheduler_type, -hostname", expected_list_ids=[1,0,4,3,2]) @@ -693,8 +655,8 @@ def test_computers_filter_mixed1(self): filtered computer list """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, 'computers', f"/computers?id>{str(node_pk)}&hostname=\"test1.epfl.ch\"", expected_list_ids=[1] + self.process_test( + 'computers', f"/computers?id>{str(node_pk)}&hostname=\"test1.epfl.ch\"", expected_list_ids=[1] ) def test_computers_filter_mixed2(self): @@ -703,8 +665,7 @@ def test_computers_filter_mixed2(self): and get the filtered computer list """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, + self.process_test( 'computers', f"/computers?id>{str(node_pk)}&hostname=\"test3.epfl.ch\"&transport_type=\"core.ssh\"", empty_list=True @@ -716,8 +677,8 @@ def test_computers_mixed1(self): url parameters: id, limit and offset """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, 'computers', f'/computers?id>{str(node_pk)}&limit=2&offset=3&orderby=+id', expected_list_ids=[4] + self.process_test( + 'computers', f'/computers?id>{str(node_pk)}&limit=2&offset=3&orderby=+id', expected_list_ids=[4] ) def test_computers_mixed2(self): @@ -725,8 +686,8 @@ def test_computers_mixed2(self): url parameters: id, page, perpage """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, 'computers', f'/computers/page/2?id>{str(node_pk)}&perpage=2&orderby=+id', expected_list_ids=[3, 4] + self.process_test( + 'computers', f'/computers/page/2?id>{str(node_pk)}&perpage=2&orderby=+id', expected_list_ids=[3, 4] ) def test_computers_mixed3(self): @@ -734,8 +695,7 @@ def test_computers_mixed3(self): url parameters: id, transport_type, orderby """ node_pk = self.get_dummy_data()['computers'][0]['id'] - RESTApiTestCase.process_test( - self, + self.process_test( 'computers', f"/computers?id>={str(node_pk)}&transport_type=\"core.ssh\"&orderby=-id&limit=2", expected_list_ids=[4, 2] @@ -747,7 +707,7 @@ def test_computers_unknown_param(self): url parameters: id, limit and offset from aiida.common.exceptions import InputValidationError - RESTApiTestCase.node_exception(self, "/computers?aa=bb&id=2", InputValidationError) + self.node_exception("/computers?aa=bb&id=2", InputValidationError) """ ############### calculation retrieved_inputs and retrieved_outputs ############# @@ -760,7 +720,7 @@ def test_calculation_retrieved_inputs(self): with self.app.test_client() as client: response_value = client.get(url) response = json.loads(response_value.data) - self.assertEqual(response['data'], [{'name': 'calcjob_inputs', 'type': 'DIRECTORY'}]) + assert response['data'] == [{'name': 'calcjob_inputs', 'type': 'DIRECTORY'}] def test_calculation_retrieved_outputs(self): """ @@ -771,7 +731,7 @@ def test_calculation_retrieved_outputs(self): with self.app.test_client() as client: response_value = client.get(url) response = json.loads(response_value.data) - self.assertEqual(response['data'], [{'name': 'calcjob_outputs', 'type': 'DIRECTORY'}]) + assert response['data'] == [{'name': 'calcjob_outputs', 'type': 'DIRECTORY'}] ############### calculation incoming ############# def test_calculation_inputs(self): @@ -811,17 +771,17 @@ def test_calculation_iotree(self): with self.app.test_client() as client: response_value = client.get(url) response = json.loads(response_value.data) - self.assertEqual(len(response['data']['nodes']), 1) - self.assertEqual(len(response['data']['nodes'][0]['incoming']), 1) - self.assertEqual(len(response['data']['nodes'][0]['outgoing']), 1) - self.assertEqual(len(response['data']['metadata']), 1) + assert len(response['data']['nodes']) == 1 + assert len(response['data']['nodes'][0]['incoming']) == 1 + assert len(response['data']['nodes'][0]['outgoing']) == 1 + assert len(response['data']['metadata']) == 1 expected_attr = [ 'ctime', 'mtime', 'id', 'node_label', 'node_type', 'uuid', 'description', 'incoming', 'outgoing' ] received_attr = response['data']['nodes'][0].keys() for attr in expected_attr: - self.assertIn(attr, received_attr) - RESTApiTestCase.compare_extra_response_data(self, 'nodes', url, response, uuid=node_uuid) + assert attr in received_attr + self.compare_extra_response_data('nodes', url, response, uuid=node_uuid) ############### calculation attributes ############# def test_calculation_attributes(self): @@ -841,9 +801,9 @@ def test_calculation_attributes(self): with self.app.test_client() as client: rv_obj = client.get(url) response = json.loads(rv_obj.data) - self.assertNotIn('message', response) - self.assertEqual(response['data']['attributes'], attributes) - RESTApiTestCase.compare_extra_response_data(self, 'nodes', url, response, uuid=node_uuid) + assert 'message' not in response + assert response['data']['attributes'] == attributes + self.compare_extra_response_data('nodes', url, response, uuid=node_uuid) def test_contents_attributes_filter(self): """ @@ -854,9 +814,9 @@ def test_contents_attributes_filter(self): with self.app.test_client() as client: rv_obj = client.get(url) response = json.loads(rv_obj.data) - self.assertNotIn('message', response) - self.assertEqual(response['data']['attributes'], {'attr1': 'OK'}) - RESTApiTestCase.compare_extra_response_data(self, 'nodes', url, response, uuid=node_uuid) + assert 'message' not in response + assert response['data']['attributes'] == {'attr1': 'OK'} + self.compare_extra_response_data('nodes', url, response, uuid=node_uuid) ############### calculation node attributes filter ############# def test_calculation_attributes_filter(self): @@ -876,7 +836,7 @@ def test_calculation_attributes_filter(self): with self.app.test_client() as client: response_value = client.get(url) response = json.loads(response_value.data) - self.assertEqual(response['data']['nodes'][0]['attributes'], attributes) + assert response['data']['nodes'][0]['attributes'] == attributes ############### calculation node extras_filter ############# def test_calculation_extras_filter(self): @@ -889,8 +849,8 @@ def test_calculation_extras_filter(self): with self.app.test_client() as client: response_value = client.get(url) response = json.loads(response_value.data) - self.assertEqual(response['data']['nodes'][0]['extras']['extra1'], extras['extra1']) - self.assertEqual(response['data']['nodes'][0]['extras']['extra2'], extras['extra2']) + assert response['data']['nodes'][0]['extras']['extra1'] == extras['extra1'] + assert response['data']['nodes'][0]['extras']['extra2'] == extras['extra2'] ############### structure node attributes filter ############# def test_structure_attributes_filter(self): @@ -903,7 +863,7 @@ def test_structure_attributes_filter(self): with self.app.test_client() as client: rv_obj = client.get(url) response = json.loads(rv_obj.data) - self.assertEqual(response['data']['nodes'][0]['attributes']['cell'], cell) + assert response['data']['nodes'][0]['attributes']['cell'] == cell ############### node attributes_filter with pagination ############# def test_node_attributes_filter_pagination(self): @@ -916,14 +876,14 @@ def test_node_attributes_filter_pagination(self): with self.app.test_client() as client: response_value = client.get(url) response = json.loads(response_value.data) - self.assertNotEqual(len(response['data']['nodes']), 0) + assert len(response['data']['nodes']) != 0 for node in response['data']['nodes']: - self.assertIn('attributes', node) - self.assertNotIn('attributes.resources', node) - self.assertNotIn('attributes.cell', node) - self.assertEqual(len(node['attributes']), len(expected_attributes)) + assert 'attributes' in node + assert 'attributes.resources' not in node + assert 'attributes.cell' not in node + assert len(node['attributes']) == len(expected_attributes) for attr in expected_attributes: - self.assertIn(attr, node['attributes']) + assert attr in node['attributes'] ############### node get one attributes_filter with pagination ############# def test_node_single_attributes_filter(self): @@ -936,9 +896,9 @@ def test_node_single_attributes_filter(self): with self.app.test_client() as client: response_value = client.get(url) response = json.loads(response_value.data) - self.assertNotEqual(len(response['data']['nodes']), 0) + assert len(response['data']['nodes']) != 0 for node in response['data']['nodes']: - self.assertEqual(list(node['attributes'].keys()), expected_attribute) + assert list(node['attributes'].keys()) == expected_attribute ############### node extras_filter with pagination ############# def test_node_extras_filter_pagination(self): @@ -951,14 +911,14 @@ def test_node_extras_filter_pagination(self): with self.app.test_client() as client: response_value = client.get(url) response = json.loads(response_value.data) - self.assertNotEqual(len(response['data']['nodes']), 0) + assert len(response['data']['nodes']) != 0 for node in response['data']['nodes']: - self.assertIn('extras', node) - self.assertNotIn('extras.extra1', node) - self.assertNotIn('extras.extra2', node) - self.assertEqual(len(node['extras']), len(expected_extras)) + assert 'extras' in node + assert 'extras.extra1' not in node + assert 'extras.extra2' not in node + assert len(node['extras']) == len(expected_extras) for extra in expected_extras: - self.assertIn(extra, node['extras']) + assert extra in node['extras'] ############### node get one extras_filter with pagination ############# def test_node_single_extras_filter(self): @@ -971,9 +931,9 @@ def test_node_single_extras_filter(self): with self.app.test_client() as client: response_value = client.get(url) response = json.loads(response_value.data) - self.assertNotEqual(len(response['data']['nodes']), 0) + assert len(response['data']['nodes']) != 0 for node in response['data']['nodes']: - self.assertEqual(list(node['extras'].keys()), expected_extra) + assert list(node['extras'].keys()) == expected_extra ############### node full_type filter ############# def test_nodes_full_type_filter(self): @@ -990,7 +950,7 @@ def test_nodes_full_type_filter(self): rv_obj = client.get(url) response = json.loads(rv_obj.data) for node in response['data']['nodes']: - self.assertIn(node['uuid'], expected_node_uuids) + assert node['uuid'] in expected_node_uuids def test_nodes_time_filters(self): """ @@ -1009,7 +969,7 @@ def test_nodes_time_filters(self): rv_obj = client.get(url) response = json.loads(rv_obj.data) for node in response['data']['nodes']: - self.assertIn(node['uuid'], expected_node_uuids) + assert node['uuid'] in expected_node_uuids # mtime filter test url = f"{self.get_url_prefix()}/nodes/?mtime={today}&full_type=\"process.calculation.calcjob.CalcJobNode.|\"" @@ -1017,7 +977,7 @@ def test_nodes_time_filters(self): rv_obj = client.get(url) response = json.loads(rv_obj.data) for node in response['data']['nodes']: - self.assertIn(node['uuid'], expected_node_uuids) + assert node['uuid'] in expected_node_uuids ############### Structure visualization and download ############# def test_structure_derived_properties(self): @@ -1029,16 +989,14 @@ def test_structure_derived_properties(self): with self.app.test_client() as client: rv_obj = client.get(url) response = json.loads(rv_obj.data) - self.assertNotIn('message', response) - self.assertEqual( - response['data']['derived_properties']['dimensionality'], { - 'dim': 3, - 'value': 8.0, - 'label': 'volume' - } - ) - self.assertEqual(response['data']['derived_properties']['formula'], 'Ba') - RESTApiTestCase.compare_extra_response_data(self, 'nodes', url, response, uuid=node_uuid) + assert 'message' not in response + assert response['data']['derived_properties']['dimensionality'] == { + 'dim': 3, + 'value': 8.0, + 'label': 'volume' + } + assert response['data']['derived_properties']['formula'] == 'Ba' + self.compare_extra_response_data('nodes', url, response, uuid=node_uuid) def test_structure_download(self): """ @@ -1051,7 +1009,7 @@ def test_structure_download(self): with self.app.test_client() as client: rv_obj = client.get(url) structure_data = load_node(node_uuid)._exportcontent('xsf')[0] # pylint: disable=protected-access - self.assertEqual(rv_obj.data, structure_data) + assert rv_obj.data == structure_data def test_cif(self): """ @@ -1064,7 +1022,7 @@ def test_cif(self): with self.app.test_client() as client: rv_obj = client.get(url) cif = load_node(node_uuid)._prepare_cif()[0] # pylint: disable=protected-access - self.assertEqual(rv_obj.data, cif) + assert rv_obj.data == cif ############### projectable_properties ############# def test_projectable_properties(self): @@ -1076,7 +1034,7 @@ def test_projectable_properties(self): with self.app.test_client() as client: rv_obj = client.get(url) response = json.loads(rv_obj.data) - self.assertNotIn('message', response) + assert 'message' not in response expected_keys = ['display_name', 'help_text', 'is_display', 'is_foreign_key', 'type'] @@ -1084,12 +1042,12 @@ def test_projectable_properties(self): for _, pinfo in response['data']['fields'].items(): available_keys = pinfo.keys() for prop in expected_keys: - self.assertIn(prop, available_keys) + assert prop in available_keys # check order available_properties = response['data']['fields'].keys() for prop in response['data']['ordering']: - self.assertIn(prop, available_properties) + assert prop in available_properties def test_node_namespace(self): """ @@ -1107,8 +1065,8 @@ def test_node_namespace(self): response = json.loads(rv_obj.data) response_keys = response['data'].keys() for dkay in expected_data_keys: - self.assertIn(dkay, response_keys) - RESTApiTestCase.compare_extra_response_data(self, 'nodes', url, response) + assert dkay in response_keys + self.compare_extra_response_data('nodes', url, response) def test_comments(self): """ @@ -1122,7 +1080,7 @@ def test_comments(self): all_comments = [] for comment in response: all_comments.append(comment['message']) - self.assertEqual(sorted(all_comments), sorted(['This is test comment.', 'Add another comment.'])) + assert sorted(all_comments) == sorted(['This is test comment.', 'Add another comment.']) def test_repo(self): """ @@ -1135,13 +1093,13 @@ def test_repo(self): with self.app.test_client() as client: response_value = client.get(url) response = json.loads(response_value.data) - self.assertEqual(response['data']['repo_list'], [{'type': 'FILE', 'name': 'aiida.in'}]) + assert response['data']['repo_list'] == [{'type': 'FILE', 'name': 'aiida.in'}] url = f"{self.get_url_prefix()}/nodes/{str(node_uuid)}/repo/contents?filename=\"calcjob_inputs/aiida.in\"" with self.app.test_client() as client: response_obj = client.get(url) input_file = load_node(node_uuid).get_object_content('calcjob_inputs/aiida.in', mode='rb') - self.assertEqual(response_obj.data, input_file) + assert response_obj.data == input_file def test_process_report(self): """ @@ -1155,11 +1113,11 @@ def test_process_report(self): expected_keys = response['data'].keys() for key in ['logs']: - self.assertIn(key, expected_keys) + assert key in expected_keys expected_log_keys = response['data']['logs'][0].keys() for key in ['time', 'loggername', 'levelname', 'dbnode_id', 'message']: - self.assertIn(key, expected_log_keys) + assert key in expected_log_keys def test_download_formats(self): """ @@ -1171,10 +1129,10 @@ def test_download_formats(self): response = json.loads(response_value.data) for key in ['data.core.structure.StructureData.|', 'data.core.cif.CifData.|']: - self.assertIn(key, response['data'].keys()) + assert key in response['data'].keys() for key in ['cif', 'xsf', 'xyz']: - self.assertIn(key, response['data']['data.core.structure.StructureData.|']) - self.assertIn('cif', response['data']['data.core.cif.CifData.|']) + assert key in response['data']['data.core.structure.StructureData.|'] + assert 'cif' in response['data']['data.core.cif.CifData.|'] ############### querybuilder ############### def test_querybuilder(self): @@ -1203,23 +1161,19 @@ def test_querybuilder(self): with self.app.test_client() as client: response = client.post(f'{self.get_url_prefix()}/querybuilder', json=query_dict).json - self.assertEqual('POST', response.get('method', '')) - self.assertEqual('QueryBuilder', response.get('resource_type', '')) + assert response.get('method', '') == 'POST' + assert response.get('resource_type', '') == 'QueryBuilder' - self.assertEqual( - len(expected_node_uuids), - len(response.get('data', {}).get('calc', [])), - msg=json.dumps(response, indent=2), - ) - self.assertListEqual( - expected_node_uuids, - [_.get('uuid', '') for _ in response.get('data', {}).get('calc', [])], - ) + assert len(expected_node_uuids) == \ + len(response.get('data', {}).get('calc', [])), \ + json.dumps(response, indent=2) + assert expected_node_uuids == \ + [_.get('uuid', '') for _ in response.get('data', {}).get('calc', [])] for entities in response.get('data', {}).values(): for entity in entities: # All are Nodes, but neither `node_type` or `process_type` are requested, # hence `full_type` should not be present. - self.assertFalse('full_type' in entity) + assert 'full_type' not in entity def test_get_querybuilder(self): """Test GETting the /querybuilder endpoint @@ -1233,12 +1187,12 @@ def test_get_querybuilder(self): response_value = client.get(f'{self.get_url_prefix()}/querybuilder') response = response_value.json - self.assertEqual(response_value.status_code, 405) - self.assertEqual(response_value.status, '405 METHOD NOT ALLOWED') + assert response_value.status_code == 405 + assert response_value.status == '405 METHOD NOT ALLOWED' - self.assertEqual('GET', response.get('method', '')) - self.assertEqual('QueryBuilder', response.get('resource_type', '')) - self.assertEqual(qb_api.GET_MESSAGE, response.get('data', {}).get('message', '')) + assert response.get('method', '') == 'GET' + assert response.get('resource_type', '') == 'QueryBuilder' + assert qb_api.GET_MESSAGE == response.get('data', {}).get('message', '') def test_querybuilder_user(self): """Retrieve a User through the use of the /querybuilder endpoint @@ -1270,26 +1224,20 @@ def test_querybuilder_user(self): with self.app.test_client() as client: response = client.post(f'{self.get_url_prefix()}/querybuilder', json=query_dict).json - self.assertEqual('POST', response.get('method', '')) - self.assertEqual('QueryBuilder', response.get('resource_type', '')) + assert response.get('method', '') == 'POST' + assert response.get('resource_type', '') == 'QueryBuilder' - self.assertEqual( - len(expected_user_ids), - len(response.get('data', {}).get('users', [])), - msg=json.dumps(response, indent=2), - ) - self.assertListEqual( - expected_user_ids, - [_.get('id', '') for _ in response.get('data', {}).get('users', [])], - ) - self.assertListEqual( - expected_user_ids, - [_.get('user_id', '') for _ in response.get('data', {}).get('calc', [])], - ) + assert len(expected_user_ids) == \ + len(response.get('data', {}).get('users', [])), \ + json.dumps(response, indent=2) + assert expected_user_ids == \ + [_.get('id', '') for _ in response.get('data', {}).get('users', [])] + assert expected_user_ids == \ + [_.get('user_id', '') for _ in response.get('data', {}).get('calc', [])] for entities in response.get('data', {}).values(): for entity in entities: # User is not a Node (no full_type) - self.assertFalse('full_type' in entity) + assert 'full_type' not in entity def test_querybuilder_project_explicit(self): """Expliticly project everything from the resulting entities @@ -1323,32 +1271,24 @@ def test_querybuilder_project_explicit(self): with self.app.test_client() as client: response = client.post(f'{self.get_url_prefix()}/querybuilder', json=query_dict).json - self.assertEqual('POST', response.get('method', '')) - self.assertEqual('QueryBuilder', response.get('resource_type', '')) - - self.assertEqual( - len(expected_calc_uuids), - len(response.get('data', {}).get('calc', [])), - msg=json.dumps(response, indent=2), - ) - self.assertEqual( - len(expected_data_uuids), - len(response.get('data', {}).get('data', [])), - msg=json.dumps(response, indent=2), - ) - self.assertListEqual( - expected_calc_uuids, - [_.get('uuid', '') for _ in response.get('data', {}).get('calc', [])], - ) - self.assertListEqual( - expected_data_uuids, - [_.get('uuid', '') for _ in response.get('data', {}).get('data', [])], - ) + assert response.get('method', '') == 'POST' + assert response.get('resource_type', '') == 'QueryBuilder' + + assert len(expected_calc_uuids) == \ + len(response.get('data', {}).get('calc', [])), \ + json.dumps(response, indent=2) + assert len(expected_data_uuids) == \ + len(response.get('data', {}).get('data', [])), \ + json.dumps(response, indent=2) + assert expected_calc_uuids == \ + [_.get('uuid', '') for _ in response.get('data', {}).get('calc', [])] + assert expected_data_uuids == \ + [_.get('uuid', '') for _ in response.get('data', {}).get('data', [])] for entities in response.get('data', {}).values(): for entity in entities: # All are Nodes, and all properties are projected, full_type should be present - self.assertTrue('full_type' in entity) - self.assertTrue('attributes' in entity) + assert 'full_type' in entity + assert 'attributes' in entity def test_querybuilder_project_implicit(self): """Implicitly project everything from the resulting entities @@ -1374,21 +1314,17 @@ def test_querybuilder_project_implicit(self): with self.app.test_client() as client: response = client.post(f'{self.get_url_prefix()}/querybuilder', json=query_dict).json - self.assertEqual('POST', response.get('method', '')) - self.assertEqual('QueryBuilder', response.get('resource_type', '')) + assert response.get('method', '') == 'POST' + assert response.get('resource_type', '') == 'QueryBuilder' - self.assertListEqual(['data'], list(response.get('data', {}).keys())) - self.assertEqual( - len(expected_data_uuids), - len(response.get('data', {}).get('data', [])), - msg=json.dumps(response, indent=2), - ) - self.assertListEqual( - expected_data_uuids, - [_.get('uuid', '') for _ in response.get('data', {}).get('data', [])], - ) + assert ['data'] == list(response.get('data', {}).keys()) + assert len(expected_data_uuids) == \ + len(response.get('data', {}).get('data', [])), \ + json.dumps(response, indent=2) + assert expected_data_uuids == \ + [_.get('uuid', '') for _ in response.get('data', {}).get('data', [])] for entities in response.get('data', {}).values(): for entity in entities: # All are Nodes, and all properties are projected, full_type should be present - self.assertTrue('full_type' in entity) - self.assertTrue('attributes' in entity) + assert 'full_type' in entity + assert 'attributes' in entity diff --git a/tests/storage/psql_dos/test_nodes.py b/tests/storage/psql_dos/test_nodes.py index 104fe2ff50..4bee5a3958 100644 --- a/tests/storage/psql_dos/test_nodes.py +++ b/tests/storage/psql_dos/test_nodes.py @@ -7,62 +7,66 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=import-error,no-name-in-module +# pylint: disable=import-error,no-name-in-module,no-self-use """Tests for nodes, attributes and links.""" +import pytest from aiida import orm -from aiida.orm import Data -from aiida.storage.testbase import AiidaTestCase +from aiida.orm import Data, load_node -class TestNodeBasicSQLA(AiidaTestCase): +class TestNodeBasicSQLA: """These tests check the basic features of nodes(setting of attributes, copying of files, ...).""" + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, backend): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.backend = backend + def test_load_nodes(self): """Test for load_node() function.""" - from aiida.orm import load_node - a_obj = Data() a_obj.store() - self.assertEqual(a_obj.pk, load_node(identifier=a_obj.pk).pk) - self.assertEqual(a_obj.pk, load_node(identifier=a_obj.uuid).pk) - self.assertEqual(a_obj.pk, load_node(pk=a_obj.pk).pk) - self.assertEqual(a_obj.pk, load_node(uuid=a_obj.uuid).pk) + assert a_obj.pk == load_node(identifier=a_obj.pk).pk + assert a_obj.pk == load_node(identifier=a_obj.uuid).pk + assert a_obj.pk == load_node(pk=a_obj.pk).pk + assert a_obj.pk == load_node(uuid=a_obj.uuid).pk session = self.backend.get_session() try: session.begin_nested() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): load_node(identifier=a_obj.pk, pk=a_obj.pk) finally: session.rollback() try: session.begin_nested() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): load_node(pk=a_obj.pk, uuid=a_obj.uuid) finally: session.rollback() try: session.begin_nested() - with self.assertRaises(TypeError): + with pytest.raises(TypeError): load_node(pk=a_obj.uuid) finally: session.rollback() try: session.begin_nested() - with self.assertRaises(TypeError): + with pytest.raises(TypeError): load_node(uuid=a_obj.pk) finally: session.rollback() try: session.begin_nested() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): load_node() finally: session.rollback() @@ -77,7 +81,7 @@ def test_multiple_node_creation(self): from aiida.storage.psql_dos.models.node import DbNode # Get the automatic user - dbuser = self.backend.users.create(f'{self.id()}@aiida.net').store().bare_model + dbuser = self.backend.users.create('user@aiida.net').store().bare_model # Create a new node but don't add it to the session node_uuid = get_new_uuid() DbNode(user=dbuser, uuid=node_uuid, node_type=None) @@ -86,14 +90,14 @@ def test_multiple_node_creation(self): # Query the session before commit res = session.query(DbNode.uuid).filter(DbNode.uuid == node_uuid).all() - self.assertEqual(len(res), 0, 'There should not be any nodes with this UUID in the session/DB.') + assert len(res) == 0, 'There should not be any nodes with this UUID in the session/DB.' # Commit the transaction session.commit() # Check again that the node is not in the DB res = session.query(DbNode.uuid).filter(DbNode.uuid == node_uuid).all() - self.assertEqual(len(res), 0, 'There should not be any nodes with this UUID in the session/DB.') + assert len(res) == 0, 'There should not be any nodes with this UUID in the session/DB.' # Get the automatic user dbuser = orm.User.objects.get_default().backend_entity.bare_model @@ -104,11 +108,11 @@ def test_multiple_node_creation(self): # Query the session before commit res = session.query(DbNode.uuid).filter(DbNode.uuid == node_uuid).all() - self.assertEqual(len(res), 1, f'There should be a node in the session/DB with the UUID {node_uuid}') + assert len(res) == 1, f'There should be a node in the session/DB with the UUID {node_uuid}' # Commit the transaction session.commit() # Check again that the node is in the db res = session.query(DbNode.uuid).filter(DbNode.uuid == node_uuid).all() - self.assertEqual(len(res), 1, f'There should be a node in the session/DB with the UUID {node_uuid}') + assert len(res) == 1, f'There should be a node in the session/DB with the UUID {node_uuid}' diff --git a/tests/storage/psql_dos/test_query.py b/tests/storage/psql_dos/test_query.py index ce57b6c3b0..27340d2021 100644 --- a/tests/storage/psql_dos/test_query.py +++ b/tests/storage/psql_dos/test_query.py @@ -8,45 +8,40 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for generic queries.""" +import pytest from aiida.orm import Computer, Data, Group, Node, ProcessNode, QueryBuilder, User -from aiida.storage.testbase import AiidaTestCase -class TestQueryBuilderSQLA(AiidaTestCase): - """Test QueryBuilder for SQLA objects.""" +def test_qb_clsf_sqla(): + """Test SQLA classifiers""" + from aiida.orm.querybuilder import _get_ormclass - def test_clsf_sqla(self): - """Test SQLA classifiers""" - from aiida.orm.querybuilder import _get_ormclass + for aiida_cls, orm_name in zip((Group, User, Computer, Node, Data, ProcessNode), + ('group', 'user', 'computer', 'node', 'node', 'node')): + cls, _ = _get_ormclass(aiida_cls, None) - for aiida_cls, orm_name in zip((Group, User, Computer, Node, Data, ProcessNode), - ('group', 'user', 'computer', 'node', 'node', 'node')): - cls, _ = _get_ormclass(aiida_cls, None) + assert cls.value == orm_name - self.assertEqual(cls.value, orm_name) +@pytest.mark.usefixtures('aiida_profile_clean') +def test_qb_ordering_limits_offsets_sqla(): + """Test ordering limits offsets of SQLA query results.""" + # Creating 10 nodes with an attribute that can be ordered + for i in range(10): + node = Data() + node.set_attribute('foo', i) + node.store() + q_b = QueryBuilder().append(Node, project='attributes.foo').order_by({Node: {'attributes.foo': {'cast': 'i'}}}) + res = next(zip(*q_b.all())) + assert res == tuple(range(10)) -class QueryBuilderLimitOffsetsTestSQLA(AiidaTestCase): - """Test query builder limits.""" + # Now applying an offset: + q_b.offset(5) + res = next(zip(*q_b.all())) + assert res == tuple(range(5, 10)) - def test_ordering_limits_offsets_sqla(self): - """Test ordering limits offsets of SQLA query results.""" - # Creating 10 nodes with an attribute that can be ordered - for i in range(10): - node = Data() - node.set_attribute('foo', i) - node.store() - q_b = QueryBuilder().append(Node, project='attributes.foo').order_by({Node: {'attributes.foo': {'cast': 'i'}}}) - res = next(zip(*q_b.all())) - self.assertEqual(res, tuple(range(10))) - - # Now applying an offset: - q_b.offset(5) - res = next(zip(*q_b.all())) - self.assertEqual(res, tuple(range(5, 10))) - - # Now also applying a limit: - q_b.limit(3) - res = next(zip(*q_b.all())) - self.assertEqual(res, tuple(range(5, 8))) + # Now also applying a limit: + q_b.limit(3) + res = next(zip(*q_b.all())) + assert res == tuple(range(5, 8)) diff --git a/tests/storage/psql_dos/test_schema.py b/tests/storage/psql_dos/test_schema.py index affd43937f..7df924be6a 100644 --- a/tests/storage/psql_dos/test_schema.py +++ b/tests/storage/psql_dos/test_schema.py @@ -7,10 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=import-error,no-name-in-module +# pylint: disable=import-error,no-name-in-module,no-self-use """Test object relationships in the database.""" import warnings +import pytest from sqlalchemy import exc as sa_exc from aiida.common.links import LinkType @@ -19,10 +20,10 @@ from aiida.orm import CalculationNode, Data from aiida.storage.psql_dos.models.node import DbNode from aiida.storage.psql_dos.models.user import DbUser -from aiida.storage.testbase import AiidaTestCase -class TestRelationshipsSQLA(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestRelationshipsSQLA: """Class of tests concerning the schema and the correct implementation of relationships within the AiiDA ORM @@ -44,17 +45,17 @@ def test_outputs_children_relationship(self): n_3.add_incoming(n_2, link_type=LinkType.CREATE, link_label='N2') # Check that the result of outputs is a list - self.assertIsInstance(n_1.backend_entity.bare_model.outputs, list, 'This is expected to be a list') + assert isinstance(n_1.backend_entity.bare_model.outputs, list), 'This is expected to be a list' # Check that the result of outputs_q is a query from sqlalchemy.orm.dynamic import AppenderQuery - self.assertIsInstance( - n_1.backend_entity.bare_model.outputs_q, AppenderQuery, 'This is expected to be an AppenderQuery' - ) + assert isinstance( + n_1.backend_entity.bare_model.outputs_q, AppenderQuery + ), 'This is expected to be an AppenderQuery' # Check that the result of outputs is correct out = {_.pk for _ in n_1.backend_entity.bare_model.outputs} - self.assertEqual(out, set([n_2.pk])) + assert out == set([n_2.pk]) def test_inputs_parents_relationship(self): """This test checks that the inputs_q, parents_q relationship and the @@ -69,17 +70,17 @@ def test_inputs_parents_relationship(self): n_3.add_incoming(n_2, link_type=LinkType.CREATE, link_label='N2') # Check that the result of outputs is a list - self.assertIsInstance(n_1.backend_entity.bare_model.inputs, list, 'This is expected to be a list') + assert isinstance(n_1.backend_entity.bare_model.inputs, list), 'This is expected to be a list' # Check that the result of outputs_q is a query from sqlalchemy.orm.dynamic import AppenderQuery - self.assertIsInstance( - n_1.backend_entity.bare_model.inputs_q, AppenderQuery, 'This is expected to be an AppenderQuery' - ) + assert isinstance( + n_1.backend_entity.bare_model.inputs_q, AppenderQuery + ), 'This is expected to be an AppenderQuery' # Check that the result of inputs is correct out = {_.pk for _ in n_3.backend_entity.bare_model.inputs} - self.assertEqual(out, set([n_2.pk])) + assert out == set([n_2.pk]) def test_user_node_1(self): """Test that when a user and a node having that user are created, @@ -95,8 +96,8 @@ def test_user_node_1(self): dbn_1 = DbNode(**node_dict) # Check that the two are neither flushed nor committed - self.assertIsNone(dbu1.id) - self.assertIsNone(dbn_1.id) + assert dbu1.id is None + assert dbn_1.id is None session = get_manager().get_profile_storage().get_session() # Add only the node and commit @@ -105,8 +106,8 @@ def test_user_node_1(self): # Check that a pk has been assigned, which means that things have # been flushed into the database - self.assertIsNotNone(dbn_1.id) - self.assertIsNotNone(dbu1.id) + assert dbn_1.id is not None + assert dbu1.id is not None def test_user_node_2(self): """Test that when a user and a node having that user are created, @@ -121,8 +122,8 @@ def test_user_node_2(self): dbn_1 = DbNode(**node_dict) # Check that the two are neither flushed nor committed - self.assertIsNone(dbu1.id) - self.assertIsNone(dbn_1.id) + assert dbu1.id is None + assert dbn_1.id is None session = get_manager().get_profile_storage().get_session() @@ -136,8 +137,8 @@ def test_user_node_2(self): # Check that a pk has been assigned (or not), which means that things # have been flushed into the database - self.assertIsNotNone(dbu1.id) - self.assertIsNone(dbn_1.id) + assert dbu1.id is not None + assert dbn_1.id is None def test_user_node_3(self): """Test that when a user and two nodes having that user are created, @@ -155,9 +156,9 @@ def test_user_node_3(self): dbn_2 = DbNode(**node_dict) # Check that the two are neither flushed nor committed - self.assertIsNone(dbu1.id) - self.assertIsNone(dbn_1.id) - self.assertIsNone(dbn_2.id) + assert dbu1.id is None + assert dbn_1.id is None + assert dbn_2.id is None session = get_manager().get_profile_storage().get_session() @@ -170,9 +171,9 @@ def test_user_node_3(self): # Check for which object a pk has been assigned, which means that # things have been at least flushed into the database - self.assertIsNotNone(dbu1.id) - self.assertIsNotNone(dbn_1.id) - self.assertIsNone(dbn_2.id) + assert dbu1.id is not None + assert dbn_1.id is not None + assert dbn_2.id is None def test_user_node_4(self): """Test that when several nodes are created with the same user and each @@ -195,8 +196,8 @@ def test_user_node_4(self): dbn_1 = DbNode(user=dbu1, uuid=get_new_uuid()) # Check that the two are neither flushed nor committed - self.assertIsNone(dbu1.id) - self.assertIsNone(dbn_1.id) + assert dbu1.id is None + assert dbn_1.id is None session = get_manager().get_profile_storage().get_session() @@ -209,5 +210,5 @@ def test_user_node_4(self): # Check for which object a pk has been assigned, which means that # things have been at least flushed into the database - self.assertIsNotNone(dbu1.id) - self.assertIsNotNone(dbn_1.id) + assert dbu1.id is not None + assert dbn_1.id is not None diff --git a/tests/test_calculation_node.py b/tests/test_calculation_node.py index 2906deca15..ea99b606fd 100644 --- a/tests/test_calculation_node.py +++ b/tests/test_calculation_node.py @@ -8,14 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the CalculationNode and CalcJobNode class.""" +import pytest from aiida.common.datastructures import CalcJobState from aiida.common.exceptions import ModificationNotAllowed from aiida.orm import CalcJobNode, CalculationNode -from aiida.storage.testbase import AiidaTestCase -class TestProcessNode(AiidaTestCase): +class TestProcessNode: """ These tests check the features of process nodes that differ from the base Node type """ @@ -45,29 +45,28 @@ class TestProcessNode(AiidaTestCase): emptydict = {} emptylist = [] - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer.configure() # pylint: disable=no-member - cls.construction_options = {'resources': {'num_machines': 1, 'num_mpiprocs_per_machine': 1}} - - cls.calcjob = CalcJobNode() - cls.calcjob.computer = cls.computer - cls.calcjob.set_options(cls.construction_options) - cls.calcjob.store() - - def test_process_state(self): + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.calcjob = CalcJobNode() + self.calcjob.computer = aiida_localhost + self.calcjob.set_options({'resources': {'num_machines': 1, 'num_mpiprocs_per_machine': 1}}) + self.calcjob.store() + + @staticmethod + def test_process_state(): """ Check the properties of a newly created bare CalculationNode """ process_node = CalculationNode() - self.assertEqual(process_node.is_terminated, False) - self.assertEqual(process_node.is_excepted, False) - self.assertEqual(process_node.is_killed, False) - self.assertEqual(process_node.is_finished, False) - self.assertEqual(process_node.is_finished_ok, False) - self.assertEqual(process_node.is_failed, False) + assert process_node.is_terminated is False + assert process_node.is_excepted is False + assert process_node.is_killed is False + assert process_node.is_finished is False + assert process_node.is_finished_ok is False + assert process_node.is_failed is False def test_process_node_updatable_attribute(self): """Check that updatable attributes and only those can be mutated for a stored but unsealed CalculationNode.""" @@ -87,49 +86,49 @@ def test_process_node_updatable_attribute(self): # Check before storing node.set_attribute(CalculationNode.PROCESS_STATE_KEY, self.stateval) - self.assertEqual(node.get_attribute(CalculationNode.PROCESS_STATE_KEY), self.stateval) + assert node.get_attribute(CalculationNode.PROCESS_STATE_KEY) == self.stateval node.store() # Check after storing - self.assertEqual(node.get_attribute(CalculationNode.PROCESS_STATE_KEY), self.stateval) + assert node.get_attribute(CalculationNode.PROCESS_STATE_KEY) == self.stateval # I should be able to mutate the updatable attribute but not the others node.set_attribute(CalculationNode.PROCESS_STATE_KEY, 'FINISHED') node.delete_attribute(CalculationNode.PROCESS_STATE_KEY) # Deleting non-existing attribute should raise attribute error - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): node.delete_attribute(CalculationNode.PROCESS_STATE_KEY) - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): node.set_attribute('bool', False) - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): node.delete_attribute('bool') node.seal() # After sealing, even updatable attributes should be immutable - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): node.set_attribute(CalculationNode.PROCESS_STATE_KEY, 'FINISHED') - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): node.delete_attribute(CalculationNode.PROCESS_STATE_KEY) def test_get_description(self): - self.assertEqual(self.calcjob.get_description(), '') + assert self.calcjob.get_description() == '' self.calcjob.set_state(CalcJobState.PARSING) - self.assertEqual(self.calcjob.get_description(), CalcJobState.PARSING.value) + assert self.calcjob.get_description() == CalcJobState.PARSING.value def test_get_authinfo(self): """Test that we can get the AuthInfo object from the calculation instance.""" from aiida.orm import AuthInfo authinfo = self.calcjob.get_authinfo() - self.assertIsInstance(authinfo, AuthInfo) + assert isinstance(authinfo, AuthInfo) def test_get_transport(self): """Test that we can get the Transport object from the calculation instance.""" from aiida.transports import Transport transport = self.calcjob.get_transport() - self.assertIsInstance(transport, Transport) + assert isinstance(transport, Transport) diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 2df6ca3e8c..fbe159d2e1 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -7,20 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines,invalid-name,no-member +# pylint: disable=too-many-lines,invalid-name,no-member,too-many-public-methods,no-self-use """Tests for specific subclasses of Data.""" import os import tempfile -import unittest +import numpy as np import pytest from aiida.common.exceptions import ModificationNotAllowed from aiida.common.utils import Capturing from aiida.orm import ArrayData, BandsData, CifData, Dict, KpointsData, StructureData, TrajectoryData, load_node +from aiida.orm.nodes.data.cif import has_pycifrw from aiida.orm.nodes.data.structure import ( Kind, Site, + _atomic_masses, ase_refine_cell, get_formula, get_pymatgen_version, @@ -28,7 +30,6 @@ has_pymatgen, has_spglib, ) -from aiida.storage.testbase import AiidaTestCase def has_seekpath(): @@ -59,16 +60,20 @@ def simplify(string): return '\n'.join(s.strip() for s in string.split()) -@pytest.mark.skipif(not has_pymatgen(), reason='pymatgen not installed') +skip_ase = pytest.mark.skipif(not has_ase(), reason='Unable to import ase') +skip_spglib = pytest.mark.skipif(not has_spglib(), reason='Unable to import spglib') +skip_pycifrw = pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') +skip_pymatgen = pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') + + +@skip_pymatgen def test_get_pymatgen_version(): assert isinstance(get_pymatgen_version(), str) -class TestCifData(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestCifData: """Tests for CifData class.""" - from distutils.version import StrictVersion - - from aiida.orm.nodes.data.cif import has_pycifrw valid_sample_cif_str = ''' data_test @@ -109,7 +114,7 @@ class TestCifData(AiidaTestCase): O 0.5 0.5 0.5 . ''' - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_reload_cifdata(self): """Test `CifData` cycle.""" file_content = 'data_test _cell_length_a 10(1)' @@ -121,35 +126,35 @@ def test_reload_cifdata(self): a = CifData(file=filename, source={'version': '1234', 'db_name': 'COD', 'id': '0000001'}) # Key 'db_kind' is not allowed in source description: - with self.assertRaises(KeyError): + with pytest.raises(KeyError): a.source = {'db_kind': 'small molecule'} the_uuid = a.uuid - self.assertEqual(a.list_object_names(), [basename]) + assert a.list_object_names() == [basename] with a.open() as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content a.store() - self.assertEqual(a.source, { + assert a.source == { 'db_name': 'COD', 'id': '0000001', 'version': '1234', - }) + } with a.open() as fhandle: - self.assertEqual(fhandle.read(), file_content) - self.assertEqual(a.list_object_names(), [basename]) + assert fhandle.read() == file_content + assert a.list_object_names() == [basename] b = load_node(the_uuid) # I check the retrieved object - self.assertTrue(isinstance(b, CifData)) - self.assertEqual(b.list_object_names(), [basename]) + assert isinstance(b, CifData) + assert b.list_object_names() == [basename] with b.open() as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # Checking the get_or_create() method: with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -157,9 +162,9 @@ def test_reload_cifdata(self): tmpf.flush() c, created = CifData.get_or_create(tmpf.name, store_cif=False) - self.assertTrue(isinstance(c, CifData)) - self.assertTrue(not created) - self.assertEqual(c.get_content(), file_content) + assert isinstance(c, CifData) + assert not created + assert c.get_content() == file_content other_content = 'data_test _cell_length_b 10(1)' with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -167,11 +172,11 @@ def test_reload_cifdata(self): tmpf.flush() c, created = CifData.get_or_create(tmpf.name, store_cif=False) - self.assertTrue(isinstance(c, CifData)) - self.assertTrue(created) - self.assertEqual(c.get_content(), other_content) + assert isinstance(c, CifData) + assert created + assert c.get_content() == other_content - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_parse_cifdata(self): """Test parsing a CIF file.""" file_content = 'data_test _cell_length_a 10(1)' @@ -180,9 +185,9 @@ def test_parse_cifdata(self): tmpf.flush() a = CifData(file=tmpf.name) - self.assertEqual(list(a.values.keys()), ['test']) + assert list(a.values.keys()) == ['test'] - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_change_cifdata_file(self): """Test changing file for `CifData` before storing.""" file_content_1 = 'data_test _cell_length_a 10(1)' @@ -192,17 +197,17 @@ def test_change_cifdata_file(self): tmpf.flush() a = CifData(file=tmpf.name) - self.assertEqual(a.values['test']['_cell_length_a'], '10(1)') + assert a.values['test']['_cell_length_a'] == '10(1)' with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write(file_content_2) tmpf.flush() a.set_file(tmpf.name) - self.assertEqual(a.values['test']['_cell_length_a'], '11(1)') + assert a.values['test']['_cell_length_a'] == '11(1)' - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_ase + @skip_pycifrw @pytest.mark.requires_rmq def test_get_structure(self): """Test `CifData.get_structure`.""" @@ -232,15 +237,15 @@ def test_get_structure(self): tmpf.flush() a = CifData(file=tmpf.name) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.get_structure(converter='none') c = a.get_structure() - self.assertEqual(c.get_kind_names(), ['C', 'O']) + assert c.get_kind_names() == ['C', 'O'] - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_ase + @skip_pycifrw @pytest.mark.requires_rmq def test_ase_primitive_and_conventional_cells_ase(self): """Checking the number of atoms per primitive/conventional cell @@ -276,17 +281,17 @@ def test_ase_primitive_and_conventional_cells_ase(self): c = CifData(file=tmpf.name) ase = c.get_structure(converter='ase', primitive_cell=False).get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 15) + assert ase.get_global_number_of_atoms() == 15 ase = c.get_structure(converter='ase').get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 15) + assert ase.get_global_number_of_atoms() == 15 ase = c.get_structure(converter='ase', primitive_cell=True, subtrans_included=False).get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 5) + assert ase.get_global_number_of_atoms() == 5 - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_ase + @skip_pycifrw + @skip_pymatgen @pytest.mark.requires_rmq def test_ase_primitive_and_conventional_cells_pymatgen(self): """Checking the number of atoms per primitive/conventional cell @@ -332,15 +337,15 @@ def test_ase_primitive_and_conventional_cells_pymatgen(self): c = CifData(file=tmpf.name) ase = c.get_structure(converter='pymatgen', primitive_cell=False).get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 15) + assert ase.get_global_number_of_atoms() == 15 ase = c.get_structure(converter='pymatgen').get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 15) + assert ase.get_global_number_of_atoms() == 15 ase = c.get_structure(converter='pymatgen', primitive_cell=True).get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 5) + assert ase.get_global_number_of_atoms() == 5 - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_pycifrw_from_datablocks(self): """ Tests CifData.pycifrw_from_cif() @@ -360,8 +365,7 @@ def test_pycifrw_from_datablocks(self): for line in lines: if not re.search('^#', line): non_comments.append(line) - self.assertEqual( - simplify('\n'.join(non_comments)), + assert simplify('\n'.join(non_comments)) == \ simplify( ''' data_0 @@ -380,7 +384,6 @@ def test_pycifrw_from_datablocks(self): _publ_section_title 'Test CIF' ''' ) - ) loops = {'_atom_site': ['_atom_site_label', '_atom_site_occupancy']} with Capturing(): @@ -389,8 +392,7 @@ def test_pycifrw_from_datablocks(self): for line in lines: if not re.search('^#', line): non_comments.append(line) - self.assertEqual( - simplify('\n'.join(non_comments)), + assert simplify('\n'.join(non_comments)) == \ simplify( ''' data_0 @@ -404,9 +406,8 @@ def test_pycifrw_from_datablocks(self): _publ_section_title 'Test CIF' ''' ) - ) - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_pycifrw_syntax(self): """Tests CifData.pycifrw_from_cif() - check syntax pb in PyCifRW 3.6.""" import re @@ -422,15 +423,13 @@ def test_pycifrw_syntax(self): for line in lines: if not re.search('^#', line): non_comments.append(line) - self.assertEqual( - simplify('\n'.join(non_comments)), + assert simplify('\n'.join(non_comments)) == \ simplify(''' data_0 _tag '[value]' ''') - ) - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw @staticmethod def test_cif_with_long_line(): """Tests CifData - check that long lines (longer than 2048 characters) are supported. @@ -443,8 +442,8 @@ def test_cif_with_long_line(): tmpf.flush() _ = CifData(file=tmpf.name) - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_ase + @skip_pycifrw def test_cif_roundtrip(self): """Test the `CifData` roundtrip.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -473,26 +472,24 @@ def test_cif_roundtrip(self): b = CifData(values=a.values) c = CifData(values=b.values) - self.assertEqual(b._prepare_cif(), c._prepare_cif()) # pylint: disable=protected-access + assert b._prepare_cif() == c._prepare_cif() # pylint: disable=protected-access b = CifData(ase=a.ase) c = CifData(ase=b.ase) - self.assertEqual(b._prepare_cif(), c._prepare_cif()) # pylint: disable=protected-access + assert b._prepare_cif() == c._prepare_cif() # pylint: disable=protected-access def test_symop_string_from_symop_matrix_tr(self): """Test symmetry operations.""" from aiida.tools.data.cif import symop_string_from_symop_matrix_tr - self.assertEqual(symop_string_from_symop_matrix_tr([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), 'x,y,z') + assert symop_string_from_symop_matrix_tr([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) == 'x,y,z' - self.assertEqual(symop_string_from_symop_matrix_tr([[1, 0, 0], [0, -1, 0], [0, 1, 1]]), 'x,-y,y+z') + assert symop_string_from_symop_matrix_tr([[1, 0, 0], [0, -1, 0], [0, 1, 1]]) == 'x,-y,y+z' - self.assertEqual( - symop_string_from_symop_matrix_tr([[-1, 0, 0], [0, 1, 0], [0, 0, 1]], [1, -1, 0]), '-x+1,y-1,z' - ) + assert symop_string_from_symop_matrix_tr([[-1, 0, 0], [0, 1, 0], [0, 0, 1]], [1, -1, 0]) == '-x+1,y-1,z' - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_ase + @skip_pycifrw def test_attached_hydrogens(self): """Test parsing of file with attached hydrogens.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -519,7 +516,7 @@ def test_attached_hydrogens(self): tmpf.flush() a = CifData(file=tmpf.name) - self.assertEqual(a.has_attached_hydrogens, False) + assert a.has_attached_hydrogens is False with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write( @@ -545,11 +542,11 @@ def test_attached_hydrogens(self): tmpf.flush() a = CifData(file=tmpf.name) - self.assertEqual(a.has_attached_hydrogens, True) + assert a.has_attached_hydrogens is True - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') - @unittest.skipIf(not has_spglib(), 'Unable to import spglib') + @skip_ase + @skip_pycifrw + @skip_spglib @pytest.mark.requires_rmq def test_refine(self): """ @@ -583,14 +580,12 @@ def test_refine(self): ret_dict = refine_inline(a) b = ret_dict['cif'] - self.assertEqual(list(b.values.keys()), ['test']) - self.assertEqual(b.values['test']['_chemical_formula_sum'], 'C O2') - self.assertEqual( - b.values['test']['_symmetry_equiv_pos_as_xyz'], [ - 'x,y,z', '-x,-y,-z', '-y,x,z', 'y,-x,-z', '-x,-y,z', 'x,y,-z', 'y,-x,z', '-y,x,-z', 'x,-y,-z', '-x,y,z', - '-y,-x,-z', 'y,x,z', '-x,y,-z', 'x,-y,z', 'y,x,-z', '-y,-x,z' - ] - ) + assert list(b.values.keys()) == ['test'] + assert b.values['test']['_chemical_formula_sum'] == 'C O2' + assert b.values['test']['_symmetry_equiv_pos_as_xyz'] == [ + 'x,y,z', '-x,-y,-z', '-y,x,z', 'y,-x,-z', '-x,-y,z', 'x,y,-z', 'y,-x,z', '-y,x,-z', 'x,-y,-z', '-x,y,z', + '-y,-x,-z', 'y,x,z', '-x,y,-z', 'x,-y,z', 'y,x,-z', '-y,-x,z' + ] with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write(''' @@ -600,10 +595,10 @@ def test_refine(self): tmpf.flush() c = CifData(file=tmpf.name) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ret_dict = refine_inline(c) - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_scan_type(self): """Check that different scan_types of PyCifRW produce the same result.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -612,12 +607,12 @@ def test_scan_type(self): default = CifData(file=tmpf.name) default2 = CifData(file=tmpf.name, scan_type='standard') - self.assertEqual(default._prepare_cif(), default2._prepare_cif()) # pylint: disable=protected-access + assert default._prepare_cif() == default2._prepare_cif() # pylint: disable=protected-access flex = CifData(file=tmpf.name, scan_type='flex') - self.assertEqual(default._prepare_cif(), flex._prepare_cif()) # pylint: disable=protected-access + assert default._prepare_cif() == flex._prepare_cif() # pylint: disable=protected-access - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_empty_cif(self): """Test empty CifData @@ -630,7 +625,7 @@ def test_empty_cif(self): a = CifData() # but it does not have a file - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): _ = a.filename #now it has @@ -639,7 +634,7 @@ def test_empty_cif(self): a.store() - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_parse_policy(self): """Test that loading of CIF file occurs as defined by parse_policy.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -648,20 +643,20 @@ def test_parse_policy(self): # this will parse the cif eager = CifData(file=tmpf.name, parse_policy='eager') - self.assertIsNot(eager._values, None) # pylint: disable=protected-access + assert eager._values is not None # pylint: disable=protected-access # this should not parse the cif lazy = CifData(file=tmpf.name, parse_policy='lazy') - self.assertIs(lazy._values, None) # pylint: disable=protected-access + assert lazy._values is None # pylint: disable=protected-access # also lazy-loaded nodes should be storable lazy.store() # this should parse the cif _ = lazy.values - self.assertIsNot(lazy._values, None) # pylint: disable=protected-access + assert lazy._values is not None # pylint: disable=protected-access - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_set_file(self): """Test that setting a new file clears formulae and spacegroups.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -670,7 +665,7 @@ def test_set_file(self): a = CifData(file=tmpf.name) f1 = a.get_formulae() - self.assertIsNot(f1, None) + assert f1 is not None with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write(self.valid_sample_cif_str_2) @@ -678,27 +673,27 @@ def test_set_file(self): # this should reset formulae and spacegroup_numbers a.set_file(tmpf.name) - self.assertIs(a.get_attribute('formulae'), None) - self.assertIs(a.get_attribute('spacegroup_numbers'), None) + assert a.get_attribute('formulae') is None + assert a.get_attribute('spacegroup_numbers') is None # this should populate formulae a.parse() f2 = a.get_formulae() - self.assertIsNot(f2, None) + assert f2 is not None # empty cifdata should be possible a = CifData() # but it does not have a file - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): _ = a.filename # now it has a.set_file(tmpf.name) a.parse() _ = a.filename - self.assertNotEqual(f1, f2) + assert f1 != f2 - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_has_partial_occupancies(self): """Test structure with partial occupancies.""" tests = [ @@ -728,9 +723,9 @@ def test_has_partial_occupancies(self): ) handle.flush() cif = CifData(file=handle.name) - self.assertEqual(cif.has_partial_occupancies, result) + assert cif.has_partial_occupancies == result - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_has_unknown_species(self): """Test structure with unknown species.""" tests = [ @@ -746,9 +741,9 @@ def test_has_unknown_species(self): handle.write(f"""data_test\n{formula_string}\n""") handle.flush() cif = CifData(file=handle.name) - self.assertEqual(cif.has_unknown_species, result, formula_string) + assert cif.has_unknown_species == result, formula_string - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_pycifrw def test_has_undefined_atomic_sites(self): """Test structure with undefined atomic sites.""" tests = [ @@ -764,20 +759,20 @@ def test_has_undefined_atomic_sites(self): handle.write(f"""data_test\n{atomic_site_string}\n""") handle.flush() cif = CifData(file=handle.name) - self.assertEqual(cif.has_undefined_atomic_sites, result) + assert cif.has_undefined_atomic_sites == result -class TestKindValidSymbols(AiidaTestCase): +class TestKindValidSymbols: """Tests the symbol validation of the aiida.orm.nodes.data.structure.Kind class.""" def test_bad_symbol(self): """Should not accept a non-existing symbol.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols='Hxx') def test_empty_list_symbols(self): """Should not accept an empty list.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols=[]) @staticmethod @@ -791,35 +786,35 @@ def test_unknown_symbol(): Kind(symbols=['X']) -class TestSiteValidWeights(AiidaTestCase): +class TestSiteValidWeights: """Tests valid weight lists.""" def test_isnot_list(self): """Should not accept a non-list, non-number weight.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols='Ba', weights='aaa') def test_empty_list_weights(self): """Should not accept an empty list.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols='Ba', weights=[]) def test_symbol_weight_mismatch(self): """Should not accept a size mismatch of the symbols and weights list.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols=['Ba', 'C'], weights=[1.]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols=['Ba'], weights=[0.1, 0.2]) def test_negative_value(self): """Should not accept a negative weight.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols=['Ba', 'C'], weights=[-0.1, 0.3]) def test_sum_greater_one(self): """Should not accept a sum of weights larger than one.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols=['Ba', 'C'], weights=[0.5, 0.6]) @staticmethod @@ -838,24 +833,24 @@ def test_none(): Kind(symbols='Ba', weights=None) -class TestKindTestGeneral(AiidaTestCase): +class TestKindTestGeneral: """Tests the creation of Kind objects and their methods.""" def test_sum_one_general(self): """Should accept a sum equal to one.""" a = Kind(symbols=['Ba', 'C'], weights=[1. / 3., 2. / 3.]) - self.assertTrue(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert a.is_alloy + assert not a.has_vacancies def test_sum_less_one_general(self): """Should accept a sum equal less than one.""" a = Kind(symbols=['Ba', 'C'], weights=[1. / 3., 1. / 3.]) - self.assertTrue(a.is_alloy) - self.assertTrue(a.has_vacancies) + assert a.is_alloy + assert a.has_vacancies def test_no_position(self): """Should not accept a 'positions' parameter.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(position=[0., 0., 0.], symbols=['Ba'], weights=[1.]) def test_simple(self): @@ -863,45 +858,45 @@ def test_simple(self): Should recognize a simple element. """ a = Kind(symbols='Ba') - self.assertFalse(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert not a.is_alloy + assert not a.has_vacancies b = Kind(symbols='Ba', weights=1.) - self.assertFalse(b.is_alloy) - self.assertFalse(b.has_vacancies) + assert not b.is_alloy + assert not b.has_vacancies c = Kind(symbols='Ba', weights=None) - self.assertFalse(c.is_alloy) - self.assertFalse(c.has_vacancies) + assert not c.is_alloy + assert not c.has_vacancies def test_automatic_name(self): """ Check the automatic name generator. """ a = Kind(symbols='Ba') - self.assertEqual(a.name, 'Ba') + assert a.name == 'Ba' a = Kind(symbols='X') - self.assertEqual(a.name, 'X') + assert a.name == 'X' a = Kind(symbols=('Si', 'Ge'), weights=(1. / 3., 2. / 3.)) - self.assertEqual(a.name, 'GeSi') + assert a.name == 'GeSi' a = Kind(symbols=('Si', 'X'), weights=(1. / 3., 2. / 3.)) - self.assertEqual(a.name, 'SiX') + assert a.name == 'SiX' a = Kind(symbols=('Si', 'Ge'), weights=(0.4, 0.5)) - self.assertEqual(a.name, 'GeSiX') + assert a.name == 'GeSiX' a = Kind(symbols=('Si', 'X'), weights=(0.4, 0.5)) - self.assertEqual(a.name, 'SiXX') + assert a.name == 'SiXX' # Manually setting the name of the species a.name = 'newstring' - self.assertEqual(a.name, 'newstring') + assert a.name == 'newstring' -class TestKindTestMasses(AiidaTestCase): +class TestKindTestMasses: """ Tests the management of masses during the creation of Kind objects. """ @@ -910,38 +905,33 @@ def test_auto_mass_one(self): """ mass for elements with sum one """ - from aiida.orm.nodes.data.structure import _atomic_masses - a = Kind(symbols=['Ba', 'C'], weights=[1. / 3., 2. / 3.]) - self.assertAlmostEqual(a.mass, (_atomic_masses['Ba'] + 2. * _atomic_masses['C']) / 3.) + assert round(abs(a.mass - (_atomic_masses['Ba'] + 2. * _atomic_masses['C']) / 3.), 7) == 0 def test_sum_less_one_masses(self): """ mass for elements with sum less than one """ - from aiida.orm.nodes.data.structure import _atomic_masses - a = Kind(symbols=['Ba', 'C'], weights=[1. / 3., 1. / 3.]) - self.assertAlmostEqual(a.mass, (_atomic_masses['Ba'] + _atomic_masses['C']) / 2.) + assert round(abs(a.mass - (_atomic_masses['Ba'] + _atomic_masses['C']) / 2.), 7) == 0 def test_sum_less_one_singleelem(self): """ mass for a single element """ - from aiida.orm.nodes.data.structure import _atomic_masses - a = Kind(symbols=['Ba']) - self.assertAlmostEqual(a.mass, _atomic_masses['Ba']) + assert round(abs(a.mass - _atomic_masses['Ba']), 7) == 0 def test_manual_mass(self): """ mass set manually """ a = Kind(symbols=['Ba', 'C'], weights=[1. / 3., 1. / 3.], mass=1000.) - self.assertAlmostEqual(a.mass, 1000.) + assert round(abs(a.mass - 1000.), 7) == 0 -class TestStructureDataInit(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestStructureDataInit: """ Tests the creation of StructureData objects (cell and pbc). """ @@ -950,28 +940,28 @@ def test_cell_wrong_size_1(self): """ Wrong cell size (not 3x3) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(cell=((1., 2., 3.),)) def test_cell_wrong_size_2(self): """ Wrong cell size (not 3x3) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(cell=((1., 0., 0.), (0., 0., 3.), (0., 3.))) def test_cell_zero_vector(self): """ Wrong cell (one vector has zero length) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(cell=((0., 0., 0.), (0., 1., 0.), (0., 0., 1.))) def test_cell_zero_volume(self): """ Wrong cell (volume is zero) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(cell=((1., 0., 0.), (0., 1., 0.), (1., 1., 0.))) def test_cell_ok_init(self): @@ -984,20 +974,20 @@ def test_cell_ok_init(self): for i in range(3): for j in range(3): - self.assertAlmostEqual(cell[i][j], out_cell[i][j]) + assert round(abs(cell[i][j] - out_cell[i][j]), 7) == 0 def test_volume(self): """ Check the volume calculation """ a = StructureData(cell=((1., 0., 0.), (0., 2., 0.), (0., 0., 3.))) - self.assertAlmostEqual(a.get_cell_volume(), 6.) + assert round(abs(a.get_cell_volume() - 6.), 7) == 0 def test_wrong_pbc_1(self): """ Wrong pbc parameter (not bool or iterable) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) StructureData(cell=cell, pbc=1) @@ -1005,7 +995,7 @@ def test_wrong_pbc_2(self): """ Wrong pbc parameter (iterable but with wrong len) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) StructureData(cell=cell, pbc=[True, True]) @@ -1013,7 +1003,7 @@ def test_wrong_pbc_3(self): """ Wrong pbc parameter (iterable but with wrong len) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) StructureData(cell=cell, pbc=[]) @@ -1023,10 +1013,10 @@ def test_ok_pbc_1(self): """ cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) a = StructureData(cell=cell, pbc=True) - self.assertEqual(a.pbc, tuple([True, True, True])) + assert a.pbc == tuple([True, True, True]) a = StructureData(cell=cell, pbc=False) - self.assertEqual(a.pbc, tuple([False, False, False])) + assert a.pbc == tuple([False, False, False]) def test_ok_pbc_2(self): """ @@ -1034,10 +1024,10 @@ def test_ok_pbc_2(self): """ cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) a = StructureData(cell=cell, pbc=[True]) - self.assertEqual(a.pbc, tuple([True, True, True])) + assert a.pbc == tuple([True, True, True]) a = StructureData(cell=cell, pbc=[False]) - self.assertEqual(a.pbc, tuple([False, False, False])) + assert a.pbc == tuple([False, False, False]) def test_ok_pbc_3(self): """ @@ -1045,15 +1035,14 @@ def test_ok_pbc_3(self): """ cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) a = StructureData(cell=cell, pbc=[True, False, True]) - self.assertEqual(a.pbc, tuple([True, False, True])) + assert a.pbc == tuple([True, False, True]) -class TestStructureData(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestStructureData: """ Tests the creation of StructureData objects (cell and pbc). """ - # pylint: disable=too-many-public-methods - from aiida.orm.nodes.data.cif import has_pycifrw def test_cell_ok_and_atoms(self): """ @@ -1063,29 +1052,29 @@ def test_cell_ok_and_atoms(self): a = StructureData(cell=cell) out_cell = a.cell - self.assertAlmostEqual(cell, out_cell) + np.testing.assert_allclose(out_cell, cell) a.append_atom(position=(0., 0., 0.), symbols=['Ba']) a.append_atom(position=(1., 1., 1.), symbols=['Ti']) a.append_atom(position=(1.2, 1.4, 1.6), symbols=['Ti']) - self.assertFalse(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert not a.is_alloy + assert not a.has_vacancies # There should be only two kinds! (two atoms of kind Ti should # belong to the same kind) - self.assertEqual(len(a.kinds), 2) + assert len(a.kinds) == 2 a.append_atom(position=(0.5, 1., 1.5), symbols=['O', 'C'], weights=[0.5, 0.5]) - self.assertTrue(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert a.is_alloy + assert not a.has_vacancies a.append_atom(position=(0.5, 1., 1.5), symbols=['O'], weights=[0.5]) - self.assertTrue(a.is_alloy) - self.assertTrue(a.has_vacancies) + assert a.is_alloy + assert a.has_vacancies a.clear_kinds() a.append_atom(position=(0.5, 1., 1.5), symbols=['O'], weights=[0.5]) - self.assertFalse(a.is_alloy) - self.assertTrue(a.has_vacancies) + assert not a.is_alloy + assert a.has_vacancies def test_cell_ok_and_unknown_atoms(self): """ @@ -1096,29 +1085,29 @@ def test_cell_ok_and_unknown_atoms(self): a = StructureData(cell=cell) out_cell = a.cell - self.assertAlmostEqual(cell, out_cell) + np.testing.assert_allclose(out_cell, cell) a.append_atom(position=(0., 0., 0.), symbols=['Ba']) a.append_atom(position=(1., 1., 1.), symbols=['X']) a.append_atom(position=(1.2, 1.4, 1.6), symbols=['X']) - self.assertFalse(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert not a.is_alloy + assert not a.has_vacancies # There should be only two kinds! (two atoms of kind X should # belong to the same kind) - self.assertEqual(len(a.kinds), 2) + assert len(a.kinds) == 2 a.append_atom(position=(0.5, 1., 1.5), symbols=['O', 'C'], weights=[0.5, 0.5]) - self.assertTrue(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert a.is_alloy + assert not a.has_vacancies a.append_atom(position=(0.5, 1., 1.5), symbols=['O'], weights=[0.5]) - self.assertTrue(a.is_alloy) - self.assertTrue(a.has_vacancies) + assert a.is_alloy + assert a.has_vacancies a.clear_kinds() a.append_atom(position=(0.5, 1., 1.5), symbols=['X'], weights=[0.5]) - self.assertFalse(a.is_alloy) - self.assertTrue(a.has_vacancies) + assert not a.is_alloy + assert a.has_vacancies def test_kind_1(self): """ @@ -1131,9 +1120,9 @@ def test_kind_1(self): a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Ba']) a.append_atom(position=(1., 1., 1.), symbols=['Ti']) - self.assertEqual(len(a.kinds), 2) # I should only have two types + assert len(a.kinds) == 2 # I should only have two types # I check for the default names of kinds - self.assertEqual(set(k.name for k in a.kinds), set(('Ba', 'Ti'))) + assert set(k.name for k in a.kinds) == set(('Ba', 'Ti')) def test_kind_1_unknown(self): """ @@ -1146,9 +1135,9 @@ def test_kind_1_unknown(self): a.append_atom(position=(0.5, 0.5, 0.5), symbols=['X']) a.append_atom(position=(1., 1., 1.), symbols=['Ti']) - self.assertEqual(len(a.kinds), 2) # I should only have two types + assert len(a.kinds) == 2 # I should only have two types # I check for the default names of kinds - self.assertEqual(set(k.name for k in a.kinds), set(('X', 'Ti'))) + assert set(k.name for k in a.kinds) == set(('X', 'Ti')) def test_kind_2(self): """ @@ -1161,8 +1150,8 @@ def test_kind_2(self): a.append_atom(position=(1., 1., 1.), symbols=['Ti']) kind_list = a.kinds - self.assertEqual(len(kind_list), 3) # I should have now three kinds - self.assertEqual(set(k.name for k in kind_list), set(('Ba1', 'Ba2', 'Ti'))) + assert len(kind_list) == 3 # I should have now three kinds + assert set(k.name for k in kind_list) == set(('Ba1', 'Ba2', 'Ti')) def test_kind_2_unknown(self): """ @@ -1176,8 +1165,8 @@ def test_kind_2_unknown(self): a.append_atom(position=(1., 1., 1.), symbols=['Ti']) kind_list = a.kinds - self.assertEqual(len(kind_list), 3) # I should have now three kinds - self.assertEqual(set(k.name for k in kind_list), set(('X1', 'X2', 'Ti'))) + assert len(kind_list) == 3 # I should have now three kinds + assert set(k.name for k in kind_list) == set(('X1', 'X2', 'Ti')) def test_kind_3(self): """ @@ -1186,7 +1175,7 @@ def test_kind_3(self): a = StructureData(cell=((2., 0., 0.), (0., 2., 0.), (0., 0., 2.))) a.append_atom(position=(0., 0., 0.), symbols=['Ba'], mass=100.) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, I am adding two sites with the same name 'Ba' a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Ba'], mass=101., name='Ba') @@ -1195,9 +1184,9 @@ def test_kind_3(self): a.append_atom(position=(1., 1., 1.), symbols=['Ti']) - self.assertEqual(len(a.kinds), 3) # I should have now three types - self.assertEqual(len(a.sites), 3) # and 3 sites - self.assertEqual(set(k.name for k in a.kinds), set(('Ba', 'Ba2', 'Ti'))) + assert len(a.kinds) == 3 # I should have now three types + assert len(a.sites) == 3 # and 3 sites + assert set(k.name for k in a.kinds) == set(('Ba', 'Ba2', 'Ti')) def test_kind_3_unknown(self): """ @@ -1207,7 +1196,7 @@ def test_kind_3_unknown(self): a = StructureData(cell=((2., 0., 0.), (0., 2., 0.), (0., 0., 2.))) a.append_atom(position=(0., 0., 0.), symbols=['X'], mass=100.) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, I am adding two sites with the same name 'Ba' a.append_atom(position=(0.5, 0.5, 0.5), symbols=['X'], mass=101., name='X') @@ -1216,9 +1205,9 @@ def test_kind_3_unknown(self): a.append_atom(position=(1., 1., 1.), symbols=['Ti']) - self.assertEqual(len(a.kinds), 3) # I should have now three types - self.assertEqual(len(a.sites), 3) # and 3 sites - self.assertEqual(set(k.name for k in a.kinds), set(('X', 'X2', 'Ti'))) + assert len(a.kinds) == 3 # I should have now three types + assert len(a.sites) == 3 # and 3 sites + assert set(k.name for k in a.kinds) == set(('X', 'X2', 'Ti')) def test_kind_4(self): """ @@ -1229,26 +1218,26 @@ def test_kind_4(self): a.append_atom(position=(0., 0., 0.), symbols=['Ba', 'Ti'], weights=(1., 0.), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different weights a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Ba', 'Ti'], weights=(0.9, 0.1), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different weights (with vacancy) a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Ba', 'Ti'], weights=(0.8, 0.1), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different symbols list a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Ba'], name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different symbols list a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Si', 'Ti'], weights=(1., 0.), name='mytype') # should allow because every property is identical a.append_atom(position=(0., 0., 0.), symbols=['Ba', 'Ti'], weights=(1., 0.), name='mytype') - self.assertEqual(len(a.kinds), 1) + assert len(a.kinds) == 1 def test_kind_4_unknown(self): """ @@ -1259,26 +1248,26 @@ def test_kind_4_unknown(self): a.append_atom(position=(0., 0., 0.), symbols=['X', 'Ti'], weights=(1., 0.), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different weights a.append_atom(position=(0.5, 0.5, 0.5), symbols=['X', 'Ti'], weights=(0.9, 0.1), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different weights (with vacancy) a.append_atom(position=(0.5, 0.5, 0.5), symbols=['X', 'Ti'], weights=(0.8, 0.1), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different symbols list a.append_atom(position=(0.5, 0.5, 0.5), symbols=['X'], name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different symbols list a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Si', 'Ti'], weights=(1., 0.), name='mytype') # should allow because every property is identical a.append_atom(position=(0., 0., 0.), symbols=['X', 'Ti'], weights=(1., 0.), name='mytype') - self.assertEqual(len(a.kinds), 1) + assert len(a.kinds) == 1 def test_kind_5(self): """ @@ -1294,15 +1283,15 @@ def test_kind_5(self): # The name already exists, but the properties are identical => OK a.append_atom(position=(1., 1., 1.), symbols='Ti', name='Ti2') # The name already exists, but the properties are different! - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.append_atom(position=(1., 1., 1.), symbols='Ti', mass=100., name='Ti2') # Should not complain, should create a new type a.append_atom(position=(0., 0., 0.), symbols='Ba', mass=150.) # There should be 4 kinds, the automatic name for the last one # should be Ba1 - self.assertEqual([k.name for k in a.kinds], ['Ba', 'Ti', 'Ti2', 'Ba1']) - self.assertEqual(len(a.sites), 5) + assert [k.name for k in a.kinds] == ['Ba', 'Ti', 'Ti2', 'Ba1'] + assert len(a.sites) == 5 def test_kind_5_unknown(self): """ @@ -1319,15 +1308,15 @@ def test_kind_5_unknown(self): # The name already exists, but the properties are identical => OK a.append_atom(position=(1., 1., 1.), symbols='Ti', name='Ti2') # The name already exists, but the properties are different! - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.append_atom(position=(1., 1., 1.), symbols='Ti', mass=100., name='Ti2') # Should not complain, should create a new type a.append_atom(position=(0., 0., 0.), symbols='X', mass=150.) # There should be 4 kinds, the automatic name for the last one # should be Ba1 - self.assertEqual([k.name for k in a.kinds], ['X', 'Ti', 'Ti2', 'X1']) - self.assertEqual(len(a.sites), 5) + assert [k.name for k in a.kinds] == ['X', 'Ti', 'Ti2', 'X1'] + assert len(a.sites) == 5 def test_kind_5_bis(self): """Test the management of kinds (automatic creation of new kind @@ -1346,10 +1335,10 @@ def test_kind_5_bis(self): # I expect only two species, the first one with name 'Fe', mass 12, # and referencing the first three atoms; the second with name # 'Fe1', mass = elements[26]['mass'], and referencing the last two atoms - self.assertEqual({(k.name, k.mass) for k in s.kinds}, {('Fe', 12.0), ('Fe1', elements[26]['mass'])}) + assert {(k.name, k.mass) for k in s.kinds} == {('Fe', 12.0), ('Fe1', elements[26]['mass'])} kind_of_each_site = [site.kind_name for site in s.sites] - self.assertEqual(kind_of_each_site, ['Fe', 'Fe', 'Fe', 'Fe1', 'Fe1']) + assert kind_of_each_site == ['Fe', 'Fe', 'Fe', 'Fe1', 'Fe1'] def test_kind_5_bis_unknown(self): """Test the management of kinds (automatic creation of new kind @@ -1369,12 +1358,12 @@ def test_kind_5_bis_unknown(self): # I expect only two species, the first one with name 'X', mass 12, # and referencing the first three atoms; the second with name # 'X', mass = elements[0]['mass'], and referencing the last two atoms - self.assertEqual({(k.name, k.mass) for k in s.kinds}, {('X', 12.0), ('X1', elements[0]['mass'])}) + assert {(k.name, k.mass) for k in s.kinds} == {('X', 12.0), ('X1', elements[0]['mass'])} kind_of_each_site = [site.kind_name for site in s.sites] - self.assertEqual(kind_of_each_site, ['X', 'X', 'X', 'X1', 'X1']) + assert kind_of_each_site == ['X', 'X', 'X', 'X1', 'X1'] - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @skip_ase def test_kind_5_bis_ase(self): """ Same test as test_kind_5_bis, but using ase @@ -1399,12 +1388,12 @@ def test_kind_5_bis_ase(self): # I expect only two species, the first one with name 'Fe', mass 12, # and referencing the first three atoms; the second with name # 'Fe1', mass = elements[26]['mass'], and referencing the last two atoms - self.assertEqual({(k.name, k.mass) for k in s.kinds}, {('Fe', 12.0), ('Fe1', asecell[3].mass)}) + assert {(k.name, k.mass) for k in s.kinds} == {('Fe', 12.0), ('Fe1', asecell[3].mass)} kind_of_each_site = [site.kind_name for site in s.sites] - self.assertEqual(kind_of_each_site, ['Fe', 'Fe', 'Fe', 'Fe1', 'Fe1']) + assert kind_of_each_site == ['Fe', 'Fe', 'Fe', 'Fe1', 'Fe1'] - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @skip_ase def test_kind_5_bis_ase_unknown(self): """ Same test as test_kind_5_bis_unknown, but using ase @@ -1429,10 +1418,10 @@ def test_kind_5_bis_ase_unknown(self): # I expect only two species, the first one with name 'X', mass 12, # and referencing the first three atoms; the second with name # 'X1', mass = elements[26]['mass'], and referencing the last two atoms - self.assertEqual({(k.name, k.mass) for k in s.kinds}, {('X', 12.0), ('X1', asecell[3].mass)}) + assert {(k.name, k.mass) for k in s.kinds} == {('X', 12.0), ('X1', asecell[3].mass)} kind_of_each_site = [site.kind_name for site in s.sites] - self.assertEqual(kind_of_each_site, ['X', 'X', 'X', 'X1', 'X1']) + assert kind_of_each_site == ['X', 'X', 'X', 'X1', 'X1'] def test_kind_6(self): """ @@ -1451,15 +1440,15 @@ def test_kind_6(self): a.append_atom(position=(0., 0., 0.), symbols='Ba', mass=150.) # There should be 4 kinds, the automatic name for the last one # should be Ba1 (same check of test_kind_5 - self.assertEqual([k.name for k in a.kinds], ['Ba', 'Ti', 'Ti2', 'Ba1']) + assert [k.name for k in a.kinds] == ['Ba', 'Ti', 'Ti2', 'Ba1'] ############################# # Here I start the real tests # No such kind - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.get_kind('Ti3') k = a.get_kind('Ba1') - self.assertEqual(k.symbols, ('Ba',)) - self.assertAlmostEqual(k.mass, 150.) + assert k.symbols == ('Ba',) + assert round(abs(k.mass - 150.), 7) == 0 def test_kind_6_unknown(self): """ @@ -1478,15 +1467,15 @@ def test_kind_6_unknown(self): a.append_atom(position=(0., 0., 0.), symbols='X', mass=150.) # There should be 4 kinds, the automatic name for the last one # should be Ba1 (same check of test_kind_5 - self.assertEqual([k.name for k in a.kinds], ['X', 'Ti', 'Ti2', 'X1']) + assert [k.name for k in a.kinds] == ['X', 'Ti', 'Ti2', 'X1'] ############################# # Here I start the real tests # No such kind - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.get_kind('Ti3') k = a.get_kind('X1') - self.assertEqual(k.symbols, ('X',)) - self.assertAlmostEqual(k.mass, 150.) + assert k.symbols == ('X',) + assert round(abs(k.mass - 150.), 7) == 0 def test_kind_7(self): """ @@ -1501,7 +1490,7 @@ def test_kind_7(self): # The name already exists, but the properties are identical => OK a.append_atom(position=(0., 0., 0.), symbols=['O', 'H'], weights=[0.9, 0.1], mass=15.) - self.assertEqual(a.get_symbols_set(), set(['Ba', 'Ti', 'O', 'H'])) + assert a.get_symbols_set() == set(['Ba', 'Ti', 'O', 'H']) def test_kind_7_unknown(self): """ @@ -1517,10 +1506,10 @@ def test_kind_7_unknown(self): # The name already exists, but the properties are identical => OK a.append_atom(position=(0., 0., 0.), symbols=['O', 'H'], weights=[0.9, 0.1], mass=15.) - self.assertEqual(a.get_symbols_set(), set(['Ba', 'X', 'O', 'H'])) + assert a.get_symbols_set() == set(['Ba', 'X', 'O', 'H']) - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_spglib(), 'Unable to import spglib') + @skip_ase + @skip_spglib def test_kind_8(self): """ Test the ase_refine_cell() function @@ -1529,7 +1518,6 @@ def test_kind_8(self): import math import ase - import numpy a = ase.Atoms(cell=[10, 10, 10]) a.append(ase.Atom('C', [0, 0, 0])) @@ -1537,9 +1525,9 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['C']) - self.assertEqual(b.cell.tolist(), [[10, 0, 0], [0, 10, 0], [0, 0, 5]]) - self.assertEqual(sym, {'hall': '-P 4 2', 'hm': 'P4/mmm', 'tables': 123}) + assert b.get_chemical_symbols() == ['C'] + assert b.cell.tolist() == [[10, 0, 0], [0, 10, 0], [0, 0, 5]] + assert sym == {'hall': '-P 4 2', 'hm': 'P4/mmm', 'tables': 123} a = ase.Atoms(cell=[10, 2 * math.sqrt(75), 10]) a.append(ase.Atom('C', [0, 0, 0])) @@ -1547,9 +1535,9 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['C']) - self.assertEqual(numpy.round(b.cell, 2).tolist(), [[10, 0, 0], [-5, 8.66, 0], [0, 0, 10]]) - self.assertEqual(sym, {'hall': '-P 6 2', 'hm': 'P6/mmm', 'tables': 191}) + assert b.get_chemical_symbols() == ['C'] + assert np.round(b.cell, 2).tolist() == [[10, 0, 0], [-5, 8.66, 0], [0, 0, 10]] + assert sym == {'hall': '-P 6 2', 'hm': 'P6/mmm', 'tables': 191} a = ase.Atoms(cell=[[10, 0, 0], [-10, 10, 0], [0, 0, 10]]) a.append(ase.Atom('C', [5, 5, 5])) @@ -1557,10 +1545,10 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['C', 'F']) - self.assertEqual(b.cell.tolist(), [[10, 0, 0], [0, 10, 0], [0, 0, 10]]) - self.assertEqual(b.get_scaled_positions().tolist(), [[0.5, 0.5, 0.5], [0, 0, 0]]) - self.assertEqual(sym, {'hall': '-P 4 2 3', 'hm': 'Pm-3m', 'tables': 221}) + assert b.get_chemical_symbols() == ['C', 'F'] + assert b.cell.tolist() == [[10, 0, 0], [0, 10, 0], [0, 0, 10]] + assert b.get_scaled_positions().tolist() == [[0.5, 0.5, 0.5], [0, 0, 0]] + assert sym == {'hall': '-P 4 2 3', 'hm': 'Pm-3m', 'tables': 221} a = ase.Atoms(cell=[[10, 0, 0], [-10, 10, 0], [0, 0, 10]]) a.append(ase.Atom('C', [0, 0, 0])) @@ -1568,18 +1556,18 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['C', 'F']) - self.assertEqual(b.cell.tolist(), [[10, 0, 0], [0, 10, 0], [0, 0, 10]]) - self.assertEqual(b.get_scaled_positions().tolist(), [[0, 0, 0], [0.5, 0.5, 0.5]]) - self.assertEqual(sym, {'hall': '-P 4 2 3', 'hm': 'Pm-3m', 'tables': 221}) + assert b.get_chemical_symbols() == ['C', 'F'] + assert b.cell.tolist() == [[10, 0, 0], [0, 10, 0], [0, 0, 10]] + assert b.get_scaled_positions().tolist() == [[0, 0, 0], [0.5, 0.5, 0.5]] + assert sym == {'hall': '-P 4 2 3', 'hm': 'Pm-3m', 'tables': 221} a = ase.Atoms(cell=[[12.132, 0, 0], [0, 6.0606, 0], [0, 0, 8.0956]]) a.append(ase.Atom('Ba', [1.5334848, 1.3999986, 2.00042276])) b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.cell.tolist(), [[6.0606, 0, 0], [0, 8.0956, 0], [0, 0, 12.132]]) - self.assertEqual(b.get_scaled_positions().tolist(), [[0, 0, 0]]) + assert b.cell.tolist() == [[6.0606, 0, 0], [0, 8.0956, 0], [0, 0, 12.132]] + assert b.get_scaled_positions().tolist() == [[0, 0, 0]] a = ase.Atoms(cell=[10, 10, 10]) a.append(ase.Atom('C', [5, 5, 5])) @@ -1588,8 +1576,8 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['C', 'O']) - self.assertEqual(sym, {'hall': '-P 4 2', 'hm': 'P4/mmm', 'tables': 123}) + assert b.get_chemical_symbols() == ['C', 'O'] + assert sym == {'hall': '-P 4 2', 'hm': 'P4/mmm', 'tables': 123} # Generated from COD entry 1507756 # (http://www.crystallography.net/cod/1507756.cif@87343) @@ -1600,65 +1588,65 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['Ba', 'Ti', 'O', 'O']) - self.assertEqual(sym, {'hall': 'P 4 -2', 'hm': 'P4mm', 'tables': 99}) + assert b.get_chemical_symbols() == ['Ba', 'Ti', 'O', 'O'] + assert sym == {'hall': 'P 4 -2', 'hm': 'P4mm', 'tables': 99} def test_get_formula(self): """ Tests the generation of formula """ - self.assertEqual(get_formula(['Ba', 'Ti'] + ['O'] * 3), 'BaO3Ti') - self.assertEqual(get_formula(['Ba', 'Ti', 'C'] + ['O'] * 3, separator=' '), 'C Ba O3 Ti') - self.assertEqual(get_formula(['H'] * 6 + ['C'] * 6), 'C6H6') - self.assertEqual(get_formula(['H'] * 6 + ['C'] * 6, mode='hill_compact'), 'CH') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ + assert get_formula(['Ba', 'Ti'] + ['O'] * 3) == 'BaO3Ti' + assert get_formula(['Ba', 'Ti', 'C'] + ['O'] * 3, separator=' ') == 'C Ba O3 Ti' + assert get_formula(['H'] * 6 + ['C'] * 6) == 'C6H6' + assert get_formula(['H'] * 6 + ['C'] * 6, mode='hill_compact') == 'CH' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['O'] * 3, - mode='group'), - '(BaTiO3)2BaTi2O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ + mode='group') == \ + '(BaTiO3)2BaTi2O3' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['O'] * 3, - mode='group', separator=' '), - '(Ba Ti O3)2 Ba Ti2 O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ + mode='group', separator=' ') == \ + '(Ba Ti O3)2 Ba Ti2 O3' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['O'] * 3, - mode='reduce'), - 'BaTiO3BaTiO3BaTi2O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ + mode='reduce') == \ + 'BaTiO3BaTiO3BaTi2O3' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['O'] * 3, - mode='reduce', separator=', '), - 'Ba, Ti, O3, Ba, Ti, O3, Ba, Ti2, O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count'), 'Ba2Ti2O6') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count_compact'), 'BaTiO3') + mode='reduce', separator=', ') == \ + 'Ba, Ti, O3, Ba, Ti, O3, Ba, Ti2, O3' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count') == 'Ba2Ti2O6' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count_compact') == 'BaTiO3' def test_get_formula_unknown(self): """ Tests the generation of formula, including unknown entry. """ - self.assertEqual(get_formula(['Ba', 'Ti'] + ['X'] * 3), 'BaTiX3') - self.assertEqual(get_formula(['Ba', 'Ti', 'C'] + ['X'] * 3, separator=' '), 'C Ba Ti X3') - self.assertEqual(get_formula(['X'] * 6 + ['C'] * 6), 'C6X6') - self.assertEqual(get_formula(['X'] * 6 + ['C'] * 6, mode='hill_compact'), 'CX') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ + assert get_formula(['Ba', 'Ti'] + ['X'] * 3) == 'BaTiX3' + assert get_formula(['Ba', 'Ti', 'C'] + ['X'] * 3, separator=' ') == 'C Ba Ti X3' + assert get_formula(['X'] * 6 + ['C'] * 6) == 'C6X6' + assert get_formula(['X'] * 6 + ['C'] * 6, mode='hill_compact') == 'CX' + assert get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ ['Ba'] + ['X'] * 2 + ['O'] * 3, - mode='group'), - '(BaTiX3)2BaX2O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ + mode='group') == \ + '(BaTiX3)2BaX2O3' + assert get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ ['Ba'] + ['X'] * 2 + ['O'] * 3, - mode='group', separator=' '), - '(Ba Ti X3)2 Ba X2 O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ + mode='group', separator=' ') == \ + '(Ba Ti X3)2 Ba X2 O3' + assert get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['X'] * 3, - mode='reduce'), - 'BaTiX3BaTiX3BaTi2X3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ + mode='reduce') == \ + 'BaTiX3BaTiX3BaTi2X3' + assert get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['X'] * 3, - mode='reduce', separator=', '), - 'Ba, Ti, X3, Ba, Ti, X3, Ba, Ti2, X3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count'), 'Ba2Ti2O6') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2, mode='count_compact'), 'BaTiX3') + mode='reduce', separator=', ') == \ + 'Ba, Ti, X3, Ba, Ti, X3, Ba, Ti2, X3' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count') == 'Ba2Ti2O6' + assert get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2, mode='count_compact') == 'BaTiX3' - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @skip_ase + @skip_pycifrw @pytest.mark.requires_rmq def test_get_cif(self): """ @@ -1678,8 +1666,7 @@ def test_get_cif(self): for line in lines: if not re.search('^#', line): non_comments.append(line) - self.assertEqual( - simplify('\n'.join(non_comments)), + assert simplify('\n'.join(non_comments)) == \ simplify( """ data_0 @@ -1707,11 +1694,9 @@ def test_get_cif(self): _symmetry_space_group_name_H-M 'P 1' """ ) - ) def test_xyz_parser(self): """Test XYZ parser.""" - import numpy as np xyz_string1 = """ 3 @@ -1739,13 +1724,13 @@ def test_xyz_parser(self): s._parse_xyz(xyz_string) # pylint: disable=protected-access # Making sure that the periodic boundary condition are not True # because I cannot parse a cell! - self.assertTrue(not any(s.pbc)) + assert not any(s.pbc) # Making sure that the structure has sites, kinds and a cell - self.assertTrue(s.sites) - self.assertTrue(s.kinds) - self.assertTrue(s.cell) + assert s.sites + assert s.kinds + assert s.cell # The default cell is given in these cases: - self.assertEqual(s.cell, np.diag([1, 1, 1]).tolist()) + assert s.cell == np.diag([1, 1, 1]).tolist() # Testing a case where 1 xyz_string4 = """ @@ -1771,59 +1756,58 @@ def test_xyz_parser(self): # The above cases have to fail because the number of atoms is wrong for xyz_string in (xyz_string4, xyz_string5, xyz_string6): - with self.assertRaises(TypeError): + with pytest.raises(TypeError): StructureData()._parse_xyz(xyz_string) # pylint: disable=protected-access -class TestStructureDataLock(AiidaTestCase): - """Tests that the structure is locked after storage.""" +@pytest.mark.usefixtures('aiida_profile_clean') +def test_lock(): + """Test that the structure is locked after storage.""" + cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) + a = StructureData(cell=cell) - def test_lock(self): - """Start from a StructureData object, convert to raw and then back.""" - cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) - a = StructureData(cell=cell) + a.pbc = [False, True, True] - a.pbc = [False, True, True] + k = Kind(symbols='Ba', name='Ba') + s = Site(position=(0., 0., 0.), kind_name='Ba') + a.append_kind(k) + a.append_site(s) - k = Kind(symbols='Ba', name='Ba') - s = Site(position=(0., 0., 0.), kind_name='Ba') - a.append_kind(k) - a.append_site(s) + a.append_atom(symbols='Ti', position=[0., 0., 0.]) - a.append_atom(symbols='Ti', position=[0., 0., 0.]) + a.store() - a.store() - - k2 = Kind(symbols='Ba', name='Ba') - # Nothing should be changed after store() - with self.assertRaises(ModificationNotAllowed): - a.append_kind(k2) - with self.assertRaises(ModificationNotAllowed): - a.append_site(s) - with self.assertRaises(ModificationNotAllowed): - a.clear_sites() - with self.assertRaises(ModificationNotAllowed): - a.clear_kinds() - with self.assertRaises(ModificationNotAllowed): - a.cell = cell - with self.assertRaises(ModificationNotAllowed): - a.pbc = [True, True, True] - - _ = a.get_cell_volume() - _ = a.is_alloy - _ = a.has_vacancies - - b = a.clone() - # I check that clone returned an unstored copy and so can be altered - b.append_site(s) - b.clear_sites() - # I check that the original did not change - self.assertNotEqual(len(a.sites), 0) - b.cell = cell - b.pbc = [True, True, True] - - -class TestStructureDataReload(AiidaTestCase): + k2 = Kind(symbols='Ba', name='Ba') + # Nothing should be changed after store() + with pytest.raises(ModificationNotAllowed): + a.append_kind(k2) + with pytest.raises(ModificationNotAllowed): + a.append_site(s) + with pytest.raises(ModificationNotAllowed): + a.clear_sites() + with pytest.raises(ModificationNotAllowed): + a.clear_kinds() + with pytest.raises(ModificationNotAllowed): + a.cell = cell + with pytest.raises(ModificationNotAllowed): + a.pbc = [True, True, True] + + _ = a.get_cell_volume() + _ = a.is_alloy + _ = a.has_vacancies + + b = a.clone() + # I check that clone returned an unstored copy and so can be altered + b.append_site(s) + b.clear_sites() + # I check that the original did not change + assert len(a.sites) != 0 + b.cell = cell + b.pbc = [True, True, True] + + +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestStructureDataReload: """ Tests the creation of StructureData, converting it to a raw format and converting it back. @@ -1847,32 +1831,32 @@ def test_reload(self): for i in range(3): for j in range(3): - self.assertAlmostEqual(cell[i][j], b.cell[i][j]) + assert round(abs(cell[i][j] - b.cell[i][j]), 7) == 0 - self.assertEqual(b.pbc, (False, True, True)) - self.assertEqual(len(b.sites), 2) - self.assertEqual(b.kinds[0].symbols[0], 'Ba') - self.assertEqual(b.kinds[1].symbols[0], 'Ti') + assert b.pbc == (False, True, True) + assert len(b.sites) == 2 + assert b.kinds[0].symbols[0] == 'Ba' + assert b.kinds[1].symbols[0] == 'Ti' for i in range(3): - self.assertAlmostEqual(b.sites[0].position[i], 0.) + assert round(abs(b.sites[0].position[i] - 0.), 7) == 0 for i in range(3): - self.assertAlmostEqual(b.sites[1].position[i], 1.) + assert round(abs(b.sites[1].position[i] - 1.), 7) == 0 # Fully reload from UUID b = load_node(a.uuid, sub_classes=(StructureData,)) for i in range(3): for j in range(3): - self.assertAlmostEqual(cell[i][j], b.cell[i][j]) + assert round(abs(cell[i][j] - b.cell[i][j]), 7) == 0 - self.assertEqual(b.pbc, (False, True, True)) - self.assertEqual(len(b.sites), 2) - self.assertEqual(b.kinds[0].symbols[0], 'Ba') - self.assertEqual(b.kinds[1].symbols[0], 'Ti') + assert b.pbc == (False, True, True) + assert len(b.sites) == 2 + assert b.kinds[0].symbols[0] == 'Ba' + assert b.kinds[1].symbols[0] == 'Ti' for i in range(3): - self.assertAlmostEqual(b.sites[0].position[i], 0.) + assert round(abs(b.sites[0].position[i] - 0.), 7) == 0 for i in range(3): - self.assertAlmostEqual(b.sites[1].position[i], 1.) + assert round(abs(b.sites[1].position[i] - 1.), 7) == 0 def test_clone(self): """ @@ -1890,17 +1874,17 @@ def test_clone(self): for i in range(3): for j in range(3): - self.assertAlmostEqual(cell[i][j], b.cell[i][j]) + assert round(abs(cell[i][j] - b.cell[i][j]), 7) == 0 - self.assertEqual(b.pbc, (False, True, True)) - self.assertEqual(len(b.kinds), 2) - self.assertEqual(len(b.sites), 2) - self.assertEqual(b.kinds[0].symbols[0], 'Ba') - self.assertEqual(b.kinds[1].symbols[0], 'Ti') + assert b.pbc == (False, True, True) + assert len(b.kinds) == 2 + assert len(b.sites) == 2 + assert b.kinds[0].symbols[0] == 'Ba' + assert b.kinds[1].symbols[0] == 'Ti' for i in range(3): - self.assertAlmostEqual(b.sites[0].position[i], 0.) + assert round(abs(b.sites[0].position[i] - 0.), 7) == 0 for i in range(3): - self.assertAlmostEqual(b.sites[1].position[i], 1.) + assert round(abs(b.sites[1].position[i] - 1.), 7) == 0 a.store() @@ -1908,23 +1892,24 @@ def test_clone(self): c = a.clone() for i in range(3): for j in range(3): - self.assertAlmostEqual(cell[i][j], c.cell[i][j]) + assert round(abs(cell[i][j] - c.cell[i][j]), 7) == 0 - self.assertEqual(c.pbc, (False, True, True)) - self.assertEqual(len(c.kinds), 2) - self.assertEqual(len(c.sites), 2) - self.assertEqual(c.kinds[0].symbols[0], 'Ba') - self.assertEqual(c.kinds[1].symbols[0], 'Ti') + assert c.pbc == (False, True, True) + assert len(c.kinds) == 2 + assert len(c.sites) == 2 + assert c.kinds[0].symbols[0] == 'Ba' + assert c.kinds[1].symbols[0] == 'Ti' for i in range(3): - self.assertAlmostEqual(c.sites[0].position[i], 0.) + assert round(abs(c.sites[0].position[i] - 0.), 7) == 0 for i in range(3): - self.assertAlmostEqual(c.sites[1].position[i], 1.) + assert round(abs(c.sites[1].position[i] - 1.), 7) == 0 -class TestStructureDataFromAse(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestStructureDataFromAse: """Tests the creation of Sites from/to a ASE object.""" - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @skip_ase def test_ase(self): """Tests roundtrip ASE -> StructureData -> ASE.""" import ase @@ -1939,17 +1924,17 @@ def test_ase(self): b = StructureData(ase=a) c = b.get_ase() - self.assertEqual(a[0].symbol, c[0].symbol) - self.assertEqual(a[1].symbol, c[1].symbol) + assert a[0].symbol == c[0].symbol + assert a[1].symbol == c[1].symbol for i in range(3): - self.assertAlmostEqual(a[0].position[i], c[0].position[i]) + assert round(abs(a[0].position[i] - c[0].position[i]), 7) == 0 for i in range(3): for j in range(3): - self.assertAlmostEqual(a.cell[i][j], c.cell[i][j]) + assert round(abs(a.cell[i][j] - c.cell[i][j]), 7) == 0 - self.assertAlmostEqual(c[1].mass, 110.2) + assert round(abs(c[1].mass - 110.2), 7) == 0 - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @skip_ase def test_conversion_of_types_1(self): """ Tests roundtrip ASE -> StructureData -> ASE, with tags @@ -1971,14 +1956,14 @@ def test_conversion_of_types_1(self): a.set_tags((0, 1, 2, 3, 4, 5, 6, 7)) b = StructureData(ase=a) - self.assertEqual([k.name for k in b.kinds], ['Si', 'Si1', 'Si2', 'Si3', 'Ge4', 'Ge5', 'Ge6', 'Ge7']) + assert [k.name for k in b.kinds] == ['Si', 'Si1', 'Si2', 'Si3', 'Ge4', 'Ge5', 'Ge6', 'Ge7'] c = b.get_ase() a_tags = list(a.get_tags()) c_tags = list(c.get_tags()) - self.assertEqual(a_tags, c_tags) + assert a_tags == c_tags - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @skip_ase def test_conversion_of_types_2(self): """ Tests roundtrip ASE -> StructureData -> ASE, with tags, and @@ -2002,7 +1987,7 @@ def test_conversion_of_types_2(self): # This will give funny names to the kinds, because I am using # both tags and different properties (mass). I just check to have # 4 kinds - self.assertEqual(len(b.kinds), 4) + assert len(b.kinds) == 4 # Do I get the same tags after one full iteration back and forth? c = b.get_ase() @@ -2010,9 +1995,9 @@ def test_conversion_of_types_2(self): e = d.get_ase() c_tags = list(c.get_tags()) e_tags = list(e.get_tags()) - self.assertEqual(c_tags, e_tags) + assert c_tags == e_tags - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @skip_ase def test_conversion_of_types_3(self): """ Tests StructureData -> ASE, with all sorts of kind names @@ -2034,13 +2019,13 @@ def test_conversion_of_types_3(self): # Just to be sure that the species were saved with the correct name # in the first place - self.assertEqual([k.name for k in a.kinds], ['Ba', 'Ba1', 'Cu', 'Cu2', 'Cu_my', 'a_name', 'Fe', 'cu1']) + assert [k.name for k in a.kinds] == ['Ba', 'Ba1', 'Cu', 'Cu2', 'Cu_my', 'a_name', 'Fe', 'cu1'] b = a.get_ase() - self.assertEqual(b.get_chemical_symbols(), ['Ba', 'Ba', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu']) - self.assertEqual(list(b.get_tags()), [0, 1, 0, 2, 3, 4, 5, 6]) + assert b.get_chemical_symbols() == ['Ba', 'Ba', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu'] + assert list(b.get_tags()) == [0, 1, 0, 2, 3, 4, 5, 6] - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @skip_ase def test_conversion_of_types_4(self): """ Tests ASE -> StructureData -> ASE, in particular conversion tags / kind names @@ -2054,14 +2039,14 @@ def test_conversion_of_types_4(self): atoms.set_cell([1, 1, 1]) s = StructureData(ase=atoms) kindnames = {k.name for k in s.kinds} - self.assertEqual(kindnames, set(['Fe', 'Fe1', 'Fe4'])) + assert kindnames == set(['Fe', 'Fe1', 'Fe4']) # check roundtrip ASE -> StructureData -> ASE atoms2 = s.get_ase() - self.assertEqual(list(atoms2.get_tags()), list(atoms.get_tags())) - self.assertEqual(list(atoms2.get_chemical_symbols()), list(atoms.get_chemical_symbols())) - self.assertEqual(atoms2.get_chemical_formula(), 'Fe5') + assert list(atoms2.get_tags()) == list(atoms.get_tags()) + assert list(atoms2.get_chemical_symbols()) == list(atoms.get_chemical_symbols()) + assert atoms2.get_chemical_formula() == 'Fe5' - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @skip_ase def test_conversion_of_types_5(self): """ Tests ASE -> StructureData -> ASE, in particular conversion tags / kind names @@ -2076,14 +2061,14 @@ def test_conversion_of_types_5(self): atoms.set_cell([1, 1, 1]) s = StructureData(ase=atoms) kindnames = {k.name for k in s.kinds} - self.assertEqual(kindnames, set(['Fe', 'Fe1', 'Fe4'])) + assert kindnames == set(['Fe', 'Fe1', 'Fe4']) # check roundtrip ASE -> StructureData -> ASE atoms2 = s.get_ase() - self.assertEqual(list(atoms2.get_tags()), list(atoms.get_tags())) - self.assertEqual(list(atoms2.get_chemical_symbols()), list(atoms.get_chemical_symbols())) - self.assertEqual(atoms2.get_chemical_formula(), 'Fe5') + assert list(atoms2.get_tags()) == list(atoms.get_tags()) + assert list(atoms2.get_chemical_symbols()) == list(atoms.get_chemical_symbols()) + assert atoms2.get_chemical_formula() == 'Fe5' - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @skip_ase def test_conversion_of_types_6(self): """ Tests roundtrip StructureData -> ASE -> StructureData, with tags/kind names @@ -2095,22 +2080,23 @@ def test_conversion_of_types_6(self): a.append_atom(position=(1, 3, 1), symbols='Cl', name='Cl') b = a.get_ase() - self.assertEqual(b.get_chemical_symbols(), ['Ni', 'Ni', 'Cl', 'Cl']) - self.assertEqual(list(b.get_tags()), [1, 2, 0, 0]) + assert b.get_chemical_symbols() == ['Ni', 'Ni', 'Cl', 'Cl'] + assert list(b.get_tags()) == [1, 2, 0, 0] c = StructureData(ase=b) - self.assertEqual(c.get_site_kindnames(), ['Ni1', 'Ni2', 'Cl', 'Cl']) - self.assertEqual([k.symbol for k in c.kinds], ['Ni', 'Ni', 'Cl']) - self.assertEqual([s.position for s in c.sites], [(0., 0., 0.), (2., 2., 2.), (1., 0., 1.), (1., 3., 1.)]) + assert c.get_site_kindnames() == ['Ni1', 'Ni2', 'Cl', 'Cl'] + assert [k.symbol for k in c.kinds] == ['Ni', 'Ni', 'Cl'] + assert [s.position for s in c.sites] == [(0., 0., 0.), (2., 2., 2.), (1., 0., 1.), (1., 3., 1.)] -class TestStructureDataFromPymatgen(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestStructureDataFromPymatgen: """ Tests the creation of StructureData from a pymatgen Structure and Molecule objects. """ - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_pymatgen def test_1(self): """ Tests roundtrip pymatgen -> StructureData -> pymatgen @@ -2154,17 +2140,17 @@ def test_1(self): structs_to_test = [StructureData(pymatgen=pymatgen_struct), StructureData(pymatgen_structure=pymatgen_struct)] for struct in structs_to_test: - self.assertEqual(struct.get_site_kindnames(), ['Bi', 'Bi', 'SeTe', 'SeTe', 'SeTe']) + assert struct.get_site_kindnames() == ['Bi', 'Bi', 'SeTe', 'SeTe', 'SeTe'] # Pymatgen's Composition does not guarantee any particular ordering of the kinds, # see the definition of its internal datatype at # pymatgen/core/composition.py#L135 (d4fe64c18a52949a4e22bfcf7b45de5b87242c51) - self.assertEqual([sorted(x.symbols) for x in struct.kinds], [[ + assert [sorted(x.symbols) for x in struct.kinds] == [[ 'Bi', - ], ['Se', 'Te']]) - self.assertEqual([sorted(x.weights) for x in struct.kinds], [[ + ], ['Se', 'Te']] + assert [sorted(x.weights) for x in struct.kinds] == [[ 1.0, - ], [0.33333, 0.66667]]) + ], [0.33333, 0.66667]] struct = StructureData(pymatgen_structure=pymatgen_struct) @@ -2196,7 +2182,7 @@ def recursively_compare_values(left, right): recursively_compare_values(dict1, dict2) - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_pymatgen def test_2(self): """ Tests xyz -> pymatgen -> StructureData @@ -2219,15 +2205,15 @@ def test_2(self): pymatgen_mol = pymatgen_xyz.molecule for struct in [StructureData(pymatgen=pymatgen_mol), StructureData(pymatgen_molecule=pymatgen_mol)]: - self.assertEqual(struct.get_site_kindnames(), ['H', 'H', 'H', 'H', 'C']) - self.assertEqual(struct.pbc, (False, False, False)) - self.assertEqual([round(x, 2) for x in list(struct.sites[0].position)], [5.77, 5.89, 6.81]) - self.assertEqual([round(x, 2) for x in list(struct.sites[1].position)], [6.8, 5.89, 5.36]) - self.assertEqual([round(x, 2) for x in list(struct.sites[2].position)], [5.26, 5.0, 5.36]) - self.assertEqual([round(x, 2) for x in list(struct.sites[3].position)], [5.26, 6.78, 5.36]) - self.assertEqual([round(x, 2) for x in list(struct.sites[4].position)], [5.77, 5.89, 5.73]) - - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + assert struct.get_site_kindnames() == ['H', 'H', 'H', 'H', 'C'] + assert struct.pbc == (False, False, False) + assert [round(x, 2) for x in list(struct.sites[0].position)] == [5.77, 5.89, 6.81] + assert [round(x, 2) for x in list(struct.sites[1].position)] == [6.8, 5.89, 5.36] + assert [round(x, 2) for x in list(struct.sites[2].position)] == [5.26, 5.0, 5.36] + assert [round(x, 2) for x in list(struct.sites[3].position)] == [5.26, 6.78, 5.36] + assert [round(x, 2) for x in list(struct.sites[4].position)] == [5.77, 5.89, 5.73] + + @skip_pymatgen def test_partial_occ_and_spin(self): """ Tests pymatgen -> StructureData, with partial occupancies and spins. @@ -2247,7 +2233,7 @@ def test_partial_occ_and_spin(self): lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], species=[FeMn1, FeMn2], coords=[[0, 0, 0], [0.5, 0.5, 0.5]] ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(pymatgen=a) # same, with vacancies @@ -2257,10 +2243,10 @@ def test_partial_occ_and_spin(self): lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], species=[Fe1, Fe2], coords=[[0, 0, 0], [0.5, 0.5, 0.5]] ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(pymatgen=a) - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_pymatgen @staticmethod def test_multiple_kinds_partial_occupancies(): """Tests that a structure with multiple sites with the same element but different @@ -2277,7 +2263,7 @@ def test_multiple_kinds_partial_occupancies(): StructureData(pymatgen=a) - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_pymatgen @staticmethod def test_multiple_kinds_alloy(): """ @@ -2299,11 +2285,12 @@ def test_multiple_kinds_alloy(): StructureData(pymatgen=a) -class TestPymatgenFromStructureData(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestPymatgenFromStructureData: """Tests the creation of pymatgen Structure and Molecule objects from StructureData.""" - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_pymatgen def test_1(self): """Tests the check of periodic boundary conditions.""" struct = StructureData() @@ -2312,11 +2299,11 @@ def test_1(self): struct.get_pymatgen_structure() struct.pbc = [True, True, False] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): struct.get_pymatgen_structure() - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_ase + @skip_pymatgen def test_2(self): """Tests ASE -> StructureData -> pymatgen.""" import ase @@ -2337,10 +2324,10 @@ def test_2(self): for i, _ in enumerate(coord_array): coord_array[i] = [round(x, 2) for x in coord_array[i]] - self.assertEqual(coord_array, [[0.0, 0.0, 0.0], [0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]]) + assert coord_array == [[0.0, 0.0, 0.0], [0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]] - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_ase + @skip_pymatgen def test_3(self): """ Tests the conversion of StructureData to pymatgen's Molecule @@ -2360,10 +2347,10 @@ def test_3(self): p_mol = a_struct.get_pymatgen_molecule() p_mol_dict = p_mol.as_dict() - self.assertEqual([x['xyz'] for x in p_mol_dict['sites']], - [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) + assert [x['xyz'] for x in p_mol_dict['sites']] == \ + [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]] - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_pymatgen def test_roundtrip(self): """ Tests roundtrip StructureData -> pymatgen -> StructureData @@ -2381,12 +2368,12 @@ def test_roundtrip(self): b = a.get_pymatgen() c = StructureData(pymatgen=b) - self.assertEqual(c.get_site_kindnames(), ['Cl', 'Cl', 'Cl', 'Cl', 'Na', 'Na', 'Na', 'Na']) - self.assertEqual([k.symbol for k in c.kinds], ['Cl', 'Na']) - self.assertEqual([s.position for s in c.sites], [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), - (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)]) + assert c.get_site_kindnames() == ['Cl', 'Cl', 'Cl', 'Cl', 'Na', 'Na', 'Na', 'Na'] + assert [k.symbol for k in c.kinds] == ['Cl', 'Na'] + assert [s.position for s in c.sites] == [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), + (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)] - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_pymatgen def test_roundtrip_kindnames(self): """ Tests roundtrip StructureData -> pymatgen -> StructureData @@ -2403,16 +2390,16 @@ def test_roundtrip_kindnames(self): a.append_atom(position=(0, 0, 2.8), symbols='Na', name='Na4') b = a.get_pymatgen() - self.assertEqual([site.properties['kind_name'] for site in b.sites], - ['Cl', 'Cl10', 'Cla', 'cl_x', 'Na1', 'Na2', 'Na_Na', 'Na4']) + assert [site.properties['kind_name'] for site in b.sites] == \ + ['Cl', 'Cl10', 'Cla', 'cl_x', 'Na1', 'Na2', 'Na_Na', 'Na4'] c = StructureData(pymatgen=b) - self.assertEqual(c.get_site_kindnames(), ['Cl', 'Cl10', 'Cla', 'cl_x', 'Na1', 'Na2', 'Na_Na', 'Na4']) - self.assertEqual(c.get_symbols_set(), set(['Cl', 'Na'])) - self.assertEqual([s.position for s in c.sites], [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), - (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)]) + assert c.get_site_kindnames() == ['Cl', 'Cl10', 'Cla', 'cl_x', 'Na1', 'Na2', 'Na_Na', 'Na4'] + assert c.get_symbols_set() == set(['Cl', 'Na']) + assert [s.position for s in c.sites] == [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), + (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)] - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_pymatgen def test_roundtrip_spins(self): """ Tests roundtrip StructureData -> pymatgen -> StructureData @@ -2430,15 +2417,15 @@ def test_roundtrip_spins(self): b = a.get_pymatgen(add_spin=True) # check the spins - self.assertEqual([s.as_dict()['properties']['spin'] for s in b.species], [-1, -1, -1, -1, 1, 1, 1, 1]) + assert [s.as_dict()['properties']['spin'] for s in b.species] == [-1, -1, -1, -1, 1, 1, 1, 1] # back to StructureData c = StructureData(pymatgen=b) - self.assertEqual(c.get_site_kindnames(), ['Mn1', 'Mn1', 'Mn1', 'Mn1', 'Mn2', 'Mn2', 'Mn2', 'Mn2']) - self.assertEqual([k.symbol for k in c.kinds], ['Mn', 'Mn']) - self.assertEqual([s.position for s in c.sites], [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), - (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)]) + assert c.get_site_kindnames() == ['Mn1', 'Mn1', 'Mn1', 'Mn1', 'Mn2', 'Mn2', 'Mn2', 'Mn2'] + assert [k.symbol for k in c.kinds] == ['Mn', 'Mn'] + assert [s.position for s in c.sites] == [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), + (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)] - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_pymatgen def test_roundtrip_partial_occ(self): """ Tests roundtrip StructureData -> pymatgen -> StructureData @@ -2461,13 +2448,13 @@ def test_roundtrip_partial_occ(self): a.append_atom(position=(2., 1., 9.5), symbols='N') # a few checks on the structure kinds and symbols - self.assertEqual(a.get_symbols_set(), set(['Mn', 'Si', 'N'])) - self.assertEqual(a.get_site_kindnames(), ['Mn', 'Mn', 'Mn', 'Mn', 'MnX', 'MnX', 'Si', 'Si', 'N', 'N', 'N', 'N']) - self.assertEqual(a.get_formula(), 'Mn4N4Si2{Mn0.80X0.20}2') + assert a.get_symbols_set() == set(['Mn', 'Si', 'N']) + assert a.get_site_kindnames() == ['Mn', 'Mn', 'Mn', 'Mn', 'MnX', 'MnX', 'Si', 'Si', 'N', 'N', 'N', 'N'] + assert a.get_formula() == 'Mn4N4Si2{Mn0.80X0.20}2' b = a.get_pymatgen() # check the partial occupancies - self.assertEqual([s.as_dict() for s in b.species_and_occu], [{ + assert [s.as_dict() for s in b.species_and_occu] == [{ 'Mn': 1.0 }, { 'Mn': 1.0 @@ -2491,20 +2478,20 @@ def test_roundtrip_partial_occ(self): 'N': 1.0 }, { 'N': 1.0 - }]) + }] # back to StructureData c = StructureData(pymatgen=b) - self.assertEqual(c.cell, [[4., 0.0, 0.0], [-2., 3.5, 0.0], [0.0, 0.0, 16.]]) - self.assertEqual(c.get_symbols_set(), set(['Mn', 'Si', 'N'])) - self.assertEqual(c.get_site_kindnames(), ['Mn', 'Mn', 'Mn', 'Mn', 'MnX', 'MnX', 'Si', 'Si', 'N', 'N', 'N', 'N']) - self.assertEqual(c.get_formula(), 'Mn4N4Si2{Mn0.80X0.20}2') + assert c.cell == [[4., 0.0, 0.0], [-2., 3.5, 0.0], [0.0, 0.0, 16.]] + assert c.get_symbols_set() == set(['Mn', 'Si', 'N']) + assert c.get_site_kindnames() == ['Mn', 'Mn', 'Mn', 'Mn', 'MnX', 'MnX', 'Si', 'Si', 'N', 'N', 'N', 'N'] + assert c.get_formula() == 'Mn4N4Si2{Mn0.80X0.20}2' testing.assert_allclose([s.position for s in c.sites], [(0.0, 0.0, 13.5), (0.0, 0.0, 2.6), (0.0, 0.0, 5.5), (0.0, 0.0, 11.), (2., 1., 12.), (0.0, 2.2, 4.), (0.0, 2.2, 12.), (2., 1., 4.), (2., 1., 15.), (0.0, 2.2, 1.5), (0.0, 2.2, 7.), (2., 1., 9.5)]) - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @skip_pymatgen def test_partial_occ_and_spin(self): """Tests StructureData -> pymatgen, with partial occupancies and spins. This should raise a ValueError.""" @@ -2513,11 +2500,11 @@ def test_partial_occ_and_spin(self): a.append_atom(position=(2, 2, 2), symbols=('Fe', 'Al'), weights=(0.8, 0.2), name='FeAl2') # a few checks on the structure kinds and symbols - self.assertEqual(a.get_symbols_set(), set(['Fe', 'Al'])) - self.assertEqual(a.get_site_kindnames(), ['FeAl1', 'FeAl2']) - self.assertEqual(a.get_formula(), '{Al0.20Fe0.80}2') + assert a.get_symbols_set() == set(['Fe', 'Al']) + assert a.get_site_kindnames() == ['FeAl1', 'FeAl2'] + assert a.get_formula() == '{Al0.20Fe0.80}2' - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.get_pymatgen(add_spin=True) # same, with vacancies @@ -2526,15 +2513,16 @@ def test_partial_occ_and_spin(self): a.append_atom(position=(2, 2, 2), symbols='Fe', weights=0.8, name='FeX2') # a few checks on the structure kinds and symbols - self.assertEqual(a.get_symbols_set(), set(['Fe'])) - self.assertEqual(a.get_site_kindnames(), ['FeX1', 'FeX2']) - self.assertEqual(a.get_formula(), '{Fe0.80X0.20}2') + assert a.get_symbols_set() == set(['Fe']) + assert a.get_site_kindnames() == ['FeX1', 'FeX2'] + assert a.get_formula() == '{Fe0.80X0.20}2' - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.get_pymatgen(add_spin=True) -class TestArrayData(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestArrayData: """Tests the ArrayData objects.""" def test_creation(self): @@ -2543,134 +2531,131 @@ def test_creation(self): array shapes. """ # pylint: disable=too-many-statements - import numpy # Create a node with two arrays n = ArrayData() - first = numpy.random.rand(2, 3, 4) + first = np.random.rand(2, 3, 4) n.set_array('first', first) - second = numpy.arange(10) + second = np.arange(10) n.set_array('second', second) - third = numpy.random.rand(6, 6) + third = np.random.rand(6, 6) n.set_array('third', third) # Check if the arrays are there - self.assertEqual(set(['first', 'second', 'third']), set(n.get_arraynames())) - self.assertAlmostEqual(abs(first - n.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n.get_array('second')).max(), 0.) - self.assertAlmostEqual(abs(third - n.get_array('third')).max(), 0.) - self.assertEqual(first.shape, n.get_shape('first')) - self.assertEqual(second.shape, n.get_shape('second')) - self.assertEqual(third.shape, n.get_shape('third')) - - with self.assertRaises(KeyError): + assert set(['first', 'second', 'third']) == set(n.get_arraynames()) + assert round(abs(abs(first - n.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n.get_array('second')).max() - 0.), 7) == 0 + assert round(abs(abs(third - n.get_array('third')).max() - 0.), 7) == 0 + assert first.shape == n.get_shape('first') + assert second.shape == n.get_shape('second') + assert third.shape == n.get_shape('third') + + with pytest.raises(KeyError): n.get_array('nonexistent_array') # Delete an array, and try to delete a non-existing one n.delete_array('third') - with self.assertRaises(KeyError): + with pytest.raises(KeyError): n.delete_array('nonexistent_array') # Overwrite an array - first = numpy.random.rand(4, 5, 6) + first = np.random.rand(4, 5, 6) n.set_array('first', first) # Check if the arrays are there, and if I am getting the new one - self.assertEqual(set(['first', 'second']), set(n.get_arraynames())) - self.assertAlmostEqual(abs(first - n.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n.get_shape('first')) - self.assertEqual(second.shape, n.get_shape('second')) + assert set(['first', 'second']) == set(n.get_arraynames()) + assert round(abs(abs(first - n.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n.get_shape('first') + assert second.shape == n.get_shape('second') n.store() # Same checks, after storing - self.assertEqual(set(['first', 'second']), set(n.get_arraynames())) - self.assertAlmostEqual(abs(first - n.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n.get_shape('first')) - self.assertEqual(second.shape, n.get_shape('second')) + assert set(['first', 'second']) == set(n.get_arraynames()) + assert round(abs(abs(first - n.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n.get_shape('first') + assert second.shape == n.get_shape('second') # Same checks, again (this is checking the caching features) - self.assertEqual(set(['first', 'second']), set(n.get_arraynames())) - self.assertAlmostEqual(abs(first - n.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n.get_shape('first')) - self.assertEqual(second.shape, n.get_shape('second')) + assert set(['first', 'second']) == set(n.get_arraynames()) + assert round(abs(abs(first - n.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n.get_shape('first') + assert second.shape == n.get_shape('second') # Same checks, after reloading n2 = load_node(uuid=n.uuid) - self.assertEqual(set(['first', 'second']), set(n2.get_arraynames())) - self.assertAlmostEqual(abs(first - n2.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n2.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n2.get_shape('first')) - self.assertEqual(second.shape, n2.get_shape('second')) + assert set(['first', 'second']) == set(n2.get_arraynames()) + assert round(abs(abs(first - n2.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n2.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n2.get_shape('first') + assert second.shape == n2.get_shape('second') # Same checks, after reloading with UUID n2 = load_node(n.uuid, sub_classes=(ArrayData,)) - self.assertEqual(set(['first', 'second']), set(n2.get_arraynames())) - self.assertAlmostEqual(abs(first - n2.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n2.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n2.get_shape('first')) - self.assertEqual(second.shape, n2.get_shape('second')) + assert set(['first', 'second']) == set(n2.get_arraynames()) + assert round(abs(abs(first - n2.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n2.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n2.get_shape('first') + assert second.shape == n2.get_shape('second') # Check that I cannot modify the node after storing - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): n.delete_array('first') - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): n.set_array('second', first) # Again same checks, to verify that the attempts to delete/overwrite # arrays did not damage the node content - self.assertEqual(set(['first', 'second']), set(n.get_arraynames())) - self.assertAlmostEqual(abs(first - n.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n.get_shape('first')) - self.assertEqual(second.shape, n.get_shape('second')) + assert set(['first', 'second']) == set(n.get_arraynames()) + assert round(abs(abs(first - n.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n.get_shape('first') + assert second.shape == n.get_shape('second') def test_iteration(self): """ Check the functionality of the get_iterarrays() iterator """ - import numpy - # Create a node with two arrays n = ArrayData() - first = numpy.random.rand(2, 3, 4) + first = np.random.rand(2, 3, 4) n.set_array('first', first) - second = numpy.arange(10) + second = np.arange(10) n.set_array('second', second) - third = numpy.random.rand(6, 6) + third = np.random.rand(6, 6) n.set_array('third', third) for name, array in n.get_iterarrays(): if name == 'first': - self.assertAlmostEqual(abs(first - array).max(), 0.) + assert round(abs(abs(first - array).max() - 0.), 7) == 0 if name == 'second': - self.assertAlmostEqual(abs(second - array).max(), 0.) + assert round(abs(abs(second - array).max() - 0.), 7) == 0 if name == 'third': - self.assertAlmostEqual(abs(third - array).max(), 0.) + assert round(abs(abs(third - array).max() - 0.), 7) == 0 -class TestTrajectoryData(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestTrajectoryData: """Tests the TrajectoryData objects.""" def test_creation(self): """Check the methods to set and retrieve a trajectory.""" # pylint: disable=too-many-statements - import numpy # Create a node with two arrays n = TrajectoryData() # I create sample data - stepids = numpy.array([60, 70]) + stepids = np.array([60, 70]) times = stepids * 0.01 - cells = numpy.array([[[ + cells = np.array([[[ 2., 0., 0., @@ -2696,10 +2681,10 @@ def test_creation(self): 3., ]]]) symbols = ['H', 'O', 'C'] - positions = numpy.array([[[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]], - [[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]]]) - velocities = numpy.array([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], - [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [-0.5, -0.5, -0.5]]]) + positions = np.array([[[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]], + [[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]]]) + velocities = np.array([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [-0.5, -0.5, -0.5]]]) # I set the node n.set_trajectory( @@ -2707,27 +2692,27 @@ def test_creation(self): ) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertAlmostEqual(abs(times - n.get_times()).sum(), 0.) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertAlmostEqual(abs(velocities - n.get_velocities()).sum(), 0.) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert round(abs(abs(times - n.get_times()).sum() - 0.), 7) == 0 + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert round(abs(abs(velocities - n.get_velocities()).sum() - 0.), 7) == 0 # get_step_data function check data = n.get_step_data(1) - self.assertEqual(data[0], stepids[1]) - self.assertAlmostEqual(data[1], times[1]) - self.assertAlmostEqual(abs(cells[1] - data[2]).sum(), 0.) - self.assertEqual(symbols, data[3]) - self.assertAlmostEqual(abs(data[4] - positions[1]).sum(), 0.) - self.assertAlmostEqual(abs(data[5] - velocities[1]).sum(), 0.) + assert data[0] == stepids[1] + assert round(abs(data[1] - times[1]), 7) == 0 + assert round(abs(abs(cells[1] - data[2]).sum() - 0.), 7) == 0 + assert symbols == data[3] + assert round(abs(abs(data[4] - positions[1]).sum() - 0.), 7) == 0 + assert round(abs(abs(data[5] - velocities[1]).sum() - 0.), 7) == 0 # Step 70 has index 1 - self.assertEqual(1, n.get_index_from_stepid(70)) - with self.assertRaises(ValueError): + assert n.get_index_from_stepid(70) == 1 + with pytest.raises(ValueError): # Step 66 does not exist n.get_index_from_stepid(66) @@ -2735,79 +2720,79 @@ def test_creation(self): # I set the node, this time without times or velocities (the same node) n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertIsNone(n.get_times()) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert n.get_times() is None + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None # Same thing, but for a new node n = TrajectoryData() n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertIsNone(n.get_times()) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert n.get_times() is None + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None ######################################################## # I set the node, this time without velocities (the same node) n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertAlmostEqual(abs(times - n.get_times()).sum(), 0.) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert round(abs(abs(times - n.get_times()).sum() - 0.), 7) == 0 + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None # Same thing, but for a new node n = TrajectoryData() n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertAlmostEqual(abs(times - n.get_times()).sum(), 0.) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert round(abs(abs(times - n.get_times()).sum() - 0.), 7) == 0 + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None n.store() # Again same checks, but after storing # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertAlmostEqual(abs(times - n.get_times()).sum(), 0.) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert round(abs(abs(times - n.get_times()).sum() - 0.), 7) == 0 + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None # get_step_data function check data = n.get_step_data(1) - self.assertEqual(data[0], stepids[1]) - self.assertAlmostEqual(data[1], times[1]) - self.assertAlmostEqual(abs(cells[1] - data[2]).sum(), 0.) - self.assertEqual(symbols, data[3]) - self.assertAlmostEqual(abs(data[4] - positions[1]).sum(), 0.) - self.assertIsNone(data[5]) + assert data[0] == stepids[1] + assert round(abs(data[1] - times[1]), 7) == 0 + assert round(abs(abs(cells[1] - data[2]).sum() - 0.), 7) == 0 + assert symbols == data[3] + assert round(abs(abs(data[4] - positions[1]).sum() - 0.), 7) == 0 + assert data[5] is None # Step 70 has index 1 - self.assertEqual(1, n.get_index_from_stepid(70)) - with self.assertRaises(ValueError): + assert n.get_index_from_stepid(70) == 1 + with pytest.raises(ValueError): # Step 66 does not exist n.get_index_from_stepid(66) @@ -2815,27 +2800,27 @@ def test_creation(self): # Again, but after reloading from uuid n = load_node(n.uuid, sub_classes=(TrajectoryData,)) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertAlmostEqual(abs(times - n.get_times()).sum(), 0.) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert round(abs(abs(times - n.get_times()).sum() - 0.), 7) == 0 + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None # get_step_data function check data = n.get_step_data(1) - self.assertEqual(data[0], stepids[1]) - self.assertAlmostEqual(data[1], times[1]) - self.assertAlmostEqual(abs(cells[1] - data[2]).sum(), 0.) - self.assertEqual(symbols, data[3]) - self.assertAlmostEqual(abs(data[4] - positions[1]).sum(), 0.) - self.assertIsNone(data[5]) + assert data[0] == stepids[1] + assert round(abs(data[1] - times[1]), 7) == 0 + assert round(abs(abs(cells[1] - data[2]).sum() - 0.), 7) == 0 + assert symbols == data[3] + assert round(abs(abs(data[4] - positions[1]).sum() - 0.), 7) == 0 + assert data[5] is None # Step 70 has index 1 - self.assertEqual(1, n.get_index_from_stepid(70)) - with self.assertRaises(ValueError): + assert n.get_index_from_stepid(70) == 1 + with pytest.raises(ValueError): # Step 66 does not exist n.get_index_from_stepid(66) @@ -2844,15 +2829,14 @@ def test_conversion_to_structure(self): """ Check the methods to export a given time step to a StructureData node. """ - import numpy # Create a node with two arrays n = TrajectoryData() # I create sample data - stepids = numpy.array([60, 70]) + stepids = np.array([60, 70]) times = stepids * 0.01 - cells = numpy.array([[[ + cells = np.array([[[ 2., 0., 0., @@ -2878,10 +2862,10 @@ def test_conversion_to_structure(self): 3., ]]]) symbols = ['H', 'O', 'C'] - positions = numpy.array([[[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]], - [[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]]]) - velocities = numpy.array([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], - [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [-0.5, -0.5, -0.5]]]) + positions = np.array([[[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]], + [[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]]]) + velocities = np.array([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [-0.5, -0.5, -0.5]]]) # I set the node n.set_trajectory( @@ -2892,12 +2876,12 @@ def test_conversion_to_structure(self): from_get_structure = n.get_structure(index=1) for struc in [from_step, from_get_structure]: - self.assertEqual(len(struc.sites), 3) # 3 sites - self.assertAlmostEqual(abs(numpy.array(struc.cell) - cells[1]).sum(), 0) - newpos = numpy.array([s.position for s in struc.sites]) - self.assertAlmostEqual(abs(newpos - positions[1]).sum(), 0) + assert len(struc.sites) == 3 # 3 sites + assert round(abs(abs(np.array(struc.cell) - cells[1]).sum() - 0), 7) == 0 + newpos = np.array([s.position for s in struc.sites]) + assert round(abs(abs(newpos - positions[1]).sum() - 0), 7) == 0 newkinds = [s.kind_name for s in struc.sites] - self.assertEqual(newkinds, symbols) + assert newkinds == symbols # Weird assignments (nobody should ever do this, but it is possible in # principle and we want to check @@ -2906,19 +2890,19 @@ def test_conversion_to_structure(self): k3 = Kind(name='O', symbols='Os', mass=100.) k4 = Kind(name='Ge', symbols='Ge') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Not enough kinds struc = n.get_step_structure(1, custom_kinds=[k1, k2]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Too many kinds struc = n.get_step_structure(1, custom_kinds=[k1, k2, k3, k4]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Wrong kinds struc = n.get_step_structure(1, custom_kinds=[k1, k2, k4]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Two kinds with the same name struc = n.get_step_structure(1, custom_kinds=[k1, k2, k3, k3]) @@ -2926,18 +2910,18 @@ def test_conversion_to_structure(self): struc = n.get_step_structure(1, custom_kinds=[k1, k2, k3]) # Checks - self.assertEqual(len(struc.sites), 3) # 3 sites - self.assertAlmostEqual(abs(numpy.array(struc.cell) - cells[1]).sum(), 0) - newpos = numpy.array([s.position for s in struc.sites]) - self.assertAlmostEqual(abs(newpos - positions[1]).sum(), 0) + assert len(struc.sites) == 3 # 3 sites + assert round(abs(abs(np.array(struc.cell) - cells[1]).sum() - 0), 7) == 0 + newpos = np.array([s.position for s in struc.sites]) + assert round(abs(abs(newpos - positions[1]).sum() - 0), 7) == 0 newkinds = [s.kind_name for s in struc.sites] # Kinds are in the same order as given in the custm_kinds list - self.assertEqual(newkinds, symbols) + assert newkinds == symbols newatomtypes = [struc.get_kind(s.kind_name).symbols[0] for s in struc.sites] # Atoms remain in the same order as given in the positions list - self.assertEqual(newatomtypes, ['He', 'Os', 'Cu']) + assert newatomtypes == ['He', 'Os', 'Cu'] # Check the mass of the kind of the second atom ('O' _> symbol Os, mass 100) - self.assertAlmostEqual(struc.get_kind(struc.sites[1].kind_name).mass, 100.) + assert round(abs(struc.get_kind(struc.sites[1].kind_name).mass - 100.), 7) == 0 def test_conversion_from_structurelist(self): """ @@ -2980,9 +2964,9 @@ def test_conversion_from_structurelist(self): structurelist.append(struct) td = TrajectoryData(structurelist=structurelist) - self.assertEqual(td.get_cells().tolist(), cells) - self.assertEqual(td.symbols, symbols[0]) - self.assertEqual(td.get_positions().tolist(), positions) + assert td.get_cells().tolist() == cells + assert td.symbols == symbols[0] + assert td.get_positions().tolist() == positions symbols = [['H', 'O', 'C'], ['H', 'O', 'P']] structurelist = [] @@ -2992,22 +2976,18 @@ def test_conversion_from_structurelist(self): struct.append_atom(symbols=symbol, position=positions[i][j]) structurelist.append(struct) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): td = TrajectoryData(structurelist=structurelist) @staticmethod def test_export_to_file(): """Export the band structure on a file, check if it is working.""" - import numpy - - from aiida.orm.nodes.data.cif import has_pycifrw - n = TrajectoryData() # I create sample data - stepids = numpy.array([60, 70]) + stepids = np.array([60, 70]) times = stepids * 0.01 - cells = numpy.array([[[ + cells = np.array([[[ 2., 0., 0., @@ -3033,10 +3013,10 @@ def test_export_to_file(): 3., ]]]) symbols = ['H', 'O', 'C'] - positions = numpy.array([[[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]], - [[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]]]) - velocities = numpy.array([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], - [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [-0.5, -0.5, -0.5]]]) + positions = np.array([[[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]], + [[0., 0., 0.], [0.5, 0.5, 0.5], [1.5, 1.5, 1.5]]]) + velocities = np.array([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [-0.5, -0.5, -0.5]]]) # I set the node n.set_trajectory( @@ -3068,7 +3048,8 @@ def test_export_to_file(): os.remove(file) -class TestKpointsData(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestKpointsData: """Tests the KpointsData objects.""" def test_mesh(self): @@ -3082,38 +3063,37 @@ def test_mesh(self): input_mesh = [4, 4, 4] k.set_kpoints_mesh(input_mesh) mesh, offset = k.get_kpoints_mesh() - self.assertEqual(mesh, input_mesh) - self.assertEqual(offset, [0., 0., 0.]) # must be a tuple of three 0 by default + assert mesh == input_mesh + assert offset == [0., 0., 0.] # must be a tuple of three 0 by default # a too long list should fail - with self.assertRaises(ValueError): + with pytest.raises(ValueError): k.set_kpoints_mesh([4, 4, 4, 4]) # now try to put explicitely an offset input_offset = [0.5, 0.5, 0.5] k.set_kpoints_mesh(input_mesh, input_offset) mesh, offset = k.get_kpoints_mesh() - self.assertEqual(mesh, input_mesh) - self.assertEqual(offset, input_offset) + assert mesh == input_mesh + assert offset == input_offset # verify the same but after storing k.store() - self.assertEqual(mesh, input_mesh) - self.assertEqual(offset, input_offset) + assert mesh == input_mesh + assert offset == input_offset # cannot modify it after storage - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): k.set_kpoints_mesh(input_mesh) def test_list(self): """ Test the method to set and retrieve a kpoint list. """ - import numpy k = KpointsData() - input_klist = numpy.array([ + input_klist = np.array([ (0.0, 0.0, 0.0), (0.2, 0.0, 0.0), (0.0, 0.2, 0.0), @@ -3125,44 +3105,43 @@ def test_list(self): klist = k.get_kpoints() # try to get the same - self.assertTrue(numpy.array_equal(input_klist, klist)) + assert np.array_equal(input_klist, klist) # if no cell is set, cannot convert into cartesian - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): _ = k.get_kpoints(cartesian=True) # try to set also weights # should fail if the weights length do not match kpoints - input_weights = numpy.ones(6) - with self.assertRaises(ValueError): + input_weights = np.ones(6) + with pytest.raises(ValueError): k.set_kpoints(input_klist, weights=input_weights) # try a right one - input_weights = numpy.ones(4) + input_weights = np.ones(4) k.set_kpoints(input_klist, weights=input_weights) klist, weights = k.get_kpoints(also_weights=True) - self.assertTrue(numpy.array_equal(weights, input_weights)) - self.assertTrue(numpy.array_equal(klist, input_klist)) + assert np.array_equal(weights, input_weights) + assert np.array_equal(klist, input_klist) # verify the same, but after storing k.store() klist, weights = k.get_kpoints(also_weights=True) - self.assertTrue(numpy.array_equal(weights, input_weights)) - self.assertTrue(numpy.array_equal(klist, input_klist)) + assert np.array_equal(weights, input_weights) + assert np.array_equal(klist, input_klist) # cannot modify it after storage - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): k.set_kpoints(input_klist) def test_kpoints_to_cartesian(self): """ Test how the list of kpoints is converted to cartesian coordinates """ - import numpy k = KpointsData() - input_klist = numpy.array([ + input_klist = np.array([ (0.0, 0.0, 0.0), (0.2, 0.0, 0.0), (0.0, 0.2, 0.0), @@ -3171,7 +3150,7 @@ def test_kpoints_to_cartesian(self): # define a cell alat = 4. - cell = numpy.array([ + cell = np.array([ [alat, 0., 0.], [0., alat, 0.], [0., 0., alat], @@ -3185,13 +3164,13 @@ def test_kpoints_to_cartesian(self): # verify that it is not the same of the input # (at least I check that there something has been done) klist = k.get_kpoints(cartesian=True) - self.assertFalse(numpy.array_equal(klist, input_klist)) + assert not np.array_equal(klist, input_klist) # put the kpoints in cartesian and get them back, they should be equal # internally it is doing two matrix transforms k.set_kpoints(input_klist, cartesian=True) klist = k.get_kpoints(cartesian=True) - self.assertTrue(numpy.allclose(klist, input_klist, atol=1e-16)) + assert np.allclose(klist, input_klist, atol=1e-16) def test_path_wrapper_legacy(self): """ @@ -3199,17 +3178,16 @@ def test_path_wrapper_legacy(self): calling the deprecated legacy implementation. This tests that the wrapper maintains the same behavior of the old implementation """ - import numpy from aiida.tools.data.array.kpoints import get_explicit_kpoints_path # Shouldn't get anything without having set the cell - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): get_explicit_kpoints_path(None) # Define a cell alat = 4. - cell = numpy.array([ + cell = np.array([ [alat, 0., 0.], [0., alat, 0.], [0., 0., alat], @@ -3232,12 +3210,12 @@ def test_path_wrapper_legacy(self): ]) # at least 2 points per segment - with self.assertRaises(ValueError): + with pytest.raises(ValueError): get_explicit_kpoints_path(structure, method='legacy', value=[ ('G', 'M', 1), ]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): get_explicit_kpoints_path(structure, method='legacy', value=[ ('G', (0., 0., 0.), 'M', (1., 1., 1.), 1), ]) @@ -3251,7 +3229,6 @@ def test_tetra_z_wrapper_legacy(self): calling the deprecated legacy implementation. This tests that the wrapper maintains the same behavior of the old implementation """ - import numpy from aiida.tools.data.array.kpoints import get_kpoints_path @@ -3260,23 +3237,22 @@ def test_tetra_z_wrapper_legacy(self): s = StructureData(cell=cell_x) result = get_kpoints_path(s, method='legacy', cartesian=True) - self.assertIsInstance(result['parameters'], Dict) + assert isinstance(result['parameters'], Dict) point_coords = result['parameters'].dict.point_coords - self.assertAlmostEqual(point_coords['Z'][2], numpy.pi / alat) - self.assertAlmostEqual(point_coords['Z'][0], 0.) + assert round(abs(point_coords['Z'][2] - np.pi / alat), 7) == 0 + assert round(abs(point_coords['Z'][0] - 0.), 7) == 0 -class TestSpglibTupleConversion(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestSpglibTupleConversion: """Tests for conversion of Spglib tuples.""" def test_simple_to_aiida(self): """ Test conversion of a simple tuple to an AiiDA structure """ - import numpy as np - from aiida.tools import spglib_tuple_to_structure cell = np.array([[4., 1., 0.], [0., 4., 0.], [0., 0., 4.]]) @@ -3290,16 +3266,14 @@ def test_simple_to_aiida(self): struc = spglib_tuple_to_structure((cell, relcoords, numbers)) - self.assertAlmostEqual(np.sum(np.abs(np.array(struc.cell) - np.array(cell))), 0.) - self.assertAlmostEqual( - np.sum(np.abs(np.array([site.position for site in struc.sites]) - np.array(abscoords))), 0. - ) - self.assertEqual([site.kind_name for site in struc.sites], ['Ba', 'Ti', 'O', 'O', 'O']) + assert round(abs(np.sum(np.abs(np.array(struc.cell) - np.array(cell))) - 0.), 7) == 0 + assert round( + abs(np.sum(np.abs(np.array([site.position for site in struc.sites]) - np.array(abscoords))) - 0.), 7 + ) == 0 + assert [site.kind_name for site in struc.sites] == ['Ba', 'Ti', 'O', 'O', 'O'] def test_complex1_to_aiida(self): """Test conversion of a tuple to an AiiDA structure when passing also information on the kinds.""" - import numpy as np - from aiida.tools import spglib_tuple_to_structure cell = np.array([[4., 1., 0.], [0., 4., 0.], [0., 0., 4.]]) @@ -3337,31 +3311,29 @@ def test_complex1_to_aiida(self): ] # Must specify also kind_info and kinds - with self.assertRaises(ValueError): + with pytest.raises(ValueError): struc = spglib_tuple_to_structure((cell, relcoords, numbers),) # There is no kind_info for one of the numbers - with self.assertRaises(ValueError): + with pytest.raises(ValueError): struc = spglib_tuple_to_structure((cell, relcoords, numbers), kind_info=kind_info_wrong, kinds=kinds) # There is no kind in the kinds for one of the labels # specified in kind_info - with self.assertRaises(ValueError): + with pytest.raises(ValueError): struc = spglib_tuple_to_structure((cell, relcoords, numbers), kind_info=kind_info, kinds=kinds_wrong) struc = spglib_tuple_to_structure((cell, relcoords, numbers), kind_info=kind_info, kinds=kinds) - self.assertAlmostEqual(np.sum(np.abs(np.array(struc.cell) - np.array(cell))), 0.) - self.assertAlmostEqual( - np.sum(np.abs(np.array([site.position for site in struc.sites]) - np.array(abscoords))), 0. - ) - self.assertEqual([site.kind_name for site in struc.sites], - ['Ba', 'Ti', 'O', 'O', 'O', 'Ba2', 'BaTi', 'BaTi2', 'Ba3']) + assert round(abs(np.sum(np.abs(np.array(struc.cell) - np.array(cell))) - 0.), 7) == 0 + assert round( + abs(np.sum(np.abs(np.array([site.position for site in struc.sites]) - np.array(abscoords))) - 0.), 7 + ) == 0 + assert [site.kind_name for site in struc.sites] == \ + ['Ba', 'Ti', 'O', 'O', 'O', 'Ba2', 'BaTi', 'BaTi2', 'Ba3'] def test_from_aiida(self): """Test conversion of an AiiDA structure to a spglib tuple.""" - import numpy as np - from aiida.tools import structure_to_spglib_tuple cell = np.array([[4., 1., 0.], [0., 4., 0.], [0., 0., 4.]]) @@ -3380,18 +3352,16 @@ def test_from_aiida(self): abscoords = np.array([_.position for _ in struc.sites]) struc_relpos = np.dot(np.linalg.inv(cell.T), abscoords.T).T - self.assertAlmostEqual(np.sum(np.abs(np.array(struc.cell) - np.array(struc_tuple[0]))), 0.) - self.assertAlmostEqual(np.sum(np.abs(np.array(struc_tuple[1]) - struc_relpos)), 0.) + assert round(abs(np.sum(np.abs(np.array(struc.cell) - np.array(struc_tuple[0]))) - 0.), 7) == 0 + assert round(abs(np.sum(np.abs(np.array(struc_tuple[1]) - struc_relpos)) - 0.), 7) == 0 expected_kind_info = [kind_info[site.kind_name] for site in struc.sites] - self.assertEqual(struc_tuple[2], expected_kind_info) + assert struc_tuple[2] == expected_kind_info def test_aiida_roundtrip(self): """ Convert an AiiDA structure to a tuple and go back to see if we get the same results """ - import numpy as np - from aiida.tools import spglib_tuple_to_structure, structure_to_spglib_tuple cell = np.array([[4., 1., 0.], [0., 4., 0.], [0., 0., 4.]]) @@ -3408,57 +3378,47 @@ def test_aiida_roundtrip(self): struc_tuple, kind_info, kinds = structure_to_spglib_tuple(struc) roundtrip_struc = spglib_tuple_to_structure(struc_tuple, kind_info, kinds) - self.assertAlmostEqual(np.sum(np.abs(np.array(struc.cell) - np.array(roundtrip_struc.cell))), 0.) - self.assertEqual(struc.get_attribute('kinds'), roundtrip_struc.get_attribute('kinds')) - self.assertEqual([_.kind_name for _ in struc.sites], [_.kind_name for _ in roundtrip_struc.sites]) - self.assertEqual( - np.sum( - np.abs( - np.array([_.position for _ in struc.sites]) - np.array([_.position for _ in roundtrip_struc.sites]) - ) - ), 0. - ) - - -class TestSeekpathExplicitPath(AiidaTestCase): - """Tests for the `get_explicit_kpoints_path` from SeeK-path.""" - - @unittest.skipIf(not has_seekpath(), 'No seekpath available') - def test_simple(self): - """Test a simple case.""" - import numpy as np - - from aiida.plugins import DataFactory - from aiida.tools import get_explicit_kpoints_path - - structure = DataFactory('core.structure')(cell=[[4, 0, 0], [0, 4, 0], [0, 0, 6]]) - structure.append_atom(symbols='Ba', position=[0, 0, 0]) - structure.append_atom(symbols='Ti', position=[2, 2, 3]) - structure.append_atom(symbols='O', position=[2, 2, 0]) - structure.append_atom(symbols='O', position=[2, 0, 3]) - structure.append_atom(symbols='O', position=[0, 2, 3]) - - params = {'with_time_reversal': True, 'reference_distance': 0.025, 'recipe': 'hpkot', 'threshold': 1.e-7} - - return_value = get_explicit_kpoints_path(structure, method='seekpath', **params) - retdict = return_value['parameters'].get_dict() - - self.assertTrue(retdict['has_inversion_symmetry']) - self.assertFalse(retdict['augmented_path']) - self.assertAlmostEqual(retdict['volume_original_wrt_prim'], 1.0) - self.assertEqual( - to_list_of_lists(retdict['explicit_segments']), - [[0, 31], [30, 61], [60, 104], [103, 123], [122, 153], [152, 183], [182, 226], [226, 246], [246, 266]] - ) - - ret_k = return_value['explicit_kpoints'] - self.assertEqual( - to_list_of_lists(ret_k.labels), [[0, 'GAMMA'], [30, 'X'], [60, 'M'], [103, 'GAMMA'], [122, 'Z'], [152, 'R'], - [182, 'A'], [225, 'Z'], [226, 'X'], [245, 'R'], [246, 'M'], [265, 'A']] - ) - kpts = ret_k.get_kpoints(cartesian=False) - highsympoints_relcoords = [kpts[idx] for idx, label in ret_k.labels] - self.assertAlmostEqual( + assert round(abs(np.sum(np.abs(np.array(struc.cell) - np.array(roundtrip_struc.cell))) - 0.), 7) == 0 + assert struc.get_attribute('kinds') == roundtrip_struc.get_attribute('kinds') + assert [_.kind_name for _ in struc.sites] == [_.kind_name for _ in roundtrip_struc.sites] + assert np.sum( + np.abs(np.array([_.position for _ in struc.sites]) - np.array([_.position for _ in roundtrip_struc.sites])) + ) == 0. + + +@pytest.mark.skipif(not has_seekpath(), reason='No seekpath available') +@pytest.mark.usefixtures('aiida_profile_clean') +def test_seekpath_explicit_path(): + """"Tests the `get_explicit_kpoints_path` from SeeK-path.""" + from aiida.plugins import DataFactory + from aiida.tools import get_explicit_kpoints_path + + structure = DataFactory('core.structure')(cell=[[4, 0, 0], [0, 4, 0], [0, 0, 6]]) + structure.append_atom(symbols='Ba', position=[0, 0, 0]) + structure.append_atom(symbols='Ti', position=[2, 2, 3]) + structure.append_atom(symbols='O', position=[2, 2, 0]) + structure.append_atom(symbols='O', position=[2, 0, 3]) + structure.append_atom(symbols='O', position=[0, 2, 3]) + + params = {'with_time_reversal': True, 'reference_distance': 0.025, 'recipe': 'hpkot', 'threshold': 1.e-7} + + return_value = get_explicit_kpoints_path(structure, method='seekpath', **params) + retdict = return_value['parameters'].get_dict() + + assert retdict['has_inversion_symmetry'] + assert not retdict['augmented_path'] + assert round(abs(retdict['volume_original_wrt_prim'] - 1.0), 7) == 0 + assert to_list_of_lists(retdict['explicit_segments']) == \ + [[0, 31], [30, 61], [60, 104], [103, 123], [122, 153], [152, 183], [182, 226], [226, 246], [246, 266]] + + ret_k = return_value['explicit_kpoints'] + assert to_list_of_lists(ret_k.labels) == [[0, 'GAMMA'], [30, 'X'], [60, 'M'], [103, 'GAMMA'], [122, 'Z'], + [152, 'R'], [182, 'A'], [225, 'Z'], [226, 'X'], [245, 'R'], [246, 'M'], + [265, 'A']] + kpts = ret_k.get_kpoints(cartesian=False) + highsympoints_relcoords = [kpts[idx] for idx, label in ret_k.labels] + assert round( + abs( np.sum( np.abs( np.array([ @@ -3476,111 +3436,95 @@ def test_simple(self): [0.5, 0.5, 0.5], # A ]) - np.array(highsympoints_relcoords) ) - ), - 0. - ) - - ret_prims = return_value['primitive_structure'] - ret_convs = return_value['conv_structure'] - # The primitive structure should be the same as the one I input - self.assertAlmostEqual(np.sum(np.abs(np.array(structure.cell) - np.array(ret_prims.cell))), 0.) - self.assertEqual([_.kind_name for _ in structure.sites], [_.kind_name for _ in ret_prims.sites]) - self.assertEqual( - np.sum( - np. - abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_prims.sites])) - ), 0. - ) - - # Also the conventional structure should be the same as the one I input - self.assertAlmostEqual(np.sum(np.abs(np.array(structure.cell) - np.array(ret_convs.cell))), 0.) - self.assertEqual([_.kind_name for _ in structure.sites], [_.kind_name for _ in ret_convs.sites]) - self.assertEqual( - np.sum( - np. - abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_convs.sites])) - ), 0. - ) - - -class TestSeekpathPath(AiidaTestCase): - """Test Seekpath.""" - - @unittest.skipIf(not has_seekpath(), 'No seekpath available') - def test_simple(self): - """Test SeekPath for BaTiO3 structure.""" - import numpy as np - - from aiida.plugins import DataFactory - from aiida.tools import get_kpoints_path - - structure = DataFactory('core.structure')(cell=[[4, 0, 0], [0, 4, 0], [0, 0, 6]]) - structure.append_atom(symbols='Ba', position=[0, 0, 0]) - structure.append_atom(symbols='Ti', position=[2, 2, 3]) - structure.append_atom(symbols='O', position=[2, 2, 0]) - structure.append_atom(symbols='O', position=[2, 0, 3]) - structure.append_atom(symbols='O', position=[0, 2, 3]) - - params = {'with_time_reversal': True, 'recipe': 'hpkot', 'threshold': 1.e-7} - - return_value = get_kpoints_path(structure, method='seekpath', **params) - retdict = return_value['parameters'].get_dict() - - self.assertTrue(retdict['has_inversion_symmetry']) - self.assertFalse(retdict['augmented_path']) - self.assertAlmostEqual(retdict['volume_original_wrt_prim'], 1.0) - self.assertAlmostEqual(retdict['volume_original_wrt_conv'], 1.0) - self.assertEqual(retdict['bravais_lattice'], 'tP') - self.assertEqual(retdict['bravais_lattice_extended'], 'tP1') - self.assertEqual( - to_list_of_lists(retdict['path']), [['GAMMA', 'X'], ['X', 'M'], ['M', 'GAMMA'], ['GAMMA', 'Z'], ['Z', 'R'], - ['R', 'A'], ['A', 'Z'], ['X', 'R'], ['M', 'A']] - ) - - self.assertEqual( - retdict['point_coords'], { - 'A': [0.5, 0.5, 0.5], - 'M': [0.5, 0.5, 0.0], - 'R': [0.0, 0.5, 0.5], - 'X': [0.0, 0.5, 0.0], - 'Z': [0.0, 0.0, 0.5], - 'GAMMA': [0.0, 0.0, 0.0] - } - ) - - self.assertAlmostEqual( + ) - 0. + ), + 7 + ) == 0 + + ret_prims = return_value['primitive_structure'] + ret_convs = return_value['conv_structure'] + # The primitive structure should be the same as the one I input + assert round(abs(np.sum(np.abs(np.array(structure.cell) - np.array(ret_prims.cell))) - 0.), 7) == 0 + assert [_.kind_name for _ in structure.sites] == [_.kind_name for _ in ret_prims.sites] + assert np.sum( + np.abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_prims.sites])) + ) == 0. + + # Also the conventional structure should be the same as the one I input + assert round(abs(np.sum(np.abs(np.array(structure.cell) - np.array(ret_convs.cell))) - 0.), 7) == 0 + assert [_.kind_name for _ in structure.sites] == [_.kind_name for _ in ret_convs.sites] + assert np.sum( + np.abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_convs.sites])) + ) == 0. + + +@pytest.mark.skipif(not has_seekpath(), reason='No seekpath available') +@pytest.mark.usefixtures('aiida_profile_clean') +def test_seekpath(): + """Test SeekPath for BaTiO3 structure.""" + from aiida.plugins import DataFactory + from aiida.tools import get_kpoints_path + + structure = DataFactory('core.structure')(cell=[[4, 0, 0], [0, 4, 0], [0, 0, 6]]) + structure.append_atom(symbols='Ba', position=[0, 0, 0]) + structure.append_atom(symbols='Ti', position=[2, 2, 3]) + structure.append_atom(symbols='O', position=[2, 2, 0]) + structure.append_atom(symbols='O', position=[2, 0, 3]) + structure.append_atom(symbols='O', position=[0, 2, 3]) + + params = {'with_time_reversal': True, 'recipe': 'hpkot', 'threshold': 1.e-7} + + return_value = get_kpoints_path(structure, method='seekpath', **params) + retdict = return_value['parameters'].get_dict() + + assert retdict['has_inversion_symmetry'] + assert not retdict['augmented_path'] + assert round(abs(retdict['volume_original_wrt_prim'] - 1.0), 7) == 0 + assert round(abs(retdict['volume_original_wrt_conv'] - 1.0), 7) == 0 + assert retdict['bravais_lattice'] == 'tP' + assert retdict['bravais_lattice_extended'] == 'tP1' + assert to_list_of_lists(retdict['path']) == [['GAMMA', 'X'], ['X', 'M'], ['M', 'GAMMA'], ['GAMMA', 'Z'], ['Z', 'R'], + ['R', 'A'], ['A', 'Z'], ['X', 'R'], ['M', 'A']] + + assert retdict['point_coords'] == { + 'A': [0.5, 0.5, 0.5], + 'M': [0.5, 0.5, 0.0], + 'R': [0.0, 0.5, 0.5], + 'X': [0.0, 0.5, 0.0], + 'Z': [0.0, 0.0, 0.5], + 'GAMMA': [0.0, 0.0, 0.0] + } + + assert round( + abs( np.sum( np.abs( np.array(retdict['inverse_primitive_transformation_matrix']) - np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) ) - ), 0. - ) - - ret_prims = return_value['primitive_structure'] - ret_convs = return_value['conv_structure'] - # The primitive structure should be the same as the one I input - self.assertAlmostEqual(np.sum(np.abs(np.array(structure.cell) - np.array(ret_prims.cell))), 0.) - self.assertEqual([_.kind_name for _ in structure.sites], [_.kind_name for _ in ret_prims.sites]) - self.assertEqual( - np.sum( - np. - abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_prims.sites])) - ), 0. - ) - - # Also the conventional structure should be the same as the one I input - self.assertAlmostEqual(np.sum(np.abs(np.array(structure.cell) - np.array(ret_convs.cell))), 0.) - self.assertEqual([_.kind_name for _ in structure.sites], [_.kind_name for _ in ret_convs.sites]) - self.assertEqual( - np.sum( - np. - abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_convs.sites])) - ), 0. - ) - - -class TestBandsData(AiidaTestCase): + ) - 0. + ), 7 + ) == 0 + + ret_prims = return_value['primitive_structure'] + ret_convs = return_value['conv_structure'] + # The primitive structure should be the same as the one I input + assert round(abs(np.sum(np.abs(np.array(structure.cell) - np.array(ret_prims.cell))) - 0.), 7) == 0 + assert [_.kind_name for _ in structure.sites] == [_.kind_name for _ in ret_prims.sites] + assert np.sum( + np.abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_prims.sites])) + ) == 0. + + # Also the conventional structure should be the same as the one I input + assert round(abs(np.sum(np.abs(np.array(structure.cell) - np.array(ret_convs.cell))) - 0.), 7) == 0 + assert [_.kind_name for _ in structure.sites] == [_.kind_name for _ in ret_convs.sites] + assert np.sum( + np.abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_convs.sites])) + ) == 0. + + +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestBandsData: """ Tests the BandsData objects. """ @@ -3589,11 +3533,10 @@ def test_band(self): """ Check the methods to set and retrieve a mesh. """ - import numpy # define a cell alat = 4. - cell = numpy.array([ + cell = np.array([ [alat, 0., 0.], [0., alat, 0.], [0., 0., alat], @@ -3605,36 +3548,35 @@ def test_band(self): b = BandsData() b.set_kpointsdata(k) - self.assertTrue(numpy.array_equal(b.cell, k.cell)) + assert np.array_equal(b.cell, k.cell) - input_bands = numpy.array([numpy.ones(4) for i in range(k.get_kpoints().shape[0])]) + input_bands = np.array([np.ones(4) for i in range(k.get_kpoints().shape[0])]) input_occupations = input_bands b.set_bands(input_bands, occupations=input_occupations, units='ev') b.set_bands(input_bands, units='ev') b.set_bands(input_bands, occupations=input_occupations) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): b.set_bands(occupations=input_occupations, units='ev') # pylint: disable=no-value-for-parameter b.set_bands(input_bands, occupations=input_occupations, units='ev') bands, occupations = b.get_bands(also_occupations=True) - self.assertTrue(numpy.array_equal(bands, input_bands)) - self.assertTrue(numpy.array_equal(occupations, input_occupations)) - self.assertTrue(b.units == 'ev') + assert np.array_equal(bands, input_bands) + assert np.array_equal(occupations, input_occupations) + assert b.units == 'ev' b.store() - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): b.set_bands(bands) @staticmethod def test_export_to_file(): """Export the band structure on a file, check if it is working.""" - import numpy # define a cell alat = 4. - cell = numpy.array([ + cell = np.array([ [alat, 0., 0.], [0., alat, 0.], [0., 0., alat], @@ -3649,7 +3591,7 @@ def test_export_to_file(): # 4 bands with linearly increasing energies, it does not make sense # but is good for testing - input_bands = numpy.array([numpy.ones(4) * i for i in range(k.get_kpoints().shape[0])]) + input_bands = np.array([np.ones(4) * i for i in range(k.get_kpoints().shape[0])]) b.set_bands(input_bands, units='eV') diff --git a/tests/test_dbimporters.py b/tests/test_dbimporters.py index 9c23165515..cd19edf7a3 100644 --- a/tests/test_dbimporters.py +++ b/tests/test_dbimporters.py @@ -7,14 +7,15 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for subclasses of DbImporter, DbSearchResults and DbEntry""" -import unittest +import pytest -from aiida.storage.testbase import AiidaTestCase from tests.static import STATIC_DIR -class TestCodDbImporter(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestCodDbImporter: """Test the CodDbImporter class.""" from aiida.orm.nodes.data.cif import has_pycifrw @@ -45,28 +46,28 @@ def test_query_construction_1(self): q_sql = re.sub(r'(\d\.\d{6})\d+', r'\1', q_sql) q_sql = re.sub(r'(120.00)39+', r'\g<1>4', q_sql) - self.assertEqual(q_sql, \ - 'SELECT file, svnrevision FROM data WHERE ' - "(status IS NULL OR status != 'retracted') AND " - '(a BETWEEN 3.332333 AND 3.334333 OR ' - 'a BETWEEN 0.999 AND 1.001) AND ' - '(alpha BETWEEN 1.665666 AND 1.667666 OR ' - 'alpha BETWEEN -0.001 AND 0.001) AND ' - "(chemname LIKE '%caffeine%' OR " - "chemname LIKE '%serotonine%') AND " - "(method IN ('single crystal') OR method IS NULL) AND " - "(formula REGEXP ' C[0-9 ]' AND " - "formula REGEXP ' H[0-9 ]' AND " - "formula REGEXP ' Cl[0-9 ]') AND " - "(formula IN ('- C6 H6 -')) AND " - '(file IN (1000000, 3000000)) AND ' - '(cellpressure BETWEEN 999 AND 1001 OR ' - 'cellpressure BETWEEN 1000 AND 1002) AND ' - '(celltemp BETWEEN -0.001 AND 0.001 OR ' - 'celltemp BETWEEN 10.499 AND 10.501) AND ' - "(nel IN (5)) AND (sg IN ('P -1')) AND " - '(vol BETWEEN 99.999 AND 100.001 OR ' - 'vol BETWEEN 120.004 AND 120.006)') + assert q_sql == \ + 'SELECT file, svnrevision FROM data WHERE ' \ + "(status IS NULL OR status != 'retracted') AND " \ + '(a BETWEEN 3.332333 AND 3.334333 OR ' \ + 'a BETWEEN 0.999 AND 1.001) AND ' \ + '(alpha BETWEEN 1.665666 AND 1.667666 OR ' \ + 'alpha BETWEEN -0.001 AND 0.001) AND ' \ + "(chemname LIKE '%caffeine%' OR " \ + "chemname LIKE '%serotonine%') AND " \ + "(method IN ('single crystal') OR method IS NULL) AND " \ + "(formula REGEXP ' C[0-9 ]' AND " \ + "formula REGEXP ' H[0-9 ]' AND " \ + "formula REGEXP ' Cl[0-9 ]') AND " \ + "(formula IN ('- C6 H6 -')) AND " \ + '(file IN (1000000, 3000000)) AND ' \ + '(cellpressure BETWEEN 999 AND 1001 OR ' \ + 'cellpressure BETWEEN 1000 AND 1002) AND ' \ + '(celltemp BETWEEN -0.001 AND 0.001 OR ' \ + 'celltemp BETWEEN 10.499 AND 10.501) AND ' \ + "(nel IN (5)) AND (sg IN ('P -1')) AND " \ + '(vol BETWEEN 99.999 AND 100.001 OR ' \ + 'vol BETWEEN 120.004 AND 120.006)' def test_datatype_checks(self): """Rather complicated, but wide-coverage test for data types, accepted @@ -100,7 +101,7 @@ def test_datatype_checks(self): methods[i]('test', 'test', [values[j]]) except ValueError as exc: message = str(exc) - self.assertEqual(message, messages[results[i][j]]) + assert message == messages[results[i][j]] def test_dbentry_creation(self): """Tests the creation of CodEntry from CodSearchResults.""" @@ -116,25 +117,23 @@ def test_dbentry_creation(self): 'id': '2000000', 'svnrevision': '1234' }]) - self.assertEqual(len(results), 3) - self.assertEqual( - results.at(1).source, { - 'db_name': 'Crystallography Open Database', - 'db_uri': 'http://www.crystallography.net/cod', - 'extras': {}, - 'id': '1000001', - 'license': 'CC0', - 'source_md5': None, - 'uri': 'http://www.crystallography.net/cod/1000001.cif@1234', - 'version': '1234', - } - ) - self.assertEqual([x.source['uri'] for x in results], [ + assert len(results) == 3 + assert results.at(1).source == { + 'db_name': 'Crystallography Open Database', + 'db_uri': 'http://www.crystallography.net/cod', + 'extras': {}, + 'id': '1000001', + 'license': 'CC0', + 'source_md5': None, + 'uri': 'http://www.crystallography.net/cod/1000001.cif@1234', + 'version': '1234', + } + assert [x.source['uri'] for x in results] == [ 'http://www.crystallography.net/cod/1000000.cif', 'http://www.crystallography.net/cod/1000001.cif@1234', 'http://www.crystallography.net/cod/2000000.cif@1234' - ]) + ] - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_dbentry_to_cif_node(self): """Tests the creation of CifData node from CodEntry.""" from aiida.orm import CifData @@ -144,23 +143,22 @@ def test_dbentry_to_cif_node(self): entry.cif = "data_test _publ_section_title 'Test structure'" cif = entry.get_cif_node() - self.assertEqual(isinstance(cif, CifData), True) - self.assertEqual(cif.get_attribute('md5'), '070711e8e99108aade31d20cd5c94c48') - self.assertEqual( - cif.source, { - 'db_name': 'Crystallography Open Database', - 'db_uri': 'http://www.crystallography.net/cod', - 'id': None, - 'version': None, - 'extras': {}, - 'source_md5': '070711e8e99108aade31d20cd5c94c48', - 'uri': 'http://www.crystallography.net/cod/1000000.cif', - 'license': 'CC0', - } - ) - - -class TestTcodDbImporter(AiidaTestCase): + assert isinstance(cif, CifData) is True + assert cif.get_attribute('md5') == '070711e8e99108aade31d20cd5c94c48' + assert cif.source == { + 'db_name': 'Crystallography Open Database', + 'db_uri': 'http://www.crystallography.net/cod', + 'id': None, + 'version': None, + 'extras': {}, + 'source_md5': '070711e8e99108aade31d20cd5c94c48', + 'uri': 'http://www.crystallography.net/cod/1000000.cif', + 'license': 'CC0', + } + + +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestTcodDbImporter: """Test the TcodDbImporter class.""" def test_dbentry_creation(self): @@ -177,26 +175,25 @@ def test_dbentry_creation(self): 'id': '20000000', 'svnrevision': '1234' }]) - self.assertEqual(len(results), 3) - self.assertEqual( - results.at(1).source, { - 'db_name': 'Theoretical Crystallography Open Database', - 'db_uri': 'http://www.crystallography.net/tcod', - 'extras': {}, - 'id': '10000001', - 'license': 'CC0', - 'source_md5': None, - 'uri': 'http://www.crystallography.net/tcod/10000001.cif@1234', - 'version': '1234', - } - ) - self.assertEqual([x.source['uri'] for x in results], [ + assert len(results) == 3 + assert results.at(1).source == { + 'db_name': 'Theoretical Crystallography Open Database', + 'db_uri': 'http://www.crystallography.net/tcod', + 'extras': {}, + 'id': '10000001', + 'license': 'CC0', + 'source_md5': None, + 'uri': 'http://www.crystallography.net/tcod/10000001.cif@1234', + 'version': '1234', + } + assert [x.source['uri'] for x in results] == [ 'http://www.crystallography.net/tcod/10000000.cif', 'http://www.crystallography.net/tcod/10000001.cif@1234', 'http://www.crystallography.net/tcod/20000000.cif@1234' - ]) + ] -class TestPcodDbImporter(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestPcodDbImporter: """Test the PcodDbImporter class.""" def test_dbentry_creation(self): @@ -204,22 +201,21 @@ def test_dbentry_creation(self): from aiida.tools.dbimporters.plugins.pcod import PcodSearchResults results = PcodSearchResults([{'id': '12345678'}]) - self.assertEqual(len(results), 1) - self.assertEqual( - results.at(0).source, { - 'db_name': 'Predicted Crystallography Open Database', - 'db_uri': 'http://www.crystallography.net/pcod', - 'extras': {}, - 'id': '12345678', - 'license': 'CC0', - 'source_md5': None, - 'uri': 'http://www.crystallography.net/pcod/cif/1/123/12345678.cif', - 'version': None, - } - ) - - -class TestMpodDbImporter(AiidaTestCase): + assert len(results) == 1 + assert results.at(0).source == { + 'db_name': 'Predicted Crystallography Open Database', + 'db_uri': 'http://www.crystallography.net/pcod', + 'extras': {}, + 'id': '12345678', + 'license': 'CC0', + 'source_md5': None, + 'uri': 'http://www.crystallography.net/pcod/cif/1/123/12345678.cif', + 'version': None, + } + + +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestMpodDbImporter: """Test the MpodDbImporter class.""" def test_dbentry_creation(self): @@ -227,22 +223,21 @@ def test_dbentry_creation(self): from aiida.tools.dbimporters.plugins.mpod import MpodSearchResults results = MpodSearchResults([{'id': '1234567'}]) - self.assertEqual(len(results), 1) - self.assertEqual( - results.at(0).source, { - 'db_name': 'Material Properties Open Database', - 'db_uri': 'http://mpod.cimav.edu.mx', - 'extras': {}, - 'id': '1234567', - 'license': None, - 'source_md5': None, - 'uri': 'http://mpod.cimav.edu.mx/datafiles/1234567.mpod', - 'version': None, - } - ) - - -class TestNnincDbImporter(AiidaTestCase): + assert len(results) == 1 + assert results.at(0).source == { + 'db_name': 'Material Properties Open Database', + 'db_uri': 'http://mpod.cimav.edu.mx', + 'extras': {}, + 'id': '1234567', + 'license': None, + 'source_md5': None, + 'uri': 'http://mpod.cimav.edu.mx/datafiles/1234567.mpod', + 'version': None, + } + + +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestNnincDbImporter: """Test the UpfEntry class.""" def test_upfentry_creation(self): @@ -262,7 +257,7 @@ def test_upfentry_creation(self): entry._contents = fpntr.read() # pylint: disable=protected-access upfnode = entry.get_upf_node() - self.assertEqual(upfnode.element, 'Ba') + assert upfnode.element == 'Ba' entry.source = {'id': 'O.pbesol-n-rrkjus_psl.0.1-tested-pslib030.UPF'} @@ -270,5 +265,5 @@ def test_upfentry_creation(self): # thus UpfData parser will complain about the mismatch of chemical # element, mentioned in file name, and the one described in the # pseudopotential file. - with self.assertRaises(ParsingError): + with pytest.raises(ParsingError): upfnode = entry.get_upf_node() diff --git a/tests/test_generic.py b/tests/test_generic.py index 4df00cbee1..c0e98792ee 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -7,96 +7,98 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=unused-argument """Generic tests that need the use of the DB.""" +import pytest from aiida import orm -from aiida.storage.testbase import AiidaTestCase -class TestCode(AiidaTestCase): - """Test the Code class.""" +def test_code_local(aiida_profile_clean, aiida_localhost): + """Test local code.""" + import tempfile - def test_code_local(self): - """Test local code.""" - import tempfile + from aiida.common.exceptions import ValidationError + from aiida.orm import Code - from aiida.common.exceptions import ValidationError - from aiida.orm import Code - - code = Code(local_executable='test.sh') - with self.assertRaises(ValidationError): - # No file with name test.sh - code.store() + code = Code(local_executable='test.sh') + with pytest.raises(ValidationError): + # No file with name test.sh + code.store() - with tempfile.NamedTemporaryFile(mode='w+') as fhandle: - fhandle.write('#/bin/bash\n\necho test run\n') - fhandle.flush() - code.put_object_from_filelike(fhandle, 'test.sh') + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write('#/bin/bash\n\necho test run\n') + fhandle.flush() + code.put_object_from_filelike(fhandle, 'test.sh') - code.store() - self.assertTrue(code.can_run_on(self.computer)) - self.assertTrue(code.get_local_executable(), 'test.sh') - self.assertTrue(code.get_execname(), 'stest.sh') + code.store() + assert code.can_run_on(aiida_localhost) + assert code.get_local_executable(), 'test.sh' + assert code.get_execname(), 'stest.sh' - def test_remote(self): - """Test remote code.""" - import tempfile - from aiida.common.exceptions import ValidationError - from aiida.orm import Code +def test_code_remote(aiida_profile_clean, aiida_localhost): + """Test remote code.""" + import tempfile - with self.assertRaises(ValueError): - # remote_computer_exec has length 2 but is not a list or tuple - Code(remote_computer_exec='ab') + from aiida.common.exceptions import ValidationError + from aiida.orm import Code - # invalid code path - with self.assertRaises(ValueError): - Code(remote_computer_exec=(self.computer, '')) + with pytest.raises(ValueError): + # remote_computer_exec has length 2 but is not a list or tuple + Code(remote_computer_exec='ab') - # Relative path is invalid for remote code - with self.assertRaises(ValueError): - Code(remote_computer_exec=(self.computer, 'subdir/run.exe')) + # invalid code path + with pytest.raises(ValueError): + Code(remote_computer_exec=(aiida_localhost, '')) - # first argument should be a computer, not a string - with self.assertRaises(TypeError): - Code(remote_computer_exec=('localhost', '/bin/ls')) + # Relative path is invalid for remote code + with pytest.raises(ValueError): + Code(remote_computer_exec=(aiida_localhost, 'subdir/run.exe')) - code = Code(remote_computer_exec=(self.computer, '/bin/ls')) - with tempfile.NamedTemporaryFile(mode='w+') as fhandle: - fhandle.write('#/bin/bash\n\necho test run\n') - fhandle.flush() - code.put_object_from_filelike(fhandle, 'test.sh') + # first argument should be a computer, not a string + with pytest.raises(TypeError): + Code(remote_computer_exec=('localhost', '/bin/ls')) - with self.assertRaises(ValidationError): - # There are files inside - code.store() + code = Code(remote_computer_exec=(aiida_localhost, '/bin/ls')) + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write('#/bin/bash\n\necho test run\n') + fhandle.flush() + code.put_object_from_filelike(fhandle, 'test.sh') - # If there are no files, I can store - code.delete_object('test.sh') + with pytest.raises(ValidationError): + # There are files inside code.store() - self.assertEqual(code.get_remote_computer().pk, self.computer.pk) # pylint: disable=no-member - self.assertEqual(code.get_remote_exec_path(), '/bin/ls') - self.assertEqual(code.get_execname(), '/bin/ls') + # If there are no files, I can store + code.delete_object('test.sh') + code.store() + + assert code.get_remote_computer().pk == aiida_localhost.pk # pylint: disable=no-member + assert code.get_remote_exec_path() == '/bin/ls' + assert code.get_execname() == '/bin/ls' - self.assertTrue(code.can_run_on(self.computer)) - othercomputer = orm.Computer( - label='another_localhost', - hostname='localhost', - transport_type='core.local', - scheduler_type='core.pbspro', - workdir='/tmp/aiida' - ).store() - self.assertFalse(code.can_run_on(othercomputer)) + assert code.can_run_on(aiida_localhost) + othercomputer = orm.Computer( + label='another_localhost', + hostname='localhost', + transport_type='core.local', + scheduler_type='core.pbspro', + workdir='/tmp/aiida' + ).store() + assert not code.can_run_on(othercomputer) -class TestBool(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestBool: """Test AiiDA Bool class.""" - def test_bool_conversion(self): + @staticmethod + def test_bool_conversion(): for val in [True, False]: - self.assertEqual(val, bool(orm.Bool(val))) + assert val == bool(orm.Bool(val)) - def test_int_conversion(self): + @staticmethod + def test_int_conversion(): for val in [True, False]: - self.assertEqual(int(val), int(orm.Bool(val))) + assert int(val) == int(orm.Bool(val)) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index ce55867571..5fdeae14da 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -7,31 +7,30 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines,invalid-name,protected-access -# pylint: disable=missing-docstring,too-many-locals,too-many-statements -# pylint: disable=too-many-public-methods,no-member +# pylint: disable=too-many-lines,invalid-name,protected-access,missing-docstring,too-many-locals,too-many-statements +# pylint: disable=too-many-public-methods,no-member,no-self-use import copy import io import tempfile import pytest -from aiida import orm +from aiida import get_profile, orm from aiida.common.exceptions import InvalidOperation, ModificationNotAllowed, StoringNotAllowed, ValidationError from aiida.common.links import LinkType -from aiida.storage.testbase import AiidaTestCase from aiida.tools import delete_group_nodes, delete_nodes -class TestNodeIsStorable(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestNodeIsStorable: """Test that checks on storability of certain node sub classes work correctly.""" def test_base_classes(self): """Test storability of `Node` base sub classes.""" - with self.assertRaises(StoringNotAllowed): + with pytest.raises(StoringNotAllowed): orm.Node().store() - with self.assertRaises(StoringNotAllowed): + with pytest.raises(StoringNotAllowed): orm.ProcessNode().store() # The following base classes are storable @@ -45,27 +44,29 @@ def test_unregistered_sub_class(self): class SubData(orm.Data): pass - with self.assertRaises(StoringNotAllowed): + with pytest.raises(StoringNotAllowed): SubData().store() -class TestNodeCopyDeepcopy(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestNodeCopyDeepcopy: """Test that calling copy and deepcopy on a Node does the right thing.""" def test_copy_not_supported(self): """Copying a base Node instance is not supported.""" node = orm.Node() - with self.assertRaises(InvalidOperation): + with pytest.raises(InvalidOperation): copy.copy(node) def test_deepcopy_not_supported(self): """Deep copying a base Node instance is not supported.""" node = orm.Node() - with self.assertRaises(InvalidOperation): + with pytest.raises(InvalidOperation): copy.deepcopy(node) -class TestNodeHashing(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestNodeHashing: """ Tests the functionality of hashing a node """ @@ -97,8 +98,8 @@ def test_node_uuid_hashing_for_querybuidler(self): # Check that the query doesn't fail qb.all() # And that the results are correct - self.assertEqual(qb.count(), 1) - self.assertEqual(qb.first()[0], n.id) + assert qb.count() == 1 + assert qb.first()[0] == n.id @staticmethod def create_folderdata_with_empty_file(): @@ -129,11 +130,12 @@ def test_updatable_attributes(self): hash1 = node.get_hash() node.set_process_state('finished') hash2 = node.get_hash() - self.assertNotEqual(hash1, None) - self.assertEqual(hash1, hash2) + assert hash1 is not None + assert hash1 == hash2 -class TestTransitiveNoLoops(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestTransitiveNoLoops: """ Test the transitive closure functionality """ @@ -150,11 +152,12 @@ def test_loop_not_allowed(self): c2.add_incoming(d2, link_type=LinkType.INPUT_CALC, link_label='link') c2.store() - with self.assertRaises(ValueError): # This would generate a loop + with pytest.raises(ValueError): # This would generate a loop d1.add_incoming(c2, link_type=LinkType.CREATE, link_label='link') -class TestTypes(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestTypes: """ Generic test class to test types """ @@ -168,16 +171,22 @@ def test_uuid_type(self): results = orm.QueryBuilder().append(orm.Data, project=('uuid', '*')).all() for uuid, data in results: - self.assertTrue(isinstance(uuid, str)) - self.assertTrue(isinstance(data.uuid, str)) + assert isinstance(uuid, str) + assert isinstance(data.uuid, str) -class TestQueryWithAiidaObjects(AiidaTestCase): +class TestQueryWithAiidaObjects: """ Test if queries work properly also with aiida.orm.Node classes instead of backend model objects. """ + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost + def test_with_subclasses(self): from aiida.plugins import DataFactory @@ -208,28 +217,28 @@ def test_with_subclasses(self): results = qb.all(flat=True) # a3, a4 should not be found because they are not CalcJobNodes. # a6, a7 should not be found because they have not the attribute set. - self.assertEqual({i.pk for i in results}, {a1.pk}) + assert {i.pk for i in results} == {a1.pk} # Same query, but by the generic Node class qb = orm.QueryBuilder() qb.append(orm.Node, filters={'extras': {'has_key': extra_name}}) results = qb.all(flat=True) - self.assertEqual({i.pk for i in results}, {a1.pk, a3.pk, a4.pk}) + assert {i.pk for i in results} == {a1.pk, a3.pk, a4.pk} # Same query, but by the Data class qb = orm.QueryBuilder() qb.append(orm.Data, filters={'extras': {'has_key': extra_name}}) results = qb.all(flat=True) - self.assertEqual({i.pk for i in results}, {a3.pk, a4.pk}) + assert {i.pk for i in results} == {a3.pk, a4.pk} # Same query, but by the Dict subclass qb = orm.QueryBuilder() qb.append(orm.Dict, filters={'extras': {'has_key': extra_name}}) results = qb.all(flat=True) - self.assertEqual({i.pk for i in results}, {a4.pk}) + assert {i.pk for i in results} == {a4.pk} -class TestNodeBasic(AiidaTestCase): +class TestNodeBasic: """ These tests check the basic features of nodes (setting of attributes, copying of files, ...) @@ -260,6 +269,12 @@ class TestNodeBasic(AiidaTestCase): emptydict = {} emptylist = [] + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost + def test_uuid_uniquess(self): """ A uniqueness constraint on the UUID column of the Node model should prevent multiple nodes with identical UUID @@ -271,7 +286,7 @@ def test_uuid_uniquess(self): b.backend_entity.bare_model.uuid = a.uuid a.store() - with self.assertRaises(SqlaIntegrityError): + with pytest.raises(SqlaIntegrityError): b.store() def test_attribute_mutability(self): @@ -285,10 +300,10 @@ def test_attribute_mutability(self): a.store() # After storing attributes should now be immutable - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.delete_attribute('bool') - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.set_attribute('integer', self.intval) def test_attr_before_storing(self): @@ -304,15 +319,15 @@ def test_attr_before_storing(self): a.set_attribute('k9', None) # Now I check if I can retrieve them, before the storage - self.assertEqual(self.boolval, a.get_attribute('k1')) - self.assertEqual(self.intval, a.get_attribute('k2')) - self.assertEqual(self.floatval, a.get_attribute('k3')) - self.assertEqual(self.stringval, a.get_attribute('k4')) - self.assertEqual(self.dictval, a.get_attribute('k5')) - self.assertEqual(self.listval, a.get_attribute('k6')) - self.assertEqual(self.emptydict, a.get_attribute('k7')) - self.assertEqual(self.emptylist, a.get_attribute('k8')) - self.assertIsNone(a.get_attribute('k9')) + assert self.boolval == a.get_attribute('k1') + assert self.intval == a.get_attribute('k2') + assert self.floatval == a.get_attribute('k3') + assert self.stringval == a.get_attribute('k4') + assert self.dictval == a.get_attribute('k5') + assert self.listval == a.get_attribute('k6') + assert self.emptydict == a.get_attribute('k7') + assert self.emptylist == a.get_attribute('k8') + assert a.get_attribute('k9') is None # And now I try to delete the keys a.delete_attribute('k1') @@ -325,19 +340,19 @@ def test_attr_before_storing(self): a.delete_attribute('k8') a.delete_attribute('k9') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): # I delete twice the same attribute a.delete_attribute('k1') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): # I delete a non-existing attribute a.delete_attribute('nonexisting') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): # I get a deleted attribute a.get_attribute('k1') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): # I get a non-existing attribute a.get_attribute('nonexisting') @@ -366,7 +381,7 @@ def test_get_attrs_before_storing(self): } # Now I check if I can retrieve them, before the storage - self.assertEqual(a.attributes, target_attrs) + assert a.attributes == target_attrs # And now I try to delete the keys a.delete_attribute('k1') @@ -379,7 +394,7 @@ def test_get_attrs_before_storing(self): a.delete_attribute('k8') a.delete_attribute('k9') - self.assertEqual(a.attributes, {}) + assert a.attributes == {} def test_get_attrs_after_storing(self): a = orm.Data() @@ -408,19 +423,19 @@ def test_get_attrs_after_storing(self): } # Now I check if I can retrieve them, before the storage - self.assertEqual(a.attributes, target_attrs) + assert a.attributes == target_attrs def test_store_object(self): """Trying to set objects as attributes should fail, because they are not json-serializable.""" a = orm.Data() a.set_attribute('object', object()) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): a.store() b = orm.Data() b.set_attribute('object_list', [object(), object()]) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): b.store() def test_attributes_on_clone(self): @@ -451,9 +466,9 @@ def test_attributes_on_clone(self): b_expected_attributes['new'] = 'cvb' # I check before storing that the attributes are ok - self.assertEqual(b.attributes, b_expected_attributes) + assert b.attributes == b_expected_attributes # Note that during copy, I do not copy the extras! - self.assertEqual(b.extras, {}) + assert b.extras == {} # I store now b.store() @@ -462,11 +477,11 @@ def test_attributes_on_clone(self): b_expected_extras = {'meta': 'textofext', '_aiida_hash': AnyValue()} # Now I check that the attributes of the original node have not changed - self.assertEqual(a.attributes, attrs_to_set) + assert a.attributes == attrs_to_set # I check then on the 'b' copy - self.assertEqual(b.attributes, b_expected_attributes) - self.assertEqual(b.extras, b_expected_extras) + assert b.attributes == b_expected_attributes + assert b.extras == b_expected_extras def test_files(self): a = orm.Data() @@ -480,21 +495,21 @@ def test_files(self): a.put_object_from_file(handle.name, 'file1.txt') a.put_object_from_file(handle.name, 'file2.txt') - self.assertEqual(set(a.list_object_names()), set(['file1.txt', 'file2.txt'])) + assert set(a.list_object_names()) == set(['file1.txt', 'file2.txt']) with a.open('file1.txt') as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with a.open('file2.txt') as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content b = a.clone() - self.assertNotEqual(a.uuid, b.uuid) + assert a.uuid != b.uuid # Check that the content is there - self.assertEqual(set(b.list_object_names()), set(['file1.txt', 'file2.txt'])) + assert set(b.list_object_names()) == set(['file1.txt', 'file2.txt']) with b.open('file1.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content with b.open('file2.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content # I overwrite a file and create a new one in the clone only with tempfile.NamedTemporaryFile(mode='w+') as handle: @@ -504,18 +519,18 @@ def test_files(self): b.put_object_from_file(handle.name, 'file3.txt') # I check the new content, and that the old one has not changed - self.assertEqual(set(a.list_object_names()), set(['file1.txt', 'file2.txt'])) + assert set(a.list_object_names()) == set(['file1.txt', 'file2.txt']) with a.open('file1.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content with a.open('file2.txt') as handle: - self.assertEqual(handle.read(), file_content) - self.assertEqual(set(b.list_object_names()), set(['file1.txt', 'file2.txt', 'file3.txt'])) + assert handle.read() == file_content + assert set(b.list_object_names()) == set(['file1.txt', 'file2.txt', 'file3.txt']) with b.open('file1.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content with b.open('file2.txt') as handle: - self.assertEqual(handle.read(), file_content_different) + assert handle.read() == file_content_different with b.open('file3.txt') as handle: - self.assertEqual(handle.read(), file_content_different) + assert handle.read() == file_content_different # This should in principle change the location of the files, # so I recheck @@ -530,19 +545,19 @@ def test_files(self): c.put_object_from_file(handle.name, 'file1.txt') c.put_object_from_file(handle.name, 'file4.txt') - self.assertEqual(set(a.list_object_names()), set(['file1.txt', 'file2.txt'])) + assert set(a.list_object_names()) == set(['file1.txt', 'file2.txt']) with a.open('file1.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content with a.open('file2.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content - self.assertEqual(set(c.list_object_names()), set(['file1.txt', 'file2.txt', 'file4.txt'])) + assert set(c.list_object_names()) == set(['file1.txt', 'file2.txt', 'file4.txt']) with c.open('file1.txt') as handle: - self.assertEqual(handle.read(), file_content_different) + assert handle.read() == file_content_different with c.open('file2.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content with c.open('file4.txt') as handle: - self.assertEqual(handle.read(), file_content_different) + assert handle.read() == file_content_different @pytest.mark.skip('relies on deleting folders from the repo which is not yet implemented') def test_folders(self): @@ -582,30 +597,30 @@ def test_folders(self): a.put_object_from_tree(tree_1, 'tree_1') # verify if the node has the structure I expect - self.assertEqual(set(a.list_object_names()), set(['tree_1'])) - self.assertEqual(set(a.list_object_names('tree_1')), set(['file1.txt', 'dir1'])) - self.assertEqual(set(a.list_object_names(os.path.join('tree_1', 'dir1'))), set(['dir2', 'file2.txt'])) + assert set(a.list_object_names()) == set(['tree_1']) + assert set(a.list_object_names('tree_1')) == set(['file1.txt', 'dir1']) + assert set(a.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['dir2', 'file2.txt']) with a.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with a.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # try to exit from the folder - with self.assertRaises(FileNotFoundError): + with pytest.raises(FileNotFoundError): a.list_object_names('..') # clone into a new node b = a.clone() - self.assertNotEqual(a.uuid, b.uuid) + assert a.uuid != b.uuid # Check that the content is there - self.assertEqual(set(b.list_object_names()), set(['tree_1'])) - self.assertEqual(set(b.list_object_names('tree_1')), set(['file1.txt', 'dir1'])) - self.assertEqual(set(b.list_object_names(os.path.join('tree_1', 'dir1'))), set(['dir2', 'file2.txt'])) + assert set(b.list_object_names()) == set(['tree_1']) + assert set(b.list_object_names('tree_1')) == set(['file1.txt', 'dir1']) + assert set(b.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['dir2', 'file2.txt']) with b.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with b.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # I overwrite a file and create a new one in the copy only dir3 = os.path.join(directory, 'dir3') @@ -613,28 +628,28 @@ def test_folders(self): b.put_object_from_tree(dir3, os.path.join('tree_1', 'dir3')) # no absolute path here - with self.assertRaises(TypeError): + with pytest.raises(TypeError): b.put_object_from_tree('dir3', os.path.join('tree_1', 'dir3')) stream = io.StringIO(file_content_different) b.put_object_from_filelike(stream, 'file3.txt') # I check the new content, and that the old one has not changed old - self.assertEqual(set(a.list_object_names()), set(['tree_1'])) - self.assertEqual(set(a.list_object_names('tree_1')), set(['file1.txt', 'dir1'])) - self.assertEqual(set(a.list_object_names(os.path.join('tree_1', 'dir1'))), set(['dir2', 'file2.txt'])) + assert set(a.list_object_names()) == set(['tree_1']) + assert set(a.list_object_names('tree_1')) == set(['file1.txt', 'dir1']) + assert set(a.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['dir2', 'file2.txt']) with a.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with a.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # new - self.assertEqual(set(b.list_object_names()), set(['tree_1', 'file3.txt'])) - self.assertEqual(set(b.list_object_names('tree_1')), set(['file1.txt', 'dir1', 'dir3'])) - self.assertEqual(set(b.list_object_names(os.path.join('tree_1', 'dir1'))), set(['dir2', 'file2.txt'])) + assert set(b.list_object_names()) == set(['tree_1', 'file3.txt']) + assert set(b.list_object_names('tree_1')) == set(['file1.txt', 'dir1', 'dir3']) + assert set(b.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['dir2', 'file2.txt']) with b.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with b.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # This should in principle change the location of the files, so I recheck a.store() @@ -649,22 +664,22 @@ def test_folders(self): c.delete_object(os.path.join('tree_1', 'dir1', 'dir2')) # check old - self.assertEqual(set(a.list_object_names()), set(['tree_1'])) - self.assertEqual(set(a.list_object_names('tree_1')), set(['file1.txt', 'dir1'])) - self.assertEqual(set(a.list_object_names(os.path.join('tree_1', 'dir1'))), set(['dir2', 'file2.txt'])) + assert set(a.list_object_names()) == set(['tree_1']) + assert set(a.list_object_names('tree_1')) == set(['file1.txt', 'dir1']) + assert set(a.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['dir2', 'file2.txt']) with a.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with a.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # check new - self.assertEqual(set(c.list_object_names()), set(['tree_1'])) - self.assertEqual(set(c.list_object_names('tree_1')), set(['file1.txt', 'dir1'])) - self.assertEqual(set(c.list_object_names(os.path.join('tree_1', 'dir1'))), set(['file2.txt', 'file4.txt'])) + assert set(c.list_object_names()) == set(['tree_1']) + assert set(c.list_object_names('tree_1')) == set(['file1.txt', 'dir1']) + assert set(c.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['file2.txt', 'file4.txt']) with c.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content_different) + assert fhandle.read() == file_content_different with c.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # garbage cleaning shutil.rmtree(directory) @@ -682,13 +697,13 @@ def test_attr_after_storing(self): a.store() # Now I check if I can retrieve them, before the storage - self.assertIsNone(a.get_attribute('none')) - self.assertEqual(self.boolval, a.get_attribute('bool')) - self.assertEqual(self.intval, a.get_attribute('integer')) - self.assertEqual(self.floatval, a.get_attribute('float')) - self.assertEqual(self.stringval, a.get_attribute('string')) - self.assertEqual(self.dictval, a.get_attribute('dict')) - self.assertEqual(self.listval, a.get_attribute('list')) + assert a.get_attribute('none') is None + assert self.boolval == a.get_attribute('bool') + assert self.intval == a.get_attribute('integer') + assert self.floatval == a.get_attribute('float') + assert self.stringval == a.get_attribute('string') + assert self.dictval == a.get_attribute('dict') + assert self.listval == a.get_attribute('list') def test_attr_with_reload(self): a = orm.Data() @@ -703,13 +718,13 @@ def test_attr_with_reload(self): a.store() b = orm.load_node(uuid=a.uuid) - self.assertIsNone(a.get_attribute('none')) - self.assertEqual(self.boolval, b.get_attribute('bool')) - self.assertEqual(self.intval, b.get_attribute('integer')) - self.assertEqual(self.floatval, b.get_attribute('float')) - self.assertEqual(self.stringval, b.get_attribute('string')) - self.assertEqual(self.dictval, b.get_attribute('dict')) - self.assertEqual(self.listval, b.get_attribute('list')) + assert a.get_attribute('none') is None + assert self.boolval == b.get_attribute('bool') + assert self.intval == b.get_attribute('integer') + assert self.floatval == b.get_attribute('float') + assert self.stringval == b.get_attribute('string') + assert self.dictval == b.get_attribute('dict') + assert self.listval == b.get_attribute('list') def test_extra_with_reload(self): a = orm.Data() @@ -722,42 +737,42 @@ def test_extra_with_reload(self): a.set_extra('list', self.listval) # Check before storing - self.assertEqual(self.boolval, a.get_extra('bool')) - self.assertEqual(self.intval, a.get_extra('integer')) - self.assertEqual(self.floatval, a.get_extra('float')) - self.assertEqual(self.stringval, a.get_extra('string')) - self.assertEqual(self.dictval, a.get_extra('dict')) - self.assertEqual(self.listval, a.get_extra('list')) + assert self.boolval == a.get_extra('bool') + assert self.intval == a.get_extra('integer') + assert self.floatval == a.get_extra('float') + assert self.stringval == a.get_extra('string') + assert self.dictval == a.get_extra('dict') + assert self.listval == a.get_extra('list') a.store() # Check after storing - self.assertEqual(self.boolval, a.get_extra('bool')) - self.assertEqual(self.intval, a.get_extra('integer')) - self.assertEqual(self.floatval, a.get_extra('float')) - self.assertEqual(self.stringval, a.get_extra('string')) - self.assertEqual(self.dictval, a.get_extra('dict')) - self.assertEqual(self.listval, a.get_extra('list')) + assert self.boolval == a.get_extra('bool') + assert self.intval == a.get_extra('integer') + assert self.floatval == a.get_extra('float') + assert self.stringval == a.get_extra('string') + assert self.dictval == a.get_extra('dict') + assert self.listval == a.get_extra('list') b = orm.load_node(uuid=a.uuid) - self.assertIsNone(a.get_extra('none')) - self.assertEqual(self.boolval, b.get_extra('bool')) - self.assertEqual(self.intval, b.get_extra('integer')) - self.assertEqual(self.floatval, b.get_extra('float')) - self.assertEqual(self.stringval, b.get_extra('string')) - self.assertEqual(self.dictval, b.get_extra('dict')) - self.assertEqual(self.listval, b.get_extra('list')) + assert a.get_extra('none') is None + assert self.boolval == b.get_extra('bool') + assert self.intval == b.get_extra('integer') + assert self.floatval == b.get_extra('float') + assert self.stringval == b.get_extra('string') + assert self.dictval == b.get_extra('dict') + assert self.listval == b.get_extra('list') def test_get_extras_with_default(self): a = orm.Data() a.store() a.set_extra('a', 'b') - self.assertEqual(a.get_extra('a'), 'b') - with self.assertRaises(AttributeError): + assert a.get_extra('a') == 'b' + with pytest.raises(AttributeError): a.get_extra('c') - self.assertEqual(a.get_extra('c', 'def'), 'def') + assert a.get_extra('c', 'def') == 'def' @staticmethod def test_attr_and_extras_multikey(): @@ -801,12 +816,12 @@ def test_attr_listing(self): all_extras = dict(_aiida_hash=AnyValue(), **extras_to_set) - self.assertEqual(set(list(a.attributes.keys())), set(attrs_to_set.keys())) - self.assertEqual(set(list(a.extras.keys())), set(all_extras.keys())) + assert set(list(a.attributes.keys())) == set(attrs_to_set.keys()) + assert set(list(a.extras.keys())) == set(all_extras.keys()) - self.assertEqual(a.attributes, attrs_to_set) + assert a.attributes == attrs_to_set - self.assertEqual(a.extras, all_extras) + assert a.extras == all_extras def test_delete_extras(self): """ @@ -830,7 +845,7 @@ def test_delete_extras(self): for k, v in extras_to_set.items(): a.set_extra(k, v) - self.assertEqual(a.extras, all_extras) + assert a.extras == all_extras # I pregenerate it, it cannot change during iteration list_keys = list(extras_to_set.keys()) @@ -839,7 +854,7 @@ def test_delete_extras(self): # performed correctly a.delete_extra(k) del all_extras[k] - self.assertEqual(a.extras, all_extras) + assert a.extras == all_extras def test_replace_extras_1(self): """ @@ -882,7 +897,7 @@ def test_replace_extras_1(self): for k, v in extras_to_set.items(): a.set_extra(k, v) - self.assertEqual(a.extras, all_extras) + assert a.extras == all_extras for k, v in new_extras.items(): # I delete one by one the keys and check if the operation is @@ -892,7 +907,7 @@ def test_replace_extras_1(self): # I update extras_to_set with the new entries, and do the comparison # again all_extras.update(new_extras) - self.assertEqual(a.extras, all_extras) + assert a.extras == all_extras def test_basetype_as_attr(self): """ @@ -910,28 +925,28 @@ def test_basetype_as_attr(self): # Manages to store, and value is converted to its base type p = orm.Dict(dict={'b': orm.Str('sometext'), 'c': l1}) p.store() - self.assertEqual(p.get_attribute('b'), 'sometext') - self.assertIsInstance(p.get_attribute('b'), str) - self.assertEqual(p.get_attribute('c'), ['b', [1, 2]]) - self.assertIsInstance(p.get_attribute('c'), (list, tuple)) + assert p.get_attribute('b') == 'sometext' + assert isinstance(p.get_attribute('b'), str) + assert p.get_attribute('c') == ['b', [1, 2]] + assert isinstance(p.get_attribute('c'), (list, tuple)) # Check also before storing n = orm.Data() n.set_attribute('a', orm.Str('sometext2')) n.set_attribute('b', l2) - self.assertEqual(n.get_attribute('a').value, 'sometext2') - self.assertIsInstance(n.get_attribute('a'), orm.Str) - self.assertEqual(n.get_attribute('b').get_list(), ['f', True, {'gg': None}]) - self.assertIsInstance(n.get_attribute('b'), orm.List) + assert n.get_attribute('a').value == 'sometext2' + assert isinstance(n.get_attribute('a'), orm.Str) + assert n.get_attribute('b').get_list() == ['f', True, {'gg': None}] + assert isinstance(n.get_attribute('b'), orm.List) # Check also deep in a dictionary/list n = orm.Data() n.set_attribute('a', {'b': [orm.Str('sometext3')]}) - self.assertEqual(n.get_attribute('a')['b'][0].value, 'sometext3') - self.assertIsInstance(n.get_attribute('a')['b'][0], orm.Str) + assert n.get_attribute('a')['b'][0].value == 'sometext3' + assert isinstance(n.get_attribute('a')['b'][0], orm.Str) n.store() - self.assertEqual(n.get_attribute('a')['b'][0], 'sometext3') - self.assertIsInstance(n.get_attribute('a')['b'][0], str) + assert n.get_attribute('a')['b'][0] == 'sometext3' + assert isinstance(n.get_attribute('a')['b'][0], str) def test_basetype_as_extra(self): """ @@ -952,19 +967,19 @@ def test_basetype_as_extra(self): n.set_extra('a', orm.Str('sometext2')) n.set_extra('c', l1) n.set_extra('d', l2) - self.assertEqual(n.get_extra('a'), 'sometext2') - self.assertIsInstance(n.get_extra('a'), str) - self.assertEqual(n.get_extra('c'), ['b', [1, 2]]) - self.assertIsInstance(n.get_extra('c'), (list, tuple)) - self.assertEqual(n.get_extra('d'), ['f', True, {'gg': None}]) - self.assertIsInstance(n.get_extra('d'), (list, tuple)) + assert n.get_extra('a') == 'sometext2' + assert isinstance(n.get_extra('a'), str) + assert n.get_extra('c') == ['b', [1, 2]] + assert isinstance(n.get_extra('c'), (list, tuple)) + assert n.get_extra('d') == ['f', True, {'gg': None}] + assert isinstance(n.get_extra('d'), (list, tuple)) # Check also deep in a dictionary/list n = orm.Data() n.store() n.set_extra('a', {'b': [orm.Str('sometext3')]}) - self.assertEqual(n.get_extra('a')['b'][0], 'sometext3') - self.assertIsInstance(n.get_extra('a')['b'][0], str) + assert n.get_extra('a')['b'][0] == 'sometext3' + assert isinstance(n.get_extra('a')['b'][0], str) def test_comments(self): # This is the best way to compare dates with the stored ones, instead @@ -978,11 +993,11 @@ def test_comments(self): user = orm.User.objects.get_default() a = orm.Data() - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.add_comment('text', user=user) a.store() - self.assertEqual(a.get_comments(), []) + assert a.get_comments() == [] before = timezone.now() - timedelta(seconds=1) a.add_comment('text', user=user) @@ -997,13 +1012,15 @@ def test_comments(self): times = [i.ctime for i in comments] for time in times: - self.assertTrue(time > before) - self.assertTrue(time < after) + assert time > before + assert time < after - self.assertEqual([(i.user.email, i.content) for i in comments], [ - (self.user_email, 'text'), - (self.user_email, 'text2'), - ]) + default_user_email = get_profile().default_user_email + + assert [(i.user.email, i.content) for i in comments] == [ + (default_user_email, 'text'), + (default_user_email, 'text2'), + ] def test_code_loading_from_string(self): """ @@ -1024,22 +1041,22 @@ def test_code_loading_from_string(self): # Test that the code1 can be loaded correctly with its label q_code_1 = orm.Code.get_from_string(code1.label) - self.assertEqual(q_code_1.id, code1.id) - self.assertEqual(q_code_1.label, code1.label) - self.assertEqual(q_code_1.get_remote_exec_path(), code1.get_remote_exec_path()) + assert q_code_1.id == code1.id + assert q_code_1.label == code1.label + assert q_code_1.get_remote_exec_path() == code1.get_remote_exec_path() # Test that the code2 can be loaded correctly with its label q_code_2 = orm.Code.get_from_string(f'{code2.label}@{self.computer.label}') # pylint: disable=no-member - self.assertEqual(q_code_2.id, code2.id) - self.assertEqual(q_code_2.label, code2.label) - self.assertEqual(q_code_2.get_remote_exec_path(), code2.get_remote_exec_path()) + assert q_code_2.id == code2.id + assert q_code_2.label == code2.label + assert q_code_2.get_remote_exec_path() == code2.get_remote_exec_path() # Calling get_from_string for a non string type raises exception - with self.assertRaises(TypeError): + with pytest.raises(TypeError): orm.Code.get_from_string(code1.id) # Test that the lookup of a nonexistent code works as expected - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.Code.get_from_string('nonexistent_code') # Add another code with the label of code1 @@ -1049,7 +1066,7 @@ def test_code_loading_from_string(self): code3.store() # Query with the common label - with self.assertRaises(MultipleObjectsError): + with pytest.raises(MultipleObjectsError): orm.Code.get_from_string(code3.label) def test_code_loading_using_get(self): @@ -1071,30 +1088,30 @@ def test_code_loading_using_get(self): # Test that the code1 can be loaded correctly with its label only q_code_1 = orm.Code.get(label=code1.label) - self.assertEqual(q_code_1.id, code1.id) - self.assertEqual(q_code_1.label, code1.label) - self.assertEqual(q_code_1.get_remote_exec_path(), code1.get_remote_exec_path()) + assert q_code_1.id == code1.id + assert q_code_1.label == code1.label + assert q_code_1.get_remote_exec_path() == code1.get_remote_exec_path() # Test that the code1 can be loaded correctly with its id/pk q_code_1 = orm.Code.get(code1.id) - self.assertEqual(q_code_1.id, code1.id) - self.assertEqual(q_code_1.label, code1.label) - self.assertEqual(q_code_1.get_remote_exec_path(), code1.get_remote_exec_path()) + assert q_code_1.id == code1.id + assert q_code_1.label == code1.label + assert q_code_1.get_remote_exec_path() == code1.get_remote_exec_path() # Test that the code2 can be loaded correctly with its label and computername q_code_2 = orm.Code.get(label=code2.label, machinename=self.computer.label) # pylint: disable=no-member - self.assertEqual(q_code_2.id, code2.id) - self.assertEqual(q_code_2.label, code2.label) - self.assertEqual(q_code_2.get_remote_exec_path(), code2.get_remote_exec_path()) + assert q_code_2.id == code2.id + assert q_code_2.label == code2.label + assert q_code_2.get_remote_exec_path() == code2.get_remote_exec_path() # Test that the code2 can be loaded correctly with its id/pk q_code_2 = orm.Code.get(code2.id) - self.assertEqual(q_code_2.id, code2.id) - self.assertEqual(q_code_2.label, code2.label) - self.assertEqual(q_code_2.get_remote_exec_path(), code2.get_remote_exec_path()) + assert q_code_2.id == code2.id + assert q_code_2.label == code2.label + assert q_code_2.get_remote_exec_path() == code2.get_remote_exec_path() # Test that the lookup of a nonexistent code works as expected - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.Code.get(label='nonexistent_code') # Add another code with the label of code1 @@ -1104,7 +1121,7 @@ def test_code_loading_using_get(self): code3.store() # Query with the common label - with self.assertRaises(MultipleObjectsError): + with pytest.raises(MultipleObjectsError): orm.Code.get(label=code3.label) # Add another code whose label is equal to pk of another code @@ -1118,9 +1135,9 @@ def test_code_loading_using_get(self): # Code.get(pk_label_duplicate) should return code1, as the pk takes # precedence q_code_4 = orm.Code.get(code4.label) - self.assertEqual(q_code_4.id, code1.id) - self.assertEqual(q_code_4.label, code1.label) - self.assertEqual(q_code_4.get_remote_exec_path(), code1.get_remote_exec_path()) + assert q_code_4.id == code1.id + assert q_code_4.label == code1.label + assert q_code_4.get_remote_exec_path() == code1.get_remote_exec_path() def test_code_description(self): """ @@ -1135,10 +1152,10 @@ def test_code_description(self): code.store() q_code1 = orm.Code.get(label=code.label) - self.assertEqual(code.description, str(q_code1.description)) + assert code.description == str(q_code1.description) q_code2 = orm.Code.get(code.id) - self.assertEqual(code.description, str(q_code2.description)) + assert code.description == str(q_code2.description) def test_list_for_plugin(self): """ @@ -1157,10 +1174,10 @@ def test_list_for_plugin(self): code2.store() retrieved_pks = set(orm.Code.list_for_plugin('plugin_name', labels=False)) - self.assertEqual(retrieved_pks, set([code1.pk, code2.pk])) + assert retrieved_pks == set([code1.pk, code2.pk]) retrieved_labels = set(orm.Code.list_for_plugin('plugin_name', labels=True)) - self.assertEqual(retrieved_labels, set([code1.label, code2.label])) + assert retrieved_labels == set([code1.label, code2.label]) def test_load_node(self): """ @@ -1172,38 +1189,44 @@ def test_load_node(self): node = orm.Data().store() uuid_stored = node.uuid # convenience to store the uuid # Simple test to see whether I load correctly from the pk: - self.assertEqual(uuid_stored, orm.load_node(pk=node.pk).uuid) + assert uuid_stored == orm.load_node(pk=node.pk).uuid # Testing the loading with the uuid: - self.assertEqual(uuid_stored, orm.load_node(uuid=uuid_stored).uuid) + assert uuid_stored == orm.load_node(uuid=uuid_stored).uuid # Here I'm testing whether loading the node with the beginnings of a uuid works for i in range(10, len(uuid_stored), 2): start_uuid = uuid_stored[:i] - self.assertEqual(uuid_stored, orm.load_node(uuid=start_uuid).uuid) + assert uuid_stored == orm.load_node(uuid=start_uuid).uuid # Testing whether loading the node with part of UUID works, removing the dashes for i in range(10, len(uuid_stored), 2): start_uuid = uuid_stored[:i].replace('-', '') - self.assertEqual(uuid_stored, orm.load_node(uuid=start_uuid).uuid) + assert uuid_stored == orm.load_node(uuid=start_uuid).uuid # If I don't allow load_node to fix the dashes, this has to raise: - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.load_node(uuid=start_uuid, query_with_dashes=False) # Now I am reverting the order of the uuid, this will raise a NotExistent error: - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.load_node(uuid=uuid_stored[::-1]) # I am giving a non-sensical pk, this should also raise - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.load_node(-1) # Last check, when asking for specific subclass, this should raise: for spec in (node.pk, uuid_stored): - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.load_node(spec, sub_classes=(orm.ArrayData,)) -class TestSubNodesAndLinks(AiidaTestCase): +class TestSubNodesAndLinks: + + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost def test_cachelink(self): """Test the proper functionality of the links cache, with different scenarios.""" @@ -1218,29 +1241,28 @@ def test_cachelink(self): # Try also reverse storage endcalc.add_incoming(n2, LinkType.INPUT_CALC, 'N2') - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, {('N1', n1.uuid), - ('N2', n2.uuid)}) + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == {('N1', n1.uuid), ('N2', n2.uuid)} # Endnode not stored yet, n3 and n4 already stored endcalc.add_incoming(n3, LinkType.INPUT_CALC, 'N3') # Try also reverse storage endcalc.add_incoming(n4, LinkType.INPUT_CALC, 'N4') - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, - {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)}) + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == \ + {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)} # Some parent nodes are not stored yet - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): endcalc.store() - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, - {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)}) + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == \ + {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)} # This will also store n1 and n2! endcalc.store_all() - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, - {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)}) + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == \ + {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)} def test_store_with_unstored_parents(self): """ @@ -1254,15 +1276,14 @@ def test_store_with_unstored_parents(self): endcalc.add_incoming(n2, LinkType.INPUT_CALC, 'N2') # Some parent nodes are not stored yet - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): endcalc.store() n1.store() # Now I can store endcalc.store() - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, {('N1', n1.uuid), - ('N2', n2.uuid)}) + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == {('N1', n1.uuid), ('N2', n2.uuid)} def test_storeall_with_unstored_grandparents(self): """ @@ -1276,7 +1297,7 @@ def test_storeall_with_unstored_grandparents(self): endcalc.add_incoming(n2, LinkType.INPUT_CALC, 'N2') # Grandparents are unstored - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): endcalc.store_all() n1.store() @@ -1284,10 +1305,9 @@ def test_storeall_with_unstored_grandparents(self): endcalc.store_all() # Check the parents... - self.assertEqual({(i.link_label, i.node.uuid) for i in n2.get_incoming()}, {('N1', n1.uuid)}) - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, {('N2', n2.uuid)}) + assert {(i.link_label, i.node.uuid) for i in n2.get_incoming()} == {('N1', n1.uuid)} + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == {('N2', n2.uuid)} - # pylint: disable=unused-variable,no-member,no-self-use def test_calculation_load(self): from aiida.orm import CalcJobNode @@ -1296,7 +1316,7 @@ def test_calculation_load(self): calc.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) calc.store() - with self.assertRaises(Exception): + with pytest.raises(Exception): # I should get an error if I ask for a computer id/pk that doesn't exist CalcJobNode(computer=self.computer.id + 100000).store() @@ -1311,7 +1331,7 @@ def test_links_label_constraints(self): calc2b = orm.CalculationNode() calc.add_incoming(d1, LinkType.INPUT_CALC, link_label='label1') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): calc.add_incoming(d1bis, LinkType.INPUT_CALC, link_label='label1') calc.store() @@ -1323,7 +1343,7 @@ def test_links_label_constraints(self): # This shouldn't be allowed, it's an output CREATE link with # the same same of an existing output CREATE link - with self.assertRaises(ValueError): + with pytest.raises(ValueError): d4.add_incoming(calc, LinkType.CREATE, link_label='label2') # instead, for outputs, I can have multiple times the same label @@ -1347,29 +1367,29 @@ def test_link_with_unstored(self): n3.add_incoming(n2, link_type=LinkType.CALL_CALC, link_label='l2') # Twice the same link name - with self.assertRaises(ValueError): + with pytest.raises(ValueError): n3.add_incoming(n4, link_type=LinkType.INPUT_CALC, link_label='l3') n2.store_all() n3.store_all() n2_in_links = [(n.link_label, n.node.uuid) for n in n2.get_incoming()] - self.assertEqual(sorted(n2_in_links), sorted([ + assert sorted(n2_in_links) == sorted([ ('l1', n1.uuid), - ])) + ]) n3_in_links = [(n.link_label, n.node.uuid) for n in n3.get_incoming()] - self.assertEqual(sorted(n3_in_links), sorted([ + assert sorted(n3_in_links) == sorted([ ('l2', n2.uuid), ('l3', n1.uuid), - ])) + ]) n1_out_links = [(entry.link_label, entry.node.pk) for entry in n1.get_outgoing()] - self.assertEqual(sorted(n1_out_links), sorted([ + assert sorted(n1_out_links) == sorted([ ('l1', n2.pk), ('l3', n3.pk), - ])) + ]) n2_out_links = [(entry.link_label, entry.node.pk) for entry in n2.get_outgoing()] - self.assertEqual(sorted(n2_out_links), sorted([('l2', n3.pk)])) + assert sorted(n2_out_links) == sorted([('l2', n3.pk)]) def test_multiple_create_links(self): """ @@ -1381,7 +1401,7 @@ def test_multiple_create_links(self): # Caching the links n3.add_incoming(n1, link_type=LinkType.CREATE, link_label='CREATE') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): n3.add_incoming(n2, link_type=LinkType.CREATE, link_label='CREATE') def test_valid_links(self): @@ -1398,7 +1418,7 @@ def test_valid_links(self): label='localhost2', hostname='localhost', scheduler_type='core.direct', transport_type='core.local' ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # I need to save the localhost entry first orm.CalcJobNode(computer=unsavedcomputer).store() @@ -1414,17 +1434,17 @@ def test_valid_links(self): calc.add_incoming(d2, link_type=LinkType.INPUT_CALC, link_label='some_label') # Cannot link to itself - with self.assertRaises(ValueError): + with pytest.raises(ValueError): d1.add_incoming(d1, link_type=LinkType.INPUT_CALC, link_label='link') # I try to add wrong links (data to data, calc to calc, etc.) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): d2.add_incoming(d1, link_type=LinkType.INPUT_CALC, link_label='link') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): d1.add_incoming(d2, link_type=LinkType.INPUT_CALC, link_label='link') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): calc.add_incoming(calc2, link_type=LinkType.INPUT_CALC, link_label='link') calc.store() @@ -1439,13 +1459,13 @@ def test_valid_links(self): data_node = orm.Data().store() data_node.add_incoming(calc_a, link_type=LinkType.CREATE, link_label='link') # A data cannot have two input calculations - with self.assertRaises(ValueError): + with pytest.raises(ValueError): data_node.add_incoming(calc_b, link_type=LinkType.CREATE, link_label='link') calculation_inputs = calc.get_incoming().all() # This calculation has two data inputs - self.assertEqual(len(calculation_inputs), 2) + assert len(calculation_inputs) == 2 def test_check_single_calc_source(self): """ @@ -1463,7 +1483,7 @@ def test_check_single_calc_source(self): d1.add_incoming(calc, link_type=LinkType.CREATE, link_label='link') # more than one input to the same data object! - with self.assertRaises(ValueError): + with pytest.raises(ValueError): d1.add_incoming(calc2, link_type=LinkType.CREATE, link_label='link') def test_node_get_incoming_outgoing_links(self): @@ -1494,19 +1514,19 @@ def test_node_get_incoming_outgoing_links(self): node_return.add_incoming(node_origin, link_type=LinkType.RETURN, link_label='return2') # All incoming and outgoing - self.assertEqual(len(node_origin.get_incoming().all()), 2) - self.assertEqual(len(node_origin.get_outgoing().all()), 3) + assert len(node_origin.get_incoming().all()) == 2 + assert len(node_origin.get_outgoing().all()) == 3 # Link specific incoming - self.assertEqual(len(node_origin.get_incoming(link_type=LinkType.CALL_WORK).all()), 1) - self.assertEqual(len(node_origin2.get_incoming(link_type=LinkType.CALL_WORK).all()), 1) - self.assertEqual(len(node_origin.get_incoming(link_type=LinkType.INPUT_WORK).all()), 1) - self.assertEqual(len(node_origin.get_incoming(link_label_filter='in_ut%').all()), 1) - self.assertEqual(len(node_origin.get_incoming(node_class=orm.Node).all()), 2) + assert len(node_origin.get_incoming(link_type=LinkType.CALL_WORK).all()) == 1 + assert len(node_origin2.get_incoming(link_type=LinkType.CALL_WORK).all()) == 1 + assert len(node_origin.get_incoming(link_type=LinkType.INPUT_WORK).all()) == 1 + assert len(node_origin.get_incoming(link_label_filter='in_ut%').all()) == 1 + assert len(node_origin.get_incoming(node_class=orm.Node).all()) == 2 # Link specific outgoing - self.assertEqual(len(node_origin.get_outgoing(link_type=LinkType.CALL_WORK).all()), 1) - self.assertEqual(len(node_origin.get_outgoing(link_type=LinkType.RETURN).all()), 2) + assert len(node_origin.get_outgoing(link_type=LinkType.CALL_WORK).all()) == 1 + assert len(node_origin.get_outgoing(link_type=LinkType.RETURN).all()) == 2 class AnyValue: @@ -1518,7 +1538,8 @@ def __eq__(self, other): return True -class TestNodeDeletion(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean_class') +class TestNodeDeletion: def _check_existence(self, uuids_check_existence, uuids_check_deleted): """ @@ -1539,7 +1560,7 @@ def _check_existence(self, uuids_check_existence, uuids_check_deleted): orm.load_node(uuid) for uuid in uuids_check_deleted: # I check that it raises - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.load_node(uuid) @staticmethod @@ -1553,8 +1574,8 @@ def test_deletion_dry_run_true(self): node = orm.Data().store() node_pk = node.pk deleted_pks, was_deleted = delete_nodes([node_pk], dry_run=True) - self.assertTrue(not was_deleted) - self.assertSetEqual(deleted_pks, {node_pk}) + assert not was_deleted + assert deleted_pks == {node_pk} orm.load_node(node_pk) def test_deletion_dry_run_callback(self): @@ -1569,11 +1590,11 @@ def _callback(pks): return False deleted_pks, was_deleted = delete_nodes([node_pk], dry_run=_callback) - self.assertTrue(was_deleted) - self.assertSetEqual(deleted_pks, {node_pk}) - with self.assertRaises(NotExistent): + assert was_deleted + assert deleted_pks == {node_pk} + with pytest.raises(NotExistent): orm.load_node(node_pk) - self.assertListEqual(callback_pks, [node_pk]) + assert callback_pks == [node_pk] # TEST BASIC CASES @@ -2057,8 +2078,8 @@ def test_delete_group_nodes(self): node_uuids = {node.uuid for node in nodes} group.add_nodes(nodes) deleted_pks, was_deleted = delete_group_nodes([group.pk], dry_run=False) - self.assertTrue(was_deleted) - self.assertSetEqual(deleted_pks, node_pks) + assert was_deleted + assert deleted_pks == node_pks self._check_existence([], node_uuids) def test_delete_group_nodes_dry_run_true(self): @@ -2069,6 +2090,6 @@ def test_delete_group_nodes_dry_run_true(self): node_uuids = {node.uuid for node in nodes} group.add_nodes(nodes) deleted_pks, was_deleted = delete_group_nodes([group.pk], dry_run=True) - self.assertTrue(not was_deleted) - self.assertSetEqual(deleted_pks, node_pks) + assert not was_deleted + assert deleted_pks == node_pks self._check_existence(node_uuids, []) diff --git a/tests/tools/data/orbital/test_orbitals.py b/tests/tools/data/orbital/test_orbitals.py index 25442dccab..f56dc748ab 100644 --- a/tests/tools/data/orbital/test_orbitals.py +++ b/tests/tools/data/orbital/test_orbitals.py @@ -7,16 +7,16 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Test for the `Orbital` class and subclasses.""" +import pytest from aiida.common.exceptions import ValidationError from aiida.plugins import OrbitalFactory -#from aiida import orm -from aiida.storage.testbase import AiidaTestCase from aiida.tools.data.orbital import Orbital -class TestOrbital(AiidaTestCase): +class TestOrbital: """Test the Orbital base class""" def test_orbital_str(self): @@ -24,42 +24,41 @@ def test_orbital_str(self): orbital = Orbital(position=(1, 2, 3)) expected_output = 'Orbital @ 1.0000,2.0000,3.0000' - self.assertEqual(str(orbital), expected_output) + assert str(orbital) == expected_output def test_required_fields(self): """Verify that required fields are validated.""" # position is required - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Orbital() # position must be a list of three floats - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Orbital(position=(1, 2)) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Orbital(position=(1, 2, 'a')) orbital = Orbital(position=(1, 2, 3)) - self.assertAlmostEqual(orbital.get_orbital_dict()['position'][0], 1.) - self.assertAlmostEqual(orbital.get_orbital_dict()['position'][1], 2.) - self.assertAlmostEqual(orbital.get_orbital_dict()['position'][2], 3.) + assert round(abs(orbital.get_orbital_dict()['position'][0] - 1.), 7) == 0 + assert round(abs(orbital.get_orbital_dict()['position'][1] - 2.), 7) == 0 + assert round(abs(orbital.get_orbital_dict()['position'][2] - 3.), 7) == 0 def test_unknown_fields(self): """Verify that unkwown fields raise a validation error.""" # position is required # position must be a list of three floats - with self.assertRaises(ValidationError) as exc: + with pytest.raises(ValidationError, match='some_strange_key'): Orbital(position=(1, 2, 3), some_strange_key=1) - self.assertIn('some_strange_key', str(exc.exception)) -class TestRealhydrogenOrbital(AiidaTestCase): +class TestRealhydrogenOrbital: """Test the Orbital base class""" def test_required_fields(self): """Verify that required fields are validated.""" RealhydrogenOrbital = OrbitalFactory('core.realhydrogen') # pylint: disable=invalid-name # Check that the required fields of the base class are not enough - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): RealhydrogenOrbital(position=(1, 2, 3)) orbital = RealhydrogenOrbital( @@ -70,18 +69,18 @@ def test_required_fields(self): 'radial_nodes': 2 } ) - self.assertAlmostEqual(orbital.get_orbital_dict()['position'][0], -1.) - self.assertAlmostEqual(orbital.get_orbital_dict()['position'][1], -2.) - self.assertAlmostEqual(orbital.get_orbital_dict()['position'][2], -3.) - self.assertAlmostEqual(orbital.get_orbital_dict()['angular_momentum'], 1) - self.assertAlmostEqual(orbital.get_orbital_dict()['magnetic_number'], 0) - self.assertAlmostEqual(orbital.get_orbital_dict()['radial_nodes'], 2) + assert round(abs(orbital.get_orbital_dict()['position'][0] - -1.), 7) == 0 + assert round(abs(orbital.get_orbital_dict()['position'][1] - -2.), 7) == 0 + assert round(abs(orbital.get_orbital_dict()['position'][2] - -3.), 7) == 0 + assert round(abs(orbital.get_orbital_dict()['angular_momentum'] - 1), 7) == 0 + assert round(abs(orbital.get_orbital_dict()['magnetic_number'] - 0), 7) == 0 + assert round(abs(orbital.get_orbital_dict()['radial_nodes'] - 2), 7) == 0 def test_validation_for_fields(self): """Verify that the values are properly validated""" RealhydrogenOrbital = OrbitalFactory('core.realhydrogen') # pylint: disable=invalid-name - with self.assertRaises(ValidationError) as exc: + with pytest.raises(ValidationError, match='angular_momentum'): RealhydrogenOrbital( **{ 'position': (-1, -2, -3), @@ -90,9 +89,8 @@ def test_validation_for_fields(self): 'radial_nodes': 2 } ) - self.assertIn('angular_momentum', str(exc.exception)) - with self.assertRaises(ValidationError) as exc: + with pytest.raises(ValidationError, match='magnetic number must be in the range'): RealhydrogenOrbital( **{ 'position': (-1, -2, -3), @@ -101,9 +99,8 @@ def test_validation_for_fields(self): 'radial_nodes': 2 } ) - self.assertIn('magnetic number must be in the range', str(exc.exception)) - with self.assertRaises(ValidationError) as exc: + with pytest.raises(ValidationError, match='radial_nodes'): RealhydrogenOrbital( **{ 'position': (-1, -2, -3), @@ -112,7 +109,6 @@ def test_validation_for_fields(self): 'radial_nodes': 100 } ) - self.assertIn('radial_nodes', str(exc.exception)) def test_optional_fields(self): """ @@ -130,8 +126,8 @@ def test_optional_fields(self): } ) # Check that the optional value is there and has its default value - self.assertEqual(orbital.get_orbital_dict()['spin'], 0) - self.assertEqual(orbital.get_orbital_dict()['diffusivity'], None) + assert orbital.get_orbital_dict()['spin'] == 0 + assert orbital.get_orbital_dict()['diffusivity'] is None orbital = RealhydrogenOrbital( **{ @@ -143,10 +139,10 @@ def test_optional_fields(self): 'diffusivity': 3.1 } ) - self.assertEqual(orbital.get_orbital_dict()['spin'], 1) - self.assertEqual(orbital.get_orbital_dict()['diffusivity'], 3.1) + assert orbital.get_orbital_dict()['spin'] == 1 + assert orbital.get_orbital_dict()['diffusivity'] == 3.1 - with self.assertRaises(ValidationError) as exc: + with pytest.raises(ValidationError, match='diffusivity'): RealhydrogenOrbital( **{ 'position': (-1, -2, -3), @@ -157,9 +153,8 @@ def test_optional_fields(self): 'diffusivity': 'a' } ) - self.assertIn('diffusivity', str(exc.exception)) - with self.assertRaises(ValidationError) as exc: + with pytest.raises(ValidationError, match='spin'): RealhydrogenOrbital( **{ 'position': (-1, -2, -3), @@ -170,13 +165,12 @@ def test_optional_fields(self): 'diffusivity': 3.1 } ) - self.assertIn('spin', str(exc.exception)) def test_unknown_fields(self): """Verify that unkwown fields raise a validation error.""" RealhydrogenOrbital = OrbitalFactory('core.realhydrogen') # pylint: disable=invalid-name - with self.assertRaises(ValidationError) as exc: + with pytest.raises(ValidationError, match='some_strange_key'): RealhydrogenOrbital( **{ 'position': (-1, -2, -3), @@ -186,7 +180,6 @@ def test_unknown_fields(self): 'some_strange_key': 1 } ) - self.assertIn('some_strange_key', str(exc.exception)) def test_get_name_from_quantum_numbers(self): """ @@ -195,16 +188,16 @@ def test_get_name_from_quantum_numbers(self): RealhydrogenOrbital = OrbitalFactory('core.realhydrogen') # pylint: disable=invalid-name name = RealhydrogenOrbital.get_name_from_quantum_numbers(angular_momentum=1) - self.assertEqual(name, 'P') + assert name == 'P' name = RealhydrogenOrbital.get_name_from_quantum_numbers(angular_momentum=0) - self.assertEqual(name, 'S') + assert name == 'S' name = RealhydrogenOrbital.get_name_from_quantum_numbers(angular_momentum=0, magnetic_number=0) - self.assertEqual(name, 'S') + assert name == 'S' name = RealhydrogenOrbital.get_name_from_quantum_numbers(angular_momentum=1, magnetic_number=1) - self.assertEqual(name, 'PX') + assert name == 'PX' name = RealhydrogenOrbital.get_name_from_quantum_numbers(angular_momentum=2, magnetic_number=4) - self.assertEqual(name, 'DXY') + assert name == 'DXY' diff --git a/tests/tools/dbimporters/test_icsd.py b/tests/tools/dbimporters/test_icsd.py index 55374659f8..50f14d3cb0 100644 --- a/tests/tools/dbimporters/test_icsd.py +++ b/tests/tools/dbimporters/test_icsd.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """ Tests for IcsdDbImporter """ @@ -14,7 +15,7 @@ import pytest -from aiida.storage.testbase import AiidaTestCase +from aiida import get_profile from aiida.tools.dbimporters.plugins import icsd @@ -37,7 +38,6 @@ def has_icsd_config(): """ :return: True if the currently loaded profile has a ICSD configuration """ - from aiida.manage.configuration import get_profile profile = get_profile() required_keywords = { @@ -47,18 +47,17 @@ def has_icsd_config(): return required_keywords <= set(profile.dictionary.keys()) -class TestIcsd(AiidaTestCase): +class TestIcsd: """ Tests for the ICSD importer """ - def setUp(self): - """ - Set up IcsdDbImporter for web and mysql db query. - """ + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument + """Initialize the profile and set up IcsdDbImporter for web and mysql db query.""" + # pylint: disable=attribute-defined-outside-init if not (has_mysqldb() and has_icsd_config()): pytest.skip('ICSD configuration in profile required') - from aiida.manage.configuration import get_profile profile = get_profile() self.server = profile.dictionary['ICSD_SERVER_URL'] @@ -94,7 +93,7 @@ def test_web_zero_results(self): """ No results should be obtained from year 3000. """ - with self.assertRaises(icsd.NoResultsWebExp): + with pytest.raises(icsd.NoResultsWebExp): self.importerweb.query(year='3000') def test_web_collcode_155006(self): @@ -103,13 +102,13 @@ def test_web_collcode_155006(self): """ queryresults = self.importerweb.query(id='155006') - self.assertEqual(queryresults.number_of_results, 1) + assert queryresults.number_of_results == 1 - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): next(queryresults) next(queryresults) - with self.assertRaises(IndexError): + with pytest.raises(IndexError): queryresults.at(10) def test_dbquery_zero_results(self): @@ -119,10 +118,10 @@ def test_dbquery_zero_results(self): importer = icsd.IcsdDbImporter(server=self.server, host=self.host) noresults = importer.query(year='3000') # which should work at least for the next 85 years.. - self.assertEqual(noresults.number_of_results, 0) + assert noresults.number_of_results == 0 - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): next(noresults) - with self.assertRaises(IndexError): + with pytest.raises(IndexError): noresults.at(0) diff --git a/tests/tools/dbimporters/test_materialsproject.py b/tests/tools/dbimporters/test_materialsproject.py index 2a670e23c1..80f1a97943 100644 --- a/tests/tools/dbimporters/test_materialsproject.py +++ b/tests/tools/dbimporters/test_materialsproject.py @@ -8,11 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module that contains the class definitions necessary to offer support for queries to Materials Project.""" - import pytest from aiida.plugins import DbImporterFactory -from aiida.storage.testbase import AiidaTestCase def run_materialsproject_api_tests(): @@ -21,7 +19,8 @@ def run_materialsproject_api_tests(): return profile.dictionary.get('run_materialsproject_api_tests', False) -class TestMaterialsProject(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestMaterialsProject: """ Contains the tests to verify the functionality of the Materials Project importer functions. diff --git a/tests/tools/graph/test_age.py b/tests/tools/graph/test_age.py index b3e2850e75..2f8dbf4c96 100644 --- a/tests/tools/graph/test_age.py +++ b/tests/tools/graph/test_age.py @@ -7,14 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-locals, too-many-statements +# pylint: disable=too-many-locals,too-many-statements,no-self-use """AGE tests""" - import numpy as np +import pytest from aiida import orm from aiida.common.links import LinkType -from aiida.storage.testbase import AiidaTestCase from aiida.tools.graph.age_entities import AiidaEntitySet, Basket, DirectedEdgeSet, GroupNodeEdge from aiida.tools.graph.age_rules import ReplaceRule, RuleSaveWalkers, RuleSequence, RuleSetWalkers, UpdateRule @@ -80,13 +79,10 @@ def create_tree(max_depth=3, branching=3, starting_cls=orm.Data): return result -class TestAiidaGraphExplorer(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestAiidaGraphExplorer: """Tests for the AGE""" - def setUp(self): - super().setUp() - self.refurbish_db() - @staticmethod def _create_basic_graph(): """ @@ -173,7 +169,7 @@ def test_basic_graph(self): obtained = uprule.run(basket_w1.copy())['nodes'].keyset expected = set((nodes['work_1'].id, nodes['calc_0'].id, nodes['data_o'].id)) - self.assertEqual(obtained, expected) + assert obtained == expected # Find all the descendants of work_1 through call_calc (calc_0) edge_cacalc = {'type': {'in': [LinkType.CALL_CALC.value]}} @@ -184,7 +180,7 @@ def test_basic_graph(self): obtained = uprule.run(basket_w1.copy())['nodes'].keyset expected = set((nodes['work_1'].id, nodes['calc_0'].id)) - self.assertEqual(obtained, expected) + assert obtained == expected # Find all the descendants of work_1 that are data nodes (data_o) queryb = orm.QueryBuilder() @@ -194,7 +190,7 @@ def test_basic_graph(self): obtained = uprule.run(basket_w1.copy())['nodes'].keyset expected = set((nodes['work_1'].id, nodes['data_o'].id)) - self.assertEqual(obtained, expected) + assert obtained == expected # Find all the ascendants of work_1 queryb = orm.QueryBuilder() @@ -204,7 +200,7 @@ def test_basic_graph(self): obtained = uprule.run(basket_w1.copy())['nodes'].keyset expected = set((nodes['work_1'].id, nodes['work_2'].id, nodes['data_i'].id)) - self.assertEqual(obtained, expected) + assert obtained == expected # Find all the ascendants of work_1 through input_work (data_i) edge_inpwork = {'type': {'in': [LinkType.INPUT_WORK.value]}} @@ -215,7 +211,7 @@ def test_basic_graph(self): obtained = uprule.run(basket_w1.copy())['nodes'].keyset expected = set((nodes['work_1'].id, nodes['data_i'].id)) - self.assertEqual(obtained, expected) + assert obtained == expected # Find all the ascendants of work_1 that are workflow nodes (work_2) queryb = orm.QueryBuilder() @@ -225,7 +221,7 @@ def test_basic_graph(self): obtained = uprule.run(basket_w1.copy())['nodes'].keyset expected = set((nodes['work_1'].id, nodes['work_2'].id)) - self.assertEqual(obtained, expected) + assert obtained == expected # Only get the descendants that are direct (1st level) (work_1, data_o) queryb = orm.QueryBuilder() @@ -235,7 +231,7 @@ def test_basic_graph(self): obtained = rerule.run(basket_w2.copy())['nodes'].keyset expected = set((nodes['work_1'].id, nodes['data_o'].id)) - self.assertEqual(obtained, expected) + assert obtained == expected # Only get the descendants of the descendants (2nd level) (calc_0, data_o) queryb = orm.QueryBuilder() @@ -245,7 +241,7 @@ def test_basic_graph(self): obtained = rerule.run(basket_w2.copy())['nodes'].keyset expected = set((nodes['calc_0'].id, nodes['data_o'].id)) - self.assertEqual(obtained, expected) + assert obtained == expected def test_cycle(self): """ @@ -267,23 +263,23 @@ def test_cycle(self): uprule = UpdateRule(queryb, max_iterations=np.inf) obtained = uprule.run(basket.copy())['nodes'].keyset expected = set([data_node.id, work_node.id]) - self.assertEqual(obtained, expected) + assert obtained == expected rerule1 = ReplaceRule(queryb, max_iterations=1) result1 = rerule1.run(basket.copy())['nodes'].keyset - self.assertEqual(result1, set([work_node.id])) + assert result1 == set([work_node.id]) rerule2 = ReplaceRule(queryb, max_iterations=2) result2 = rerule2.run(basket.copy())['nodes'].keyset - self.assertEqual(result2, set([data_node.id])) + assert result2 == set([data_node.id]) rerule3 = ReplaceRule(queryb, max_iterations=3) result3 = rerule3.run(basket.copy())['nodes'].keyset - self.assertEqual(result3, set([work_node.id])) + assert result3 == set([work_node.id]) rerule4 = ReplaceRule(queryb, max_iterations=4) result4 = rerule4.run(basket.copy())['nodes'].keyset - self.assertEqual(result4, set([data_node.id])) + assert result4 == set([data_node.id]) @staticmethod def _create_branchy_graph(): @@ -365,13 +361,13 @@ def test_stash(self): rule_seq = RuleSequence((uprule_out, uprule_inp)) obtained = rule_seq.run(basket.copy())['nodes'].keyset expected = expect_base.union(set([nodes['data_i'].id])) - self.assertEqual(obtained, expected) + assert obtained == expected # First get inputs, then outputs. rule_seq = RuleSequence((uprule_inp, uprule_out)) obtained = rule_seq.run(basket.copy())['nodes'].keyset expected = expect_base.union(set([nodes['data_o'].id])) - self.assertEqual(obtained, expected) + assert obtained == expected # Now using the stash option in either order. stash = basket.get_template() @@ -383,22 +379,22 @@ def test_stash(self): # set, whereas the stash contains the same data as the starting point) obtained = rule_save.run(basket.copy()) expected = basket.copy() - self.assertEqual(obtained, expected) - self.assertEqual(stash, basket) + assert obtained == expected + assert stash == basket stash = basket.get_template() rule_save = RuleSaveWalkers(stash) rule_load = RuleSetWalkers(stash) serule_io = RuleSequence((rule_save, uprule_inp, rule_load, uprule_out)) result_io = serule_io.run(basket.copy())['nodes'].keyset - self.assertEqual(result_io, expect_base) + assert result_io == expect_base stash = basket.get_template() rule_save = RuleSaveWalkers(stash) rule_load = RuleSetWalkers(stash) serule_oi = RuleSequence((rule_save, uprule_out, rule_load, uprule_inp)) result_oi = serule_oi.run(basket.copy())['nodes'].keyset - self.assertEqual(result_oi, expect_base) + assert result_oi == expect_base def test_returns_calls(self): """Tests return calls (?)""" @@ -465,7 +461,7 @@ def test_returns_calls(self): ruleseq = RuleSequence(rules, max_iterations=np.inf) resulting_set = ruleseq.run(basket.copy()) expecting_set = resulting_set - self.assertEqual(expecting_set, resulting_set) + assert expecting_set == resulting_set def test_groups(self): """ @@ -502,11 +498,11 @@ def test_groups(self): obtained = basket_out['nodes'].keyset expected = set([node2.id, node3.id]) - self.assertEqual(obtained, expected) + assert obtained == expected obtained = basket_out['groups'].keyset expected = set() - self.assertEqual(obtained, expected) + assert obtained == expected # But two rules chained should get both nodes and groups... queryb = orm.QueryBuilder() @@ -528,11 +524,11 @@ def test_groups(self): obtained = basket_out['nodes'].keyset expected = set([node2.id, node3.id]) - self.assertEqual(obtained, expected) + assert obtained == expected obtained = basket_out['groups'].keyset expected = set([group2.id]) - self.assertEqual(obtained, expected) + assert obtained == expected # ...and starting with a group initial_group = [group3.id] @@ -541,11 +537,11 @@ def test_groups(self): obtained = basket_out['nodes'].keyset expected = set([node4.id]) - self.assertEqual(obtained, expected) + assert obtained == expected obtained = basket_out['groups'].keyset expected = set([group3.id, group4.id]) - self.assertEqual(obtained, expected) + assert obtained == expected # Testing a "group chain" total_groups = 10 @@ -583,15 +579,15 @@ def test_groups(self): obtained = basket_out['nodes'].keyset expected = set(n.id for n in nodes) - self.assertEqual(obtained, expected) + assert obtained == expected obtained = basket_out['groups'].keyset expected = set(g.id for g in groups) - self.assertEqual(obtained, expected) + assert obtained == expected # testing the edges between groups and nodes: result = basket_out['groups_nodes'].keyset - self.assertEqual(result, edges) + assert result == edges def test_edges(self): """ @@ -609,7 +605,7 @@ def test_edges(self): obtained = uprule_result['nodes'].keyset expected = set(anode.id for _, anode in nodes.items()) - self.assertEqual(obtained, expected) + assert obtained == expected obtained = set() for data in uprule_result['nodes_nodes'].keyset: @@ -625,7 +621,7 @@ def test_edges(self): (nodes['work_2'].id, nodes['work_1'].id), (nodes['work_1'].id, nodes['calc_0'].id), } - self.assertEqual(obtained, expected) + assert obtained == expected # Backwards traversal (check partial traversal and link direction) basket = Basket(nodes=[nodes['data_o'].id]) @@ -638,7 +634,7 @@ def test_edges(self): obtained = uprule_result['nodes'].keyset expected = set(anode.id for _, anode in nodes.items()) expected = expected.difference(set([nodes['data_i'].id])) - self.assertEqual(obtained, expected) + assert obtained == expected obtained = set() for data in uprule_result['nodes_nodes'].keyset: @@ -649,7 +645,7 @@ def test_edges(self): (nodes['work_1'].id, nodes['data_o'].id), (nodes['work_2'].id, nodes['data_o'].id), } - self.assertEqual(obtained, expected) + assert obtained == expected def test_empty_input(self): """ @@ -660,16 +656,13 @@ def test_empty_input(self): queryb.append(orm.Node).append(orm.Node) uprule = UpdateRule(queryb, max_iterations=np.inf) result = uprule.run(basket.copy())['nodes'].keyset - self.assertEqual(result, set()) + assert result == set() -class TestAiidaEntitySet(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestAiidaEntitySet: """Tests for AiidaEntitySets""" - def setUp(self): - super().setUp() - self.refurbish_db() - def test_class_mismatch(self): """ Test the case where an AiidaEntitySet is trying to be used in an operation @@ -681,16 +674,16 @@ def test_class_mismatch(self): des_node_node = DirectedEdgeSet(orm.Node, orm.Node) python_set = {1, 2, 3, 4} - with self.assertRaises(TypeError): + with pytest.raises(TypeError): _ = aes_node + aes_group - with self.assertRaises(TypeError): + with pytest.raises(TypeError): _ = aes_node + des_node_node - with self.assertRaises(TypeError): + with pytest.raises(TypeError): _ = aes_group + des_node_node - with self.assertRaises(TypeError): + with pytest.raises(TypeError): _ = aes_node + python_set def test_algebra(self): @@ -718,22 +711,22 @@ def test_algebra(self): aes2 = aes0 + aes1 union01 = aes0.keyset | aes1.keyset - self.assertEqual(aes2.keyset, union01) + assert aes2.keyset == union01 aes0_copy = aes0.copy() aes0_copy += aes1 - self.assertEqual(aes0_copy.keyset, union01) + assert aes0_copy.keyset == union01 aes3 = aes0_copy - aes1 - self.assertEqual(aes0.keyset, aes3.keyset) - self.assertEqual(aes0, aes3) + assert aes0.keyset == aes3.keyset + assert aes0 == aes3 aes0_copy -= aes1 - self.assertEqual(aes0.keyset, aes3.keyset, aes0_copy.keyset) - self.assertEqual(aes0, aes3, aes0_copy) + assert aes0.keyset == aes3.keyset, aes0_copy.keyset + assert aes0 == aes3, aes0_copy aes4 = aes0 - aes0 - self.assertEqual(aes4.keyset, set()) + assert aes4.keyset == set() aes0_copy -= aes0 - self.assertEqual(aes0_copy.keyset, set()) + assert aes0_copy.keyset == set() diff --git a/tests/tools/graph/test_graph_traversers.py b/tests/tools/graph/test_graph_traversers.py index 8db6af587e..58e06e2d1e 100644 --- a/tests/tools/graph/test_graph_traversers.py +++ b/tests/tools/graph/test_graph_traversers.py @@ -7,10 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for aiida.tools.graph.graph_traversers""" +import pytest from aiida.common.links import LinkType -from aiida.storage.testbase import AiidaTestCase from aiida.tools.graph.graph_traversers import get_nodes_delete, traverse_graph @@ -82,7 +83,8 @@ def create_minimal_graph(): return output_dict -class TestTraverseGraph(AiidaTestCase): +@pytest.mark.usefixtures('aiida_profile_clean') +class TestTraverseGraph: """Test class for traverse_graph""" def _single_test(self, starting_nodes=(), expanded_nodes=(), links_forward=(), links_backward=()): @@ -93,7 +95,7 @@ def _single_test(self, starting_nodes=(), expanded_nodes=(), links_forward=(), l links_backward=links_backward, )['nodes'] expected_nodes = set(starting_nodes + expanded_nodes) - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes def test_traversal_individually(self): """ @@ -302,7 +304,7 @@ def test_traversal_cycle(self): for single_node in every_node: expected_nodes = set([single_node]) obtained_nodes = traverse_graph([single_node])['nodes'] - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes links_forward = [LinkType.INPUT_WORK, LinkType.RETURN] links_backward = [] @@ -311,18 +313,18 @@ def test_traversal_cycle(self): obtained_nodes = traverse_graph([data_drop], links_forward=links_forward, links_backward=links_backward)['nodes'] expected_nodes = set(every_node) - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes # Forward: data_take to (input) work_select (data_drop is not returned) obtained_nodes = traverse_graph([data_take], links_forward=links_forward, links_backward=links_backward)['nodes'] expected_nodes = set([work_select, data_take]) - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes # Forward: work_select to (return) data_take (data_drop is not returned) obtained_nodes = traverse_graph([work_select], links_forward=links_forward, links_backward=links_backward)['nodes'] - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes links_forward = [] links_backward = [LinkType.INPUT_WORK, LinkType.RETURN] @@ -331,18 +333,18 @@ def test_traversal_cycle(self): expected_nodes = set([data_drop]) obtained_nodes = traverse_graph([data_drop], links_forward=links_forward, links_backward=links_backward)['nodes'] - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes # Backward: data_take to (return) work_select to (input) data_drop expected_nodes = set(every_node) obtained_nodes = traverse_graph([data_take], links_forward=links_forward, links_backward=links_backward)['nodes'] - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes # Backward: work_select to (input) data_take and data_drop obtained_nodes = traverse_graph([work_select], links_forward=links_forward, links_backward=links_backward)['nodes'] - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes def test_traversal_errors(self): """This will test the errors of the traversers.""" @@ -352,19 +354,19 @@ def test_traversal_errors(self): test_node = orm.Data().store() false_node = -1 - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): _ = traverse_graph([false_node]) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): _ = traverse_graph(['not a node']) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): _ = traverse_graph('not a list') - with self.assertRaises(TypeError): + with pytest.raises(TypeError): _ = traverse_graph([test_node], links_forward=1984) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): _ = traverse_graph([test_node], links_backward=['not a link']) def test_empty_input(self): @@ -376,12 +378,12 @@ def test_empty_input(self): ] obtained_results = traverse_graph([], links_forward=all_links, links_backward=all_links) - self.assertEqual(obtained_results['nodes'], set()) - self.assertEqual(obtained_results['links'], None) + assert obtained_results['nodes'] == set() + assert obtained_results['links'] is None obtained_results = traverse_graph([], get_links=True, links_forward=all_links, links_backward=all_links) - self.assertEqual(obtained_results['nodes'], set()) - self.assertEqual(obtained_results['links'], set()) + assert obtained_results['nodes'] == set() + assert obtained_results['links'] == set() def test_delete_aux(self): """Tests for the get_nodes_delete function""" @@ -391,23 +393,23 @@ def test_delete_aux(self): obtained_nodes = get_nodes_delete([nodes_dict['data_i'].pk])['nodes'] expected_nodes = set(nodes_pklist) - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes obtained_nodes = get_nodes_delete([nodes_dict['data_o'].pk])['nodes'] expected_nodes = set(nodes_pklist).difference(set([nodes_dict['data_i'].pk])) - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes obtained_nodes = get_nodes_delete([nodes_dict['work_1'].pk], call_calc_forward=False)['nodes'] expected_nodes = set([nodes_dict['work_1'].pk, nodes_dict['work_2'].pk]) - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes obtained_nodes = get_nodes_delete([nodes_dict['work_2'].pk], call_work_forward=False)['nodes'] expected_nodes = set([nodes_dict['work_2'].pk]) - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes obtained_nodes = get_nodes_delete([nodes_dict['calc_0'].pk], create_forward=False)['nodes'] expected_nodes = set([nodes_dict['calc_0'].pk, nodes_dict['work_1'].pk, nodes_dict['work_2'].pk]) - self.assertEqual(obtained_nodes, expected_nodes) + assert obtained_nodes == expected_nodes - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _ = get_nodes_delete([nodes_dict['data_o'].pk], create_backward=False) diff --git a/tests/tools/visualization/test_graph.py b/tests/tools/visualization/test_graph.py index 434b6a747b..aedfe99fc7 100644 --- a/tests/tools/visualization/test_graph.py +++ b/tests/tools/visualization/test_graph.py @@ -8,22 +8,24 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for creating graphs (using graphviz)""" +import pytest from aiida import orm from aiida.common import AttributeDict from aiida.common.links import LinkType from aiida.engine import ProcessState from aiida.orm.utils.links import LinkPair -from aiida.storage.testbase import AiidaTestCase from aiida.tools.visualization import graph as graph_mod -class TestVisGraph(AiidaTestCase): +class TestVisGraph: """Tests for verdi graph""" - def setUp(self): - super().setUp() - self.refurbish_db() + @pytest.fixture(autouse=True) + def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + """Initialize the profile.""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost def create_provenance(self): """create an example provenance graph @@ -110,16 +112,16 @@ def test_graph_add_node(self): graph = graph_mod.Graph() graph.add_node(nodes.pd0) - self.assertEqual(graph.nodes, set([nodes.pd0.pk])) - self.assertEqual(graph.edges, set()) + assert graph.nodes == set([nodes.pd0.pk]) + assert graph.edges == set() # try adding a second time graph.add_node(nodes.pd0) - self.assertEqual(graph.nodes, set([nodes.pd0.pk])) + assert graph.nodes == set([nodes.pd0.pk]) # add second node graph.add_node(nodes.pd1) - self.assertEqual(graph.nodes, set([nodes.pd0.pk, nodes.pd1.pk])) + assert graph.nodes == set([nodes.pd0.pk, nodes.pd1.pk]) def test_graph_add_edge(self): """ test adding an edge to the graph """ @@ -129,8 +131,8 @@ def test_graph_add_edge(self): graph.add_node(nodes.pd0) graph.add_node(nodes.rd1) graph.add_edge(nodes.pd0, nodes.rd1) - self.assertEqual(graph.nodes, set([nodes.pd0.pk, nodes.rd1.pk])) - self.assertEqual(graph.edges, set([(nodes.pd0.pk, nodes.rd1.pk, None)])) + assert graph.nodes == set([nodes.pd0.pk, nodes.rd1.pk]) + assert graph.edges == set([(nodes.pd0.pk, nodes.rd1.pk, None)]) def test_graph_add_incoming(self): """ test adding a node and all its incoming nodes to a graph""" @@ -139,13 +141,11 @@ def test_graph_add_incoming(self): graph = graph_mod.Graph() graph.add_incoming(nodes.calc1) - self.assertEqual(graph.nodes, set([nodes.calc1.pk, nodes.pd0.pk, nodes.pd1.pk, nodes.wc1.pk])) - self.assertEqual( - graph.edges, + assert graph.nodes == set([nodes.calc1.pk, nodes.pd0.pk, nodes.pd1.pk, nodes.wc1.pk]) + assert graph.edges == \ set([(nodes.pd0.pk, nodes.calc1.pk, LinkPair(LinkType.INPUT_CALC, 'input1')), (nodes.pd1.pk, nodes.calc1.pk, LinkPair(LinkType.INPUT_CALC, 'input2')), (nodes.wc1.pk, nodes.calc1.pk, LinkPair(LinkType.CALL_CALC, 'call1'))]) - ) def test_graph_add_outgoing(self): """ test adding a node and all its outgoing nodes to a graph""" @@ -154,12 +154,10 @@ def test_graph_add_outgoing(self): graph = graph_mod.Graph() graph.add_outgoing(nodes.calcf1) - self.assertEqual(graph.nodes, set([nodes.calcf1.pk, nodes.pd3.pk, nodes.fd1.pk])) - self.assertEqual( - graph.edges, + assert graph.nodes == set([nodes.calcf1.pk, nodes.pd3.pk, nodes.fd1.pk]) + assert graph.edges == \ set([(nodes.calcf1.pk, nodes.pd3.pk, LinkPair(LinkType.CREATE, 'output1')), (nodes.calcf1.pk, nodes.fd1.pk, LinkPair(LinkType.CREATE, 'output2'))]) - ) def test_graph_recurse_ancestors(self): """ test adding nodes and all its (recursed) incoming nodes to a graph""" @@ -168,16 +166,14 @@ def test_graph_recurse_ancestors(self): graph = graph_mod.Graph() graph.recurse_ancestors(nodes.rd1) - self.assertEqual(graph.nodes, set([nodes.rd1.pk, nodes.calc1.pk, nodes.pd0.pk, nodes.pd1.pk, nodes.wc1.pk])) - self.assertEqual( - graph.edges, + assert graph.nodes == set([nodes.rd1.pk, nodes.calc1.pk, nodes.pd0.pk, nodes.pd1.pk, nodes.wc1.pk]) + assert graph.edges == \ set([(nodes.calc1.pk, nodes.rd1.pk, LinkPair(LinkType.CREATE, 'output')), (nodes.pd0.pk, nodes.calc1.pk, LinkPair(LinkType.INPUT_CALC, 'input1')), (nodes.pd1.pk, nodes.calc1.pk, LinkPair(LinkType.INPUT_CALC, 'input2')), (nodes.wc1.pk, nodes.calc1.pk, LinkPair(LinkType.CALL_CALC, 'call1')), (nodes.pd0.pk, nodes.wc1.pk, LinkPair(LinkType.INPUT_WORK, 'input1')), (nodes.pd1.pk, nodes.wc1.pk, LinkPair(LinkType.INPUT_WORK, 'input2'))]) - ) def test_graph_recurse_spot_highlight_classes(self): """ test adding nodes and all its (recursed) incoming nodes to a graph""" @@ -200,14 +196,13 @@ def test_graph_recurse_spot_highlight_classes(self): expected_diff = """\ +State: running" color=lightgray fillcolor=white penwidth=2 shape=rectangle style=filled] - +@localhost" color=lightgray fillcolor=white penwidth=2 shape=ellipse style=filled] + +@localhost-test" color=lightgray fillcolor=white penwidth=2 shape=ellipse style=filled] +Exit Code: 200" color=lightgray fillcolor=white penwidth=2 shape=rectangle style=filled] +\tN{fd1} [label="FolderData ({fd1})" color=lightgray fillcolor=white penwidth=2 shape=ellipse style=filled] +++""".format(**{k: v.pk for k, v in nodes.items()}) - self.assertEqual( - sorted([l.strip() for l in got_diff.splitlines()]), sorted([l.strip() for l in expected_diff.splitlines()]) - ) + assert sorted([l.strip() for l in got_diff.splitlines()] + ) == sorted([l.strip() for l in expected_diff.splitlines()]) def test_graph_recurse_ancestors_filter_links(self): """ test adding nodes and all its (recursed) incoming nodes to a graph, but filter link types""" @@ -216,13 +211,11 @@ def test_graph_recurse_ancestors_filter_links(self): graph = graph_mod.Graph() graph.recurse_ancestors(nodes.rd1, link_types=['create', 'input_calc']) - self.assertEqual(graph.nodes, set([nodes.rd1.pk, nodes.calc1.pk, nodes.pd0.pk, nodes.pd1.pk])) - self.assertEqual( - graph.edges, + assert graph.nodes == set([nodes.rd1.pk, nodes.calc1.pk, nodes.pd0.pk, nodes.pd1.pk]) + assert graph.edges == \ set([(nodes.calc1.pk, nodes.rd1.pk, LinkPair(LinkType.CREATE, 'output')), (nodes.pd0.pk, nodes.calc1.pk, LinkPair(LinkType.INPUT_CALC, 'input1')), (nodes.pd1.pk, nodes.calc1.pk, LinkPair(LinkType.INPUT_CALC, 'input2'))]) - ) def test_graph_recurse_descendants(self): """ test adding nodes and all its (recursed) incoming nodes to a graph""" @@ -231,15 +224,13 @@ def test_graph_recurse_descendants(self): graph = graph_mod.Graph() graph.recurse_descendants(nodes.rd1) - self.assertEqual(graph.nodes, set([nodes.rd1.pk, nodes.calcf1.pk, nodes.pd3.pk, nodes.fd1.pk])) - self.assertEqual( - graph.edges, + assert graph.nodes == set([nodes.rd1.pk, nodes.calcf1.pk, nodes.pd3.pk, nodes.fd1.pk]) + assert graph.edges == \ set([ (nodes.rd1.pk, nodes.calcf1.pk, LinkPair(LinkType.INPUT_CALC, 'input1')), (nodes.calcf1.pk, nodes.pd3.pk, LinkPair(LinkType.CREATE, 'output1')), (nodes.calcf1.pk, nodes.fd1.pk, LinkPair(LinkType.CREATE, 'output2')), ]) - ) def test_graph_graphviz_source(self): """ test the output of graphviz source """ @@ -259,7 +250,7 @@ def test_graph_graphviz_source(self): State: running" fillcolor="#e38851ff" penwidth=0 shape=rectangle style=filled] N{pd0} -> N{wc1} [color="#000000" style=dashed] N{rd1} [label="RemoteData ({rd1}) - @localhost" fillcolor="#8cd499ff" penwidth=0 shape=ellipse style=filled] + @localhost-test" fillcolor="#8cd499ff" penwidth=0 shape=ellipse style=filled] N{calc1} -> N{rd1} [color="#000000" style=solid] N{fd1} [label="FolderData ({fd1})" fillcolor="#8cd499ff" penwidth=0 shape=ellipse style=filled] N{wc1} -> N{fd1} [color="#000000" style=dashed] @@ -276,10 +267,8 @@ def test_graph_graphviz_source(self): }}""".format(**{k: v.pk for k, v in nodes.items()}) # dedent before comparison - self.assertEqual( - sorted([l.strip() for l in graph.graphviz.source.splitlines()]), + assert sorted([l.strip() for l in graph.graphviz.source.splitlines()]) == \ sorted([l.strip() for l in expected.splitlines()]) - ) def test_graph_graphviz_source_pstate(self): """ test the output of graphviz source, with the `pstate_node_styles` function """ @@ -303,7 +292,7 @@ def test_graph_graphviz_source_pstate(self): State: running" fillcolor="#e38851ff" penwidth=0 shape=polygon sides=6 style=filled] N{pd0} -> N{wc1} [color="#000000" style=dashed] N{rd1} [label="RemoteData ({rd1}) - @localhost" pencolor=black shape=rectangle] + @localhost-test" pencolor=black shape=rectangle] N{calc1} -> N{rd1} [color="#000000" style=solid] N{fd1} [label="FolderData ({fd1})" pencolor=black shape=rectangle] N{wc1} -> N{fd1} [color="#000000" style=dashed] @@ -320,7 +309,5 @@ def test_graph_graphviz_source_pstate(self): }}""".format(**{k: v.pk for k, v in nodes.items()}) # dedent before comparison - self.assertEqual( - sorted([l.strip() for l in graph.graphviz.source.splitlines()]), + assert sorted([l.strip() for l in graph.graphviz.source.splitlines()]) == \ sorted([l.strip() for l in expected.splitlines()]) - ) From 4cf468e208cc639579b93ba9f05a2a6eb1b09f64 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 22 Feb 2022 02:58:58 +0100 Subject: [PATCH 2/3] Update tests.sh --- .github/workflows/tests.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/tests.sh b/.github/workflows/tests.sh index 2127c4f06f..fe31c799f9 100755 --- a/.github/workflows/tests.sh +++ b/.github/workflows/tests.sh @@ -8,7 +8,6 @@ SYSTEM_TESTS="${GITHUB_WORKSPACE}/.github/system_tests" pytest --cov aiida --verbose --noconftest ${SYSTEM_TESTS}/test_test_manager.py pytest --cov aiida --verbose --noconftest ${SYSTEM_TESTS}/test_ipython_magics.py pytest --cov aiida --verbose --noconftest ${SYSTEM_TESTS}/test_profile_manager.py -python ${SYSTEM_TESTS}/test_plugin_testcase.py # uses custom unittest test runner # Until the `${SYSTEM_TESTS}/pytest` tests are moved within `tests` we have to run them separately and pass in the path to the # `conftest.py` explicitly, because otherwise it won't be able to find the fixtures it provides From 42dbc55a2ea24170f09c57e56f464e9144bccdfd Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 22 Feb 2022 03:30:42 +0100 Subject: [PATCH 3/3] Update test_config.py --- tests/restapi/test_config.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/restapi/test_config.py b/tests/restapi/test_config.py index 9742640730..bccdb77de3 100644 --- a/tests/restapi/test_config.py +++ b/tests/restapi/test_config.py @@ -11,19 +11,29 @@ # pylint: disable=redefined-outer-name import pytest +from aiida import orm +from aiida.manage import get_manager + @pytest.fixture -def create_app(): +def create_app(aiida_profile_clean): # pylint: disable=unused-argument """Set up Flask App""" from aiida.restapi.run_api import configure_api + user = orm.User.objects.get_default() + def _create_app(**kwargs): catch_internal_server = kwargs.pop('catch_internal_server', True) api = configure_api(catch_internal_server=catch_internal_server, **kwargs) api.app.config['TESTING'] = True return api.app - return _create_app + yield _create_app + + # because the `close_thread_connection` decorator, currently, directly closes the SQLA session, + # the default user will be detached from the session, and the `_clean` method will fail. + # So, we need to reattach the default user to the session. + get_manager().get_profile_storage().get_session().add(user.backend_entity.bare_model) def test_posting(create_app):