Add out wrappers to some decompositions (#115437)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115437
Approved by: https://github.com/lezcano
This commit is contained in:
Isuru Fernando 2024-04-23 06:26:11 +00:00 committed by PyTorch MergeBot
parent e0c5113dec
commit edcd968b51
5 changed files with 48 additions and 21 deletions

View File

@ -22,6 +22,9 @@ aten::_softmax
aten::_softmax.out aten::_softmax.out
aten::_to_copy aten::_to_copy
aten::_to_copy.out aten::_to_copy.out
aten::_upsample_nearest_exact1d.out
aten::_upsample_nearest_exact2d.out
aten::_upsample_nearest_exact3d.out
aten::abs aten::abs
aten::abs.out aten::abs.out
aten::abs_ aten::abs_
@ -508,6 +511,10 @@ aten::uniform.out
aten::uniform_ aten::uniform_
aten::unsqueeze aten::unsqueeze
aten::upsample_bicubic2d aten::upsample_bicubic2d
aten::upsample_bicubic2d.out
aten::upsample_nearest1d.out
aten::upsample_nearest2d.out
aten::upsample_nearest3d.out
aten::var.correction aten::var.correction
aten::var.correction_out aten::var.correction_out
aten::var_mean.correction aten::var_mean.correction

View File

@ -609,13 +609,10 @@ aten::_upsample_bilinear2d_aa
aten::_upsample_bilinear2d_aa.out aten::_upsample_bilinear2d_aa.out
aten::_upsample_bilinear2d_aa_backward aten::_upsample_bilinear2d_aa_backward
aten::_upsample_bilinear2d_aa_backward.grad_input aten::_upsample_bilinear2d_aa_backward.grad_input
aten::_upsample_nearest_exact1d.out
aten::_upsample_nearest_exact1d_backward aten::_upsample_nearest_exact1d_backward
aten::_upsample_nearest_exact1d_backward.grad_input aten::_upsample_nearest_exact1d_backward.grad_input
aten::_upsample_nearest_exact2d.out
aten::_upsample_nearest_exact2d_backward aten::_upsample_nearest_exact2d_backward
aten::_upsample_nearest_exact2d_backward.grad_input aten::_upsample_nearest_exact2d_backward.grad_input
aten::_upsample_nearest_exact3d.out
aten::_upsample_nearest_exact3d_backward aten::_upsample_nearest_exact3d_backward
aten::_upsample_nearest_exact3d_backward.grad_input aten::_upsample_nearest_exact3d_backward.grad_input
aten::_use_cudnn_ctc_loss aten::_use_cudnn_ctc_loss
@ -1331,20 +1328,16 @@ aten::unsafe_split_with_sizes.out
aten::unsqueeze_ aten::unsqueeze_
aten::unsqueeze_copy aten::unsqueeze_copy
aten::unsqueeze_copy.out aten::unsqueeze_copy.out
aten::upsample_bicubic2d.out
aten::upsample_bicubic2d_backward aten::upsample_bicubic2d_backward
aten::upsample_bicubic2d_backward.grad_input aten::upsample_bicubic2d_backward.grad_input
aten::upsample_bilinear2d_backward aten::upsample_bilinear2d_backward
aten::upsample_bilinear2d_backward.grad_input aten::upsample_bilinear2d_backward.grad_input
aten::upsample_linear1d_backward aten::upsample_linear1d_backward
aten::upsample_linear1d_backward.grad_input aten::upsample_linear1d_backward.grad_input
aten::upsample_nearest1d.out
aten::upsample_nearest1d_backward aten::upsample_nearest1d_backward
aten::upsample_nearest1d_backward.grad_input aten::upsample_nearest1d_backward.grad_input
aten::upsample_nearest2d.out
aten::upsample_nearest2d_backward aten::upsample_nearest2d_backward
aten::upsample_nearest2d_backward.grad_input aten::upsample_nearest2d_backward.grad_input
aten::upsample_nearest3d.out
aten::upsample_nearest3d_backward aten::upsample_nearest3d_backward
aten::upsample_nearest3d_backward.grad_input aten::upsample_nearest3d_backward.grad_input
aten::upsample_trilinear3d_backward aten::upsample_trilinear3d_backward

View File

@ -8847,7 +8847,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
out = torch.empty(4, 3, 16, 16, device='meta', dtype=torch.double) out = torch.empty(4, 3, 16, 16, device='meta', dtype=torch.double)
self.assertExpectedRaisesInline( self.assertExpectedRaisesInline(
RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out), RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
"""Expected out tensor to have dtype float, but got double instead""" """Expected out tensor to have dtype torch.float32 but got torch.float64 instead"""
) )
# Complain if out device mismatch # Complain if out device mismatch
@ -8857,7 +8857,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
if not TEST_WITH_TORCHINDUCTOR: if not TEST_WITH_TORCHINDUCTOR:
self.assertExpectedRaisesInline( self.assertExpectedRaisesInline(
RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out), RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
"""Expected out tensor to have device meta, but got cpu instead""" """Attempting to copy from device meta to device cpu, but cross-device copies are not allowed!"""
) )
def test_add_meta_scalar(self): def test_add_meta_scalar(self):

View File

@ -2720,9 +2720,10 @@ def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
return indices return indices
@register_decomposition(aten.upsample_nearest1d.default) @register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out])
@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd) @aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest1d( def upsample_nearest1d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
@ -2731,9 +2732,12 @@ def upsample_nearest1d(
return _upsample_nearest(input, output_size, [scales]) return _upsample_nearest(input, output_size, [scales])
@register_decomposition(aten._upsample_nearest_exact1d.default) @register_decomposition(
[aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out]
)
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd) @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest_exact1d( def upsample_nearest_exact1d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
@ -2742,9 +2746,10 @@ def upsample_nearest_exact1d(
return _upsample_nearest(input, output_size, [scales], exact=True) return _upsample_nearest(input, output_size, [scales], exact=True)
@register_decomposition(aten.upsample_nearest2d.default) @register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd) @aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest2d( def upsample_nearest2d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
@ -2754,9 +2759,12 @@ def upsample_nearest2d(
return _upsample_nearest(input, output_size, [scales_h, scales_w]) return _upsample_nearest(input, output_size, [scales_h, scales_w])
@register_decomposition(aten._upsample_nearest_exact2d.default) @register_decomposition(
[aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out]
)
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd) @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def _upsample_nearest_exact2d( def _upsample_nearest_exact2d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
@ -2766,9 +2774,10 @@ def _upsample_nearest_exact2d(
return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True) return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True)
@register_decomposition(aten.upsample_nearest3d.default) @register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out])
@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd) @aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest3d( def upsample_nearest3d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
@ -2779,9 +2788,12 @@ def upsample_nearest3d(
return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w]) return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w])
@register_decomposition(aten._upsample_nearest_exact3d.default) @register_decomposition(
[aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out]
)
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd) @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def _upsample_nearest_exact3d( def _upsample_nearest_exact3d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
@ -4251,8 +4263,9 @@ def matmul(tensor1, tensor2, *, is_out=False):
torch._check(False, lambda: "both arguments to matmul need to be at least 1D") torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
@register_decomposition(aten.upsample_bicubic2d.default) @register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out])
@aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd) @aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd)
@out_wrapper()
@pw_cast_for_opmath @pw_cast_for_opmath
def upsample_bicubic2d_default( def upsample_bicubic2d_default(
input: Tensor, input: Tensor,

View File

@ -170,9 +170,13 @@ def _resize_output_check(out: TensorLikeType, shape: ShapeType):
# TODO: handle tuples of tensors # TODO: handle tuples of tensors
def _maybe_resize_out(out: TensorLikeType, shape: ShapeType): def _maybe_resize_out(
out: TensorLikeType,
shape: ShapeType,
memory_format: Optional[torch.memory_format] = None,
):
if _resize_output_check(out, shape): if _resize_output_check(out, shape):
return out.resize_(shape) return out.resize_(shape, memory_format=memory_format)
else: else:
return out return out
@ -205,7 +209,12 @@ def _safe_copy_out(
return copy_to.copy_(copy_from) return copy_to.copy_(copy_from)
def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = False): def out_wrapper(
*out_names: str,
exact_dtype: bool = False,
pass_is_out: bool = False,
preserve_memory_format=False,
):
# The wrapped function needs to convert the output parameters to ensure # The wrapped function needs to convert the output parameters to ensure
# compatibility between the Python API (which always uses "out" as the # compatibility between the Python API (which always uses "out" as the
# parameter name and may be a tuple) and the Aten API (which may have # parameter name and may be a tuple) and the Aten API (which may have
@ -219,6 +228,9 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool =
is_tensor = len(out_names) == 1 is_tensor = len(out_names) == 1
def maybe_compute_memory_format(t):
return utils.suggest_memory_format(t) if preserve_memory_format else None
def _out_wrapper(fn: Callable) -> Callable: def _out_wrapper(fn: Callable) -> Callable:
""" """
Adds the out parameter to a Python reference. Adds the out parameter to a Python reference.
@ -277,7 +289,9 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool =
if is_tensor: if is_tensor:
assert isinstance(out, TensorLike) assert isinstance(out, TensorLike)
# These two operations are done in-place # These two operations are done in-place
_maybe_resize_out(out, result.shape) _maybe_resize_out(
out, result.shape, maybe_compute_memory_format(result)
)
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
else: else:
assert isinstance(out, Tuple) # type: ignore[arg-type] assert isinstance(out, Tuple) # type: ignore[arg-type]
@ -287,7 +301,7 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool =
) )
for r, o in zip(result, out): for r, o in zip(result, out):
# These two operations are done in-place # These two operations are done in-place
_maybe_resize_out(o, r.shape) _maybe_resize_out(o, r.shape, maybe_compute_memory_format(r))
_safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type] _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type]
else: else:
out = result out = result