diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 722375fc68..76cff174af 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import logging from typing import ( List, Optional, @@ -18,8 +17,6 @@ DescrptBlockSeAtten, ) -log = logging.getLogger(__name__) - @Descriptor.register("dpa1") @Descriptor.register("se_atten") @@ -128,7 +125,7 @@ def compute_input_stats(self, merged): def init_desc_stat( self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs ): - assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] + assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) self.se_atten.init_desc_stat(sumr, suma, sumn, sumr2, suma2) @classmethod @@ -141,6 +138,7 @@ def get_stat_name( """ descrpt_type = type_name assert descrpt_type in ["dpa1", "se_atten"] + assert all(x is not None for x in [rcut, rcut_smth, sel]) return f"stat_file_descrpt_dpa1_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz" @classmethod diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 05e7cec658..6cefaf6f38 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import logging from typing import ( List, Optional, @@ -27,8 +26,6 @@ DescrptBlockSeAtten, ) -log = logging.getLogger(__name__) - @Descriptor.register("dpa2") class DescrptDPA2(Descriptor): @@ -316,7 +313,7 @@ def compute_input_stats(self, merged): def init_desc_stat( self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs ): - assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] + assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) for ii, descrpt in enumerate([self.repinit, self.repformers]): stat_dict_ii = { "sumr": sumr[ii], @@ -346,8 +343,8 @@ def get_stat_name( """ descrpt_type = type_name assert descrpt_type in ["dpa2"] - assert True not in [ - x is None + assert all( + x is not None for x in [ repinit_rcut, repinit_rcut_smth, @@ -356,7 +353,7 @@ def get_stat_name( repformer_rcut_smth, repformer_nsel, ] - ] + ) return ( f"stat_file_descrpt_dpa2_repinit_rcut{repinit_rcut:.2f}_smth{repinit_rcut_smth:.2f}_sel{repinit_nsel}" f"_repformer_rcut{repformer_rcut:.2f}_smth{repformer_rcut_smth:.2f}_sel{repformer_nsel}_ntypes{ntypes}.npz" diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index e944d1277b..c5c08c760d 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -107,9 +107,9 @@ def distinguish_types(self) -> bool: """Returns if the descriptor requires a neighbor list that distinguish different atomic types or not. """ - return True in [ + return any( descriptor.distinguish_types() for descriptor in self.descriptor_list - ] + ) @property def dim_out(self): @@ -178,7 +178,7 @@ def compute_input_stats(self, merged): def init_desc_stat( self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs ): - assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] + assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) for ii, descrpt in enumerate(self.descriptor_list): stat_dict_ii = { "sumr": sumr[ii], diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index fcc906b248..ddd2b7ca73 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import logging from typing import ( ClassVar, List, @@ -41,10 +40,6 @@ process_input, ) -log = logging.getLogger(__name__) - -log = logging.getLogger(__name__) - @Descriptor.register("se_e2_a") class DescrptSeA(Descriptor): @@ -119,7 +114,7 @@ def compute_input_stats(self, merged): def init_desc_stat( self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs ): - assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] + assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) self.sea.init_desc_stat(sumr, suma, sumn, sumr2, suma2) @classmethod @@ -132,7 +127,7 @@ def get_stat_name( """ descrpt_type = type_name assert descrpt_type in ["se_e2_a"] - assert True not in [x is None for x in [rcut, rcut_smth, sel]] + assert all(x is not None for x in [rcut, rcut_smth, sel]) return f"stat_file_descrpt_sea_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz" @classmethod diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index c8ade925c0..1b3e2c3d65 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -223,7 +223,7 @@ def compute_output_stats(self, merged): return {"bias_atom_e": bias_atom_e} def init_fitting_stat(self, bias_atom_e=None, **kwargs): - assert True not in [x is None for x in [bias_atom_e]] + assert all(x is not None for x in [bias_atom_e]) self.bias_atom_e.copy_( torch.tensor(bias_atom_e, device=env.DEVICE).view( [self.ntypes, self.dim_out] diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index edcdf742db..76b2afe41b 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -105,4 +105,4 @@ def process_stat_path( has_stat_file_path_list = [ os.path.exists(stat_file_path[key]) for key in stat_file_dict ] - return stat_file_path, False not in has_stat_file_path_list + return stat_file_path, all(has_stat_file_path_list)