forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
GridSampler.cpp
82 lines (75 loc) · 3.11 KB
/
GridSampler.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/cuda/GridSampler.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/grid_sampler_2d_backward_native.h>
#include <ATen/ops/grid_sampler_2d_native.h>
#include <ATen/ops/grid_sampler_3d_backward_native.h>
#include <ATen/ops/grid_sampler_3d_native.h>
#include <ATen/ops/zeros_like.h>
#endif
namespace at::native {
Tensor grid_sampler_2d_cuda(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {
auto in_size = input.sizes();
auto grid_size = grid.sizes();
auto output = at::empty(
{in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
launch_grid_sampler_2d_forward_kernel(
output, input, grid, interpolation_mode, padding_mode, align_corners);
return output;
}
Tensor grid_sampler_3d_cuda(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {
auto in_size = input.sizes();
auto grid_size = grid.sizes();
auto output = at::empty(
{in_size[0], in_size[1], grid_size[1], grid_size[2], grid_size[3]},
input.options());
launch_grid_sampler_3d_forward_kernel(
output, input, grid, interpolation_mode, padding_mode, align_corners);
return output;
}
std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
bool align_corners, std::array<bool, 2> output_mask) {
auto input_requires_grad = output_mask[0];
Tensor grad_input = ([&]() {
if (input_requires_grad) {
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
return Tensor();
}
})();
auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
launch_grid_sampler_2d_backward_kernel(
grad_input, grad_grid, grad_output, input,
grid, interpolation_mode, padding_mode, align_corners, output_mask);
return std::make_tuple(grad_input, grad_grid);
}
std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
bool align_corners, std::array<bool,2> output_mask) {
auto input_requires_grad = output_mask[0];
Tensor grad_input = ([&]() {
if (input_requires_grad) {
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
return Tensor();
}
})();
auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
launch_grid_sampler_3d_backward_kernel(
grad_input, grad_grid, grad_output, input,
grid, interpolation_mode, padding_mode, align_corners, output_mask);
return std::make_tuple(grad_input, grad_grid);
}
} // namespace at::native