mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[BE]: FURB142 - Remove set mutations. Use set update (#124551)
Uses set mutation methods instead of manually reimplementing (update, set_difference etc). Pull Request resolved: https://github.com/pytorch/pytorch/pull/124551 Approved by: https://github.com/ezyang
This commit is contained in:
parent
5a1216bb2e
commit
29cc293725
|
|
@ -167,11 +167,9 @@ def refresh_model_names():
|
||||||
del all_models_family[key]
|
del all_models_family[key]
|
||||||
|
|
||||||
chosen_models = set()
|
chosen_models = set()
|
||||||
for value in docs_models_family.values():
|
chosen_models.update(value[0] for value in docs_models_family.values())
|
||||||
chosen_models.add(value[0])
|
|
||||||
|
|
||||||
for key, value in all_models_family.items():
|
chosen_models.update(value[0] for key, value in all_models_family.items())
|
||||||
chosen_models.add(value[0])
|
|
||||||
|
|
||||||
filename = "timm_models_list.txt"
|
filename = "timm_models_list.txt"
|
||||||
if os.path.exists("benchmarks"):
|
if os.path.exists("benchmarks"):
|
||||||
|
|
|
||||||
|
|
@ -345,8 +345,9 @@ def get_operator_range(chars_range):
|
||||||
ops_start_chars_set.add(item.lower())
|
ops_start_chars_set.add(item.lower())
|
||||||
continue
|
continue
|
||||||
start, end = item.split("-")
|
start, end = item.split("-")
|
||||||
for c in range(ord(start), ord(end) + 1):
|
ops_start_chars_set.update(
|
||||||
ops_start_chars_set.add(chr(c).lower())
|
chr(c).lower() for c in range(ord(start), ord(end) + 1)
|
||||||
|
)
|
||||||
return ops_start_chars_set
|
return ops_start_chars_set
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -144,10 +144,12 @@ class TestInitialization(FSDPTest):
|
||||||
# Check that the composable module does not add any wrapper class
|
# Check that the composable module does not add any wrapper class
|
||||||
local_module_classes = set()
|
local_module_classes = set()
|
||||||
composable_module_classes = set()
|
composable_module_classes = set()
|
||||||
for submodule in local_model.modules():
|
local_module_classes.update(
|
||||||
local_module_classes.add(type(submodule))
|
type(submodule) for submodule in local_model.modules()
|
||||||
for submodule in composable_module.modules():
|
)
|
||||||
composable_module_classes.add(type(submodule))
|
composable_module_classes.update(
|
||||||
|
type(submodule) for submodule in composable_module.modules()
|
||||||
|
)
|
||||||
self.assertEqual(local_module_classes, composable_module_classes)
|
self.assertEqual(local_module_classes, composable_module_classes)
|
||||||
|
|
||||||
# Check that the composable module has the same FSDP states with the
|
# Check that the composable module has the same FSDP states with the
|
||||||
|
|
@ -310,14 +312,14 @@ class TestInitialization(FSDPTest):
|
||||||
]
|
]
|
||||||
for data_structure_name in data_structure_names:
|
for data_structure_name in data_structure_names:
|
||||||
all_structures = set()
|
all_structures = set()
|
||||||
for module in (
|
all_structures.update(
|
||||||
composable_module.u1,
|
id(getattr(fully_shard.state(module), data_structure_name))
|
||||||
composable_module.u2,
|
for module in (
|
||||||
composable_module,
|
composable_module.u1,
|
||||||
):
|
composable_module.u2,
|
||||||
all_structures.add(
|
composable_module,
|
||||||
id(getattr(fully_shard.state(module), data_structure_name))
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
self.assertEqual(len(all_structures), 1)
|
self.assertEqual(len(all_structures), 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -945,8 +945,7 @@ class TestWrapUtils(TestCase):
|
||||||
ignored_params = set()
|
ignored_params = set()
|
||||||
for module_name, module in model.named_modules():
|
for module_name, module in model.named_modules():
|
||||||
if "lora_A" in module_name:
|
if "lora_A" in module_name:
|
||||||
for param in module.parameters():
|
ignored_params.update(module.parameters())
|
||||||
ignored_params.add(param)
|
|
||||||
_validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params)
|
_validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1375,8 +1375,7 @@ def forward(self, getitem, const):
|
||||||
|
|
||||||
cond_gm = backend.graphs[0]
|
cond_gm = backend.graphs[0]
|
||||||
name_set = set()
|
name_set = set()
|
||||||
for name, _ in cond_gm.named_modules():
|
name_set.update(name for name, _ in cond_gm.named_modules())
|
||||||
name_set.add(name)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
name_set,
|
name_set,
|
||||||
{
|
{
|
||||||
|
|
@ -1735,8 +1734,7 @@ def forward(self):
|
||||||
self.assertEqual(result, x + y + x)
|
self.assertEqual(result, x + y + x)
|
||||||
wrap_gm = backend.graphs[0]
|
wrap_gm = backend.graphs[0]
|
||||||
names = set()
|
names = set()
|
||||||
for mod_name, _ in wrap_gm.named_modules():
|
names.update(mod_name for mod_name, _ in wrap_gm.named_modules())
|
||||||
names.add(mod_name)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
names,
|
names,
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -365,8 +365,7 @@ def get_all_tested_ops():
|
||||||
result = set({})
|
result = set({})
|
||||||
for op in get_covered_ops(overridable_outplace_we_care_about).values():
|
for op in get_covered_ops(overridable_outplace_we_care_about).values():
|
||||||
opinfos = op_to_opinfo[op]
|
opinfos = op_to_opinfo[op]
|
||||||
for opinfo in opinfos:
|
result.update(opinfo.name for opinfo in opinfos)
|
||||||
result.add(opinfo.name)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,8 +79,7 @@ class TestDiGraph(PackageTestCase):
|
||||||
g.add_node(3)
|
g.add_node(3)
|
||||||
|
|
||||||
nodes = set()
|
nodes = set()
|
||||||
for n in g:
|
nodes.update(g)
|
||||||
nodes.add(n)
|
|
||||||
|
|
||||||
self.assertEqual(nodes, {1, 2, 3})
|
self.assertEqual(nodes, {1, 2, 3})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1617,8 +1617,7 @@ except RuntimeError as e:
|
||||||
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
|
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
|
||||||
dataloader = self._get_data_loader(dataset, batch_size=batch_size, num_workers=num_workers)
|
dataloader = self._get_data_loader(dataset, batch_size=batch_size, num_workers=num_workers)
|
||||||
seeds = set()
|
seeds = set()
|
||||||
for batch in dataloader:
|
seeds.update(batch[0] for batch in dataloader)
|
||||||
seeds.add(batch[0])
|
|
||||||
self.assertEqual(len(seeds), num_workers)
|
self.assertEqual(len(seeds), num_workers)
|
||||||
|
|
||||||
def test_worker_seed_reproducibility(self):
|
def test_worker_seed_reproducibility(self):
|
||||||
|
|
|
||||||
|
|
@ -9523,8 +9523,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||||
|
|
||||||
device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
|
device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
|
||||||
device_hash_set = set()
|
device_hash_set = set()
|
||||||
for device in device_set:
|
device_hash_set.update(hash(torch.device(device)) for device in device_set)
|
||||||
device_hash_set.add(hash(torch.device(device)))
|
|
||||||
self.assertEqual(len(device_set), len(device_hash_set))
|
self.assertEqual(len(device_set), len(device_hash_set))
|
||||||
|
|
||||||
def get_expected_device_repr(device):
|
def get_expected_device_repr(device):
|
||||||
|
|
|
||||||
|
|
@ -3233,17 +3233,19 @@ if torch.distributed.is_available():
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def get_legacy_mod_inlinelist():
|
def get_legacy_mod_inlinelist():
|
||||||
inlinelist = set()
|
inlinelist = {
|
||||||
for m in LEGACY_MOD_INLINELIST:
|
_module_dir(torch) + m[len("torch.") :].replace(".", "/")
|
||||||
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
|
for m in LEGACY_MOD_INLINELIST
|
||||||
|
}
|
||||||
return inlinelist
|
return inlinelist
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def get_mod_inlinelist():
|
def get_mod_inlinelist():
|
||||||
inlinelist = set()
|
inlinelist = {
|
||||||
for m in MOD_INLINELIST:
|
_module_dir(torch) + m[len("torch.") :].replace(".", "/")
|
||||||
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
|
for m in MOD_INLINELIST
|
||||||
|
}
|
||||||
return inlinelist
|
return inlinelist
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -744,8 +744,7 @@ def min_cut_rematerialization_partition(
|
||||||
if node.op == "placeholder" and "tangents" in node.target:
|
if node.op == "placeholder" and "tangents" in node.target:
|
||||||
required_bw_nodes.add(node)
|
required_bw_nodes.add(node)
|
||||||
if node in required_bw_nodes:
|
if node in required_bw_nodes:
|
||||||
for user in node.users:
|
required_bw_nodes.update(node.users)
|
||||||
required_bw_nodes.add(user)
|
|
||||||
|
|
||||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||||
fwd_seed_offset_inputs = list(
|
fwd_seed_offset_inputs = list(
|
||||||
|
|
|
||||||
|
|
@ -3623,8 +3623,7 @@ class CppScheduling(BaseScheduling):
|
||||||
if var_ranges is None:
|
if var_ranges is None:
|
||||||
var_ranges = v
|
var_ranges = v
|
||||||
assert var_ranges == v, (var_ranges, v, node.snodes)
|
assert var_ranges == v, (var_ranges, v, node.snodes)
|
||||||
for expr in exprs:
|
indexing_exprs.update(exprs)
|
||||||
indexing_exprs.add(expr)
|
|
||||||
return var_ranges, list(indexing_exprs)
|
return var_ranges, list(indexing_exprs)
|
||||||
else:
|
else:
|
||||||
assert isinstance(node, SchedulerNode)
|
assert isinstance(node, SchedulerNode)
|
||||||
|
|
|
||||||
|
|
@ -635,8 +635,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
# - sebotnet33ts_256
|
# - sebotnet33ts_256
|
||||||
for n in self.module.graph.nodes:
|
for n in self.module.graph.nodes:
|
||||||
if n in output_set:
|
if n in output_set:
|
||||||
for child in n.users:
|
output_set.update(n.users)
|
||||||
output_set.add(child)
|
|
||||||
|
|
||||||
return output_set
|
return output_set
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -89,8 +89,9 @@ def add_needs_realized_inputs(fn):
|
||||||
return [add_needs_realized_inputs(x) for x in fn]
|
return [add_needs_realized_inputs(x) for x in fn]
|
||||||
needs_realized_inputs.add(fn)
|
needs_realized_inputs.add(fn)
|
||||||
if isinstance(fn, torch._ops.OpOverloadPacket):
|
if isinstance(fn, torch._ops.OpOverloadPacket):
|
||||||
for overload in fn.overloads():
|
needs_realized_inputs.update(
|
||||||
needs_realized_inputs.add(getattr(fn, overload))
|
getattr(fn, overload) for overload in fn.overloads()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_layout_constraint(fn, constraint):
|
def add_layout_constraint(fn, constraint):
|
||||||
|
|
|
||||||
|
|
@ -2292,9 +2292,7 @@ class Scheduler:
|
||||||
Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
|
Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
future_used_buffers = set()
|
future_used_buffers = set(V.graph.get_output_names())
|
||||||
for node_name in V.graph.get_output_names():
|
|
||||||
future_used_buffers.add(node_name)
|
|
||||||
|
|
||||||
for node in reversed(self.nodes):
|
for node in reversed(self.nodes):
|
||||||
node.set_last_usage(future_used_buffers, self.mutation_real_name)
|
node.set_last_usage(future_used_buffers, self.mutation_real_name)
|
||||||
|
|
|
||||||
|
|
@ -223,9 +223,10 @@ class CustomOpDef:
|
||||||
def backend_impl(*args, **kwargs):
|
def backend_impl(*args, **kwargs):
|
||||||
# Checks the assumption that outputs cannot alias
|
# Checks the assumption that outputs cannot alias
|
||||||
# inputs or other outputs.
|
# inputs or other outputs.
|
||||||
storages = set()
|
storages = {
|
||||||
for tensor in iter_tensors(args, kwargs):
|
id(tensor.untyped_storage())
|
||||||
storages.add(id(tensor.untyped_storage()))
|
for tensor in iter_tensors(args, kwargs)
|
||||||
|
}
|
||||||
|
|
||||||
result = self._backend_fns[device_type](*args, **kwargs)
|
result = self._backend_fns[device_type](*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -742,8 +742,7 @@ def create_add_loggers_graph(
|
||||||
insert_submodule_copy = False
|
insert_submodule_copy = False
|
||||||
if maybe_subgraph is not None:
|
if maybe_subgraph is not None:
|
||||||
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
|
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
|
||||||
for node_to_skip in maybe_subgraph:
|
nodes_to_skip.update(maybe_subgraph)
|
||||||
nodes_to_skip.add(node_to_skip)
|
|
||||||
qconfig = node_name_to_qconfig[first_node.name]
|
qconfig = node_name_to_qconfig[first_node.name]
|
||||||
if qconfig is not None:
|
if qconfig is not None:
|
||||||
insert_submodule_copy = True
|
insert_submodule_copy = True
|
||||||
|
|
@ -873,8 +872,7 @@ def create_add_loggers_graph(
|
||||||
maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
|
maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
|
||||||
if maybe_subgraph is not None:
|
if maybe_subgraph is not None:
|
||||||
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
|
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
|
||||||
for node_to_skip in maybe_subgraph:
|
nodes_to_skip.update(maybe_subgraph)
|
||||||
nodes_to_skip.add(node_to_skip)
|
|
||||||
else:
|
else:
|
||||||
first_node, last_node = n, n
|
first_node, last_node = n, n
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,9 +45,9 @@ class EmbeddingQuantizer(Quantizer):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
||||||
op_configs: Set[QuantizationConfig] = set({})
|
op_configs: Set[QuantizationConfig] = {
|
||||||
for spec, _ in cls.get_supported_operators():
|
spec for spec, _ in cls.get_supported_operators()
|
||||||
op_configs.add(spec)
|
}
|
||||||
return list(op_configs)
|
return list(op_configs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -286,9 +286,9 @@ class X86InductorQuantizer(Quantizer):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
||||||
op_configs: Set[QuantizationConfig] = set({})
|
op_configs: Set[QuantizationConfig] = {
|
||||||
for spec, _ in cls.supported_config_and_operators:
|
spec for spec, _ in cls.supported_config_and_operators
|
||||||
op_configs.add(spec)
|
}
|
||||||
return list(op_configs)
|
return list(op_configs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -305,9 +305,9 @@ class XNNPACKQuantizer(Quantizer):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
||||||
op_configs: Set[QuantizationConfig] = set({})
|
op_configs: Set[QuantizationConfig] = {
|
||||||
for spec, _ in cls.supported_config_and_operators:
|
spec for spec, _ in cls.supported_config_and_operators
|
||||||
op_configs.add(spec)
|
}
|
||||||
return list(op_configs)
|
return list(op_configs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import itertools
|
import itertools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from typing import List, Tuple
|
from typing import List, Set, Tuple
|
||||||
|
|
||||||
from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy
|
from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed._tensor.placement_types import (
|
||||||
|
|
@ -44,10 +44,9 @@ class EinsumDims:
|
||||||
Parse the dims and extract the contracting, batch, and free dimensions
|
Parse the dims and extract the contracting, batch, and free dimensions
|
||||||
for the left and right hand sides.
|
for the left and right hand sides.
|
||||||
"""
|
"""
|
||||||
dim_char_set = set()
|
dim_char_set: Set[str] = set()
|
||||||
for input_dim in input_dims:
|
for input_dim in input_dims:
|
||||||
for input_char in list(input_dim):
|
dim_char_set.update(input_dim)
|
||||||
dim_char_set.add(input_char)
|
|
||||||
|
|
||||||
# get a determinisitc order of all dim chars
|
# get a determinisitc order of all dim chars
|
||||||
all_dim_chars = sorted(dim_char_set)
|
all_dim_chars = sorted(dim_char_set)
|
||||||
|
|
|
||||||
|
|
@ -218,7 +218,7 @@ def _verify_options(
|
||||||
fqn_param_mapping[fqn] = param
|
fqn_param_mapping[fqn] = param
|
||||||
all_fqns.add(fqn)
|
all_fqns.add(fqn)
|
||||||
|
|
||||||
submodule_prefixes = set()
|
submodule_prefixes: Set[str] = set()
|
||||||
if submodules:
|
if submodules:
|
||||||
submodules = set(submodules)
|
submodules = set(submodules)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
|
@ -226,8 +226,7 @@ def _verify_options(
|
||||||
continue
|
continue
|
||||||
fqns = _get_fqns(model, name)
|
fqns = _get_fqns(model, name)
|
||||||
assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
|
assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
|
||||||
for fqn in fqns:
|
submodule_prefixes.update(f"{fqn}." for fqn in fqns)
|
||||||
submodule_prefixes.add(f"{fqn}.")
|
|
||||||
|
|
||||||
fsdp_modules = FSDP.fsdp_modules(model)
|
fsdp_modules = FSDP.fsdp_modules(model)
|
||||||
state_dict_config: StateDictConfig
|
state_dict_config: StateDictConfig
|
||||||
|
|
|
||||||
|
|
@ -112,9 +112,7 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _format_import_block(globals: Dict[str, Any], importer: Importer):
|
def _format_import_block(globals: Dict[str, Any], importer: Importer):
|
||||||
import_strs: Set[str] = set()
|
import_strs: Set[str] = {_format_import_statement(name, obj, importer) for name, obj in globals.items()}
|
||||||
for name, obj in globals.items():
|
|
||||||
import_strs.add(_format_import_statement(name, obj, importer))
|
|
||||||
# Sort the imports so we have a stable import block that allows us to
|
# Sort the imports so we have a stable import block that allows us to
|
||||||
# hash the graph module and get a consistent key for use in a cache.
|
# hash the graph module and get a consistent key for use in a cache.
|
||||||
return "\n".join(sorted(import_strs))
|
return "\n".join(sorted(import_strs))
|
||||||
|
|
|
||||||
|
|
@ -294,8 +294,7 @@ def _replace_pattern(
|
||||||
# Copy the replacement graph over
|
# Copy the replacement graph over
|
||||||
user_nodes: Set[Node] = set()
|
user_nodes: Set[Node] = set()
|
||||||
for n in match.returning_nodes:
|
for n in match.returning_nodes:
|
||||||
for user in n.users:
|
user_nodes.update(n.users)
|
||||||
user_nodes.add(user)
|
|
||||||
assert user_nodes, "The returning_nodes should have at least one user node"
|
assert user_nodes, "The returning_nodes should have at least one user node"
|
||||||
|
|
||||||
if len(user_nodes) == 1:
|
if len(user_nodes) == 1:
|
||||||
|
|
|
||||||
|
|
@ -930,8 +930,9 @@ class MemoryProfile:
|
||||||
self._is_gradient(*i) or i in used_for_gradient
|
self._is_gradient(*i) or i in used_for_gradient
|
||||||
for i in node.outputs.items()
|
for i in node.outputs.items()
|
||||||
):
|
):
|
||||||
for key, (_, version) in node.inputs.items():
|
used_for_gradient.update(
|
||||||
used_for_gradient.add((key, version))
|
(key, version) for key, (_, version) in node.inputs.items()
|
||||||
|
)
|
||||||
candidate_parameters.intersection_update(used_for_gradient)
|
candidate_parameters.intersection_update(used_for_gradient)
|
||||||
|
|
||||||
# and depends on a gradient.
|
# and depends on a gradient.
|
||||||
|
|
|
||||||
|
|
@ -34,9 +34,7 @@ def _strip_datapipe_from_name(name: str) -> str:
|
||||||
def _generate_input_args_string(obj):
|
def _generate_input_args_string(obj):
|
||||||
"""Generate a string for the input arguments of an object."""
|
"""Generate a string for the input arguments of an object."""
|
||||||
signature = inspect.signature(obj.__class__)
|
signature = inspect.signature(obj.__class__)
|
||||||
input_param_names = set()
|
input_param_names = set(signature.parameters.keys())
|
||||||
for param_name in signature.parameters.keys():
|
|
||||||
input_param_names.add(param_name)
|
|
||||||
result = []
|
result = []
|
||||||
for name, value in inspect.getmembers(obj):
|
for name, value in inspect.getmembers(obj):
|
||||||
if name in input_param_names:
|
if name in input_param_names:
|
||||||
|
|
|
||||||
|
|
@ -578,10 +578,8 @@ def _compute_in_out(ops):
|
||||||
out_blobs = set()
|
out_blobs = set()
|
||||||
|
|
||||||
for op in ops:
|
for op in ops:
|
||||||
for input_blob in op.input:
|
in_blobs.update(op.input)
|
||||||
in_blobs.add(input_blob)
|
out_blobs.update(op.output)
|
||||||
for output_blob in op.output:
|
|
||||||
out_blobs.add(output_blob)
|
|
||||||
|
|
||||||
input_blobs = list(in_blobs.difference(out_blobs))
|
input_blobs = list(in_blobs.difference(out_blobs))
|
||||||
output_blobs = list(out_blobs.difference(in_blobs))
|
output_blobs = list(out_blobs.difference(in_blobs))
|
||||||
|
|
@ -700,8 +698,7 @@ def _operators_to_graph_def(
|
||||||
else [_operator_to_node(shapes, op)]
|
else [_operator_to_node(shapes, op)]
|
||||||
) # .extend() expects an iterable
|
) # .extend() expects an iterable
|
||||||
current_graph.node.extend(nodes_from_op)
|
current_graph.node.extend(nodes_from_op)
|
||||||
for input_blob in op.input:
|
blobs.update(op.input)
|
||||||
blobs.add(input_blob)
|
|
||||||
for i, output_blob in enumerate(op.output):
|
for i, output_blob in enumerate(op.output):
|
||||||
blobs.add(output_blob)
|
blobs.add(output_blob)
|
||||||
producing_ops.setdefault(output_blob, []).append((op, i))
|
producing_ops.setdefault(output_blob, []).append((op, i))
|
||||||
|
|
|
||||||
|
|
@ -2125,7 +2125,7 @@ def gen_headers(
|
||||||
)
|
)
|
||||||
|
|
||||||
def gen_aten_interned_strings() -> Dict[str, str]:
|
def gen_aten_interned_strings() -> Dict[str, str]:
|
||||||
attrs = set() # All function argument names
|
attrs: Set[str] = set() # All function argument names
|
||||||
names = set() # All ATen function names
|
names = set() # All ATen function names
|
||||||
for func in native_functions:
|
for func in native_functions:
|
||||||
names.add(str(func.func.name.name))
|
names.add(str(func.func.name.name))
|
||||||
|
|
@ -2133,8 +2133,7 @@ def gen_headers(
|
||||||
# symbol without the underscore
|
# symbol without the underscore
|
||||||
names.add(func.func.name.name.base)
|
names.add(func.func.name.name.base)
|
||||||
|
|
||||||
for arg in func.func.schema_order_arguments():
|
attrs.update(arg.name for arg in func.func.schema_order_arguments())
|
||||||
attrs.add(arg.name)
|
|
||||||
|
|
||||||
# These are keywords in C++, so aren't valid symbol names
|
# These are keywords in C++, so aren't valid symbol names
|
||||||
# https://en.cppreference.com/w/cpp/language/operator_alternative
|
# https://en.cppreference.com/w/cpp/language/operator_alternative
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user