fix regression which creates a new fake tensor (#111864)

Fixes regression identified here: ccd6b373b5 (r1369334484)

Now that `get_fake_value` will identify aliases, we should not try to wrap the fake value again.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111864
Approved by: https://github.com/eellison
This commit is contained in:
Jon Chuang 2023-10-24 05:11:43 +00:00 committed by PyTorch MergeBot
parent 0e0f6a248d
commit 6d78f34a06
3 changed files with 6 additions and 11 deletions

View File

@ -2221,10 +2221,10 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
mod = Mod()
foo(mod, torch.rand([4]))
self.assertEqual(compiles_without_buffers, 1)
self.assertEqual(compiles_without_buffers, 0)
foo(mod, torch.rand([4], dtype=torch.half))
self.assertEqual(compiles_without_buffers, 2)
self.assertEqual(compiles_without_buffers, 1)
class Mod2(Mod):
def __setattr__(self, name, value):
@ -2232,7 +2232,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
foo(Mod2(), torch.rand([4]))
# causes two compilations, bc unimplemented custom setattr
self.assertTrue(compiles_without_buffers >= 4)
self.assertTrue(compiles_without_buffers >= 2)
def test_unspec_non_inlinable_module(self):
mod = UnspecNonInlinableModule()

View File

@ -1360,7 +1360,7 @@ def get_fake_value(node, tx):
op = node.op
# FX Node should always return the same value
# FX Node should always return the same fake value
if "example_value" in node.meta and is_fake(node.meta["example_value"]):
return node.meta["example_value"]

View File

@ -1238,13 +1238,8 @@ class BuiltinVariable(VariableTracker):
getattr_var = None
if isinstance(getattr_var, variables.TensorVariable):
# get_fake_val will return a real tensor here because it's an attribute on the module (get_attr node)
existing_attr = get_fake_value(getattr_var.as_proxy().node, tx)
existing_fake_attr = (
variables.builder.wrap_to_fake_tensor_and_record(
existing_attr, tx, source=getattr_var.source, is_tensor=True
)
)
# get_fake_val will get the same fake tensor
existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx)
# same tensor identiy, setattr is a no-op
mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__")