mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix regression which creates a new fake tensor (#111864)
Fixes regression identified here: ccd6b373b5 (r1369334484)
Now that `get_fake_value` will identify aliases, we should not try to wrap the fake value again.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111864
Approved by: https://github.com/eellison
This commit is contained in:
parent
0e0f6a248d
commit
6d78f34a06
|
|
@ -2221,10 +2221,10 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
|
||||
mod = Mod()
|
||||
foo(mod, torch.rand([4]))
|
||||
self.assertEqual(compiles_without_buffers, 1)
|
||||
self.assertEqual(compiles_without_buffers, 0)
|
||||
|
||||
foo(mod, torch.rand([4], dtype=torch.half))
|
||||
self.assertEqual(compiles_without_buffers, 2)
|
||||
self.assertEqual(compiles_without_buffers, 1)
|
||||
|
||||
class Mod2(Mod):
|
||||
def __setattr__(self, name, value):
|
||||
|
|
@ -2232,7 +2232,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
|
||||
foo(Mod2(), torch.rand([4]))
|
||||
# causes two compilations, bc unimplemented custom setattr
|
||||
self.assertTrue(compiles_without_buffers >= 4)
|
||||
self.assertTrue(compiles_without_buffers >= 2)
|
||||
|
||||
def test_unspec_non_inlinable_module(self):
|
||||
mod = UnspecNonInlinableModule()
|
||||
|
|
|
|||
|
|
@ -1360,7 +1360,7 @@ def get_fake_value(node, tx):
|
|||
|
||||
op = node.op
|
||||
|
||||
# FX Node should always return the same value
|
||||
# FX Node should always return the same fake value
|
||||
if "example_value" in node.meta and is_fake(node.meta["example_value"]):
|
||||
return node.meta["example_value"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1238,13 +1238,8 @@ class BuiltinVariable(VariableTracker):
|
|||
getattr_var = None
|
||||
|
||||
if isinstance(getattr_var, variables.TensorVariable):
|
||||
# get_fake_val will return a real tensor here because it's an attribute on the module (get_attr node)
|
||||
existing_attr = get_fake_value(getattr_var.as_proxy().node, tx)
|
||||
existing_fake_attr = (
|
||||
variables.builder.wrap_to_fake_tensor_and_record(
|
||||
existing_attr, tx, source=getattr_var.source, is_tensor=True
|
||||
)
|
||||
)
|
||||
# get_fake_val will get the same fake tensor
|
||||
existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx)
|
||||
|
||||
# same tensor identiy, setattr is a no-op
|
||||
mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user