mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR introduces a new overload of torch.library.define. Like impl_abstract, and our plans for the rest of the torch.library APIs, we allow it to accept an optional library object to tie the lifetime of the op definition to. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/111307 Approved by: https://github.com/soulitzer, https://github.com/ezyang
52 lines
1.3 KiB
Python
52 lines
1.3 KiB
Python
import dataclasses
|
|
import inspect
|
|
import sys
|
|
from typing import Callable, Tuple
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Kernel:
|
|
"""Models a (function, source location)"""
|
|
|
|
func: Callable
|
|
source: str
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.func(*args, **kwargs)
|
|
|
|
|
|
class RegistrationHandle:
|
|
"""Does something when someone calls .destroy() on it"""
|
|
|
|
def __init__(self, on_destroy: Callable):
|
|
self._on_destroy = on_destroy
|
|
|
|
def destroy(self) -> None:
|
|
self._on_destroy()
|
|
|
|
|
|
def get_source(stacklevel: int) -> str:
|
|
"""Get a string that represents the caller.
|
|
|
|
Example: "/path/to/foo.py:42"
|
|
|
|
Use stacklevel=1 to get the caller's source
|
|
Use stacklevel=2 to get the caller's caller's source
|
|
etc.
|
|
"""
|
|
frame = inspect.getframeinfo(sys._getframe(stacklevel))
|
|
source = f"{frame.filename}:{frame.lineno}"
|
|
return source
|
|
|
|
|
|
def parse_namespace(qualname: str) -> Tuple[str, str]:
|
|
splits = qualname.split("::")
|
|
if len(splits) != 2:
|
|
raise ValueError(
|
|
f"Expected `qualname` to be of the form "
|
|
f'"namespace::name", but got {qualname}. '
|
|
f"The qualname passed to the torch.library APIs must consist "
|
|
f"of a namespace and a name, e.g. aten::sin"
|
|
)
|
|
return splits[0], splits[1]
|