Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel&Prim] Fix IndexPutCudaKernel for thread safe and add index_put_double_grad #69095

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"log_double_grad",
"where_double_grad",
"bmm_double_grad",
"index_put_double_grad",
]

# white ops list whose kernel can automatically do type promotion.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@
'abs_double_grad',
'where_grad',
'bmm_grad',
'index_put_grad',
]
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@
- tanh
- sign
- sigmoid
- index_put
Original file line number Diff line number Diff line change
Expand Up @@ -984,5 +984,70 @@ void bmm_double_grad(const Tensor& x,
}
}

template <typename T>
void index_put_double_grad(const Tensor& x,
const std::vector<Tensor>& indices,
const Tensor& value,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_value_grad,
const bool& accumulate,
Tensor* grad_out_grad) {
if (grad_out_grad) {
if (grad_x_grad && grad_value_grad) {
/*
ddout_{i,j} = {
ddx_{i, j}, (i, j) \notin indices,
ddv_{k}, (i, j) \in indices and accumulate is false.
ddx_{i, j} + ddv_{k}, (i, j) \in indices and accumulate is true.
}
*/
Tensor grad_out_grad_tmp = grad_x_grad.get();
grad_out_grad_tmp = index_put<T>(
grad_out_grad_tmp, indices, grad_value_grad.get(), accumulate);
set_output<T>(grad_out_grad_tmp, grad_out_grad);

} else if (grad_x_grad) {
/*
ddout_{i,j} = {
ddx_{i, j}, (i, j) \notin indices,
0, (i, j) \in indices and accumulate is false.
ddx_{i, j}, (i, j) \in indices and accumulate is true.
}
*/
Tensor grad_out_grad_tmp = grad_x_grad.get();
if (!accumulate) {
auto zero_to_fill =
full<T>(common::vectorize(value.dims()), 0, value.dtype());
grad_out_grad_tmp =
index_put<T>(grad_out_grad_tmp, indices, zero_to_fill, accumulate);
}
set_output<T>(grad_out_grad_tmp, grad_out_grad);

} else if (grad_value_grad) {
/*
ddout_{i,j} = {
0, (i, j) \notin indices,
ddv_{k}, (i, j) \in indices.
}
*/
Tensor grad_out_grad_tmp =
full<T>(common::vectorize(x.dims()), 0, x.dtype());
grad_out_grad_tmp = index_put<T>(grad_out_grad_tmp,
indices,
grad_value_grad.get(),
/*accumulate*/ false);
set_output<T>(grad_out_grad_tmp, grad_out_grad);

} else {
/*
ddout_{i,j} = 0
*/
Tensor grad_out_grad_tmp =
full<T>(common::vectorize(x.dims()), 0, x.dtype());
set_output<T>(grad_out_grad_tmp, grad_out_grad);
}
}
}

} // namespace prim
} // namespace paddle
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/index_put_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/phi/kernels/index_put_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
Expand Down Expand Up @@ -54,7 +55,7 @@ __global__ void IndexPutCudaKernel(const T* x,
}

if (accumulate) {
*(out + offset) += *(vals + (idx & is_single_val_tensor));
phi::CudaAtomicAdd(out + offset, *(vals + (idx & is_single_val_tensor)));
} else {
*(out + offset) = *(vals + (idx & is_single_val_tensor));
}
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,18 @@
data_type : out_grad
inplace : (out_grad -> x_grad)

- backward_op : index_put_double_grad
forward : index_put_grad (Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false) -> Tensor(grad_x), Tensor(grad_value)
args : (Tensor x, Tensor[] indices, Tensor value, Tensor grad_x_grad, Tensor grad_value_grad, bool accumulate=false)
output : Tensor(out_grad_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
data_transform :
skip_transform : indices
composite : index_put_double_grad(x, indices, value, grad_x_grad, grad_value_grad, accumulate, out_grad_grad)
optional: grad_x_grad, grad_value_grad

- backward_op : index_put_grad
forward : index_put (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) -> Tensor(out)
args : (Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false)
Expand All @@ -1564,6 +1576,7 @@
data_type : out_grad
data_transform :
skip_transform : indices
backward : index_put_double_grad

- backward_op : index_sample_grad
forward : index_sample (Tensor x, Tensor index) -> Tensor(out)
Expand Down
102 changes: 102 additions & 0 deletions test/prim/prim/vjp/test_comp_high_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,5 +1110,107 @@ def test_high_grad(self):
self.func_triple(p, x_stop, y_stop)


@param.parameterized_class(
('x_shape', 'indices_shape', 'value_shape'),
[
([16], [10], [10]),
([16, 16], [20, 2], [20]),
([12, 13, 14], [88, 1], [88, 13, 14]),
([12, 13, 14], [88, 2], [88, 14]),
([12, 13, 14], [88, 3], [88]),
],
)
class TestIndexPutHighGradCheck(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.x_shape = cls.x_shape
cls.indices_shape = cls.indices_shape
cls.value_shape = cls.value_shape

def _grad(self, y, x, order):
u = y
dx = paddle.ones_like(x)
for _ in range(order):
dx = paddle.grad(u, x, create_graph=True)[0]
u = dx
return dx

def func_double(self, place, x_stop, y_stop):
x = paddle.randn(self.x_shape).astype("float32").to(device=place)
x.stop_gradient = x_stop
n_indices = self.indices_shape[0]
index_dim_size = (
self.indices_shape[1] if len(self.indices_shape) > 1 else 1
)
self.assertEqual(n_indices, self.value_shape[0])
indices = tuple(
[
paddle.randint(0, self.x_shape[i], shape=[n_indices]).to(place)
for i in range(max(index_dim_size, 1))
]
)
value = (
paddle.randn(self.value_shape).astype("float32").to(device=place)
)
value.stop_gradient = y_stop

z = paddle.index_put(x, indices, value)
z = paddle.tanh(z)

if not x.stop_gradient:
dzdx = self._grad(z, x, 2)
if not value.stop_gradient:
dzdy = self._grad(z, value, 2)

def func_triple(self, place, x_stop, y_stop):
x = paddle.randn(self.x_shape).astype("float32").to(device=place)
x.stop_gradient = x_stop
n_indices = self.indices_shape[0]
index_dim_size = (
self.indices_shape[1] if len(self.indices_shape) > 1 else 1
)
self.assertEqual(n_indices, self.value_shape[0])
indices = tuple(
[
paddle.randint(
0,
self.x_shape[i],
shape=[n_indices],
).to(place)
for i in range(max(index_dim_size, 1))
]
)
value = (
paddle.randn(self.value_shape).astype("float32").to(device=place)
)
value.stop_gradient = y_stop

# wraping with tanh to enable high order gradient
z = paddle.index_put(paddle.tanh(x), indices, paddle.tanh(value))
z = paddle.tanh(z)

if not x.stop_gradient:
dzdx = self._grad(z, x, 3)
if not value.stop_gradient:
dzdy = self._grad(z, value, 3)

def test_high_grad(self):
places = []
if (
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
in ['1', 'true', 'on']
or not core.is_compiled_with_cuda()
):
places.append(base.CPUPlace())
if core.is_compiled_with_cuda():
places.append(base.CUDAPlace(0))
for p in places:
for x_stop in [False, True]:
for y_stop in [False, True]:
with dygraph_guard():
self.func_double(p, x_stop, y_stop)
self.func_triple(p, x_stop, y_stop)


if __name__ == '__main__':
unittest.main()