mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes https://github.com/pytorch/pytorch/issues/112446 This is a doozy of a PR, there's a few important things to keep in mind here: 1) We MUST lift all tensors accessed via attrs to inputs, getattr is a no go in the graph, it violates the aot_autograd contract. Furthermore, aot_autograd does not know how to apply in-place ops to intermediary tensors that are attributes (aka from getattr) anyway. Views from ops are fine. 2) `.grad` access handling in dynamo peeks at the underlying value, the real tensor, because re-piping FakeTensors already made with this fake_mode through builder anew is a no go. 3) We have no proper mechanism for updating the hint / grapharg.example (the real value in (2) above) midway through trace Therefore, what we need to do is reconcile the difference in grad stashed on grapharg.example. The easiest way to do this is lazily, upon .grad access, by reading the new value off the right fake tensors. We can then make a tensor using that data as a hint to VariableBuilder to make the right VariableTracker. Note that the example value used here (torch.zeros) in the PR, is a dummy value only used as a tracing hint, it does not leak out into real runtime code. Alternatively, we could implement accumulate_grad_ in python... Pull Request resolved: https://github.com/pytorch/pytorch/pull/112811 Approved by: https://github.com/jansel
218 lines
7.2 KiB
Python
218 lines
7.2 KiB
Python
import contextlib
|
|
import functools
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from torch._dynamo.external_utils import call_hook
|
|
from torch._dynamo.source import GetItemSource, LocalSource
|
|
from torch._dynamo.utils import counters, lazy_format_graph_code
|
|
from torch._logging import getArtifactLogger
|
|
from torch._prims_common import clone_preserve_strides
|
|
from torch._subclasses import FakeTensorMode
|
|
from torch.fx import GraphModule
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
decompose,
|
|
disable_autocast_cache,
|
|
disable_proxy_modes_tracing,
|
|
fetch_tensor_proxy,
|
|
ProxyTorchDispatchMode,
|
|
PythonKeyTracer,
|
|
track_tensor_tree,
|
|
)
|
|
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
|
from torch.fx.proxy import Proxy
|
|
|
|
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
|
|
|
|
|
|
def maybe_clone(x):
|
|
if x is not None:
|
|
return clone_preserve_strides(x)
|
|
return x
|
|
|
|
|
|
class AutogradCompilerInstance:
|
|
def __init__(self, compiler_fn) -> None:
|
|
self.compiler_fn = compiler_fn
|
|
self.stack = contextlib.ExitStack()
|
|
self.close = self.stack.close
|
|
self.shape_env = ShapeEnv()
|
|
self.fake_tensor_mode = FakeTensorMode(
|
|
allow_fallback_kernels=True,
|
|
allow_non_fake_inputs=True,
|
|
shape_env=self.shape_env,
|
|
)
|
|
self.fx_tracer = PythonKeyTracer()
|
|
self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
|
|
self.hooks_proxy: Optional[Proxy] = None
|
|
|
|
def wrap_fake(self, x, source):
|
|
assert isinstance(x, torch.Tensor)
|
|
return self.fake_tensor_mode.from_tensor(x, source=source)
|
|
|
|
@staticmethod
|
|
def source(name, idx) -> GetItemSource:
|
|
return GetItemSource(LocalSource(name), idx)
|
|
|
|
def begin_capture(self, inputs: List[torch.Tensor], sizes: List[int]):
|
|
counters["compiled_autograd"]["captures"] += 1
|
|
self.fx_tracer.root = torch.nn.Module()
|
|
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
|
|
self.fx_tracer.tensor_attrs = {}
|
|
args_proxy = self.fx_tracer.create_proxy("placeholder", "inputs", (), {})
|
|
sizes_proxy = self.fx_tracer.create_proxy("placeholder", "sizes", (), {})
|
|
self.hooks_proxy = self.fx_tracer.create_proxy("placeholder", "hooks", (), {})
|
|
|
|
# tensor inputs to fake tensors
|
|
inputs = [
|
|
self.wrap_fake(x, self.source("inputs", idx))
|
|
for idx, x in enumerate(inputs)
|
|
]
|
|
proxies = [args_proxy[i] for i in range(len(inputs))]
|
|
self.bind_tensors_to_proxies(inputs, proxies)
|
|
|
|
# size inputs to symints
|
|
sizes = [
|
|
self.shape_env.create_unspecified_symint_and_symbol(
|
|
val,
|
|
self.source("sizes", idx),
|
|
DimDynamic.DYNAMIC,
|
|
)
|
|
for idx, val in enumerate(sizes)
|
|
]
|
|
self.bind_tensors_to_proxies(sizes, sizes_proxy)
|
|
|
|
# TODO(jansel): are all these modes needed?
|
|
self.stack.enter_context(decompose({}))
|
|
self.stack.enter_context(self.fake_tensor_mode)
|
|
self.stack.enter_context(self.proxy_mode.sym_mode)
|
|
self.stack.enter_context(self.proxy_mode)
|
|
self.stack.enter_context(disable_autocast_cache())
|
|
self.stack.enter_context(disable_proxy_modes_tracing(enable_current=True))
|
|
return inputs, sizes
|
|
|
|
def proxy_call_hook(self, hook, *args):
|
|
return self.fx_tracer.create_proxy(
|
|
"call_function",
|
|
call_hook,
|
|
(
|
|
hook,
|
|
*[self.to_proxy(x) for x in args],
|
|
),
|
|
{},
|
|
)
|
|
|
|
def tensor_pre_hook(self, inputs, hook_id, i: int):
|
|
assert self.hooks_proxy is not None
|
|
hook = self.hooks_proxy[hook_id]
|
|
proxy = self.proxy_call_hook(
|
|
hook,
|
|
inputs[i],
|
|
)
|
|
with disable_proxy_modes_tracing():
|
|
inputs[i] = maybe_clone(inputs[i])
|
|
self.bind_tensors_to_proxies([inputs[i]], [proxy])
|
|
return inputs
|
|
|
|
def pre_hook(self, inputs, hook_id):
|
|
assert self.hooks_proxy is not None
|
|
hook = self.hooks_proxy[hook_id]
|
|
proxies = self.proxy_call_hook(
|
|
hook,
|
|
inputs,
|
|
)
|
|
with disable_proxy_modes_tracing():
|
|
inputs = [maybe_clone(x) for x in inputs]
|
|
self.bind_tensors_to_proxies(inputs, proxies)
|
|
return inputs
|
|
|
|
def post_hook(self, outputs, inputs, hook_id):
|
|
assert self.hooks_proxy is not None
|
|
hook = self.hooks_proxy[hook_id]
|
|
proxies = self.proxy_call_hook(
|
|
hook,
|
|
outputs,
|
|
inputs,
|
|
)
|
|
with disable_proxy_modes_tracing():
|
|
outputs = [maybe_clone(x) for x in outputs]
|
|
self.bind_tensors_to_proxies(outputs, proxies)
|
|
return outputs
|
|
|
|
def post_acc_grad_hook(self, input, hook_id):
|
|
assert isinstance(input, torch.Tensor)
|
|
assert self.hooks_proxy is not None
|
|
hook = self.hooks_proxy[hook_id]
|
|
proxies = self.proxy_call_hook(
|
|
hook,
|
|
input,
|
|
)
|
|
with disable_proxy_modes_tracing():
|
|
input = [maybe_clone(input)]
|
|
self.bind_tensors_to_proxies(input, proxies)
|
|
return input
|
|
|
|
def end_capture(self, outputs):
|
|
self.stack.close()
|
|
self.fx_tracer.create_node(
|
|
"output",
|
|
"output",
|
|
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
|
|
{},
|
|
)
|
|
graph = GraphModule(
|
|
self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
|
|
)
|
|
compiled_autograd_log.info(
|
|
"%s", lazy_format_graph_code("Compiled autograd graph", graph)
|
|
)
|
|
return self.compiler_fn(graph)
|
|
|
|
def to_proxy(self, t):
|
|
if t is None:
|
|
return None
|
|
if isinstance(t, list):
|
|
return [self.to_proxy(x) for x in t]
|
|
if isinstance(t, tuple):
|
|
return tuple(self.to_proxy(x) for x in t)
|
|
assert isinstance(t, (torch.Tensor, torch.SymInt))
|
|
return fetch_tensor_proxy(self.fx_tracer)(t).proxy
|
|
|
|
def bind_tensors_to_proxies(self, tensors, proxies):
|
|
if isinstance(proxies, torch.fx.Proxy):
|
|
proxies = [proxies[i] for i in range(len(tensors))]
|
|
assert len(tensors) == len(proxies)
|
|
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
|
|
|
|
|
|
compiled_autograd_enabled = False
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def enable(compiler_fn):
|
|
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
|
|
functools.partial(AutogradCompilerInstance, compiler_fn)
|
|
)
|
|
global compiled_autograd_enabled
|
|
compiled_autograd_enabled = True
|
|
try:
|
|
with torch.autograd.set_multithreading_enabled(False):
|
|
yield
|
|
finally:
|
|
if not prior:
|
|
compiled_autograd_enabled = False
|
|
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def disable():
|
|
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
|
|
global compiled_autograd_enabled
|
|
compiled_autograd_enabled = False
|
|
try:
|
|
yield
|
|
finally:
|
|
if prior:
|
|
compiled_autograd_enabled = True
|
|
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
|