Commit Graph

2 Commits

Author SHA1 Message Date
Edward Yang
e42360d56f Remove default arguments before calling to __torch_dispatch__ (#61123)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61123

This applies the design pattern of removing explicit arguments when they
coincide with the default arguments.  This simplifies argument patterns
that dispatch kernels receive and make it easier for us to maintain BC
(as addition of a new default argument isn't immediately BC-breaking
for dispatch implementors).

There is an important extra API which I haven't implemented here yet,
which is to take an incomplete sequence of arguments and fill out their
defaults (in case the user did want normalization).  I plan on adding
that in a future PR.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: saketh-are

Differential Revision: D29853616

Pulled By: ezyang

fbshipit-source-id: 71c672cb3a7d4d01f838a1c7fcdb75a8ce7d058e
2021-07-23 10:41:35 -07:00
Edward Yang
aacc722aec Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760

See https://github.com/pytorch/pytorch/issues/59049

There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.

**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.

**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with  then newly added `check_has_torch_dispatch`.

**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.

**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.

**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.

**Known limitations.**

* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision:
D29017912
D29017912

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Pulled By: ezyang

fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 11:50:32 -07:00