mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Should be non-semantic. Uses https://en.wikipedia.org/wiki/Wikipedia:Lists_of_common_misspellings/For_machines to find likely typos, with https://github.com/bwignall/typochecker to help automate the checking. Uses an updated version of the tool used in https://github.com/pytorch/pytorch/pull/30606 . Pull Request resolved: https://github.com/pytorch/pytorch/pull/31523 Differential Revision: D19216749 Pulled By: mrshenli fbshipit-source-id: 7fd489cb9a77cd7e4950c1046f925d57524960ea
502 lines
21 KiB
ReStructuredText
502 lines
21 KiB
ReStructuredText
Extending PyTorch
|
|
=================
|
|
|
|
In this note we'll cover ways of extending :mod:`torch.nn`,
|
|
:mod:`torch.autograd`, :mod:`torch`, and writing custom C extensions utilizing our C
|
|
libraries.
|
|
|
|
Extending :mod:`torch.autograd`
|
|
-------------------------------
|
|
|
|
.. currentmodule:: torch.autograd
|
|
|
|
Adding operations to :mod:`~torch.autograd` requires implementing a new
|
|
:class:`Function` subclass for each operation. Recall that :class:`Function` s
|
|
are what :mod:`~torch.autograd` uses to compute the results and gradients, and
|
|
encode the operation history. Every new function requires you to implement 2 methods:
|
|
|
|
- :meth:`~Function.forward` - the code that performs the operation. It can take
|
|
as many arguments as you want, with some of them being optional, if you
|
|
specify the default values. All kinds of Python objects are accepted here.
|
|
:class:`Tensor` arguments that track history (i.e., with
|
|
``requires_grad=True``) will be converted to ones that don't track history
|
|
before the call, and their use will be registered in the graph. Note that this
|
|
logic won't traverse lists/dicts/any other data structures and will only
|
|
consider :class:`Tensor` s that are direct arguments to the call. You can
|
|
return either a single :class:`Tensor` output, or a :class:`tuple` of
|
|
:class:`Tensor` s if there are multiple outputs. Also, please refer to the
|
|
docs of :class:`Function` to find descriptions of useful methods that can be
|
|
called only from :meth:`~Function.forward`.
|
|
- :meth:`~Function.backward` - gradient formula. It will be given
|
|
as many :class:`Tensor` arguments as there were outputs, with each of them
|
|
representing gradient w.r.t. that output. It should return as many
|
|
:class:`Tensor` s as there were inputs, with each of them containing the
|
|
gradient w.r.t. its corresponding input. If your inputs didn't require
|
|
gradient (:attr:`~ctx.needs_input_grad` is a tuple of booleans indicating
|
|
whether each input needs gradient computation), or were non-:class:`Tensor`
|
|
objects, you can return :class:`python:None`. Also, if you have optional
|
|
arguments to :meth:`~Function.forward` you can return more gradients than there
|
|
were inputs, as long as they're all :any:`python:None`.
|
|
|
|
.. note::
|
|
|
|
It's the user's responsibility to use the special functions in the forward's `ctx`
|
|
properly in order to ensure that the new :class:`Function` works properly with
|
|
the autograd engine.
|
|
|
|
- :meth:`~torch.autograd.function._ContextMethodMixin.save_for_backward` must be
|
|
used when saving input or output of the forward to be used later in the backward.
|
|
- :meth:`~torch.autograd.function._ContextMethodMixin.mark_dirty` must be used to
|
|
mark any input that is modified inplace by the forward function.
|
|
- :meth:`~torch.autograd.function._ContextMethodMixin.mark_non_differentiable` must
|
|
be used to tell the engine if an output is not differentiable.
|
|
|
|
|
|
Below you can find code for a ``Linear`` function from :mod:`torch.nn`, with
|
|
additional comments::
|
|
|
|
# Inherit from Function
|
|
class LinearFunction(Function):
|
|
|
|
# Note that both forward and backward are @staticmethods
|
|
@staticmethod
|
|
# bias is an optional argument
|
|
def forward(ctx, input, weight, bias=None):
|
|
ctx.save_for_backward(input, weight, bias)
|
|
output = input.mm(weight.t())
|
|
if bias is not None:
|
|
output += bias.unsqueeze(0).expand_as(output)
|
|
return output
|
|
|
|
# This function has only a single output, so it gets only one gradient
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
# This is a pattern that is very convenient - at the top of backward
|
|
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
|
|
# None. Thanks to the fact that additional trailing Nones are
|
|
# ignored, the return statement is simple even when the function has
|
|
# optional inputs.
|
|
input, weight, bias = ctx.saved_tensors
|
|
grad_input = grad_weight = grad_bias = None
|
|
|
|
# These needs_input_grad checks are optional and there only to
|
|
# improve efficiency. If you want to make your code simpler, you can
|
|
# skip them. Returning gradients for inputs that don't require it is
|
|
# not an error.
|
|
if ctx.needs_input_grad[0]:
|
|
grad_input = grad_output.mm(weight)
|
|
if ctx.needs_input_grad[1]:
|
|
grad_weight = grad_output.t().mm(input)
|
|
if bias is not None and ctx.needs_input_grad[2]:
|
|
grad_bias = grad_output.sum(0)
|
|
|
|
return grad_input, grad_weight, grad_bias
|
|
|
|
Now, to make it easier to use these custom ops, we recommend aliasing their
|
|
``apply`` method::
|
|
|
|
linear = LinearFunction.apply
|
|
|
|
Here, we give an additional example of a function that is parametrized by
|
|
non-Tensor arguments::
|
|
|
|
class MulConstant(Function):
|
|
@staticmethod
|
|
def forward(ctx, tensor, constant):
|
|
# ctx is a context object that can be used to stash information
|
|
# for backward computation
|
|
ctx.constant = constant
|
|
return tensor * constant
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
# We return as many input gradients as there were arguments.
|
|
# Gradients of non-Tensor arguments to forward must be None.
|
|
return grad_output * ctx.constant, None
|
|
|
|
.. note::
|
|
Inputs to ``backward``, i.e., :attr:`grad_output`, can also be Tensors that
|
|
track history. So if ``backward`` is implemented with differentiable
|
|
operations, (e.g., invocation of another custom
|
|
:class:`~torch.autograd.function`), higher order derivatives will work.
|
|
|
|
You probably want to check if the backward method you implemented actually
|
|
computes the derivatives of your function. It is possible by comparing with
|
|
numerical approximations using small finite differences::
|
|
|
|
from torch.autograd import gradcheck
|
|
|
|
# gradcheck takes a tuple of tensors as input, check if your gradient
|
|
# evaluated with these tensors are close enough to numerical
|
|
# approximations and returns True if they all verify this condition.
|
|
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
|
|
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
|
|
print(test)
|
|
|
|
See :ref:`grad-check` for more details on finite-difference gradient comparisons.
|
|
|
|
Extending :mod:`torch.nn`
|
|
-------------------------
|
|
|
|
.. currentmodule:: torch.nn
|
|
|
|
:mod:`~torch.nn` exports two kinds of interfaces - modules and their functional
|
|
versions. You can extend it in both ways, but we recommend using modules for
|
|
all kinds of layers, that hold any parameters or buffers, and recommend using
|
|
a functional form parameter-less operations like activation functions, pooling,
|
|
etc.
|
|
|
|
Adding a functional version of an operation is already fully covered in the
|
|
section above.
|
|
|
|
Adding a :class:`Module`
|
|
^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
Since :mod:`~torch.nn` heavily utilizes :mod:`~torch.autograd`, adding a new
|
|
:class:`Module` requires implementing a :class:`~torch.autograd.Function`
|
|
that performs the operation and can compute the gradient. From now on let's
|
|
assume that we want to implement a ``Linear`` module and we have the function
|
|
implemented as in the listing above. There's very little code required to
|
|
add this. Now, there are two functions that need to be implemented:
|
|
|
|
- ``__init__`` (*optional*) - takes in arguments such as kernel sizes, numbers
|
|
of features, etc. and initializes parameters and buffers.
|
|
- :meth:`~Module.forward` - instantiates a :class:`~torch.autograd.Function` and
|
|
uses it to perform the operation. It's very similar to a functional wrapper
|
|
shown above.
|
|
|
|
This is how a ``Linear`` module can be implemented::
|
|
|
|
class Linear(nn.Module):
|
|
def __init__(self, input_features, output_features, bias=True):
|
|
super(Linear, self).__init__()
|
|
self.input_features = input_features
|
|
self.output_features = output_features
|
|
|
|
# nn.Parameter is a special kind of Tensor, that will get
|
|
# automatically registered as Module's parameter once it's assigned
|
|
# as an attribute. Parameters and buffers need to be registered, or
|
|
# they won't appear in .parameters() (doesn't apply to buffers), and
|
|
# won't be converted when e.g. .cuda() is called. You can use
|
|
# .register_buffer() to register buffers.
|
|
# nn.Parameters require gradients by default.
|
|
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.Tensor(output_features))
|
|
else:
|
|
# You should always register all possible parameters, but the
|
|
# optional ones can be None if you want.
|
|
self.register_parameter('bias', None)
|
|
|
|
# Not a very smart way to initialize weights
|
|
self.weight.data.uniform_(-0.1, 0.1)
|
|
if bias is not None:
|
|
self.bias.data.uniform_(-0.1, 0.1)
|
|
|
|
def forward(self, input):
|
|
# See the autograd section for explanation of what happens here.
|
|
return LinearFunction.apply(input, self.weight, self.bias)
|
|
|
|
def extra_repr(self):
|
|
# (Optional)Set the extra information about this module. You can test
|
|
# it by printing an object of this class.
|
|
return 'input_features={}, output_features={}, bias={}'.format(
|
|
self.input_features, self.output_features, self.bias is not None
|
|
)
|
|
|
|
Extending :mod:`torch`
|
|
----------------------
|
|
|
|
You can create custom types that emulate :class:`Tensor` by defining a custom
|
|
class with methods that match :class:`Tensor`. But what if you want to be able
|
|
to pass these types to functions like :func:`torch.add` in the top-level
|
|
:mod:`torch` namespace that accept :class:`Tensor` operands?
|
|
|
|
If your custom python type defines a method named ``__torch_function__``, PyTorch
|
|
will invoke your ``__torch_function__`` implementation when an instance of your
|
|
custom class is passed to a function in the :mod:`torch` namespace. This makes
|
|
it possible to define custom implementations for any of the functions in the
|
|
:mod:`torch` namespace which your ``__torch_function__`` implementation can call,
|
|
allowing your users to make use of your custom type with existing PyTorch
|
|
workflows that they have already written for :class:`Tensor`. This works with
|
|
"duck" types that are unrelated to :class:`Tensor` as well as user-defined
|
|
subclasses of :class:`Tensor`.
|
|
|
|
Extending :mod:`torch` with a :class:`Tensor`-like type
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
.. note:: This functionality is inspired by the NumPy ``__array_function__``
|
|
protocol. See `the NumPy documentation
|
|
<https://docs.scipy.org/doc/numpy/user/basics.dispatch.html#basics-dispatch>`_
|
|
and `NEP-0018
|
|
<https://numpy.org/neps/nep-0018-array-function-protocol.html>`_ for
|
|
more details.
|
|
|
|
To make this concrete, let's begin with a simple example that illustrates the
|
|
API dispatch mechanism. We'll create a custom type that represents a 2D scalar
|
|
tensor, parametrized by the order ``N`` and value along the diagonal entries,
|
|
``value``::
|
|
|
|
class ScalarTensor(object):
|
|
def __init__(self, N, value):
|
|
self._N = N
|
|
self._value = value
|
|
|
|
def __repr__(self):
|
|
return "DiagonalTensor(N={}, value={})".format(self._N, self._value)
|
|
|
|
def tensor(self):
|
|
return self._value * torch.eye(self._N)
|
|
|
|
This first iteration of the design isn't very useful. The main functionality of
|
|
``ScalarTensor`` is to provide a more compact string representation of a scalar
|
|
tensor than in the base tensor class::
|
|
|
|
>>> d = ScalarTensor(5, 2)
|
|
>>> d
|
|
ScalarTensor(N=5, value=2)
|
|
>>> d.tensor()
|
|
tensor([[2., 0., 0., 0., 0.],
|
|
[0., 2., 0., 0., 0.],
|
|
[0., 0., 2., 0., 0.],
|
|
[0., 0., 0., 2., 0.],
|
|
[0., 0., 0., 0., 2.]])
|
|
|
|
If we try to use this object with the :mod:`torch` API, we will run
|
|
into issues::
|
|
|
|
>>> import torch
|
|
>>> torch.mean(d)
|
|
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor
|
|
|
|
Adding a ``__torch_function__`` implementation to ``ScalarTensor`` makes it
|
|
possible for the above operation to succeed. Let's re-do our implementation,
|
|
this time adding a ``__torch_function__`` implementation::
|
|
|
|
HANDLED_FUNCTIONS = {}
|
|
class ScalarTensor(object):
|
|
def __init__(self, N, value):
|
|
self._N = N
|
|
self._value = value
|
|
|
|
def __repr__(self):
|
|
return "DiagonalTensor(N={}, value={})".format(self._N, self._value)
|
|
|
|
def tensor(self):
|
|
return self._value * torch.eye(self._N)
|
|
|
|
def __torch_function__(self, func, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if func not in HANDLED_FUNCTIONS:
|
|
return NotImplemented
|
|
return HANDLED_FUNCTIONS[func](*args, **kwargs)
|
|
|
|
The ``__torch_function__`` method takes three arguments: ``func``, a reference to
|
|
the torch API function that is being overrided, ``args``, the tuple of arguments
|
|
passed to the function, and ``kwargs``, the dict of keyword arguments passed to
|
|
the function. It uses a global dispatch stable named ``HANDLED_FUNCTIONS`` to
|
|
store custom implementations. The keys of this dictionary are functions in the
|
|
``torch`` namespace and the values are implementations for ``ScalarTensor``.
|
|
|
|
.. note:: Using a global dispatch table is not a mandated part of the
|
|
``__torch_function__`` API, it is just a useful design pattern for
|
|
structuring your override implementations.
|
|
|
|
This class definition isn't quite enough to make ``torch.mean`` do the right
|
|
thing when we pass it a ``ScalarTensor`` -- we also need to define an
|
|
implementation for ``torch.mean`` for ``ScalarTensor`` operands and add the
|
|
implementation to the ``HANDLED_FUNCTIONS`` dispatch table dictionary. One way
|
|
of doing this is to define a decorator::
|
|
|
|
import functools
|
|
def implements(torch_function):
|
|
"""Register a torch function override for ScalarTensor"""
|
|
@functools.wraps(torch_function)
|
|
def decorator(func):
|
|
HANDLED_FUNCTIONS[torch_function] = func
|
|
return func
|
|
return decorator
|
|
|
|
which can be applied to the implementation of our override::
|
|
|
|
@implements(torch.mean)
|
|
def mean(input):
|
|
return float(input._value) / input._N
|
|
|
|
With this change we can now use ``torch.mean`` with ``ScalarTensor``::
|
|
|
|
>>> d = ScalarTensor(5, 2)
|
|
>>> torch.mean(d)
|
|
0.4
|
|
|
|
Of course ``torch.mean`` is an example of the simplest kind of function to
|
|
override since it only takes one operand. We can use the same machinery to
|
|
override a function that takes more than one operand, any one of which might be
|
|
a tensor or tensor-like that defines ``__torch_function__``, for example for
|
|
:func:`torch.add`::
|
|
|
|
def ensure_tensor(data):
|
|
if isinstance(data, ScalarTensor):
|
|
return data.tensor()
|
|
return torch.as_tensor(data)
|
|
|
|
@implements(torch.add)
|
|
def add(input, other):
|
|
try:
|
|
if input._N == other._N:
|
|
return ScalarTensor(input._N, input._value + other._value)
|
|
else:
|
|
raise ValueError("Shape mismatch!")
|
|
except AttributeError:
|
|
return torch.add(ensure_tensor(input), ensure_tensor(other))
|
|
|
|
This version has a fast path for when both operands are ``ScalarTensor``
|
|
instances and also a slower path which degrades to converting the data to
|
|
tensors when either operand is not a ``ScalarTensor``. That makes the override
|
|
function correctly when either operand is a ``ScalarTensor`` or a regular
|
|
:class:`Tensor`::
|
|
|
|
>>> s = ScalarTensor(2, 2)
|
|
>>> torch.add(s, s)
|
|
DiagonalTensor(N=2, value=4)
|
|
>>> t = torch.tensor([[1, 1,], [1, 1]])
|
|
>>> torch.add(s, t)
|
|
tensor([[3., 1.],
|
|
[1., 3.]])
|
|
|
|
Note that our implementation of ``add`` does not take ``alpha`` or ``out`` as
|
|
keyword arguments like :func:`torch.add` does::
|
|
|
|
>>> torch.add(s, s, alpha=2)
|
|
TypeError: add() got an unexpected keyword argument 'alpha'
|
|
|
|
For speed and flexibility the ``__torch_function__`` dispatch mechanism does not
|
|
check that the signature of an override function matches the signature of the
|
|
function being overrided in the :mod:`torch` API. For some applications ignoring
|
|
optional arguments would be fine but to ensure full compatibility with
|
|
:class:`Tensor`, user implementations of torch API functions should take care to
|
|
exactly emulate the API of the function that is being overrided.
|
|
|
|
Functions in the :mod:`torch` API that do not have explicit overrides will
|
|
return ``NotImplemented`` from ``__torch_function__``. If all operands with
|
|
``__torch_function__`` defined on them return ``NotImplemented``, PyTorch will
|
|
raise a ``TypeError``. This means that most of the time operations that do not
|
|
have explicit overrides for a type will raise a ``TypeError`` when an instance
|
|
of such a type is passed::
|
|
|
|
>>> torch.mul(s, 3)
|
|
TypeError: no implementation found for 'torch.mul' on types that
|
|
implement __torch_function__: [ScalarTensor]
|
|
|
|
In practice this means that if you would like to implement your overrides using
|
|
a ``__torch_function__`` implementation along these lines, you will need to
|
|
explicitly implement the full :mod:`torch` API or the entire subset of the API
|
|
that you care about for your use case. This may be a tall order as the full
|
|
:mod:`torch` API is quite extensive.
|
|
|
|
Another option is to not return ``NotImplemented`` for operations that are not
|
|
handled but to instead pass a :class:`Tensor` to the original :mod:`torch`
|
|
function when no override is available. For example, if we change our
|
|
implementation of ``__torch_function__`` for ``ScalarTensor`` to the one below::
|
|
|
|
def __torch_function__(self, func, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if func not in HANDLED_FUNCTIONS:
|
|
args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
|
|
return func(*args, **kwargs)
|
|
return HANDLED_FUNCTIONS[func](*args, **kwargs)
|
|
|
|
Then :func:`torch.mul` will work correctly, although the return type will always
|
|
be a :class:`Tensor` rather than a :class:`ScalarTensor`, even if both operands
|
|
are :class:`ScalarTensor` instances::
|
|
|
|
>>> s = ScalarTensor(2, 2)
|
|
>>> torch.mul(s, s)
|
|
tensor([[4., 0.],
|
|
[0., 4.]])
|
|
|
|
Also see the ``MetadataTensor`` example below for another variation on this
|
|
pattern but instead always returns a ``MetadataTensor`` to propagate metadata
|
|
through operations in the :mod:`torch` API.
|
|
|
|
Extending :mod:`torch` with a :class:`Tensor` wrapper type
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
Another useful case is a type that wraps a :class:`Tensor`, either as an
|
|
attribute or via subclassing. Below we implement a special case of this sort of
|
|
type, a ``MetadataTensor`` that attaches a dictionary of metadata to a
|
|
:class:`Tensor` that is propagated through :mod:`torch` operations. Since this
|
|
is a generic sort of wrapping for the full :mod:`torch` API, we do not need to
|
|
individually implement each override so we can make the ``__torch_function__``
|
|
implementation more permissive about what operations are allowed::
|
|
|
|
class MetadataTensor(object):
|
|
def __init__(self, data, metadata=None, **kwargs):
|
|
self._t = torch.as_tensor(data, **kwargs)
|
|
self._metadata = metadata
|
|
|
|
def __repr__(self):
|
|
return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
|
|
|
|
def __torch_function__(self, func, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args = [a._t if hasattr(a, '_t') else a for a in args]
|
|
ret = func(*args, **kwargs)
|
|
return MetadataTensor(ret, metadata=self._metadata)
|
|
|
|
This simple implementation won't necessarily work with every function in the
|
|
:mod:`torch` API but it is good enough to capture most common operations::
|
|
|
|
>>> metadata = {'owner': 'Ministry of Silly Walks'}
|
|
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
|
|
>>> t = torch.tensor([[1, 2], [1, 2]]])
|
|
>>> torch.add(t, m)
|
|
Metadata:
|
|
{'owner': 'Ministry of Silly Walks'}
|
|
|
|
data:
|
|
tensor([[2, 4],
|
|
[4, 6]])
|
|
>>> torch.mul(t, m)
|
|
Metadata:
|
|
{'owner': 'Ministry of Silly Walks'}
|
|
|
|
data:
|
|
tensor([[1, 4],
|
|
[3, 8]])
|
|
|
|
Operations on multiple types that define ``__torch_function__``
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
It is possible to use the torch API with multiple distinct types that each have
|
|
a ``__torch_function__`` implementation, but special care must be taken. In such
|
|
a case the rules are:
|
|
|
|
* The dispatch operation gathers all distinct implementations of
|
|
``__torch_function__`` for each operand and calls them in order: subclasses
|
|
before superclasses, and otherwise left to right in the operator expression.
|
|
* If any value other than ``NotImplemented`` is returned, that value is
|
|
returned as the result. Implementations can register that they do not
|
|
implement an operation by returning ``NotImplemented``.
|
|
* If all of the ``__torch_function__`` implementations return
|
|
``NotImplemented``, PyTorch raises a ``TypeError``.
|
|
|
|
Writing custom C++ extensions
|
|
-----------------------------
|
|
|
|
See this
|
|
`PyTorch tutorial <https://pytorch.org/tutorials/advanced/cpp_extension.html>`_
|
|
for a detailed explanation and examples.
|
|
|
|
Documentations are available at :doc:`../cpp_extension`.
|
|
|
|
|
|
Writing custom C extensions
|
|
---------------------------
|
|
|
|
Example available at
|
|
`this GitHub repository <https://github.com/pytorch/extension-ffi>`_.
|