mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pt2] add metas for adaptive_max_pool ops (#104167)
Fixes #103892. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104167 Approved by: https://github.com/ezyang
This commit is contained in:
parent
54e320d4d1
commit
ad58aba932
|
|
@ -2819,7 +2819,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('median', ''), # could not find kernel
|
||||
xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.adaptive_avg_pool3d', ''), # aten._adaptive_avg_pool3d_backward.default - couldn't ...
|
||||
xfail('nn.functional.adaptive_max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbo...
|
||||
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2...
|
||||
skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
|
||||
|
|
@ -2973,7 +2972,6 @@ symbolic_aot_autograd_module_failures = {
|
|||
torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool
|
||||
torch.nn.AdaptiveAvgPool3d, # could not find kernel for aten._adaptive_avg_pool3d_backward.default at dispatch key
|
||||
# DispatchKey.Meta
|
||||
torch.nn.AdaptiveMaxPool1d, # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
torch.nn.AdaptiveMaxPool2d, # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
torch.nn.AdaptiveMaxPool3d, # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group)
|
||||
|
|
|
|||
|
|
@ -1529,7 +1529,6 @@ symbolic_tensor_failures = {
|
|||
xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
|
||||
xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct...
|
||||
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl...
|
||||
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
|
||||
|
|
|
|||
|
|
@ -2207,6 +2207,137 @@ def meta__adaptive_avg_pool2d_backward(grad_out, self):
|
|||
return self.new_empty(self.shape).to(memory_format=memory_format)
|
||||
|
||||
|
||||
def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str):
|
||||
ndim = grad_output.ndim
|
||||
for i in range(1, ndim):
|
||||
torch._check(
|
||||
grad_output.size(i) > 0,
|
||||
lambda: (
|
||||
f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, "
|
||||
f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register_meta(aten.adaptive_max_pool2d)
|
||||
@out_wrapper("out", "indices")
|
||||
def meta_adaptive_max_pool2d(input, output_size):
|
||||
ndim = input.ndim
|
||||
torch._check(
|
||||
ndim in (3, 4),
|
||||
lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}",
|
||||
)
|
||||
for i in range(1, ndim):
|
||||
torch._check(
|
||||
input.size(i) > 0,
|
||||
lambda: (
|
||||
f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
|
||||
f"but input has sizes {input.shape} with dimension {i} being empty"
|
||||
),
|
||||
)
|
||||
|
||||
torch._check(
|
||||
len(output_size) == 2,
|
||||
lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2",
|
||||
)
|
||||
|
||||
dimH = 1
|
||||
sizeB = 1
|
||||
sizeD = 0
|
||||
|
||||
if input.ndim == 4:
|
||||
sizeB = input.size(0)
|
||||
dimH += 1
|
||||
|
||||
sizeD = input.size(dimH - 1)
|
||||
osizeH, osizeW = output_size
|
||||
|
||||
if input.ndim == 3:
|
||||
out_shape = (sizeD, osizeH, osizeW)
|
||||
out = input.new_empty(out_shape)
|
||||
indices = input.new_empty(out_shape, dtype=torch.int64)
|
||||
return out, indices
|
||||
else:
|
||||
out_shape = (sizeB, sizeD, osizeH, osizeW) # type: ignore[assignment]
|
||||
memory_format = utils.suggest_memory_format(input)
|
||||
out = input.new_empty(out_shape).to(memory_format=memory_format)
|
||||
indices = input.new_empty(out_shape, dtype=torch.int64).to(
|
||||
memory_format=memory_format
|
||||
)
|
||||
return out, indices
|
||||
|
||||
|
||||
@register_meta(aten.adaptive_max_pool2d_backward)
|
||||
@out_wrapper()
|
||||
def meta_adaptive_max_pool2d_backward(grad_output, input, indices):
|
||||
ndim = grad_output.ndim
|
||||
torch._check(
|
||||
ndim in (3, 4),
|
||||
lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}",
|
||||
)
|
||||
|
||||
_adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward")
|
||||
|
||||
torch._check(
|
||||
input.dtype == grad_output.dtype,
|
||||
lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}",
|
||||
)
|
||||
|
||||
memory_format = utils.suggest_memory_format(input)
|
||||
return input.new_empty(input.shape).to(memory_format=memory_format)
|
||||
|
||||
|
||||
@register_meta(aten.adaptive_max_pool3d)
|
||||
@out_wrapper("out", "indices")
|
||||
def meta_adaptive_max_pool3d(input, output_size):
|
||||
ndim = input.ndim
|
||||
torch._check(
|
||||
ndim in (4, 5),
|
||||
lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}",
|
||||
)
|
||||
for i in range(1, ndim):
|
||||
torch._check(
|
||||
input.size(i) > 0,
|
||||
lambda: (
|
||||
f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
|
||||
f"but input has sizes {input.shape} with dimension {i} being empty"
|
||||
),
|
||||
)
|
||||
|
||||
torch._check(
|
||||
len(output_size) == 3,
|
||||
lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3",
|
||||
)
|
||||
|
||||
dimD = 0
|
||||
sizeB = 1
|
||||
sizeD = 0
|
||||
|
||||
if ndim == 5:
|
||||
sizeB = input.size(0)
|
||||
dimD += 1
|
||||
|
||||
sizeD = input.size(dimD)
|
||||
osizeT, osizeH, osizeW = output_size
|
||||
|
||||
if ndim == 4:
|
||||
out_shape = (sizeD, osizeT, osizeH, osizeW)
|
||||
else:
|
||||
out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW) # type: ignore[assignment]
|
||||
|
||||
out = input.new_empty(out_shape)
|
||||
indices = input.new_empty(out_shape, dtype=torch.int64)
|
||||
|
||||
return out, indices
|
||||
|
||||
|
||||
@register_meta(aten.adaptive_max_pool3d_backward)
|
||||
@out_wrapper()
|
||||
def meta_adaptive_max_pool3d_backward(grad_output, input, indices):
|
||||
_adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward")
|
||||
return input.new_empty(input.shape)
|
||||
|
||||
|
||||
@register_meta(aten.repeat_interleave.Tensor)
|
||||
def meta_repeat_interleave_Tensor(repeats, output_size=None):
|
||||
if output_size is None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user