fx.replace_pattern accepts pattern/replacement as GraphModule (#88479)

Symbolic tracer is no longer the default tracer to produce fx graph.
SubgraphRewriter should thus accept a raw GraphModule, rather than use symbolic tracer by default.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88479
Approved by: https://github.com/jerryzh168
This commit is contained in:
Sherlock Huang 2022-11-04 05:01:27 +00:00 committed by PyTorch MergeBot
parent 4bb5c2c205
commit 957a9b63c5
2 changed files with 21 additions and 9 deletions

View File

@ -71,4 +71,4 @@ torch.fx.proxy.TracerBase.iter(self, obj: 'Proxy') -> Iterator
torch.fx.proxy.TracerBase.keys(self, obj: 'Proxy') -> Any
torch.fx.proxy.TracerBase.proxy(self, node: torch.fx.node.Node) -> 'Proxy'
torch.fx.proxy.TracerBase.to_bool(self, obj: 'Proxy') -> bool
torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Callable, replacement: Callable) -> List[torch.fx.subgraph_rewriter.Match]
torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Union[Callable, torch.fx.graph_module.GraphModule], replacement: Union[Callable, torch.fx.graph_module.GraphModule]) -> List[torch.fx.subgraph_rewriter.Match]

View File

@ -5,7 +5,7 @@ from ._symbolic_trace import symbolic_trace
from ._compatibility import compatibility
import copy
from typing import Callable, Dict, List, NamedTuple, Optional, Set
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Union
import torch
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters']
@ -65,7 +65,11 @@ def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None:
@compatibility(is_backward_compatible=True)
def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]:
def replace_pattern(
gm: GraphModule,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule]
) -> List[Match]:
"""
Matches all possible non-overlapping sets of operators and their
data dependencies (``pattern``) in the Graph of a GraphModule
@ -187,8 +191,8 @@ def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -
@compatibility(is_backward_compatible=False)
def replace_pattern_with_filters(
gm: GraphModule,
pattern: Callable,
replacement: Callable,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule],
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]], # type: ignore[name-defined]
) -> List[Match]:
"""
@ -205,8 +209,8 @@ def replace_pattern_with_filters(
def _replace_pattern(
gm: GraphModule,
pattern: Callable,
replacement: Callable,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule],
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined]
) -> List[Match]:
@ -217,8 +221,16 @@ def _replace_pattern(
# Get the graphs for `gm`, `pattern`, `replacement`
original_graph: Graph = gm.graph
pattern_graph: Graph = symbolic_trace(pattern).graph
replacement_graph: Graph = symbolic_trace(replacement).graph
if isinstance(pattern, GraphModule):
pattern_graph = pattern.graph
else:
pattern_graph = symbolic_trace(pattern).graph
if isinstance(replacement, GraphModule):
replacement_graph = replacement.graph
else:
replacement_graph = symbolic_trace(replacement).graph
matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
remove_overlapping_matches=True)