Remove follow_imports = skip from sympy (#118469)

dmypy silently ignores follow_imports = skip, so to get parity between
dmypy and mypy we have to suck it up and type: ignore all of the sympy
typing problems.

The suppressions were added automatically with the following script generated by GPT-4:

```
import re

# Read the error file
with open("error_file.txt", "r") as f:
    errors = f.readlines()

# Parse the lines with errors and error types
error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

# Insert ignore comments in the source files
for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118469
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432, #118467, #118468
This commit is contained in:
Edward Z. Yang 2024-01-27 19:39:22 -08:00 committed by PyTorch MergeBot
parent 59b4d2cd40
commit cad79bd0bb
17 changed files with 112 additions and 112 deletions

View File

@ -138,6 +138,7 @@ init_command = [
'numpy==1.26.0 ; python_version >= "3.9"',
'expecttest==0.1.6',
'mypy==1.7.0',
'sympy==1.11.1',
'types-requests==2.27.25',
'types-PyYAML==6.0.7',
'types-tabulate==0.8.8',

View File

@ -179,7 +179,6 @@ ignore_missing_imports = True
[mypy-sympy.*]
ignore_missing_imports = True
follow_imports = skip
[mypy-hypothesis.*]
ignore_missing_imports = True

View File

@ -1143,7 +1143,7 @@ class Kernel(CodeGen):
):
# Skip CSE since this doesn't return an expression
if var.bounds.lower < 0:
if var.bounds.lower < 0: # type: ignore[operator]
new_bounds = ValueRanges.unknown()
if var.bounds != ValueRanges.unknown() and isinstance(
size, sympy.Number
@ -1154,13 +1154,13 @@ class Kernel(CodeGen):
neg = var.bounds & ValueRanges(-sympy.oo, -1)
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
# We don't have a good way of representing the empty range
if var.bounds.upper >= 0:
if var.bounds.upper >= 0: # type: ignore[operator]
pos = var.bounds & ValueRanges(0, sympy.oo)
new_bounds = new_bounds | pos
stm = ops.add(var, self.rename_indexing(size))
# Mixed negative and non-negative
if var.bounds.upper >= 0:
if var.bounds.upper >= 0: # type: ignore[operator]
lt = ops.lt(var, "0")
stm = ops.where(lt, stm, var)
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
@ -1310,7 +1310,7 @@ class Kernel(CodeGen):
# adds the necessary kernel args for index expressions
# and renames variables in index expressions to kernel arg names
if isinstance(index, (list, tuple)):
return [self.rename_indexing(x) for x in index]
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
index = V.graph.sizevars.simplify(index)
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
replacements = {

View File

@ -311,7 +311,7 @@ def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar):
@functools.lru_cache
def stride_at(index: sympy.Expr, var: sympy.Symbol):
replacement = {var: var + 1}
new_index = sympy_subs(index, replacement)
new_index = sympy_subs(index, replacement) # type: ignore[arg-type]
return sympy.simplify(new_index - index)
@ -620,10 +620,10 @@ class CppCSEVariable(CSEVariable):
"""
for s in index.free_symbols:
if s in V.kernel.itervars:
self.dependent_itervars.add(s)
elif s.name in V.kernel.cse.varname_map:
self.dependent_itervars.add(s) # type: ignore[arg-type]
elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined]
self.dependent_itervars.update(
V.kernel.cse.varname_map[s.name].dependent_itervars
V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined]
)
def depends_on(self, itervar: sympy.Symbol):
@ -1512,10 +1512,10 @@ class CppKernel(Kernel):
Check if an index has free symbol CppCSEVariable that depends on `itervar`.
"""
return any(
self.cse.varname_map[s.name].depends_on(itervar)
self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined]
for s in index.free_symbols
if s.name in self.cse.varname_map
and isinstance(self.cse.varname_map[s.name], CppCSEVariable)
if s.name in self.cse.varname_map # type: ignore[attr-defined]
and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined]
)
def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol):
@ -1894,9 +1894,9 @@ class CppVecKernel(CppKernel):
)
replacements = {}
for indirect_var in (
self.cse.varname_map[s.name]
self.cse.varname_map[s.name] # type: ignore[attr-defined]
for s in index.free_symbols
if s.name.startswith("tmp")
if s.name.startswith("tmp") # type: ignore[attr-defined]
):
assert isinstance(indirect_var, CppCSEVariable)
if indirect_var.is_vec:
@ -1911,7 +1911,7 @@ class CppVecKernel(CppKernel):
)
else:
load_mask = f"{self._load_mask} != 0"
index = sympy_subs(index, replacements)
index = sympy_subs(index, replacements) # type: ignore[arg-type]
index = self.scale_index_with_offset(
index, itervar_idx=self.tiling_idx, offset=itervar_inner
)
@ -1934,7 +1934,7 @@ class CppVecKernel(CppKernel):
code.writeline(f"if ({load_mask})")
stack.enter_context(code.indent())
code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};")
load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype)
load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type]
code.writeline(f"return {load_line};")
code.writeline("()")
csevar = self.cse.generate(buffer, code)
@ -2296,7 +2296,7 @@ class CppTile2DKernel(CppVecKernel):
# vector load inside the kernel inner loop
loadbuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}"
dtype = V.graph.get_dtype(name)
line = self._get_vec_load_line(loadbuf, 0, dtype)
line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type]
csevar = self.cse.generate(self.loads, line)
csevar.update_on_args("load", (name, index), {})
assert isinstance(csevar, CppCSEVariable)

View File

@ -326,7 +326,7 @@ class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
@cache_on_self
def get_symbolic_size(self) -> sympy.Expr:
if not self.allocations:
return 0
return 0 # type: ignore[return-value]
return sympy.Max(*[x.get_symbolic_size() for x in self.allocations])
def is_empty(self):

View File

@ -197,10 +197,10 @@ class BlockPtrOptions:
for i in range(len(self.shape)):
if (
self.block_shape[i] != "1"
and not V.graph.sizevars.statically_known_equals(self.strides[i], 0)
and not V.graph.sizevars.statically_known_equals(self.strides[i], 0) # type: ignore[arg-type]
and not V.graph.sizevars.statically_known_multiple_of(
self.shape[i],
config.triton.max_block[self.block_shape[i][0]],
config.triton.max_block[self.block_shape[i][0]], # type: ignore[arg-type]
)
and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK")
):
@ -1280,7 +1280,7 @@ class TritonKernel(Kernel):
if hint > threshold:
return False
# will need to recompile if we cross a larger power of 2 boundary
V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint))
V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint)) # type: ignore[arg-type]
return True
def set_last_usage(self, nodes):
@ -1370,7 +1370,7 @@ class TritonKernel(Kernel):
for length_group in lengths:
return_getters = []
for size in length_group:
if sv.statically_known_equals(size, 1):
if sv.statically_known_equals(size, 1): # type: ignore[arg-type]
return_getters.append(lambda _: sympy.Integer(0))
continue
@ -1461,7 +1461,7 @@ class TritonKernel(Kernel):
if symbol not in self.range_tree_nodes:
# Non-iterated variables, e.g. strides
continue
entry = self.range_tree_nodes[symbol]
entry = self.range_tree_nodes[symbol] # type: ignore[index]
assert isinstance(entry.parent, IterationRangesRoot)
index_numels[entry.parent.index] *= entry.length
@ -1469,7 +1469,7 @@ class TritonKernel(Kernel):
# numels, then it must be broadcasted.
simplify = V.graph.sizevars.simplify
return any(
simplify(idx_range) != simplify(iter_range)
simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type]
for idx_range, iter_range in zip(index_numels, self.numels)
)
@ -1603,7 +1603,7 @@ class TritonKernel(Kernel):
[m[s] for s in strides],
m[offset],
range_trees,
mask_vars,
mask_vars, # type: ignore[arg-type]
)
expand_str = None
@ -1630,7 +1630,7 @@ class TritonKernel(Kernel):
self.filter_masks(mask_vars)
mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex)
return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex) # type: ignore[arg-type]
def active_range_trees(self, reorder=False):
trees = [
@ -1647,7 +1647,7 @@ class TritonKernel(Kernel):
def filter_masks(self, mask_vars):
for tree in self.range_trees:
# Masks are superfluous if we only have one element
if V.graph.sizevars.statically_known_equals(tree.numel, 1):
if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type]
mask_vars.discard(f"{tree.prefix}mask")
continue
# Masks are superfluous if numel is a multiple of BLOCK
@ -1659,7 +1659,7 @@ class TritonKernel(Kernel):
# never need to do a masked load to handle stragglers at the end.
# It's faster to avoid masking at all. But it is sound to always
# mask.
if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block):
if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type]
mask_vars.discard(f"{tree.prefix}mask")
def var_ranges(self):
@ -1676,13 +1676,13 @@ class TritonKernel(Kernel):
# if indexing expression is complicated, we precompute it on the host side
# and send the result as a kernel argument
replacements = {}
for ps in self.range_tree_nodes[sym].precomputed_args():
for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index]
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
if len(replacements) > 0:
self.range_tree_nodes[sym].expr = sympy_subs(
self.range_tree_nodes[sym].expr, replacements
self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index]
self.range_tree_nodes[sym].expr, replacements # type: ignore[index]
)
self.range_tree_nodes[sym].codegen()
self.range_tree_nodes[sym].codegen() # type: ignore[index]
return expr
@contextlib.contextmanager
@ -1735,7 +1735,7 @@ class TritonKernel(Kernel):
{xindex: 512, rindex: 1024}
"""
index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()}
index_in_tile_vars = sympy_subs(index, index_to_tile_indexes)
index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type]
strides = {}
for range_tree in self.range_trees:
s = sympy_index_symbol(range_tree.name)
@ -1883,7 +1883,7 @@ class TritonKernel(Kernel):
result_var = self.cse.generate(load_buffer, line)
assert isinstance(result_var, TritonCSEVariable)
result_var.mask_vars = indexing.mask_vars
result_var.mask_vars = indexing.mask_vars # type: ignore[assignment]
if append_broadcast:
line = f"tl.broadcast_to({result_var}, {append_broadcast})"
@ -2410,7 +2410,7 @@ class TritonKernel(Kernel):
# note that random seed is put in V.graph.constants
const_tensor = V.graph.constants[arg_name]
result.writeline(
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # noqa: B950 line too long
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
)
elif isinstance(arg_sig, SizeArg):
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
@ -3128,9 +3128,9 @@ class TritonScheduling(BaseScheduling):
# Only install guards for 32-bit indexing as there is no correctness
# issue with using 64-bit for everything
V.graph.sizevars.guard_leq(numel, int_max)
V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type]
for size in buf_sizes:
V.graph.sizevars.guard_leq(size, int_max)
V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type]
return True
@staticmethod

View File

@ -81,7 +81,7 @@ def config_of(args: List[Union[TensorArg, SizeArg]]) -> instance_descriptor:
return False
if isinstance(x.expr, float):
return False
return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment)
return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type]
raise NotImplementedError(f"unhandled {type(x)}: {x}")
if config.triton.divisible_by_16:

View File

@ -757,17 +757,17 @@ class WrapperCodeGen(CodeGen):
)
for name, shape in graph_inputs_expr:
shape = V.graph.sizevars.simplify(shape)
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
if shape in needed:
needed.remove(shape)
needed.remove(shape) # type: ignore[arg-type]
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
for name, value in graph_inputs_tensors:
shapes = value.get_size()
for dim, shape in enumerate(shapes):
shape = V.graph.sizevars.simplify(shape)
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
if shape in needed:
needed.remove(shape)
needed.remove(shape) # type: ignore[arg-type]
code.writeline(
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
)
@ -775,9 +775,9 @@ class WrapperCodeGen(CodeGen):
for name, value in graph_inputs_tensors:
shapes = value.get_stride()
for dim, shape in enumerate(shapes):
shape = V.graph.sizevars.simplify(shape)
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
if shape in needed:
needed.remove(shape)
needed.remove(shape) # type: ignore[arg-type]
code.writeline(
f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
)

View File

@ -30,7 +30,7 @@ Dep = Union["MemoryDep", "StarDep", "WeakDep"]
class MemoryDep(typing.NamedTuple):
name: str
index: sympy.Expr
index: sympy.Expr # type: ignore[assignment]
var_names: Tuple[sympy.Symbol, ...]
size: Tuple[sympy.Expr, ...]
@ -77,7 +77,7 @@ class MemoryDep(typing.NamedTuple):
return isinstance(self.index, (int, sympy.Integer))
def is_indirect(self) -> bool:
return any(is_indirect(v.name) for v in self.index.free_symbols)
return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined]
class StarDep(typing.NamedTuple):
@ -146,7 +146,7 @@ class WeakDep(typing.NamedTuple):
class IndexExprDep(typing.NamedTuple):
index: sympy.Expr
index: sympy.Expr # type: ignore[assignment]
var_names: Tuple[sympy.Symbol, ...]
size: Tuple[sympy.Expr, ...]
@ -235,7 +235,7 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1
)
sizes = tuple(v for v in sizes if v != 1)
return index, var_names, sizes
return index, var_names, sizes # type: ignore[return-value]
# Try to further simplify the indexes even if simplify_loops didn't
# convert it to the simplest form because of the interference from
@ -269,7 +269,7 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
# downstream users won't. Normalize this away.
new_vars.pop()
new_sizes.pop()
return index, tuple(new_vars), tuple(new_sizes)
return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type]
def load(self, name: str, index: sympy.Expr) -> str:
self._reads.add(MemoryDep(name, *self.canonicalize(index)))

View File

@ -368,7 +368,7 @@ def cat_tuned_op(match, inputs, dim, *, op, shape_of):
if new_size is None:
new_size = shape
else:
new_size[notdim] = V.graph.sizevars.guard_equals(
new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload]
shape[notdim], new_size[notdim]
)
new_size[dim] += shape[dim]

View File

@ -208,9 +208,9 @@ def ir_node_to_tensor(x, guard_shape=True):
size = [shape_fn(s) for s in x.get_size()]
stride: StrideType
if is_storage_and_layout(x):
stride = [shape_fn(s) for s in x.get_layout().stride]
stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc]
else:
stride = make_contiguous_strides_for(size)
stride = make_contiguous_strides_for(size) # type: ignore[arg-type]
dtype = x.get_dtype()
device = x.get_device()
size = convert_shape_to_symint(size)
@ -294,7 +294,7 @@ class IRNode:
return sympy_product(self.get_size())
def is_zero_elements(self):
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
def realize(self):
"""
@ -1150,7 +1150,7 @@ class Reduction(Loops):
):
reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
need_mask = not V.graph.sizevars.is_expr_static_and_true(
sympy.Eq(reduction_numel % split, 0)
sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type]
)
def wrapper_fn(index, reduction_index):
@ -1286,7 +1286,7 @@ class Reduction(Loops):
wrapper_fn,
ranges,
reduction_ranges,
[*ranges, split],
[*ranges, split], # type: ignore[list-item]
[block_size],
reduction_type,
split,
@ -1522,7 +1522,7 @@ class WelfordReduction(Reduction):
"""
reduction_numel = sympy_product(reduction_ranges)
need_mask = not V.graph.sizevars.is_expr_static_and_true(
sympy.Eq(reduction_numel % split, 0)
sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type]
)
if need_mask and reduction_type != "welford_combine":
@ -1562,7 +1562,7 @@ class WelfordReduction(Reduction):
)
for loader in inner_fns
),
[*ranges, split],
[*ranges, split], # type: ignore[list-item]
[block_size],
reduction_type,
reduction_hint,
@ -1587,7 +1587,7 @@ class WelfordReduction(Reduction):
for i in intermediates
),
ranges,
[split],
[split], # type: ignore[list-item]
# welford_reduce turns one input into three outputs, which are combined with welford_combine
"welford_combine",
reduction_hint,
@ -1680,7 +1680,7 @@ class Scan(Loops):
scan_numel = sizevars.simplify(sympy_product(scan_ranges))
# Scan with a single element is just a copy
if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)):
if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type]
return Pointwise.create(
device=device,
dtype=dtype,
@ -1993,7 +1993,7 @@ class PermuteView(BaseView):
def make_reindexer(self):
inv = {j: i for i, j in enumerate(self.dims)}
inv = [inv[i] for i in range(len(self.dims))]
inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index]
assert set(inv) == set(range(len(self.dims)))
def reindex(index):
@ -2214,12 +2214,12 @@ class View(GenericView):
while stack_old:
size_old = stack_old.pop()
V.graph.sizevars.guard_equals(size_old, 1)
V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type]
view_expr.append(sympy.Integer(0))
while stack_new:
var, size_new = stack_new.pop()
V.graph.sizevars.guard_equals(size_new, 1)
V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type]
view_expr = list(reversed(view_expr))
assert len(view_expr) == len(old_size)
@ -2227,7 +2227,7 @@ class View(GenericView):
def reindex(index):
assert len(index) == len(vars), (len(index), len(vars))
replacements = dict(zip(vars, index))
return tuple(sympy_subs(x, replacements) for x in view_expr)
return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type]
return reindex
@ -2487,7 +2487,7 @@ class Layout(IRNode):
if ndim not in [4, 5]:
return False
for left, right, size in zip(
self.stride, make_channels_last_strides_for(self.size), self.size
self.stride, make_channels_last_strides_for(self.size), self.size # type: ignore[arg-type]
):
if size != 1 and left != right:
return False
@ -2564,7 +2564,7 @@ class Layout(IRNode):
)
def storage_size(self) -> sympy.Expr:
return compute_required_storage_length(self.size, self.stride, self.offset)
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value]
class FixedLayout(Layout):
@ -2583,9 +2583,9 @@ class FixedLayout(Layout):
super().__init__(
device,
dtype,
size,
size, # type: ignore[arg-type]
stride,
offset,
offset, # type: ignore[arg-type]
)
def make_indexer(self):
@ -2715,7 +2715,7 @@ class AliasedLayout(Layout):
return True
from .compile_fx import ALIGNMENT
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type]
class NoneLayout(IRNode):
@ -2882,7 +2882,7 @@ class Buffer(IRNode):
self.layout = self.layout.as_same_order(stride)
def is_zero_elements(self):
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
def make_loader(self):
# Loading from a zero-element buffer is a no-op
@ -3177,7 +3177,7 @@ class ComputedBuffer(Buffer):
else:
indices = index_vars
stride_lengths = [
V.graph.sizevars.stride_hints(expr, indices) for expr in reads
V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type]
]
from .scheduler import pick_loop_order
@ -3907,13 +3907,13 @@ class ExternKernel(InputsKernel):
else:
type_ = None
kwargs.append(
V.graph.wrapper_code.val_to_cpp_arg_str(
V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type]
type_, v, self.is_legacy_abi_kernel()
)
)
else:
kwargs = [
f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}"
f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc]
for k, v in self.kwargs.items()
]
return kwargs
@ -3962,7 +3962,7 @@ class ExternKernel(InputsKernel):
_, add_var = var_builder("c")
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
index = sympy_subs(sympy.expand(index), replacement)
index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type]
return index, tuple(new_sizes)
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:

View File

@ -243,14 +243,14 @@ def conv_layout(
ir.ir_node_to_tensor(weight, guard_shape=True),
ir.ir_node_to_tensor(bias, guard_shape=True),
stride,
tuple(V.graph.sizevars.size_hint(p) for p in padding),
tuple(V.graph.sizevars.size_hint(p) for p in padding), # type: ignore[arg-type]
dilation,
transposed,
tuple(V.graph.sizevars.size_hint(p) for p in output_padding),
tuple(V.graph.sizevars.size_hint(p) for p in output_padding), # type: ignore[arg-type]
groups,
)
sizes = ir.convert_shape_to_inductor(output.size())
stride = ir.convert_shape_to_inductor(output.stride())
stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
return ir.FixedLayout(
x.get_device(),
@ -419,7 +419,7 @@ def convolution(
and not transposed
and is_zeros(output_padding)
# there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1])
and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
):
if (
is_ones(kernel_shape)

View File

@ -36,7 +36,7 @@ def filtered_configs(
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m, fallback=torch._inductor.config.unbacked_symint_fallback
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
)
),
min_block_size,
@ -44,7 +44,7 @@ def filtered_configs(
n = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n, fallback=torch._inductor.config.unbacked_symint_fallback
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
)
),
min_block_size,
@ -52,7 +52,7 @@ def filtered_configs(
k = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k, fallback=torch._inductor.config.unbacked_symint_fallback
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
)
),
min_block_size,

View File

@ -983,8 +983,8 @@ def pointwise_cat(inputs, dim=0):
inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = []
prev_end = 0
for inp in inputs:
inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim]))
prev_end = inputs_ranges[-1][-1]
inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type]
prev_end = inputs_ranges[-1][-1] # type: ignore[assignment]
inputs_loaders = [inp.make_loader() for inp in inputs]
@ -1215,7 +1215,7 @@ def unfold(x, dimension, size, step):
dim_size = sizes[dim]
sizevars = V.graph.sizevars
sizevars.guard_leq(size, dim_size)
sizevars.guard_lt(0, step)
sizevars.guard_lt(0, step) # type: ignore[arg-type]
new_dim_size = FloorDiv(dim_size - size, step) + 1
if sizevars.size_hint(dim_size) > 0:
@ -2371,8 +2371,8 @@ def select_scatter(x, src, dim: int, index: int):
dim = _validate_dim(x, dim, 0)
if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
index = index + x.get_size()[dim]
V.graph.sizevars.guard_leq(0, index)
V.graph.sizevars.guard_lt(index, x.get_size()[dim])
V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type]
V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
src = expand(unsqueeze(src, dim), x.get_size())
src_loader = src.make_loader()
@ -3673,7 +3673,7 @@ def constant_pad_nd(x, padding, fill_value=0):
# if padding is a complicated expression, hoist it
bounds_precomp: List[Tuple[sympy.Symbol, Any]] = []
for l, h in bounds:
bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h))
bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type]
output_size = list(sizes[:n])
mask_sizes = []
@ -3770,7 +3770,7 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
# Sliding windows must start within the input or left padding
x_alt -= 1 # type: ignore[assignment]
V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i])
V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
# ceil mode is actually a no-op, lets guard on that
V.graph.sizevars.guard_equals(x_out, x_alt)

View File

@ -128,7 +128,7 @@ class SizeVarAllocator:
# actual iteration range is to size-1
iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
base_lowest = sympy_subs(base, iter_ranges_zero)
if self.statically_known_leq(0, base_lowest):
if self.statically_known_leq(0, base_lowest): # type: ignore[arg-type]
# can't replace with indexing div if base can be negative
base_pos = True
else:
@ -272,7 +272,7 @@ class SizeVarAllocator:
"""
Returns a bool indicating if it is sound to optimize as if left and right are equal.
"""
return self.is_expr_static_and_true(sympy.Eq(left, right))
return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type]
# See Note - [On Statically Known]
def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
@ -307,7 +307,7 @@ class SizeVarAllocator:
Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
"""
expr = sympy.Eq(numerator % denominator, 0)
return self.is_expr_static_and_true(expr)
return self.is_expr_static_and_true(expr) # type: ignore[arg-type]
# The guard functions require you to ALREADY KNOW that a particular
# condition holds. If you don't know (you want to guard on an expression
@ -316,9 +316,9 @@ class SizeVarAllocator:
def guard_equals(self, left: Expr, right: Expr) -> Expr:
if isinstance(left, Expr):
left = sympy_subs(left, self.inv_precomputed_replacements)
left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
if isinstance(right, Expr):
right = sympy_subs(right, self.inv_precomputed_replacements)
right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
return left
@ -329,16 +329,16 @@ class SizeVarAllocator:
assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
def expect_true(self, expr: Expr, *, msg: str) -> None:
expr = sympy_subs(expr, self.inv_precomputed_replacements)
expr = sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
self.shape_env.defer_runtime_assert(expr, msg, fx_node=None)
def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr:
# Prefer returning the expression without unbacked symints
if self.shape_env.is_unbacked_symint(left):
self.expect_true(sympy.Eq(left, right), msg=msg)
self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
return right
elif self.shape_env.is_unbacked_symint(right):
self.expect_true(sympy.Eq(left, right), msg=msg)
self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
return left
else:
return self.guard_equals(left, right)
@ -399,8 +399,8 @@ class SizeVarAllocator:
return [self.evaluate_static_shape(x) for x in left]
def remove_precomputed_replacements(self, expr: Expr) -> Expr:
if any(s.name.startswith("ps") for s in expr.free_symbols):
return sympy_subs(expr, self.inv_precomputed_replacements)
if any(s.name.startswith("ps") for s in expr.free_symbols): # type: ignore[attr-defined]
return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
return expr
def symbolic_hint(self, expr: Expr) -> Expr:
@ -410,7 +410,7 @@ class SizeVarAllocator:
return expr
free_symbols = expr.free_symbols
if not free_symbols:
return int(expr)
return int(expr) # type: ignore[return-value]
expr = self.remove_precomputed_replacements(expr)
return sympy_subs(expr, self.var_to_val)
@ -422,9 +422,9 @@ class SizeVarAllocator:
s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols
}
if all(vr is not None for vr in sym_vrs.values()):
expr_vr = bound_sympy(expr, sym_vrs)
lower = self.size_hint(expr_vr.lower)
upper = self.size_hint(expr_vr.upper)
expr_vr = bound_sympy(expr, sym_vrs) # type: ignore[arg-type]
lower = self.size_hint(expr_vr.lower) # type: ignore[arg-type]
upper = self.size_hint(expr_vr.upper) # type: ignore[arg-type]
fallback = min(max(fallback, lower), upper)
return fallback
try:
@ -522,8 +522,8 @@ class SizeVarAllocator:
support_vars: Optional[List[sympy.Symbol]] = None,
) -> List[int]:
for v in index.free_symbols:
if v.name.startswith("indirect"):
index = sympy_subs(index, {v: 0})
if v.name.startswith("indirect"): # type: ignore[attr-defined]
index = sympy_subs(index, {v: 0}) # type: ignore[dict-item]
result = []
for s in self.stride_vars(index, vars, support_vars):
try:
@ -538,7 +538,7 @@ class SizeVarAllocator:
order.sort(key=lambda x: (strides[x] == 0, strides[x]))
return order
def lookup_precomputed_size(self, expr: Expr) -> sympy.Symbol:
def lookup_precomputed_size(self, expr: Expr) -> Expr:
if (
isinstance(expr, (int, sympy.Symbol, sympy.Number))
or expr.is_number

View File

@ -569,8 +569,8 @@ def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.E
if isinstance(replacement, str):
return sympy.Symbol(
replacement,
integer=replaced.is_integer,
nonnegative=replaced.is_nonnegative,
integer=replaced.is_integer, # type: ignore[attr-defined]
nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
)
else:
return replacement
@ -582,11 +582,11 @@ def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.E
def free_symbol_startswith(index: sympy.Expr, prefix: str):
return any(v.name.startswith(prefix) for v in index.free_symbols)
return any(v.name.startswith(prefix) for v in index.free_symbols) # type: ignore[attr-defined]
def free_symbol_has(index: sympy.Expr, pattern: str):
return any(pattern in v.name for v in index.free_symbols)
return any(pattern in v.name for v in index.free_symbols) # type: ignore[attr-defined]
def is_symbolic(a: Any) -> bool:
@ -1081,7 +1081,7 @@ def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
assert isinstance(
val, sympy.Expr
), "only support sympy.Expr as input to get_sympy_Expr_dtype"
if val.is_integer:
if val.is_integer: # type: ignore[attr-defined]
return torch.int64
else:
return torch.float64

View File

@ -468,7 +468,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nni.ConvReLU1d
_FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment]
_FLOAT_CONV_MODULE = nn.Conv1d
_FLOAT_BN_MODULE: None = None
_FLOAT_RELU_MODULE = nn.ReLU
@ -600,7 +600,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nni.ConvReLU2d
_FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment]
_FLOAT_CONV_MODULE = nn.Conv2d
_FLOAT_BN_MODULE: None = None
_FLOAT_RELU_MODULE = nn.ReLU
@ -773,7 +773,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nni.ConvReLU3d
_FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment]
_FLOAT_CONV_MODULE = nn.Conv3d
_FLOAT_BN_MODULE: None = None
_FLOAT_RELU_MODULE = nn.ReLU