pytorch/torch/fx/experimental/proxy_tensor.py
Horace He 50cadfae10 Add strictness check and made tensors into leaves if input tensors were leaves (#77474)
I think this makes sense to do? Otherwise, if you call `backward()` in your traced function, you can't get gradients out of any tensors that should have been leaves.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77474
Approved by: https://github.com/ezyang
2022-05-21 01:16:39 +00:00

233 lines
8.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Any, Dict, Optional, Tuple, Callable, Union
import torch
from torch._C import _disabled_torch_function_impl
import torch.utils._pytree as pytree
from torch.fx import Tracer, GraphModule
import torch.fx as fx
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from contextlib import contextmanager
from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict"]
aten = torch.ops.aten
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
@contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
yield
finally:
del guard
@contextmanager
def decompose(decomposition_table):
global CURRENT_DECOMPOSITION_TABLE
old_decomposition_table = CURRENT_DECOMPOSITION_TABLE
CURRENT_DECOMPOSITION_TABLE = decomposition_table
try:
yield CURRENT_DECOMPOSITION_TABLE
finally:
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
# Checks whether we try to convert the tensor into a scalar
IS_STRICT = True
def enable_strict(val):
global IS_STRICT
IS_STRICT = val
def wrap_output(real_out, proxy_out):
def wrap_with_proxy(e, proxy):
if type(e) == torch.Tensor:
with no_dispatch():
return ProxyTensor(e, proxy)
else:
return e
# Unfortunately, tree_map cannot directly be used here. As the resulting
# object may be a proxy that represents a tuple, we may need to
# explicitly unwrap the proxy by simulating the flattening operations.
if isinstance(real_out, tuple):
return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out))
elif isinstance(real_out, list):
return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
elif isinstance(real_out, torch.Tensor):
return wrap_with_proxy(real_out, proxy_out)
else:
return real_out
def proxy_call(func_overload, args, kwargs=None):
func = func_overload.overloadpacket
if func_overload in CURRENT_DECOMPOSITION_TABLE:
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
if func_overload == aten._local_scalar_dense.default:
raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
"It's likely that this is caused by data-dependent control flow or similar."
"Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")
def unwrap_proxy(e):
return e.proxy if isinstance(e, ProxyTensor) else e
proxy_args = pytree.tree_map(unwrap_proxy, args)
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
proxy_out = func(*proxy_args, **proxy_kwargs)
# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
args[0].proxy = proxy_out
proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
with no_dispatch():
real_out = func_overload(*args, **kwargs)
return wrap_output(real_out, proxy_out)
class ProxyTensor(torch.Tensor):
proxy: fx.Proxy
@staticmethod
def __new__(cls, elem, proxy, *, requires_grad=None):
# Hack to deal with super().__new__ not working for sparse tensors
if elem.is_sparse or requires_grad is not None:
r = torch.Tensor._make_subclass(cls, elem, requires_grad)
else:
r = super().__new__(cls, elem) # type: ignore[call-arg]
if elem.is_sparse:
proxy.node.meta['tensor_meta'] = {}
else:
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
r.proxy = proxy # type: ignore[attr-defined]
return r
def __deepcopy__(self, memo):
return self.clone()
def __repr__(self):
with no_dispatch():
return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type]
__torch_function__ = _disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
return proxy_call(func_overload, args, kwargs)
class PythonKeyTracer(Tracer):
def __init__(self):
super().__init__()
# In general, we don't want to make modules leaves. In principle, users of
# this tracer might want to override this in order to turn a couple specific
# modules into leaves in the traced graph.
def call_module(
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Any:
return forward(*args, **kwargs)
def create_arg(self, a: Any):
if isinstance(a, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if a is p:
return self.create_node('get_attr', n, (), {})
qualname: Optional[str] = None
if not qualname:
i = 0
while True:
qualname = f'_param_constant{i}'
if not hasattr(self.root, qualname):
break
i += 1
setattr(self.root, qualname, a)
return self.create_node('get_attr', qualname, (), {})
return super().create_arg(a)
def dispatch_trace(
root: Union[torch.nn.Module, Callable],
concrete_args: Optional[Tuple[Any, ...]] = None,
trace_factory_functions: bool = False,
) -> GraphModule:
tracer = PythonKeyTracer()
if trace_factory_functions:
with push_torch_dispatch_mode(functools.partial(ProxyTorchDispatchMode, tracer)):
graph = tracer.trace(root, concrete_args)
else:
graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)
def wrap_key(f, inps):
flat_inps, _ = pytree.tree_flatten(inps)
@functools.wraps(f)
def wrapped(*args):
flat_args, args_spec = pytree.tree_flatten(args)
assert (len(flat_args) == len(flat_inps))
for idx, arg in enumerate(flat_args):
if isinstance(flat_inps[idx], torch.Tensor):
with no_dispatch():
flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=flat_inps[idx].is_leaf)
else:
flat_args[idx] = flat_inps[idx]
tree_args = pytree.tree_unflatten(flat_args, args_spec)
out = f(*tree_args)
flat_outs, out_spec = pytree.tree_flatten(out)
for idx in range(len(flat_outs)):
if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor):
flat_outs[idx] = flat_outs[idx].proxy
return pytree.tree_unflatten(flat_outs, out_spec)
return wrapped
class ProxyTorchDispatchMode(TorchDispatchMode):
def __init__(self, tracer):
self.tracer = tracer
def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
func = func_overload.overloadpacket
if any(tuple(isinstance(arg, ProxyTensor) for arg in args)):
return proxy_call(func_overload, args, kwargs)
else:
proxy_out = self.tracer.create_proxy('call_function', func, args, kwargs,
name=self.tracer.graph._target_to_str(func.__name__))
with no_dispatch():
real_out = func_overload(*args, **kwargs)
return wrap_output(real_out, proxy_out)
def make_fx(f, decomposition_table=None, trace_factory_functions=False):
if decomposition_table is None:
decomposition_table = {}
@functools.wraps(f)
def wrapped(*args):
phs = pytree.tree_map(lambda x: fx.PH, args) # type: ignore[attr-defined]
with decompose(decomposition_table):
t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs),
trace_factory_functions=trace_factory_functions)
return t
return wrapped