From ff8c4ad6cbc12bc20a61733ca66f129ddc680e76 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Mar 2024 12:11:50 +0800 Subject: [PATCH 1/2] pt: add index input for `use_spin` --- deepmd/pt/model/model/__init__.py | 8 ++++++++ deepmd/utils/argcheck.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index f93ec88bde..71e776da3a 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -14,6 +14,8 @@ import copy import json +import numpy as np + from deepmd.pt.model.atomic_model import ( DPAtomicModel, PairTabAtomicModel, @@ -57,6 +59,12 @@ def get_spin_model(model_params): model_params = copy.deepcopy(model_params) + if model_params["spin"]["use_spin"] and isinstance( + model_params["spin"]["use_spin"][0], int + ): + use_spin = np.full(len(model_params["type_map"]), False) + use_spin[model_params["spin"]["use_spin"]] = True + model_params["spin"]["use_spin"] = use_spin.tolist() # include virtual spin and placeholder types model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]] spin = Spin( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 57f1145d55..5c4d0820d3 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -93,7 +93,12 @@ def type_embedding_args(): def spin_args(): - doc_use_spin = "Whether to use atomic spin model for each atom type" + doc_use_spin = ( + "Whether to use atomic spin model for each atom type. " + "List of boolean values with the shape of [ntypes] to specify which types use spin, " + f"or a list of integer values {doc_only_pt_supported} " + "to indicate the index of the type that uses spin." + ) doc_spin_norm = "The magnitude of atomic spin for each atom type with spin" doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin" doc_virtual_scale = ( @@ -106,7 +111,7 @@ def spin_args(): ) return [ - Argument("use_spin", List[bool], doc=doc_use_spin), + Argument("use_spin", [List[bool], List[int]], doc=doc_use_spin), Argument( "spin_norm", List[float], @@ -121,7 +126,7 @@ def spin_args(): ), Argument( "virtual_scale", - List[float], + [List[float], float], optional=True, doc=doc_only_pt_supported + doc_virtual_scale, ), From 919f8293cf1a3beb88d208b090cc3833fbdb1002 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 19 Mar 2024 18:19:30 +0800 Subject: [PATCH 2/2] Update __init__.py --- deepmd/pt/model/model/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 30ef20d0e1..1675215d7b 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -59,7 +59,7 @@ def get_spin_model(model_params): model_params = copy.deepcopy(model_params) - if model_params["spin"]["use_spin"] and isinstance( + if not model_params["spin"]["use_spin"] or isinstance( model_params["spin"]["use_spin"][0], int ): use_spin = np.full(len(model_params["type_map"]), False)