[BE]: FURB142 - Remove set mutations. Use set update (#124551)

Uses set mutation methods instead of manually reimplementing (update, set_difference etc).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124551
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan 2024-04-21 14:12:30 +00:00 committed by PyTorch MergeBot
parent 5a1216bb2e
commit 29cc293725
28 changed files with 71 additions and 90 deletions

View File

@ -167,11 +167,9 @@ def refresh_model_names():
del all_models_family[key]
chosen_models = set()
for value in docs_models_family.values():
chosen_models.add(value[0])
chosen_models.update(value[0] for value in docs_models_family.values())
for key, value in all_models_family.items():
chosen_models.add(value[0])
chosen_models.update(value[0] for key, value in all_models_family.items())
filename = "timm_models_list.txt"
if os.path.exists("benchmarks"):

View File

@ -345,8 +345,9 @@ def get_operator_range(chars_range):
ops_start_chars_set.add(item.lower())
continue
start, end = item.split("-")
for c in range(ord(start), ord(end) + 1):
ops_start_chars_set.add(chr(c).lower())
ops_start_chars_set.update(
chr(c).lower() for c in range(ord(start), ord(end) + 1)
)
return ops_start_chars_set

View File

@ -144,10 +144,12 @@ class TestInitialization(FSDPTest):
# Check that the composable module does not add any wrapper class
local_module_classes = set()
composable_module_classes = set()
for submodule in local_model.modules():
local_module_classes.add(type(submodule))
for submodule in composable_module.modules():
composable_module_classes.add(type(submodule))
local_module_classes.update(
type(submodule) for submodule in local_model.modules()
)
composable_module_classes.update(
type(submodule) for submodule in composable_module.modules()
)
self.assertEqual(local_module_classes, composable_module_classes)
# Check that the composable module has the same FSDP states with the
@ -310,14 +312,14 @@ class TestInitialization(FSDPTest):
]
for data_structure_name in data_structure_names:
all_structures = set()
for module in (
composable_module.u1,
composable_module.u2,
composable_module,
):
all_structures.add(
id(getattr(fully_shard.state(module), data_structure_name))
all_structures.update(
id(getattr(fully_shard.state(module), data_structure_name))
for module in (
composable_module.u1,
composable_module.u2,
composable_module,
)
)
self.assertEqual(len(all_structures), 1)

View File

@ -945,8 +945,7 @@ class TestWrapUtils(TestCase):
ignored_params = set()
for module_name, module in model.named_modules():
if "lora_A" in module_name:
for param in module.parameters():
ignored_params.add(param)
ignored_params.update(module.parameters())
_validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params)

View File

@ -1375,8 +1375,7 @@ def forward(self, getitem, const):
cond_gm = backend.graphs[0]
name_set = set()
for name, _ in cond_gm.named_modules():
name_set.add(name)
name_set.update(name for name, _ in cond_gm.named_modules())
self.assertEqual(
name_set,
{
@ -1735,8 +1734,7 @@ def forward(self):
self.assertEqual(result, x + y + x)
wrap_gm = backend.graphs[0]
names = set()
for mod_name, _ in wrap_gm.named_modules():
names.add(mod_name)
names.update(mod_name for mod_name, _ in wrap_gm.named_modules())
self.assertEqual(
names,
{

View File

@ -365,8 +365,7 @@ def get_all_tested_ops():
result = set({})
for op in get_covered_ops(overridable_outplace_we_care_about).values():
opinfos = op_to_opinfo[op]
for opinfo in opinfos:
result.add(opinfo.name)
result.update(opinfo.name for opinfo in opinfos)
return result

View File

@ -79,8 +79,7 @@ class TestDiGraph(PackageTestCase):
g.add_node(3)
nodes = set()
for n in g:
nodes.add(n)
nodes.update(g)
self.assertEqual(nodes, {1, 2, 3})

View File

@ -1617,8 +1617,7 @@ except RuntimeError as e:
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
dataloader = self._get_data_loader(dataset, batch_size=batch_size, num_workers=num_workers)
seeds = set()
for batch in dataloader:
seeds.add(batch[0])
seeds.update(batch[0] for batch in dataloader)
self.assertEqual(len(seeds), num_workers)
def test_worker_seed_reproducibility(self):

View File

@ -9523,8 +9523,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
device_hash_set = set()
for device in device_set:
device_hash_set.add(hash(torch.device(device)))
device_hash_set.update(hash(torch.device(device)) for device in device_set)
self.assertEqual(len(device_set), len(device_hash_set))
def get_expected_device_repr(device):

View File

@ -3233,17 +3233,19 @@ if torch.distributed.is_available():
@functools.lru_cache(None)
def get_legacy_mod_inlinelist():
inlinelist = set()
for m in LEGACY_MOD_INLINELIST:
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
inlinelist = {
_module_dir(torch) + m[len("torch.") :].replace(".", "/")
for m in LEGACY_MOD_INLINELIST
}
return inlinelist
@functools.lru_cache(None)
def get_mod_inlinelist():
inlinelist = set()
for m in MOD_INLINELIST:
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
inlinelist = {
_module_dir(torch) + m[len("torch.") :].replace(".", "/")
for m in MOD_INLINELIST
}
return inlinelist

View File

@ -744,8 +744,7 @@ def min_cut_rematerialization_partition(
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
if node in required_bw_nodes:
for user in node.users:
required_bw_nodes.add(user)
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(

View File

@ -3623,8 +3623,7 @@ class CppScheduling(BaseScheduling):
if var_ranges is None:
var_ranges = v
assert var_ranges == v, (var_ranges, v, node.snodes)
for expr in exprs:
indexing_exprs.add(expr)
indexing_exprs.update(exprs)
return var_ranges, list(indexing_exprs)
else:
assert isinstance(node, SchedulerNode)

View File

@ -635,8 +635,7 @@ class GraphLowering(torch.fx.Interpreter):
# - sebotnet33ts_256
for n in self.module.graph.nodes:
if n in output_set:
for child in n.users:
output_set.add(child)
output_set.update(n.users)
return output_set

View File

@ -89,8 +89,9 @@ def add_needs_realized_inputs(fn):
return [add_needs_realized_inputs(x) for x in fn]
needs_realized_inputs.add(fn)
if isinstance(fn, torch._ops.OpOverloadPacket):
for overload in fn.overloads():
needs_realized_inputs.add(getattr(fn, overload))
needs_realized_inputs.update(
getattr(fn, overload) for overload in fn.overloads()
)
def add_layout_constraint(fn, constraint):

View File

@ -2292,9 +2292,7 @@ class Scheduler:
Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
"""
future_used_buffers = set()
for node_name in V.graph.get_output_names():
future_used_buffers.add(node_name)
future_used_buffers = set(V.graph.get_output_names())
for node in reversed(self.nodes):
node.set_last_usage(future_used_buffers, self.mutation_real_name)

View File

@ -223,9 +223,10 @@ class CustomOpDef:
def backend_impl(*args, **kwargs):
# Checks the assumption that outputs cannot alias
# inputs or other outputs.
storages = set()
for tensor in iter_tensors(args, kwargs):
storages.add(id(tensor.untyped_storage()))
storages = {
id(tensor.untyped_storage())
for tensor in iter_tensors(args, kwargs)
}
result = self._backend_fns[device_type](*args, **kwargs)

View File

@ -742,8 +742,7 @@ def create_add_loggers_graph(
insert_submodule_copy = False
if maybe_subgraph is not None:
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
for node_to_skip in maybe_subgraph:
nodes_to_skip.add(node_to_skip)
nodes_to_skip.update(maybe_subgraph)
qconfig = node_name_to_qconfig[first_node.name]
if qconfig is not None:
insert_submodule_copy = True
@ -873,8 +872,7 @@ def create_add_loggers_graph(
maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
if maybe_subgraph is not None:
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
for node_to_skip in maybe_subgraph:
nodes_to_skip.add(node_to_skip)
nodes_to_skip.update(maybe_subgraph)
else:
first_node, last_node = n, n

View File

@ -45,9 +45,9 @@ class EmbeddingQuantizer(Quantizer):
@classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({})
for spec, _ in cls.get_supported_operators():
op_configs.add(spec)
op_configs: Set[QuantizationConfig] = {
spec for spec, _ in cls.get_supported_operators()
}
return list(op_configs)
@classmethod

View File

@ -286,9 +286,9 @@ class X86InductorQuantizer(Quantizer):
@classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({})
for spec, _ in cls.supported_config_and_operators:
op_configs.add(spec)
op_configs: Set[QuantizationConfig] = {
spec for spec, _ in cls.supported_config_and_operators
}
return list(op_configs)
@classmethod

View File

@ -305,9 +305,9 @@ class XNNPACKQuantizer(Quantizer):
@classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({})
for spec, _ in cls.supported_config_and_operators:
op_configs.add(spec)
op_configs: Set[QuantizationConfig] = {
spec for spec, _ in cls.supported_config_and_operators
}
return list(op_configs)
@classmethod

View File

@ -1,7 +1,7 @@
import itertools
from dataclasses import dataclass
from typing import List, Tuple
from typing import List, Set, Tuple
from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy
from torch.distributed._tensor.placement_types import (
@ -44,10 +44,9 @@ class EinsumDims:
Parse the dims and extract the contracting, batch, and free dimensions
for the left and right hand sides.
"""
dim_char_set = set()
dim_char_set: Set[str] = set()
for input_dim in input_dims:
for input_char in list(input_dim):
dim_char_set.add(input_char)
dim_char_set.update(input_dim)
# get a determinisitc order of all dim chars
all_dim_chars = sorted(dim_char_set)

View File

@ -218,7 +218,7 @@ def _verify_options(
fqn_param_mapping[fqn] = param
all_fqns.add(fqn)
submodule_prefixes = set()
submodule_prefixes: Set[str] = set()
if submodules:
submodules = set(submodules)
for name, module in model.named_modules():
@ -226,8 +226,7 @@ def _verify_options(
continue
fqns = _get_fqns(model, name)
assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
for fqn in fqns:
submodule_prefixes.add(f"{fqn}.")
submodule_prefixes.update(f"{fqn}." for fqn in fqns)
fsdp_modules = FSDP.fsdp_modules(model)
state_dict_config: StateDictConfig

View File

@ -112,9 +112,7 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
def _format_import_block(globals: Dict[str, Any], importer: Importer):
import_strs: Set[str] = set()
for name, obj in globals.items():
import_strs.add(_format_import_statement(name, obj, importer))
import_strs: Set[str] = {_format_import_statement(name, obj, importer) for name, obj in globals.items()}
# Sort the imports so we have a stable import block that allows us to
# hash the graph module and get a consistent key for use in a cache.
return "\n".join(sorted(import_strs))

View File

@ -294,8 +294,7 @@ def _replace_pattern(
# Copy the replacement graph over
user_nodes: Set[Node] = set()
for n in match.returning_nodes:
for user in n.users:
user_nodes.add(user)
user_nodes.update(n.users)
assert user_nodes, "The returning_nodes should have at least one user node"
if len(user_nodes) == 1:

View File

@ -930,8 +930,9 @@ class MemoryProfile:
self._is_gradient(*i) or i in used_for_gradient
for i in node.outputs.items()
):
for key, (_, version) in node.inputs.items():
used_for_gradient.add((key, version))
used_for_gradient.update(
(key, version) for key, (_, version) in node.inputs.items()
)
candidate_parameters.intersection_update(used_for_gradient)
# and depends on a gradient.

View File

@ -34,9 +34,7 @@ def _strip_datapipe_from_name(name: str) -> str:
def _generate_input_args_string(obj):
"""Generate a string for the input arguments of an object."""
signature = inspect.signature(obj.__class__)
input_param_names = set()
for param_name in signature.parameters.keys():
input_param_names.add(param_name)
input_param_names = set(signature.parameters.keys())
result = []
for name, value in inspect.getmembers(obj):
if name in input_param_names:

View File

@ -578,10 +578,8 @@ def _compute_in_out(ops):
out_blobs = set()
for op in ops:
for input_blob in op.input:
in_blobs.add(input_blob)
for output_blob in op.output:
out_blobs.add(output_blob)
in_blobs.update(op.input)
out_blobs.update(op.output)
input_blobs = list(in_blobs.difference(out_blobs))
output_blobs = list(out_blobs.difference(in_blobs))
@ -700,8 +698,7 @@ def _operators_to_graph_def(
else [_operator_to_node(shapes, op)]
) # .extend() expects an iterable
current_graph.node.extend(nodes_from_op)
for input_blob in op.input:
blobs.add(input_blob)
blobs.update(op.input)
for i, output_blob in enumerate(op.output):
blobs.add(output_blob)
producing_ops.setdefault(output_blob, []).append((op, i))

View File

@ -2125,7 +2125,7 @@ def gen_headers(
)
def gen_aten_interned_strings() -> Dict[str, str]:
attrs = set() # All function argument names
attrs: Set[str] = set() # All function argument names
names = set() # All ATen function names
for func in native_functions:
names.add(str(func.func.name.name))
@ -2133,8 +2133,7 @@ def gen_headers(
# symbol without the underscore
names.add(func.func.name.name.base)
for arg in func.func.schema_order_arguments():
attrs.add(arg.name)
attrs.update(arg.name for arg in func.func.schema_order_arguments())
# These are keywords in C++, so aren't valid symbol names
# https://en.cppreference.com/w/cpp/language/operator_alternative