-
Notifications
You must be signed in to change notification settings - Fork 520
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a dpdata driver via the plugin mechanism (override that in the dpdata package) so it can benefit from the multiple-backend DeepPot. Currently, the driver in the dpdata package has to support both v1 and v2 for backward compatibility. When shipped within the deepmd-kit package, it only needs to support the current deepmd-kit version. --------- Signed-off-by: Jinzhe Zeng <[email protected]>
- Loading branch information
Showing
9 changed files
with
144 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
"""dpdata driver.""" | ||
# Derived from https://github.com/deepmodeling/dpdata/blob/18a0ed5ebced8b1f6887038883d46f31ae9990a4/dpdata/plugins/deepmd.py#L361-L443 | ||
# under LGPL-3.0-or-later license. | ||
# The original deepmd driver maintained in the dpdata package will be overriden. | ||
# The class in the dpdata package needs to handle different situations for v1 and v2 interface, | ||
# which is too complex with the development of deepmd-kit. | ||
# So, it will be a good idea to ship it with DeePMD-kit itself. | ||
import dpdata | ||
from dpdata.utils import ( | ||
sort_atom_names, | ||
) | ||
|
||
|
||
@dpdata.driver.Driver.register("dp") | ||
@dpdata.driver.Driver.register("deepmd") | ||
@dpdata.driver.Driver.register("deepmd-kit") | ||
class DPDriver(dpdata.driver.Driver): | ||
"""DeePMD-kit driver. | ||
Parameters | ||
---------- | ||
dp : deepmd.DeepPot or str | ||
The deepmd-kit potential class or the filename of the model. | ||
Examples | ||
-------- | ||
>>> DPDriver("frozen_model.pb") | ||
""" | ||
|
||
def __init__(self, dp: str) -> None: | ||
from deepmd_utils.infer.deep_pot import ( | ||
DeepPot, | ||
) | ||
|
||
if not isinstance(dp, DeepPot): | ||
self.dp = DeepPot(dp, auto_batch_size=True) | ||
else: | ||
self.dp = dp | ||
|
||
def label(self, data: dict) -> dict: | ||
"""Label a system data by deepmd-kit. Returns new data with energy, forces, and virials. | ||
Parameters | ||
---------- | ||
data : dict | ||
data with coordinates and atom types | ||
Returns | ||
------- | ||
dict | ||
labeled data with energies and forces | ||
""" | ||
nframes = data["coords"].shape[0] | ||
natoms = data["coords"].shape[1] | ||
type_map = self.dp.get_type_map() | ||
# important: dpdata type_map may not be the same as the model type_map | ||
# note: while we want to change the type_map when feeding to DeepPot, | ||
# we don't want to change the type_map in the returned data | ||
sorted_data = sort_atom_names(data.copy(), type_map=type_map) | ||
atype = sorted_data["atom_types"] | ||
|
||
coord = data["coords"].reshape((nframes, natoms * 3)) | ||
if "nopbc" not in data: | ||
cell = data["cells"].reshape((nframes, 9)) | ||
else: | ||
cell = None | ||
e, f, v = self.dp.eval(coord, cell, atype) | ||
data = data.copy() | ||
data["energies"] = e.reshape((nframes,)) | ||
data["forces"] = f.reshape((nframes, natoms, 3)) | ||
data["virials"] = v.reshape((nframes, 3, 3)) | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Use deep potential with dpdata | ||
|
||
DeePMD-kit provides a driver for [dpdata](https://github.com/deepmodeling/dpdata) >=0.2.7 via the plugin mechanism, making it possible to call the `predict` method for `System` class: | ||
|
||
```py | ||
import dpdata | ||
|
||
dsys = dpdata.LabeledSystem("OUTCAR") | ||
dp_sys = dsys.predict("frozen_model_compressed.pb", driver="dp") | ||
``` | ||
|
||
By inferring with the DP model `frozen_model_compressed.pb`, dpdata will generate a new labeled system `dp_sys` with inferred energies, forces, and virials. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters