Skip to content

Commit

Permalink
Replace deepmd.pt.utils.ase_calc with deepmd.calculator
Browse files Browse the repository at this point in the history
- Set `deepmd.pt.utils.ase_calc.DPCalculator` as an alias of `deepmd.calculator.DP`
- Replace `deepmd_pt` with  `deepmd.pt` in `deep_pot.py`
- Set pbc in `pt/test_calculator.py` as it requests stress

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 27, 2024
1 parent f4d7c7e commit fbfbb27
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 64 deletions.
2 changes: 1 addition & 1 deletion deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __new__(cls, model_file: str, *args, **kwargs):

return super().__new__(DeepPotTF)
elif backend == DPBackend.PyTorch:
from deepmd_pt.infer.deep_eval import DeepPot as DeepPotPT
from deepmd.pt.infer.deep_eval import DeepPot as DeepPotPT

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.infer.deep_eval
begins an import cycle.

Check warning on line 59 in deepmd/infer/deep_pot.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_pot.py#L59

Added line #L59 was not covered by tests

return super().__new__(DeepPotPT)
else:
Expand Down
67 changes: 4 additions & 63 deletions deepmd/pt/utils/ase_calc.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
ClassVar,
)
from deepmd.calculator import DP as DPCalculator

Check warning on line 2 in deepmd/pt/utils/ase_calc.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/ase_calc.py#L2

Added line #L2 was not covered by tests

import dpdata
import numpy as np
from ase import (
Atoms,
)
from ase.calculators.calculator import (
Calculator,
PropertyNotImplementedError,
)

from deepmd.pt.infer.deep_eval import (
DeepPot,
)


class DPCalculator(Calculator):
implemented_properties: ClassVar[list] = [
"energy",
"free_energy",
"forces",
"virial",
"stress",
]

def __init__(self, model):
Calculator.__init__(self)
self.dp = DeepPot(model)
self.type_map = self.dp.type_map

def calculate(self, atoms: Atoms, properties, system_changes) -> None:
Calculator.calculate(self, atoms, properties, system_changes)
system = dpdata.System(atoms, fmt="ase/structure")
type_trans = np.array(
[self.type_map.index(i) for i in system.data["atom_names"]]
)
input_coords = system.data["coords"]
input_cells = system.data["cells"]
input_types = list(type_trans[system.data["atom_types"]])
model_predict = self.dp.eval(input_coords, input_cells, input_types)
self.results = {
"energy": model_predict[0].item(),
"free_energy": model_predict[0].item(),
"forces": model_predict[1].reshape(-1, 3),
"virial": model_predict[2].reshape(3, 3),
}

# convert virial into stress for lattice relaxation
if "stress" in properties:
if sum(atoms.get_pbc()) > 0 or (atoms.cell is not None):
# the usual convention (tensile stress is positive)
# stress = -virial / volume
stress = (
-0.5
* (self.results["virial"].copy() + self.results["virial"].copy().T)
/ atoms.get_volume()
)
# Voigt notation
self.results["stress"] = stress.flat[[0, 4, 8, 5, 2, 1]]
else:
raise PropertyNotImplementedError
__all__ = [

Check warning on line 4 in deepmd/pt/utils/ase_calc.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/ase_calc.py#L4

Added line #L4 was not covered by tests
"DPCalculator",
]
2 changes: 2 additions & 0 deletions source/tests/pt/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test_calculator(self):
# positions=[tuple(item) for item in coordinate],
cell=cell,
calculator=self.calculator,
pbc=True,
)
e0, f0 = ase_atoms0.get_potential_energy(), ase_atoms0.get_forces()
s0, v0 = (
Expand All @@ -79,6 +80,7 @@ def test_calculator(self):
# positions=[tuple(item) for item in coordinate],
cell=cell,
calculator=self.calculator,
pbc=True,
)
e1, f1 = ase_atoms1.get_potential_energy(), ase_atoms1.get_forces()
s1, v1 = (
Expand Down

0 comments on commit fbfbb27

Please sign in to comment.