mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Remove follow_imports = skip from sympy (#118469)
dmypy silently ignores follow_imports = skip, so to get parity between
dmypy and mypy we have to suck it up and type: ignore all of the sympy
typing problems.
The suppressions were added automatically with the following script generated by GPT-4:
```
import re
# Read the error file
with open("error_file.txt", "r") as f:
errors = f.readlines()
# Parse the lines with errors and error types
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
# Insert ignore comments in the source files
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118469
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432, #118467, #118468
This commit is contained in:
parent
59b4d2cd40
commit
cad79bd0bb
|
|
@ -138,6 +138,7 @@ init_command = [
|
||||||
'numpy==1.26.0 ; python_version >= "3.9"',
|
'numpy==1.26.0 ; python_version >= "3.9"',
|
||||||
'expecttest==0.1.6',
|
'expecttest==0.1.6',
|
||||||
'mypy==1.7.0',
|
'mypy==1.7.0',
|
||||||
|
'sympy==1.11.1',
|
||||||
'types-requests==2.27.25',
|
'types-requests==2.27.25',
|
||||||
'types-PyYAML==6.0.7',
|
'types-PyYAML==6.0.7',
|
||||||
'types-tabulate==0.8.8',
|
'types-tabulate==0.8.8',
|
||||||
|
|
|
||||||
1
mypy.ini
1
mypy.ini
|
|
@ -179,7 +179,6 @@ ignore_missing_imports = True
|
||||||
|
|
||||||
[mypy-sympy.*]
|
[mypy-sympy.*]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
follow_imports = skip
|
|
||||||
|
|
||||||
[mypy-hypothesis.*]
|
[mypy-hypothesis.*]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
|
||||||
|
|
@ -1143,7 +1143,7 @@ class Kernel(CodeGen):
|
||||||
):
|
):
|
||||||
# Skip CSE since this doesn't return an expression
|
# Skip CSE since this doesn't return an expression
|
||||||
|
|
||||||
if var.bounds.lower < 0:
|
if var.bounds.lower < 0: # type: ignore[operator]
|
||||||
new_bounds = ValueRanges.unknown()
|
new_bounds = ValueRanges.unknown()
|
||||||
if var.bounds != ValueRanges.unknown() and isinstance(
|
if var.bounds != ValueRanges.unknown() and isinstance(
|
||||||
size, sympy.Number
|
size, sympy.Number
|
||||||
|
|
@ -1154,13 +1154,13 @@ class Kernel(CodeGen):
|
||||||
neg = var.bounds & ValueRanges(-sympy.oo, -1)
|
neg = var.bounds & ValueRanges(-sympy.oo, -1)
|
||||||
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
|
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
|
||||||
# We don't have a good way of representing the empty range
|
# We don't have a good way of representing the empty range
|
||||||
if var.bounds.upper >= 0:
|
if var.bounds.upper >= 0: # type: ignore[operator]
|
||||||
pos = var.bounds & ValueRanges(0, sympy.oo)
|
pos = var.bounds & ValueRanges(0, sympy.oo)
|
||||||
new_bounds = new_bounds | pos
|
new_bounds = new_bounds | pos
|
||||||
|
|
||||||
stm = ops.add(var, self.rename_indexing(size))
|
stm = ops.add(var, self.rename_indexing(size))
|
||||||
# Mixed negative and non-negative
|
# Mixed negative and non-negative
|
||||||
if var.bounds.upper >= 0:
|
if var.bounds.upper >= 0: # type: ignore[operator]
|
||||||
lt = ops.lt(var, "0")
|
lt = ops.lt(var, "0")
|
||||||
stm = ops.where(lt, stm, var)
|
stm = ops.where(lt, stm, var)
|
||||||
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
|
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
|
||||||
|
|
@ -1310,7 +1310,7 @@ class Kernel(CodeGen):
|
||||||
# adds the necessary kernel args for index expressions
|
# adds the necessary kernel args for index expressions
|
||||||
# and renames variables in index expressions to kernel arg names
|
# and renames variables in index expressions to kernel arg names
|
||||||
if isinstance(index, (list, tuple)):
|
if isinstance(index, (list, tuple)):
|
||||||
return [self.rename_indexing(x) for x in index]
|
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
|
||||||
index = V.graph.sizevars.simplify(index)
|
index = V.graph.sizevars.simplify(index)
|
||||||
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
|
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
|
||||||
replacements = {
|
replacements = {
|
||||||
|
|
|
||||||
|
|
@ -311,7 +311,7 @@ def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar):
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def stride_at(index: sympy.Expr, var: sympy.Symbol):
|
def stride_at(index: sympy.Expr, var: sympy.Symbol):
|
||||||
replacement = {var: var + 1}
|
replacement = {var: var + 1}
|
||||||
new_index = sympy_subs(index, replacement)
|
new_index = sympy_subs(index, replacement) # type: ignore[arg-type]
|
||||||
return sympy.simplify(new_index - index)
|
return sympy.simplify(new_index - index)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -620,10 +620,10 @@ class CppCSEVariable(CSEVariable):
|
||||||
"""
|
"""
|
||||||
for s in index.free_symbols:
|
for s in index.free_symbols:
|
||||||
if s in V.kernel.itervars:
|
if s in V.kernel.itervars:
|
||||||
self.dependent_itervars.add(s)
|
self.dependent_itervars.add(s) # type: ignore[arg-type]
|
||||||
elif s.name in V.kernel.cse.varname_map:
|
elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined]
|
||||||
self.dependent_itervars.update(
|
self.dependent_itervars.update(
|
||||||
V.kernel.cse.varname_map[s.name].dependent_itervars
|
V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
def depends_on(self, itervar: sympy.Symbol):
|
def depends_on(self, itervar: sympy.Symbol):
|
||||||
|
|
@ -1512,10 +1512,10 @@ class CppKernel(Kernel):
|
||||||
Check if an index has free symbol CppCSEVariable that depends on `itervar`.
|
Check if an index has free symbol CppCSEVariable that depends on `itervar`.
|
||||||
"""
|
"""
|
||||||
return any(
|
return any(
|
||||||
self.cse.varname_map[s.name].depends_on(itervar)
|
self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined]
|
||||||
for s in index.free_symbols
|
for s in index.free_symbols
|
||||||
if s.name in self.cse.varname_map
|
if s.name in self.cse.varname_map # type: ignore[attr-defined]
|
||||||
and isinstance(self.cse.varname_map[s.name], CppCSEVariable)
|
and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol):
|
def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol):
|
||||||
|
|
@ -1894,9 +1894,9 @@ class CppVecKernel(CppKernel):
|
||||||
)
|
)
|
||||||
replacements = {}
|
replacements = {}
|
||||||
for indirect_var in (
|
for indirect_var in (
|
||||||
self.cse.varname_map[s.name]
|
self.cse.varname_map[s.name] # type: ignore[attr-defined]
|
||||||
for s in index.free_symbols
|
for s in index.free_symbols
|
||||||
if s.name.startswith("tmp")
|
if s.name.startswith("tmp") # type: ignore[attr-defined]
|
||||||
):
|
):
|
||||||
assert isinstance(indirect_var, CppCSEVariable)
|
assert isinstance(indirect_var, CppCSEVariable)
|
||||||
if indirect_var.is_vec:
|
if indirect_var.is_vec:
|
||||||
|
|
@ -1911,7 +1911,7 @@ class CppVecKernel(CppKernel):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
load_mask = f"{self._load_mask} != 0"
|
load_mask = f"{self._load_mask} != 0"
|
||||||
index = sympy_subs(index, replacements)
|
index = sympy_subs(index, replacements) # type: ignore[arg-type]
|
||||||
index = self.scale_index_with_offset(
|
index = self.scale_index_with_offset(
|
||||||
index, itervar_idx=self.tiling_idx, offset=itervar_inner
|
index, itervar_idx=self.tiling_idx, offset=itervar_inner
|
||||||
)
|
)
|
||||||
|
|
@ -1934,7 +1934,7 @@ class CppVecKernel(CppKernel):
|
||||||
code.writeline(f"if ({load_mask})")
|
code.writeline(f"if ({load_mask})")
|
||||||
stack.enter_context(code.indent())
|
stack.enter_context(code.indent())
|
||||||
code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};")
|
code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};")
|
||||||
load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype)
|
load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type]
|
||||||
code.writeline(f"return {load_line};")
|
code.writeline(f"return {load_line};")
|
||||||
code.writeline("()")
|
code.writeline("()")
|
||||||
csevar = self.cse.generate(buffer, code)
|
csevar = self.cse.generate(buffer, code)
|
||||||
|
|
@ -2296,7 +2296,7 @@ class CppTile2DKernel(CppVecKernel):
|
||||||
# vector load inside the kernel inner loop
|
# vector load inside the kernel inner loop
|
||||||
loadbuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}"
|
loadbuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}"
|
||||||
dtype = V.graph.get_dtype(name)
|
dtype = V.graph.get_dtype(name)
|
||||||
line = self._get_vec_load_line(loadbuf, 0, dtype)
|
line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type]
|
||||||
csevar = self.cse.generate(self.loads, line)
|
csevar = self.cse.generate(self.loads, line)
|
||||||
csevar.update_on_args("load", (name, index), {})
|
csevar.update_on_args("load", (name, index), {})
|
||||||
assert isinstance(csevar, CppCSEVariable)
|
assert isinstance(csevar, CppCSEVariable)
|
||||||
|
|
|
||||||
|
|
@ -326,7 +326,7 @@ class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
|
||||||
@cache_on_self
|
@cache_on_self
|
||||||
def get_symbolic_size(self) -> sympy.Expr:
|
def get_symbolic_size(self) -> sympy.Expr:
|
||||||
if not self.allocations:
|
if not self.allocations:
|
||||||
return 0
|
return 0 # type: ignore[return-value]
|
||||||
return sympy.Max(*[x.get_symbolic_size() for x in self.allocations])
|
return sympy.Max(*[x.get_symbolic_size() for x in self.allocations])
|
||||||
|
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
|
|
|
||||||
|
|
@ -197,10 +197,10 @@ class BlockPtrOptions:
|
||||||
for i in range(len(self.shape)):
|
for i in range(len(self.shape)):
|
||||||
if (
|
if (
|
||||||
self.block_shape[i] != "1"
|
self.block_shape[i] != "1"
|
||||||
and not V.graph.sizevars.statically_known_equals(self.strides[i], 0)
|
and not V.graph.sizevars.statically_known_equals(self.strides[i], 0) # type: ignore[arg-type]
|
||||||
and not V.graph.sizevars.statically_known_multiple_of(
|
and not V.graph.sizevars.statically_known_multiple_of(
|
||||||
self.shape[i],
|
self.shape[i],
|
||||||
config.triton.max_block[self.block_shape[i][0]],
|
config.triton.max_block[self.block_shape[i][0]], # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK")
|
and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK")
|
||||||
):
|
):
|
||||||
|
|
@ -1280,7 +1280,7 @@ class TritonKernel(Kernel):
|
||||||
if hint > threshold:
|
if hint > threshold:
|
||||||
return False
|
return False
|
||||||
# will need to recompile if we cross a larger power of 2 boundary
|
# will need to recompile if we cross a larger power of 2 boundary
|
||||||
V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint))
|
V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint)) # type: ignore[arg-type]
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def set_last_usage(self, nodes):
|
def set_last_usage(self, nodes):
|
||||||
|
|
@ -1370,7 +1370,7 @@ class TritonKernel(Kernel):
|
||||||
for length_group in lengths:
|
for length_group in lengths:
|
||||||
return_getters = []
|
return_getters = []
|
||||||
for size in length_group:
|
for size in length_group:
|
||||||
if sv.statically_known_equals(size, 1):
|
if sv.statically_known_equals(size, 1): # type: ignore[arg-type]
|
||||||
return_getters.append(lambda _: sympy.Integer(0))
|
return_getters.append(lambda _: sympy.Integer(0))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -1461,7 +1461,7 @@ class TritonKernel(Kernel):
|
||||||
if symbol not in self.range_tree_nodes:
|
if symbol not in self.range_tree_nodes:
|
||||||
# Non-iterated variables, e.g. strides
|
# Non-iterated variables, e.g. strides
|
||||||
continue
|
continue
|
||||||
entry = self.range_tree_nodes[symbol]
|
entry = self.range_tree_nodes[symbol] # type: ignore[index]
|
||||||
assert isinstance(entry.parent, IterationRangesRoot)
|
assert isinstance(entry.parent, IterationRangesRoot)
|
||||||
index_numels[entry.parent.index] *= entry.length
|
index_numels[entry.parent.index] *= entry.length
|
||||||
|
|
||||||
|
|
@ -1469,7 +1469,7 @@ class TritonKernel(Kernel):
|
||||||
# numels, then it must be broadcasted.
|
# numels, then it must be broadcasted.
|
||||||
simplify = V.graph.sizevars.simplify
|
simplify = V.graph.sizevars.simplify
|
||||||
return any(
|
return any(
|
||||||
simplify(idx_range) != simplify(iter_range)
|
simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type]
|
||||||
for idx_range, iter_range in zip(index_numels, self.numels)
|
for idx_range, iter_range in zip(index_numels, self.numels)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1603,7 +1603,7 @@ class TritonKernel(Kernel):
|
||||||
[m[s] for s in strides],
|
[m[s] for s in strides],
|
||||||
m[offset],
|
m[offset],
|
||||||
range_trees,
|
range_trees,
|
||||||
mask_vars,
|
mask_vars, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
expand_str = None
|
expand_str = None
|
||||||
|
|
@ -1630,7 +1630,7 @@ class TritonKernel(Kernel):
|
||||||
self.filter_masks(mask_vars)
|
self.filter_masks(mask_vars)
|
||||||
|
|
||||||
mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
|
mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
|
||||||
return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex)
|
return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex) # type: ignore[arg-type]
|
||||||
|
|
||||||
def active_range_trees(self, reorder=False):
|
def active_range_trees(self, reorder=False):
|
||||||
trees = [
|
trees = [
|
||||||
|
|
@ -1647,7 +1647,7 @@ class TritonKernel(Kernel):
|
||||||
def filter_masks(self, mask_vars):
|
def filter_masks(self, mask_vars):
|
||||||
for tree in self.range_trees:
|
for tree in self.range_trees:
|
||||||
# Masks are superfluous if we only have one element
|
# Masks are superfluous if we only have one element
|
||||||
if V.graph.sizevars.statically_known_equals(tree.numel, 1):
|
if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type]
|
||||||
mask_vars.discard(f"{tree.prefix}mask")
|
mask_vars.discard(f"{tree.prefix}mask")
|
||||||
continue
|
continue
|
||||||
# Masks are superfluous if numel is a multiple of BLOCK
|
# Masks are superfluous if numel is a multiple of BLOCK
|
||||||
|
|
@ -1659,7 +1659,7 @@ class TritonKernel(Kernel):
|
||||||
# never need to do a masked load to handle stragglers at the end.
|
# never need to do a masked load to handle stragglers at the end.
|
||||||
# It's faster to avoid masking at all. But it is sound to always
|
# It's faster to avoid masking at all. But it is sound to always
|
||||||
# mask.
|
# mask.
|
||||||
if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block):
|
if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type]
|
||||||
mask_vars.discard(f"{tree.prefix}mask")
|
mask_vars.discard(f"{tree.prefix}mask")
|
||||||
|
|
||||||
def var_ranges(self):
|
def var_ranges(self):
|
||||||
|
|
@ -1676,13 +1676,13 @@ class TritonKernel(Kernel):
|
||||||
# if indexing expression is complicated, we precompute it on the host side
|
# if indexing expression is complicated, we precompute it on the host side
|
||||||
# and send the result as a kernel argument
|
# and send the result as a kernel argument
|
||||||
replacements = {}
|
replacements = {}
|
||||||
for ps in self.range_tree_nodes[sym].precomputed_args():
|
for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index]
|
||||||
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
|
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
|
||||||
if len(replacements) > 0:
|
if len(replacements) > 0:
|
||||||
self.range_tree_nodes[sym].expr = sympy_subs(
|
self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index]
|
||||||
self.range_tree_nodes[sym].expr, replacements
|
self.range_tree_nodes[sym].expr, replacements # type: ignore[index]
|
||||||
)
|
)
|
||||||
self.range_tree_nodes[sym].codegen()
|
self.range_tree_nodes[sym].codegen() # type: ignore[index]
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
|
@ -1735,7 +1735,7 @@ class TritonKernel(Kernel):
|
||||||
{xindex: 512, rindex: 1024}
|
{xindex: 512, rindex: 1024}
|
||||||
"""
|
"""
|
||||||
index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()}
|
index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()}
|
||||||
index_in_tile_vars = sympy_subs(index, index_to_tile_indexes)
|
index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type]
|
||||||
strides = {}
|
strides = {}
|
||||||
for range_tree in self.range_trees:
|
for range_tree in self.range_trees:
|
||||||
s = sympy_index_symbol(range_tree.name)
|
s = sympy_index_symbol(range_tree.name)
|
||||||
|
|
@ -1883,7 +1883,7 @@ class TritonKernel(Kernel):
|
||||||
|
|
||||||
result_var = self.cse.generate(load_buffer, line)
|
result_var = self.cse.generate(load_buffer, line)
|
||||||
assert isinstance(result_var, TritonCSEVariable)
|
assert isinstance(result_var, TritonCSEVariable)
|
||||||
result_var.mask_vars = indexing.mask_vars
|
result_var.mask_vars = indexing.mask_vars # type: ignore[assignment]
|
||||||
|
|
||||||
if append_broadcast:
|
if append_broadcast:
|
||||||
line = f"tl.broadcast_to({result_var}, {append_broadcast})"
|
line = f"tl.broadcast_to({result_var}, {append_broadcast})"
|
||||||
|
|
@ -2410,7 +2410,7 @@ class TritonKernel(Kernel):
|
||||||
# note that random seed is put in V.graph.constants
|
# note that random seed is put in V.graph.constants
|
||||||
const_tensor = V.graph.constants[arg_name]
|
const_tensor = V.graph.constants[arg_name]
|
||||||
result.writeline(
|
result.writeline(
|
||||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # noqa: B950 line too long
|
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
|
||||||
)
|
)
|
||||||
elif isinstance(arg_sig, SizeArg):
|
elif isinstance(arg_sig, SizeArg):
|
||||||
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
|
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
|
||||||
|
|
@ -3128,9 +3128,9 @@ class TritonScheduling(BaseScheduling):
|
||||||
|
|
||||||
# Only install guards for 32-bit indexing as there is no correctness
|
# Only install guards for 32-bit indexing as there is no correctness
|
||||||
# issue with using 64-bit for everything
|
# issue with using 64-bit for everything
|
||||||
V.graph.sizevars.guard_leq(numel, int_max)
|
V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type]
|
||||||
for size in buf_sizes:
|
for size in buf_sizes:
|
||||||
V.graph.sizevars.guard_leq(size, int_max)
|
V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type]
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ def config_of(args: List[Union[TensorArg, SizeArg]]) -> instance_descriptor:
|
||||||
return False
|
return False
|
||||||
if isinstance(x.expr, float):
|
if isinstance(x.expr, float):
|
||||||
return False
|
return False
|
||||||
return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment)
|
return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type]
|
||||||
raise NotImplementedError(f"unhandled {type(x)}: {x}")
|
raise NotImplementedError(f"unhandled {type(x)}: {x}")
|
||||||
|
|
||||||
if config.triton.divisible_by_16:
|
if config.triton.divisible_by_16:
|
||||||
|
|
|
||||||
|
|
@ -757,17 +757,17 @@ class WrapperCodeGen(CodeGen):
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, shape in graph_inputs_expr:
|
for name, shape in graph_inputs_expr:
|
||||||
shape = V.graph.sizevars.simplify(shape)
|
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
|
||||||
if shape in needed:
|
if shape in needed:
|
||||||
needed.remove(shape)
|
needed.remove(shape) # type: ignore[arg-type]
|
||||||
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
|
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
|
||||||
|
|
||||||
for name, value in graph_inputs_tensors:
|
for name, value in graph_inputs_tensors:
|
||||||
shapes = value.get_size()
|
shapes = value.get_size()
|
||||||
for dim, shape in enumerate(shapes):
|
for dim, shape in enumerate(shapes):
|
||||||
shape = V.graph.sizevars.simplify(shape)
|
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
|
||||||
if shape in needed:
|
if shape in needed:
|
||||||
needed.remove(shape)
|
needed.remove(shape) # type: ignore[arg-type]
|
||||||
code.writeline(
|
code.writeline(
|
||||||
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
|
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
|
||||||
)
|
)
|
||||||
|
|
@ -775,9 +775,9 @@ class WrapperCodeGen(CodeGen):
|
||||||
for name, value in graph_inputs_tensors:
|
for name, value in graph_inputs_tensors:
|
||||||
shapes = value.get_stride()
|
shapes = value.get_stride()
|
||||||
for dim, shape in enumerate(shapes):
|
for dim, shape in enumerate(shapes):
|
||||||
shape = V.graph.sizevars.simplify(shape)
|
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
|
||||||
if shape in needed:
|
if shape in needed:
|
||||||
needed.remove(shape)
|
needed.remove(shape) # type: ignore[arg-type]
|
||||||
code.writeline(
|
code.writeline(
|
||||||
f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
|
f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ Dep = Union["MemoryDep", "StarDep", "WeakDep"]
|
||||||
|
|
||||||
class MemoryDep(typing.NamedTuple):
|
class MemoryDep(typing.NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
index: sympy.Expr
|
index: sympy.Expr # type: ignore[assignment]
|
||||||
var_names: Tuple[sympy.Symbol, ...]
|
var_names: Tuple[sympy.Symbol, ...]
|
||||||
size: Tuple[sympy.Expr, ...]
|
size: Tuple[sympy.Expr, ...]
|
||||||
|
|
||||||
|
|
@ -77,7 +77,7 @@ class MemoryDep(typing.NamedTuple):
|
||||||
return isinstance(self.index, (int, sympy.Integer))
|
return isinstance(self.index, (int, sympy.Integer))
|
||||||
|
|
||||||
def is_indirect(self) -> bool:
|
def is_indirect(self) -> bool:
|
||||||
return any(is_indirect(v.name) for v in self.index.free_symbols)
|
return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
class StarDep(typing.NamedTuple):
|
class StarDep(typing.NamedTuple):
|
||||||
|
|
@ -146,7 +146,7 @@ class WeakDep(typing.NamedTuple):
|
||||||
|
|
||||||
|
|
||||||
class IndexExprDep(typing.NamedTuple):
|
class IndexExprDep(typing.NamedTuple):
|
||||||
index: sympy.Expr
|
index: sympy.Expr # type: ignore[assignment]
|
||||||
var_names: Tuple[sympy.Symbol, ...]
|
var_names: Tuple[sympy.Symbol, ...]
|
||||||
size: Tuple[sympy.Expr, ...]
|
size: Tuple[sympy.Expr, ...]
|
||||||
|
|
||||||
|
|
@ -235,7 +235,7 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
||||||
k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1
|
k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1
|
||||||
)
|
)
|
||||||
sizes = tuple(v for v in sizes if v != 1)
|
sizes = tuple(v for v in sizes if v != 1)
|
||||||
return index, var_names, sizes
|
return index, var_names, sizes # type: ignore[return-value]
|
||||||
|
|
||||||
# Try to further simplify the indexes even if simplify_loops didn't
|
# Try to further simplify the indexes even if simplify_loops didn't
|
||||||
# convert it to the simplest form because of the interference from
|
# convert it to the simplest form because of the interference from
|
||||||
|
|
@ -269,7 +269,7 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
||||||
# downstream users won't. Normalize this away.
|
# downstream users won't. Normalize this away.
|
||||||
new_vars.pop()
|
new_vars.pop()
|
||||||
new_sizes.pop()
|
new_sizes.pop()
|
||||||
return index, tuple(new_vars), tuple(new_sizes)
|
return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type]
|
||||||
|
|
||||||
def load(self, name: str, index: sympy.Expr) -> str:
|
def load(self, name: str, index: sympy.Expr) -> str:
|
||||||
self._reads.add(MemoryDep(name, *self.canonicalize(index)))
|
self._reads.add(MemoryDep(name, *self.canonicalize(index)))
|
||||||
|
|
|
||||||
|
|
@ -368,7 +368,7 @@ def cat_tuned_op(match, inputs, dim, *, op, shape_of):
|
||||||
if new_size is None:
|
if new_size is None:
|
||||||
new_size = shape
|
new_size = shape
|
||||||
else:
|
else:
|
||||||
new_size[notdim] = V.graph.sizevars.guard_equals(
|
new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload]
|
||||||
shape[notdim], new_size[notdim]
|
shape[notdim], new_size[notdim]
|
||||||
)
|
)
|
||||||
new_size[dim] += shape[dim]
|
new_size[dim] += shape[dim]
|
||||||
|
|
|
||||||
|
|
@ -208,9 +208,9 @@ def ir_node_to_tensor(x, guard_shape=True):
|
||||||
size = [shape_fn(s) for s in x.get_size()]
|
size = [shape_fn(s) for s in x.get_size()]
|
||||||
stride: StrideType
|
stride: StrideType
|
||||||
if is_storage_and_layout(x):
|
if is_storage_and_layout(x):
|
||||||
stride = [shape_fn(s) for s in x.get_layout().stride]
|
stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc]
|
||||||
else:
|
else:
|
||||||
stride = make_contiguous_strides_for(size)
|
stride = make_contiguous_strides_for(size) # type: ignore[arg-type]
|
||||||
dtype = x.get_dtype()
|
dtype = x.get_dtype()
|
||||||
device = x.get_device()
|
device = x.get_device()
|
||||||
size = convert_shape_to_symint(size)
|
size = convert_shape_to_symint(size)
|
||||||
|
|
@ -294,7 +294,7 @@ class IRNode:
|
||||||
return sympy_product(self.get_size())
|
return sympy_product(self.get_size())
|
||||||
|
|
||||||
def is_zero_elements(self):
|
def is_zero_elements(self):
|
||||||
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))
|
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
|
||||||
|
|
||||||
def realize(self):
|
def realize(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -1150,7 +1150,7 @@ class Reduction(Loops):
|
||||||
):
|
):
|
||||||
reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
|
reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
|
||||||
need_mask = not V.graph.sizevars.is_expr_static_and_true(
|
need_mask = not V.graph.sizevars.is_expr_static_and_true(
|
||||||
sympy.Eq(reduction_numel % split, 0)
|
sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
def wrapper_fn(index, reduction_index):
|
def wrapper_fn(index, reduction_index):
|
||||||
|
|
@ -1286,7 +1286,7 @@ class Reduction(Loops):
|
||||||
wrapper_fn,
|
wrapper_fn,
|
||||||
ranges,
|
ranges,
|
||||||
reduction_ranges,
|
reduction_ranges,
|
||||||
[*ranges, split],
|
[*ranges, split], # type: ignore[list-item]
|
||||||
[block_size],
|
[block_size],
|
||||||
reduction_type,
|
reduction_type,
|
||||||
split,
|
split,
|
||||||
|
|
@ -1522,7 +1522,7 @@ class WelfordReduction(Reduction):
|
||||||
"""
|
"""
|
||||||
reduction_numel = sympy_product(reduction_ranges)
|
reduction_numel = sympy_product(reduction_ranges)
|
||||||
need_mask = not V.graph.sizevars.is_expr_static_and_true(
|
need_mask = not V.graph.sizevars.is_expr_static_and_true(
|
||||||
sympy.Eq(reduction_numel % split, 0)
|
sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
if need_mask and reduction_type != "welford_combine":
|
if need_mask and reduction_type != "welford_combine":
|
||||||
|
|
@ -1562,7 +1562,7 @@ class WelfordReduction(Reduction):
|
||||||
)
|
)
|
||||||
for loader in inner_fns
|
for loader in inner_fns
|
||||||
),
|
),
|
||||||
[*ranges, split],
|
[*ranges, split], # type: ignore[list-item]
|
||||||
[block_size],
|
[block_size],
|
||||||
reduction_type,
|
reduction_type,
|
||||||
reduction_hint,
|
reduction_hint,
|
||||||
|
|
@ -1587,7 +1587,7 @@ class WelfordReduction(Reduction):
|
||||||
for i in intermediates
|
for i in intermediates
|
||||||
),
|
),
|
||||||
ranges,
|
ranges,
|
||||||
[split],
|
[split], # type: ignore[list-item]
|
||||||
# welford_reduce turns one input into three outputs, which are combined with welford_combine
|
# welford_reduce turns one input into three outputs, which are combined with welford_combine
|
||||||
"welford_combine",
|
"welford_combine",
|
||||||
reduction_hint,
|
reduction_hint,
|
||||||
|
|
@ -1680,7 +1680,7 @@ class Scan(Loops):
|
||||||
scan_numel = sizevars.simplify(sympy_product(scan_ranges))
|
scan_numel = sizevars.simplify(sympy_product(scan_ranges))
|
||||||
|
|
||||||
# Scan with a single element is just a copy
|
# Scan with a single element is just a copy
|
||||||
if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)):
|
if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type]
|
||||||
return Pointwise.create(
|
return Pointwise.create(
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
@ -1993,7 +1993,7 @@ class PermuteView(BaseView):
|
||||||
|
|
||||||
def make_reindexer(self):
|
def make_reindexer(self):
|
||||||
inv = {j: i for i, j in enumerate(self.dims)}
|
inv = {j: i for i, j in enumerate(self.dims)}
|
||||||
inv = [inv[i] for i in range(len(self.dims))]
|
inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index]
|
||||||
assert set(inv) == set(range(len(self.dims)))
|
assert set(inv) == set(range(len(self.dims)))
|
||||||
|
|
||||||
def reindex(index):
|
def reindex(index):
|
||||||
|
|
@ -2214,12 +2214,12 @@ class View(GenericView):
|
||||||
|
|
||||||
while stack_old:
|
while stack_old:
|
||||||
size_old = stack_old.pop()
|
size_old = stack_old.pop()
|
||||||
V.graph.sizevars.guard_equals(size_old, 1)
|
V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type]
|
||||||
view_expr.append(sympy.Integer(0))
|
view_expr.append(sympy.Integer(0))
|
||||||
|
|
||||||
while stack_new:
|
while stack_new:
|
||||||
var, size_new = stack_new.pop()
|
var, size_new = stack_new.pop()
|
||||||
V.graph.sizevars.guard_equals(size_new, 1)
|
V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type]
|
||||||
|
|
||||||
view_expr = list(reversed(view_expr))
|
view_expr = list(reversed(view_expr))
|
||||||
assert len(view_expr) == len(old_size)
|
assert len(view_expr) == len(old_size)
|
||||||
|
|
@ -2227,7 +2227,7 @@ class View(GenericView):
|
||||||
def reindex(index):
|
def reindex(index):
|
||||||
assert len(index) == len(vars), (len(index), len(vars))
|
assert len(index) == len(vars), (len(index), len(vars))
|
||||||
replacements = dict(zip(vars, index))
|
replacements = dict(zip(vars, index))
|
||||||
return tuple(sympy_subs(x, replacements) for x in view_expr)
|
return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type]
|
||||||
|
|
||||||
return reindex
|
return reindex
|
||||||
|
|
||||||
|
|
@ -2487,7 +2487,7 @@ class Layout(IRNode):
|
||||||
if ndim not in [4, 5]:
|
if ndim not in [4, 5]:
|
||||||
return False
|
return False
|
||||||
for left, right, size in zip(
|
for left, right, size in zip(
|
||||||
self.stride, make_channels_last_strides_for(self.size), self.size
|
self.stride, make_channels_last_strides_for(self.size), self.size # type: ignore[arg-type]
|
||||||
):
|
):
|
||||||
if size != 1 and left != right:
|
if size != 1 and left != right:
|
||||||
return False
|
return False
|
||||||
|
|
@ -2564,7 +2564,7 @@ class Layout(IRNode):
|
||||||
)
|
)
|
||||||
|
|
||||||
def storage_size(self) -> sympy.Expr:
|
def storage_size(self) -> sympy.Expr:
|
||||||
return compute_required_storage_length(self.size, self.stride, self.offset)
|
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value]
|
||||||
|
|
||||||
|
|
||||||
class FixedLayout(Layout):
|
class FixedLayout(Layout):
|
||||||
|
|
@ -2583,9 +2583,9 @@ class FixedLayout(Layout):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
size,
|
size, # type: ignore[arg-type]
|
||||||
stride,
|
stride,
|
||||||
offset,
|
offset, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_indexer(self):
|
def make_indexer(self):
|
||||||
|
|
@ -2715,7 +2715,7 @@ class AliasedLayout(Layout):
|
||||||
return True
|
return True
|
||||||
from .compile_fx import ALIGNMENT
|
from .compile_fx import ALIGNMENT
|
||||||
|
|
||||||
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
|
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
class NoneLayout(IRNode):
|
class NoneLayout(IRNode):
|
||||||
|
|
@ -2882,7 +2882,7 @@ class Buffer(IRNode):
|
||||||
self.layout = self.layout.as_same_order(stride)
|
self.layout = self.layout.as_same_order(stride)
|
||||||
|
|
||||||
def is_zero_elements(self):
|
def is_zero_elements(self):
|
||||||
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))
|
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
|
||||||
|
|
||||||
def make_loader(self):
|
def make_loader(self):
|
||||||
# Loading from a zero-element buffer is a no-op
|
# Loading from a zero-element buffer is a no-op
|
||||||
|
|
@ -3177,7 +3177,7 @@ class ComputedBuffer(Buffer):
|
||||||
else:
|
else:
|
||||||
indices = index_vars
|
indices = index_vars
|
||||||
stride_lengths = [
|
stride_lengths = [
|
||||||
V.graph.sizevars.stride_hints(expr, indices) for expr in reads
|
V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type]
|
||||||
]
|
]
|
||||||
from .scheduler import pick_loop_order
|
from .scheduler import pick_loop_order
|
||||||
|
|
||||||
|
|
@ -3907,13 +3907,13 @@ class ExternKernel(InputsKernel):
|
||||||
else:
|
else:
|
||||||
type_ = None
|
type_ = None
|
||||||
kwargs.append(
|
kwargs.append(
|
||||||
V.graph.wrapper_code.val_to_cpp_arg_str(
|
V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type]
|
||||||
type_, v, self.is_legacy_abi_kernel()
|
type_, v, self.is_legacy_abi_kernel()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
kwargs = [
|
kwargs = [
|
||||||
f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}"
|
f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc]
|
||||||
for k, v in self.kwargs.items()
|
for k, v in self.kwargs.items()
|
||||||
]
|
]
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
@ -3962,7 +3962,7 @@ class ExternKernel(InputsKernel):
|
||||||
_, add_var = var_builder("c")
|
_, add_var = var_builder("c")
|
||||||
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
|
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
|
||||||
|
|
||||||
index = sympy_subs(sympy.expand(index), replacement)
|
index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type]
|
||||||
return index, tuple(new_sizes)
|
return index, tuple(new_sizes)
|
||||||
|
|
||||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||||
|
|
|
||||||
|
|
@ -243,14 +243,14 @@ def conv_layout(
|
||||||
ir.ir_node_to_tensor(weight, guard_shape=True),
|
ir.ir_node_to_tensor(weight, guard_shape=True),
|
||||||
ir.ir_node_to_tensor(bias, guard_shape=True),
|
ir.ir_node_to_tensor(bias, guard_shape=True),
|
||||||
stride,
|
stride,
|
||||||
tuple(V.graph.sizevars.size_hint(p) for p in padding),
|
tuple(V.graph.sizevars.size_hint(p) for p in padding), # type: ignore[arg-type]
|
||||||
dilation,
|
dilation,
|
||||||
transposed,
|
transposed,
|
||||||
tuple(V.graph.sizevars.size_hint(p) for p in output_padding),
|
tuple(V.graph.sizevars.size_hint(p) for p in output_padding), # type: ignore[arg-type]
|
||||||
groups,
|
groups,
|
||||||
)
|
)
|
||||||
sizes = ir.convert_shape_to_inductor(output.size())
|
sizes = ir.convert_shape_to_inductor(output.size())
|
||||||
stride = ir.convert_shape_to_inductor(output.stride())
|
stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
|
||||||
|
|
||||||
return ir.FixedLayout(
|
return ir.FixedLayout(
|
||||||
x.get_device(),
|
x.get_device(),
|
||||||
|
|
@ -419,7 +419,7 @@ def convolution(
|
||||||
and not transposed
|
and not transposed
|
||||||
and is_zeros(output_padding)
|
and is_zeros(output_padding)
|
||||||
# there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
|
# there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
|
||||||
and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1])
|
and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
is_ones(kernel_shape)
|
is_ones(kernel_shape)
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ def filtered_configs(
|
||||||
m = max(
|
m = max(
|
||||||
next_power_of_2(
|
next_power_of_2(
|
||||||
V.graph.sizevars.size_hint(
|
V.graph.sizevars.size_hint(
|
||||||
m, fallback=torch._inductor.config.unbacked_symint_fallback
|
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
min_block_size,
|
min_block_size,
|
||||||
|
|
@ -44,7 +44,7 @@ def filtered_configs(
|
||||||
n = max(
|
n = max(
|
||||||
next_power_of_2(
|
next_power_of_2(
|
||||||
V.graph.sizevars.size_hint(
|
V.graph.sizevars.size_hint(
|
||||||
n, fallback=torch._inductor.config.unbacked_symint_fallback
|
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
min_block_size,
|
min_block_size,
|
||||||
|
|
@ -52,7 +52,7 @@ def filtered_configs(
|
||||||
k = max(
|
k = max(
|
||||||
next_power_of_2(
|
next_power_of_2(
|
||||||
V.graph.sizevars.size_hint(
|
V.graph.sizevars.size_hint(
|
||||||
k, fallback=torch._inductor.config.unbacked_symint_fallback
|
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
min_block_size,
|
min_block_size,
|
||||||
|
|
|
||||||
|
|
@ -983,8 +983,8 @@ def pointwise_cat(inputs, dim=0):
|
||||||
inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = []
|
inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = []
|
||||||
prev_end = 0
|
prev_end = 0
|
||||||
for inp in inputs:
|
for inp in inputs:
|
||||||
inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim]))
|
inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type]
|
||||||
prev_end = inputs_ranges[-1][-1]
|
prev_end = inputs_ranges[-1][-1] # type: ignore[assignment]
|
||||||
|
|
||||||
inputs_loaders = [inp.make_loader() for inp in inputs]
|
inputs_loaders = [inp.make_loader() for inp in inputs]
|
||||||
|
|
||||||
|
|
@ -1215,7 +1215,7 @@ def unfold(x, dimension, size, step):
|
||||||
dim_size = sizes[dim]
|
dim_size = sizes[dim]
|
||||||
sizevars = V.graph.sizevars
|
sizevars = V.graph.sizevars
|
||||||
sizevars.guard_leq(size, dim_size)
|
sizevars.guard_leq(size, dim_size)
|
||||||
sizevars.guard_lt(0, step)
|
sizevars.guard_lt(0, step) # type: ignore[arg-type]
|
||||||
|
|
||||||
new_dim_size = FloorDiv(dim_size - size, step) + 1
|
new_dim_size = FloorDiv(dim_size - size, step) + 1
|
||||||
if sizevars.size_hint(dim_size) > 0:
|
if sizevars.size_hint(dim_size) > 0:
|
||||||
|
|
@ -2371,8 +2371,8 @@ def select_scatter(x, src, dim: int, index: int):
|
||||||
dim = _validate_dim(x, dim, 0)
|
dim = _validate_dim(x, dim, 0)
|
||||||
if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
|
if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
|
||||||
index = index + x.get_size()[dim]
|
index = index + x.get_size()[dim]
|
||||||
V.graph.sizevars.guard_leq(0, index)
|
V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type]
|
||||||
V.graph.sizevars.guard_lt(index, x.get_size()[dim])
|
V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
|
||||||
src = expand(unsqueeze(src, dim), x.get_size())
|
src = expand(unsqueeze(src, dim), x.get_size())
|
||||||
src_loader = src.make_loader()
|
src_loader = src.make_loader()
|
||||||
|
|
||||||
|
|
@ -3673,7 +3673,7 @@ def constant_pad_nd(x, padding, fill_value=0):
|
||||||
# if padding is a complicated expression, hoist it
|
# if padding is a complicated expression, hoist it
|
||||||
bounds_precomp: List[Tuple[sympy.Symbol, Any]] = []
|
bounds_precomp: List[Tuple[sympy.Symbol, Any]] = []
|
||||||
for l, h in bounds:
|
for l, h in bounds:
|
||||||
bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h))
|
bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type]
|
||||||
|
|
||||||
output_size = list(sizes[:n])
|
output_size = list(sizes[:n])
|
||||||
mask_sizes = []
|
mask_sizes = []
|
||||||
|
|
@ -3770,7 +3770,7 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
|
||||||
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
|
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
|
||||||
# Sliding windows must start within the input or left padding
|
# Sliding windows must start within the input or left padding
|
||||||
x_alt -= 1 # type: ignore[assignment]
|
x_alt -= 1 # type: ignore[assignment]
|
||||||
V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i])
|
V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
|
||||||
if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
|
if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
|
||||||
# ceil mode is actually a no-op, lets guard on that
|
# ceil mode is actually a no-op, lets guard on that
|
||||||
V.graph.sizevars.guard_equals(x_out, x_alt)
|
V.graph.sizevars.guard_equals(x_out, x_alt)
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,7 @@ class SizeVarAllocator:
|
||||||
# actual iteration range is to size-1
|
# actual iteration range is to size-1
|
||||||
iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
|
iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
|
||||||
base_lowest = sympy_subs(base, iter_ranges_zero)
|
base_lowest = sympy_subs(base, iter_ranges_zero)
|
||||||
if self.statically_known_leq(0, base_lowest):
|
if self.statically_known_leq(0, base_lowest): # type: ignore[arg-type]
|
||||||
# can't replace with indexing div if base can be negative
|
# can't replace with indexing div if base can be negative
|
||||||
base_pos = True
|
base_pos = True
|
||||||
else:
|
else:
|
||||||
|
|
@ -272,7 +272,7 @@ class SizeVarAllocator:
|
||||||
"""
|
"""
|
||||||
Returns a bool indicating if it is sound to optimize as if left and right are equal.
|
Returns a bool indicating if it is sound to optimize as if left and right are equal.
|
||||||
"""
|
"""
|
||||||
return self.is_expr_static_and_true(sympy.Eq(left, right))
|
return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type]
|
||||||
|
|
||||||
# See Note - [On Statically Known]
|
# See Note - [On Statically Known]
|
||||||
def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
|
def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
|
||||||
|
|
@ -307,7 +307,7 @@ class SizeVarAllocator:
|
||||||
Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
|
Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
|
||||||
"""
|
"""
|
||||||
expr = sympy.Eq(numerator % denominator, 0)
|
expr = sympy.Eq(numerator % denominator, 0)
|
||||||
return self.is_expr_static_and_true(expr)
|
return self.is_expr_static_and_true(expr) # type: ignore[arg-type]
|
||||||
|
|
||||||
# The guard functions require you to ALREADY KNOW that a particular
|
# The guard functions require you to ALREADY KNOW that a particular
|
||||||
# condition holds. If you don't know (you want to guard on an expression
|
# condition holds. If you don't know (you want to guard on an expression
|
||||||
|
|
@ -316,9 +316,9 @@ class SizeVarAllocator:
|
||||||
|
|
||||||
def guard_equals(self, left: Expr, right: Expr) -> Expr:
|
def guard_equals(self, left: Expr, right: Expr) -> Expr:
|
||||||
if isinstance(left, Expr):
|
if isinstance(left, Expr):
|
||||||
left = sympy_subs(left, self.inv_precomputed_replacements)
|
left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||||
if isinstance(right, Expr):
|
if isinstance(right, Expr):
|
||||||
right = sympy_subs(right, self.inv_precomputed_replacements)
|
right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||||
assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
|
assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
|
||||||
return left
|
return left
|
||||||
|
|
||||||
|
|
@ -329,16 +329,16 @@ class SizeVarAllocator:
|
||||||
assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
|
assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
|
||||||
|
|
||||||
def expect_true(self, expr: Expr, *, msg: str) -> None:
|
def expect_true(self, expr: Expr, *, msg: str) -> None:
|
||||||
expr = sympy_subs(expr, self.inv_precomputed_replacements)
|
expr = sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||||
self.shape_env.defer_runtime_assert(expr, msg, fx_node=None)
|
self.shape_env.defer_runtime_assert(expr, msg, fx_node=None)
|
||||||
|
|
||||||
def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr:
|
def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr:
|
||||||
# Prefer returning the expression without unbacked symints
|
# Prefer returning the expression without unbacked symints
|
||||||
if self.shape_env.is_unbacked_symint(left):
|
if self.shape_env.is_unbacked_symint(left):
|
||||||
self.expect_true(sympy.Eq(left, right), msg=msg)
|
self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
|
||||||
return right
|
return right
|
||||||
elif self.shape_env.is_unbacked_symint(right):
|
elif self.shape_env.is_unbacked_symint(right):
|
||||||
self.expect_true(sympy.Eq(left, right), msg=msg)
|
self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
|
||||||
return left
|
return left
|
||||||
else:
|
else:
|
||||||
return self.guard_equals(left, right)
|
return self.guard_equals(left, right)
|
||||||
|
|
@ -399,8 +399,8 @@ class SizeVarAllocator:
|
||||||
return [self.evaluate_static_shape(x) for x in left]
|
return [self.evaluate_static_shape(x) for x in left]
|
||||||
|
|
||||||
def remove_precomputed_replacements(self, expr: Expr) -> Expr:
|
def remove_precomputed_replacements(self, expr: Expr) -> Expr:
|
||||||
if any(s.name.startswith("ps") for s in expr.free_symbols):
|
if any(s.name.startswith("ps") for s in expr.free_symbols): # type: ignore[attr-defined]
|
||||||
return sympy_subs(expr, self.inv_precomputed_replacements)
|
return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
def symbolic_hint(self, expr: Expr) -> Expr:
|
def symbolic_hint(self, expr: Expr) -> Expr:
|
||||||
|
|
@ -410,7 +410,7 @@ class SizeVarAllocator:
|
||||||
return expr
|
return expr
|
||||||
free_symbols = expr.free_symbols
|
free_symbols = expr.free_symbols
|
||||||
if not free_symbols:
|
if not free_symbols:
|
||||||
return int(expr)
|
return int(expr) # type: ignore[return-value]
|
||||||
expr = self.remove_precomputed_replacements(expr)
|
expr = self.remove_precomputed_replacements(expr)
|
||||||
return sympy_subs(expr, self.var_to_val)
|
return sympy_subs(expr, self.var_to_val)
|
||||||
|
|
||||||
|
|
@ -422,9 +422,9 @@ class SizeVarAllocator:
|
||||||
s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols
|
s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols
|
||||||
}
|
}
|
||||||
if all(vr is not None for vr in sym_vrs.values()):
|
if all(vr is not None for vr in sym_vrs.values()):
|
||||||
expr_vr = bound_sympy(expr, sym_vrs)
|
expr_vr = bound_sympy(expr, sym_vrs) # type: ignore[arg-type]
|
||||||
lower = self.size_hint(expr_vr.lower)
|
lower = self.size_hint(expr_vr.lower) # type: ignore[arg-type]
|
||||||
upper = self.size_hint(expr_vr.upper)
|
upper = self.size_hint(expr_vr.upper) # type: ignore[arg-type]
|
||||||
fallback = min(max(fallback, lower), upper)
|
fallback = min(max(fallback, lower), upper)
|
||||||
return fallback
|
return fallback
|
||||||
try:
|
try:
|
||||||
|
|
@ -522,8 +522,8 @@ class SizeVarAllocator:
|
||||||
support_vars: Optional[List[sympy.Symbol]] = None,
|
support_vars: Optional[List[sympy.Symbol]] = None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
for v in index.free_symbols:
|
for v in index.free_symbols:
|
||||||
if v.name.startswith("indirect"):
|
if v.name.startswith("indirect"): # type: ignore[attr-defined]
|
||||||
index = sympy_subs(index, {v: 0})
|
index = sympy_subs(index, {v: 0}) # type: ignore[dict-item]
|
||||||
result = []
|
result = []
|
||||||
for s in self.stride_vars(index, vars, support_vars):
|
for s in self.stride_vars(index, vars, support_vars):
|
||||||
try:
|
try:
|
||||||
|
|
@ -538,7 +538,7 @@ class SizeVarAllocator:
|
||||||
order.sort(key=lambda x: (strides[x] == 0, strides[x]))
|
order.sort(key=lambda x: (strides[x] == 0, strides[x]))
|
||||||
return order
|
return order
|
||||||
|
|
||||||
def lookup_precomputed_size(self, expr: Expr) -> sympy.Symbol:
|
def lookup_precomputed_size(self, expr: Expr) -> Expr:
|
||||||
if (
|
if (
|
||||||
isinstance(expr, (int, sympy.Symbol, sympy.Number))
|
isinstance(expr, (int, sympy.Symbol, sympy.Number))
|
||||||
or expr.is_number
|
or expr.is_number
|
||||||
|
|
|
||||||
|
|
@ -569,8 +569,8 @@ def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.E
|
||||||
if isinstance(replacement, str):
|
if isinstance(replacement, str):
|
||||||
return sympy.Symbol(
|
return sympy.Symbol(
|
||||||
replacement,
|
replacement,
|
||||||
integer=replaced.is_integer,
|
integer=replaced.is_integer, # type: ignore[attr-defined]
|
||||||
nonnegative=replaced.is_nonnegative,
|
nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return replacement
|
return replacement
|
||||||
|
|
@ -582,11 +582,11 @@ def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.E
|
||||||
|
|
||||||
|
|
||||||
def free_symbol_startswith(index: sympy.Expr, prefix: str):
|
def free_symbol_startswith(index: sympy.Expr, prefix: str):
|
||||||
return any(v.name.startswith(prefix) for v in index.free_symbols)
|
return any(v.name.startswith(prefix) for v in index.free_symbols) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def free_symbol_has(index: sympy.Expr, pattern: str):
|
def free_symbol_has(index: sympy.Expr, pattern: str):
|
||||||
return any(pattern in v.name for v in index.free_symbols)
|
return any(pattern in v.name for v in index.free_symbols) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def is_symbolic(a: Any) -> bool:
|
def is_symbolic(a: Any) -> bool:
|
||||||
|
|
@ -1081,7 +1081,7 @@ def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
val, sympy.Expr
|
val, sympy.Expr
|
||||||
), "only support sympy.Expr as input to get_sympy_Expr_dtype"
|
), "only support sympy.Expr as input to get_sympy_Expr_dtype"
|
||||||
if val.is_integer:
|
if val.is_integer: # type: ignore[attr-defined]
|
||||||
return torch.int64
|
return torch.int64
|
||||||
else:
|
else:
|
||||||
return torch.float64
|
return torch.float64
|
||||||
|
|
|
||||||
|
|
@ -468,7 +468,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
||||||
weight_fake_quant: fake quant module for weight
|
weight_fake_quant: fake quant module for weight
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_FLOAT_MODULE = nni.ConvReLU1d
|
_FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment]
|
||||||
_FLOAT_CONV_MODULE = nn.Conv1d
|
_FLOAT_CONV_MODULE = nn.Conv1d
|
||||||
_FLOAT_BN_MODULE: None = None
|
_FLOAT_BN_MODULE: None = None
|
||||||
_FLOAT_RELU_MODULE = nn.ReLU
|
_FLOAT_RELU_MODULE = nn.ReLU
|
||||||
|
|
@ -600,7 +600,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
||||||
weight_fake_quant: fake quant module for weight
|
weight_fake_quant: fake quant module for weight
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_FLOAT_MODULE = nni.ConvReLU2d
|
_FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment]
|
||||||
_FLOAT_CONV_MODULE = nn.Conv2d
|
_FLOAT_CONV_MODULE = nn.Conv2d
|
||||||
_FLOAT_BN_MODULE: None = None
|
_FLOAT_BN_MODULE: None = None
|
||||||
_FLOAT_RELU_MODULE = nn.ReLU
|
_FLOAT_RELU_MODULE = nn.ReLU
|
||||||
|
|
@ -773,7 +773,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
||||||
weight_fake_quant: fake quant module for weight
|
weight_fake_quant: fake quant module for weight
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_FLOAT_MODULE = nni.ConvReLU3d
|
_FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment]
|
||||||
_FLOAT_CONV_MODULE = nn.Conv3d
|
_FLOAT_CONV_MODULE = nn.Conv3d
|
||||||
_FLOAT_BN_MODULE: None = None
|
_FLOAT_BN_MODULE: None = None
|
||||||
_FLOAT_RELU_MODULE = nn.ReLU
|
_FLOAT_RELU_MODULE = nn.ReLU
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user