mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Harden torch function dispatchability check for attributes and methods access (#153082)
See more details in https://github.com/pytorch/pytorch/issues/151771#issuecomment-2836372110. Fixes #151771. Differential Revision: [D74342291](https://our.internmc.facebook.com/intern/diff/D74342291) Pull Request resolved: https://github.com/pytorch/pytorch/pull/153082 Approved by: https://github.com/mlazos
This commit is contained in:
parent
c227865720
commit
18e13a67ce
|
|
@ -1041,6 +1041,24 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
pass
|
||||
|
||||
def fn(x):
|
||||
x = x.t()
|
||||
x = x.T
|
||||
return x + 1
|
||||
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
x = torch.randn(2, 2).as_subclass(MySubclass)
|
||||
res_exp = fn(x)
|
||||
res_act = fn_opt(x)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
def test_subclass_with_disabled_torch_function(self):
|
||||
class MySubclass(torch.Tensor):
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
def fn(x):
|
||||
x = x.t()
|
||||
x = x.T
|
||||
return x + 1
|
||||
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
|
|
|||
|
|
@ -643,7 +643,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||
# Handle non-overriden attributes inherited from `torch.Tensor`.
|
||||
attr_is_overriden = _is_attr_overidden(tx, self, name)
|
||||
if hasattr(torch.Tensor, name) and not attr_is_overriden:
|
||||
if tx.symbolic_torch_function_state.torch_function_subclass_enabled:
|
||||
args, kwargs = [self], {}
|
||||
if can_dispatch_torch_function(tx, args, kwargs):
|
||||
if self.source:
|
||||
install_guard(
|
||||
AttrSource(
|
||||
|
|
@ -656,8 +657,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||
tx,
|
||||
get_fn,
|
||||
TupleVariable([self.class_type_var(tx)]),
|
||||
[self],
|
||||
{},
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
else:
|
||||
# `TensorVariable.var_getattr` doesn't handle user-defined
|
||||
|
|
@ -726,7 +727,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||
) -> "VariableTracker":
|
||||
# This code block implements inlining the __torch_function__ override
|
||||
# of `call_method`.
|
||||
if tx.symbolic_torch_function_state.torch_function_subclass_enabled:
|
||||
tf_args = [self] + args
|
||||
if can_dispatch_torch_function(tx, tf_args, kwargs):
|
||||
import torch
|
||||
|
||||
if _is_attr_overidden(tx, self, name):
|
||||
|
|
@ -752,6 +754,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||
source = None
|
||||
value = getattr(torch.Tensor, name)
|
||||
func_var = VariableTracker.build(tx, value, source)
|
||||
return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
|
||||
return dispatch_torch_function(tx, func_var, tf_args, kwargs)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user