Skip to content

Commit

Permalink
move pairtab init to EnerModel - at this time no other models use it
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz authored May 18, 2024
1 parent 16550ee commit e957e10
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
10 changes: 10 additions & 0 deletions deepmd/tf/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ def __init__(
self.numb_fparam = self.fitting.get_numb_fparam()
self.numb_aparam = self.fitting.get_numb_aparam()

self.srtab_name = use_srtab
if self.srtab_name is not None:
self.srtab = PairTab(self.srtab_name, rcut=self.get_rcut())
self.smin_alpha = smin_alpha
self.sw_rmin = sw_rmin
self.sw_rmax = sw_rmax
self.srtab_add_bias = srtab_add_bias
else:
self.srtab = None

def get_rcut(self):
return self.rcut

Expand Down
23 changes: 4 additions & 19 deletions deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,6 @@ def __init__(
data_stat_nbatch: int = 10,
data_bias_nsample: int = 10,
data_stat_protect: float = 1e-2,
use_srtab: Optional[str] = None,
smin_alpha: Optional[float] = None,
sw_rmin: Optional[float] = None,
sw_rmax: Optional[float] = None,
srtab_add_bias: bool = True,
spin: Optional[Spin] = None,
compress: Optional[dict] = None,
**kwargs,
Expand All @@ -142,15 +137,6 @@ def __init__(
self.data_stat_nbatch = data_stat_nbatch
self.data_bias_nsample = data_bias_nsample
self.data_stat_protect = data_stat_protect
self.srtab_name = use_srtab
if self.srtab_name is not None:
self.srtab = PairTab(self.srtab_name, rcut=self.get_rcut())
self.smin_alpha = smin_alpha
self.sw_rmin = sw_rmin
self.sw_rmax = sw_rmax
self.srtab_add_bias = srtab_add_bias
else:
self.srtab = None

def get_type_map(self) -> list:
"""Get the type map."""
Expand Down Expand Up @@ -649,6 +635,10 @@ def __init__(
type_map: Optional[List[str]] = None,
**kwargs,
) -> None:
super().__init__(
descriptor=descriptor, fitting=fitting_net, type_map=type_map, **kwargs
)

if isinstance(descriptor, Descriptor):
self.descrpt = descriptor
else:
Expand All @@ -672,11 +662,6 @@ def __init__(
self.rcut = self.descrpt.get_rcut()
self.ntypes = self.descrpt.get_ntypes()

# need rcut
super().__init__(
descriptor=descriptor, fitting=fitting_net, type_map=type_map, **kwargs
)

# type embedding
if type_embedding is not None and isinstance(type_embedding, TypeEmbedNet):
self.typeebd = type_embedding
Expand Down

0 comments on commit e957e10

Please sign in to comment.