pytorch/torch/_export/wrappers.py
Tugsbayasgalan Manlaibaatar 81f98f1082 Experimental non-strict mode (#114658)
This is proof-of-concept implementation of how people can use a marker `mark_strict` to enable torchdynamo while exporting under non-strict mode. The main idea is that `mark_strict` will turn into an HOO which then utilizes dynamo to do correctness analysis in the same way how torch.cond works today. There are some notable limitations:
1. This API is not meant for public use yet
2. Strict region can't work with arbitrary container inputs
3. We don't preserve `nn_module_stack` and other node metadata for the strict region.
4. strict_mode HOO will show up in the final graph. This is undesirable in the long term, but for short term experiments, it should be good enough. Will fix this in the follow up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114658
Approved by: https://github.com/ydwu4
2024-01-04 12:24:58 +00:00

112 lines
3.7 KiB
Python

from contextlib import contextmanager
import torch
import torch._custom_ops
from torch._C import DispatchKey
from torch._higher_order_ops.strict_mode import strict_mode
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils import _pytree as pytree
_export_tracepoint = HigherOrderOperator("_export_tracepoint")
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)
def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
if not mode.enable_tracing:
return _export_tracepoint(*args, **kwargs)
p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
proxy = mode.tracer.create_proxy(
"call_function", _export_tracepoint, p_args, p_kwargs
)
return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
@_export_tracepoint.py_impl(FakeTensorMode)
def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
with mode:
return args
@_export_tracepoint.py_functionalize_impl
def export_tracepoint_functional(ctx, *args, **kwargs):
unwrapped_args = ctx.unwrap_tensors(args)
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs)
return ctx.wrap_tensors(out)
_export_tracepoint.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(_export_tracepoint, deferred_error=True)
)
@_export_tracepoint.py_impl(DispatchKey.CPU)
def export_tracepoint_cpu(*args, **kwargs):
return args
def _wrap_submodule(mod, path, module_call_specs):
assert isinstance(mod, torch.nn.Module)
assert path != ""
submodule = mod
for name in path.split("."):
if not hasattr(submodule, name):
raise RuntimeError(f"Couldn't find submodule at path {path}")
submodule = getattr(submodule, name)
def update_module_call_signatures(path, in_spec, out_spec):
assert path not in module_call_specs
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
assert "forward" not in submodule.__dict__
wrapped_forward = submodule.forward
def check_flattened(flat_args):
for a in flat_args:
if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None):
raise AssertionError(
f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}"
)
def wrapper(self, *args, **kwargs):
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
check_flattened(flat_args)
flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path)
args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
res = wrapped_forward(*args, **kwargs)
flat_res, out_spec = pytree.tree_flatten(res)
check_flattened(flat_res)
flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path)
update_module_call_signatures(path, in_spec, out_spec)
return pytree.tree_unflatten(flat_res, out_spec)
submodule.forward = wrapper.__get__(submodule, type(submodule))
return submodule
@contextmanager
def _wrap_submodules(f, preserve_signature, module_call_signatures):
tasks = []
try:
for path in preserve_signature:
tasks.append(_wrap_submodule(f, path, module_call_signatures))
yield
finally:
for submodule in tasks:
del submodule.__dict__["forward"]
def _mark_strict_DO_NOT_USE(cls):
def call(self, *args):
return strict_mode(self, args)
cls.__call__ = call
return cls