Skip to content

Commit

Permalink
Add aten::grid_sampler_3d, aten::grid_sample_3d_backward (#898)
Browse files Browse the repository at this point in the history
- grid_sampler_3d
- grid_sample_3d_backward

---------

Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
hjhee and xytintel authored Oct 29, 2024
1 parent 2339e13 commit 2d43f11
Show file tree
Hide file tree
Showing 8 changed files with 1,065 additions and 5 deletions.
41 changes: 41 additions & 0 deletions src/ATen/native/xpu/GridSampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,46 @@ std::tuple<Tensor, Tensor> grid_sampler_2d_backward_xpu(
output_mask);
return std::make_tuple(grad_input, grad_grid);
}

Tensor grid_sampler_3d_xpu(
const Tensor& input,
const Tensor& grid,
int64_t interpolation_mode,
int64_t padding_mode,
bool align_corners) {
return xpu::grid_sampler_3d_kernel(
input, grid, interpolation_mode, padding_mode, align_corners);
}

std::tuple<Tensor, Tensor> grid_sampler_3d_backward_xpu(
const Tensor& grad_output,
const Tensor& input,
const Tensor& grid,
int64_t interpolation_mode,
int64_t padding_mode,
bool align_corners,
std::array<bool, 2> output_mask) {
auto input_requires_grad = output_mask[0];
Tensor grad_input = ([&]() {
if (input_requires_grad) {
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
return Tensor();
}
})();
auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
xpu::grid_sampler_3d_backward_kernel(
grad_input,
grad_grid,
grad_output,
input,
grid,
interpolation_mode,
padding_mode,
align_corners,
output_mask);
return std::make_tuple(grad_input, grad_grid);
}

} // namespace native
} // namespace at
Loading

0 comments on commit 2d43f11

Please sign in to comment.