mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FX] Experimental type annotation pass using Python signatures (#53831)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53831 Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D26982804 Pulled By: jamesr66a fbshipit-source-id: 17db9f71e729206f29ee231e34723d9616f128b7
This commit is contained in:
parent
255b103c1b
commit
a27f46bbe3
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import operator
|
||||
import unittest
|
||||
import sys
|
||||
from typing import Callable, Dict, Union, List
|
||||
|
|
@ -23,6 +24,7 @@ from torch.fx.experimental.partitioner_utils import (
|
|||
from torch.fx.experimental.fuser import fuse
|
||||
from torch.fx.experimental import merge_matmul
|
||||
from torch.fx.experimental.normalize import NormalizeArgs
|
||||
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
|
||||
from torch.testing._internal.common_nn import module_tests, new_module_tests
|
||||
|
||||
try:
|
||||
|
|
@ -854,6 +856,47 @@ class {test_classname}(torch.nn.Module):
|
|||
if submod_class == nn_class:
|
||||
self.assertEqual(len(node.args), 0)
|
||||
|
||||
@skipIfNoTorchVision
|
||||
def test_annotate_returns_with_schema(self):
|
||||
m = resnet18()
|
||||
|
||||
traced_modules = symbolic_trace(m)
|
||||
traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform()
|
||||
for node in traced_modules_annotated.graph.nodes:
|
||||
if node.type is None:
|
||||
check = (node.op, node.target)
|
||||
self.assertTrue(check in {('placeholder', 'x'), ('call_function', operator.add),
|
||||
('call_function', torch.flatten), ('output', 'output')})
|
||||
|
||||
# Smoke test torchscript compilation since now we're emitting type annotations
|
||||
torch.jit.script(traced_modules_annotated)
|
||||
|
||||
class FunctionalTracer(torch.fx.Tracer):
|
||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
||||
# `leaves` contains the set of standard `nn.Modules` that are not
|
||||
# currently symbolically traceable. Ideally this set would be empty
|
||||
leaves = set([torch.nn.BatchNorm2d])
|
||||
return type(m) in leaves
|
||||
|
||||
traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
|
||||
|
||||
traced_functionals_annotated = AnnotateTypesWithSchema(traced_functionals).transform()
|
||||
for node in traced_functionals_annotated.graph.nodes:
|
||||
if node.type is None:
|
||||
check = (node.op, node.target)
|
||||
excluded_nodes = {
|
||||
('placeholder', 'x'),
|
||||
('call_function', torch.conv2d),
|
||||
# Return type differs based on boolean dispatch :(
|
||||
('call_function', torch.nn.functional.max_pool2d),
|
||||
('call_function', operator.add),
|
||||
('call_function', torch.flatten),
|
||||
('output', 'output'),
|
||||
}
|
||||
self.assertTrue(check in excluded_nodes)
|
||||
|
||||
# Smoke test torchscript compilation since now we're emitting type annotations
|
||||
torch.jit.script(traced_functionals_annotated)
|
||||
|
||||
def test_subgraph_uniquename(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
|
|
|
|||
111
torch/fx/experimental/schema_type_annotation.py
Normal file
111
torch/fx/experimental/schema_type_annotation.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
import torch
|
||||
import torch.fx
|
||||
import inspect
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch._jit_internal import boolean_dispatched
|
||||
from torch.fx.operator_schemas import _torchscript_type_to_python_type
|
||||
|
||||
from torch.fx import Transformer
|
||||
|
||||
class AnnotateTypesWithSchema(Transformer):
|
||||
"""
|
||||
Use Python function signatures to annotate types for `Nodes` within an FX graph.
|
||||
This pulls out Python function signatures for:
|
||||
|
||||
1. Standard `torch.nn` Module calls
|
||||
2. `torch.nn.functional` calls
|
||||
3. Attribute fetches via `get_attr`
|
||||
|
||||
Example usage:
|
||||
|
||||
m = torchvision.models.resnet18()
|
||||
|
||||
traced = torch.fx.symbolic_trace(m)
|
||||
|
||||
traced = AnnotateTypesWithSchema(traced).transform()
|
||||
|
||||
"""
|
||||
def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True,
|
||||
annotate_modules : bool = True, annotate_get_attrs : bool = True):
|
||||
super().__init__(module)
|
||||
self.annotate_functionals = annotate_functionals
|
||||
self.annotate_modules = annotate_modules
|
||||
self.annotate_get_attrs = annotate_get_attrs
|
||||
|
||||
def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
|
||||
python_ret_type = None
|
||||
if self.annotate_functionals and target.__module__ == 'torch.nn.functional':
|
||||
target_for_analysis = target
|
||||
if target in boolean_dispatched:
|
||||
# HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
|
||||
# a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
|
||||
# branches of the dispatch have exactly the same signature. If they do, use the `true`
|
||||
# branch signature for analysis. Otherwise, leave this un-normalized
|
||||
assert not isinstance(target, str)
|
||||
dispatched = boolean_dispatched[target]
|
||||
if_true, if_false = dispatched['if_true'], dispatched['if_false']
|
||||
# TODO: can we emit the union of these? What are the implications on TorchScript
|
||||
# compilation?
|
||||
if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation:
|
||||
return super().call_function(target, args, kwargs)
|
||||
target_for_analysis = if_true
|
||||
|
||||
python_ret_type = self._extract_python_return_type(target_for_analysis)
|
||||
|
||||
return_proxy = super().call_function(target, args, kwargs)
|
||||
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
|
||||
return return_proxy
|
||||
|
||||
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
|
||||
python_ret_type = None
|
||||
assert isinstance(target, str)
|
||||
submod = self.fetch_attr(target)
|
||||
if self.annotate_modules and hasattr(submod.__class__, '__name__'):
|
||||
classname = submod.__class__.__name__
|
||||
if getattr(torch.nn, classname, None) == submod.__class__:
|
||||
python_ret_type = self._extract_python_return_type(submod.forward)
|
||||
return_proxy = super().call_module(target, args, kwargs)
|
||||
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
|
||||
return return_proxy
|
||||
|
||||
def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
|
||||
attr_proxy = super().get_attr(target, args, kwargs)
|
||||
|
||||
if self.annotate_get_attrs:
|
||||
module_itr = self.module
|
||||
assert isinstance(target, str)
|
||||
atoms = target.split('.')
|
||||
for i, atom in enumerate(atoms):
|
||||
if not hasattr(module_itr, atom):
|
||||
raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!')
|
||||
module_itr = getattr(module_itr, atom)
|
||||
|
||||
maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr)
|
||||
if maybe_inferred_ts_type.success():
|
||||
python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type())
|
||||
attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type
|
||||
|
||||
return attr_proxy
|
||||
|
||||
def _extract_python_return_type(self, target : Target) -> Optional[Any]:
|
||||
"""
|
||||
Given a Python call target, try to extract the Python return annotation
|
||||
if it is available, otherwise return None
|
||||
|
||||
Args:
|
||||
|
||||
target (Callable): Python callable to get return annotation for
|
||||
|
||||
Returns:
|
||||
|
||||
Optional[Any]: Return annotation from the `target`, or None if it was
|
||||
not available.
|
||||
"""
|
||||
assert callable(target)
|
||||
try:
|
||||
sig = inspect.signature(target)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None
|
||||
|
|
@ -12,7 +12,7 @@ import warnings
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph_module import GraphModule
|
||||
from .graph_module import GraphModule # noqa
|
||||
|
||||
|
||||
# Mapping of builtins to their `typing` equivalent.
|
||||
|
|
@ -851,9 +851,9 @@ class Graph:
|
|||
|
||||
|
||||
def emit_node(node : Node):
|
||||
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
|
||||
if node.op == 'placeholder':
|
||||
assert isinstance(node.target, str)
|
||||
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
|
||||
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
|
||||
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
|
||||
raw_name = node.target.replace('*', '')
|
||||
|
|
@ -863,7 +863,7 @@ class Graph:
|
|||
elif node.op == 'call_method':
|
||||
assert isinstance(node.target, str)
|
||||
body.append(
|
||||
f'{repr(node)} = {_format_target(repr(node.args[0]), node.target)}'
|
||||
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
|
||||
f'({_format_args(node.args[1:], node.kwargs)})')
|
||||
return
|
||||
elif node.op == 'call_function':
|
||||
|
|
@ -871,7 +871,8 @@ class Graph:
|
|||
# pretty print operators
|
||||
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
|
||||
assert isinstance(node.args, tuple)
|
||||
body.append(f'{repr(node)} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = '
|
||||
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
|
||||
return
|
||||
qualified_name = _get_qualified_name(node.target)
|
||||
global_name = add_global(qualified_name, node.target)
|
||||
|
|
@ -880,17 +881,18 @@ class Graph:
|
|||
isinstance(node.args[1], str) and \
|
||||
node.args[1].isidentifier():
|
||||
# pretty print attribute access
|
||||
body.append(f'{repr(node)} = {_format_target(repr(node.args[0]), node.args[1])}')
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
|
||||
return
|
||||
body.append(f'{repr(node)} = {global_name}({_format_args(node.args, node.kwargs)})')
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
|
||||
return
|
||||
elif node.op == 'call_module':
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f'{repr(node)} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = '
|
||||
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
|
||||
return
|
||||
elif node.op == 'get_attr':
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f'{repr(node)} = {_format_target(root_module, node.target)}')
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
|
||||
return
|
||||
elif node.op == 'output':
|
||||
if node.type is not None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user