mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
fix regression which creates a new fake tensor (#111864)
Fixes regression identified here: ccd6b373b5 (r1369334484)
Now that `get_fake_value` will identify aliases, we should not try to wrap the fake value again.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111864
Approved by: https://github.com/eellison
This commit is contained in:
parent
0e0f6a248d
commit
6d78f34a06
|
|
@ -2221,10 +2221,10 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
mod = Mod()
|
mod = Mod()
|
||||||
foo(mod, torch.rand([4]))
|
foo(mod, torch.rand([4]))
|
||||||
self.assertEqual(compiles_without_buffers, 1)
|
self.assertEqual(compiles_without_buffers, 0)
|
||||||
|
|
||||||
foo(mod, torch.rand([4], dtype=torch.half))
|
foo(mod, torch.rand([4], dtype=torch.half))
|
||||||
self.assertEqual(compiles_without_buffers, 2)
|
self.assertEqual(compiles_without_buffers, 1)
|
||||||
|
|
||||||
class Mod2(Mod):
|
class Mod2(Mod):
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
|
|
@ -2232,7 +2232,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
foo(Mod2(), torch.rand([4]))
|
foo(Mod2(), torch.rand([4]))
|
||||||
# causes two compilations, bc unimplemented custom setattr
|
# causes two compilations, bc unimplemented custom setattr
|
||||||
self.assertTrue(compiles_without_buffers >= 4)
|
self.assertTrue(compiles_without_buffers >= 2)
|
||||||
|
|
||||||
def test_unspec_non_inlinable_module(self):
|
def test_unspec_non_inlinable_module(self):
|
||||||
mod = UnspecNonInlinableModule()
|
mod = UnspecNonInlinableModule()
|
||||||
|
|
|
||||||
|
|
@ -1360,7 +1360,7 @@ def get_fake_value(node, tx):
|
||||||
|
|
||||||
op = node.op
|
op = node.op
|
||||||
|
|
||||||
# FX Node should always return the same value
|
# FX Node should always return the same fake value
|
||||||
if "example_value" in node.meta and is_fake(node.meta["example_value"]):
|
if "example_value" in node.meta and is_fake(node.meta["example_value"]):
|
||||||
return node.meta["example_value"]
|
return node.meta["example_value"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1238,13 +1238,8 @@ class BuiltinVariable(VariableTracker):
|
||||||
getattr_var = None
|
getattr_var = None
|
||||||
|
|
||||||
if isinstance(getattr_var, variables.TensorVariable):
|
if isinstance(getattr_var, variables.TensorVariable):
|
||||||
# get_fake_val will return a real tensor here because it's an attribute on the module (get_attr node)
|
# get_fake_val will get the same fake tensor
|
||||||
existing_attr = get_fake_value(getattr_var.as_proxy().node, tx)
|
existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx)
|
||||||
existing_fake_attr = (
|
|
||||||
variables.builder.wrap_to_fake_tensor_and_record(
|
|
||||||
existing_attr, tx, source=getattr_var.source, is_tensor=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# same tensor identiy, setattr is a no-op
|
# same tensor identiy, setattr is a no-op
|
||||||
mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__")
|
mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user