[1/N] Remove unused loop variables (#166258)

This PR removes unused loop variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166258
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
This commit is contained in:
Yuanyuan Chen 2025-10-30 12:22:25 +00:00 committed by PyTorch MergeBot
parent 369f2d6951
commit 2de4cf2102
47 changed files with 73 additions and 101 deletions

View File

@ -1791,7 +1791,7 @@ def rewrite_signature(
for i, val in enumerate(sources): for i, val in enumerate(sources):
dict_of_source_vals[id(val)] = i dict_of_source_vals[id(val)] = i
for i, val in enumerate(candidates): for val in candidates:
if isinstance(val, tuple(common_constant_types)): if isinstance(val, tuple(common_constant_types)):
matched_elements_positions.append(None) matched_elements_positions.append(None)
elif id(val) not in dict_of_source_vals: elif id(val) not in dict_of_source_vals:

View File

@ -319,7 +319,7 @@ class GuardManagerWrapper:
is_diff_guard_node = ( is_diff_guard_node = (
node.get_source() in self.diff_guard_sources or node.fail_count() > 0 node.get_source() in self.diff_guard_sources or node.fail_count() > 0
) )
for idx, (key_mgr, val_mgr) in sorted( for _idx, (key_mgr, val_mgr) in sorted(
node.get_key_value_managers().items() node.get_key_value_managers().items()
): ):
is_diff_guard_node |= visit(key_mgr) | visit(val_mgr) is_diff_guard_node |= visit(key_mgr) | visit(val_mgr)
@ -442,7 +442,7 @@ class GuardManagerWrapper:
is_subtree_tag_safe = True is_subtree_tag_safe = True
# Recurse to get the tag safe roots from subtree. # Recurse to get the tag safe roots from subtree.
for idx, (key_mgr, val_mgr) in sorted( for _idx, (key_mgr, val_mgr) in sorted(
node.get_key_value_managers().items() node.get_key_value_managers().items()
): ):
if key_mgr is not None: if key_mgr is not None:
@ -450,9 +450,7 @@ class GuardManagerWrapper:
if val_mgr is not None: if val_mgr is not None:
tag_safe_roots.extend(visit(val_mgr)) tag_safe_roots.extend(visit(val_mgr))
for idx, (key_mgr, val_mgr) in sorted( for key_mgr, val_mgr in node.get_key_value_managers().values():
node.get_key_value_managers().items()
):
if key_mgr: if key_mgr:
is_subtree_tag_safe &= key_mgr.is_tag_safe() is_subtree_tag_safe &= key_mgr.is_tag_safe()

View File

@ -289,9 +289,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params")) params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
all_static = True all_static = True
non_static_grads = [] non_static_grads = []
for p_ind, (p, p_vt) in enumerate( for p, p_vt in zip(group["params"], params_vt.unpack_var_sequence(tx)):
zip(group["params"], params_vt.unpack_var_sequence(tx))
):
param_source = p_vt.source param_source = p_vt.source
self.tensor_to_source[p] = param_source self.tensor_to_source[p] = param_source
grad_source = GradSource( grad_source = GradSource(
@ -322,12 +320,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
# We have to again iterate over the state dict to collect the # We have to again iterate over the state dict to collect the
# tensor_to_source dict. This is used for the finalizer. # tensor_to_source dict. This is used for the finalizer.
for idx, (p, value) in enumerate(self.value.state.items()): for idx, value in enumerate(self.value.state.values()):
p_state_source = DictGetItemSource( p_state_source = DictGetItemSource(
state_source, ConstDictKeySource(state_source, idx) state_source, ConstDictKeySource(state_source, idx)
) )
tx.output.guard_on_key_order.add(p_state_source) tx.output.guard_on_key_order.add(p_state_source)
for inner_idx, (k, v) in enumerate(value.items()): for inner_idx, v in enumerate(value.values()):
if ( if (
isinstance(v, torch.Tensor) isinstance(v, torch.Tensor)
and v not in self.grad_to_source and v not in self.grad_to_source

View File

@ -240,7 +240,7 @@ def run_functionalized_fw_and_collect_metadata(
# Inspect the state of the input tensor functional wrapper to detect input mutation info # Inspect the state of the input tensor functional wrapper to detect input mutation info
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version # If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)): for arg, f_arg in zip(flat_args, flat_f_args):
# NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in # NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in
# strides between the functionalized arg inner tensors and non-functionalized arg inner # strides between the functionalized arg inner tensors and non-functionalized arg inner
# tensors. This is a problem as the inner tensor stride change may not be reflected # tensors. This is a problem as the inner tensor stride change may not be reflected

View File

@ -2041,7 +2041,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
assert len(meta.attrs) == len(runtime_subclass_keys) assert len(meta.attrs) == len(runtime_subclass_keys)
leaves = [] leaves = []
for i, (attr, attr_meta) in enumerate(meta.attrs.items()): for attr, attr_meta in meta.attrs.items():
elem = getattr(x, attr) elem = getattr(x, attr)
new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent( new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent(
elem, attr_meta elem, attr_meta

View File

@ -98,7 +98,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
module, name, UnwrapTensorSubclass() module, name, UnwrapTensorSubclass()
) )
for name, child in module.named_children(): for child in module.children():
unwrap_tensor_subclass_parameters(child) unwrap_tensor_subclass_parameters(child)
return module return module

View File

@ -1481,9 +1481,7 @@ def functionalize_rng_ops(
) )
) )
for rng_count, (base_node, node_pair) in enumerate( for rng_count, node_pair in enumerate(recomputable_rng_ops_map.values()):
recomputable_rng_ops_map.items()
):
# Step 2 - Modify the fwd pass such that # Step 2 - Modify the fwd pass such that
fw_node = node_pair["fwd"] fw_node = node_pair["fwd"]
bw_node = node_pair["bwd"] bw_node = node_pair["bwd"]
@ -2714,9 +2712,7 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
subgraph = getattr(module, hop_node.args[0].target) subgraph = getattr(module, hop_node.args[0].target)
if isinstance(subgraph, fx.GraphModule): if isinstance(subgraph, fx.GraphModule):
new_rng_inputs = [] new_rng_inputs = []
for idx, placeholder_node in enumerate( for placeholder_node in subgraph.graph.find_nodes(op="placeholder"):
subgraph.graph.find_nodes(op="placeholder")
):
if rng_string in placeholder_node.name: if rng_string in placeholder_node.name:
# Found a rng state placeholder in the hop graph, lets add # Found a rng state placeholder in the hop graph, lets add
# the corresponding node in the outer graph # the corresponding node in the outer graph

View File

@ -116,7 +116,7 @@ def temporarily_restore_interpreter_stack(stack):
pushed.append(s) pushed.append(s)
yield yield
finally: finally:
for s in reversed(pushed): for _ in reversed(pushed):
# TODO: would be nice to assert that the layers are the same, but # TODO: would be nice to assert that the layers are the same, but
# Python object identity is not preserved # Python object identity is not preserved
pop_dynamic_layer_stack() pop_dynamic_layer_stack()

View File

@ -907,7 +907,7 @@ def diff_tensor_meta(
try: try:
if val1 != val2: if val1 != val2:
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'") pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
except GuardOnDataDependentSymNode as _: except GuardOnDataDependentSymNode:
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'") pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
continue continue
return pair_diffs return pair_diffs
@ -1197,7 +1197,7 @@ def materialize_callable_in_args(op: HopInstance, args, kwargs):
# call_op preserves ordering of proxies via schema # call_op preserves ordering of proxies via schema
materialized_args = [] materialized_args = []
for i, (proxy, arg) in enumerate(zip(arg_proxies, schema.arguments)): for i, proxy in enumerate(arg_proxies):
if ( if (
isinstance(proxy, torch.fx.Node) isinstance(proxy, torch.fx.Node)
and proxy.op == "get_attr" and proxy.op == "get_attr"

View File

@ -316,7 +316,7 @@ def while_loop_dense(
if stack_output: if stack_output:
outs: list[torch.Tensor] = [] outs: list[torch.Tensor] = []
for i, out in enumerate(outputs): for out in outputs:
outs.append(torch.stack(out, dim=0)) outs.append(torch.stack(out, dim=0))
return tuple(outs) return tuple(outs)

View File

@ -2606,7 +2606,7 @@ def custom_op_wrapper(op: str, *args: Any) -> list[c_void_p] | c_void_p | None:
if isinstance(result, (list, tuple)): if isinstance(result, (list, tuple)):
# unsafe_alloc_void_ptrs_from_tensors expects result contains tensor only # unsafe_alloc_void_ptrs_from_tensors expects result contains tensor only
result = [torch.tensor([]) if r is None else r for r in result] result = [torch.tensor([]) if r is None else r for r in result]
for i, r in enumerate(result): for r in result:
assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors" assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors"
return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type] return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type]

View File

@ -895,7 +895,7 @@ class MetalKernel(SIMDKernel):
else: else:
dtype_str = self.dtype_to_str(dtype) dtype_str = self.dtype_to_str(dtype)
code.writeline(f"constant {dtype_str}* {inner},") code.writeline(f"constant {dtype_str}* {inner},")
for outer, inner in self.args.sizevars.items(): for inner in self.args.sizevars.values():
code.writeline(f"constant long& {inner},") code.writeline(f"constant long& {inner},")
# Write dynamic values as inputs # Write dynamic values as inputs

View File

@ -218,7 +218,7 @@ class MultiKernel:
# the multi call kernel. # the multi call kernel.
multi_call_args = call_args multi_call_args = call_args
multi_call_arg_types = arg_types multi_call_arg_types = arg_types
for i, kernel in enumerate(self.kernels): for kernel in self.kernels:
additional_call_args, additional_arg_types = ( additional_call_args, additional_arg_types = (
kernel.additional_call_args_and_types() kernel.additional_call_args_and_types()
) )

View File

@ -717,7 +717,7 @@ class ComboKernel(Kernel):
self, name: str, call_args: list[Any], arg_types: list[Any] self, name: str, call_args: list[Any], arg_types: list[Any]
) -> None: ) -> None:
for num, sub_kernel in enumerate(self.sub_kernels): for num, sub_kernel in enumerate(self.sub_kernels):
for i, tree in enumerate(sub_kernel.range_trees): for tree in sub_kernel.range_trees:
numel_name = f"{tree.prefix}numel_{num}" numel_name = f"{tree.prefix}numel_{num}"
if numel_name not in self.dynamic_shape_args: if numel_name not in self.dynamic_shape_args:
continue continue
@ -735,7 +735,7 @@ class ComboKernel(Kernel):
def kernel_benchmark_extra_args(self) -> list[str]: def kernel_benchmark_extra_args(self) -> list[str]:
extra_args = [] extra_args = []
for num, sub_kernel in enumerate(self.sub_kernels): for num, sub_kernel in enumerate(self.sub_kernels):
for i, tree in enumerate(sub_kernel.range_trees): for tree in sub_kernel.range_trees:
numel_name = f"{tree.prefix}numel_{num}" numel_name = f"{tree.prefix}numel_{num}"
if numel_name not in self.dynamic_shape_args: if numel_name not in self.dynamic_shape_args:
continue continue
@ -1018,7 +1018,7 @@ class ComboKernel(Kernel):
for num, sub_kernel in enumerate(self.sub_kernels): for num, sub_kernel in enumerate(self.sub_kernels):
meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim
for i, tree in enumerate(sub_kernel.range_trees): for tree in sub_kernel.range_trees:
# pyrefly: ignore [missing-argument] # pyrefly: ignore [missing-argument]
if not tree.is_reduction: if not tree.is_reduction:
numel_name = f"{tree.prefix}numel_{num}" numel_name = f"{tree.prefix}numel_{num}"

View File

@ -3604,16 +3604,12 @@ class PythonWrapperCodegen(CodeGen):
self.writeline("if not should_loop:") self.writeline("if not should_loop:")
if stack_output: if stack_output:
# Handle the case when loop never executes # Handle the case when loop never executes
for i, (carried_input, carried_buf) in enumerate( for i, carried_input in enumerate(outer_carried_inputs):
zip(outer_carried_inputs, while_loop.carried_inputs)
):
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
self.writeline(f"{name}[{i}] = {carried_input}.unsqueeze(0).clone()") self.writeline(f"{name}[{i}] = {carried_input}.unsqueeze(0).clone()")
self.writeline(ExitSubgraphLine(self)) self.writeline(ExitSubgraphLine(self))
else: else:
for i, (carried_input, carried_buf) in enumerate( for i, carried_input in enumerate(outer_carried_inputs):
zip(outer_carried_inputs, while_loop.carried_inputs)
):
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
self.writeline(f"{name}[{i}] = {carried_input}.clone()") self.writeline(f"{name}[{i}] = {carried_input}.clone()")
self.writeline(ExitSubgraphLine(self)) self.writeline(ExitSubgraphLine(self))

View File

@ -424,10 +424,7 @@ def _reorder_communication_preserving_peak_memory_internal(
return return
# Candidate becomes last use of some bufs # Candidate becomes last use of some bufs
for ( for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values():
gn,
bufs,
) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
for buf in bufs: for buf in bufs:
buf_to_snode_last_use[buf] = candidate buf_to_snode_last_use[buf] = candidate
@ -840,7 +837,7 @@ def _schedule_for_comm(
else: else:
schedule(snode) schedule(snode)
for snode, deps in unmet_deps.items(): for deps in unmet_deps.values():
assert len(deps) == 0, ( assert len(deps) == 0, (
f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}" f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
) )
@ -1552,11 +1549,8 @@ Graph: {graph}
node.args = new_args node.args = new_args
# Delete `fsdp.copy_(unsharded_param, Y)` nodes # Delete `fsdp.copy_(unsharded_param, Y)` nodes
for ( for fsdp_copy_node_idxes in unsharded_param_to_fsdp_copy_node_idxes.values():
unsharded_param, for fsdp_copy_node_idx in fsdp_copy_node_idxes:
fsdp_copy_node_idxes,
) in unsharded_param_to_fsdp_copy_node_idxes.items():
for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
fsdp_copy_node = node_list[fsdp_copy_node_idx] fsdp_copy_node = node_list[fsdp_copy_node_idx]
graph.erase_node(fsdp_copy_node) graph.erase_node(fsdp_copy_node)

View File

@ -46,7 +46,7 @@ def _debug_iterative_memory_recompute(
if iter_cm != new_cm: if iter_cm != new_cm:
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH" log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
iterative_recompute_error = True iterative_recompute_error = True
for i, gn in enumerate(gns): for gn in gns:
iter_gnm = iter_curr_memory[gn] iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn] new_gnm = est_curr_memory[gn]
if iter_gnm != new_gnm: if iter_gnm != new_gnm:
@ -65,7 +65,7 @@ def _debug_iterative_memory_recompute(
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}" f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
) )
peak_log = "" peak_log = ""
for i, (pre, post) in enumerate(snodes_curr_memory): for i, (pre, _post) in enumerate(snodes_curr_memory):
if est_peak_memory == pre: if est_peak_memory == pre:
n = snodes[i] n = snodes[i]
peak_log = ( peak_log = (

View File

@ -454,7 +454,7 @@ def decompose_map_to_while_loop(gm: torch.fx.GraphModule):
graph_pass.apply(gm) graph_pass.apply(gm)
for node in gm.graph.find_nodes( for _node in gm.graph.find_nodes(
op="call_function", target=torch.ops.higher_order.map_impl op="call_function", target=torch.ops.higher_order.map_impl
): ):
raise AssertionError("map is not lowered to while_loop") raise AssertionError("map is not lowered to while_loop")
@ -666,7 +666,7 @@ def decompose_scan_to_while_loop(gm: torch.fx.GraphModule):
graph_pass.apply(gm) graph_pass.apply(gm)
for node in gm.graph.find_nodes( for _node in gm.graph.find_nodes(
op="call_function", target=torch.ops.higher_order.scan op="call_function", target=torch.ops.higher_order.scan
): ):
raise AssertionError("scan is not lowered to while_loop") raise AssertionError("scan is not lowered to while_loop")
@ -1265,7 +1265,7 @@ def decompose_triton_kernel_wrapper_functional(graph):
graph_pass.apply(graph) graph_pass.apply(graph)
for node in graph.find_nodes( for _ in graph.find_nodes(
op="call_function", op="call_function",
target=torch.ops.higher_order.triton_kernel_wrapper_functional, target=torch.ops.higher_order.triton_kernel_wrapper_functional,
): ):

View File

@ -8770,9 +8770,7 @@ class WhileLoop(ExternKernel):
seen_buffers: OrderedSet[int] = OrderedSet() seen_buffers: OrderedSet[int] = OrderedSet()
result: list[Union[IRNode, TensorBox, ShapeAsConstantBuffer]] = [] result: list[Union[IRNode, TensorBox, ShapeAsConstantBuffer]] = []
for i, (original_input, unwrapped_buffer) in enumerate( for original_input, unwrapped_buffer in zip(carried_inputs, unwrapped_buffers):
zip(carried_inputs, unwrapped_buffers)
):
if id(unwrapped_buffer) in seen_buffers: if id(unwrapped_buffer) in seen_buffers:
result.append(ExternKernel.copy_input(original_input)) result.append(ExternKernel.copy_input(original_input))
else: else:

View File

@ -743,7 +743,7 @@ class _TargetArgsExpr(_TargetExpr):
assert len(node_items) == len(self_items) assert len(node_items) == len(self_items)
m = Match(ctx, self) m = Match(ctx, self)
for i, pattern, child_node in zip(itertools.count(), self_items, node_items): for pattern, child_node in zip(self_items, node_items):
if isinstance(pattern, PatternExpr): if isinstance(pattern, PatternExpr):
child_match = ctx.match(pattern, child_node) child_match = ctx.match(pattern, child_node)
if not is_match(child_match): if not is_match(child_match):

View File

@ -2850,7 +2850,7 @@ class Scheduler:
# NB: None means that the dependency is on an input. Don't actually # NB: None means that the dependency is on an input. Don't actually
# generate a dependency because if we do, Inductor will start trying # generate a dependency because if we do, Inductor will start trying
# to free the unbacked int but that's pointless # to free the unbacked int but that's pointless
for name, val in V.graph.graph_inputs.items(): for val in V.graph.graph_inputs.values():
if isinstance(val, sympy.Expr): if isinstance(val, sympy.Expr):
for fs in val.free_symbols: for fs in val.free_symbols:
unbacked_symbol_to_origin_node[fs] = None unbacked_symbol_to_origin_node[fs] = None
@ -3550,9 +3550,7 @@ class Scheduler:
future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = [] future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
for hint_override in config.multi_kernel_hints: for hint_override in config.multi_kernel_hints:
choice_timings = multi_node.choice_timings(hint_override) choice_timings = multi_node.choice_timings(hint_override)
for choice, unfused_time in sorted( for choice, _ in sorted(choice_timings.items(), key=lambda x: x[1]):
choice_timings.items(), key=lambda x: x[1]
):
if not isinstance( if not isinstance(
choice, torch._inductor.select_algorithm.TritonTemplateCaller choice, torch._inductor.select_algorithm.TritonTemplateCaller
): ):

View File

@ -425,7 +425,7 @@ def apply_var_mapping(
new_ranges, norm_pw_vars + norm_red_vars, strict=True new_ranges, norm_pw_vars + norm_red_vars, strict=True
): ):
range_vars = [] range_vars = []
for i in range(len(new_range)): for _ in range(len(new_range)):
range_vars.append(flat_vars[count]) range_vars.append(flat_vars[count])
count += 1 count += 1

View File

@ -348,7 +348,7 @@ def _do_bench_using_profiling(
] ]
) as p: ) as p:
# Benchmark # Benchmark
for i in range(n_repeat): for _ in range(n_repeat):
# we clear the L2 cache before each run # we clear the L2 cache before each run
cache.zero_() cache.zero_()
# record time of `fn` # record time of `fn`

View File

@ -3118,7 +3118,7 @@ def _validate_symbolic_output_for_caching(
if is_tracing: if is_tracing:
# Check for SymNode types in PROXY mode - this should bypass caching # Check for SymNode types in PROXY mode - this should bypass caching
# regardless of whether symbols are known or not # regardless of whether symbols are known or not
for node in _iterate_nodes(output): for _ in _iterate_nodes(output):
raise _BypassDispatchCache("Proxy mode with SymNode output") raise _BypassDispatchCache("Proxy mode with SymNode output")
else: else:
# Check for unrepresented symbols in tensor expressions # Check for unrepresented symbols in tensor expressions

View File

@ -137,7 +137,7 @@ def _get_logger_dict_helper(
def get_prefix(prefix): def get_prefix(prefix):
return prefix if prefix == "" else prefix + "." return prefix if prefix == "" else prefix + "."
for name, child in mod.named_children(): for child in mod.children():
if isinstance(child, Logger): if isinstance(child, Logger):
target_dict[get_prefix(prefix) + "stats"] = child.stats target_dict[get_prefix(prefix) + "stats"] = child.stats
break break

View File

@ -909,8 +909,7 @@ def create_a_shadows_b(
# is added # is added
prev_node_c_list = [env_c[arg.name] for arg in prev_node_b] prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
for arg_idx, arg in enumerate(prev_node_b): for arg_idx, prev_node_c in enumerate(prev_node_c_list):
prev_node_c = prev_node_c_list[arg_idx]
env_c[prev_node_c.name] = _insert_logger_after_node( env_c[prev_node_c.name] = _insert_logger_after_node(
prev_node_c, prev_node_c,
gm_b, gm_b,

View File

@ -151,6 +151,6 @@ def bias_correction(
bias.data = updated_bias bias.data = updated_bias
# Resets the data contained in the loggers # Resets the data contained in the loggers
for name, submodule in quantized_model.named_modules(): for submodule in quantized_model.modules():
if isinstance(submodule, MeanShadowLogger): if isinstance(submodule, MeanShadowLogger):
submodule.clear() submodule.clear()

View File

@ -297,7 +297,7 @@ def _get_numerical_jacobian(
inp_indices = [ inp_indices = [
i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad
] ]
for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)): for inp, inp_idx in zip(_iter_tensors(target, True), inp_indices):
jacobians += [ jacobians += [
get_numerical_jacobian_wrt_specific_input( get_numerical_jacobian_wrt_specific_input(
fn, fn,
@ -549,7 +549,7 @@ def _get_analytical_jacobian_forward_ad(
with fwAD.dual_level(): with fwAD.dual_level():
fw_grads = [] fw_grads = []
dual_inputs = [] dual_inputs = []
for i, inp in enumerate(inputs): for inp in inputs:
if is_tensor_like(inp) and inp.requires_grad: if is_tensor_like(inp) and inp.requires_grad:
if inp.layout == torch._mkldnn: # type: ignore[attr-defined] if inp.layout == torch._mkldnn: # type: ignore[attr-defined]
raise ValueError( raise ValueError(
@ -1275,7 +1275,7 @@ def _test_undefined_forward_mode(func, outputs, inputs):
tensor_indices.add(i) tensor_indices.add(i)
dual_inputs.append(inp) dual_inputs.append(inp)
for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)): for fw_grad, u in zip(fw_grads, all_u):
fw_grad.copy_(u.view_as(fw_grad)) fw_grad.copy_(u.view_as(fw_grad))
for idx, inp in enumerate(inputs): for idx, inp in enumerate(inputs):

View File

@ -41,7 +41,7 @@ class _PseudoZipFile:
pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL) pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL)
for key, (data, length) in self.records.items(): for data, _ in self.records.values():
if isinstance(data, bytes): if isinstance(data, bytes):
f.write(data) f.write(data)
elif isinstance(data, str): elif isinstance(data, str):

View File

@ -578,7 +578,7 @@ def _load_model_state_dict(
assign = False assign = False
if info.broadcast_from_rank0 or info.full_state_dict: if info.broadcast_from_rank0 or info.full_state_dict:
devices = set() devices = set()
for key, value in local_state_dict.items(): for value in local_state_dict.values():
if torch.is_tensor(value) and value.dim() > 0: if torch.is_tensor(value) and value.dim() > 0:
devices.add(value.device) devices.add(value.device)
# In lora state_dict, there could be multiple devices, with meta device inside. # In lora state_dict, there could be multiple devices, with meta device inside.

View File

@ -2087,14 +2087,14 @@ class FlatParamHandle:
param.grad.data = view param.grad.data = view
else: else:
param.grad = view param.grad = view
for i, ( for (
param_name, param_name,
module, module,
module_name, module_name,
prim_param_name, prim_param_name,
prim_module, prim_module,
_, _,
) in enumerate(self.flat_param._shared_param_infos): ) in self.flat_param._shared_param_infos:
_p_assert( _p_assert(
hasattr(module, param_name), hasattr(module, param_name),
f"{module_name + '.' + param_name if module_name else param_name} is missing", f"{module_name + '.' + param_name if module_name else param_name} is missing",
@ -2171,11 +2171,8 @@ class FlatParamHandle:
param.data = flat_param[offset : offset + numel_in_shard] param.data = flat_param[offset : offset + numel_in_shard]
if self.flat_param._shared_params is None: if self.flat_param._shared_params is None:
raise AssertionError("Expected _shared_params to be not None") raise AssertionError("Expected _shared_params to be not None")
for i, ( for param, (param_name, module, _, prim_param_name, prim_module, _) in zip(
param, self.flat_param._shared_params, self.flat_param._shared_param_infos
(param_name, module, _, prim_param_name, prim_module, _),
) in enumerate(
zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
): ):
self._setattr_param(module, param_name, param) self._setattr_param(module, param_name, param)
prim_param = getattr(prim_module, prim_param_name) prim_param = getattr(prim_module, prim_param_name)
@ -2388,14 +2385,14 @@ class FlatParamHandle:
# TODO: If we want to handle shared parameters, we need to re-generate # TODO: If we want to handle shared parameters, we need to re-generate
# the shared parameter data structures in case sharedness changed. # the shared parameter data structures in case sharedness changed.
for i, ( for (
param_name, param_name,
module, module,
_, _,
prim_param_name, prim_param_name,
prim_module, prim_module,
_, _,
) in enumerate(flat_param._shared_param_infos): ) in flat_param._shared_param_infos:
if getattr(module, param_name) is not getattr(prim_module, prim_param_name): if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
raise NotImplementedError( raise NotImplementedError(
"Changing shared parameters is not supported yet" "Changing shared parameters is not supported yet"

View File

@ -924,7 +924,7 @@ class Pipe(torch.nn.Module):
pass pass
# This is done by (1) `_sink_params` at each submodule; # This is done by (1) `_sink_params` at each submodule;
for name, submod in split.named_children(): for submod in split.children():
if isinstance(submod, fx.GraphModule): if isinstance(submod, fx.GraphModule):
_sink_params(submod, inputs_to_state, []) _sink_params(submod, inputs_to_state, [])
submod.graph.lint() submod.graph.lint()

View File

@ -967,7 +967,7 @@ def distribute_module(
if partition_fn is None: if partition_fn is None:
# if partition_fn not specified, we by default replicate # if partition_fn not specified, we by default replicate
# all module params/buffers # all module params/buffers
for name, submod in module.named_modules(): for submod in module.modules():
replicate_module_params_buffers(submod, device_mesh) replicate_module_params_buffers(submod, device_mesh)
else: else:
# apply partition_fun to submodules # apply partition_fun to submodules

View File

@ -170,7 +170,7 @@ def gen_einsum_strategies(
# linearity strategy # linearity strategy
if linearity: if linearity:
linearity_placement_list: list[Placement] = [Partial()] linearity_placement_list: list[Placement] = [Partial()]
for input_dim in input_dims: for _ in input_dims:
linearity_placement_list.append(Partial()) linearity_placement_list.append(Partial())
strategies_over_one_mesh_dim.append(linearity_placement_list) strategies_over_one_mesh_dim.append(linearity_placement_list)

View File

@ -1332,7 +1332,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
roots.add(c.root.__name__) # type: ignore[attr-defined] roots.add(c.root.__name__) # type: ignore[attr-defined]
# check keys are existing dims or new roots # check keys are existing dims or new roots
for k, c in shape_fixes.items(): for k in shape_fixes.keys():
assert k in name_to_dim or k in roots assert k in name_to_dim or k in roots
# cache so we don't produce multiple derived dim objects # cache so we don't produce multiple derived dim objects

View File

@ -101,11 +101,11 @@ def broadcast_types(t1, t2):
# We make the types the same length which is the first requirement # We make the types the same length which is the first requirement
# for consistency # for consistency
if s1 > s2: if s1 > s2:
for i in range(s1 - s2): for _ in range(s1 - s2):
new_t2.insert(0, 1) new_t2.insert(0, 1)
elif s2 > s1: elif s2 > s1:
for i in range(s2 - s1): for _ in range(s2 - s1):
new_t1.insert(0, 1) new_t1.insert(0, 1)
# we replace occurrences of "1" with each tensor with # we replace occurrences of "1" with each tensor with

View File

@ -1871,7 +1871,7 @@ def _make_user_magic(method, user_type):
setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl) setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl)
for method, func in magic_methods.items(): # type: ignore[assignment] for method in magic_methods.keys(): # type: ignore[assignment]
if method in only_bool_magic_methods: if method in only_bool_magic_methods:
_make_user_magic(method, SymBool) _make_user_magic(method, SymBool)
continue continue

View File

@ -3342,7 +3342,7 @@ class DimConstraints:
# alter derivations that depend on old root, to unify to new root # alter derivations that depend on old root, to unify to new root
# e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2 # e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2
for old_root in introduced_roots.values(): for old_root in introduced_roots.values():
for k, c in list(results.items()): for c in results.values():
if ( if (
"eq" in c "eq" in c
and isinstance(c["eq"], sympy.Expr) and isinstance(c["eq"], sympy.Expr)

View File

@ -1066,7 +1066,7 @@ def call_prepare_scriptable_func_impl(obj, memo):
else: else:
new_obj_dict[name] = sub_module new_obj_dict[name] = sub_module
for k, v in new_obj_dict.items(): for v in new_obj_dict.values():
obj.__dict__[name] = v obj.__dict__[name] = v
return obj return obj

View File

@ -6099,7 +6099,7 @@ def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None):
if other_dim_rank != self_dim_rank: if other_dim_rank != self_dim_rank:
delta = self_dim_rank - other_dim_rank delta = self_dim_rank - other_dim_rank
for i in range(delta): for _ in range(delta):
other = symbolic_helper._unsqueeze_helper( other = symbolic_helper._unsqueeze_helper(
g, other, [symbolic_helper._get_tensor_rank(other)] g, other, [symbolic_helper._get_tensor_rank(other)]
) )
@ -6126,10 +6126,10 @@ def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None):
) )
other = expand_as(g, other, new_shape) other = expand_as(g, other, new_shape)
for i in range(dim): for _ in range(dim):
index = symbolic_helper._unsqueeze_helper(g, index, [0]) index = symbolic_helper._unsqueeze_helper(g, index, [0])
for i in range(self_dim_rank - dim - 1): for _ in range(self_dim_rank - dim - 1):
index = symbolic_helper._unsqueeze_helper( index = symbolic_helper._unsqueeze_helper(
g, index, [symbolic_helper._get_tensor_rank(index)] g, index, [symbolic_helper._get_tensor_rank(index)]
) )

View File

@ -78,7 +78,7 @@ class EncodedAttrs:
attr_floats=[], attr_floats=[],
attr_strs=[], attr_strs=[],
) )
for i, (k, v) in enumerate(attrs.items()): for k, v in attrs.items():
encoded.attr_keys.append(k) encoded.attr_keys.append(k)
if isinstance(v, int): if isinstance(v, int):
start_pos = len(encoded.attr_ints) start_pos = len(encoded.attr_ints)

View File

@ -445,11 +445,9 @@ def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
) )
# Checking for permutations of weights and biases as `None` # Checking for permutations of weights and biases as `None`
weights = [channels, None, None]
biases = [None, channels, None]
is_training = [True, False, False] is_training = [True, False, False]
for weight, bias, training in zip(weights, biases, is_training, strict=True): for training in is_training:
yield SampleInput( yield SampleInput(
make_arg(input_shape), make_arg(input_shape),
args=( args=(

View File

@ -465,7 +465,7 @@ class DdpUnderDistAutogradTest(RpcAgentTestFixture):
) )
# Destroy process groups # Destroy process groups
for idx, trainer_rref in enumerate(trainer_rrefs): for trainer_rref in trainer_rrefs:
_remote_method_async(Trainer.destroy_pg, trainer_rref).wait() _remote_method_async(Trainer.destroy_pg, trainer_rref).wait()
# Send shutdown signals. # Send shutdown signals.

View File

@ -6094,7 +6094,7 @@ class DistributedTest:
dim=1, dim=1,
).cuda(rank) ).cuda(rank)
for i in range(100): for _ in range(100):
y = model(input_var[rank].cuda(rank)) y = model(input_var[rank].cuda(rank))
y.mean().backward() y.mean().backward()

View File

@ -1988,7 +1988,7 @@ class DistAutogradTest(CommonDistAutogradTest):
self.assertEqual(self.world_size - 1, len(known_context_ids)) self.assertEqual(self.world_size - 1, len(known_context_ids))
t1 = torch.rand((3, 3), requires_grad=True) t1 = torch.rand((3, 3), requires_grad=True)
for i in range(100): for _ in range(100):
dst = self._next_rank() dst = self._next_rank()
t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1)) t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1))

View File

@ -823,7 +823,7 @@ if has_triton():
mask = offsets < n_elements mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask) x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask)
for i in range(2): for _ in range(2):
output = x + y output = x + y
tl.store(out_ptr + offsets, output, mask=mask) tl.store(out_ptr + offsets, output, mask=mask)
i = 2 i = 2

View File

@ -355,7 +355,7 @@
"dp = dp.shuffle()\n", "dp = dp.shuffle()\n",
"dp = dp.batch(2)\n", "dp = dp.batch(2)\n",
"print(\"Iterate over DataFrame batches\")\n", "print(\"Iterate over DataFrame batches\")\n",
"for i,v in enumerate(dp):\n", "for v in dp:\n",
" print(v)\n", " print(v)\n",
"\n", "\n",
"# this is similar to batching of regular DataPipe\n", "# this is similar to batching of regular DataPipe\n",