diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index ab32d6af3e6..44bfccd6ee1 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-decorators """ # Inductor Pattern Matcher @@ -50,7 +49,7 @@ import textwrap import typing from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Generator, Iterable, Mapping, Sequence +from collections.abc import Collection, Generator, Iterable, Mapping, Sequence from pathlib import Path from typing import Any, Callable, NoReturn, Optional, Protocol, TypeVar, Union from typing_extensions import Self, TypeIs @@ -261,7 +260,7 @@ class Match: fwd_only, run_functional_passes=run_functional_passes ) replacement = trace_fn( - replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type] + replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) ) if len(self.nodes) == 1: for n in replacement.graph.nodes: @@ -652,8 +651,9 @@ class _TargetArgsExpr(_TargetExpr): if len(_kwargs) < len(self.kwargs): from torch.fx.operator_schemas import normalize_function + assert callable(node.target) normalized_args_and_kwargs = normalize_function( - node.target, node.args, node.kwargs # type: ignore[arg-type] + node.target, node.args, node.kwargs ) if normalized_args_and_kwargs is None: @@ -1080,7 +1080,8 @@ class ReplacementPatternEntry(PatternEntry): if node.op == "call_function": target = node.target args, kwargs = self.fetch_args_kwargs_from_env(node) - result = graph.call_function(target, args, kwargs) # type: ignore[arg-type] + assert callable(target) + result = graph.call_function(target, args, kwargs) _transfer_meta( new_meta=result.meta, old_node=node, @@ -1129,7 +1130,8 @@ class ReplacementPatternEntry(PatternEntry): queue.extend(arg.all_input_nodes) with graph.inserting_before(last_node): - replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type] + assert isinstance(replacement_graph, torch.fx.GraphModule) + replacement = Replacer(replacement_graph).run(*args) if isinstance(replacement, torch.fx.Node): replacement = [replacement] @@ -1207,7 +1209,7 @@ class ReplacementPatternEntry(PatternEntry): idx = maybe_getitem(user) if idx is None: raise AssertionError("can't handle") - replace(user, new[idx]) # type: ignore[index] + replace(user, new[idx]) graph.erase_node(old) if len(output_nodes) == len(replacement): @@ -1326,10 +1328,11 @@ def register_replacement( ) args = list( - torch.fx.map_arg( # type: ignore[arg-type] + torch.fx.map_arg( [match.kwargs[name] for name in argnames], lambda n: n.meta["val"] ) ) + sym_args: list[torch.SymInt] = [] with torch._dynamo.utils.detect_fake_mode(args): for i, grad in enumerate(requires_grad): @@ -1625,7 +1628,7 @@ def gen_register_replacement( ) -@functorch_config.patch(functionalize_rng_ops=False) +@functorch_config.patch(functionalize_rng_ops=False) # type: ignore[misc] def gen_pattern_and_search_gm( search_fn: SearchFn, example_inputs: Sequence[Any], @@ -1750,10 +1753,12 @@ def is_mutation_op(node: torch.fx.Node) -> bool: ): return False if node.op == "call_function": - if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr] + assert callable(node.target) + if _mutation_op_re.search(node.target.__name__): return True elif node.op == "call_method": - if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type] + assert isinstance(node.target, str) + if _mutation_op_re.search(node.target): return True return node.kwargs.get("out") is not None @@ -1777,13 +1782,13 @@ def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int: return mutation_region_id -def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool: - return "mutation_region_id" not in next(iter(graph.nodes)).meta # type: ignore[arg-type] +def should_compute_mutation_region_ids(graph: torch.fx.Graph) -> bool: + return "mutation_region_id" not in next(iter(graph.nodes)).meta -def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None: +def compute_mutation_region_ids(graph: torch.fx.Graph) -> None: mutation_region_id = 0 - for nd in graph.nodes: # type: ignore[union-attr] + for nd in graph.nodes: if is_mutation_op(nd): mutation_region_id += 1 nd.meta["mutation_region_id"] = mutation_region_id @@ -1821,8 +1826,8 @@ class PatternMatcherPass: raise RuntimeError( f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}" ) - if should_compute_mutation_region_ids(graph): # type: ignore[arg-type] - compute_mutation_region_ids(graph) # type: ignore[arg-type] + if should_compute_mutation_region_ids(graph): + compute_mutation_region_ids(graph) get_mutation_region_id_partial = functools.partial( get_mutation_region_id, graph ) @@ -1858,14 +1863,17 @@ class PatternMatcherPass: # pattern match crosses mutation barrier - discard if ( is_match(m) - and len(OrderedSet(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined] + and len( + OrderedSet(map(get_mutation_region_id_partial, m.nodes)) + ) + != 1 ): continue if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: log.warning("%s%s %s %s", node, node.args, m, entry.pattern) if is_match(m) and entry.extra_check(m): count += 1 - entry.apply(m, graph, node) # type: ignore[arg-type] + entry.apply(m, graph, node) counters["inductor"]["pattern_matcher_count"] += 1 counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) return count @@ -1960,14 +1968,17 @@ def fx_to_pattern( def run_node(self, n: torch.fx.Node) -> Any: rv = super().run_node(n) if n.op == "output" and isinstance(rv, tuple): - assert len(rv) == len(n.args[0]) # type: ignore[arg-type] - for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type] + args = n.args[0] + assert isinstance(args, Collection) + assert len(rv) == len(args) + for r, arg in zip(rv, args): r.users = len(arg.users) else: rv.users = len(n.users) return rv - pattern = Converter(gm).run() # type: ignore[arg-type] + assert isinstance(gm, torch.fx.GraphModule) + pattern = Converter(gm).run() if not isinstance(pattern, PatternExpr): return MultiOutputPattern(pytree.tree_leaves(pattern)) return pattern @@ -2037,7 +2048,7 @@ def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.Graph GraphPatternEntry( pattern=pattern, handler=pointless_view, extra_check=_return_true ).register(matcher_pass.patterns) - matcher_pass.apply(gm.graph) # type: ignore[arg-type] + matcher_pass.apply(gm.graph) # remove in/out specs gm.graph._codegen = torch.fx.graph.CodeGen() @@ -2137,11 +2148,12 @@ _seen_patterns = OrderedSet[str]() def get_arg_value( node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None ) -> Any: - return ( - node.args[arg_number] - if len(node.args) > arg_number - else node.kwargs.get(kwarg_name) # type: ignore[arg-type] - ) + if len(node.args) > arg_number: + return node.args[arg_number] + elif kwarg_name is None: + return None + else: + return node.kwargs.get(kwarg_name) def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> list[torch.fx.Node]: @@ -2158,5 +2170,6 @@ def extract_target(node: torch.fx.Node) -> torch.fx.node.Target: as a function. """ if node.op == "call_module": - return _get_attr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type] + assert isinstance(node.target, str) + return _get_attr(node.graph.owning_module, node.target).__class__ return node.target diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 45e2bf33483..411daa5aaa1 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -335,7 +335,7 @@ def type_matches(signature_type: Any, argument_type: Any): @compatibility(is_backward_compatible=False) def normalize_function( target: Callable, - args: tuple[Any], + args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None, arg_types: Optional[tuple[Any]] = None, kwarg_types: Optional[dict[str, Any]] = None,