From 9c861c24f1be3874b08fe36274d8f4afea5b89f1 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 20 Mar 2024 13:12:58 +0800 Subject: [PATCH] pt: add index input for `use_spin` (#3456) It's convenient for spin multitask input file to handle different spin inputs. --- deepmd/pt/model/model/__init__.py | 8 ++++++++ deepmd/utils/argcheck.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 7a2070e476..1675215d7b 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -14,6 +14,8 @@ import copy import json +import numpy as np + from deepmd.pt.model.atomic_model import ( DPAtomicModel, PairTabAtomicModel, @@ -57,6 +59,12 @@ def get_spin_model(model_params): model_params = copy.deepcopy(model_params) + if not model_params["spin"]["use_spin"] or isinstance( + model_params["spin"]["use_spin"][0], int + ): + use_spin = np.full(len(model_params["type_map"]), False) + use_spin[model_params["spin"]["use_spin"]] = True + model_params["spin"]["use_spin"] = use_spin.tolist() # include virtual spin and placeholder types model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]] spin = Spin( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 7a8e0e98cb..564039ccd0 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -93,7 +93,12 @@ def type_embedding_args(): def spin_args(): - doc_use_spin = "Whether to use atomic spin model for each atom type" + doc_use_spin = ( + "Whether to use atomic spin model for each atom type. " + "List of boolean values with the shape of [ntypes] to specify which types use spin, " + f"or a list of integer values {doc_only_pt_supported} " + "to indicate the index of the type that uses spin." + ) doc_spin_norm = "The magnitude of atomic spin for each atom type with spin" doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin" doc_virtual_scale = ( @@ -106,7 +111,7 @@ def spin_args(): ) return [ - Argument("use_spin", List[bool], doc=doc_use_spin), + Argument("use_spin", [List[bool], List[int]], doc=doc_use_spin), Argument( "spin_norm", List[float], @@ -121,7 +126,7 @@ def spin_args(): ), Argument( "virtual_scale", - List[float], + [List[float], float], optional=True, doc=doc_only_pt_supported + doc_virtual_scale, ),