Skip to content

Commit

Permalink
add explanations for se_a_grad_grad (#2903)
Browse files Browse the repository at this point in the history
  • Loading branch information
nahso authored Oct 9, 2023
1 parent d5b1423 commit d8ee74b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
20 changes: 20 additions & 0 deletions source/lib/src/gpu/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,26 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
res_grad += res_grad * t;
}

/*
* `dz_dy`(or `iteratorC`) represents the derivative of the variable `out`
* in the function `tabulate_fusion_se_a_fifth_order_polynomial`.
*
* The expression `em[em_index] * res_grad * dz_xx + dz_dy_dem[em_index] *
* res` utilizes the product rule of derivatives: `(f * g)' = f' * g + f *
* g'`.
*
* This expression can be alternatively expressed as:
* `dz_dy_dem[em_index] * res + em[em_index] * (res_grad * dz_xx)`.
* Note that we can refer to `dz_dy_dem` as `em'`
*
* Therefore, we can rewrite this expression as: `em' * res + em * res'`,
* where `em'` is the derivative of `em` and `res'` is the derivative of
* `res`. Additionally, `res'` can be further represented as: `res_grad *
* dz_xx`.
*
* If `enable_se_atten` is true, `res` will be `res * t + res`, and `res'`
* will become `(res_grad * t + res_grad) * dz_xx`.
*/
for (int kk = 0; kk < MTILE; kk++) {
int em_index = block_idx * nnei * MTILE + ii * MTILE + kk;
iteratorC[kk * last_layer_size + thread_idx] +=
Expand Down
21 changes: 21 additions & 0 deletions source/lib/src/tabulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,27 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
var += var * t;
var_grad += var_grad * t;
}

/*
* `dz_dy` represents the derivative of the variable `out` in the
* function `deepmd::tabulate_fusion_se_a_cpu`.
*
* The expression `var * hh[0] + dz_xx * var_grad * ll[0]` utilizes the
* product rule of derivatives: `(f * g)' = f' * g + f * g'`.
*
* This expression can be alternatively expressed as:
* `hh[0] * var + ll[0] * (dz_xx * var_grad)`.
* Note that `hh[0]` is one element of `em`, and `ll[0]` is one element
* of `dz_dy_dem` which is `em'`.
*
* Therefore, we can rewrite this expression as: `em' * var + em *
* var'`, where `em'` is the derivative of `em` and `var'` is the
* derivative of `var`. Additionally, `var'` can be further represented
* as: `var_grad * dz_xx`.
*
* If `enable_se_atten` is true, `var` will be `var * t + var`, and
* `var'` will be `(var_grad * t + var_grad) * dz_xx`.
*/
if (unloop) {
dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] +=
(nnei - jj) * (var * hh[0] + dz_xx * var_grad * ll[0]);
Expand Down

0 comments on commit d8ee74b

Please sign in to comment.