From cad79bd0bb3bb8191913e9f02fedac749a8e23fc Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 27 Jan 2024 19:39:22 -0800 Subject: [PATCH] Remove follow_imports = skip from sympy (#118469) dmypy silently ignores follow_imports = skip, so to get parity between dmypy and mypy we have to suck it up and type: ignore all of the sympy typing problems. The suppressions were added automatically with the following script generated by GPT-4: ``` import re # Read the error file with open("error_file.txt", "r") as f: errors = f.readlines() # Parse the lines with errors and error types error_lines = {} for error in errors: match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error) if match: file_path, line_number, error_type = match.groups() if file_path not in error_lines: error_lines[file_path] = {} error_lines[file_path][int(line_number)] = error_type # Insert ignore comments in the source files for file_path, lines in error_lines.items(): with open(file_path, "r") as f: code = f.readlines() for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True): code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n" with open(file_path, "w") as f: f.writelines(code) ``` Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/118469 Approved by: https://github.com/Skylion007 ghstack dependencies: #118414, #118418, #118432, #118467, #118468 --- .lintrunner.toml | 1 + mypy.ini | 1 - torch/_inductor/codegen/common.py | 8 ++-- torch/_inductor/codegen/cpp.py | 24 +++++----- torch/_inductor/codegen/memory_planning.py | 2 +- torch/_inductor/codegen/triton.py | 38 +++++++-------- torch/_inductor/codegen/triton_utils.py | 2 +- torch/_inductor/codegen/wrapper.py | 12 ++--- torch/_inductor/dependencies.py | 10 ++-- torch/_inductor/fx_passes/post_grad.py | 2 +- torch/_inductor/ir.py | 46 +++++++++---------- torch/_inductor/kernel/conv.py | 8 ++-- torch/_inductor/kernel/mm_common.py | 6 +-- torch/_inductor/lowering.py | 14 +++--- torch/_inductor/sizevars.py | 34 +++++++------- torch/_inductor/utils.py | 10 ++-- .../ao/nn/intrinsic/qat/modules/conv_fused.py | 6 +-- 17 files changed, 112 insertions(+), 112 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index a3c1a15531a..a4dbd24d2c9 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -138,6 +138,7 @@ init_command = [ 'numpy==1.26.0 ; python_version >= "3.9"', 'expecttest==0.1.6', 'mypy==1.7.0', + 'sympy==1.11.1', 'types-requests==2.27.25', 'types-PyYAML==6.0.7', 'types-tabulate==0.8.8', diff --git a/mypy.ini b/mypy.ini index 545d06d1b4d..4895da84b77 100644 --- a/mypy.ini +++ b/mypy.ini @@ -179,7 +179,6 @@ ignore_missing_imports = True [mypy-sympy.*] ignore_missing_imports = True -follow_imports = skip [mypy-hypothesis.*] ignore_missing_imports = True diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 3de41f4be6d..22c6d6b8f44 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1143,7 +1143,7 @@ class Kernel(CodeGen): ): # Skip CSE since this doesn't return an expression - if var.bounds.lower < 0: + if var.bounds.lower < 0: # type: ignore[operator] new_bounds = ValueRanges.unknown() if var.bounds != ValueRanges.unknown() and isinstance( size, sympy.Number @@ -1154,13 +1154,13 @@ class Kernel(CodeGen): neg = var.bounds & ValueRanges(-sympy.oo, -1) new_bounds = ValueRanges(neg.lower + size, neg.upper + size) # We don't have a good way of representing the empty range - if var.bounds.upper >= 0: + if var.bounds.upper >= 0: # type: ignore[operator] pos = var.bounds & ValueRanges(0, sympy.oo) new_bounds = new_bounds | pos stm = ops.add(var, self.rename_indexing(size)) # Mixed negative and non-negative - if var.bounds.upper >= 0: + if var.bounds.upper >= 0: # type: ignore[operator] lt = ops.lt(var, "0") stm = ops.where(lt, stm, var) new_var = self.cse.generate(self.compute, stm, bounds=new_bounds) @@ -1310,7 +1310,7 @@ class Kernel(CodeGen): # adds the necessary kernel args for index expressions # and renames variables in index expressions to kernel arg names if isinstance(index, (list, tuple)): - return [self.rename_indexing(x) for x in index] + return [self.rename_indexing(x) for x in index] # type: ignore[return-value] index = V.graph.sizevars.simplify(index) sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) replacements = { diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 4cb8c1e1426..6b6c94aa9e3 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -311,7 +311,7 @@ def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar): @functools.lru_cache def stride_at(index: sympy.Expr, var: sympy.Symbol): replacement = {var: var + 1} - new_index = sympy_subs(index, replacement) + new_index = sympy_subs(index, replacement) # type: ignore[arg-type] return sympy.simplify(new_index - index) @@ -620,10 +620,10 @@ class CppCSEVariable(CSEVariable): """ for s in index.free_symbols: if s in V.kernel.itervars: - self.dependent_itervars.add(s) - elif s.name in V.kernel.cse.varname_map: + self.dependent_itervars.add(s) # type: ignore[arg-type] + elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined] self.dependent_itervars.update( - V.kernel.cse.varname_map[s.name].dependent_itervars + V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined] ) def depends_on(self, itervar: sympy.Symbol): @@ -1512,10 +1512,10 @@ class CppKernel(Kernel): Check if an index has free symbol CppCSEVariable that depends on `itervar`. """ return any( - self.cse.varname_map[s.name].depends_on(itervar) + self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined] for s in index.free_symbols - if s.name in self.cse.varname_map - and isinstance(self.cse.varname_map[s.name], CppCSEVariable) + if s.name in self.cse.varname_map # type: ignore[attr-defined] + and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined] ) def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): @@ -1894,9 +1894,9 @@ class CppVecKernel(CppKernel): ) replacements = {} for indirect_var in ( - self.cse.varname_map[s.name] + self.cse.varname_map[s.name] # type: ignore[attr-defined] for s in index.free_symbols - if s.name.startswith("tmp") + if s.name.startswith("tmp") # type: ignore[attr-defined] ): assert isinstance(indirect_var, CppCSEVariable) if indirect_var.is_vec: @@ -1911,7 +1911,7 @@ class CppVecKernel(CppKernel): ) else: load_mask = f"{self._load_mask} != 0" - index = sympy_subs(index, replacements) + index = sympy_subs(index, replacements) # type: ignore[arg-type] index = self.scale_index_with_offset( index, itervar_idx=self.tiling_idx, offset=itervar_inner ) @@ -1934,7 +1934,7 @@ class CppVecKernel(CppKernel): code.writeline(f"if ({load_mask})") stack.enter_context(code.indent()) code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") - load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) + load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type] code.writeline(f"return {load_line};") code.writeline("()") csevar = self.cse.generate(buffer, code) @@ -2296,7 +2296,7 @@ class CppTile2DKernel(CppVecKernel): # vector load inside the kernel inner loop loadbuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}" dtype = V.graph.get_dtype(name) - line = self._get_vec_load_line(loadbuf, 0, dtype) + line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type] csevar = self.cse.generate(self.loads, line) csevar.update_on_args("load", (name, index), {}) assert isinstance(csevar, CppCSEVariable) diff --git a/torch/_inductor/codegen/memory_planning.py b/torch/_inductor/codegen/memory_planning.py index 0299941b40e..d94e4723dba 100644 --- a/torch/_inductor/codegen/memory_planning.py +++ b/torch/_inductor/codegen/memory_planning.py @@ -326,7 +326,7 @@ class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): @cache_on_self def get_symbolic_size(self) -> sympy.Expr: if not self.allocations: - return 0 + return 0 # type: ignore[return-value] return sympy.Max(*[x.get_symbolic_size() for x in self.allocations]) def is_empty(self): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 4bfddda6bd7..12c32bda0e2 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -197,10 +197,10 @@ class BlockPtrOptions: for i in range(len(self.shape)): if ( self.block_shape[i] != "1" - and not V.graph.sizevars.statically_known_equals(self.strides[i], 0) + and not V.graph.sizevars.statically_known_equals(self.strides[i], 0) # type: ignore[arg-type] and not V.graph.sizevars.statically_known_multiple_of( self.shape[i], - config.triton.max_block[self.block_shape[i][0]], + config.triton.max_block[self.block_shape[i][0]], # type: ignore[arg-type] ) and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK") ): @@ -1280,7 +1280,7 @@ class TritonKernel(Kernel): if hint > threshold: return False # will need to recompile if we cross a larger power of 2 boundary - V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint)) + V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint)) # type: ignore[arg-type] return True def set_last_usage(self, nodes): @@ -1370,7 +1370,7 @@ class TritonKernel(Kernel): for length_group in lengths: return_getters = [] for size in length_group: - if sv.statically_known_equals(size, 1): + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] return_getters.append(lambda _: sympy.Integer(0)) continue @@ -1461,7 +1461,7 @@ class TritonKernel(Kernel): if symbol not in self.range_tree_nodes: # Non-iterated variables, e.g. strides continue - entry = self.range_tree_nodes[symbol] + entry = self.range_tree_nodes[symbol] # type: ignore[index] assert isinstance(entry.parent, IterationRangesRoot) index_numels[entry.parent.index] *= entry.length @@ -1469,7 +1469,7 @@ class TritonKernel(Kernel): # numels, then it must be broadcasted. simplify = V.graph.sizevars.simplify return any( - simplify(idx_range) != simplify(iter_range) + simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type] for idx_range, iter_range in zip(index_numels, self.numels) ) @@ -1603,7 +1603,7 @@ class TritonKernel(Kernel): [m[s] for s in strides], m[offset], range_trees, - mask_vars, + mask_vars, # type: ignore[arg-type] ) expand_str = None @@ -1630,7 +1630,7 @@ class TritonKernel(Kernel): self.filter_masks(mask_vars) mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" - return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex) + return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex) # type: ignore[arg-type] def active_range_trees(self, reorder=False): trees = [ @@ -1647,7 +1647,7 @@ class TritonKernel(Kernel): def filter_masks(self, mask_vars): for tree in self.range_trees: # Masks are superfluous if we only have one element - if V.graph.sizevars.statically_known_equals(tree.numel, 1): + if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type] mask_vars.discard(f"{tree.prefix}mask") continue # Masks are superfluous if numel is a multiple of BLOCK @@ -1659,7 +1659,7 @@ class TritonKernel(Kernel): # never need to do a masked load to handle stragglers at the end. # It's faster to avoid masking at all. But it is sound to always # mask. - if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): + if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type] mask_vars.discard(f"{tree.prefix}mask") def var_ranges(self): @@ -1676,13 +1676,13 @@ class TritonKernel(Kernel): # if indexing expression is complicated, we precompute it on the host side # and send the result as a kernel argument replacements = {} - for ps in self.range_tree_nodes[sym].precomputed_args(): + for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) if len(replacements) > 0: - self.range_tree_nodes[sym].expr = sympy_subs( - self.range_tree_nodes[sym].expr, replacements + self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] + self.range_tree_nodes[sym].expr, replacements # type: ignore[index] ) - self.range_tree_nodes[sym].codegen() + self.range_tree_nodes[sym].codegen() # type: ignore[index] return expr @contextlib.contextmanager @@ -1735,7 +1735,7 @@ class TritonKernel(Kernel): {xindex: 512, rindex: 1024} """ index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} - index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) + index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type] strides = {} for range_tree in self.range_trees: s = sympy_index_symbol(range_tree.name) @@ -1883,7 +1883,7 @@ class TritonKernel(Kernel): result_var = self.cse.generate(load_buffer, line) assert isinstance(result_var, TritonCSEVariable) - result_var.mask_vars = indexing.mask_vars + result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] if append_broadcast: line = f"tl.broadcast_to({result_var}, {append_broadcast})" @@ -2410,7 +2410,7 @@ class TritonKernel(Kernel): # note that random seed is put in V.graph.constants const_tensor = V.graph.constants[arg_name] result.writeline( - f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # noqa: B950 line too long + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long ) elif isinstance(arg_sig, SizeArg): symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) @@ -3128,9 +3128,9 @@ class TritonScheduling(BaseScheduling): # Only install guards for 32-bit indexing as there is no correctness # issue with using 64-bit for everything - V.graph.sizevars.guard_leq(numel, int_max) + V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] for size in buf_sizes: - V.graph.sizevars.guard_leq(size, int_max) + V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] return True @staticmethod diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 43f044abdc8..9509bf0e935 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -81,7 +81,7 @@ def config_of(args: List[Union[TensorArg, SizeArg]]) -> instance_descriptor: return False if isinstance(x.expr, float): return False - return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) + return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type] raise NotImplementedError(f"unhandled {type(x)}: {x}") if config.triton.divisible_by_16: diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 3cde373bebd..e0ffef72be0 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -757,17 +757,17 @@ class WrapperCodeGen(CodeGen): ) for name, shape in graph_inputs_expr: - shape = V.graph.sizevars.simplify(shape) + shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] if shape in needed: - needed.remove(shape) + needed.remove(shape) # type: ignore[arg-type] code.writeline(f"{self.declare}{shape} = {name}{self.ending}") for name, value in graph_inputs_tensors: shapes = value.get_size() for dim, shape in enumerate(shapes): - shape = V.graph.sizevars.simplify(shape) + shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] if shape in needed: - needed.remove(shape) + needed.remove(shape) # type: ignore[arg-type] code.writeline( f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}" ) @@ -775,9 +775,9 @@ class WrapperCodeGen(CodeGen): for name, value in graph_inputs_tensors: shapes = value.get_stride() for dim, shape in enumerate(shapes): - shape = V.graph.sizevars.simplify(shape) + shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] if shape in needed: - needed.remove(shape) + needed.remove(shape) # type: ignore[arg-type] code.writeline( f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}" ) diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index c36d7979b56..20085a39320 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -30,7 +30,7 @@ Dep = Union["MemoryDep", "StarDep", "WeakDep"] class MemoryDep(typing.NamedTuple): name: str - index: sympy.Expr + index: sympy.Expr # type: ignore[assignment] var_names: Tuple[sympy.Symbol, ...] size: Tuple[sympy.Expr, ...] @@ -77,7 +77,7 @@ class MemoryDep(typing.NamedTuple): return isinstance(self.index, (int, sympy.Integer)) def is_indirect(self) -> bool: - return any(is_indirect(v.name) for v in self.index.free_symbols) + return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined] class StarDep(typing.NamedTuple): @@ -146,7 +146,7 @@ class WeakDep(typing.NamedTuple): class IndexExprDep(typing.NamedTuple): - index: sympy.Expr + index: sympy.Expr # type: ignore[assignment] var_names: Tuple[sympy.Symbol, ...] size: Tuple[sympy.Expr, ...] @@ -235,7 +235,7 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1 ) sizes = tuple(v for v in sizes if v != 1) - return index, var_names, sizes + return index, var_names, sizes # type: ignore[return-value] # Try to further simplify the indexes even if simplify_loops didn't # convert it to the simplest form because of the interference from @@ -269,7 +269,7 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] # downstream users won't. Normalize this away. new_vars.pop() new_sizes.pop() - return index, tuple(new_vars), tuple(new_sizes) + return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type] def load(self, name: str, index: sympy.Expr) -> str: self._reads.add(MemoryDep(name, *self.canonicalize(index))) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index b8a0a28f670..ecd64669fd1 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -368,7 +368,7 @@ def cat_tuned_op(match, inputs, dim, *, op, shape_of): if new_size is None: new_size = shape else: - new_size[notdim] = V.graph.sizevars.guard_equals( + new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload] shape[notdim], new_size[notdim] ) new_size[dim] += shape[dim] diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 37f1a998752..d410e885797 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -208,9 +208,9 @@ def ir_node_to_tensor(x, guard_shape=True): size = [shape_fn(s) for s in x.get_size()] stride: StrideType if is_storage_and_layout(x): - stride = [shape_fn(s) for s in x.get_layout().stride] + stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] else: - stride = make_contiguous_strides_for(size) + stride = make_contiguous_strides_for(size) # type: ignore[arg-type] dtype = x.get_dtype() device = x.get_device() size = convert_shape_to_symint(size) @@ -294,7 +294,7 @@ class IRNode: return sympy_product(self.get_size()) def is_zero_elements(self): - return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] def realize(self): """ @@ -1150,7 +1150,7 @@ class Reduction(Loops): ): reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel]) need_mask = not V.graph.sizevars.is_expr_static_and_true( - sympy.Eq(reduction_numel % split, 0) + sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] ) def wrapper_fn(index, reduction_index): @@ -1286,7 +1286,7 @@ class Reduction(Loops): wrapper_fn, ranges, reduction_ranges, - [*ranges, split], + [*ranges, split], # type: ignore[list-item] [block_size], reduction_type, split, @@ -1522,7 +1522,7 @@ class WelfordReduction(Reduction): """ reduction_numel = sympy_product(reduction_ranges) need_mask = not V.graph.sizevars.is_expr_static_and_true( - sympy.Eq(reduction_numel % split, 0) + sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] ) if need_mask and reduction_type != "welford_combine": @@ -1562,7 +1562,7 @@ class WelfordReduction(Reduction): ) for loader in inner_fns ), - [*ranges, split], + [*ranges, split], # type: ignore[list-item] [block_size], reduction_type, reduction_hint, @@ -1587,7 +1587,7 @@ class WelfordReduction(Reduction): for i in intermediates ), ranges, - [split], + [split], # type: ignore[list-item] # welford_reduce turns one input into three outputs, which are combined with welford_combine "welford_combine", reduction_hint, @@ -1680,7 +1680,7 @@ class Scan(Loops): scan_numel = sizevars.simplify(sympy_product(scan_ranges)) # Scan with a single element is just a copy - if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): + if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type] return Pointwise.create( device=device, dtype=dtype, @@ -1993,7 +1993,7 @@ class PermuteView(BaseView): def make_reindexer(self): inv = {j: i for i, j in enumerate(self.dims)} - inv = [inv[i] for i in range(len(self.dims))] + inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index] assert set(inv) == set(range(len(self.dims))) def reindex(index): @@ -2214,12 +2214,12 @@ class View(GenericView): while stack_old: size_old = stack_old.pop() - V.graph.sizevars.guard_equals(size_old, 1) + V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type] view_expr.append(sympy.Integer(0)) while stack_new: var, size_new = stack_new.pop() - V.graph.sizevars.guard_equals(size_new, 1) + V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type] view_expr = list(reversed(view_expr)) assert len(view_expr) == len(old_size) @@ -2227,7 +2227,7 @@ class View(GenericView): def reindex(index): assert len(index) == len(vars), (len(index), len(vars)) replacements = dict(zip(vars, index)) - return tuple(sympy_subs(x, replacements) for x in view_expr) + return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type] return reindex @@ -2487,7 +2487,7 @@ class Layout(IRNode): if ndim not in [4, 5]: return False for left, right, size in zip( - self.stride, make_channels_last_strides_for(self.size), self.size + self.stride, make_channels_last_strides_for(self.size), self.size # type: ignore[arg-type] ): if size != 1 and left != right: return False @@ -2564,7 +2564,7 @@ class Layout(IRNode): ) def storage_size(self) -> sympy.Expr: - return compute_required_storage_length(self.size, self.stride, self.offset) + return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value] class FixedLayout(Layout): @@ -2583,9 +2583,9 @@ class FixedLayout(Layout): super().__init__( device, dtype, - size, + size, # type: ignore[arg-type] stride, - offset, + offset, # type: ignore[arg-type] ) def make_indexer(self): @@ -2715,7 +2715,7 @@ class AliasedLayout(Layout): return True from .compile_fx import ALIGNMENT - return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) + return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type] class NoneLayout(IRNode): @@ -2882,7 +2882,7 @@ class Buffer(IRNode): self.layout = self.layout.as_same_order(stride) def is_zero_elements(self): - return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] def make_loader(self): # Loading from a zero-element buffer is a no-op @@ -3177,7 +3177,7 @@ class ComputedBuffer(Buffer): else: indices = index_vars stride_lengths = [ - V.graph.sizevars.stride_hints(expr, indices) for expr in reads + V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type] ] from .scheduler import pick_loop_order @@ -3907,13 +3907,13 @@ class ExternKernel(InputsKernel): else: type_ = None kwargs.append( - V.graph.wrapper_code.val_to_cpp_arg_str( + V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] type_, v, self.is_legacy_abi_kernel() ) ) else: kwargs = [ - f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" + f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc] for k, v in self.kwargs.items() ] return kwargs @@ -3962,7 +3962,7 @@ class ExternKernel(InputsKernel): _, add_var = var_builder("c") replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) - index = sympy_subs(sympy.expand(index), replacement) + index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type] return index, tuple(new_sizes) def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 84905834977..f75c8e9054c 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -243,14 +243,14 @@ def conv_layout( ir.ir_node_to_tensor(weight, guard_shape=True), ir.ir_node_to_tensor(bias, guard_shape=True), stride, - tuple(V.graph.sizevars.size_hint(p) for p in padding), + tuple(V.graph.sizevars.size_hint(p) for p in padding), # type: ignore[arg-type] dilation, transposed, - tuple(V.graph.sizevars.size_hint(p) for p in output_padding), + tuple(V.graph.sizevars.size_hint(p) for p in output_padding), # type: ignore[arg-type] groups, ) sizes = ir.convert_shape_to_inductor(output.size()) - stride = ir.convert_shape_to_inductor(output.stride()) + stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment] return ir.FixedLayout( x.get_device(), @@ -419,7 +419,7 @@ def convolution( and not transposed and is_zeros(output_padding) # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0) - and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) + and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type] ): if ( is_ones(kernel_shape) diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 03a8cde3ae9..9c7e23316c8 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -36,7 +36,7 @@ def filtered_configs( m = max( next_power_of_2( V.graph.sizevars.size_hint( - m, fallback=torch._inductor.config.unbacked_symint_fallback + m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] ) ), min_block_size, @@ -44,7 +44,7 @@ def filtered_configs( n = max( next_power_of_2( V.graph.sizevars.size_hint( - n, fallback=torch._inductor.config.unbacked_symint_fallback + n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] ) ), min_block_size, @@ -52,7 +52,7 @@ def filtered_configs( k = max( next_power_of_2( V.graph.sizevars.size_hint( - k, fallback=torch._inductor.config.unbacked_symint_fallback + k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] ) ), min_block_size, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 5ab0d73215a..4b6e619591a 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -983,8 +983,8 @@ def pointwise_cat(inputs, dim=0): inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = [] prev_end = 0 for inp in inputs: - inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) - prev_end = inputs_ranges[-1][-1] + inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type] + prev_end = inputs_ranges[-1][-1] # type: ignore[assignment] inputs_loaders = [inp.make_loader() for inp in inputs] @@ -1215,7 +1215,7 @@ def unfold(x, dimension, size, step): dim_size = sizes[dim] sizevars = V.graph.sizevars sizevars.guard_leq(size, dim_size) - sizevars.guard_lt(0, step) + sizevars.guard_lt(0, step) # type: ignore[arg-type] new_dim_size = FloorDiv(dim_size - size, step) + 1 if sizevars.size_hint(dim_size) > 0: @@ -2371,8 +2371,8 @@ def select_scatter(x, src, dim: int, index: int): dim = _validate_dim(x, dim, 0) if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): index = index + x.get_size()[dim] - V.graph.sizevars.guard_leq(0, index) - V.graph.sizevars.guard_lt(index, x.get_size()[dim]) + V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] src = expand(unsqueeze(src, dim), x.get_size()) src_loader = src.make_loader() @@ -3673,7 +3673,7 @@ def constant_pad_nd(x, padding, fill_value=0): # if padding is a complicated expression, hoist it bounds_precomp: List[Tuple[sympy.Symbol, Any]] = [] for l, h in bounds: - bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) + bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type] output_size = list(sizes[:n]) mask_sizes = [] @@ -3770,7 +3770,7 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: # Sliding windows must start within the input or left padding x_alt -= 1 # type: ignore[assignment] - V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) + V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] if V.graph.sizevars.size_hint(x_out - x_alt) == 0: # ceil mode is actually a no-op, lets guard on that V.graph.sizevars.guard_equals(x_out, x_alt) diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 41968b3bbd5..ceff1bddc91 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -128,7 +128,7 @@ class SizeVarAllocator: # actual iteration range is to size-1 iter_ranges_zero = {k: 0 for k, v in var_ranges.items()} base_lowest = sympy_subs(base, iter_ranges_zero) - if self.statically_known_leq(0, base_lowest): + if self.statically_known_leq(0, base_lowest): # type: ignore[arg-type] # can't replace with indexing div if base can be negative base_pos = True else: @@ -272,7 +272,7 @@ class SizeVarAllocator: """ Returns a bool indicating if it is sound to optimize as if left and right are equal. """ - return self.is_expr_static_and_true(sympy.Eq(left, right)) + return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type] # See Note - [On Statically Known] def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool: @@ -307,7 +307,7 @@ class SizeVarAllocator: Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. """ expr = sympy.Eq(numerator % denominator, 0) - return self.is_expr_static_and_true(expr) + return self.is_expr_static_and_true(expr) # type: ignore[arg-type] # The guard functions require you to ALREADY KNOW that a particular # condition holds. If you don't know (you want to guard on an expression @@ -316,9 +316,9 @@ class SizeVarAllocator: def guard_equals(self, left: Expr, right: Expr) -> Expr: if isinstance(left, Expr): - left = sympy_subs(left, self.inv_precomputed_replacements) + left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] if isinstance(right, Expr): - right = sympy_subs(right, self.inv_precomputed_replacements) + right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] assert self.shape_env.evaluate_expr(sympy.Eq(left, right)) return left @@ -329,16 +329,16 @@ class SizeVarAllocator: assert self.shape_env.evaluate_expr(sympy.Lt(left, right)) def expect_true(self, expr: Expr, *, msg: str) -> None: - expr = sympy_subs(expr, self.inv_precomputed_replacements) + expr = sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] self.shape_env.defer_runtime_assert(expr, msg, fx_node=None) def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr: # Prefer returning the expression without unbacked symints if self.shape_env.is_unbacked_symint(left): - self.expect_true(sympy.Eq(left, right), msg=msg) + self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type] return right elif self.shape_env.is_unbacked_symint(right): - self.expect_true(sympy.Eq(left, right), msg=msg) + self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type] return left else: return self.guard_equals(left, right) @@ -399,8 +399,8 @@ class SizeVarAllocator: return [self.evaluate_static_shape(x) for x in left] def remove_precomputed_replacements(self, expr: Expr) -> Expr: - if any(s.name.startswith("ps") for s in expr.free_symbols): - return sympy_subs(expr, self.inv_precomputed_replacements) + if any(s.name.startswith("ps") for s in expr.free_symbols): # type: ignore[attr-defined] + return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] return expr def symbolic_hint(self, expr: Expr) -> Expr: @@ -410,7 +410,7 @@ class SizeVarAllocator: return expr free_symbols = expr.free_symbols if not free_symbols: - return int(expr) + return int(expr) # type: ignore[return-value] expr = self.remove_precomputed_replacements(expr) return sympy_subs(expr, self.var_to_val) @@ -422,9 +422,9 @@ class SizeVarAllocator: s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols } if all(vr is not None for vr in sym_vrs.values()): - expr_vr = bound_sympy(expr, sym_vrs) - lower = self.size_hint(expr_vr.lower) - upper = self.size_hint(expr_vr.upper) + expr_vr = bound_sympy(expr, sym_vrs) # type: ignore[arg-type] + lower = self.size_hint(expr_vr.lower) # type: ignore[arg-type] + upper = self.size_hint(expr_vr.upper) # type: ignore[arg-type] fallback = min(max(fallback, lower), upper) return fallback try: @@ -522,8 +522,8 @@ class SizeVarAllocator: support_vars: Optional[List[sympy.Symbol]] = None, ) -> List[int]: for v in index.free_symbols: - if v.name.startswith("indirect"): - index = sympy_subs(index, {v: 0}) + if v.name.startswith("indirect"): # type: ignore[attr-defined] + index = sympy_subs(index, {v: 0}) # type: ignore[dict-item] result = [] for s in self.stride_vars(index, vars, support_vars): try: @@ -538,7 +538,7 @@ class SizeVarAllocator: order.sort(key=lambda x: (strides[x] == 0, strides[x])) return order - def lookup_precomputed_size(self, expr: Expr) -> sympy.Symbol: + def lookup_precomputed_size(self, expr: Expr) -> Expr: if ( isinstance(expr, (int, sympy.Symbol, sympy.Number)) or expr.is_number diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index d89d95c9da4..8c36dc01029 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -569,8 +569,8 @@ def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.E if isinstance(replacement, str): return sympy.Symbol( replacement, - integer=replaced.is_integer, - nonnegative=replaced.is_nonnegative, + integer=replaced.is_integer, # type: ignore[attr-defined] + nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined] ) else: return replacement @@ -582,11 +582,11 @@ def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.E def free_symbol_startswith(index: sympy.Expr, prefix: str): - return any(v.name.startswith(prefix) for v in index.free_symbols) + return any(v.name.startswith(prefix) for v in index.free_symbols) # type: ignore[attr-defined] def free_symbol_has(index: sympy.Expr, pattern: str): - return any(pattern in v.name for v in index.free_symbols) + return any(pattern in v.name for v in index.free_symbols) # type: ignore[attr-defined] def is_symbolic(a: Any) -> bool: @@ -1081,7 +1081,7 @@ def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: assert isinstance( val, sympy.Expr ), "only support sympy.Expr as input to get_sympy_Expr_dtype" - if val.is_integer: + if val.is_integer: # type: ignore[attr-defined] return torch.int64 else: return torch.float64 diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 7041713ad06..00d454e70a4 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -468,7 +468,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU1d + _FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment] _FLOAT_CONV_MODULE = nn.Conv1d _FLOAT_BN_MODULE: None = None _FLOAT_RELU_MODULE = nn.ReLU @@ -600,7 +600,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU2d + _FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment] _FLOAT_CONV_MODULE = nn.Conv2d _FLOAT_BN_MODULE: None = None _FLOAT_RELU_MODULE = nn.ReLU @@ -773,7 +773,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU3d + _FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment] _FLOAT_CONV_MODULE = nn.Conv3d _FLOAT_BN_MODULE: None = None _FLOAT_RELU_MODULE = nn.ReLU