mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Unify MYPYINDUCTOR and MYPY (#118432)
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this. Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432 Approved by: https://github.com/Skylion007 ghstack dependencies: #118414, #118418
This commit is contained in:
parent
42062e2622
commit
d03173e88c
|
|
@ -150,34 +150,6 @@ init_command = [
|
|||
'optree==0.10.0',
|
||||
]
|
||||
|
||||
[[linter]]
|
||||
code = 'MYPYINDUCTOR'
|
||||
include_patterns = [
|
||||
'torch/_dynamo/**/*.py',
|
||||
'torch/_inductor/**/*.py',
|
||||
]
|
||||
exclude_patterns = [
|
||||
'**/fb/**',
|
||||
'torch/_dynamo/backends/**/*.py',
|
||||
'torch/_dynamo/variables/**/*.py',
|
||||
'torch/_dynamo/polyfill.py',
|
||||
'torch/_inductor/fx_passes/serialized_patterns/**',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/mypy_linter.py',
|
||||
'--config=mypy-inductor.ini',
|
||||
'--code=MYPYINDUCTOR',
|
||||
'--',
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'types-colorama==0.4.6',
|
||||
]
|
||||
|
||||
[[linter]]
|
||||
code = 'MYPYSTRICT'
|
||||
include_patterns = [
|
||||
|
|
|
|||
|
|
@ -1,79 +0,0 @@
|
|||
[mypy]
|
||||
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
|
||||
|
||||
cache_dir = .mypy_cache/inductor
|
||||
allow_redefinition = True
|
||||
warn_unused_configs = True
|
||||
warn_redundant_casts = True
|
||||
show_error_codes = True
|
||||
show_column_numbers = True
|
||||
check_untyped_defs = True
|
||||
follow_imports = silent
|
||||
|
||||
# do not reenable this:
|
||||
# https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657
|
||||
warn_unused_ignores = False
|
||||
disallow_any_generics = True
|
||||
|
||||
files =
|
||||
torch/_dynamo,
|
||||
torch/_inductor
|
||||
|
||||
# We access some Python runtime classes / class members that are only available
|
||||
# in 3.11. These accesses are gated by runtime checks that cannot always be
|
||||
# understood by mypy.
|
||||
python_version = 3.11
|
||||
|
||||
[mypy-colorama.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-cutlass_library.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-deeplearning.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-dill.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-einops.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-libfb.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
# sympy is too dynamic, hard to type properly
|
||||
[mypy-sympy.*]
|
||||
ignore_missing_imports = True
|
||||
follow_imports = skip
|
||||
|
||||
[mypy-torch.*.fb.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
# FIXME: importing this creates lots of type errors
|
||||
[mypy-torch._dynamo.variables.*]
|
||||
follow_imports = skip
|
||||
|
||||
[mypy-torch.fb.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
# FIXME: importing this creates lots of type errors
|
||||
[mypy-torch.fx.*]
|
||||
follow_imports = skip
|
||||
|
||||
# FIXME: importing this creates lots of type errors
|
||||
[mypy-torch.testing._internal.*]
|
||||
follow_imports = skip
|
||||
|
||||
# sympy is too dynamic, hard to type properly
|
||||
[mypy-torch.utils._sympy.*]
|
||||
follow_imports = skip
|
||||
|
||||
[mypy-torch_xla.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torchvision.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-triton.*]
|
||||
ignore_missing_imports = True
|
||||
29
mypy.ini
29
mypy.ini
|
|
@ -178,6 +178,7 @@ ignore_missing_imports = True
|
|||
|
||||
[mypy-sympy.*]
|
||||
ignore_missing_imports = True
|
||||
follow_imports = skip
|
||||
|
||||
[mypy-hypothesis.*]
|
||||
ignore_missing_imports = True
|
||||
|
|
@ -270,7 +271,31 @@ ignore_missing_imports = True
|
|||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch._inductor.*]
|
||||
ignore_errors = True
|
||||
disallow_any_generics = True
|
||||
|
||||
[mypy-torch._dynamo.*]
|
||||
ignore_errors = True
|
||||
disallow_any_generics = True
|
||||
|
||||
[mypy-colorama.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-cutlass_library.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-deeplearning.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-einops.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-libfb.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch.*.fb.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch.fb.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch_xla.*]
|
||||
ignore_missing_imports = True
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import logging
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
from importlib import import_module
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
from torch._dynamo import register_backend
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# This backend is maintained by ONNX team. To direct issues
|
||||
# to the right people, please tag related GitHub issues with `module: onnx`.
|
||||
#
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import functools
|
||||
import sys
|
||||
from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# import torch # type: ignore[import]
|
||||
# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import]
|
||||
# from .registry import register_backend # type: ignore[import]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class AutogradCompilerInstance:
|
|||
backward_idx: int,
|
||||
):
|
||||
assert self.hooks_proxy is not None
|
||||
backward_fn = self.hooks_proxy[backward_idx]
|
||||
backward_fn = self.hooks_proxy[backward_idx] # type: ignore[index]
|
||||
proxies = self.fx_tracer.create_proxy(
|
||||
kind="call_function",
|
||||
target=call_backward,
|
||||
|
|
@ -139,7 +139,7 @@ class AutogradCompilerInstance:
|
|||
|
||||
def tensor_pre_hook(self, inputs, hook_id, i: int):
|
||||
assert self.hooks_proxy is not None
|
||||
hook = self.hooks_proxy[hook_id]
|
||||
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||||
proxy = self.proxy_call_hook(
|
||||
hook,
|
||||
inputs[i],
|
||||
|
|
@ -151,7 +151,7 @@ class AutogradCompilerInstance:
|
|||
|
||||
def pre_hook(self, inputs, hook_id):
|
||||
assert self.hooks_proxy is not None
|
||||
hook = self.hooks_proxy[hook_id]
|
||||
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||||
proxies = self.proxy_call_hook(
|
||||
hook,
|
||||
inputs,
|
||||
|
|
@ -163,7 +163,7 @@ class AutogradCompilerInstance:
|
|||
|
||||
def post_hook(self, outputs, inputs, hook_id):
|
||||
assert self.hooks_proxy is not None
|
||||
hook = self.hooks_proxy[hook_id]
|
||||
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||||
proxies = self.proxy_call_hook(
|
||||
hook,
|
||||
outputs,
|
||||
|
|
@ -177,7 +177,7 @@ class AutogradCompilerInstance:
|
|||
def post_acc_grad_hook(self, input, hook_id):
|
||||
assert isinstance(input, torch.Tensor)
|
||||
assert self.hooks_proxy is not None
|
||||
hook = self.hooks_proxy[hook_id]
|
||||
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||||
proxies = self.proxy_call_hook(
|
||||
hook,
|
||||
input,
|
||||
|
|
|
|||
|
|
@ -250,6 +250,7 @@ def generate_config_string(*, stable_output=False):
|
|||
if stable_output:
|
||||
return "# config omitted due to stable_output=True"
|
||||
|
||||
experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined]
|
||||
return f"""\
|
||||
import torch._dynamo.config
|
||||
import torch._inductor.config
|
||||
|
|
@ -258,7 +259,7 @@ import torch.fx.experimental._config
|
|||
{torch._dynamo.config.codegen_config()}
|
||||
{torch._inductor.config.codegen_config()}
|
||||
{torch._functorch.config.codegen_config()}
|
||||
{torch.fx.experimental._config.codegen_config()}
|
||||
{experimental_config}
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1398,6 +1398,7 @@ def export(
|
|||
case_name="cond_operands",
|
||||
)
|
||||
|
||||
assert graph is not None
|
||||
for node in graph.graph.nodes:
|
||||
if node.op == "get_attr" and isinstance(
|
||||
getattr(graph, node.target), torch.Tensor
|
||||
|
|
@ -1424,6 +1425,7 @@ def export(
|
|||
flat_args_dynamic_dims,
|
||||
)
|
||||
# Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check
|
||||
assert graph is not None
|
||||
graph.meta["input_shape_constraints"] = (
|
||||
[constraint.serializable_spec for constraint in constraints]
|
||||
if constraints
|
||||
|
|
|
|||
|
|
@ -1127,8 +1127,8 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||
# TODO: Why isn't this stored in meta :think:
|
||||
pl._dynamo_source = arg.source
|
||||
|
||||
gm._param_name_to_source = self.param_name_to_source
|
||||
gm._source_to_user_stacks = self.source_to_user_stacks
|
||||
gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
|
||||
gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
name = (
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Python polyfills for common builtins.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -338,7 +338,7 @@ class SideEffects:
|
|||
cg.extend_output(create_call_function(0, True))
|
||||
cg.add_cache(var)
|
||||
if isinstance(var.mutable_local, AttributeMutationNew):
|
||||
var.mutable_local.source = LocalSource(cg.tempvars[var])
|
||||
var.mutable_local.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
|
||||
elif isinstance(var.mutable_local, AttributeMutationNew):
|
||||
if isinstance(var, variables.AutogradFunctionContextVariable):
|
||||
unimplemented("AutogradFunctionContextVariable escaped")
|
||||
|
|
@ -455,7 +455,7 @@ class SideEffects:
|
|||
if isinstance(var, variables.ListVariable):
|
||||
# old[:] = new
|
||||
cg(var, allow_cache=False)
|
||||
cg(var.mutable_local.source)
|
||||
cg(var.mutable_local.source) # type: ignore[attr-defined]
|
||||
cg.extend_output(
|
||||
[
|
||||
cg.create_load_const(None),
|
||||
|
|
@ -468,11 +468,11 @@ class SideEffects:
|
|||
cg.tx.output.update_co_names("clear")
|
||||
cg.tx.output.update_co_names("update")
|
||||
|
||||
cg(var.mutable_local.source)
|
||||
cg(var.mutable_local.source) # type: ignore[attr-defined]
|
||||
cg.extend_output([create_instruction("LOAD_METHOD", argval="update")])
|
||||
cg(var, allow_cache=False)
|
||||
|
||||
cg(var.mutable_local.source)
|
||||
cg(var.mutable_local.source) # type: ignore[attr-defined]
|
||||
cg.extend_output([create_instruction("LOAD_METHOD", argval="clear")])
|
||||
|
||||
suffixes.append(
|
||||
|
|
@ -512,7 +512,7 @@ class SideEffects:
|
|||
elif isinstance(var, variables.TupleIteratorVariable):
|
||||
for _ in range(var.index):
|
||||
cg.load_import_from(utils.__name__, "iter_next")
|
||||
cg(var.mutable_local.source)
|
||||
cg(var.mutable_local.source) # type: ignore[attr-defined]
|
||||
cg.extend_output(create_call_function(1, True))
|
||||
cg.append_output(create_instruction("POP_TOP"))
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -436,7 +436,7 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
|||
from .source import is_constant_source
|
||||
|
||||
if value.source is not None and is_constant_source(value.source):
|
||||
if truth_fn(value.get_real_value()):
|
||||
if truth_fn(value.get_real_value()): # type: ignore[attr-defined]
|
||||
push and self.push(value)
|
||||
self.jump(inst)
|
||||
else:
|
||||
|
|
@ -807,7 +807,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||
assert val is None or isinstance(
|
||||
val, VariableTracker
|
||||
), f"push expects VariableTracker, got {typestr(val)}"
|
||||
self.stack.append(val)
|
||||
self.stack.append(val) # type: ignore[arg-type]
|
||||
|
||||
def push_many(self, vals: List[VariableTracker]):
|
||||
for val in vals:
|
||||
|
|
@ -921,7 +921,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||
name = inst.argval
|
||||
source = self.get_global_source(name)
|
||||
if name not in self.symbolic_globals:
|
||||
self.symbolic_globals[name] = object() # sentinel object
|
||||
self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object
|
||||
variable = self.output.side_effects.track_global_existing(
|
||||
source, self.symbolic_globals[name]
|
||||
)
|
||||
|
|
@ -1424,7 +1424,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||
assert inst.argval > 0
|
||||
obj = self.stack[-inst.arg].realize()
|
||||
assert isinstance(obj, ConstDictVariable)
|
||||
obj.call_method(self, "__setitem__", (k, v), {})
|
||||
obj.call_method(self, "__setitem__", (k, v), {}) # type: ignore[arg-type]
|
||||
|
||||
def SET_ADD(self, inst):
|
||||
v = self.pop()
|
||||
|
|
@ -1452,8 +1452,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||
if sys.version_info >= (3, 11):
|
||||
# MAKE_FUNCTION behavior actually changed in 3.11, see
|
||||
# https://github.com/python/cpython/pull/93189/
|
||||
assert hasattr(code.value, "co_qualname")
|
||||
fn_name = ConstantVariable.create(value=code.value.co_qualname)
|
||||
assert hasattr(code.value, "co_qualname") # type: ignore[attr-defined]
|
||||
fn_name = ConstantVariable.create(value=code.value.co_qualname) # type: ignore[attr-defined]
|
||||
defaults = None
|
||||
closure = None
|
||||
annotations = None
|
||||
|
|
@ -1676,8 +1676,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||
tos1 = self.stack[-2]
|
||||
assert isinstance(tos1, ConstDictVariable)
|
||||
|
||||
if all(k in tos1 for k in tos):
|
||||
self.push(TupleVariable([tos1.getitem_const(k) for k in tos]))
|
||||
if all(k in tos1 for k in tos): # type: ignore[attr-defined]
|
||||
self.push(TupleVariable([tos1.getitem_const(k) for k in tos])) # type: ignore[attr-defined]
|
||||
if sys.version_info < (3, 11):
|
||||
self.push(ConstantVariable.create(True))
|
||||
else:
|
||||
|
|
@ -1750,7 +1750,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||
for name in kw_names:
|
||||
assert isinstance(name, str)
|
||||
assert self.kw_names is None
|
||||
self.kw_names = ConstantVariable.create(value=kw_names)
|
||||
self.kw_names = ConstantVariable.create(value=kw_names) # type: ignore[assignment]
|
||||
|
||||
def PUSH_NULL(self, inst):
|
||||
self.push(NullVariable())
|
||||
|
|
@ -2405,7 +2405,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
closure_cells: Dict[str, VariableTracker],
|
||||
funcvar: BaseUserFunctionVariable,
|
||||
):
|
||||
f_globals = funcvar.get_globals()
|
||||
f_globals = funcvar.get_globals() # type: ignore[attr-defined]
|
||||
f_builtins = f_globals["__builtins__"]
|
||||
if not isinstance(f_builtins, dict):
|
||||
f_builtins = f_builtins.__dict__
|
||||
|
|
@ -2510,7 +2510,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
unimplemented("cant resume while inlining")
|
||||
|
||||
def RETURN_VALUE(self, inst):
|
||||
self.symbolic_result = self.pop()
|
||||
self.symbolic_result = self.pop() # type: ignore[assignment]
|
||||
self.instruction_pointer = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import sys
|
|||
|
||||
import torch
|
||||
import torch.testing
|
||||
from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
||||
IS_WINDOWS,
|
||||
TEST_WITH_CROSSREF,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
|
|
@ -44,14 +44,14 @@ def run_tests(needs=()):
|
|||
class TestCase(TorchTestCase):
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._exit_stack.close()
|
||||
cls._exit_stack.close() # type: ignore[attr-defined]
|
||||
super().tearDownClass()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._exit_stack = contextlib.ExitStack()
|
||||
cls._exit_stack.enter_context(
|
||||
cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined]
|
||||
cls._exit_stack.enter_context( # type: ignore[attr-defined]
|
||||
config.patch(
|
||||
raise_on_ctx_manager_usage=True,
|
||||
suppress_errors=False,
|
||||
|
|
|
|||
|
|
@ -43,12 +43,12 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
|
|||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._exit_stack.enter_context(
|
||||
cls._exit_stack.enter_context( # type: ignore[attr-defined]
|
||||
torch._dynamo.config.patch(debug_dir_root=cls.DEBUG_DIR)
|
||||
)
|
||||
# These configurations make new process startup slower. Disable them
|
||||
# for the minification tests to speed them up.
|
||||
cls._exit_stack.enter_context(
|
||||
cls._exit_stack.enter_context( # type: ignore[attr-defined]
|
||||
torch._inductor.config.patch(
|
||||
{
|
||||
# https://github.com/pytorch/pytorch/issues/100376
|
||||
|
|
@ -67,7 +67,7 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
|
|||
shutil.rmtree(cls.DEBUG_DIR)
|
||||
else:
|
||||
print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}")
|
||||
cls._exit_stack.close()
|
||||
cls._exit_stack.close() # type: ignore[attr-defined]
|
||||
|
||||
def _gen_codegen_fn_patch_code(self, device, bug_type):
|
||||
assert bug_type in ("compile_error", "runtime_error", "accuracy")
|
||||
|
|
|
|||
|
|
@ -2737,7 +2737,7 @@ Generate the torch object - Dynamo tracing rule (the wrapping variable) map.
|
|||
def get_torch_obj_rule_map():
|
||||
d: Dict[Any, VariableTracker] = dict()
|
||||
for m in torch_name_rule_map:
|
||||
for k, v in m.items():
|
||||
for k, v in m.items(): # type: ignore[attr-defined]
|
||||
obj = load_object(k)
|
||||
if obj is not None:
|
||||
if obj in d and d[obj] != v:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
from .base import VariableTracker
|
||||
from .builtin import BuiltinVariable
|
||||
from .constant import ConstantVariable, EnumVariable
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import collections
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import contextlib
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import operator
|
||||
from typing import Dict, List
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
import enum
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
MAX_CYCLE = 3000
|
||||
|
||||
import itertools
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import inspect
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import weakref
|
||||
from typing import Dict, List
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
from inspect import getattr_static
|
||||
|
||||
from ..bytecode_transformation import create_call_function
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import functools
|
||||
|
||||
import inspect
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
|
|
|
|||
|
|
@ -224,10 +224,10 @@ class DataTypePropagation:
|
|||
"store_reduction",
|
||||
):
|
||||
buf_name = node.args[1]
|
||||
return V.graph.get_dtype(buf_name)
|
||||
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
|
||||
|
||||
if node.target == operator.getitem:
|
||||
return self.deduce_node_dtype(node.args[0])
|
||||
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
|
||||
|
||||
assert isinstance(node.target, str)
|
||||
|
||||
|
|
@ -235,7 +235,7 @@ class DataTypePropagation:
|
|||
return node.args[1]
|
||||
|
||||
if node.target == "constant":
|
||||
return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]]
|
||||
return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index]
|
||||
|
||||
if node.target.startswith("masked_subblock"):
|
||||
return self.deduce_node_dtype_by_subgraph(node)
|
||||
|
|
|
|||
|
|
@ -2768,7 +2768,7 @@ class CppVecKernelChecker(CppVecKernel):
|
|||
# Support masked_load for BF16/FP16. Because the legalization will
|
||||
# insert to_dtype to convert the BF16/FP16 input to FP32.
|
||||
dtype = (
|
||||
V.graph.get_dtype(input_value.args[1])
|
||||
V.graph.get_dtype(input_value.args[1]) # type: ignore[arg-type]
|
||||
if input_value.target == "load"
|
||||
else input_value.args[-1]
|
||||
)
|
||||
|
|
@ -2784,7 +2784,7 @@ class CppVecKernelChecker(CppVecKernel):
|
|||
dtype in [torch.int32, torch.int64]
|
||||
and input_value.target == "load"
|
||||
):
|
||||
buffer = V.graph.get_buffer(input_value.args[1])
|
||||
buffer = V.graph.get_buffer(input_value.args[1]) # type: ignore[arg-type]
|
||||
# Check if load of a scalar tensor of integer
|
||||
if not (
|
||||
isinstance(buffer, TensorBox)
|
||||
|
|
@ -2907,14 +2907,14 @@ class CppKernelProxy(CppKernel):
|
|||
if node.target not in ["load"]:
|
||||
return False
|
||||
assert len(node.args) == 3
|
||||
load_dtype = V.graph.get_dtype(node.args[1])
|
||||
load_dtype = V.graph.get_dtype(node.args[1]) # type: ignore[arg-type]
|
||||
return load_dtype in DTYPE_LOWP_FP
|
||||
|
||||
def is_lowp_fp_store(node: torch.fx.Node):
|
||||
if node.target != "store":
|
||||
return False
|
||||
_, store_var, _, _, _ = node.args
|
||||
store_dtype = V.graph.get_dtype(store_var)
|
||||
store_dtype = V.graph.get_dtype(store_var) # type: ignore[arg-type]
|
||||
return store_dtype in DTYPE_LOWP_FP
|
||||
|
||||
sub_graph_nodes = list(sub_graph.nodes)
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ def draw_buffers(nodes: List[BaseSchedulerNode], print_graph=False, fname=None):
|
|||
if isinstance(node, ir.ComputedBuffer):
|
||||
dtype = node.data.dtype
|
||||
|
||||
metadata = TensorMetadata(group, dtype, None, None, None, None, None)
|
||||
metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type]
|
||||
node.meta["tensor_meta"] = metadata
|
||||
|
||||
if print_graph:
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
|||
bn_node = match.nodes[0]
|
||||
graph = match.graph
|
||||
gm = graph.owning_module
|
||||
bn_mod = getattr(gm, bn_node.target)
|
||||
bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type]
|
||||
|
||||
# We can only use efficient conv-bn for eval mode with track_running_stats
|
||||
if not bn_mod.track_running_stats or bn_mod.training:
|
||||
|
|
@ -100,11 +100,11 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
|||
input_node = bn_node.args[0]
|
||||
else:
|
||||
input_node = bn_node.kwargs["input"]
|
||||
if input_node.op != "call_module":
|
||||
if input_node.op != "call_module": # type: ignore[union-attr]
|
||||
return
|
||||
if not hasattr(gm, input_node.target):
|
||||
if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr]
|
||||
return
|
||||
input_mod = getattr(gm, input_node.target)
|
||||
input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr]
|
||||
supported_convs = [
|
||||
nn.Linear,
|
||||
nn.Conv1d,
|
||||
|
|
@ -118,7 +118,7 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
|||
return
|
||||
conv_node = input_node
|
||||
# Output of conv is used by other nodes, cannot optimize
|
||||
if len(conv_node.users) > 1:
|
||||
if len(conv_node.users) > 1: # type: ignore[union-attr]
|
||||
return
|
||||
|
||||
# Find a pair of conv and bn computation nodes to optimize.
|
||||
|
|
@ -130,15 +130,15 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
|||
# argument. `graph.get_attr` and
|
||||
# `graph.call_function` does not allow the `name` argument.
|
||||
conv_get_node = graph.create_node(
|
||||
op="get_attr", target=conv_node.target, name="get_conv"
|
||||
op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr]
|
||||
)
|
||||
bn_get_node = graph.create_node(
|
||||
op="get_attr", target=bn_node.target, name="get_bn"
|
||||
)
|
||||
if conv_node.args:
|
||||
conv_input = conv_node.args[0]
|
||||
if conv_node.args: # type: ignore[union-attr]
|
||||
conv_input = conv_node.args[0] # type: ignore[union-attr]
|
||||
else:
|
||||
conv_input = conv_node.kwargs["input"]
|
||||
conv_input = conv_node.kwargs["input"] # type: ignore[union-attr]
|
||||
# prepare args for the fused function
|
||||
args = (bn_get_node, conv_get_node, conv_input)
|
||||
# create a new node
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs):
|
|||
constant_fold(gm)
|
||||
# Make sure meta['val'] is properly set for all nodes
|
||||
fake_tensor_prop(gm, aot_example_inputs, True)
|
||||
binary_folding_pass.apply(gm.graph)
|
||||
binary_folding_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||
# If we don't have binary folding, we don't need to run the pass again.
|
||||
# TODO: remove the need to run fake_tensor_prop on the whole model.
|
||||
if counters["inductor"]["binary_folding"] == binary_folding:
|
||||
|
|
@ -63,7 +63,7 @@ def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs):
|
|||
fake_tensor_prop(gm, aot_example_inputs, True)
|
||||
|
||||
for pattern in pass_patterns:
|
||||
pattern.apply(gm.graph)
|
||||
pattern.apply(gm.graph) # type: ignore[arg-type]
|
||||
|
||||
# The CPU weight packing always assume the conv's weight is channels last,
|
||||
# So make sure the layout_optimization is on when doing it.
|
||||
|
|
|
|||
|
|
@ -465,7 +465,7 @@ def _sfdp_params_check(match):
|
|||
# attn_mask_node may be a float/int number.
|
||||
if not hasattr(attn_mask_node, "meta"):
|
||||
return False
|
||||
attn_mask = attn_mask_node.meta["val"]
|
||||
attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr]
|
||||
# Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool
|
||||
if (
|
||||
not isinstance(attn_mask, torch.Tensor)
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ class PostGradBatchLinearFusion(BatchFusion):
|
|||
|
||||
def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool:
|
||||
return (
|
||||
node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0
|
||||
node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value]
|
||||
)
|
||||
|
||||
def _is_input_2d(self, input: torch.fx.Node) -> bool:
|
||||
|
|
@ -150,10 +150,10 @@ class PostGradBatchLinearFusion(BatchFusion):
|
|||
return None
|
||||
|
||||
# only handle the cases where inputs are 2D tensors
|
||||
if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m):
|
||||
if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type]
|
||||
return None
|
||||
m, k = input_m.meta["tensor_meta"].shape
|
||||
n = weight_m.meta["tensor_meta"].shape[1]
|
||||
m, k = input_m.meta["tensor_meta"].shape # type: ignore[union-attr]
|
||||
n = weight_m.meta["tensor_meta"].shape[1] # type: ignore[union-attr]
|
||||
batch_key = ("batch_linear", m, k, n, bias_m is not None)
|
||||
return batch_key
|
||||
|
||||
|
|
@ -200,8 +200,8 @@ class PostGradBatchLinearFusion(BatchFusion):
|
|||
@register_fusion("group_linear", pre_grad=False)
|
||||
class GroupLinearFusion(GroupFusion):
|
||||
def _addmm_node_can_be_fused(self, node: torch.fx.Node):
|
||||
input_shape = node.args[1].meta["tensor_meta"].shape
|
||||
weight_shape = node.args[2].meta["tensor_meta"].shape
|
||||
input_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr]
|
||||
weight_shape = node.args[2].meta["tensor_meta"].shape # type: ignore[union-attr]
|
||||
return (
|
||||
node.kwargs.get("beta", 1.0) == 1.0
|
||||
and node.kwargs.get("alpha", 1.0) == 1.0
|
||||
|
|
@ -215,8 +215,8 @@ class GroupLinearFusion(GroupFusion):
|
|||
)
|
||||
|
||||
def _mm_node_can_be_fused(self, node: torch.fx.Node):
|
||||
input_shape = node.args[0].meta["tensor_meta"].shape
|
||||
weight_shape = node.args[1].meta["tensor_meta"].shape
|
||||
input_shape = node.args[0].meta["tensor_meta"].shape # type: ignore[union-attr]
|
||||
weight_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr]
|
||||
return (
|
||||
len(input_shape) == 2
|
||||
and len(weight_shape) == 2
|
||||
|
|
@ -295,11 +295,11 @@ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
|
|||
# its inputs, and cause dtype not same error in mm or addmm
|
||||
input, other = node.args
|
||||
return (
|
||||
input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape
|
||||
input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape # type: ignore[union-attr]
|
||||
if hasattr(input, "meta")
|
||||
and hasattr(other, "meta")
|
||||
and "tensor_meta" in input.meta
|
||||
and "tensor_meta" in other.meta
|
||||
and "tensor_meta" in input.meta # type: ignore[union-attr]
|
||||
and "tensor_meta" in other.meta # type: ignore[union-attr]
|
||||
else False
|
||||
)
|
||||
|
||||
|
|
@ -310,12 +310,12 @@ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
|
|||
alpha = node.kwargs.get("alpha", 1.0)
|
||||
rounding_mode = node.kwargs.get("rounding_mode", None)
|
||||
input, other = node.args
|
||||
shape = list(input.meta["tensor_meta"].shape)
|
||||
shape = list(input.meta["tensor_meta"].shape) # type: ignore[union-attr]
|
||||
group_key = (
|
||||
"batch_" + self.op.__name__.lower() + "_post_grad",
|
||||
str(shape),
|
||||
str(input.meta["tensor_meta"].dtype),
|
||||
str(other.meta["tensor_meta"].dtype),
|
||||
str(input.meta["tensor_meta"].dtype), # type: ignore[union-attr]
|
||||
str(other.meta["tensor_meta"].dtype), # type: ignore[union-attr]
|
||||
str(alpha),
|
||||
str(rounding_mode),
|
||||
)
|
||||
|
|
@ -827,7 +827,7 @@ def get_fusion_candidates(
|
|||
|
||||
|
||||
def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase):
|
||||
stable_topological_sort(graph)
|
||||
stable_topological_sort(graph) # type: ignore[arg-type]
|
||||
fused_set: Set[torch.fx.Node] = set()
|
||||
|
||||
for node in reversed(graph.nodes):
|
||||
|
|
@ -893,5 +893,5 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
|
|||
fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False)
|
||||
|
||||
for rule in fusions:
|
||||
apply_group_batch_fusion(graph, rule)
|
||||
apply_group_batch_fusion(graph, rule) # type: ignore[arg-type]
|
||||
print_graph(graph, f"Apply fusion {rule.__class__.__name__}.")
|
||||
|
|
|
|||
|
|
@ -271,7 +271,7 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
|
|||
constant_fold_uniform_value(graph)
|
||||
|
||||
if config.pattern_matcher:
|
||||
count += patterns.apply(graph.graph)
|
||||
count += patterns.apply(graph.graph) # type: ignore[arg-type]
|
||||
|
||||
if not config.fallback_random:
|
||||
count += replace_random_passes(graph)
|
||||
|
|
@ -317,7 +317,7 @@ def pointless_view(match: Match, arg, size):
|
|||
"""Remove no-op view"""
|
||||
graph = match.graph
|
||||
node = match.output_node()
|
||||
arg_size = list(node.args[0].meta["val"].shape)
|
||||
arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
|
||||
if size == arg_size:
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
match.erase_nodes(graph)
|
||||
|
|
|
|||
|
|
@ -343,11 +343,11 @@ if torch._C._has_mkldnn:
|
|||
if any(
|
||||
not (
|
||||
hasattr(n.args[0], "meta")
|
||||
and isinstance(n.args[0].meta.get("val", None), torch.Tensor)
|
||||
and isinstance(n.args[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
|
||||
)
|
||||
or not (
|
||||
hasattr(n.args[1], "meta")
|
||||
and isinstance(n.args[1].meta.get("val", None), torch.Tensor)
|
||||
and isinstance(n.args[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
|
||||
)
|
||||
for n in binary_nodes
|
||||
):
|
||||
|
|
@ -360,9 +360,9 @@ if torch._C._has_mkldnn:
|
|||
):
|
||||
return False
|
||||
if any(
|
||||
n.args[0].meta["val"].size() != n.args[1].meta["val"].size()
|
||||
or n.args[0].meta["val"].device != n.args[1].meta["val"].device
|
||||
or n.args[0].meta["val"].dtype != n.args[1].meta["val"].dtype
|
||||
n.args[0].meta["val"].size() != n.args[1].meta["val"].size() # type: ignore[union-attr]
|
||||
or n.args[0].meta["val"].device != n.args[1].meta["val"].device # type: ignore[union-attr]
|
||||
or n.args[0].meta["val"].dtype != n.args[1].meta["val"].dtype # type: ignore[union-attr]
|
||||
for n in binary_nodes
|
||||
):
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -85,13 +85,13 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
|||
remove_noop_ops(gm.graph)
|
||||
print_graph(gm.graph, "Before split cat in post grad pass.")
|
||||
for patterns in pass_patterns:
|
||||
patterns.apply(gm.graph)
|
||||
patterns.apply(gm.graph) # type: ignore[arg-type]
|
||||
print_graph(
|
||||
gm.graph,
|
||||
"Apply split cat pattern matcher PatternMatcherPass in post grad.",
|
||||
)
|
||||
if is_inference:
|
||||
inference_patterns.apply(gm.graph)
|
||||
inference_patterns.apply(gm.graph) # type: ignore[arg-type]
|
||||
|
||||
if config.post_grad_custom_post_pass is not None:
|
||||
config.post_grad_custom_post_pass(gm.graph)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
|
|||
# explicitly run with predispatch atenIR based passes
|
||||
if config.is_predispatch:
|
||||
group_batch_fusion_passes(gm.graph, pre_grad=True)
|
||||
predispatch_pass.apply(gm.graph)
|
||||
predispatch_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||
else:
|
||||
gm = fuse_fx(gm, example_inputs)
|
||||
numpy_compat_normalization(gm.graph)
|
||||
|
|
@ -82,7 +82,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
|
|||
group_batch_fusion_passes(gm.graph, pre_grad=True)
|
||||
print_graph(gm.graph, "Before split cat in pre grad pass.")
|
||||
for pattern_matcher_pass in pattern_matcher_passes:
|
||||
pattern_matcher_pass.apply(gm.graph)
|
||||
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||
print_graph(
|
||||
gm.graph,
|
||||
"Apply split cat pattern matcher PatternMatcherPass in pre grad.",
|
||||
|
|
@ -245,21 +245,21 @@ class NormalizedLinearNode:
|
|||
|
||||
def get_input(self) -> torch.fx.Node:
|
||||
if len(self.node.args) > 0:
|
||||
return self.node.args[0]
|
||||
return self.node.args[0] # type: ignore[return-value]
|
||||
else:
|
||||
return self.node.kwargs["input"]
|
||||
return self.node.kwargs["input"] # type: ignore[return-value]
|
||||
|
||||
def get_weight(self) -> torch.fx.Node:
|
||||
if len(self.node.args) > 1:
|
||||
return self.node.args[1]
|
||||
return self.node.args[1] # type: ignore[return-value]
|
||||
else:
|
||||
return self.node.kwargs["weight"]
|
||||
return self.node.kwargs["weight"] # type: ignore[return-value]
|
||||
|
||||
def get_bias(self) -> torch.fx.Node:
|
||||
if len(self.node.args) > 2:
|
||||
return self.node.args[2]
|
||||
return self.node.args[2] # type: ignore[return-value]
|
||||
else:
|
||||
return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None
|
||||
return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value]
|
||||
|
||||
|
||||
class NormalizedMatmulNode:
|
||||
|
|
@ -270,27 +270,27 @@ class NormalizedMatmulNode:
|
|||
|
||||
def get_input(self) -> torch.fx.Node:
|
||||
if len(self.node.args) > 0:
|
||||
return self.node.args[0]
|
||||
return self.node.args[0] # type: ignore[return-value]
|
||||
else:
|
||||
return self.node.kwargs["input"]
|
||||
return self.node.kwargs["input"] # type: ignore[return-value]
|
||||
|
||||
def get_other(self) -> torch.fx.Node:
|
||||
if len(self.node.args) > 1:
|
||||
return self.node.args[1]
|
||||
return self.node.args[1] # type: ignore[return-value]
|
||||
else:
|
||||
return self.node.kwargs["other"]
|
||||
return self.node.kwargs["other"] # type: ignore[return-value]
|
||||
|
||||
|
||||
def check_permute(node: torch.fx.Node) -> bool:
|
||||
ranks = len(node.meta["tensor_meta"].shape)
|
||||
if len(node.args) > 3:
|
||||
permutation = [node.args[i] % ranks for i in range(1, ranks + 1)]
|
||||
permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator]
|
||||
elif (
|
||||
"permutation" in node.kwargs
|
||||
and node.kwargs["permutation"] is not None
|
||||
and len(node.kwargs["permutation"]) > 2
|
||||
and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type]
|
||||
):
|
||||
permutation = [i % ranks for i in node.kwargs["permutation"]]
|
||||
permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[union-attr]
|
||||
else:
|
||||
return False
|
||||
allowed_permutation = list(range(ranks))
|
||||
|
|
@ -443,9 +443,9 @@ def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|||
):
|
||||
Atrans = True
|
||||
if len(input_A_node.args) > 0:
|
||||
input_A = input_A_node.args[0]
|
||||
input_A = input_A_node.args[0] # type: ignore[assignment]
|
||||
else:
|
||||
input_A = input_A_node.kwargs["input"]
|
||||
input_A = input_A_node.kwargs["input"] # type: ignore[assignment]
|
||||
|
||||
if (
|
||||
input_B_node.op == "call_method"
|
||||
|
|
@ -454,9 +454,9 @@ def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|||
):
|
||||
Btrans = True
|
||||
if len(input_B_node.args) > 0:
|
||||
input_B = input_B_node.args[0]
|
||||
input_B = input_B_node.args[0] # type: ignore[assignment]
|
||||
else:
|
||||
input_B = input_B_node.kwargs["input"]
|
||||
input_B = input_B_node.kwargs["input"] # type: ignore[assignment]
|
||||
|
||||
if Atrans or Btrans:
|
||||
with module.graph.inserting_before(node):
|
||||
|
|
|
|||
|
|
@ -484,16 +484,16 @@ def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype):
|
|||
# the two inputs of binary node should have attribute "meta" and should be tensors
|
||||
if not (
|
||||
hasattr(binary_node_inputs[0], "meta")
|
||||
and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor)
|
||||
and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
|
||||
) or not (
|
||||
hasattr(binary_node_inputs[1], "meta")
|
||||
and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor)
|
||||
and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
|
||||
):
|
||||
return False
|
||||
# the two inputs of binary node should have the same shape
|
||||
if (
|
||||
binary_node_inputs[0].meta["val"].size()
|
||||
!= binary_node_inputs[1].meta["val"].size()
|
||||
binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr]
|
||||
!= binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr]
|
||||
):
|
||||
return False
|
||||
|
||||
|
|
@ -951,12 +951,12 @@ def _is_input_output_same_scale_zp(check_node):
|
|||
scales = [
|
||||
(
|
||||
mul_node.args[1]
|
||||
if mul_node.args[0].target is check_node
|
||||
else 1.0 / mul_node.args[1]
|
||||
if mul_node.args[0].target is check_node # type: ignore[union-attr]
|
||||
else 1.0 / mul_node.args[1] # type: ignore[operator]
|
||||
)
|
||||
for mul_node in mul_nodes
|
||||
]
|
||||
if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales):
|
||||
if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
@ -1178,7 +1178,7 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
|
|||
_user_node = user_node
|
||||
while _source_node != dequant_pattern_start_node.args[0]:
|
||||
_user_node = clone_to_new_node(graph, _source_node, _user_node)
|
||||
_source_node = _source_node.args[0]
|
||||
_source_node = _source_node.args[0] # type: ignore[assignment]
|
||||
|
||||
counters["inductor"]["dequant_promotion_matcher_count"] += 1
|
||||
counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes)
|
||||
|
|
@ -1256,11 +1256,11 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
|
|||
mul_node = conv_node.args[0]
|
||||
else:
|
||||
convert_to_bf16 = conv_node.args[0]
|
||||
mul_node = convert_to_bf16.args[0]
|
||||
sub_node = mul_node.args[0]
|
||||
to_fp32_node = sub_node.args[0]
|
||||
mul_node = convert_to_bf16.args[0] # type: ignore[union-attr]
|
||||
sub_node = mul_node.args[0] # type: ignore[union-attr]
|
||||
to_fp32_node = sub_node.args[0] # type: ignore[union-attr]
|
||||
has_clone_to_channel_last_node_in_pattern = (
|
||||
conv_node.args[1].target is aten.clone.default
|
||||
conv_node.args[1].target is aten.clone.default # type: ignore[union-attr]
|
||||
)
|
||||
clone_node = (
|
||||
conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None
|
||||
|
|
@ -1268,20 +1268,20 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
|
|||
|
||||
if dtype == torch.float32:
|
||||
dequant_per_channel = (
|
||||
clone_node.args[0]
|
||||
clone_node.args[0] # type: ignore[union-attr]
|
||||
if has_clone_to_channel_last_node_in_pattern
|
||||
else conv_node.args[1]
|
||||
)
|
||||
else:
|
||||
weight_to_bf16_node = (
|
||||
clone_node.args[0]
|
||||
clone_node.args[0] # type: ignore[union-attr]
|
||||
if has_clone_to_channel_last_node_in_pattern
|
||||
else conv_node.args[1]
|
||||
)
|
||||
dequant_per_channel = weight_to_bf16_node.args[0]
|
||||
dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr]
|
||||
|
||||
assert (
|
||||
dequant_per_channel.target
|
||||
dequant_per_channel.target # type: ignore[union-attr]
|
||||
is quantized_decomposed.dequantize_per_channel.default
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ def _decompose_scatter_functional_helper(
|
|||
view = graph_call_function(
|
||||
graph, view_op.target, inp, *view_op.args, **view_op.kwargs
|
||||
)
|
||||
src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:])
|
||||
src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment]
|
||||
|
||||
return graph_call_function(
|
||||
graph,
|
||||
|
|
@ -114,7 +114,7 @@ def _decompose_scatter_functional(
|
|||
"""
|
||||
assert node.target is _generalized_scatter
|
||||
inp, src, view_ops = node.args
|
||||
return _decompose_scatter_functional_helper(graph, *node.args)
|
||||
return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _decompose_scatter_mutating(
|
||||
|
|
@ -140,11 +140,11 @@ def _decompose_scatter_mutating(
|
|||
inp = graph_call_function(graph, aten.clone, inp)
|
||||
|
||||
tmp = inp
|
||||
for view in view_ops:
|
||||
tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs)
|
||||
for view in view_ops: # type: ignore[union-attr]
|
||||
tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr]
|
||||
|
||||
graph_call_function(graph, aten.copy_.default, tmp, src)
|
||||
return inp
|
||||
return inp # type: ignore[return-value]
|
||||
|
||||
|
||||
# View ops whose view_scatter op is lowered into mutations anyway,
|
||||
|
|
@ -157,7 +157,7 @@ _ALWAYS_MUTATING_SCATTER_OPS = {
|
|||
|
||||
def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
|
||||
_, _, view_ops = node.args
|
||||
return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops)
|
||||
return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr]
|
||||
|
||||
|
||||
def should_reinplace_scatter(node: torch.fx.Node) -> bool:
|
||||
|
|
@ -174,12 +174,12 @@ def should_reinplace_scatter(node: torch.fx.Node) -> bool:
|
|||
if scatter_always_uses_mutation(node):
|
||||
return True
|
||||
|
||||
if is_node_realized(inp) and is_node_realized(node):
|
||||
if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type]
|
||||
return True
|
||||
|
||||
# If the output is copied back into the input, this forces both to be
|
||||
# realized as the output is a user of the input
|
||||
if inp.op == "placeholder" and any(
|
||||
if inp.op == "placeholder" and any( # type: ignore[union-attr]
|
||||
user.target is aten.copy_.default and user.args[0] is inp for user in node.users
|
||||
):
|
||||
return True
|
||||
|
|
@ -234,11 +234,11 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
|||
|
||||
def handle_views(node: torch.fx.Node):
|
||||
inp = node.args[0]
|
||||
node_to_view_base[node] = node_to_view_base.get(inp, inp)
|
||||
node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
|
||||
node_to_view_op[node] = [
|
||||
*node_to_view_op[inp],
|
||||
*node_to_view_op[inp], # type: ignore[index]
|
||||
ViewOp(
|
||||
node.target,
|
||||
node.target, # type: ignore[arg-type]
|
||||
args=node.args[1:],
|
||||
kwargs=node.kwargs,
|
||||
),
|
||||
|
|
@ -255,14 +255,14 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
|||
)
|
||||
|
||||
def can_fuse():
|
||||
if src.target is not _generalized_scatter:
|
||||
if src.target is not _generalized_scatter: # type: ignore[union-attr]
|
||||
return False
|
||||
src_inp, src_src, src_scatter_view_op = src.args
|
||||
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
|
||||
|
||||
inp_base = node_to_view_base.get(inp, inp)
|
||||
src_base = node_to_view_base.get(src_inp, src_inp)
|
||||
return inp_base is src_base and node_to_view_op[src_inp] == [
|
||||
*node_to_view_op[inp],
|
||||
inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
|
||||
src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type]
|
||||
return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index]
|
||||
*node_to_view_op[inp], # type: ignore[index]
|
||||
scatter_view_op,
|
||||
]
|
||||
|
||||
|
|
@ -279,19 +279,19 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
|||
graph.erase_node(node)
|
||||
return
|
||||
|
||||
src_inp, src_src, src_scatter_view_op = src.args
|
||||
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
|
||||
with graph.inserting_before(src):
|
||||
new_node = graph_call_function(
|
||||
graph,
|
||||
_generalized_scatter,
|
||||
inp,
|
||||
src_src,
|
||||
[scatter_view_op, *src_scatter_view_op],
|
||||
[scatter_view_op, *src_scatter_view_op], # type: ignore[misc]
|
||||
)
|
||||
node.replace_all_uses_with(new_node)
|
||||
graph.erase_node(node)
|
||||
|
||||
if src.users:
|
||||
if src.users: # type: ignore[union-attr]
|
||||
new_src = graph_call_function(
|
||||
graph,
|
||||
_SCATTER_OP_TO_VIEW[node.target],
|
||||
|
|
@ -301,7 +301,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
|||
)
|
||||
|
||||
handle_views(new_src)
|
||||
src.replace_all_uses_with(new_src)
|
||||
src.replace_all_uses_with(new_src) # type: ignore[union-attr]
|
||||
|
||||
graph.erase_node(src)
|
||||
|
||||
|
|
|
|||
|
|
@ -112,7 +112,9 @@ def replace_random(
|
|||
mode = {
|
||||
aten.rand: "rand",
|
||||
aten.randn: "randn",
|
||||
}[match.output_node().target.overloadpacket]
|
||||
}[
|
||||
match.output_node().target.overloadpacket # type: ignore[union-attr]
|
||||
] # type: ignore[union-attr]
|
||||
device = get_device(device)
|
||||
match.replace_by_example(replacement, [size])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python
|
||||
|
|
|
|||
|
|
@ -322,7 +322,7 @@ class TorchSplit(CallFunction):
|
|||
return FailedMatch("only integer getitems are handled")
|
||||
if user.args[1] in seen_idxs:
|
||||
return FailedMatch(f"duplicate getitem {user.args[1]}")
|
||||
if user.args[-1] < 0:
|
||||
if user.args[-1] < 0: # type: ignore[operator]
|
||||
# This shouldn't ideally happen as dynamo normalizes indexes to positive
|
||||
return FailedMatch("negative index")
|
||||
seen_idxs.add(user.args[1])
|
||||
|
|
@ -358,13 +358,13 @@ def merge_splits(
|
|||
if len(node.users.keys()) == 0:
|
||||
return
|
||||
graph = match.graph
|
||||
first_split = node.args[0].args[0]
|
||||
next_split_index = node.args[0].args[1]
|
||||
first_split = node.args[0].args[0] # type: ignore[union-attr]
|
||||
next_split_index = node.args[0].args[1] # type: ignore[union-attr]
|
||||
|
||||
new_split_sections = list(first_split_sections)
|
||||
new_split_sections[next_split_index : next_split_index + 1] = next_split_sections
|
||||
new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc]
|
||||
|
||||
first_split_dim = first_split.kwargs["dim"]
|
||||
first_split_dim = first_split.kwargs["dim"] # type: ignore[union-attr]
|
||||
|
||||
to_remove = []
|
||||
|
||||
|
|
@ -376,7 +376,7 @@ def merge_splits(
|
|||
kwargs={"dim": first_split_dim},
|
||||
)
|
||||
first_split_num_to_user = {
|
||||
user.args[1]: user for user in first_split.users.keys()
|
||||
user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr]
|
||||
}
|
||||
|
||||
new_split_num = 0
|
||||
|
|
@ -406,7 +406,7 @@ def merge_splits(
|
|||
to_remove.append(node)
|
||||
to_remove.append(old_getitem)
|
||||
|
||||
to_remove.append(first_split)
|
||||
to_remove.append(first_split) # type: ignore[arg-type]
|
||||
for node in to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
|
|
@ -463,9 +463,9 @@ class SplitCatSimplifier:
|
|||
graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
|
||||
)
|
||||
self.replace_cat(
|
||||
graph, split_node, next_users, user_inputs_list_new, transform_params_list
|
||||
graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type]
|
||||
)
|
||||
self.erase_old_nodes(graph, split_node, next_users)
|
||||
self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type]
|
||||
|
||||
def get_user_input_list(
|
||||
self, split_node: torch.fx.Node, next_users: List[torch.fx.Node]
|
||||
|
|
@ -481,7 +481,7 @@ class SplitCatSimplifier:
|
|||
if user.target in {torch.cat, torch.stack}:
|
||||
user_inputs_list.append(self.get_merged_user_inputs(split_node, user))
|
||||
else:
|
||||
user_inputs_list.append(self.get_non_cat_node_input(split_node, user))
|
||||
user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type]
|
||||
return user_inputs_list
|
||||
|
||||
def get_merged_user_inputs(
|
||||
|
|
@ -536,10 +536,10 @@ class SplitCatSimplifier:
|
|||
if cur_range:
|
||||
merged_ranges.append(tuple(cur_range))
|
||||
cur_range = None
|
||||
merged_ranges.append(input_)
|
||||
merged_ranges.append(input_) # type: ignore[arg-type]
|
||||
if cur_range:
|
||||
merged_ranges.append(tuple(cur_range))
|
||||
return merged_ranges
|
||||
return merged_ranges # type: ignore[return-value]
|
||||
|
||||
def get_simplified_split_ranges(
|
||||
self,
|
||||
|
|
@ -618,7 +618,7 @@ class SplitCatSimplifier:
|
|||
transform_params.append((None, None, None, None))
|
||||
elif isinstance(user_input, tuple): # Split being simplified
|
||||
# Verify equal split
|
||||
subset_split_sections = split_sections[
|
||||
subset_split_sections = split_sections[ # type: ignore[index]
|
||||
user_input[0] : user_input[1] + 1
|
||||
]
|
||||
# All sections should be equal
|
||||
|
|
@ -698,7 +698,7 @@ class SplitCatSimplifier:
|
|||
else:
|
||||
new_user_inputs.append(user_input)
|
||||
new_user_inputs_list.append(new_user_inputs)
|
||||
return new_user_inputs_list
|
||||
return new_user_inputs_list # type: ignore[return-value]
|
||||
|
||||
def replace_cat(
|
||||
self,
|
||||
|
|
@ -837,10 +837,10 @@ class UnbindCatRemover(SplitCatSimplifier):
|
|||
graph: torch.fx.Graph,
|
||||
unbind_node: torch.fx.Node,
|
||||
):
|
||||
num_unbind = (
|
||||
max(getitem_node.args[1] for getitem_node in unbind_node.users.keys()) + 1
|
||||
num_unbind = ( # type: ignore[operator]
|
||||
max(getitem_node.args[1] for getitem_node in unbind_node.users.keys()) + 1 # type: ignore[operator, union-attr, type-var]
|
||||
)
|
||||
split_sections = [1 for _ in range(num_unbind)]
|
||||
split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type]
|
||||
|
||||
super().simplify(graph, unbind_node, split_sections)
|
||||
|
||||
|
|
@ -1115,11 +1115,11 @@ def safe_to_abort_node(node: torch.fx.Node):
|
|||
2. the user of all the input nodes should be only one
|
||||
"""
|
||||
prev_node = None
|
||||
for arg in node.args[0]:
|
||||
if len(arg.users) != 1 or arg.target != operator.getitem:
|
||||
for arg in node.args[0]: # type: ignore[union-attr]
|
||||
if len(arg.users) != 1 or arg.target != operator.getitem: # type: ignore[union-attr]
|
||||
return False
|
||||
if prev_node is None:
|
||||
prev_node = arg.args[0]
|
||||
prev_node = arg.args[0] # type: ignore[union-attr]
|
||||
else:
|
||||
if arg.args[0] != prev_node:
|
||||
return False
|
||||
|
|
@ -1173,8 +1173,8 @@ def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
|
|||
continue
|
||||
# find the index of getitems to be cated/stacked
|
||||
indices = []
|
||||
for arg in cat_user.args[0]:
|
||||
indices.append(arg.args[1])
|
||||
for arg in cat_user.args[0]: # type: ignore[union-attr]
|
||||
indices.append(arg.args[1]) # type: ignore[union-attr]
|
||||
# indices may not be necessarily sorted, we sort them first
|
||||
indices.sort()
|
||||
# the gettitems to be merged must be consecutive, otherwise
|
||||
|
|
@ -1182,12 +1182,12 @@ def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
|
|||
if indices[len(indices) - 1] - indices[0] + 1 != len(indices):
|
||||
continue
|
||||
# update the arg of cat user, only keep the first getitem
|
||||
cat_user.update_arg(0, cat_user.args[0][0])
|
||||
cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index]
|
||||
# calculate the fused tensor sizes in the indices
|
||||
fused_tensor_size = 0
|
||||
for i in range(len(split_node.args[1])):
|
||||
for i in range(len(split_node.args[1])): # type: ignore[arg-type]
|
||||
if i in indices:
|
||||
fused_tensor_size += split_node.args[1][i]
|
||||
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
|
||||
# update the split sections
|
||||
split_sections[indices[0]] = fused_tensor_size
|
||||
# padding others with zeros to keep the same dict size
|
||||
|
|
@ -1320,9 +1320,9 @@ def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int):
|
|||
# find the index of getitems to be stacked
|
||||
indices = []
|
||||
split_sections_for_unbind = []
|
||||
for arg in user.args[0]:
|
||||
indices.append(arg.args[1])
|
||||
split_sections_for_unbind.append(split_sections[arg.args[1]])
|
||||
for arg in user.args[0]: # type: ignore[union-attr]
|
||||
indices.append(arg.args[1]) # type: ignore[union-attr]
|
||||
split_sections_for_unbind.append(split_sections[arg.args[1]]) # type: ignore[union-attr]
|
||||
# indices may not be necessarily sorted, we sort them first
|
||||
indices.sort()
|
||||
# the gettitems to be merged must be consecutive, otherwise
|
||||
|
|
@ -1330,12 +1330,12 @@ def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int):
|
|||
if indices[len(indices) - 1] - indices[0] + 1 != len(indices):
|
||||
continue
|
||||
# update the arg of stack user, only keep the first getitem
|
||||
user.update_arg(0, user.args[0][0])
|
||||
user.update_arg(0, user.args[0][0]) # type: ignore[index]
|
||||
# calculate the fused tensor sizes in the indices
|
||||
fused_tensor_size = 0
|
||||
for i in range(len(split_node.args[1])):
|
||||
for i in range(len(split_node.args[1])): # type: ignore[arg-type]
|
||||
if i in indices:
|
||||
fused_tensor_size += split_node.args[1][i]
|
||||
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, index, assignment]
|
||||
# update the split sections
|
||||
split_sections[indices[0]] = fused_tensor_size
|
||||
# padding others with zeros to keep the same dict size
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ def is_node_realized(node: torch.fx.Node) -> bool:
|
|||
# getitem = foo[0]
|
||||
# getitem_1 = foo[1]
|
||||
# where we need to check if foo is a fallback kernel
|
||||
return is_buffer(node.args[0])
|
||||
return is_buffer(node.args[0]) # type: ignore[arg-type]
|
||||
return node.op in ("placeholder", "output") or node.target in fallbacks
|
||||
|
||||
if is_buffer(node):
|
||||
|
|
|
|||
|
|
@ -6578,7 +6578,7 @@ class InterpreterShim(torch.fx.Interpreter):
|
|||
# call super() with a placeholder to avoid constructing a
|
||||
# GraphModule which is very expensive (it does codegen).
|
||||
super().__init__(self._dummy_gm(), garbage_collect_values=False)
|
||||
self.module = self
|
||||
self.module = self # type: ignore[assignment]
|
||||
self.graph = graph
|
||||
self.submodules = submodules
|
||||
self.extra_traceback = False
|
||||
|
|
|
|||
|
|
@ -3769,7 +3769,7 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
|
|||
)
|
||||
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
|
||||
# Sliding windows must start within the input or left padding
|
||||
x_alt -= 1
|
||||
x_alt -= 1 # type: ignore[assignment]
|
||||
V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i])
|
||||
if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
|
||||
# ceil mode is actually a no-op, lets guard on that
|
||||
|
|
|
|||
|
|
@ -551,7 +551,7 @@ class ListOf(PatternExpr):
|
|||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self.pattern})"
|
||||
|
||||
def _match(self, node: List[torch.fx.Node], ctx: MatchContext):
|
||||
def _match(self, node: List[torch.fx.Node], ctx: MatchContext): # type: ignore[override]
|
||||
if not isinstance(node, (list, tuple)) or len(node) == 0:
|
||||
return FailedMatch("non_list")
|
||||
m = Match(self)
|
||||
|
|
@ -780,9 +780,9 @@ class ReplacementPatternEntry(PatternEntry):
|
|||
first_node = output_nodes[0]
|
||||
|
||||
class Replacer(torch.fx.Interpreter):
|
||||
call_method = None
|
||||
call_module = None
|
||||
get_attr = None
|
||||
call_method = None # type: ignore[assignment]
|
||||
call_module = None # type: ignore[assignment]
|
||||
get_attr = None # type: ignore[assignment]
|
||||
|
||||
def run_node(self, node) -> Any:
|
||||
if node.op in ("placeholder", "output"):
|
||||
|
|
@ -861,7 +861,7 @@ class ReplacementPatternEntry(PatternEntry):
|
|||
self.replace_with_graph(
|
||||
match,
|
||||
graph,
|
||||
match.replacement_graph,
|
||||
match.replacement_graph, # type: ignore[arg-type]
|
||||
self.normalize_args(*match.args, **match.kwargs),
|
||||
)
|
||||
|
||||
|
|
@ -991,10 +991,10 @@ def register_replacement(
|
|||
exclusive_arg_names=exclusive_arg_names,
|
||||
scalar_workaround=scalar_workaround,
|
||||
)
|
||||
specific_pattern_match = specific_pattern.match(match.output_nodes()[0])
|
||||
specific_pattern_match = specific_pattern.match(match.output_nodes()[0]) # type: ignore[arg-type]
|
||||
if specific_pattern_match and extra_check(specific_pattern_match):
|
||||
# trace the pattern using the shapes from the user program
|
||||
match.replacement_graph = trace_fn(replace_fn, args)
|
||||
match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
@ -1120,10 +1120,10 @@ _mutation_op_re = re.compile(r"_$|(\b|_)(set|enter|exit|seed)(\b|_)")
|
|||
|
||||
def is_mutation_op(node: torch.fx.Node) -> bool:
|
||||
if node.op == "call_function":
|
||||
if _mutation_op_re.search(node.target.__name__):
|
||||
if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
|
||||
return True
|
||||
elif node.op == "call_method":
|
||||
if _mutation_op_re.search(node.target):
|
||||
if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type]
|
||||
return True
|
||||
return node.kwargs.get("out") is not None
|
||||
|
||||
|
|
@ -1203,7 +1203,7 @@ class PatternMatcherPass:
|
|||
log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
|
||||
if is_match(m) and entry.extra_check(m):
|
||||
count += 1
|
||||
entry.apply(m, graph, node)
|
||||
entry.apply(m, graph, node) # type: ignore[arg-type]
|
||||
counters["inductor"]["pattern_matcher_count"] += 1
|
||||
counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
|
||||
return count
|
||||
|
|
@ -1334,7 +1334,7 @@ def joint_fwd_bwd(fn, args) -> torch.fx.GraphModule:
|
|||
GraphPatternEntry(
|
||||
pattern=pattern, handler=pointless_view, extra_check=_return_true
|
||||
).register(matcher_pass.patterns)
|
||||
matcher_pass.apply(gm.graph)
|
||||
matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||
|
||||
# remove in/out specs
|
||||
gm.graph._codegen = torch.fx.graph.CodeGen()
|
||||
|
|
@ -1438,7 +1438,7 @@ def get_arg_value(
|
|||
return (
|
||||
node.args[arg_number]
|
||||
if len(node.args) > arg_number
|
||||
else node.kwargs.get(kwarg_name)
|
||||
else node.kwargs.get(kwarg_name) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1456,5 +1456,5 @@ def extract_target(node: Node):
|
|||
as a function.
|
||||
"""
|
||||
if node.op == "call_module":
|
||||
return getattr(node.graph.owning_module, node.target).__class__
|
||||
return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type]
|
||||
return node.target
|
||||
|
|
|
|||
|
|
@ -466,7 +466,7 @@ class SymPyValueRangeAnalysis:
|
|||
ndigits = ndigits.lower
|
||||
# We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind
|
||||
# the second parameter.
|
||||
fn = lambda number: RoundDecimal(number, ndigits) # noqa: E731
|
||||
fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731
|
||||
|
||||
return ValueRanges.increasing_map(number, fn)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user