Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 22, 2024
1 parent dac7a76 commit d3befec
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 163 deletions.
29 changes: 14 additions & 15 deletions source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <torch/torch.h>

#include "DeepPot.h"
#include "commonPT.h"
#include <torch/torch.h>

namespace deepmd {
/**
Expand Down Expand Up @@ -33,8 +34,7 @@ class DeepPotPT : public DeepPotBase {
* @param[in] file_content The content of the model file. If it is not empty,
*DP will read from the string instead of the file.
**/
void init(const std::string& model,
const int& gpu_rank = 0);
void init(const std::string& model, const int& gpu_rank = 0);

private:
/**
Expand Down Expand Up @@ -63,13 +63,13 @@ class DeepPotPT : public DeepPotBase {
void compute(ENERGYVTYPE& ener,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
// std::vector<VALUETYPE>& atom_energy,
// std::vector<VALUETYPE>& atom_virial,
// std::vector<VALUETYPE>& atom_energy,
// std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box);
// const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
// const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
// const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
// const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
/**
* @brief Evaluate the energy, force, virial, atomic energy, and atomic virial
*by using this DP.
Expand Down Expand Up @@ -99,16 +99,16 @@ class DeepPotPT : public DeepPotBase {
void compute(ENERGYVTYPE& ener,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
// std::vector<VALUETYPE>& atom_energy,
// std::vector<VALUETYPE>& atom_virial,
// std::vector<VALUETYPE>& atom_energy,
// std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
// const int nghost,
// const int nghost,
const InputNlist& lmp_list,
const int& ago);
// const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
// const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
// const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
// const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
/**
* @brief Evaluate the energy, force, and virial with the mixed type
*by using this DP.
Expand Down Expand Up @@ -310,7 +310,6 @@ class DeepPotPT : public DeepPotBase {
const std::vector<float>& aparam = std::vector<float>());

private:

bool inited;
int ntypes;
int ntypes_spin;
Expand All @@ -320,10 +319,10 @@ class DeepPotPT : public DeepPotBase {
torch::jit::script::Module module;
double rcut;
NeighborListData nlist_data;
//InputNlist nlist;
// InputNlist nlist;
int max_num_neighbors;
int gpu_id;
at::Tensor firstneigh_tensor;
at::Tensor firstneigh_tensor;
};

} // namespace deepmd
19 changes: 11 additions & 8 deletions source/api_cc/include/commonPT.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#include <torch/script.h>
#ifndef COMMON_H
#define COMMON_H
#include <iostream>
#include <cstdlib>
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <vector>
#include "neighbor_list.h"

#include "neighbor_list.h"

struct NeighborListData {
/// Array stores the core region atom's index
std::vector<int> ilist;
/// Array stores the core region atom's neighbor index
//std::vector<std::vector<int>> jlist;
int *jlist;
// std::vector<std::vector<int>> jlist;
int* jlist;
/// Array stores the number of neighbors of core region atoms
std::vector<int> numneigh;
/// Array stores the the location of the first neighbor of core region atoms
std::vector<int*> firstneigh;

public:
void copy_from_nlist(const InputNlist& inlist, int& max_num_neighbors,int nnei);
//void make_inlist(InputNlist& inlist);
void copy_from_nlist(const InputNlist& inlist,
int& max_num_neighbors,
int nnei);
// void make_inlist(InputNlist& inlist);
};
#endif
#endif
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void DeepPot::init(const std::string& model,
// TODO: throw errors if TF backend is not built, without mentioning TF
dp = std::make_shared<deepmd::DeepPotTF>(model, gpu_rank, file_content);
} else if (deepmd::DPBackend::PyTorch == backend) {
//throw deepmd::deepmd_exception("PyTorch backend is not supported yet");
// throw deepmd::deepmd_exception("PyTorch backend is not supported yet");
dp = std::make_shared<deepmd::DeepPotPT>(model, gpu_rank, file_content);
} else if (deepmd::DPBackend::Paddle == backend) {
throw deepmd::deepmd_exception("PaddlePaddle backend is not supported yet");
Expand Down
Loading

0 comments on commit d3befec

Please sign in to comment.