Skip to content

Commit

Permalink
merge devel
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Oct 13, 2023
2 parents 7ceabd2 + 8bc4e3f commit ba5e8db
Show file tree
Hide file tree
Showing 16 changed files with 260 additions and 170 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
exclude: "^.+\\.pbtxt$"
Expand Down
23 changes: 14 additions & 9 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def build(
self.filter_precision,
)
self.negative_mask = -(2 << 32) * (1.0 - self.nmask)
# hard coding the magnitude of attention weight shift
# hard coding the magnitude of attention weight shift
self.smth_attn_w_shift = 20.0
# only used when tensorboard was set as true
tf.summary.histogram("descrpt", self.descrpt)
Expand Down Expand Up @@ -601,7 +601,9 @@ def build(
)
self.recovered_r = (
tf.reshape(
tf.slice(tf.reshape(self.descrpt_reshape, [-1, 4]), [0, 0], [-1, 1]),
tf.slice(
tf.reshape(self.descrpt_reshape, [-1, 4]), [0, 0], [-1, 1]
),
[-1, natoms[0], self.sel_all_a[0]],
)
* self.std_looked_up
Expand Down Expand Up @@ -870,18 +872,21 @@ def _scaled_dot_attn(
if self.smooth:
# (nb x nloc) x nsel
nsel = self.sel_all_a[0]
attn = ((attn + self.smth_attn_w_shift) *
tf.reshape(self.recovered_switch, [-1,1,nsel]) *
tf.reshape(self.recovered_switch, [-1,nsel,1]) -
self.smth_attn_w_shift)
attn = (attn + self.smth_attn_w_shift) * tf.reshape(
self.recovered_switch, [-1, 1, nsel]
) * tf.reshape(
self.recovered_switch, [-1, nsel, 1]
) - self.smth_attn_w_shift
else:
attn *= self.nmask
attn += self.negative_mask
attn = tf.nn.softmax(attn, axis=-1)
if self.smooth:
attn = (attn *
tf.reshape(self.recovered_switch, [-1,1,nsel]) *
tf.reshape(self.recovered_switch, [-1,nsel,1]))
attn = (
attn
* tf.reshape(self.recovered_switch, [-1, 1, nsel])
* tf.reshape(self.recovered_switch, [-1, nsel, 1])
)
else:
attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1])
if save_weights:
Expand Down
3 changes: 2 additions & 1 deletion deepmd/op/_tabulate_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _tabulate_fusion_se_atten_grad_cc(op, dy):
op.outputs[0],
is_sorted=op.get_attr("is_sorted"),
)
return [None, None, dy_dx, dy_df, None]
return [None, None, dy_dx, dy_df, dy_dtwo]


@ops.RegisterGradient("TabulateFusionSeAttenGrad")
Expand All @@ -68,6 +68,7 @@ def _tabulate_fusion_se_atten_grad_grad_cc(op, dy, dy_, dy_dtwo):
op.inputs[4],
dy,
dy_,
dy_dtwo,
op.inputs[6],
is_sorted=op.get_attr("is_sorted"),
)
Expand Down
2 changes: 1 addition & 1 deletion doc/troubleshooting/howtoset_num_nodes.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Set the number of processes with:
```bash
mpirun -np $num_nodes dp
```
Note that `mpirun` here should be the same as the MPI used to build software. For example, one can use `mpirun -h` and `lmp -h` to see if `mpirun` and LAMMPS has the same MPI version.
Note that `mpirun` here should be the same as the MPI used to build software. For example, one can use `mpirun --version` and `lmp -h` to see if `mpirun` and LAMMPS has the same MPI version.

Sometimes, `$num_nodes` and the nodes information can be directly given by the HPC scheduler system, if the MPI used here is the same as the MPI used to build the scheduler system. Otherwise, one have to manually assign these information.

Expand Down
20 changes: 10 additions & 10 deletions source/api_cc/src/DataModifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,18 @@ void DipoleChargeModifier::run_model(
Tensor output_f = output_tensors[cc++];
Tensor output_v = output_tensors[cc++];
Tensor output_av = output_tensors[cc++];
assert(output_f.dims() == 2), "dim of output tensor should be 2";
assert(output_v.dims() == 2), "dim of output tensor should be 2";
assert(output_av.dims() == 2), "dim of output tensor should be 2";
assert(output_f.dims() == 2 && "dim of output tensor should be 2");
assert(output_v.dims() == 2 && "dim of output tensor should be 2");
assert(output_av.dims() == 2 && "dim of output tensor should be 2");
int nframes = output_f.dim_size(0);
int natoms = output_f.dim_size(1) / 3;
assert(output_f.dim_size(0) == 1), "nframes should match";
assert(natoms == nall), "natoms should be nall";
assert(output_v.dim_size(0) == nframes), "nframes should match";
assert(output_v.dim_size(1) == 9), "dof of virial should be 9";
assert(output_av.dim_size(0) == nframes), "nframes should match";
assert(output_av.dim_size(1) == natoms * 9),
"dof of atom virial should be 9 * natoms";
assert(output_f.dim_size(0) == 1 && "nframes should match");
assert(natoms == nall && "natoms should be nall");
assert(output_v.dim_size(0) == nframes && "nframes should match");
assert(output_v.dim_size(1) == 9 && "dof of virial should be 9");
assert(output_av.dim_size(0) == nframes && "nframes should match");
assert(output_av.dim_size(1) == natoms * 9 &&
"dof of atom virial should be 9 * natoms");

auto of = output_f.flat<MODELTYPE>();
auto ov = output_v.flat<MODELTYPE>();
Expand Down
38 changes: 20 additions & 18 deletions source/api_cc/src/DeepTensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,25 +201,27 @@ void DeepTensor::run_model(
Tensor output_at = output_tensors[3];
Tensor output_av = output_tensors[4];
// this is the new model, output has to be rank 2 tensor
assert(output_gt.dims() == 2), "dim of output tensor should be 2";
assert(output_f.dims() == 2), "dim of output tensor should be 2";
assert(output_v.dims() == 2), "dim of output tensor should be 2";
assert(output_at.dims() == 2), "dim of output tensor should be 2";
assert(output_av.dims() == 2), "dim of output tensor should be 2";
assert(output_gt.dims() == 2 && "dim of output tensor should be 2");
assert(output_f.dims() == 2 && "dim of output tensor should be 2");
assert(output_v.dims() == 2 && "dim of output tensor should be 2");
assert(output_at.dims() == 2 && "dim of output tensor should be 2");
assert(output_av.dims() == 2 && "dim of output tensor should be 2");
// also check the tensor shapes
assert(output_gt.dim_size(0) == 1), "nframes should match";
assert(output_gt.dim_size(1) == odim), "dof of global tensor should be odim";
assert(output_f.dim_size(0) == 1), "nframes should match";
assert(output_f.dim_size(1) == odim * nall * 3),
"dof of force should be odim * nall * 3";
assert(output_v.dim_size(0) == 1), "nframes should match";
assert(output_v.dim_size(1) == odim * 9), "dof of virial should be odim * 9";
assert(output_at.dim_size(0) == 1), "nframes should match";
assert(output_at.dim_size(1) == nsel * odim),
"dof of atomic tensor should be nsel * odim";
assert(output_av.dim_size(0) == 1), "nframes should match";
assert(output_av.dim_size(1) == odim * nall * 9),
"dof of atomic virial should be odim * nall * 9";
assert(output_gt.dim_size(0) == 1 && "nframes should match");
assert(output_gt.dim_size(1) == odim &&
"dof of global tensor should be odim");
assert(output_f.dim_size(0) == 1 && "nframes should match");
assert(output_f.dim_size(1) == odim * nall * 3 &&
"dof of force should be odim * nall * 3");
assert(output_v.dim_size(0) == 1 && "nframes should match");
assert(output_v.dim_size(1) == odim * 9 &&
"dof of virial should be odim * 9");
assert(output_at.dim_size(0) == 1 && "nframes should match");
assert(output_at.dim_size(1) == nsel * odim &&
"dof of atomic tensor should be nsel * odim");
assert(output_av.dim_size(0) == 1 && "nframes should match");
assert(output_av.dim_size(1) == odim * nall * 9 &&
"dof of atomic virial should be odim * nall * 9");

auto ogt = output_gt.flat<ENERGYTYPE>();
auto of = output_f.flat<MODELTYPE>();
Expand Down
18 changes: 9 additions & 9 deletions source/api_cc/src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -849,13 +849,13 @@ void deepmd::select_map(std::vector<VT>& out,
const int& nall2) {
for (int kk = 0; kk < nframes; ++kk) {
#ifdef DEBUG
assert(in.size() / stride * stride == in.size()),
"in size should be multiples of stride"
assert(in.size() / stride * stride == in.size() &&
"in size should be multiples of stride")
#endif
for (int ii = 0; ii < in.size() / stride / nframes; ++ii) {
#ifdef DEBUG
assert(ii < idx_map.size()), "idx goes over the idx map size";
assert(idx_map[ii] < out.size()), "mappped idx goes over the out size";
assert(ii < idx_map.size() && "idx goes over the idx map size");
assert(idx_map[ii] < out.size() && "mappped idx goes over the out size");
#endif
if (idx_map[ii] >= 0) {
int to_ii = idx_map[ii];
Expand Down Expand Up @@ -896,13 +896,13 @@ void deepmd::select_map_inv(std::vector<VT>& out,
const std::vector<int>& idx_map,
const int& stride) {
#ifdef DEBUG
assert(in.size() / stride * stride == in.size()),
"in size should be multiples of stride"
assert(in.size() / stride * stride == in.size() &&
"in size should be multiples of stride");
#endif
for (int ii = 0; ii < out.size() / stride; ++ii) {
for (int ii = 0; ii < out.size() / stride; ++ii) {
#ifdef DEBUG
assert(ii < idx_map.size()), "idx goes over the idx map size";
assert(idx_map[ii] < in.size()), "from idx goes over the in size";
assert(ii < idx_map.size() && "idx goes over the idx map size");
assert(idx_map[ii] < in.size() && "from idx goes over the in size");
#endif
if (idx_map[ii] >= 0) {
int from_ii = idx_map[ii];
Expand Down
12 changes: 6 additions & 6 deletions source/lib/include/ComputeDescriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -829,8 +829,8 @@ void compute_descriptor_se_a_extf(std::vector<double> &descrpt_a,
ef[ii] = ef_[ii];
}
}
assert(fabs(deepmd::dot3(ef, ef) - 1.0) < 1e-12),
"ef should be a normalized std::vector";
assert(fabs(deepmd::dot3(ef, ef) - 1.0) < 1e-12 &&
"ef should be a normalized std::vector");

// compute the diff of the neighbors
std::vector<std::vector<double> > sel_a_diff(sec_a.back());
Expand Down Expand Up @@ -970,8 +970,8 @@ void compute_descriptor_se_a_ef_para(std::vector<double> &descrpt_a,
ef[ii] = ef_[ii];
}
}
assert(fabs(deepmd::dot3(ef, ef) - 1.0) < 1e-12),
"ef should be a normalized vector";
assert(fabs(deepmd::dot3(ef, ef) - 1.0) < 1e-12 &&
"ef should be a normalized vector");

// compute the diff of the neighbors
std::vector<std::vector<double> > sel_a_diff(sec_a.back());
Expand Down Expand Up @@ -1107,8 +1107,8 @@ void compute_descriptor_se_a_ef_vert(std::vector<double> &descrpt_a,
ef[ii] = ef_[ii];
}
}
assert(fabs(deepmd::dot3(ef, ef) - 1.0) < 1e-12),
"ef should be a normalized vector";
assert(fabs(deepmd::dot3(ef, ef) - 1.0) < 1e-12 &&
"ef should be a normalized vector");

// compute the diff of the neighbors
std::vector<std::vector<double> > sel_a_diff(sec_a.back());
Expand Down
4 changes: 4 additions & 0 deletions source/lib/include/tabulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ void tabulate_fusion_se_a_cpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand All @@ -38,6 +39,7 @@ void tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down Expand Up @@ -125,6 +127,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand All @@ -145,6 +148,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down
34 changes: 28 additions & 6 deletions source/lib/src/gpu/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ template <typename FPTYPE, int MTILE, int KTILE>
__global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* em_x,
const FPTYPE* em,
Expand Down Expand Up @@ -307,6 +308,7 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
(var[1] +
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
xx;
FPTYPE oldres = res;
FPTYPE t;
if (enable_se_atten) {
t = two_embed[block_idx * nnei * last_layer_size +
Expand All @@ -330,6 +332,13 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
xx) *
xx) *
(enable_se_atten ? res * t + res : res);
if (enable_se_atten) {
// from ii to ii + (nnei - breakpoint)
for (int ii2 = ii; ii2 < ii + nnei - breakpoint; ii2++) {
dy_dtwo[block_idx * nnei * last_layer_size + ii2 * last_layer_size +
jj] = oldres * res;
}
}
}
GpuSyncThreads();
for (int kk = 0; kk < MTILE; kk++) {
Expand Down Expand Up @@ -357,6 +366,7 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const FPTYPE lower,
const FPTYPE upper,
const FPTYPE max,
Expand Down Expand Up @@ -404,9 +414,15 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) *
xx) *
xx;
FPTYPE two_grad = 0.;
if (enable_se_atten) {
FPTYPE t = two_embed[block_idx * nnei * last_layer_size +
ii * last_layer_size + thread_idx];
// dz_dy_dtwo * res * em
// res above should be used instead of res + res * t below
two_grad = dz_dy_dtwo[block_idx * nnei * last_layer_size +
ii * last_layer_size + thread_idx] *
res;
res += res * t;
res_grad += res_grad * t;
}
Expand Down Expand Up @@ -434,8 +450,8 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
for (int kk = 0; kk < MTILE; kk++) {
int em_index = block_idx * nnei * MTILE + ii * MTILE + kk;
iteratorC[kk * last_layer_size + thread_idx] +=
(nnei - breakpoint) *
(em[em_index] * res_grad * dz_xx + dz_dy_dem[em_index] * res);
(nnei - breakpoint) * (em[em_index] * (res_grad * dz_xx + two_grad) +
dz_dy_dem[em_index] * res);
}
mark_table_idx = table_idx;
if (unloop) {
Expand Down Expand Up @@ -764,6 +780,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand All @@ -784,9 +801,9 @@ void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,

tabulate_fusion_se_a_grad_fifth_order_polynomial<FPTYPE, MM, KK>
<<<nloc, KK * WARP_SIZE, sizeof(FPTYPE) * MM * last_layer_size>>>(
dy_dem_x, dy_dem, table, em_x, em, two_embed, dy, table_info[0],
table_info[1], table_info[2], table_info[3], table_info[4], nnei,
last_layer_size, is_sorted);
dy_dem_x, dy_dem, dy_dtwo, table, em_x, em, two_embed, dy,
table_info[0], table_info[1], table_info[2], table_info[3],
table_info[4], nnei, last_layer_size, is_sorted);
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
}
Expand All @@ -800,6 +817,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand All @@ -812,7 +830,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
DPErrcheck(gpuMemset(dz_dy, 0, sizeof(FPTYPE) * nloc * 4 * last_layer_size));
tabulate_fusion_se_a_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>
<<<nloc, last_layer_size, sizeof(FPTYPE) * MM * last_layer_size>>>(
dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem,
dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, dz_dy_dtwo,
table_info[0], table_info[1], table_info[2], table_info[3],
table_info[4], nnei, last_layer_size, is_sorted);
DPErrcheck(gpuGetLastError());
Expand Down Expand Up @@ -990,6 +1008,7 @@ template void tabulate_fusion_se_a_gpu<double>(double* out,
const bool is_sorted);
template void tabulate_fusion_se_a_grad_gpu<float>(float* dy_dem_x,
float* dy_dem,
float* dy_dtwo,
const float* table,
const float* table_info,
const float* em_x,
Expand All @@ -1002,6 +1021,7 @@ template void tabulate_fusion_se_a_grad_gpu<float>(float* dy_dem_x,
const bool is_sorted);
template void tabulate_fusion_se_a_grad_gpu<double>(double* dy_dem_x,
double* dy_dem,
double* dy_dtwo,
const double* table,
const double* table_info,
const double* em_x,
Expand All @@ -1021,6 +1041,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu<float>(
const float* two_embed,
const float* dz_dy_dem_x,
const float* dz_dy_dem,
const float* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand All @@ -1034,6 +1055,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu<double>(
const double* two_embed,
const double* dz_dy_dem_x,
const double* dz_dy_dem,
const double* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down
Loading

0 comments on commit ba5e8db

Please sign in to comment.