Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test for ase fileio caching with VASP example #332

Merged
merged 10 commits into from
Aug 29, 2024
22 changes: 17 additions & 5 deletions .github/workflows/pytests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ jobs:
- name: Install dependencies from pip
run: |
echo "numpy<2" >> $PIP_CONSTRAINT
python3 -m pip install wheel setuptools numpy scipy click matplotlib pyyaml spglib rdkit flake8 pytest pytest-cov requests
python3 -m pip install wheel setuptools numpy scipy click matplotlib pyyaml spglib rdkit==2024.3.3 flake8 pytest pytest-cov requests

- name: Install latest ASE from pypi
run: |
echo PIP_CONSTRAINT $PIP_CONSTRAINT
python3 -m pip install ase
python3 -m pip install ase
echo -n "ASE VERSION "
python3 -c "import ase; print(ase.__file__, ase.__version__)"

Expand Down Expand Up @@ -105,15 +105,27 @@ jobs:
run: |
echo "search for torch version"
set +o pipefail

# echo "torch versions"
# python3 -m pip install torch==
# echo "torch versions to search"
# python3 -m pip install torch== 2>&1 | fgrep 'from versions' |
# sed -e 's/.*from versions: //' -e 's/)//' -e 's/,[ ]*/\n/g' | tac

# search for available torch version with +cpu support
for torch_version_test in $( python3 -m pip install torch== 2>&1 | fgrep 'from versions' |
sed -e 's/.*from versions: //' -e 's/)//' -e 's/,[ ]*/\n/g' | tac ); do
# for torch_version_test in $( python3 -m pip install torch== 2>&1 | fgrep 'from versions' |
# sed -e 's/.*from versions: //' -e 's/)//' -e 's/,[ ]*/\n/g' | tac ); do
wget https://pypi.org/pypi/torch/json -O torch_versions
for torch_version_test in $( python3 -c "import json; print(' '.join(json.load(open('torch_versions'))['releases'].keys()))" | sed 's/ /\n/g' | tac ); do
echo "check torch_version_test $torch_version_test"
set +e
python3 -m pip install --dry-run torch==${torch_version_test}+cpu \
-f https://download.pytorch.org/whl/torch_stable.html > /dev/null 2>&1
-f https://download.pytorch.org/whl/torch_stable.html 2>&1
search_stat=$?
echo "got search_stat $search_stat"
set -e
if [ $search_stat == 0 ]; then
echo "got valid +cpu version, exiting"
torch_version=${torch_version_test}
break
fi
Expand Down
78 changes: 78 additions & 0 deletions tests/calculators/test_ase_fileio_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os

import pytest

from ase.atoms import Atoms


########################
# test Vasp calculator

from tests.calculators.test_vasp import test_vasp_mark
@pytest.mark.skipif(test_vasp_mark, reason='Vasp testing env vars missing')
def test_vasp_cache_timing(tmp_path, monkeypatch):
from ase.calculators.vasp import Vasp as Vasp_ase
from wfl.calculators.vasp import Vasp as Vasp_wrap

config = Atoms('Si', positions=[[0, 0, 9]], cell=[2, 2, 2], pbc=[True, True, True])
kwargs_ase = {'encut': 200, 'pp': os.environ['PYTEST_VASP_POTCAR_DIR']}
kwargs_wrapper = {'workdir': tmp_path}
# make sure 'pp' is relative to correct dir (see wfl.calculators.vasp)
if os.environ['PYTEST_VASP_POTCAR_DIR'].startswith('/'):
monkeypatch.setenv("VASP_PP_PATH", "/.")
else:
monkeypatch.setenv("VASP_PP_PATH", ".")
cache_timing(config, Vasp_ase, kwargs_ase, Vasp_wrap, kwargs_wrapper, tmp_path, monkeypatch)

########################
# test quantum espresso calculator
from tests.calculators.test_qe import espresso_avail, qe_pseudo
@pytest.mark.skipif(not espresso_avail, reason='qe testing env vars missing')
def test_qe_cache_timing(tmp_path, monkeypatch, qe_pseudo):
from ase.calculators.espresso import Espresso as Espresso_ASE
from wfl.calculators.espresso import Espresso as Espresso_wrap

config = Atoms('Si', positions=[[0, 0, 9]], cell=[2, 2, 2], pbc=[True, True, True])

pspot = qe_pseudo
kwargs_ase = dict(
pseudopotentials=dict(Si=pspot.name),
pseudo_dir=pspot.parent,
input_data={"SYSTEM": {"ecutwfc": 40, "input_dft": "LDA",}},
kpts=(2, 3, 4),
conv_thr=0.0001,
workdir=tmp_path
)

kwargs_wrapper = {}
cache_timing(config, Espresso_ASE, kwargs_ase, Espresso_wrap, kwargs_wrapper, tmp_path, monkeypatch)


########################
# generic code used by all calculators

import time

from wfl.configset import ConfigSet, OutputSpec
from wfl.calculators import generic

def cache_timing(config, calc_ase, kwargs_ase, calc_wfl, kwargs_wrapper, rundir, monkeypatch):
(rundir / "run_calc_ase").mkdir()

calc = calc_ase(**kwargs_ase)
config.calc = calc

monkeypatch.chdir(rundir / "run_calc_ase")
t0 = time.time()
E = config.get_potential_energy()
ase_time = time.time() - t0

monkeypatch.chdir(rundir)
t0 = time.time()
_ = generic.calculate(inputs=ConfigSet(config), outputs=OutputSpec(),
calculator=calc_wfl(**kwargs_wrapper, **kwargs_ase))
wfl_time = time.time() - t0

print("ASE", ase_time, "WFL", wfl_time)

assert wfl_time < ase_time * 1.25
3 changes: 1 addition & 2 deletions tests/calculators/test_calc_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def test_generic_autopara_defaults():
sys.stderr = sys.__stderr__
assert "num_inputs_per_python_subprocess=3" in l_stderr.getvalue()

@pytest.mark.xfail(reason="Waiting for update to work with ASE3.23")
def test_generic_DFT_autopara_defaults(tmp_path, monkeypatch):
ats = [Atoms('Al2', positions=[[0,0,0], [1,1,1]], cell=[10]*3, pbc=[True]*3) for _ in range(50)]

Expand All @@ -151,6 +150,6 @@ def test_generic_DFT_autopara_defaults(tmp_path, monkeypatch):
# try with a calculator that overrides an autopara default, namely a DFT calculator
# that sets num_inputs_per_python_subprocess=1
sys.stderr = l_stderr
at_proc = generic.calculate(ci, os, Espresso(calculator_exec="_DUMMY_EXEC_", pseudo_dir="_DUMMY_DIR_", workdir=tmp_path))
at_proc = generic.calculate(ci, os, Espresso(workdir=tmp_path))
sys.stderr = sys.__stderr__
assert "num_inputs_per_python_subprocess=1" in l_stderr.getvalue()
9 changes: 5 additions & 4 deletions tests/calculators/test_vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from wfl.calculators import generic
from wfl.configset import ConfigSet, OutputSpec

pytestmark = pytest.mark.skipif('ASE_VASP_COMMAND' not in os.environ or
'ASE_VASP_COMMAND_GAMMA' not in os.environ or
'PYTEST_VASP_POTCAR_DIR' not in os.environ,
reason='missing env var ASE_VASP_COMMAND or ASE_VASP_COMMAND_GAMMA or PYTEST_VASP_POTCAR_DIR')
test_vasp_mark = ('ASE_VASP_COMMAND' not in os.environ or
'ASE_VASP_COMMAND_GAMMA' not in os.environ or
'PYTEST_VASP_POTCAR_DIR' not in os.environ)
pytestmark = pytest.mark.skipif(test_vasp_mark, reason='missing env var ASE_VASP_COMMAND or ASE_VASP_COMMAND_GAMMA '
'or PYTEST_VASP_POTCAR_DIR')


def test_vasp_gamma(tmp_path, monkeypatch):
Expand Down
30 changes: 30 additions & 0 deletions tests/calculators/test_wrapped_calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from ase.atoms import Atoms
from wfl.configset import ConfigSet, OutputSpec
from wfl.calculators import generic

########################
# test a RuntimeWarning is raised when using the Espresso Calculator directly from ase
from tests.calculators.test_qe import espresso_avail, qe_pseudo
@pytest.mark.skipif(not espresso_avail, reason='qe testing env vars missing')
def test_wrapped_qe(tmp_path, qe_pseudo):
from ase.calculators.espresso import Espresso as Espresso_ASE
from wfl.calculators.espresso import Espresso as Espresso_wrap

config = Atoms('Si', positions=[[0, 0, 9]], cell=[2, 2, 2], pbc=[True, True, True])

pspot = qe_pseudo
kwargs = dict(
pseudopotentials=dict(Si=pspot.name),
pseudo_dir=pspot.parent,
input_data={"SYSTEM": {"ecutwfc": 40, "input_dft": "LDA",}},
kpts=(2, 3, 4),
conv_thr=0.0001,
workdir=tmp_path,
tstress=True,
tprnfor=True
)

direct_calc = (Espresso_ASE, [], kwargs)
kwargs_generic = dict(inputs=ConfigSet(config), outputs=OutputSpec(), calculator=direct_calc)
pytest.warns(RuntimeWarning, generic.calculate, **kwargs_generic)
14 changes: 14 additions & 0 deletions wfl/calculators/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,20 @@ def calculate(*args, **kwargs):
if calculator is None:
calculator = args[2]

#check if calculator should be wrapped
if type(calculator) == tuple:
from ase.calculators.espresso import Espresso as ASE_Espresso
from ase.calculators.vasp.vasp import Vasp as ASE_Vasp
from ase.calculators.aims import Aims as ASE_Aims
from ase.calculators.castep import Castep as ASE_Castep
from ase.calculators.mopac import MOPAC as ASE_MOPAC
from ase.calculators.orca import ORCA as ASE_ORCA
wrapped_types = [ASE_Espresso, ASE_Vasp, ASE_Aims, ASE_Castep, ASE_MOPAC, ASE_ORCA]

calc = calculator[0]
if calc in wrapped_types:
warnings.warn(f"{calc} should be imported from wfl.calculators rather than ase. Using {calc} directly can lead to duplicated singlepoints", RuntimeWarning)

default_autopara_info = getattr(calculator, "wfl_generic_default_autopara_info", {"num_inputs_per_python_subprocess": 10})

return autoparallelize(_run_autopara_wrappable, *args, default_autopara_info=default_autopara_info, **kwargs)
Expand Down
Loading