mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This is the entrypoint for defining an opaque/blackbox (e.g. PyTorch will never peek into it) custom op. In this PR, you can specify backend impls and the abstract impl for this op. NB: most of this PR is docstrings, please don't be intimidated by the line count. There are a number of interesting features: - we infer the schema from type hints. In a followup I add the ability to manually specify a schema. - name inference. The user needs to manually specify an op name for now. In a followup we add the ability to automatically infer a name (this is a little tricky). - custom_op registrations can override each other. This makes them more pleasant to work with in environments like colab. - we require that the outputs of the custom_op do not alias any inputs or each other. We enforce this via a runtime check, but can relax this into an opcheck test if it really matters in the future. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/122344 Approved by: https://github.com/ezyang, https://github.com/albanD
31 lines
933 B
Python
31 lines
933 B
Python
import contextlib
|
|
import torch
|
|
from torch._C._functorch import (
|
|
set_single_level_autograd_function_allowed,
|
|
get_single_level_autograd_function_allowed,
|
|
unwrap_if_dead,
|
|
)
|
|
from typing import Union, Tuple
|
|
from torch.utils._exposed_in import exposed_in
|
|
|
|
__all__ = ["exposed_in", "argnums_t", "enable_single_level_autograd_function", "unwrap_dead_wrappers"]
|
|
|
|
@contextlib.contextmanager
|
|
def enable_single_level_autograd_function():
|
|
try:
|
|
prev_state = get_single_level_autograd_function_allowed()
|
|
set_single_level_autograd_function_allowed(True)
|
|
yield
|
|
finally:
|
|
set_single_level_autograd_function_allowed(prev_state)
|
|
|
|
def unwrap_dead_wrappers(args):
|
|
# NB: doesn't use tree_map_only for performance reasons
|
|
result = tuple(
|
|
unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg
|
|
for arg in args
|
|
)
|
|
return result
|
|
|
|
argnums_t = Union[int, Tuple[int, ...]]
|