Revert "Enable possibly-undefined error code (#118533)"

This reverts commit 4f13f69a45.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
This commit is contained in:
PyTorch MergeBot 2024-01-30 19:00:34 +00:00
parent 6511811ebb
commit 40ece2e579
94 changed files with 197 additions and 200 deletions

View File

@ -13,7 +13,6 @@ show_column_numbers = True
check_untyped_defs = True check_untyped_defs = True
follow_imports = normal follow_imports = normal
local_partial_types = True local_partial_types = True
enable_error_code = possibly-undefined
# do not reenable this: # do not reenable this:
# https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657 # https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657

View File

@ -1,4 +1,3 @@
# mypy: disable-error-code="possibly-undefined"
# flake8: noqa # flake8: noqa
import torch import torch
from torch.testing._internal.common_utils import TEST_NUMPY from torch.testing._internal.common_utils import TEST_NUMPY

View File

@ -1,4 +1,3 @@
# mypy: disable-error-code="possibly-undefined"
# flake8: noqa # flake8: noqa
import torch import torch
from torch.testing._internal.common_utils import TEST_NUMPY from torch.testing._internal.common_utils import TEST_NUMPY

View File

@ -515,7 +515,7 @@ for name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos"
sym_sqrt = current_module._sym_sqrt sym_sqrt = current_module._sym_sqrt
__all__.append("sym_sqrt") __all__.append("sym_sqrt")
del fn, name, sym_name, current_module # type: ignore[possibly-undefined] del fn, name, sym_name, current_module
def sym_ite(b, t, f): def sym_ite(b, t, f):

View File

@ -2832,7 +2832,7 @@ def _rnn_helper(
final_hiddens.append(bwd_hidden) final_hiddens.append(bwd_hidden)
if bidirectional: if bidirectional:
input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1) # type: ignore[possibly-undefined] input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1)
else: else:
input = fwd_inp input = fwd_inp

View File

@ -163,7 +163,7 @@ def preserve_global_state(fn):
random.setstate(py_rng_state) random.setstate(py_rng_state)
torch.random.set_rng_state(torch_rng_state) torch.random.set_rng_state(torch_rng_state)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] torch.cuda.set_rng_state(cuda_rng_state)
torch.fx.graph_module._forward_from_src = prior_fwd_from_src torch.fx.graph_module._forward_from_src = prior_fwd_from_src
assert ( assert (
guards.check() guards.check()
@ -568,7 +568,7 @@ def _compile(
code.co_name, code.co_name,
code.co_filename, code.co_filename,
code.co_firstlineno, code.co_firstlineno,
out_code, # type: ignore[possibly-undefined] out_code,
) )
for hook in _bytecode_hooks.values(): for hook in _bytecode_hooks.values():

View File

@ -46,7 +46,7 @@ if use_buck:
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu", "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops", "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
] ]
cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined] cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//")
extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps]) extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])

View File

@ -1430,7 +1430,7 @@ def export(
example_fake_inputs, example_fake_inputs,
graph_captured_input, graph_captured_input,
graph_captured_result, graph_captured_result,
result_traced, # type: ignore[possibly-undefined] result_traced,
flat_args_dynamic_dims, flat_args_dynamic_dims,
) )
# Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check # Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check

View File

@ -1115,7 +1115,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
tos = self.pop() tos = self.pop()
_ = self.pop() _ = self.pop()
if preserve_tos: if preserve_tos:
self.push(tos) # type: ignore[possibly-undefined] self.push(tos)
def FOR_ITER(self, inst): def FOR_ITER(self, inst):
it = self.pop().realize() it = self.pop().realize()

View File

@ -118,7 +118,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
finally: finally:
log.removeHandler(log_handler) log.removeHandler(log_handler)
if cwd is not None: if cwd is not None:
os.chdir(prev_cwd) # type: ignore[possibly-undefined] os.chdir(prev_cwd)
# Make sure we don't leave buggy compiled frames lying # Make sure we don't leave buggy compiled frames lying
# around # around
torch._dynamo.reset() torch._dynamo.reset()

View File

@ -773,7 +773,7 @@ def preserve_rng_state():
with torch.utils._python_dispatch._disable_current_modes(): with torch.utils._python_dispatch._disable_current_modes():
torch.random.set_rng_state(rng_state) torch.random.set_rng_state(rng_state)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] torch.cuda.set_rng_state(cuda_rng_state)
def is_jit_model(model0): def is_jit_model(model0):
@ -892,7 +892,7 @@ def timed(model, example_inputs, times=1):
result = model(*example_inputs) result = model(*example_inputs)
synchronize() synchronize()
t1 = time.perf_counter() t1 = time.perf_counter()
return result, t1 - t0 # type: ignore[possibly-undefined] return result, t1 - t0
def check_is_cuda(gm, example_inputs): def check_is_cuda(gm, example_inputs):

View File

@ -199,7 +199,7 @@ def wrap_outputs_maintaining_identity(
result.append(unwrapped_input_to_orig_input[id(output)]) result.append(unwrapped_input_to_orig_input[id(output)])
continue continue
if out_dims_specified: if out_dims_specified:
result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index] result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[index]
else: else:
result.append(wrap_fn(output)) result.append(wrap_fn(output))

View File

@ -163,7 +163,7 @@ def parse_ttir(ttir, kwargs):
return None return None
try: try:
import lark # type: ignore[import-not-found] import lark
from lark import Lark, Transformer, v_args from lark import Lark, Transformer, v_args
except ModuleNotFoundError: except ModuleNotFoundError:
warnings.warn( warnings.warn(

View File

@ -440,25 +440,25 @@ class BenchmarkRequest:
output_tensor = self.output_tensor_meta.to_tensor() output_tensor = self.output_tensor_meta.to_tensor()
if debug: if debug:
create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined] create_tensor_elapse = time.time() - start_ts
start_ts = time.time() start_ts = time.time()
fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor) fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
if debug: if debug:
load_elapse = time.time() - start_ts # type: ignore[possibly-undefined] load_elapse = time.time() - start_ts
start_ts = time.time() start_ts = time.time()
out = do_bench(fn) out = do_bench(fn)
torch.cuda.synchronize() # shake out any CUDA errors torch.cuda.synchronize() # shake out any CUDA errors
if debug: if debug:
bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined] bench_elapse = time.time() - start_ts
log.debug( log.debug(
"InChildProcess %s: load %f, create tensor %f, bench %f", "InChildProcess %s: load %f, create tensor %f, bench %f",
str(self), str(self),
load_elapse, # type: ignore[possibly-undefined] load_elapse,
create_tensor_elapse, # type: ignore[possibly-undefined] create_tensor_elapse,
bench_elapse, bench_elapse,
) )
self.cleanup_run_fn() self.cleanup_run_fn()

View File

@ -99,7 +99,7 @@ class CutlassEVTEpilogueTypeFormatter:
result = pnode.inner_fn(index) result = pnode.inner_fn(index)
# each epilogue node results in a single "using" statement and may refer to the previous steps by name # each epilogue node results in a single "using" statement and may refer to the previous steps by name
formatter.aliases[node.name] = result formatter.aliases[node.name] = result
res = formatter.getvalue(result) # type: ignore[possibly-undefined] res = formatter.getvalue(result)
if _MAGIC_SYMPY_ERROR_STRING in res: if _MAGIC_SYMPY_ERROR_STRING in res:
raise CUTLASSEVTOpNotImplementedError( raise CUTLASSEVTOpNotImplementedError(
"sympy / indexing expressions not yet supported in EVT fusion" "sympy / indexing expressions not yet supported in EVT fusion"
@ -266,7 +266,7 @@ class CutlassEVTEpilogueArgumentFormatter:
if node.name is not None: if node.name is not None:
formatter.aliases[node.name] = result formatter.aliases[node.name] = result
res: str = formatter.getvalue(result) # type: ignore[possibly-undefined] res: str = formatter.getvalue(result)
if _MAGIC_SYMPY_ERROR_STRING in res: if _MAGIC_SYMPY_ERROR_STRING in res:
raise CUTLASSEVTOpNotImplementedError( raise CUTLASSEVTOpNotImplementedError(
"sympy / indexing expressions not yet supported in EVT fusion" "sympy / indexing expressions not yet supported in EVT fusion"

View File

@ -155,7 +155,7 @@ def get_cpp_op_schema(kernel: torch._ops.OpOverload) -> str:
cpp_return_value = f"std::tuple<{tuple_returns}>" cpp_return_value = f"std::tuple<{tuple_returns}>"
cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args] cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args]
return f"{cpp_return_value}({', '.join(cpp_arg_type)})" # type: ignore[possibly-undefined] return f"{cpp_return_value}({', '.join(cpp_arg_type)})"
# TODO: Move to a well known place # TODO: Move to a well known place

View File

@ -209,7 +209,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
nsteps = nRanks - 1 nsteps = nRanks - 1
# Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time) # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined] ratio = (1.0 * nRanks) / nsteps
bandwidth = busBw * ratio bandwidth = busBw * ratio
# Convert GB/s to GB/ns # Convert GB/s to GB/ns
bandwidth_GB_per_ns = bandwidth / 1e9 bandwidth_GB_per_ns = bandwidth / 1e9
@ -236,7 +236,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
if nNodes > 1: if nNodes > 1:
netOverhead = 1.0 # getNetOverhead(comm); netOverhead = 1.0 # getNetOverhead(comm);
intraLat = max(intraLat, netOverhead) intraLat = max(intraLat, netOverhead)
latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined] latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat
# Convert us to ns # Convert us to ns
latency_ns = latency * 1e3 latency_ns = latency * 1e3

View File

@ -170,9 +170,9 @@ class PostGradBatchLinearFusion(BatchFusion):
input, weight = node.args input, weight = node.args
bias = None bias = None
batch_nodes.append(node) batch_nodes.append(node)
batch_inputs.append(input) # type: ignore[possibly-undefined] batch_inputs.append(input)
batch_weights.append(weight) # type: ignore[possibly-undefined] batch_weights.append(weight)
batch_biases.append(bias) # type: ignore[possibly-undefined] batch_biases.append(bias)
with graph.inserting_before(subset[-1]): with graph.inserting_before(subset[-1]):
fused_inputs = decompose_stack(graph, batch_inputs) fused_inputs = decompose_stack(graph, batch_inputs)
@ -191,7 +191,7 @@ class PostGradBatchLinearFusion(BatchFusion):
new_bias_add = graph.call_function( new_bias_add = graph.call_function(
aten.add, args=((batch_biases[i], new_mm)) aten.add, args=((batch_biases[i], new_mm))
) )
new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined] new_mm_cont = new_bias_add if has_bias else new_mm
original_mm.replace_all_uses_with(new_mm_cont) original_mm.replace_all_uses_with(new_mm_cont)
new_mm_cont.meta.update(original_mm.meta) new_mm_cont.meta.update(original_mm.meta)
graph.erase_node(original_mm) graph.erase_node(original_mm)

View File

@ -283,7 +283,7 @@ if torch._C._has_mkldnn:
L[aten.mul](out, negative_slope), L[aten.mul](out, negative_slope),
) )
if lowp_dtype: if lowp_dtype:
out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] out = L[prims.convert_element_type.default](out, dtype=dtype2)
return out return out
return fn return fn
@ -324,7 +324,7 @@ if torch._C._has_mkldnn:
out = L[prims.convert_element_type.default](out, dtype=torch.float) out = L[prims.convert_element_type.default](out, dtype=torch.float)
out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value) out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value)
if lowp_dtype: if lowp_dtype:
out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] out = L[prims.convert_element_type.default](out, dtype=dtype2)
return out return out
return fn return fn

View File

@ -105,7 +105,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
gm_after_fx_passes = gm.__copy__() gm_after_fx_passes = gm.__copy__()
numeric_check_if_enabled( numeric_check_if_enabled(
gm_before_fx_passes, # type: ignore[possibly-undefined] gm_before_fx_passes,
gm_after_fx_passes, gm_after_fx_passes,
example_inputs, example_inputs,
config.fx_passes_numeric_check.get("num_iterations", 1), config.fx_passes_numeric_check.get("num_iterations", 1),

View File

@ -1360,7 +1360,7 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
graph.erase_node(conv_node) graph.erase_node(conv_node)
# Erase the dequant pattern # Erase the dequant pattern
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined] graph.erase_node(convert_to_bf16)
# Erase the dequant pattern # Erase the dequant pattern
graph.erase_node(mul_node) graph.erase_node(mul_node)
graph.erase_node(sub_node) graph.erase_node(sub_node)
@ -1369,7 +1369,7 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
if clone_node is not None: if clone_node is not None:
graph.erase_node(clone_node) graph.erase_node(clone_node)
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] graph.erase_node(weight_to_bf16_node)
graph.erase_node(dequant_per_channel) graph.erase_node(dequant_per_channel)
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1 counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len( counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
@ -1697,14 +1697,14 @@ def _register_qlinear_weight_prepack_pass(
if input_contiguous: if input_contiguous:
graph.erase_node(output_reshape_node) graph.erase_node(output_reshape_node)
elif not input_contiguous and bias: elif not input_contiguous and bias:
graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined] graph.erase_node(output_add_node_for_bias)
graph.erase_node(linear_node) graph.erase_node(linear_node)
if input_dim_exceeds_two: if input_dim_exceeds_two:
if input_contiguous: if input_contiguous:
graph.erase_node(act_reshape_node) graph.erase_node(act_reshape_node)
else: else:
graph.erase_node(act_expand_node) graph.erase_node(act_expand_node)
graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined] graph.erase_node(wgt_expand_node)
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
graph.erase_node(activation_to_bf16_node) graph.erase_node(activation_to_bf16_node)
# Erase the dequant pattern # Erase the dequant pattern
@ -1714,7 +1714,7 @@ def _register_qlinear_weight_prepack_pass(
# Erase the dequant per channel pattern # Erase the dequant per channel pattern
graph.erase_node(t_node) graph.erase_node(t_node)
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] graph.erase_node(weight_to_bf16_node)
graph.erase_node(dequant_per_channel) graph.erase_node(dequant_per_channel)
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1

View File

@ -845,7 +845,7 @@ class GraphLowering(torch.fx.Interpreter):
): ):
debug("fallback_handler") debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)( result = fallback_handler(n.target, add_to_fallback_set=False)(
*args, **kwargs # type: ignore[possibly-undefined] *args, **kwargs
) )
elif n.op == "call_function" and n.target in layout_constraints: elif n.op == "call_function" and n.target in layout_constraints:
debug("layout_constraints") debug("layout_constraints")

View File

@ -607,7 +607,7 @@ def register_pointwise(
fn, fn,
override_return_dtype=override_return_dtype, override_return_dtype=override_return_dtype,
override_fn_when_input_bool=override_fn_when_input_bool, override_fn_when_input_bool=override_fn_when_input_bool,
override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None, # type: ignore[possibly-undefined] override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None,
allow_alpha=allow_alpha, allow_alpha=allow_alpha,
) )
fn = register_lowering( fn = register_lowering(
@ -3630,8 +3630,8 @@ def _reflection_padnd_backward(grad_output, x, padding):
out = right_reflect[i] out = right_reflect[i]
index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1) index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1)
outs.append(out) # type: ignore[possibly-undefined] outs.append(out)
index_ranges.append(index_range) # type: ignore[possibly-undefined] index_ranges.append(index_range)
grad = accumulate(grad, outs, index_ranges) grad = accumulate(grad, outs, index_ranges)

View File

@ -1196,7 +1196,7 @@ class PatternMatcherPass:
if ( if (
self.prevent_match_across_mutations self.prevent_match_across_mutations
and is_match(m) and is_match(m)
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined] and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1
): ):
continue continue
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:

View File

@ -1038,7 +1038,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
else: else:
fused_nodes.append(node) fused_nodes.append(node)
return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) # type: ignore[possibly-undefined] return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2)
def __init__( def __init__(
self, self,
@ -2256,13 +2256,13 @@ class Scheduler:
if node.is_template(): if node.is_template():
node, *epilogue = node.get_nodes() node, *epilogue = node.get_nodes()
self.get_backend(device).codegen_template(node, epilogue) # type: ignore[possibly-undefined] self.get_backend(device).codegen_template(node, epilogue)
elif node.is_extern(): elif node.is_extern():
self.codegen_extern_call(node) self.codegen_extern_call(node)
elif node.is_foreach(): elif node.is_foreach():
self.get_backend(device).codegen_foreach(node) # type: ignore[possibly-undefined] self.get_backend(device).codegen_foreach(node)
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
self.get_backend(device).codegen_nodes(node.get_nodes()) # type: ignore[possibly-undefined] self.get_backend(device).codegen_nodes(node.get_nodes())
else: else:
assert isinstance(node, NopKernelSchedulerNode) assert isinstance(node, NopKernelSchedulerNode)
node.allocate() node.allocate()
@ -2271,7 +2271,7 @@ class Scheduler:
V.graph.wrapper_code.generate_inf_and_nan_checker(node) V.graph.wrapper_code.generate_inf_and_nan_checker(node)
if config.triton.debug_sync_kernel: if config.triton.debug_sync_kernel:
self.get_backend(device).codegen_sync() # type: ignore[possibly-undefined] self.get_backend(device).codegen_sync()
self.available_buffer_names.update(node.get_names()) self.available_buffer_names.update(node.get_names())

View File

@ -331,7 +331,7 @@ def timed(
synchronize(device) synchronize(device)
t1 = time.perf_counter() t1 = time.perf_counter()
# GC the result after timing # GC the result after timing
assert result is not None # type: ignore[possibly-undefined] assert result is not None
return t1 - t0 return t1 - t0

View File

@ -1147,7 +1147,7 @@ def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
f"but expected one of 'reduced' (default), 'r', or 'complete'" f"but expected one of 'reduced' (default), 'r', or 'complete'"
), ),
) )
return compute_q, reduced # type: ignore[possibly-undefined] return compute_q, reduced
@register_meta([aten.linalg_qr.default, aten.linalg_qr.out]) @register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
@ -1412,7 +1412,7 @@ def triangular_solve_meta(
cloned_coefficient = self.new_empty([0]) cloned_coefficient = self.new_empty([0])
else: else:
torch._check(False, lambda: "triangular_solve: Got an unexpected layout.") torch._check(False, lambda: "triangular_solve: Got an unexpected layout.")
return solution, cloned_coefficient # type: ignore[possibly-undefined] return solution, cloned_coefficient
# From aten/src/ATen/native/LinearAlgebra.cpp # From aten/src/ATen/native/LinearAlgebra.cpp
@ -1809,7 +1809,7 @@ def _pad3d_common(input, padding, *, is_reflection):
) )
if batch_mode: if batch_mode:
return input.new_empty((nbatch, nplane, output_d, output_h, output_w)) # type: ignore[possibly-undefined] return input.new_empty((nbatch, nplane, output_d, output_h, output_w))
else: else:
return input.new_empty((nplane, output_d, output_h, output_w)) return input.new_empty((nplane, output_d, output_h, output_w))

View File

@ -246,10 +246,10 @@ def TensorMeta(
assert dtype is not None assert dtype is not None
assert device is not None assert device is not None
shape = inferred_shape if shape is None else tuple(shape) # type: ignore[possibly-undefined] shape = inferred_shape if shape is None else tuple(shape)
strides = inferred_strides if strides is None else tuple(strides) # type: ignore[possibly-undefined] strides = inferred_strides if strides is None else tuple(strides)
dtype = inferred_dtype if dtype is None else dtype # type: ignore[possibly-undefined] dtype = inferred_dtype if dtype is None else dtype
device = inferred_device if device is None else device # type: ignore[possibly-undefined] device = inferred_device if device is None else device
if isinstance(device, str): if isinstance(device, str):
device = torch.device(device) device = torch.device(device)

View File

@ -4875,16 +4875,16 @@ def arange(
# other integral dtypes we don't. Weird... but needed to match ATen shapes. # other integral dtypes we don't. Weird... but needed to match ATen shapes.
if dtype == torch.int64: if dtype == torch.int64:
# Uses floordiv to avoid ceil in inductor. # Uses floordiv to avoid ceil in inductor.
sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined] sgn = bool(xstep > 0) - bool(xstep < 0)
length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined] length = (xend - xstart + xstep - sgn) // xstep
else: else:
length = math.ceil((end - start) / step) length = math.ceil((end - start) / step)
if is_integer: if is_integer:
return prims.iota( return prims.iota(
length, length,
start=xstart, # type: ignore[possibly-undefined] start=xstart,
step=xstep, # type: ignore[possibly-undefined] step=xstep,
dtype=dtype, dtype=dtype,
device=device, device=device,
requires_grad=requires_grad, requires_grad=requires_grad,

View File

@ -312,7 +312,7 @@ def _canonicalize_fft_shape_and_dim_args(
# Translate any -1 values in shape to the default length # Translate any -1 values in shape to the default length
ret_shape = tuple( ret_shape = tuple(
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined] s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims)
) )
elif dim is None: elif dim is None:
# No shape, no dim # No shape, no dim
@ -320,12 +320,12 @@ def _canonicalize_fft_shape_and_dim_args(
ret_shape = tuple(input_sizes) ret_shape = tuple(input_sizes)
else: else:
# No shape, has dim # No shape, has dim
ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined] ret_shape = tuple(input_sizes[d] for d in ret_dims)
for n in ret_shape: for n in ret_shape:
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified") torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined] return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
def _prod(xs: Iterable[int]) -> int: def _prod(xs: Iterable[int]) -> int:

View File

@ -610,7 +610,7 @@ def _str_intern(inp, *, tensor_contents=None):
# no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968 # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968
grad_fn_name = "Invalid" grad_fn_name = "Invalid"
if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined] if grad_fn_name is None and grad_fn is not None:
grad_fn_name = type(grad_fn).__name__ grad_fn_name = type(grad_fn).__name__
if grad_fn_name == "CppFunction": if grad_fn_name == "CppFunction":
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1] grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
@ -627,7 +627,7 @@ def _str_intern(inp, *, tensor_contents=None):
suffixes.append(f"tangent={tangent}") suffixes.append(f"tangent={tangent}")
string_repr = _add_suffixes( string_repr = _add_suffixes(
prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse # type: ignore[possibly-undefined] prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse
) )
# Check if this instance is flagged as a parameter and change the repr accordingly. # Check if this instance is flagged as a parameter and change the repr accordingly.

View File

@ -188,7 +188,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
if self.bn.training: if self.bn.training:
avg_dims = [0] + list(range(2, len(self.weight.shape))) avg_dims = [0] + list(range(2, len(self.weight.shape)))
batch_mean = conv_out.mean(avg_dims) # type: ignore[possibly-undefined] batch_mean = conv_out.mean(avg_dims)
batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean( batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
avg_dims avg_dims
) )

View File

@ -231,7 +231,7 @@ class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
if converted.bias_v is not None: if converted.bias_v is not None:
bias_v = converted._parameters.pop('bias_v') bias_v = converted._parameters.pop('bias_v')
sc, zp = torch._choose_qparams_per_tensor(bias_k, # type: ignore[possibly-undefined] sc, zp = torch._choose_qparams_per_tensor(bias_k,
reduce_range=False) reduce_range=False)
bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8) bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8)
setattr(converted, 'bias_v', bias_v) # noqa: B010 setattr(converted, 'bias_v', bias_v) # noqa: B010

View File

@ -24,7 +24,7 @@ class LinearPackedParams(torch.nn.Module):
wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8) wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
elif self.dtype == torch.float16: elif self.dtype == torch.float16:
wq = torch.zeros([1, 1], dtype=torch.float) wq = torch.zeros([1, 1], dtype=torch.float)
self.set_weight_bias(wq, None) # type: ignore[possibly-undefined] self.set_weight_bias(wq, None)
@torch.jit.export @torch.jit.export
def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None: def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:

View File

@ -435,7 +435,7 @@ class LSTM(RNNBase):
hx = (h_zeros, c_zeros) hx = (h_zeros, c_zeros)
else: else:
if batch_sizes is None: # If not PackedSequence input. if batch_sizes is None: # If not PackedSequence input.
if is_batched: # type: ignore[possibly-undefined] if is_batched:
if (hx[0].dim() != 3 or hx[1].dim() != 3): if (hx[0].dim() != 3 or hx[1].dim() != 3):
msg = ("For batched 3-D input, hx and cx should " msg = ("For batched 3-D input, hx and cx should "
f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors") f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
@ -465,8 +465,8 @@ class LSTM(RNNBase):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices)
else: else:
if not is_batched: # type: ignore[possibly-undefined] if not is_batched:
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] output = output.squeeze(batch_dim)
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
return output, self.permute_hidden(hidden, unsorted_indices) return output, self.permute_hidden(hidden, unsorted_indices)
@ -589,8 +589,8 @@ class GRU(RNNBase):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices)
else: else:
if not is_batched: # type: ignore[possibly-undefined] if not is_batched:
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] output = output.squeeze(batch_dim)
hidden = hidden.squeeze(1) hidden = hidden.squeeze(1)
return output, self.permute_hidden(hidden, unsorted_indices) return output, self.permute_hidden(hidden, unsorted_indices)

View File

@ -759,7 +759,7 @@ def create_a_shadows_b(
continue continue
fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a) fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b) # type: ignore[possibly-undefined] fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)
if node_b_is_start_node: if node_b_is_start_node:
@ -817,7 +817,7 @@ def create_a_shadows_b(
# cast dtype from the dtype of node_c's input to the dtype of # cast dtype from the dtype of node_c's input to the dtype of
# node_a's input (dequant, etc) # node_a's input (dequant, etc)
# prev_node_c = node_c.args[0] # prev_node_c = node_c.args[0]
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) # type: ignore[possibly-undefined] prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
if should_log_inputs: if should_log_inputs:
# skip the input logger when inserting a dtype cast # skip the input logger when inserting a dtype cast
if isinstance(prev_node_c, Node): if isinstance(prev_node_c, Node):
@ -901,7 +901,7 @@ def create_a_shadows_b(
# input_logger = env_c[dtype_cast_node.name] # input_logger = env_c[dtype_cast_node.name]
# Find the first node in the subgraph # Find the first node in the subgraph
cur_node = node_a_shadows_c cur_node = node_a_shadows_c
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined] while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger:
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment] cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
if isinstance(input_logger, Node): if isinstance(input_logger, Node):
input_logger_mod = getattr(gm_b, input_logger.name) input_logger_mod = getattr(gm_b, input_logger.name)

View File

@ -92,7 +92,7 @@ class OutputProp:
elif node.op == 'call_module': elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined] if isinstance(result, torch.Tensor):
node.traced_result = result node.traced_result = result
env[node.name] = result env[node.name] = result
@ -375,7 +375,7 @@ def create_submodule_from_subgraph(
# TODO(future PR): this is ignoring kwargs, will need to support kwargs # TODO(future PR): this is ignoring kwargs, will need to support kwargs
# for any fusion pattern which has them for a node that is not the # for any fusion pattern which has them for a node that is not the
# first node. # first node.
cur_args_copy = [cur_node_copy] # type: ignore[has-type, possibly-undefined] # noqa: F821 cur_args_copy = [cur_node_copy] # type: ignore[has-type] # noqa: F821
if len(cur_node_orig.args) > 1: if len(cur_node_orig.args) > 1:
for arg in cur_node_orig.args[1:]: for arg in cur_node_orig.args[1:]:
@ -399,15 +399,15 @@ def create_submodule_from_subgraph(
mod_name = f"mod_{cur_name_idx}" mod_name = f"mod_{cur_name_idx}"
setattr(gm, mod_name, orig_mod_copy) setattr(gm, mod_name, orig_mod_copy)
cur_name_idx += 1 cur_name_idx += 1
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined] cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy)
elif cur_node_orig.op == 'call_function': elif cur_node_orig.op == 'call_function':
cur_node_copy = g.call_function( cur_node_copy = g.call_function(
cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined] cur_node_orig.target, cur_args_copy, cur_kwargs_copy)
elif cur_node_orig.op == 'call_method': elif cur_node_orig.op == 'call_method':
cur_node_copy = g.call_method( cur_node_copy = g.call_method(
cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined] cur_node_orig.target, cur_args_copy, cur_kwargs_copy)
else: else:
raise AssertionError(f'{cur_node_orig.op} not supported yet') raise AssertionError(f'{cur_node_orig.op} not supported yet')

View File

@ -402,7 +402,7 @@ class ActivationSparsifier:
hook = layer.register_forward_pre_hook(self._sparsify_hook(name)) hook = layer.register_forward_pre_hook(self._sparsify_hook(name))
config['layer'] = layer config['layer'] = layer
config['hook'] = hook # type: ignore[possibly-undefined] config['hook'] = hook
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + ' (' format_string = self.__class__.__name__ + ' ('

View File

@ -117,7 +117,7 @@ def _prune_linear_helper(linear: nn.Linear) -> Tensor:
with torch.no_grad(): with torch.no_grad():
parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True) parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True)
linear.weight = nn.Parameter(linear.weight[mask]) # type: ignore[possibly-undefined] linear.weight = nn.Parameter(linear.weight[mask])
linear.out_features = linear.weight.shape[0] linear.out_features = linear.weight.shape[0]
_remove_bias_handles(linear) _remove_bias_handles(linear)
@ -175,7 +175,7 @@ def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor:
with torch.no_grad(): with torch.no_grad():
parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True) parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True)
conv2d.weight = nn.Parameter(conv2d.weight[mask]) # type: ignore[possibly-undefined] conv2d.weight = nn.Parameter(conv2d.weight[mask])
conv2d.out_channels = conv2d.weight.shape[0] conv2d.out_channels = conv2d.weight.shape[0]
_remove_bias_handles(conv2d) _remove_bias_handles(conv2d)
@ -197,7 +197,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
conv2d_1.bias is not None conv2d_1.bias is not None
): # conv2d_1 has original bias and bias propagated from previous layer ): # conv2d_1 has original bias and bias propagated from previous layer
new_bias = torch.zeros(conv2d_1.bias.shape) new_bias = torch.zeros(conv2d_1.bias.shape)
new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined] new_bias[mask] = conv2d_1.bias[mask]
# adjusted bias that to keep in conv2d_1 # adjusted bias that to keep in conv2d_1
new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask] new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
# pruned biases that are kept instead of propagated # pruned biases that are kept instead of propagated
@ -209,7 +209,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
if ( if (
conv2d_1.bias is not None conv2d_1.bias is not None
): # conv2d_1 has bias propagated from previous layer ): # conv2d_1 has bias propagated from previous layer
conv2d_1.bias.data[~mask] = 0 # type: ignore[possibly-undefined] conv2d_1.bias.data[~mask] = 0
if hasattr(conv2d_1, "_bias"): if hasattr(conv2d_1, "_bias"):
delattr(conv2d_1, "_bias") delattr(conv2d_1, "_bias")

View File

@ -835,7 +835,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
if ( if (
type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined] maybe_obs_mod.dtype == arg_as_input_target_dtype
): ):
arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment] arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
existing_obs_node = maybe_obs_node existing_obs_node = maybe_obs_node

View File

@ -516,7 +516,7 @@ def register_multi_grad_hook(
if tensor.requires_grad if tensor.requires_grad
) )
return Handle(handles) # type: ignore[possibly-undefined] return Handle(handles)
# NOTE [Allow mutation on tensors saved for backward] # NOTE [Allow mutation on tensors saved for backward]
@ -746,4 +746,4 @@ def _engine_run_backward(t_outputs, *args, **kwargs):
) # Calls into the C++ engine to run the backward pass ) # Calls into the C++ engine to run the backward pass
finally: finally:
if attach_logging_hooks: if attach_logging_hooks:
unregister_hooks() # type: ignore[possibly-undefined] unregister_hooks()

View File

@ -1148,7 +1148,7 @@ def _build_table(
if evt.flops <= 0: if evt.flops <= 0:
row_values.append("--") row_values.append("--")
else: else:
row_values.append(f"{evt.flops * flops_scale:8.3f}") # type: ignore[possibly-undefined] row_values.append(f"{evt.flops * flops_scale:8.3f}")
if has_stack: if has_stack:
src_field = "" src_field = ""
if len(evt.stack) > 0: if len(evt.stack) > 0:

View File

@ -1176,7 +1176,7 @@ class _NnapiSerializer:
shape=change_element(out_oper.shape, dim, out_dim_size) shape=change_element(out_oper.shape, dim, out_dim_size)
) )
if in_oper.dim_order == DimOrder.CHANNELS_LAST: # type: ignore[possibly-undefined] if in_oper.dim_order == DimOrder.CHANNELS_LAST:
assert len(out_oper.shape) == 4 assert len(out_oper.shape) == 4
nnapi_dim = [0, 3, 1, 2][dim] nnapi_dim = [0, 3, 1, 2][dim]
else: else:
@ -1633,10 +1633,10 @@ class _NnapiSerializer:
size_ctype, size_arg = self.get_constant_value(size_jit) size_ctype, size_arg = self.get_constant_value(size_jit)
if node.inputsSize() == 3: if node.inputsSize() == 3:
scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined] scale_ctype, scale_arg = self.get_constant_value(scale_jit)
else: else:
scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined] scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit)
scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined] scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit)
# The only way for the 4-argument overload of upsample_nearest2d to # The only way for the 4-argument overload of upsample_nearest2d to
# have been added to the graph without error is if the scale_h and # have been added to the graph without error is if the scale_h and

View File

@ -325,7 +325,7 @@ def make_graphed_callables(
only_inputs=True, only_inputs=True,
allow_unused=allow_unused_input, allow_unused=allow_unused_input,
) )
del outputs, grad_inputs # type: ignore[possibly-undefined] del outputs, grad_inputs
torch.cuda.synchronize() torch.cuda.synchronize()
# All captures here share a mempool. To avoid replays corrupting each other's memory, # All captures here share a mempool. To avoid replays corrupting each other's memory,

View File

@ -206,4 +206,4 @@ def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank):
start_pos = current_offsets start_pos = current_offsets
break break
current_offsets += chunk_size current_offsets += chunk_size
return start_pos, chunk_size # type: ignore[possibly-undefined] return start_pos, chunk_size

View File

@ -395,7 +395,7 @@ def _handle_row_wise_sharding(
result = torch.nn.functional.embedding_bag( result = torch.nn.functional.embedding_bag(
lookup_input, lookup_input,
torch.cat([local_shard, padding_row]), torch.cat([local_shard, padding_row]),
offsets=offsets_list if offsets is not None else offsets, # type: ignore[possibly-undefined] offsets=offsets_list if offsets is not None else offsets,
mode=mode if mode != "mean" else "sum", mode=mode if mode != "mean" else "sum",
per_sample_weights=per_sample_weights, per_sample_weights=per_sample_weights,
max_norm=max_norm, max_norm=max_norm,

View File

@ -541,7 +541,7 @@ def mark_data_parallel_shardings(
# mark activation as sharded on batch dim # mark activation as sharded on batch dim
node_sharding = node_strategies[0] node_sharding = node_strategies[0]
node.meta["sharding"] = node_sharding # type: ignore[possibly-undefined] node.meta["sharding"] = node_sharding
placeholder_idx += 1 placeholder_idx += 1
elif node.op == "call_function": elif node.op == "call_function":

View File

@ -45,7 +45,7 @@ class GraphModuleTransformation:
"iter_graph_main_gm": iter_gm.main_gm.print_readable(False), "iter_graph_main_gm": iter_gm.main_gm.print_readable(False),
"iter_graph_cleanup_gm": iter_gm.cleanup_gm.print_readable(False), "iter_graph_cleanup_gm": iter_gm.cleanup_gm.print_readable(False),
}, },
graph_folder, # type: ignore[possibly-undefined] graph_folder,
) )
return iter_gm return iter_gm

View File

@ -353,7 +353,7 @@ def _scatter_wait_result(
gm.graph.node_replace_all_uses_with(orig_wait, wait_output_node) gm.graph.node_replace_all_uses_with(orig_wait, wait_output_node)
if last_split_reshape_node == split_node: if last_split_reshape_node == split_node:
last_split_reshape_node = wait_output_node # type: ignore[possibly-undefined] last_split_reshape_node = wait_output_node
need_sort_nodes = sorted(need_sort_nodes, key=lambda node: node_indices[node]) need_sort_nodes = sorted(need_sort_nodes, key=lambda node: node_indices[node])
gm.graph.move_after(need_sort_nodes, last_split_reshape_node) gm.graph.move_after(need_sort_nodes, last_split_reshape_node)

View File

@ -561,7 +561,7 @@ class IterGraph(fx.Graph):
delete_user_cb, delete_user_cb,
propagate_meta=propagate_meta, propagate_meta=propagate_meta,
) )
return ret # type: ignore[possibly-undefined] return ret
def node_add_user(self, node: fx.Node, user: Any) -> None: def node_add_user(self, node: fx.Node, user: Any) -> None:
for graph in self._all_graphs: for graph in self._all_graphs:
@ -607,8 +607,8 @@ class IterGraph(fx.Graph):
"_foreach_add_", "_foreach_add_",
): ):
step_node = node step_node = node
self.node_add_user(optim_node, output_node) # type: ignore[possibly-undefined] self.node_add_user(optim_node, output_node)
self.node_add_user(step_node, optim_node) # type: ignore[possibly-undefined] self.node_add_user(step_node, optim_node)
def defunctionalize_optim(self) -> None: def defunctionalize_optim(self) -> None:
# TODO: remove this API after DCE is not used with IterGraph # TODO: remove this API after DCE is not used with IterGraph
@ -624,8 +624,8 @@ class IterGraph(fx.Graph):
"_foreach_add_", "_foreach_add_",
): ):
step_node = node step_node = node
optim_node.users.pop(output_node, None) # type: ignore[possibly-undefined] optim_node.users.pop(output_node, None)
step_node.users.pop(optim_node, None) # type: ignore[possibly-undefined] step_node.users.pop(optim_node, None)
def freeze_cross_iter_movement(self) -> None: def freeze_cross_iter_movement(self) -> None:
self._freeze_cross_iter_movement = True self._freeze_cross_iter_movement = True

View File

@ -199,7 +199,7 @@ class OpDispatcher:
if output_sharding.output_spec is None: if output_sharding.output_spec is None:
if op_call == aten.equal.default: if op_call == aten.equal.default:
obj_list = [None for _ in range(dist.get_world_size())] obj_list = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined] dist.all_gather_object(obj_list, local_results)
obj_list = list(filter(lambda x: x is not None, obj_list)) obj_list = list(filter(lambda x: x is not None, obj_list))
# perform reduce on the collection with AND op # perform reduce on the collection with AND op
local_results = functools.reduce(operator.and_, obj_list, True) local_results = functools.reduce(operator.and_, obj_list, True)
@ -229,7 +229,7 @@ class OpDispatcher:
assert len(out_dts) >= 1, "out variant should have at least one out arg" assert len(out_dts) >= 1, "out variant should have at least one out arg"
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
else: else:
return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] return self.wrap(local_results, output_sharding.output_spec)
@staticmethod @staticmethod
def redistribute_local_args( def redistribute_local_args(

View File

@ -201,7 +201,7 @@ class Shard(Placement):
) )
if is_padded: if is_padded:
output = self._unpad_tensor(output, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined] output = self._unpad_tensor(output, pad_sizes[my_coordinate[mesh_dim]])
return output return output
def _to_replicate_tensor( def _to_replicate_tensor(
@ -236,7 +236,7 @@ class Shard(Placement):
group=(mesh, mesh_dim), group=(mesh, mesh_dim),
) )
if is_padded: if is_padded:
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined] unpad_size = full_chunk_size * num_chunks - logical_dim_size
result = self._unpad_tensor(result, unpad_size) result = self._unpad_tensor(result, unpad_size)
return result return result

View File

@ -69,7 +69,7 @@ class HybridModel(torch.nn.Module):
# Make sure combined PS dimension is always bigger or equal than the FC input # Make sure combined PS dimension is always bigger or equal than the FC input
assert NUM_PS * EMBEDDING_DIM >= 512 assert NUM_PS * EMBEDDING_DIM >= 512
dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512) dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512)
emb_lookups_reshaped = emb_lookups_cat.reshape( # type: ignore[possibly-undefined] emb_lookups_reshaped = emb_lookups_cat.reshape(
[emb_lookups_cat.shape[0] * dim_normalizer, 512] [emb_lookups_cat.shape[0] * dim_normalizer, 512]
) )
@ -195,7 +195,7 @@ def _run_trainer(emb_rref_list, rank):
# Throw away warm-up measurements # Throw away warm-up measurements
measurements = measurements[WARMUP_CYCLES:] measurements = measurements[WARMUP_CYCLES:]
return rank, measurements, batch_size # type: ignore[possibly-undefined] return rank, measurements, batch_size
def run_worker(rank, world_size): def run_worker(rank, world_size):

View File

@ -85,7 +85,7 @@ else:
if cur_rank in mesh_1d: if cur_rank in mesh_1d:
res_sub_mesh = sub_mesh res_sub_mesh = sub_mesh
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined] res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
# Assign the current DeviceMesh as the parent of the child DeviceMesh. # Assign the current DeviceMesh as the parent of the child DeviceMesh.
self.child_to_parent_mapping[res_sub_mesh] = device_mesh self.child_to_parent_mapping[res_sub_mesh] = device_mesh
return res_sub_mesh return res_sub_mesh

View File

@ -1945,9 +1945,9 @@ def _coalescing_manager(
work = group._end_coalescing(device) work = group._end_coalescing(device)
if async_ops: if async_ops:
cm.append(work) # type: ignore[possibly-undefined] cm.append(work)
else: else:
work.wait() # type: ignore[possibly-undefined] work.wait()
def batch_isend_irecv(p2p_op_list): def batch_isend_irecv(p2p_op_list):
@ -2460,7 +2460,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
# All ranks call gather with equal-sized tensors. # All ranks call gather with equal-sized tensors.
gather( gather(
input_tensor, input_tensor,
gather_list=output_tensors if my_rank == dst else None, # type: ignore[possibly-undefined] gather_list=output_tensors if my_rank == dst else None,
dst=dst, dst=dst,
group=group, group=group,
) )
@ -2560,7 +2560,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
# has only one element, we can skip the copy. # has only one element, we can skip the copy.
if my_rank == src: if my_rank == src:
if len(tensor_list) == 1: # type: ignore[possibly-undefined] if len(tensor_list) == 1:
object_tensor = tensor_list[0] object_tensor = tensor_list[0]
else: else:
object_tensor = torch.cat(tensor_list) object_tensor = torch.cat(tensor_list)
@ -2663,8 +2663,8 @@ def scatter_object_list(
# Src rank broadcasts the maximum tensor size. This is because all ranks are # Src rank broadcasts the maximum tensor size. This is because all ranks are
# expected to call into scatter() with equal-sized tensors. # expected to call into scatter() with equal-sized tensors.
if my_rank == src: if my_rank == src:
max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined] max_tensor_size = max(tensor_sizes)
for tensor in tensor_list: # type: ignore[possibly-undefined] for tensor in tensor_list:
tensor.resize_(max_tensor_size) tensor.resize_(max_tensor_size)
else: else:
max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
@ -2674,7 +2674,7 @@ def scatter_object_list(
output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device) output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device)
scatter( scatter(
output_tensor, output_tensor,
scatter_list=None if my_rank != src else tensor_list, # type: ignore[possibly-undefined] scatter_list=None if my_rank != src else tensor_list,
src=src, src=src,
group=group, group=group,
) )
@ -2683,7 +2683,7 @@ def scatter_object_list(
obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
scatter( scatter(
obj_tensor_size, obj_tensor_size,
scatter_list=None if my_rank != src else tensor_sizes, # type: ignore[possibly-undefined] scatter_list=None if my_rank != src else tensor_sizes,
src=src, src=src,
group=group, group=group,
) )

View File

@ -51,7 +51,7 @@ class Event:
return data return data
if isinstance(data, str): if isinstance(data, str):
data_dict = json.loads(data) data_dict = json.loads(data)
data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined] data_dict["source"] = EventSource[data_dict["source"]]
return Event(**data_dict) return Event(**data_dict)
def serialize(self) -> str: def serialize(self) -> str:
@ -105,7 +105,7 @@ class RdzvEvent:
return data return data
if isinstance(data, str): if isinstance(data, str):
data_dict = json.loads(data) data_dict = json.loads(data)
data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined] data_dict["node_state"] = NodeState[data_dict["node_state"]]
return RdzvEvent(**data_dict) return RdzvEvent(**data_dict)
def serialize(self) -> str: def serialize(self) -> str:

View File

@ -126,7 +126,7 @@ def prof(fn=None, group: str = "torchelastic"):
put_metric(f"{key}.failure", 1, group) put_metric(f"{key}.failure", 1, group)
raise raise
finally: finally:
put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined] put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group)
return result return result
return wrapper return wrapper
@ -164,7 +164,7 @@ def profile(group=None):
publish_metric( publish_metric(
group, group,
f"{func.__name__}.duration.ms", f"{func.__name__}.duration.ms",
get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined] get_elapsed_time_ms(start_time),
) )
return result return result

View File

@ -176,7 +176,7 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
"The connection to the C10d store has failed. See inner exception for details." "The connection to the C10d store has failed. See inner exception for details."
) from exc ) from exc
return store # type: ignore[possibly-undefined] return store
def _create_file_store(params: RendezvousParameters) -> FileStore: def _create_file_store(params: RendezvousParameters) -> FileStore:

View File

@ -57,7 +57,7 @@ def find_free_port():
s.listen(0) s.listen(0)
return s return s
except OSError as e: except OSError as e:
s.close() # type: ignore[possibly-undefined] s.close()
print(f"Socket creation attempt failed: {e}") print(f"Socket creation attempt failed: {e}")
raise RuntimeError("Failed to create a socket") raise RuntimeError("Failed to create a socket")

View File

@ -1767,8 +1767,8 @@ class FlatParamHandle:
) )
flat_param.data = flat_param._local_shard # type: ignore[attr-defined] flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
if self._use_orig_params: if self._use_orig_params:
if skip_use_sharded_views: # type: ignore[possibly-undefined] if skip_use_sharded_views:
self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined] self._unsharded_flat_param_for_skipped_views = unsharded_flat_param
else: else:
self._use_sharded_views() self._use_sharded_views()
# For the post-forward reshard, we may try to use sharded gradient # For the post-forward reshard, we may try to use sharded gradient
@ -1776,7 +1776,7 @@ class FlatParamHandle:
# in `no_sync()`), but for the post-backward reshard, we delay the # in `no_sync()`), but for the post-backward reshard, we delay the
# call to after the reduce-scatter. # call to after the reduce-scatter.
if ( if (
in_forward # type: ignore[possibly-undefined] in_forward
# Skip using gradient views if skipped using sharded views # Skip using gradient views if skipped using sharded views
# since exposing unsharded parameters with sharded gradients # since exposing unsharded parameters with sharded gradients
# may be confusing to the user # may be confusing to the user

View File

@ -885,7 +885,7 @@ def _materialize_meta_module(
warnings.warn( warnings.warn(
"Unable to call `reset_parameters()` for module on meta " "Unable to call `reset_parameters()` for module on meta "
f"device with error {str(e)}. Please ensure that your module of" f"device with error {str(e)}. Please ensure that your module of"
f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined] f"type {type(module)} implements a `reset_parameters()` method."
) )
raise e raise e
@ -994,7 +994,7 @@ def _move_states_to_device(
param.grad.data = param.grad.to(device_from_device_id) param.grad.data = param.grad.to(device_from_device_id)
for buffer in buffers: for buffer in buffers:
buffer.data = buffer.to(device_from_device_id) buffer.data = buffer.to(device_from_device_id)
elif current_device == cpu_device: # type: ignore[possibly-undefined] elif current_device == cpu_device:
_warn_cpu_init() _warn_cpu_init()

View File

@ -1419,7 +1419,7 @@ def _convert_all_state_info(
) )
gathered_state[name] = scalar_tensor_value gathered_state[name] = scalar_tensor_value
return dtype, state_buffers # type: ignore[possibly-undefined] return dtype, state_buffers
def _unflatten_orig_param_states( def _unflatten_orig_param_states(

View File

@ -171,7 +171,7 @@ def _create_1d_device_mesh(device_mesh: DeviceMesh, tp_mesh_dim: int = 0) -> Dev
if cur_rank in mesh_1d: if cur_rank in mesh_1d:
res_sub_mesh = sub_mesh res_sub_mesh = sub_mesh
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[tp_mesh_dim]] # type: ignore[possibly-undefined] res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[tp_mesh_dim]]
return res_sub_mesh return res_sub_mesh

View File

@ -253,7 +253,7 @@ class ExportedProgram:
user_args, self.call_spec.in_spec, exact_structural_match=True user_args, self.call_spec.in_spec, exact_structural_match=True
) # type: ignore[assignment] ) # type: ignore[assignment]
except Exception: except Exception:
_, received_spec = pytree.tree_flatten(user_args) # type: ignore[possibly-undefined] _, received_spec = pytree.tree_flatten(user_args)
raise TypeError( # noqa: TRY200 raise TypeError( # noqa: TRY200
"Trying to flatten user inputs with exported input tree spec: \n" "Trying to flatten user inputs with exported input tree spec: \n"
f"{self.call_spec.in_spec}\n" f"{self.call_spec.in_spec}\n"

View File

@ -998,7 +998,7 @@ class Partitioner:
if cost < min_cost: if cost < min_cost:
node_pair = [node, n1] node_pair = [node, n1]
min_cost = cost min_cost = cost
return cost, node_pair # type: ignore[possibly-undefined] return cost, node_pair
# First use size_base_partition # First use size_base_partition
self.size_based_partition() self.size_based_partition()

View File

@ -263,7 +263,7 @@ def split_const_subgraphs(
setattr( setattr(
split, split,
fx_const_folded_attrs_name, fx_const_folded_attrs_name,
torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined] torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(),
) )
for node in split.graph.nodes: for node in split.graph.nodes:
if node.op == "call_module" and node.target == const_mod_name: if node.op == "call_module" and node.target == const_mod_name:

View File

@ -694,7 +694,7 @@ for name in math_op_names:
fn.__qualname__ = fn.__name__ = priv_sympy_name fn.__qualname__ = fn.__name__ = priv_sympy_name
setattr(current_module, priv_sympy_name, fn) setattr(current_module, priv_sympy_name, fn)
del fn, name, priv_sympy_name # type: ignore[possibly-undefined] del fn, name, priv_sympy_name
def _sympy_abs(a): def _sympy_abs(a):
@ -753,7 +753,7 @@ for name in math_op_names:
sym_name = f"sym_{name}" sym_name = f"sym_{name}"
magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}") magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined] del name, sym_name, math_op_names, current_module
def sympy_is_contiguous(sizes, strides): def sympy_is_contiguous(sizes, strides):

View File

@ -68,8 +68,8 @@ def consistent(a, b):
p1 += 1 p1 += 1
# We only need to check for variadic ends # We only need to check for variadic ends
# Variadic types are guaranteed to be the last element # Variadic types are guaranteed to be the last element
return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined] return (isvariadic(cur_a) and p2 == len(b) or
isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined] isvariadic(cur_b) and p1 == len(a))
def ambiguous(a, b): def ambiguous(a, b):

View File

@ -371,11 +371,11 @@ class _MinimizerBase:
# Compare results # Compare results
names: Names = output_names names: Names = output_names
if output_names is None: if output_names is None:
names = [str(v) for v in result_key] # type: ignore[possibly-undefined] names = [str(v) for v in result_key]
numeric_result, bool_result = self.compare_fn(a_result, b_result, names) numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
self.results[result_key] = numeric_result # type: ignore[possibly-undefined] self.results[result_key] = numeric_result
report.append(f"Numerical accuracy = {numeric_result}") report.append(f"Numerical accuracy = {numeric_result}")
if not bool_result: if not bool_result:
report.append(f"Result mismatch for {result_key}") report.append(f"Result mismatch for {result_key}")

View File

@ -575,7 +575,7 @@ class _SplitterBase:
else: else:
total_output_bytes += get_size_of_node(submod, node)[0] total_output_bytes += get_size_of_node(submod, node)[0]
map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined] map_arg(output_node.args, get_bytes)
qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"

View File

@ -305,7 +305,7 @@ def _replace_pattern(
first_user_node = n first_user_node = n
break break
with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined] with original_graph.inserting_before(first_user_node):
copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map) copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map)
if isinstance(copied_returning_nodes, Node): if isinstance(copied_returning_nodes, Node):

View File

@ -650,14 +650,14 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
buffer = u.read(READ_DATA_CHUNK) buffer = u.read(READ_DATA_CHUNK)
if len(buffer) == 0: if len(buffer) == 0:
break break
f.write(buffer) # type: ignore[possibly-undefined] f.write(buffer)
if hash_prefix is not None: if hash_prefix is not None:
sha256.update(buffer) # type: ignore[possibly-undefined] sha256.update(buffer)
pbar.update(len(buffer)) pbar.update(len(buffer))
f.close() f.close()
if hash_prefix is not None: if hash_prefix is not None:
digest = sha256.hexdigest() # type: ignore[possibly-undefined] digest = sha256.hexdigest()
if digest[:len(hash_prefix)] != hash_prefix: if digest[:len(hash_prefix)] != hash_prefix:
raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")') raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
shutil.move(f.name, dst) shutil.move(f.name, dst)

View File

@ -70,8 +70,8 @@ def fuser(name):
yield yield
finally: finally:
if name in ["fuser1", "fuser3"]: # NNC or oneDNN Graph if name in ["fuser1", "fuser3"]: # NNC or oneDNN Graph
torch._C._jit_set_profiling_executor(old_profiling_executor) # type: ignore[possibly-undefined] torch._C._jit_set_profiling_executor(old_profiling_executor)
torch._C._get_graph_executor_optimize(old_profiling_mode) # type: ignore[possibly-undefined] torch._C._get_graph_executor_optimize(old_profiling_mode)
# recover the previous values # recover the previous values
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse) torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse) torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)

View File

@ -254,7 +254,7 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
if assert_compiled: if assert_compiled:
hits = compiled_fn.hits hits = compiled_fn.hits
out = model(*args) out = model(*args)
if assert_compiled and compiled_fn.hits == hits: # type: ignore[possibly-undefined] if assert_compiled and compiled_fn.hits == hits:
raise RuntimeError("failed to use the compiled function") raise RuntimeError("failed to use the compiled function")
if not isinstance(out, tuple): if not isinstance(out, tuple):
out = (out,) out = (out,)
@ -280,7 +280,7 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
assert model.has_trace_for(*args) assert model.has_trace_for(*args)
if is_module: if is_module:
model.load_state_dict(saved_state) # type: ignore[possibly-undefined] model.load_state_dict(saved_state)
compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True) compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
_verify_equal(uncompiled_outs, compiled_outs) _verify_equal(uncompiled_outs, compiled_outs)

View File

@ -1627,7 +1627,7 @@ def _std_var(
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
else: else:
total = sum( total = sum(
x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined] x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask
) )
if not keepdim: if not keepdim:
count = count.reshape(total.shape) count = count.reshape(total.shape)

View File

@ -781,8 +781,8 @@ class SyncBatchNorm(_BatchNorm):
running_var, running_var,
self.eps, self.eps,
exponential_average_factor, exponential_average_factor,
process_group, # type: ignore[possibly-undefined] process_group,
world_size, # type: ignore[possibly-undefined] world_size,
) )
@classmethod @classmethod

View File

@ -1604,9 +1604,9 @@ class Module:
# For now only forward hooks have the always_call option but perhaps # For now only forward hooks have the always_call option but perhaps
# this functionality should be added to full backward hooks as well. # this functionality should be added to full backward hooks as well.
for hook_id, hook in _global_forward_hooks.items(): for hook_id, hook in _global_forward_hooks.items():
if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks:
try: try:
hook_result = hook(self, args, result) # type: ignore[possibly-undefined] hook_result = hook(self, args, result)
if hook_result is not None: if hook_result is not None:
result = hook_result result = hook_result
except Exception as e: except Exception as e:
@ -1615,12 +1615,12 @@ class Module:
continue continue
for hook_id, hook in self._forward_hooks.items(): for hook_id, hook in self._forward_hooks.items():
if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks:
try: try:
if hook_id in self._forward_hooks_with_kwargs: if hook_id in self._forward_hooks_with_kwargs:
hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined] hook_result = hook(self, args, kwargs, result)
else: else:
hook_result = hook(self, args, result) # type: ignore[possibly-undefined] hook_result = hook(self, args, result)
if hook_result is not None: if hook_result is not None:
result = hook_result result = hook_result
except Exception as e: except Exception as e:

View File

@ -575,8 +575,8 @@ class RNN(RNNBase):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices)
if not is_batched: # type: ignore[possibly-undefined] if not is_batched:
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] output = output.squeeze(batch_dim)
hidden = hidden.squeeze(1) hidden = hidden.squeeze(1)
return output, self.permute_hidden(hidden, unsorted_indices) return output, self.permute_hidden(hidden, unsorted_indices)
@ -888,8 +888,8 @@ class LSTM(RNNBase):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices)
else: else:
if not is_batched: # type: ignore[possibly-undefined] if not is_batched:
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] output = output.squeeze(batch_dim)
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
return output, self.permute_hidden(hidden, unsorted_indices) return output, self.permute_hidden(hidden, unsorted_indices)
@ -1111,8 +1111,8 @@ class GRU(RNNBase):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices)
else: else:
if not is_batched: # type: ignore[possibly-undefined] if not is_batched:
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] output = output.squeeze(batch_dim)
hidden = hidden.squeeze(1) hidden = hidden.squeeze(1)
return output, self.permute_hidden(hidden, unsorted_indices) return output, self.permute_hidden(hidden, unsorted_indices)

View File

@ -105,7 +105,7 @@ class _Orthogonal(Module):
Q = self.base @ Q Q = self.base @ Q
if transposed: if transposed:
Q = Q.mT Q = Q.mT
return Q # type: ignore[possibly-undefined] return Q
@torch.autograd.no_grad() @torch.autograd.no_grad()
def right_inverse(self, Q: torch.Tensor) -> torch.Tensor: def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:

View File

@ -293,7 +293,7 @@ def _create_node(
for _ in range(1, n_outputs): for _ in range(1, n_outputs):
node.addOutput() node.addOutput()
node_ouputs = tuple(node.outputs()) # type: ignore[possibly-undefined] node_ouputs = tuple(node.outputs())
assert len(node_ouputs) == n_outputs assert len(node_ouputs) == n_outputs
aten = domain_op.startswith("aten::") aten = domain_op.startswith("aten::")

View File

@ -1529,7 +1529,7 @@ def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
) )
if is_transpose_required: if is_transpose_required:
softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined] softmax = g.op("Transpose", softmax, perm_i=axes)
return softmax return softmax
# Apply max normalization. # Apply max normalization.
@ -2467,7 +2467,7 @@ def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
"Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
) )
if is_transpose_required: if is_transpose_required:
return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined] return_op = g.op("Transpose", return_op, perm_i=axes)
return return_op return return_op
@ -2978,7 +2978,7 @@ def native_layer_norm(
# mean and normalized, so we need to Cast it back # mean and normalized, so we need to Cast it back
if is_type_half: if is_type_half:
denominator = g.op( denominator = g.op(
"Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() # type: ignore[possibly-undefined] "Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type()
) )
rdenominator = g.op("Reciprocal", denominator) rdenominator = g.op("Reciprocal", denominator)
else: else:
@ -4754,7 +4754,7 @@ def _generic_rnn(
reform_weights(g, w, hidden_size, reform_permutation) for w in weights reform_weights(g, w, hidden_size, reform_permutation) for w in weights
) )
return tuple( return tuple(
symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined] symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh)
) )
@_beartype.beartype @_beartype.beartype
@ -4766,10 +4766,10 @@ def _generic_rnn(
weight_ih, weight_hh, bias_ih, bias_hh = ( weight_ih, weight_hh, bias_ih, bias_hh = (
reform_weights(g, w, hidden_size, reform_permutation) for w in weights reform_weights(g, w, hidden_size, reform_permutation) for w in weights
) )
bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined] bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0)
return tuple( return tuple(
symbolic_helper._unsqueeze_helper(g, x, [0]) symbolic_helper._unsqueeze_helper(g, x, [0])
for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined] for x in (weight_ih, weight_hh, bias_concat)
) )
@_beartype.beartype @_beartype.beartype
@ -4808,16 +4808,16 @@ def _generic_rnn(
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined] inputs.append(retrieve_state(h0, *state_indices))
if variant == "LSTM": if variant == "LSTM":
inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined] inputs.append(retrieve_state(c0, *state_indices))
extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"}
if variant == "RNN": if variant == "RNN":
if bidirectional: if bidirectional:
activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined] activation = [nonlinearity, nonlinearity]
else: else:
activation = [nonlinearity] # type: ignore[possibly-undefined] activation = [nonlinearity]
prev_output, h_out = g.op( prev_output, h_out = g.op(
"RNN", "RNN",
@ -4859,17 +4859,17 @@ def _generic_rnn(
else: else:
prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1])
h_outs.append(h_out) # type: ignore[possibly-undefined] h_outs.append(h_out)
if variant == "LSTM": if variant == "LSTM":
c_outs.append(c_out) # type: ignore[possibly-undefined] c_outs.append(c_out)
if batch_first: if batch_first:
# seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size
prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2])
h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined] h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0)
if variant == "RNN" or variant == "GRU": if variant == "RNN" or variant == "GRU":
return prev_output, h_outs return prev_output, h_outs
elif variant == "LSTM": elif variant == "LSTM":
c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined] c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0)
return prev_output, h_outs, c_outs return prev_output, h_outs, c_outs

View File

@ -199,7 +199,7 @@ class BasicEvaluation:
while ( while (
current_kernel_index < len(cuda_kernel_events) current_kernel_index < len(cuda_kernel_events)
and (cuda_kernel_events[current_kernel_index].start_us()) * 1000 and (cuda_kernel_events[current_kernel_index].start_us()) * 1000
<= start_time # type: ignore[possibly-undefined] <= start_time
): ):
current_kernel_index += 1 current_kernel_index += 1
current_queue_depth = spawned_kernel_index - current_kernel_index + 1 current_queue_depth = spawned_kernel_index - current_kernel_index + 1
@ -207,7 +207,7 @@ class BasicEvaluation:
if hasattr(event, "start_us"): if hasattr(event, "start_us"):
queue_depth_list.append( queue_depth_list.append(
Interval(start_time, end_time, current_queue_depth) # type: ignore[possibly-undefined] Interval(start_time, end_time, current_queue_depth)
) )
elif hasattr(event, "start_time_ns"): elif hasattr(event, "start_time_ns"):
self.metrics[EventKey(event)].queue_depth = current_queue_depth self.metrics[EventKey(event)].queue_depth = current_queue_depth

View File

@ -44,7 +44,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
else: else:
outp = weight outp = weight
ncols, nrows = outp.shape # type: ignore[possibly-undefined] ncols, nrows = outp.shape
assert nrows % (32 if dtypeq == torch.quint4x2 else 64) == 0 assert nrows % (32 if dtypeq == torch.quint4x2 else 64) == 0
assert ncols % 64 == 0 assert ncols % 64 == 0

View File

@ -134,11 +134,11 @@ def sparse_semi_structured_from_dense_cutlass(dense):
idxs1 = bit2 | (bit3.to(torch.int64) << 1) idxs1 = bit2 | (bit3.to(torch.int64) << 1)
if dense.dtype != torch.float: if dense.dtype != torch.float:
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1))
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else: else:
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2)
meta_4 = idxs0 | (idxs1 << 2) meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
@ -163,7 +163,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
) )
# Reorder meta tensor elements. # Reorder meta tensor elements.
meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined] meta_reordered = meta.new_empty((m * meta_ncols,))
meta_offsets = _calculate_meta_reordering_scatter_offsets( meta_offsets = _calculate_meta_reordering_scatter_offsets(
m, meta_ncols, meta_dtype, device m, meta_ncols, meta_dtype, device
) )

View File

@ -1662,7 +1662,7 @@ if has_triton():
acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype) acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
if is_compressed: if is_compressed:
A_ptr += r0 * blocks_stride_P # type: ignore[possibly-undefined] A_ptr += r0 * blocks_stride_P
for _ in range(nnz): for _ in range(nnz):
q = tl.load(q_ptr) q = tl.load(q_ptr)
B = tl.load(B_ptr + q) B = tl.load(B_ptr + q)
@ -1889,7 +1889,7 @@ if has_triton():
# alpha is never 0 # alpha is never 0
if beta_is_nonzero: if beta_is_nonzero:
output_acc_block = tl.load(input_ptrs).to(acc_dtype) # type: ignore[possibly-undefined] output_acc_block = tl.load(input_ptrs).to(acc_dtype)
if not (beta_is_one and alpha_is_one): if not (beta_is_one and alpha_is_one):
beta_alpha = beta / alpha beta_alpha = beta / alpha
output_acc_block *= beta_alpha output_acc_block *= beta_alpha

View File

@ -76,7 +76,7 @@ def _output_csv(file, results):
dim_str = str(dim) dim_str = str(dim)
shape_str = 'x'.join(str(s) for s in shape) shape_str = 'x'.join(str(s) for s in shape)
print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str, # type: ignore[possibly-undefined] print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str,
measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6, measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6,
sep=',', file=file) sep=',', file=file)

View File

@ -701,7 +701,7 @@ class _ValgrindWrapper:
if fn_match: if fn_match:
ir_str, file_function = fn_match.groups() ir_str, file_function = fn_match.groups()
ir = int(ir_str.replace(",", "")) ir = int(ir_str.replace(",", ""))
if ir == program_totals: # type: ignore[possibly-undefined] if ir == program_totals:
# Callgrind includes some top level red herring symbols when # Callgrind includes some top level red herring symbols when
# a program dumps multiple profiles. # a program dumps multiple profiles.
continue continue

View File

@ -1427,7 +1427,7 @@ def _checkpoint_without_reentrant_generator(
new_frame.forward_completed = True new_frame.forward_completed = True
if getattr(device_module, "_initialized", False) and \ if getattr(device_module, "_initialized", False) and \
preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined] preserve_rng_state and not had_device_in_fwd:
# Device was not initialized before running the forward, so we didn't # Device was not initialized before running the forward, so we didn't
# stash the device state. # stash the device state.
raise RuntimeError( raise RuntimeError(

View File

@ -2391,7 +2391,7 @@ def _write_ninja_file(path,
# 'Blocks' should be separated by newlines, for visual benefit. # 'Blocks' should be separated by newlines, for visual benefit.
blocks = [config, flags, compile_rule] blocks = [config, flags, compile_rule]
if with_cuda: if with_cuda:
blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule)
blocks += [devlink_rule, link_rule, build, devlink, link, default] blocks += [devlink_rule, link_rule, build, devlink, link, default]
content = "\n\n".join("\n".join(b) for b in blocks) content = "\n\n".join("\n".join(b) for b in blocks)
# Ninja requires a new lines at the end of the .ninja file # Ninja requires a new lines at the end of the .ninja file

View File

@ -305,7 +305,7 @@ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
init_exception = None init_exception = None
else: else:
try: try:
data = fetcher.fetch(index) # type: ignore[possibly-undefined] data = fetcher.fetch(index)
except Exception as e: except Exception as e:
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable: if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
data = _IterableDatasetStopIteration(worker_id) data = _IterableDatasetStopIteration(worker_id)

View File

@ -1360,7 +1360,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# not found (i.e., didn't break) # not found (i.e., didn't break)
return return
self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined] self._index_queues[worker_queue_idx].put((self._send_idx, index))
self._task_info[self._send_idx] = (worker_queue_idx,) self._task_info[self._send_idx] = (worker_queue_idx,)
self._tasks_outstanding += 1 self._tasks_outstanding += 1
self._send_idx += 1 self._send_idx += 1

View File

@ -210,7 +210,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
raise BufferError("ForkerIterDataPipe buffer overflow," + raise BufferError("ForkerIterDataPipe buffer overflow," +
f"buffer size {self.buffer_size} is insufficient.") f"buffer size {self.buffer_size} is insufficient.")
yield self.copy_fn(return_val) # type: ignore[possibly-undefined] yield self.copy_fn(return_val)
finally: finally:
self._child_stop[instance_id] = True self._child_stop[instance_id] = True
# Cleanup _datapipe_iterator for the case that fork exits earlier # Cleanup _datapipe_iterator for the case that fork exits earlier

View File

@ -907,7 +907,7 @@ class SummaryWriter:
else: else:
# Handles cnn.CNNModelHelper, model_helper.ModelHelper # Handles cnn.CNNModelHelper, model_helper.ModelHelper
current_graph = model_to_graph_def(model) current_graph = model_to_graph_def(model)
event = event_pb2.Event(graph_def=current_graph.SerializeToString()) # type: ignore[possibly-undefined] event = event_pb2.Event(graph_def=current_graph.SerializeToString())
self._get_file_writer().add_event(event) self._get_file_writer().add_event(event)
@staticmethod @staticmethod

View File

@ -717,10 +717,10 @@ resize_out(out, sizes, strides, options);
f"{textwrap.indent(class_ctor_str, indent)}", f"{textwrap.indent(class_ctor_str, indent)}",
f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}", f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
" const Tensor& maybe_get_output(int64_t output_idx) override {", " const Tensor& maybe_get_output(int64_t output_idx) override {",
f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit f" return {output_value};\n",
" }", " }",
f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", # type: ignore[possibly-undefined] # TODO: audit f" std::array<{output_type}, {len(f.func.returns)}> outputs_;",
f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit f"{textwrap.indent(proxy_field, indent)}",
f"{textwrap.indent(guard_field, indent)}", f"{textwrap.indent(guard_field, indent)}",
"};", "};",
) )
@ -962,7 +962,7 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
else: else:
refs = ", ".join(a.name for a in f.func.arguments.out) refs = ", ".join(a.name for a in f.func.arguments.out)
ret_expr = f"std::forward_as_tuple({refs})" ret_expr = f"std::forward_as_tuple({refs})"
sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit sig_body.append(f"return {ret_expr};")
sig_body_str = "\n".join(sig_body) sig_body_str = "\n".join(sig_body)