[pt2] add metas for adaptive_max_pool ops (#104167)

Fixes #103892.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104167
Approved by: https://github.com/ezyang
This commit is contained in:
Nikita Karetnikov 2023-07-05 05:53:04 +02:00 committed by PyTorch MergeBot
parent 54e320d4d1
commit ad58aba932
3 changed files with 131 additions and 3 deletions

View File

@ -2819,7 +2819,6 @@ symbolic_aot_autograd_failures = {
xfail('median', ''), # could not find kernel
xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.adaptive_avg_pool3d', ''), # aten._adaptive_avg_pool3d_backward.default - couldn't ...
xfail('nn.functional.adaptive_max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbo...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2...
skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
@ -2973,7 +2972,6 @@ symbolic_aot_autograd_module_failures = {
torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool
torch.nn.AdaptiveAvgPool3d, # could not find kernel for aten._adaptive_avg_pool3d_backward.default at dispatch key
# DispatchKey.Meta
torch.nn.AdaptiveMaxPool1d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.AdaptiveMaxPool2d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.AdaptiveMaxPool3d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group)

View File

@ -1529,7 +1529,6 @@ symbolic_tensor_failures = {
xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition
xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl...
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...

View File

@ -2207,6 +2207,137 @@ def meta__adaptive_avg_pool2d_backward(grad_out, self):
return self.new_empty(self.shape).to(memory_format=memory_format)
def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str):
ndim = grad_output.ndim
for i in range(1, ndim):
torch._check(
grad_output.size(i) > 0,
lambda: (
f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, "
f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty"
),
)
@register_meta(aten.adaptive_max_pool2d)
@out_wrapper("out", "indices")
def meta_adaptive_max_pool2d(input, output_size):
ndim = input.ndim
torch._check(
ndim in (3, 4),
lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}",
)
for i in range(1, ndim):
torch._check(
input.size(i) > 0,
lambda: (
f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
f"but input has sizes {input.shape} with dimension {i} being empty"
),
)
torch._check(
len(output_size) == 2,
lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2",
)
dimH = 1
sizeB = 1
sizeD = 0
if input.ndim == 4:
sizeB = input.size(0)
dimH += 1
sizeD = input.size(dimH - 1)
osizeH, osizeW = output_size
if input.ndim == 3:
out_shape = (sizeD, osizeH, osizeW)
out = input.new_empty(out_shape)
indices = input.new_empty(out_shape, dtype=torch.int64)
return out, indices
else:
out_shape = (sizeB, sizeD, osizeH, osizeW) # type: ignore[assignment]
memory_format = utils.suggest_memory_format(input)
out = input.new_empty(out_shape).to(memory_format=memory_format)
indices = input.new_empty(out_shape, dtype=torch.int64).to(
memory_format=memory_format
)
return out, indices
@register_meta(aten.adaptive_max_pool2d_backward)
@out_wrapper()
def meta_adaptive_max_pool2d_backward(grad_output, input, indices):
ndim = grad_output.ndim
torch._check(
ndim in (3, 4),
lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}",
)
_adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward")
torch._check(
input.dtype == grad_output.dtype,
lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}",
)
memory_format = utils.suggest_memory_format(input)
return input.new_empty(input.shape).to(memory_format=memory_format)
@register_meta(aten.adaptive_max_pool3d)
@out_wrapper("out", "indices")
def meta_adaptive_max_pool3d(input, output_size):
ndim = input.ndim
torch._check(
ndim in (4, 5),
lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}",
)
for i in range(1, ndim):
torch._check(
input.size(i) > 0,
lambda: (
f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
f"but input has sizes {input.shape} with dimension {i} being empty"
),
)
torch._check(
len(output_size) == 3,
lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3",
)
dimD = 0
sizeB = 1
sizeD = 0
if ndim == 5:
sizeB = input.size(0)
dimD += 1
sizeD = input.size(dimD)
osizeT, osizeH, osizeW = output_size
if ndim == 4:
out_shape = (sizeD, osizeT, osizeH, osizeW)
else:
out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW) # type: ignore[assignment]
out = input.new_empty(out_shape)
indices = input.new_empty(out_shape, dtype=torch.int64)
return out, indices
@register_meta(aten.adaptive_max_pool3d_backward)
@out_wrapper()
def meta_adaptive_max_pool3d_backward(grad_output, input, indices):
_adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward")
return input.new_empty(input.shape)
@register_meta(aten.repeat_interleave.Tensor)
def meta_repeat_interleave_Tensor(repeats, output_size=None):
if output_size is None: