diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index b1aa7d94c815cd..00538dc58c575d 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -80,6 +80,7 @@ "log_double_grad", "where_double_grad", "bmm_double_grad", + "index_put_double_grad", ] # white ops list whose kernel can automatically do type promotion. diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py index 0281e2c574f98b..c06dc36f038459 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -33,4 +33,5 @@ 'abs_double_grad', 'where_grad', 'bmm_grad', + 'index_put_grad', ] diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index 33401c213366d1..38756f461e6222 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -50,3 +50,4 @@ - tanh - sign - sigmoid +- index_put diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index d40d637097fe71..746fe2b9a89c4b 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -984,5 +984,70 @@ void bmm_double_grad(const Tensor& x, } } +template +void index_put_double_grad(const Tensor& x, + const std::vector& indices, + const Tensor& value, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_value_grad, + const bool& accumulate, + Tensor* grad_out_grad) { + if (grad_out_grad) { + if (grad_x_grad && grad_value_grad) { + /* + ddout_{i,j} = { + ddx_{i, j}, (i, j) \notin indices, + ddv_{k}, (i, j) \in indices and accumulate is false. + ddx_{i, j} + ddv_{k}, (i, j) \in indices and accumulate is true. + } + */ + Tensor grad_out_grad_tmp = grad_x_grad.get(); + grad_out_grad_tmp = index_put( + grad_out_grad_tmp, indices, grad_value_grad.get(), accumulate); + set_output(grad_out_grad_tmp, grad_out_grad); + + } else if (grad_x_grad) { + /* + ddout_{i,j} = { + ddx_{i, j}, (i, j) \notin indices, + 0, (i, j) \in indices and accumulate is false. + ddx_{i, j}, (i, j) \in indices and accumulate is true. + } + */ + Tensor grad_out_grad_tmp = grad_x_grad.get(); + if (!accumulate) { + auto zero_to_fill = + full(common::vectorize(value.dims()), 0, value.dtype()); + grad_out_grad_tmp = + index_put(grad_out_grad_tmp, indices, zero_to_fill, accumulate); + } + set_output(grad_out_grad_tmp, grad_out_grad); + + } else if (grad_value_grad) { + /* + ddout_{i,j} = { + 0, (i, j) \notin indices, + ddv_{k}, (i, j) \in indices. + } + */ + Tensor grad_out_grad_tmp = + full(common::vectorize(x.dims()), 0, x.dtype()); + grad_out_grad_tmp = index_put(grad_out_grad_tmp, + indices, + grad_value_grad.get(), + /*accumulate*/ false); + set_output(grad_out_grad_tmp, grad_out_grad); + + } else { + /* + ddout_{i,j} = 0 + */ + Tensor grad_out_grad_tmp = + full(common::vectorize(x.dims()), 0, x.dtype()); + set_output(grad_out_grad_tmp, grad_out_grad); + } + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index ea7440fabfce47..e92c0244b4eded 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/index_put_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/funcs/index_put_utils.h" @@ -54,7 +55,7 @@ __global__ void IndexPutCudaKernel(const T* x, } if (accumulate) { - *(out + offset) += *(vals + (idx & is_single_val_tensor)); + phi::CudaAtomicAdd(out + offset, *(vals + (idx & is_single_val_tensor))); } else { *(out + offset) = *(vals + (idx & is_single_val_tensor)); } diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index d4c6db2f59528a..741139f6ed55b2 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1553,6 +1553,18 @@ data_type : out_grad inplace : (out_grad -> x_grad) +- backward_op : index_put_double_grad + forward : index_put_grad (Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false) -> Tensor(grad_x), Tensor(grad_value) + args : (Tensor x, Tensor[] indices, Tensor value, Tensor grad_x_grad, Tensor grad_value_grad, bool accumulate=false) + output : Tensor(out_grad_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + data_transform : + skip_transform : indices + composite : index_put_double_grad(x, indices, value, grad_x_grad, grad_value_grad, accumulate, out_grad_grad) + optional: grad_x_grad, grad_value_grad + - backward_op : index_put_grad forward : index_put (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) -> Tensor(out) args : (Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false) @@ -1564,6 +1576,7 @@ data_type : out_grad data_transform : skip_transform : indices + backward : index_put_double_grad - backward_op : index_sample_grad forward : index_sample (Tensor x, Tensor index) -> Tensor(out) diff --git a/test/prim/prim/vjp/test_comp_high_grad.py b/test/prim/prim/vjp/test_comp_high_grad.py index 5ee3413817d4af..db1393e0a1590a 100644 --- a/test/prim/prim/vjp/test_comp_high_grad.py +++ b/test/prim/prim/vjp/test_comp_high_grad.py @@ -1110,5 +1110,107 @@ def test_high_grad(self): self.func_triple(p, x_stop, y_stop) +@param.parameterized_class( + ('x_shape', 'indices_shape', 'value_shape'), + [ + ([16], [10], [10]), + ([16, 16], [20, 2], [20]), + ([12, 13, 14], [88, 1], [88, 13, 14]), + ([12, 13, 14], [88, 2], [88, 14]), + ([12, 13, 14], [88, 3], [88]), + ], +) +class TestIndexPutHighGradCheck(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.x_shape = cls.x_shape + cls.indices_shape = cls.indices_shape + cls.value_shape = cls.value_shape + + def _grad(self, y, x, order): + u = y + dx = paddle.ones_like(x) + for _ in range(order): + dx = paddle.grad(u, x, create_graph=True)[0] + u = dx + return dx + + def func_double(self, place, x_stop, y_stop): + x = paddle.randn(self.x_shape).astype("float32").to(device=place) + x.stop_gradient = x_stop + n_indices = self.indices_shape[0] + index_dim_size = ( + self.indices_shape[1] if len(self.indices_shape) > 1 else 1 + ) + self.assertEqual(n_indices, self.value_shape[0]) + indices = tuple( + [ + paddle.randint(0, self.x_shape[i], shape=[n_indices]).to(place) + for i in range(max(index_dim_size, 1)) + ] + ) + value = ( + paddle.randn(self.value_shape).astype("float32").to(device=place) + ) + value.stop_gradient = y_stop + + z = paddle.index_put(x, indices, value) + z = paddle.tanh(z) + + if not x.stop_gradient: + dzdx = self._grad(z, x, 2) + if not value.stop_gradient: + dzdy = self._grad(z, value, 2) + + def func_triple(self, place, x_stop, y_stop): + x = paddle.randn(self.x_shape).astype("float32").to(device=place) + x.stop_gradient = x_stop + n_indices = self.indices_shape[0] + index_dim_size = ( + self.indices_shape[1] if len(self.indices_shape) > 1 else 1 + ) + self.assertEqual(n_indices, self.value_shape[0]) + indices = tuple( + [ + paddle.randint( + 0, + self.x_shape[i], + shape=[n_indices], + ).to(place) + for i in range(max(index_dim_size, 1)) + ] + ) + value = ( + paddle.randn(self.value_shape).astype("float32").to(device=place) + ) + value.stop_gradient = y_stop + + # wraping with tanh to enable high order gradient + z = paddle.index_put(paddle.tanh(x), indices, paddle.tanh(value)) + z = paddle.tanh(z) + + if not x.stop_gradient: + dzdx = self._grad(z, x, 3) + if not value.stop_gradient: + dzdy = self._grad(z, value, 3) + + def test_high_grad(self): + places = [] + if ( + os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() + in ['1', 'true', 'on'] + or not core.is_compiled_with_cuda() + ): + places.append(base.CPUPlace()) + if core.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + for p in places: + for x_stop in [False, True]: + for y_stop in [False, True]: + with dygraph_guard(): + self.func_double(p, x_stop, y_stop) + self.func_triple(p, x_stop, y_stop) + + if __name__ == '__main__': unittest.main()