Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dpmodel): fix precision #4343

Merged
merged 15 commits into from
Nov 14, 2024
Merged

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Nov 12, 2024

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced a new environment variable DP_DTYPE_PROMOTION_STRICT to enhance precision handling in TensorFlow tests.
    • Added a decorator @cast_precision to several descriptor classes, improving precision management during computations.
    • Updated JAX configuration to enable strict dtype promotion based on the new environment variable.
    • Enhanced serialization and deserialization processes to include precision attributes across multiple classes.
  • Bug Fixes

    • Enhanced type handling and input processing in the GeneralFitting class for better output predictions.
    • Improved handling of atomic contributions and exclusions in the BaseAtomicModel class.
    • Addressed potential type mismatches during matrix operations in the NativeLayer class.
  • Chores

    • Updated caching mechanisms in the testing workflow to ensure unique keys based on run parameters.

Copy link
Contributor

coderabbitai bot commented Nov 12, 2024

📝 Walkthrough
📝 Walkthrough

Walkthrough

The pull request introduces several updates across multiple files, primarily focusing on enhancing precision handling in descriptor computations and refining the Python testing workflow. Key changes include the addition of a new environment variable in the testing workflow, the implementation of a cast_precision decorator for managing precision in various descriptor classes, and modifications to the caching mechanism for test durations. These updates aim to improve the robustness and clarity of the code while maintaining existing functionalities.

Changes

File Path Change Summary
.github/workflows/test_python.yml Added merge_group for concurrency settings, introduced DP_DTYPE_PROMOTION_STRICT variable, refined caching keys, and updated upload processes.
deepmd/dpmodel/common.py Added cast_precision decorator and safe_cast_array function for precision handling in array operations.
deepmd/dpmodel/descriptor/dpa1.py Applied cast_precision decorator to call method, added precision attribute to DescrptDPA1 class.
deepmd/dpmodel/descriptor/dpa2.py Applied cast_precision decorator to call method, added precision parameter to DescrptDPA2 constructor.
deepmd/dpmodel/descriptor/se_e2_a.py Added cast_precision decorator to call methods of DescrptSeA and DescrptSeAArrayAPI.
deepmd/dpmodel/descriptor/se_r.py Added cast_precision decorator to call method of DescrptSeR, updated serialize method.
deepmd/dpmodel/descriptor/se_t.py Added cast_precision decorator to call method, removed explicit type casting for rr and result.
deepmd/dpmodel/descriptor/se_t_tebd.py Renamed get_xp_precision to cast_precision, updated call method and added precision attribute.
deepmd/dpmodel/fitting/general_fitting.py Enhanced type handling in _call_common, updated logic for output predictions, and strengthened error handling.
deepmd/jax/env.py Added conditional configuration for JAX based on DP_DTYPE_PROMOTION_STRICT environment variable.
deepmd/dpmodel/atomic_model/base_atomic_model.py Updated forward_common_atomic method to enhance atom exclusion handling and output reshaping.
deepmd/dpmodel/fitting/dipole_fitting.py Added cast_precision decorator to call method of DipoleFitting class.
deepmd/dpmodel/fitting/invar_fitting.py Added cast_precision decorator to call method of InvarFitting class.
deepmd/dpmodel/fitting/polarizability_fitting.py Added cast_precision decorator to call method of PolarFitting class.

Possibly related PRs

Suggested labels

Docs

Suggested reviewers

  • wanghan-iapcm

Warning

Rate limit exceeded

@njzjz has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 25 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between e43134d and 36b087d.


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between e43134d and 36b087d.

📒 Files selected for processing (1)
  • source/tests/common/dpmodel/test_network.py (4 hunks)
🔇 Additional comments (4)
source/tests/common/dpmodel/test_network.py (4)

11-13: LGTM! Clear and well-organized import.

The new import of get_xp_precision is properly organized and the function name clearly describes its purpose of getting the precision for NumPy arrays.


176-176: LGTM! Consistent precision handling.

The test properly creates input arrays with the correct precision, maintaining consistency with the precision parameter used in EmbeddingNet initialization.


199-199: LGTM! Thorough precision verification.

The test properly handles precision-specific inputs and verifies both the network's ability to process inputs and the consistency of results between serialized and deserialized instances.


52-54: LGTM! Comprehensive precision testing.

The test now properly creates input arrays with the correct precision based on the test parameter. This ensures thorough testing across different precision settings ("float32", "float64", "single", "double").

Let's verify that all precision options are handled correctly by get_xp_precision:

✅ Verification successful

Precision handling is correctly implemented and matches test requirements

The get_xp_precision function in deepmd/dpmodel/common.py properly handles all precision options used in the test ("float32", "float64", "single", "double"). The function maps:

  • "float32" or "single" to xp.float32
  • "float64" or "double" to xp.float64

This confirms that the test's input array creation with dynamic precision is well-supported by the underlying implementation.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the implementation of get_xp_precision function
# Look for the function definition to ensure all precision options are handled

ast-grep --pattern 'def get_xp_precision($_, $_):
  $$$
'

Length of output: 1677


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Outside diff range and nitpick comments (16)
source/tests/consistent/descriptor/test_se_t.py (1)

107-119: LGTM! Consider adding return type hint.

The skip logic is well-implemented and consistent with the PR's objectives. The code is clean and well-structured.

Consider adding a return type hint to match the style of other properties:

-    def skip_array_api_strict(self):
+    def skip_array_api_strict(self) -> bool:
source/tests/consistent/fitting/test_dipole.py (2)

89-99: LGTM! Consider enhancing the comment clarity.

The implementation correctly skips float32 precision tests, aligning with the PR objective. However, the comment could be more explicit about the specific NumPy behavior being addressed.

Consider updating the comment to be more descriptive:

-            # NumPy doesn't throw errors for float64 x float32
+            # Skip float32 tests as NumPy silently promotes float32 to float64 during operations,
+            # which masks potential precision-related issues

101-111: LGTM! Consider reducing code duplication.

The implementation correctly handles both float32 precision and Array API strict mode availability. However, there's duplicated logic with skip_dp for float32 handling.

Consider extracting the shared float32 check into a private property:

+    @property
+    def _skip_float32(self) -> bool:
+        """Skip tests for float32 precision due to NumPy's silent promotion behavior."""
+        (_, precision, _) = self.param
+        return precision == "float32"
+
     @property
     def skip_dp(self) -> bool:
-        (
-            resnet_dt,
-            precision,
-            mixed_types,
-        ) = self.param
-        if precision == "float32":
-            # NumPy doesn't throw errors for float64 x float32
-            return True
+        if self._skip_float32:
+            return True
         return CommonTest.skip_dp

     @property
     def skip_array_api_strict(self) -> bool:
-        (
-            resnet_dt,
-            precision,
-            mixed_types,
-        ) = self.param
-        if precision == "float32":
-            # NumPy doesn't throw errors for float64 x float32
-            return True
+        if self._skip_float32:
+            return True
         return not INSTALLED_ARRAY_API_STRICT
source/tests/consistent/fitting/test_polar.py (1)

89-100: LGTM! Consider enhancing the comment for better clarity.

The implementation correctly handles skipping DP backend tests for float32 precision to avoid false positives.

Consider expanding the comment to better explain the implications:

-            # NumPy doesn't throw errors for float64 x float32
+            # Skip float32 tests as NumPy silently promotes float32 to float64 during operations,
+            # which could mask precision-related issues we want to catch
source/tests/consistent/descriptor/test_se_e2_a.py (1)

137-139: LGTM! Consider extracting the precision check.

The skip condition is consistent with the skip_dp implementation.

Consider extracting the precision check into a private method to avoid duplication:

+    def _should_skip_float32(self) -> bool:
+        (
+            resnet_dt,
+            type_one_side,
+            excluded_types,
+            precision,
+            env_protection,
+        ) = self.param
+        return precision == "float32"
+
     @property
     def skip_dp(self) -> bool:
-        if precision == "float32":
-            # NumPy doesn't throw errors for float64 x float32
-            return True
+        if self._should_skip_float32():
+            return True
         return CommonTest.skip_dp

     @property
     def skip_array_api_strict(self) -> bool:
-        if precision == "float32":
-            # NumPy doesn't throw errors for float64 x float32
-            return True
+        if self._should_skip_float32():
+            return True
         return not type_one_side or not INSTALLED_ARRAY_API_STRICT
source/tests/consistent/fitting/test_dos.py (1)

121-131: LGTM! Consider reducing code duplication.

The implementation correctly handles skipping array API strict tests for float32 precision. However, the float32 check logic is duplicated from skip_dp.

Consider extracting the common float32 check into a private method:

+    def _is_float32(self) -> bool:
+        (
+            resnet_dt,
+            precision,
+            mixed_types,
+            numb_fparam,
+            numb_aparam,
+            numb_dos,
+        ) = self.param
+        return precision == "float32"
+
     @property
     def skip_dp(self) -> bool:
-        (
-            resnet_dt,
-            precision,
-            mixed_types,
-            numb_fparam,
-            numb_aparam,
-            numb_dos,
-        ) = self.param
-        if precision == "float32":
+        if self._is_float32():
             # NumPy doesn't throw errors for float64 x float32
             return True
         return CommonTest.skip_dp

     @property
     def skip_array_api_strict(self) -> bool:
-        (
-            resnet_dt,
-            precision,
-            mixed_types,
-            numb_fparam,
-            numb_aparam,
-            numb_dos,
-        ) = self.param
-        if precision == "float32":
+        if self._is_float32():
             # NumPy doesn't throw errors for float64 x float32
             return True
         return not INSTALLED_ARRAY_API_STRICT
source/tests/consistent/descriptor/test_se_t_tebd.py (1)

153-172: LGTM! Consider extracting common precision check.

The implementation correctly handles array API strict mode skipping. However, the float32 check logic is duplicated from skip_dp.

Consider extracting the common precision check to reduce duplication:

+    @property
+    def _skip_float32(self) -> bool:
+        return self.param[8] == "float32"  # precision is param[8]
+
     @property
     def skip_dp(self) -> bool:
         (...param unpacking...)
-        if precision == "float32":
+        if self._skip_float32:
             return True
         return CommonTest.skip_dp

     @property
     def skip_array_api_strict(self) -> bool:
         (...param unpacking...)
-        if precision == "float32":
+        if self._skip_float32:
             return True
         return not INSTALLED_ARRAY_API_STRICT
source/tests/consistent/fitting/test_ener.py (1)

103-116: LGTM! Consider enhancing the documentation.

The skip_dp property correctly implements the skipping of float32 precision tests for NumPy backend, which aligns with the PR objectives. The implementation is clean and the comment explains the rationale.

Consider enhancing the docstring to provide more context:

 @property
 def skip_dp(self) -> bool:
+    """Determines whether to skip tests for NumPy backend.
+    
+    Returns:
+        bool: True if precision is float32 (to avoid NumPy's silent handling of
+        float64 x float32 operations) or if CommonTest.skip_dp is True.
+    """
source/tests/consistent/descriptor/test_se_atten_v2.py (1)

181-183: LGTM! Consider enhancing the comment.

The skip condition for float32 precision is correct. However, the comment could be more specific about the NumPy behavior.

Consider expanding the comment to explain the specific behavior:

-    # NumPy doesn't throw errors for float64 x float32
+    # Skip float32 tests as NumPy silently promotes float32 to float64 during operations,
+    # masking potential precision issues that would occur in other backends
source/tests/consistent/descriptor/test_dpa1.py (2)

185-187: LGTM! Consider enhancing the comment.

The skip condition for float32 precision is correct and aligns with the PR objectives. However, the comment could be more explicit about NumPy's type promotion behavior.

Consider expanding the comment to:

-    # NumPy doesn't throw errors for float64 x float32
+    # Skip float32 tests as NumPy silently promotes float32 to float64 during operations,
+    # which masks potential precision-related issues

244-246: LGTM! Consider extracting duplicated logic.

The skip condition is correct but duplicates the logic and comment from skip_dp. Consider extracting this common check into a private method.

Consider refactoring to:

+    def _should_skip_float32(self) -> bool:
+        """Check if tests should be skipped for float32 precision.
+        
+        NumPy silently promotes float32 to float64 during operations,
+        which masks potential precision-related issues.
+        """
+        return self.param[-3] == "float32"  # precision parameter
+
     @property
     def skip_dp(self) -> bool:
-        if precision == "float32":
-            # NumPy doesn't throw errors for float64 x float32
-            return True
+        if self._should_skip_float32():
+            return True
         return CommonTest.skip_dp or ...

     @property
     def skip_array_api_strict(self) -> bool:
-        if precision == "float32":
-            # NumPy doesn't throw errors for float64 x float32
-            return True
+        if self._should_skip_float32():
+            return True
         return not INSTALLED_ARRAY_API_STRICT or ...
source/tests/consistent/descriptor/test_dpa2.py (3)

248-250: Enhance comment clarity regarding NumPy's mixed precision behavior.

The comment could be more specific about why NumPy's behavior with mixed precision operations necessitates skipping these tests.

Consider updating the comment to be more explicit:

-    # NumPy doesn't throw errors for float64 x float32
+    # Skip tests for float32 as NumPy silently promotes float32 to float64 during mixed precision operations,
+    # which masks potential precision-related issues that could occur in production

289-318: Simplify parameter unpacking in skip_array_api_strict.

The method unpacks all parameters but only uses 'precision'. This makes the code harder to maintain and violates the principle of minimal parameter usage.

Consider simplifying the parameter handling:

     @property
     def skip_array_api_strict(self) -> bool:
-        (
-            repinit_tebd_input_mode,
-            repinit_set_davg_zero,
-            repinit_type_one_side,
-            repinit_use_three_body,
-            # ... [removed for brevity]
-            use_tebd_bias,
-        ) = self.param
+        precision = self.param[23]  # Index of precision in param tuple

Alternatively, consider refactoring the test class to make precision more easily accessible without unpacking all parameters.


319-321: Duplicate comment about NumPy's behavior.

This comment is identical to the one in skip_dp. Consider creating a shared documentation string or constant to maintain consistency and make future updates easier.

Consider creating a constant at the class level:

NUMPY_PRECISION_SKIP_REASON = "NumPy doesn't throw errors for float64 x float32"

Then use it in both methods:

     @property
     def skip_dp(self) -> bool:
         if precision == "float32":
-            # NumPy doesn't throw errors for float64 x float32
+            # {NUMPY_PRECISION_SKIP_REASON}
             return True

     @property
     def skip_array_api_strict(self) -> bool:
         if precision == "float32":
-            # NumPy doesn't throw errors for float64 x float32
+            # {NUMPY_PRECISION_SKIP_REASON}
             return True
deepmd/dpmodel/common.py (1)

150-156: Ensure that Array type is defined or imported in the example

In the example provided in the docstring, the type Array is used in the function signature, but it is not defined or imported. This may lead to confusion or errors when using the example code. Consider importing Array or replacing it with an appropriate type hint.

Apply this diff to update the example:

+from typing import Any
...
      @cast_precision
-     def f(x: Array, y: Array) -> Array:
+     def f(x: Any, y: Any) -> Any:
          return x**2 + y

Alternatively, if Array is defined elsewhere in your codebase, consider adding an import statement in the example.

deepmd/dpmodel/fitting/general_fitting.py (1)

367-371: Refactor repetitive type casting into a helper function

The type casting of descriptor, fparam, and aparam to the specified precision is repeated. To enhance maintainability and reduce code duplication, consider creating a helper function or loop to handle these casts.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 4793125 and 73b4227.

📒 Files selected for processing (21)
  • .github/workflows/test_python.yml (1 hunks)
  • deepmd/dpmodel/common.py (2 hunks)
  • deepmd/dpmodel/descriptor/dpa1.py (3 hunks)
  • deepmd/dpmodel/descriptor/dpa2.py (3 hunks)
  • deepmd/dpmodel/descriptor/se_e2_a.py (5 hunks)
  • deepmd/dpmodel/descriptor/se_r.py (2 hunks)
  • deepmd/dpmodel/descriptor/se_t.py (2 hunks)
  • deepmd/dpmodel/descriptor/se_t_tebd.py (3 hunks)
  • deepmd/dpmodel/fitting/general_fitting.py (2 hunks)
  • deepmd/jax/env.py (1 hunks)
  • source/tests/consistent/descriptor/test_dpa1.py (2 hunks)
  • source/tests/consistent/descriptor/test_dpa2.py (2 hunks)
  • source/tests/consistent/descriptor/test_se_atten_v2.py (2 hunks)
  • source/tests/consistent/descriptor/test_se_e2_a.py (2 hunks)
  • source/tests/consistent/descriptor/test_se_r.py (2 hunks)
  • source/tests/consistent/descriptor/test_se_t.py (2 hunks)
  • source/tests/consistent/descriptor/test_se_t_tebd.py (2 hunks)
  • source/tests/consistent/fitting/test_dipole.py (1 hunks)
  • source/tests/consistent/fitting/test_dos.py (1 hunks)
  • source/tests/consistent/fitting/test_ener.py (2 hunks)
  • source/tests/consistent/fitting/test_polar.py (1 hunks)
🔇 Additional comments (34)
deepmd/jax/env.py (1)

16-18: LGTM! The strict dtype promotion aligns with precision requirements.

The conditional configuration of strict dtype promotion is a good defensive measure to prevent silent precision loss during numerical operations. This complements the existing x64 precision setting and aligns with the PR's objective of fixing precision issues.

Let's verify how this affects the codebase:

.github/workflows/test_python.yml (1)

65-65: LGTM! Verify test coverage impact.

The addition of DP_DTYPE_PROMOTION_STRICT environment variable aligns with the PR's objective of fixing precision issues. The placement in the TF2 eager mode test section is appropriate.

Let's verify if this affects test coverage:

✅ Verification successful

Environment variable is correctly placed and impacts JAX precision handling

The verification shows that DP_DTYPE_PROMOTION_STRICT is used in deepmd/jax/env.py to control dtype promotion behavior in JAX. The extensive test coverage (150+ test files) involving dtype/precision handling and JAX-specific tests confirms this change will properly enforce strict type checking during tests.

  • The variable is referenced in deepmd/jax/env.py for JAX configuration
  • Test suite includes comprehensive JAX tests under source/tests/consistent/ that will be affected
  • Large number of tests involving dtype/precision handling will help catch any type-related issues
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for any tests that might be affected by strict dtype promotion

# Look for tests involving dtype or precision
rg -l "dtype|precision|float32|float64" source/tests/

# Look for JAX-specific tests that might be affected
rg -l "jax" source/tests/

# Check for any existing usage of DP_DTYPE_PROMOTION_STRICT
rg "DP_DTYPE_PROMOTION_STRICT" 

Length of output: 10429

source/tests/consistent/descriptor/test_se_r.py (2)

118-120: LGTM! Consistent handling of FP32 precision across test modes.

The implementation maintains consistency with the skip_dp property, ensuring that float32 precision tests are skipped in array API strict mode as well.


95-97: LGTM! Skipping FP32 tests for NumPy is appropriate.

The added check correctly prevents running tests with float32 precision since NumPy silently handles mixed-precision operations between float64 and float32, which could mask potential precision-related issues.

Let's verify if this change is consistent across other descriptor test files:

✅ Verification successful

Verified: The float32 precision skip is consistently implemented across descriptor tests

The verification shows that:

  • The same float32 precision check pattern exists across multiple descriptor test files (test_dpa1.py, test_dpa2.py, test_se_atten_v2.py, test_se_e2_a.py, test_se_t.py, test_se_t_tebd.py)
  • The implementation in test_se_r.py follows the same pattern used in other files, with identical comments about NumPy's behavior with float32
  • All descriptor test files consistently handle float32 precision cases
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if similar precision checks exist in other descriptor test files
# Expected: Similar skip_dp implementations in other test files

# Search for skip_dp methods in descriptor test files
rg -A 5 "def skip_dp.*precision.*float32" source/tests/consistent/descriptor/

Length of output: 79


Script:

#!/bin/bash
# Let's try a broader search to find skip_dp property implementations
rg -A 5 "@property\s+def skip_dp" source/tests/consistent/descriptor/

# Also search for similar float32 precision checks
rg "precision.*float32" source/tests/consistent/descriptor/

Length of output: 2545


Script:

#!/bin/bash
# Let's check the context of these float32 checks in one of the files
# to see if they follow the same pattern as test_se_r.py
rg -B 2 -A 2 "precision.*float32" source/tests/consistent/descriptor/test_dpa1.py

Length of output: 706

source/tests/consistent/descriptor/test_se_t.py (1)

92-94: LGTM! Verify similar changes in other test files.

The skip logic for float32 precision is correctly implemented and well-documented. The comment clearly explains why NumPy tests need to be skipped for float32 precision.

Let's verify that this skip logic is consistently implemented across other descriptor test files:

source/tests/consistent/fitting/test_polar.py (1)

101-112: LGTM! Verify test coverage for all precision combinations.

The implementation correctly combines precision-based skipping with installation status check.

Let's verify that the test parameterization covers all precision combinations:

✅ Verification successful

LGTM! Test coverage for precision combinations is comprehensive.

The test class is properly parameterized with all precision combinations:

  • Test runs with both float64 and float32 precision
  • For float32, array_api_strict tests are correctly skipped due to NumPy's mixed precision behavior
  • For float64, tests run when array_api_strict backend is installed
  • The parameterization also covers combinations with resnet_dt and mixed_types
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check test parameterization and actual test executions

# Check test parameterization
rg -A 10 "@parameterized" source/tests/consistent/fitting/test_polar.py

# Check actual test executions and their skip status
rg "def test_" source/tests/consistent/fitting/test_polar.py

# Look for similar skip patterns in other test files
rg "skip_array_api_strict.*float32" source/tests

Length of output: 475


Script:

#!/bin/bash
# Let's get more context about the test class and its parameterization

# Get the full test class implementation
cat source/tests/consistent/fitting/test_polar.py

# Check for any other skip_array_api_strict implementations
rg -B 5 -A 5 "skip_array_api_strict" source/tests/

Length of output: 21551

source/tests/consistent/descriptor/test_se_e2_a.py (1)

101-103: LGTM! Verify NumPy's mixed precision behavior.

The skip condition for float32 precision is well-documented and aligns with the PR objectives.

Let's verify the NumPy behavior that necessitates this skip:

✅ Verification successful

Based on the search results, I can now verify the review comment. The codebase shows consistent handling of float32 precision across multiple test files, particularly in the context of NumPy's behavior with mixed precision operations. The documentation and test files confirm that NumPy doesn't throw errors for float64 x float32 operations, which is exactly what the skip condition is handling.

Skip condition for float32 precision is correctly implemented

The skip condition is well-justified as NumPy's mixed precision behavior requires special handling for float32 cases. This is consistently implemented across multiple test files in the codebase, including:

  • Other descriptor tests like test_se_t.py, test_se_r.py, test_se_atten_v2.py have similar skip conditions
  • The comment accurately describes NumPy's behavior with mixed precision operations
  • The implementation aligns with the project's precision handling patterns
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Search for related test failures or issues with NumPy's mixed precision

# Search for related test failures or discussions
rg -l "float32.*float64|mixed precision|precision issue" --type md

# Search for similar skip conditions in other test files
rg "precision.*float32.*skip|skip.*float32" --type py

Length of output: 155


Script:

#!/bin/bash
# Let's check the content of the documentation file and search for more context
cat doc/train/training-advanced.md

# Search for any float32/float64 related tests or implementations
rg "float(32|64)" -A 3 --type py

# Search for NumPy specific precision handling
rg "numpy.*float|np\.(float32|float64)" -A 2 --type py

# Look for skip conditions in test files
fd "test.*\.py$" --exec grep -l "skip.*precision\|precision.*skip" {} \;

Length of output: 223836

source/tests/consistent/fitting/test_dos.py (1)

100-113: LGTM! Verify consistent skip behavior across test files.

The implementation correctly handles skipping tests for float32 precision due to NumPy's mixed precision behavior.

Let's verify that this skip behavior is consistently implemented across other test files:

✅ Verification successful

Skip behavior for float32 precision is consistently implemented across test files

The verification shows that all test files in the codebase consistently implement the same skip logic for float32 precision with identical comments explaining "NumPy doesn't throw errors for float64 x float32". This includes:

  • All descriptor test files (se_e2_a, se_t, dpa1, se_atten_v2, se_t_tebd, dpa2, se_r)
  • All fitting test files (dipole, ener, polar, dos)
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for consistent implementation of skip_dp property across test files
# Expect: All test files should have similar skip logic for float32 precision

# Search for skip_dp property implementations
ast-grep --pattern 'def skip_dp(self) -> bool:
  $$$
  if $precision == "float32":
    $$$
    return True
  $$$'

# Also check for any inconsistent implementations
rg -l 'def skip_dp' | xargs rg 'float32|precision'

Length of output: 19917

source/tests/consistent/descriptor/test_se_t_tebd.py (1)

130-132: LGTM! Skipping float32 tests prevents silent precision issues.

The added check correctly implements the PR objective to skip float32 tests for NumPy, preventing potential silent precision issues when mixing float64 and float32 operations.

Let's verify this is consistent with other test files:

✅ Verification successful

I'm broadening the search to:

  1. Find any skip patterns related to precision across all test files
  2. Look specifically for float32 handling in the consistent tests directory
    This will help verify if the approach is consistent with other test files.

The float32 skip pattern is consistent across the codebase

The search results show that the same float32 skip pattern with identical comment "NumPy doesn't throw errors for float64 x float32" is consistently implemented across multiple test files in the codebase, including:

  • descriptor tests (se_t_tebd.py, se_t.py, se_r.py, se_e2_a.py, se_atten_v2.py, dpa1.py, dpa2.py)
  • fitting tests (polar.py, property.py, dos.py, dipole.py, ener.py)

This confirms that the implementation in test_se_t_tebd.py follows the established pattern for handling float32 precision tests across the project.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for similar precision-based skip patterns in other test files
rg -l "precision.*float32.*return True" source/tests/

Length of output: 55


Script:

#!/bin/bash
# Let's try a broader search for precision-related skip patterns
rg -B2 -A2 "skip.*precision" source/tests/

# Also search for float32-related patterns
rg -B2 -A2 "float32" source/tests/consistent/

Length of output: 25430

source/tests/consistent/fitting/test_ener.py (1)

130-132: LGTM! Consistent precision handling.

The modification to skip_array_api_strict property maintains consistency with skip_dp by applying the same precision-based skipping logic.

source/tests/consistent/descriptor/test_se_atten_v2.py (1)

244-246: LGTM! Consistent with DP implementation.

The skip condition for array_api_strict matches the DP implementation, ensuring consistent handling of float32 precision across different backends.

deepmd/dpmodel/descriptor/se_r.py (2)

17-17: LGTM: Import addition aligns with precision handling requirements

The addition of cast_precision import is appropriate for implementing precision control in the descriptor computations.


293-293: LGTM: Precision control enhancement via decorator

The addition of @cast_precision decorator appropriately addresses precision handling in descriptor computations. This change aligns with the PR objective of fixing precision issues.

Let's verify the numerical stability of the descriptor computations with the new precision handling:

✅ Verification successful

Let me gather more information about the precision-related tests by examining their implementation details.


Let me try a different approach to examine the test implementations.


Precision handling is properly tested and verified

The test suite comprehensively validates the @cast_precision decorator's functionality through:

  • Tests with both float32 and float64 precision
  • Numerical consistency checks between serialized/deserialized instances
  • Validation across different tensor operations with appropriate tolerances
  • Tests with various configurations (resnet_dt, exclude_mask)
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Search for test files that verify numerical stability of descriptor computations

# Look for relevant test files
echo "Searching for relevant test files..."
rg -l "DescrptSeR" --type py | grep "test"

# Look for precision-related tests
echo "Searching for precision-related tests..."
rg -A 5 "precision|dtype|fp32|fp64" --type py | grep -B 2 "DescrptSeR"

Length of output: 3403


Script:

#!/bin/bash
# Check the actual test implementations for precision-related assertions
rg -B 2 -A 10 "test.*prec" source/tests/pt/model/test_descriptor_se_r.py

# Look for precision-related test cases in other test files
rg -B 2 -A 10 "test.*prec" source/tests/consistent/descriptor/test_se_r.py

# Check if there are any existing precision validation tests
rg -B 2 -A 10 "assert.*dtype" source/tests/pt/model/test_descriptor_se_r.py

Length of output: 229


Script:

#!/bin/bash
# Look for test methods in the descriptor test files
echo "Checking test_descriptor_se_r.py..."
rg -A 20 "def test_" source/tests/pt/model/test_descriptor_se_r.py

# Check for precision/dtype related variables and test setup
echo "Checking precision-related setup..."
rg -B 5 -A 5 "dtype|precision|fp32|fp64" source/tests/pt/model/test_descriptor_se_r.py

# Look for test assertions related to descriptor output
echo "Checking descriptor output assertions..."
rg -B 5 -A 5 "assert|self.assert" source/tests/pt/model/test_descriptor_se_r.py

Length of output: 7177

deepmd/dpmodel/descriptor/se_t.py (2)

17-17: LGTM: Import addition is well-organized

The addition of cast_precision import is properly grouped with related imports from the same module.


268-268: Verify numerical equivalence with previous implementation

The addition of @cast_precision decorator aligns with the PR's objective to fix precision issues. However, since this affects numerical computations in a scientific computing context, it's crucial to verify that the results remain numerically equivalent to the previous implementation.

Let's verify the numerical precision handling:

✅ Verification successful

Based on the search results, I can now generate the final response:

@cast_precision decorator consistently handles precision across descriptors

The addition of @cast_precision decorator to the call method is part of a systematic approach to handle precision casting across the codebase. The search results show:

  • The decorator is consistently used across multiple descriptor implementations (se_t, se_r, se_a, dpa1, etc.)
  • It's properly defined in both dpmodel/common.py and tf/common.py for different backends
  • The implementation uses PRECISION_DICT to ensure consistent precision handling
  • No remaining explicit precision casting in the descriptor implementations that would conflict with the decorator
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Search for any remaining explicit precision casting in descriptor implementations
# to ensure consistent precision handling across the codebase

# Look for any remaining astype calls that might indicate inconsistent precision handling
rg "astype.*prec|PRECISION_DICT" --type python

# Look for other descriptor implementations to verify consistent usage of cast_precision
ast-grep --pattern 'class Descrpt$_ {
  $$$
  @cast_precision
  def call($$$) {
    $$$
  }
  $$$
}'

Length of output: 182


Script:

#!/bin/bash
# Search for astype and precision related patterns without file type restriction
rg "astype.*prec|PRECISION_DICT"

# Look for cast_precision decorator usage
rg "@cast_precision" -A 2

# Look for descriptor class implementations
rg "class.*Descriptor|class.*Descrpt" -A 5

# Check the cast_precision decorator implementation
rg "def cast_precision"

Length of output: 86322

deepmd/dpmodel/descriptor/se_t_tebd.py (2)

171-171: LGTM: Added precision instance variable

The addition of the precision instance variable ensures proper precision handling throughout the class lifecycle.


291-291: LGTM: Added precision casting decorator

The @cast_precision decorator ensures consistent precision handling for input and output tensors in the call method.

Let's verify that all tensor operations in the method handle the precision correctly:

✅ Verification successful

Let me gather more information about the cast_precision decorator and its usage to ensure there are no conflicts.


Let me check the implementation of the cast_precision decorator to ensure it's compatible with the current usage.


Verified: Precision casting decorator is correctly applied

The @cast_precision decorator from deepmd/dpmodel/common.py is correctly implemented to:

  1. Cast input arrays from global precision to the class's precision before method execution
  2. Cast output arrays back to global precision after method execution

The decorator's usage aligns with the existing precision handling in the file, where:

  • The class already uses PRECISION_DICT[self.precision] for numpy arrays
  • Explicit dtype casting is properly managed with astype
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for any explicit dtype casting in the method that might conflict with the decorator
rg -A 2 "astype|dtype=" deepmd/dpmodel/descriptor/se_t_tebd.py

# Check for any numpy array creation that might need precision handling
rg "np\.(zeros|ones|full|empty)" deepmd/dpmodel/descriptor/se_t_tebd.py

Length of output: 693


Script:

#!/bin/bash
# Check the implementation of cast_precision decorator
rg -B 2 -A 5 "@cast_precision" deepmd/dpmodel/descriptor/se_t_tebd.py

# Check for any other precision-related code in the file
rg "precision" deepmd/dpmodel/descriptor/se_t_tebd.py

# Look for the decorator definition
fd -e py -x rg -l "def cast_precision"

Length of output: 1082


Script:

#!/bin/bash
# Check the implementation of cast_precision decorator in both files
rg -B 2 -A 10 "def cast_precision" deepmd/tf/common.py deepmd/dpmodel/common.py

# Check if the call method has any tensor operations that might be affected
rg -A 10 "def call" deepmd/dpmodel/descriptor/se_t_tebd.py

Length of output: 2199

deepmd/dpmodel/descriptor/dpa2.py (3)

17-17: LGTM: Import of cast_precision decorator

The addition of the cast_precision import aligns with the precision handling enhancement objective.


762-762: LGTM: Addition of cast_precision decorator

The addition of the @cast_precision decorator to the call method appropriately implements precision control for tensor operations, aligning with the PR's objective to fix precision handling.


Line range hint 762-924: Verify precision handling in tensor operations

Let's ensure that precision handling is consistent across all tensor operations in the call method and related files.

deepmd/dpmodel/common.py (2)

6-8: Importing wraps from functools is appropriate

The wraps decorator is correctly imported and used to preserve the metadata of the original function.


11-11: Importing Callable and Optional from typing

The imports are necessary for the type hints used in the functions.

deepmd/dpmodel/fitting/general_fitting.py (4)

447-449: Efficient masking of atom_property using xp.where

Using xp.where to apply the mask to atom_property improves code clarity and computational efficiency.


451-451: Consistent casting of outs to global precision

Casting outs to global precision ensures consistency in computations and aligns with the precision used elsewhere in the codebase.


455-459: Correct application of per-atom biases in mixed types scenario

The code correctly adds biases to outs based on each atom's type using xp.take and reshaping. This ensures that atoms receive the appropriate bias according to their type.


466-466: Effective use of exclude_mask with xp.where

Applying exclude_mask to outs using xp.where efficiently zeroes out contributions from excluded atom types, enhancing code clarity.

deepmd/dpmodel/descriptor/se_e2_a.py (5)

18-18: Import cast_precision for consistent precision handling

The addition of cast_precision to the imports allows for consistent precision casting throughout the module, aligning with the goal of fixing precision issues.


341-341: Apply @cast_precision decorator to call method

Applying the @cast_precision decorator ensures that the inputs and outputs of the call method are correctly cast to the desired precision, improving numerical stability and consistency.


417-417: Remove explicit precision casting inside the call method

By removing explicit casting to a global precision, the method now relies on the @cast_precision decorator to handle precision, simplifying the code and reducing redundancy.


506-506: Apply @cast_precision decorator to call method in DescrptSeAArrayAPI

Adding the @cast_precision decorator to the call method of DescrptSeAArrayAPI ensures consistent precision handling, mirroring the update made in the base class DescrptSeA.


586-586: Remove explicit precision casting in return statement

Removing the explicit casting in the return statement is appropriate because the @cast_precision decorator now manages precision, simplifying the code and ensuring consistency.

deepmd/dpmodel/descriptor/dpa1.py (3)

22-23: Import necessary utilities for precision handling

The inclusion of cast_precision and to_numpy_array is appropriate and necessary for managing precision settings and array conversions within the module.


333-333: Store precision setting in class instance

Assigning self.precision = precision in the constructor ensures that the precision configuration is consistently accessible throughout the class methods.


453-453: Apply @cast_precision decorator to call method

Applying the @cast_precision decorator to the call method ensures that computations within this method adhere to the specified precision, enhancing numerical consistency and preventing potential precision-related issues.

deepmd/dpmodel/descriptor/dpa2.py Show resolved Hide resolved
deepmd/dpmodel/common.py Outdated Show resolved Hide resolved
deepmd/dpmodel/common.py Outdated Show resolved Hide resolved
deepmd/dpmodel/fitting/general_fitting.py Outdated Show resolved Hide resolved
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz changed the title fix(dpmodel): fix precision; skip FP32 tests for NP fix(dpmodel): fix precision Nov 12, 2024
Copy link
Member Author

@njzjz njzjz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revert "skip FP32 for NP". The issue is that if dpmodel is skipped, the reference backend is fallback to TF, which is not expected...

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
deepmd/dpmodel/fitting/general_fitting.py (1)

456-459: Consider combining bias addition and exclude mask operations

While the current implementation is correct, you could potentially improve performance by combining the bias addition and exclude mask application into a single operation:

-outs = xp.astype(outs, get_xp_precision(xp, "global"))
-outs += xp.reshape(
-    xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0),
-    [nf, nloc, net_dim_out],
-)
-outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs))
+bias = xp.reshape(
+    xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0),
+    [nf, nloc, net_dim_out],
+)
+outs = xp.where(
+    exclude_mask[:, :, None],
+    xp.astype(outs, get_xp_precision(xp, "global")) + bias,
+    xp.zeros_like(bias)
+)

This change:

  1. Reduces memory allocation by combining operations
  2. Avoids unnecessary computation for excluded atoms

Also applies to: 463-463

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 73b4227 and 918284a.

📒 Files selected for processing (1)
  • deepmd/dpmodel/fitting/general_fitting.py (2 hunks)
🔇 Additional comments (3)
deepmd/dpmodel/fitting/general_fitting.py (3)

367-371: LGTM: Proper precision casting for input tensors

The explicit type casting ensures consistent precision handling across all input tensors using the get_xp_precision helper.


447-449: LGTM: Improved output handling using xp.where

The use of xp.where for conditional output is more explicit and potentially more efficient than direct masked assignment.


452-452: LGTM: Proper handling of mixed types and vacuum contribution

The code correctly processes outputs for mixed types and properly handles the vacuum contribution subtraction when required.

Also applies to: 455-455

@iProzd iProzd mentioned this pull request Nov 12, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
.github/workflows/test_python.yml (1)

65-65: Consider adding a comment explaining the variable's purpose.

To improve maintainability, consider adding a comment explaining why this environment variable is needed and its relationship with JAX's dtype promotion behavior.

Example addition:

+        # Enable strict dtype promotion in JAX to catch precision issues
         DP_DTYPE_PROMOTION_STRICT: 1
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 918284a and 741c34d.

📒 Files selected for processing (1)
  • .github/workflows/test_python.yml (1 hunks)
🔇 Additional comments (1)
.github/workflows/test_python.yml (1)

65-65: LGTM! Verify variable usage across the codebase.

The addition of DP_DTYPE_PROMOTION_STRICT=1 for TF2 eager mode tests is appropriate for enforcing stricter type checking.

Let's verify how this environment variable is used:

✅ Verification successful

Verification Successful: Variable Usage Confirmed

The environment variable DP_DTYPE_PROMOTION_STRICT is used consistently within deepmd/jax/env.py to enforce strict dtype promotion.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for other occurrences of DP_DTYPE_PROMOTION_STRICT
rg "DP_DTYPE_PROMOTION_STRICT" -A 3

# Look for related JAX dtype promotion settings
rg "dtype.*promotion" -A 3

Length of output: 422

deepmd/dpmodel/fitting/general_fitting.py Show resolved Hide resolved
deepmd/dpmodel/common.py Show resolved Hide resolved
deepmd/dpmodel/common.py Show resolved Hide resolved
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz requested a review from wanghan-iapcm November 12, 2024 22:35
@overload
def safe_cast_array(
input: np.ndarray, from_precision: str, to_precision: str
) -> np.ndarray: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
input: np.ndarray, from_precision: str, to_precision: str
) -> np.ndarray: ...
@overload
def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (3)
deepmd/dpmodel/common.py (2)

156-157: Fix undefined type hint in docstring example

The example uses Array as a type hint, but this type is not defined or imported. Consider using np.ndarray or adding a proper import statement.

-    ...     def f(x: Array, y: Array) -> Array:
+    ...     def f(x: np.ndarray, y: np.ndarray) -> np.ndarray:

204-205: Update docstring parameter and return types

The docstring incorrectly refers to tf.Tensor. Update it to reflect the actual types used in the function.

-    input : tf.Tensor
+    input : np.ndarray
         Input tensor
     ...
-    tf.Tensor
+    np.ndarray
         casted Tensor

Also applies to: 213-214

deepmd/dpmodel/atomic_model/base_atomic_model.py (1)

210-210: Minor optimization: Simplify dict key lookup

Use kk in ret_dict instead of kk in ret_dict.keys() for more efficient key membership testing.

-        for kk in ret_dict.keys():
+        for kk in ret_dict:
🧰 Tools
🪛 Ruff

210-210: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 741c34d and b7d2b32.

📒 Files selected for processing (2)
  • deepmd/dpmodel/atomic_model/base_atomic_model.py (1 hunks)
  • deepmd/dpmodel/common.py (2 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/atomic_model/base_atomic_model.py

210-210: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

🪛 GitHub Check: CodeQL
deepmd/dpmodel/common.py

[notice] 189-189: Statement has no effect
This statement has no effect.


[notice] 191-191: Statement has no effect
This statement has no effect.

🔇 Additional comments (5)
deepmd/dpmodel/common.py (3)

6-8: LGTM: Imports are properly organized and necessary.

The new imports support the added functionality for precision management.

Also applies to: 11-13


160-183: LGTM: Robust implementation of precision casting

The decorator implementation is well-structured and handles various input/output scenarios correctly:

  • Properly processes both args and kwargs
  • Handles multiple return types (tuple, dict)
  • Uses safe type casting

186-220: LGTM: Well-structured type-safe implementation

The function is properly typed with overloads and correctly handles both NumPy arrays and Array API objects.

🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 189-189: Statement has no effect
This statement has no effect.


[notice] 191-191: Statement has no effect
This statement has no effect.

deepmd/dpmodel/atomic_model/base_atomic_model.py (2)

204-208: LGTM: Improved mask computation logic

The change from element-wise multiplication to logical_and is more semantically correct for combining boolean masks and helps maintain precision by avoiding floating-point operations.


213-215: LGTM: Enhanced precision in output masking

The use of xp.where with explicit zeros provides better precision control compared to direct multiplication with the mask. The array reshaping properly handles multi-dimensional outputs while maintaining array API compatibility.

Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (1)
deepmd/dpmodel/common.py (1)

124-183: LGTM with minor docstring improvements

The decorator implementation is solid and handles all cases correctly, including tuple and dict returns. Consider improving the example in the docstring for clarity.

Update the example to use proper type hints and actual numpy arrays:

     Examples
     --------
     >>> class A:
     ...     def __init__(self):
     ...         self.precision = "float32"
     ...
     ...     @cast_precision
-    ...     def f(x: Array, y: Array) -> Array:
+    ...     def f(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
     ...         return x**2 + y
+    ...
+    >>> import numpy as np
+    >>> a = A()
+    >>> x = np.array([1.0, 2.0], dtype=np.float64)
+    >>> y = np.array([3.0, 4.0], dtype=np.float64)
+    >>> result = a.f(x, y)  # result will be float64, internally computed as float32
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between b7d2b32 and 201cf80.

📒 Files selected for processing (1)
  • deepmd/dpmodel/common.py (2 hunks)
🧰 Additional context used
🪛 GitHub Check: CodeQL
deepmd/dpmodel/common.py

[notice] 189-189: Statement has no effect
This statement has no effect.


[notice] 191-191: Statement has no effect
This statement has no effect.

🔇 Additional comments (1)
deepmd/dpmodel/common.py (1)

6-13: LGTM: Import additions are appropriate

The new imports support the decorator pattern and type system features appropriately.

deepmd/dpmodel/common.py Show resolved Hide resolved
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (2)
deepmd/dpmodel/fitting/dipole_fitting.py (1)

Additional Precision Handling Required in Multiple Modules

Several precision-sensitive operations using np.einsum, np.dot, and np.matmul were identified. It is recommended to apply the @cast_precision decorator to these methods to ensure consistent numerical stability across the codebase.

  • deepmd/utils/out_stat.py (Lines 70-72)
  • deepmd/dpmodel/descriptor/se_e2_a.py (Lines 408, 415)
  • source/ipi/tests/test_driver.py (Multiple lines)
  • deepmd/dpmodel/fitting/dipole_fitting.py (Additional lines)
  • [Refer to full list in shell script output]
🔗 Analysis chain

Line range hint 180-238: Excellent addition of precision handling to dipole calculations

The addition of @cast_precision decorator improves the robustness of numerical computations in the dipole fitting calculations. The implementation:

  • Ensures consistent precision across array operations
  • Maintains clean matrix operations with modern syntax
  • Preserves comprehensive type hints and documentation

Let's verify the precision handling impact:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for other precision-sensitive operations in the codebase
# that might need similar treatment

# Look for similar numerical computation patterns
ast-grep --pattern 'np.einsum($$$)' 

# Check for other array operations that might need precision handling
rg -l 'import numpy as np' | xargs rg '@|np\.dot|np\.matmul'

Length of output: 77769

deepmd/dpmodel/atomic_model/base_atomic_model.py (1)

210-210: Minor optimization: Use direct dictionary membership test

Replace ret_dict.keys() with direct dictionary membership test for better performance.

-for kk in ret_dict.keys():
+for kk in ret_dict:
🧰 Tools
🪛 Ruff

210-210: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 201cf80 and 41d80b9.

📒 Files selected for processing (5)
  • deepmd/dpmodel/atomic_model/base_atomic_model.py (1 hunks)
  • deepmd/dpmodel/fitting/dipole_fitting.py (2 hunks)
  • deepmd/dpmodel/fitting/general_fitting.py (1 hunks)
  • deepmd/dpmodel/fitting/invar_fitting.py (2 hunks)
  • deepmd/dpmodel/fitting/polarizability_fitting.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/dpmodel/fitting/general_fitting.py
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/atomic_model/base_atomic_model.py

210-210: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

🔇 Additional comments (8)
deepmd/dpmodel/fitting/dipole_fitting.py (1)

14-16: LGTM: Clean import addition for precision handling

The addition of the cast_precision import aligns with the broader initiative to enhance precision handling across the codebase.

deepmd/dpmodel/fitting/invar_fitting.py (2)

13-15: LGTM: Clean import addition

The import of cast_precision is properly organized with other imports from the same module.


209-209: Verify precision handling across different input types

The addition of @cast_precision decorator is appropriate for ensuring consistent precision across numerical computations. However, let's verify its behavior with different input types.

deepmd/dpmodel/atomic_model/base_atomic_model.py (2)

204-208: LGTM! Improved boolean mask computation

The change from element-wise multiplication to logical AND is a good improvement as it:

  1. Makes the boolean intent more explicit
  2. Potentially improves performance through short-circuiting
  3. Enhances code maintainability

213-216: LGTM! Robust precision-preserving masking implementation

The implementation correctly handles n-dimensional arrays while preserving precision through:

  1. Proper dimension handling with math.prod
  2. Precise masking using xp.zeros_like
  3. Shape preservation through careful reshaping
deepmd/dpmodel/fitting/polarizability_fitting.py (3)

18-18: LGTM: Import addition is appropriate.

The addition of cast_precision import aligns with the module's precision handling requirements.


245-245: LGTM: Decorator addition enhances precision handling.

The @cast_precision decorator ensures consistent precision handling across input and output tensors, which is crucial for numerical stability in scientific computations.


290-291: LGTM: Type casting ensures precision consistency.

The explicit type casting in both scale and constant_matrix operations ensures consistent precision throughout the calculations. This is particularly important for:

  1. Preventing silent precision loss during tensor operations
  2. Maintaining numerical stability in scientific computations

However, let's verify that these changes don't introduce any precision-related regressions.

Also applies to: 314-318

✅ Verification successful

Type casting changes are correct and safe

The explicit type casting to match output tensor dtype follows established patterns in the codebase and maintains precision consistency. The changes are well-integrated with the global precision handling architecture.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for any existing precision-related tests and their results

# Look for precision-related test files
echo "Searching for precision-related tests..."
rg -l "precision|dtype" --type py "tests/"

# Look for any existing precision-related issues
echo "Checking for precision-related issues in the codebase..."
rg "precision|dtype" --type py

Length of output: 271265

Co-authored-by: Han Wang <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 41d80b9 and 30d33b7.

📒 Files selected for processing (1)
  • deepmd/dpmodel/fitting/general_fitting.py (1 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/fitting/general_fitting.py

460-460: SyntaxError: Expected an expression or a '}'

🔇 Additional comments (3)
deepmd/dpmodel/fitting/general_fitting.py (3)

442-444: LGTM: Improved type masking implementation

The use of xp.where provides a more explicit and safer way to handle type masking, ensuring proper zero initialization for non-matching types.


447-449: LGTM: Consistent vacuum contribution handling

The implementation correctly handles vacuum contribution removal for mixed types, maintaining consistency with the non-mixed types case.


450-455: LGTM: Improved bias addition with type safety

The implementation enhances type safety and efficiency by:

  • Using explicit type casting with xp.astype
  • Efficiently applying bias per atom type using xp.take
  • Properly reshaping the output to maintain dimensions

deepmd/dpmodel/fitting/general_fitting.py Outdated Show resolved Hide resolved
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <[email protected]>
github-merge-queue bot pushed a commit that referenced this pull request Nov 13, 2024
Tried to implement the decorator as in #4343, but encountered JIT
errors.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Enhanced precision handling across various descriptor classes and
methods, ensuring consistent tensor operations.
- Updated output formats in several classes to improve clarity and
usability.
- Introduced a new environment variable for stricter control over tensor
precision handling.
- Added a new parameter to the `DipoleFittingNet` class for excluding
specific types.

- **Bug Fixes**
- Removed conditions that skipped tests for "float32" data type,
allowing all tests to run consistently.

- **Documentation**
- Improved error messages for dimension mismatches and unsupported
parameters, enhancing user understanding.

- **Tests**
- Adjusted test parameters for consistency in handling `fparam` and
`aparam` across multiple test cases.
- Simplified tensor handling in tests by removing unnecessary type
conversions before compression.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz added this pull request to the merge queue Nov 14, 2024
Merged via the queue into deepmodeling:devel with commit 6e815a2 Nov 14, 2024
51 checks passed
@njzjz njzjz deleted the fix-dpmodel-precision branch November 14, 2024 11:09
@njzjz njzjz linked an issue Nov 15, 2024 that may be closed by this pull request
This was referenced Nov 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Interface/Internal precision design & consistency
3 participants