Skip to content

Commit

Permalink
Se atten grad grad (#2898)
Browse files Browse the repository at this point in the history
  • Loading branch information
nahso authored Oct 7, 2023
1 parent ba3376b commit 14c9964
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 79 deletions.
5 changes: 3 additions & 2 deletions deepmd/op/_tabulate_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,17 @@ def _tabulate_fusion_se_atten_grad_cc(op, dy):
op.outputs[0],
is_sorted=op.get_attr("is_sorted"),
)
return [None, None, dy_dx, dy_df, dy_dtwo]
return [None, None, dy_dx, dy_df, None]


@ops.RegisterGradient("TabulateFusionSeAttenGrad")
def _tabulate_fusion_se_atten_grad_grad_cc(op, dy, dy_, dy_dtwo):
dz_dy = op_module.tabulate_fusion_se_a_grad_grad(
dz_dy = op_module.tabulate_fusion_se_atten_grad_grad(
op.inputs[0],
op.inputs[1],
op.inputs[2],
op.inputs[3],
op.inputs[4],
dy,
dy_,
op.inputs[6],
Expand Down
2 changes: 2 additions & 0 deletions source/lib/include/tabulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
const FPTYPE* table_info,
const FPTYPE* em_x,
const FPTYPE* em,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const int nloc,
Expand Down Expand Up @@ -141,6 +142,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
const FPTYPE* table_info,
const FPTYPE* em_x,
const FPTYPE* em,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const int nloc,
Expand Down
17 changes: 14 additions & 3 deletions source/lib/src/gpu/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
const FPTYPE* table,
const FPTYPE* em_x,
const FPTYPE* em,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE lower,
Expand All @@ -364,6 +365,7 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
const int nnei,
const int last_layer_size,
const bool is_sorted) {
bool enable_se_atten = two_embed != nullptr;
GPU_DYNAMIC_SHARED_MEM_DECL(int, _data);
const int_64 block_idx = blockIdx.x; // nloc
const int thread_idx = threadIdx.x; // last_layer_size
Expand Down Expand Up @@ -402,6 +404,12 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) *
xx) *
xx;
if (enable_se_atten) {
FPTYPE t = two_embed[block_idx * nnei * last_layer_size +
ii * last_layer_size + thread_idx];
res += res * t;
res_grad += res_grad * t;
}

for (int kk = 0; kk < MTILE; kk++) {
int em_index = block_idx * nnei * MTILE + ii * MTILE + kk;
Expand Down Expand Up @@ -769,6 +777,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
const FPTYPE* table_info,
const FPTYPE* em_x,
const FPTYPE* em,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const int nloc,
Expand All @@ -783,9 +792,9 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
DPErrcheck(gpuMemset(dz_dy, 0, sizeof(FPTYPE) * nloc * 4 * last_layer_size));
tabulate_fusion_se_a_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>
<<<nloc, last_layer_size, sizeof(FPTYPE) * MM * last_layer_size>>>(
dz_dy, table, em_x, em, dz_dy_dem_x, dz_dy_dem, table_info[0],
table_info[1], table_info[2], table_info[3], table_info[4], nnei,
last_layer_size, is_sorted);
dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem,
table_info[0], table_info[1], table_info[2], table_info[3],
table_info[4], nnei, last_layer_size, is_sorted);
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
}
Expand Down Expand Up @@ -989,6 +998,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu<float>(
const float* table_info,
const float* em_x,
const float* em,
const float* two_embed,
const float* dz_dy_dem_x,
const float* dz_dy_dem,
const int nloc,
Expand All @@ -1001,6 +1011,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu<double>(
const double* table_info,
const double* em_x,
const double* em,
const double* two_embed,
const double* dz_dy_dem_x,
const double* dz_dy_dem,
const int nloc,
Expand Down
10 changes: 10 additions & 0 deletions source/lib/src/tabulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,14 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
const FPTYPE* table_info,
const FPTYPE* em_x,
const FPTYPE* em,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const int nloc,
const int nnei,
const int last_layer_size,
const bool is_sorted) {
bool enable_se_atten = two_embed != nullptr;
memset(dz_dy, 0, sizeof(FPTYPE) * nloc * 4 * last_layer_size);
const FPTYPE lower = table_info[0];
const FPTYPE upper = table_info[1];
Expand Down Expand Up @@ -298,6 +300,12 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
((FPTYPE)3. * a3 + ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) *
xx) *
xx;
if (enable_se_atten) {
FPTYPE t = two_embed[ii * nnei * last_layer_size +
jj * last_layer_size + kk];
var += var * t;
var_grad += var_grad * t;
}
if (unloop) {
dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] +=
(nnei - jj) * (var * hh[0] + dz_xx * var_grad * ll[0]);
Expand Down Expand Up @@ -660,6 +668,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu<float>(
const float* table_info,
const float* em_x,
const float* em,
const float* two_embed,
const float* dz_dy_dem_x,
const float* dz_dy_dem,
const int nloc,
Expand All @@ -672,6 +681,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu<double>(
const double* table_info,
const double* em_x,
const double* em,
const double* two_embed,
const double* dz_dy_dem_x,
const double* dz_dy_dem,
const int nloc,
Expand Down
Loading

0 comments on commit 14c9964

Please sign in to comment.