diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 64223bc644d..5694725a291 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1105,6 +1105,7 @@ aten::replication_pad1d aten::replication_pad1d.out aten::replication_pad1d_backward aten::replication_pad1d_backward.grad_input +aten::replication_pad2d aten::replication_pad2d.out aten::replication_pad2d_backward aten::replication_pad2d_backward.grad_input diff --git a/test/onnx/expect/TestOperators.test_pad.expect b/test/onnx/expect/TestOperators.test_pad.expect index 862e80061d2..1319a22713d 100644 --- a/test/onnx/expect/TestOperators.test_pad.expect +++ b/test/onnx/expect/TestOperators.test_pad.expect @@ -199,7 +199,7 @@ graph { } } node { - input: "input" + input: "onnx::Pad_0" input: "onnx::Pad_22" output: "23" name: "Pad_23" @@ -212,7 +212,7 @@ graph { } name: "main_graph" input { - name: "input" + name: "onnx::Pad_0" type { tensor_type { elem_type: 1 diff --git a/test/test_fx.py b/test/test_fx.py index f06aba8cc0c..fa63e79cb46 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4208,7 +4208,7 @@ class TestFunctionalTracing(JitTestCase): "linear": BUILT_IN_FUNC, "logsigmoid": BUILT_IN_FUNC, "one_hot": BUILT_IN_FUNC, - "pad": ARG_TYPE_MISMATCH, + "pad": BUILT_IN_FUNC, "pairwise_distance": BUILT_IN_FUNC, "pdist": BUILT_IN_FUNC, "pixel_shuffle": BUILT_IN_FUNC, diff --git a/test/test_torch.py b/test/test_torch.py index 1c3ea80a9aa..e563ee2815e 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1514,33 +1514,6 @@ else: 'upsample_bilinear2d_backward_out_cuda', torch.device(device).type == 'cuda') - @skipIfTorchInductor("aot-autograd issue") - def test_deterministic_replication_pad2d(self, device): - test_cases = [ - # size, padding - [(1, 2, 4, 4), (0, 0, 0, 0)], - [(1, 2, 4, 4), (3, 4, 5, 6)], - [(4, 3, 5, 10), (-9, 4, 5, 6)], - [(3, 8, 7), (0, 0, 0, 0)], - [(3, 8, 7), (-4, -2, -2, -3)], - [(3, 8, 7), (4, 3, 2, 7)], - ] - - for size, padding in test_cases: - input = torch.randn(*size, device=device, requires_grad=True) - grad = None - with DeterministicGuard(True): - res = torch.nn.functional.pad( - input, - padding, - mode='replicate') - res.backward(torch.ones_like(res)) - if grad is None: - grad = input.grad - else: - self.assertEqual(grad, input.grad, atol=0, rtol=0) - input.grad = None - @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") def test_deterministic_interpolate_bilinear(self, device): input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True) @@ -1649,25 +1622,11 @@ else: res = module(input) grad = torch.ones_like(res) - # Nondeterministic alert should only be raised if the forward call was - # nondeterministic self.check_nondeterministic_alert( lambda: res.backward(grad, retain_graph=True), 'replication_pad2d_backward_cuda', torch.device(device).type == 'cuda') - with DeterministicGuard(True): - res = module(input) - - grad = torch.ones_like(res) - - # If the forward call was deterministic, nondeterministic alert should - # not be raised - self.check_nondeterministic_alert( - lambda: res.backward(grad, retain_graph=True), - 'replication_pad2d_backward_cuda', - False) - @skipIfMps @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") def test_nondeterministic_alert_ReplicationPad3d(self, device): diff --git a/torch/__init__.py b/torch/__init__.py index 0eaf86277b5..a29668a5673 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -754,7 +754,6 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor - * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor * :func:`torch.bmm` when called on sparse-dense CUDA tensors * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor and the index is a list of tensors @@ -797,6 +796,7 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.NLLLoss` when called on a CUDA tensor * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 2ab501c853c..38c94fb8eae 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -375,7 +375,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.rad2deg_, aten.renorm, aten.renorm_, - aten.replication_pad2d, aten.rot90, aten.rrelu_with_noise, aten.rrelu_with_noise_, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 5bb882cbc62..645bc3259cd 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -3344,96 +3344,6 @@ def upsample_bilinear2d( return result -@register_decomposition(aten.replication_pad2d.default) -@pw_cast_for_opmath -def replication_pad2d(input: Tensor, padding: List[int]) -> Tensor: - pad_left = padding[0] - pad_right = padding[1] - pad_top = padding[2] - pad_bottom = padding[3] - - # If all of the padding values are non-negative, then the following tensors - # are all equal to the input. But if any padding values are negative, we - # have to remove the appropriate rows and columns from the input. - # `input_mid` has all negative padding removed from it. `input_mid_tb` has - # negative left and right padding removed from it. `input_mid_lr` has - # negative top and bottom padding removed from it. - input_mid = input - input_mid_tb = input - input_mid_lr = input - - if pad_left < 0: - input_mid = input_mid[..., -pad_left:] - input_mid_tb = input_mid_tb[..., -pad_left:] - pad_left = 0 - - if pad_right < 0: - input_mid = input_mid[..., :pad_right] - input_mid_tb = input_mid_tb[..., :pad_right] - pad_right = 0 - - if pad_top < 0: - input_mid = input_mid[..., -pad_top:, :] - input_mid_lr = input_mid_lr[..., -pad_top:, :] - pad_top = 0 - - if pad_bottom < 0: - input_mid = input_mid[..., :pad_bottom, :] - input_mid_lr = input_mid_lr[..., :pad_bottom, :] - pad_bottom = 0 - - batch_dims_no_repeat = (1,) * (input.dim() - 2) - - repeat_top_left = batch_dims_no_repeat + (pad_top, pad_left) - repeat_top_middle = batch_dims_no_repeat + (pad_top, 1) - repeat_top_right = batch_dims_no_repeat + (pad_top, pad_right) - - top_rows = torch.cat( - [ - # top left - input[..., [0], :][..., [0]].repeat(repeat_top_left), - # top middle - input_mid_tb[..., [0], :].repeat(repeat_top_middle), - # top right - input[..., [0], :][..., [-1]].repeat(repeat_top_right), - ], - dim=-1, - ) - - repeat_middle_left = batch_dims_no_repeat + (1, pad_left) - repeat_middle_right = batch_dims_no_repeat + (1, pad_right) - - middle_rows = torch.cat( - [ - # middle left - input_mid_lr[..., [0]].repeat(repeat_middle_left), - # middle middle - input_mid, - # middle right - input_mid_lr[..., [-1]].repeat(repeat_middle_right), - ], - dim=-1, - ) - - repeat_bottom_left = batch_dims_no_repeat + (pad_bottom, pad_left) - repeat_bottom_middle = batch_dims_no_repeat + (pad_bottom, 1) - repeat_bottom_right = batch_dims_no_repeat + (pad_bottom, pad_right) - - bottom_rows = torch.cat( - [ - # bottom left - input[..., [-1], :][..., [0]].repeat(repeat_bottom_left), - # bottom middle - input_mid_tb[..., [-1], :].repeat(repeat_bottom_middle), - # bottom right - input[..., [-1], :][..., [-1]].repeat(repeat_bottom_right), - ], - dim=-1, - ) - - return torch.cat([top_rows, middle_rows, bottom_rows], dim=-2) - - # We should be applying decompositions after all transformations @register_decomposition(aten.is_same_size.default) def is_same_size(a: Tensor, b: Tensor) -> bool: diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 32b07280e58..b1789decf16 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2235,7 +2235,6 @@ make_fallback(aten.max_pool3d_with_indices_backward) make_fallback(aten._pdist_backward) make_fallback(aten.reflection_pad1d_backward) make_fallback(aten.replication_pad1d_backward) -make_fallback(aten.replication_pad2d_backward) make_fallback(aten.soft_margin_loss_backward, warn=False) make_fallback(aten.linalg_pinv.atol_rtol_tensor) make_fallback(aten.segment_reduce.default) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 707fc04f7e7..f2c888f2cb2 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -4418,7 +4418,8 @@ def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] = return torch.affine_grid_generator(theta, size, align_corners) -def pad(input: Tensor, pad: List[int], mode: str = "constant", value: Optional[float] = None) -> Tensor: +pad = _add_docstr( + torch._C._nn.pad, r""" pad(input, pad, mode="constant", value=None) -> Tensor @@ -4479,21 +4480,7 @@ Examples:: >>> print(out.size()) torch.Size([3, 9, 7, 3]) -""" - if has_torch_function_unary(input): - return handle_torch_function( - torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value) - if not torch.jit.is_scripting(): - if torch.are_deterministic_algorithms_enabled() and input.is_cuda: - if len(pad) == 4 and (input.dim() == 3 or input.dim() == 4) and mode == 'replicate': - # Use slow decomp whose backward will be in terms of index_put. - # importlib is required because the import cannot be top level - # (cycle) and cannot be nested (TS doesn't support) - return importlib.import_module('torch._decomp.decompositions').replication_pad2d( - input, pad - ) - return torch._C._nn.pad(input, pad, mode, value) - +""") # TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 pad.__module__ = "torch.nn.functional" diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 8274f9e3741..d704534b00b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -5596,19 +5596,6 @@ def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs): for shape, pad in cases: yield SampleInput(make_inp(shape), args=(pad, mode, pad_value)) -def sample_inputs_nn_pad_replicate_negative(op_info, device, dtype, requires_grad, **kwargs): - cases: tuple = ( - ((5, 3, 4, 4), (-4, 5, 0, 0)), - ((6, 2, 4, 4), (0, 0, 2, -4)), - ((5, 6, 4, 4), (5, -4, -4, 3)), - ((4, 2, 5, 5), (-2, -1, 4, 6)), - ((2, 6, 5, 5), (8, -1, -1, -3)), - ((8, 1, 5, 5), (-2, -1, -1, -3)), - ) - make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - - for shape, pad in cases: - yield SampleInput(make_inp(shape), args=(pad, 'replicate')) def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs): # Inherit sample inputs from nn.pad, but transform them to fit @@ -13423,23 +13410,6 @@ op_db: List[OpInfo] = [ ), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, supports_out=False), - OpInfo('nn.functional.pad', - variant_test_name='replicate_negative', - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), - sample_inputs_func=sample_inputs_nn_pad_replicate_negative, - skips=( - # Doesn't have a corresponding aten operator. - # RuntimeError: falseINTERNAL ASSERT FAILED at - # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. - DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), - # Some negative padding cases cause a segfault on MPS - DecorateInfo(unittest.skip("Not fully supported on MPS"), 'TestConsistency'), - ), - gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, - supports_out=False), OpInfo('nn.functional.pad', variant_test_name='circular', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),