-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): passing mapping from LAMMPS to DPA-2 #4316
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Warning Rate limit exceeded@github-actions[bot] has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 14 minutes and 20 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Warning Rate limit exceeded@github-actions[bot] has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 14 minutes and 48 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
1 similar comment
Warning Rate limit exceeded@github-actions[bot] has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 14 minutes and 48 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
📝 WalkthroughWalkthroughThe pull request introduces several updates across multiple files, primarily focusing on enhancing support for the JAX backend within the DeePMD-kit framework. Key changes include modifications to the Changes
Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 13
🧹 Outside diff range and nitpick comments (39)
doc/model/dpa2.md (1)
21-27
: Documentation looks good, with some suggestions for enhancement.The new section clearly documents the JAX backend limitations and requirements. Consider enhancing it further by:
- Explaining why multiple MPI ranks are not supported
- Adding a link to LAMMPS MPI documentation
- Mentioning if multi-rank support is planned for future releases
doc/backend.md (2)
34-35
: Critical information about JAX model limitationsThe documentation clearly states important limitations that users need to be aware of:
- Only
.savedmodel
format supports C++ inference- Models are device-specific (GPU models won't run on CPU)
These are crucial pieces of information that help prevent runtime issues.
Consider adding examples of common error messages users might encounter when:
- Trying to use non-
.savedmodel
formats with C++ inference- Attempting to run GPU models on CPU
This would make troubleshooting easier for users.🧰 Tools
🪛 LanguageTool
[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...(SO_THAT_UNNECESSARY_COMMA)
35-35
: Fix unnecessary comma-The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs. +The model is device-specific so that the model generated on the GPU device cannot be run on the CPUs.🧰 Tools
🪛 LanguageTool
[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...(SO_THAT_UNNECESSARY_COMMA)
source/api_cc/include/common.h (1)
16-16
: Consider adding documentation for enum values.The DPBackend enum would benefit from documentation describing each backend option and their implications.
Consider adding doxygen-style documentation:
+/** + * @brief Supported deep learning backends + * @enum DPBackend + * @var TensorFlow TensorFlow backend support + * @var PyTorch PyTorch backend support + * @var Paddle PaddlePaddle backend support + * @var JAX JAX backend support + * @var Unknown Unrecognized backend + */ enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };source/lib/include/neighbor_list.h (2)
47-48
: Consider using smart pointers for safer memory management.The raw pointer
mapping
could lead to memory leaks or dangling pointers. Consider:
- Using
std::unique_ptr
orstd::shared_ptr
for automatic memory management- Adding a size member variable to explicitly track the mapping array size
- Adding bounds checking in
set_mapping
- /// mapping from all atoms to real atoms, in the size of nall - int* mapping = nullptr; + /// mapping from all atoms to real atoms, in the size of nall + private: + std::unique_ptr<int[]> mapping = nullptr; + size_t mapping_size = 0; + public: + const int* get_mapping() const { return mapping.get(); }
Based on the evidence gathered from the codebase:
GPU implementation needs to be updated for the mapping feature
The
convert_nlist_gpu_device
function currently doesn't handle the newmapping
member, which could lead to inconsistencies between CPU and GPU implementations. Specific updates needed:
convert_nlist_gpu_device
insource/lib/src/neighbor_list.cc
needs to allocate and copy the mapping array to GPU memoryfree_nlist_gpu_device
should be updated to free the GPU mapping array- GPU-specific test cases should be added to verify the mapping functionality
🔗 Analysis chain
Line range hint
47-107
: Consider GPU implications of the new mapping feature.The new mapping functionality might need corresponding support in GPU-related functions:
- Update
convert_nlist_gpu_device
to handle the mapping array- Consider memory management in
free_nlist_gpu_device
- Evaluate if
use_nlist_map
needs modificationsLet's check for related GPU functions:
Let me check the implementation of GPU-related functions to verify if they need updates for the mapping feature.
Let me check one more aspect to see if there are any GPU-related test cases that might show the expected behavior with mapping.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for GPU-related functions that might need updates rg "nlist.*gpu" --type cppLength of output: 4129
Script:
#!/bin/bash # Check the implementation of convert_nlist_gpu_device rg -A 20 "convert_nlist_gpu_device.*InputNlist" source/lib/src/neighbor_list.cc # Check if mapping is used in GPU-related functions rg "mapping" source/lib/src/neighbor_list.ccLength of output: 1278
Script:
#!/bin/bash # Check test cases for GPU neighbor list implementation rg -B 2 -A 5 "test.*nlist.*gpu" source/lib/tests/test_neighbor_list.cc # Look for any mapping-related GPU tests rg "mapping.*gpu" --type cpp source/lib/tests/Length of output: 121
doc/install/install-from-source.md (3)
300-302
: Add version compatibility information for JAX.The documentation clearly explains that both TensorFlow and JAX backends use the TensorFlow C++ library. However, it would be helpful to specify which JAX versions are compatible with which TensorFlow versions to help users avoid version mismatches.
Line range hint
380-396
: Clarify the relationship between TensorFlow and JAX enablement.While the documentation correctly indicates that both TensorFlow and JAX backends use these CMake variables, it would be helpful to explicitly state:
- Whether enabling TensorFlow automatically enables JAX support
- If there are any additional steps or variables needed specifically for JAX
- If there are any limitations when using both backends simultaneously
This would help users better understand the configuration options available to them.
Line range hint
1-500
: Documentation successfully integrates JAX backend support.The documentation changes effectively integrate JAX backend support while maintaining clarity and consistency. The shared infrastructure with TensorFlow is well explained, and the installation process is clearly documented. The changes align well with the PR objectives of adding JAX backend support.
Consider adding a troubleshooting section specific to JAX installation to help users resolve common issues they might encounter during the setup process.
source/api_cc/tests/test_deeppot_jax.cc (3)
19-32
: Consider moving Python code to a separate file or improve formatting.The Python code in comments shows how the test data was generated, but it's not properly formatted and makes the code harder to read.
Consider either:
- Moving this code to a separate
.py
file and reference it in the comments- Improving the formatting of the inline Python code:
- // import numpy as np - // from deepmd.infer import DeepPot - // coord = np.array([ - // 12.83, 2.56, 2.18, 12.09, 2.87, 2.74, - // 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, - // 3.51, 2.51, 2.60, 4.27, 3.22, 1.56 - // ]).reshape(1, -1) + // Python code used to generate test data: + // ```python + // import numpy as np + // from deepmd.infer import DeepPot + // + // coord = np.array([ + // 12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + // 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + // 3.51, 2.51, 2.60, 4.27, 3.22, 1.56 + // ]).reshape(1, -1) + // ```
319-324
: Define magic numbers as named constants.The test cases use magic numbers for array sizes and indices. This makes the code harder to understand and maintain.
Consider defining constants:
+// Constants for virtual atom tests +constexpr int kNumVirtualAtoms = 2; +constexpr int kVirtualAtomType = 2; + // add vir atoms -int nvir = 2; -std::vector<VALUETYPE> coord_vir(nvir * 3); -std::vector<int> atype_vir(nvir, 2); +int nvir = kNumVirtualAtoms; +std::vector<VALUETYPE> coord_vir(nvir * 3); +std::vector<int> atype_vir(nvir, kVirtualAtomType);Also applies to: 380-386
97-97
: Add documentation for test purposes.Each test case lacks documentation explaining its purpose and what specific aspect it's testing.
Add descriptive comments for each test case:
+/** + * Tests the basic functionality of DeepPot with LAMMPS neighbor lists. + * Verifies energy, force, and virial calculations. + */ TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist) {Also applies to: 159-159, 242-242, 304-304, 366-366, 429-429, 434-434
source/api_c/include/c_api.h (1)
81-89
: Documentation could be more detailed.While the function declaration is well-structured, the documentation could be enhanced with:
- More details about the expected format and constraints of the
mapping
array- Example usage or common use cases
- Error handling behavior
Consider expanding the documentation:
/** * @brief Set mapping for a neighbor list. * * @param nl Neighbor list. - * @param mapping mapping from all atoms to real atoms, in size nall. + * @param mapping Array mapping indices from all atoms to real atoms. Size must be equal to + * the total number of atoms (nall). Each element should be a valid index + * into the real atoms array. If NULL, the mapping is reset. + * @throws ValueError if any mapping index is invalid or if the array size is incorrect. * @since API version 24 * **/source/api_c/include/deepmd.hpp (1)
618-622
: Enhance documentation for theset_mapping
method.The documentation should clarify:
- The ownership and lifetime requirements of the
mapping
pointer- Whether the pointer can be null
- The expected size of the mapping array
Apply this diff to improve the documentation:
/** * @brief Set mapping for this neighbor list. * @param mapping mapping from all atoms to real atoms, in size nall. + * @note The mapping array must remain valid for the lifetime of this neighbor list. + * The array is not copied, only the pointer is stored. + * @warning The size of the mapping array must match the total number of atoms (nall). + * Passing nullptr will clear any existing mapping. */ void set_mapping(int *mapping) { DP_NlistSetMapping(nl, mapping); };source/api_cc/include/DeepPotJAX.h (1)
233-247
: Consider adding error handling to thecompute
method.The
compute
method currently lacks error handling for potential issues such as:
- Invalid input sizes (e.g., mismatch between the number of atoms and the size of the
atype
vector)- Out-of-range atom types
- Invalid or inconsistent neighbor list data
To improve the robustness of the code, consider adding appropriate error checks and throwing exceptions with informative error messages when invalid input is detected. This will help users identify and fix issues more easily.
source/lmp/fix_dplr.cpp (1)
445-445
: Use Consistent Data Types for Loop IndicesThe loop variable
ii
is declared assize_t
, whereasnall
is of typeint
. Mixing signed and unsigned types may lead to potential type mismatch warnings or unintended behavior. Consider declaringii
as anint
to match the type ofnall
.source/api_cc/src/DeepPotJAX.cc (10)
20-25
: Ensure consistent error handling across the codebase.The
check_status
function provides a convenient way to check the status of TensorFlow operations and throw exceptions in case of errors. Consider using this function consistently throughout the codebase to ensure uniform error handling and improve code readability.
27-45
: Consider using a more efficient search algorithm.The
find_function
function performs a linear search to find a specific function by name in thefuncs
vector. If the number of functions is large, this can be inefficient. Consider using a more efficient search algorithm, such as binary search or a hash table, to improve performance.🧰 Tools
🪛 cppcheck
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
47-61
: Use a template function to deduce the data type.Instead of defining separate overloads of the
get_data_tensor_type
function for each data type, consider using a template function that can deduce the data type based on the input vector's type. This will reduce code duplication and improve maintainability.template <typename T> inline TF_DataType get_data_tensor_type(const std::vector<T>& data) { if constexpr (std::is_same_v<T, double>) { return TF_DOUBLE; } else if constexpr (std::is_same_v<T, float>) { return TF_FLOAT; } else if constexpr (std::is_same_v<T, int32_t>) { return TF_INT32; } else if constexpr (std::is_same_v<T, int64_t>) { return TF_INT64; } else { static_assert(always_false_v<T>, "Unsupported data type"); } }
63-82
: Refactor theget_func_op
function to improve readability.The
get_func_op
function performs several steps to retrieve a TensorFlow operation for a given function name and context. Consider breaking down the function into smaller, more focused functions to improve readability and maintainability. For example, you can extract the code for finding the function and adding it to the context into separate functions.🧰 Tools
🪛 cppcheck
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 66-66: Function parameter 'device' should be passed by const reference.
(passedByValue)
84-107
: Usestd::optional
to handle the case when the function is not found.In the
get_scalar
function, instead of returning a default value when the function is not found, consider usingstd::optional
to explicitly handle the case when the function is not found. This will make the code more expressive and less error-prone.template <typename T> inline std::optional<T> get_scalar(TFE_Context* ctx, const std::string& func_name, const std::vector<TF_Function*>& funcs, const std::string& device, TF_Status* status) { TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status); check_status(status); TFE_TensorHandle* retvals[1]; int nretvals = 1; TFE_Execute(op, retvals, &nretvals, status); check_status(status); TFE_TensorHandle* retval = retvals[0]; TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status); check_status(status); T* data = static_cast<T*>(TF_TensorData(tensor)); if (data == nullptr) { return std::nullopt; } T result = *data; TFE_DeleteOp(op); TF_DeleteTensor(tensor); TFE_DeleteTensorHandle(retval); return result; }🧰 Tools
🪛 cppcheck
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
109-129
: Usestd::vector::resize
instead of creating a new vector.In the
get_vector
function, instead of creating a newresult
vector and resizing it later, consider resizing theresult
vector directly usingstd::vector::resize
. This will avoid unnecessary memory allocations and improve performance.template <typename T> inline std::vector<T> get_vector(TFE_Context* ctx, const std::string& func_name, const std::vector<TF_Function*>& funcs, const std::string& device, TF_Status* status) { TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status); check_status(status); TFE_TensorHandle* retvals[1]; int nretvals = 1; TFE_Execute(op, retvals, &nretvals, status); check_status(status); TFE_TensorHandle* retval = retvals[0]; std::vector<T> result; tensor_to_vector(result, retval, status); TFE_DeleteTensorHandle(retval); TFE_DeleteOp(op); return result; }🧰 Tools
🪛 cppcheck
[performance] 111-111: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'device' should be passed by const reference.
(passedByValue)
131-166
: Usestd::string_view
to avoid unnecessary string copies.In the
get_vector_string
function, consider usingstd::string_view
instead ofstd::string
to avoid unnecessary string copies when pushing back the strings into theresult
vector. This will improve performance, especially when dealing with large strings.inline std::vector<std::string> get_vector_string( TFE_Context* ctx, const std::string& func_name, const std::vector<TF_Function*>& funcs, const std::string& device, TF_Status* status) { TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status); check_status(status); TFE_TensorHandle* retvals[1]; int nretvals = 1; TFE_Execute(op, retvals, &nretvals, status); check_status(status); TFE_TensorHandle* retval = retvals[0]; TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status); check_status(status); const void* data = TF_TensorData(tensor); int64_t bytes_each_string = TF_TensorByteSize(tensor) / TF_TensorElementCount(tensor); std::vector<std::string> result; for (int ii = 0; ii < TF_TensorElementCount(tensor); ++ii) { const TF_TString* datastr = static_cast<const TF_TString*>(static_cast<const void*>( static_cast<const char*>(data) + ii * bytes_each_string)); const char* dst = TF_TString_GetDataPointer(datastr); size_t dst_len = TF_TString_GetSize(datastr); result.emplace_back(std::string_view(dst, dst_len)); } TFE_DeleteOp(op); TF_DeleteTensor(tensor); TFE_DeleteTensorHandle(retval); return result; }🧰 Tools
🪛 cppcheck
[performance] 133-133: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'device' should be passed by const reference.
(passedByValue)
168-176
: Usestd::vector::data
instead of&data[0]
.In the
create_tensor
function, instead of using&data[0]
to get a pointer to the underlying data of thedata
vector, consider usingstd::vector::data
. This is a more idiomatic and safer way to obtain a pointer to the vector's data.template <typename T> inline TF_Tensor* create_tensor(const std::vector<T>& data, const std::vector<int64_t>& shape) { TF_Tensor* tensor = TF_AllocateTensor(get_data_tensor_type(data), shape.data(), shape.size(), data.size() * sizeof(T)); std::memcpy(TF_TensorData(tensor), data.data(), TF_TensorByteSize(tensor)); return tensor; }
193-207
: Usestd::copy
instead of a raw loop to copy data.In the
tensor_to_vector
function, instead of using a raw loop to copy data from the TensorFlow tensor to theresult
vector, consider usingstd::copy
. This will make the code more readable and less error-prone.template <typename T> inline void tensor_to_vector(std::vector<T>& result, TFE_TensorHandle* retval, TF_Status* status) { TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status); check_status(status); T* data = static_cast<T*>(TF_TensorData(tensor)); result.resize(TF_TensorElementCount(tensor)); std::copy(data, data + result.size(), result.begin()); TF_DeleteTensor(tensor); }
209-215
: Use member initializer list for constructor initialization.In the
DeepPotJAX
constructor, consider using a member initializer list to initialize theinited
member variable instead of assigning it in the constructor body. This is more efficient and follows best practices for constructor initialization.deepmd::DeepPotJAX::DeepPotJAX(const std::string& model, const int& gpu_rank, const std::string& file_content) : inited(false) { init(model, gpu_rank, file_content); }source/lmp/tests/test_lammps_jax.py (2)
225-227
: Add error handling for model conversion commandThe use of
subprocess.check_output()
without error handling may cause the test to fail silently if the conversion command fails. To improve robustness, consider capturing exceptions and handling errors appropriately.Apply this diff to handle potential exceptions:
+try: sp.check_output( f"{sys.executable} -m deepmd convert-from pbtxt -i {pbtxt_file2.resolve()} -o {pb_file2.resolve()}".split() ) +except sp.CalledProcessError as e: + raise RuntimeError(f"Model conversion failed: {e}")
246-276
: Simplify unit handling in_lammps()
functionThe unit handling logic uses repetitive
if-elif
statements for setting neighbor distances, masses, and timesteps. Refactoring this code using dictionaries can improve readability and maintainability.Consider refactoring as follows:
def _lammps(data_file, units="metal") -> PyLammps: lammps = PyLammps() lammps.units(units) lammps.boundary("p p p") lammps.atom_style("atomic") - if units == "metal" or units == "real": - lammps.neighbor("2.0 bin") - elif units == "si": - lammps.neighbor("2.0e-10 bin") - else: - raise ValueError("units should be metal, real, or si") + neighbor_settings = { + "metal": "2.0 bin", + "real": "2.0 bin", + "si": "2.0e-10 bin", + } + lammps.neighbor(neighbor_settings.get(units, "2.0 bin")) lammps.neigh_modify("every 10 delay 0 check no") lammps.read_data(data_file.resolve()) - if units == "metal" or units == "real": - lammps.mass("1 16") - lammps.mass("2 2") - elif units == "si": - lammps.mass("1 %.10e" % (16 * constants.mass_metal2si)) - lammps.mass("2 %.10e" % (2 * constants.mass_metal2si)) - else: - raise ValueError("units should be metal, real, or si") + mass_settings = { + "metal": [("1 16"), ("2 2")], + "real": [("1 16"), ("2 2")], + "si": [ + ("1 %.10e" % (16 * constants.mass_metal2si)), + ("2 %.10e" % (2 * constants.mass_metal2si)), + ], + } + for mass in mass_settings.get(units, []): + lammps.mass(mass) - if units == "metal": - lammps.timestep(0.0005) - elif units == "real": - lammps.timestep(0.5) - elif units == "si": - lammps.timestep(5e-16) - else: - raise ValueError("units should be metal, real, or si") + timestep_settings = { + "metal": 0.0005, + "real": 0.5, + "si": 5e-16, + } + lammps.timestep(timestep_settings.get(units, 0.0005)) lammps.fix("1 all nve") return lammpssource/lmp/tests/test_lammps_dpa_jax.py (6)
246-268
: Refactor unit handling to avoid code duplicationThe
_lammps
function repeats similarif-elif-else
blocks for unit handling multiple times. This can be refactored to improve readability and maintainability.Apply this diff to refactor the unit handling:
def _lammps(data_file, units="metal") -> PyLammps: lammps = PyLammps() lammps.units(units) lammps.boundary("p p p") lammps.atom_style("atomic") # Requires for DPA-2 lammps.atom_modify("map yes") - if units == "metal" or units == "real": - lammps.neighbor("2.0 bin") - elif units == "si": - lammps.neighbor("2.0e-10 bin") - else: - raise ValueError("units should be metal, real, or si") + neighbor_distance = { + "metal": "2.0 bin", + "real": "2.0 bin", + "si": "2.0e-10 bin", + }.get(units) + if neighbor_distance is None: + raise ValueError("units should be 'metal', 'real', or 'si'") + lammps.neighbor(neighbor_distance) lammps.neigh_modify("every 10 delay 0 check no") lammps.read_data(data_file.resolve()) - if units == "metal" or units == "real": - lammps.mass("1 16") - lammps.mass("2 2") - elif units == "si": - lammps.mass("1 %.10e" % (16 * constants.mass_metal2si)) - lammps.mass("2 %.10e" % (2 * constants.mass_metal2si)) - else: - raise ValueError("units should be metal, real, or si") + if units in ["metal", "real"]: + mass1 = "16" + mass2 = "2" + elif units == "si": + mass1 = "%.10e" % (16 * constants.mass_metal2si) + mass2 = "%.10e" % (2 * constants.mass_metal2si) + else: + raise ValueError("units should be 'metal', 'real', or 'si'") + lammps.mass(f"1 {mass1}") + lammps.mass(f"2 {mass2}") - if units == "metal": - lammps.timestep(0.0005) - elif units == "real": - lammps.timestep(0.5) - elif units == "si": - lammps.timestep(5e-16) - else: - raise ValueError("units should be metal, real, or si") + timestep = { + "metal": 0.0005, + "real": 0.5, + "si": 5e-16, + }.get(units) + if timestep is None: + raise ValueError("units should be 'metal', 'real', or 'si'") + lammps.timestep(timestep) lammps.fix("1 all nve") return lammps
312-318
: Ensure proper cleanup of generated files in testsThe test
test_pair_deepmd
does not clean up the generated dump files, which may lead to clutter or issues in subsequent tests.Consider adding code to remove or manage temporary files created during the test.
679-681
: Remove unnecessary import ofimportlib
The import of
importlib
is only used for checking ifmpi4py
is installed. If MPI support is not required, consider removing this import to streamline the code.If you decide to keep MPI tests, ensure that
importlib
is necessary; otherwise, remove it.
468-521
: Consolidate similar test functions to reduce duplicationThe functions
test_pair_deepmd_real
andtest_pair_deepmd_virial_real
share similar setup code. Refactoring them to use a common helper function or parameterization could reduce code duplication and improve maintainability.Consider refactoring as follows:
- Extract common code into a helper function.
- Use
@pytest.mark.parametrize
to run tests with varying parameters.
690-726
: Reconsider skipping the MPI test functionThe
test_pair_deepmd_mpi
function is marked to be skipped unconditionally with@pytest.mark.skip
. Given that there are checks for MPI installation, you might want to enable this test when the environment supports it.If MPI support is now available, remove the skip decorator or adjust the condition to enable the test appropriately.
230-619
: Enhance test coverage for different unit systemsThe tests for different unit systems (
metal
,real
,si
) could be parameterized to improve readability and reduce redundancy.Use
@pytest.mark.parametrize
to run the same test logic over different units:@pytest.mark.parametrize("units", ["metal", "real", "si"]) def test_pair_deepmd_units(units): lammps_instance = _lammps(data_file=data_file, units=units) # ... rest of the test logic ...source/api_cc/src/DeepPot.cc (2)
45-47
: Replace magic numbers with constants for file extension checksTo enhance readability and maintainability, consider defining constants for the file extensions instead of using magic numbers like
11
. This approach makes the code clearer and simplifies updates if file extensions change in the future.Apply this diff to define a constant:
+#define SAVEDMODEL_EXTENSION ".savedmodel" +... } else if (model.length() >= strlen(SAVEDMODEL_EXTENSION) && - model.substr(model.length() - 11) == ".savedmodel") { + model.substr(model.length() - strlen(SAVEDMODEL_EXTENSION)) == SAVEDMODEL_EXTENSION) { backend = deepmd::DPBackend::JAX;
65-72
: Consider introducing a separate build flag for JAX supportCurrently, JAX support is conditioned on the
BUILD_TENSORFLOW
build flag. For better modularity and clarity, consider adding a separate build flag (e.g.,BUILD_JAX
) to control JAX-related compilation independently. This allows for more flexible build configurations and clearer dependency management.source/lmp/pair_deepmd.cpp (3)
524-531
: Mapping vector initialization is limited to single-process runsThe
mapping_vec
is initialized and populated only whencomm->nprocs == 1
. If mapping functionality is required for multi-process runs, consider extending this logic to support multiple processes.
575-577
: Setting mapping inlmp_list
is restricted to single-process runsThe mapping is set in
lmp_list
only whencomm->nprocs == 1
. If mapping is needed in multi-process configurations, ensure thatmapping_vec
is properly initialized andlmp_list.set_mapping()
is called accordingly.
588-590
: Consistent mapping setting inextend_lmp_list
for multi-process runsSimilar to the previous cases,
extend_lmp_list.set_mapping()
is called only whencomm->nprocs == 1
. Consider extending support for multi-process scenarios if mapping is required.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (22)
.pre-commit-config.yaml
(1 hunks)doc/backend.md
(1 hunks)doc/install/install-from-source.md
(4 hunks)doc/model/dpa2.md
(1 hunks)source/api_c/include/c_api.h
(4 hunks)source/api_c/include/deepmd.hpp
(1 hunks)source/api_c/src/c_api.cc
(1 hunks)source/api_cc/include/DeepPotJAX.h
(1 hunks)source/api_cc/include/common.h
(1 hunks)source/api_cc/src/DeepPot.cc
(3 hunks)source/api_cc/src/DeepPotJAX.cc
(1 hunks)source/api_cc/tests/test_deeppot_jax.cc
(1 hunks)source/cmake/googletest.cmake.in
(1 hunks)source/lib/include/neighbor_list.h
(2 hunks)source/lmp/fix_dplr.cpp
(2 hunks)source/lmp/pair_deepmd.cpp
(3 hunks)source/lmp/tests/test_lammps_dpa_jax.py
(1 hunks)source/lmp/tests/test_lammps_jax.py
(1 hunks)source/tests/infer/deeppot_dpa.savedmodel/.gitignore
(1 hunks)source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb
(1 hunks)source/tests/infer/deeppot_sea.savedmodel/.gitignore
(1 hunks)source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
(1 hunks)
✅ Files skipped from review due to trivial changes (4)
- source/tests/infer/deeppot_dpa.savedmodel/.gitignore
- source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb
- source/tests/infer/deeppot_sea.savedmodel/.gitignore
- source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
🧰 Additional context used
🪛 LanguageTool
doc/backend.md
[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...
(SO_THAT_UNNECESSARY_COMMA)
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 66-66: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 111-111: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 133-133: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
🔇 Additional comments (16)
source/cmake/googletest.cmake.in (1)
14-14
: LGTM! Upgrading Google Test to v1.14.0.
The upgrade from release-1.12.1 to v1.14.0 is a good improvement that brings support for C++20 and various testing enhancements. This is a backward-compatible change.
Let's verify if there are any test files using C++20 features that would benefit from this upgrade:
✅ Verification successful
Upgrade to Google Test v1.14.0 is safe and beneficial
The upgrade from release-1.12.1 to v1.14.0 is safe. While we found no direct usage of C++20 features in test files, the codebase contains test utilities and the upgrade provides better testing infrastructure with improved death test implementation and custom printer support, which are valuable for the existing test suite.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Look for C++20 features in test files
# Expected: Files using C++20 features would benefit from this upgrade
# Search for potential C++20 features in test files
rg -l "concept|requires|constinit|consteval|co_await|co_yield|co_return|ranges::|views::|<span>|<concepts>" --type cpp
Length of output: 197
Script:
#!/bin/bash
# Let's check if these files are actually test files and examine their C++20 usage
rg -C3 "concept|requires|constinit|consteval|co_await|co_yield|co_return|ranges::|views::|<span>|<concepts>" source/lmp/pppm_dplr.cpp source/lmp/pair_deepmd.cpp source/3rdparty/json.hpp
# Also let's check if there are any actual test files in the repository
fd "test.*\.(cpp|h|hpp)$"
Length of output: 11735
doc/model/dpa2.md (1)
23-23
: Verify the MPI rank limitation.
Let's confirm that the JAX backend indeed doesn't support multiple MPI ranks.
✅ Verification successful
The test files show clear evidence that MPI is not supported in the JAX backend. Specifically, in source/lmp/tests/test_lammps_dpa_jax.py
, there's an explicit skip decorator @pytest.mark.skip("MPI is not supported")
for the MPI test. Let me verify if there are any additional implementation details or documentation about this limitation.
MPI rank limitation for JAX backend is verified
The limitation is confirmed through multiple pieces of evidence in the codebase:
- Explicit test skip with message "MPI is not supported" in
source/lmp/tests/test_lammps_dpa_jax.py
- Documentation in
doc/model/dpa2.md
correctly states the limitation - No MPI-specific implementation found in the JAX backend code
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Search for JAX-related MPI rank limitations in the codebase
# Look for JAX backend implementation files
echo "Searching for JAX backend files..."
fd -e py -e cpp -e h "jax|JAX"
# Search for MPI-related code in JAX implementation
echo "Searching for MPI-related code in JAX implementation..."
rg -i "mpi|rank" $(fd -e py -e cpp -e h "jax|JAX")
# Look for test files that might verify this limitation
echo "Searching for relevant test files..."
fd -e py "test.*jax"
Length of output: 2116
Script:
#!/bin/bash
# Search for documentation or comments about JAX MPI limitations
rg -i "jax.*mpi|mpi.*jax" --type md --type py --type cpp
# Look for any JAX-specific configuration or initialization code
ast-grep --pattern 'def $_($$$ jax $$$) {
$$$
}'
Length of output: 301
.pre-commit-config.yaml (1)
18-19
: LGTM! Please verify if other model files need exclusion.
The addition of deeppot_dpa.savedmodel/saved_model.pb
to the exclusion list is appropriate for the model file. However, let's verify if there are other similar model files that should also be excluded.
✅ Verification successful
Exclusion pattern is correctly configured
The verification shows that source/tests/infer/deeppot_dpa.savedmodel/saved_model.pb
(1.5M) is the only large .pb
file in the repository. The current exclusion in .pre-commit-config.yaml
correctly addresses this file, and no other model files need to be excluded at this time.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for other large .pb files that might need exclusion
# Expected: List of .pb files larger than 1024KB that aren't already excluded
# Find all .pb files
echo "Large .pb files (>1024KB) that might need exclusion:"
fd -e pb -x sh -c '
size=$(stat -f%z {} 2>/dev/null || stat -c%s {});
if [ "$size" -gt 1048576 ]; then
echo "{} ($(numfmt --to=iec-i --suffix=B $size))";
fi
'
Length of output: 485
Script:
#!/bin/bash
# Let's try a different approach using find and ls
echo "Large .pb files (>1024KB) that might need exclusion:"
find . -type f -name "*.pb" -exec ls -lh {} \; | awk '{if($5 ~ /[0-9.]+[MG]/) print $9, "("$5")"}'
Length of output: 284
doc/install/install-from-source.md (1)
343-343
: LGTM!
The CMake configuration section correctly indicates that JAX uses the same configuration process as TensorFlow, which is consistent with their shared dependency on TensorFlow's C++ library.
source/api_c/include/c_api.h (2)
Line range hint 34-71
: Documentation enhancement improves clarity.
The expanded documentation for DP_NewNlist_comm
provides clearer context about its parallel communication capabilities and parameter usage.
15-15
: API version increment is appropriate.
The increment from 23 to 24 correctly reflects the addition of new functionality (DP_NlistSetMapping
).
source/api_c/src/c_api.cc (1)
46-48
: Verify consistent error handling across the codebase.
Let's check if other neighbor list functions follow similar error handling patterns.
source/api_c/include/deepmd.hpp (1)
622-622
: LGTM: Implementation follows established patterns.
The implementation correctly delegates to the C API function and follows the same pattern as other methods in the class.
source/api_cc/include/DeepPotJAX.h (2)
1-249
: LGTM!
The DeepPotJAX
class is well-structured and follows the existing design patterns in the codebase. It provides the necessary functionality to integrate JAX models into the DeePMD-kit framework. The class inherits from DeepPotBase
, ensuring compatibility with the existing API. The constructors, destructor, and initialization methods are properly defined, allowing for flexible instantiation of the class.
The class includes methods to retrieve important model properties such as cutoff radius, number of types, and parameter dimensions. It also provides a get_type_map
method to access the type map information. The computew
and computew_mixed_type
methods are overloaded to handle different input types (double
and float
) and enable the computation of energy, force, virial, and atomic properties based on the provided input data.
The private member variables are appropriately named and store relevant information such as initialization status, device information, model parameters, and TensorFlow C API objects for graph and session management.
Error handling is implemented using assertions to ensure that methods are called only after initialization, preventing potential runtime errors.
The class is well-documented with Doxygen-style comments, providing clear explanations of the purpose and parameters of each method.
Overall, the DeepPotJAX
class is a solid addition to the DeePMD-kit framework, enabling seamless integration of JAX models.
4-5
: Verify the TensorFlow C API header inclusion.
Please ensure that the TensorFlow C API headers tensorflow/c/c_api.h
and tensorflow/c/eager/c_api.h
are available in the include path and that the TensorFlow library is properly linked to the project.
source/api_cc/src/DeepPotJAX.cc (1)
1-581
: Comprehensive implementation of the DeepPotJAX class for JAX integration.
The file provides a thorough implementation of the DeepPotJAX
class, enabling seamless integration of the JAX backend within the deepmd-kit framework. It covers essential aspects such as:
- Loading and initializing the JAX model
- Managing TensorFlow resources and contexts
- Implementing the core
compute
method for energy, force, and virial calculations - Handling different data types and input formats
- Ensuring proper memory management and cleanup
The code follows best practices, including:
- Appropriate error handling and status checks
- Efficient memory management using smart pointers and RAII
- Templated functions for generic programming
- Consistent naming conventions and code style
Overall, the implementation is well-structured, modular, and adheres to the SOLID principles, promoting maintainability and extensibility.
🧰 Tools
🪛 cppcheck
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 66-66: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 111-111: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 133-133: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
source/lmp/tests/test_lammps_jax.py (1)
677-682
: Verify the effectiveness of @pytest.mark.skipif
decorators
The @pytest.mark.skipif
decorators check for mpirun
and mpi4py
at import time. If the environment changes between import and execution, tests may not skip as intended.
Run the following script to ensure that tests are skipped appropriately:
✅ Verification successful
Skip conditions for MPI tests are correctly implemented
The skip conditions are properly implemented for the following reasons:
- The
shutil.which("mpirun")
check is appropriate as it verifies the MPI executable's presence in the system PATH - The
importlib.util.find_spec("mpi4py")
check is the recommended way to verify Python package availability - Both checks are used together, ensuring all MPI dependencies are available before test execution
- The test function itself uses these dependencies (
mpirun
command and implicitlympi4py
through the MPI script), confirming the necessity of these checks
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify that the skip conditions are evaluated correctly.
# Expected behavior: Tests should skip if `mpirun` or `mpi4py` are not available.
pytest source/lmp/tests/test_lammps_jax.py -v -k test_pair_deepmd_mpi
Length of output: 145
Script:
#!/bin/bash
# Let's analyze the test file to understand the skip conditions and test implementation
# First, check the test function implementation
rg "def test_pair_deepmd_mpi" source/lmp/tests/test_lammps_jax.py -A 10
# Then, check all skipif decorators in the file
rg "@pytest.mark.skipif" source/lmp/tests/test_lammps_jax.py -A 2
# Check if there are any other MPI-related tests and their skip conditions
rg "mpi" source/lmp/tests/test_lammps_jax.py
Length of output: 1039
source/lmp/tests/test_lammps_dpa_jax.py (2)
680-684
:
Correct the MPI skip condition and message
The skip condition for the MPI tests may not be accurate, and the skip message "MPI is not supported" contradicts the previous checks for MPI installation.
Apply this diff to correct the skip condition and message:
@pytest.mark.skipif(
shutil.which("mpirun") is None, reason="MPI is not installed on this system"
)
@pytest.mark.skipif(
importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed"
)
-@pytest.mark.skip("MPI is not supported")
+@pytest.mark.skip(reason="MPI is currently not supported in this test")
def test_pair_deepmd_mpi(balance_args: list):
Alternatively, consider enabling the test if MPI support is now available.
Likely invalid or redundant comment.
356-364
: Verify the correctness of model deviation calculations
In the function test_pair_deepmd_model_devi
, the computation of expected_md_v
uses np.sum(expected_v, axis=0)
which sums over atoms, possibly leading to incorrect virial deviation calculations.
Please ensure that the virial deviations are calculated correctly per atom rather than summing over atoms.
Run the following script to verify the computation:
If the virial deviations should not sum over atoms, consider modifying the calculation accordingly.
source/api_cc/src/DeepPot.cc (2)
10-10
: Including DeepPotJAX.h
under BUILD_TENSORFLOW
Including DeepPotJAX.h
within the #ifdef BUILD_TENSORFLOW
block aligns with the dependency that JAX support requires TensorFlow. This ensures that JAX-related code is only compiled when TensorFlow support is enabled.
65-72
: Appropriate initialization and error handling for JAX backend
The code correctly initializes the JAX backend when TensorFlow support is available and provides clear exception handling when it is not. This ensures users are informed about the dependency on TensorFlow for loading JAX2TF SavedModels.
@@ -13,7 +13,7 @@ | |||
namespace deepmd { | |||
|
|||
typedef double ENERGYTYPE; | |||
enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown }; | |||
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Missing JAX case in some backend handlers
The verification reveals that JAX backend case is missing in two files:
source/api_cc/src/DeepTensor.cc
: JAX case not handled in if-else chainsource/api_cc/src/DataModifier.cc
: JAX case not handled in if-else chain
While DeepPot.cc
properly handles the JAX backend, the other backend handlers need to be updated for consistency.
🔗 Analysis chain
LGTM! Verify switch statements for the new backend.
The addition of JAX to the DPBackend enum is clean and properly placed before the Unknown value.
Let's verify that all switch statements handling DPBackend are updated to include the JAX case:
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for switch statements or if-else chains handling DPBackend
rg -A 10 "switch.*DPBackend|if.*DPBackend.*==" source/
Length of output: 4502
/** | ||
* @brief Set mapping for this neighbor list. | ||
*/ | ||
void set_mapping(int* mapping_) { mapping = mapping_; }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Enhance method robustness and documentation.
The set_mapping
method needs:
- Parameter validation
- Clear ownership semantics
- Size validation
- More detailed documentation
- /**
- * @brief Set mapping for this neighbor list.
- */
- void set_mapping(int* mapping_) { mapping = mapping_; };
+ /**
+ * @brief Set mapping from all atoms to real atoms.
+ * @param mapping_ Pointer to mapping array of size nall. Must not be nullptr.
+ * @param size Size of the mapping array.
+ * @throws std::invalid_argument if mapping_ is nullptr or size is invalid.
+ * @note The class takes ownership of the mapping array.
+ */
+ void set_mapping(int* mapping_, size_t size) {
+ if (!mapping_ || size == 0) {
+ throw std::invalid_argument("Invalid mapping array");
+ }
+ mapping.reset(mapping_);
+ mapping_size = size;
+ };
Committable suggestion skipped: line range outside the PR's diff.
TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist) { | ||
using VALUETYPE = TypeParam; | ||
std::vector<VALUETYPE>& coord = this->coord; | ||
std::vector<int>& atype = this->atype; | ||
std::vector<VALUETYPE>& box = this->box; | ||
std::vector<VALUETYPE>& expected_e = this->expected_e; | ||
std::vector<VALUETYPE>& expected_f = this->expected_f; | ||
std::vector<VALUETYPE>& expected_v = this->expected_v; | ||
int& natoms = this->natoms; | ||
double& expected_tot_e = this->expected_tot_e; | ||
std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v; | ||
deepmd::DeepPot& dp = this->dp; | ||
float rc = dp.cutoff(); | ||
int nloc = coord.size() / 3; | ||
std::vector<VALUETYPE> coord_cpy; | ||
std::vector<int> atype_cpy, mapping; | ||
std::vector<std::vector<int> > nlist_data; | ||
_build_nlist<VALUETYPE>(nlist_data, coord_cpy, atype_cpy, mapping, coord, | ||
atype, box, rc); | ||
int nall = coord_cpy.size() / 3; | ||
std::vector<int> ilist(nloc), numneigh(nloc); | ||
std::vector<int*> firstneigh(nloc); | ||
deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); | ||
convert_nlist(inlist, nlist_data); | ||
|
||
double ener; | ||
std::vector<VALUETYPE> force_, virial; | ||
dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, | ||
inlist, 0); | ||
std::vector<VALUETYPE> force; | ||
_fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3); | ||
|
||
EXPECT_EQ(force.size(), natoms * 3); | ||
EXPECT_EQ(virial.size(), 9); | ||
|
||
EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); | ||
for (int ii = 0; ii < natoms * 3; ++ii) { | ||
EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); | ||
} | ||
for (int ii = 0; ii < 3 * 3; ++ii) { | ||
EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); | ||
} | ||
|
||
ener = 0.; | ||
std::fill(force_.begin(), force_.end(), 0.0); | ||
std::fill(virial.begin(), virial.end(), 0.0); | ||
dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, | ||
inlist, 1); | ||
_fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3); | ||
|
||
EXPECT_EQ(force.size(), natoms * 3); | ||
EXPECT_EQ(virial.size(), 9); | ||
|
||
EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); | ||
for (int ii = 0; ii < natoms * 3; ++ii) { | ||
EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); | ||
} | ||
for (int ii = 0; ii < 3 * 3; ++ii) { | ||
EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Reduce code duplication in test cases.
The test cases share significant amounts of code for setup and validation. This makes maintenance harder and increases the chance of inconsistencies.
Consider extracting common setup and validation code into helper methods:
template <typename VALUETYPE>
void ValidateResults(
const std::vector<VALUETYPE>& force,
const std::vector<VALUETYPE>& virial,
const std::vector<VALUETYPE>& expected_f,
const std::vector<VALUETYPE>& expected_tot_v,
double ener,
double expected_tot_e,
int natoms) {
EXPECT_EQ(force.size(), natoms * 3);
EXPECT_EQ(virial.size(), 9);
EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
for (int ii = 0; ii < natoms * 3; ++ii) {
EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
}
for (int ii = 0; ii < 3 * 3; ++ii) {
EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
}
}
Also applies to: 159-240, 242-302, 304-364, 366-427
TYPED_TEST(TestInferDeepPotAJAX, print_summary) { | ||
deepmd::DeepPot& dp = this->dp; | ||
dp.print_summary(""); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Enhance print_summary test case.
The print_summary
test case doesn't verify the output. It should validate that the summary contains expected information.
Consider capturing and validating the output:
TYPED_TEST(TestInferDeepPotAJAX, print_summary) {
deepmd::DeepPot& dp = this->dp;
- dp.print_summary("");
+ testing::internal::CaptureStdout();
+ dp.print_summary("");
+ std::string output = testing::internal::GetCapturedStdout();
+
+ EXPECT_THAT(output, testing::HasSubstr("Summary of DeepPot Model"));
+ EXPECT_THAT(output, testing::HasSubstr("Cutoff radius"));
+ EXPECT_THAT(output, testing::HasSubstr("Model type"));
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
TYPED_TEST(TestInferDeepPotAJAX, print_summary) { | |
deepmd::DeepPot& dp = this->dp; | |
dp.print_summary(""); | |
} | |
TYPED_TEST(TestInferDeepPotAJAX, print_summary) { | |
deepmd::DeepPot& dp = this->dp; | |
testing::internal::CaptureStdout(); | |
dp.print_summary(""); | |
std::string output = testing::internal::GetCapturedStdout(); | |
EXPECT_THAT(output, testing::HasSubstr("Summary of DeepPot Model")); | |
EXPECT_THAT(output, testing::HasSubstr("Cutoff radius")); | |
EXPECT_THAT(output, testing::HasSubstr("Model type")); | |
} |
void SetUp() override { | ||
std::string file_name = "../../tests/infer/deeppot_sea.savedmodel"; | ||
|
||
dp.init(file_name); | ||
|
||
natoms = expected_e.size(); | ||
EXPECT_EQ(natoms * 3, expected_f.size()); | ||
EXPECT_EQ(natoms * 9, expected_v.size()); | ||
expected_tot_e = 0.; | ||
expected_tot_v.resize(9); | ||
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.); | ||
for (int ii = 0; ii < natoms; ++ii) { | ||
expected_tot_e += expected_e[ii]; | ||
} | ||
for (int ii = 0; ii < natoms; ++ii) { | ||
for (int dd = 0; dd < 9; ++dd) { | ||
expected_tot_v[dd] += expected_v[ii * 9 + dd]; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling in SetUp method.
The SetUp
method should handle potential errors when initializing the DeepPot model and validating test data.
Consider adding error handling:
void SetUp() override {
std::string file_name = "../../tests/infer/deeppot_sea.savedmodel";
-
- dp.init(file_name);
+ try {
+ dp.init(file_name);
+ } catch (const std::exception& e) {
+ FAIL() << "Failed to initialize DeepPot model: " << e.what();
+ }
natoms = expected_e.size();
- EXPECT_EQ(natoms * 3, expected_f.size());
- EXPECT_EQ(natoms * 9, expected_v.size());
+ ASSERT_EQ(natoms * 3, expected_f.size()) << "Force array size mismatch";
+ ASSERT_EQ(natoms * 9, expected_v.size()) << "Virial array size mismatch";
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
void SetUp() override { | |
std::string file_name = "../../tests/infer/deeppot_sea.savedmodel"; | |
dp.init(file_name); | |
natoms = expected_e.size(); | |
EXPECT_EQ(natoms * 3, expected_f.size()); | |
EXPECT_EQ(natoms * 9, expected_v.size()); | |
expected_tot_e = 0.; | |
expected_tot_v.resize(9); | |
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.); | |
for (int ii = 0; ii < natoms; ++ii) { | |
expected_tot_e += expected_e[ii]; | |
} | |
for (int ii = 0; ii < natoms; ++ii) { | |
for (int dd = 0; dd < 9; ++dd) { | |
expected_tot_v[dd] += expected_v[ii * 9 + dd]; | |
} | |
} | |
} | |
void SetUp() override { | |
std::string file_name = "../../tests/infer/deeppot_sea.savedmodel"; | |
try { | |
dp.init(file_name); | |
} catch (const std::exception& e) { | |
FAIL() << "Failed to initialize DeepPot model: " << e.what(); | |
} | |
natoms = expected_e.size(); | |
ASSERT_EQ(natoms * 3, expected_f.size()) << "Force array size mismatch"; | |
ASSERT_EQ(natoms * 9, expected_v.size()) << "Virial array size mismatch"; | |
expected_tot_e = 0.; | |
expected_tot_v.resize(9); | |
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.); | |
for (int ii = 0; ii < natoms; ++ii) { | |
expected_tot_e += expected_e[ii]; | |
} | |
for (int ii = 0; ii < natoms; ++ii) { | |
for (int dd = 0; dd < 9; ++dd) { | |
expected_tot_v[dd] += expected_v[ii * 9 + dd]; | |
} | |
} | |
} |
void deepmd::DeepPotJAX::init(const std::string& model, | ||
const int& gpu_rank, | ||
const std::string& file_content) { | ||
if (inited) { | ||
std::cerr << "WARNING: deepmd-kit should not be initialized twice, do " | ||
"nothing at the second call of initializer" | ||
<< std::endl; | ||
return; | ||
} | ||
|
||
const char* saved_model_dir = model.c_str(); | ||
graph = TF_NewGraph(); | ||
status = TF_NewStatus(); | ||
|
||
sessionopts = TF_NewSessionOptions(); | ||
TF_Buffer* runopts = NULL; | ||
|
||
const char* tags = "serve"; | ||
int ntags = 1; | ||
|
||
session = TF_LoadSessionFromSavedModel(sessionopts, runopts, saved_model_dir, | ||
&tags, ntags, graph, NULL, status); | ||
check_status(status); | ||
|
||
int nfuncs = TF_GraphNumFunctions(graph); | ||
// allocate memory for the TF_Function* array | ||
func_vector.resize(nfuncs); | ||
TF_Function** funcs = func_vector.data(); | ||
TF_GraphGetFunctions(graph, funcs, nfuncs, status); | ||
check_status(status); | ||
|
||
ctx_opts = TFE_NewContextOptions(); | ||
ctx = TFE_NewContext(ctx_opts, status); | ||
check_status(status); | ||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM | ||
int gpu_num; | ||
DPGetDeviceCount(gpu_num); // check current device environment | ||
DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); | ||
if (gpu_num > 0) { | ||
device = "/gpu:" + std::to_string(gpu_rank % gpu_num); | ||
} else { | ||
device = "/cpu:0"; | ||
} | ||
#else | ||
device = "/cpu:0"; | ||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM | ||
|
||
rcut = get_scalar<double>(ctx, "get_rcut", func_vector, device, status); | ||
dfparam = | ||
get_scalar<int64_t>(ctx, "get_dim_fparam", func_vector, device, status); | ||
daparam = | ||
get_scalar<int64_t>(ctx, "get_dim_aparam", func_vector, device, status); | ||
std::vector<std::string> type_map_ = | ||
get_vector_string(ctx, "get_type_map", func_vector, device, status); | ||
// deepmd-kit stores type_map as a concatenated string, split by ' ' | ||
type_map = type_map_[0]; | ||
for (size_t i = 1; i < type_map_.size(); i++) { | ||
type_map += " " + type_map_[i]; | ||
} | ||
ntypes = type_map_.size(); | ||
sel = get_vector<int64_t>(ctx, "get_sel", func_vector, device, status); | ||
nnei = std::accumulate(sel.begin(), sel.end(), decltype(sel)::value_type(0)); | ||
inited = true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use std::unique_ptr
to manage TensorFlow resources.
In the init
function, consider using std::unique_ptr
with custom deleters to manage TensorFlow resources, such as TF_Graph
, TF_Status
, TF_SessionOptions
, and TFE_ContextOptions
. This will ensure proper cleanup of these resources and make the code more exception-safe.
void deepmd::DeepPotJAX::init(const std::string& model,
const int& gpu_rank,
const std::string& file_content) {
if (inited) {
std::cerr << "WARNING: deepmd-kit should not be initialized twice, do "
"nothing at the second call of initializer"
<< std::endl;
return;
}
const char* saved_model_dir = model.c_str();
auto graph = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>(
TF_NewGraph(), TF_DeleteGraph);
auto status = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>(
TF_NewStatus(), TF_DeleteStatus);
auto sessionopts = std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>(
TF_NewSessionOptions(), TF_DeleteSessionOptions);
TF_Buffer* runopts = nullptr;
const char* tags = "serve";
int ntags = 1;
session = TF_LoadSessionFromSavedModel(sessionopts.get(), runopts, saved_model_dir,
&tags, ntags, graph.get(), nullptr, status.get());
check_status(status.get());
int nfuncs = TF_GraphNumFunctions(graph.get());
func_vector.resize(nfuncs);
TF_GraphGetFunctions(graph.get(), func_vector.data(), nfuncs, status.get());
check_status(status.get());
auto ctx_opts = std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)>(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
ctx = TFE_NewContext(ctx_opts.get(), status.get());
check_status(status.get());
// ... (rest of the code remains the same)
}
deepmd::DeepPotJAX::~DeepPotJAX() { | ||
if (inited) { | ||
TF_DeleteSession(session, status); | ||
TF_DeleteGraph(graph); | ||
TF_DeleteSessionOptions(sessionopts); | ||
TF_DeleteStatus(status); | ||
TFE_DeleteContext(ctx); | ||
TFE_DeleteContextOptions(ctx_opts); | ||
for (size_t i = 0; i < func_vector.size(); i++) { | ||
TF_DeleteFunction(func_vector[i]); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use std::unique_ptr
to manage TensorFlow resources in the destructor.
In the DeepPotJAX
destructor, consider using std::unique_ptr
with custom deleters to manage the cleanup of TensorFlow resources. This will ensure proper cleanup and make the code more readable and maintainable.
deepmd::De
<!-- This is an auto-generated comment by CodeRabbit -->
lammps.pair_style( | ||
f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove temporary model deviation file after tests
The md_file
is created during tests but not removed afterward, which may clutter the file system with temporary files.
Add cleanup code to teardown_module()
:
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)
+ os.remove(md_file)
Committable suggestion skipped: line range outside the PR's diff.
def teardown_module(): | ||
os.remove(data_file) | ||
os.remove(data_type_map_file) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure all generated data files are removed in teardown_module()
The teardown_module()
function currently removes data_file
and data_type_map_file
but does not remove data_file_si
. This may leave temporary files after tests are run.
Apply this diff to remove data_file_si
during teardown:
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)
+ os.remove(data_file_si)
Committable suggestion skipped: line range outside the PR's diff.
def setup_module(): | ||
write_lmp_data(box, coord, type_OH, data_file) | ||
write_lmp_data(box, coord, type_HO, data_type_map_file) | ||
write_lmp_data( | ||
box * constants.dist_metal2si, | ||
coord * constants.dist_metal2si, | ||
type_OH, | ||
data_file_si, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing teardown for data_file_si
The setup_module
function creates data_file_si
, but teardown_module
does not remove it, which may leave residual files after tests are run.
Apply this diff to remove data_file_si
during teardown:
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)
+ os.remove(data_file_si)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def setup_module(): | |
write_lmp_data(box, coord, type_OH, data_file) | |
write_lmp_data(box, coord, type_HO, data_type_map_file) | |
write_lmp_data( | |
box * constants.dist_metal2si, | |
coord * constants.dist_metal2si, | |
type_OH, | |
data_file_si, | |
) | |
def teardown_module(): | |
os.remove(data_file) | |
os.remove(data_type_map_file) | |
os.remove(data_file_si) |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4316 +/- ##
==========================================
- Coverage 84.40% 84.31% -0.09%
==========================================
Files 570 573 +3
Lines 53071 53581 +510
Branches 3054 3110 +56
==========================================
+ Hits 44794 45177 +383
- Misses 7318 7429 +111
- Partials 959 975 +16 ☔ View full report in Codecov by Sentry. |
📝 WalkthroughWalkthroughThis pull request introduces several updates across multiple files, primarily focusing on the integration of the JAX backend into the DeePMD-kit framework. Key changes include modifications to pre-commit configurations to exclude specific files, enhancements to documentation regarding backend support, and the addition of new functionalities in the C API and related classes. The JAX backend is now fully supported, with corresponding updates to the installation instructions, API headers, and various source files to ensure compatibility and functionality across the codebase. Changes
Possibly related issues
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Outside diff range and nitpick comments (22)
source/cmake/googletest.cmake.in (1)
Line range hint
1-1
: Consider updating minimum CMake version requirement.The current minimum CMake version (2.8.2) is quite old. Consider updating to a more recent version (e.g., 3.10 or later) to leverage modern CMake features and best practices.
doc/model/dpa2.md (1)
21-27
: Documentation looks good, but could be more detailed.The new section clearly documents the limitations and provides the necessary command. Consider enhancing it with:
- A brief explanation of why these limitations exist
- Potential performance implications of using
map yes
- A more prominent warning about the MPI rank limitation
## Limiations of the JAX backend with LAMMPS {{ jax_icon }} -When using the JAX backend, 2 or more MPI ranks are not supported. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command. +:::warning +When using the JAX backend: +- 2 or more MPI ranks are not supported due to JAX's current parallel processing limitations +- The `map` option must be set to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command to ensure proper atom indexing + +Note: Using `map yes` may have a small performance overhead but is required for correct operation with the JAX backend. +::: ```lammps atom_modify map yes</blockquote></details> <details> <summary>doc/backend.md (2)</summary><blockquote> `34-36`: **Fix typographical error and enhance clarity of limitations.** The limitations section for JAX backend has a minor typographical issue and could be clearer about the implications. Apply these changes: ```diff -The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs. +The model is device-specific, so models generated on GPU devices cannot be run on CPUs.
🧰 Tools
🪛 LanguageTool
[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...(SO_THAT_UNNECESSARY_COMMA)
34-36
: Consider adding migration guidance for users.The limitations described (C++ inference only with SavedModel, device specificity, and no training support) are significant. Consider adding guidance for users migrating from other backends.
Would you like me to help draft a migration guide section that covers:
- When to use JAX vs other backends
- Migration paths from TensorFlow/PyTorch
- Best practices for handling device-specific models
🧰 Tools
🪛 LanguageTool
[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...(SO_THAT_UNNECESSARY_COMMA)
source/lib/include/neighbor_list.h (1)
47-48
: Enhance documentation for the mapping member.While the basic purpose is documented, please add more details about:
- Memory ownership (who allocates/deallocates)
- Size requirements (how "nall" is determined)
- Expected values in the mapping array
doc/install/install-from-source.md (2)
300-302
: Enhance clarity regarding JAX backend requirements and dependencies.The current documentation groups TensorFlow and JAX backends together which might be confusing. Consider:
- Separating JAX into its own tab-item section to clearly distinguish it from TensorFlow
- Adding JAX-specific version requirements
- Clarifying why JAX backend needs TensorFlow C++ library
-:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }} +:::{tab-item} TensorFlow {{ tensorflow_icon }} + +::: + +:::{tab-item} JAX {{ jax_icon }} + +Note: The JAX backend requires JAX version 0.4.33 or above and uses TensorFlow C++ library for its C++ interface implementation.
396-396
: Clarify TENSORFLOW_ROOT usage for JAX backend.The description should better explain why this path is needed for the JAX backend and how it's used.
-{{ tensorflow_icon }} {{ jax_icon }} The Path to TensorFlow's C++ interface. +{{ tensorflow_icon }} {{ jax_icon }} The Path to TensorFlow's C++ interface. This is required for both TensorFlow and JAX backends as JAX's C++ interface is implemented using TensorFlow's C++ libraries.source/api_c/include/c_api.h (1)
81-89
: Documentation could be more detailed.While the function documentation follows the established format, consider enhancing it with:
- Description of the expected size and format of the mapping array
- Documentation of the return value (void)
- Example usage or typical use case
/** * @brief Set mapping for a neighbor list. * - * @param nl Neighbor list. - * @param mapping mapping from all atoms to real atoms, in size nall. + * @param[in] nl Pointer to the neighbor list to be modified. + * @param[in] mapping Array of size nall that maps from all atoms to real atoms. + * @return void + * @note Typical use case: When working with ghost atoms or periodic boundary conditions, + * this mapping helps translate between local and global atom indices. * @since API version 24 * **/source/api_c/include/deepmd.hpp (1)
618-622
: Enhance documentation for theset_mapping
method.The implementation looks good, but the documentation could be more detailed to clarify:
- The ownership and lifetime requirements of the
mapping
pointer- Whether nullptr is a valid input
- The expected size of the mapping array
Consider updating the documentation like this:
/** * @brief Set mapping for this neighbor list. - * @param mapping mapping from all atoms to real atoms, in size nall. + * @param mapping Pointer to an array that maps from all atoms to real atoms. The array size must match the total number of atoms (nall). + * The pointer must remain valid for the lifetime of the neighbor list. + * A nullptr can be passed to reset/clear the mapping. */source/api_cc/include/DeepPotJAX.h (2)
46-49
: Add a const qualifier to thecutoff()
method.Since the
cutoff()
method does not modify the object's state and only returns thercut
member variable, consider adding theconst
qualifier to the method to indicate that it is a read-only operation. This helps improve code clarity and allows the method to be called on const objects.-double cutoff() const { +double cutoff() const { assert(inited); return rcut; };
233-247
: Consider usingstd::size_t
for array indexing and sizes.In the
compute
template method, consider usingstd::size_t
instead ofint
for variables that represent array sizes or indices, such asnghost
andago
. This ensures compatibility with the size type returned bystd::vector::size()
and avoids potential issues with signed/unsigned conversions.template <typename VALUETYPE> void compute(std::vector<ENERGYTYPE>& ener, std::vector<VALUETYPE>& force, std::vector<VALUETYPE>& virial, std::vector<VALUETYPE>& atom_energy, std::vector<VALUETYPE>& atom_virial, const std::vector<VALUETYPE>& coord, const std::vector<int>& atype, const std::vector<VALUETYPE>& box, - const int nghost, + const std::size_t nghost, const InputNlist& lmp_list, - const int& ago, + const std::size_t& ago, const std::vector<VALUETYPE>& fparam, const std::vector<VALUETYPE>& aparam, const bool atomic);source/api_cc/src/DeepPotJAX.cc (6)
27-45
: Consider passingfunc_name
by const reference for better performance.To avoid unnecessary string copying, consider passing the
func_name
parameter by const reference:-inline void find_function(TF_Function*& found_func, - const std::vector<TF_Function*>& funcs, - const std::string func_name) { +inline void find_function(TF_Function*& found_func, + const std::vector<TF_Function*>& funcs, + const std::string& func_name) {🧰 Tools
🪛 cppcheck
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
37-38
: Consider usingresize()
orpop_back()
instead ofsubstr()
for better performance.Assigning a prefix of the string to itself using
substr()
is ineffective. Consider usingresize()
orpop_back()
instead for better performance:- name_ = name_.substr(0, pos + 1); + name_.resize(pos + 1);🧰 Tools
🪛 cppcheck
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
63-82
: Consider passingfunc_name
anddevice
by const reference for better performance.To avoid unnecessary string copying, consider passing the
func_name
anddevice
parameters by const reference:-inline TFE_Op* get_func_op(TFE_Context* ctx, - const std::string func_name, - const std::vector<TF_Function*>& funcs, - const std::string device, - TF_Status* status) { +inline TFE_Op* get_func_op(TFE_Context* ctx, + const std::string& func_name, + const std::vector<TF_Function*>& funcs, + const std::string& device, + TF_Status* status) {🧰 Tools
🪛 cppcheck
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 66-66: Function parameter 'device' should be passed by const reference.
(passedByValue)
84-107
: Consider passingfunc_name
anddevice
by const reference for better performance.To avoid unnecessary string copying, consider passing the
func_name
anddevice
parameters by const reference:-template <typename T> -inline T get_scalar(TFE_Context* ctx, - const std::string func_name, - const std::vector<TF_Function*>& funcs, - const std::string device, - TF_Status* status) { +template <typename T> +inline T get_scalar(TFE_Context* ctx, + const std::string& func_name, + const std::vector<TF_Function*>& funcs, + const std::string& device, + TF_Status* status) {🧰 Tools
🪛 cppcheck
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
109-129
: Consider passingfunc_name
anddevice
by const reference for better performance.To avoid unnecessary string copying, consider passing the
func_name
anddevice
parameters by const reference:-template <typename T> -inline std::vector<T> get_vector(TFE_Context* ctx, - const std::string func_name, - const std::vector<TF_Function*>& funcs, - const std::string device, - TF_Status* status) { +template <typename T> +inline std::vector<T> get_vector(TFE_Context* ctx, + const std::string& func_name, + const std::vector<TF_Function*>& funcs, + const std::string& device, + TF_Status* status) {🧰 Tools
🪛 cppcheck
[performance] 111-111: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'device' should be passed by const reference.
(passedByValue)
131-166
: Consider passingfunc_name
anddevice
by const reference for better performance.To avoid unnecessary string copying, consider passing the
func_name
anddevice
parameters by const reference:-inline std::vector<std::string> get_vector_string( - TFE_Context* ctx, - const std::string func_name, - const std::vector<TF_Function*>& funcs, - const std::string device, - TF_Status* status) { +inline std::vector<std::string> get_vector_string( + TFE_Context* ctx, + const std::string& func_name, + const std::vector<TF_Function*>& funcs, + const std::string& device, + TF_Status* status) {🧰 Tools
🪛 cppcheck
[performance] 133-133: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'device' should be passed by const reference.
(passedByValue)
source/lmp/tests/test_lammps_jax.py (2)
246-276
: Consider adding error handling for file read operations.The
_lammps
function sets up a LAMMPS instance with the provided data file and units. Consider adding error handling for thelammps.read_data
operation to gracefully handle cases where the data file is missing or inaccessible.try: lammps.read_data(data_file.resolve()) except FileNotFoundError: raise FileNotFoundError(f"Data file not found: {data_file}") except Exception as e: raise RuntimeError(f"Error reading data file: {data_file}. {str(e)}")
687-723
: Consider adding error handling for the subprocess call.The
test_pair_deepmd_mpi
function runs therun_mpi_pair_deepmd.py
script using MPI and checks the potential energy and model deviation output against the expected values. Consider adding error handling for thesp.check_call
to gracefully handle cases where the script fails to run.try: sp.check_call(...) except sp.CalledProcessError as e: raise RuntimeError(f"Failed to run MPI script: {e}")source/lmp/tests/test_lammps_dpa_jax.py (1)
246-279
: Consider extracting common LAMMPS setup code into a separate function.The
_lammps
function contains a lot of common setup code for creating a LAMMPS instance with specific settings. Consider extracting this setup code into a separate function to improve readability and maintainability.source/lmp/pair_deepmd.cpp (2)
527-529
: Use consistent integer types in the loopIn the loop starting at line 527,
ii
is declared assize_t
, whilenall
is of typeint
. Mixing signed and unsigned integer types can lead to potential issues. Consider changingii
toint
to match the type ofnall
.Apply this diff to change the loop variable type:
-for (size_t ii = 0; ii < nall; ++ii) { +for (int ii = 0; ii < nall; ++ii) {
575-577
: Refactor duplicate mapping logic into a helper functionThe mapping logic in lines 575-577 and 588-590 is duplicated. To improve maintainability and reduce code duplication, consider refactoring this code into a helper function or lambda.
Also applies to: 588-590
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (22)
.pre-commit-config.yaml
(1 hunks)doc/backend.md
(1 hunks)doc/install/install-from-source.md
(4 hunks)doc/model/dpa2.md
(1 hunks)source/api_c/include/c_api.h
(4 hunks)source/api_c/include/deepmd.hpp
(1 hunks)source/api_c/src/c_api.cc
(1 hunks)source/api_cc/include/DeepPotJAX.h
(1 hunks)source/api_cc/include/common.h
(1 hunks)source/api_cc/src/DeepPot.cc
(3 hunks)source/api_cc/src/DeepPotJAX.cc
(1 hunks)source/api_cc/tests/test_deeppot_jax.cc
(1 hunks)source/cmake/googletest.cmake.in
(1 hunks)source/lib/include/neighbor_list.h
(2 hunks)source/lmp/fix_dplr.cpp
(2 hunks)source/lmp/pair_deepmd.cpp
(3 hunks)source/lmp/tests/test_lammps_dpa_jax.py
(1 hunks)source/lmp/tests/test_lammps_jax.py
(1 hunks)source/tests/infer/deeppot_dpa.savedmodel/.gitignore
(1 hunks)source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb
(1 hunks)source/tests/infer/deeppot_sea.savedmodel/.gitignore
(1 hunks)source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
(1 hunks)
✅ Files skipped from review due to trivial changes (4)
- source/tests/infer/deeppot_dpa.savedmodel/.gitignore
- source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb
- source/tests/infer/deeppot_sea.savedmodel/.gitignore
- source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
🧰 Additional context used
🪛 LanguageTool
doc/backend.md
[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...
(SO_THAT_UNNECESSARY_COMMA)
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 66-66: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 111-111: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 133-133: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
🔇 Additional comments (91)
source/cmake/googletest.cmake.in (1)
14-14
: Verify compatibility with Google Test v1.14.0.
The upgrade from release-1.12.1 to v1.14.0 looks good. This version brings improvements including C++20 support and various bug fixes.
Let's verify if there are any potential compatibility issues:
✅ Verification successful
Based on the search results, I can generate the final response since we have sufficient information:
Google Test v1.14.0 upgrade appears safe to proceed
The codebase search reveals:
- No usage of death tests (
ASSERT_DEATH
,EXPECT_DEATH
,ASSERT_EXIT
,EXPECT_EXIT
) that could be affected by the improvements in v1.14.0 - While there are files containing C++20 related keywords, they are primarily in documentation or non-test files
- The test files under
source/tests/
don't show any patterns that would be impacted by breaking changes between v1.12.1 and v1.14.0
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for Google Test usage patterns that might be affected by the version upgrade
# Test: Look for death test usage which had improvements in v1.14.0
rg "ASSERT_DEATH|EXPECT_DEATH|ASSERT_EXIT|EXPECT_EXIT"
# Test: Check for C++20 features usage that might benefit from the upgrade
rg -l "concept|requires|constinit|consteval"
Length of output: 2447
doc/model/dpa2.md (1)
23-23
: Verify MPI rank limitation in tests and code.
The documentation states that "2 or more MPI ranks are not supported". Let's verify this limitation is properly enforced in the codebase.
✅ Verification successful
MPI rank limitation is properly enforced in JAX tests
The codebase correctly enforces the MPI rank limitation for the JAX backend through test files. Specifically in source/lmp/tests/test_lammps_dpa_jax.py
, there is an explicit test skip with the message "MPI is not supported" for the JAX backend:
@pytest.mark.skip("MPI is not supported")
def test_pair_deepmd_mpi(balance_args: list):
This test skip enforces the documented limitation that "2 or more MPI ranks are not supported" when using the JAX backend. The implementation matches the documentation's requirements.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for MPI rank validation in JAX-related code and tests
# Look for MPI rank checks in JAX-related files
echo "Searching for MPI rank validation in source files..."
rg -l "MPI|mpi.*rank|n.*ranks" --type cpp --type python source/
# Look for test cases that verify this limitation
echo "Searching for related test cases..."
rg -l "test.*mpi.*rank|test.*jax.*mpi" --type python source/lmp/tests/
Length of output: 396
Script:
#!/bin/bash
# Let's try a different approach to find JAX and MPI related code
# First, let's find all Python and C++ files that mention JAX
echo "Finding JAX-related files..."
rg -l "jax" --type-add 'python:*.py' --type-add 'cpp:*.{cpp,h,hpp}'
# Then search for MPI rank validation in those files
echo -e "\nSearching for MPI/rank related content in JAX files..."
rg "MPI|mpi.*rank|n.*ranks|world_size" --type-add 'python:*.py' --type-add 'cpp:*.{cpp,h,hpp}' -C 2
# Look specifically for test files
echo -e "\nSearching for test files mentioning both JAX and MPI..."
rg -l "jax.*test|test.*jax" --type-add 'python:*.py'
Length of output: 107291
Script:
#!/bin/bash
# Let's specifically look for JAX MPI rank validation in the codebase
# Search for JAX-related MPI rank checks in Python files
echo "Searching for JAX MPI rank validation..."
rg -C 3 "jax.*mpi|mpi.*jax|rank.*jax|jax.*rank" --type-add 'python:*.py'
# Look for JAX test files that handle MPI ranks
echo -e "\nSearching for JAX test files with MPI rank handling..."
rg -C 3 "test.*jax.*rank|test.*jax.*mpi" --type-add 'python:*.py'
# Look for any JAX-related MPI limitations or assertions
echo -e "\nSearching for JAX MPI limitations or assertions..."
rg -C 3 "assert.*rank|rank.*assert|mpi.*not.*support.*jax|jax.*not.*support.*mpi" --type-add 'python:*.py'
Length of output: 584
Script:
#!/bin/bash
# Let's look for evidence of JAX MPI limitations in LAMMPS integration code
# Search in LAMMPS-related files for JAX backend limitations
echo "Searching LAMMPS integration files for JAX limitations..."
rg -C 3 "jax.*backend|backend.*jax" source/lmp/
# Look for test files that skip MPI tests for JAX
echo -e "\nSearching for skipped MPI tests with JAX..."
rg -C 3 "skip.*mpi.*jax|skip.*jax.*mpi" source/lmp/tests/
# Look for atom_modify map requirements for JAX
echo -e "\nSearching for atom_modify map requirements..."
rg -C 3 "atom_modify.*map.*jax|jax.*atom_modify.*map" source/lmp/
Length of output: 512
Script:
#!/bin/bash
# Let's check the test files more thoroughly for JAX MPI limitations
# Look for skipped tests in JAX-related test files
echo "Checking JAX test files for MPI limitations..."
rg -C 5 "@pytest.mark.skip" source/lmp/tests/test_lammps_dpa_jax.py source/lmp/tests/test_lammps_jax.py
# Check for any MPI-related comments or docstrings in JAX tests
echo -e "\nChecking for MPI-related comments in JAX tests..."
rg -C 3 "# .*mpi|# .*MPI|\"\"\".*mpi|\"\"\".*MPI" source/lmp/tests/test_lammps_dpa_jax.py source/lmp/tests/test_lammps_jax.py
# Look for any JAX-specific test fixtures or setup
echo -e "\nChecking for JAX-specific test setups..."
rg -C 3 "def test.*jax" source/lmp/tests/test_lammps_dpa_jax.py source/lmp/tests/test_lammps_jax.py
Length of output: 2985
doc/backend.md (1)
34-36
: Verify TensorFlow dependency information.
Since JAX backend requires TensorFlow C++ interface for inference, we should verify if the minimum TensorFlow version requirement is documented correctly.
🧰 Tools
🪛 LanguageTool
[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...
(SO_THAT_UNNECESSARY_COMMA)
source/api_cc/include/common.h (1)
16-16
: LGTM! Verify the enum usage across the codebase.
The addition of JAX to the DPBackend enum is well-placed and maintains the Unknown value as the last enum option. This change aligns with the PR's objective of integrating JAX backend support.
Let's verify the usage of this enum across the codebase:
✅ Verification successful
Based on the search results, I can now make a definitive assessment of the review comment. The code shows that the DPBackend
enum is properly used across the codebase:
-
The enum is used in multiple files to determine the backend type based on model file extensions:
.pth
-> PyTorch.pb
-> TensorFlow.savedmodel
-> JAX- Unknown for unrecognized formats
-
The enum is used consistently in initialization code to create the appropriate backend implementation:
DeepPot.cc
,DeepTensor.cc
, andDataModifier.cc
all use the enum for backend selection- Each backend has proper error handling when not built/supported
- The
Unknown
value is properly handled as the last enum option
-
The order of enum values is preserved across the codebase, maintaining ABI compatibility:
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };
-
The addition of JAX follows the existing pattern and is properly integrated into the backend selection logic.
The enum modification is well-implemented and properly integrated across the codebase.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for DPBackend enum usage and ensure proper handling of JAX backend
# Test 1: Find all switch/if statements handling DPBackend to ensure JAX is properly handled
echo "Checking DPBackend handling..."
rg -A 5 "DPBackend"
# Test 2: Look for potential string conversions of backend types
echo "Checking backend string conversions..."
rg -A 5 "Unknown|TensorFlow|PyTorch|Paddle|JAX"
# Test 3: Find potential serialization/deserialization of backend types
echo "Checking serialization code..."
ast-grep --pattern 'switch ($backend) { $$$ }'
Length of output: 282653
source/lib/include/neighbor_list.h (1)
Line range hint 47-107
: Verify the mapping implementation across the codebase.
Let's ensure the mapping functionality is properly integrated:
✅ Verification successful
Memory management and usage of mapping is properly implemented
The verification shows that the mapping implementation is correctly integrated:
-
Memory management is properly handled:
- In LAMMPS integration (
fix_dplr.cpp
,pair_deepmd.cpp
), mapping is managed viastd::vector
- Device memory allocation/deallocation is properly paired in tests using
malloc_device_memory
/delete_device_memory
- In LAMMPS integration (
-
Usage is consistent across the codebase:
- Only set when single process and atom mapping is enabled (
comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE
) - Properly initialized with
-1
default values - Correctly maps atom tags using LAMMPS'
atom->map()
function - C API properly exposes the mapping functionality
- Only set when single process and atom mapping is enabled (
The implementation follows good practices for memory safety and has consistent usage patterns across the codebase.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check usage and implementation of the new mapping functionality
# Check for mapping usage in implementation files
echo "Checking mapping usage in implementation files..."
rg -l "set_mapping|->mapping" --type cpp
# Check for potential memory management of mapping
echo "Checking memory management..."
rg -B 3 -A 3 "delete.*mapping|free.*mapping" --type cpp
# Look for LAMMPS integration points
echo "Checking LAMMPS integration..."
rg -l "LAMMPS.*mapping|mapping.*LAMMPS" --type cpp
Length of output: 3124
Script:
#!/bin/bash
# Let's check the actual usage of mapping in the identified files
echo "Checking mapping usage in fix_dplr.cpp..."
rg -C 5 "mapping" source/lmp/fix_dplr.cpp
echo -e "\nChecking mapping usage in pair_deepmd.cpp..."
rg -C 5 "mapping" source/lmp/pair_deepmd.cpp
echo -e "\nChecking mapping usage in c_api.cc..."
rg -C 5 "mapping" source/api_c/src/c_api.cc
echo -e "\nChecking mapping usage in deepmd.hpp..."
rg -C 5 "mapping" source/api_c/include/deepmd.hpp
# Let's also check for any initialization or allocation of mapping
echo -e "\nChecking mapping initialization..."
rg -C 5 "mapping\s*=\s*new|malloc.*mapping" --type cpp
Length of output: 8714
source/api_c/include/c_api.h (3)
15-15
: API version increment is appropriate.
The increment from version 23 to 24 correctly reflects the addition of new functionality (DP_NlistSetMapping
).
Line range hint 34-71
: Documentation improvements enhance API clarity.
The expanded documentation for DP_NewNlist_comm
provides clear and comprehensive parameter descriptions, improving API usability.
81-89
: Verify implementation consistency.
Let's verify that the implementation in c_api.cc
matches this interface declaration.
✅ Verification successful
Implementation is consistent with interface
The implementation is correctly aligned with the interface:
- The
DP_NlistSetMapping
function inc_api.cc
correctly calls theset_mapping
method of theInputNlist
struct - The
InputNlist
struct inneighbor_list.h
has themapping
member variable andset_mapping
method properly defined - The mapping functionality is properly documented in both interface and implementation
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check implementation of DP_NlistSetMapping in c_api.cc
# and related changes in neighbor_list.h
# Test 1: Verify implementation exists in c_api.cc
rg -A 5 "DP_NlistSetMapping" "source/api_c/src/c_api.cc"
# Test 2: Verify InputNlist structure changes
rg -A 5 "class\s+InputNlist" "source/lib/include/neighbor_list.h"
Length of output: 285
Script:
#!/bin/bash
# Let's verify the InputNlist class and its set_mapping method
# Check for InputNlist class definition with more context
rg -B 2 -A 10 "class\s+InputNlist" source/lib/include/neighbor_list.h
# Check for set_mapping method declaration
rg "set_mapping" source/lib/include/neighbor_list.h
# Look for any tests related to mapping functionality
rg -l "set_mapping" source/lib/tests/
Length of output: 220
Script:
#!/bin/bash
# Let's try a different approach to find the InputNlist class definition
# using ast-grep which might handle the class structure better
# Search for InputNlist class definition
ast-grep --pattern 'class InputNlist {
$$$
}'
# Also check if there are any other files that might contain the implementation
fd -e h -e cc -e cpp | xargs rg -l "InputNlist"
Length of output: 2729
Script:
#!/bin/bash
# Let's examine the neighbor_list.h file more carefully
# and the test files to verify the implementation
# Check neighbor_list.h content
cat source/lib/include/neighbor_list.h
# Check the test implementation for neighbor list
cat source/lib/tests/test_neighbor_list.cc
Length of output: 16925
source/api_c/include/deepmd.hpp (1)
618-622
: Verify LAMMPS integration for the mapping functionality.
The method is well-integrated into the InputNlist
structure and provides the necessary functionality for LAMMPS integration. To ensure correctness:
- Verify that LAMMPS correctly passes the atom mapping array
- Confirm the mapping is properly used in molecular dynamics simulations
- Test with both real and ghost atoms to ensure proper handling
Let's verify the usage of this new mapping functionality:
✅ Verification successful
The mapping functionality is correctly integrated with LAMMPS
The verification shows proper integration:
- The
set_mapping
is correctly used in bothpair_deepmd.cpp
andfix_dplr.cpp
- The mapping is only applied when running on a single processor (
comm->nprocs == 1
) and when LAMMPS atom mapping is enabled (atom->map_style != Atom::MAP_NONE
) - The implementation is consistent across both regular and extended neighbor lists
- Documentation confirms the intended use for atom type mapping in LAMMPS integration
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for LAMMPS integration code that uses the new mapping functionality
# Test: Search for LAMMPS code that calls set_mapping
rg -t cpp "set_mapping"
# Test: Search for tests that verify the mapping functionality
rg -t cpp "test.*mapping"
Length of output: 534
Script:
#!/bin/bash
# Let's verify the integration more thoroughly
# Check the context of LAMMPS usage
rg -A 5 -B 5 "set_mapping" source/lmp/pair_deepmd.cpp
rg -A 5 -B 5 "set_mapping" source/lmp/fix_dplr.cpp
# Check if there are any tests for the mapping functionality
fd "test.*\.cpp" | xargs rg "mapping.*vec"
# Look for any documentation about mapping
fd "README|\.md" | xargs rg -i "mapping"
Length of output: 2787
source/api_cc/include/DeepPotJAX.h (2)
1-249
: Comprehensive implementation of the DeepPotJAX class.
The DeepPotJAX
class provides a well-structured and comprehensive implementation for integrating the JAX backend into the DeePMD-kit framework. It follows good practices by:
- Inheriting from the
DeepPotBase
class to ensure a consistent interface. - Providing multiple constructors and an
init
method for flexible initialization. - Implementing accessor methods for retrieving model properties.
- Overloading
computew
methods to support different data types and scenarios. - Utilizing assertions for error handling and ensuring proper initialization.
- Encapsulating TensorFlow C API objects for graph and session management.
- Integrating with the
NeighborListData
class for handling neighbor list data. - Providing a template
compute
method for evaluating energy and forces.
The class is thoroughly documented with Doxygen-style comments, enhancing readability and maintainability.
4-5
: Verify the compatibility of TensorFlow C API with JAX.
Please ensure that the TensorFlow C API (tensorflow/c/c_api.h
and tensorflow/c/eager/c_api.h
) is compatible with the JAX backend and does not introduce any conflicts or limitations. It's important to thoroughly test the integration to ensure smooth functionality.
✅ Verification successful
Let me gather more information about the JAX implementation and its relationship with TensorFlow C API.
Based on the code analysis, I can now provide a final response regarding the compatibility of TensorFlow C API with JAX:
The TensorFlow C API is compatible and properly integrated with JAX through JAX2TF conversion
The codebase shows a well-designed integration where:
- JAX models are converted to TensorFlow SavedModel format (JAX2TF) which allows using TensorFlow C API
- The
DeepPotJAX
class properly initializes and uses TensorFlow C API to load and execute the converted JAX models - The implementation includes proper error handling and memory management for the TF C API resources
- The code has been tested as evidenced by the test files (test_deeppot_jax.cc)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for potential compatibility issues or limitations in the codebase
rg -i "compatibility|limitation|issue|conflict" -C 5 --iglob "*.cc" --iglob "*.h"
# Check if there are any TODOs, FIXMEs, or HACKs related to JAX and TensorFlow C API integration
rg -i "TODO|FIXME|HACK" -C 5 --iglob "*.cc" --iglob "*.h" | rg -i "jax|tensorflow"
Length of output: 5496
Script:
#!/bin/bash
# Check for JAX-related implementations and backend configurations
rg -i "jax" -C 5 --iglob "*.cc" --iglob "*.h"
# Look for backend-related code and configurations
rg -i "backend" -C 5 --iglob "*.cc" --iglob "*.h"
# Search for DeepPotJAX implementation details
fd "DeepPotJAX" --exec cat {}
Length of output: 64163
source/api_cc/tests/test_deeppot_jax.cc (9)
1-15
: LGTM!
The header includes and namespace usages look good. The necessary headers for Google Test, DeepPot, and other utility functions are included correctly.
16-90
: LGTM!
The TestInferDeepPotAJAX
class is set up correctly as a parameterized test fixture. The member variables for coordinates, atom types, box dimensions, and expected values are initialized appropriately in the SetUp
method. The SetUp
method also verifies the sizes of the expected force and virial vectors relative to the number of atoms.
97-157
: LGTM!
The cpu_lmp_nlist
test case looks good:
- It correctly sets up the input data and neighbor list.
- It calls the
compute
method ofDeepPot
with the appropriate arguments. - It folds back the computed forces using the
_fold_back
helper function. - It verifies the sizes of the computed force and virial vectors.
- It checks the computed energy, forces, and virial against the expected values using appropriate tolerances.
- It repeats the computation with a different output index to ensure consistency.
159-240
: LGTM!
The cpu_lmp_nlist_atomic
test case looks good:
- It follows a similar structure to the
cpu_lmp_nlist
test case. - It additionally computes and verifies atomic energies and virials.
- It checks the computed atomic energies and virials against the expected values using appropriate tolerances.
- It repeats the computation with a different output index to ensure consistency.
242-302
: LGTM!
The cpu_lmp_nlist_2rc
test case looks good:
- It sets up the neighbor list with a cutoff of
2*rc
. - It calls the
compute
method ofDeepPot
with the appropriate arguments. - It verifies the sizes of the computed force and virial vectors.
- It checks the computed energy, forces, and virial against the expected values using appropriate tolerances.
- It repeats the computation with a different output index to ensure consistency.
304-364
: LGTM!
The cpu_lmp_nlist_type_sel
test case looks good:
- It adds virtual atoms of a different type to the input data.
- It sets up the neighbor list with the updated atom coordinates and types.
- It calls the
compute
method ofDeepPot
with the appropriate arguments. - It verifies the sizes of the computed force and virial vectors.
- It checks the computed energy, forces, and virial against the expected values (including the virtual atoms) using appropriate tolerances.
366-427
: LGTM!
The cpu_lmp_nlist_type_sel_atomic
test case looks good:
- It follows a similar structure to the
cpu_lmp_nlist_type_sel
test case. - It additionally computes and verifies atomic energies and virials.
- It checks the computed atomic energies and virials against the expected values using appropriate tolerances.
429-432
: LGTM!
The print_summary
test case looks good. It calls the print_summary
method of DeepPot
with an empty string argument.
434-439
: LGTM!
The get_type_map
test case looks good:
- It calls the
get_type_map
method ofDeepPot
to retrieve the type map. - It verifies that the retrieved type map matches the expected value of "O H".
source/lmp/fix_dplr.cpp (2)
442-449
: Verify mapping functionality in multi-processor runs
The current implementation initializes and populates mapping_vec
only when running on a single processor (comm->nprocs == 1
). If the mapping is required for correct behavior in multi-processor runs, omitting this initialization may lead to issues. Please verify whether the mapping should be handled when running on multiple processors and update the code accordingly if necessary.
482-484
: Confirm the necessity of mapping in multi-processor scenarios
Similarly, the neighbor list mapping is set only when comm->nprocs == 1
. If the mapping is essential for computations across multiple processors, this condition might prevent proper functionality in parallel executions. Ensure that the mapping is appropriately managed in multi-processor environments to avoid potential inconsistencies.
source/api_cc/src/DeepPotJAX.cc (16)
1-2
: LGTM!
The SPDX license identifier is correctly specified.
4-18
: LGTM!
The necessary header files are included correctly.
20-25
: LGTM!
The check_status
function is implemented correctly to check the TensorFlow operation status and throw an exception with the error message if an error occurs.
47-61
: LGTM!
The get_data_tensor_type
functions are implemented correctly to return the appropriate TensorFlow data type based on the input data type.
168-176
: LGTM!
The create_tensor
function is implemented correctly to create a TensorFlow tensor from the input data and shape.
178-191
: LGTM!
The add_input
function is implemented correctly to create a TensorFlow tensor from the input data and shape, and add it as an input to the TensorFlow operation.
193-207
: LGTM!
The tensor_to_vector
function is implemented correctly to copy the data from a TensorFlow tensor to a vector.
209-215
: LGTM!
The constructors for the DeepPotJAX
class are implemented correctly.
216-279
: LGTM!
The init
function is implemented correctly to initialize the DeepPotJAX
class by loading the TensorFlow saved model, creating a TensorFlow session, and retrieving the necessary parameters.
281-293
: LGTM!
The destructor for the DeepPotJAX
class is implemented correctly to clean up the TensorFlow resources if the class was initialized.
295-454
: LGTM!
The compute
function is implemented correctly to perform the computation using the TensorFlow saved model. It handles the input and output tensors, executes the TensorFlow function, and copies the results back to the output vectors.
456-470
: LGTM!
The template specialization for the compute
function with double
data type is implemented correctly.
472-486
: LGTM!
The template specialization for the compute
function with float
data type is implemented correctly.
488-490
: LGTM!
The get_type_map
function is implemented correctly to retrieve the type map.
492-552
: Verify the correctness of throwing exceptions for unimplemented methods.
The computew
methods with different parameter types are currently throwing exceptions to indicate that they are not implemented. Ensure that this behavior is intended and aligns with the overall design and requirements of the codebase.
553-580
: Verify the correctness of throwing exceptions for unimplemented methods.
The computew_mixed_type
methods with different parameter types are currently throwing exceptions to indicate that they are not implemented. Ensure that this behavior is intended and aligns with the overall design and requirements of the codebase.
source/lmp/tests/test_lammps_jax.py (27)
1-11
: LGTM!
The import statements look good and follow the standard Python conventions.
35-68
: LGTM!
The definition of the expected atomic energies (expected_ae
) and the calculation of the expected total energy (expected_e
) look correct.
70-208
: LGTM!
The definitions of the expected forces (expected_f
and expected_f2
) and virials (expected_v
and expected_v2
) look correct. The reshaping of the arrays is also done correctly.
210-223
: LGTM!
The definitions of the simulation box (box
), atomic coordinates (coord
), and atom types (type_OH
and type_HO
) look correct.
230-238
: LGTM!
The setup_module
function correctly sets up the necessary data files for the tests using the write_lmp_data
function.
241-244
: LGTM!
The teardown_module
function correctly removes the data files created during the setup phase.
279-283
: LGTM!
The lammps
fixture correctly creates a LAMMPS instance using the _lammps
function and yields it for use in tests. The fixture also closes the LAMMPS instance after the test.
286-290
: LGTM!
The lammps_type_map
fixture correctly creates a LAMMPS instance with a type map using the _lammps
function and yields it for use in tests. The fixture also closes the LAMMPS instance after the test.
293-297
: LGTM!
The lammps_real
fixture correctly creates a LAMMPS instance with real units using the _lammps
function and yields it for use in tests. The fixture also closes the LAMMPS instance after the test.
300-304
: LGTM!
The lammps_si
fixture correctly creates a LAMMPS instance with SI units using the _lammps
function and yields it for use in tests. The fixture also closes the LAMMPS instance after the test.
307-317
: LGTM!
The test_pair_deepmd
function correctly tests the DeepMD pair style by setting up the pair style, running the simulation, and asserting the potential energy and forces against the expected values.
319-340
: LGTM!
The test_pair_deepmd_virial
function correctly tests the DeepMD pair style with virial calculations by setting up the pair style, computing the virial tensor, running the simulation, and asserting the potential energy, forces, and virial tensor components against the expected values.
342-366
: LGTM!
The test_pair_deepmd_model_devi
function correctly tests the DeepMD pair style with model deviation output by setting up the pair style with two models, running the simulation, asserting the potential energy and forces against the expected values, and verifying the model deviation output against the expected values.
368-403
: LGTM!
The test_pair_deepmd_model_devi_virial
function correctly tests the DeepMD pair style with model deviation output and virial calculations by setting up the pair style with two models, computing the virial tensor, running the simulation, asserting the potential energy, forces, and virial tensor components against the expected values, and verifying the model deviation output against the expected values.
406-433
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative
function correctly tests the DeepMD pair style with model deviation output and relative atomic deviations by setting up the pair style with two models and the relative
parameter, running the simulation, asserting the potential energy and forces against the expected values, and verifying the model deviation output against the expected values calculated with the relative deviations.
435-465
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative_v
function correctly tests the DeepMD pair style with model deviation output and relative virial deviations by setting up the pair style with two models and the relative_v
parameter, running the simulation, asserting the potential energy and forces against the expected values, and verifying the model deviation output against the expected values calculated with the relative virial deviations.
468-477
: LGTM!
The test_pair_deepmd_type_map
function correctly tests the DeepMD pair style with a type map by setting up the pair style with a type map, running the simulation, and asserting the potential energy and forces against the expected values.
480-491
: LGTM!
The test_pair_deepmd_real
function correctly tests the DeepMD pair style with real units by setting up the pair style, running the simulation, and asserting the potential energy and forces against the expected values converted to real units.
494-519
: LGTM!
The test_pair_deepmd_virial_real
function correctly tests the DeepMD pair style with virial calculations and real units by setting up the pair style, computing the virial tensor, running the simulation, and asserting the potential energy, forces, and virial tensor components against the expected values converted to real units.
521-548
: LGTM!
The test_pair_deepmd_model_devi_real
function correctly tests the DeepMD pair style with model deviation output and real units by setting up the pair style with two models, running the simulation, asserting the potential energy and forces against the expected values converted to real units, and verifying the model deviation output against the expected values converted to real units.
551-592
: LGTM!
The test_pair_deepmd_model_devi_virial_real
function correctly tests the DeepMD pair style with model deviation output, virial calculations, and real units by setting up the pair style with two models, computing the virial tensor, running the simulation, asserting the potential energy, forces, and virial tensor components against the expected values converted to real units, and verifying the model deviation output against the expected values converted to real units.
595-625
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative_real
function correctly tests the DeepMD pair style with model deviation output, relative atomic deviations, and real units by setting up the pair style with two models and the relative
parameter converted to real units, running the simulation, asserting the potential energy and forces against the expected values converted to real units, and verifying the model deviation output against the expected values calculated with the relative deviations and converted to real units.
628-662
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative_v_real
function correctly tests the DeepMD pair style with model deviation output, relative virial deviations, and real units by setting up the pair style with two models and the relative_v
parameter converted to real units, running the simulation, asserting the potential energy and forces against the expected values converted to real units, and verifying the model deviation output against the expected values calculated with the relative virial deviations and converted to real units.
665-674
: LGTM!
The test_pair_deepmd_si
function correctly tests the DeepMD pair style with SI units by setting up the pair style, running the simulation, and asserting the potential energy and forces against the expected values converted to SI units.
677-686
: LGTM!
The test_pair_deepmd_mpi
function is correctly skipped if MPI or mpi4py is not installed on the system. The function is parametrized with different load balancing arguments.
28-33
: Verify the existence and accessibility of the referenced files.
Similar to the previous comment, ensure that the file paths system_file
, data_file
, data_file_si
, and data_type_map_file
point to existing and accessible files in the repository. If not, update the paths accordingly.
#!/bin/bash
# Description: Verify the existence of the referenced files.
# Test: Check if the files exist. Expect: No output if files exist, else list missing files.
fd -e lmp . | xargs -I {} sh -c 'test -f "$1" || echo "$1 not found"' -- {}
22-27
: Verify the existence and accessibility of the referenced files.
Ensure that the file paths pbtxt_file2
, pb_file
, and pb_file2
point to existing and accessible files in the repository. If not, update the paths accordingly.
✅ Verification successful
Both referenced files exist and are accessible
The verification confirms that both files exist and are accessible in the repository:
deeppot-1.pbtxt
exists atsource/tests/infer/deeppot-1.pbtxt
deeppot_sea.savedmodel
exists atsource/tests/infer/deeppot_sea.savedmodel
(as a directory containing model files)
The paths in the code are correctly referencing these files relative to the test file's location.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify the existence of the referenced files.
# Test: Check if the files exist. Expect: No output if files exist, else list missing files.
fd -e pbtxt -e pb . | xargs -I {} sh -c 'test -f "$1" || echo "$1 not found"' -- {}
Length of output: 161
Script:
#!/bin/bash
# Let's try a different approach to verify the file paths
# First, let's check for the pbtxt file
ls -l source/tests/infer/deeppot-1.pbtxt 2>/dev/null || echo "deeppot-1.pbtxt not found"
# Then check for the savedmodel file
ls -l source/tests/infer/deeppot_sea.savedmodel 2>/dev/null || echo "deeppot_sea.savedmodel not found"
# Let's also try to find these files anywhere in the repository
echo "Searching for files in the repository:"
find . -name "deeppot-1.pbtxt" -o -name "deeppot_sea.savedmodel"
Length of output: 612
source/lmp/tests/test_lammps_dpa_jax.py (22)
1-11
: LGTM!
The imports look good and cover the necessary dependencies for the test suite.
36-138
: LGTM!
The expected values for energy, forces, and virial stress are correctly defined as NumPy arrays.
210-223
: LGTM!
The box dimensions, coordinates, and atom types are correctly defined as NumPy arrays.
230-238
: LGTM!
The setup_module
function correctly sets up the necessary data files for the tests using the write_lmp_data
function.
241-244
: LGTM!
The teardown_module
function correctly removes the data files after the tests are completed.
281-306
: LGTM!
The fixture functions correctly create and close LAMMPS instances for different test scenarios (default, type map, real units, SI units).
309-319
: LGTM!
The test_pair_deepmd
function correctly tests the DeePMD pair style by setting up the LAMMPS instance, running the simulation, and comparing the potential energy and forces against the expected values.
321-342
: LGTM!
The test_pair_deepmd_virial
function correctly tests the computation of virial stress using the DeePMD pair style. It sets up the necessary compute and variables, runs the simulation, and compares the virial stress values against the expected values.
344-368
: LGTM!
The test_pair_deepmd_model_devi
function correctly tests the model deviation output of the DeePMD pair style. It sets up the pair style with two models, runs the simulation, and compares the model deviation values against the expected values.
370-406
: LGTM!
The test_pair_deepmd_model_devi_virial
function correctly tests the model deviation output along with the virial stress computation. It sets up the necessary compute, variables, and pair style, runs the simulation, and compares the model deviation and virial stress values against the expected values.
408-435
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative
function correctly tests the model deviation output with the atomic
and relative
keywords. It sets up the pair style with the appropriate keywords, runs the simulation, and compares the model deviation values against the expected values calculated with the relative factor.
437-467
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative_v
function correctly tests the model deviation output with the atomic
and relative_v
keywords. It sets up the pair style with the appropriate keywords, runs the simulation, and compares the model deviation values against the expected values calculated with the relative factor for the virial stress.
470-480
: LGTM!
The test_pair_deepmd_type_map
function correctly tests the DeePMD pair style with a type map. It sets up the LAMMPS instance with the type map data file, runs the simulation, and compares the potential energy and forces against the expected values.
482-493
: LGTM!
The test_pair_deepmd_real
function correctly tests the DeePMD pair style with real units. It sets up the LAMMPS instance with real units, runs the simulation, and compares the potential energy and forces against the expected values converted to real units.
496-521
: LGTM!
The test_pair_deepmd_virial_real
function correctly tests the computation of virial stress using the DeePMD pair style with real units. It sets up the necessary compute, variables, and pair style, runs the simulation, and compares the virial stress values against the expected values converted to real units.
523-551
: LGTM!
The test_pair_deepmd_model_devi_real
function correctly tests the model deviation output of the DeePMD pair style with real units. It sets up the pair style with two models, runs the simulation, and compares the model deviation values against the expected values converted to real units.
553-594
: LGTM!
The test_pair_deepmd_model_devi_virial_real
function correctly tests the model deviation output along with the virial stress computation using real units. It sets up the necessary compute, variables, and pair style, runs the simulation, and compares the model deviation and virial stress values against the expected values converted to real units.
597-627
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative_real
function correctly tests the model deviation output with the atomic
and relative
keywords using real units. It sets up the pair style with the appropriate keywords, runs the simulation, and compares the model deviation values against the expected values calculated with the relative factor converted to real units.
630-664
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative_v_real
function correctly tests the model deviation output with the atomic
and relative_v
keywords using real units. It sets up the pair style with the appropriate keywords, runs the simulation, and compares the model deviation values against the expected values calculated with the relative factor for the virial stress converted to real units.
667-676
: LGTM!
The test_pair_deepmd_si
function correctly tests the DeePMD pair style with SI units. It sets up the LAMMPS instance with SI units, runs the simulation, and compares the potential energy and forces against the expected values converted to SI units.
225-227
: Verify the command to convert pbtxt to pb format.
The command to convert the pbtxt file to pb format looks correct. However, please ensure that the deepmd
package is installed and the convert-from
command is available.
#!/bin/bash
# Description: Verify the deepmd package is installed and the convert-from command is available.
# Test: Check if the deepmd package is installed. Expect: No output if installed.
python -c "import deepmd"
# Test: Check if the convert-from command is available. Expect: Usage information.
python -m deepmd convert-from --help
679-726
: Verify the MPI test setup and execution.
The MPI test test_pair_deepmd_mpi
is currently skipped due to the @pytest.mark.skip("MPI is not supported")
decorator. If MPI support is intended to be tested, please ensure that the necessary dependencies (MPI and mpi4py) are installed and remove the skip decorator.
Additionally, verify that the run_mpi_pair_deepmd.py
script exists at the specified location and contains the correct code to run the MPI test.
#!/bin/bash
# Description: Verify the MPI test setup and execution.
# Test: Check if the run_mpi_pair_deepmd.py script exists. Expect: File path.
fd run_mpi_pair_deepmd.py tests
# Test: Check if the script runs without errors. Expect: No output.
mpirun -n 2 python tests/run_mpi_pair_deepmd.py
source/api_cc/src/DeepPot.cc (3)
10-10
: Inclusion of DeepPotJAX.h
is appropriate
Including "DeepPotJAX.h"
under #ifdef BUILD_TENSORFLOW
ensures that the JAX backend is only included when TensorFlow support is available.
45-47
: Correctly identify and set JAX backend based on model file extension
The added condition checks for the .savedmodel
extension to determine if the JAX backend should be used. This logic correctly allows the code to recognize and initialize JAX models.
65-72
: Initialization logic for JAX backend is well-implemented
The code properly initializes the JAX backend when BUILD_TENSORFLOW
is defined by creating a DeepPotJAX
instance. If TensorFlow support is not available, it throws a clear exception indicating that TensorFlow is required to load JAX2TF SavedModels.
source/lmp/pair_deepmd.cpp (1)
526-530
: Verify if mapping initialization is required for multiple processes
The mapping vector mapping_vec
is initialized only when comm->nprocs == 1
and atom->map_style != Atom::MAP_NONE
. If the mapping is also necessary when running with multiple processes, consider adjusting the condition to ensure mapping_vec
is properly initialized in all required scenarios.
To verify whether mapping should be initialized for multiple processes, search the codebase for other usages of atom->map
in parallel contexts:
source/tests/infer/deeppolar_new.pbtxt| | ||
source/tests/infer/deeppot_dpa.savedmodel/saved_model.pb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Additional large model files need exclusion
The scan revealed several model files exceeding the 1024KB limit that are not currently excluded:
source/tests/infer/deepdipole_new.pbtxt
(916KB)source/tests/infer/deepdipole_fake.pbtxt
(1002KB)source/tests/infer/deeppot_sea.savedmodel/saved_model.pb
(630KB)
These files should also be added to the exclusion list in .pre-commit-config.yaml
to maintain consistency with the current approach of excluding large model files.
🔗 Analysis chain
LGTM! Consider checking for additional model files.
The exclusion patterns for the new model files are consistent with the existing pattern. However, since this PR is part of the JAX backend integration, we should verify if there are other similar model files that might need exclusion.
Let me try a different approach to find potential large model files by using fd
and checking file sizes directly.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for other potential large model files that might need exclusion
# Look for similar files in the test directories that might exceed the size limit
echo "Searching for .pb and .pbtxt files in test directories..."
find source/tests -type f \( -name "*.pb" -o -name "*.pbtxt" \) -exec ls -lh {} \; | \
awk '{if($5+0 > 1024) print "Warning: Large file found:", $9, "Size:", $5}'
Length of output: 292
Script:
#!/bin/bash
# Search for model files and their sizes in the source/tests directory
echo "Searching for model files in tests directory..."
fd -t f "\.(pb|pbtxt|savedmodel)$" source/tests -x ls -lh {}
# Also check the pre-commit config for current exclusions
echo -e "\nCurrent exclusions in pre-commit config:"
grep -A 10 "check-added-large-files" .pre-commit-config.yaml
Length of output: 2875
/** | ||
* @brief Set mapping for this neighbor list. | ||
*/ | ||
void set_mapping(int* mapping_) { mapping = mapping_; }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve setter method safety and documentation.
The current implementation could benefit from:
- Input validation to prevent null pointer assignment
- Documentation about:
- Pointer ownership expectations
- Required size of the mapping array
- Thread safety considerations
Consider this safer implementation:
/**
* @brief Set mapping for this neighbor list.
+ * @param mapping_ Pointer to an array of size nall containing indices mapping all atoms to real atoms.
+ * The pointer must remain valid for the lifetime of this object.
+ * @throw std::invalid_argument if mapping_ is null
*/
- void set_mapping(int* mapping_) { mapping = mapping_; };
+ void set_mapping(int* mapping_) {
+ if (!mapping_) {
+ throw std::invalid_argument("mapping pointer cannot be null");
+ }
+ mapping = mapping_;
+ }
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
/** | |
* @brief Set mapping for this neighbor list. | |
*/ | |
void set_mapping(int* mapping_) { mapping = mapping_; }; | |
/** | |
* @brief Set mapping for this neighbor list. | |
* @param mapping_ Pointer to an array of size nall containing indices mapping all atoms to real atoms. | |
* The pointer must remain valid for the lifetime of this object. | |
* @throw std::invalid_argument if mapping_ is null | |
*/ | |
void set_mapping(int* mapping_) { | |
if (!mapping_) { | |
throw std::invalid_argument("mapping pointer cannot be null"); | |
} | |
mapping = mapping_; | |
} |
@@ -375,7 +377,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value | |||
|
|||
**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF` | |||
|
|||
{{ tensorflow_icon }} Whether building the TensorFlow backend. | |||
{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider separating backend enable flags.
The current ENABLE_TENSORFLOW
variable name doesn't accurately reflect that it enables both TensorFlow and JAX backends. This could be confusing for users who only want to use JAX.
Consider either:
- Renaming to something more inclusive like
ENABLE_TF_JAX_BACKEND
- Creating a separate
ENABLE_JAX
flag
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) { | ||
nl->nl.set_mapping(mapping); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add safety checks and documentation for the new mapping function.
The implementation needs several improvements for robustness and clarity:
- Add null pointer validation for both parameters
- Document the requirements for the mapping array (size, lifetime, ownership)
- Consider adding error handling similar to other functions in the codebase
Consider applying this diff:
+/**
+ * @brief Set the mapping array for the neighbor list
+ * @param[in] nl The neighbor list object
+ * @param[in] mapping The mapping array that will be used by the neighbor list.
+ * The array must remain valid for the lifetime of the neighbor list
+ * or until a new mapping is set.
+ * @note The function does not take ownership of the mapping array.
+ */
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
+ if (nl == nullptr || mapping == nullptr) {
+ return;
+ }
nl->nl.set_mapping(mapping);
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) { | |
nl->nl.set_mapping(mapping); | |
} | |
/** | |
* @brief Set the mapping array for the neighbor list | |
* @param[in] nl The neighbor list object | |
* @param[in] mapping The mapping array that will be used by the neighbor list. | |
* The array must remain valid for the lifetime of the neighbor list | |
* or until a new mapping is set. | |
* @note The function does not take ownership of the mapping array. | |
*/ | |
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) { | |
if (nl == nullptr || mapping == nullptr) { | |
return; | |
} | |
nl->nl.set_mapping(mapping); | |
} |
sp.check_output( | ||
f"{sys.executable} -m deepmd convert-from pbtxt -i {pbtxt_file2.resolve()} -o {pb_file2.resolve()}".split() | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Files are not present at the expected locations
The code expects the files at:
source/tests/infer/deeppot-1.pbtxt
(pbtxt_file2)source/lmp/tests/graph2.pb
(pb_file2)
While deeppot-1.pbtxt
exists at the correct location, graph2.pb
is missing. This file is expected to be generated by the convert-from
command but its parent directory must exist first.
🔗 Analysis chain
Verify the existence and accessibility of the pbtxt_file2
and pb_file2
.
Ensure that the pbtxt_file2
and pb_file2
files exist and are accessible. If not, the convert-from
command will fail.
Let me search for the actual paths of these files in the codebase.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify the existence of the `pbtxt_file2` and `pb_file2` files.
# Test: Check if the files exist. Expect: No output if files exist, else list missing files.
test -f "tests/infer/deeppot-1.pbtxt" || echo "tests/infer/deeppot-1.pbtxt not found"
test -f "source/lmp/tests/graph2.pb" || echo "source/lmp/tests/graph2.pb not found"
Length of output: 242
Script:
#!/bin/bash
# Search for pbtxt_file2 and pb_file2 variable assignments
rg -A 3 "pbtxt_file2|pb_file2" source/lmp/tests/test_lammps_jax.py
# Find any .pbtxt files in the repository
fd -e pbtxt
# Find any .pb files in the repository
fd -e pb
Length of output: 3216
pbtxt_file2 = ( | ||
Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot-1.pbtxt" | ||
) | ||
pb_file = ( | ||
Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa.savedmodel" | ||
) | ||
pb_file2 = Path(__file__).parent / "graph2.pb" | ||
system_file = Path(__file__).parent.parent.parent / "tests" | ||
data_file = Path(__file__).parent / "data.lmp" | ||
data_file_si = Path(__file__).parent / "data.si" | ||
data_type_map_file = Path(__file__).parent / "data_type_map.lmp" | ||
md_file = Path(__file__).parent / "md.out" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
File paths need to be corrected
The verification reveals several issues with the file paths:
- While
pbtxt_file2
(deeppot-1.pbtxt) andpb_file
(deeppot_dpa.savedmodel) exist in the correct location atsource/tests/infer/
, the other files are missing:pb_file2
(graph2.pb)data_file
(data.lmp)data_file_si
(data.si)data_type_map_file
(data_type_map.lmp)md_file
(md.out)
These test data files need to be created in the source/lmp/tests/
directory or their paths need to be updated to point to existing files.
🔗 Analysis chain
Verify the file paths are correct and exist.
Please ensure that the specified file paths for pbtxt_file2
, pb_file
, pb_file2
, system_file
, data_file
, data_file_si
, data_type_map_file
, and md_file
are correct and the files exist at those locations.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify the specified file paths exist.
# Test: Check if the files exist. Expect: No output if files exist.
fd -e pbtxt -e pb -e lmp -e out . tests/infer
Length of output: 445
Script:
#!/bin/bash
# Let's verify the existence of files using more specific paths based on the code
# Check for pbtxt and savedmodel files in tests/infer
ls -la source/tests/infer/deeppot-1.pbtxt source/tests/infer/deeppot_dpa.savedmodel 2>/dev/null
# Check for files in the test directory
ls -la source/lmp/tests/graph2.pb source/lmp/tests/data.lmp source/lmp/tests/data.si source/lmp/tests/data_type_map.lmp source/lmp/tests/md.out 2>/dev/null
# Additional search to find these files in case paths are different
find source -type f \( -name "deeppot-1.pbtxt" -o -name "deeppot_dpa.savedmodel" -o -name "graph2.pb" -o -name "data.lmp" -o -name "data.si" -o -name "data_type_map.lmp" -o -name "md.out" \)
Length of output: 956
Script:
#!/bin/bash
# Let's check the actual test directory structure to understand the correct paths
# First, let's find the exact location of test_lammps_dpa_jax.py to understand its context
find source -type f -name "test_lammps_dpa_jax.py"
# Then, let's look at the directory structure around this file
find source -type d -name "tests" -o -name "infer"
# Finally, let's check for the test data files in the same directory as test_lammps_dpa_jax.py
find source -type f -name "*.lmp" -o -name "*.si" -o -name "*.pb" -o -name "*.out"
Length of output: 594
📝 WalkthroughWalkthroughThis pull request introduces several enhancements across various files, focusing on the integration of the JAX backend into the DeePMD-kit framework. Key changes include updates to the pre-commit configuration, documentation for backend support, and the introduction of new functionalities in the API and testing files. The JAX backend is now supported alongside TensorFlow and PyTorch, with corresponding updates made to installation instructions, code structures, and testing frameworks to ensure comprehensive functionality and compatibility. Changes
Possibly related issues
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 18
🧹 Outside diff range and nitpick comments (23)
source/cmake/googletest.cmake.in (1)
Line range hint
1-1
: Consider updating minimum CMake version requirement.The current minimum CMake version (2.8.2) is quite old. Google Test 1.14.0 might benefit from newer CMake features. Consider updating to a more recent version (e.g., 3.10 or newer) to ensure better compatibility and access to modern CMake features.
doc/model/dpa2.md (1)
23-24
: Consider adding more context about the MPI limitation.While the limitation is clearly stated, it would be helpful to provide more context about why this limitation exists and whether there are plans to support multiple MPI ranks in the future.
Consider expanding the explanation:
-When using the JAX backend, 2 or more MPI ranks are not supported. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command. +When using the JAX backend, 2 or more MPI ranks are not currently supported due to JAX's parallel processing model. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command to ensure proper atom indexing and data mapping between LAMMPS and DPA-2.doc/backend.md (2)
34-36
: Enhance clarity of JAX backend documentation.The documentation would benefit from the following improvements:
- Explain the differences and use cases for
.xlo
vs.jax
formats- Provide more details about GPU device specificity, such as:
- Whether this applies to all formats or just specific ones
- How users can identify if a model is GPU-specific
- Consider reorganizing the version requirements to be more prominent
Here's a suggested improvement:
Only the `.savedmodel` format supports C++ inference, which needs the TensorFlow C++ interface. -The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs. +The model is device-specific: models generated on GPU devices cannot be executed on CPUs. This applies to all JAX model formats (.xlo, .savedmodel, and .jax). You can identify GPU-specific models by checking the device information in the model metadata. Currently, this backend is developed actively, and has no support for training.🧰 Tools
🪛 LanguageTool
[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...(SO_THAT_UNNECESSARY_COMMA)
35-35
: Remove unnecessary comma before "so that".The comma before "so that" is grammatically incorrect and should be removed.
-The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs. +The model is device-specific so that the model generated on the GPU device cannot be run on the CPUs.🧰 Tools
🪛 LanguageTool
[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...(SO_THAT_UNNECESSARY_COMMA)
source/api_cc/include/common.h (1)
16-16
: Consider adding enum documentation.Since this is a public API header, consider adding documentation comments for the
DPBackend
enum to describe each backend option and their implications.Example improvement:
+/** + * @brief Supported deep learning backends + * @details + * - TensorFlow: TensorFlow backend support + * - PyTorch: PyTorch backend support + * - Paddle: PaddlePaddle backend support + * - JAX: JAX backend support (requires JAX 0.4.33+) + * - Unknown: Represents an unrecognized backend + */ enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };source/lib/include/neighbor_list.h (3)
47-48
: Documentation could be more specific about mapping requirements.The comment should clarify:
- The expected size relationship with
nall
- Whether negative values are allowed in the mapping
- The ownership/lifetime of the pointer
- The relationship with LAMMPS atom indexing
Consider expanding the comment to:
- /// mapping from all atoms to real atoms, in the size of nall + /// Mapping array from all atoms (including ghost atoms) to real atoms. + /// Size must match nall (total number of atoms including ghost atoms). + /// The caller retains ownership of the pointer.
104-107
: Enhance documentation and consider adding validation.The method documentation should be as detailed as other methods in this file. Also, consider adding nullptr validation.
Consider these improvements:
/** - * @brief Set mapping for this neighbor list. + * @brief Set the mapping array for this neighbor list. + * @param mapping_ Pointer to an integer array mapping all atoms to real atoms. + * Must not be nullptr and must have size matching nall. + * @note The caller retains ownership of the mapping array and must ensure + * its lifetime exceeds that of the neighbor list. */ - void set_mapping(int* mapping_) { mapping = mapping_; }; + void set_mapping(int* mapping_) { + assert(mapping_ != nullptr); + mapping = mapping_; + };
Line range hint
47-107
: Consider architectural improvements for safer pointer management.While the current implementation follows existing patterns, consider these architectural improvements:
- Store
nall
as a member to enable size validation- Consider using
std::vector<int>
orstd::unique_ptr<int[]>
for clearer ownership- Add a method to validate the mapping array size
These changes would improve safety but would require more significant refactoring.
Would you like me to propose a more detailed design for these improvements?
doc/install/install-from-source.md (2)
300-302
: Add version compatibility information for JAX backend.While the documentation correctly states that JAX backend uses TensorFlow's C++ library, it would be helpful to specify:
- Minimum supported JAX version
- Version compatibility requirements between TensorFlow and JAX
Line range hint
380-396
: Clarify JAX-specific configuration options.The documentation updates for
ENABLE_TENSORFLOW
andTENSORFLOW_ROOT
now include JAX backend support. However, please clarify:
- Are there any JAX-specific CMake variables that users need to set?
- Are there any additional configuration steps needed when using JAX vs TensorFlow?
source/api_cc/tests/test_deeppot_jax.cc (2)
72-73
: Consider making the model file path configurable.The model file path is hardcoded which could make the tests less portable and harder to maintain. Consider:
- Using environment variables
- Making it a configurable parameter
- Using a test fixture to manage test resources
- std::string file_name = "../../tests/infer/deeppot_sea.savedmodel"; + const char* model_path = std::getenv("DEEPMD_TEST_MODEL_PATH"); + std::string file_name = model_path ? model_path : "../../tests/infer/deeppot_sea.savedmodel";
97-427
: Consider reducing code duplication in test cases.The test cases share similar setup and verification patterns. Consider extracting common test logic into helper functions to improve maintainability and reduce duplication. For example:
+ template <typename VALUETYPE> + void verify_results( + const std::vector<VALUETYPE>& force, + const std::vector<VALUETYPE>& virial, + const std::vector<VALUETYPE>& expected_f, + const std::vector<VALUETYPE>& expected_tot_v, + double ener, + double expected_tot_e, + int natoms) { + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + }source/api_c/src/c_api.cc (1)
46-48
: Consider consistent error handling.The function should follow the established error handling pattern used throughout the codebase. Consider using the
DP_REQUIRES_OK
macro or setting the exception string in theDP_Nlist
object when errors occur.void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) { - nl->nl.set_mapping(mapping); + try { + nl->nl.set_mapping(mapping); + } catch (const std::exception& e) { + nl->exception = std::string(e.what()); + } }source/api_c/include/deepmd.hpp (1)
618-622
: Enhance documentation while implementation looks good.The implementation is correct and follows the established pattern of delegating to the C API. However, the documentation could be more detailed to help users understand:
- The expected size and lifetime requirements of the mapping array
- Whether the pointer is stored or just used temporarily
- The purpose and typical use cases for this mapping
Consider expanding the documentation like this:
/** * @brief Set mapping for this neighbor list. * @param mapping mapping from all atoms to real atoms, in size nall. + * @details The mapping array should remain valid for the lifetime of the neighbor list + * or until the next call to set_mapping. The mapping is typically used to handle + * ghost/virtual atoms by mapping them to their corresponding real atoms. + * @note The size of the mapping array should match the total number of atoms (nall) + * in the system. */source/api_cc/include/DeepPotJAX.h (2)
29-31
: Pass integer parameters by value instead of byconst int&
Passing integers like
gpu_rank
byconst int&
introduces unnecessary indirection since integers are small and copying them is inexpensive. It is more efficient and idiomatic in C++ to pass them by value.Apply this diff to update the parameter passing:
-DeepPotJAX(const std::string& model, - const int& gpu_rank = 0, - const std::string& file_content = ""); +DeepPotJAX(const std::string& model, + int gpu_rank = 0, + const std::string& file_content = ""); -void init(const std::string& model, - const int& gpu_rank = 0, - const std::string& file_content = ""); +void init(const std::string& model, + int gpu_rank = 0, + const std::string& file_content = "");Also applies to: 39-41
63-65
: Clarify the constant return value innumb_types_spin()
The method
numb_types_spin()
always returns0
. If spin types are not supported inDeepPotJAX
, consider documenting this behavior or modifying the method to reflect the intended use.You could update the method to throw an exception or assert if spin types are not applicable:
-int numb_types_spin() const { - assert(inited); - return 0; -}; +int numb_types_spin() const { + assert(inited); + throw std::runtime_error("Spin types are not supported in DeepPotJAX."); +};Alternatively, update the documentation to specify that this method returns
0
because spin types are unsupported in this implementation.source/api_cc/src/DeepPotJAX.cc (2)
35-38
: Optimize string manipulation by using 'resize' instead of 'substr'In
find_function
, the callname_ = name_.substr(0, pos + 1);
may result in self-assignment when the substring is the same as the original string. Usingresize
is more efficient and avoids unnecessary copying.Apply this diff to improve efficiency:
std::string::size_type pos = name_.find_last_not_of("0123456789_"); if (pos != std::string::npos) { - name_ = name_.substr(0, pos + 1); + name_.resize(pos + 1); }🧰 Tools
🪛 cppcheck
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
284-286
: Ensure proper cleanup in destructor by checking 'status'In the destructor
~DeepPotJAX()
, after callingTF_DeleteSession
, it's good practice to check the status to ensure that the session was deleted without errors.Consider adding a status check:
TF_DeleteSession(session, status); +check_status(status); TF_DeleteGraph(graph);
source/lmp/tests/test_lammps_jax.py (1)
307-723
: Refactor repetitive test code into helper functions.Multiple test functions contain similar code blocks, such as setting up
pair_style
,pair_coeff
, running simulations, and performing assertions. Refactoring these blocks into helper functions can enhance readability and maintainability.source/lmp/tests/test_lammps_dpa_jax.py (2)
246-279
: Consider extracting common LAMMPS setup code into a separate function.The
_lammps
function contains a lot of common setup code for initializing the LAMMPS instance with specific units, boundary conditions, atom styles, etc. Consider extracting this into a separatesetup_lammps
function that can be reused across tests for better code organization and reusability.
679-726
: Consider enabling the skipped MPI tests if possible.The MPI tests are currently skipped due to MPI and mpi4py not being installed. Consider enabling these tests if possible by installing the necessary dependencies. MPI tests are important to ensure the DeepMD pair style works correctly in parallel.
source/api_cc/src/DeepPot.cc (2)
45-47
: Refactor file extension checks into a helper functionThe repeated pattern of checking model file extensions (e.g.,
.pth
,.pb
,.savedmodel
) could be refactored into a helper function to improve maintainability and reduce code duplication.
69-71
: Improve clarity of exception messageConsider rephrasing the exception message for better clarity.
Apply this diff to update the exception message:
throw deepmd::deepmd_exception( - "TensorFlow backend is not built, which is used to load JAX2TF " + "TensorFlow backend is not built; it is required to load JAX2TF " "SavedModels");
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (22)
.pre-commit-config.yaml
(1 hunks)doc/backend.md
(1 hunks)doc/install/install-from-source.md
(4 hunks)doc/model/dpa2.md
(1 hunks)source/api_c/include/c_api.h
(4 hunks)source/api_c/include/deepmd.hpp
(1 hunks)source/api_c/src/c_api.cc
(1 hunks)source/api_cc/include/DeepPotJAX.h
(1 hunks)source/api_cc/include/common.h
(1 hunks)source/api_cc/src/DeepPot.cc
(3 hunks)source/api_cc/src/DeepPotJAX.cc
(1 hunks)source/api_cc/tests/test_deeppot_jax.cc
(1 hunks)source/cmake/googletest.cmake.in
(1 hunks)source/lib/include/neighbor_list.h
(2 hunks)source/lmp/fix_dplr.cpp
(2 hunks)source/lmp/pair_deepmd.cpp
(3 hunks)source/lmp/tests/test_lammps_dpa_jax.py
(1 hunks)source/lmp/tests/test_lammps_jax.py
(1 hunks)source/tests/infer/deeppot_dpa.savedmodel/.gitignore
(1 hunks)source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb
(1 hunks)source/tests/infer/deeppot_sea.savedmodel/.gitignore
(1 hunks)source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
(1 hunks)
✅ Files skipped from review due to trivial changes (4)
- source/tests/infer/deeppot_dpa.savedmodel/.gitignore
- source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb
- source/tests/infer/deeppot_sea.savedmodel/.gitignore
- source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
🧰 Additional context used
🪛 LanguageTool
doc/backend.md
[typographical] ~35-~35: The conjunction “so that” does not require a comma.
Context: ... interface. The model is device-specific, so that the model generated on the GPU device c...
(SO_THAT_UNNECESSARY_COMMA)
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 66-66: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 111-111: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 133-133: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
🔇 Additional comments (41)
source/cmake/googletest.cmake.in (2)
Line range hint 4-11
: LGTM: Repository fallback mechanism.
The fallback mechanism to use Gitee mirror when USE_GITEE_GTEST
is set is a good practice for users in regions with limited GitHub access.
14-14
: Verify compatibility with Google Test v1.14.0.
The upgrade from v1.12.1 to v1.14.0 brings improvements but may introduce breaking changes. Notable changes in v1.14.0 include:
- Support for C++20
- Improved death test implementation
- New matchers and assertions
Let's check for potential compatibility issues:
✅ Verification successful
Based on the analysis of the test files and Google Test usage patterns, I can now provide a definitive response about the compatibility with Google Test v1.14.0:
The upgrade to Google Test v1.14.0 is safe to proceed.
The codebase uses standard Google Test features that are stable across versions:
- Basic test fixtures using
::testing::Test
- Standard test macros like
TEST_F
andTYPED_TEST
- No usage of death tests (
ASSERT_DEATH
,EXPECT_DEATH
, etc.) - No custom test event listeners or printers
- No deprecated features that were removed in v1.14.0
The test suite primarily uses:
- Type-parameterized tests with
TYPED_TEST_SUITE
andTYPED_TEST
- Test fixtures with basic class inheritance from
::testing::Test
- Standard assertions and test declarations
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Search for Google Test usage patterns that might be affected by the version upgrade
# Look for death tests that might need updates
rg -A 5 "ASSERT_DEATH|EXPECT_DEATH|ASSERT_EXIT|EXPECT_EXIT"
# Check for custom test event listeners that might need updates
ast-grep --pattern 'class $_ : public testing::TestEventListener'
# Look for custom test printers that might need updates
ast-grep --pattern 'class $_ : public testing::PrintTo'
Length of output: 187
Script:
#!/bin/bash
# Let's check for actual test files and Google Test usage patterns
# Find test files
fd -e cpp -e cc -e h test
# Look for Google Test includes and main patterns
rg -l "gtest|gmock"
# Check CMake files for any Google Test specific configurations
fd -e cmake | xargs cat
# Look for any test executables or test suites
rg -l "TEST_F|TEST_P|TEST|INSTANTIATE_TEST"
Length of output: 30214
Script:
#!/bin/bash
# Let's examine the test files to understand the Google Test usage patterns
# Look for specific Google Test features that might be affected by version changes
rg -A 3 "TEST_F|TEST_P|TYPED_TEST|INSTANTIATE_TEST" source/lib/tests source/api_c/tests source/api_cc/tests
# Check for any custom Google Test configurations or setup
rg -A 3 "testing::InitGoogleTest|testing::Environment|testing::Test" source/lib/tests source/api_c/tests source/api_cc/tests
# Look for any death test implementations
rg -A 3 "ASSERT_DEATH|EXPECT_DEATH|ASSERT_EXIT|EXPECT_EXIT" source/lib/tests source/api_c/tests source/api_cc/tests
# Check for any test fixture setup that might be affected
rg -A 3 "class.*: public ::testing::Test" source/lib/tests source/api_c/tests source/api_cc/tests
Length of output: 179258
doc/model/dpa2.md (2)
21-22
: LGTM! Clear section header with appropriate backend icon.
The section header is well-formatted and correctly uses the JAX icon to indicate backend-specific content.
25-27
: LGTM! Clear example with proper LAMMPS syntax.
The code example is correctly formatted and demonstrates the required command.
.pre-commit-config.yaml (1)
18-19
: LGTM! Exclusion patterns properly added for model files.
The additional exclusions for deeppolar_new.pbtxt
and deeppot_dpa.savedmodel/saved_model.pb
are correctly formatted and aligned with the PR's objective of supporting DPA-2 models.
Let's verify the size of these excluded files to ensure they indeed need to be excluded:
✅ Verification successful
Exclusion patterns correctly added for large model files
The verification confirms that both files are indeed large binary files that warrant exclusion:
source/tests/infer/deeppolar_new.pbtxt
: 2.6MBsource/tests/infer/deeppot_dpa.savedmodel/saved_model.pb
: 1.5MB
These sizes justify their exclusion from pre-commit checks to maintain good performance of the git hooks.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check if the excluded files exist and their sizes
# to validate the need for exclusion from large file checks
echo "Checking sizes of excluded files..."
find . -type f \( -name "deeppolar_new.pbtxt" -o -name "saved_model.pb" \) -exec ls -lh {} \;
Length of output: 483
source/api_cc/tests/test_deeppot_jax.cc (4)
1-15
: LGTM! Well-organized includes and proper license header.
The file structure follows best practices with system headers, C++ standard library headers, and project headers properly organized.
16-93
: LGTM! Well-structured test class with comprehensive documentation.
The test class is well-documented with the numpy code that generated the test data, making it easy to understand and maintain. The setup properly validates data sizes and initializes all necessary components.
429-439
: LGTM! Good coverage of utility functions.
The tests properly verify the summary printing and type map retrieval functionality.
132-138
: Verify EPSILON definition.
The code uses EPSILON for floating-point comparisons, but its definition is not visible in this file. Ensure it's properly defined with an appropriate value for the precision requirements.
✅ Verification successful
Let me verify if the test file includes the necessary header with EPSILON definition.
EPSILON is properly defined and included
The test file correctly includes "test_utils.h" which defines EPSILON as (std::is_same<VALUETYPE, double>::value ? 1e-10 : 1e-4)
. This is a suitable definition that adjusts the precision threshold based on the value type being used (1e-10 for double, 1e-4 for float).
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for EPSILON definition in the codebase
rg -A 1 "^#define\s+EPSILON\b"
Length of output: 657
Script:
#!/bin/bash
# Check if test_utils.h is included in test_deeppot_jax.cc
rg -A 1 '#include.*test_utils\.h' source/api_cc/tests/test_deeppot_jax.cc
Length of output: 100
source/api_c/include/c_api.h (1)
15-15
: LGTM: API version increment is appropriate.
The increment from 23 to 24 aligns with the addition of the new DP_NlistSetMapping
function.
source/api_c/src/c_api.cc (2)
46-48
: LGTM - Good integration with existing codebase.
The function is well-integrated into the codebase:
- Follows established naming conventions
- Placed appropriately with other DP_Nlist functions
- Maintains consistent implementation pattern
46-48
: Verify usage and impact of the new mapping functionality.
Let's verify how this new mapping functionality is used in the codebase and ensure it doesn't affect existing neighbor list behavior.
✅ Verification successful
New mapping functionality is properly integrated and used in existing codebase
The verification shows that:
- The
set_mapping
functionality is a well-integrated part of the neighbor list implementation insource/lib/include/neighbor_list.h
- It's actively used in production code:
- LAMMPS integration (
pair_deepmd.cpp
,fix_dplr.cpp
) - Core neighbor list implementation
- C/C++ API layers
- LAMMPS integration (
- The implementation is consistent across all layers (core, API, and integrations)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for potential usage and impact of the new mapping functionality
# Test 1: Find potential callers of the new function
echo "Searching for potential callers of DP_NlistSetMapping..."
rg -l "DP_NlistSetMapping"
# Test 2: Find related neighbor list mapping code
echo "Searching for related neighbor list mapping code..."
rg -l "set_mapping.*neighbor"
# Test 3: Check for potential test files
echo "Searching for related test files..."
fd -g "*test*" | rg -l "Nlist|mapping"
Length of output: 503
Script:
#!/bin/bash
# Let's examine the implementation and usage more thoroughly
# Check the header file definition
echo "=== C API Header Definition ==="
cat source/api_c/include/c_api.h | grep -A 5 "DP_NlistSetMapping"
# Check the C++ header definition
echo -e "\n=== C++ Header Definition ==="
cat source/api_c/include/deepmd.hpp | grep -A 5 "DP_NlistSetMapping"
# Look for the base Nlist class implementation
echo -e "\n=== Searching for Nlist class implementation ==="
rg -l "class.*Nlist"
# If we find the class file, let's examine the set_mapping method
echo -e "\n=== Examining set_mapping implementation ==="
rg "set_mapping" --type cpp -A 5
# Check if there are any existing usages of mapping in neighbor list context
echo -e "\n=== Checking existing mapping usages ==="
rg "mapping.*neighbor|neighbor.*mapping" --type cpp
Length of output: 5281
source/api_c/include/deepmd.hpp (1)
618-622
: Verify the usage of set_mapping in the codebase.
The method is well-integrated and follows the established pattern. Let's verify its usage across the codebase to ensure proper adoption.
✅ Verification successful
The set_mapping
method is properly integrated and used as intended
The verification shows the method is well-integrated and used appropriately:
- Implemented in
neighbor_list.h
as core functionality - Called by LAMMPS integration code in
pair_deepmd.cpp
andfix_dplr.cpp
to map atom tags - Properly wrapped in the C API (
c_api.cc
) for external usage
The implementation follows the same pattern across all usage points, mapping from atom tags to real atoms consistently.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for potential callers of set_mapping to ensure proper integration
# Test 1: Find potential callers in the codebase
echo "Searching for potential callers of set_mapping..."
rg -l "InputNlist" | xargs rg "set_mapping"
# Test 2: Look for related LAMMPS integration code
echo "Checking LAMMPS integration..."
rg -l "LAMMPS" | xargs rg "mapping.*atom"
Length of output: 1277
source/api_cc/include/DeepPotJAX.h (1)
93-96
: Verify the implementation of is_aparam_nall()
The method is_aparam_nall()
always returns false
. Please confirm if this is the intended behavior. If aparam
is never of dimension nall
, it might be clearer to document this explicitly or adjust the method to better reflect its purpose.
source/api_cc/src/DeepPotJAX.cc (1)
334-339
: Avoid unnecessary casting to double when using float model
The comment suggests that casting to double may be unnecessary if using a float model. Ensure that this casting is intentional and necessary; otherwise, it could lead to precision issues or unnecessary conversions.
Please confirm whether this casting is required. If not, consider modifying the code to avoid unnecessary conversions.
source/lmp/tests/test_lammps_dpa_jax.py (21)
1-11
: LGTM!
The imports look good and follow the standard convention.
35-138
: LGTM!
The expected values for energy, forces, and virial are defined correctly using numpy arrays.
210-223
: LGTM!
The box coordinates, atom coordinates, and atom types are defined correctly using numpy arrays.
230-238
: LGTM!
The setup_module
function correctly writes the LAMMPS data files using the write_lmp_data
function with the defined box, coordinates, and atom types.
241-244
: LGTM!
The teardown_module
function removes the data files after the tests complete.
281-306
: LGTM!
The pytest fixtures for creating LAMMPS instances with different units and data files are set up correctly.
309-319
: LGTM!
The test_pair_deepmd
function correctly tests the DeepMD pair style by comparing the potential energy and forces against expected values.
321-342
: LGTM!
The test_pair_deepmd_virial
function correctly tests the virial computation with the DeepMD pair style by comparing against expected values.
344-368
: LGTM!
The test_pair_deepmd_model_devi
function correctly tests the model deviation output by comparing against expected values loaded from the md_file
.
370-406
: LGTM!
The test_pair_deepmd_model_devi_virial
function correctly tests the model deviation output with virial computation by comparing against expected values.
408-435
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative
function correctly tests the model deviation output with relative atomic deviations.
437-468
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative_v
function correctly tests the model deviation output with relative atomic deviations for virials.
470-480
: LGTM!
The test_pair_deepmd_type_map
function correctly tests the DeepMD pair style with a type map by comparing against expected values.
482-493
: LGTM!
The test_pair_deepmd_real
function correctly tests the DeepMD pair style with real units by comparing against expected values converted to real units.
496-521
: LGTM!
The test_pair_deepmd_virial_real
function correctly tests the virial computation with real units by comparing against expected values converted to real units.
523-551
: LGTM!
The test_pair_deepmd_model_devi_real
function correctly tests the model deviation output with real units by comparing against expected values converted to real units.
553-594
: LGTM!
The test_pair_deepmd_model_devi_virial_real
function correctly tests the model deviation output with virial computation in real units.
597-627
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative_real
function correctly tests the model deviation output with relative atomic deviations in real units.
630-664
: LGTM!
The test_pair_deepmd_model_devi_atomic_relative_v_real
function correctly tests the model deviation output with relative atomic deviations for virials in real units.
667-676
: LGTM!
The test_pair_deepmd_si
function correctly tests the DeepMD pair style with SI units by comparing against expected values converted to SI units.
225-227
: Ensure the command to convert pbtxt to pb runs successfully.
Please verify that the command to convert the pbtxt file to pb file using deepmd-kit
runs without any errors.
#!/bin/bash
# Description: Verify pbtxt to pb conversion command runs successfully.
# Test: Run the conversion command. Expect: No output on successful execution.
python -m deepmd convert-from pbtxt -i source/tests/infer/deeppot-1.pbtxt -o source/lmp/tests/graph2.pb 2>&1 >/dev/null || echo "Failed to convert pbtxt to pb"
source/api_cc/src/DeepPot.cc (2)
10-10
: Include DeepPotJAX.h
conditionally
Including DeepPotJAX.h
within the #ifdef BUILD_TENSORFLOW
block ensures that the header is only included when TensorFlow support is enabled, which is appropriate.
65-72
: Correct handling of JAX backend initialization
The addition of JAX backend support with appropriate conditional compilation and exception handling is correctly implemented.
source/lmp/pair_deepmd.cpp (3)
575-577
: Verify the mapping is correctly set for the main neighbor list.
The code correctly sets the mapping for the main neighbor list lmp_list
when the number of processes is one and the mapping style is not MAP_NONE
. This ensures that the atom indices in the neighbor list correctly correspond to the global atom indices.
588-590
: Verify the mapping is correctly set for the extended neighbor list.
The code correctly sets the mapping for the extended neighbor list extend_lmp_list
when the number of processes is one and the mapping style is not MAP_NONE
. This ensures that the atom indices in the extended neighbor list, which includes ghost atoms, correctly correspond to the global atom indices.
524-530
: Verify the mapping vector is correctly set for all atoms.
The code correctly sets the mapping_vec
for all atoms when the number of processes is one and the mapping style is not MAP_NONE
. However, ensure that the mapping is correctly utilized in other parts of the code where atom indices are used, such as neighbor lists, force calculations, etc.
Run the following script to verify the mapping usage:
✅ Verification successful
Let me gather more information about how this mapping vector is used in the codebase.
Let me try one more search to understand how this mapping vector is utilized:
Based on the search results, I can now provide a final response:
The mapping vector implementation is correct and properly utilized.
The mapping_vec
is consistently used in both pair_deepmd.cpp
and fix_dplr.cpp
with the same implementation pattern:
- It's only used when running on a single process (
comm->nprocs == 1
) and when atom mapping is enabled (atom->map_style != Atom::MAP_NONE
) - The mapping is properly passed to the neighbor list through
lmp_list.set_mapping(mapping_vec.data())
- The implementation correctly maps all atoms (both local and ghost) as required
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify `mapping_vec` is correctly used for atom indices in the codebase.
# Test: Search for potential mapping issues. Expect: No concerning usages.
rg -A 5 -B 5 $'atom->map\(\w+\)|mapping_vec\[\w+\]' -g '*.cpp' -g '*.h' -g '*.cu'
Length of output: 1291
Script:
#!/bin/bash
# Search for usage of mapping_vec in the codebase
rg -A 10 -B 10 "mapping_vec" --type cpp --type hpp --type cu
# Search for any potential atom mapping or indexing related patterns
ast-grep --pattern 'std::vector<int> mapping_vec = $_'
# Look for the broader context of the files where mapping_vec is used
rg -l "mapping_vec" | xargs -I {} head -n 1 {}
Length of output: 288
Script:
#!/bin/bash
# Search for the complete function context where mapping_vec is used
rg -A 30 -B 30 "mapping_vec" source/lmp/pair_deepmd.cpp source/lmp/fix_dplr.cpp
# Look for any atom mapping related functions or methods
ast-grep --pattern 'void $_($$$) {
$$$
mapping_vec
$$$
}'
Length of output: 13444
@@ -13,7 +13,7 @@ | |||
namespace deepmd { | |||
|
|||
typedef double ENERGYTYPE; | |||
enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown }; | |||
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Update if-else chains in multiple files to handle JAX backend
The codebase has several if-else chains that need to be updated to handle the new JAX backend:
source/api_cc/src/DeepPot.cc
: Missing JAX case in backend comparisonsource/api_cc/src/DataModifier.cc
: Missing JAX case in backend comparisonsource/api_cc/src/DeepTensor.cc
: Missing JAX case in backend comparison
Each of these files needs to add a new condition else if (deepmd::DPBackend::JAX == backend)
before the final else clause to properly handle the JAX backend.
🔗 Analysis chain
LGTM! Verify enum usage across codebase.
The addition of JAX
to the DPBackend
enum before Unknown
is correct and maintains backward compatibility.
Let's verify the enum usage across the codebase:
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for DPBackend enum usage and potential switch statements that need updating
# Search for switch statements on DPBackend that might need updating
rg -A 10 "switch.*DPBackend"
# Search for direct enum value comparisons
rg "DPBackend::(TensorFlow|PyTorch|Paddle|Unknown)"
Length of output: 1157
/** | ||
* @brief Set mapping for a neighbor list. | ||
* | ||
* @param nl Neighbor list. | ||
* @param mapping mapping from all atoms to real atoms, in size nall. | ||
* @since API version 24 | ||
* | ||
**/ | ||
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider adding size parameter for safety.
While the function signature and documentation are well-structured, consider adding a size parameter for the mapping array to prevent potential buffer overflows. This would allow runtime validation of the mapping array size against the expected size (nall
).
-extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);
+extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping, int mapping_size);
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
/** | |
* @brief Set mapping for a neighbor list. | |
* | |
* @param nl Neighbor list. | |
* @param mapping mapping from all atoms to real atoms, in size nall. | |
* @since API version 24 | |
* | |
**/ | |
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping); | |
/** | |
* @brief Set mapping for a neighbor list. | |
* | |
* @param nl Neighbor list. | |
* @param mapping mapping from all atoms to real atoms, in size nall. | |
* @since API version 24 | |
* | |
**/ | |
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping, int mapping_size); |
💡 Codebase verification
Based on the code inspection, I can now provide the final response:
Buffer overflow risk confirmed - no size validation for mapping array
The implementation in both C API (source/api_c/src/c_api.cc
) and the underlying C++ class (source/lib/include/neighbor_list.h
) shows that the set_mapping
function directly assigns the pointer without validating its size against nall
. This could lead to buffer overflows if the provided array is smaller than expected.
Key findings:
- C API implementation simply forwards the pointer:
nl->nl.set_mapping(mapping);
- C++ implementation in
InputNlist
class also directly assigns without validation:void set_mapping(int* mapping_) { mapping = mapping_; };
- The mapping array is documented to require size
nall
but this is not enforced
🔗 Analysis chain
Add validation for mapping array size.
The function should validate that the mapping array size matches the expected size (nall
) to prevent buffer overflows.
Let's verify if size validation is implemented in the source:
Let me check the C++ implementation of the set_mapping
method to see if size validation is performed there.
Let me check the NeighborList
class implementation to understand how the mapping is used and if there are any size validations.
Let me try a different approach to find the NeighborList implementation.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check if the implementation validates the mapping array size
# Test: Search for size validation in the implementation
rg -A 5 "DP_NlistSetMapping.*\{" source/api_c/src/
Length of output: 373
Script:
#!/bin/bash
# Search for set_mapping implementation in the C++ code
ast-grep --pattern 'set_mapping($$$) { $$$ }'
# Also search for any references to mapping size validation
rg -i "set_mapping|mapping.*size" source/
Length of output: 2662
Script:
#!/bin/bash
# Search for NeighborList class implementation and usage of mapping
rg -A 10 "class.*NeighborList" source/lib/include/
rg -B 2 -A 5 "mapping\s*=" source/lib/src/neighbor_list.cc
Length of output: 113
Script:
#!/bin/bash
# Search for the NeighborList class definition and implementation
rg -A 10 "class.*NeighborList" source/
cat source/lib/include/neighbor_list.h
Length of output: 13092
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) { | ||
nl->nl.set_mapping(mapping); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add parameter validation and documentation.
While the implementation is correct, consider adding:
- Null pointer validation for parameters
- Documentation explaining the expected format and size of the mapping array
- Error handling for invalid inputs
+// Set the mapping for the neighbor list
+// @param nl: Pointer to the neighbor list
+// @param mapping: Array of integers defining the mapping. Must not be null.
+// @return void
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
+ if (nl == nullptr || mapping == nullptr) {
+ nl->exception = "Invalid null pointer in DP_NlistSetMapping";
+ return;
+ }
nl->nl.set_mapping(mapping);
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) { | |
nl->nl.set_mapping(mapping); | |
} | |
// Set the mapping for the neighbor list | |
// @param nl: Pointer to the neighbor list | |
// @param mapping: Array of integers defining the mapping. Must not be null. | |
// @return void | |
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) { | |
if (nl == nullptr || mapping == nullptr) { | |
nl->exception = "Invalid null pointer in DP_NlistSetMapping"; | |
return; | |
} | |
nl->nl.set_mapping(mapping); | |
} |
* @param[in] atomic Whether to compute atomic energy and virial. | ||
**/ | ||
template <typename VALUETYPE> | ||
void compute(std::vector<ENERGYTYPE>& ener, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undefined type ENERGYTYPE
in template function
The template function compute
uses ENERGYTYPE
for the ener
parameter, but ENERGYTYPE
is not defined within this header file. This may lead to compilation errors.
Please ensure that ENERGYTYPE
is defined or include the appropriate header file where ENERGYTYPE
is declared.
TF_Graph* graph; | ||
TF_Status* status; | ||
TF_Session* session; | ||
TF_SessionOptions* sessionopts; | ||
TFE_ContextOptions* ctx_opts; | ||
TFE_Context* ctx; | ||
std::vector<TF_Function*> func_vector; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure proper cleanup of TensorFlow C API objects in destructor
The class utilizes several TensorFlow C API objects (graph
, status
, session
, sessionopts
, ctx_opts
, ctx
, func_vector
). To prevent memory leaks, it's crucial to release these resources appropriately in the destructor.
Consider adding cleanup code in the destructor ~DeepPotJAX()
:
+DeepPotJAX::~DeepPotJAX() {
+ if (session) TF_DeleteSession(session, status);
+ if (graph) TF_DeleteGraph(graph);
+ if (status) TF_DeleteStatus(status);
+ if (sessionopts) TF_DeleteSessionOptions(sessionopts);
+ if (ctx_opts) TFE_DeleteContextOptions(ctx_opts);
+ if (ctx) TFE_DeleteContext(ctx);
+ for (auto func : func_vector) {
+ TF_DeleteFunction(func);
+ }
+}
Ensure that all TensorFlow objects are properly deleted and that error checking is implemented where necessary.
Committable suggestion skipped: line range outside the PR's diff.
inline T get_scalar(TFE_Context* ctx, | ||
const std::string func_name, | ||
const std::vector<TF_Function*>& funcs, | ||
const std::string device, | ||
TF_Status* status) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Pass 'func_name' and 'device' by const reference in 'get_scalar'
In the function get_scalar
, passing func_name
and device
by const std::string&
avoids unnecessary string copying and enhances performance.
Apply this diff to update the parameter passing:
template <typename T>
inline T get_scalar(TFE_Context* ctx,
- const std::string func_name,
+ const std::string& func_name,
const std::vector<TF_Function*>& funcs,
- const std::string device,
+ const std::string& device,
TF_Status* status) {
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
inline T get_scalar(TFE_Context* ctx, | |
const std::string func_name, | |
const std::vector<TF_Function*>& funcs, | |
const std::string device, | |
TF_Status* status) { | |
inline T get_scalar(TFE_Context* ctx, | |
const std::string& func_name, | |
const std::vector<TF_Function*>& funcs, | |
const std::string& device, | |
TF_Status* status) { |
🧰 Tools
🪛 cppcheck
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
DPGetDeviceCount(gpu_num); // check current device environment | ||
DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); | ||
if (gpu_num > 0) { | ||
device = "/gpu:" + std::to_string(gpu_rank % gpu_num); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle potential division by zero when computing 'gpu_rank % gpu_num'
In the code block where gpu_num
is used, ensure that gpu_num
is not zero before performing the modulo operation to avoid a division by zero error.
Apply this diff to add a check for gpu_num
:
int gpu_num;
DPGetDeviceCount(gpu_num); // check current device environment
+DPErrcheck(gpu_num != 0, "GPU device count is zero.");
DPErrcheck(DPSetDevice(gpu_rank % gpu_num));
Or modify the logic to handle the case when gpu_num
is zero.
Committable suggestion skipped: line range outside the PR's diff.
def teardown_module(): | ||
os.remove(data_file) | ||
os.remove(data_type_map_file) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add cleanup for data_file_si
in teardown_module()
.
The file data_file_si
created in setup_module()
is not being removed in teardown_module()
. This could lead to leftover files after tests are run.
Apply this diff to fix the issue:
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)
+ os.remove(data_file_si)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def teardown_module(): | |
os.remove(data_file) | |
os.remove(data_type_map_file) | |
def teardown_module(): | |
os.remove(data_file) | |
os.remove(data_type_map_file) | |
os.remove(data_file_si) | |
Ensure pb_file2
is removed in teardown_module()
.
The file pb_file2
generated in the setup is not being deleted in teardown_module()
, potentially leaving temporary files after tests are executed.
Apply this diff to fix the issue:
def teardown_module():
os.remove(data_file)
os.remove(data_type_map_file)
+ os.remove(pb_file2)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def teardown_module(): | |
os.remove(data_file) | |
os.remove(data_type_map_file) | |
def teardown_module(): | |
os.remove(data_file) | |
os.remove(data_type_map_file) | |
os.remove(pb_file2) |
sp.check_call( | ||
[ | ||
"mpirun", | ||
"-n", | ||
"2", | ||
sys.executable, | ||
Path(__file__).parent / "run_mpi_pair_deepmd.py", | ||
data_file, | ||
pb_file, | ||
pb_file2, | ||
md_file, | ||
f.name, | ||
*balance_args, | ||
] | ||
) | ||
arr = np.loadtxt(f.name, ndmin=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid using NamedTemporaryFile
with subprocess on Windows.
Using tempfile.NamedTemporaryFile()
may cause issues on Windows when passing the filename to a subprocess because the file is already open. This can lead to permission errors.
Consider using tempfile.mkstemp()
to create a temporary file safely:
-def test_pair_deepmd_mpi(balance_args: list):
- with tempfile.NamedTemporaryFile() as f:
+def test_pair_deepmd_mpi(balance_args: list):
+ fd, temp_path = tempfile.mkstemp()
+ os.close(fd)
sp.check_call(
[
"mpirun",
"-n",
"2",
sys.executable,
Path(__file__).parent / "run_mpi_pair_deepmd.py",
data_file,
pb_file,
pb_file2,
md_file,
- f.name,
+ temp_path,
*balance_args,
]
)
- arr = np.loadtxt(f.name, ndmin=1)
+ arr = np.loadtxt(temp_path, ndmin=1)
+ os.remove(temp_path)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
sp.check_call( | |
[ | |
"mpirun", | |
"-n", | |
"2", | |
sys.executable, | |
Path(__file__).parent / "run_mpi_pair_deepmd.py", | |
data_file, | |
pb_file, | |
pb_file2, | |
md_file, | |
f.name, | |
*balance_args, | |
] | |
) | |
arr = np.loadtxt(f.name, ndmin=1) | |
fd, temp_path = tempfile.mkstemp() | |
os.close(fd) | |
sp.check_call( | |
[ | |
"mpirun", | |
"-n", | |
"2", | |
sys.executable, | |
Path(__file__).parent / "run_mpi_pair_deepmd.py", | |
data_file, | |
pb_file, | |
pb_file2, | |
md_file, | |
temp_path, | |
*balance_args, | |
] | |
) | |
arr = np.loadtxt(temp_path, ndmin=1) | |
os.remove(temp_path) |
pbtxt_file2 = ( | ||
Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot-1.pbtxt" | ||
) | ||
pb_file = ( | ||
Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa.savedmodel" | ||
) | ||
pb_file2 = Path(__file__).parent / "graph2.pb" | ||
system_file = Path(__file__).parent.parent.parent / "tests" | ||
data_file = Path(__file__).parent / "data.lmp" | ||
data_file_si = Path(__file__).parent / "data.si" | ||
data_type_map_file = Path(__file__).parent / "data_type_map.lmp" | ||
md_file = Path(__file__).parent / "md.out" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Several test data files are missing from the expected locations
The verification reveals that while some source files exist (deeppot-1.pbtxt
and deeppot_dpa.savedmodel
), the following test data files are missing:
source/lmp/tests/graph2.pb
source/lmp/tests/data.lmp
source/lmp/tests/data.si
source/lmp/tests/data_type_map.lmp
source/lmp/tests/md.out
These files appear to be test data files that should be present for the tests to run properly. You'll need to either:
- Generate these test data files if they're meant to be created during test setup
- Add these missing test data files to the repository
- Update the test to use existing test data files from the correct locations
🔗 Analysis chain
Verify the file paths are correct and exist.
Please ensure that the file paths used for pbtxt_file2
, pb_file
, pb_file2
, system_file
, data_file
, data_file_si
, data_type_map_file
, and md_file
point to the correct locations and the files exist.
Let me verify the existence of these files with a more direct approach.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify the file paths used in the test file exist.
# Test: Check if the files exist. Expect: No output if files exist.
fd -e pbtxt -e pb -e lmp -e out | xargs -I {} sh -c 'test -f "$1" || echo "$1 does not exist"' -- {}
Length of output: 248
Script:
#!/bin/bash
# Check for the existence of specific files mentioned in the test
ls -l source/tests/infer/deeppot-1.pbtxt 2>/dev/null || echo "deeppot-1.pbtxt not found"
ls -l source/tests/infer/deeppot_dpa.savedmodel 2>/dev/null || echo "deeppot_dpa.savedmodel not found"
ls -l source/lmp/tests/graph2.pb 2>/dev/null || echo "graph2.pb not found"
ls -l source/lmp/tests/data.lmp 2>/dev/null || echo "data.lmp not found"
ls -l source/lmp/tests/data.si 2>/dev/null || echo "data.si not found"
ls -l source/lmp/tests/data_type_map.lmp 2>/dev/null || echo "data_type_map.lmp not found"
ls -l source/lmp/tests/md.out 2>/dev/null || echo "md.out not found"
# Also check the directory structure
tree source/tests/infer/
tree source/lmp/tests/
Length of output: 2384
I will merge this PR to #4307. |
Based on #4307 (which should be merged first). The testing model is generated after #4315 is applied.
Summary by CodeRabbit
Release Notes
New Features
DeepPotJAX
class for TensorFlow integration, facilitating advanced computations.Documentation
Bug Fixes
FixDPLR
andPairDeepMD
classes.Tests