From edcd968b51f8a9f8faba806fdd91d34210e672b2 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 23 Apr 2024 06:26:11 +0000 Subject: [PATCH] Add out wrappers to some decompositions (#115437) Pull Request resolved: https://github.com/pytorch/pytorch/pull/115437 Approved by: https://github.com/lezcano --- ...DecompTest.test_aten_core_operators.expect | 7 +++++ ...asDecompTest.test_has_decomposition.expect | 7 ----- test/test_torch.py | 4 +-- torch/_decomp/decompositions.py | 27 ++++++++++++++----- torch/_prims_common/wrappers.py | 24 +++++++++++++---- 5 files changed, 48 insertions(+), 21 deletions(-) diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index b330aa787c1..dc3d8cc389a 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -22,6 +22,9 @@ aten::_softmax aten::_softmax.out aten::_to_copy aten::_to_copy.out +aten::_upsample_nearest_exact1d.out +aten::_upsample_nearest_exact2d.out +aten::_upsample_nearest_exact3d.out aten::abs aten::abs.out aten::abs_ @@ -508,6 +511,10 @@ aten::uniform.out aten::uniform_ aten::unsqueeze aten::upsample_bicubic2d +aten::upsample_bicubic2d.out +aten::upsample_nearest1d.out +aten::upsample_nearest2d.out +aten::upsample_nearest3d.out aten::var.correction aten::var.correction_out aten::var_mean.correction diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 8fbdc431f4d..2fc26d1a326 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -609,13 +609,10 @@ aten::_upsample_bilinear2d_aa aten::_upsample_bilinear2d_aa.out aten::_upsample_bilinear2d_aa_backward aten::_upsample_bilinear2d_aa_backward.grad_input -aten::_upsample_nearest_exact1d.out aten::_upsample_nearest_exact1d_backward aten::_upsample_nearest_exact1d_backward.grad_input -aten::_upsample_nearest_exact2d.out aten::_upsample_nearest_exact2d_backward aten::_upsample_nearest_exact2d_backward.grad_input -aten::_upsample_nearest_exact3d.out aten::_upsample_nearest_exact3d_backward aten::_upsample_nearest_exact3d_backward.grad_input aten::_use_cudnn_ctc_loss @@ -1331,20 +1328,16 @@ aten::unsafe_split_with_sizes.out aten::unsqueeze_ aten::unsqueeze_copy aten::unsqueeze_copy.out -aten::upsample_bicubic2d.out aten::upsample_bicubic2d_backward aten::upsample_bicubic2d_backward.grad_input aten::upsample_bilinear2d_backward aten::upsample_bilinear2d_backward.grad_input aten::upsample_linear1d_backward aten::upsample_linear1d_backward.grad_input -aten::upsample_nearest1d.out aten::upsample_nearest1d_backward aten::upsample_nearest1d_backward.grad_input -aten::upsample_nearest2d.out aten::upsample_nearest2d_backward aten::upsample_nearest2d_backward.grad_input -aten::upsample_nearest3d.out aten::upsample_nearest3d_backward aten::upsample_nearest3d_backward.grad_input aten::upsample_trilinear3d_backward diff --git a/test/test_torch.py b/test/test_torch.py index 735a4f447ae..25d1cc14eda 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -8847,7 +8847,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], out = torch.empty(4, 3, 16, 16, device='meta', dtype=torch.double) self.assertExpectedRaisesInline( RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out), - """Expected out tensor to have dtype float, but got double instead""" + """Expected out tensor to have dtype torch.float32 but got torch.float64 instead""" ) # Complain if out device mismatch @@ -8857,7 +8857,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], if not TEST_WITH_TORCHINDUCTOR: self.assertExpectedRaisesInline( RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out), - """Expected out tensor to have device meta, but got cpu instead""" + """Attempting to copy from device meta to device cpu, but cross-device copies are not allowed!""" ) def test_add_meta_scalar(self): diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 3ef43ad4b13..3b69cc5b91c 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2720,9 +2720,10 @@ def _compute_upsample_nearest_indices(input, output_size, scales, exact=False): return indices -@register_decomposition(aten.upsample_nearest1d.default) +@register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out]) @aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) def upsample_nearest1d( input: Tensor, output_size: List[int], @@ -2731,9 +2732,12 @@ def upsample_nearest1d( return _upsample_nearest(input, output_size, [scales]) -@register_decomposition(aten._upsample_nearest_exact1d.default) +@register_decomposition( + [aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out] +) @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) def upsample_nearest_exact1d( input: Tensor, output_size: List[int], @@ -2742,9 +2746,10 @@ def upsample_nearest_exact1d( return _upsample_nearest(input, output_size, [scales], exact=True) -@register_decomposition(aten.upsample_nearest2d.default) +@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out]) @aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) def upsample_nearest2d( input: Tensor, output_size: List[int], @@ -2754,9 +2759,12 @@ def upsample_nearest2d( return _upsample_nearest(input, output_size, [scales_h, scales_w]) -@register_decomposition(aten._upsample_nearest_exact2d.default) +@register_decomposition( + [aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out] +) @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) def _upsample_nearest_exact2d( input: Tensor, output_size: List[int], @@ -2766,9 +2774,10 @@ def _upsample_nearest_exact2d( return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True) -@register_decomposition(aten.upsample_nearest3d.default) +@register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out]) @aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) def upsample_nearest3d( input: Tensor, output_size: List[int], @@ -2779,9 +2788,12 @@ def upsample_nearest3d( return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w]) -@register_decomposition(aten._upsample_nearest_exact3d.default) +@register_decomposition( + [aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out] +) @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) def _upsample_nearest_exact3d( input: Tensor, output_size: List[int], @@ -4251,8 +4263,9 @@ def matmul(tensor1, tensor2, *, is_out=False): torch._check(False, lambda: "both arguments to matmul need to be at least 1D") -@register_decomposition(aten.upsample_bicubic2d.default) +@register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out]) @aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper() @pw_cast_for_opmath def upsample_bicubic2d_default( input: Tensor, diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 8b7515bbca5..9057edc8759 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -170,9 +170,13 @@ def _resize_output_check(out: TensorLikeType, shape: ShapeType): # TODO: handle tuples of tensors -def _maybe_resize_out(out: TensorLikeType, shape: ShapeType): +def _maybe_resize_out( + out: TensorLikeType, + shape: ShapeType, + memory_format: Optional[torch.memory_format] = None, +): if _resize_output_check(out, shape): - return out.resize_(shape) + return out.resize_(shape, memory_format=memory_format) else: return out @@ -205,7 +209,12 @@ def _safe_copy_out( return copy_to.copy_(copy_from) -def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = False): +def out_wrapper( + *out_names: str, + exact_dtype: bool = False, + pass_is_out: bool = False, + preserve_memory_format=False, +): # The wrapped function needs to convert the output parameters to ensure # compatibility between the Python API (which always uses "out" as the # parameter name and may be a tuple) and the Aten API (which may have @@ -219,6 +228,9 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = is_tensor = len(out_names) == 1 + def maybe_compute_memory_format(t): + return utils.suggest_memory_format(t) if preserve_memory_format else None + def _out_wrapper(fn: Callable) -> Callable: """ Adds the out parameter to a Python reference. @@ -277,7 +289,9 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = if is_tensor: assert isinstance(out, TensorLike) # These two operations are done in-place - _maybe_resize_out(out, result.shape) + _maybe_resize_out( + out, result.shape, maybe_compute_memory_format(result) + ) _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] else: assert isinstance(out, Tuple) # type: ignore[arg-type] @@ -287,7 +301,7 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = ) for r, o in zip(result, out): # These two operations are done in-place - _maybe_resize_out(o, r.shape) + _maybe_resize_out(o, r.shape, maybe_compute_memory_format(r)) _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type] else: out = result