mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: When we export in on strict mode and turn on preserve_module_call_signature, the following assertion error will occur today: ``` child_split[: len(parent_split)] == parent_split ``` This is due to the fact that we're monkey patching forward call directly, which kinda breaks the attribute propagation in the tracer. It's actually better to implement this by using forward hook because we don't have to alter the original module structure at all during export. Test Plan: CI Differential Revision: D54102714 Pull Request resolved: https://github.com/pytorch/pytorch/pull/120468 Approved by: https://github.com/ydwu4
115 lines
3.9 KiB
Python
115 lines
3.9 KiB
Python
from contextlib import contextmanager
|
|
|
|
import torch
|
|
import torch._custom_ops
|
|
from torch._C import DispatchKey
|
|
from torch._higher_order_ops.strict_mode import strict_mode
|
|
from torch._higher_order_ops.utils import autograd_not_implemented
|
|
from torch._ops import HigherOrderOperator
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
|
from torch.utils import _pytree as pytree
|
|
|
|
|
|
_export_tracepoint = HigherOrderOperator("_export_tracepoint")
|
|
|
|
|
|
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)
|
|
def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
|
|
if not mode.enable_tracing:
|
|
return _export_tracepoint(*args, **kwargs)
|
|
p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
|
|
proxy = mode.tracer.create_proxy(
|
|
"call_function", _export_tracepoint, p_args, p_kwargs
|
|
)
|
|
return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
|
|
|
|
|
|
@_export_tracepoint.py_impl(FakeTensorMode)
|
|
def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
|
|
with mode:
|
|
return args
|
|
|
|
|
|
@_export_tracepoint.py_functionalize_impl
|
|
def export_tracepoint_functional(ctx, *args, **kwargs):
|
|
unwrapped_args = ctx.unwrap_tensors(args)
|
|
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
|
|
|
with ctx.redispatch_to_next():
|
|
out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs)
|
|
return ctx.wrap_tensors(out)
|
|
|
|
|
|
_export_tracepoint.py_impl(DispatchKey.Autograd)(
|
|
autograd_not_implemented(_export_tracepoint, deferred_error=True)
|
|
)
|
|
|
|
|
|
@_export_tracepoint.py_impl(DispatchKey.CPU)
|
|
def export_tracepoint_cpu(*args, **kwargs):
|
|
return args
|
|
|
|
|
|
def _wrap_submodule(mod, path, module_call_specs):
|
|
assert isinstance(mod, torch.nn.Module)
|
|
assert path != ""
|
|
submodule = mod
|
|
for name in path.split("."):
|
|
if not hasattr(submodule, name):
|
|
raise RuntimeError(f"Couldn't find submodule at path {path}")
|
|
submodule = getattr(submodule, name)
|
|
|
|
def update_module_call_signatures(path, in_spec, out_spec):
|
|
if path in module_call_specs:
|
|
assert module_call_specs[path]["in_spec"] == in_spec
|
|
assert module_call_specs[path]["out_spec"] == out_spec
|
|
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
|
|
|
|
def check_flattened(flat_args):
|
|
for a in flat_args:
|
|
if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None):
|
|
raise AssertionError(
|
|
f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}"
|
|
)
|
|
|
|
def pre_hook(module, args, kwargs):
|
|
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
|
|
check_flattened(flat_args)
|
|
flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path)
|
|
args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
|
|
return args, kwargs
|
|
|
|
def post_hook(module, args, kwargs, res):
|
|
_, in_spec = pytree.tree_flatten((args, kwargs))
|
|
flat_res, out_spec = pytree.tree_flatten(res)
|
|
check_flattened(flat_res)
|
|
flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path)
|
|
update_module_call_signatures(path, in_spec, out_spec)
|
|
return pytree.tree_unflatten(flat_res, out_spec)
|
|
|
|
pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True)
|
|
post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True)
|
|
return pre_handle, post_handle
|
|
|
|
|
|
@contextmanager
|
|
def _wrap_submodules(f, preserve_signature, module_call_signatures):
|
|
handles = []
|
|
|
|
try:
|
|
for path in preserve_signature:
|
|
handles.extend(_wrap_submodule(f, path, module_call_signatures))
|
|
yield
|
|
finally:
|
|
for handle in handles:
|
|
handle.remove()
|
|
|
|
|
|
def _mark_strict_experimental(cls):
|
|
def call(self, *args):
|
|
return strict_mode(self, args)
|
|
|
|
cls.__call__ = call
|
|
return cls
|