Add out wrappers to some decompositions (#115437)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115437
Approved by: https://github.com/lezcano
This commit is contained in:
Isuru Fernando 2024-04-23 06:26:11 +00:00 committed by PyTorch MergeBot
parent e0c5113dec
commit edcd968b51
5 changed files with 48 additions and 21 deletions

View File

@ -22,6 +22,9 @@ aten::_softmax
aten::_softmax.out
aten::_to_copy
aten::_to_copy.out
aten::_upsample_nearest_exact1d.out
aten::_upsample_nearest_exact2d.out
aten::_upsample_nearest_exact3d.out
aten::abs
aten::abs.out
aten::abs_
@ -508,6 +511,10 @@ aten::uniform.out
aten::uniform_
aten::unsqueeze
aten::upsample_bicubic2d
aten::upsample_bicubic2d.out
aten::upsample_nearest1d.out
aten::upsample_nearest2d.out
aten::upsample_nearest3d.out
aten::var.correction
aten::var.correction_out
aten::var_mean.correction

View File

@ -609,13 +609,10 @@ aten::_upsample_bilinear2d_aa
aten::_upsample_bilinear2d_aa.out
aten::_upsample_bilinear2d_aa_backward
aten::_upsample_bilinear2d_aa_backward.grad_input
aten::_upsample_nearest_exact1d.out
aten::_upsample_nearest_exact1d_backward
aten::_upsample_nearest_exact1d_backward.grad_input
aten::_upsample_nearest_exact2d.out
aten::_upsample_nearest_exact2d_backward
aten::_upsample_nearest_exact2d_backward.grad_input
aten::_upsample_nearest_exact3d.out
aten::_upsample_nearest_exact3d_backward
aten::_upsample_nearest_exact3d_backward.grad_input
aten::_use_cudnn_ctc_loss
@ -1331,20 +1328,16 @@ aten::unsafe_split_with_sizes.out
aten::unsqueeze_
aten::unsqueeze_copy
aten::unsqueeze_copy.out
aten::upsample_bicubic2d.out
aten::upsample_bicubic2d_backward
aten::upsample_bicubic2d_backward.grad_input
aten::upsample_bilinear2d_backward
aten::upsample_bilinear2d_backward.grad_input
aten::upsample_linear1d_backward
aten::upsample_linear1d_backward.grad_input
aten::upsample_nearest1d.out
aten::upsample_nearest1d_backward
aten::upsample_nearest1d_backward.grad_input
aten::upsample_nearest2d.out
aten::upsample_nearest2d_backward
aten::upsample_nearest2d_backward.grad_input
aten::upsample_nearest3d.out
aten::upsample_nearest3d_backward
aten::upsample_nearest3d_backward.grad_input
aten::upsample_trilinear3d_backward

View File

@ -8847,7 +8847,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
out = torch.empty(4, 3, 16, 16, device='meta', dtype=torch.double)
self.assertExpectedRaisesInline(
RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
"""Expected out tensor to have dtype float, but got double instead"""
"""Expected out tensor to have dtype torch.float32 but got torch.float64 instead"""
)
# Complain if out device mismatch
@ -8857,7 +8857,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
if not TEST_WITH_TORCHINDUCTOR:
self.assertExpectedRaisesInline(
RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
"""Expected out tensor to have device meta, but got cpu instead"""
"""Attempting to copy from device meta to device cpu, but cross-device copies are not allowed!"""
)
def test_add_meta_scalar(self):

View File

@ -2720,9 +2720,10 @@ def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
return indices
@register_decomposition(aten.upsample_nearest1d.default)
@register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out])
@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest1d(
input: Tensor,
output_size: List[int],
@ -2731,9 +2732,12 @@ def upsample_nearest1d(
return _upsample_nearest(input, output_size, [scales])
@register_decomposition(aten._upsample_nearest_exact1d.default)
@register_decomposition(
[aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out]
)
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest_exact1d(
input: Tensor,
output_size: List[int],
@ -2742,9 +2746,10 @@ def upsample_nearest_exact1d(
return _upsample_nearest(input, output_size, [scales], exact=True)
@register_decomposition(aten.upsample_nearest2d.default)
@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest2d(
input: Tensor,
output_size: List[int],
@ -2754,9 +2759,12 @@ def upsample_nearest2d(
return _upsample_nearest(input, output_size, [scales_h, scales_w])
@register_decomposition(aten._upsample_nearest_exact2d.default)
@register_decomposition(
[aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out]
)
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def _upsample_nearest_exact2d(
input: Tensor,
output_size: List[int],
@ -2766,9 +2774,10 @@ def _upsample_nearest_exact2d(
return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True)
@register_decomposition(aten.upsample_nearest3d.default)
@register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out])
@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest3d(
input: Tensor,
output_size: List[int],
@ -2779,9 +2788,12 @@ def upsample_nearest3d(
return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w])
@register_decomposition(aten._upsample_nearest_exact3d.default)
@register_decomposition(
[aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out]
)
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def _upsample_nearest_exact3d(
input: Tensor,
output_size: List[int],
@ -4251,8 +4263,9 @@ def matmul(tensor1, tensor2, *, is_out=False):
torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
@register_decomposition(aten.upsample_bicubic2d.default)
@register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out])
@aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd)
@out_wrapper()
@pw_cast_for_opmath
def upsample_bicubic2d_default(
input: Tensor,

View File

@ -170,9 +170,13 @@ def _resize_output_check(out: TensorLikeType, shape: ShapeType):
# TODO: handle tuples of tensors
def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
def _maybe_resize_out(
out: TensorLikeType,
shape: ShapeType,
memory_format: Optional[torch.memory_format] = None,
):
if _resize_output_check(out, shape):
return out.resize_(shape)
return out.resize_(shape, memory_format=memory_format)
else:
return out
@ -205,7 +209,12 @@ def _safe_copy_out(
return copy_to.copy_(copy_from)
def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = False):
def out_wrapper(
*out_names: str,
exact_dtype: bool = False,
pass_is_out: bool = False,
preserve_memory_format=False,
):
# The wrapped function needs to convert the output parameters to ensure
# compatibility between the Python API (which always uses "out" as the
# parameter name and may be a tuple) and the Aten API (which may have
@ -219,6 +228,9 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool =
is_tensor = len(out_names) == 1
def maybe_compute_memory_format(t):
return utils.suggest_memory_format(t) if preserve_memory_format else None
def _out_wrapper(fn: Callable) -> Callable:
"""
Adds the out parameter to a Python reference.
@ -277,7 +289,9 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool =
if is_tensor:
assert isinstance(out, TensorLike)
# These two operations are done in-place
_maybe_resize_out(out, result.shape)
_maybe_resize_out(
out, result.shape, maybe_compute_memory_format(result)
)
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
else:
assert isinstance(out, Tuple) # type: ignore[arg-type]
@ -287,7 +301,7 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool =
)
for r, o in zip(result, out):
# These two operations are done in-place
_maybe_resize_out(o, r.shape)
_maybe_resize_out(o, r.shape, maybe_compute_memory_format(r))
_safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type]
else:
out = result