Skip to content

Commit

Permalink
chore: optimize compute_smooth_weight (deepmodeling#4390)
Browse files Browse the repository at this point in the history
New implmenetation is obviously more efficient.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced the `compute_smooth_weight` functionality for improved
efficiency and clarity by simplifying the distance handling logic.
- Introduced configuration for the `array_api_strict` module to ensure
compatibility with the latest API version.

- **Bug Fixes**
- Removed unnecessary masking conditions, ensuring smoother calculations
within defined distance ranges.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Nov 21, 2024
1 parent e9f9321 commit abb32d1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
13 changes: 5 additions & 8 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)


@support_array_api(version="2022.12")
@support_array_api(version="2023.12")
def compute_smooth_weight(
distance: np.ndarray,
rmin: float,
Expand All @@ -25,14 +25,11 @@ def compute_smooth_weight(
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
xp = array_api_compat.array_namespace(distance)
min_mask = distance <= rmin
max_mask = distance >= rmax
mid_mask = xp.logical_not(xp.logical_or(min_mask, max_mask))
distance = xp.clip(distance, min=rmin, max=rmax)
uu = (distance - rmin) / (rmax - rmin)
vv = uu * uu * uu * (-6.0 * uu * uu + 15.0 * uu - 10.0) + 1.0
return vv * xp.astype(mid_mask, distance.dtype) + xp.astype(
min_mask, distance.dtype
)
uu2 = uu * uu
vv = uu2 * uu * (-6.0 * uu2 + 15.0 * uu - 10.0) + 1.0
return vv


def _make_env_mat(
Expand Down
9 changes: 4 additions & 5 deletions deepmd/pt/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ def compute_smooth_weight(distance, rmin: float, rmax: float):
"""Compute smooth weight for descriptor elements."""
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
min_mask = distance <= rmin
max_mask = distance >= rmax
mid_mask = torch.logical_not(torch.logical_or(min_mask, max_mask))
distance = torch.clamp(distance, min=rmin, max=rmax)
uu = (distance - rmin) / (rmax - rmin)
vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1
return vv * mid_mask + min_mask
uu2 = uu * uu
vv = uu2 * uu * (-6 * uu2 + 15 * uu - 10) + 1
return vv
6 changes: 6 additions & 0 deletions source/tests/array_api_strict/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Synchronize with deepmd.jax for test purpose only."""

import array_api_strict

# this is the default version in the latest array_api_strict,
# but in old versions it may be 2022.12
array_api_strict.set_array_api_strict_flags(api_version="2023.12")

0 comments on commit abb32d1

Please sign in to comment.