Skip to content

Commit

Permalink
Merge pull request #512 from JiabinYang/deemd2paddle
Browse files Browse the repository at this point in the history
Support virial forward run on gpu with cpu kernel
  • Loading branch information
amcadmus authored Apr 14, 2021
2 parents 9426f1c + dde23ee commit e988870
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <vector>

#define CHECK_INPUT(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
#define CHECK_INPUT_READY(x) PD_CHECK(x.IsInitialized(), #x " must be initialized before usage.")
#define CHECK_INPUT_READY(x) PD_CHECK(x.is_initialized(), #x " must be initialized before usage.")
#define CHECK_INPUT_DIM(x, value) PD_CHECK(x.shape().size() == value, #x "'s dim should be " #value ".")
template <typename FPTYPE>
static int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


#define CHECK_INPUT(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
#define CHECK_INPUT_READY(x) PD_CHECK(x.IsInitialized(), #x " must be initialized before usage.")
#define CHECK_INPUT_READY(x) PD_CHECK(x.is_initialized(), #x " must be initialized before usage.")
#define CHECK_INPUT_DIM(x, value) PD_CHECK(x.shape().size() == value, #x "'s dim should be " #value ".")


Expand Down
67 changes: 45 additions & 22 deletions source/op/paddle_ops/srcs/pd_prod_virial_se_a_multi_devices_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


#define CHECK_INPUT(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
#define CHECK_INPUT_READY(x) PD_CHECK(x.IsInitialized(), #x " must be initialized before usage.")
#define CHECK_INPUT_READY(x) PD_CHECK(x.is_initialized(), #x " must be initialized before usage.")
#define CHECK_INPUT_DIM(x, value) PD_CHECK(x.shape().size() == value, #x "'s dim should be " #value ".")


Expand Down Expand Up @@ -47,11 +47,11 @@ const paddle::Tensor& natoms_tensor,
int n_a_sel,
int n_r_sel
){
CHECK_INPUT(net_deriv_tensor);
CHECK_INPUT(in_deriv_tensor);
CHECK_INPUT(rij_tensor);
CHECK_INPUT(nlist_tensor);
CHECK_INPUT(natoms_tensor);
CHECK_INPUT_READY(net_deriv_tensor);
CHECK_INPUT_READY(in_deriv_tensor);
CHECK_INPUT_READY(rij_tensor);
CHECK_INPUT_READY(nlist_tensor);
CHECK_INPUT_READY(natoms_tensor);

CHECK_INPUT_DIM(net_deriv_tensor, 2);
CHECK_INPUT_DIM(in_deriv_tensor, 2);
Expand All @@ -60,7 +60,13 @@ int n_r_sel
CHECK_INPUT_DIM(natoms_tensor, 1);

PD_CHECK(natoms_tensor.shape()[0] >= 3, "number of atoms should be larger than (or equal to) 3");
const int* natoms = natoms_tensor.data<int>();
// TODO:(jiabin) This code should be removed when virial cuda kernel fixed.
const int* natoms = nullptr;
if(natoms_tensor.place() != paddle::PlaceType::kCPU){
natoms = natoms_tensor.copy_to<int>(paddle::PlaceType::kCPU).data<int>();
}else{
natoms = natoms_tensor.data<int>();
}
int nloc = natoms[0];
int nall = natoms[1];
int nnei = nlist_tensor.shape()[1] / nloc;
Expand All @@ -79,14 +85,29 @@ int n_r_sel
paddle::Tensor virial_tensor = paddle::Tensor(paddle::PlaceType::kCPU, virial_shape);
paddle::Tensor atom_virial_tensor = paddle::Tensor(paddle::PlaceType::kCPU, atom_virial_shape);

PD_DISPATCH_FLOATING_TYPES(
net_deriv_tensor.type(), "pd_prod_virial_se_a_cpu_forward_kernel", ([&] {
PdProdVirialSeAOpForwardCPUKernel<data_t>(
nloc, nall, ndescrpt, nnei, nframes,
virial_tensor.mutable_data<data_t>(), atom_virial_tensor.mutable_data<data_t>(),
net_deriv_tensor.data<data_t>(), in_deriv_tensor.data<data_t>(),
rij_tensor.data<data_t>(), nlist_tensor.data<int>());
}));
if(natoms_tensor.place() == paddle::PlaceType::kCPU){
PD_DISPATCH_FLOATING_TYPES(
net_deriv_tensor.type(), "pd_prod_virial_se_a_cpu_forward_kernel", ([&] {
PdProdVirialSeAOpForwardCPUKernel<data_t>(
nloc, nall, ndescrpt, nnei, nframes,
virial_tensor.mutable_data<data_t>(), atom_virial_tensor.mutable_data<data_t>(),
net_deriv_tensor.data<data_t>(), in_deriv_tensor.data<data_t>(),
rij_tensor.data<data_t>(), nlist_tensor.data<int>());
}));
}else{
PD_DISPATCH_FLOATING_TYPES(
net_deriv_tensor.type(), "pd_prod_virial_se_a_cpu_forward_kernel", ([&] {
PdProdVirialSeAOpForwardCPUKernel<data_t>(
nloc, nall, ndescrpt, nnei, nframes,
virial_tensor.mutable_data<data_t>(),
atom_virial_tensor.mutable_data<data_t>(),
net_deriv_tensor.copy_to<data_t>(paddle::PlaceType::kCPU).data<data_t>(),
in_deriv_tensor.copy_to<data_t>(paddle::PlaceType::kCPU).data<data_t>(),
rij_tensor.copy_to<data_t>(paddle::PlaceType::kCPU).data<data_t>(),
nlist_tensor.copy_to<int>(paddle::PlaceType::kCPU).data<int>());
}));
}


return {virial_tensor, atom_virial_tensor};
}
Expand Down Expand Up @@ -210,13 +231,15 @@ const paddle::Tensor& nlist_tensor,
const paddle::Tensor& natoms_tensor,
int n_a_sel,
int n_r_sel){
if(net_deriv_tensor.place() == paddle::PlaceType::kCPU){
return PdProdVirialSeAOpCPUForward(net_deriv_tensor, in_deriv_tensor, rij_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel);
}else if(net_deriv_tensor.place() == paddle::PlaceType::kGPU){
return PdProdVirialSeAOpCUDAForward(net_deriv_tensor, in_deriv_tensor, rij_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel);
}else{
PD_THROW("No Such kernel for PdFrodForceSeAForward!");
}
return PdProdVirialSeAOpCPUForward(net_deriv_tensor, in_deriv_tensor, rij_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel);
// TODO:(jiabin) Support this when virial cuda kernel fixed.
// if(net_deriv_tensor.place() == paddle::PlaceType::kCPU){
// return PdProdVirialSeAOpCPUForward(net_deriv_tensor, in_deriv_tensor, rij_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel);
// }else if(net_deriv_tensor.place() == paddle::PlaceType::kGPU){
// return PdProdVirialSeAOpCUDAForward(net_deriv_tensor, in_deriv_tensor, rij_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel);
// }else{
// PD_THROW("No Such kernel for PdFrodForceSeAForward!");
// }
}

std::vector<paddle::Tensor> PdProdVirialSeABackward(
Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_pd_prod_force_and_virial.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def virial_test (inter,
hh = global_default_fv_hh,
suffix = '') :
# TODO:(jiabin): Remove this line when virial lib fixed by Wang Han
paddle.set_device("cpu")
paddle.set_device("gpu:0")
# set weights
w0 = np.ones (inter.ndescrpt)
inter.net_w_i = np.copy(w0)
Expand Down

0 comments on commit e988870

Please sign in to comment.