[dynamo] Support tensor subclass with overriden tensor methods and properties (#149484)

This fixes most of the "torch.compile X tensor-subclass" issues
encountered in https://github.com/city96/ComfyUI-GGUF/issues/118. The
relevant tensor subclass definition is here:
298192ed60/ops.py (L18-L65).

A few things to note about the tensor subclass:
1. it overrides a lot of the `torch.Tensor` methods (e.g., `to`,
   `clone`), so this patch updates `TensorWithTFOverrideVariable.var_getattr`
   to support that.
2. it overrides the `shape` property, so this patch updates
   `TensorWithTFOverrideVariable.var_getattr` to support property as well.
3. it has calls to `torch.Tensor.size`, which returns `torch.Size`,
   which gets reconstructed in `torch.Tensor.__torch_function__`, so
   this patch adds support for calling `torch.Size(...)` on non-constant
   inputs.

Differential Revision: [D71906137](https://our.internmc.facebook.com/intern/diff/D71906137)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149484
Approved by: https://github.com/jansel, https://github.com/mlazos
ghstack dependencies: #149482, #149483
This commit is contained in:
Ryan Guo 2025-04-01 17:29:40 -07:00 committed by PyTorch MergeBot
parent 0d4dbfd9ed
commit 3463ea1059
4 changed files with 137 additions and 47 deletions

View File

@ -757,26 +757,22 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
def sigmoid(self):
return None
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
x.sigmoid()
msg = (
"Accessing overridden method/attribute sigmoid on a tensor"
" subclass with a __torch_function__ override is not supported"
)
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {LocalSubclass}
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn_opt = compile_full_eager(fn)
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
res_exp = fn(x)
res_act = fn_opt(x)
self.assertEqual(res_exp, res_act)
def test_user_overidden_attr_unsupported(self):
class LocalSubclass(torch.Tensor):
@ -792,10 +788,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
def fn(x):
return x.ndim
msg = (
"Accessing overridden method/attribute ndim on a tensor"
" subclass with a __torch_function__ override is not supported"
)
msg = "Currently only support accessing overridden attributes that are functions or properties, but got <class 'int'>"
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {LocalSubclass}
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
@ -804,13 +797,11 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
def test_user_overidden_property_unsupported(self):
class LocalSubclass(torch.Tensor):
def __init__(self) -> None:
def __init__(self, *args, **kwargs) -> None:
self._ndim = 10
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
@property
@ -821,19 +812,17 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
def ndim(self, value):
self._ndim = value
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return x.ndim
return x + x.ndim
msg = (
"Accessing overridden method/attribute ndim on a tensor"
" subclass with a __torch_function__ override is not supported"
)
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {LocalSubclass}
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
x = LocalSubclass(torch.ones(2, 2))
fn_opt = compile_full_eager(fn)
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
res_exp = fn(x)
res_act = fn_opt(x)
self.assertEqual(res_exp, res_act)
def test_overridden_method_guarding(self):
class LocalSubclass(torch.Tensor):
@ -982,6 +971,88 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
self.assertEqual(res_exp, res_act)
self.assertEqual(x0, x1)
def test_subclass_override_shape_and_to(self):
# This is a slight variabtion of
# https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435
class MySubclass(torch.Tensor):
def to(self, *args, **kwargs):
new = super().to(*args, **kwargs)
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
return new
@property
def shape(self):
if not hasattr(self, "tensor_shape"):
self.tensor_shape = self.size()
return self.tensor_shape
def fn(x):
x_shape = x.shape
y = x.to("cpu")
return x + 1, y + 2, x_shape, x.tensor_shape, y.tensor_shape
with traceable_subclass(MySubclass):
x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass))
fn_opt = compile_full_eager(fn)
res_exp = fn(x0)
res_act = fn_opt(x1)
self.assertEqual(res_exp, res_act)
self.assertEqual(x0, x1)
self.assertEqual(x0.tensor_shape, x1.tensor_shape)
def test_subclass_dont_invoke_torch_function_on_overriden_method(self):
# We shouldn't fire `__torch_function__` for overriden tensor methods.
class MySubclass(torch.Tensor):
def to(self, device):
return self * len(device)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if func is torch.Tensor.to:
torch._dynamo.graph_break()
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return x.to("cpu")
with traceable_subclass(MySubclass):
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
fn_opt = compile_full_eager(fn)
res_exp = fn(x)
res_act = fn_opt(x)
self.assertEqual(res_exp, res_act)
def test_subclass_dont_invoke_torch_function_on_overriden_attr(self):
from types import MethodWrapperType
# We shouldn't fire `__torch_function__` for overriden tensor attrs.
class MySubclass(torch.Tensor):
def ndim(self):
return 42
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if type(func) is MethodWrapperType and func.__name__ == "ndim":
torch._dynamo.graph_break()
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return x + x.ndim()
with traceable_subclass(MySubclass):
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
fn_opt = compile_full_eager(fn)
res_exp = fn(x)
res_act = fn_opt(x)
self.assertEqual(res_exp, res_act)
def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self):
# This is a slight variation of
# https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435

View File

@ -32,7 +32,7 @@ import torch._C
import torch._numpy as tnp
import torch.utils._pytree as pytree
from .. import config, variables
from .. import config, trace_rules, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..create_parameter_op import do_not_convert_to_tracable_parameter
from ..exc import raise_observed_exception, unimplemented, unimplemented_v2
@ -297,6 +297,14 @@ class SuperVariable(VariableTracker):
tx.symbolic_torch_function_state.torch_function_subclass_enabled = (
tx_old
)
elif (
isinstance(inner_fn, types.MethodDescriptorType)
and inner_fn in trace_rules.get_tensor_method()
):
# FunctionType but implementation is in C, we support some of these,
# e.g., tensor ops like `torch.Tensor.to`.
fn_var = VariableTracker.build(tx, inner_fn, source)
return fn_var.call_function(tx, [self.objvar] + args, kwargs)
unimplemented(f"non-function or method super: {inner_fn}")
@ -669,11 +677,10 @@ class AutogradFunctionVariable(VariableTracker):
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
):
from ..trace_rules import is_callable_allowed
from .builder import wrap_fx_proxy
if name == "apply":
if is_callable_allowed(self.fn_cls):
if trace_rules.is_callable_allowed(self.fn_cls):
trampoline_autograd_apply = produce_trampoline_autograd_apply(
self.fn_cls
)
@ -691,8 +698,6 @@ class AutogradFunctionVariable(VariableTracker):
elif name == "backward":
return self.call_backward(tx, args, kwargs)
else:
from .. import trace_rules
source = AttrSource(self.source, name) if self.source is not None else None
try:
obj = inspect.getattr_static(self.fn_cls, name)

View File

@ -597,8 +597,9 @@ class TensorWithTFOverrideVariable(TensorVariable):
# This simulates shallow-copying the tensor object.
kwargs = dict(tensor_var.__dict__)
assert kwargs.pop("class_type") is torch.Tensor, (
"invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
input_tensor_type = kwargs.pop("class_type")
assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), (
f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var"
)
torch_fn_var = build_torch_function_fn(tx, class_type, cls_source)
var = cls(torch_function_fn=torch_fn_var, class_type=class_type, **kwargs)
@ -638,13 +639,9 @@ class TensorWithTFOverrideVariable(TensorVariable):
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
)
if hasattr(torch.Tensor, name):
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Accessing overridden method/attribute {name} on a tensor"
" subclass with a __torch_function__ override is not supported"
)
# Handle non-overriden attributes inherited from `torch.Tensor`.
attr_is_overriden = _is_attr_overidden(tx, self, name)
if hasattr(torch.Tensor, name) and not attr_is_overriden:
if tx.output.torch_function_enabled:
if self.source:
install_guard(
@ -674,11 +671,23 @@ class TensorWithTFOverrideVariable(TensorVariable):
else:
import types
cls_source = GlobalSource(self.global_mangled_class_name(tx))
attr_source = AttrSource(cls_source, name)
if isinstance(attr, types.FunctionType):
cls_source = GlobalSource(self.global_mangled_class_name(tx))
func_source = AttrSource(cls_source, name)
install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH))
install_guard(attr_source.make_guard(GuardBuilder.FUNCTION_MATCH))
return UserMethodVariable(attr, self)
elif isinstance(attr, property):
getter_source = AttrSource(attr_source, "fget")
getter = attr.fget
getter_var = UserMethodVariable(getter, self, source=getter_source)
return getter_var.call_function(tx, [], {})
elif attr_is_overriden:
unimplemented(
f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950
)
return super().var_getattr(tx, name)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):

View File

@ -82,6 +82,7 @@ from ..utils import (
)
from .base import AttributeMutationExisting, ValueMutationNew, VariableTracker
from .dicts import DefaultDictVariable
from .lists import SizeVariable
try:
@ -579,6 +580,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
assert all(x is not None for x in items)
return variables.NamedTupleVariable(items, self.value)
elif self.value is torch.Size:
# This simulates `THPSize_pynew`, the C impl for `Size.__new__`.
tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs)
return SizeVariable(tup.items)
elif is_frozen_dataclass(self.value) and self.is_standard_new():
fields = dataclasses.fields(self.value)
items = list(args)