reland "python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)" (#109518)

Reland - the previous PR was reverted by internal with this error:
```
  File "/data/sandcastle/boxes/eden-trunk-hg-fbcode-fbsource/buck-out/v2/gen/fbcode/363cd7e240f5d021/caffe2/torch/fb/trainer/data_modules/tests/__test_dataloader__/test_dataloader#link-tree/torch/__init__.py", line 29, in <module>
    from ._utils_internal import _functionalize_sync as _sync
ImportError: cannot import name '_functionalize_sync' from 'torch._utils_internal'
```

I couldn't figure out why internal was unhappy with the import. One potential reason is that I see a build rule for *another* `_utils_internal.py` in the fb folder here ([link](https://www.internalfb.com/code/fbsource/[30ed85cd88409af98b7490be137aaa5dfd7afd01]/fbcode/caffe2/TARGETS?lines=444))

Rather than burn more time investigating, I confirmed internally that the error goes away if I move the util from `torch/_utils_internal.py` to `torch/_utils.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109518
Approved by: https://github.com/albanD
This commit is contained in:
Brian Hirsh 2023-09-18 14:32:40 -07:00 committed by PyTorch MergeBot
parent 677a1010e6
commit 25e81f19f3
7 changed files with 105 additions and 28 deletions

View File

@ -10,7 +10,6 @@ import torch._functorch.config
import torch.utils._pytree as pytree
import torch.utils.checkpoint
from torch._dynamo.testing import normalize_gm
from torch._functorch.aot_autograd import to_fun
from torch._higher_order_ops.wrap import wrap
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
@ -240,6 +239,12 @@ class GraphModule(torch.nn.Module):
self.assertEqual(actual, expected)
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
# Cannot re-use the version from AOTAutograd, since that uses python functional tensors.
def to_fun(x):
x_functional = torch._to_functional_tensor(x)
torch._mirror_autograd_meta_to(x, x_functional)
return x_functional
def aot_f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
@ -322,7 +327,8 @@ class GraphModule(torch.nn.Module):
check_count_and_graph(2, 2, 2, expected_graph)
try:
x = torch._to_functional_tensor(t_clone2, mirror_autograd_meta=True)
x = torch._to_functional_tensor(t_clone2)
torch._mirror_autograd_meta_to(t_clone2, x)
torch._enable_functionalization(reapply_views=False)
aot_f_out = f(x)
finally:

View File

@ -24,6 +24,7 @@ def _running_with_deploy():
return sys.modules.get("torch._meta_registrations", None) is object
from ._utils import _import_dotted_name, classproperty
from ._utils import _functionalize_sync as _sync
from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \
USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_GLOBAL_DEPS

View File

@ -676,7 +676,9 @@ def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires
def to_fun(t):
if isinstance(t, Tensor):
return torch._to_functional_tensor(t, mirror_autograd_meta=True)
out = torch._to_functional_tensor(t)
torch._mirror_autograd_meta_to(t, out)
return out
else:
return t
@ -727,7 +729,8 @@ def run_functionalized_fw_and_collect_metadata(
if isinstance(t, Tensor):
if t in memo:
return memo[t]
r = torch._to_functional_tensor(t, mirror_autograd_meta=True)
r = torch._to_functional_tensor(t)
torch._mirror_autograd_meta_to(t, r)
memo[t] = r
return r
else:

View File

@ -119,7 +119,9 @@ class FunctionalTensor(torch.Tensor):
# - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine)
# this is handled by FunctionalTensor.to_functional
x_functional = torch._to_functional_tensor(x)
torch._mirror_autograd_meta_to(x, x_functional)
out = FunctionalTensor(x_functional)
torch._mirror_autograd_meta_to(x_functional, out)
return out
def from_functional(self):

View File

@ -600,9 +600,9 @@ class MetaConverter:
dynamic_dims=dynamic_dims,
constraint_dims=constraint_dims,
)
return torch._to_functional_tensor(
fake_t, mirror_autograd_meta=True
)
out = torch._to_functional_tensor(fake_t)
torch._mirror_autograd_meta_to(fake_t, out)
return out
else:
# torch.func.functionalize
reapply_views = torch._C._functionalization_reapply_views_tls()

View File

@ -4,6 +4,7 @@ import sys
import traceback
import warnings
from collections import defaultdict
from contextlib import nullcontext
from typing import Any, DefaultDict, List, Optional
import torch
@ -842,6 +843,44 @@ def is_compiling():
return False
def _functionalize_sync(t):
# This code lives in python instead of C++ since conditioning on a certain python subclass
# is much more of a pain in C++.
from torch._subclasses.functional_tensor import (
FunctionalTensor,
maybe_disable_functional_mode,
)
ctx = (
maybe_disable_functional_mode
if isinstance(t, FunctionalTensor)
else nullcontext
)
if isinstance(t, FunctionalTensor):
# If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
# when we sync our inner tensor.
# Why?
# (1) If there are input mutations in the graph, then they will be re-applied during
# AOTAutograd when we call _sync() from inside of our functionalization kernels.
# (2) _sync() causes us to regenerate our updated the tensor from the updated base,
# which dispatches to a bunch of view ops
# (3) The input to these view ops is our inner FunctionalTensorWrapper
# (since the sync was called from C++), not the python FunctionalTensor
# (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
# the view op, since it will see an input that is a C++ FunctionalTensorWrapper
# (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
maybe_functional_mode = torch._C._unset_dispatch_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
try:
torch._functionalize_sync(t.elem)
finally:
if maybe_functional_mode is not None:
torch._C._set_dispatch_mode(maybe_functional_mode)
else:
torch._functionalize_sync(t)
@functools.lru_cache(2)
def _get_device_module(device_type: str):
device_module = getattr(torch, device_type, None)

View File

@ -362,31 +362,52 @@ static PyObject* THPVariable__to_functional_tensor(
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_to_functional_tensor(Tensor t, *, bool mirror_autograd_meta=False)"},
{"_to_functional_tensor(Tensor t)"},
/*traceable=*/true);
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto self_ = r.tensor(0);
auto mirror_autograd_meta = r.toBool(1);
auto wrapped = at::functionalization::impl::to_functional_tensor(self_);
if (mirror_autograd_meta) {
return wrap(std::move(wrapped));
END_HANDLE_TH_ERRORS
}
// Given source and dest tensors,
// Sets **some** (but not all) autograd metadata on dest, according to source:
// - requires_grad
// - grad_fn
// (If src has a grad_fn, we install an error grad_fn on dest to avoid
// difficult bugs.
// The main purpose is to ensure that dst.is_leaf == src.is_leaf)
static PyObject* THPVariable__mirror_autograd_meta_to(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_mirror_autograd_meta_to(Tensor source, Tensor dest)"},
/*traceable=*/true);
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto src_ = r.tensor(0);
auto dst_ = r.tensor(1);
// Here, we unsafely set the grad function on the wrapper to be the same as
// the inner. We expect this grad_fn to NEVER be used. It's needed so that
// .is_leaf metadata is accurate on the wrapper
auto inner_autograd_meta = impl::get_autograd_meta(self_);
auto inner_autograd_meta = impl::get_autograd_meta(src_);
if (inner_autograd_meta) {
wrapped.set_requires_grad(self_.requires_grad());
if (wrapped.requires_grad()) {
dst_.set_requires_grad(src_.requires_grad());
if (dst_.requires_grad()) {
auto new_grad_fn = std::shared_ptr<torch::autograd::Error>(
new torch::autograd::Error(
"Cannot backprop through mirrored meta, file a bug in PyTorch"),
torch::autograd::deleteNode);
torch::autograd::set_history(wrapped, new_grad_fn);
torch::autograd::set_history(dst_, new_grad_fn);
}
}
}
return wrap(std::move(wrapped));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@ -526,12 +547,13 @@ static PyObject* THPVariable__disable_functionalization(
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable__sync(
static PyObject* THPVariable__functionalize_sync(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({"_sync(Tensor t)"}, /*traceable=*/true);
static PythonArgParser parser(
{"_functionalize_sync(Tensor t)"}, /*traceable=*/true);
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
@ -568,6 +590,10 @@ static PyMethodDef torch_functions_manual[] = {
castPyCFunctionWithKeywords(THPVariable__to_functional_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_mirror_autograd_meta_to",
castPyCFunctionWithKeywords(THPVariable__mirror_autograd_meta_to),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_from_functional_tensor",
castPyCFunctionWithKeywords(THPVariable__from_functional_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
@ -576,8 +602,8 @@ static PyMethodDef torch_functions_manual[] = {
castPyCFunctionWithKeywords(THPVariable__freeze_functional_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_sync",
castPyCFunctionWithKeywords(THPVariable__sync),
{"_functionalize_sync",
castPyCFunctionWithKeywords(THPVariable__functionalize_sync),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_enable_functionalization",