Skip to content

Commit

Permalink
support GPU backward of force and virial
Browse files Browse the repository at this point in the history
  • Loading branch information
JiabinYang committed Apr 13, 2021
1 parent d596c18 commit c24e33e
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 47 deletions.
56 changes: 35 additions & 21 deletions source/op/paddle_ops/srcs/pd_prod_force_se_a_multi_devices_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ const paddle::Tensor& natoms_tensor,
int n_a_sel,
int n_r_sel
){
CHECK_INPUT(grad_tensor);
CHECK_INPUT(net_deriv_tensor);
CHECK_INPUT(in_deriv_tensor);
CHECK_INPUT(nlist_tensor);
CHECK_INPUT(natoms_tensor);
CHECK_INPUT_READY(grad_tensor);
CHECK_INPUT_READY(net_deriv_tensor);
CHECK_INPUT_READY(in_deriv_tensor);
CHECK_INPUT_READY(nlist_tensor);
CHECK_INPUT_READY(natoms_tensor);

auto grad_shape = grad_tensor.shape();
auto net_deriv_shape = net_deriv_tensor.shape();
Expand All @@ -142,8 +142,12 @@ int n_r_sel
CHECK_INPUT_DIM(natoms_tensor, 1);

PD_CHECK(natoms_shape[0] >= 3, "number of atoms should be larger than (or equal to) 3");

const int* natoms = natoms_tensor.data<int>();
const int* natoms = nullptr;
if(natoms_tensor.place() != paddle::PlaceType::kCPU){
natoms = natoms_tensor.copy_to<int>(paddle::PlaceType::kCPU).data<int>();
}else{
natoms = natoms_tensor.data<int>();
}
int nloc = natoms[0];
int nframes = net_deriv_shape[0];
int ndescrpt = net_deriv_shape[1] / nloc;
Expand All @@ -158,15 +162,30 @@ int n_r_sel

std::vector<int64_t> grad_net_shape {nframes, nloc * ndescrpt};
paddle::Tensor grad_net_tensor = paddle::Tensor(paddle::PlaceType::kCPU, grad_net_shape);
if(grad_tensor.place() == paddle::PlaceType::kCPU){
PD_DISPATCH_FLOATING_TYPES(
grad_tensor.type(), "pd_prod_force_se_a_cpu_backward_kernel", ([&] {
PdProdForceSeAOpCPUBackwardKernel<data_t>(
nloc, nframes, ndescrpt, nnei,
grad_tensor.data<data_t>(),
net_deriv_tensor.data<data_t>(),
in_deriv_tensor.data<data_t>(),
nlist_tensor.data<int>(),
grad_net_tensor.mutable_data<data_t>());
}));
}else{
PD_DISPATCH_FLOATING_TYPES(
grad_tensor.type(), "pd_prod_force_se_a_cpu_backward_kernel", ([&] {
PdProdForceSeAOpCPUBackwardKernel<data_t>(
nloc, nframes, ndescrpt, nnei,
grad_tensor.copy_to<data_t>(paddle::PlaceType::kCPU).data<data_t>(),
net_deriv_tensor.copy_to<data_t>(paddle::PlaceType::kCPU).data<data_t>(),
in_deriv_tensor.copy_to<data_t>(paddle::PlaceType::kCPU).data<data_t>(),
nlist_tensor.copy_to<int>(paddle::PlaceType::kCPU).data<int>(),
grad_net_tensor.mutable_data<data_t>());
}));
}

PD_DISPATCH_FLOATING_TYPES(
grad_tensor.type(), "pd_prod_force_se_a_cpu_backward_kernel", ([&] {
PdProdForceSeAOpCPUBackwardKernel<data_t>(
nloc, nframes, ndescrpt, nnei,
grad_tensor.data<data_t>(), net_deriv_tensor.data<data_t>(),
in_deriv_tensor.data<data_t>(), nlist_tensor.data<int>(),
grad_net_tensor.mutable_data<data_t>());
}));

return {grad_net_tensor};
}
Expand Down Expand Up @@ -195,15 +214,10 @@ const paddle::Tensor& nlist_tensor,
const paddle::Tensor& natoms_tensor,
int n_a_sel,
int n_r_sel){
if(grad_tensor.place() == paddle::PlaceType::kCPU){
return PdProdForceSeAOpCPUBackward(
return PdProdForceSeAOpCPUBackward(
grad_tensor, net_deriv_tensor, in_deriv_tensor,
nlist_tensor, natoms_tensor, n_a_sel, n_r_sel);
}else{
PD_THROW("No Such kernel for PdFrodForceSeABackward!");
}
}

std::vector<std::vector<int64_t>> PdProdForceSeAOpForwardInferShape(
std::vector<int64_t> net_deriv_shape,
std::vector<int64_t> in_deriv_shape,
Expand Down
64 changes: 41 additions & 23 deletions source/op/paddle_ops/srcs/pd_prod_virial_se_a_multi_devices_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ const paddle::Tensor& natoms_tensor,
int n_a_sel,
int n_r_sel
){
CHECK_INPUT(grad_tensor);
CHECK_INPUT(net_deriv_tensor);
CHECK_INPUT(in_deriv_tensor);
CHECK_INPUT(rij_tensor);
CHECK_INPUT(nlist_tensor);
CHECK_INPUT(natoms_tensor);
CHECK_INPUT_READY(grad_tensor);
CHECK_INPUT_READY(net_deriv_tensor);
CHECK_INPUT_READY(in_deriv_tensor);
CHECK_INPUT_READY(rij_tensor);
CHECK_INPUT_READY(nlist_tensor);
CHECK_INPUT_READY(natoms_tensor);

auto grad_shape = grad_tensor.shape();
auto net_deriv_shape = net_deriv_tensor.shape();
Expand All @@ -151,7 +151,12 @@ int n_r_sel

PD_CHECK(natoms_shape[0] >= 3, "number of atoms should be larger than (or equal to) 3");

const int* natoms = natoms_tensor.data<int>();
const int* natoms = nullptr;
if(natoms_tensor.place() != paddle::PlaceType::kCPU){
natoms = natoms_tensor.copy_to<int>(paddle::PlaceType::kCPU).data<int>();
}else{
natoms = natoms_tensor.data<int>();
}
int nframes = net_deriv_shape[0];
int nloc = natoms[0];
int ndescrpt = net_deriv_shape[1] / nloc;
Expand All @@ -169,14 +174,30 @@ int n_r_sel
std::vector<int64_t> grad_net_shape {nframes, nloc * ndescrpt};
paddle::Tensor grad_net_tensor = paddle::Tensor(paddle::PlaceType::kCPU, grad_net_shape);

PD_DISPATCH_FLOATING_TYPES(
grad_tensor.type(), "pd_prod_force_se_a_cpu_backward_kernel", ([&] {
PdProdForceSeAOpCPUBackwardKernel<data_t>(
nloc, nframes, ndescrpt, nnei,
grad_tensor.data<data_t>(), net_deriv_tensor.data<data_t>(),
in_deriv_tensor.data<data_t>(), rij_tensor.data<data_t>(), nlist_tensor.data<int>(),
grad_net_tensor.mutable_data<data_t>());
}));
if(grad_tensor.place() == paddle::PlaceType::kCPU){
PD_DISPATCH_FLOATING_TYPES(
grad_tensor.type(), "pd_prod_force_se_a_cpu_backward_kernel", ([&] {
PdProdForceSeAOpCPUBackwardKernel<data_t>(
nloc, nframes, ndescrpt, nnei,
grad_tensor.data<data_t>(),
net_deriv_tensor.data<data_t>(),
in_deriv_tensor.data<data_t>(),
rij_tensor.data<data_t>(), nlist_tensor.data<int>(),
grad_net_tensor.mutable_data<data_t>());
}));
}else{
PD_DISPATCH_FLOATING_TYPES(
grad_tensor.type(), "pd_prod_force_se_a_cpu_backward_kernel", ([&] {
PdProdForceSeAOpCPUBackwardKernel<data_t>(
nloc, nframes, ndescrpt, nnei,
grad_tensor.copy_to<data_t>(paddle::PlaceType::kCPU).data<data_t>(),
net_deriv_tensor.copy_to<data_t>(paddle::PlaceType::kCPU).data<data_t>(),
in_deriv_tensor.copy_to<data_t>(paddle::PlaceType::kCPU).data<data_t>(),
rij_tensor.copy_to<data_t>(paddle::PlaceType::kCPU).data<data_t>(),
nlist_tensor.copy_to<int>(paddle::PlaceType::kCPU).data<int>(),
grad_net_tensor.mutable_data<data_t>());
}));
}

return {grad_net_tensor};
}
Expand Down Expand Up @@ -207,14 +228,11 @@ const paddle::Tensor& nlist_tensor,
const paddle::Tensor& natoms_tensor,
int n_a_sel,
int n_r_sel){
if(grad_tensor.place() == paddle::PlaceType::kCPU){
return PdProdVirialSeAOpCPUBackward(
grad_tensor, net_deriv_tensor, in_deriv_tensor,
rij_tensor,
nlist_tensor, natoms_tensor, n_a_sel, n_r_sel);
}else{
PD_THROW("No Such kernel for PdFrodForceSeABackward!");
}
return PdProdVirialSeAOpCPUBackward(
grad_tensor, net_deriv_tensor, in_deriv_tensor,
rij_tensor,
nlist_tensor, natoms_tensor, n_a_sel, n_r_sel);

}

std::vector<std::vector<int64_t>> PdProdVirialSeAOpForwardInferShape(
Expand Down
Loading

0 comments on commit c24e33e

Please sign in to comment.