[BE]: FURB142 - Remove set mutations. Use set update (#124551)

Uses set mutation methods instead of manually reimplementing (update, set_difference etc).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124551
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan 2024-04-21 14:12:30 +00:00 committed by PyTorch MergeBot
parent 5a1216bb2e
commit 29cc293725
28 changed files with 71 additions and 90 deletions

View File

@ -167,11 +167,9 @@ def refresh_model_names():
del all_models_family[key] del all_models_family[key]
chosen_models = set() chosen_models = set()
for value in docs_models_family.values(): chosen_models.update(value[0] for value in docs_models_family.values())
chosen_models.add(value[0])
for key, value in all_models_family.items(): chosen_models.update(value[0] for key, value in all_models_family.items())
chosen_models.add(value[0])
filename = "timm_models_list.txt" filename = "timm_models_list.txt"
if os.path.exists("benchmarks"): if os.path.exists("benchmarks"):

View File

@ -345,8 +345,9 @@ def get_operator_range(chars_range):
ops_start_chars_set.add(item.lower()) ops_start_chars_set.add(item.lower())
continue continue
start, end = item.split("-") start, end = item.split("-")
for c in range(ord(start), ord(end) + 1): ops_start_chars_set.update(
ops_start_chars_set.add(chr(c).lower()) chr(c).lower() for c in range(ord(start), ord(end) + 1)
)
return ops_start_chars_set return ops_start_chars_set

View File

@ -144,10 +144,12 @@ class TestInitialization(FSDPTest):
# Check that the composable module does not add any wrapper class # Check that the composable module does not add any wrapper class
local_module_classes = set() local_module_classes = set()
composable_module_classes = set() composable_module_classes = set()
for submodule in local_model.modules(): local_module_classes.update(
local_module_classes.add(type(submodule)) type(submodule) for submodule in local_model.modules()
for submodule in composable_module.modules(): )
composable_module_classes.add(type(submodule)) composable_module_classes.update(
type(submodule) for submodule in composable_module.modules()
)
self.assertEqual(local_module_classes, composable_module_classes) self.assertEqual(local_module_classes, composable_module_classes)
# Check that the composable module has the same FSDP states with the # Check that the composable module has the same FSDP states with the
@ -310,14 +312,14 @@ class TestInitialization(FSDPTest):
] ]
for data_structure_name in data_structure_names: for data_structure_name in data_structure_names:
all_structures = set() all_structures = set()
for module in ( all_structures.update(
composable_module.u1, id(getattr(fully_shard.state(module), data_structure_name))
composable_module.u2, for module in (
composable_module, composable_module.u1,
): composable_module.u2,
all_structures.add( composable_module,
id(getattr(fully_shard.state(module), data_structure_name))
) )
)
self.assertEqual(len(all_structures), 1) self.assertEqual(len(all_structures), 1)

View File

@ -945,8 +945,7 @@ class TestWrapUtils(TestCase):
ignored_params = set() ignored_params = set()
for module_name, module in model.named_modules(): for module_name, module in model.named_modules():
if "lora_A" in module_name: if "lora_A" in module_name:
for param in module.parameters(): ignored_params.update(module.parameters())
ignored_params.add(param)
_validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params) _validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params)

View File

@ -1375,8 +1375,7 @@ def forward(self, getitem, const):
cond_gm = backend.graphs[0] cond_gm = backend.graphs[0]
name_set = set() name_set = set()
for name, _ in cond_gm.named_modules(): name_set.update(name for name, _ in cond_gm.named_modules())
name_set.add(name)
self.assertEqual( self.assertEqual(
name_set, name_set,
{ {
@ -1735,8 +1734,7 @@ def forward(self):
self.assertEqual(result, x + y + x) self.assertEqual(result, x + y + x)
wrap_gm = backend.graphs[0] wrap_gm = backend.graphs[0]
names = set() names = set()
for mod_name, _ in wrap_gm.named_modules(): names.update(mod_name for mod_name, _ in wrap_gm.named_modules())
names.add(mod_name)
self.assertEqual( self.assertEqual(
names, names,
{ {

View File

@ -365,8 +365,7 @@ def get_all_tested_ops():
result = set({}) result = set({})
for op in get_covered_ops(overridable_outplace_we_care_about).values(): for op in get_covered_ops(overridable_outplace_we_care_about).values():
opinfos = op_to_opinfo[op] opinfos = op_to_opinfo[op]
for opinfo in opinfos: result.update(opinfo.name for opinfo in opinfos)
result.add(opinfo.name)
return result return result

View File

@ -79,8 +79,7 @@ class TestDiGraph(PackageTestCase):
g.add_node(3) g.add_node(3)
nodes = set() nodes = set()
for n in g: nodes.update(g)
nodes.add(n)
self.assertEqual(nodes, {1, 2, 3}) self.assertEqual(nodes, {1, 2, 3})

View File

@ -1617,8 +1617,7 @@ except RuntimeError as e:
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers) dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
dataloader = self._get_data_loader(dataset, batch_size=batch_size, num_workers=num_workers) dataloader = self._get_data_loader(dataset, batch_size=batch_size, num_workers=num_workers)
seeds = set() seeds = set()
for batch in dataloader: seeds.update(batch[0] for batch in dataloader)
seeds.add(batch[0])
self.assertEqual(len(seeds), num_workers) self.assertEqual(len(seeds), num_workers)
def test_worker_seed_reproducibility(self): def test_worker_seed_reproducibility(self):

View File

@ -9523,8 +9523,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'} device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
device_hash_set = set() device_hash_set = set()
for device in device_set: device_hash_set.update(hash(torch.device(device)) for device in device_set)
device_hash_set.add(hash(torch.device(device)))
self.assertEqual(len(device_set), len(device_hash_set)) self.assertEqual(len(device_set), len(device_hash_set))
def get_expected_device_repr(device): def get_expected_device_repr(device):

View File

@ -3233,17 +3233,19 @@ if torch.distributed.is_available():
@functools.lru_cache(None) @functools.lru_cache(None)
def get_legacy_mod_inlinelist(): def get_legacy_mod_inlinelist():
inlinelist = set() inlinelist = {
for m in LEGACY_MOD_INLINELIST: _module_dir(torch) + m[len("torch.") :].replace(".", "/")
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) for m in LEGACY_MOD_INLINELIST
}
return inlinelist return inlinelist
@functools.lru_cache(None) @functools.lru_cache(None)
def get_mod_inlinelist(): def get_mod_inlinelist():
inlinelist = set() inlinelist = {
for m in MOD_INLINELIST: _module_dir(torch) + m[len("torch.") :].replace(".", "/")
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) for m in MOD_INLINELIST
}
return inlinelist return inlinelist

View File

@ -744,8 +744,7 @@ def min_cut_rematerialization_partition(
if node.op == "placeholder" and "tangents" in node.target: if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node) required_bw_nodes.add(node)
if node in required_bw_nodes: if node in required_bw_nodes:
for user in node.users: required_bw_nodes.update(node.users)
required_bw_nodes.add(user)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list( fwd_seed_offset_inputs = list(

View File

@ -3623,8 +3623,7 @@ class CppScheduling(BaseScheduling):
if var_ranges is None: if var_ranges is None:
var_ranges = v var_ranges = v
assert var_ranges == v, (var_ranges, v, node.snodes) assert var_ranges == v, (var_ranges, v, node.snodes)
for expr in exprs: indexing_exprs.update(exprs)
indexing_exprs.add(expr)
return var_ranges, list(indexing_exprs) return var_ranges, list(indexing_exprs)
else: else:
assert isinstance(node, SchedulerNode) assert isinstance(node, SchedulerNode)

View File

@ -635,8 +635,7 @@ class GraphLowering(torch.fx.Interpreter):
# - sebotnet33ts_256 # - sebotnet33ts_256
for n in self.module.graph.nodes: for n in self.module.graph.nodes:
if n in output_set: if n in output_set:
for child in n.users: output_set.update(n.users)
output_set.add(child)
return output_set return output_set

View File

@ -89,8 +89,9 @@ def add_needs_realized_inputs(fn):
return [add_needs_realized_inputs(x) for x in fn] return [add_needs_realized_inputs(x) for x in fn]
needs_realized_inputs.add(fn) needs_realized_inputs.add(fn)
if isinstance(fn, torch._ops.OpOverloadPacket): if isinstance(fn, torch._ops.OpOverloadPacket):
for overload in fn.overloads(): needs_realized_inputs.update(
needs_realized_inputs.add(getattr(fn, overload)) getattr(fn, overload) for overload in fn.overloads()
)
def add_layout_constraint(fn, constraint): def add_layout_constraint(fn, constraint):

View File

@ -2292,9 +2292,7 @@ class Scheduler:
Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
""" """
future_used_buffers = set() future_used_buffers = set(V.graph.get_output_names())
for node_name in V.graph.get_output_names():
future_used_buffers.add(node_name)
for node in reversed(self.nodes): for node in reversed(self.nodes):
node.set_last_usage(future_used_buffers, self.mutation_real_name) node.set_last_usage(future_used_buffers, self.mutation_real_name)

View File

@ -223,9 +223,10 @@ class CustomOpDef:
def backend_impl(*args, **kwargs): def backend_impl(*args, **kwargs):
# Checks the assumption that outputs cannot alias # Checks the assumption that outputs cannot alias
# inputs or other outputs. # inputs or other outputs.
storages = set() storages = {
for tensor in iter_tensors(args, kwargs): id(tensor.untyped_storage())
storages.add(id(tensor.untyped_storage())) for tensor in iter_tensors(args, kwargs)
}
result = self._backend_fns[device_type](*args, **kwargs) result = self._backend_fns[device_type](*args, **kwargs)

View File

@ -742,8 +742,7 @@ def create_add_loggers_graph(
insert_submodule_copy = False insert_submodule_copy = False
if maybe_subgraph is not None: if maybe_subgraph is not None:
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
for node_to_skip in maybe_subgraph: nodes_to_skip.update(maybe_subgraph)
nodes_to_skip.add(node_to_skip)
qconfig = node_name_to_qconfig[first_node.name] qconfig = node_name_to_qconfig[first_node.name]
if qconfig is not None: if qconfig is not None:
insert_submodule_copy = True insert_submodule_copy = True
@ -873,8 +872,7 @@ def create_add_loggers_graph(
maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
if maybe_subgraph is not None: if maybe_subgraph is not None:
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
for node_to_skip in maybe_subgraph: nodes_to_skip.update(maybe_subgraph)
nodes_to_skip.add(node_to_skip)
else: else:
first_node, last_node = n, n first_node, last_node = n, n

View File

@ -45,9 +45,9 @@ class EmbeddingQuantizer(Quantizer):
@classmethod @classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({}) op_configs: Set[QuantizationConfig] = {
for spec, _ in cls.get_supported_operators(): spec for spec, _ in cls.get_supported_operators()
op_configs.add(spec) }
return list(op_configs) return list(op_configs)
@classmethod @classmethod

View File

@ -286,9 +286,9 @@ class X86InductorQuantizer(Quantizer):
@classmethod @classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({}) op_configs: Set[QuantizationConfig] = {
for spec, _ in cls.supported_config_and_operators: spec for spec, _ in cls.supported_config_and_operators
op_configs.add(spec) }
return list(op_configs) return list(op_configs)
@classmethod @classmethod

View File

@ -305,9 +305,9 @@ class XNNPACKQuantizer(Quantizer):
@classmethod @classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({}) op_configs: Set[QuantizationConfig] = {
for spec, _ in cls.supported_config_and_operators: spec for spec, _ in cls.supported_config_and_operators
op_configs.add(spec) }
return list(op_configs) return list(op_configs)
@classmethod @classmethod

View File

@ -1,7 +1,7 @@
import itertools import itertools
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Tuple from typing import List, Set, Tuple
from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy
from torch.distributed._tensor.placement_types import ( from torch.distributed._tensor.placement_types import (
@ -44,10 +44,9 @@ class EinsumDims:
Parse the dims and extract the contracting, batch, and free dimensions Parse the dims and extract the contracting, batch, and free dimensions
for the left and right hand sides. for the left and right hand sides.
""" """
dim_char_set = set() dim_char_set: Set[str] = set()
for input_dim in input_dims: for input_dim in input_dims:
for input_char in list(input_dim): dim_char_set.update(input_dim)
dim_char_set.add(input_char)
# get a determinisitc order of all dim chars # get a determinisitc order of all dim chars
all_dim_chars = sorted(dim_char_set) all_dim_chars = sorted(dim_char_set)

View File

@ -218,7 +218,7 @@ def _verify_options(
fqn_param_mapping[fqn] = param fqn_param_mapping[fqn] = param
all_fqns.add(fqn) all_fqns.add(fqn)
submodule_prefixes = set() submodule_prefixes: Set[str] = set()
if submodules: if submodules:
submodules = set(submodules) submodules = set(submodules)
for name, module in model.named_modules(): for name, module in model.named_modules():
@ -226,8 +226,7 @@ def _verify_options(
continue continue
fqns = _get_fqns(model, name) fqns = _get_fqns(model, name)
assert len(fqns) == 1, "Submodule FQN should only have 1 instance" assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
for fqn in fqns: submodule_prefixes.update(f"{fqn}." for fqn in fqns)
submodule_prefixes.add(f"{fqn}.")
fsdp_modules = FSDP.fsdp_modules(model) fsdp_modules = FSDP.fsdp_modules(model)
state_dict_config: StateDictConfig state_dict_config: StateDictConfig

View File

@ -112,9 +112,7 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
def _format_import_block(globals: Dict[str, Any], importer: Importer): def _format_import_block(globals: Dict[str, Any], importer: Importer):
import_strs: Set[str] = set() import_strs: Set[str] = {_format_import_statement(name, obj, importer) for name, obj in globals.items()}
for name, obj in globals.items():
import_strs.add(_format_import_statement(name, obj, importer))
# Sort the imports so we have a stable import block that allows us to # Sort the imports so we have a stable import block that allows us to
# hash the graph module and get a consistent key for use in a cache. # hash the graph module and get a consistent key for use in a cache.
return "\n".join(sorted(import_strs)) return "\n".join(sorted(import_strs))

View File

@ -294,8 +294,7 @@ def _replace_pattern(
# Copy the replacement graph over # Copy the replacement graph over
user_nodes: Set[Node] = set() user_nodes: Set[Node] = set()
for n in match.returning_nodes: for n in match.returning_nodes:
for user in n.users: user_nodes.update(n.users)
user_nodes.add(user)
assert user_nodes, "The returning_nodes should have at least one user node" assert user_nodes, "The returning_nodes should have at least one user node"
if len(user_nodes) == 1: if len(user_nodes) == 1:

View File

@ -930,8 +930,9 @@ class MemoryProfile:
self._is_gradient(*i) or i in used_for_gradient self._is_gradient(*i) or i in used_for_gradient
for i in node.outputs.items() for i in node.outputs.items()
): ):
for key, (_, version) in node.inputs.items(): used_for_gradient.update(
used_for_gradient.add((key, version)) (key, version) for key, (_, version) in node.inputs.items()
)
candidate_parameters.intersection_update(used_for_gradient) candidate_parameters.intersection_update(used_for_gradient)
# and depends on a gradient. # and depends on a gradient.

View File

@ -34,9 +34,7 @@ def _strip_datapipe_from_name(name: str) -> str:
def _generate_input_args_string(obj): def _generate_input_args_string(obj):
"""Generate a string for the input arguments of an object.""" """Generate a string for the input arguments of an object."""
signature = inspect.signature(obj.__class__) signature = inspect.signature(obj.__class__)
input_param_names = set() input_param_names = set(signature.parameters.keys())
for param_name in signature.parameters.keys():
input_param_names.add(param_name)
result = [] result = []
for name, value in inspect.getmembers(obj): for name, value in inspect.getmembers(obj):
if name in input_param_names: if name in input_param_names:

View File

@ -578,10 +578,8 @@ def _compute_in_out(ops):
out_blobs = set() out_blobs = set()
for op in ops: for op in ops:
for input_blob in op.input: in_blobs.update(op.input)
in_blobs.add(input_blob) out_blobs.update(op.output)
for output_blob in op.output:
out_blobs.add(output_blob)
input_blobs = list(in_blobs.difference(out_blobs)) input_blobs = list(in_blobs.difference(out_blobs))
output_blobs = list(out_blobs.difference(in_blobs)) output_blobs = list(out_blobs.difference(in_blobs))
@ -700,8 +698,7 @@ def _operators_to_graph_def(
else [_operator_to_node(shapes, op)] else [_operator_to_node(shapes, op)]
) # .extend() expects an iterable ) # .extend() expects an iterable
current_graph.node.extend(nodes_from_op) current_graph.node.extend(nodes_from_op)
for input_blob in op.input: blobs.update(op.input)
blobs.add(input_blob)
for i, output_blob in enumerate(op.output): for i, output_blob in enumerate(op.output):
blobs.add(output_blob) blobs.add(output_blob)
producing_ops.setdefault(output_blob, []).append((op, i)) producing_ops.setdefault(output_blob, []).append((op, i))

View File

@ -2125,7 +2125,7 @@ def gen_headers(
) )
def gen_aten_interned_strings() -> Dict[str, str]: def gen_aten_interned_strings() -> Dict[str, str]:
attrs = set() # All function argument names attrs: Set[str] = set() # All function argument names
names = set() # All ATen function names names = set() # All ATen function names
for func in native_functions: for func in native_functions:
names.add(str(func.func.name.name)) names.add(str(func.func.name.name))
@ -2133,8 +2133,7 @@ def gen_headers(
# symbol without the underscore # symbol without the underscore
names.add(func.func.name.name.base) names.add(func.func.name.name.base)
for arg in func.func.schema_order_arguments(): attrs.update(arg.name for arg in func.func.schema_order_arguments())
attrs.add(arg.name)
# These are keywords in C++, so aren't valid symbol names # These are keywords in C++, so aren't valid symbol names
# https://en.cppreference.com/w/cpp/language/operator_alternative # https://en.cppreference.com/w/cpp/language/operator_alternative