mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Remove follow_imports = skip from sympy (#118469)
dmypy silently ignores follow_imports = skip, so to get parity between
dmypy and mypy we have to suck it up and type: ignore all of the sympy
typing problems.
The suppressions were added automatically with the following script generated by GPT-4:
```
import re
# Read the error file
with open("error_file.txt", "r") as f:
errors = f.readlines()
# Parse the lines with errors and error types
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
# Insert ignore comments in the source files
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118469
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432, #118467, #118468
This commit is contained in:
parent
59b4d2cd40
commit
cad79bd0bb
|
|
@ -138,6 +138,7 @@ init_command = [
|
|||
'numpy==1.26.0 ; python_version >= "3.9"',
|
||||
'expecttest==0.1.6',
|
||||
'mypy==1.7.0',
|
||||
'sympy==1.11.1',
|
||||
'types-requests==2.27.25',
|
||||
'types-PyYAML==6.0.7',
|
||||
'types-tabulate==0.8.8',
|
||||
|
|
|
|||
1
mypy.ini
1
mypy.ini
|
|
@ -179,7 +179,6 @@ ignore_missing_imports = True
|
|||
|
||||
[mypy-sympy.*]
|
||||
ignore_missing_imports = True
|
||||
follow_imports = skip
|
||||
|
||||
[mypy-hypothesis.*]
|
||||
ignore_missing_imports = True
|
||||
|
|
|
|||
|
|
@ -1143,7 +1143,7 @@ class Kernel(CodeGen):
|
|||
):
|
||||
# Skip CSE since this doesn't return an expression
|
||||
|
||||
if var.bounds.lower < 0:
|
||||
if var.bounds.lower < 0: # type: ignore[operator]
|
||||
new_bounds = ValueRanges.unknown()
|
||||
if var.bounds != ValueRanges.unknown() and isinstance(
|
||||
size, sympy.Number
|
||||
|
|
@ -1154,13 +1154,13 @@ class Kernel(CodeGen):
|
|||
neg = var.bounds & ValueRanges(-sympy.oo, -1)
|
||||
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
|
||||
# We don't have a good way of representing the empty range
|
||||
if var.bounds.upper >= 0:
|
||||
if var.bounds.upper >= 0: # type: ignore[operator]
|
||||
pos = var.bounds & ValueRanges(0, sympy.oo)
|
||||
new_bounds = new_bounds | pos
|
||||
|
||||
stm = ops.add(var, self.rename_indexing(size))
|
||||
# Mixed negative and non-negative
|
||||
if var.bounds.upper >= 0:
|
||||
if var.bounds.upper >= 0: # type: ignore[operator]
|
||||
lt = ops.lt(var, "0")
|
||||
stm = ops.where(lt, stm, var)
|
||||
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
|
||||
|
|
@ -1310,7 +1310,7 @@ class Kernel(CodeGen):
|
|||
# adds the necessary kernel args for index expressions
|
||||
# and renames variables in index expressions to kernel arg names
|
||||
if isinstance(index, (list, tuple)):
|
||||
return [self.rename_indexing(x) for x in index]
|
||||
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
|
||||
index = V.graph.sizevars.simplify(index)
|
||||
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
|
||||
replacements = {
|
||||
|
|
|
|||
|
|
@ -311,7 +311,7 @@ def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar):
|
|||
@functools.lru_cache
|
||||
def stride_at(index: sympy.Expr, var: sympy.Symbol):
|
||||
replacement = {var: var + 1}
|
||||
new_index = sympy_subs(index, replacement)
|
||||
new_index = sympy_subs(index, replacement) # type: ignore[arg-type]
|
||||
return sympy.simplify(new_index - index)
|
||||
|
||||
|
||||
|
|
@ -620,10 +620,10 @@ class CppCSEVariable(CSEVariable):
|
|||
"""
|
||||
for s in index.free_symbols:
|
||||
if s in V.kernel.itervars:
|
||||
self.dependent_itervars.add(s)
|
||||
elif s.name in V.kernel.cse.varname_map:
|
||||
self.dependent_itervars.add(s) # type: ignore[arg-type]
|
||||
elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined]
|
||||
self.dependent_itervars.update(
|
||||
V.kernel.cse.varname_map[s.name].dependent_itervars
|
||||
V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
def depends_on(self, itervar: sympy.Symbol):
|
||||
|
|
@ -1512,10 +1512,10 @@ class CppKernel(Kernel):
|
|||
Check if an index has free symbol CppCSEVariable that depends on `itervar`.
|
||||
"""
|
||||
return any(
|
||||
self.cse.varname_map[s.name].depends_on(itervar)
|
||||
self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined]
|
||||
for s in index.free_symbols
|
||||
if s.name in self.cse.varname_map
|
||||
and isinstance(self.cse.varname_map[s.name], CppCSEVariable)
|
||||
if s.name in self.cse.varname_map # type: ignore[attr-defined]
|
||||
and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol):
|
||||
|
|
@ -1894,9 +1894,9 @@ class CppVecKernel(CppKernel):
|
|||
)
|
||||
replacements = {}
|
||||
for indirect_var in (
|
||||
self.cse.varname_map[s.name]
|
||||
self.cse.varname_map[s.name] # type: ignore[attr-defined]
|
||||
for s in index.free_symbols
|
||||
if s.name.startswith("tmp")
|
||||
if s.name.startswith("tmp") # type: ignore[attr-defined]
|
||||
):
|
||||
assert isinstance(indirect_var, CppCSEVariable)
|
||||
if indirect_var.is_vec:
|
||||
|
|
@ -1911,7 +1911,7 @@ class CppVecKernel(CppKernel):
|
|||
)
|
||||
else:
|
||||
load_mask = f"{self._load_mask} != 0"
|
||||
index = sympy_subs(index, replacements)
|
||||
index = sympy_subs(index, replacements) # type: ignore[arg-type]
|
||||
index = self.scale_index_with_offset(
|
||||
index, itervar_idx=self.tiling_idx, offset=itervar_inner
|
||||
)
|
||||
|
|
@ -1934,7 +1934,7 @@ class CppVecKernel(CppKernel):
|
|||
code.writeline(f"if ({load_mask})")
|
||||
stack.enter_context(code.indent())
|
||||
code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};")
|
||||
load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype)
|
||||
load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type]
|
||||
code.writeline(f"return {load_line};")
|
||||
code.writeline("()")
|
||||
csevar = self.cse.generate(buffer, code)
|
||||
|
|
@ -2296,7 +2296,7 @@ class CppTile2DKernel(CppVecKernel):
|
|||
# vector load inside the kernel inner loop
|
||||
loadbuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}"
|
||||
dtype = V.graph.get_dtype(name)
|
||||
line = self._get_vec_load_line(loadbuf, 0, dtype)
|
||||
line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type]
|
||||
csevar = self.cse.generate(self.loads, line)
|
||||
csevar.update_on_args("load", (name, index), {})
|
||||
assert isinstance(csevar, CppCSEVariable)
|
||||
|
|
|
|||
|
|
@ -326,7 +326,7 @@ class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
|
|||
@cache_on_self
|
||||
def get_symbolic_size(self) -> sympy.Expr:
|
||||
if not self.allocations:
|
||||
return 0
|
||||
return 0 # type: ignore[return-value]
|
||||
return sympy.Max(*[x.get_symbolic_size() for x in self.allocations])
|
||||
|
||||
def is_empty(self):
|
||||
|
|
|
|||
|
|
@ -197,10 +197,10 @@ class BlockPtrOptions:
|
|||
for i in range(len(self.shape)):
|
||||
if (
|
||||
self.block_shape[i] != "1"
|
||||
and not V.graph.sizevars.statically_known_equals(self.strides[i], 0)
|
||||
and not V.graph.sizevars.statically_known_equals(self.strides[i], 0) # type: ignore[arg-type]
|
||||
and not V.graph.sizevars.statically_known_multiple_of(
|
||||
self.shape[i],
|
||||
config.triton.max_block[self.block_shape[i][0]],
|
||||
config.triton.max_block[self.block_shape[i][0]], # type: ignore[arg-type]
|
||||
)
|
||||
and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK")
|
||||
):
|
||||
|
|
@ -1280,7 +1280,7 @@ class TritonKernel(Kernel):
|
|||
if hint > threshold:
|
||||
return False
|
||||
# will need to recompile if we cross a larger power of 2 boundary
|
||||
V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint))
|
||||
V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint)) # type: ignore[arg-type]
|
||||
return True
|
||||
|
||||
def set_last_usage(self, nodes):
|
||||
|
|
@ -1370,7 +1370,7 @@ class TritonKernel(Kernel):
|
|||
for length_group in lengths:
|
||||
return_getters = []
|
||||
for size in length_group:
|
||||
if sv.statically_known_equals(size, 1):
|
||||
if sv.statically_known_equals(size, 1): # type: ignore[arg-type]
|
||||
return_getters.append(lambda _: sympy.Integer(0))
|
||||
continue
|
||||
|
||||
|
|
@ -1461,7 +1461,7 @@ class TritonKernel(Kernel):
|
|||
if symbol not in self.range_tree_nodes:
|
||||
# Non-iterated variables, e.g. strides
|
||||
continue
|
||||
entry = self.range_tree_nodes[symbol]
|
||||
entry = self.range_tree_nodes[symbol] # type: ignore[index]
|
||||
assert isinstance(entry.parent, IterationRangesRoot)
|
||||
index_numels[entry.parent.index] *= entry.length
|
||||
|
||||
|
|
@ -1469,7 +1469,7 @@ class TritonKernel(Kernel):
|
|||
# numels, then it must be broadcasted.
|
||||
simplify = V.graph.sizevars.simplify
|
||||
return any(
|
||||
simplify(idx_range) != simplify(iter_range)
|
||||
simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type]
|
||||
for idx_range, iter_range in zip(index_numels, self.numels)
|
||||
)
|
||||
|
||||
|
|
@ -1603,7 +1603,7 @@ class TritonKernel(Kernel):
|
|||
[m[s] for s in strides],
|
||||
m[offset],
|
||||
range_trees,
|
||||
mask_vars,
|
||||
mask_vars, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
expand_str = None
|
||||
|
|
@ -1630,7 +1630,7 @@ class TritonKernel(Kernel):
|
|||
self.filter_masks(mask_vars)
|
||||
|
||||
mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
|
||||
return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex)
|
||||
return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex) # type: ignore[arg-type]
|
||||
|
||||
def active_range_trees(self, reorder=False):
|
||||
trees = [
|
||||
|
|
@ -1647,7 +1647,7 @@ class TritonKernel(Kernel):
|
|||
def filter_masks(self, mask_vars):
|
||||
for tree in self.range_trees:
|
||||
# Masks are superfluous if we only have one element
|
||||
if V.graph.sizevars.statically_known_equals(tree.numel, 1):
|
||||
if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type]
|
||||
mask_vars.discard(f"{tree.prefix}mask")
|
||||
continue
|
||||
# Masks are superfluous if numel is a multiple of BLOCK
|
||||
|
|
@ -1659,7 +1659,7 @@ class TritonKernel(Kernel):
|
|||
# never need to do a masked load to handle stragglers at the end.
|
||||
# It's faster to avoid masking at all. But it is sound to always
|
||||
# mask.
|
||||
if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block):
|
||||
if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type]
|
||||
mask_vars.discard(f"{tree.prefix}mask")
|
||||
|
||||
def var_ranges(self):
|
||||
|
|
@ -1676,13 +1676,13 @@ class TritonKernel(Kernel):
|
|||
# if indexing expression is complicated, we precompute it on the host side
|
||||
# and send the result as a kernel argument
|
||||
replacements = {}
|
||||
for ps in self.range_tree_nodes[sym].precomputed_args():
|
||||
for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index]
|
||||
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
|
||||
if len(replacements) > 0:
|
||||
self.range_tree_nodes[sym].expr = sympy_subs(
|
||||
self.range_tree_nodes[sym].expr, replacements
|
||||
self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index]
|
||||
self.range_tree_nodes[sym].expr, replacements # type: ignore[index]
|
||||
)
|
||||
self.range_tree_nodes[sym].codegen()
|
||||
self.range_tree_nodes[sym].codegen() # type: ignore[index]
|
||||
return expr
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
|
@ -1735,7 +1735,7 @@ class TritonKernel(Kernel):
|
|||
{xindex: 512, rindex: 1024}
|
||||
"""
|
||||
index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()}
|
||||
index_in_tile_vars = sympy_subs(index, index_to_tile_indexes)
|
||||
index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type]
|
||||
strides = {}
|
||||
for range_tree in self.range_trees:
|
||||
s = sympy_index_symbol(range_tree.name)
|
||||
|
|
@ -1883,7 +1883,7 @@ class TritonKernel(Kernel):
|
|||
|
||||
result_var = self.cse.generate(load_buffer, line)
|
||||
assert isinstance(result_var, TritonCSEVariable)
|
||||
result_var.mask_vars = indexing.mask_vars
|
||||
result_var.mask_vars = indexing.mask_vars # type: ignore[assignment]
|
||||
|
||||
if append_broadcast:
|
||||
line = f"tl.broadcast_to({result_var}, {append_broadcast})"
|
||||
|
|
@ -2410,7 +2410,7 @@ class TritonKernel(Kernel):
|
|||
# note that random seed is put in V.graph.constants
|
||||
const_tensor = V.graph.constants[arg_name]
|
||||
result.writeline(
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # noqa: B950 line too long
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
|
||||
)
|
||||
elif isinstance(arg_sig, SizeArg):
|
||||
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
|
||||
|
|
@ -3128,9 +3128,9 @@ class TritonScheduling(BaseScheduling):
|
|||
|
||||
# Only install guards for 32-bit indexing as there is no correctness
|
||||
# issue with using 64-bit for everything
|
||||
V.graph.sizevars.guard_leq(numel, int_max)
|
||||
V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type]
|
||||
for size in buf_sizes:
|
||||
V.graph.sizevars.guard_leq(size, int_max)
|
||||
V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type]
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ def config_of(args: List[Union[TensorArg, SizeArg]]) -> instance_descriptor:
|
|||
return False
|
||||
if isinstance(x.expr, float):
|
||||
return False
|
||||
return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment)
|
||||
return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type]
|
||||
raise NotImplementedError(f"unhandled {type(x)}: {x}")
|
||||
|
||||
if config.triton.divisible_by_16:
|
||||
|
|
|
|||
|
|
@ -757,17 +757,17 @@ class WrapperCodeGen(CodeGen):
|
|||
)
|
||||
|
||||
for name, shape in graph_inputs_expr:
|
||||
shape = V.graph.sizevars.simplify(shape)
|
||||
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
|
||||
if shape in needed:
|
||||
needed.remove(shape)
|
||||
needed.remove(shape) # type: ignore[arg-type]
|
||||
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
|
||||
|
||||
for name, value in graph_inputs_tensors:
|
||||
shapes = value.get_size()
|
||||
for dim, shape in enumerate(shapes):
|
||||
shape = V.graph.sizevars.simplify(shape)
|
||||
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
|
||||
if shape in needed:
|
||||
needed.remove(shape)
|
||||
needed.remove(shape) # type: ignore[arg-type]
|
||||
code.writeline(
|
||||
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
|
||||
)
|
||||
|
|
@ -775,9 +775,9 @@ class WrapperCodeGen(CodeGen):
|
|||
for name, value in graph_inputs_tensors:
|
||||
shapes = value.get_stride()
|
||||
for dim, shape in enumerate(shapes):
|
||||
shape = V.graph.sizevars.simplify(shape)
|
||||
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
|
||||
if shape in needed:
|
||||
needed.remove(shape)
|
||||
needed.remove(shape) # type: ignore[arg-type]
|
||||
code.writeline(
|
||||
f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ Dep = Union["MemoryDep", "StarDep", "WeakDep"]
|
|||
|
||||
class MemoryDep(typing.NamedTuple):
|
||||
name: str
|
||||
index: sympy.Expr
|
||||
index: sympy.Expr # type: ignore[assignment]
|
||||
var_names: Tuple[sympy.Symbol, ...]
|
||||
size: Tuple[sympy.Expr, ...]
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ class MemoryDep(typing.NamedTuple):
|
|||
return isinstance(self.index, (int, sympy.Integer))
|
||||
|
||||
def is_indirect(self) -> bool:
|
||||
return any(is_indirect(v.name) for v in self.index.free_symbols)
|
||||
return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class StarDep(typing.NamedTuple):
|
||||
|
|
@ -146,7 +146,7 @@ class WeakDep(typing.NamedTuple):
|
|||
|
||||
|
||||
class IndexExprDep(typing.NamedTuple):
|
||||
index: sympy.Expr
|
||||
index: sympy.Expr # type: ignore[assignment]
|
||||
var_names: Tuple[sympy.Symbol, ...]
|
||||
size: Tuple[sympy.Expr, ...]
|
||||
|
||||
|
|
@ -235,7 +235,7 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
|||
k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1
|
||||
)
|
||||
sizes = tuple(v for v in sizes if v != 1)
|
||||
return index, var_names, sizes
|
||||
return index, var_names, sizes # type: ignore[return-value]
|
||||
|
||||
# Try to further simplify the indexes even if simplify_loops didn't
|
||||
# convert it to the simplest form because of the interference from
|
||||
|
|
@ -269,7 +269,7 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
|||
# downstream users won't. Normalize this away.
|
||||
new_vars.pop()
|
||||
new_sizes.pop()
|
||||
return index, tuple(new_vars), tuple(new_sizes)
|
||||
return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type]
|
||||
|
||||
def load(self, name: str, index: sympy.Expr) -> str:
|
||||
self._reads.add(MemoryDep(name, *self.canonicalize(index)))
|
||||
|
|
|
|||
|
|
@ -368,7 +368,7 @@ def cat_tuned_op(match, inputs, dim, *, op, shape_of):
|
|||
if new_size is None:
|
||||
new_size = shape
|
||||
else:
|
||||
new_size[notdim] = V.graph.sizevars.guard_equals(
|
||||
new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload]
|
||||
shape[notdim], new_size[notdim]
|
||||
)
|
||||
new_size[dim] += shape[dim]
|
||||
|
|
|
|||
|
|
@ -208,9 +208,9 @@ def ir_node_to_tensor(x, guard_shape=True):
|
|||
size = [shape_fn(s) for s in x.get_size()]
|
||||
stride: StrideType
|
||||
if is_storage_and_layout(x):
|
||||
stride = [shape_fn(s) for s in x.get_layout().stride]
|
||||
stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc]
|
||||
else:
|
||||
stride = make_contiguous_strides_for(size)
|
||||
stride = make_contiguous_strides_for(size) # type: ignore[arg-type]
|
||||
dtype = x.get_dtype()
|
||||
device = x.get_device()
|
||||
size = convert_shape_to_symint(size)
|
||||
|
|
@ -294,7 +294,7 @@ class IRNode:
|
|||
return sympy_product(self.get_size())
|
||||
|
||||
def is_zero_elements(self):
|
||||
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))
|
||||
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
|
||||
|
||||
def realize(self):
|
||||
"""
|
||||
|
|
@ -1150,7 +1150,7 @@ class Reduction(Loops):
|
|||
):
|
||||
reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
|
||||
need_mask = not V.graph.sizevars.is_expr_static_and_true(
|
||||
sympy.Eq(reduction_numel % split, 0)
|
||||
sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def wrapper_fn(index, reduction_index):
|
||||
|
|
@ -1286,7 +1286,7 @@ class Reduction(Loops):
|
|||
wrapper_fn,
|
||||
ranges,
|
||||
reduction_ranges,
|
||||
[*ranges, split],
|
||||
[*ranges, split], # type: ignore[list-item]
|
||||
[block_size],
|
||||
reduction_type,
|
||||
split,
|
||||
|
|
@ -1522,7 +1522,7 @@ class WelfordReduction(Reduction):
|
|||
"""
|
||||
reduction_numel = sympy_product(reduction_ranges)
|
||||
need_mask = not V.graph.sizevars.is_expr_static_and_true(
|
||||
sympy.Eq(reduction_numel % split, 0)
|
||||
sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if need_mask and reduction_type != "welford_combine":
|
||||
|
|
@ -1562,7 +1562,7 @@ class WelfordReduction(Reduction):
|
|||
)
|
||||
for loader in inner_fns
|
||||
),
|
||||
[*ranges, split],
|
||||
[*ranges, split], # type: ignore[list-item]
|
||||
[block_size],
|
||||
reduction_type,
|
||||
reduction_hint,
|
||||
|
|
@ -1587,7 +1587,7 @@ class WelfordReduction(Reduction):
|
|||
for i in intermediates
|
||||
),
|
||||
ranges,
|
||||
[split],
|
||||
[split], # type: ignore[list-item]
|
||||
# welford_reduce turns one input into three outputs, which are combined with welford_combine
|
||||
"welford_combine",
|
||||
reduction_hint,
|
||||
|
|
@ -1680,7 +1680,7 @@ class Scan(Loops):
|
|||
scan_numel = sizevars.simplify(sympy_product(scan_ranges))
|
||||
|
||||
# Scan with a single element is just a copy
|
||||
if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)):
|
||||
if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type]
|
||||
return Pointwise.create(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
|
@ -1993,7 +1993,7 @@ class PermuteView(BaseView):
|
|||
|
||||
def make_reindexer(self):
|
||||
inv = {j: i for i, j in enumerate(self.dims)}
|
||||
inv = [inv[i] for i in range(len(self.dims))]
|
||||
inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index]
|
||||
assert set(inv) == set(range(len(self.dims)))
|
||||
|
||||
def reindex(index):
|
||||
|
|
@ -2214,12 +2214,12 @@ class View(GenericView):
|
|||
|
||||
while stack_old:
|
||||
size_old = stack_old.pop()
|
||||
V.graph.sizevars.guard_equals(size_old, 1)
|
||||
V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type]
|
||||
view_expr.append(sympy.Integer(0))
|
||||
|
||||
while stack_new:
|
||||
var, size_new = stack_new.pop()
|
||||
V.graph.sizevars.guard_equals(size_new, 1)
|
||||
V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type]
|
||||
|
||||
view_expr = list(reversed(view_expr))
|
||||
assert len(view_expr) == len(old_size)
|
||||
|
|
@ -2227,7 +2227,7 @@ class View(GenericView):
|
|||
def reindex(index):
|
||||
assert len(index) == len(vars), (len(index), len(vars))
|
||||
replacements = dict(zip(vars, index))
|
||||
return tuple(sympy_subs(x, replacements) for x in view_expr)
|
||||
return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type]
|
||||
|
||||
return reindex
|
||||
|
||||
|
|
@ -2487,7 +2487,7 @@ class Layout(IRNode):
|
|||
if ndim not in [4, 5]:
|
||||
return False
|
||||
for left, right, size in zip(
|
||||
self.stride, make_channels_last_strides_for(self.size), self.size
|
||||
self.stride, make_channels_last_strides_for(self.size), self.size # type: ignore[arg-type]
|
||||
):
|
||||
if size != 1 and left != right:
|
||||
return False
|
||||
|
|
@ -2564,7 +2564,7 @@ class Layout(IRNode):
|
|||
)
|
||||
|
||||
def storage_size(self) -> sympy.Expr:
|
||||
return compute_required_storage_length(self.size, self.stride, self.offset)
|
||||
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value]
|
||||
|
||||
|
||||
class FixedLayout(Layout):
|
||||
|
|
@ -2583,9 +2583,9 @@ class FixedLayout(Layout):
|
|||
super().__init__(
|
||||
device,
|
||||
dtype,
|
||||
size,
|
||||
size, # type: ignore[arg-type]
|
||||
stride,
|
||||
offset,
|
||||
offset, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def make_indexer(self):
|
||||
|
|
@ -2715,7 +2715,7 @@ class AliasedLayout(Layout):
|
|||
return True
|
||||
from .compile_fx import ALIGNMENT
|
||||
|
||||
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
|
||||
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class NoneLayout(IRNode):
|
||||
|
|
@ -2882,7 +2882,7 @@ class Buffer(IRNode):
|
|||
self.layout = self.layout.as_same_order(stride)
|
||||
|
||||
def is_zero_elements(self):
|
||||
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))
|
||||
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
|
||||
|
||||
def make_loader(self):
|
||||
# Loading from a zero-element buffer is a no-op
|
||||
|
|
@ -3177,7 +3177,7 @@ class ComputedBuffer(Buffer):
|
|||
else:
|
||||
indices = index_vars
|
||||
stride_lengths = [
|
||||
V.graph.sizevars.stride_hints(expr, indices) for expr in reads
|
||||
V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type]
|
||||
]
|
||||
from .scheduler import pick_loop_order
|
||||
|
||||
|
|
@ -3907,13 +3907,13 @@ class ExternKernel(InputsKernel):
|
|||
else:
|
||||
type_ = None
|
||||
kwargs.append(
|
||||
V.graph.wrapper_code.val_to_cpp_arg_str(
|
||||
V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type]
|
||||
type_, v, self.is_legacy_abi_kernel()
|
||||
)
|
||||
)
|
||||
else:
|
||||
kwargs = [
|
||||
f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}"
|
||||
f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc]
|
||||
for k, v in self.kwargs.items()
|
||||
]
|
||||
return kwargs
|
||||
|
|
@ -3962,7 +3962,7 @@ class ExternKernel(InputsKernel):
|
|||
_, add_var = var_builder("c")
|
||||
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
|
||||
|
||||
index = sympy_subs(sympy.expand(index), replacement)
|
||||
index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type]
|
||||
return index, tuple(new_sizes)
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
|
|
|
|||
|
|
@ -243,14 +243,14 @@ def conv_layout(
|
|||
ir.ir_node_to_tensor(weight, guard_shape=True),
|
||||
ir.ir_node_to_tensor(bias, guard_shape=True),
|
||||
stride,
|
||||
tuple(V.graph.sizevars.size_hint(p) for p in padding),
|
||||
tuple(V.graph.sizevars.size_hint(p) for p in padding), # type: ignore[arg-type]
|
||||
dilation,
|
||||
transposed,
|
||||
tuple(V.graph.sizevars.size_hint(p) for p in output_padding),
|
||||
tuple(V.graph.sizevars.size_hint(p) for p in output_padding), # type: ignore[arg-type]
|
||||
groups,
|
||||
)
|
||||
sizes = ir.convert_shape_to_inductor(output.size())
|
||||
stride = ir.convert_shape_to_inductor(output.stride())
|
||||
stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
|
||||
|
||||
return ir.FixedLayout(
|
||||
x.get_device(),
|
||||
|
|
@ -419,7 +419,7 @@ def convolution(
|
|||
and not transposed
|
||||
and is_zeros(output_padding)
|
||||
# there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
|
||||
and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1])
|
||||
and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
|
||||
):
|
||||
if (
|
||||
is_ones(kernel_shape)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def filtered_configs(
|
|||
m = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
m, fallback=torch._inductor.config.unbacked_symint_fallback
|
||||
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
|
|
@ -44,7 +44,7 @@ def filtered_configs(
|
|||
n = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
n, fallback=torch._inductor.config.unbacked_symint_fallback
|
||||
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
|
|
@ -52,7 +52,7 @@ def filtered_configs(
|
|||
k = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
k, fallback=torch._inductor.config.unbacked_symint_fallback
|
||||
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
|
|
|
|||
|
|
@ -983,8 +983,8 @@ def pointwise_cat(inputs, dim=0):
|
|||
inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = []
|
||||
prev_end = 0
|
||||
for inp in inputs:
|
||||
inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim]))
|
||||
prev_end = inputs_ranges[-1][-1]
|
||||
inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type]
|
||||
prev_end = inputs_ranges[-1][-1] # type: ignore[assignment]
|
||||
|
||||
inputs_loaders = [inp.make_loader() for inp in inputs]
|
||||
|
||||
|
|
@ -1215,7 +1215,7 @@ def unfold(x, dimension, size, step):
|
|||
dim_size = sizes[dim]
|
||||
sizevars = V.graph.sizevars
|
||||
sizevars.guard_leq(size, dim_size)
|
||||
sizevars.guard_lt(0, step)
|
||||
sizevars.guard_lt(0, step) # type: ignore[arg-type]
|
||||
|
||||
new_dim_size = FloorDiv(dim_size - size, step) + 1
|
||||
if sizevars.size_hint(dim_size) > 0:
|
||||
|
|
@ -2371,8 +2371,8 @@ def select_scatter(x, src, dim: int, index: int):
|
|||
dim = _validate_dim(x, dim, 0)
|
||||
if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
|
||||
index = index + x.get_size()[dim]
|
||||
V.graph.sizevars.guard_leq(0, index)
|
||||
V.graph.sizevars.guard_lt(index, x.get_size()[dim])
|
||||
V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type]
|
||||
V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
|
||||
src = expand(unsqueeze(src, dim), x.get_size())
|
||||
src_loader = src.make_loader()
|
||||
|
||||
|
|
@ -3673,7 +3673,7 @@ def constant_pad_nd(x, padding, fill_value=0):
|
|||
# if padding is a complicated expression, hoist it
|
||||
bounds_precomp: List[Tuple[sympy.Symbol, Any]] = []
|
||||
for l, h in bounds:
|
||||
bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h))
|
||||
bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type]
|
||||
|
||||
output_size = list(sizes[:n])
|
||||
mask_sizes = []
|
||||
|
|
@ -3770,7 +3770,7 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
|
|||
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
|
||||
# Sliding windows must start within the input or left padding
|
||||
x_alt -= 1 # type: ignore[assignment]
|
||||
V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i])
|
||||
V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
|
||||
if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
|
||||
# ceil mode is actually a no-op, lets guard on that
|
||||
V.graph.sizevars.guard_equals(x_out, x_alt)
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ class SizeVarAllocator:
|
|||
# actual iteration range is to size-1
|
||||
iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
|
||||
base_lowest = sympy_subs(base, iter_ranges_zero)
|
||||
if self.statically_known_leq(0, base_lowest):
|
||||
if self.statically_known_leq(0, base_lowest): # type: ignore[arg-type]
|
||||
# can't replace with indexing div if base can be negative
|
||||
base_pos = True
|
||||
else:
|
||||
|
|
@ -272,7 +272,7 @@ class SizeVarAllocator:
|
|||
"""
|
||||
Returns a bool indicating if it is sound to optimize as if left and right are equal.
|
||||
"""
|
||||
return self.is_expr_static_and_true(sympy.Eq(left, right))
|
||||
return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type]
|
||||
|
||||
# See Note - [On Statically Known]
|
||||
def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
|
||||
|
|
@ -307,7 +307,7 @@ class SizeVarAllocator:
|
|||
Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
|
||||
"""
|
||||
expr = sympy.Eq(numerator % denominator, 0)
|
||||
return self.is_expr_static_and_true(expr)
|
||||
return self.is_expr_static_and_true(expr) # type: ignore[arg-type]
|
||||
|
||||
# The guard functions require you to ALREADY KNOW that a particular
|
||||
# condition holds. If you don't know (you want to guard on an expression
|
||||
|
|
@ -316,9 +316,9 @@ class SizeVarAllocator:
|
|||
|
||||
def guard_equals(self, left: Expr, right: Expr) -> Expr:
|
||||
if isinstance(left, Expr):
|
||||
left = sympy_subs(left, self.inv_precomputed_replacements)
|
||||
left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||
if isinstance(right, Expr):
|
||||
right = sympy_subs(right, self.inv_precomputed_replacements)
|
||||
right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||
assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
|
||||
return left
|
||||
|
||||
|
|
@ -329,16 +329,16 @@ class SizeVarAllocator:
|
|||
assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
|
||||
|
||||
def expect_true(self, expr: Expr, *, msg: str) -> None:
|
||||
expr = sympy_subs(expr, self.inv_precomputed_replacements)
|
||||
expr = sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||
self.shape_env.defer_runtime_assert(expr, msg, fx_node=None)
|
||||
|
||||
def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr:
|
||||
# Prefer returning the expression without unbacked symints
|
||||
if self.shape_env.is_unbacked_symint(left):
|
||||
self.expect_true(sympy.Eq(left, right), msg=msg)
|
||||
self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
|
||||
return right
|
||||
elif self.shape_env.is_unbacked_symint(right):
|
||||
self.expect_true(sympy.Eq(left, right), msg=msg)
|
||||
self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
|
||||
return left
|
||||
else:
|
||||
return self.guard_equals(left, right)
|
||||
|
|
@ -399,8 +399,8 @@ class SizeVarAllocator:
|
|||
return [self.evaluate_static_shape(x) for x in left]
|
||||
|
||||
def remove_precomputed_replacements(self, expr: Expr) -> Expr:
|
||||
if any(s.name.startswith("ps") for s in expr.free_symbols):
|
||||
return sympy_subs(expr, self.inv_precomputed_replacements)
|
||||
if any(s.name.startswith("ps") for s in expr.free_symbols): # type: ignore[attr-defined]
|
||||
return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||
return expr
|
||||
|
||||
def symbolic_hint(self, expr: Expr) -> Expr:
|
||||
|
|
@ -410,7 +410,7 @@ class SizeVarAllocator:
|
|||
return expr
|
||||
free_symbols = expr.free_symbols
|
||||
if not free_symbols:
|
||||
return int(expr)
|
||||
return int(expr) # type: ignore[return-value]
|
||||
expr = self.remove_precomputed_replacements(expr)
|
||||
return sympy_subs(expr, self.var_to_val)
|
||||
|
||||
|
|
@ -422,9 +422,9 @@ class SizeVarAllocator:
|
|||
s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols
|
||||
}
|
||||
if all(vr is not None for vr in sym_vrs.values()):
|
||||
expr_vr = bound_sympy(expr, sym_vrs)
|
||||
lower = self.size_hint(expr_vr.lower)
|
||||
upper = self.size_hint(expr_vr.upper)
|
||||
expr_vr = bound_sympy(expr, sym_vrs) # type: ignore[arg-type]
|
||||
lower = self.size_hint(expr_vr.lower) # type: ignore[arg-type]
|
||||
upper = self.size_hint(expr_vr.upper) # type: ignore[arg-type]
|
||||
fallback = min(max(fallback, lower), upper)
|
||||
return fallback
|
||||
try:
|
||||
|
|
@ -522,8 +522,8 @@ class SizeVarAllocator:
|
|||
support_vars: Optional[List[sympy.Symbol]] = None,
|
||||
) -> List[int]:
|
||||
for v in index.free_symbols:
|
||||
if v.name.startswith("indirect"):
|
||||
index = sympy_subs(index, {v: 0})
|
||||
if v.name.startswith("indirect"): # type: ignore[attr-defined]
|
||||
index = sympy_subs(index, {v: 0}) # type: ignore[dict-item]
|
||||
result = []
|
||||
for s in self.stride_vars(index, vars, support_vars):
|
||||
try:
|
||||
|
|
@ -538,7 +538,7 @@ class SizeVarAllocator:
|
|||
order.sort(key=lambda x: (strides[x] == 0, strides[x]))
|
||||
return order
|
||||
|
||||
def lookup_precomputed_size(self, expr: Expr) -> sympy.Symbol:
|
||||
def lookup_precomputed_size(self, expr: Expr) -> Expr:
|
||||
if (
|
||||
isinstance(expr, (int, sympy.Symbol, sympy.Number))
|
||||
or expr.is_number
|
||||
|
|
|
|||
|
|
@ -569,8 +569,8 @@ def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.E
|
|||
if isinstance(replacement, str):
|
||||
return sympy.Symbol(
|
||||
replacement,
|
||||
integer=replaced.is_integer,
|
||||
nonnegative=replaced.is_nonnegative,
|
||||
integer=replaced.is_integer, # type: ignore[attr-defined]
|
||||
nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
|
||||
)
|
||||
else:
|
||||
return replacement
|
||||
|
|
@ -582,11 +582,11 @@ def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.E
|
|||
|
||||
|
||||
def free_symbol_startswith(index: sympy.Expr, prefix: str):
|
||||
return any(v.name.startswith(prefix) for v in index.free_symbols)
|
||||
return any(v.name.startswith(prefix) for v in index.free_symbols) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def free_symbol_has(index: sympy.Expr, pattern: str):
|
||||
return any(pattern in v.name for v in index.free_symbols)
|
||||
return any(pattern in v.name for v in index.free_symbols) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def is_symbolic(a: Any) -> bool:
|
||||
|
|
@ -1081,7 +1081,7 @@ def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
|
|||
assert isinstance(
|
||||
val, sympy.Expr
|
||||
), "only support sympy.Expr as input to get_sympy_Expr_dtype"
|
||||
if val.is_integer:
|
||||
if val.is_integer: # type: ignore[attr-defined]
|
||||
return torch.int64
|
||||
else:
|
||||
return torch.float64
|
||||
|
|
|
|||
|
|
@ -468,7 +468,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
|||
weight_fake_quant: fake quant module for weight
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nni.ConvReLU1d
|
||||
_FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment]
|
||||
_FLOAT_CONV_MODULE = nn.Conv1d
|
||||
_FLOAT_BN_MODULE: None = None
|
||||
_FLOAT_RELU_MODULE = nn.ReLU
|
||||
|
|
@ -600,7 +600,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
|||
weight_fake_quant: fake quant module for weight
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nni.ConvReLU2d
|
||||
_FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment]
|
||||
_FLOAT_CONV_MODULE = nn.Conv2d
|
||||
_FLOAT_BN_MODULE: None = None
|
||||
_FLOAT_RELU_MODULE = nn.ReLU
|
||||
|
|
@ -773,7 +773,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
|||
weight_fake_quant: fake quant module for weight
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nni.ConvReLU3d
|
||||
_FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment]
|
||||
_FLOAT_CONV_MODULE = nn.Conv3d
|
||||
_FLOAT_BN_MODULE: None = None
|
||||
_FLOAT_RELU_MODULE = nn.ReLU
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user