[dynamo] Remove traceable_tensor_subclasses-related code (#151062)

Since #149792 deprecates `traceable_tensor_subclasses` and it's been
landed for over a week, we can safely remove all the old code that uses
`traceable_tensor_subclasses` (they were primarily for testing purposes
and are equivalent to no-ops now).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151062
Approved by: https://github.com/mlazos, https://github.com/anijain2305
ghstack dependencies: #151060, #151061
This commit is contained in:
Ryan Guo 2025-04-11 15:20:15 -07:00 committed by PyTorch MergeBot
parent 6a1499d209
commit f66229de2b
4 changed files with 119 additions and 194 deletions

View File

@ -373,9 +373,7 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
with TestMode():
with torch._C.DisableTorchFunctionSubclass():
expected = fn(inp)
actual = fn_opt(inp)
@ -404,9 +402,7 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
with TestMode():
expected = fn(inp)
actual = fn_opt(inp)

View File

@ -2,7 +2,6 @@
# ruff: noqa: F841
import collections
import contextlib
import copy
import itertools
import os
@ -1173,7 +1172,6 @@ def make_test(fn, expected_ops=None):
return test_fn
@contextlib.contextmanager
def temporary_tensor_subclass(torch_function=None):
class TensorProxy(torch.Tensor):
@classmethod
@ -1182,11 +1180,7 @@ def temporary_tensor_subclass(torch_function=None):
torch_function()
return super().__torch_function__(func, types, args, kwargs)
torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
try:
yield TensorProxy
finally:
torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
return TensorProxy
class NNModuleTests(torch._dynamo.test_case.TestCase):
@ -1377,15 +1371,15 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
x = x.sigmoid()
return x
with temporary_tensor_subclass() as TensorProxy:
x = torch.randn(1).as_subclass(TensorProxy)
cnt = torch._dynamo.testing.CompileCounter()
out1 = foo(x)
opt_foo = torch.compile(foo, backend=cnt, fullgraph=True)
out2 = opt_foo(x)
TensorProxy = temporary_tensor_subclass()
x = torch.randn(1).as_subclass(TensorProxy)
cnt = torch._dynamo.testing.CompileCounter()
out1 = foo(x)
opt_foo = torch.compile(foo, backend=cnt, fullgraph=True)
out2 = opt_foo(x)
self.assertEqual(cnt.op_count, 4)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
self.assertEqual(cnt.op_count, 4)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
def test_torch_function_with_closure(self):
def run():
@ -1406,16 +1400,16 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
# TODO(future PR): support writes as well
counter + 1
with temporary_tensor_subclass(function) as TensorProxy:
x = torch.randn(1).as_subclass(TensorProxy)
x = torch.randn(1)
cnt = torch._dynamo.testing.CompileCounter()
out1 = foo(x)
opt_foo = torch.compile(foo, backend=cnt, fullgraph=True)
out2 = opt_foo(x)
TensorProxy = temporary_tensor_subclass(function)
x = torch.randn(1).as_subclass(TensorProxy)
x = torch.randn(1)
cnt = torch._dynamo.testing.CompileCounter()
out1 = foo(x)
opt_foo = torch.compile(foo, backend=cnt, fullgraph=True)
out2 = opt_foo(x)
self.assertEqual(cnt.op_count, 4)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
self.assertEqual(cnt.op_count, 4)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
run()
@ -1437,36 +1431,36 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
return x
try:
with temporary_tensor_subclass() as TensorProxy:
x = torch.randn(1).as_subclass(TensorProxy)
x1 = one_break(x)
TensorProxy = temporary_tensor_subclass()
x = torch.randn(1).as_subclass(TensorProxy)
x1 = one_break(x)
cnt = torch._dynamo.testing.CompileCounter()
opt_one_break = torch.compile(one_break, backend=cnt)
x2 = opt_one_break(x)
cnt = torch._dynamo.testing.CompileCounter()
opt_one_break = torch.compile(one_break, backend=cnt)
x2 = opt_one_break(x)
self.assertTrue(torch._dynamo.testing.same(x1, x2))
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
self.assertTrue(torch._dynamo.testing.same(x1, x2))
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
compile_ids = set()
for r in results:
# A mangled classname looks like __subclass_TensorProxy_94524181138240_c0
# where the last segment contains the compile_id.
prefix = "__subclass_TensorProxy_"
before, sep, after = r.partition(prefix)
self.assertEqual(before, "")
self.assertEqual(sep, prefix)
compile_ids = set()
for r in results:
# A mangled classname looks like __subclass_TensorProxy_94524181138240_c0
# where the last segment contains the compile_id.
prefix = "__subclass_TensorProxy_"
before, sep, after = r.partition(prefix)
self.assertEqual(before, "")
self.assertEqual(sep, prefix)
class_type_id, compile_id = after.split("_")
self.assertTrue(class_type_id.isnumeric())
self.assertTrue(compile_id.startswith("c"))
class_type_id, compile_id = after.split("_")
self.assertTrue(class_type_id.isnumeric())
self.assertTrue(compile_id.startswith("c"))
cid = compile_id[1:]
self.assertTrue(cid.isnumeric())
compile_ids.add(cid)
cid = compile_id[1:]
self.assertTrue(cid.isnumeric())
compile_ids.add(cid)
self.assertEqual(len(compile_ids), 3)
self.assertEqual(len(compile_ids), 3)
finally:
TensorWithTFOverrideVariable.global_mangled_class_name = original

View File

@ -36,10 +36,6 @@ from torch.testing._internal.two_tensor import TwoTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
def traceable_subclass(c):
return torch._dynamo.config.patch("traceable_tensor_subclasses", {c})
def nontraceable_subclass(c):
return torch._dynamo.config.patch("nontraceable_tensor_subclasses", {c})
@ -417,15 +413,6 @@ def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
class SubclassTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(
torch._dynamo.config.patch(
"traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES
)
)
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
@ -444,18 +431,14 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {BadNewTorchFunction}
):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return torch.add(x, 1)
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return torch.add(x, 1)
input = torch.ones(2, 2).as_subclass(BadNewTorchFunction)
input = torch.ones(2, 2).as_subclass(BadNewTorchFunction)
res = fn(input)
self.assertIsInstance(res, BadNewTorchFunction)
res = fn(input)
self.assertIsInstance(res, BadNewTorchFunction)
def test_no_torch_function_recompiles(self):
class NJT:
@ -629,16 +612,14 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
def __torch_function__(cls, func, types, args=(), kwargs=None):
return super().__torch_function__(func, types, args, kwargs)
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return LocalSubclass(torch.add(x, 1.0)) * 2
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return LocalSubclass(torch.add(x, 1.0)) * 2
input = torch.ones(2, 2)
input = torch.ones(2, 2)
res = fn(input)
self.assertIsInstance(res, LocalSubclass)
res = fn(input)
self.assertIsInstance(res, LocalSubclass)
def test_torch_function_list_args(self):
HANDLED_FUNCTIONS = {}
@ -693,18 +674,16 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
],
)
def test_type_check(self, comparison, input_type):
with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}):
def fn(x):
if comparison(x, DummyNDim):
return torch.ones(1, 1)
else:
return torch.zeros(2, 2)
def fn(x):
if comparison(x, DummyNDim):
return torch.ones(1, 1)
else:
return torch.zeros(2, 2)
input = torch.ones(2, 2).as_subclass(input_type)
exp_res = fn(input)
act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input)
self.assertEqual(exp_res, act_res)
input = torch.ones(2, 2).as_subclass(input_type)
exp_res = fn(input)
act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input)
self.assertEqual(exp_res, act_res)
def test_torch_function_call_on_method(self):
x = torch.ones(2, 2)
@ -751,9 +730,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
fn_opt = torch.compile(fn)
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
res_exp = fn(x, wrapped)
res_act = fn_opt(y, wrapped2)
res_exp = fn(x, wrapped)
res_act = fn_opt(y, wrapped2)
self.assertEqual(res_exp, res_act)
@ -772,9 +750,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn_opt = compile_full_eager(fn)
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
res_exp = fn(x)
res_act = fn_opt(x)
res_exp = fn(x)
res_act = fn_opt(x)
self.assertEqual(res_exp, res_act)
@ -793,9 +770,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
return x.ndim
msg = "Currently only support accessing overridden attributes that are functions or properties, but got <class 'int'>"
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {LocalSubclass}
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
@ -822,9 +797,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
x = LocalSubclass(torch.ones(2, 2))
fn_opt = compile_full_eager(fn)
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
res_exp = fn(x)
res_act = fn_opt(x)
res_exp = fn(x)
res_act = fn_opt(x)
self.assertEqual(res_exp, res_act)
@ -840,21 +814,14 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
def fn(x):
return x.sigmoid()
with torch._dynamo.config.patch(
error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass}
):
with torch._dynamo.config.patch(error_on_recompile=True):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
fn(x)
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
with torch._dynamo.config.patch(
traceable_tensor_subclasses={LocalSubclass}
), self.assertRaisesRegex(
TypeError,
"'bool' object is not callable",
):
with self.assertRaisesRegex(TypeError, "'bool' object is not callable"):
LocalSubclass.sigmoid = False
fn(x)
@ -903,13 +870,12 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
def fn(x):
return torch.clone(x)
with torch._dynamo.config.patch(traceable_tensor_subclasses={TestTensor}):
inp = torch.ones(4, 4)
x = inp.as_subclass(TestTensor)
torch._dynamo.mark_dynamic(x, 0)
compiled_fn = torch.compile(fn, fullgraph=True)
out = compiled_fn(x)
self.assertEqual(out, torch.ones(4, 4) * 2)
inp = torch.ones(4, 4)
x = inp.as_subclass(TestTensor)
torch._dynamo.mark_dynamic(x, 0)
compiled_fn = torch.compile(fn, fullgraph=True)
out = compiled_fn(x)
self.assertEqual(out, torch.ones(4, 4) * 2)
def test_torch_function_wrapper_class_with_kwargs(self):
x = torch.ones(2, 2)
@ -957,13 +923,12 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
def fn(x):
return x.x + torch.ones(2, 2)
with traceable_subclass(AttrSubclass):
input = torch.ones(2, 2).as_subclass(AttrSubclass)
fn_opt = compile_full_eager(fn)
input = torch.ones(2, 2).as_subclass(AttrSubclass)
fn_opt = compile_full_eager(fn)
res_exp = fn(input)
res_act = fn_opt(input)
self.assertEqual(res_exp, res_act)
res_exp = fn(input)
res_act = fn_opt(input)
self.assertEqual(res_exp, res_act)
def test_make_subclass(self):
# Make sure `torch.Tensor._make_subclass` is traceable, and Dynamo
@ -982,16 +947,15 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
res = x * y + z
return res
with traceable_subclass(MySubclass):
x0 = torch.randn(2, 2)
x1 = x0.clone()
x0 = torch.randn(2, 2)
x1 = x0.clone()
fn_opt = compile_full_eager(fn)
fn_opt = compile_full_eager(fn)
res_exp = fn(x0)
res_act = fn_opt(x1)
self.assertEqual(res_exp, res_act)
self.assertEqual(x0, x1)
res_exp = fn(x0)
res_act = fn_opt(x1)
self.assertEqual(res_exp, res_act)
self.assertEqual(x0, x1)
def test_subclass_override_shape_and_to(self):
# This is a slight variabtion of
@ -1013,17 +977,16 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
y = x.to("cpu")
return x + 1, y + 2, x_shape, x.tensor_shape, y.tensor_shape
with traceable_subclass(MySubclass):
x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass))
x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass))
fn_opt = compile_full_eager(fn)
fn_opt = compile_full_eager(fn)
res_exp = fn(x0)
res_act = fn_opt(x1)
self.assertEqual(res_exp, res_act)
self.assertEqual(x0, x1)
self.assertEqual(x0.tensor_shape, x1.tensor_shape)
res_exp = fn(x0)
res_act = fn_opt(x1)
self.assertEqual(res_exp, res_act)
self.assertEqual(x0, x1)
self.assertEqual(x0.tensor_shape, x1.tensor_shape)
def test_subclass_dont_invoke_torch_function_on_overriden_method(self):
# We shouldn't fire `__torch_function__` for overriden tensor methods.
@ -1040,14 +1003,13 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
def fn(x):
return x.to("cpu")
with traceable_subclass(MySubclass):
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
fn_opt = compile_full_eager(fn)
fn_opt = compile_full_eager(fn)
res_exp = fn(x)
res_act = fn_opt(x)
self.assertEqual(res_exp, res_act)
res_exp = fn(x)
res_act = fn_opt(x)
self.assertEqual(res_exp, res_act)
def test_subclass_dont_invoke_torch_function_on_overriden_attr(self):
from types import MethodWrapperType
@ -1066,14 +1028,13 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
def fn(x):
return x + x.ndim()
with traceable_subclass(MySubclass):
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
fn_opt = compile_full_eager(fn)
fn_opt = compile_full_eager(fn)
res_exp = fn(x)
res_act = fn_opt(x)
self.assertEqual(res_exp, res_act)
res_exp = fn(x)
res_act = fn_opt(x)
self.assertEqual(res_exp, res_act)
def test_parameter_subclass_with_old_torch_function(self):
class MySubclass(torch.nn.Parameter):
@ -1152,9 +1113,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
x = GGUFParameter(torch.ones(2), quant_type=42)
with traceable_subclass(GGUFParameter):
res = f(x)
ref = opt_f(x)
res = f(x)
ref = opt_f(x)
self.assertEqual(res, ref)
def test_newly_constructed_tensor_subclass_attr_mutation(self):
@ -1171,9 +1131,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
opt_f = compile_full_eager(f)
with traceable_subclass(MySubclass):
res = f()
ref = opt_f()
res = f()
ref = opt_f()
self.assertEqual(res, ref)
self.assertEqual(res[0].bar, ref[0].bar)
@ -1192,9 +1151,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
opt_f = compile_full_eager(f)
with traceable_subclass(MySubclass):
res = f()
ref = opt_f()
res = f()
ref = opt_f()
self.assertEqual(res, ref)
self.assertEqual(res[0].bar, ref[0].bar)
@ -1216,9 +1174,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
opt_f = compile_full_eager(f)
t = MySubclass(torch.ones(2))
with traceable_subclass(MySubclass):
res = f(t)
ref = opt_f(t)
res = f(t)
ref = opt_f(t)
self.assertEqual(res, ref)
self.assertEqual(res.elem, ref.elem)

View File

@ -9,10 +9,9 @@ import pprint
import pickle
import collections
import unittest
import contextlib
import os
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF
from torch.overrides import (
handle_torch_function,
has_torch_function,
@ -382,27 +381,6 @@ class TensorLike:
return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
class TestTorchFunctionOverride(TestCase):
@classmethod
def setUpClass(cls):
cls._stack = contextlib.ExitStack()
if TEST_WITH_TORCHDYNAMO:
# Add classes to the wrapped tensor subclasses
@contextlib.contextmanager
def setup_subclasses():
old = set(torch._dynamo.config.traceable_tensor_subclasses)
torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor)
try:
yield
finally:
torch._dynamo.config.traceable_tensor_subclasses.clear()
torch._dynamo.config.traceable_tensor_subclasses.update(old)
cls._stack.enter_context(setup_subclasses())
@classmethod
def tearDownClass(cls):
cls._stack.close()
def test_dtype_override(self):
class MyDtype:
def __torch_function__(self, *args, **kwargs):