diff --git a/.github/workflows/test_cc.yml b/.github/workflows/test_cc.yml index 2082e7e4cc..d98f8ca58d 100644 --- a/.github/workflows/test_cc.yml +++ b/.github/workflows/test_cc.yml @@ -56,7 +56,7 @@ jobs: TF_INTRA_OP_PARALLELISM_THREADS: 1 TF_INTER_OP_PARALLELISM_THREADS: 1 LAMMPS_PLUGIN_PATH: ${{ github.workspace }}/dp_test/lib/deepmd_lmp - LD_LIBRARY_PATH: ${{ github.workspace }}/dp_test/lib + LD_LIBRARY_PATH: ${{ github.workspace }}/dp_test/lib:${{ github.workspace }}/libtorch/lib if: ${{ !matrix.check_memleak }} # test ipi - run: pytest --cov=deepmd source/ipi/tests @@ -65,7 +65,7 @@ jobs: TF_INTRA_OP_PARALLELISM_THREADS: 1 TF_INTER_OP_PARALLELISM_THREADS: 1 PATH: ${{ github.workspace }}/dp_test/bin:$PATH - LD_LIBRARY_PATH: ${{ github.workspace }}/dp_test/lib + LD_LIBRARY_PATH: ${{ github.workspace }}/dp_test/lib:${{ github.workspace }}/libtorch/lib if: ${{ !matrix.check_memleak }} - uses: codecov/codecov-action@v4 env: diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 0d934e6d77..915d983663 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -38,6 +38,8 @@ jobs: with: useLocalCache: true useCloudCache: false + - name: Install wget and unzip + run: apt-get update && apt-get install -y wget unzip - run: | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb \ && sudo dpkg -i cuda-keyring_1.0-1_all.deb \ @@ -53,7 +55,13 @@ jobs: DP_ENABLE_NATIVE_OPTIMIZATION: 1 - run: dp --version - run: python -m pytest source/tests --durations=0 - - run: source/install/test_cc_local.sh + - name: Download libtorch + run: | + wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.1%2Bcu121.zip -O libtorch.zip + unzip libtorch.zip + - run: | + export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch + source/install/test_cc_local.sh env: OMP_NUM_THREADS: 1 TF_INTRA_OP_PARALLELISM_THREADS: 1 @@ -63,7 +71,7 @@ jobs: DP_VARIANT: cuda DP_USE_MPICH2: 1 - run: | - export LD_LIBRARY_PATH=$GITHUB_WORKSPACE/dp_test/lib:$CUDA_PATH/lib64:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$CUDA_PATH/lib64:$LD_LIBRARY_PATH export PATH=$GITHUB_WORKSPACE/dp_test/bin:$PATH python -m pytest source/lmp/tests python -m pytest source/ipi/tests diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 79634186e4..b6478d297f 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -17,6 +17,9 @@ communicate_extended_output, fit_output_to_model_output, ) +from deepmd.pt.utils import ( + env, +) from deepmd.pt.utils.nlist import ( extend_input_and_build_neighbor_list, nlist_distinguish_types, @@ -115,6 +118,9 @@ def forward_common( The keys are defined by the `ModelOutputDef`. """ + coord = coord.to(env.GLOBAL_PT_FLOAT_PRECISION) + if box is not None: + box = box.to(env.GLOBAL_PT_FLOAT_PRECISION) ( extended_coord, extended_atype, @@ -183,6 +189,7 @@ def forward_common_lower( the result dict, defined by the `FittingOutputDef`. """ + extended_coord = extended_coord.to(env.GLOBAL_PT_FLOAT_PRECISION) nframes, nall = extended_atype.shape[:2] extended_coord = extended_coord.view(nframes, -1, 3) nlist = self.format_nlist(extended_coord, extended_atype, nlist) diff --git a/source/api_cc/include/DeepPotPT.h b/source/api_cc/include/DeepPotPT.h new file mode 100644 index 0000000000..1b757069c3 --- /dev/null +++ b/source/api_cc/include/DeepPotPT.h @@ -0,0 +1,332 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#pragma once + +#include + +#include "DeepPot.h" +#include "commonPT.h" + +namespace deepmd { +/** + * @brief PyTorch implementation for Deep Potential. + **/ +class DeepPotPT : public DeepPotBase { + public: + /** + * @brief DP constructor without initialization. + **/ + DeepPotPT(); + ~DeepPotPT(); + /** + * @brief DP constructor with initialization. + * @param[in] model The name of the frozen model file. + * @param[in] gpu_rank The GPU rank. Default is 0. + * @param[in] file_content The content of the model file. If it is not empty, + *DP will read from the string instead of the file. + **/ + DeepPotPT(const std::string& model, + const int& gpu_rank = 0, + const std::string& file_content = ""); + /** + * @brief Initialize the DP. + * @param[in] model The name of the frozen model file. + * @param[in] gpu_rank The GPU rank. Default is 0. + * @param[in] file_content The content of the model file. If it is not empty, + *DP will read from the string instead of the file. + **/ + void init(const std::string& model, + const int& gpu_rank = 0, + const std::string& file_content = ""); + + private: + /** + * @brief Evaluate the energy, force, virial, atomic energy, and atomic virial + *by using this DP. + * @param[out] ener The system energy. + * @param[out] force The force on each atom. + * @param[out] virial The virial. + * @param[out] atom_energy The atomic energy. + * @param[out] atom_virial The atomic virial. + * @param[in] coord The coordinates of atoms. The array should be of size + *nframes x natoms x 3. + * @param[in] atype The atom types. The list should contain natoms ints. + * @param[in] box The cell of the region. The array should be of size nframes + *x 9. + * @param[in] fparam The frame parameter. The array can be of size : + * nframes x dim_fparam. + * dim_fparam. Then all frames are assumed to be provided with the same + *fparam. + * @param[in] aparam The atomic parameter The array can be of size : + * nframes x natoms x dim_aparam. + * natoms x dim_aparam. Then all frames are assumed to be provided with the + *same aparam. + **/ + template + void compute(ENERGYVTYPE& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box); + // const std::vector& fparam = std::vector(), + // const std::vector& aparam = std::vector()); + /** + * @brief Evaluate the energy, force, virial, atomic energy, and atomic virial + *by using this DP. + * @param[out] ener The system energy. + * @param[out] force The force on each atom. + * @param[out] virial The virial. + * @param[out] atom_energy The atomic energy. + * @param[out] atom_virial The atomic virial. + * @param[in] coord The coordinates of atoms. The array should be of size + *nframes x natoms x 3. + * @param[in] atype The atom types. The list should contain natoms ints. + * @param[in] box The cell of the region. The array should be of size nframes + *x 9. + * @param[in] nghost The number of ghost atoms. + * @param[in] lmp_list The input neighbour list. + * @param[in] ago Update the internal neighbour list if ago is 0. + * @param[in] fparam The frame parameter. The array can be of size : + * nframes x dim_fparam. + * dim_fparam. Then all frames are assumed to be provided with the same + *fparam. + * @param[in] aparam The atomic parameter The array can be of size : + * nframes x natoms x dim_aparam. + * natoms x dim_aparam. Then all frames are assumed to be provided with the + *same aparam. + **/ + template + void compute(ENERGYVTYPE& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + // const int nghost, + const InputNlist& lmp_list, + const int& ago); + // const std::vector& fparam = std::vector(), + // const std::vector& aparam = std::vector()); + /** + * @brief Evaluate the energy, force, and virial with the mixed type + *by using this DP. + * @param[out] ener The system energy. + * @param[out] force The force on each atom. + * @param[out] virial The virial. + * @param[in] nframes The number of frames. + * @param[in] coord The coordinates of atoms. The array should be of size + *nframes x natoms x 3. + * @param[in] atype The atom types. The array should be of size nframes x + *natoms. + * @param[in] box The cell of the region. The array should be of size nframes + *x 9. + * @param[in] fparam The frame parameter. The array can be of size : + * nframes x dim_fparam. + * dim_fparam. Then all frames are assumed to be provided with the same + *fparam. + * @param[in] aparam The atomic parameter The array can be of size : + * nframes x natoms x dim_aparam. + * natoms x dim_aparam. Then all frames are assumed to be provided with the + *same aparam. + **/ + template + void compute_mixed_type( + ENERGYVTYPE& ener, + std::vector& force, + std::vector& virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam = std::vector(), + const std::vector& aparam = std::vector()); + /** + * @brief Evaluate the energy, force, and virial with the mixed type + *by using this DP. + * @param[out] ener The system energy. + * @param[out] force The force on each atom. + * @param[out] virial The virial. + * @param[out] atom_energy The atomic energy. + * @param[out] atom_virial The atomic virial. + * @param[in] nframes The number of frames. + * @param[in] coord The coordinates of atoms. The array should be of size + *nframes x natoms x 3. + * @param[in] atype The atom types. The array should be of size nframes x + *natoms. + * @param[in] box The cell of the region. The array should be of size nframes + *x 9. + * @param[in] fparam The frame parameter. The array can be of size : + * nframes x dim_fparam. + * dim_fparam. Then all frames are assumed to be provided with the same + *fparam. + * @param[in] aparam The atomic parameter The array can be of size : + * nframes x natoms x dim_aparam. + * natoms x dim_aparam. Then all frames are assumed to be provided with the + *same aparam. + **/ + template + void compute_mixed_type( + ENERGYVTYPE& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam = std::vector(), + const std::vector& aparam = std::vector()); + + public: + /** + * @brief Get the cutoff radius. + * @return The cutoff radius. + **/ + double cutoff() const { + assert(inited); + return rcut; + }; + /** + * @brief Get the number of types. + * @return The number of types. + **/ + int numb_types() const { + assert(inited); + return ntypes; + }; + /** + * @brief Get the number of types with spin. + * @return The number of types with spin. + **/ + int numb_types_spin() const { + assert(inited); + return ntypes_spin; + }; + /** + * @brief Get the dimension of the frame parameter. + * @return The dimension of the frame parameter. + **/ + int dim_fparam() const { + assert(inited); + return dfparam; + }; + /** + * @brief Get the dimension of the atomic parameter. + * @return The dimension of the atomic parameter. + **/ + int dim_aparam() const { + assert(inited); + return daparam; + }; + /** + * @brief Get the type map (element name of the atom types) of this model. + * @param[out] type_map The type map of this model. + **/ + void get_type_map(std::string& type_map); + + /** + * @brief Get whether the atom dimension of aparam is nall instead of fparam. + * @param[out] aparam_nall whether the atom dimension of aparam is nall + *instead of fparam. + **/ + bool is_aparam_nall() const { + assert(inited); + return aparam_nall; + }; + + // forward to template class + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam = std::vector(), + const std::vector& aparam = std::vector()); + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam = std::vector(), + const std::vector& aparam = std::vector()); + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam = std::vector(), + const std::vector& aparam = std::vector()); + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam = std::vector(), + const std::vector& aparam = std::vector()); + void computew_mixed_type( + std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam = std::vector(), + const std::vector& aparam = std::vector()); + void computew_mixed_type( + std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam = std::vector(), + const std::vector& aparam = std::vector()); + + private: + int num_intra_nthreads, num_inter_nthreads; + bool inited; + int ntypes; + int ntypes_spin; + int dfparam; + int daparam; + bool aparam_nall; + // copy neighbor list info from host + torch::jit::script::Module module; + double rcut; + NeighborListDataPT nlist_data; + int max_num_neighbors; + int gpu_id; + bool gpu_enabled; + at::Tensor firstneigh_tensor; +}; + +} // namespace deepmd diff --git a/source/api_cc/include/commonPT.h b/source/api_cc/include/commonPT.h new file mode 100644 index 0000000000..57ffd5b295 --- /dev/null +++ b/source/api_cc/include/commonPT.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#include + +#include +#include +#include +#include + +#include "neighbor_list.h" +namespace deepmd { +struct NeighborListDataPT { + /// Array stores the core region atom's index + std::vector ilist; + /// Array stores the core region atom's neighbor index + std::vector jlist; + /// Array stores the number of neighbors of core region atoms + std::vector numneigh; + /// Array stores the the location of the first neighbor of core region atoms + std::vector firstneigh; + + public: + void copy_from_nlist(const InputNlist& inlist, int& max_num_neighbors); +}; +} // namespace deepmd diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index c598549844..442e2d90cc 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -10,6 +10,9 @@ #ifdef BUILD_TENSORFLOW #include "DeepPotTF.h" #endif +#ifdef BUILD_PYTORCH +#include "DeepPotPT.h" +#endif #include "device.h" using namespace deepmd; @@ -34,8 +37,14 @@ void DeepPot::init(const std::string& model, << std::endl; return; } - // TODO: To implement detect_backend - DPBackend backend = deepmd::DPBackend::TensorFlow; + DPBackend backend; + if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") { + backend = deepmd::DPBackend::PyTorch; + } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") { + backend = deepmd::DPBackend::TensorFlow; + } else { + throw deepmd::deepmd_exception("Unsupported model file format"); + } if (deepmd::DPBackend::TensorFlow == backend) { #ifdef BUILD_TENSORFLOW dp = std::make_shared(model, gpu_rank, file_content); @@ -43,7 +52,11 @@ void DeepPot::init(const std::string& model, throw deepmd::deepmd_exception("TensorFlow backend is not built"); #endif } else if (deepmd::DPBackend::PyTorch == backend) { - throw deepmd::deepmd_exception("PyTorch backend is not supported yet"); +#ifdef BUILD_PYTORCH + dp = std::make_shared(model, gpu_rank, file_content); +#else + throw deepmd::deepmd_exception("PyTorch backend is not built"); +#endif } else if (deepmd::DPBackend::Paddle == backend) { throw deepmd::deepmd_exception("PaddlePaddle backend is not supported yet"); } else { diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index c94fb4247b..f05e27b9b2 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -1,8 +1,358 @@ // SPDX-License-Identifier: LGPL-3.0-or-later #ifdef BUILD_PYTORCH -#include +#include "DeepPotPT.h" -void test_function_please_remove_after_torch_is_actually_used() { - torch::Tensor tensor = torch::rand({2, 3}); +#include "common.h" +using namespace deepmd; +DeepPotPT::DeepPotPT() : inited(false) {} +DeepPotPT::DeepPotPT(const std::string& model, + const int& gpu_rank, + const std::string& file_content) + : inited(false) { + try { + init(model, gpu_rank, file_content); + } catch (...) { + // Clean up and rethrow, as the destructor will not be called + throw; + } +} +void DeepPotPT::init(const std::string& model, + const int& gpu_rank, + const std::string& file_content) { + if (inited) { + std::cerr << "WARNING: deepmd-kit should not be initialized twice, do " + "nothing at the second call of initializer" + << std::endl; + return; + } + gpu_id = gpu_rank; + torch::Device device(torch::kCUDA, gpu_rank); + gpu_enabled = torch::cuda::is_available(); + if (!gpu_enabled) { + device = torch::Device(torch::kCPU); + std::cout << "load model from: " << model << " to cpu " << gpu_rank + << std::endl; + } else { + std::cout << "load model from: " << model << " to gpu " << gpu_rank + << std::endl; + } + module = torch::jit::load(model, device); + + torch::jit::FusionStrategy strategy; + strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}}; + torch::jit::setFusionStrategy(strategy); + + get_env_nthreads(num_intra_nthreads, + num_inter_nthreads); // need to be fixed as + // DP_INTRA_OP_PARALLELISM_THREADS + if (num_inter_nthreads) { + try { + at::set_num_interop_threads(num_inter_nthreads); + } catch (...) { + } + } + if (num_intra_nthreads) { + try { + at::set_num_threads(num_intra_nthreads); + } catch (...) { + } + } + + auto rcut_ = module.run_method("get_rcut").toDouble(); + rcut = static_cast(rcut_); + ntypes = 0; + ntypes_spin = 0; + dfparam = 0; + daparam = 0; + aparam_nall = false; + inited = true; +} +DeepPotPT::~DeepPotPT() {} + +template +void DeepPotPT::compute(ENERGYVTYPE& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const InputNlist& lmp_list, + const int& ago) { + torch::Device device(torch::kCUDA, gpu_id); + if (!gpu_enabled) { + device = torch::Device(torch::kCPU); + } + std::vector coord_wrapped = coord; + int natoms = atype.size(); + auto options = torch::TensorOptions().dtype(torch::kFloat64); + torch::ScalarType floatType = torch::kFloat64; + if (std::is_same_v) { + options = torch::TensorOptions().dtype(torch::kFloat32); + floatType = torch::kFloat32; + } + auto int_options = torch::TensorOptions().dtype(torch::kInt64); + auto int32_options = torch::TensorOptions().dtype(torch::kInt32); + at::Tensor coord_wrapped_Tensor = + torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options) + .to(device); + std::vector atype_64(atype.begin(), atype.end()); + at::Tensor atype_Tensor = + torch::from_blob(atype_64.data(), {1, natoms}, int_options).to(device); + if (ago == 0) { + nlist_data.copy_from_nlist(lmp_list, max_num_neighbors); + } + at::Tensor firstneigh = + torch::from_blob(nlist_data.jlist.data(), + {1, lmp_list.inum, max_num_neighbors}, int32_options); + firstneigh_tensor = firstneigh.to(torch::kInt64).to(device); + bool do_atom_virial_tensor = true; + c10::optional optional_tensor; + c10::Dict outputs = + module + .run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor, + firstneigh_tensor, optional_tensor, optional_tensor, + optional_tensor, do_atom_virial_tensor) + .toGenericDict(); + c10::IValue energy_ = outputs.at("energy"); + c10::IValue force_ = outputs.at("extended_force"); + c10::IValue virial_ = outputs.at("virial"); + c10::IValue atom_virial_ = outputs.at("extended_virial"); + c10::IValue atom_energy_ = outputs.at("atom_energy"); + torch::Tensor flat_energy_ = energy_.toTensor().view({-1}); + torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU); + ener.assign(cpu_energy_.data_ptr(), + cpu_energy_.data_ptr() + cpu_energy_.numel()); + torch::Tensor flat_atom_energy_ = + atom_energy_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_atom_energy_ = flat_atom_energy_.to(torch::kCPU); + atom_energy.resize(natoms, 0.0); // resize to nall to be consistenet with TF. + atom_energy.assign( + cpu_atom_energy_.data_ptr(), + cpu_atom_energy_.data_ptr() + cpu_atom_energy_.numel()); + torch::Tensor flat_force_ = force_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_force_ = flat_force_.to(torch::kCPU); + force.assign(cpu_force_.data_ptr(), + cpu_force_.data_ptr() + cpu_force_.numel()); + torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU); + virial.assign(cpu_virial_.data_ptr(), + cpu_virial_.data_ptr() + cpu_virial_.numel()); + torch::Tensor flat_atom_virial_ = + atom_virial_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_atom_virial_ = flat_atom_virial_.to(torch::kCPU); + atom_virial.assign( + cpu_atom_virial_.data_ptr(), + cpu_atom_virial_.data_ptr() + cpu_atom_virial_.numel()); +} +template void DeepPotPT::compute>( + std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const InputNlist& lmp_list, + const int& ago); +template void DeepPotPT::compute>( + std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const InputNlist& lmp_list, + const int& ago); +template +void DeepPotPT::compute(ENERGYVTYPE& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box) { + torch::Device device(torch::kCUDA, gpu_id); + if (!gpu_enabled) { + device = torch::Device(torch::kCPU); + } + std::vector coord_wrapped = coord; + int natoms = atype.size(); + auto options = torch::TensorOptions().dtype(torch::kFloat64); + torch::ScalarType floatType = torch::kFloat64; + if (std::is_same_v) { + options = torch::TensorOptions().dtype(torch::kFloat32); + floatType = torch::kFloat32; + } + auto int_options = torch::TensorOptions().dtype(torch::kInt64); + std::vector inputs; + at::Tensor coord_wrapped_Tensor = + torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options) + .to(device); + inputs.push_back(coord_wrapped_Tensor); + std::vector atype_64(atype.begin(), atype.end()); + at::Tensor atype_Tensor = + torch::from_blob(atype_64.data(), {1, natoms}, int_options).to(device); + inputs.push_back(atype_Tensor); + c10::optional box_Tensor; + if (!box.empty()) { + box_Tensor = + torch::from_blob(const_cast(box.data()), {1, 9}, options) + .to(device); + } + inputs.push_back(box_Tensor); + c10::optional fparam_tensor; + inputs.push_back(fparam_tensor); + c10::optional aparam_tensor; + inputs.push_back(aparam_tensor); + bool do_atom_virial_tensor = true; + inputs.push_back(do_atom_virial_tensor); + c10::Dict outputs = + module.forward(inputs).toGenericDict(); + c10::IValue energy_ = outputs.at("energy"); + c10::IValue force_ = outputs.at("force"); + c10::IValue virial_ = outputs.at("virial"); + c10::IValue atom_virial_ = outputs.at("atom_virial"); + c10::IValue atom_energy_ = outputs.at("atom_energy"); + torch::Tensor flat_energy_ = energy_.toTensor().view({-1}); + torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU); + ener.assign(cpu_energy_.data_ptr(), + cpu_energy_.data_ptr() + cpu_energy_.numel()); + torch::Tensor flat_atom_energy_ = + atom_energy_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_atom_energy_ = flat_atom_energy_.to(torch::kCPU); + atom_energy.assign( + cpu_atom_energy_.data_ptr(), + cpu_atom_energy_.data_ptr() + cpu_atom_energy_.numel()); + torch::Tensor flat_force_ = force_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_force_ = flat_force_.to(torch::kCPU); + force.assign(cpu_force_.data_ptr(), + cpu_force_.data_ptr() + cpu_force_.numel()); + torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU); + virial.assign(cpu_virial_.data_ptr(), + cpu_virial_.data_ptr() + cpu_virial_.numel()); + torch::Tensor flat_atom_virial_ = + atom_virial_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_atom_virial_ = flat_atom_virial_.to(torch::kCPU); + atom_virial.assign( + cpu_atom_virial_.data_ptr(), + cpu_atom_virial_.data_ptr() + cpu_atom_virial_.numel()); +} + +template void DeepPotPT::compute>( + std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box); +template void DeepPotPT::compute>( + std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box); +void DeepPotPT::get_type_map(std::string& type_map) { + auto ret = module.run_method("get_type_map").toList(); + for (const torch::IValue& element : ret) { + type_map += torch::str(element); // Convert each element to a string + type_map += " "; // Add a space between elements + } +} + +// forward to template method +void DeepPotPT::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam) { + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box); +} +void DeepPotPT::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam) { + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box); +} +void DeepPotPT::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam) { + // TODO: atomic compute unsupported + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, + inlist, ago); +} +void DeepPotPT::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam) { + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, + inlist, ago); +} +void DeepPotPT::computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam) { + throw deepmd::deepmd_exception("computew_mixed_type is not implemented"); +} +void DeepPotPT::computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam) { + throw deepmd::deepmd_exception("computew_mixed_type is not implemented"); } #endif diff --git a/source/api_cc/src/commonPT.cc b/source/api_cc/src/commonPT.cc new file mode 100644 index 0000000000..4ed3b21fe8 --- /dev/null +++ b/source/api_cc/src/commonPT.cc @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#ifdef BUILD_PYTORCH +#include "commonPT.h" +using namespace deepmd; +void NeighborListDataPT::copy_from_nlist(const InputNlist& inlist, + int& max_num_neighbors) { + int inum = inlist.inum; + ilist.resize(inum); + numneigh.resize(inum); + memcpy(&ilist[0], inlist.ilist, inum * sizeof(int)); + int* max_element = std::max_element(inlist.numneigh, inlist.numneigh + inum); + max_num_neighbors = *max_element; + unsigned long nlist_size = (unsigned long)inum * max_num_neighbors; + jlist.resize(nlist_size); + jlist.assign(nlist_size, -1); + for (int ii = 0; ii < inum; ++ii) { + int jnum = inlist.numneigh[ii]; + numneigh[ii] = inlist.numneigh[ii]; + memcpy(&jlist[(unsigned long)ii * max_num_neighbors], inlist.firstneigh[ii], + jnum * sizeof(int)); + } +} +#endif diff --git a/source/api_cc/tests/test_deeppot_pt.cc b/source/api_cc/tests/test_deeppot_pt.cc new file mode 100644 index 0000000000..e0e90ac75c --- /dev/null +++ b/source/api_cc/tests/test_deeppot_pt.cc @@ -0,0 +1,625 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "DeepPot.h" +#include "neighbor_list.h" +#include "test_utils.h" + +template +class TestInferDeepPotAPt : public ::testing::Test { + protected: + std::vector coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + 3.51, 2.51, 2.60, 4.27, 3.22, 1.56}; + std::vector atype = {0, 1, 1, 0, 1, 1}; + std::vector box = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; + std::vector expected_e = { + + -93.016873944029, -185.923296645958, -185.927096544970, + -93.019371018039, -185.926179995548, -185.924351901852}; + std::vector expected_f = { + + 0.006277522211, -0.001117962774, 0.000618580445, 0.009928999655, + 0.003026035654, -0.006941982227, 0.000667853212, -0.002449963843, + 0.006506463508, -0.007284129115, 0.000530662205, -0.000028806821, + 0.000068097781, 0.006121331983, -0.009019754602, -0.009658343745, + -0.006110103225, 0.008865499697}; + std::vector expected_v = { + -0.000155238009, 0.000116605516, -0.007869862476, 0.000465578340, + 0.008182547185, -0.002398713212, -0.008112887338, -0.002423738425, + 0.007210716605, -0.019203504012, 0.001724938709, 0.009909211091, + 0.001153857542, -0.001600015103, -0.000560024090, 0.010727836276, + -0.001034836404, -0.007973454377, -0.021517399106, -0.004064359664, + 0.004866398692, -0.003360038617, -0.007241406162, 0.005920941051, + 0.004899151657, 0.006290788591, -0.006478820311, 0.001921504710, + 0.001313470921, -0.000304091236, 0.001684345981, 0.004124109256, + -0.006396084465, -0.000701095618, -0.006356507032, 0.009818550859, + -0.015230664587, -0.000110244376, 0.000690319396, 0.000045953023, + -0.005726548770, 0.008769818495, -0.000572380210, 0.008860603423, + -0.013819348050, -0.021227082558, -0.004977781343, 0.006646239696, + -0.005987066507, -0.002767831232, 0.003746502525, 0.007697590397, + 0.003746130152, -0.005172634748}; + int natoms; + double expected_tot_e; + std::vector expected_tot_v; + + deepmd::DeepPot dp; + + void SetUp() override { + std::string file_name = "../../tests/infer/deeppot_sea.pth"; + + dp.init(file_name); + + natoms = expected_e.size(); + EXPECT_EQ(natoms * 3, expected_f.size()); + EXPECT_EQ(natoms * 9, expected_v.size()); + expected_tot_e = 0.; + expected_tot_v.resize(9); + std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.); + for (int ii = 0; ii < natoms; ++ii) { + expected_tot_e += expected_e[ii]; + } + for (int ii = 0; ii < natoms; ++ii) { + for (int dd = 0; dd < 9; ++dd) { + expected_tot_v[dd] += expected_v[ii * 9 + dd]; + } + } + }; + + void TearDown() override { remove("deeppot.pb"); }; +}; + +TYPED_TEST_SUITE(TestInferDeepPotAPt, ValueTypes); + +TYPED_TEST(TestInferDeepPotAPt, cpu_build_nlist) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + double ener; + std::vector force, virial; + dp.compute(ener, force, virial, coord, atype, box); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAPt, cpu_build_nlist_numfv) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + class MyModel : public EnergyModelTest { + deepmd::DeepPot& mydp; + const std::vector atype; + + public: + MyModel(deepmd::DeepPot& dp_, const std::vector& atype_) + : mydp(dp_), atype(atype_){}; + virtual void compute(double& ener, + std::vector& force, + std::vector& virial, + const std::vector& coord, + const std::vector& box) { + mydp.compute(ener, force, virial, coord, atype, box); + } + }; + MyModel model(dp, atype); + model.test_f(coord, box); + model.test_v(coord, box); + std::vector box_(box); + box_[1] -= 0.4; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[2] += 0.5; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[4] += 0.2; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[3] -= 0.3; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[6] -= 0.7; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[7] += 0.6; + model.test_f(coord, box_); + model.test_v(coord, box_); +} + +TYPED_TEST(TestInferDeepPotAPt, cpu_build_nlist_atomic) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + double ener; + std::vector force, virial, atom_ener, atom_vir; + dp.compute(ener, force, virial, atom_ener, atom_vir, coord, atype, box); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + EXPECT_EQ(atom_ener.size(), natoms); + EXPECT_EQ(atom_vir.size(), natoms * 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + double ener; + std::vector force_, virial; + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 0); + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + + ener = 0.; + std::fill(force_.begin(), force_.end(), 0.0); + std::fill(virial.begin(), virial.end(), 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 1); + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_atomic) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + double ener; + std::vector force_, atom_ener_, atom_vir_, virial; + std::vector force, atom_ener, atom_vir; + dp.compute(ener, force_, virial, atom_ener_, atom_vir_, coord_cpy, atype_cpy, + box, nall - nloc, inlist, 0); + _fold_back(force, force_, mapping, nloc, nall, 3); + _fold_back(atom_ener, atom_ener_, mapping, nloc, nall, 1); + _fold_back(atom_vir, atom_vir_, mapping, nloc, nall, 9); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + EXPECT_EQ(atom_ener.size(), natoms); + EXPECT_EQ(atom_vir.size(), natoms * 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON); + } + + ener = 0.; + std::fill(force_.begin(), force_.end(), 0.0); + std::fill(virial.begin(), virial.end(), 0.0); + std::fill(atom_ener_.begin(), atom_ener_.end(), 0.0); + std::fill(atom_vir_.begin(), atom_vir_.end(), 0.0); + dp.compute(ener, force_, virial, atom_ener_, atom_vir_, coord_cpy, atype_cpy, + box, nall - nloc, inlist, 1); + _fold_back(force, force_, mapping, nloc, nall, 3); + _fold_back(atom_ener, atom_ener_, mapping, nloc, nall, 1); + _fold_back(atom_vir, atom_vir_, mapping, nloc, nall, 9); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + EXPECT_EQ(atom_ener.size(), natoms); + EXPECT_EQ(atom_vir.size(), natoms * 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_2rc) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc * 2); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + double ener; + std::vector force_(nall * 3, 0.0), virial(9, 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 0); + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + + ener = 0.; + std::fill(force_.begin(), force_.end(), 0.0); + std::fill(virial.begin(), virial.end(), 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 1); + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_type_sel) { + GTEST_SKIP() << "Skipping this test for unsupported"; + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + + // add vir atoms + int nvir = 2; + std::vector coord_vir(nvir * 3); + std::vector atype_vir(nvir, 2); + for (int ii = 0; ii < nvir; ++ii) { + coord_vir[ii] = coord[ii]; + } + coord.insert(coord.begin(), coord_vir.begin(), coord_vir.end()); + atype.insert(atype.begin(), atype_vir.begin(), atype_vir.end()); + natoms += nvir; + std::vector expected_f_vir(nvir * 3, 0.0); + expected_f.insert(expected_f.begin(), expected_f_vir.begin(), + expected_f_vir.end()); + + // build nlist + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + // dp compute + double ener; + std::vector force_(nall * 3, 0.0), virial(9, 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 0); + // fold back + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_type_sel_atomic) { + GTEST_SKIP() << "Skipping this test for unsupported"; + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + + // add vir atoms + int nvir = 2; + std::vector coord_vir(nvir * 3); + std::vector atype_vir(nvir, 2); + for (int ii = 0; ii < nvir; ++ii) { + coord_vir[ii] = coord[ii]; + } + coord.insert(coord.begin(), coord_vir.begin(), coord_vir.end()); + atype.insert(atype.begin(), atype_vir.begin(), atype_vir.end()); + natoms += nvir; + std::vector expected_f_vir(nvir * 3, 0.0); + expected_f.insert(expected_f.begin(), expected_f_vir.begin(), + expected_f_vir.end()); + + // build nlist + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + // dp compute + double ener; + std::vector force_(nall * 3, 0.0), virial(9, 0.0), atomic_energy, + atomic_virial; + dp.compute(ener, force_, virial, atomic_energy, atomic_virial, coord_cpy, + atype_cpy, box, nall - nloc, inlist, 0); + // fold back + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAPt, print_summary) { + deepmd::DeepPot& dp = this->dp; + dp.print_summary(""); +} + +template +class TestInferDeepPotAPtNoPbc : public ::testing::Test { + protected: + std::vector coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + 3.51, 2.51, 2.60, 4.27, 3.22, 1.56}; + std::vector atype = {0, 1, 1, 0, 1, 1}; + std::vector box = {}; + std::vector expected_e = {-93.003304908874, -185.915806542480, + -185.928116717624, -93.017934934346, + -185.924393412278, -185.923906740801}; + std::vector expected_f = { + 0.000868182637, -0.000363698132, -0.000657003077, -0.000868182637, + 0.000363698132, 0.000657003077, 0.007932614680, -0.001003609844, + 0.000737731722, -0.003883788858, 0.000686896282, -0.000578400682, + 0.004064895086, 0.006115547962, -0.008747097814, -0.008113720908, + -0.005798834400, 0.008587766774}; + std::vector expected_v = { + 0.007762485364, -0.003251851977, -0.005874313248, -0.003251851977, + 0.001362262315, 0.002460860955, -0.005874313248, 0.002460860955, + 0.004445426242, -0.007120030212, 0.002982715359, 0.005388130971, + 0.002982715359, -0.001249515894, -0.002257190002, 0.005388130971, + -0.002257190002, -0.004077504519, -0.015805863589, 0.001952684835, + -0.001522876482, 0.001796574704, -0.000358803950, 0.000369710813, + -0.001108943040, 0.000332585300, -0.000395481309, 0.008873525623, + 0.001919112114, -0.001486235522, 0.002002929532, 0.004222469272, + -0.006517211126, -0.001656192522, -0.006501210045, 0.010118622295, + -0.006548889778, -0.000465126991, 0.001002876603, 0.000240398734, + -0.005794489784, 0.008940685179, -0.000121727685, 0.008931999051, + -0.013852797563, -0.017962955675, -0.004645050453, 0.006214692837, + -0.005278283465, -0.002662692758, 0.003618275905, 0.007095320684, + 0.003648086464, -0.005023397513}; + int natoms; + double expected_tot_e; + std::vector expected_tot_v; + + deepmd::DeepPot dp; + + void SetUp() override { + std::string file_name = "../../tests/infer/deeppot_sea.pth"; + dp.init(file_name); + + natoms = expected_e.size(); + EXPECT_EQ(natoms * 3, expected_f.size()); + EXPECT_EQ(natoms * 9, expected_v.size()); + expected_tot_e = 0.; + expected_tot_v.resize(9); + std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.); + for (int ii = 0; ii < natoms; ++ii) { + expected_tot_e += expected_e[ii]; + } + for (int ii = 0; ii < natoms; ++ii) { + for (int dd = 0; dd < 9; ++dd) { + expected_tot_v[dd] += expected_v[ii * 9 + dd]; + } + } + }; + + void TearDown() override { remove("deeppot.pb"); }; +}; + +TYPED_TEST_SUITE(TestInferDeepPotAPtNoPbc, ValueTypes); + +TYPED_TEST(TestInferDeepPotAPtNoPbc, cpu_build_nlist) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + double ener; + std::vector force, virial; + dp.compute(ener, force, virial, coord, atype, box); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} diff --git a/source/install/test_cc_local.sh b/source/install/test_cc_local.sh index 22d22a27f6..73aa74ed90 100755 --- a/source/install/test_cc_local.sh +++ b/source/install/test_cc_local.sh @@ -18,7 +18,15 @@ INSTALL_PREFIX=${SCRIPT_PATH}/../../dp_test BUILD_TMP_DIR=${SCRIPT_PATH}/../build_tests mkdir -p ${BUILD_TMP_DIR} cd ${BUILD_TMP_DIR} -cmake -DINSTALL_TENSORFLOW=FALSE -DUSE_TF_PYTHON_LIBS=TRUE -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} -DBUILD_TESTING:BOOL=TRUE -DLAMMPS_VERSION=stable_2Aug2023_update2 ${CUDA_ARGS} .. +cmake \ + -D ENABLE_TENSORFLOW=TRUE \ + -D ENABLE_PYTORCH=TRUE \ + -D INSTALL_TENSORFLOW=FALSE \ + -D USE_TF_PYTHON_LIBS=TRUE \ + -D CMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \ + -D BUILD_TESTING:BOOL=TRUE \ + -D LAMMPS_VERSION=stable_2Aug2023_update2 \ + ${CUDA_ARGS} .. cmake --build . -j${NPROC} cmake --install . ctest --output-on-failure diff --git a/source/ipi/tests/test_driver.py b/source/ipi/tests/test_driver.py index 1b2e1dd951..b0fbf53b01 100644 --- a/source/ipi/tests/test_driver.py +++ b/source/ipi/tests/test_driver.py @@ -251,3 +251,108 @@ def test_normalize_coords(self): ) expected_se = np.sum(self.expected_e.reshape([nframes, -1]), axis=1) np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), default_places) + + +class TestDPIPIPt(TestDPIPI): + @classmethod + def setUpClass(cls): + cls.model_file = str(tests_path / "infer" / "deeppot_sea.pth") + + def setUp(self): + super().setUp() + + self.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) + self.expected_e = np.array( + [ + -93.016873944029, + -185.923296645958, + -185.927096544970, + -93.019371018039, + -185.926179995548, + -185.924351901852, + ] + ) + self.expected_f = np.array( + [ + 0.006277522211, + -0.001117962774, + 0.000618580445, + 0.009928999655, + 0.003026035654, + -0.006941982227, + 0.000667853212, + -0.002449963843, + 0.006506463508, + -0.007284129115, + 0.000530662205, + -0.000028806821, + 0.000068097781, + 0.006121331983, + -0.009019754602, + -0.009658343745, + -0.006110103225, + 0.008865499697, + ] + ) + self.expected_v = np.array( + [ + -0.000155238009, + 0.000116605516, + -0.007869862476, + 0.000465578340, + 0.008182547185, + -0.002398713212, + -0.008112887338, + -0.002423738425, + 0.007210716605, + -0.019203504012, + 0.001724938709, + 0.009909211091, + 0.001153857542, + -0.001600015103, + -0.000560024090, + 0.010727836276, + -0.001034836404, + -0.007973454377, + -0.021517399106, + -0.004064359664, + 0.004866398692, + -0.003360038617, + -0.007241406162, + 0.005920941051, + 0.004899151657, + 0.006290788591, + -0.006478820311, + 0.001921504710, + 0.001313470921, + -0.000304091236, + 0.001684345981, + 0.004124109256, + -0.006396084465, + -0.000701095618, + -0.006356507032, + 0.009818550859, + -0.015230664587, + -0.000110244376, + 0.000690319396, + 0.000045953023, + -0.005726548770, + 0.008769818495, + -0.000572380210, + 0.008860603423, + -0.013819348050, + -0.021227082558, + -0.004977781343, + 0.006646239696, + -0.005987066507, + -0.002767831232, + 0.003746502525, + 0.007697590397, + 0.003746130152, + -0.005172634748, + ] + ) + + @classmethod + def tearDownClass(cls): + cls.dp = None diff --git a/source/lmp/tests/test_lammps_pt.py b/source/lmp/tests/test_lammps_pt.py new file mode 100644 index 0000000000..bf1ef97e2b --- /dev/null +++ b/source/lmp/tests/test_lammps_pt.py @@ -0,0 +1,695 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +import subprocess as sp +import sys +from pathlib import ( + Path, +) + +import constants +import numpy as np +import pytest +from lammps import ( + PyLammps, +) +from write_lmp_data import ( + write_lmp_data, +) + +pbtxt_file2 = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot-1.pbtxt" +) +pb_file = Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_sea.pth" +pb_file2 = Path(__file__).parent / "graph2.pb" +system_file = Path(__file__).parent.parent.parent / "tests" +data_file = Path(__file__).parent / "data.lmp" +data_file_si = Path(__file__).parent / "data.si" +data_type_map_file = Path(__file__).parent / "data_type_map.lmp" +md_file = Path(__file__).parent / "md.out" + +# this is as the same as python and c++ tests, test_deeppot_a.py +expected_ae = np.array( + [ + -93.016873944029, + -185.923296645958, + -185.927096544970, + -93.019371018039, + -185.926179995548, + -185.924351901852, + ] +) +expected_e = np.sum(expected_ae) +expected_f = np.array( + [ + 0.006277522211, + -0.001117962774, + 0.000618580445, + 0.009928999655, + 0.003026035654, + -0.006941982227, + 0.000667853212, + -0.002449963843, + 0.006506463508, + -0.007284129115, + 0.000530662205, + -0.000028806821, + 0.000068097781, + 0.006121331983, + -0.009019754602, + -0.009658343745, + -0.006110103225, + 0.008865499697, + ] +).reshape(6, 3) + +expected_f2 = np.array( + [ + [-0.6454949, 1.72457783, 0.18897958], + [1.68936514, -0.36995299, -1.36044464], + [-1.09902692, -1.35487928, 1.17416702], + [1.68426111, -0.50835585, 0.98340415], + [0.05771758, 1.12515818, -1.77561531], + [-1.686822, -0.61654789, 0.78950921], + ] +) + +expected_v = -np.array( + [ + -0.000155238009, + 0.000116605516, + -0.007869862476, + 0.000465578340, + 0.008182547185, + -0.002398713212, + -0.008112887338, + -0.002423738425, + 0.007210716605, + -0.019203504012, + 0.001724938709, + 0.009909211091, + 0.001153857542, + -0.001600015103, + -0.000560024090, + 0.010727836276, + -0.001034836404, + -0.007973454377, + -0.021517399106, + -0.004064359664, + 0.004866398692, + -0.003360038617, + -0.007241406162, + 0.005920941051, + 0.004899151657, + 0.006290788591, + -0.006478820311, + 0.001921504710, + 0.001313470921, + -0.000304091236, + 0.001684345981, + 0.004124109256, + -0.006396084465, + -0.000701095618, + -0.006356507032, + 0.009818550859, + -0.015230664587, + -0.000110244376, + 0.000690319396, + 0.000045953023, + -0.005726548770, + 0.008769818495, + -0.000572380210, + 0.008860603423, + -0.013819348050, + -0.021227082558, + -0.004977781343, + 0.006646239696, + -0.005987066507, + -0.002767831232, + 0.003746502525, + 0.007697590397, + 0.003746130152, + -0.005172634748, + ] +).reshape(6, 9) +expected_v2 = -np.array( + [ + [ + -0.70008436, + -0.06399891, + 0.63678391, + -0.07642171, + -0.70580035, + 0.20506145, + 0.64098364, + 0.20305781, + -0.57906794, + ], + [ + -0.6372635, + 0.14315552, + 0.51952246, + 0.04604049, + -0.06003681, + -0.02688702, + 0.54489318, + -0.10951559, + -0.43730539, + ], + [ + -0.25090748, + -0.37466262, + 0.34085833, + -0.26690852, + -0.37676917, + 0.29080825, + 0.31600481, + 0.37558276, + -0.33251064, + ], + [ + -0.80195614, + -0.10273138, + 0.06935364, + -0.10429256, + -0.29693811, + 0.45643496, + 0.07247872, + 0.45604679, + -0.71048816, + ], + [ + -0.03840668, + -0.07680205, + 0.10940472, + -0.02374189, + -0.27610266, + 0.4336071, + 0.02465248, + 0.4290638, + -0.67496763, + ], + [ + -0.61475065, + -0.21163135, + 0.26652929, + -0.26134659, + -0.11560267, + 0.15415902, + 0.34343952, + 0.1589482, + -0.21370642, + ], + ] +).reshape(6, 9) + +box = np.array([0, 13, 0, 13, 0, 13, 0, 0, 0]) +coord = np.array( + [ + [12.83, 2.56, 2.18], + [12.09, 2.87, 2.74], + [0.25, 3.32, 1.68], + [3.36, 3.00, 1.81], + [3.51, 2.51, 2.60], + [4.27, 3.22, 1.56], + ] +) +type_OH = np.array([1, 2, 2, 1, 2, 2]) +type_HO = np.array([2, 1, 1, 2, 1, 1]) + + +sp.check_output( + "{} -m deepmd convert-from pbtxt -i {} -o {}".format( + sys.executable, + pbtxt_file2.resolve(), + pb_file2.resolve(), + ).split() +) + + +def setup_module(): + write_lmp_data(box, coord, type_OH, data_file) + write_lmp_data(box, coord, type_HO, data_type_map_file) + write_lmp_data( + box * constants.dist_metal2si, + coord * constants.dist_metal2si, + type_OH, + data_file_si, + ) + + +def teardown_module(): + os.remove(data_file) + os.remove(data_type_map_file) + + +def _lammps(data_file, units="metal") -> PyLammps: + lammps = PyLammps() + lammps.units(units) + lammps.boundary("p p p") + lammps.atom_style("atomic") + if units == "metal" or units == "real": + lammps.neighbor("2.0 bin") + elif units == "si": + lammps.neighbor("2.0e-10 bin") + else: + raise ValueError("units should be metal, real, or si") + lammps.neigh_modify("every 10 delay 0 check no") + lammps.read_data(data_file.resolve()) + if units == "metal" or units == "real": + lammps.mass("1 16") + lammps.mass("2 2") + elif units == "si": + lammps.mass("1 %.10e" % (16 * constants.mass_metal2si)) + lammps.mass("2 %.10e" % (2 * constants.mass_metal2si)) + else: + raise ValueError("units should be metal, real, or si") + if units == "metal": + lammps.timestep(0.0005) + elif units == "real": + lammps.timestep(0.5) + elif units == "si": + lammps.timestep(5e-16) + else: + raise ValueError("units should be metal, real, or si") + lammps.fix("1 all nve") + return lammps + + +@pytest.fixture +def lammps(): + lmp = _lammps(data_file=data_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_type_map(): + lmp = _lammps(data_file=data_type_map_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_real(): + lmp = _lammps(data_file=data_file, units="real") + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_si(): + lmp = _lammps(data_file=data_file_si, units="si") + yield lmp + lmp.close() + + +def test_pair_deepmd(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + lammps.run(1) + + +def test_pair_deepmd_virial(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + + +def test_pair_deepmd_model_devi(lammps): + lammps.pair_style( + "deepmd {} {} out_file {} out_freq 1 atomic".format( + pb_file.resolve(), pb_file2.resolve(), md_file.resolve() + ) + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_virial(lammps): + lammps.pair_style( + "deepmd {} {} out_file {} out_freq 1 atomic".format( + pb_file.resolve(), pb_file2.resolve(), md_file.resolve() + ) + ) + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative(lammps): + relative = 1.0 + lammps.pair_style( + "deepmd {} {} out_file {} out_freq 1 atomic relative {}".format( + pb_file.resolve(), pb_file2.resolve(), md_file.resolve(), relative + ) + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative_v(lammps): + relative = 1.0 + lammps.pair_style( + "deepmd {} {} out_file {} out_freq 1 atomic relative_v {}".format( + pb_file.resolve(), pb_file2.resolve(), md_file.resolve(), relative + ) + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_type_map(lammps_type_map): + lammps_type_map.pair_style(f"deepmd {pb_file.resolve()}") + lammps_type_map.pair_coeff("* * H O") + lammps_type_map.run(0) + assert lammps_type_map.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps_type_map.atoms[ii].force == pytest.approx( + expected_f[lammps_type_map.atoms[ii].id - 1] + ) + lammps_type_map.run(1) + + +def test_pair_deepmd_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + lammps_real.run(1) + + +def test_pair_deepmd_virial_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_real(lammps_real): + lammps_real.pair_style( + "deepmd {} {} out_file {} out_freq 1 atomic".format( + pb_file.resolve(), pb_file2.resolve(), md_file.resolve() + ) + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_virial_real(lammps_real): + lammps_real.pair_style( + "deepmd {} {} out_file {} out_freq 1 atomic".format( + pb_file.resolve(), pb_file2.resolve(), md_file.resolve() + ) + ) + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + "deepmd {} {} out_file {} out_freq 1 atomic relative {}".format( + pb_file.resolve(), + pb_file2.resolve(), + md_file.resolve(), + relative * constants.force_metal2real, + ) + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_v_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + "deepmd {} {} out_file {} out_freq 1 atomic relative_v {}".format( + pb_file.resolve(), + pb_file2.resolve(), + md_file.resolve(), + relative * constants.ener_metal2real, + ) + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_si(lammps_si): + lammps_si.pair_style(f"deepmd {pb_file.resolve()}") + lammps_si.pair_coeff("* *") + lammps_si.run(0) + assert lammps_si.eval("pe") == pytest.approx(expected_e * constants.ener_metal2si) + for ii in range(6): + assert lammps_si.atoms[ii].force == pytest.approx( + expected_f[lammps_si.atoms[ii].id - 1] * constants.force_metal2si + ) + lammps_si.run(1) diff --git a/source/tests/infer/deeppot_sea.pth b/source/tests/infer/deeppot_sea.pth new file mode 100644 index 0000000000..98aaa8a2ad Binary files /dev/null and b/source/tests/infer/deeppot_sea.pth differ