mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Support tensor subclass with overriden tensor methods and properties (#149484)
This fixes most of the "torch.compile X tensor-subclass" issues
encountered in https://github.com/city96/ComfyUI-GGUF/issues/118. The
relevant tensor subclass definition is here:
298192ed60/ops.py (L18-L65).
A few things to note about the tensor subclass:
1. it overrides a lot of the `torch.Tensor` methods (e.g., `to`,
`clone`), so this patch updates `TensorWithTFOverrideVariable.var_getattr`
to support that.
2. it overrides the `shape` property, so this patch updates
`TensorWithTFOverrideVariable.var_getattr` to support property as well.
3. it has calls to `torch.Tensor.size`, which returns `torch.Size`,
which gets reconstructed in `torch.Tensor.__torch_function__`, so
this patch adds support for calling `torch.Size(...)` on non-constant
inputs.
Differential Revision: [D71906137](https://our.internmc.facebook.com/intern/diff/D71906137)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149484
Approved by: https://github.com/jansel, https://github.com/mlazos
ghstack dependencies: #149482, #149483
This commit is contained in:
parent
0d4dbfd9ed
commit
3463ea1059
|
|
@ -757,26 +757,22 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
class LocalSubclass(torch.Tensor):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def sigmoid(self):
|
||||
return None
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
x.sigmoid()
|
||||
|
||||
msg = (
|
||||
"Accessing overridden method/attribute sigmoid on a tensor"
|
||||
" subclass with a __torch_function__ override is not supported"
|
||||
)
|
||||
with torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", {LocalSubclass}
|
||||
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
||||
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
||||
fn(x)
|
||||
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
def test_user_overidden_attr_unsupported(self):
|
||||
class LocalSubclass(torch.Tensor):
|
||||
|
|
@ -792,10 +788,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
def fn(x):
|
||||
return x.ndim
|
||||
|
||||
msg = (
|
||||
"Accessing overridden method/attribute ndim on a tensor"
|
||||
" subclass with a __torch_function__ override is not supported"
|
||||
)
|
||||
msg = "Currently only support accessing overridden attributes that are functions or properties, but got <class 'int'>"
|
||||
with torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", {LocalSubclass}
|
||||
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
||||
|
|
@ -804,13 +797,11 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
def test_user_overidden_property_unsupported(self):
|
||||
class LocalSubclass(torch.Tensor):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self._ndim = 10
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
@property
|
||||
|
|
@ -821,19 +812,17 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
def ndim(self, value):
|
||||
self._ndim = value
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
return x.ndim
|
||||
return x + x.ndim
|
||||
|
||||
msg = (
|
||||
"Accessing overridden method/attribute ndim on a tensor"
|
||||
" subclass with a __torch_function__ override is not supported"
|
||||
)
|
||||
with torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", {LocalSubclass}
|
||||
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
||||
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
||||
fn(x)
|
||||
x = LocalSubclass(torch.ones(2, 2))
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
def test_overridden_method_guarding(self):
|
||||
class LocalSubclass(torch.Tensor):
|
||||
|
|
@ -982,6 +971,88 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(res_exp, res_act)
|
||||
self.assertEqual(x0, x1)
|
||||
|
||||
def test_subclass_override_shape_and_to(self):
|
||||
# This is a slight variabtion of
|
||||
# https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435
|
||||
class MySubclass(torch.Tensor):
|
||||
def to(self, *args, **kwargs):
|
||||
new = super().to(*args, **kwargs)
|
||||
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
|
||||
return new
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if not hasattr(self, "tensor_shape"):
|
||||
self.tensor_shape = self.size()
|
||||
return self.tensor_shape
|
||||
|
||||
def fn(x):
|
||||
x_shape = x.shape
|
||||
y = x.to("cpu")
|
||||
return x + 1, y + 2, x_shape, x.tensor_shape, y.tensor_shape
|
||||
|
||||
with traceable_subclass(MySubclass):
|
||||
x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
||||
x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass))
|
||||
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
res_exp = fn(x0)
|
||||
res_act = fn_opt(x1)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
self.assertEqual(x0, x1)
|
||||
self.assertEqual(x0.tensor_shape, x1.tensor_shape)
|
||||
|
||||
def test_subclass_dont_invoke_torch_function_on_overriden_method(self):
|
||||
# We shouldn't fire `__torch_function__` for overriden tensor methods.
|
||||
class MySubclass(torch.Tensor):
|
||||
def to(self, device):
|
||||
return self * len(device)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if func is torch.Tensor.to:
|
||||
torch._dynamo.graph_break()
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def fn(x):
|
||||
return x.to("cpu")
|
||||
|
||||
with traceable_subclass(MySubclass):
|
||||
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
||||
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
def test_subclass_dont_invoke_torch_function_on_overriden_attr(self):
|
||||
from types import MethodWrapperType
|
||||
|
||||
# We shouldn't fire `__torch_function__` for overriden tensor attrs.
|
||||
class MySubclass(torch.Tensor):
|
||||
def ndim(self):
|
||||
return 42
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if type(func) is MethodWrapperType and func.__name__ == "ndim":
|
||||
torch._dynamo.graph_break()
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def fn(x):
|
||||
return x + x.ndim()
|
||||
|
||||
with traceable_subclass(MySubclass):
|
||||
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
||||
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self):
|
||||
# This is a slight variation of
|
||||
# https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ import torch._C
|
|||
import torch._numpy as tnp
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from .. import config, variables
|
||||
from .. import config, trace_rules, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..create_parameter_op import do_not_convert_to_tracable_parameter
|
||||
from ..exc import raise_observed_exception, unimplemented, unimplemented_v2
|
||||
|
|
@ -297,6 +297,14 @@ class SuperVariable(VariableTracker):
|
|||
tx.symbolic_torch_function_state.torch_function_subclass_enabled = (
|
||||
tx_old
|
||||
)
|
||||
elif (
|
||||
isinstance(inner_fn, types.MethodDescriptorType)
|
||||
and inner_fn in trace_rules.get_tensor_method()
|
||||
):
|
||||
# FunctionType but implementation is in C, we support some of these,
|
||||
# e.g., tensor ops like `torch.Tensor.to`.
|
||||
fn_var = VariableTracker.build(tx, inner_fn, source)
|
||||
return fn_var.call_function(tx, [self.objvar] + args, kwargs)
|
||||
|
||||
unimplemented(f"non-function or method super: {inner_fn}")
|
||||
|
||||
|
|
@ -669,11 +677,10 @@ class AutogradFunctionVariable(VariableTracker):
|
|||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
):
|
||||
from ..trace_rules import is_callable_allowed
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
if name == "apply":
|
||||
if is_callable_allowed(self.fn_cls):
|
||||
if trace_rules.is_callable_allowed(self.fn_cls):
|
||||
trampoline_autograd_apply = produce_trampoline_autograd_apply(
|
||||
self.fn_cls
|
||||
)
|
||||
|
|
@ -691,8 +698,6 @@ class AutogradFunctionVariable(VariableTracker):
|
|||
elif name == "backward":
|
||||
return self.call_backward(tx, args, kwargs)
|
||||
else:
|
||||
from .. import trace_rules
|
||||
|
||||
source = AttrSource(self.source, name) if self.source is not None else None
|
||||
try:
|
||||
obj = inspect.getattr_static(self.fn_cls, name)
|
||||
|
|
|
|||
|
|
@ -597,8 +597,9 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||
|
||||
# This simulates shallow-copying the tensor object.
|
||||
kwargs = dict(tensor_var.__dict__)
|
||||
assert kwargs.pop("class_type") is torch.Tensor, (
|
||||
"invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
|
||||
input_tensor_type = kwargs.pop("class_type")
|
||||
assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), (
|
||||
f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var"
|
||||
)
|
||||
torch_fn_var = build_torch_function_fn(tx, class_type, cls_source)
|
||||
var = cls(torch_function_fn=torch_fn_var, class_type=class_type, **kwargs)
|
||||
|
|
@ -638,13 +639,9 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
|
||||
)
|
||||
|
||||
if hasattr(torch.Tensor, name):
|
||||
if _is_attr_overidden(tx, self, name):
|
||||
unimplemented(
|
||||
f"Accessing overridden method/attribute {name} on a tensor"
|
||||
" subclass with a __torch_function__ override is not supported"
|
||||
)
|
||||
|
||||
# Handle non-overriden attributes inherited from `torch.Tensor`.
|
||||
attr_is_overriden = _is_attr_overidden(tx, self, name)
|
||||
if hasattr(torch.Tensor, name) and not attr_is_overriden:
|
||||
if tx.output.torch_function_enabled:
|
||||
if self.source:
|
||||
install_guard(
|
||||
|
|
@ -674,11 +671,23 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||
else:
|
||||
import types
|
||||
|
||||
cls_source = GlobalSource(self.global_mangled_class_name(tx))
|
||||
attr_source = AttrSource(cls_source, name)
|
||||
if isinstance(attr, types.FunctionType):
|
||||
cls_source = GlobalSource(self.global_mangled_class_name(tx))
|
||||
func_source = AttrSource(cls_source, name)
|
||||
install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
install_guard(attr_source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
return UserMethodVariable(attr, self)
|
||||
|
||||
elif isinstance(attr, property):
|
||||
getter_source = AttrSource(attr_source, "fget")
|
||||
getter = attr.fget
|
||||
getter_var = UserMethodVariable(getter, self, source=getter_source)
|
||||
return getter_var.call_function(tx, [], {})
|
||||
|
||||
elif attr_is_overriden:
|
||||
unimplemented(
|
||||
f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950
|
||||
)
|
||||
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ from ..utils import (
|
|||
)
|
||||
from .base import AttributeMutationExisting, ValueMutationNew, VariableTracker
|
||||
from .dicts import DefaultDictVariable
|
||||
from .lists import SizeVariable
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -579,6 +580,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
assert all(x is not None for x in items)
|
||||
|
||||
return variables.NamedTupleVariable(items, self.value)
|
||||
elif self.value is torch.Size:
|
||||
# This simulates `THPSize_pynew`, the C impl for `Size.__new__`.
|
||||
tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs)
|
||||
return SizeVariable(tup.items)
|
||||
elif is_frozen_dataclass(self.value) and self.is_standard_new():
|
||||
fields = dataclasses.fields(self.value)
|
||||
items = list(args)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user