Skip to content

Commit

Permalink
pt: add index input for use_spin
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Mar 13, 2024
1 parent 36fdf53 commit ff8c4ad
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
8 changes: 8 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import copy
import json

import numpy as np

Check warning on line 17 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L17

Added line #L17 was not covered by tests

from deepmd.pt.model.atomic_model import (
DPAtomicModel,
PairTabAtomicModel,
Expand Down Expand Up @@ -57,6 +59,12 @@

def get_spin_model(model_params):
model_params = copy.deepcopy(model_params)
if model_params["spin"]["use_spin"] and isinstance(

Check warning on line 62 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L62

Added line #L62 was not covered by tests
model_params["spin"]["use_spin"][0], int
):
use_spin = np.full(len(model_params["type_map"]), False)
use_spin[model_params["spin"]["use_spin"]] = True
model_params["spin"]["use_spin"] = use_spin.tolist()

Check warning on line 67 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L65-L67

Added lines #L65 - L67 were not covered by tests
# include virtual spin and placeholder types
model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]]
spin = Spin(
Expand Down
11 changes: 8 additions & 3 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def type_embedding_args():


def spin_args():
doc_use_spin = "Whether to use atomic spin model for each atom type"
doc_use_spin = (

Check warning on line 96 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L96

Added line #L96 was not covered by tests
"Whether to use atomic spin model for each atom type. "
"List of boolean values with the shape of [ntypes] to specify which types use spin, "
f"or a list of integer values {doc_only_pt_supported} "
"to indicate the index of the type that uses spin."
)
doc_spin_norm = "The magnitude of atomic spin for each atom type with spin"
doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin"
doc_virtual_scale = (
Expand All @@ -106,7 +111,7 @@ def spin_args():
)

return [
Argument("use_spin", List[bool], doc=doc_use_spin),
Argument("use_spin", [List[bool], List[int]], doc=doc_use_spin),
Argument(
"spin_norm",
List[float],
Expand All @@ -121,7 +126,7 @@ def spin_args():
),
Argument(
"virtual_scale",
List[float],
[List[float], float],
optional=True,
doc=doc_only_pt_supported + doc_virtual_scale,
),
Expand Down

0 comments on commit ff8c4ad

Please sign in to comment.