diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 7a2070e476..1675215d7b 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -14,6 +14,8 @@ import copy import json +import numpy as np + from deepmd.pt.model.atomic_model import ( DPAtomicModel, PairTabAtomicModel, @@ -57,6 +59,12 @@ def get_spin_model(model_params): model_params = copy.deepcopy(model_params) + if not model_params["spin"]["use_spin"] or isinstance( + model_params["spin"]["use_spin"][0], int + ): + use_spin = np.full(len(model_params["type_map"]), False) + use_spin[model_params["spin"]["use_spin"]] = True + model_params["spin"]["use_spin"] = use_spin.tolist() # include virtual spin and placeholder types model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]] spin = Spin( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 1c2b7935b5..6081839c34 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -93,7 +93,12 @@ def type_embedding_args(): def spin_args(): - doc_use_spin = "Whether to use atomic spin model for each atom type" + doc_use_spin = ( + "Whether to use atomic spin model for each atom type. " + "List of boolean values with the shape of [ntypes] to specify which types use spin, " + f"or a list of integer values {doc_only_pt_supported} " + "to indicate the index of the type that uses spin." + ) doc_spin_norm = "The magnitude of atomic spin for each atom type with spin" doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin" doc_virtual_scale = ( @@ -106,7 +111,7 @@ def spin_args(): ) return [ - Argument("use_spin", List[bool], doc=doc_use_spin), + Argument("use_spin", [List[bool], List[int]], doc=doc_use_spin), Argument( "spin_norm", List[float], @@ -121,7 +126,7 @@ def spin_args(): ), Argument( "virtual_scale", - List[float], + [List[float], float], optional=True, doc=doc_only_pt_supported + doc_virtual_scale, ),