diff --git a/.github/workflows/pytests.yml b/.github/workflows/pytests.yml index 11550bb1..ceec692d 100644 --- a/.github/workflows/pytests.yml +++ b/.github/workflows/pytests.yml @@ -41,12 +41,12 @@ jobs: - name: Install dependencies from pip run: | echo "numpy<2" >> $PIP_CONSTRAINT - python3 -m pip install wheel setuptools numpy scipy click matplotlib pyyaml spglib rdkit flake8 pytest pytest-cov requests + python3 -m pip install wheel setuptools numpy scipy click matplotlib pyyaml spglib rdkit==2024.3.3 flake8 pytest pytest-cov requests - name: Install latest ASE from pypi run: | echo PIP_CONSTRAINT $PIP_CONSTRAINT - python3 -m pip install ase + python3 -m pip install ase echo -n "ASE VERSION " python3 -c "import ase; print(ase.__file__, ase.__version__)" @@ -105,15 +105,27 @@ jobs: run: | echo "search for torch version" set +o pipefail + + # echo "torch versions" + # python3 -m pip install torch== + # echo "torch versions to search" + # python3 -m pip install torch== 2>&1 | fgrep 'from versions' | + # sed -e 's/.*from versions: //' -e 's/)//' -e 's/,[ ]*/\n/g' | tac + # search for available torch version with +cpu support - for torch_version_test in $( python3 -m pip install torch== 2>&1 | fgrep 'from versions' | - sed -e 's/.*from versions: //' -e 's/)//' -e 's/,[ ]*/\n/g' | tac ); do + # for torch_version_test in $( python3 -m pip install torch== 2>&1 | fgrep 'from versions' | + # sed -e 's/.*from versions: //' -e 's/)//' -e 's/,[ ]*/\n/g' | tac ); do + wget https://pypi.org/pypi/torch/json -O torch_versions + for torch_version_test in $( python3 -c "import json; print(' '.join(json.load(open('torch_versions'))['releases'].keys()))" | sed 's/ /\n/g' | tac ); do + echo "check torch_version_test $torch_version_test" set +e python3 -m pip install --dry-run torch==${torch_version_test}+cpu \ - -f https://download.pytorch.org/whl/torch_stable.html > /dev/null 2>&1 + -f https://download.pytorch.org/whl/torch_stable.html 2>&1 search_stat=$? + echo "got search_stat $search_stat" set -e if [ $search_stat == 0 ]; then + echo "got valid +cpu version, exiting" torch_version=${torch_version_test} break fi diff --git a/tests/calculators/test_ase_fileio_caching.py b/tests/calculators/test_ase_fileio_caching.py new file mode 100644 index 00000000..00883e03 --- /dev/null +++ b/tests/calculators/test_ase_fileio_caching.py @@ -0,0 +1,78 @@ +import os + +import pytest + +from ase.atoms import Atoms + + +######################## +# test Vasp calculator + +from tests.calculators.test_vasp import test_vasp_mark +@pytest.mark.skipif(test_vasp_mark, reason='Vasp testing env vars missing') +def test_vasp_cache_timing(tmp_path, monkeypatch): + from ase.calculators.vasp import Vasp as Vasp_ase + from wfl.calculators.vasp import Vasp as Vasp_wrap + + config = Atoms('Si', positions=[[0, 0, 9]], cell=[2, 2, 2], pbc=[True, True, True]) + kwargs_ase = {'encut': 200, 'pp': os.environ['PYTEST_VASP_POTCAR_DIR']} + kwargs_wrapper = {'workdir': tmp_path} + # make sure 'pp' is relative to correct dir (see wfl.calculators.vasp) + if os.environ['PYTEST_VASP_POTCAR_DIR'].startswith('/'): + monkeypatch.setenv("VASP_PP_PATH", "/.") + else: + monkeypatch.setenv("VASP_PP_PATH", ".") + cache_timing(config, Vasp_ase, kwargs_ase, Vasp_wrap, kwargs_wrapper, tmp_path, monkeypatch) + +######################## +# test quantum espresso calculator +from tests.calculators.test_qe import espresso_avail, qe_pseudo +@pytest.mark.skipif(not espresso_avail, reason='qe testing env vars missing') +def test_qe_cache_timing(tmp_path, monkeypatch, qe_pseudo): + from ase.calculators.espresso import Espresso as Espresso_ASE + from wfl.calculators.espresso import Espresso as Espresso_wrap + + config = Atoms('Si', positions=[[0, 0, 9]], cell=[2, 2, 2], pbc=[True, True, True]) + + pspot = qe_pseudo + kwargs_ase = dict( + pseudopotentials=dict(Si=pspot.name), + pseudo_dir=pspot.parent, + input_data={"SYSTEM": {"ecutwfc": 40, "input_dft": "LDA",}}, + kpts=(2, 3, 4), + conv_thr=0.0001, + workdir=tmp_path + ) + + kwargs_wrapper = {} + cache_timing(config, Espresso_ASE, kwargs_ase, Espresso_wrap, kwargs_wrapper, tmp_path, monkeypatch) + + +######################## +# generic code used by all calculators + +import time + +from wfl.configset import ConfigSet, OutputSpec +from wfl.calculators import generic + +def cache_timing(config, calc_ase, kwargs_ase, calc_wfl, kwargs_wrapper, rundir, monkeypatch): + (rundir / "run_calc_ase").mkdir() + + calc = calc_ase(**kwargs_ase) + config.calc = calc + + monkeypatch.chdir(rundir / "run_calc_ase") + t0 = time.time() + E = config.get_potential_energy() + ase_time = time.time() - t0 + + monkeypatch.chdir(rundir) + t0 = time.time() + _ = generic.calculate(inputs=ConfigSet(config), outputs=OutputSpec(), + calculator=calc_wfl(**kwargs_wrapper, **kwargs_ase)) + wfl_time = time.time() - t0 + + print("ASE", ase_time, "WFL", wfl_time) + + assert wfl_time < ase_time * 1.25 diff --git a/tests/calculators/test_calc_generic.py b/tests/calculators/test_calc_generic.py index 5e828a9a..d08c12bb 100644 --- a/tests/calculators/test_calc_generic.py +++ b/tests/calculators/test_calc_generic.py @@ -139,7 +139,6 @@ def test_generic_autopara_defaults(): sys.stderr = sys.__stderr__ assert "num_inputs_per_python_subprocess=3" in l_stderr.getvalue() -@pytest.mark.xfail(reason="Waiting for update to work with ASE3.23") def test_generic_DFT_autopara_defaults(tmp_path, monkeypatch): ats = [Atoms('Al2', positions=[[0,0,0], [1,1,1]], cell=[10]*3, pbc=[True]*3) for _ in range(50)] @@ -151,6 +150,6 @@ def test_generic_DFT_autopara_defaults(tmp_path, monkeypatch): # try with a calculator that overrides an autopara default, namely a DFT calculator # that sets num_inputs_per_python_subprocess=1 sys.stderr = l_stderr - at_proc = generic.calculate(ci, os, Espresso(calculator_exec="_DUMMY_EXEC_", pseudo_dir="_DUMMY_DIR_", workdir=tmp_path)) + at_proc = generic.calculate(ci, os, Espresso(workdir=tmp_path)) sys.stderr = sys.__stderr__ assert "num_inputs_per_python_subprocess=1" in l_stderr.getvalue() diff --git a/tests/calculators/test_vasp.py b/tests/calculators/test_vasp.py index 3be3485e..6d323880 100644 --- a/tests/calculators/test_vasp.py +++ b/tests/calculators/test_vasp.py @@ -13,10 +13,11 @@ from wfl.calculators import generic from wfl.configset import ConfigSet, OutputSpec -pytestmark = pytest.mark.skipif('ASE_VASP_COMMAND' not in os.environ or - 'ASE_VASP_COMMAND_GAMMA' not in os.environ or - 'PYTEST_VASP_POTCAR_DIR' not in os.environ, - reason='missing env var ASE_VASP_COMMAND or ASE_VASP_COMMAND_GAMMA or PYTEST_VASP_POTCAR_DIR') +test_vasp_mark = ('ASE_VASP_COMMAND' not in os.environ or + 'ASE_VASP_COMMAND_GAMMA' not in os.environ or + 'PYTEST_VASP_POTCAR_DIR' not in os.environ) +pytestmark = pytest.mark.skipif(test_vasp_mark, reason='missing env var ASE_VASP_COMMAND or ASE_VASP_COMMAND_GAMMA ' + 'or PYTEST_VASP_POTCAR_DIR') def test_vasp_gamma(tmp_path, monkeypatch): diff --git a/tests/calculators/test_wrapped_calculator.py b/tests/calculators/test_wrapped_calculator.py new file mode 100644 index 00000000..4fe2b3e9 --- /dev/null +++ b/tests/calculators/test_wrapped_calculator.py @@ -0,0 +1,30 @@ +import pytest +from ase.atoms import Atoms +from wfl.configset import ConfigSet, OutputSpec +from wfl.calculators import generic + +######################## +# test a RuntimeWarning is raised when using the Espresso Calculator directly from ase +from tests.calculators.test_qe import espresso_avail, qe_pseudo +@pytest.mark.skipif(not espresso_avail, reason='qe testing env vars missing') +def test_wrapped_qe(tmp_path, qe_pseudo): + from ase.calculators.espresso import Espresso as Espresso_ASE + from wfl.calculators.espresso import Espresso as Espresso_wrap + + config = Atoms('Si', positions=[[0, 0, 9]], cell=[2, 2, 2], pbc=[True, True, True]) + + pspot = qe_pseudo + kwargs = dict( + pseudopotentials=dict(Si=pspot.name), + pseudo_dir=pspot.parent, + input_data={"SYSTEM": {"ecutwfc": 40, "input_dft": "LDA",}}, + kpts=(2, 3, 4), + conv_thr=0.0001, + workdir=tmp_path, + tstress=True, + tprnfor=True + ) + + direct_calc = (Espresso_ASE, [], kwargs) + kwargs_generic = dict(inputs=ConfigSet(config), outputs=OutputSpec(), calculator=direct_calc) + pytest.warns(RuntimeWarning, generic.calculate, **kwargs_generic) \ No newline at end of file diff --git a/wfl/calculators/generic.py b/wfl/calculators/generic.py index 50c4b350..544b8103 100644 --- a/wfl/calculators/generic.py +++ b/wfl/calculators/generic.py @@ -150,6 +150,20 @@ def calculate(*args, **kwargs): if calculator is None: calculator = args[2] + #check if calculator should be wrapped + if type(calculator) == tuple: + from ase.calculators.espresso import Espresso as ASE_Espresso + from ase.calculators.vasp.vasp import Vasp as ASE_Vasp + from ase.calculators.aims import Aims as ASE_Aims + from ase.calculators.castep import Castep as ASE_Castep + from ase.calculators.mopac import MOPAC as ASE_MOPAC + from ase.calculators.orca import ORCA as ASE_ORCA + wrapped_types = [ASE_Espresso, ASE_Vasp, ASE_Aims, ASE_Castep, ASE_MOPAC, ASE_ORCA] + + calc = calculator[0] + if calc in wrapped_types: + warnings.warn(f"{calc} should be imported from wfl.calculators rather than ase. Using {calc} directly can lead to duplicated singlepoints", RuntimeWarning) + default_autopara_info = getattr(calculator, "wfl_generic_default_autopara_info", {"num_inputs_per_python_subprocess": 10}) return autoparallelize(_run_autopara_wrappable, *args, default_autopara_info=default_autopara_info, **kwargs)