mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add AotAutogradFallbackTests to dynamic suite (#100454)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100454 Approved by: https://github.com/ezyang
This commit is contained in:
parent
2dca418112
commit
fe3ecfe0cf
|
|
@ -6,9 +6,26 @@ import torch
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
from torch._dynamo.testing import CompileCounter, rand_strided
|
from torch._dynamo.testing import CompileCounter, rand_strided
|
||||||
|
from torch._dynamo.utils import ifdyn, ifdynstaticdefault
|
||||||
from torch.testing._internal.common_utils import compare_equal_outs_and_grads
|
from torch.testing._internal.common_utils import compare_equal_outs_and_grads
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_dupe_op(x):
|
||||||
|
y = x + 1
|
||||||
|
z = x + 2
|
||||||
|
if x.numel() < 5:
|
||||||
|
return y, y
|
||||||
|
else:
|
||||||
|
return y, z
|
||||||
|
|
||||||
|
|
||||||
|
aten = torch.ops.aten
|
||||||
|
lib = torch.library.Library("custom", "DEF")
|
||||||
|
lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)")
|
||||||
|
lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU")
|
||||||
|
lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")
|
||||||
|
|
||||||
|
|
||||||
class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
||||||
def test_LSTM(self):
|
def test_LSTM(self):
|
||||||
# https://github.com/pytorch/torchdynamo/issues/1147
|
# https://github.com/pytorch/torchdynamo/issues/1147
|
||||||
|
|
@ -385,12 +402,13 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
||||||
fxy(x1, y1)
|
fxy(x1, y1)
|
||||||
fxy(x2, y2)
|
fxy(x2, y2)
|
||||||
|
|
||||||
self.assertTrue(failure_reason is None)
|
if not torch._dynamo.config.dynamic_shapes:
|
||||||
|
self.assertTrue(failure_reason is None)
|
||||||
|
|
||||||
# Reset failure reason
|
# Reset failure reason
|
||||||
failure_reason = None
|
failure_reason = None
|
||||||
|
|
||||||
self.assertEqual(cc.frame_count, 1)
|
self.assertEqual(cc.frame_count, ifdyn(ifdynstaticdefault(1, 2), 1))
|
||||||
|
|
||||||
torch._dynamo.reset() # for new backend
|
torch._dynamo.reset() # for new backend
|
||||||
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||||
|
|
@ -424,10 +442,19 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
||||||
f(a)
|
f(a)
|
||||||
f(a)
|
f(a)
|
||||||
self.assertEqual(cc.frame_count, 2)
|
self.assertEqual(cc.frame_count, 2)
|
||||||
self.assertExpectedInline(
|
if (
|
||||||
failure_reason,
|
torch._dynamo.config.dynamic_shapes
|
||||||
"""tensor 'L['a']' stride mismatch at index 0. expected 3, actual 1""",
|
and not torch._dynamo.config.assume_static_by_default
|
||||||
)
|
):
|
||||||
|
self.assertExpectedInline(
|
||||||
|
failure_reason,
|
||||||
|
"""tensor 'L['a']' stride mismatch at index 1. expected 1, actual 3""",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertExpectedInline(
|
||||||
|
failure_reason,
|
||||||
|
"""tensor 'L['a']' stride mismatch at index 0. expected 3, actual 1""",
|
||||||
|
)
|
||||||
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
|
@ -665,21 +692,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
||||||
self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
|
self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
|
||||||
|
|
||||||
@patch("torch._functorch.config.debug_assert", True)
|
@patch("torch._functorch.config.debug_assert", True)
|
||||||
|
@patch("torch._dynamo.config.dynamic_shapes", False)
|
||||||
def test_multiple_aot_autograd_calls_dupe_args(self):
|
def test_multiple_aot_autograd_calls_dupe_args(self):
|
||||||
def maybe_dupe_op(x):
|
|
||||||
y = x + 1
|
|
||||||
z = x + 2
|
|
||||||
if x.numel() < 5:
|
|
||||||
return y, y
|
|
||||||
else:
|
|
||||||
return y, z
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
|
||||||
lib = torch.library.Library("custom", "DEF")
|
|
||||||
lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)")
|
|
||||||
lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU")
|
|
||||||
lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")
|
|
||||||
|
|
||||||
# this is just dealing with the fact that
|
# this is just dealing with the fact that
|
||||||
# aot_module_simplified expects submods to always return tuples/lists
|
# aot_module_simplified expects submods to always return tuples/lists
|
||||||
class WrapperModule(torch.nn.Module):
|
class WrapperModule(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from torch._dynamo.testing import make_test_cls_with_patches
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from . import (
|
from . import (
|
||||||
|
test_aot_autograd,
|
||||||
test_ctx_manager,
|
test_ctx_manager,
|
||||||
test_export,
|
test_export,
|
||||||
test_functions,
|
test_functions,
|
||||||
|
|
@ -14,6 +15,7 @@ try:
|
||||||
test_subgraphs,
|
test_subgraphs,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
import test_aot_autograd
|
||||||
import test_ctx_manager
|
import test_ctx_manager
|
||||||
import test_export
|
import test_export
|
||||||
import test_functions
|
import test_functions
|
||||||
|
|
@ -82,6 +84,7 @@ tests = [
|
||||||
test_export.ExportTests,
|
test_export.ExportTests,
|
||||||
test_subgraphs.SubGraphTests,
|
test_subgraphs.SubGraphTests,
|
||||||
test_higher_order_ops.HigherOrderOpTests,
|
test_higher_order_ops.HigherOrderOpTests,
|
||||||
|
test_aot_autograd.AotAutogradFallbackTests,
|
||||||
]
|
]
|
||||||
for test in tests:
|
for test in tests:
|
||||||
make_dynamic_cls(test)
|
make_dynamic_cls(test)
|
||||||
|
|
|
||||||
|
|
@ -1577,9 +1577,6 @@ inplace_symbolic_tensor_failures = {
|
||||||
xfail('unique', ''),
|
xfail('unique', ''),
|
||||||
# in-place has a different signature than out-of-place
|
# in-place has a different signature than out-of-place
|
||||||
xfail('uniform', ''),
|
xfail('uniform', ''),
|
||||||
# Views
|
|
||||||
xfail('t', ''),
|
|
||||||
xfail('transpose', ''),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Copies inputs to inplace operations to avoid inplace modifications
|
# Copies inputs to inplace operations to avoid inplace modifications
|
||||||
|
|
|
||||||
|
|
@ -1954,13 +1954,15 @@ def aot_wrapper_dedupe(
|
||||||
duped_arg_len = len(flat_args)
|
duped_arg_len = len(flat_args)
|
||||||
|
|
||||||
j = 0 # index into deduped_flat_args
|
j = 0 # index into deduped_flat_args
|
||||||
for i, t in enumerate(flat_args):
|
for t in flat_args:
|
||||||
if t in seen_args:
|
if isinstance(t, torch.Tensor):
|
||||||
keep_arg_mask.append(False)
|
if t in seen_args:
|
||||||
add_dupe_map.append(seen_args[t])
|
keep_arg_mask.append(False)
|
||||||
continue
|
add_dupe_map.append(seen_args[t])
|
||||||
|
continue
|
||||||
|
seen_args[t] = j
|
||||||
|
|
||||||
keep_arg_mask.append(True)
|
keep_arg_mask.append(True)
|
||||||
seen_args[t] = j
|
|
||||||
add_dupe_map.append(j)
|
add_dupe_map.append(j)
|
||||||
j += 1
|
j += 1
|
||||||
assert len(add_dupe_map) == duped_arg_len, (
|
assert len(add_dupe_map) == duped_arg_len, (
|
||||||
|
|
|
||||||
|
|
@ -3103,6 +3103,51 @@ def nan_to_num(self, nan=None, posinf=None, neginf=None):
|
||||||
return self.new_empty(result_size)
|
return self.new_empty(result_size)
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta(torch.ops.aten.transpose_)
|
||||||
|
def transpose_(self, dim0, dim1):
|
||||||
|
assert self.layout not in {
|
||||||
|
torch.sparse_csr,
|
||||||
|
torch.sparse_csc,
|
||||||
|
torch.sparse_bsr,
|
||||||
|
torch.sparse_bsc,
|
||||||
|
}, f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
|
||||||
|
|
||||||
|
ndims = self.ndim
|
||||||
|
|
||||||
|
dim0 = maybe_wrap_dim(dim0, ndims)
|
||||||
|
dim1 = maybe_wrap_dim(dim1, ndims)
|
||||||
|
|
||||||
|
if dim0 == dim1:
|
||||||
|
return self
|
||||||
|
|
||||||
|
size = list(self.size())
|
||||||
|
stride = list(self.stride())
|
||||||
|
|
||||||
|
stride[dim0], stride[dim1] = stride[dim1], stride[dim0]
|
||||||
|
size[dim0], size[dim1] = size[dim1], size[dim0]
|
||||||
|
|
||||||
|
self.as_strided_(size, stride)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta(torch.ops.aten.t_)
|
||||||
|
def t_(self):
|
||||||
|
ndims = self.ndim
|
||||||
|
|
||||||
|
if self.is_sparse:
|
||||||
|
sparse_dim = self.sparse_dim()
|
||||||
|
dense_dim = self.dense_dim()
|
||||||
|
assert (
|
||||||
|
sparse_dim <= 2 and dense_dim == 0
|
||||||
|
), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions" # noqa: B950
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
self.dim() <= 2
|
||||||
|
), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
|
||||||
|
|
||||||
|
return transpose_(self, 0, 0 if ndims < 2 else 1)
|
||||||
|
|
||||||
|
|
||||||
# We must also trigger meta registrations from PrimTorch ref
|
# We must also trigger meta registrations from PrimTorch ref
|
||||||
# decompositions
|
# decompositions
|
||||||
import torch._refs
|
import torch._refs
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user