[dynamo] Harden torch function dispatchability check for attributes and methods access (#153082)

See more details in
https://github.com/pytorch/pytorch/issues/151771#issuecomment-2836372110.

Fixes #151771.

Differential Revision: [D74342291](https://our.internmc.facebook.com/intern/diff/D74342291)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153082
Approved by: https://github.com/mlazos
This commit is contained in:
Ryan Guo 2025-05-07 11:54:22 -07:00 committed by PyTorch MergeBot
parent c227865720
commit 18e13a67ce
2 changed files with 25 additions and 5 deletions

View File

@ -1041,6 +1041,24 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
pass
def fn(x):
x = x.t()
x = x.T
return x + 1
fn_opt = compile_full_eager(fn)
x = torch.randn(2, 2).as_subclass(MySubclass)
res_exp = fn(x)
res_act = fn_opt(x)
self.assertEqual(res_exp, res_act)
def test_subclass_with_disabled_torch_function(self):
class MySubclass(torch.Tensor):
__torch_function__ = torch._C._disabled_torch_function_impl
def fn(x):
x = x.t()
x = x.T
return x + 1
fn_opt = compile_full_eager(fn)

View File

@ -643,7 +643,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
# Handle non-overriden attributes inherited from `torch.Tensor`.
attr_is_overriden = _is_attr_overidden(tx, self, name)
if hasattr(torch.Tensor, name) and not attr_is_overriden:
if tx.symbolic_torch_function_state.torch_function_subclass_enabled:
args, kwargs = [self], {}
if can_dispatch_torch_function(tx, args, kwargs):
if self.source:
install_guard(
AttrSource(
@ -656,8 +657,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
tx,
get_fn,
TupleVariable([self.class_type_var(tx)]),
[self],
{},
args,
kwargs,
)
else:
# `TensorVariable.var_getattr` doesn't handle user-defined
@ -726,7 +727,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
) -> "VariableTracker":
# This code block implements inlining the __torch_function__ override
# of `call_method`.
if tx.symbolic_torch_function_state.torch_function_subclass_enabled:
tf_args = [self] + args
if can_dispatch_torch_function(tx, tf_args, kwargs):
import torch
if _is_attr_overidden(tx, self, name):
@ -752,6 +754,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
source = None
value = getattr(torch.Tensor, name)
func_var = VariableTracker.build(tx, value, source)
return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
return dispatch_torch_function(tx, func_var, tf_args, kwargs)
else:
return super().call_method(tx, name, args, kwargs)