mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
functionalization: update is_aliased() logic (#68881)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68881 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D32647614 Pulled By: bdhirsh fbshipit-source-id: 6bec50d3e54419d1707d0b6c0c6729bcc1ced1f0
This commit is contained in:
parent
4484c04513
commit
0de7a618a3
|
|
@ -130,15 +130,7 @@ void FunctionalTensorWrapper::commit_update() {
|
||||||
generation_ = storage_impl->generation();
|
generation_ = storage_impl->generation();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FunctionalTensorWrapper::is_aliased() const {
|
|
||||||
// Two FunctionalTensorWrapper objects are aliased if they share storage.
|
|
||||||
// That means that we can check if a given FunctionalTensorWrapper is aliased
|
|
||||||
// by checking the reference count on its storage.
|
|
||||||
return storage_.use_count() > 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool FunctionalTensorWrapper::is_up_to_date() const {
|
bool FunctionalTensorWrapper::is_up_to_date() const {
|
||||||
if (!is_aliased()) return true;
|
|
||||||
auto alias_generation = functional_storage_impl()->generation();
|
auto alias_generation = functional_storage_impl()->generation();
|
||||||
return generation_ == alias_generation;
|
return generation_ == alias_generation;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -95,9 +95,6 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const char* tensorimpl_type_name() const override;
|
const char* tensorimpl_type_name() const override;
|
||||||
// Returns true if this FunctionalTensorWrapper is aliased with any other FunctionalTensorWrapper objects.
|
|
||||||
// During a functionalization pass, if we have `b = a.view()`, then a and b should both report as aliased.
|
|
||||||
bool is_aliased() const;
|
|
||||||
void set_constructor_metadata();
|
void set_constructor_metadata();
|
||||||
functionalization::FunctionalStorageImpl* functional_storage_impl() const;
|
functionalization::FunctionalStorageImpl* functional_storage_impl() const;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -306,5 +306,21 @@ $1 = torch._ops.aten._to_copy($0, dtype=6, layout=0, device=device(type='cpu'),
|
||||||
$2 = torch._ops.aten.expand($1, [2])
|
$2 = torch._ops.aten.expand($1, [2])
|
||||||
$3 = torch._ops.aten.add($2, $0)""")
|
$3 = torch._ops.aten.add($2, $0)""")
|
||||||
|
|
||||||
|
def test_nested_functions_propagate_updates(self):
|
||||||
|
def g(x):
|
||||||
|
# Create a view of x
|
||||||
|
y = x[0]
|
||||||
|
y.add_(1)
|
||||||
|
# The view, y, gets deallocated at the end of this function
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
# Calling g(x) should mutate x
|
||||||
|
g(x)
|
||||||
|
# We expect x to be synced here, even though the alias created in g() has been deallocated!
|
||||||
|
y = x + x
|
||||||
|
return y
|
||||||
|
|
||||||
|
self.assert_functionalization(f, torch.ones(2, 2))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -653,9 +653,7 @@ static PyObject * THPVariable__sync(PyObject *self, PyObject* args, PyObject* kw
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
auto self_ = r.tensor(0);
|
auto self_ = r.tensor(0);
|
||||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_));
|
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_));
|
||||||
auto wrapped_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self_);
|
at::functionalization::impl::sync(self_);
|
||||||
wrapped_impl->apply_updates();
|
|
||||||
wrapped_impl->regenerate_from_base();
|
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user