[BE]: Apply PERF401 autofixes from ruff (#140980)

* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
This commit is contained in:
Aaron Gokaslan 2024-11-20 17:52:07 +00:00 committed by PyTorch MergeBot
parent 8d708090c0
commit 12e95aa4ee
133 changed files with 611 additions and 761 deletions

View File

@ -67,8 +67,7 @@ def pretty_print_buckets(buckets: List[Bucket], bucket_bytes_cap: int):
for idx, bucket in enumerate(reversed(buckets)): for idx, bucket in enumerate(reversed(buckets)):
if len(bucket.params) > 0: if len(bucket.params) > 0:
rows.append((idx, bucket.size, bucket.params[0])) rows.append((idx, bucket.size, bucket.params[0]))
for param in bucket.params[1:]: rows.extend((None, None, param) for param in bucket.params[1:])
rows.append((None, None, param))
if bucket.opcount_increased_to_capture_external_output > 0: if bucket.opcount_increased_to_capture_external_output > 0:
extended_buckets.append( extended_buckets.append(
( (

View File

@ -1140,11 +1140,12 @@ def update_offsets(instructions) -> None:
def debug_bytes(*args) -> str: def debug_bytes(*args) -> str:
index = range(max(map(len, args))) index = range(max(map(len, args)))
result = [] result = [
for arg in ( " ".join(f"{x:03}" for x in arg)
[index] + list(args) + [[int(a != b) for a, b in zip(args[-1], args[-2])]] for arg in [index]
): + list(args)
result.append(" ".join(f"{x:03}" for x in arg)) + [[int(a != b) for a, b in zip(args[-1], args[-2])]]
]
return "bytes mismatch\n" + "\n".join(result) return "bytes mismatch\n" + "\n".join(result)

View File

@ -478,10 +478,8 @@ class AutogradCompilerInstance:
@staticmethod @staticmethod
def get_all_nodes(args): def get_all_nodes(args):
nodes = [] # filter out non-Node args, like None
for n in args: nodes = [n for n in args if type(n) is torch.fx.Node]
if type(n) is torch.fx.Node: # filter out non-Node args, like None
nodes.append(n)
return nodes return nodes
@staticmethod @staticmethod
@ -671,13 +669,15 @@ class AutogradCompilerInstance:
input_nodes_and_users = [] input_nodes_and_users = []
input_nodes_and_users.extend(list(input_nodes)) input_nodes_and_users.extend(list(input_nodes))
for input_node in input_nodes: for input_node in input_nodes:
for user in list(input_node.users.keys()): input_nodes_and_users.extend(
user
for user in list(input_node.users.keys())
if not ( if not (
user.op == "call_function" user.op == "call_function"
and user.target == call_hook and user.target == call_hook
and node.kwargs.get("hook_type", None) == "post_hook" and node.kwargs.get("hook_type", None) == "post_hook"
): )
input_nodes_and_users.append(user) )
arg = max(input_nodes_and_users) # last input users arg = max(input_nodes_and_users) # last input users
if ( if (

View File

@ -2526,7 +2526,6 @@ def make_torch_function_mode_stack_guard(intial_stack):
def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope): def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
duplicate_tensors = []
global_scope = dict(guard_manager.global_scope) global_scope = dict(guard_manager.global_scope)
ids_to_source = collections.defaultdict(list) ids_to_source = collections.defaultdict(list)
for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined] for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined]
@ -2534,9 +2533,9 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
tensor_id = id(eval(tensor_source, global_scope, scope)) tensor_id = id(eval(tensor_source, global_scope, scope))
ids_to_source[tensor_id].append(tensor_source) ids_to_source[tensor_id].append(tensor_source)
for key in ids_to_source: duplicate_tensors = [
if len(ids_to_source[key]) > 1: f"{ids_to_source[key]}" for key in ids_to_source if len(ids_to_source[key]) > 1
duplicate_tensors.append(f"{ids_to_source[key]}") ]
reason = ", ".join(duplicate_tensors) reason = ", ".join(duplicate_tensors)
return [f"Duplicate tensors found: {reason}"] return [f"Duplicate tensors found: {reason}"]

View File

@ -1463,9 +1463,7 @@ class OutputGraph:
return compiled_fn return compiled_fn
def example_inputs(self) -> List[torch.Tensor]: def example_inputs(self) -> List[torch.Tensor]:
result = [] result = [arg.example for arg in self.graphargs]
for arg in self.graphargs:
result.append(arg.example)
return result return result
def remove_unused_graphargs(self) -> None: def remove_unused_graphargs(self) -> None:

View File

@ -252,8 +252,7 @@ def _filter_iter(l1, l2, cond):
def _load_tuple_and_call(tup): def _load_tuple_and_call(tup):
insts: List[Instruction] = [] insts: List[Instruction] = []
_initial_push_null(insts) _initial_push_null(insts)
for val in tup: insts.extend(create_instruction("LOAD_CONST", argval=val) for val in tup)
insts.append(create_instruction("LOAD_CONST", argval=val))
insts.extend(create_call_function(len(tup), False)) insts.extend(create_call_function(len(tup), False))
return insts return insts

View File

@ -95,9 +95,7 @@ def collect_results(
results.append(buffers) results.append(buffers)
for example in example_inputs: for example in example_inputs:
if isinstance(example, (tuple, list)): if isinstance(example, (tuple, list)):
for inp in example: results.extend(inp.grad for inp in example if isinstance(inp, torch.Tensor))
if isinstance(inp, torch.Tensor):
results.append(inp.grad)
else: else:
if isinstance(example, torch.Tensor): if isinstance(example, torch.Tensor):
results.append(example.grad) results.append(example.grad)

View File

@ -1535,9 +1535,10 @@ def checkpoint_params(gm):
rng_state = torch.clone(torch.random.get_rng_state()) rng_state = torch.clone(torch.random.get_rng_state())
if torch.cuda.is_available(): if torch.cuda.is_available():
cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
saved_state = [] saved_state = [
for param in itertools.chain(gm.parameters(), gm.buffers()): (param, param._version, torch.clone(param))
saved_state.append((param, param._version, torch.clone(param))) for param in itertools.chain(gm.parameters(), gm.buffers())
]
def restore(): def restore():
with torch.no_grad(): with torch.no_grad():

View File

@ -1765,10 +1765,9 @@ class BuiltinVariable(VariableTracker):
# tracked fakes to produce incorrect guards. This is sound because the TensorVariable # tracked fakes to produce incorrect guards. This is sound because the TensorVariable
# coming out of set_() below will be a new one, and get # coming out of set_() below will be a new one, and get
# installed in tracked fakes. # installed in tracked fakes.
to_remove = [] to_remove = [
for tf in tx.output.tracked_fakes: tf for tf in tx.output.tracked_fakes if tf.source == obj.source
if tf.source == obj.source: ]
to_remove.append(tf)
for tf in to_remove: for tf in to_remove:
tx.output.tracked_fakes.remove(tf) tx.output.tracked_fakes.remove(tf)

View File

@ -1026,17 +1026,16 @@ class SDPAKernelVariable(ContextWrappingVariable):
@staticmethod @staticmethod
def _backends_to_nodes(tx, backends): def _backends_to_nodes(tx, backends):
nodes = []
for backend in backends:
# convert to/from string in order to bake the backend into FX graph # convert to/from string in order to bake the backend into FX graph
nodes.append( nodes = [
tx.output.create_node( tx.output.create_node(
"call_function", "call_function",
torch.nn.attention._backend_from_string, torch.nn.attention._backend_from_string,
(backend.name,), (backend.name,),
{}, {},
) )
) for backend in backends
]
return nodes return nodes
def enter(self, tx): def enter(self, tx):

View File

@ -48,9 +48,9 @@ class ItertoolsVariable(VariableTracker):
and all(arg.has_unpack_var_sequence(tx) for arg in args) and all(arg.has_unpack_var_sequence(tx) for arg in args)
): ):
seqs = [arg.unpack_var_sequence(tx) for arg in args] seqs = [arg.unpack_var_sequence(tx) for arg in args]
items = [] items = [
for item in itertools.product(*seqs): variables.TupleVariable(list(item)) for item in itertools.product(*seqs)
items.append(variables.TupleVariable(list(item))) ]
return variables.ListIteratorVariable( return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew() items, mutation_type=ValueMutationNew()
) )

View File

@ -776,14 +776,13 @@ class UserDefinedObjectVariable(UserDefinedVariable):
): ):
assert self.source # OrderedDict, dict subtypes must always have source assert self.source # OrderedDict, dict subtypes must always have source
assert not (args or kwargs) assert not (args or kwargs)
items = []
keys = self.call_method(tx, "keys", [], {}) keys = self.call_method(tx, "keys", [], {})
for key in keys.force_unpack_var_sequence(tx): items = [
items.append(
TupleVariable( TupleVariable(
[key, self.odict_getitem(tx, key)], [key, self.odict_getitem(tx, key)],
) )
) for key in keys.force_unpack_var_sequence(tx)
]
tx.output.guard_on_key_order.add(self.source.name()) tx.output.guard_on_key_order.add(self.source.name())
return TupleVariable(items) return TupleVariable(items)

View File

@ -833,9 +833,7 @@ class TS2FXGraphConverter:
self._convert_prim_iterator(node) self._convert_prim_iterator(node)
def _convert_prim_iterator(self, node: torch._C.Node): def _convert_prim_iterator(self, node: torch._C.Node):
output_list = [] output_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()]
for inp in node.inputs():
output_list.append(self.get_fx_value_by_ir_value(inp))
output_name = node.output().debugName() output_name = node.output().debugName()
self.name_to_node[output_name] = output_list self.name_to_node[output_name] = output_list

View File

@ -7,9 +7,8 @@ class StaticForLoop(torch.nn.Module):
""" """
def forward(self, x): def forward(self, x):
ret = [] # constant
for i in range(10): # constant ret = [i + x for i in range(10)]
ret.append(i + x)
return ret return ret
example_args = (torch.randn(3, 2),) example_args = (torch.randn(3, 2),)

View File

@ -1780,9 +1780,7 @@ class GraphModuleDeserializer(metaclass=Final):
) from e ) from e
# Outputs: convert to a single `output` node. # Outputs: convert to a single `output` node.
outputs = [] outputs = [self.deserialize_graph_output(output) for output in serialized_graph.outputs]
for output in serialized_graph.outputs:
outputs.append(self.deserialize_graph_output(output))
if serialized_graph.is_single_tensor_return: if serialized_graph.is_single_tensor_return:
assert len(outputs) == 1 assert len(outputs) == 1
@ -2149,9 +2147,7 @@ class GraphModuleDeserializer(metaclass=Final):
if len(value) == 0: if len(value) == 0:
return [] return []
elif typ_ == "as_tensors": elif typ_ == "as_tensors":
result = [] result = [self.serialized_name_to_node[arg.name] for arg in value]
for arg in value:
result.append(self.serialized_name_to_node[arg.name])
return result return result
elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"): elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"):
# convert from serialized.python.types.List to python list # convert from serialized.python.types.List to python list

View File

@ -1306,8 +1306,9 @@ def merge_view_inputs(
mutated_input_info[inpt_idx].mutates_data mutated_input_info[inpt_idx].mutates_data
for inpt_idx in aliased_input_indices for inpt_idx in aliased_input_indices
): ):
for curr_idx in aliased_input_indices: other_args.extend(
other_args.append(fwd_inputs[curr_idx]) fwd_inputs[curr_idx] for curr_idx in aliased_input_indices
)
continue continue
# Here, we attempt to do a more complicated check to detect false aliasing # Here, we attempt to do a more complicated check to detect false aliasing
@ -1320,8 +1321,9 @@ def merge_view_inputs(
fwd_inputs, aliased_input_indices fwd_inputs, aliased_input_indices
) )
if len(aliased_input_indices_no_false_sharing) <= 1: if len(aliased_input_indices_no_false_sharing) <= 1:
for curr_idx in aliased_input_indices: other_args.extend(
other_args.append(fwd_inputs[curr_idx]) fwd_inputs[curr_idx] for curr_idx in aliased_input_indices
)
continue continue
# We detected an input that was mutated, AND aliases with another input. # We detected an input that was mutated, AND aliases with another input.

View File

@ -217,9 +217,7 @@ def trace_map(proxy_mode, func_overload, f, xs, pos_args):
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd) @map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, xs, pos_args): def map_dense(f, xs, pos_args):
pytrees = [] pytrees = [f(*inp, *pos_args) for inp in _unstack_pytree(xs)]
for inp in _unstack_pytree(xs):
pytrees.append(f(*inp, *pos_args))
return _stack_pytree(pytrees) return _stack_pytree(pytrees)

View File

@ -543,8 +543,7 @@ def analyze_kernel_mutations(
) )
stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated) stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated)
else: else:
for idx in MUTATION_OPS.get(op.name, []): stack.extend(op.args[idx] for idx in MUTATION_OPS.get(op.name, []))
stack.append(op.args[idx])
# The following is an iterative DFS algorithm # The following is an iterative DFS algorithm
mutated = [False] * num_args mutated = [False] * num_args

View File

@ -388,9 +388,7 @@ def _unstack_pytree(xs):
a = zip(*flat_xs) a = zip(*flat_xs)
pytrees = [] pytrees = [pytree.tree_unflatten(tuple, inspec) for tuple in a]
for tuple in a:
pytrees.append(pytree.tree_unflatten(tuple, inspec))
return pytrees return pytrees

View File

@ -2102,13 +2102,14 @@ class AotCodeCompiler:
aot_constants = struct.pack("qq", consts_size + 8, magic_number) aot_constants = struct.pack("qq", consts_size + 8, magic_number)
consts_o = _compile_consts(aot_constants, sys.platform) consts_o = _compile_consts(aot_constants, sys.platform)
kernels_o = []
gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = ( gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = (
ROCmCodeCache() if torch.version.hip else CUDACodeCache() ROCmCodeCache() if torch.version.hip else CUDACodeCache()
) )
for entry in gpu_codecache.cache.values(): kernels_o = [
if entry.output_path.endswith(".o"): entry.output_path
kernels_o.append(entry.output_path) for entry in gpu_codecache.cache.values()
if entry.output_path.endswith(".o")
]
kernels_o = " ".join(kernels_o) kernels_o = " ".join(kernels_o)
output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) output_name, output_dir = get_name_and_dir_from_output_file_path(output_so)

View File

@ -4800,10 +4800,9 @@ class LoopLevel:
def split_with_tiling(self, depth, factor): def split_with_tiling(self, depth, factor):
def clone_inner(): def clone_inner():
inner = [] inner: List[LoopLevel] = []
if self.inner: if self.inner:
for loop in self.inner: inner.extend(loop.clone() for loop in self.inner)
inner.append(loop.clone())
return inner return inner
def do_split_with_tiling(): def do_split_with_tiling():

View File

@ -150,18 +150,16 @@ class DebugPrinterManager:
# get the list of args_to_print_or_save # get the list of args_to_print_or_save
# TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls # TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls
if kernel_type == "extern": if kernel_type == "extern":
args_to_print_or_save_extern = [] args_to_print_or_save_extern = [
for arg in args_to_print_or_save: arg for arg in args_to_print_or_save if arg.startswith(("buf", "arg"))
if arg.startswith(("buf", "arg")): ]
args_to_print_or_save_extern.append(arg)
self.args_to_print_or_save = args_to_print_or_save_extern self.args_to_print_or_save = args_to_print_or_save_extern
elif kernel_type == "cpp": elif kernel_type == "cpp":
args_to_print_or_save_cpp = [] args_to_print_or_save_cpp = [
for arg in args_to_print_or_save:
if arg.startswith(("buf", "arg")):
args_to_print_or_save_cpp.append(
f"convert_arrayref_tensor_to_tensor({arg})" f"convert_arrayref_tensor_to_tensor({arg})"
) for arg in args_to_print_or_save
if arg.startswith(("buf", "arg"))
]
self.args_to_print_or_save = args_to_print_or_save_cpp self.args_to_print_or_save = args_to_print_or_save_cpp
else: else:
self.args_to_print_or_save = args_to_print_or_save self.args_to_print_or_save = args_to_print_or_save

View File

@ -1392,8 +1392,7 @@ class HalideKernel(SIMDKernel):
result.append((call_str, arg)) result.append((call_str, arg))
if isinstance(arg, TensorArg): if isinstance(arg, TensorArg):
assert arg.offset == 0 and arg.alias_of is None assert arg.offset == 0 and arg.alias_of is None
for alias in self.buffer_aliases.get(arg.name, ()): result.extend(
result.append(
( (
None, None,
TensorArg( TensorArg(
@ -1404,6 +1403,7 @@ class HalideKernel(SIMDKernel):
alias_of=arg.name, alias_of=arg.name,
), ),
) )
for alias in self.buffer_aliases.get(arg.name, ())
) )
return result return result

View File

@ -72,10 +72,11 @@ def get_all_call_args(call_args_list, arg_types_list):
def get_numel_argdefs(kernel): def get_numel_argdefs(kernel):
numel_argdefs = [] numel_argdefs = [
for tree in kernel.range_trees: f"{tree.prefix}numel"
if tree.prefix != "r" or kernel.inside_reduction: for tree in kernel.range_trees
numel_argdefs.append(f"{tree.prefix}numel") if tree.prefix != "r" or kernel.inside_reduction
]
return numel_argdefs return numel_argdefs

View File

@ -76,7 +76,7 @@ def _default_custom_combo_kernel_horizontal_partition(
tilings = [node_info_map[n][1] for n in nodes] tilings = [node_info_map[n][1] for n in nodes]
max_dims = max(len(t) for t in tilings) max_dims = max(len(t) for t in tilings)
nodes_per_ndim = [] nodes_per_ndim: List[List[BaseSchedulerNode]] = []
for i in range(2, max_dims + 1): for i in range(2, max_dims + 1):
group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i] group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i]
reduction = [ reduction = [
@ -111,12 +111,11 @@ def _default_custom_combo_kernel_horizontal_partition(
len(large_pointwise), len(large_pointwise),
) )
not_reduction = [n for n in not_reduction if n not in large_pointwise] not_reduction = [n for n in not_reduction if n not in large_pointwise]
for node in large_pointwise: nodes_per_ndim.extend([node] for node in large_pointwise)
nodes_per_ndim.append([node])
for g in (not_reduction, short_reduction, long_reduction): nodes_per_ndim.extend(
if g: g for g in (not_reduction, short_reduction, long_reduction) if g
nodes_per_ndim.append(g) )
assert sum(len(p) for p in nodes_per_ndim) == len(nodes) assert sum(len(p) for p in nodes_per_ndim) == len(nodes)
return nodes_per_ndim return nodes_per_ndim

View File

@ -288,8 +288,7 @@ def get_compiler_version_info(compiler: str) -> str:
# =============================== cpp builder =============================== # =============================== cpp builder ===============================
def _append_list(dest_list: List[str], src_list: List[str]) -> None: def _append_list(dest_list: List[str], src_list: List[str]) -> None:
for item in src_list: dest_list.extend(copy.deepcopy(item) for item in src_list)
dest_list.append(copy.deepcopy(item))
def _remove_duplication_in_list(orig_list: List[str]) -> List[str]: def _remove_duplication_in_list(orig_list: List[str]) -> List[str]:
@ -742,12 +741,11 @@ def _setup_standard_sys_libs(
def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> Tuple[List[str], List[str]]: def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> Tuple[List[str], List[str]]:
macros = [] macros: List[str] = []
build_flags = [] build_flags: List[str] = []
if vec_isa != invalid_vec_isa: if vec_isa != invalid_vec_isa:
# Add Windows support later. # Add Windows support later.
for x in vec_isa.build_macro(): macros.extend(copy.deepcopy(x) for x in vec_isa.build_macro())
macros.append(copy.deepcopy(x))
build_flags = [vec_isa.build_arch_flags()] build_flags = [vec_isa.build_arch_flags()]

View File

@ -398,9 +398,11 @@ def valid_vec_isa_list() -> List[VecISA]:
arch value is x86_64 on Linux, and the value is AMD64 on Windows. arch value is x86_64 on Linux, and the value is AMD64 on Windows.
""" """
_cpu_supported_x86_isa = x86_isa_checker() _cpu_supported_x86_isa = x86_isa_checker()
for isa in supported_vec_isa_list: isa_list.extend(
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa: isa
isa_list.append(isa) for isa in supported_vec_isa_list
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa
)
return isa_list return isa_list

View File

@ -621,13 +621,12 @@ class CUDAWarmupNode:
} }
def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]: def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]:
non_cudagraph_inps = [] non_cudagraph_inps = [
for t in itertools.chain(new_inputs, self.wrapped_function.constants): weakref.ref(t.untyped_storage())
if ( for t in itertools.chain(new_inputs, self.wrapped_function.constants)
isinstance(t, torch.Tensor) if isinstance(t, torch.Tensor)
and t.untyped_storage().data_ptr() not in existing_path_data_ptrs and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
): ]
non_cudagraph_inps.append(weakref.ref(t.untyped_storage()))
return non_cudagraph_inps return non_cudagraph_inps
non_cudagraph_inps_storages = get_non_cudagraph_inps() non_cudagraph_inps_storages = get_non_cudagraph_inps()
@ -1709,12 +1708,10 @@ def get_block_addrs(pool_id: Tuple[int, int], live_only: bool = True) -> List[in
def format_tb(frames: List[Any]) -> str: def format_tb(frames: List[Any]) -> str:
formatted_traceback = [] formatted_traceback = [
for entry in frames:
formatted_traceback.append(
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"]) traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
) for entry in frames
]
return "".join(traceback.format_list(formatted_traceback)) return "".join(traceback.format_list(formatted_traceback))

View File

@ -127,9 +127,7 @@ class MemoryDep(Dep):
return None return None
stride_to_index = {s: i for i, s in enumerate(self_strides)} stride_to_index = {s: i for i, s in enumerate(self_strides)}
order = [] order = [stride_to_index[s] for s in other_strides]
for s in other_strides:
order.append(stride_to_index[s])
assert set(order) == set(range(0, self.num_vars)) assert set(order) == set(range(0, self.num_vars))
return order return order
@ -547,9 +545,7 @@ def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Sy
def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str): def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
var_ranges, add_var = var_builder(prefix) var_ranges, add_var = var_builder(prefix)
args: List[List[sympy.Symbol]] = [] args: List[List[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
for size in argsizes:
args.append(list(map(add_var, size)))
return args, var_ranges return args, var_ranges

View File

@ -1120,9 +1120,9 @@ class UnbindCatRemover(SplitCatSimplifier):
return return
# we need to check if the getitem indices from unbind are consecutive and all go to the same cat node # we need to check if the getitem indices from unbind are consecutive and all go to the same cat node
# before we do the unbind remove, otherwise it will hit the error when we unbind part of them # before we do the unbind remove, otherwise it will hit the error when we unbind part of them
getitem_indices = [] getitem_indices = [
for getitem_node in unbind_node.users.keys(): getitem_node.args[1] for getitem_node in unbind_node.users.keys()
getitem_indices.append(getitem_node.args[1]) ]
if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type] if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type]
getitem_indices getitem_indices
) != len( ) != len(
@ -1497,9 +1497,8 @@ def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
): ):
continue continue
# find the index of getitems to be cated/stacked # find the index of getitems to be cated/stacked
indices = [] # type: ignore[union-attr]
for arg in cat_user.args[0]: # type: ignore[union-attr] indices = [arg.args[1] for arg in cat_user.args[0]] # type: ignore[union-attr]
indices.append(arg.args[1]) # type: ignore[union-attr]
# the gettitems to be merged must be consecutive, otherwise # the gettitems to be merged must be consecutive, otherwise
# returned sliced tensor could be wrong # returned sliced tensor could be wrong
if not is_sorted_and_consecutive(indices): if not is_sorted_and_consecutive(indices):

View File

@ -1657,14 +1657,13 @@ class GraphLowering(torch.fx.Interpreter):
new_unbacked_defs |= op.get_unbacked_symbol_defs() new_unbacked_defs |= op.get_unbacked_symbol_defs()
def format_new_defs() -> str: def format_new_defs() -> str:
r = [] r = [
for buf in self.buffers[buffer_watermark:]:
r.append(
f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n" f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n"
) for buf in self.buffers[buffer_watermark:]
for op in self.operations[operation_watermark:]: ]
r.append( r.extend(
f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n" f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n"
for op in self.operations[operation_watermark:]
) )
return "***\n".join(r) return "***\n".join(r)

View File

@ -5484,13 +5484,14 @@ class UserDefinedTritonKernel(ExternKernel):
kernel = kernel_side_table.get_kernel(self.kernel_idx) kernel = kernel_side_table.get_kernel(self.kernel_idx)
configs = [] configs = []
restore_value_args = [] restore_value_args: List[str] = []
if isinstance(kernel, Autotuner): if isinstance(kernel, Autotuner):
# https://github.com/triton-lang/triton/pull/5083 # https://github.com/triton-lang/triton/pull/5083
# changes kernel.restore_idx to kernel.restore_value # changes kernel.restore_idx to kernel.restore_value
if hasattr(kernel, "restore_idx"): if hasattr(kernel, "restore_idx"):
for i in kernel.restore_idx: restore_value_args.extend(
restore_value_args.append(kernel.fn.arg_names[i]) kernel.fn.arg_names[i] for i in kernel.restore_idx
)
else: else:
assert hasattr(kernel, "restore_value") assert hasattr(kernel, "restore_value")
restore_value_args.extend(kernel.restore_value) restore_value_args.extend(kernel.restore_value)

View File

@ -1691,9 +1691,7 @@ def split_with_sizes(x, sizes, dim=0):
def unbind(x, dim=0): def unbind(x, dim=0):
dim = _validate_dim(x, dim, 0) dim = _validate_dim(x, dim, 0)
x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim])
result = [] result = [select(x, dim, i) for i in range(x_size)]
for i in range(x_size):
result.append(select(x, dim, i))
return result return result

View File

@ -87,8 +87,7 @@ def _prepare_convolution_fusion_create(
weight_size = [] weight_size = []
weight_size.append(prepacked_weight_size[1] * groups) weight_size.append(prepacked_weight_size[1] * groups)
weight_size.append(prepacked_weight_size[0] / groups) weight_size.append(prepacked_weight_size[0] / groups)
for d in range(2, dim): weight_size.extend(prepacked_weight_size[d] for d in range(2, dim))
weight_size.append(prepacked_weight_size[d])
else: else:
weight_size = prepacked_weight.transpose(0, 1).size() weight_size = prepacked_weight.transpose(0, 1).size()
return weight_size return weight_size

View File

@ -952,9 +952,10 @@ class PatternPrettyPrinter:
assert hasattr(obj, "pretty_print") assert hasattr(obj, "pretty_print")
out_str = obj.pretty_print(pp=pp) out_str = obj.pretty_print(pp=pp)
output = [] output = [
for key in pp.memoized_objs_names: f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}"
output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}") for key in pp.memoized_objs_names
]
output.append(f"{output_name} = {out_str}") output.append(f"{output_name} = {out_str}")
@ -1361,9 +1362,7 @@ def register_replacement(
return False return False
def normalize_args(**kwargs: Any) -> List[Any]: def normalize_args(**kwargs: Any) -> List[Any]:
args = [] args = [kwargs.pop(name) for name in argnames_static]
for name in argnames_static:
args.append(kwargs.pop(name))
for i in range(1, len(kwargs) + 1): for i in range(1, len(kwargs) + 1):
if f"tangents_{i}" not in kwargs: if f"tangents_{i}" not in kwargs:
break break

View File

@ -120,7 +120,7 @@ def autotune_hints_to_configs(
configs to try. configs to try.
""" """
xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...] xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...]
configs = [] configs: List[Config] = []
warp_size = device_props.warp_size warp_size = device_props.warp_size
# CPU target has no concept of "warp" # CPU target has no concept of "warp"
if warp_size is None: if warp_size is None:
@ -138,8 +138,7 @@ def autotune_hints_to_configs(
(1, block_size // 4, 1), (1, block_size // 4, 1),
(1, 1, block_size // 4), (1, 1, block_size // 4),
) )
for xyz in xyz_options: configs.extend(
configs.append(
triton_config( triton_config(
size_hints, size_hints,
*xyz, *xyz,
@ -147,6 +146,7 @@ def autotune_hints_to_configs(
device_props.warp_size if device_props.warp_size else 32 device_props.warp_size if device_props.warp_size else 32
), ),
) )
for xyz in xyz_options
) )
return configs return configs

View File

@ -308,9 +308,13 @@ class BaseSchedulerNode:
dep = deps.pop() dep = deps.pop()
used_names.add(dep) used_names.add(dep)
if V.graph.name_to_buffer.get(dep): if V.graph.name_to_buffer.get(dep):
for alias in V.graph.name_to_buffer[dep].get_inputs_that_alias_output(): deps.extend(
if alias not in used_names: alias
deps.append(alias) for alias in V.graph.name_to_buffer[
dep
].get_inputs_that_alias_output()
if alias not in used_names
)
return used_names return used_names
def prune_deps(self) -> None: def prune_deps(self) -> None:
@ -3324,10 +3328,11 @@ class Scheduler:
node1 = node2 node1 = node2
node2 = tmp node2 = tmp
deps = [] deps = [
for dep in node1.read_writes.reads | node1.read_writes.writes: dep
if dep in node2.read_writes.reads or dep in node2.read_writes.writes: for dep in node1.read_writes.reads | node1.read_writes.writes
deps.append(dep) if dep in node2.read_writes.reads or dep in node2.read_writes.writes
]
return sum(self.dep_size_hint(dep) for dep in deps) return sum(self.dep_size_hint(dep) for dep in deps)

View File

@ -176,14 +176,12 @@ def derived_types(
] ]
if list_base: if list_base:
for seq_typ in derived_seq_types(base_type): result.extend((seq_typ, f"{cpp_type}[]") for seq_typ in derived_seq_types(base_type)) # type: ignore[valid-type]
result.append((seq_typ, f"{cpp_type}[]")) # type: ignore[valid-type]
if optional_base_list: if optional_base_list:
for seq_typ in derived_seq_types(typing.Optional[base_type]): result.extend((seq_typ, f"{cpp_type}?[]") for seq_typ in derived_seq_types(typing.Optional[base_type])) # type: ignore[valid-type]
result.append((seq_typ, f"{cpp_type}?[]")) # type: ignore[valid-type]
if optional_list_base: if optional_list_base:
for seq_typ in derived_seq_types(base_type): # type: ignore[valid-type] # type: ignore[valid-type]
result.append((typing.Optional[seq_typ], f"{cpp_type}[]?")) # type: ignore[valid-type] result.extend((typing.Optional[seq_typ], f"{cpp_type}[]?") for seq_typ in derived_seq_types(base_type)) # type: ignore[valid-type]
return result return result

View File

@ -279,17 +279,17 @@ def requires_set_python_module() -> bool:
def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs): def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode) assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
overload_types = []
args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values())) args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
for a in args_flattened:
# TODO: need to double check the semantics of the "types" argument to torch_dispatch. # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
# It's generated in PyInterpreter.cpp, but seems to be generated in two places, # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
# where in one case we only include tensors with the python key, and in another # where in one case we only include tensors with the python key, and in another
# we include **all** tensors. # we include **all** tensors.
if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has( overload_types = [
torch._C.DispatchKey.Python type(a)
): for a in args_flattened
overload_types.append(type(a)) if isinstance(a, torch.Tensor)
and torch._C._dispatch_keys(a).has(torch._C.DispatchKey.Python)
]
# TODO: check that I got these args correct (in C++, we pass in "0000"??) # TODO: check that I got these args correct (in C++, we pass in "0000"??)
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs) return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)

View File

@ -43,15 +43,14 @@ def dump_file(filename: str) -> None:
def from_traceback(tb: Sequence[traceback.FrameSummary]) -> List[Dict[str, Any]]: def from_traceback(tb: Sequence[traceback.FrameSummary]) -> List[Dict[str, Any]]:
r = []
for frame in tb:
# dict naming convention here coincides with # dict naming convention here coincides with
# python/combined_traceback.cpp # python/combined_traceback.cpp
r.append( r = [
{ {
"line": frame.lineno, "line": frame.lineno,
"name": frame.name, "name": frame.name,
"filename": intern_string(frame.filename), "filename": intern_string(frame.filename),
} }
) for frame in tb
]
return r return r

View File

@ -312,10 +312,11 @@ def _make_prim(
prim_autograd_impl.impl(name, _autograd_impl) prim_autograd_impl.impl(name, _autograd_impl)
prim_meta_impl.impl(name, meta) prim_meta_impl.impl(name, meta)
else: else:
mutates_args = [] mutates_args = [
for arg in cpp_schema.arguments: arg.name
if arg.alias_info is not None and arg.alias_info.is_write: for arg in cpp_schema.arguments
mutates_args.append(arg.name) if arg.alias_info is not None and arg.alias_info.is_write
]
prim_def = torch.library.custom_op( prim_def = torch.library.custom_op(
"prims::" + name, "prims::" + name,
_prim_impl, _prim_impl,

View File

@ -3055,9 +3055,7 @@ def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType,
full_chunks = math.floor(length / chunk_size) full_chunks = math.floor(length / chunk_size)
tail_chunk_size = length % chunk_size tail_chunk_size = length % chunk_size
result = [] result = [narrow(a, dim, i * chunk_size, chunk_size) for i in range(full_chunks)]
for i in range(full_chunks):
result.append(narrow(a, dim, i * chunk_size, chunk_size))
if tail_chunk_size != 0: if tail_chunk_size != 0:
result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size)) result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))

View File

@ -585,14 +585,13 @@ def has_meta(func):
lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func) lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
) )
def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs): def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
tensor_lists = [] tensor_lists = [
for arg in itertools.chain(args, kwargs.values()): arg
if ( for arg in itertools.chain(args, kwargs.values())
isinstance(arg, (list, tuple)) if isinstance(arg, (list, tuple))
and len(arg) and len(arg)
and isinstance(arg[0], torch.Tensor) and isinstance(arg[0], torch.Tensor)
): ]
tensor_lists.append(arg)
try: try:
with in_kernel_invocation_manager(fake_mode): with in_kernel_invocation_manager(fake_mode):

View File

@ -171,8 +171,7 @@ def get_plain_tensors(
continue continue
inner_keys, _ = curr.__tensor_flatten__() inner_keys, _ = curr.__tensor_flatten__()
for key in reversed(inner_keys): todo.extend(getattr(curr, key) for key in reversed(inner_keys))
todo.append(getattr(curr, key))
return out return out
@ -1629,13 +1628,12 @@ class FakeTensorMode(TorchDispatchMode):
) )
if isinstance(output, tuple): if isinstance(output, tuple):
output_infos = [] output_infos = [
for out_elem in output:
output_infos.append(
self._get_output_info_for_cache_entry( self._get_output_info_for_cache_entry(
state, key, func, args, kwargs, out_elem state, key, func, args, kwargs, out_elem
) )
) for out_elem in output
]
return _DispatchCacheEntry( return _DispatchCacheEntry(
output_infos=tuple(output_infos), is_output_tuple=True output_infos=tuple(output_infos), is_output_tuple=True
) )
@ -1727,9 +1725,7 @@ class FakeTensorMode(TorchDispatchMode):
""" """
if entry.is_output_tuple: if entry.is_output_tuple:
outputs = [] outputs = [
for output_info in entry.output_infos:
outputs.append(
self._get_output_tensor_from_cache_entry( self._get_output_tensor_from_cache_entry(
state, state,
output_info, output_info,
@ -1737,7 +1733,8 @@ class FakeTensorMode(TorchDispatchMode):
func, func,
args, args,
) )
) for output_info in entry.output_infos
]
return tuple(outputs) return tuple(outputs)
else: else:
return self._get_output_tensor_from_cache_entry( return self._get_output_tensor_from_cache_entry(

View File

@ -389,8 +389,7 @@ class LSTM(torch.nn.Module):
**factory_kwargs, **factory_kwargs,
) )
] ]
for layer in range(1, num_layers): layers.extend(
layers.append(
_LSTMLayer( _LSTMLayer(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
@ -399,6 +398,7 @@ class LSTM(torch.nn.Module):
bidirectional=self.bidirectional, bidirectional=self.bidirectional,
**factory_kwargs, **factory_kwargs,
) )
for layer in range(1, num_layers)
) )
self.layers = torch.nn.ModuleList(layers) self.layers = torch.nn.ModuleList(layers)

View File

@ -32,8 +32,7 @@ def _reverse_repeat_padding(padding: List[int]) -> List[int]:
_reversed_padding_repeated_twice: List[int] = [] _reversed_padding_repeated_twice: List[int] = []
N = len(padding) N = len(padding)
for idx in range(N): for idx in range(N):
for _ in range(2): _reversed_padding_repeated_twice.extend(padding[N - idx - 1] for _ in range(2))
_reversed_padding_repeated_twice.append(padding[N - idx - 1])
return _reversed_padding_repeated_twice return _reversed_padding_repeated_twice

View File

@ -568,9 +568,9 @@ def create_one_transformed_and_logged_copy_of_subgraph(
and len(arg) and len(arg)
and isinstance(arg[0], Node) and isinstance(arg[0], Node)
): ):
for inner_arg in arg: new_args.extend(
if isinstance(inner_arg, Node): inner_arg for inner_arg in arg if isinstance(inner_arg, Node)
new_args.append(inner_arg) )
new_kwargs = {} new_kwargs = {}
for name, old_kwarg in first_node.kwargs.items(): for name, old_kwarg in first_node.kwargs.items():

View File

@ -312,10 +312,7 @@ def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
node.target in (torch.add, torch.ops.quantized.add, operator.add) node.target in (torch.add, torch.ops.quantized.add, operator.add)
or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul) or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
): ):
result = [] result = [i for i in range(2) if type(node.args[i]) == Node]
for i in range(2):
if type(node.args[i]) == Node:
result.append(i)
return result return result
return [0] return [0]

View File

@ -191,8 +191,7 @@ def expand_groups_in_paired_modules_list(paired_modules_list):
elif len(group) == 2: elif len(group) == 2:
new_list.append(group) new_list.append(group)
elif len(group) > 2: elif len(group) > 2:
for i in range(len(group) - 1): new_list.extend([group[i], group[i + 1]] for i in range(len(group) - 1))
new_list.append([group[i], group[i + 1]])
return new_list return new_list

View File

@ -155,13 +155,13 @@ def _get_binary_op_configs(
(op_with_quantized_bop_scalar_variant, torch.relu), (op_with_quantized_bop_scalar_variant, torch.relu),
op_with_quantized_bop_scalar_variant, op_with_quantized_bop_scalar_variant,
] ]
for bop_pattern in bop_patterns: binary_op_configs.extend(
binary_op_configs.append(
BackendPatternConfig(bop_pattern) BackendPatternConfig(bop_pattern)
.set_dtype_configs(dtype_configs) # noqa: E131 .set_dtype_configs(dtype_configs) # noqa: E131
._set_num_tensor_args_to_observation_type( ._set_num_tensor_args_to_observation_type(
num_tensor_args_to_observation_type_mapping num_tensor_args_to_observation_type_mapping
) )
for bop_pattern in bop_patterns
) )
# matmul # matmul
binary_op_configs.append( binary_op_configs.append(
@ -502,7 +502,6 @@ def _get_ln_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConf
def _get_default_op_configs( def _get_default_op_configs(
dtype_configs: List[DTypeConfig], dtype_configs: List[DTypeConfig],
) -> List[BackendPatternConfig]: ) -> List[BackendPatternConfig]:
configs = []
default_ops = [ default_ops = [
torch.nn.ELU, torch.nn.ELU,
torch.nn.LeakyReLU, torch.nn.LeakyReLU,
@ -517,14 +516,14 @@ def _get_default_op_configs(
torch.nn.functional.leaky_relu, torch.nn.functional.leaky_relu,
torch.nn.functional.dropout, torch.nn.functional.dropout,
] ]
for op in default_ops: configs = [
configs.append(
BackendPatternConfig(op) BackendPatternConfig(op)
.set_observation_type( .set_observation_type(
ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
) # noqa: E131 ) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs)
) for op in default_ops
]
configs.append( configs.append(
BackendPatternConfig(torch.nn.functional.group_norm) BackendPatternConfig(torch.nn.functional.group_norm)

View File

@ -159,13 +159,13 @@ def get_binary_op_configs():
# TODO: remove when functionalization is supported in pt2_mode # TODO: remove when functionalization is supported in pt2_mode
(op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default), (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default),
] ]
for bop_pattern in bop_patterns: binary_op_configs.extend(
binary_op_configs.append(
BackendPatternConfig(bop_pattern) BackendPatternConfig(bop_pattern)
.set_dtype_configs(dtype_configs) # noqa: E131 .set_dtype_configs(dtype_configs) # noqa: E131
._set_num_tensor_args_to_observation_type( ._set_num_tensor_args_to_observation_type(
num_tensor_args_to_observation_type_mapping num_tensor_args_to_observation_type_mapping
) )
for bop_pattern in bop_patterns
) )
return binary_op_configs return binary_op_configs

View File

@ -325,13 +325,13 @@ def _get_binary_ops_configs() -> List[BackendPatternConfig]:
(op, torch.relu), (op, torch.relu),
op, op,
] ]
for bop_pattern in bop_patterns: binary_op_configs.extend(
binary_op_configs.append(
BackendPatternConfig(bop_pattern) BackendPatternConfig(bop_pattern)
.set_dtype_configs(dtype_configs) # noqa: E131 .set_dtype_configs(dtype_configs) # noqa: E131
._set_num_tensor_args_to_observation_type( ._set_num_tensor_args_to_observation_type(
num_tensor_args_to_observation_type_mapping num_tensor_args_to_observation_type_mapping
) )
for bop_pattern in bop_patterns
) )
return binary_op_configs return binary_op_configs
@ -385,13 +385,12 @@ def _get_share_qparams_ops_configs() -> List[BackendPatternConfig]:
"squeeze_", "squeeze_",
"leaky_relu", "leaky_relu",
] ]
share_qparams_op_configs: List[BackendPatternConfig] = [] share_qparams_op_configs: List[BackendPatternConfig] = [
for op in share_qparams_ops:
share_qparams_op_configs.append(
BackendPatternConfig(op) BackendPatternConfig(op)
.set_observation_type(observation_type) # noqa: E131 .set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs)
) for op in share_qparams_ops
]
return share_qparams_op_configs return share_qparams_op_configs

View File

@ -36,18 +36,20 @@ def get_pattern_to_dtype_configs(
def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]: def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
qat_module_classes = [] qat_module_classes = [
for config in backend_config.configs: config.qat_module
if config.qat_module is not None: for config in backend_config.configs
qat_module_classes.append(config.qat_module) if config.qat_module is not None
]
return tuple(set(qat_module_classes)) return tuple(set(qat_module_classes))
def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]: def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
fused_module_classes = [] fused_module_classes = [
for config in backend_config.configs: config.fused_module
if config.fused_module is not None: for config in backend_config.configs
fused_module_classes.append(config.fused_module) if config.fused_module is not None
]
return tuple(set(fused_module_classes)) return tuple(set(fused_module_classes))

View File

@ -92,9 +92,7 @@ def _fuse_modules_helper(
additional_fuser_method_mapping = fuse_custom_config_dict.get( additional_fuser_method_mapping = fuse_custom_config_dict.get(
"additional_fuser_method_mapping", {} "additional_fuser_method_mapping", {}
) )
mod_list = [] mod_list = [_get_module(model, item) for item in modules_to_fuse]
for item in modules_to_fuse:
mod_list.append(_get_module(model, item))
# Fuse list of modules # Fuse list of modules
new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping) new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)

View File

@ -266,9 +266,7 @@ def _get_valid_patterns(op_pattern):
""" """
result: List[Any] result: List[Any]
if isinstance(op_pattern, (tuple, list)): if isinstance(op_pattern, (tuple, list)):
sub_combs = [] sub_combs = [_get_valid_patterns(sub_pattern) for sub_pattern in op_pattern]
for sub_pattern in op_pattern:
sub_combs.append(_get_valid_patterns(sub_pattern))
result = list(itertools.product(*sub_combs)) result = list(itertools.product(*sub_combs))
else: else:
result = [op_pattern, MatchAllNode] result = [op_pattern, MatchAllNode]

View File

@ -527,8 +527,9 @@ class ModelReportVisualizer:
# gather the x_data and multiple y_data # gather the x_data and multiple y_data
# calculate the number of channels # calculate the number of channels
num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1 num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1
for channel in range(num_channels): y_data.extend(
y_data.append([]) # separate data list per channel [] for channel in range(num_channels)
) # separate data list per channel
for table_row_num, row in enumerate(table): for table_row_num, row in enumerate(table):
# get x_value to append # get x_value to append

View File

@ -1116,10 +1116,9 @@ def convert(
# for dynamic quant ops or weight only quant ops # for dynamic quant ops or weight only quant ops
_run_weight_observers(model, backend_config) _run_weight_observers(model, backend_config)
graph_inputs: List[str] = [] graph_inputs: List[str] = [
for node in model.graph.nodes: node.name for node in model.graph.nodes if node.op == "placeholder"
if node.op == "placeholder": ]
graph_inputs.append(node.name)
# additional state to override inputs to be quantized, if specified # additional state to override inputs to be quantized, if specified
# by the user # by the user

View File

@ -78,8 +78,7 @@ class DefaultFuseHandler(FuseHandler):
n, *args = pattern n, *args = pattern
modules: List[torch.nn.Module] = [] modules: List[torch.nn.Module] = []
modules.append(get_modules(n)) modules.append(get_modules(n))
for a in args: modules.extend(get_modules(a) for a in args)
modules.append(get_modules(a))
return tuple(modules) return tuple(modules)
else: else:
n = pattern n = pattern
@ -111,9 +110,7 @@ class DefaultFuseHandler(FuseHandler):
# as input # as input
fused_module = fuser_method(is_qat, *matched_modules) fused_module = fuser_method(is_qat, *matched_modules)
setattr(named_modules[module_parent_name], module_name, fused_module) setattr(named_modules[module_parent_name], module_name, fused_module)
extra_args = [] extra_args = [load_arg(input) for input in extra_inputs]
for input in extra_inputs:
extra_args.append(load_arg(input))
node = fused_graph.node_copy(root_node, load_arg) node = fused_graph.node_copy(root_node, load_arg)
args = list(node.args) args = list(node.args)
args.extend(extra_args) args.extend(extra_args)

View File

@ -1219,13 +1219,12 @@ def _maybe_insert_observers_before_graph_output(
else: else:
return maybe_node return maybe_node
elif isinstance(maybe_node, (list, tuple)): elif isinstance(maybe_node, (list, tuple)):
results = [] results = [
for inner_node in maybe_node:
results.append(
_recursive_maybe_replace_node_with_obs( _recursive_maybe_replace_node_with_obs(
inner_node, model, named_modules, graph inner_node, model, named_modules, graph
) )
) for inner_node in maybe_node
]
if isinstance(maybe_node, list): if isinstance(maybe_node, list):
return results return results
else: else:
@ -1244,11 +1243,10 @@ def _maybe_insert_observers_before_graph_output(
"Unhandled type for returned node:", maybe_node "Unhandled type for returned node:", maybe_node
) )
new_args = [] new_args = [
for old_arg in graph_output_node.args:
new_args.append(
_recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph) _recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph)
) for old_arg in graph_output_node.args
]
graph_output_node.args = tuple(new_args) # type: ignore[assignment] graph_output_node.args = tuple(new_args) # type: ignore[assignment]

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import itertools import itertools
import operator import operator
from typing import Any, Callable, List, Optional, OrderedDict, Set from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set
import torch import torch
from torch.fx import Node from torch.fx import Node
@ -58,7 +58,7 @@ def update_equivalent_types_dict(customized_equivalent_types=None):
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() _EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
def _partitions_sequential(partitions: List[SourcePartition]): def _partitions_sequential(partitions: Sequence[SourcePartition]):
prev_partition = None prev_partition = None
for partition in partitions: for partition in partitions:
if prev_partition is not None and not check_subgraphs_connected( if prev_partition is not None and not check_subgraphs_connected(
@ -108,8 +108,9 @@ def find_sequential_partitions(
typed_partitions_list = list(typed_partitions.values()) typed_partitions_list = list(typed_partitions.values())
fusion_candidates = itertools.product(*typed_partitions_list) fusion_candidates = itertools.product(*typed_partitions_list)
fused_partitions = [] fused_partitions = [
for candidate in fusion_candidates: candidate
if _partitions_sequential(candidate): # type: ignore[arg-type] for candidate in fusion_candidates
fused_partitions.append(candidate) if _partitions_sequential(candidate)
]
return fused_partitions return fused_partitions

View File

@ -93,9 +93,9 @@ def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
get_symmetric_quantization_config(is_per_channel=True, is_qat=True), get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
]: ]:
ops = _supported_symmetric_quantized_operators() ops = _supported_symmetric_quantized_operators()
for pattern_list in ops.values(): supported_config_and_operators.extend(
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list) OperatorConfig(quantization_config, pattern_list)
for pattern_list in ops.values()
) )
return copy.deepcopy(supported_config_and_operators) return copy.deepcopy(supported_config_and_operators)

View File

@ -55,10 +55,11 @@ def _allocate_jacobians_with_inputs(
# of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns # of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns
# a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have # a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have
# the same dtype and device as those of the corresponding input. # the same dtype and device as those of the corresponding input.
out: List[torch.Tensor] = [] out: List[torch.Tensor] = [
for t in input_tensors: t.new_zeros((t.numel(), numel_output), layout=torch.strided)
if _is_float_or_complex_tensor(t) and t.requires_grad: for t in input_tensors
out.append(t.new_zeros((t.numel(), numel_output), layout=torch.strided)) if _is_float_or_complex_tensor(t) and t.requires_grad
]
return tuple(out) return tuple(out)
@ -69,11 +70,12 @@ def _allocate_jacobians_with_outputs(
# in `output_tensors`, returns a new zero-filled tensor with height of `dim` and # in `output_tensors`, returns a new zero-filled tensor with height of `dim` and
# width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size # width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size
# (t.numel,). # (t.numel,).
out: List[torch.Tensor] = []
options = {"dtype": dtype, "device": device, "layout": torch.strided} options = {"dtype": dtype, "device": device, "layout": torch.strided}
for t in output_tensors: out: List[torch.Tensor] = [
if _is_float_or_complex_tensor(t): t.new_zeros((numel_input, t.numel()), **options)
out.append(t.new_zeros((numel_input, t.numel()), **options)) for t in output_tensors
if _is_float_or_complex_tensor(t)
]
return tuple(out) return tuple(out)
@ -904,10 +906,10 @@ def _compute_analytical_jacobian_rows(
def _get_analytical_vjps_wrt_specific_output( def _get_analytical_vjps_wrt_specific_output(
vjp_fn, sample_output, v vjp_fn, sample_output, v
) -> List[List[Optional[torch.Tensor]]]: ) -> List[List[Optional[torch.Tensor]]]:
vjps: List[List[Optional[torch.Tensor]]] = []
grad_inputs = vjp_fn(v.reshape(sample_output.shape)) grad_inputs = vjp_fn(v.reshape(sample_output.shape))
for vjp in grad_inputs: vjps: List[List[Optional[torch.Tensor]]] = [
vjps.append([vjp.clone() if isinstance(vjp, torch.Tensor) else None]) [vjp.clone() if isinstance(vjp, torch.Tensor) else None] for vjp in grad_inputs
]
return vjps return vjps

View File

@ -802,9 +802,7 @@ def _register_logging_hooks_on_whole_graph(
log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}" log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
log.debug(log_str) log.debug(log_str)
handles = [] handles = [node.register_prehook(prehook) for node in iter_graph(grad_fns)]
for node in iter_graph(grad_fns):
handles.append(node.register_prehook(prehook))
def unregister_hooks() -> None: def unregister_hooks() -> None:
for handle in handles: for handle in handles:

View File

@ -856,10 +856,9 @@ def _build_table(
flops_column_width = DEFAULT_COLUMN_WIDTH flops_column_width = DEFAULT_COLUMN_WIDTH
src_column_width = None src_column_width = None
stacks = [] stacks = [
for evt in events: evt.stack for evt in events if evt.stack is not None and len(evt.stack) > 0
if evt.stack is not None and len(evt.stack) > 0: ]
stacks.append(evt.stack)
has_stack = len(stacks) > 0 has_stack = len(stacks) > 0
if has_stack: if has_stack:
src_column_width = ( src_column_width = (
@ -947,10 +946,7 @@ def _build_table(
if with_flops: if with_flops:
# Auto-scaling of flops header # Auto-scaling of flops header
raw_flops = [] raw_flops = [evt.flops for evt in events if evt.flops > 0]
for evt in events:
if evt.flops > 0:
raw_flops.append(evt.flops)
if len(raw_flops) != 0: if len(raw_flops) != 0:
(flops_scale, flops_header) = auto_scale_flops(min(raw_flops)) (flops_scale, flops_header) = auto_scale_flops(min(raw_flops))
headers.append(f"Total {flops_header}") headers.append(f"Total {flops_header}")

View File

@ -322,9 +322,7 @@ def _lazy_init():
# However, we must not let any *other* threads in! # However, we must not let any *other* threads in!
_tls.is_initializing = True _tls.is_initializing = True
for calls in _lazy_seed_tracker.get_calls(): _queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls)
if calls:
_queued_calls.append(calls)
try: try:
for queued_call, orig_traceback in _queued_calls: for queued_call, orig_traceback in _queued_calls:

View File

@ -44,9 +44,7 @@ def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
def get_rng_state_all() -> List[Tensor]: def get_rng_state_all() -> List[Tensor]:
r"""Return a list of ByteTensor representing the random number states of all devices.""" r"""Return a list of ByteTensor representing the random number states of all devices."""
results = [] results = [get_rng_state(i) for i in range(device_count())]
for i in range(device_count()):
results.append(get_rng_state(i))
return results return results

View File

@ -31,8 +31,9 @@ class ShardedOptimizer(optim.Optimizer):
tensors: List[Tensor] = [] tensors: List[Tensor] = []
for value in named_params.values(): for value in named_params.values():
if isinstance(value, ShardedTensor): if isinstance(value, ShardedTensor):
for local_shard in value.local_shards(): tensors.extend(
tensors.append(local_shard.tensor) local_shard.tensor for local_shard in value.local_shards()
)
else: else:
tensors.append(value) tensors.append(value)

View File

@ -99,9 +99,10 @@ def sharded_type_as(args, kwargs, pg):
tensor = args[1] tensor = args[1]
if isinstance(tensor, ShardedTensor): if isinstance(tensor, ShardedTensor):
tensor = tensor.local_tensor() tensor = tensor.local_tensor()
new_local_shards = [] new_local_shards = [
for shard in st.local_shards(): Shard(shard.tensor.type_as(tensor), shard.metadata)
new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata)) for shard in st.local_shards()
]
st_meta = copy.deepcopy(st._metadata) st_meta = copy.deepcopy(st._metadata)
st_meta.tensor_properties.dtype = tensor.dtype st_meta.tensor_properties.dtype = tensor.dtype
return new_local_shards, st_meta return new_local_shards, st_meta

View File

@ -191,9 +191,9 @@ def reshard_local_shard(
) )
# Compute expected size # Compute expected size
input_split_sizes = [] input_split_sizes = [
for metadata in shards_metadata: metadata.shard_sizes[reshard_dim] for metadata in shards_metadata
input_split_sizes.append(metadata.shard_sizes[reshard_dim]) ]
rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1)) rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1))
if rearrange_input: if rearrange_input:

View File

@ -142,8 +142,7 @@ def _run_trainer(emb_rref_list, rank):
) )
# model.parameters() only includes local parameters. # model.parameters() only includes local parameters.
for param in model.parameters(): model_parameter_rrefs.extend(RRef(param) for param in model.parameters())
model_parameter_rrefs.append(RRef(param))
# Setup distributed optimizer # Setup distributed optimizer
opt = DistributedOptimizer(optim.SGD, model_parameter_rrefs, lr=0.05) opt = DistributedOptimizer(optim.SGD, model_parameter_rrefs, lr=0.05)

View File

@ -207,8 +207,10 @@ def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
if isinstance(obj, DTensor): if isinstance(obj, DTensor):
requests.append(_create_write_items_for_dtensor(fqn, obj)) requests.append(_create_write_items_for_dtensor(fqn, obj))
elif isinstance(obj, ShardedTensor): elif isinstance(obj, ShardedTensor):
for shard_md in obj.metadata().shards_metadata: requests.extend(
requests.append(_create_write_item_for_shard(fqn, obj, shard_md)) _create_write_item_for_shard(fqn, obj, shard_md)
for shard_md in obj.metadata().shards_metadata
)
elif isinstance(obj, torch.Tensor): elif isinstance(obj, torch.Tensor):
requests.append(_create_write_item_for_tensor(fqn, obj)) requests.append(_create_write_item_for_tensor(fqn, obj))
else: else:

View File

@ -1371,9 +1371,9 @@ def _get_all_pg_configs() -> List[Dict[str, Any]]:
Return the pg configuration of all the process groups. Return the pg configuration of all the process groups.
""" """
config_info: List[Dict[str, Any]] = [] config_info: List[Dict[str, Any]] = [
for pg in _world.pg_map.keys(): _get_pg_config(pg) for pg in _world.pg_map.keys()
config_info.append(_get_pg_config(pg)) ]
return config_info return config_info
@ -2508,9 +2508,7 @@ def _coalescing_manager(
# - coalesced `reduce_scatter_tensor` # - coalesced `reduce_scatter_tensor`
op0 = op_list[0].op op0 = op_list[0].op
if op0 == all_reduce: if op0 == all_reduce:
tensors = [] tensors = [op.tensor for op in op_list]
for op in op_list:
tensors.append(op.tensor)
all_reduce_opts = AllreduceCoalescedOptions() all_reduce_opts = AllreduceCoalescedOptions()
all_reduce_opts.reduceOp = not_none(op_list[0].redop) all_reduce_opts.reduceOp = not_none(op_list[0].redop)
work = group.allreduce_coalesced(tensors, all_reduce_opts) work = group.allreduce_coalesced(tensors, all_reduce_opts)

View File

@ -568,10 +568,11 @@ def _root_pre_forward(
state._needs_buffer_dtype_restore_check = False state._needs_buffer_dtype_restore_check = False
if state.forward_prefetch: if state.forward_prefetch:
handles = [] handles = [
for fsdp_state in state._all_fsdp_states: fsdp_state._handle
if fsdp_state._handle: for fsdp_state in state._all_fsdp_states
handles.append(fsdp_state._handle) if fsdp_state._handle
]
for handle in handles: for handle in handles:
handle._needs_pre_forward_unshard = True handle._needs_pre_forward_unshard = True
handle._prefetched = False handle._prefetched = False

View File

@ -110,9 +110,9 @@ def _create_module_with_interface(
def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]: def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]:
ret: List[rpc.RRef[Parameter]] = [] ret: List[rpc.RRef[Parameter]] = [
for param in module_rref.local_value().parameters(recurse): rpc.RRef(param) for param in module_rref.local_value().parameters(recurse)
ret.append(rpc.RRef(param)) ]
return ret return ret

View File

@ -146,9 +146,7 @@ class _NamedOptimizer(optim.Optimizer):
ret_groups = [] ret_groups = []
for group in param_groups: for group in param_groups:
param_keys = [] param_keys = [self.ordered_param_keys[param] for param in group["params"]]
for param in group["params"]:
param_keys.append(self.ordered_param_keys[param])
ret_group = {"params": sorted(param_keys)} ret_group = {"params": sorted(param_keys)}
for k, v in group.items(): for k, v in group.items():
if k != "params": if k != "params":

View File

@ -245,13 +245,12 @@ class DistributedOptimizer:
else _local_optimizer_step else _local_optimizer_step
) )
rpc_futs = [] rpc_futs = [
for optimizer in self.remote_optimizers:
rpc_futs.append(
rpc.rpc_async( rpc.rpc_async(
optimizer.owner(), optimizer.owner(),
optimizer_step_func, optimizer_step_func,
args=(optimizer, context_id), args=(optimizer, context_id),
) )
) for optimizer in self.remote_optimizers
]
_wait_for_all(rpc_futs) _wait_for_all(rpc_futs)

View File

@ -781,14 +781,14 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self.process_group, rank self.process_group, rank
) )
for param_group in param_groups: for param_group in param_groups:
for param in param_group["params"]: handles.extend(
handles.append(
dist.broadcast( dist.broadcast(
tensor=param.data, tensor=param.data,
src=global_rank, src=global_rank,
group=self.process_group, group=self.process_group,
async_op=True, async_op=True,
) )
for param in param_group["params"]
) )
return handles return handles

View File

@ -224,9 +224,7 @@ def _shard_dict_of_args(
for chunk_idx in range(real_num_chunks): for chunk_idx in range(real_num_chunks):
chunk_args = {} chunk_args = {}
for key, arg in args_sharded_replicated.items(): for key, arg in args_sharded_replicated.items():
arg_single_chunk = [] arg_single_chunk = [v_flat[chunk_idx] for v_flat in arg]
for v_flat in arg:
arg_single_chunk.append(v_flat[chunk_idx])
chunk_args[key] = arg_single_chunk chunk_args[key] = arg_single_chunk
chunks_flat.append(chunk_args) chunks_flat.append(chunk_args)
@ -340,9 +338,10 @@ def split_args_kwargs_into_chunks(
f"{len(args_split_dict)}, {len(kwargs_split)}" f"{len(args_split_dict)}, {len(kwargs_split)}"
) )
args_split = [] args_split = [
for chunk_args in args_split_dict: tuple(chunk_args[i] for i in range(len(chunk_args)))
args_split.append(tuple(chunk_args[i] for i in range(len(chunk_args)))) for chunk_args in args_split_dict
]
return args_split, kwargs_split return args_split, kwargs_split

View File

@ -1702,15 +1702,13 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
) )
# Store the list of operations used for that rank # Store the list of operations used for that rank
rank_ops: List[Optional[_Action]] = []
# Pre-padding, rank starts with no-ops based on the warmup. # Pre-padding, rank starts with no-ops based on the warmup.
for _ in range(rank): rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
rank_ops.append(None)
for stage_index in stage_indices: for stage_index in stage_indices:
for mb_index in range(self._n_microbatches): rank_ops.extend(
rank_ops.append(
_Action(stage_index, _ComputationType.FORWARD, mb_index) _Action(stage_index, _ComputationType.FORWARD, mb_index)
for mb_index in range(self._n_microbatches)
) )
# wait for the first backward to trickle up # wait for the first backward to trickle up
@ -1719,9 +1717,9 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
rank_ops.extend([None] * post_warmup_ops) rank_ops.extend([None] * post_warmup_ops)
for stage_index in reversed(stage_indices): for stage_index in reversed(stage_indices):
for mb_index in reversed(range(self._n_microbatches)): rank_ops.extend(
rank_ops.append(
_Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index) _Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index)
for mb_index in reversed(range(self._n_microbatches))
) )
return rank_ops return rank_ops
@ -1744,10 +1742,8 @@ def _get_1f1b_rank_ops(
weight_stage_mb_index: Dict[int, int] = defaultdict(int) weight_stage_mb_index: Dict[int, int] = defaultdict(int)
# Store the list of operations used for that rank # Store the list of operations used for that rank
rank_ops: List[Optional[_Action]] = []
# Pre-padding, rank starts with no-ops based on the warmup. # Pre-padding, rank starts with no-ops based on the warmup.
for _ in range(rank): rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
rank_ops.append(None)
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
# Formula: # Formula:

View File

@ -195,8 +195,7 @@ def fill_empty_tensor_to_shards(
size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size) size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size)
] ]
tensor = shards[0].new_zeros(tensor_size) tensor = shards[0].new_zeros(tensor_size)
for _ in range(num_empty_tensors): shards.extend(tensor for _ in range(num_empty_tensors))
shards.append(tensor)
return shards return shards

View File

@ -167,9 +167,7 @@ def gen_einsum_strategies(
# (i.e. for Shard, tensor dim size must > mesh size) # (i.e. for Shard, tensor dim size must > mesh size)
all_strategies = [] all_strategies = []
for strategy_comb in strategy_combs: for strategy_comb in strategy_combs:
spec_list = [] spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)]
for specs in zip(*strategy_comb):
spec_list.append(DTensorSpec(mesh, tuple(specs)))
strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:]) strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:])
all_strategies.append(strat) all_strategies.append(strat)

View File

@ -43,18 +43,17 @@ def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
# Default strategy by default just propagate the first input strategy # Default strategy by default just propagate the first input strategy
select_strategy = op_schema.args_schema[0] select_strategy = op_schema.args_schema[0]
assert isinstance(select_strategy, OpStrategy) assert isinstance(select_strategy, OpStrategy)
default_strategy = []
for strategy in select_strategy.strategies:
# we create new DTensorSpecs even for default strategy to assure that # we create new DTensorSpecs even for default strategy to assure that
# the tensor metas are distinct between the arguments and outputs # the tensor metas are distinct between the arguments and outputs
default_strategy.append( default_strategy = [
PlacementStrategy( PlacementStrategy(
output_specs=DTensorSpec( output_specs=DTensorSpec(
mesh=strategy.output_spec.mesh, mesh=strategy.output_spec.mesh,
placements=strategy.output_spec.placements, placements=strategy.output_spec.placements,
) )
) )
) for strategy in select_strategy.strategies
]
return OpStrategy(default_strategy) return OpStrategy(default_strategy)

View File

@ -209,9 +209,10 @@ def map_placements_after_broadcast(
def generate_redistribute_costs( def generate_redistribute_costs(
src_strategy: OpStrategy, dst_spec: DTensorSpec src_strategy: OpStrategy, dst_spec: DTensorSpec
) -> List[float]: ) -> List[float]:
redistribute_costs: List[float] = [] redistribute_costs: List[float] = [
for strat in src_strategy.strategies: redistribute_cost(strat.output_spec, dst_spec)
redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec)) for strat in src_strategy.strategies
]
return redistribute_costs return redistribute_costs

View File

@ -152,9 +152,11 @@ class InterpreterModule(torch.nn.Module):
# the keys in the kwarg dict. # the keys in the kwarg dict.
arg_list = list(args) arg_list = list(args)
kwarg_names = self.arg_names[len(arg_list) :] kwarg_names = self.arg_names[len(arg_list) :]
for kwarg_name in kwarg_names: arg_list.extend(
if kwarg_name in kwargs: kwargs[kwarg_name]
arg_list.append(kwargs[kwarg_name]) for kwarg_name in kwarg_names
if kwarg_name in kwargs
)
# Assert that the kwargs passed in exactly match the positional # Assert that the kwargs passed in exactly match the positional
# arguments specified by the GraphModule. This should be # arguments specified by the GraphModule. This should be
@ -933,9 +935,10 @@ class _ModuleFrame:
assert kwargs_spec.context is not None assert kwargs_spec.context is not None
with self.graph.inserting_after(None): with self.graph.inserting_after(None):
arg_nodes = [] arg_nodes = [
for idx in range(args_spec.num_children): self.graph.placeholder(f"_positional_arg_{idx}")
arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}")) for idx in range(args_spec.num_children)
]
kwarg_nodes = {} kwarg_nodes = {}
for name in kwargs_spec.context: for name in kwargs_spec.context:
kwarg_nodes[name] = self.graph.placeholder(name) kwarg_nodes[name] = self.graph.placeholder(name)

View File

@ -1015,10 +1015,11 @@ class Partitioner:
# Keep tracking the partition pair of node pair # Keep tracking the partition pair of node pair
partition_pair: List[Partition] = [] partition_pair: List[Partition] = []
# Collect all the op nodes from the graph # Collect all the op nodes from the graph
op_nodes = [] op_nodes = [
for n in self.graph_module.graph.nodes: n
if n.op not in {"placeholder", "get_attr", "output"}: for n in self.graph_module.graph.nodes
op_nodes.append(n) if n.op not in {"placeholder", "get_attr", "output"}
]
for node in op_nodes: for node in op_nodes:
# Find which partition the current node belongs # Find which partition the current node belongs
p0_index = self.node_to_partition[node] p0_index = self.node_to_partition[node]

View File

@ -648,11 +648,7 @@ def generate_reshape(constraint, counter):
# then there must be exactly one occurrence of dyn # then there must be exactly one occurrence of dyn
else: else:
new_target = [] new_target = [n for n in target if n != Dyn]
for n in target:
if n != Dyn:
new_target.append(n)
# tensor 1 # tensor 1
c3_tensor1 = Disj( c3_tensor1 = Disj(
@ -886,10 +882,8 @@ def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
neq_possibilities = [ neq_possibilities = [
BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list)) BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))
] ]
d_possibilities = []
for i in zip(eq_possibilities, neq_possibilities): d_possibilities = [list(i) for i in zip(eq_possibilities, neq_possibilities)]
d_possibilities.append(list(i))
all_possibilities = list(itertools.product(*d_possibilities)) all_possibilities = list(itertools.product(*d_possibilities))
return all_possibilities return all_possibilities
@ -1043,13 +1037,11 @@ def apply_padding(
assert len(simulate_padding + d1) == len(d2) assert len(simulate_padding + d1) == len(d2)
broadcast_padding = []
# for every padding size, we also consider broadcasting # for every padding size, we also consider broadcasting
for j in range(len(d2) - i): broadcast_padding = [
broadcast_padding.append(
broadcast_dim(simulate_padding, d2, d11, d12, j, True) broadcast_dim(simulate_padding, d2, d11, d12, j, True)
) for j in range(len(d2) - i)
]
# we consider the possibilities for broadcasting for every dimension. Since we already # we consider the possibilities for broadcasting for every dimension. Since we already
# padded d1, we do not consider it while broadcasting # padded d1, we do not consider it while broadcasting

View File

@ -302,11 +302,10 @@ def get_latency_of_partitioned_graph(
"""This function is to return all the partitions without parents """This function is to return all the partitions without parents
as the starting points of all the paths as the starting points of all the paths
""" """
top_partitions = []
for partition in partitions:
# If a partition has no parents, then it is a top partition # If a partition has no parents, then it is a top partition
if len(partition.parents) == 0: top_partitions = [
top_partitions.append(partition) partition for partition in partitions if len(partition.parents) == 0
]
return top_partitions return top_partitions
top_partitions = get_top_partitions(partitions) top_partitions = get_top_partitions(partitions)

View File

@ -5201,10 +5201,9 @@ class ShapeEnv:
symints = { symints = {
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol) s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
} }
guards = [] guards = [
for g in self.guards: g for g in self.guards if all(s in symints for s in g.expr.free_symbols)
if all(s in symints for s in g.expr.free_symbols): ]
guards.append(g)
return guards return guards
def bind_symbols( def bind_symbols(

View File

@ -220,10 +220,7 @@ class PassManager:
def remove_pass(self, _passes: List[str]): def remove_pass(self, _passes: List[str]):
if _passes is None: if _passes is None:
return return
passes_left = [] passes_left = [ps for ps in self.passes if ps.__name__ not in _passes]
for ps in self.passes:
if ps.__name__ not in _passes:
passes_left.append(ps)
self.passes = passes_left self.passes = passes_left
self._validated = False self._validated = False

View File

@ -881,9 +881,7 @@ def infer_methods_to_compile(nn_module):
uniqued_methods.append(name) uniqued_methods.append(name)
uniquer.add(name) uniquer.add(name)
stubs = [] stubs = [make_stub_from_method(nn_module, method) for method in uniqued_methods]
for method in uniqued_methods:
stubs.append(make_stub_from_method(nn_module, method))
return overload_stubs + stubs return overload_stubs + stubs
@ -959,9 +957,10 @@ def interface_script(mod_interface, nn_module):
It is used to know which methods need to act as starting points for compilation. It is used to know which methods need to act as starting points for compilation.
""" """
stubs = [] stubs = [
for method in mod_interface.getMethodNames(): make_stub_from_method(nn_module, method)
stubs.append(make_stub_from_method(nn_module, method)) for method in mod_interface.getMethodNames()
]
return stubs return stubs
return create_script_module(nn_module, infer_interface_methods_to_compile) return create_script_module(nn_module, infer_interface_methods_to_compile)

View File

@ -1493,11 +1493,10 @@ def _get_overloads(obj):
_jit_internal.get_overload_no_implementation_error_message("function", obj) _jit_internal.get_overload_no_implementation_error_message("function", obj)
) )
compiled_fns = [] compiled_fns = [
for overload_fn in uncompiled_overloads:
compiled_fns.append(
_compile_function_with_overload(overload_fn, qual_name, obj) _compile_function_with_overload(overload_fn, qual_name, obj)
) for overload_fn in uncompiled_overloads
]
if existing_compiled_fns: if existing_compiled_fns:
compiled_fns = existing_compiled_fns + compiled_fns compiled_fns = existing_compiled_fns + compiled_fns

View File

@ -77,9 +77,7 @@ def _lazy_init() -> None:
# However, we must not let any *other* threads in! # However, we must not let any *other* threads in!
_tls.is_initializing = True _tls.is_initializing = True
for calls in _lazy_seed_tracker.get_calls(): _queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls)
if calls:
_queued_calls.append(calls)
try: try:
for queued_call, orig_traceback in _queued_calls: for queued_call, orig_traceback in _queued_calls:

View File

@ -23,8 +23,7 @@ class Broadcast(Function):
non_differentiables = [] non_differentiables = []
for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]): for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
if not input_requires_grad: if not input_requires_grad:
for output in outputs: non_differentiables.extend(output[idx] for output in outputs)
non_differentiables.append(output[idx])
ctx.mark_non_differentiable(*non_differentiables) ctx.mark_non_differentiable(*non_differentiables)
return tuple([t for tensors in outputs for t in tensors]) return tuple([t for tensors in outputs for t in tensors])

View File

@ -142,9 +142,7 @@ def conv_backward(func, ctx, grad_output):
ctx.groups, ctx.groups,
) )
kernel_size = [] kernel_size = [weight_shape[i] for i in range(2, conv_picker(func, 3, 4, 5))]
for i in range(2, conv_picker(func, 3, 4, 5)):
kernel_size.append(weight_shape[i])
batch_size = ctx.batch_size batch_size = ctx.batch_size
results: List[Optional[torch.Tensor]] = [] results: List[Optional[torch.Tensor]] = []

View File

@ -356,12 +356,14 @@ def quantized_args(
return descriptor and _is_value(arg) and _is_tuple_construct(arg) return descriptor and _is_value(arg) and _is_tuple_construct(arg)
# Run regular symbolic function if none of the argument is QTensor. # Run regular symbolic function if none of the argument is QTensor.
is_quantized = [] is_quantized: typing.List[bool] = []
for descriptor, arg in descriptor_args: for descriptor, arg in descriptor_args:
# ListConstruct # ListConstruct
if _is_packed_list(arg): if _is_packed_list(arg):
for arg_input in arg.node().inputs(): is_quantized.extend(
is_quantized.append(_is_arg_quantized(descriptor, arg_input)) _is_arg_quantized(descriptor, arg_input)
for arg_input in arg.node().inputs()
)
else: else:
is_quantized.append(_is_arg_quantized(descriptor, arg)) is_quantized.append(_is_arg_quantized(descriptor, arg))

View File

@ -67,10 +67,9 @@ def col2im(
stride: Sequence[int], stride: Sequence[int],
): ):
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in] # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
adjusted_padding = [] adjusted_padding: List[int] = []
for pad in padding: for pad in padding:
for _ in range(2): adjusted_padding.extend(pad for _ in range(2))
adjusted_padding.append(pad)
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0] num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
if not adjusted_padding: if not adjusted_padding:

View File

@ -1397,8 +1397,6 @@ class GraphInfo:
original_outputs = list(graph.outputs()) original_outputs = list(graph.outputs())
original_inputs = list(graph.inputs()) original_inputs = list(graph.inputs())
new_outputs = []
def _process_bridge_value_for_lower( def _process_bridge_value_for_lower(
graph: torch.Graph, bridge_value: torch.Value graph: torch.Graph, bridge_value: torch.Value
) -> torch.Value: ) -> torch.Value:
@ -1416,9 +1414,9 @@ class GraphInfo:
graph, pivot, process_bridge_value_for_lower graph, pivot, process_bridge_value_for_lower
) )
for output in original_outputs: new_outputs = [
if _produced_by(output, lower_nodes): output for output in original_outputs if _produced_by(output, lower_nodes)
new_outputs.append(output) ]
for _ in enumerate(original_outputs): for _ in enumerate(original_outputs):
graph.eraseOutput(0) graph.eraseOutput(0)
for output in new_outputs: for output in new_outputs:

View File

@ -50,10 +50,11 @@ class DirectoryReader:
def get_all_records( def get_all_records(
self, self,
): ):
files = [] files = [
for filename in glob(f"{self.directory}/**", recursive=True): filename[len(self.directory) + 1 :]
if not os.path.isdir(filename): for filename in glob(f"{self.directory}/**", recursive=True)
files.append(filename[len(self.directory) + 1 :]) if not os.path.isdir(filename)
]
return files return files
def serialization_id( def serialization_id(

View File

@ -53,16 +53,16 @@ class _ExtractModuleReferences(ast.NodeVisitor):
if hasattr(node.func, "id") and node.func.id == "__import__": if hasattr(node.func, "id") and node.func.id == "__import__":
try: try:
name = self._grab_node_str(node.args[0]) name = self._grab_node_str(node.args[0])
fromlist = [] fromlist: List[str] = []
level = 0 level = 0
if len(node.args) > 3: if len(node.args) > 3:
for v in node.args[3].elts: fromlist.extend(self._grab_node_str(v) for v in node.args[3].elts)
fromlist.append(self._grab_node_str(v))
elif hasattr(node, "keywords"): elif hasattr(node, "keywords"):
for keyword in node.keywords: for keyword in node.keywords:
if keyword.arg == "fromlist": if keyword.arg == "fromlist":
for v in keyword.value.elts: fromlist.extend(
fromlist.append(self._grab_node_str(v)) self._grab_node_str(v) for v in keyword.value.elts
)
if len(node.args) > 4: if len(node.args) > 4:
level = self._grab_node_int(node.args[4]) level = self._grab_node_int(node.args[4])
elif hasattr(node, "keywords"): elif hasattr(node, "keywords"):

View File

@ -97,10 +97,9 @@ class Pattern:
def matched_events(self): def matched_events(self):
if self.skip: if self.skip:
return [] return []
matched_events = [] matched_events = [
for event in self.eventTreeTraversal(): event for event in self.eventTreeTraversal() if self.match(event)
if self.match(event): ]
matched_events.append(event)
return matched_events return matched_events
def root_of(self, event: _ProfilerEvent): def root_of(self, event: _ProfilerEvent):

Some files were not shown because too many files have changed in this diff Show More