mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963 Approved by: https://github.com/ezyang
This commit is contained in:
parent
fb44658415
commit
c0582fd0f8
|
|
@ -266,7 +266,7 @@ def broadcast_shapes(shape1, shape2):
|
|||
|
||||
|
||||
def get_conv_pool_shape(image_shape, args, out_ch, transpose):
|
||||
batch, in_c, in_h, in_w = image_shape
|
||||
batch, _in_c, in_h, in_w = image_shape
|
||||
|
||||
# TODO: Handle dilation
|
||||
if args.dilation_h != 1 or args.dilation_w != 1:
|
||||
|
|
@ -443,7 +443,6 @@ class _NnapiSerializer:
|
|||
operand_id = len(self.operands)
|
||||
self.operands.append(toper)
|
||||
tsize = tensor_size(toper.op_type, toper.shape)
|
||||
psize = ((tsize - 1) | 0x3) + 1
|
||||
self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER))
|
||||
buf_num = len(self.used_weights)
|
||||
offset = 0
|
||||
|
|
@ -917,7 +916,7 @@ class _NnapiSerializer:
|
|||
adder(self, node)
|
||||
|
||||
def _identity(self, node):
|
||||
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
|
||||
in_id, _in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
|
||||
jitval = node.outputsAt(0)
|
||||
self.jitval_operand_map[jitval] = in_id
|
||||
|
||||
|
|
@ -1039,8 +1038,8 @@ class _NnapiSerializer:
|
|||
|
||||
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
|
||||
|
||||
start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
|
||||
end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
|
||||
_start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
|
||||
_end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
|
||||
|
||||
# channels last with channels == 1 or (height & width both 1)
|
||||
is_trivial_flatten = len(in_oper.shape) == 4 and (
|
||||
|
|
@ -1526,7 +1525,7 @@ class _NnapiSerializer:
|
|||
def add_pool2d_node(self, node, opcode):
|
||||
assert node.inputsSize() == 6
|
||||
assert node.outputsSize() == 1
|
||||
image, kernel, stride, padding, dilation, ceil_mode = node.inputs()
|
||||
image, kernel, stride, padding, dilation, _ceil_mode = node.inputs()
|
||||
|
||||
stride = stride or kernel
|
||||
|
||||
|
|
@ -1574,7 +1573,7 @@ class _NnapiSerializer:
|
|||
kernel,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
_ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
) = node.inputs()
|
||||
|
|
@ -1673,7 +1672,7 @@ class _NnapiSerializer:
|
|||
scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined]
|
||||
else:
|
||||
scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined]
|
||||
scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined]
|
||||
scale_w_ctype, _scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined]
|
||||
|
||||
# The only way for the 4-argument overload of upsample_nearest2d to
|
||||
# have been added to the graph without error is if the scale_h and
|
||||
|
|
@ -1892,7 +1891,7 @@ class _NnapiSerializer:
|
|||
self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
|
||||
|
||||
def get_optional_bias(self, jit_bias, weight_tensor, transpose=False):
|
||||
ctype, value = self.get_constant_value(jit_bias)
|
||||
ctype, _value = self.get_constant_value(jit_bias)
|
||||
if ctype.kind() == "NoneType":
|
||||
bias_idx = 1 if transpose else 0
|
||||
nnapi_bias_tensor = torch.zeros(
|
||||
|
|
@ -1919,7 +1918,7 @@ class _NnapiSerializer:
|
|||
) = node.inputs()
|
||||
|
||||
_, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
|
||||
bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
|
||||
bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
|
||||
args = self.get_conv_pool_args_2d_from_jit(
|
||||
weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
|
||||
)
|
||||
|
|
@ -1958,7 +1957,7 @@ class _NnapiSerializer:
|
|||
|
||||
_, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
|
||||
_, transpose = self.get_constant_value(jit_transpose)
|
||||
bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose)
|
||||
bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose)
|
||||
args = self.get_conv_pool_args_2d_from_jit(
|
||||
weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
|
||||
)
|
||||
|
|
@ -1979,7 +1978,7 @@ class _NnapiSerializer:
|
|||
assert node.inputsSize() == 3
|
||||
assert node.outputsSize() == 1
|
||||
|
||||
(jit_input, jit_dim, jit_half_to_float) = node.inputs()
|
||||
jit_input, jit_dim, _jit_half_to_float = node.inputs()
|
||||
input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
|
||||
_, dim = self.get_constant_value(jit_dim, "IntType")
|
||||
|
||||
|
|
@ -2117,7 +2116,7 @@ class _NnapiSerializer:
|
|||
|
||||
if depthwise:
|
||||
# Depthwise convolution
|
||||
one, kern_h, kern_w, out_c = weight_oper.shape
|
||||
one, _kern_h, _kern_w, out_c = weight_oper.shape
|
||||
assert one == 1
|
||||
assert out_c % in_c == 0
|
||||
channel_multiplier = out_c // in_c
|
||||
|
|
@ -2125,7 +2124,7 @@ class _NnapiSerializer:
|
|||
assert out_c == in_c
|
||||
else:
|
||||
# Full convolution
|
||||
out_c, kern_h, kern_w, kern_d = weight_oper.shape
|
||||
out_c, _kern_h, _kern_w, kern_d = weight_oper.shape
|
||||
assert kern_d == in_c
|
||||
|
||||
assert out_c == bias_oper.shape[0]
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def visualize(graph, name_prefix='', pb_graph=None, executors_it=None):
|
|||
return pb_graph
|
||||
|
||||
# Set up an input node
|
||||
input_node = pb_graph.node.add(op='input', name=name_prefix + 'input')
|
||||
pb_graph.node.add(op='input', name=name_prefix + 'input')
|
||||
for i, value in enumerate(graph.param_node().outputs()):
|
||||
value_map[value.unique()] = name_prefix + 'input:' + str(i)
|
||||
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ def _check_capability():
|
|||
work properly, but your PyTorch was compiled
|
||||
with CUDA_VERSION %d. Please install the correct PyTorch binary
|
||||
using instructions from https://pytorch.org
|
||||
"""
|
||||
""" # noqa: F841
|
||||
|
||||
old_gpu_warn = """
|
||||
Found GPU%d %s which is of cuda capability %d.%d.
|
||||
|
|
@ -195,7 +195,7 @@ def _check_capability():
|
|||
"""
|
||||
|
||||
if torch.version.cuda is not None: # on ROCm we don't want this check
|
||||
CUDA_VERSION = torch._C._cuda_getCompiledVersion()
|
||||
CUDA_VERSION = torch._C._cuda_getCompiledVersion() # noqa: F841
|
||||
for d in range(device_count()):
|
||||
capability = get_device_capability(d)
|
||||
major = capability[0]
|
||||
|
|
|
|||
|
|
@ -213,7 +213,6 @@ def segsum(data):
|
|||
Args:
|
||||
data: snapshot dictionary created from _snapshot()
|
||||
"""
|
||||
segments = []
|
||||
out = io.StringIO()
|
||||
out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n")
|
||||
total_reserved = 0
|
||||
|
|
@ -272,7 +271,6 @@ def segsum(data):
|
|||
out.write(f'segments: {len(data["segments"])}\n')
|
||||
out.write(f'total_reserved: {Bytes(total_reserved)}\n')
|
||||
out.write(f'total_allocated: {Bytes(total_allocated)}\n')
|
||||
internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else ''
|
||||
out.write(f'total_free: {_report_free(free_external, free_internal)}\n')
|
||||
out.write(legend)
|
||||
assert free_internal + free_external + total_allocated == total_reserved
|
||||
|
|
@ -478,10 +476,8 @@ def _profile_to_snapshot(profile):
|
|||
|
||||
kv_to_elem = {}
|
||||
|
||||
|
||||
|
||||
# create the device trace
|
||||
for time, action, (tensor_key, version), size in memory_profile.timeline:
|
||||
for _time, action, (tensor_key, version), size in memory_profile.timeline:
|
||||
if not isinstance(tensor_key, TensorKey):
|
||||
continue
|
||||
if action == Action.CREATE:
|
||||
|
|
|
|||
|
|
@ -357,11 +357,10 @@ def make_graphed_callables(
|
|||
# Capture backward graphs in reverse order
|
||||
per_callable_static_grad_outputs = []
|
||||
per_callable_static_grad_inputs = []
|
||||
for static_input_surface, static_outputs, bwd_graph, module_params in zip(
|
||||
for static_input_surface, static_outputs, bwd_graph in zip(
|
||||
reversed(per_callable_static_input_surfaces),
|
||||
reversed(per_callable_static_outputs),
|
||||
reversed(bwd_graphs),
|
||||
reversed(per_callable_module_params),
|
||||
):
|
||||
# For now, assumes all static_outputs require grad
|
||||
# assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ def foreach_all_gather_copy_out(
|
|||
param_all_gather_input_numels,
|
||||
all_gather_input_split_sizes,
|
||||
) = all_gather_result
|
||||
dtype, device = all_gather_output.dtype, all_gather_output.device
|
||||
_dtype, device = all_gather_output.dtype, all_gather_output.device
|
||||
device_handle = _get_device_handle(device.type)
|
||||
if all_gather_event is not None: # sync op
|
||||
device_handle.current_stream().wait_event(all_gather_event)
|
||||
|
|
|
|||
|
|
@ -82,8 +82,6 @@ class _ReplicateState(_State):
|
|||
return
|
||||
|
||||
self.has_initialized = True
|
||||
|
||||
device_mesh = kwargs.get("device_mesh", None)
|
||||
self.module = module
|
||||
ignored_params = {p for m in ignored_modules for p in m.parameters()}
|
||||
for submodule in module.modules():
|
||||
|
|
|
|||
|
|
@ -274,7 +274,7 @@ def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_grou
|
|||
mod, param_name, spec, src_rank=src_rank, process_group=process_group
|
||||
)
|
||||
elif isinstance(spec, Sharder):
|
||||
parent_mod_path, _, mod_name = name.rpartition(".")
|
||||
parent_mod_path, _, _mod_name = name.rpartition(".")
|
||||
if name == "":
|
||||
raise KeyError("Module path must not be empty for custom sharder!")
|
||||
mod = module.get_submodule(name)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
|
|||
if len(args) != 2:
|
||||
raise ValueError(f"Expected two arguments for torch.{cmp_fun.__name__}")
|
||||
|
||||
result = True
|
||||
st1 = args[0]
|
||||
st2 = args[1]
|
||||
if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)):
|
||||
|
|
|
|||
|
|
@ -857,7 +857,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
|
||||
local_shards: List[Shard] = []
|
||||
for shard_metadata in sharded_tensor_metadata.shards_metadata:
|
||||
rank, device = _parse_and_validate_remote_device(
|
||||
rank, _device = _parse_and_validate_remote_device(
|
||||
process_group, shard_metadata.placement
|
||||
)
|
||||
if rank == current_rank:
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ def _reducer_allreduce_and_upcast_hook(
|
|||
"""
|
||||
ddp_weakref = hook_state.ddp_weakref
|
||||
reducer, process_group = ddp_weakref().reducer, ddp_weakref().process_group
|
||||
gradient_is_bucket_view = ddp_weakref().gradient_as_bucket_view
|
||||
# Cast bucket if different than param_dtype.
|
||||
if (
|
||||
ddp_weakref().mixed_precision.param_dtype
|
||||
|
|
@ -53,8 +52,7 @@ def _reducer_allreduce_and_upcast_hook(
|
|||
ret_fut.set_result(bucket.buffer())
|
||||
|
||||
# Upcast parameters and gradients so optimizer step can run in fp32.
|
||||
params, grads = bucket.parameters(), bucket.gradients()
|
||||
for p, g in zip(params, grads):
|
||||
for p in bucket.parameters():
|
||||
p.data = p._fp_param
|
||||
# free storage for mp param as it will be allocated again in next
|
||||
# forward pass.
|
||||
|
|
@ -70,7 +68,7 @@ def _reducer_allreduce_and_upcast_hook(
|
|||
# they may participate in computation. However, they would not be recast
|
||||
# by hook above as they don't have a grad hook installed, so cast them
|
||||
# back here.
|
||||
for n, p in ddp_weakref().module.named_parameters():
|
||||
for _, p in ddp_weakref().module.named_parameters():
|
||||
if hasattr(p, "_ddp_mp_hook_state"):
|
||||
p._ddp_mp_hook_state[1].remove()
|
||||
delattr(p, "_ddp_mp_hook_state")
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ def _retrieve_embedding_parameters(emb_rref):
|
|||
def _print_header():
|
||||
_print_cont("\n")
|
||||
_print_cont("%10s" % "")
|
||||
for p in [50, 75, 90, 95]:
|
||||
for _ in [50, 75, 90, 95]:
|
||||
_print_cont("%14s%10s" % ("sec/epoch", "epoch/sec"))
|
||||
_print_cont("\n")
|
||||
|
||||
|
|
@ -112,7 +112,6 @@ def _run_printable(cmd):
|
|||
buffer = io.BytesIO()
|
||||
torch.save(proc.stdout.decode("utf-8"), buffer)
|
||||
input_tensor = torch.ByteTensor(list(buffer.getvalue()))
|
||||
input_length = torch.IntTensor([input_tensor.size(0)])
|
||||
|
||||
output = []
|
||||
buffer = io.BytesIO(np.asarray(input_tensor).tobytes())
|
||||
|
|
@ -173,7 +172,7 @@ def _run_trainer(emb_rref_list, rank):
|
|||
|
||||
measurements = []
|
||||
# Include warm-up cycles during training
|
||||
for epoch in range(100 + WARMUP_CYCLES):
|
||||
for _ in range(100 + WARMUP_CYCLES):
|
||||
start = time.time()
|
||||
batch_size = 0
|
||||
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ def create_read_items_for_chunk_list(
|
|||
dest_offsets = []
|
||||
lengths = []
|
||||
for (
|
||||
dim,
|
||||
_dim,
|
||||
offset_for_saved_tensor,
|
||||
offset_for_current_tensor,
|
||||
length,
|
||||
|
|
|
|||
|
|
@ -4581,7 +4581,7 @@ def split_group(
|
|||
raise RuntimeError(
|
||||
"No device associated with the default pg, not safe to split any process groups"
|
||||
)
|
||||
default_backend, default_store = _world.pg_map[default_pg]
|
||||
_default_backend, default_store = _world.pg_map[default_pg]
|
||||
global_rank = default_pg.rank()
|
||||
global_world_size = default_pg.size()
|
||||
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ class ChildFailedError(Exception):
|
|||
|
||||
def format_msg(self, boarder_delim="=", section_delim="-"):
|
||||
title = f"{self.name} FAILED"
|
||||
root_rank, root_failure = self.get_first_failure()
|
||||
root_rank, _root_failure = self.get_first_failure()
|
||||
|
||||
root_failure_fmt: str = ""
|
||||
other_failures_fmt: List[str] = []
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def get_socket_with_port() -> socket.socket:
|
|||
s.bind(("localhost", 0))
|
||||
s.listen(0)
|
||||
return s
|
||||
except OSError as e:
|
||||
except OSError:
|
||||
s.close()
|
||||
raise RuntimeError("Failed to create a socket")
|
||||
|
||||
|
|
|
|||
|
|
@ -274,7 +274,7 @@ def _named_parameters_with_duplicates(
|
|||
kwargs["remove_duplicate"] = False
|
||||
try:
|
||||
ret = list(module.named_parameters(**kwargs))
|
||||
except AssertionError as e:
|
||||
except AssertionError:
|
||||
kwargs.pop("remove_duplicate")
|
||||
ret = list(module.named_parameters(**kwargs))
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -1017,10 +1017,10 @@ class FlatParamHandle:
|
|||
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
|
||||
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
|
||||
# into the unsharded flat parameter (inclusive) of the given parameter
|
||||
for i, (
|
||||
for (
|
||||
(unsharded_param_start_idx, unsharded_param_end_idx),
|
||||
is_padding,
|
||||
) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)):
|
||||
) in zip(flat_param_offsets, self.flat_param._is_padding_mask):
|
||||
if is_padding:
|
||||
continue
|
||||
in_sharded_flat_param = (
|
||||
|
|
@ -2201,8 +2201,8 @@ class FlatParamHandle:
|
|||
else:
|
||||
param.grad = None
|
||||
assert flat_param._shared_params is not None
|
||||
for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
|
||||
zip(flat_param._shared_params, flat_param._shared_param_infos)
|
||||
for param, (_, _, _, prim_param_name, prim_module, _) in zip(
|
||||
flat_param._shared_params, flat_param._shared_param_infos
|
||||
):
|
||||
in_sharded_flat_param = hasattr(prim_module, prim_param_name)
|
||||
if in_sharded_flat_param and param.requires_grad:
|
||||
|
|
|
|||
|
|
@ -536,9 +536,7 @@ def _flatten_optim_state_dict(
|
|||
else:
|
||||
# Move the tensor in the original osd back to CPU to make the
|
||||
# original osd unaffected.
|
||||
unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][
|
||||
state_name
|
||||
].cpu()
|
||||
unflat_osd_state[fqn][state_name] = param_state.cpu()
|
||||
|
||||
# Handle user-defined state, states that are not associated with parameters.
|
||||
for key in all_state_keys:
|
||||
|
|
@ -1457,7 +1455,7 @@ def _unflatten_orig_param_states(
|
|||
# gather the tensor on its TP dimension before chunking them into DTensor again.
|
||||
if placement != Replicate():
|
||||
placement_dim = placement.dim # type: ignore[attr-defined]
|
||||
value_local = value.redistribute(placements=(Replicate(),))
|
||||
value.redistribute(placements=(Replicate(),))
|
||||
reshape_size = list(flat_param._shapes[param_idx])
|
||||
reshape_size[placement_dim] *= value.device_mesh.size(0)
|
||||
reshape_size = torch.Size(reshape_size)
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ def _full_pre_state_dict_hook(
|
|||
``nn.Module``.
|
||||
"""
|
||||
if getattr(fsdp_state, "_device_mesh", False):
|
||||
root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh)
|
||||
_mesh_resources.get_root_mesh(fsdp_state._device_mesh)
|
||||
|
||||
_common_pre_state_dict_hook(module, fsdp_state)
|
||||
_common_unshard_pre_state_dict_hook(
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ def get_param_groups(
|
|||
# but omits weights and any subgraphs connecting weights to this closure
|
||||
inputs_closure, _ = reverse_closure(inputs, set(), reverse_edges_dict)
|
||||
param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates
|
||||
for i, param in enumerate(params):
|
||||
for param in params:
|
||||
closure, intersected = reverse_closure(
|
||||
[param], inputs_closure, reverse_edges_dict
|
||||
)
|
||||
|
|
|
|||
|
|
@ -802,7 +802,6 @@ class Schedule1F1B(PipelineScheduleSingle):
|
|||
# Chunk counters
|
||||
fwd_mb_index = 0
|
||||
bwd_mb_index = 0
|
||||
weight_stage_mb_index = 0
|
||||
|
||||
# Warmup phase
|
||||
send_work = None
|
||||
|
|
|
|||
|
|
@ -719,7 +719,7 @@ class _PipelineStageBase(ABC):
|
|||
"param_groups": param_groups,
|
||||
"full_backward": False,
|
||||
}
|
||||
weight_grads, _ = self.backward_maybe_with_nosync("weight", bwd_kwargs)
|
||||
self.backward_maybe_with_nosync("weight", bwd_kwargs)
|
||||
else:
|
||||
# TODO: figure out a better way to do this:
|
||||
# if inputs does not require gradient,
|
||||
|
|
@ -1603,13 +1603,13 @@ def _validate_stage_shapes(pipeline_stages: List[PipelineStage]):
|
|||
]
|
||||
|
||||
logger.debug(
|
||||
f"Rank: {pg_rank}" # noqa: G004
|
||||
f"Stage id: {stage_id}"
|
||||
f"Stage num stages: {stage.num_stages}"
|
||||
f"Stage rank: {rank}"
|
||||
f"Stage world size: {world_size}"
|
||||
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}" # noqa: G003
|
||||
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}" # noqa: G003
|
||||
f"Rank: {pg_rank}", # noqa: G004
|
||||
f"Stage id: {stage_id}",
|
||||
f"Stage num stages: {stage.num_stages}",
|
||||
f"Stage rank: {rank}",
|
||||
f"Stage world size: {world_size}",
|
||||
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}", # noqa: G003
|
||||
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}", # noqa: G003
|
||||
)
|
||||
|
||||
all_inputs.extend(stage_input_shapes)
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ def found_inf_reduce_handler(
|
|||
cast(List[object], op_info.local_args), op_info.args_tree_spec
|
||||
)
|
||||
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
|
||||
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
|
||||
op_call(*local_tensor_args, **op_info.local_kwargs)
|
||||
|
||||
grad_dtensor = cast(list[dtensor.DTensor], args[0])[0]
|
||||
grad_placements = grad_dtensor.placements
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding:
|
|||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
_transposed,
|
||||
_output_padding,
|
||||
_groups,
|
||||
) = op_schema.args_schema
|
||||
|
||||
assert isinstance(input_spec, DTensorSpec)
|
||||
|
|
@ -37,7 +37,7 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding:
|
|||
assert isinstance(padding, List)
|
||||
assert isinstance(dilation, List)
|
||||
assert isinstance(weight_shape, torch.Size)
|
||||
N, C_in, H_in, W_in = in_shape[0], in_shape[1], in_shape[2], in_shape[3]
|
||||
N, H_in, W_in = in_shape[0], in_shape[2], in_shape[3]
|
||||
C_out = weight_shape[0]
|
||||
H_out = (H_in + 2 * padding[0] - dilation[0] * (weight_shape[2] - 1) - 1) // stride[
|
||||
0
|
||||
|
|
@ -73,13 +73,13 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
|
|||
input_spec,
|
||||
weight_spec,
|
||||
bias_shape_opt,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
output_mask,
|
||||
_stride,
|
||||
_padding,
|
||||
_dilation,
|
||||
_transposed,
|
||||
_output_padding,
|
||||
_groups,
|
||||
_output_mask,
|
||||
) = op_schema.args_schema
|
||||
|
||||
assert isinstance(grad_output_spec, DTensorSpec)
|
||||
|
|
|
|||
|
|
@ -479,7 +479,7 @@ def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
|
|||
softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim)
|
||||
|
||||
output_strategy = OpStrategy([])
|
||||
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
|
||||
for input_placement_strategy in input_strategy.strategies:
|
||||
redistribute_costs = []
|
||||
input_src_spec = input_placement_strategy.output_spec
|
||||
|
||||
|
|
@ -1038,8 +1038,6 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
|
|||
)
|
||||
def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
|
||||
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
|
||||
k = cast(int, op_schema.args_schema[1])
|
||||
input_shape = input_strategy.shape
|
||||
topk_dim = (
|
||||
cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1
|
||||
)
|
||||
|
|
|
|||
|
|
@ -171,7 +171,6 @@ def scaled_dot_product_flash_attention_strategy(
|
|||
q_input_strategy = op_schema.args_schema[0]
|
||||
assert isinstance(q_input_strategy, OpStrategy)
|
||||
# assuming q/k/v have the same shape
|
||||
qkv_shape = q_input_strategy.shape
|
||||
|
||||
single_mesh_dim_strategies = []
|
||||
|
||||
|
|
@ -250,7 +249,6 @@ def scaled_dot_product_flash_attention_backward_strategy(
|
|||
q_input_strategy = op_schema.args_schema[1]
|
||||
assert isinstance(q_input_strategy, OpStrategy)
|
||||
# assuming q/k/v have the same shape
|
||||
qkv_shape = q_input_strategy.shape
|
||||
|
||||
tensor_input_indices = [
|
||||
i
|
||||
|
|
@ -344,7 +342,7 @@ def scaled_dot_product_efficient_attention_strategy(
|
|||
q_input_strategy = op_schema.args_schema[0]
|
||||
assert isinstance(q_input_strategy, OpStrategy)
|
||||
# assuming q/k/v have the same shape
|
||||
qkv_shape = q_input_strategy.shape
|
||||
|
||||
has_attn_bias = op_schema.args_schema[3] is not None
|
||||
compute_log_sumexp = op_schema.args_schema[4]
|
||||
|
||||
|
|
@ -418,15 +416,8 @@ def scaled_dot_product_efficient_attention_backward_strategy(
|
|||
q_input_strategy = op_schema.args_schema[1]
|
||||
assert isinstance(q_input_strategy, OpStrategy)
|
||||
# assuming q/k/v have the same shape
|
||||
qkv_shape = q_input_strategy.shape
|
||||
has_attn_bias = op_schema.args_schema[4] is not None
|
||||
|
||||
tensor_input_indices = [
|
||||
i
|
||||
for i, arg_spec in enumerate(op_schema.args_schema)
|
||||
if isinstance(arg_spec, OpStrategy)
|
||||
]
|
||||
|
||||
single_mesh_dim_strategies = []
|
||||
|
||||
# placement list stores placements of [outputs, inputs]
|
||||
|
|
|
|||
|
|
@ -367,7 +367,6 @@ def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType
|
|||
schema_info=RuntimeSchemaInfo(1),
|
||||
)
|
||||
def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
|
||||
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
|
||||
single_mesh_dim_strategies = []
|
||||
|
||||
# placement list stores placements of [output, input, index, src]
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ def _gen_transform_infos_non_cached(
|
|||
# Handle multi-dim device mesh placement redistribution
|
||||
# First, we need to build the logical shape for each mesh dim
|
||||
# for correct allgathering uneven shards on each mesh dim (with dynamic padding)
|
||||
for i, (src, dst) in enumerate(zip(src_spec.placements, dst_spec.placements)):
|
||||
for i, src in enumerate(src_spec.placements):
|
||||
current_logical_shape = mesh_dims_to_logical_shape[i]
|
||||
if isinstance(src, Shard):
|
||||
if i < device_mesh.ndim - 1:
|
||||
|
|
@ -192,7 +192,7 @@ def redistribute_local_tensor(
|
|||
for transform_info in transform_infos:
|
||||
i = transform_info.mesh_dim
|
||||
current, target = transform_info.src_dst_placements
|
||||
num_chunks = device_mesh.size(mesh_dim=i)
|
||||
device_mesh.size(mesh_dim=i)
|
||||
|
||||
if current == target:
|
||||
# short cut, just use the original local tensor
|
||||
|
|
@ -220,7 +220,6 @@ def redistribute_local_tensor(
|
|||
elif target.is_shard():
|
||||
# Case 2: target is Shard
|
||||
target_placement = cast(Shard, target)
|
||||
target_dim = target_placement.dim
|
||||
if current.is_partial():
|
||||
partial_spec = cast(Partial, current)
|
||||
new_local_tensor = partial_spec._reduce_shard_value(
|
||||
|
|
|
|||
|
|
@ -192,7 +192,6 @@ def tp_convolution_backward(
|
|||
)
|
||||
|
||||
# step2 reconstruct local gradient output tensor
|
||||
N, C_out, H_out, _ = grad_out_tensor.shape
|
||||
padding_w = padding[1]
|
||||
if rank == 0:
|
||||
grad_out_tensor = torch.nn.functional.pad(
|
||||
|
|
|
|||
|
|
@ -269,7 +269,7 @@ class CommDebugModeExample:
|
|||
|
||||
comm_mode = CommDebugMode()
|
||||
with comm_mode:
|
||||
output = model(inp)
|
||||
model(inp)
|
||||
|
||||
# print the module level collective tracing information
|
||||
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
|
||||
|
|
@ -592,7 +592,7 @@ class CommDebugModeExample:
|
|||
|
||||
comm_mode = CommDebugMode()
|
||||
with comm_mode:
|
||||
output = model(inp)
|
||||
model(inp)
|
||||
|
||||
# print the operation level collective tracing information
|
||||
print(comm_mode.generate_comm_debug_tracing_table(noise_level=2))
|
||||
|
|
@ -628,7 +628,7 @@ class CommDebugModeExample:
|
|||
|
||||
comm_mode = CommDebugMode()
|
||||
with comm_mode:
|
||||
output = model(inp)
|
||||
model(inp)
|
||||
|
||||
comm_mode.generate_json_dump(file_name="transformer_log.json", noise_level=1)
|
||||
comm_mode.generate_json_dump(file_name="transformer_log_2.json", noise_level=2)
|
||||
|
|
|
|||
|
|
@ -220,7 +220,7 @@ def train_convnext_example():
|
|||
forward_time = 0.0
|
||||
backward_time = 0.0
|
||||
start = time.time()
|
||||
for i in range(ITER_TIME):
|
||||
for _ in range(ITER_TIME):
|
||||
t1 = time.time()
|
||||
y = model(x)
|
||||
torch.cuda.synchronize()
|
||||
|
|
|
|||
|
|
@ -130,7 +130,6 @@ def run_torchrec_row_wise_even_sharding_example(rank, world_size):
|
|||
# manually create the embedding table's local shards
|
||||
num_embeddings = 8
|
||||
embedding_dim = 16
|
||||
emb_table_shape = torch.Size([num_embeddings, embedding_dim])
|
||||
# tensor shape
|
||||
local_shard_shape = torch.Size(
|
||||
[num_embeddings // world_size, embedding_dim] # (local_rows, local_cols)
|
||||
|
|
@ -270,7 +269,7 @@ def run_torchrec_table_wise_sharding_example(rank, world_size):
|
|||
device = torch.device(device_type)
|
||||
# note: without initializing this mesh, the following local_tensor will be put on
|
||||
# device cuda:0.
|
||||
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,))
|
||||
init_device_mesh(device_type=device_type, mesh_shape=(world_size,))
|
||||
|
||||
# manually create the embedding table's local shards
|
||||
num_embeddings = 8
|
||||
|
|
@ -293,8 +292,6 @@ def run_torchrec_table_wise_sharding_example(rank, world_size):
|
|||
else torch.empty(0, device=device)
|
||||
)
|
||||
table_to_local_tensor[i] = local_tensor
|
||||
# tensor shape
|
||||
local_shard_shape = local_tensor.shape
|
||||
# tensor offset
|
||||
local_shard_offset = torch.Size((0, 0))
|
||||
# wrap local shards into a wrapper
|
||||
|
|
|
|||
|
|
@ -1087,7 +1087,6 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
|
|||
), "The current implementation only works if ROUND_ROBIN_CYCLE is 2."
|
||||
buffer = buffer.contiguous()
|
||||
cp_world_size = mesh.size()
|
||||
cp_rank = mesh.get_local_rank()
|
||||
|
||||
all_buffers = [torch.empty_like(buffer) for _ in range(cp_world_size)]
|
||||
ft_c.all_gather_inplace(all_buffers, buffer, mesh)
|
||||
|
|
|
|||
|
|
@ -279,7 +279,6 @@ def _nll_loss_forward_handler(
|
|||
ignore_index = cast(int, args[4])
|
||||
|
||||
channel_dim = 1 if x.dim() >= 2 else 0
|
||||
channel_dim_size = x.shape[channel_dim]
|
||||
spec = x._spec
|
||||
mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ class Kumaraswamy(TransformedDistribution):
|
|||
self.concentration1, self.concentration0 = broadcast_all(
|
||||
concentration1, concentration0
|
||||
)
|
||||
finfo = torch.finfo(self.concentration0.dtype)
|
||||
base_dist = Uniform(
|
||||
torch.full_like(self.concentration0, 0),
|
||||
torch.full_like(self.concentration0, 1),
|
||||
|
|
|
|||
|
|
@ -312,7 +312,6 @@ class Wishart(ExponentialFamily):
|
|||
def entropy(self):
|
||||
nu = self.df # has shape (batch_shape)
|
||||
p = self._event_shape[-1] # has singleton shape
|
||||
V = self.covariance_matrix # has shape (batch_shape x event_shape)
|
||||
return (
|
||||
(p + 1)
|
||||
* (
|
||||
|
|
|
|||
|
|
@ -921,7 +921,7 @@ def _verify_stack_trace(graph_module: torch.fx.GraphModule) -> None:
|
|||
- None or non-empty str for 'call_function', 'get_attr'
|
||||
- None for 'placeholder', 'output'
|
||||
"""
|
||||
for i, mod in enumerate([graph_module] + list(graph_module.modules())):
|
||||
for mod in [graph_module, *graph_module.modules()]:
|
||||
if not isinstance(mod, torch.fx.GraphModule):
|
||||
continue
|
||||
for node in graph_module.graph.nodes:
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ def get_node_context(node, num_nodes=2) -> str:
|
|||
"""
|
||||
node_contexts = []
|
||||
cur = node
|
||||
for i in range(num_nodes):
|
||||
for _ in range(num_nodes):
|
||||
node_contexts.append(cur.format_node())
|
||||
if cur.op == "root":
|
||||
break
|
||||
|
|
|
|||
|
|
@ -463,8 +463,6 @@ class Partitioner:
|
|||
# Check if no device is left
|
||||
if len(self.partitions) == len(self.devices):
|
||||
# No device is left
|
||||
# Put the previous partitions into a list (non_single_node_partitions)
|
||||
non_single_node_partitions = self.partitions[:]
|
||||
# Create the first single node partition for the current node
|
||||
self.create_single_node_partition(node)
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -175,7 +175,6 @@ def get_attr_inference_rule(n: Node, traced):
|
|||
The most representitive type we have is "Dyn" but the system
|
||||
can be extended with more types, such as a type to represent shapes
|
||||
"""
|
||||
attr_node = n.args[0]
|
||||
attr_name = n.args[1]
|
||||
|
||||
if attr_name == "shape":
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ class MetaTracer(torch.fx.Tracer):
|
|||
def path_of_module(self, mod: torch.nn.Module) -> str:
|
||||
try:
|
||||
return super().path_of_module(mod)
|
||||
except NameError as e:
|
||||
except NameError:
|
||||
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
|
||||
path = self._insert_module_as_submodule(mod)
|
||||
self.prev_module = path
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ def generate_binconstraint_t(constraint, counter):
|
|||
disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
|
||||
for i in range(1, constraint.rhs + 1):
|
||||
dims = []
|
||||
for j in range(1, i + 1):
|
||||
for _ in range(1, i + 1):
|
||||
dim_var, counter = gen_dvar(counter)
|
||||
dims.append(dim_var)
|
||||
disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq))
|
||||
|
|
|
|||
|
|
@ -193,7 +193,7 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
|
|||
f()
|
||||
begin = time.time()
|
||||
for _ in range(iters):
|
||||
out = f()
|
||||
f()
|
||||
return time.time() - begin
|
||||
|
||||
mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])])
|
||||
|
|
@ -278,7 +278,7 @@ def optimize_for_inference(
|
|||
|
||||
cur_tracer = tracer()
|
||||
fx_graph = cur_tracer.trace(copy.deepcopy(model))
|
||||
fx_model = fx.GraphModule(cur_tracer.root, fx_graph)
|
||||
fx.GraphModule(cur_tracer.root, fx_graph)
|
||||
modules: Dict[str, nn.Module] = dict(model.named_modules())
|
||||
|
||||
class MklSupport(Enum):
|
||||
|
|
|
|||
|
|
@ -279,10 +279,8 @@ def get_latency_of_partitioned_graph(
|
|||
def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
|
||||
"""This function helps to recursively get the latency of a path of partitions"""
|
||||
# Update latency by adding current partition's latency
|
||||
latency_so_far_sec += partition_to_latency_mapping[
|
||||
partition
|
||||
].overall_latency_sec
|
||||
children = partition.children
|
||||
latency_so_far_sec += partition_to_latency_mapping[partition].overall_latency_sec
|
||||
|
||||
if partition.children:
|
||||
max_latency_sec = 0.0
|
||||
for child in partition.children:
|
||||
|
|
|
|||
|
|
@ -1176,11 +1176,11 @@ def dispatch_trace(
|
|||
def wrap_key(
|
||||
f: Callable[_P, R], tensors: _P.args, tracer: _ProxyTracer, pre_dispatch: bool
|
||||
) -> Callable[_P, R]:
|
||||
flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
|
||||
flat_tensors, _tensors_spec = pytree.tree_flatten(tensors)
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R:
|
||||
flat_proxies, proxies_spec = pytree.tree_flatten(proxies)
|
||||
flat_proxies, _proxies_spec = pytree.tree_flatten(proxies)
|
||||
assert len(flat_proxies) == len(flat_tensors)
|
||||
with disable_proxy_modes_tracing() as m:
|
||||
assert isinstance(m, ProxyTorchDispatchMode)
|
||||
|
|
@ -1733,7 +1733,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||
|
||||
try:
|
||||
return Tracer.call_module(self, m, forward, args, kwargs)
|
||||
except _ModuleNotInstalledAsSubmoduleError as e:
|
||||
except _ModuleNotInstalledAsSubmoduleError:
|
||||
warnings.warn(
|
||||
f"Unable to find the path of the module {m}. "
|
||||
"This might be because the module was not properly registered "
|
||||
|
|
|
|||
|
|
@ -331,7 +331,7 @@ def replay_shape_env_events(events):
|
|||
# We need to call create_mapping_fn every time, since the node list might
|
||||
# change after each event is replayed.
|
||||
event.run(shape_env)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
log.error("failed when running event: %s", event)
|
||||
raise
|
||||
|
||||
|
|
|
|||
|
|
@ -3879,8 +3879,6 @@ class ShapeEnv:
|
|||
guess
|
||||
|
||||
"""
|
||||
source_name = source.name() if source else None
|
||||
|
||||
if self._translation_validation_enabled and source is not None:
|
||||
# Create a new symbol for this source.
|
||||
symbol = self._create_symbol_for_source(source)
|
||||
|
|
@ -3919,8 +3917,6 @@ class ShapeEnv:
|
|||
source: Optional[Source] = None,
|
||||
) -> Union[float, SymFloat]:
|
||||
"""Create a SymFloat value from a symbolic expression"""
|
||||
source_name = source.name() if source else None
|
||||
|
||||
if self._translation_validation_enabled and source is not None:
|
||||
# Create a new symbol for this source.
|
||||
symbol = self._create_symbol_for_source(source)
|
||||
|
|
@ -4808,7 +4804,7 @@ class ShapeEnv:
|
|||
res = f"{source_ref(source)} == {sexpr}"
|
||||
exprs.append(res)
|
||||
if (s0 := self.source_to_var.get(srcname)) is not None:
|
||||
if source != (canonical_source := self.var_to_sources[s0][0]):
|
||||
if source != self.var_to_sources[s0][0]:
|
||||
verbose_exprs.append(
|
||||
f"{res} # duck sizing added this equality because these "
|
||||
f"variables had the same size {self.var_to_val[s0]} "
|
||||
|
|
@ -6140,7 +6136,6 @@ class ShapeEnv:
|
|||
# If an error is raised before the end of this function, we remove the FX node
|
||||
# inserted, and re-raise the error.
|
||||
guard = None
|
||||
tb = None
|
||||
|
||||
try:
|
||||
if orig_expr.is_number:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class Dispatcher:
|
|||
self.ordering = ordering(self.funcs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
func, s = self.resolve(args)
|
||||
func, _ = self.resolve(args)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def resolve(self, args):
|
||||
|
|
|
|||
|
|
@ -310,7 +310,6 @@ def _parse_stack_trace(stack_trace: str):
|
|||
# stacktrace should have innermost frame last, so we
|
||||
# iterate backwards to find the first line that starts
|
||||
# with 'File '
|
||||
summary_str = ""
|
||||
for idx in range(len(lines) - 2, -1, -1):
|
||||
line = lines[idx].strip()
|
||||
matches = pattern.match(line)
|
||||
|
|
@ -463,10 +462,10 @@ class CodeGen:
|
|||
return s
|
||||
return f
|
||||
|
||||
yellow = make_wrapper_func("yellow")
|
||||
cyan = make_wrapper_func("cyan")
|
||||
yellow = make_wrapper_func("yellow") # noqa: F841
|
||||
cyan = make_wrapper_func("cyan") # noqa: F841
|
||||
red = make_wrapper_func("red")
|
||||
green = make_wrapper_func("green")
|
||||
green = make_wrapper_func("green") # noqa: F841
|
||||
dim_green = make_wrapper_func("dim_green")
|
||||
dim = make_wrapper_func("dim")
|
||||
dim_blue = make_wrapper_func("dim_blue")
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...]
|
|||
try:
|
||||
candidate_signature.bind(*args, **kwargs)
|
||||
matched_schemas.append((candidate_signature, schema))
|
||||
except TypeError as e:
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
def throw_if_mutable(schema):
|
||||
|
|
@ -214,7 +214,7 @@ def create_type_hint(x):
|
|||
else:
|
||||
return ret_type(Any)
|
||||
return ret_type(base_type)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# We tried to create a type hint for list but failed.
|
||||
warnings.warn(f"We were not able to successfully create type hint from the type {x}")
|
||||
return x
|
||||
|
|
@ -328,7 +328,7 @@ def normalize_function(
|
|||
try:
|
||||
candidate_signature.bind(*args, **kwargs)
|
||||
matched_schemas.append(candidate_signature)
|
||||
except TypeError as e:
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
if len(matched_schemas) == 0:
|
||||
|
|
@ -349,7 +349,7 @@ def normalize_function(
|
|||
for arg_name, arg_type in bound_types.arguments.items():
|
||||
param = candidate_signature.parameters[arg_name]
|
||||
sig_matches = sig_matches and type_matches(param.annotation, arg_type)
|
||||
except TypeError as e:
|
||||
except TypeError:
|
||||
sig_matches = False
|
||||
if sig_matches:
|
||||
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs,
|
||||
|
|
|
|||
|
|
@ -107,7 +107,6 @@ def tensorify_python_scalars(
|
|||
placeholders = set()
|
||||
for node in graph.nodes:
|
||||
if node.op != "placeholder":
|
||||
first_none_placeholder = node
|
||||
break
|
||||
else:
|
||||
placeholders.add(node)
|
||||
|
|
|
|||
|
|
@ -58,7 +58,6 @@ def get_size_of_all_nodes(
|
|||
# Mark shape and dtype for each node (node.shape and node.dtype)
|
||||
ShapeProp(fx_module).propagate(*args)
|
||||
# Calculate the total size of the whole fx graph
|
||||
total_size_of_graph = 0.0
|
||||
for node in fx_module.graph.nodes:
|
||||
if node.op == "output":
|
||||
break
|
||||
|
|
@ -92,7 +91,7 @@ def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
|
|||
submodule = submodule_dict[node.target]
|
||||
parameters = submodule.named_parameters()
|
||||
# Parameters are named tuples
|
||||
for name, p in parameters:
|
||||
for _name, p in parameters:
|
||||
total_num_of_elems += p.numel()
|
||||
# Don't forget the output size
|
||||
# node.shape is the shape of this node's output
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def inplace_wrapper(fn: Callable) -> Callable:
|
|||
|
||||
@wraps(fn)
|
||||
def wrapped_fn(gm):
|
||||
val = fn(gm)
|
||||
fn(gm)
|
||||
return gm
|
||||
|
||||
return wrapped_fn
|
||||
|
|
|
|||
|
|
@ -487,7 +487,7 @@ def reinplace(gm, *sample_args):
|
|||
|
||||
# inplace-ify functional ops, subject to the constraints written below.
|
||||
all_later_view_inverse_nodes_to_delete = set()
|
||||
for idx, node in enumerate(gm.graph.nodes):
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'call_function':
|
||||
|
||||
# Today, the re-inplace pass on directly acts on:
|
||||
|
|
@ -532,7 +532,6 @@ def reinplace(gm, *sample_args):
|
|||
continue
|
||||
|
||||
# Step 1b: ensure that the op we're trying to re-inplace isn't a program input
|
||||
self_arg_name = self_arg.name
|
||||
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
|
||||
if self_arg_storage in input_storages:
|
||||
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
|
||||
|
|
@ -578,7 +577,7 @@ def reinplace(gm, *sample_args):
|
|||
remaining_slice_args = node.args[2:]
|
||||
slice_node = gm.graph.create_node(
|
||||
'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs)
|
||||
copy_node = gm.graph.create_node(
|
||||
gm.graph.create_node(
|
||||
'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {})
|
||||
# Add the slice_scatter node to our "nodes to delete" list.
|
||||
all_later_view_inverse_nodes_to_delete.add(node)
|
||||
|
|
@ -612,8 +611,6 @@ def reinplace(gm, *sample_args):
|
|||
new = old.args[0]
|
||||
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
|
||||
for node_to_update in nodes_to_update:
|
||||
new_args = []
|
||||
args = node_to_update.args
|
||||
|
||||
def replace_arg(a):
|
||||
if a == old:
|
||||
|
|
@ -654,7 +651,6 @@ def reinplace(gm, *sample_args):
|
|||
x._typed_storage()
|
||||
) for x in new_flattened_res if isinstance(x, FakeTensor)}
|
||||
assert len(new_res_storage) == 1
|
||||
(old_ref,) = old_res_storage
|
||||
(new_ref,) = new_res_storage
|
||||
(node_ref,) = node_res_storage
|
||||
# Technically, "old_ref" and all its aliases will remain
|
||||
|
|
|
|||
|
|
@ -526,7 +526,7 @@ def split_module(
|
|||
for node in original_order:
|
||||
if node in already_constructed_attr_nodes:
|
||||
continue # already added this attr to the base graph
|
||||
base_mod_env, based_mod_attrs = construct_graph(
|
||||
base_mod_env, _based_mod_attrs = construct_graph(
|
||||
node, base_mod_env, base_mod_attrs
|
||||
)
|
||||
already_constructed_attr_nodes.add(node)
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class _SplitterSettingBase:
|
|||
"we might not care about non-tensor data flow and we can set this option "
|
||||
"to true to disable the functionality that prevent non-tensor data flow.",
|
||||
)
|
||||
args, unknown = parser.parse_known_args()
|
||||
args, _unknown = parser.parse_known_args()
|
||||
|
||||
self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
|
||||
self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
|
||||
|
|
@ -882,7 +882,7 @@ class _SplitterBase:
|
|||
def generate_split_results(self) -> SplitResult:
|
||||
split_module = self()
|
||||
submodule_names = []
|
||||
for name, mod in split_module.named_children():
|
||||
for name, _mod in split_module.named_children():
|
||||
submodule_names.append(name)
|
||||
if (
|
||||
self.settings.max_acc_splits > 0
|
||||
|
|
|
|||
|
|
@ -279,7 +279,7 @@ def _replace_pattern(
|
|||
match_changed_node: Dict[Node, Node] = {}
|
||||
|
||||
match_and_replacements = []
|
||||
for i, match in enumerate(_matches):
|
||||
for match in _matches:
|
||||
if replacement_callback is not None:
|
||||
replacement_graph = replacement_callback(match, original_graph, pattern_graph)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -720,7 +720,7 @@ def download_url_to_file(
|
|||
# We deliberately do not use NamedTemporaryFile to avoid restrictive
|
||||
# file permissions being applied to the downloaded file.
|
||||
dst = os.path.expanduser(dst)
|
||||
for seq in range(tempfile.TMP_MAX):
|
||||
for _ in range(tempfile.TMP_MAX):
|
||||
tmp_dst = dst + "." + uuid.uuid4().hex + ".partial"
|
||||
try:
|
||||
f = open(tmp_dst, "w+b")
|
||||
|
|
|
|||
|
|
@ -570,10 +570,6 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
|
|||
method_stubs = stubs_fn(nn_module)
|
||||
property_stubs = get_property_stubs(nn_module)
|
||||
hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
|
||||
|
||||
user_annotated_ignored_attributes = getattr(
|
||||
nn_module, "__jit_ignored_attributes__", []
|
||||
)
|
||||
ignored_properties = jit_ignored_properties(nn_module)
|
||||
|
||||
def init_fn(script_module):
|
||||
|
|
@ -838,9 +834,6 @@ def infer_methods_to_compile(nn_module):
|
|||
(TODO add a link when the rules are published).
|
||||
"""
|
||||
check_module_initialized(nn_module)
|
||||
user_annotated_ignored_attributes = getattr(
|
||||
nn_module, "__jit_ignored_attributes__", []
|
||||
)
|
||||
ignored_properties = jit_ignored_properties(nn_module)
|
||||
|
||||
methods: List[str] = []
|
||||
|
|
|
|||
|
|
@ -1600,7 +1600,7 @@ def _recursive_compile_class(obj, loc):
|
|||
_qual_name = _qualified_name(obj)
|
||||
# We're starting a new compilation, so update the error call stack in
|
||||
# case it fails
|
||||
error_stack = torch._C.CallStack(_qual_name, loc)
|
||||
error_stack = torch._C.CallStack(_qual_name, loc) # noqa: F841
|
||||
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
|
||||
return _compile_and_register_class(obj, rcb, _qual_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -255,7 +255,6 @@ def pool2d_shape_check(
|
|||
outputWidth: int,
|
||||
):
|
||||
ndim = len(input)
|
||||
nOutputPlane = nInputPlane
|
||||
|
||||
assert kW > 0 and kH > 0
|
||||
assert dW > 0 and dH > 0
|
||||
|
|
@ -608,12 +607,10 @@ def matmul(tensor1: List[int], tensor2: List[int]):
|
|||
# We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
|
||||
# we track m1 vs m2 separately even though they must match for nicer error messages
|
||||
n = tensor1[-2] if dim_tensor1 > 1 else 1
|
||||
m1 = tensor1[-1]
|
||||
batch_tensor1: List[int] = []
|
||||
# TODO: handling of slice
|
||||
for i in range(dim_tensor1 - 2):
|
||||
batch_tensor1.append(tensor1[i])
|
||||
m2 = tensor2[-1] if dim_tensor2 > 1 else 1
|
||||
p = tensor2[-1]
|
||||
batch_tensor2: List[int] = []
|
||||
# TODO: handling of slice
|
||||
|
|
|
|||
|
|
@ -55,7 +55,6 @@ def _create_interpreter_name_lookup_fn(frames_up=1):
|
|||
i += 1
|
||||
|
||||
f_locals = frame.f_locals
|
||||
f_globals = frame.f_globals
|
||||
|
||||
for k, v in f_locals.items():
|
||||
if isinstance(v, torch.Tensor) and var is v:
|
||||
|
|
@ -136,7 +135,7 @@ class ONNXTracedModule(torch.nn.Module):
|
|||
else:
|
||||
return tuple(out_vars)
|
||||
|
||||
graph, out = torch._C._create_graph_by_tracing(
|
||||
graph, _out = torch._C._create_graph_by_tracing(
|
||||
wrapper,
|
||||
in_vars + module_state,
|
||||
_create_interpreter_name_lookup_fn(),
|
||||
|
|
@ -241,7 +240,6 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
|
|||
if not isinstance(args, tuple):
|
||||
args = (args,)
|
||||
|
||||
saved_args = _clone_inputs(args)
|
||||
if is_module:
|
||||
saved_state = copy.deepcopy(model.state_dict())
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def _gen_unsupported_methods_properties():
|
|||
scope: Dict[str, Any] = {}
|
||||
execWrapper(funcs_str, globals(), scope)
|
||||
try:
|
||||
cu = torch.jit.CompilationUnit(funcs_str)
|
||||
torch.jit.CompilationUnit(funcs_str)
|
||||
except Exception as e:
|
||||
if "nonexistent attribute" not in repr(e):
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -1784,7 +1784,6 @@ def normalize(
|
|||
) -> Tensor:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
dim_ = _canonical_dim(dim, input.ndim)[0]
|
||||
# TODO: eliminate mask_input as unnecessary when using masked divide.
|
||||
mask_input = _combine_input_and_mask(sum, input, mask)
|
||||
if mask_input.layout == torch.strided:
|
||||
|
|
|
|||
|
|
@ -375,7 +375,7 @@ def ones_like(func, *args, **kwargs):
|
|||
@register_dispatch_func([torch.ops.aten._softmax_backward_data])
|
||||
def _softmax_backward_data(func, *args, **kwargs):
|
||||
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4)
|
||||
grad, output, dim, input_dtype = args
|
||||
grad, output, dim, _input_dtype = args
|
||||
if is_masked_tensor(grad) and is_masked_tensor(output):
|
||||
if not _masks_match(grad, output):
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -96,8 +96,8 @@ def _binary_helper(fn, args, kwargs, inplace):
|
|||
"Input masks must match. If you need support for this, please open an issue on Github."
|
||||
)
|
||||
|
||||
data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
|
||||
mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
|
||||
data_args, _data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
|
||||
mask_args, _mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
|
||||
|
||||
args0_layout = data_args[0].layout
|
||||
same_layout = (
|
||||
|
|
|
|||
|
|
@ -120,8 +120,12 @@ def _unary_helper(fn, args, kwargs, inplace):
|
|||
"MaskedTensor unary ops do not support additional Tensor arguments"
|
||||
)
|
||||
|
||||
mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_mask)
|
||||
data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_data)
|
||||
mask_args, _mask_kwargs = _map_mt_args_kwargs(
|
||||
args, kwargs, lambda x: x._masked_mask
|
||||
)
|
||||
data_args, _data_kwargs = _map_mt_args_kwargs(
|
||||
args, kwargs, lambda x: x._masked_data
|
||||
)
|
||||
|
||||
if args[0].layout == torch.sparse_coo:
|
||||
data_args[0] = data_args[0].coalesce()
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class Pool(multiprocessing.pool.Pool):
|
|||
Bring the number of pool processes up to the specified number, for use after
|
||||
reaping workers which have exited.
|
||||
"""
|
||||
for i in range(self._processes - len(self._pool)):
|
||||
for _ in range(self._processes - len(self._pool)):
|
||||
# changed worker -> clean_worker
|
||||
args = (
|
||||
self._inqueue,
|
||||
|
|
|
|||
|
|
@ -664,7 +664,7 @@ def _softmax_default(func, *args, **kwargs):
|
|||
new_kwargs["dim"],
|
||||
reduce_on_batch,
|
||||
reduce_on_ragged,
|
||||
reduce_on_non_batch,
|
||||
_reduce_on_non_batch,
|
||||
) = _wrap_jagged_dims(
|
||||
inp.dim(),
|
||||
(new_kwargs["dim"],),
|
||||
|
|
@ -985,7 +985,7 @@ def matmul_default(func, *args, **kwargs):
|
|||
|
||||
def _padded_impl(a, b):
|
||||
assert a.is_nested and not b.is_nested
|
||||
nt, t = a, b
|
||||
nt = a
|
||||
|
||||
from .nested_tensor import nested_from_padded
|
||||
|
||||
|
|
@ -1588,7 +1588,7 @@ def mean_dim(func, *args, **kwargs):
|
|||
new_kwargs["dim"],
|
||||
reduce_on_batch,
|
||||
reduce_on_ragged,
|
||||
reduce_on_non_batch,
|
||||
_reduce_on_non_batch,
|
||||
) = _wrap_jagged_dims(
|
||||
inp.dim(),
|
||||
new_kwargs["dim"],
|
||||
|
|
|
|||
|
|
@ -323,7 +323,6 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, in
|
|||
cumulative_seqlen = (
|
||||
qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
|
||||
)
|
||||
batch_size = qkv.size(0)
|
||||
max_seqlen = qkv._get_max_seqlen()
|
||||
# TODO: Explore performance impact when compiling
|
||||
n_elem = int(cumulative_seqlen[-1].item())
|
||||
|
|
@ -750,10 +749,10 @@ def jagged_scaled_dot_product_attention(
|
|||
|
||||
(
|
||||
attention,
|
||||
logsumexp,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
debug_attn_mask,
|
||||
_logsumexp,
|
||||
_philox_seed,
|
||||
_philox_offset,
|
||||
_debug_attn_mask,
|
||||
) = torch.ops.aten._flash_attention_forward(
|
||||
query_buffer_reshaped,
|
||||
key_buffer_reshaped,
|
||||
|
|
|
|||
|
|
@ -193,8 +193,6 @@ def _adjust_num_blocks_and_indices(
|
|||
new_num_rows: int,
|
||||
new_num_cols: int,
|
||||
):
|
||||
num_rows = indices.shape[-2]
|
||||
num_columns = indices.shape[-1]
|
||||
indices = indices[:, :, :new_num_rows, :new_num_cols]
|
||||
num_blocks = num_blocks[:, :, :new_num_rows]
|
||||
num_blocks = torch.where(num_blocks < new_num_cols, num_blocks, new_num_cols)
|
||||
|
|
|
|||
|
|
@ -6236,7 +6236,7 @@ def multi_head_attention_forward(
|
|||
#
|
||||
|
||||
if need_weights:
|
||||
B, Nt, E = q.shape
|
||||
_B, _Nt, E = q.shape
|
||||
q_scaled = q * math.sqrt(1.0 / float(E))
|
||||
|
||||
assert not (
|
||||
|
|
|
|||
|
|
@ -224,11 +224,7 @@ class CrossMapLRN2d(Function):
|
|||
|
||||
ctx.scale = ctx.scale or input.new()
|
||||
output = input.new()
|
||||
|
||||
batch_size = input.size(0)
|
||||
channels = input.size(1)
|
||||
input_height = input.size(2)
|
||||
input_width = input.size(3)
|
||||
|
||||
output.resize_as_(input)
|
||||
ctx.scale.resize_as_(input)
|
||||
|
|
|
|||
|
|
@ -2630,7 +2630,7 @@ class Module:
|
|||
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
|
||||
|
||||
"""
|
||||
for name, param in self.named_parameters(recurse=recurse):
|
||||
for _name, param in self.named_parameters(recurse=recurse):
|
||||
yield param
|
||||
|
||||
def named_parameters(
|
||||
|
|
@ -2725,7 +2725,7 @@ class Module:
|
|||
Yields:
|
||||
Module: a child module
|
||||
"""
|
||||
for name, module in self.named_children():
|
||||
for _name, module in self.named_children():
|
||||
yield module
|
||||
|
||||
def named_children(self) -> Iterator[Tuple[str, "Module"]]:
|
||||
|
|
|
|||
|
|
@ -1043,7 +1043,6 @@ class LSTM(RNNBase):
|
|||
orig_input = input
|
||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||
batch_sizes = None
|
||||
do_permute = False
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
|
||||
if isinstance(orig_input, PackedSequence):
|
||||
|
|
|
|||
|
|
@ -1496,7 +1496,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
|
||||
# Disable the python reducer if compiled_autograd is not enabled.
|
||||
if self._accum_grad_hooks:
|
||||
for index, h in enumerate(self._accum_grad_hooks):
|
||||
for h in self._accum_grad_hooks:
|
||||
h.remove()
|
||||
self._accum_grad_hooks.clear()
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class _PyTreeExtensionContext:
|
|||
def _register_huggingface_model_output_extension(self):
|
||||
try:
|
||||
from transformers import modeling_outputs # type: ignore[import]
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
def model_output_flatten(
|
||||
|
|
|
|||
|
|
@ -61,7 +61,6 @@ def set_node_name(
|
|||
new_name: The new name to use.
|
||||
name_to_node_cache: A cache of node names to nodes.
|
||||
"""
|
||||
module = node.graph.owning_module
|
||||
node_name_to_set = collections.deque([(node, new_name)])
|
||||
|
||||
while node_name_to_set:
|
||||
|
|
|
|||
|
|
@ -84,12 +84,11 @@ class Functionalize(_pass.Transform):
|
|||
out = function(*inputs_functional)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
flat_inputs = pytree.tree_leaves(inputs)
|
||||
|
||||
flat_inputs_functional = pytree.tree_leaves(inputs_functional)
|
||||
for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
|
||||
for input_functional in flat_inputs_functional:
|
||||
if isinstance(input_functional, torch.Tensor):
|
||||
torch._sync(input_functional)
|
||||
inpt_new = torch._from_functional_tensor(input_functional)
|
||||
pytree.tree_map(torch._sync, out)
|
||||
out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out)
|
||||
return out_unwrapped
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ class _ModuleMeta:
|
|||
cls, raw_meta: _DYNAMO_NN_MODULE_META_TYPE
|
||||
) -> _ModuleMeta:
|
||||
"""Create a module meta from raw meta produced by FX dynamo tracer."""
|
||||
module_name, (qualified_name, module_class) = raw_meta
|
||||
module_name, (_qualified_name, module_class) = raw_meta
|
||||
return _ModuleMeta(module_name.split("@")[0], module_class, raw_meta)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ _SCALAR_TYPE_TENSOR_DTYPE_MAP: Mapping[type, torch.dtype] = {
|
|||
def _try_getclosurevars(func):
|
||||
try:
|
||||
return inspect.getclosurevars(func)
|
||||
except TypeError as e:
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1988,7 +1988,7 @@ def _embedding_bag_helper(
|
|||
|
||||
# FIXME(justinchuby): We need to handle what happens when we call b.op on a node return
|
||||
block_input_iter = utils._add_input_to_block(loop_block)
|
||||
cond = utils._add_input_to_block(loop_block)
|
||||
utils._add_input_to_block(loop_block)
|
||||
|
||||
indices_start = loop_context.op(
|
||||
"Gather", offsets_starts, block_input_iter, axis_i=0
|
||||
|
|
|
|||
|
|
@ -534,7 +534,7 @@ def stack(g: jit_utils.GraphContext, tensor_list, dim):
|
|||
@_onnx_symbolic("aten::_unique2")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||||
def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts):
|
||||
u, indices, inverse_indices, counts = g.op(
|
||||
u, _indices, inverse_indices, counts = g.op(
|
||||
"Unique", self, sorted_i=sorted, outputs=4
|
||||
)
|
||||
return u, inverse_indices, counts
|
||||
|
|
@ -545,7 +545,7 @@ def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_cou
|
|||
def unique_dim(
|
||||
g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts
|
||||
):
|
||||
u, indices, inverse_indices, counts = g.op(
|
||||
u, _indices, inverse_indices, counts = g.op(
|
||||
"Unique", self, axis_i=dim, sorted_i=sorted, outputs=4
|
||||
)
|
||||
return u, inverse_indices, counts
|
||||
|
|
@ -945,7 +945,6 @@ def index(g: jit_utils.GraphContext, self, index):
|
|||
|
||||
@_onnx_symbolic("aten::index_fill")
|
||||
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
g, self, dim, index
|
||||
)
|
||||
|
|
@ -957,8 +956,7 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
|
|||
|
||||
@_onnx_symbolic("aten::index_copy")
|
||||
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
_expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
g, self, dim, index
|
||||
)
|
||||
return scatter(g, self, dim, expanded_index, source)
|
||||
|
|
|
|||
|
|
@ -346,8 +346,7 @@ def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
|
|||
|
||||
loop_block = loop_context.block
|
||||
block_input_iter = utils._add_input_to_block(loop_block)
|
||||
# FIXME(justinchuby): cond is unused?
|
||||
cond = utils._add_input_to_block(loop_block)
|
||||
cond = utils._add_input_to_block(loop_block) # noqa: F841
|
||||
|
||||
starts = loop_context.op("Gather", low_indices, block_input_iter)
|
||||
ends = loop_context.op("Gather", hi_indices, block_input_iter)
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ def tensor_split(
|
|||
|
||||
loop_block = loop_context.block
|
||||
block_input_iter = utils._add_input_to_block(loop_block)
|
||||
cond = utils._add_input_to_block(loop_block)
|
||||
cond = utils._add_input_to_block(loop_block) # noqa: F841
|
||||
final_splits = utils._add_input_to_block(loop_block)
|
||||
|
||||
start = loop_context.op(
|
||||
|
|
@ -689,7 +689,7 @@ def repeat_interleave(
|
|||
|
||||
loop_block = loop_context.block
|
||||
block_input_iter = utils._add_input_to_block(loop_block)
|
||||
cond = utils._add_input_to_block(loop_block)
|
||||
cond = utils._add_input_to_block(loop_block) # noqa: F841
|
||||
final_splits = utils._add_input_to_block(loop_block)
|
||||
|
||||
r_split = loop_context.op("SequenceAt", r_splits, block_input_iter)
|
||||
|
|
|
|||
|
|
@ -2955,7 +2955,6 @@ def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accum
|
|||
|
||||
@_onnx_symbolic("aten::index_fill")
|
||||
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
g, self, dim, index
|
||||
)
|
||||
|
|
@ -2968,8 +2967,7 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
|
|||
|
||||
@_onnx_symbolic("aten::index_copy")
|
||||
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
_expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
g, self, dim, index
|
||||
)
|
||||
return scatter(g, self, dim, expanded_index, source)
|
||||
|
|
@ -3674,14 +3672,14 @@ def new_full(
|
|||
def eye(g: jit_utils.GraphContext, *args):
|
||||
if len(args) == 5:
|
||||
# aten::eye(n, dtype, layout, device, pin_memory)
|
||||
n, dtype, layout, device, pin_memory = args
|
||||
n, dtype, layout, device, _pin_memory = args
|
||||
dim_size = symbolic_helper._unsqueeze_helper(g, n, [0])
|
||||
shape = g.op("Concat", dim_size, dim_size, axis_i=0)
|
||||
tensor = zeros(g, shape, dtype, layout, device)
|
||||
return g.op("EyeLike", tensor)
|
||||
if len(args) == 6:
|
||||
# aten::eye(n, m, dtype, layout, device, pin_memory)
|
||||
n, m, dtype, layout, device, pin_memory = args
|
||||
n, m, dtype, layout, device, _pin_memory = args
|
||||
shape = g.op(
|
||||
"Concat",
|
||||
symbolic_helper._unsqueeze_helper(g, n, [0]),
|
||||
|
|
@ -5567,14 +5565,14 @@ def linalg_matrix_norm(
|
|||
g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim
|
||||
)
|
||||
if ord_value > 0:
|
||||
result, indices = max(
|
||||
result, _indices = max(
|
||||
g,
|
||||
sum,
|
||||
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
|
||||
keepdim=keepdim,
|
||||
)
|
||||
else:
|
||||
result, indices = min(
|
||||
result, _indices = min(
|
||||
g,
|
||||
sum,
|
||||
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
|
||||
|
|
@ -6391,7 +6389,7 @@ def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]:
|
|||
opset_version = GLOBALS.export_onnx_opset_version
|
||||
|
||||
old_blocks = tuple(node.blocks())
|
||||
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
|
||||
_new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
|
||||
g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks)
|
||||
)
|
||||
|
||||
|
|
@ -6500,7 +6498,7 @@ def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]:
|
|||
return final_b_list
|
||||
else:
|
||||
old_blocks = tuple(n.blocks())
|
||||
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
|
||||
_new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
|
||||
g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1066,7 +1066,7 @@ def _model_to_graph(
|
|||
input_names=input_names,
|
||||
module=module,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
_C._jit_onnx_log("Torch IR graph at exception: ", graph)
|
||||
raise
|
||||
|
||||
|
|
@ -1544,8 +1544,8 @@ def _export(
|
|||
(
|
||||
proto,
|
||||
export_map,
|
||||
val_use_external_data_format,
|
||||
node_names,
|
||||
_val_use_external_data_format,
|
||||
_node_names,
|
||||
) = graph._export_onnx( # type: ignore[attr-defined]
|
||||
params_dict,
|
||||
opset_version,
|
||||
|
|
@ -1563,8 +1563,8 @@ def _export(
|
|||
(
|
||||
proto,
|
||||
export_map,
|
||||
val_use_external_data_format,
|
||||
node_names,
|
||||
_,
|
||||
_,
|
||||
) = graph._export_onnx( # type: ignore[attr-defined]
|
||||
{},
|
||||
opset_version,
|
||||
|
|
|
|||
|
|
@ -1783,7 +1783,7 @@ def find_mismatch(
|
|||
args = utils._decide_input_format(model, inputs_for_export)
|
||||
|
||||
model = utils._pre_trace_quant_model(model, args)
|
||||
graph, params, torch_out, module = utils._create_jit_graph(model, args)
|
||||
graph, params, _torch_out, _module = utils._create_jit_graph(model, args)
|
||||
params_dict = utils._get_named_param_dict(graph, params)
|
||||
|
||||
utils._apply_friendly_debug_names(graph, params_dict)
|
||||
|
|
|
|||
|
|
@ -247,8 +247,7 @@ class LRScheduler:
|
|||
else:
|
||||
values = self.get_lr()
|
||||
|
||||
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
||||
param_group, lr = data
|
||||
for param_group, lr in zip(self.optimizer.param_groups, values):
|
||||
if isinstance(param_group["lr"], Tensor):
|
||||
param_group["lr"].fill_(lr)
|
||||
else:
|
||||
|
|
@ -1865,8 +1864,7 @@ class CosineAnnealingWarmRestarts(LRScheduler):
|
|||
self.last_epoch = math.floor(epoch)
|
||||
|
||||
with _enable_get_lr_call(self):
|
||||
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
|
||||
param_group, lr = data
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group["lr"] = lr
|
||||
|
||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def demangle(name: str) -> str:
|
|||
mangled name, irrespective of which PackageMangler created it.
|
||||
"""
|
||||
if is_mangled(name):
|
||||
first, sep, last = name.partition(".")
|
||||
_first, sep, last = name.partition(".")
|
||||
# If there is only a base mangle prefix, e.g. '<torch_package_0>',
|
||||
# then return an empty string.
|
||||
return last if len(sep) != 0 else ""
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ class PackagePickler(_PyTorchLegacyPickler):
|
|||
self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
|
||||
|
||||
def save_global(self, obj, name=None):
|
||||
# ruff: noqa: F841
|
||||
# unfortunately the pickler code is factored in a way that
|
||||
# forces us to copy/paste this function. The only change is marked
|
||||
# CHANGED below.
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class _ExtractModuleReferences(ast.NodeVisitor):
|
|||
self.references[(name, alias)] = True
|
||||
else:
|
||||
self.references[(name, None)] = True
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -427,7 +427,7 @@ class PackageExporter:
|
|||
def _import_module(self, module_name: str):
|
||||
try:
|
||||
return self.importer.import_module(module_name)
|
||||
except ModuleNotFoundError as e:
|
||||
except ModuleNotFoundError:
|
||||
if not is_mangled(module_name):
|
||||
raise
|
||||
msg = (
|
||||
|
|
@ -662,7 +662,7 @@ class PackageExporter:
|
|||
memo: DefaultDict[int, str] = defaultdict(None)
|
||||
memo_count = 0
|
||||
# pickletools.dis(data_value)
|
||||
for opcode, arg, pos in pickletools.genops(data_value):
|
||||
for opcode, arg, _pos in pickletools.genops(data_value):
|
||||
if pickle_protocol == 4:
|
||||
if (
|
||||
opcode.name == "SHORT_BINUNICODE"
|
||||
|
|
|
|||
|
|
@ -463,7 +463,6 @@ class PackageImporter(Importer):
|
|||
|
||||
# note: copied from cpython's import code, with call to create module replaced with _make_module
|
||||
def _do_find_and_load(self, name):
|
||||
path = None
|
||||
parent = name.rpartition(".")[0]
|
||||
module_name_no_parent = name.rpartition(".")[-1]
|
||||
if parent:
|
||||
|
|
@ -475,7 +474,7 @@ class PackageImporter(Importer):
|
|||
parent_module = self.modules[parent]
|
||||
|
||||
try:
|
||||
path = parent_module.__path__ # type: ignore[attr-defined]
|
||||
parent_module.__path__ # type: ignore[attr-defined]
|
||||
|
||||
except AttributeError:
|
||||
# when we attempt to import a package only containing pybinded files,
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ def _extract_parameters_and_gradients(
|
|||
|
||||
|
||||
def extract_parameters(node: _ProfilerEvent) -> Iterator[TensorKey]:
|
||||
for p, p_grad in _extract_parameters_and_gradients(node):
|
||||
for p, _p_grad in _extract_parameters_and_gradients(node):
|
||||
if p is not None:
|
||||
yield p
|
||||
|
||||
|
|
|
|||
|
|
@ -884,7 +884,7 @@ class ExecutionTraceObserver(_ITraceObserver):
|
|||
for kernel_file in kernel_files:
|
||||
if kernel_file is None:
|
||||
continue
|
||||
path, name = os.path.split(kernel_file)
|
||||
name = os.path.basename(kernel_file)
|
||||
dst = os.path.join(resource_dir, name)
|
||||
shutil.copyfile(kernel_file, dst)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ def default_eval_fn(model, calib_data):
|
|||
Default evaluation function takes a torch.utils.data.Dataset or a list of
|
||||
input Tensors and run the model on the dataset
|
||||
"""
|
||||
for data, target in calib_data:
|
||||
for data, _target in calib_data:
|
||||
model(data)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1493,7 +1493,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
|||
tar.extract("storages", path=tmpdir)
|
||||
with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
|
||||
num_storages = pickle_module.load(f, **pickle_load_args)
|
||||
for i in range(num_storages):
|
||||
for _ in range(num_storages):
|
||||
args = pickle_module.load(f, **pickle_load_args)
|
||||
key, location, storage_type = args
|
||||
dtype = storage_type._dtype
|
||||
|
|
@ -1527,7 +1527,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
|||
num_tensors = pickle_module.load(f, **pickle_load_args)
|
||||
for _ in range(num_tensors):
|
||||
args = pickle_module.load(f, **pickle_load_args)
|
||||
key, storage_id, original_tensor_type = args
|
||||
key, storage_id, _original_tensor_type = args
|
||||
storage = deserialized_objects[storage_id]
|
||||
(ndim,) = struct.unpack("<i", f.read(4))
|
||||
# skip next 4 bytes; legacy encoding treated ndim as 8 bytes
|
||||
|
|
|
|||
|
|
@ -87,11 +87,11 @@ def sparse_semi_structured_from_dense_cutlass(dense):
|
|||
if dense.dtype != torch.float:
|
||||
ksparse = 4
|
||||
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
||||
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
|
||||
m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1)
|
||||
else:
|
||||
ksparse = 2
|
||||
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
||||
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
||||
m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
||||
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
||||
|
||||
# Encoding quadruples of True/False values as follows:
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user