[dynamo] Fix bug in hasattr(tensor, "size") (#152883)

Fixes https://github.com/pytorch/pytorch/issues/135696

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152883
Approved by: https://github.com/StrongerXi
This commit is contained in:
Animesh Jain 2025-05-07 12:37:36 -07:00 committed by PyTorch MergeBot
parent 834bc5e414
commit 6f6fac6a41
2 changed files with 20 additions and 0 deletions

View File

@ -6832,6 +6832,19 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
out2, _ = torch.compile(moe_mlp, backend="eager")(x)
self.assertEqual(out1, out2)
def test_tensor_size_hasattr(self):
def fn(x):
if hasattr(x, "size"):
x = x * 2
if hasattr(x, "stride"):
x = x * 3
return x * 5
x = torch.ones(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
@requires_cuda
def test_memleak_when_graph_input_has_tensor_attr(self, device):
@torch.compile(backend="eager")

View File

@ -394,6 +394,13 @@ class TensorVariable(VariableTracker):
from . import GetAttrVariable
from .builtin import BuiltinVariable
# TODO - This is not a good solution but solves an accuracy issue.
# Today, var_getattr returns GetAttrVariable for both non-existent
# attributes and existing attributes. This is a bug and requires more
# deep dive.
if name in ("size", "stride"):
return ConstantVariable(True)
try:
var = BuiltinVariable(getattr).call_function(
tx, [self, ConstantVariable(name)], {}