mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add way to actually delete a torch.library.Library object (#118318)
Relying on object lifetimes in Python is a bad idea due to reference cycles. Previously, when a torch.library.Library object gets destroyed, it clears all the registrations associated with it, but it's unclear when it actually gets destroyed due to the existence of refcycles. This PR: - adds torch::Library::clear(), which deterministically releases all of the RAII registration handles of the torch::Library object - adds a new `torch.library._scoped_library` context manager, which creates a library and cleans it up at the end of the scope using the previous item. All tests (unless they already handle library lifetimes) should use this new API - Rewrites some flaky tests to use `_scoped_library`. In the future we'll probably migrate all of our torch.library tests to use `_scoped_library`, but that's kind of annoying because we have multiple thousands of LOC I'm hoping this will deflake those tests; we'll see. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118318 Approved by: https://github.com/albanD
This commit is contained in:
parent
f129e3fe03
commit
b256b7b348
|
|
@ -51,6 +51,10 @@ CppFunction::CppFunction(c10::KernelFunction func, c10::optional<c10::impl::CppS
|
|||
|
||||
CppFunction::~CppFunction() = default;
|
||||
|
||||
void Library::reset() {
|
||||
registrars_.clear();
|
||||
}
|
||||
|
||||
#define ERROR_CONTEXT "(Error occurred while processing ", toString(kind_), " block at ", file_, ":", line_, ")"
|
||||
|
||||
Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import tempfile
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
from torch.library import Library, impl, fallthrough_kernel
|
||||
from torch.library import Library, impl, fallthrough_kernel, _scoped_library
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch import SymInt
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
|
@ -48,48 +48,43 @@ class TestPythonRegistration(TestCase):
|
|||
|
||||
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
||||
x = torch.tensor([1, 2])
|
||||
my_lib1 = Library("aten", "IMPL")
|
||||
my_lib2 = Library("aten", "IMPL")
|
||||
with _scoped_library("aten", "IMPL") as my_lib2:
|
||||
with _scoped_library("aten", "IMPL") as my_lib1:
|
||||
# Example 1
|
||||
def my_neg(*args, **kwargs):
|
||||
return args[0]._neg_view()
|
||||
|
||||
# Example 1
|
||||
def my_neg(*args, **kwargs):
|
||||
return args[0]._neg_view()
|
||||
# Now we are secretly making the operator a view op so autograd needs to know how
|
||||
# to handle it
|
||||
my_lib1.impl('neg', my_neg, "AutogradCPU")
|
||||
|
||||
# Now we are secretly making the operator a view op so autograd needs to know how
|
||||
# to handle it
|
||||
my_lib1.impl('neg', my_neg, "AutogradCPU")
|
||||
self.assertTrue(torch.neg(x).is_neg())
|
||||
|
||||
self.assertTrue(torch.neg(x).is_neg())
|
||||
# RuntimeError: impl("aten::neg", ...):
|
||||
# Explicitly provided namespace (aten) in operator name does not match ...
|
||||
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
|
||||
my_lib3 = Library("foo", "DEF")
|
||||
my_lib3.define("neg(Tensor self) -> Tensor")
|
||||
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
|
||||
del my_lib3
|
||||
|
||||
# RuntimeError: impl("aten::neg", ...):
|
||||
# Explicitly provided namespace (aten) in operator name does not match ...
|
||||
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
|
||||
my_lib3 = Library("foo", "DEF")
|
||||
my_lib3.define("neg(Tensor self) -> Tensor")
|
||||
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
|
||||
del my_lib3
|
||||
# Example 2
|
||||
def my_mul(*args, **kwargs):
|
||||
return torch.zeros_like(args[0])
|
||||
|
||||
# Example 2
|
||||
def my_mul(*args, **kwargs):
|
||||
return torch.zeros_like(args[0])
|
||||
# torch.ops.aten.mul.Tensor
|
||||
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
|
||||
|
||||
# torch.ops.aten.mul.Tensor
|
||||
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
|
||||
y = torch._efficientzerotensor(2)
|
||||
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
||||
|
||||
y = torch._efficientzerotensor(2)
|
||||
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
||||
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
|
||||
# combination if someone overrided the behavior for the same before them
|
||||
with self.assertRaisesRegex(RuntimeError, 'already a kernel registered from python'):
|
||||
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
|
||||
|
||||
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
|
||||
# combination if someone overrided the behavior for the same before them
|
||||
with self.assertRaisesRegex(RuntimeError, 'already a kernel registered from python'):
|
||||
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
|
||||
|
||||
del my_lib1
|
||||
|
||||
# Validate that lib2 is not affected by removing lib1
|
||||
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
||||
|
||||
del my_lib2
|
||||
# Validate that lib2 is not affected by removing lib1
|
||||
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
||||
|
||||
# Validate that the old behavior is restored for neg and mul
|
||||
self.assertFalse(torch.neg(x).is_neg())
|
||||
|
|
@ -419,33 +414,28 @@ class TestPythonRegistration(TestCase):
|
|||
self.assertEqual(out_val, 13)
|
||||
|
||||
def test_register_functional_op_error_cases(self):
|
||||
lib = Library(self.test_ns, "FRAGMENT")
|
||||
with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs_)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs.out)
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs_)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs.out)
|
||||
|
||||
schemas = [
|
||||
'foo(Tensor x, Tensor(a!)[] y) -> ()',
|
||||
'foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)',
|
||||
'foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))',
|
||||
]
|
||||
del lib
|
||||
schemas = [
|
||||
'foo(Tensor x, Tensor(a!)[] y) -> ()',
|
||||
'foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)',
|
||||
'foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))',
|
||||
]
|
||||
|
||||
for schema in schemas:
|
||||
lib = Library(self.test_ns, "FRAGMENT")
|
||||
try:
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
lib.define(schema)
|
||||
with self.assertRaisesRegex(RuntimeError, "NYI"):
|
||||
register_functional_op(
|
||||
lib,
|
||||
"foo_functional",
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
finally:
|
||||
del lib
|
||||
delattr(torch.ops, self.test_ns)
|
||||
|
||||
def _check_is_functional_variant(self, mutable_op, functional_op, args):
|
||||
# functional op should not mutate
|
||||
|
|
@ -483,98 +473,97 @@ class TestPythonRegistration(TestCase):
|
|||
self.assertTrue(has_functional_op)
|
||||
|
||||
def test_register_functional_op_no_returns(self):
|
||||
lib = Library(self.test_ns, 'FRAGMENT')
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()')
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()')
|
||||
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
|
||||
def test_register_functional_op_with_optional(self):
|
||||
lib = Library(self.test_ns, 'FRAGMENT')
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()')
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()')
|
||||
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
z.fill_(2.71)
|
||||
if w is not None:
|
||||
w.fill_(1.618)
|
||||
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, None))
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
z.fill_(2.71)
|
||||
if w is not None:
|
||||
w.fill_(1.618)
|
||||
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, None))
|
||||
|
||||
def test_register_functional_op_one_return(self):
|
||||
lib = Library(self.test_ns, 'FRAGMENT')
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor')
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor')
|
||||
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
z.fill_(0.99)
|
||||
return x.clone()
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
z.fill_(0.99)
|
||||
return x.clone()
|
||||
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
"foo_functional",
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
"foo_functional",
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
|
||||
def test_register_functional_op_multiple_returns(self):
|
||||
lib = Library(self.test_ns, 'FRAGMENT')
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)')
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)')
|
||||
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
return x.clone(), z.clone()
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
return x.clone(), z.clone()
|
||||
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
|
||||
def test_register_fallthrough(self):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -247,6 +247,14 @@ void initDispatchBindings(PyObject* module) {
|
|||
|
||||
// TODO: figure out how to do chaining
|
||||
py::class_<torch::Library>(m, "_DispatchModule")
|
||||
.def(
|
||||
"reset",
|
||||
[](const py::object& self) {
|
||||
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
||||
self.cast<torch::Library&>().reset();
|
||||
return;
|
||||
},
|
||||
"")
|
||||
// Some of these APIs are only for testing and do not work in multipy
|
||||
// environment
|
||||
.def(
|
||||
|
|
|
|||
|
|
@ -837,6 +837,9 @@ class TORCH_API Library final {
|
|||
template <class CurClass>
|
||||
inline detail::ClassNotSelected class_(detail::SelectiveStr<false> className);
|
||||
|
||||
// De-registers all registrations created with this Library
|
||||
void reset();
|
||||
|
||||
private:
|
||||
Kind kind_;
|
||||
c10::optional<std::string> ns_;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import weakref
|
|||
import functools
|
||||
import inspect
|
||||
import re
|
||||
import contextlib
|
||||
import sys
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -171,10 +172,27 @@ class Library:
|
|||
self._op_impls.add(key)
|
||||
|
||||
def _destroy(self):
|
||||
if self.m is not None:
|
||||
self.m.reset()
|
||||
self.m = None
|
||||
for handle in self._registration_handles:
|
||||
handle.destroy()
|
||||
self._registration_handles.clear()
|
||||
for name in self._op_defs:
|
||||
# Delete the cached torch.ops.ns.foo if it was registered.
|
||||
# Otherwise, accessing it leads to a segfault.
|
||||
# It's possible that we only registered an overload in this Library
|
||||
# and another library owns an alive overload.
|
||||
# That's OK - the next time torch.ops.ns.foo gets called, it'll be
|
||||
# recomputed to point at the right collection of overloads.
|
||||
ns, name_with_overload = name.split("::")
|
||||
name = name_with_overload.split(".")[0]
|
||||
if not hasattr(torch.ops, ns):
|
||||
continue
|
||||
namespace = getattr(torch.ops, ns)
|
||||
if not hasattr(namespace, name):
|
||||
continue
|
||||
delattr(namespace, name)
|
||||
|
||||
|
||||
def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_handles):
|
||||
|
|
@ -184,6 +202,15 @@ def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_
|
|||
handle.destroy()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _scoped_library(*args, **kwargs):
|
||||
try:
|
||||
lib = Library(*args, **kwargs)
|
||||
yield lib
|
||||
finally:
|
||||
lib._destroy()
|
||||
|
||||
|
||||
_keep_alive = []
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2745,7 +2745,6 @@ dynamo_expected_failures = {
|
|||
"TestPythonDispatch.test_subclass_autograd_device_check", # test_python_dispatch
|
||||
"TestPythonDispatch.test_data_ptr_respects_numel_slow_path", # test_python_dispatch
|
||||
"TestPythonDispatch.test_make_subclass_with_modes", # test_python_dispatch
|
||||
"TestPythonRegistration.test_override_aten_ops_with_multiple_libraries", # test_python_dispatch
|
||||
"TestPythonDispatch.test_dispatch_super_call", # test_python_dispatch
|
||||
"TestPythonDispatch.test_subclass_priority", # test_python_dispatch
|
||||
"TestPythonDispatch.test_exception_handling", # test_python_dispatch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user