pytorch/torch/_functorch/utils.py
rzou 44c0c0fc0f Add torch.library.custom_op (#122344)
This is the entrypoint for defining an opaque/blackbox (e.g. PyTorch will
never peek into it) custom op. In this PR, you can specify backend impls
and the abstract impl for this op.

NB: most of this PR is docstrings, please don't be intimidated by the
line count.

There are a number of interesting features:
- we infer the schema from type hints. In a followup I add the ability
  to manually specify a schema.
- name inference. The user needs to manually specify an op name for now.
  In a followup we add the ability to automatically infer a name (this
  is a little tricky).
- custom_op registrations can override each other. This makes them
  more pleasant to work with in environments like colab.
- we require that the outputs of the custom_op do not alias any inputs
  or each other. We enforce this via a runtime check, but can relax this
  into an opcheck test if it really matters in the future.

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122344
Approved by: https://github.com/ezyang, https://github.com/albanD
2024-04-03 18:36:17 +00:00

31 lines
933 B
Python

import contextlib
import torch
from torch._C._functorch import (
set_single_level_autograd_function_allowed,
get_single_level_autograd_function_allowed,
unwrap_if_dead,
)
from typing import Union, Tuple
from torch.utils._exposed_in import exposed_in
__all__ = ["exposed_in", "argnums_t", "enable_single_level_autograd_function", "unwrap_dead_wrappers"]
@contextlib.contextmanager
def enable_single_level_autograd_function():
try:
prev_state = get_single_level_autograd_function_allowed()
set_single_level_autograd_function_allowed(True)
yield
finally:
set_single_level_autograd_function_allowed(prev_state)
def unwrap_dead_wrappers(args):
# NB: doesn't use tree_map_only for performance reasons
result = tuple(
unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg
for arg in args
)
return result
argnums_t = Union[int, Tuple[int, ...]]