mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676 Approved by: https://github.com/ezyang
This commit is contained in:
parent
54c0f37646
commit
67d9790985
2
.flake8
2
.flake8
|
|
@ -11,7 +11,7 @@ ignore =
|
|||
# these ignores are from flake8-bugbear; please fix!
|
||||
B007,B008,
|
||||
# these ignores are from flake8-comprehensions; please fix!
|
||||
C400,C401,C402,C405,C407
|
||||
C407
|
||||
per-file-ignores =
|
||||
__init__.py: F401
|
||||
torch/utils/cpp_extension.py: B950
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def main():
|
|||
ja = load(args.file[0])
|
||||
jb = load(args.file[1])
|
||||
|
||||
keys = (set(ja.keys()) | set(jb.keys())) - set(["benchmark_results"])
|
||||
keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"}
|
||||
print("{:20s} {:>20s} {:>20s}".format("", "baseline", "test"))
|
||||
print("{:20s} {:>20s} {:>20s}".format("", "-" * 20, "-" * 20))
|
||||
for key in sorted(keys):
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def get_content(submod):
|
|||
return content
|
||||
|
||||
def namespace_filter(data):
|
||||
out = set(d for d in data if d[0] != "_")
|
||||
out = {d for d in data if d[0] != "_"}
|
||||
return out
|
||||
|
||||
def run(args, submod):
|
||||
|
|
|
|||
|
|
@ -417,7 +417,7 @@ class TestNearlyDiagonalSparsifier(TestCase):
|
|||
assert torch.all(weights == torch.eye(height, width) * weights) # only diagonal to be present
|
||||
|
||||
def test_sparsity_levels(self):
|
||||
nearliness_levels = list(nearliness for nearliness in range(-1, 100))
|
||||
nearliness_levels = list(range(-1, 100))
|
||||
model = nn.Sequential()
|
||||
|
||||
p = re.compile(r'[-\.\s]')
|
||||
|
|
|
|||
|
|
@ -244,9 +244,9 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
{"ignored_modules": layer1_ignored_modules}
|
||||
if ignore_modules
|
||||
else {
|
||||
"ignored_parameters": set(
|
||||
"ignored_parameters": {
|
||||
p for m in layer1_ignored_modules for p in m.parameters()
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
model.layer1 = FSDP(model.layer1, **ignore_kwargs)
|
||||
|
|
@ -260,9 +260,9 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
{"ignored_modules": model_ignored_modules}
|
||||
if ignore_modules
|
||||
else {
|
||||
"ignored_parameters": set(
|
||||
"ignored_parameters": {
|
||||
p for m in model_ignored_modules for p in m.parameters()
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
wrapped_model = FSDP(model, **ignore_kwargs_top)
|
||||
|
|
@ -279,9 +279,9 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
{"ignored_modules": ignored_modules}
|
||||
if ignore_modules
|
||||
else {
|
||||
"ignored_parameters": set(
|
||||
"ignored_parameters": {
|
||||
p for m in ignored_modules for p in m.parameters()
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -783,9 +783,7 @@ class TestFSDPStateDict(FSDPTest):
|
|||
def test_fsdp_state_dict_keys(self, state_dict_type):
|
||||
state_dict = self._state_dict(self._initialize_model(True), state_dict_type)
|
||||
if state_dict_type == "local_state_dict":
|
||||
self.assertEqual(
|
||||
set([FLAT_PARAM, f"inner.{FLAT_PARAM}"]), state_dict.keys()
|
||||
)
|
||||
self.assertEqual({FLAT_PARAM, f"inner.{FLAT_PARAM}"}, state_dict.keys())
|
||||
elif state_dict_type in ("state_dict", "sharded_state_dict"):
|
||||
# Keys should match local model.
|
||||
local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False)
|
||||
|
|
|
|||
|
|
@ -66,8 +66,8 @@ class TestUtils(TestCase):
|
|||
# create a mixed bag of data.
|
||||
data = [1, "str"]
|
||||
data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
|
||||
data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
|
||||
data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2))))
|
||||
data.insert(0, {"x", get_a_tensor(), get_a_tensor()})
|
||||
data.append(([1], get_a_tensor(), (1), [get_a_tensor()], {1, 2}))
|
||||
data.append({"abc": SomeDataClass("some_key", 1.0, [get_a_tensor()])})
|
||||
od = OrderedDict()
|
||||
od["k"] = "value"
|
||||
|
|
|
|||
|
|
@ -662,7 +662,7 @@ def test_named_children(setup_rpc):
|
|||
model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
|
||||
model = Pipe(model)
|
||||
|
||||
names = set(n for n, _ in model.named_modules())
|
||||
names = {n for n, _ in model.named_modules()}
|
||||
assert "partitions.0.0" in names
|
||||
assert "partitions.1.0" in names
|
||||
|
||||
|
|
|
|||
|
|
@ -1120,7 +1120,7 @@ class AbstractCommTest:
|
|||
)
|
||||
self._test_sequence_num_incremented(
|
||||
c10d._get_default_group(),
|
||||
ranks=list(i for i in range(dist.get_world_size())),
|
||||
ranks=list(range(dist.get_world_size())),
|
||||
)
|
||||
|
||||
def _test_sequence_num_incremented_subgroup(self, backend_name):
|
||||
|
|
|
|||
|
|
@ -2296,9 +2296,9 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
# The tensors to pass to broadcast are identical to the target
|
||||
# only on the process that is the root of the broadcast.
|
||||
if self.rank == root_rank:
|
||||
tensors = list(tensor.clone() for tensor in target)
|
||||
tensors = [tensor.clone() for tensor in target]
|
||||
else:
|
||||
tensors = list(torch.zeros_like(tensor) for tensor in target)
|
||||
tensors = [torch.zeros_like(tensor) for tensor in target]
|
||||
|
||||
if self.rank != root_rank:
|
||||
self.assertNotEqual(tensors, target)
|
||||
|
|
|
|||
|
|
@ -2623,9 +2623,9 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
# The tensors to pass to broadcast are idential to the target
|
||||
# only on the process that is the root of the broadcast.
|
||||
if self.rank == root_rank:
|
||||
tensors = list(tensor.clone() for tensor in target)
|
||||
tensors = [tensor.clone() for tensor in target]
|
||||
else:
|
||||
tensors = list(torch.zeros_like(tensor) for tensor in target)
|
||||
tensors = [torch.zeros_like(tensor) for tensor in target]
|
||||
|
||||
if self.rank != root_rank:
|
||||
self.assertNotEqual(tensors, target)
|
||||
|
|
|
|||
|
|
@ -55,15 +55,13 @@ class OptimizerTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
# exclude SparseAdam because other areas of the stack don't support it yet
|
||||
# the others are handled specially above
|
||||
exclude = set(
|
||||
[
|
||||
"SGD", # Handled above
|
||||
"Optimizer",
|
||||
"SparseAdam", # Unsupported
|
||||
"LBFGS", # Unsupported
|
||||
"RAdam", # Has data dependent control for rectification (needs symint)
|
||||
]
|
||||
)
|
||||
exclude = {
|
||||
"SGD", # Handled above
|
||||
"Optimizer",
|
||||
"SparseAdam", # Unsupported
|
||||
"LBFGS", # Unsupported
|
||||
"RAdam", # Has data dependent control for rectification (needs symint)
|
||||
}
|
||||
|
||||
optimizers = [
|
||||
opt
|
||||
|
|
|
|||
|
|
@ -649,7 +649,9 @@ def _get_min_chunk_len(config):
|
|||
return config.lsh_attn_chunk_length
|
||||
elif len(attn_types_set) == 1 and attn_types[0] == "local":
|
||||
return config.local_attn_chunk_length
|
||||
elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]):
|
||||
elif len(attn_types_set) == 2 and attn_types_set == set( # noqa: C405
|
||||
["lsh", "local"]
|
||||
):
|
||||
return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
|
|
|||
|
|
@ -803,7 +803,7 @@ class OperatorSet:
|
|||
def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)):
|
||||
result = {}
|
||||
for key in filter:
|
||||
result[key] = set([])
|
||||
result[key] = set()
|
||||
for op in self.data:
|
||||
support_status = operator_method(op)
|
||||
if support_status in filter:
|
||||
|
|
|
|||
|
|
@ -158,20 +158,20 @@ class TestTensorBuiltins(JitTestCase):
|
|||
return x.{}
|
||||
"""
|
||||
|
||||
EQUALITY_MISMATCH = set([
|
||||
EQUALITY_MISMATCH = {
|
||||
# TorchScript doesn't have real enums so they return an int instead
|
||||
# of the actual value
|
||||
'dtype',
|
||||
'layout',
|
||||
])
|
||||
MISSING_PROPERTIES = set([
|
||||
}
|
||||
MISSING_PROPERTIES = {
|
||||
'grad_fn',
|
||||
# This is an undocumented property so it's not included
|
||||
"output_nr",
|
||||
# This has a longer implementation, maybe not worth copying to
|
||||
# TorchScript if named tensors don't work there anyways
|
||||
'names',
|
||||
])
|
||||
}
|
||||
|
||||
for p in properties:
|
||||
if p in MISSING_PROPERTIES:
|
||||
|
|
|
|||
|
|
@ -1516,7 +1516,7 @@ class TestDict(JitTestCase):
|
|||
li.append(3)
|
||||
return li
|
||||
|
||||
self.assertTrue(set(specialized_list()) == set([1, 2, 3]))
|
||||
self.assertTrue(set(specialized_list()) == {1, 2, 3})
|
||||
|
||||
@skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
|
||||
def test_values(self):
|
||||
|
|
|
|||
|
|
@ -221,11 +221,11 @@ class TestMisc(JitTestCase):
|
|||
|
||||
torch._C._enable_mobile_interface_call_export()
|
||||
scripted_M_mod = torch.jit.script(M())
|
||||
self.assertTrue(set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal']).issubset(
|
||||
self.assertTrue({'aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'}.issubset(
|
||||
set(torch.jit.export_opnames(scripted_M_mod))))
|
||||
|
||||
scripted_M_mod.sub = torch.jit.script(FooMod())
|
||||
self.assertTrue(set(['aten::add.Tensor', 'aten::mul.Scalar']).issubset(
|
||||
self.assertTrue({'aten::add.Tensor', 'aten::mul.Scalar'}.issubset(
|
||||
set(torch.jit.export_opnames(scripted_M_mod))))
|
||||
|
||||
def test_math_inf(self):
|
||||
|
|
|
|||
|
|
@ -525,8 +525,8 @@ class TestSaveLoad(JitTestCase):
|
|||
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
|
||||
)
|
||||
self.assertEqual(
|
||||
set(name for name, _ in m.named_modules()),
|
||||
set(name for name, _ in m_loaded.named_modules()),
|
||||
{name for name, _ in m.named_modules()},
|
||||
{name for name, _ in m_loaded.named_modules()},
|
||||
)
|
||||
# Check parameters.
|
||||
m_params = dict(m.named_parameters())
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ class TestSlice(JitTestCase):
|
|||
self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3))
|
||||
tuple_graph = scripted_fn.graph
|
||||
slices = tuple_graph.findAllNodes("prim::TupleConstruct")
|
||||
num_outputs = set(len(x.output().type().elements()) for x in slices)
|
||||
num_outputs = {len(x.output().type().elements()) for x in slices}
|
||||
# there should be only one tupleSlice with length of 2
|
||||
self.assertTrue(num_outputs == {2})
|
||||
self.run_pass('lower_all_tuples', tuple_graph)
|
||||
|
|
|
|||
|
|
@ -34,8 +34,8 @@ def init_lists():
|
|||
yaml_ts = yaml.load(f, yaml.Loader)
|
||||
LAZY_OPS_LIST = set(remove_suffixes(itertools.chain(yaml_ts["full_codegen"], yaml_ts["supported"], yaml_ts["autograd"])))
|
||||
HAS_SYMINT_SUFFIX = yaml_ts["symint"]
|
||||
FALLBACK_LIST = set(["clamp"])
|
||||
SKIP_RUNTIME_ERROR_LIST = set([
|
||||
FALLBACK_LIST = {"clamp"}
|
||||
SKIP_RUNTIME_ERROR_LIST = {
|
||||
'index_select', # Empty output_sizes is not supported
|
||||
'clone', # is clone decomposed?
|
||||
|
||||
|
|
@ -46,19 +46,19 @@ def init_lists():
|
|||
'all', # ASAN failure
|
||||
'any', # ASAN failure
|
||||
'logdet', # ASAN failure
|
||||
])
|
||||
SKIP_INCORRECT_RESULTS_LIST = set([
|
||||
}
|
||||
SKIP_INCORRECT_RESULTS_LIST = {
|
||||
'squeeze', # Value out of range
|
||||
't', # Value out of range
|
||||
'transpose', # Value out of range
|
||||
'bernoulli', # incorrect results
|
||||
'pow', # incorrect results
|
||||
'addcdiv', # incorrect results (on CI not locally?)
|
||||
])
|
||||
}
|
||||
# The following ops all show up directly in ts_native_functions.yaml,
|
||||
# but run functionalized versions of the composite kernels in core.
|
||||
# This means that we don't expect the ops to show directly in the LTC metrics.
|
||||
FUNCTIONAL_DECOMPOSE_LIST = set([
|
||||
FUNCTIONAL_DECOMPOSE_LIST = {
|
||||
'diag_embed',
|
||||
'block_diag',
|
||||
'new_empty_strided',
|
||||
|
|
@ -70,13 +70,13 @@ def init_lists():
|
|||
'linalg_inv_ex',
|
||||
'linalg_pinv.atol_rtol_tensor',
|
||||
'logsumexp',
|
||||
])
|
||||
}
|
||||
# For some ops, we don't support all variants. Here we use formatted_name
|
||||
# to uniquely identify the variant.
|
||||
SKIP_VARIANT_LIST = set([
|
||||
SKIP_VARIANT_LIST = {
|
||||
'norm_nuc',
|
||||
'min_reduction_with_dim'
|
||||
])
|
||||
}
|
||||
|
||||
return (LAZY_OPS_LIST,
|
||||
FALLBACK_LIST,
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class TestDependencyHooks(PackageTestCase):
|
|||
exporter.register_extern_hook(my_extern_hook)
|
||||
exporter.save_source_string("foo", "import module_a")
|
||||
|
||||
self.assertEqual(my_externs, set(["module_a"]))
|
||||
self.assertEqual(my_externs, {"module_a"})
|
||||
|
||||
def test_multiple_extern_hooks(self):
|
||||
buffer = BytesIO()
|
||||
|
|
@ -93,7 +93,7 @@ class TestDependencyHooks(PackageTestCase):
|
|||
exporter.save_source_string("foo", "import module_a")
|
||||
|
||||
self.assertEqual(my_externs, set())
|
||||
self.assertEqual(my_externs2, set(["module_a"]))
|
||||
self.assertEqual(my_externs2, {"module_a"})
|
||||
|
||||
def test_extern_and_mock_hook(self):
|
||||
buffer = BytesIO()
|
||||
|
|
@ -114,8 +114,8 @@ class TestDependencyHooks(PackageTestCase):
|
|||
exporter.register_mock_hook(my_mock_hook)
|
||||
exporter.save_source_string("foo", "import module_a; import package_a")
|
||||
|
||||
self.assertEqual(my_externs, set(["module_a"]))
|
||||
self.assertEqual(my_mocks, set(["package_a"]))
|
||||
self.assertEqual(my_externs, {"module_a"})
|
||||
self.assertEqual(my_mocks, {"package_a"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class TestDiGraph(PackageTestCase):
|
|||
for n in g:
|
||||
nodes.add(n)
|
||||
|
||||
self.assertEqual(nodes, set([1, 2, 3]))
|
||||
self.assertEqual(nodes, {1, 2, 3})
|
||||
|
||||
def test_contains(self):
|
||||
g = DiGraph()
|
||||
|
|
@ -101,8 +101,8 @@ class TestDiGraph(PackageTestCase):
|
|||
g.add_edge("2", "3")
|
||||
g.add_edge("5", "4")
|
||||
g.add_edge("4", "3")
|
||||
self.assertTrue(g.forward_transitive_closure("1") == set(["1", "2", "3"]))
|
||||
self.assertTrue(g.forward_transitive_closure("4") == set(["4", "3"]))
|
||||
self.assertTrue(g.forward_transitive_closure("1") == {"1", "2", "3"})
|
||||
self.assertTrue(g.forward_transitive_closure("4") == {"4", "3"})
|
||||
|
||||
def test_all_paths(self):
|
||||
g = DiGraph()
|
||||
|
|
|
|||
|
|
@ -2443,7 +2443,7 @@ class TestQuantizedOps(TestCase):
|
|||
affine_list = (True, False)
|
||||
combined = [shape_list, torch_types, y_scales, y_zero_points, channels_last_list, affine_list]
|
||||
test_cases_product = itertools.product(*combined)
|
||||
test_cases = list(test_case for test_case in test_cases_product)
|
||||
test_cases = list(test_cases_product)
|
||||
# add just one test case to test overflow
|
||||
test_cases.append([
|
||||
[1, 4, 224, 224, 160], # shape,
|
||||
|
|
|
|||
|
|
@ -95,8 +95,8 @@ class TestModelNumericsEager(QuantizationTestCase):
|
|||
torch.manual_seed(67)
|
||||
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
|
||||
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
|
||||
qconfigset = set([torch.ao.quantization.default_weight_only_qconfig,
|
||||
torch.ao.quantization.default_activation_only_qconfig])
|
||||
qconfigset = {torch.ao.quantization.default_weight_only_qconfig,
|
||||
torch.ao.quantization.default_activation_only_qconfig}
|
||||
SQNRTarget = [35, 45]
|
||||
for idx, qconfig in enumerate(qconfigset):
|
||||
my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
|
||||
|
|
|
|||
|
|
@ -1120,7 +1120,7 @@ class TestQuantizeEagerPTQDynamic(QuantizationTestCase):
|
|||
|
||||
# Test set qconfig
|
||||
model = SingleLayerLinearDynamicModel()
|
||||
quantize_dynamic(model, set([nn.Linear]), inplace=True, dtype=dtype)
|
||||
quantize_dynamic(model, {nn.Linear}, inplace=True, dtype=dtype)
|
||||
checkQuantized(model)
|
||||
|
||||
def test_two_layers(self):
|
||||
|
|
|
|||
|
|
@ -895,7 +895,7 @@ class TestFxModelReportClass(QuantizationTestCase):
|
|||
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, model.get_example_inputs()[0])
|
||||
|
||||
# make an example set of detectors
|
||||
test_detector_set = set([DynamicStaticDetector(), PerChannelDetector(backend)])
|
||||
test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)}
|
||||
# initialize with an empty detector
|
||||
model_report = ModelReport(model_prep, test_detector_set)
|
||||
|
||||
|
|
@ -905,7 +905,7 @@ class TestFxModelReportClass(QuantizationTestCase):
|
|||
|
||||
# now attempt with no valid reports, should raise error
|
||||
with self.assertRaises(ValueError):
|
||||
model_report = ModelReport(model, set([]))
|
||||
model_report = ModelReport(model, set())
|
||||
|
||||
# number of expected obs of interest entries
|
||||
num_expected_entries = len(test_detector_set)
|
||||
|
|
@ -932,7 +932,7 @@ class TestFxModelReportClass(QuantizationTestCase):
|
|||
# make an example set of detectors
|
||||
torch.backends.quantized.engine = "fbgemm"
|
||||
backend = torch.backends.quantized.engine
|
||||
test_detector_set = set([DynamicStaticDetector(), PerChannelDetector(backend)])
|
||||
test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)}
|
||||
# initialize with an empty detector
|
||||
|
||||
# prepare the model
|
||||
|
|
@ -1029,8 +1029,8 @@ class TestFxModelReportClass(QuantizationTestCase):
|
|||
torch.backends.quantized.engine = "fbgemm"
|
||||
|
||||
# check whether the correct number of reports are being generated
|
||||
filled_detector_set = set([DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)])
|
||||
single_detector_set = set([DynamicStaticDetector()])
|
||||
filled_detector_set = {DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)}
|
||||
single_detector_set = {DynamicStaticDetector()}
|
||||
|
||||
# create our models
|
||||
model_full = TwoThreeOps()
|
||||
|
|
@ -1316,7 +1316,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
|
|||
# then create model report instance with detector
|
||||
with override_quantized_engine('fbgemm'):
|
||||
|
||||
detector_set = set([InputWeightEqualizationDetector(0.5)])
|
||||
detector_set = {InputWeightEqualizationDetector(0.5)}
|
||||
|
||||
# get tst model and callibrate
|
||||
non_fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set)
|
||||
|
|
@ -1326,7 +1326,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
|
|||
for prepared_for_callibrate_model, mod_report in [non_fused, fused]:
|
||||
|
||||
# supported modules to check
|
||||
mods_to_check = set([nn.Linear, nn.Conv2d])
|
||||
mods_to_check = {nn.Linear, nn.Conv2d}
|
||||
|
||||
# get the set of all nodes in the graph their fqns
|
||||
node_fqns = {node.target for node in prepared_for_callibrate_model.graph.nodes}
|
||||
|
|
@ -1362,7 +1362,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
|
|||
with override_quantized_engine('fbgemm'):
|
||||
|
||||
test_input_weight_detector = InputWeightEqualizationDetector(0.4)
|
||||
detector_set = set([test_input_weight_detector])
|
||||
detector_set = {test_input_weight_detector}
|
||||
model = self.TwoBlockComplexNet()
|
||||
# prepare the model for callibration
|
||||
prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(
|
||||
|
|
@ -1471,7 +1471,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
|
|||
# then create model report instance with detector
|
||||
with override_quantized_engine('fbgemm'):
|
||||
test_input_weight_detector = InputWeightEqualizationDetector(0.4)
|
||||
detector_set = set([test_input_weight_detector])
|
||||
detector_set = {test_input_weight_detector}
|
||||
model = self.ReluOnly()
|
||||
# prepare the model for callibration
|
||||
prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(model, detector_set)
|
||||
|
|
@ -1547,7 +1547,7 @@ class TestFxDetectOutliers(QuantizationTestCase):
|
|||
# not explicitly testing fusion because fx workflow automatically
|
||||
with override_quantized_engine('fbgemm'):
|
||||
|
||||
detector_set = set([OutlierDetector(reference_percentile=0.95)])
|
||||
detector_set = {OutlierDetector(reference_percentile=0.95)}
|
||||
|
||||
# get tst model and callibrate
|
||||
prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
|
||||
|
|
@ -1555,7 +1555,7 @@ class TestFxDetectOutliers(QuantizationTestCase):
|
|||
)
|
||||
|
||||
# supported modules to check
|
||||
mods_to_check = set([nn.Linear, nn.Conv2d, nn.ReLU])
|
||||
mods_to_check = {nn.Linear, nn.Conv2d, nn.ReLU}
|
||||
|
||||
# there should be 4 node fqns that have the observer inserted
|
||||
correct_number_of_obs_inserted = 4
|
||||
|
|
@ -1590,7 +1590,7 @@ class TestFxDetectOutliers(QuantizationTestCase):
|
|||
dynamic_static_detector = DynamicStaticDetector(tolerance=0.5)
|
||||
|
||||
param_size: int = 4
|
||||
detector_set = set([outlier_detector, dynamic_static_detector])
|
||||
detector_set = {outlier_detector, dynamic_static_detector}
|
||||
model = self.LargeBatchModel(param_size=param_size)
|
||||
|
||||
# get tst model and callibrate
|
||||
|
|
@ -1640,7 +1640,7 @@ class TestFxDetectOutliers(QuantizationTestCase):
|
|||
outlier_detector = OutlierDetector(ratio_threshold=1, reference_percentile=0)
|
||||
|
||||
param_size: int = 16
|
||||
detector_set = set([outlier_detector])
|
||||
detector_set = {outlier_detector}
|
||||
model = self.LargeBatchModel(param_size=param_size)
|
||||
|
||||
# get tst model and callibrate
|
||||
|
|
@ -1690,7 +1690,7 @@ class TestFxDetectOutliers(QuantizationTestCase):
|
|||
outlier_detector = OutlierDetector(reference_percentile=0.95)
|
||||
|
||||
param_size: int = 8
|
||||
detector_set = set([outlier_detector])
|
||||
detector_set = {outlier_detector}
|
||||
model = self.LargeBatchModel(param_size=param_size)
|
||||
|
||||
# get tst model and callibrate
|
||||
|
|
@ -1874,8 +1874,8 @@ class TestFxModelReportVisualizer(QuantizationTestCase):
|
|||
channel_headers, channel_table = table_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
|
||||
|
||||
# these two together should be the same as the generated report info in terms of keys
|
||||
tensor_info_modules = set(row[1] for row in tensor_table)
|
||||
channel_info_modules = set(row[1] for row in channel_table)
|
||||
tensor_info_modules = {row[1] for row in tensor_table}
|
||||
channel_info_modules = {row[1] for row in channel_table}
|
||||
combined_modules: Set = tensor_info_modules.union(channel_info_modules)
|
||||
|
||||
generated_report_keys: Set = set(mod_rep_visualizer.generated_reports.keys())
|
||||
|
|
@ -1901,8 +1901,8 @@ class TestFxModelReportVisualizer(QuantizationTestCase):
|
|||
tensor_headers, tensor_table = empty_tables_dict[ModelReportVisualizer.TABLE_TENSOR_KEY]
|
||||
channel_headers, channel_table = empty_tables_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
|
||||
|
||||
tensor_info_modules = set(row[1] for row in tensor_table)
|
||||
channel_info_modules = set(row[1] for row in channel_table)
|
||||
tensor_info_modules = {row[1] for row in tensor_table}
|
||||
channel_info_modules = {row[1] for row in channel_table}
|
||||
combined_modules: Set = tensor_info_modules.union(channel_info_modules)
|
||||
self.assertEqual(len(combined_modules), 0) # should be no matching modules
|
||||
|
||||
|
|
|
|||
|
|
@ -660,16 +660,16 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
|||
m = torch.jit.script(M())
|
||||
qconfig_dict = {"": default_qconfig}
|
||||
m = prepare_jit(m, qconfig_dict)
|
||||
activation_dtypes = set(
|
||||
activation_dtypes = {
|
||||
obs.getattr("dtype")
|
||||
for x, obs in m._modules._c.items()
|
||||
if x.startswith("_observer_")
|
||||
)
|
||||
weight_dtypes = set(
|
||||
}
|
||||
weight_dtypes = {
|
||||
obs.getattr("dtype")
|
||||
for x, obs in m.conv._modules._c.items()
|
||||
if x.startswith("_observer_")
|
||||
)
|
||||
}
|
||||
assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype"
|
||||
assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype"
|
||||
assert (
|
||||
|
|
|
|||
|
|
@ -1557,7 +1557,7 @@ class TestBinaryUfuncs(TestCase):
|
|||
((2, 1), (2, 2)),
|
||||
((2, 2), (2, 1, 1)),
|
||||
)
|
||||
test_inputs = list(
|
||||
test_inputs = [
|
||||
(
|
||||
make_tensor(
|
||||
base_size, dtype=torch.float64, device=device, high=10.0, low=0.0
|
||||
|
|
@ -1567,7 +1567,7 @@ class TestBinaryUfuncs(TestCase):
|
|||
),
|
||||
)
|
||||
for base_size, exp_size in test_cases
|
||||
)
|
||||
]
|
||||
for base, exponent in test_inputs:
|
||||
regex = "doesn't match the broadcast shape"
|
||||
self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
|
||||
|
|
@ -1605,10 +1605,10 @@ class TestBinaryUfuncs(TestCase):
|
|||
(2, 1),
|
||||
(2, 2, 2),
|
||||
)
|
||||
tensors = list(
|
||||
tensors = [
|
||||
make_tensor(shape, dtype=dtype, device=device, low=0)
|
||||
for shape in exponent_shapes
|
||||
)
|
||||
]
|
||||
floats_tensor = torch.tensor(floats, dtype=dtype, device=device)
|
||||
for base in floats:
|
||||
self._test_pow(base, floats_tensor)
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ class TestBundledInputs(TestCase):
|
|||
|
||||
# Check helper that work on all functions
|
||||
all_info = loaded.get_bundled_inputs_functions_and_info()
|
||||
self.assertEqual(set(all_info.keys()), set(['forward', 'foo']))
|
||||
self.assertEqual(set(all_info.keys()), {'forward', 'foo'})
|
||||
self.assertEqual(all_info['forward']['get_inputs_function_name'], ['get_all_bundled_inputs_for_forward'])
|
||||
self.assertEqual(all_info['foo']['get_inputs_function_name'], ['get_all_bundled_inputs_for_foo'])
|
||||
self.assertEqual(all_info['forward']['info'], info)
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ class TestPybindTypeCasters(common.TestCase):
|
|||
In these cases we expect to get exactly one function per python type.
|
||||
"""
|
||||
# Verify that all functions have the same return type.
|
||||
union_type = set(self.expected_return_type(f) for f in funcs)
|
||||
union_type = {self.expected_return_type(f) for f in funcs}
|
||||
assert len(union_type) == 1
|
||||
union_type = union_type.pop()
|
||||
self.assertIs(Union, get_origin(union_type))
|
||||
|
|
|
|||
|
|
@ -1361,7 +1361,7 @@ except RuntimeError as e:
|
|||
dataloader_iter = iter(dataloader)
|
||||
fetched = list(dataloader_iter)
|
||||
self.assertEqual(len(fetched), 4)
|
||||
fetched = set(tuple(t.tolist()) for t in fetched)
|
||||
fetched = {tuple(t.tolist()) for t in fetched}
|
||||
self.assertEqual(fetched, {tuple(range(4)), tuple(range(7)), tuple(range(7, 14)), tuple(range(14, 20))})
|
||||
|
||||
# [auto-batching] test that workers exit gracefully
|
||||
|
|
@ -1399,7 +1399,7 @@ except RuntimeError as e:
|
|||
dataloader_iter = iter(dataloader)
|
||||
fetched = list(dataloader_iter)
|
||||
self.assertEqual(len(fetched), 2)
|
||||
fetched = set(tuple(t.tolist()) for t in fetched)
|
||||
fetched = {tuple(t.tolist()) for t in fetched}
|
||||
self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))})
|
||||
|
||||
# [auto-batching & drop_last] test that workers exit gracefully
|
||||
|
|
@ -1500,7 +1500,7 @@ except RuntimeError as e:
|
|||
num_workers = 6
|
||||
batch_size = 1
|
||||
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
|
||||
self.assertEqual(set(int(batch) for batch in get_dataloader()), set(int(batch) for batch in get_dataloader()))
|
||||
self.assertEqual({int(batch) for batch in get_dataloader()}, {int(batch) for batch in get_dataloader()})
|
||||
|
||||
def test_multi_epochs_reproducibility(self):
|
||||
num_workers = 2
|
||||
|
|
|
|||
|
|
@ -1755,7 +1755,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
len(zipped_dp)
|
||||
|
||||
# Functional Test: zips the results properly
|
||||
exp = list((i, i) for i in range(5))
|
||||
exp = [(i, i) for i in range(5)]
|
||||
self.assertEqual(list(zipped_dp), exp)
|
||||
|
||||
# Functional Test: zips the inputs properly even when lengths are different (zips to the shortest)
|
||||
|
|
@ -2364,7 +2364,7 @@ class TestTyping(TestCase):
|
|||
|
||||
# Context Manager to disable the runtime validation
|
||||
with runtime_validation_disabled():
|
||||
self.assertEqual(list(d for d in dp3), ds)
|
||||
self.assertEqual(list(dp3), ds)
|
||||
|
||||
|
||||
class NumbersDataset(IterDataPipe):
|
||||
|
|
|
|||
|
|
@ -739,9 +739,9 @@ class HasDecompTest(TestCase):
|
|||
|
||||
# This is for operators that are only registered in some CI
|
||||
# configurations, so would cause the test to fail
|
||||
allow_list = set([aten.get_gradients.default])
|
||||
allow_list = {aten.get_gradients.default}
|
||||
|
||||
overloads_wanting_decomp = set(op for op in all_aten_overloads() if can_appear_in_trace(op))
|
||||
overloads_wanting_decomp = {op for op in all_aten_overloads() if can_appear_in_trace(op)}
|
||||
ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys()
|
||||
ops_missing_decomp -= allow_list
|
||||
self.assertExpected("".join(sorted(op.name() + "\n" for op in ops_missing_decomp)))
|
||||
|
|
|
|||
|
|
@ -466,7 +466,7 @@ class TestForeach(TestCase):
|
|||
# `tensors2`: ['cuda', 'cpu']
|
||||
_cuda_tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True))[0].input
|
||||
_cpu_tensors = list(op.sample_inputs("cpu", dtype, num_input_tensors=[2], same_size=True))[0].input
|
||||
tensors1, tensors2 = list(tensors for tensors in zip(_cuda_tensors, _cpu_tensors))
|
||||
tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors))
|
||||
|
||||
foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
|
||||
native_op, native_op_ = op.ref, op.ref_inplace
|
||||
|
|
@ -494,7 +494,7 @@ class TestForeach(TestCase):
|
|||
# tensors3: ['cuda', 'cpu]
|
||||
_cuda_tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[3], same_size=True))[0].input
|
||||
_cpu_tensors = list(op.sample_inputs("cpu", dtype, num_input_tensors=[3], same_size=True))[0].input
|
||||
tensors1, tensors2, tensors3 = list(tensors for tensors in zip(_cuda_tensors, _cpu_tensors))
|
||||
tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors))
|
||||
|
||||
foreach_op, foreach_op_, native_op = op.method_variant, op.inplace_variant, op.ref
|
||||
actual = foreach_op(tensors1, tensors2, tensors3)
|
||||
|
|
|
|||
|
|
@ -1598,8 +1598,8 @@ class TestFX(JitTestCase):
|
|||
if node.op == 'output':
|
||||
output_shape = node.args[0].meta['tensor_meta'].shape
|
||||
output_stride = node.args[0].meta['tensor_meta'].stride
|
||||
self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method',
|
||||
'call_module', 'output']))
|
||||
self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method',
|
||||
'call_module', 'output'})
|
||||
|
||||
# Test shape propagation and make sure results match actual
|
||||
self.assertEqual(output_shape, ref_out.shape)
|
||||
|
|
@ -1832,8 +1832,8 @@ class TestFX(JitTestCase):
|
|||
interp = Interpreter(symbolic_trace(rn18))
|
||||
inp = torch.rand(5, 3, 224, 224)
|
||||
out = interp.run(inp)
|
||||
env_key_names = set(n.name for n in interp.env.keys())
|
||||
self.assertEqual(env_key_names, set(['output']))
|
||||
env_key_names = {n.name for n in interp.env.keys()}
|
||||
self.assertEqual(env_key_names, {'output'})
|
||||
|
||||
def test_interpreter_default_args(self):
|
||||
class Model(torch.nn.Module):
|
||||
|
|
@ -2052,7 +2052,7 @@ class TestFX(JitTestCase):
|
|||
|
||||
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
|
||||
orig_users = set(orig_node.users.keys())
|
||||
orig_users_equiv = set(val_map[u] for u in orig_users)
|
||||
orig_users_equiv = {val_map[u] for u in orig_users}
|
||||
new_users = set(new_node.users.keys())
|
||||
self.assertEqual(orig_users_equiv, new_users)
|
||||
|
||||
|
|
@ -2230,7 +2230,7 @@ class TestFX(JitTestCase):
|
|||
|
||||
users_of_x = x.node.users
|
||||
self.assertEqual(len(users_of_x), 3)
|
||||
expected_ops = set(['relu', 'add', 'neg'])
|
||||
expected_ops = {'relu', 'add', 'neg'}
|
||||
for use in users_of_x:
|
||||
assert any(use.name.startswith(prefix) for prefix in expected_ops)
|
||||
|
||||
|
|
|
|||
|
|
@ -873,7 +873,7 @@ terrible spacing
|
|||
) -> bool:
|
||||
# `leaves` contains the set of standard `nn.Modules` that are not
|
||||
# currently symbolically traceable. Ideally this set would be empty
|
||||
leaves = set([torch.nn.BatchNorm2d])
|
||||
leaves = {torch.nn.BatchNorm2d}
|
||||
return type(m) in leaves
|
||||
|
||||
traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
|
||||
|
|
@ -1057,7 +1057,7 @@ class {test_classname}(torch.nn.Module):
|
|||
) -> bool:
|
||||
# `leaves` contains the set of standard `nn.Modules` that are not
|
||||
# currently symbolically traceable. Ideally this set would be empty
|
||||
leaves = set([torch.nn.BatchNorm2d])
|
||||
leaves = {torch.nn.BatchNorm2d}
|
||||
return type(m) in leaves
|
||||
|
||||
traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
|
||||
|
|
|
|||
|
|
@ -3743,7 +3743,7 @@ class TestCudaFuser(JitTestCase):
|
|||
result += 1
|
||||
return result
|
||||
|
||||
complete_views = set([tuple(original_view)])
|
||||
complete_views = {tuple(original_view)}
|
||||
|
||||
to_visit = []
|
||||
# empty new view, curent originaal view, start pos=0, move count = 0, last_move
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
|||
|
||||
|
||||
def strip_profiling_nodes(nodes):
|
||||
profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut'])
|
||||
profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'}
|
||||
return [n for n in nodes if n.kind() not in profiling_opcodes]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ LLVM_ENABLED = torch._C._llvm_enabled()
|
|||
autograd_check_set = {'aten::__is__', 'prim::AutogradAllNonZero', 'prim::AutogradAllZero', 'prim::ListConstruct'}
|
||||
|
||||
def strip_profiling_nodes(nodes):
|
||||
profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut'])
|
||||
profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'}
|
||||
return [n for n in nodes if n.kind() not in profiling_opcodes]
|
||||
|
||||
def warmup_forward(f, *args, profiling_count=2):
|
||||
|
|
@ -189,7 +189,7 @@ class TestTEFuser(JitTestCase):
|
|||
return x2.sum()
|
||||
|
||||
with texpr_reductions_enabled():
|
||||
a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu')
|
||||
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu')
|
||||
a = a.reshape(5, 3)
|
||||
scripted = self.checkScript(func, (a,))
|
||||
self.assertLastGraphAllFused()
|
||||
|
|
@ -205,7 +205,7 @@ class TestTEFuser(JitTestCase):
|
|||
return x.sum((-2, )) * 2
|
||||
|
||||
with texpr_reductions_enabled():
|
||||
a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu')
|
||||
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu')
|
||||
a = a.reshape(5, 3)
|
||||
scripted = self.checkScript(func, (a,))
|
||||
self.assertLastGraphAllFused()
|
||||
|
|
@ -217,7 +217,7 @@ class TestTEFuser(JitTestCase):
|
|||
return x.sum((0, ), keepdim=True, dtype=torch.double) * 2
|
||||
|
||||
with texpr_reductions_enabled():
|
||||
a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu')
|
||||
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu')
|
||||
a = a.reshape(5, 3)
|
||||
|
||||
self.checkScript(func, (a,))
|
||||
|
|
|
|||
|
|
@ -498,7 +498,7 @@ class TestModule(TestCase):
|
|||
# TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
|
||||
# nicer way for eval mode only.
|
||||
# See https://github.com/pytorch/pytorch/issues/79161
|
||||
rnn_modules = set([torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM])
|
||||
rnn_modules = {torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM}
|
||||
if (module_info.module_cls in rnn_modules
|
||||
and not training
|
||||
and 'cuda' in device
|
||||
|
|
|
|||
|
|
@ -1719,7 +1719,7 @@ class TestRefsOpsInfo(TestCase):
|
|||
module_alls = [(path, import_module(f"torch.{path}").__all__) for path in import_paths]
|
||||
ref_ops_names = tuple(itertools.chain.from_iterable(
|
||||
[f"{path}.{op}" for op in module_all] for path, module_all in module_alls))
|
||||
ref_db_names = set(ref_op.name for ref_op in python_ref_db)
|
||||
ref_db_names = {ref_op.name for ref_op in python_ref_db}
|
||||
|
||||
# TODO: References that do not have an entry in python_ref_db
|
||||
skip_ref_ops = {
|
||||
|
|
@ -1910,9 +1910,7 @@ fake_skips = (
|
|||
fake_autocast_device_skips = defaultdict(dict)
|
||||
|
||||
# TODO: investigate/fix
|
||||
fake_autocast_device_skips["cpu"] = set(
|
||||
("linalg.pinv",)
|
||||
)
|
||||
fake_autocast_device_skips["cpu"] = {"linalg.pinv"}
|
||||
|
||||
|
||||
dynamic_output_op_tests = (
|
||||
|
|
|
|||
|
|
@ -145,8 +145,8 @@ class TestOptim(TestCase):
|
|||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=False,
|
||||
):
|
||||
maximize_options = set([False, constructor_accepts_maximize])
|
||||
foreach_options = set([False, constructor_accepts_foreach])
|
||||
maximize_options = {False, constructor_accepts_maximize}
|
||||
foreach_options = {False, constructor_accepts_foreach}
|
||||
|
||||
four_arg_constructor = constructor
|
||||
if constructor_accepts_maximize and constructor_accepts_foreach:
|
||||
|
|
@ -317,7 +317,7 @@ class TestOptim(TestCase):
|
|||
|
||||
# validate deepcopy() copies all public attributes
|
||||
def getPublicAttr(obj):
|
||||
return set(k for k in obj.__dict__ if not k.startswith("_"))
|
||||
return {k for k in obj.__dict__ if not k.startswith("_")}
|
||||
|
||||
self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
|
||||
|
||||
|
|
@ -346,8 +346,8 @@ class TestOptim(TestCase):
|
|||
return constructor
|
||||
|
||||
for maximize, foreach in itertools.product(
|
||||
set([False, constructor_accepts_maximize]),
|
||||
set([False, constructor_accepts_foreach]),
|
||||
{False, constructor_accepts_maximize},
|
||||
{False, constructor_accepts_foreach},
|
||||
):
|
||||
self._test_state_dict(
|
||||
torch.randn(10, 5),
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ def _reduced_shape(shape, dim=None, keepdim=False):
|
|||
|
||||
# Wrap negative dims
|
||||
dim = dim if isinstance(dim, Sequence) else [dim]
|
||||
dim = set(i if i >= 0 else len(shape) + i for i in dim)
|
||||
dim = {i if i >= 0 else len(shape) + i for i in dim}
|
||||
|
||||
result = []
|
||||
for i, size in enumerate(shape):
|
||||
|
|
|
|||
|
|
@ -19,33 +19,29 @@ from torchgen.utils import FileManager
|
|||
# - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration
|
||||
# Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now.
|
||||
# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
|
||||
MANUAL_BACKEND = set(
|
||||
[
|
||||
"options",
|
||||
"data",
|
||||
"set_data",
|
||||
"is_leaf",
|
||||
"output_nr",
|
||||
"_version",
|
||||
"retain_grad",
|
||||
"_backward",
|
||||
"requires_grad_",
|
||||
]
|
||||
)
|
||||
MANUAL_BACKEND = {
|
||||
"options",
|
||||
"data",
|
||||
"set_data",
|
||||
"is_leaf",
|
||||
"output_nr",
|
||||
"_version",
|
||||
"retain_grad",
|
||||
"_backward",
|
||||
"requires_grad_",
|
||||
}
|
||||
|
||||
# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys.
|
||||
# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
|
||||
MANUAL_AUTOGRAD_AND_TRACER = set(
|
||||
[
|
||||
"resize_",
|
||||
"resize_as_",
|
||||
"detach",
|
||||
"detach_",
|
||||
"copy_",
|
||||
"_fw_primal",
|
||||
"_make_dual",
|
||||
]
|
||||
)
|
||||
MANUAL_AUTOGRAD_AND_TRACER = {
|
||||
"resize_",
|
||||
"resize_as_",
|
||||
"detach",
|
||||
"detach_",
|
||||
"copy_",
|
||||
"_fw_primal",
|
||||
"_make_dual",
|
||||
}
|
||||
|
||||
# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops:
|
||||
# union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER)
|
||||
|
|
|
|||
|
|
@ -968,10 +968,10 @@ def emit_body(
|
|||
"""Find arguments that have derivative definitions"""
|
||||
if info is None or not info.has_derivatives:
|
||||
return differentiable_inputs
|
||||
names = set(name for d in info.derivatives for name in d.var_names)
|
||||
names = {name for d in info.derivatives for name in d.var_names}
|
||||
differentiable = [arg for arg in differentiable_inputs if arg.name in names]
|
||||
if len(differentiable) != len(names):
|
||||
missing = names - set(arg.name for arg in differentiable)
|
||||
missing = names - {arg.name for arg in differentiable}
|
||||
raise RuntimeError(
|
||||
f"Missing arguments for derivatives: {missing} in {info.name}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1408,7 +1408,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||
assert isinstance(tos1, ConstDictVariable)
|
||||
match_obj = tos1.items
|
||||
if all(key in match_obj for key in keys):
|
||||
self.push(TupleVariable(list(match_obj[key] for key in keys)))
|
||||
self.push(TupleVariable([match_obj[key] for key in keys]))
|
||||
self.push(ConstantVariable(True))
|
||||
else:
|
||||
self.push(ConstantVariable(None))
|
||||
|
|
|
|||
|
|
@ -764,14 +764,14 @@ def dict_param_key_ids(value):
|
|||
|
||||
|
||||
def dict_const_keys(value):
|
||||
return set(k for k in value.keys() if not isinstance(k, torch.nn.Parameter))
|
||||
return {k for k in value.keys() if not isinstance(k, torch.nn.Parameter)}
|
||||
|
||||
|
||||
def dict_const_keys_repr(const_keys):
|
||||
if any(isinstance(k, enum.Enum) for k in const_keys):
|
||||
# To workaround repr(Enum) returning invalid global reference before python 3.11
|
||||
# by calling enum_repr and removing quotes to render enum in guard code.
|
||||
const_keys_str = f"{set(enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys)}".replace(
|
||||
const_keys_str = f"{ {enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys} }".replace(
|
||||
"'", ""
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -451,7 +451,7 @@ class TorchVariable(VariableTracker):
|
|||
for x in args
|
||||
]
|
||||
)
|
||||
bin_ops = set(["add", "sub", "mul", "div", "sqrt"])
|
||||
bin_ops = {"add", "sub", "mul", "div", "sqrt"}
|
||||
if (
|
||||
getattr(self.value, "__module__", "") == "torch"
|
||||
and self.value.__name__ in bin_ops
|
||||
|
|
@ -903,7 +903,7 @@ class TorchPyOperator(VariableTracker):
|
|||
args[0].as_proxy(),
|
||||
true_node,
|
||||
false_node,
|
||||
list(a.as_proxy() for a in sub_args),
|
||||
[a.as_proxy() for a in sub_args],
|
||||
)
|
||||
# TODO: assert that the true/false return values are
|
||||
# consistent
|
||||
|
|
|
|||
|
|
@ -388,11 +388,11 @@ def min_cut_rematerialization_partition(
|
|||
|
||||
fusible_ops = recomputable_ops | set(random_ops)
|
||||
if AOT_PARTITIONER_DEBUG:
|
||||
joint_module_ops = set(
|
||||
joint_module_ops = {
|
||||
str(node.target._overloadpacket)
|
||||
for node in joint_module.graph.nodes
|
||||
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
|
||||
)
|
||||
}
|
||||
ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops}
|
||||
print("Ops banned from rematerialization: ", ops_ignored)
|
||||
print()
|
||||
|
|
@ -400,7 +400,7 @@ def min_cut_rematerialization_partition(
|
|||
AGGRESSIVE_RECOMPUTATION = False
|
||||
|
||||
def is_materialized_backwards(node):
|
||||
cur_nodes = set([node])
|
||||
cur_nodes = {node}
|
||||
while len(cur_nodes) > 0:
|
||||
cur = cur_nodes.pop()
|
||||
for user in cur.users:
|
||||
|
|
|
|||
|
|
@ -949,7 +949,7 @@ class TritonKernel(Kernel):
|
|||
|
||||
dim = len(self.range_trees) - 1
|
||||
result_var = self.cse.newvar()
|
||||
result_var.mask_vars = set(var for var in masks if var[0] != "r")
|
||||
result_var.mask_vars = {var for var in masks if var[0] != "r"}
|
||||
if (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
|
||||
self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
|
||||
accumulator = f"_{result_var}"
|
||||
|
|
|
|||
|
|
@ -531,15 +531,15 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
def get_read_write_buffers_sizes(node):
|
||||
if isinstance(node, NopKernelSchedulerNode):
|
||||
return 0
|
||||
reads = set(dep.name for dep in node.read_writes.reads)
|
||||
writes = set(dep.name for dep in node.read_writes.writes)
|
||||
reads = {dep.name for dep in node.read_writes.reads}
|
||||
writes = {dep.name for dep in node.read_writes.writes}
|
||||
|
||||
def is_materialized(buf):
|
||||
buf_uses = {user.node for user in scheduler.name_to_node[buf].users}
|
||||
return len(buf_uses - set(node.snodes)) > 0
|
||||
|
||||
if isinstance(node, FusedSchedulerNode):
|
||||
removed_buffers = set(dep for dep in writes if not is_materialized(dep))
|
||||
removed_buffers = {dep for dep in writes if not is_materialized(dep)}
|
||||
writes = writes - removed_buffers
|
||||
reads = reads - removed_buffers
|
||||
node_bytes = 0
|
||||
|
|
|
|||
|
|
@ -2995,7 +2995,7 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
|
||||
constant_args = [Shim(repr(x)) for x in self.constant_args]
|
||||
args, kwargs = self.unflatten_args(tensor_args, constant_args)
|
||||
return list(map(repr, args)) + list(gen_kwarg(k, v) for k, v in kwargs.items())
|
||||
return list(map(repr, args)) + [gen_kwarg(k, v) for k, v in kwargs.items()]
|
||||
|
||||
@classmethod
|
||||
def create(cls, kernel, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ class BaseSchedulerNode:
|
|||
return self.get_name()
|
||||
|
||||
def get_names(self) -> Set[str]:
|
||||
return set([self.get_name()])
|
||||
return {self.get_name()}
|
||||
|
||||
def get_nodes(self) -> List["BaseSchedulerNode"]:
|
||||
return [self]
|
||||
|
|
|
|||
|
|
@ -295,14 +295,12 @@ def free_symbol_startswith(index: sympy.Expr, prefix: str):
|
|||
|
||||
|
||||
def has_incompatible_cudagraph_ops(gm):
|
||||
forbidden_list = set(
|
||||
[
|
||||
"aten._fused_moving_avg_obs_fq_helper.default",
|
||||
"aten._fused_moving_avg_obs_fq_helper_functional.default",
|
||||
"fbgemm.dense_to_jagged.default",
|
||||
"fbgemm.jagged_to_padded_dense.default",
|
||||
]
|
||||
)
|
||||
forbidden_list = {
|
||||
"aten._fused_moving_avg_obs_fq_helper.default",
|
||||
"aten._fused_moving_avg_obs_fq_helper_functional.default",
|
||||
"fbgemm.dense_to_jagged.default",
|
||||
"fbgemm.jagged_to_padded_dense.default",
|
||||
}
|
||||
for node in gm.graph.nodes:
|
||||
if str(node.target) in forbidden_list:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -243,14 +243,12 @@ def is_channels_last_contiguous_3d(a: Tensor) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
_memory_formats = set(
|
||||
(
|
||||
torch.contiguous_format,
|
||||
torch.preserve_format,
|
||||
torch.channels_last,
|
||||
torch.channels_last_3d,
|
||||
)
|
||||
)
|
||||
_memory_formats = {
|
||||
torch.contiguous_format,
|
||||
torch.preserve_format,
|
||||
torch.channels_last,
|
||||
torch.channels_last_3d,
|
||||
}
|
||||
|
||||
|
||||
def validate_memory_format(memory_format: torch.memory_format):
|
||||
|
|
|
|||
|
|
@ -2956,7 +2956,7 @@ def native_group_norm(
|
|||
out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps)
|
||||
out = out.view(input.shape)
|
||||
|
||||
broadcast_dims = [0] + list(dim for dim in range(2, input.ndim))
|
||||
broadcast_dims = [0] + list(range(2, input.ndim))
|
||||
unsqueeze_bias = None
|
||||
if bias is not None:
|
||||
unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)
|
||||
|
|
|
|||
|
|
@ -816,13 +816,9 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
|||
return super(ConvReLU3d, cls).from_float(mod)
|
||||
|
||||
def update_bn_stats(mod):
|
||||
if type(mod) in set(
|
||||
[ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
|
||||
):
|
||||
if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}:
|
||||
mod.update_bn_stats()
|
||||
|
||||
def freeze_bn_stats(mod):
|
||||
if type(mod) in set(
|
||||
[ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d]
|
||||
):
|
||||
if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}:
|
||||
mod.freeze_bn_stats()
|
||||
|
|
|
|||
|
|
@ -267,10 +267,8 @@ class RNNBase(torch.nn.Module):
|
|||
|
||||
@classmethod
|
||||
def from_float(cls, mod):
|
||||
assert type(mod) in set(
|
||||
[torch.nn.LSTM,
|
||||
torch.nn.GRU]
|
||||
), 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'
|
||||
assert type(mod) in {torch.nn.LSTM,
|
||||
torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'
|
||||
assert hasattr(
|
||||
mod,
|
||||
'qconfig'
|
||||
|
|
@ -823,9 +821,9 @@ class RNNCellBase(torch.nn.Module):
|
|||
|
||||
@classmethod
|
||||
def from_float(cls, mod):
|
||||
assert type(mod) in set([torch.nn.LSTMCell,
|
||||
torch.nn.GRUCell,
|
||||
torch.nn.RNNCell]), 'nn.quantized.dynamic.RNNCellBase.from_float \
|
||||
assert type(mod) in {torch.nn.LSTMCell,
|
||||
torch.nn.GRUCell,
|
||||
torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \
|
||||
only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell'
|
||||
assert hasattr(
|
||||
mod, 'qconfig'), 'Input float module must have qconfig defined'
|
||||
|
|
|
|||
|
|
@ -222,12 +222,12 @@ class OutputLogger(Logger):
|
|||
|
||||
|
||||
def _convert_tuple_to_list(t: Any) -> Any:
|
||||
return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
|
||||
return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t
|
||||
|
||||
|
||||
def _dequantize_tensor_list(t: Any) -> Any:
|
||||
return (
|
||||
list(_dequantize_tensor_list(x) for x in t)
|
||||
[_dequantize_tensor_list(x) for x in t]
|
||||
if type(t) is list
|
||||
else t.dequantize()
|
||||
if t.is_quantized
|
||||
|
|
|
|||
|
|
@ -27,303 +27,303 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
# note: this set is modified below by items from backend_config
|
||||
sets_of_related_ops: List[Set[NSNodeTargetType]] = [
|
||||
# conv modules
|
||||
set([
|
||||
{
|
||||
nn.Conv1d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.Conv2d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.Conv3d,
|
||||
]),
|
||||
},
|
||||
# conv functionals
|
||||
set([
|
||||
{
|
||||
F.conv1d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
F.conv2d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
F.conv3d,
|
||||
]),
|
||||
},
|
||||
# linear modules
|
||||
set([
|
||||
{
|
||||
nn.Linear,
|
||||
]),
|
||||
},
|
||||
# linear functionals
|
||||
set([
|
||||
{
|
||||
F.linear,
|
||||
]),
|
||||
},
|
||||
# average pool
|
||||
set([
|
||||
{
|
||||
nn.AvgPool1d,
|
||||
torch.avg_pool1d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.AvgPool2d,
|
||||
torch._C._nn.avg_pool2d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.AvgPool3d,
|
||||
torch._C._nn.avg_pool3d,
|
||||
]),
|
||||
},
|
||||
# adaptive average pool
|
||||
set([
|
||||
{
|
||||
nn.AdaptiveAvgPool1d,
|
||||
F.adaptive_avg_pool1d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.AdaptiveAvgPool2d,
|
||||
F.adaptive_avg_pool2d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.AdaptiveAvgPool3d,
|
||||
F.adaptive_avg_pool3d,
|
||||
]),
|
||||
},
|
||||
# LSTM
|
||||
set([
|
||||
{
|
||||
nn.LSTM,
|
||||
]),
|
||||
},
|
||||
# add
|
||||
set([
|
||||
{
|
||||
torch.add,
|
||||
operator.add, # x + y
|
||||
]),
|
||||
},
|
||||
# cat
|
||||
set([
|
||||
{
|
||||
torch.cat,
|
||||
]),
|
||||
},
|
||||
# mul
|
||||
set([
|
||||
{
|
||||
torch.mul,
|
||||
operator.mul,
|
||||
]),
|
||||
},
|
||||
# relu
|
||||
set([
|
||||
{
|
||||
F.relu,
|
||||
nn.ReLU,
|
||||
'relu',
|
||||
'relu_',
|
||||
torch.relu,
|
||||
]),
|
||||
},
|
||||
# maxpool
|
||||
set([
|
||||
{
|
||||
nn.MaxPool1d,
|
||||
F.max_pool1d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.MaxPool2d,
|
||||
F.max_pool2d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.MaxPool3d,
|
||||
F.max_pool3d,
|
||||
]),
|
||||
},
|
||||
# sigmoid
|
||||
set([
|
||||
{
|
||||
torch.sigmoid,
|
||||
'sigmoid',
|
||||
'sigmoid_',
|
||||
nn.Sigmoid,
|
||||
F.sigmoid,
|
||||
]),
|
||||
},
|
||||
# BatchNorm
|
||||
set([
|
||||
{
|
||||
nn.BatchNorm2d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.BatchNorm3d,
|
||||
]),
|
||||
},
|
||||
# ConvTranspose
|
||||
set([
|
||||
{
|
||||
nn.ConvTranspose1d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.ConvTranspose2d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.ConvTranspose3d,
|
||||
]),
|
||||
},
|
||||
# ELU
|
||||
set([
|
||||
{
|
||||
nn.ELU,
|
||||
]),
|
||||
},
|
||||
# Embedding
|
||||
set([
|
||||
{
|
||||
nn.Embedding,
|
||||
]),
|
||||
},
|
||||
# EmbeddingBag
|
||||
set([
|
||||
{
|
||||
nn.EmbeddingBag,
|
||||
]),
|
||||
},
|
||||
# GroupNorm
|
||||
set([
|
||||
{
|
||||
nn.GroupNorm,
|
||||
]),
|
||||
},
|
||||
# Hardswish
|
||||
set([
|
||||
{
|
||||
nn.Hardswish,
|
||||
]),
|
||||
},
|
||||
# InstanceNorm
|
||||
set([
|
||||
{
|
||||
nn.InstanceNorm1d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.InstanceNorm2d,
|
||||
]),
|
||||
set([
|
||||
},
|
||||
{
|
||||
nn.InstanceNorm3d,
|
||||
]),
|
||||
},
|
||||
# LayerNorm
|
||||
set([
|
||||
{
|
||||
nn.LayerNorm,
|
||||
]),
|
||||
},
|
||||
# LeakyReLU
|
||||
set([
|
||||
{
|
||||
nn.LeakyReLU,
|
||||
]),
|
||||
},
|
||||
# ReLU6
|
||||
set([
|
||||
{
|
||||
nn.ReLU6,
|
||||
F.relu6,
|
||||
]),
|
||||
},
|
||||
# F.elu
|
||||
set([
|
||||
{
|
||||
F.elu,
|
||||
]),
|
||||
},
|
||||
# F.hardswish
|
||||
set([
|
||||
{
|
||||
F.hardswish,
|
||||
]),
|
||||
},
|
||||
# F.group_norm
|
||||
set([
|
||||
{
|
||||
F.group_norm,
|
||||
]),
|
||||
},
|
||||
# F.instance_norm
|
||||
set([
|
||||
{
|
||||
F.instance_norm,
|
||||
]),
|
||||
},
|
||||
# F.layer_norm
|
||||
set([
|
||||
{
|
||||
F.layer_norm,
|
||||
]),
|
||||
},
|
||||
# F.leaky_relu
|
||||
set([
|
||||
{
|
||||
F.leaky_relu,
|
||||
]),
|
||||
},
|
||||
# F.silu
|
||||
set([
|
||||
{
|
||||
nn.SiLU,
|
||||
F.silu,
|
||||
]),
|
||||
},
|
||||
# F.mish
|
||||
set([
|
||||
{
|
||||
nn.Mish,
|
||||
F.mish,
|
||||
]),
|
||||
},
|
||||
# F.tanh
|
||||
set([
|
||||
{
|
||||
nn.Tanh,
|
||||
F.tanh,
|
||||
torch.tanh,
|
||||
'tanh_',
|
||||
'tanh',
|
||||
]),
|
||||
},
|
||||
# F.hardsigmoid
|
||||
set([
|
||||
{
|
||||
'hardsigmoid_',
|
||||
'hardsigmoid',
|
||||
F.hardsigmoid,
|
||||
nn.Hardsigmoid,
|
||||
]),
|
||||
},
|
||||
# F.hardtanh
|
||||
set([
|
||||
{
|
||||
nn.Hardtanh,
|
||||
F.hardtanh,
|
||||
F.hardtanh_,
|
||||
]),
|
||||
},
|
||||
# floordiv
|
||||
set([
|
||||
{
|
||||
operator.floordiv,
|
||||
]),
|
||||
},
|
||||
# unsqueeze
|
||||
set([
|
||||
{
|
||||
torch.unsqueeze,
|
||||
]),
|
||||
},
|
||||
# stack
|
||||
set([
|
||||
{
|
||||
torch.stack,
|
||||
]),
|
||||
},
|
||||
# squeeze
|
||||
set([
|
||||
{
|
||||
torch.squeeze,
|
||||
]),
|
||||
},
|
||||
# sort
|
||||
set([
|
||||
{
|
||||
torch.sort,
|
||||
]),
|
||||
},
|
||||
# repeat_interleave
|
||||
set([
|
||||
{
|
||||
torch.repeat_interleave,
|
||||
]),
|
||||
},
|
||||
# min
|
||||
set([
|
||||
{
|
||||
torch.min,
|
||||
]),
|
||||
},
|
||||
# mean
|
||||
set([
|
||||
{
|
||||
torch.mean,
|
||||
]),
|
||||
},
|
||||
# max
|
||||
set([
|
||||
{
|
||||
torch.max,
|
||||
]),
|
||||
},
|
||||
# transpose
|
||||
set([
|
||||
{
|
||||
torch.transpose,
|
||||
]),
|
||||
},
|
||||
# flatten
|
||||
set([
|
||||
{
|
||||
torch.flatten,
|
||||
]),
|
||||
},
|
||||
# clamp
|
||||
set([
|
||||
{
|
||||
torch.clamp,
|
||||
]),
|
||||
},
|
||||
# chunk
|
||||
set([
|
||||
{
|
||||
torch.chunk,
|
||||
]),
|
||||
},
|
||||
# interpolate
|
||||
set([
|
||||
{
|
||||
torch.nn.functional.interpolate,
|
||||
]),
|
||||
},
|
||||
# dropout
|
||||
set([
|
||||
{
|
||||
nn.Dropout,
|
||||
]),
|
||||
},
|
||||
# F.dropout
|
||||
set([
|
||||
{
|
||||
F.dropout,
|
||||
]),
|
||||
},
|
||||
# matmul
|
||||
set([
|
||||
{
|
||||
torch.matmul,
|
||||
]),
|
||||
},
|
||||
# Softmax
|
||||
set([
|
||||
{
|
||||
nn.Softmax,
|
||||
]),
|
||||
},
|
||||
# PReLU
|
||||
set([
|
||||
{
|
||||
nn.PReLU,
|
||||
nnq.PReLU,
|
||||
]),
|
||||
},
|
||||
# F.prelu
|
||||
set([
|
||||
{
|
||||
F.prelu,
|
||||
toq.prelu,
|
||||
]),
|
||||
},
|
||||
]
|
||||
|
||||
# for each floating point op, add versions of the op added by
|
||||
|
|
@ -453,12 +453,12 @@ def add_op_to_sets_of_related_ops(
|
|||
counter = 0
|
||||
while str(counter) in base_name_to_sets_of_related_ops:
|
||||
counter += 1
|
||||
base_name_to_sets_of_related_ops[str(counter)] = set([op])
|
||||
base_name_to_sets_of_related_ops[str(counter)] = {op}
|
||||
|
||||
|
||||
# TODO(future PR): clean this up
|
||||
def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
||||
FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = set([
|
||||
FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
|
||||
F.linear,
|
||||
F.conv1d,
|
||||
F.conv2d,
|
||||
|
|
@ -478,11 +478,11 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
torch.mul,
|
||||
torch.sum,
|
||||
F.prelu,
|
||||
])
|
||||
}
|
||||
|
||||
FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
|
||||
|
||||
FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([
|
||||
FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
|
||||
toq.linear,
|
||||
toq.linear_relu,
|
||||
toq.conv1d,
|
||||
|
|
@ -503,9 +503,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
# uncomment below
|
||||
# toq.add,
|
||||
# toq.mul,
|
||||
])
|
||||
}
|
||||
|
||||
FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([
|
||||
FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
|
||||
F.relu,
|
||||
F.tanh,
|
||||
torch.tanh,
|
||||
|
|
@ -541,9 +541,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
torch.stack,
|
||||
torch.unsqueeze,
|
||||
operator.add,
|
||||
])
|
||||
}
|
||||
|
||||
MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = set([
|
||||
MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
|
||||
nn.Linear,
|
||||
nnqat.Linear,
|
||||
nnqatd.Linear,
|
||||
|
|
@ -606,9 +606,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
nni.LinearTanh,
|
||||
nni.ConvAdd2d,
|
||||
nni.ConvAddReLU2d,
|
||||
])
|
||||
}
|
||||
|
||||
MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([
|
||||
MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
|
||||
nnq.Linear,
|
||||
nnq.Conv1d,
|
||||
nnq.Conv2d,
|
||||
|
|
@ -640,9 +640,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
nniq.LinearTanh,
|
||||
nniq.ConvAdd2d,
|
||||
nniq.ConvAddReLU2d,
|
||||
])
|
||||
}
|
||||
|
||||
MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([
|
||||
MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
|
||||
nn.ReLU,
|
||||
nn.Tanh,
|
||||
nn.Sigmoid,
|
||||
|
|
@ -660,9 +660,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
nn.MaxPool2d,
|
||||
nn.MaxPool3d,
|
||||
nn.ReLU6,
|
||||
])
|
||||
}
|
||||
|
||||
METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([
|
||||
METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
|
||||
'sigmoid_',
|
||||
'sigmoid',
|
||||
'tanh_',
|
||||
|
|
@ -671,7 +671,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
'hardsigmoid',
|
||||
'relu_',
|
||||
'relu',
|
||||
])
|
||||
}
|
||||
|
||||
return {
|
||||
'funs_io_type_fp32': FUNS_IO_TYPE_FP32,
|
||||
|
|
@ -687,16 +687,16 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
|
||||
def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
|
||||
|
||||
FUNS_UNMATCHABLE: Set[NSNodeTargetType] = set([
|
||||
FUNS_UNMATCHABLE: Set[NSNodeTargetType] = {
|
||||
torch.quantize_per_tensor,
|
||||
operator.getitem,
|
||||
])
|
||||
}
|
||||
|
||||
MODS_UNMATCHABLE: Set[NSNodeTargetType] = set([
|
||||
MODS_UNMATCHABLE: Set[NSNodeTargetType] = {
|
||||
nn.Identity,
|
||||
])
|
||||
}
|
||||
|
||||
METHS_UNMATCHABLE: Set[NSNodeTargetType] = set([
|
||||
METHS_UNMATCHABLE: Set[NSNodeTargetType] = {
|
||||
'to',
|
||||
'dequantize',
|
||||
'reshape',
|
||||
|
|
@ -719,7 +719,7 @@ def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
'contiguous',
|
||||
'clamp',
|
||||
'chunk',
|
||||
])
|
||||
}
|
||||
|
||||
return {
|
||||
'funs_unmatchable': FUNS_UNMATCHABLE,
|
||||
|
|
|
|||
|
|
@ -991,9 +991,9 @@ def extract_weight_comparison(m: GraphModule) -> NSResultsType:
|
|||
# use functions.
|
||||
|
||||
# TODO(future PR): move this to config
|
||||
weighted_ops = set([
|
||||
weighted_ops = {
|
||||
torch.nn.functional.linear,
|
||||
])
|
||||
}
|
||||
|
||||
results: NSResultsType = {
|
||||
'model': {NSSingleResultValuesType.WEIGHT.value: {}}
|
||||
|
|
|
|||
|
|
@ -219,10 +219,10 @@ class PerChannelDetector(DetectorBase):
|
|||
|
||||
# Default map for representing supported per channel quantization modules for different backends
|
||||
DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = {
|
||||
"fbgemm": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
|
||||
"qnnpack": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
|
||||
"onednn": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
|
||||
"x86": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
|
||||
"fbgemm": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
|
||||
"qnnpack": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
|
||||
"onednn": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
|
||||
"x86": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
|
||||
}
|
||||
|
||||
def __init__(self, backend: str = torch.backends.quantized.engine):
|
||||
|
|
@ -230,7 +230,7 @@ class PerChannelDetector(DetectorBase):
|
|||
|
||||
# store the backend information
|
||||
self.backend_chosen = backend
|
||||
self.supported_modules = set([])
|
||||
self.supported_modules = set()
|
||||
if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES:
|
||||
self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[self.backend_chosen]
|
||||
else:
|
||||
|
|
@ -413,17 +413,17 @@ class DynamicStaticDetector(DetectorBase):
|
|||
IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported"
|
||||
|
||||
# modules that are supported both dynamic and static for this report function
|
||||
DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = set([nn.Linear])
|
||||
DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear}
|
||||
|
||||
# modules that will be supported soon for both
|
||||
DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = set([nn.Conv1d, nn.Conv2d, nn.Conv3d])
|
||||
DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d}
|
||||
|
||||
def __init__(self, tolerance=0.5):
|
||||
super().__init__()
|
||||
|
||||
# set tolerance level and initialize a set to keep track of useful fqn locations
|
||||
self.tolerance = tolerance
|
||||
self.useful_observer_fqns: Set[str] = set([])
|
||||
self.useful_observer_fqns: Set[str] = set()
|
||||
|
||||
def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
|
||||
r"""
|
||||
|
|
@ -737,9 +737,14 @@ class InputWeightEqualizationDetector(DetectorBase):
|
|||
* :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
|
||||
"""
|
||||
|
||||
SUPPORTED_MODULES: Set[Callable] = set(
|
||||
[nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]
|
||||
)
|
||||
SUPPORTED_MODULES: Set[Callable] = {nn.Linear,
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
nnqat.Linear,
|
||||
nnqat.Conv1d,
|
||||
nnqat.Conv2d,
|
||||
nnqat.Conv3d}
|
||||
|
||||
# names for the pre and post observers that are inserted
|
||||
DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ class ModelReport:
|
|||
|
||||
# initialize each report to have empty set of observers of interest
|
||||
for desired_report in self._desired_detector_names:
|
||||
self._detector_name_to_observer_fqns[desired_report] = set([])
|
||||
self._detector_name_to_observer_fqns[desired_report] = set()
|
||||
|
||||
# flags to ensure that we can only prepare and remove observers once
|
||||
self._prepared_flag = False
|
||||
|
|
@ -287,7 +287,7 @@ class ModelReport:
|
|||
if remove_inserted_observers:
|
||||
self._removed_observers = True
|
||||
# get the set of all Observers inserted by this instance of ModelReport
|
||||
all_observers_of_interest: Set[str] = set([])
|
||||
all_observers_of_interest: Set[str] = set()
|
||||
for desired_report in self._detector_name_to_observer_fqns:
|
||||
observers_of_interest = self._detector_name_to_observer_fqns[desired_report]
|
||||
all_observers_of_interest.update(observers_of_interest)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class FusedGraphModule(GraphModule):
|
|||
class ObservedGraphModule(GraphModule):
|
||||
|
||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
|
||||
self.preserved_attr_names = set([
|
||||
self.preserved_attr_names = {
|
||||
'_activation_post_process_map',
|
||||
'_activation_post_process_indexes',
|
||||
'_patterns',
|
||||
|
|
@ -40,7 +40,7 @@ class ObservedGraphModule(GraphModule):
|
|||
'_node_name_to_scope',
|
||||
'_qconfig_mapping',
|
||||
'_is_qat',
|
||||
'_observed_node_names']).union(preserved_attr_names)
|
||||
'_observed_node_names'}.union(preserved_attr_names)
|
||||
preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
|
||||
super().__init__(root, graph)
|
||||
for attr in preserved_attrs:
|
||||
|
|
@ -64,9 +64,9 @@ def _get_observed_graph_module_attr(model: Union[torch.nn.Module, GraphModule],
|
|||
|
||||
class ObservedStandaloneGraphModule(ObservedGraphModule):
|
||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
|
||||
preserved_attr_names = preserved_attr_names.union(set([
|
||||
preserved_attr_names = preserved_attr_names.union({
|
||||
"_standalone_module_input_quantized_idxs",
|
||||
"_standalone_module_output_quantized_idxs"]))
|
||||
"_standalone_module_output_quantized_idxs"})
|
||||
super().__init__(root, graph, preserved_attr_names)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
|
|
|
|||
|
|
@ -208,10 +208,10 @@ DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
|||
|
||||
def no_observer_set() -> Set[Any]:
|
||||
r"""These modules cannot have observers inserted by default."""
|
||||
no_observers = set([
|
||||
no_observers = {
|
||||
nn.quantizable.LSTM,
|
||||
nn.quantizable.MultiheadAttention
|
||||
])
|
||||
}
|
||||
return no_observers
|
||||
|
||||
def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
|
||||
|
|
|
|||
|
|
@ -1609,8 +1609,8 @@ def gradgradcheck(
|
|||
|
||||
# NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs
|
||||
# before running forward mode AD
|
||||
diff_input_args_indices = set(i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad)
|
||||
diff_grad_output_indices = set(i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad)
|
||||
diff_input_args_indices = {i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad}
|
||||
diff_grad_output_indices = {i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad}
|
||||
|
||||
def new_func(*args):
|
||||
# Restore the requires_grad information
|
||||
|
|
|
|||
|
|
@ -491,7 +491,7 @@ def _parse_visible_devices() -> Set[int]:
|
|||
"""Parse CUDA_VISIBLE_DEVICES environment variable."""
|
||||
var = os.getenv("CUDA_VISIBLE_DEVICES")
|
||||
if var is None:
|
||||
return set(x for x in range(64))
|
||||
return set(range(64))
|
||||
|
||||
def _strtoul(s: str) -> int:
|
||||
"""Return -1 or positive integer sequence string starts with,"""
|
||||
|
|
|
|||
|
|
@ -85,11 +85,11 @@ def compare(before, after, format_flamegraph=format_flamegraph):
|
|||
|
||||
f = io.StringIO()
|
||||
|
||||
before_segs = set(_seg_key(seg) for seg in before)
|
||||
after_segs = set(_seg_key(seg) for seg in after)
|
||||
before_segs = {_seg_key(seg) for seg in before}
|
||||
after_segs = {_seg_key(seg) for seg in after}
|
||||
|
||||
print(f'only_before = {list(a for a,_ in (before_segs - after_segs))}')
|
||||
print(f'only_after = {list(a for a,_ in (after_segs - before_segs))}')
|
||||
print(f'only_before = {[a for a,_ in (before_segs - after_segs)]}')
|
||||
print(f'only_after = {[a for a,_ in (after_segs - before_segs)]}')
|
||||
|
||||
for seg in before:
|
||||
if _seg_key(seg) not in after_segs:
|
||||
|
|
|
|||
|
|
@ -383,7 +383,7 @@ class DistributedDataParallel(Module):
|
|||
]
|
||||
|
||||
# Build list of parameters.
|
||||
parameters = list(parameter for _, parameter in modules_and_parameters)
|
||||
parameters = [parameter for _, parameter in modules_and_parameters]
|
||||
|
||||
# Checks if a module will produce a sparse gradient.
|
||||
def produces_sparse_gradient(module):
|
||||
|
|
@ -393,9 +393,9 @@ class DistributedDataParallel(Module):
|
|||
|
||||
# Build list of booleans indicating whether or not to expect sparse
|
||||
# gradients for the corresponding parameters.
|
||||
expect_sparse_gradient = list(
|
||||
expect_sparse_gradient = [
|
||||
produces_sparse_gradient(module) for module, _ in modules_and_parameters
|
||||
)
|
||||
]
|
||||
|
||||
self._assign_modules_buffers()
|
||||
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ def _handle_row_wise_sharding_tensor(
|
|||
indices[placement.rank()] = list(
|
||||
range(offset_start_idx, offset_start_idx + split_size)
|
||||
)
|
||||
indices_flatten = list(idx for indice in indices for idx in indice)
|
||||
indices_flatten = [idx for indice in indices for idx in indice]
|
||||
|
||||
input_t = input_t.index_select(
|
||||
0, torch.tensor(indices_flatten, device=input_t.device)
|
||||
|
|
|
|||
|
|
@ -38,10 +38,10 @@ def wrap(res: object, spec: OutputSpecType) -> object:
|
|||
assert spec is not None and isinstance(
|
||||
spec, list
|
||||
), f"output spec does not match with output! Expected list, got {spec}."
|
||||
return list(
|
||||
return [
|
||||
dtensor.DTensor(e, s.mesh, s.placements, size=s.shape)
|
||||
for e, s in zip(res, spec)
|
||||
)
|
||||
]
|
||||
elif isinstance(res, tuple):
|
||||
assert spec is not None and isinstance(
|
||||
spec, tuple
|
||||
|
|
|
|||
|
|
@ -397,7 +397,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding:
|
|||
assert isinstance(indices_output_spec, DTensorSpec)
|
||||
indices_spec = indices_output_spec
|
||||
|
||||
lookup_dims = set(v[0] for v in valid_indices_spec)
|
||||
lookup_dims = {v[0] for v in valid_indices_spec}
|
||||
|
||||
need_reshard_on_values = tuple(
|
||||
(isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard)))
|
||||
|
|
|
|||
|
|
@ -370,7 +370,7 @@ def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap:
|
|||
dim2 = normalize_dim(dim2, ndim)
|
||||
assert dim1 < ndim
|
||||
assert dim2 < ndim
|
||||
dimmap = list(InputDim(i) for i in range(ndim))
|
||||
dimmap = [InputDim(i) for i in range(ndim)]
|
||||
swapdim = dimmap[dim1]
|
||||
dimmap[dim1] = dimmap[dim2]
|
||||
dimmap[dim2] = swapdim
|
||||
|
|
@ -480,7 +480,7 @@ def propagate_shape_and_sharding(
|
|||
if the leftmost split size is divisible by the mesh dimension
|
||||
"""
|
||||
assert len(in_shard) == len(mesh_sizes)
|
||||
sharded_in_dims: Set[int] = set(s.dim for s in in_shard if isinstance(s, Shard))
|
||||
sharded_in_dims: Set[int] = {s.dim for s in in_shard if isinstance(s, Shard)}
|
||||
# for each input dim, for each mesh dim, provides a list of possible shardable dimensions
|
||||
shardable_dims: torch.Tensor = torch.ones(
|
||||
(len(local_in_shape), len(mesh_sizes)), dtype=torch.bool
|
||||
|
|
|
|||
|
|
@ -567,12 +567,12 @@ def _get_ignored_modules(
|
|||
# that this FSDP instance can get any ignored modules from its children.
|
||||
|
||||
# Include child modules and exclude nested FSDP modules themselves
|
||||
ignored_modules = set(
|
||||
ignored_modules = {
|
||||
child
|
||||
for module in ignored_root_modules
|
||||
for child in module.modules()
|
||||
if not isinstance(child, fsdp_file.FullyShardedDataParallel)
|
||||
)
|
||||
}
|
||||
if root_module in ignored_modules:
|
||||
warnings.warn(
|
||||
"Trying to ignore the top-level module passed into the FSDP "
|
||||
|
|
@ -599,16 +599,16 @@ def _get_ignored_params(
|
|||
"""
|
||||
all_ignored_params: Set[torch.nn.Parameter] = set()
|
||||
|
||||
params_in_ignored_modules = set(
|
||||
params_in_ignored_modules = {
|
||||
p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
|
||||
)
|
||||
}
|
||||
|
||||
all_ignored_params.update(params_in_ignored_modules)
|
||||
|
||||
if ignored_parameters is not None:
|
||||
params_in_ignored_parameters = set(
|
||||
params_in_ignored_parameters = {
|
||||
p for p in ignored_parameters if not _is_fsdp_flattened(p)
|
||||
)
|
||||
}
|
||||
all_ignored_params.update(params_in_ignored_parameters)
|
||||
|
||||
# Include nested FSDP modules' ignored parameters
|
||||
|
|
@ -626,9 +626,9 @@ def _get_buffer_names(root_module: nn.Module) -> Set[str]:
|
|||
Returns the fully prefixed names of all buffers in the module hierarchy
|
||||
rooted at ``root_module`` as a class:`set`.
|
||||
"""
|
||||
return set(
|
||||
return {
|
||||
clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _check_single_device_module(
|
||||
|
|
@ -640,7 +640,7 @@ def _check_single_device_module(
|
|||
ignoring the parameters in ``ignored_params``. Thus, after this method, the
|
||||
module must be either fully on the CPU or fully on a non-CPU device.
|
||||
"""
|
||||
devices = set(param.device for param in _get_orig_params(module, ignored_params))
|
||||
devices = {param.device for param in _get_orig_params(module, ignored_params)}
|
||||
if len(devices) > 1:
|
||||
raise RuntimeError(
|
||||
f"FSDP only supports single device modules but got params on {devices}"
|
||||
|
|
|
|||
|
|
@ -485,7 +485,7 @@ def _flatten_optim_state(
|
|||
are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0
|
||||
are_zero_dim_tensors &= _is_zero_dim_tensor(v)
|
||||
are_non_tensors &= not torch.is_tensor(v)
|
||||
types = set(type(v) for v in non_none_state_values)
|
||||
types = {type(v) for v in non_none_state_values}
|
||||
if len(types) != 1 or not (
|
||||
are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors
|
||||
):
|
||||
|
|
@ -570,7 +570,7 @@ def _flatten_tensor_optim_state(
|
|||
"""
|
||||
non_none_tensors = [t for t in pos_dim_tensors if t is not None]
|
||||
# Check that all are tensors with the same dtype
|
||||
dtypes = set(t.dtype for t in non_none_tensors)
|
||||
dtypes = {t.dtype for t in non_none_tensors}
|
||||
if len(dtypes) != 1:
|
||||
raise ValueError(
|
||||
"All unflattened parameters comprising a single flattened "
|
||||
|
|
@ -648,8 +648,8 @@ def _flatten_zero_dim_tensor_optim_state(
|
|||
"""
|
||||
non_none_tensors = [t for t in zero_dim_tensors if t is not None]
|
||||
# Enforce that all have the same value and dtype
|
||||
values_set = set(t.item() if t is not None else None for t in zero_dim_tensors)
|
||||
dtypes = set(t.dtype if t is not None else None for t in zero_dim_tensors)
|
||||
values_set = {t.item() if t is not None else None for t in zero_dim_tensors}
|
||||
dtypes = {t.dtype if t is not None else None for t in zero_dim_tensors}
|
||||
if (
|
||||
len(non_none_tensors) != len(zero_dim_tensors)
|
||||
or len(values_set) != 1
|
||||
|
|
@ -1004,10 +1004,10 @@ def _rekey_sharded_optim_state_dict(
|
|||
for unflat_param_group in sharded_osd["param_groups"]:
|
||||
flat_param_group = copy.deepcopy(unflat_param_group)
|
||||
flat_param_keys = sorted(
|
||||
set(
|
||||
{
|
||||
unflat_param_name_to_flat_param_key[unflat_param_name]
|
||||
for unflat_param_name in unflat_param_group["params"]
|
||||
)
|
||||
}
|
||||
)
|
||||
flat_param_group["params"] = flat_param_keys
|
||||
rekeyed_osd_param_groups.append(flat_param_group)
|
||||
|
|
|
|||
|
|
@ -1068,7 +1068,7 @@ def _get_training_state(
|
|||
) -> HandleTrainingState:
|
||||
"""Returns the training state of the handles in ``handles_key``."""
|
||||
p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
|
||||
training_states = set(handle._training_state for handle in handles_key)
|
||||
training_states = {handle._training_state for handle in handles_key}
|
||||
p_assert(
|
||||
len(training_states) == 1,
|
||||
f"Expects uniform training state but got {training_states}",
|
||||
|
|
|
|||
|
|
@ -274,8 +274,8 @@ class FlatParameter(nn.Parameter):
|
|||
self._fqns = tuple(fqns)
|
||||
self._shared_param_infos = tuple(shared_param_infos)
|
||||
self._param_extensions = tuple(param_extensions)
|
||||
self._modules = set(pi.module for pi in self._param_infos).union(
|
||||
set(spi.module for spi in self._shared_param_infos)
|
||||
self._modules = {pi.module for pi in self._param_infos}.union(
|
||||
{spi.module for spi in self._shared_param_infos}
|
||||
)
|
||||
assert (params is None) == (shared_params is None)
|
||||
if params is not None:
|
||||
|
|
@ -1857,8 +1857,8 @@ class FlatParamHandle:
|
|||
def _get_modules(self) -> Set[nn.Module]:
|
||||
"""Returns a :class:`set` of the modules whose parameters are included
|
||||
in this handle's flattened parameter."""
|
||||
return set(pi.module for pi in self.flat_param._param_infos).union(
|
||||
set(spi.module for spi in self.flat_param._shared_param_infos)
|
||||
return {pi.module for pi in self.flat_param._param_infos}.union(
|
||||
{spi.module for spi in self.flat_param._shared_param_infos}
|
||||
)
|
||||
|
||||
def is_sharded(self, tensor: Tensor) -> bool:
|
||||
|
|
|
|||
|
|
@ -1968,7 +1968,7 @@ def _get_grad_norm(
|
|||
if len(params_with_grad) == 0:
|
||||
return torch.tensor(0.0)
|
||||
grads = [param.grad for param in params_with_grad]
|
||||
grad_dtypes = set(grad.dtype for grad in grads)
|
||||
grad_dtypes = {grad.dtype for grad in grads}
|
||||
if len(grad_dtypes) != 1:
|
||||
raise ValueError(
|
||||
f"Requires uniform dtype across all gradients but got {grad_dtypes}"
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def register_rendezvous_handler(scheme, handler):
|
|||
# Query will have format "rank=0&world_size=1" and is
|
||||
# converted into {"rank": 0, "world_size": 1}
|
||||
def _query_to_dict(query: str) -> Dict[str, str]:
|
||||
return dict((pair[0], pair[1]) for pair in (pair.split("=") for pair in filter(None, query.split("&"))))
|
||||
return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))}
|
||||
|
||||
|
||||
def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs):
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ def check_dependency(partition):
|
|||
"""Given a partition,check if there is a circular dependency on
|
||||
this partition using bfs
|
||||
"""
|
||||
visited: Set[Partition] = set([partition])
|
||||
visited: Set[Partition] = {partition}
|
||||
queue: Deque[Partition] = deque([partition])
|
||||
while queue:
|
||||
p = queue.popleft()
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ _reify
|
|||
|
||||
@dispatch(dict, dict) # type: ignore[no-redef]
|
||||
def _reify(d, s):
|
||||
return dict((k, reify(v, s)) for k, v in d.items())
|
||||
return {k: reify(v, s) for k, v in d.items()}
|
||||
_reify
|
||||
|
||||
@dispatch(object, dict) # type: ignore[no-redef]
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class VarDispatcher(Dispatcher):
|
|||
"""
|
||||
def __call__(self, *args, **kwargs):
|
||||
func, s = self.resolve(args)
|
||||
d = dict((k.token, v) for k, v in s.items())
|
||||
d = {k.token: v for k, v in s.items()}
|
||||
return func(**d)
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ def supercedes(a, b):
|
|||
s = unify(a, b)
|
||||
if s is False:
|
||||
return False
|
||||
s = dict((k, v) for k, v in s.items() if not isvar(k) or not isvar(v))
|
||||
s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
|
||||
if reify(a, s) == a:
|
||||
return True
|
||||
if reify(b, s) == b:
|
||||
|
|
@ -117,5 +117,5 @@ def ordering(signatures):
|
|||
for s in signatures:
|
||||
if s not in edges:
|
||||
edges[s] = []
|
||||
edges = dict((k, [b for a, b in v]) for k, v in edges.items()) # type: ignore[attr-defined, assignment]
|
||||
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
|
||||
return _toposort(edges)
|
||||
|
|
|
|||
|
|
@ -80,11 +80,11 @@ def ambiguous(a, b):
|
|||
def ambiguities(signatures):
|
||||
""" All signature pairs such that A is ambiguous with B """
|
||||
signatures = list(map(tuple, signatures))
|
||||
return set((a, b) for a in signatures for b in signatures
|
||||
if hash(a) < hash(b)
|
||||
and ambiguous(a, b)
|
||||
and not any(supercedes(c, a) and supercedes(c, b)
|
||||
for c in signatures))
|
||||
return {(a, b) for a in signatures for b in signatures
|
||||
if hash(a) < hash(b)
|
||||
and ambiguous(a, b)
|
||||
and not any(supercedes(c, a) and supercedes(c, b)
|
||||
for c in signatures)}
|
||||
|
||||
|
||||
def super_signature(signatures):
|
||||
|
|
@ -92,7 +92,7 @@ def super_signature(signatures):
|
|||
n = len(signatures[0])
|
||||
assert all(len(s) == n for s in signatures)
|
||||
|
||||
return [max([type.mro(sig[i]) for sig in signatures], key=len)[0]
|
||||
return [max((type.mro(sig[i]) for sig in signatures), key=len)[0]
|
||||
for i in range(n)]
|
||||
|
||||
|
||||
|
|
@ -115,5 +115,5 @@ def ordering(signatures):
|
|||
for s in signatures:
|
||||
if s not in edges:
|
||||
edges[s] = []
|
||||
edges = dict((k, [b for a, b in v]) for k, v in edges.items()) # type: ignore[assignment, attr-defined]
|
||||
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined]
|
||||
return _toposort(edges)
|
||||
|
|
|
|||
|
|
@ -45,8 +45,8 @@ def _toposort(edges):
|
|||
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
|
||||
"""
|
||||
incoming_edges = reverse_dict(edges)
|
||||
incoming_edges = dict((k, set(val)) for k, val in incoming_edges.items())
|
||||
S = set((v for v in edges if v not in incoming_edges))
|
||||
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
|
||||
S = ({v for v in edges if v not in incoming_edges})
|
||||
L = []
|
||||
|
||||
while S:
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ aten = torch.ops.aten
|
|||
|
||||
|
||||
# stateful ops are banned from CSE
|
||||
rand_ops = set([aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]) # noqa: E501
|
||||
rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501
|
||||
|
||||
inplace_ops = set([aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_]) # noqa: E501
|
||||
inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501
|
||||
|
||||
|
||||
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
|
||||
|
|
|
|||
|
|
@ -468,10 +468,10 @@ def reinplace(gm, *sample_args):
|
|||
# so we know not to re-inplace them.
|
||||
# NOTE: later, we'll need to add an optimization for fully recovering performance
|
||||
# on programs that mutate inputs.
|
||||
input_storages = set(
|
||||
input_storages = {
|
||||
StorageWeakRef(
|
||||
node.meta['fake_result']._typed_storage()
|
||||
) for node in gm.graph.nodes if node.op == 'placeholder')
|
||||
) for node in gm.graph.nodes if node.op == 'placeholder'}
|
||||
|
||||
|
||||
# We also need to know for a given node, what are all of its aliasing nodes.
|
||||
|
|
@ -627,14 +627,14 @@ def reinplace(gm, *sample_args):
|
|||
old_flattened_res, _ = tree_flatten(old.meta['fake_result'])
|
||||
node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result'])
|
||||
|
||||
old_res_storage = set(
|
||||
old_res_storage = {
|
||||
StorageWeakRef(
|
||||
x._typed_storage()
|
||||
) for x in old_flattened_res if isinstance(x, FakeTensor))
|
||||
node_res_storage = set(
|
||||
) for x in old_flattened_res if isinstance(x, FakeTensor)}
|
||||
node_res_storage = {
|
||||
StorageWeakRef(
|
||||
x._typed_storage()
|
||||
) for x in node_flattened_res if isinstance(x, FakeTensor))
|
||||
) for x in node_flattened_res if isinstance(x, FakeTensor)}
|
||||
|
||||
# This will happen if we're updating a view op, e.g.
|
||||
# e.g. replacing
|
||||
|
|
@ -648,10 +648,10 @@ def reinplace(gm, *sample_args):
|
|||
# We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
|
||||
if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
|
||||
new_flattened_res, _ = tree_flatten(new.meta['fake_result'])
|
||||
new_res_storage = set(
|
||||
new_res_storage = {
|
||||
StorageWeakRef(
|
||||
x._typed_storage()
|
||||
) for x in new_flattened_res if isinstance(x, FakeTensor))
|
||||
) for x in new_flattened_res if isinstance(x, FakeTensor)}
|
||||
assert len(new_res_storage) == 1
|
||||
(old_ref,) = old_res_storage
|
||||
(new_ref,) = new_res_storage
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ def generate_inputs_for_submodules(
|
|||
|
||||
handles = []
|
||||
results = {}
|
||||
submodule_to_names = dict((mod, name) for name, mod in model.named_modules())
|
||||
submodule_to_names = {mod: name for name, mod in model.named_modules()}
|
||||
|
||||
def pre_forward(module, module_inputs):
|
||||
results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ def _gen_torch_functional_registered_ops():
|
|||
# some functions directly map to their aten:: implementations.
|
||||
# TODO: add support for more ops
|
||||
ops = ["stft", "istft", "lu", "cdist", "norm", "unique", "unique_consecutive", "tensordot"]
|
||||
return set(getattr(torch.functional, name) for name in ops)
|
||||
return {getattr(torch.functional, name) for name in ops}
|
||||
|
||||
_functional_registered_ops = _gen_torch_functional_registered_ops()
|
||||
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ def jit_ignored_properties(module):
|
|||
user_annotated_ignored_attributes = getattr(module, "__jit_ignored_attributes__", list())
|
||||
|
||||
def get_properties_names(module):
|
||||
return set(k for k, v in vars(module).items() if isinstance(v, property))
|
||||
return {k for k, v in vars(module).items() if isinstance(v, property)}
|
||||
|
||||
properties = get_properties_names(type(module))
|
||||
user_annoted_ignored_properties = set()
|
||||
|
|
|
|||
|
|
@ -352,7 +352,7 @@ def try_ann_to_type(ann, loc):
|
|||
return OptionalType(valid_type)
|
||||
if is_union(ann):
|
||||
# TODO: this is hack to recognize NumberType
|
||||
if set(ann.__args__) == set([int, float, complex]):
|
||||
if set(ann.__args__) == {int, float, complex}:
|
||||
return NumberType.get()
|
||||
inner: List = []
|
||||
# We need these extra checks because both `None` and invalid
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def _gen_unsupported_methods_properties():
|
|||
return x.{op}()
|
||||
''')
|
||||
|
||||
deprecated_apis = set(["volatile", "resize", "reinforce", "new", "name", "map2_", "has_names", "grad_fn", "resize_as"])
|
||||
deprecated_apis = {"volatile", "resize", "reinforce", "new", "name", "map2_", "has_names", "grad_fn", "resize_as"}
|
||||
tensor_attrs = tensor_attrs - deprecated_apis
|
||||
|
||||
properties = []
|
||||
|
|
|
|||
|
|
@ -378,11 +378,11 @@ defined as ``prod(x[:i])``.""",
|
|||
)
|
||||
|
||||
# Apply function name info to docstring templates:
|
||||
templates = dict(
|
||||
(k, v.format_map(template_data))
|
||||
templates = {
|
||||
k: v.format_map(template_data)
|
||||
for k, v in docstring_templates.items()
|
||||
if k.startswith(op_kind)
|
||||
)
|
||||
}
|
||||
templates.update(
|
||||
(k, v.format_map(template_data) if isinstance(v, str) else v)
|
||||
for k, v in template_data.items()
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ def _map_mt_args_kwargs(args, kwargs, map_fn):
|
|||
|
||||
def _wrap_result(result_data, result_mask):
|
||||
if isinstance(result_data, list):
|
||||
return list(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
|
||||
return [_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)]
|
||||
if isinstance(result_data, tuple):
|
||||
return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
|
||||
if torch.is_tensor(result_data):
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ class RNNBase(Module):
|
|||
# a sufficient check, because overlapping parameter buffers that don't completely
|
||||
# alias would break the assumptions of the uniqueness check in
|
||||
# Module.named_parameters().
|
||||
unique_data_ptrs = set(p.data_ptr() for p in self._flat_weights)
|
||||
unique_data_ptrs = {p.data_ptr() for p in self._flat_weights}
|
||||
if len(unique_data_ptrs) != len(self._flat_weights):
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -929,7 +929,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
]
|
||||
|
||||
# Build list of parameters.
|
||||
parameters = list(parameter for _, parameter in modules_and_parameters)
|
||||
parameters = [parameter for _, parameter in modules_and_parameters]
|
||||
|
||||
# Checks if a module will produce a sparse gradient.
|
||||
def produces_sparse_gradient(module):
|
||||
|
|
@ -939,10 +939,10 @@ class DistributedDataParallel(Module, Joinable):
|
|||
|
||||
# Build list of booleans indicating whether or not to expect sparse
|
||||
# gradients for the corresponding parameters.
|
||||
expect_sparse_gradient = list(
|
||||
expect_sparse_gradient = [
|
||||
produces_sparse_gradient(module)
|
||||
for module, _ in modules_and_parameters
|
||||
)
|
||||
]
|
||||
|
||||
self._assign_modules_buffers()
|
||||
|
||||
|
|
|
|||
|
|
@ -296,7 +296,7 @@ class NamedMemberAccessor:
|
|||
Check that the given keys are valid.
|
||||
"""
|
||||
keys = set(keys)
|
||||
valid_keys = set(name for name, _ in self.named_tensors(remove_duplicate=False))
|
||||
valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)}
|
||||
missing_keys = valid_keys - keys
|
||||
unexpected_keys = keys - valid_keys
|
||||
return sorted(missing_keys), sorted(unexpected_keys)
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ class DiagnosticContext:
|
|||
|
||||
def sarif(self) -> sarif.Run:
|
||||
"""Returns the SARIF Run object."""
|
||||
unique_rules = set(diagnostic.rule for diagnostic in self.diagnostics)
|
||||
unique_rules = {diagnostic.rule for diagnostic in self.diagnostics}
|
||||
return sarif.Run(
|
||||
tool=sarif.Tool(
|
||||
driver=sarif.ToolComponent(
|
||||
|
|
|
|||
|
|
@ -914,7 +914,7 @@ def verify_aten_graph(
|
|||
graph = graph.copy()
|
||||
|
||||
# Execute aten graph and get reference torch jit outputs.
|
||||
graph_inputs = list(v for v in graph.inputs())
|
||||
graph_inputs = list(graph.inputs())
|
||||
jit_inputs = tuple([arg for arg in input_args if arg is not None])
|
||||
weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]]
|
||||
assert all([w is not None for w in weights])
|
||||
|
|
@ -940,7 +940,7 @@ def verify_aten_graph(
|
|||
# NOTE: Verification is unstable. Try catch to emit information for debugging.
|
||||
try:
|
||||
# NOTE: Input might be dce'ed, so we need to remove those from the input args.
|
||||
new_input_names = set(v.debugName() for v in graph.inputs())
|
||||
new_input_names = {v.debugName() for v in graph.inputs()}
|
||||
new_input_args = []
|
||||
for v, arg in zip(original_jit_graph.inputs(), input_args):
|
||||
if v.debugName() in new_input_names:
|
||||
|
|
|
|||
|
|
@ -7919,9 +7919,7 @@ def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
|
|||
'nn.functional.max_unpool3d': 3
|
||||
}
|
||||
|
||||
unpool_to_pool_name_dict = dict((
|
||||
(k, f'nn.functional.{v.__name__}') for k, v in unpool_name_to_pool_method_dict.items()
|
||||
))
|
||||
unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()}
|
||||
|
||||
pool_dim = unpool_name_to_dim[op_info.name]
|
||||
pool_method = unpool_name_to_pool_method_dict[op_info.name]
|
||||
|
|
|
|||
|
|
@ -507,7 +507,7 @@ def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None,
|
|||
if isinstance(t, torch.Tensor) and t.requires_grad:
|
||||
return torch.randn_like(t)
|
||||
elif is_tensorlist(t):
|
||||
return list(torch.randn_like(e) if e.requires_grad else None for e in t)
|
||||
return [torch.randn_like(e) if e.requires_grad else None for e in t]
|
||||
return None
|
||||
|
||||
tangent_args = tuple(maybe_tangent(arg) for arg in args)
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user