diff --git a/source/op/paddle_ops/srcs/pd_prod_env_mat_multi_devices_cpu.cc b/source/op/paddle_ops/srcs/pd_prod_env_mat_multi_devices_cpu.cc index 0a35bba1a9..903207f10c 100644 --- a/source/op/paddle_ops/srcs/pd_prod_env_mat_multi_devices_cpu.cc +++ b/source/op/paddle_ops/srcs/pd_prod_env_mat_multi_devices_cpu.cc @@ -8,7 +8,7 @@ #include #define CHECK_INPUT(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") -#define CHECK_INPUT_READY(x) PD_CHECK(x.IsInitialized(), #x " must be initialized before usage.") +#define CHECK_INPUT_READY(x) PD_CHECK(x.is_initialized(), #x " must be initialized before usage.") #define CHECK_INPUT_DIM(x, value) PD_CHECK(x.shape().size() == value, #x "'s dim should be " #value ".") template static int diff --git a/source/op/paddle_ops/srcs/pd_prod_force_se_a_multi_devices_cpu.cc b/source/op/paddle_ops/srcs/pd_prod_force_se_a_multi_devices_cpu.cc index 9591e414d0..d872ffdb44 100644 --- a/source/op/paddle_ops/srcs/pd_prod_force_se_a_multi_devices_cpu.cc +++ b/source/op/paddle_ops/srcs/pd_prod_force_se_a_multi_devices_cpu.cc @@ -5,7 +5,7 @@ #define CHECK_INPUT(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") -#define CHECK_INPUT_READY(x) PD_CHECK(x.IsInitialized(), #x " must be initialized before usage.") +#define CHECK_INPUT_READY(x) PD_CHECK(x.is_initialized(), #x " must be initialized before usage.") #define CHECK_INPUT_DIM(x, value) PD_CHECK(x.shape().size() == value, #x "'s dim should be " #value ".") diff --git a/source/op/paddle_ops/srcs/pd_prod_virial_se_a_multi_devices_cpu.cc b/source/op/paddle_ops/srcs/pd_prod_virial_se_a_multi_devices_cpu.cc index 381ff24bcc..97df33752c 100644 --- a/source/op/paddle_ops/srcs/pd_prod_virial_se_a_multi_devices_cpu.cc +++ b/source/op/paddle_ops/srcs/pd_prod_virial_se_a_multi_devices_cpu.cc @@ -5,7 +5,7 @@ #define CHECK_INPUT(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") -#define CHECK_INPUT_READY(x) PD_CHECK(x.IsInitialized(), #x " must be initialized before usage.") +#define CHECK_INPUT_READY(x) PD_CHECK(x.is_initialized(), #x " must be initialized before usage.") #define CHECK_INPUT_DIM(x, value) PD_CHECK(x.shape().size() == value, #x "'s dim should be " #value ".") @@ -47,11 +47,11 @@ const paddle::Tensor& natoms_tensor, int n_a_sel, int n_r_sel ){ - CHECK_INPUT(net_deriv_tensor); - CHECK_INPUT(in_deriv_tensor); - CHECK_INPUT(rij_tensor); - CHECK_INPUT(nlist_tensor); - CHECK_INPUT(natoms_tensor); + CHECK_INPUT_READY(net_deriv_tensor); + CHECK_INPUT_READY(in_deriv_tensor); + CHECK_INPUT_READY(rij_tensor); + CHECK_INPUT_READY(nlist_tensor); + CHECK_INPUT_READY(natoms_tensor); CHECK_INPUT_DIM(net_deriv_tensor, 2); CHECK_INPUT_DIM(in_deriv_tensor, 2); @@ -60,7 +60,13 @@ int n_r_sel CHECK_INPUT_DIM(natoms_tensor, 1); PD_CHECK(natoms_tensor.shape()[0] >= 3, "number of atoms should be larger than (or equal to) 3"); - const int* natoms = natoms_tensor.data(); + // TODO:(jiabin) This code should be removed when virial cuda kernel fixed. + const int* natoms = nullptr; + if(natoms_tensor.place() != paddle::PlaceType::kCPU){ + natoms = natoms_tensor.copy_to(paddle::PlaceType::kCPU).data(); + }else{ + natoms = natoms_tensor.data(); + } int nloc = natoms[0]; int nall = natoms[1]; int nnei = nlist_tensor.shape()[1] / nloc; @@ -79,14 +85,29 @@ int n_r_sel paddle::Tensor virial_tensor = paddle::Tensor(paddle::PlaceType::kCPU, virial_shape); paddle::Tensor atom_virial_tensor = paddle::Tensor(paddle::PlaceType::kCPU, atom_virial_shape); - PD_DISPATCH_FLOATING_TYPES( - net_deriv_tensor.type(), "pd_prod_virial_se_a_cpu_forward_kernel", ([&] { - PdProdVirialSeAOpForwardCPUKernel( - nloc, nall, ndescrpt, nnei, nframes, - virial_tensor.mutable_data(), atom_virial_tensor.mutable_data(), - net_deriv_tensor.data(), in_deriv_tensor.data(), - rij_tensor.data(), nlist_tensor.data()); - })); + if(natoms_tensor.place() == paddle::PlaceType::kCPU){ + PD_DISPATCH_FLOATING_TYPES( + net_deriv_tensor.type(), "pd_prod_virial_se_a_cpu_forward_kernel", ([&] { + PdProdVirialSeAOpForwardCPUKernel( + nloc, nall, ndescrpt, nnei, nframes, + virial_tensor.mutable_data(), atom_virial_tensor.mutable_data(), + net_deriv_tensor.data(), in_deriv_tensor.data(), + rij_tensor.data(), nlist_tensor.data()); + })); + }else{ + PD_DISPATCH_FLOATING_TYPES( + net_deriv_tensor.type(), "pd_prod_virial_se_a_cpu_forward_kernel", ([&] { + PdProdVirialSeAOpForwardCPUKernel( + nloc, nall, ndescrpt, nnei, nframes, + virial_tensor.mutable_data(), + atom_virial_tensor.mutable_data(), + net_deriv_tensor.copy_to(paddle::PlaceType::kCPU).data(), + in_deriv_tensor.copy_to(paddle::PlaceType::kCPU).data(), + rij_tensor.copy_to(paddle::PlaceType::kCPU).data(), + nlist_tensor.copy_to(paddle::PlaceType::kCPU).data()); + })); + } + return {virial_tensor, atom_virial_tensor}; } @@ -210,13 +231,15 @@ const paddle::Tensor& nlist_tensor, const paddle::Tensor& natoms_tensor, int n_a_sel, int n_r_sel){ - if(net_deriv_tensor.place() == paddle::PlaceType::kCPU){ - return PdProdVirialSeAOpCPUForward(net_deriv_tensor, in_deriv_tensor, rij_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel); - }else if(net_deriv_tensor.place() == paddle::PlaceType::kGPU){ - return PdProdVirialSeAOpCUDAForward(net_deriv_tensor, in_deriv_tensor, rij_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel); - }else{ - PD_THROW("No Such kernel for PdFrodForceSeAForward!"); - } + return PdProdVirialSeAOpCPUForward(net_deriv_tensor, in_deriv_tensor, rij_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel); + // TODO:(jiabin) Support this when virial cuda kernel fixed. + // if(net_deriv_tensor.place() == paddle::PlaceType::kCPU){ + // return PdProdVirialSeAOpCPUForward(net_deriv_tensor, in_deriv_tensor, rij_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel); + // }else if(net_deriv_tensor.place() == paddle::PlaceType::kGPU){ + // return PdProdVirialSeAOpCUDAForward(net_deriv_tensor, in_deriv_tensor, rij_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel); + // }else{ + // PD_THROW("No Such kernel for PdFrodForceSeAForward!"); + // } } std::vector PdProdVirialSeABackward( diff --git a/source/tests/test_pd_prod_force_and_virial.py b/source/tests/test_pd_prod_force_and_virial.py index 1da998e5ea..a71e2d44c0 100644 --- a/source/tests/test_pd_prod_force_and_virial.py +++ b/source/tests/test_pd_prod_force_and_virial.py @@ -216,7 +216,7 @@ def virial_test (inter, hh = global_default_fv_hh, suffix = '') : # TODO:(jiabin): Remove this line when virial lib fixed by Wang Han - paddle.set_device("cpu") + paddle.set_device("gpu:0") # set weights w0 = np.ones (inter.ndescrpt) inter.net_w_i = np.copy(w0)