Enable possibly-undefined error code (#118533)

Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
This commit is contained in:
Catherine Lee 2024-01-30 11:36:08 -08:00 committed by PyTorch MergeBot
parent e332653eb3
commit 4f5785b6b3
94 changed files with 200 additions and 197 deletions

View File

@ -13,6 +13,7 @@ show_column_numbers = True
check_untyped_defs = True
follow_imports = normal
local_partial_types = True
enable_error_code = possibly-undefined
# do not reenable this:
# https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657

View File

@ -1,3 +1,4 @@
# mypy: disable-error-code="possibly-undefined"
# flake8: noqa
import torch
from torch.testing._internal.common_utils import TEST_NUMPY

View File

@ -1,3 +1,4 @@
# mypy: disable-error-code="possibly-undefined"
# flake8: noqa
import torch
from torch.testing._internal.common_utils import TEST_NUMPY

View File

@ -515,7 +515,7 @@ for name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos"
sym_sqrt = current_module._sym_sqrt
__all__.append("sym_sqrt")
del fn, name, sym_name, current_module
del fn, name, sym_name, current_module # type: ignore[possibly-undefined]
def sym_ite(b, t, f):

View File

@ -2832,7 +2832,7 @@ def _rnn_helper(
final_hiddens.append(bwd_hidden)
if bidirectional:
input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1)
input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1) # type: ignore[possibly-undefined]
else:
input = fwd_inp

View File

@ -163,7 +163,7 @@ def preserve_global_state(fn):
random.setstate(py_rng_state)
torch.random.set_rng_state(torch_rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
torch.fx.graph_module._forward_from_src = prior_fwd_from_src
assert (
guards.check()
@ -568,7 +568,7 @@ def _compile(
code.co_name,
code.co_filename,
code.co_firstlineno,
out_code,
out_code, # type: ignore[possibly-undefined]
)
for hook in _bytecode_hooks.values():

View File

@ -46,7 +46,7 @@ if use_buck:
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
]
cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//")
cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined]
extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])

View File

@ -1430,7 +1430,7 @@ def export(
example_fake_inputs,
graph_captured_input,
graph_captured_result,
result_traced,
result_traced, # type: ignore[possibly-undefined]
flat_args_dynamic_dims,
)
# Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check

View File

@ -1115,7 +1115,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
tos = self.pop()
_ = self.pop()
if preserve_tos:
self.push(tos)
self.push(tos) # type: ignore[possibly-undefined]
def FOR_ITER(self, inst):
it = self.pop().realize()

View File

@ -118,7 +118,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
finally:
log.removeHandler(log_handler)
if cwd is not None:
os.chdir(prev_cwd)
os.chdir(prev_cwd) # type: ignore[possibly-undefined]
# Make sure we don't leave buggy compiled frames lying
# around
torch._dynamo.reset()

View File

@ -773,7 +773,7 @@ def preserve_rng_state():
with torch.utils._python_dispatch._disable_current_modes():
torch.random.set_rng_state(rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
def is_jit_model(model0):
@ -892,7 +892,7 @@ def timed(model, example_inputs, times=1):
result = model(*example_inputs)
synchronize()
t1 = time.perf_counter()
return result, t1 - t0
return result, t1 - t0 # type: ignore[possibly-undefined]
def check_is_cuda(gm, example_inputs):

View File

@ -199,7 +199,7 @@ def wrap_outputs_maintaining_identity(
result.append(unwrapped_input_to_orig_input[id(output)])
continue
if out_dims_specified:
result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[index]
result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index]
else:
result.append(wrap_fn(output))

View File

@ -163,7 +163,7 @@ def parse_ttir(ttir, kwargs):
return None
try:
import lark
import lark # type: ignore[import-not-found]
from lark import Lark, Transformer, v_args
except ModuleNotFoundError:
warnings.warn(

View File

@ -440,25 +440,25 @@ class BenchmarkRequest:
output_tensor = self.output_tensor_meta.to_tensor()
if debug:
create_tensor_elapse = time.time() - start_ts
create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
start_ts = time.time()
fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
if debug:
load_elapse = time.time() - start_ts
load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
start_ts = time.time()
out = do_bench(fn)
torch.cuda.synchronize() # shake out any CUDA errors
if debug:
bench_elapse = time.time() - start_ts
bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
log.debug(
"InChildProcess %s: load %f, create tensor %f, bench %f",
str(self),
load_elapse,
create_tensor_elapse,
load_elapse, # type: ignore[possibly-undefined]
create_tensor_elapse, # type: ignore[possibly-undefined]
bench_elapse,
)
self.cleanup_run_fn()

View File

@ -99,7 +99,7 @@ class CutlassEVTEpilogueTypeFormatter:
result = pnode.inner_fn(index)
# each epilogue node results in a single "using" statement and may refer to the previous steps by name
formatter.aliases[node.name] = result
res = formatter.getvalue(result)
res = formatter.getvalue(result) # type: ignore[possibly-undefined]
if _MAGIC_SYMPY_ERROR_STRING in res:
raise CUTLASSEVTOpNotImplementedError(
"sympy / indexing expressions not yet supported in EVT fusion"
@ -266,7 +266,7 @@ class CutlassEVTEpilogueArgumentFormatter:
if node.name is not None:
formatter.aliases[node.name] = result
res: str = formatter.getvalue(result)
res: str = formatter.getvalue(result) # type: ignore[possibly-undefined]
if _MAGIC_SYMPY_ERROR_STRING in res:
raise CUTLASSEVTOpNotImplementedError(
"sympy / indexing expressions not yet supported in EVT fusion"

View File

@ -155,7 +155,7 @@ def get_cpp_op_schema(kernel: torch._ops.OpOverload) -> str:
cpp_return_value = f"std::tuple<{tuple_returns}>"
cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args]
return f"{cpp_return_value}({', '.join(cpp_arg_type)})"
return f"{cpp_return_value}({', '.join(cpp_arg_type)})" # type: ignore[possibly-undefined]
# TODO: Move to a well known place

View File

@ -209,7 +209,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
nsteps = nRanks - 1
# Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
ratio = (1.0 * nRanks) / nsteps
ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined]
bandwidth = busBw * ratio
# Convert GB/s to GB/ns
bandwidth_GB_per_ns = bandwidth / 1e9
@ -236,7 +236,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
if nNodes > 1:
netOverhead = 1.0 # getNetOverhead(comm);
intraLat = max(intraLat, netOverhead)
latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat
latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined]
# Convert us to ns
latency_ns = latency * 1e3

View File

@ -170,9 +170,9 @@ class PostGradBatchLinearFusion(BatchFusion):
input, weight = node.args
bias = None
batch_nodes.append(node)
batch_inputs.append(input)
batch_weights.append(weight)
batch_biases.append(bias)
batch_inputs.append(input) # type: ignore[possibly-undefined]
batch_weights.append(weight) # type: ignore[possibly-undefined]
batch_biases.append(bias) # type: ignore[possibly-undefined]
with graph.inserting_before(subset[-1]):
fused_inputs = decompose_stack(graph, batch_inputs)
@ -191,7 +191,7 @@ class PostGradBatchLinearFusion(BatchFusion):
new_bias_add = graph.call_function(
aten.add, args=((batch_biases[i], new_mm))
)
new_mm_cont = new_bias_add if has_bias else new_mm
new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined]
original_mm.replace_all_uses_with(new_mm_cont)
new_mm_cont.meta.update(original_mm.meta)
graph.erase_node(original_mm)

View File

@ -283,7 +283,7 @@ if torch._C._has_mkldnn:
L[aten.mul](out, negative_slope),
)
if lowp_dtype:
out = L[prims.convert_element_type.default](out, dtype=dtype2)
out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
return out
return fn
@ -324,7 +324,7 @@ if torch._C._has_mkldnn:
out = L[prims.convert_element_type.default](out, dtype=torch.float)
out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value)
if lowp_dtype:
out = L[prims.convert_element_type.default](out, dtype=dtype2)
out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
return out
return fn

View File

@ -105,7 +105,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
gm_after_fx_passes = gm.__copy__()
numeric_check_if_enabled(
gm_before_fx_passes,
gm_before_fx_passes, # type: ignore[possibly-undefined]
gm_after_fx_passes,
example_inputs,
config.fx_passes_numeric_check.get("num_iterations", 1),

View File

@ -1360,7 +1360,7 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
graph.erase_node(conv_node)
# Erase the dequant pattern
if dtype == torch.bfloat16:
graph.erase_node(convert_to_bf16)
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined]
# Erase the dequant pattern
graph.erase_node(mul_node)
graph.erase_node(sub_node)
@ -1369,7 +1369,7 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
if clone_node is not None:
graph.erase_node(clone_node)
if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node)
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
graph.erase_node(dequant_per_channel)
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
@ -1697,14 +1697,14 @@ def _register_qlinear_weight_prepack_pass(
if input_contiguous:
graph.erase_node(output_reshape_node)
elif not input_contiguous and bias:
graph.erase_node(output_add_node_for_bias)
graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined]
graph.erase_node(linear_node)
if input_dim_exceeds_two:
if input_contiguous:
graph.erase_node(act_reshape_node)
else:
graph.erase_node(act_expand_node)
graph.erase_node(wgt_expand_node)
graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined]
if dtype == torch.bfloat16:
graph.erase_node(activation_to_bf16_node)
# Erase the dequant pattern
@ -1714,7 +1714,7 @@ def _register_qlinear_weight_prepack_pass(
# Erase the dequant per channel pattern
graph.erase_node(t_node)
if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node)
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
graph.erase_node(dequant_per_channel)
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1

View File

@ -857,7 +857,7 @@ class GraphLowering(torch.fx.Interpreter):
):
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(
*args, **kwargs
*args, **kwargs # type: ignore[possibly-undefined]
)
elif n.op == "call_function" and n.target in layout_constraints:
debug("layout_constraints")

View File

@ -607,7 +607,7 @@ def register_pointwise(
fn,
override_return_dtype=override_return_dtype,
override_fn_when_input_bool=override_fn_when_input_bool,
override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None,
override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None, # type: ignore[possibly-undefined]
allow_alpha=allow_alpha,
)
fn = register_lowering(
@ -3630,8 +3630,8 @@ def _reflection_padnd_backward(grad_output, x, padding):
out = right_reflect[i]
index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1)
outs.append(out)
index_ranges.append(index_range)
outs.append(out) # type: ignore[possibly-undefined]
index_ranges.append(index_range) # type: ignore[possibly-undefined]
grad = accumulate(grad, outs, index_ranges)

View File

@ -1196,7 +1196,7 @@ class PatternMatcherPass:
if (
self.prevent_match_across_mutations
and is_match(m)
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
):
continue
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:

View File

@ -1038,7 +1038,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
else:
fused_nodes.append(node)
return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2)
return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) # type: ignore[possibly-undefined]
def __init__(
self,
@ -2256,13 +2256,13 @@ class Scheduler:
if node.is_template():
node, *epilogue = node.get_nodes()
self.get_backend(device).codegen_template(node, epilogue)
self.get_backend(device).codegen_template(node, epilogue) # type: ignore[possibly-undefined]
elif node.is_extern():
self.codegen_extern_call(node)
elif node.is_foreach():
self.get_backend(device).codegen_foreach(node)
self.get_backend(device).codegen_foreach(node) # type: ignore[possibly-undefined]
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
self.get_backend(device).codegen_nodes(node.get_nodes())
self.get_backend(device).codegen_nodes(node.get_nodes()) # type: ignore[possibly-undefined]
else:
assert isinstance(node, NopKernelSchedulerNode)
node.allocate()
@ -2271,7 +2271,7 @@ class Scheduler:
V.graph.wrapper_code.generate_inf_and_nan_checker(node)
if config.triton.debug_sync_kernel:
self.get_backend(device).codegen_sync()
self.get_backend(device).codegen_sync() # type: ignore[possibly-undefined]
self.available_buffer_names.update(node.get_names())

View File

@ -331,7 +331,7 @@ def timed(
synchronize(device)
t1 = time.perf_counter()
# GC the result after timing
assert result is not None
assert result is not None # type: ignore[possibly-undefined]
return t1 - t0

View File

@ -1147,7 +1147,7 @@ def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
f"but expected one of 'reduced' (default), 'r', or 'complete'"
),
)
return compute_q, reduced
return compute_q, reduced # type: ignore[possibly-undefined]
@register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
@ -1412,7 +1412,7 @@ def triangular_solve_meta(
cloned_coefficient = self.new_empty([0])
else:
torch._check(False, lambda: "triangular_solve: Got an unexpected layout.")
return solution, cloned_coefficient
return solution, cloned_coefficient # type: ignore[possibly-undefined]
# From aten/src/ATen/native/LinearAlgebra.cpp
@ -1809,7 +1809,7 @@ def _pad3d_common(input, padding, *, is_reflection):
)
if batch_mode:
return input.new_empty((nbatch, nplane, output_d, output_h, output_w))
return input.new_empty((nbatch, nplane, output_d, output_h, output_w)) # type: ignore[possibly-undefined]
else:
return input.new_empty((nplane, output_d, output_h, output_w))

View File

@ -246,10 +246,10 @@ def TensorMeta(
assert dtype is not None
assert device is not None
shape = inferred_shape if shape is None else tuple(shape)
strides = inferred_strides if strides is None else tuple(strides)
dtype = inferred_dtype if dtype is None else dtype
device = inferred_device if device is None else device
shape = inferred_shape if shape is None else tuple(shape) # type: ignore[possibly-undefined]
strides = inferred_strides if strides is None else tuple(strides) # type: ignore[possibly-undefined]
dtype = inferred_dtype if dtype is None else dtype # type: ignore[possibly-undefined]
device = inferred_device if device is None else device # type: ignore[possibly-undefined]
if isinstance(device, str):
device = torch.device(device)

View File

@ -4875,16 +4875,16 @@ def arange(
# other integral dtypes we don't. Weird... but needed to match ATen shapes.
if dtype == torch.int64:
# Uses floordiv to avoid ceil in inductor.
sgn = bool(xstep > 0) - bool(xstep < 0)
length = (xend - xstart + xstep - sgn) // xstep
sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined]
length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined]
else:
length = math.ceil((end - start) / step)
if is_integer:
return prims.iota(
length,
start=xstart,
step=xstep,
start=xstart, # type: ignore[possibly-undefined]
step=xstep, # type: ignore[possibly-undefined]
dtype=dtype,
device=device,
requires_grad=requires_grad,

View File

@ -312,7 +312,7 @@ def _canonicalize_fft_shape_and_dim_args(
# Translate any -1 values in shape to the default length
ret_shape = tuple(
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims)
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined]
)
elif dim is None:
# No shape, no dim
@ -320,12 +320,12 @@ def _canonicalize_fft_shape_and_dim_args(
ret_shape = tuple(input_sizes)
else:
# No shape, has dim
ret_shape = tuple(input_sizes[d] for d in ret_dims)
ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined]
for n in ret_shape:
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined]
def _prod(xs: Iterable[int]) -> int:

View File

@ -610,7 +610,7 @@ def _str_intern(inp, *, tensor_contents=None):
# no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968
grad_fn_name = "Invalid"
if grad_fn_name is None and grad_fn is not None:
if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined]
grad_fn_name = type(grad_fn).__name__
if grad_fn_name == "CppFunction":
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
@ -627,7 +627,7 @@ def _str_intern(inp, *, tensor_contents=None):
suffixes.append(f"tangent={tangent}")
string_repr = _add_suffixes(
prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse
prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse # type: ignore[possibly-undefined]
)
# Check if this instance is flagged as a parameter and change the repr accordingly.

View File

@ -188,7 +188,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
if self.bn.training:
avg_dims = [0] + list(range(2, len(self.weight.shape)))
batch_mean = conv_out.mean(avg_dims)
batch_mean = conv_out.mean(avg_dims) # type: ignore[possibly-undefined]
batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
avg_dims
)

View File

@ -231,7 +231,7 @@ class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
if converted.bias_v is not None:
bias_v = converted._parameters.pop('bias_v')
sc, zp = torch._choose_qparams_per_tensor(bias_k,
sc, zp = torch._choose_qparams_per_tensor(bias_k, # type: ignore[possibly-undefined]
reduce_range=False)
bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8)
setattr(converted, 'bias_v', bias_v) # noqa: B010

View File

@ -24,7 +24,7 @@ class LinearPackedParams(torch.nn.Module):
wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
elif self.dtype == torch.float16:
wq = torch.zeros([1, 1], dtype=torch.float)
self.set_weight_bias(wq, None)
self.set_weight_bias(wq, None) # type: ignore[possibly-undefined]
@torch.jit.export
def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:

View File

@ -435,7 +435,7 @@ class LSTM(RNNBase):
hx = (h_zeros, c_zeros)
else:
if batch_sizes is None: # If not PackedSequence input.
if is_batched:
if is_batched: # type: ignore[possibly-undefined]
if (hx[0].dim() != 3 or hx[1].dim() != 3):
msg = ("For batched 3-D input, hx and cx should "
f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
@ -465,8 +465,8 @@ class LSTM(RNNBase):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
if not is_batched:
output = output.squeeze(batch_dim)
if not is_batched: # type: ignore[possibly-undefined]
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
return output, self.permute_hidden(hidden, unsorted_indices)
@ -589,8 +589,8 @@ class GRU(RNNBase):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
if not is_batched:
output = output.squeeze(batch_dim)
if not is_batched: # type: ignore[possibly-undefined]
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
hidden = hidden.squeeze(1)
return output, self.permute_hidden(hidden, unsorted_indices)

View File

@ -759,7 +759,7 @@ def create_a_shadows_b(
continue
fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)
fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b) # type: ignore[possibly-undefined]
if node_b_is_start_node:
@ -817,7 +817,7 @@ def create_a_shadows_b(
# cast dtype from the dtype of node_c's input to the dtype of
# node_a's input (dequant, etc)
# prev_node_c = node_c.args[0]
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) # type: ignore[possibly-undefined]
if should_log_inputs:
# skip the input logger when inserting a dtype cast
if isinstance(prev_node_c, Node):
@ -901,7 +901,7 @@ def create_a_shadows_b(
# input_logger = env_c[dtype_cast_node.name]
# Find the first node in the subgraph
cur_node = node_a_shadows_c
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger:
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
if isinstance(input_logger, Node):
input_logger_mod = getattr(gm_b, input_logger.name)

View File

@ -92,7 +92,7 @@ class OutputProp:
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
if isinstance(result, torch.Tensor):
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
node.traced_result = result
env[node.name] = result
@ -375,7 +375,7 @@ def create_submodule_from_subgraph(
# TODO(future PR): this is ignoring kwargs, will need to support kwargs
# for any fusion pattern which has them for a node that is not the
# first node.
cur_args_copy = [cur_node_copy] # type: ignore[has-type] # noqa: F821
cur_args_copy = [cur_node_copy] # type: ignore[has-type, possibly-undefined] # noqa: F821
if len(cur_node_orig.args) > 1:
for arg in cur_node_orig.args[1:]:
@ -399,15 +399,15 @@ def create_submodule_from_subgraph(
mod_name = f"mod_{cur_name_idx}"
setattr(gm, mod_name, orig_mod_copy)
cur_name_idx += 1
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy)
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
elif cur_node_orig.op == 'call_function':
cur_node_copy = g.call_function(
cur_node_orig.target, cur_args_copy, cur_kwargs_copy)
cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
elif cur_node_orig.op == 'call_method':
cur_node_copy = g.call_method(
cur_node_orig.target, cur_args_copy, cur_kwargs_copy)
cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
else:
raise AssertionError(f'{cur_node_orig.op} not supported yet')

View File

@ -402,7 +402,7 @@ class ActivationSparsifier:
hook = layer.register_forward_pre_hook(self._sparsify_hook(name))
config['layer'] = layer
config['hook'] = hook
config['hook'] = hook # type: ignore[possibly-undefined]
def __repr__(self):
format_string = self.__class__.__name__ + ' ('

View File

@ -117,7 +117,7 @@ def _prune_linear_helper(linear: nn.Linear) -> Tensor:
with torch.no_grad():
parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True)
linear.weight = nn.Parameter(linear.weight[mask])
linear.weight = nn.Parameter(linear.weight[mask]) # type: ignore[possibly-undefined]
linear.out_features = linear.weight.shape[0]
_remove_bias_handles(linear)
@ -175,7 +175,7 @@ def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor:
with torch.no_grad():
parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True)
conv2d.weight = nn.Parameter(conv2d.weight[mask])
conv2d.weight = nn.Parameter(conv2d.weight[mask]) # type: ignore[possibly-undefined]
conv2d.out_channels = conv2d.weight.shape[0]
_remove_bias_handles(conv2d)
@ -197,7 +197,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
conv2d_1.bias is not None
): # conv2d_1 has original bias and bias propagated from previous layer
new_bias = torch.zeros(conv2d_1.bias.shape)
new_bias[mask] = conv2d_1.bias[mask]
new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined]
# adjusted bias that to keep in conv2d_1
new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
# pruned biases that are kept instead of propagated
@ -209,7 +209,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
if (
conv2d_1.bias is not None
): # conv2d_1 has bias propagated from previous layer
conv2d_1.bias.data[~mask] = 0
conv2d_1.bias.data[~mask] = 0 # type: ignore[possibly-undefined]
if hasattr(conv2d_1, "_bias"):
delattr(conv2d_1, "_bias")

View File

@ -835,7 +835,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
if (
type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
maybe_obs_mod.dtype == arg_as_input_target_dtype
maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined]
):
arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
existing_obs_node = maybe_obs_node

View File

@ -516,7 +516,7 @@ def register_multi_grad_hook(
if tensor.requires_grad
)
return Handle(handles)
return Handle(handles) # type: ignore[possibly-undefined]
# NOTE [Allow mutation on tensors saved for backward]
@ -746,4 +746,4 @@ def _engine_run_backward(t_outputs, *args, **kwargs):
) # Calls into the C++ engine to run the backward pass
finally:
if attach_logging_hooks:
unregister_hooks()
unregister_hooks() # type: ignore[possibly-undefined]

View File

@ -1148,7 +1148,7 @@ def _build_table(
if evt.flops <= 0:
row_values.append("--")
else:
row_values.append(f"{evt.flops * flops_scale:8.3f}")
row_values.append(f"{evt.flops * flops_scale:8.3f}") # type: ignore[possibly-undefined]
if has_stack:
src_field = ""
if len(evt.stack) > 0:

View File

@ -1176,7 +1176,7 @@ class _NnapiSerializer:
shape=change_element(out_oper.shape, dim, out_dim_size)
)
if in_oper.dim_order == DimOrder.CHANNELS_LAST:
if in_oper.dim_order == DimOrder.CHANNELS_LAST: # type: ignore[possibly-undefined]
assert len(out_oper.shape) == 4
nnapi_dim = [0, 3, 1, 2][dim]
else:
@ -1633,10 +1633,10 @@ class _NnapiSerializer:
size_ctype, size_arg = self.get_constant_value(size_jit)
if node.inputsSize() == 3:
scale_ctype, scale_arg = self.get_constant_value(scale_jit)
scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined]
else:
scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit)
scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit)
scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined]
scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined]
# The only way for the 4-argument overload of upsample_nearest2d to
# have been added to the graph without error is if the scale_h and

View File

@ -325,7 +325,7 @@ def make_graphed_callables(
only_inputs=True,
allow_unused=allow_unused_input,
)
del outputs, grad_inputs
del outputs, grad_inputs # type: ignore[possibly-undefined]
torch.cuda.synchronize()
# All captures here share a mempool. To avoid replays corrupting each other's memory,

View File

@ -206,4 +206,4 @@ def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank):
start_pos = current_offsets
break
current_offsets += chunk_size
return start_pos, chunk_size
return start_pos, chunk_size # type: ignore[possibly-undefined]

View File

@ -395,7 +395,7 @@ def _handle_row_wise_sharding(
result = torch.nn.functional.embedding_bag(
lookup_input,
torch.cat([local_shard, padding_row]),
offsets=offsets_list if offsets is not None else offsets,
offsets=offsets_list if offsets is not None else offsets, # type: ignore[possibly-undefined]
mode=mode if mode != "mean" else "sum",
per_sample_weights=per_sample_weights,
max_norm=max_norm,

View File

@ -541,7 +541,7 @@ def mark_data_parallel_shardings(
# mark activation as sharded on batch dim
node_sharding = node_strategies[0]
node.meta["sharding"] = node_sharding
node.meta["sharding"] = node_sharding # type: ignore[possibly-undefined]
placeholder_idx += 1
elif node.op == "call_function":

View File

@ -45,7 +45,7 @@ class GraphModuleTransformation:
"iter_graph_main_gm": iter_gm.main_gm.print_readable(False),
"iter_graph_cleanup_gm": iter_gm.cleanup_gm.print_readable(False),
},
graph_folder,
graph_folder, # type: ignore[possibly-undefined]
)
return iter_gm

View File

@ -353,7 +353,7 @@ def _scatter_wait_result(
gm.graph.node_replace_all_uses_with(orig_wait, wait_output_node)
if last_split_reshape_node == split_node:
last_split_reshape_node = wait_output_node
last_split_reshape_node = wait_output_node # type: ignore[possibly-undefined]
need_sort_nodes = sorted(need_sort_nodes, key=lambda node: node_indices[node])
gm.graph.move_after(need_sort_nodes, last_split_reshape_node)

View File

@ -561,7 +561,7 @@ class IterGraph(fx.Graph):
delete_user_cb,
propagate_meta=propagate_meta,
)
return ret
return ret # type: ignore[possibly-undefined]
def node_add_user(self, node: fx.Node, user: Any) -> None:
for graph in self._all_graphs:
@ -607,8 +607,8 @@ class IterGraph(fx.Graph):
"_foreach_add_",
):
step_node = node
self.node_add_user(optim_node, output_node)
self.node_add_user(step_node, optim_node)
self.node_add_user(optim_node, output_node) # type: ignore[possibly-undefined]
self.node_add_user(step_node, optim_node) # type: ignore[possibly-undefined]
def defunctionalize_optim(self) -> None:
# TODO: remove this API after DCE is not used with IterGraph
@ -624,8 +624,8 @@ class IterGraph(fx.Graph):
"_foreach_add_",
):
step_node = node
optim_node.users.pop(output_node, None)
step_node.users.pop(optim_node, None)
optim_node.users.pop(output_node, None) # type: ignore[possibly-undefined]
step_node.users.pop(optim_node, None) # type: ignore[possibly-undefined]
def freeze_cross_iter_movement(self) -> None:
self._freeze_cross_iter_movement = True

View File

@ -199,7 +199,7 @@ class OpDispatcher:
if output_sharding.output_spec is None:
if op_call == aten.equal.default:
obj_list = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(obj_list, local_results)
dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined]
obj_list = list(filter(lambda x: x is not None, obj_list))
# perform reduce on the collection with AND op
local_results = functools.reduce(operator.and_, obj_list, True)
@ -229,7 +229,7 @@ class OpDispatcher:
assert len(out_dts) >= 1, "out variant should have at least one out arg"
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
else:
return self.wrap(local_results, output_sharding.output_spec)
return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
@staticmethod
def redistribute_local_args(

View File

@ -201,7 +201,7 @@ class Shard(Placement):
)
if is_padded:
output = self._unpad_tensor(output, pad_sizes[my_coordinate[mesh_dim]])
output = self._unpad_tensor(output, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
return output
def _to_replicate_tensor(
@ -236,7 +236,7 @@ class Shard(Placement):
group=(mesh, mesh_dim),
)
if is_padded:
unpad_size = full_chunk_size * num_chunks - logical_dim_size
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
result = self._unpad_tensor(result, unpad_size)
return result

View File

@ -69,7 +69,7 @@ class HybridModel(torch.nn.Module):
# Make sure combined PS dimension is always bigger or equal than the FC input
assert NUM_PS * EMBEDDING_DIM >= 512
dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512)
emb_lookups_reshaped = emb_lookups_cat.reshape(
emb_lookups_reshaped = emb_lookups_cat.reshape( # type: ignore[possibly-undefined]
[emb_lookups_cat.shape[0] * dim_normalizer, 512]
)
@ -195,7 +195,7 @@ def _run_trainer(emb_rref_list, rank):
# Throw away warm-up measurements
measurements = measurements[WARMUP_CYCLES:]
return rank, measurements, batch_size
return rank, measurements, batch_size # type: ignore[possibly-undefined]
def run_worker(rank, world_size):

View File

@ -85,7 +85,7 @@ else:
if cur_rank in mesh_1d:
res_sub_mesh = sub_mesh
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined]
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
return res_sub_mesh

View File

@ -1943,9 +1943,9 @@ def _coalescing_manager(
work = group._end_coalescing(device)
if async_ops:
cm.append(work)
cm.append(work) # type: ignore[possibly-undefined]
else:
work.wait()
work.wait() # type: ignore[possibly-undefined]
def batch_isend_irecv(p2p_op_list):
@ -2458,7 +2458,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
# All ranks call gather with equal-sized tensors.
gather(
input_tensor,
gather_list=output_tensors if my_rank == dst else None,
gather_list=output_tensors if my_rank == dst else None, # type: ignore[possibly-undefined]
dst=dst,
group=group,
)
@ -2558,7 +2558,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
# has only one element, we can skip the copy.
if my_rank == src:
if len(tensor_list) == 1:
if len(tensor_list) == 1: # type: ignore[possibly-undefined]
object_tensor = tensor_list[0]
else:
object_tensor = torch.cat(tensor_list)
@ -2661,8 +2661,8 @@ def scatter_object_list(
# Src rank broadcasts the maximum tensor size. This is because all ranks are
# expected to call into scatter() with equal-sized tensors.
if my_rank == src:
max_tensor_size = max(tensor_sizes)
for tensor in tensor_list:
max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined]
for tensor in tensor_list: # type: ignore[possibly-undefined]
tensor.resize_(max_tensor_size)
else:
max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
@ -2672,7 +2672,7 @@ def scatter_object_list(
output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device)
scatter(
output_tensor,
scatter_list=None if my_rank != src else tensor_list,
scatter_list=None if my_rank != src else tensor_list, # type: ignore[possibly-undefined]
src=src,
group=group,
)
@ -2681,7 +2681,7 @@ def scatter_object_list(
obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
scatter(
obj_tensor_size,
scatter_list=None if my_rank != src else tensor_sizes,
scatter_list=None if my_rank != src else tensor_sizes, # type: ignore[possibly-undefined]
src=src,
group=group,
)

View File

@ -51,7 +51,7 @@ class Event:
return data
if isinstance(data, str):
data_dict = json.loads(data)
data_dict["source"] = EventSource[data_dict["source"]]
data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined]
return Event(**data_dict)
def serialize(self) -> str:
@ -105,7 +105,7 @@ class RdzvEvent:
return data
if isinstance(data, str):
data_dict = json.loads(data)
data_dict["node_state"] = NodeState[data_dict["node_state"]]
data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined]
return RdzvEvent(**data_dict)
def serialize(self) -> str:

View File

@ -126,7 +126,7 @@ def prof(fn=None, group: str = "torchelastic"):
put_metric(f"{key}.failure", 1, group)
raise
finally:
put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group)
put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined]
return result
return wrapper
@ -164,7 +164,7 @@ def profile(group=None):
publish_metric(
group,
f"{func.__name__}.duration.ms",
get_elapsed_time_ms(start_time),
get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined]
)
return result

View File

@ -176,7 +176,7 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
"The connection to the C10d store has failed. See inner exception for details."
) from exc
return store
return store # type: ignore[possibly-undefined]
def _create_file_store(params: RendezvousParameters) -> FileStore:

View File

@ -57,7 +57,7 @@ def find_free_port():
s.listen(0)
return s
except OSError as e:
s.close()
s.close() # type: ignore[possibly-undefined]
print(f"Socket creation attempt failed: {e}")
raise RuntimeError("Failed to create a socket")

View File

@ -1767,8 +1767,8 @@ class FlatParamHandle:
)
flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
if self._use_orig_params:
if skip_use_sharded_views:
self._unsharded_flat_param_for_skipped_views = unsharded_flat_param
if skip_use_sharded_views: # type: ignore[possibly-undefined]
self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined]
else:
self._use_sharded_views()
# For the post-forward reshard, we may try to use sharded gradient
@ -1776,7 +1776,7 @@ class FlatParamHandle:
# in `no_sync()`), but for the post-backward reshard, we delay the
# call to after the reduce-scatter.
if (
in_forward
in_forward # type: ignore[possibly-undefined]
# Skip using gradient views if skipped using sharded views
# since exposing unsharded parameters with sharded gradients
# may be confusing to the user

View File

@ -885,7 +885,7 @@ def _materialize_meta_module(
warnings.warn(
"Unable to call `reset_parameters()` for module on meta "
f"device with error {str(e)}. Please ensure that your module of"
f"type {type(module)} implements a `reset_parameters()` method."
f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined]
)
raise e
@ -994,7 +994,7 @@ def _move_states_to_device(
param.grad.data = param.grad.to(device_from_device_id)
for buffer in buffers:
buffer.data = buffer.to(device_from_device_id)
elif current_device == cpu_device:
elif current_device == cpu_device: # type: ignore[possibly-undefined]
_warn_cpu_init()

View File

@ -1419,7 +1419,7 @@ def _convert_all_state_info(
)
gathered_state[name] = scalar_tensor_value
return dtype, state_buffers
return dtype, state_buffers # type: ignore[possibly-undefined]
def _unflatten_orig_param_states(

View File

@ -171,7 +171,7 @@ def _create_1d_device_mesh(device_mesh: DeviceMesh, tp_mesh_dim: int = 0) -> Dev
if cur_rank in mesh_1d:
res_sub_mesh = sub_mesh
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[tp_mesh_dim]]
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[tp_mesh_dim]] # type: ignore[possibly-undefined]
return res_sub_mesh

View File

@ -252,7 +252,7 @@ class ExportedProgram:
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
(args, kwargs)
)
) # type: ignore[possibly-undefined]
if in_spec is not None and received_spec != in_spec:
raise ValueError(

View File

@ -998,7 +998,7 @@ class Partitioner:
if cost < min_cost:
node_pair = [node, n1]
min_cost = cost
return cost, node_pair
return cost, node_pair # type: ignore[possibly-undefined]
# First use size_base_partition
self.size_based_partition()

View File

@ -263,7 +263,7 @@ def split_const_subgraphs(
setattr(
split,
fx_const_folded_attrs_name,
torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(),
torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined]
)
for node in split.graph.nodes:
if node.op == "call_module" and node.target == const_mod_name:

View File

@ -694,7 +694,7 @@ for name in math_op_names:
fn.__qualname__ = fn.__name__ = priv_sympy_name
setattr(current_module, priv_sympy_name, fn)
del fn, name, priv_sympy_name
del fn, name, priv_sympy_name # type: ignore[possibly-undefined]
def _sympy_abs(a):
@ -753,7 +753,7 @@ for name in math_op_names:
sym_name = f"sym_{name}"
magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
del name, sym_name, math_op_names, current_module
del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined]
def sympy_is_contiguous(sizes, strides):

View File

@ -68,8 +68,8 @@ def consistent(a, b):
p1 += 1
# We only need to check for variadic ends
# Variadic types are guaranteed to be the last element
return (isvariadic(cur_a) and p2 == len(b) or
isvariadic(cur_b) and p1 == len(a))
return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined]
isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined]
def ambiguous(a, b):

View File

@ -371,11 +371,11 @@ class _MinimizerBase:
# Compare results
names: Names = output_names
if output_names is None:
names = [str(v) for v in result_key]
names = [str(v) for v in result_key] # type: ignore[possibly-undefined]
numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
self.results[result_key] = numeric_result
self.results[result_key] = numeric_result # type: ignore[possibly-undefined]
report.append(f"Numerical accuracy = {numeric_result}")
if not bool_result:
report.append(f"Result mismatch for {result_key}")

View File

@ -575,7 +575,7 @@ class _SplitterBase:
else:
total_output_bytes += get_size_of_node(submod, node)[0]
map_arg(output_node.args, get_bytes)
map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined]
qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"

View File

@ -305,7 +305,7 @@ def _replace_pattern(
first_user_node = n
break
with original_graph.inserting_before(first_user_node):
with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined]
copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map)
if isinstance(copied_returning_nodes, Node):

View File

@ -650,14 +650,14 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
buffer = u.read(READ_DATA_CHUNK)
if len(buffer) == 0:
break
f.write(buffer)
f.write(buffer) # type: ignore[possibly-undefined]
if hash_prefix is not None:
sha256.update(buffer)
sha256.update(buffer) # type: ignore[possibly-undefined]
pbar.update(len(buffer))
f.close()
if hash_prefix is not None:
digest = sha256.hexdigest()
digest = sha256.hexdigest() # type: ignore[possibly-undefined]
if digest[:len(hash_prefix)] != hash_prefix:
raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
shutil.move(f.name, dst)

View File

@ -70,8 +70,8 @@ def fuser(name):
yield
finally:
if name in ["fuser1", "fuser3"]: # NNC or oneDNN Graph
torch._C._jit_set_profiling_executor(old_profiling_executor)
torch._C._get_graph_executor_optimize(old_profiling_mode)
torch._C._jit_set_profiling_executor(old_profiling_executor) # type: ignore[possibly-undefined]
torch._C._get_graph_executor_optimize(old_profiling_mode) # type: ignore[possibly-undefined]
# recover the previous values
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)

View File

@ -254,7 +254,7 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
if assert_compiled:
hits = compiled_fn.hits
out = model(*args)
if assert_compiled and compiled_fn.hits == hits:
if assert_compiled and compiled_fn.hits == hits: # type: ignore[possibly-undefined]
raise RuntimeError("failed to use the compiled function")
if not isinstance(out, tuple):
out = (out,)
@ -280,7 +280,7 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
assert model.has_trace_for(*args)
if is_module:
model.load_state_dict(saved_state)
model.load_state_dict(saved_state) # type: ignore[possibly-undefined]
compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
_verify_equal(uncompiled_outs, compiled_outs)

View File

@ -1627,7 +1627,7 @@ def _std_var(
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
else:
total = sum(
x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask
x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined]
)
if not keepdim:
count = count.reshape(total.shape)

View File

@ -781,8 +781,8 @@ class SyncBatchNorm(_BatchNorm):
running_var,
self.eps,
exponential_average_factor,
process_group,
world_size,
process_group, # type: ignore[possibly-undefined]
world_size, # type: ignore[possibly-undefined]
)
@classmethod

View File

@ -1604,9 +1604,9 @@ class Module:
# For now only forward hooks have the always_call option but perhaps
# this functionality should be added to full backward hooks as well.
for hook_id, hook in _global_forward_hooks.items():
if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks:
if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined]
try:
hook_result = hook(self, args, result)
hook_result = hook(self, args, result) # type: ignore[possibly-undefined]
if hook_result is not None:
result = hook_result
except Exception as e:
@ -1615,12 +1615,12 @@ class Module:
continue
for hook_id, hook in self._forward_hooks.items():
if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks:
if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined]
try:
if hook_id in self._forward_hooks_with_kwargs:
hook_result = hook(self, args, kwargs, result)
hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined]
else:
hook_result = hook(self, args, result)
hook_result = hook(self, args, result) # type: ignore[possibly-undefined]
if hook_result is not None:
result = hook_result
except Exception as e:

View File

@ -575,8 +575,8 @@ class RNN(RNNBase):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
if not is_batched:
output = output.squeeze(batch_dim)
if not is_batched: # type: ignore[possibly-undefined]
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
hidden = hidden.squeeze(1)
return output, self.permute_hidden(hidden, unsorted_indices)
@ -888,8 +888,8 @@ class LSTM(RNNBase):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
if not is_batched:
output = output.squeeze(batch_dim)
if not is_batched: # type: ignore[possibly-undefined]
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
return output, self.permute_hidden(hidden, unsorted_indices)
@ -1111,8 +1111,8 @@ class GRU(RNNBase):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
if not is_batched:
output = output.squeeze(batch_dim)
if not is_batched: # type: ignore[possibly-undefined]
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
hidden = hidden.squeeze(1)
return output, self.permute_hidden(hidden, unsorted_indices)

View File

@ -105,7 +105,7 @@ class _Orthogonal(Module):
Q = self.base @ Q
if transposed:
Q = Q.mT
return Q
return Q # type: ignore[possibly-undefined]
@torch.autograd.no_grad()
def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:

View File

@ -293,7 +293,7 @@ def _create_node(
for _ in range(1, n_outputs):
node.addOutput()
node_ouputs = tuple(node.outputs())
node_ouputs = tuple(node.outputs()) # type: ignore[possibly-undefined]
assert len(node_ouputs) == n_outputs
aten = domain_op.startswith("aten::")

View File

@ -1529,7 +1529,7 @@ def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
)
if is_transpose_required:
softmax = g.op("Transpose", softmax, perm_i=axes)
softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined]
return softmax
# Apply max normalization.
@ -2467,7 +2467,7 @@ def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
"Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
)
if is_transpose_required:
return_op = g.op("Transpose", return_op, perm_i=axes)
return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined]
return return_op
@ -2978,7 +2978,7 @@ def native_layer_norm(
# mean and normalized, so we need to Cast it back
if is_type_half:
denominator = g.op(
"Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type()
"Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() # type: ignore[possibly-undefined]
)
rdenominator = g.op("Reciprocal", denominator)
else:
@ -4754,7 +4754,7 @@ def _generic_rnn(
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
)
return tuple(
symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh)
symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined]
)
@_beartype.beartype
@ -4766,10 +4766,10 @@ def _generic_rnn(
weight_ih, weight_hh, bias_ih, bias_hh = (
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
)
bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0)
bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined]
return tuple(
symbolic_helper._unsqueeze_helper(g, x, [0])
for x in (weight_ih, weight_hh, bias_concat)
for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined]
)
@_beartype.beartype
@ -4808,16 +4808,16 @@ def _generic_rnn(
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
inputs.append(retrieve_state(h0, *state_indices))
inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined]
if variant == "LSTM":
inputs.append(retrieve_state(c0, *state_indices))
inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined]
extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"}
if variant == "RNN":
if bidirectional:
activation = [nonlinearity, nonlinearity]
activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined]
else:
activation = [nonlinearity]
activation = [nonlinearity] # type: ignore[possibly-undefined]
prev_output, h_out = g.op(
"RNN",
@ -4859,17 +4859,17 @@ def _generic_rnn(
else:
prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1])
h_outs.append(h_out)
h_outs.append(h_out) # type: ignore[possibly-undefined]
if variant == "LSTM":
c_outs.append(c_out)
c_outs.append(c_out) # type: ignore[possibly-undefined]
if batch_first:
# seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size
prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2])
h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0)
h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined]
if variant == "RNN" or variant == "GRU":
return prev_output, h_outs
elif variant == "LSTM":
c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0)
c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined]
return prev_output, h_outs, c_outs

View File

@ -199,7 +199,7 @@ class BasicEvaluation:
while (
current_kernel_index < len(cuda_kernel_events)
and (cuda_kernel_events[current_kernel_index].start_us()) * 1000
<= start_time
<= start_time # type: ignore[possibly-undefined]
):
current_kernel_index += 1
current_queue_depth = spawned_kernel_index - current_kernel_index + 1
@ -207,7 +207,7 @@ class BasicEvaluation:
if hasattr(event, "start_us"):
queue_depth_list.append(
Interval(start_time, end_time, current_queue_depth)
Interval(start_time, end_time, current_queue_depth) # type: ignore[possibly-undefined]
)
elif hasattr(event, "start_time_ns"):
self.metrics[EventKey(event)].queue_depth = current_queue_depth

View File

@ -44,7 +44,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
else:
outp = weight
ncols, nrows = outp.shape
ncols, nrows = outp.shape # type: ignore[possibly-undefined]
assert nrows % (32 if dtypeq == torch.quint4x2 else 64) == 0
assert ncols % 64 == 0

View File

@ -134,11 +134,11 @@ def sparse_semi_structured_from_dense_cutlass(dense):
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
if dense.dtype != torch.float:
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1))
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else:
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2)
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
@ -163,7 +163,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
)
# Reorder meta tensor elements.
meta_reordered = meta.new_empty((m * meta_ncols,))
meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined]
meta_offsets = _calculate_meta_reordering_scatter_offsets(
m, meta_ncols, meta_dtype, device
)

View File

@ -1662,7 +1662,7 @@ if has_triton():
acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
if is_compressed:
A_ptr += r0 * blocks_stride_P
A_ptr += r0 * blocks_stride_P # type: ignore[possibly-undefined]
for _ in range(nnz):
q = tl.load(q_ptr)
B = tl.load(B_ptr + q)
@ -1889,7 +1889,7 @@ if has_triton():
# alpha is never 0
if beta_is_nonzero:
output_acc_block = tl.load(input_ptrs).to(acc_dtype)
output_acc_block = tl.load(input_ptrs).to(acc_dtype) # type: ignore[possibly-undefined]
if not (beta_is_one and alpha_is_one):
beta_alpha = beta / alpha
output_acc_block *= beta_alpha

View File

@ -76,7 +76,7 @@ def _output_csv(file, results):
dim_str = str(dim)
shape_str = 'x'.join(str(s) for s in shape)
print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str,
print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str, # type: ignore[possibly-undefined]
measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6,
sep=',', file=file)

View File

@ -701,7 +701,7 @@ class _ValgrindWrapper:
if fn_match:
ir_str, file_function = fn_match.groups()
ir = int(ir_str.replace(",", ""))
if ir == program_totals:
if ir == program_totals: # type: ignore[possibly-undefined]
# Callgrind includes some top level red herring symbols when
# a program dumps multiple profiles.
continue

View File

@ -1427,7 +1427,7 @@ def _checkpoint_without_reentrant_generator(
new_frame.forward_completed = True
if getattr(device_module, "_initialized", False) and \
preserve_rng_state and not had_device_in_fwd:
preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined]
# Device was not initialized before running the forward, so we didn't
# stash the device state.
raise RuntimeError(

View File

@ -2391,7 +2391,7 @@ def _write_ninja_file(path,
# 'Blocks' should be separated by newlines, for visual benefit.
blocks = [config, flags, compile_rule]
if with_cuda:
blocks.append(cuda_compile_rule)
blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined]
blocks += [devlink_rule, link_rule, build, devlink, link, default]
content = "\n\n".join("\n".join(b) for b in blocks)
# Ninja requires a new lines at the end of the .ninja file

View File

@ -305,7 +305,7 @@ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
init_exception = None
else:
try:
data = fetcher.fetch(index)
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
except Exception as e:
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
data = _IterableDatasetStopIteration(worker_id)

View File

@ -1360,7 +1360,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# not found (i.e., didn't break)
return
self._index_queues[worker_queue_idx].put((self._send_idx, index))
self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined]
self._task_info[self._send_idx] = (worker_queue_idx,)
self._tasks_outstanding += 1
self._send_idx += 1

View File

@ -210,7 +210,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
raise BufferError("ForkerIterDataPipe buffer overflow," +
f"buffer size {self.buffer_size} is insufficient.")
yield self.copy_fn(return_val)
yield self.copy_fn(return_val) # type: ignore[possibly-undefined]
finally:
self._child_stop[instance_id] = True
# Cleanup _datapipe_iterator for the case that fork exits earlier

View File

@ -907,7 +907,7 @@ class SummaryWriter:
else:
# Handles cnn.CNNModelHelper, model_helper.ModelHelper
current_graph = model_to_graph_def(model)
event = event_pb2.Event(graph_def=current_graph.SerializeToString())
event = event_pb2.Event(graph_def=current_graph.SerializeToString()) # type: ignore[possibly-undefined]
self._get_file_writer().add_event(event)
@staticmethod

View File

@ -717,10 +717,10 @@ resize_out(out, sizes, strides, options);
f"{textwrap.indent(class_ctor_str, indent)}",
f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
" const Tensor& maybe_get_output(int64_t output_idx) override {",
f" return {output_value};\n",
f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit
" }",
f" std::array<{output_type}, {len(f.func.returns)}> outputs_;",
f"{textwrap.indent(proxy_field, indent)}",
f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", # type: ignore[possibly-undefined] # TODO: audit
f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit
f"{textwrap.indent(guard_field, indent)}",
"};",
)
@ -962,7 +962,7 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
else:
refs = ", ".join(a.name for a in f.func.arguments.out)
ret_expr = f"std::forward_as_tuple({refs})"
sig_body.append(f"return {ret_expr};")
sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit
sig_body_str = "\n".join(sig_body)