diff --git a/source/api_cc/include/DeepPotJAX.h b/source/api_cc/include/DeepPotJAX.h index 38fd0c68f4..76533fcc35 100644 --- a/source/api_cc/include/DeepPotJAX.h +++ b/source/api_cc/include/DeepPotJAX.h @@ -22,7 +22,7 @@ class DeepPotJAX : public DeepPotBackend { /** * @brief DP constructor with initialization. * @param[in] model The name of the frozen model file. - * @param[in] gpu_rank The GPU rank. Default is 0. + * @param[in] gpu_rank The GPU rank. Default is 0. If < 0, use CPU. * @param[in] file_content The content of the model file. If it is not empty, *DP will read from the string instead of the file. **/ @@ -32,7 +32,7 @@ class DeepPotJAX : public DeepPotBackend { /** * @brief Initialize the DP. * @param[in] model The name of the frozen model file. - * @param[in] gpu_rank The GPU rank. Default is 0. + * @param[in] gpu_rank The GPU rank. Default is 0. If < 0, use CPU. * @param[in] file_content The content of the model file. If it is not empty, *DP will read from the string instead of the file. **/ @@ -208,6 +208,42 @@ class DeepPotJAX : public DeepPotBackend { */ // neighbor list data NeighborListData nlist_data; + /** + * @brief Evaluate the energy, force, virial, atomic energy, and atomic virial + *by using this DP. + * @param[out] ener The system energy. + * @param[out] force The force on each atom. + * @param[out] virial The virial. + * @param[out] atom_energy The atomic energy. + * @param[out] atom_virial The atomic virial. + * @param[in] coord The coordinates of atoms. The array should be of size + *nframes x natoms x 3. + * @param[in] atype The atom types. The list should contain natoms ints. + * @param[in] box The cell of the region. The array should be of size nframes + *x 9. + * @param[in] fparam The frame parameter. The array can be of size : + * nframes x dim_fparam. + * dim_fparam. Then all frames are assumed to be provided with the same + *fparam. + * @param[in] aparam The atomic parameter The array can be of size : + * nframes x natoms x dim_aparam. + * natoms x dim_aparam. Then all frames are assumed to be provided with the + *same aparam. + * @param[in] atomic Whether to compute the atomic energy and virial. + **/ + template + void compute(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + /** * @brief Evaluate the energy, force, virial, atomic energy, and atomic virial *by using this DP. diff --git a/source/api_cc/src/DeepPotJAX.cc b/source/api_cc/src/DeepPotJAX.cc index a376f9b195..66525facd3 100644 --- a/source/api_cc/src/DeepPotJAX.cc +++ b/source/api_cc/src/DeepPotJAX.cc @@ -6,6 +6,8 @@ #include #include +#include +#include #include #include #include @@ -228,6 +230,13 @@ void deepmd::DeepPotJAX::init(const std::string& model, status = TF_NewStatus(); sessionopts = TF_NewSessionOptions(); + int num_intra_nthreads, num_inter_nthreads; + get_env_nthreads(num_intra_nthreads, num_inter_nthreads); + // https://github.com/Neargye/hello_tf_c_api/blob/51516101cf59408a6bb456f7e5f3c6628e327b3a/src/tf_utils.cpp#L400-L401 + std::array config = { + {0x10, static_cast(num_intra_nthreads), 0x28, + static_cast(num_inter_nthreads)}}; + TF_SetConfig(sessionopts, config.data(), config.size(), status); TF_Buffer* runopts = NULL; const char* tags = "serve"; @@ -250,8 +259,8 @@ void deepmd::DeepPotJAX::init(const std::string& model, #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM int gpu_num; DPGetDeviceCount(gpu_num); // check current device environment - DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); - if (gpu_num > 0) { + if (gpu_num > 0 && gpu_rank >= 0) { + DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); device = "/gpu:" + std::to_string(gpu_rank % gpu_num); } else { device = "/cpu:0"; @@ -300,6 +309,153 @@ deepmd::DeepPotJAX::~DeepPotJAX() { } } +template +void deepmd::DeepPotJAX::compute(std::vector& ener, + std::vector& force_, + std::vector& virial, + std::vector& atom_energy_, + std::vector& atom_virial_, + const std::vector& dcoord, + const std::vector& datype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam_, + const bool atomic) { + std::vector coord, force, aparam, atom_energy, atom_virial; + std::vector ener_double, force_double, virial_double, + atom_energy_double, atom_virial_double; + std::vector atype, fwd_map, bkw_map; + int nghost_real, nall_real, nloc_real; + int nall = datype.size(); + // nlist passed to the model + int nframes = nall > 0 ? (dcoord.size() / 3 / nall) : 1; + int nghost = 0; + + select_real_atoms_coord(coord, atype, aparam, nghost_real, fwd_map, bkw_map, + nall_real, nloc_real, dcoord, datype, aparam_, nghost, + ntypes, nframes, daparam, nall, false); + + if (nloc_real == 0) { + // no real atoms, fill 0 for all outputs + // this can prevent a Xla error + ener.resize(nframes, 0.0); + force_.resize(static_cast(nframes) * nall * 3, 0.0); + virial.resize(static_cast(nframes) * 9, 0.0); + atom_energy_.resize(static_cast(nframes) * nall, 0.0); + atom_virial_.resize(static_cast(nframes) * nall * 9, 0.0); + return; + } + + // cast coord, fparam, and aparam to double - I think it's useless to have a + // float model interface + std::vector coord_double(coord.begin(), coord.end()); + std::vector box_double(box.begin(), box.end()); + std::vector fparam_double(fparam.begin(), fparam.end()); + std::vector aparam_double(aparam.begin(), aparam.end()); + + TFE_Op* op; + if (atomic) { + op = get_func_op(ctx, "call_with_atomic_virial", func_vector, device, + status); + } else { + op = get_func_op(ctx, "call_without_atomic_virial", func_vector, device, + status); + } + std::vector input_list(5); + std::vector data_tensor(5); + // coord + std::vector coord_shape = {nframes, nloc_real, 3}; + input_list[0] = + add_input(op, coord_double, coord_shape, data_tensor[0], status); + // atype + std::vector atype_shape = {nframes, nloc_real}; + input_list[1] = add_input(op, atype, atype_shape, data_tensor[1], status); + // box + int box_size = box_double.size() > 0 ? 3 : 0; + std::vector box_shape = {nframes, box_size, box_size}; + input_list[2] = add_input(op, box_double, box_shape, data_tensor[2], status); + // fparam + std::vector fparam_shape = {nframes, dfparam}; + input_list[3] = + add_input(op, fparam_double, fparam_shape, data_tensor[3], status); + // aparam + std::vector aparam_shape = {nframes, nloc_real, daparam}; + input_list[4] = + add_input(op, aparam_double, aparam_shape, data_tensor[4], status); + // execute the function + int nretvals = 6; + TFE_TensorHandle* retvals[nretvals]; + + TFE_Execute(op, retvals, &nretvals, status); + check_status(status); + + // copy data + // for atom virial, the order is: + // Identity_15 energy -1, -1, 1 + // Identity_16 energy_derv_c -1, -1, 1, 9 (may pop) + // Identity_17 energy_derv_c_redu -1, 1, 9 + // Identity_18 energy_derv_r -1, -1, 1, 3 + // Identity_19 energy_redu -1, 1 + // Identity_20 mask (int32) -1, -1 + // + // for no atom virial, the order is: + // Identity_15 energy -1, -1, 1 + // Identity_16 energy_derv_c -1, 1, 9 + // Identity_17 energy_derv_r -1, -1, 1, 3 + // Identity_18 energy_redu -1, 1 + // Identity_19 mask (int32) -1, -1 + // + // it seems the order is the alphabet order? + // not sure whether it is safe to assume the order + if (atomic) { + tensor_to_vector(ener_double, retvals[4], status); + tensor_to_vector(force_double, retvals[3], status); + tensor_to_vector(virial_double, retvals[2], status); + tensor_to_vector(atom_energy_double, retvals[0], status); + tensor_to_vector(atom_virial_double, retvals[1], status); + } else { + tensor_to_vector(ener_double, retvals[3], status); + tensor_to_vector(force_double, retvals[2], status); + tensor_to_vector(virial_double, retvals[1], status); + tensor_to_vector(atom_energy_double, retvals[0], status); + } + + // cast back to VALUETYPE + ener = std::vector(ener_double.begin(), ener_double.end()); + force = std::vector(force_double.begin(), force_double.end()); + virial = std::vector(virial_double.begin(), virial_double.end()); + atom_energy = std::vector(atom_energy_double.begin(), + atom_energy_double.end()); + atom_virial = std::vector(atom_virial_double.begin(), + atom_virial_double.end()); + force.resize(static_cast(nframes) * nall_real * 3); + atom_virial.resize(static_cast(nframes) * nall_real * 9); + + // nall atom_energy is required in the C++ API; + // we always forget it! + atom_energy.resize(static_cast(nframes) * nall_real, 0.0); + + force_.resize(static_cast(nframes) * fwd_map.size() * 3); + atom_energy_.resize(static_cast(nframes) * fwd_map.size()); + atom_virial_.resize(static_cast(nframes) * fwd_map.size() * 9); + select_map(force_, force, bkw_map, 3, nframes, fwd_map.size(), + nall_real); + select_map(atom_energy_, atom_energy, bkw_map, 1, nframes, + fwd_map.size(), nall_real); + select_map(atom_virial_, atom_virial, bkw_map, 9, nframes, + fwd_map.size(), nall_real); + + // cleanup input_list, etc + for (size_t i = 0; i < 5; i++) { + TFE_DeleteTensorHandle(input_list[i]); + TF_DeleteTensor(data_tensor[i]); + } + for (size_t i = 0; i < nretvals; i++) { + TFE_DeleteTensorHandle(retvals[i]); + } + TFE_DeleteOp(op); +} + template void deepmd::DeepPotJAX::compute(std::vector& ener, std::vector& force_, @@ -523,7 +679,8 @@ void deepmd::DeepPotJAX::computew(std::vector& ener, const std::vector& fparam, const std::vector& aparam, const bool atomic) { - throw deepmd::deepmd_exception("not implemented"); + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, + fparam, aparam, atomic); } void deepmd::DeepPotJAX::computew(std::vector& ener, std::vector& force, @@ -536,7 +693,8 @@ void deepmd::DeepPotJAX::computew(std::vector& ener, const std::vector& fparam, const std::vector& aparam, const bool atomic) { - throw deepmd::deepmd_exception("not implemented"); + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, + fparam, aparam, atomic); } void deepmd::DeepPotJAX::computew(std::vector& ener, std::vector& force, diff --git a/source/api_cc/tests/test_deeppot_jax.cc b/source/api_cc/tests/test_deeppot_jax.cc index c9fe8ea3dd..e7095bee5e 100644 --- a/source/api_cc/tests/test_deeppot_jax.cc +++ b/source/api_cc/tests/test_deeppot_jax.cc @@ -71,7 +71,8 @@ class TestInferDeepPotAJAX : public ::testing::Test { void SetUp() override { std::string file_name = "../../tests/infer/deeppot_sea.savedmodel"; - dp.init(file_name); + // the model is generated for the CPU, so always use the CPU + dp.init(file_name, -1); natoms = expected_e.size(); EXPECT_EQ(natoms * 3, expected_f.size()); @@ -94,6 +95,121 @@ class TestInferDeepPotAJAX : public ::testing::Test { TYPED_TEST_SUITE(TestInferDeepPotAJAX, ValueTypes); +TYPED_TEST(TestInferDeepPotAJAX, cpu_build_nlist) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + double ener; + std::vector force, virial; + dp.compute(ener, force, virial, coord, atype, box); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_build_nlist_numfv) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + class MyModel : public EnergyModelTest { + deepmd::DeepPot& mydp; + const std::vector atype; + + public: + MyModel(deepmd::DeepPot& dp_, const std::vector& atype_) + : mydp(dp_), atype(atype_) {}; + virtual void compute(double& ener, + std::vector& force, + std::vector& virial, + const std::vector& coord, + const std::vector& box) { + mydp.compute(ener, force, virial, coord, atype, box); + } + }; + MyModel model(dp, atype); + model.test_f(coord, box); + model.test_v(coord, box); + std::vector box_(box); + box_[1] -= 0.4; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[2] += 0.5; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[4] += 0.2; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[3] -= 0.3; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[6] -= 0.7; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[7] += 0.6; + model.test_f(coord, box_); + model.test_v(coord, box_); +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_build_nlist_atomic) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + double ener; + std::vector force, virial, atom_ener, atom_vir; + dp.compute(ener, force, virial, atom_ener, atom_vir, coord, atype, box); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + EXPECT_EQ(atom_ener.size(), natoms); + EXPECT_EQ(atom_vir.size(), natoms * 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON); + } +} + TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist) { using VALUETYPE = TypeParam; std::vector& coord = this->coord; diff --git a/source/lmp/tests/test_lammps_jax.py b/source/lmp/tests/test_lammps_jax.py index 6d67cd3203..2306081a8f 100644 --- a/source/lmp/tests/test_lammps_jax.py +++ b/source/lmp/tests/test_lammps_jax.py @@ -19,6 +19,12 @@ write_lmp_data, ) +pytest.skipif( + os.environ.get("CUDA_VISIBLE_DEVICES", "") != "", + reason="The model is generated with CPU", + allow_module_level=True, +) + pbtxt_file2 = ( Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot-1.pbtxt" )