Skip to content

Commit

Permalink
Merge pull request #490 from iProzd/api
Browse files Browse the repository at this point in the history
api prod_force_grad&prod_virial_grad gpu update
  • Loading branch information
amcadmus authored Apr 9, 2021
2 parents 0ae18d3 + 0cc22bc commit 7fa54f7
Show file tree
Hide file tree
Showing 16 changed files with 1,034 additions and 18 deletions.
20 changes: 20 additions & 0 deletions source/lib/include/prod_force_grad.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,24 @@ void prod_force_grad_r_cpu(
const int nloc,
const int nnei);

#if GOOGLE_CUDA
template<typename FPTYPE>
void prod_force_grad_a_gpu_cuda(
FPTYPE * grad_net,
const FPTYPE * grad,
const FPTYPE * env_deriv,
const int * nlist,
const int nloc,
const int nnei);

template<typename FPTYPE>
void prod_force_grad_r_gpu_cuda(
FPTYPE * grad_net,
const FPTYPE * grad,
const FPTYPE * env_deriv,
const int * nlist,
const int nloc,
const int nnei);
#endif // GOOGLE_CUDA

}
22 changes: 22 additions & 0 deletions source/lib/include/prod_virial_grad.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,26 @@ void prod_virial_grad_r_cpu(
const int nloc,
const int nnei);

#if GOOGLE_CUDA
template<typename FPTYPE>
void prod_virial_grad_a_gpu_cuda(
FPTYPE * grad_net,
const FPTYPE * grad,
const FPTYPE * env_deriv,
const FPTYPE * rij,
const int * nlist,
const int nloc,
const int nnei);

template<typename FPTYPE>
void prod_virial_grad_r_gpu_cuda(
FPTYPE * grad_net,
const FPTYPE * grad,
const FPTYPE * env_deriv,
const FPTYPE * rij,
const int * nlist,
const int nloc,
const int nnei);
#endif // GOOGLE_CUDA

}
2 changes: 1 addition & 1 deletion source/lib/src/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -DCUB_IGNORE_DEPRECATED_CPP_DIALECT -DCUB_IGNORE_DEPRECATED_CPP_DIALECT")

set (SOURCE_FILES
prod_env_mat.cu prod_force.cu prod_virial.cu gelu.cu tabulate.cu coord.cu neighbor_list.cu region.cu
prod_env_mat.cu prod_force.cu prod_force_grad.cu prod_virial.cu prod_virial_grad.cu gelu.cu tabulate.cu coord.cu neighbor_list.cu region.cu
)

cuda_add_library(deepmd_op_cuda SHARED ${SOURCE_FILES})
Expand Down
4 changes: 2 additions & 2 deletions source/lib/src/cuda/prod_env_mat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ __global__ void format_nlist_fill_a(
}
FPTYPE rr = sqrt(dev_dot(diff, diff));
if (rr <= rcut) {
key_in[idy] = type[j_idx] * 1E15+ (int_64)(rr * 1.0E13) / 10000000 * 10000000 + j_idx;
key_in[idy] = type[j_idx] * 1E14+ (int_64)(rr * 1.0E12) / 10000000 * 10000000 + j_idx;
}
}

Expand Down Expand Up @@ -142,7 +142,7 @@ __global__ void format_nlist_fill_b(
}

for (unsigned int kk = 0; key_out[kk] != key_out[max_nbor_size - 1]; kk++) {
const int & nei_type = key_out[kk] / 1E15;
const int & nei_type = key_out[kk] / 1E14;
if (nei_iter[nei_type] < sec[nei_type + 1]) {
row_nlist[nei_iter[nei_type]++] = key_out[kk] % 10000000;
}
Expand Down
143 changes: 143 additions & 0 deletions source/lib/src/cuda/prod_force_grad.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#include "device.h"
#include "gpu_cuda.h"
#include "prod_force_grad.h"

template<typename FPTYPE>
__device__ inline FPTYPE dev_dot(
const FPTYPE * arr1,
const FPTYPE * arr2)
{
return arr1[0] * arr2[0] + arr1[1] * arr2[1] + arr1[2] * arr2[2];
}

template<typename FPTYPE>
__global__ void force_grad_wrt_center_atom(
FPTYPE * grad_net,
const FPTYPE * grad,
const FPTYPE * env_deriv,
const int ndescrpt)
{
__shared__ FPTYPE grad_one[3];
unsigned int center_idx = blockIdx.x;
unsigned int tid = threadIdx.x;
if(tid < 3){
grad_one[tid] = grad[center_idx * 3 + tid];
}
__syncthreads();
unsigned int descrpt_idx = blockIdx.y * blockDim.x + tid;
if(descrpt_idx < ndescrpt){
grad_net[center_idx * ndescrpt + descrpt_idx] -= dev_dot(grad_one, env_deriv + center_idx * ndescrpt * 3 + descrpt_idx * 3);
}
}

template<typename FPTYPE>
__global__ void force_grad_wrt_neighbors_a(
FPTYPE * grad_net,
const FPTYPE * grad,
const FPTYPE * env_deriv,
const int * nlist,
const int nloc,
const int nnei)
{
// idy -> nnei
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int idy = blockIdx.y;
const unsigned int idw = threadIdx.y;
if (idx >= nloc) {
return;
}
int j_idx = nlist[idx * nnei + idy];
if (j_idx < 0) {
return;
}
if (j_idx >= nloc) j_idx = j_idx % nloc;
grad_net[idx * nnei * 4 + idy * 4 + idw] += dev_dot(grad + j_idx * 3, env_deriv + idx * nnei * 4 * 3 + idy * 4 * 3 + idw * 3);
}

template<typename FPTYPE>
__global__ void force_grad_wrt_neighbors_r(
FPTYPE * grad_net,
const FPTYPE * grad,
const FPTYPE * env_deriv,
const int * nlist,
const int nloc,
const int nnei)
{
// idy -> nnei
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int idy = blockIdx.y;
if (idx >= nloc) {
return;
}
int j_idx = nlist[idx * nnei + idy];
if (j_idx < 0) {
return;
}
if (j_idx >= nloc) j_idx = j_idx % nloc;
grad_net[idx * nnei + idy] += dev_dot(grad + j_idx * 3, env_deriv + idx * nnei * 3 + idy * 3);
}

namespace deepmd {
template<typename FPTYPE>
void prod_force_grad_a_gpu_cuda(
FPTYPE * grad_net,
const FPTYPE * grad,
const FPTYPE * env_deriv,
const int * nlist,
const int nloc,
const int nnei)
{
const int ndescrpt = nnei * 4;
cudaErrcheck(cudaMemset(
grad_net,
0.0, sizeof(FPTYPE) * nloc * ndescrpt));
const int nblock = (ndescrpt + TPB - 1) / TPB;
dim3 block_grid(nloc, nblock);
dim3 thread_grid(TPB, 1);
force_grad_wrt_center_atom<<<block_grid, thread_grid>>>(
grad_net,
grad, env_deriv, ndescrpt);

const int LEN = 128;
const int nblock_ = (nloc + LEN -1) / LEN;
dim3 block_grid_(nblock_, nnei);
dim3 thread_grid_(LEN, 4);
force_grad_wrt_neighbors_a<<<block_grid_, thread_grid_>>>(
grad_net,
grad, env_deriv, nlist, nloc, nnei);
}

template<typename FPTYPE>
void prod_force_grad_r_gpu_cuda(
FPTYPE * grad_net,
const FPTYPE * grad,
const FPTYPE * env_deriv,
const int * nlist,
const int nloc,
const int nnei)
{
const int ndescrpt = nnei * 1;
cudaErrcheck(cudaMemset(
grad_net,
0.0, sizeof(FPTYPE) * nloc * ndescrpt));
const int nblock = (ndescrpt + TPB - 1) / TPB;
dim3 block_grid(nloc, nblock);
dim3 thread_grid(TPB, 1);
force_grad_wrt_center_atom<<<block_grid, thread_grid>>>(
grad_net,
grad, env_deriv, ndescrpt);

const int LEN = 128;
const int nblock_ = (nloc + LEN -1) / LEN;
dim3 block_grid_(nblock_, nnei);
dim3 thread_grid_(LEN, 1);
force_grad_wrt_neighbors_r<<<block_grid_, thread_grid_>>>(
grad_net,
grad, env_deriv, nlist, nloc, nnei);
}

template void prod_force_grad_a_gpu_cuda<float>(float * grad_net, const float * grad, const float * env_deriv, const int * nlist, const int nloc, const int nnei);
template void prod_force_grad_a_gpu_cuda<double>(double * grad_net, const double * grad, const double * env_deriv, const int * nlist, const int nloc, const int nnei);
template void prod_force_grad_r_gpu_cuda<float>(float * grad_net, const float * grad, const float * env_deriv, const int * nlist, const int nloc, const int nnei);
template void prod_force_grad_r_gpu_cuda<double>(double * grad_net, const double * grad, const double * env_deriv, const int * nlist, const int nloc, const int nnei);
}
36 changes: 36 additions & 0 deletions source/lib/src/cuda/prod_virial.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,34 @@
#include "device.h"
#include "gpu_cuda.h"
#include "prod_virial.h"

template <
typename FPTYPE,
int THREADS_PER_BLOCK>
__global__ void atom_virial_reduction(
FPTYPE * virial,
const FPTYPE * atom_virial,
const int nall)
{
unsigned int bid = blockIdx.x;
unsigned int tid = threadIdx.x;
__shared__ FPTYPE data[THREADS_PER_BLOCK];
data[tid] = 0.f;
for (int ii = tid; ii < nall; ii += THREADS_PER_BLOCK) {
data[tid] += atom_virial[ii * 9 + bid];
}
__syncthreads();
// do reduction in shared memory
for (int ii = THREADS_PER_BLOCK >> 1; ii > 0; ii >>= 1) {
if (tid < ii) {
data[tid] += data[tid + ii];
}
__syncthreads();
}
// write result for this block to global memory
if (tid == 0) virial[bid] = data[0];
}

template<typename FPTYPE>
__global__ void virial_deriv_wrt_neighbors_a(
FPTYPE * virial,
Expand Down Expand Up @@ -101,6 +129,10 @@ void prod_virial_a_gpu_cuda(
virial_deriv_wrt_neighbors_a<<<block_grid, thread_grid>>>(
virial, atom_virial,
net_deriv, in_deriv, rij, nlist, nloc, nnei);
// reduction atom_virial to virial
atom_virial_reduction<FPTYPE, TPB> <<<9, TPB>>>(
virial,
atom_virial, nall);
}

template<typename FPTYPE>
Expand Down Expand Up @@ -130,6 +162,10 @@ void prod_virial_r_gpu_cuda(
virial_deriv_wrt_neighbors_r<<<block_grid, thread_grid>>>(
virial, atom_virial,
net_deriv, in_deriv, rij, nlist, nloc, nnei);
// reduction atom_virial to virial
atom_virial_reduction<FPTYPE, TPB> <<<9, TPB>>>(
virial,
atom_virial, nall);
}

template void prod_virial_a_gpu_cuda<float>(float * virial, float * atom_virial, const float * net_deriv, const float * in_deriv, const float * rij, const int * nlist, const int nloc, const int nall, const int nnei);
Expand Down
Loading

0 comments on commit 7fa54f7

Please sign in to comment.