Skip to content

Commit

Permalink
fix bugs in type_idx_map (deepmodeling#1943)
Browse files Browse the repository at this point in the history
fix bugs in type_idx_map using np.searchsorted
  • Loading branch information
iProzd authored Sep 23, 2022
1 parent f0973f0 commit ffb2fd8
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def __init__ (self,
atom_type_ = [type_map.index(self.type_map[ii]) for ii in self.atom_type]
self.atom_type = np.array(atom_type_, dtype = np.int32)
else:
type_idx_map = np.searchsorted(type_map, self.type_map)
sorter = np.argsort(type_map)
type_idx_map = sorter[np.searchsorted(type_map, self.type_map, sorter=sorter)]
try:
atom_type_mix_ = np.array(type_idx_map)[self.atom_type_mix].astype(np.int32)
except RuntimeError as e:
Expand Down
3 changes: 2 additions & 1 deletion source/tests/test_data_large_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def test_data_mixed_type(self):
batch_size = 1
test_size = 1
rcut = j_must_have(jdata['model']['descriptor'], 'rcut')
type_map = j_must_have(jdata['model'], 'type_map')

data = DeepmdDataSystem(systems, batch_size, test_size, rcut)
data = DeepmdDataSystem(systems, batch_size, test_size, rcut, type_map=type_map)
data_requirement = {'energy': {'ndof': 1,
'atomic': False,
'must': False,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/water_se_atten_mixed_type.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"_comment": " model parameters",
"model" : {
"type_map": ["O", "H"],
"type_map": ["foo", "bar"],
"type_embedding":{
"neuron": [8],
"resnet_dt": false,
Expand Down

0 comments on commit ffb2fd8

Please sign in to comment.