[inductor] Improve type annotations in _inductor/pattern_matcher.py (#146626)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146626
Approved by: https://github.com/Skylion007
This commit is contained in:
Tom Ritchford 2025-02-14 20:14:03 +00:00 committed by PyTorch MergeBot
parent d0f08dc3eb
commit 80d3afc698
2 changed files with 43 additions and 30 deletions

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
"""
# Inductor Pattern Matcher
@ -50,7 +49,7 @@ import textwrap
import typing
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Generator, Iterable, Mapping, Sequence
from collections.abc import Collection, Generator, Iterable, Mapping, Sequence
from pathlib import Path
from typing import Any, Callable, NoReturn, Optional, Protocol, TypeVar, Union
from typing_extensions import Self, TypeIs
@ -261,7 +260,7 @@ class Match:
fwd_only, run_functional_passes=run_functional_passes
)
replacement = trace_fn(
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
)
if len(self.nodes) == 1:
for n in replacement.graph.nodes:
@ -652,8 +651,9 @@ class _TargetArgsExpr(_TargetExpr):
if len(_kwargs) < len(self.kwargs):
from torch.fx.operator_schemas import normalize_function
assert callable(node.target)
normalized_args_and_kwargs = normalize_function(
node.target, node.args, node.kwargs # type: ignore[arg-type]
node.target, node.args, node.kwargs
)
if normalized_args_and_kwargs is None:
@ -1080,7 +1080,8 @@ class ReplacementPatternEntry(PatternEntry):
if node.op == "call_function":
target = node.target
args, kwargs = self.fetch_args_kwargs_from_env(node)
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
assert callable(target)
result = graph.call_function(target, args, kwargs)
_transfer_meta(
new_meta=result.meta,
old_node=node,
@ -1129,7 +1130,8 @@ class ReplacementPatternEntry(PatternEntry):
queue.extend(arg.all_input_nodes)
with graph.inserting_before(last_node):
replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type]
assert isinstance(replacement_graph, torch.fx.GraphModule)
replacement = Replacer(replacement_graph).run(*args)
if isinstance(replacement, torch.fx.Node):
replacement = [replacement]
@ -1207,7 +1209,7 @@ class ReplacementPatternEntry(PatternEntry):
idx = maybe_getitem(user)
if idx is None:
raise AssertionError("can't handle")
replace(user, new[idx]) # type: ignore[index]
replace(user, new[idx])
graph.erase_node(old)
if len(output_nodes) == len(replacement):
@ -1326,10 +1328,11 @@ def register_replacement(
)
args = list(
torch.fx.map_arg( # type: ignore[arg-type]
torch.fx.map_arg(
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
)
)
sym_args: list[torch.SymInt] = []
with torch._dynamo.utils.detect_fake_mode(args):
for i, grad in enumerate(requires_grad):
@ -1625,7 +1628,7 @@ def gen_register_replacement(
)
@functorch_config.patch(functionalize_rng_ops=False)
@functorch_config.patch(functionalize_rng_ops=False) # type: ignore[misc]
def gen_pattern_and_search_gm(
search_fn: SearchFn,
example_inputs: Sequence[Any],
@ -1750,10 +1753,12 @@ def is_mutation_op(node: torch.fx.Node) -> bool:
):
return False
if node.op == "call_function":
if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
assert callable(node.target)
if _mutation_op_re.search(node.target.__name__):
return True
elif node.op == "call_method":
if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type]
assert isinstance(node.target, str)
if _mutation_op_re.search(node.target):
return True
return node.kwargs.get("out") is not None
@ -1777,13 +1782,13 @@ def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
return mutation_region_id
def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool:
return "mutation_region_id" not in next(iter(graph.nodes)).meta # type: ignore[arg-type]
def should_compute_mutation_region_ids(graph: torch.fx.Graph) -> bool:
return "mutation_region_id" not in next(iter(graph.nodes)).meta
def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None:
def compute_mutation_region_ids(graph: torch.fx.Graph) -> None:
mutation_region_id = 0
for nd in graph.nodes: # type: ignore[union-attr]
for nd in graph.nodes:
if is_mutation_op(nd):
mutation_region_id += 1
nd.meta["mutation_region_id"] = mutation_region_id
@ -1821,8 +1826,8 @@ class PatternMatcherPass:
raise RuntimeError(
f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
)
if should_compute_mutation_region_ids(graph): # type: ignore[arg-type]
compute_mutation_region_ids(graph) # type: ignore[arg-type]
if should_compute_mutation_region_ids(graph):
compute_mutation_region_ids(graph)
get_mutation_region_id_partial = functools.partial(
get_mutation_region_id, graph
)
@ -1858,14 +1863,17 @@ class PatternMatcherPass:
# pattern match crosses mutation barrier - discard
if (
is_match(m)
and len(OrderedSet(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
and len(
OrderedSet(map(get_mutation_region_id_partial, m.nodes))
)
!= 1
):
continue
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
if is_match(m) and entry.extra_check(m):
count += 1
entry.apply(m, graph, node) # type: ignore[arg-type]
entry.apply(m, graph, node)
counters["inductor"]["pattern_matcher_count"] += 1
counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
return count
@ -1960,14 +1968,17 @@ def fx_to_pattern(
def run_node(self, n: torch.fx.Node) -> Any:
rv = super().run_node(n)
if n.op == "output" and isinstance(rv, tuple):
assert len(rv) == len(n.args[0]) # type: ignore[arg-type]
for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type]
args = n.args[0]
assert isinstance(args, Collection)
assert len(rv) == len(args)
for r, arg in zip(rv, args):
r.users = len(arg.users)
else:
rv.users = len(n.users)
return rv
pattern = Converter(gm).run() # type: ignore[arg-type]
assert isinstance(gm, torch.fx.GraphModule)
pattern = Converter(gm).run()
if not isinstance(pattern, PatternExpr):
return MultiOutputPattern(pytree.tree_leaves(pattern))
return pattern
@ -2037,7 +2048,7 @@ def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.Graph
GraphPatternEntry(
pattern=pattern, handler=pointless_view, extra_check=_return_true
).register(matcher_pass.patterns)
matcher_pass.apply(gm.graph) # type: ignore[arg-type]
matcher_pass.apply(gm.graph)
# remove in/out specs
gm.graph._codegen = torch.fx.graph.CodeGen()
@ -2137,11 +2148,12 @@ _seen_patterns = OrderedSet[str]()
def get_arg_value(
node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None
) -> Any:
return (
node.args[arg_number]
if len(node.args) > arg_number
else node.kwargs.get(kwarg_name) # type: ignore[arg-type]
)
if len(node.args) > arg_number:
return node.args[arg_number]
elif kwarg_name is None:
return None
else:
return node.kwargs.get(kwarg_name)
def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> list[torch.fx.Node]:
@ -2158,5 +2170,6 @@ def extract_target(node: torch.fx.Node) -> torch.fx.node.Target:
as a function.
"""
if node.op == "call_module":
return _get_attr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type]
assert isinstance(node.target, str)
return _get_attr(node.graph.owning_module, node.target).__class__
return node.target

View File

@ -335,7 +335,7 @@ def type_matches(signature_type: Any, argument_type: Any):
@compatibility(is_backward_compatible=False)
def normalize_function(
target: Callable,
args: tuple[Any],
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
arg_types: Optional[tuple[Any]] = None,
kwarg_types: Optional[dict[str, Any]] = None,