Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attaching User defined MD logger #328

Merged
merged 8 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import pytest

import numpy as np
import os
from pathlib import Path

from ase import Atoms
import ase.io
from ase.build import bulk
from ase.calculators.emt import EMT
from ase.units import fs
from ase.md.logger import MDLogger
from wfl.autoparallelize import autoparainfo

from wfl.generate import md
Expand Down Expand Up @@ -216,3 +219,30 @@ def test_md_abort_function(cu_slab):
rng=np.random.default_rng(1))

assert len(list(atoms_traj)) < 501


def test_md_attach_logger(cu_slab):

calc = EMT()
autopara_info = autoparainfo.AutoparaInfo(num_python_subprocesses=2, num_inputs_per_python_subprocess=1, skip_failed=False)

inputs = ConfigSet([cu_slab, cu_slab])
outputs = OutputSpec()

logger_kwargs = {
"logger" : MDLogger,
"logfile" : "test_log",
}

atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="Langevin", steps=300, dt=1.0,
temperature=500.0, temperature_tau=100/fs, logger_kwargs=logger_kwargs, logger_interval=1,
rng=np.random.default_rng(1), autopara_info=autopara_info,)

atoms_traj = list(atoms_traj)
atoms_final = atoms_traj[-1]

workdir = Path(os.getcwd())

assert len(atoms_traj) == 602
assert all([Path(workdir / "test_log.item_0").is_file(), Path(workdir / "test_log.item_1").is_file()])

23 changes: 21 additions & 2 deletions wfl/generate/md/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
from ase.md.verlet import VelocityVerlet
from ase.md.langevin import Langevin
from ase.md.logger import MDLogger
from ase.units import GPa, fs

from wfl.autoparallelize import autoparallelize, autoparallelize_docstring
Expand All @@ -21,8 +22,8 @@
def _sample_autopara_wrappable(atoms, calculator, steps, dt, integrator="NVTBerendsen", temperature=None, temperature_tau=None,
pressure=None, pressure_tau=None, compressibility_au=None, compressibility_fd_displ=0.01,
traj_step_interval=1, skip_failures=True, results_prefix='last_op__md_', verbose=False, update_config_type="append",
traj_select_during_func=lambda at: True, traj_select_after_func=None, abort_check=None, rng=None,
_autopara_per_item_info=None):
traj_select_during_func=lambda at: True, traj_select_after_func=None, abort_check=None,
logger_kwargs=None, logger_interval=None, rng=None, _autopara_per_item_info=None):
"""runs an MD trajectory with aggresive, not necessarily physical, integrators for
sampling configs. By default calculator properties for each frame stored in
keys prefixed with "last_op__md_", which may be overwritten by next operation.
Expand Down Expand Up @@ -81,6 +82,11 @@ def _sample_autopara_wrappable(atoms, calculator, steps, dt, integrator="NVTBere
checks the MD snapshots and aborts the simulation on some condition.
rng: numpy.random.Generator, default None
random number generator to use (needed for pressure sampling, initial temperature, or Langevin dynamics)
logger_kwargs: dict, default None
kwargs to MDLogger to attach to each MD run, including "logfile" as string to which
config number will be appended. User defined ase.md.MDLogger derived class can be provided with "logger" as key.
logger_interval: int, default None
interval for logger
_autopara_per_item_info: dict
INTERNALLY used by autoparallelization framework to make runs reproducible (see
wfl.autoparallelize.autoparallelize() docs)
Expand All @@ -100,6 +106,10 @@ def _sample_autopara_wrappable(atoms, calculator, steps, dt, integrator="NVTBere
else:
logfile = None

if logger_kwargs is not None:
logger_constructor = logger_kwargs.pop("logger", MDLogger)
logger_logfile = logger_kwargs["logfile"]

if temperature_tau is None and (temperature is not None and not isinstance(temperature, (float, int))):
raise RuntimeError(f'NVE (temperature_tau is None) can only accept temperature=float for initial T, got {type(temperature)}')

Expand Down Expand Up @@ -128,6 +138,7 @@ def _sample_autopara_wrappable(atoms, calculator, steps, dt, integrator="NVTBere
# get rng from autopara_per_item info if available ("rng" arg that was passed in was
# already used by autoparallelization framework to set "rng" key in per-item dict)
rng = _autopara_per_item_info[at_i].get("rng")
item_i = _autopara_per_item_info[at_i].get("item_i")

at.calc = calculator
if pressure is not None and compressibility_au is None:
Expand Down Expand Up @@ -211,6 +222,7 @@ def process_step(interval):
if not first_step_of_later_stage and cur_step % interval == 0:
at.info['MD_time_fs'] = cur_step * dt
at.info['MD_step'] = cur_step
at.info["MD_current_temperature"] = at.get_temperature()
at_save = at_copy_save_calc_results(at, prefix=results_prefix)

if traj_select_during_func(at):
Expand All @@ -234,7 +246,14 @@ def process_step(interval):
at.info['MD_temperature_K'] = stage_kwargs['temperature_K']

md = md_constructor(at, **stage_kwargs)

md.attach(process_step, 1, traj_step_interval)
if logger_kwargs is not None:
logger_kwargs["logfile"] = f"{logger_logfile}.item_{item_i}"
logger_kwargs["dyn"] = md
logger_kwargs["atoms"] = at
logger = logger_constructor(**logger_kwargs)
md.attach(logger, logger_interval)

if stage_i > 0:
first_step_of_later_stage = True
Expand Down
Loading