diff --git a/source/lib/include/prod_force_grad.h b/source/lib/include/prod_force_grad.h index b4b95f2ac3..f6ac58269f 100644 --- a/source/lib/include/prod_force_grad.h +++ b/source/lib/include/prod_force_grad.h @@ -20,4 +20,24 @@ void prod_force_grad_r_cpu( const int nloc, const int nnei); +#if GOOGLE_CUDA +template +void prod_force_grad_a_gpu_cuda( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const int * nlist, + const int nloc, + const int nnei); + +template +void prod_force_grad_r_gpu_cuda( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const int * nlist, + const int nloc, + const int nnei); +#endif // GOOGLE_CUDA + } diff --git a/source/lib/include/prod_virial_grad.h b/source/lib/include/prod_virial_grad.h index ab0f84ffec..7a8c87c0dd 100644 --- a/source/lib/include/prod_virial_grad.h +++ b/source/lib/include/prod_virial_grad.h @@ -22,4 +22,26 @@ void prod_virial_grad_r_cpu( const int nloc, const int nnei); +#if GOOGLE_CUDA +template +void prod_virial_grad_a_gpu_cuda( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const FPTYPE * rij, + const int * nlist, + const int nloc, + const int nnei); + +template +void prod_virial_grad_r_gpu_cuda( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const FPTYPE * rij, + const int * nlist, + const int nloc, + const int nnei); +#endif // GOOGLE_CUDA + } diff --git a/source/lib/src/cuda/CMakeLists.txt b/source/lib/src/cuda/CMakeLists.txt index f9832caefd..41a2ea091e 100644 --- a/source/lib/src/cuda/CMakeLists.txt +++ b/source/lib/src/cuda/CMakeLists.txt @@ -106,7 +106,7 @@ endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -DCUB_IGNORE_DEPRECATED_CPP_DIALECT -DCUB_IGNORE_DEPRECATED_CPP_DIALECT") set (SOURCE_FILES - prod_env_mat.cu prod_force.cu prod_virial.cu gelu.cu tabulate.cu coord.cu neighbor_list.cu region.cu + prod_env_mat.cu prod_force.cu prod_force_grad.cu prod_virial.cu prod_virial_grad.cu gelu.cu tabulate.cu coord.cu neighbor_list.cu region.cu ) cuda_add_library(deepmd_op_cuda SHARED ${SOURCE_FILES}) diff --git a/source/lib/src/cuda/prod_env_mat.cu b/source/lib/src/cuda/prod_env_mat.cu index fcc1aa970d..807b2c37de 100644 --- a/source/lib/src/cuda/prod_env_mat.cu +++ b/source/lib/src/cuda/prod_env_mat.cu @@ -114,7 +114,7 @@ __global__ void format_nlist_fill_a( } FPTYPE rr = sqrt(dev_dot(diff, diff)); if (rr <= rcut) { - key_in[idy] = type[j_idx] * 1E15+ (int_64)(rr * 1.0E13) / 10000000 * 10000000 + j_idx; + key_in[idy] = type[j_idx] * 1E14+ (int_64)(rr * 1.0E12) / 10000000 * 10000000 + j_idx; } } @@ -142,7 +142,7 @@ __global__ void format_nlist_fill_b( } for (unsigned int kk = 0; key_out[kk] != key_out[max_nbor_size - 1]; kk++) { - const int & nei_type = key_out[kk] / 1E15; + const int & nei_type = key_out[kk] / 1E14; if (nei_iter[nei_type] < sec[nei_type + 1]) { row_nlist[nei_iter[nei_type]++] = key_out[kk] % 10000000; } diff --git a/source/lib/src/cuda/prod_force_grad.cu b/source/lib/src/cuda/prod_force_grad.cu new file mode 100644 index 0000000000..7fd9359cfe --- /dev/null +++ b/source/lib/src/cuda/prod_force_grad.cu @@ -0,0 +1,143 @@ +#include "device.h" +#include "gpu_cuda.h" +#include "prod_force_grad.h" + +template +__device__ inline FPTYPE dev_dot( + const FPTYPE * arr1, + const FPTYPE * arr2) +{ + return arr1[0] * arr2[0] + arr1[1] * arr2[1] + arr1[2] * arr2[2]; +} + +template +__global__ void force_grad_wrt_center_atom( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const int ndescrpt) +{ + __shared__ FPTYPE grad_one[3]; + unsigned int center_idx = blockIdx.x; + unsigned int tid = threadIdx.x; + if(tid < 3){ + grad_one[tid] = grad[center_idx * 3 + tid]; + } + __syncthreads(); + unsigned int descrpt_idx = blockIdx.y * blockDim.x + tid; + if(descrpt_idx < ndescrpt){ + grad_net[center_idx * ndescrpt + descrpt_idx] -= dev_dot(grad_one, env_deriv + center_idx * ndescrpt * 3 + descrpt_idx * 3); + } +} + +template +__global__ void force_grad_wrt_neighbors_a( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const int * nlist, + const int nloc, + const int nnei) +{ + // idy -> nnei + const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int idy = blockIdx.y; + const unsigned int idw = threadIdx.y; + if (idx >= nloc) { + return; + } + int j_idx = nlist[idx * nnei + idy]; + if (j_idx < 0) { + return; + } + if (j_idx >= nloc) j_idx = j_idx % nloc; + grad_net[idx * nnei * 4 + idy * 4 + idw] += dev_dot(grad + j_idx * 3, env_deriv + idx * nnei * 4 * 3 + idy * 4 * 3 + idw * 3); +} + +template +__global__ void force_grad_wrt_neighbors_r( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const int * nlist, + const int nloc, + const int nnei) +{ + // idy -> nnei + const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int idy = blockIdx.y; + if (idx >= nloc) { + return; + } + int j_idx = nlist[idx * nnei + idy]; + if (j_idx < 0) { + return; + } + if (j_idx >= nloc) j_idx = j_idx % nloc; + grad_net[idx * nnei + idy] += dev_dot(grad + j_idx * 3, env_deriv + idx * nnei * 3 + idy * 3); +} + +namespace deepmd { +template +void prod_force_grad_a_gpu_cuda( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const int * nlist, + const int nloc, + const int nnei) +{ + const int ndescrpt = nnei * 4; + cudaErrcheck(cudaMemset( + grad_net, + 0.0, sizeof(FPTYPE) * nloc * ndescrpt)); + const int nblock = (ndescrpt + TPB - 1) / TPB; + dim3 block_grid(nloc, nblock); + dim3 thread_grid(TPB, 1); + force_grad_wrt_center_atom<<>>( + grad_net, + grad, env_deriv, ndescrpt); + + const int LEN = 128; + const int nblock_ = (nloc + LEN -1) / LEN; + dim3 block_grid_(nblock_, nnei); + dim3 thread_grid_(LEN, 4); + force_grad_wrt_neighbors_a<<>>( + grad_net, + grad, env_deriv, nlist, nloc, nnei); +} + +template +void prod_force_grad_r_gpu_cuda( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const int * nlist, + const int nloc, + const int nnei) +{ + const int ndescrpt = nnei * 1; + cudaErrcheck(cudaMemset( + grad_net, + 0.0, sizeof(FPTYPE) * nloc * ndescrpt)); + const int nblock = (ndescrpt + TPB - 1) / TPB; + dim3 block_grid(nloc, nblock); + dim3 thread_grid(TPB, 1); + force_grad_wrt_center_atom<<>>( + grad_net, + grad, env_deriv, ndescrpt); + + const int LEN = 128; + const int nblock_ = (nloc + LEN -1) / LEN; + dim3 block_grid_(nblock_, nnei); + dim3 thread_grid_(LEN, 1); + force_grad_wrt_neighbors_r<<>>( + grad_net, + grad, env_deriv, nlist, nloc, nnei); +} + +template void prod_force_grad_a_gpu_cuda(float * grad_net, const float * grad, const float * env_deriv, const int * nlist, const int nloc, const int nnei); +template void prod_force_grad_a_gpu_cuda(double * grad_net, const double * grad, const double * env_deriv, const int * nlist, const int nloc, const int nnei); +template void prod_force_grad_r_gpu_cuda(float * grad_net, const float * grad, const float * env_deriv, const int * nlist, const int nloc, const int nnei); +template void prod_force_grad_r_gpu_cuda(double * grad_net, const double * grad, const double * env_deriv, const int * nlist, const int nloc, const int nnei); +} \ No newline at end of file diff --git a/source/lib/src/cuda/prod_virial.cu b/source/lib/src/cuda/prod_virial.cu index 032e1b1c09..addb4df92a 100644 --- a/source/lib/src/cuda/prod_virial.cu +++ b/source/lib/src/cuda/prod_virial.cu @@ -1,6 +1,34 @@ +#include "device.h" #include "gpu_cuda.h" #include "prod_virial.h" +template < + typename FPTYPE, + int THREADS_PER_BLOCK> +__global__ void atom_virial_reduction( + FPTYPE * virial, + const FPTYPE * atom_virial, + const int nall) +{ + unsigned int bid = blockIdx.x; + unsigned int tid = threadIdx.x; + __shared__ FPTYPE data[THREADS_PER_BLOCK]; + data[tid] = 0.f; + for (int ii = tid; ii < nall; ii += THREADS_PER_BLOCK) { + data[tid] += atom_virial[ii * 9 + bid]; + } + __syncthreads(); + // do reduction in shared memory + for (int ii = THREADS_PER_BLOCK >> 1; ii > 0; ii >>= 1) { + if (tid < ii) { + data[tid] += data[tid + ii]; + } + __syncthreads(); + } + // write result for this block to global memory + if (tid == 0) virial[bid] = data[0]; +} + template __global__ void virial_deriv_wrt_neighbors_a( FPTYPE * virial, @@ -101,6 +129,10 @@ void prod_virial_a_gpu_cuda( virial_deriv_wrt_neighbors_a<<>>( virial, atom_virial, net_deriv, in_deriv, rij, nlist, nloc, nnei); + // reduction atom_virial to virial + atom_virial_reduction <<<9, TPB>>>( + virial, + atom_virial, nall); } template @@ -130,6 +162,10 @@ void prod_virial_r_gpu_cuda( virial_deriv_wrt_neighbors_r<<>>( virial, atom_virial, net_deriv, in_deriv, rij, nlist, nloc, nnei); + // reduction atom_virial to virial + atom_virial_reduction <<<9, TPB>>>( + virial, + atom_virial, nall); } template void prod_virial_a_gpu_cuda(float * virial, float * atom_virial, const float * net_deriv, const float * in_deriv, const float * rij, const int * nlist, const int nloc, const int nall, const int nnei); diff --git a/source/lib/src/cuda/prod_virial_grad.cu b/source/lib/src/cuda/prod_virial_grad.cu new file mode 100644 index 0000000000..2cdd25ec38 --- /dev/null +++ b/source/lib/src/cuda/prod_virial_grad.cu @@ -0,0 +1,141 @@ +#include "device.h" +#include "gpu_cuda.h" +#include "prod_virial_grad.h" + +template +__device__ inline FPTYPE dev_dot9( + const FPTYPE * arr1, + const FPTYPE * arr2) +{ + FPTYPE result = 0.0; + for(int ii=0; ii<9; ii++){ + result += arr1[ii] * arr2[ii]; + } + return result; +} + +template +__global__ void virial_grad_wrt_neighbors_a( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const FPTYPE * rij, + const int * nlist, + const int nloc, + const int nnei) +{ + // idy -> nnei + const unsigned int tid = threadIdx.x; + const unsigned int idx = blockIdx.x * blockDim.x + tid; + const unsigned int idy = blockIdx.y; + const unsigned int idw = threadIdx.y; + const int ndescrpt = nnei * 4; + __shared__ FPTYPE grad_one[9]; + if(tid < 9){ + grad_one[tid] = grad[tid]; + } + __syncthreads(); + if (idx >= nloc) { + return; + } + int j_idx = nlist[idx * nnei + idy]; + if (j_idx < 0) { + return; + } + FPTYPE tmp[9]; + for (int dd0 = 0; dd0 < 3; ++dd0){ + for (int dd1 = 0; dd1 < 3; ++dd1){ + tmp[dd0 * 3 + dd1] = rij[idx * nnei * 3 + idy * 3 + dd1] * env_deriv[idx * ndescrpt * 3 + idy * 4 * 3 + idw * 3 + dd0]; + } + } + grad_net[idx * ndescrpt + idy * 4 + idw] -= -1.0 * dev_dot9(grad_one, tmp); +} + +template +__global__ void virial_grad_wrt_neighbors_r( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const FPTYPE * rij, + const int * nlist, + const int nloc, + const int nnei) +{ + // idy -> nnei + const unsigned int tid = threadIdx.x; + const unsigned int idx = blockIdx.x * blockDim.x + tid; + const unsigned int idy = blockIdx.y; + const int ndescrpt = nnei; + __shared__ FPTYPE grad_one[9]; + if(tid < 9){ + grad_one[tid] = grad[tid]; + } + __syncthreads(); + if (idx >= nloc) { + return; + } + int j_idx = nlist[idx * nnei + idy]; + if (j_idx < 0) { + return; + } + FPTYPE tmp[9]; + for (int dd0 = 0; dd0 < 3; ++dd0){ + for (int dd1 = 0; dd1 < 3; ++dd1){ + tmp[dd0 * 3 + dd1] = rij[idx * nnei * 3 + idy * 3 + dd1] * env_deriv[idx * ndescrpt * 3 + idy * 3 + dd0]; + } + } + grad_net[idx * ndescrpt + idy] -= -1.0 * dev_dot9(grad_one, tmp); +} + +namespace deepmd { +template +void prod_virial_grad_a_gpu_cuda( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const FPTYPE * rij, + const int * nlist, + const int nloc, + const int nnei) +{ + const int ndescrpt = nnei * 4; + cudaErrcheck(cudaMemset( + grad_net, + 0.0, sizeof(FPTYPE) * nloc * ndescrpt)); + const int LEN = 128; + const int nblock = (nloc + LEN -1) / LEN; + dim3 block_grid(nblock, nnei); + dim3 thread_grid(LEN, 4); + virial_grad_wrt_neighbors_a<<>>( + grad_net, + grad, env_deriv, rij, nlist, nloc, nnei); +} + +template +void prod_virial_grad_r_gpu_cuda( + FPTYPE * grad_net, + const FPTYPE * grad, + const FPTYPE * env_deriv, + const FPTYPE * rij, + const int * nlist, + const int nloc, + const int nnei) +{ + const int ndescrpt = nnei; + cudaErrcheck(cudaMemset( + grad_net, + 0.0, sizeof(FPTYPE) * nloc * ndescrpt)); + const int LEN = 128; + const int nblock = (nloc + LEN -1) / LEN; + dim3 block_grid(nblock, nnei); + dim3 thread_grid(LEN, 1); + virial_grad_wrt_neighbors_r<<>>( + grad_net, + grad, env_deriv, rij, nlist, nloc, nnei); +} + +template void prod_virial_grad_a_gpu_cuda(float * grad_net, const float * grad, const float * env_deriv, const float * rij, const int * nlist, const int nloc, const int nnei); +template void prod_virial_grad_a_gpu_cuda(double * grad_net, const double * grad, const double * env_deriv, const double * rij, const int * nlist, const int nloc, const int nnei); +template void prod_virial_grad_r_gpu_cuda(float * grad_net, const float * grad, const float * env_deriv, const float * rij, const int * nlist, const int nloc, const int nnei); +template void prod_virial_grad_r_gpu_cuda(double * grad_net, const double * grad, const double * env_deriv, const double * rij, const int * nlist, const int nloc, const int nnei); +} \ No newline at end of file diff --git a/source/lib/tests/test_prod_force_grad_a.cc b/source/lib/tests/test_prod_force_grad_a.cc index 82ab484616..e456e2a8dd 100644 --- a/source/lib/tests/test_prod_force_grad_a.cc +++ b/source/lib/tests/test_prod_force_grad_a.cc @@ -4,6 +4,7 @@ #include "env_mat.h" #include "neighbor_list.h" #include "prod_force_grad.h" +#include "device.h" class TestProdForceGradA : public ::testing::Test { @@ -95,3 +96,33 @@ TEST_F(TestProdForceGradA, cpu) // } // printf("\n"); } + +#if GOOGLE_CUDA +TEST_F(TestProdForceGradA, gpu) +{ + std::vector grad_net(nloc * ndescrpt); + int * nlist_dev = NULL; + double * grad_net_dev = NULL, * grad_dev = NULL, * env_deriv_dev = NULL; + + deepmd::malloc_device_memory_sync(nlist_dev, nlist); + deepmd::malloc_device_memory_sync(grad_dev, grad); + deepmd::malloc_device_memory_sync(env_deriv_dev, env_deriv); + deepmd::malloc_device_memory(grad_net_dev, nloc * ndescrpt); + deepmd::prod_force_grad_a_gpu_cuda(grad_net_dev, grad_dev, env_deriv_dev, nlist_dev, nloc, nnei); + deepmd::memcpy_device_to_host(grad_net_dev, grad_net); + deepmd::delete_device_memory(nlist_dev); + deepmd::delete_device_memory(grad_dev); + deepmd::delete_device_memory(env_deriv_dev); + deepmd::delete_device_memory(grad_net_dev); + + EXPECT_EQ(grad_net.size(), nloc * ndescrpt); + EXPECT_EQ(grad_net.size(), expected_grad_net.size()); + for (int jj = 0; jj < grad_net.size(); ++jj){ + EXPECT_LT(fabs(grad_net[jj] - expected_grad_net[jj]) , 1e-5); + } + // for (int jj = 0; jj < nloc * ndescrpt; ++jj){ + // printf("%8.5f, ", grad_net[jj]); + // } + // printf("\n"); +} +#endif // GOOGLE_CUDA diff --git a/source/lib/tests/test_prod_force_grad_r.cc b/source/lib/tests/test_prod_force_grad_r.cc index 37534db7f8..da4ac96d3b 100644 --- a/source/lib/tests/test_prod_force_grad_r.cc +++ b/source/lib/tests/test_prod_force_grad_r.cc @@ -4,6 +4,7 @@ #include "env_mat.h" #include "neighbor_list.h" #include "prod_force_grad.h" +#include "device.h" class TestProdForceGradR : public ::testing::Test { @@ -95,3 +96,33 @@ TEST_F(TestProdForceGradR, cpu) // } // printf("\n"); } + +#if GOOGLE_CUDA +TEST_F(TestProdForceGradR, gpu) +{ + std::vector grad_net(nloc * ndescrpt); + int * nlist_dev = NULL; + double * grad_net_dev = NULL, * grad_dev = NULL, * env_deriv_dev = NULL; + + deepmd::malloc_device_memory_sync(nlist_dev, nlist); + deepmd::malloc_device_memory_sync(grad_dev, grad); + deepmd::malloc_device_memory_sync(env_deriv_dev, env_deriv); + deepmd::malloc_device_memory(grad_net_dev, nloc * ndescrpt); + deepmd::prod_force_grad_r_gpu_cuda(grad_net_dev, grad_dev, env_deriv_dev, nlist_dev, nloc, nnei); + deepmd::memcpy_device_to_host(grad_net_dev, grad_net); + deepmd::delete_device_memory(nlist_dev); + deepmd::delete_device_memory(grad_dev); + deepmd::delete_device_memory(env_deriv_dev); + deepmd::delete_device_memory(grad_net_dev); + + EXPECT_EQ(grad_net.size(), nloc * ndescrpt); + EXPECT_EQ(grad_net.size(), expected_grad_net.size()); + for (int jj = 0; jj < grad_net.size(); ++jj){ + EXPECT_LT(fabs(grad_net[jj] - expected_grad_net[jj]) , 1e-5); + } + // for (int jj = 0; jj < nloc * ndescrpt; ++jj){ + // printf("%8.5f, ", grad_net[jj]); + // } + // printf("\n"); +} +#endif // GOOGLE_CUDA diff --git a/source/lib/tests/test_prod_virial_a.cc b/source/lib/tests/test_prod_virial_a.cc index 4cade7c771..f63fc00fb5 100644 --- a/source/lib/tests/test_prod_virial_a.cc +++ b/source/lib/tests/test_prod_virial_a.cc @@ -141,13 +141,12 @@ TEST_F(TestProdVirialA, gpu_cuda) deepmd::delete_device_memory(env_deriv_dev); deepmd::delete_device_memory(rij_dev); // virial are not calculated in gpu currently; - for (int ii = 0; ii < 9; ii++) { - virial[ii] = 0; - } - for (int ii = 0; ii < nall * 9; ii++) { - virial[ii % 9] += atom_virial[ii]; - } - + // for (int ii = 0; ii < 9; ii++) { + // virial[ii] = 0; + // } + // for (int ii = 0; ii < nall * 9; ii++) { + // virial[ii % 9] += atom_virial[ii]; + // } EXPECT_EQ(virial.size(), 9); EXPECT_EQ(virial.size(), expected_virial.size()); EXPECT_EQ(atom_virial.size(), nall * 9); diff --git a/source/lib/tests/test_prod_virial_grad_a.cc b/source/lib/tests/test_prod_virial_grad_a.cc index 53ad63e965..461552c5a3 100644 --- a/source/lib/tests/test_prod_virial_grad_a.cc +++ b/source/lib/tests/test_prod_virial_grad_a.cc @@ -4,6 +4,7 @@ #include "env_mat.h" #include "neighbor_list.h" #include "prod_virial_grad.h" +#include "device.h" class TestProdVirialGradA : public ::testing::Test { @@ -99,3 +100,36 @@ TEST_F(TestProdVirialGradA, cpu) // } // printf("\n"); } + +#if GOOGLE_CUDA +TEST_F(TestProdVirialGradA, gpu) +{ + std::vector grad_net(nloc * ndescrpt); + int n_a_sel = nnei; + int * nlist_dev = NULL; + double * grad_net_dev = NULL, * grad_dev = NULL, * env_deriv_dev = NULL, * rij_dev = NULL; + + deepmd::malloc_device_memory_sync(nlist_dev, nlist); + deepmd::malloc_device_memory_sync(grad_dev, grad); + deepmd::malloc_device_memory_sync(env_deriv_dev, env_deriv); + deepmd::malloc_device_memory_sync(rij_dev, rij); + deepmd::malloc_device_memory(grad_net_dev, nloc * ndescrpt); + deepmd::prod_virial_grad_a_gpu_cuda(grad_net_dev, grad_dev, env_deriv_dev, rij_dev, nlist_dev, nloc, nnei); + deepmd::memcpy_device_to_host(grad_net_dev, grad_net); + deepmd::delete_device_memory(nlist_dev); + deepmd::delete_device_memory(grad_dev); + deepmd::delete_device_memory(env_deriv_dev); + deepmd::delete_device_memory(rij_dev); + deepmd::delete_device_memory(grad_net_dev); + + EXPECT_EQ(grad_net.size(), nloc * ndescrpt); + EXPECT_EQ(grad_net.size(), expected_grad_net.size()); + for (int jj = 0; jj < grad_net.size(); ++jj){ + EXPECT_LT(fabs(grad_net[jj] - expected_grad_net[jj]) , 1e-5); + } + // for (int jj = 0; jj < nloc * ndescrpt; ++jj){ + // printf("%8.5f, ", grad_net[jj]); + // } + // printf("\n"); +} +#endif // GOOGLE_CUDA \ No newline at end of file diff --git a/source/lib/tests/test_prod_virial_grad_r.cc b/source/lib/tests/test_prod_virial_grad_r.cc index 2cb0c91038..3f12599232 100644 --- a/source/lib/tests/test_prod_virial_grad_r.cc +++ b/source/lib/tests/test_prod_virial_grad_r.cc @@ -4,6 +4,7 @@ #include "env_mat.h" #include "neighbor_list.h" #include "prod_virial_grad.h" +#include "device.h" class TestProdVirialGradR : public ::testing::Test { @@ -99,3 +100,36 @@ TEST_F(TestProdVirialGradR, cpu) // } // printf("\n"); } + +#if GOOGLE_CUDA +TEST_F(TestProdVirialGradR, gpu) +{ + std::vector grad_net(nloc * ndescrpt); + int n_a_sel = nnei; + int * nlist_dev = NULL; + double * grad_net_dev = NULL, * grad_dev = NULL, * env_deriv_dev = NULL, * rij_dev = NULL; + + deepmd::malloc_device_memory_sync(nlist_dev, nlist); + deepmd::malloc_device_memory_sync(grad_dev, grad); + deepmd::malloc_device_memory_sync(env_deriv_dev, env_deriv); + deepmd::malloc_device_memory_sync(rij_dev, rij); + deepmd::malloc_device_memory(grad_net_dev, nloc * ndescrpt); + deepmd::prod_virial_grad_r_gpu_cuda(grad_net_dev, grad_dev, env_deriv_dev, rij_dev, nlist_dev, nloc, nnei); + deepmd::memcpy_device_to_host(grad_net_dev, grad_net); + deepmd::delete_device_memory(nlist_dev); + deepmd::delete_device_memory(grad_dev); + deepmd::delete_device_memory(env_deriv_dev); + deepmd::delete_device_memory(rij_dev); + deepmd::delete_device_memory(grad_net_dev); + + EXPECT_EQ(grad_net.size(), nloc * ndescrpt); + EXPECT_EQ(grad_net.size(), expected_grad_net.size()); + for (int jj = 0; jj < grad_net.size(); ++jj){ + EXPECT_LT(fabs(grad_net[jj] - expected_grad_net[jj]) , 1e-5); + } + // for (int jj = 0; jj < nloc * ndescrpt; ++jj){ + // printf("%8.5f, ", grad_net[jj]); + // } + // printf("\n"); +} +#endif // GOOGLE_CUDA diff --git a/source/lib/tests/test_prod_virial_r.cc b/source/lib/tests/test_prod_virial_r.cc index b321454b8e..be7e865962 100644 --- a/source/lib/tests/test_prod_virial_r.cc +++ b/source/lib/tests/test_prod_virial_r.cc @@ -141,13 +141,12 @@ TEST_F(TestProdVirialR, gpu_cuda) deepmd::delete_device_memory(env_deriv_dev); deepmd::delete_device_memory(rij_dev); // virial are not calculated in gpu currently; - for (int ii = 0; ii < 9; ii++) { - virial[ii] = 0; - } - for (int ii = 0; ii < nall * 9; ii++) { - virial[ii % 9] += atom_virial[ii]; - } - + // for (int ii = 0; ii < 9; ii++) { + // virial[ii] = 0; + // } + // for (int ii = 0; ii < nall * 9; ii++) { + // virial[ii % 9] += atom_virial[ii]; + // } EXPECT_EQ(virial.size(), 9); EXPECT_EQ(virial.size(), expected_virial.size()); EXPECT_EQ(atom_virial.size(), nall * 9); diff --git a/source/op/CMakeLists.txt b/source/op/CMakeLists.txt index 8ec98461be..c3dc7b6815 100644 --- a/source/op/CMakeLists.txt +++ b/source/op/CMakeLists.txt @@ -5,7 +5,7 @@ set(OP_LIB ${PROJECT_SOURCE_DIR}/lib/src/SimulationRegion.cpp ${PROJECT_SOURCE_D set (OP_CXX_FLAG -D_GLIBCXX_USE_CXX11_ABI=${OP_CXX_ABI} ) file(GLOB OP_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a_ef.cc descrpt_se_a_ef.cc descrpt_se_a_ef_para.cc descrpt_se_a_ef_vert.cc pair_tab.cc prod_force_multi_device.cc prod_virial_multi_device.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc gelu_multi_device.cc map_aparam.cc neighbor_stat.cc unaggregated_grad.cc tabulate_multi_device.cc prod_env_mat_multi_device.cc) file(GLOB OP_CUDA_SRC prod_force.cc prod_virial.cc descrpt.cc prod_env_mat_multi_device.cc pair_tab.cc prod_force_multi_device.cc prod_virial_multi_device.cc soft_min.cc soft_min_force.cc soft_min_virial.cc gelu_multi_device.cc tabulate_multi_device.cc) -file(GLOB OP_GRADS_SRC prod_force_grad.cc prod_force_se_a_grad.cc prod_force_se_r_grad.cc prod_virial_grad.cc prod_virial_se_a_grad.cc prod_virial_se_r_grad.cc soft_min_force_grad.cc soft_min_virial_grad.cc ) +file(GLOB OP_GRADS_SRC prod_force_grad.cc prod_force_grad_multi_device.cc prod_virial_grad.cc prod_virial_grad_multi_device.cc soft_min_force_grad.cc soft_min_virial_grad.cc ) file(GLOB OP_PY *.py) if (BUILD_CPP_IF) diff --git a/source/op/prod_force_grad_multi_device.cc b/source/op/prod_force_grad_multi_device.cc new file mode 100644 index 0000000000..1bda63903b --- /dev/null +++ b/source/op/prod_force_grad_multi_device.cc @@ -0,0 +1,251 @@ +#include "custom_op.h" +#include "prod_force_grad.h" + +REGISTER_OP("ProdForceSeAGrad") + .Attr("T: {float, double}") + .Input("grad: T") + .Input("net_deriv: T") + .Input("in_deriv: T") + .Input("nlist: int32") + .Input("natoms: int32") + .Attr("n_a_sel: int") + .Attr("n_r_sel: int") + .Output("grad_net: T"); + +REGISTER_OP("ProdForceSeRGrad") + .Attr("T: {float, double}") + .Input("grad: T") + .Input("net_deriv: T") + .Input("in_deriv: T") + .Input("nlist: int32") + .Input("natoms: int32") + .Output("grad_net: T"); + +template +class ProdForceSeAGradOp : public OpKernel { +public: + explicit ProdForceSeAGradOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("n_a_sel", &n_a_sel)); + OP_REQUIRES_OK(context, context->GetAttr("n_r_sel", &n_r_sel)); + n_a_shift = n_a_sel * 4; + } + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + int context_input_index = 0; + const Tensor& grad_tensor = context->input(context_input_index++); + const Tensor& net_deriv_tensor = context->input(context_input_index++); + const Tensor& in_deriv_tensor = context->input(context_input_index++); + const Tensor& nlist_tensor = context->input(context_input_index++); + const Tensor& natoms_tensor = context->input(context_input_index++); + + // set size of the sample + TensorShape grad_shape = grad_tensor.shape(); + TensorShape net_deriv_shape = net_deriv_tensor.shape(); + TensorShape in_deriv_shape = in_deriv_tensor.shape(); + TensorShape nlist_shape = nlist_tensor.shape(); + + OP_REQUIRES (context, (grad_shape.dims() == 2), errors::InvalidArgument ("Dim of grad should be 2")); + OP_REQUIRES (context, (net_deriv_shape.dims() == 2),errors::InvalidArgument ("Dim of net deriv should be 2")); + OP_REQUIRES (context, (in_deriv_shape.dims() == 2), errors::InvalidArgument ("Dim of input deriv should be 2")); + OP_REQUIRES (context, (nlist_shape.dims() == 2), errors::InvalidArgument ("Dim of nlist should be 2")); + OP_REQUIRES (context, (natoms_tensor.shape().dims() == 1), errors::InvalidArgument ("Dim of natoms should be 1")); + + OP_REQUIRES (context, (natoms_tensor.shape().dim_size(0) >= 3), errors::InvalidArgument ("number of atoms should be larger than (or equal to) 3")); + auto natoms = natoms_tensor .flat(); + + int nframes = net_deriv_tensor.shape().dim_size(0); + int nloc = natoms(0); + int ndescrpt = net_deriv_tensor.shape().dim_size(1) / nloc; + int nnei = nlist_tensor.shape().dim_size(1) / nloc; + + // check the sizes + OP_REQUIRES (context, (nframes == grad_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + OP_REQUIRES (context, (nframes == in_deriv_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + OP_REQUIRES (context, (nframes == nlist_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + + OP_REQUIRES (context, (nloc * 3 == grad_shape.dim_size(1)), errors::InvalidArgument ("input grad shape should be 3 x natoms")); + OP_REQUIRES (context, (nloc * ndescrpt * 3 == in_deriv_shape.dim_size(1)),errors::InvalidArgument ("number of descriptors should match")); + OP_REQUIRES (context, (nnei == n_a_sel + n_r_sel), errors::InvalidArgument ("number of neighbors should match")); + + // Create an output tensor + TensorShape grad_net_shape ; + grad_net_shape.AddDim (nframes); + grad_net_shape.AddDim (nloc * ndescrpt); + + // allocate the output tensor + Tensor* grad_net_tensor = NULL; + int context_output_index = 0; + OP_REQUIRES_OK(context, context->allocate_output( + context_output_index++, + grad_net_shape, + &grad_net_tensor)); + DeviceFunctor() ( + device, + context->eigen_device() + ); + assert (nframes == grad_net_shape.dim_size(0)); + assert (nframes == grad_shape.dim_size(0)); + assert (nframes == net_deriv_tensor.shape().dim_size(0)); + assert (nframes == in_deriv_tensor.shape().dim_size(0)); + assert (nframes == nlist_tensor.shape().dim_size(0)); + assert (nloc * ndescrpt == grad_net_shape.dim_size(1)); + assert (nloc * 3 == grad_shape.dim_size(1)); + assert (nloc * ndescrpt == net_deriv_tensor.shape().dim_size(1)); + assert (nloc * ndescrpt * 3 == in_deriv_tensor.shape().dim_size(1)); + assert (nloc * nnei == nlist_tensor.shape().dim_size(1)); + assert (nnei * 4 == ndescrpt); + // flat the tensors + FPTYPE * p_grad_net = grad_net_tensor->flat().data(); + const FPTYPE * p_grad = grad_tensor.flat().data(); + const FPTYPE * p_net_deriv = net_deriv_tensor.flat().data(); + const FPTYPE * p_in_deriv = in_deriv_tensor.flat().data(); + const int * p_nlist = nlist_tensor.flat().data(); + + for (int kk = 0; kk < nframes; ++kk){ + FPTYPE * grad_net = p_grad_net + kk * nloc * ndescrpt; + const FPTYPE * grad = p_grad + kk * nloc * 3; + const FPTYPE * in_deriv = p_in_deriv + kk * nloc * ndescrpt * 3; + const int * nlist = p_nlist + kk * nloc * nnei; + if (device == "GPU") { + #if GOOGLE_CUDA + deepmd::prod_force_grad_a_gpu_cuda( + grad_net, + grad, in_deriv, nlist, nloc, nnei); + #endif // GOOGLE_CUDA + } + else if (device == "CPU") { + deepmd::prod_force_grad_a_cpu( + grad_net, + grad, in_deriv, nlist, nloc, nnei); + } + } + } +private: + std::string device; + int n_r_sel, n_a_sel, n_a_shift; +}; + +template +class ProdForceSeRGradOp : public OpKernel +{ +public: + explicit ProdForceSeRGradOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + int context_input_index = 0; + const Tensor& grad_tensor = context->input(context_input_index++); + const Tensor& net_deriv_tensor = context->input(context_input_index++); + const Tensor& in_deriv_tensor = context->input(context_input_index++); + const Tensor& nlist_tensor = context->input(context_input_index++); + const Tensor& natoms_tensor = context->input(context_input_index++); + + // set size of the sample + TensorShape grad_shape = grad_tensor.shape(); + TensorShape net_deriv_shape = net_deriv_tensor.shape(); + TensorShape in_deriv_shape = in_deriv_tensor.shape(); + TensorShape nlist_shape = nlist_tensor.shape(); + + OP_REQUIRES (context, (grad_shape.dims() == 2), errors::InvalidArgument ("Dim of grad should be 2")); + OP_REQUIRES (context, (net_deriv_shape.dims() == 2),errors::InvalidArgument ("Dim of net deriv should be 2")); + OP_REQUIRES (context, (in_deriv_shape.dims() == 2), errors::InvalidArgument ("Dim of input deriv should be 2")); + OP_REQUIRES (context, (nlist_shape.dims() == 2), errors::InvalidArgument ("Dim of nlist should be 2")); + OP_REQUIRES (context, (natoms_tensor.shape().dims() == 1), errors::InvalidArgument ("Dim of natoms should be 1")); + + OP_REQUIRES (context, (natoms_tensor.shape().dim_size(0) >= 3), errors::InvalidArgument ("number of atoms should be larger than (or equal to) 3")); + auto natoms = natoms_tensor .flat(); + + int nframes = net_deriv_tensor.shape().dim_size(0); + int nloc = natoms(0); + int ndescrpt = net_deriv_tensor.shape().dim_size(1) / nloc; + int nnei = nlist_tensor.shape().dim_size(1) / nloc; + + // check the sizes + OP_REQUIRES (context, (nframes == grad_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + OP_REQUIRES (context, (nframes == in_deriv_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + OP_REQUIRES (context, (nframes == nlist_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + + OP_REQUIRES (context, (nloc * 3 == grad_shape.dim_size(1)), errors::InvalidArgument ("input grad shape should be 3 x natoms")); + OP_REQUIRES (context, (nloc * ndescrpt * 3 == in_deriv_shape.dim_size(1)),errors::InvalidArgument ("number of descriptors should match")); + + // Create an output tensor + TensorShape grad_net_shape ; + grad_net_shape.AddDim (nframes); + grad_net_shape.AddDim (nloc * ndescrpt); + + // allocate the output tensor + Tensor* grad_net_tensor = NULL; + int context_output_index = 0; + OP_REQUIRES_OK(context, context->allocate_output( + context_output_index++, + grad_net_shape, + &grad_net_tensor)); + DeviceFunctor() ( + device, + context->eigen_device() + ); + assert (nframes == grad_net_shape.dim_size(0)); + assert (nframes == grad_shape.dim_size(0)); + assert (nframes == net_deriv_tensor.shape().dim_size(0)); + assert (nframes == in_deriv_tensor.shape().dim_size(0)); + assert (nframes == nlist_tensor.shape().dim_size(0)); + assert (nloc * ndescrpt == grad_net_shape.dim_size(1)); + assert (nloc * 3 == grad_shape.dim_size(1)); + assert (nloc * ndescrpt == net_deriv_tensor.shape().dim_size(1)); + assert (nloc * ndescrpt * 3 == in_deriv_tensor.shape().dim_size(1)); + assert (nloc * nnei == nlist_tensor.shape().dim_size(1)); + assert (nnei * 1 == ndescrpt); + // flat the tensors + FPTYPE * p_grad_net = grad_net_tensor->flat().data(); + const FPTYPE * p_grad = grad_tensor.flat().data(); + const FPTYPE * p_net_deriv = net_deriv_tensor.flat().data(); + const FPTYPE * p_in_deriv = in_deriv_tensor.flat().data(); + const int * p_nlist = nlist_tensor.flat().data(); + + // loop over frames + for (int kk = 0; kk < nframes; ++kk){ + FPTYPE * grad_net = p_grad_net + kk * nloc * ndescrpt; + const FPTYPE * grad = p_grad + kk * nloc * 3; + const FPTYPE * in_deriv = p_in_deriv + kk * nloc * ndescrpt * 3; + const int * nlist = p_nlist + kk * nloc * nnei; + if (device == "GPU") { + #if GOOGLE_CUDA + deepmd::prod_force_grad_r_gpu_cuda( + grad_net, + grad, in_deriv, nlist, nloc, nnei); + #endif // GOOGLE_CUDA + } + else if (device == "CPU") { + deepmd::prod_force_grad_r_cpu( + grad_net, + grad, in_deriv, nlist, nloc, nnei); + } + } + } + private: + std::string device; +}; + +// Register the CPU kernels. +#define REGISTER_CPU(T) \ +REGISTER_KERNEL_BUILDER( \ + Name("ProdForceSeAGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + ProdForceSeAGradOp); \ +REGISTER_KERNEL_BUILDER( \ + Name("ProdForceSeRGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + ProdForceSeRGradOp); +REGISTER_CPU(float); +REGISTER_CPU(double); +// Register the GPU kernels. +#if GOOGLE_CUDA +#define REGISTER_GPU(T) \ +REGISTER_KERNEL_BUILDER( \ + Name("ProdForceSeAGrad").Device(DEVICE_GPU).TypeConstraint("T").HostMemory("natoms"), \ + ProdForceSeAGradOp); \ +REGISTER_KERNEL_BUILDER( \ + Name("ProdForceSeRGrad").Device(DEVICE_GPU).TypeConstraint("T").HostMemory("natoms"), \ + ProdForceSeRGradOp); +REGISTER_GPU(float); +REGISTER_GPU(double); +#endif // GOOGLE_CUDA \ No newline at end of file diff --git a/source/op/prod_virial_grad_multi_device.cc b/source/op/prod_virial_grad_multi_device.cc new file mode 100644 index 0000000000..ac74d1d141 --- /dev/null +++ b/source/op/prod_virial_grad_multi_device.cc @@ -0,0 +1,275 @@ +#include "custom_op.h" +#include "prod_virial_grad.h" + +REGISTER_OP("ProdVirialSeAGrad") + .Attr("T: {float, double}") + .Input("grad: T") + .Input("net_deriv: T") + .Input("in_deriv: T") + .Input("rij: T") + .Input("nlist: int32") + .Input("natoms: int32") + .Attr("n_a_sel: int") + .Attr("n_r_sel: int") + .Output("grad_net: T"); + +REGISTER_OP("ProdVirialSeRGrad") + .Attr("T: {float, double}") + .Input("grad: T") + .Input("net_deriv: T") + .Input("in_deriv: T") + .Input("rij: T") + .Input("nlist: int32") + .Input("natoms: int32") + .Output("grad_net: T"); + +template +class ProdVirialSeAGradOp : public OpKernel +{ +public: + explicit ProdVirialSeAGradOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("n_a_sel", &n_a_sel)); + OP_REQUIRES_OK(context, context->GetAttr("n_r_sel", &n_r_sel)); + n_a_shift = n_a_sel * 4; + } + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + int context_input_index = 0; + const Tensor& grad_tensor = context->input(context_input_index++); + const Tensor& net_deriv_tensor = context->input(context_input_index++); + const Tensor& in_deriv_tensor = context->input(context_input_index++); + const Tensor& rij_tensor = context->input(context_input_index++); + const Tensor& nlist_tensor = context->input(context_input_index++); + const Tensor& natoms_tensor = context->input(context_input_index++); + + // set size of the sample + TensorShape grad_shape = grad_tensor.shape(); + TensorShape net_deriv_shape = net_deriv_tensor.shape(); + TensorShape in_deriv_shape = in_deriv_tensor.shape(); + TensorShape rij_shape = rij_tensor.shape(); + TensorShape nlist_shape = nlist_tensor.shape(); + + OP_REQUIRES (context, (grad_shape.dims() == 2), errors::InvalidArgument ("Dim of grad should be 2")); + OP_REQUIRES (context, (net_deriv_shape.dims() == 2),errors::InvalidArgument ("Dim of net deriv should be 2")); + OP_REQUIRES (context, (in_deriv_shape.dims() == 2), errors::InvalidArgument ("Dim of input deriv should be 2")); + OP_REQUIRES (context, (rij_shape.dims() == 2), errors::InvalidArgument ("Dim of rij should be 2")); + OP_REQUIRES (context, (nlist_shape.dims() == 2), errors::InvalidArgument ("Dim of nlist should be 2")); + OP_REQUIRES (context, (natoms_tensor.shape().dims() == 1), errors::InvalidArgument ("Dim of natoms should be 1")); + + OP_REQUIRES (context, (natoms_tensor.shape().dim_size(0) >= 3), errors::InvalidArgument ("number of atoms should be larger than (or equal to) 3")); + auto natoms = natoms_tensor .flat(); + + int nframes = net_deriv_tensor.shape().dim_size(0); + int nloc = natoms(0); + int ndescrpt = net_deriv_tensor.shape().dim_size(1) / nloc; + int nnei = nlist_tensor.shape().dim_size(1) / nloc; + + // check the sizes + OP_REQUIRES (context, (nframes == grad_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + OP_REQUIRES (context, (nframes == in_deriv_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + OP_REQUIRES (context, (nframes == rij_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + OP_REQUIRES (context, (nframes == nlist_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + + OP_REQUIRES (context, (9 == grad_shape.dim_size(1)), errors::InvalidArgument ("input grad shape should be 3 x natoms")); + OP_REQUIRES (context, (nloc * ndescrpt * 3 == in_deriv_shape.dim_size(1)),errors::InvalidArgument ("number of descriptors should match")); + OP_REQUIRES (context, (nloc * nnei * 3 == rij_shape.dim_size(1)), errors::InvalidArgument ("dim of rij should be nnei * 3")); + OP_REQUIRES (context, (nnei == n_a_sel + n_r_sel), errors::InvalidArgument ("number of neighbors should match")); + + // Create an output tensor + TensorShape grad_net_shape ; + grad_net_shape.AddDim (nframes); + grad_net_shape.AddDim (nloc * ndescrpt); + + // allocate the output tensor + Tensor* grad_net_tensor = NULL; + int context_output_index = 0; + OP_REQUIRES_OK(context, context->allocate_output( + context_output_index++, + grad_net_shape, + &grad_net_tensor)); + DeviceFunctor() ( + device, + context->eigen_device() + ); + assert (nframes == grad_net_shape.dim_size(0)); + assert (nframes == grad_shape.dim_size(0)); + assert (nframes == net_deriv_tensor.shape().dim_size(0)); + assert (nframes == in_deriv_tensor.shape().dim_size(0)); + assert (nframes == rij_tensor.shape().dim_size(0)); + assert (nframes == nlist_tensor.shape().dim_size(0)); + assert (nloc * ndescrpt == grad_net_shape.dim_size(1)); + assert (9 == grad_shape.dim_size(1)); + assert (nloc * ndescrpt == net_deriv_tensor.shape().dim_size(1)); + assert (nloc * ndescrpt * 3 == in_deriv_tensor.shape().dim_size(1)); + assert (nloc * nnei * 3 == rij_tensor.shape().dim_size(1)); + assert (nloc * nnei == nlist_tensor.shape().dim_size(1)); + assert (nnei * 4 == ndescrpt); + + // flat the tensors + FPTYPE * p_grad_net = grad_net_tensor->flat().data(); + const FPTYPE * p_grad = grad_tensor.flat().data(); + const FPTYPE * p_net_deriv = net_deriv_tensor.flat().data(); + const FPTYPE * p_in_deriv = in_deriv_tensor.flat().data(); + const FPTYPE * p_rij = rij_tensor.flat().data(); + const int * p_nlist = nlist_tensor.flat().data(); + + // loop over frames + for (int kk = 0; kk < nframes; ++kk){ + FPTYPE * grad_net = p_grad_net + kk * nloc * ndescrpt; + const FPTYPE * grad = p_grad + kk * nloc * 3; + const FPTYPE * in_deriv = p_in_deriv + kk * nloc * ndescrpt * 3; + const FPTYPE * rij = p_rij + kk * nloc * nnei * 3; + const int * nlist = p_nlist + kk * nloc * nnei; + if (device == "GPU") { + #if GOOGLE_CUDA + deepmd::prod_virial_grad_a_gpu_cuda( + grad_net, + grad, in_deriv, rij, nlist, nloc, nnei); + #endif // GOOGLE_CUDA + } + else if (device == "CPU") { + deepmd::prod_virial_grad_a_cpu( + grad_net, + grad, in_deriv, rij, nlist, nloc, nnei); + } + } + } +private: + std::string device; + int n_r_sel, n_a_sel, n_a_shift; +}; + +template +class ProdVirialSeRGradOp : public OpKernel +{ +public: + explicit ProdVirialSeRGradOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + int context_input_index = 0; + const Tensor& grad_tensor = context->input(context_input_index++); + const Tensor& net_deriv_tensor = context->input(context_input_index++); + const Tensor& in_deriv_tensor = context->input(context_input_index++); + const Tensor& rij_tensor = context->input(context_input_index++); + const Tensor& nlist_tensor = context->input(context_input_index++); + const Tensor& natoms_tensor = context->input(context_input_index++); + + // set size of the sample + TensorShape grad_shape = grad_tensor.shape(); + TensorShape net_deriv_shape = net_deriv_tensor.shape(); + TensorShape in_deriv_shape = in_deriv_tensor.shape(); + TensorShape rij_shape = rij_tensor.shape(); + TensorShape nlist_shape = nlist_tensor.shape(); + + OP_REQUIRES (context, (grad_shape.dims() == 2), errors::InvalidArgument ("Dim of grad should be 2")); + OP_REQUIRES (context, (net_deriv_shape.dims() == 2),errors::InvalidArgument ("Dim of net deriv should be 2")); + OP_REQUIRES (context, (in_deriv_shape.dims() == 2), errors::InvalidArgument ("Dim of input deriv should be 2")); + OP_REQUIRES (context, (rij_shape.dims() == 2), errors::InvalidArgument ("Dim of rij should be 2")); + OP_REQUIRES (context, (nlist_shape.dims() == 2), errors::InvalidArgument ("Dim of nlist should be 2")); + OP_REQUIRES (context, (natoms_tensor.shape().dims() == 1), errors::InvalidArgument ("Dim of natoms should be 1")); + + OP_REQUIRES (context, (natoms_tensor.shape().dim_size(0) >= 3), errors::InvalidArgument ("number of atoms should be larger than (or equal to) 3")); + auto natoms = natoms_tensor .flat(); + + int nframes = net_deriv_tensor.shape().dim_size(0); + int nloc = natoms(0); + int ndescrpt = net_deriv_tensor.shape().dim_size(1) / nloc; + int nnei = nlist_tensor.shape().dim_size(1) / nloc; + + // check the sizes + OP_REQUIRES (context, (nframes == grad_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + OP_REQUIRES (context, (nframes == in_deriv_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + OP_REQUIRES (context, (nframes == rij_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + OP_REQUIRES (context, (nframes == nlist_shape.dim_size(0)), errors::InvalidArgument ("number of frames should match")); + + OP_REQUIRES (context, (9 == grad_shape.dim_size(1)), errors::InvalidArgument ("input grad shape should be 3 x natoms")); + OP_REQUIRES (context, (nloc * ndescrpt * 3 == in_deriv_shape.dim_size(1)),errors::InvalidArgument ("number of descriptors should match")); + OP_REQUIRES (context, (nloc * nnei * 3 == rij_shape.dim_size(1)), errors::InvalidArgument ("dim of rij should be nnei * 3")); + + // Create an output tensor + TensorShape grad_net_shape ; + grad_net_shape.AddDim (nframes); + grad_net_shape.AddDim (nloc * ndescrpt); + + // allocate the output tensor + Tensor* grad_net_tensor = NULL; + int context_output_index = 0; + OP_REQUIRES_OK(context, context->allocate_output( + context_output_index++, + grad_net_shape, + &grad_net_tensor)); + DeviceFunctor() ( + device, + context->eigen_device() + ); + assert (nframes == grad_net_shape.dim_size(0)); + assert (nframes == grad_shape.dim_size(0)); + assert (nframes == net_deriv_tensor.shape().dim_size(0)); + assert (nframes == in_deriv_tensor.shape().dim_size(0)); + assert (nframes == rij_tensor.shape().dim_size(0)); + assert (nframes == nlist_tensor.shape().dim_size(0)); + assert (nloc * ndescrpt == grad_net_shape.dim_size(1)); + assert (9 == grad_shape.dim_size(1)); + assert (nloc * ndescrpt == net_deriv_tensor.shape().dim_size(1)); + assert (nloc * ndescrpt * 3 == in_deriv_tensor.shape().dim_size(1)); + assert (nloc * nnei * 3 == rij_tensor.shape().dim_size(1)); + assert (nloc * nnei == nlist_tensor.shape().dim_size(1)); + assert (nnei * 1 == ndescrpt); + + // flat the tensors + FPTYPE * p_grad_net = grad_net_tensor->flat().data(); + const FPTYPE * p_grad = grad_tensor.flat().data(); + const FPTYPE * p_net_deriv = net_deriv_tensor.flat().data(); + const FPTYPE * p_in_deriv = in_deriv_tensor.flat().data(); + const FPTYPE * p_rij = rij_tensor.flat().data(); + const int * p_nlist = nlist_tensor.flat().data(); + + // loop over frames + for (int kk = 0; kk < nframes; ++kk){ + FPTYPE * grad_net = p_grad_net + kk * nloc * ndescrpt; + const FPTYPE * grad = p_grad + kk * nloc * 3; + const FPTYPE * in_deriv = p_in_deriv + kk * nloc * ndescrpt * 3; + const FPTYPE * rij = p_rij + kk * nloc * nnei * 3; + const int * nlist = p_nlist + kk * nloc * nnei; + if (device == "GPU") { + #if GOOGLE_CUDA + deepmd::prod_virial_grad_r_gpu_cuda( + grad_net, + grad, in_deriv, rij, nlist, nloc, nnei); + #endif // GOOGLE_CUDA + } + else if (device == "CPU") { + deepmd::prod_virial_grad_r_cpu( + grad_net, + grad, in_deriv, rij, nlist, nloc, nnei); + } + } + } +private: + std::string device; +}; + +// Register the CPU kernels. +#define REGISTER_CPU(T) \ +REGISTER_KERNEL_BUILDER( \ + Name("ProdVirialSeAGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + ProdVirialSeAGradOp); \ +REGISTER_KERNEL_BUILDER( \ + Name("ProdVirialSeRGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + ProdVirialSeRGradOp); +REGISTER_CPU(float); +REGISTER_CPU(double); +// Register the GPU kernels. +#if GOOGLE_CUDA +#define REGISTER_GPU(T) \ +REGISTER_KERNEL_BUILDER( \ + Name("ProdVirialSeAGrad").Device(DEVICE_GPU).TypeConstraint("T").HostMemory("natoms"), \ + ProdVirialSeAGradOp); \ +REGISTER_KERNEL_BUILDER( \ + Name("ProdVirialSeRGrad").Device(DEVICE_GPU).TypeConstraint("T").HostMemory("natoms"), \ + ProdVirialSeRGradOp); +REGISTER_GPU(float); +REGISTER_GPU(double); +#endif // GOOGLE_CUDA \ No newline at end of file