mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[dynamo] Remove traceable_tensor_subclasses-related code (#151062)
Since #149792 deprecates `traceable_tensor_subclasses` and it's been landed for over a week, we can safely remove all the old code that uses `traceable_tensor_subclasses` (they were primarily for testing purposes and are equivalent to no-ops now). Pull Request resolved: https://github.com/pytorch/pytorch/pull/151062 Approved by: https://github.com/mlazos, https://github.com/anijain2305 ghstack dependencies: #151060, #151061
This commit is contained in:
parent
6a1499d209
commit
f66229de2b
|
|
@ -373,9 +373,7 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
|||
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
|
||||
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
with TestMode(), torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", {TestSubclass}
|
||||
):
|
||||
with TestMode():
|
||||
with torch._C.DisableTorchFunctionSubclass():
|
||||
expected = fn(inp)
|
||||
actual = fn_opt(inp)
|
||||
|
|
@ -404,9 +402,7 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
|||
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
|
||||
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
with TestMode(), torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", {TestSubclass}
|
||||
):
|
||||
with TestMode():
|
||||
expected = fn(inp)
|
||||
actual = fn_opt(inp)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
# ruff: noqa: F841
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
|
@ -1173,7 +1172,6 @@ def make_test(fn, expected_ops=None):
|
|||
return test_fn
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporary_tensor_subclass(torch_function=None):
|
||||
class TensorProxy(torch.Tensor):
|
||||
@classmethod
|
||||
|
|
@ -1182,11 +1180,7 @@ def temporary_tensor_subclass(torch_function=None):
|
|||
torch_function()
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
|
||||
try:
|
||||
yield TensorProxy
|
||||
finally:
|
||||
torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
|
||||
return TensorProxy
|
||||
|
||||
|
||||
class NNModuleTests(torch._dynamo.test_case.TestCase):
|
||||
|
|
@ -1377,15 +1371,15 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
|
|||
x = x.sigmoid()
|
||||
return x
|
||||
|
||||
with temporary_tensor_subclass() as TensorProxy:
|
||||
x = torch.randn(1).as_subclass(TensorProxy)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
out1 = foo(x)
|
||||
opt_foo = torch.compile(foo, backend=cnt, fullgraph=True)
|
||||
out2 = opt_foo(x)
|
||||
TensorProxy = temporary_tensor_subclass()
|
||||
x = torch.randn(1).as_subclass(TensorProxy)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
out1 = foo(x)
|
||||
opt_foo = torch.compile(foo, backend=cnt, fullgraph=True)
|
||||
out2 = opt_foo(x)
|
||||
|
||||
self.assertEqual(cnt.op_count, 4)
|
||||
self.assertTrue(torch._dynamo.testing.same(out1, out2))
|
||||
self.assertEqual(cnt.op_count, 4)
|
||||
self.assertTrue(torch._dynamo.testing.same(out1, out2))
|
||||
|
||||
def test_torch_function_with_closure(self):
|
||||
def run():
|
||||
|
|
@ -1406,16 +1400,16 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
|
|||
# TODO(future PR): support writes as well
|
||||
counter + 1
|
||||
|
||||
with temporary_tensor_subclass(function) as TensorProxy:
|
||||
x = torch.randn(1).as_subclass(TensorProxy)
|
||||
x = torch.randn(1)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
out1 = foo(x)
|
||||
opt_foo = torch.compile(foo, backend=cnt, fullgraph=True)
|
||||
out2 = opt_foo(x)
|
||||
TensorProxy = temporary_tensor_subclass(function)
|
||||
x = torch.randn(1).as_subclass(TensorProxy)
|
||||
x = torch.randn(1)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
out1 = foo(x)
|
||||
opt_foo = torch.compile(foo, backend=cnt, fullgraph=True)
|
||||
out2 = opt_foo(x)
|
||||
|
||||
self.assertEqual(cnt.op_count, 4)
|
||||
self.assertTrue(torch._dynamo.testing.same(out1, out2))
|
||||
self.assertEqual(cnt.op_count, 4)
|
||||
self.assertTrue(torch._dynamo.testing.same(out1, out2))
|
||||
|
||||
run()
|
||||
|
||||
|
|
@ -1437,36 +1431,36 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
|
|||
return x
|
||||
|
||||
try:
|
||||
with temporary_tensor_subclass() as TensorProxy:
|
||||
x = torch.randn(1).as_subclass(TensorProxy)
|
||||
x1 = one_break(x)
|
||||
TensorProxy = temporary_tensor_subclass()
|
||||
x = torch.randn(1).as_subclass(TensorProxy)
|
||||
x1 = one_break(x)
|
||||
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_one_break = torch.compile(one_break, backend=cnt)
|
||||
x2 = opt_one_break(x)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_one_break = torch.compile(one_break, backend=cnt)
|
||||
x2 = opt_one_break(x)
|
||||
|
||||
self.assertTrue(torch._dynamo.testing.same(x1, x2))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
self.assertEqual(cnt.op_count, 2)
|
||||
self.assertTrue(torch._dynamo.testing.same(x1, x2))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
self.assertEqual(cnt.op_count, 2)
|
||||
|
||||
compile_ids = set()
|
||||
for r in results:
|
||||
# A mangled classname looks like __subclass_TensorProxy_94524181138240_c0
|
||||
# where the last segment contains the compile_id.
|
||||
prefix = "__subclass_TensorProxy_"
|
||||
before, sep, after = r.partition(prefix)
|
||||
self.assertEqual(before, "")
|
||||
self.assertEqual(sep, prefix)
|
||||
compile_ids = set()
|
||||
for r in results:
|
||||
# A mangled classname looks like __subclass_TensorProxy_94524181138240_c0
|
||||
# where the last segment contains the compile_id.
|
||||
prefix = "__subclass_TensorProxy_"
|
||||
before, sep, after = r.partition(prefix)
|
||||
self.assertEqual(before, "")
|
||||
self.assertEqual(sep, prefix)
|
||||
|
||||
class_type_id, compile_id = after.split("_")
|
||||
self.assertTrue(class_type_id.isnumeric())
|
||||
self.assertTrue(compile_id.startswith("c"))
|
||||
class_type_id, compile_id = after.split("_")
|
||||
self.assertTrue(class_type_id.isnumeric())
|
||||
self.assertTrue(compile_id.startswith("c"))
|
||||
|
||||
cid = compile_id[1:]
|
||||
self.assertTrue(cid.isnumeric())
|
||||
compile_ids.add(cid)
|
||||
cid = compile_id[1:]
|
||||
self.assertTrue(cid.isnumeric())
|
||||
compile_ids.add(cid)
|
||||
|
||||
self.assertEqual(len(compile_ids), 3)
|
||||
self.assertEqual(len(compile_ids), 3)
|
||||
|
||||
finally:
|
||||
TensorWithTFOverrideVariable.global_mangled_class_name = original
|
||||
|
|
|
|||
|
|
@ -36,10 +36,6 @@ from torch.testing._internal.two_tensor import TwoTensor
|
|||
from torch.utils._python_dispatch import return_and_correct_aliasing
|
||||
|
||||
|
||||
def traceable_subclass(c):
|
||||
return torch._dynamo.config.patch("traceable_tensor_subclasses", {c})
|
||||
|
||||
|
||||
def nontraceable_subclass(c):
|
||||
return torch._dynamo.config.patch("nontraceable_tensor_subclasses", {c})
|
||||
|
||||
|
|
@ -417,15 +413,6 @@ def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
|
|||
|
||||
|
||||
class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._exit_stack.enter_context(
|
||||
torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._exit_stack.close()
|
||||
|
|
@ -444,18 +431,14 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
kwargs = {}
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", {BadNewTorchFunction}
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
return torch.add(x, 1)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
return torch.add(x, 1)
|
||||
input = torch.ones(2, 2).as_subclass(BadNewTorchFunction)
|
||||
|
||||
input = torch.ones(2, 2).as_subclass(BadNewTorchFunction)
|
||||
|
||||
res = fn(input)
|
||||
self.assertIsInstance(res, BadNewTorchFunction)
|
||||
res = fn(input)
|
||||
self.assertIsInstance(res, BadNewTorchFunction)
|
||||
|
||||
def test_no_torch_function_recompiles(self):
|
||||
class NJT:
|
||||
|
|
@ -629,16 +612,14 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
return LocalSubclass(torch.add(x, 1.0)) * 2
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
return LocalSubclass(torch.add(x, 1.0)) * 2
|
||||
input = torch.ones(2, 2)
|
||||
|
||||
input = torch.ones(2, 2)
|
||||
|
||||
res = fn(input)
|
||||
self.assertIsInstance(res, LocalSubclass)
|
||||
res = fn(input)
|
||||
self.assertIsInstance(res, LocalSubclass)
|
||||
|
||||
def test_torch_function_list_args(self):
|
||||
HANDLED_FUNCTIONS = {}
|
||||
|
|
@ -693,18 +674,16 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
],
|
||||
)
|
||||
def test_type_check(self, comparison, input_type):
|
||||
with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}):
|
||||
def fn(x):
|
||||
if comparison(x, DummyNDim):
|
||||
return torch.ones(1, 1)
|
||||
else:
|
||||
return torch.zeros(2, 2)
|
||||
|
||||
def fn(x):
|
||||
if comparison(x, DummyNDim):
|
||||
return torch.ones(1, 1)
|
||||
else:
|
||||
return torch.zeros(2, 2)
|
||||
|
||||
input = torch.ones(2, 2).as_subclass(input_type)
|
||||
exp_res = fn(input)
|
||||
act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input)
|
||||
self.assertEqual(exp_res, act_res)
|
||||
input = torch.ones(2, 2).as_subclass(input_type)
|
||||
exp_res = fn(input)
|
||||
act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input)
|
||||
self.assertEqual(exp_res, act_res)
|
||||
|
||||
def test_torch_function_call_on_method(self):
|
||||
x = torch.ones(2, 2)
|
||||
|
|
@ -751,9 +730,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
|
||||
res_exp = fn(x, wrapped)
|
||||
res_act = fn_opt(y, wrapped2)
|
||||
res_exp = fn(x, wrapped)
|
||||
res_act = fn_opt(y, wrapped2)
|
||||
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
|
|
@ -772,9 +750,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
|
|
@ -793,9 +770,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
return x.ndim
|
||||
|
||||
msg = "Currently only support accessing overridden attributes that are functions or properties, but got <class 'int'>"
|
||||
with torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", {LocalSubclass}
|
||||
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
||||
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
||||
fn(x)
|
||||
|
||||
|
|
@ -822,9 +797,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
x = LocalSubclass(torch.ones(2, 2))
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
|
|
@ -840,21 +814,14 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
def fn(x):
|
||||
return x.sigmoid()
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass}
|
||||
):
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
||||
fn(x)
|
||||
fn(x)
|
||||
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
||||
fn(x)
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
traceable_tensor_subclasses={LocalSubclass}
|
||||
), self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"'bool' object is not callable",
|
||||
):
|
||||
with self.assertRaisesRegex(TypeError, "'bool' object is not callable"):
|
||||
LocalSubclass.sigmoid = False
|
||||
fn(x)
|
||||
|
||||
|
|
@ -903,13 +870,12 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
def fn(x):
|
||||
return torch.clone(x)
|
||||
|
||||
with torch._dynamo.config.patch(traceable_tensor_subclasses={TestTensor}):
|
||||
inp = torch.ones(4, 4)
|
||||
x = inp.as_subclass(TestTensor)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
compiled_fn = torch.compile(fn, fullgraph=True)
|
||||
out = compiled_fn(x)
|
||||
self.assertEqual(out, torch.ones(4, 4) * 2)
|
||||
inp = torch.ones(4, 4)
|
||||
x = inp.as_subclass(TestTensor)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
compiled_fn = torch.compile(fn, fullgraph=True)
|
||||
out = compiled_fn(x)
|
||||
self.assertEqual(out, torch.ones(4, 4) * 2)
|
||||
|
||||
def test_torch_function_wrapper_class_with_kwargs(self):
|
||||
x = torch.ones(2, 2)
|
||||
|
|
@ -957,13 +923,12 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
def fn(x):
|
||||
return x.x + torch.ones(2, 2)
|
||||
|
||||
with traceable_subclass(AttrSubclass):
|
||||
input = torch.ones(2, 2).as_subclass(AttrSubclass)
|
||||
fn_opt = compile_full_eager(fn)
|
||||
input = torch.ones(2, 2).as_subclass(AttrSubclass)
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
res_exp = fn(input)
|
||||
res_act = fn_opt(input)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
res_exp = fn(input)
|
||||
res_act = fn_opt(input)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
def test_make_subclass(self):
|
||||
# Make sure `torch.Tensor._make_subclass` is traceable, and Dynamo
|
||||
|
|
@ -982,16 +947,15 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
res = x * y + z
|
||||
return res
|
||||
|
||||
with traceable_subclass(MySubclass):
|
||||
x0 = torch.randn(2, 2)
|
||||
x1 = x0.clone()
|
||||
x0 = torch.randn(2, 2)
|
||||
x1 = x0.clone()
|
||||
|
||||
fn_opt = compile_full_eager(fn)
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
res_exp = fn(x0)
|
||||
res_act = fn_opt(x1)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
self.assertEqual(x0, x1)
|
||||
res_exp = fn(x0)
|
||||
res_act = fn_opt(x1)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
self.assertEqual(x0, x1)
|
||||
|
||||
def test_subclass_override_shape_and_to(self):
|
||||
# This is a slight variabtion of
|
||||
|
|
@ -1013,17 +977,16 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
y = x.to("cpu")
|
||||
return x + 1, y + 2, x_shape, x.tensor_shape, y.tensor_shape
|
||||
|
||||
with traceable_subclass(MySubclass):
|
||||
x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
||||
x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass))
|
||||
x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
||||
x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass))
|
||||
|
||||
fn_opt = compile_full_eager(fn)
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
res_exp = fn(x0)
|
||||
res_act = fn_opt(x1)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
self.assertEqual(x0, x1)
|
||||
self.assertEqual(x0.tensor_shape, x1.tensor_shape)
|
||||
res_exp = fn(x0)
|
||||
res_act = fn_opt(x1)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
self.assertEqual(x0, x1)
|
||||
self.assertEqual(x0.tensor_shape, x1.tensor_shape)
|
||||
|
||||
def test_subclass_dont_invoke_torch_function_on_overriden_method(self):
|
||||
# We shouldn't fire `__torch_function__` for overriden tensor methods.
|
||||
|
|
@ -1040,14 +1003,13 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
def fn(x):
|
||||
return x.to("cpu")
|
||||
|
||||
with traceable_subclass(MySubclass):
|
||||
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
||||
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
||||
|
||||
fn_opt = compile_full_eager(fn)
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
def test_subclass_dont_invoke_torch_function_on_overriden_attr(self):
|
||||
from types import MethodWrapperType
|
||||
|
|
@ -1066,14 +1028,13 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
def fn(x):
|
||||
return x + x.ndim()
|
||||
|
||||
with traceable_subclass(MySubclass):
|
||||
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
||||
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
||||
|
||||
fn_opt = compile_full_eager(fn)
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
def test_parameter_subclass_with_old_torch_function(self):
|
||||
class MySubclass(torch.nn.Parameter):
|
||||
|
|
@ -1152,9 +1113,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
||||
|
||||
x = GGUFParameter(torch.ones(2), quant_type=42)
|
||||
with traceable_subclass(GGUFParameter):
|
||||
res = f(x)
|
||||
ref = opt_f(x)
|
||||
res = f(x)
|
||||
ref = opt_f(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_newly_constructed_tensor_subclass_attr_mutation(self):
|
||||
|
|
@ -1171,9 +1131,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
opt_f = compile_full_eager(f)
|
||||
|
||||
with traceable_subclass(MySubclass):
|
||||
res = f()
|
||||
ref = opt_f()
|
||||
res = f()
|
||||
ref = opt_f()
|
||||
|
||||
self.assertEqual(res, ref)
|
||||
self.assertEqual(res[0].bar, ref[0].bar)
|
||||
|
|
@ -1192,9 +1151,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
opt_f = compile_full_eager(f)
|
||||
|
||||
with traceable_subclass(MySubclass):
|
||||
res = f()
|
||||
ref = opt_f()
|
||||
res = f()
|
||||
ref = opt_f()
|
||||
|
||||
self.assertEqual(res, ref)
|
||||
self.assertEqual(res[0].bar, ref[0].bar)
|
||||
|
|
@ -1216,9 +1174,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
opt_f = compile_full_eager(f)
|
||||
|
||||
t = MySubclass(torch.ones(2))
|
||||
with traceable_subclass(MySubclass):
|
||||
res = f(t)
|
||||
ref = opt_f(t)
|
||||
res = f(t)
|
||||
ref = opt_f(t)
|
||||
|
||||
self.assertEqual(res, ref)
|
||||
self.assertEqual(res.elem, ref.elem)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,9 @@ import pprint
|
|||
import pickle
|
||||
import collections
|
||||
import unittest
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF
|
||||
from torch.overrides import (
|
||||
handle_torch_function,
|
||||
has_torch_function,
|
||||
|
|
@ -382,27 +381,6 @@ class TensorLike:
|
|||
return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
|
||||
|
||||
class TestTorchFunctionOverride(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._stack = contextlib.ExitStack()
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
# Add classes to the wrapped tensor subclasses
|
||||
@contextlib.contextmanager
|
||||
def setup_subclasses():
|
||||
old = set(torch._dynamo.config.traceable_tensor_subclasses)
|
||||
torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.traceable_tensor_subclasses.clear()
|
||||
torch._dynamo.config.traceable_tensor_subclasses.update(old)
|
||||
|
||||
cls._stack.enter_context(setup_subclasses())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._stack.close()
|
||||
|
||||
def test_dtype_override(self):
|
||||
class MyDtype:
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user