mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[1/N] Remove unused loop variables (#166258)
This PR removes unused loop variables. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166258 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
This commit is contained in:
parent
369f2d6951
commit
2de4cf2102
|
|
@ -1791,7 +1791,7 @@ def rewrite_signature(
|
|||
for i, val in enumerate(sources):
|
||||
dict_of_source_vals[id(val)] = i
|
||||
|
||||
for i, val in enumerate(candidates):
|
||||
for val in candidates:
|
||||
if isinstance(val, tuple(common_constant_types)):
|
||||
matched_elements_positions.append(None)
|
||||
elif id(val) not in dict_of_source_vals:
|
||||
|
|
|
|||
|
|
@ -319,7 +319,7 @@ class GuardManagerWrapper:
|
|||
is_diff_guard_node = (
|
||||
node.get_source() in self.diff_guard_sources or node.fail_count() > 0
|
||||
)
|
||||
for idx, (key_mgr, val_mgr) in sorted(
|
||||
for _idx, (key_mgr, val_mgr) in sorted(
|
||||
node.get_key_value_managers().items()
|
||||
):
|
||||
is_diff_guard_node |= visit(key_mgr) | visit(val_mgr)
|
||||
|
|
@ -442,7 +442,7 @@ class GuardManagerWrapper:
|
|||
is_subtree_tag_safe = True
|
||||
|
||||
# Recurse to get the tag safe roots from subtree.
|
||||
for idx, (key_mgr, val_mgr) in sorted(
|
||||
for _idx, (key_mgr, val_mgr) in sorted(
|
||||
node.get_key_value_managers().items()
|
||||
):
|
||||
if key_mgr is not None:
|
||||
|
|
@ -450,9 +450,7 @@ class GuardManagerWrapper:
|
|||
if val_mgr is not None:
|
||||
tag_safe_roots.extend(visit(val_mgr))
|
||||
|
||||
for idx, (key_mgr, val_mgr) in sorted(
|
||||
node.get_key_value_managers().items()
|
||||
):
|
||||
for key_mgr, val_mgr in node.get_key_value_managers().values():
|
||||
if key_mgr:
|
||||
is_subtree_tag_safe &= key_mgr.is_tag_safe()
|
||||
|
||||
|
|
|
|||
|
|
@ -289,9 +289,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||
params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
|
||||
all_static = True
|
||||
non_static_grads = []
|
||||
for p_ind, (p, p_vt) in enumerate(
|
||||
zip(group["params"], params_vt.unpack_var_sequence(tx))
|
||||
):
|
||||
for p, p_vt in zip(group["params"], params_vt.unpack_var_sequence(tx)):
|
||||
param_source = p_vt.source
|
||||
self.tensor_to_source[p] = param_source
|
||||
grad_source = GradSource(
|
||||
|
|
@ -322,12 +320,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||
|
||||
# We have to again iterate over the state dict to collect the
|
||||
# tensor_to_source dict. This is used for the finalizer.
|
||||
for idx, (p, value) in enumerate(self.value.state.items()):
|
||||
for idx, value in enumerate(self.value.state.values()):
|
||||
p_state_source = DictGetItemSource(
|
||||
state_source, ConstDictKeySource(state_source, idx)
|
||||
)
|
||||
tx.output.guard_on_key_order.add(p_state_source)
|
||||
for inner_idx, (k, v) in enumerate(value.items()):
|
||||
for inner_idx, v in enumerate(value.values()):
|
||||
if (
|
||||
isinstance(v, torch.Tensor)
|
||||
and v not in self.grad_to_source
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ def run_functionalized_fw_and_collect_metadata(
|
|||
|
||||
# Inspect the state of the input tensor functional wrapper to detect input mutation info
|
||||
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
|
||||
for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)):
|
||||
for arg, f_arg in zip(flat_args, flat_f_args):
|
||||
# NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in
|
||||
# strides between the functionalized arg inner tensors and non-functionalized arg inner
|
||||
# tensors. This is a problem as the inner tensor stride change may not be reflected
|
||||
|
|
|
|||
|
|
@ -2041,7 +2041,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
|
||||
assert len(meta.attrs) == len(runtime_subclass_keys)
|
||||
leaves = []
|
||||
for i, (attr, attr_meta) in enumerate(meta.attrs.items()):
|
||||
for attr, attr_meta in meta.attrs.items():
|
||||
elem = getattr(x, attr)
|
||||
new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent(
|
||||
elem, attr_meta
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
|
|||
module, name, UnwrapTensorSubclass()
|
||||
)
|
||||
|
||||
for name, child in module.named_children():
|
||||
for child in module.children():
|
||||
unwrap_tensor_subclass_parameters(child)
|
||||
|
||||
return module
|
||||
|
|
|
|||
|
|
@ -1481,9 +1481,7 @@ def functionalize_rng_ops(
|
|||
)
|
||||
)
|
||||
|
||||
for rng_count, (base_node, node_pair) in enumerate(
|
||||
recomputable_rng_ops_map.items()
|
||||
):
|
||||
for rng_count, node_pair in enumerate(recomputable_rng_ops_map.values()):
|
||||
# Step 2 - Modify the fwd pass such that
|
||||
fw_node = node_pair["fwd"]
|
||||
bw_node = node_pair["bwd"]
|
||||
|
|
@ -2714,9 +2712,7 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
|
|||
subgraph = getattr(module, hop_node.args[0].target)
|
||||
if isinstance(subgraph, fx.GraphModule):
|
||||
new_rng_inputs = []
|
||||
for idx, placeholder_node in enumerate(
|
||||
subgraph.graph.find_nodes(op="placeholder")
|
||||
):
|
||||
for placeholder_node in subgraph.graph.find_nodes(op="placeholder"):
|
||||
if rng_string in placeholder_node.name:
|
||||
# Found a rng state placeholder in the hop graph, lets add
|
||||
# the corresponding node in the outer graph
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ def temporarily_restore_interpreter_stack(stack):
|
|||
pushed.append(s)
|
||||
yield
|
||||
finally:
|
||||
for s in reversed(pushed):
|
||||
for _ in reversed(pushed):
|
||||
# TODO: would be nice to assert that the layers are the same, but
|
||||
# Python object identity is not preserved
|
||||
pop_dynamic_layer_stack()
|
||||
|
|
|
|||
|
|
@ -907,7 +907,7 @@ def diff_tensor_meta(
|
|||
try:
|
||||
if val1 != val2:
|
||||
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
||||
except GuardOnDataDependentSymNode as _:
|
||||
except GuardOnDataDependentSymNode:
|
||||
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
||||
continue
|
||||
return pair_diffs
|
||||
|
|
@ -1197,7 +1197,7 @@ def materialize_callable_in_args(op: HopInstance, args, kwargs):
|
|||
|
||||
# call_op preserves ordering of proxies via schema
|
||||
materialized_args = []
|
||||
for i, (proxy, arg) in enumerate(zip(arg_proxies, schema.arguments)):
|
||||
for i, proxy in enumerate(arg_proxies):
|
||||
if (
|
||||
isinstance(proxy, torch.fx.Node)
|
||||
and proxy.op == "get_attr"
|
||||
|
|
|
|||
|
|
@ -316,7 +316,7 @@ def while_loop_dense(
|
|||
|
||||
if stack_output:
|
||||
outs: list[torch.Tensor] = []
|
||||
for i, out in enumerate(outputs):
|
||||
for out in outputs:
|
||||
outs.append(torch.stack(out, dim=0))
|
||||
return tuple(outs)
|
||||
|
||||
|
|
|
|||
|
|
@ -2606,7 +2606,7 @@ def custom_op_wrapper(op: str, *args: Any) -> list[c_void_p] | c_void_p | None:
|
|||
if isinstance(result, (list, tuple)):
|
||||
# unsafe_alloc_void_ptrs_from_tensors expects result contains tensor only
|
||||
result = [torch.tensor([]) if r is None else r for r in result]
|
||||
for i, r in enumerate(result):
|
||||
for r in result:
|
||||
assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors"
|
||||
return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type]
|
||||
|
||||
|
|
|
|||
|
|
@ -895,7 +895,7 @@ class MetalKernel(SIMDKernel):
|
|||
else:
|
||||
dtype_str = self.dtype_to_str(dtype)
|
||||
code.writeline(f"constant {dtype_str}* {inner},")
|
||||
for outer, inner in self.args.sizevars.items():
|
||||
for inner in self.args.sizevars.values():
|
||||
code.writeline(f"constant long& {inner},")
|
||||
|
||||
# Write dynamic values as inputs
|
||||
|
|
|
|||
|
|
@ -218,7 +218,7 @@ class MultiKernel:
|
|||
# the multi call kernel.
|
||||
multi_call_args = call_args
|
||||
multi_call_arg_types = arg_types
|
||||
for i, kernel in enumerate(self.kernels):
|
||||
for kernel in self.kernels:
|
||||
additional_call_args, additional_arg_types = (
|
||||
kernel.additional_call_args_and_types()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -717,7 +717,7 @@ class ComboKernel(Kernel):
|
|||
self, name: str, call_args: list[Any], arg_types: list[Any]
|
||||
) -> None:
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
for i, tree in enumerate(sub_kernel.range_trees):
|
||||
for tree in sub_kernel.range_trees:
|
||||
numel_name = f"{tree.prefix}numel_{num}"
|
||||
if numel_name not in self.dynamic_shape_args:
|
||||
continue
|
||||
|
|
@ -735,7 +735,7 @@ class ComboKernel(Kernel):
|
|||
def kernel_benchmark_extra_args(self) -> list[str]:
|
||||
extra_args = []
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
for i, tree in enumerate(sub_kernel.range_trees):
|
||||
for tree in sub_kernel.range_trees:
|
||||
numel_name = f"{tree.prefix}numel_{num}"
|
||||
if numel_name not in self.dynamic_shape_args:
|
||||
continue
|
||||
|
|
@ -1018,7 +1018,7 @@ class ComboKernel(Kernel):
|
|||
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim
|
||||
for i, tree in enumerate(sub_kernel.range_trees):
|
||||
for tree in sub_kernel.range_trees:
|
||||
# pyrefly: ignore [missing-argument]
|
||||
if not tree.is_reduction:
|
||||
numel_name = f"{tree.prefix}numel_{num}"
|
||||
|
|
|
|||
|
|
@ -3604,16 +3604,12 @@ class PythonWrapperCodegen(CodeGen):
|
|||
self.writeline("if not should_loop:")
|
||||
if stack_output:
|
||||
# Handle the case when loop never executes
|
||||
for i, (carried_input, carried_buf) in enumerate(
|
||||
zip(outer_carried_inputs, while_loop.carried_inputs)
|
||||
):
|
||||
for i, carried_input in enumerate(outer_carried_inputs):
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
|
||||
self.writeline(f"{name}[{i}] = {carried_input}.unsqueeze(0).clone()")
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
else:
|
||||
for i, (carried_input, carried_buf) in enumerate(
|
||||
zip(outer_carried_inputs, while_loop.carried_inputs)
|
||||
):
|
||||
for i, carried_input in enumerate(outer_carried_inputs):
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
|
||||
self.writeline(f"{name}[{i}] = {carried_input}.clone()")
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
|
|
|
|||
|
|
@ -424,10 +424,7 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||
return
|
||||
|
||||
# Candidate becomes last use of some bufs
|
||||
for (
|
||||
gn,
|
||||
bufs,
|
||||
) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
|
||||
for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values():
|
||||
for buf in bufs:
|
||||
buf_to_snode_last_use[buf] = candidate
|
||||
|
||||
|
|
@ -840,7 +837,7 @@ def _schedule_for_comm(
|
|||
else:
|
||||
schedule(snode)
|
||||
|
||||
for snode, deps in unmet_deps.items():
|
||||
for deps in unmet_deps.values():
|
||||
assert len(deps) == 0, (
|
||||
f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
|
||||
)
|
||||
|
|
@ -1552,11 +1549,8 @@ Graph: {graph}
|
|||
node.args = new_args
|
||||
|
||||
# Delete `fsdp.copy_(unsharded_param, Y)` nodes
|
||||
for (
|
||||
unsharded_param,
|
||||
fsdp_copy_node_idxes,
|
||||
) in unsharded_param_to_fsdp_copy_node_idxes.items():
|
||||
for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
|
||||
for fsdp_copy_node_idxes in unsharded_param_to_fsdp_copy_node_idxes.values():
|
||||
for fsdp_copy_node_idx in fsdp_copy_node_idxes:
|
||||
fsdp_copy_node = node_list[fsdp_copy_node_idx]
|
||||
graph.erase_node(fsdp_copy_node)
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def _debug_iterative_memory_recompute(
|
|||
if iter_cm != new_cm:
|
||||
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
|
||||
iterative_recompute_error = True
|
||||
for i, gn in enumerate(gns):
|
||||
for gn in gns:
|
||||
iter_gnm = iter_curr_memory[gn]
|
||||
new_gnm = est_curr_memory[gn]
|
||||
if iter_gnm != new_gnm:
|
||||
|
|
@ -65,7 +65,7 @@ def _debug_iterative_memory_recompute(
|
|||
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
|
||||
)
|
||||
peak_log = ""
|
||||
for i, (pre, post) in enumerate(snodes_curr_memory):
|
||||
for i, (pre, _post) in enumerate(snodes_curr_memory):
|
||||
if est_peak_memory == pre:
|
||||
n = snodes[i]
|
||||
peak_log = (
|
||||
|
|
|
|||
|
|
@ -454,7 +454,7 @@ def decompose_map_to_while_loop(gm: torch.fx.GraphModule):
|
|||
|
||||
graph_pass.apply(gm)
|
||||
|
||||
for node in gm.graph.find_nodes(
|
||||
for _node in gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.higher_order.map_impl
|
||||
):
|
||||
raise AssertionError("map is not lowered to while_loop")
|
||||
|
|
@ -666,7 +666,7 @@ def decompose_scan_to_while_loop(gm: torch.fx.GraphModule):
|
|||
|
||||
graph_pass.apply(gm)
|
||||
|
||||
for node in gm.graph.find_nodes(
|
||||
for _node in gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.higher_order.scan
|
||||
):
|
||||
raise AssertionError("scan is not lowered to while_loop")
|
||||
|
|
@ -1265,7 +1265,7 @@ def decompose_triton_kernel_wrapper_functional(graph):
|
|||
|
||||
graph_pass.apply(graph)
|
||||
|
||||
for node in graph.find_nodes(
|
||||
for _ in graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops.higher_order.triton_kernel_wrapper_functional,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -8770,9 +8770,7 @@ class WhileLoop(ExternKernel):
|
|||
seen_buffers: OrderedSet[int] = OrderedSet()
|
||||
result: list[Union[IRNode, TensorBox, ShapeAsConstantBuffer]] = []
|
||||
|
||||
for i, (original_input, unwrapped_buffer) in enumerate(
|
||||
zip(carried_inputs, unwrapped_buffers)
|
||||
):
|
||||
for original_input, unwrapped_buffer in zip(carried_inputs, unwrapped_buffers):
|
||||
if id(unwrapped_buffer) in seen_buffers:
|
||||
result.append(ExternKernel.copy_input(original_input))
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -743,7 +743,7 @@ class _TargetArgsExpr(_TargetExpr):
|
|||
assert len(node_items) == len(self_items)
|
||||
|
||||
m = Match(ctx, self)
|
||||
for i, pattern, child_node in zip(itertools.count(), self_items, node_items):
|
||||
for pattern, child_node in zip(self_items, node_items):
|
||||
if isinstance(pattern, PatternExpr):
|
||||
child_match = ctx.match(pattern, child_node)
|
||||
if not is_match(child_match):
|
||||
|
|
|
|||
|
|
@ -2850,7 +2850,7 @@ class Scheduler:
|
|||
# NB: None means that the dependency is on an input. Don't actually
|
||||
# generate a dependency because if we do, Inductor will start trying
|
||||
# to free the unbacked int but that's pointless
|
||||
for name, val in V.graph.graph_inputs.items():
|
||||
for val in V.graph.graph_inputs.values():
|
||||
if isinstance(val, sympy.Expr):
|
||||
for fs in val.free_symbols:
|
||||
unbacked_symbol_to_origin_node[fs] = None
|
||||
|
|
@ -3550,9 +3550,7 @@ class Scheduler:
|
|||
future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
|
||||
for hint_override in config.multi_kernel_hints:
|
||||
choice_timings = multi_node.choice_timings(hint_override)
|
||||
for choice, unfused_time in sorted(
|
||||
choice_timings.items(), key=lambda x: x[1]
|
||||
):
|
||||
for choice, _ in sorted(choice_timings.items(), key=lambda x: x[1]):
|
||||
if not isinstance(
|
||||
choice, torch._inductor.select_algorithm.TritonTemplateCaller
|
||||
):
|
||||
|
|
|
|||
|
|
@ -425,7 +425,7 @@ def apply_var_mapping(
|
|||
new_ranges, norm_pw_vars + norm_red_vars, strict=True
|
||||
):
|
||||
range_vars = []
|
||||
for i in range(len(new_range)):
|
||||
for _ in range(len(new_range)):
|
||||
range_vars.append(flat_vars[count])
|
||||
count += 1
|
||||
|
||||
|
|
|
|||
|
|
@ -348,7 +348,7 @@ def _do_bench_using_profiling(
|
|||
]
|
||||
) as p:
|
||||
# Benchmark
|
||||
for i in range(n_repeat):
|
||||
for _ in range(n_repeat):
|
||||
# we clear the L2 cache before each run
|
||||
cache.zero_()
|
||||
# record time of `fn`
|
||||
|
|
|
|||
|
|
@ -3118,7 +3118,7 @@ def _validate_symbolic_output_for_caching(
|
|||
if is_tracing:
|
||||
# Check for SymNode types in PROXY mode - this should bypass caching
|
||||
# regardless of whether symbols are known or not
|
||||
for node in _iterate_nodes(output):
|
||||
for _ in _iterate_nodes(output):
|
||||
raise _BypassDispatchCache("Proxy mode with SymNode output")
|
||||
else:
|
||||
# Check for unrepresented symbols in tensor expressions
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ def _get_logger_dict_helper(
|
|||
def get_prefix(prefix):
|
||||
return prefix if prefix == "" else prefix + "."
|
||||
|
||||
for name, child in mod.named_children():
|
||||
for child in mod.children():
|
||||
if isinstance(child, Logger):
|
||||
target_dict[get_prefix(prefix) + "stats"] = child.stats
|
||||
break
|
||||
|
|
|
|||
|
|
@ -909,8 +909,7 @@ def create_a_shadows_b(
|
|||
# is added
|
||||
prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
|
||||
|
||||
for arg_idx, arg in enumerate(prev_node_b):
|
||||
prev_node_c = prev_node_c_list[arg_idx]
|
||||
for arg_idx, prev_node_c in enumerate(prev_node_c_list):
|
||||
env_c[prev_node_c.name] = _insert_logger_after_node(
|
||||
prev_node_c,
|
||||
gm_b,
|
||||
|
|
|
|||
|
|
@ -151,6 +151,6 @@ def bias_correction(
|
|||
bias.data = updated_bias
|
||||
|
||||
# Resets the data contained in the loggers
|
||||
for name, submodule in quantized_model.named_modules():
|
||||
for submodule in quantized_model.modules():
|
||||
if isinstance(submodule, MeanShadowLogger):
|
||||
submodule.clear()
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ def _get_numerical_jacobian(
|
|||
inp_indices = [
|
||||
i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad
|
||||
]
|
||||
for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)):
|
||||
for inp, inp_idx in zip(_iter_tensors(target, True), inp_indices):
|
||||
jacobians += [
|
||||
get_numerical_jacobian_wrt_specific_input(
|
||||
fn,
|
||||
|
|
@ -549,7 +549,7 @@ def _get_analytical_jacobian_forward_ad(
|
|||
with fwAD.dual_level():
|
||||
fw_grads = []
|
||||
dual_inputs = []
|
||||
for i, inp in enumerate(inputs):
|
||||
for inp in inputs:
|
||||
if is_tensor_like(inp) and inp.requires_grad:
|
||||
if inp.layout == torch._mkldnn: # type: ignore[attr-defined]
|
||||
raise ValueError(
|
||||
|
|
@ -1275,7 +1275,7 @@ def _test_undefined_forward_mode(func, outputs, inputs):
|
|||
tensor_indices.add(i)
|
||||
dual_inputs.append(inp)
|
||||
|
||||
for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
|
||||
for fw_grad, u in zip(fw_grads, all_u):
|
||||
fw_grad.copy_(u.view_as(fw_grad))
|
||||
|
||||
for idx, inp in enumerate(inputs):
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class _PseudoZipFile:
|
|||
|
||||
pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL)
|
||||
|
||||
for key, (data, length) in self.records.items():
|
||||
for data, _ in self.records.values():
|
||||
if isinstance(data, bytes):
|
||||
f.write(data)
|
||||
elif isinstance(data, str):
|
||||
|
|
|
|||
|
|
@ -578,7 +578,7 @@ def _load_model_state_dict(
|
|||
assign = False
|
||||
if info.broadcast_from_rank0 or info.full_state_dict:
|
||||
devices = set()
|
||||
for key, value in local_state_dict.items():
|
||||
for value in local_state_dict.values():
|
||||
if torch.is_tensor(value) and value.dim() > 0:
|
||||
devices.add(value.device)
|
||||
# In lora state_dict, there could be multiple devices, with meta device inside.
|
||||
|
|
|
|||
|
|
@ -2087,14 +2087,14 @@ class FlatParamHandle:
|
|||
param.grad.data = view
|
||||
else:
|
||||
param.grad = view
|
||||
for i, (
|
||||
for (
|
||||
param_name,
|
||||
module,
|
||||
module_name,
|
||||
prim_param_name,
|
||||
prim_module,
|
||||
_,
|
||||
) in enumerate(self.flat_param._shared_param_infos):
|
||||
) in self.flat_param._shared_param_infos:
|
||||
_p_assert(
|
||||
hasattr(module, param_name),
|
||||
f"{module_name + '.' + param_name if module_name else param_name} is missing",
|
||||
|
|
@ -2171,11 +2171,8 @@ class FlatParamHandle:
|
|||
param.data = flat_param[offset : offset + numel_in_shard]
|
||||
if self.flat_param._shared_params is None:
|
||||
raise AssertionError("Expected _shared_params to be not None")
|
||||
for i, (
|
||||
param,
|
||||
(param_name, module, _, prim_param_name, prim_module, _),
|
||||
) in enumerate(
|
||||
zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
|
||||
for param, (param_name, module, _, prim_param_name, prim_module, _) in zip(
|
||||
self.flat_param._shared_params, self.flat_param._shared_param_infos
|
||||
):
|
||||
self._setattr_param(module, param_name, param)
|
||||
prim_param = getattr(prim_module, prim_param_name)
|
||||
|
|
@ -2388,14 +2385,14 @@ class FlatParamHandle:
|
|||
|
||||
# TODO: If we want to handle shared parameters, we need to re-generate
|
||||
# the shared parameter data structures in case sharedness changed.
|
||||
for i, (
|
||||
for (
|
||||
param_name,
|
||||
module,
|
||||
_,
|
||||
prim_param_name,
|
||||
prim_module,
|
||||
_,
|
||||
) in enumerate(flat_param._shared_param_infos):
|
||||
) in flat_param._shared_param_infos:
|
||||
if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
|
||||
raise NotImplementedError(
|
||||
"Changing shared parameters is not supported yet"
|
||||
|
|
|
|||
|
|
@ -924,7 +924,7 @@ class Pipe(torch.nn.Module):
|
|||
pass
|
||||
|
||||
# This is done by (1) `_sink_params` at each submodule;
|
||||
for name, submod in split.named_children():
|
||||
for submod in split.children():
|
||||
if isinstance(submod, fx.GraphModule):
|
||||
_sink_params(submod, inputs_to_state, [])
|
||||
submod.graph.lint()
|
||||
|
|
|
|||
|
|
@ -967,7 +967,7 @@ def distribute_module(
|
|||
if partition_fn is None:
|
||||
# if partition_fn not specified, we by default replicate
|
||||
# all module params/buffers
|
||||
for name, submod in module.named_modules():
|
||||
for submod in module.modules():
|
||||
replicate_module_params_buffers(submod, device_mesh)
|
||||
else:
|
||||
# apply partition_fun to submodules
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ def gen_einsum_strategies(
|
|||
# linearity strategy
|
||||
if linearity:
|
||||
linearity_placement_list: list[Placement] = [Partial()]
|
||||
for input_dim in input_dims:
|
||||
for _ in input_dims:
|
||||
linearity_placement_list.append(Partial())
|
||||
strategies_over_one_mesh_dim.append(linearity_placement_list)
|
||||
|
||||
|
|
|
|||
|
|
@ -1332,7 +1332,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
|
|||
roots.add(c.root.__name__) # type: ignore[attr-defined]
|
||||
|
||||
# check keys are existing dims or new roots
|
||||
for k, c in shape_fixes.items():
|
||||
for k in shape_fixes.keys():
|
||||
assert k in name_to_dim or k in roots
|
||||
|
||||
# cache so we don't produce multiple derived dim objects
|
||||
|
|
|
|||
|
|
@ -101,11 +101,11 @@ def broadcast_types(t1, t2):
|
|||
# We make the types the same length which is the first requirement
|
||||
# for consistency
|
||||
if s1 > s2:
|
||||
for i in range(s1 - s2):
|
||||
for _ in range(s1 - s2):
|
||||
new_t2.insert(0, 1)
|
||||
|
||||
elif s2 > s1:
|
||||
for i in range(s2 - s1):
|
||||
for _ in range(s2 - s1):
|
||||
new_t1.insert(0, 1)
|
||||
|
||||
# we replace occurrences of "1" with each tensor with
|
||||
|
|
|
|||
|
|
@ -1871,7 +1871,7 @@ def _make_user_magic(method, user_type):
|
|||
setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl)
|
||||
|
||||
|
||||
for method, func in magic_methods.items(): # type: ignore[assignment]
|
||||
for method in magic_methods.keys(): # type: ignore[assignment]
|
||||
if method in only_bool_magic_methods:
|
||||
_make_user_magic(method, SymBool)
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -3342,7 +3342,7 @@ class DimConstraints:
|
|||
# alter derivations that depend on old root, to unify to new root
|
||||
# e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2
|
||||
for old_root in introduced_roots.values():
|
||||
for k, c in list(results.items()):
|
||||
for c in results.values():
|
||||
if (
|
||||
"eq" in c
|
||||
and isinstance(c["eq"], sympy.Expr)
|
||||
|
|
|
|||
|
|
@ -1066,7 +1066,7 @@ def call_prepare_scriptable_func_impl(obj, memo):
|
|||
else:
|
||||
new_obj_dict[name] = sub_module
|
||||
|
||||
for k, v in new_obj_dict.items():
|
||||
for v in new_obj_dict.values():
|
||||
obj.__dict__[name] = v
|
||||
|
||||
return obj
|
||||
|
|
|
|||
|
|
@ -6099,7 +6099,7 @@ def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None):
|
|||
|
||||
if other_dim_rank != self_dim_rank:
|
||||
delta = self_dim_rank - other_dim_rank
|
||||
for i in range(delta):
|
||||
for _ in range(delta):
|
||||
other = symbolic_helper._unsqueeze_helper(
|
||||
g, other, [symbolic_helper._get_tensor_rank(other)]
|
||||
)
|
||||
|
|
@ -6126,10 +6126,10 @@ def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None):
|
|||
)
|
||||
other = expand_as(g, other, new_shape)
|
||||
|
||||
for i in range(dim):
|
||||
for _ in range(dim):
|
||||
index = symbolic_helper._unsqueeze_helper(g, index, [0])
|
||||
|
||||
for i in range(self_dim_rank - dim - 1):
|
||||
for _ in range(self_dim_rank - dim - 1):
|
||||
index = symbolic_helper._unsqueeze_helper(
|
||||
g, index, [symbolic_helper._get_tensor_rank(index)]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class EncodedAttrs:
|
|||
attr_floats=[],
|
||||
attr_strs=[],
|
||||
)
|
||||
for i, (k, v) in enumerate(attrs.items()):
|
||||
for k, v in attrs.items():
|
||||
encoded.attr_keys.append(k)
|
||||
if isinstance(v, int):
|
||||
start_pos = len(encoded.attr_ints)
|
||||
|
|
|
|||
|
|
@ -445,11 +445,9 @@ def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
|
|||
)
|
||||
|
||||
# Checking for permutations of weights and biases as `None`
|
||||
weights = [channels, None, None]
|
||||
biases = [None, channels, None]
|
||||
is_training = [True, False, False]
|
||||
|
||||
for weight, bias, training in zip(weights, biases, is_training, strict=True):
|
||||
for training in is_training:
|
||||
yield SampleInput(
|
||||
make_arg(input_shape),
|
||||
args=(
|
||||
|
|
|
|||
|
|
@ -465,7 +465,7 @@ class DdpUnderDistAutogradTest(RpcAgentTestFixture):
|
|||
)
|
||||
|
||||
# Destroy process groups
|
||||
for idx, trainer_rref in enumerate(trainer_rrefs):
|
||||
for trainer_rref in trainer_rrefs:
|
||||
_remote_method_async(Trainer.destroy_pg, trainer_rref).wait()
|
||||
|
||||
# Send shutdown signals.
|
||||
|
|
|
|||
|
|
@ -6094,7 +6094,7 @@ class DistributedTest:
|
|||
dim=1,
|
||||
).cuda(rank)
|
||||
|
||||
for i in range(100):
|
||||
for _ in range(100):
|
||||
y = model(input_var[rank].cuda(rank))
|
||||
y.mean().backward()
|
||||
|
||||
|
|
|
|||
|
|
@ -1988,7 +1988,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
|||
self.assertEqual(self.world_size - 1, len(known_context_ids))
|
||||
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
for i in range(100):
|
||||
for _ in range(100):
|
||||
dst = self._next_rank()
|
||||
t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1))
|
||||
|
||||
|
|
|
|||
|
|
@ -823,7 +823,7 @@ if has_triton():
|
|||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
for i in range(2):
|
||||
for _ in range(2):
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
i = 2
|
||||
|
|
|
|||
|
|
@ -355,7 +355,7 @@
|
|||
"dp = dp.shuffle()\n",
|
||||
"dp = dp.batch(2)\n",
|
||||
"print(\"Iterate over DataFrame batches\")\n",
|
||||
"for i,v in enumerate(dp):\n",
|
||||
"for v in dp:\n",
|
||||
" print(v)\n",
|
||||
"\n",
|
||||
"# this is similar to batching of regular DataPipe\n",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user