Add way to actually delete a torch.library.Library object (#118318)

Relying on object lifetimes in Python is a bad idea due to reference
cycles. Previously, when a torch.library.Library object gets destroyed,
it clears all the registrations associated with it, but it's unclear
when it actually gets destroyed due to the existence of refcycles.

This PR:
- adds torch::Library::clear(), which deterministically releases all of
  the RAII registration handles of the torch::Library object
- adds a new `torch.library._scoped_library` context manager, which creates
  a library and cleans it up at the end of the scope using the previous item.
  All tests (unless they already handle library lifetimes) should use
  this new API
- Rewrites some flaky tests to use `_scoped_library`.

In the future we'll probably migrate all of our torch.library tests to
use `_scoped_library`, but that's kind of annoying because we have
multiple thousands of LOC

I'm hoping this will deflake those tests; we'll see.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118318
Approved by: https://github.com/albanD
This commit is contained in:
rzou 2024-01-26 11:08:49 -08:00 committed by PyTorch MergeBot
parent f129e3fe03
commit b256b7b348
6 changed files with 161 additions and 131 deletions

View File

@ -51,6 +51,10 @@ CppFunction::CppFunction(c10::KernelFunction func, c10::optional<c10::impl::CppS
CppFunction::~CppFunction() = default;
void Library::reset() {
registrars_.clear();
}
#define ERROR_CONTEXT "(Error occurred while processing ", toString(kind_), " block at ", file_, ":", line_, ")"
Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)

View File

@ -3,7 +3,7 @@
import tempfile
import torch
from copy import deepcopy
from torch.library import Library, impl, fallthrough_kernel
from torch.library import Library, impl, fallthrough_kernel, _scoped_library
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch import SymInt
from torch._subclasses.fake_tensor import FakeTensorMode
@ -48,9 +48,8 @@ class TestPythonRegistration(TestCase):
def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
my_lib1 = Library("aten", "IMPL")
my_lib2 = Library("aten", "IMPL")
with _scoped_library("aten", "IMPL") as my_lib2:
with _scoped_library("aten", "IMPL") as my_lib1:
# Example 1
def my_neg(*args, **kwargs):
return args[0]._neg_view()
@ -84,13 +83,9 @@ class TestPythonRegistration(TestCase):
with self.assertRaisesRegex(RuntimeError, 'already a kernel registered from python'):
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
del my_lib1
# Validate that lib2 is not affected by removing lib1
self.assertFalse(torch.mul(x, y)._is_zerotensor())
del my_lib2
# Validate that the old behavior is restored for neg and mul
self.assertFalse(torch.neg(x).is_neg())
self.assertTrue(torch.mul(x, y)._is_zerotensor())
@ -419,7 +414,7 @@ class TestPythonRegistration(TestCase):
self.assertEqual(out_val, 13)
def test_register_functional_op_error_cases(self):
lib = Library(self.test_ns, "FRAGMENT")
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
register_functional_op(lib, "abs", torch.ops.aten.abs_)
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
@ -432,20 +427,15 @@ class TestPythonRegistration(TestCase):
'foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)',
'foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))',
]
del lib
for schema in schemas:
lib = Library(self.test_ns, "FRAGMENT")
try:
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define(schema)
with self.assertRaisesRegex(RuntimeError, "NYI"):
register_functional_op(
lib,
"foo_functional",
getattr(torch.ops, self.test_ns).foo.default)
finally:
del lib
delattr(torch.ops, self.test_ns)
def _check_is_functional_variant(self, mutable_op, functional_op, args):
# functional op should not mutate
@ -483,7 +473,7 @@ class TestPythonRegistration(TestCase):
self.assertTrue(has_functional_op)
def test_register_functional_op_no_returns(self):
lib = Library(self.test_ns, 'FRAGMENT')
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()')
def foo_impl(x, y, z, w):
@ -504,7 +494,7 @@ class TestPythonRegistration(TestCase):
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
def test_register_functional_op_with_optional(self):
lib = Library(self.test_ns, 'FRAGMENT')
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define('foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()')
def foo_impl(x, y, z, w):
@ -529,9 +519,8 @@ class TestPythonRegistration(TestCase):
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, None))
def test_register_functional_op_one_return(self):
lib = Library(self.test_ns, 'FRAGMENT')
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define('foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor')
def foo_impl(x, y, z, w):
@ -554,7 +543,7 @@ class TestPythonRegistration(TestCase):
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
def test_register_functional_op_multiple_returns(self):
lib = Library(self.test_ns, 'FRAGMENT')
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)')
def foo_impl(x, y, z, w):

View File

@ -247,6 +247,14 @@ void initDispatchBindings(PyObject* module) {
// TODO: figure out how to do chaining
py::class_<torch::Library>(m, "_DispatchModule")
.def(
"reset",
[](const py::object& self) {
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
self.cast<torch::Library&>().reset();
return;
},
"")
// Some of these APIs are only for testing and do not work in multipy
// environment
.def(

View File

@ -837,6 +837,9 @@ class TORCH_API Library final {
template <class CurClass>
inline detail::ClassNotSelected class_(detail::SelectiveStr<false> className);
// De-registers all registrations created with this Library
void reset();
private:
Kind kind_;
c10::optional<std::string> ns_;

View File

@ -6,6 +6,7 @@ import weakref
import functools
import inspect
import re
import contextlib
import sys
__all__ = [
@ -171,10 +172,27 @@ class Library:
self._op_impls.add(key)
def _destroy(self):
if self.m is not None:
self.m.reset()
self.m = None
for handle in self._registration_handles:
handle.destroy()
self._registration_handles.clear()
for name in self._op_defs:
# Delete the cached torch.ops.ns.foo if it was registered.
# Otherwise, accessing it leads to a segfault.
# It's possible that we only registered an overload in this Library
# and another library owns an alive overload.
# That's OK - the next time torch.ops.ns.foo gets called, it'll be
# recomputed to point at the right collection of overloads.
ns, name_with_overload = name.split("::")
name = name_with_overload.split(".")[0]
if not hasattr(torch.ops, ns):
continue
namespace = getattr(torch.ops, ns)
if not hasattr(namespace, name):
continue
delattr(namespace, name)
def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_handles):
@ -184,6 +202,15 @@ def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_
handle.destroy()
@contextlib.contextmanager
def _scoped_library(*args, **kwargs):
try:
lib = Library(*args, **kwargs)
yield lib
finally:
lib._destroy()
_keep_alive = []

View File

@ -2745,7 +2745,6 @@ dynamo_expected_failures = {
"TestPythonDispatch.test_subclass_autograd_device_check", # test_python_dispatch
"TestPythonDispatch.test_data_ptr_respects_numel_slow_path", # test_python_dispatch
"TestPythonDispatch.test_make_subclass_with_modes", # test_python_dispatch
"TestPythonRegistration.test_override_aten_ops_with_multiple_libraries", # test_python_dispatch
"TestPythonDispatch.test_dispatch_super_call", # test_python_dispatch
"TestPythonDispatch.test_subclass_priority", # test_python_dispatch
"TestPythonDispatch.test_exception_handling", # test_python_dispatch