reland "python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)" (#109518)

Reland - the previous PR was reverted by internal with this error:
```
  File "/data/sandcastle/boxes/eden-trunk-hg-fbcode-fbsource/buck-out/v2/gen/fbcode/363cd7e240f5d021/caffe2/torch/fb/trainer/data_modules/tests/__test_dataloader__/test_dataloader#link-tree/torch/__init__.py", line 29, in <module>
    from ._utils_internal import _functionalize_sync as _sync
ImportError: cannot import name '_functionalize_sync' from 'torch._utils_internal'
```

I couldn't figure out why internal was unhappy with the import. One potential reason is that I see a build rule for *another* `_utils_internal.py` in the fb folder here ([link](https://www.internalfb.com/code/fbsource/[30ed85cd88409af98b7490be137aaa5dfd7afd01]/fbcode/caffe2/TARGETS?lines=444))

Rather than burn more time investigating, I confirmed internally that the error goes away if I move the util from `torch/_utils_internal.py` to `torch/_utils.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109518
Approved by: https://github.com/albanD
This commit is contained in:
Brian Hirsh 2023-09-18 14:32:40 -07:00 committed by PyTorch MergeBot
parent 677a1010e6
commit 25e81f19f3
7 changed files with 105 additions and 28 deletions

View File

@ -10,7 +10,6 @@ import torch._functorch.config
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
import torch.utils.checkpoint import torch.utils.checkpoint
from torch._dynamo.testing import normalize_gm from torch._dynamo.testing import normalize_gm
from torch._functorch.aot_autograd import to_fun
from torch._higher_order_ops.wrap import wrap from torch._higher_order_ops.wrap import wrap
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
@ -240,6 +239,12 @@ class GraphModule(torch.nn.Module):
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
# Cannot re-use the version from AOTAutograd, since that uses python functional tensors.
def to_fun(x):
x_functional = torch._to_functional_tensor(x)
torch._mirror_autograd_meta_to(x, x_functional)
return x_functional
def aot_f_wrapper(func): def aot_f_wrapper(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -322,7 +327,8 @@ class GraphModule(torch.nn.Module):
check_count_and_graph(2, 2, 2, expected_graph) check_count_and_graph(2, 2, 2, expected_graph)
try: try:
x = torch._to_functional_tensor(t_clone2, mirror_autograd_meta=True) x = torch._to_functional_tensor(t_clone2)
torch._mirror_autograd_meta_to(t_clone2, x)
torch._enable_functionalization(reapply_views=False) torch._enable_functionalization(reapply_views=False)
aot_f_out = f(x) aot_f_out = f(x)
finally: finally:

View File

@ -24,6 +24,7 @@ def _running_with_deploy():
return sys.modules.get("torch._meta_registrations", None) is object return sys.modules.get("torch._meta_registrations", None) is object
from ._utils import _import_dotted_name, classproperty from ._utils import _import_dotted_name, classproperty
from ._utils import _functionalize_sync as _sync
from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \ from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \
USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_GLOBAL_DEPS USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_GLOBAL_DEPS

View File

@ -676,7 +676,9 @@ def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires
def to_fun(t): def to_fun(t):
if isinstance(t, Tensor): if isinstance(t, Tensor):
return torch._to_functional_tensor(t, mirror_autograd_meta=True) out = torch._to_functional_tensor(t)
torch._mirror_autograd_meta_to(t, out)
return out
else: else:
return t return t
@ -727,7 +729,8 @@ def run_functionalized_fw_and_collect_metadata(
if isinstance(t, Tensor): if isinstance(t, Tensor):
if t in memo: if t in memo:
return memo[t] return memo[t]
r = torch._to_functional_tensor(t, mirror_autograd_meta=True) r = torch._to_functional_tensor(t)
torch._mirror_autograd_meta_to(t, r)
memo[t] = r memo[t] = r
return r return r
else: else:

View File

@ -119,7 +119,9 @@ class FunctionalTensor(torch.Tensor):
# - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine) # - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine)
# this is handled by FunctionalTensor.to_functional # this is handled by FunctionalTensor.to_functional
x_functional = torch._to_functional_tensor(x) x_functional = torch._to_functional_tensor(x)
torch._mirror_autograd_meta_to(x, x_functional)
out = FunctionalTensor(x_functional) out = FunctionalTensor(x_functional)
torch._mirror_autograd_meta_to(x_functional, out)
return out return out
def from_functional(self): def from_functional(self):

View File

@ -600,9 +600,9 @@ class MetaConverter:
dynamic_dims=dynamic_dims, dynamic_dims=dynamic_dims,
constraint_dims=constraint_dims, constraint_dims=constraint_dims,
) )
return torch._to_functional_tensor( out = torch._to_functional_tensor(fake_t)
fake_t, mirror_autograd_meta=True torch._mirror_autograd_meta_to(fake_t, out)
) return out
else: else:
# torch.func.functionalize # torch.func.functionalize
reapply_views = torch._C._functionalization_reapply_views_tls() reapply_views = torch._C._functionalization_reapply_views_tls()

View File

@ -4,6 +4,7 @@ import sys
import traceback import traceback
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from typing import Any, DefaultDict, List, Optional from typing import Any, DefaultDict, List, Optional
import torch import torch
@ -842,6 +843,44 @@ def is_compiling():
return False return False
def _functionalize_sync(t):
# This code lives in python instead of C++ since conditioning on a certain python subclass
# is much more of a pain in C++.
from torch._subclasses.functional_tensor import (
FunctionalTensor,
maybe_disable_functional_mode,
)
ctx = (
maybe_disable_functional_mode
if isinstance(t, FunctionalTensor)
else nullcontext
)
if isinstance(t, FunctionalTensor):
# If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
# when we sync our inner tensor.
# Why?
# (1) If there are input mutations in the graph, then they will be re-applied during
# AOTAutograd when we call _sync() from inside of our functionalization kernels.
# (2) _sync() causes us to regenerate our updated the tensor from the updated base,
# which dispatches to a bunch of view ops
# (3) The input to these view ops is our inner FunctionalTensorWrapper
# (since the sync was called from C++), not the python FunctionalTensor
# (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
# the view op, since it will see an input that is a C++ FunctionalTensorWrapper
# (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
maybe_functional_mode = torch._C._unset_dispatch_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
try:
torch._functionalize_sync(t.elem)
finally:
if maybe_functional_mode is not None:
torch._C._set_dispatch_mode(maybe_functional_mode)
else:
torch._functionalize_sync(t)
@functools.lru_cache(2) @functools.lru_cache(2)
def _get_device_module(device_type: str): def _get_device_module(device_type: str):
device_module = getattr(torch, device_type, None) device_module = getattr(torch, device_type, None)

View File

@ -362,31 +362,52 @@ static PyObject* THPVariable__to_functional_tensor(
PyObject* kwargs) { PyObject* kwargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
static PythonArgParser parser( static PythonArgParser parser(
{"_to_functional_tensor(Tensor t, *, bool mirror_autograd_meta=False)"}, {"_to_functional_tensor(Tensor t)"},
/*traceable=*/true); /*traceable=*/true);
ParsedArgs<2> parsed_args; ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args); auto r = parser.parse(args, kwargs, parsed_args);
auto self_ = r.tensor(0); auto self_ = r.tensor(0);
auto mirror_autograd_meta = r.toBool(1);
auto wrapped = at::functionalization::impl::to_functional_tensor(self_); auto wrapped = at::functionalization::impl::to_functional_tensor(self_);
if (mirror_autograd_meta) { return wrap(std::move(wrapped));
// Here, we unsafely set the grad function on the wrapper to be the same as END_HANDLE_TH_ERRORS
// the inner. We expect this grad_fn to NEVER be used. It's needed so that }
// .is_leaf metadata is accurate on the wrapper
auto inner_autograd_meta = impl::get_autograd_meta(self_); // Given source and dest tensors,
if (inner_autograd_meta) { // Sets **some** (but not all) autograd metadata on dest, according to source:
wrapped.set_requires_grad(self_.requires_grad()); // - requires_grad
if (wrapped.requires_grad()) { // - grad_fn
auto new_grad_fn = std::shared_ptr<torch::autograd::Error>( // (If src has a grad_fn, we install an error grad_fn on dest to avoid
new torch::autograd::Error( // difficult bugs.
"Cannot backprop through mirrored meta, file a bug in PyTorch"), // The main purpose is to ensure that dst.is_leaf == src.is_leaf)
torch::autograd::deleteNode); static PyObject* THPVariable__mirror_autograd_meta_to(
torch::autograd::set_history(wrapped, new_grad_fn); PyObject* self,
} PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_mirror_autograd_meta_to(Tensor source, Tensor dest)"},
/*traceable=*/true);
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto src_ = r.tensor(0);
auto dst_ = r.tensor(1);
// Here, we unsafely set the grad function on the wrapper to be the same as
// the inner. We expect this grad_fn to NEVER be used. It's needed so that
// .is_leaf metadata is accurate on the wrapper
auto inner_autograd_meta = impl::get_autograd_meta(src_);
if (inner_autograd_meta) {
dst_.set_requires_grad(src_.requires_grad());
if (dst_.requires_grad()) {
auto new_grad_fn = std::shared_ptr<torch::autograd::Error>(
new torch::autograd::Error(
"Cannot backprop through mirrored meta, file a bug in PyTorch"),
torch::autograd::deleteNode);
torch::autograd::set_history(dst_, new_grad_fn);
} }
} }
return wrap(std::move(wrapped)); Py_RETURN_NONE;
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
@ -526,12 +547,13 @@ static PyObject* THPVariable__disable_functionalization(
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPVariable__sync( static PyObject* THPVariable__functionalize_sync(
PyObject* self, PyObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
static PythonArgParser parser({"_sync(Tensor t)"}, /*traceable=*/true); static PythonArgParser parser(
{"_functionalize_sync(Tensor t)"}, /*traceable=*/true);
ParsedArgs<1> parsed_args; ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args); auto r = parser.parse(args, kwargs, parsed_args);
@ -568,6 +590,10 @@ static PyMethodDef torch_functions_manual[] = {
castPyCFunctionWithKeywords(THPVariable__to_functional_tensor), castPyCFunctionWithKeywords(THPVariable__to_functional_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr}, nullptr},
{"_mirror_autograd_meta_to",
castPyCFunctionWithKeywords(THPVariable__mirror_autograd_meta_to),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_from_functional_tensor", {"_from_functional_tensor",
castPyCFunctionWithKeywords(THPVariable__from_functional_tensor), castPyCFunctionWithKeywords(THPVariable__from_functional_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, METH_VARARGS | METH_KEYWORDS | METH_STATIC,
@ -576,8 +602,8 @@ static PyMethodDef torch_functions_manual[] = {
castPyCFunctionWithKeywords(THPVariable__freeze_functional_tensor), castPyCFunctionWithKeywords(THPVariable__freeze_functional_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr}, nullptr},
{"_sync", {"_functionalize_sync",
castPyCFunctionWithKeywords(THPVariable__sync), castPyCFunctionWithKeywords(THPVariable__functionalize_sync),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr}, nullptr},
{"_enable_functionalization", {"_enable_functionalization",