mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE]: FURB142 - Remove set mutations. Use set update (#124551)
Uses set mutation methods instead of manually reimplementing (update, set_difference etc). Pull Request resolved: https://github.com/pytorch/pytorch/pull/124551 Approved by: https://github.com/ezyang
This commit is contained in:
parent
5a1216bb2e
commit
29cc293725
|
|
@ -167,11 +167,9 @@ def refresh_model_names():
|
|||
del all_models_family[key]
|
||||
|
||||
chosen_models = set()
|
||||
for value in docs_models_family.values():
|
||||
chosen_models.add(value[0])
|
||||
chosen_models.update(value[0] for value in docs_models_family.values())
|
||||
|
||||
for key, value in all_models_family.items():
|
||||
chosen_models.add(value[0])
|
||||
chosen_models.update(value[0] for key, value in all_models_family.items())
|
||||
|
||||
filename = "timm_models_list.txt"
|
||||
if os.path.exists("benchmarks"):
|
||||
|
|
|
|||
|
|
@ -345,8 +345,9 @@ def get_operator_range(chars_range):
|
|||
ops_start_chars_set.add(item.lower())
|
||||
continue
|
||||
start, end = item.split("-")
|
||||
for c in range(ord(start), ord(end) + 1):
|
||||
ops_start_chars_set.add(chr(c).lower())
|
||||
ops_start_chars_set.update(
|
||||
chr(c).lower() for c in range(ord(start), ord(end) + 1)
|
||||
)
|
||||
return ops_start_chars_set
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -144,10 +144,12 @@ class TestInitialization(FSDPTest):
|
|||
# Check that the composable module does not add any wrapper class
|
||||
local_module_classes = set()
|
||||
composable_module_classes = set()
|
||||
for submodule in local_model.modules():
|
||||
local_module_classes.add(type(submodule))
|
||||
for submodule in composable_module.modules():
|
||||
composable_module_classes.add(type(submodule))
|
||||
local_module_classes.update(
|
||||
type(submodule) for submodule in local_model.modules()
|
||||
)
|
||||
composable_module_classes.update(
|
||||
type(submodule) for submodule in composable_module.modules()
|
||||
)
|
||||
self.assertEqual(local_module_classes, composable_module_classes)
|
||||
|
||||
# Check that the composable module has the same FSDP states with the
|
||||
|
|
@ -310,14 +312,14 @@ class TestInitialization(FSDPTest):
|
|||
]
|
||||
for data_structure_name in data_structure_names:
|
||||
all_structures = set()
|
||||
for module in (
|
||||
composable_module.u1,
|
||||
composable_module.u2,
|
||||
composable_module,
|
||||
):
|
||||
all_structures.add(
|
||||
id(getattr(fully_shard.state(module), data_structure_name))
|
||||
all_structures.update(
|
||||
id(getattr(fully_shard.state(module), data_structure_name))
|
||||
for module in (
|
||||
composable_module.u1,
|
||||
composable_module.u2,
|
||||
composable_module,
|
||||
)
|
||||
)
|
||||
self.assertEqual(len(all_structures), 1)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -945,8 +945,7 @@ class TestWrapUtils(TestCase):
|
|||
ignored_params = set()
|
||||
for module_name, module in model.named_modules():
|
||||
if "lora_A" in module_name:
|
||||
for param in module.parameters():
|
||||
ignored_params.add(param)
|
||||
ignored_params.update(module.parameters())
|
||||
_validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1375,8 +1375,7 @@ def forward(self, getitem, const):
|
|||
|
||||
cond_gm = backend.graphs[0]
|
||||
name_set = set()
|
||||
for name, _ in cond_gm.named_modules():
|
||||
name_set.add(name)
|
||||
name_set.update(name for name, _ in cond_gm.named_modules())
|
||||
self.assertEqual(
|
||||
name_set,
|
||||
{
|
||||
|
|
@ -1735,8 +1734,7 @@ def forward(self):
|
|||
self.assertEqual(result, x + y + x)
|
||||
wrap_gm = backend.graphs[0]
|
||||
names = set()
|
||||
for mod_name, _ in wrap_gm.named_modules():
|
||||
names.add(mod_name)
|
||||
names.update(mod_name for mod_name, _ in wrap_gm.named_modules())
|
||||
self.assertEqual(
|
||||
names,
|
||||
{
|
||||
|
|
|
|||
|
|
@ -365,8 +365,7 @@ def get_all_tested_ops():
|
|||
result = set({})
|
||||
for op in get_covered_ops(overridable_outplace_we_care_about).values():
|
||||
opinfos = op_to_opinfo[op]
|
||||
for opinfo in opinfos:
|
||||
result.add(opinfo.name)
|
||||
result.update(opinfo.name for opinfo in opinfos)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -79,8 +79,7 @@ class TestDiGraph(PackageTestCase):
|
|||
g.add_node(3)
|
||||
|
||||
nodes = set()
|
||||
for n in g:
|
||||
nodes.add(n)
|
||||
nodes.update(g)
|
||||
|
||||
self.assertEqual(nodes, {1, 2, 3})
|
||||
|
||||
|
|
|
|||
|
|
@ -1617,8 +1617,7 @@ except RuntimeError as e:
|
|||
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
|
||||
dataloader = self._get_data_loader(dataset, batch_size=batch_size, num_workers=num_workers)
|
||||
seeds = set()
|
||||
for batch in dataloader:
|
||||
seeds.add(batch[0])
|
||||
seeds.update(batch[0] for batch in dataloader)
|
||||
self.assertEqual(len(seeds), num_workers)
|
||||
|
||||
def test_worker_seed_reproducibility(self):
|
||||
|
|
|
|||
|
|
@ -9523,8 +9523,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
|
||||
device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
|
||||
device_hash_set = set()
|
||||
for device in device_set:
|
||||
device_hash_set.add(hash(torch.device(device)))
|
||||
device_hash_set.update(hash(torch.device(device)) for device in device_set)
|
||||
self.assertEqual(len(device_set), len(device_hash_set))
|
||||
|
||||
def get_expected_device_repr(device):
|
||||
|
|
|
|||
|
|
@ -3233,17 +3233,19 @@ if torch.distributed.is_available():
|
|||
|
||||
@functools.lru_cache(None)
|
||||
def get_legacy_mod_inlinelist():
|
||||
inlinelist = set()
|
||||
for m in LEGACY_MOD_INLINELIST:
|
||||
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
|
||||
inlinelist = {
|
||||
_module_dir(torch) + m[len("torch.") :].replace(".", "/")
|
||||
for m in LEGACY_MOD_INLINELIST
|
||||
}
|
||||
return inlinelist
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_mod_inlinelist():
|
||||
inlinelist = set()
|
||||
for m in MOD_INLINELIST:
|
||||
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
|
||||
inlinelist = {
|
||||
_module_dir(torch) + m[len("torch.") :].replace(".", "/")
|
||||
for m in MOD_INLINELIST
|
||||
}
|
||||
return inlinelist
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -744,8 +744,7 @@ def min_cut_rematerialization_partition(
|
|||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
if node in required_bw_nodes:
|
||||
for user in node.users:
|
||||
required_bw_nodes.add(user)
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(
|
||||
|
|
|
|||
|
|
@ -3623,8 +3623,7 @@ class CppScheduling(BaseScheduling):
|
|||
if var_ranges is None:
|
||||
var_ranges = v
|
||||
assert var_ranges == v, (var_ranges, v, node.snodes)
|
||||
for expr in exprs:
|
||||
indexing_exprs.add(expr)
|
||||
indexing_exprs.update(exprs)
|
||||
return var_ranges, list(indexing_exprs)
|
||||
else:
|
||||
assert isinstance(node, SchedulerNode)
|
||||
|
|
|
|||
|
|
@ -635,8 +635,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
# - sebotnet33ts_256
|
||||
for n in self.module.graph.nodes:
|
||||
if n in output_set:
|
||||
for child in n.users:
|
||||
output_set.add(child)
|
||||
output_set.update(n.users)
|
||||
|
||||
return output_set
|
||||
|
||||
|
|
|
|||
|
|
@ -89,8 +89,9 @@ def add_needs_realized_inputs(fn):
|
|||
return [add_needs_realized_inputs(x) for x in fn]
|
||||
needs_realized_inputs.add(fn)
|
||||
if isinstance(fn, torch._ops.OpOverloadPacket):
|
||||
for overload in fn.overloads():
|
||||
needs_realized_inputs.add(getattr(fn, overload))
|
||||
needs_realized_inputs.update(
|
||||
getattr(fn, overload) for overload in fn.overloads()
|
||||
)
|
||||
|
||||
|
||||
def add_layout_constraint(fn, constraint):
|
||||
|
|
|
|||
|
|
@ -2292,9 +2292,7 @@ class Scheduler:
|
|||
Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
|
||||
"""
|
||||
|
||||
future_used_buffers = set()
|
||||
for node_name in V.graph.get_output_names():
|
||||
future_used_buffers.add(node_name)
|
||||
future_used_buffers = set(V.graph.get_output_names())
|
||||
|
||||
for node in reversed(self.nodes):
|
||||
node.set_last_usage(future_used_buffers, self.mutation_real_name)
|
||||
|
|
|
|||
|
|
@ -223,9 +223,10 @@ class CustomOpDef:
|
|||
def backend_impl(*args, **kwargs):
|
||||
# Checks the assumption that outputs cannot alias
|
||||
# inputs or other outputs.
|
||||
storages = set()
|
||||
for tensor in iter_tensors(args, kwargs):
|
||||
storages.add(id(tensor.untyped_storage()))
|
||||
storages = {
|
||||
id(tensor.untyped_storage())
|
||||
for tensor in iter_tensors(args, kwargs)
|
||||
}
|
||||
|
||||
result = self._backend_fns[device_type](*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -742,8 +742,7 @@ def create_add_loggers_graph(
|
|||
insert_submodule_copy = False
|
||||
if maybe_subgraph is not None:
|
||||
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
|
||||
for node_to_skip in maybe_subgraph:
|
||||
nodes_to_skip.add(node_to_skip)
|
||||
nodes_to_skip.update(maybe_subgraph)
|
||||
qconfig = node_name_to_qconfig[first_node.name]
|
||||
if qconfig is not None:
|
||||
insert_submodule_copy = True
|
||||
|
|
@ -873,8 +872,7 @@ def create_add_loggers_graph(
|
|||
maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
|
||||
if maybe_subgraph is not None:
|
||||
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
|
||||
for node_to_skip in maybe_subgraph:
|
||||
nodes_to_skip.add(node_to_skip)
|
||||
nodes_to_skip.update(maybe_subgraph)
|
||||
else:
|
||||
first_node, last_node = n, n
|
||||
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ class EmbeddingQuantizer(Quantizer):
|
|||
|
||||
@classmethod
|
||||
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
||||
op_configs: Set[QuantizationConfig] = set({})
|
||||
for spec, _ in cls.get_supported_operators():
|
||||
op_configs.add(spec)
|
||||
op_configs: Set[QuantizationConfig] = {
|
||||
spec for spec, _ in cls.get_supported_operators()
|
||||
}
|
||||
return list(op_configs)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -286,9 +286,9 @@ class X86InductorQuantizer(Quantizer):
|
|||
|
||||
@classmethod
|
||||
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
||||
op_configs: Set[QuantizationConfig] = set({})
|
||||
for spec, _ in cls.supported_config_and_operators:
|
||||
op_configs.add(spec)
|
||||
op_configs: Set[QuantizationConfig] = {
|
||||
spec for spec, _ in cls.supported_config_and_operators
|
||||
}
|
||||
return list(op_configs)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -305,9 +305,9 @@ class XNNPACKQuantizer(Quantizer):
|
|||
|
||||
@classmethod
|
||||
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
||||
op_configs: Set[QuantizationConfig] = set({})
|
||||
for spec, _ in cls.supported_config_and_operators:
|
||||
op_configs.add(spec)
|
||||
op_configs: Set[QuantizationConfig] = {
|
||||
spec for spec, _ in cls.supported_config_and_operators
|
||||
}
|
||||
return list(op_configs)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import itertools
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy
|
||||
from torch.distributed._tensor.placement_types import (
|
||||
|
|
@ -44,10 +44,9 @@ class EinsumDims:
|
|||
Parse the dims and extract the contracting, batch, and free dimensions
|
||||
for the left and right hand sides.
|
||||
"""
|
||||
dim_char_set = set()
|
||||
dim_char_set: Set[str] = set()
|
||||
for input_dim in input_dims:
|
||||
for input_char in list(input_dim):
|
||||
dim_char_set.add(input_char)
|
||||
dim_char_set.update(input_dim)
|
||||
|
||||
# get a determinisitc order of all dim chars
|
||||
all_dim_chars = sorted(dim_char_set)
|
||||
|
|
|
|||
|
|
@ -218,7 +218,7 @@ def _verify_options(
|
|||
fqn_param_mapping[fqn] = param
|
||||
all_fqns.add(fqn)
|
||||
|
||||
submodule_prefixes = set()
|
||||
submodule_prefixes: Set[str] = set()
|
||||
if submodules:
|
||||
submodules = set(submodules)
|
||||
for name, module in model.named_modules():
|
||||
|
|
@ -226,8 +226,7 @@ def _verify_options(
|
|||
continue
|
||||
fqns = _get_fqns(model, name)
|
||||
assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
|
||||
for fqn in fqns:
|
||||
submodule_prefixes.add(f"{fqn}.")
|
||||
submodule_prefixes.update(f"{fqn}." for fqn in fqns)
|
||||
|
||||
fsdp_modules = FSDP.fsdp_modules(model)
|
||||
state_dict_config: StateDictConfig
|
||||
|
|
|
|||
|
|
@ -112,9 +112,7 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
|
|||
|
||||
|
||||
def _format_import_block(globals: Dict[str, Any], importer: Importer):
|
||||
import_strs: Set[str] = set()
|
||||
for name, obj in globals.items():
|
||||
import_strs.add(_format_import_statement(name, obj, importer))
|
||||
import_strs: Set[str] = {_format_import_statement(name, obj, importer) for name, obj in globals.items()}
|
||||
# Sort the imports so we have a stable import block that allows us to
|
||||
# hash the graph module and get a consistent key for use in a cache.
|
||||
return "\n".join(sorted(import_strs))
|
||||
|
|
|
|||
|
|
@ -294,8 +294,7 @@ def _replace_pattern(
|
|||
# Copy the replacement graph over
|
||||
user_nodes: Set[Node] = set()
|
||||
for n in match.returning_nodes:
|
||||
for user in n.users:
|
||||
user_nodes.add(user)
|
||||
user_nodes.update(n.users)
|
||||
assert user_nodes, "The returning_nodes should have at least one user node"
|
||||
|
||||
if len(user_nodes) == 1:
|
||||
|
|
|
|||
|
|
@ -930,8 +930,9 @@ class MemoryProfile:
|
|||
self._is_gradient(*i) or i in used_for_gradient
|
||||
for i in node.outputs.items()
|
||||
):
|
||||
for key, (_, version) in node.inputs.items():
|
||||
used_for_gradient.add((key, version))
|
||||
used_for_gradient.update(
|
||||
(key, version) for key, (_, version) in node.inputs.items()
|
||||
)
|
||||
candidate_parameters.intersection_update(used_for_gradient)
|
||||
|
||||
# and depends on a gradient.
|
||||
|
|
|
|||
|
|
@ -34,9 +34,7 @@ def _strip_datapipe_from_name(name: str) -> str:
|
|||
def _generate_input_args_string(obj):
|
||||
"""Generate a string for the input arguments of an object."""
|
||||
signature = inspect.signature(obj.__class__)
|
||||
input_param_names = set()
|
||||
for param_name in signature.parameters.keys():
|
||||
input_param_names.add(param_name)
|
||||
input_param_names = set(signature.parameters.keys())
|
||||
result = []
|
||||
for name, value in inspect.getmembers(obj):
|
||||
if name in input_param_names:
|
||||
|
|
|
|||
|
|
@ -578,10 +578,8 @@ def _compute_in_out(ops):
|
|||
out_blobs = set()
|
||||
|
||||
for op in ops:
|
||||
for input_blob in op.input:
|
||||
in_blobs.add(input_blob)
|
||||
for output_blob in op.output:
|
||||
out_blobs.add(output_blob)
|
||||
in_blobs.update(op.input)
|
||||
out_blobs.update(op.output)
|
||||
|
||||
input_blobs = list(in_blobs.difference(out_blobs))
|
||||
output_blobs = list(out_blobs.difference(in_blobs))
|
||||
|
|
@ -700,8 +698,7 @@ def _operators_to_graph_def(
|
|||
else [_operator_to_node(shapes, op)]
|
||||
) # .extend() expects an iterable
|
||||
current_graph.node.extend(nodes_from_op)
|
||||
for input_blob in op.input:
|
||||
blobs.add(input_blob)
|
||||
blobs.update(op.input)
|
||||
for i, output_blob in enumerate(op.output):
|
||||
blobs.add(output_blob)
|
||||
producing_ops.setdefault(output_blob, []).append((op, i))
|
||||
|
|
|
|||
|
|
@ -2125,7 +2125,7 @@ def gen_headers(
|
|||
)
|
||||
|
||||
def gen_aten_interned_strings() -> Dict[str, str]:
|
||||
attrs = set() # All function argument names
|
||||
attrs: Set[str] = set() # All function argument names
|
||||
names = set() # All ATen function names
|
||||
for func in native_functions:
|
||||
names.add(str(func.func.name.name))
|
||||
|
|
@ -2133,8 +2133,7 @@ def gen_headers(
|
|||
# symbol without the underscore
|
||||
names.add(func.func.name.name.base)
|
||||
|
||||
for arg in func.func.schema_order_arguments():
|
||||
attrs.add(arg.name)
|
||||
attrs.update(arg.name for arg in func.func.schema_order_arguments())
|
||||
|
||||
# These are keywords in C++, so aren't valid symbol names
|
||||
# https://en.cppreference.com/w/cpp/language/operator_alternative
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user