-
Notifications
You must be signed in to change notification settings - Fork 520
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dp model format for dp models. add missing files...
- Loading branch information
Han Wang
committed
Jan 31, 2024
1 parent
27f14cd
commit fc3cc8e
Showing
3 changed files
with
218 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from .dp_atomic_model import ( | ||
DPAtomicModel, | ||
) | ||
from .make_model import ( | ||
make_model, | ||
) | ||
|
||
DPModel = make_model(DPAtomicModel) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from typing import ( | ||
Dict, | ||
Optional, | ||
) | ||
|
||
import numpy as np | ||
|
||
from .nlist import ( | ||
build_neighbor_list, | ||
extend_coord_with_ghosts, | ||
) | ||
from .output_def import ( | ||
ModelOutputDef, | ||
) | ||
from .region import ( | ||
normalize_coord, | ||
) | ||
from .transform_output import ( | ||
communicate_extended_output, | ||
fit_output_to_model_output, | ||
) | ||
|
||
|
||
def make_model(T_AtomicModel): | ||
class CM(T_AtomicModel): | ||
def __init__( | ||
self, | ||
*args, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
*args, | ||
**kwargs, | ||
) | ||
|
||
def get_model_output_def(self): | ||
return ModelOutputDef(self.get_fitting_output_def()) | ||
|
||
def call( | ||
self, | ||
coord, | ||
atype, | ||
box: Optional[np.ndarray] = None, | ||
do_atomic_virial: bool = False, | ||
) -> Dict[str, np.ndarray]: | ||
"""Return total energy of the system. | ||
Args: | ||
- coord: Atom coordinates with shape [nframes, natoms[1]*3]. | ||
- atype: Atom types with shape [nframes, natoms[1]]. | ||
- natoms: Atom statisics with shape [self.ntypes+2]. | ||
- box: Simulation box with shape [nframes, 9]. | ||
- atomic_virial: Whether or not compoute the atomic virial. | ||
Returns | ||
------- | ||
- energy: Energy per atom. | ||
- force: XYZ force per atom. | ||
""" | ||
nframes, nloc = atype.shape[:2] | ||
if box is not None: | ||
coord_normalized = normalize_coord( | ||
coord.reshape(nframes, nloc, 3), | ||
box.reshape(nframes, 3, 3), | ||
) | ||
else: | ||
coord_normalized = coord.clone() | ||
extended_coord, extended_atype, mapping = extend_coord_with_ghosts( | ||
coord_normalized, atype, box, self.get_rcut() | ||
) | ||
nlist = build_neighbor_list( | ||
extended_coord, | ||
extended_atype, | ||
nloc, | ||
self.get_rcut(), | ||
self.get_sel(), | ||
distinguish_types=self.distinguish_types(), | ||
) | ||
extended_coord = extended_coord.reshape(nframes, -1, 3) | ||
model_predict_lower = self.call_lower( | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping, | ||
do_atomic_virial=do_atomic_virial, | ||
) | ||
model_predict = communicate_extended_output( | ||
model_predict_lower, | ||
self.get_model_output_def(), | ||
mapping, | ||
do_atomic_virial=do_atomic_virial, | ||
) | ||
return model_predict | ||
|
||
def call_lower( | ||
self, | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping: Optional[np.ndarray] = None, | ||
do_atomic_virial: bool = False, | ||
): | ||
"""Return model prediction. | ||
Parameters | ||
---------- | ||
extended_coord | ||
coodinates in extended region | ||
extended_atype | ||
atomic type in extended region | ||
nlist | ||
neighbor list. nf x nloc x nsel | ||
mapping | ||
mapps the extended indices to local indices | ||
do_atomic_virial | ||
whether do atomic virial | ||
Returns | ||
------- | ||
result_dict | ||
the result dict, defined by the fitting net output def. | ||
""" | ||
nframes, nall = extended_atype.shape[:2] | ||
extended_coord = extended_coord.reshape(nframes, -1, 3) | ||
atomic_ret = self.forward_atomic( | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping=mapping, | ||
) | ||
model_predict = fit_output_to_model_output( | ||
atomic_ret, | ||
self.get_fitting_output_def(), | ||
extended_coord, | ||
do_atomic_virial=do_atomic_virial, | ||
) | ||
return model_predict | ||
|
||
return CM |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from typing import ( | ||
Dict, | ||
) | ||
|
||
import numpy as np | ||
|
||
from .output_def import ( | ||
FittingOutputDef, | ||
ModelOutputDef, | ||
get_deriv_name, | ||
get_reduce_name, | ||
) | ||
|
||
|
||
def fit_output_to_model_output( | ||
fit_ret: Dict[str, np.ndarray], | ||
fit_output_def: FittingOutputDef, | ||
coord_ext: np.ndarray, | ||
do_atomic_virial: bool = False, | ||
) -> Dict[str, np.ndarray]: | ||
"""Transform the output of the fitting network to | ||
the model output. | ||
""" | ||
model_ret = dict(fit_ret.items()) | ||
for kk, vv in fit_ret.items(): | ||
vdef = fit_output_def[kk] | ||
shap = vdef.shape | ||
atom_axis = -(len(shap) + 1) | ||
if vdef.reduciable: | ||
kk_redu = get_reduce_name(kk) | ||
model_ret[kk_redu] = np.sum(vv, axis=atom_axis) | ||
if vdef.differentiable: | ||
kk_derv_r, kk_derv_c = get_deriv_name(kk) | ||
# name-holders | ||
model_ret[kk_derv_r] = None | ||
model_ret[kk_derv_c] = None | ||
return model_ret | ||
|
||
|
||
def communicate_extended_output( | ||
model_ret: Dict[str, np.ndarray], | ||
model_output_def: ModelOutputDef, | ||
mapping: np.ndarray, # nf x nloc | ||
do_atomic_virial: bool = False, | ||
) -> Dict[str, np.ndarray]: | ||
"""Transform the output of the model network defined on | ||
local and ghost (extended) atoms to local atoms. | ||
""" | ||
new_ret = {} | ||
for kk in model_output_def.keys_outp(): | ||
vv = model_ret[kk] | ||
vdef = model_output_def[kk] | ||
new_ret[kk] = vv | ||
if vdef.reduciable: | ||
kk_redu = get_reduce_name(kk) | ||
new_ret[kk_redu] = model_ret[kk_redu] | ||
if vdef.differentiable: | ||
kk_derv_r, kk_derv_c = get_deriv_name(kk) | ||
# name holders | ||
new_ret[kk_derv_r] = None | ||
new_ret[kk_derv_c] = None | ||
new_ret[kk_derv_c + "_redu"] = None | ||
if not do_atomic_virial: | ||
# pop atomic virial, because it is not correctly calculated. | ||
new_ret.pop(kk_derv_c) | ||
return new_ret |