Unify MYPYINDUCTOR and MYPY (#118432)

The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this.

Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418
This commit is contained in:
Edward Z. Yang 2024-01-27 06:23:25 -08:00 committed by PyTorch MergeBot
parent 42062e2622
commit d03173e88c
82 changed files with 314 additions and 291 deletions

View File

@ -150,34 +150,6 @@ init_command = [
'optree==0.10.0',
]
[[linter]]
code = 'MYPYINDUCTOR'
include_patterns = [
'torch/_dynamo/**/*.py',
'torch/_inductor/**/*.py',
]
exclude_patterns = [
'**/fb/**',
'torch/_dynamo/backends/**/*.py',
'torch/_dynamo/variables/**/*.py',
'torch/_dynamo/polyfill.py',
'torch/_inductor/fx_passes/serialized_patterns/**',
]
command = [
'python3',
'tools/linter/adapters/mypy_linter.py',
'--config=mypy-inductor.ini',
'--code=MYPYINDUCTOR',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'types-colorama==0.4.6',
]
[[linter]]
code = 'MYPYSTRICT'
include_patterns = [

View File

@ -1,79 +0,0 @@
[mypy]
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
cache_dir = .mypy_cache/inductor
allow_redefinition = True
warn_unused_configs = True
warn_redundant_casts = True
show_error_codes = True
show_column_numbers = True
check_untyped_defs = True
follow_imports = silent
# do not reenable this:
# https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657
warn_unused_ignores = False
disallow_any_generics = True
files =
torch/_dynamo,
torch/_inductor
# We access some Python runtime classes / class members that are only available
# in 3.11. These accesses are gated by runtime checks that cannot always be
# understood by mypy.
python_version = 3.11
[mypy-colorama.*]
ignore_missing_imports = True
[mypy-cutlass_library.*]
ignore_missing_imports = True
[mypy-deeplearning.*]
ignore_missing_imports = True
[mypy-dill.*]
ignore_missing_imports = True
[mypy-einops.*]
ignore_missing_imports = True
[mypy-libfb.*]
ignore_missing_imports = True
# sympy is too dynamic, hard to type properly
[mypy-sympy.*]
ignore_missing_imports = True
follow_imports = skip
[mypy-torch.*.fb.*]
ignore_missing_imports = True
# FIXME: importing this creates lots of type errors
[mypy-torch._dynamo.variables.*]
follow_imports = skip
[mypy-torch.fb.*]
ignore_missing_imports = True
# FIXME: importing this creates lots of type errors
[mypy-torch.fx.*]
follow_imports = skip
# FIXME: importing this creates lots of type errors
[mypy-torch.testing._internal.*]
follow_imports = skip
# sympy is too dynamic, hard to type properly
[mypy-torch.utils._sympy.*]
follow_imports = skip
[mypy-torch_xla.*]
ignore_missing_imports = True
[mypy-torchvision.*]
ignore_missing_imports = True
[mypy-triton.*]
ignore_missing_imports = True

View File

@ -178,6 +178,7 @@ ignore_missing_imports = True
[mypy-sympy.*]
ignore_missing_imports = True
follow_imports = skip
[mypy-hypothesis.*]
ignore_missing_imports = True
@ -270,7 +271,31 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-torch._inductor.*]
ignore_errors = True
disallow_any_generics = True
[mypy-torch._dynamo.*]
ignore_errors = True
disallow_any_generics = True
[mypy-colorama.*]
ignore_missing_imports = True
[mypy-cutlass_library.*]
ignore_missing_imports = True
[mypy-deeplearning.*]
ignore_missing_imports = True
[mypy-einops.*]
ignore_missing_imports = True
[mypy-libfb.*]
ignore_missing_imports = True
[mypy-torch.*.fb.*]
ignore_missing_imports = True
[mypy-torch.fb.*]
ignore_missing_imports = True
[mypy-torch_xla.*]
ignore_missing_imports = True

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import contextlib
import functools
import logging

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import logging
import operator
from collections import defaultdict

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import dataclasses
import functools
from importlib import import_module

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import logging
import traceback
from dataclasses import dataclass, field

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
from torch._dynamo import register_backend

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# This backend is maintained by ONNX team. To direct issues
# to the right people, please tag related GitHub issues with `module: onnx`.
#

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import functools
import sys
from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# import torch # type: ignore[import]
# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import]
# from .registry import register_backend # type: ignore[import]

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import logging
import warnings

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import functools
import importlib
import logging

View File

@ -99,7 +99,7 @@ class AutogradCompilerInstance:
backward_idx: int,
):
assert self.hooks_proxy is not None
backward_fn = self.hooks_proxy[backward_idx]
backward_fn = self.hooks_proxy[backward_idx] # type: ignore[index]
proxies = self.fx_tracer.create_proxy(
kind="call_function",
target=call_backward,
@ -139,7 +139,7 @@ class AutogradCompilerInstance:
def tensor_pre_hook(self, inputs, hook_id, i: int):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id]
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxy = self.proxy_call_hook(
hook,
inputs[i],
@ -151,7 +151,7 @@ class AutogradCompilerInstance:
def pre_hook(self, inputs, hook_id):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id]
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxies = self.proxy_call_hook(
hook,
inputs,
@ -163,7 +163,7 @@ class AutogradCompilerInstance:
def post_hook(self, outputs, inputs, hook_id):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id]
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxies = self.proxy_call_hook(
hook,
outputs,
@ -177,7 +177,7 @@ class AutogradCompilerInstance:
def post_acc_grad_hook(self, input, hook_id):
assert isinstance(input, torch.Tensor)
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id]
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxies = self.proxy_call_hook(
hook,
input,

View File

@ -250,6 +250,7 @@ def generate_config_string(*, stable_output=False):
if stable_output:
return "# config omitted due to stable_output=True"
experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined]
return f"""\
import torch._dynamo.config
import torch._inductor.config
@ -258,7 +259,7 @@ import torch.fx.experimental._config
{torch._dynamo.config.codegen_config()}
{torch._inductor.config.codegen_config()}
{torch._functorch.config.codegen_config()}
{torch.fx.experimental._config.codegen_config()}
{experimental_config}
"""

View File

@ -1398,6 +1398,7 @@ def export(
case_name="cond_operands",
)
assert graph is not None
for node in graph.graph.nodes:
if node.op == "get_attr" and isinstance(
getattr(graph, node.target), torch.Tensor
@ -1424,6 +1425,7 @@ def export(
flat_args_dynamic_dims,
)
# Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check
assert graph is not None
graph.meta["input_shape_constraints"] = (
[constraint.serializable_spec for constraint in constraints]
if constraints

View File

@ -1127,8 +1127,8 @@ class OutputGraph(Checkpointable[OutputGraphState]):
# TODO: Why isn't this stored in meta :think:
pl._dynamo_source = arg.source
gm._param_name_to_source = self.param_name_to_source
gm._source_to_user_stacks = self.source_to_user_stacks
gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment]
try:
name = (

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
"""
Python polyfills for common builtins.
"""

View File

@ -338,7 +338,7 @@ class SideEffects:
cg.extend_output(create_call_function(0, True))
cg.add_cache(var)
if isinstance(var.mutable_local, AttributeMutationNew):
var.mutable_local.source = LocalSource(cg.tempvars[var])
var.mutable_local.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
elif isinstance(var.mutable_local, AttributeMutationNew):
if isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented("AutogradFunctionContextVariable escaped")
@ -455,7 +455,7 @@ class SideEffects:
if isinstance(var, variables.ListVariable):
# old[:] = new
cg(var, allow_cache=False)
cg(var.mutable_local.source)
cg(var.mutable_local.source) # type: ignore[attr-defined]
cg.extend_output(
[
cg.create_load_const(None),
@ -468,11 +468,11 @@ class SideEffects:
cg.tx.output.update_co_names("clear")
cg.tx.output.update_co_names("update")
cg(var.mutable_local.source)
cg(var.mutable_local.source) # type: ignore[attr-defined]
cg.extend_output([create_instruction("LOAD_METHOD", argval="update")])
cg(var, allow_cache=False)
cg(var.mutable_local.source)
cg(var.mutable_local.source) # type: ignore[attr-defined]
cg.extend_output([create_instruction("LOAD_METHOD", argval="clear")])
suffixes.append(
@ -512,7 +512,7 @@ class SideEffects:
elif isinstance(var, variables.TupleIteratorVariable):
for _ in range(var.index):
cg.load_import_from(utils.__name__, "iter_next")
cg(var.mutable_local.source)
cg(var.mutable_local.source) # type: ignore[attr-defined]
cg.extend_output(create_call_function(1, True))
cg.append_output(create_instruction("POP_TOP"))
else:

View File

@ -436,7 +436,7 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
from .source import is_constant_source
if value.source is not None and is_constant_source(value.source):
if truth_fn(value.get_real_value()):
if truth_fn(value.get_real_value()): # type: ignore[attr-defined]
push and self.push(value)
self.jump(inst)
else:
@ -807,7 +807,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
assert val is None or isinstance(
val, VariableTracker
), f"push expects VariableTracker, got {typestr(val)}"
self.stack.append(val)
self.stack.append(val) # type: ignore[arg-type]
def push_many(self, vals: List[VariableTracker]):
for val in vals:
@ -921,7 +921,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
name = inst.argval
source = self.get_global_source(name)
if name not in self.symbolic_globals:
self.symbolic_globals[name] = object() # sentinel object
self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object
variable = self.output.side_effects.track_global_existing(
source, self.symbolic_globals[name]
)
@ -1424,7 +1424,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
assert inst.argval > 0
obj = self.stack[-inst.arg].realize()
assert isinstance(obj, ConstDictVariable)
obj.call_method(self, "__setitem__", (k, v), {})
obj.call_method(self, "__setitem__", (k, v), {}) # type: ignore[arg-type]
def SET_ADD(self, inst):
v = self.pop()
@ -1452,8 +1452,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
if sys.version_info >= (3, 11):
# MAKE_FUNCTION behavior actually changed in 3.11, see
# https://github.com/python/cpython/pull/93189/
assert hasattr(code.value, "co_qualname")
fn_name = ConstantVariable.create(value=code.value.co_qualname)
assert hasattr(code.value, "co_qualname") # type: ignore[attr-defined]
fn_name = ConstantVariable.create(value=code.value.co_qualname) # type: ignore[attr-defined]
defaults = None
closure = None
annotations = None
@ -1676,8 +1676,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
tos1 = self.stack[-2]
assert isinstance(tos1, ConstDictVariable)
if all(k in tos1 for k in tos):
self.push(TupleVariable([tos1.getitem_const(k) for k in tos]))
if all(k in tos1 for k in tos): # type: ignore[attr-defined]
self.push(TupleVariable([tos1.getitem_const(k) for k in tos])) # type: ignore[attr-defined]
if sys.version_info < (3, 11):
self.push(ConstantVariable.create(True))
else:
@ -1750,7 +1750,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
for name in kw_names:
assert isinstance(name, str)
assert self.kw_names is None
self.kw_names = ConstantVariable.create(value=kw_names)
self.kw_names = ConstantVariable.create(value=kw_names) # type: ignore[assignment]
def PUSH_NULL(self, inst):
self.push(NullVariable())
@ -2405,7 +2405,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
closure_cells: Dict[str, VariableTracker],
funcvar: BaseUserFunctionVariable,
):
f_globals = funcvar.get_globals()
f_globals = funcvar.get_globals() # type: ignore[attr-defined]
f_builtins = f_globals["__builtins__"]
if not isinstance(f_builtins, dict):
f_builtins = f_builtins.__dict__
@ -2510,7 +2510,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
unimplemented("cant resume while inlining")
def RETURN_VALUE(self, inst):
self.symbolic_result = self.pop()
self.symbolic_result = self.pop() # type: ignore[assignment]
self.instruction_pointer = None

View File

@ -5,7 +5,7 @@ import sys
import torch
import torch.testing
from torch.testing._internal.common_utils import (
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
IS_WINDOWS,
TEST_WITH_CROSSREF,
TEST_WITH_TORCHDYNAMO,
@ -44,14 +44,14 @@ def run_tests(needs=()):
class TestCase(TorchTestCase):
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
cls._exit_stack.close() # type: ignore[attr-defined]
super().tearDownClass()
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack = contextlib.ExitStack()
cls._exit_stack.enter_context(
cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined]
cls._exit_stack.enter_context( # type: ignore[attr-defined]
config.patch(
raise_on_ctx_manager_usage=True,
suppress_errors=False,

View File

@ -43,12 +43,12 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(
cls._exit_stack.enter_context( # type: ignore[attr-defined]
torch._dynamo.config.patch(debug_dir_root=cls.DEBUG_DIR)
)
# These configurations make new process startup slower. Disable them
# for the minification tests to speed them up.
cls._exit_stack.enter_context(
cls._exit_stack.enter_context( # type: ignore[attr-defined]
torch._inductor.config.patch(
{
# https://github.com/pytorch/pytorch/issues/100376
@ -67,7 +67,7 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
shutil.rmtree(cls.DEBUG_DIR)
else:
print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}")
cls._exit_stack.close()
cls._exit_stack.close() # type: ignore[attr-defined]
def _gen_codegen_fn_patch_code(self, device, bug_type):
assert bug_type in ("compile_error", "runtime_error", "accuracy")

View File

@ -2737,7 +2737,7 @@ Generate the torch object - Dynamo tracing rule (the wrapping variable) map.
def get_torch_obj_rule_map():
d: Dict[Any, VariableTracker] = dict()
for m in torch_name_rule_map:
for k, v in m.items():
for k, v in m.items(): # type: ignore[attr-defined]
obj = load_object(k)
if obj is not None:
if obj in d and d[obj] != v:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
from .base import VariableTracker
from .builtin import BuiltinVariable
from .constant import ConstantVariable, EnumVariable

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import collections
from enum import Enum
from typing import Any, Callable, Dict, List

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import abc
import collections
import contextlib

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import contextlib
import functools
import inspect

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import operator
from typing import Dict, List

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import dataclasses
import inspect
from typing import Callable, Dict, List, Optional

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import collections
import dataclasses
import enum

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import inspect
from typing import Dict, List

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import functools
import inspect
import itertools

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import contextlib
import functools
import logging

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
MAX_CYCLE = 3000
import itertools

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import functools
from typing import Optional

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import collections
import functools
import inspect

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import collections
import dataclasses
import functools

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import functools
import inspect
import itertools

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import weakref
from typing import Dict, List

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
from inspect import getattr_static
from ..bytecode_transformation import create_call_function

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import functools
import inspect

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import functools
import inspect
import logging

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import inspect
from typing import Dict, List

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
import collections
import contextlib
import functools

View File

@ -224,10 +224,10 @@ class DataTypePropagation:
"store_reduction",
):
buf_name = node.args[1]
return V.graph.get_dtype(buf_name)
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
if node.target == operator.getitem:
return self.deduce_node_dtype(node.args[0])
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
assert isinstance(node.target, str)
@ -235,7 +235,7 @@ class DataTypePropagation:
return node.args[1]
if node.target == "constant":
return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]]
return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index]
if node.target.startswith("masked_subblock"):
return self.deduce_node_dtype_by_subgraph(node)

View File

@ -2768,7 +2768,7 @@ class CppVecKernelChecker(CppVecKernel):
# Support masked_load for BF16/FP16. Because the legalization will
# insert to_dtype to convert the BF16/FP16 input to FP32.
dtype = (
V.graph.get_dtype(input_value.args[1])
V.graph.get_dtype(input_value.args[1]) # type: ignore[arg-type]
if input_value.target == "load"
else input_value.args[-1]
)
@ -2784,7 +2784,7 @@ class CppVecKernelChecker(CppVecKernel):
dtype in [torch.int32, torch.int64]
and input_value.target == "load"
):
buffer = V.graph.get_buffer(input_value.args[1])
buffer = V.graph.get_buffer(input_value.args[1]) # type: ignore[arg-type]
# Check if load of a scalar tensor of integer
if not (
isinstance(buffer, TensorBox)
@ -2907,14 +2907,14 @@ class CppKernelProxy(CppKernel):
if node.target not in ["load"]:
return False
assert len(node.args) == 3
load_dtype = V.graph.get_dtype(node.args[1])
load_dtype = V.graph.get_dtype(node.args[1]) # type: ignore[arg-type]
return load_dtype in DTYPE_LOWP_FP
def is_lowp_fp_store(node: torch.fx.Node):
if node.target != "store":
return False
_, store_var, _, _, _ = node.args
store_dtype = V.graph.get_dtype(store_var)
store_dtype = V.graph.get_dtype(store_var) # type: ignore[arg-type]
return store_dtype in DTYPE_LOWP_FP
sub_graph_nodes = list(sub_graph.nodes)

View File

@ -80,7 +80,7 @@ def draw_buffers(nodes: List[BaseSchedulerNode], print_graph=False, fname=None):
if isinstance(node, ir.ComputedBuffer):
dtype = node.data.dtype
metadata = TensorMetadata(group, dtype, None, None, None, None, None)
metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type]
node.meta["tensor_meta"] = metadata
if print_graph:

View File

@ -89,7 +89,7 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
bn_node = match.nodes[0]
graph = match.graph
gm = graph.owning_module
bn_mod = getattr(gm, bn_node.target)
bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type]
# We can only use efficient conv-bn for eval mode with track_running_stats
if not bn_mod.track_running_stats or bn_mod.training:
@ -100,11 +100,11 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
input_node = bn_node.args[0]
else:
input_node = bn_node.kwargs["input"]
if input_node.op != "call_module":
if input_node.op != "call_module": # type: ignore[union-attr]
return
if not hasattr(gm, input_node.target):
if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr]
return
input_mod = getattr(gm, input_node.target)
input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr]
supported_convs = [
nn.Linear,
nn.Conv1d,
@ -118,7 +118,7 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
return
conv_node = input_node
# Output of conv is used by other nodes, cannot optimize
if len(conv_node.users) > 1:
if len(conv_node.users) > 1: # type: ignore[union-attr]
return
# Find a pair of conv and bn computation nodes to optimize.
@ -130,15 +130,15 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
# argument. `graph.get_attr` and
# `graph.call_function` does not allow the `name` argument.
conv_get_node = graph.create_node(
op="get_attr", target=conv_node.target, name="get_conv"
op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr]
)
bn_get_node = graph.create_node(
op="get_attr", target=bn_node.target, name="get_bn"
)
if conv_node.args:
conv_input = conv_node.args[0]
if conv_node.args: # type: ignore[union-attr]
conv_input = conv_node.args[0] # type: ignore[union-attr]
else:
conv_input = conv_node.kwargs["input"]
conv_input = conv_node.kwargs["input"] # type: ignore[union-attr]
# prepare args for the fused function
args = (bn_get_node, conv_get_node, conv_input)
# create a new node

View File

@ -50,7 +50,7 @@ def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs):
constant_fold(gm)
# Make sure meta['val'] is properly set for all nodes
fake_tensor_prop(gm, aot_example_inputs, True)
binary_folding_pass.apply(gm.graph)
binary_folding_pass.apply(gm.graph) # type: ignore[arg-type]
# If we don't have binary folding, we don't need to run the pass again.
# TODO: remove the need to run fake_tensor_prop on the whole model.
if counters["inductor"]["binary_folding"] == binary_folding:
@ -63,7 +63,7 @@ def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs):
fake_tensor_prop(gm, aot_example_inputs, True)
for pattern in pass_patterns:
pattern.apply(gm.graph)
pattern.apply(gm.graph) # type: ignore[arg-type]
# The CPU weight packing always assume the conv's weight is channels last,
# So make sure the layout_optimization is on when doing it.

View File

@ -465,7 +465,7 @@ def _sfdp_params_check(match):
# attn_mask_node may be a float/int number.
if not hasattr(attn_mask_node, "meta"):
return False
attn_mask = attn_mask_node.meta["val"]
attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr]
# Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool
if (
not isinstance(attn_mask, torch.Tensor)

View File

@ -126,7 +126,7 @@ class PostGradBatchLinearFusion(BatchFusion):
def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool:
return (
node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0
node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value]
)
def _is_input_2d(self, input: torch.fx.Node) -> bool:
@ -150,10 +150,10 @@ class PostGradBatchLinearFusion(BatchFusion):
return None
# only handle the cases where inputs are 2D tensors
if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m):
if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type]
return None
m, k = input_m.meta["tensor_meta"].shape
n = weight_m.meta["tensor_meta"].shape[1]
m, k = input_m.meta["tensor_meta"].shape # type: ignore[union-attr]
n = weight_m.meta["tensor_meta"].shape[1] # type: ignore[union-attr]
batch_key = ("batch_linear", m, k, n, bias_m is not None)
return batch_key
@ -200,8 +200,8 @@ class PostGradBatchLinearFusion(BatchFusion):
@register_fusion("group_linear", pre_grad=False)
class GroupLinearFusion(GroupFusion):
def _addmm_node_can_be_fused(self, node: torch.fx.Node):
input_shape = node.args[1].meta["tensor_meta"].shape
weight_shape = node.args[2].meta["tensor_meta"].shape
input_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr]
weight_shape = node.args[2].meta["tensor_meta"].shape # type: ignore[union-attr]
return (
node.kwargs.get("beta", 1.0) == 1.0
and node.kwargs.get("alpha", 1.0) == 1.0
@ -215,8 +215,8 @@ class GroupLinearFusion(GroupFusion):
)
def _mm_node_can_be_fused(self, node: torch.fx.Node):
input_shape = node.args[0].meta["tensor_meta"].shape
weight_shape = node.args[1].meta["tensor_meta"].shape
input_shape = node.args[0].meta["tensor_meta"].shape # type: ignore[union-attr]
weight_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr]
return (
len(input_shape) == 2
and len(weight_shape) == 2
@ -295,11 +295,11 @@ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
# its inputs, and cause dtype not same error in mm or addmm
input, other = node.args
return (
input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape
input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape # type: ignore[union-attr]
if hasattr(input, "meta")
and hasattr(other, "meta")
and "tensor_meta" in input.meta
and "tensor_meta" in other.meta
and "tensor_meta" in input.meta # type: ignore[union-attr]
and "tensor_meta" in other.meta # type: ignore[union-attr]
else False
)
@ -310,12 +310,12 @@ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
alpha = node.kwargs.get("alpha", 1.0)
rounding_mode = node.kwargs.get("rounding_mode", None)
input, other = node.args
shape = list(input.meta["tensor_meta"].shape)
shape = list(input.meta["tensor_meta"].shape) # type: ignore[union-attr]
group_key = (
"batch_" + self.op.__name__.lower() + "_post_grad",
str(shape),
str(input.meta["tensor_meta"].dtype),
str(other.meta["tensor_meta"].dtype),
str(input.meta["tensor_meta"].dtype), # type: ignore[union-attr]
str(other.meta["tensor_meta"].dtype), # type: ignore[union-attr]
str(alpha),
str(rounding_mode),
)
@ -827,7 +827,7 @@ def get_fusion_candidates(
def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase):
stable_topological_sort(graph)
stable_topological_sort(graph) # type: ignore[arg-type]
fused_set: Set[torch.fx.Node] = set()
for node in reversed(graph.nodes):
@ -893,5 +893,5 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False)
for rule in fusions:
apply_group_batch_fusion(graph, rule)
apply_group_batch_fusion(graph, rule) # type: ignore[arg-type]
print_graph(graph, f"Apply fusion {rule.__class__.__name__}.")

View File

@ -271,7 +271,7 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
constant_fold_uniform_value(graph)
if config.pattern_matcher:
count += patterns.apply(graph.graph)
count += patterns.apply(graph.graph) # type: ignore[arg-type]
if not config.fallback_random:
count += replace_random_passes(graph)
@ -317,7 +317,7 @@ def pointless_view(match: Match, arg, size):
"""Remove no-op view"""
graph = match.graph
node = match.output_node()
arg_size = list(node.args[0].meta["val"].shape)
arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
if size == arg_size:
node.replace_all_uses_with(node.args[0])
match.erase_nodes(graph)

View File

@ -343,11 +343,11 @@ if torch._C._has_mkldnn:
if any(
not (
hasattr(n.args[0], "meta")
and isinstance(n.args[0].meta.get("val", None), torch.Tensor)
and isinstance(n.args[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
)
or not (
hasattr(n.args[1], "meta")
and isinstance(n.args[1].meta.get("val", None), torch.Tensor)
and isinstance(n.args[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
)
for n in binary_nodes
):
@ -360,9 +360,9 @@ if torch._C._has_mkldnn:
):
return False
if any(
n.args[0].meta["val"].size() != n.args[1].meta["val"].size()
or n.args[0].meta["val"].device != n.args[1].meta["val"].device
or n.args[0].meta["val"].dtype != n.args[1].meta["val"].dtype
n.args[0].meta["val"].size() != n.args[1].meta["val"].size() # type: ignore[union-attr]
or n.args[0].meta["val"].device != n.args[1].meta["val"].device # type: ignore[union-attr]
or n.args[0].meta["val"].dtype != n.args[1].meta["val"].dtype # type: ignore[union-attr]
for n in binary_nodes
):
return False

View File

@ -85,13 +85,13 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
remove_noop_ops(gm.graph)
print_graph(gm.graph, "Before split cat in post grad pass.")
for patterns in pass_patterns:
patterns.apply(gm.graph)
patterns.apply(gm.graph) # type: ignore[arg-type]
print_graph(
gm.graph,
"Apply split cat pattern matcher PatternMatcherPass in post grad.",
)
if is_inference:
inference_patterns.apply(gm.graph)
inference_patterns.apply(gm.graph) # type: ignore[arg-type]
if config.post_grad_custom_post_pass is not None:
config.post_grad_custom_post_pass(gm.graph)

View File

@ -74,7 +74,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
# explicitly run with predispatch atenIR based passes
if config.is_predispatch:
group_batch_fusion_passes(gm.graph, pre_grad=True)
predispatch_pass.apply(gm.graph)
predispatch_pass.apply(gm.graph) # type: ignore[arg-type]
else:
gm = fuse_fx(gm, example_inputs)
numpy_compat_normalization(gm.graph)
@ -82,7 +82,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
group_batch_fusion_passes(gm.graph, pre_grad=True)
print_graph(gm.graph, "Before split cat in pre grad pass.")
for pattern_matcher_pass in pattern_matcher_passes:
pattern_matcher_pass.apply(gm.graph)
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
print_graph(
gm.graph,
"Apply split cat pattern matcher PatternMatcherPass in pre grad.",
@ -245,21 +245,21 @@ class NormalizedLinearNode:
def get_input(self) -> torch.fx.Node:
if len(self.node.args) > 0:
return self.node.args[0]
return self.node.args[0] # type: ignore[return-value]
else:
return self.node.kwargs["input"]
return self.node.kwargs["input"] # type: ignore[return-value]
def get_weight(self) -> torch.fx.Node:
if len(self.node.args) > 1:
return self.node.args[1]
return self.node.args[1] # type: ignore[return-value]
else:
return self.node.kwargs["weight"]
return self.node.kwargs["weight"] # type: ignore[return-value]
def get_bias(self) -> torch.fx.Node:
if len(self.node.args) > 2:
return self.node.args[2]
return self.node.args[2] # type: ignore[return-value]
else:
return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None
return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value]
class NormalizedMatmulNode:
@ -270,27 +270,27 @@ class NormalizedMatmulNode:
def get_input(self) -> torch.fx.Node:
if len(self.node.args) > 0:
return self.node.args[0]
return self.node.args[0] # type: ignore[return-value]
else:
return self.node.kwargs["input"]
return self.node.kwargs["input"] # type: ignore[return-value]
def get_other(self) -> torch.fx.Node:
if len(self.node.args) > 1:
return self.node.args[1]
return self.node.args[1] # type: ignore[return-value]
else:
return self.node.kwargs["other"]
return self.node.kwargs["other"] # type: ignore[return-value]
def check_permute(node: torch.fx.Node) -> bool:
ranks = len(node.meta["tensor_meta"].shape)
if len(node.args) > 3:
permutation = [node.args[i] % ranks for i in range(1, ranks + 1)]
permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator]
elif (
"permutation" in node.kwargs
and node.kwargs["permutation"] is not None
and len(node.kwargs["permutation"]) > 2
and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type]
):
permutation = [i % ranks for i in node.kwargs["permutation"]]
permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[union-attr]
else:
return False
allowed_permutation = list(range(ranks))
@ -443,9 +443,9 @@ def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
):
Atrans = True
if len(input_A_node.args) > 0:
input_A = input_A_node.args[0]
input_A = input_A_node.args[0] # type: ignore[assignment]
else:
input_A = input_A_node.kwargs["input"]
input_A = input_A_node.kwargs["input"] # type: ignore[assignment]
if (
input_B_node.op == "call_method"
@ -454,9 +454,9 @@ def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
):
Btrans = True
if len(input_B_node.args) > 0:
input_B = input_B_node.args[0]
input_B = input_B_node.args[0] # type: ignore[assignment]
else:
input_B = input_B_node.kwargs["input"]
input_B = input_B_node.kwargs["input"] # type: ignore[assignment]
if Atrans or Btrans:
with module.graph.inserting_before(node):

View File

@ -484,16 +484,16 @@ def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype):
# the two inputs of binary node should have attribute "meta" and should be tensors
if not (
hasattr(binary_node_inputs[0], "meta")
and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor)
and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
) or not (
hasattr(binary_node_inputs[1], "meta")
and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor)
and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
):
return False
# the two inputs of binary node should have the same shape
if (
binary_node_inputs[0].meta["val"].size()
!= binary_node_inputs[1].meta["val"].size()
binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr]
!= binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr]
):
return False
@ -951,12 +951,12 @@ def _is_input_output_same_scale_zp(check_node):
scales = [
(
mul_node.args[1]
if mul_node.args[0].target is check_node
else 1.0 / mul_node.args[1]
if mul_node.args[0].target is check_node # type: ignore[union-attr]
else 1.0 / mul_node.args[1] # type: ignore[operator]
)
for mul_node in mul_nodes
]
if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales):
if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type]
return False
return True
@ -1178,7 +1178,7 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
_user_node = user_node
while _source_node != dequant_pattern_start_node.args[0]:
_user_node = clone_to_new_node(graph, _source_node, _user_node)
_source_node = _source_node.args[0]
_source_node = _source_node.args[0] # type: ignore[assignment]
counters["inductor"]["dequant_promotion_matcher_count"] += 1
counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes)
@ -1256,11 +1256,11 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
mul_node = conv_node.args[0]
else:
convert_to_bf16 = conv_node.args[0]
mul_node = convert_to_bf16.args[0]
sub_node = mul_node.args[0]
to_fp32_node = sub_node.args[0]
mul_node = convert_to_bf16.args[0] # type: ignore[union-attr]
sub_node = mul_node.args[0] # type: ignore[union-attr]
to_fp32_node = sub_node.args[0] # type: ignore[union-attr]
has_clone_to_channel_last_node_in_pattern = (
conv_node.args[1].target is aten.clone.default
conv_node.args[1].target is aten.clone.default # type: ignore[union-attr]
)
clone_node = (
conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None
@ -1268,20 +1268,20 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
if dtype == torch.float32:
dequant_per_channel = (
clone_node.args[0]
clone_node.args[0] # type: ignore[union-attr]
if has_clone_to_channel_last_node_in_pattern
else conv_node.args[1]
)
else:
weight_to_bf16_node = (
clone_node.args[0]
clone_node.args[0] # type: ignore[union-attr]
if has_clone_to_channel_last_node_in_pattern
else conv_node.args[1]
)
dequant_per_channel = weight_to_bf16_node.args[0]
dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr]
assert (
dequant_per_channel.target
dequant_per_channel.target # type: ignore[union-attr]
is quantized_decomposed.dequantize_per_channel.default
)

View File

@ -87,7 +87,7 @@ def _decompose_scatter_functional_helper(
view = graph_call_function(
graph, view_op.target, inp, *view_op.args, **view_op.kwargs
)
src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:])
src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment]
return graph_call_function(
graph,
@ -114,7 +114,7 @@ def _decompose_scatter_functional(
"""
assert node.target is _generalized_scatter
inp, src, view_ops = node.args
return _decompose_scatter_functional_helper(graph, *node.args)
return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type]
def _decompose_scatter_mutating(
@ -140,11 +140,11 @@ def _decompose_scatter_mutating(
inp = graph_call_function(graph, aten.clone, inp)
tmp = inp
for view in view_ops:
tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs)
for view in view_ops: # type: ignore[union-attr]
tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr]
graph_call_function(graph, aten.copy_.default, tmp, src)
return inp
return inp # type: ignore[return-value]
# View ops whose view_scatter op is lowered into mutations anyway,
@ -157,7 +157,7 @@ _ALWAYS_MUTATING_SCATTER_OPS = {
def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
_, _, view_ops = node.args
return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops)
return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr]
def should_reinplace_scatter(node: torch.fx.Node) -> bool:
@ -174,12 +174,12 @@ def should_reinplace_scatter(node: torch.fx.Node) -> bool:
if scatter_always_uses_mutation(node):
return True
if is_node_realized(inp) and is_node_realized(node):
if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type]
return True
# If the output is copied back into the input, this forces both to be
# realized as the output is a user of the input
if inp.op == "placeholder" and any(
if inp.op == "placeholder" and any( # type: ignore[union-attr]
user.target is aten.copy_.default and user.args[0] is inp for user in node.users
):
return True
@ -234,11 +234,11 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
def handle_views(node: torch.fx.Node):
inp = node.args[0]
node_to_view_base[node] = node_to_view_base.get(inp, inp)
node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
node_to_view_op[node] = [
*node_to_view_op[inp],
*node_to_view_op[inp], # type: ignore[index]
ViewOp(
node.target,
node.target, # type: ignore[arg-type]
args=node.args[1:],
kwargs=node.kwargs,
),
@ -255,14 +255,14 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
)
def can_fuse():
if src.target is not _generalized_scatter:
if src.target is not _generalized_scatter: # type: ignore[union-attr]
return False
src_inp, src_src, src_scatter_view_op = src.args
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
inp_base = node_to_view_base.get(inp, inp)
src_base = node_to_view_base.get(src_inp, src_inp)
return inp_base is src_base and node_to_view_op[src_inp] == [
*node_to_view_op[inp],
inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type]
return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index]
*node_to_view_op[inp], # type: ignore[index]
scatter_view_op,
]
@ -279,19 +279,19 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
graph.erase_node(node)
return
src_inp, src_src, src_scatter_view_op = src.args
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
with graph.inserting_before(src):
new_node = graph_call_function(
graph,
_generalized_scatter,
inp,
src_src,
[scatter_view_op, *src_scatter_view_op],
[scatter_view_op, *src_scatter_view_op], # type: ignore[misc]
)
node.replace_all_uses_with(new_node)
graph.erase_node(node)
if src.users:
if src.users: # type: ignore[union-attr]
new_src = graph_call_function(
graph,
_SCATTER_OP_TO_VIEW[node.target],
@ -301,7 +301,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
)
handle_views(new_src)
src.replace_all_uses_with(new_src)
src.replace_all_uses_with(new_src) # type: ignore[union-attr]
graph.erase_node(src)

View File

@ -112,7 +112,9 @@ def replace_random(
mode = {
aten.rand: "rand",
aten.randn: "randn",
}[match.output_node().target.overloadpacket]
}[
match.output_node().target.overloadpacket # type: ignore[union-attr]
] # type: ignore[union-attr]
device = get_device(device)
match.replace_by_example(replacement, [size])

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python

View File

@ -322,7 +322,7 @@ class TorchSplit(CallFunction):
return FailedMatch("only integer getitems are handled")
if user.args[1] in seen_idxs:
return FailedMatch(f"duplicate getitem {user.args[1]}")
if user.args[-1] < 0:
if user.args[-1] < 0: # type: ignore[operator]
# This shouldn't ideally happen as dynamo normalizes indexes to positive
return FailedMatch("negative index")
seen_idxs.add(user.args[1])
@ -358,13 +358,13 @@ def merge_splits(
if len(node.users.keys()) == 0:
return
graph = match.graph
first_split = node.args[0].args[0]
next_split_index = node.args[0].args[1]
first_split = node.args[0].args[0] # type: ignore[union-attr]
next_split_index = node.args[0].args[1] # type: ignore[union-attr]
new_split_sections = list(first_split_sections)
new_split_sections[next_split_index : next_split_index + 1] = next_split_sections
new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc]
first_split_dim = first_split.kwargs["dim"]
first_split_dim = first_split.kwargs["dim"] # type: ignore[union-attr]
to_remove = []
@ -376,7 +376,7 @@ def merge_splits(
kwargs={"dim": first_split_dim},
)
first_split_num_to_user = {
user.args[1]: user for user in first_split.users.keys()
user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr]
}
new_split_num = 0
@ -406,7 +406,7 @@ def merge_splits(
to_remove.append(node)
to_remove.append(old_getitem)
to_remove.append(first_split)
to_remove.append(first_split) # type: ignore[arg-type]
for node in to_remove:
graph.erase_node(node)
@ -463,9 +463,9 @@ class SplitCatSimplifier:
graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
)
self.replace_cat(
graph, split_node, next_users, user_inputs_list_new, transform_params_list
graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type]
)
self.erase_old_nodes(graph, split_node, next_users)
self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type]
def get_user_input_list(
self, split_node: torch.fx.Node, next_users: List[torch.fx.Node]
@ -481,7 +481,7 @@ class SplitCatSimplifier:
if user.target in {torch.cat, torch.stack}:
user_inputs_list.append(self.get_merged_user_inputs(split_node, user))
else:
user_inputs_list.append(self.get_non_cat_node_input(split_node, user))
user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type]
return user_inputs_list
def get_merged_user_inputs(
@ -536,10 +536,10 @@ class SplitCatSimplifier:
if cur_range:
merged_ranges.append(tuple(cur_range))
cur_range = None
merged_ranges.append(input_)
merged_ranges.append(input_) # type: ignore[arg-type]
if cur_range:
merged_ranges.append(tuple(cur_range))
return merged_ranges
return merged_ranges # type: ignore[return-value]
def get_simplified_split_ranges(
self,
@ -618,7 +618,7 @@ class SplitCatSimplifier:
transform_params.append((None, None, None, None))
elif isinstance(user_input, tuple): # Split being simplified
# Verify equal split
subset_split_sections = split_sections[
subset_split_sections = split_sections[ # type: ignore[index]
user_input[0] : user_input[1] + 1
]
# All sections should be equal
@ -698,7 +698,7 @@ class SplitCatSimplifier:
else:
new_user_inputs.append(user_input)
new_user_inputs_list.append(new_user_inputs)
return new_user_inputs_list
return new_user_inputs_list # type: ignore[return-value]
def replace_cat(
self,
@ -837,10 +837,10 @@ class UnbindCatRemover(SplitCatSimplifier):
graph: torch.fx.Graph,
unbind_node: torch.fx.Node,
):
num_unbind = (
max(getitem_node.args[1] for getitem_node in unbind_node.users.keys()) + 1
num_unbind = ( # type: ignore[operator]
max(getitem_node.args[1] for getitem_node in unbind_node.users.keys()) + 1 # type: ignore[operator, union-attr, type-var]
)
split_sections = [1 for _ in range(num_unbind)]
split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type]
super().simplify(graph, unbind_node, split_sections)
@ -1115,11 +1115,11 @@ def safe_to_abort_node(node: torch.fx.Node):
2. the user of all the input nodes should be only one
"""
prev_node = None
for arg in node.args[0]:
if len(arg.users) != 1 or arg.target != operator.getitem:
for arg in node.args[0]: # type: ignore[union-attr]
if len(arg.users) != 1 or arg.target != operator.getitem: # type: ignore[union-attr]
return False
if prev_node is None:
prev_node = arg.args[0]
prev_node = arg.args[0] # type: ignore[union-attr]
else:
if arg.args[0] != prev_node:
return False
@ -1173,8 +1173,8 @@ def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
continue
# find the index of getitems to be cated/stacked
indices = []
for arg in cat_user.args[0]:
indices.append(arg.args[1])
for arg in cat_user.args[0]: # type: ignore[union-attr]
indices.append(arg.args[1]) # type: ignore[union-attr]
# indices may not be necessarily sorted, we sort them first
indices.sort()
# the gettitems to be merged must be consecutive, otherwise
@ -1182,12 +1182,12 @@ def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
if indices[len(indices) - 1] - indices[0] + 1 != len(indices):
continue
# update the arg of cat user, only keep the first getitem
cat_user.update_arg(0, cat_user.args[0][0])
cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index]
# calculate the fused tensor sizes in the indices
fused_tensor_size = 0
for i in range(len(split_node.args[1])):
for i in range(len(split_node.args[1])): # type: ignore[arg-type]
if i in indices:
fused_tensor_size += split_node.args[1][i]
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
# update the split sections
split_sections[indices[0]] = fused_tensor_size
# padding others with zeros to keep the same dict size
@ -1320,9 +1320,9 @@ def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int):
# find the index of getitems to be stacked
indices = []
split_sections_for_unbind = []
for arg in user.args[0]:
indices.append(arg.args[1])
split_sections_for_unbind.append(split_sections[arg.args[1]])
for arg in user.args[0]: # type: ignore[union-attr]
indices.append(arg.args[1]) # type: ignore[union-attr]
split_sections_for_unbind.append(split_sections[arg.args[1]]) # type: ignore[union-attr]
# indices may not be necessarily sorted, we sort them first
indices.sort()
# the gettitems to be merged must be consecutive, otherwise
@ -1330,12 +1330,12 @@ def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int):
if indices[len(indices) - 1] - indices[0] + 1 != len(indices):
continue
# update the arg of stack user, only keep the first getitem
user.update_arg(0, user.args[0][0])
user.update_arg(0, user.args[0][0]) # type: ignore[index]
# calculate the fused tensor sizes in the indices
fused_tensor_size = 0
for i in range(len(split_node.args[1])):
for i in range(len(split_node.args[1])): # type: ignore[arg-type]
if i in indices:
fused_tensor_size += split_node.args[1][i]
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, index, assignment]
# update the split sections
split_sections[indices[0]] = fused_tensor_size
# padding others with zeros to keep the same dict size

View File

@ -204,7 +204,7 @@ def is_node_realized(node: torch.fx.Node) -> bool:
# getitem = foo[0]
# getitem_1 = foo[1]
# where we need to check if foo is a fallback kernel
return is_buffer(node.args[0])
return is_buffer(node.args[0]) # type: ignore[arg-type]
return node.op in ("placeholder", "output") or node.target in fallbacks
if is_buffer(node):

View File

@ -6578,7 +6578,7 @@ class InterpreterShim(torch.fx.Interpreter):
# call super() with a placeholder to avoid constructing a
# GraphModule which is very expensive (it does codegen).
super().__init__(self._dummy_gm(), garbage_collect_values=False)
self.module = self
self.module = self # type: ignore[assignment]
self.graph = graph
self.submodules = submodules
self.extra_traceback = False

View File

@ -3769,7 +3769,7 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
)
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
# Sliding windows must start within the input or left padding
x_alt -= 1
x_alt -= 1 # type: ignore[assignment]
V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i])
if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
# ceil mode is actually a no-op, lets guard on that

View File

@ -551,7 +551,7 @@ class ListOf(PatternExpr):
def __repr__(self):
return f"{self.__class__.__name__}({self.pattern})"
def _match(self, node: List[torch.fx.Node], ctx: MatchContext):
def _match(self, node: List[torch.fx.Node], ctx: MatchContext): # type: ignore[override]
if not isinstance(node, (list, tuple)) or len(node) == 0:
return FailedMatch("non_list")
m = Match(self)
@ -780,9 +780,9 @@ class ReplacementPatternEntry(PatternEntry):
first_node = output_nodes[0]
class Replacer(torch.fx.Interpreter):
call_method = None
call_module = None
get_attr = None
call_method = None # type: ignore[assignment]
call_module = None # type: ignore[assignment]
get_attr = None # type: ignore[assignment]
def run_node(self, node) -> Any:
if node.op in ("placeholder", "output"):
@ -861,7 +861,7 @@ class ReplacementPatternEntry(PatternEntry):
self.replace_with_graph(
match,
graph,
match.replacement_graph,
match.replacement_graph, # type: ignore[arg-type]
self.normalize_args(*match.args, **match.kwargs),
)
@ -991,10 +991,10 @@ def register_replacement(
exclusive_arg_names=exclusive_arg_names,
scalar_workaround=scalar_workaround,
)
specific_pattern_match = specific_pattern.match(match.output_nodes()[0])
specific_pattern_match = specific_pattern.match(match.output_nodes()[0]) # type: ignore[arg-type]
if specific_pattern_match and extra_check(specific_pattern_match):
# trace the pattern using the shapes from the user program
match.replacement_graph = trace_fn(replace_fn, args)
match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment]
return True
return False
@ -1120,10 +1120,10 @@ _mutation_op_re = re.compile(r"_$|(\b|_)(set|enter|exit|seed)(\b|_)")
def is_mutation_op(node: torch.fx.Node) -> bool:
if node.op == "call_function":
if _mutation_op_re.search(node.target.__name__):
if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
return True
elif node.op == "call_method":
if _mutation_op_re.search(node.target):
if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type]
return True
return node.kwargs.get("out") is not None
@ -1203,7 +1203,7 @@ class PatternMatcherPass:
log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
if is_match(m) and entry.extra_check(m):
count += 1
entry.apply(m, graph, node)
entry.apply(m, graph, node) # type: ignore[arg-type]
counters["inductor"]["pattern_matcher_count"] += 1
counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
return count
@ -1334,7 +1334,7 @@ def joint_fwd_bwd(fn, args) -> torch.fx.GraphModule:
GraphPatternEntry(
pattern=pattern, handler=pointless_view, extra_check=_return_true
).register(matcher_pass.patterns)
matcher_pass.apply(gm.graph)
matcher_pass.apply(gm.graph) # type: ignore[arg-type]
# remove in/out specs
gm.graph._codegen = torch.fx.graph.CodeGen()
@ -1438,7 +1438,7 @@ def get_arg_value(
return (
node.args[arg_number]
if len(node.args) > arg_number
else node.kwargs.get(kwarg_name)
else node.kwargs.get(kwarg_name) # type: ignore[arg-type]
)
@ -1456,5 +1456,5 @@ def extract_target(node: Node):
as a function.
"""
if node.op == "call_module":
return getattr(node.graph.owning_module, node.target).__class__
return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type]
return node.target

View File

@ -466,7 +466,7 @@ class SymPyValueRangeAnalysis:
ndigits = ndigits.lower
# We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind
# the second parameter.
fn = lambda number: RoundDecimal(number, ndigits) # noqa: E731
fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731
return ValueRanges.increasing_map(number, fn)