From b97571e4bb8b41717b3a819bcc1455b7c6bd3603 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 24 Sep 2024 14:03:44 +0800 Subject: [PATCH] fix bugs --- backend/dynamic_metadata.py | 2 +- backend/read_env.py | 2 +- deepmd/pd/model/descriptor/se_atten.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/dynamic_metadata.py b/backend/dynamic_metadata.py index 83123e6e41..138375e072 100644 --- a/backend/dynamic_metadata.py +++ b/backend/dynamic_metadata.py @@ -36,7 +36,7 @@ def dynamic_metadata( settings: Optional[Dict[str, object]] = None, ): assert field in ["optional-dependencies", "entry-points", "scripts"] - _, _, find_libpython_requires, extra_scripts, tf_version, pt_version = ( + _, _, find_libpython_requires, extra_scripts, tf_version, pt_version, pd_version = ( get_argument_from_env() ) with Path("pyproject.toml").open("rb") as f: diff --git a/backend/read_env.py b/backend/read_env.py index 8c595c34c3..582b08e1bb 100644 --- a/backend/read_env.py +++ b/backend/read_env.py @@ -153,6 +153,6 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str, str]: def set_scikit_build_env(): """Set scikit-build environment variables before executing scikit-build.""" - cmake_minimum_required_version, cmake_args, _, _, _, _ = get_argument_from_env() + cmake_minimum_required_version, cmake_args, _, _, _, _, _ = get_argument_from_env() os.environ["SKBUILD_CMAKE_MINIMUM_VERSION"] = cmake_minimum_required_version os.environ["SKBUILD_CMAKE_ARGS"] = ";".join(cmake_args) diff --git a/deepmd/pd/model/descriptor/se_atten.py b/deepmd/pd/model/descriptor/se_atten.py index 2b9d150dbb..93fe052b06 100644 --- a/deepmd/pd/model/descriptor/se_atten.py +++ b/deepmd/pd/model/descriptor/se_atten.py @@ -484,9 +484,9 @@ def forward( ) # nb x nloc x nnei exclude_mask = self.emask(nlist, extended_atype) - nlist = paddle.where(exclude_mask != 0, nlist, -1) + nlist = paddle.where(exclude_mask != 0, nlist, paddle.full_like(nlist, -1)) nlist_mask = nlist != -1 - nlist = paddle.where(nlist == -1, 0, nlist) + nlist = paddle.where(nlist == -1, paddle.zeros_like(nlist), nlist) sw = paddle.squeeze(sw, -1) # nf x nloc x nt -> nf x nloc x nnei x nt atype_tebd = extended_atype_embd[:, :nloc, :]