mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert D24269034: [fx] Refactor Tracer so that find_module and root args creation could be overridden by implementations
Test Plan: revert-hammer
Differential Revision:
D24269034 (7b2e8bec85)
Original commit changeset: d7b67f2349dd
fbshipit-source-id: 7dd709b585f82d52d9b9973508137e36d5b5871e
This commit is contained in:
parent
cda88e8e4b
commit
8f12c0e786
|
|
@ -10,6 +10,12 @@ from .proxy import TracerBase
|
|||
|
||||
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
|
||||
|
||||
def _find_module(root: torch.nn.Module, m: torch.nn.Module):
|
||||
for n, p in root.named_modules():
|
||||
if m is p:
|
||||
return n
|
||||
raise NameError('module is not installed as a submodule')
|
||||
|
||||
def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
|
||||
co = fn.__code__
|
||||
co_flags = co.co_flags & ~HAS_VARSTUFF
|
||||
|
|
@ -113,32 +119,34 @@ class Tracer(TracerBase):
|
|||
"""
|
||||
return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)
|
||||
|
||||
def path_of_module(self, mod):
|
||||
for n, p in self.root.named_modules():
|
||||
if mod is p:
|
||||
return n
|
||||
raise NameError('module is not installed as a submodule')
|
||||
|
||||
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs):
|
||||
module_qualified_name = self.path_of_module(m)
|
||||
def call_module(self, m: torch.nn.Module, module_qualified_name: str, forward: Callable[..., Any], args, kwargs):
|
||||
if not self.is_leaf_module(m, module_qualified_name):
|
||||
return forward(*args, **kwargs)
|
||||
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
|
||||
|
||||
def create_args_for_root(self, root_fn, is_module):
|
||||
co = root_fn.__code__
|
||||
def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
|
||||
if isinstance(root, torch.nn.Module):
|
||||
self.root = root
|
||||
fn = type(root).forward
|
||||
else:
|
||||
self.root = torch.nn.Module()
|
||||
fn = root
|
||||
self.graph = Graph()
|
||||
|
||||
assert isinstance(fn, FunctionType)
|
||||
co = fn.__code__
|
||||
total_args = co.co_argcount + co.co_kwonlyargcount
|
||||
names_iter = iter(co.co_varnames)
|
||||
args : List[Any] = []
|
||||
skip_arg_idx = 0
|
||||
if is_module:
|
||||
if isinstance(root, torch.nn.Module):
|
||||
skip_arg_idx = 1
|
||||
next(names_iter) # skip self
|
||||
args.append(self.root)
|
||||
args.append(root)
|
||||
|
||||
def proxy_placeholder(name: str):
|
||||
return self.create_proxy('placeholder', name, (), {},
|
||||
type_expr=root_fn.__annotations__.get(name, None))
|
||||
type_expr=fn.__annotations__.get(name, None))
|
||||
|
||||
args.extend(proxy_placeholder(next(names_iter)) for _ in range(skip_arg_idx, total_args))
|
||||
|
||||
|
|
@ -148,31 +156,17 @@ class Tracer(TracerBase):
|
|||
args.append(proxy_placeholder('*' + next(names_iter)))
|
||||
if co.co_flags & inspect.CO_VARKEYWORDS:
|
||||
args.append(proxy_placeholder('**' + next(names_iter)))
|
||||
root_fn = _patch_function(root_fn, len(args))
|
||||
|
||||
return root_fn, args
|
||||
|
||||
def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
|
||||
is_module = isinstance(root, torch.nn.Module)
|
||||
if is_module:
|
||||
self.root = root
|
||||
fn = type(root).forward
|
||||
else:
|
||||
self.root = torch.nn.Module()
|
||||
fn = root
|
||||
self.graph = Graph()
|
||||
|
||||
assert isinstance(fn, FunctionType)
|
||||
|
||||
fn, args = self.create_args_for_root(fn, is_module)
|
||||
fn = _patch_function(fn, len(args))
|
||||
|
||||
orig_call = torch.nn.Module.__call__
|
||||
|
||||
def module_call_wrapper(mod, *args, **kwargs):
|
||||
module_qualified_name = _find_module(self.root, mod)
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
return orig_call(mod, *args, **kwargs)
|
||||
|
||||
return self.call_module(mod, forward, args, kwargs)
|
||||
return self.call_module(mod, module_qualified_name, forward, args, kwargs)
|
||||
|
||||
try:
|
||||
torch.nn.Module.__call__ = module_call_wrapper
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user