mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45.
Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
This commit is contained in:
parent
6511811ebb
commit
40ece2e579
1
mypy.ini
1
mypy.ini
|
|
@ -13,7 +13,6 @@ show_column_numbers = True
|
||||||
check_untyped_defs = True
|
check_untyped_defs = True
|
||||||
follow_imports = normal
|
follow_imports = normal
|
||||||
local_partial_types = True
|
local_partial_types = True
|
||||||
enable_error_code = possibly-undefined
|
|
||||||
|
|
||||||
# do not reenable this:
|
# do not reenable this:
|
||||||
# https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657
|
# https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# mypy: disable-error-code="possibly-undefined"
|
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import TEST_NUMPY
|
from torch.testing._internal.common_utils import TEST_NUMPY
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# mypy: disable-error-code="possibly-undefined"
|
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import TEST_NUMPY
|
from torch.testing._internal.common_utils import TEST_NUMPY
|
||||||
|
|
|
||||||
|
|
@ -515,7 +515,7 @@ for name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos"
|
||||||
sym_sqrt = current_module._sym_sqrt
|
sym_sqrt = current_module._sym_sqrt
|
||||||
__all__.append("sym_sqrt")
|
__all__.append("sym_sqrt")
|
||||||
|
|
||||||
del fn, name, sym_name, current_module # type: ignore[possibly-undefined]
|
del fn, name, sym_name, current_module
|
||||||
|
|
||||||
|
|
||||||
def sym_ite(b, t, f):
|
def sym_ite(b, t, f):
|
||||||
|
|
|
||||||
|
|
@ -2832,7 +2832,7 @@ def _rnn_helper(
|
||||||
final_hiddens.append(bwd_hidden)
|
final_hiddens.append(bwd_hidden)
|
||||||
|
|
||||||
if bidirectional:
|
if bidirectional:
|
||||||
input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1) # type: ignore[possibly-undefined]
|
input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1)
|
||||||
else:
|
else:
|
||||||
input = fwd_inp
|
input = fwd_inp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -163,7 +163,7 @@ def preserve_global_state(fn):
|
||||||
random.setstate(py_rng_state)
|
random.setstate(py_rng_state)
|
||||||
torch.random.set_rng_state(torch_rng_state)
|
torch.random.set_rng_state(torch_rng_state)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
torch.fx.graph_module._forward_from_src = prior_fwd_from_src
|
torch.fx.graph_module._forward_from_src = prior_fwd_from_src
|
||||||
assert (
|
assert (
|
||||||
guards.check()
|
guards.check()
|
||||||
|
|
@ -568,7 +568,7 @@ def _compile(
|
||||||
code.co_name,
|
code.co_name,
|
||||||
code.co_filename,
|
code.co_filename,
|
||||||
code.co_firstlineno,
|
code.co_firstlineno,
|
||||||
out_code, # type: ignore[possibly-undefined]
|
out_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hook in _bytecode_hooks.values():
|
for hook in _bytecode_hooks.values():
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ if use_buck:
|
||||||
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
|
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
|
||||||
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
|
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
|
||||||
]
|
]
|
||||||
cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined]
|
cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//")
|
||||||
extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
|
extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1430,7 +1430,7 @@ def export(
|
||||||
example_fake_inputs,
|
example_fake_inputs,
|
||||||
graph_captured_input,
|
graph_captured_input,
|
||||||
graph_captured_result,
|
graph_captured_result,
|
||||||
result_traced, # type: ignore[possibly-undefined]
|
result_traced,
|
||||||
flat_args_dynamic_dims,
|
flat_args_dynamic_dims,
|
||||||
)
|
)
|
||||||
# Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check
|
# Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check
|
||||||
|
|
|
||||||
|
|
@ -1115,7 +1115,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||||
tos = self.pop()
|
tos = self.pop()
|
||||||
_ = self.pop()
|
_ = self.pop()
|
||||||
if preserve_tos:
|
if preserve_tos:
|
||||||
self.push(tos) # type: ignore[possibly-undefined]
|
self.push(tos)
|
||||||
|
|
||||||
def FOR_ITER(self, inst):
|
def FOR_ITER(self, inst):
|
||||||
it = self.pop().realize()
|
it = self.pop().realize()
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||||
finally:
|
finally:
|
||||||
log.removeHandler(log_handler)
|
log.removeHandler(log_handler)
|
||||||
if cwd is not None:
|
if cwd is not None:
|
||||||
os.chdir(prev_cwd) # type: ignore[possibly-undefined]
|
os.chdir(prev_cwd)
|
||||||
# Make sure we don't leave buggy compiled frames lying
|
# Make sure we don't leave buggy compiled frames lying
|
||||||
# around
|
# around
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
|
||||||
|
|
@ -773,7 +773,7 @@ def preserve_rng_state():
|
||||||
with torch.utils._python_dispatch._disable_current_modes():
|
with torch.utils._python_dispatch._disable_current_modes():
|
||||||
torch.random.set_rng_state(rng_state)
|
torch.random.set_rng_state(rng_state)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
|
|
||||||
|
|
||||||
def is_jit_model(model0):
|
def is_jit_model(model0):
|
||||||
|
|
@ -892,7 +892,7 @@ def timed(model, example_inputs, times=1):
|
||||||
result = model(*example_inputs)
|
result = model(*example_inputs)
|
||||||
synchronize()
|
synchronize()
|
||||||
t1 = time.perf_counter()
|
t1 = time.perf_counter()
|
||||||
return result, t1 - t0 # type: ignore[possibly-undefined]
|
return result, t1 - t0
|
||||||
|
|
||||||
|
|
||||||
def check_is_cuda(gm, example_inputs):
|
def check_is_cuda(gm, example_inputs):
|
||||||
|
|
|
||||||
|
|
@ -199,7 +199,7 @@ def wrap_outputs_maintaining_identity(
|
||||||
result.append(unwrapped_input_to_orig_input[id(output)])
|
result.append(unwrapped_input_to_orig_input[id(output)])
|
||||||
continue
|
continue
|
||||||
if out_dims_specified:
|
if out_dims_specified:
|
||||||
result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index]
|
result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[index]
|
||||||
else:
|
else:
|
||||||
result.append(wrap_fn(output))
|
result.append(wrap_fn(output))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -163,7 +163,7 @@ def parse_ttir(ttir, kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import lark # type: ignore[import-not-found]
|
import lark
|
||||||
from lark import Lark, Transformer, v_args
|
from lark import Lark, Transformer, v_args
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
||||||
|
|
@ -440,25 +440,25 @@ class BenchmarkRequest:
|
||||||
output_tensor = self.output_tensor_meta.to_tensor()
|
output_tensor = self.output_tensor_meta.to_tensor()
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
create_tensor_elapse = time.time() - start_ts
|
||||||
start_ts = time.time()
|
start_ts = time.time()
|
||||||
|
|
||||||
fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
|
fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
load_elapse = time.time() - start_ts
|
||||||
start_ts = time.time()
|
start_ts = time.time()
|
||||||
|
|
||||||
out = do_bench(fn)
|
out = do_bench(fn)
|
||||||
torch.cuda.synchronize() # shake out any CUDA errors
|
torch.cuda.synchronize() # shake out any CUDA errors
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
bench_elapse = time.time() - start_ts
|
||||||
log.debug(
|
log.debug(
|
||||||
"InChildProcess %s: load %f, create tensor %f, bench %f",
|
"InChildProcess %s: load %f, create tensor %f, bench %f",
|
||||||
str(self),
|
str(self),
|
||||||
load_elapse, # type: ignore[possibly-undefined]
|
load_elapse,
|
||||||
create_tensor_elapse, # type: ignore[possibly-undefined]
|
create_tensor_elapse,
|
||||||
bench_elapse,
|
bench_elapse,
|
||||||
)
|
)
|
||||||
self.cleanup_run_fn()
|
self.cleanup_run_fn()
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,7 @@ class CutlassEVTEpilogueTypeFormatter:
|
||||||
result = pnode.inner_fn(index)
|
result = pnode.inner_fn(index)
|
||||||
# each epilogue node results in a single "using" statement and may refer to the previous steps by name
|
# each epilogue node results in a single "using" statement and may refer to the previous steps by name
|
||||||
formatter.aliases[node.name] = result
|
formatter.aliases[node.name] = result
|
||||||
res = formatter.getvalue(result) # type: ignore[possibly-undefined]
|
res = formatter.getvalue(result)
|
||||||
if _MAGIC_SYMPY_ERROR_STRING in res:
|
if _MAGIC_SYMPY_ERROR_STRING in res:
|
||||||
raise CUTLASSEVTOpNotImplementedError(
|
raise CUTLASSEVTOpNotImplementedError(
|
||||||
"sympy / indexing expressions not yet supported in EVT fusion"
|
"sympy / indexing expressions not yet supported in EVT fusion"
|
||||||
|
|
@ -266,7 +266,7 @@ class CutlassEVTEpilogueArgumentFormatter:
|
||||||
if node.name is not None:
|
if node.name is not None:
|
||||||
formatter.aliases[node.name] = result
|
formatter.aliases[node.name] = result
|
||||||
|
|
||||||
res: str = formatter.getvalue(result) # type: ignore[possibly-undefined]
|
res: str = formatter.getvalue(result)
|
||||||
if _MAGIC_SYMPY_ERROR_STRING in res:
|
if _MAGIC_SYMPY_ERROR_STRING in res:
|
||||||
raise CUTLASSEVTOpNotImplementedError(
|
raise CUTLASSEVTOpNotImplementedError(
|
||||||
"sympy / indexing expressions not yet supported in EVT fusion"
|
"sympy / indexing expressions not yet supported in EVT fusion"
|
||||||
|
|
|
||||||
|
|
@ -155,7 +155,7 @@ def get_cpp_op_schema(kernel: torch._ops.OpOverload) -> str:
|
||||||
cpp_return_value = f"std::tuple<{tuple_returns}>"
|
cpp_return_value = f"std::tuple<{tuple_returns}>"
|
||||||
|
|
||||||
cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args]
|
cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args]
|
||||||
return f"{cpp_return_value}({', '.join(cpp_arg_type)})" # type: ignore[possibly-undefined]
|
return f"{cpp_return_value}({', '.join(cpp_arg_type)})"
|
||||||
|
|
||||||
|
|
||||||
# TODO: Move to a well known place
|
# TODO: Move to a well known place
|
||||||
|
|
|
||||||
|
|
@ -209,7 +209,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||||
nsteps = nRanks - 1
|
nsteps = nRanks - 1
|
||||||
|
|
||||||
# Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
|
# Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
|
||||||
ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined]
|
ratio = (1.0 * nRanks) / nsteps
|
||||||
bandwidth = busBw * ratio
|
bandwidth = busBw * ratio
|
||||||
# Convert GB/s to GB/ns
|
# Convert GB/s to GB/ns
|
||||||
bandwidth_GB_per_ns = bandwidth / 1e9
|
bandwidth_GB_per_ns = bandwidth / 1e9
|
||||||
|
|
@ -236,7 +236,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||||
if nNodes > 1:
|
if nNodes > 1:
|
||||||
netOverhead = 1.0 # getNetOverhead(comm);
|
netOverhead = 1.0 # getNetOverhead(comm);
|
||||||
intraLat = max(intraLat, netOverhead)
|
intraLat = max(intraLat, netOverhead)
|
||||||
latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined]
|
latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat
|
||||||
# Convert us to ns
|
# Convert us to ns
|
||||||
latency_ns = latency * 1e3
|
latency_ns = latency * 1e3
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -170,9 +170,9 @@ class PostGradBatchLinearFusion(BatchFusion):
|
||||||
input, weight = node.args
|
input, weight = node.args
|
||||||
bias = None
|
bias = None
|
||||||
batch_nodes.append(node)
|
batch_nodes.append(node)
|
||||||
batch_inputs.append(input) # type: ignore[possibly-undefined]
|
batch_inputs.append(input)
|
||||||
batch_weights.append(weight) # type: ignore[possibly-undefined]
|
batch_weights.append(weight)
|
||||||
batch_biases.append(bias) # type: ignore[possibly-undefined]
|
batch_biases.append(bias)
|
||||||
|
|
||||||
with graph.inserting_before(subset[-1]):
|
with graph.inserting_before(subset[-1]):
|
||||||
fused_inputs = decompose_stack(graph, batch_inputs)
|
fused_inputs = decompose_stack(graph, batch_inputs)
|
||||||
|
|
@ -191,7 +191,7 @@ class PostGradBatchLinearFusion(BatchFusion):
|
||||||
new_bias_add = graph.call_function(
|
new_bias_add = graph.call_function(
|
||||||
aten.add, args=((batch_biases[i], new_mm))
|
aten.add, args=((batch_biases[i], new_mm))
|
||||||
)
|
)
|
||||||
new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined]
|
new_mm_cont = new_bias_add if has_bias else new_mm
|
||||||
original_mm.replace_all_uses_with(new_mm_cont)
|
original_mm.replace_all_uses_with(new_mm_cont)
|
||||||
new_mm_cont.meta.update(original_mm.meta)
|
new_mm_cont.meta.update(original_mm.meta)
|
||||||
graph.erase_node(original_mm)
|
graph.erase_node(original_mm)
|
||||||
|
|
|
||||||
|
|
@ -283,7 +283,7 @@ if torch._C._has_mkldnn:
|
||||||
L[aten.mul](out, negative_slope),
|
L[aten.mul](out, negative_slope),
|
||||||
)
|
)
|
||||||
if lowp_dtype:
|
if lowp_dtype:
|
||||||
out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
|
out = L[prims.convert_element_type.default](out, dtype=dtype2)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
|
|
@ -324,7 +324,7 @@ if torch._C._has_mkldnn:
|
||||||
out = L[prims.convert_element_type.default](out, dtype=torch.float)
|
out = L[prims.convert_element_type.default](out, dtype=torch.float)
|
||||||
out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value)
|
out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value)
|
||||||
if lowp_dtype:
|
if lowp_dtype:
|
||||||
out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
|
out = L[prims.convert_element_type.default](out, dtype=dtype2)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
|
||||||
|
|
||||||
gm_after_fx_passes = gm.__copy__()
|
gm_after_fx_passes = gm.__copy__()
|
||||||
numeric_check_if_enabled(
|
numeric_check_if_enabled(
|
||||||
gm_before_fx_passes, # type: ignore[possibly-undefined]
|
gm_before_fx_passes,
|
||||||
gm_after_fx_passes,
|
gm_after_fx_passes,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
config.fx_passes_numeric_check.get("num_iterations", 1),
|
config.fx_passes_numeric_check.get("num_iterations", 1),
|
||||||
|
|
|
||||||
|
|
@ -1360,7 +1360,7 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
|
||||||
graph.erase_node(conv_node)
|
graph.erase_node(conv_node)
|
||||||
# Erase the dequant pattern
|
# Erase the dequant pattern
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined]
|
graph.erase_node(convert_to_bf16)
|
||||||
# Erase the dequant pattern
|
# Erase the dequant pattern
|
||||||
graph.erase_node(mul_node)
|
graph.erase_node(mul_node)
|
||||||
graph.erase_node(sub_node)
|
graph.erase_node(sub_node)
|
||||||
|
|
@ -1369,7 +1369,7 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
|
||||||
if clone_node is not None:
|
if clone_node is not None:
|
||||||
graph.erase_node(clone_node)
|
graph.erase_node(clone_node)
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
|
graph.erase_node(weight_to_bf16_node)
|
||||||
graph.erase_node(dequant_per_channel)
|
graph.erase_node(dequant_per_channel)
|
||||||
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
|
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
|
||||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
|
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
|
||||||
|
|
@ -1697,14 +1697,14 @@ def _register_qlinear_weight_prepack_pass(
|
||||||
if input_contiguous:
|
if input_contiguous:
|
||||||
graph.erase_node(output_reshape_node)
|
graph.erase_node(output_reshape_node)
|
||||||
elif not input_contiguous and bias:
|
elif not input_contiguous and bias:
|
||||||
graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined]
|
graph.erase_node(output_add_node_for_bias)
|
||||||
graph.erase_node(linear_node)
|
graph.erase_node(linear_node)
|
||||||
if input_dim_exceeds_two:
|
if input_dim_exceeds_two:
|
||||||
if input_contiguous:
|
if input_contiguous:
|
||||||
graph.erase_node(act_reshape_node)
|
graph.erase_node(act_reshape_node)
|
||||||
else:
|
else:
|
||||||
graph.erase_node(act_expand_node)
|
graph.erase_node(act_expand_node)
|
||||||
graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined]
|
graph.erase_node(wgt_expand_node)
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
graph.erase_node(activation_to_bf16_node)
|
graph.erase_node(activation_to_bf16_node)
|
||||||
# Erase the dequant pattern
|
# Erase the dequant pattern
|
||||||
|
|
@ -1714,7 +1714,7 @@ def _register_qlinear_weight_prepack_pass(
|
||||||
# Erase the dequant per channel pattern
|
# Erase the dequant per channel pattern
|
||||||
graph.erase_node(t_node)
|
graph.erase_node(t_node)
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
|
graph.erase_node(weight_to_bf16_node)
|
||||||
graph.erase_node(dequant_per_channel)
|
graph.erase_node(dequant_per_channel)
|
||||||
|
|
||||||
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
|
||||||
|
|
|
||||||
|
|
@ -845,7 +845,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
):
|
):
|
||||||
debug("fallback_handler")
|
debug("fallback_handler")
|
||||||
result = fallback_handler(n.target, add_to_fallback_set=False)(
|
result = fallback_handler(n.target, add_to_fallback_set=False)(
|
||||||
*args, **kwargs # type: ignore[possibly-undefined]
|
*args, **kwargs
|
||||||
)
|
)
|
||||||
elif n.op == "call_function" and n.target in layout_constraints:
|
elif n.op == "call_function" and n.target in layout_constraints:
|
||||||
debug("layout_constraints")
|
debug("layout_constraints")
|
||||||
|
|
|
||||||
|
|
@ -607,7 +607,7 @@ def register_pointwise(
|
||||||
fn,
|
fn,
|
||||||
override_return_dtype=override_return_dtype,
|
override_return_dtype=override_return_dtype,
|
||||||
override_fn_when_input_bool=override_fn_when_input_bool,
|
override_fn_when_input_bool=override_fn_when_input_bool,
|
||||||
override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None, # type: ignore[possibly-undefined]
|
override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None,
|
||||||
allow_alpha=allow_alpha,
|
allow_alpha=allow_alpha,
|
||||||
)
|
)
|
||||||
fn = register_lowering(
|
fn = register_lowering(
|
||||||
|
|
@ -3630,8 +3630,8 @@ def _reflection_padnd_backward(grad_output, x, padding):
|
||||||
out = right_reflect[i]
|
out = right_reflect[i]
|
||||||
index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1)
|
index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1)
|
||||||
|
|
||||||
outs.append(out) # type: ignore[possibly-undefined]
|
outs.append(out)
|
||||||
index_ranges.append(index_range) # type: ignore[possibly-undefined]
|
index_ranges.append(index_range)
|
||||||
|
|
||||||
grad = accumulate(grad, outs, index_ranges)
|
grad = accumulate(grad, outs, index_ranges)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1196,7 +1196,7 @@ class PatternMatcherPass:
|
||||||
if (
|
if (
|
||||||
self.prevent_match_across_mutations
|
self.prevent_match_across_mutations
|
||||||
and is_match(m)
|
and is_match(m)
|
||||||
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
|
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
|
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
|
||||||
|
|
|
||||||
|
|
@ -1038,7 +1038,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||||
else:
|
else:
|
||||||
fused_nodes.append(node)
|
fused_nodes.append(node)
|
||||||
|
|
||||||
return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) # type: ignore[possibly-undefined]
|
return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -2256,13 +2256,13 @@ class Scheduler:
|
||||||
|
|
||||||
if node.is_template():
|
if node.is_template():
|
||||||
node, *epilogue = node.get_nodes()
|
node, *epilogue = node.get_nodes()
|
||||||
self.get_backend(device).codegen_template(node, epilogue) # type: ignore[possibly-undefined]
|
self.get_backend(device).codegen_template(node, epilogue)
|
||||||
elif node.is_extern():
|
elif node.is_extern():
|
||||||
self.codegen_extern_call(node)
|
self.codegen_extern_call(node)
|
||||||
elif node.is_foreach():
|
elif node.is_foreach():
|
||||||
self.get_backend(device).codegen_foreach(node) # type: ignore[possibly-undefined]
|
self.get_backend(device).codegen_foreach(node)
|
||||||
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
|
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
|
||||||
self.get_backend(device).codegen_nodes(node.get_nodes()) # type: ignore[possibly-undefined]
|
self.get_backend(device).codegen_nodes(node.get_nodes())
|
||||||
else:
|
else:
|
||||||
assert isinstance(node, NopKernelSchedulerNode)
|
assert isinstance(node, NopKernelSchedulerNode)
|
||||||
node.allocate()
|
node.allocate()
|
||||||
|
|
@ -2271,7 +2271,7 @@ class Scheduler:
|
||||||
V.graph.wrapper_code.generate_inf_and_nan_checker(node)
|
V.graph.wrapper_code.generate_inf_and_nan_checker(node)
|
||||||
|
|
||||||
if config.triton.debug_sync_kernel:
|
if config.triton.debug_sync_kernel:
|
||||||
self.get_backend(device).codegen_sync() # type: ignore[possibly-undefined]
|
self.get_backend(device).codegen_sync()
|
||||||
|
|
||||||
self.available_buffer_names.update(node.get_names())
|
self.available_buffer_names.update(node.get_names())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -331,7 +331,7 @@ def timed(
|
||||||
synchronize(device)
|
synchronize(device)
|
||||||
t1 = time.perf_counter()
|
t1 = time.perf_counter()
|
||||||
# GC the result after timing
|
# GC the result after timing
|
||||||
assert result is not None # type: ignore[possibly-undefined]
|
assert result is not None
|
||||||
return t1 - t0
|
return t1 - t0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1147,7 +1147,7 @@ def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
|
||||||
f"but expected one of 'reduced' (default), 'r', or 'complete'"
|
f"but expected one of 'reduced' (default), 'r', or 'complete'"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return compute_q, reduced # type: ignore[possibly-undefined]
|
return compute_q, reduced
|
||||||
|
|
||||||
|
|
||||||
@register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
|
@register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
|
||||||
|
|
@ -1412,7 +1412,7 @@ def triangular_solve_meta(
|
||||||
cloned_coefficient = self.new_empty([0])
|
cloned_coefficient = self.new_empty([0])
|
||||||
else:
|
else:
|
||||||
torch._check(False, lambda: "triangular_solve: Got an unexpected layout.")
|
torch._check(False, lambda: "triangular_solve: Got an unexpected layout.")
|
||||||
return solution, cloned_coefficient # type: ignore[possibly-undefined]
|
return solution, cloned_coefficient
|
||||||
|
|
||||||
|
|
||||||
# From aten/src/ATen/native/LinearAlgebra.cpp
|
# From aten/src/ATen/native/LinearAlgebra.cpp
|
||||||
|
|
@ -1809,7 +1809,7 @@ def _pad3d_common(input, padding, *, is_reflection):
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_mode:
|
if batch_mode:
|
||||||
return input.new_empty((nbatch, nplane, output_d, output_h, output_w)) # type: ignore[possibly-undefined]
|
return input.new_empty((nbatch, nplane, output_d, output_h, output_w))
|
||||||
else:
|
else:
|
||||||
return input.new_empty((nplane, output_d, output_h, output_w))
|
return input.new_empty((nplane, output_d, output_h, output_w))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -246,10 +246,10 @@ def TensorMeta(
|
||||||
assert dtype is not None
|
assert dtype is not None
|
||||||
assert device is not None
|
assert device is not None
|
||||||
|
|
||||||
shape = inferred_shape if shape is None else tuple(shape) # type: ignore[possibly-undefined]
|
shape = inferred_shape if shape is None else tuple(shape)
|
||||||
strides = inferred_strides if strides is None else tuple(strides) # type: ignore[possibly-undefined]
|
strides = inferred_strides if strides is None else tuple(strides)
|
||||||
dtype = inferred_dtype if dtype is None else dtype # type: ignore[possibly-undefined]
|
dtype = inferred_dtype if dtype is None else dtype
|
||||||
device = inferred_device if device is None else device # type: ignore[possibly-undefined]
|
device = inferred_device if device is None else device
|
||||||
|
|
||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
|
||||||
|
|
@ -4875,16 +4875,16 @@ def arange(
|
||||||
# other integral dtypes we don't. Weird... but needed to match ATen shapes.
|
# other integral dtypes we don't. Weird... but needed to match ATen shapes.
|
||||||
if dtype == torch.int64:
|
if dtype == torch.int64:
|
||||||
# Uses floordiv to avoid ceil in inductor.
|
# Uses floordiv to avoid ceil in inductor.
|
||||||
sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined]
|
sgn = bool(xstep > 0) - bool(xstep < 0)
|
||||||
length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined]
|
length = (xend - xstart + xstep - sgn) // xstep
|
||||||
else:
|
else:
|
||||||
length = math.ceil((end - start) / step)
|
length = math.ceil((end - start) / step)
|
||||||
|
|
||||||
if is_integer:
|
if is_integer:
|
||||||
return prims.iota(
|
return prims.iota(
|
||||||
length,
|
length,
|
||||||
start=xstart, # type: ignore[possibly-undefined]
|
start=xstart,
|
||||||
step=xstep, # type: ignore[possibly-undefined]
|
step=xstep,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
requires_grad=requires_grad,
|
requires_grad=requires_grad,
|
||||||
|
|
|
||||||
|
|
@ -312,7 +312,7 @@ def _canonicalize_fft_shape_and_dim_args(
|
||||||
|
|
||||||
# Translate any -1 values in shape to the default length
|
# Translate any -1 values in shape to the default length
|
||||||
ret_shape = tuple(
|
ret_shape = tuple(
|
||||||
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined]
|
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims)
|
||||||
)
|
)
|
||||||
elif dim is None:
|
elif dim is None:
|
||||||
# No shape, no dim
|
# No shape, no dim
|
||||||
|
|
@ -320,12 +320,12 @@ def _canonicalize_fft_shape_and_dim_args(
|
||||||
ret_shape = tuple(input_sizes)
|
ret_shape = tuple(input_sizes)
|
||||||
else:
|
else:
|
||||||
# No shape, has dim
|
# No shape, has dim
|
||||||
ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined]
|
ret_shape = tuple(input_sizes[d] for d in ret_dims)
|
||||||
|
|
||||||
for n in ret_shape:
|
for n in ret_shape:
|
||||||
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
|
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
|
||||||
|
|
||||||
return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined]
|
return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
|
||||||
|
|
||||||
|
|
||||||
def _prod(xs: Iterable[int]) -> int:
|
def _prod(xs: Iterable[int]) -> int:
|
||||||
|
|
|
||||||
|
|
@ -610,7 +610,7 @@ def _str_intern(inp, *, tensor_contents=None):
|
||||||
# no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968
|
# no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968
|
||||||
grad_fn_name = "Invalid"
|
grad_fn_name = "Invalid"
|
||||||
|
|
||||||
if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined]
|
if grad_fn_name is None and grad_fn is not None:
|
||||||
grad_fn_name = type(grad_fn).__name__
|
grad_fn_name = type(grad_fn).__name__
|
||||||
if grad_fn_name == "CppFunction":
|
if grad_fn_name == "CppFunction":
|
||||||
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
|
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
|
||||||
|
|
@ -627,7 +627,7 @@ def _str_intern(inp, *, tensor_contents=None):
|
||||||
suffixes.append(f"tangent={tangent}")
|
suffixes.append(f"tangent={tangent}")
|
||||||
|
|
||||||
string_repr = _add_suffixes(
|
string_repr = _add_suffixes(
|
||||||
prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse # type: ignore[possibly-undefined]
|
prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if this instance is flagged as a parameter and change the repr accordingly.
|
# Check if this instance is flagged as a parameter and change the repr accordingly.
|
||||||
|
|
|
||||||
|
|
@ -188,7 +188,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
|
||||||
|
|
||||||
if self.bn.training:
|
if self.bn.training:
|
||||||
avg_dims = [0] + list(range(2, len(self.weight.shape)))
|
avg_dims = [0] + list(range(2, len(self.weight.shape)))
|
||||||
batch_mean = conv_out.mean(avg_dims) # type: ignore[possibly-undefined]
|
batch_mean = conv_out.mean(avg_dims)
|
||||||
batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
|
batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
|
||||||
avg_dims
|
avg_dims
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,7 @@ class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
|
||||||
|
|
||||||
if converted.bias_v is not None:
|
if converted.bias_v is not None:
|
||||||
bias_v = converted._parameters.pop('bias_v')
|
bias_v = converted._parameters.pop('bias_v')
|
||||||
sc, zp = torch._choose_qparams_per_tensor(bias_k, # type: ignore[possibly-undefined]
|
sc, zp = torch._choose_qparams_per_tensor(bias_k,
|
||||||
reduce_range=False)
|
reduce_range=False)
|
||||||
bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8)
|
bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8)
|
||||||
setattr(converted, 'bias_v', bias_v) # noqa: B010
|
setattr(converted, 'bias_v', bias_v) # noqa: B010
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ class LinearPackedParams(torch.nn.Module):
|
||||||
wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
|
wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
|
||||||
elif self.dtype == torch.float16:
|
elif self.dtype == torch.float16:
|
||||||
wq = torch.zeros([1, 1], dtype=torch.float)
|
wq = torch.zeros([1, 1], dtype=torch.float)
|
||||||
self.set_weight_bias(wq, None) # type: ignore[possibly-undefined]
|
self.set_weight_bias(wq, None)
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
|
def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
|
||||||
|
|
|
||||||
|
|
@ -435,7 +435,7 @@ class LSTM(RNNBase):
|
||||||
hx = (h_zeros, c_zeros)
|
hx = (h_zeros, c_zeros)
|
||||||
else:
|
else:
|
||||||
if batch_sizes is None: # If not PackedSequence input.
|
if batch_sizes is None: # If not PackedSequence input.
|
||||||
if is_batched: # type: ignore[possibly-undefined]
|
if is_batched:
|
||||||
if (hx[0].dim() != 3 or hx[1].dim() != 3):
|
if (hx[0].dim() != 3 or hx[1].dim() != 3):
|
||||||
msg = ("For batched 3-D input, hx and cx should "
|
msg = ("For batched 3-D input, hx and cx should "
|
||||||
f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
|
f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
|
||||||
|
|
@ -465,8 +465,8 @@ class LSTM(RNNBase):
|
||||||
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
||||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||||
else:
|
else:
|
||||||
if not is_batched: # type: ignore[possibly-undefined]
|
if not is_batched:
|
||||||
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
|
output = output.squeeze(batch_dim)
|
||||||
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
|
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
|
||||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||||
|
|
||||||
|
|
@ -589,8 +589,8 @@ class GRU(RNNBase):
|
||||||
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
||||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||||
else:
|
else:
|
||||||
if not is_batched: # type: ignore[possibly-undefined]
|
if not is_batched:
|
||||||
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
|
output = output.squeeze(batch_dim)
|
||||||
hidden = hidden.squeeze(1)
|
hidden = hidden.squeeze(1)
|
||||||
|
|
||||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||||
|
|
|
||||||
|
|
@ -759,7 +759,7 @@ def create_a_shadows_b(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
|
fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
|
||||||
fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b) # type: ignore[possibly-undefined]
|
fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)
|
||||||
|
|
||||||
if node_b_is_start_node:
|
if node_b_is_start_node:
|
||||||
|
|
||||||
|
|
@ -817,7 +817,7 @@ def create_a_shadows_b(
|
||||||
# cast dtype from the dtype of node_c's input to the dtype of
|
# cast dtype from the dtype of node_c's input to the dtype of
|
||||||
# node_a's input (dequant, etc)
|
# node_a's input (dequant, etc)
|
||||||
# prev_node_c = node_c.args[0]
|
# prev_node_c = node_c.args[0]
|
||||||
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) # type: ignore[possibly-undefined]
|
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
|
||||||
if should_log_inputs:
|
if should_log_inputs:
|
||||||
# skip the input logger when inserting a dtype cast
|
# skip the input logger when inserting a dtype cast
|
||||||
if isinstance(prev_node_c, Node):
|
if isinstance(prev_node_c, Node):
|
||||||
|
|
@ -901,7 +901,7 @@ def create_a_shadows_b(
|
||||||
# input_logger = env_c[dtype_cast_node.name]
|
# input_logger = env_c[dtype_cast_node.name]
|
||||||
# Find the first node in the subgraph
|
# Find the first node in the subgraph
|
||||||
cur_node = node_a_shadows_c
|
cur_node = node_a_shadows_c
|
||||||
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
|
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger:
|
||||||
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
|
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
|
||||||
if isinstance(input_logger, Node):
|
if isinstance(input_logger, Node):
|
||||||
input_logger_mod = getattr(gm_b, input_logger.name)
|
input_logger_mod = getattr(gm_b, input_logger.name)
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ class OutputProp:
|
||||||
elif node.op == 'call_module':
|
elif node.op == 'call_module':
|
||||||
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
|
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
|
||||||
|
|
||||||
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
|
if isinstance(result, torch.Tensor):
|
||||||
node.traced_result = result
|
node.traced_result = result
|
||||||
|
|
||||||
env[node.name] = result
|
env[node.name] = result
|
||||||
|
|
@ -375,7 +375,7 @@ def create_submodule_from_subgraph(
|
||||||
# TODO(future PR): this is ignoring kwargs, will need to support kwargs
|
# TODO(future PR): this is ignoring kwargs, will need to support kwargs
|
||||||
# for any fusion pattern which has them for a node that is not the
|
# for any fusion pattern which has them for a node that is not the
|
||||||
# first node.
|
# first node.
|
||||||
cur_args_copy = [cur_node_copy] # type: ignore[has-type, possibly-undefined] # noqa: F821
|
cur_args_copy = [cur_node_copy] # type: ignore[has-type] # noqa: F821
|
||||||
|
|
||||||
if len(cur_node_orig.args) > 1:
|
if len(cur_node_orig.args) > 1:
|
||||||
for arg in cur_node_orig.args[1:]:
|
for arg in cur_node_orig.args[1:]:
|
||||||
|
|
@ -399,15 +399,15 @@ def create_submodule_from_subgraph(
|
||||||
mod_name = f"mod_{cur_name_idx}"
|
mod_name = f"mod_{cur_name_idx}"
|
||||||
setattr(gm, mod_name, orig_mod_copy)
|
setattr(gm, mod_name, orig_mod_copy)
|
||||||
cur_name_idx += 1
|
cur_name_idx += 1
|
||||||
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
|
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy)
|
||||||
|
|
||||||
elif cur_node_orig.op == 'call_function':
|
elif cur_node_orig.op == 'call_function':
|
||||||
cur_node_copy = g.call_function(
|
cur_node_copy = g.call_function(
|
||||||
cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
|
cur_node_orig.target, cur_args_copy, cur_kwargs_copy)
|
||||||
|
|
||||||
elif cur_node_orig.op == 'call_method':
|
elif cur_node_orig.op == 'call_method':
|
||||||
cur_node_copy = g.call_method(
|
cur_node_copy = g.call_method(
|
||||||
cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
|
cur_node_orig.target, cur_args_copy, cur_kwargs_copy)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f'{cur_node_orig.op} not supported yet')
|
raise AssertionError(f'{cur_node_orig.op} not supported yet')
|
||||||
|
|
|
||||||
|
|
@ -402,7 +402,7 @@ class ActivationSparsifier:
|
||||||
hook = layer.register_forward_pre_hook(self._sparsify_hook(name))
|
hook = layer.register_forward_pre_hook(self._sparsify_hook(name))
|
||||||
|
|
||||||
config['layer'] = layer
|
config['layer'] = layer
|
||||||
config['hook'] = hook # type: ignore[possibly-undefined]
|
config['hook'] = hook
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
format_string = self.__class__.__name__ + ' ('
|
format_string = self.__class__.__name__ + ' ('
|
||||||
|
|
|
||||||
|
|
@ -117,7 +117,7 @@ def _prune_linear_helper(linear: nn.Linear) -> Tensor:
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True)
|
parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True)
|
||||||
linear.weight = nn.Parameter(linear.weight[mask]) # type: ignore[possibly-undefined]
|
linear.weight = nn.Parameter(linear.weight[mask])
|
||||||
linear.out_features = linear.weight.shape[0]
|
linear.out_features = linear.weight.shape[0]
|
||||||
_remove_bias_handles(linear)
|
_remove_bias_handles(linear)
|
||||||
|
|
||||||
|
|
@ -175,7 +175,7 @@ def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor:
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True)
|
parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True)
|
||||||
conv2d.weight = nn.Parameter(conv2d.weight[mask]) # type: ignore[possibly-undefined]
|
conv2d.weight = nn.Parameter(conv2d.weight[mask])
|
||||||
conv2d.out_channels = conv2d.weight.shape[0]
|
conv2d.out_channels = conv2d.weight.shape[0]
|
||||||
|
|
||||||
_remove_bias_handles(conv2d)
|
_remove_bias_handles(conv2d)
|
||||||
|
|
@ -197,7 +197,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
|
||||||
conv2d_1.bias is not None
|
conv2d_1.bias is not None
|
||||||
): # conv2d_1 has original bias and bias propagated from previous layer
|
): # conv2d_1 has original bias and bias propagated from previous layer
|
||||||
new_bias = torch.zeros(conv2d_1.bias.shape)
|
new_bias = torch.zeros(conv2d_1.bias.shape)
|
||||||
new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined]
|
new_bias[mask] = conv2d_1.bias[mask]
|
||||||
# adjusted bias that to keep in conv2d_1
|
# adjusted bias that to keep in conv2d_1
|
||||||
new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
|
new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
|
||||||
# pruned biases that are kept instead of propagated
|
# pruned biases that are kept instead of propagated
|
||||||
|
|
@ -209,7 +209,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
|
||||||
if (
|
if (
|
||||||
conv2d_1.bias is not None
|
conv2d_1.bias is not None
|
||||||
): # conv2d_1 has bias propagated from previous layer
|
): # conv2d_1 has bias propagated from previous layer
|
||||||
conv2d_1.bias.data[~mask] = 0 # type: ignore[possibly-undefined]
|
conv2d_1.bias.data[~mask] = 0
|
||||||
|
|
||||||
if hasattr(conv2d_1, "_bias"):
|
if hasattr(conv2d_1, "_bias"):
|
||||||
delattr(conv2d_1, "_bias")
|
delattr(conv2d_1, "_bias")
|
||||||
|
|
|
||||||
|
|
@ -835,7 +835,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||||
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
|
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
|
||||||
if (
|
if (
|
||||||
type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
|
type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
|
||||||
maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined]
|
maybe_obs_mod.dtype == arg_as_input_target_dtype
|
||||||
):
|
):
|
||||||
arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
|
arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
|
||||||
existing_obs_node = maybe_obs_node
|
existing_obs_node = maybe_obs_node
|
||||||
|
|
|
||||||
|
|
@ -516,7 +516,7 @@ def register_multi_grad_hook(
|
||||||
if tensor.requires_grad
|
if tensor.requires_grad
|
||||||
)
|
)
|
||||||
|
|
||||||
return Handle(handles) # type: ignore[possibly-undefined]
|
return Handle(handles)
|
||||||
|
|
||||||
|
|
||||||
# NOTE [Allow mutation on tensors saved for backward]
|
# NOTE [Allow mutation on tensors saved for backward]
|
||||||
|
|
@ -746,4 +746,4 @@ def _engine_run_backward(t_outputs, *args, **kwargs):
|
||||||
) # Calls into the C++ engine to run the backward pass
|
) # Calls into the C++ engine to run the backward pass
|
||||||
finally:
|
finally:
|
||||||
if attach_logging_hooks:
|
if attach_logging_hooks:
|
||||||
unregister_hooks() # type: ignore[possibly-undefined]
|
unregister_hooks()
|
||||||
|
|
|
||||||
|
|
@ -1148,7 +1148,7 @@ def _build_table(
|
||||||
if evt.flops <= 0:
|
if evt.flops <= 0:
|
||||||
row_values.append("--")
|
row_values.append("--")
|
||||||
else:
|
else:
|
||||||
row_values.append(f"{evt.flops * flops_scale:8.3f}") # type: ignore[possibly-undefined]
|
row_values.append(f"{evt.flops * flops_scale:8.3f}")
|
||||||
if has_stack:
|
if has_stack:
|
||||||
src_field = ""
|
src_field = ""
|
||||||
if len(evt.stack) > 0:
|
if len(evt.stack) > 0:
|
||||||
|
|
|
||||||
|
|
@ -1176,7 +1176,7 @@ class _NnapiSerializer:
|
||||||
shape=change_element(out_oper.shape, dim, out_dim_size)
|
shape=change_element(out_oper.shape, dim, out_dim_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
if in_oper.dim_order == DimOrder.CHANNELS_LAST: # type: ignore[possibly-undefined]
|
if in_oper.dim_order == DimOrder.CHANNELS_LAST:
|
||||||
assert len(out_oper.shape) == 4
|
assert len(out_oper.shape) == 4
|
||||||
nnapi_dim = [0, 3, 1, 2][dim]
|
nnapi_dim = [0, 3, 1, 2][dim]
|
||||||
else:
|
else:
|
||||||
|
|
@ -1633,10 +1633,10 @@ class _NnapiSerializer:
|
||||||
size_ctype, size_arg = self.get_constant_value(size_jit)
|
size_ctype, size_arg = self.get_constant_value(size_jit)
|
||||||
|
|
||||||
if node.inputsSize() == 3:
|
if node.inputsSize() == 3:
|
||||||
scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined]
|
scale_ctype, scale_arg = self.get_constant_value(scale_jit)
|
||||||
else:
|
else:
|
||||||
scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined]
|
scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit)
|
||||||
scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined]
|
scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit)
|
||||||
|
|
||||||
# The only way for the 4-argument overload of upsample_nearest2d to
|
# The only way for the 4-argument overload of upsample_nearest2d to
|
||||||
# have been added to the graph without error is if the scale_h and
|
# have been added to the graph without error is if the scale_h and
|
||||||
|
|
|
||||||
|
|
@ -325,7 +325,7 @@ def make_graphed_callables(
|
||||||
only_inputs=True,
|
only_inputs=True,
|
||||||
allow_unused=allow_unused_input,
|
allow_unused=allow_unused_input,
|
||||||
)
|
)
|
||||||
del outputs, grad_inputs # type: ignore[possibly-undefined]
|
del outputs, grad_inputs
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# All captures here share a mempool. To avoid replays corrupting each other's memory,
|
# All captures here share a mempool. To avoid replays corrupting each other's memory,
|
||||||
|
|
|
||||||
|
|
@ -206,4 +206,4 @@ def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank):
|
||||||
start_pos = current_offsets
|
start_pos = current_offsets
|
||||||
break
|
break
|
||||||
current_offsets += chunk_size
|
current_offsets += chunk_size
|
||||||
return start_pos, chunk_size # type: ignore[possibly-undefined]
|
return start_pos, chunk_size
|
||||||
|
|
|
||||||
|
|
@ -395,7 +395,7 @@ def _handle_row_wise_sharding(
|
||||||
result = torch.nn.functional.embedding_bag(
|
result = torch.nn.functional.embedding_bag(
|
||||||
lookup_input,
|
lookup_input,
|
||||||
torch.cat([local_shard, padding_row]),
|
torch.cat([local_shard, padding_row]),
|
||||||
offsets=offsets_list if offsets is not None else offsets, # type: ignore[possibly-undefined]
|
offsets=offsets_list if offsets is not None else offsets,
|
||||||
mode=mode if mode != "mean" else "sum",
|
mode=mode if mode != "mean" else "sum",
|
||||||
per_sample_weights=per_sample_weights,
|
per_sample_weights=per_sample_weights,
|
||||||
max_norm=max_norm,
|
max_norm=max_norm,
|
||||||
|
|
|
||||||
|
|
@ -541,7 +541,7 @@ def mark_data_parallel_shardings(
|
||||||
# mark activation as sharded on batch dim
|
# mark activation as sharded on batch dim
|
||||||
node_sharding = node_strategies[0]
|
node_sharding = node_strategies[0]
|
||||||
|
|
||||||
node.meta["sharding"] = node_sharding # type: ignore[possibly-undefined]
|
node.meta["sharding"] = node_sharding
|
||||||
|
|
||||||
placeholder_idx += 1
|
placeholder_idx += 1
|
||||||
elif node.op == "call_function":
|
elif node.op == "call_function":
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class GraphModuleTransformation:
|
||||||
"iter_graph_main_gm": iter_gm.main_gm.print_readable(False),
|
"iter_graph_main_gm": iter_gm.main_gm.print_readable(False),
|
||||||
"iter_graph_cleanup_gm": iter_gm.cleanup_gm.print_readable(False),
|
"iter_graph_cleanup_gm": iter_gm.cleanup_gm.print_readable(False),
|
||||||
},
|
},
|
||||||
graph_folder, # type: ignore[possibly-undefined]
|
graph_folder,
|
||||||
)
|
)
|
||||||
|
|
||||||
return iter_gm
|
return iter_gm
|
||||||
|
|
|
||||||
|
|
@ -353,7 +353,7 @@ def _scatter_wait_result(
|
||||||
gm.graph.node_replace_all_uses_with(orig_wait, wait_output_node)
|
gm.graph.node_replace_all_uses_with(orig_wait, wait_output_node)
|
||||||
|
|
||||||
if last_split_reshape_node == split_node:
|
if last_split_reshape_node == split_node:
|
||||||
last_split_reshape_node = wait_output_node # type: ignore[possibly-undefined]
|
last_split_reshape_node = wait_output_node
|
||||||
|
|
||||||
need_sort_nodes = sorted(need_sort_nodes, key=lambda node: node_indices[node])
|
need_sort_nodes = sorted(need_sort_nodes, key=lambda node: node_indices[node])
|
||||||
gm.graph.move_after(need_sort_nodes, last_split_reshape_node)
|
gm.graph.move_after(need_sort_nodes, last_split_reshape_node)
|
||||||
|
|
|
||||||
|
|
@ -561,7 +561,7 @@ class IterGraph(fx.Graph):
|
||||||
delete_user_cb,
|
delete_user_cb,
|
||||||
propagate_meta=propagate_meta,
|
propagate_meta=propagate_meta,
|
||||||
)
|
)
|
||||||
return ret # type: ignore[possibly-undefined]
|
return ret
|
||||||
|
|
||||||
def node_add_user(self, node: fx.Node, user: Any) -> None:
|
def node_add_user(self, node: fx.Node, user: Any) -> None:
|
||||||
for graph in self._all_graphs:
|
for graph in self._all_graphs:
|
||||||
|
|
@ -607,8 +607,8 @@ class IterGraph(fx.Graph):
|
||||||
"_foreach_add_",
|
"_foreach_add_",
|
||||||
):
|
):
|
||||||
step_node = node
|
step_node = node
|
||||||
self.node_add_user(optim_node, output_node) # type: ignore[possibly-undefined]
|
self.node_add_user(optim_node, output_node)
|
||||||
self.node_add_user(step_node, optim_node) # type: ignore[possibly-undefined]
|
self.node_add_user(step_node, optim_node)
|
||||||
|
|
||||||
def defunctionalize_optim(self) -> None:
|
def defunctionalize_optim(self) -> None:
|
||||||
# TODO: remove this API after DCE is not used with IterGraph
|
# TODO: remove this API after DCE is not used with IterGraph
|
||||||
|
|
@ -624,8 +624,8 @@ class IterGraph(fx.Graph):
|
||||||
"_foreach_add_",
|
"_foreach_add_",
|
||||||
):
|
):
|
||||||
step_node = node
|
step_node = node
|
||||||
optim_node.users.pop(output_node, None) # type: ignore[possibly-undefined]
|
optim_node.users.pop(output_node, None)
|
||||||
step_node.users.pop(optim_node, None) # type: ignore[possibly-undefined]
|
step_node.users.pop(optim_node, None)
|
||||||
|
|
||||||
def freeze_cross_iter_movement(self) -> None:
|
def freeze_cross_iter_movement(self) -> None:
|
||||||
self._freeze_cross_iter_movement = True
|
self._freeze_cross_iter_movement = True
|
||||||
|
|
|
||||||
|
|
@ -199,7 +199,7 @@ class OpDispatcher:
|
||||||
if output_sharding.output_spec is None:
|
if output_sharding.output_spec is None:
|
||||||
if op_call == aten.equal.default:
|
if op_call == aten.equal.default:
|
||||||
obj_list = [None for _ in range(dist.get_world_size())]
|
obj_list = [None for _ in range(dist.get_world_size())]
|
||||||
dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined]
|
dist.all_gather_object(obj_list, local_results)
|
||||||
obj_list = list(filter(lambda x: x is not None, obj_list))
|
obj_list = list(filter(lambda x: x is not None, obj_list))
|
||||||
# perform reduce on the collection with AND op
|
# perform reduce on the collection with AND op
|
||||||
local_results = functools.reduce(operator.and_, obj_list, True)
|
local_results = functools.reduce(operator.and_, obj_list, True)
|
||||||
|
|
@ -229,7 +229,7 @@ class OpDispatcher:
|
||||||
assert len(out_dts) >= 1, "out variant should have at least one out arg"
|
assert len(out_dts) >= 1, "out variant should have at least one out arg"
|
||||||
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
|
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
|
||||||
else:
|
else:
|
||||||
return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
|
return self.wrap(local_results, output_sharding.output_spec)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def redistribute_local_args(
|
def redistribute_local_args(
|
||||||
|
|
|
||||||
|
|
@ -201,7 +201,7 @@ class Shard(Placement):
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_padded:
|
if is_padded:
|
||||||
output = self._unpad_tensor(output, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
|
output = self._unpad_tensor(output, pad_sizes[my_coordinate[mesh_dim]])
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _to_replicate_tensor(
|
def _to_replicate_tensor(
|
||||||
|
|
@ -236,7 +236,7 @@ class Shard(Placement):
|
||||||
group=(mesh, mesh_dim),
|
group=(mesh, mesh_dim),
|
||||||
)
|
)
|
||||||
if is_padded:
|
if is_padded:
|
||||||
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
|
unpad_size = full_chunk_size * num_chunks - logical_dim_size
|
||||||
result = self._unpad_tensor(result, unpad_size)
|
result = self._unpad_tensor(result, unpad_size)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ class HybridModel(torch.nn.Module):
|
||||||
# Make sure combined PS dimension is always bigger or equal than the FC input
|
# Make sure combined PS dimension is always bigger or equal than the FC input
|
||||||
assert NUM_PS * EMBEDDING_DIM >= 512
|
assert NUM_PS * EMBEDDING_DIM >= 512
|
||||||
dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512)
|
dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512)
|
||||||
emb_lookups_reshaped = emb_lookups_cat.reshape( # type: ignore[possibly-undefined]
|
emb_lookups_reshaped = emb_lookups_cat.reshape(
|
||||||
[emb_lookups_cat.shape[0] * dim_normalizer, 512]
|
[emb_lookups_cat.shape[0] * dim_normalizer, 512]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -195,7 +195,7 @@ def _run_trainer(emb_rref_list, rank):
|
||||||
|
|
||||||
# Throw away warm-up measurements
|
# Throw away warm-up measurements
|
||||||
measurements = measurements[WARMUP_CYCLES:]
|
measurements = measurements[WARMUP_CYCLES:]
|
||||||
return rank, measurements, batch_size # type: ignore[possibly-undefined]
|
return rank, measurements, batch_size
|
||||||
|
|
||||||
|
|
||||||
def run_worker(rank, world_size):
|
def run_worker(rank, world_size):
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ else:
|
||||||
if cur_rank in mesh_1d:
|
if cur_rank in mesh_1d:
|
||||||
res_sub_mesh = sub_mesh
|
res_sub_mesh = sub_mesh
|
||||||
|
|
||||||
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined]
|
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
|
||||||
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
|
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
|
||||||
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
|
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
|
||||||
return res_sub_mesh
|
return res_sub_mesh
|
||||||
|
|
|
||||||
|
|
@ -1945,9 +1945,9 @@ def _coalescing_manager(
|
||||||
work = group._end_coalescing(device)
|
work = group._end_coalescing(device)
|
||||||
|
|
||||||
if async_ops:
|
if async_ops:
|
||||||
cm.append(work) # type: ignore[possibly-undefined]
|
cm.append(work)
|
||||||
else:
|
else:
|
||||||
work.wait() # type: ignore[possibly-undefined]
|
work.wait()
|
||||||
|
|
||||||
|
|
||||||
def batch_isend_irecv(p2p_op_list):
|
def batch_isend_irecv(p2p_op_list):
|
||||||
|
|
@ -2460,7 +2460,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
|
||||||
# All ranks call gather with equal-sized tensors.
|
# All ranks call gather with equal-sized tensors.
|
||||||
gather(
|
gather(
|
||||||
input_tensor,
|
input_tensor,
|
||||||
gather_list=output_tensors if my_rank == dst else None, # type: ignore[possibly-undefined]
|
gather_list=output_tensors if my_rank == dst else None,
|
||||||
dst=dst,
|
dst=dst,
|
||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
|
|
@ -2560,7 +2560,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
|
||||||
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
|
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
|
||||||
# has only one element, we can skip the copy.
|
# has only one element, we can skip the copy.
|
||||||
if my_rank == src:
|
if my_rank == src:
|
||||||
if len(tensor_list) == 1: # type: ignore[possibly-undefined]
|
if len(tensor_list) == 1:
|
||||||
object_tensor = tensor_list[0]
|
object_tensor = tensor_list[0]
|
||||||
else:
|
else:
|
||||||
object_tensor = torch.cat(tensor_list)
|
object_tensor = torch.cat(tensor_list)
|
||||||
|
|
@ -2663,8 +2663,8 @@ def scatter_object_list(
|
||||||
# Src rank broadcasts the maximum tensor size. This is because all ranks are
|
# Src rank broadcasts the maximum tensor size. This is because all ranks are
|
||||||
# expected to call into scatter() with equal-sized tensors.
|
# expected to call into scatter() with equal-sized tensors.
|
||||||
if my_rank == src:
|
if my_rank == src:
|
||||||
max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined]
|
max_tensor_size = max(tensor_sizes)
|
||||||
for tensor in tensor_list: # type: ignore[possibly-undefined]
|
for tensor in tensor_list:
|
||||||
tensor.resize_(max_tensor_size)
|
tensor.resize_(max_tensor_size)
|
||||||
else:
|
else:
|
||||||
max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
|
max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
|
||||||
|
|
@ -2674,7 +2674,7 @@ def scatter_object_list(
|
||||||
output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device)
|
output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device)
|
||||||
scatter(
|
scatter(
|
||||||
output_tensor,
|
output_tensor,
|
||||||
scatter_list=None if my_rank != src else tensor_list, # type: ignore[possibly-undefined]
|
scatter_list=None if my_rank != src else tensor_list,
|
||||||
src=src,
|
src=src,
|
||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
|
|
@ -2683,7 +2683,7 @@ def scatter_object_list(
|
||||||
obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
|
obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
|
||||||
scatter(
|
scatter(
|
||||||
obj_tensor_size,
|
obj_tensor_size,
|
||||||
scatter_list=None if my_rank != src else tensor_sizes, # type: ignore[possibly-undefined]
|
scatter_list=None if my_rank != src else tensor_sizes,
|
||||||
src=src,
|
src=src,
|
||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ class Event:
|
||||||
return data
|
return data
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
data_dict = json.loads(data)
|
data_dict = json.loads(data)
|
||||||
data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined]
|
data_dict["source"] = EventSource[data_dict["source"]]
|
||||||
return Event(**data_dict)
|
return Event(**data_dict)
|
||||||
|
|
||||||
def serialize(self) -> str:
|
def serialize(self) -> str:
|
||||||
|
|
@ -105,7 +105,7 @@ class RdzvEvent:
|
||||||
return data
|
return data
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
data_dict = json.loads(data)
|
data_dict = json.loads(data)
|
||||||
data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined]
|
data_dict["node_state"] = NodeState[data_dict["node_state"]]
|
||||||
return RdzvEvent(**data_dict)
|
return RdzvEvent(**data_dict)
|
||||||
|
|
||||||
def serialize(self) -> str:
|
def serialize(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -126,7 +126,7 @@ def prof(fn=None, group: str = "torchelastic"):
|
||||||
put_metric(f"{key}.failure", 1, group)
|
put_metric(f"{key}.failure", 1, group)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined]
|
put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
@ -164,7 +164,7 @@ def profile(group=None):
|
||||||
publish_metric(
|
publish_metric(
|
||||||
group,
|
group,
|
||||||
f"{func.__name__}.duration.ms",
|
f"{func.__name__}.duration.ms",
|
||||||
get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined]
|
get_elapsed_time_ms(start_time),
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -176,7 +176,7 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
|
||||||
"The connection to the C10d store has failed. See inner exception for details."
|
"The connection to the C10d store has failed. See inner exception for details."
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
return store # type: ignore[possibly-undefined]
|
return store
|
||||||
|
|
||||||
|
|
||||||
def _create_file_store(params: RendezvousParameters) -> FileStore:
|
def _create_file_store(params: RendezvousParameters) -> FileStore:
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ def find_free_port():
|
||||||
s.listen(0)
|
s.listen(0)
|
||||||
return s
|
return s
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
s.close() # type: ignore[possibly-undefined]
|
s.close()
|
||||||
print(f"Socket creation attempt failed: {e}")
|
print(f"Socket creation attempt failed: {e}")
|
||||||
raise RuntimeError("Failed to create a socket")
|
raise RuntimeError("Failed to create a socket")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1767,8 +1767,8 @@ class FlatParamHandle:
|
||||||
)
|
)
|
||||||
flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
|
flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
|
||||||
if self._use_orig_params:
|
if self._use_orig_params:
|
||||||
if skip_use_sharded_views: # type: ignore[possibly-undefined]
|
if skip_use_sharded_views:
|
||||||
self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined]
|
self._unsharded_flat_param_for_skipped_views = unsharded_flat_param
|
||||||
else:
|
else:
|
||||||
self._use_sharded_views()
|
self._use_sharded_views()
|
||||||
# For the post-forward reshard, we may try to use sharded gradient
|
# For the post-forward reshard, we may try to use sharded gradient
|
||||||
|
|
@ -1776,7 +1776,7 @@ class FlatParamHandle:
|
||||||
# in `no_sync()`), but for the post-backward reshard, we delay the
|
# in `no_sync()`), but for the post-backward reshard, we delay the
|
||||||
# call to after the reduce-scatter.
|
# call to after the reduce-scatter.
|
||||||
if (
|
if (
|
||||||
in_forward # type: ignore[possibly-undefined]
|
in_forward
|
||||||
# Skip using gradient views if skipped using sharded views
|
# Skip using gradient views if skipped using sharded views
|
||||||
# since exposing unsharded parameters with sharded gradients
|
# since exposing unsharded parameters with sharded gradients
|
||||||
# may be confusing to the user
|
# may be confusing to the user
|
||||||
|
|
|
||||||
|
|
@ -885,7 +885,7 @@ def _materialize_meta_module(
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Unable to call `reset_parameters()` for module on meta "
|
"Unable to call `reset_parameters()` for module on meta "
|
||||||
f"device with error {str(e)}. Please ensure that your module of"
|
f"device with error {str(e)}. Please ensure that your module of"
|
||||||
f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined]
|
f"type {type(module)} implements a `reset_parameters()` method."
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
@ -994,7 +994,7 @@ def _move_states_to_device(
|
||||||
param.grad.data = param.grad.to(device_from_device_id)
|
param.grad.data = param.grad.to(device_from_device_id)
|
||||||
for buffer in buffers:
|
for buffer in buffers:
|
||||||
buffer.data = buffer.to(device_from_device_id)
|
buffer.data = buffer.to(device_from_device_id)
|
||||||
elif current_device == cpu_device: # type: ignore[possibly-undefined]
|
elif current_device == cpu_device:
|
||||||
_warn_cpu_init()
|
_warn_cpu_init()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1419,7 +1419,7 @@ def _convert_all_state_info(
|
||||||
)
|
)
|
||||||
gathered_state[name] = scalar_tensor_value
|
gathered_state[name] = scalar_tensor_value
|
||||||
|
|
||||||
return dtype, state_buffers # type: ignore[possibly-undefined]
|
return dtype, state_buffers
|
||||||
|
|
||||||
|
|
||||||
def _unflatten_orig_param_states(
|
def _unflatten_orig_param_states(
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ def _create_1d_device_mesh(device_mesh: DeviceMesh, tp_mesh_dim: int = 0) -> Dev
|
||||||
if cur_rank in mesh_1d:
|
if cur_rank in mesh_1d:
|
||||||
res_sub_mesh = sub_mesh
|
res_sub_mesh = sub_mesh
|
||||||
|
|
||||||
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[tp_mesh_dim]] # type: ignore[possibly-undefined]
|
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[tp_mesh_dim]]
|
||||||
return res_sub_mesh
|
return res_sub_mesh
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -253,7 +253,7 @@ class ExportedProgram:
|
||||||
user_args, self.call_spec.in_spec, exact_structural_match=True
|
user_args, self.call_spec.in_spec, exact_structural_match=True
|
||||||
) # type: ignore[assignment]
|
) # type: ignore[assignment]
|
||||||
except Exception:
|
except Exception:
|
||||||
_, received_spec = pytree.tree_flatten(user_args) # type: ignore[possibly-undefined]
|
_, received_spec = pytree.tree_flatten(user_args)
|
||||||
raise TypeError( # noqa: TRY200
|
raise TypeError( # noqa: TRY200
|
||||||
"Trying to flatten user inputs with exported input tree spec: \n"
|
"Trying to flatten user inputs with exported input tree spec: \n"
|
||||||
f"{self.call_spec.in_spec}\n"
|
f"{self.call_spec.in_spec}\n"
|
||||||
|
|
|
||||||
|
|
@ -998,7 +998,7 @@ class Partitioner:
|
||||||
if cost < min_cost:
|
if cost < min_cost:
|
||||||
node_pair = [node, n1]
|
node_pair = [node, n1]
|
||||||
min_cost = cost
|
min_cost = cost
|
||||||
return cost, node_pair # type: ignore[possibly-undefined]
|
return cost, node_pair
|
||||||
|
|
||||||
# First use size_base_partition
|
# First use size_base_partition
|
||||||
self.size_based_partition()
|
self.size_based_partition()
|
||||||
|
|
|
||||||
|
|
@ -263,7 +263,7 @@ def split_const_subgraphs(
|
||||||
setattr(
|
setattr(
|
||||||
split,
|
split,
|
||||||
fx_const_folded_attrs_name,
|
fx_const_folded_attrs_name,
|
||||||
torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined]
|
torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(),
|
||||||
)
|
)
|
||||||
for node in split.graph.nodes:
|
for node in split.graph.nodes:
|
||||||
if node.op == "call_module" and node.target == const_mod_name:
|
if node.op == "call_module" and node.target == const_mod_name:
|
||||||
|
|
|
||||||
|
|
@ -694,7 +694,7 @@ for name in math_op_names:
|
||||||
fn.__qualname__ = fn.__name__ = priv_sympy_name
|
fn.__qualname__ = fn.__name__ = priv_sympy_name
|
||||||
setattr(current_module, priv_sympy_name, fn)
|
setattr(current_module, priv_sympy_name, fn)
|
||||||
|
|
||||||
del fn, name, priv_sympy_name # type: ignore[possibly-undefined]
|
del fn, name, priv_sympy_name
|
||||||
|
|
||||||
|
|
||||||
def _sympy_abs(a):
|
def _sympy_abs(a):
|
||||||
|
|
@ -753,7 +753,7 @@ for name in math_op_names:
|
||||||
sym_name = f"sym_{name}"
|
sym_name = f"sym_{name}"
|
||||||
magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
|
magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
|
||||||
|
|
||||||
del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined]
|
del name, sym_name, math_op_names, current_module
|
||||||
|
|
||||||
|
|
||||||
def sympy_is_contiguous(sizes, strides):
|
def sympy_is_contiguous(sizes, strides):
|
||||||
|
|
|
||||||
|
|
@ -68,8 +68,8 @@ def consistent(a, b):
|
||||||
p1 += 1
|
p1 += 1
|
||||||
# We only need to check for variadic ends
|
# We only need to check for variadic ends
|
||||||
# Variadic types are guaranteed to be the last element
|
# Variadic types are guaranteed to be the last element
|
||||||
return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined]
|
return (isvariadic(cur_a) and p2 == len(b) or
|
||||||
isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined]
|
isvariadic(cur_b) and p1 == len(a))
|
||||||
|
|
||||||
|
|
||||||
def ambiguous(a, b):
|
def ambiguous(a, b):
|
||||||
|
|
|
||||||
|
|
@ -371,11 +371,11 @@ class _MinimizerBase:
|
||||||
# Compare results
|
# Compare results
|
||||||
names: Names = output_names
|
names: Names = output_names
|
||||||
if output_names is None:
|
if output_names is None:
|
||||||
names = [str(v) for v in result_key] # type: ignore[possibly-undefined]
|
names = [str(v) for v in result_key]
|
||||||
|
|
||||||
numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
|
numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
|
||||||
|
|
||||||
self.results[result_key] = numeric_result # type: ignore[possibly-undefined]
|
self.results[result_key] = numeric_result
|
||||||
report.append(f"Numerical accuracy = {numeric_result}")
|
report.append(f"Numerical accuracy = {numeric_result}")
|
||||||
if not bool_result:
|
if not bool_result:
|
||||||
report.append(f"Result mismatch for {result_key}")
|
report.append(f"Result mismatch for {result_key}")
|
||||||
|
|
|
||||||
|
|
@ -575,7 +575,7 @@ class _SplitterBase:
|
||||||
else:
|
else:
|
||||||
total_output_bytes += get_size_of_node(submod, node)[0]
|
total_output_bytes += get_size_of_node(submod, node)[0]
|
||||||
|
|
||||||
map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined]
|
map_arg(output_node.args, get_bytes)
|
||||||
qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
|
qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
|
||||||
reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
|
reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
|
||||||
reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
|
reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
|
||||||
|
|
|
||||||
|
|
@ -305,7 +305,7 @@ def _replace_pattern(
|
||||||
first_user_node = n
|
first_user_node = n
|
||||||
break
|
break
|
||||||
|
|
||||||
with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined]
|
with original_graph.inserting_before(first_user_node):
|
||||||
copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map)
|
copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map)
|
||||||
|
|
||||||
if isinstance(copied_returning_nodes, Node):
|
if isinstance(copied_returning_nodes, Node):
|
||||||
|
|
|
||||||
|
|
@ -650,14 +650,14 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
|
||||||
buffer = u.read(READ_DATA_CHUNK)
|
buffer = u.read(READ_DATA_CHUNK)
|
||||||
if len(buffer) == 0:
|
if len(buffer) == 0:
|
||||||
break
|
break
|
||||||
f.write(buffer) # type: ignore[possibly-undefined]
|
f.write(buffer)
|
||||||
if hash_prefix is not None:
|
if hash_prefix is not None:
|
||||||
sha256.update(buffer) # type: ignore[possibly-undefined]
|
sha256.update(buffer)
|
||||||
pbar.update(len(buffer))
|
pbar.update(len(buffer))
|
||||||
|
|
||||||
f.close()
|
f.close()
|
||||||
if hash_prefix is not None:
|
if hash_prefix is not None:
|
||||||
digest = sha256.hexdigest() # type: ignore[possibly-undefined]
|
digest = sha256.hexdigest()
|
||||||
if digest[:len(hash_prefix)] != hash_prefix:
|
if digest[:len(hash_prefix)] != hash_prefix:
|
||||||
raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
|
raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
|
||||||
shutil.move(f.name, dst)
|
shutil.move(f.name, dst)
|
||||||
|
|
|
||||||
|
|
@ -70,8 +70,8 @@ def fuser(name):
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
if name in ["fuser1", "fuser3"]: # NNC or oneDNN Graph
|
if name in ["fuser1", "fuser3"]: # NNC or oneDNN Graph
|
||||||
torch._C._jit_set_profiling_executor(old_profiling_executor) # type: ignore[possibly-undefined]
|
torch._C._jit_set_profiling_executor(old_profiling_executor)
|
||||||
torch._C._get_graph_executor_optimize(old_profiling_mode) # type: ignore[possibly-undefined]
|
torch._C._get_graph_executor_optimize(old_profiling_mode)
|
||||||
# recover the previous values
|
# recover the previous values
|
||||||
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
|
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
|
||||||
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
|
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
|
||||||
|
|
|
||||||
|
|
@ -254,7 +254,7 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
|
||||||
if assert_compiled:
|
if assert_compiled:
|
||||||
hits = compiled_fn.hits
|
hits = compiled_fn.hits
|
||||||
out = model(*args)
|
out = model(*args)
|
||||||
if assert_compiled and compiled_fn.hits == hits: # type: ignore[possibly-undefined]
|
if assert_compiled and compiled_fn.hits == hits:
|
||||||
raise RuntimeError("failed to use the compiled function")
|
raise RuntimeError("failed to use the compiled function")
|
||||||
if not isinstance(out, tuple):
|
if not isinstance(out, tuple):
|
||||||
out = (out,)
|
out = (out,)
|
||||||
|
|
@ -280,7 +280,7 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
|
||||||
assert model.has_trace_for(*args)
|
assert model.has_trace_for(*args)
|
||||||
|
|
||||||
if is_module:
|
if is_module:
|
||||||
model.load_state_dict(saved_state) # type: ignore[possibly-undefined]
|
model.load_state_dict(saved_state)
|
||||||
compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
|
compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
|
||||||
|
|
||||||
_verify_equal(uncompiled_outs, compiled_outs)
|
_verify_equal(uncompiled_outs, compiled_outs)
|
||||||
|
|
|
||||||
|
|
@ -1627,7 +1627,7 @@ def _std_var(
|
||||||
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
|
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
|
||||||
else:
|
else:
|
||||||
total = sum(
|
total = sum(
|
||||||
x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined]
|
x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask
|
||||||
)
|
)
|
||||||
if not keepdim:
|
if not keepdim:
|
||||||
count = count.reshape(total.shape)
|
count = count.reshape(total.shape)
|
||||||
|
|
|
||||||
|
|
@ -781,8 +781,8 @@ class SyncBatchNorm(_BatchNorm):
|
||||||
running_var,
|
running_var,
|
||||||
self.eps,
|
self.eps,
|
||||||
exponential_average_factor,
|
exponential_average_factor,
|
||||||
process_group, # type: ignore[possibly-undefined]
|
process_group,
|
||||||
world_size, # type: ignore[possibly-undefined]
|
world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -1604,9 +1604,9 @@ class Module:
|
||||||
# For now only forward hooks have the always_call option but perhaps
|
# For now only forward hooks have the always_call option but perhaps
|
||||||
# this functionality should be added to full backward hooks as well.
|
# this functionality should be added to full backward hooks as well.
|
||||||
for hook_id, hook in _global_forward_hooks.items():
|
for hook_id, hook in _global_forward_hooks.items():
|
||||||
if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined]
|
if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks:
|
||||||
try:
|
try:
|
||||||
hook_result = hook(self, args, result) # type: ignore[possibly-undefined]
|
hook_result = hook(self, args, result)
|
||||||
if hook_result is not None:
|
if hook_result is not None:
|
||||||
result = hook_result
|
result = hook_result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1615,12 +1615,12 @@ class Module:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for hook_id, hook in self._forward_hooks.items():
|
for hook_id, hook in self._forward_hooks.items():
|
||||||
if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined]
|
if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks:
|
||||||
try:
|
try:
|
||||||
if hook_id in self._forward_hooks_with_kwargs:
|
if hook_id in self._forward_hooks_with_kwargs:
|
||||||
hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined]
|
hook_result = hook(self, args, kwargs, result)
|
||||||
else:
|
else:
|
||||||
hook_result = hook(self, args, result) # type: ignore[possibly-undefined]
|
hook_result = hook(self, args, result)
|
||||||
if hook_result is not None:
|
if hook_result is not None:
|
||||||
result = hook_result
|
result = hook_result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -575,8 +575,8 @@ class RNN(RNNBase):
|
||||||
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
||||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||||
|
|
||||||
if not is_batched: # type: ignore[possibly-undefined]
|
if not is_batched:
|
||||||
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
|
output = output.squeeze(batch_dim)
|
||||||
hidden = hidden.squeeze(1)
|
hidden = hidden.squeeze(1)
|
||||||
|
|
||||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||||
|
|
@ -888,8 +888,8 @@ class LSTM(RNNBase):
|
||||||
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
||||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||||
else:
|
else:
|
||||||
if not is_batched: # type: ignore[possibly-undefined]
|
if not is_batched:
|
||||||
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
|
output = output.squeeze(batch_dim)
|
||||||
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
|
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
|
||||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||||
|
|
||||||
|
|
@ -1111,8 +1111,8 @@ class GRU(RNNBase):
|
||||||
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
||||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||||
else:
|
else:
|
||||||
if not is_batched: # type: ignore[possibly-undefined]
|
if not is_batched:
|
||||||
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
|
output = output.squeeze(batch_dim)
|
||||||
hidden = hidden.squeeze(1)
|
hidden = hidden.squeeze(1)
|
||||||
|
|
||||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ class _Orthogonal(Module):
|
||||||
Q = self.base @ Q
|
Q = self.base @ Q
|
||||||
if transposed:
|
if transposed:
|
||||||
Q = Q.mT
|
Q = Q.mT
|
||||||
return Q # type: ignore[possibly-undefined]
|
return Q
|
||||||
|
|
||||||
@torch.autograd.no_grad()
|
@torch.autograd.no_grad()
|
||||||
def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
|
def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
|
||||||
|
|
@ -293,7 +293,7 @@ def _create_node(
|
||||||
for _ in range(1, n_outputs):
|
for _ in range(1, n_outputs):
|
||||||
node.addOutput()
|
node.addOutput()
|
||||||
|
|
||||||
node_ouputs = tuple(node.outputs()) # type: ignore[possibly-undefined]
|
node_ouputs = tuple(node.outputs())
|
||||||
assert len(node_ouputs) == n_outputs
|
assert len(node_ouputs) == n_outputs
|
||||||
|
|
||||||
aten = domain_op.startswith("aten::")
|
aten = domain_op.startswith("aten::")
|
||||||
|
|
|
||||||
|
|
@ -1529,7 +1529,7 @@ def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_transpose_required:
|
if is_transpose_required:
|
||||||
softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined]
|
softmax = g.op("Transpose", softmax, perm_i=axes)
|
||||||
return softmax
|
return softmax
|
||||||
|
|
||||||
# Apply max normalization.
|
# Apply max normalization.
|
||||||
|
|
@ -2467,7 +2467,7 @@ def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
|
||||||
"Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
|
"Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
|
||||||
)
|
)
|
||||||
if is_transpose_required:
|
if is_transpose_required:
|
||||||
return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined]
|
return_op = g.op("Transpose", return_op, perm_i=axes)
|
||||||
return return_op
|
return return_op
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2978,7 +2978,7 @@ def native_layer_norm(
|
||||||
# mean and normalized, so we need to Cast it back
|
# mean and normalized, so we need to Cast it back
|
||||||
if is_type_half:
|
if is_type_half:
|
||||||
denominator = g.op(
|
denominator = g.op(
|
||||||
"Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() # type: ignore[possibly-undefined]
|
"Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type()
|
||||||
)
|
)
|
||||||
rdenominator = g.op("Reciprocal", denominator)
|
rdenominator = g.op("Reciprocal", denominator)
|
||||||
else:
|
else:
|
||||||
|
|
@ -4754,7 +4754,7 @@ def _generic_rnn(
|
||||||
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
|
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
|
||||||
)
|
)
|
||||||
return tuple(
|
return tuple(
|
||||||
symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined]
|
symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh)
|
||||||
)
|
)
|
||||||
|
|
||||||
@_beartype.beartype
|
@_beartype.beartype
|
||||||
|
|
@ -4766,10 +4766,10 @@ def _generic_rnn(
|
||||||
weight_ih, weight_hh, bias_ih, bias_hh = (
|
weight_ih, weight_hh, bias_ih, bias_hh = (
|
||||||
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
|
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
|
||||||
)
|
)
|
||||||
bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined]
|
bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0)
|
||||||
return tuple(
|
return tuple(
|
||||||
symbolic_helper._unsqueeze_helper(g, x, [0])
|
symbolic_helper._unsqueeze_helper(g, x, [0])
|
||||||
for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined]
|
for x in (weight_ih, weight_hh, bias_concat)
|
||||||
)
|
)
|
||||||
|
|
||||||
@_beartype.beartype
|
@_beartype.beartype
|
||||||
|
|
@ -4808,16 +4808,16 @@ def _generic_rnn(
|
||||||
|
|
||||||
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
|
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
|
||||||
|
|
||||||
inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined]
|
inputs.append(retrieve_state(h0, *state_indices))
|
||||||
if variant == "LSTM":
|
if variant == "LSTM":
|
||||||
inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined]
|
inputs.append(retrieve_state(c0, *state_indices))
|
||||||
|
|
||||||
extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"}
|
extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"}
|
||||||
if variant == "RNN":
|
if variant == "RNN":
|
||||||
if bidirectional:
|
if bidirectional:
|
||||||
activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined]
|
activation = [nonlinearity, nonlinearity]
|
||||||
else:
|
else:
|
||||||
activation = [nonlinearity] # type: ignore[possibly-undefined]
|
activation = [nonlinearity]
|
||||||
|
|
||||||
prev_output, h_out = g.op(
|
prev_output, h_out = g.op(
|
||||||
"RNN",
|
"RNN",
|
||||||
|
|
@ -4859,17 +4859,17 @@ def _generic_rnn(
|
||||||
else:
|
else:
|
||||||
prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1])
|
prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1])
|
||||||
|
|
||||||
h_outs.append(h_out) # type: ignore[possibly-undefined]
|
h_outs.append(h_out)
|
||||||
if variant == "LSTM":
|
if variant == "LSTM":
|
||||||
c_outs.append(c_out) # type: ignore[possibly-undefined]
|
c_outs.append(c_out)
|
||||||
if batch_first:
|
if batch_first:
|
||||||
# seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size
|
# seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size
|
||||||
prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2])
|
prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2])
|
||||||
h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined]
|
h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0)
|
||||||
if variant == "RNN" or variant == "GRU":
|
if variant == "RNN" or variant == "GRU":
|
||||||
return prev_output, h_outs
|
return prev_output, h_outs
|
||||||
elif variant == "LSTM":
|
elif variant == "LSTM":
|
||||||
c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined]
|
c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0)
|
||||||
return prev_output, h_outs, c_outs
|
return prev_output, h_outs, c_outs
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -199,7 +199,7 @@ class BasicEvaluation:
|
||||||
while (
|
while (
|
||||||
current_kernel_index < len(cuda_kernel_events)
|
current_kernel_index < len(cuda_kernel_events)
|
||||||
and (cuda_kernel_events[current_kernel_index].start_us()) * 1000
|
and (cuda_kernel_events[current_kernel_index].start_us()) * 1000
|
||||||
<= start_time # type: ignore[possibly-undefined]
|
<= start_time
|
||||||
):
|
):
|
||||||
current_kernel_index += 1
|
current_kernel_index += 1
|
||||||
current_queue_depth = spawned_kernel_index - current_kernel_index + 1
|
current_queue_depth = spawned_kernel_index - current_kernel_index + 1
|
||||||
|
|
@ -207,7 +207,7 @@ class BasicEvaluation:
|
||||||
|
|
||||||
if hasattr(event, "start_us"):
|
if hasattr(event, "start_us"):
|
||||||
queue_depth_list.append(
|
queue_depth_list.append(
|
||||||
Interval(start_time, end_time, current_queue_depth) # type: ignore[possibly-undefined]
|
Interval(start_time, end_time, current_queue_depth)
|
||||||
)
|
)
|
||||||
elif hasattr(event, "start_time_ns"):
|
elif hasattr(event, "start_time_ns"):
|
||||||
self.metrics[EventKey(event)].queue_depth = current_queue_depth
|
self.metrics[EventKey(event)].queue_depth = current_queue_depth
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
|
||||||
else:
|
else:
|
||||||
outp = weight
|
outp = weight
|
||||||
|
|
||||||
ncols, nrows = outp.shape # type: ignore[possibly-undefined]
|
ncols, nrows = outp.shape
|
||||||
assert nrows % (32 if dtypeq == torch.quint4x2 else 64) == 0
|
assert nrows % (32 if dtypeq == torch.quint4x2 else 64) == 0
|
||||||
assert ncols % 64 == 0
|
assert ncols % 64 == 0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -134,11 +134,11 @@ def sparse_semi_structured_from_dense_cutlass(dense):
|
||||||
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
||||||
|
|
||||||
if dense.dtype != torch.float:
|
if dense.dtype != torch.float:
|
||||||
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
|
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1))
|
||||||
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
||||||
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
||||||
else:
|
else:
|
||||||
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
|
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2)
|
||||||
|
|
||||||
meta_4 = idxs0 | (idxs1 << 2)
|
meta_4 = idxs0 | (idxs1 << 2)
|
||||||
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
||||||
|
|
@ -163,7 +163,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reorder meta tensor elements.
|
# Reorder meta tensor elements.
|
||||||
meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined]
|
meta_reordered = meta.new_empty((m * meta_ncols,))
|
||||||
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
||||||
m, meta_ncols, meta_dtype, device
|
m, meta_ncols, meta_dtype, device
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1662,7 +1662,7 @@ if has_triton():
|
||||||
acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
|
acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
|
||||||
|
|
||||||
if is_compressed:
|
if is_compressed:
|
||||||
A_ptr += r0 * blocks_stride_P # type: ignore[possibly-undefined]
|
A_ptr += r0 * blocks_stride_P
|
||||||
for _ in range(nnz):
|
for _ in range(nnz):
|
||||||
q = tl.load(q_ptr)
|
q = tl.load(q_ptr)
|
||||||
B = tl.load(B_ptr + q)
|
B = tl.load(B_ptr + q)
|
||||||
|
|
@ -1889,7 +1889,7 @@ if has_triton():
|
||||||
|
|
||||||
# alpha is never 0
|
# alpha is never 0
|
||||||
if beta_is_nonzero:
|
if beta_is_nonzero:
|
||||||
output_acc_block = tl.load(input_ptrs).to(acc_dtype) # type: ignore[possibly-undefined]
|
output_acc_block = tl.load(input_ptrs).to(acc_dtype)
|
||||||
if not (beta_is_one and alpha_is_one):
|
if not (beta_is_one and alpha_is_one):
|
||||||
beta_alpha = beta / alpha
|
beta_alpha = beta / alpha
|
||||||
output_acc_block *= beta_alpha
|
output_acc_block *= beta_alpha
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ def _output_csv(file, results):
|
||||||
dim_str = str(dim)
|
dim_str = str(dim)
|
||||||
shape_str = 'x'.join(str(s) for s in shape)
|
shape_str = 'x'.join(str(s) for s in shape)
|
||||||
|
|
||||||
print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str, # type: ignore[possibly-undefined]
|
print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str,
|
||||||
measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6,
|
measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6,
|
||||||
sep=',', file=file)
|
sep=',', file=file)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -701,7 +701,7 @@ class _ValgrindWrapper:
|
||||||
if fn_match:
|
if fn_match:
|
||||||
ir_str, file_function = fn_match.groups()
|
ir_str, file_function = fn_match.groups()
|
||||||
ir = int(ir_str.replace(",", ""))
|
ir = int(ir_str.replace(",", ""))
|
||||||
if ir == program_totals: # type: ignore[possibly-undefined]
|
if ir == program_totals:
|
||||||
# Callgrind includes some top level red herring symbols when
|
# Callgrind includes some top level red herring symbols when
|
||||||
# a program dumps multiple profiles.
|
# a program dumps multiple profiles.
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -1427,7 +1427,7 @@ def _checkpoint_without_reentrant_generator(
|
||||||
new_frame.forward_completed = True
|
new_frame.forward_completed = True
|
||||||
|
|
||||||
if getattr(device_module, "_initialized", False) and \
|
if getattr(device_module, "_initialized", False) and \
|
||||||
preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined]
|
preserve_rng_state and not had_device_in_fwd:
|
||||||
# Device was not initialized before running the forward, so we didn't
|
# Device was not initialized before running the forward, so we didn't
|
||||||
# stash the device state.
|
# stash the device state.
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
|
||||||
|
|
@ -2391,7 +2391,7 @@ def _write_ninja_file(path,
|
||||||
# 'Blocks' should be separated by newlines, for visual benefit.
|
# 'Blocks' should be separated by newlines, for visual benefit.
|
||||||
blocks = [config, flags, compile_rule]
|
blocks = [config, flags, compile_rule]
|
||||||
if with_cuda:
|
if with_cuda:
|
||||||
blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined]
|
blocks.append(cuda_compile_rule)
|
||||||
blocks += [devlink_rule, link_rule, build, devlink, link, default]
|
blocks += [devlink_rule, link_rule, build, devlink, link, default]
|
||||||
content = "\n\n".join("\n".join(b) for b in blocks)
|
content = "\n\n".join("\n".join(b) for b in blocks)
|
||||||
# Ninja requires a new lines at the end of the .ninja file
|
# Ninja requires a new lines at the end of the .ninja file
|
||||||
|
|
|
||||||
|
|
@ -305,7 +305,7 @@ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
|
||||||
init_exception = None
|
init_exception = None
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
|
data = fetcher.fetch(index)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
|
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
|
||||||
data = _IterableDatasetStopIteration(worker_id)
|
data = _IterableDatasetStopIteration(worker_id)
|
||||||
|
|
|
||||||
|
|
@ -1360,7 +1360,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||||
# not found (i.e., didn't break)
|
# not found (i.e., didn't break)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined]
|
self._index_queues[worker_queue_idx].put((self._send_idx, index))
|
||||||
self._task_info[self._send_idx] = (worker_queue_idx,)
|
self._task_info[self._send_idx] = (worker_queue_idx,)
|
||||||
self._tasks_outstanding += 1
|
self._tasks_outstanding += 1
|
||||||
self._send_idx += 1
|
self._send_idx += 1
|
||||||
|
|
|
||||||
|
|
@ -210,7 +210,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
|
||||||
raise BufferError("ForkerIterDataPipe buffer overflow," +
|
raise BufferError("ForkerIterDataPipe buffer overflow," +
|
||||||
f"buffer size {self.buffer_size} is insufficient.")
|
f"buffer size {self.buffer_size} is insufficient.")
|
||||||
|
|
||||||
yield self.copy_fn(return_val) # type: ignore[possibly-undefined]
|
yield self.copy_fn(return_val)
|
||||||
finally:
|
finally:
|
||||||
self._child_stop[instance_id] = True
|
self._child_stop[instance_id] = True
|
||||||
# Cleanup _datapipe_iterator for the case that fork exits earlier
|
# Cleanup _datapipe_iterator for the case that fork exits earlier
|
||||||
|
|
|
||||||
|
|
@ -907,7 +907,7 @@ class SummaryWriter:
|
||||||
else:
|
else:
|
||||||
# Handles cnn.CNNModelHelper, model_helper.ModelHelper
|
# Handles cnn.CNNModelHelper, model_helper.ModelHelper
|
||||||
current_graph = model_to_graph_def(model)
|
current_graph = model_to_graph_def(model)
|
||||||
event = event_pb2.Event(graph_def=current_graph.SerializeToString()) # type: ignore[possibly-undefined]
|
event = event_pb2.Event(graph_def=current_graph.SerializeToString())
|
||||||
self._get_file_writer().add_event(event)
|
self._get_file_writer().add_event(event)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -717,10 +717,10 @@ resize_out(out, sizes, strides, options);
|
||||||
f"{textwrap.indent(class_ctor_str, indent)}",
|
f"{textwrap.indent(class_ctor_str, indent)}",
|
||||||
f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
|
f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
|
||||||
" const Tensor& maybe_get_output(int64_t output_idx) override {",
|
" const Tensor& maybe_get_output(int64_t output_idx) override {",
|
||||||
f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit
|
f" return {output_value};\n",
|
||||||
" }",
|
" }",
|
||||||
f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", # type: ignore[possibly-undefined] # TODO: audit
|
f" std::array<{output_type}, {len(f.func.returns)}> outputs_;",
|
||||||
f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit
|
f"{textwrap.indent(proxy_field, indent)}",
|
||||||
f"{textwrap.indent(guard_field, indent)}",
|
f"{textwrap.indent(guard_field, indent)}",
|
||||||
"};",
|
"};",
|
||||||
)
|
)
|
||||||
|
|
@ -962,7 +962,7 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
||||||
else:
|
else:
|
||||||
refs = ", ".join(a.name for a in f.func.arguments.out)
|
refs = ", ".join(a.name for a in f.func.arguments.out)
|
||||||
ret_expr = f"std::forward_as_tuple({refs})"
|
ret_expr = f"std::forward_as_tuple({refs})"
|
||||||
sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit
|
sig_body.append(f"return {ret_expr};")
|
||||||
|
|
||||||
sig_body_str = "\n".join(sig_body)
|
sig_body_str = "\n".join(sig_body)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user