Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): SavedModel C++ interface (including DPA-2 supports) #4307

Merged
merged 81 commits into from
Nov 13, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Nov 4, 2024

Including nlist and no nlist interface.

The limitation: A SavedModel created on a device cannot be run on another. For example, a CUDA model cannot be run on the CPU.

The model is generated using #4336.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for the JAX backend, including specific model and checkpoint file formats.
    • Introduced a new shell script for model conversion to enhance usability.
    • Updated installation documentation to clarify JAX support and requirements.
    • New section in documentation detailing limitations of the JAX backend with LAMMPS.
  • Bug Fixes

    • Enhanced error handling for model initialization and backend compatibility.
  • Documentation

    • Updated backend documentation to include JAX details and limitations.
    • Improved clarity in installation instructions for both TensorFlow and JAX.
  • Tests

    • Added comprehensive unit tests for JAX integration with the Deep Potential class.
    • Expanded test coverage for LAMMPS integration with DeepMD.
  • Chores

    • Updated CMake configurations and workflow files for improved testing and dependency management.

njzjz added 6 commits November 4, 2024 03:50
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

coderabbitai bot commented Nov 4, 2024

📝 Walkthrough

Walkthrough

The pull request introduces support for the JAX backend in DeePMD-kit, updating documentation and codebase to accommodate this addition. Key changes include the inclusion of JAX-specific model and checkpoint file formats, installation instructions, and the implementation of a new DeepPotJAX class that interfaces with TensorFlow's C API. The workflow configurations for testing have also been modified to include model conversion steps. Additionally, extensive testing frameworks have been established for validating the functionality of the JAX backend with LAMMPS and DeepMD.

Changes

File Change Summary
doc/backend.md Updated documentation to include JAX as a supported backend, specifying model and checkpoint filename extensions, version requirements, and limitations.
doc/install/install-from-source.md Enhanced installation instructions for JAX alongside TensorFlow, clarified compiler requirements, and updated CMake configurations.
source/api_cc/include/DeepPotJAX.h Introduced DeepPotJAX class for JAX support, with multiple constructors and methods for model interaction.
source/api_cc/include/common.h Updated DPBackend enum to include JAX.
source/api_cc/src/DeepPot.cc Modified DeepPot class to support JAX backend initialization with appropriate error handling.
source/api_cc/src/DeepPotJAX.cc Implemented TensorFlow C API interface for JAX, including initialization and computation methods.
source/api_cc/tests/test_deeppot_jax.cc Added unit tests for the DeepPot class focusing on JAX functionality using Google Test framework.
source/cmake/googletest.cmake.in Updated Google Test library version from release-1.12.1 to v1.14.0.
source/lmp/tests/test_lammps_jax.py Implemented tests for LAMMPS integration with DeepMD using JAX, including setup and assertions.
doc/model/dpa2.md Added section on limitations of the JAX backend with LAMMPS.
source/api_c/include/c_api.h Incremented API version to 25 and added DP_NlistSetMapping function.
source/api_c/include/deepmd.hpp Introduced set_mapping method in InputNlist structure.
source/api_c/src/c_api.cc Added DP_NlistSetMapping function to DP_Nlist class.
source/lib/include/neighbor_list.h Enhanced InputNlist structure with mapping capabilities.
source/lmp/fix_dplr.cpp Improved atom mapping handling in FixDPLR class.
source/lmp/pair_deepmd.cpp Enhanced atom mapping handling in PairDeepMD class for JAX.
source/lmp/tests/test_lammps_dpa_jax.py Established tests for DPA using LAMMPS with JAX integration.
.github/workflows/test_cc.yml Updated CI workflow for Python dependencies and added model conversion step.
.github/workflows/test_cuda.yml Modified CI workflow for CUDA, including Docker image update and model conversion step.
examples/water/lmp/jax_dpa2.lammps Added configuration file for simulating water with LAMMPS using JAX backend.
source/tests/infer/convert-models.sh New script for converting model files using the dp convert-backend command.

Possibly related PRs

Suggested reviewers

  • CaRoLZhangxy
  • wanghan-iapcm
  • iProzd

📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between cfff834 and ca02625.

📒 Files selected for processing (1)
  • .github/workflows/test_cuda.yml (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • .github/workflows/test_cuda.yml

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🧹 Outside diff range and nitpick comments (14)
source/cmake/googletest.cmake.in (1)

14-14: Document the reason for upgrading Google Test.

Consider adding a comment explaining why the upgrade to v1.14.0 was necessary. This helps future maintainers understand the motivation behind version changes.

-  GIT_TAG           v1.14.0
+  # Upgraded to v1.14.0 to support new testing features needed for JAX backend
+  GIT_TAG           v1.14.0
source/api_cc/include/common.h (2)

16-16: Document JAX backend limitations.

Given the known limitations mentioned in the PR description:

  1. Neighbor list only support
  2. Device compatibility restrictions
  3. Memory leak concerns

Consider adding documentation comments above the enum to clarify these limitations for API users.

Example addition:

+/**
+ * @brief Backend types supported by DeePMD-kit
+ * @note JAX backend has the following limitations:
+ * - Only supports neighbor list operations
+ * - Models created on one device (e.g., CUDA) cannot run on different devices (e.g., CPU)
+ * - Known memory leaks exist in the TensorFlow C API implementation
+ */
 enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };

16-16: Consider adding runtime checks for JAX limitations.

Given the device compatibility and neighbor list limitations, consider adding runtime validation to prevent misuse.

Would you like help implementing:

  1. Device compatibility checks
  2. Neighbor list validation
  3. Memory leak detection utilities
doc/install/install-from-source.md (2)

300-302: Consider clarifying the JAX backend dependency.

The documentation accurately combines TensorFlow and JAX backends, but it would be helpful to explicitly mention that the JAX backend requires TensorFlow's C++ library as a dependency. This would help users better understand the system requirements.

Consider adding a note like:

 The C++ interfaces of both TensorFlow and JAX backends are based on the TensorFlow C++ library.
+Note: The JAX backend requires TensorFlow's C++ library as a dependency, even if you're not using TensorFlow directly.

380-380: Add version compatibility information.

The documentation correctly indicates that both TensorFlow and JAX backends use these CMake variables, but it would be beneficial to add information about version compatibility requirements.

Consider adding version compatibility notes:

 {{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend.
+Note: Ensure that the TensorFlow C++ library version is compatible with both your TensorFlow and JAX Python packages.

Also applies to: 396-396

source/api_cc/include/DeepPotJAX.h (4)

89-92: Correct the Doxygen comment to match the method signature

The method is_aparam_nall() does not take any parameters, but the comment includes @param[out] aparam_nall, which is incorrect. Please update the comment to reflect the actual method signature.

Apply this diff:

    /**
-    * @brief Get whether the atom dimension of aparam is nall instead of fparam.
-    * @param[out] aparam_nall whether the atom dimension of aparam is nall
-    *instead of fparam.
+    * @brief Check if the atom dimension of `aparam` is `nall` instead of `fparam`.
     **/

49-49: Remove unnecessary semicolons after method definitions

The semicolons after the closing braces in the method definitions are unnecessary and can be removed to maintain consistency and readability.

Apply this diff:

   }
-  };
+  }

Also applies to: 57-57, 65-65, 73-73, 81-81, 96-96


225-226: Clarify the Doxygen comments for fparam and aparam parameters

The descriptions for the fparam and aparam parameters appear incomplete or unclear. The lines seem to be missing words or have formatting issues.

For fparam:

* dim_fparam. Then all frames are assumed to be provided with the same
*fparam.

For aparam:

* natoms x dim_aparam. Then all frames are assumed to be provided with the
*same aparam.

Please revise the comments to provide clear and complete descriptions of the expected parameter formats.

Suggested correction:

       * @param[in] fparam The frame parameter. The array can be of size:
       *   - nframes x dim_fparam, or
-      * dim_fparam. Then all frames are assumed to be provided with the same
-      *fparam.
+      *   - dim_fparam (if all frames share the same `fparam`), in which case all frames are assumed to use the same `fparam`.
       * @param[in] aparam The atomic parameter. The array can be of size:
       *   - nframes x natoms x dim_aparam, or
-      * natoms x dim_aparam. Then all frames are assumed to be provided with the
-      *same aparam.
+      *   - natoms x dim_aparam (if all frames share the same `aparam`), in which case all frames are assumed to use the same `aparam`.

Also applies to: 229-230


192-204: Consider using RAII wrappers for TensorFlow C API resources

To improve resource management and exception safety, consider encapsulating the TensorFlow C API resources (e.g., TF_Graph*, TF_Session*, etc.) in RAII-style wrapper classes. This ensures that resources are automatically released when they go out of scope, helping to prevent memory leaks and simplifying the destructor implementation.

source/api_cc/tests/test_deeppot_jax.cc (2)

100-110: Refactor repeated variable assignments in test cases

To enhance maintainability and reduce code duplication, consider removing the repeated variable assignments at the beginning of each test case by directly accessing the class member variables.

Apply this diff to each test case:

-  std::vector<VALUETYPE>& coord = this->coord;
-  std::vector<int>& atype = this->atype;
-  std::vector<VALUETYPE>& box = this->box;
-  std::vector<VALUETYPE>& expected_e = this->expected_e;
-  std::vector<VALUETYPE>& expected_f = this->expected_f;
-  std::vector<VALUETYPE>& expected_v = this->expected_v;
-  int& natoms = this->natoms;
-  double& expected_tot_e = this->expected_tot_e;
-  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
-  deepmd::DeepPot& dp = this->dp;
-  float rc = dp.cutoff();
+  float rc = this->dp.cutoff();

Also applies to: 162-172, 244-255, 306-317, 369-380


76-76: Add error handling for dp.init(file_name)

Ensure that dp.init(file_name) successfully loads the model file and handle any potential exceptions that may occur during initialization.

Apply this diff to enhance error handling:

-        dp.init(file_name);
+        try {
+          dp.init(file_name);
+        } catch (const std::exception& e) {
+          FAIL() << "Failed to initialize DeepPot: " << e.what();
+        }
source/api_cc/src/DeepPotJAX.cc (2)

35-38: Optimize string truncation in find_function

The use of substr to truncate name_ can be replaced with the more efficient resize method to avoid unnecessary copying.

Apply this diff to optimize the string truncation:

-        name_ = name_.substr(0, pos + 1);
+        name_.resize(pos + 1);
🧰 Tools
🪛 cppcheck

[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)


29-29: Pass string parameters by const reference

Passing string parameters by const reference can improve performance by avoiding unnecessary copies.

Apply this diff to modify the function signatures:

-inline void find_function(TF_Function*& found_func,
-                          const std::vector<TF_Function*>& funcs,
-                          const std::string func_name) {
+inline void find_function(TF_Function*& found_func,
+                          const std::vector<TF_Function*>& funcs,
+                          const std::string& func_name) {

...

-inline TFE_Op* get_func_op(TFE_Context* ctx,
-                           const std::string func_name,
-                           const std::vector<TF_Function*>& funcs,
-                           const std::string device,
-                           TF_Status* status) {
+inline TFE_Op* get_func_op(TFE_Context* ctx,
+                           const std::string& func_name,
+                           const std::vector<TF_Function*>& funcs,
+                           const std::string& device,
+                           TF_Status* status) {

...

-template <typename T>
-inline T get_scalar(TFE_Context* ctx,
-                    const std::string func_name,
-                    const std::vector<TF_Function*>& funcs,
-                    const std::string device,
-                    TF_Status* status) {
+template <typename T>
+inline T get_scalar(TFE_Context* ctx,
+                    const std::string& func_name,
+                    const std::vector<TF_Function*>& funcs,
+                    const std::string& device,
+                    TF_Status* status) {

...

Repeat similar changes for the functions get_vector, get_vector_string, and any other functions where std::string parameters are passed by value.

Also applies to: 64-64, 66-66, 86-86, 88-88, 109-109, 111-111, 130-130, 132-132

🧰 Tools
🪛 cppcheck

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)

source/api_cc/src/DeepPot.cc (1)

65-72: Ensure consistent backend initialization and error handling.

The code correctly initializes the JAX backend when BUILD_TENSORFLOW is defined and throws an appropriate exception when it is not. However, consider clarifying the exception message for better understanding.

Apply this diff to improve the exception message clarity:

} else if (deepmd::DPBackend::JAX == backend) {
#ifdef BUILD_TENSORFLOW
    dp = std::make_shared<deepmd::DeepPotJAX>(model, gpu_rank, file_content);
#else
    throw deepmd::deepmd_exception(
-        "TensorFlow backend is not built, which is used to load JAX2TF "
-        "SavedModels");
+        "Cannot load JAX2TF SavedModels because TensorFlow backend is not built.");
#endif
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 7aaf284 and 6c10e8e.

📒 Files selected for processing (11)
  • .github/workflows/test_cc.yml (1 hunks)
  • doc/backend.md (1 hunks)
  • doc/install/install-from-source.md (4 hunks)
  • source/api_cc/include/DeepPotJAX.h (1 hunks)
  • source/api_cc/include/common.h (1 hunks)
  • source/api_cc/src/DeepPot.cc (3 hunks)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
  • source/api_cc/tests/test_deeppot_jax.cc (1 hunks)
  • source/cmake/googletest.cmake.in (1 hunks)
  • source/tests/infer/deeppot_sea.savedmodel/.gitignore (1 hunks)
  • source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • source/tests/infer/deeppot_sea.savedmodel/.gitignore
  • source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 66-66: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 109-109: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 111-111: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 130-130: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 132-132: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

🔇 Additional comments (5)
source/cmake/googletest.cmake.in (1)

14-14: Verify test compatibility with Google Test v1.14.0.

The upgrade from release-1.12.1 to v1.14.0 is reasonable, but please ensure all existing tests are compatible with the new version, as it includes breaking changes.

✅ Verification successful

Based on the analysis of the test files, I can now provide a final response regarding the Google Test upgrade compatibility:

Google Test v1.14.0 upgrade is safe to proceed

The codebase uses standard Google Test features that are stable across versions 1.12.1 to 1.14.0:

  • Uses standard TYPED_TEST macros with consistent syntax
  • No usage of deprecated features like TEST_CASE or FRIEND_TEST
  • No death test assertions (ASSERT_DEATH/EXPECT_DEATH)
  • No custom test event listeners
  • No usage of features that had breaking changes in v1.14.0
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for usage of deprecated or changed Google Test features
# that might be affected by the upgrade to v1.14.0

# Look for test files
echo "Searching for test files..."
fd -e cc -e cpp -e h -e hpp --full-path '.*test.*' -x rg -l 'TEST|TEST_F|TEST_P|EXPECT_|ASSERT_'

# Check for specific deprecated features or patterns
echo "Checking for potentially affected test patterns..."
rg -A 2 'GTEST_SKIP|INSTANTIATE_TEST_SUITE|INSTANTIATE_TEST_CASE' 

Length of output: 3284


Script:

#!/bin/bash
# Let's check the actual test implementations for any potential compatibility issues
# Focus on specific patterns that changed between 1.12.1 and 1.14.0

# Check for usage of TEST_CASE (deprecated in favor of TEST_SUITE)
echo "Checking for TEST_CASE usage..."
rg "TEST_CASE\(" -A 2

# Check for FRIEND_TEST usage (syntax changed in 1.14.0)
echo "Checking for FRIEND_TEST usage..."
rg "FRIEND_TEST\(" -A 2

# Check for TYPED_TEST usage (behavior changes in 1.14.0)
echo "Checking for TYPED_TEST patterns..."
rg "TYPED_TEST\(" -A 2

# Check for ASSERT_DEATH and EXPECT_DEATH (behavior changes in 1.14.0)
echo "Checking death test patterns..."
rg "(ASSERT_DEATH|EXPECT_DEATH)\(" -A 2

# Check for custom test event listeners (API changes in 1.14.0)
echo "Checking for custom test listeners..."
rg "testing::TestEventListener" -A 2

Length of output: 49962

.github/workflows/test_cc.yml (1)

48-48: Clarify strategy for memory leak detection

The addition of -fsanitize=leak -DENABLE_SANITIZE=1 enables leak detection, which is valuable for maintaining code quality. However, this seems to conflict with the PR's stated intention to skip memory leak checks due to known issues with the TensorFlow C API.

Could you clarify:

  1. How should the CI handle known memory leaks from the TensorFlow C API?
  2. Should we consider conditionally enabling leak sanitizer only for non-TensorFlow components?
  3. What's the strategy for distinguishing between TensorFlow C API leaks and potential leaks in the new SavedModel interface?

Let's check if there are any existing leak sanitizer suppressions:

✅ Verification successful

Based on the search results, I can now generate the final response:

Memory leak sanitizer configuration appears appropriate

The addition of -fsanitize=leak -DENABLE_SANITIZE=1 in the CI workflow is well-aligned with the codebase's needs:

  1. The codebase already has sanitizer-aware code, as evidenced by the ENABLE_SANITIZE guards in test files
  2. No existing leak sanitizer suppressions were found, suggesting memory leaks should be actively detected rather than suppressed
  3. The TensorFlow C API usage in the codebase (e.g., in DeepPotJAX) shows careful memory management with proper cleanup

The conditional application (matrix.check_memleak && ...) provides the flexibility to:

  • Run tests with leak detection when investigating memory issues
  • Skip leak detection when focusing on other aspects or when known TensorFlow C API leaks would be disruptive
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Look for leak sanitizer suppression configurations
rg -i "leak.*suppress|suppress.*leak" 
# Check for any LSAN-related configurations
fd -e h -e cpp -e cc -e hpp | xargs rg "LSAN_OPTIONS|__lsan"

Length of output: 102


Script:

#!/bin/bash
# Look for any memory leak related configurations or handling in CMake files
rg -i "sanitize.*leak|leak.*sanitize" --type cmake

# Check for any TensorFlow-related memory management or leak handling
rg -i "tensorflow.*memory|memory.*leak" --type cpp --type hpp -A 3

# Look for any test configurations or setup files that might handle memory checks
fd "test|Test" -e cmake -e txt | xargs rg -i "sanitize|leak|memory"

# Check if there are any existing memory leak related preprocessor definitions
rg "ENABLE_SANITIZE|DISABLE_SANITIZE" --type cpp --type hpp --type cmake

Length of output: 68638


Script:

#!/bin/bash
# Check for any leak sanitizer related flags or configurations in CMake files
rg -i "sanitize.*leak|leak.*sanitize|ENABLE_SANITIZE" --type cmake

# Look for any memory leak related test configurations
fd -e cpp -e cc | xargs rg "ENABLE_SANITIZE|DISABLE_SANITIZE|fsanitize=leak"

# Check for any TensorFlow C API related memory management code
rg -i "tensorflow.*memory|memory.*leak" --type cpp -A 3

Length of output: 336


Script:

#!/bin/bash
# Check CMakeLists.txt files for sanitizer configurations
fd CMakeLists.txt | xargs rg -i "sanitize|ENABLE_SANITIZE"

# Look for any memory leak handling in C++ source files
rg -i "memory.*leak|leak.*check" --type cpp -B 2 -A 2

# Check for any TensorFlow C API usage and error handling
rg "TF_\w+\s*\(" --type cpp -B 2 -A 2

Length of output: 8410

doc/backend.md (1)

33-34: LGTM! Documentation accurately reflects the SavedModel C++ interface dependency.

The documentation clearly states that only the .savedmodel format supports C++ inference and requires the TensorFlow C++ interface, which aligns with the implementation.

source/api_cc/src/DeepPot.cc (2)

10-10: Including DeepPotJAX.h conditionally is appropriate.

Including DeepPotJAX.h when BUILD_TENSORFLOW is defined is correct since the JAX backend depends on TensorFlow.


45-47: Correctly identify JAX backend based on model file extension.

The condition properly checks for the '.savedmodel' extension and sets the backend to JAX.

doc/backend.md Show resolved Hide resolved
source/api_cc/include/common.h Show resolved Hide resolved
source/api_cc/include/DeepPotJAX.h Show resolved Hide resolved
source/api_cc/include/DeepPotJAX.h Show resolved Hide resolved
source/api_cc/include/DeepPotJAX.h Outdated Show resolved Hide resolved
source/api_cc/include/DeepPotJAX.h Outdated Show resolved Hide resolved
source/api_cc/tests/test_deeppot_jax.cc Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Fixed Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Fixed Show fixed Hide fixed
Copy link
Contributor

coderabbitai bot commented Nov 4, 2024

📝 Walkthrough

Walkthrough

The pull request introduces significant enhancements across various components of the DeePMD-kit project. Key changes include the addition of JAX as a new backend option in the documentation and codebase, modifications to the workflow configuration for improved memory leak checks, and the introduction of a new class, DeepPotJAX, to facilitate TensorFlow interactions. The installation documentation is updated to reflect these changes, ensuring clarity on backend options and requirements. Additionally, unit tests for the new JAX functionalities are implemented, and the Google Test library version is updated.

Changes

File Path Change Summary
.github/workflows/test_cc.yml Updated workflow for C++ testing to enhance memory leak checks and added concurrency management with merge_group.
doc/backend.md Added JAX as a backend option, specifying file extensions and requirements for C++ inference support.
doc/install/install-from-source.md Expanded installation instructions to include JAX and specific backend options, with new environment variables introduced.
source/api_cc/include/DeepPotJAX.h Introduced DeepPotJAX class for TensorFlow implementation, including multiple constructors and methods for property access.
source/api_cc/include/common.h Updated DPBackend enumeration to include JAX.
source/api_cc/src/DeepPot.cc Modified DeepPot class to support JAX backend, including conditional logic for model file handling.
source/api_cc/src/DeepPotJAX.cc Introduced DeepPotJAX class with methods for TensorFlow interaction and error handling.
source/api_cc/tests/test_deeppot_jax.cc Added unit tests for DeepPot class functionalities using Google Test framework.
source/cmake/googletest.cmake.in Updated Google Test library version from release-1.12.1 to v1.14.0.
source/tests/infer/deeppot_sea.savedmodel/.gitignore Added entry to ignore .pb files except those matching a specific pattern.
source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb Added new binary file fingerprint.pb, a serialized protocol buffer.

Possibly related PRs

Suggested reviewers

  • wanghan-iapcm
  • iProzd

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

🧹 Outside diff range and nitpick comments (21)
source/cmake/googletest.cmake.in (1)

14-14: Leverage improved memory leak detection

This version update is particularly beneficial as Google Test v1.14.0 includes enhanced memory leak detection capabilities, which could help identify potential memory leaks mentioned in the PR objectives regarding the TensorFlow C API usage.

Consider enabling the following Google Test features in your test configurations:

  • Use testing::FLAGS_gtest_death_test_style = "threadsafe" for better leak detection in death tests
  • Enable GTEST_FLAG(detect_leaks) when available on your platform
.github/workflows/test_cc.yml (1)

48-48: Consider enhancing memory leak detection.

Given the PR's focus on C++ interface and known memory leak concerns, consider these improvements:

  1. Add a timeout specifically for sanitizer runs as they can take longer
  2. Consider running leak checks on multiple platforms to catch platform-specific memory issues

Example configuration:

 strategy:
   matrix:
     check_memleak: [true, false]
+    # Add platform matrix when check_memleak is true
+    include:
+      - check_memleak: true
+        os: macos-latest
+      - check_memleak: true
+        os: windows-latest
 steps:
+  # Add timeout for sanitizer runs
+  timeout-minutes: ${{ matrix.check_memleak && 30 || 15 }}
source/api_cc/include/common.h (1)

16-16: Consider adding documentation for backend-specific limitations.

Given the PR objectives mentioning device compatibility restrictions and memory leak concerns with the TF C API, it would be helpful to document these limitations in the header file.

Add documentation above the enum:

+/**
+ * @brief Supported deep learning backends
+ * @note JAX backend has the following limitations:
+ *   - Models created on one device type cannot be executed on different devices
+ *   - Only supports neighbor list functionality
+ *   - May have memory leaks when using TensorFlow C API
+ */
 enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };
source/api_cc/include/DeepPotJAX.h (5)

25-25: Avoid passing primitive types by const reference.

Passing primitive types like int by const reference (const int&) does not provide performance benefits and can slightly hinder performance. It's recommended to pass these types by value instead.

Apply the following diff to update parameter declarations:

-const int& gpu_rank
+int gpu_rank

-const int& ago
+int ago

Also applies to: 30-30, 35-35, 40-40, 131-131, 244-244


62-65: Clarify the constant return value in numb_types_spin().

The method numb_types_spin() always returns 0, indicating that spin types are not supported. Consider documenting this behavior in the method's description to inform users.


27-27: Maintain consistent formatting in Doxygen comments for better readability.

In the Doxygen comments, ensure that each line begins with * (including a space) for consistency and improved readability.

Example fix:

- *DP will read from the string instead of the file.
+ * DP will read from the string instead of the file.

Also applies to: 37-37, 91-91, 226-226, 230-230


221-221: Correct parameter name mismatch between documentation and code.

In the documentation for the compute method, the parameter is referred to as lmp_list, but in the code, it is named inlist. Please update the documentation to match the code to prevent confusion.

Apply the following diff:

- * @param[in] lmp_list The input neighbour list.
+ * @param[in] inlist The input neighbour list.

Also applies to: 243-243


86-86: Declare get_type_map as a const method.

The get_type_map method does not modify any member variables and can be declared as const to reflect its non-mutating behavior.

Apply the following change:

-void get_type_map(std::string& type_map);
+void get_type_map(std::string& type_map) const;
source/api_cc/tests/test_deeppot_jax.cc (5)

21-34: Remove unnecessary commented-out code

The block of commented-out Python code between lines 21-34 is not needed in the C++ test file and can be removed to improve readability.

Apply this diff to remove the commented code:

-      // import numpy as np
-      // from deepmd.infer import DeepPot
-      // coord = np.array([
-      //     12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
-      //     00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
-      //     3.51,  2.51, 2.60, 4.27,  3.22, 1.56
-      // ]).reshape(1, -1)
-      // atype = np.array([0, 1, 1, 0, 1, 1])
-      // box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.]).reshape(1, -1)
-      // dp = DeepPot("deeppot_sea.savedmodel")
-      // e, f, v, ae, av = dp.eval(coord, box, atype, atomic=True)
-      // np.set_printoptions(precision=16)
-      // print(f"{e.ravel()=} {v.ravel()=} {f.ravel()=} {ae.ravel()=}
-      // {av.ravel()=}")

36-36: Correct the leading zero in floating-point literal

The coordinate value 00.25 on line 36 has an unnecessary leading zero, which may cause confusion. It should be written as 0.25 for clarity.

Apply this diff:

-                                      00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
+                                      0.25, 3.32, 1.68, 3.36,  3.00, 1.81,

94-94: Remove unnecessary empty TearDown() method

Since the TearDown() method is empty, it can be omitted to simplify the code.

Apply this diff:

-      void TearDown() override {};

121-121: Use .data() method for vector pointers

When obtaining pointers to the underlying data of a std::vector, prefer using the .data() method over &vector[0] for clarity and safety, especially if the vector could be empty.

Apply this diff to update the code:

-deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]);
+deepmd::InputNlist inlist(nloc, ilist.data(), numneigh.data(), firstneigh.data());

Also applies to: 184-184, 266-266, 345-345, 406-406


99-304: Refactor test cases to reduce code duplication

The test cases share significant portions of code, particularly in variable declarations and initializations. Consider extracting common code into helper functions or setting up shared fixtures to improve maintainability and readability.

source/api_cc/src/DeepPotJAX.cc (7)

29-29: Pass 'func_name' by const reference to improve performance

In the function find_function, the parameter func_name is passed by value. Since func_name is a std::string and is not modified within the function, consider passing it by const reference to avoid unnecessary copies.

🧰 Tools
🪛 cppcheck

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


37-37: Optimize string manipulation to avoid unnecessary copying

The call to substr assigns a prefix of name_ back to itself. You can make this more efficient by modifying name_ in place using erase or resize, which avoids creating a new string object.

Apply this diff to optimize the code:

-if (pos != std::string::npos) {
-  name_ = name_.substr(0, pos + 1);
-}
+if (pos != std::string::npos) {
+  name_.erase(pos + 1);
+}
🧰 Tools
🪛 cppcheck

[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)


64-65: Pass parameters by const reference to improve performance

In the function get_func_op, the parameters func_name and device are passed by value. Since these are std::string objects and are not modified within the function, consider passing them by const reference to avoid unnecessary copying.

🧰 Tools
🪛 cppcheck

[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


86-89: Pass parameters by const reference to improve performance

In the function get_scalar, the parameters func_name and device are passed by value. Passing them by const reference can improve performance by avoiding unnecessary string copies.

🧰 Tools
🪛 cppcheck

[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


109-112: Pass parameters by const reference to improve performance

In the function get_vector, the parameters func_name and device are passed by value. Modify them to be passed by const reference to enhance performance.

🧰 Tools
🪛 cppcheck

[performance] 109-109: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 111-111: Function parameter 'device' should be passed by const reference.

(passedByValue)


130-133: Pass parameters by const reference to improve performance

In get_vector_string, passing func_name and device by const reference will prevent unnecessary copying of std::string objects.

🧰 Tools
🪛 cppcheck

[performance] 130-130: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 132-132: Function parameter 'device' should be passed by const reference.

(passedByValue)


305-305: Remove unused local variable 'nloc'

The variable nloc is declared but not used in the code. Removing it will clean up the code and eliminate any compiler warnings about unused variables.

Apply this diff to remove the unused variable:

-int nloc = nall_real - nghost_real;
🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 305-305: Unused local variable
Variable nloc is not used.

source/api_cc/src/DeepPot.cc (1)

65-72: Clarify the exception message when TensorFlow backend is not built

The code checks if the TensorFlow backend is built before initializing DeepPotJAX, which is necessary because JAX models rely on TensorFlow. However, the exception message can be improved for clarity.

Suggested change:

        throw deepmd::deepmd_exception(
-            "TensorFlow backend is not built, which is used to load JAX2TF "
-            "SavedModels");
+            "TensorFlow backend is not built. It is required to load JAX2TF SavedModels.");
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 7aaf284 and 6c10e8e.

📒 Files selected for processing (11)
  • .github/workflows/test_cc.yml (1 hunks)
  • doc/backend.md (1 hunks)
  • doc/install/install-from-source.md (4 hunks)
  • source/api_cc/include/DeepPotJAX.h (1 hunks)
  • source/api_cc/include/common.h (1 hunks)
  • source/api_cc/src/DeepPot.cc (3 hunks)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
  • source/api_cc/tests/test_deeppot_jax.cc (1 hunks)
  • source/cmake/googletest.cmake.in (1 hunks)
  • source/tests/infer/deeppot_sea.savedmodel/.gitignore (1 hunks)
  • source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • source/tests/infer/deeppot_sea.savedmodel/.gitignore
  • source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 66-66: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 109-109: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 111-111: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 130-130: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 132-132: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

🪛 GitHub Check: CodeQL
source/api_cc/src/DeepPotJAX.cc

[notice] 305-305: Unused local variable
Variable nloc is not used.


[failure] 334-334: Multiplication result converted to larger type
Multiplication result may overflow 'int' before it is converted to 'size_type'.

🔇 Additional comments (6)
source/cmake/googletest.cmake.in (1)

14-14: Verify test compatibility with Google Test v1.14.0

The version update from 1.12.1 to 1.14.0 is a good improvement. However, since this is a major version update, we should ensure all existing tests are compatible with the new version.

✅ Verification successful

Based on the analysis of the test files, I can now provide a final response:

Google Test v1.14.0 update is safe to proceed

The codebase uses standard Google Test features like TEST, TEST_F, and TYPED_TEST macros for test definitions. No usage of death tests, custom matchers, or printers was found that could be affected by breaking changes in v1.14.0. The test suite appears to use basic assertions and test fixtures which have remained stable across Google Test versions.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for potentially incompatible Google Test features

# Look for death tests that might be affected by the new implementation
rg -l "EXPECT_DEATH|ASSERT_DEATH|EXPECT_DEBUG_DEATH|ASSERT_DEBUG_DEATH"

# Check for custom printers that might need updates
rg -l "MATCHER_P|MATCHER_P2|MATCHER_P3"

Length of output: 115


Script:

#!/bin/bash
# Let's try a different approach to check for test files and Google Test usage

# Find test files
fd -e cpp -e cc -e h test

# Look for Google Test includes and macros
rg -l "gtest|EXPECT_|ASSERT_|TEST|TEST_F"

# Check CMake files for any Google Test related configurations
fd -e cmake . -x cat {} \; 

Length of output: 27330


Script:

#!/bin/bash
# Let's examine test files to check for Google Test usage patterns

# Look for specific Google Test features that might be affected by version update
rg -A 2 "TEST\(|TEST_F\(" source/api_c/tests source/api_cc/tests source/lib/tests

# Check for death tests specifically
rg -A 2 "EXPECT_DEATH|ASSERT_DEATH|EXPECT_DEBUG_DEATH|ASSERT_DEBUG_DEATH" source/api_c/tests source/api_cc/tests source/lib/tests

# Look for custom matchers/printers
rg -A 2 "MATCHER_P|MATCHER_P2|MATCHER_P3|PrintTo" source/api_c/tests source/api_cc/tests source/lib/tests

Length of output: 86120

.github/workflows/test_cc.yml (1)

48-48: LGTM! Good addition of leak sanitizer.

The addition of -fsanitize=leak flag when check_memleak is true is a good approach to detect memory leaks, especially given the PR's known limitation regarding potential memory leaks in the TensorFlow C API.

source/api_cc/include/common.h (1)

16-16: LGTM! The JAX backend addition looks good.

The placement of JAX in the DPBackend enum is correct, maintaining the Unknown value as the last enum option.

source/api_cc/tests/test_deeppot_jax.cc (1)

2-3: Confirmed: Tests are skipped when memory sanitizer is enabled

The use of #ifndef ENABLE_SANITIZE ensures that the tests are correctly skipped when memory sanitizer is enabled, as indicated by the comment on line 2.

source/api_cc/src/DeepPot.cc (2)

10-10: Conditional inclusion of DeepPotJAX.h is appropriate

Including DeepPotJAX.h within the #ifdef BUILD_TENSORFLOW block ensures that the header file is only included when the TensorFlow backend is built, which is necessary because DeepPotJAX depends on TensorFlow.


45-47: Properly handle .savedmodel files for JAX backend selection

The code correctly identifies model files with the .savedmodel extension and assigns the JAX backend, ensuring that JAX models are appropriately initialized.

doc/backend.md Outdated Show resolved Hide resolved
source/api_cc/include/common.h Show resolved Hide resolved
doc/install/install-from-source.md Show resolved Hide resolved
doc/install/install-from-source.md Show resolved Hide resolved
source/api_cc/include/DeepPotJAX.h Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Show resolved Hide resolved
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (7)
source/api_cc/tests/test_deeppot_jax.cc (1)

100-110: Refactor to eliminate redundant variable declarations in test cases

In each test case, the initial declarations of variables that reference class members are repetitive. To improve maintainability and reduce code duplication, consider accessing the member variables directly using this-> or removing this-> if not necessary.

Apply this diff to remove the redundant local references:

 using VALUETYPE = TypeParam;
-std::vector<VALUETYPE>& coord = this->coord;
-std::vector<int>& atype = this->atype;
-std::vector<VALUETYPE>& box = this->box;
-std::vector<VALUETYPE>& expected_e = this->expected_e;
-std::vector<VALUETYPE>& expected_f = this->expected_f;
-std::vector<VALUETYPE>& expected_v = this->expected_v;
-int& natoms = this->natoms;
-double& expected_tot_e = this->expected_tot_e;
-std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
-deepmd::DeepPot& dp = this->dp;

After removing these declarations, you can directly use the member variables in your test cases. For example, replace coord with this->coord or simply coord if this-> is not required.

Also applies to: 162-172, 245-255, 307-317, 369-379

source/api_cc/src/DeepPotJAX.cc (6)

27-45: Pass func_name by const reference in find_function

Passing func_name as const std::string& instead of const std::string avoids unnecessary copying, improving performance.

🧰 Tools
🪛 cppcheck

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)


35-38: Optimize string manipulation with resize or pop_back

In the find_function method, using substr may be inefficient when trimming the string. Consider using resize or pop_back for better performance.

🧰 Tools
🪛 cppcheck

[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)


63-82: Pass func_name and device by const reference in get_func_op

Passing func_name and device as const std::string& enhances performance by avoiding unnecessary string copies.

🧰 Tools
🪛 cppcheck

[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 66-66: Function parameter 'device' should be passed by const reference.

(passedByValue)


84-107: Pass func_name and device by const reference in get_scalar

Passing func_name and device as const std::string& avoids unnecessary copying, improving performance.

🧰 Tools
🪛 cppcheck

[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


109-128: Pass func_name and device by const reference in get_vector

Passing func_name and device as const std::string& improves performance by avoiding unnecessary string copies.

🧰 Tools
🪛 cppcheck

[performance] 111-111: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'device' should be passed by const reference.

(passedByValue)


130-160: Pass func_name and device by const reference in get_vector_string

Passing func_name and device as const std::string& enhances performance by avoiding unnecessary string copies.

🧰 Tools
🪛 cppcheck

[performance] 132-132: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 134-134: Function parameter 'device' should be passed by const reference.

(passedByValue)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 6c10e8e and 2aa6deb.

📒 Files selected for processing (3)
  • source/api_cc/include/DeepPotJAX.h (1 hunks)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
  • source/api_cc/tests/test_deeppot_jax.cc (1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 66-66: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 111-111: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 132-132: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 134-134: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

🔇 Additional comments (3)
source/api_cc/include/DeepPotJAX.h (1)

1-249: Code Review Completed

The implementation of the DeepPotJAX class appears comprehensive and well-structured. The class definitions, method declarations, and documentation comments are clear and adhere to the project's coding standards. All public methods are adequately documented, enhancing readability and maintainability.

source/api_cc/src/DeepPotJAX.cc (2)

275-283: Ensure resources are deallocated regardless of inited flag

The inited flag should not guard resource deallocation in the destructor, as resources may have been allocated even if initialization failed. Resources should be cleaned up unconditionally to prevent memory leaks.


369-388: ⚠️ Potential issue

Release retvals to prevent memory leaks in compute

The TFE_TensorHandle* objects in retvals obtained from TFE_Execute are not deleted after use, which can lead to memory leaks. Ensure that you call TFE_DeleteTensorHandle on each element of retvals after processing.

Apply this diff to address the issue:

     // Process retvals...

+    // Delete TFE_TensorHandle objects to free memory
+    for (int i = 0; i < nretvals; ++i) {
+        TFE_DeleteTensorHandle(retvals[i]);
+    }

Likely invalid or redundant comment.

source/api_cc/src/DeepPotJAX.cc Outdated Show resolved Hide resolved
Copy link

codecov bot commented Nov 4, 2024

Codecov Report

Attention: Patch coverage is 94.74474% with 35 lines in your changes missing coverage. Please review.

Project coverage is 84.47%. Comparing base (3a95d22) to head (9af5267).
Report is 1 commits behind head on devel.

Files with missing lines Patch % Lines
source/api_cc/src/DeepPotJAX.cc 90.69% 19 Missing and 13 partials ⚠️
source/lmp/fix_dplr.cpp 66.66% 0 Missing and 2 partials ⚠️
source/api_cc/src/DeepPot.cc 75.00% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4307      +/-   ##
==========================================
+ Coverage   84.35%   84.47%   +0.12%     
==========================================
  Files         593      596       +3     
  Lines       55900    56566     +666     
  Branches     3388     3457      +69     
==========================================
+ Hits        47154    47786     +632     
- Misses       7636     7653      +17     
- Partials     1110     1127      +17     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (2)
source/api_cc/src/DeepPotJAX.cc (2)

27-29: Pass func_name and device by const reference to improve performance

In the functions find_function, get_func_op, get_scalar, get_vector, and get_vector_string, the parameters func_name and device are passed by value. Passing them as const std::string& avoids unnecessary copying of strings and can enhance performance.

Also applies to: 63-68, 84-90, 109-115, 131-137

🧰 Tools
🪛 cppcheck

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


35-37: Optimize string truncation using resize instead of substr

In the find_function function, the call to substr at line 37 may be inefficient since it assigns a prefix of the string to itself. Consider using name_.resize(pos + 1); to truncate the string more efficiently.

🧰 Tools
🪛 cppcheck

[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 2aa6deb and f16dd92.

📒 Files selected for processing (1)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 66-66: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 111-111: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 133-133: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

njzjz added 2 commits November 4, 2024 21:28
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
@github-actions github-actions bot added the LAMMPS label Nov 4, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (2)
source/lmp/tests/test_lammps_jax.py (2)

251-274: Refactor repeated 'units' conditional blocks to improve maintainability

The units conditional blocks in the _lammps function are repeated multiple times for setting neighbor settings, masses, and timesteps. Refactoring this code can reduce duplication and enhance readability.

Consider using dictionaries to map units to their corresponding parameters:

 def _lammps(data_file, units="metal") -> PyLammps:
     lammps = PyLammps()
     lammps.units(units)
     lammps.boundary("p p p")
     lammps.atom_style("atomic")

+    units_params = {
+        "metal": {
+            "neighbor": "2.0 bin",
+            "mass": {"1": "16", "2": "2"},
+            "timestep": 0.0005,
+        },
+        "real": {
+            "neighbor": "2.0 bin",
+            "mass": {"1": "16", "2": "2"},
+            "timestep": 0.5,
+        },
+        "si": {
+            "neighbor": "2.0e-10 bin",
+            "mass": {
+                "1": "%.10e" % (16 * constants.mass_metal2si),
+                "2": "%.10e" % (2 * constants.mass_metal2si),
+            },
+            "timestep": 5e-16,
+        },
+    }
+
+    if units not in units_params:
+        raise ValueError("units should be metal, real, or si")
+
+    params = units_params[units]
+    lammps.neighbor(params["neighbor"])
     lammps.neigh_modify("every 10 delay 0 check no")
     lammps.read_data(data_file.resolve())
-    if units == "metal" or units == "real":
-        lammps.mass("1 16")
-        lammps.mass("2 2")
-    elif units == "si":
-        lammps.mass("1 %.10e" % (16 * constants.mass_metal2si))
-        lammps.mass("2 %.10e" % (2 * constants.mass_metal2si))
-    else:
-        raise ValueError("units should be metal, real, or si")
+    for atom_type, mass in params["mass"].items():
+        lammps.mass(f"{atom_type} {mass}")
 
-    if units == "metal":
-        lammps.timestep(0.0005)
-    elif units == "real":
-        lammps.timestep(0.5)
-    elif units == "si":
-        lammps.timestep(5e-16)
-    else:
-        raise ValueError("units should be metal, real, or si")
+    lammps.timestep(params["timestep"])
 
     lammps.fix("1 all nve")
     return lammps

681-682: Use pytest.importorskip for cleaner skipping of tests

Instead of manually checking for mpi4py, you can use pytest.importorskip to skip the test if the module is not installed.

Replace the manual check with:

-@pytest.mark.skipif(
-    importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed"
-)
+mpi4py = pytest.importorskip("mpi4py", reason="mpi4py is not installed")
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between f16dd92 and e64e06a.

📒 Files selected for processing (1)
  • source/lmp/tests/test_lammps_jax.py (1 hunks)

source/lmp/tests/test_lammps_jax.py Show resolved Hide resolved
njzjz added 2 commits November 4, 2024 21:44
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz marked this pull request as draft November 4, 2024 21:59
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz added the Test CUDA Trigger test CUDA workflow label Nov 13, 2024
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz added Test CUDA Trigger test CUDA workflow and removed Test CUDA Trigger test CUDA workflow labels Nov 13, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (3)
source/api_cc/src/DeepPotJAX.cc (3)

31-49: Optimize string parameter passing and manipulation

The function parameters could be more efficiently passed by const reference, and the string manipulation could be improved.

Apply this diff to improve efficiency:

-inline void find_function(TF_Function*& found_func,
-                         const std::vector<TF_Function*>& funcs,
-                         const std::string func_name) {
+inline void find_function(TF_Function*& found_func,
+                         const std::vector<TF_Function*>& funcs,
+                         const std::string& func_name) {
   for (size_t i = 0; i < funcs.size(); i++) {
     TF_Function* func = funcs[i];
     const char* name = TF_FunctionName(func);
     std::string name_(name);
     // remove trailing integer e.g. _123
-    std::string::size_type pos = name_.find_last_not_of("0123456789_");
-    if (pos != std::string::npos) {
-      name_ = name_.substr(0, pos + 1);
-    }
+    name_.erase(std::find_if(name_.rbegin(), name_.rend(),
+                [](char c) { return c != '_' && !std::isdigit(c); }).base() + 1);
🧰 Tools
🪛 cppcheck

[performance] 33-33: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)


195-209: Optimize vector data copying

The manual loop for copying data could be replaced with std::copy for better efficiency and readability.

Apply this diff to improve the implementation:

   result.resize(TF_TensorElementCount(tensor));
-  for (int i = 0; i < TF_TensorElementCount(tensor); i++) {
-    result[i] = data[i];
-  }
+  std::copy(data, data + TF_TensorElementCount(tensor), result.begin());

751-778: Improve exception messages for unimplemented methods

The exception messages for unimplemented mixed-type computation methods could be more informative.

Apply this diff to improve error messages:

-  throw deepmd::deepmd_exception("not implemented");
+  throw deepmd::deepmd_exception("Mixed-type computation is not yet implemented in DeepPotJAX");
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between e569ed9 and 1b3fd5e.

📒 Files selected for processing (1)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[error] 253-253: Null pointer dereference

(nullPointer)


[performance] 33-33: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 68-68: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 70-70: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 90-90: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 115-115: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 137-137: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

source/api_cc/src/DeepPotJAX.cc Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Show resolved Hide resolved
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Nov 13, 2024
@njzjz njzjz added the Test CUDA Trigger test CUDA workflow label Nov 13, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (2)
source/api_cc/src/DeepPotJAX.cc (2)

331-476: Consider using RAII for tensor management

The tensor management code could benefit from RAII to ensure proper resource cleanup, especially in case of exceptions.

Consider creating a RAII wrapper for tensor handles:

class TensorHandleRAII {
private:
    TFE_TensorHandle* handle;
    TF_Tensor* tensor;
public:
    TensorHandleRAII(TFE_TensorHandle* h, TF_Tensor* t) : handle(h), tensor(t) {}
    ~TensorHandleRAII() {
        if (handle) TFE_DeleteTensorHandle(handle);
        if (tensor) TF_DeleteTensor(tensor);
    }
    TFE_TensorHandle* get_handle() { return handle; }
};

// Usage in compute:
std::vector<TensorHandleRAII> input_handles;
for (int i = 0; i < 5; i++) {
    TF_Tensor* tensor = create_tensor(...);
    TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor, status);
    input_handles.emplace_back(handle, tensor);
}

752-779: Improve error messages for unimplemented methods

The error messages for unimplemented mixed type compute methods could be more descriptive to help users understand why the functionality is not available.

Apply this diff to improve the error messages:

-  throw deepmd::deepmd_exception("not implemented");
+  throw deepmd::deepmd_exception("Mixed type computation is not yet implemented for JAX backend. Please use single precision types.");
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 1b3fd5e and 9d95778.

📒 Files selected for processing (1)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[performance] 33-33: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 68-68: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 70-70: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 90-90: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 115-115: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 137-137: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

🔇 Additional comments (2)
source/api_cc/src/DeepPotJAX.cc (2)

317-329: ⚠️ Potential issue

Ensure unconditional resource cleanup in destructor

The destructor currently guards resource cleanup with the inited flag. This could lead to resource leaks if initialization fails before setting inited = true. Resources should be cleaned up unconditionally.

Apply this diff to fix the issue:

 deepmd::DeepPotJAX::~DeepPotJAX() {
-  if (inited) {
-    TF_DeleteSession(session, status);
-    TF_DeleteGraph(graph);
-    TF_DeleteSessionOptions(sessionopts);
-    TF_DeleteStatus(status);
-    TFE_DeleteContext(ctx);
-    TFE_DeleteContextOptions(ctx_opts);
-    for (size_t i = 0; i < func_vector.size(); i++) {
-      TF_DeleteFunction(func_vector[i]);
-    }
-  }
+  TF_DeleteSession(session, status);
+  TF_DeleteGraph(graph);
+  TF_DeleteSessionOptions(sessionopts);
+  TF_DeleteStatus(status);
+  TFE_DeleteContext(ctx);
+  TFE_DeleteContextOptions(ctx_opts);
+  for (size_t i = 0; i < func_vector.size(); i++) {
+    TF_DeleteFunction(func_vector[i]);
+  }
 }

Likely invalid or redundant comment.


523-529: ⚠️ Potential issue

Add safety checks for padding calculation

The padding calculation loop could potentially run indefinitely if there are numerical issues. Add bounds checking to prevent this.

Apply this diff to add safety checks:

+  const size_t MAX_PADDING_ITERATIONS = 100;  // Reasonable upper limit
+  size_t iteration_count = 0;
   while (padding_to_nall < nall_real) {
     padding_to_nall *= PADDING_FACTOR;
+    if (++iteration_count > MAX_PADDING_ITERATIONS) {
+      throw deepmd::deepmd_exception("Exceeded maximum padding iterations. Check for numerical issues.");
+    }
   }

Likely invalid or redundant comment.

source/api_cc/src/DeepPotJAX.cc Outdated Show resolved Hide resolved
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Nov 13, 2024
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz added the Test CUDA Trigger test CUDA workflow label Nov 13, 2024
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Nov 13, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (3)
source/api_cc/src/DeepPotJAX.cc (3)

22-22: Consider using constexpr instead of macro

Replace the macro with a constexpr for better type safety and debugging:

-#define PADDING_FACTOR 1.05
+constexpr double PADDING_FACTOR = 1.05;

331-476: Consider using RAII for resource management

The compute method manages multiple resources manually. Consider using RAII patterns with smart pointers or scope guards to ensure proper cleanup even in case of exceptions.

Example approach:

template <typename T>
class ScopedTensorHandle {
    TFE_TensorHandle* handle;
public:
    explicit ScopedTensorHandle(TFE_TensorHandle* h) : handle(h) {}
    ~ScopedTensorHandle() { if(handle) TFE_DeleteTensorHandle(handle); }
    TFE_TensorHandle* get() { return handle; }
};

764-779: Add TODO comments for unimplemented methods

The computew_mixed_type methods throw "not implemented" exceptions. Consider adding TODO comments explaining the planned implementation or requirements.

Would you like me to help create GitHub issues to track the implementation of these methods?

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 9d95778 and 09efdd3.

📒 Files selected for processing (1)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[performance] 33-33: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 68-68: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 70-70: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 90-90: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 115-115: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 137-137: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

🔇 Additional comments (2)
source/api_cc/src/DeepPotJAX.cc (2)

523-529: ⚠️ Potential issue

Add safety checks for padding calculations

The padding calculation loop lacks bounds checking which could lead to infinite loops or integer overflow.

+  const size_t MAX_PADDING_ITERATIONS = 100;
+  size_t iteration_count = 0;
   while (padding_to_nall < nall_real) {
     padding_to_nall *= PADDING_FACTOR;
+    if (++iteration_count > MAX_PADDING_ITERATIONS) {
+      throw deepmd::deepmd_exception("Maximum padding iterations exceeded");
+    }
   }

Likely invalid or redundant comment.


317-329: ⚠️ Potential issue

Remove inited flag check from destructor

The destructor should unconditionally clean up resources as they are allocated during initialization before the inited flag is set. This prevents resource leaks if initialization fails.

 deepmd::DeepPotJAX::~DeepPotJAX() {
-  if (inited) {
     TF_DeleteSession(session, status);
     TF_DeleteGraph(graph);
     TF_DeleteSessionOptions(sessionopts);
     TF_DeleteStatus(status);
     TFE_DeleteContext(ctx);
     TFE_DeleteContextOptions(ctx_opts);
     for (size_t i = 0; i < func_vector.size(); i++) {
       TF_DeleteFunction(func_vector[i]);
     }
-  }
 }

Likely invalid or redundant comment.

source/api_cc/src/DeepPotJAX.cc Show resolved Hide resolved
@njzjz njzjz added the Test CUDA Trigger test CUDA workflow label Nov 13, 2024
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Nov 13, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (1)
source/api_cc/src/DeepPotJAX.cc (1)

31-49: Optimize string operations and parameter passing

The function can be optimized for better performance:

  1. Use std::string_view for string parameters that are only read
  2. Use resize() instead of substr() for better performance
  3. Pass string parameters by const reference
-inline void find_function(TF_Function*& found_func,
-                         const std::vector<TF_Function*>& funcs,
-                         const std::string func_name) {
+inline void find_function(TF_Function*& found_func,
+                         const std::vector<TF_Function*>& funcs,
+                         std::string_view func_name) {
   for (size_t i = 0; i < funcs.size(); i++) {
     TF_Function* func = funcs[i];
     const char* name = TF_FunctionName(func);
     std::string name_(name);
     // remove trailing integer e.g. _123
     std::string::size_type pos = name_.find_last_not_of("0123456789_");
     if (pos != std::string::npos) {
-      name_ = name_.substr(0, pos + 1);
+      name_.resize(pos + 1);
     }
-    if (name_ == "__inference_" + func_name) {
+    if (name_ == "__inference_" + std::string(func_name)) {
       found_func = func;
       return;
     }
   }
   found_func = NULL;
}
🧰 Tools
🪛 cppcheck

[performance] 33-33: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 09efdd3 and 39f357c.

📒 Files selected for processing (1)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[performance] 33-33: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 68-68: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 70-70: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 90-90: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 115-115: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 137-137: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

🔇 Additional comments (2)
source/api_cc/src/DeepPotJAX.cc (2)

243-255: ⚠️ Potential issue

Add bounds checking for GPU memory configuration

The GPU memory configuration manipulates raw bytes without proper bounds checking, which could lead to buffer overflows.

  std::array<std::uint8_t, 19> config = {
      {0x10, static_cast<std::uint8_t>(num_intra_nthreads), 0x28,
       static_cast<std::uint8_t>(num_inter_nthreads), 0x32, 0xb, 0x9, 0xFF,
       0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x20, 0x1, 0x38, 0x1}};

  double gpu_memory_fraction = 0.9;
+ if (sizeof(gpu_memory_fraction) > 8) {
+   throw deepmd::deepmd_exception("GPU memory fraction size exceeds buffer capacity");
+ }
  auto bytes = reinterpret_cast<std::uint8_t*>(&gpu_memory_fraction);

  for (std::size_t i = 0; i < sizeof(gpu_memory_fraction); ++i) {
+   if (i + 3 >= config.size()) {
+     throw deepmd::deepmd_exception("Buffer overflow in GPU memory configuration");
+   }
    config[i + 3] = bytes[i];
  }

Likely invalid or redundant comment.


526-532: ⚠️ Potential issue

Add safety check for padding calculation loop

The while loop for padding calculation could potentially run indefinitely if there are numerical issues.

+ const size_t MAX_PADDING_ITERATIONS = 100;  // Reasonable upper limit
+ size_t iteration_count = 0;
  while (padding_to_nall < nall_real) {
    padding_to_nall *= PADDING_FACTOR;
+   if (++iteration_count > MAX_PADDING_ITERATIONS) {
+     throw deepmd::deepmd_exception("Exceeded maximum padding iterations. Check for numerical issues.");
+   }
  }

Likely invalid or redundant comment.

Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz added the Test CUDA Trigger test CUDA workflow label Nov 13, 2024
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Nov 13, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (2)
source/api_cc/src/DeepPotJAX.cc (2)

31-49: Optimize string operations and parameter passing

Consider these performance improvements:

  1. Pass string parameters by const reference to avoid unnecessary copies
  2. Use string_view for string parameters that are only read
  3. Use resize() instead of substr() for better performance
-inline void find_function(TF_Function*& found_func,
-                         const std::vector<TF_Function*>& funcs,
-                         const std::string func_name) {
+inline void find_function(TF_Function*& found_func,
+                         const std::vector<TF_Function*>& funcs,
+                         std::string_view func_name) {
   for (size_t i = 0; i < funcs.size(); i++) {
     TF_Function* func = funcs[i];
     const char* name = TF_FunctionName(func);
     std::string name_(name);
     // remove trailing integer e.g. _123
     std::string::size_type pos = name_.find_last_not_of("0123456789_");
     if (pos != std::string::npos) {
-      name_ = name_.substr(0, pos + 1);
+      name_.resize(pos + 1);
     }
-    if (name_ == "__inference_" + func_name) {
+    if (name_ == "__inference_" + std::string(func_name)) {
       found_func = func;
       return;
     }
   }
   found_func = NULL;
}
🧰 Tools
🪛 cppcheck

[performance] 33-33: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)


755-782: Document unimplemented methods

The computew_mixed_type methods throw "not implemented" exceptions but lack documentation explaining why they're not implemented or when they will be.

Add documentation comments explaining:

  1. Why these methods are not implemented
  2. When they will be implemented
  3. What alternatives users should use in the meantime
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 39f357c and cfff834.

📒 Files selected for processing (1)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[performance] 33-33: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 68-68: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 70-70: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 90-90: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 115-115: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 137-137: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

🔇 Additional comments (1)
source/api_cc/src/DeepPotJAX.cc (1)

526-532: ⚠️ Potential issue

Add safety checks for padding calculations

The while loop for padding calculation could potentially run indefinitely if there are numerical issues.

+  const size_t MAX_PADDING_ITERATIONS = 100;  // Reasonable upper limit
+  size_t iteration_count = 0;
   while (padding_to_nall < nall_real) {
     padding_to_nall *= PADDING_FACTOR;
+    if (++iteration_count > MAX_PADDING_ITERATIONS) {
+      throw deepmd::deepmd_exception("Exceeded maximum padding iterations. Check for numerical issues.");
+    }
   }

Likely invalid or redundant comment.

source/api_cc/src/DeepPotJAX.cc Show resolved Hide resolved
source/api_cc/src/DeepPotJAX.cc Show resolved Hide resolved
This reverts commit 6d5b45a.
@njzjz njzjz marked this pull request as ready for review November 13, 2024 05:29
njzjz added a commit to njzjz/deepmd-kit that referenced this pull request Nov 13, 2024
@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Nov 13, 2024
Merged via the queue into deepmodeling:devel with commit 698b08d Nov 13, 2024
51 checks passed
github-merge-queue bot pushed a commit that referenced this pull request Nov 14, 2024
As discussed, this PR passes mapping from LAMMPS to the PT C++
interface, which is helpful for the external GNN models.

The mapping interface is synced from #4307.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced a new function `DP_NlistSetMapping` for setting mappings in
neighbor lists.
- Added `set_mapping` method to `InputNlist` for mapping atoms to real
atoms.
- Enhanced `compute` methods in `DeepPotPT`, `PairDeepMD`, and `FixDPLR`
classes to support new mapping functionalities.

- **Bug Fixes**
- Improved error handling in various classes to ensure robustness during
execution.

- **Documentation**
- Updated and added comments for clarity and consistency in new and
existing functions.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants