Skip to content

Commit

Permalink
fix errors
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Feb 1, 2024
1 parent b428fc2 commit eb0b7d2
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 13 deletions.
1 change: 1 addition & 0 deletions deepmd/infer/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def detect_backend(filename: str) -> DPBackend:
filename : str
The model file name
"""
filename = str(filename).lower()
if filename.endswith(".pb"):
return DPBackend.TensorFlow
elif filename.endswith(".pth") or filename.endswith(".pt"):
Expand Down
6 changes: 3 additions & 3 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def get_has_efield(self):
return False

Check warning on line 239 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L239

Added line #L239 was not covered by tests

@abstractmethod
def get_ntypes_spin(self):
def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model."""


Expand Down Expand Up @@ -458,10 +458,10 @@ def _get_sel_natoms(self, atype) -> int:
return np.sum(np.isin(atype, self.get_sel_type()).astype(int))

Check warning on line 458 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L458

Added line #L458 was not covered by tests

@property
def has_efield(self):
def has_efield(self) -> bool:
"""Check if the model has efield."""
return self.deep_eval.get_has_efield()

Check warning on line 463 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L463

Added line #L463 was not covered by tests

def get_ntypes_spin(self):
def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model."""
return self.deep_eval.get_ntypes_spin()

Check warning on line 467 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L467

Added line #L467 was not covered by tests
12 changes: 11 additions & 1 deletion deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,17 @@ def eval(
virial = results["energy_derv_c_redu"].reshape(nframes, 9)

Check warning on line 148 in deepmd/infer/deep_pot.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_pot.py#L146-L148

Added lines #L146 - L148 were not covered by tests

if atomic:
atomic_energy = results["energy"].reshape(nframes, natoms, 1)
if self.get_ntypes_spin() > 0:
ntypes_real = self.get_ntypes() - self.get_ntypes_spin()
natoms_real = sum(

Check warning on line 153 in deepmd/infer/deep_pot.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_pot.py#L150-L153

Added lines #L150 - L153 were not covered by tests
[
np.count_nonzero(np.array(atom_types[0]) == ii)
for ii in range(ntypes_real)
]
)
else:
natoms_real = natoms
atomic_energy = results["energy"].reshape(nframes, natoms_real, 1)
atomic_virial = results["energy_derv_c"].reshape(nframes, natoms, 9)
return (

Check warning on line 163 in deepmd/infer/deep_pot.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_pot.py#L160-L163

Added lines #L160 - L163 were not covered by tests
energy,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_sel_type(self) -> List[int]:

def get_numb_dos(self) -> int:

Check warning on line 113 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L113

Added line #L113 was not covered by tests
"""Get the number of DOS."""
raise 0
return 0

Check warning on line 115 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L115

Added line #L115 was not covered by tests

def get_has_efield(self):

Check warning on line 117 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L117

Added line #L117 was not covered by tests
"""Check if the model has efield."""
Expand Down
12 changes: 5 additions & 7 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _init_tensors(self):
"efield": "t_efield:0",
"fparam": "t_fparam:0",
"aparam": "t_aparam:0",
"ntypes_spin": "descrpt_attr/ntypes_spin:0",
"ntypes_spin": "spin_attr/ntypes_spin:0",
# descriptor
"descriptor": "o_descriptor:0",
}
Expand Down Expand Up @@ -435,7 +435,7 @@ def sort_input(
"""
natoms = atom_type.shape[1]
if sel_atoms is not None:
selection = [False] * natoms
selection = np.array([False] * natoms, dtype=bool)
for ii in sel_atoms:
selection += atom_type[0] == ii
sel_atom_type = atom_type[:, selection]
Expand Down Expand Up @@ -628,7 +628,7 @@ def get_ntypes(self) -> int:
"""Get the number of atom types of this model."""
return self.ntypes

Check warning on line 629 in deepmd/tf/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/infer/deep_eval.py#L629

Added line #L629 was not covered by tests

def get_ntypes_spin(self):
def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model."""
return self.ntypes_spin

Check warning on line 633 in deepmd/tf/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/infer/deep_eval.py#L633

Added line #L633 was not covered by tests

Expand Down Expand Up @@ -770,7 +770,6 @@ def eval(
output_dict["energy_redu"] += me.reshape(e.shape)
output_dict["energy_deri_r"] += mf.reshape(f.shape)
output_dict["energy_deri_c_redu"] += mv.reshape(v.shape)
output = tuple(output)
return output_dict

Check warning on line 773 in deepmd/tf/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/infer/deep_eval.py#L764-L773

Added lines #L764 - L773 were not covered by tests

def _prepare_feed_dict(
Expand All @@ -787,7 +786,7 @@ def _prepare_feed_dict(
coords,
atom_types,
)
atom_types = np.array(atom_types, dtype=int).reshape([-1, natoms])
atom_types = np.array(atom_types, dtype=int).reshape([nframes, natoms])
coords = np.reshape(np.array(coords), [nframes, natoms * 3])
if cells is None:
pbc = False

Check warning on line 792 in deepmd/tf/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/infer/deep_eval.py#L789-L792

Added lines #L789 - L792 were not covered by tests
Expand Down Expand Up @@ -942,7 +941,7 @@ def _eval_inner(
ntypes_real = self.ntypes - self.ntypes_spin
natoms_real = sum(

Check warning on line 942 in deepmd/tf/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/infer/deep_eval.py#L937-L942

Added lines #L937 - L942 were not covered by tests
[
np.count_nonzero(np.array(atom_types) == ii)
np.count_nonzero(np.array(atom_types[0]) == ii)
for ii in range(ntypes_real)
]
)
Expand Down Expand Up @@ -977,7 +976,6 @@ def _eval_inner(
odef_shape = self._get_output_shape(

Check warning on line 976 in deepmd/tf/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/infer/deep_eval.py#L976

Added line #L976 was not covered by tests
odef.name, nframes, natoms_real, odef.shape
)
# tmp_shape = [np.prod(odef_shape[:-2]), *odef_shape[-2:]]
v_out[ii] = self.reverse_map(

Check warning on line 979 in deepmd/tf/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/infer/deep_eval.py#L979

Added line #L979 was not covered by tests
np.reshape(v_out[ii], odef_shape), sel_imap[:natoms_real]
)
Expand Down
10 changes: 9 additions & 1 deletion deepmd/tf/model/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
Union,
)

from deepmd.infer.deep_pot import (

Check warning on line 10 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L10

Added line #L10 was not covered by tests
DeepPot,
)
from deepmd.tf.env import (
GLOBAL_TF_FLOAT_PRECISION,
MODEL_VERSION,
Expand Down Expand Up @@ -40,7 +43,12 @@ def __init__(self, model_file: str, **kwargs):
super().__init__(**kwargs)
self.model_file = model_file
self.model = DeepPotential(model_file)
self.model_type = self.model.model_type
if isinstance(self.model, DeepPot):
self.model_type = "ener"

Check warning on line 47 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L46-L47

Added lines #L46 - L47 were not covered by tests
else:
raise NotImplementedError(

Check warning on line 49 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L49

Added line #L49 was not covered by tests
"This model type has not been implemented. " "Contribution is welcome!"
)

def build(
self,
Expand Down

0 comments on commit eb0b7d2

Please sign in to comment.