diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 721a0cd6eb..4113223e02 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,8 @@ repos: exclude: | (?x)^( source/tests/infer/dipolecharge_e.pbtxt| - source/tests/infer/deeppolar_new.pbtxt + source/tests/infer/deeppolar_new.pbtxt| + source/tests/infer/deeppot_dpa.savedmodel/saved_model.pb )$ - id: check-merge-conflict - id: check-symlinks diff --git a/doc/backend.md b/doc/backend.md index 3fb70bee90..00e7e64ccc 100644 --- a/doc/backend.md +++ b/doc/backend.md @@ -31,7 +31,9 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different [JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required. Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions. `.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow. -Currently, this backend is developed actively, and has no support for training and the C++ interface. +Only the `.savedmodel` format supports C++ inference, which needs the TensorFlow C++ interface. +The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs. +Currently, this backend is developed actively, and has no support for training. ### DP {{ dpmodel_icon }} diff --git a/doc/install/install-from-source.md b/doc/install/install-from-source.md index 4a0a104b7e..0bf6fa5ee3 100644 --- a/doc/install/install-from-source.md +++ b/doc/install/install-from-source.md @@ -297,7 +297,9 @@ If one does not need to use DeePMD-kit with LAMMPS or i-PI, then the python inte ::::{tab-set} -:::{tab-item} TensorFlow {{ tensorflow_icon }} +:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }} + +The C++ interfaces of both TensorFlow and JAX backends are based on the TensorFlow C++ library. Since TensorFlow 2.12, TensorFlow C++ library (`libtensorflow_cc`) is packaged inside the Python library. Thus, you can skip building TensorFlow C++ library manually. If that does not work for you, you can still build it manually. @@ -338,7 +340,7 @@ We recommend using [conda packages](https://docs.deepmodeling.org/faq/conda.html ::::{tab-set} -:::{tab-item} TensorFlow {{ tensorflow_icon }} +:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }} I assume you have activated the TensorFlow Python environment and want to install DeePMD-kit into path `$deepmd_root`, then execute CMake @@ -375,7 +377,7 @@ One may add the following CMake variables to `cmake` using the [`-D ==nl.set_mask(mask); } +void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) { + nl->nl.set_mapping(mapping); +} void DP_DeleteNlist(DP_Nlist* nl) { delete nl; } DP_DeepPot::DP_DeepPot() {} diff --git a/source/api_cc/include/DeepPotJAX.h b/source/api_cc/include/DeepPotJAX.h new file mode 100644 index 0000000000..606836de7e --- /dev/null +++ b/source/api_cc/include/DeepPotJAX.h @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#pragma once + +#include +#include + +#include "DeepPot.h" +#include "common.h" +#include "neighbor_list.h" + +namespace deepmd { +/** + * @brief TensorFlow implementation for Deep Potential. + **/ +class DeepPotJAX : public DeepPotBase { + public: + /** + * @brief DP constructor without initialization. + **/ + DeepPotJAX(); + virtual ~DeepPotJAX(); + /** + * @brief DP constructor with initialization. + * @param[in] model The name of the frozen model file. + * @param[in] gpu_rank The GPU rank. Default is 0. + * @param[in] file_content The content of the model file. If it is not empty, + *DP will read from the string instead of the file. + **/ + DeepPotJAX(const std::string& model, + const int& gpu_rank = 0, + const std::string& file_content = ""); + /** + * @brief Initialize the DP. + * @param[in] model The name of the frozen model file. + * @param[in] gpu_rank The GPU rank. Default is 0. + * @param[in] file_content The content of the model file. If it is not empty, + *DP will read from the string instead of the file. + **/ + void init(const std::string& model, + const int& gpu_rank = 0, + const std::string& file_content = ""); + /** + * @brief Get the cutoff radius. + * @return The cutoff radius. + **/ + double cutoff() const { + assert(inited); + return rcut; + }; + /** + * @brief Get the number of types. + * @return The number of types. + **/ + int numb_types() const { + assert(inited); + return ntypes; + }; + /** + * @brief Get the number of types with spin. + * @return The number of types with spin. + **/ + int numb_types_spin() const { + assert(inited); + return 0; + }; + /** + * @brief Get the dimension of the frame parameter. + * @return The dimension of the frame parameter. + **/ + int dim_fparam() const { + assert(inited); + return dfparam; + }; + /** + * @brief Get the dimension of the atomic parameter. + * @return The dimension of the atomic parameter. + **/ + int dim_aparam() const { + assert(inited); + return daparam; + }; + /** + * @brief Get the type map (element name of the atom types) of this model. + * @param[out] type_map The type map of this model. + **/ + void get_type_map(std::string& type_map); + + /** + * @brief Get whether the atom dimension of aparam is nall instead of fparam. + * @param[out] aparam_nall whether the atom dimension of aparam is nall + *instead of fparam. + **/ + bool is_aparam_nall() const { + assert(inited); + return false; + }; + + // forward to template class + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + + private: + bool inited; + // device + std::string device; + // the cutoff radius + double rcut; + // the number of types + int ntypes; + // the dimension of the frame parameter + int dfparam; + // the dimension of the atomic parameter + int daparam; + // type map + std::string type_map; + // sel + std::vector sel; + // number of neighbors + int nnei; + /** TF C API objects. + * @{ + */ + TF_Graph* graph; + TF_Status* status; + TF_Session* session; + TF_SessionOptions* sessionopts; + TFE_ContextOptions* ctx_opts; + TFE_Context* ctx; + std::vector func_vector; + /** + * @} + */ + // neighbor list data + NeighborListData nlist_data; + /** + * @brief Evaluate the energy, force, virial, atomic energy, and atomic virial + *by using this DP. + * @param[out] ener The system energy. + * @param[out] force The force on each atom. + * @param[out] virial The virial. + * @param[out] atom_energy The atomic energy. + * @param[out] atom_virial The atomic virial. + * @param[in] coord The coordinates of atoms. The array should be of size + *nframes x natoms x 3. + * @param[in] atype The atom types. The list should contain natoms ints. + * @param[in] box The cell of the region. The array should be of size nframes + *x 9. + * @param[in] nghost The number of ghost atoms. + * @param[in] lmp_list The input neighbour list. + * @param[in] ago Update the internal neighbour list if ago is 0. + * @param[in] fparam The frame parameter. The array can be of size : + * nframes x dim_fparam. + * dim_fparam. Then all frames are assumed to be provided with the same + *fparam. + * @param[in] aparam The atomic parameter The array can be of size : + * nframes x natoms x dim_aparam. + * natoms x dim_aparam. Then all frames are assumed to be provided with the + *same aparam. + * @param[in] atomic Whether to compute atomic energy and virial. + **/ + template + void compute(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); +}; +} // namespace deepmd diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index 9b1adcbd62..def3df933b 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -13,7 +13,7 @@ namespace deepmd { typedef double ENERGYTYPE; -enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown }; +enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown }; struct NeighborListData { /// Array stores the core region atom's index diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index c184446288..193aabce0a 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -7,6 +7,7 @@ #include "AtomMap.h" #include "common.h" #ifdef BUILD_TENSORFLOW +#include "DeepPotJAX.h" #include "DeepPotTF.h" #endif #ifdef BUILD_PYTORCH @@ -41,6 +42,9 @@ void DeepPot::init(const std::string& model, backend = deepmd::DPBackend::PyTorch; } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") { backend = deepmd::DPBackend::TensorFlow; + } else if (model.length() >= 11 && + model.substr(model.length() - 11) == ".savedmodel") { + backend = deepmd::DPBackend::JAX; } else { throw deepmd::deepmd_exception("Unsupported model file format"); } @@ -58,6 +62,14 @@ void DeepPot::init(const std::string& model, #endif } else if (deepmd::DPBackend::Paddle == backend) { throw deepmd::deepmd_exception("PaddlePaddle backend is not supported yet"); + } else if (deepmd::DPBackend::JAX == backend) { +#ifdef BUILD_TENSORFLOW + dp = std::make_shared(model, gpu_rank, file_content); +#else + throw deepmd::deepmd_exception( + "TensorFlow backend is not built, which is used to load JAX2TF " + "SavedModels"); +#endif } else { throw deepmd::deepmd_exception("Unknown file type"); } diff --git a/source/api_cc/src/DeepPotJAX.cc b/source/api_cc/src/DeepPotJAX.cc new file mode 100644 index 0000000000..8e4a9eda64 --- /dev/null +++ b/source/api_cc/src/DeepPotJAX.cc @@ -0,0 +1,581 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#ifdef BUILD_TENSORFLOW + +#include "DeepPotJAX.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "device.h" +#include "errors.h" + +inline void check_status(TF_Status* status) { + if (TF_GetCode(status) != TF_OK) { + throw deepmd::deepmd_exception("TensorFlow C API Error: " + + std::string(TF_Message(status))); + } +} + +inline void find_function(TF_Function*& found_func, + const std::vector& funcs, + const std::string func_name) { + for (size_t i = 0; i < funcs.size(); i++) { + TF_Function* func = funcs[i]; + const char* name = TF_FunctionName(func); + std::string name_(name); + // remove trailing integer e.g. _123 + std::string::size_type pos = name_.find_last_not_of("0123456789_"); + if (pos != std::string::npos) { + name_ = name_.substr(0, pos + 1); + } + if (name_ == "__inference_" + func_name) { + found_func = func; + return; + } + } + found_func = NULL; +} + +inline TF_DataType get_data_tensor_type(const std::vector& data) { + return TF_DOUBLE; +} + +inline TF_DataType get_data_tensor_type(const std::vector& data) { + return TF_FLOAT; +} + +inline TF_DataType get_data_tensor_type(const std::vector& data) { + return TF_INT32; +} + +inline TF_DataType get_data_tensor_type(const std::vector& data) { + return TF_INT64; +} + +inline TFE_Op* get_func_op(TFE_Context* ctx, + const std::string func_name, + const std::vector& funcs, + const std::string device, + TF_Status* status) { + TF_Function* func = NULL; + find_function(func, funcs, func_name); + if (func == NULL) { + throw std::runtime_error("Function " + func_name + " not found"); + } + TFE_ContextAddFunction(ctx, func, status); + check_status(status); + const char* real_func_name = TF_FunctionName(func); + // execute the function + TFE_Op* op = TFE_NewOp(ctx, real_func_name, status); + check_status(status); + TFE_OpSetDevice(op, device.c_str(), status); + check_status(status); + return op; +} + +template +inline T get_scalar(TFE_Context* ctx, + const std::string func_name, + const std::vector& funcs, + const std::string device, + TF_Status* status) { + TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status); + check_status(status); + TFE_TensorHandle* retvals[1]; + int nretvals = 1; + TFE_Execute(op, retvals, &nretvals, status); + check_status(status); + TFE_TensorHandle* retval = retvals[0]; + TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status); + check_status(status); + T* data = (T*)TF_TensorData(tensor); + // copy data + T result = *data; + // deallocate + TFE_DeleteOp(op); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retval); + return result; +} + +template +inline std::vector get_vector(TFE_Context* ctx, + const std::string func_name, + const std::vector& funcs, + const std::string device, + TF_Status* status) { + TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status); + check_status(status); + TFE_TensorHandle* retvals[1]; + int nretvals = 1; + TFE_Execute(op, retvals, &nretvals, status); + check_status(status); + TFE_TensorHandle* retval = retvals[0]; + // copy data + std::vector result; + tensor_to_vector(result, retval, status); + // deallocate + TFE_DeleteTensorHandle(retval); + TFE_DeleteOp(op); + return result; +} + +inline std::vector get_vector_string( + TFE_Context* ctx, + const std::string func_name, + const std::vector& funcs, + const std::string device, + TF_Status* status) { + TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status); + check_status(status); + TFE_TensorHandle* retvals[1]; + int nretvals = 1; + TFE_Execute(op, retvals, &nretvals, status); + check_status(status); + TFE_TensorHandle* retval = retvals[0]; + TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status); + check_status(status); + // calculate the number of bytes in each string + const void* data = TF_TensorData(tensor); + int64_t bytes_each_string = + TF_TensorByteSize(tensor) / TF_TensorElementCount(tensor); + // copy data + std::vector result; + for (int ii = 0; ii < TF_TensorElementCount(tensor); ++ii) { + const TF_TString* datastr = + static_cast(static_cast( + static_cast(data) + ii * bytes_each_string)); + const char* dst = TF_TString_GetDataPointer(datastr); + size_t dst_len = TF_TString_GetSize(datastr); + result.push_back(std::string(dst, dst_len)); + } + + // deallocate + TFE_DeleteOp(op); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retval); + return result; +} + +template +inline TF_Tensor* create_tensor(const std::vector& data, + const std::vector& shape) { + TF_Tensor* tensor = + TF_AllocateTensor(get_data_tensor_type(data), shape.data(), shape.size(), + data.size() * sizeof(T)); + memcpy(TF_TensorData(tensor), data.data(), TF_TensorByteSize(tensor)); + return tensor; +} + +template +inline TFE_TensorHandle* add_input(TFE_Op* op, + const std::vector& data, + const std::vector& data_shape, + TF_Tensor*& data_tensor, + TF_Status* status) { + data_tensor = create_tensor(data, data_shape); + TFE_TensorHandle* handle = TFE_NewTensorHandle(data_tensor, status); + check_status(status); + + TFE_OpAddInput(op, handle, status); + check_status(status); + return handle; +} + +template +inline void tensor_to_vector(std::vector& result, + TFE_TensorHandle* retval, + TF_Status* status) { + TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status); + check_status(status); + T* data = (T*)TF_TensorData(tensor); + // copy data + result.resize(TF_TensorElementCount(tensor)); + for (int i = 0; i < TF_TensorElementCount(tensor); i++) { + result[i] = data[i]; + } + // Delete the tensor to free memory + TF_DeleteTensor(tensor); +} + +deepmd::DeepPotJAX::DeepPotJAX() : inited(false) {} +deepmd::DeepPotJAX::DeepPotJAX(const std::string& model, + const int& gpu_rank, + const std::string& file_content) + : inited(false) { + init(model, gpu_rank, file_content); +} +void deepmd::DeepPotJAX::init(const std::string& model, + const int& gpu_rank, + const std::string& file_content) { + if (inited) { + std::cerr << "WARNING: deepmd-kit should not be initialized twice, do " + "nothing at the second call of initializer" + << std::endl; + return; + } + + const char* saved_model_dir = model.c_str(); + graph = TF_NewGraph(); + status = TF_NewStatus(); + + sessionopts = TF_NewSessionOptions(); + TF_Buffer* runopts = NULL; + + const char* tags = "serve"; + int ntags = 1; + + session = TF_LoadSessionFromSavedModel(sessionopts, runopts, saved_model_dir, + &tags, ntags, graph, NULL, status); + check_status(status); + + int nfuncs = TF_GraphNumFunctions(graph); + // allocate memory for the TF_Function* array + func_vector.resize(nfuncs); + TF_Function** funcs = func_vector.data(); + TF_GraphGetFunctions(graph, funcs, nfuncs, status); + check_status(status); + + ctx_opts = TFE_NewContextOptions(); + ctx = TFE_NewContext(ctx_opts, status); + check_status(status); +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + int gpu_num; + DPGetDeviceCount(gpu_num); // check current device environment + DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); + if (gpu_num > 0) { + device = "/gpu:" + std::to_string(gpu_rank % gpu_num); + } else { + device = "/cpu:0"; + } +#else + device = "/cpu:0"; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + + rcut = get_scalar(ctx, "get_rcut", func_vector, device, status); + dfparam = + get_scalar(ctx, "get_dim_fparam", func_vector, device, status); + daparam = + get_scalar(ctx, "get_dim_aparam", func_vector, device, status); + std::vector type_map_ = + get_vector_string(ctx, "get_type_map", func_vector, device, status); + // deepmd-kit stores type_map as a concatenated string, split by ' ' + type_map = type_map_[0]; + for (size_t i = 1; i < type_map_.size(); i++) { + type_map += " " + type_map_[i]; + } + ntypes = type_map_.size(); + sel = get_vector(ctx, "get_sel", func_vector, device, status); + nnei = std::accumulate(sel.begin(), sel.end(), decltype(sel)::value_type(0)); + inited = true; +} + +deepmd::DeepPotJAX::~DeepPotJAX() { + if (inited) { + TF_DeleteSession(session, status); + TF_DeleteGraph(graph); + TF_DeleteSessionOptions(sessionopts); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); + TFE_DeleteContextOptions(ctx_opts); + for (size_t i = 0; i < func_vector.size(); i++) { + TF_DeleteFunction(func_vector[i]); + } + } +} + +template +void deepmd::DeepPotJAX::compute(std::vector& ener, + std::vector& force_, + std::vector& virial, + std::vector& atom_energy_, + std::vector& atom_virial_, + const std::vector& dcoord, + const std::vector& datype, + const std::vector& box, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam_, + const bool atomic) { + std::vector coord, force, aparam, atom_energy, atom_virial; + std::vector ener_double, force_double, virial_double, + atom_energy_double, atom_virial_double; + std::vector atype, fwd_map, bkw_map; + int nghost_real, nall_real, nloc_real; + int nall = datype.size(); + // nlist passed to the model + int nframes = 1; + + select_real_atoms_coord(coord, atype, aparam, nghost_real, fwd_map, bkw_map, + nall_real, nloc_real, dcoord, datype, aparam_, nghost, + ntypes, nframes, daparam, nall, false); + + if (nloc_real == 0) { + // no real atoms, fill 0 for all outputs + // this can prevent a Xla error + ener.resize(nframes, 0.0); + force_.resize(static_cast(nframes) * nall * 3, 0.0); + virial.resize(static_cast(nframes) * 9, 0.0); + atom_energy_.resize(static_cast(nframes) * nall, 0.0); + atom_virial_.resize(static_cast(nframes) * nall * 9, 0.0); + return; + } + + // cast coord, fparam, and aparam to double - I think it's useless to have a + // float model interface + std::vector coord_double(coord.begin(), coord.end()); + std::vector fparam_double(fparam.begin(), fparam.end()); + std::vector aparam_double(aparam.begin(), aparam.end()); + + TFE_Op* op; + if (atomic) { + op = get_func_op(ctx, "call_lower_with_atomic_virial", func_vector, device, + status); + } else { + op = get_func_op(ctx, "call_lower_without_atomic_virial", func_vector, + device, status); + } + std::vector input_list(6); + std::vector data_tensor(6); + // coord + std::vector coord_shape = {nframes, nall_real, 3}; + input_list[0] = + add_input(op, coord_double, coord_shape, data_tensor[0], status); + // atype + std::vector atype_shape = {nframes, nall_real}; + input_list[1] = add_input(op, atype, atype_shape, data_tensor[1], status); + // nlist + if (ago == 0) { + nlist_data.copy_from_nlist(lmp_list); + nlist_data.shuffle_exclude_empty(fwd_map); + } + std::vector nlist_shape = {nframes, nloc_real, nnei}; + std::vector nlist(static_cast(nframes) * nloc_real * nnei); + // pass nlist_data.jlist to nlist + for (int ii = 0; ii < nloc_real; ii++) { + for (int jj = 0; jj < nnei; jj++) { + if (jj < nlist_data.jlist[ii].size()) { + nlist[ii * nnei + jj] = nlist_data.jlist[ii][jj]; + } else { + nlist[ii * nnei + jj] = -1; + } + } + if (nnei < nlist_data.jlist[ii].size()) { + std::cerr << "WARNING: nnei < nlist_data.jlist[ii].size(); JAX backend " + "never handles this." + << std::endl; + } + } + input_list[2] = add_input(op, nlist, nlist_shape, data_tensor[2], status); + // mapping; for now, set it to -1, assume it is not used + std::vector mapping_shape = {nframes, nall_real}; + std::vector mapping(nframes * nall_real, -1); + // pass mapping if it is given in the neighbor list + if (lmp_list.mapping) { + // assume nframes is 1 + for (size_t ii = 0; ii < nall_real; ii++) { + mapping[ii] = lmp_list.mapping[fwd_map[ii]]; + } + } + input_list[3] = add_input(op, mapping, mapping_shape, data_tensor[3], status); + // fparam + std::vector fparam_shape = {nframes, dfparam}; + input_list[4] = + add_input(op, fparam_double, fparam_shape, data_tensor[4], status); + // aparam + std::vector aparam_shape = {nframes, nloc_real, daparam}; + input_list[5] = + add_input(op, aparam_double, aparam_shape, data_tensor[5], status); + // execute the function + int nretvals = 6; + TFE_TensorHandle* retvals[nretvals]; + + TFE_Execute(op, retvals, &nretvals, status); + check_status(status); + + // copy data + // the order is: + // energy + // energy_derv_c + // energy_derv_c_redu + // energy_derv_r + // energy_redu + // mask + // it seems the order is the alphabet order? + // not sure whether it is safe to assume the order + tensor_to_vector(ener_double, retvals[4], status); + tensor_to_vector(force_double, retvals[3], status); + tensor_to_vector(virial_double, retvals[2], status); + tensor_to_vector(atom_energy_double, retvals[0], status); + tensor_to_vector(atom_virial_double, retvals[1], status); + + // cast back to VALUETYPE + ener = std::vector(ener_double.begin(), ener_double.end()); + force = std::vector(force_double.begin(), force_double.end()); + virial = std::vector(virial_double.begin(), virial_double.end()); + atom_energy = std::vector(atom_energy_double.begin(), + atom_energy_double.end()); + atom_virial = std::vector(atom_virial_double.begin(), + atom_virial_double.end()); + + // nall atom_energy is required in the C++ API; + // we always forget it! + atom_energy.resize(static_cast(nframes) * nall_real, 0.0); + + force_.resize(static_cast(nframes) * fwd_map.size() * 3); + atom_energy_.resize(static_cast(nframes) * fwd_map.size()); + atom_virial_.resize(static_cast(nframes) * fwd_map.size() * 9); + select_map(force_, force, bkw_map, 3, nframes, fwd_map.size(), + nall_real); + select_map(atom_energy_, atom_energy, bkw_map, 1, nframes, + fwd_map.size(), nall_real); + select_map(atom_virial_, atom_virial, bkw_map, 9, nframes, + fwd_map.size(), nall_real); + + // cleanup input_list, etc + for (size_t i = 0; i < 6; i++) { + TFE_DeleteTensorHandle(input_list[i]); + TF_DeleteTensor(data_tensor[i]); + } + for (size_t i = 0; i < nretvals; i++) { + TFE_DeleteTensorHandle(retvals[i]); + } + TFE_DeleteOp(op); +} + +template void deepmd::DeepPotJAX::compute( + std::vector& dener, + std::vector& dforce_, + std::vector& dvirial, + std::vector& datom_energy_, + std::vector& datom_virial_, + const std::vector& dcoord_, + const std::vector& datype_, + const std::vector& dbox, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam_, + const bool atomic); + +template void deepmd::DeepPotJAX::compute( + std::vector& dener, + std::vector& dforce_, + std::vector& dvirial, + std::vector& datom_energy_, + std::vector& datom_virial_, + const std::vector& dcoord_, + const std::vector& datype_, + const std::vector& dbox, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam_, + const bool atomic); + +void deepmd::DeepPotJAX::get_type_map(std::string& type_map_) { + type_map_ = type_map; +} + +// forward to template method +void deepmd::DeepPotJAX::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + throw deepmd::deepmd_exception("not implemented"); +} +void deepmd::DeepPotJAX::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + throw deepmd::deepmd_exception("not implemented"); +} +void deepmd::DeepPotJAX::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, + nghost, inlist, ago, fparam, aparam, atomic); +} +void deepmd::DeepPotJAX::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, + nghost, inlist, ago, fparam, aparam, atomic); +} +void deepmd::DeepPotJAX::computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + throw deepmd::deepmd_exception("not implemented"); +} +void deepmd::DeepPotJAX::computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + throw deepmd::deepmd_exception("not implemented"); +} +#endif diff --git a/source/api_cc/tests/test_deeppot_jax.cc b/source/api_cc/tests/test_deeppot_jax.cc new file mode 100644 index 0000000000..c9fe8ea3dd --- /dev/null +++ b/source/api_cc/tests/test_deeppot_jax.cc @@ -0,0 +1,439 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "DeepPot.h" +#include "neighbor_list.h" +#include "test_utils.h" + +template +class TestInferDeepPotAJAX : public ::testing::Test { + protected: + // import numpy as np + // from deepmd.infer import DeepPot + // coord = np.array([ + // 12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + // 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + // 3.51, 2.51, 2.60, 4.27, 3.22, 1.56 + // ]).reshape(1, -1) + // atype = np.array([0, 1, 1, 0, 1, 1]) + // box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.]).reshape(1, -1) + // dp = DeepPot("deeppot_sea.savedmodel") + // e, f, v, ae, av = dp.eval(coord, box, atype, atomic=True) + // np.set_printoptions(precision=16) + // print(f"{e.ravel()=} {v.ravel()=} {f.ravel()=} {ae.ravel()=} + // {av.ravel()=}") + std::vector coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + 3.51, 2.51, 2.60, 4.27, 3.22, 1.56}; + std::vector atype = {0, 1, 1, 0, 1, 1}; + std::vector box = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; + // the data in this file is just copied from PT + std::vector expected_e = { + + -93.016873944029, -185.923296645958, -185.927096544970, + -93.019371018039, -185.926179995548, -185.924351901852}; + std::vector expected_f = { + + 0.006277522211, -0.001117962774, 0.000618580445, 0.009928999655, + 0.003026035654, -0.006941982227, 0.000667853212, -0.002449963843, + 0.006506463508, -0.007284129115, 0.000530662205, -0.000028806821, + 0.000068097781, 0.006121331983, -0.009019754602, -0.009658343745, + -0.006110103225, 0.008865499697}; + std::vector expected_v = { + -0.000155238009, 0.000116605516, -0.007869862476, 0.000465578340, + 0.008182547185, -0.002398713212, -0.008112887338, -0.002423738425, + 0.007210716605, -0.019203504012, 0.001724938709, 0.009909211091, + 0.001153857542, -0.001600015103, -0.000560024090, 0.010727836276, + -0.001034836404, -0.007973454377, -0.021517399106, -0.004064359664, + 0.004866398692, -0.003360038617, -0.007241406162, 0.005920941051, + 0.004899151657, 0.006290788591, -0.006478820311, 0.001921504710, + 0.001313470921, -0.000304091236, 0.001684345981, 0.004124109256, + -0.006396084465, -0.000701095618, -0.006356507032, 0.009818550859, + -0.015230664587, -0.000110244376, 0.000690319396, 0.000045953023, + -0.005726548770, 0.008769818495, -0.000572380210, 0.008860603423, + -0.013819348050, -0.021227082558, -0.004977781343, 0.006646239696, + -0.005987066507, -0.002767831232, 0.003746502525, 0.007697590397, + 0.003746130152, -0.005172634748}; + int natoms; + double expected_tot_e; + std::vector expected_tot_v; + + deepmd::DeepPot dp; + + void SetUp() override { + std::string file_name = "../../tests/infer/deeppot_sea.savedmodel"; + + dp.init(file_name); + + natoms = expected_e.size(); + EXPECT_EQ(natoms * 3, expected_f.size()); + EXPECT_EQ(natoms * 9, expected_v.size()); + expected_tot_e = 0.; + expected_tot_v.resize(9); + std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.); + for (int ii = 0; ii < natoms; ++ii) { + expected_tot_e += expected_e[ii]; + } + for (int ii = 0; ii < natoms; ++ii) { + for (int dd = 0; dd < 9; ++dd) { + expected_tot_v[dd] += expected_v[ii * 9 + dd]; + } + } + } + + void TearDown() override {} +}; + +TYPED_TEST_SUITE(TestInferDeepPotAJAX, ValueTypes); + +TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + double ener; + std::vector force_, virial; + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 0); + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + + ener = 0.; + std::fill(force_.begin(), force_.end(), 0.0); + std::fill(virial.begin(), virial.end(), 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 1); + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist_atomic) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + double ener; + std::vector force_, atom_ener_, atom_vir_, virial; + std::vector force, atom_ener, atom_vir; + dp.compute(ener, force_, virial, atom_ener_, atom_vir_, coord_cpy, atype_cpy, + box, nall - nloc, inlist, 0); + _fold_back(force, force_, mapping, nloc, nall, 3); + _fold_back(atom_ener, atom_ener_, mapping, nloc, nall, 1); + _fold_back(atom_vir, atom_vir_, mapping, nloc, nall, 9); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + EXPECT_EQ(atom_ener.size(), natoms); + EXPECT_EQ(atom_vir.size(), natoms * 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON); + } + + ener = 0.; + std::fill(force_.begin(), force_.end(), 0.0); + std::fill(virial.begin(), virial.end(), 0.0); + std::fill(atom_ener_.begin(), atom_ener_.end(), 0.0); + std::fill(atom_vir_.begin(), atom_vir_.end(), 0.0); + dp.compute(ener, force_, virial, atom_ener_, atom_vir_, coord_cpy, atype_cpy, + box, nall - nloc, inlist, 1); + _fold_back(force, force_, mapping, nloc, nall, 3); + _fold_back(atom_ener, atom_ener_, mapping, nloc, nall, 1); + _fold_back(atom_vir, atom_vir_, mapping, nloc, nall, 9); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + EXPECT_EQ(atom_ener.size(), natoms); + EXPECT_EQ(atom_vir.size(), natoms * 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist_2rc) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc * 2); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + double ener; + std::vector force_(nall * 3, 0.0), virial(9, 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 0); + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + + ener = 0.; + std::fill(force_.begin(), force_.end(), 0.0); + std::fill(virial.begin(), virial.end(), 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 1); + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist_type_sel) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + + // add vir atoms + int nvir = 2; + std::vector coord_vir(nvir * 3); + std::vector atype_vir(nvir, 2); + for (int ii = 0; ii < nvir; ++ii) { + coord_vir[ii] = coord[ii]; + } + coord.insert(coord.begin(), coord_vir.begin(), coord_vir.end()); + atype.insert(atype.begin(), atype_vir.begin(), atype_vir.end()); + natoms += nvir; + std::vector expected_f_vir(nvir * 3, 0.0); + expected_f.insert(expected_f.begin(), expected_f_vir.begin(), + expected_f_vir.end()); + + // build nlist + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + // dp compute + double ener; + std::vector force_(nall * 3, 0.0), virial(9, 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 0); + // fold back + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist_type_sel_atomic) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + + // add vir atoms + int nvir = 2; + std::vector coord_vir(nvir * 3); + std::vector atype_vir(nvir, 2); + for (int ii = 0; ii < nvir; ++ii) { + coord_vir[ii] = coord[ii]; + } + coord.insert(coord.begin(), coord_vir.begin(), coord_vir.end()); + atype.insert(atype.begin(), atype_vir.begin(), atype_vir.end()); + natoms += nvir; + std::vector expected_f_vir(nvir * 3, 0.0); + expected_f.insert(expected_f.begin(), expected_f_vir.begin(), + expected_f_vir.end()); + + // build nlist + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + // dp compute + double ener; + std::vector force_(nall * 3, 0.0), virial(9, 0.0), atomic_energy, + atomic_virial; + dp.compute(ener, force_, virial, atomic_energy, atomic_virial, coord_cpy, + atype_cpy, box, nall - nloc, inlist, 0); + // fold back + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, print_summary) { + deepmd::DeepPot& dp = this->dp; + dp.print_summary(""); +} + +TYPED_TEST(TestInferDeepPotAJAX, get_type_map) { + deepmd::DeepPot& dp = this->dp; + std::string type_map; + dp.get_type_map(type_map); + EXPECT_EQ(type_map, "O H"); +} diff --git a/source/cmake/googletest.cmake.in b/source/cmake/googletest.cmake.in index 5d167cf774..85c3745c00 100644 --- a/source/cmake/googletest.cmake.in +++ b/source/cmake/googletest.cmake.in @@ -11,7 +11,7 @@ endif() include(ExternalProject) ExternalProject_Add(googletest GIT_REPOSITORY ${GTEST_REPO_ADDRESS} - GIT_TAG release-1.12.1 + GIT_TAG v1.14.0 GIT_SHALLOW TRUE SOURCE_DIR "@CMAKE_CURRENT_BINARY_DIR@/googletest-src" BINARY_DIR "@CMAKE_CURRENT_BINARY_DIR@/googletest-build" diff --git a/source/lib/include/neighbor_list.h b/source/lib/include/neighbor_list.h index bb4b8cf13c..5b39ea7454 100644 --- a/source/lib/include/neighbor_list.h +++ b/source/lib/include/neighbor_list.h @@ -44,6 +44,8 @@ struct InputNlist { void* world; /// mask to the neighbor index int mask = 0xFFFFFFFF; + /// mapping from all atoms to real atoms, in the size of nall + int* mapping = nullptr; InputNlist() : inum(0), ilist(NULL), @@ -99,6 +101,10 @@ struct InputNlist { * @brief Set mask for this neighbor list. */ void set_mask(int mask_) { mask = mask_; }; + /** + * @brief Set mapping for this neighbor list. + */ + void set_mapping(int* mapping_) { mapping = mapping_; }; }; /** diff --git a/source/lmp/fix_dplr.cpp b/source/lmp/fix_dplr.cpp index 34fd2515ed..82bbe558bf 100644 --- a/source/lmp/fix_dplr.cpp +++ b/source/lmp/fix_dplr.cpp @@ -439,6 +439,14 @@ void FixDPLR::pre_force(int vflag) { int nghost = atom->nghost; int nall = nlocal + nghost; + // mapping (for DPA-2 JAX) + std::vector mapping_vec(nall, -1); + if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) { + for (size_t ii = 0; ii < nall; ++ii) { + mapping_vec[ii] = atom->map(atom->tag[ii]); + } + } + // if (eflag_atom) { // error->all(FLERR,"atomic energy calculation is not supported by this // fix\n"); @@ -471,6 +479,9 @@ void FixDPLR::pre_force(int vflag) { deepmd_compat::InputNlist lmp_list(list->inum, list->ilist, list->numneigh, list->firstneigh); lmp_list.set_mask(NEIGHMASK); + if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) { + lmp_list.set_mapping(mapping_vec.data()); + } // declear output vector tensor; // compute diff --git a/source/lmp/pair_deepmd.cpp b/source/lmp/pair_deepmd.cpp index d741814aa5..0abaa586ff 100644 --- a/source/lmp/pair_deepmd.cpp +++ b/source/lmp/pair_deepmd.cpp @@ -521,6 +521,14 @@ void PairDeepMD::compute(int eflag, int vflag) { } } + // mapping (for DPA-2 JAX) + std::vector mapping_vec(nall, -1); + if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) { + for (size_t ii = 0; ii < nall; ++ii) { + mapping_vec[ii] = atom->map(atom->tag[ii]); + } + } + if (do_compute_aparam) { make_aparam_from_compute(daparam); } else if (aparam.size() > 0) { @@ -564,6 +572,9 @@ void PairDeepMD::compute(int eflag, int vflag) { commdata_->firstrecv, commdata_->sendlist, commdata_->sendproc, commdata_->recvproc, &world); lmp_list.set_mask(NEIGHMASK); + if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) { + lmp_list.set_mapping(mapping_vec.data()); + } deepmd_compat::InputNlist extend_lmp_list; if (atom->sp_flag) { extend(extend_inum, extend_ilist, extend_numneigh, extend_neigh, @@ -574,6 +585,9 @@ void PairDeepMD::compute(int eflag, int vflag) { deepmd_compat::InputNlist(extend_inum, &extend_ilist[0], &extend_numneigh[0], &extend_firstneigh[0]); extend_lmp_list.set_mask(NEIGHMASK); + if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) { + extend_lmp_list.set_mapping(mapping_vec.data()); + } } if (single_model || multi_models_no_mod_devi) { // cvflag_atom is the right flag for the cvatom matrix diff --git a/source/lmp/tests/test_lammps_dpa_jax.py b/source/lmp/tests/test_lammps_dpa_jax.py new file mode 100644 index 0000000000..10428b2374 --- /dev/null +++ b/source/lmp/tests/test_lammps_dpa_jax.py @@ -0,0 +1,726 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import os +import shutil +import subprocess as sp +import sys +import tempfile +from pathlib import ( + Path, +) + +import constants +import numpy as np +import pytest +from lammps import ( + PyLammps, +) +from write_lmp_data import ( + write_lmp_data, +) + +pbtxt_file2 = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot-1.pbtxt" +) +pb_file = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa.savedmodel" +) +pb_file2 = Path(__file__).parent / "graph2.pb" +system_file = Path(__file__).parent.parent.parent / "tests" +data_file = Path(__file__).parent / "data.lmp" +data_file_si = Path(__file__).parent / "data.si" +data_type_map_file = Path(__file__).parent / "data_type_map.lmp" +md_file = Path(__file__).parent / "md.out" + +# this is as the same as python and c++ tests, test_deeppot_a.py +expected_ae = np.array( + [ + -94.24098099691867, + -187.8049502787117, + -187.80486052083617, + -94.24059525229518, + -187.80366985846246, + -187.8042377490619, + ] +) +expected_e = np.sum(expected_ae) +expected_f = np.array( + [ + -0.0020150115442053, + -0.0133389255924977, + -0.0014347177433057, + -0.0140757358179293, + 0.0031373814221557, + 0.0098594354314677, + 0.004755683505073, + 0.0099471082374397, + -0.0080868184532793, + -0.0086166721574536, + 0.0037803939137322, + -0.0075733131286482, + 0.0037437603038209, + -0.008452527996008, + 0.0134837461840424, + 0.0162079757106944, + 0.0049265700151781, + -0.0062483322902769, + ] +).reshape(6, 3) + +expected_f2 = np.array( + [ + [-0.6454949, 1.72457783, 0.18897958], + [1.68936514, -0.36995299, -1.36044464], + [-1.09902692, -1.35487928, 1.17416702], + [1.68426111, -0.50835585, 0.98340415], + [0.05771758, 1.12515818, -1.77561531], + [-1.686822, -0.61654789, 0.78950921], + ] +) + +expected_v = -np.array( + [ + 0.0133534319524089, + 0.0013445914938337, + -0.0029370551651952, + 0.0002611806151294, + 0.004662662211533, + -0.0002717443796319, + -0.0027779798869954, + -0.0003277976466339, + 0.0018284972283065, + 0.0085710118978246, + 0.0003865036653608, + -0.0057964032875089, + -0.0014358330222619, + 0.0002912625128908, + 0.001212630641674, + -0.0050582608957046, + -0.0001087907763249, + 0.0040068757134429, + 0.0116736349373084, + 0.0007055477968445, + -0.0019544933708784, + 0.0032997459258512, + 0.0037887116116712, + -0.0043140890650835, + -0.0034418738401156, + -0.0029420616852742, + 0.0038219676716965, + 0.0147134944025738, + 0.0005214313829998, + -0.0006524136175906, + 0.0003656980996363, + 0.0010046161607714, + -0.0017279359476254, + 0.000111127036911, + -0.0017063190420654, + 0.0030174567965904, + 0.0104435705455108, + -0.0008704394438241, + 0.0012354202650812, + 0.0009397615830053, + 0.0029105236407293, + -0.0044188897903449, + -0.0011461513500477, + -0.0045759080125852, + 0.0070310883421107, + 0.0089818851995049, + 0.0038819466696704, + -0.005443705549253, + 0.0025390283635246, + 0.0012121502955869, + -0.0016998728971157, + -0.0032355117893925, + -0.0015590242752438, + 0.0021980725909838, + ] +).reshape(6, 9) +expected_v2 = -np.array( + [ + [ + -0.70008436, + -0.06399891, + 0.63678391, + -0.07642171, + -0.70580035, + 0.20506145, + 0.64098364, + 0.20305781, + -0.57906794, + ], + [ + -0.6372635, + 0.14315552, + 0.51952246, + 0.04604049, + -0.06003681, + -0.02688702, + 0.54489318, + -0.10951559, + -0.43730539, + ], + [ + -0.25090748, + -0.37466262, + 0.34085833, + -0.26690852, + -0.37676917, + 0.29080825, + 0.31600481, + 0.37558276, + -0.33251064, + ], + [ + -0.80195614, + -0.10273138, + 0.06935364, + -0.10429256, + -0.29693811, + 0.45643496, + 0.07247872, + 0.45604679, + -0.71048816, + ], + [ + -0.03840668, + -0.07680205, + 0.10940472, + -0.02374189, + -0.27610266, + 0.4336071, + 0.02465248, + 0.4290638, + -0.67496763, + ], + [ + -0.61475065, + -0.21163135, + 0.26652929, + -0.26134659, + -0.11560267, + 0.15415902, + 0.34343952, + 0.1589482, + -0.21370642, + ], + ] +).reshape(6, 9) + +box = np.array([0, 13, 0, 13, 0, 13, 0, 0, 0]) +coord = np.array( + [ + [12.83, 2.56, 2.18], + [12.09, 2.87, 2.74], + [0.25, 3.32, 1.68], + [3.36, 3.00, 1.81], + [3.51, 2.51, 2.60], + [4.27, 3.22, 1.56], + ] +) +type_OH = np.array([1, 2, 2, 1, 2, 2]) +type_HO = np.array([2, 1, 1, 2, 1, 1]) + + +sp.check_output( + f"{sys.executable} -m deepmd convert-from pbtxt -i {pbtxt_file2.resolve()} -o {pb_file2.resolve()}".split() +) + + +def setup_module(): + write_lmp_data(box, coord, type_OH, data_file) + write_lmp_data(box, coord, type_HO, data_type_map_file) + write_lmp_data( + box * constants.dist_metal2si, + coord * constants.dist_metal2si, + type_OH, + data_file_si, + ) + + +def teardown_module(): + os.remove(data_file) + os.remove(data_type_map_file) + + +def _lammps(data_file, units="metal") -> PyLammps: + lammps = PyLammps() + lammps.units(units) + lammps.boundary("p p p") + lammps.atom_style("atomic") + # Requires for DPA-2 + lammps.atom_modify("map yes") + if units == "metal" or units == "real": + lammps.neighbor("2.0 bin") + elif units == "si": + lammps.neighbor("2.0e-10 bin") + else: + raise ValueError("units should be metal, real, or si") + lammps.neigh_modify("every 10 delay 0 check no") + lammps.read_data(data_file.resolve()) + if units == "metal" or units == "real": + lammps.mass("1 16") + lammps.mass("2 2") + elif units == "si": + lammps.mass("1 %.10e" % (16 * constants.mass_metal2si)) + lammps.mass("2 %.10e" % (2 * constants.mass_metal2si)) + else: + raise ValueError("units should be metal, real, or si") + if units == "metal": + lammps.timestep(0.0005) + elif units == "real": + lammps.timestep(0.5) + elif units == "si": + lammps.timestep(5e-16) + else: + raise ValueError("units should be metal, real, or si") + lammps.fix("1 all nve") + return lammps + + +@pytest.fixture +def lammps(): + lmp = _lammps(data_file=data_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_type_map(): + lmp = _lammps(data_file=data_type_map_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_real(): + lmp = _lammps(data_file=data_file, units="real") + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_si(): + lmp = _lammps(data_file=data_file_si, units="si") + yield lmp + lmp.close() + + +def test_pair_deepmd(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + lammps.run(1) + + +def test_pair_deepmd_virial(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + + +def test_pair_deepmd_model_devi(lammps): + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_virial(lammps): + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative(lammps): + relative = 1.0 + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative_v(lammps): + relative = 1.0 + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative_v {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_type_map(lammps_type_map): + lammps_type_map.pair_style(f"deepmd {pb_file.resolve()}") + lammps_type_map.pair_coeff("* * H O") + lammps_type_map.run(0) + assert lammps_type_map.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps_type_map.atoms[ii].force == pytest.approx( + expected_f[lammps_type_map.atoms[ii].id - 1] + ) + lammps_type_map.run(1) + + +def test_pair_deepmd_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + lammps_real.run(1) + + +def test_pair_deepmd_virial_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_real(lammps_real): + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_virial_real(lammps_real): + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative * constants.force_metal2real}" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_v_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative_v {relative * constants.ener_metal2real}" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_si(lammps_si): + lammps_si.pair_style(f"deepmd {pb_file.resolve()}") + lammps_si.pair_coeff("* *") + lammps_si.run(0) + assert lammps_si.eval("pe") == pytest.approx(expected_e * constants.ener_metal2si) + for ii in range(6): + assert lammps_si.atoms[ii].force == pytest.approx( + expected_f[lammps_si.atoms[ii].id - 1] * constants.force_metal2si + ) + lammps_si.run(1) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +@pytest.mark.parametrize( + ("balance_args",), + [(["--balance"],), ([],)], +) +@pytest.mark.skip("MPI is not supported") +def test_pair_deepmd_mpi(balance_args: list): + with tempfile.NamedTemporaryFile() as f: + sp.check_call( + [ + "mpirun", + "-n", + "2", + sys.executable, + Path(__file__).parent / "run_mpi_pair_deepmd.py", + data_file, + pb_file, + pb_file2, + md_file, + f.name, + *balance_args, + ] + ) + arr = np.loadtxt(f.name, ndmin=1) + pe = arr[0] + + relative = 1.0 + assert pe == pytest.approx(expected_e) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) diff --git a/source/lmp/tests/test_lammps_jax.py b/source/lmp/tests/test_lammps_jax.py new file mode 100644 index 0000000000..6d67cd3203 --- /dev/null +++ b/source/lmp/tests/test_lammps_jax.py @@ -0,0 +1,723 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import os +import shutil +import subprocess as sp +import sys +import tempfile +from pathlib import ( + Path, +) + +import constants +import numpy as np +import pytest +from lammps import ( + PyLammps, +) +from write_lmp_data import ( + write_lmp_data, +) + +pbtxt_file2 = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot-1.pbtxt" +) +pb_file = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_sea.savedmodel" +) +pb_file2 = Path(__file__).parent / "graph2.pb" +system_file = Path(__file__).parent.parent.parent / "tests" +data_file = Path(__file__).parent / "data.lmp" +data_file_si = Path(__file__).parent / "data.si" +data_type_map_file = Path(__file__).parent / "data_type_map.lmp" +md_file = Path(__file__).parent / "md.out" + +# this is as the same as python and c++ tests, test_deeppot_a.py +expected_ae = np.array( + [ + -93.016873944029, + -185.923296645958, + -185.927096544970, + -93.019371018039, + -185.926179995548, + -185.924351901852, + ] +) +expected_e = np.sum(expected_ae) +expected_f = np.array( + [ + 0.006277522211, + -0.001117962774, + 0.000618580445, + 0.009928999655, + 0.003026035654, + -0.006941982227, + 0.000667853212, + -0.002449963843, + 0.006506463508, + -0.007284129115, + 0.000530662205, + -0.000028806821, + 0.000068097781, + 0.006121331983, + -0.009019754602, + -0.009658343745, + -0.006110103225, + 0.008865499697, + ] +).reshape(6, 3) + +expected_f2 = np.array( + [ + [-0.6454949, 1.72457783, 0.18897958], + [1.68936514, -0.36995299, -1.36044464], + [-1.09902692, -1.35487928, 1.17416702], + [1.68426111, -0.50835585, 0.98340415], + [0.05771758, 1.12515818, -1.77561531], + [-1.686822, -0.61654789, 0.78950921], + ] +) + +expected_v = -np.array( + [ + -0.000155238009, + 0.000116605516, + -0.007869862476, + 0.000465578340, + 0.008182547185, + -0.002398713212, + -0.008112887338, + -0.002423738425, + 0.007210716605, + -0.019203504012, + 0.001724938709, + 0.009909211091, + 0.001153857542, + -0.001600015103, + -0.000560024090, + 0.010727836276, + -0.001034836404, + -0.007973454377, + -0.021517399106, + -0.004064359664, + 0.004866398692, + -0.003360038617, + -0.007241406162, + 0.005920941051, + 0.004899151657, + 0.006290788591, + -0.006478820311, + 0.001921504710, + 0.001313470921, + -0.000304091236, + 0.001684345981, + 0.004124109256, + -0.006396084465, + -0.000701095618, + -0.006356507032, + 0.009818550859, + -0.015230664587, + -0.000110244376, + 0.000690319396, + 0.000045953023, + -0.005726548770, + 0.008769818495, + -0.000572380210, + 0.008860603423, + -0.013819348050, + -0.021227082558, + -0.004977781343, + 0.006646239696, + -0.005987066507, + -0.002767831232, + 0.003746502525, + 0.007697590397, + 0.003746130152, + -0.005172634748, + ] +).reshape(6, 9) +expected_v2 = -np.array( + [ + [ + -0.70008436, + -0.06399891, + 0.63678391, + -0.07642171, + -0.70580035, + 0.20506145, + 0.64098364, + 0.20305781, + -0.57906794, + ], + [ + -0.6372635, + 0.14315552, + 0.51952246, + 0.04604049, + -0.06003681, + -0.02688702, + 0.54489318, + -0.10951559, + -0.43730539, + ], + [ + -0.25090748, + -0.37466262, + 0.34085833, + -0.26690852, + -0.37676917, + 0.29080825, + 0.31600481, + 0.37558276, + -0.33251064, + ], + [ + -0.80195614, + -0.10273138, + 0.06935364, + -0.10429256, + -0.29693811, + 0.45643496, + 0.07247872, + 0.45604679, + -0.71048816, + ], + [ + -0.03840668, + -0.07680205, + 0.10940472, + -0.02374189, + -0.27610266, + 0.4336071, + 0.02465248, + 0.4290638, + -0.67496763, + ], + [ + -0.61475065, + -0.21163135, + 0.26652929, + -0.26134659, + -0.11560267, + 0.15415902, + 0.34343952, + 0.1589482, + -0.21370642, + ], + ] +).reshape(6, 9) + +box = np.array([0, 13, 0, 13, 0, 13, 0, 0, 0]) +coord = np.array( + [ + [12.83, 2.56, 2.18], + [12.09, 2.87, 2.74], + [0.25, 3.32, 1.68], + [3.36, 3.00, 1.81], + [3.51, 2.51, 2.60], + [4.27, 3.22, 1.56], + ] +) +type_OH = np.array([1, 2, 2, 1, 2, 2]) +type_HO = np.array([2, 1, 1, 2, 1, 1]) + + +sp.check_output( + f"{sys.executable} -m deepmd convert-from pbtxt -i {pbtxt_file2.resolve()} -o {pb_file2.resolve()}".split() +) + + +def setup_module(): + write_lmp_data(box, coord, type_OH, data_file) + write_lmp_data(box, coord, type_HO, data_type_map_file) + write_lmp_data( + box * constants.dist_metal2si, + coord * constants.dist_metal2si, + type_OH, + data_file_si, + ) + + +def teardown_module(): + os.remove(data_file) + os.remove(data_type_map_file) + + +def _lammps(data_file, units="metal") -> PyLammps: + lammps = PyLammps() + lammps.units(units) + lammps.boundary("p p p") + lammps.atom_style("atomic") + if units == "metal" or units == "real": + lammps.neighbor("2.0 bin") + elif units == "si": + lammps.neighbor("2.0e-10 bin") + else: + raise ValueError("units should be metal, real, or si") + lammps.neigh_modify("every 10 delay 0 check no") + lammps.read_data(data_file.resolve()) + if units == "metal" or units == "real": + lammps.mass("1 16") + lammps.mass("2 2") + elif units == "si": + lammps.mass("1 %.10e" % (16 * constants.mass_metal2si)) + lammps.mass("2 %.10e" % (2 * constants.mass_metal2si)) + else: + raise ValueError("units should be metal, real, or si") + if units == "metal": + lammps.timestep(0.0005) + elif units == "real": + lammps.timestep(0.5) + elif units == "si": + lammps.timestep(5e-16) + else: + raise ValueError("units should be metal, real, or si") + lammps.fix("1 all nve") + return lammps + + +@pytest.fixture +def lammps(): + lmp = _lammps(data_file=data_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_type_map(): + lmp = _lammps(data_file=data_type_map_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_real(): + lmp = _lammps(data_file=data_file, units="real") + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_si(): + lmp = _lammps(data_file=data_file_si, units="si") + yield lmp + lmp.close() + + +def test_pair_deepmd(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + lammps.run(1) + + +def test_pair_deepmd_virial(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + + +def test_pair_deepmd_model_devi(lammps): + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_virial(lammps): + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative(lammps): + relative = 1.0 + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative_v(lammps): + relative = 1.0 + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative_v {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_type_map(lammps_type_map): + lammps_type_map.pair_style(f"deepmd {pb_file.resolve()}") + lammps_type_map.pair_coeff("* * H O") + lammps_type_map.run(0) + assert lammps_type_map.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps_type_map.atoms[ii].force == pytest.approx( + expected_f[lammps_type_map.atoms[ii].id - 1] + ) + lammps_type_map.run(1) + + +def test_pair_deepmd_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + lammps_real.run(1) + + +def test_pair_deepmd_virial_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_real(lammps_real): + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_virial_real(lammps_real): + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative * constants.force_metal2real}" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_v_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative_v {relative * constants.ener_metal2real}" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_si(lammps_si): + lammps_si.pair_style(f"deepmd {pb_file.resolve()}") + lammps_si.pair_coeff("* *") + lammps_si.run(0) + assert lammps_si.eval("pe") == pytest.approx(expected_e * constants.ener_metal2si) + for ii in range(6): + assert lammps_si.atoms[ii].force == pytest.approx( + expected_f[lammps_si.atoms[ii].id - 1] * constants.force_metal2si + ) + lammps_si.run(1) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +@pytest.mark.parametrize( + ("balance_args",), + [(["--balance"],), ([],)], +) +def test_pair_deepmd_mpi(balance_args: list): + with tempfile.NamedTemporaryFile() as f: + sp.check_call( + [ + "mpirun", + "-n", + "2", + sys.executable, + Path(__file__).parent / "run_mpi_pair_deepmd.py", + data_file, + pb_file, + pb_file2, + md_file, + f.name, + *balance_args, + ] + ) + arr = np.loadtxt(f.name, ndmin=1) + pe = arr[0] + + relative = 1.0 + assert pe == pytest.approx(expected_e) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) diff --git a/source/tests/infer/deeppot_dpa.savedmodel/.gitignore b/source/tests/infer/deeppot_dpa.savedmodel/.gitignore new file mode 100644 index 0000000000..dad5ee2642 --- /dev/null +++ b/source/tests/infer/deeppot_dpa.savedmodel/.gitignore @@ -0,0 +1 @@ +!*.pb diff --git a/source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb b/source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb new file mode 100644 index 0000000000..a930670465 --- /dev/null +++ b/source/tests/infer/deeppot_dpa.savedmodel/fingerprint.pb @@ -0,0 +1 @@ +|俖ў (Ňʋ2 \ No newline at end of file diff --git a/source/tests/infer/deeppot_dpa.savedmodel/saved_model.pb b/source/tests/infer/deeppot_dpa.savedmodel/saved_model.pb new file mode 100644 index 0000000000..9dca73e7cd Binary files /dev/null and b/source/tests/infer/deeppot_dpa.savedmodel/saved_model.pb differ diff --git a/source/tests/infer/deeppot_dpa.savedmodel/variables/variables.data-00000-of-00001 b/source/tests/infer/deeppot_dpa.savedmodel/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000..95a524f222 Binary files /dev/null and b/source/tests/infer/deeppot_dpa.savedmodel/variables/variables.data-00000-of-00001 differ diff --git a/source/tests/infer/deeppot_dpa.savedmodel/variables/variables.index b/source/tests/infer/deeppot_dpa.savedmodel/variables/variables.index new file mode 100644 index 0000000000..4dfa9afa72 Binary files /dev/null and b/source/tests/infer/deeppot_dpa.savedmodel/variables/variables.index differ diff --git a/source/tests/infer/deeppot_sea.savedmodel/.gitignore b/source/tests/infer/deeppot_sea.savedmodel/.gitignore new file mode 100644 index 0000000000..dad5ee2642 --- /dev/null +++ b/source/tests/infer/deeppot_sea.savedmodel/.gitignore @@ -0,0 +1 @@ +!*.pb diff --git a/source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb b/source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb new file mode 100644 index 0000000000..71dd8d955d --- /dev/null +++ b/source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb @@ -0,0 +1 @@ +Ǎ俖ў ՝:(Ňʋ2 \ No newline at end of file diff --git a/source/tests/infer/deeppot_sea.savedmodel/saved_model.pb b/source/tests/infer/deeppot_sea.savedmodel/saved_model.pb new file mode 100644 index 0000000000..3897c71855 Binary files /dev/null and b/source/tests/infer/deeppot_sea.savedmodel/saved_model.pb differ diff --git a/source/tests/infer/deeppot_sea.savedmodel/variables/variables.data-00000-of-00001 b/source/tests/infer/deeppot_sea.savedmodel/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000..95a524f222 Binary files /dev/null and b/source/tests/infer/deeppot_sea.savedmodel/variables/variables.data-00000-of-00001 differ diff --git a/source/tests/infer/deeppot_sea.savedmodel/variables/variables.index b/source/tests/infer/deeppot_sea.savedmodel/variables/variables.index new file mode 100644 index 0000000000..4dfa9afa72 Binary files /dev/null and b/source/tests/infer/deeppot_sea.savedmodel/variables/variables.index differ