Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): passing mapping from LAMMPS to DPA-2 #4316

Closed
wants to merge 29 commits into from

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Nov 6, 2024

Based on #4307 (which should be merged first). The testing model is generated after #4315 is applied.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for JAX as a backend option in DeePMD-kit, enhancing model compatibility.
    • Introduced a new DeepPotJAX class for TensorFlow integration, facilitating advanced computations.
  • Documentation

    • Updated backend documentation to include JAX specifics and installation instructions.
    • Added a section on limitations for JAX with LAMMPS in the relevant documentation.
  • Bug Fixes

    • Improved error handling and mapping logic in the FixDPLR and PairDeepMD classes.
  • Tests

    • Introduced comprehensive testing suites for LAMMPS integration with DeepMD, ensuring accuracy across various configurations.

njzjz and others added 29 commits November 4, 2024 03:50
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz requested a review from wanghan-iapcm November 6, 2024 00:17
Copy link
Contributor

coderabbitai bot commented Nov 6, 2024

Warning

Rate limit exceeded

@github-actions[bot] has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 14 minutes and 20 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between dabedd2 and 58dcf2b.


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

coderabbitai bot commented Nov 6, 2024

Warning

Rate limit exceeded

@github-actions[bot] has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 14 minutes and 48 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between dabedd2 and 58dcf2b.


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

1 similar comment
Copy link
Contributor

coderabbitai bot commented Nov 6, 2024

Warning

Rate limit exceeded

@github-actions[bot] has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 14 minutes and 48 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between dabedd2 and 58dcf2b.


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

coderabbitai bot commented Nov 6, 2024

📝 Walkthrough

Walkthrough

The pull request introduces several updates across multiple files, primarily focusing on enhancing support for the JAX backend within the DeePMD-kit framework. Key changes include modifications to the .pre-commit-config.yaml for pre-commit hooks, updates to documentation regarding JAX integration, and the introduction of new classes and methods in the C++ API to support JAX functionalities. Additionally, various test files have been added to ensure the reliability of the new features, while existing files have been updated to accommodate the new backend options and improve overall functionality.

Changes

File Path Change Summary
.pre-commit-config.yaml Updated check-added-large-files hook with new exclusion; defined multiple insert-license hooks for different file types; configured pylint to run with specific entry command.
doc/backend.md Added JAX as a new backend option with specific model filename extensions; updated requirements for JAX version; clarified support for C++ inference.
doc/install/install-from-source.md Enhanced installation instructions for both TensorFlow and JAX backends; clarified Python interface setup; updated environment variables and CMake configuration.
doc/model/dpa2.md Added section on limitations of JAX backend with LAMMPS.
source/api_c/include/c_api.h Incremented API version to 24; added DP_NlistSetMapping function.
source/api_c/include/deepmd.hpp Introduced set_mapping method in InputNlist structure.
source/api_c/src/c_api.cc Implemented DP_NlistSetMapping function.
source/api_cc/include/DeepPotJAX.h Added DeepPotJAX class with multiple constructors and methods for Deep Potential.
source/api_cc/include/common.h Updated DPBackend enum to include JAX.
source/api_cc/src/DeepPot.cc Modified DeepPot class to handle JAX backend initialization.
source/api_cc/src/DeepPotJAX.cc Implemented TensorFlow integration for DeepPotJAX class with error handling and tensor management functions.
source/api_cc/tests/test_deeppot_jax.cc Introduced unit tests for DeepPotJAX class using Google Test framework.
source/cmake/googletest.cmake.in Updated Google Test library version from release-1.12.1 to v1.14.0.
source/lib/include/neighbor_list.h Added mapping member and set_mapping method to InputNlist.
source/lmp/fix_dplr.cpp Enhanced FixDPLR class for improved mapping and force calculations.
source/lmp/pair_deepmd.cpp Updated PairDeepMD class to handle atom mappings and communication.
source/lmp/tests/test_lammps_dpa_jax.py Added comprehensive tests for DPMD framework with LAMMPS.
source/lmp/tests/test_lammps_jax.py Established tests for LAMMPS integration with DeepMD potential.
source/tests/infer/deeppot_dpa.savedmodel/.gitignore Updated to ignore .pb files except for specified patterns.
source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb Added binary file fingerprint.pb.
source/tests/infer/deeppot_sea.savedmodel/.gitignore Updated to ignore .pb files except for specified patterns.
source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb Added binary file fingerprint.pb.

Possibly related issues

Possibly related PRs

Suggested labels

Python, LAMMPS, Docs

Suggested reviewers

  • wanghan-iapcm
  • iProzd

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

🧹 Outside diff range and nitpick comments (39)
doc/model/dpa2.md (1)

21-27: Documentation looks good, with some suggestions for enhancement.

The new section clearly documents the JAX backend limitations and requirements. Consider enhancing it further by:

  1. Explaining why multiple MPI ranks are not supported
  2. Adding a link to LAMMPS MPI documentation
  3. Mentioning if multi-rank support is planned for future releases
doc/backend.md (2)

34-35: Critical information about JAX model limitations

The documentation clearly states important limitations that users need to be aware of:

  1. Only .savedmodel format supports C++ inference
  2. Models are device-specific (GPU models won't run on CPU)

These are crucial pieces of information that help prevent runtime issues.

Consider adding examples of common error messages users might encounter when:

  1. Trying to use non-.savedmodel formats with C++ inference
  2. Attempting to run GPU models on CPU
    This would make troubleshooting easier for users.
🧰 Tools
🪛 LanguageTool

[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...

(SO_THAT_UNNECESSARY_COMMA)


35-35: Fix unnecessary comma

-The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs.
+The model is device-specific so that the model generated on the GPU device cannot be run on the CPUs.
🧰 Tools
🪛 LanguageTool

[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...

(SO_THAT_UNNECESSARY_COMMA)

source/api_cc/include/common.h (1)

16-16: Consider adding documentation for enum values.

The DPBackend enum would benefit from documentation describing each backend option and their implications.

Consider adding doxygen-style documentation:

+/**
+ * @brief Supported deep learning backends
+ * @enum DPBackend
+ * @var TensorFlow TensorFlow backend support
+ * @var PyTorch PyTorch backend support
+ * @var Paddle PaddlePaddle backend support
+ * @var JAX JAX backend support
+ * @var Unknown Unrecognized backend
+ */
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };
source/lib/include/neighbor_list.h (2)

47-48: Consider using smart pointers for safer memory management.

The raw pointer mapping could lead to memory leaks or dangling pointers. Consider:

  1. Using std::unique_ptr or std::shared_ptr for automatic memory management
  2. Adding a size member variable to explicitly track the mapping array size
  3. Adding bounds checking in set_mapping
- /// mapping from all atoms to real atoms, in the size of nall
- int* mapping = nullptr;
+ /// mapping from all atoms to real atoms, in the size of nall
+ private:
+   std::unique_ptr<int[]> mapping = nullptr;
+   size_t mapping_size = 0;
+ public:
+   const int* get_mapping() const { return mapping.get(); }

Based on the evidence gathered from the codebase:

GPU implementation needs to be updated for the mapping feature

The convert_nlist_gpu_device function currently doesn't handle the new mapping member, which could lead to inconsistencies between CPU and GPU implementations. Specific updates needed:

  • convert_nlist_gpu_device in source/lib/src/neighbor_list.cc needs to allocate and copy the mapping array to GPU memory
  • free_nlist_gpu_device should be updated to free the GPU mapping array
  • GPU-specific test cases should be added to verify the mapping functionality
🔗 Analysis chain

Line range hint 47-107: Consider GPU implications of the new mapping feature.

The new mapping functionality might need corresponding support in GPU-related functions:

  1. Update convert_nlist_gpu_device to handle the mapping array
  2. Consider memory management in free_nlist_gpu_device
  3. Evaluate if use_nlist_map needs modifications

Let's check for related GPU functions:


Let me check the implementation of GPU-related functions to verify if they need updates for the mapping feature.


Let me check one more aspect to see if there are any GPU-related test cases that might show the expected behavior with mapping.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for GPU-related functions that might need updates
rg "nlist.*gpu" --type cpp

Length of output: 4129


Script:

#!/bin/bash
# Check the implementation of convert_nlist_gpu_device
rg -A 20 "convert_nlist_gpu_device.*InputNlist" source/lib/src/neighbor_list.cc

# Check if mapping is used in GPU-related functions
rg "mapping" source/lib/src/neighbor_list.cc

Length of output: 1278


Script:

#!/bin/bash
# Check test cases for GPU neighbor list implementation
rg -B 2 -A 5 "test.*nlist.*gpu" source/lib/tests/test_neighbor_list.cc

# Look for any mapping-related GPU tests
rg "mapping.*gpu" --type cpp source/lib/tests/

Length of output: 121

doc/install/install-from-source.md (3)

300-302: Add version compatibility information for JAX.

The documentation clearly explains that both TensorFlow and JAX backends use the TensorFlow C++ library. However, it would be helpful to specify which JAX versions are compatible with which TensorFlow versions to help users avoid version mismatches.


Line range hint 380-396: Clarify the relationship between TensorFlow and JAX enablement.

While the documentation correctly indicates that both TensorFlow and JAX backends use these CMake variables, it would be helpful to explicitly state:

  1. Whether enabling TensorFlow automatically enables JAX support
  2. If there are any additional steps or variables needed specifically for JAX
  3. If there are any limitations when using both backends simultaneously

This would help users better understand the configuration options available to them.


Line range hint 1-500: Documentation successfully integrates JAX backend support.

The documentation changes effectively integrate JAX backend support while maintaining clarity and consistency. The shared infrastructure with TensorFlow is well explained, and the installation process is clearly documented. The changes align well with the PR objectives of adding JAX backend support.

Consider adding a troubleshooting section specific to JAX installation to help users resolve common issues they might encounter during the setup process.

source/api_cc/tests/test_deeppot_jax.cc (3)

19-32: Consider moving Python code to a separate file or improve formatting.

The Python code in comments shows how the test data was generated, but it's not properly formatted and makes the code harder to read.

Consider either:

  1. Moving this code to a separate .py file and reference it in the comments
  2. Improving the formatting of the inline Python code:
-  // import numpy as np
-  // from deepmd.infer import DeepPot
-  // coord = np.array([
-  //     12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
-  //     00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
-  //     3.51,  2.51, 2.60, 4.27,  3.22, 1.56
-  // ]).reshape(1, -1)
+  // Python code used to generate test data:
+  // ```python
+  // import numpy as np
+  // from deepmd.infer import DeepPot
+  // 
+  // coord = np.array([
+  //     12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
+  //     00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
+  //     3.51,  2.51, 2.60, 4.27,  3.22, 1.56
+  // ]).reshape(1, -1)
+  // ```

319-324: Define magic numbers as named constants.

The test cases use magic numbers for array sizes and indices. This makes the code harder to understand and maintain.

Consider defining constants:

+// Constants for virtual atom tests
+constexpr int kNumVirtualAtoms = 2;
+constexpr int kVirtualAtomType = 2;
+
 // add vir atoms
-int nvir = 2;
-std::vector<VALUETYPE> coord_vir(nvir * 3);
-std::vector<int> atype_vir(nvir, 2);
+int nvir = kNumVirtualAtoms;
+std::vector<VALUETYPE> coord_vir(nvir * 3);
+std::vector<int> atype_vir(nvir, kVirtualAtomType);

Also applies to: 380-386


97-97: Add documentation for test purposes.

Each test case lacks documentation explaining its purpose and what specific aspect it's testing.

Add descriptive comments for each test case:

+/**
+ * Tests the basic functionality of DeepPot with LAMMPS neighbor lists.
+ * Verifies energy, force, and virial calculations.
+ */
 TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist) {

Also applies to: 159-159, 242-242, 304-304, 366-366, 429-429, 434-434

source/api_c/include/c_api.h (1)

81-89: Documentation could be more detailed.

While the function declaration is well-structured, the documentation could be enhanced with:

  1. More details about the expected format and constraints of the mapping array
  2. Example usage or common use cases
  3. Error handling behavior

Consider expanding the documentation:

 /**
  * @brief Set mapping for a neighbor list.
  *
  * @param nl Neighbor list.
- * @param mapping mapping from all atoms to real atoms, in size nall.
+ * @param mapping Array mapping indices from all atoms to real atoms. Size must be equal to
+ *               the total number of atoms (nall). Each element should be a valid index
+ *               into the real atoms array. If NULL, the mapping is reset.
+ * @throws ValueError if any mapping index is invalid or if the array size is incorrect.
  * @since API version 24
  *
  **/
source/api_c/include/deepmd.hpp (1)

618-622: Enhance documentation for the set_mapping method.

The documentation should clarify:

  • The ownership and lifetime requirements of the mapping pointer
  • Whether the pointer can be null
  • The expected size of the mapping array

Apply this diff to improve the documentation:

  /**
   * @brief Set mapping for this neighbor list.
   * @param mapping mapping from all atoms to real atoms, in size nall.
+  * @note The mapping array must remain valid for the lifetime of this neighbor list.
+  *       The array is not copied, only the pointer is stored.
+  * @warning The size of the mapping array must match the total number of atoms (nall).
+  *          Passing nullptr will clear any existing mapping.
   */
  void set_mapping(int *mapping) { DP_NlistSetMapping(nl, mapping); };
source/api_cc/include/DeepPotJAX.h (1)

233-247: Consider adding error handling to the compute method.

The compute method currently lacks error handling for potential issues such as:

  • Invalid input sizes (e.g., mismatch between the number of atoms and the size of the atype vector)
  • Out-of-range atom types
  • Invalid or inconsistent neighbor list data

To improve the robustness of the code, consider adding appropriate error checks and throwing exceptions with informative error messages when invalid input is detected. This will help users identify and fix issues more easily.

source/lmp/fix_dplr.cpp (1)

445-445: Use Consistent Data Types for Loop Indices

The loop variable ii is declared as size_t, whereas nall is of type int. Mixing signed and unsigned types may lead to potential type mismatch warnings or unintended behavior. Consider declaring ii as an int to match the type of nall.

source/api_cc/src/DeepPotJAX.cc (10)

20-25: Ensure consistent error handling across the codebase.

The check_status function provides a convenient way to check the status of TensorFlow operations and throw exceptions in case of errors. Consider using this function consistently throughout the codebase to ensure uniform error handling and improve code readability.


27-45: Consider using a more efficient search algorithm.

The find_function function performs a linear search to find a specific function by name in the funcs vector. If the number of functions is large, this can be inefficient. Consider using a more efficient search algorithm, such as binary search or a hash table, to improve performance.

🧰 Tools
🪛 cppcheck

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)


47-61: Use a template function to deduce the data type.

Instead of defining separate overloads of the get_data_tensor_type function for each data type, consider using a template function that can deduce the data type based on the input vector's type. This will reduce code duplication and improve maintainability.

template <typename T>
inline TF_DataType get_data_tensor_type(const std::vector<T>& data) {
  if constexpr (std::is_same_v<T, double>) {
    return TF_DOUBLE;
  } else if constexpr (std::is_same_v<T, float>) {
    return TF_FLOAT;
  } else if constexpr (std::is_same_v<T, int32_t>) {
    return TF_INT32;
  } else if constexpr (std::is_same_v<T, int64_t>) {
    return TF_INT64;
  } else {
    static_assert(always_false_v<T>, "Unsupported data type");
  }
}

63-82: Refactor the get_func_op function to improve readability.

The get_func_op function performs several steps to retrieve a TensorFlow operation for a given function name and context. Consider breaking down the function into smaller, more focused functions to improve readability and maintainability. For example, you can extract the code for finding the function and adding it to the context into separate functions.

🧰 Tools
🪛 cppcheck

[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 66-66: Function parameter 'device' should be passed by const reference.

(passedByValue)


84-107: Use std::optional to handle the case when the function is not found.

In the get_scalar function, instead of returning a default value when the function is not found, consider using std::optional to explicitly handle the case when the function is not found. This will make the code more expressive and less error-prone.

template <typename T>
inline std::optional<T> get_scalar(TFE_Context* ctx,
                                   const std::string& func_name,
                                   const std::vector<TF_Function*>& funcs,
                                   const std::string& device,
                                   TF_Status* status) {
  TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status);
  check_status(status);
  TFE_TensorHandle* retvals[1];
  int nretvals = 1;
  TFE_Execute(op, retvals, &nretvals, status);
  check_status(status);
  TFE_TensorHandle* retval = retvals[0];
  TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status);
  check_status(status);
  T* data = static_cast<T*>(TF_TensorData(tensor));
  if (data == nullptr) {
    return std::nullopt;
  }
  T result = *data;
  TFE_DeleteOp(op);
  TF_DeleteTensor(tensor);
  TFE_DeleteTensorHandle(retval);
  return result;
}
🧰 Tools
🪛 cppcheck

[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


109-129: Use std::vector::resize instead of creating a new vector.

In the get_vector function, instead of creating a new result vector and resizing it later, consider resizing the result vector directly using std::vector::resize. This will avoid unnecessary memory allocations and improve performance.

template <typename T>
inline std::vector<T> get_vector(TFE_Context* ctx,
                                 const std::string& func_name,
                                 const std::vector<TF_Function*>& funcs,
                                 const std::string& device,
                                 TF_Status* status) {
  TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status);
  check_status(status);
  TFE_TensorHandle* retvals[1];
  int nretvals = 1;
  TFE_Execute(op, retvals, &nretvals, status);
  check_status(status);
  TFE_TensorHandle* retval = retvals[0];
  std::vector<T> result;
  tensor_to_vector(result, retval, status);
  TFE_DeleteTensorHandle(retval);
  TFE_DeleteOp(op);
  return result;
}
🧰 Tools
🪛 cppcheck

[performance] 111-111: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'device' should be passed by const reference.

(passedByValue)


131-166: Use std::string_view to avoid unnecessary string copies.

In the get_vector_string function, consider using std::string_view instead of std::string to avoid unnecessary string copies when pushing back the strings into the result vector. This will improve performance, especially when dealing with large strings.

inline std::vector<std::string> get_vector_string(
    TFE_Context* ctx,
    const std::string& func_name,
    const std::vector<TF_Function*>& funcs,
    const std::string& device,
    TF_Status* status) {
  TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status);
  check_status(status);
  TFE_TensorHandle* retvals[1];
  int nretvals = 1;
  TFE_Execute(op, retvals, &nretvals, status);
  check_status(status);
  TFE_TensorHandle* retval = retvals[0];
  TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status);
  check_status(status);
  const void* data = TF_TensorData(tensor);
  int64_t bytes_each_string =
      TF_TensorByteSize(tensor) / TF_TensorElementCount(tensor);
  std::vector<std::string> result;
  for (int ii = 0; ii < TF_TensorElementCount(tensor); ++ii) {
    const TF_TString* datastr =
        static_cast<const TF_TString*>(static_cast<const void*>(
            static_cast<const char*>(data) + ii * bytes_each_string));
    const char* dst = TF_TString_GetDataPointer(datastr);
    size_t dst_len = TF_TString_GetSize(datastr);
    result.emplace_back(std::string_view(dst, dst_len));
  }
  TFE_DeleteOp(op);
  TF_DeleteTensor(tensor);
  TFE_DeleteTensorHandle(retval);
  return result;
}
🧰 Tools
🪛 cppcheck

[performance] 133-133: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'device' should be passed by const reference.

(passedByValue)


168-176: Use std::vector::data instead of &data[0].

In the create_tensor function, instead of using &data[0] to get a pointer to the underlying data of the data vector, consider using std::vector::data. This is a more idiomatic and safer way to obtain a pointer to the vector's data.

template <typename T>
inline TF_Tensor* create_tensor(const std::vector<T>& data,
                                const std::vector<int64_t>& shape) {
  TF_Tensor* tensor =
      TF_AllocateTensor(get_data_tensor_type(data), shape.data(), shape.size(),
                        data.size() * sizeof(T));
  std::memcpy(TF_TensorData(tensor), data.data(), TF_TensorByteSize(tensor));
  return tensor;
}

193-207: Use std::copy instead of a raw loop to copy data.

In the tensor_to_vector function, instead of using a raw loop to copy data from the TensorFlow tensor to the result vector, consider using std::copy. This will make the code more readable and less error-prone.

template <typename T>
inline void tensor_to_vector(std::vector<T>& result,
                             TFE_TensorHandle* retval,
                             TF_Status* status) {
  TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status);
  check_status(status);
  T* data = static_cast<T*>(TF_TensorData(tensor));
  result.resize(TF_TensorElementCount(tensor));
  std::copy(data, data + result.size(), result.begin());
  TF_DeleteTensor(tensor);
}

209-215: Use member initializer list for constructor initialization.

In the DeepPotJAX constructor, consider using a member initializer list to initialize the inited member variable instead of assigning it in the constructor body. This is more efficient and follows best practices for constructor initialization.

deepmd::DeepPotJAX::DeepPotJAX(const std::string& model,
                               const int& gpu_rank,
                               const std::string& file_content)
    : inited(false) {
  init(model, gpu_rank, file_content);
}
source/lmp/tests/test_lammps_jax.py (2)

225-227: Add error handling for model conversion command

The use of subprocess.check_output() without error handling may cause the test to fail silently if the conversion command fails. To improve robustness, consider capturing exceptions and handling errors appropriately.

Apply this diff to handle potential exceptions:

+try:
    sp.check_output(
        f"{sys.executable} -m deepmd convert-from pbtxt -i {pbtxt_file2.resolve()} -o {pb_file2.resolve()}".split()
    )
+except sp.CalledProcessError as e:
+    raise RuntimeError(f"Model conversion failed: {e}")

246-276: Simplify unit handling in _lammps() function

The unit handling logic uses repetitive if-elif statements for setting neighbor distances, masses, and timesteps. Refactoring this code using dictionaries can improve readability and maintainability.

Consider refactoring as follows:

def _lammps(data_file, units="metal") -> PyLammps:
    lammps = PyLammps()
    lammps.units(units)
    lammps.boundary("p p p")
    lammps.atom_style("atomic")
-   if units == "metal" or units == "real":
-       lammps.neighbor("2.0 bin")
-   elif units == "si":
-       lammps.neighbor("2.0e-10 bin")
-   else:
-       raise ValueError("units should be metal, real, or si")
+   neighbor_settings = {
+       "metal": "2.0 bin",
+       "real": "2.0 bin",
+       "si": "2.0e-10 bin",
+   }
+   lammps.neighbor(neighbor_settings.get(units, "2.0 bin"))
    lammps.neigh_modify("every 10 delay 0 check no")
    lammps.read_data(data_file.resolve())
-   if units == "metal" or units == "real":
-       lammps.mass("1 16")
-       lammps.mass("2 2")
-   elif units == "si":
-       lammps.mass("1 %.10e" % (16 * constants.mass_metal2si))
-       lammps.mass("2 %.10e" % (2 * constants.mass_metal2si))
-   else:
-       raise ValueError("units should be metal, real, or si")
+   mass_settings = {
+       "metal": [("1 16"), ("2 2")],
+       "real": [("1 16"), ("2 2")],
+       "si": [
+           ("1 %.10e" % (16 * constants.mass_metal2si)),
+           ("2 %.10e" % (2 * constants.mass_metal2si)),
+       ],
+   }
+   for mass in mass_settings.get(units, []):
+       lammps.mass(mass)
-   if units == "metal":
-       lammps.timestep(0.0005)
-   elif units == "real":
-       lammps.timestep(0.5)
-   elif units == "si":
-       lammps.timestep(5e-16)
-   else:
-       raise ValueError("units should be metal, real, or si")
+   timestep_settings = {
+       "metal": 0.0005,
+       "real": 0.5,
+       "si": 5e-16,
+   }
+   lammps.timestep(timestep_settings.get(units, 0.0005))
    lammps.fix("1 all nve")
    return lammps
source/lmp/tests/test_lammps_dpa_jax.py (6)

246-268: Refactor unit handling to avoid code duplication

The _lammps function repeats similar if-elif-else blocks for unit handling multiple times. This can be refactored to improve readability and maintainability.

Apply this diff to refactor the unit handling:

 def _lammps(data_file, units="metal") -> PyLammps:
     lammps = PyLammps()
     lammps.units(units)
     lammps.boundary("p p p")
     lammps.atom_style("atomic")
     # Requires for DPA-2
     lammps.atom_modify("map yes")
-    if units == "metal" or units == "real":
-        lammps.neighbor("2.0 bin")
-    elif units == "si":
-        lammps.neighbor("2.0e-10 bin")
-    else:
-        raise ValueError("units should be metal, real, or si")
+    neighbor_distance = {
+        "metal": "2.0 bin",
+        "real": "2.0 bin",
+        "si": "2.0e-10 bin",
+    }.get(units)
+    if neighbor_distance is None:
+        raise ValueError("units should be 'metal', 'real', or 'si'")
+    lammps.neighbor(neighbor_distance)
     lammps.neigh_modify("every 10 delay 0 check no")
     lammps.read_data(data_file.resolve())
-    if units == "metal" or units == "real":
-        lammps.mass("1 16")
-        lammps.mass("2 2")
-    elif units == "si":
-        lammps.mass("1 %.10e" % (16 * constants.mass_metal2si))
-        lammps.mass("2 %.10e" % (2 * constants.mass_metal2si))
-    else:
-        raise ValueError("units should be metal, real, or si")
+    if units in ["metal", "real"]:
+        mass1 = "16"
+        mass2 = "2"
+    elif units == "si":
+        mass1 = "%.10e" % (16 * constants.mass_metal2si)
+        mass2 = "%.10e" % (2 * constants.mass_metal2si)
+    else:
+        raise ValueError("units should be 'metal', 'real', or 'si'")
+    lammps.mass(f"1 {mass1}")
+    lammps.mass(f"2 {mass2}")
-    if units == "metal":
-        lammps.timestep(0.0005)
-    elif units == "real":
-        lammps.timestep(0.5)
-    elif units == "si":
-        lammps.timestep(5e-16)
-    else:
-        raise ValueError("units should be metal, real, or si")
+    timestep = {
+        "metal": 0.0005,
+        "real": 0.5,
+        "si": 5e-16,
+    }.get(units)
+    if timestep is None:
+        raise ValueError("units should be 'metal', 'real', or 'si'")
+    lammps.timestep(timestep)
     lammps.fix("1 all nve")
     return lammps

312-318: Ensure proper cleanup of generated files in tests

The test test_pair_deepmd does not clean up the generated dump files, which may lead to clutter or issues in subsequent tests.

Consider adding code to remove or manage temporary files created during the test.


679-681: Remove unnecessary import of importlib

The import of importlib is only used for checking if mpi4py is installed. If MPI support is not required, consider removing this import to streamline the code.

If you decide to keep MPI tests, ensure that importlib is necessary; otherwise, remove it.


468-521: Consolidate similar test functions to reduce duplication

The functions test_pair_deepmd_real and test_pair_deepmd_virial_real share similar setup code. Refactoring them to use a common helper function or parameterization could reduce code duplication and improve maintainability.

Consider refactoring as follows:

  • Extract common code into a helper function.
  • Use @pytest.mark.parametrize to run tests with varying parameters.

690-726: Reconsider skipping the MPI test function

The test_pair_deepmd_mpi function is marked to be skipped unconditionally with @pytest.mark.skip. Given that there are checks for MPI installation, you might want to enable this test when the environment supports it.

If MPI support is now available, remove the skip decorator or adjust the condition to enable the test appropriately.


230-619: Enhance test coverage for different unit systems

The tests for different unit systems (metal, real, si) could be parameterized to improve readability and reduce redundancy.

Use @pytest.mark.parametrize to run the same test logic over different units:

@pytest.mark.parametrize("units", ["metal", "real", "si"])
def test_pair_deepmd_units(units):
    lammps_instance = _lammps(data_file=data_file, units=units)
    # ... rest of the test logic ...
source/api_cc/src/DeepPot.cc (2)

45-47: Replace magic numbers with constants for file extension checks

To enhance readability and maintainability, consider defining constants for the file extensions instead of using magic numbers like 11. This approach makes the code clearer and simplifies updates if file extensions change in the future.

Apply this diff to define a constant:

+#define SAVEDMODEL_EXTENSION ".savedmodel"
+...
 } else if (model.length() >= strlen(SAVEDMODEL_EXTENSION) &&
-           model.substr(model.length() - 11) == ".savedmodel") {
+           model.substr(model.length() - strlen(SAVEDMODEL_EXTENSION)) == SAVEDMODEL_EXTENSION) {
     backend = deepmd::DPBackend::JAX;

65-72: Consider introducing a separate build flag for JAX support

Currently, JAX support is conditioned on the BUILD_TENSORFLOW build flag. For better modularity and clarity, consider adding a separate build flag (e.g., BUILD_JAX) to control JAX-related compilation independently. This allows for more flexible build configurations and clearer dependency management.

source/lmp/pair_deepmd.cpp (3)

524-531: Mapping vector initialization is limited to single-process runs

The mapping_vec is initialized and populated only when comm->nprocs == 1. If mapping functionality is required for multi-process runs, consider extending this logic to support multiple processes.


575-577: Setting mapping in lmp_list is restricted to single-process runs

The mapping is set in lmp_list only when comm->nprocs == 1. If mapping is needed in multi-process configurations, ensure that mapping_vec is properly initialized and lmp_list.set_mapping() is called accordingly.


588-590: Consistent mapping setting in extend_lmp_list for multi-process runs

Similar to the previous cases, extend_lmp_list.set_mapping() is called only when comm->nprocs == 1. Consider extending support for multi-process scenarios if mapping is required.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between dabedd2 and 58dcf2b.

📒 Files selected for processing (22)
  • .pre-commit-config.yaml (1 hunks)
  • doc/backend.md (1 hunks)
  • doc/install/install-from-source.md (4 hunks)
  • doc/model/dpa2.md (1 hunks)
  • source/api_c/include/c_api.h (4 hunks)
  • source/api_c/include/deepmd.hpp (1 hunks)
  • source/api_c/src/c_api.cc (1 hunks)
  • source/api_cc/include/DeepPotJAX.h (1 hunks)
  • source/api_cc/include/common.h (1 hunks)
  • source/api_cc/src/DeepPot.cc (3 hunks)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
  • source/api_cc/tests/test_deeppot_jax.cc (1 hunks)
  • source/cmake/googletest.cmake.in (1 hunks)
  • source/lib/include/neighbor_list.h (2 hunks)
  • source/lmp/fix_dplr.cpp (2 hunks)
  • source/lmp/pair_deepmd.cpp (3 hunks)
  • source/lmp/tests/test_lammps_dpa_jax.py (1 hunks)
  • source/lmp/tests/test_lammps_jax.py (1 hunks)
  • source/tests/infer/deeppot_dpa.savedmodel/.gitignore (1 hunks)
  • source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb (1 hunks)
  • source/tests/infer/deeppot_sea.savedmodel/.gitignore (1 hunks)
  • source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb (1 hunks)
✅ Files skipped from review due to trivial changes (4)
  • source/tests/infer/deeppot_dpa.savedmodel/.gitignore
  • source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb
  • source/tests/infer/deeppot_sea.savedmodel/.gitignore
  • source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
🧰 Additional context used
🪛 LanguageTool
doc/backend.md

[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...

(SO_THAT_UNNECESSARY_COMMA)

🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 66-66: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 111-111: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 133-133: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

🔇 Additional comments (16)
source/cmake/googletest.cmake.in (1)

14-14: LGTM! Upgrading Google Test to v1.14.0.

The upgrade from release-1.12.1 to v1.14.0 is a good improvement that brings support for C++20 and various testing enhancements. This is a backward-compatible change.

Let's verify if there are any test files using C++20 features that would benefit from this upgrade:

✅ Verification successful

Upgrade to Google Test v1.14.0 is safe and beneficial

The upgrade from release-1.12.1 to v1.14.0 is safe. While we found no direct usage of C++20 features in test files, the codebase contains test utilities and the upgrade provides better testing infrastructure with improved death test implementation and custom printer support, which are valuable for the existing test suite.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Look for C++20 features in test files
# Expected: Files using C++20 features would benefit from this upgrade

# Search for potential C++20 features in test files
rg -l "concept|requires|constinit|consteval|co_await|co_yield|co_return|ranges::|views::|<span>|<concepts>" --type cpp

Length of output: 197


Script:

#!/bin/bash
# Let's check if these files are actually test files and examine their C++20 usage
rg -C3 "concept|requires|constinit|consteval|co_await|co_yield|co_return|ranges::|views::|<span>|<concepts>" source/lmp/pppm_dplr.cpp source/lmp/pair_deepmd.cpp source/3rdparty/json.hpp

# Also let's check if there are any actual test files in the repository
fd "test.*\.(cpp|h|hpp)$"

Length of output: 11735

doc/model/dpa2.md (1)

23-23: Verify the MPI rank limitation.

Let's confirm that the JAX backend indeed doesn't support multiple MPI ranks.

✅ Verification successful

The test files show clear evidence that MPI is not supported in the JAX backend. Specifically, in source/lmp/tests/test_lammps_dpa_jax.py, there's an explicit skip decorator @pytest.mark.skip("MPI is not supported") for the MPI test. Let me verify if there are any additional implementation details or documentation about this limitation.


MPI rank limitation for JAX backend is verified

The limitation is confirmed through multiple pieces of evidence in the codebase:

  • Explicit test skip with message "MPI is not supported" in source/lmp/tests/test_lammps_dpa_jax.py
  • Documentation in doc/model/dpa2.md correctly states the limitation
  • No MPI-specific implementation found in the JAX backend code
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Search for JAX-related MPI rank limitations in the codebase

# Look for JAX backend implementation files
echo "Searching for JAX backend files..."
fd -e py -e cpp -e h "jax|JAX"

# Search for MPI-related code in JAX implementation
echo "Searching for MPI-related code in JAX implementation..."
rg -i "mpi|rank" $(fd -e py -e cpp -e h "jax|JAX")

# Look for test files that might verify this limitation
echo "Searching for relevant test files..."
fd -e py "test.*jax"

Length of output: 2116


Script:

#!/bin/bash
# Search for documentation or comments about JAX MPI limitations
rg -i "jax.*mpi|mpi.*jax" --type md --type py --type cpp

# Look for any JAX-specific configuration or initialization code
ast-grep --pattern 'def $_($$$ jax $$$) {
  $$$
}'

Length of output: 301

.pre-commit-config.yaml (1)

18-19: LGTM! Please verify if other model files need exclusion.

The addition of deeppot_dpa.savedmodel/saved_model.pb to the exclusion list is appropriate for the model file. However, let's verify if there are other similar model files that should also be excluded.

✅ Verification successful

Exclusion pattern is correctly configured

The verification shows that source/tests/infer/deeppot_dpa.savedmodel/saved_model.pb (1.5M) is the only large .pb file in the repository. The current exclusion in .pre-commit-config.yaml correctly addresses this file, and no other model files need to be excluded at this time.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for other large .pb files that might need exclusion
# Expected: List of .pb files larger than 1024KB that aren't already excluded

# Find all .pb files
echo "Large .pb files (>1024KB) that might need exclusion:"
fd -e pb -x sh -c '
  size=$(stat -f%z {} 2>/dev/null || stat -c%s {});
  if [ "$size" -gt 1048576 ]; then
    echo "{} ($(numfmt --to=iec-i --suffix=B $size))";
  fi
'

Length of output: 485


Script:

#!/bin/bash
# Let's try a different approach using find and ls
echo "Large .pb files (>1024KB) that might need exclusion:"
find . -type f -name "*.pb" -exec ls -lh {} \; | awk '{if($5 ~ /[0-9.]+[MG]/) print $9, "("$5")"}'

Length of output: 284

doc/install/install-from-source.md (1)

343-343: LGTM!

The CMake configuration section correctly indicates that JAX uses the same configuration process as TensorFlow, which is consistent with their shared dependency on TensorFlow's C++ library.

source/api_c/include/c_api.h (2)

Line range hint 34-71: Documentation enhancement improves clarity.

The expanded documentation for DP_NewNlist_comm provides clearer context about its parallel communication capabilities and parameter usage.


15-15: API version increment is appropriate.

The increment from 23 to 24 correctly reflects the addition of new functionality (DP_NlistSetMapping).

source/api_c/src/c_api.cc (1)

46-48: Verify consistent error handling across the codebase.

Let's check if other neighbor list functions follow similar error handling patterns.

source/api_c/include/deepmd.hpp (1)

622-622: LGTM: Implementation follows established patterns.

The implementation correctly delegates to the C API function and follows the same pattern as other methods in the class.

source/api_cc/include/DeepPotJAX.h (2)

1-249: LGTM!

The DeepPotJAX class is well-structured and follows the existing design patterns in the codebase. It provides the necessary functionality to integrate JAX models into the DeePMD-kit framework. The class inherits from DeepPotBase, ensuring compatibility with the existing API. The constructors, destructor, and initialization methods are properly defined, allowing for flexible instantiation of the class.

The class includes methods to retrieve important model properties such as cutoff radius, number of types, and parameter dimensions. It also provides a get_type_map method to access the type map information. The computew and computew_mixed_type methods are overloaded to handle different input types (double and float) and enable the computation of energy, force, virial, and atomic properties based on the provided input data.

The private member variables are appropriately named and store relevant information such as initialization status, device information, model parameters, and TensorFlow C API objects for graph and session management.

Error handling is implemented using assertions to ensure that methods are called only after initialization, preventing potential runtime errors.

The class is well-documented with Doxygen-style comments, providing clear explanations of the purpose and parameters of each method.

Overall, the DeepPotJAX class is a solid addition to the DeePMD-kit framework, enabling seamless integration of JAX models.


4-5: Verify the TensorFlow C API header inclusion.

Please ensure that the TensorFlow C API headers tensorflow/c/c_api.h and tensorflow/c/eager/c_api.h are available in the include path and that the TensorFlow library is properly linked to the project.

source/api_cc/src/DeepPotJAX.cc (1)

1-581: Comprehensive implementation of the DeepPotJAX class for JAX integration.

The file provides a thorough implementation of the DeepPotJAX class, enabling seamless integration of the JAX backend within the deepmd-kit framework. It covers essential aspects such as:

  • Loading and initializing the JAX model
  • Managing TensorFlow resources and contexts
  • Implementing the core compute method for energy, force, and virial calculations
  • Handling different data types and input formats
  • Ensuring proper memory management and cleanup

The code follows best practices, including:

  • Appropriate error handling and status checks
  • Efficient memory management using smart pointers and RAII
  • Templated functions for generic programming
  • Consistent naming conventions and code style

Overall, the implementation is well-structured, modular, and adheres to the SOLID principles, promoting maintainability and extensibility.

🧰 Tools
🪛 cppcheck

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 66-66: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 111-111: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 133-133: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

source/lmp/tests/test_lammps_jax.py (1)

677-682: Verify the effectiveness of @pytest.mark.skipif decorators

The @pytest.mark.skipif decorators check for mpirun and mpi4py at import time. If the environment changes between import and execution, tests may not skip as intended.

Run the following script to ensure that tests are skipped appropriately:

✅ Verification successful

Skip conditions for MPI tests are correctly implemented

The skip conditions are properly implemented for the following reasons:

  • The shutil.which("mpirun") check is appropriate as it verifies the MPI executable's presence in the system PATH
  • The importlib.util.find_spec("mpi4py") check is the recommended way to verify Python package availability
  • Both checks are used together, ensuring all MPI dependencies are available before test execution
  • The test function itself uses these dependencies (mpirun command and implicitly mpi4py through the MPI script), confirming the necessity of these checks
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that the skip conditions are evaluated correctly.

# Expected behavior: Tests should skip if `mpirun` or `mpi4py` are not available.

pytest source/lmp/tests/test_lammps_jax.py -v -k test_pair_deepmd_mpi

Length of output: 145


Script:

#!/bin/bash
# Let's analyze the test file to understand the skip conditions and test implementation

# First, check the test function implementation
rg "def test_pair_deepmd_mpi" source/lmp/tests/test_lammps_jax.py -A 10

# Then, check all skipif decorators in the file
rg "@pytest.mark.skipif" source/lmp/tests/test_lammps_jax.py -A 2

# Check if there are any other MPI-related tests and their skip conditions
rg "mpi" source/lmp/tests/test_lammps_jax.py

Length of output: 1039

source/lmp/tests/test_lammps_dpa_jax.py (2)

680-684: ⚠️ Potential issue

Correct the MPI skip condition and message

The skip condition for the MPI tests may not be accurate, and the skip message "MPI is not supported" contradicts the previous checks for MPI installation.

Apply this diff to correct the skip condition and message:

 @pytest.mark.skipif(
     shutil.which("mpirun") is None, reason="MPI is not installed on this system"
 )
 @pytest.mark.skipif(
     importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed"
 )
-@pytest.mark.skip("MPI is not supported")
+@pytest.mark.skip(reason="MPI is currently not supported in this test")
 def test_pair_deepmd_mpi(balance_args: list):

Alternatively, consider enabling the test if MPI support is now available.

Likely invalid or redundant comment.


356-364: Verify the correctness of model deviation calculations

In the function test_pair_deepmd_model_devi, the computation of expected_md_v uses np.sum(expected_v, axis=0) which sums over atoms, possibly leading to incorrect virial deviation calculations.

Please ensure that the virial deviations are calculated correctly per atom rather than summing over atoms.

Run the following script to verify the computation:

If the virial deviations should not sum over atoms, consider modifying the calculation accordingly.

source/api_cc/src/DeepPot.cc (2)

10-10: Including DeepPotJAX.h under BUILD_TENSORFLOW

Including DeepPotJAX.h within the #ifdef BUILD_TENSORFLOW block aligns with the dependency that JAX support requires TensorFlow. This ensures that JAX-related code is only compiled when TensorFlow support is enabled.


65-72: Appropriate initialization and error handling for JAX backend

The code correctly initializes the JAX backend when TensorFlow support is available and provides clear exception handling when it is not. This ensures users are informed about the dependency on TensorFlow for loading JAX2TF SavedModels.

@@ -13,7 +13,7 @@
namespace deepmd {

typedef double ENERGYTYPE;
enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown };
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Missing JAX case in some backend handlers

The verification reveals that JAX backend case is missing in two files:

  • source/api_cc/src/DeepTensor.cc: JAX case not handled in if-else chain
  • source/api_cc/src/DataModifier.cc: JAX case not handled in if-else chain

While DeepPot.cc properly handles the JAX backend, the other backend handlers need to be updated for consistency.

🔗 Analysis chain

LGTM! Verify switch statements for the new backend.

The addition of JAX to the DPBackend enum is clean and properly placed before the Unknown value.

Let's verify that all switch statements handling DPBackend are updated to include the JAX case:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for switch statements or if-else chains handling DPBackend
rg -A 10 "switch.*DPBackend|if.*DPBackend.*==" source/

Length of output: 4502

Comment on lines +104 to +107
/**
* @brief Set mapping for this neighbor list.
*/
void set_mapping(int* mapping_) { mapping = mapping_; };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance method robustness and documentation.

The set_mapping method needs:

  1. Parameter validation
  2. Clear ownership semantics
  3. Size validation
  4. More detailed documentation
- /**
-  * @brief Set mapping for this neighbor list.
-  */
- void set_mapping(int* mapping_) { mapping = mapping_; };
+ /**
+  * @brief Set mapping from all atoms to real atoms.
+  * @param mapping_ Pointer to mapping array of size nall. Must not be nullptr.
+  * @param size Size of the mapping array.
+  * @throws std::invalid_argument if mapping_ is nullptr or size is invalid.
+  * @note The class takes ownership of the mapping array.
+  */
+ void set_mapping(int* mapping_, size_t size) {
+   if (!mapping_ || size == 0) {
+     throw std::invalid_argument("Invalid mapping array");
+   }
+   mapping.reset(mapping_);
+   mapping_size = size;
+ };

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +97 to +157
TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist) {
using VALUETYPE = TypeParam;
std::vector<VALUETYPE>& coord = this->coord;
std::vector<int>& atype = this->atype;
std::vector<VALUETYPE>& box = this->box;
std::vector<VALUETYPE>& expected_e = this->expected_e;
std::vector<VALUETYPE>& expected_f = this->expected_f;
std::vector<VALUETYPE>& expected_v = this->expected_v;
int& natoms = this->natoms;
double& expected_tot_e = this->expected_tot_e;
std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
deepmd::DeepPot& dp = this->dp;
float rc = dp.cutoff();
int nloc = coord.size() / 3;
std::vector<VALUETYPE> coord_cpy;
std::vector<int> atype_cpy, mapping;
std::vector<std::vector<int> > nlist_data;
_build_nlist<VALUETYPE>(nlist_data, coord_cpy, atype_cpy, mapping, coord,
atype, box, rc);
int nall = coord_cpy.size() / 3;
std::vector<int> ilist(nloc), numneigh(nloc);
std::vector<int*> firstneigh(nloc);
deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]);
convert_nlist(inlist, nlist_data);

double ener;
std::vector<VALUETYPE> force_, virial;
dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc,
inlist, 0);
std::vector<VALUETYPE> force;
_fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3);

EXPECT_EQ(force.size(), natoms * 3);
EXPECT_EQ(virial.size(), 9);

EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
for (int ii = 0; ii < natoms * 3; ++ii) {
EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
}
for (int ii = 0; ii < 3 * 3; ++ii) {
EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
}

ener = 0.;
std::fill(force_.begin(), force_.end(), 0.0);
std::fill(virial.begin(), virial.end(), 0.0);
dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc,
inlist, 1);
_fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3);

EXPECT_EQ(force.size(), natoms * 3);
EXPECT_EQ(virial.size(), 9);

EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
for (int ii = 0; ii < natoms * 3; ++ii) {
EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
}
for (int ii = 0; ii < 3 * 3; ++ii) {
EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Reduce code duplication in test cases.

The test cases share significant amounts of code for setup and validation. This makes maintenance harder and increases the chance of inconsistencies.

Consider extracting common setup and validation code into helper methods:

template <typename VALUETYPE>
void ValidateResults(
    const std::vector<VALUETYPE>& force,
    const std::vector<VALUETYPE>& virial,
    const std::vector<VALUETYPE>& expected_f,
    const std::vector<VALUETYPE>& expected_tot_v,
    double ener,
    double expected_tot_e,
    int natoms) {
  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);
  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

Also applies to: 159-240, 242-302, 304-364, 366-427

Comment on lines +429 to +432
TYPED_TEST(TestInferDeepPotAJAX, print_summary) {
deepmd::DeepPot& dp = this->dp;
dp.print_summary("");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance print_summary test case.

The print_summary test case doesn't verify the output. It should validate that the summary contains expected information.

Consider capturing and validating the output:

 TYPED_TEST(TestInferDeepPotAJAX, print_summary) {
   deepmd::DeepPot& dp = this->dp;
-  dp.print_summary("");
+  testing::internal::CaptureStdout();
+  dp.print_summary("");
+  std::string output = testing::internal::GetCapturedStdout();
+  
+  EXPECT_THAT(output, testing::HasSubstr("Summary of DeepPot Model"));
+  EXPECT_THAT(output, testing::HasSubstr("Cutoff radius"));
+  EXPECT_THAT(output, testing::HasSubstr("Model type"));
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
TYPED_TEST(TestInferDeepPotAJAX, print_summary) {
deepmd::DeepPot& dp = this->dp;
dp.print_summary("");
}
TYPED_TEST(TestInferDeepPotAJAX, print_summary) {
deepmd::DeepPot& dp = this->dp;
testing::internal::CaptureStdout();
dp.print_summary("");
std::string output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("Summary of DeepPot Model"));
EXPECT_THAT(output, testing::HasSubstr("Cutoff radius"));
EXPECT_THAT(output, testing::HasSubstr("Model type"));
}

Comment on lines +71 to +90
void SetUp() override {
std::string file_name = "../../tests/infer/deeppot_sea.savedmodel";

dp.init(file_name);

natoms = expected_e.size();
EXPECT_EQ(natoms * 3, expected_f.size());
EXPECT_EQ(natoms * 9, expected_v.size());
expected_tot_e = 0.;
expected_tot_v.resize(9);
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
for (int ii = 0; ii < natoms; ++ii) {
expected_tot_e += expected_e[ii];
}
for (int ii = 0; ii < natoms; ++ii) {
for (int dd = 0; dd < 9; ++dd) {
expected_tot_v[dd] += expected_v[ii * 9 + dd];
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling in SetUp method.

The SetUp method should handle potential errors when initializing the DeepPot model and validating test data.

Consider adding error handling:

 void SetUp() override {
   std::string file_name = "../../tests/infer/deeppot_sea.savedmodel";
-
-  dp.init(file_name);
+  try {
+    dp.init(file_name);
+  } catch (const std::exception& e) {
+    FAIL() << "Failed to initialize DeepPot model: " << e.what();
+  }
 
   natoms = expected_e.size();
-  EXPECT_EQ(natoms * 3, expected_f.size());
-  EXPECT_EQ(natoms * 9, expected_v.size());
+  ASSERT_EQ(natoms * 3, expected_f.size()) << "Force array size mismatch";
+  ASSERT_EQ(natoms * 9, expected_v.size()) << "Virial array size mismatch";
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void SetUp() override {
std::string file_name = "../../tests/infer/deeppot_sea.savedmodel";
dp.init(file_name);
natoms = expected_e.size();
EXPECT_EQ(natoms * 3, expected_f.size());
EXPECT_EQ(natoms * 9, expected_v.size());
expected_tot_e = 0.;
expected_tot_v.resize(9);
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
for (int ii = 0; ii < natoms; ++ii) {
expected_tot_e += expected_e[ii];
}
for (int ii = 0; ii < natoms; ++ii) {
for (int dd = 0; dd < 9; ++dd) {
expected_tot_v[dd] += expected_v[ii * 9 + dd];
}
}
}
void SetUp() override {
std::string file_name = "../../tests/infer/deeppot_sea.savedmodel";
try {
dp.init(file_name);
} catch (const std::exception& e) {
FAIL() << "Failed to initialize DeepPot model: " << e.what();
}
natoms = expected_e.size();
ASSERT_EQ(natoms * 3, expected_f.size()) << "Force array size mismatch";
ASSERT_EQ(natoms * 9, expected_v.size()) << "Virial array size mismatch";
expected_tot_e = 0.;
expected_tot_v.resize(9);
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
for (int ii = 0; ii < natoms; ++ii) {
expected_tot_e += expected_e[ii];
}
for (int ii = 0; ii < natoms; ++ii) {
for (int dd = 0; dd < 9; ++dd) {
expected_tot_v[dd] += expected_v[ii * 9 + dd];
}
}
}

Comment on lines +216 to +279
void deepmd::DeepPotJAX::init(const std::string& model,
const int& gpu_rank,
const std::string& file_content) {
if (inited) {
std::cerr << "WARNING: deepmd-kit should not be initialized twice, do "
"nothing at the second call of initializer"
<< std::endl;
return;
}

const char* saved_model_dir = model.c_str();
graph = TF_NewGraph();
status = TF_NewStatus();

sessionopts = TF_NewSessionOptions();
TF_Buffer* runopts = NULL;

const char* tags = "serve";
int ntags = 1;

session = TF_LoadSessionFromSavedModel(sessionopts, runopts, saved_model_dir,
&tags, ntags, graph, NULL, status);
check_status(status);

int nfuncs = TF_GraphNumFunctions(graph);
// allocate memory for the TF_Function* array
func_vector.resize(nfuncs);
TF_Function** funcs = func_vector.data();
TF_GraphGetFunctions(graph, funcs, nfuncs, status);
check_status(status);

ctx_opts = TFE_NewContextOptions();
ctx = TFE_NewContext(ctx_opts, status);
check_status(status);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
int gpu_num;
DPGetDeviceCount(gpu_num); // check current device environment
DPErrcheck(DPSetDevice(gpu_rank % gpu_num));
if (gpu_num > 0) {
device = "/gpu:" + std::to_string(gpu_rank % gpu_num);
} else {
device = "/cpu:0";
}
#else
device = "/cpu:0";
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

rcut = get_scalar<double>(ctx, "get_rcut", func_vector, device, status);
dfparam =
get_scalar<int64_t>(ctx, "get_dim_fparam", func_vector, device, status);
daparam =
get_scalar<int64_t>(ctx, "get_dim_aparam", func_vector, device, status);
std::vector<std::string> type_map_ =
get_vector_string(ctx, "get_type_map", func_vector, device, status);
// deepmd-kit stores type_map as a concatenated string, split by ' '
type_map = type_map_[0];
for (size_t i = 1; i < type_map_.size(); i++) {
type_map += " " + type_map_[i];
}
ntypes = type_map_.size();
sel = get_vector<int64_t>(ctx, "get_sel", func_vector, device, status);
nnei = std::accumulate(sel.begin(), sel.end(), decltype(sel)::value_type(0));
inited = true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use std::unique_ptr to manage TensorFlow resources.

In the init function, consider using std::unique_ptr with custom deleters to manage TensorFlow resources, such as TF_Graph, TF_Status, TF_SessionOptions, and TFE_ContextOptions. This will ensure proper cleanup of these resources and make the code more exception-safe.

void deepmd::DeepPotJAX::init(const std::string& model,
                              const int& gpu_rank,
                              const std::string& file_content) {
  if (inited) {
    std::cerr << "WARNING: deepmd-kit should not be initialized twice, do "
                 "nothing at the second call of initializer"
              << std::endl;
    return;
  }

  const char* saved_model_dir = model.c_str();
  auto graph = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>(
      TF_NewGraph(), TF_DeleteGraph);
  auto status = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>(
      TF_NewStatus(), TF_DeleteStatus);

  auto sessionopts = std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>(
      TF_NewSessionOptions(), TF_DeleteSessionOptions);
  TF_Buffer* runopts = nullptr;

  const char* tags = "serve";
  int ntags = 1;

  session = TF_LoadSessionFromSavedModel(sessionopts.get(), runopts, saved_model_dir,
                                         &tags, ntags, graph.get(), nullptr, status.get());
  check_status(status.get());

  int nfuncs = TF_GraphNumFunctions(graph.get());
  func_vector.resize(nfuncs);
  TF_GraphGetFunctions(graph.get(), func_vector.data(), nfuncs, status.get());
  check_status(status.get());

  auto ctx_opts = std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)>(
      TFE_NewContextOptions(), TFE_DeleteContextOptions);
  ctx = TFE_NewContext(ctx_opts.get(), status.get());
  check_status(status.get());

  // ... (rest of the code remains the same)
}

Comment on lines +281 to +293
deepmd::DeepPotJAX::~DeepPotJAX() {
if (inited) {
TF_DeleteSession(session, status);
TF_DeleteGraph(graph);
TF_DeleteSessionOptions(sessionopts);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
TFE_DeleteContextOptions(ctx_opts);
for (size_t i = 0; i < func_vector.size(); i++) {
TF_DeleteFunction(func_vector[i]);
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use std::unique_ptr to manage TensorFlow resources in the destructor.

In the DeepPotJAX destructor, consider using std::unique_ptr with custom deleters to manage the cleanup of TensorFlow resources. This will ensure proper cleanup and make the code more readable and maintainable.

deepmd::De

<!-- This is an auto-generated comment by CodeRabbit -->

Comment on lines +343 to +345
lammps.pair_style(
f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove temporary model deviation file after tests

The md_file is created during tests but not removed afterward, which may clutter the file system with temporary files.

Add cleanup code to teardown_module():

def teardown_module():
    os.remove(data_file)
    os.remove(data_type_map_file)
+   os.remove(md_file)

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +241 to +244
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure all generated data files are removed in teardown_module()

The teardown_module() function currently removes data_file and data_type_map_file but does not remove data_file_si. This may leave temporary files after tests are run.

Apply this diff to remove data_file_si during teardown:

def teardown_module():
    os.remove(data_file)
    os.remove(data_type_map_file)
+   os.remove(data_file_si)

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +230 to +239
def setup_module():
write_lmp_data(box, coord, type_OH, data_file)
write_lmp_data(box, coord, type_HO, data_type_map_file)
write_lmp_data(
box * constants.dist_metal2si,
coord * constants.dist_metal2si,
type_OH,
data_file_si,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing teardown for data_file_si

The setup_module function creates data_file_si, but teardown_module does not remove it, which may leave residual files after tests are run.

Apply this diff to remove data_file_si during teardown:

 def teardown_module():
     os.remove(data_file)
     os.remove(data_type_map_file)
+    os.remove(data_file_si)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def setup_module():
write_lmp_data(box, coord, type_OH, data_file)
write_lmp_data(box, coord, type_HO, data_type_map_file)
write_lmp_data(
box * constants.dist_metal2si,
coord * constants.dist_metal2si,
type_OH,
data_file_si,
)
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)
os.remove(data_file_si)

Copy link

codecov bot commented Nov 6, 2024

Codecov Report

Attention: Patch coverage is 93.12377% with 35 lines in your changes missing coverage. Please review.

Project coverage is 84.31%. Comparing base (dabedd2) to head (58dcf2b).
Report is 19 commits behind head on devel.

Files with missing lines Patch % Lines
source/api_cc/src/DeepPotJAX.cc 88.18% 18 Missing and 12 partials ⚠️
source/lmp/fix_dplr.cpp 66.66% 0 Missing and 2 partials ⚠️
source/lmp/pair_deepmd.cpp 75.00% 2 Missing ⚠️
source/api_cc/src/DeepPot.cc 75.00% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4316      +/-   ##
==========================================
- Coverage   84.40%   84.31%   -0.09%     
==========================================
  Files         570      573       +3     
  Lines       53071    53581     +510     
  Branches     3054     3110      +56     
==========================================
+ Hits        44794    45177     +383     
- Misses       7318     7429     +111     
- Partials      959      975      +16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

coderabbitai bot commented Nov 6, 2024

📝 Walkthrough

Walkthrough

This pull request introduces several updates across multiple files, primarily focusing on the integration of the JAX backend into the DeePMD-kit framework. Key changes include modifications to pre-commit configurations to exclude specific files, enhancements to documentation regarding backend support, and the addition of new functionalities in the C API and related classes. The JAX backend is now fully supported, with corresponding updates to the installation instructions, API headers, and various source files to ensure compatibility and functionality across the codebase.

Changes

File Path Change Summary
.pre-commit-config.yaml Updated hooks: check-added-large-files with new exclusion, insert-license for Python and C++ files to exclude source/3rdparty, pylint updated to set PYTHONPATH.
doc/backend.md Added JAX as a backend option, specified model and checkpoint extensions, noted C++ inference limitations, and clarified GPU-specific model constraints.
doc/install/install-from-source.md Updated installation instructions for C++ interface to include JAX, clarified Python version requirements, and reorganized environment variables section.
doc/model/dpa2.md Added section on JAX backend limitations with LAMMPS.
source/api_c/include/c_api.h Incremented API version to 24, added DP_NlistSetMapping function, and updated documentation for existing functions.
source/api_c/include/deepmd.hpp Introduced set_mapping method in InputNlist structure.
source/api_c/src/c_api.cc Added DP_NlistSetMapping function implementation.
source/api_cc/include/DeepPotJAX.h Introduced DeepPotJAX class with multiple constructors and methods for TensorFlow integration.
source/api_cc/include/common.h Added JAX value to DPBackend enum.
source/api_cc/src/DeepPot.cc Updated DeepPot class to handle JAX backend initialization logic.
source/api_cc/src/DeepPotJAX.cc Implemented functionalities for TensorFlow integration within DeepPotJAX.
source/api_cc/tests/test_deeppot_jax.cc Introduced unit tests for DeepPotJAX class using Google Test framework.
source/cmake/googletest.cmake.in Updated Google Test library version from release-1.12.1 to v1.14.0.
source/lib/include/neighbor_list.h Added int* mapping member and set_mapping method to InputNlist.
source/lmp/fix_dplr.cpp Enhanced FixDPLR class for atom mapping and communication in parallel contexts.
source/lmp/pair_deepmd.cpp Updated PairDeepMD class to improve atom mapping handling.
source/lmp/tests/test_lammps_dpa_jax.py Added comprehensive tests for DPA with LAMMPS.
source/lmp/tests/test_lammps_jax.py Added tests for LAMMPS integration with DeepMD potential.
source/tests/infer/deeppot_dpa.savedmodel/.gitignore Updated to ignore .pb files except those matching !*.pb.
source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb Added binary file fingerprint.pb.
source/tests/infer/deeppot_sea.savedmodel/.gitignore Updated to ignore .pb files.
source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb Added binary file fingerprint.pb.

Possibly related issues

Possibly related PRs

Suggested reviewers

  • njzjz
  • wanghan-iapcm
  • iProzd

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (22)
source/cmake/googletest.cmake.in (1)

Line range hint 1-1: Consider updating minimum CMake version requirement.

The current minimum CMake version (2.8.2) is quite old. Consider updating to a more recent version (e.g., 3.10 or later) to leverage modern CMake features and best practices.

doc/model/dpa2.md (1)

21-27: Documentation looks good, but could be more detailed.

The new section clearly documents the limitations and provides the necessary command. Consider enhancing it with:

  • A brief explanation of why these limitations exist
  • Potential performance implications of using map yes
  • A more prominent warning about the MPI rank limitation
 ## Limiations of the JAX backend with LAMMPS {{ jax_icon }}
 
-When using the JAX backend, 2 or more MPI ranks are not supported. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command.
+:::warning
+When using the JAX backend:
+- 2 or more MPI ranks are not supported due to JAX's current parallel processing limitations
+- The `map` option must be set to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command to ensure proper atom indexing
+
+Note: Using `map yes` may have a small performance overhead but is required for correct operation with the JAX backend.
+:::
 
 ```lammps
 atom_modify map yes

</blockquote></details>
<details>
<summary>doc/backend.md (2)</summary><blockquote>

`34-36`: **Fix typographical error and enhance clarity of limitations.**

The limitations section for JAX backend has a minor typographical issue and could be clearer about the implications.

Apply these changes:

```diff
-The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs.
+The model is device-specific, so models generated on GPU devices cannot be run on CPUs.
🧰 Tools
🪛 LanguageTool

[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...

(SO_THAT_UNNECESSARY_COMMA)


34-36: Consider adding migration guidance for users.

The limitations described (C++ inference only with SavedModel, device specificity, and no training support) are significant. Consider adding guidance for users migrating from other backends.

Would you like me to help draft a migration guide section that covers:

  • When to use JAX vs other backends
  • Migration paths from TensorFlow/PyTorch
  • Best practices for handling device-specific models
🧰 Tools
🪛 LanguageTool

[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...

(SO_THAT_UNNECESSARY_COMMA)

source/lib/include/neighbor_list.h (1)

47-48: Enhance documentation for the mapping member.

While the basic purpose is documented, please add more details about:

  • Memory ownership (who allocates/deallocates)
  • Size requirements (how "nall" is determined)
  • Expected values in the mapping array
doc/install/install-from-source.md (2)

300-302: Enhance clarity regarding JAX backend requirements and dependencies.

The current documentation groups TensorFlow and JAX backends together which might be confusing. Consider:

  1. Separating JAX into its own tab-item section to clearly distinguish it from TensorFlow
  2. Adding JAX-specific version requirements
  3. Clarifying why JAX backend needs TensorFlow C++ library
-:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}
+:::{tab-item} TensorFlow {{ tensorflow_icon }}
+
+:::
+
+:::{tab-item} JAX {{ jax_icon }}
+
+Note: The JAX backend requires JAX version 0.4.33 or above and uses TensorFlow C++ library for its C++ interface implementation.

396-396: Clarify TENSORFLOW_ROOT usage for JAX backend.

The description should better explain why this path is needed for the JAX backend and how it's used.

-{{ tensorflow_icon }} {{ jax_icon }} The Path to TensorFlow's C++ interface.
+{{ tensorflow_icon }} {{ jax_icon }} The Path to TensorFlow's C++ interface. This is required for both TensorFlow and JAX backends as JAX's C++ interface is implemented using TensorFlow's C++ libraries.
source/api_c/include/c_api.h (1)

81-89: Documentation could be more detailed.

While the function documentation follows the established format, consider enhancing it with:

  1. Description of the expected size and format of the mapping array
  2. Documentation of the return value (void)
  3. Example usage or typical use case
 /**
  * @brief Set mapping for a neighbor list.
  *
- * @param nl Neighbor list.
- * @param mapping mapping from all atoms to real atoms, in size nall.
+ * @param[in] nl Pointer to the neighbor list to be modified.
+ * @param[in] mapping Array of size nall that maps from all atoms to real atoms.
+ * @return void
+ * @note Typical use case: When working with ghost atoms or periodic boundary conditions,
+ *       this mapping helps translate between local and global atom indices.
  * @since API version 24
  *
  **/
source/api_c/include/deepmd.hpp (1)

618-622: Enhance documentation for the set_mapping method.

The implementation looks good, but the documentation could be more detailed to clarify:

  • The ownership and lifetime requirements of the mapping pointer
  • Whether nullptr is a valid input
  • The expected size of the mapping array

Consider updating the documentation like this:

  /**
   * @brief Set mapping for this neighbor list.
-  * @param mapping mapping from all atoms to real atoms, in size nall.
+  * @param mapping Pointer to an array that maps from all atoms to real atoms. The array size must match the total number of atoms (nall).
+  *               The pointer must remain valid for the lifetime of the neighbor list.
+  *               A nullptr can be passed to reset/clear the mapping.
   */
source/api_cc/include/DeepPotJAX.h (2)

46-49: Add a const qualifier to the cutoff() method.

Since the cutoff() method does not modify the object's state and only returns the rcut member variable, consider adding the const qualifier to the method to indicate that it is a read-only operation. This helps improve code clarity and allows the method to be called on const objects.

-double cutoff() const {
+double cutoff() const {
   assert(inited);
   return rcut;
 };

233-247: Consider using std::size_t for array indexing and sizes.

In the compute template method, consider using std::size_t instead of int for variables that represent array sizes or indices, such as nghost and ago. This ensures compatibility with the size type returned by std::vector::size() and avoids potential issues with signed/unsigned conversions.

 template <typename VALUETYPE>
 void compute(std::vector<ENERGYTYPE>& ener,
              std::vector<VALUETYPE>& force,
              std::vector<VALUETYPE>& virial,
              std::vector<VALUETYPE>& atom_energy,
              std::vector<VALUETYPE>& atom_virial,
              const std::vector<VALUETYPE>& coord,
              const std::vector<int>& atype,
              const std::vector<VALUETYPE>& box,
-             const int nghost,
+             const std::size_t nghost,
              const InputNlist& lmp_list,
-             const int& ago,
+             const std::size_t& ago,
              const std::vector<VALUETYPE>& fparam,
              const std::vector<VALUETYPE>& aparam,
              const bool atomic);
source/api_cc/src/DeepPotJAX.cc (6)

27-45: Consider passing func_name by const reference for better performance.

To avoid unnecessary string copying, consider passing the func_name parameter by const reference:

-inline void find_function(TF_Function*& found_func,
-                          const std::vector<TF_Function*>& funcs,
-                          const std::string func_name) {
+inline void find_function(TF_Function*& found_func,
+                          const std::vector<TF_Function*>& funcs,
+                          const std::string& func_name) {
🧰 Tools
🪛 cppcheck

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)


37-38: Consider using resize() or pop_back() instead of substr() for better performance.

Assigning a prefix of the string to itself using substr() is ineffective. Consider using resize() or pop_back() instead for better performance:

-      name_ = name_.substr(0, pos + 1);
+      name_.resize(pos + 1);
🧰 Tools
🪛 cppcheck

[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)


63-82: Consider passing func_name and device by const reference for better performance.

To avoid unnecessary string copying, consider passing the func_name and device parameters by const reference:

-inline TFE_Op* get_func_op(TFE_Context* ctx,
-                           const std::string func_name,
-                           const std::vector<TF_Function*>& funcs,
-                           const std::string device,
-                           TF_Status* status) {
+inline TFE_Op* get_func_op(TFE_Context* ctx,
+                           const std::string& func_name,
+                           const std::vector<TF_Function*>& funcs,
+                           const std::string& device,
+                           TF_Status* status) {
🧰 Tools
🪛 cppcheck

[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 66-66: Function parameter 'device' should be passed by const reference.

(passedByValue)


84-107: Consider passing func_name and device by const reference for better performance.

To avoid unnecessary string copying, consider passing the func_name and device parameters by const reference:

-template <typename T>
-inline T get_scalar(TFE_Context* ctx,
-                    const std::string func_name,
-                    const std::vector<TF_Function*>& funcs,
-                    const std::string device,
-                    TF_Status* status) {
+template <typename T>
+inline T get_scalar(TFE_Context* ctx,
+                    const std::string& func_name,
+                    const std::vector<TF_Function*>& funcs,
+                    const std::string& device,
+                    TF_Status* status) {
🧰 Tools
🪛 cppcheck

[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


109-129: Consider passing func_name and device by const reference for better performance.

To avoid unnecessary string copying, consider passing the func_name and device parameters by const reference:

-template <typename T>
-inline std::vector<T> get_vector(TFE_Context* ctx,
-                                 const std::string func_name,
-                                 const std::vector<TF_Function*>& funcs,
-                                 const std::string device,
-                                 TF_Status* status) {
+template <typename T>
+inline std::vector<T> get_vector(TFE_Context* ctx,
+                                 const std::string& func_name,
+                                 const std::vector<TF_Function*>& funcs,
+                                 const std::string& device,
+                                 TF_Status* status) {
🧰 Tools
🪛 cppcheck

[performance] 111-111: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'device' should be passed by const reference.

(passedByValue)


131-166: Consider passing func_name and device by const reference for better performance.

To avoid unnecessary string copying, consider passing the func_name and device parameters by const reference:

-inline std::vector<std::string> get_vector_string(
-    TFE_Context* ctx,
-    const std::string func_name,
-    const std::vector<TF_Function*>& funcs,
-    const std::string device,
-    TF_Status* status) {
+inline std::vector<std::string> get_vector_string(
+    TFE_Context* ctx,
+    const std::string& func_name,
+    const std::vector<TF_Function*>& funcs,
+    const std::string& device,
+    TF_Status* status) {
🧰 Tools
🪛 cppcheck

[performance] 133-133: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'device' should be passed by const reference.

(passedByValue)

source/lmp/tests/test_lammps_jax.py (2)

246-276: Consider adding error handling for file read operations.

The _lammps function sets up a LAMMPS instance with the provided data file and units. Consider adding error handling for the lammps.read_data operation to gracefully handle cases where the data file is missing or inaccessible.

try:
    lammps.read_data(data_file.resolve())
except FileNotFoundError:
    raise FileNotFoundError(f"Data file not found: {data_file}")
except Exception as e:
    raise RuntimeError(f"Error reading data file: {data_file}. {str(e)}")

687-723: Consider adding error handling for the subprocess call.

The test_pair_deepmd_mpi function runs the run_mpi_pair_deepmd.py script using MPI and checks the potential energy and model deviation output against the expected values. Consider adding error handling for the sp.check_call to gracefully handle cases where the script fails to run.

try:
    sp.check_call(...)
except sp.CalledProcessError as e:
    raise RuntimeError(f"Failed to run MPI script: {e}")
source/lmp/tests/test_lammps_dpa_jax.py (1)

246-279: Consider extracting common LAMMPS setup code into a separate function.

The _lammps function contains a lot of common setup code for creating a LAMMPS instance with specific settings. Consider extracting this setup code into a separate function to improve readability and maintainability.

source/lmp/pair_deepmd.cpp (2)

527-529: Use consistent integer types in the loop

In the loop starting at line 527, ii is declared as size_t, while nall is of type int. Mixing signed and unsigned integer types can lead to potential issues. Consider changing ii to int to match the type of nall.

Apply this diff to change the loop variable type:

-for (size_t ii = 0; ii < nall; ++ii) {
+for (int ii = 0; ii < nall; ++ii) {

575-577: Refactor duplicate mapping logic into a helper function

The mapping logic in lines 575-577 and 588-590 is duplicated. To improve maintainability and reduce code duplication, consider refactoring this code into a helper function or lambda.

Also applies to: 588-590

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between dabedd2 and 58dcf2b.

📒 Files selected for processing (22)
  • .pre-commit-config.yaml (1 hunks)
  • doc/backend.md (1 hunks)
  • doc/install/install-from-source.md (4 hunks)
  • doc/model/dpa2.md (1 hunks)
  • source/api_c/include/c_api.h (4 hunks)
  • source/api_c/include/deepmd.hpp (1 hunks)
  • source/api_c/src/c_api.cc (1 hunks)
  • source/api_cc/include/DeepPotJAX.h (1 hunks)
  • source/api_cc/include/common.h (1 hunks)
  • source/api_cc/src/DeepPot.cc (3 hunks)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
  • source/api_cc/tests/test_deeppot_jax.cc (1 hunks)
  • source/cmake/googletest.cmake.in (1 hunks)
  • source/lib/include/neighbor_list.h (2 hunks)
  • source/lmp/fix_dplr.cpp (2 hunks)
  • source/lmp/pair_deepmd.cpp (3 hunks)
  • source/lmp/tests/test_lammps_dpa_jax.py (1 hunks)
  • source/lmp/tests/test_lammps_jax.py (1 hunks)
  • source/tests/infer/deeppot_dpa.savedmodel/.gitignore (1 hunks)
  • source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb (1 hunks)
  • source/tests/infer/deeppot_sea.savedmodel/.gitignore (1 hunks)
  • source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb (1 hunks)
✅ Files skipped from review due to trivial changes (4)
  • source/tests/infer/deeppot_dpa.savedmodel/.gitignore
  • source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb
  • source/tests/infer/deeppot_sea.savedmodel/.gitignore
  • source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
🧰 Additional context used
🪛 LanguageTool
doc/backend.md

[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...

(SO_THAT_UNNECESSARY_COMMA)

🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 66-66: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 111-111: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 133-133: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

🔇 Additional comments (91)
source/cmake/googletest.cmake.in (1)

14-14: Verify compatibility with Google Test v1.14.0.

The upgrade from release-1.12.1 to v1.14.0 looks good. This version brings improvements including C++20 support and various bug fixes.

Let's verify if there are any potential compatibility issues:

✅ Verification successful

Based on the search results, I can generate the final response since we have sufficient information:

Google Test v1.14.0 upgrade appears safe to proceed

The codebase search reveals:

  • No usage of death tests (ASSERT_DEATH, EXPECT_DEATH, ASSERT_EXIT, EXPECT_EXIT) that could be affected by the improvements in v1.14.0
  • While there are files containing C++20 related keywords, they are primarily in documentation or non-test files
  • The test files under source/tests/ don't show any patterns that would be impacted by breaking changes between v1.12.1 and v1.14.0
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for Google Test usage patterns that might be affected by the version upgrade

# Test: Look for death test usage which had improvements in v1.14.0
rg "ASSERT_DEATH|EXPECT_DEATH|ASSERT_EXIT|EXPECT_EXIT"

# Test: Check for C++20 features usage that might benefit from the upgrade
rg -l "concept|requires|constinit|consteval"

Length of output: 2447

doc/model/dpa2.md (1)

23-23: Verify MPI rank limitation in tests and code.

The documentation states that "2 or more MPI ranks are not supported". Let's verify this limitation is properly enforced in the codebase.

✅ Verification successful

MPI rank limitation is properly enforced in JAX tests

The codebase correctly enforces the MPI rank limitation for the JAX backend through test files. Specifically in source/lmp/tests/test_lammps_dpa_jax.py, there is an explicit test skip with the message "MPI is not supported" for the JAX backend:

@pytest.mark.skip("MPI is not supported")
def test_pair_deepmd_mpi(balance_args: list):

This test skip enforces the documented limitation that "2 or more MPI ranks are not supported" when using the JAX backend. The implementation matches the documentation's requirements.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for MPI rank validation in JAX-related code and tests

# Look for MPI rank checks in JAX-related files
echo "Searching for MPI rank validation in source files..."
rg -l "MPI|mpi.*rank|n.*ranks" --type cpp --type python source/

# Look for test cases that verify this limitation
echo "Searching for related test cases..."
rg -l "test.*mpi.*rank|test.*jax.*mpi" --type python source/lmp/tests/

Length of output: 396


Script:

#!/bin/bash
# Let's try a different approach to find JAX and MPI related code

# First, let's find all Python and C++ files that mention JAX
echo "Finding JAX-related files..."
rg -l "jax" --type-add 'python:*.py' --type-add 'cpp:*.{cpp,h,hpp}' 

# Then search for MPI rank validation in those files
echo -e "\nSearching for MPI/rank related content in JAX files..."
rg "MPI|mpi.*rank|n.*ranks|world_size" --type-add 'python:*.py' --type-add 'cpp:*.{cpp,h,hpp}' -C 2

# Look specifically for test files
echo -e "\nSearching for test files mentioning both JAX and MPI..."
rg -l "jax.*test|test.*jax" --type-add 'python:*.py'

Length of output: 107291


Script:

#!/bin/bash
# Let's specifically look for JAX MPI rank validation in the codebase

# Search for JAX-related MPI rank checks in Python files
echo "Searching for JAX MPI rank validation..."
rg -C 3 "jax.*mpi|mpi.*jax|rank.*jax|jax.*rank" --type-add 'python:*.py'

# Look for JAX test files that handle MPI ranks
echo -e "\nSearching for JAX test files with MPI rank handling..."
rg -C 3 "test.*jax.*rank|test.*jax.*mpi" --type-add 'python:*.py'

# Look for any JAX-related MPI limitations or assertions
echo -e "\nSearching for JAX MPI limitations or assertions..."
rg -C 3 "assert.*rank|rank.*assert|mpi.*not.*support.*jax|jax.*not.*support.*mpi" --type-add 'python:*.py'

Length of output: 584


Script:

#!/bin/bash
# Let's look for evidence of JAX MPI limitations in LAMMPS integration code

# Search in LAMMPS-related files for JAX backend limitations
echo "Searching LAMMPS integration files for JAX limitations..."
rg -C 3 "jax.*backend|backend.*jax" source/lmp/

# Look for test files that skip MPI tests for JAX
echo -e "\nSearching for skipped MPI tests with JAX..."
rg -C 3 "skip.*mpi.*jax|skip.*jax.*mpi" source/lmp/tests/

# Look for atom_modify map requirements for JAX
echo -e "\nSearching for atom_modify map requirements..."
rg -C 3 "atom_modify.*map.*jax|jax.*atom_modify.*map" source/lmp/

Length of output: 512


Script:

#!/bin/bash
# Let's check the test files more thoroughly for JAX MPI limitations

# Look for skipped tests in JAX-related test files
echo "Checking JAX test files for MPI limitations..."
rg -C 5 "@pytest.mark.skip" source/lmp/tests/test_lammps_dpa_jax.py source/lmp/tests/test_lammps_jax.py

# Check for any MPI-related comments or docstrings in JAX tests
echo -e "\nChecking for MPI-related comments in JAX tests..."
rg -C 3 "# .*mpi|# .*MPI|\"\"\".*mpi|\"\"\".*MPI" source/lmp/tests/test_lammps_dpa_jax.py source/lmp/tests/test_lammps_jax.py

# Look for any JAX-specific test fixtures or setup
echo -e "\nChecking for JAX-specific test setups..."
rg -C 3 "def test.*jax" source/lmp/tests/test_lammps_dpa_jax.py source/lmp/tests/test_lammps_jax.py

Length of output: 2985

doc/backend.md (1)

34-36: Verify TensorFlow dependency information.

Since JAX backend requires TensorFlow C++ interface for inference, we should verify if the minimum TensorFlow version requirement is documented correctly.

🧰 Tools
🪛 LanguageTool

[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...

(SO_THAT_UNNECESSARY_COMMA)

source/api_cc/include/common.h (1)

16-16: LGTM! Verify the enum usage across the codebase.

The addition of JAX to the DPBackend enum is well-placed and maintains the Unknown value as the last enum option. This change aligns with the PR's objective of integrating JAX backend support.

Let's verify the usage of this enum across the codebase:

✅ Verification successful

Based on the search results, I can now make a definitive assessment of the review comment. The code shows that the DPBackend enum is properly used across the codebase:

  1. The enum is used in multiple files to determine the backend type based on model file extensions:

    • .pth -> PyTorch
    • .pb -> TensorFlow
    • .savedmodel -> JAX
    • Unknown for unrecognized formats
  2. The enum is used consistently in initialization code to create the appropriate backend implementation:

    • DeepPot.cc, DeepTensor.cc, and DataModifier.cc all use the enum for backend selection
    • Each backend has proper error handling when not built/supported
    • The Unknown value is properly handled as the last enum option
  3. The order of enum values is preserved across the codebase, maintaining ABI compatibility:

    enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };
  4. The addition of JAX follows the existing pattern and is properly integrated into the backend selection logic.

The enum modification is well-implemented and properly integrated across the codebase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for DPBackend enum usage and ensure proper handling of JAX backend

# Test 1: Find all switch/if statements handling DPBackend to ensure JAX is properly handled
echo "Checking DPBackend handling..."
rg -A 5 "DPBackend" 

# Test 2: Look for potential string conversions of backend types
echo "Checking backend string conversions..."
rg -A 5 "Unknown|TensorFlow|PyTorch|Paddle|JAX"

# Test 3: Find potential serialization/deserialization of backend types
echo "Checking serialization code..."
ast-grep --pattern 'switch ($backend) { $$$ }'

Length of output: 282653

source/lib/include/neighbor_list.h (1)

Line range hint 47-107: Verify the mapping implementation across the codebase.

Let's ensure the mapping functionality is properly integrated:

✅ Verification successful

Memory management and usage of mapping is properly implemented

The verification shows that the mapping implementation is correctly integrated:

  1. Memory management is properly handled:

    • In LAMMPS integration (fix_dplr.cpp, pair_deepmd.cpp), mapping is managed via std::vector
    • Device memory allocation/deallocation is properly paired in tests using malloc_device_memory/delete_device_memory
  2. Usage is consistent across the codebase:

    • Only set when single process and atom mapping is enabled (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE)
    • Properly initialized with -1 default values
    • Correctly maps atom tags using LAMMPS' atom->map() function
    • C API properly exposes the mapping functionality

The implementation follows good practices for memory safety and has consistent usage patterns across the codebase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check usage and implementation of the new mapping functionality

# Check for mapping usage in implementation files
echo "Checking mapping usage in implementation files..."
rg -l "set_mapping|->mapping" --type cpp

# Check for potential memory management of mapping
echo "Checking memory management..."
rg -B 3 -A 3 "delete.*mapping|free.*mapping" --type cpp

# Look for LAMMPS integration points
echo "Checking LAMMPS integration..."
rg -l "LAMMPS.*mapping|mapping.*LAMMPS" --type cpp

Length of output: 3124


Script:

#!/bin/bash
# Let's check the actual usage of mapping in the identified files

echo "Checking mapping usage in fix_dplr.cpp..."
rg -C 5 "mapping" source/lmp/fix_dplr.cpp

echo -e "\nChecking mapping usage in pair_deepmd.cpp..."
rg -C 5 "mapping" source/lmp/pair_deepmd.cpp

echo -e "\nChecking mapping usage in c_api.cc..."
rg -C 5 "mapping" source/api_c/src/c_api.cc

echo -e "\nChecking mapping usage in deepmd.hpp..."
rg -C 5 "mapping" source/api_c/include/deepmd.hpp

# Let's also check for any initialization or allocation of mapping
echo -e "\nChecking mapping initialization..."
rg -C 5 "mapping\s*=\s*new|malloc.*mapping" --type cpp

Length of output: 8714

source/api_c/include/c_api.h (3)

15-15: API version increment is appropriate.

The increment from version 23 to 24 correctly reflects the addition of new functionality (DP_NlistSetMapping).


Line range hint 34-71: Documentation improvements enhance API clarity.

The expanded documentation for DP_NewNlist_comm provides clear and comprehensive parameter descriptions, improving API usability.


81-89: Verify implementation consistency.

Let's verify that the implementation in c_api.cc matches this interface declaration.

✅ Verification successful

Implementation is consistent with interface

The implementation is correctly aligned with the interface:

  • The DP_NlistSetMapping function in c_api.cc correctly calls the set_mapping method of the InputNlist struct
  • The InputNlist struct in neighbor_list.h has the mapping member variable and set_mapping method properly defined
  • The mapping functionality is properly documented in both interface and implementation
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check implementation of DP_NlistSetMapping in c_api.cc
# and related changes in neighbor_list.h

# Test 1: Verify implementation exists in c_api.cc
rg -A 5 "DP_NlistSetMapping" "source/api_c/src/c_api.cc"

# Test 2: Verify InputNlist structure changes
rg -A 5 "class\s+InputNlist" "source/lib/include/neighbor_list.h"

Length of output: 285


Script:

#!/bin/bash
# Let's verify the InputNlist class and its set_mapping method

# Check for InputNlist class definition with more context
rg -B 2 -A 10 "class\s+InputNlist" source/lib/include/neighbor_list.h

# Check for set_mapping method declaration
rg "set_mapping" source/lib/include/neighbor_list.h

# Look for any tests related to mapping functionality
rg -l "set_mapping" source/lib/tests/

Length of output: 220


Script:

#!/bin/bash
# Let's try a different approach to find the InputNlist class definition
# using ast-grep which might handle the class structure better

# Search for InputNlist class definition
ast-grep --pattern 'class InputNlist {
  $$$
}'

# Also check if there are any other files that might contain the implementation
fd -e h -e cc -e cpp | xargs rg -l "InputNlist"

Length of output: 2729


Script:

#!/bin/bash
# Let's examine the neighbor_list.h file more carefully
# and the test files to verify the implementation

# Check neighbor_list.h content
cat source/lib/include/neighbor_list.h

# Check the test implementation for neighbor list
cat source/lib/tests/test_neighbor_list.cc

Length of output: 16925

source/api_c/include/deepmd.hpp (1)

618-622: Verify LAMMPS integration for the mapping functionality.

The method is well-integrated into the InputNlist structure and provides the necessary functionality for LAMMPS integration. To ensure correctness:

  1. Verify that LAMMPS correctly passes the atom mapping array
  2. Confirm the mapping is properly used in molecular dynamics simulations
  3. Test with both real and ghost atoms to ensure proper handling

Let's verify the usage of this new mapping functionality:

✅ Verification successful

The mapping functionality is correctly integrated with LAMMPS

The verification shows proper integration:

  • The set_mapping is correctly used in both pair_deepmd.cpp and fix_dplr.cpp
  • The mapping is only applied when running on a single processor (comm->nprocs == 1) and when LAMMPS atom mapping is enabled (atom->map_style != Atom::MAP_NONE)
  • The implementation is consistent across both regular and extended neighbor lists
  • Documentation confirms the intended use for atom type mapping in LAMMPS integration
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for LAMMPS integration code that uses the new mapping functionality

# Test: Search for LAMMPS code that calls set_mapping
rg -t cpp "set_mapping" 

# Test: Search for tests that verify the mapping functionality
rg -t cpp "test.*mapping"

Length of output: 534


Script:

#!/bin/bash
# Let's verify the integration more thoroughly

# Check the context of LAMMPS usage
rg -A 5 -B 5 "set_mapping" source/lmp/pair_deepmd.cpp
rg -A 5 -B 5 "set_mapping" source/lmp/fix_dplr.cpp

# Check if there are any tests for the mapping functionality
fd "test.*\.cpp" | xargs rg "mapping.*vec"

# Look for any documentation about mapping
fd "README|\.md" | xargs rg -i "mapping"

Length of output: 2787

source/api_cc/include/DeepPotJAX.h (2)

1-249: Comprehensive implementation of the DeepPotJAX class.

The DeepPotJAX class provides a well-structured and comprehensive implementation for integrating the JAX backend into the DeePMD-kit framework. It follows good practices by:

  • Inheriting from the DeepPotBase class to ensure a consistent interface.
  • Providing multiple constructors and an init method for flexible initialization.
  • Implementing accessor methods for retrieving model properties.
  • Overloading computew methods to support different data types and scenarios.
  • Utilizing assertions for error handling and ensuring proper initialization.
  • Encapsulating TensorFlow C API objects for graph and session management.
  • Integrating with the NeighborListData class for handling neighbor list data.
  • Providing a template compute method for evaluating energy and forces.

The class is thoroughly documented with Doxygen-style comments, enhancing readability and maintainability.


4-5: Verify the compatibility of TensorFlow C API with JAX.

Please ensure that the TensorFlow C API (tensorflow/c/c_api.h and tensorflow/c/eager/c_api.h) is compatible with the JAX backend and does not introduce any conflicts or limitations. It's important to thoroughly test the integration to ensure smooth functionality.

✅ Verification successful

Let me gather more information about the JAX implementation and its relationship with TensorFlow C API.


Based on the code analysis, I can now provide a final response regarding the compatibility of TensorFlow C API with JAX:

The TensorFlow C API is compatible and properly integrated with JAX through JAX2TF conversion

The codebase shows a well-designed integration where:

  • JAX models are converted to TensorFlow SavedModel format (JAX2TF) which allows using TensorFlow C API
  • The DeepPotJAX class properly initializes and uses TensorFlow C API to load and execute the converted JAX models
  • The implementation includes proper error handling and memory management for the TF C API resources
  • The code has been tested as evidenced by the test files (test_deeppot_jax.cc)
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash

# Search for potential compatibility issues or limitations in the codebase
rg -i "compatibility|limitation|issue|conflict" -C 5 --iglob "*.cc" --iglob "*.h"

# Check if there are any TODOs, FIXMEs, or HACKs related to JAX and TensorFlow C API integration
rg -i "TODO|FIXME|HACK" -C 5 --iglob "*.cc" --iglob "*.h" | rg -i "jax|tensorflow"

Length of output: 5496


Script:

#!/bin/bash
# Check for JAX-related implementations and backend configurations
rg -i "jax" -C 5 --iglob "*.cc" --iglob "*.h"

# Look for backend-related code and configurations
rg -i "backend" -C 5 --iglob "*.cc" --iglob "*.h"

# Search for DeepPotJAX implementation details
fd "DeepPotJAX" --exec cat {}

Length of output: 64163

source/api_cc/tests/test_deeppot_jax.cc (9)

1-15: LGTM!

The header includes and namespace usages look good. The necessary headers for Google Test, DeepPot, and other utility functions are included correctly.


16-90: LGTM!

The TestInferDeepPotAJAX class is set up correctly as a parameterized test fixture. The member variables for coordinates, atom types, box dimensions, and expected values are initialized appropriately in the SetUp method. The SetUp method also verifies the sizes of the expected force and virial vectors relative to the number of atoms.


97-157: LGTM!

The cpu_lmp_nlist test case looks good:

  • It correctly sets up the input data and neighbor list.
  • It calls the compute method of DeepPot with the appropriate arguments.
  • It folds back the computed forces using the _fold_back helper function.
  • It verifies the sizes of the computed force and virial vectors.
  • It checks the computed energy, forces, and virial against the expected values using appropriate tolerances.
  • It repeats the computation with a different output index to ensure consistency.

159-240: LGTM!

The cpu_lmp_nlist_atomic test case looks good:

  • It follows a similar structure to the cpu_lmp_nlist test case.
  • It additionally computes and verifies atomic energies and virials.
  • It checks the computed atomic energies and virials against the expected values using appropriate tolerances.
  • It repeats the computation with a different output index to ensure consistency.

242-302: LGTM!

The cpu_lmp_nlist_2rc test case looks good:

  • It sets up the neighbor list with a cutoff of 2*rc.
  • It calls the compute method of DeepPot with the appropriate arguments.
  • It verifies the sizes of the computed force and virial vectors.
  • It checks the computed energy, forces, and virial against the expected values using appropriate tolerances.
  • It repeats the computation with a different output index to ensure consistency.

304-364: LGTM!

The cpu_lmp_nlist_type_sel test case looks good:

  • It adds virtual atoms of a different type to the input data.
  • It sets up the neighbor list with the updated atom coordinates and types.
  • It calls the compute method of DeepPot with the appropriate arguments.
  • It verifies the sizes of the computed force and virial vectors.
  • It checks the computed energy, forces, and virial against the expected values (including the virtual atoms) using appropriate tolerances.

366-427: LGTM!

The cpu_lmp_nlist_type_sel_atomic test case looks good:

  • It follows a similar structure to the cpu_lmp_nlist_type_sel test case.
  • It additionally computes and verifies atomic energies and virials.
  • It checks the computed atomic energies and virials against the expected values using appropriate tolerances.

429-432: LGTM!

The print_summary test case looks good. It calls the print_summary method of DeepPot with an empty string argument.


434-439: LGTM!

The get_type_map test case looks good:

  • It calls the get_type_map method of DeepPot to retrieve the type map.
  • It verifies that the retrieved type map matches the expected value of "O H".
source/lmp/fix_dplr.cpp (2)

442-449: Verify mapping functionality in multi-processor runs

The current implementation initializes and populates mapping_vec only when running on a single processor (comm->nprocs == 1). If the mapping is required for correct behavior in multi-processor runs, omitting this initialization may lead to issues. Please verify whether the mapping should be handled when running on multiple processors and update the code accordingly if necessary.


482-484: Confirm the necessity of mapping in multi-processor scenarios

Similarly, the neighbor list mapping is set only when comm->nprocs == 1. If the mapping is essential for computations across multiple processors, this condition might prevent proper functionality in parallel executions. Ensure that the mapping is appropriately managed in multi-processor environments to avoid potential inconsistencies.

source/api_cc/src/DeepPotJAX.cc (16)

1-2: LGTM!

The SPDX license identifier is correctly specified.


4-18: LGTM!

The necessary header files are included correctly.


20-25: LGTM!

The check_status function is implemented correctly to check the TensorFlow operation status and throw an exception with the error message if an error occurs.


47-61: LGTM!

The get_data_tensor_type functions are implemented correctly to return the appropriate TensorFlow data type based on the input data type.


168-176: LGTM!

The create_tensor function is implemented correctly to create a TensorFlow tensor from the input data and shape.


178-191: LGTM!

The add_input function is implemented correctly to create a TensorFlow tensor from the input data and shape, and add it as an input to the TensorFlow operation.


193-207: LGTM!

The tensor_to_vector function is implemented correctly to copy the data from a TensorFlow tensor to a vector.


209-215: LGTM!

The constructors for the DeepPotJAX class are implemented correctly.


216-279: LGTM!

The init function is implemented correctly to initialize the DeepPotJAX class by loading the TensorFlow saved model, creating a TensorFlow session, and retrieving the necessary parameters.


281-293: LGTM!

The destructor for the DeepPotJAX class is implemented correctly to clean up the TensorFlow resources if the class was initialized.


295-454: LGTM!

The compute function is implemented correctly to perform the computation using the TensorFlow saved model. It handles the input and output tensors, executes the TensorFlow function, and copies the results back to the output vectors.


456-470: LGTM!

The template specialization for the compute function with double data type is implemented correctly.


472-486: LGTM!

The template specialization for the compute function with float data type is implemented correctly.


488-490: LGTM!

The get_type_map function is implemented correctly to retrieve the type map.


492-552: Verify the correctness of throwing exceptions for unimplemented methods.

The computew methods with different parameter types are currently throwing exceptions to indicate that they are not implemented. Ensure that this behavior is intended and aligns with the overall design and requirements of the codebase.


553-580: Verify the correctness of throwing exceptions for unimplemented methods.

The computew_mixed_type methods with different parameter types are currently throwing exceptions to indicate that they are not implemented. Ensure that this behavior is intended and aligns with the overall design and requirements of the codebase.

source/lmp/tests/test_lammps_jax.py (27)

1-11: LGTM!

The import statements look good and follow the standard Python conventions.


35-68: LGTM!

The definition of the expected atomic energies (expected_ae) and the calculation of the expected total energy (expected_e) look correct.


70-208: LGTM!

The definitions of the expected forces (expected_f and expected_f2) and virials (expected_v and expected_v2) look correct. The reshaping of the arrays is also done correctly.


210-223: LGTM!

The definitions of the simulation box (box), atomic coordinates (coord), and atom types (type_OH and type_HO) look correct.


230-238: LGTM!

The setup_module function correctly sets up the necessary data files for the tests using the write_lmp_data function.


241-244: LGTM!

The teardown_module function correctly removes the data files created during the setup phase.


279-283: LGTM!

The lammps fixture correctly creates a LAMMPS instance using the _lammps function and yields it for use in tests. The fixture also closes the LAMMPS instance after the test.


286-290: LGTM!

The lammps_type_map fixture correctly creates a LAMMPS instance with a type map using the _lammps function and yields it for use in tests. The fixture also closes the LAMMPS instance after the test.


293-297: LGTM!

The lammps_real fixture correctly creates a LAMMPS instance with real units using the _lammps function and yields it for use in tests. The fixture also closes the LAMMPS instance after the test.


300-304: LGTM!

The lammps_si fixture correctly creates a LAMMPS instance with SI units using the _lammps function and yields it for use in tests. The fixture also closes the LAMMPS instance after the test.


307-317: LGTM!

The test_pair_deepmd function correctly tests the DeepMD pair style by setting up the pair style, running the simulation, and asserting the potential energy and forces against the expected values.


319-340: LGTM!

The test_pair_deepmd_virial function correctly tests the DeepMD pair style with virial calculations by setting up the pair style, computing the virial tensor, running the simulation, and asserting the potential energy, forces, and virial tensor components against the expected values.


342-366: LGTM!

The test_pair_deepmd_model_devi function correctly tests the DeepMD pair style with model deviation output by setting up the pair style with two models, running the simulation, asserting the potential energy and forces against the expected values, and verifying the model deviation output against the expected values.


368-403: LGTM!

The test_pair_deepmd_model_devi_virial function correctly tests the DeepMD pair style with model deviation output and virial calculations by setting up the pair style with two models, computing the virial tensor, running the simulation, asserting the potential energy, forces, and virial tensor components against the expected values, and verifying the model deviation output against the expected values.


406-433: LGTM!

The test_pair_deepmd_model_devi_atomic_relative function correctly tests the DeepMD pair style with model deviation output and relative atomic deviations by setting up the pair style with two models and the relative parameter, running the simulation, asserting the potential energy and forces against the expected values, and verifying the model deviation output against the expected values calculated with the relative deviations.


435-465: LGTM!

The test_pair_deepmd_model_devi_atomic_relative_v function correctly tests the DeepMD pair style with model deviation output and relative virial deviations by setting up the pair style with two models and the relative_v parameter, running the simulation, asserting the potential energy and forces against the expected values, and verifying the model deviation output against the expected values calculated with the relative virial deviations.


468-477: LGTM!

The test_pair_deepmd_type_map function correctly tests the DeepMD pair style with a type map by setting up the pair style with a type map, running the simulation, and asserting the potential energy and forces against the expected values.


480-491: LGTM!

The test_pair_deepmd_real function correctly tests the DeepMD pair style with real units by setting up the pair style, running the simulation, and asserting the potential energy and forces against the expected values converted to real units.


494-519: LGTM!

The test_pair_deepmd_virial_real function correctly tests the DeepMD pair style with virial calculations and real units by setting up the pair style, computing the virial tensor, running the simulation, and asserting the potential energy, forces, and virial tensor components against the expected values converted to real units.


521-548: LGTM!

The test_pair_deepmd_model_devi_real function correctly tests the DeepMD pair style with model deviation output and real units by setting up the pair style with two models, running the simulation, asserting the potential energy and forces against the expected values converted to real units, and verifying the model deviation output against the expected values converted to real units.


551-592: LGTM!

The test_pair_deepmd_model_devi_virial_real function correctly tests the DeepMD pair style with model deviation output, virial calculations, and real units by setting up the pair style with two models, computing the virial tensor, running the simulation, asserting the potential energy, forces, and virial tensor components against the expected values converted to real units, and verifying the model deviation output against the expected values converted to real units.


595-625: LGTM!

The test_pair_deepmd_model_devi_atomic_relative_real function correctly tests the DeepMD pair style with model deviation output, relative atomic deviations, and real units by setting up the pair style with two models and the relative parameter converted to real units, running the simulation, asserting the potential energy and forces against the expected values converted to real units, and verifying the model deviation output against the expected values calculated with the relative deviations and converted to real units.


628-662: LGTM!

The test_pair_deepmd_model_devi_atomic_relative_v_real function correctly tests the DeepMD pair style with model deviation output, relative virial deviations, and real units by setting up the pair style with two models and the relative_v parameter converted to real units, running the simulation, asserting the potential energy and forces against the expected values converted to real units, and verifying the model deviation output against the expected values calculated with the relative virial deviations and converted to real units.


665-674: LGTM!

The test_pair_deepmd_si function correctly tests the DeepMD pair style with SI units by setting up the pair style, running the simulation, and asserting the potential energy and forces against the expected values converted to SI units.


677-686: LGTM!

The test_pair_deepmd_mpi function is correctly skipped if MPI or mpi4py is not installed on the system. The function is parametrized with different load balancing arguments.


28-33: Verify the existence and accessibility of the referenced files.

Similar to the previous comment, ensure that the file paths system_file, data_file, data_file_si, and data_type_map_file point to existing and accessible files in the repository. If not, update the paths accordingly.

#!/bin/bash 
# Description: Verify the existence of the referenced files.

# Test: Check if the files exist. Expect: No output if files exist, else list missing files.
fd -e lmp . | xargs -I {} sh -c 'test -f "$1" || echo "$1 not found"' -- {}

22-27: Verify the existence and accessibility of the referenced files.

Ensure that the file paths pbtxt_file2, pb_file, and pb_file2 point to existing and accessible files in the repository. If not, update the paths accordingly.

✅ Verification successful

Both referenced files exist and are accessible

The verification confirms that both files exist and are accessible in the repository:

  • deeppot-1.pbtxt exists at source/tests/infer/deeppot-1.pbtxt
  • deeppot_sea.savedmodel exists at source/tests/infer/deeppot_sea.savedmodel (as a directory containing model files)

The paths in the code are correctly referencing these files relative to the test file's location.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the existence of the referenced files.

# Test: Check if the files exist. Expect: No output if files exist, else list missing files.
fd -e pbtxt -e pb . | xargs -I {} sh -c 'test -f "$1" || echo "$1 not found"' -- {}

Length of output: 161


Script:

#!/bin/bash
# Let's try a different approach to verify the file paths

# First, let's check for the pbtxt file
ls -l source/tests/infer/deeppot-1.pbtxt 2>/dev/null || echo "deeppot-1.pbtxt not found"

# Then check for the savedmodel file
ls -l source/tests/infer/deeppot_sea.savedmodel 2>/dev/null || echo "deeppot_sea.savedmodel not found"

# Let's also try to find these files anywhere in the repository
echo "Searching for files in the repository:"
find . -name "deeppot-1.pbtxt" -o -name "deeppot_sea.savedmodel"

Length of output: 612

source/lmp/tests/test_lammps_dpa_jax.py (22)

1-11: LGTM!

The imports look good and cover the necessary dependencies for the test suite.


36-138: LGTM!

The expected values for energy, forces, and virial stress are correctly defined as NumPy arrays.


210-223: LGTM!

The box dimensions, coordinates, and atom types are correctly defined as NumPy arrays.


230-238: LGTM!

The setup_module function correctly sets up the necessary data files for the tests using the write_lmp_data function.


241-244: LGTM!

The teardown_module function correctly removes the data files after the tests are completed.


281-306: LGTM!

The fixture functions correctly create and close LAMMPS instances for different test scenarios (default, type map, real units, SI units).


309-319: LGTM!

The test_pair_deepmd function correctly tests the DeePMD pair style by setting up the LAMMPS instance, running the simulation, and comparing the potential energy and forces against the expected values.


321-342: LGTM!

The test_pair_deepmd_virial function correctly tests the computation of virial stress using the DeePMD pair style. It sets up the necessary compute and variables, runs the simulation, and compares the virial stress values against the expected values.


344-368: LGTM!

The test_pair_deepmd_model_devi function correctly tests the model deviation output of the DeePMD pair style. It sets up the pair style with two models, runs the simulation, and compares the model deviation values against the expected values.


370-406: LGTM!

The test_pair_deepmd_model_devi_virial function correctly tests the model deviation output along with the virial stress computation. It sets up the necessary compute, variables, and pair style, runs the simulation, and compares the model deviation and virial stress values against the expected values.


408-435: LGTM!

The test_pair_deepmd_model_devi_atomic_relative function correctly tests the model deviation output with the atomic and relative keywords. It sets up the pair style with the appropriate keywords, runs the simulation, and compares the model deviation values against the expected values calculated with the relative factor.


437-467: LGTM!

The test_pair_deepmd_model_devi_atomic_relative_v function correctly tests the model deviation output with the atomic and relative_v keywords. It sets up the pair style with the appropriate keywords, runs the simulation, and compares the model deviation values against the expected values calculated with the relative factor for the virial stress.


470-480: LGTM!

The test_pair_deepmd_type_map function correctly tests the DeePMD pair style with a type map. It sets up the LAMMPS instance with the type map data file, runs the simulation, and compares the potential energy and forces against the expected values.


482-493: LGTM!

The test_pair_deepmd_real function correctly tests the DeePMD pair style with real units. It sets up the LAMMPS instance with real units, runs the simulation, and compares the potential energy and forces against the expected values converted to real units.


496-521: LGTM!

The test_pair_deepmd_virial_real function correctly tests the computation of virial stress using the DeePMD pair style with real units. It sets up the necessary compute, variables, and pair style, runs the simulation, and compares the virial stress values against the expected values converted to real units.


523-551: LGTM!

The test_pair_deepmd_model_devi_real function correctly tests the model deviation output of the DeePMD pair style with real units. It sets up the pair style with two models, runs the simulation, and compares the model deviation values against the expected values converted to real units.


553-594: LGTM!

The test_pair_deepmd_model_devi_virial_real function correctly tests the model deviation output along with the virial stress computation using real units. It sets up the necessary compute, variables, and pair style, runs the simulation, and compares the model deviation and virial stress values against the expected values converted to real units.


597-627: LGTM!

The test_pair_deepmd_model_devi_atomic_relative_real function correctly tests the model deviation output with the atomic and relative keywords using real units. It sets up the pair style with the appropriate keywords, runs the simulation, and compares the model deviation values against the expected values calculated with the relative factor converted to real units.


630-664: LGTM!

The test_pair_deepmd_model_devi_atomic_relative_v_real function correctly tests the model deviation output with the atomic and relative_v keywords using real units. It sets up the pair style with the appropriate keywords, runs the simulation, and compares the model deviation values against the expected values calculated with the relative factor for the virial stress converted to real units.


667-676: LGTM!

The test_pair_deepmd_si function correctly tests the DeePMD pair style with SI units. It sets up the LAMMPS instance with SI units, runs the simulation, and compares the potential energy and forces against the expected values converted to SI units.


225-227: Verify the command to convert pbtxt to pb format.

The command to convert the pbtxt file to pb format looks correct. However, please ensure that the deepmd package is installed and the convert-from command is available.

#!/bin/bash
# Description: Verify the deepmd package is installed and the convert-from command is available.

# Test: Check if the deepmd package is installed. Expect: No output if installed.
python -c "import deepmd"

# Test: Check if the convert-from command is available. Expect: Usage information.
python -m deepmd convert-from --help

679-726: Verify the MPI test setup and execution.

The MPI test test_pair_deepmd_mpi is currently skipped due to the @pytest.mark.skip("MPI is not supported") decorator. If MPI support is intended to be tested, please ensure that the necessary dependencies (MPI and mpi4py) are installed and remove the skip decorator.

Additionally, verify that the run_mpi_pair_deepmd.py script exists at the specified location and contains the correct code to run the MPI test.

#!/bin/bash
# Description: Verify the MPI test setup and execution.

# Test: Check if the run_mpi_pair_deepmd.py script exists. Expect: File path.
fd run_mpi_pair_deepmd.py tests

# Test: Check if the script runs without errors. Expect: No output.
mpirun -n 2 python tests/run_mpi_pair_deepmd.py
source/api_cc/src/DeepPot.cc (3)

10-10: Inclusion of DeepPotJAX.h is appropriate

Including "DeepPotJAX.h" under #ifdef BUILD_TENSORFLOW ensures that the JAX backend is only included when TensorFlow support is available.


45-47: Correctly identify and set JAX backend based on model file extension

The added condition checks for the .savedmodel extension to determine if the JAX backend should be used. This logic correctly allows the code to recognize and initialize JAX models.


65-72: Initialization logic for JAX backend is well-implemented

The code properly initializes the JAX backend when BUILD_TENSORFLOW is defined by creating a DeepPotJAX instance. If TensorFlow support is not available, it throws a clear exception indicating that TensorFlow is required to load JAX2TF SavedModels.

source/lmp/pair_deepmd.cpp (1)

526-530: Verify if mapping initialization is required for multiple processes

The mapping vector mapping_vec is initialized only when comm->nprocs == 1 and atom->map_style != Atom::MAP_NONE. If the mapping is also necessary when running with multiple processes, consider adjusting the condition to ensure mapping_vec is properly initialized in all required scenarios.

To verify whether mapping should be initialized for multiple processes, search the codebase for other usages of atom->map in parallel contexts:

Comment on lines +18 to +19
source/tests/infer/deeppolar_new.pbtxt|
source/tests/infer/deeppot_dpa.savedmodel/saved_model.pb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Additional large model files need exclusion

The scan revealed several model files exceeding the 1024KB limit that are not currently excluded:

  • source/tests/infer/deepdipole_new.pbtxt (916KB)
  • source/tests/infer/deepdipole_fake.pbtxt (1002KB)
  • source/tests/infer/deeppot_sea.savedmodel/saved_model.pb (630KB)

These files should also be added to the exclusion list in .pre-commit-config.yaml to maintain consistency with the current approach of excluding large model files.

🔗 Analysis chain

LGTM! Consider checking for additional model files.

The exclusion patterns for the new model files are consistent with the existing pattern. However, since this PR is part of the JAX backend integration, we should verify if there are other similar model files that might need exclusion.


Let me try a different approach to find potential large model files by using fd and checking file sizes directly.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for other potential large model files that might need exclusion
# Look for similar files in the test directories that might exceed the size limit

echo "Searching for .pb and .pbtxt files in test directories..."
find source/tests -type f \( -name "*.pb" -o -name "*.pbtxt" \) -exec ls -lh {} \; | \
  awk '{if($5+0 > 1024) print "Warning: Large file found:", $9, "Size:", $5}'

Length of output: 292


Script:

#!/bin/bash
# Search for model files and their sizes in the source/tests directory
echo "Searching for model files in tests directory..."
fd -t f "\.(pb|pbtxt|savedmodel)$" source/tests -x ls -lh {}

# Also check the pre-commit config for current exclusions
echo -e "\nCurrent exclusions in pre-commit config:"
grep -A 10 "check-added-large-files" .pre-commit-config.yaml

Length of output: 2875

Comment on lines +104 to +107
/**
* @brief Set mapping for this neighbor list.
*/
void set_mapping(int* mapping_) { mapping = mapping_; };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve setter method safety and documentation.

The current implementation could benefit from:

  1. Input validation to prevent null pointer assignment
  2. Documentation about:
    • Pointer ownership expectations
    • Required size of the mapping array
    • Thread safety considerations

Consider this safer implementation:

   /**
    * @brief Set mapping for this neighbor list.
+   * @param mapping_ Pointer to an array of size nall containing indices mapping all atoms to real atoms.
+   *                 The pointer must remain valid for the lifetime of this object.
+   * @throw std::invalid_argument if mapping_ is null
    */
-  void set_mapping(int* mapping_) { mapping = mapping_; };
+  void set_mapping(int* mapping_) {
+    if (!mapping_) {
+      throw std::invalid_argument("mapping pointer cannot be null");
+    }
+    mapping = mapping_;
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/**
* @brief Set mapping for this neighbor list.
*/
void set_mapping(int* mapping_) { mapping = mapping_; };
/**
* @brief Set mapping for this neighbor list.
* @param mapping_ Pointer to an array of size nall containing indices mapping all atoms to real atoms.
* The pointer must remain valid for the lifetime of this object.
* @throw std::invalid_argument if mapping_ is null
*/
void set_mapping(int* mapping_) {
if (!mapping_) {
throw std::invalid_argument("mapping pointer cannot be null");
}
mapping = mapping_;
}

@@ -375,7 +377,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`

{{ tensorflow_icon }} Whether building the TensorFlow backend.
{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider separating backend enable flags.

The current ENABLE_TENSORFLOW variable name doesn't accurately reflect that it enables both TensorFlow and JAX backends. This could be confusing for users who only want to use JAX.

Consider either:

  1. Renaming to something more inclusive like ENABLE_TF_JAX_BACKEND
  2. Creating a separate ENABLE_JAX flag

Comment on lines +46 to +48
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add safety checks and documentation for the new mapping function.

The implementation needs several improvements for robustness and clarity:

  1. Add null pointer validation for both parameters
  2. Document the requirements for the mapping array (size, lifetime, ownership)
  3. Consider adding error handling similar to other functions in the codebase

Consider applying this diff:

+/**
+ * @brief Set the mapping array for the neighbor list
+ * @param[in] nl The neighbor list object
+ * @param[in] mapping The mapping array that will be used by the neighbor list.
+ *                    The array must remain valid for the lifetime of the neighbor list
+ *                    or until a new mapping is set.
+ * @note The function does not take ownership of the mapping array.
+ */
 void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
+  if (nl == nullptr || mapping == nullptr) {
+    return;
+  }
   nl->nl.set_mapping(mapping);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
/**
* @brief Set the mapping array for the neighbor list
* @param[in] nl The neighbor list object
* @param[in] mapping The mapping array that will be used by the neighbor list.
* The array must remain valid for the lifetime of the neighbor list
* or until a new mapping is set.
* @note The function does not take ownership of the mapping array.
*/
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
if (nl == nullptr || mapping == nullptr) {
return;
}
nl->nl.set_mapping(mapping);
}

Comment on lines +225 to +227
sp.check_output(
f"{sys.executable} -m deepmd convert-from pbtxt -i {pbtxt_file2.resolve()} -o {pb_file2.resolve()}".split()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Files are not present at the expected locations

The code expects the files at:

  • source/tests/infer/deeppot-1.pbtxt (pbtxt_file2)
  • source/lmp/tests/graph2.pb (pb_file2)

While deeppot-1.pbtxt exists at the correct location, graph2.pb is missing. This file is expected to be generated by the convert-from command but its parent directory must exist first.

🔗 Analysis chain

Verify the existence and accessibility of the pbtxt_file2 and pb_file2.

Ensure that the pbtxt_file2 and pb_file2 files exist and are accessible. If not, the convert-from command will fail.


Let me search for the actual paths of these files in the codebase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the existence of the `pbtxt_file2` and `pb_file2` files.

# Test: Check if the files exist. Expect: No output if files exist, else list missing files.
test -f "tests/infer/deeppot-1.pbtxt" || echo "tests/infer/deeppot-1.pbtxt not found"
test -f "source/lmp/tests/graph2.pb" || echo "source/lmp/tests/graph2.pb not found" 

Length of output: 242


Script:

#!/bin/bash
# Search for pbtxt_file2 and pb_file2 variable assignments
rg -A 3 "pbtxt_file2|pb_file2" source/lmp/tests/test_lammps_jax.py

# Find any .pbtxt files in the repository
fd -e pbtxt

# Find any .pb files in the repository
fd -e pb

Length of output: 3216

Comment on lines +22 to +34
pbtxt_file2 = (
Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot-1.pbtxt"
)
pb_file = (
Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa.savedmodel"
)
pb_file2 = Path(__file__).parent / "graph2.pb"
system_file = Path(__file__).parent.parent.parent / "tests"
data_file = Path(__file__).parent / "data.lmp"
data_file_si = Path(__file__).parent / "data.si"
data_type_map_file = Path(__file__).parent / "data_type_map.lmp"
md_file = Path(__file__).parent / "md.out"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

File paths need to be corrected

The verification reveals several issues with the file paths:

  • While pbtxt_file2 (deeppot-1.pbtxt) and pb_file (deeppot_dpa.savedmodel) exist in the correct location at source/tests/infer/, the other files are missing:
    • pb_file2 (graph2.pb)
    • data_file (data.lmp)
    • data_file_si (data.si)
    • data_type_map_file (data_type_map.lmp)
    • md_file (md.out)

These test data files need to be created in the source/lmp/tests/ directory or their paths need to be updated to point to existing files.

🔗 Analysis chain

Verify the file paths are correct and exist.

Please ensure that the specified file paths for pbtxt_file2, pb_file, pb_file2, system_file, data_file, data_file_si, data_type_map_file, and md_file are correct and the files exist at those locations.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the specified file paths exist.

# Test: Check if the files exist. Expect: No output if files exist.
fd -e pbtxt -e pb -e lmp -e out . tests/infer

Length of output: 445


Script:

#!/bin/bash
# Let's verify the existence of files using more specific paths based on the code

# Check for pbtxt and savedmodel files in tests/infer
ls -la source/tests/infer/deeppot-1.pbtxt source/tests/infer/deeppot_dpa.savedmodel 2>/dev/null

# Check for files in the test directory
ls -la source/lmp/tests/graph2.pb source/lmp/tests/data.lmp source/lmp/tests/data.si source/lmp/tests/data_type_map.lmp source/lmp/tests/md.out 2>/dev/null

# Additional search to find these files in case paths are different
find source -type f \( -name "deeppot-1.pbtxt" -o -name "deeppot_dpa.savedmodel" -o -name "graph2.pb" -o -name "data.lmp" -o -name "data.si" -o -name "data_type_map.lmp" -o -name "md.out" \)

Length of output: 956


Script:

#!/bin/bash
# Let's check the actual test directory structure to understand the correct paths

# First, let's find the exact location of test_lammps_dpa_jax.py to understand its context
find source -type f -name "test_lammps_dpa_jax.py"

# Then, let's look at the directory structure around this file
find source -type d -name "tests" -o -name "infer"

# Finally, let's check for the test data files in the same directory as test_lammps_dpa_jax.py
find source -type f -name "*.lmp" -o -name "*.si" -o -name "*.pb" -o -name "*.out"

Length of output: 594

Copy link
Contributor

coderabbitai bot commented Nov 6, 2024

📝 Walkthrough

Walkthrough

This pull request introduces several enhancements across various files, focusing on the integration of the JAX backend into the DeePMD-kit framework. Key changes include updates to the pre-commit configuration, documentation for backend support, and the introduction of new functionalities in the API and testing files. The JAX backend is now supported alongside TensorFlow and PyTorch, with corresponding updates made to installation instructions, code structures, and testing frameworks to ensure comprehensive functionality and compatibility.

Changes

File Path Change Summary
.pre-commit-config.yaml Updated hooks: added disallow-caps, modified check-added-large-files exclusions, and updated pylint entry.
doc/backend.md Added JAX as a new backend option, specified model extensions, and clarified C++ inference support.
doc/install/install-from-source.md Expanded C++ interface installation details for TensorFlow and JAX, clarified compiler requirements, and added environment variables.
doc/model/dpa2.md Added section on limitations of the JAX backend with LAMMPS.
source/api_c/include/c_api.h Incremented API version to 24, added DP_NlistSetMapping function, and updated documentation for existing functions.
source/api_c/include/deepmd.hpp Added set_mapping method to InputNlist structure.
source/api_c/src/c_api.cc Introduced DP_NlistSetMapping function to set mappings for neighbor lists.
source/api_cc/include/DeepPotJAX.h Added DeepPotJAX class with multiple constructors and methods for TensorFlow integration.
source/api_cc/include/common.h Updated DPBackend enum to include JAX.
source/api_cc/src/DeepPot.cc Enhanced initialization logic to support JAX backend.
source/api_cc/src/DeepPotJAX.cc Implemented TensorFlow operations and error handling for DeepPotJAX.
source/api_cc/tests/test_deeppot_jax.cc Added unit tests for DeepPot class using Google Test framework.
source/cmake/googletest.cmake.in Updated GIT_TAG for googletest from release-1.12.1 to v1.14.0.
source/lib/include/neighbor_list.h Added int* mapping member to InputNlist structure and corresponding setter method.
source/lmp/fix_dplr.cpp Enhanced mapping and force calculation logic in FixDPLR class.
source/lmp/pair_deepmd.cpp Improved atom mapping handling in PairDeepMD class.
source/lmp/tests/test_lammps_dpa_jax.py Introduced tests for DPA using LAMMPS, validating energy and force calculations.
source/lmp/tests/test_lammps_jax.py Added tests for LAMMPS integration with DeepMD, focusing on potential energy and force accuracy.
source/tests/infer/deeppot_dpa.savedmodel/.gitignore Updated .gitignore to exclude .pb files from version control.
source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb Added binary file fingerprint.pb.
source/tests/infer/deeppot_sea.savedmodel/.gitignore Updated .gitignore to exclude .pb files from version control.
source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb Added binary file fingerprint.pb.

Possibly related issues

Possibly related PRs

Suggested reviewers

  • iProzd
  • wanghan-iapcm
  • CaRoLZhangxy

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 18

🧹 Outside diff range and nitpick comments (23)
source/cmake/googletest.cmake.in (1)

Line range hint 1-1: Consider updating minimum CMake version requirement.

The current minimum CMake version (2.8.2) is quite old. Google Test 1.14.0 might benefit from newer CMake features. Consider updating to a more recent version (e.g., 3.10 or newer) to ensure better compatibility and access to modern CMake features.

doc/model/dpa2.md (1)

23-24: Consider adding more context about the MPI limitation.

While the limitation is clearly stated, it would be helpful to provide more context about why this limitation exists and whether there are plans to support multiple MPI ranks in the future.

Consider expanding the explanation:

-When using the JAX backend, 2 or more MPI ranks are not supported. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command.
+When using the JAX backend, 2 or more MPI ranks are not currently supported due to JAX's parallel processing model. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command to ensure proper atom indexing and data mapping between LAMMPS and DPA-2.
doc/backend.md (2)

34-36: Enhance clarity of JAX backend documentation.

The documentation would benefit from the following improvements:

  1. Explain the differences and use cases for .xlo vs .jax formats
  2. Provide more details about GPU device specificity, such as:
    • Whether this applies to all formats or just specific ones
    • How users can identify if a model is GPU-specific
  3. Consider reorganizing the version requirements to be more prominent

Here's a suggested improvement:

 Only the `.savedmodel` format supports C++ inference, which needs the TensorFlow C++ interface.
-The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs.
+The model is device-specific: models generated on GPU devices cannot be executed on CPUs. This applies to all JAX model formats (.xlo, .savedmodel, and .jax). You can identify GPU-specific models by checking the device information in the model metadata.
 Currently, this backend is developed actively, and has no support for training.
🧰 Tools
🪛 LanguageTool

[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...

(SO_THAT_UNNECESSARY_COMMA)


35-35: Remove unnecessary comma before "so that".

The comma before "so that" is grammatically incorrect and should be removed.

-The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs.
+The model is device-specific so that the model generated on the GPU device cannot be run on the CPUs.
🧰 Tools
🪛 LanguageTool

[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...

(SO_THAT_UNNECESSARY_COMMA)

source/api_cc/include/common.h (1)

16-16: Consider adding enum documentation.

Since this is a public API header, consider adding documentation comments for the DPBackend enum to describe each backend option and their implications.

Example improvement:

+/**
+ * @brief Supported deep learning backends
+ * @details
+ * - TensorFlow: TensorFlow backend support
+ * - PyTorch: PyTorch backend support
+ * - Paddle: PaddlePaddle backend support
+ * - JAX: JAX backend support (requires JAX 0.4.33+)
+ * - Unknown: Represents an unrecognized backend
+ */
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };
source/lib/include/neighbor_list.h (3)

47-48: Documentation could be more specific about mapping requirements.

The comment should clarify:

  • The expected size relationship with nall
  • Whether negative values are allowed in the mapping
  • The ownership/lifetime of the pointer
  • The relationship with LAMMPS atom indexing

Consider expanding the comment to:

-  /// mapping from all atoms to real atoms, in the size of nall
+  /// Mapping array from all atoms (including ghost atoms) to real atoms.
+  /// Size must match nall (total number of atoms including ghost atoms).
+  /// The caller retains ownership of the pointer.

104-107: Enhance documentation and consider adding validation.

The method documentation should be as detailed as other methods in this file. Also, consider adding nullptr validation.

Consider these improvements:

   /**
-   * @brief Set mapping for this neighbor list.
+   * @brief Set the mapping array for this neighbor list.
+   * @param mapping_ Pointer to an integer array mapping all atoms to real atoms.
+   *                Must not be nullptr and must have size matching nall.
+   * @note The caller retains ownership of the mapping array and must ensure
+   *       its lifetime exceeds that of the neighbor list.
    */
-  void set_mapping(int* mapping_) { mapping = mapping_; };
+  void set_mapping(int* mapping_) {
+    assert(mapping_ != nullptr);
+    mapping = mapping_;
+  };

Line range hint 47-107: Consider architectural improvements for safer pointer management.

While the current implementation follows existing patterns, consider these architectural improvements:

  1. Store nall as a member to enable size validation
  2. Consider using std::vector<int> or std::unique_ptr<int[]> for clearer ownership
  3. Add a method to validate the mapping array size

These changes would improve safety but would require more significant refactoring.

Would you like me to propose a more detailed design for these improvements?

doc/install/install-from-source.md (2)

300-302: Add version compatibility information for JAX backend.

While the documentation correctly states that JAX backend uses TensorFlow's C++ library, it would be helpful to specify:

  1. Minimum supported JAX version
  2. Version compatibility requirements between TensorFlow and JAX

Line range hint 380-396: Clarify JAX-specific configuration options.

The documentation updates for ENABLE_TENSORFLOW and TENSORFLOW_ROOT now include JAX backend support. However, please clarify:

  1. Are there any JAX-specific CMake variables that users need to set?
  2. Are there any additional configuration steps needed when using JAX vs TensorFlow?
source/api_cc/tests/test_deeppot_jax.cc (2)

72-73: Consider making the model file path configurable.

The model file path is hardcoded which could make the tests less portable and harder to maintain. Consider:

  1. Using environment variables
  2. Making it a configurable parameter
  3. Using a test fixture to manage test resources
-    std::string file_name = "../../tests/infer/deeppot_sea.savedmodel";
+    const char* model_path = std::getenv("DEEPMD_TEST_MODEL_PATH");
+    std::string file_name = model_path ? model_path : "../../tests/infer/deeppot_sea.savedmodel";

97-427: Consider reducing code duplication in test cases.

The test cases share similar setup and verification patterns. Consider extracting common test logic into helper functions to improve maintainability and reduce duplication. For example:

+ template <typename VALUETYPE>
+ void verify_results(
+     const std::vector<VALUETYPE>& force,
+     const std::vector<VALUETYPE>& virial,
+     const std::vector<VALUETYPE>& expected_f,
+     const std::vector<VALUETYPE>& expected_tot_v,
+     double ener,
+     double expected_tot_e,
+     int natoms) {
+   EXPECT_EQ(force.size(), natoms * 3);
+   EXPECT_EQ(virial.size(), 9);
+   EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
+   for (int ii = 0; ii < natoms * 3; ++ii) {
+     EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
+   }
+   for (int ii = 0; ii < 3 * 3; ++ii) {
+     EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
+   }
+ }
source/api_c/src/c_api.cc (1)

46-48: Consider consistent error handling.

The function should follow the established error handling pattern used throughout the codebase. Consider using the DP_REQUIRES_OK macro or setting the exception string in the DP_Nlist object when errors occur.

 void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
-  nl->nl.set_mapping(mapping);
+  try {
+    nl->nl.set_mapping(mapping);
+  } catch (const std::exception& e) {
+    nl->exception = std::string(e.what());
+  }
 }
source/api_c/include/deepmd.hpp (1)

618-622: Enhance documentation while implementation looks good.

The implementation is correct and follows the established pattern of delegating to the C API. However, the documentation could be more detailed to help users understand:

  • The expected size and lifetime requirements of the mapping array
  • Whether the pointer is stored or just used temporarily
  • The purpose and typical use cases for this mapping

Consider expanding the documentation like this:

  /**
   * @brief Set mapping for this neighbor list.
   * @param mapping mapping from all atoms to real atoms, in size nall.
+  * @details The mapping array should remain valid for the lifetime of the neighbor list
+  * or until the next call to set_mapping. The mapping is typically used to handle
+  * ghost/virtual atoms by mapping them to their corresponding real atoms.
+  * @note The size of the mapping array should match the total number of atoms (nall)
+  * in the system.
   */
source/api_cc/include/DeepPotJAX.h (2)

29-31: Pass integer parameters by value instead of by const int&

Passing integers like gpu_rank by const int& introduces unnecessary indirection since integers are small and copying them is inexpensive. It is more efficient and idiomatic in C++ to pass them by value.

Apply this diff to update the parameter passing:

-DeepPotJAX(const std::string& model,
-           const int& gpu_rank = 0,
-           const std::string& file_content = "");
+DeepPotJAX(const std::string& model,
+           int gpu_rank = 0,
+           const std::string& file_content = "");

-void init(const std::string& model,
-          const int& gpu_rank = 0,
-          const std::string& file_content = "");
+void init(const std::string& model,
+          int gpu_rank = 0,
+          const std::string& file_content = "");

Also applies to: 39-41


63-65: Clarify the constant return value in numb_types_spin()

The method numb_types_spin() always returns 0. If spin types are not supported in DeepPotJAX, consider documenting this behavior or modifying the method to reflect the intended use.

You could update the method to throw an exception or assert if spin types are not applicable:

-int numb_types_spin() const {
-    assert(inited);
-    return 0;
-};
+int numb_types_spin() const {
+    assert(inited);
+    throw std::runtime_error("Spin types are not supported in DeepPotJAX.");
+};

Alternatively, update the documentation to specify that this method returns 0 because spin types are unsupported in this implementation.

source/api_cc/src/DeepPotJAX.cc (2)

35-38: Optimize string manipulation by using 'resize' instead of 'substr'

In find_function, the call name_ = name_.substr(0, pos + 1); may result in self-assignment when the substring is the same as the original string. Using resize is more efficient and avoids unnecessary copying.

Apply this diff to improve efficiency:

 std::string::size_type pos = name_.find_last_not_of("0123456789_");
 if (pos != std::string::npos) {
-  name_ = name_.substr(0, pos + 1);
+  name_.resize(pos + 1);
 }
🧰 Tools
🪛 cppcheck

[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)


284-286: Ensure proper cleanup in destructor by checking 'status'

In the destructor ~DeepPotJAX(), after calling TF_DeleteSession, it's good practice to check the status to ensure that the session was deleted without errors.

Consider adding a status check:

 TF_DeleteSession(session, status);
+check_status(status);
 TF_DeleteGraph(graph);
source/lmp/tests/test_lammps_jax.py (1)

307-723: Refactor repetitive test code into helper functions.

Multiple test functions contain similar code blocks, such as setting up pair_style, pair_coeff, running simulations, and performing assertions. Refactoring these blocks into helper functions can enhance readability and maintainability.

source/lmp/tests/test_lammps_dpa_jax.py (2)

246-279: Consider extracting common LAMMPS setup code into a separate function.

The _lammps function contains a lot of common setup code for initializing the LAMMPS instance with specific units, boundary conditions, atom styles, etc. Consider extracting this into a separate setup_lammps function that can be reused across tests for better code organization and reusability.


679-726: Consider enabling the skipped MPI tests if possible.

The MPI tests are currently skipped due to MPI and mpi4py not being installed. Consider enabling these tests if possible by installing the necessary dependencies. MPI tests are important to ensure the DeepMD pair style works correctly in parallel.

source/api_cc/src/DeepPot.cc (2)

45-47: Refactor file extension checks into a helper function

The repeated pattern of checking model file extensions (e.g., .pth, .pb, .savedmodel) could be refactored into a helper function to improve maintainability and reduce code duplication.


69-71: Improve clarity of exception message

Consider rephrasing the exception message for better clarity.

Apply this diff to update the exception message:

         throw deepmd::deepmd_exception(
-            "TensorFlow backend is not built, which is used to load JAX2TF "
+            "TensorFlow backend is not built; it is required to load JAX2TF "
             "SavedModels");
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between dabedd2 and 58dcf2b.

📒 Files selected for processing (22)
  • .pre-commit-config.yaml (1 hunks)
  • doc/backend.md (1 hunks)
  • doc/install/install-from-source.md (4 hunks)
  • doc/model/dpa2.md (1 hunks)
  • source/api_c/include/c_api.h (4 hunks)
  • source/api_c/include/deepmd.hpp (1 hunks)
  • source/api_c/src/c_api.cc (1 hunks)
  • source/api_cc/include/DeepPotJAX.h (1 hunks)
  • source/api_cc/include/common.h (1 hunks)
  • source/api_cc/src/DeepPot.cc (3 hunks)
  • source/api_cc/src/DeepPotJAX.cc (1 hunks)
  • source/api_cc/tests/test_deeppot_jax.cc (1 hunks)
  • source/cmake/googletest.cmake.in (1 hunks)
  • source/lib/include/neighbor_list.h (2 hunks)
  • source/lmp/fix_dplr.cpp (2 hunks)
  • source/lmp/pair_deepmd.cpp (3 hunks)
  • source/lmp/tests/test_lammps_dpa_jax.py (1 hunks)
  • source/lmp/tests/test_lammps_jax.py (1 hunks)
  • source/tests/infer/deeppot_dpa.savedmodel/.gitignore (1 hunks)
  • source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb (1 hunks)
  • source/tests/infer/deeppot_sea.savedmodel/.gitignore (1 hunks)
  • source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb (1 hunks)
✅ Files skipped from review due to trivial changes (4)
  • source/tests/infer/deeppot_dpa.savedmodel/.gitignore
  • source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb
  • source/tests/infer/deeppot_sea.savedmodel/.gitignore
  • source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
🧰 Additional context used
🪛 LanguageTool
doc/backend.md

[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...

(SO_THAT_UNNECESSARY_COMMA)

🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc

[performance] 29-29: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 64-64: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 66-66: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 111-111: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 113-113: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 133-133: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 135-135: Function parameter 'device' should be passed by const reference.

(passedByValue)


[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.

(uselessCallsSubstr)

🔇 Additional comments (41)
source/cmake/googletest.cmake.in (2)

Line range hint 4-11: LGTM: Repository fallback mechanism.

The fallback mechanism to use Gitee mirror when USE_GITEE_GTEST is set is a good practice for users in regions with limited GitHub access.


14-14: Verify compatibility with Google Test v1.14.0.

The upgrade from v1.12.1 to v1.14.0 brings improvements but may introduce breaking changes. Notable changes in v1.14.0 include:

  • Support for C++20
  • Improved death test implementation
  • New matchers and assertions

Let's check for potential compatibility issues:

✅ Verification successful

Based on the analysis of the test files and Google Test usage patterns, I can now provide a definitive response about the compatibility with Google Test v1.14.0:

The upgrade to Google Test v1.14.0 is safe to proceed.

The codebase uses standard Google Test features that are stable across versions:

  • Basic test fixtures using ::testing::Test
  • Standard test macros like TEST_F and TYPED_TEST
  • No usage of death tests (ASSERT_DEATH, EXPECT_DEATH, etc.)
  • No custom test event listeners or printers
  • No deprecated features that were removed in v1.14.0

The test suite primarily uses:

  1. Type-parameterized tests with TYPED_TEST_SUITE and TYPED_TEST
  2. Test fixtures with basic class inheritance from ::testing::Test
  3. Standard assertions and test declarations
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Search for Google Test usage patterns that might be affected by the version upgrade

# Look for death tests that might need updates
rg -A 5 "ASSERT_DEATH|EXPECT_DEATH|ASSERT_EXIT|EXPECT_EXIT"

# Check for custom test event listeners that might need updates
ast-grep --pattern 'class $_ : public testing::TestEventListener'

# Look for custom test printers that might need updates
ast-grep --pattern 'class $_ : public testing::PrintTo'

Length of output: 187


Script:

#!/bin/bash
# Let's check for actual test files and Google Test usage patterns

# Find test files
fd -e cpp -e cc -e h test

# Look for Google Test includes and main patterns
rg -l "gtest|gmock"

# Check CMake files for any Google Test specific configurations
fd -e cmake | xargs cat

# Look for any test executables or test suites
rg -l "TEST_F|TEST_P|TEST|INSTANTIATE_TEST"

Length of output: 30214


Script:

#!/bin/bash
# Let's examine the test files to understand the Google Test usage patterns

# Look for specific Google Test features that might be affected by version changes
rg -A 3 "TEST_F|TEST_P|TYPED_TEST|INSTANTIATE_TEST" source/lib/tests source/api_c/tests source/api_cc/tests

# Check for any custom Google Test configurations or setup
rg -A 3 "testing::InitGoogleTest|testing::Environment|testing::Test" source/lib/tests source/api_c/tests source/api_cc/tests

# Look for any death test implementations
rg -A 3 "ASSERT_DEATH|EXPECT_DEATH|ASSERT_EXIT|EXPECT_EXIT" source/lib/tests source/api_c/tests source/api_cc/tests

# Check for any test fixture setup that might be affected
rg -A 3 "class.*: public ::testing::Test" source/lib/tests source/api_c/tests source/api_cc/tests

Length of output: 179258

doc/model/dpa2.md (2)

21-22: LGTM! Clear section header with appropriate backend icon.

The section header is well-formatted and correctly uses the JAX icon to indicate backend-specific content.


25-27: LGTM! Clear example with proper LAMMPS syntax.

The code example is correctly formatted and demonstrates the required command.

.pre-commit-config.yaml (1)

18-19: LGTM! Exclusion patterns properly added for model files.

The additional exclusions for deeppolar_new.pbtxt and deeppot_dpa.savedmodel/saved_model.pb are correctly formatted and aligned with the PR's objective of supporting DPA-2 models.

Let's verify the size of these excluded files to ensure they indeed need to be excluded:

✅ Verification successful

Exclusion patterns correctly added for large model files

The verification confirms that both files are indeed large binary files that warrant exclusion:

  • source/tests/infer/deeppolar_new.pbtxt: 2.6MB
  • source/tests/infer/deeppot_dpa.savedmodel/saved_model.pb: 1.5MB

These sizes justify their exclusion from pre-commit checks to maintain good performance of the git hooks.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if the excluded files exist and their sizes
# to validate the need for exclusion from large file checks

echo "Checking sizes of excluded files..."
find . -type f \( -name "deeppolar_new.pbtxt" -o -name "saved_model.pb" \) -exec ls -lh {} \;

Length of output: 483

source/api_cc/tests/test_deeppot_jax.cc (4)

1-15: LGTM! Well-organized includes and proper license header.

The file structure follows best practices with system headers, C++ standard library headers, and project headers properly organized.


16-93: LGTM! Well-structured test class with comprehensive documentation.

The test class is well-documented with the numpy code that generated the test data, making it easy to understand and maintain. The setup properly validates data sizes and initializes all necessary components.


429-439: LGTM! Good coverage of utility functions.

The tests properly verify the summary printing and type map retrieval functionality.


132-138: Verify EPSILON definition.

The code uses EPSILON for floating-point comparisons, but its definition is not visible in this file. Ensure it's properly defined with an appropriate value for the precision requirements.

✅ Verification successful

Let me verify if the test file includes the necessary header with EPSILON definition.


EPSILON is properly defined and included

The test file correctly includes "test_utils.h" which defines EPSILON as (std::is_same<VALUETYPE, double>::value ? 1e-10 : 1e-4). This is a suitable definition that adjusts the precision threshold based on the value type being used (1e-10 for double, 1e-4 for float).

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for EPSILON definition in the codebase
rg -A 1 "^#define\s+EPSILON\b" 

Length of output: 657


Script:

#!/bin/bash
# Check if test_utils.h is included in test_deeppot_jax.cc
rg -A 1 '#include.*test_utils\.h' source/api_cc/tests/test_deeppot_jax.cc

Length of output: 100

source/api_c/include/c_api.h (1)

15-15: LGTM: API version increment is appropriate.

The increment from 23 to 24 aligns with the addition of the new DP_NlistSetMapping function.

source/api_c/src/c_api.cc (2)

46-48: LGTM - Good integration with existing codebase.

The function is well-integrated into the codebase:

  • Follows established naming conventions
  • Placed appropriately with other DP_Nlist functions
  • Maintains consistent implementation pattern

46-48: Verify usage and impact of the new mapping functionality.

Let's verify how this new mapping functionality is used in the codebase and ensure it doesn't affect existing neighbor list behavior.

✅ Verification successful

New mapping functionality is properly integrated and used in existing codebase

The verification shows that:

  • The set_mapping functionality is a well-integrated part of the neighbor list implementation in source/lib/include/neighbor_list.h
  • It's actively used in production code:
    • LAMMPS integration (pair_deepmd.cpp, fix_dplr.cpp)
    • Core neighbor list implementation
    • C/C++ API layers
  • The implementation is consistent across all layers (core, API, and integrations)
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for potential usage and impact of the new mapping functionality

# Test 1: Find potential callers of the new function
echo "Searching for potential callers of DP_NlistSetMapping..."
rg -l "DP_NlistSetMapping"

# Test 2: Find related neighbor list mapping code
echo "Searching for related neighbor list mapping code..."
rg -l "set_mapping.*neighbor"

# Test 3: Check for potential test files
echo "Searching for related test files..."
fd -g "*test*" | rg -l "Nlist|mapping"

Length of output: 503


Script:

#!/bin/bash
# Let's examine the implementation and usage more thoroughly

# Check the header file definition
echo "=== C API Header Definition ==="
cat source/api_c/include/c_api.h | grep -A 5 "DP_NlistSetMapping"

# Check the C++ header definition
echo -e "\n=== C++ Header Definition ==="
cat source/api_c/include/deepmd.hpp | grep -A 5 "DP_NlistSetMapping"

# Look for the base Nlist class implementation
echo -e "\n=== Searching for Nlist class implementation ==="
rg -l "class.*Nlist"

# If we find the class file, let's examine the set_mapping method
echo -e "\n=== Examining set_mapping implementation ==="
rg "set_mapping" --type cpp -A 5

# Check if there are any existing usages of mapping in neighbor list context
echo -e "\n=== Checking existing mapping usages ==="
rg "mapping.*neighbor|neighbor.*mapping" --type cpp

Length of output: 5281

source/api_c/include/deepmd.hpp (1)

618-622: Verify the usage of set_mapping in the codebase.

The method is well-integrated and follows the established pattern. Let's verify its usage across the codebase to ensure proper adoption.

✅ Verification successful

The set_mapping method is properly integrated and used as intended

The verification shows the method is well-integrated and used appropriately:

  • Implemented in neighbor_list.h as core functionality
  • Called by LAMMPS integration code in pair_deepmd.cpp and fix_dplr.cpp to map atom tags
  • Properly wrapped in the C API (c_api.cc) for external usage

The implementation follows the same pattern across all usage points, mapping from atom tags to real atoms consistently.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for potential callers of set_mapping to ensure proper integration

# Test 1: Find potential callers in the codebase
echo "Searching for potential callers of set_mapping..."
rg -l "InputNlist" | xargs rg "set_mapping"

# Test 2: Look for related LAMMPS integration code
echo "Checking LAMMPS integration..."
rg -l "LAMMPS" | xargs rg "mapping.*atom"

Length of output: 1277

source/api_cc/include/DeepPotJAX.h (1)

93-96: Verify the implementation of is_aparam_nall()

The method is_aparam_nall() always returns false. Please confirm if this is the intended behavior. If aparam is never of dimension nall, it might be clearer to document this explicitly or adjust the method to better reflect its purpose.

source/api_cc/src/DeepPotJAX.cc (1)

334-339: Avoid unnecessary casting to double when using float model

The comment suggests that casting to double may be unnecessary if using a float model. Ensure that this casting is intentional and necessary; otherwise, it could lead to precision issues or unnecessary conversions.

Please confirm whether this casting is required. If not, consider modifying the code to avoid unnecessary conversions.

source/lmp/tests/test_lammps_dpa_jax.py (21)

1-11: LGTM!

The imports look good and follow the standard convention.


35-138: LGTM!

The expected values for energy, forces, and virial are defined correctly using numpy arrays.


210-223: LGTM!

The box coordinates, atom coordinates, and atom types are defined correctly using numpy arrays.


230-238: LGTM!

The setup_module function correctly writes the LAMMPS data files using the write_lmp_data function with the defined box, coordinates, and atom types.


241-244: LGTM!

The teardown_module function removes the data files after the tests complete.


281-306: LGTM!

The pytest fixtures for creating LAMMPS instances with different units and data files are set up correctly.


309-319: LGTM!

The test_pair_deepmd function correctly tests the DeepMD pair style by comparing the potential energy and forces against expected values.


321-342: LGTM!

The test_pair_deepmd_virial function correctly tests the virial computation with the DeepMD pair style by comparing against expected values.


344-368: LGTM!

The test_pair_deepmd_model_devi function correctly tests the model deviation output by comparing against expected values loaded from the md_file.


370-406: LGTM!

The test_pair_deepmd_model_devi_virial function correctly tests the model deviation output with virial computation by comparing against expected values.


408-435: LGTM!

The test_pair_deepmd_model_devi_atomic_relative function correctly tests the model deviation output with relative atomic deviations.


437-468: LGTM!

The test_pair_deepmd_model_devi_atomic_relative_v function correctly tests the model deviation output with relative atomic deviations for virials.


470-480: LGTM!

The test_pair_deepmd_type_map function correctly tests the DeepMD pair style with a type map by comparing against expected values.


482-493: LGTM!

The test_pair_deepmd_real function correctly tests the DeepMD pair style with real units by comparing against expected values converted to real units.


496-521: LGTM!

The test_pair_deepmd_virial_real function correctly tests the virial computation with real units by comparing against expected values converted to real units.


523-551: LGTM!

The test_pair_deepmd_model_devi_real function correctly tests the model deviation output with real units by comparing against expected values converted to real units.


553-594: LGTM!

The test_pair_deepmd_model_devi_virial_real function correctly tests the model deviation output with virial computation in real units.


597-627: LGTM!

The test_pair_deepmd_model_devi_atomic_relative_real function correctly tests the model deviation output with relative atomic deviations in real units.


630-664: LGTM!

The test_pair_deepmd_model_devi_atomic_relative_v_real function correctly tests the model deviation output with relative atomic deviations for virials in real units.


667-676: LGTM!

The test_pair_deepmd_si function correctly tests the DeepMD pair style with SI units by comparing against expected values converted to SI units.


225-227: Ensure the command to convert pbtxt to pb runs successfully.

Please verify that the command to convert the pbtxt file to pb file using deepmd-kit runs without any errors.

#!/bin/bash 
# Description: Verify pbtxt to pb conversion command runs successfully.

# Test: Run the conversion command. Expect: No output on successful execution.
python -m deepmd convert-from pbtxt -i source/tests/infer/deeppot-1.pbtxt -o source/lmp/tests/graph2.pb 2>&1 >/dev/null || echo "Failed to convert pbtxt to pb"
source/api_cc/src/DeepPot.cc (2)

10-10: Include DeepPotJAX.h conditionally

Including DeepPotJAX.h within the #ifdef BUILD_TENSORFLOW block ensures that the header is only included when TensorFlow support is enabled, which is appropriate.


65-72: Correct handling of JAX backend initialization

The addition of JAX backend support with appropriate conditional compilation and exception handling is correctly implemented.

source/lmp/pair_deepmd.cpp (3)

575-577: Verify the mapping is correctly set for the main neighbor list.

The code correctly sets the mapping for the main neighbor list lmp_list when the number of processes is one and the mapping style is not MAP_NONE. This ensures that the atom indices in the neighbor list correctly correspond to the global atom indices.


588-590: Verify the mapping is correctly set for the extended neighbor list.

The code correctly sets the mapping for the extended neighbor list extend_lmp_list when the number of processes is one and the mapping style is not MAP_NONE. This ensures that the atom indices in the extended neighbor list, which includes ghost atoms, correctly correspond to the global atom indices.


524-530: Verify the mapping vector is correctly set for all atoms.

The code correctly sets the mapping_vec for all atoms when the number of processes is one and the mapping style is not MAP_NONE. However, ensure that the mapping is correctly utilized in other parts of the code where atom indices are used, such as neighbor lists, force calculations, etc.

Run the following script to verify the mapping usage:

✅ Verification successful

Let me gather more information about how this mapping vector is used in the codebase.


Let me try one more search to understand how this mapping vector is utilized:


Based on the search results, I can now provide a final response:

The mapping vector implementation is correct and properly utilized.

The mapping_vec is consistently used in both pair_deepmd.cpp and fix_dplr.cpp with the same implementation pattern:

  • It's only used when running on a single process (comm->nprocs == 1) and when atom mapping is enabled (atom->map_style != Atom::MAP_NONE)
  • The mapping is properly passed to the neighbor list through lmp_list.set_mapping(mapping_vec.data())
  • The implementation correctly maps all atoms (both local and ghost) as required
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify `mapping_vec` is correctly used for atom indices in the codebase.

# Test: Search for potential mapping issues. Expect: No concerning usages.
rg -A 5 -B 5 $'atom->map\(\w+\)|mapping_vec\[\w+\]' -g '*.cpp' -g '*.h' -g '*.cu'

Length of output: 1291


Script:

#!/bin/bash
# Search for usage of mapping_vec in the codebase
rg -A 10 -B 10 "mapping_vec" --type cpp --type hpp --type cu

# Search for any potential atom mapping or indexing related patterns
ast-grep --pattern 'std::vector<int> mapping_vec = $_'

# Look for the broader context of the files where mapping_vec is used
rg -l "mapping_vec" | xargs -I {} head -n 1 {}

Length of output: 288


Script:

#!/bin/bash
# Search for the complete function context where mapping_vec is used
rg -A 30 -B 30 "mapping_vec" source/lmp/pair_deepmd.cpp source/lmp/fix_dplr.cpp

# Look for any atom mapping related functions or methods
ast-grep --pattern 'void $_($$$) {
  $$$
  mapping_vec
  $$$
}'

Length of output: 13444

@@ -13,7 +13,7 @@
namespace deepmd {

typedef double ENERGYTYPE;
enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown };
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Update if-else chains in multiple files to handle JAX backend

The codebase has several if-else chains that need to be updated to handle the new JAX backend:

  • source/api_cc/src/DeepPot.cc: Missing JAX case in backend comparison
  • source/api_cc/src/DataModifier.cc: Missing JAX case in backend comparison
  • source/api_cc/src/DeepTensor.cc: Missing JAX case in backend comparison

Each of these files needs to add a new condition else if (deepmd::DPBackend::JAX == backend) before the final else clause to properly handle the JAX backend.

🔗 Analysis chain

LGTM! Verify enum usage across codebase.

The addition of JAX to the DPBackend enum before Unknown is correct and maintains backward compatibility.

Let's verify the enum usage across the codebase:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for DPBackend enum usage and potential switch statements that need updating

# Search for switch statements on DPBackend that might need updating
rg -A 10 "switch.*DPBackend" 

# Search for direct enum value comparisons
rg "DPBackend::(TensorFlow|PyTorch|Paddle|Unknown)"

Length of output: 1157

Comment on lines +81 to +89
/**
* @brief Set mapping for a neighbor list.
*
* @param nl Neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
* @since API version 24
*
**/
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding size parameter for safety.

While the function signature and documentation are well-structured, consider adding a size parameter for the mapping array to prevent potential buffer overflows. This would allow runtime validation of the mapping array size against the expected size (nall).

-extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);
+extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping, int mapping_size);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/**
* @brief Set mapping for a neighbor list.
*
* @param nl Neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
* @since API version 24
*
**/
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);
/**
* @brief Set mapping for a neighbor list.
*
* @param nl Neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
* @since API version 24
*
**/
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping, int mapping_size);

💡 Codebase verification

⚠️ Potential issue

Based on the code inspection, I can now provide the final response:

Buffer overflow risk confirmed - no size validation for mapping array

The implementation in both C API (source/api_c/src/c_api.cc) and the underlying C++ class (source/lib/include/neighbor_list.h) shows that the set_mapping function directly assigns the pointer without validating its size against nall. This could lead to buffer overflows if the provided array is smaller than expected.

Key findings:

  • C API implementation simply forwards the pointer: nl->nl.set_mapping(mapping);
  • C++ implementation in InputNlist class also directly assigns without validation: void set_mapping(int* mapping_) { mapping = mapping_; };
  • The mapping array is documented to require size nall but this is not enforced
🔗 Analysis chain

Add validation for mapping array size.

The function should validate that the mapping array size matches the expected size (nall) to prevent buffer overflows.

Let's verify if size validation is implemented in the source:


Let me check the C++ implementation of the set_mapping method to see if size validation is performed there.


Let me check the NeighborList class implementation to understand how the mapping is used and if there are any size validations.


Let me try a different approach to find the NeighborList implementation.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if the implementation validates the mapping array size
# Test: Search for size validation in the implementation
rg -A 5 "DP_NlistSetMapping.*\{" source/api_c/src/

Length of output: 373


Script:

#!/bin/bash
# Search for set_mapping implementation in the C++ code
ast-grep --pattern 'set_mapping($$$) { $$$ }'

# Also search for any references to mapping size validation
rg -i "set_mapping|mapping.*size" source/

Length of output: 2662


Script:

#!/bin/bash
# Search for NeighborList class implementation and usage of mapping
rg -A 10 "class.*NeighborList" source/lib/include/
rg -B 2 -A 5 "mapping\s*=" source/lib/src/neighbor_list.cc

Length of output: 113


Script:

#!/bin/bash
# Search for the NeighborList class definition and implementation
rg -A 10 "class.*NeighborList" source/
cat source/lib/include/neighbor_list.h

Length of output: 13092

Comment on lines +46 to +48
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add parameter validation and documentation.

While the implementation is correct, consider adding:

  1. Null pointer validation for parameters
  2. Documentation explaining the expected format and size of the mapping array
  3. Error handling for invalid inputs
+// Set the mapping for the neighbor list
+// @param nl: Pointer to the neighbor list
+// @param mapping: Array of integers defining the mapping. Must not be null.
+// @return void
 void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
+  if (nl == nullptr || mapping == nullptr) {
+    nl->exception = "Invalid null pointer in DP_NlistSetMapping";
+    return;
+  }
   nl->nl.set_mapping(mapping);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
// Set the mapping for the neighbor list
// @param nl: Pointer to the neighbor list
// @param mapping: Array of integers defining the mapping. Must not be null.
// @return void
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
if (nl == nullptr || mapping == nullptr) {
nl->exception = "Invalid null pointer in DP_NlistSetMapping";
return;
}
nl->nl.set_mapping(mapping);
}

* @param[in] atomic Whether to compute atomic energy and virial.
**/
template <typename VALUETYPE>
void compute(std::vector<ENERGYTYPE>& ener,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Undefined type ENERGYTYPE in template function

The template function compute uses ENERGYTYPE for the ener parameter, but ENERGYTYPE is not defined within this header file. This may lead to compilation errors.

Please ensure that ENERGYTYPE is defined or include the appropriate header file where ENERGYTYPE is declared.

Comment on lines +195 to +201
TF_Graph* graph;
TF_Status* status;
TF_Session* session;
TF_SessionOptions* sessionopts;
TFE_ContextOptions* ctx_opts;
TFE_Context* ctx;
std::vector<TF_Function*> func_vector;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure proper cleanup of TensorFlow C API objects in destructor

The class utilizes several TensorFlow C API objects (graph, status, session, sessionopts, ctx_opts, ctx, func_vector). To prevent memory leaks, it's crucial to release these resources appropriately in the destructor.

Consider adding cleanup code in the destructor ~DeepPotJAX():

+DeepPotJAX::~DeepPotJAX() {
+    if (session) TF_DeleteSession(session, status);
+    if (graph) TF_DeleteGraph(graph);
+    if (status) TF_DeleteStatus(status);
+    if (sessionopts) TF_DeleteSessionOptions(sessionopts);
+    if (ctx_opts) TFE_DeleteContextOptions(ctx_opts);
+    if (ctx) TFE_DeleteContext(ctx);
+    for (auto func : func_vector) {
+        TF_DeleteFunction(func);
+    }
+}

Ensure that all TensorFlow objects are properly deleted and that error checking is implemented where necessary.

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +85 to +89
inline T get_scalar(TFE_Context* ctx,
const std::string func_name,
const std::vector<TF_Function*>& funcs,
const std::string device,
TF_Status* status) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Pass 'func_name' and 'device' by const reference in 'get_scalar'

In the function get_scalar, passing func_name and device by const std::string& avoids unnecessary string copying and enhances performance.

Apply this diff to update the parameter passing:

 template <typename T>
 inline T get_scalar(TFE_Context* ctx,
-                    const std::string func_name,
+                    const std::string& func_name,
                     const std::vector<TF_Function*>& funcs,
-                    const std::string device,
+                    const std::string& device,
                     TF_Status* status) {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
inline T get_scalar(TFE_Context* ctx,
const std::string func_name,
const std::vector<TF_Function*>& funcs,
const std::string device,
TF_Status* status) {
inline T get_scalar(TFE_Context* ctx,
const std::string& func_name,
const std::vector<TF_Function*>& funcs,
const std::string& device,
TF_Status* status) {
🧰 Tools
🪛 cppcheck

[performance] 86-86: Function parameter 'func_name' should be passed by const reference.

(passedByValue)


[performance] 88-88: Function parameter 'device' should be passed by const reference.

(passedByValue)

Comment on lines +252 to +255
DPGetDeviceCount(gpu_num); // check current device environment
DPErrcheck(DPSetDevice(gpu_rank % gpu_num));
if (gpu_num > 0) {
device = "/gpu:" + std::to_string(gpu_rank % gpu_num);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle potential division by zero when computing 'gpu_rank % gpu_num'

In the code block where gpu_num is used, ensure that gpu_num is not zero before performing the modulo operation to avoid a division by zero error.

Apply this diff to add a check for gpu_num:

 int gpu_num;
 DPGetDeviceCount(gpu_num);  // check current device environment
+DPErrcheck(gpu_num != 0, "GPU device count is zero.");
 DPErrcheck(DPSetDevice(gpu_rank % gpu_num));

Or modify the logic to handle the case when gpu_num is zero.

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +241 to +244
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add cleanup for data_file_si in teardown_module().

The file data_file_si created in setup_module() is not being removed in teardown_module(). This could lead to leftover files after tests are run.

Apply this diff to fix the issue:

 def teardown_module():
     os.remove(data_file)
     os.remove(data_type_map_file)
+    os.remove(data_file_si)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)
os.remove(data_file_si)

⚠️ Potential issue

Ensure pb_file2 is removed in teardown_module().

The file pb_file2 generated in the setup is not being deleted in teardown_module(), potentially leaving temporary files after tests are executed.

Apply this diff to fix the issue:

 def teardown_module():
     os.remove(data_file)
     os.remove(data_type_map_file)
+    os.remove(pb_file2)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)
os.remove(pb_file2)

Comment on lines +689 to +704
sp.check_call(
[
"mpirun",
"-n",
"2",
sys.executable,
Path(__file__).parent / "run_mpi_pair_deepmd.py",
data_file,
pb_file,
pb_file2,
md_file,
f.name,
*balance_args,
]
)
arr = np.loadtxt(f.name, ndmin=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid using NamedTemporaryFile with subprocess on Windows.

Using tempfile.NamedTemporaryFile() may cause issues on Windows when passing the filename to a subprocess because the file is already open. This can lead to permission errors.

Consider using tempfile.mkstemp() to create a temporary file safely:

-def test_pair_deepmd_mpi(balance_args: list):
-    with tempfile.NamedTemporaryFile() as f:
+def test_pair_deepmd_mpi(balance_args: list):
+    fd, temp_path = tempfile.mkstemp()
+    os.close(fd)
     sp.check_call(
         [
             "mpirun",
             "-n",
             "2",
             sys.executable,
             Path(__file__).parent / "run_mpi_pair_deepmd.py",
             data_file,
             pb_file,
             pb_file2,
             md_file,
-            f.name,
+            temp_path,
             *balance_args,
         ]
     )
-    arr = np.loadtxt(f.name, ndmin=1)
+    arr = np.loadtxt(temp_path, ndmin=1)
+    os.remove(temp_path)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
sp.check_call(
[
"mpirun",
"-n",
"2",
sys.executable,
Path(__file__).parent / "run_mpi_pair_deepmd.py",
data_file,
pb_file,
pb_file2,
md_file,
f.name,
*balance_args,
]
)
arr = np.loadtxt(f.name, ndmin=1)
fd, temp_path = tempfile.mkstemp()
os.close(fd)
sp.check_call(
[
"mpirun",
"-n",
"2",
sys.executable,
Path(__file__).parent / "run_mpi_pair_deepmd.py",
data_file,
pb_file,
pb_file2,
md_file,
temp_path,
*balance_args,
]
)
arr = np.loadtxt(temp_path, ndmin=1)
os.remove(temp_path)

Comment on lines +22 to +34
pbtxt_file2 = (
Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot-1.pbtxt"
)
pb_file = (
Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa.savedmodel"
)
pb_file2 = Path(__file__).parent / "graph2.pb"
system_file = Path(__file__).parent.parent.parent / "tests"
data_file = Path(__file__).parent / "data.lmp"
data_file_si = Path(__file__).parent / "data.si"
data_type_map_file = Path(__file__).parent / "data_type_map.lmp"
md_file = Path(__file__).parent / "md.out"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Several test data files are missing from the expected locations

The verification reveals that while some source files exist (deeppot-1.pbtxt and deeppot_dpa.savedmodel), the following test data files are missing:

  • source/lmp/tests/graph2.pb
  • source/lmp/tests/data.lmp
  • source/lmp/tests/data.si
  • source/lmp/tests/data_type_map.lmp
  • source/lmp/tests/md.out

These files appear to be test data files that should be present for the tests to run properly. You'll need to either:

  1. Generate these test data files if they're meant to be created during test setup
  2. Add these missing test data files to the repository
  3. Update the test to use existing test data files from the correct locations
🔗 Analysis chain

Verify the file paths are correct and exist.

Please ensure that the file paths used for pbtxt_file2, pb_file, pb_file2, system_file, data_file, data_file_si, data_type_map_file, and md_file point to the correct locations and the files exist.


Let me verify the existence of these files with a more direct approach.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the file paths used in the test file exist.

# Test: Check if the files exist. Expect: No output if files exist.
fd -e pbtxt -e pb -e lmp -e out | xargs -I {} sh -c 'test -f "$1" || echo "$1 does not exist"' -- {}

Length of output: 248


Script:

#!/bin/bash
# Check for the existence of specific files mentioned in the test
ls -l source/tests/infer/deeppot-1.pbtxt 2>/dev/null || echo "deeppot-1.pbtxt not found"
ls -l source/tests/infer/deeppot_dpa.savedmodel 2>/dev/null || echo "deeppot_dpa.savedmodel not found"
ls -l source/lmp/tests/graph2.pb 2>/dev/null || echo "graph2.pb not found"
ls -l source/lmp/tests/data.lmp 2>/dev/null || echo "data.lmp not found"
ls -l source/lmp/tests/data.si 2>/dev/null || echo "data.si not found"
ls -l source/lmp/tests/data_type_map.lmp 2>/dev/null || echo "data_type_map.lmp not found"
ls -l source/lmp/tests/md.out 2>/dev/null || echo "md.out not found"

# Also check the directory structure
tree source/tests/infer/
tree source/lmp/tests/

Length of output: 2384

@njzjz njzjz marked this pull request as draft November 10, 2024 22:38
@njzjz
Copy link
Member Author

njzjz commented Nov 11, 2024

I will merge this PR to #4307.

@njzjz njzjz closed this Nov 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant