Skip to content

Commit

Permalink
Fix uts
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 7, 2024
1 parent f6dc9c1 commit c1369c9
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 45 deletions.
15 changes: 8 additions & 7 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,14 @@ def share_params(self, base_class, shared_level, resume=False):
self.sumr2,
self.suma2,
)
base_class.init_desc_stat(
sumr_base + sumr,
suma_base + suma,
sumn_base + sumn,
sumr2_base + sumr2,
suma2_base + suma2,
)
stat_dict = {
"sumr": sumr_base + sumr,
"suma": suma_base + suma,
"sumn": sumn_base + sumn,
"sumr2": sumr2_base + sumr2,
"suma2": suma2_base + suma2,
}
base_class.init_desc_stat(stat_dict)
self.mean = base_class.mean
self.stddev = base_class.stddev
# self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model
Expand Down
27 changes: 14 additions & 13 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,18 +299,12 @@ def compute_input_stats(self, merged):
}
for item in merged
]
(
sumr_tmp,
suma_tmp,
sumn_tmp,
sumr2_tmp,
suma2_tmp,
) = descrpt.compute_input_stats(merged_tmp)
sumr.append(sumr_tmp["sumr"])
suma.append(suma_tmp["suma"])
sumn.append(sumn_tmp["sumn"])
sumr2.append(sumr2_tmp["sumr2"])
suma2.append(suma2_tmp["suma2"])
tmp_stat_dict = descrpt.compute_input_stats(merged_tmp)
sumr.append(tmp_stat_dict["sumr"])
suma.append(tmp_stat_dict["suma"])
sumn.append(tmp_stat_dict["sumn"])
sumr2.append(tmp_stat_dict["sumr2"])
suma2.append(tmp_stat_dict["suma2"])
return {
"sumr": sumr,
"suma": suma,
Expand All @@ -328,7 +322,14 @@ def init_desc_stat(self, stat_dict):
sumr2 = stat_dict["sumr2"]
suma2 = stat_dict["suma2"]
for ii, descrpt in enumerate([self.repinit, self.repformers]):
descrpt.init_desc_stat(sumr[ii], suma[ii], sumn[ii], sumr2[ii], suma2[ii])
stat_dict_ii = {
"sumr": sumr[ii],
"suma": suma[ii],
"sumn": sumn[ii],
"sumr2": sumr2[ii],
"suma2": suma2[ii],
}
descrpt.init_desc_stat(stat_dict_ii)

@classmethod
def get_stat_name(cls, config, ntypes):
Expand Down
44 changes: 29 additions & 15 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,23 +153,37 @@ def compute_input_stats(self, merged):
}
for item in merged
]
(
sumr_tmp,
suma_tmp,
sumn_tmp,
sumr2_tmp,
suma2_tmp,
) = descrpt.compute_input_stats(merged_tmp)
sumr.append(sumr_tmp)
suma.append(suma_tmp)
sumn.append(sumn_tmp)
sumr2.append(sumr2_tmp)
suma2.append(suma2_tmp)
return sumr, suma, sumn, sumr2, suma2
tmp_stat_dict = descrpt.compute_input_stats(merged_tmp)
sumr.append(tmp_stat_dict["sumr"])
suma.append(tmp_stat_dict["suma"])
sumn.append(tmp_stat_dict["sumn"])
sumr2.append(tmp_stat_dict["sumr2"])
suma2.append(tmp_stat_dict["suma2"])
return {
"sumr": sumr,
"suma": suma,
"sumn": sumn,
"sumr2": sumr2,
"suma2": suma2,
}

def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):
def init_desc_stat(self, stat_dict):
for key in ["sumr", "suma", "sumn", "sumr2", "suma2"]:
assert key in stat_dict, f"Statistics {key} not found in the dictionary!"
sumr = stat_dict["sumr"]
suma = stat_dict["suma"]
sumn = stat_dict["sumn"]
sumr2 = stat_dict["sumr2"]
suma2 = stat_dict["suma2"]
for ii, descrpt in enumerate(self.descriptor_list):
descrpt.init_desc_stat(sumr[ii], suma[ii], sumn[ii], sumr2[ii], suma2[ii])
stat_dict_ii = {
"sumr": sumr[ii],
"suma": suma[ii],
"sumn": sumn[ii],
"sumr2": sumr2[ii],
"suma2": suma2[ii],
}
descrpt.init_desc_stat(stat_dict_ii)

def forward(
self,
Expand Down
17 changes: 15 additions & 2 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,22 @@ def compute_input_stats(self, merged):
sumn = np.sum(sumn, axis=0)
sumr2 = np.sum(sumr2, axis=0)
suma2 = np.sum(suma2, axis=0)
return sumr, suma, sumn, sumr2, suma2
return {
"sumr": sumr,
"suma": suma,
"sumn": sumn,
"sumr2": sumr2,
"suma2": suma2,
}

def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):
def init_desc_stat(self, stat_dict):
for key in ["sumr", "suma", "sumn", "sumr2", "suma2"]:
assert key in stat_dict, f"Statistics {key} not found in the dictionary!"
sumr = stat_dict["sumr"]
suma = stat_dict["suma"]
sumn = stat_dict["sumn"]
sumr2 = stat_dict["sumr2"]
suma2 = stat_dict["suma2"]
all_davg = []
all_dstd = []
for type_i in range(self.ntypes):
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
PRECISION_DICT,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
compute_output_bias,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
Expand Down Expand Up @@ -213,7 +213,7 @@ def compute_output_stats(self, merged):
input_natoms = [item["real_natoms_vec"] for item in merged]
else:
input_natoms = [item["natoms"] for item in merged]
tmp = compute_output_stats(energy, input_natoms)
tmp = compute_output_bias(energy, input_natoms)
bias_atom_e = tmp[:, 0]
return {"bias_atom_e": bias_atom_e}

Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def share_params(self, base_class, shared_level, resume=False):
else:
raise NotImplementedError

def compute_output_stats(self, merged):
"""Update the output bias for fitting net."""
raise NotImplementedError

def init_fitting_stat(self, result_dict):
"""Initialize the model bias by the statistics."""
raise NotImplementedError

@classmethod
def get_stat_name(cls, config, ntypes):
"""
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def make_stat_input(datasets, dataloaders, nbatches):
return lst


def compute_output_stats(energy, natoms, rcond=None):
"""Update mean and stddev for descriptor elements.
def compute_output_bias(energy, natoms, rcond=None):
"""Update output bias for fitting net.
Args:
- energy: Batched energy with shape [nframes, 1].
Expand Down
8 changes: 4 additions & 4 deletions source/tests/pt/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
DpLoaderSet,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
compute_output_bias,
)
from deepmd.pt.utils.stat import make_stat_input as my_make
from deepmd.tf.common import (
Expand Down Expand Up @@ -124,7 +124,7 @@ def my_merge(energy, natoms):
energy, natoms = my_merge(energy, natoms)
dp_fn = EnerFitting(self.dp_d, self.n_neuron)
dp_fn.compute_output_stats(self.dp_sampled)
bias_atom_e = compute_output_stats(energy, natoms)
bias_atom_e = compute_output_bias(energy, natoms)
self.assertTrue(np.allclose(dp_fn.bias_atom_e, bias_atom_e[:, 0]))

# temporarily delete this function for performance of seeds in tf and pytorch may be different
Expand Down Expand Up @@ -172,8 +172,8 @@ def test_descriptor(self):
]:
if key in sys.keys():
sys[key] = sys[key].to(env.DEVICE)
sumr, suma, sumn, sumr2, suma2 = my_en.compute_input_stats(sampled)
my_en.init_desc_stat(sumr, suma, sumn, sumr2, suma2)
stat_dict = my_en.compute_input_stats(sampled)
my_en.init_desc_stat(stat_dict)
my_en.mean = my_en.mean
my_en.stddev = my_en.stddev
self.assertTrue(
Expand Down

0 comments on commit c1369c9

Please sign in to comment.