[BE] Apply almost all remaining flake8-comprehension checks (#94676)

Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan 2023-02-12 01:01:21 +00:00 committed by PyTorch MergeBot
parent 54c0f37646
commit 67d9790985
113 changed files with 500 additions and 526 deletions

View File

@ -11,7 +11,7 @@ ignore =
# these ignores are from flake8-bugbear; please fix!
B007,B008,
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C405,C407
C407
per-file-ignores =
__init__.py: F401
torch/utils/cpp_extension.py: B950

View File

@ -25,7 +25,7 @@ def main():
ja = load(args.file[0])
jb = load(args.file[1])
keys = (set(ja.keys()) | set(jb.keys())) - set(["benchmark_results"])
keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"}
print("{:20s} {:>20s} {:>20s}".format("", "baseline", "test"))
print("{:20s} {:>20s} {:>20s}".format("", "-" * 20, "-" * 20))
for key in sorted(keys):

View File

@ -39,7 +39,7 @@ def get_content(submod):
return content
def namespace_filter(data):
out = set(d for d in data if d[0] != "_")
out = {d for d in data if d[0] != "_"}
return out
def run(args, submod):

View File

@ -417,7 +417,7 @@ class TestNearlyDiagonalSparsifier(TestCase):
assert torch.all(weights == torch.eye(height, width) * weights) # only diagonal to be present
def test_sparsity_levels(self):
nearliness_levels = list(nearliness for nearliness in range(-1, 100))
nearliness_levels = list(range(-1, 100))
model = nn.Sequential()
p = re.compile(r'[-\.\s]')

View File

@ -244,9 +244,9 @@ class TestFSDPIgnoredModules(FSDPTest):
{"ignored_modules": layer1_ignored_modules}
if ignore_modules
else {
"ignored_parameters": set(
"ignored_parameters": {
p for m in layer1_ignored_modules for p in m.parameters()
)
}
}
)
model.layer1 = FSDP(model.layer1, **ignore_kwargs)
@ -260,9 +260,9 @@ class TestFSDPIgnoredModules(FSDPTest):
{"ignored_modules": model_ignored_modules}
if ignore_modules
else {
"ignored_parameters": set(
"ignored_parameters": {
p for m in model_ignored_modules for p in m.parameters()
)
}
}
)
wrapped_model = FSDP(model, **ignore_kwargs_top)
@ -279,9 +279,9 @@ class TestFSDPIgnoredModules(FSDPTest):
{"ignored_modules": ignored_modules}
if ignore_modules
else {
"ignored_parameters": set(
"ignored_parameters": {
p for m in ignored_modules for p in m.parameters()
)
}
}
)

View File

@ -783,9 +783,7 @@ class TestFSDPStateDict(FSDPTest):
def test_fsdp_state_dict_keys(self, state_dict_type):
state_dict = self._state_dict(self._initialize_model(True), state_dict_type)
if state_dict_type == "local_state_dict":
self.assertEqual(
set([FLAT_PARAM, f"inner.{FLAT_PARAM}"]), state_dict.keys()
)
self.assertEqual({FLAT_PARAM, f"inner.{FLAT_PARAM}"}, state_dict.keys())
elif state_dict_type in ("state_dict", "sharded_state_dict"):
# Keys should match local model.
local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False)

View File

@ -66,8 +66,8 @@ class TestUtils(TestCase):
# create a mixed bag of data.
data = [1, "str"]
data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2))))
data.insert(0, {"x", get_a_tensor(), get_a_tensor()})
data.append(([1], get_a_tensor(), (1), [get_a_tensor()], {1, 2}))
data.append({"abc": SomeDataClass("some_key", 1.0, [get_a_tensor()])})
od = OrderedDict()
od["k"] = "value"

View File

@ -662,7 +662,7 @@ def test_named_children(setup_rpc):
model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
model = Pipe(model)
names = set(n for n, _ in model.named_modules())
names = {n for n, _ in model.named_modules()}
assert "partitions.0.0" in names
assert "partitions.1.0" in names

View File

@ -1120,7 +1120,7 @@ class AbstractCommTest:
)
self._test_sequence_num_incremented(
c10d._get_default_group(),
ranks=list(i for i in range(dist.get_world_size())),
ranks=list(range(dist.get_world_size())),
)
def _test_sequence_num_incremented_subgroup(self, backend_name):

View File

@ -2296,9 +2296,9 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
# The tensors to pass to broadcast are identical to the target
# only on the process that is the root of the broadcast.
if self.rank == root_rank:
tensors = list(tensor.clone() for tensor in target)
tensors = [tensor.clone() for tensor in target]
else:
tensors = list(torch.zeros_like(tensor) for tensor in target)
tensors = [torch.zeros_like(tensor) for tensor in target]
if self.rank != root_rank:
self.assertNotEqual(tensors, target)

View File

@ -2623,9 +2623,9 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
# The tensors to pass to broadcast are idential to the target
# only on the process that is the root of the broadcast.
if self.rank == root_rank:
tensors = list(tensor.clone() for tensor in target)
tensors = [tensor.clone() for tensor in target]
else:
tensors = list(torch.zeros_like(tensor) for tensor in target)
tensors = [torch.zeros_like(tensor) for tensor in target]
if self.rank != root_rank:
self.assertNotEqual(tensors, target)

View File

@ -55,15 +55,13 @@ class OptimizerTests(torch._dynamo.test_case.TestCase):
# exclude SparseAdam because other areas of the stack don't support it yet
# the others are handled specially above
exclude = set(
[
"SGD", # Handled above
"Optimizer",
"SparseAdam", # Unsupported
"LBFGS", # Unsupported
"RAdam", # Has data dependent control for rectification (needs symint)
]
)
exclude = {
"SGD", # Handled above
"Optimizer",
"SparseAdam", # Unsupported
"LBFGS", # Unsupported
"RAdam", # Has data dependent control for rectification (needs symint)
}
optimizers = [
opt

View File

@ -649,7 +649,9 @@ def _get_min_chunk_len(config):
return config.lsh_attn_chunk_length
elif len(attn_types_set) == 1 and attn_types[0] == "local":
return config.local_attn_chunk_length
elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]):
elif len(attn_types_set) == 2 and attn_types_set == set( # noqa: C405
["lsh", "local"]
):
return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
else:
raise NotImplementedError(

View File

@ -803,7 +803,7 @@ class OperatorSet:
def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)):
result = {}
for key in filter:
result[key] = set([])
result[key] = set()
for op in self.data:
support_status = operator_method(op)
if support_status in filter:

View File

@ -158,20 +158,20 @@ class TestTensorBuiltins(JitTestCase):
return x.{}
"""
EQUALITY_MISMATCH = set([
EQUALITY_MISMATCH = {
# TorchScript doesn't have real enums so they return an int instead
# of the actual value
'dtype',
'layout',
])
MISSING_PROPERTIES = set([
}
MISSING_PROPERTIES = {
'grad_fn',
# This is an undocumented property so it's not included
"output_nr",
# This has a longer implementation, maybe not worth copying to
# TorchScript if named tensors don't work there anyways
'names',
])
}
for p in properties:
if p in MISSING_PROPERTIES:

View File

@ -1516,7 +1516,7 @@ class TestDict(JitTestCase):
li.append(3)
return li
self.assertTrue(set(specialized_list()) == set([1, 2, 3]))
self.assertTrue(set(specialized_list()) == {1, 2, 3})
@skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
def test_values(self):

View File

@ -221,11 +221,11 @@ class TestMisc(JitTestCase):
torch._C._enable_mobile_interface_call_export()
scripted_M_mod = torch.jit.script(M())
self.assertTrue(set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal']).issubset(
self.assertTrue({'aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'}.issubset(
set(torch.jit.export_opnames(scripted_M_mod))))
scripted_M_mod.sub = torch.jit.script(FooMod())
self.assertTrue(set(['aten::add.Tensor', 'aten::mul.Scalar']).issubset(
self.assertTrue({'aten::add.Tensor', 'aten::mul.Scalar'}.issubset(
set(torch.jit.export_opnames(scripted_M_mod))))
def test_math_inf(self):

View File

@ -525,8 +525,8 @@ class TestSaveLoad(JitTestCase):
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
)
self.assertEqual(
set(name for name, _ in m.named_modules()),
set(name for name, _ in m_loaded.named_modules()),
{name for name, _ in m.named_modules()},
{name for name, _ in m_loaded.named_modules()},
)
# Check parameters.
m_params = dict(m.named_parameters())

View File

@ -133,7 +133,7 @@ class TestSlice(JitTestCase):
self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3))
tuple_graph = scripted_fn.graph
slices = tuple_graph.findAllNodes("prim::TupleConstruct")
num_outputs = set(len(x.output().type().elements()) for x in slices)
num_outputs = {len(x.output().type().elements()) for x in slices}
# there should be only one tupleSlice with length of 2
self.assertTrue(num_outputs == {2})
self.run_pass('lower_all_tuples', tuple_graph)

View File

@ -34,8 +34,8 @@ def init_lists():
yaml_ts = yaml.load(f, yaml.Loader)
LAZY_OPS_LIST = set(remove_suffixes(itertools.chain(yaml_ts["full_codegen"], yaml_ts["supported"], yaml_ts["autograd"])))
HAS_SYMINT_SUFFIX = yaml_ts["symint"]
FALLBACK_LIST = set(["clamp"])
SKIP_RUNTIME_ERROR_LIST = set([
FALLBACK_LIST = {"clamp"}
SKIP_RUNTIME_ERROR_LIST = {
'index_select', # Empty output_sizes is not supported
'clone', # is clone decomposed?
@ -46,19 +46,19 @@ def init_lists():
'all', # ASAN failure
'any', # ASAN failure
'logdet', # ASAN failure
])
SKIP_INCORRECT_RESULTS_LIST = set([
}
SKIP_INCORRECT_RESULTS_LIST = {
'squeeze', # Value out of range
't', # Value out of range
'transpose', # Value out of range
'bernoulli', # incorrect results
'pow', # incorrect results
'addcdiv', # incorrect results (on CI not locally?)
])
}
# The following ops all show up directly in ts_native_functions.yaml,
# but run functionalized versions of the composite kernels in core.
# This means that we don't expect the ops to show directly in the LTC metrics.
FUNCTIONAL_DECOMPOSE_LIST = set([
FUNCTIONAL_DECOMPOSE_LIST = {
'diag_embed',
'block_diag',
'new_empty_strided',
@ -70,13 +70,13 @@ def init_lists():
'linalg_inv_ex',
'linalg_pinv.atol_rtol_tensor',
'logsumexp',
])
}
# For some ops, we don't support all variants. Here we use formatted_name
# to uniquely identify the variant.
SKIP_VARIANT_LIST = set([
SKIP_VARIANT_LIST = {
'norm_nuc',
'min_reduction_with_dim'
])
}
return (LAZY_OPS_LIST,
FALLBACK_LIST,

View File

@ -31,7 +31,7 @@ class TestDependencyHooks(PackageTestCase):
exporter.register_extern_hook(my_extern_hook)
exporter.save_source_string("foo", "import module_a")
self.assertEqual(my_externs, set(["module_a"]))
self.assertEqual(my_externs, {"module_a"})
def test_multiple_extern_hooks(self):
buffer = BytesIO()
@ -93,7 +93,7 @@ class TestDependencyHooks(PackageTestCase):
exporter.save_source_string("foo", "import module_a")
self.assertEqual(my_externs, set())
self.assertEqual(my_externs2, set(["module_a"]))
self.assertEqual(my_externs2, {"module_a"})
def test_extern_and_mock_hook(self):
buffer = BytesIO()
@ -114,8 +114,8 @@ class TestDependencyHooks(PackageTestCase):
exporter.register_mock_hook(my_mock_hook)
exporter.save_source_string("foo", "import module_a; import package_a")
self.assertEqual(my_externs, set(["module_a"]))
self.assertEqual(my_mocks, set(["package_a"]))
self.assertEqual(my_externs, {"module_a"})
self.assertEqual(my_mocks, {"package_a"})
if __name__ == "__main__":

View File

@ -82,7 +82,7 @@ class TestDiGraph(PackageTestCase):
for n in g:
nodes.add(n)
self.assertEqual(nodes, set([1, 2, 3]))
self.assertEqual(nodes, {1, 2, 3})
def test_contains(self):
g = DiGraph()
@ -101,8 +101,8 @@ class TestDiGraph(PackageTestCase):
g.add_edge("2", "3")
g.add_edge("5", "4")
g.add_edge("4", "3")
self.assertTrue(g.forward_transitive_closure("1") == set(["1", "2", "3"]))
self.assertTrue(g.forward_transitive_closure("4") == set(["4", "3"]))
self.assertTrue(g.forward_transitive_closure("1") == {"1", "2", "3"})
self.assertTrue(g.forward_transitive_closure("4") == {"4", "3"})
def test_all_paths(self):
g = DiGraph()

View File

@ -2443,7 +2443,7 @@ class TestQuantizedOps(TestCase):
affine_list = (True, False)
combined = [shape_list, torch_types, y_scales, y_zero_points, channels_last_list, affine_list]
test_cases_product = itertools.product(*combined)
test_cases = list(test_case for test_case in test_cases_product)
test_cases = list(test_cases_product)
# add just one test case to test overflow
test_cases.append([
[1, 4, 224, 224, 160], # shape,

View File

@ -95,8 +95,8 @@ class TestModelNumericsEager(QuantizationTestCase):
torch.manual_seed(67)
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
qconfigset = set([torch.ao.quantization.default_weight_only_qconfig,
torch.ao.quantization.default_activation_only_qconfig])
qconfigset = {torch.ao.quantization.default_weight_only_qconfig,
torch.ao.quantization.default_activation_only_qconfig}
SQNRTarget = [35, 45]
for idx, qconfig in enumerate(qconfigset):
my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)

View File

@ -1120,7 +1120,7 @@ class TestQuantizeEagerPTQDynamic(QuantizationTestCase):
# Test set qconfig
model = SingleLayerLinearDynamicModel()
quantize_dynamic(model, set([nn.Linear]), inplace=True, dtype=dtype)
quantize_dynamic(model, {nn.Linear}, inplace=True, dtype=dtype)
checkQuantized(model)
def test_two_layers(self):

View File

@ -895,7 +895,7 @@ class TestFxModelReportClass(QuantizationTestCase):
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, model.get_example_inputs()[0])
# make an example set of detectors
test_detector_set = set([DynamicStaticDetector(), PerChannelDetector(backend)])
test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)}
# initialize with an empty detector
model_report = ModelReport(model_prep, test_detector_set)
@ -905,7 +905,7 @@ class TestFxModelReportClass(QuantizationTestCase):
# now attempt with no valid reports, should raise error
with self.assertRaises(ValueError):
model_report = ModelReport(model, set([]))
model_report = ModelReport(model, set())
# number of expected obs of interest entries
num_expected_entries = len(test_detector_set)
@ -932,7 +932,7 @@ class TestFxModelReportClass(QuantizationTestCase):
# make an example set of detectors
torch.backends.quantized.engine = "fbgemm"
backend = torch.backends.quantized.engine
test_detector_set = set([DynamicStaticDetector(), PerChannelDetector(backend)])
test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)}
# initialize with an empty detector
# prepare the model
@ -1029,8 +1029,8 @@ class TestFxModelReportClass(QuantizationTestCase):
torch.backends.quantized.engine = "fbgemm"
# check whether the correct number of reports are being generated
filled_detector_set = set([DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)])
single_detector_set = set([DynamicStaticDetector()])
filled_detector_set = {DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)}
single_detector_set = {DynamicStaticDetector()}
# create our models
model_full = TwoThreeOps()
@ -1316,7 +1316,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
# then create model report instance with detector
with override_quantized_engine('fbgemm'):
detector_set = set([InputWeightEqualizationDetector(0.5)])
detector_set = {InputWeightEqualizationDetector(0.5)}
# get tst model and callibrate
non_fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set)
@ -1326,7 +1326,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
for prepared_for_callibrate_model, mod_report in [non_fused, fused]:
# supported modules to check
mods_to_check = set([nn.Linear, nn.Conv2d])
mods_to_check = {nn.Linear, nn.Conv2d}
# get the set of all nodes in the graph their fqns
node_fqns = {node.target for node in prepared_for_callibrate_model.graph.nodes}
@ -1362,7 +1362,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
with override_quantized_engine('fbgemm'):
test_input_weight_detector = InputWeightEqualizationDetector(0.4)
detector_set = set([test_input_weight_detector])
detector_set = {test_input_weight_detector}
model = self.TwoBlockComplexNet()
# prepare the model for callibration
prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(
@ -1471,7 +1471,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
# then create model report instance with detector
with override_quantized_engine('fbgemm'):
test_input_weight_detector = InputWeightEqualizationDetector(0.4)
detector_set = set([test_input_weight_detector])
detector_set = {test_input_weight_detector}
model = self.ReluOnly()
# prepare the model for callibration
prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(model, detector_set)
@ -1547,7 +1547,7 @@ class TestFxDetectOutliers(QuantizationTestCase):
# not explicitly testing fusion because fx workflow automatically
with override_quantized_engine('fbgemm'):
detector_set = set([OutlierDetector(reference_percentile=0.95)])
detector_set = {OutlierDetector(reference_percentile=0.95)}
# get tst model and callibrate
prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
@ -1555,7 +1555,7 @@ class TestFxDetectOutliers(QuantizationTestCase):
)
# supported modules to check
mods_to_check = set([nn.Linear, nn.Conv2d, nn.ReLU])
mods_to_check = {nn.Linear, nn.Conv2d, nn.ReLU}
# there should be 4 node fqns that have the observer inserted
correct_number_of_obs_inserted = 4
@ -1590,7 +1590,7 @@ class TestFxDetectOutliers(QuantizationTestCase):
dynamic_static_detector = DynamicStaticDetector(tolerance=0.5)
param_size: int = 4
detector_set = set([outlier_detector, dynamic_static_detector])
detector_set = {outlier_detector, dynamic_static_detector}
model = self.LargeBatchModel(param_size=param_size)
# get tst model and callibrate
@ -1640,7 +1640,7 @@ class TestFxDetectOutliers(QuantizationTestCase):
outlier_detector = OutlierDetector(ratio_threshold=1, reference_percentile=0)
param_size: int = 16
detector_set = set([outlier_detector])
detector_set = {outlier_detector}
model = self.LargeBatchModel(param_size=param_size)
# get tst model and callibrate
@ -1690,7 +1690,7 @@ class TestFxDetectOutliers(QuantizationTestCase):
outlier_detector = OutlierDetector(reference_percentile=0.95)
param_size: int = 8
detector_set = set([outlier_detector])
detector_set = {outlier_detector}
model = self.LargeBatchModel(param_size=param_size)
# get tst model and callibrate
@ -1874,8 +1874,8 @@ class TestFxModelReportVisualizer(QuantizationTestCase):
channel_headers, channel_table = table_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
# these two together should be the same as the generated report info in terms of keys
tensor_info_modules = set(row[1] for row in tensor_table)
channel_info_modules = set(row[1] for row in channel_table)
tensor_info_modules = {row[1] for row in tensor_table}
channel_info_modules = {row[1] for row in channel_table}
combined_modules: Set = tensor_info_modules.union(channel_info_modules)
generated_report_keys: Set = set(mod_rep_visualizer.generated_reports.keys())
@ -1901,8 +1901,8 @@ class TestFxModelReportVisualizer(QuantizationTestCase):
tensor_headers, tensor_table = empty_tables_dict[ModelReportVisualizer.TABLE_TENSOR_KEY]
channel_headers, channel_table = empty_tables_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
tensor_info_modules = set(row[1] for row in tensor_table)
channel_info_modules = set(row[1] for row in channel_table)
tensor_info_modules = {row[1] for row in tensor_table}
channel_info_modules = {row[1] for row in channel_table}
combined_modules: Set = tensor_info_modules.union(channel_info_modules)
self.assertEqual(len(combined_modules), 0) # should be no matching modules

View File

@ -660,16 +660,16 @@ class TestQuantizeJitPasses(QuantizationTestCase):
m = torch.jit.script(M())
qconfig_dict = {"": default_qconfig}
m = prepare_jit(m, qconfig_dict)
activation_dtypes = set(
activation_dtypes = {
obs.getattr("dtype")
for x, obs in m._modules._c.items()
if x.startswith("_observer_")
)
weight_dtypes = set(
}
weight_dtypes = {
obs.getattr("dtype")
for x, obs in m.conv._modules._c.items()
if x.startswith("_observer_")
)
}
assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype"
assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype"
assert (

View File

@ -1557,7 +1557,7 @@ class TestBinaryUfuncs(TestCase):
((2, 1), (2, 2)),
((2, 2), (2, 1, 1)),
)
test_inputs = list(
test_inputs = [
(
make_tensor(
base_size, dtype=torch.float64, device=device, high=10.0, low=0.0
@ -1567,7 +1567,7 @@ class TestBinaryUfuncs(TestCase):
),
)
for base_size, exp_size in test_cases
)
]
for base, exponent in test_inputs:
regex = "doesn't match the broadcast shape"
self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
@ -1605,10 +1605,10 @@ class TestBinaryUfuncs(TestCase):
(2, 1),
(2, 2, 2),
)
tensors = list(
tensors = [
make_tensor(shape, dtype=dtype, device=device, low=0)
for shape in exponent_shapes
)
]
floats_tensor = torch.tensor(floats, dtype=dtype, device=device)
for base in floats:
self._test_pow(base, floats_tensor)

View File

@ -194,7 +194,7 @@ class TestBundledInputs(TestCase):
# Check helper that work on all functions
all_info = loaded.get_bundled_inputs_functions_and_info()
self.assertEqual(set(all_info.keys()), set(['forward', 'foo']))
self.assertEqual(set(all_info.keys()), {'forward', 'foo'})
self.assertEqual(all_info['forward']['get_inputs_function_name'], ['get_all_bundled_inputs_for_forward'])
self.assertEqual(all_info['foo']['get_inputs_function_name'], ['get_all_bundled_inputs_for_foo'])
self.assertEqual(all_info['forward']['info'], info)

View File

@ -191,7 +191,7 @@ class TestPybindTypeCasters(common.TestCase):
In these cases we expect to get exactly one function per python type.
"""
# Verify that all functions have the same return type.
union_type = set(self.expected_return_type(f) for f in funcs)
union_type = {self.expected_return_type(f) for f in funcs}
assert len(union_type) == 1
union_type = union_type.pop()
self.assertIs(Union, get_origin(union_type))

View File

@ -1361,7 +1361,7 @@ except RuntimeError as e:
dataloader_iter = iter(dataloader)
fetched = list(dataloader_iter)
self.assertEqual(len(fetched), 4)
fetched = set(tuple(t.tolist()) for t in fetched)
fetched = {tuple(t.tolist()) for t in fetched}
self.assertEqual(fetched, {tuple(range(4)), tuple(range(7)), tuple(range(7, 14)), tuple(range(14, 20))})
# [auto-batching] test that workers exit gracefully
@ -1399,7 +1399,7 @@ except RuntimeError as e:
dataloader_iter = iter(dataloader)
fetched = list(dataloader_iter)
self.assertEqual(len(fetched), 2)
fetched = set(tuple(t.tolist()) for t in fetched)
fetched = {tuple(t.tolist()) for t in fetched}
self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))})
# [auto-batching & drop_last] test that workers exit gracefully
@ -1500,7 +1500,7 @@ except RuntimeError as e:
num_workers = 6
batch_size = 1
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
self.assertEqual(set(int(batch) for batch in get_dataloader()), set(int(batch) for batch in get_dataloader()))
self.assertEqual({int(batch) for batch in get_dataloader()}, {int(batch) for batch in get_dataloader()})
def test_multi_epochs_reproducibility(self):
num_workers = 2

View File

@ -1755,7 +1755,7 @@ class TestFunctionalIterDataPipe(TestCase):
len(zipped_dp)
# Functional Test: zips the results properly
exp = list((i, i) for i in range(5))
exp = [(i, i) for i in range(5)]
self.assertEqual(list(zipped_dp), exp)
# Functional Test: zips the inputs properly even when lengths are different (zips to the shortest)
@ -2364,7 +2364,7 @@ class TestTyping(TestCase):
# Context Manager to disable the runtime validation
with runtime_validation_disabled():
self.assertEqual(list(d for d in dp3), ds)
self.assertEqual(list(dp3), ds)
class NumbersDataset(IterDataPipe):

View File

@ -739,9 +739,9 @@ class HasDecompTest(TestCase):
# This is for operators that are only registered in some CI
# configurations, so would cause the test to fail
allow_list = set([aten.get_gradients.default])
allow_list = {aten.get_gradients.default}
overloads_wanting_decomp = set(op for op in all_aten_overloads() if can_appear_in_trace(op))
overloads_wanting_decomp = {op for op in all_aten_overloads() if can_appear_in_trace(op)}
ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys()
ops_missing_decomp -= allow_list
self.assertExpected("".join(sorted(op.name() + "\n" for op in ops_missing_decomp)))

View File

@ -466,7 +466,7 @@ class TestForeach(TestCase):
# `tensors2`: ['cuda', 'cpu']
_cuda_tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True))[0].input
_cpu_tensors = list(op.sample_inputs("cpu", dtype, num_input_tensors=[2], same_size=True))[0].input
tensors1, tensors2 = list(tensors for tensors in zip(_cuda_tensors, _cpu_tensors))
tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors))
foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
native_op, native_op_ = op.ref, op.ref_inplace
@ -494,7 +494,7 @@ class TestForeach(TestCase):
# tensors3: ['cuda', 'cpu]
_cuda_tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[3], same_size=True))[0].input
_cpu_tensors = list(op.sample_inputs("cpu", dtype, num_input_tensors=[3], same_size=True))[0].input
tensors1, tensors2, tensors3 = list(tensors for tensors in zip(_cuda_tensors, _cpu_tensors))
tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors))
foreach_op, foreach_op_, native_op = op.method_variant, op.inplace_variant, op.ref
actual = foreach_op(tensors1, tensors2, tensors3)

View File

@ -1598,8 +1598,8 @@ class TestFX(JitTestCase):
if node.op == 'output':
output_shape = node.args[0].meta['tensor_meta'].shape
output_stride = node.args[0].meta['tensor_meta'].stride
self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method',
'call_module', 'output']))
self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method',
'call_module', 'output'})
# Test shape propagation and make sure results match actual
self.assertEqual(output_shape, ref_out.shape)
@ -1832,8 +1832,8 @@ class TestFX(JitTestCase):
interp = Interpreter(symbolic_trace(rn18))
inp = torch.rand(5, 3, 224, 224)
out = interp.run(inp)
env_key_names = set(n.name for n in interp.env.keys())
self.assertEqual(env_key_names, set(['output']))
env_key_names = {n.name for n in interp.env.keys()}
self.assertEqual(env_key_names, {'output'})
def test_interpreter_default_args(self):
class Model(torch.nn.Module):
@ -2052,7 +2052,7 @@ class TestFX(JitTestCase):
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
orig_users = set(orig_node.users.keys())
orig_users_equiv = set(val_map[u] for u in orig_users)
orig_users_equiv = {val_map[u] for u in orig_users}
new_users = set(new_node.users.keys())
self.assertEqual(orig_users_equiv, new_users)
@ -2230,7 +2230,7 @@ class TestFX(JitTestCase):
users_of_x = x.node.users
self.assertEqual(len(users_of_x), 3)
expected_ops = set(['relu', 'add', 'neg'])
expected_ops = {'relu', 'add', 'neg'}
for use in users_of_x:
assert any(use.name.startswith(prefix) for prefix in expected_ops)

View File

@ -873,7 +873,7 @@ terrible spacing
) -> bool:
# `leaves` contains the set of standard `nn.Modules` that are not
# currently symbolically traceable. Ideally this set would be empty
leaves = set([torch.nn.BatchNorm2d])
leaves = {torch.nn.BatchNorm2d}
return type(m) in leaves
traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
@ -1057,7 +1057,7 @@ class {test_classname}(torch.nn.Module):
) -> bool:
# `leaves` contains the set of standard `nn.Modules` that are not
# currently symbolically traceable. Ideally this set would be empty
leaves = set([torch.nn.BatchNorm2d])
leaves = {torch.nn.BatchNorm2d}
return type(m) in leaves
traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m))

View File

@ -3743,7 +3743,7 @@ class TestCudaFuser(JitTestCase):
result += 1
return result
complete_views = set([tuple(original_view)])
complete_views = {tuple(original_view)}
to_visit = []
# empty new view, curent originaal view, start pos=0, move count = 0, last_move

View File

@ -27,7 +27,7 @@ if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
def strip_profiling_nodes(nodes):
profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut'])
profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'}
return [n for n in nodes if n.kind() not in profiling_opcodes]

View File

@ -46,7 +46,7 @@ LLVM_ENABLED = torch._C._llvm_enabled()
autograd_check_set = {'aten::__is__', 'prim::AutogradAllNonZero', 'prim::AutogradAllZero', 'prim::ListConstruct'}
def strip_profiling_nodes(nodes):
profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut'])
profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'}
return [n for n in nodes if n.kind() not in profiling_opcodes]
def warmup_forward(f, *args, profiling_count=2):
@ -189,7 +189,7 @@ class TestTEFuser(JitTestCase):
return x2.sum()
with texpr_reductions_enabled():
a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu')
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu')
a = a.reshape(5, 3)
scripted = self.checkScript(func, (a,))
self.assertLastGraphAllFused()
@ -205,7 +205,7 @@ class TestTEFuser(JitTestCase):
return x.sum((-2, )) * 2
with texpr_reductions_enabled():
a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu')
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu')
a = a.reshape(5, 3)
scripted = self.checkScript(func, (a,))
self.assertLastGraphAllFused()
@ -217,7 +217,7 @@ class TestTEFuser(JitTestCase):
return x.sum((0, ), keepdim=True, dtype=torch.double) * 2
with texpr_reductions_enabled():
a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu')
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu')
a = a.reshape(5, 3)
self.checkScript(func, (a,))

View File

@ -498,7 +498,7 @@ class TestModule(TestCase):
# TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
# nicer way for eval mode only.
# See https://github.com/pytorch/pytorch/issues/79161
rnn_modules = set([torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM])
rnn_modules = {torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM}
if (module_info.module_cls in rnn_modules
and not training
and 'cuda' in device

View File

@ -1719,7 +1719,7 @@ class TestRefsOpsInfo(TestCase):
module_alls = [(path, import_module(f"torch.{path}").__all__) for path in import_paths]
ref_ops_names = tuple(itertools.chain.from_iterable(
[f"{path}.{op}" for op in module_all] for path, module_all in module_alls))
ref_db_names = set(ref_op.name for ref_op in python_ref_db)
ref_db_names = {ref_op.name for ref_op in python_ref_db}
# TODO: References that do not have an entry in python_ref_db
skip_ref_ops = {
@ -1910,9 +1910,7 @@ fake_skips = (
fake_autocast_device_skips = defaultdict(dict)
# TODO: investigate/fix
fake_autocast_device_skips["cpu"] = set(
("linalg.pinv",)
)
fake_autocast_device_skips["cpu"] = {"linalg.pinv"}
dynamic_output_op_tests = (

View File

@ -145,8 +145,8 @@ class TestOptim(TestCase):
constructor_accepts_maximize=True,
constructor_accepts_foreach=False,
):
maximize_options = set([False, constructor_accepts_maximize])
foreach_options = set([False, constructor_accepts_foreach])
maximize_options = {False, constructor_accepts_maximize}
foreach_options = {False, constructor_accepts_foreach}
four_arg_constructor = constructor
if constructor_accepts_maximize and constructor_accepts_foreach:
@ -317,7 +317,7 @@ class TestOptim(TestCase):
# validate deepcopy() copies all public attributes
def getPublicAttr(obj):
return set(k for k in obj.__dict__ if not k.startswith("_"))
return {k for k in obj.__dict__ if not k.startswith("_")}
self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
@ -346,8 +346,8 @@ class TestOptim(TestCase):
return constructor
for maximize, foreach in itertools.product(
set([False, constructor_accepts_maximize]),
set([False, constructor_accepts_foreach]),
{False, constructor_accepts_maximize},
{False, constructor_accepts_foreach},
):
self._test_state_dict(
torch.randn(10, 5),

View File

@ -80,7 +80,7 @@ def _reduced_shape(shape, dim=None, keepdim=False):
# Wrap negative dims
dim = dim if isinstance(dim, Sequence) else [dim]
dim = set(i if i >= 0 else len(shape) + i for i in dim)
dim = {i if i >= 0 else len(shape) + i for i in dim}
result = []
for i, size in enumerate(shape):

View File

@ -19,33 +19,29 @@ from torchgen.utils import FileManager
# - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration
# Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now.
# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
MANUAL_BACKEND = set(
[
"options",
"data",
"set_data",
"is_leaf",
"output_nr",
"_version",
"retain_grad",
"_backward",
"requires_grad_",
]
)
MANUAL_BACKEND = {
"options",
"data",
"set_data",
"is_leaf",
"output_nr",
"_version",
"retain_grad",
"_backward",
"requires_grad_",
}
# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys.
# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
MANUAL_AUTOGRAD_AND_TRACER = set(
[
"resize_",
"resize_as_",
"detach",
"detach_",
"copy_",
"_fw_primal",
"_make_dual",
]
)
MANUAL_AUTOGRAD_AND_TRACER = {
"resize_",
"resize_as_",
"detach",
"detach_",
"copy_",
"_fw_primal",
"_make_dual",
}
# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops:
# union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER)

View File

@ -968,10 +968,10 @@ def emit_body(
"""Find arguments that have derivative definitions"""
if info is None or not info.has_derivatives:
return differentiable_inputs
names = set(name for d in info.derivatives for name in d.var_names)
names = {name for d in info.derivatives for name in d.var_names}
differentiable = [arg for arg in differentiable_inputs if arg.name in names]
if len(differentiable) != len(names):
missing = names - set(arg.name for arg in differentiable)
missing = names - {arg.name for arg in differentiable}
raise RuntimeError(
f"Missing arguments for derivatives: {missing} in {info.name}"
)

View File

@ -1408,7 +1408,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
assert isinstance(tos1, ConstDictVariable)
match_obj = tos1.items
if all(key in match_obj for key in keys):
self.push(TupleVariable(list(match_obj[key] for key in keys)))
self.push(TupleVariable([match_obj[key] for key in keys]))
self.push(ConstantVariable(True))
else:
self.push(ConstantVariable(None))

View File

@ -764,14 +764,14 @@ def dict_param_key_ids(value):
def dict_const_keys(value):
return set(k for k in value.keys() if not isinstance(k, torch.nn.Parameter))
return {k for k in value.keys() if not isinstance(k, torch.nn.Parameter)}
def dict_const_keys_repr(const_keys):
if any(isinstance(k, enum.Enum) for k in const_keys):
# To workaround repr(Enum) returning invalid global reference before python 3.11
# by calling enum_repr and removing quotes to render enum in guard code.
const_keys_str = f"{set(enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys)}".replace(
const_keys_str = f"{ {enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys} }".replace(
"'", ""
)
else:

View File

@ -451,7 +451,7 @@ class TorchVariable(VariableTracker):
for x in args
]
)
bin_ops = set(["add", "sub", "mul", "div", "sqrt"])
bin_ops = {"add", "sub", "mul", "div", "sqrt"}
if (
getattr(self.value, "__module__", "") == "torch"
and self.value.__name__ in bin_ops
@ -903,7 +903,7 @@ class TorchPyOperator(VariableTracker):
args[0].as_proxy(),
true_node,
false_node,
list(a.as_proxy() for a in sub_args),
[a.as_proxy() for a in sub_args],
)
# TODO: assert that the true/false return values are
# consistent

View File

@ -388,11 +388,11 @@ def min_cut_rematerialization_partition(
fusible_ops = recomputable_ops | set(random_ops)
if AOT_PARTITIONER_DEBUG:
joint_module_ops = set(
joint_module_ops = {
str(node.target._overloadpacket)
for node in joint_module.graph.nodes
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
)
}
ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops}
print("Ops banned from rematerialization: ", ops_ignored)
print()
@ -400,7 +400,7 @@ def min_cut_rematerialization_partition(
AGGRESSIVE_RECOMPUTATION = False
def is_materialized_backwards(node):
cur_nodes = set([node])
cur_nodes = {node}
while len(cur_nodes) > 0:
cur = cur_nodes.pop()
for user in cur.users:

View File

@ -949,7 +949,7 @@ class TritonKernel(Kernel):
dim = len(self.range_trees) - 1
result_var = self.cse.newvar()
result_var.mask_vars = set(var for var in masks if var[0] != "r")
result_var.mask_vars = {var for var in masks if var[0] != "r"}
if (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
accumulator = f"_{result_var}"

View File

@ -531,15 +531,15 @@ class GraphLowering(torch.fx.Interpreter):
def get_read_write_buffers_sizes(node):
if isinstance(node, NopKernelSchedulerNode):
return 0
reads = set(dep.name for dep in node.read_writes.reads)
writes = set(dep.name for dep in node.read_writes.writes)
reads = {dep.name for dep in node.read_writes.reads}
writes = {dep.name for dep in node.read_writes.writes}
def is_materialized(buf):
buf_uses = {user.node for user in scheduler.name_to_node[buf].users}
return len(buf_uses - set(node.snodes)) > 0
if isinstance(node, FusedSchedulerNode):
removed_buffers = set(dep for dep in writes if not is_materialized(dep))
removed_buffers = {dep for dep in writes if not is_materialized(dep)}
writes = writes - removed_buffers
reads = reads - removed_buffers
node_bytes = 0

View File

@ -2995,7 +2995,7 @@ class FallbackKernel(ExternKernelAlloc):
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
constant_args = [Shim(repr(x)) for x in self.constant_args]
args, kwargs = self.unflatten_args(tensor_args, constant_args)
return list(map(repr, args)) + list(gen_kwarg(k, v) for k, v in kwargs.items())
return list(map(repr, args)) + [gen_kwarg(k, v) for k, v in kwargs.items()]
@classmethod
def create(cls, kernel, *args, **kwargs):

View File

@ -177,7 +177,7 @@ class BaseSchedulerNode:
return self.get_name()
def get_names(self) -> Set[str]:
return set([self.get_name()])
return {self.get_name()}
def get_nodes(self) -> List["BaseSchedulerNode"]:
return [self]

View File

@ -295,14 +295,12 @@ def free_symbol_startswith(index: sympy.Expr, prefix: str):
def has_incompatible_cudagraph_ops(gm):
forbidden_list = set(
[
"aten._fused_moving_avg_obs_fq_helper.default",
"aten._fused_moving_avg_obs_fq_helper_functional.default",
"fbgemm.dense_to_jagged.default",
"fbgemm.jagged_to_padded_dense.default",
]
)
forbidden_list = {
"aten._fused_moving_avg_obs_fq_helper.default",
"aten._fused_moving_avg_obs_fq_helper_functional.default",
"fbgemm.dense_to_jagged.default",
"fbgemm.jagged_to_padded_dense.default",
}
for node in gm.graph.nodes:
if str(node.target) in forbidden_list:
return True

View File

@ -243,14 +243,12 @@ def is_channels_last_contiguous_3d(a: Tensor) -> bool:
return True
_memory_formats = set(
(
torch.contiguous_format,
torch.preserve_format,
torch.channels_last,
torch.channels_last_3d,
)
)
_memory_formats = {
torch.contiguous_format,
torch.preserve_format,
torch.channels_last,
torch.channels_last_3d,
}
def validate_memory_format(memory_format: torch.memory_format):

View File

@ -2956,7 +2956,7 @@ def native_group_norm(
out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps)
out = out.view(input.shape)
broadcast_dims = [0] + list(dim for dim in range(2, input.ndim))
broadcast_dims = [0] + list(range(2, input.ndim))
unsqueeze_bias = None
if bias is not None:
unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)

View File

@ -816,13 +816,9 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
return super(ConvReLU3d, cls).from_float(mod)
def update_bn_stats(mod):
if type(mod) in set(
[ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
):
if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}:
mod.update_bn_stats()
def freeze_bn_stats(mod):
if type(mod) in set(
[ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
):
if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}:
mod.freeze_bn_stats()

View File

@ -267,10 +267,8 @@ class RNNBase(torch.nn.Module):
@classmethod
def from_float(cls, mod):
assert type(mod) in set(
[torch.nn.LSTM,
torch.nn.GRU]
), 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'
assert type(mod) in {torch.nn.LSTM,
torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'
assert hasattr(
mod,
'qconfig'
@ -823,9 +821,9 @@ class RNNCellBase(torch.nn.Module):
@classmethod
def from_float(cls, mod):
assert type(mod) in set([torch.nn.LSTMCell,
torch.nn.GRUCell,
torch.nn.RNNCell]), 'nn.quantized.dynamic.RNNCellBase.from_float \
assert type(mod) in {torch.nn.LSTMCell,
torch.nn.GRUCell,
torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \
only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell'
assert hasattr(
mod, 'qconfig'), 'Input float module must have qconfig defined'

View File

@ -222,12 +222,12 @@ class OutputLogger(Logger):
def _convert_tuple_to_list(t: Any) -> Any:
return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t
def _dequantize_tensor_list(t: Any) -> Any:
return (
list(_dequantize_tensor_list(x) for x in t)
[_dequantize_tensor_list(x) for x in t]
if type(t) is list
else t.dequantize()
if t.is_quantized

View File

@ -27,303 +27,303 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
# note: this set is modified below by items from backend_config
sets_of_related_ops: List[Set[NSNodeTargetType]] = [
# conv modules
set([
{
nn.Conv1d,
]),
set([
},
{
nn.Conv2d,
]),
set([
},
{
nn.Conv3d,
]),
},
# conv functionals
set([
{
F.conv1d,
]),
set([
},
{
F.conv2d,
]),
set([
},
{
F.conv3d,
]),
},
# linear modules
set([
{
nn.Linear,
]),
},
# linear functionals
set([
{
F.linear,
]),
},
# average pool
set([
{
nn.AvgPool1d,
torch.avg_pool1d,
]),
set([
},
{
nn.AvgPool2d,
torch._C._nn.avg_pool2d,
]),
set([
},
{
nn.AvgPool3d,
torch._C._nn.avg_pool3d,
]),
},
# adaptive average pool
set([
{
nn.AdaptiveAvgPool1d,
F.adaptive_avg_pool1d,
]),
set([
},
{
nn.AdaptiveAvgPool2d,
F.adaptive_avg_pool2d,
]),
set([
},
{
nn.AdaptiveAvgPool3d,
F.adaptive_avg_pool3d,
]),
},
# LSTM
set([
{
nn.LSTM,
]),
},
# add
set([
{
torch.add,
operator.add, # x + y
]),
},
# cat
set([
{
torch.cat,
]),
},
# mul
set([
{
torch.mul,
operator.mul,
]),
},
# relu
set([
{
F.relu,
nn.ReLU,
'relu',
'relu_',
torch.relu,
]),
},
# maxpool
set([
{
nn.MaxPool1d,
F.max_pool1d,
]),
set([
},
{
nn.MaxPool2d,
F.max_pool2d,
]),
set([
},
{
nn.MaxPool3d,
F.max_pool3d,
]),
},
# sigmoid
set([
{
torch.sigmoid,
'sigmoid',
'sigmoid_',
nn.Sigmoid,
F.sigmoid,
]),
},
# BatchNorm
set([
{
nn.BatchNorm2d,
]),
set([
},
{
nn.BatchNorm3d,
]),
},
# ConvTranspose
set([
{
nn.ConvTranspose1d,
]),
set([
},
{
nn.ConvTranspose2d,
]),
set([
},
{
nn.ConvTranspose3d,
]),
},
# ELU
set([
{
nn.ELU,
]),
},
# Embedding
set([
{
nn.Embedding,
]),
},
# EmbeddingBag
set([
{
nn.EmbeddingBag,
]),
},
# GroupNorm
set([
{
nn.GroupNorm,
]),
},
# Hardswish
set([
{
nn.Hardswish,
]),
},
# InstanceNorm
set([
{
nn.InstanceNorm1d,
]),
set([
},
{
nn.InstanceNorm2d,
]),
set([
},
{
nn.InstanceNorm3d,
]),
},
# LayerNorm
set([
{
nn.LayerNorm,
]),
},
# LeakyReLU
set([
{
nn.LeakyReLU,
]),
},
# ReLU6
set([
{
nn.ReLU6,
F.relu6,
]),
},
# F.elu
set([
{
F.elu,
]),
},
# F.hardswish
set([
{
F.hardswish,
]),
},
# F.group_norm
set([
{
F.group_norm,
]),
},
# F.instance_norm
set([
{
F.instance_norm,
]),
},
# F.layer_norm
set([
{
F.layer_norm,
]),
},
# F.leaky_relu
set([
{
F.leaky_relu,
]),
},
# F.silu
set([
{
nn.SiLU,
F.silu,
]),
},
# F.mish
set([
{
nn.Mish,
F.mish,
]),
},
# F.tanh
set([
{
nn.Tanh,
F.tanh,
torch.tanh,
'tanh_',
'tanh',
]),
},
# F.hardsigmoid
set([
{
'hardsigmoid_',
'hardsigmoid',
F.hardsigmoid,
nn.Hardsigmoid,
]),
},
# F.hardtanh
set([
{
nn.Hardtanh,
F.hardtanh,
F.hardtanh_,
]),
},
# floordiv
set([
{
operator.floordiv,
]),
},
# unsqueeze
set([
{
torch.unsqueeze,
]),
},
# stack
set([
{
torch.stack,
]),
},
# squeeze
set([
{
torch.squeeze,
]),
},
# sort
set([
{
torch.sort,
]),
},
# repeat_interleave
set([
{
torch.repeat_interleave,
]),
},
# min
set([
{
torch.min,
]),
},
# mean
set([
{
torch.mean,
]),
},
# max
set([
{
torch.max,
]),
},
# transpose
set([
{
torch.transpose,
]),
},
# flatten
set([
{
torch.flatten,
]),
},
# clamp
set([
{
torch.clamp,
]),
},
# chunk
set([
{
torch.chunk,
]),
},
# interpolate
set([
{
torch.nn.functional.interpolate,
]),
},
# dropout
set([
{
nn.Dropout,
]),
},
# F.dropout
set([
{
F.dropout,
]),
},
# matmul
set([
{
torch.matmul,
]),
},
# Softmax
set([
{
nn.Softmax,
]),
},
# PReLU
set([
{
nn.PReLU,
nnq.PReLU,
]),
},
# F.prelu
set([
{
F.prelu,
toq.prelu,
]),
},
]
# for each floating point op, add versions of the op added by
@ -453,12 +453,12 @@ def add_op_to_sets_of_related_ops(
counter = 0
while str(counter) in base_name_to_sets_of_related_ops:
counter += 1
base_name_to_sets_of_related_ops[str(counter)] = set([op])
base_name_to_sets_of_related_ops[str(counter)] = {op}
# TODO(future PR): clean this up
def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = set([
FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
F.linear,
F.conv1d,
F.conv2d,
@ -478,11 +478,11 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
torch.mul,
torch.sum,
F.prelu,
])
}
FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([
FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
toq.linear,
toq.linear_relu,
toq.conv1d,
@ -503,9 +503,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
# uncomment below
# toq.add,
# toq.mul,
])
}
FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([
FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
F.relu,
F.tanh,
torch.tanh,
@ -541,9 +541,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
torch.stack,
torch.unsqueeze,
operator.add,
])
}
MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = set([
MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
nn.Linear,
nnqat.Linear,
nnqatd.Linear,
@ -606,9 +606,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
nni.LinearTanh,
nni.ConvAdd2d,
nni.ConvAddReLU2d,
])
}
MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([
MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
nnq.Linear,
nnq.Conv1d,
nnq.Conv2d,
@ -640,9 +640,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
nniq.LinearTanh,
nniq.ConvAdd2d,
nniq.ConvAddReLU2d,
])
}
MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([
MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
nn.ReLU,
nn.Tanh,
nn.Sigmoid,
@ -660,9 +660,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
nn.MaxPool2d,
nn.MaxPool3d,
nn.ReLU6,
])
}
METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([
METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
'sigmoid_',
'sigmoid',
'tanh_',
@ -671,7 +671,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
'hardsigmoid',
'relu_',
'relu',
])
}
return {
'funs_io_type_fp32': FUNS_IO_TYPE_FP32,
@ -687,16 +687,16 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
FUNS_UNMATCHABLE: Set[NSNodeTargetType] = set([
FUNS_UNMATCHABLE: Set[NSNodeTargetType] = {
torch.quantize_per_tensor,
operator.getitem,
])
}
MODS_UNMATCHABLE: Set[NSNodeTargetType] = set([
MODS_UNMATCHABLE: Set[NSNodeTargetType] = {
nn.Identity,
])
}
METHS_UNMATCHABLE: Set[NSNodeTargetType] = set([
METHS_UNMATCHABLE: Set[NSNodeTargetType] = {
'to',
'dequantize',
'reshape',
@ -719,7 +719,7 @@ def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
'contiguous',
'clamp',
'chunk',
])
}
return {
'funs_unmatchable': FUNS_UNMATCHABLE,

View File

@ -991,9 +991,9 @@ def extract_weight_comparison(m: GraphModule) -> NSResultsType:
# use functions.
# TODO(future PR): move this to config
weighted_ops = set([
weighted_ops = {
torch.nn.functional.linear,
])
}
results: NSResultsType = {
'model': {NSSingleResultValuesType.WEIGHT.value: {}}

View File

@ -219,10 +219,10 @@ class PerChannelDetector(DetectorBase):
# Default map for representing supported per channel quantization modules for different backends
DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = {
"fbgemm": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
"qnnpack": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
"onednn": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
"x86": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
"fbgemm": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
"qnnpack": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
"onednn": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
"x86": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
}
def __init__(self, backend: str = torch.backends.quantized.engine):
@ -230,7 +230,7 @@ class PerChannelDetector(DetectorBase):
# store the backend information
self.backend_chosen = backend
self.supported_modules = set([])
self.supported_modules = set()
if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES:
self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[self.backend_chosen]
else:
@ -413,17 +413,17 @@ class DynamicStaticDetector(DetectorBase):
IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported"
# modules that are supported both dynamic and static for this report function
DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = set([nn.Linear])
DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear}
# modules that will be supported soon for both
DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = set([nn.Conv1d, nn.Conv2d, nn.Conv3d])
DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d}
def __init__(self, tolerance=0.5):
super().__init__()
# set tolerance level and initialize a set to keep track of useful fqn locations
self.tolerance = tolerance
self.useful_observer_fqns: Set[str] = set([])
self.useful_observer_fqns: Set[str] = set()
def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
r"""
@ -737,9 +737,14 @@ class InputWeightEqualizationDetector(DetectorBase):
* :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
"""
SUPPORTED_MODULES: Set[Callable] = set(
[nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]
)
SUPPORTED_MODULES: Set[Callable] = {nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nnqat.Linear,
nnqat.Conv1d,
nnqat.Conv2d,
nnqat.Conv3d}
# names for the pre and post observers that are inserted
DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"

View File

@ -129,7 +129,7 @@ class ModelReport:
# initialize each report to have empty set of observers of interest
for desired_report in self._desired_detector_names:
self._detector_name_to_observer_fqns[desired_report] = set([])
self._detector_name_to_observer_fqns[desired_report] = set()
# flags to ensure that we can only prepare and remove observers once
self._prepared_flag = False
@ -287,7 +287,7 @@ class ModelReport:
if remove_inserted_observers:
self._removed_observers = True
# get the set of all Observers inserted by this instance of ModelReport
all_observers_of_interest: Set[str] = set([])
all_observers_of_interest: Set[str] = set()
for desired_report in self._detector_name_to_observer_fqns:
observers_of_interest = self._detector_name_to_observer_fqns[desired_report]
all_observers_of_interest.update(observers_of_interest)

View File

@ -30,7 +30,7 @@ class FusedGraphModule(GraphModule):
class ObservedGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
self.preserved_attr_names = set([
self.preserved_attr_names = {
'_activation_post_process_map',
'_activation_post_process_indexes',
'_patterns',
@ -40,7 +40,7 @@ class ObservedGraphModule(GraphModule):
'_node_name_to_scope',
'_qconfig_mapping',
'_is_qat',
'_observed_node_names']).union(preserved_attr_names)
'_observed_node_names'}.union(preserved_attr_names)
preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
super().__init__(root, graph)
for attr in preserved_attrs:
@ -64,9 +64,9 @@ def _get_observed_graph_module_attr(model: Union[torch.nn.Module, GraphModule],
class ObservedStandaloneGraphModule(ObservedGraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
preserved_attr_names = preserved_attr_names.union(set([
preserved_attr_names = preserved_attr_names.union({
"_standalone_module_input_quantized_idxs",
"_standalone_module_output_quantized_idxs"]))
"_standalone_module_output_quantized_idxs"})
super().__init__(root, graph, preserved_attr_names)
def __deepcopy__(self, memo):

View File

@ -208,10 +208,10 @@ DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
def no_observer_set() -> Set[Any]:
r"""These modules cannot have observers inserted by default."""
no_observers = set([
no_observers = {
nn.quantizable.LSTM,
nn.quantizable.MultiheadAttention
])
}
return no_observers
def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:

View File

@ -1609,8 +1609,8 @@ def gradgradcheck(
# NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs
# before running forward mode AD
diff_input_args_indices = set(i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad)
diff_grad_output_indices = set(i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad)
diff_input_args_indices = {i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad}
diff_grad_output_indices = {i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad}
def new_func(*args):
# Restore the requires_grad information

View File

@ -491,7 +491,7 @@ def _parse_visible_devices() -> Set[int]:
"""Parse CUDA_VISIBLE_DEVICES environment variable."""
var = os.getenv("CUDA_VISIBLE_DEVICES")
if var is None:
return set(x for x in range(64))
return set(range(64))
def _strtoul(s: str) -> int:
"""Return -1 or positive integer sequence string starts with,"""

View File

@ -85,11 +85,11 @@ def compare(before, after, format_flamegraph=format_flamegraph):
f = io.StringIO()
before_segs = set(_seg_key(seg) for seg in before)
after_segs = set(_seg_key(seg) for seg in after)
before_segs = {_seg_key(seg) for seg in before}
after_segs = {_seg_key(seg) for seg in after}
print(f'only_before = {list(a for a,_ in (before_segs - after_segs))}')
print(f'only_after = {list(a for a,_ in (after_segs - before_segs))}')
print(f'only_before = {[a for a,_ in (before_segs - after_segs)]}')
print(f'only_after = {[a for a,_ in (after_segs - before_segs)]}')
for seg in before:
if _seg_key(seg) not in after_segs:

View File

@ -383,7 +383,7 @@ class DistributedDataParallel(Module):
]
# Build list of parameters.
parameters = list(parameter for _, parameter in modules_and_parameters)
parameters = [parameter for _, parameter in modules_and_parameters]
# Checks if a module will produce a sparse gradient.
def produces_sparse_gradient(module):
@ -393,9 +393,9 @@ class DistributedDataParallel(Module):
# Build list of booleans indicating whether or not to expect sparse
# gradients for the corresponding parameters.
expect_sparse_gradient = list(
expect_sparse_gradient = [
produces_sparse_gradient(module) for module, _ in modules_and_parameters
)
]
self._assign_modules_buffers()

View File

@ -281,7 +281,7 @@ def _handle_row_wise_sharding_tensor(
indices[placement.rank()] = list(
range(offset_start_idx, offset_start_idx + split_size)
)
indices_flatten = list(idx for indice in indices for idx in indice)
indices_flatten = [idx for indice in indices for idx in indice]
input_t = input_t.index_select(
0, torch.tensor(indices_flatten, device=input_t.device)

View File

@ -38,10 +38,10 @@ def wrap(res: object, spec: OutputSpecType) -> object:
assert spec is not None and isinstance(
spec, list
), f"output spec does not match with output! Expected list, got {spec}."
return list(
return [
dtensor.DTensor(e, s.mesh, s.placements, size=s.shape)
for e, s in zip(res, spec)
)
]
elif isinstance(res, tuple):
assert spec is not None and isinstance(
spec, tuple

View File

@ -397,7 +397,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding:
assert isinstance(indices_output_spec, DTensorSpec)
indices_spec = indices_output_spec
lookup_dims = set(v[0] for v in valid_indices_spec)
lookup_dims = {v[0] for v in valid_indices_spec}
need_reshard_on_values = tuple(
(isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard)))

View File

@ -370,7 +370,7 @@ def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap:
dim2 = normalize_dim(dim2, ndim)
assert dim1 < ndim
assert dim2 < ndim
dimmap = list(InputDim(i) for i in range(ndim))
dimmap = [InputDim(i) for i in range(ndim)]
swapdim = dimmap[dim1]
dimmap[dim1] = dimmap[dim2]
dimmap[dim2] = swapdim
@ -480,7 +480,7 @@ def propagate_shape_and_sharding(
if the leftmost split size is divisible by the mesh dimension
"""
assert len(in_shard) == len(mesh_sizes)
sharded_in_dims: Set[int] = set(s.dim for s in in_shard if isinstance(s, Shard))
sharded_in_dims: Set[int] = {s.dim for s in in_shard if isinstance(s, Shard)}
# for each input dim, for each mesh dim, provides a list of possible shardable dimensions
shardable_dims: torch.Tensor = torch.ones(
(len(local_in_shape), len(mesh_sizes)), dtype=torch.bool

View File

@ -567,12 +567,12 @@ def _get_ignored_modules(
# that this FSDP instance can get any ignored modules from its children.
# Include child modules and exclude nested FSDP modules themselves
ignored_modules = set(
ignored_modules = {
child
for module in ignored_root_modules
for child in module.modules()
if not isinstance(child, fsdp_file.FullyShardedDataParallel)
)
}
if root_module in ignored_modules:
warnings.warn(
"Trying to ignore the top-level module passed into the FSDP "
@ -599,16 +599,16 @@ def _get_ignored_params(
"""
all_ignored_params: Set[torch.nn.Parameter] = set()
params_in_ignored_modules = set(
params_in_ignored_modules = {
p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
)
}
all_ignored_params.update(params_in_ignored_modules)
if ignored_parameters is not None:
params_in_ignored_parameters = set(
params_in_ignored_parameters = {
p for p in ignored_parameters if not _is_fsdp_flattened(p)
)
}
all_ignored_params.update(params_in_ignored_parameters)
# Include nested FSDP modules' ignored parameters
@ -626,9 +626,9 @@ def _get_buffer_names(root_module: nn.Module) -> Set[str]:
Returns the fully prefixed names of all buffers in the module hierarchy
rooted at ``root_module`` as a class:`set`.
"""
return set(
return {
clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
)
}
def _check_single_device_module(
@ -640,7 +640,7 @@ def _check_single_device_module(
ignoring the parameters in ``ignored_params``. Thus, after this method, the
module must be either fully on the CPU or fully on a non-CPU device.
"""
devices = set(param.device for param in _get_orig_params(module, ignored_params))
devices = {param.device for param in _get_orig_params(module, ignored_params)}
if len(devices) > 1:
raise RuntimeError(
f"FSDP only supports single device modules but got params on {devices}"

View File

@ -485,7 +485,7 @@ def _flatten_optim_state(
are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0
are_zero_dim_tensors &= _is_zero_dim_tensor(v)
are_non_tensors &= not torch.is_tensor(v)
types = set(type(v) for v in non_none_state_values)
types = {type(v) for v in non_none_state_values}
if len(types) != 1 or not (
are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors
):
@ -570,7 +570,7 @@ def _flatten_tensor_optim_state(
"""
non_none_tensors = [t for t in pos_dim_tensors if t is not None]
# Check that all are tensors with the same dtype
dtypes = set(t.dtype for t in non_none_tensors)
dtypes = {t.dtype for t in non_none_tensors}
if len(dtypes) != 1:
raise ValueError(
"All unflattened parameters comprising a single flattened "
@ -648,8 +648,8 @@ def _flatten_zero_dim_tensor_optim_state(
"""
non_none_tensors = [t for t in zero_dim_tensors if t is not None]
# Enforce that all have the same value and dtype
values_set = set(t.item() if t is not None else None for t in zero_dim_tensors)
dtypes = set(t.dtype if t is not None else None for t in zero_dim_tensors)
values_set = {t.item() if t is not None else None for t in zero_dim_tensors}
dtypes = {t.dtype if t is not None else None for t in zero_dim_tensors}
if (
len(non_none_tensors) != len(zero_dim_tensors)
or len(values_set) != 1
@ -1004,10 +1004,10 @@ def _rekey_sharded_optim_state_dict(
for unflat_param_group in sharded_osd["param_groups"]:
flat_param_group = copy.deepcopy(unflat_param_group)
flat_param_keys = sorted(
set(
{
unflat_param_name_to_flat_param_key[unflat_param_name]
for unflat_param_name in unflat_param_group["params"]
)
}
)
flat_param_group["params"] = flat_param_keys
rekeyed_osd_param_groups.append(flat_param_group)

View File

@ -1068,7 +1068,7 @@ def _get_training_state(
) -> HandleTrainingState:
"""Returns the training state of the handles in ``handles_key``."""
p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
training_states = set(handle._training_state for handle in handles_key)
training_states = {handle._training_state for handle in handles_key}
p_assert(
len(training_states) == 1,
f"Expects uniform training state but got {training_states}",

View File

@ -274,8 +274,8 @@ class FlatParameter(nn.Parameter):
self._fqns = tuple(fqns)
self._shared_param_infos = tuple(shared_param_infos)
self._param_extensions = tuple(param_extensions)
self._modules = set(pi.module for pi in self._param_infos).union(
set(spi.module for spi in self._shared_param_infos)
self._modules = {pi.module for pi in self._param_infos}.union(
{spi.module for spi in self._shared_param_infos}
)
assert (params is None) == (shared_params is None)
if params is not None:
@ -1857,8 +1857,8 @@ class FlatParamHandle:
def _get_modules(self) -> Set[nn.Module]:
"""Returns a :class:`set` of the modules whose parameters are included
in this handle's flattened parameter."""
return set(pi.module for pi in self.flat_param._param_infos).union(
set(spi.module for spi in self.flat_param._shared_param_infos)
return {pi.module for pi in self.flat_param._param_infos}.union(
{spi.module for spi in self.flat_param._shared_param_infos}
)
def is_sharded(self, tensor: Tensor) -> bool:

View File

@ -1968,7 +1968,7 @@ def _get_grad_norm(
if len(params_with_grad) == 0:
return torch.tensor(0.0)
grads = [param.grad for param in params_with_grad]
grad_dtypes = set(grad.dtype for grad in grads)
grad_dtypes = {grad.dtype for grad in grads}
if len(grad_dtypes) != 1:
raise ValueError(
f"Requires uniform dtype across all gradients but got {grad_dtypes}"

View File

@ -54,7 +54,7 @@ def register_rendezvous_handler(scheme, handler):
# Query will have format "rank=0&world_size=1" and is
# converted into {"rank": 0, "world_size": 1}
def _query_to_dict(query: str) -> Dict[str, str]:
return dict((pair[0], pair[1]) for pair in (pair.split("=") for pair in filter(None, query.split("&"))))
return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))}
def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs):

View File

@ -275,7 +275,7 @@ def check_dependency(partition):
"""Given a partition,check if there is a circular dependency on
this partition using bfs
"""
visited: Set[Partition] = set([partition])
visited: Set[Partition] = {partition}
queue: Deque[Partition] = deque([partition])
while queue:
p = queue.popleft()

View File

@ -30,7 +30,7 @@ _reify
@dispatch(dict, dict) # type: ignore[no-redef]
def _reify(d, s):
return dict((k, reify(v, s)) for k, v in d.items())
return {k: reify(v, s) for k, v in d.items()}
_reify
@dispatch(object, dict) # type: ignore[no-redef]

View File

@ -55,7 +55,7 @@ class VarDispatcher(Dispatcher):
"""
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
d = dict((k.token, v) for k, v in s.items())
d = {k.token: v for k, v in s.items()}
return func(**d)
@ -86,7 +86,7 @@ def supercedes(a, b):
s = unify(a, b)
if s is False:
return False
s = dict((k, v) for k, v in s.items() if not isvar(k) or not isvar(v))
s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
if reify(a, s) == a:
return True
if reify(b, s) == b:
@ -117,5 +117,5 @@ def ordering(signatures):
for s in signatures:
if s not in edges:
edges[s] = []
edges = dict((k, [b for a, b in v]) for k, v in edges.items()) # type: ignore[attr-defined, assignment]
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
return _toposort(edges)

View File

@ -80,11 +80,11 @@ def ambiguous(a, b):
def ambiguities(signatures):
""" All signature pairs such that A is ambiguous with B """
signatures = list(map(tuple, signatures))
return set((a, b) for a in signatures for b in signatures
if hash(a) < hash(b)
and ambiguous(a, b)
and not any(supercedes(c, a) and supercedes(c, b)
for c in signatures))
return {(a, b) for a in signatures for b in signatures
if hash(a) < hash(b)
and ambiguous(a, b)
and not any(supercedes(c, a) and supercedes(c, b)
for c in signatures)}
def super_signature(signatures):
@ -92,7 +92,7 @@ def super_signature(signatures):
n = len(signatures[0])
assert all(len(s) == n for s in signatures)
return [max([type.mro(sig[i]) for sig in signatures], key=len)[0]
return [max((type.mro(sig[i]) for sig in signatures), key=len)[0]
for i in range(n)]
@ -115,5 +115,5 @@ def ordering(signatures):
for s in signatures:
if s not in edges:
edges[s] = []
edges = dict((k, [b for a, b in v]) for k, v in edges.items()) # type: ignore[assignment, attr-defined]
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined]
return _toposort(edges)

View File

@ -45,8 +45,8 @@ def _toposort(edges):
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = dict((k, set(val)) for k, val in incoming_edges.items())
S = set((v for v in edges if v not in incoming_edges))
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
S = ({v for v in edges if v not in incoming_edges})
L = []
while S:

View File

@ -11,9 +11,9 @@ aten = torch.ops.aten
# stateful ops are banned from CSE
rand_ops = set([aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]) # noqa: E501
rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501
inplace_ops = set([aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_]) # noqa: E501
inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501
@torch.fx._compatibility.compatibility(is_backward_compatible=False)

View File

@ -468,10 +468,10 @@ def reinplace(gm, *sample_args):
# so we know not to re-inplace them.
# NOTE: later, we'll need to add an optimization for fully recovering performance
# on programs that mutate inputs.
input_storages = set(
input_storages = {
StorageWeakRef(
node.meta['fake_result']._typed_storage()
) for node in gm.graph.nodes if node.op == 'placeholder')
) for node in gm.graph.nodes if node.op == 'placeholder'}
# We also need to know for a given node, what are all of its aliasing nodes.
@ -627,14 +627,14 @@ def reinplace(gm, *sample_args):
old_flattened_res, _ = tree_flatten(old.meta['fake_result'])
node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result'])
old_res_storage = set(
old_res_storage = {
StorageWeakRef(
x._typed_storage()
) for x in old_flattened_res if isinstance(x, FakeTensor))
node_res_storage = set(
) for x in old_flattened_res if isinstance(x, FakeTensor)}
node_res_storage = {
StorageWeakRef(
x._typed_storage()
) for x in node_flattened_res if isinstance(x, FakeTensor))
) for x in node_flattened_res if isinstance(x, FakeTensor)}
# This will happen if we're updating a view op, e.g.
# e.g. replacing
@ -648,10 +648,10 @@ def reinplace(gm, *sample_args):
# We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
new_flattened_res, _ = tree_flatten(new.meta['fake_result'])
new_res_storage = set(
new_res_storage = {
StorageWeakRef(
x._typed_storage()
) for x in new_flattened_res if isinstance(x, FakeTensor))
) for x in new_flattened_res if isinstance(x, FakeTensor)}
assert len(new_res_storage) == 1
(old_ref,) = old_res_storage
(new_ref,) = new_res_storage

View File

@ -229,7 +229,7 @@ def generate_inputs_for_submodules(
handles = []
results = {}
submodule_to_names = dict((mod, name) for name, mod in model.named_modules())
submodule_to_names = {mod: name for name, mod in model.named_modules()}
def pre_forward(module, module_inputs):
results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs

View File

@ -117,7 +117,7 @@ def _gen_torch_functional_registered_ops():
# some functions directly map to their aten:: implementations.
# TODO: add support for more ops
ops = ["stft", "istft", "lu", "cdist", "norm", "unique", "unique_consecutive", "tensordot"]
return set(getattr(torch.functional, name) for name in ops)
return {getattr(torch.functional, name) for name in ops}
_functional_registered_ops = _gen_torch_functional_registered_ops()

View File

@ -89,7 +89,7 @@ def jit_ignored_properties(module):
user_annotated_ignored_attributes = getattr(module, "__jit_ignored_attributes__", list())
def get_properties_names(module):
return set(k for k, v in vars(module).items() if isinstance(v, property))
return {k for k, v in vars(module).items() if isinstance(v, property)}
properties = get_properties_names(type(module))
user_annoted_ignored_properties = set()

View File

@ -352,7 +352,7 @@ def try_ann_to_type(ann, loc):
return OptionalType(valid_type)
if is_union(ann):
# TODO: this is hack to recognize NumberType
if set(ann.__args__) == set([int, float, complex]):
if set(ann.__args__) == {int, float, complex}:
return NumberType.get()
inner: List = []
# We need these extra checks because both `None` and invalid

View File

@ -14,7 +14,7 @@ def _gen_unsupported_methods_properties():
return x.{op}()
''')
deprecated_apis = set(["volatile", "resize", "reinforce", "new", "name", "map2_", "has_names", "grad_fn", "resize_as"])
deprecated_apis = {"volatile", "resize", "reinforce", "new", "name", "map2_", "has_names", "grad_fn", "resize_as"}
tensor_attrs = tensor_attrs - deprecated_apis
properties = []

View File

@ -378,11 +378,11 @@ defined as ``prod(x[:i])``.""",
)
# Apply function name info to docstring templates:
templates = dict(
(k, v.format_map(template_data))
templates = {
k: v.format_map(template_data)
for k, v in docstring_templates.items()
if k.startswith(op_kind)
)
}
templates.update(
(k, v.format_map(template_data) if isinstance(v, str) else v)
for k, v in template_data.items()

View File

@ -90,7 +90,7 @@ def _map_mt_args_kwargs(args, kwargs, map_fn):
def _wrap_result(result_data, result_mask):
if isinstance(result_data, list):
return list(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
return [_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)]
if isinstance(result_data, tuple):
return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
if torch.is_tensor(result_data):

View File

@ -173,7 +173,7 @@ class RNNBase(Module):
# a sufficient check, because overlapping parameter buffers that don't completely
# alias would break the assumptions of the uniqueness check in
# Module.named_parameters().
unique_data_ptrs = set(p.data_ptr() for p in self._flat_weights)
unique_data_ptrs = {p.data_ptr() for p in self._flat_weights}
if len(unique_data_ptrs) != len(self._flat_weights):
return

View File

@ -929,7 +929,7 @@ class DistributedDataParallel(Module, Joinable):
]
# Build list of parameters.
parameters = list(parameter for _, parameter in modules_and_parameters)
parameters = [parameter for _, parameter in modules_and_parameters]
# Checks if a module will produce a sparse gradient.
def produces_sparse_gradient(module):
@ -939,10 +939,10 @@ class DistributedDataParallel(Module, Joinable):
# Build list of booleans indicating whether or not to expect sparse
# gradients for the corresponding parameters.
expect_sparse_gradient = list(
expect_sparse_gradient = [
produces_sparse_gradient(module)
for module, _ in modules_and_parameters
)
]
self._assign_modules_buffers()

View File

@ -296,7 +296,7 @@ class NamedMemberAccessor:
Check that the given keys are valid.
"""
keys = set(keys)
valid_keys = set(name for name, _ in self.named_tensors(remove_duplicate=False))
valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)}
missing_keys = valid_keys - keys
unexpected_keys = keys - valid_keys
return sorted(missing_keys), sorted(unexpected_keys)

View File

@ -197,7 +197,7 @@ class DiagnosticContext:
def sarif(self) -> sarif.Run:
"""Returns the SARIF Run object."""
unique_rules = set(diagnostic.rule for diagnostic in self.diagnostics)
unique_rules = {diagnostic.rule for diagnostic in self.diagnostics}
return sarif.Run(
tool=sarif.Tool(
driver=sarif.ToolComponent(

View File

@ -914,7 +914,7 @@ def verify_aten_graph(
graph = graph.copy()
# Execute aten graph and get reference torch jit outputs.
graph_inputs = list(v for v in graph.inputs())
graph_inputs = list(graph.inputs())
jit_inputs = tuple([arg for arg in input_args if arg is not None])
weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]]
assert all([w is not None for w in weights])
@ -940,7 +940,7 @@ def verify_aten_graph(
# NOTE: Verification is unstable. Try catch to emit information for debugging.
try:
# NOTE: Input might be dce'ed, so we need to remove those from the input args.
new_input_names = set(v.debugName() for v in graph.inputs())
new_input_names = {v.debugName() for v in graph.inputs()}
new_input_args = []
for v, arg in zip(original_jit_graph.inputs(), input_args):
if v.debugName() in new_input_names:

View File

@ -7919,9 +7919,7 @@ def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
'nn.functional.max_unpool3d': 3
}
unpool_to_pool_name_dict = dict((
(k, f'nn.functional.{v.__name__}') for k, v in unpool_name_to_pool_method_dict.items()
))
unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()}
pool_dim = unpool_name_to_dim[op_info.name]
pool_method = unpool_name_to_pool_method_dict[op_info.name]

View File

@ -507,7 +507,7 @@ def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None,
if isinstance(t, torch.Tensor) and t.requires_grad:
return torch.randn_like(t)
elif is_tensorlist(t):
return list(torch.randn_like(e) if e.requires_grad else None for e in t)
return [torch.randn_like(e) if e.requires_grad else None for e in t]
return None
tangent_args = tuple(maybe_tangent(arg) for arg in args)

Some files were not shown because too many files have changed in this diff Show More