The user does not need to return gradients for these args.
We also change how setup_context works to adapt to kwargonly-args. If
the user's op has no kwonly-args, then their setup_context function must
look like `setup_context(ctx, inputs, output)`: we require that the
arguments have the same names.
If the user's op has kwonly-args, then their setup_context function must
look like `setup_context(ctx, inputs, keyword_only_inputs, output)`.
We require that the arguments have the same names.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124806
Approved by: https://github.com/albanD, https://github.com/williamwen42
ghstack dependencies: #124637, #124805
old: `register_autograd(setup_context, backward, /)`
new: `register_autograd(backward, /, *, setup_context=None)`
Motivations:
- We introduce these APIs as "give us a backward and use setup_context
to save things for backward".
- setup_context isn't always necessary.
Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124403
Approved by: https://github.com/albanD
ghstack dependencies: #124180, #124200, #124299, #124134, #124199
Motivation:
- The API is used for registering an implementation for a specific
device type.
- "impl" is ambiguous and can be confused with Library.impl.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124200
Approved by: https://github.com/albanD
ghstack dependencies: #124180
A kernel has "dispatcher convention" if there is an additional keyset
arg at the beginning of the argument list. This PR:
- adds a way to register kernels with dispatcher_convention using
Library.impl (pass dispatcher_convention = True)
- adds OpOverload.redispatch
We use both of the above in the new custom ops API: we register the
autograd kernel in dispatcher convention so that we can actually call
redispatch like how pytorch built-in ops do it.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124089
Approved by: https://github.com/albanD
ghstack dependencies: #123937, #124064, #124065, #124066, #124071
We allow it to accept:
- a string with the op name
- an opoverload
- a new-style custom op
If any of these are referring to a new-style custom op (created with the
custom_op decorator), then we dispatch to CustomOpDef.register_fake.
Otherwise, we do what we previously did.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124066
Approved by: https://github.com/albanD
ghstack dependencies: #123937, #124064, #124065
Motivations:
- This makes things more consistent: using a Library object, you should
be able to do all of the registration APIs that tie registrations to
the lifetime of the Library.
- I need this for the next PR up in the stack, where we will have
torch.library.register_fake support both CustomOpDef (from the new
custom ops API) and other custom ops.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124065
Approved by: https://github.com/albanD
ghstack dependencies: #123937, #124064
The user provides a `setup_context` and a `backward_function`. These
get put into a torch.autograd.Function that gets registered as the
custom op's autograd implementation.
Test Plan:
- we update custom ops in the custom_op_db to use the new
register_autograd API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123110
Approved by: https://github.com/albanD
ghstack dependencies: #123108, #123109
This is the entrypoint for defining an opaque/blackbox (e.g. PyTorch will
never peek into it) custom op. In this PR, you can specify backend impls
and the abstract impl for this op.
NB: most of this PR is docstrings, please don't be intimidated by the
line count.
There are a number of interesting features:
- we infer the schema from type hints. In a followup I add the ability
to manually specify a schema.
- name inference. The user needs to manually specify an op name for now.
In a followup we add the ability to automatically infer a name (this
is a little tricky).
- custom_op registrations can override each other. This makes them
more pleasant to work with in environments like colab.
- we require that the outputs of the custom_op do not alias any inputs
or each other. We enforce this via a runtime check, but can relax this
into an opcheck test if it really matters in the future.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122344
Approved by: https://github.com/ezyang, https://github.com/albanD