mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Changelog: - torch.library.impl_abstract optionally accepts a torch.library.Library object. If passed in, then the lifetime of the registration is tied to the Library object. - we've also changed torch.library.impl_abstract to work on all operators, including overloads. - we refactored the `torch._custom_ops.*` and `torch._custom_op.*` impl_abstract APIs and put them under torch._library. This is the final resting place for them. I will follow-up with deleting all the `torch._custom_ops.*` stuff later. - There is a new "SimpleOperatorRegistry" where we actually collect the abstract_impl. We will expand this to also hold the other torch._custom_ops.* APIs when we move those to torch.library NB: Previously we had designed `impl_abstract` assuming a very high-level Python-only custom op API. We've revisited that since; now, impl_abstract works for all custom ops, no matter python or C++, no matter the schema. The new refactored design reflects this better. Test Plan: - existing and new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/109912 Approved by: https://github.com/ezyang
40 lines
886 B
Python
40 lines
886 B
Python
import dataclasses
|
|
import inspect
|
|
import sys
|
|
from typing import Callable
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Kernel:
|
|
"""Models a (function, source location)"""
|
|
|
|
func: Callable
|
|
source: str
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.func(*args, **kwargs)
|
|
|
|
|
|
class RegistrationHandle:
|
|
"""Does something when someone calls .destroy() on it"""
|
|
|
|
def __init__(self, on_destroy: Callable):
|
|
self._on_destroy = on_destroy
|
|
|
|
def destroy(self) -> None:
|
|
self._on_destroy()
|
|
|
|
|
|
def get_source(stacklevel: int) -> str:
|
|
"""Get a string that represents the caller.
|
|
|
|
Example: "/path/to/foo.py:42"
|
|
|
|
Use stacklevel=1 to get the caller's source
|
|
Use stacklevel=2 to get the caller's caller's source
|
|
etc.
|
|
"""
|
|
frame = inspect.getframeinfo(sys._getframe(stacklevel))
|
|
source = f"{frame.filename}:{frame.lineno}"
|
|
return source
|