Skip to content

Commit

Permalink
fix uts
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jun 3, 2024
1 parent aa08e30 commit ddaa38d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
6 changes: 6 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def serialize(self) -> dict:
def deserialize(cls) -> "EnergyFittingNetDirect":
raise NotImplementedError

def slim_type_map(self, type_map: List[str]) -> None:
raise NotImplementedError

Check warning on line 187 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L187

Added line #L187 was not covered by tests

def get_type_map(self) -> List[str]:
raise NotImplementedError

Check warning on line 190 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L190

Added line #L190 was not covered by tests

def forward(
self,
inputs: torch.Tensor,
Expand Down
10 changes: 7 additions & 3 deletions deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,10 @@ def __init__(
self.descrpt = descriptor
else:
self.descrpt = Descriptor(
**descriptor, ntypes=len(self.get_type_map()), spin=self.spin
**descriptor,
ntypes=len(self.get_type_map()),
spin=self.spin,
type_map=type_map,
)

if isinstance(fitting_net, Fitting):
Expand All @@ -672,6 +675,7 @@ def __init__(
ntypes=self.descrpt.get_ntypes(),
dim_descrpt=self.descrpt.get_dim_out(),
mixed_types=type_embedding is not None or self.descrpt.explicit_ntypes,
type_map=type_map,
)
self.rcut = self.descrpt.get_rcut()
self.ntypes = self.descrpt.get_ntypes()
Expand All @@ -680,12 +684,11 @@ def __init__(
if type_embedding is not None and isinstance(type_embedding, TypeEmbedNet):
self.typeebd = type_embedding
elif type_embedding is not None:
if type_embedding.get("use_econf_tebd", False):
type_embedding["type_map"] = type_map
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**type_embedding,
padding=self.descrpt.explicit_ntypes,
type_map=type_map,
)
elif self.descrpt.explicit_ntypes:
default_args = type_embedding_args()
Expand All @@ -695,6 +698,7 @@ def __init__(
ntypes=self.ntypes,
**default_args_dict,
padding=True,
type_map=type_map,
)
else:
self.typeebd = None
Expand Down
3 changes: 1 addition & 2 deletions deepmd/tf/model/pairwise_dprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,12 @@ def __init__(
if isinstance(type_embedding, TypeEmbedNet):
self.typeebd = type_embedding
else:
if type_embedding.get("use_econf_tebd", False):
type_embedding["type_map"] = type_map
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**type_embedding,
# must use se_atten, so it must be True
padding=True,
type_map=type_map,
)

self.qm_model = Model(
Expand Down
6 changes: 6 additions & 0 deletions source/tests/pt/model/test_linear_atomic_model_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def output_def(self):
def serialize(self) -> dict:
raise NotImplementedError

def slim_type_map(self, type_map: List[str]) -> None:
raise NotImplementedError

def get_type_map(self) -> List[str]:
raise NotImplementedError

def forward(
self,
descriptor: torch.Tensor,
Expand Down

0 comments on commit ddaa38d

Please sign in to comment.