mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
The goal of this is to make FX's codegen extensible. I've refactored it into a class with 5 extensibility points on it.
```
class Codegen(object):
def generate_prologue(self, free_vars: List[str], maybe_return_annotation: str) -> str:
"""
Given the free variables and a return annotation, generates the beginning of the FX function.
By default, `generate_prologue(['a', 'b'], '') == 'def forward(a, b):'`
"""
def generate_output(self, output_args: Argument) -> str:
"""
Given the output arguments, generates the return statement of the FX function.
"""
def process_inputs(self, args: Any) -> Any:
"""
Transforms the inputs so that the graph can take them as arguments, as
non-default codegen may result in the inputs to the function being
different from the inputs to the graph.
If the graph was directly runnable, this invariant should hold true
`f.process_outputs(f.graph(*f.process_inputs(*inputs))) == f(*inputs)`
"""
def process_outputs(self, outputs: Any) -> Any:
"""
Transforms the outputs of the graph to be identical to the codegen.
See ``process_inputs`` for more details.
"""
def additional_globals(self) -> List[Tuple[str, Any]]:
"""
If your codegen uses extra global values, add them here.
For example, return ['List', typing.List] if you need ``List`` in the global context.
"""
```
So, for example, the `ListCodeGen` we want for AOTAutograd looks like this
```
class ListCodeGen(CodeGen):
def generate_prologue(self, free_vars, maybe_return_annotation):
lst_unpack = f"""
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
{', '.join(free_vars)} = args_list"""
return lst_unpack
def additional_globals(self):
return [('List', typing.List)]
def process_inputs(self, *inputs):
assert(len(inputs) == 1)
return inputs[0]
```
and
```
def f(a, b):
return a + b
nf = fx.symbolic_trace(f)
nf.graph.set_codegen(ListCodeGen())
nf.recompile()
print(nf.code)
```
would result in
```
def forward(self, args_list: List[torch.Tensor]):
a, b = args_list
add = a + b; a = b = None
return add
```
Backwards compatibility changes - I added `process_outputs` and `process_inputs` to `fx.Graph`, while removing `flatten_inputs` and `flatten_outputs` - those didn't have `backwards_compatibility` on them, so I *think* it's probably fine?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72566
Reviewed By: desertfire
Differential Revision: D34160424
Pulled By: Chillee
fbshipit-source-id: ebf6411312b373e3fbcb13288a34befa449a2375
(cherry picked from commit 13cd12eaa1)
90 lines
3.7 KiB
Python
90 lines
3.7 KiB
Python
r'''
|
|
FX is a toolkit for developers to use to transform ``nn.Module``
|
|
instances. FX consists of three main components: a **symbolic tracer,**
|
|
an **intermediate representation**, and **Python code generation**. A
|
|
demonstration of these components in action:
|
|
|
|
::
|
|
|
|
import torch
|
|
# Simple module for demonstration
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
self.linear = torch.nn.Linear(4, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
|
|
|
module = MyModule()
|
|
|
|
from torch.fx import symbolic_trace
|
|
# Symbolic tracing frontend - captures the semantics of the module
|
|
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
|
|
|
|
# High-level intermediate representation (IR) - Graph representation
|
|
print(symbolic_traced.graph)
|
|
"""
|
|
graph():
|
|
%x : [#users=1] = placeholder[target=x]
|
|
%param : [#users=1] = get_attr[target=param]
|
|
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
|
|
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
|
|
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
|
|
return clamp
|
|
"""
|
|
|
|
# Code generation - valid Python code
|
|
print(symbolic_traced.code)
|
|
"""
|
|
def forward(self, x):
|
|
param = self.param
|
|
add = x + param; x = param = None
|
|
linear = self.linear(add); add = None
|
|
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
|
|
return clamp
|
|
"""
|
|
|
|
The **symbolic tracer** performs "symbolic execution" of the Python
|
|
code. It feeds fake values, called Proxies, through the code. Operations
|
|
on theses Proxies are recorded. More information about symbolic tracing
|
|
can be found in the :func:`symbolic_trace` and :class:`Tracer`
|
|
documentation.
|
|
|
|
The **intermediate representation** is the container for the operations
|
|
that were recorded during symbolic tracing. It consists of a list of
|
|
Nodes that represent function inputs, callsites (to functions, methods,
|
|
or :class:`torch.nn.Module` instances), and return values. More information
|
|
about the IR can be found in the documentation for :class:`Graph`. The
|
|
IR is the format on which transformations are applied.
|
|
|
|
**Python code generation** is what makes FX a Python-to-Python (or
|
|
Module-to-Module) transformation toolkit. For each Graph IR, we can
|
|
create valid Python code matching the Graph's semantics. This
|
|
functionality is wrapped up in :class:`GraphModule`, which is a
|
|
:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a
|
|
``forward`` method generated from the Graph.
|
|
|
|
Taken together, this pipeline of components (symbolic tracing ->
|
|
intermediate representation -> transforms -> Python code generation)
|
|
constitutes the Python-to-Python transformation pipeline of FX. In
|
|
addition, these components can be used separately. For example,
|
|
symbolic tracing can be used in isolation to capture a form of
|
|
the code for analysis (and not transformation) purposes. Code
|
|
generation can be used for programmatically generating models, for
|
|
example from a config file. There are many uses for FX!
|
|
|
|
Several example transformations can be found at the
|
|
`examples <https://github.com/pytorch/examples/tree/master/fx>`__
|
|
repository.
|
|
'''
|
|
|
|
from .graph_module import GraphModule
|
|
from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta
|
|
from .graph import Graph, CodeGen
|
|
from .node import Node, map_arg
|
|
from .proxy import Proxy
|
|
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
|
|
from .subgraph_rewriter import replace_pattern
|