Revert "Add decomp for replication_pad2d and use for CUDA deterministic (#111590)"

This reverts commit f1286161a6.

Reverted https://github.com/pytorch/pytorch/pull/111590 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing XLA job.  The job is also failing on the PR, but the log classifier failed to find the failed test which lead to it being marked wrongly as flaky ([comment](https://github.com/pytorch/pytorch/pull/111590#issuecomment-1833004794))
This commit is contained in:
PyTorch MergeBot 2023-11-30 02:28:14 +00:00
parent 9f3ec2ad45
commit 013675ff59
10 changed files with 8 additions and 183 deletions

View File

@ -1105,6 +1105,7 @@ aten::replication_pad1d
aten::replication_pad1d.out aten::replication_pad1d.out
aten::replication_pad1d_backward aten::replication_pad1d_backward
aten::replication_pad1d_backward.grad_input aten::replication_pad1d_backward.grad_input
aten::replication_pad2d
aten::replication_pad2d.out aten::replication_pad2d.out
aten::replication_pad2d_backward aten::replication_pad2d_backward
aten::replication_pad2d_backward.grad_input aten::replication_pad2d_backward.grad_input

View File

@ -199,7 +199,7 @@ graph {
} }
} }
node { node {
input: "input" input: "onnx::Pad_0"
input: "onnx::Pad_22" input: "onnx::Pad_22"
output: "23" output: "23"
name: "Pad_23" name: "Pad_23"
@ -212,7 +212,7 @@ graph {
} }
name: "main_graph" name: "main_graph"
input { input {
name: "input" name: "onnx::Pad_0"
type { type {
tensor_type { tensor_type {
elem_type: 1 elem_type: 1

View File

@ -4208,7 +4208,7 @@ class TestFunctionalTracing(JitTestCase):
"linear": BUILT_IN_FUNC, "linear": BUILT_IN_FUNC,
"logsigmoid": BUILT_IN_FUNC, "logsigmoid": BUILT_IN_FUNC,
"one_hot": BUILT_IN_FUNC, "one_hot": BUILT_IN_FUNC,
"pad": ARG_TYPE_MISMATCH, "pad": BUILT_IN_FUNC,
"pairwise_distance": BUILT_IN_FUNC, "pairwise_distance": BUILT_IN_FUNC,
"pdist": BUILT_IN_FUNC, "pdist": BUILT_IN_FUNC,
"pixel_shuffle": BUILT_IN_FUNC, "pixel_shuffle": BUILT_IN_FUNC,

View File

@ -1514,33 +1514,6 @@ else:
'upsample_bilinear2d_backward_out_cuda', 'upsample_bilinear2d_backward_out_cuda',
torch.device(device).type == 'cuda') torch.device(device).type == 'cuda')
@skipIfTorchInductor("aot-autograd issue")
def test_deterministic_replication_pad2d(self, device):
test_cases = [
# size, padding
[(1, 2, 4, 4), (0, 0, 0, 0)],
[(1, 2, 4, 4), (3, 4, 5, 6)],
[(4, 3, 5, 10), (-9, 4, 5, 6)],
[(3, 8, 7), (0, 0, 0, 0)],
[(3, 8, 7), (-4, -2, -2, -3)],
[(3, 8, 7), (4, 3, 2, 7)],
]
for size, padding in test_cases:
input = torch.randn(*size, device=device, requires_grad=True)
grad = None
with DeterministicGuard(True):
res = torch.nn.functional.pad(
input,
padding,
mode='replicate')
res.backward(torch.ones_like(res))
if grad is None:
grad = input.grad
else:
self.assertEqual(grad, input.grad, atol=0, rtol=0)
input.grad = None
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
def test_deterministic_interpolate_bilinear(self, device): def test_deterministic_interpolate_bilinear(self, device):
input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True) input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True)
@ -1649,25 +1622,11 @@ else:
res = module(input) res = module(input)
grad = torch.ones_like(res) grad = torch.ones_like(res)
# Nondeterministic alert should only be raised if the forward call was
# nondeterministic
self.check_nondeterministic_alert( self.check_nondeterministic_alert(
lambda: res.backward(grad, retain_graph=True), lambda: res.backward(grad, retain_graph=True),
'replication_pad2d_backward_cuda', 'replication_pad2d_backward_cuda',
torch.device(device).type == 'cuda') torch.device(device).type == 'cuda')
with DeterministicGuard(True):
res = module(input)
grad = torch.ones_like(res)
# If the forward call was deterministic, nondeterministic alert should
# not be raised
self.check_nondeterministic_alert(
lambda: res.backward(grad, retain_graph=True),
'replication_pad2d_backward_cuda',
False)
@skipIfMps @skipIfMps
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
def test_nondeterministic_alert_ReplicationPad3d(self, device): def test_nondeterministic_alert_ReplicationPad3d(self, device):

View File

@ -754,7 +754,6 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo
* :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor
* :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor
* :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor
* :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor
* :func:`torch.bmm` when called on sparse-dense CUDA tensors * :func:`torch.bmm` when called on sparse-dense CUDA tensors
* :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
and the index is a list of tensors and the index is a list of tensors
@ -797,6 +796,7 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo
* :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor
* :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor
* :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor
* :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor
* :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor
* :class:`torch.nn.NLLLoss` when called on a CUDA tensor * :class:`torch.nn.NLLLoss` when called on a CUDA tensor
* :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor

View File

@ -375,7 +375,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.rad2deg_, aten.rad2deg_,
aten.renorm, aten.renorm,
aten.renorm_, aten.renorm_,
aten.replication_pad2d,
aten.rot90, aten.rot90,
aten.rrelu_with_noise, aten.rrelu_with_noise,
aten.rrelu_with_noise_, aten.rrelu_with_noise_,

View File

@ -3344,96 +3344,6 @@ def upsample_bilinear2d(
return result return result
@register_decomposition(aten.replication_pad2d.default)
@pw_cast_for_opmath
def replication_pad2d(input: Tensor, padding: List[int]) -> Tensor:
pad_left = padding[0]
pad_right = padding[1]
pad_top = padding[2]
pad_bottom = padding[3]
# If all of the padding values are non-negative, then the following tensors
# are all equal to the input. But if any padding values are negative, we
# have to remove the appropriate rows and columns from the input.
# `input_mid` has all negative padding removed from it. `input_mid_tb` has
# negative left and right padding removed from it. `input_mid_lr` has
# negative top and bottom padding removed from it.
input_mid = input
input_mid_tb = input
input_mid_lr = input
if pad_left < 0:
input_mid = input_mid[..., -pad_left:]
input_mid_tb = input_mid_tb[..., -pad_left:]
pad_left = 0
if pad_right < 0:
input_mid = input_mid[..., :pad_right]
input_mid_tb = input_mid_tb[..., :pad_right]
pad_right = 0
if pad_top < 0:
input_mid = input_mid[..., -pad_top:, :]
input_mid_lr = input_mid_lr[..., -pad_top:, :]
pad_top = 0
if pad_bottom < 0:
input_mid = input_mid[..., :pad_bottom, :]
input_mid_lr = input_mid_lr[..., :pad_bottom, :]
pad_bottom = 0
batch_dims_no_repeat = (1,) * (input.dim() - 2)
repeat_top_left = batch_dims_no_repeat + (pad_top, pad_left)
repeat_top_middle = batch_dims_no_repeat + (pad_top, 1)
repeat_top_right = batch_dims_no_repeat + (pad_top, pad_right)
top_rows = torch.cat(
[
# top left
input[..., [0], :][..., [0]].repeat(repeat_top_left),
# top middle
input_mid_tb[..., [0], :].repeat(repeat_top_middle),
# top right
input[..., [0], :][..., [-1]].repeat(repeat_top_right),
],
dim=-1,
)
repeat_middle_left = batch_dims_no_repeat + (1, pad_left)
repeat_middle_right = batch_dims_no_repeat + (1, pad_right)
middle_rows = torch.cat(
[
# middle left
input_mid_lr[..., [0]].repeat(repeat_middle_left),
# middle middle
input_mid,
# middle right
input_mid_lr[..., [-1]].repeat(repeat_middle_right),
],
dim=-1,
)
repeat_bottom_left = batch_dims_no_repeat + (pad_bottom, pad_left)
repeat_bottom_middle = batch_dims_no_repeat + (pad_bottom, 1)
repeat_bottom_right = batch_dims_no_repeat + (pad_bottom, pad_right)
bottom_rows = torch.cat(
[
# bottom left
input[..., [-1], :][..., [0]].repeat(repeat_bottom_left),
# bottom middle
input_mid_tb[..., [-1], :].repeat(repeat_bottom_middle),
# bottom right
input[..., [-1], :][..., [-1]].repeat(repeat_bottom_right),
],
dim=-1,
)
return torch.cat([top_rows, middle_rows, bottom_rows], dim=-2)
# We should be applying decompositions after all transformations # We should be applying decompositions after all transformations
@register_decomposition(aten.is_same_size.default) @register_decomposition(aten.is_same_size.default)
def is_same_size(a: Tensor, b: Tensor) -> bool: def is_same_size(a: Tensor, b: Tensor) -> bool:

View File

@ -2235,7 +2235,6 @@ make_fallback(aten.max_pool3d_with_indices_backward)
make_fallback(aten._pdist_backward) make_fallback(aten._pdist_backward)
make_fallback(aten.reflection_pad1d_backward) make_fallback(aten.reflection_pad1d_backward)
make_fallback(aten.replication_pad1d_backward) make_fallback(aten.replication_pad1d_backward)
make_fallback(aten.replication_pad2d_backward)
make_fallback(aten.soft_margin_loss_backward, warn=False) make_fallback(aten.soft_margin_loss_backward, warn=False)
make_fallback(aten.linalg_pinv.atol_rtol_tensor) make_fallback(aten.linalg_pinv.atol_rtol_tensor)
make_fallback(aten.segment_reduce.default) make_fallback(aten.segment_reduce.default)

View File

@ -4418,7 +4418,8 @@ def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] =
return torch.affine_grid_generator(theta, size, align_corners) return torch.affine_grid_generator(theta, size, align_corners)
def pad(input: Tensor, pad: List[int], mode: str = "constant", value: Optional[float] = None) -> Tensor: pad = _add_docstr(
torch._C._nn.pad,
r""" r"""
pad(input, pad, mode="constant", value=None) -> Tensor pad(input, pad, mode="constant", value=None) -> Tensor
@ -4479,21 +4480,7 @@ Examples::
>>> print(out.size()) >>> print(out.size())
torch.Size([3, 9, 7, 3]) torch.Size([3, 9, 7, 3])
""" """)
if has_torch_function_unary(input):
return handle_torch_function(
torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value)
if not torch.jit.is_scripting():
if torch.are_deterministic_algorithms_enabled() and input.is_cuda:
if len(pad) == 4 and (input.dim() == 3 or input.dim() == 4) and mode == 'replicate':
# Use slow decomp whose backward will be in terms of index_put.
# importlib is required because the import cannot be top level
# (cycle) and cannot be nested (TS doesn't support)
return importlib.import_module('torch._decomp.decompositions').replication_pad2d(
input, pad
)
return torch._C._nn.pad(input, pad, mode, value)
# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 # TODO: Fix via https://github.com/pytorch/pytorch/issues/75798
pad.__module__ = "torch.nn.functional" pad.__module__ = "torch.nn.functional"

View File

@ -5596,19 +5596,6 @@ def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs):
for shape, pad in cases: for shape, pad in cases:
yield SampleInput(make_inp(shape), args=(pad, mode, pad_value)) yield SampleInput(make_inp(shape), args=(pad, mode, pad_value))
def sample_inputs_nn_pad_replicate_negative(op_info, device, dtype, requires_grad, **kwargs):
cases: tuple = (
((5, 3, 4, 4), (-4, 5, 0, 0)),
((6, 2, 4, 4), (0, 0, 2, -4)),
((5, 6, 4, 4), (5, -4, -4, 3)),
((4, 2, 5, 5), (-2, -1, 4, 6)),
((2, 6, 5, 5), (8, -1, -1, -3)),
((8, 1, 5, 5), (-2, -1, -1, -3)),
)
make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
for shape, pad in cases:
yield SampleInput(make_inp(shape), args=(pad, 'replicate'))
def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs): def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs):
# Inherit sample inputs from nn.pad, but transform them to fit # Inherit sample inputs from nn.pad, but transform them to fit
@ -13423,23 +13410,6 @@ op_db: List[OpInfo] = [
), ),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
supports_out=False), supports_out=False),
OpInfo('nn.functional.pad',
variant_test_name='replicate_negative',
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_nn_pad_replicate_negative,
skips=(
# Doesn't have a corresponding aten operator.
# RuntimeError: falseINTERNAL ASSERT FAILED at
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
# Some negative padding cases cause a segfault on MPS
DecorateInfo(unittest.skip("Not fully supported on MPS"), 'TestConsistency'),
),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
supports_out=False),
OpInfo('nn.functional.pad', OpInfo('nn.functional.pad',
variant_test_name='circular', variant_test_name='circular',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),