Skip to content

Commit

Permalink
add distinguish types in base descriptor. fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Feb 1, 2024
1 parent f412bf7 commit 87fabfb
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 12 deletions.
13 changes: 10 additions & 3 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def get_rcut(self) -> float:

@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
"""Returns the number of selected neighboring atoms for each type."""
pass

Check warning on line 40 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L40

Added line #L40 was not covered by tests

def get_nsel(self) -> int:
"""Returns the total number of selected atoms in the cut-off radius."""
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return sum(self.get_sel())

Check warning on line 44 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L44

Added line #L44 was not covered by tests

def get_nnei(self) -> int:
"""Returns the total number of neighboring atoms in the cut-off radius."""
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.get_nsel()

Check warning on line 48 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L48

Added line #L48 was not covered by tests

@abstractmethod
Expand All @@ -62,6 +62,13 @@ def get_dim_emb(self) -> int:
"""Returns the embedding dimension of g2."""
pass

Check warning on line 63 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L63

Added line #L63 was not covered by tests

@abstractmethod
def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
pass

Check warning on line 70 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L70

Added line #L70 was not covered by tests

@abstractmethod
def compute_input_stats(self, merged):
"""Update mean and stddev for descriptor elements."""
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def get_sel(self):
return self.sel

def distinguish_types(self):
"""Returns if the descriptor uses different nets for
different atomic types.
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return True

Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def get_sel(self) -> List[int]:
return self.descriptor.get_sel()

def distinguish_types(self) -> bool:
"""If distinguish different types by sorting."""
"""Returns if model requires a neighbor list that distinguish different
atomic types or not.
"""
return self.descriptor.distinguish_types()

def forward_atomic(
Expand Down
14 changes: 14 additions & 0 deletions deepmd/dpmodel/model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,32 @@ class BAM(ABC):

@abstractmethod
def fitting_output_def(self) -> FittingOutputDef:
"""Get the fitting output def."""
pass

Check warning on line 25 in deepmd/dpmodel/model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_base_atomic_model.py#L25

Added line #L25 was not covered by tests

@abstractmethod
def get_rcut(self) -> float:
"""Get the cut-off radius."""
pass

Check warning on line 30 in deepmd/dpmodel/model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_base_atomic_model.py#L30

Added line #L30 was not covered by tests

@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
pass

Check warning on line 35 in deepmd/dpmodel/model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_base_atomic_model.py#L35

Added line #L35 was not covered by tests

def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return sum(self.get_sel())

Check warning on line 39 in deepmd/dpmodel/model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_base_atomic_model.py#L39

Added line #L39 was not covered by tests

def get_nnei(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.get_nsel()

Check warning on line 43 in deepmd/dpmodel/model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_base_atomic_model.py#L43

Added line #L43 was not covered by tests

@abstractmethod
def distinguish_types(self) -> bool:
"""Returns if the model requires a neighbor list that distinguish different
atomic types or not.
"""
pass

Check warning on line 50 in deepmd/dpmodel/model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_base_atomic_model.py#L50

Added line #L50 was not covered by tests

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def get_dim_emb(self) -> int:
return self.se_atten.dim_emb

def distinguish_types(self) -> bool:
"""Returns if the descriptor uses different nets for
different atomic types.
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return False

Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def get_dim_emb(self) -> int:
return self.repformers.dim_emb

def distinguish_types(self) -> bool:
"""Returns if the descriptor uses different nets for
different atomic types.
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return False

Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def get_dim_emb(self) -> int:
return self.sea.get_dim_emb()

Check warning on line 94 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L94

Added line #L94 was not covered by tests

def distinguish_types(self):
"""Returns if the descriptor uses different nets for
different atomic types.
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return True

Expand Down

0 comments on commit 87fabfb

Please sign in to comment.