pytorch/torch/_library/utils.py
Richard Zou 9d9cc67592 Make torch.library.define consistent with the new APIs (#111307)
This PR introduces a new overload of torch.library.define. Like
impl_abstract, and our plans for the rest of the torch.library APIs, we
allow it to accept an optional library object to tie the lifetime of the
op definition to.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111307
Approved by: https://github.com/soulitzer, https://github.com/ezyang
2023-10-16 22:32:23 +00:00

52 lines
1.3 KiB
Python

import dataclasses
import inspect
import sys
from typing import Callable, Tuple
@dataclasses.dataclass
class Kernel:
"""Models a (function, source location)"""
func: Callable
source: str
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
class RegistrationHandle:
"""Does something when someone calls .destroy() on it"""
def __init__(self, on_destroy: Callable):
self._on_destroy = on_destroy
def destroy(self) -> None:
self._on_destroy()
def get_source(stacklevel: int) -> str:
"""Get a string that represents the caller.
Example: "/path/to/foo.py:42"
Use stacklevel=1 to get the caller's source
Use stacklevel=2 to get the caller's caller's source
etc.
"""
frame = inspect.getframeinfo(sys._getframe(stacklevel))
source = f"{frame.filename}:{frame.lineno}"
return source
def parse_namespace(qualname: str) -> Tuple[str, str]:
splits = qualname.split("::")
if len(splits) != 2:
raise ValueError(
f"Expected `qualname` to be of the form "
f'"namespace::name", but got {qualname}. '
f"The qualname passed to the torch.library APIs must consist "
f"of a namespace and a name, e.g. aten::sin"
)
return splits[0], splits[1]