[BE]: Apply PERF401 autofixes from ruff (#140980)

* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
This commit is contained in:
Aaron Gokaslan 2024-11-20 17:52:07 +00:00 committed by PyTorch MergeBot
parent 8d708090c0
commit 12e95aa4ee
133 changed files with 611 additions and 761 deletions

View File

@ -67,8 +67,7 @@ def pretty_print_buckets(buckets: List[Bucket], bucket_bytes_cap: int):
for idx, bucket in enumerate(reversed(buckets)):
if len(bucket.params) > 0:
rows.append((idx, bucket.size, bucket.params[0]))
for param in bucket.params[1:]:
rows.append((None, None, param))
rows.extend((None, None, param) for param in bucket.params[1:])
if bucket.opcount_increased_to_capture_external_output > 0:
extended_buckets.append(
(

View File

@ -1140,11 +1140,12 @@ def update_offsets(instructions) -> None:
def debug_bytes(*args) -> str:
index = range(max(map(len, args)))
result = []
for arg in (
[index] + list(args) + [[int(a != b) for a, b in zip(args[-1], args[-2])]]
):
result.append(" ".join(f"{x:03}" for x in arg))
result = [
" ".join(f"{x:03}" for x in arg)
for arg in [index]
+ list(args)
+ [[int(a != b) for a, b in zip(args[-1], args[-2])]]
]
return "bytes mismatch\n" + "\n".join(result)

View File

@ -478,10 +478,8 @@ class AutogradCompilerInstance:
@staticmethod
def get_all_nodes(args):
nodes = []
for n in args:
if type(n) is torch.fx.Node: # filter out non-Node args, like None
nodes.append(n)
# filter out non-Node args, like None
nodes = [n for n in args if type(n) is torch.fx.Node]
return nodes
@staticmethod
@ -671,13 +669,15 @@ class AutogradCompilerInstance:
input_nodes_and_users = []
input_nodes_and_users.extend(list(input_nodes))
for input_node in input_nodes:
for user in list(input_node.users.keys()):
input_nodes_and_users.extend(
user
for user in list(input_node.users.keys())
if not (
user.op == "call_function"
and user.target == call_hook
and node.kwargs.get("hook_type", None) == "post_hook"
):
input_nodes_and_users.append(user)
)
)
arg = max(input_nodes_and_users) # last input users
if (

View File

@ -2526,7 +2526,6 @@ def make_torch_function_mode_stack_guard(intial_stack):
def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
duplicate_tensors = []
global_scope = dict(guard_manager.global_scope)
ids_to_source = collections.defaultdict(list)
for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined]
@ -2534,9 +2533,9 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
tensor_id = id(eval(tensor_source, global_scope, scope))
ids_to_source[tensor_id].append(tensor_source)
for key in ids_to_source:
if len(ids_to_source[key]) > 1:
duplicate_tensors.append(f"{ids_to_source[key]}")
duplicate_tensors = [
f"{ids_to_source[key]}" for key in ids_to_source if len(ids_to_source[key]) > 1
]
reason = ", ".join(duplicate_tensors)
return [f"Duplicate tensors found: {reason}"]

View File

@ -1463,9 +1463,7 @@ class OutputGraph:
return compiled_fn
def example_inputs(self) -> List[torch.Tensor]:
result = []
for arg in self.graphargs:
result.append(arg.example)
result = [arg.example for arg in self.graphargs]
return result
def remove_unused_graphargs(self) -> None:

View File

@ -252,8 +252,7 @@ def _filter_iter(l1, l2, cond):
def _load_tuple_and_call(tup):
insts: List[Instruction] = []
_initial_push_null(insts)
for val in tup:
insts.append(create_instruction("LOAD_CONST", argval=val))
insts.extend(create_instruction("LOAD_CONST", argval=val) for val in tup)
insts.extend(create_call_function(len(tup), False))
return insts

View File

@ -95,9 +95,7 @@ def collect_results(
results.append(buffers)
for example in example_inputs:
if isinstance(example, (tuple, list)):
for inp in example:
if isinstance(inp, torch.Tensor):
results.append(inp.grad)
results.extend(inp.grad for inp in example if isinstance(inp, torch.Tensor))
else:
if isinstance(example, torch.Tensor):
results.append(example.grad)

View File

@ -1535,9 +1535,10 @@ def checkpoint_params(gm):
rng_state = torch.clone(torch.random.get_rng_state())
if torch.cuda.is_available():
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
saved_state = []
for param in itertools.chain(gm.parameters(), gm.buffers()):
saved_state.append((param, param._version, torch.clone(param)))
saved_state = [
(param, param._version, torch.clone(param))
for param in itertools.chain(gm.parameters(), gm.buffers())
]
def restore():
with torch.no_grad():

View File

@ -1765,10 +1765,9 @@ class BuiltinVariable(VariableTracker):
# tracked fakes to produce incorrect guards. This is sound because the TensorVariable
# coming out of set_() below will be a new one, and get
# installed in tracked fakes.
to_remove = []
for tf in tx.output.tracked_fakes:
if tf.source == obj.source:
to_remove.append(tf)
to_remove = [
tf for tf in tx.output.tracked_fakes if tf.source == obj.source
]
for tf in to_remove:
tx.output.tracked_fakes.remove(tf)

View File

@ -1026,17 +1026,16 @@ class SDPAKernelVariable(ContextWrappingVariable):
@staticmethod
def _backends_to_nodes(tx, backends):
nodes = []
for backend in backends:
# convert to/from string in order to bake the backend into FX graph
nodes.append(
tx.output.create_node(
"call_function",
torch.nn.attention._backend_from_string,
(backend.name,),
{},
)
# convert to/from string in order to bake the backend into FX graph
nodes = [
tx.output.create_node(
"call_function",
torch.nn.attention._backend_from_string,
(backend.name,),
{},
)
for backend in backends
]
return nodes
def enter(self, tx):

View File

@ -48,9 +48,9 @@ class ItertoolsVariable(VariableTracker):
and all(arg.has_unpack_var_sequence(tx) for arg in args)
):
seqs = [arg.unpack_var_sequence(tx) for arg in args]
items = []
for item in itertools.product(*seqs):
items.append(variables.TupleVariable(list(item)))
items = [
variables.TupleVariable(list(item)) for item in itertools.product(*seqs)
]
return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew()
)

View File

@ -776,14 +776,13 @@ class UserDefinedObjectVariable(UserDefinedVariable):
):
assert self.source # OrderedDict, dict subtypes must always have source
assert not (args or kwargs)
items = []
keys = self.call_method(tx, "keys", [], {})
for key in keys.force_unpack_var_sequence(tx):
items.append(
TupleVariable(
[key, self.odict_getitem(tx, key)],
)
items = [
TupleVariable(
[key, self.odict_getitem(tx, key)],
)
for key in keys.force_unpack_var_sequence(tx)
]
tx.output.guard_on_key_order.add(self.source.name())
return TupleVariable(items)

View File

@ -833,9 +833,7 @@ class TS2FXGraphConverter:
self._convert_prim_iterator(node)
def _convert_prim_iterator(self, node: torch._C.Node):
output_list = []
for inp in node.inputs():
output_list.append(self.get_fx_value_by_ir_value(inp))
output_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()]
output_name = node.output().debugName()
self.name_to_node[output_name] = output_list

View File

@ -7,9 +7,8 @@ class StaticForLoop(torch.nn.Module):
"""
def forward(self, x):
ret = []
for i in range(10): # constant
ret.append(i + x)
# constant
ret = [i + x for i in range(10)]
return ret
example_args = (torch.randn(3, 2),)

View File

@ -1780,9 +1780,7 @@ class GraphModuleDeserializer(metaclass=Final):
) from e
# Outputs: convert to a single `output` node.
outputs = []
for output in serialized_graph.outputs:
outputs.append(self.deserialize_graph_output(output))
outputs = [self.deserialize_graph_output(output) for output in serialized_graph.outputs]
if serialized_graph.is_single_tensor_return:
assert len(outputs) == 1
@ -2149,9 +2147,7 @@ class GraphModuleDeserializer(metaclass=Final):
if len(value) == 0:
return []
elif typ_ == "as_tensors":
result = []
for arg in value:
result.append(self.serialized_name_to_node[arg.name])
result = [self.serialized_name_to_node[arg.name] for arg in value]
return result
elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"):
# convert from serialized.python.types.List to python list

View File

@ -1306,8 +1306,9 @@ def merge_view_inputs(
mutated_input_info[inpt_idx].mutates_data
for inpt_idx in aliased_input_indices
):
for curr_idx in aliased_input_indices:
other_args.append(fwd_inputs[curr_idx])
other_args.extend(
fwd_inputs[curr_idx] for curr_idx in aliased_input_indices
)
continue
# Here, we attempt to do a more complicated check to detect false aliasing
@ -1320,8 +1321,9 @@ def merge_view_inputs(
fwd_inputs, aliased_input_indices
)
if len(aliased_input_indices_no_false_sharing) <= 1:
for curr_idx in aliased_input_indices:
other_args.append(fwd_inputs[curr_idx])
other_args.extend(
fwd_inputs[curr_idx] for curr_idx in aliased_input_indices
)
continue
# We detected an input that was mutated, AND aliases with another input.

View File

@ -217,9 +217,7 @@ def trace_map(proxy_mode, func_overload, f, xs, pos_args):
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, xs, pos_args):
pytrees = []
for inp in _unstack_pytree(xs):
pytrees.append(f(*inp, *pos_args))
pytrees = [f(*inp, *pos_args) for inp in _unstack_pytree(xs)]
return _stack_pytree(pytrees)

View File

@ -543,8 +543,7 @@ def analyze_kernel_mutations(
)
stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated)
else:
for idx in MUTATION_OPS.get(op.name, []):
stack.append(op.args[idx])
stack.extend(op.args[idx] for idx in MUTATION_OPS.get(op.name, []))
# The following is an iterative DFS algorithm
mutated = [False] * num_args

View File

@ -388,9 +388,7 @@ def _unstack_pytree(xs):
a = zip(*flat_xs)
pytrees = []
for tuple in a:
pytrees.append(pytree.tree_unflatten(tuple, inspec))
pytrees = [pytree.tree_unflatten(tuple, inspec) for tuple in a]
return pytrees

View File

@ -2102,13 +2102,14 @@ class AotCodeCompiler:
aot_constants = struct.pack("qq", consts_size + 8, magic_number)
consts_o = _compile_consts(aot_constants, sys.platform)
kernels_o = []
gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = (
ROCmCodeCache() if torch.version.hip else CUDACodeCache()
)
for entry in gpu_codecache.cache.values():
if entry.output_path.endswith(".o"):
kernels_o.append(entry.output_path)
kernels_o = [
entry.output_path
for entry in gpu_codecache.cache.values()
if entry.output_path.endswith(".o")
]
kernels_o = " ".join(kernels_o)
output_name, output_dir = get_name_and_dir_from_output_file_path(output_so)

View File

@ -4800,10 +4800,9 @@ class LoopLevel:
def split_with_tiling(self, depth, factor):
def clone_inner():
inner = []
inner: List[LoopLevel] = []
if self.inner:
for loop in self.inner:
inner.append(loop.clone())
inner.extend(loop.clone() for loop in self.inner)
return inner
def do_split_with_tiling():

View File

@ -150,18 +150,16 @@ class DebugPrinterManager:
# get the list of args_to_print_or_save
# TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls
if kernel_type == "extern":
args_to_print_or_save_extern = []
for arg in args_to_print_or_save:
if arg.startswith(("buf", "arg")):
args_to_print_or_save_extern.append(arg)
args_to_print_or_save_extern = [
arg for arg in args_to_print_or_save if arg.startswith(("buf", "arg"))
]
self.args_to_print_or_save = args_to_print_or_save_extern
elif kernel_type == "cpp":
args_to_print_or_save_cpp = []
for arg in args_to_print_or_save:
if arg.startswith(("buf", "arg")):
args_to_print_or_save_cpp.append(
f"convert_arrayref_tensor_to_tensor({arg})"
)
args_to_print_or_save_cpp = [
f"convert_arrayref_tensor_to_tensor({arg})"
for arg in args_to_print_or_save
if arg.startswith(("buf", "arg"))
]
self.args_to_print_or_save = args_to_print_or_save_cpp
else:
self.args_to_print_or_save = args_to_print_or_save

View File

@ -1392,19 +1392,19 @@ class HalideKernel(SIMDKernel):
result.append((call_str, arg))
if isinstance(arg, TensorArg):
assert arg.offset == 0 and arg.alias_of is None
for alias in self.buffer_aliases.get(arg.name, ()):
result.append(
(
None,
TensorArg(
alias,
arg.buffer,
arg.dtype,
arg.offset,
alias_of=arg.name,
),
)
result.extend(
(
None,
TensorArg(
alias,
arg.buffer,
arg.dtype,
arg.offset,
alias_of=arg.name,
),
)
for alias in self.buffer_aliases.get(arg.name, ())
)
return result
def halide_kernel_meta(self) -> HalideMeta:

View File

@ -72,10 +72,11 @@ def get_all_call_args(call_args_list, arg_types_list):
def get_numel_argdefs(kernel):
numel_argdefs = []
for tree in kernel.range_trees:
if tree.prefix != "r" or kernel.inside_reduction:
numel_argdefs.append(f"{tree.prefix}numel")
numel_argdefs = [
f"{tree.prefix}numel"
for tree in kernel.range_trees
if tree.prefix != "r" or kernel.inside_reduction
]
return numel_argdefs

View File

@ -76,7 +76,7 @@ def _default_custom_combo_kernel_horizontal_partition(
tilings = [node_info_map[n][1] for n in nodes]
max_dims = max(len(t) for t in tilings)
nodes_per_ndim = []
nodes_per_ndim: List[List[BaseSchedulerNode]] = []
for i in range(2, max_dims + 1):
group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i]
reduction = [
@ -111,12 +111,11 @@ def _default_custom_combo_kernel_horizontal_partition(
len(large_pointwise),
)
not_reduction = [n for n in not_reduction if n not in large_pointwise]
for node in large_pointwise:
nodes_per_ndim.append([node])
nodes_per_ndim.extend([node] for node in large_pointwise)
for g in (not_reduction, short_reduction, long_reduction):
if g:
nodes_per_ndim.append(g)
nodes_per_ndim.extend(
g for g in (not_reduction, short_reduction, long_reduction) if g
)
assert sum(len(p) for p in nodes_per_ndim) == len(nodes)
return nodes_per_ndim

View File

@ -288,8 +288,7 @@ def get_compiler_version_info(compiler: str) -> str:
# =============================== cpp builder ===============================
def _append_list(dest_list: List[str], src_list: List[str]) -> None:
for item in src_list:
dest_list.append(copy.deepcopy(item))
dest_list.extend(copy.deepcopy(item) for item in src_list)
def _remove_duplication_in_list(orig_list: List[str]) -> List[str]:
@ -742,12 +741,11 @@ def _setup_standard_sys_libs(
def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> Tuple[List[str], List[str]]:
macros = []
build_flags = []
macros: List[str] = []
build_flags: List[str] = []
if vec_isa != invalid_vec_isa:
# Add Windows support later.
for x in vec_isa.build_macro():
macros.append(copy.deepcopy(x))
macros.extend(copy.deepcopy(x) for x in vec_isa.build_macro())
build_flags = [vec_isa.build_arch_flags()]

View File

@ -398,9 +398,11 @@ def valid_vec_isa_list() -> List[VecISA]:
arch value is x86_64 on Linux, and the value is AMD64 on Windows.
"""
_cpu_supported_x86_isa = x86_isa_checker()
for isa in supported_vec_isa_list:
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa:
isa_list.append(isa)
isa_list.extend(
isa
for isa in supported_vec_isa_list
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa
)
return isa_list

View File

@ -621,13 +621,12 @@ class CUDAWarmupNode:
}
def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]:
non_cudagraph_inps = []
for t in itertools.chain(new_inputs, self.wrapped_function.constants):
if (
isinstance(t, torch.Tensor)
and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
):
non_cudagraph_inps.append(weakref.ref(t.untyped_storage()))
non_cudagraph_inps = [
weakref.ref(t.untyped_storage())
for t in itertools.chain(new_inputs, self.wrapped_function.constants)
if isinstance(t, torch.Tensor)
and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
]
return non_cudagraph_inps
non_cudagraph_inps_storages = get_non_cudagraph_inps()
@ -1709,12 +1708,10 @@ def get_block_addrs(pool_id: Tuple[int, int], live_only: bool = True) -> List[in
def format_tb(frames: List[Any]) -> str:
formatted_traceback = []
for entry in frames:
formatted_traceback.append(
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
)
formatted_traceback = [
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
for entry in frames
]
return "".join(traceback.format_list(formatted_traceback))

View File

@ -127,9 +127,7 @@ class MemoryDep(Dep):
return None
stride_to_index = {s: i for i, s in enumerate(self_strides)}
order = []
for s in other_strides:
order.append(stride_to_index[s])
order = [stride_to_index[s] for s in other_strides]
assert set(order) == set(range(0, self.num_vars))
return order
@ -547,9 +545,7 @@ def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Sy
def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
var_ranges, add_var = var_builder(prefix)
args: List[List[sympy.Symbol]] = []
for size in argsizes:
args.append(list(map(add_var, size)))
args: List[List[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
return args, var_ranges

View File

@ -1120,9 +1120,9 @@ class UnbindCatRemover(SplitCatSimplifier):
return
# we need to check if the getitem indices from unbind are consecutive and all go to the same cat node
# before we do the unbind remove, otherwise it will hit the error when we unbind part of them
getitem_indices = []
for getitem_node in unbind_node.users.keys():
getitem_indices.append(getitem_node.args[1])
getitem_indices = [
getitem_node.args[1] for getitem_node in unbind_node.users.keys()
]
if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type]
getitem_indices
) != len(
@ -1497,9 +1497,8 @@ def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
):
continue
# find the index of getitems to be cated/stacked
indices = []
for arg in cat_user.args[0]: # type: ignore[union-attr]
indices.append(arg.args[1]) # type: ignore[union-attr]
# type: ignore[union-attr]
indices = [arg.args[1] for arg in cat_user.args[0]] # type: ignore[union-attr]
# the gettitems to be merged must be consecutive, otherwise
# returned sliced tensor could be wrong
if not is_sorted_and_consecutive(indices):

View File

@ -1657,15 +1657,14 @@ class GraphLowering(torch.fx.Interpreter):
new_unbacked_defs |= op.get_unbacked_symbol_defs()
def format_new_defs() -> str:
r = []
for buf in self.buffers[buffer_watermark:]:
r.append(
f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n"
)
for op in self.operations[operation_watermark:]:
r.append(
f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n"
)
r = [
f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n"
for buf in self.buffers[buffer_watermark:]
]
r.extend(
f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n"
for op in self.operations[operation_watermark:]
)
return "***\n".join(r)
if n.op != "placeholder":

View File

@ -5484,13 +5484,14 @@ class UserDefinedTritonKernel(ExternKernel):
kernel = kernel_side_table.get_kernel(self.kernel_idx)
configs = []
restore_value_args = []
restore_value_args: List[str] = []
if isinstance(kernel, Autotuner):
# https://github.com/triton-lang/triton/pull/5083
# changes kernel.restore_idx to kernel.restore_value
if hasattr(kernel, "restore_idx"):
for i in kernel.restore_idx:
restore_value_args.append(kernel.fn.arg_names[i])
restore_value_args.extend(
kernel.fn.arg_names[i] for i in kernel.restore_idx
)
else:
assert hasattr(kernel, "restore_value")
restore_value_args.extend(kernel.restore_value)

View File

@ -1691,9 +1691,7 @@ def split_with_sizes(x, sizes, dim=0):
def unbind(x, dim=0):
dim = _validate_dim(x, dim, 0)
x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim])
result = []
for i in range(x_size):
result.append(select(x, dim, i))
result = [select(x, dim, i) for i in range(x_size)]
return result

View File

@ -87,8 +87,7 @@ def _prepare_convolution_fusion_create(
weight_size = []
weight_size.append(prepacked_weight_size[1] * groups)
weight_size.append(prepacked_weight_size[0] / groups)
for d in range(2, dim):
weight_size.append(prepacked_weight_size[d])
weight_size.extend(prepacked_weight_size[d] for d in range(2, dim))
else:
weight_size = prepacked_weight.transpose(0, 1).size()
return weight_size

View File

@ -952,9 +952,10 @@ class PatternPrettyPrinter:
assert hasattr(obj, "pretty_print")
out_str = obj.pretty_print(pp=pp)
output = []
for key in pp.memoized_objs_names:
output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}")
output = [
f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}"
for key in pp.memoized_objs_names
]
output.append(f"{output_name} = {out_str}")
@ -1361,9 +1362,7 @@ def register_replacement(
return False
def normalize_args(**kwargs: Any) -> List[Any]:
args = []
for name in argnames_static:
args.append(kwargs.pop(name))
args = [kwargs.pop(name) for name in argnames_static]
for i in range(1, len(kwargs) + 1):
if f"tangents_{i}" not in kwargs:
break

View File

@ -120,7 +120,7 @@ def autotune_hints_to_configs(
configs to try.
"""
xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...]
configs = []
configs: List[Config] = []
warp_size = device_props.warp_size
# CPU target has no concept of "warp"
if warp_size is None:
@ -138,16 +138,16 @@ def autotune_hints_to_configs(
(1, block_size // 4, 1),
(1, 1, block_size // 4),
)
for xyz in xyz_options:
configs.append(
triton_config(
size_hints,
*xyz,
num_elements_per_warp=(
device_props.warp_size if device_props.warp_size else 32
),
)
configs.extend(
triton_config(
size_hints,
*xyz,
num_elements_per_warp=(
device_props.warp_size if device_props.warp_size else 32
),
)
for xyz in xyz_options
)
return configs

View File

@ -308,9 +308,13 @@ class BaseSchedulerNode:
dep = deps.pop()
used_names.add(dep)
if V.graph.name_to_buffer.get(dep):
for alias in V.graph.name_to_buffer[dep].get_inputs_that_alias_output():
if alias not in used_names:
deps.append(alias)
deps.extend(
alias
for alias in V.graph.name_to_buffer[
dep
].get_inputs_that_alias_output()
if alias not in used_names
)
return used_names
def prune_deps(self) -> None:
@ -3324,10 +3328,11 @@ class Scheduler:
node1 = node2
node2 = tmp
deps = []
for dep in node1.read_writes.reads | node1.read_writes.writes:
if dep in node2.read_writes.reads or dep in node2.read_writes.writes:
deps.append(dep)
deps = [
dep
for dep in node1.read_writes.reads | node1.read_writes.writes
if dep in node2.read_writes.reads or dep in node2.read_writes.writes
]
return sum(self.dep_size_hint(dep) for dep in deps)

View File

@ -176,14 +176,12 @@ def derived_types(
]
if list_base:
for seq_typ in derived_seq_types(base_type):
result.append((seq_typ, f"{cpp_type}[]")) # type: ignore[valid-type]
result.extend((seq_typ, f"{cpp_type}[]") for seq_typ in derived_seq_types(base_type)) # type: ignore[valid-type]
if optional_base_list:
for seq_typ in derived_seq_types(typing.Optional[base_type]):
result.append((seq_typ, f"{cpp_type}?[]")) # type: ignore[valid-type]
result.extend((seq_typ, f"{cpp_type}?[]") for seq_typ in derived_seq_types(typing.Optional[base_type])) # type: ignore[valid-type]
if optional_list_base:
for seq_typ in derived_seq_types(base_type): # type: ignore[valid-type]
result.append((typing.Optional[seq_typ], f"{cpp_type}[]?")) # type: ignore[valid-type]
# type: ignore[valid-type]
result.extend((typing.Optional[seq_typ], f"{cpp_type}[]?") for seq_typ in derived_seq_types(base_type)) # type: ignore[valid-type]
return result

View File

@ -279,17 +279,17 @@ def requires_set_python_module() -> bool:
def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
overload_types = []
args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
for a in args_flattened:
# TODO: need to double check the semantics of the "types" argument to torch_dispatch.
# It's generated in PyInterpreter.cpp, but seems to be generated in two places,
# where in one case we only include tensors with the python key, and in another
# we include **all** tensors.
if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has(
torch._C.DispatchKey.Python
):
overload_types.append(type(a))
# TODO: need to double check the semantics of the "types" argument to torch_dispatch.
# It's generated in PyInterpreter.cpp, but seems to be generated in two places,
# where in one case we only include tensors with the python key, and in another
# we include **all** tensors.
overload_types = [
type(a)
for a in args_flattened
if isinstance(a, torch.Tensor)
and torch._C._dispatch_keys(a).has(torch._C.DispatchKey.Python)
]
# TODO: check that I got these args correct (in C++, we pass in "0000"??)
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)

View File

@ -43,15 +43,14 @@ def dump_file(filename: str) -> None:
def from_traceback(tb: Sequence[traceback.FrameSummary]) -> List[Dict[str, Any]]:
r = []
for frame in tb:
# dict naming convention here coincides with
# python/combined_traceback.cpp
r.append(
{
"line": frame.lineno,
"name": frame.name,
"filename": intern_string(frame.filename),
}
)
# dict naming convention here coincides with
# python/combined_traceback.cpp
r = [
{
"line": frame.lineno,
"name": frame.name,
"filename": intern_string(frame.filename),
}
for frame in tb
]
return r

View File

@ -312,10 +312,11 @@ def _make_prim(
prim_autograd_impl.impl(name, _autograd_impl)
prim_meta_impl.impl(name, meta)
else:
mutates_args = []
for arg in cpp_schema.arguments:
if arg.alias_info is not None and arg.alias_info.is_write:
mutates_args.append(arg.name)
mutates_args = [
arg.name
for arg in cpp_schema.arguments
if arg.alias_info is not None and arg.alias_info.is_write
]
prim_def = torch.library.custom_op(
"prims::" + name,
_prim_impl,

View File

@ -3055,9 +3055,7 @@ def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType,
full_chunks = math.floor(length / chunk_size)
tail_chunk_size = length % chunk_size
result = []
for i in range(full_chunks):
result.append(narrow(a, dim, i * chunk_size, chunk_size))
result = [narrow(a, dim, i * chunk_size, chunk_size) for i in range(full_chunks)]
if tail_chunk_size != 0:
result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))

View File

@ -585,14 +585,13 @@ def has_meta(func):
lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
)
def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
tensor_lists = []
for arg in itertools.chain(args, kwargs.values()):
if (
isinstance(arg, (list, tuple))
and len(arg)
and isinstance(arg[0], torch.Tensor)
):
tensor_lists.append(arg)
tensor_lists = [
arg
for arg in itertools.chain(args, kwargs.values())
if isinstance(arg, (list, tuple))
and len(arg)
and isinstance(arg[0], torch.Tensor)
]
try:
with in_kernel_invocation_manager(fake_mode):

View File

@ -171,8 +171,7 @@ def get_plain_tensors(
continue
inner_keys, _ = curr.__tensor_flatten__()
for key in reversed(inner_keys):
todo.append(getattr(curr, key))
todo.extend(getattr(curr, key) for key in reversed(inner_keys))
return out
@ -1629,13 +1628,12 @@ class FakeTensorMode(TorchDispatchMode):
)
if isinstance(output, tuple):
output_infos = []
for out_elem in output:
output_infos.append(
self._get_output_info_for_cache_entry(
state, key, func, args, kwargs, out_elem
)
output_infos = [
self._get_output_info_for_cache_entry(
state, key, func, args, kwargs, out_elem
)
for out_elem in output
]
return _DispatchCacheEntry(
output_infos=tuple(output_infos), is_output_tuple=True
)
@ -1727,17 +1725,16 @@ class FakeTensorMode(TorchDispatchMode):
"""
if entry.is_output_tuple:
outputs = []
for output_info in entry.output_infos:
outputs.append(
self._get_output_tensor_from_cache_entry(
state,
output_info,
key,
func,
args,
)
outputs = [
self._get_output_tensor_from_cache_entry(
state,
output_info,
key,
func,
args,
)
for output_info in entry.output_infos
]
return tuple(outputs)
else:
return self._get_output_tensor_from_cache_entry(

View File

@ -389,17 +389,17 @@ class LSTM(torch.nn.Module):
**factory_kwargs,
)
]
for layer in range(1, num_layers):
layers.append(
_LSTMLayer(
self.hidden_size,
self.hidden_size,
self.bias,
batch_first=False,
bidirectional=self.bidirectional,
**factory_kwargs,
)
layers.extend(
_LSTMLayer(
self.hidden_size,
self.hidden_size,
self.bias,
batch_first=False,
bidirectional=self.bidirectional,
**factory_kwargs,
)
for layer in range(1, num_layers)
)
self.layers = torch.nn.ModuleList(layers)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):

View File

@ -32,8 +32,7 @@ def _reverse_repeat_padding(padding: List[int]) -> List[int]:
_reversed_padding_repeated_twice: List[int] = []
N = len(padding)
for idx in range(N):
for _ in range(2):
_reversed_padding_repeated_twice.append(padding[N - idx - 1])
_reversed_padding_repeated_twice.extend(padding[N - idx - 1] for _ in range(2))
return _reversed_padding_repeated_twice

View File

@ -568,9 +568,9 @@ def create_one_transformed_and_logged_copy_of_subgraph(
and len(arg)
and isinstance(arg[0], Node)
):
for inner_arg in arg:
if isinstance(inner_arg, Node):
new_args.append(inner_arg)
new_args.extend(
inner_arg for inner_arg in arg if isinstance(inner_arg, Node)
)
new_kwargs = {}
for name, old_kwarg in first_node.kwargs.items():

View File

@ -312,10 +312,7 @@ def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
node.target in (torch.add, torch.ops.quantized.add, operator.add)
or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
):
result = []
for i in range(2):
if type(node.args[i]) == Node:
result.append(i)
result = [i for i in range(2) if type(node.args[i]) == Node]
return result
return [0]

View File

@ -191,8 +191,7 @@ def expand_groups_in_paired_modules_list(paired_modules_list):
elif len(group) == 2:
new_list.append(group)
elif len(group) > 2:
for i in range(len(group) - 1):
new_list.append([group[i], group[i + 1]])
new_list.extend([group[i], group[i + 1]] for i in range(len(group) - 1))
return new_list

View File

@ -155,14 +155,14 @@ def _get_binary_op_configs(
(op_with_quantized_bop_scalar_variant, torch.relu),
op_with_quantized_bop_scalar_variant,
]
for bop_pattern in bop_patterns:
binary_op_configs.append(
BackendPatternConfig(bop_pattern)
.set_dtype_configs(dtype_configs) # noqa: E131
._set_num_tensor_args_to_observation_type(
num_tensor_args_to_observation_type_mapping
)
binary_op_configs.extend(
BackendPatternConfig(bop_pattern)
.set_dtype_configs(dtype_configs) # noqa: E131
._set_num_tensor_args_to_observation_type(
num_tensor_args_to_observation_type_mapping
)
for bop_pattern in bop_patterns
)
# matmul
binary_op_configs.append(
BackendPatternConfig(torch.matmul).set_dtype_configs(
@ -502,7 +502,6 @@ def _get_ln_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConf
def _get_default_op_configs(
dtype_configs: List[DTypeConfig],
) -> List[BackendPatternConfig]:
configs = []
default_ops = [
torch.nn.ELU,
torch.nn.LeakyReLU,
@ -517,14 +516,14 @@ def _get_default_op_configs(
torch.nn.functional.leaky_relu,
torch.nn.functional.dropout,
]
for op in default_ops:
configs.append(
BackendPatternConfig(op)
.set_observation_type(
ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
) # noqa: E131
.set_dtype_configs(dtype_configs)
)
configs = [
BackendPatternConfig(op)
.set_observation_type(
ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
) # noqa: E131
.set_dtype_configs(dtype_configs)
for op in default_ops
]
configs.append(
BackendPatternConfig(torch.nn.functional.group_norm)

View File

@ -159,14 +159,14 @@ def get_binary_op_configs():
# TODO: remove when functionalization is supported in pt2_mode
(op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default),
]
for bop_pattern in bop_patterns:
binary_op_configs.append(
BackendPatternConfig(bop_pattern)
.set_dtype_configs(dtype_configs) # noqa: E131
._set_num_tensor_args_to_observation_type(
num_tensor_args_to_observation_type_mapping
)
binary_op_configs.extend(
BackendPatternConfig(bop_pattern)
.set_dtype_configs(dtype_configs) # noqa: E131
._set_num_tensor_args_to_observation_type(
num_tensor_args_to_observation_type_mapping
)
for bop_pattern in bop_patterns
)
return binary_op_configs

View File

@ -325,14 +325,14 @@ def _get_binary_ops_configs() -> List[BackendPatternConfig]:
(op, torch.relu),
op,
]
for bop_pattern in bop_patterns:
binary_op_configs.append(
BackendPatternConfig(bop_pattern)
.set_dtype_configs(dtype_configs) # noqa: E131
._set_num_tensor_args_to_observation_type(
num_tensor_args_to_observation_type_mapping
)
binary_op_configs.extend(
BackendPatternConfig(bop_pattern)
.set_dtype_configs(dtype_configs) # noqa: E131
._set_num_tensor_args_to_observation_type(
num_tensor_args_to_observation_type_mapping
)
for bop_pattern in bop_patterns
)
return binary_op_configs
@ -385,13 +385,12 @@ def _get_share_qparams_ops_configs() -> List[BackendPatternConfig]:
"squeeze_",
"leaky_relu",
]
share_qparams_op_configs: List[BackendPatternConfig] = []
for op in share_qparams_ops:
share_qparams_op_configs.append(
BackendPatternConfig(op)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
)
share_qparams_op_configs: List[BackendPatternConfig] = [
BackendPatternConfig(op)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
for op in share_qparams_ops
]
return share_qparams_op_configs

View File

@ -36,18 +36,20 @@ def get_pattern_to_dtype_configs(
def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
qat_module_classes = []
for config in backend_config.configs:
if config.qat_module is not None:
qat_module_classes.append(config.qat_module)
qat_module_classes = [
config.qat_module
for config in backend_config.configs
if config.qat_module is not None
]
return tuple(set(qat_module_classes))
def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
fused_module_classes = []
for config in backend_config.configs:
if config.fused_module is not None:
fused_module_classes.append(config.fused_module)
fused_module_classes = [
config.fused_module
for config in backend_config.configs
if config.fused_module is not None
]
return tuple(set(fused_module_classes))

View File

@ -92,9 +92,7 @@ def _fuse_modules_helper(
additional_fuser_method_mapping = fuse_custom_config_dict.get(
"additional_fuser_method_mapping", {}
)
mod_list = []
for item in modules_to_fuse:
mod_list.append(_get_module(model, item))
mod_list = [_get_module(model, item) for item in modules_to_fuse]
# Fuse list of modules
new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)

View File

@ -266,9 +266,7 @@ def _get_valid_patterns(op_pattern):
"""
result: List[Any]
if isinstance(op_pattern, (tuple, list)):
sub_combs = []
for sub_pattern in op_pattern:
sub_combs.append(_get_valid_patterns(sub_pattern))
sub_combs = [_get_valid_patterns(sub_pattern) for sub_pattern in op_pattern]
result = list(itertools.product(*sub_combs))
else:
result = [op_pattern, MatchAllNode]

View File

@ -527,8 +527,9 @@ class ModelReportVisualizer:
# gather the x_data and multiple y_data
# calculate the number of channels
num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1
for channel in range(num_channels):
y_data.append([]) # separate data list per channel
y_data.extend(
[] for channel in range(num_channels)
) # separate data list per channel
for table_row_num, row in enumerate(table):
# get x_value to append

View File

@ -1116,10 +1116,9 @@ def convert(
# for dynamic quant ops or weight only quant ops
_run_weight_observers(model, backend_config)
graph_inputs: List[str] = []
for node in model.graph.nodes:
if node.op == "placeholder":
graph_inputs.append(node.name)
graph_inputs: List[str] = [
node.name for node in model.graph.nodes if node.op == "placeholder"
]
# additional state to override inputs to be quantized, if specified
# by the user

View File

@ -78,8 +78,7 @@ class DefaultFuseHandler(FuseHandler):
n, *args = pattern
modules: List[torch.nn.Module] = []
modules.append(get_modules(n))
for a in args:
modules.append(get_modules(a))
modules.extend(get_modules(a) for a in args)
return tuple(modules)
else:
n = pattern
@ -111,9 +110,7 @@ class DefaultFuseHandler(FuseHandler):
# as input
fused_module = fuser_method(is_qat, *matched_modules)
setattr(named_modules[module_parent_name], module_name, fused_module)
extra_args = []
for input in extra_inputs:
extra_args.append(load_arg(input))
extra_args = [load_arg(input) for input in extra_inputs]
node = fused_graph.node_copy(root_node, load_arg)
args = list(node.args)
args.extend(extra_args)

View File

@ -1219,13 +1219,12 @@ def _maybe_insert_observers_before_graph_output(
else:
return maybe_node
elif isinstance(maybe_node, (list, tuple)):
results = []
for inner_node in maybe_node:
results.append(
_recursive_maybe_replace_node_with_obs(
inner_node, model, named_modules, graph
)
results = [
_recursive_maybe_replace_node_with_obs(
inner_node, model, named_modules, graph
)
for inner_node in maybe_node
]
if isinstance(maybe_node, list):
return results
else:
@ -1244,11 +1243,10 @@ def _maybe_insert_observers_before_graph_output(
"Unhandled type for returned node:", maybe_node
)
new_args = []
for old_arg in graph_output_node.args:
new_args.append(
_recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph)
)
new_args = [
_recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph)
for old_arg in graph_output_node.args
]
graph_output_node.args = tuple(new_args) # type: ignore[assignment]

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import itertools
import operator
from typing import Any, Callable, List, Optional, OrderedDict, Set
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set
import torch
from torch.fx import Node
@ -58,7 +58,7 @@ def update_equivalent_types_dict(customized_equivalent_types=None):
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
def _partitions_sequential(partitions: List[SourcePartition]):
def _partitions_sequential(partitions: Sequence[SourcePartition]):
prev_partition = None
for partition in partitions:
if prev_partition is not None and not check_subgraphs_connected(
@ -108,8 +108,9 @@ def find_sequential_partitions(
typed_partitions_list = list(typed_partitions.values())
fusion_candidates = itertools.product(*typed_partitions_list)
fused_partitions = []
for candidate in fusion_candidates:
if _partitions_sequential(candidate): # type: ignore[arg-type]
fused_partitions.append(candidate)
fused_partitions = [
candidate
for candidate in fusion_candidates
if _partitions_sequential(candidate)
]
return fused_partitions

View File

@ -93,10 +93,10 @@ def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
]:
ops = _supported_symmetric_quantized_operators()
for pattern_list in ops.values():
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list)
)
supported_config_and_operators.extend(
OperatorConfig(quantization_config, pattern_list)
for pattern_list in ops.values()
)
return copy.deepcopy(supported_config_and_operators)

View File

@ -55,10 +55,11 @@ def _allocate_jacobians_with_inputs(
# of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns
# a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have
# the same dtype and device as those of the corresponding input.
out: List[torch.Tensor] = []
for t in input_tensors:
if _is_float_or_complex_tensor(t) and t.requires_grad:
out.append(t.new_zeros((t.numel(), numel_output), layout=torch.strided))
out: List[torch.Tensor] = [
t.new_zeros((t.numel(), numel_output), layout=torch.strided)
for t in input_tensors
if _is_float_or_complex_tensor(t) and t.requires_grad
]
return tuple(out)
@ -69,11 +70,12 @@ def _allocate_jacobians_with_outputs(
# in `output_tensors`, returns a new zero-filled tensor with height of `dim` and
# width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size
# (t.numel,).
out: List[torch.Tensor] = []
options = {"dtype": dtype, "device": device, "layout": torch.strided}
for t in output_tensors:
if _is_float_or_complex_tensor(t):
out.append(t.new_zeros((numel_input, t.numel()), **options))
out: List[torch.Tensor] = [
t.new_zeros((numel_input, t.numel()), **options)
for t in output_tensors
if _is_float_or_complex_tensor(t)
]
return tuple(out)
@ -904,10 +906,10 @@ def _compute_analytical_jacobian_rows(
def _get_analytical_vjps_wrt_specific_output(
vjp_fn, sample_output, v
) -> List[List[Optional[torch.Tensor]]]:
vjps: List[List[Optional[torch.Tensor]]] = []
grad_inputs = vjp_fn(v.reshape(sample_output.shape))
for vjp in grad_inputs:
vjps.append([vjp.clone() if isinstance(vjp, torch.Tensor) else None])
vjps: List[List[Optional[torch.Tensor]]] = [
[vjp.clone() if isinstance(vjp, torch.Tensor) else None] for vjp in grad_inputs
]
return vjps

View File

@ -802,9 +802,7 @@ def _register_logging_hooks_on_whole_graph(
log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
log.debug(log_str)
handles = []
for node in iter_graph(grad_fns):
handles.append(node.register_prehook(prehook))
handles = [node.register_prehook(prehook) for node in iter_graph(grad_fns)]
def unregister_hooks() -> None:
for handle in handles:

View File

@ -856,10 +856,9 @@ def _build_table(
flops_column_width = DEFAULT_COLUMN_WIDTH
src_column_width = None
stacks = []
for evt in events:
if evt.stack is not None and len(evt.stack) > 0:
stacks.append(evt.stack)
stacks = [
evt.stack for evt in events if evt.stack is not None and len(evt.stack) > 0
]
has_stack = len(stacks) > 0
if has_stack:
src_column_width = (
@ -947,10 +946,7 @@ def _build_table(
if with_flops:
# Auto-scaling of flops header
raw_flops = []
for evt in events:
if evt.flops > 0:
raw_flops.append(evt.flops)
raw_flops = [evt.flops for evt in events if evt.flops > 0]
if len(raw_flops) != 0:
(flops_scale, flops_header) = auto_scale_flops(min(raw_flops))
headers.append(f"Total {flops_header}")

View File

@ -322,9 +322,7 @@ def _lazy_init():
# However, we must not let any *other* threads in!
_tls.is_initializing = True
for calls in _lazy_seed_tracker.get_calls():
if calls:
_queued_calls.append(calls)
_queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls)
try:
for queued_call, orig_traceback in _queued_calls:

View File

@ -44,9 +44,7 @@ def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
def get_rng_state_all() -> List[Tensor]:
r"""Return a list of ByteTensor representing the random number states of all devices."""
results = []
for i in range(device_count()):
results.append(get_rng_state(i))
results = [get_rng_state(i) for i in range(device_count())]
return results

View File

@ -31,8 +31,9 @@ class ShardedOptimizer(optim.Optimizer):
tensors: List[Tensor] = []
for value in named_params.values():
if isinstance(value, ShardedTensor):
for local_shard in value.local_shards():
tensors.append(local_shard.tensor)
tensors.extend(
local_shard.tensor for local_shard in value.local_shards()
)
else:
tensors.append(value)

View File

@ -99,9 +99,10 @@ def sharded_type_as(args, kwargs, pg):
tensor = args[1]
if isinstance(tensor, ShardedTensor):
tensor = tensor.local_tensor()
new_local_shards = []
for shard in st.local_shards():
new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata))
new_local_shards = [
Shard(shard.tensor.type_as(tensor), shard.metadata)
for shard in st.local_shards()
]
st_meta = copy.deepcopy(st._metadata)
st_meta.tensor_properties.dtype = tensor.dtype
return new_local_shards, st_meta

View File

@ -191,9 +191,9 @@ def reshard_local_shard(
)
# Compute expected size
input_split_sizes = []
for metadata in shards_metadata:
input_split_sizes.append(metadata.shard_sizes[reshard_dim])
input_split_sizes = [
metadata.shard_sizes[reshard_dim] for metadata in shards_metadata
]
rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1))
if rearrange_input:

View File

@ -142,8 +142,7 @@ def _run_trainer(emb_rref_list, rank):
)
# model.parameters() only includes local parameters.
for param in model.parameters():
model_parameter_rrefs.append(RRef(param))
model_parameter_rrefs.extend(RRef(param) for param in model.parameters())
# Setup distributed optimizer
opt = DistributedOptimizer(optim.SGD, model_parameter_rrefs, lr=0.05)

View File

@ -207,8 +207,10 @@ def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
if isinstance(obj, DTensor):
requests.append(_create_write_items_for_dtensor(fqn, obj))
elif isinstance(obj, ShardedTensor):
for shard_md in obj.metadata().shards_metadata:
requests.append(_create_write_item_for_shard(fqn, obj, shard_md))
requests.extend(
_create_write_item_for_shard(fqn, obj, shard_md)
for shard_md in obj.metadata().shards_metadata
)
elif isinstance(obj, torch.Tensor):
requests.append(_create_write_item_for_tensor(fqn, obj))
else:

View File

@ -1371,9 +1371,9 @@ def _get_all_pg_configs() -> List[Dict[str, Any]]:
Return the pg configuration of all the process groups.
"""
config_info: List[Dict[str, Any]] = []
for pg in _world.pg_map.keys():
config_info.append(_get_pg_config(pg))
config_info: List[Dict[str, Any]] = [
_get_pg_config(pg) for pg in _world.pg_map.keys()
]
return config_info
@ -2508,9 +2508,7 @@ def _coalescing_manager(
# - coalesced `reduce_scatter_tensor`
op0 = op_list[0].op
if op0 == all_reduce:
tensors = []
for op in op_list:
tensors.append(op.tensor)
tensors = [op.tensor for op in op_list]
all_reduce_opts = AllreduceCoalescedOptions()
all_reduce_opts.reduceOp = not_none(op_list[0].redop)
work = group.allreduce_coalesced(tensors, all_reduce_opts)

View File

@ -568,10 +568,11 @@ def _root_pre_forward(
state._needs_buffer_dtype_restore_check = False
if state.forward_prefetch:
handles = []
for fsdp_state in state._all_fsdp_states:
if fsdp_state._handle:
handles.append(fsdp_state._handle)
handles = [
fsdp_state._handle
for fsdp_state in state._all_fsdp_states
if fsdp_state._handle
]
for handle in handles:
handle._needs_pre_forward_unshard = True
handle._prefetched = False

View File

@ -110,9 +110,9 @@ def _create_module_with_interface(
def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]:
ret: List[rpc.RRef[Parameter]] = []
for param in module_rref.local_value().parameters(recurse):
ret.append(rpc.RRef(param))
ret: List[rpc.RRef[Parameter]] = [
rpc.RRef(param) for param in module_rref.local_value().parameters(recurse)
]
return ret

View File

@ -146,9 +146,7 @@ class _NamedOptimizer(optim.Optimizer):
ret_groups = []
for group in param_groups:
param_keys = []
for param in group["params"]:
param_keys.append(self.ordered_param_keys[param])
param_keys = [self.ordered_param_keys[param] for param in group["params"]]
ret_group = {"params": sorted(param_keys)}
for k, v in group.items():
if k != "params":

View File

@ -245,13 +245,12 @@ class DistributedOptimizer:
else _local_optimizer_step
)
rpc_futs = []
for optimizer in self.remote_optimizers:
rpc_futs.append(
rpc.rpc_async(
optimizer.owner(),
optimizer_step_func,
args=(optimizer, context_id),
)
rpc_futs = [
rpc.rpc_async(
optimizer.owner(),
optimizer_step_func,
args=(optimizer, context_id),
)
for optimizer in self.remote_optimizers
]
_wait_for_all(rpc_futs)

View File

@ -781,15 +781,15 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self.process_group, rank
)
for param_group in param_groups:
for param in param_group["params"]:
handles.append(
dist.broadcast(
tensor=param.data,
src=global_rank,
group=self.process_group,
async_op=True,
)
handles.extend(
dist.broadcast(
tensor=param.data,
src=global_rank,
group=self.process_group,
async_op=True,
)
for param in param_group["params"]
)
return handles
def _sync_params(self):

View File

@ -224,9 +224,7 @@ def _shard_dict_of_args(
for chunk_idx in range(real_num_chunks):
chunk_args = {}
for key, arg in args_sharded_replicated.items():
arg_single_chunk = []
for v_flat in arg:
arg_single_chunk.append(v_flat[chunk_idx])
arg_single_chunk = [v_flat[chunk_idx] for v_flat in arg]
chunk_args[key] = arg_single_chunk
chunks_flat.append(chunk_args)
@ -340,9 +338,10 @@ def split_args_kwargs_into_chunks(
f"{len(args_split_dict)}, {len(kwargs_split)}"
)
args_split = []
for chunk_args in args_split_dict:
args_split.append(tuple(chunk_args[i] for i in range(len(chunk_args))))
args_split = [
tuple(chunk_args[i] for i in range(len(chunk_args)))
for chunk_args in args_split_dict
]
return args_split, kwargs_split

View File

@ -1702,16 +1702,14 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
)
# Store the list of operations used for that rank
rank_ops: List[Optional[_Action]] = []
# Pre-padding, rank starts with no-ops based on the warmup.
for _ in range(rank):
rank_ops.append(None)
rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
for stage_index in stage_indices:
for mb_index in range(self._n_microbatches):
rank_ops.append(
_Action(stage_index, _ComputationType.FORWARD, mb_index)
)
rank_ops.extend(
_Action(stage_index, _ComputationType.FORWARD, mb_index)
for mb_index in range(self._n_microbatches)
)
# wait for the first backward to trickle up
# which is 2 for every hop away
@ -1719,10 +1717,10 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
rank_ops.extend([None] * post_warmup_ops)
for stage_index in reversed(stage_indices):
for mb_index in reversed(range(self._n_microbatches)):
rank_ops.append(
_Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index)
)
rank_ops.extend(
_Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index)
for mb_index in reversed(range(self._n_microbatches))
)
return rank_ops
@ -1744,10 +1742,8 @@ def _get_1f1b_rank_ops(
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
# Store the list of operations used for that rank
rank_ops: List[Optional[_Action]] = []
# Pre-padding, rank starts with no-ops based on the warmup.
for _ in range(rank):
rank_ops.append(None)
rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
# Formula:

View File

@ -195,8 +195,7 @@ def fill_empty_tensor_to_shards(
size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size)
]
tensor = shards[0].new_zeros(tensor_size)
for _ in range(num_empty_tensors):
shards.append(tensor)
shards.extend(tensor for _ in range(num_empty_tensors))
return shards

View File

@ -167,9 +167,7 @@ def gen_einsum_strategies(
# (i.e. for Shard, tensor dim size must > mesh size)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = []
for specs in zip(*strategy_comb):
spec_list.append(DTensorSpec(mesh, tuple(specs)))
spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)]
strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:])
all_strategies.append(strat)

View File

@ -43,18 +43,17 @@ def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
# Default strategy by default just propagate the first input strategy
select_strategy = op_schema.args_schema[0]
assert isinstance(select_strategy, OpStrategy)
default_strategy = []
for strategy in select_strategy.strategies:
# we create new DTensorSpecs even for default strategy to assure that
# the tensor metas are distinct between the arguments and outputs
default_strategy.append(
PlacementStrategy(
output_specs=DTensorSpec(
mesh=strategy.output_spec.mesh,
placements=strategy.output_spec.placements,
)
# we create new DTensorSpecs even for default strategy to assure that
# the tensor metas are distinct between the arguments and outputs
default_strategy = [
PlacementStrategy(
output_specs=DTensorSpec(
mesh=strategy.output_spec.mesh,
placements=strategy.output_spec.placements,
)
)
for strategy in select_strategy.strategies
]
return OpStrategy(default_strategy)

View File

@ -209,9 +209,10 @@ def map_placements_after_broadcast(
def generate_redistribute_costs(
src_strategy: OpStrategy, dst_spec: DTensorSpec
) -> List[float]:
redistribute_costs: List[float] = []
for strat in src_strategy.strategies:
redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec))
redistribute_costs: List[float] = [
redistribute_cost(strat.output_spec, dst_spec)
for strat in src_strategy.strategies
]
return redistribute_costs

View File

@ -152,9 +152,11 @@ class InterpreterModule(torch.nn.Module):
# the keys in the kwarg dict.
arg_list = list(args)
kwarg_names = self.arg_names[len(arg_list) :]
for kwarg_name in kwarg_names:
if kwarg_name in kwargs:
arg_list.append(kwargs[kwarg_name])
arg_list.extend(
kwargs[kwarg_name]
for kwarg_name in kwarg_names
if kwarg_name in kwargs
)
# Assert that the kwargs passed in exactly match the positional
# arguments specified by the GraphModule. This should be
@ -933,9 +935,10 @@ class _ModuleFrame:
assert kwargs_spec.context is not None
with self.graph.inserting_after(None):
arg_nodes = []
for idx in range(args_spec.num_children):
arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}"))
arg_nodes = [
self.graph.placeholder(f"_positional_arg_{idx}")
for idx in range(args_spec.num_children)
]
kwarg_nodes = {}
for name in kwargs_spec.context:
kwarg_nodes[name] = self.graph.placeholder(name)

View File

@ -1015,10 +1015,11 @@ class Partitioner:
# Keep tracking the partition pair of node pair
partition_pair: List[Partition] = []
# Collect all the op nodes from the graph
op_nodes = []
for n in self.graph_module.graph.nodes:
if n.op not in {"placeholder", "get_attr", "output"}:
op_nodes.append(n)
op_nodes = [
n
for n in self.graph_module.graph.nodes
if n.op not in {"placeholder", "get_attr", "output"}
]
for node in op_nodes:
# Find which partition the current node belongs
p0_index = self.node_to_partition[node]

View File

@ -648,11 +648,7 @@ def generate_reshape(constraint, counter):
# then there must be exactly one occurrence of dyn
else:
new_target = []
for n in target:
if n != Dyn:
new_target.append(n)
new_target = [n for n in target if n != Dyn]
# tensor 1
c3_tensor1 = Disj(
@ -886,10 +882,8 @@ def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
neq_possibilities = [
BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))
]
d_possibilities = []
for i in zip(eq_possibilities, neq_possibilities):
d_possibilities.append(list(i))
d_possibilities = [list(i) for i in zip(eq_possibilities, neq_possibilities)]
all_possibilities = list(itertools.product(*d_possibilities))
return all_possibilities
@ -1043,13 +1037,11 @@ def apply_padding(
assert len(simulate_padding + d1) == len(d2)
broadcast_padding = []
# for every padding size, we also consider broadcasting
for j in range(len(d2) - i):
broadcast_padding.append(
broadcast_dim(simulate_padding, d2, d11, d12, j, True)
)
broadcast_padding = [
broadcast_dim(simulate_padding, d2, d11, d12, j, True)
for j in range(len(d2) - i)
]
# we consider the possibilities for broadcasting for every dimension. Since we already
# padded d1, we do not consider it while broadcasting

View File

@ -302,11 +302,10 @@ def get_latency_of_partitioned_graph(
"""This function is to return all the partitions without parents
as the starting points of all the paths
"""
top_partitions = []
for partition in partitions:
# If a partition has no parents, then it is a top partition
if len(partition.parents) == 0:
top_partitions.append(partition)
# If a partition has no parents, then it is a top partition
top_partitions = [
partition for partition in partitions if len(partition.parents) == 0
]
return top_partitions
top_partitions = get_top_partitions(partitions)

View File

@ -5201,10 +5201,9 @@ class ShapeEnv:
symints = {
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
}
guards = []
for g in self.guards:
if all(s in symints for s in g.expr.free_symbols):
guards.append(g)
guards = [
g for g in self.guards if all(s in symints for s in g.expr.free_symbols)
]
return guards
def bind_symbols(

View File

@ -220,10 +220,7 @@ class PassManager:
def remove_pass(self, _passes: List[str]):
if _passes is None:
return
passes_left = []
for ps in self.passes:
if ps.__name__ not in _passes:
passes_left.append(ps)
passes_left = [ps for ps in self.passes if ps.__name__ not in _passes]
self.passes = passes_left
self._validated = False

View File

@ -881,9 +881,7 @@ def infer_methods_to_compile(nn_module):
uniqued_methods.append(name)
uniquer.add(name)
stubs = []
for method in uniqued_methods:
stubs.append(make_stub_from_method(nn_module, method))
stubs = [make_stub_from_method(nn_module, method) for method in uniqued_methods]
return overload_stubs + stubs
@ -959,9 +957,10 @@ def interface_script(mod_interface, nn_module):
It is used to know which methods need to act as starting points for compilation.
"""
stubs = []
for method in mod_interface.getMethodNames():
stubs.append(make_stub_from_method(nn_module, method))
stubs = [
make_stub_from_method(nn_module, method)
for method in mod_interface.getMethodNames()
]
return stubs
return create_script_module(nn_module, infer_interface_methods_to_compile)

View File

@ -1493,11 +1493,10 @@ def _get_overloads(obj):
_jit_internal.get_overload_no_implementation_error_message("function", obj)
)
compiled_fns = []
for overload_fn in uncompiled_overloads:
compiled_fns.append(
_compile_function_with_overload(overload_fn, qual_name, obj)
)
compiled_fns = [
_compile_function_with_overload(overload_fn, qual_name, obj)
for overload_fn in uncompiled_overloads
]
if existing_compiled_fns:
compiled_fns = existing_compiled_fns + compiled_fns

View File

@ -77,9 +77,7 @@ def _lazy_init() -> None:
# However, we must not let any *other* threads in!
_tls.is_initializing = True
for calls in _lazy_seed_tracker.get_calls():
if calls:
_queued_calls.append(calls)
_queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls)
try:
for queued_call, orig_traceback in _queued_calls:

View File

@ -23,8 +23,7 @@ class Broadcast(Function):
non_differentiables = []
for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
if not input_requires_grad:
for output in outputs:
non_differentiables.append(output[idx])
non_differentiables.extend(output[idx] for output in outputs)
ctx.mark_non_differentiable(*non_differentiables)
return tuple([t for tensors in outputs for t in tensors])

View File

@ -142,9 +142,7 @@ def conv_backward(func, ctx, grad_output):
ctx.groups,
)
kernel_size = []
for i in range(2, conv_picker(func, 3, 4, 5)):
kernel_size.append(weight_shape[i])
kernel_size = [weight_shape[i] for i in range(2, conv_picker(func, 3, 4, 5))]
batch_size = ctx.batch_size
results: List[Optional[torch.Tensor]] = []

View File

@ -356,12 +356,14 @@ def quantized_args(
return descriptor and _is_value(arg) and _is_tuple_construct(arg)
# Run regular symbolic function if none of the argument is QTensor.
is_quantized = []
is_quantized: typing.List[bool] = []
for descriptor, arg in descriptor_args:
# ListConstruct
if _is_packed_list(arg):
for arg_input in arg.node().inputs():
is_quantized.append(_is_arg_quantized(descriptor, arg_input))
is_quantized.extend(
_is_arg_quantized(descriptor, arg_input)
for arg_input in arg.node().inputs()
)
else:
is_quantized.append(_is_arg_quantized(descriptor, arg))

View File

@ -67,10 +67,9 @@ def col2im(
stride: Sequence[int],
):
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
adjusted_padding = []
adjusted_padding: List[int] = []
for pad in padding:
for _ in range(2):
adjusted_padding.append(pad)
adjusted_padding.extend(pad for _ in range(2))
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
if not adjusted_padding:

View File

@ -1397,8 +1397,6 @@ class GraphInfo:
original_outputs = list(graph.outputs())
original_inputs = list(graph.inputs())
new_outputs = []
def _process_bridge_value_for_lower(
graph: torch.Graph, bridge_value: torch.Value
) -> torch.Value:
@ -1416,9 +1414,9 @@ class GraphInfo:
graph, pivot, process_bridge_value_for_lower
)
for output in original_outputs:
if _produced_by(output, lower_nodes):
new_outputs.append(output)
new_outputs = [
output for output in original_outputs if _produced_by(output, lower_nodes)
]
for _ in enumerate(original_outputs):
graph.eraseOutput(0)
for output in new_outputs:

View File

@ -50,10 +50,11 @@ class DirectoryReader:
def get_all_records(
self,
):
files = []
for filename in glob(f"{self.directory}/**", recursive=True):
if not os.path.isdir(filename):
files.append(filename[len(self.directory) + 1 :])
files = [
filename[len(self.directory) + 1 :]
for filename in glob(f"{self.directory}/**", recursive=True)
if not os.path.isdir(filename)
]
return files
def serialization_id(

View File

@ -53,16 +53,16 @@ class _ExtractModuleReferences(ast.NodeVisitor):
if hasattr(node.func, "id") and node.func.id == "__import__":
try:
name = self._grab_node_str(node.args[0])
fromlist = []
fromlist: List[str] = []
level = 0
if len(node.args) > 3:
for v in node.args[3].elts:
fromlist.append(self._grab_node_str(v))
fromlist.extend(self._grab_node_str(v) for v in node.args[3].elts)
elif hasattr(node, "keywords"):
for keyword in node.keywords:
if keyword.arg == "fromlist":
for v in keyword.value.elts:
fromlist.append(self._grab_node_str(v))
fromlist.extend(
self._grab_node_str(v) for v in keyword.value.elts
)
if len(node.args) > 4:
level = self._grab_node_int(node.args[4])
elif hasattr(node, "keywords"):

View File

@ -97,10 +97,9 @@ class Pattern:
def matched_events(self):
if self.skip:
return []
matched_events = []
for event in self.eventTreeTraversal():
if self.match(event):
matched_events.append(event)
matched_events = [
event for event in self.eventTreeTraversal() if self.match(event)
]
return matched_events
def root_of(self, event: _ProfilerEvent):

Some files were not shown because too many files have changed in this diff Show More