Skip to content

Commit

Permalink
nvnmd
Browse files Browse the repository at this point in the history
  • Loading branch information
MoPinghui committed Dec 13, 2023
1 parent 455e14a commit 7148e72
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 9 deletions.
1 change: 1 addition & 0 deletions deepmd/nvnmd/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
},
"ctrl": {
# NSTDM
"MAX_NNEI": 128,
"NSTDM": 64,
"NSTDM_M1": 32,
"NSTDM_M2": 2,
Expand Down
19 changes: 15 additions & 4 deletions deepmd/nvnmd/entrypoints/mapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,22 @@ def __init__(self, config_file: str, weight_file: str, map_file: str):
jdata["weight_file"] = weight_file
jdata["enable"] = True

# 0 : xyz_scatter = xyz_scatter * two_embd + xyz_scatter;
# Gs + 1, Gt + 0
# 1 : xyz_scatter = xyz_scatter * two_embd + two_embd ;
# Gs + 0, Gt + 1
self.Gs_Gt_mode = 1

nvnmd_cfg.init_from_jdata(jdata)

def build_map(self):
if self.Gs_Gt_mode == 0:
self.shift_Gs = 1
self.shift_Gt = 0
if self.Gs_Gt_mode == 1:
self.shift_Gs = 0
self.shift_Gt = 1
#
M = nvnmd_cfg.dscp["M1"]
if nvnmd_cfg.version == 0:
ndim = nvnmd_cfg.dscp["ntype"]
Expand Down Expand Up @@ -482,7 +495,7 @@ def build_s2g_grad(self):
shift = 0
if nvnmd_cfg.version == 1:
ndim = 1
shift = 0
shift = self.shift_Gs
#
dic_ph = {}
dic_ph["s"] = tf.placeholder(tf.float64, [None, 1], "t_s")
Expand Down Expand Up @@ -575,10 +588,8 @@ def build_t2g(self):
[-1, two_side_type_embedding.shape[-1]],
)
# see se_atten.py in dp
# old version : xyz_scatter = xyz_scatter * two_embd + xyz_scatter; Gs + 1, Gt + 0
# new version : xyz_scatter = xyz_scatter * two_embd + two_embd ; Gs + 0, Gt + 1
wbs = [get_filter_type_weight(nvnmd_cfg.weight, ll) for ll in range(1, 5)]
dic_ph["gt"] = self.build_embedding_net(two_side_type_embedding, wbs) + 1
dic_ph["gt"] = self.build_embedding_net(two_side_type_embedding, wbs) + self.shift_Gt
return dic_ph

def run_t2g(self):
Expand Down
8 changes: 4 additions & 4 deletions deepmd/nvnmd/entrypoints/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def wrap_head(self, nhs, nws):
bs = e.dec2bin(SEL, NBIT_MODEL_HEAD)[0] + bs
# atom_ener
# fix the bug: the different energy between qnn and lammps
if "t_bias_atom_e" not in weight.keys():
log.error("NVNMD: There is not t_bias_atom_e in weight!")
exit(1)
atom_ener = weight['t_bias_atom_e']
if "t_bias_atom_e" in weight.keys():
atom_ener = weight['t_bias_atom_e']

Check warning on line 209 in deepmd/nvnmd/entrypoints/wrap.py

View check run for this annotation

Codecov / codecov/patch

deepmd/nvnmd/entrypoints/wrap.py#L209

Added line #L209 was not covered by tests
else:
atom_ener = [0] * 32
nlayer_fit = fitn["nlayer_fit"]
if VERSION == 0:
for tt in range(ntype):
Expand Down
3 changes: 3 additions & 0 deletions deepmd/nvnmd/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def init_from_config(self, jdata):
r"""Initialize member element one by one."""
if "ctrl" in jdata.keys():
if "VERSION" in jdata["ctrl"].keys():
if "MAX_NNEI" not in jdata["ctrl"].keys():
jdata["ctrl"]["MAX_NNEI"] = 128
self.init_config_by_version(jdata["ctrl"]["VERSION"], jdata["ctrl"]["MAX_NNEI"])
#
self.config = FioDic().update(jdata, self.config)
Expand Down Expand Up @@ -336,6 +338,7 @@ def get_nvnmd_jdata(self):
r"""Generate `nvnmd` in input script."""
jdata = self.jdata_deepmd_input["nvnmd"]
jdata["net_size"] = self.net_size
jdata["max_nnei"] = self.max_nnei
jdata["config_file"] = self.config_file
jdata["weight_file"] = self.weight_file
jdata["map_file"] = self.map_file
Expand Down
11 changes: 11 additions & 0 deletions doc/nvnmd/nvnmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ The "nvnmd" section is defined as
```json
{
"version": 0,
"max_nnei":128,
"net_size":128,
"sel":[60, 60],
"rcut":6.0,
Expand All @@ -73,6 +74,7 @@ where items are defined as:
| Item | Mean | Optional Value |
| --------- | --------------------------- | --------------------------------------------- |
| version | the version of network structure | 0 or 1 |
| max_nnei | the maximum number of neighbors that do not distinguish element types | 128 or 256 |
| net_size | the size of nueral network | 128 |
| sel | the number of neighbors | version 0: integer list of lengths 1 to 4 are acceptable; version 1: integer |
| rcut | the cutoff radial | (0, 8.0] |
Expand Down Expand Up @@ -187,6 +189,15 @@ You can also restart the CNN training from the path prefix of checkpoint files (
dp train-nvnmd train_cnn.json -r nvnmd_cnn/model.ckpt -s s1
```

You can also initialize the CNN model and train it by

``` bash
mv nvnmd_cnn nvnmd_cnn_bck
cp train_cnn.json train_cnn2.json
# please edit train_cnn2.json
dp train-nvnmd train_cnn2.json -s s1 -i nvnmd_cnn_bck/model.ckpt
```


# Testing

Expand Down
1 change: 1 addition & 0 deletions examples/nvnmd/train/train_cnn.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"nvnmd": {
"version": 0,
"max_nnei": 128,
"net_size": 128,
"sel": [
60,
Expand Down
1 change: 1 addition & 0 deletions examples/nvnmd/train/train_qnn.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"nvnmd": {
"version": 0,
"max_nnei": 128,
"net_size": 128,
"sel": [
60,
Expand Down
4 changes: 3 additions & 1 deletion source/tests/test_nvnmd_entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ def test_mapt_cnn_v1(self):
map_file = str(tests_path / "nvnmd" / "out" / "map_v1_cnn.npy")
# mapt
mapObj = MapTable(config_file, weight_file, map_file)
mapObj.Gs_Gt_mode = 0
mapt = mapObj.build_map()
#
N = 32
Expand Down Expand Up @@ -859,8 +860,9 @@ def test_wrap_qnn_v1(self):
# test
data = FioBin().load(jdata["nvnmd_model"])
idx = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]
idx = [i + 128*4 for i in idx]
pred = [data[i] for i in idx]
red_dout = [1, 0, 0, 128, 0, 0, 0, 8, 249, 0, 0, 0, 91, 252, 183, 254]
red_dout = [249, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 254, 95, 24, 176]
np.testing.assert_equal(pred, red_dout)
# close
nvnmd_cfg.enable = False
Expand Down

0 comments on commit 7148e72

Please sign in to comment.