Skip to content

Commit

Permalink
resolve conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 12, 2024
1 parent 5e8734c commit 3e4ed5f
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 23 deletions.
6 changes: 2 additions & 4 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
Optional,
Expand All @@ -18,8 +17,6 @@
DescrptBlockSeAtten,
)

log = logging.getLogger(__name__)


@Descriptor.register("dpa1")
@Descriptor.register("se_atten")
Expand Down Expand Up @@ -128,7 +125,7 @@ def compute_input_stats(self, merged):
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])
self.se_atten.init_desc_stat(sumr, suma, sumn, sumr2, suma2)

@classmethod
Expand All @@ -141,6 +138,7 @@ def get_stat_name(
"""
descrpt_type = type_name
assert descrpt_type in ["dpa1", "se_atten"]
assert all(x is not None for x in [rcut, rcut_smth, sel])
return f"stat_file_descrpt_dpa1_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz"

@classmethod
Expand Down
11 changes: 4 additions & 7 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
Optional,
Expand Down Expand Up @@ -27,8 +26,6 @@
DescrptBlockSeAtten,
)

log = logging.getLogger(__name__)


@Descriptor.register("dpa2")
class DescrptDPA2(Descriptor):
Expand Down Expand Up @@ -316,7 +313,7 @@ def compute_input_stats(self, merged):
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])
for ii, descrpt in enumerate([self.repinit, self.repformers]):
stat_dict_ii = {
"sumr": sumr[ii],
Expand Down Expand Up @@ -346,8 +343,8 @@ def get_stat_name(
"""
descrpt_type = type_name
assert descrpt_type in ["dpa2"]
assert True not in [
x is None
assert all(
x is not None
for x in [
repinit_rcut,
repinit_rcut_smth,
Expand All @@ -356,7 +353,7 @@ def get_stat_name(
repformer_rcut_smth,
repformer_nsel,
]
]
)
return (
f"stat_file_descrpt_dpa2_repinit_rcut{repinit_rcut:.2f}_smth{repinit_rcut_smth:.2f}_sel{repinit_nsel}"
f"_repformer_rcut{repformer_rcut:.2f}_smth{repformer_rcut_smth:.2f}_sel{repformer_nsel}_ntypes{ntypes}.npz"
Expand Down
6 changes: 3 additions & 3 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return True in [
return any(
descriptor.distinguish_types() for descriptor in self.descriptor_list
]
)

@property
def dim_out(self):
Expand Down Expand Up @@ -178,7 +178,7 @@ def compute_input_stats(self, merged):
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])
for ii, descrpt in enumerate(self.descriptor_list):
stat_dict_ii = {
"sumr": sumr[ii],
Expand Down
9 changes: 2 additions & 7 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
ClassVar,
List,
Expand Down Expand Up @@ -41,10 +40,6 @@
process_input,
)

log = logging.getLogger(__name__)

log = logging.getLogger(__name__)


@Descriptor.register("se_e2_a")
class DescrptSeA(Descriptor):
Expand Down Expand Up @@ -119,7 +114,7 @@ def compute_input_stats(self, merged):
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])
self.sea.init_desc_stat(sumr, suma, sumn, sumr2, suma2)

@classmethod
Expand All @@ -132,7 +127,7 @@ def get_stat_name(
"""
descrpt_type = type_name
assert descrpt_type in ["se_e2_a"]
assert True not in [x is None for x in [rcut, rcut_smth, sel]]
assert all(x is not None for x in [rcut, rcut_smth, sel])
return f"stat_file_descrpt_sea_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz"

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def compute_output_stats(self, merged):
return {"bias_atom_e": bias_atom_e}

def init_fitting_stat(self, bias_atom_e=None, **kwargs):
assert True not in [x is None for x in [bias_atom_e]]
assert all(x is not None for x in [bias_atom_e])
self.bias_atom_e.copy_(
torch.tensor(bias_atom_e, device=env.DEVICE).view(
[self.ntypes, self.dim_out]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,4 @@ def process_stat_path(
has_stat_file_path_list = [
os.path.exists(stat_file_path[key]) for key in stat_file_dict
]
return stat_file_path, False not in has_stat_file_path_list
return stat_file_path, all(has_stat_file_path_list)

0 comments on commit 3e4ed5f

Please sign in to comment.