Skip to content

Commit

Permalink
fix(tf): fix bugs in tensor training and migrate to reformat data (#3581
Browse files Browse the repository at this point in the history
)

Fix #3499.

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Mar 25, 2024
1 parent 48f06fe commit 3c3e2ce
Show file tree
Hide file tree
Showing 50 changed files with 70 additions and 35 deletions.
4 changes: 3 additions & 1 deletion deepmd/tf/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def compute_output_stats(self, all_stat):

polar_bias.append(
np.sum(
all_stat["atomic_polarizability"][ss][:, index_lis, :]
all_stat["atomic_polarizability"][ss].reshape(
nframes, len(atom_has_polar), -1
)[:, index_lis, :]
/ nframes,
axis=(0, 1),
).reshape((1, 9))
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def data_stat(self, data):
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(m_all_stat, protection=self.data_stat_protect)
self._compute_output_stat(all_stat)
self._compute_output_stat(m_all_stat)

def _compute_input_stat(self, all_stat, protection=1e-2):
self.descrpt.compute_input_stats(
Expand Down
3 changes: 2 additions & 1 deletion deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ def _load_data(
pass
else:
sel_mask = np.isin(self.atom_type, type_sel)
data = data.reshape([nframes, natoms, ndof_])
data = data[:, sel_mask]
natoms = natoms_sel
idx_map = idx_map_sel
Expand All @@ -669,7 +670,7 @@ def _load_data(
elif must:
raise RuntimeError("%s not found!" % path)
else:
if type_sel is not None and not output_natoms_for_type_sel:
if atomic and type_sel is not None and not output_natoms_for_type_sel:
ndof = ndof_ * natoms_sel
data = np.full([nframes, ndof], default, dtype=dtype)
if repeat != 1:
Expand Down
8 changes: 4 additions & 4 deletions examples/water_tensor/dipole/dipole_input.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@
"training": {
"training_data": {
"systems": [
"./training_data/atomic_system",
"./training_data/global_system"
"./training_data_reformat/atomic_system",
"./training_data_reformat/global_system"
],
"batch_size": "auto",
"_comment8": "that's all"
},
"validation_data": {
"systems": [
"./validation_data/atomic_system",
"./validation_data/global_system"
"./validation_data_reformat/atomic_system",
"./validation_data_reformat/global_system"
],
"batch_size": 1,
"numb_btch": 3,
Expand Down
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

This file was deleted.

Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

This file was deleted.

Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

This file was deleted.

Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

This file was deleted.

8 changes: 4 additions & 4 deletions examples/water_tensor/polar/polar_input.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@
"training": {
"training_data": {
"systems": [
"./training_data/atomic_system",
"./training_data/global_system"
"./training_data_reformat/atomic_system",
"./training_data_reformat/global_system"
],
"batch_size": "auto",
"_comment8": "that's all"
},
"validation_data": {
"systems": [
"./validation_data/atomic_system",
"./validation_data/global_system"
"./validation_data_reformat/atomic_system",
"./validation_data_reformat/global_system"
],
"batch_size": 1,
"numb_btch": 3,
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

This file was deleted.

56 changes: 56 additions & 0 deletions source/tests/tf/test_polar_se_a.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from pathlib import (
Path,
)

import numpy as np

from deepmd.tf.common import (
Expand All @@ -16,6 +20,9 @@
from deepmd.tf.model import (
PolarModel,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)

from .common import (
DataSystem,
Expand Down Expand Up @@ -198,3 +205,52 @@ def test_model(self):
# make sure atomic virial sum to virial
places = 10
np.testing.assert_almost_equal(pv, spv, places)

def test_data_stat(self):
jfile = "polar_se_a.json"
jdata = j_loader(jfile)

systems = [
str(
Path(__file__).parent.parent
/ "pt"
/ "water_tensor"
/ "polar"
/ "global_system"
),
str(
Path(__file__).parent.parent
/ "pt"
/ "water_tensor"
/ "polar"
/ "atomic_system"
),
]

batch_size = 1
test_size = 1
rcut = j_must_have(jdata["model"]["descriptor"], "rcut")

data = DeepmdDataSystem(systems, batch_size, test_size, rcut)
data.add(
"atomic_polarizability",
9,
atomic=True,
type_sel=jdata["model"]["fitting_net"]["sel_type"],
)
data.add(
"polarizability",
9,
atomic=False,
)

jdata["model"]["descriptor"].pop("type", None)
jdata["model"]["fitting_net"].pop("type", None)
descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
jdata["model"]["fitting_net"]["embedding_width"] = descrpt.get_dim_rot_mat_1()
fitting = PolarFittingSeA(**jdata["model"]["fitting_net"], uniform_seed=True)
model = PolarModel(descrpt, fitting)

model.data_stat(data)

0 comments on commit 3c3e2ce

Please sign in to comment.