Remove follow_imports = skip from sympy (#118469)

dmypy silently ignores follow_imports = skip, so to get parity between
dmypy and mypy we have to suck it up and type: ignore all of the sympy
typing problems.

The suppressions were added automatically with the following script generated by GPT-4:

```
import re

# Read the error file
with open("error_file.txt", "r") as f:
    errors = f.readlines()

# Parse the lines with errors and error types
error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

# Insert ignore comments in the source files
for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118469
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432, #118467, #118468
This commit is contained in:
Edward Z. Yang 2024-01-27 19:39:22 -08:00 committed by PyTorch MergeBot
parent 59b4d2cd40
commit cad79bd0bb
17 changed files with 112 additions and 112 deletions

View File

@ -138,6 +138,7 @@ init_command = [
'numpy==1.26.0 ; python_version >= "3.9"', 'numpy==1.26.0 ; python_version >= "3.9"',
'expecttest==0.1.6', 'expecttest==0.1.6',
'mypy==1.7.0', 'mypy==1.7.0',
'sympy==1.11.1',
'types-requests==2.27.25', 'types-requests==2.27.25',
'types-PyYAML==6.0.7', 'types-PyYAML==6.0.7',
'types-tabulate==0.8.8', 'types-tabulate==0.8.8',

View File

@ -179,7 +179,6 @@ ignore_missing_imports = True
[mypy-sympy.*] [mypy-sympy.*]
ignore_missing_imports = True ignore_missing_imports = True
follow_imports = skip
[mypy-hypothesis.*] [mypy-hypothesis.*]
ignore_missing_imports = True ignore_missing_imports = True

View File

@ -1143,7 +1143,7 @@ class Kernel(CodeGen):
): ):
# Skip CSE since this doesn't return an expression # Skip CSE since this doesn't return an expression
if var.bounds.lower < 0: if var.bounds.lower < 0: # type: ignore[operator]
new_bounds = ValueRanges.unknown() new_bounds = ValueRanges.unknown()
if var.bounds != ValueRanges.unknown() and isinstance( if var.bounds != ValueRanges.unknown() and isinstance(
size, sympy.Number size, sympy.Number
@ -1154,13 +1154,13 @@ class Kernel(CodeGen):
neg = var.bounds & ValueRanges(-sympy.oo, -1) neg = var.bounds & ValueRanges(-sympy.oo, -1)
new_bounds = ValueRanges(neg.lower + size, neg.upper + size) new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
# We don't have a good way of representing the empty range # We don't have a good way of representing the empty range
if var.bounds.upper >= 0: if var.bounds.upper >= 0: # type: ignore[operator]
pos = var.bounds & ValueRanges(0, sympy.oo) pos = var.bounds & ValueRanges(0, sympy.oo)
new_bounds = new_bounds | pos new_bounds = new_bounds | pos
stm = ops.add(var, self.rename_indexing(size)) stm = ops.add(var, self.rename_indexing(size))
# Mixed negative and non-negative # Mixed negative and non-negative
if var.bounds.upper >= 0: if var.bounds.upper >= 0: # type: ignore[operator]
lt = ops.lt(var, "0") lt = ops.lt(var, "0")
stm = ops.where(lt, stm, var) stm = ops.where(lt, stm, var)
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds) new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
@ -1310,7 +1310,7 @@ class Kernel(CodeGen):
# adds the necessary kernel args for index expressions # adds the necessary kernel args for index expressions
# and renames variables in index expressions to kernel arg names # and renames variables in index expressions to kernel arg names
if isinstance(index, (list, tuple)): if isinstance(index, (list, tuple)):
return [self.rename_indexing(x) for x in index] return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
index = V.graph.sizevars.simplify(index) index = V.graph.sizevars.simplify(index)
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
replacements = { replacements = {

View File

@ -311,7 +311,7 @@ def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar):
@functools.lru_cache @functools.lru_cache
def stride_at(index: sympy.Expr, var: sympy.Symbol): def stride_at(index: sympy.Expr, var: sympy.Symbol):
replacement = {var: var + 1} replacement = {var: var + 1}
new_index = sympy_subs(index, replacement) new_index = sympy_subs(index, replacement) # type: ignore[arg-type]
return sympy.simplify(new_index - index) return sympy.simplify(new_index - index)
@ -620,10 +620,10 @@ class CppCSEVariable(CSEVariable):
""" """
for s in index.free_symbols: for s in index.free_symbols:
if s in V.kernel.itervars: if s in V.kernel.itervars:
self.dependent_itervars.add(s) self.dependent_itervars.add(s) # type: ignore[arg-type]
elif s.name in V.kernel.cse.varname_map: elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined]
self.dependent_itervars.update( self.dependent_itervars.update(
V.kernel.cse.varname_map[s.name].dependent_itervars V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined]
) )
def depends_on(self, itervar: sympy.Symbol): def depends_on(self, itervar: sympy.Symbol):
@ -1512,10 +1512,10 @@ class CppKernel(Kernel):
Check if an index has free symbol CppCSEVariable that depends on `itervar`. Check if an index has free symbol CppCSEVariable that depends on `itervar`.
""" """
return any( return any(
self.cse.varname_map[s.name].depends_on(itervar) self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined]
for s in index.free_symbols for s in index.free_symbols
if s.name in self.cse.varname_map if s.name in self.cse.varname_map # type: ignore[attr-defined]
and isinstance(self.cse.varname_map[s.name], CppCSEVariable) and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined]
) )
def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol):
@ -1894,9 +1894,9 @@ class CppVecKernel(CppKernel):
) )
replacements = {} replacements = {}
for indirect_var in ( for indirect_var in (
self.cse.varname_map[s.name] self.cse.varname_map[s.name] # type: ignore[attr-defined]
for s in index.free_symbols for s in index.free_symbols
if s.name.startswith("tmp") if s.name.startswith("tmp") # type: ignore[attr-defined]
): ):
assert isinstance(indirect_var, CppCSEVariable) assert isinstance(indirect_var, CppCSEVariable)
if indirect_var.is_vec: if indirect_var.is_vec:
@ -1911,7 +1911,7 @@ class CppVecKernel(CppKernel):
) )
else: else:
load_mask = f"{self._load_mask} != 0" load_mask = f"{self._load_mask} != 0"
index = sympy_subs(index, replacements) index = sympy_subs(index, replacements) # type: ignore[arg-type]
index = self.scale_index_with_offset( index = self.scale_index_with_offset(
index, itervar_idx=self.tiling_idx, offset=itervar_inner index, itervar_idx=self.tiling_idx, offset=itervar_inner
) )
@ -1934,7 +1934,7 @@ class CppVecKernel(CppKernel):
code.writeline(f"if ({load_mask})") code.writeline(f"if ({load_mask})")
stack.enter_context(code.indent()) stack.enter_context(code.indent())
code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};")
load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type]
code.writeline(f"return {load_line};") code.writeline(f"return {load_line};")
code.writeline("()") code.writeline("()")
csevar = self.cse.generate(buffer, code) csevar = self.cse.generate(buffer, code)
@ -2296,7 +2296,7 @@ class CppTile2DKernel(CppVecKernel):
# vector load inside the kernel inner loop # vector load inside the kernel inner loop
loadbuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}" loadbuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}"
dtype = V.graph.get_dtype(name) dtype = V.graph.get_dtype(name)
line = self._get_vec_load_line(loadbuf, 0, dtype) line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type]
csevar = self.cse.generate(self.loads, line) csevar = self.cse.generate(self.loads, line)
csevar.update_on_args("load", (name, index), {}) csevar.update_on_args("load", (name, index), {})
assert isinstance(csevar, CppCSEVariable) assert isinstance(csevar, CppCSEVariable)

View File

@ -326,7 +326,7 @@ class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
@cache_on_self @cache_on_self
def get_symbolic_size(self) -> sympy.Expr: def get_symbolic_size(self) -> sympy.Expr:
if not self.allocations: if not self.allocations:
return 0 return 0 # type: ignore[return-value]
return sympy.Max(*[x.get_symbolic_size() for x in self.allocations]) return sympy.Max(*[x.get_symbolic_size() for x in self.allocations])
def is_empty(self): def is_empty(self):

View File

@ -197,10 +197,10 @@ class BlockPtrOptions:
for i in range(len(self.shape)): for i in range(len(self.shape)):
if ( if (
self.block_shape[i] != "1" self.block_shape[i] != "1"
and not V.graph.sizevars.statically_known_equals(self.strides[i], 0) and not V.graph.sizevars.statically_known_equals(self.strides[i], 0) # type: ignore[arg-type]
and not V.graph.sizevars.statically_known_multiple_of( and not V.graph.sizevars.statically_known_multiple_of(
self.shape[i], self.shape[i],
config.triton.max_block[self.block_shape[i][0]], config.triton.max_block[self.block_shape[i][0]], # type: ignore[arg-type]
) )
and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK") and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK")
): ):
@ -1280,7 +1280,7 @@ class TritonKernel(Kernel):
if hint > threshold: if hint > threshold:
return False return False
# will need to recompile if we cross a larger power of 2 boundary # will need to recompile if we cross a larger power of 2 boundary
V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint)) V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint)) # type: ignore[arg-type]
return True return True
def set_last_usage(self, nodes): def set_last_usage(self, nodes):
@ -1370,7 +1370,7 @@ class TritonKernel(Kernel):
for length_group in lengths: for length_group in lengths:
return_getters = [] return_getters = []
for size in length_group: for size in length_group:
if sv.statically_known_equals(size, 1): if sv.statically_known_equals(size, 1): # type: ignore[arg-type]
return_getters.append(lambda _: sympy.Integer(0)) return_getters.append(lambda _: sympy.Integer(0))
continue continue
@ -1461,7 +1461,7 @@ class TritonKernel(Kernel):
if symbol not in self.range_tree_nodes: if symbol not in self.range_tree_nodes:
# Non-iterated variables, e.g. strides # Non-iterated variables, e.g. strides
continue continue
entry = self.range_tree_nodes[symbol] entry = self.range_tree_nodes[symbol] # type: ignore[index]
assert isinstance(entry.parent, IterationRangesRoot) assert isinstance(entry.parent, IterationRangesRoot)
index_numels[entry.parent.index] *= entry.length index_numels[entry.parent.index] *= entry.length
@ -1469,7 +1469,7 @@ class TritonKernel(Kernel):
# numels, then it must be broadcasted. # numels, then it must be broadcasted.
simplify = V.graph.sizevars.simplify simplify = V.graph.sizevars.simplify
return any( return any(
simplify(idx_range) != simplify(iter_range) simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type]
for idx_range, iter_range in zip(index_numels, self.numels) for idx_range, iter_range in zip(index_numels, self.numels)
) )
@ -1603,7 +1603,7 @@ class TritonKernel(Kernel):
[m[s] for s in strides], [m[s] for s in strides],
m[offset], m[offset],
range_trees, range_trees,
mask_vars, mask_vars, # type: ignore[arg-type]
) )
expand_str = None expand_str = None
@ -1630,7 +1630,7 @@ class TritonKernel(Kernel):
self.filter_masks(mask_vars) self.filter_masks(mask_vars)
mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex) return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex) # type: ignore[arg-type]
def active_range_trees(self, reorder=False): def active_range_trees(self, reorder=False):
trees = [ trees = [
@ -1647,7 +1647,7 @@ class TritonKernel(Kernel):
def filter_masks(self, mask_vars): def filter_masks(self, mask_vars):
for tree in self.range_trees: for tree in self.range_trees:
# Masks are superfluous if we only have one element # Masks are superfluous if we only have one element
if V.graph.sizevars.statically_known_equals(tree.numel, 1): if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type]
mask_vars.discard(f"{tree.prefix}mask") mask_vars.discard(f"{tree.prefix}mask")
continue continue
# Masks are superfluous if numel is a multiple of BLOCK # Masks are superfluous if numel is a multiple of BLOCK
@ -1659,7 +1659,7 @@ class TritonKernel(Kernel):
# never need to do a masked load to handle stragglers at the end. # never need to do a masked load to handle stragglers at the end.
# It's faster to avoid masking at all. But it is sound to always # It's faster to avoid masking at all. But it is sound to always
# mask. # mask.
if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type]
mask_vars.discard(f"{tree.prefix}mask") mask_vars.discard(f"{tree.prefix}mask")
def var_ranges(self): def var_ranges(self):
@ -1676,13 +1676,13 @@ class TritonKernel(Kernel):
# if indexing expression is complicated, we precompute it on the host side # if indexing expression is complicated, we precompute it on the host side
# and send the result as a kernel argument # and send the result as a kernel argument
replacements = {} replacements = {}
for ps in self.range_tree_nodes[sym].precomputed_args(): for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index]
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
if len(replacements) > 0: if len(replacements) > 0:
self.range_tree_nodes[sym].expr = sympy_subs( self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index]
self.range_tree_nodes[sym].expr, replacements self.range_tree_nodes[sym].expr, replacements # type: ignore[index]
) )
self.range_tree_nodes[sym].codegen() self.range_tree_nodes[sym].codegen() # type: ignore[index]
return expr return expr
@contextlib.contextmanager @contextlib.contextmanager
@ -1735,7 +1735,7 @@ class TritonKernel(Kernel):
{xindex: 512, rindex: 1024} {xindex: 512, rindex: 1024}
""" """
index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()}
index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type]
strides = {} strides = {}
for range_tree in self.range_trees: for range_tree in self.range_trees:
s = sympy_index_symbol(range_tree.name) s = sympy_index_symbol(range_tree.name)
@ -1883,7 +1883,7 @@ class TritonKernel(Kernel):
result_var = self.cse.generate(load_buffer, line) result_var = self.cse.generate(load_buffer, line)
assert isinstance(result_var, TritonCSEVariable) assert isinstance(result_var, TritonCSEVariable)
result_var.mask_vars = indexing.mask_vars result_var.mask_vars = indexing.mask_vars # type: ignore[assignment]
if append_broadcast: if append_broadcast:
line = f"tl.broadcast_to({result_var}, {append_broadcast})" line = f"tl.broadcast_to({result_var}, {append_broadcast})"
@ -2410,7 +2410,7 @@ class TritonKernel(Kernel):
# note that random seed is put in V.graph.constants # note that random seed is put in V.graph.constants
const_tensor = V.graph.constants[arg_name] const_tensor = V.graph.constants[arg_name]
result.writeline( result.writeline(
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # noqa: B950 line too long f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
) )
elif isinstance(arg_sig, SizeArg): elif isinstance(arg_sig, SizeArg):
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
@ -3128,9 +3128,9 @@ class TritonScheduling(BaseScheduling):
# Only install guards for 32-bit indexing as there is no correctness # Only install guards for 32-bit indexing as there is no correctness
# issue with using 64-bit for everything # issue with using 64-bit for everything
V.graph.sizevars.guard_leq(numel, int_max) V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type]
for size in buf_sizes: for size in buf_sizes:
V.graph.sizevars.guard_leq(size, int_max) V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type]
return True return True
@staticmethod @staticmethod

View File

@ -81,7 +81,7 @@ def config_of(args: List[Union[TensorArg, SizeArg]]) -> instance_descriptor:
return False return False
if isinstance(x.expr, float): if isinstance(x.expr, float):
return False return False
return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type]
raise NotImplementedError(f"unhandled {type(x)}: {x}") raise NotImplementedError(f"unhandled {type(x)}: {x}")
if config.triton.divisible_by_16: if config.triton.divisible_by_16:

View File

@ -757,17 +757,17 @@ class WrapperCodeGen(CodeGen):
) )
for name, shape in graph_inputs_expr: for name, shape in graph_inputs_expr:
shape = V.graph.sizevars.simplify(shape) shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
if shape in needed: if shape in needed:
needed.remove(shape) needed.remove(shape) # type: ignore[arg-type]
code.writeline(f"{self.declare}{shape} = {name}{self.ending}") code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
for name, value in graph_inputs_tensors: for name, value in graph_inputs_tensors:
shapes = value.get_size() shapes = value.get_size()
for dim, shape in enumerate(shapes): for dim, shape in enumerate(shapes):
shape = V.graph.sizevars.simplify(shape) shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
if shape in needed: if shape in needed:
needed.remove(shape) needed.remove(shape) # type: ignore[arg-type]
code.writeline( code.writeline(
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}" f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
) )
@ -775,9 +775,9 @@ class WrapperCodeGen(CodeGen):
for name, value in graph_inputs_tensors: for name, value in graph_inputs_tensors:
shapes = value.get_stride() shapes = value.get_stride()
for dim, shape in enumerate(shapes): for dim, shape in enumerate(shapes):
shape = V.graph.sizevars.simplify(shape) shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
if shape in needed: if shape in needed:
needed.remove(shape) needed.remove(shape) # type: ignore[arg-type]
code.writeline( code.writeline(
f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}" f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
) )

View File

@ -30,7 +30,7 @@ Dep = Union["MemoryDep", "StarDep", "WeakDep"]
class MemoryDep(typing.NamedTuple): class MemoryDep(typing.NamedTuple):
name: str name: str
index: sympy.Expr index: sympy.Expr # type: ignore[assignment]
var_names: Tuple[sympy.Symbol, ...] var_names: Tuple[sympy.Symbol, ...]
size: Tuple[sympy.Expr, ...] size: Tuple[sympy.Expr, ...]
@ -77,7 +77,7 @@ class MemoryDep(typing.NamedTuple):
return isinstance(self.index, (int, sympy.Integer)) return isinstance(self.index, (int, sympy.Integer))
def is_indirect(self) -> bool: def is_indirect(self) -> bool:
return any(is_indirect(v.name) for v in self.index.free_symbols) return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined]
class StarDep(typing.NamedTuple): class StarDep(typing.NamedTuple):
@ -146,7 +146,7 @@ class WeakDep(typing.NamedTuple):
class IndexExprDep(typing.NamedTuple): class IndexExprDep(typing.NamedTuple):
index: sympy.Expr index: sympy.Expr # type: ignore[assignment]
var_names: Tuple[sympy.Symbol, ...] var_names: Tuple[sympy.Symbol, ...]
size: Tuple[sympy.Expr, ...] size: Tuple[sympy.Expr, ...]
@ -235,7 +235,7 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1 k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1
) )
sizes = tuple(v for v in sizes if v != 1) sizes = tuple(v for v in sizes if v != 1)
return index, var_names, sizes return index, var_names, sizes # type: ignore[return-value]
# Try to further simplify the indexes even if simplify_loops didn't # Try to further simplify the indexes even if simplify_loops didn't
# convert it to the simplest form because of the interference from # convert it to the simplest form because of the interference from
@ -269,7 +269,7 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
# downstream users won't. Normalize this away. # downstream users won't. Normalize this away.
new_vars.pop() new_vars.pop()
new_sizes.pop() new_sizes.pop()
return index, tuple(new_vars), tuple(new_sizes) return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type]
def load(self, name: str, index: sympy.Expr) -> str: def load(self, name: str, index: sympy.Expr) -> str:
self._reads.add(MemoryDep(name, *self.canonicalize(index))) self._reads.add(MemoryDep(name, *self.canonicalize(index)))

View File

@ -368,7 +368,7 @@ def cat_tuned_op(match, inputs, dim, *, op, shape_of):
if new_size is None: if new_size is None:
new_size = shape new_size = shape
else: else:
new_size[notdim] = V.graph.sizevars.guard_equals( new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload]
shape[notdim], new_size[notdim] shape[notdim], new_size[notdim]
) )
new_size[dim] += shape[dim] new_size[dim] += shape[dim]

View File

@ -208,9 +208,9 @@ def ir_node_to_tensor(x, guard_shape=True):
size = [shape_fn(s) for s in x.get_size()] size = [shape_fn(s) for s in x.get_size()]
stride: StrideType stride: StrideType
if is_storage_and_layout(x): if is_storage_and_layout(x):
stride = [shape_fn(s) for s in x.get_layout().stride] stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc]
else: else:
stride = make_contiguous_strides_for(size) stride = make_contiguous_strides_for(size) # type: ignore[arg-type]
dtype = x.get_dtype() dtype = x.get_dtype()
device = x.get_device() device = x.get_device()
size = convert_shape_to_symint(size) size = convert_shape_to_symint(size)
@ -294,7 +294,7 @@ class IRNode:
return sympy_product(self.get_size()) return sympy_product(self.get_size())
def is_zero_elements(self): def is_zero_elements(self):
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
def realize(self): def realize(self):
""" """
@ -1150,7 +1150,7 @@ class Reduction(Loops):
): ):
reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel]) reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
need_mask = not V.graph.sizevars.is_expr_static_and_true( need_mask = not V.graph.sizevars.is_expr_static_and_true(
sympy.Eq(reduction_numel % split, 0) sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type]
) )
def wrapper_fn(index, reduction_index): def wrapper_fn(index, reduction_index):
@ -1286,7 +1286,7 @@ class Reduction(Loops):
wrapper_fn, wrapper_fn,
ranges, ranges,
reduction_ranges, reduction_ranges,
[*ranges, split], [*ranges, split], # type: ignore[list-item]
[block_size], [block_size],
reduction_type, reduction_type,
split, split,
@ -1522,7 +1522,7 @@ class WelfordReduction(Reduction):
""" """
reduction_numel = sympy_product(reduction_ranges) reduction_numel = sympy_product(reduction_ranges)
need_mask = not V.graph.sizevars.is_expr_static_and_true( need_mask = not V.graph.sizevars.is_expr_static_and_true(
sympy.Eq(reduction_numel % split, 0) sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type]
) )
if need_mask and reduction_type != "welford_combine": if need_mask and reduction_type != "welford_combine":
@ -1562,7 +1562,7 @@ class WelfordReduction(Reduction):
) )
for loader in inner_fns for loader in inner_fns
), ),
[*ranges, split], [*ranges, split], # type: ignore[list-item]
[block_size], [block_size],
reduction_type, reduction_type,
reduction_hint, reduction_hint,
@ -1587,7 +1587,7 @@ class WelfordReduction(Reduction):
for i in intermediates for i in intermediates
), ),
ranges, ranges,
[split], [split], # type: ignore[list-item]
# welford_reduce turns one input into three outputs, which are combined with welford_combine # welford_reduce turns one input into three outputs, which are combined with welford_combine
"welford_combine", "welford_combine",
reduction_hint, reduction_hint,
@ -1680,7 +1680,7 @@ class Scan(Loops):
scan_numel = sizevars.simplify(sympy_product(scan_ranges)) scan_numel = sizevars.simplify(sympy_product(scan_ranges))
# Scan with a single element is just a copy # Scan with a single element is just a copy
if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type]
return Pointwise.create( return Pointwise.create(
device=device, device=device,
dtype=dtype, dtype=dtype,
@ -1993,7 +1993,7 @@ class PermuteView(BaseView):
def make_reindexer(self): def make_reindexer(self):
inv = {j: i for i, j in enumerate(self.dims)} inv = {j: i for i, j in enumerate(self.dims)}
inv = [inv[i] for i in range(len(self.dims))] inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index]
assert set(inv) == set(range(len(self.dims))) assert set(inv) == set(range(len(self.dims)))
def reindex(index): def reindex(index):
@ -2214,12 +2214,12 @@ class View(GenericView):
while stack_old: while stack_old:
size_old = stack_old.pop() size_old = stack_old.pop()
V.graph.sizevars.guard_equals(size_old, 1) V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type]
view_expr.append(sympy.Integer(0)) view_expr.append(sympy.Integer(0))
while stack_new: while stack_new:
var, size_new = stack_new.pop() var, size_new = stack_new.pop()
V.graph.sizevars.guard_equals(size_new, 1) V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type]
view_expr = list(reversed(view_expr)) view_expr = list(reversed(view_expr))
assert len(view_expr) == len(old_size) assert len(view_expr) == len(old_size)
@ -2227,7 +2227,7 @@ class View(GenericView):
def reindex(index): def reindex(index):
assert len(index) == len(vars), (len(index), len(vars)) assert len(index) == len(vars), (len(index), len(vars))
replacements = dict(zip(vars, index)) replacements = dict(zip(vars, index))
return tuple(sympy_subs(x, replacements) for x in view_expr) return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type]
return reindex return reindex
@ -2487,7 +2487,7 @@ class Layout(IRNode):
if ndim not in [4, 5]: if ndim not in [4, 5]:
return False return False
for left, right, size in zip( for left, right, size in zip(
self.stride, make_channels_last_strides_for(self.size), self.size self.stride, make_channels_last_strides_for(self.size), self.size # type: ignore[arg-type]
): ):
if size != 1 and left != right: if size != 1 and left != right:
return False return False
@ -2564,7 +2564,7 @@ class Layout(IRNode):
) )
def storage_size(self) -> sympy.Expr: def storage_size(self) -> sympy.Expr:
return compute_required_storage_length(self.size, self.stride, self.offset) return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value]
class FixedLayout(Layout): class FixedLayout(Layout):
@ -2583,9 +2583,9 @@ class FixedLayout(Layout):
super().__init__( super().__init__(
device, device,
dtype, dtype,
size, size, # type: ignore[arg-type]
stride, stride,
offset, offset, # type: ignore[arg-type]
) )
def make_indexer(self): def make_indexer(self):
@ -2715,7 +2715,7 @@ class AliasedLayout(Layout):
return True return True
from .compile_fx import ALIGNMENT from .compile_fx import ALIGNMENT
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type]
class NoneLayout(IRNode): class NoneLayout(IRNode):
@ -2882,7 +2882,7 @@ class Buffer(IRNode):
self.layout = self.layout.as_same_order(stride) self.layout = self.layout.as_same_order(stride)
def is_zero_elements(self): def is_zero_elements(self):
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
def make_loader(self): def make_loader(self):
# Loading from a zero-element buffer is a no-op # Loading from a zero-element buffer is a no-op
@ -3177,7 +3177,7 @@ class ComputedBuffer(Buffer):
else: else:
indices = index_vars indices = index_vars
stride_lengths = [ stride_lengths = [
V.graph.sizevars.stride_hints(expr, indices) for expr in reads V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type]
] ]
from .scheduler import pick_loop_order from .scheduler import pick_loop_order
@ -3907,13 +3907,13 @@ class ExternKernel(InputsKernel):
else: else:
type_ = None type_ = None
kwargs.append( kwargs.append(
V.graph.wrapper_code.val_to_cpp_arg_str( V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type]
type_, v, self.is_legacy_abi_kernel() type_, v, self.is_legacy_abi_kernel()
) )
) )
else: else:
kwargs = [ kwargs = [
f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc]
for k, v in self.kwargs.items() for k, v in self.kwargs.items()
] ]
return kwargs return kwargs
@ -3962,7 +3962,7 @@ class ExternKernel(InputsKernel):
_, add_var = var_builder("c") _, add_var = var_builder("c")
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
index = sympy_subs(sympy.expand(index), replacement) index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type]
return index, tuple(new_sizes) return index, tuple(new_sizes)
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:

View File

@ -243,14 +243,14 @@ def conv_layout(
ir.ir_node_to_tensor(weight, guard_shape=True), ir.ir_node_to_tensor(weight, guard_shape=True),
ir.ir_node_to_tensor(bias, guard_shape=True), ir.ir_node_to_tensor(bias, guard_shape=True),
stride, stride,
tuple(V.graph.sizevars.size_hint(p) for p in padding), tuple(V.graph.sizevars.size_hint(p) for p in padding), # type: ignore[arg-type]
dilation, dilation,
transposed, transposed,
tuple(V.graph.sizevars.size_hint(p) for p in output_padding), tuple(V.graph.sizevars.size_hint(p) for p in output_padding), # type: ignore[arg-type]
groups, groups,
) )
sizes = ir.convert_shape_to_inductor(output.size()) sizes = ir.convert_shape_to_inductor(output.size())
stride = ir.convert_shape_to_inductor(output.stride()) stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
return ir.FixedLayout( return ir.FixedLayout(
x.get_device(), x.get_device(),
@ -419,7 +419,7 @@ def convolution(
and not transposed and not transposed
and is_zeros(output_padding) and is_zeros(output_padding)
# there are some odd models where this check fails (e.g. shufflenet_v2_x1_0) # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
): ):
if ( if (
is_ones(kernel_shape) is_ones(kernel_shape)

View File

@ -36,7 +36,7 @@ def filtered_configs(
m = max( m = max(
next_power_of_2( next_power_of_2(
V.graph.sizevars.size_hint( V.graph.sizevars.size_hint(
m, fallback=torch._inductor.config.unbacked_symint_fallback m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
) )
), ),
min_block_size, min_block_size,
@ -44,7 +44,7 @@ def filtered_configs(
n = max( n = max(
next_power_of_2( next_power_of_2(
V.graph.sizevars.size_hint( V.graph.sizevars.size_hint(
n, fallback=torch._inductor.config.unbacked_symint_fallback n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
) )
), ),
min_block_size, min_block_size,
@ -52,7 +52,7 @@ def filtered_configs(
k = max( k = max(
next_power_of_2( next_power_of_2(
V.graph.sizevars.size_hint( V.graph.sizevars.size_hint(
k, fallback=torch._inductor.config.unbacked_symint_fallback k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
) )
), ),
min_block_size, min_block_size,

View File

@ -983,8 +983,8 @@ def pointwise_cat(inputs, dim=0):
inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = [] inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = []
prev_end = 0 prev_end = 0
for inp in inputs: for inp in inputs:
inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type]
prev_end = inputs_ranges[-1][-1] prev_end = inputs_ranges[-1][-1] # type: ignore[assignment]
inputs_loaders = [inp.make_loader() for inp in inputs] inputs_loaders = [inp.make_loader() for inp in inputs]
@ -1215,7 +1215,7 @@ def unfold(x, dimension, size, step):
dim_size = sizes[dim] dim_size = sizes[dim]
sizevars = V.graph.sizevars sizevars = V.graph.sizevars
sizevars.guard_leq(size, dim_size) sizevars.guard_leq(size, dim_size)
sizevars.guard_lt(0, step) sizevars.guard_lt(0, step) # type: ignore[arg-type]
new_dim_size = FloorDiv(dim_size - size, step) + 1 new_dim_size = FloorDiv(dim_size - size, step) + 1
if sizevars.size_hint(dim_size) > 0: if sizevars.size_hint(dim_size) > 0:
@ -2371,8 +2371,8 @@ def select_scatter(x, src, dim: int, index: int):
dim = _validate_dim(x, dim, 0) dim = _validate_dim(x, dim, 0)
if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
index = index + x.get_size()[dim] index = index + x.get_size()[dim]
V.graph.sizevars.guard_leq(0, index) V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type]
V.graph.sizevars.guard_lt(index, x.get_size()[dim]) V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
src = expand(unsqueeze(src, dim), x.get_size()) src = expand(unsqueeze(src, dim), x.get_size())
src_loader = src.make_loader() src_loader = src.make_loader()
@ -3673,7 +3673,7 @@ def constant_pad_nd(x, padding, fill_value=0):
# if padding is a complicated expression, hoist it # if padding is a complicated expression, hoist it
bounds_precomp: List[Tuple[sympy.Symbol, Any]] = [] bounds_precomp: List[Tuple[sympy.Symbol, Any]] = []
for l, h in bounds: for l, h in bounds:
bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type]
output_size = list(sizes[:n]) output_size = list(sizes[:n])
mask_sizes = [] mask_sizes = []
@ -3770,7 +3770,7 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
# Sliding windows must start within the input or left padding # Sliding windows must start within the input or left padding
x_alt -= 1 # type: ignore[assignment] x_alt -= 1 # type: ignore[assignment]
V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
if V.graph.sizevars.size_hint(x_out - x_alt) == 0: if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
# ceil mode is actually a no-op, lets guard on that # ceil mode is actually a no-op, lets guard on that
V.graph.sizevars.guard_equals(x_out, x_alt) V.graph.sizevars.guard_equals(x_out, x_alt)

View File

@ -128,7 +128,7 @@ class SizeVarAllocator:
# actual iteration range is to size-1 # actual iteration range is to size-1
iter_ranges_zero = {k: 0 for k, v in var_ranges.items()} iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
base_lowest = sympy_subs(base, iter_ranges_zero) base_lowest = sympy_subs(base, iter_ranges_zero)
if self.statically_known_leq(0, base_lowest): if self.statically_known_leq(0, base_lowest): # type: ignore[arg-type]
# can't replace with indexing div if base can be negative # can't replace with indexing div if base can be negative
base_pos = True base_pos = True
else: else:
@ -272,7 +272,7 @@ class SizeVarAllocator:
""" """
Returns a bool indicating if it is sound to optimize as if left and right are equal. Returns a bool indicating if it is sound to optimize as if left and right are equal.
""" """
return self.is_expr_static_and_true(sympy.Eq(left, right)) return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type]
# See Note - [On Statically Known] # See Note - [On Statically Known]
def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool: def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
@ -307,7 +307,7 @@ class SizeVarAllocator:
Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
""" """
expr = sympy.Eq(numerator % denominator, 0) expr = sympy.Eq(numerator % denominator, 0)
return self.is_expr_static_and_true(expr) return self.is_expr_static_and_true(expr) # type: ignore[arg-type]
# The guard functions require you to ALREADY KNOW that a particular # The guard functions require you to ALREADY KNOW that a particular
# condition holds. If you don't know (you want to guard on an expression # condition holds. If you don't know (you want to guard on an expression
@ -316,9 +316,9 @@ class SizeVarAllocator:
def guard_equals(self, left: Expr, right: Expr) -> Expr: def guard_equals(self, left: Expr, right: Expr) -> Expr:
if isinstance(left, Expr): if isinstance(left, Expr):
left = sympy_subs(left, self.inv_precomputed_replacements) left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
if isinstance(right, Expr): if isinstance(right, Expr):
right = sympy_subs(right, self.inv_precomputed_replacements) right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
assert self.shape_env.evaluate_expr(sympy.Eq(left, right)) assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
return left return left
@ -329,16 +329,16 @@ class SizeVarAllocator:
assert self.shape_env.evaluate_expr(sympy.Lt(left, right)) assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
def expect_true(self, expr: Expr, *, msg: str) -> None: def expect_true(self, expr: Expr, *, msg: str) -> None:
expr = sympy_subs(expr, self.inv_precomputed_replacements) expr = sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
self.shape_env.defer_runtime_assert(expr, msg, fx_node=None) self.shape_env.defer_runtime_assert(expr, msg, fx_node=None)
def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr: def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr:
# Prefer returning the expression without unbacked symints # Prefer returning the expression without unbacked symints
if self.shape_env.is_unbacked_symint(left): if self.shape_env.is_unbacked_symint(left):
self.expect_true(sympy.Eq(left, right), msg=msg) self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
return right return right
elif self.shape_env.is_unbacked_symint(right): elif self.shape_env.is_unbacked_symint(right):
self.expect_true(sympy.Eq(left, right), msg=msg) self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
return left return left
else: else:
return self.guard_equals(left, right) return self.guard_equals(left, right)
@ -399,8 +399,8 @@ class SizeVarAllocator:
return [self.evaluate_static_shape(x) for x in left] return [self.evaluate_static_shape(x) for x in left]
def remove_precomputed_replacements(self, expr: Expr) -> Expr: def remove_precomputed_replacements(self, expr: Expr) -> Expr:
if any(s.name.startswith("ps") for s in expr.free_symbols): if any(s.name.startswith("ps") for s in expr.free_symbols): # type: ignore[attr-defined]
return sympy_subs(expr, self.inv_precomputed_replacements) return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
return expr return expr
def symbolic_hint(self, expr: Expr) -> Expr: def symbolic_hint(self, expr: Expr) -> Expr:
@ -410,7 +410,7 @@ class SizeVarAllocator:
return expr return expr
free_symbols = expr.free_symbols free_symbols = expr.free_symbols
if not free_symbols: if not free_symbols:
return int(expr) return int(expr) # type: ignore[return-value]
expr = self.remove_precomputed_replacements(expr) expr = self.remove_precomputed_replacements(expr)
return sympy_subs(expr, self.var_to_val) return sympy_subs(expr, self.var_to_val)
@ -422,9 +422,9 @@ class SizeVarAllocator:
s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols
} }
if all(vr is not None for vr in sym_vrs.values()): if all(vr is not None for vr in sym_vrs.values()):
expr_vr = bound_sympy(expr, sym_vrs) expr_vr = bound_sympy(expr, sym_vrs) # type: ignore[arg-type]
lower = self.size_hint(expr_vr.lower) lower = self.size_hint(expr_vr.lower) # type: ignore[arg-type]
upper = self.size_hint(expr_vr.upper) upper = self.size_hint(expr_vr.upper) # type: ignore[arg-type]
fallback = min(max(fallback, lower), upper) fallback = min(max(fallback, lower), upper)
return fallback return fallback
try: try:
@ -522,8 +522,8 @@ class SizeVarAllocator:
support_vars: Optional[List[sympy.Symbol]] = None, support_vars: Optional[List[sympy.Symbol]] = None,
) -> List[int]: ) -> List[int]:
for v in index.free_symbols: for v in index.free_symbols:
if v.name.startswith("indirect"): if v.name.startswith("indirect"): # type: ignore[attr-defined]
index = sympy_subs(index, {v: 0}) index = sympy_subs(index, {v: 0}) # type: ignore[dict-item]
result = [] result = []
for s in self.stride_vars(index, vars, support_vars): for s in self.stride_vars(index, vars, support_vars):
try: try:
@ -538,7 +538,7 @@ class SizeVarAllocator:
order.sort(key=lambda x: (strides[x] == 0, strides[x])) order.sort(key=lambda x: (strides[x] == 0, strides[x]))
return order return order
def lookup_precomputed_size(self, expr: Expr) -> sympy.Symbol: def lookup_precomputed_size(self, expr: Expr) -> Expr:
if ( if (
isinstance(expr, (int, sympy.Symbol, sympy.Number)) isinstance(expr, (int, sympy.Symbol, sympy.Number))
or expr.is_number or expr.is_number

View File

@ -569,8 +569,8 @@ def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.E
if isinstance(replacement, str): if isinstance(replacement, str):
return sympy.Symbol( return sympy.Symbol(
replacement, replacement,
integer=replaced.is_integer, integer=replaced.is_integer, # type: ignore[attr-defined]
nonnegative=replaced.is_nonnegative, nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
) )
else: else:
return replacement return replacement
@ -582,11 +582,11 @@ def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.E
def free_symbol_startswith(index: sympy.Expr, prefix: str): def free_symbol_startswith(index: sympy.Expr, prefix: str):
return any(v.name.startswith(prefix) for v in index.free_symbols) return any(v.name.startswith(prefix) for v in index.free_symbols) # type: ignore[attr-defined]
def free_symbol_has(index: sympy.Expr, pattern: str): def free_symbol_has(index: sympy.Expr, pattern: str):
return any(pattern in v.name for v in index.free_symbols) return any(pattern in v.name for v in index.free_symbols) # type: ignore[attr-defined]
def is_symbolic(a: Any) -> bool: def is_symbolic(a: Any) -> bool:
@ -1081,7 +1081,7 @@ def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
assert isinstance( assert isinstance(
val, sympy.Expr val, sympy.Expr
), "only support sympy.Expr as input to get_sympy_Expr_dtype" ), "only support sympy.Expr as input to get_sympy_Expr_dtype"
if val.is_integer: if val.is_integer: # type: ignore[attr-defined]
return torch.int64 return torch.int64
else: else:
return torch.float64 return torch.float64

View File

@ -468,7 +468,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
weight_fake_quant: fake quant module for weight weight_fake_quant: fake quant module for weight
""" """
_FLOAT_MODULE = nni.ConvReLU1d _FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment]
_FLOAT_CONV_MODULE = nn.Conv1d _FLOAT_CONV_MODULE = nn.Conv1d
_FLOAT_BN_MODULE: None = None _FLOAT_BN_MODULE: None = None
_FLOAT_RELU_MODULE = nn.ReLU _FLOAT_RELU_MODULE = nn.ReLU
@ -600,7 +600,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
weight_fake_quant: fake quant module for weight weight_fake_quant: fake quant module for weight
""" """
_FLOAT_MODULE = nni.ConvReLU2d _FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment]
_FLOAT_CONV_MODULE = nn.Conv2d _FLOAT_CONV_MODULE = nn.Conv2d
_FLOAT_BN_MODULE: None = None _FLOAT_BN_MODULE: None = None
_FLOAT_RELU_MODULE = nn.ReLU _FLOAT_RELU_MODULE = nn.ReLU
@ -773,7 +773,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
weight_fake_quant: fake quant module for weight weight_fake_quant: fake quant module for weight
""" """
_FLOAT_MODULE = nni.ConvReLU3d _FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment]
_FLOAT_CONV_MODULE = nn.Conv3d _FLOAT_CONV_MODULE = nn.Conv3d
_FLOAT_BN_MODULE: None = None _FLOAT_BN_MODULE: None = None
_FLOAT_RELU_MODULE = nn.ReLU _FLOAT_RELU_MODULE = nn.ReLU