mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Improve type annotations in _inductor/pattern_matcher.py (#146626)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146626 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
d0f08dc3eb
commit
80d3afc698
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
"""
|
||||
# Inductor Pattern Matcher
|
||||
|
||||
|
|
@ -50,7 +49,7 @@ import textwrap
|
|||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Iterable, Mapping, Sequence
|
||||
from collections.abc import Collection, Generator, Iterable, Mapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, NoReturn, Optional, Protocol, TypeVar, Union
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
|
@ -261,7 +260,7 @@ class Match:
|
|||
fwd_only, run_functional_passes=run_functional_passes
|
||||
)
|
||||
replacement = trace_fn(
|
||||
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
|
||||
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
|
||||
)
|
||||
if len(self.nodes) == 1:
|
||||
for n in replacement.graph.nodes:
|
||||
|
|
@ -652,8 +651,9 @@ class _TargetArgsExpr(_TargetExpr):
|
|||
if len(_kwargs) < len(self.kwargs):
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
|
||||
assert callable(node.target)
|
||||
normalized_args_and_kwargs = normalize_function(
|
||||
node.target, node.args, node.kwargs # type: ignore[arg-type]
|
||||
node.target, node.args, node.kwargs
|
||||
)
|
||||
|
||||
if normalized_args_and_kwargs is None:
|
||||
|
|
@ -1080,7 +1080,8 @@ class ReplacementPatternEntry(PatternEntry):
|
|||
if node.op == "call_function":
|
||||
target = node.target
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
|
||||
assert callable(target)
|
||||
result = graph.call_function(target, args, kwargs)
|
||||
_transfer_meta(
|
||||
new_meta=result.meta,
|
||||
old_node=node,
|
||||
|
|
@ -1129,7 +1130,8 @@ class ReplacementPatternEntry(PatternEntry):
|
|||
queue.extend(arg.all_input_nodes)
|
||||
|
||||
with graph.inserting_before(last_node):
|
||||
replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type]
|
||||
assert isinstance(replacement_graph, torch.fx.GraphModule)
|
||||
replacement = Replacer(replacement_graph).run(*args)
|
||||
if isinstance(replacement, torch.fx.Node):
|
||||
replacement = [replacement]
|
||||
|
||||
|
|
@ -1207,7 +1209,7 @@ class ReplacementPatternEntry(PatternEntry):
|
|||
idx = maybe_getitem(user)
|
||||
if idx is None:
|
||||
raise AssertionError("can't handle")
|
||||
replace(user, new[idx]) # type: ignore[index]
|
||||
replace(user, new[idx])
|
||||
graph.erase_node(old)
|
||||
|
||||
if len(output_nodes) == len(replacement):
|
||||
|
|
@ -1326,10 +1328,11 @@ def register_replacement(
|
|||
)
|
||||
|
||||
args = list(
|
||||
torch.fx.map_arg( # type: ignore[arg-type]
|
||||
torch.fx.map_arg(
|
||||
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
|
||||
)
|
||||
)
|
||||
|
||||
sym_args: list[torch.SymInt] = []
|
||||
with torch._dynamo.utils.detect_fake_mode(args):
|
||||
for i, grad in enumerate(requires_grad):
|
||||
|
|
@ -1625,7 +1628,7 @@ def gen_register_replacement(
|
|||
)
|
||||
|
||||
|
||||
@functorch_config.patch(functionalize_rng_ops=False)
|
||||
@functorch_config.patch(functionalize_rng_ops=False) # type: ignore[misc]
|
||||
def gen_pattern_and_search_gm(
|
||||
search_fn: SearchFn,
|
||||
example_inputs: Sequence[Any],
|
||||
|
|
@ -1750,10 +1753,12 @@ def is_mutation_op(node: torch.fx.Node) -> bool:
|
|||
):
|
||||
return False
|
||||
if node.op == "call_function":
|
||||
if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
|
||||
assert callable(node.target)
|
||||
if _mutation_op_re.search(node.target.__name__):
|
||||
return True
|
||||
elif node.op == "call_method":
|
||||
if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type]
|
||||
assert isinstance(node.target, str)
|
||||
if _mutation_op_re.search(node.target):
|
||||
return True
|
||||
return node.kwargs.get("out") is not None
|
||||
|
||||
|
|
@ -1777,13 +1782,13 @@ def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
|
|||
return mutation_region_id
|
||||
|
||||
|
||||
def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool:
|
||||
return "mutation_region_id" not in next(iter(graph.nodes)).meta # type: ignore[arg-type]
|
||||
def should_compute_mutation_region_ids(graph: torch.fx.Graph) -> bool:
|
||||
return "mutation_region_id" not in next(iter(graph.nodes)).meta
|
||||
|
||||
|
||||
def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None:
|
||||
def compute_mutation_region_ids(graph: torch.fx.Graph) -> None:
|
||||
mutation_region_id = 0
|
||||
for nd in graph.nodes: # type: ignore[union-attr]
|
||||
for nd in graph.nodes:
|
||||
if is_mutation_op(nd):
|
||||
mutation_region_id += 1
|
||||
nd.meta["mutation_region_id"] = mutation_region_id
|
||||
|
|
@ -1821,8 +1826,8 @@ class PatternMatcherPass:
|
|||
raise RuntimeError(
|
||||
f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
|
||||
)
|
||||
if should_compute_mutation_region_ids(graph): # type: ignore[arg-type]
|
||||
compute_mutation_region_ids(graph) # type: ignore[arg-type]
|
||||
if should_compute_mutation_region_ids(graph):
|
||||
compute_mutation_region_ids(graph)
|
||||
get_mutation_region_id_partial = functools.partial(
|
||||
get_mutation_region_id, graph
|
||||
)
|
||||
|
|
@ -1858,14 +1863,17 @@ class PatternMatcherPass:
|
|||
# pattern match crosses mutation barrier - discard
|
||||
if (
|
||||
is_match(m)
|
||||
and len(OrderedSet(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
|
||||
and len(
|
||||
OrderedSet(map(get_mutation_region_id_partial, m.nodes))
|
||||
)
|
||||
!= 1
|
||||
):
|
||||
continue
|
||||
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
|
||||
log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
|
||||
if is_match(m) and entry.extra_check(m):
|
||||
count += 1
|
||||
entry.apply(m, graph, node) # type: ignore[arg-type]
|
||||
entry.apply(m, graph, node)
|
||||
counters["inductor"]["pattern_matcher_count"] += 1
|
||||
counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
|
||||
return count
|
||||
|
|
@ -1960,14 +1968,17 @@ def fx_to_pattern(
|
|||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
rv = super().run_node(n)
|
||||
if n.op == "output" and isinstance(rv, tuple):
|
||||
assert len(rv) == len(n.args[0]) # type: ignore[arg-type]
|
||||
for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type]
|
||||
args = n.args[0]
|
||||
assert isinstance(args, Collection)
|
||||
assert len(rv) == len(args)
|
||||
for r, arg in zip(rv, args):
|
||||
r.users = len(arg.users)
|
||||
else:
|
||||
rv.users = len(n.users)
|
||||
return rv
|
||||
|
||||
pattern = Converter(gm).run() # type: ignore[arg-type]
|
||||
assert isinstance(gm, torch.fx.GraphModule)
|
||||
pattern = Converter(gm).run()
|
||||
if not isinstance(pattern, PatternExpr):
|
||||
return MultiOutputPattern(pytree.tree_leaves(pattern))
|
||||
return pattern
|
||||
|
|
@ -2037,7 +2048,7 @@ def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.Graph
|
|||
GraphPatternEntry(
|
||||
pattern=pattern, handler=pointless_view, extra_check=_return_true
|
||||
).register(matcher_pass.patterns)
|
||||
matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||
matcher_pass.apply(gm.graph)
|
||||
|
||||
# remove in/out specs
|
||||
gm.graph._codegen = torch.fx.graph.CodeGen()
|
||||
|
|
@ -2137,11 +2148,12 @@ _seen_patterns = OrderedSet[str]()
|
|||
def get_arg_value(
|
||||
node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None
|
||||
) -> Any:
|
||||
return (
|
||||
node.args[arg_number]
|
||||
if len(node.args) > arg_number
|
||||
else node.kwargs.get(kwarg_name) # type: ignore[arg-type]
|
||||
)
|
||||
if len(node.args) > arg_number:
|
||||
return node.args[arg_number]
|
||||
elif kwarg_name is None:
|
||||
return None
|
||||
else:
|
||||
return node.kwargs.get(kwarg_name)
|
||||
|
||||
|
||||
def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> list[torch.fx.Node]:
|
||||
|
|
@ -2158,5 +2170,6 @@ def extract_target(node: torch.fx.Node) -> torch.fx.node.Target:
|
|||
as a function.
|
||||
"""
|
||||
if node.op == "call_module":
|
||||
return _get_attr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type]
|
||||
assert isinstance(node.target, str)
|
||||
return _get_attr(node.graph.owning_module, node.target).__class__
|
||||
return node.target
|
||||
|
|
|
|||
|
|
@ -335,7 +335,7 @@ def type_matches(signature_type: Any, argument_type: Any):
|
|||
@compatibility(is_backward_compatible=False)
|
||||
def normalize_function(
|
||||
target: Callable,
|
||||
args: tuple[Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
arg_types: Optional[tuple[Any]] = None,
|
||||
kwarg_types: Optional[dict[str, Any]] = None,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user