mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable typechecking for _inductor/fx_passes/group_batch_fusion.py (#110111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110111 Approved by: https://github.com/eellison, https://github.com/Skylion007 ghstack dependencies: #110109
This commit is contained in:
parent
3e7f23e04f
commit
fc1fcc4d17
|
|
@ -200,7 +200,6 @@ exclude_patterns = [
|
|||
'torch/_inductor/index_propagation.py',
|
||||
'torch/_inductor/ir.py',
|
||||
'torch/_inductor/scheduler.py',
|
||||
'torch/_inductor/fx_passes/group_batch_fusion.py',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ files =
|
|||
# in Python 3.8
|
||||
python_version = 3.8
|
||||
|
||||
[mypy-deeplearning.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-sympy]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import collections
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, DefaultDict, Deque, Iterator, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
|
|
@ -57,7 +58,7 @@ class BatchFusion(GroupBatchFusionBase):
|
|||
|
||||
|
||||
class GroupLinearFusion(GroupFusion):
|
||||
def _addmm_node_can_be_fused(self, node):
|
||||
def _addmm_node_can_be_fused(self, node: torch.fx.Node):
|
||||
input_shape = node.args[1].meta["tensor_meta"].shape
|
||||
weight_shape = node.args[2].meta["tensor_meta"].shape
|
||||
return (
|
||||
|
|
@ -70,7 +71,7 @@ class GroupLinearFusion(GroupFusion):
|
|||
for shape in input_shape + weight_shape
|
||||
)
|
||||
|
||||
def _mm_node_can_be_fused(self, node):
|
||||
def _mm_node_can_be_fused(self, node: torch.fx.Node):
|
||||
input_shape = node.args[0].meta["tensor_meta"].shape
|
||||
weight_shape = node.args[1].meta["tensor_meta"].shape
|
||||
return (
|
||||
|
|
@ -81,7 +82,7 @@ class GroupLinearFusion(GroupFusion):
|
|||
for shape in input_shape + weight_shape
|
||||
)
|
||||
|
||||
def match(self, node):
|
||||
def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]:
|
||||
if CallFunctionVarArgs(aten.mm.default).match(
|
||||
node
|
||||
) and self._mm_node_can_be_fused(node):
|
||||
|
|
@ -95,7 +96,7 @@ class GroupLinearFusion(GroupFusion):
|
|||
group_key = None
|
||||
return group_key
|
||||
|
||||
def fuse(self, graph, subset):
|
||||
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
||||
group_inputs = []
|
||||
group_weights = []
|
||||
group_biases = []
|
||||
|
|
@ -114,7 +115,8 @@ class GroupLinearFusion(GroupFusion):
|
|||
group_biases.append(bias)
|
||||
|
||||
if all(bias is None for bias in group_biases):
|
||||
group_biases = None
|
||||
group_biases = None # type: ignore[assignment]
|
||||
group_biases: Optional[List[Any]]
|
||||
|
||||
with graph.inserting_before(subset[0]):
|
||||
fused_mm = graph.call_function(
|
||||
|
|
@ -140,7 +142,7 @@ class BatchLinearLHSFusion(BatchFusion):
|
|||
We have a separate pass to eliminate contiguous transpose in a generic way.
|
||||
"""
|
||||
|
||||
def match(self, node):
|
||||
def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool, Any]]:
|
||||
if CallFunctionVarArgs(torch.nn.functional.linear).match(
|
||||
node
|
||||
) and is_linear_node_can_be_fused(node):
|
||||
|
|
@ -151,7 +153,7 @@ class BatchLinearLHSFusion(BatchFusion):
|
|||
group_key = None
|
||||
return group_key
|
||||
|
||||
def fuse(self, graph, subset):
|
||||
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
||||
batch_nodes = []
|
||||
batch_input = None
|
||||
batch_weights = []
|
||||
|
|
@ -201,7 +203,7 @@ class BatchLinearLHSFusion(BatchFusion):
|
|||
graph.erase_node(node)
|
||||
|
||||
|
||||
def is_node_meta_valid(node):
|
||||
def is_node_meta_valid(node: Optional[torch.fx.Node]):
|
||||
if node is None:
|
||||
return True
|
||||
if "example_value" not in node.meta:
|
||||
|
|
@ -209,7 +211,7 @@ def is_node_meta_valid(node):
|
|||
return True
|
||||
|
||||
|
||||
def is_linear_node_can_be_fused(node):
|
||||
def is_linear_node_can_be_fused(node: torch.fx.Node):
|
||||
input = get_arg_value(node, 0, "input")
|
||||
weight = get_arg_value(node, 1, "weight")
|
||||
return (
|
||||
|
|
@ -234,7 +236,7 @@ class BatchLinearFusion(BatchFusion):
|
|||
return None
|
||||
return getitem_node.args[0]
|
||||
|
||||
def match(self, node):
|
||||
def match(self, node: torch.fx.Node):
|
||||
if CallFunctionVarArgs(torch.nn.functional.linear).match(
|
||||
node
|
||||
) and is_linear_node_can_be_fused(node):
|
||||
|
|
@ -252,7 +254,7 @@ class BatchLinearFusion(BatchFusion):
|
|||
group_key = None
|
||||
return group_key
|
||||
|
||||
def fuse(self, graph, subset):
|
||||
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
||||
batch_nodes = []
|
||||
batch_inputs = []
|
||||
batch_weights = []
|
||||
|
|
@ -306,7 +308,7 @@ class BatchTanhFusion(BatchFusion):
|
|||
return None
|
||||
return getitem_node.args[0]
|
||||
|
||||
def match(self, node):
|
||||
def match(self, node: torch.fx.Node):
|
||||
input = get_arg_value(node, 0, "input")
|
||||
if (
|
||||
CallFunctionVarArgs(torch.tanh).match(node)
|
||||
|
|
@ -322,7 +324,7 @@ class BatchTanhFusion(BatchFusion):
|
|||
group_key = None
|
||||
return group_key
|
||||
|
||||
def fuse(self, graph, subset):
|
||||
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
||||
batch_nodes = []
|
||||
batch_inputs = []
|
||||
|
||||
|
|
@ -355,7 +357,7 @@ class BatchLayernormFusion(BatchFusion):
|
|||
Batch layer norm fusion in pre grad pass
|
||||
"""
|
||||
|
||||
def match(self, node):
|
||||
def match(self, node: torch.fx.Node):
|
||||
if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node):
|
||||
input = get_arg_value(node, 0, "input")
|
||||
weight = get_arg_value(node, 2, "weight")
|
||||
|
|
@ -380,7 +382,7 @@ class BatchLayernormFusion(BatchFusion):
|
|||
group_key = None
|
||||
return group_key
|
||||
|
||||
def fuse(self, graph, subset):
|
||||
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
||||
group_inputs = []
|
||||
group_shapes = []
|
||||
group_weights = []
|
||||
|
|
@ -400,9 +402,11 @@ class BatchLayernormFusion(BatchFusion):
|
|||
stack_dim = -1 - len(group_shapes[-1])
|
||||
|
||||
if all(bias is None for bias in group_biases):
|
||||
group_biases = None
|
||||
group_biases = None # type: ignore[assignment]
|
||||
group_biases: Optional[List[Any]]
|
||||
if all(weight is None for weight in group_weights):
|
||||
group_weights = None
|
||||
group_weights = None # type: ignore[assignment]
|
||||
group_weights: Optional[List[Any]]
|
||||
assert all(
|
||||
eps == group_epss[0] for eps in group_epss
|
||||
), "all epsilon values must be equal"
|
||||
|
|
@ -455,13 +459,15 @@ class BatchLayernormFusion(BatchFusion):
|
|||
graph.erase_node(node)
|
||||
|
||||
|
||||
def find_independent_subset_greedy(node_list):
|
||||
def find_independent_subset_greedy(
|
||||
node_list: List[torch.fx.Node],
|
||||
) -> Iterator[List[torch.fx.Node]]:
|
||||
"""
|
||||
Return a list of subset from node_list, all nodes in each subset are independent with each other and can be fused together.
|
||||
The type of subset is list, so we can preserve node's order and benefit from split-cat elimination in later pass.
|
||||
"""
|
||||
visited_node_set = set()
|
||||
dep_set = set()
|
||||
visited_node_set: Set[torch.fx.Node] = set()
|
||||
dep_set: Set[torch.fx.Node] = set()
|
||||
|
||||
def find_dependent_nodes(src_node, cur_node):
|
||||
for input_node in cur_node.all_input_nodes:
|
||||
|
|
@ -473,8 +479,8 @@ def find_independent_subset_greedy(node_list):
|
|||
find_dependent_nodes(src_node, input_node)
|
||||
|
||||
while len(node_list) > 0:
|
||||
subset = []
|
||||
subset_deps = set()
|
||||
subset: List[torch.fx.Node] = []
|
||||
subset_deps: Set[torch.fx.Node] = set()
|
||||
|
||||
for node in node_list:
|
||||
if len(subset) >= MAX_FUSE_SET_SIZE:
|
||||
|
|
@ -495,15 +501,19 @@ def find_independent_subset_greedy(node_list):
|
|||
node_list = next_round_node_list
|
||||
|
||||
|
||||
def get_fusion_candidates(rule, root_node, fused_set):
|
||||
def get_fusion_candidates(
|
||||
rule: GroupBatchFusionBase, root_node: torch.fx.Node, fused_set: Set[torch.fx.Node]
|
||||
) -> DefaultDict[Any, List[torch.fx.Node]]:
|
||||
"""
|
||||
Search fusion candidates for a specific rule using BFS starting from the root node.
|
||||
We only search the subgraph within MAX_FUSE_SEARCH_DEPTH.
|
||||
"""
|
||||
q = collections.deque()
|
||||
q: Deque[Tuple[int, torch.fx.Node]] = collections.deque()
|
||||
|
||||
candidate_dict = collections.defaultdict(list)
|
||||
visited_set = set()
|
||||
candidate_dict: DefaultDict[Any, List[torch.fx.Node]] = collections.defaultdict(
|
||||
list
|
||||
)
|
||||
visited_set: Set[torch.fx.Node] = set()
|
||||
|
||||
for next_node in root_node.all_input_nodes:
|
||||
q.append((1, next_node))
|
||||
|
|
@ -530,9 +540,9 @@ def get_fusion_candidates(rule, root_node, fused_set):
|
|||
return candidate_dict
|
||||
|
||||
|
||||
def apply_group_batch_fusion(graph, rule):
|
||||
def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase):
|
||||
stable_topological_sort(graph)
|
||||
fused_set = set()
|
||||
fused_set: Set[torch.fx.Node] = set()
|
||||
|
||||
for node in reversed(graph.nodes):
|
||||
candidates = get_fusion_candidates(rule, node, fused_set)
|
||||
|
|
@ -555,7 +565,7 @@ def apply_group_batch_fusion(graph, rule):
|
|||
|
||||
|
||||
def group_batch_fusion_post_grad_passes(graph: torch.fx.Graph):
|
||||
fusions = []
|
||||
fusions: List[GroupBatchFusionBase] = []
|
||||
|
||||
if config.group_fusion and has_fbgemm:
|
||||
fusions += [GroupLinearFusion()]
|
||||
|
|
@ -565,7 +575,7 @@ def group_batch_fusion_post_grad_passes(graph: torch.fx.Graph):
|
|||
|
||||
|
||||
def group_batch_fusion_pre_grad_passes(graph: torch.fx.Graph):
|
||||
fusions = []
|
||||
fusions: List[GroupBatchFusionBase] = []
|
||||
if config.batch_fusion:
|
||||
fusions += [
|
||||
BatchLinearFusion(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user