Enable typechecking for _inductor/fx_passes/group_batch_fusion.py (#110111)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110111
Approved by: https://github.com/eellison, https://github.com/Skylion007
ghstack dependencies: #110109
This commit is contained in:
Jez Ng 2023-09-27 17:53:31 -07:00 committed by PyTorch MergeBot
parent 3e7f23e04f
commit fc1fcc4d17
3 changed files with 43 additions and 31 deletions

View File

@ -200,7 +200,6 @@ exclude_patterns = [
'torch/_inductor/index_propagation.py',
'torch/_inductor/ir.py',
'torch/_inductor/scheduler.py',
'torch/_inductor/fx_passes/group_batch_fusion.py',
]
command = [
'python3',

View File

@ -23,6 +23,9 @@ files =
# in Python 3.8
python_version = 3.8
[mypy-deeplearning.*]
ignore_missing_imports = True
[mypy-sympy]
ignore_missing_imports = True

View File

@ -1,6 +1,7 @@
import collections
import logging
import operator
from typing import Any, DefaultDict, Deque, Iterator, List, Optional, Set, Tuple
import torch
from torch._dynamo.utils import counters
@ -57,7 +58,7 @@ class BatchFusion(GroupBatchFusionBase):
class GroupLinearFusion(GroupFusion):
def _addmm_node_can_be_fused(self, node):
def _addmm_node_can_be_fused(self, node: torch.fx.Node):
input_shape = node.args[1].meta["tensor_meta"].shape
weight_shape = node.args[2].meta["tensor_meta"].shape
return (
@ -70,7 +71,7 @@ class GroupLinearFusion(GroupFusion):
for shape in input_shape + weight_shape
)
def _mm_node_can_be_fused(self, node):
def _mm_node_can_be_fused(self, node: torch.fx.Node):
input_shape = node.args[0].meta["tensor_meta"].shape
weight_shape = node.args[1].meta["tensor_meta"].shape
return (
@ -81,7 +82,7 @@ class GroupLinearFusion(GroupFusion):
for shape in input_shape + weight_shape
)
def match(self, node):
def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]:
if CallFunctionVarArgs(aten.mm.default).match(
node
) and self._mm_node_can_be_fused(node):
@ -95,7 +96,7 @@ class GroupLinearFusion(GroupFusion):
group_key = None
return group_key
def fuse(self, graph, subset):
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
group_inputs = []
group_weights = []
group_biases = []
@ -114,7 +115,8 @@ class GroupLinearFusion(GroupFusion):
group_biases.append(bias)
if all(bias is None for bias in group_biases):
group_biases = None
group_biases = None # type: ignore[assignment]
group_biases: Optional[List[Any]]
with graph.inserting_before(subset[0]):
fused_mm = graph.call_function(
@ -140,7 +142,7 @@ class BatchLinearLHSFusion(BatchFusion):
We have a separate pass to eliminate contiguous transpose in a generic way.
"""
def match(self, node):
def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool, Any]]:
if CallFunctionVarArgs(torch.nn.functional.linear).match(
node
) and is_linear_node_can_be_fused(node):
@ -151,7 +153,7 @@ class BatchLinearLHSFusion(BatchFusion):
group_key = None
return group_key
def fuse(self, graph, subset):
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
batch_nodes = []
batch_input = None
batch_weights = []
@ -201,7 +203,7 @@ class BatchLinearLHSFusion(BatchFusion):
graph.erase_node(node)
def is_node_meta_valid(node):
def is_node_meta_valid(node: Optional[torch.fx.Node]):
if node is None:
return True
if "example_value" not in node.meta:
@ -209,7 +211,7 @@ def is_node_meta_valid(node):
return True
def is_linear_node_can_be_fused(node):
def is_linear_node_can_be_fused(node: torch.fx.Node):
input = get_arg_value(node, 0, "input")
weight = get_arg_value(node, 1, "weight")
return (
@ -234,7 +236,7 @@ class BatchLinearFusion(BatchFusion):
return None
return getitem_node.args[0]
def match(self, node):
def match(self, node: torch.fx.Node):
if CallFunctionVarArgs(torch.nn.functional.linear).match(
node
) and is_linear_node_can_be_fused(node):
@ -252,7 +254,7 @@ class BatchLinearFusion(BatchFusion):
group_key = None
return group_key
def fuse(self, graph, subset):
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
batch_nodes = []
batch_inputs = []
batch_weights = []
@ -306,7 +308,7 @@ class BatchTanhFusion(BatchFusion):
return None
return getitem_node.args[0]
def match(self, node):
def match(self, node: torch.fx.Node):
input = get_arg_value(node, 0, "input")
if (
CallFunctionVarArgs(torch.tanh).match(node)
@ -322,7 +324,7 @@ class BatchTanhFusion(BatchFusion):
group_key = None
return group_key
def fuse(self, graph, subset):
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
batch_nodes = []
batch_inputs = []
@ -355,7 +357,7 @@ class BatchLayernormFusion(BatchFusion):
Batch layer norm fusion in pre grad pass
"""
def match(self, node):
def match(self, node: torch.fx.Node):
if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node):
input = get_arg_value(node, 0, "input")
weight = get_arg_value(node, 2, "weight")
@ -380,7 +382,7 @@ class BatchLayernormFusion(BatchFusion):
group_key = None
return group_key
def fuse(self, graph, subset):
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
group_inputs = []
group_shapes = []
group_weights = []
@ -400,9 +402,11 @@ class BatchLayernormFusion(BatchFusion):
stack_dim = -1 - len(group_shapes[-1])
if all(bias is None for bias in group_biases):
group_biases = None
group_biases = None # type: ignore[assignment]
group_biases: Optional[List[Any]]
if all(weight is None for weight in group_weights):
group_weights = None
group_weights = None # type: ignore[assignment]
group_weights: Optional[List[Any]]
assert all(
eps == group_epss[0] for eps in group_epss
), "all epsilon values must be equal"
@ -455,13 +459,15 @@ class BatchLayernormFusion(BatchFusion):
graph.erase_node(node)
def find_independent_subset_greedy(node_list):
def find_independent_subset_greedy(
node_list: List[torch.fx.Node],
) -> Iterator[List[torch.fx.Node]]:
"""
Return a list of subset from node_list, all nodes in each subset are independent with each other and can be fused together.
The type of subset is list, so we can preserve node's order and benefit from split-cat elimination in later pass.
"""
visited_node_set = set()
dep_set = set()
visited_node_set: Set[torch.fx.Node] = set()
dep_set: Set[torch.fx.Node] = set()
def find_dependent_nodes(src_node, cur_node):
for input_node in cur_node.all_input_nodes:
@ -473,8 +479,8 @@ def find_independent_subset_greedy(node_list):
find_dependent_nodes(src_node, input_node)
while len(node_list) > 0:
subset = []
subset_deps = set()
subset: List[torch.fx.Node] = []
subset_deps: Set[torch.fx.Node] = set()
for node in node_list:
if len(subset) >= MAX_FUSE_SET_SIZE:
@ -495,15 +501,19 @@ def find_independent_subset_greedy(node_list):
node_list = next_round_node_list
def get_fusion_candidates(rule, root_node, fused_set):
def get_fusion_candidates(
rule: GroupBatchFusionBase, root_node: torch.fx.Node, fused_set: Set[torch.fx.Node]
) -> DefaultDict[Any, List[torch.fx.Node]]:
"""
Search fusion candidates for a specific rule using BFS starting from the root node.
We only search the subgraph within MAX_FUSE_SEARCH_DEPTH.
"""
q = collections.deque()
q: Deque[Tuple[int, torch.fx.Node]] = collections.deque()
candidate_dict = collections.defaultdict(list)
visited_set = set()
candidate_dict: DefaultDict[Any, List[torch.fx.Node]] = collections.defaultdict(
list
)
visited_set: Set[torch.fx.Node] = set()
for next_node in root_node.all_input_nodes:
q.append((1, next_node))
@ -530,9 +540,9 @@ def get_fusion_candidates(rule, root_node, fused_set):
return candidate_dict
def apply_group_batch_fusion(graph, rule):
def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase):
stable_topological_sort(graph)
fused_set = set()
fused_set: Set[torch.fx.Node] = set()
for node in reversed(graph.nodes):
candidates = get_fusion_candidates(rule, node, fused_set)
@ -555,7 +565,7 @@ def apply_group_batch_fusion(graph, rule):
def group_batch_fusion_post_grad_passes(graph: torch.fx.Graph):
fusions = []
fusions: List[GroupBatchFusionBase] = []
if config.group_fusion and has_fbgemm:
fusions += [GroupLinearFusion()]
@ -565,7 +575,7 @@ def group_batch_fusion_post_grad_passes(graph: torch.fx.Graph):
def group_batch_fusion_pre_grad_passes(graph: torch.fx.Graph):
fusions = []
fusions: List[GroupBatchFusionBase] = []
if config.batch_fusion:
fusions += [
BatchLinearFusion(),