mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR fixes typo in `.rst` files under docs directory Pull Request resolved: https://github.com/pytorch/pytorch/pull/110359 Approved by: https://github.com/kit1980
437 lines
15 KiB
ReStructuredText
437 lines
15 KiB
ReStructuredText
Writing Graph Transformations on ATen IR
|
||
========================================
|
||
|
||
Passes
|
||
------
|
||
|
||
Since the ATen IR sits at the FX Graph/GraphModule level, any
|
||
transformations written for FX Graphs can be easily applied onto the
|
||
ATen IR. If you’re familiar with writing FX graph transformations, then
|
||
this will be the same.
|
||
|
||
The most direct way of writing transformations is by looping through the
|
||
given graph and directly manipulating the nodes within the graph.
|
||
|
||
For example, let’s say we want to replace
|
||
``torch.ops.aten.add.Tensor()`` calls with
|
||
``torch.ops.aten.mul.Tensor()`` calls:
|
||
|
||
.. code:: python
|
||
|
||
import torch
|
||
|
||
def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||
for node in gm.graph.nodes:
|
||
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
|
||
node.target = torch.ops.aten.mul.Tensor
|
||
|
||
We can also delete and append new nodes through FX utility functions
|
||
that can be found in the
|
||
`Graph <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__
|
||
documentation. For example, if we want to insert a
|
||
``torch.ops.aten.relu.default()`` after the ``add`` call:
|
||
|
||
.. code:: python
|
||
|
||
import torch
|
||
|
||
def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||
for node in gm.graph.nodes:
|
||
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
|
||
|
||
# Specifies the insertion point. Any nodes added to the graph within
|
||
# this scope will be inserted after `node`
|
||
with gm.graph.inserting_after(node):
|
||
# Insert a new `call_function` node with op `torch.ops.aten.relu.default`
|
||
new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,))
|
||
# Replace all the places that use `node` to now use the `new_relu_node`
|
||
node.replace_all_uses_with(new_relu_node)
|
||
|
||
In general, transformations can be roughly categorized into a couple of
|
||
axis:
|
||
|
||
Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creating
|
||
many-to-one mapping (eg. fusion)
|
||
|
||
Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doing
|
||
backwards iteration (eg. dead code elimination)
|
||
|
||
Axis C: 1. Dependent on local node information (eg. out-variant
|
||
conversion) 2. Dependent on global graph information (eg. memory
|
||
planning)
|
||
|
||
Our projection on the frequency of these use cases are: 1. A.1, B.1, C.1
|
||
2. A.2 3. B.2, C.2
|
||
|
||
Although we can make all graph transformations through directly
|
||
manipulating the graph, we also provide some helper utilities for some
|
||
ease of use for the level 1 and 2 use-cases.
|
||
|
||
Transformer
|
||
~~~~~~~~~~~
|
||
|
||
For level 1 uses cases (creating one-to-X mappings, doing forwards
|
||
iterations, and looking at local node information), we can utilize the
|
||
`Transformer <https://pytorch.org/docs/stable/fx.html#torch.fx.Transformer>`__
|
||
class to execute each node and recreate a graph, except with the
|
||
transformations specified.
|
||
|
||
One-to-One Pass
|
||
^^^^^^^^^^^^^^^
|
||
|
||
An example for one-to-one mappings, if we wanted to replace an op A with
|
||
another op B, we can run the GraphModule, and very time we see op A,
|
||
return op B.
|
||
|
||
An example is:
|
||
|
||
.. code:: python
|
||
|
||
class ReplaceAddWithMul(torch.fx.Transformer):
|
||
def call_function(self, target, args, kwargs):
|
||
if target != torch.ops.aten.add.Tensor:
|
||
return super().call_function(target, args, kwargs)
|
||
return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)
|
||
|
||
transformed_graph_module = ReplaceAddWithMul(graph_module).transform()
|
||
|
||
The ``super().call_function(target, args, kwargs, meta)`` call creates a
|
||
``call_function`` FX node, and returns the result of running the
|
||
operator with the given arguments.
|
||
|
||
One-to-X Pass
|
||
^^^^^^^^^^^^^
|
||
|
||
If we wanted to do one-to-X mappings, like replacing op A with 2 other
|
||
ops B and C, we would then make 2 calls to ``super().call_function`` to
|
||
create 2 FX nodes, one with op B and another with op C, and return the
|
||
result of running op C.
|
||
|
||
For example:
|
||
|
||
.. code:: python
|
||
|
||
class ReplaceAddWithMulSub(torch.fx.Transformer):
|
||
"""
|
||
Original:
|
||
def f(x, y):
|
||
return x + y
|
||
|
||
After pass:
|
||
def f(x, y):
|
||
z = x * y
|
||
return z - y
|
||
"""
|
||
def call_function(self, target, args, kwargs):
|
||
if target != torch.ops.aten.add.Tensor:
|
||
return super().call_function(target, args, kwargs)
|
||
|
||
x, y = args
|
||
|
||
mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
|
||
return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})
|
||
|
||
transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()
|
||
|
||
One-to-None Pass
|
||
^^^^^^^^^^^^^^^^
|
||
|
||
If we wanted to remove an op, we can just return the value passed into
|
||
the function:
|
||
|
||
.. code:: python
|
||
|
||
class RemoveDetachPass(torch.fx.Transformer):
|
||
def call_function(self, target, args, kwargs):
|
||
if target not in (
|
||
torch.ops.aten.detach.default,
|
||
torch.ops.aten.detach_copy.default,
|
||
):
|
||
return super().call_function(target, args, kwargs, meta)
|
||
|
||
assert len(args) == 1
|
||
return args[0]
|
||
|
||
transformed_graph_module = RemoveDetachPass(graph_module).transform()
|
||
|
||
Utilizing Local Information
|
||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||
|
||
An example of utilizing local node information is, if we wanted to
|
||
convert all the scalars within the graph to tensors, we can run the
|
||
given ``fx.GraphModule``, and for every argument that contains a scalar,
|
||
we convert it to a tensor. It might look something like:
|
||
|
||
.. code:: python
|
||
|
||
def args_map(target, fn, args, kwargs):
|
||
assert isinstance(args, tuple)
|
||
assert isinstance(kwargs, dict)
|
||
args = list(args)
|
||
kwargs = kwargs.copy()
|
||
|
||
# Update the argument based on the function passed
|
||
def update(key, args, schema):
|
||
args[key] = fn(args[key], schema)
|
||
|
||
# Update each argument in the schema
|
||
for i, schema in enumerate(target._schema.arguments):
|
||
if schema.name in kwargs:
|
||
update(schema.name, kwargs, schema)
|
||
elif not schema.kwarg_only and i < len(args):
|
||
update(i, args, schema)
|
||
return tuple(args), kwargs
|
||
|
||
class ScalarToTensorPass(torch.fx.Transformer):
|
||
def call_function(self, target, args, kwargs):
|
||
breakpoint()
|
||
def try_coerce(value, arg):
|
||
return (
|
||
torch.tensor(value)
|
||
if isinstance(value, (float, int, bool))
|
||
and type(arg.type) == torch.TensorType
|
||
else value
|
||
)
|
||
|
||
args, kwargs = args_map(target, try_coerce, args, kwargs)
|
||
return super().call_function(target, args, kwargs)
|
||
|
||
transformed_graph_module = ScalarToTensorPass(graph_module).transform()
|
||
|
||
Subgraph Rewriter
|
||
~~~~~~~~~~~~~~~~~
|
||
|
||
For creating many-to-one mappings, we can utilize FX’s `subgraph
|
||
rewriter <https://github.com/pytorch/pytorch/blob/main/torch/fx/subgraph_rewriter.py>`__.
|
||
Given a ``pattern``, it creates a subgraph of operators matching to the
|
||
pattern, and then replaces each matched subgraph with the
|
||
``replacement``.
|
||
|
||
Note:
|
||
|
||
::
|
||
|
||
This is an inplace operation.
|
||
|
||
The ``pattern`` and ``replacement`` inputs must be callable functions or
|
||
GraphModules containing the same operators that are used within the
|
||
graph (ATen ops) so that the subgraph rewriter can find the correct
|
||
pattern in the graph. Inputs to the pattern/replacement callables will
|
||
be treated as wildcards when matching.
|
||
|
||
An example:
|
||
|
||
.. code:: python
|
||
|
||
from torch.fx import subgraph_rewriter
|
||
|
||
def replace_patterns(graph_module):
|
||
def pattern(x, y):
|
||
x = torch.ops.aten.add.Tensor(x, y)
|
||
x = torch.ops.aten.mul.Tensor(x, y)
|
||
return x
|
||
|
||
def replacement(x, y):
|
||
return torch.ops.aten.sub.Tensor(x, y)
|
||
|
||
replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
|
||
traced_module, pattern, replacement
|
||
)
|
||
|
||
The subgraph rewriter returns a list of ``ReplacedPatterns``:
|
||
|
||
.. code:: python
|
||
|
||
@dataclass
|
||
class ReplacedPatterns:
|
||
# Node from which the match was found
|
||
anchor: Node
|
||
# Maps nodes in the pattern subgraph to nodes in the larger graph
|
||
nodes_map: Dict[Node, Node]
|
||
# List of nodes that were added into the graph
|
||
replacements: List[Node]
|
||
|
||
Note:
|
||
|
||
::
|
||
|
||
The nodes created by the subgraph rewriter will not have the metadata that
|
||
is populated in the matched nodes, but you can use
|
||
`ReplacedPatterns.nodes_map` to find the nodes in the original graph that
|
||
were matched, and `ReplacedPatterns.replacements` to find the nodes that
|
||
were replaced in the transformed graph.
|
||
|
||
Pass Manager
|
||
------------
|
||
|
||
The
|
||
```PassManager`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/pass_manager.py>`__
|
||
is a class used to run multiple passes on a given graph module. When
|
||
initializing a ``PassManager`` instance, we pass in a list of passes
|
||
that we want to run and set a couple of flags. To run the collection of
|
||
passes on a graph module, we can pass the graph module directly to the
|
||
``PassManager`` instance.
|
||
|
||
An example:
|
||
|
||
.. code:: python
|
||
|
||
from torch.fx.passes.infra.pass_manager import PassManager
|
||
|
||
pm = PassManager(
|
||
passes=[replace_add_with_div, replace_div_with_mul],
|
||
run_checks_after_each_pass=True,
|
||
suppress_check_failures=False,
|
||
)
|
||
graph_module_out = pm(graph_module)
|
||
|
||
To add a common set of checks that are run after each pass, we can call
|
||
the function ``set_checks(check: Callable)`` which takes in a callable
|
||
function as input. If the ``run_checks_after_each_pass`` flag is set,
|
||
the ``check`` will be called after each pass is run on the graph module.
|
||
|
||
An example:
|
||
|
||
.. code:: python
|
||
|
||
pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])
|
||
|
||
def check_div_target(graph_module):
|
||
for node in graph_module.graph.nodes:
|
||
if node.op == "call_function" and node.target != torch.div:
|
||
raise ValueError("Target should be div!")
|
||
|
||
pm.add_checks(check_div_target)
|
||
|
||
pm(graph_module) # raises ValueError after replace_div_with_mul pass
|
||
|
||
Partitioner
|
||
-----------
|
||
|
||
There are a couple of common FX graph based partitioners we can use to
|
||
partition the graph.
|
||
|
||
Subgraph Matcher
|
||
~~~~~~~~~~~~~~~~
|
||
|
||
For finding subgraphs within a graph that match a specific pattern, we
|
||
can utilize FX’s
|
||
```SubgraphMatcher`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/matcher_utils.py>`__.
|
||
|
||
Class Attributes:
|
||
|
||
- ``pattern (Graph)``: The targeted matching pattern. Placeholder nodes
|
||
in the graph will be treated as wildcards when matching.
|
||
- ``match_output (bool)``: If True, output node in the pattern graph
|
||
will be treated as a part of the targeted pattern. If False, output
|
||
node is ignored during match.
|
||
- ``match_placeholder (bool)``: If True, placeholder node in the
|
||
pattern graph will be treated as a part of the targeted pattern. If
|
||
False, placeholder nodes will be used a wildcard.
|
||
- ``remove_overlapping_matches (bool)``: If True, in the case of
|
||
overlapping matches, only the first match will be returned.
|
||
- ``ignore_literals (bool)``: If True, will not check if literals are
|
||
equal and will instead treat them as wildcards.
|
||
|
||
An example:
|
||
|
||
.. code:: python
|
||
|
||
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
|
||
|
||
class LargeModel(torch.nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self._weight = torch.nn.Parameter(torch.ones(3, 3))
|
||
self._bias = torch.nn.Parameter(torch.ones(3, 3))
|
||
|
||
def forward(self, x):
|
||
return torch.ops.aten.addmm.default(self._bias, x, self._weight)
|
||
|
||
large_model_graph = torch.export(LargeModel(), inputs).graph
|
||
|
||
class PatternModel(torch.nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
|
||
self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
|
||
|
||
def forward(self, x):
|
||
return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
|
||
|
||
pattern_graph = torch.export(PatternModel(), inputs).graph
|
||
|
||
subgraph_matcher = SubgraphMatcher(pattern_graph)
|
||
match_result = subgraph_matcher.match(large_model_graph)
|
||
|
||
The ``match`` function returns a list of ``InternalMatch``:
|
||
|
||
.. code:: python
|
||
|
||
@dataclass
|
||
class InternalMatch():
|
||
# Nodes from which the match was found
|
||
anchors: List[Node]
|
||
# Maps nodes in the pattern subgraph to nodes in the larger graph
|
||
nodes_map: Dict[Node, Node] = field(default_factory=dict)
|
||
# Nodes in target graph that are matched placeholder in pattern
|
||
placeholder_nodes: List[Node] = field(default_factory=list)
|
||
# Nodes in matched subgraph returned by output
|
||
returning_nodes: List[Node] = field(default_factory=list)
|
||
|
||
Capability Based Partitioner
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
To find the largest subgraphs of nodes that support a specific
|
||
invariant, we can utilize FX’s
|
||
```CapabilityBasedPartitioner`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/partitioner.py#L34>`__.
|
||
|
||
Class Attributes
|
||
|
||
- ``graph_module (torch.fx.GraphModule)``: The graph module we are
|
||
partitioning on.
|
||
- ``operator_support (OperatorSupportBase)``: The object used to
|
||
determine if a node in the graph is supported in the partition.
|
||
- ``allows_single_node_partition (bool)``: If True, allows single node
|
||
partitions to be formed.
|
||
- ``non_compute_ops (Optional[Sequence[str]])``: A set of ops that are
|
||
considered to be “non-compute” (ex ``torch.ops.aten.view`` and
|
||
``_operator.getitem``, so that the partitioner will not create graphs
|
||
that only contain these non-compute ops
|
||
- ``allowed_single_node_partition_ops (Optional[Sequence[str]])``: A
|
||
set of ops that are allowed to be in a single node partition.
|
||
|
||
The
|
||
```OperatorSupportBase`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#LL28C1-L28C1>`__
|
||
class is used by the partitioner to determine if a specific node in the
|
||
graph belongs in the partition. This is done by overriding the
|
||
``is_node_supported`` function. You can chain multiple
|
||
``OperatorSupportBase`` by using
|
||
```chain`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L150>`__\ (which
|
||
returns False if any of the OperatorSupportBase return False) and
|
||
```any_chain`` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L164>`__
|
||
(which returns True if any of the OperatorSupportBase returns True).
|
||
|
||
An example:
|
||
|
||
.. code:: python
|
||
|
||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
||
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
|
||
|
||
class AddMulOperatorSupport(OperatorSupportBase):
|
||
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
|
||
return node.op == "call_function" and node.target in [
|
||
torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
|
||
]
|
||
|
||
capability_partitioner = CapabilityBasedPartitioner(
|
||
graph_module,
|
||
op_support,
|
||
)
|
||
|
||
# Returns a list of partitions (list of nodes that belong in each partition)
|
||
partition_list = capability_partitioner.propose_partitions()
|
||
# Fuses the partitions into graph modules and inserts `call_module` nodes in the graph
|
||
fused_graph_module = capability_partitioner.fuse_partitions(partition_list)
|