mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Fix bug in hasattr(tensor, "size") (#152883)
Fixes https://github.com/pytorch/pytorch/issues/135696 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152883 Approved by: https://github.com/StrongerXi
This commit is contained in:
parent
834bc5e414
commit
6f6fac6a41
|
|
@ -6832,6 +6832,19 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
|||
out2, _ = torch.compile(moe_mlp, backend="eager")(x)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
def test_tensor_size_hasattr(self):
|
||||
def fn(x):
|
||||
if hasattr(x, "size"):
|
||||
x = x * 2
|
||||
if hasattr(x, "stride"):
|
||||
x = x * 3
|
||||
return x * 5
|
||||
|
||||
x = torch.ones(4)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
@requires_cuda
|
||||
def test_memleak_when_graph_input_has_tensor_attr(self, device):
|
||||
@torch.compile(backend="eager")
|
||||
|
|
|
|||
|
|
@ -394,6 +394,13 @@ class TensorVariable(VariableTracker):
|
|||
from . import GetAttrVariable
|
||||
from .builtin import BuiltinVariable
|
||||
|
||||
# TODO - This is not a good solution but solves an accuracy issue.
|
||||
# Today, var_getattr returns GetAttrVariable for both non-existent
|
||||
# attributes and existing attributes. This is a bug and requires more
|
||||
# deep dive.
|
||||
if name in ("size", "stride"):
|
||||
return ConstantVariable(True)
|
||||
|
||||
try:
|
||||
var = BuiltinVariable(getattr).call_function(
|
||||
tx, [self, ConstantVariable(name)], {}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user