mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Add decomp for replication_pad2d and use for CUDA deterministic (#111590)"
This reverts commit f1286161a6.
Reverted https://github.com/pytorch/pytorch/pull/111590 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing XLA job. The job is also failing on the PR, but the log classifier failed to find the failed test which lead to it being marked wrongly as flaky ([comment](https://github.com/pytorch/pytorch/pull/111590#issuecomment-1833004794))
This commit is contained in:
parent
9f3ec2ad45
commit
013675ff59
|
|
@ -1105,6 +1105,7 @@ aten::replication_pad1d
|
|||
aten::replication_pad1d.out
|
||||
aten::replication_pad1d_backward
|
||||
aten::replication_pad1d_backward.grad_input
|
||||
aten::replication_pad2d
|
||||
aten::replication_pad2d.out
|
||||
aten::replication_pad2d_backward
|
||||
aten::replication_pad2d_backward.grad_input
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ graph {
|
|||
}
|
||||
}
|
||||
node {
|
||||
input: "input"
|
||||
input: "onnx::Pad_0"
|
||||
input: "onnx::Pad_22"
|
||||
output: "23"
|
||||
name: "Pad_23"
|
||||
|
|
@ -212,7 +212,7 @@ graph {
|
|||
}
|
||||
name: "main_graph"
|
||||
input {
|
||||
name: "input"
|
||||
name: "onnx::Pad_0"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
|
|
|
|||
|
|
@ -4208,7 +4208,7 @@ class TestFunctionalTracing(JitTestCase):
|
|||
"linear": BUILT_IN_FUNC,
|
||||
"logsigmoid": BUILT_IN_FUNC,
|
||||
"one_hot": BUILT_IN_FUNC,
|
||||
"pad": ARG_TYPE_MISMATCH,
|
||||
"pad": BUILT_IN_FUNC,
|
||||
"pairwise_distance": BUILT_IN_FUNC,
|
||||
"pdist": BUILT_IN_FUNC,
|
||||
"pixel_shuffle": BUILT_IN_FUNC,
|
||||
|
|
|
|||
|
|
@ -1514,33 +1514,6 @@ else:
|
|||
'upsample_bilinear2d_backward_out_cuda',
|
||||
torch.device(device).type == 'cuda')
|
||||
|
||||
@skipIfTorchInductor("aot-autograd issue")
|
||||
def test_deterministic_replication_pad2d(self, device):
|
||||
test_cases = [
|
||||
# size, padding
|
||||
[(1, 2, 4, 4), (0, 0, 0, 0)],
|
||||
[(1, 2, 4, 4), (3, 4, 5, 6)],
|
||||
[(4, 3, 5, 10), (-9, 4, 5, 6)],
|
||||
[(3, 8, 7), (0, 0, 0, 0)],
|
||||
[(3, 8, 7), (-4, -2, -2, -3)],
|
||||
[(3, 8, 7), (4, 3, 2, 7)],
|
||||
]
|
||||
|
||||
for size, padding in test_cases:
|
||||
input = torch.randn(*size, device=device, requires_grad=True)
|
||||
grad = None
|
||||
with DeterministicGuard(True):
|
||||
res = torch.nn.functional.pad(
|
||||
input,
|
||||
padding,
|
||||
mode='replicate')
|
||||
res.backward(torch.ones_like(res))
|
||||
if grad is None:
|
||||
grad = input.grad
|
||||
else:
|
||||
self.assertEqual(grad, input.grad, atol=0, rtol=0)
|
||||
input.grad = None
|
||||
|
||||
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
|
||||
def test_deterministic_interpolate_bilinear(self, device):
|
||||
input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True)
|
||||
|
|
@ -1649,25 +1622,11 @@ else:
|
|||
res = module(input)
|
||||
grad = torch.ones_like(res)
|
||||
|
||||
# Nondeterministic alert should only be raised if the forward call was
|
||||
# nondeterministic
|
||||
self.check_nondeterministic_alert(
|
||||
lambda: res.backward(grad, retain_graph=True),
|
||||
'replication_pad2d_backward_cuda',
|
||||
torch.device(device).type == 'cuda')
|
||||
|
||||
with DeterministicGuard(True):
|
||||
res = module(input)
|
||||
|
||||
grad = torch.ones_like(res)
|
||||
|
||||
# If the forward call was deterministic, nondeterministic alert should
|
||||
# not be raised
|
||||
self.check_nondeterministic_alert(
|
||||
lambda: res.backward(grad, retain_graph=True),
|
||||
'replication_pad2d_backward_cuda',
|
||||
False)
|
||||
|
||||
@skipIfMps
|
||||
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
|
||||
def test_nondeterministic_alert_ReplicationPad3d(self, device):
|
||||
|
|
|
|||
|
|
@ -754,7 +754,6 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo
|
|||
* :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor
|
||||
* :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor
|
||||
* :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor
|
||||
* :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor
|
||||
* :func:`torch.bmm` when called on sparse-dense CUDA tensors
|
||||
* :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
|
||||
and the index is a list of tensors
|
||||
|
|
@ -797,6 +796,7 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo
|
|||
* :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor
|
||||
* :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor
|
||||
* :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor
|
||||
* :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor
|
||||
* :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor
|
||||
* :class:`torch.nn.NLLLoss` when called on a CUDA tensor
|
||||
* :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor
|
||||
|
|
|
|||
|
|
@ -375,7 +375,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
|||
aten.rad2deg_,
|
||||
aten.renorm,
|
||||
aten.renorm_,
|
||||
aten.replication_pad2d,
|
||||
aten.rot90,
|
||||
aten.rrelu_with_noise,
|
||||
aten.rrelu_with_noise_,
|
||||
|
|
|
|||
|
|
@ -3344,96 +3344,6 @@ def upsample_bilinear2d(
|
|||
return result
|
||||
|
||||
|
||||
@register_decomposition(aten.replication_pad2d.default)
|
||||
@pw_cast_for_opmath
|
||||
def replication_pad2d(input: Tensor, padding: List[int]) -> Tensor:
|
||||
pad_left = padding[0]
|
||||
pad_right = padding[1]
|
||||
pad_top = padding[2]
|
||||
pad_bottom = padding[3]
|
||||
|
||||
# If all of the padding values are non-negative, then the following tensors
|
||||
# are all equal to the input. But if any padding values are negative, we
|
||||
# have to remove the appropriate rows and columns from the input.
|
||||
# `input_mid` has all negative padding removed from it. `input_mid_tb` has
|
||||
# negative left and right padding removed from it. `input_mid_lr` has
|
||||
# negative top and bottom padding removed from it.
|
||||
input_mid = input
|
||||
input_mid_tb = input
|
||||
input_mid_lr = input
|
||||
|
||||
if pad_left < 0:
|
||||
input_mid = input_mid[..., -pad_left:]
|
||||
input_mid_tb = input_mid_tb[..., -pad_left:]
|
||||
pad_left = 0
|
||||
|
||||
if pad_right < 0:
|
||||
input_mid = input_mid[..., :pad_right]
|
||||
input_mid_tb = input_mid_tb[..., :pad_right]
|
||||
pad_right = 0
|
||||
|
||||
if pad_top < 0:
|
||||
input_mid = input_mid[..., -pad_top:, :]
|
||||
input_mid_lr = input_mid_lr[..., -pad_top:, :]
|
||||
pad_top = 0
|
||||
|
||||
if pad_bottom < 0:
|
||||
input_mid = input_mid[..., :pad_bottom, :]
|
||||
input_mid_lr = input_mid_lr[..., :pad_bottom, :]
|
||||
pad_bottom = 0
|
||||
|
||||
batch_dims_no_repeat = (1,) * (input.dim() - 2)
|
||||
|
||||
repeat_top_left = batch_dims_no_repeat + (pad_top, pad_left)
|
||||
repeat_top_middle = batch_dims_no_repeat + (pad_top, 1)
|
||||
repeat_top_right = batch_dims_no_repeat + (pad_top, pad_right)
|
||||
|
||||
top_rows = torch.cat(
|
||||
[
|
||||
# top left
|
||||
input[..., [0], :][..., [0]].repeat(repeat_top_left),
|
||||
# top middle
|
||||
input_mid_tb[..., [0], :].repeat(repeat_top_middle),
|
||||
# top right
|
||||
input[..., [0], :][..., [-1]].repeat(repeat_top_right),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
repeat_middle_left = batch_dims_no_repeat + (1, pad_left)
|
||||
repeat_middle_right = batch_dims_no_repeat + (1, pad_right)
|
||||
|
||||
middle_rows = torch.cat(
|
||||
[
|
||||
# middle left
|
||||
input_mid_lr[..., [0]].repeat(repeat_middle_left),
|
||||
# middle middle
|
||||
input_mid,
|
||||
# middle right
|
||||
input_mid_lr[..., [-1]].repeat(repeat_middle_right),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
repeat_bottom_left = batch_dims_no_repeat + (pad_bottom, pad_left)
|
||||
repeat_bottom_middle = batch_dims_no_repeat + (pad_bottom, 1)
|
||||
repeat_bottom_right = batch_dims_no_repeat + (pad_bottom, pad_right)
|
||||
|
||||
bottom_rows = torch.cat(
|
||||
[
|
||||
# bottom left
|
||||
input[..., [-1], :][..., [0]].repeat(repeat_bottom_left),
|
||||
# bottom middle
|
||||
input_mid_tb[..., [-1], :].repeat(repeat_bottom_middle),
|
||||
# bottom right
|
||||
input[..., [-1], :][..., [-1]].repeat(repeat_bottom_right),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
return torch.cat([top_rows, middle_rows, bottom_rows], dim=-2)
|
||||
|
||||
|
||||
# We should be applying decompositions after all transformations
|
||||
@register_decomposition(aten.is_same_size.default)
|
||||
def is_same_size(a: Tensor, b: Tensor) -> bool:
|
||||
|
|
|
|||
|
|
@ -2235,7 +2235,6 @@ make_fallback(aten.max_pool3d_with_indices_backward)
|
|||
make_fallback(aten._pdist_backward)
|
||||
make_fallback(aten.reflection_pad1d_backward)
|
||||
make_fallback(aten.replication_pad1d_backward)
|
||||
make_fallback(aten.replication_pad2d_backward)
|
||||
make_fallback(aten.soft_margin_loss_backward, warn=False)
|
||||
make_fallback(aten.linalg_pinv.atol_rtol_tensor)
|
||||
make_fallback(aten.segment_reduce.default)
|
||||
|
|
|
|||
|
|
@ -4418,7 +4418,8 @@ def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] =
|
|||
return torch.affine_grid_generator(theta, size, align_corners)
|
||||
|
||||
|
||||
def pad(input: Tensor, pad: List[int], mode: str = "constant", value: Optional[float] = None) -> Tensor:
|
||||
pad = _add_docstr(
|
||||
torch._C._nn.pad,
|
||||
r"""
|
||||
pad(input, pad, mode="constant", value=None) -> Tensor
|
||||
|
||||
|
|
@ -4479,21 +4480,7 @@ Examples::
|
|||
>>> print(out.size())
|
||||
torch.Size([3, 9, 7, 3])
|
||||
|
||||
"""
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(
|
||||
torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value)
|
||||
if not torch.jit.is_scripting():
|
||||
if torch.are_deterministic_algorithms_enabled() and input.is_cuda:
|
||||
if len(pad) == 4 and (input.dim() == 3 or input.dim() == 4) and mode == 'replicate':
|
||||
# Use slow decomp whose backward will be in terms of index_put.
|
||||
# importlib is required because the import cannot be top level
|
||||
# (cycle) and cannot be nested (TS doesn't support)
|
||||
return importlib.import_module('torch._decomp.decompositions').replication_pad2d(
|
||||
input, pad
|
||||
)
|
||||
return torch._C._nn.pad(input, pad, mode, value)
|
||||
|
||||
""")
|
||||
# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798
|
||||
pad.__module__ = "torch.nn.functional"
|
||||
|
||||
|
|
|
|||
|
|
@ -5596,19 +5596,6 @@ def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs):
|
|||
for shape, pad in cases:
|
||||
yield SampleInput(make_inp(shape), args=(pad, mode, pad_value))
|
||||
|
||||
def sample_inputs_nn_pad_replicate_negative(op_info, device, dtype, requires_grad, **kwargs):
|
||||
cases: tuple = (
|
||||
((5, 3, 4, 4), (-4, 5, 0, 0)),
|
||||
((6, 2, 4, 4), (0, 0, 2, -4)),
|
||||
((5, 6, 4, 4), (5, -4, -4, 3)),
|
||||
((4, 2, 5, 5), (-2, -1, 4, 6)),
|
||||
((2, 6, 5, 5), (8, -1, -1, -3)),
|
||||
((8, 1, 5, 5), (-2, -1, -1, -3)),
|
||||
)
|
||||
make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
for shape, pad in cases:
|
||||
yield SampleInput(make_inp(shape), args=(pad, 'replicate'))
|
||||
|
||||
def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs):
|
||||
# Inherit sample inputs from nn.pad, but transform them to fit
|
||||
|
|
@ -13423,23 +13410,6 @@ op_db: List[OpInfo] = [
|
|||
),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
supports_out=False),
|
||||
OpInfo('nn.functional.pad',
|
||||
variant_test_name='replicate_negative',
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_nn_pad_replicate_negative,
|
||||
skips=(
|
||||
# Doesn't have a corresponding aten operator.
|
||||
# RuntimeError: falseINTERNAL ASSERT FAILED at
|
||||
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
|
||||
# Some negative padding cases cause a segfault on MPS
|
||||
DecorateInfo(unittest.skip("Not fully supported on MPS"), 'TestConsistency'),
|
||||
),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
supports_out=False),
|
||||
OpInfo('nn.functional.pad',
|
||||
variant_test_name='circular',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user