diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index 7f2ca7258cb..bd9077fce32 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -6,9 +6,26 @@ import torch import torch._dynamo import torch._dynamo.test_case from torch._dynamo.testing import CompileCounter, rand_strided +from torch._dynamo.utils import ifdyn, ifdynstaticdefault from torch.testing._internal.common_utils import compare_equal_outs_and_grads +def maybe_dupe_op(x): + y = x + 1 + z = x + 2 + if x.numel() < 5: + return y, y + else: + return y, z + + +aten = torch.ops.aten +lib = torch.library.Library("custom", "DEF") +lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)") +lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU") +lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta") + + class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): def test_LSTM(self): # https://github.com/pytorch/torchdynamo/issues/1147 @@ -385,12 +402,13 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): fxy(x1, y1) fxy(x2, y2) - self.assertTrue(failure_reason is None) + if not torch._dynamo.config.dynamic_shapes: + self.assertTrue(failure_reason is None) # Reset failure reason failure_reason = None - self.assertEqual(cc.frame_count, 1) + self.assertEqual(cc.frame_count, ifdyn(ifdynstaticdefault(1, 2), 1)) torch._dynamo.reset() # for new backend cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") @@ -424,10 +442,19 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): f(a) f(a) self.assertEqual(cc.frame_count, 2) - self.assertExpectedInline( - failure_reason, - """tensor 'L['a']' stride mismatch at index 0. expected 3, actual 1""", - ) + if ( + torch._dynamo.config.dynamic_shapes + and not torch._dynamo.config.assume_static_by_default + ): + self.assertExpectedInline( + failure_reason, + """tensor 'L['a']' stride mismatch at index 1. expected 1, actual 3""", + ) + else: + self.assertExpectedInline( + failure_reason, + """tensor 'L['a']' stride mismatch at index 0. expected 3, actual 1""", + ) torch._dynamo.reset() @@ -665,21 +692,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): self.assertExpectedInline(failure_reason, """L['c'] is L['d']""") @patch("torch._functorch.config.debug_assert", True) + @patch("torch._dynamo.config.dynamic_shapes", False) def test_multiple_aot_autograd_calls_dupe_args(self): - def maybe_dupe_op(x): - y = x + 1 - z = x + 2 - if x.numel() < 5: - return y, y - else: - return y, z - - aten = torch.ops.aten - lib = torch.library.Library("custom", "DEF") - lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)") - lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU") - lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta") - # this is just dealing with the fact that # aot_module_simplified expects submods to always return tuples/lists class WrapperModule(torch.nn.Module): diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 700a3f0fbfc..4d422d71048 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -4,6 +4,7 @@ from torch._dynamo.testing import make_test_cls_with_patches try: from . import ( + test_aot_autograd, test_ctx_manager, test_export, test_functions, @@ -14,6 +15,7 @@ try: test_subgraphs, ) except ImportError: + import test_aot_autograd import test_ctx_manager import test_export import test_functions @@ -82,6 +84,7 @@ tests = [ test_export.ExportTests, test_subgraphs.SubGraphTests, test_higher_order_ops.HigherOrderOpTests, + test_aot_autograd.AotAutogradFallbackTests, ] for test in tests: make_dynamic_cls(test) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 46d8f79e8ad..c5836237a90 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1577,9 +1577,6 @@ inplace_symbolic_tensor_failures = { xfail('unique', ''), # in-place has a different signature than out-of-place xfail('uniform', ''), - # Views - xfail('t', ''), - xfail('transpose', ''), } # Copies inputs to inplace operations to avoid inplace modifications diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 579b5fc39d6..42a38c1b6fd 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -1954,13 +1954,15 @@ def aot_wrapper_dedupe( duped_arg_len = len(flat_args) j = 0 # index into deduped_flat_args - for i, t in enumerate(flat_args): - if t in seen_args: - keep_arg_mask.append(False) - add_dupe_map.append(seen_args[t]) - continue + for t in flat_args: + if isinstance(t, torch.Tensor): + if t in seen_args: + keep_arg_mask.append(False) + add_dupe_map.append(seen_args[t]) + continue + seen_args[t] = j + keep_arg_mask.append(True) - seen_args[t] = j add_dupe_map.append(j) j += 1 assert len(add_dupe_map) == duped_arg_len, ( diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index ea7254e7570..0c255c7e7de 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3103,6 +3103,51 @@ def nan_to_num(self, nan=None, posinf=None, neginf=None): return self.new_empty(result_size) +@register_meta(torch.ops.aten.transpose_) +def transpose_(self, dim0, dim1): + assert self.layout not in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }, f"torch.transpose_: in-place transposition is not supported for {self.layout} layout" + + ndims = self.ndim + + dim0 = maybe_wrap_dim(dim0, ndims) + dim1 = maybe_wrap_dim(dim1, ndims) + + if dim0 == dim1: + return self + + size = list(self.size()) + stride = list(self.stride()) + + stride[dim0], stride[dim1] = stride[dim1], stride[dim0] + size[dim0], size[dim1] = size[dim1], size[dim0] + + self.as_strided_(size, stride) + return self + + +@register_meta(torch.ops.aten.t_) +def t_(self): + ndims = self.ndim + + if self.is_sparse: + sparse_dim = self.sparse_dim() + dense_dim = self.dense_dim() + assert ( + sparse_dim <= 2 and dense_dim == 0 + ), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions" # noqa: B950 + else: + assert ( + self.dim() <= 2 + ), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D" + + return transpose_(self, 0, 0 if ndims < 2 else 1) + + # We must also trigger meta registrations from PrimTorch ref # decompositions import torch._refs