Skip to content

Commit

Permalink
MPS: Add adaptive max pool2d op (pytorch#78410)
Browse files Browse the repository at this point in the history
Adaptive max pool 2d forward and backward with test

Pull Request resolved: pytorch#78410
Approved by: https://github.com/albanD
  • Loading branch information
kulinseth authored and pytorchmergebot committed May 27, 2022
1 parent 8ad305f commit 2e32d5f
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
kernel_sizeW = isizeW - (osizeW-1) * strideW;
}

// Adaptive average pooling

Tensor& adaptive_avg_pool2d_out_mps
(const Tensor& input,
IntArrayRef output_size,
Expand Down Expand Up @@ -150,5 +152,93 @@

}

// Adaptive max pooling

TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
(const Tensor& input,
IntArrayRef output_size,
const Tensor& output,
const Tensor& indices) {

for (int64_t i = 1; i < input.ndimension(); i++) {
TORCH_CHECK(input.size(i) > 0,
"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
"but input has sizes ", input.sizes(), " with dimension ", i, " being "
"empty");
}

int64_t isizeH = input.size(-2);
int64_t isizeW = input.size(-1);

int64_t osizeH = output_size[0];
int64_t osizeW = output_size[1];

if(input.suggest_memory_format() == at::MemoryFormat::ChannelsLast)
TORCH_CHECK(input.ndimension() == 4,
"adaptive_avg_pool2d(): Expected 4D tensor, but got ",
input.sizes())

switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous:
case at::MemoryFormat::ChannelsLast:
break;
default:
TORCH_CHECK(
false,
"Unsupported memory format. Supports only ChannelsLast, Contiguous")
}

int64_t strideH;
int64_t strideW;
int64_t kernel_sizeH;
int64_t kernel_sizeW;

set_kernel_params(isizeH, isizeW,
osizeH, osizeW,
strideH, strideW,
kernel_sizeH, kernel_sizeW);

auto outputs = at::max_pool2d_with_indices(input,
IntArrayRef({kernel_sizeH, kernel_sizeW}),
IntArrayRef({strideH, strideW}),
IntArrayRef({0, 0}),
IntArrayRef({1, 1}),
false);

output.copy_(std::get<0>(outputs));
indices.copy_(std::get<1>(outputs));
}

TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps)
(const Tensor& gradOutput,
const Tensor& input,
const Tensor& indices,
const Tensor& gradInput) {

int64_t isizeH = input.size(-2);
int64_t isizeW = input.size(-1);
int64_t osizeH = gradOutput.size(-2);
int64_t osizeW = gradOutput.size(-1);

int64_t strideH, strideW, kernel_sizeH, kernel_sizeW;

set_kernel_params(isizeH, isizeW,
osizeH, osizeW,
strideH, strideW,
kernel_sizeH, kernel_sizeW);

auto returnGradInput = at::max_pool2d_with_indices_backward(gradOutput,
input,
IntArrayRef({kernel_sizeH, kernel_sizeW}),
IntArrayRef({strideH, strideW}),
IntArrayRef({0, 0}),
IntArrayRef({1, 1}),
false,
indices);

gradInput.copy_(returnGradInput);

}

}
}
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9791,6 +9791,7 @@
dispatch:
CPU: adaptive_max_pool2d_out_cpu
CUDA: adaptive_max_pool2d_out_cuda
MPS: adaptive_max_pool2d_out_mps

# Return: (Tensor output, Tensor indices)
- func: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
Expand All @@ -9803,6 +9804,7 @@
dispatch:
CPU: adaptive_max_pool2d_backward_out_cpu
CUDA: adaptive_max_pool2d_backward_out_cuda
MPS: adaptive_max_pool2d_backward_out_mps

- func: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
python_module: nn
Expand Down
44 changes: 44 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3089,6 +3089,50 @@ def helper(input_shape, out_shape, channels_last):

helper((2, 16, 16), (4, 4), False)

# Test max avg pool2d - when the input size is a multiple of output size
# Not testing for channels last right now
def test_adaptive_max_pool2d_simple(self):
def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
cpu_x = None
if(dtype in [torch.float16, torch.float32]):
cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
else:
cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
if(channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()

max_result, max_indices = None, None
max_result_cpu, max_indices_cpu = None, None

if(return_indices):
max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
else:
max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)

cpu_grad = torch.randn(max_result_cpu.shape)
grad = cpu_grad.to('mps')

max_result.backward(gradient=grad)
max_result_cpu.backward(gradient=cpu_grad)

self.assertEqual(max_result, max_result_cpu)
if(return_indices):
self.assertEqual(max_indices, max_indices_cpu)
self.assertEqual(x.grad, cpu_x.grad)

for dtype in [torch.float32]:
for return_indices in [False, True]:
helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
helper((2, 16, 16), (4, 4), return_indices, dtype)

def test_gelu_simple(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
Expand Down

0 comments on commit 2e32d5f

Please sign in to comment.