mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
reland "python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)" (#109518)
Reland - the previous PR was reverted by internal with this error:
```
File "/data/sandcastle/boxes/eden-trunk-hg-fbcode-fbsource/buck-out/v2/gen/fbcode/363cd7e240f5d021/caffe2/torch/fb/trainer/data_modules/tests/__test_dataloader__/test_dataloader#link-tree/torch/__init__.py", line 29, in <module>
from ._utils_internal import _functionalize_sync as _sync
ImportError: cannot import name '_functionalize_sync' from 'torch._utils_internal'
```
I couldn't figure out why internal was unhappy with the import. One potential reason is that I see a build rule for *another* `_utils_internal.py` in the fb folder here ([link](https://www.internalfb.com/code/fbsource/[30ed85cd88409af98b7490be137aaa5dfd7afd01]/fbcode/caffe2/TARGETS?lines=444))
Rather than burn more time investigating, I confirmed internally that the error goes away if I move the util from `torch/_utils_internal.py` to `torch/_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109518
Approved by: https://github.com/albanD
This commit is contained in:
parent
677a1010e6
commit
25e81f19f3
|
|
@ -10,7 +10,6 @@ import torch._functorch.config
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch._dynamo.testing import normalize_gm
|
from torch._dynamo.testing import normalize_gm
|
||||||
from torch._functorch.aot_autograd import to_fun
|
|
||||||
from torch._higher_order_ops.wrap import wrap
|
from torch._higher_order_ops.wrap import wrap
|
||||||
|
|
||||||
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
||||||
|
|
@ -240,6 +239,12 @@ class GraphModule(torch.nn.Module):
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
|
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
|
||||||
|
|
||||||
|
# Cannot re-use the version from AOTAutograd, since that uses python functional tensors.
|
||||||
|
def to_fun(x):
|
||||||
|
x_functional = torch._to_functional_tensor(x)
|
||||||
|
torch._mirror_autograd_meta_to(x, x_functional)
|
||||||
|
return x_functional
|
||||||
|
|
||||||
def aot_f_wrapper(func):
|
def aot_f_wrapper(func):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
|
@ -322,7 +327,8 @@ class GraphModule(torch.nn.Module):
|
||||||
check_count_and_graph(2, 2, 2, expected_graph)
|
check_count_and_graph(2, 2, 2, expected_graph)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
x = torch._to_functional_tensor(t_clone2, mirror_autograd_meta=True)
|
x = torch._to_functional_tensor(t_clone2)
|
||||||
|
torch._mirror_autograd_meta_to(t_clone2, x)
|
||||||
torch._enable_functionalization(reapply_views=False)
|
torch._enable_functionalization(reapply_views=False)
|
||||||
aot_f_out = f(x)
|
aot_f_out = f(x)
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ def _running_with_deploy():
|
||||||
return sys.modules.get("torch._meta_registrations", None) is object
|
return sys.modules.get("torch._meta_registrations", None) is object
|
||||||
|
|
||||||
from ._utils import _import_dotted_name, classproperty
|
from ._utils import _import_dotted_name, classproperty
|
||||||
|
from ._utils import _functionalize_sync as _sync
|
||||||
from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \
|
from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \
|
||||||
USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_GLOBAL_DEPS
|
USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_GLOBAL_DEPS
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -676,7 +676,9 @@ def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires
|
||||||
|
|
||||||
def to_fun(t):
|
def to_fun(t):
|
||||||
if isinstance(t, Tensor):
|
if isinstance(t, Tensor):
|
||||||
return torch._to_functional_tensor(t, mirror_autograd_meta=True)
|
out = torch._to_functional_tensor(t)
|
||||||
|
torch._mirror_autograd_meta_to(t, out)
|
||||||
|
return out
|
||||||
else:
|
else:
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
@ -727,7 +729,8 @@ def run_functionalized_fw_and_collect_metadata(
|
||||||
if isinstance(t, Tensor):
|
if isinstance(t, Tensor):
|
||||||
if t in memo:
|
if t in memo:
|
||||||
return memo[t]
|
return memo[t]
|
||||||
r = torch._to_functional_tensor(t, mirror_autograd_meta=True)
|
r = torch._to_functional_tensor(t)
|
||||||
|
torch._mirror_autograd_meta_to(t, r)
|
||||||
memo[t] = r
|
memo[t] = r
|
||||||
return r
|
return r
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,9 @@ class FunctionalTensor(torch.Tensor):
|
||||||
# - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine)
|
# - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine)
|
||||||
# this is handled by FunctionalTensor.to_functional
|
# this is handled by FunctionalTensor.to_functional
|
||||||
x_functional = torch._to_functional_tensor(x)
|
x_functional = torch._to_functional_tensor(x)
|
||||||
|
torch._mirror_autograd_meta_to(x, x_functional)
|
||||||
out = FunctionalTensor(x_functional)
|
out = FunctionalTensor(x_functional)
|
||||||
|
torch._mirror_autograd_meta_to(x_functional, out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def from_functional(self):
|
def from_functional(self):
|
||||||
|
|
|
||||||
|
|
@ -600,9 +600,9 @@ class MetaConverter:
|
||||||
dynamic_dims=dynamic_dims,
|
dynamic_dims=dynamic_dims,
|
||||||
constraint_dims=constraint_dims,
|
constraint_dims=constraint_dims,
|
||||||
)
|
)
|
||||||
return torch._to_functional_tensor(
|
out = torch._to_functional_tensor(fake_t)
|
||||||
fake_t, mirror_autograd_meta=True
|
torch._mirror_autograd_meta_to(fake_t, out)
|
||||||
)
|
return out
|
||||||
else:
|
else:
|
||||||
# torch.func.functionalize
|
# torch.func.functionalize
|
||||||
reapply_views = torch._C._functionalization_reapply_views_tls()
|
reapply_views = torch._C._functionalization_reapply_views_tls()
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import Any, DefaultDict, List, Optional
|
from typing import Any, DefaultDict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -842,6 +843,44 @@ def is_compiling():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _functionalize_sync(t):
|
||||||
|
# This code lives in python instead of C++ since conditioning on a certain python subclass
|
||||||
|
# is much more of a pain in C++.
|
||||||
|
from torch._subclasses.functional_tensor import (
|
||||||
|
FunctionalTensor,
|
||||||
|
maybe_disable_functional_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = (
|
||||||
|
maybe_disable_functional_mode
|
||||||
|
if isinstance(t, FunctionalTensor)
|
||||||
|
else nullcontext
|
||||||
|
)
|
||||||
|
if isinstance(t, FunctionalTensor):
|
||||||
|
# If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
|
||||||
|
# when we sync our inner tensor.
|
||||||
|
# Why?
|
||||||
|
# (1) If there are input mutations in the graph, then they will be re-applied during
|
||||||
|
# AOTAutograd when we call _sync() from inside of our functionalization kernels.
|
||||||
|
# (2) _sync() causes us to regenerate our updated the tensor from the updated base,
|
||||||
|
# which dispatches to a bunch of view ops
|
||||||
|
# (3) The input to these view ops is our inner FunctionalTensorWrapper
|
||||||
|
# (since the sync was called from C++), not the python FunctionalTensor
|
||||||
|
# (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
|
||||||
|
# the view op, since it will see an input that is a C++ FunctionalTensorWrapper
|
||||||
|
# (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
|
||||||
|
maybe_functional_mode = torch._C._unset_dispatch_mode(
|
||||||
|
torch._C._TorchDispatchModeKey.FUNCTIONAL
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
torch._functionalize_sync(t.elem)
|
||||||
|
finally:
|
||||||
|
if maybe_functional_mode is not None:
|
||||||
|
torch._C._set_dispatch_mode(maybe_functional_mode)
|
||||||
|
else:
|
||||||
|
torch._functionalize_sync(t)
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(2)
|
@functools.lru_cache(2)
|
||||||
def _get_device_module(device_type: str):
|
def _get_device_module(device_type: str):
|
||||||
device_module = getattr(torch, device_type, None)
|
device_module = getattr(torch, device_type, None)
|
||||||
|
|
|
||||||
|
|
@ -362,31 +362,52 @@ static PyObject* THPVariable__to_functional_tensor(
|
||||||
PyObject* kwargs) {
|
PyObject* kwargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
static PythonArgParser parser(
|
static PythonArgParser parser(
|
||||||
{"_to_functional_tensor(Tensor t, *, bool mirror_autograd_meta=False)"},
|
{"_to_functional_tensor(Tensor t)"},
|
||||||
/*traceable=*/true);
|
/*traceable=*/true);
|
||||||
|
|
||||||
ParsedArgs<2> parsed_args;
|
ParsedArgs<2> parsed_args;
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
auto self_ = r.tensor(0);
|
auto self_ = r.tensor(0);
|
||||||
auto mirror_autograd_meta = r.toBool(1);
|
|
||||||
auto wrapped = at::functionalization::impl::to_functional_tensor(self_);
|
auto wrapped = at::functionalization::impl::to_functional_tensor(self_);
|
||||||
if (mirror_autograd_meta) {
|
return wrap(std::move(wrapped));
|
||||||
// Here, we unsafely set the grad function on the wrapper to be the same as
|
END_HANDLE_TH_ERRORS
|
||||||
// the inner. We expect this grad_fn to NEVER be used. It's needed so that
|
}
|
||||||
// .is_leaf metadata is accurate on the wrapper
|
|
||||||
auto inner_autograd_meta = impl::get_autograd_meta(self_);
|
// Given source and dest tensors,
|
||||||
if (inner_autograd_meta) {
|
// Sets **some** (but not all) autograd metadata on dest, according to source:
|
||||||
wrapped.set_requires_grad(self_.requires_grad());
|
// - requires_grad
|
||||||
if (wrapped.requires_grad()) {
|
// - grad_fn
|
||||||
auto new_grad_fn = std::shared_ptr<torch::autograd::Error>(
|
// (If src has a grad_fn, we install an error grad_fn on dest to avoid
|
||||||
new torch::autograd::Error(
|
// difficult bugs.
|
||||||
"Cannot backprop through mirrored meta, file a bug in PyTorch"),
|
// The main purpose is to ensure that dst.is_leaf == src.is_leaf)
|
||||||
torch::autograd::deleteNode);
|
static PyObject* THPVariable__mirror_autograd_meta_to(
|
||||||
torch::autograd::set_history(wrapped, new_grad_fn);
|
PyObject* self,
|
||||||
}
|
PyObject* args,
|
||||||
|
PyObject* kwargs) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
static PythonArgParser parser(
|
||||||
|
{"_mirror_autograd_meta_to(Tensor source, Tensor dest)"},
|
||||||
|
/*traceable=*/true);
|
||||||
|
|
||||||
|
ParsedArgs<2> parsed_args;
|
||||||
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
|
auto src_ = r.tensor(0);
|
||||||
|
auto dst_ = r.tensor(1);
|
||||||
|
// Here, we unsafely set the grad function on the wrapper to be the same as
|
||||||
|
// the inner. We expect this grad_fn to NEVER be used. It's needed so that
|
||||||
|
// .is_leaf metadata is accurate on the wrapper
|
||||||
|
auto inner_autograd_meta = impl::get_autograd_meta(src_);
|
||||||
|
if (inner_autograd_meta) {
|
||||||
|
dst_.set_requires_grad(src_.requires_grad());
|
||||||
|
if (dst_.requires_grad()) {
|
||||||
|
auto new_grad_fn = std::shared_ptr<torch::autograd::Error>(
|
||||||
|
new torch::autograd::Error(
|
||||||
|
"Cannot backprop through mirrored meta, file a bug in PyTorch"),
|
||||||
|
torch::autograd::deleteNode);
|
||||||
|
torch::autograd::set_history(dst_, new_grad_fn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return wrap(std::move(wrapped));
|
Py_RETURN_NONE;
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -526,12 +547,13 @@ static PyObject* THPVariable__disable_functionalization(
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* THPVariable__sync(
|
static PyObject* THPVariable__functionalize_sync(
|
||||||
PyObject* self,
|
PyObject* self,
|
||||||
PyObject* args,
|
PyObject* args,
|
||||||
PyObject* kwargs) {
|
PyObject* kwargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
static PythonArgParser parser({"_sync(Tensor t)"}, /*traceable=*/true);
|
static PythonArgParser parser(
|
||||||
|
{"_functionalize_sync(Tensor t)"}, /*traceable=*/true);
|
||||||
|
|
||||||
ParsedArgs<1> parsed_args;
|
ParsedArgs<1> parsed_args;
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
|
|
@ -568,6 +590,10 @@ static PyMethodDef torch_functions_manual[] = {
|
||||||
castPyCFunctionWithKeywords(THPVariable__to_functional_tensor),
|
castPyCFunctionWithKeywords(THPVariable__to_functional_tensor),
|
||||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
||||||
nullptr},
|
nullptr},
|
||||||
|
{"_mirror_autograd_meta_to",
|
||||||
|
castPyCFunctionWithKeywords(THPVariable__mirror_autograd_meta_to),
|
||||||
|
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
||||||
|
nullptr},
|
||||||
{"_from_functional_tensor",
|
{"_from_functional_tensor",
|
||||||
castPyCFunctionWithKeywords(THPVariable__from_functional_tensor),
|
castPyCFunctionWithKeywords(THPVariable__from_functional_tensor),
|
||||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
||||||
|
|
@ -576,8 +602,8 @@ static PyMethodDef torch_functions_manual[] = {
|
||||||
castPyCFunctionWithKeywords(THPVariable__freeze_functional_tensor),
|
castPyCFunctionWithKeywords(THPVariable__freeze_functional_tensor),
|
||||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
||||||
nullptr},
|
nullptr},
|
||||||
{"_sync",
|
{"_functionalize_sync",
|
||||||
castPyCFunctionWithKeywords(THPVariable__sync),
|
castPyCFunctionWithKeywords(THPVariable__functionalize_sync),
|
||||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
||||||
nullptr},
|
nullptr},
|
||||||
{"_enable_functionalization",
|
{"_enable_functionalization",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user