mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add meta support for _adaptive_avg_pool2d_backward (#86359)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86359 Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
parent
03d8ab4dec
commit
a56a8c0fc0
|
|
@ -597,6 +597,26 @@ def meta_adaptive_avg_pool3d(self, output_size):
|
|||
return self.new_empty(self.shape[:-3] + tuple(output_size))
|
||||
|
||||
|
||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||
def meta__adaptive_avg_pool2d_backward(grad_out, self):
|
||||
ndim = grad_out.ndim
|
||||
for i in range(1, ndim):
|
||||
check(
|
||||
grad_out.size(i) > 0,
|
||||
lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
|
||||
size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
|
||||
)
|
||||
check(
|
||||
ndim == 3 or ndim == 4,
|
||||
lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
|
||||
)
|
||||
check(
|
||||
self.dtype == grad_out.dtype,
|
||||
lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
|
||||
)
|
||||
return self.new_empty(self.shape)
|
||||
|
||||
|
||||
@register_meta(aten.repeat_interleave.Tensor)
|
||||
def meta_repeat_interleave_Tensor(repeats, output_size=None):
|
||||
if output_size is None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user