Remove unused Python variables in torch/[b-z]* (#136963)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
This commit is contained in:
Tom Ritchford 2024-10-19 13:25:28 +00:00 committed by PyTorch MergeBot
parent fb44658415
commit c0582fd0f8
152 changed files with 376 additions and 528 deletions

View File

@ -266,7 +266,7 @@ def broadcast_shapes(shape1, shape2):
def get_conv_pool_shape(image_shape, args, out_ch, transpose):
batch, in_c, in_h, in_w = image_shape
batch, _in_c, in_h, in_w = image_shape
# TODO: Handle dilation
if args.dilation_h != 1 or args.dilation_w != 1:
@ -443,7 +443,6 @@ class _NnapiSerializer:
operand_id = len(self.operands)
self.operands.append(toper)
tsize = tensor_size(toper.op_type, toper.shape)
psize = ((tsize - 1) | 0x3) + 1
self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER))
buf_num = len(self.used_weights)
offset = 0
@ -917,7 +916,7 @@ class _NnapiSerializer:
adder(self, node)
def _identity(self, node):
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
in_id, _in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
jitval = node.outputsAt(0)
self.jitval_operand_map[jitval] = in_id
@ -1039,8 +1038,8 @@ class _NnapiSerializer:
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
_start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
_end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
# channels last with channels == 1 or (height & width both 1)
is_trivial_flatten = len(in_oper.shape) == 4 and (
@ -1526,7 +1525,7 @@ class _NnapiSerializer:
def add_pool2d_node(self, node, opcode):
assert node.inputsSize() == 6
assert node.outputsSize() == 1
image, kernel, stride, padding, dilation, ceil_mode = node.inputs()
image, kernel, stride, padding, dilation, _ceil_mode = node.inputs()
stride = stride or kernel
@ -1574,7 +1573,7 @@ class _NnapiSerializer:
kernel,
stride,
padding,
ceil_mode,
_ceil_mode,
count_include_pad,
divisor_override,
) = node.inputs()
@ -1673,7 +1672,7 @@ class _NnapiSerializer:
scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined]
else:
scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined]
scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined]
scale_w_ctype, _scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined]
# The only way for the 4-argument overload of upsample_nearest2d to
# have been added to the graph without error is if the scale_h and
@ -1892,7 +1891,7 @@ class _NnapiSerializer:
self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
def get_optional_bias(self, jit_bias, weight_tensor, transpose=False):
ctype, value = self.get_constant_value(jit_bias)
ctype, _value = self.get_constant_value(jit_bias)
if ctype.kind() == "NoneType":
bias_idx = 1 if transpose else 0
nnapi_bias_tensor = torch.zeros(
@ -1919,7 +1918,7 @@ class _NnapiSerializer:
) = node.inputs()
_, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
args = self.get_conv_pool_args_2d_from_jit(
weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
)
@ -1958,7 +1957,7 @@ class _NnapiSerializer:
_, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
_, transpose = self.get_constant_value(jit_transpose)
bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose)
bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose)
args = self.get_conv_pool_args_2d_from_jit(
weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
)
@ -1979,7 +1978,7 @@ class _NnapiSerializer:
assert node.inputsSize() == 3
assert node.outputsSize() == 1
(jit_input, jit_dim, jit_half_to_float) = node.inputs()
jit_input, jit_dim, _jit_half_to_float = node.inputs()
input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
_, dim = self.get_constant_value(jit_dim, "IntType")
@ -2117,7 +2116,7 @@ class _NnapiSerializer:
if depthwise:
# Depthwise convolution
one, kern_h, kern_w, out_c = weight_oper.shape
one, _kern_h, _kern_w, out_c = weight_oper.shape
assert one == 1
assert out_c % in_c == 0
channel_multiplier = out_c // in_c
@ -2125,7 +2124,7 @@ class _NnapiSerializer:
assert out_c == in_c
else:
# Full convolution
out_c, kern_h, kern_w, kern_d = weight_oper.shape
out_c, _kern_h, _kern_w, kern_d = weight_oper.shape
assert kern_d == in_c
assert out_c == bias_oper.shape[0]

View File

@ -37,7 +37,7 @@ def visualize(graph, name_prefix='', pb_graph=None, executors_it=None):
return pb_graph
# Set up an input node
input_node = pb_graph.node.add(op='input', name=name_prefix + 'input')
pb_graph.node.add(op='input', name=name_prefix + 'input')
for i, value in enumerate(graph.param_node().outputs()):
value_map[value.unique()] = name_prefix + 'input:' + str(i)

View File

@ -186,7 +186,7 @@ def _check_capability():
work properly, but your PyTorch was compiled
with CUDA_VERSION %d. Please install the correct PyTorch binary
using instructions from https://pytorch.org
"""
""" # noqa: F841
old_gpu_warn = """
Found GPU%d %s which is of cuda capability %d.%d.
@ -195,7 +195,7 @@ def _check_capability():
"""
if torch.version.cuda is not None: # on ROCm we don't want this check
CUDA_VERSION = torch._C._cuda_getCompiledVersion()
CUDA_VERSION = torch._C._cuda_getCompiledVersion() # noqa: F841
for d in range(device_count()):
capability = get_device_capability(d)
major = capability[0]

View File

@ -213,7 +213,6 @@ def segsum(data):
Args:
data: snapshot dictionary created from _snapshot()
"""
segments = []
out = io.StringIO()
out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n")
total_reserved = 0
@ -272,7 +271,6 @@ def segsum(data):
out.write(f'segments: {len(data["segments"])}\n')
out.write(f'total_reserved: {Bytes(total_reserved)}\n')
out.write(f'total_allocated: {Bytes(total_allocated)}\n')
internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else ''
out.write(f'total_free: {_report_free(free_external, free_internal)}\n')
out.write(legend)
assert free_internal + free_external + total_allocated == total_reserved
@ -478,10 +476,8 @@ def _profile_to_snapshot(profile):
kv_to_elem = {}
# create the device trace
for time, action, (tensor_key, version), size in memory_profile.timeline:
for _time, action, (tensor_key, version), size in memory_profile.timeline:
if not isinstance(tensor_key, TensorKey):
continue
if action == Action.CREATE:

View File

@ -357,11 +357,10 @@ def make_graphed_callables(
# Capture backward graphs in reverse order
per_callable_static_grad_outputs = []
per_callable_static_grad_inputs = []
for static_input_surface, static_outputs, bwd_graph, module_params in zip(
for static_input_surface, static_outputs, bwd_graph in zip(
reversed(per_callable_static_input_surfaces),
reversed(per_callable_static_outputs),
reversed(bwd_graphs),
reversed(per_callable_module_params),
):
# For now, assumes all static_outputs require grad
# assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."

View File

@ -245,7 +245,7 @@ def foreach_all_gather_copy_out(
param_all_gather_input_numels,
all_gather_input_split_sizes,
) = all_gather_result
dtype, device = all_gather_output.dtype, all_gather_output.device
_dtype, device = all_gather_output.dtype, all_gather_output.device
device_handle = _get_device_handle(device.type)
if all_gather_event is not None: # sync op
device_handle.current_stream().wait_event(all_gather_event)

View File

@ -82,8 +82,6 @@ class _ReplicateState(_State):
return
self.has_initialized = True
device_mesh = kwargs.get("device_mesh", None)
self.module = module
ignored_params = {p for m in ignored_modules for p in m.parameters()}
for submodule in module.modules():

View File

@ -274,7 +274,7 @@ def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_grou
mod, param_name, spec, src_rank=src_rank, process_group=process_group
)
elif isinstance(spec, Sharder):
parent_mod_path, _, mod_name = name.rpartition(".")
parent_mod_path, _, _mod_name = name.rpartition(".")
if name == "":
raise KeyError("Module path must not be empty for custom sharder!")
mod = module.get_submodule(name)

View File

@ -25,7 +25,6 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
if len(args) != 2:
raise ValueError(f"Expected two arguments for torch.{cmp_fun.__name__}")
result = True
st1 = args[0]
st2 = args[1]
if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)):

View File

@ -857,7 +857,7 @@ class ShardedTensor(ShardedTensorBase):
local_shards: List[Shard] = []
for shard_metadata in sharded_tensor_metadata.shards_metadata:
rank, device = _parse_and_validate_remote_device(
rank, _device = _parse_and_validate_remote_device(
process_group, shard_metadata.placement
)
if rank == current_rank:

View File

@ -34,7 +34,6 @@ def _reducer_allreduce_and_upcast_hook(
"""
ddp_weakref = hook_state.ddp_weakref
reducer, process_group = ddp_weakref().reducer, ddp_weakref().process_group
gradient_is_bucket_view = ddp_weakref().gradient_as_bucket_view
# Cast bucket if different than param_dtype.
if (
ddp_weakref().mixed_precision.param_dtype
@ -53,8 +52,7 @@ def _reducer_allreduce_and_upcast_hook(
ret_fut.set_result(bucket.buffer())
# Upcast parameters and gradients so optimizer step can run in fp32.
params, grads = bucket.parameters(), bucket.gradients()
for p, g in zip(params, grads):
for p in bucket.parameters():
p.data = p._fp_param
# free storage for mp param as it will be allocated again in next
# forward pass.
@ -70,7 +68,7 @@ def _reducer_allreduce_and_upcast_hook(
# they may participate in computation. However, they would not be recast
# by hook above as they don't have a grad hook installed, so cast them
# back here.
for n, p in ddp_weakref().module.named_parameters():
for _, p in ddp_weakref().module.named_parameters():
if hasattr(p, "_ddp_mp_hook_state"):
p._ddp_mp_hook_state[1].remove()
delattr(p, "_ddp_mp_hook_state")

View File

@ -87,7 +87,7 @@ def _retrieve_embedding_parameters(emb_rref):
def _print_header():
_print_cont("\n")
_print_cont("%10s" % "")
for p in [50, 75, 90, 95]:
for _ in [50, 75, 90, 95]:
_print_cont("%14s%10s" % ("sec/epoch", "epoch/sec"))
_print_cont("\n")
@ -112,7 +112,6 @@ def _run_printable(cmd):
buffer = io.BytesIO()
torch.save(proc.stdout.decode("utf-8"), buffer)
input_tensor = torch.ByteTensor(list(buffer.getvalue()))
input_length = torch.IntTensor([input_tensor.size(0)])
output = []
buffer = io.BytesIO(np.asarray(input_tensor).tobytes())
@ -173,7 +172,7 @@ def _run_trainer(emb_rref_list, rank):
measurements = []
# Include warm-up cycles during training
for epoch in range(100 + WARMUP_CYCLES):
for _ in range(100 + WARMUP_CYCLES):
start = time.time()
batch_size = 0

View File

@ -178,7 +178,7 @@ def create_read_items_for_chunk_list(
dest_offsets = []
lengths = []
for (
dim,
_dim,
offset_for_saved_tensor,
offset_for_current_tensor,
length,

View File

@ -4581,7 +4581,7 @@ def split_group(
raise RuntimeError(
"No device associated with the default pg, not safe to split any process groups"
)
default_backend, default_store = _world.pg_map[default_pg]
_default_backend, default_store = _world.pg_map[default_pg]
global_rank = default_pg.rank()
global_world_size = default_pg.size()

View File

@ -245,7 +245,7 @@ class ChildFailedError(Exception):
def format_msg(self, boarder_delim="=", section_delim="-"):
title = f"{self.name} FAILED"
root_rank, root_failure = self.get_first_failure()
root_rank, _root_failure = self.get_first_failure()
root_failure_fmt: str = ""
other_failures_fmt: List[str] = []

View File

@ -38,7 +38,7 @@ def get_socket_with_port() -> socket.socket:
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
except OSError:
s.close()
raise RuntimeError("Failed to create a socket")

View File

@ -274,7 +274,7 @@ def _named_parameters_with_duplicates(
kwargs["remove_duplicate"] = False
try:
ret = list(module.named_parameters(**kwargs))
except AssertionError as e:
except AssertionError:
kwargs.pop("remove_duplicate")
ret = list(module.named_parameters(**kwargs))
return ret

View File

@ -1017,10 +1017,10 @@ class FlatParamHandle:
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
# into the unsharded flat parameter (inclusive) of the given parameter
for i, (
for (
(unsharded_param_start_idx, unsharded_param_end_idx),
is_padding,
) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)):
) in zip(flat_param_offsets, self.flat_param._is_padding_mask):
if is_padding:
continue
in_sharded_flat_param = (
@ -2201,8 +2201,8 @@ class FlatParamHandle:
else:
param.grad = None
assert flat_param._shared_params is not None
for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
zip(flat_param._shared_params, flat_param._shared_param_infos)
for param, (_, _, _, prim_param_name, prim_module, _) in zip(
flat_param._shared_params, flat_param._shared_param_infos
):
in_sharded_flat_param = hasattr(prim_module, prim_param_name)
if in_sharded_flat_param and param.requires_grad:

View File

@ -536,9 +536,7 @@ def _flatten_optim_state_dict(
else:
# Move the tensor in the original osd back to CPU to make the
# original osd unaffected.
unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][
state_name
].cpu()
unflat_osd_state[fqn][state_name] = param_state.cpu()
# Handle user-defined state, states that are not associated with parameters.
for key in all_state_keys:
@ -1457,7 +1455,7 @@ def _unflatten_orig_param_states(
# gather the tensor on its TP dimension before chunking them into DTensor again.
if placement != Replicate():
placement_dim = placement.dim # type: ignore[attr-defined]
value_local = value.redistribute(placements=(Replicate(),))
value.redistribute(placements=(Replicate(),))
reshape_size = list(flat_param._shapes[param_idx])
reshape_size[placement_dim] *= value.device_mesh.size(0)
reshape_size = torch.Size(reshape_size)

View File

@ -297,7 +297,7 @@ def _full_pre_state_dict_hook(
``nn.Module``.
"""
if getattr(fsdp_state, "_device_mesh", False):
root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh)
_mesh_resources.get_root_mesh(fsdp_state._device_mesh)
_common_pre_state_dict_hook(module, fsdp_state)
_common_unshard_pre_state_dict_hook(

View File

@ -104,7 +104,7 @@ def get_param_groups(
# but omits weights and any subgraphs connecting weights to this closure
inputs_closure, _ = reverse_closure(inputs, set(), reverse_edges_dict)
param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates
for i, param in enumerate(params):
for param in params:
closure, intersected = reverse_closure(
[param], inputs_closure, reverse_edges_dict
)

View File

@ -802,7 +802,6 @@ class Schedule1F1B(PipelineScheduleSingle):
# Chunk counters
fwd_mb_index = 0
bwd_mb_index = 0
weight_stage_mb_index = 0
# Warmup phase
send_work = None

View File

@ -719,7 +719,7 @@ class _PipelineStageBase(ABC):
"param_groups": param_groups,
"full_backward": False,
}
weight_grads, _ = self.backward_maybe_with_nosync("weight", bwd_kwargs)
self.backward_maybe_with_nosync("weight", bwd_kwargs)
else:
# TODO: figure out a better way to do this:
# if inputs does not require gradient,
@ -1603,13 +1603,13 @@ def _validate_stage_shapes(pipeline_stages: List[PipelineStage]):
]
logger.debug(
f"Rank: {pg_rank}" # noqa: G004
f"Stage id: {stage_id}"
f"Stage num stages: {stage.num_stages}"
f"Stage rank: {rank}"
f"Stage world size: {world_size}"
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}" # noqa: G003
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}" # noqa: G003
f"Rank: {pg_rank}", # noqa: G004
f"Stage id: {stage_id}",
f"Stage num stages: {stage.num_stages}",
f"Stage rank: {rank}",
f"Stage world size: {world_size}",
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}", # noqa: G003
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}", # noqa: G003
)
all_inputs.extend(stage_input_shapes)

View File

@ -77,7 +77,7 @@ def found_inf_reduce_handler(
cast(List[object], op_info.local_args), op_info.args_tree_spec
)
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
op_call(*local_tensor_args, **op_info.local_kwargs)
grad_dtensor = cast(list[dtensor.DTensor], args[0])[0]
grad_placements = grad_dtensor.placements

View File

@ -21,9 +21,9 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding:
stride,
padding,
dilation,
transposed,
output_padding,
groups,
_transposed,
_output_padding,
_groups,
) = op_schema.args_schema
assert isinstance(input_spec, DTensorSpec)
@ -37,7 +37,7 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding:
assert isinstance(padding, List)
assert isinstance(dilation, List)
assert isinstance(weight_shape, torch.Size)
N, C_in, H_in, W_in = in_shape[0], in_shape[1], in_shape[2], in_shape[3]
N, H_in, W_in = in_shape[0], in_shape[2], in_shape[3]
C_out = weight_shape[0]
H_out = (H_in + 2 * padding[0] - dilation[0] * (weight_shape[2] - 1) - 1) // stride[
0
@ -73,13 +73,13 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
input_spec,
weight_spec,
bias_shape_opt,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
_stride,
_padding,
_dilation,
_transposed,
_output_padding,
_groups,
_output_mask,
) = op_schema.args_schema
assert isinstance(grad_output_spec, DTensorSpec)

View File

@ -479,7 +479,7 @@ def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim)
output_strategy = OpStrategy([])
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
for input_placement_strategy in input_strategy.strategies:
redistribute_costs = []
input_src_spec = input_placement_strategy.output_spec
@ -1038,8 +1038,6 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
)
def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
k = cast(int, op_schema.args_schema[1])
input_shape = input_strategy.shape
topk_dim = (
cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1
)

View File

@ -171,7 +171,6 @@ def scaled_dot_product_flash_attention_strategy(
q_input_strategy = op_schema.args_schema[0]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
qkv_shape = q_input_strategy.shape
single_mesh_dim_strategies = []
@ -250,7 +249,6 @@ def scaled_dot_product_flash_attention_backward_strategy(
q_input_strategy = op_schema.args_schema[1]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
qkv_shape = q_input_strategy.shape
tensor_input_indices = [
i
@ -344,7 +342,7 @@ def scaled_dot_product_efficient_attention_strategy(
q_input_strategy = op_schema.args_schema[0]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
qkv_shape = q_input_strategy.shape
has_attn_bias = op_schema.args_schema[3] is not None
compute_log_sumexp = op_schema.args_schema[4]
@ -418,15 +416,8 @@ def scaled_dot_product_efficient_attention_backward_strategy(
q_input_strategy = op_schema.args_schema[1]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
qkv_shape = q_input_strategy.shape
has_attn_bias = op_schema.args_schema[4] is not None
tensor_input_indices = [
i
for i, arg_spec in enumerate(op_schema.args_schema)
if isinstance(arg_spec, OpStrategy)
]
single_mesh_dim_strategies = []
# placement list stores placements of [outputs, inputs]

View File

@ -367,7 +367,6 @@ def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType
schema_info=RuntimeSchemaInfo(1),
)
def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
single_mesh_dim_strategies = []
# placement list stores placements of [output, input, index, src]

View File

@ -67,7 +67,7 @@ def _gen_transform_infos_non_cached(
# Handle multi-dim device mesh placement redistribution
# First, we need to build the logical shape for each mesh dim
# for correct allgathering uneven shards on each mesh dim (with dynamic padding)
for i, (src, dst) in enumerate(zip(src_spec.placements, dst_spec.placements)):
for i, src in enumerate(src_spec.placements):
current_logical_shape = mesh_dims_to_logical_shape[i]
if isinstance(src, Shard):
if i < device_mesh.ndim - 1:
@ -192,7 +192,7 @@ def redistribute_local_tensor(
for transform_info in transform_infos:
i = transform_info.mesh_dim
current, target = transform_info.src_dst_placements
num_chunks = device_mesh.size(mesh_dim=i)
device_mesh.size(mesh_dim=i)
if current == target:
# short cut, just use the original local tensor
@ -220,7 +220,6 @@ def redistribute_local_tensor(
elif target.is_shard():
# Case 2: target is Shard
target_placement = cast(Shard, target)
target_dim = target_placement.dim
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_shard_value(

View File

@ -192,7 +192,6 @@ def tp_convolution_backward(
)
# step2 reconstruct local gradient output tensor
N, C_out, H_out, _ = grad_out_tensor.shape
padding_w = padding[1]
if rank == 0:
grad_out_tensor = torch.nn.functional.pad(

View File

@ -269,7 +269,7 @@ class CommDebugModeExample:
comm_mode = CommDebugMode()
with comm_mode:
output = model(inp)
model(inp)
# print the module level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
@ -592,7 +592,7 @@ class CommDebugModeExample:
comm_mode = CommDebugMode()
with comm_mode:
output = model(inp)
model(inp)
# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=2))
@ -628,7 +628,7 @@ class CommDebugModeExample:
comm_mode = CommDebugMode()
with comm_mode:
output = model(inp)
model(inp)
comm_mode.generate_json_dump(file_name="transformer_log.json", noise_level=1)
comm_mode.generate_json_dump(file_name="transformer_log_2.json", noise_level=2)

View File

@ -220,7 +220,7 @@ def train_convnext_example():
forward_time = 0.0
backward_time = 0.0
start = time.time()
for i in range(ITER_TIME):
for _ in range(ITER_TIME):
t1 = time.time()
y = model(x)
torch.cuda.synchronize()

View File

@ -130,7 +130,6 @@ def run_torchrec_row_wise_even_sharding_example(rank, world_size):
# manually create the embedding table's local shards
num_embeddings = 8
embedding_dim = 16
emb_table_shape = torch.Size([num_embeddings, embedding_dim])
# tensor shape
local_shard_shape = torch.Size(
[num_embeddings // world_size, embedding_dim] # (local_rows, local_cols)
@ -270,7 +269,7 @@ def run_torchrec_table_wise_sharding_example(rank, world_size):
device = torch.device(device_type)
# note: without initializing this mesh, the following local_tensor will be put on
# device cuda:0.
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,))
init_device_mesh(device_type=device_type, mesh_shape=(world_size,))
# manually create the embedding table's local shards
num_embeddings = 8
@ -293,8 +292,6 @@ def run_torchrec_table_wise_sharding_example(rank, world_size):
else torch.empty(0, device=device)
)
table_to_local_tensor[i] = local_tensor
# tensor shape
local_shard_shape = local_tensor.shape
# tensor offset
local_shard_offset = torch.Size((0, 0))
# wrap local shards into a wrapper

View File

@ -1087,7 +1087,6 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
), "The current implementation only works if ROUND_ROBIN_CYCLE is 2."
buffer = buffer.contiguous()
cp_world_size = mesh.size()
cp_rank = mesh.get_local_rank()
all_buffers = [torch.empty_like(buffer) for _ in range(cp_world_size)]
ft_c.all_gather_inplace(all_buffers, buffer, mesh)

View File

@ -279,7 +279,6 @@ def _nll_loss_forward_handler(
ignore_index = cast(int, args[4])
channel_dim = 1 if x.dim() >= 2 else 0
channel_dim_size = x.shape[channel_dim]
spec = x._spec
mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)

View File

@ -48,7 +48,6 @@ class Kumaraswamy(TransformedDistribution):
self.concentration1, self.concentration0 = broadcast_all(
concentration1, concentration0
)
finfo = torch.finfo(self.concentration0.dtype)
base_dist = Uniform(
torch.full_like(self.concentration0, 0),
torch.full_like(self.concentration0, 1),

View File

@ -312,7 +312,6 @@ class Wishart(ExponentialFamily):
def entropy(self):
nu = self.df # has shape (batch_shape)
p = self._event_shape[-1] # has singleton shape
V = self.covariance_matrix # has shape (batch_shape x event_shape)
return (
(p + 1)
* (

View File

@ -921,7 +921,7 @@ def _verify_stack_trace(graph_module: torch.fx.GraphModule) -> None:
- None or non-empty str for 'call_function', 'get_attr'
- None for 'placeholder', 'output'
"""
for i, mod in enumerate([graph_module] + list(graph_module.modules())):
for mod in [graph_module, *graph_module.modules()]:
if not isinstance(mod, torch.fx.GraphModule):
continue
for node in graph_module.graph.nodes:

View File

@ -55,7 +55,7 @@ def get_node_context(node, num_nodes=2) -> str:
"""
node_contexts = []
cur = node
for i in range(num_nodes):
for _ in range(num_nodes):
node_contexts.append(cur.format_node())
if cur.op == "root":
break

View File

@ -463,8 +463,6 @@ class Partitioner:
# Check if no device is left
if len(self.partitions) == len(self.devices):
# No device is left
# Put the previous partitions into a list (non_single_node_partitions)
non_single_node_partitions = self.partitions[:]
# Create the first single node partition for the current node
self.create_single_node_partition(node)
continue

View File

@ -175,7 +175,6 @@ def get_attr_inference_rule(n: Node, traced):
The most representitive type we have is "Dyn" but the system
can be extended with more types, such as a type to represent shapes
"""
attr_node = n.args[0]
attr_name = n.args[1]
if attr_name == "shape":

View File

@ -227,7 +227,7 @@ class MetaTracer(torch.fx.Tracer):
def path_of_module(self, mod: torch.nn.Module) -> str:
try:
return super().path_of_module(mod)
except NameError as e:
except NameError:
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
path = self._insert_module_as_submodule(mod)
self.prev_module = path

View File

@ -251,7 +251,7 @@ def generate_binconstraint_t(constraint, counter):
disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
for i in range(1, constraint.rhs + 1):
dims = []
for j in range(1, i + 1):
for _ in range(1, i + 1):
dim_var, counter = gen_dvar(counter)
dims.append(dim_var)
disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq))

View File

@ -193,7 +193,7 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
f()
begin = time.time()
for _ in range(iters):
out = f()
f()
return time.time() - begin
mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])])
@ -278,7 +278,7 @@ def optimize_for_inference(
cur_tracer = tracer()
fx_graph = cur_tracer.trace(copy.deepcopy(model))
fx_model = fx.GraphModule(cur_tracer.root, fx_graph)
fx.GraphModule(cur_tracer.root, fx_graph)
modules: Dict[str, nn.Module] = dict(model.named_modules())
class MklSupport(Enum):

View File

@ -279,10 +279,8 @@ def get_latency_of_partitioned_graph(
def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
"""This function helps to recursively get the latency of a path of partitions"""
# Update latency by adding current partition's latency
latency_so_far_sec += partition_to_latency_mapping[
partition
].overall_latency_sec
children = partition.children
latency_so_far_sec += partition_to_latency_mapping[partition].overall_latency_sec
if partition.children:
max_latency_sec = 0.0
for child in partition.children:

View File

@ -1176,11 +1176,11 @@ def dispatch_trace(
def wrap_key(
f: Callable[_P, R], tensors: _P.args, tracer: _ProxyTracer, pre_dispatch: bool
) -> Callable[_P, R]:
flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
flat_tensors, _tensors_spec = pytree.tree_flatten(tensors)
@functools.wraps(f)
def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R:
flat_proxies, proxies_spec = pytree.tree_flatten(proxies)
flat_proxies, _proxies_spec = pytree.tree_flatten(proxies)
assert len(flat_proxies) == len(flat_tensors)
with disable_proxy_modes_tracing() as m:
assert isinstance(m, ProxyTorchDispatchMode)
@ -1733,7 +1733,7 @@ class _ModuleStackTracer(PythonKeyTracer):
try:
return Tracer.call_module(self, m, forward, args, kwargs)
except _ModuleNotInstalledAsSubmoduleError as e:
except _ModuleNotInstalledAsSubmoduleError:
warnings.warn(
f"Unable to find the path of the module {m}. "
"This might be because the module was not properly registered "

View File

@ -331,7 +331,7 @@ def replay_shape_env_events(events):
# We need to call create_mapping_fn every time, since the node list might
# change after each event is replayed.
event.run(shape_env)
except Exception as e:
except Exception:
log.error("failed when running event: %s", event)
raise

View File

@ -3879,8 +3879,6 @@ class ShapeEnv:
guess
"""
source_name = source.name() if source else None
if self._translation_validation_enabled and source is not None:
# Create a new symbol for this source.
symbol = self._create_symbol_for_source(source)
@ -3919,8 +3917,6 @@ class ShapeEnv:
source: Optional[Source] = None,
) -> Union[float, SymFloat]:
"""Create a SymFloat value from a symbolic expression"""
source_name = source.name() if source else None
if self._translation_validation_enabled and source is not None:
# Create a new symbol for this source.
symbol = self._create_symbol_for_source(source)
@ -4808,7 +4804,7 @@ class ShapeEnv:
res = f"{source_ref(source)} == {sexpr}"
exprs.append(res)
if (s0 := self.source_to_var.get(srcname)) is not None:
if source != (canonical_source := self.var_to_sources[s0][0]):
if source != self.var_to_sources[s0][0]:
verbose_exprs.append(
f"{res} # duck sizing added this equality because these "
f"variables had the same size {self.var_to_val[s0]} "
@ -6140,7 +6136,6 @@ class ShapeEnv:
# If an error is raised before the end of this function, we remove the FX node
# inserted, and re-raise the error.
guard = None
tb = None
try:
if orig_expr.is_number:

View File

@ -16,7 +16,7 @@ class Dispatcher:
self.ordering = ordering(self.funcs)
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
func, _ = self.resolve(args)
return func(*args, **kwargs)
def resolve(self, args):

View File

@ -310,7 +310,6 @@ def _parse_stack_trace(stack_trace: str):
# stacktrace should have innermost frame last, so we
# iterate backwards to find the first line that starts
# with 'File '
summary_str = ""
for idx in range(len(lines) - 2, -1, -1):
line = lines[idx].strip()
matches = pattern.match(line)
@ -463,10 +462,10 @@ class CodeGen:
return s
return f
yellow = make_wrapper_func("yellow")
cyan = make_wrapper_func("cyan")
yellow = make_wrapper_func("yellow") # noqa: F841
cyan = make_wrapper_func("cyan") # noqa: F841
red = make_wrapper_func("red")
green = make_wrapper_func("green")
green = make_wrapper_func("green") # noqa: F841
dim_green = make_wrapper_func("dim_green")
dim = make_wrapper_func("dim")
dim_blue = make_wrapper_func("dim_blue")

View File

@ -126,7 +126,7 @@ def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...]
try:
candidate_signature.bind(*args, **kwargs)
matched_schemas.append((candidate_signature, schema))
except TypeError as e:
except TypeError:
continue
def throw_if_mutable(schema):
@ -214,7 +214,7 @@ def create_type_hint(x):
else:
return ret_type(Any)
return ret_type(base_type)
except Exception as e:
except Exception:
# We tried to create a type hint for list but failed.
warnings.warn(f"We were not able to successfully create type hint from the type {x}")
return x
@ -328,7 +328,7 @@ def normalize_function(
try:
candidate_signature.bind(*args, **kwargs)
matched_schemas.append(candidate_signature)
except TypeError as e:
except TypeError:
continue
if len(matched_schemas) == 0:
@ -349,7 +349,7 @@ def normalize_function(
for arg_name, arg_type in bound_types.arguments.items():
param = candidate_signature.parameters[arg_name]
sig_matches = sig_matches and type_matches(param.annotation, arg_type)
except TypeError as e:
except TypeError:
sig_matches = False
if sig_matches:
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs,

View File

@ -107,7 +107,6 @@ def tensorify_python_scalars(
placeholders = set()
for node in graph.nodes:
if node.op != "placeholder":
first_none_placeholder = node
break
else:
placeholders.add(node)

View File

@ -58,7 +58,6 @@ def get_size_of_all_nodes(
# Mark shape and dtype for each node (node.shape and node.dtype)
ShapeProp(fx_module).propagate(*args)
# Calculate the total size of the whole fx graph
total_size_of_graph = 0.0
for node in fx_module.graph.nodes:
if node.op == "output":
break
@ -92,7 +91,7 @@ def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
submodule = submodule_dict[node.target]
parameters = submodule.named_parameters()
# Parameters are named tuples
for name, p in parameters:
for _name, p in parameters:
total_num_of_elems += p.numel()
# Don't forget the output size
# node.shape is the shape of this node's output

View File

@ -31,7 +31,7 @@ def inplace_wrapper(fn: Callable) -> Callable:
@wraps(fn)
def wrapped_fn(gm):
val = fn(gm)
fn(gm)
return gm
return wrapped_fn

View File

@ -487,7 +487,7 @@ def reinplace(gm, *sample_args):
# inplace-ify functional ops, subject to the constraints written below.
all_later_view_inverse_nodes_to_delete = set()
for idx, node in enumerate(gm.graph.nodes):
for node in gm.graph.nodes:
if node.op == 'call_function':
# Today, the re-inplace pass on directly acts on:
@ -532,7 +532,6 @@ def reinplace(gm, *sample_args):
continue
# Step 1b: ensure that the op we're trying to re-inplace isn't a program input
self_arg_name = self_arg.name
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
if self_arg_storage in input_storages:
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
@ -578,7 +577,7 @@ def reinplace(gm, *sample_args):
remaining_slice_args = node.args[2:]
slice_node = gm.graph.create_node(
'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs)
copy_node = gm.graph.create_node(
gm.graph.create_node(
'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {})
# Add the slice_scatter node to our "nodes to delete" list.
all_later_view_inverse_nodes_to_delete.add(node)
@ -612,8 +611,6 @@ def reinplace(gm, *sample_args):
new = old.args[0]
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
for node_to_update in nodes_to_update:
new_args = []
args = node_to_update.args
def replace_arg(a):
if a == old:
@ -654,7 +651,6 @@ def reinplace(gm, *sample_args):
x._typed_storage()
) for x in new_flattened_res if isinstance(x, FakeTensor)}
assert len(new_res_storage) == 1
(old_ref,) = old_res_storage
(new_ref,) = new_res_storage
(node_ref,) = node_res_storage
# Technically, "old_ref" and all its aliases will remain

View File

@ -526,7 +526,7 @@ def split_module(
for node in original_order:
if node in already_constructed_attr_nodes:
continue # already added this attr to the base graph
base_mod_env, based_mod_attrs = construct_graph(
base_mod_env, _based_mod_attrs = construct_graph(
node, base_mod_env, base_mod_attrs
)
already_constructed_attr_nodes.add(node)

View File

@ -80,7 +80,7 @@ class _SplitterSettingBase:
"we might not care about non-tensor data flow and we can set this option "
"to true to disable the functionality that prevent non-tensor data flow.",
)
args, unknown = parser.parse_known_args()
args, _unknown = parser.parse_known_args()
self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
@ -882,7 +882,7 @@ class _SplitterBase:
def generate_split_results(self) -> SplitResult:
split_module = self()
submodule_names = []
for name, mod in split_module.named_children():
for name, _mod in split_module.named_children():
submodule_names.append(name)
if (
self.settings.max_acc_splits > 0

View File

@ -279,7 +279,7 @@ def _replace_pattern(
match_changed_node: Dict[Node, Node] = {}
match_and_replacements = []
for i, match in enumerate(_matches):
for match in _matches:
if replacement_callback is not None:
replacement_graph = replacement_callback(match, original_graph, pattern_graph)
else:

View File

@ -720,7 +720,7 @@ def download_url_to_file(
# We deliberately do not use NamedTemporaryFile to avoid restrictive
# file permissions being applied to the downloaded file.
dst = os.path.expanduser(dst)
for seq in range(tempfile.TMP_MAX):
for _ in range(tempfile.TMP_MAX):
tmp_dst = dst + "." + uuid.uuid4().hex + ".partial"
try:
f = open(tmp_dst, "w+b")

View File

@ -570,10 +570,6 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
method_stubs = stubs_fn(nn_module)
property_stubs = get_property_stubs(nn_module)
hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
user_annotated_ignored_attributes = getattr(
nn_module, "__jit_ignored_attributes__", []
)
ignored_properties = jit_ignored_properties(nn_module)
def init_fn(script_module):
@ -838,9 +834,6 @@ def infer_methods_to_compile(nn_module):
(TODO add a link when the rules are published).
"""
check_module_initialized(nn_module)
user_annotated_ignored_attributes = getattr(
nn_module, "__jit_ignored_attributes__", []
)
ignored_properties = jit_ignored_properties(nn_module)
methods: List[str] = []

View File

@ -1600,7 +1600,7 @@ def _recursive_compile_class(obj, loc):
_qual_name = _qualified_name(obj)
# We're starting a new compilation, so update the error call stack in
# case it fails
error_stack = torch._C.CallStack(_qual_name, loc)
error_stack = torch._C.CallStack(_qual_name, loc) # noqa: F841
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
return _compile_and_register_class(obj, rcb, _qual_name)

View File

@ -255,7 +255,6 @@ def pool2d_shape_check(
outputWidth: int,
):
ndim = len(input)
nOutputPlane = nInputPlane
assert kW > 0 and kH > 0
assert dW > 0 and dH > 0
@ -608,12 +607,10 @@ def matmul(tensor1: List[int], tensor2: List[int]):
# We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
# we track m1 vs m2 separately even though they must match for nicer error messages
n = tensor1[-2] if dim_tensor1 > 1 else 1
m1 = tensor1[-1]
batch_tensor1: List[int] = []
# TODO: handling of slice
for i in range(dim_tensor1 - 2):
batch_tensor1.append(tensor1[i])
m2 = tensor2[-1] if dim_tensor2 > 1 else 1
p = tensor2[-1]
batch_tensor2: List[int] = []
# TODO: handling of slice

View File

@ -55,7 +55,6 @@ def _create_interpreter_name_lookup_fn(frames_up=1):
i += 1
f_locals = frame.f_locals
f_globals = frame.f_globals
for k, v in f_locals.items():
if isinstance(v, torch.Tensor) and var is v:
@ -136,7 +135,7 @@ class ONNXTracedModule(torch.nn.Module):
else:
return tuple(out_vars)
graph, out = torch._C._create_graph_by_tracing(
graph, _out = torch._C._create_graph_by_tracing(
wrapper,
in_vars + module_state,
_create_interpreter_name_lookup_fn(),
@ -241,7 +240,6 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
if not isinstance(args, tuple):
args = (args,)
saved_args = _clone_inputs(args)
if is_module:
saved_state = copy.deepcopy(model.state_dict())

View File

@ -40,7 +40,7 @@ def _gen_unsupported_methods_properties():
scope: Dict[str, Any] = {}
execWrapper(funcs_str, globals(), scope)
try:
cu = torch.jit.CompilationUnit(funcs_str)
torch.jit.CompilationUnit(funcs_str)
except Exception as e:
if "nonexistent attribute" not in repr(e):
continue

View File

@ -1784,7 +1784,6 @@ def normalize(
) -> Tensor:
if dtype is None:
dtype = input.dtype
dim_ = _canonical_dim(dim, input.ndim)[0]
# TODO: eliminate mask_input as unnecessary when using masked divide.
mask_input = _combine_input_and_mask(sum, input, mask)
if mask_input.layout == torch.strided:

View File

@ -375,7 +375,7 @@ def ones_like(func, *args, **kwargs):
@register_dispatch_func([torch.ops.aten._softmax_backward_data])
def _softmax_backward_data(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4)
grad, output, dim, input_dtype = args
grad, output, dim, _input_dtype = args
if is_masked_tensor(grad) and is_masked_tensor(output):
if not _masks_match(grad, output):
raise ValueError(

View File

@ -96,8 +96,8 @@ def _binary_helper(fn, args, kwargs, inplace):
"Input masks must match. If you need support for this, please open an issue on Github."
)
data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
data_args, _data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
mask_args, _mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
args0_layout = data_args[0].layout
same_layout = (

View File

@ -120,8 +120,12 @@ def _unary_helper(fn, args, kwargs, inplace):
"MaskedTensor unary ops do not support additional Tensor arguments"
)
mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_mask)
data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_data)
mask_args, _mask_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x._masked_mask
)
data_args, _data_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x._masked_data
)
if args[0].layout == torch.sparse_coo:
data_args[0] = data_args[0].coalesce()

View File

@ -33,7 +33,7 @@ class Pool(multiprocessing.pool.Pool):
Bring the number of pool processes up to the specified number, for use after
reaping workers which have exited.
"""
for i in range(self._processes - len(self._pool)):
for _ in range(self._processes - len(self._pool)):
# changed worker -> clean_worker
args = (
self._inqueue,

View File

@ -664,7 +664,7 @@ def _softmax_default(func, *args, **kwargs):
new_kwargs["dim"],
reduce_on_batch,
reduce_on_ragged,
reduce_on_non_batch,
_reduce_on_non_batch,
) = _wrap_jagged_dims(
inp.dim(),
(new_kwargs["dim"],),
@ -985,7 +985,7 @@ def matmul_default(func, *args, **kwargs):
def _padded_impl(a, b):
assert a.is_nested and not b.is_nested
nt, t = a, b
nt = a
from .nested_tensor import nested_from_padded
@ -1588,7 +1588,7 @@ def mean_dim(func, *args, **kwargs):
new_kwargs["dim"],
reduce_on_batch,
reduce_on_ragged,
reduce_on_non_batch,
_reduce_on_non_batch,
) = _wrap_jagged_dims(
inp.dim(),
new_kwargs["dim"],

View File

@ -323,7 +323,6 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, in
cumulative_seqlen = (
qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
)
batch_size = qkv.size(0)
max_seqlen = qkv._get_max_seqlen()
# TODO: Explore performance impact when compiling
n_elem = int(cumulative_seqlen[-1].item())
@ -750,10 +749,10 @@ def jagged_scaled_dot_product_attention(
(
attention,
logsumexp,
philox_seed,
philox_offset,
debug_attn_mask,
_logsumexp,
_philox_seed,
_philox_offset,
_debug_attn_mask,
) = torch.ops.aten._flash_attention_forward(
query_buffer_reshaped,
key_buffer_reshaped,

View File

@ -193,8 +193,6 @@ def _adjust_num_blocks_and_indices(
new_num_rows: int,
new_num_cols: int,
):
num_rows = indices.shape[-2]
num_columns = indices.shape[-1]
indices = indices[:, :, :new_num_rows, :new_num_cols]
num_blocks = num_blocks[:, :, :new_num_rows]
num_blocks = torch.where(num_blocks < new_num_cols, num_blocks, new_num_cols)

View File

@ -6236,7 +6236,7 @@ def multi_head_attention_forward(
#
if need_weights:
B, Nt, E = q.shape
_B, _Nt, E = q.shape
q_scaled = q * math.sqrt(1.0 / float(E))
assert not (

View File

@ -224,11 +224,7 @@ class CrossMapLRN2d(Function):
ctx.scale = ctx.scale or input.new()
output = input.new()
batch_size = input.size(0)
channels = input.size(1)
input_height = input.size(2)
input_width = input.size(3)
output.resize_as_(input)
ctx.scale.resize_as_(input)

View File

@ -2630,7 +2630,7 @@ class Module:
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
"""
for name, param in self.named_parameters(recurse=recurse):
for _name, param in self.named_parameters(recurse=recurse):
yield param
def named_parameters(
@ -2725,7 +2725,7 @@ class Module:
Yields:
Module: a child module
"""
for name, module in self.named_children():
for _name, module in self.named_children():
yield module
def named_children(self) -> Iterator[Tuple[str, "Module"]]:

View File

@ -1043,7 +1043,6 @@ class LSTM(RNNBase):
orig_input = input
# xxx: isinstance check needs to be in conditional for TorchScript to compile
batch_sizes = None
do_permute = False
num_directions = 2 if self.bidirectional else 1
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
if isinstance(orig_input, PackedSequence):

View File

@ -1496,7 +1496,7 @@ class DistributedDataParallel(Module, Joinable):
# Disable the python reducer if compiled_autograd is not enabled.
if self._accum_grad_hooks:
for index, h in enumerate(self._accum_grad_hooks):
for h in self._accum_grad_hooks:
h.remove()
self._accum_grad_hooks.clear()

View File

@ -68,7 +68,7 @@ class _PyTreeExtensionContext:
def _register_huggingface_model_output_extension(self):
try:
from transformers import modeling_outputs # type: ignore[import]
except ImportError as e:
except ImportError:
return
def model_output_flatten(

View File

@ -61,7 +61,6 @@ def set_node_name(
new_name: The new name to use.
name_to_node_cache: A cache of node names to nodes.
"""
module = node.graph.owning_module
node_name_to_set = collections.deque([(node, new_name)])
while node_name_to_set:

View File

@ -84,12 +84,11 @@ class Functionalize(_pass.Transform):
out = function(*inputs_functional)
finally:
torch._disable_functionalization()
flat_inputs = pytree.tree_leaves(inputs)
flat_inputs_functional = pytree.tree_leaves(inputs_functional)
for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
for input_functional in flat_inputs_functional:
if isinstance(input_functional, torch.Tensor):
torch._sync(input_functional)
inpt_new = torch._from_functional_tensor(input_functional)
pytree.tree_map(torch._sync, out)
out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out)
return out_unwrapped

View File

@ -139,7 +139,7 @@ class _ModuleMeta:
cls, raw_meta: _DYNAMO_NN_MODULE_META_TYPE
) -> _ModuleMeta:
"""Create a module meta from raw meta produced by FX dynamo tracer."""
module_name, (qualified_name, module_class) = raw_meta
module_name, (_qualified_name, module_class) = raw_meta
return _ModuleMeta(module_name.split("@")[0], module_class, raw_meta)
@classmethod

View File

@ -43,7 +43,7 @@ _SCALAR_TYPE_TENSOR_DTYPE_MAP: Mapping[type, torch.dtype] = {
def _try_getclosurevars(func):
try:
return inspect.getclosurevars(func)
except TypeError as e:
except TypeError:
return None

View File

@ -1988,7 +1988,7 @@ def _embedding_bag_helper(
# FIXME(justinchuby): We need to handle what happens when we call b.op on a node return
block_input_iter = utils._add_input_to_block(loop_block)
cond = utils._add_input_to_block(loop_block)
utils._add_input_to_block(loop_block)
indices_start = loop_context.op(
"Gather", offsets_starts, block_input_iter, axis_i=0

View File

@ -534,7 +534,7 @@ def stack(g: jit_utils.GraphContext, tensor_list, dim):
@_onnx_symbolic("aten::_unique2")
@symbolic_helper.parse_args("v", "i", "i", "i")
def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts):
u, indices, inverse_indices, counts = g.op(
u, _indices, inverse_indices, counts = g.op(
"Unique", self, sorted_i=sorted, outputs=4
)
return u, inverse_indices, counts
@ -545,7 +545,7 @@ def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_cou
def unique_dim(
g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts
):
u, indices, inverse_indices, counts = g.op(
u, _indices, inverse_indices, counts = g.op(
"Unique", self, axis_i=dim, sorted_i=sorted, outputs=4
)
return u, inverse_indices, counts
@ -945,7 +945,6 @@ def index(g: jit_utils.GraphContext, self, index):
@_onnx_symbolic("aten::index_fill")
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
dim_value = symbolic_helper._parse_arg(dim, "i")
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)
@ -957,8 +956,7 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
@_onnx_symbolic("aten::index_copy")
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
dim_value = symbolic_helper._parse_arg(dim, "i")
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
_expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)
return scatter(g, self, dim, expanded_index, source)

View File

@ -346,8 +346,7 @@ def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
loop_block = loop_context.block
block_input_iter = utils._add_input_to_block(loop_block)
# FIXME(justinchuby): cond is unused?
cond = utils._add_input_to_block(loop_block)
cond = utils._add_input_to_block(loop_block) # noqa: F841
starts = loop_context.op("Gather", low_indices, block_input_iter)
ends = loop_context.op("Gather", hi_indices, block_input_iter)

View File

@ -211,7 +211,7 @@ def tensor_split(
loop_block = loop_context.block
block_input_iter = utils._add_input_to_block(loop_block)
cond = utils._add_input_to_block(loop_block)
cond = utils._add_input_to_block(loop_block) # noqa: F841
final_splits = utils._add_input_to_block(loop_block)
start = loop_context.op(
@ -689,7 +689,7 @@ def repeat_interleave(
loop_block = loop_context.block
block_input_iter = utils._add_input_to_block(loop_block)
cond = utils._add_input_to_block(loop_block)
cond = utils._add_input_to_block(loop_block) # noqa: F841
final_splits = utils._add_input_to_block(loop_block)
r_split = loop_context.op("SequenceAt", r_splits, block_input_iter)

View File

@ -2955,7 +2955,6 @@ def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accum
@_onnx_symbolic("aten::index_fill")
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
dim_value = symbolic_helper._parse_arg(dim, "i")
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)
@ -2968,8 +2967,7 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
@_onnx_symbolic("aten::index_copy")
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
dim_value = symbolic_helper._parse_arg(dim, "i")
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
_expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)
return scatter(g, self, dim, expanded_index, source)
@ -3674,14 +3672,14 @@ def new_full(
def eye(g: jit_utils.GraphContext, *args):
if len(args) == 5:
# aten::eye(n, dtype, layout, device, pin_memory)
n, dtype, layout, device, pin_memory = args
n, dtype, layout, device, _pin_memory = args
dim_size = symbolic_helper._unsqueeze_helper(g, n, [0])
shape = g.op("Concat", dim_size, dim_size, axis_i=0)
tensor = zeros(g, shape, dtype, layout, device)
return g.op("EyeLike", tensor)
if len(args) == 6:
# aten::eye(n, m, dtype, layout, device, pin_memory)
n, m, dtype, layout, device, pin_memory = args
n, m, dtype, layout, device, _pin_memory = args
shape = g.op(
"Concat",
symbolic_helper._unsqueeze_helper(g, n, [0]),
@ -5567,14 +5565,14 @@ def linalg_matrix_norm(
g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim
)
if ord_value > 0:
result, indices = max(
result, _indices = max(
g,
sum,
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
keepdim=keepdim,
)
else:
result, indices = min(
result, _indices = min(
g,
sum,
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
@ -6391,7 +6389,7 @@ def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]:
opset_version = GLOBALS.export_onnx_opset_version
old_blocks = tuple(node.blocks())
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
_new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks)
)
@ -6500,7 +6498,7 @@ def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]:
return final_b_list
else:
old_blocks = tuple(n.blocks())
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
_new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks)
)

View File

@ -1066,7 +1066,7 @@ def _model_to_graph(
input_names=input_names,
module=module,
)
except Exception as e:
except Exception:
_C._jit_onnx_log("Torch IR graph at exception: ", graph)
raise
@ -1544,8 +1544,8 @@ def _export(
(
proto,
export_map,
val_use_external_data_format,
node_names,
_val_use_external_data_format,
_node_names,
) = graph._export_onnx( # type: ignore[attr-defined]
params_dict,
opset_version,
@ -1563,8 +1563,8 @@ def _export(
(
proto,
export_map,
val_use_external_data_format,
node_names,
_,
_,
) = graph._export_onnx( # type: ignore[attr-defined]
{},
opset_version,

View File

@ -1783,7 +1783,7 @@ def find_mismatch(
args = utils._decide_input_format(model, inputs_for_export)
model = utils._pre_trace_quant_model(model, args)
graph, params, torch_out, module = utils._create_jit_graph(model, args)
graph, params, _torch_out, _module = utils._create_jit_graph(model, args)
params_dict = utils._get_named_param_dict(graph, params)
utils._apply_friendly_debug_names(graph, params_dict)

View File

@ -247,8 +247,7 @@ class LRScheduler:
else:
values = self.get_lr()
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
for param_group, lr in zip(self.optimizer.param_groups, values):
if isinstance(param_group["lr"], Tensor):
param_group["lr"].fill_(lr)
else:
@ -1865,8 +1864,7 @@ class CosineAnnealingWarmRestarts(LRScheduler):
self.last_epoch = math.floor(epoch)
with _enable_get_lr_call(self):
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
param_group, lr = data
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group["lr"] = lr
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]

View File

@ -53,7 +53,7 @@ def demangle(name: str) -> str:
mangled name, irrespective of which PackageMangler created it.
"""
if is_mangled(name):
first, sep, last = name.partition(".")
_first, sep, last = name.partition(".")
# If there is only a base mangle prefix, e.g. '<torch_package_0>',
# then return an empty string.
return last if len(sep) != 0 else ""

View File

@ -49,6 +49,7 @@ class PackagePickler(_PyTorchLegacyPickler):
self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
def save_global(self, obj, name=None):
# ruff: noqa: F841
# unfortunately the pickler code is factored in a way that
# forces us to copy/paste this function. The only change is marked
# CHANGED below.

View File

@ -89,7 +89,7 @@ class _ExtractModuleReferences(ast.NodeVisitor):
self.references[(name, alias)] = True
else:
self.references[(name, None)] = True
except Exception as e:
except Exception:
return

View File

@ -427,7 +427,7 @@ class PackageExporter:
def _import_module(self, module_name: str):
try:
return self.importer.import_module(module_name)
except ModuleNotFoundError as e:
except ModuleNotFoundError:
if not is_mangled(module_name):
raise
msg = (
@ -662,7 +662,7 @@ class PackageExporter:
memo: DefaultDict[int, str] = defaultdict(None)
memo_count = 0
# pickletools.dis(data_value)
for opcode, arg, pos in pickletools.genops(data_value):
for opcode, arg, _pos in pickletools.genops(data_value):
if pickle_protocol == 4:
if (
opcode.name == "SHORT_BINUNICODE"

View File

@ -463,7 +463,6 @@ class PackageImporter(Importer):
# note: copied from cpython's import code, with call to create module replaced with _make_module
def _do_find_and_load(self, name):
path = None
parent = name.rpartition(".")[0]
module_name_no_parent = name.rpartition(".")[-1]
if parent:
@ -475,7 +474,7 @@ class PackageImporter(Importer):
parent_module = self.modules[parent]
try:
path = parent_module.__path__ # type: ignore[attr-defined]
parent_module.__path__ # type: ignore[attr-defined]
except AttributeError:
# when we attempt to import a package only containing pybinded files,

View File

@ -192,7 +192,7 @@ def _extract_parameters_and_gradients(
def extract_parameters(node: _ProfilerEvent) -> Iterator[TensorKey]:
for p, p_grad in _extract_parameters_and_gradients(node):
for p, _p_grad in _extract_parameters_and_gradients(node):
if p is not None:
yield p

View File

@ -884,7 +884,7 @@ class ExecutionTraceObserver(_ITraceObserver):
for kernel_file in kernel_files:
if kernel_file is None:
continue
path, name = os.path.split(kernel_file)
name = os.path.basename(kernel_file)
dst = os.path.join(resource_dir, name)
shutil.copyfile(kernel_file, dst)

View File

@ -16,7 +16,7 @@ def default_eval_fn(model, calib_data):
Default evaluation function takes a torch.utils.data.Dataset or a list of
input Tensors and run the model on the dataset
"""
for data, target in calib_data:
for data, _target in calib_data:
model(data)

View File

@ -1493,7 +1493,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
tar.extract("storages", path=tmpdir)
with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
num_storages = pickle_module.load(f, **pickle_load_args)
for i in range(num_storages):
for _ in range(num_storages):
args = pickle_module.load(f, **pickle_load_args)
key, location, storage_type = args
dtype = storage_type._dtype
@ -1527,7 +1527,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
num_tensors = pickle_module.load(f, **pickle_load_args)
for _ in range(num_tensors):
args = pickle_module.load(f, **pickle_load_args)
key, storage_id, original_tensor_type = args
key, storage_id, _original_tensor_type = args
storage = deserialized_objects[storage_id]
(ndim,) = struct.unpack("<i", f.read(4))
# skip next 4 bytes; legacy encoding treated ndim as 8 bytes

View File

@ -87,11 +87,11 @@ def sparse_semi_structured_from_dense_cutlass(dense):
if dense.dtype != torch.float:
ksparse = 4
dense_4 = dense.view(-1, k // ksparse, ksparse)
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1)
else:
ksparse = 2
dense_2 = dense.view(-1, k // ksparse, ksparse)
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1)
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
# Encoding quadruples of True/False values as follows:

Some files were not shown because too many files have changed in this diff Show More