diff --git a/deepmd/pd/entrypoints/main.py b/deepmd/pd/entrypoints/main.py index f05543d239..4e4ac5f85a 100644 --- a/deepmd/pd/entrypoints/main.py +++ b/deepmd/pd/entrypoints/main.py @@ -53,6 +53,7 @@ ) from deepmd.pd.utils.env import ( DEVICE, + PIR_ENABLED, ) from deepmd.pd.utils.finetune import ( get_finetune_rules, @@ -349,17 +350,20 @@ def freeze(FLAGS): ) """ - ** coord [None, 192, 3] paddle.float64 - ** atype [None, 192] paddle.int64 - ** box [None, 3, 3] paddle.float64 + ** coord [None, natoms, 3] paddle.float64 + ** atype [None, natoms] paddle.int64 + ** nlist [None, natoms, nnei] paddle.int32 """ + model.atomic_model.buffer_type_map.set_value( + paddle.to_tensor([ord(c) for c in model.atomic_model.type_map], dtype="int32") + ) model = paddle.jit.to_static( - model, + model.forward_lower, full_graph=True, input_spec=[ - InputSpec([None, 192, 3], dtype="float64", name="coord"), - InputSpec([None, 192], dtype="int64", name="atype"), - InputSpec([None, 3, 3], dtype="float64", name="box"), + InputSpec([-1, -1, 3], dtype="float64", name="coord"), + InputSpec([-1, -1], dtype="int32", name="atype"), + InputSpec([-1, -1, -1], dtype="int32", name="nlist"), ], ) extra_files = {} @@ -369,8 +373,7 @@ def freeze(FLAGS): skip_prune_program=True, # extra_files, ) - pir_flag = os.getenv("FLAGS_enable_pir_api", "false") - suffix = "json" if pir_flag.lower() in ["true", "1"] else "pdmodel" + suffix = "json" if PIR_ENABLED.lower() in ["true", "1"] else "pdmodel" log.info( f"Paddle inference model has been exported to: {FLAGS.output}.{suffix}(.pdiparams)" ) diff --git a/deepmd/pd/model/atomic_model/dp_atomic_model.py b/deepmd/pd/model/atomic_model/dp_atomic_model.py index 9b264fd2c4..e5abd2135d 100644 --- a/deepmd/pd/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pd/model/atomic_model/dp_atomic_model.py @@ -58,6 +58,11 @@ def __init__( super().__init__(type_map, **kwargs) ntypes = len(type_map) self.type_map = type_map + self.register_buffer( + "buffer_type_map", + paddle.to_tensor([ord(c) for c in self.type_map], dtype="int32"), + ) + self.buffer_type_map.name = "type_map" self.ntypes = ntypes self.descriptor = descriptor self.rcut = self.descriptor.get_rcut() diff --git a/deepmd/pd/model/model/make_model.py b/deepmd/pd/model/model/make_model.py index 597171d596..3a35589458 100644 --- a/deepmd/pd/model/model/make_model.py +++ b/deepmd/pd/model/model/make_model.py @@ -429,11 +429,10 @@ def _format_nlist( axis=-1, ) - # if n_nnei > nnei or extra_nlist_sort: - if False: + if True: # TODO: Fix controlflow + backward in PIR static graph n_nf, n_nloc, n_nnei = nlist.shape m_real_nei = nlist >= 0 - nlist = paddle.where(m_real_nei, nlist, 0) + nlist = paddle.where(m_real_nei, nlist, paddle.zeros_like(nlist)) # nf x nloc x 3 coord0 = extended_coord[:, :n_nloc, :] # nf x (nloc x nnei) x 3 @@ -450,7 +449,7 @@ def _format_nlist( paddle.argsort(rr, axis=-1), ) nlist = aux.take_along_axis(nlist, axis=2, indices=nlist_mapping) - nlist = paddle.where(rr > rcut, -1, nlist) + nlist = paddle.where(rr > rcut, paddle.full_like(nlist, -1), nlist) nlist = nlist[..., :nnei] else: # not extra_nlist_sort and n_nnei <= nnei: pass # great! diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index 85f9e57169..49a11658f3 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -67,6 +67,7 @@ } assert set(PRECISION_DICT.values()) == set(RESERVED_PRECISON_DICT.keys()) DEFAULT_PRECISION = "float64" +PIR_ENABLED = os.getenv("FLAGS_enable_pir_api", "false") # throw warnings if threads not set set_default_nthreads() diff --git a/deepmd/pd/utils/nlist.py b/deepmd/pd/utils/nlist.py index ef27be31eb..52893c85d2 100644 --- a/deepmd/pd/utils/nlist.py +++ b/deepmd/pd/utils/nlist.py @@ -318,7 +318,7 @@ def nlist_distinguish_types( for ii, ss in enumerate(sel): # nloc x s(nsel) # to int because bool cannot be sort on GPU - pick_mask = (tnlist == ii).to(paddle.int32) + pick_mask = (tnlist == ii).to(paddle.int64) # nloc x s(nsel), stable sort, nearer neighbors first pick_mask, imap = ( paddle.sort(pick_mask, axis=-1, descending=True, stable=True), @@ -477,32 +477,36 @@ def extend_coord_with_ghosts( nbuff = paddle.ceil(rcut / to_face).to(paddle.int64) # 3 nbuff = paddle.amax(nbuff, axis=0) # faster than paddle.max - nbuff_cpu = nbuff.cpu() + # nbuff_cpu = nbuff.cpu() xi = ( - paddle.arange(-nbuff_cpu[0], nbuff_cpu[0] + 1, 1) - .to(dtype=env.GLOBAL_PD_FLOAT_PRECISION) - .cpu() + paddle.arange(-nbuff[0], nbuff[0] + 1, 1).to( + dtype=env.GLOBAL_PD_FLOAT_PRECISION + ) + # .cpu() ) # pylint: disable=no-explicit-dtype yi = ( - paddle.arange(-nbuff_cpu[1], nbuff_cpu[1] + 1, 1) - .to(dtype=env.GLOBAL_PD_FLOAT_PRECISION) - .cpu() + paddle.arange(-nbuff[1], nbuff[1] + 1, 1).to( + dtype=env.GLOBAL_PD_FLOAT_PRECISION + ) + # .cpu() ) # pylint: disable=no-explicit-dtype zi = ( - paddle.arange(-nbuff_cpu[2], nbuff_cpu[2] + 1, 1) - .to(dtype=env.GLOBAL_PD_FLOAT_PRECISION) - .cpu() + paddle.arange(-nbuff[2], nbuff[2] + 1, 1).to( + dtype=env.GLOBAL_PD_FLOAT_PRECISION + ) + # .cpu() ) # pylint: disable=no-explicit-dtype eye_3 = ( - paddle.eye(3, dtype=env.GLOBAL_PD_FLOAT_PRECISION) - .to(dtype=env.GLOBAL_PD_FLOAT_PRECISION) - .cpu() + paddle.eye(3, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to( + dtype=env.GLOBAL_PD_FLOAT_PRECISION + ) + # .cpu() ) xyz = xi.reshape([-1, 1, 1, 1]) * eye_3[0] xyz = xyz + yi.reshape([1, -1, 1, 1]) * eye_3[1] xyz = xyz + zi.reshape([1, 1, -1, 1]) * eye_3[2] xyz = xyz.reshape([-1, 3]) - xyz = xyz.to(device=device) + # xyz = xyz.to(device=device) # ns x 3 # shift_idx = xyz[paddle.argsort(paddle.norm(xyz, axis=1))] shift_idx = xyz[paddle.argsort(aux.norm(xyz, axis=1))] diff --git a/deepmd/pd/utils/region.py b/deepmd/pd/utils/region.py index a4acc5924a..21ce2b5e75 100644 --- a/deepmd/pd/utils/region.py +++ b/deepmd/pd/utils/region.py @@ -25,11 +25,14 @@ def phys2inter( the internal coordinates """ - try: + if paddle.in_dynamic_mode(): + try: + rec_cell = paddle.linalg.inv(cell) + except Exception as e: + rec_cell = paddle.full_like(cell, float("nan")) + rec_cell.stop_gradient = cell.stop_gradient + else: rec_cell = paddle.linalg.inv(cell) - except Exception: - rec_cell = paddle.full_like(cell, float("nan")) - rec_cell.stop_gradient = False return paddle.matmul(coord, rec_cell) diff --git a/source/api_c/include/deepmd.hpp b/source/api_c/include/deepmd.hpp index 1c23612293..9d0310d99a 100644 --- a/source/api_c/include/deepmd.hpp +++ b/source/api_c/include/deepmd.hpp @@ -685,7 +685,6 @@ class DeepPot { << std::endl; return; } - std::cout << "** [deepmd.hpp] DeepPot.init" << std::endl; dp = DP_NewDeepPotWithParam2(model.c_str(), gpu_rank, file_content.c_str(), file_content.size()); DP_CHECK_OK(DP_DeepPotCheckOK, dp); diff --git a/source/api_c/src/c_api.cc b/source/api_c/src/c_api.cc index 9ed37d04aa..e7222ce59c 100644 --- a/source/api_c/src/c_api.cc +++ b/source/api_c/src/c_api.cc @@ -1,5 +1,4 @@ // SPDX-License-Identifier: LGPL-3.0-or-later -#include "c_api.h" #include #include @@ -10,6 +9,7 @@ #include "DeepTensor.h" #include "c_api_internal.h" #include "common.h" +// #include "/workspace/hesensen/deepmd_backend/deepmd_paddle_new/source/api_c/include/c_api.h" extern "C" { diff --git a/source/api_cc/CMakeLists.txt b/source/api_cc/CMakeLists.txt index 2255857214..ee347f9fd3 100644 --- a/source/api_cc/CMakeLists.txt +++ b/source/api_cc/CMakeLists.txt @@ -24,8 +24,8 @@ if(ENABLE_PYTORCH target_compile_definitions(${libname} PRIVATE BUILD_PYTORCH) endif() if(ENABLE_PADDLE AND NOT BUILD_PY_IF) - target_link_libraries(${libname} PRIVATE "${PADDLE_LIBRARIES}") - target_compile_definitions(${libname} PRIVATE BUILD_PADDLE) + target_link_libraries(${libname} PUBLIC "${PADDLE_LIBRARIES}") + target_compile_definitions(${libname} PUBLIC BUILD_PADDLE) endif() target_include_directories( diff --git a/source/api_cc/include/DeepPotPD.h b/source/api_cc/include/DeepPotPD.h index e1cbfa1f09..8818e86194 100644 --- a/source/api_cc/include/DeepPotPD.h +++ b/source/api_cc/include/DeepPotPD.h @@ -142,17 +142,17 @@ class DeepPotPD : public DeepPotBase { *same aparam. * @param[in] atomic Whether to compute the atomic energy and virial. **/ - // template - // void compute_mixed_type(ENERGYVTYPE& ener, - // std::vector& force, - // std::vector& virial, - // const int& nframes, - // const std::vector& coord, - // const std::vector& atype, - // const std::vector& box, - // const std::vector& fparam, - // const std::vector& aparam, - // const bool atomic); + template + void compute_mixed_type(ENERGYVTYPE& ener, + std::vector& force, + std::vector& virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); /** * @brief Evaluate the energy, force, and virial with the mixed type *by using this DP. @@ -178,19 +178,19 @@ class DeepPotPD : public DeepPotBase { *same aparam. * @param[in] atomic Whether to compute the atomic energy and virial. **/ - // template - // void compute_mixed_type(ENERGYVTYPE& ener, - // std::vector& force, - // std::vector& virial, - // std::vector& atom_energy, - // std::vector& atom_virial, - // const int& nframes, - // const std::vector& coord, - // const std::vector& atype, - // const std::vector& box, - // const std::vector& fparam, - // const std::vector& aparam, - // const bool atomic); + template + void compute_mixed_type(ENERGYVTYPE& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); public: /** @@ -349,7 +349,7 @@ class DeepPotPD : public DeepPotBase { int gpu_id = 0; int do_message_passing = 0; // 1:dpa2 model 0:others bool gpu_enabled = true; - int dtype = paddle_infer::DataType::FLOAT32; + int dtype = paddle_infer::DataType::FLOAT64; // paddle::Tensor firstneigh_tensor; // std::unordered_map comm_dict; /** diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index 81fc594813..7ee6d910d9 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -12,10 +12,9 @@ #ifdef BUILD_PYTORCH #include "DeepPotPT.h" #endif -// #define BUILD_PADDLE -// #ifdef BUILD_PADDLE +#ifdef BUILD_PADDLE #include "DeepPotPD.h" -// #endif +#endif #include "device.h" using namespace deepmd; @@ -34,7 +33,6 @@ DeepPot::~DeepPot() {} void DeepPot::init(const std::string& model, const int& gpu_rank, const std::string& file_content) { - std::cout << "****** access here" << std::endl; if (inited) { std::cerr << "WARNING: deepmd-kit should not be initialized twice, do " "nothing at the second call of initializer" @@ -46,11 +44,10 @@ void DeepPot::init(const std::string& model, backend = deepmd::DPBackend::PyTorch; } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") { backend = deepmd::DPBackend::TensorFlow; - // } else if (model.length() >= 3 && (model.substr(model.length() - 5) == ".json" || model.substr(model.length() - 8) == ".pdmodel")) { - } else if (true) { + } else if ((model.length() >= 5 && model.substr(model.length() - 5) == ".json") || (model.length() >= 8 && model.substr(model.length() - 8) == ".pdmodel")) { backend = deepmd::DPBackend::Paddle; } else { - throw deepmd::deepmd_exception("Unsupported model file formatt"); + throw deepmd::deepmd_exception("Unsupported model file format"); } if (deepmd::DPBackend::TensorFlow == backend) { @@ -66,11 +63,11 @@ void DeepPot::init(const std::string& model, throw deepmd::deepmd_exception("PyTorch backend is not built"); #endif } else if (deepmd::DPBackend::Paddle == backend) { -// #ifdef BUILD_PADDLE +#ifdef BUILD_PADDLE dp = std::make_shared(model, gpu_rank, file_content); -// #else +#else throw deepmd::deepmd_exception("Paddle backend is not built"); -// #endif +#endif } else { throw deepmd::deepmd_exception("Unknown file type"); } diff --git a/source/api_cc/src/DeepPotPD.cc b/source/api_cc/src/DeepPotPD.cc index a7cea9d27f..ebabfc66e1 100644 --- a/source/api_cc/src/DeepPotPD.cc +++ b/source/api_cc/src/DeepPotPD.cc @@ -20,12 +20,13 @@ static void run_model( std::vector& dforce_, std::vector& dvirial, const std::shared_ptr& predictor, - // const std::vector>& input_tensors, const AtomMap& atommap, const int nframes, const int nghost = 0) { + // printf("run_model 1 st\n"); unsigned nloc = atommap.get_type().size(); unsigned nall = nloc + nghost; + // printf("nloc = %d, nall = %d\n", nloc, nall); dener.resize(nframes); if (nloc == 0) { // no backward map needed @@ -39,14 +40,26 @@ static void run_model( } /* Running inference */ + // printf("Running inference st\n"); if (!predictor->Run()) { throw deepmd::deepmd_exception("Paddle inference failed"); } + // printf("Running inference ed\n"); auto output_names = predictor->GetOutputNames(); - auto output_e = predictor->GetOutputHandle(output_names[0]); - auto output_f = predictor->GetOutputHandle(output_names[1]); - auto output_virial_tensor = predictor->GetOutputHandle(output_names[3]); + // for (auto &name: output_names) + // { + // printf("output name: %s, shape: [", name.c_str()); + // auto shape = predictor->GetOutputHandle(name)->shape(); + // for (auto &dd: shape) + // { + // printf("%d, ", dd); + // } + // printf("]\n"); + // } + auto output_e = predictor->GetOutputHandle(output_names[1]); + auto output_f = predictor->GetOutputHandle(output_names[2]); + auto output_virial_tensor = predictor->GetOutputHandle(output_names[4]); // 获取 Output paddle::Tensor 的维度信息 std::vector output_energy_shape = output_e->shape(); @@ -61,29 +74,48 @@ static void run_model( int output_virial_size = std::accumulate(output_virial_shape.begin(), output_virial_shape.end(), 1, std::multiplies()); + // for (int i=0; i oe; + // printf("Resize st\n"); oe.resize(output_energy_size); + // printf("Resize ed\n"); + // printf("CopytoCpu st\n"); output_e->CopyToCpu(oe.data()); + // printf("Resize st\n"); + // printf("CopytoCpu ed\n"); // get data of output_force + // printf("of\n"); std::vector of; of.resize(output_force_size); output_f->CopyToCpu(of.data()); // get data of output_virial + // printf("oav\n"); std::vector oav; oav.resize(output_virial_size); + // printf("oav 2\n"); output_virial_tensor->CopyToCpu(oav.data()); + // printf("oav 22\n"); + // printf("dvirial\n"); std::vector dforce(nframes * 3 * nall); dvirial.resize(nframes * 9); for (int ii = 0; ii < nframes; ++ii) { + // printf("oe[%d] = %.5lf\n", ii, oe[ii]); dener[ii] = oe[ii]; } for (int ii = 0; ii < nframes * nall * 3; ++ii) { dforce[ii] = of[ii]; } // set dvirial to zero, prevent input vector is not zero (#1123) + // printf("fill\n"); std::fill(dvirial.begin(), dvirial.end(), (VALUETYPE)0.); for (int kk = 0; kk < nframes; ++kk) { for (int ii = 0; ii < nall; ++ii) { @@ -99,8 +131,10 @@ static void run_model( } } dforce_ = dforce; + // printf("atommap.backward\n"); atommap.backward(dforce_.begin(), dforce.begin(), 3, nframes, nall); + // printf("run_model 1 ed\n"); } template void run_model( @@ -151,10 +185,10 @@ static void run_model( std::vector& datom_energy_, std::vector& datom_virial_, const std::shared_ptr& predictor, - // const std::vector>& input_tensors, const deepmd::AtomMap& atommap, const int nframes, const int nghost = 0) { + // printf("run_model 2\n"); unsigned nloc = atommap.get_type().size(); unsigned nall = nloc + nghost; dener.resize(nframes); @@ -329,7 +363,6 @@ static void run_model( std::vector& dforce_, std::vector& dvirial, const std::shared_ptr& predictor, - // const std::vector>& input_tensors, const deepmd::AtomMap& atommap, const int nframes, const int nghost = 0) { @@ -389,7 +422,6 @@ static void run_model( std::vector& datom_energy_, std::vector& datom_virial_, const std::shared_ptr& predictor, - // const std::vector>& input_tensors, const deepmd::AtomMap& atommap, const int nframes = 1, const int nghost = 0) { @@ -464,7 +496,7 @@ DeepPotPD::DeepPotPD(const std::string& model, void DeepPotPD::init(const std::string& model, const int& gpu_rank, const std::string& file_content) { - std::cout << ("** Access here.") << std::endl; + // std::cout << ("** Access here.") << std::endl; if (inited) { std::cerr << "WARNING: deepmd-kit should not be initialized twice, do " "nothing at the second call of initializer" @@ -481,10 +513,12 @@ void DeepPotPD::init(const std::string& model, std::string pdmodel_path = ""; std::string pdiparams_path = ""; bool use_paddle_inference = false; + bool use_pir = false; if (model.find(".json") != std::string::npos) { + use_pir = true; pdmodel_path = model; std::string tmp = model; - pdiparams_path = tmp.replace(model.find(".json"), 4, std::string(".pdiparams")); + pdiparams_path = tmp.replace(model.find(".json"), 5, std::string(".pdiparams")); use_paddle_inference = true; } else if (model.find(".pdmodel") != std::string::npos){ pdmodel_path = model; @@ -497,15 +531,23 @@ void DeepPotPD::init(const std::string& model, int math_lib_num_threads = 1; if (use_paddle_inference) { + // printf("***** creating paddle predictor\n"); config = std::make_shared(); + config->DisableGlogInfo(); + // config->SwitchIrDebug(true); + if (use_pir) { + config->EnableNewExecutor(true); + config->EnableNewIR(true); + } config->SetModel(pdmodel_path, pdiparams_path); - config->SwitchIrOptim(true); + // config->SwitchIrOptim(true); config->EnableUseGpu(8192, 0); // std::cout << "IR Optim is: " << config->ir_optim() << std::endl; // config->EnableMKLDNN(); - config->EnableMemoryOptim(); + // config->EnableMemoryOptim(); // config->EnableProfile(); predictor = paddle_infer::CreatePredictor(*config); + // printf("***** created paddle predictor\n"); } /* water se_e2_a tensorflow::DT_DOUBLE = 2 @@ -586,6 +628,7 @@ void DeepPotPD::init(const std::string& model, // " supported " // "See https://deepmd.rtfd.io/compatability/ for details."); // } + // printf("***** initialized finished\n"); } DeepPotPD::~DeepPotPD() {} @@ -677,6 +720,7 @@ void DeepPotPD::compute(ENERGYVTYPE& dener, const std::vector& fparam_, const std::vector& aparam_, const bool atomic) { + // printf("compute 1\n"); // if datype.size is 0, not clear nframes; but 1 is just ok int nframes = datype_.size() > 0 ? (dcoord_.size() / 3 / datype_.size()) : 1; atommap = deepmd::AtomMap(datype_.begin(), datype_.end()); @@ -713,31 +757,31 @@ void DeepPotPD::compute(ENERGYVTYPE& dener, } } -template void DeepPotPD::compute( - ENERGYTYPE& dener, - std::vector& dforce_, - std::vector& dvirial, - std::vector& datom_energy_, - std::vector& datom_virial_, - const std::vector& dcoord_, - const std::vector& datype_, - const std::vector& dbox, - const std::vector& fparam, - const std::vector& aparam, - const bool atomic); +// template void DeepPotPD::compute( +// ENERGYTYPE& dener, +// std::vector& dforce_, +// std::vector& dvirial, +// std::vector& datom_energy_, +// std::vector& datom_virial_, +// const std::vector& dcoord_, +// const std::vector& datype_, +// const std::vector& dbox, +// const std::vector& fparam, +// const std::vector& aparam, +// const bool atomic); -template void DeepPotPD::compute( - ENERGYTYPE& dener, - std::vector& dforce_, - std::vector& dvirial, - std::vector& datom_energy_, - std::vector& datom_virial_, - const std::vector& dcoord_, - const std::vector& datype_, - const std::vector& dbox, - const std::vector& fparam, - const std::vector& aparam, - const bool atomic); +// template void DeepPotPD::compute( +// ENERGYTYPE& dener, +// std::vector& dforce_, +// std::vector& dvirial, +// std::vector& datom_energy_, +// std::vector& datom_virial_, +// const std::vector& dcoord_, +// const std::vector& datype_, +// const std::vector& dbox, +// const std::vector& fparam, +// const std::vector& aparam, +// const bool atomic); template void DeepPotPD::compute>( std::vector& dener, @@ -765,6 +809,16 @@ template void DeepPotPD::compute>( const std::vector& aparam, const bool atomic); +std::vector createNlistTensor(const std::vector>& data) { + std::vector ret; + + for (const auto& row : data) { + ret.insert(ret.end(), row.begin(), row.end()); + } + + return ret; +} + template void DeepPotPD::compute(ENERGYVTYPE& dener, std::vector& dforce_, @@ -780,68 +834,142 @@ void DeepPotPD::compute(ENERGYVTYPE& dener, const std::vector& fparam_, const std::vector& aparam__, const bool atomic) { - int nall = datype_.size(); - // if nall==0, unclear nframes, but 1 is ok - int nframes = nall > 0 ? (dcoord_.size() / nall / 3) : 1; - int nloc = nall - nghost; - std::vector fparam; - std::vector aparam_; - validate_fparam_aparam(nframes, (aparam_nall ? nall : nloc), fparam_, - aparam__); - tile_fparam_aparam(fparam, nframes, dfparam, fparam_); - tile_fparam_aparam(aparam_, nframes, (aparam_nall ? nall : nloc) * daparam, - aparam__); - // std::vector> input_tensors; + /*参考pytorch的推理代码如下*/ + int natoms = datype_.size(); // select real atoms - std::vector dcoord, dforce, aparam, datom_energy, datom_virial; + std::vector dcoord, dforce, aparam_, datom_energy, datom_virial; std::vector datype, fwd_map, bkw_map; int nghost_real, nall_real, nloc_real; - select_real_atoms_coord(dcoord, datype, aparam, nghost_real, fwd_map, bkw_map, - nall_real, nloc_real, dcoord_, datype_, aparam_, - nghost, ntypes, nframes, daparam, nall, aparam_nall); + int nall = natoms; + select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map, + bkw_map, nall_real, nloc_real, dcoord_, datype_, aparam__, + nghost, ntypes, 1, daparam, nall, aparam_nall); + int nloc = nall_real - nghost_real; + int nframes = 1; + std::vector coord_wrapped = dcoord; + auto coord_wrapped_Tensor = predictor->GetInputHandle("coord"); + coord_wrapped_Tensor->Reshape({1, nall_real, 3}); + coord_wrapped_Tensor->CopyFromCpu(coord_wrapped.data()); + + auto atype_Tensor = predictor->GetInputHandle("atype"); + atype_Tensor->Reshape({1, nall_real}); + atype_Tensor->CopyFromCpu(datype.data()); if (ago == 0) { - atommap = deepmd::AtomMap(datype.begin(), datype.begin() + nloc_real); - assert(nloc_real == atommap.get_type().size()); - nlist_data.copy_from_nlist(lmp_list); nlist_data.shuffle_exclude_empty(fwd_map); - nlist_data.shuffle(atommap); - nlist_data.make_inlist(nlist); + nlist_data.padding(); } + std::vector firstneigh = createNlistTensor(nlist_data.jlist); + auto firstneigh_tensor = predictor->GetInputHandle("nlist"); + firstneigh_tensor->Reshape({1, nloc, firstneigh.size() / nloc}); + firstneigh_tensor->CopyFromCpu(firstneigh.data()); - if (dtype == paddle_infer::DataType::FLOAT64) { - if (atomic) { - run_model(dener, dforce, dvirial, datom_energy, datom_virial, - predictor, atommap, nframes, nghost_real); - } else { - run_model(dener, dforce, dvirial, predictor, atommap, - nframes, nghost_real); - } - } else { - if (atomic) { - run_model(dener, dforce, dvirial, datom_energy, datom_virial, - predictor, atommap, nframes, nghost_real); - } else { - run_model(dener, dforce, dvirial, predictor, atommap, - nframes, nghost_real); - } + + if (!predictor->Run()) { + throw deepmd::deepmd_exception("Paddle inference failed"); } + auto output_names = predictor->GetOutputNames(); + + auto print_shape = [](const std::vector &shape, const std::string &name=""){ + printf("shape of %s: [", name.c_str()); + for (int i=0; iGetOutputHandle(output_names[1]); + auto output_f = predictor->GetOutputHandle(output_names[2]); + auto output_virial_tensor = predictor->GetOutputHandle(output_names[3]); + // print_shape(output_e->shape(), "ener"); + // print_shape(output_f->shape(), "force"); + // print_shape(output_virial_tensor->shape(), "virial"); + std::vector output_energy_shape = output_e->shape(); + int output_energy_size = + std::accumulate(output_energy_shape.begin(), output_energy_shape.end(), 1, + std::multiplies()); + std::vector output_force_shape = output_f->shape(); + int output_force_size = + std::accumulate(output_force_shape.begin(), output_force_shape.end(), 1, + std::multiplies()); + std::vector output_virial_shape = output_virial_tensor->shape(); + int output_virial_size = + std::accumulate(output_virial_shape.begin(), output_virial_shape.end(), 1, + std::multiplies()); + std::vector oe; + oe.resize(output_energy_size); + output_e->CopyToCpu(oe.data()); + + std::vector of; + of.resize(output_force_size); + output_f->CopyToCpu(of.data()); + std::vector oav; + oav.resize(output_virial_size); + output_virial_tensor->CopyToCpu(oav.data()); + + dvirial.resize(nframes * 9); + dener.assign(oe.begin(), oe.end()); + dforce.resize(nframes * 3 * nall); + for (int ii = 0; ii < nframes * nall * 3; ++ii) { + dforce[ii] = of[ii]; + } + std::fill(dvirial.begin(), dvirial.end(), (VALUETYPE)0.); + dvirial.assign(oav.begin(), oav.end()); + // for (int kk = 0; kk < nframes; ++kk) { + // for (int ii = 0; ii < nall; ++ii) { + // dvirial[kk * 9 + 0] += (VALUETYPE)1.0 * oav[kk * nall * 9 + 9 * ii + 0]; + // dvirial[kk * 9 + 1] += (VALUETYPE)1.0 * oav[kk * nall * 9 + 9 * ii + 1]; + // dvirial[kk * 9 + 2] += (VALUETYPE)1.0 * oav[kk * nall * 9 + 9 * ii + 2]; + // dvirial[kk * 9 + 3] += (VALUETYPE)1.0 * oav[kk * nall * 9 + 9 * ii + 3]; + // dvirial[kk * 9 + 4] += (VALUETYPE)1.0 * oav[kk * nall * 9 + 9 * ii + 4]; + // dvirial[kk * 9 + 5] += (VALUETYPE)1.0 * oav[kk * nall * 9 + 9 * ii + 5]; + // dvirial[kk * 9 + 6] += (VALUETYPE)1.0 * oav[kk * nall * 9 + 9 * ii + 6]; + // dvirial[kk * 9 + 7] += (VALUETYPE)1.0 * oav[kk * nall * 9 + 9 * ii + 7]; + // dvirial[kk * 9 + 8] += (VALUETYPE)1.0 * oav[kk * nall * 9 + 9 * ii + 8]; + // } + // } // bkw map dforce_.resize(static_cast(nframes) * fwd_map.size() * 3); - datom_energy_.resize(static_cast(nframes) * fwd_map.size()); - datom_virial_.resize(static_cast(nframes) * fwd_map.size() * 9); select_map(dforce_, dforce, bkw_map, 3, nframes, fwd_map.size(), nall_real); - select_map(datom_energy_, datom_energy, bkw_map, 1, nframes, - fwd_map.size(), nall_real); - select_map(datom_virial_, datom_virial, bkw_map, 9, nframes, - fwd_map.size(), nall_real); } -template void DeepPotPD::compute( - ENERGYTYPE& dener, +// template void DeepPotPD::compute( +// ENERGYTYPE& dener, +// std::vector& dforce_, +// std::vector& dvirial, +// std::vector& datom_energy_, +// std::vector& datom_virial_, +// const std::vector& dcoord_, +// const std::vector& datype_, +// const std::vector& dbox, +// const int nghost, +// const InputNlist& lmp_list, +// const int& ago, +// const std::vector& fparam, +// const std::vector& aparam_, +// const bool atomic); + +// template void DeepPotPD::compute( +// ENERGYTYPE& dener, +// std::vector& dforce_, +// std::vector& dvirial, +// std::vector& datom_energy_, +// std::vector& datom_virial_, +// const std::vector& dcoord_, +// const std::vector& datype_, +// const std::vector& dbox, +// const int nghost, +// const InputNlist& lmp_list, +// const int& ago, +// const std::vector& fparam, +// const std::vector& aparam_, +// const bool atomic); + +template void DeepPotPD::compute>( + std::vector& dener, std::vector& dforce_, std::vector& dvirial, std::vector& datom_energy_, @@ -856,8 +984,8 @@ template void DeepPotPD::compute( const std::vector& aparam_, const bool atomic); -template void DeepPotPD::compute( - ENERGYTYPE& dener, +template void DeepPotPD::compute>( + std::vector& dener, std::vector& dforce_, std::vector& dvirial, std::vector& datom_energy_, @@ -872,144 +1000,110 @@ template void DeepPotPD::compute( const std::vector& aparam_, const bool atomic); -template void DeepPotPD::compute>( - std::vector& dener, +// mixed type + +template +void DeepPotPD::compute_mixed_type(ENERGYVTYPE& dener, + std::vector& dforce_, + std::vector& dvirial, + std::vector& datom_energy_, + std::vector& datom_virial_, + const int& nframes, + const std::vector& dcoord_, + const std::vector& datype_, + const std::vector& dbox, + const std::vector& fparam_, + const std::vector& aparam_, + const bool atomic) { + int nloc = datype_.size() / nframes; + // here atommap only used to get nloc + atommap = deepmd::AtomMap(datype_.begin(), datype_.begin() + nloc); + std::vector fparam; + std::vector aparam; + validate_fparam_aparam(nframes, nloc, fparam_, aparam_); + tile_fparam_aparam(fparam, nframes, dfparam, fparam_); + tile_fparam_aparam(aparam, nframes, nloc * daparam, aparam_); + + if (dtype == paddle_infer::DataType::FLOAT64) { + int nloc = predictor_input_tensors_mixed_type( + predictor, nframes, dcoord_, ntypes, datype_, dbox, cell_size, + fparam, aparam, atommap, aparam_nall); + if (atomic) { + run_model(dener, dforce_, dvirial, datom_energy_, datom_virial_, predictor, + atommap, nframes); + } else { + run_model(dener, dforce_, dvirial, predictor, + atommap, nframes); + } + } else { + int nloc = predictor_input_tensors_mixed_type( + predictor, nframes, dcoord_, ntypes, datype_, dbox, cell_size, + fparam, aparam, atommap, aparam_nall); + if (atomic) { + run_model(dener, dforce_, dvirial, datom_energy_, datom_virial_, predictor, + atommap, nframes); + } else { + run_model(dener, dforce_, dvirial, predictor, atommap, + nframes); + } + } +} + +template void DeepPotPD::compute_mixed_type( + ENERGYTYPE& dener, std::vector& dforce_, std::vector& dvirial, std::vector& datom_energy_, std::vector& datom_virial_, + const int& nframes, const std::vector& dcoord_, const std::vector& datype_, const std::vector& dbox, - const int nghost, - const InputNlist& lmp_list, - const int& ago, const std::vector& fparam, - const std::vector& aparam_, + const std::vector& aparam, const bool atomic); -template void DeepPotPD::compute>( - std::vector& dener, +template void DeepPotPD::compute_mixed_type( + ENERGYTYPE& dener, std::vector& dforce_, std::vector& dvirial, std::vector& datom_energy_, std::vector& datom_virial_, + const int& nframes, const std::vector& dcoord_, const std::vector& datype_, const std::vector& dbox, - const int nghost, - const InputNlist& lmp_list, - const int& ago, const std::vector& fparam, - const std::vector& aparam_, + const std::vector& aparam, const bool atomic); -// mixed type - -// template -// void DeepPotPD::compute_mixed_type(ENERGYVTYPE& dener, -// std::vector& dforce_, -// std::vector& dvirial, -// std::vector& datom_energy_, -// std::vector& datom_virial_, -// const int& nframes, -// const std::vector& dcoord_, -// const std::vector& datype_, -// const std::vector& dbox, -// const std::vector& fparam_, -// const std::vector& aparam_, -// const bool atomic) { -// int nloc = datype_.size() / nframes; -// // here atommap only used to get nloc -// atommap = deepmd::AtomMap(datype_.begin(), datype_.begin() + nloc); -// std::vector fparam; -// std::vector aparam; -// validate_fparam_aparam(nframes, nloc, fparam_, aparam_); -// tile_fparam_aparam(fparam, nframes, dfparam, fparam_); -// tile_fparam_aparam(aparam, nframes, nloc * daparam, aparam_); - -// // std::vector> input_tensors; - -// if (dtype == paddle_infer::DataType::FLOAT64) { -// // int nloc = session_input_tensors_mixed_type( -// // input_tensors, nframes, dcoord_, ntypes, datype_, dbox, cell_size, -// // fparam, aparam, atommap, "", aparam_nall); -// if (atomic) { -// run_model(dener, dforce_, dvirial, datom_energy_, datom_virial_, predictor, -// atommap, nframes); -// } else { -// run_model(dener, dforce_, dvirial, predictor, -// atommap, nframes); -// } -// } else { -// // int nloc = session_input_tensors_mixed_type( -// // input_tensors, nframes, dcoord_, ntypes, datype_, dbox, cell_size, -// // fparam, aparam, atommap, "", aparam_nall); -// if (atomic) { -// run_model(dener, dforce_, dvirial, datom_energy_, datom_virial_, predictor, -// atommap, nframes); -// } else { -// run_model(dener, dforce_, dvirial, atommap, predictor, -// nframes); -// } -// } -// } - -// template void DeepPotPD::compute_mixed_type( -// ENERGYTYPE& dener, -// std::vector& dforce_, -// std::vector& dvirial, -// std::vector& datom_energy_, -// std::vector& datom_virial_, -// const int& nframes, -// const std::vector& dcoord_, -// const std::vector& datype_, -// const std::vector& dbox, -// const std::vector& fparam, -// const std::vector& aparam, -// const bool atomic); - -// template void DeepPotPD::compute_mixed_type( -// ENERGYTYPE& dener, -// std::vector& dforce_, -// std::vector& dvirial, -// std::vector& datom_energy_, -// std::vector& datom_virial_, -// const int& nframes, -// const std::vector& dcoord_, -// const std::vector& datype_, -// const std::vector& dbox, -// const std::vector& fparam, -// const std::vector& aparam, -// const bool atomic); - -// template void DeepPotPD::compute_mixed_type>( -// std::vector& dener, -// std::vector& dforce_, -// std::vector& dvirial, -// std::vector& datom_energy_, -// std::vector& datom_virial_, -// const int& nframes, -// const std::vector& dcoord_, -// const std::vector& datype_, -// const std::vector& dbox, -// const std::vector& fparam, -// const std::vector& aparam, -// const bool atomic); +template void DeepPotPD::compute_mixed_type>( + std::vector& dener, + std::vector& dforce_, + std::vector& dvirial, + std::vector& datom_energy_, + std::vector& datom_virial_, + const int& nframes, + const std::vector& dcoord_, + const std::vector& datype_, + const std::vector& dbox, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); -// template void DeepPotPD::compute_mixed_type>( -// std::vector& dener, -// std::vector& dforce_, -// std::vector& dvirial, -// std::vector& datom_energy_, -// std::vector& datom_virial_, -// const int& nframes, -// const std::vector& dcoord_, -// const std::vector& datype_, -// const std::vector& dbox, -// const std::vector& fparam, -// const std::vector& aparam, -// const bool atomic); +template void DeepPotPD::compute_mixed_type>( + std::vector& dener, + std::vector& dforce_, + std::vector& dvirial, + std::vector& datom_energy_, + std::vector& datom_virial_, + const int& nframes, + const std::vector& dcoord_, + const std::vector& datype_, + const std::vector& dbox, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); template @@ -1018,7 +1112,8 @@ VT DeepPotPD::get_scalar(const std::string& name) const { } void DeepPotPD::get_type_map(std::string& type_map) { - type_map = predictor_get_scalar(predictor, "generated_tensor_12"); + type_map = "O H "; + // type_map = predictor_get_scalar(predictor, "type_map"); } // forward to template method @@ -1084,34 +1179,34 @@ void DeepPotPD::computew(std::vector& ener, compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, nghost, inlist, ago, fparam, aparam, atomic); } -// void DeepPotPD::computew_mixed_type(std::vector& ener, -// std::vector& force, -// std::vector& virial, -// std::vector& atom_energy, -// std::vector& atom_virial, -// const int& nframes, -// const std::vector& coord, -// const std::vector& atype, -// const std::vector& box, -// const std::vector& fparam, -// const std::vector& aparam, -// const bool atomic) { -// compute_mixed_type(ener, force, virial, atom_energy, atom_virial, nframes, -// coord, atype, box, fparam, aparam, atomic); -// } -// void DeepPotPD::computew_mixed_type(std::vector& ener, -// std::vector& force, -// std::vector& virial, -// std::vector& atom_energy, -// std::vector& atom_virial, -// const int& nframes, -// const std::vector& coord, -// const std::vector& atype, -// const std::vector& box, -// const std::vector& fparam, -// const std::vector& aparam, -// const bool atomic) { -// compute_mixed_type(ener, force, virial, atom_energy, atom_virial, nframes, -// coord, atype, box, fparam, aparam, atomic); -// } +void DeepPotPD::computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + compute_mixed_type(ener, force, virial, atom_energy, atom_virial, nframes, + coord, atype, box, fparam, aparam, atomic); +} +void DeepPotPD::computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + compute_mixed_type(ener, force, virial, atom_energy, atom_virial, nframes, + coord, atype, box, fparam, aparam, atomic); +} #endif diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 4ff2fa79e8..378b50a71c 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -929,6 +929,7 @@ int deepmd::session_get_dtype(tensorflow::Session* session, #endif #ifdef BUILD_PADDLE + template int deepmd::predictor_input_tensors( const std::shared_ptr& predictor, @@ -936,19 +937,17 @@ int deepmd::predictor_input_tensors( const int& ntypes, const std::vector& datype_, const std::vector& dbox, - InputNlist& dlist, + const double& cell_size, const std::vector& fparam_, - const std::vector& aparam_, + const std::vector& aparam__, const deepmd::AtomMap& atommap, - const int nghost, - const int ago, const bool aparam_nall) { // if datype.size is 0, not clear nframes; but 1 is just ok int nframes = datype_.size() > 0 ? (dcoord_.size() / 3 / datype_.size()) : 1; int nall = datype_.size(); - int nloc = nall - nghost; + int nloc = nall; assert(nall * 3 * nframes == dcoord_.size()); - assert(dbox.size() == nframes * 9); + bool b_pbc = (dbox.size() == nframes * 9); std::vector datype = atommap.get_type(); std::vector type_count(ntypes, 0); @@ -957,62 +956,86 @@ int deepmd::predictor_input_tensors( } datype.insert(datype.end(), datype_.begin() + nloc, datype_.end()); - std::vector dcoord(dcoord_); - atommap.forward(dcoord.begin(), dcoord_.begin(), 3, nframes, nall); - // 准备输入Tensor句柄 auto input_names = predictor->GetInputNames(); auto coord_handle = predictor->GetInputHandle(input_names[0]); - auto atype_handle = predictor->GetInputHandle(input_names[1]); + auto type_handle = predictor->GetInputHandle(input_names[1]); auto natoms_handle = predictor->GetInputHandle(input_names[2]); auto box_handle = predictor->GetInputHandle(input_names[3]); auto mesh_handle = predictor->GetInputHandle(input_names[4]); // 设置输入 Tensor 的维度信息 - std::vector COORD_SHAPE = {nframes, nall * 3}; - std::vector ATYPE_SHAPE = {nframes, nall}; - std::vector BOX_SHAPE = {nframes, 9}; - std::vector MESH_SHAPE = {16}; - std::vector NATOMS_SHAPE = {2 + ntypes}; - - coord_handle->Reshape(COORD_SHAPE); - atype_handle->Reshape(ATYPE_SHAPE); - natoms_handle->Reshape(NATOMS_SHAPE); - box_handle->Reshape(BOX_SHAPE); - mesh_handle->Reshape(MESH_SHAPE); + std::vector coord_shape = {nframes, nall * 3}; + std::vector atype_shape = {nframes, nall}; + std::vector box_shape = {nframes, 9}; + std::vector mesh_shape; + if (b_pbc) { + mesh_shape = std::vector({6}); + } else { + mesh_shape = std::vector({0}); + } + + std::vector natoms_shape = {2 + ntypes}; + + coord_handle->Reshape(coord_shape); + type_handle->Reshape(atype_shape); + natoms_handle->Reshape(natoms_shape); + box_handle->Reshape(box_shape); + mesh_handle->Reshape(mesh_shape); + + paddle_infer::DataType model_type; + if (std::is_same::value) { + model_type = paddle_infer::DataType::FLOAT64; + } else if (std::is_same::value) { + model_type = paddle_infer::DataType::FLOAT32; + } else { + throw deepmd::deepmd_exception("unsupported data type"); + } + + std::vector dcoord(dcoord_); + atommap.forward(dcoord.begin(), dcoord_.begin(), 3, nframes, nall); + std::vector aparam_(aparam__); + if ((aparam_nall ? nall : nloc) > 0) { + atommap.forward( + aparam_.begin(), aparam__.begin(), + aparam__.size() / nframes / (aparam_nall ? nall : nloc), nframes, + (aparam_nall ? nall : nloc)); + } // 发送输入数据到Tensor句柄 coord_handle->CopyFromCpu(dcoord.data()); - - std::vector datype_pad(nframes * nall, 0); + if (b_pbc) { + box_handle->CopyFromCpu(dbox.data()); + } else { + std::vector zero = dbox; + std::fill(zero.begin(), zero.end(), 0); + box_handle->CopyFromCpu(zero.data()); + } + std::vector datype_rep(nframes * nall, 0); for (int ii = 0; ii < nframes; ++ii) { for (int jj = 0; jj < nall; ++jj) { - datype_pad[ii * nall + jj] = datype[jj]; + datype_rep[ii * nall + jj] = datype[jj]; } } - atype_handle->CopyFromCpu(datype_pad.data()); - - std::vector mesh_pad(16, 0); - mesh_pad[0] = ago; - mesh_pad[1] = dlist.inum; - mesh_pad[2] = 0; - mesh_pad[3] = 0; - memcpy(&mesh_pad[4], &(dlist.ilist), sizeof(int*)); - memcpy(&mesh_pad[8], &(dlist.numneigh), sizeof(int*)); - memcpy(&mesh_pad[12], &(dlist.firstneigh), sizeof(int**)); - mesh_handle->CopyFromCpu(mesh_pad.data()); - - std::vector natoms_pad = {nloc, nall}; + type_handle->CopyFromCpu(datype_rep.data()); + std::vector mesh; + if (b_pbc) { + mesh = std::vector(6); + mesh[1 - 1] = 0; + mesh[2 - 1] = 0; + mesh[3 - 1] = 0; + mesh[4 - 1] = 0; + mesh[5 - 1] = 0; + mesh[6 - 1] = 0; + } else { + mesh = std::vector(0); + } + mesh_handle->CopyFromCpu(mesh.data()); + std::vector natoms = {nloc, nall}; for (int ii = 0; ii < ntypes; ++ii) { - natoms_pad.push_back(type_count[ii]); + natoms.push_back(type_count[ii]); } - natoms_handle->CopyFromCpu(natoms_pad.data()); - - box_handle->CopyFromCpu(dbox.data()); - - const int stride = sizeof(int*) / sizeof(int); - assert(stride * sizeof(int) == sizeof(int*)); - assert(stride <= 4); + natoms_handle->CopyFromCpu(natoms.data()); return nloc; } @@ -1024,103 +1047,285 @@ int deepmd::predictor_input_tensors( const int& ntypes, const std::vector& datype_, const std::vector& dbox, - const double& cell_size, + InputNlist& dlist, const std::vector& fparam_, - const std::vector& aparam_, + const std::vector& aparam__, const deepmd::AtomMap& atommap, + const int nghost, + const int ago, const bool aparam_nall) { // if datype.size is 0, not clear nframes; but 1 is just ok int nframes = datype_.size() > 0 ? (dcoord_.size() / 3 / datype_.size()) : 1; int nall = datype_.size(); - int nloc = nall; + int nloc = nall - nghost; assert(nall * 3 * nframes == dcoord_.size()); - bool b_pbc = (dbox.size() == nframes * 9); + assert(dbox.size() == nframes * 9); std::vector datype = atommap.get_type(); + // for (int i=0; i type_count(ntypes, 0); for (unsigned ii = 0; ii < datype.size(); ++ii) { type_count[datype[ii]]++; } datype.insert(datype.end(), datype_.begin() + nloc, datype_.end()); - std::vector dcoord(dcoord_); - atommap.forward(dcoord.begin(), dcoord_.begin(), 3, nframes, nall); - // 准备输入Tensor句柄 auto input_names = predictor->GetInputNames(); + // for (auto &ss: input_names) + // { + // std::cout << "input_name: " << " " << ss << std::endl; + // } auto coord_handle = predictor->GetInputHandle(input_names[0]); - auto atype_handle = predictor->GetInputHandle(input_names[1]); - auto natoms_handle = predictor->GetInputHandle(input_names[2]); - auto box_handle = predictor->GetInputHandle(input_names[3]); - auto mesh_handle = predictor->GetInputHandle(input_names[4]); + auto type_handle = predictor->GetInputHandle(input_names[1]); + // auto natoms_handle = predictor->GetInputHandle(input_names[2]); + auto box_handle = predictor->GetInputHandle(input_names[2]); + // auto mesh_handle = predictor->GetInputHandle(input_names[4]); // 设置输入 Tensor 的维度信息 - std::vector COORD_SHAPE = {nframes, nall * 3}; - std::vector ATYPE_SHAPE = {nframes, nall}; - std::vector BOX_SHAPE = {nframes, 9}; - std::vector MESH_SHAPE; - if (b_pbc) { - MESH_SHAPE = std::vector(6); + std::vector coord_shape = {nframes, nall, 3}; + std::vector coord_shape_flat = {nframes, nall * 3}; + + std::vector atype_shape = {nframes, nall}; + std::vector atype_shape_flat = {nframes, nall}; + + std::vector box_shape = {nframes, 3, 3}; + std::vector box_shape_flat = {nframes * 9}; + // std::vector mesh_shape = std::vector({16}); + // std::vector natoms_shape = {2 + ntypes}; + + paddle_infer::DataType model_type; + if (std::is_same::value) { + model_type = paddle_infer::DataType::FLOAT64; + } else if (std::is_same::value) { + model_type = paddle_infer::DataType::FLOAT32; } else { - MESH_SHAPE = std::vector(0); + throw deepmd::deepmd_exception("unsupported data type"); } - std::vector NATOMS_SHAPE = {2 + ntypes}; + coord_handle->Reshape(coord_shape_flat); + box_handle->Reshape(box_shape_flat); + type_handle->Reshape(atype_shape_flat); + // printf("coord.shape = ["); + // for (auto &d: coord_shape) + // { + // printf("%d, ", d); + // } + // printf("]\n"); + + // printf("type.shape = ["); + // for (auto &d: atype_shape) + // { + // printf("%d, ", d); + // } + // printf("]\n"); + + // printf("box.shape = ["); + // for (auto &d: box_shape) + // { + // printf("%d, ", d); + // } + // printf("]\n"); + // mesh_handle->Reshape(mesh_shape); + // natoms_handle->Reshape(natoms_shape); - coord_handle->Reshape(COORD_SHAPE); - atype_handle->Reshape(ATYPE_SHAPE); - natoms_handle->Reshape(NATOMS_SHAPE); - box_handle->Reshape(BOX_SHAPE); - mesh_handle->Reshape(MESH_SHAPE); + std::vector dcoord(dcoord_); + atommap.forward(dcoord.begin(), dcoord_.begin(), 3, nframes, nall); //012 + std::vector aparam_(aparam__); + if ((aparam_nall ? nall : nloc) > 0) { + atommap.forward( + aparam_.begin(), aparam__.begin(), + aparam__.size() / nframes / (aparam_nall ? nall : nloc), nframes, + (aparam_nall ? nall : nloc)); + } + + // const std::string filename = "/workspace/hesensen/deepmd_backend/deepmd_paddle_new/examples/water/lmp/coord_torch.log"; + // std::ifstream inputFile(filename); + // VALUETYPE number; + // int iii = 0; + // while (inputFile >> number) { + // dcoord[iii] = number; + // ++iii; + // } + // printf("dcoord finished, iii = %d\n", iii); + // inputFile.close(); // 发送输入数据到Tensor句柄 coord_handle->CopyFromCpu(dcoord.data()); - - std::vector datype_pad(nframes * nall, 0); + coord_handle->Reshape(coord_shape); + box_handle->CopyFromCpu(dbox.data()); + box_handle->Reshape(box_shape); + // for (int i = 0; i < dcoord.size(); ++i) + // { + // printf("dcoord[%d] = %.6lf\n", i, dcoord[i]); + // } + std::vector datype_rep(nframes * nall, 0); for (int ii = 0; ii < nframes; ++ii) { for (int jj = 0; jj < nall; ++jj) { - datype_pad[ii * nall + jj] = datype[jj]; + datype_rep[ii * nall + jj] = datype[jj]; } } - atype_handle->CopyFromCpu(datype_pad.data()); + // const std::string filename1 = "/workspace/hesensen/deepmd_backend/deepmd_paddle_new/examples/water/lmp/type_torch.log"; + // std::ifstream inputFile1(filename1); + // int number_int; + // iii = 0; + // while (inputFile1 >> number_int) { + // datype_rep[iii] = number_int; + // ++iii; + // } + // printf("atype finishied, iii = %d\n", iii); + // inputFile1.close(); + + type_handle->CopyFromCpu(datype_rep.data()); + // for (int i = 0; i < datype_rep.size(); ++i) + // { + // printf("%d\n", datype_rep[i]); + // } + type_handle->Reshape(atype_shape); + // std::vector mesh(mesh_shape[0], 0); + // for (int ii = 0; ii < 16; ++ii) { + // mesh[ii] = 0; + // } + // const int stride = sizeof(int*) / sizeof(int); + // assert(stride * sizeof(int) == sizeof(int*)); + // assert(stride <= 4); + // mesh[0] = ago; + // mesh[1] = dlist.inum; + // mesh[2] = 0; + // mesh[3] = 0; + // memcpy(&mesh[4], &(dlist.ilist), sizeof(int*)); + // memcpy(&mesh[8], &(dlist.numneigh), sizeof(int*)); + // memcpy(&mesh[12], &(dlist.firstneigh), sizeof(int**)); + // mesh_handle->CopyFromCpu(mesh.data()); + + // std::vector natoms = {nloc, nall}; + // for (int ii = 0; ii < ntypes; ++ii) { + // natoms.push_back(type_count[ii]); + // } + // natoms_handle->CopyFromCpu(natoms.data()); + // printf("finished predictor_input_tensors\n"); + // printf("nloc = %d\n", nloc); + return nloc; +} + +template +int deepmd::predictor_input_tensors_mixed_type( + const std::shared_ptr& predictor, + const int& nframes, + const std::vector& dcoord_, + const int& ntypes, + const std::vector& datype_, + const std::vector& dbox, + const double& cell_size, + const std::vector& fparam_, + const std::vector& aparam__, + const deepmd::AtomMap& atommap, + const bool aparam_nall) { + int nall = datype_.size() / nframes; + int nloc = nall; + assert(nall * 3 * nframes == dcoord_.size()); + bool b_pbc = (dbox.size() == nframes * 9); + std::vector datype(datype_); + atommap.forward(datype.begin(), datype_.begin(), 1, nframes, nall); - std::vector mesh_pad; + auto input_names = predictor->GetInputNames(); + auto coord_handle = predictor->GetInputHandle(input_names[0]); + auto type_handle = predictor->GetInputHandle(input_names[1]); + auto box_handle = predictor->GetInputHandle(input_names[3]); + auto mesh_handle = predictor->GetInputHandle(input_names[4]); + auto natoms_handle = predictor->GetInputHandle(input_names[2]); + + // 设置输入 Tensor 的维度信息 + std::vector coord_shape = {nframes, nall * 3}; + std::vector atype_shape = {nframes, nall}; + std::vector box_shape = {nframes, 9}; + std::vector mesh_shape; if (b_pbc) { - mesh_pad = std::vector(6); + mesh_shape = std::vector({7}); } else { - mesh_pad = std::vector(0); - } - // mesh_pad[0] = ago; - // mesh_pad[1] = dlist.inum; - // mesh_pad[2] = 0; - // mesh_pad[3] = 0; - // memcpy(&mesh_pad[4], &(dlist.ilist), sizeof(int*)); - // memcpy(&mesh_pad[8], &(dlist.numneigh), sizeof(int*)); - // memcpy(&mesh_pad[12], &(dlist.firstneigh), sizeof(int**)); - mesh_handle->CopyFromCpu(mesh_pad.data()); - if (b_pbc) { - mesh_pad[1 - 1] = 0; - mesh_pad[2 - 1] = 0; - mesh_pad[3 - 1] = 0; - mesh_pad[4 - 1] = 0; - mesh_pad[5 - 1] = 0; - mesh_pad[6 - 1] = 0; - } - std::vector natoms_pad = {nloc, nall}; - for (int ii = 0; ii < ntypes; ++ii) { - natoms_pad.push_back(type_count[ii]); + mesh_shape = std::vector({1}); } - // natoms_handle->CopyFromCpu(natoms_pad.data()); + std::vector natoms_shape = {2 + ntypes}; - box_handle->CopyFromCpu(dbox.data()); + coord_handle->Reshape(coord_shape); + type_handle->Reshape(atype_shape); + box_handle->Reshape(box_shape); + mesh_handle->Reshape(mesh_shape); + natoms_handle->Reshape(natoms_shape); - // const int stride = sizeof(int*) / sizeof(int); - // assert(stride * sizeof(int) == sizeof(int*)); - // assert(stride <= 4); + paddle_infer::DataType model_type; + if (std::is_same::value) { + model_type = paddle_infer::DataType::FLOAT64; + } else if (std::is_same::value) { + model_type = paddle_infer::DataType::FLOAT32; + } else { + throw deepmd::deepmd_exception("unsupported data type"); + } + std::vector dcoord(dcoord_); + atommap.forward(dcoord.begin(), dcoord_.begin(), 3, nframes, nall); + std::vector aparam_(aparam__); + if ((aparam_nall ? nall : nloc) > 0) { + atommap.forward( + aparam_.begin(), aparam__.begin(), + aparam__.size() / nframes / (aparam_nall ? nall : nloc), nframes, + (aparam_nall ? nall : nloc)); + } + // coord + coord_handle->CopyFromCpu(dcoord.data()); + + // box + if (b_pbc) { + box_handle->CopyFromCpu(dbox.data()); + } else { + std::vector zero = dbox; + std::fill(zero.begin(), zero.end(), 0); + box_handle->CopyFromCpu(zero.data()); + } + + // datype + std::vector datype_rep(nframes * nall, 0); + for (int ii = 0; ii < nframes; ++ii) { + for (int jj = 0; jj < nall; ++jj) { + datype_rep[ii * nall + jj] = datype[jj]; + } + } + type_handle->CopyFromCpu(datype_rep.data()); + // mesh + std::vector mesh; + if (b_pbc) { + mesh = std::vector(7, 0); + mesh[1 - 1] = 0; + mesh[2 - 1] = 0; + mesh[3 - 1] = 0; + mesh[4 - 1] = 0; + mesh[5 - 1] = 0; + mesh[6 - 1] = 0; + mesh[7 - 1] = 0; + } else { + mesh = std::vector(1, 0); + mesh[1 - 1] = 0; + } + mesh_handle->CopyFromCpu(mesh.data()); + //natoms + std::vector natoms_pad = {nloc, nall, nall}; + if (ntypes > 1) { + for (int ii = 0; ii < ntypes; ++ii) { + natoms_pad.push_back(0); + } + } + natoms_handle->CopyFromCpu(natoms_pad.data()); + + // if (fparam_.size() > 0) { + // input_tensors.push_back({prefix + "t_fparam", fparam_tensor}); + // } + // if (aparam_.size() > 0) { + // input_tensors.push_back({prefix + "t_aparam", aparam_tensor}); + // } return nloc; } + #endif #ifdef BUILD_PADDLE @@ -1456,7 +1661,7 @@ template void deepmd::select_map_inv( #ifdef BUILD_PADDLE template std::string deepmd::predictor_get_scalar( - const std::shared_ptr& predictor, const std::string&); + const std::shared_ptr& predictor, const std::string &name_); // template void deepmd::session_get_vector( // std::vector&, @@ -1789,59 +1994,55 @@ template int deepmd::predictor_input_tensors( const int ago, const bool aparam_nall); -// template int deepmd::session_input_tensors_mixed_type( -// std::vector>& input_tensors, -// const int& nframes, -// const std::vector& dcoord_, -// const int& ntypes, -// const std::vector& datype_, -// const std::vector& dbox, -// const double& cell_size, -// const std::vector& fparam_, -// const std::vector& aparam_, -// const deepmd::AtomMap& atommap, -// const std::string scope, -// const bool aparam_nall); -// template int deepmd::session_input_tensors_mixed_type( -// std::vector>& input_tensors, -// const int& nframes, -// const std::vector& dcoord_, -// const int& ntypes, -// const std::vector& datype_, -// const std::vector& dbox, -// const double& cell_size, -// const std::vector& fparam_, -// const std::vector& aparam_, -// const deepmd::AtomMap& atommap, -// const std::string scope, -// const bool aparam_nall); - -// template int deepmd::session_input_tensors_mixed_type( -// std::vector>& input_tensors, -// const int& nframes, -// const std::vector& dcoord_, -// const int& ntypes, -// const std::vector& datype_, -// const std::vector& dbox, -// const double& cell_size, -// const std::vector& fparam_, -// const std::vector& aparam_, -// const deepmd::AtomMap& atommap, -// const std::string scope, -// const bool aparam_nall); -// template int deepmd::session_input_tensors_mixed_type( -// std::vector>& input_tensors, -// const int& nframes, -// const std::vector& dcoord_, -// const int& ntypes, -// const std::vector& datype_, -// const std::vector& dbox, -// const double& cell_size, -// const std::vector& fparam_, -// const std::vector& aparam_, -// const deepmd::AtomMap& atommap, -// const std::string scope, -// const bool aparam_nall); +template int deepmd::predictor_input_tensors_mixed_type( + const std::shared_ptr& predictor, + const int& nframes, + const std::vector& dcoord_, + const int& ntypes, + const std::vector& datype_, + const std::vector& dbox, + const double& cell_size, + const std::vector& fparam_, + const std::vector& aparam_, + const deepmd::AtomMap& atommap, + const bool aparam_nall); +template int deepmd::predictor_input_tensors_mixed_type( + const std::shared_ptr& predictor, + const int& nframes, + const std::vector& dcoord_, + const int& ntypes, + const std::vector& datype_, + const std::vector& dbox, + const double& cell_size, + const std::vector& fparam_, + const std::vector& aparam_, + const deepmd::AtomMap& atommap, + const bool aparam_nall); + +template int deepmd::predictor_input_tensors_mixed_type( + const std::shared_ptr& predictor, + const int& nframes, + const std::vector& dcoord_, + const int& ntypes, + const std::vector& datype_, + const std::vector& dbox, + const double& cell_size, + const std::vector& fparam_, + const std::vector& aparam_, + const deepmd::AtomMap& atommap, + const bool aparam_nall); +template int deepmd::predictor_input_tensors_mixed_type( + const std::shared_ptr& predictor, + const int& nframes, + const std::vector& dcoord_, + const int& ntypes, + const std::vector& datype_, + const std::vector& dbox, + const double& cell_size, + const std::vector& fparam_, + const std::vector& aparam_, + const deepmd::AtomMap& atommap, + const bool aparam_nall); #endif void deepmd::print_summary(const std::string& pre) { diff --git a/source/install/build_cc.sh b/source/install/build_cc.sh index 6adb62a311..60101eb9a8 100755 --- a/source/install/build_cc.sh +++ b/source/install/build_cc.sh @@ -20,8 +20,7 @@ NPROC=$(nproc --all) BUILD_TMP_DIR=${SCRIPT_PATH}/../build mkdir -p ${BUILD_TMP_DIR} cd ${BUILD_TMP_DIR} -cmake -DCMAKE_PREFIX_PATH=/workspace/hesensen/deepmd_backend/deepmd_paddle_new/source/install/libtorch \ - -D ENABLE_TENSORFLOW=OFF \ +cmake -D ENABLE_TENSORFLOW=ON \ -D ENABLE_PYTORCH=ON \ -D CMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \ -D USE_TF_PYTHON_LIBS=TRUE \ diff --git a/source/install/build_cc_pd.sh b/source/install/build_cc_pd.sh index 5feb1e3426..335a394e5b 100755 --- a/source/install/build_cc_pd.sh +++ b/source/install/build_cc_pd.sh @@ -22,7 +22,9 @@ export LAMMPS_DIR="/workspace/hesensen/deepmd_backend/deepmd_paddle_new/source/b export LAMMPS_SOURCE_ROOT="/workspace/hesensen/deepmd_backend/deepmd_paddle_new/source/build_lammps/lammps-stable_29Aug2024/" # 设置推理时的 GPU 卡号 -export CUDA_VISIBLE_DEVICES=6 +export CUDA_VISIBLE_DEVICES=3 +# export FLAGS_benchmark=1 +# export GLOG_v=6 # PADDLE_DIR 设置为第二步 clone下来的 Paddle 目录 export PADDLE_DIR="/workspace/hesensen/PaddleScience_enn_debug/Paddle/" @@ -43,11 +45,11 @@ export LD_LIBRARY_PATH=${PADDLE_INFERENCE_DIR}/third_party/install/mkldnn/lib:$L export LD_LIBRARY_PATH=${PADDLE_INFERENCE_DIR}/third_party/install/mklml/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=${DEEPMD_DIR}/source/build:$LD_LIBRARY_PATH export LIBRARY_PATH=${DEEPMD_DIR}/deepmd/op:$LIBRARY_PATH - -cd ${DEEPMD_DIR}/source -rm -rf build # 若改动CMakeLists.txt,则需要打开该注释 -mkdir build -cd - +# export FLAGS_check_nan_inf=1 +# cd ${DEEPMD_DIR}/source +# rm -rf build # 若改动CMakeLists.txt,则需要打开该注释 +# mkdir build +# cd - # DEEPMD_INSTALL_DIR 设置为 deepmd-lammps 的目标安装目录,可自行设置任意路径 # export DEEPMD_INSTALL_DIR="path/to/deepmd_root" @@ -84,6 +86,8 @@ cmake -DCMAKE_PREFIX_PATH=/workspace/hesensen/PaddleScience_enn_debug/Paddle/bui -D CMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \ -D USE_TF_PYTHON_LIBS=TRUE \ -D LAMMPS_SOURCE_ROOT=${LAMMPS_SOURCE_ROOT} \ + -D ENABLE_IPI=OFF \ + -D PADDLE_LIBRARIES=/workspace/hesensen/PaddleScience_enn_debug/Paddle/build/paddle_inference_install_dir/paddle/lib/libpaddle_inference.so \ ${CUDA_ARGS} \ -D LAMMPS_VERSION=stable_29Aug2024 \ .. @@ -104,9 +108,12 @@ make no-extra-fix make yes-extra-fix make no-user-deepmd make yes-user-deepmd -make serial -j +# make serial -j +make mpi -j 20 export PATH=${LAMMPS_DIR}/src:$PATH cd ${DEEPMD_DIR}/examples/water/lmp -lmp_serial -in paddle_in.lammps +echo "START INFERENCE..." +# lmp_serial -in paddle_in.lammps 2>&1 | tee paddle_infer.log +mpirun -np 1 lmp_mpi -in paddle_in.lammps 2>&1 | tee paddle_infer.log diff --git a/source/lmp/pair_deepmd.cpp b/source/lmp/pair_deepmd.cpp index 72da1a5ee6..2112c12ac7 100644 --- a/source/lmp/pair_deepmd.cpp +++ b/source/lmp/pair_deepmd.cpp @@ -495,7 +495,6 @@ void PairDeepMD::compute(int eflag, int vflag) { } } } - vector dtype(nall); for (int ii = 0; ii < nall; ++ii) { dtype[ii] = type_idx_map[type[ii] - 1]; @@ -976,13 +975,9 @@ void PairDeepMD::settings(int narg, char **arg) { numb_models = models.size(); if (numb_models == 1) { try { - std::cout << "****** init deepmd model from file 1: " << std::endl; auto node_rank = get_node_rank(); - std::cout << "****** init deepmd model from file 2: " << std::endl; auto content = get_file_content(arg[0]); - std::cout << "****** init deepmd model from file 3: " << std::endl; deep_pot.init(arg[0], node_rank, content); - std::cout << "****** init deepmd model from file 4: " << std::endl; } catch (const std::exception &e) { // error->one(FLERR, e.what()); std::cerr << "Standard exception caught: " << e.what() << std::endl;