mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fx.replace_pattern accepts pattern/replacement as GraphModule (#88479)
Symbolic tracer is no longer the default tracer to produce fx graph. SubgraphRewriter should thus accept a raw GraphModule, rather than use symbolic tracer by default. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88479 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
4bb5c2c205
commit
957a9b63c5
|
|
@ -71,4 +71,4 @@ torch.fx.proxy.TracerBase.iter(self, obj: 'Proxy') -> Iterator
|
|||
torch.fx.proxy.TracerBase.keys(self, obj: 'Proxy') -> Any
|
||||
torch.fx.proxy.TracerBase.proxy(self, node: torch.fx.node.Node) -> 'Proxy'
|
||||
torch.fx.proxy.TracerBase.to_bool(self, obj: 'Proxy') -> bool
|
||||
torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Callable, replacement: Callable) -> List[torch.fx.subgraph_rewriter.Match]
|
||||
torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Union[Callable, torch.fx.graph_module.GraphModule], replacement: Union[Callable, torch.fx.graph_module.GraphModule]) -> List[torch.fx.subgraph_rewriter.Match]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from ._symbolic_trace import symbolic_trace
|
|||
from ._compatibility import compatibility
|
||||
|
||||
import copy
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Set
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Union
|
||||
import torch
|
||||
|
||||
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters']
|
||||
|
|
@ -65,7 +65,11 @@ def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None:
|
|||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]:
|
||||
def replace_pattern(
|
||||
gm: GraphModule,
|
||||
pattern: Union[Callable, GraphModule],
|
||||
replacement: Union[Callable, GraphModule]
|
||||
) -> List[Match]:
|
||||
"""
|
||||
Matches all possible non-overlapping sets of operators and their
|
||||
data dependencies (``pattern``) in the Graph of a GraphModule
|
||||
|
|
@ -187,8 +191,8 @@ def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -
|
|||
@compatibility(is_backward_compatible=False)
|
||||
def replace_pattern_with_filters(
|
||||
gm: GraphModule,
|
||||
pattern: Callable,
|
||||
replacement: Callable,
|
||||
pattern: Union[Callable, GraphModule],
|
||||
replacement: Union[Callable, GraphModule],
|
||||
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]], # type: ignore[name-defined]
|
||||
) -> List[Match]:
|
||||
"""
|
||||
|
|
@ -205,8 +209,8 @@ def replace_pattern_with_filters(
|
|||
|
||||
def _replace_pattern(
|
||||
gm: GraphModule,
|
||||
pattern: Callable,
|
||||
replacement: Callable,
|
||||
pattern: Union[Callable, GraphModule],
|
||||
replacement: Union[Callable, GraphModule],
|
||||
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined]
|
||||
) -> List[Match]:
|
||||
|
||||
|
|
@ -217,8 +221,16 @@ def _replace_pattern(
|
|||
|
||||
# Get the graphs for `gm`, `pattern`, `replacement`
|
||||
original_graph: Graph = gm.graph
|
||||
pattern_graph: Graph = symbolic_trace(pattern).graph
|
||||
replacement_graph: Graph = symbolic_trace(replacement).graph
|
||||
|
||||
if isinstance(pattern, GraphModule):
|
||||
pattern_graph = pattern.graph
|
||||
else:
|
||||
pattern_graph = symbolic_trace(pattern).graph
|
||||
|
||||
if isinstance(replacement, GraphModule):
|
||||
replacement_graph = replacement.graph
|
||||
else:
|
||||
replacement_graph = symbolic_trace(replacement).graph
|
||||
|
||||
matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
|
||||
remove_overlapping_matches=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user