diff --git a/deepmd/op/_tabulate_grad.py b/deepmd/op/_tabulate_grad.py index e91aa5fd2f..9076ee3213 100644 --- a/deepmd/op/_tabulate_grad.py +++ b/deepmd/op/_tabulate_grad.py @@ -55,16 +55,17 @@ def _tabulate_fusion_se_atten_grad_cc(op, dy): op.outputs[0], is_sorted=op.get_attr("is_sorted"), ) - return [None, None, dy_dx, dy_df, dy_dtwo] + return [None, None, dy_dx, dy_df, None] @ops.RegisterGradient("TabulateFusionSeAttenGrad") def _tabulate_fusion_se_atten_grad_grad_cc(op, dy, dy_, dy_dtwo): - dz_dy = op_module.tabulate_fusion_se_a_grad_grad( + dz_dy = op_module.tabulate_fusion_se_atten_grad_grad( op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[3], + op.inputs[4], dy, dy_, op.inputs[6], diff --git a/source/lib/include/tabulate.h b/source/lib/include/tabulate.h index 76a46bbe6c..93992cea5b 100644 --- a/source/lib/include/tabulate.h +++ b/source/lib/include/tabulate.h @@ -35,6 +35,7 @@ void tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, const FPTYPE* table_info, const FPTYPE* em_x, const FPTYPE* em, + const FPTYPE* two_embed, const FPTYPE* dz_dy_dem_x, const FPTYPE* dz_dy_dem, const int nloc, @@ -141,6 +142,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy, const FPTYPE* table_info, const FPTYPE* em_x, const FPTYPE* em, + const FPTYPE* two_embed, const FPTYPE* dz_dy_dem_x, const FPTYPE* dz_dy_dem, const int nloc, diff --git a/source/lib/src/gpu/tabulate.cu b/source/lib/src/gpu/tabulate.cu index f424006940..9f924efd9b 100644 --- a/source/lib/src/gpu/tabulate.cu +++ b/source/lib/src/gpu/tabulate.cu @@ -354,6 +354,7 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial( const FPTYPE* table, const FPTYPE* em_x, const FPTYPE* em, + const FPTYPE* two_embed, const FPTYPE* dz_dy_dem_x, const FPTYPE* dz_dy_dem, const FPTYPE lower, @@ -364,6 +365,7 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial( const int nnei, const int last_layer_size, const bool is_sorted) { + bool enable_se_atten = two_embed != nullptr; GPU_DYNAMIC_SHARED_MEM_DECL(int, _data); const int_64 block_idx = blockIdx.x; // nloc const int thread_idx = threadIdx.x; // last_layer_size @@ -402,6 +404,12 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial( ((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) * xx) * xx; + if (enable_se_atten) { + FPTYPE t = two_embed[block_idx * nnei * last_layer_size + + ii * last_layer_size + thread_idx]; + res += res * t; + res_grad += res_grad * t; + } for (int kk = 0; kk < MTILE; kk++) { int em_index = block_idx * nnei * MTILE + ii * MTILE + kk; @@ -769,6 +777,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy, const FPTYPE* table_info, const FPTYPE* em_x, const FPTYPE* em, + const FPTYPE* two_embed, const FPTYPE* dz_dy_dem_x, const FPTYPE* dz_dy_dem, const int nloc, @@ -783,9 +792,9 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy, DPErrcheck(gpuMemset(dz_dy, 0, sizeof(FPTYPE) * nloc * 4 * last_layer_size)); tabulate_fusion_se_a_grad_grad_fifth_order_polynomial <<>>( - dz_dy, table, em_x, em, dz_dy_dem_x, dz_dy_dem, table_info[0], - table_info[1], table_info[2], table_info[3], table_info[4], nnei, - last_layer_size, is_sorted); + dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, + table_info[0], table_info[1], table_info[2], table_info[3], + table_info[4], nnei, last_layer_size, is_sorted); DPErrcheck(gpuGetLastError()); DPErrcheck(gpuDeviceSynchronize()); } @@ -989,6 +998,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu( const float* table_info, const float* em_x, const float* em, + const float* two_embed, const float* dz_dy_dem_x, const float* dz_dy_dem, const int nloc, @@ -1001,6 +1011,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu( const double* table_info, const double* em_x, const double* em, + const double* two_embed, const double* dz_dy_dem_x, const double* dz_dy_dem, const int nloc, diff --git a/source/lib/src/tabulate.cc b/source/lib/src/tabulate.cc index 1cafd36ee2..9b659269e0 100644 --- a/source/lib/src/tabulate.cc +++ b/source/lib/src/tabulate.cc @@ -247,12 +247,14 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, const FPTYPE* table_info, const FPTYPE* em_x, const FPTYPE* em, + const FPTYPE* two_embed, const FPTYPE* dz_dy_dem_x, const FPTYPE* dz_dy_dem, const int nloc, const int nnei, const int last_layer_size, const bool is_sorted) { + bool enable_se_atten = two_embed != nullptr; memset(dz_dy, 0, sizeof(FPTYPE) * nloc * 4 * last_layer_size); const FPTYPE lower = table_info[0]; const FPTYPE upper = table_info[1]; @@ -298,6 +300,12 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, ((FPTYPE)3. * a3 + ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) * xx) * xx; + if (enable_se_atten) { + FPTYPE t = two_embed[ii * nnei * last_layer_size + + jj * last_layer_size + kk]; + var += var * t; + var_grad += var_grad * t; + } if (unloop) { dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] += (nnei - jj) * (var * hh[0] + dz_xx * var_grad * ll[0]); @@ -660,6 +668,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu( const float* table_info, const float* em_x, const float* em, + const float* two_embed, const float* dz_dy_dem_x, const float* dz_dy_dem, const int nloc, @@ -672,6 +681,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu( const double* table_info, const double* em_x, const double* em, + const double* two_embed, const double* dz_dy_dem_x, const double* dz_dy_dem, const int nloc, diff --git a/source/op/tabulate_multi_device.cc b/source/op/tabulate_multi_device.cc index 85ea82803a..488a99bd7d 100644 --- a/source/op/tabulate_multi_device.cc +++ b/source/op/tabulate_multi_device.cc @@ -91,6 +91,19 @@ REGISTER_OP("TabulateFusionSeAttenGrad") .Output("dy_dtwo: T") .Attr("is_sorted: bool = true"); +REGISTER_OP("TabulateFusionSeAttenGradGrad") + .Attr("T: {float, double}") + .Input("table: T") + .Input("table_info: T") + .Input("em_x: T") + .Input("em: T") + .Input("two_embed: T") + .Input("dz_dy_dem_x: T") + .Input("dz_dy_dem: T") + .Input("descriptor: T") + .Output("dz_dy: T") + .Attr("is_sorted: bool = true"); + REGISTER_OP("TabulateFusionSeT") .Attr("T: {float, double} = DT_DOUBLE") .Input("table: T") @@ -312,6 +325,7 @@ class TabulateFusionSeAGradGradOp : public OpKernel { const FPTYPE* table_info = table_info_tensor.flat().data(); const FPTYPE* em_x = em_x_tensor.flat().data(); const FPTYPE* em = em_tensor.flat().data(); + const FPTYPE* two_embed = nullptr; const FPTYPE* dz_dy_dem_x = dz_dy_dem_x_tensor.flat().data(); const FPTYPE* dz_dy_dem = dz_dy_dem_tensor.flat().data(); const int nloc = em_tensor.shape().dim_size(0); @@ -321,8 +335,8 @@ class TabulateFusionSeAGradGradOp : public OpKernel { if (device == "GPU") { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM deepmd::tabulate_fusion_se_a_grad_grad_gpu( - dz_dy, table, table_info, em_x, em, dz_dy_dem_x, dz_dy_dem, nloc, - nnei, last_layer_size, is_sorted); + dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, + nloc, nnei, last_layer_size, is_sorted); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM OP_REQUIRES(context, (last_layer_size <= 1024), errors::InvalidArgument( @@ -330,8 +344,8 @@ class TabulateFusionSeAGradGradOp : public OpKernel { "last layer of embedding net must be less than 1024!")); } else if (device == "CPU") { deepmd::tabulate_fusion_se_a_grad_grad_cpu( - dz_dy, table, table_info, em_x, em, dz_dy_dem_x, dz_dy_dem, nloc, - nnei, last_layer_size, is_sorted); + dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, + nloc, nnei, last_layer_size, is_sorted); } } @@ -484,6 +498,76 @@ class TabulateFusionSeAttenGradOp : public OpKernel { std::string device; }; +template +class TabulateFusionSeAttenGradGradOp : public OpKernel { + public: + explicit TabulateFusionSeAttenGradGradOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("is_sorted", &is_sorted)); + } + void Compute(OpKernelContext* context) override { + deepmd::safe_compute( + context, [this](OpKernelContext* context) { this->_Compute(context); }); + } + + void _Compute(OpKernelContext* context) { + // Grab the input tensor + int context_input_index = 0; + const Tensor& table_tensor = context->input(context_input_index++); + const Tensor& table_info_tensor = context->input(context_input_index++); + const Tensor& em_x_tensor = context->input(context_input_index++); + const Tensor& em_tensor = context->input(context_input_index++); + const Tensor& two_embed_tensor = context->input(context_input_index++); + const Tensor& dz_dy_dem_x_tensor = context->input(context_input_index++); + const Tensor& dz_dy_dem_tensor = context->input(context_input_index++); + const Tensor& descriptor_tensor = context->input(context_input_index++); + // set size of the sample + OP_REQUIRES(context, (dz_dy_dem_x_tensor.shape().dims() == 2), + errors::InvalidArgument("Dim of input should be 2")); + OP_REQUIRES(context, (dz_dy_dem_tensor.shape().dims() == 3), + errors::InvalidArgument("Dim of input should be 3")); + int context_output_index = 0; + Tensor* dz_dy_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(context_output_index++, + descriptor_tensor.shape(), + &dz_dy_tensor)); + DeviceFunctor()(device, context->eigen_device()); + + // flat the tensors + FPTYPE* dz_dy = dz_dy_tensor->flat().data(); + const FPTYPE* table = table_tensor.flat().data(); + const FPTYPE* table_info = table_info_tensor.flat().data(); + const FPTYPE* em_x = em_x_tensor.flat().data(); + const FPTYPE* em = em_tensor.flat().data(); + const FPTYPE* two_embed = two_embed_tensor.flat().data(); + const FPTYPE* dz_dy_dem_x = dz_dy_dem_x_tensor.flat().data(); + const FPTYPE* dz_dy_dem = dz_dy_dem_tensor.flat().data(); + const int nloc = em_tensor.shape().dim_size(0); + const int nnei = em_tensor.shape().dim_size(1); + const int last_layer_size = descriptor_tensor.shape().dim_size(2); + + if (device == "GPU") { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + deepmd::tabulate_fusion_se_a_grad_grad_gpu( + dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, + nloc, nnei, last_layer_size, is_sorted); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + OP_REQUIRES(context, (last_layer_size <= 1024), + errors::InvalidArgument( + "In the process of model compression, the size of the " + "last layer of embedding net must be less than 1024!")); + } else if (device == "CPU") { + deepmd::tabulate_fusion_se_a_grad_grad_cpu( + dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, + nloc, nnei, last_layer_size, is_sorted); + } + } + + private: + bool is_sorted; + std::string device; +}; + template class TabulateFusionSeTOp : public OpKernel { public: @@ -863,6 +947,10 @@ class TabulateFusionSeRGradGradOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint("T"), \ TabulateFusionSeAttenGradOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeAttenGradGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + TabulateFusionSeAttenGradGradOp); \ REGISTER_KERNEL_BUILDER( \ Name("TabulateFusionSeT").Device(DEVICE_CPU).TypeConstraint("T"), \ TabulateFusionSeTOp); \ @@ -887,76 +975,81 @@ REGISTER_CPU(float); REGISTER_CPU(double); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#define REGISTER_GPU(T) \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusion") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeAOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeAGradOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionGradGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeAGradGradOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeA") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeAOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeAGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeAGradOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeAGradGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeAGradGradOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeAtten") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeAttenOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeAttenGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeAttenGradOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeT") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeTOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeTGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeTGradOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeTGradGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeTGradGradOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeR") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeROp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeRGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ - TabulateFusionSeRGradOp); \ - REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeRGradGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("table_info"), \ +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusion") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeAOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeAGradOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionGradGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeAGradGradOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeA") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeAOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeAGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeAGradOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeAGradGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeAGradGradOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeAtten") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeAttenOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeAttenGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeAttenGradOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeAttenGradGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeAttenGradGradOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeT") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeTOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeTGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeTGradOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeTGradGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeTGradGradOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeR") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeROp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeRGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ + TabulateFusionSeRGradOp); \ + REGISTER_KERNEL_BUILDER(Name("TabulateFusionSeRGradGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("table_info"), \ TabulateFusionSeRGradGradOp); REGISTER_GPU(float); REGISTER_GPU(double);