functionalization: fix bug with multiple views of same base

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77129

Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh 2022-05-24 08:30:36 -07:00 committed by PyTorch MergeBot
parent 26d9386f67
commit 2eea5eff62
2 changed files with 14 additions and 4 deletions

View File

@ -194,10 +194,8 @@ void FunctionalTensorWrapper::sync_() {
if (is_up_to_date()) { if (is_up_to_date()) {
return; return;
} }
auto any_updates = apply_updates(); apply_updates();
if (any_updates) { regenerate_from_base();
regenerate_from_base();
}
} }
void FunctionalTensorWrapper::regenerate_from_base() { void FunctionalTensorWrapper::regenerate_from_base() {

View File

@ -93,6 +93,18 @@ class TestFunctionalization(TestCase):
self.assertEqual(out_ref, torch._from_functional_tensor(out_functional)) self.assertEqual(out_ref, torch._from_functional_tensor(out_functional))
self.assertEqual(inpt, torch._from_functional_tensor(input_functional)) # input mutations should still occur self.assertEqual(inpt, torch._from_functional_tensor(input_functional)) # input mutations should still occur
def test_multiple_views_of_same_base(self):
def f(x):
y = x.view(-1)
z = x.view(-1)
x.add_(1)
# y should have been updated.
y2 = y + 1
# z should have been updated too.
z2 = z + 1
return z2
self.assert_functionalization(f, torch.ones(4))
def test_simple(self): def test_simple(self):
def f(x): def f(x):
# simple test: 1 view op, 1 inplace op # simple test: 1 view op, 1 inplace op