-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(dpmodel): fix precision #4343
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 Walkthrough📝 WalkthroughWalkthroughThe pull request introduces several updates across multiple files, primarily focusing on enhancing precision handling in descriptor computations and refining the Python testing workflow. Key changes include the addition of a new environment variable in the testing workflow, the implementation of a Changes
Possibly related PRs
Suggested labels
Suggested reviewers
Warning Rate limit exceeded@njzjz has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 25 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🔇 Additional comments (4)source/tests/common/dpmodel/test_network.py (4)
The new import of
The test properly creates input arrays with the correct precision, maintaining consistency with the precision parameter used in EmbeddingNet initialization.
The test properly handles precision-specific inputs and verifies both the network's ability to process inputs and the consistency of results between serialized and deserialized instances.
The test now properly creates input arrays with the correct precision based on the test parameter. This ensures thorough testing across different precision settings ("float32", "float64", "single", "double"). Let's verify that all precision options are handled correctly by ✅ Verification successfulPrecision handling is correctly implemented and matches test requirements The
This confirms that the test's input array creation with dynamic precision is well-supported by the underlying implementation. 🏁 Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Description: Verify the implementation of get_xp_precision function
# Look for the function definition to ensure all precision options are handled
ast-grep --pattern 'def get_xp_precision($_, $_):
$$$
'
Length of output: 1677 Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Outside diff range and nitpick comments (16)
source/tests/consistent/descriptor/test_se_t.py (1)
107-119
: LGTM! Consider adding return type hint.The skip logic is well-implemented and consistent with the PR's objectives. The code is clean and well-structured.
Consider adding a return type hint to match the style of other properties:
- def skip_array_api_strict(self): + def skip_array_api_strict(self) -> bool:source/tests/consistent/fitting/test_dipole.py (2)
89-99
: LGTM! Consider enhancing the comment clarity.The implementation correctly skips float32 precision tests, aligning with the PR objective. However, the comment could be more explicit about the specific NumPy behavior being addressed.
Consider updating the comment to be more descriptive:
- # NumPy doesn't throw errors for float64 x float32 + # Skip float32 tests as NumPy silently promotes float32 to float64 during operations, + # which masks potential precision-related issues
101-111
: LGTM! Consider reducing code duplication.The implementation correctly handles both float32 precision and Array API strict mode availability. However, there's duplicated logic with
skip_dp
for float32 handling.Consider extracting the shared float32 check into a private property:
+ @property + def _skip_float32(self) -> bool: + """Skip tests for float32 precision due to NumPy's silent promotion behavior.""" + (_, precision, _) = self.param + return precision == "float32" + @property def skip_dp(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True + if self._skip_float32: + return True return CommonTest.skip_dp @property def skip_array_api_strict(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True + if self._skip_float32: + return True return not INSTALLED_ARRAY_API_STRICTsource/tests/consistent/fitting/test_polar.py (1)
89-100
: LGTM! Consider enhancing the comment for better clarity.The implementation correctly handles skipping DP backend tests for float32 precision to avoid false positives.
Consider expanding the comment to better explain the implications:
- # NumPy doesn't throw errors for float64 x float32 + # Skip float32 tests as NumPy silently promotes float32 to float64 during operations, + # which could mask precision-related issues we want to catchsource/tests/consistent/descriptor/test_se_e2_a.py (1)
137-139
: LGTM! Consider extracting the precision check.The skip condition is consistent with the skip_dp implementation.
Consider extracting the precision check into a private method to avoid duplication:
+ def _should_skip_float32(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return precision == "float32" + @property def skip_dp(self) -> bool: - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True + if self._should_skip_float32(): + return True return CommonTest.skip_dp @property def skip_array_api_strict(self) -> bool: - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True + if self._should_skip_float32(): + return True return not type_one_side or not INSTALLED_ARRAY_API_STRICTsource/tests/consistent/fitting/test_dos.py (1)
121-131
: LGTM! Consider reducing code duplication.The implementation correctly handles skipping array API strict tests for float32 precision. However, the float32 check logic is duplicated from
skip_dp
.Consider extracting the common float32 check into a private method:
+ def _is_float32(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + numb_dos, + ) = self.param + return precision == "float32" + @property def skip_dp(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - numb_fparam, - numb_aparam, - numb_dos, - ) = self.param - if precision == "float32": + if self._is_float32(): # NumPy doesn't throw errors for float64 x float32 return True return CommonTest.skip_dp @property def skip_array_api_strict(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - numb_fparam, - numb_aparam, - numb_dos, - ) = self.param - if precision == "float32": + if self._is_float32(): # NumPy doesn't throw errors for float64 x float32 return True return not INSTALLED_ARRAY_API_STRICTsource/tests/consistent/descriptor/test_se_t_tebd.py (1)
153-172
: LGTM! Consider extracting common precision check.The implementation correctly handles array API strict mode skipping. However, the float32 check logic is duplicated from skip_dp.
Consider extracting the common precision check to reduce duplication:
+ @property + def _skip_float32(self) -> bool: + return self.param[8] == "float32" # precision is param[8] + @property def skip_dp(self) -> bool: (...param unpacking...) - if precision == "float32": + if self._skip_float32: return True return CommonTest.skip_dp @property def skip_array_api_strict(self) -> bool: (...param unpacking...) - if precision == "float32": + if self._skip_float32: return True return not INSTALLED_ARRAY_API_STRICTsource/tests/consistent/fitting/test_ener.py (1)
103-116
: LGTM! Consider enhancing the documentation.The
skip_dp
property correctly implements the skipping of float32 precision tests for NumPy backend, which aligns with the PR objectives. The implementation is clean and the comment explains the rationale.Consider enhancing the docstring to provide more context:
@property def skip_dp(self) -> bool: + """Determines whether to skip tests for NumPy backend. + + Returns: + bool: True if precision is float32 (to avoid NumPy's silent handling of + float64 x float32 operations) or if CommonTest.skip_dp is True. + """source/tests/consistent/descriptor/test_se_atten_v2.py (1)
181-183
: LGTM! Consider enhancing the comment.The skip condition for float32 precision is correct. However, the comment could be more specific about the NumPy behavior.
Consider expanding the comment to explain the specific behavior:
- # NumPy doesn't throw errors for float64 x float32 + # Skip float32 tests as NumPy silently promotes float32 to float64 during operations, + # masking potential precision issues that would occur in other backendssource/tests/consistent/descriptor/test_dpa1.py (2)
185-187
: LGTM! Consider enhancing the comment.The skip condition for float32 precision is correct and aligns with the PR objectives. However, the comment could be more explicit about NumPy's type promotion behavior.
Consider expanding the comment to:
- # NumPy doesn't throw errors for float64 x float32 + # Skip float32 tests as NumPy silently promotes float32 to float64 during operations, + # which masks potential precision-related issues
244-246
: LGTM! Consider extracting duplicated logic.The skip condition is correct but duplicates the logic and comment from
skip_dp
. Consider extracting this common check into a private method.Consider refactoring to:
+ def _should_skip_float32(self) -> bool: + """Check if tests should be skipped for float32 precision. + + NumPy silently promotes float32 to float64 during operations, + which masks potential precision-related issues. + """ + return self.param[-3] == "float32" # precision parameter + @property def skip_dp(self) -> bool: - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True + if self._should_skip_float32(): + return True return CommonTest.skip_dp or ... @property def skip_array_api_strict(self) -> bool: - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True + if self._should_skip_float32(): + return True return not INSTALLED_ARRAY_API_STRICT or ...source/tests/consistent/descriptor/test_dpa2.py (3)
248-250
: Enhance comment clarity regarding NumPy's mixed precision behavior.The comment could be more specific about why NumPy's behavior with mixed precision operations necessitates skipping these tests.
Consider updating the comment to be more explicit:
- # NumPy doesn't throw errors for float64 x float32 + # Skip tests for float32 as NumPy silently promotes float32 to float64 during mixed precision operations, + # which masks potential precision-related issues that could occur in production
289-318
: Simplify parameter unpacking inskip_array_api_strict
.The method unpacks all parameters but only uses 'precision'. This makes the code harder to maintain and violates the principle of minimal parameter usage.
Consider simplifying the parameter handling:
@property def skip_array_api_strict(self) -> bool: - ( - repinit_tebd_input_mode, - repinit_set_davg_zero, - repinit_type_one_side, - repinit_use_three_body, - # ... [removed for brevity] - use_tebd_bias, - ) = self.param + precision = self.param[23] # Index of precision in param tupleAlternatively, consider refactoring the test class to make precision more easily accessible without unpacking all parameters.
319-321
: Duplicate comment about NumPy's behavior.This comment is identical to the one in
skip_dp
. Consider creating a shared documentation string or constant to maintain consistency and make future updates easier.Consider creating a constant at the class level:
NUMPY_PRECISION_SKIP_REASON = "NumPy doesn't throw errors for float64 x float32"Then use it in both methods:
@property def skip_dp(self) -> bool: if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 + # {NUMPY_PRECISION_SKIP_REASON} return True @property def skip_array_api_strict(self) -> bool: if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 + # {NUMPY_PRECISION_SKIP_REASON} return Truedeepmd/dpmodel/common.py (1)
150-156
: Ensure thatArray
type is defined or imported in the exampleIn the example provided in the docstring, the type
Array
is used in the function signature, but it is not defined or imported. This may lead to confusion or errors when using the example code. Consider importingArray
or replacing it with an appropriate type hint.Apply this diff to update the example:
+from typing import Any ... @cast_precision - def f(x: Array, y: Array) -> Array: + def f(x: Any, y: Any) -> Any: return x**2 + yAlternatively, if
Array
is defined elsewhere in your codebase, consider adding an import statement in the example.deepmd/dpmodel/fitting/general_fitting.py (1)
367-371
: Refactor repetitive type casting into a helper functionThe type casting of
descriptor
,fparam
, andaparam
to the specified precision is repeated. To enhance maintainability and reduce code duplication, consider creating a helper function or loop to handle these casts.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (21)
.github/workflows/test_python.yml
(1 hunks)deepmd/dpmodel/common.py
(2 hunks)deepmd/dpmodel/descriptor/dpa1.py
(3 hunks)deepmd/dpmodel/descriptor/dpa2.py
(3 hunks)deepmd/dpmodel/descriptor/se_e2_a.py
(5 hunks)deepmd/dpmodel/descriptor/se_r.py
(2 hunks)deepmd/dpmodel/descriptor/se_t.py
(2 hunks)deepmd/dpmodel/descriptor/se_t_tebd.py
(3 hunks)deepmd/dpmodel/fitting/general_fitting.py
(2 hunks)deepmd/jax/env.py
(1 hunks)source/tests/consistent/descriptor/test_dpa1.py
(2 hunks)source/tests/consistent/descriptor/test_dpa2.py
(2 hunks)source/tests/consistent/descriptor/test_se_atten_v2.py
(2 hunks)source/tests/consistent/descriptor/test_se_e2_a.py
(2 hunks)source/tests/consistent/descriptor/test_se_r.py
(2 hunks)source/tests/consistent/descriptor/test_se_t.py
(2 hunks)source/tests/consistent/descriptor/test_se_t_tebd.py
(2 hunks)source/tests/consistent/fitting/test_dipole.py
(1 hunks)source/tests/consistent/fitting/test_dos.py
(1 hunks)source/tests/consistent/fitting/test_ener.py
(2 hunks)source/tests/consistent/fitting/test_polar.py
(1 hunks)
🔇 Additional comments (34)
deepmd/jax/env.py (1)
16-18
: LGTM! The strict dtype promotion aligns with precision requirements.
The conditional configuration of strict dtype promotion is a good defensive measure to prevent silent precision loss during numerical operations. This complements the existing x64 precision setting and aligns with the PR's objective of fixing precision issues.
Let's verify how this affects the codebase:
.github/workflows/test_python.yml (1)
65-65
: LGTM! Verify test coverage impact.
The addition of DP_DTYPE_PROMOTION_STRICT
environment variable aligns with the PR's objective of fixing precision issues. The placement in the TF2 eager mode test section is appropriate.
Let's verify if this affects test coverage:
✅ Verification successful
Environment variable is correctly placed and impacts JAX precision handling
The verification shows that DP_DTYPE_PROMOTION_STRICT
is used in deepmd/jax/env.py
to control dtype promotion behavior in JAX. The extensive test coverage (150+ test files) involving dtype/precision handling and JAX-specific tests confirms this change will properly enforce strict type checking during tests.
- The variable is referenced in
deepmd/jax/env.py
for JAX configuration - Test suite includes comprehensive JAX tests under
source/tests/consistent/
that will be affected - Large number of tests involving dtype/precision handling will help catch any type-related issues
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for any tests that might be affected by strict dtype promotion
# Look for tests involving dtype or precision
rg -l "dtype|precision|float32|float64" source/tests/
# Look for JAX-specific tests that might be affected
rg -l "jax" source/tests/
# Check for any existing usage of DP_DTYPE_PROMOTION_STRICT
rg "DP_DTYPE_PROMOTION_STRICT"
Length of output: 10429
source/tests/consistent/descriptor/test_se_r.py (2)
118-120
: LGTM! Consistent handling of FP32 precision across test modes.
The implementation maintains consistency with the skip_dp
property, ensuring that float32 precision tests are skipped in array API strict mode as well.
95-97
: LGTM! Skipping FP32 tests for NumPy is appropriate.
The added check correctly prevents running tests with float32 precision since NumPy silently handles mixed-precision operations between float64 and float32, which could mask potential precision-related issues.
Let's verify if this change is consistent across other descriptor test files:
✅ Verification successful
Verified: The float32 precision skip is consistently implemented across descriptor tests
The verification shows that:
- The same float32 precision check pattern exists across multiple descriptor test files (
test_dpa1.py
,test_dpa2.py
,test_se_atten_v2.py
,test_se_e2_a.py
,test_se_t.py
,test_se_t_tebd.py
) - The implementation in
test_se_r.py
follows the same pattern used in other files, with identical comments about NumPy's behavior with float32 - All descriptor test files consistently handle float32 precision cases
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check if similar precision checks exist in other descriptor test files
# Expected: Similar skip_dp implementations in other test files
# Search for skip_dp methods in descriptor test files
rg -A 5 "def skip_dp.*precision.*float32" source/tests/consistent/descriptor/
Length of output: 79
Script:
#!/bin/bash
# Let's try a broader search to find skip_dp property implementations
rg -A 5 "@property\s+def skip_dp" source/tests/consistent/descriptor/
# Also search for similar float32 precision checks
rg "precision.*float32" source/tests/consistent/descriptor/
Length of output: 2545
Script:
#!/bin/bash
# Let's check the context of these float32 checks in one of the files
# to see if they follow the same pattern as test_se_r.py
rg -B 2 -A 2 "precision.*float32" source/tests/consistent/descriptor/test_dpa1.py
Length of output: 706
source/tests/consistent/descriptor/test_se_t.py (1)
92-94
: LGTM! Verify similar changes in other test files.
The skip logic for float32 precision is correctly implemented and well-documented. The comment clearly explains why NumPy tests need to be skipped for float32 precision.
Let's verify that this skip logic is consistently implemented across other descriptor test files:
source/tests/consistent/fitting/test_polar.py (1)
101-112
: LGTM! Verify test coverage for all precision combinations.
The implementation correctly combines precision-based skipping with installation status check.
Let's verify that the test parameterization covers all precision combinations:
✅ Verification successful
LGTM! Test coverage for precision combinations is comprehensive.
The test class is properly parameterized with all precision combinations:
- Test runs with both
float64
andfloat32
precision - For
float32
, array_api_strict tests are correctly skipped due to NumPy's mixed precision behavior - For
float64
, tests run when array_api_strict backend is installed - The parameterization also covers combinations with
resnet_dt
andmixed_types
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check test parameterization and actual test executions
# Check test parameterization
rg -A 10 "@parameterized" source/tests/consistent/fitting/test_polar.py
# Check actual test executions and their skip status
rg "def test_" source/tests/consistent/fitting/test_polar.py
# Look for similar skip patterns in other test files
rg "skip_array_api_strict.*float32" source/tests
Length of output: 475
Script:
#!/bin/bash
# Let's get more context about the test class and its parameterization
# Get the full test class implementation
cat source/tests/consistent/fitting/test_polar.py
# Check for any other skip_array_api_strict implementations
rg -B 5 -A 5 "skip_array_api_strict" source/tests/
Length of output: 21551
source/tests/consistent/descriptor/test_se_e2_a.py (1)
101-103
: LGTM! Verify NumPy's mixed precision behavior.
The skip condition for float32 precision is well-documented and aligns with the PR objectives.
Let's verify the NumPy behavior that necessitates this skip:
✅ Verification successful
Based on the search results, I can now verify the review comment. The codebase shows consistent handling of float32 precision across multiple test files, particularly in the context of NumPy's behavior with mixed precision operations. The documentation and test files confirm that NumPy doesn't throw errors for float64 x float32 operations, which is exactly what the skip condition is handling.
Skip condition for float32 precision is correctly implemented
The skip condition is well-justified as NumPy's mixed precision behavior requires special handling for float32 cases. This is consistently implemented across multiple test files in the codebase, including:
- Other descriptor tests like
test_se_t.py
,test_se_r.py
,test_se_atten_v2.py
have similar skip conditions - The comment accurately describes NumPy's behavior with mixed precision operations
- The implementation aligns with the project's precision handling patterns
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Search for related test failures or issues with NumPy's mixed precision
# Search for related test failures or discussions
rg -l "float32.*float64|mixed precision|precision issue" --type md
# Search for similar skip conditions in other test files
rg "precision.*float32.*skip|skip.*float32" --type py
Length of output: 155
Script:
#!/bin/bash
# Let's check the content of the documentation file and search for more context
cat doc/train/training-advanced.md
# Search for any float32/float64 related tests or implementations
rg "float(32|64)" -A 3 --type py
# Search for NumPy specific precision handling
rg "numpy.*float|np\.(float32|float64)" -A 2 --type py
# Look for skip conditions in test files
fd "test.*\.py$" --exec grep -l "skip.*precision\|precision.*skip" {} \;
Length of output: 223836
source/tests/consistent/fitting/test_dos.py (1)
100-113
: LGTM! Verify consistent skip behavior across test files.
The implementation correctly handles skipping tests for float32 precision due to NumPy's mixed precision behavior.
Let's verify that this skip behavior is consistently implemented across other test files:
✅ Verification successful
Skip behavior for float32 precision is consistently implemented across test files
The verification shows that all test files in the codebase consistently implement the same skip logic for float32 precision with identical comments explaining "NumPy doesn't throw errors for float64 x float32". This includes:
- All descriptor test files (se_e2_a, se_t, dpa1, se_atten_v2, se_t_tebd, dpa2, se_r)
- All fitting test files (dipole, ener, polar, dos)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for consistent implementation of skip_dp property across test files
# Expect: All test files should have similar skip logic for float32 precision
# Search for skip_dp property implementations
ast-grep --pattern 'def skip_dp(self) -> bool:
$$$
if $precision == "float32":
$$$
return True
$$$'
# Also check for any inconsistent implementations
rg -l 'def skip_dp' | xargs rg 'float32|precision'
Length of output: 19917
source/tests/consistent/descriptor/test_se_t_tebd.py (1)
130-132
: LGTM! Skipping float32 tests prevents silent precision issues.
The added check correctly implements the PR objective to skip float32 tests for NumPy, preventing potential silent precision issues when mixing float64 and float32 operations.
Let's verify this is consistent with other test files:
✅ Verification successful
I'm broadening the search to:
- Find any skip patterns related to precision across all test files
- Look specifically for float32 handling in the consistent tests directory
This will help verify if the approach is consistent with other test files.
The float32 skip pattern is consistent across the codebase
The search results show that the same float32 skip pattern with identical comment "NumPy doesn't throw errors for float64 x float32" is consistently implemented across multiple test files in the codebase, including:
- descriptor tests (se_t_tebd.py, se_t.py, se_r.py, se_e2_a.py, se_atten_v2.py, dpa1.py, dpa2.py)
- fitting tests (polar.py, property.py, dos.py, dipole.py, ener.py)
This confirms that the implementation in test_se_t_tebd.py follows the established pattern for handling float32 precision tests across the project.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for similar precision-based skip patterns in other test files
rg -l "precision.*float32.*return True" source/tests/
Length of output: 55
Script:
#!/bin/bash
# Let's try a broader search for precision-related skip patterns
rg -B2 -A2 "skip.*precision" source/tests/
# Also search for float32-related patterns
rg -B2 -A2 "float32" source/tests/consistent/
Length of output: 25430
source/tests/consistent/fitting/test_ener.py (1)
130-132
: LGTM! Consistent precision handling.
The modification to skip_array_api_strict
property maintains consistency with skip_dp
by applying the same precision-based skipping logic.
source/tests/consistent/descriptor/test_se_atten_v2.py (1)
244-246
: LGTM! Consistent with DP implementation.
The skip condition for array_api_strict matches the DP implementation, ensuring consistent handling of float32 precision across different backends.
deepmd/dpmodel/descriptor/se_r.py (2)
17-17
: LGTM: Import addition aligns with precision handling requirements
The addition of cast_precision
import is appropriate for implementing precision control in the descriptor computations.
293-293
: LGTM: Precision control enhancement via decorator
The addition of @cast_precision
decorator appropriately addresses precision handling in descriptor computations. This change aligns with the PR objective of fixing precision issues.
Let's verify the numerical stability of the descriptor computations with the new precision handling:
✅ Verification successful
Let me gather more information about the precision-related tests by examining their implementation details.
Let me try a different approach to examine the test implementations.
Precision handling is properly tested and verified
The test suite comprehensively validates the @cast_precision
decorator's functionality through:
- Tests with both float32 and float64 precision
- Numerical consistency checks between serialized/deserialized instances
- Validation across different tensor operations with appropriate tolerances
- Tests with various configurations (resnet_dt, exclude_mask)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Search for test files that verify numerical stability of descriptor computations
# Look for relevant test files
echo "Searching for relevant test files..."
rg -l "DescrptSeR" --type py | grep "test"
# Look for precision-related tests
echo "Searching for precision-related tests..."
rg -A 5 "precision|dtype|fp32|fp64" --type py | grep -B 2 "DescrptSeR"
Length of output: 3403
Script:
#!/bin/bash
# Check the actual test implementations for precision-related assertions
rg -B 2 -A 10 "test.*prec" source/tests/pt/model/test_descriptor_se_r.py
# Look for precision-related test cases in other test files
rg -B 2 -A 10 "test.*prec" source/tests/consistent/descriptor/test_se_r.py
# Check if there are any existing precision validation tests
rg -B 2 -A 10 "assert.*dtype" source/tests/pt/model/test_descriptor_se_r.py
Length of output: 229
Script:
#!/bin/bash
# Look for test methods in the descriptor test files
echo "Checking test_descriptor_se_r.py..."
rg -A 20 "def test_" source/tests/pt/model/test_descriptor_se_r.py
# Check for precision/dtype related variables and test setup
echo "Checking precision-related setup..."
rg -B 5 -A 5 "dtype|precision|fp32|fp64" source/tests/pt/model/test_descriptor_se_r.py
# Look for test assertions related to descriptor output
echo "Checking descriptor output assertions..."
rg -B 5 -A 5 "assert|self.assert" source/tests/pt/model/test_descriptor_se_r.py
Length of output: 7177
deepmd/dpmodel/descriptor/se_t.py (2)
17-17
: LGTM: Import addition is well-organized
The addition of cast_precision
import is properly grouped with related imports from the same module.
268-268
: Verify numerical equivalence with previous implementation
The addition of @cast_precision
decorator aligns with the PR's objective to fix precision issues. However, since this affects numerical computations in a scientific computing context, it's crucial to verify that the results remain numerically equivalent to the previous implementation.
Let's verify the numerical precision handling:
✅ Verification successful
Based on the search results, I can now generate the final response:
@cast_precision decorator consistently handles precision across descriptors
The addition of @cast_precision
decorator to the call
method is part of a systematic approach to handle precision casting across the codebase. The search results show:
- The decorator is consistently used across multiple descriptor implementations (se_t, se_r, se_a, dpa1, etc.)
- It's properly defined in both dpmodel/common.py and tf/common.py for different backends
- The implementation uses PRECISION_DICT to ensure consistent precision handling
- No remaining explicit precision casting in the descriptor implementations that would conflict with the decorator
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Search for any remaining explicit precision casting in descriptor implementations
# to ensure consistent precision handling across the codebase
# Look for any remaining astype calls that might indicate inconsistent precision handling
rg "astype.*prec|PRECISION_DICT" --type python
# Look for other descriptor implementations to verify consistent usage of cast_precision
ast-grep --pattern 'class Descrpt$_ {
$$$
@cast_precision
def call($$$) {
$$$
}
$$$
}'
Length of output: 182
Script:
#!/bin/bash
# Search for astype and precision related patterns without file type restriction
rg "astype.*prec|PRECISION_DICT"
# Look for cast_precision decorator usage
rg "@cast_precision" -A 2
# Look for descriptor class implementations
rg "class.*Descriptor|class.*Descrpt" -A 5
# Check the cast_precision decorator implementation
rg "def cast_precision"
Length of output: 86322
deepmd/dpmodel/descriptor/se_t_tebd.py (2)
171-171
: LGTM: Added precision instance variable
The addition of the precision instance variable ensures proper precision handling throughout the class lifecycle.
291-291
: LGTM: Added precision casting decorator
The @cast_precision
decorator ensures consistent precision handling for input and output tensors in the call
method.
Let's verify that all tensor operations in the method handle the precision correctly:
✅ Verification successful
Let me gather more information about the cast_precision
decorator and its usage to ensure there are no conflicts.
Let me check the implementation of the cast_precision decorator to ensure it's compatible with the current usage.
Verified: Precision casting decorator is correctly applied
The @cast_precision
decorator from deepmd/dpmodel/common.py
is correctly implemented to:
- Cast input arrays from global precision to the class's precision before method execution
- Cast output arrays back to global precision after method execution
The decorator's usage aligns with the existing precision handling in the file, where:
- The class already uses
PRECISION_DICT[self.precision]
for numpy arrays - Explicit dtype casting is properly managed with
astype
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Check for any explicit dtype casting in the method that might conflict with the decorator
rg -A 2 "astype|dtype=" deepmd/dpmodel/descriptor/se_t_tebd.py
# Check for any numpy array creation that might need precision handling
rg "np\.(zeros|ones|full|empty)" deepmd/dpmodel/descriptor/se_t_tebd.py
Length of output: 693
Script:
#!/bin/bash
# Check the implementation of cast_precision decorator
rg -B 2 -A 5 "@cast_precision" deepmd/dpmodel/descriptor/se_t_tebd.py
# Check for any other precision-related code in the file
rg "precision" deepmd/dpmodel/descriptor/se_t_tebd.py
# Look for the decorator definition
fd -e py -x rg -l "def cast_precision"
Length of output: 1082
Script:
#!/bin/bash
# Check the implementation of cast_precision decorator in both files
rg -B 2 -A 10 "def cast_precision" deepmd/tf/common.py deepmd/dpmodel/common.py
# Check if the call method has any tensor operations that might be affected
rg -A 10 "def call" deepmd/dpmodel/descriptor/se_t_tebd.py
Length of output: 2199
deepmd/dpmodel/descriptor/dpa2.py (3)
17-17
: LGTM: Import of cast_precision decorator
The addition of the cast_precision import aligns with the precision handling enhancement objective.
762-762
: LGTM: Addition of cast_precision decorator
The addition of the @cast_precision decorator to the call method appropriately implements precision control for tensor operations, aligning with the PR's objective to fix precision handling.
Line range hint 762-924
: Verify precision handling in tensor operations
Let's ensure that precision handling is consistent across all tensor operations in the call method and related files.
deepmd/dpmodel/common.py (2)
6-8
: Importing wraps
from functools
is appropriate
The wraps
decorator is correctly imported and used to preserve the metadata of the original function.
11-11
: Importing Callable
and Optional
from typing
The imports are necessary for the type hints used in the functions.
deepmd/dpmodel/fitting/general_fitting.py (4)
447-449
: Efficient masking of atom_property
using xp.where
Using xp.where
to apply the mask
to atom_property
improves code clarity and computational efficiency.
451-451
: Consistent casting of outs
to global precision
Casting outs
to global precision ensures consistency in computations and aligns with the precision used elsewhere in the codebase.
455-459
: Correct application of per-atom biases in mixed types scenario
The code correctly adds biases to outs
based on each atom's type using xp.take
and reshaping. This ensures that atoms receive the appropriate bias according to their type.
466-466
: Effective use of exclude_mask
with xp.where
Applying exclude_mask
to outs
using xp.where
efficiently zeroes out contributions from excluded atom types, enhancing code clarity.
deepmd/dpmodel/descriptor/se_e2_a.py (5)
18-18
: Import cast_precision
for consistent precision handling
The addition of cast_precision
to the imports allows for consistent precision casting throughout the module, aligning with the goal of fixing precision issues.
341-341
: Apply @cast_precision
decorator to call
method
Applying the @cast_precision
decorator ensures that the inputs and outputs of the call
method are correctly cast to the desired precision, improving numerical stability and consistency.
417-417
: Remove explicit precision casting inside the call
method
By removing explicit casting to a global precision, the method now relies on the @cast_precision
decorator to handle precision, simplifying the code and reducing redundancy.
506-506
: Apply @cast_precision
decorator to call
method in DescrptSeAArrayAPI
Adding the @cast_precision
decorator to the call
method of DescrptSeAArrayAPI
ensures consistent precision handling, mirroring the update made in the base class DescrptSeA
.
586-586
: Remove explicit precision casting in return statement
Removing the explicit casting in the return statement is appropriate because the @cast_precision
decorator now manages precision, simplifying the code and ensuring consistency.
deepmd/dpmodel/descriptor/dpa1.py (3)
22-23
: Import necessary utilities for precision handling
The inclusion of cast_precision
and to_numpy_array
is appropriate and necessary for managing precision settings and array conversions within the module.
333-333
: Store precision setting in class instance
Assigning self.precision = precision
in the constructor ensures that the precision configuration is consistently accessible throughout the class methods.
453-453
: Apply @cast_precision
decorator to call
method
Applying the @cast_precision
decorator to the call
method ensures that computations within this method adhere to the specified precision, enhancing numerical consistency and preventing potential precision-related issues.
Signed-off-by: Jinzhe Zeng <[email protected]>
This reverts commit 73b4227.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I revert "skip FP32 for NP". The issue is that if dpmodel is skipped, the reference backend is fallback to TF, which is not expected...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
deepmd/dpmodel/fitting/general_fitting.py (1)
456-459
: Consider combining bias addition and exclude mask operationsWhile the current implementation is correct, you could potentially improve performance by combining the bias addition and exclude mask application into a single operation:
-outs = xp.astype(outs, get_xp_precision(xp, "global")) -outs += xp.reshape( - xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), - [nf, nloc, net_dim_out], -) -outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) +bias = xp.reshape( + xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), + [nf, nloc, net_dim_out], +) +outs = xp.where( + exclude_mask[:, :, None], + xp.astype(outs, get_xp_precision(xp, "global")) + bias, + xp.zeros_like(bias) +)This change:
- Reduces memory allocation by combining operations
- Avoids unnecessary computation for excluded atoms
Also applies to: 463-463
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/dpmodel/fitting/general_fitting.py
(2 hunks)
🔇 Additional comments (3)
deepmd/dpmodel/fitting/general_fitting.py (3)
367-371
: LGTM: Proper precision casting for input tensors
The explicit type casting ensures consistent precision handling across all input tensors using the get_xp_precision
helper.
447-449
: LGTM: Improved output handling using xp.where
The use of xp.where
for conditional output is more explicit and potentially more efficient than direct masked assignment.
452-452
: LGTM: Proper handling of mixed types and vacuum contribution
The code correctly processes outputs for mixed types and properly handles the vacuum contribution subtraction when required.
Also applies to: 455-455
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
.github/workflows/test_python.yml (1)
65-65
: Consider adding a comment explaining the variable's purpose.To improve maintainability, consider adding a comment explaining why this environment variable is needed and its relationship with JAX's dtype promotion behavior.
Example addition:
+ # Enable strict dtype promotion in JAX to catch precision issues DP_DTYPE_PROMOTION_STRICT: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
.github/workflows/test_python.yml
(1 hunks)
🔇 Additional comments (1)
.github/workflows/test_python.yml (1)
65-65
: LGTM! Verify variable usage across the codebase.
The addition of DP_DTYPE_PROMOTION_STRICT=1
for TF2 eager mode tests is appropriate for enforcing stricter type checking.
Let's verify how this environment variable is used:
✅ Verification successful
Verification Successful: Variable Usage Confirmed
The environment variable DP_DTYPE_PROMOTION_STRICT
is used consistently within deepmd/jax/env.py
to enforce strict dtype promotion.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for other occurrences of DP_DTYPE_PROMOTION_STRICT
rg "DP_DTYPE_PROMOTION_STRICT" -A 3
# Look for related JAX dtype promotion settings
rg "dtype.*promotion" -A 3
Length of output: 422
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
@overload | ||
def safe_cast_array( | ||
input: np.ndarray, from_precision: str, to_precision: str | ||
) -> np.ndarray: ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
input: np.ndarray, from_precision: str, to_precision: str | ||
) -> np.ndarray: ... | ||
@overload | ||
def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (3)
deepmd/dpmodel/common.py (2)
156-157
: Fix undefined type hint in docstring exampleThe example uses
Array
as a type hint, but this type is not defined or imported. Consider usingnp.ndarray
or adding a proper import statement.- ... def f(x: Array, y: Array) -> Array: + ... def f(x: np.ndarray, y: np.ndarray) -> np.ndarray:
204-205
: Update docstring parameter and return typesThe docstring incorrectly refers to
tf.Tensor
. Update it to reflect the actual types used in the function.- input : tf.Tensor + input : np.ndarray Input tensor ... - tf.Tensor + np.ndarray casted TensorAlso applies to: 213-214
deepmd/dpmodel/atomic_model/base_atomic_model.py (1)
210-210
: Minor optimization: Simplify dict key lookupUse
kk in ret_dict
instead ofkk in ret_dict.keys()
for more efficient key membership testing.- for kk in ret_dict.keys(): + for kk in ret_dict:🧰 Tools
🪛 Ruff
210-210: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
deepmd/dpmodel/atomic_model/base_atomic_model.py
(1 hunks)deepmd/dpmodel/common.py
(2 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/atomic_model/base_atomic_model.py
210-210: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
🪛 GitHub Check: CodeQL
deepmd/dpmodel/common.py
[notice] 189-189: Statement has no effect
This statement has no effect.
[notice] 191-191: Statement has no effect
This statement has no effect.
🔇 Additional comments (5)
deepmd/dpmodel/common.py (3)
6-8
: LGTM: Imports are properly organized and necessary.
The new imports support the added functionality for precision management.
Also applies to: 11-13
160-183
: LGTM: Robust implementation of precision casting
The decorator implementation is well-structured and handles various input/output scenarios correctly:
- Properly processes both args and kwargs
- Handles multiple return types (tuple, dict)
- Uses safe type casting
186-220
: LGTM: Well-structured type-safe implementation
The function is properly typed with overloads and correctly handles both NumPy arrays and Array API objects.
🧰 Tools
🪛 GitHub Check: CodeQL
[notice] 189-189: Statement has no effect
This statement has no effect.
[notice] 191-191: Statement has no effect
This statement has no effect.
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
204-208
: LGTM: Improved mask computation logic
The change from element-wise multiplication to logical_and
is more semantically correct for combining boolean masks and helps maintain precision by avoiding floating-point operations.
213-215
: LGTM: Enhanced precision in output masking
The use of xp.where
with explicit zeros provides better precision control compared to direct multiplication with the mask. The array reshaping properly handles multi-dimensional outputs while maintaining array API compatibility.
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (1)
deepmd/dpmodel/common.py (1)
124-183
: LGTM with minor docstring improvementsThe decorator implementation is solid and handles all cases correctly, including tuple and dict returns. Consider improving the example in the docstring for clarity.
Update the example to use proper type hints and actual numpy arrays:
Examples -------- >>> class A: ... def __init__(self): ... self.precision = "float32" ... ... @cast_precision - ... def f(x: Array, y: Array) -> Array: + ... def f(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: ... return x**2 + y + ... + >>> import numpy as np + >>> a = A() + >>> x = np.array([1.0, 2.0], dtype=np.float64) + >>> y = np.array([3.0, 4.0], dtype=np.float64) + >>> result = a.f(x, y) # result will be float64, internally computed as float32
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/dpmodel/common.py
(2 hunks)
🧰 Additional context used
🪛 GitHub Check: CodeQL
deepmd/dpmodel/common.py
[notice] 189-189: Statement has no effect
This statement has no effect.
[notice] 191-191: Statement has no effect
This statement has no effect.
🔇 Additional comments (1)
deepmd/dpmodel/common.py (1)
6-13
: LGTM: Import additions are appropriate
The new imports support the decorator pattern and type system features appropriately.
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
deepmd/dpmodel/fitting/dipole_fitting.py (1)
Additional Precision Handling Required in Multiple Modules
Several precision-sensitive operations using
np.einsum
,np.dot
, andnp.matmul
were identified. It is recommended to apply the@cast_precision
decorator to these methods to ensure consistent numerical stability across the codebase.
deepmd/utils/out_stat.py
(Lines 70-72)deepmd/dpmodel/descriptor/se_e2_a.py
(Lines 408, 415)source/ipi/tests/test_driver.py
(Multiple lines)deepmd/dpmodel/fitting/dipole_fitting.py
(Additional lines)- [Refer to full list in shell script output]
🔗 Analysis chain
Line range hint
180-238
: Excellent addition of precision handling to dipole calculationsThe addition of
@cast_precision
decorator improves the robustness of numerical computations in the dipole fitting calculations. The implementation:
- Ensures consistent precision across array operations
- Maintains clean matrix operations with modern syntax
- Preserves comprehensive type hints and documentation
Let's verify the precision handling impact:
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for other precision-sensitive operations in the codebase # that might need similar treatment # Look for similar numerical computation patterns ast-grep --pattern 'np.einsum($$$)' # Check for other array operations that might need precision handling rg -l 'import numpy as np' | xargs rg '@|np\.dot|np\.matmul'Length of output: 77769
deepmd/dpmodel/atomic_model/base_atomic_model.py (1)
210-210
: Minor optimization: Use direct dictionary membership testReplace
ret_dict.keys()
with direct dictionary membership test for better performance.-for kk in ret_dict.keys(): +for kk in ret_dict:🧰 Tools
🪛 Ruff
210-210: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
deepmd/dpmodel/atomic_model/base_atomic_model.py
(1 hunks)deepmd/dpmodel/fitting/dipole_fitting.py
(2 hunks)deepmd/dpmodel/fitting/general_fitting.py
(1 hunks)deepmd/dpmodel/fitting/invar_fitting.py
(2 hunks)deepmd/dpmodel/fitting/polarizability_fitting.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/dpmodel/fitting/general_fitting.py
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/atomic_model/base_atomic_model.py
210-210: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
🔇 Additional comments (8)
deepmd/dpmodel/fitting/dipole_fitting.py (1)
14-16
: LGTM: Clean import addition for precision handling
The addition of the cast_precision
import aligns with the broader initiative to enhance precision handling across the codebase.
deepmd/dpmodel/fitting/invar_fitting.py (2)
13-15
: LGTM: Clean import addition
The import of cast_precision
is properly organized with other imports from the same module.
209-209
: Verify precision handling across different input types
The addition of @cast_precision
decorator is appropriate for ensuring consistent precision across numerical computations. However, let's verify its behavior with different input types.
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
204-208
: LGTM! Improved boolean mask computation
The change from element-wise multiplication to logical AND is a good improvement as it:
- Makes the boolean intent more explicit
- Potentially improves performance through short-circuiting
- Enhances code maintainability
213-216
: LGTM! Robust precision-preserving masking implementation
The implementation correctly handles n-dimensional arrays while preserving precision through:
- Proper dimension handling with
math.prod
- Precise masking using
xp.zeros_like
- Shape preservation through careful reshaping
deepmd/dpmodel/fitting/polarizability_fitting.py (3)
18-18
: LGTM: Import addition is appropriate.
The addition of cast_precision
import aligns with the module's precision handling requirements.
245-245
: LGTM: Decorator addition enhances precision handling.
The @cast_precision
decorator ensures consistent precision handling across input and output tensors, which is crucial for numerical stability in scientific computations.
290-291
: LGTM: Type casting ensures precision consistency.
The explicit type casting in both scale
and constant_matrix
operations ensures consistent precision throughout the calculations. This is particularly important for:
- Preventing silent precision loss during tensor operations
- Maintaining numerical stability in scientific computations
However, let's verify that these changes don't introduce any precision-related regressions.
Also applies to: 314-318
✅ Verification successful
Type casting changes are correct and safe
The explicit type casting to match output tensor dtype follows established patterns in the codebase and maintains precision consistency. The changes are well-integrated with the global precision handling architecture.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for any existing precision-related tests and their results
# Look for precision-related test files
echo "Searching for precision-related tests..."
rg -l "precision|dtype" --type py "tests/"
# Look for any existing precision-related issues
echo "Checking for precision-related issues in the codebase..."
rg "precision|dtype" --type py
Length of output: 271265
Co-authored-by: Han Wang <[email protected]> Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/dpmodel/fitting/general_fitting.py
(1 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/fitting/general_fitting.py
460-460: SyntaxError: Expected an expression or a '}'
🔇 Additional comments (3)
deepmd/dpmodel/fitting/general_fitting.py (3)
442-444
: LGTM: Improved type masking implementation
The use of xp.where
provides a more explicit and safer way to handle type masking, ensuring proper zero initialization for non-matching types.
447-449
: LGTM: Consistent vacuum contribution handling
The implementation correctly handles vacuum contribution removal for mixed types, maintaining consistency with the non-mixed types case.
450-455
: LGTM: Improved bias addition with type safety
The implementation enhances type safety and efficiency by:
- Using explicit type casting with
xp.astype
- Efficiently applying bias per atom type using
xp.take
- Properly reshaping the output to maintain dimensions
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng <[email protected]>
Tried to implement the decorator as in #4343, but encountered JIT errors. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Enhanced precision handling across various descriptor classes and methods, ensuring consistent tensor operations. - Updated output formats in several classes to improve clarity and usability. - Introduced a new environment variable for stricter control over tensor precision handling. - Added a new parameter to the `DipoleFittingNet` class for excluding specific types. - **Bug Fixes** - Removed conditions that skipped tests for "float32" data type, allowing all tests to run consistently. - **Documentation** - Improved error messages for dimension mismatches and unsupported parameters, enhancing user understanding. - **Tests** - Adjusted test parameters for consistency in handling `fparam` and `aparam` across multiple test cases. - Simplified tensor handling in tests by removing unnecessary type conversions before compression. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Summary by CodeRabbit
Release Notes
New Features
DP_DTYPE_PROMOTION_STRICT
to enhance precision handling in TensorFlow tests.@cast_precision
to several descriptor classes, improving precision management during computations.Bug Fixes
GeneralFitting
class for better output predictions.BaseAtomicModel
class.NativeLayer
class.Chores