Skip to content

Commit

Permalink
pt: add index input for use_spin (#3456)
Browse files Browse the repository at this point in the history
It's convenient for spin multitask input file to handle different spin
inputs.
  • Loading branch information
iProzd authored Mar 20, 2024
1 parent 47366f6 commit 9c861c2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
8 changes: 8 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import copy
import json

import numpy as np

from deepmd.pt.model.atomic_model import (
DPAtomicModel,
PairTabAtomicModel,
Expand Down Expand Up @@ -57,6 +59,12 @@

def get_spin_model(model_params):
model_params = copy.deepcopy(model_params)
if not model_params["spin"]["use_spin"] or isinstance(
model_params["spin"]["use_spin"][0], int
):
use_spin = np.full(len(model_params["type_map"]), False)
use_spin[model_params["spin"]["use_spin"]] = True
model_params["spin"]["use_spin"] = use_spin.tolist()
# include virtual spin and placeholder types
model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]]
spin = Spin(
Expand Down
11 changes: 8 additions & 3 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def type_embedding_args():


def spin_args():
doc_use_spin = "Whether to use atomic spin model for each atom type"
doc_use_spin = (
"Whether to use atomic spin model for each atom type. "
"List of boolean values with the shape of [ntypes] to specify which types use spin, "
f"or a list of integer values {doc_only_pt_supported} "
"to indicate the index of the type that uses spin."
)
doc_spin_norm = "The magnitude of atomic spin for each atom type with spin"
doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin"
doc_virtual_scale = (
Expand All @@ -106,7 +111,7 @@ def spin_args():
)

return [
Argument("use_spin", List[bool], doc=doc_use_spin),
Argument("use_spin", [List[bool], List[int]], doc=doc_use_spin),
Argument(
"spin_norm",
List[float],
Expand All @@ -121,7 +126,7 @@ def spin_args():
),
Argument(
"virtual_scale",
List[float],
[List[float], float],
optional=True,
doc=doc_only_pt_supported + doc_virtual_scale,
),
Expand Down

0 comments on commit 9c861c2

Please sign in to comment.