From abb32d1a0829bb8d4b81a934b2dc94e2acecd5c1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 21 Nov 2024 14:05:44 -0500 Subject: [PATCH] chore: optimize compute_smooth_weight (#4390) New implmenetation is obviously more efficient. ## Summary by CodeRabbit - **New Features** - Enhanced the `compute_smooth_weight` functionality for improved efficiency and clarity by simplifying the distance handling logic. - Introduced configuration for the `array_api_strict` module to ensure compatibility with the latest API version. - **Bug Fixes** - Removed unnecessary masking conditions, ensuring smoother calculations within defined distance ranges. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/dpmodel/utils/env_mat.py | 13 +++++-------- deepmd/pt/utils/preprocess.py | 9 ++++----- source/tests/array_api_strict/__init__.py | 6 ++++++ 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index abbd68945b..bcecf62775 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -15,7 +15,7 @@ ) -@support_array_api(version="2022.12") +@support_array_api(version="2023.12") def compute_smooth_weight( distance: np.ndarray, rmin: float, @@ -25,14 +25,11 @@ def compute_smooth_weight( if rmin >= rmax: raise ValueError("rmin should be less than rmax.") xp = array_api_compat.array_namespace(distance) - min_mask = distance <= rmin - max_mask = distance >= rmax - mid_mask = xp.logical_not(xp.logical_or(min_mask, max_mask)) + distance = xp.clip(distance, min=rmin, max=rmax) uu = (distance - rmin) / (rmax - rmin) - vv = uu * uu * uu * (-6.0 * uu * uu + 15.0 * uu - 10.0) + 1.0 - return vv * xp.astype(mid_mask, distance.dtype) + xp.astype( - min_mask, distance.dtype - ) + uu2 = uu * uu + vv = uu2 * uu * (-6.0 * uu2 + 15.0 * uu - 10.0) + 1.0 + return vv def _make_env_mat( diff --git a/deepmd/pt/utils/preprocess.py b/deepmd/pt/utils/preprocess.py index 7d5b0cf314..8ab489dede 100644 --- a/deepmd/pt/utils/preprocess.py +++ b/deepmd/pt/utils/preprocess.py @@ -10,9 +10,8 @@ def compute_smooth_weight(distance, rmin: float, rmax: float): """Compute smooth weight for descriptor elements.""" if rmin >= rmax: raise ValueError("rmin should be less than rmax.") - min_mask = distance <= rmin - max_mask = distance >= rmax - mid_mask = torch.logical_not(torch.logical_or(min_mask, max_mask)) + distance = torch.clamp(distance, min=rmin, max=rmax) uu = (distance - rmin) / (rmax - rmin) - vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1 - return vv * mid_mask + min_mask + uu2 = uu * uu + vv = uu2 * uu * (-6 * uu2 + 15 * uu - 10) + 1 + return vv diff --git a/source/tests/array_api_strict/__init__.py b/source/tests/array_api_strict/__init__.py index 27785c2fd5..27f15682e0 100644 --- a/source/tests/array_api_strict/__init__.py +++ b/source/tests/array_api_strict/__init__.py @@ -1,2 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Synchronize with deepmd.jax for test purpose only.""" + +import array_api_strict + +# this is the default version in the latest array_api_strict, +# but in old versions it may be 2022.12 +array_api_strict.set_array_api_strict_flags(api_version="2023.12")