[FX] Experimental type annotation pass using Python signatures (#53831)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53831

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D26982804

Pulled By: jamesr66a

fbshipit-source-id: 17db9f71e729206f29ee231e34723d9616f128b7
This commit is contained in:
James Reed 2021-03-17 20:39:16 -07:00 committed by Facebook GitHub Bot
parent 255b103c1b
commit a27f46bbe3
3 changed files with 164 additions and 8 deletions

View File

@ -1,4 +1,5 @@
import torch
import operator
import unittest
import sys
from typing import Callable, Dict, Union, List
@ -23,6 +24,7 @@ from torch.fx.experimental.partitioner_utils import (
from torch.fx.experimental.fuser import fuse
from torch.fx.experimental import merge_matmul
from torch.fx.experimental.normalize import NormalizeArgs
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
from torch.testing._internal.common_nn import module_tests, new_module_tests
try:
@ -854,6 +856,47 @@ class {test_classname}(torch.nn.Module):
if submod_class == nn_class:
self.assertEqual(len(node.args), 0)
@skipIfNoTorchVision
def test_annotate_returns_with_schema(self):
m = resnet18()
traced_modules = symbolic_trace(m)
traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform()
for node in traced_modules_annotated.graph.nodes:
if node.type is None:
check = (node.op, node.target)
self.assertTrue(check in {('placeholder', 'x'), ('call_function', operator.add),
('call_function', torch.flatten), ('output', 'output')})
# Smoke test torchscript compilation since now we're emitting type annotations
torch.jit.script(traced_modules_annotated)
class FunctionalTracer(torch.fx.Tracer):
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
# `leaves` contains the set of standard `nn.Modules` that are not
# currently symbolically traceable. Ideally this set would be empty
leaves = set([torch.nn.BatchNorm2d])
return type(m) in leaves
traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
traced_functionals_annotated = AnnotateTypesWithSchema(traced_functionals).transform()
for node in traced_functionals_annotated.graph.nodes:
if node.type is None:
check = (node.op, node.target)
excluded_nodes = {
('placeholder', 'x'),
('call_function', torch.conv2d),
# Return type differs based on boolean dispatch :(
('call_function', torch.nn.functional.max_pool2d),
('call_function', operator.add),
('call_function', torch.flatten),
('output', 'output'),
}
self.assertTrue(check in excluded_nodes)
# Smoke test torchscript compilation since now we're emitting type annotations
torch.jit.script(traced_functionals_annotated)
def test_subgraph_uniquename(self):
class MyModule(torch.nn.Module):

View File

@ -0,0 +1,111 @@
import torch
import torch.fx
import inspect
from typing import Any, Dict, Optional, Tuple
from torch.fx.node import Argument, Target
from torch._jit_internal import boolean_dispatched
from torch.fx.operator_schemas import _torchscript_type_to_python_type
from torch.fx import Transformer
class AnnotateTypesWithSchema(Transformer):
"""
Use Python function signatures to annotate types for `Nodes` within an FX graph.
This pulls out Python function signatures for:
1. Standard `torch.nn` Module calls
2. `torch.nn.functional` calls
3. Attribute fetches via `get_attr`
Example usage:
m = torchvision.models.resnet18()
traced = torch.fx.symbolic_trace(m)
traced = AnnotateTypesWithSchema(traced).transform()
"""
def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True,
annotate_modules : bool = True, annotate_get_attrs : bool = True):
super().__init__(module)
self.annotate_functionals = annotate_functionals
self.annotate_modules = annotate_modules
self.annotate_get_attrs = annotate_get_attrs
def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
python_ret_type = None
if self.annotate_functionals and target.__module__ == 'torch.nn.functional':
target_for_analysis = target
if target in boolean_dispatched:
# HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
# a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
# branches of the dispatch have exactly the same signature. If they do, use the `true`
# branch signature for analysis. Otherwise, leave this un-normalized
assert not isinstance(target, str)
dispatched = boolean_dispatched[target]
if_true, if_false = dispatched['if_true'], dispatched['if_false']
# TODO: can we emit the union of these? What are the implications on TorchScript
# compilation?
if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation:
return super().call_function(target, args, kwargs)
target_for_analysis = if_true
python_ret_type = self._extract_python_return_type(target_for_analysis)
return_proxy = super().call_function(target, args, kwargs)
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
return return_proxy
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
python_ret_type = None
assert isinstance(target, str)
submod = self.fetch_attr(target)
if self.annotate_modules and hasattr(submod.__class__, '__name__'):
classname = submod.__class__.__name__
if getattr(torch.nn, classname, None) == submod.__class__:
python_ret_type = self._extract_python_return_type(submod.forward)
return_proxy = super().call_module(target, args, kwargs)
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
return return_proxy
def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
attr_proxy = super().get_attr(target, args, kwargs)
if self.annotate_get_attrs:
module_itr = self.module
assert isinstance(target, str)
atoms = target.split('.')
for i, atom in enumerate(atoms):
if not hasattr(module_itr, atom):
raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!')
module_itr = getattr(module_itr, atom)
maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr)
if maybe_inferred_ts_type.success():
python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type())
attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type
return attr_proxy
def _extract_python_return_type(self, target : Target) -> Optional[Any]:
"""
Given a Python call target, try to extract the Python return annotation
if it is available, otherwise return None
Args:
target (Callable): Python callable to get return annotation for
Returns:
Optional[Any]: Return annotation from the `target`, or None if it was
not available.
"""
assert callable(target)
try:
sig = inspect.signature(target)
except (ValueError, TypeError):
return None
return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None

View File

@ -12,7 +12,7 @@ import warnings
if TYPE_CHECKING:
from .graph_module import GraphModule
from .graph_module import GraphModule # noqa
# Mapping of builtins to their `typing` equivalent.
@ -851,9 +851,9 @@ class Graph:
def emit_node(node : Node):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
assert isinstance(node.target, str)
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
@ -863,7 +863,7 @@ class Graph:
elif node.op == 'call_method':
assert isinstance(node.target, str)
body.append(
f'{repr(node)} = {_format_target(repr(node.args[0]), node.target)}'
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})')
return
elif node.op == 'call_function':
@ -871,7 +871,8 @@ class Graph:
# pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{repr(node)} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
@ -880,17 +881,18 @@ class Graph:
isinstance(node.args[1], str) and \
node.args[1].isidentifier():
# pretty print attribute access
body.append(f'{repr(node)} = {_format_target(repr(node.args[0]), node.args[1])}')
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
return
body.append(f'{repr(node)} = {global_name}({_format_args(node.args, node.kwargs)})')
body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
return
elif node.op == 'call_module':
assert isinstance(node.target, str)
body.append(f'{repr(node)} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
return
elif node.op == 'get_attr':
assert isinstance(node.target, str)
body.append(f'{repr(node)} = {_format_target(root_module, node.target)}')
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
return
elif node.op == 'output':
if node.type is not None: