mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This is the entrypoint for defining an opaque/blackbox (e.g. PyTorch will never peek into it) custom op. In this PR, you can specify backend impls and the abstract impl for this op. NB: most of this PR is docstrings, please don't be intimidated by the line count. There are a number of interesting features: - we infer the schema from type hints. In a followup I add the ability to manually specify a schema. - name inference. The user needs to manually specify an op name for now. In a followup we add the ability to automatically infer a name (this is a little tricky). - custom_op registrations can override each other. This makes them more pleasant to work with in environments like colab. - we require that the outputs of the custom_op do not alias any inputs or each other. We enforce this via a runtime check, but can relax this into an opcheck test if it really matters in the future. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/122344 Approved by: https://github.com/ezyang, https://github.com/albanD
358 lines
14 KiB
Python
358 lines
14 KiB
Python
import inspect
|
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
|
|
|
|
from torch.utils._exposed_in import exposed_in
|
|
|
|
from .. import _C, _library, library, Tensor
|
|
|
|
|
|
device_types_t = Optional[Union[str, Sequence[str]]]
|
|
|
|
|
|
@exposed_in("torch.library")
|
|
def custom_op(
|
|
name: str,
|
|
/,
|
|
*,
|
|
mutated_args: Sequence[str],
|
|
device_types: device_types_t = None,
|
|
qualname: Optional[str] = None,
|
|
) -> Callable:
|
|
"""Wraps a function into custom operator.
|
|
|
|
Reasons why you may want to create a custom op include:
|
|
- Wrapping a third-party library or custom kernel to work with PyTorch
|
|
subsystems like Autograd.
|
|
- Preventing torch.compile/export/FX tracing from peeking inside your function.
|
|
|
|
This API is used as a decorator around a function (please see examples).
|
|
The provided function must have type hints; these are needed to interface
|
|
with PyTorch's various subsystems.
|
|
|
|
Args:
|
|
name (str): A name for the custom op that looks like "{namespace}::{name}",
|
|
e.g. "mylib::my_linear". The name is used as a stable identifier for
|
|
if you wish to serialize the custom op, e.g., via torch.save/torch.export.
|
|
To avoid name collisions, please use your project name as the namespace.
|
|
mutated_args (Sequence[str]): The names of args that the function mutates.
|
|
This MUST be accurate, otherwise, the behavior is undefined.
|
|
device_types (None | str | Sequence[str]): The device type(s) the function
|
|
is valid for. If no device type is provided, then the function
|
|
is used as the default implementation for all device types.
|
|
Examples: "cpu", "cuda".
|
|
|
|
Examples::
|
|
>>> import torch
|
|
>>> from torch import Tensor
|
|
>>> from torch.library import custom_op
|
|
>>> import numpy as np
|
|
>>>
|
|
>>> @custom_op("mylib::numpy_sin", mutated_args=())
|
|
>>> def numpy_sin(x: Tensor) -> Tensor:
|
|
>>> x_np = x.cpu().numpy()
|
|
>>> y_np = np.sin(x_np)
|
|
>>> return torch.from_numpy(y_np).to(device=x.device)
|
|
>>>
|
|
>>> x = torch.randn(3)
|
|
>>> y = numpy_sin(x)
|
|
>>> assert torch.allclose(y, x.sin())
|
|
>>>
|
|
>>> # Example of a custom op that only works for one device type.
|
|
>>> @custom_op("mylib::numpy_sin_cpu", mutated_args=(), device_types="cpu")
|
|
>>> def numpy_sin_cpu(x: Tensor) -> Tensor:
|
|
>>> x_np = x.numpy()
|
|
>>> y_np = np.sin(x_np)
|
|
>>> return torch.from_numpy(y_np)
|
|
>>>
|
|
>>> x = torch.randn(3)
|
|
>>> y = numpy_sin_cpu(x)
|
|
>>> assert torch.allclose(y, x.sin())
|
|
|
|
"""
|
|
assert len(mutated_args) == 0, "NYI"
|
|
|
|
def inner(fn):
|
|
import torch
|
|
|
|
schema = torch._custom_op.impl.infer_schema(fn, mutated_args)
|
|
namespace, opname = name.split("::")
|
|
result = CustomOpDef(namespace, opname, schema, fn)
|
|
result.register_impl(device_types)(fn)
|
|
return result
|
|
|
|
return inner
|
|
|
|
|
|
class CustomOpDef:
|
|
"""CustomOpDef is a wrapper around a function that turns it into a custom op.
|
|
|
|
It has various methods for registering additional behavior for this
|
|
custom op.
|
|
|
|
You should not instantiate CustomOpDef directly; instead, use the
|
|
:func:`torch.library.custom_op` API.
|
|
"""
|
|
|
|
def __init__(self, namespace: str, name: str, schema: str, fn: Callable) -> None:
|
|
# Fields used to interface with the PyTorch dispatcher
|
|
self._namespace = namespace
|
|
self._name = name
|
|
self._schema = schema
|
|
|
|
self._init_fn = fn
|
|
|
|
self._backend_fns: Dict[Union[str, None], Callable] = {}
|
|
self._abstract_fn: Optional[Callable] = None
|
|
|
|
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
|
|
self._register_to_dispatcher()
|
|
|
|
@property
|
|
def _qualname(self) -> str:
|
|
return f"{self._namespace}::{self._name}"
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<CustomOpDef({self._qualname})>"
|
|
|
|
def register_impl(
|
|
self, device_types: device_types_t, fn: Optional[Callable] = None
|
|
) -> Callable:
|
|
"""Register an implementation for a device type for this operator.
|
|
|
|
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
|
|
This API may be used as a decorator.
|
|
|
|
Args:
|
|
fn (Callable): The function to register as the implementation for
|
|
the given device types.
|
|
device_types (str | Sequence[str]): The device device_types to register an impl to.
|
|
|
|
Examples::
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
|
>>> import torch
|
|
>>> from torch import Tensor
|
|
>>> from torch.library import custom_op
|
|
>>> import numpy as np
|
|
>>>
|
|
>>> # Example of split cpu and cuda definitions
|
|
>>> @custom_op("mylib::numpy_sin", mutated_args=(), device_types="cpu")
|
|
>>> def numpy_sin(x: Tensor) -> Tensor:
|
|
>>> x_np = x.numpy()
|
|
>>> y_np = np.sin(x_np)
|
|
>>> return torch.from_numpy(y_np)
|
|
>>>
|
|
>>> # Add implementations for the cuda device
|
|
>>> @numpy_sin.register_impl("cuda")
|
|
>>> def _(x):
|
|
>>> x_np = x.cpu().numpy()
|
|
>>> y_np = np.sin(x_np)
|
|
>>> return torch.from_numpy(y_np).to(device=x.device)
|
|
>>>
|
|
>>> x_cpu = torch.randn(3)
|
|
>>> x_cuda = x_cpu.cuda()
|
|
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
|
|
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
|
|
|
|
"""
|
|
|
|
def inner(fn):
|
|
if device_types is None or isinstance(device_types, str):
|
|
dtypes: List[Union[str, None]] = [device_types]
|
|
else:
|
|
dtypes = list(device_types)
|
|
for device_type in dtypes:
|
|
if device_type not in self._backend_fns:
|
|
|
|
def backend_impl(*args, **kwargs):
|
|
# Checks the assumption that outputs cannot alias
|
|
# inputs or other outputs.
|
|
storages = set()
|
|
for tensor in iter_tensors(args, kwargs):
|
|
storages.add(id(tensor.untyped_storage()))
|
|
|
|
result = self._backend_fns[device_type](*args, **kwargs)
|
|
|
|
tuple_result = result
|
|
if not isinstance(result, tuple):
|
|
tuple_result = (result,)
|
|
for tensor in iter_tensors(tuple_result, {}):
|
|
key = id(tensor.untyped_storage())
|
|
if id(tensor.untyped_storage()) in storages:
|
|
fn = self._backend_fns[device_type]
|
|
module = inspect.getmodule(fn)
|
|
raise RuntimeError(
|
|
f"Tensors returned from custom ops (1) must not "
|
|
f"be inputs to the custom op and (2) may not alias "
|
|
f"any inputs or other returns. Please clone the "
|
|
f"the offending output tensors (e.g. output.clone()) "
|
|
f"or refactor your code. "
|
|
f"Offending op: {self._name} (with implementation in {module})"
|
|
)
|
|
storages.add(key)
|
|
return result
|
|
|
|
if device_type is None:
|
|
self._lib.impl(
|
|
self._name, backend_impl, "CompositeExplicitAutograd"
|
|
)
|
|
else:
|
|
self._lib.impl(
|
|
self._name,
|
|
backend_impl,
|
|
_C._dispatch_key_for_device(device_type),
|
|
)
|
|
self._backend_fns[device_type] = fn
|
|
return fn
|
|
|
|
# See NOTE: [Supporting decorator and non-decorator usage]
|
|
if fn is None:
|
|
return inner
|
|
return inner(fn)
|
|
|
|
def register_fake(self, fn: Callable) -> Callable:
|
|
r"""Register a FakeTensor implementation for this custom op.
|
|
|
|
This is necessary to get the operator to work efficiently with torch.compile.
|
|
|
|
The Fake impl (sometimes also known as a meta kernel or abstract impl)
|
|
specifies the behavior of this operator on Tensors that carry no data.
|
|
Given some input Tensors with certain properties
|
|
(sizes/strides/storage_offset/device), it specifies what the properties of
|
|
the output Tensors are.
|
|
|
|
Please see :func:`torch.library.impl_abstract` for more details.
|
|
|
|
Args:
|
|
fn (Callable): The function to register as the FakeTensor
|
|
implementation.
|
|
|
|
Examples:
|
|
>>> import torch
|
|
>>> import numpy as np
|
|
>>> from torch import Tensor
|
|
>>>
|
|
>>> # Example 1: an operator without data-dependent output shape
|
|
>>> @torch.library.custom_op("mylib::linear", mutated_args=())
|
|
>>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
|
|
>>> return (x @ weight.t()) + bias
|
|
>>>
|
|
>>> @linear.register_fake
|
|
>>> def _(x, weight, bias):
|
|
>>> assert x.dim() == 2
|
|
>>> assert weight.dim() == 2
|
|
>>> assert bias.dim() == 1
|
|
>>> assert x.shape[1] == weight.shape[1]
|
|
>>> assert weight.shape[0] == bias.shape[0]
|
|
>>> assert x.device == weight.device
|
|
>>> return x.new_empty(x.size(0), weight.size(0))
|
|
>>>
|
|
>>> x = torch.randn(2, 2)
|
|
>>> weight = torch.randn(2, 2)
|
|
>>> bias = torch.randn(2)
|
|
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
|
|
>>> out = torch.compile(linear, fullgraph=True)(x, weight, bias)
|
|
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
|
|
>>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias))
|
|
>>>
|
|
>>> # Example 2: an operator with data-dependent output shape
|
|
>>> @torch.library.custom_op("mylib::nonzero", mutated_args=())
|
|
>>> def nonzero(x: Tensor) -> Tensor:
|
|
>>> x_np = x.cpu().numpy()
|
|
>>> res = np.stack(np.nonzero(x_np), axis=1)
|
|
>>> return torch.tensor(res, device=x.device)
|
|
>>>
|
|
>>> @nonzero.register_fake
|
|
>>> def _(x):
|
|
>>> # Number of nonzero-elements is data-dependent.
|
|
>>> # Since we cannot peek at the data in an abstract impl,
|
|
>>> # we use the ctx object to construct a new symint that
|
|
>>> # represents the data-dependent size.
|
|
>>> ctx = torch.library.get_ctx()
|
|
>>> nnz = ctx.new_dynamic_size()
|
|
>>> shape = [nnz, x.dim()]
|
|
>>> result = x.new_empty(shape, dtype=torch.int64)
|
|
>>> return result
|
|
>>>
|
|
>>> x = torch.tensor([0, 1, 2, 0, 0, 1])
|
|
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
|
|
>>> out = torch.compile(nonzero, fullgraph=True)(x)
|
|
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
|
|
>>> assert torch.allclose(out, x.nonzero())
|
|
|
|
"""
|
|
self._abstract_fn = fn
|
|
return fn
|
|
|
|
def _register_to_dispatcher(self) -> None:
|
|
lib = self._lib
|
|
lib.define(f"{self._name}{self._schema}")
|
|
self._opoverload = _library.utils.lookup_op(self._qualname)
|
|
|
|
def fake_impl(*args, **kwargs):
|
|
if self._abstract_fn is None:
|
|
raise RuntimeError(
|
|
f"There was no fake impl registered for {self}. "
|
|
f"This is necessary for torch.compile/export/fx tracing to work. "
|
|
f"Please use `{self._init_fn.__name__}.register_fake` to add an "
|
|
f"fake impl."
|
|
)
|
|
return self._abstract_fn(*args, **kwargs)
|
|
|
|
library.impl_abstract(self._qualname, lib=lib)(fake_impl)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self._opoverload(*args, **kwargs)
|
|
|
|
|
|
# NOTE: [Supporting decorator and non-decorator usage]
|
|
#
|
|
# Some APIs may be both used as a decorator and not as a decorator.
|
|
# For example:
|
|
#
|
|
# >>> def fn(x):
|
|
# >>> return x.sin()
|
|
# >>>
|
|
# >>> # Usage 1: not as a decorator
|
|
# >>> numpy_sin.register_impl("cuda", fn)
|
|
# >>>
|
|
# >>> # Usage 2: as a decorator
|
|
# >>> @numpy_sin.register_impl("cuda")
|
|
# >>> def fn2(x):
|
|
# >>> return x.sin
|
|
#
|
|
# The way we support this is that `register_impl` accepts an optional `fn`.
|
|
# If `fn` is provided (Usage 1), then we know that the user is using it not
|
|
# as a decorator.
|
|
# If `fn` is not provided (Usage 2), then `register_impl` needs to return a
|
|
# decorator.
|
|
|
|
|
|
OPDEF_TO_LIB: Dict[str, "library.Library"] = {}
|
|
|
|
|
|
def get_library_allowing_overwrite(namespace: str, name: str) -> "library.Library":
|
|
qualname = f"{namespace}::{name}"
|
|
|
|
if qualname in OPDEF_TO_LIB:
|
|
OPDEF_TO_LIB[qualname]._destroy()
|
|
del OPDEF_TO_LIB[qualname]
|
|
|
|
lib = library.Library(namespace, "FRAGMENT")
|
|
OPDEF_TO_LIB[qualname] = lib
|
|
return lib
|
|
|
|
|
|
def iter_tensors(
|
|
args: Tuple[Any], kwargs: Dict[str, Any], allowed_nesting: int = 1
|
|
) -> Iterator[Tensor]:
|
|
def check(arg):
|
|
if isinstance(arg, Tensor):
|
|
yield arg
|
|
elif allowed_nesting > 0 and isinstance(arg, (tuple, list)):
|
|
yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1)
|
|
|
|
for arg in args:
|
|
yield from check(arg)
|
|
for kwarg in kwargs.values():
|
|
yield from check(kwarg)
|