Skip to content

Commit

Permalink
move has_spin to has_spin_pt in pt deep_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 27, 2024
1 parent 98d8232 commit b8e1dab
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def __init__(
neighbor_list=neighbor_list,
**kwargs,
)
if getattr(self.deep_eval, "has_spin", False) and hasattr(
if getattr(self.deep_eval, "has_spin_pt", False) and hasattr(
self, "output_def_mag"
):
self.deep_eval.output_def = self.output_def_mag
Expand Down Expand Up @@ -500,7 +500,7 @@ def has_efield(self) -> bool:
@property
def has_spin(self) -> bool:
"""Check if the model has spin."""
return getattr(self.deep_eval, "has_spin", False)
return getattr(self.deep_eval, "has_spin_pt", False)

def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model."""
Expand Down
2 changes: 1 addition & 1 deletion deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def eval(
force,
virial,
)
if getattr(self.deep_eval, "has_spin", False):
if getattr(self.deep_eval, "has_spin_pt", False):
force_mag = results["energy_derv_r_mag"].reshape(nframes, natoms, 3)
result = result + tuple(force_mag)
if atomic:
Expand Down
6 changes: 3 additions & 3 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def __init__(
self.auto_batch_size = auto_batch_size
else:
raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize")
self.has_spin = getattr(self.dp.model["Default"], "has_spin", False)
if callable(self.has_spin):
self.has_spin = self.has_spin()
self.has_spin_pt = getattr(self.dp.model["Default"], "has_spin", False)
if callable(self.has_spin_pt):
self.has_spin_pt = self.has_spin_pt()

def get_rcut(self) -> float:
"""Get the cutoff radius of this model."""
Expand Down

0 comments on commit b8e1dab

Please sign in to comment.