mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
reland "python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)" (#109518)
Reland - the previous PR was reverted by internal with this error:
```
File "/data/sandcastle/boxes/eden-trunk-hg-fbcode-fbsource/buck-out/v2/gen/fbcode/363cd7e240f5d021/caffe2/torch/fb/trainer/data_modules/tests/__test_dataloader__/test_dataloader#link-tree/torch/__init__.py", line 29, in <module>
from ._utils_internal import _functionalize_sync as _sync
ImportError: cannot import name '_functionalize_sync' from 'torch._utils_internal'
```
I couldn't figure out why internal was unhappy with the import. One potential reason is that I see a build rule for *another* `_utils_internal.py` in the fb folder here ([link](https://www.internalfb.com/code/fbsource/[30ed85cd88409af98b7490be137aaa5dfd7afd01]/fbcode/caffe2/TARGETS?lines=444))
Rather than burn more time investigating, I confirmed internally that the error goes away if I move the util from `torch/_utils_internal.py` to `torch/_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109518
Approved by: https://github.com/albanD
This commit is contained in:
parent
677a1010e6
commit
25e81f19f3
|
|
@ -10,7 +10,6 @@ import torch._functorch.config
|
|||
import torch.utils._pytree as pytree
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.testing import normalize_gm
|
||||
from torch._functorch.aot_autograd import to_fun
|
||||
from torch._higher_order_ops.wrap import wrap
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
||||
|
|
@ -240,6 +239,12 @@ class GraphModule(torch.nn.Module):
|
|||
self.assertEqual(actual, expected)
|
||||
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
|
||||
|
||||
# Cannot re-use the version from AOTAutograd, since that uses python functional tensors.
|
||||
def to_fun(x):
|
||||
x_functional = torch._to_functional_tensor(x)
|
||||
torch._mirror_autograd_meta_to(x, x_functional)
|
||||
return x_functional
|
||||
|
||||
def aot_f_wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
|
@ -322,7 +327,8 @@ class GraphModule(torch.nn.Module):
|
|||
check_count_and_graph(2, 2, 2, expected_graph)
|
||||
|
||||
try:
|
||||
x = torch._to_functional_tensor(t_clone2, mirror_autograd_meta=True)
|
||||
x = torch._to_functional_tensor(t_clone2)
|
||||
torch._mirror_autograd_meta_to(t_clone2, x)
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
aot_f_out = f(x)
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ def _running_with_deploy():
|
|||
return sys.modules.get("torch._meta_registrations", None) is object
|
||||
|
||||
from ._utils import _import_dotted_name, classproperty
|
||||
from ._utils import _functionalize_sync as _sync
|
||||
from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \
|
||||
USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_GLOBAL_DEPS
|
||||
|
||||
|
|
|
|||
|
|
@ -676,7 +676,9 @@ def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires
|
|||
|
||||
def to_fun(t):
|
||||
if isinstance(t, Tensor):
|
||||
return torch._to_functional_tensor(t, mirror_autograd_meta=True)
|
||||
out = torch._to_functional_tensor(t)
|
||||
torch._mirror_autograd_meta_to(t, out)
|
||||
return out
|
||||
else:
|
||||
return t
|
||||
|
||||
|
|
@ -727,7 +729,8 @@ def run_functionalized_fw_and_collect_metadata(
|
|||
if isinstance(t, Tensor):
|
||||
if t in memo:
|
||||
return memo[t]
|
||||
r = torch._to_functional_tensor(t, mirror_autograd_meta=True)
|
||||
r = torch._to_functional_tensor(t)
|
||||
torch._mirror_autograd_meta_to(t, r)
|
||||
memo[t] = r
|
||||
return r
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -119,7 +119,9 @@ class FunctionalTensor(torch.Tensor):
|
|||
# - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine)
|
||||
# this is handled by FunctionalTensor.to_functional
|
||||
x_functional = torch._to_functional_tensor(x)
|
||||
torch._mirror_autograd_meta_to(x, x_functional)
|
||||
out = FunctionalTensor(x_functional)
|
||||
torch._mirror_autograd_meta_to(x_functional, out)
|
||||
return out
|
||||
|
||||
def from_functional(self):
|
||||
|
|
|
|||
|
|
@ -600,9 +600,9 @@ class MetaConverter:
|
|||
dynamic_dims=dynamic_dims,
|
||||
constraint_dims=constraint_dims,
|
||||
)
|
||||
return torch._to_functional_tensor(
|
||||
fake_t, mirror_autograd_meta=True
|
||||
)
|
||||
out = torch._to_functional_tensor(fake_t)
|
||||
torch._mirror_autograd_meta_to(fake_t, out)
|
||||
return out
|
||||
else:
|
||||
# torch.func.functionalize
|
||||
reapply_views = torch._C._functionalization_reapply_views_tls()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import sys
|
|||
import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, DefaultDict, List, Optional
|
||||
|
||||
import torch
|
||||
|
|
@ -842,6 +843,44 @@ def is_compiling():
|
|||
return False
|
||||
|
||||
|
||||
def _functionalize_sync(t):
|
||||
# This code lives in python instead of C++ since conditioning on a certain python subclass
|
||||
# is much more of a pain in C++.
|
||||
from torch._subclasses.functional_tensor import (
|
||||
FunctionalTensor,
|
||||
maybe_disable_functional_mode,
|
||||
)
|
||||
|
||||
ctx = (
|
||||
maybe_disable_functional_mode
|
||||
if isinstance(t, FunctionalTensor)
|
||||
else nullcontext
|
||||
)
|
||||
if isinstance(t, FunctionalTensor):
|
||||
# If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
|
||||
# when we sync our inner tensor.
|
||||
# Why?
|
||||
# (1) If there are input mutations in the graph, then they will be re-applied during
|
||||
# AOTAutograd when we call _sync() from inside of our functionalization kernels.
|
||||
# (2) _sync() causes us to regenerate our updated the tensor from the updated base,
|
||||
# which dispatches to a bunch of view ops
|
||||
# (3) The input to these view ops is our inner FunctionalTensorWrapper
|
||||
# (since the sync was called from C++), not the python FunctionalTensor
|
||||
# (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
|
||||
# the view op, since it will see an input that is a C++ FunctionalTensorWrapper
|
||||
# (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
|
||||
maybe_functional_mode = torch._C._unset_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.FUNCTIONAL
|
||||
)
|
||||
try:
|
||||
torch._functionalize_sync(t.elem)
|
||||
finally:
|
||||
if maybe_functional_mode is not None:
|
||||
torch._C._set_dispatch_mode(maybe_functional_mode)
|
||||
else:
|
||||
torch._functionalize_sync(t)
|
||||
|
||||
|
||||
@functools.lru_cache(2)
|
||||
def _get_device_module(device_type: str):
|
||||
device_module = getattr(torch, device_type, None)
|
||||
|
|
|
|||
|
|
@ -362,31 +362,52 @@ static PyObject* THPVariable__to_functional_tensor(
|
|||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser(
|
||||
{"_to_functional_tensor(Tensor t, *, bool mirror_autograd_meta=False)"},
|
||||
{"_to_functional_tensor(Tensor t)"},
|
||||
/*traceable=*/true);
|
||||
|
||||
ParsedArgs<2> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
auto self_ = r.tensor(0);
|
||||
auto mirror_autograd_meta = r.toBool(1);
|
||||
auto wrapped = at::functionalization::impl::to_functional_tensor(self_);
|
||||
if (mirror_autograd_meta) {
|
||||
return wrap(std::move(wrapped));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// Given source and dest tensors,
|
||||
// Sets **some** (but not all) autograd metadata on dest, according to source:
|
||||
// - requires_grad
|
||||
// - grad_fn
|
||||
// (If src has a grad_fn, we install an error grad_fn on dest to avoid
|
||||
// difficult bugs.
|
||||
// The main purpose is to ensure that dst.is_leaf == src.is_leaf)
|
||||
static PyObject* THPVariable__mirror_autograd_meta_to(
|
||||
PyObject* self,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser(
|
||||
{"_mirror_autograd_meta_to(Tensor source, Tensor dest)"},
|
||||
/*traceable=*/true);
|
||||
|
||||
ParsedArgs<2> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
auto src_ = r.tensor(0);
|
||||
auto dst_ = r.tensor(1);
|
||||
// Here, we unsafely set the grad function on the wrapper to be the same as
|
||||
// the inner. We expect this grad_fn to NEVER be used. It's needed so that
|
||||
// .is_leaf metadata is accurate on the wrapper
|
||||
auto inner_autograd_meta = impl::get_autograd_meta(self_);
|
||||
auto inner_autograd_meta = impl::get_autograd_meta(src_);
|
||||
if (inner_autograd_meta) {
|
||||
wrapped.set_requires_grad(self_.requires_grad());
|
||||
if (wrapped.requires_grad()) {
|
||||
dst_.set_requires_grad(src_.requires_grad());
|
||||
if (dst_.requires_grad()) {
|
||||
auto new_grad_fn = std::shared_ptr<torch::autograd::Error>(
|
||||
new torch::autograd::Error(
|
||||
"Cannot backprop through mirrored meta, file a bug in PyTorch"),
|
||||
torch::autograd::deleteNode);
|
||||
torch::autograd::set_history(wrapped, new_grad_fn);
|
||||
torch::autograd::set_history(dst_, new_grad_fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
return wrap(std::move(wrapped));
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
|
@ -526,12 +547,13 @@ static PyObject* THPVariable__disable_functionalization(
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPVariable__sync(
|
||||
static PyObject* THPVariable__functionalize_sync(
|
||||
PyObject* self,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({"_sync(Tensor t)"}, /*traceable=*/true);
|
||||
static PythonArgParser parser(
|
||||
{"_functionalize_sync(Tensor t)"}, /*traceable=*/true);
|
||||
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
|
|
@ -568,6 +590,10 @@ static PyMethodDef torch_functions_manual[] = {
|
|||
castPyCFunctionWithKeywords(THPVariable__to_functional_tensor),
|
||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
||||
nullptr},
|
||||
{"_mirror_autograd_meta_to",
|
||||
castPyCFunctionWithKeywords(THPVariable__mirror_autograd_meta_to),
|
||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
||||
nullptr},
|
||||
{"_from_functional_tensor",
|
||||
castPyCFunctionWithKeywords(THPVariable__from_functional_tensor),
|
||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
||||
|
|
@ -576,8 +602,8 @@ static PyMethodDef torch_functions_manual[] = {
|
|||
castPyCFunctionWithKeywords(THPVariable__freeze_functional_tensor),
|
||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
||||
nullptr},
|
||||
{"_sync",
|
||||
castPyCFunctionWithKeywords(THPVariable__sync),
|
||||
{"_functionalize_sync",
|
||||
castPyCFunctionWithKeywords(THPVariable__functionalize_sync),
|
||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
|
||||
nullptr},
|
||||
{"_enable_functionalization",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user