Commit Graph

3 Commits

Author SHA1 Message Date
PyTorch MergeBot
49b18ae546 Revert "python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)"
This reverts commit 0ad595954a.

Reverted https://github.com/pytorch/pytorch/pull/107917 on behalf of https://github.com/clee2000 due to breaking internal builds D49346637 ([comment](https://github.com/pytorch/pytorch/pull/107917#issuecomment-1722566885))
2023-09-17 20:57:41 +00:00
Brian Hirsh
0ad595954a python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)
Added two new utils to help with turning python functionalization on in AOTAutograd (next PR):

(1) updated `torch._sync()`. Previously, this API could only handle `torch.Tensor` instances that had a `FunctionalTensorWrapper` TensorImpl. It now needs to handle python `FunctionalTensor`'s. In theory I can probably break BC and change this API (since it's private?), but I decided not to do it in this PR stack do minimize the chance of reverts. Instead of updating that API directly (which is in C++), I just added a python shim that first tries to unwrap the python `FunctionalTensor` if there is one, then calls the existing C++ logic

(2) `mirror_autograd_meta` is now a standalone API that tries to mirror the `requires_grad` and `is_leaf` autograd metadata from one tensor to another. Previously this was hardcoded into `torch._to_functional_tensor()`. But I now need to use it in a more standalone way: later in AOTAutograd when we unwrap and re-wrap a tensor subclasses, we need to manually mirror the autograd metadata from the original to the updated version of the subclass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107917
Approved by: https://github.com/ezyang
ghstack dependencies: #106404
2023-09-15 20:19:25 +00:00
Brian Hirsh
f22b303f65 Add TorchDispatch version of functionalization (#106404)
This PR adds a new `FunctionalTensor` subclass, and `FunctionalTensorMode` torch dispatch mode. Together, this class/mode are a lightweight wrapper around our existing C++ functionalization logic.

This idea came from Ed - later in the stack, I want to be able to run functionalization **underneath** torch_dispatch, when performing tracing in AOTAutograd. I can't do this easily with vanilla C++ functionalization, because it has a dedicated dispatch key that always runs before TorchDispatch. However, by adding a torch_dispatch mode shim around functionalization, we can use functionalization as a torch_dispatch mode, which will make it easier to run underneath other modes later.

This PR provides the basic new classes, and some light testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106404
Approved by: https://github.com/ezyang
2023-09-15 20:19:25 +00:00