From 0de7a618a311a21b0976b7261f6489f9f6eb703c Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 2 Dec 2021 09:17:36 -0800 Subject: [PATCH] functionalization: update is_aliased() logic (#68881) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68881 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D32647614 Pulled By: bdhirsh fbshipit-source-id: 6bec50d3e54419d1707d0b6c0c6729bcc1ced1f0 --- aten/src/ATen/FunctionalTensorWrapper.cpp | 8 -------- aten/src/ATen/FunctionalTensorWrapper.h | 3 --- test/test_functionalization.py | 16 ++++++++++++++++ .../autograd/python_torch_functions_manual.cpp | 4 +--- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index c8c7d021b63..5f99e377479 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -130,15 +130,7 @@ void FunctionalTensorWrapper::commit_update() { generation_ = storage_impl->generation(); } -bool FunctionalTensorWrapper::is_aliased() const { - // Two FunctionalTensorWrapper objects are aliased if they share storage. - // That means that we can check if a given FunctionalTensorWrapper is aliased - // by checking the reference count on its storage. - return storage_.use_count() > 1; -} - bool FunctionalTensorWrapper::is_up_to_date() const { - if (!is_aliased()) return true; auto alias_generation = functional_storage_impl()->generation(); return generation_ == alias_generation; } diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index c3820743cb4..1696b41f154 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -95,9 +95,6 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { private: const char* tensorimpl_type_name() const override; - // Returns true if this FunctionalTensorWrapper is aliased with any other FunctionalTensorWrapper objects. - // During a functionalization pass, if we have `b = a.view()`, then a and b should both report as aliased. - bool is_aliased() const; void set_constructor_metadata(); functionalization::FunctionalStorageImpl* functional_storage_impl() const; diff --git a/test/test_functionalization.py b/test/test_functionalization.py index b8c5f669c8f..6e967fad83b 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -306,5 +306,21 @@ $1 = torch._ops.aten._to_copy($0, dtype=6, layout=0, device=device(type='cpu'), $2 = torch._ops.aten.expand($1, [2]) $3 = torch._ops.aten.add($2, $0)""") + def test_nested_functions_propagate_updates(self): + def g(x): + # Create a view of x + y = x[0] + y.add_(1) + # The view, y, gets deallocated at the end of this function + + def f(x): + # Calling g(x) should mutate x + g(x) + # We expect x to be synced here, even though the alias created in g() has been deallocated! + y = x + x + return y + + self.assert_functionalization(f, torch.ones(2, 2)) + if __name__ == '__main__': run_tests() diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 7e2db564719..4c96cf86f5d 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -653,9 +653,7 @@ static PyObject * THPVariable__sync(PyObject *self, PyObject* args, PyObject* kw auto r = parser.parse(args, kwargs, parsed_args); auto self_ = r.tensor(0); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - auto wrapped_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self_); - wrapped_impl->apply_updates(); - wrapped_impl->regenerate_from_base(); + at::functionalization::impl::sync(self_); Py_RETURN_NONE; END_HANDLE_TH_ERRORS }