Added two new utils to help with turning python functionalization on in AOTAutograd (next PR):
(1) updated `torch._sync()`. Previously, this API could only handle `torch.Tensor` instances that had a `FunctionalTensorWrapper` TensorImpl. It now needs to handle python `FunctionalTensor`'s. In theory I can probably break BC and change this API (since it's private?), but I decided not to do it in this PR stack do minimize the chance of reverts. Instead of updating that API directly (which is in C++), I just added a python shim that first tries to unwrap the python `FunctionalTensor` if there is one, then calls the existing C++ logic
(2) `mirror_autograd_meta` is now a standalone API that tries to mirror the `requires_grad` and `is_leaf` autograd metadata from one tensor to another. Previously this was hardcoded into `torch._to_functional_tensor()`. But I now need to use it in a more standalone way: later in AOTAutograd when we unwrap and re-wrap a tensor subclasses, we need to manually mirror the autograd metadata from the original to the updated version of the subclass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107917
Approved by: https://github.com/ezyang
ghstack dependencies: #106404
This PR adds a new `FunctionalTensor` subclass, and `FunctionalTensorMode` torch dispatch mode. Together, this class/mode are a lightweight wrapper around our existing C++ functionalization logic.
This idea came from Ed - later in the stack, I want to be able to run functionalization **underneath** torch_dispatch, when performing tracing in AOTAutograd. I can't do this easily with vanilla C++ functionalization, because it has a dedicated dispatch key that always runs before TorchDispatch. However, by adding a torch_dispatch mode shim around functionalization, we can use functionalization as a torch_dispatch mode, which will make it easier to run underneath other modes later.
This PR provides the basic new classes, and some light testing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106404
Approved by: https://github.com/ezyang