From a56a8c0fc0251bb4cd24b366a290db2e4beea747 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Mon, 10 Oct 2022 20:28:32 +0000 Subject: [PATCH] Add meta support for _adaptive_avg_pool2d_backward (#86359) Pull Request resolved: https://github.com/pytorch/pytorch/pull/86359 Approved by: https://github.com/ezyang, https://github.com/albanD --- torch/_meta_registrations.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index b10b279a6ce..ccfa1c2c57d 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -597,6 +597,26 @@ def meta_adaptive_avg_pool3d(self, output_size): return self.new_empty(self.shape[:-3] + tuple(output_size)) +@register_meta(aten._adaptive_avg_pool2d_backward.default) +def meta__adaptive_avg_pool2d_backward(grad_out, self): + ndim = grad_out.ndim + for i in range(1, ndim): + check( + grad_out.size(i) > 0, + lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \ + size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty", + ) + check( + ndim == 3 or ndim == 4, + lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}", + ) + check( + self.dtype == grad_out.dtype, + lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}", + ) + return self.new_empty(self.shape) + + @register_meta(aten.repeat_interleave.Tensor) def meta_repeat_interleave_Tensor(repeats, output_size=None): if output_size is None: