[BE] Enable C419 rule for any all shortcircuiting (#99890)

Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890
Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
This commit is contained in:
Aaron Gokaslan 2023-04-25 15:02:13 +00:00 committed by PyTorch MergeBot
parent e43918b93a
commit e2a3817dfd
83 changed files with 175 additions and 203 deletions

View File

@ -13,5 +13,5 @@ for line in lines:
'CheckSymbolExists.c', 'CheckSymbolExists.c',
'test_compilation_error_formatting', 'test_compilation_error_formatting',
] ]
if all([keyword not in line for keyword in ignored_keywords]): if all(keyword not in line for keyword in ignored_keywords):
print(line) print(line)

View File

@ -16,7 +16,7 @@ ignore =
# these ignores are from flake8-bugbear; please fix! # these ignores are from flake8-bugbear; please fix!
B007,B008,B017,B019,B020,B023,B024,B026,B027,B028,B903,B904,B905,B906,B907 B007,B008,B017,B019,B020,B023,B024,B026,B027,B028,B903,B904,B905,B906,B907
# these ignores are from flake8-comprehensions; please fix! # these ignores are from flake8-comprehensions; please fix!
C407,C419, C407,
# these ignores are from flake8-logging-format; please fix! # these ignores are from flake8-logging-format; please fix!
G100,G101,G200,G201,G202 G100,G101,G200,G201,G202
# these ignores are from flake8-simplify. please fix or ignore with commented reason # these ignores are from flake8-simplify. please fix or ignore with commented reason

View File

@ -364,7 +364,7 @@ class TestTryMerge(TestCase):
lint_checks = [name for name in conclusions.keys() if "Lint" in name] lint_checks = [name for name in conclusions.keys() if "Lint" in name]
self.assertTrue(len(lint_checks) > 0) self.assertTrue(len(lint_checks) > 0)
self.assertTrue( self.assertTrue(
all([conclusions[name].status == "SUCCESS" for name in lint_checks]) all(conclusions[name].status == "SUCCESS" for name in lint_checks)
) )
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info()) @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())

View File

@ -82,10 +82,7 @@ class FlakyRule:
and self.name in job.get("name", "") and self.name in job.get("name", "")
and job.get("failure_captures") is not None and job.get("failure_captures") is not None
and all( and all(
[ capture in job.get("failure_captures", []) for capture in self.captures
capture in job.get("failure_captures", [])
for capture in self.captures
]
) )
) )
@ -1575,7 +1572,7 @@ def get_classifications(
checks_with_classifications[name] = JobCheckState( checks_with_classifications[name] = JobCheckState(
check.name, check.url, check.status, "BROKEN_TRUNK", check.job_id check.name, check.url, check.status, "BROKEN_TRUNK", check.job_id
) )
elif any([rule.matches(head_sha_job) for rule in flaky_rules]): elif any(rule.matches(head_sha_job) for rule in flaky_rules):
checks_with_classifications[name] = JobCheckState( checks_with_classifications[name] = JobCheckState(
check.name, check.url, check.status, "FLAKY", check.job_id check.name, check.url, check.status, "FLAKY", check.job_id
) )
@ -1710,11 +1707,11 @@ def categorize_checks(
relevant_checknames = [ relevant_checknames = [
name name
for name in check_runs.keys() for name in check_runs.keys()
if not required_checks or any([x in name for x in required_checks]) if not required_checks or any(x in name for x in required_checks)
] ]
for checkname in required_checks: for checkname in required_checks:
if all([checkname not in x for x in check_runs.keys()]): if all(checkname not in x for x in check_runs.keys()):
pending_checks.append((checkname, None, None)) pending_checks.append((checkname, None, None))
for checkname in relevant_checknames: for checkname in relevant_checknames:

View File

@ -1004,7 +1004,7 @@ def find_last_2_with_filenames(lookup_file, dashboard_archive_path, dtype, filen
fullpaths = [ fullpaths = [
os.path.join(dashboard_archive_path, path, name) for name in filenames os.path.join(dashboard_archive_path, path, name) for name in filenames
] ]
if all([os.path.exists(fullpath) for fullpath in fullpaths]): if all(os.path.exists(fullpath) for fullpath in fullpaths):
last2.append(output_dir) last2.append(output_dir)
if len(last2) >= 2: if len(last2) >= 2:
return last2 return last2

View File

@ -40,6 +40,6 @@ class TestConsumeOp(unittest.TestCase):
value = r(x) value = r(x)
self.assertTrue( self.assertTrue(
all([torch.allclose(t1, t2) for t1, t2 in zip(value, torch.chunk(x, 2))]) all(torch.allclose(t1, t2) for t1, t2 in zip(value, torch.chunk(x, 2)))
) )
self.assertEqual(occurance, iters) self.assertEqual(occurance, iters)

View File

@ -126,8 +126,8 @@ def cond_dense(pred, true_fn, false_fn, operands):
def cond_autograd(pred, true_fn, false_fn, *operands): def cond_autograd(pred, true_fn, false_fn, *operands):
# TODO: support autograd # TODO: support autograd
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands]) flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
assert all([not f.requires_grad for f in flat_operands assert all(not f.requires_grad for f in flat_operands
if isinstance(f, torch.Tensor)]) if isinstance(f, torch.Tensor))
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)) guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
return cond(pred, true_fn, false_fn, *operands) return cond(pred, true_fn, false_fn, *operands)

View File

@ -70,8 +70,8 @@ def map_cpu(f, xs, *args):
def map_autograd(f, xs, *args): def map_autograd(f, xs, *args):
# TODO: support autograd # TODO: support autograd
flat_operands, _ = tree_flatten([f, xs, args]) flat_operands, _ = tree_flatten([f, xs, args])
assert all([not f.requires_grad for f in flat_operands assert all(not f.requires_grad for f in flat_operands
if isinstance(f, torch.Tensor)]) if isinstance(f, torch.Tensor))
_ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)) _ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
return map(f, xs, *args) return map(f, xs, *args)

View File

@ -36,7 +36,6 @@ ignore = [
"B027", "B904", "B905", "B027", "B904", "B905",
"E402", "E402",
"C408", # C408 ignored because we like the dict keyword argument syntax "C408", # C408 ignored because we like the dict keyword argument syntax
"C419", # generators may not be supported by jit
"E501", # E501 is not flexible enough, we're using B950 instead "E501", # E501 is not flexible enough, we're using B950 instead
"E721", "E721",
"E731", # Assign lambda expression "E731", # Assign lambda expression

View File

@ -265,7 +265,7 @@ class CommitList:
else: else:
# Below are some extra quick checks that aren't necessarily file-path related, # Below are some extra quick checks that aren't necessarily file-path related,
# but I found that to catch a decent number of extra commits. # but I found that to catch a decent number of extra commits.
if len(files_changed) > 0 and all([f_name.endswith(('.cu', '.cuh')) for f_name in files_changed]): if len(files_changed) > 0 and all(f_name.endswith(('.cu', '.cuh')) for f_name in files_changed):
category = 'cuda' category = 'cuda'
elif '[PyTorch Edge]' in title: elif '[PyTorch Edge]' in title:
category = 'mobile' category = 'mobile'

View File

@ -134,7 +134,7 @@ class TestFSDPStateDict(FSDPTest):
self._state_compare(model, model_new, assert_fn) self._state_compare(model, model_new, assert_fn)
if check_buffers: if check_buffers:
has_buffers = any( has_buffers = any(
[len(list(m.buffers())) for m in (model, model_new)] len(list(m.buffers())) for m in (model, model_new)
) )
if has_buffers: if has_buffers:
self._state_compare( self._state_compare(

View File

@ -406,7 +406,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
explain_out = torch._dynamo.explain(ddp_m, inputs) explain_out = torch._dynamo.explain(ddp_m, inputs)
break_reasons = explain_out[4] break_reasons = explain_out[4]
self.assertEqual(len(break_reasons), 3) self.assertEqual(len(break_reasons), 3)
self.assertTrue(all(["DDPOptimizer" in r.reason for r in break_reasons])) self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
@patch.object(config, "optimize_ddp", True) @patch.object(config, "optimize_ddp", True)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")

View File

@ -554,7 +554,7 @@ class LazyModuleWithListInput(torch.nn.Module):
def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool: def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool:
requires_grad = any([p.requires_grad for p in module.parameters(recurse)]) requires_grad = any(p.requires_grad for p in module.parameters(recurse))
return requires_grad return requires_grad

View File

@ -138,7 +138,7 @@ def shapes_to_tensor(x, device=None):
return torch.as_tensor(x, device=device) return torch.as_tensor(x, device=device)
if torch.jit.is_tracing(): if torch.jit.is_tracing():
assert all( assert all(
[isinstance(t, torch.Tensor) for t in x] isinstance(t, torch.Tensor) for t in x
), "Shape should be tensor during tracing!" ), "Shape should be tensor during tracing!"
# as_tensor should not be used in tracing because it records a constant # as_tensor should not be used in tracing because it records a constant
ret = torch.stack(x) ret = torch.stack(x)

View File

@ -662,12 +662,12 @@ class Operator:
def any_opinfo_attr(self, attr): def any_opinfo_attr(self, attr):
if not self.has_opinfo(): if not self.has_opinfo():
raise RuntimeError() raise RuntimeError()
return any([getattr(opinfo, attr) for opinfo in self.opinfos]) return any(getattr(opinfo, attr) for opinfo in self.opinfos)
def all_opinfo_attr(self, attr): def all_opinfo_attr(self, attr):
if not self.has_opinfo(): if not self.has_opinfo():
raise RuntimeError() raise RuntimeError()
return all([getattr(opinfo, attr) for opinfo in self.opinfos]) return all(getattr(opinfo, attr) for opinfo in self.opinfos)
def supports_vjp(self): def supports_vjp(self):
if self.name in FACTORY_FNS: if self.name in FACTORY_FNS:

View File

@ -2030,7 +2030,7 @@ class TestPartitioning(AOTTestCase):
self.assertEqual(get_num_ins_outs(fw_graph), (3, 4)) self.assertEqual(get_num_ins_outs(fw_graph), (3, 4))
self.assertEqual(get_num_ins_outs(bw_graph), (4, 3)) self.assertEqual(get_num_ins_outs(bw_graph), (4, 3))
_, outs = get_ins_outs(fw_graph) _, outs = get_ins_outs(fw_graph)
self.assertTrue(all([is_sym_node(n) for n in outs[1:]])) self.assertTrue(all(is_sym_node(n) for n in outs[1:]))
def test_default_partitioner_output_tensor_shape_tensor(self): def test_default_partitioner_output_tensor_shape_tensor(self):
@ -2648,7 +2648,7 @@ def _test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, args):
reset_grads() reset_grads()
# See https://github.com/pytorch/pytorch/pull/98960#issuecomment-1505962215 # See https://github.com/pytorch/pytorch/pull/98960#issuecomment-1505962215
if all([x is None for x in orig_grad]): if all(x is None for x in orig_grad):
with self.assertRaisesRegex(RuntimeError, 'does not require grad and does not have a grad_fn'): with self.assertRaisesRegex(RuntimeError, 'does not require grad and does not have a grad_fn'):
call_forwards_backwards(compiled_f) call_forwards_backwards(compiled_f)
else: else:
@ -2669,7 +2669,7 @@ def _test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, args):
reset_grads() reset_grads()
# See https://github.com/pytorch/pytorch/pull/98960#issuecomment-1505962215 # See https://github.com/pytorch/pytorch/pull/98960#issuecomment-1505962215
if all([x is None for x in orig_grad]): if all(x is None for x in orig_grad):
with self.assertRaisesRegex(RuntimeError, 'does not require grad and does not have a grad_fn'): with self.assertRaisesRegex(RuntimeError, 'does not require grad and does not have a grad_fn'):
call_forwards_backwards(compiled_f) call_forwards_backwards(compiled_f)
else: else:

View File

@ -131,7 +131,7 @@ class TestControlFlowTraced(TestCase):
if node.op == "call_function": if node.op == "call_function":
all_ops_in_true_branch.append(node.target) all_ops_in_true_branch.append(node.target)
self.assertFalse(any([op._schema.is_mutable for op in all_ops_in_true_branch])) self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(*example_inputs) graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(*example_inputs)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs)) self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
@ -188,7 +188,7 @@ class TestControlFlowTraced(TestCase):
if node.op == "call_function": if node.op == "call_function":
all_ops.append(node.target) all_ops.append(node.target)
self.assertFalse(any([op._schema.is_mutable for op in all_ops])) self.assertFalse(any(op._schema.is_mutable for op in all_ops))
def test_cond_functionalized_data_dependent_pred(self): def test_cond_functionalized_data_dependent_pred(self):
def true_fn(x): def true_fn(x):

View File

@ -101,9 +101,9 @@ def diff_arg(arg, requires_grad=True):
else: else:
return arg.is_floating_point() or arg.is_complex() return arg.is_floating_point() or arg.is_complex()
if is_iterable_of_tensors(arg): if is_iterable_of_tensors(arg):
if all([is_differentiable_arg(a) for a in arg]): if all(is_differentiable_arg(a) for a in arg):
return True return True
if all([not is_differentiable_arg(a) for a in arg]): if all(not is_differentiable_arg(a) for a in arg):
return False return False
raise RuntimeError("NYI: The test runner can't handle this") raise RuntimeError("NYI: The test runner can't handle this")
return isinstance(arg, Tensor) and is_differentiable_arg(arg) return isinstance(arg, Tensor) and is_differentiable_arg(arg)

View File

@ -339,7 +339,7 @@ def check_model(
correct_grad = compute_grads(ref_inputs, ref_kwargs, correct, grads) correct_grad = compute_grads(ref_inputs, ref_kwargs, correct, grads)
flat_grads, _ = tree_flatten(correct_grad) flat_grads, _ = tree_flatten(correct_grad)
all_none_grads = all([x is None for x in flat_grads]) all_none_grads = all(x is None for x in flat_grads)
if all_none_grads: if all_none_grads:
# See Note [Detaching inputs that never need gradients] # See Note [Detaching inputs that never need gradients]
# There are a handful of ops that can return None gradients, into of zero gradients. # There are a handful of ops that can return None gradients, into of zero gradients.

View File

@ -106,7 +106,7 @@ class TestProfilerCUDA(TestCase):
# with CUDA events leaking the increase in memory was ~7 MB between # with CUDA events leaking the increase in memory was ~7 MB between
# profiler invocations above # profiler invocations above
is_increasing = all( is_increasing = all(
[last_rss[idx] > last_rss[idx - 1] for idx in range(1, len(last_rss))]) last_rss[idx] > last_rss[idx - 1] for idx in range(1, len(last_rss)))
max_diff = -1 max_diff = -1
for idx in range(1, len(last_rss)): for idx in range(1, len(last_rss)):
max_diff = max(max_diff, last_rss[idx] - last_rss[idx - 1]) max_diff = max(max_diff, last_rss[idx] - last_rss[idx - 1])
@ -530,11 +530,11 @@ class TestProfiler(TestCase):
for e in p.function_events: for e in p.function_events:
if "aten::add" in e.name or "AddBackward" in e.name: if "aten::add" in e.name or "AddBackward" in e.name:
self.assertTrue(any(["test_profiler" in entry for entry in e.stack])) self.assertTrue(any("test_profiler" in entry for entry in e.stack))
self.assertTrue(any([( self.assertTrue(any((
"test_source" in entry or "test_source" in entry or
"ts_method_1" in entry or "ts_method_1" in entry or
"ts_method_2" in entry) for entry in e.stack])) "ts_method_2" in entry) for entry in e.stack))
# TODO: https://github.com/pytorch/kineto/issues/617 # TODO: https://github.com/pytorch/kineto/issues/617
if kineto_available() and not IS_WINDOWS: if kineto_available() and not IS_WINDOWS:
@ -1429,7 +1429,7 @@ class TestProfiler(TestCase):
f_ts_1 = flow_f_to_ts[1] f_ts_1 = flow_f_to_ts[1]
s_ts_2 = flow_s_to_ts[2] s_ts_2 = flow_s_to_ts[2]
f_ts_2 = flow_f_to_ts[2] f_ts_2 = flow_f_to_ts[2]
self.assertTrue(all([ts in ts_to_name.keys() for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]])) self.assertTrue(all(ts in ts_to_name.keys() for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]))
self.assertTrue(ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits") self.assertTrue(ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits")
self.assertTrue(ts_to_name[s_ts_2] == "aten::add") self.assertTrue(ts_to_name[s_ts_2] == "aten::add")

View File

@ -2218,7 +2218,7 @@ class TestQuantizedOps(TestCase):
test_cases = itertools.product(scale_list, zero_point_list, shapes, dtypes, dims) test_cases = itertools.product(scale_list, zero_point_list, shapes, dtypes, dims)
op = torch.mean op = torch.mean
for scale, zp, shape, dtype, dim in test_cases: for scale, zp, shape, dtype, dim in test_cases:
if not all([d < len(shape) for d in dim]): if not all(d < len(shape) for d in dim):
continue continue
X = torch.randn(*shape) * 10 X = torch.randn(*shape) * 10
qX = torch.quantize_per_tensor(X, scale, zp, dtype) qX = torch.quantize_per_tensor(X, scale, zp, dtype)
@ -2257,7 +2257,7 @@ class TestQuantizedOps(TestCase):
dtypes, dims, unbiased_list, keep_dim_list) dtypes, dims, unbiased_list, keep_dim_list)
op = torch.std op = torch.std
for scale, zp, shape, dtype, dim, unbiased, keep_dim in test_cases: for scale, zp, shape, dtype, dim, unbiased, keep_dim in test_cases:
if not all([d < len(shape) for d in dim]): if not all(d < len(shape) for d in dim):
continue continue
X = torch.randn(*shape) * 10 X = torch.randn(*shape) * 10
qX = torch.quantize_per_tensor(X, scale, zp, dtype) qX = torch.quantize_per_tensor(X, scale, zp, dtype)

View File

@ -4994,7 +4994,7 @@ class TestQuantizeFx(QuantizationTestCase):
continue continue
matched_names.add(node.name) matched_names.add(node.name)
# Match preceding dequantize # Match preceding dequantize
self.assertTrue(all([arg.target == "dequantize" for arg in node.args])) self.assertTrue(all(arg.target == "dequantize" for arg in node.args))
# Match following quantize with the specific qparams and dtypes # Match following quantize with the specific qparams and dtypes
expected_scale, expected_zp, expected_dtype = node_name_to_expected_quantize_args[node.name] expected_scale, expected_zp, expected_dtype = node_name_to_expected_quantize_args[node.name]
for user in node.users.keys(): for user in node.users.keys():
@ -6513,7 +6513,7 @@ class TestQuantizeFx(QuantizationTestCase):
# Match preceding dequantize # Match preceding dequantize
self.assertTrue(len(node.args) == 1 or len(node.args) == 2) self.assertTrue(len(node.args) == 1 or len(node.args) == 2)
self.assertTrue(all([arg.target == "dequantize" for arg in node.args])) self.assertTrue(all(arg.target == "dequantize" for arg in node.args))
# Match following quantize with the specific dtypes # Match following quantize with the specific dtypes
self.assertEqual(len(node.users), 1) self.assertEqual(len(node.users), 1)

View File

@ -1625,7 +1625,7 @@ except RuntimeError as e:
'HSA_STATUS_ERROR_EXCEPTION', # ROCm 'HSA_STATUS_ERROR_EXCEPTION', # ROCm
'Device-side assertion' # ROCm 'Device-side assertion' # ROCm
] ]
self.assertTrue(any([msg in out or msg in err for msg in expected_messages])) self.assertTrue(any(msg in out or msg in err for msg in expected_messages))
@slowTest @slowTest
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts") @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")

View File

@ -59,9 +59,9 @@ def diff_arg(arg, requires_grad=True):
return arg.is_floating_point() or arg.is_complex() return arg.is_floating_point() or arg.is_complex()
if is_iterable_of_tensors(arg): if is_iterable_of_tensors(arg):
if all([is_differentiable_arg(a) for a in arg]): if all(is_differentiable_arg(a) for a in arg):
return True return True
if all([not is_differentiable_arg(a) for a in arg]): if all(not is_differentiable_arg(a) for a in arg):
return False return False
raise RuntimeError("NYI: The test runner can't handle this") raise RuntimeError("NYI: The test runner can't handle this")
return isinstance(arg, Tensor) and is_differentiable_arg(arg) return isinstance(arg, Tensor) and is_differentiable_arg(arg)

View File

@ -2756,7 +2756,7 @@ class TestFX(JitTestCase):
mod_false = symbolic_trace(mod, concrete_args={'y': False}) mod_false = symbolic_trace(mod, concrete_args={'y': False})
self.assertEqual(mod_true(3, True), 6) self.assertEqual(mod_true(3, True), 6)
print(mod_true.code) print(mod_true.code)
assert(any([i.target == torch._assert for i in mod_true.graph.nodes])) assert(any(i.target == torch._assert for i in mod_true.graph.nodes))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
mod_true(3, False) mod_true(3, False)
self.assertEqual(mod_false(3, False), 3) self.assertEqual(mod_false(3, False), 3)

View File

@ -449,7 +449,7 @@ class TestJit(JitTestCase):
buf.seek(0) buf.seek(0)
files = zipfile.ZipFile(buf).filelist files = zipfile.ZipFile(buf).filelist
self.assertTrue(any(['archive/constants.pkl' == f.filename for f in files])) self.assertTrue(any('archive/constants.pkl' == f.filename for f in files))
def test_script_fn_pkl(self): def test_script_fn_pkl(self):
with self.assertRaisesRegex(pickle.PickleError, "ScriptFunction cannot be pickled"): with self.assertRaisesRegex(pickle.PickleError, "ScriptFunction cannot be pickled"):

View File

@ -858,7 +858,7 @@ class TestJitTraceAutocast(JitTestCase):
with torch.cpu.amp.autocast(enabled=True): with torch.cpu.amp.autocast(enabled=True):
self.assertEqual(fn_s(x), fn(x)) self.assertEqual(fn_s(x), fn(x))
self.assertTrue(any(["is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()])) self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()))
@unittest.skipIf(not TEST_CUDA, "No cuda") @unittest.skipIf(not TEST_CUDA, "No cuda")
def test_script_autocast_cuda(self): def test_script_autocast_cuda(self):
@ -877,7 +877,7 @@ class TestJitTraceAutocast(JitTestCase):
with torch.cuda.amp.autocast(enabled=True): with torch.cuda.amp.autocast(enabled=True):
self.assertEqual(fn_s(x), fn(x)) self.assertEqual(fn_s(x), fn(x))
self.assertTrue(any(["is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()])) self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()))
def test_scripted_aliasing(self): def test_scripted_aliasing(self):

View File

@ -158,7 +158,7 @@ class TestJit(JitCommonTestCase):
# right now, tuple of outputs and tensor output supported # right now, tuple of outputs and tensor output supported
# TODO: list of tensor outputs # TODO: list of tensor outputs
tuple_of_tensors = isinstance(out, tuple) and all([isinstance(elem, torch.Tensor) for elem in out]) tuple_of_tensors = isinstance(out, tuple) and all(isinstance(elem, torch.Tensor) for elem in out)
if isinstance(out, torch.Tensor) or tuple_of_tensors: if isinstance(out, torch.Tensor) or tuple_of_tensors:
if tuple_of_tensors: if tuple_of_tensors:

View File

@ -428,8 +428,8 @@ def forward(self, x_1):
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3)) traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload) self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload)
for node in traced.graph.nodes if node.op == 'call_function'])) for node in traced.graph.nodes if node.op == 'call_function'))
def test_tensor_constants(self): def test_tensor_constants(self):
def f(): def f():

View File

@ -764,11 +764,11 @@ class TestTensorBoardFigure(BaseTestCase):
figures.append(figure) figures.append(figure)
writer.add_figure("add_figure/figure_list", figures, 0, close=False) writer.add_figure("add_figure/figure_list", figures, 0, close=False)
self.assertTrue(all([plt.fignum_exists(figure.number) is True for figure in figures])) # noqa: F812 self.assertTrue(all(plt.fignum_exists(figure.number) is True for figure in figures)) # noqa: F812
writer.add_figure("add_figure/figure_list", figures, 1) writer.add_figure("add_figure/figure_list", figures, 1)
if matplotlib.__version__ != '3.3.0': if matplotlib.__version__ != '3.3.0':
self.assertTrue(all([plt.fignum_exists(figure.number) is False for figure in figures])) # noqa: F812 self.assertTrue(all(plt.fignum_exists(figure.number) is False for figure in figures)) # noqa: F812
else: else:
print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163") print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163")

View File

@ -619,7 +619,7 @@ def create_differentiability_info(
output_differentiability = defn_dict.pop("output_differentiability", None) output_differentiability = defn_dict.pop("output_differentiability", None)
output_differentiability_conditions = None output_differentiability_conditions = None
if output_differentiability and any( if output_differentiability and any(
[isinstance(diff, str) for diff in output_differentiability] isinstance(diff, str) for diff in output_differentiability
): ):
if len(output_differentiability) != 1: if len(output_differentiability) != 1:
raise RuntimeError( raise RuntimeError(

View File

@ -57,7 +57,7 @@ def is_intrested_file(
file_path: str, interested_folders: List[str], platform: TestPlatform file_path: str, interested_folders: List[str], platform: TestPlatform
) -> bool: ) -> bool:
ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"] ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"]
if any([pattern in file_path for pattern in ignored_patterns]): if any(pattern in file_path for pattern in ignored_patterns):
return False return False
# ignore files that are not belong to pytorch # ignore files that are not belong to pytorch

View File

@ -13,7 +13,7 @@ IS_LINUX = platform.system() == "Linux"
IS_CONDA = ( IS_CONDA = (
"conda" in sys.version "conda" in sys.version
or "Continuum" in sys.version or "Continuum" in sys.version
or any([x.startswith("CONDA") for x in os.environ]) or any(x.startswith("CONDA") for x in os.environ)
) )
CONDA_DIR = os.path.join(os.path.dirname(sys.executable), "..") CONDA_DIR = os.path.join(os.path.dirname(sys.executable), "..")

View File

@ -72,7 +72,7 @@ if sys.platform == 'win32':
dll_paths = list(filter(os.path.exists, [th_dll_path, py_dll_path, base_py_dll_path])) dll_paths = list(filter(os.path.exists, [th_dll_path, py_dll_path, base_py_dll_path]))
if all([not os.path.exists(os.path.join(p, 'nvToolsExt64_1.dll')) for p in dll_paths]): if all(not os.path.exists(os.path.join(p, 'nvToolsExt64_1.dll')) for p in dll_paths):
nvtoolsext_dll_path = os.path.join( nvtoolsext_dll_path = os.path.join(
os.getenv('NVTOOLSEXT_PATH', os.path.join(pfiles_path, 'NVIDIA Corporation', 'NvToolsExt')), 'bin', 'x64') os.getenv('NVTOOLSEXT_PATH', os.path.join(pfiles_path, 'NVIDIA Corporation', 'NvToolsExt')), 'bin', 'x64')
else: else:
@ -80,7 +80,7 @@ if sys.platform == 'win32':
from .version import cuda as cuda_version from .version import cuda as cuda_version
import glob import glob
if cuda_version and all([not glob.glob(os.path.join(p, 'cudart64*.dll')) for p in dll_paths]): if cuda_version and all(not glob.glob(os.path.join(p, 'cudart64*.dll')) for p in dll_paths):
cuda_version_1 = cuda_version.replace('.', '_') cuda_version_1 = cuda_version.replace('.', '_')
cuda_path_var = 'CUDA_PATH_V' + cuda_version_1 cuda_path_var = 'CUDA_PATH_V' + cuda_version_1
default_path = os.path.join(pfiles_path, 'NVIDIA GPU Computing Toolkit', 'CUDA', 'v' + cuda_version) default_path = os.path.join(pfiles_path, 'NVIDIA GPU Computing Toolkit', 'CUDA', 'v' + cuda_version)

View File

@ -159,20 +159,20 @@ def has_tensor_in_frame(frame):
seen_ids[obj_id] = True seen_ids[obj_id] = True
return seen_ids[obj_id] return seen_ids[obj_id]
elif istype(obj, (list, tuple)): elif istype(obj, (list, tuple)):
seen_ids[obj_id] = any([has_tensor(v) for v in obj]) seen_ids[obj_id] = any(has_tensor(v) for v in obj)
return seen_ids[obj_id] return seen_ids[obj_id]
elif istype(obj, dict): elif istype(obj, dict):
# Some packages like pytest can be updated during runtime. So, make a # Some packages like pytest can be updated during runtime. So, make a
# copy of values to avoid issues like "RuntimeError: dictionary # copy of values to avoid issues like "RuntimeError: dictionary
# changed size during iteration" # changed size during iteration"
values = list(obj.values()) values = list(obj.values())
seen_ids[obj_id] = any([has_tensor(v) for v in values]) seen_ids[obj_id] = any(has_tensor(v) for v in values)
return seen_ids[obj_id] return seen_ids[obj_id]
elif istype(obj, (str, int, float, type(None), bool)): elif istype(obj, (str, int, float, type(None), bool)):
seen_ids[obj_id] = False seen_ids[obj_id] = False
return seen_ids[obj_id] return seen_ids[obj_id]
elif is_namedtuple(obj): elif is_namedtuple(obj):
seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields]) seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields)
return seen_ids[obj_id] return seen_ids[obj_id]
else: else:
# if config.debug: # if config.debug:

View File

@ -99,7 +99,7 @@ class ExecutionRecorder:
@classmethod @classmethod
def _is_excl(cls, mod): def _is_excl(cls, mod):
return any([mod.__name__ == excl for excl in cls.MOD_EXCLUDES]) return any(mod.__name__ == excl for excl in cls.MOD_EXCLUDES)
# Convert ModuleRecords -> DummyModule tree # Convert ModuleRecords -> DummyModule tree
@classmethod @classmethod

View File

@ -411,7 +411,7 @@ def inductor_fails(fx_g, args, check_str=None):
try: try:
result = fx_g(*args) result = fx_g(*args)
assert isinstance(result, (tuple, list)) assert isinstance(result, (tuple, list))
assert not any([isinstance(x, (tuple, list)) for x in result]) assert not any(isinstance(x, (tuple, list)) for x in result)
except Exception: except Exception:
return False return False

View File

@ -102,7 +102,7 @@ def requires_bwd_pass(out):
if isinstance(out, torch.Tensor): if isinstance(out, torch.Tensor):
return out.requires_grad return out.requires_grad
elif isinstance(out, (list, tuple)): elif isinstance(out, (list, tuple)):
return any([requires_bwd_pass(x) for x in out]) return any(requires_bwd_pass(x) for x in out)
elif out is None: elif out is None:
return False return False
elif isinstance(out, int): elif isinstance(out, int):

View File

@ -582,7 +582,7 @@ class VariableBuilder:
if ( if (
istype(value, (tuple, list)) istype(value, (tuple, list))
and all( and all(
[isinstance(x, int) or is_numpy_int_type(x) or x is None for x in value] isinstance(x, int) or is_numpy_int_type(x) or x is None for x in value
) )
and not config.dynamic_shapes and not config.dynamic_shapes
): ):
@ -1058,7 +1058,7 @@ def wrap_fx_proxy_cls(
else: else:
return ConstantVariable(example_value, **options) return ConstantVariable(example_value, **options)
elif istype(example_value, torch.Size) and all( elif istype(example_value, torch.Size) and all(
[isinstance(x, int) for x in example_value] isinstance(x, int) for x in example_value
): ):
sizes = [ConstantVariable(x) for x in example_value] sizes = [ConstantVariable(x) for x in example_value]
return SizeVariable(sizes, **options) return SizeVariable(sizes, **options)

View File

@ -491,7 +491,7 @@ class BuiltinVariable(VariableTracker):
fn, fn,
*proxy_args_kwargs(args, kwargs), *proxy_args_kwargs(args, kwargs),
) )
if any([isinstance(arg, FakeItemVariable) for arg in args]): if any(isinstance(arg, FakeItemVariable) for arg in args):
return wrap_fx_proxy_cls( return wrap_fx_proxy_cls(
FakeItemVariable, FakeItemVariable,
tx, tx,
@ -664,7 +664,7 @@ class BuiltinVariable(VariableTracker):
) )
for i in [a, b] for i in [a, b]
): ):
if any([isinstance(val, FakeItemVariable) for val in [a, b]]): if any(isinstance(val, FakeItemVariable) for val in [a, b]):
return variables.FakeItemVariable.from_tensor_variable(result) return variables.FakeItemVariable.from_tensor_variable(result)
if b.is_python_constant(): if b.is_python_constant():
@ -724,8 +724,8 @@ class BuiltinVariable(VariableTracker):
return None return None
def _dynamic_args(self, *args, **kwargs): def _dynamic_args(self, *args, **kwargs):
return any([isinstance(x, SymNodeVariable) for x in args]) or any( return any(isinstance(x, SymNodeVariable) for x in args) or any(
[isinstance(x, SymNodeVariable) for x in kwargs.values()] isinstance(x, SymNodeVariable) for x in kwargs.values()
) )
def call_slice(self, tx, *args): def call_slice(self, tx, *args):
@ -942,11 +942,9 @@ class BuiltinVariable(VariableTracker):
if ( if (
isinstance(seq, (variables.ListVariable, variables.TupleVariable)) isinstance(seq, (variables.ListVariable, variables.TupleVariable))
and all( and all(
[ isinstance(x, variables.ConstantVariable)
isinstance(x, variables.ConstantVariable) and isinstance(x.value, (int, float))
and isinstance(x.value, (int, float)) for x in seq.items
for x in seq.items
]
) )
and not kwargs and not kwargs
): ):

View File

@ -86,7 +86,7 @@ class ConstantVariable(VariableTracker):
items=self.unpack_var_sequence(tx), source=self.source, **options items=self.unpack_var_sequence(tx), source=self.source, **options
).call_method(tx, name, args, kwargs) ).call_method(tx, name, args, kwargs)
if any([isinstance(x, SymNodeVariable) for x in args]): if any(isinstance(x, SymNodeVariable) for x in args):
# Promote to SymNodeVariable for operations involving dynamic shapes. # Promote to SymNodeVariable for operations involving dynamic shapes.
return variables.SymNodeVariable(self.as_proxy(), self.value).call_method( return variables.SymNodeVariable(self.as_proxy(), self.value).call_method(
tx, name, args, kwargs tx, name, args, kwargs

View File

@ -112,7 +112,7 @@ class TensorVariable(VariableTracker):
return self.dtype in dtypes return self.dtype in dtypes
if type(tensor_type) is tuple: if type(tensor_type) is tuple:
return any([check_type(ty) for ty in tensor_type]) return any(check_type(ty) for ty in tensor_type)
else: else:
return check_type(tensor_type) return check_type(tensor_type)

View File

@ -472,16 +472,10 @@ class TorchVariable(VariableTracker):
tx, [args[0], result], {} tx, [args[0], result], {}
) )
else: else:
any_symints_or_symfloats = any( any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
[isinstance(x, SymNodeVariable) for x in args]
)
all_ints_or_floats = all( all_ints_or_floats = all(
[ isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
isinstance( for x in args
x, (variables.ConstantVariable, variables.SymNodeVariable)
)
for x in args
]
) )
bin_ops = {"add", "sub", "mul", "div", "sqrt"} bin_ops = {"add", "sub", "mul", "div", "sqrt"}
if ( if (
@ -571,7 +565,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
# Ideally, we would be able to do this at ctor time, but alas we need a combination # Ideally, we would be able to do this at ctor time, but alas we need a combination
# of value + args to determine this. # of value + args to determine this.
fn_ = self.value fn_ = self.value
if any([isinstance(x, SymNodeVariable) for x in args]): if any(isinstance(x, SymNodeVariable) for x in args):
if self.value == math.sqrt: if self.value == math.sqrt:
from torch.fx.experimental.symbolic_shapes import sym_sqrt from torch.fx.experimental.symbolic_shapes import sym_sqrt

View File

@ -2587,16 +2587,14 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
num_forward_returns:-num_symints_saved_for_bw num_forward_returns:-num_symints_saved_for_bw
] ]
assert all( assert all(
[isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards] isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards
) )
# See Note [Detaching saved tensors in AOTAutograd] # See Note [Detaching saved tensors in AOTAutograd]
ctx.save_for_backward(*(x.detach() if x._is_view() else x for x in tensors_saved_for_backwards)) ctx.save_for_backward(*(x.detach() if x._is_view() else x for x in tensors_saved_for_backwards))
symint_outs = fw_outs[-num_symints_saved_for_bw:] symint_outs = fw_outs[-num_symints_saved_for_bw:]
assert all( assert all(
[ isinstance(x, (int, float, torch.SymInt, torch.SymFloat))
isinstance(x, (int, float, torch.SymInt, torch.SymFloat)) for x in symint_outs
for x in symint_outs
]
) )
ctx.symints = symint_outs ctx.symints = symint_outs
else: else:
@ -2904,7 +2902,7 @@ def create_aot_dispatcher_function(
fake_flat_args = process_inputs(flat_args) fake_flat_args = process_inputs(flat_args)
needs_autograd = ( needs_autograd = (
any([x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)]) any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor))
and torch.is_grad_enabled() and torch.is_grad_enabled()
) )
with enable_python_dispatcher(): with enable_python_dispatcher():

View File

@ -1414,7 +1414,7 @@ def get_block_addrs(pool_id, live_only=True):
def check_memory_pool(pool_id, live_storages_ptrs: List[StorageWeakRefWrapper]): def check_memory_pool(pool_id, live_storages_ptrs: List[StorageWeakRefWrapper]):
assert all([isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs]) assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs)
gc.collect() gc.collect()
unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()}

View File

@ -127,8 +127,8 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
def in_output(snode): def in_output(snode):
if isinstance(snode, FusedSchedulerNode): if isinstance(snode, FusedSchedulerNode):
return any([in_output(x) for x in snode.snodes]) return any(in_output(x) for x in snode.snodes)
return any([isinstance(user.node, OutputNode) for user in snode.users]) return any(isinstance(user.node, OutputNode) for user in snode.users)
if in_output(snode): if in_output(snode):
outputs.append(fx_node) outputs.append(fx_node)

View File

@ -625,7 +625,7 @@ class Reduction(Loops):
indices = [] indices = []
changed = False changed = False
for md in sorted(read_writes.reads, key=lambda x: x.name): for md in sorted(read_writes.reads, key=lambda x: x.name):
if all([r in md.index.free_symbols for r in range_vars]): if all(r in md.index.free_symbols for r in range_vars):
indices.append(md.index) indices.append(md.index)
if md.name in V.graph.name_to_buffer: if md.name in V.graph.name_to_buffer:
buf = V.graph.name_to_buffer[md.name] buf = V.graph.name_to_buffer[md.name]
@ -651,7 +651,7 @@ class Reduction(Loops):
for i in indices: for i in indices:
i = V.graph.sizevars.simplify_with_ranges(i, ranges) i = V.graph.sizevars.simplify_with_ranges(i, ranges)
strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys()) strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys())
outer = all([s > 1 for s in strides]) outer = all(s > 1 for s in strides)
if outer: if outer:
num_outer += 1 num_outer += 1
else: else:
@ -837,8 +837,8 @@ class Reduction(Loops):
return isinstance(x, (int, sympy.Integer)) return isinstance(x, (int, sympy.Integer))
sr_qualified = ( sr_qualified = (
all([_is_static(r) for r in ranges]) all(_is_static(r) for r in ranges)
and all([_is_static(r) for r in reduction_ranges]) and all(_is_static(r) for r in reduction_ranges)
and _is_static(reduction_numel) and _is_static(reduction_numel)
) )
@ -3796,7 +3796,7 @@ class StorageBox(MutableBox):
""" """
heavy_ops = ["exp"] # a list of heavy ops heavy_ops = ["exp"] # a list of heavy ops
fn_str = loops.inner_fn_str() fn_str = loops.inner_fn_str()
return any([(op + "(") in fn_str for op in heavy_ops]) return any((op + "(") in fn_str for op in heavy_ops)
if ( if (
users > 1 users > 1

View File

@ -1363,7 +1363,7 @@ def meta__foreach_binop_scalar(self, scalar=1):
) )
def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1): def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
check( check(
all([isinstance(l, List) for l in [self, tensor1, tensor2]]), all(isinstance(l, List) for l in [self, tensor1, tensor2]),
lambda: ( lambda: (
"All arguments of _foreach_addc*_ must be List[Tensor], " "All arguments of _foreach_addc*_ must be List[Tensor], "
f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}" f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
@ -1384,7 +1384,7 @@ def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
) )
def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1): def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1):
check( check(
all([isinstance(l, List) for l in [self, tensor1, tensor2]]), all(isinstance(l, List) for l in [self, tensor1, tensor2]),
lambda: ( lambda: (
"All arguments must be List[Tensor], " "All arguments must be List[Tensor], "
f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}" f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
@ -2631,7 +2631,7 @@ def upsample_common_check(input_size, output_size, num_spatial_dims):
) )
check( check(
all([s > 0 for s in input_size[2:]]) and all([s > 0 for s in output_size]), all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size),
lambda: f"Input and output sizes should be greater than 0, but got " lambda: f"Input and output sizes should be greater than 0, but got "
f"input size {input_size} and output size {output_size}", f"input size {input_size} and output size {output_size}",
) )
@ -3062,7 +3062,7 @@ def meta_upsample_bilinear2d_aa(
input.size(), output_size, num_spatial_dims=2 input.size(), output_size, num_spatial_dims=2
) )
check( check(
input.numel() != 0 or all([size > 0 for size in input.size()[1:]]), input.numel() != 0 or all(size > 0 for size in input.size()[1:]),
lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}", lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
) )
return input.new_empty(full_output_size).to( return input.new_empty(full_output_size).to(

View File

@ -1140,7 +1140,7 @@ class FakeTensorMode(TorchDispatchMode):
flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs)) flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs))
flat_symints = tree_flatten_only(torch.SymInt, (args, kwargs)) flat_symints = tree_flatten_only(torch.SymInt, (args, kwargs))
has_symbolic_sizes = ( has_symbolic_sizes = (
any([i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors]) any(i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors)
or len(flat_symints) > 0 or len(flat_symints) > 0
) )
@ -1331,7 +1331,7 @@ class FakeTensorMode(TorchDispatchMode):
and type(x) is not torch.nn.Parameter and type(x) is not torch.nn.Parameter
) )
return any([check(x) for x in tree_flatten_only(torch.Tensor, (args, kwargs))]) return any(check(x) for x in tree_flatten_only(torch.Tensor, (args, kwargs)))
def validate_and_convert_non_fake_tensors(self, func, converter, args, kwargs): def validate_and_convert_non_fake_tensors(self, func, converter, args, kwargs):
""" """

View File

@ -19,7 +19,7 @@ def _validate_and_get_batch_size(
for in_dim, arg in zip(flat_in_dims, flat_args) for in_dim, arg in zip(flat_in_dims, flat_args)
if in_dim is not None if in_dim is not None
] ]
if batch_sizes and any([size != batch_sizes[0] for size in batch_sizes]): if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
raise ValueError( raise ValueError(
f"vmap: Expected all tensors to have the same size in the mapped " f"vmap: Expected all tensors to have the same size in the mapped "
f"dimension, got sizes {batch_sizes} for the mapped dimension" f"dimension, got sizes {batch_sizes} for the mapped dimension"
@ -168,7 +168,7 @@ def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) ->
if isinstance(out_dims, int): if isinstance(out_dims, int):
return return
if not isinstance(out_dims, tuple) or not all( if not isinstance(out_dims, tuple) or not all(
[isinstance(out_dim, int) for out_dim in out_dims] isinstance(out_dim, int) for out_dim in out_dims
): ):
raise ValueError( raise ValueError(
f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be " f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "

View File

@ -7,10 +7,8 @@ def module_contains_param(module, parametrization):
if is_parametrized(module): if is_parametrized(module):
# see if any of the module tensors have a parametriztion attached that matches the one passed in # see if any of the module tensors have a parametriztion attached that matches the one passed in
return any( return any(
[ any(isinstance(param, parametrization) for param in param_list)
any(isinstance(param, parametrization) for param in param_list) for key, param_list in module.parametrizations.items()
for key, param_list in module.parametrizations.items()
]
) )
return False return False

View File

@ -18,10 +18,8 @@ def module_contains_param(module: nn.Module, parametrization: Type[nn.Module]) -
if is_parametrized(module): if is_parametrized(module):
# see if any of the module tensors have a parametriztion attached that matches the one passed in # see if any of the module tensors have a parametriztion attached that matches the one passed in
return any( return any(
[ any(isinstance(param, parametrization) for param in param_list)
any(isinstance(param, parametrization) for param in param_list) for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator]
for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator]
]
) )
return False return False

View File

@ -110,9 +110,8 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
activation_post_process = modules[node.target] activation_post_process = modules[node.target]
# skip replacing observers to quant/dequant nodes if the qconfigs of all # skip replacing observers to quant/dequant nodes if the qconfigs of all
# consumers and producers of this observer are None # consumers and producers of this observer are None
skip_replacement = all([ skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
_has_none_qconfig(n, node_name_to_qconfig) for n in list(node.args) + list(node.users.keys()))
list(node.args) + list(node.users.keys())])
if skip_replacement or not _is_conversion_supported(activation_post_process): if skip_replacement or not _is_conversion_supported(activation_post_process):
# didn't find corresponding quantize op and info for the activation_post_process # didn't find corresponding quantize op and info for the activation_post_process
# so we just remove the observer # so we just remove the observer
@ -326,9 +325,8 @@ def _replace_observer_with_quantize_dequantize_node(
activation_post_process = modules[node.target] activation_post_process = modules[node.target]
# skip replacing observers to quant/dequant nodes if the qconfigs of all # skip replacing observers to quant/dequant nodes if the qconfigs of all
# consumers and producers of this observer are None # consumers and producers of this observer are None
skip_replacement = all([ skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
_has_none_qconfig(n, node_name_to_qconfig) for n in list(node.args) + list(node.users.keys()))
list(node.args) + list(node.users.keys())])
if skip_replacement or not _is_conversion_supported(activation_post_process): if skip_replacement or not _is_conversion_supported(activation_post_process):
# didn't find corresponding quantize op and info for the activation_post_process # didn't find corresponding quantize op and info for the activation_post_process
# so we just remove the observer # so we just remove the observer

View File

@ -646,7 +646,7 @@ def _filter_stack_entry(entry):
("_internal/common_utils", "prof_func_call"), ("_internal/common_utils", "prof_func_call"),
("_internal/common_utils", "prof_meth_call"), ("_internal/common_utils", "prof_meth_call"),
] ]
return all([not (f[0] in entry and f[1] in entry) for f in filtered_entries]) return all(not (f[0] in entry and f[1] in entry) for f in filtered_entries)
MEMORY_EVENT_NAME = "[memory]" MEMORY_EVENT_NAME = "[memory]"
OUT_OF_MEMORY_EVENT_NAME = "[OutOfMemory]" OUT_OF_MEMORY_EVENT_NAME = "[OutOfMemory]"
@ -692,10 +692,10 @@ def _build_table(
if len(events) == 0: if len(events) == 0:
return "" return ""
has_cuda_time = any([event.self_cuda_time_total > 0 for event in events]) has_cuda_time = any(event.self_cuda_time_total > 0 for event in events)
has_cuda_mem = any([event.self_cuda_memory_usage > 0 for event in events]) has_cuda_mem = any(event.self_cuda_memory_usage > 0 for event in events)
has_input_shapes = any( has_input_shapes = any(
[(event.input_shapes is not None and len(event.input_shapes) > 0) for event in events]) (event.input_shapes is not None and len(event.input_shapes) > 0) for event in events)
if sort_by is not None: if sort_by is not None:
events = EventList(sorted( events = EventList(sorted(
@ -753,7 +753,7 @@ def _build_table(
'# of Calls' '# of Calls'
) )
# Only append Node ID if any event has a valid (>= 0) Node ID # Only append Node ID if any event has a valid (>= 0) Node ID
append_node_id = any([evt.node_id != -1 for evt in events]) append_node_id = any(evt.node_id != -1 for evt in events)
if append_node_id: if append_node_id:
headers.append('Node ID') headers.append('Node ID')

View File

@ -183,7 +183,7 @@ If you want to use the {} GPU with PyTorch, please check the instructions at htt
for idx in range(device_count()): for idx in range(device_count()):
cap_major, cap_minor = get_device_capability(idx) cap_major, cap_minor = get_device_capability(idx)
# NVIDIA GPU compute architectures are backward compatible within major version # NVIDIA GPU compute architectures are backward compatible within major version
supported = any([sm // 10 == cap_major for sm in supported_sm]) supported = any(sm // 10 == cap_major for sm in supported_sm)
if not supported: if not supported:
device_name = get_device_name(idx) device_name = get_device_name(idx)
capability = cap_major * 10 + cap_minor capability = cap_major * 10 + cap_minor

View File

@ -200,10 +200,8 @@ def _result_distribute_with_col_rearrange(results, input, world_size, weight, pg
# Check if we need to rearrange columns appropriately for output. # Check if we need to rearrange columns appropriately for output.
rearrange_columns = any( rearrange_columns = any(
[ idx != placement.rank()
idx != placement.rank() for idx, placement in enumerate(weight._sharding_spec.placements)
for idx, placement in enumerate(weight._sharding_spec.placements)
]
) )
if not rearrange_columns: if not rearrange_columns:
return output return output

View File

@ -122,7 +122,7 @@ class CommTensor(torch.Tensor):
@classmethod @classmethod
def _is_supported(cls, op_name): def _is_supported(cls, op_name):
return any([comm in op_name for comm in cls._supported_comms]) return any(comm in op_name for comm in cls._supported_comms)
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):

View File

@ -648,13 +648,11 @@ def _rebuild_graph(
else: else:
value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n]) value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n])
if all( if all(
[ isinstance(n.target, torch._ops.OpOverload)
isinstance(n.target, torch._ops.OpOverload) and n.target._schema.name.startswith(
and n.target._schema.name.startswith( ("aten::_foreach", "aten::_fused_adam")
("aten::_foreach", "aten::_fused_adam") )
) for n in [dtn, node]
for n in [dtn, node]
]
): ):
# FIXME(@mrshenli): This is a temporary solution enable # FIXME(@mrshenli): This is a temporary solution enable
# foreach ops. The problem is that foreach ops returns # foreach ops. The problem is that foreach ops returns

View File

@ -26,7 +26,7 @@ aten = torch.ops.aten # pyre-ignore
) )
def _prop__foreach_unaop(op_schema: OpSchema) -> OutputSharding: def _prop__foreach_unaop(op_schema: OpSchema) -> OutputSharding:
self = op_schema.args_schema[0] self = op_schema.args_schema[0]
assert isinstance(self, list) and all([isinstance(s, DTensorSpec) for s in self]) assert isinstance(self, list) and all(isinstance(s, DTensorSpec) for s in self)
# FIXME(@mrshenli): for sqrt, this is only mathematically correct for # FIXME(@mrshenli): for sqrt, this is only mathematically correct for
# Replicate and Shard tensor. # Replicate and Shard tensor.
return OutputSharding(output_spec=self) return OutputSharding(output_spec=self)
@ -43,17 +43,17 @@ def _prop__foreach_binop_list(op_schema: OpSchema) -> OutputSharding:
self, other = op_schema.args_schema[:2] self, other = op_schema.args_schema[:2]
scalar = None if len(op_schema.args_schema) < 3 else op_schema.args_schema[2] scalar = None if len(op_schema.args_schema) < 3 else op_schema.args_schema[2]
assert isinstance(self, list) and all( assert isinstance(self, list) and all(
[isinstance(s, DTensorSpec) for s in self] isinstance(s, DTensorSpec) for s in self
), f"Expect a List[DTensorSpec] but got {self}" ), f"Expect a List[DTensorSpec] but got {self}"
assert isinstance(other, list) and all( assert isinstance(other, list) and all(
[isinstance(o, DTensorSpec) for o in other] isinstance(o, DTensorSpec) for o in other
), f"Expect a List[DTensorSpec] but got {other}" ), f"Expect a List[DTensorSpec] but got {other}"
assert len(self) == len(other), ( assert len(self) == len(other), (
"Two tensor lists must match in length, " "Two tensor lists must match in length, "
f"but got {len(self)} and {len(other)}" f"but got {len(self)} and {len(other)}"
) )
if any([s != o for s, o in zip(self, other)]): if any(s != o for s, o in zip(self, other)):
# If DTensorSpec for the two operand do not match, suggest using # If DTensorSpec for the two operand do not match, suggest using
# self's DTensorSpec. This will trigger allreduce if other is partial # self's DTensorSpec. This will trigger allreduce if other is partial
# and self is replicated. # and self is replicated.
@ -83,7 +83,7 @@ def _prop__foreach_binop_list(op_schema: OpSchema) -> OutputSharding:
) )
def _prop__foreach_binop_scalar(op_schema: OpSchema) -> OutputSharding: def _prop__foreach_binop_scalar(op_schema: OpSchema) -> OutputSharding:
self, scalar = op_schema.args_schema self, scalar = op_schema.args_schema
assert isinstance(self, list) and all([isinstance(s, DTensorSpec) for s in self]) assert isinstance(self, list) and all(isinstance(s, DTensorSpec) for s in self)
assert not isinstance(scalar, list) assert not isinstance(scalar, list)
return OutputSharding(output_spec=self) return OutputSharding(output_spec=self)
@ -97,10 +97,10 @@ def _prop__foreach_binop_scalar(op_schema: OpSchema) -> OutputSharding:
def _prop__foreach_addcop_scalar(op_schema: OpSchema): def _prop__foreach_addcop_scalar(op_schema: OpSchema):
self, tensor1, tensor2 = op_schema.args_schema[:3] self, tensor1, tensor2 = op_schema.args_schema[:3]
scalar = None if len(op_schema.args_schema) < 4 else op_schema.args_schema[3] scalar = None if len(op_schema.args_schema) < 4 else op_schema.args_schema[3]
assert isinstance(self, list) and all([isinstance(s, DTensorSpec) for s in self]) assert isinstance(self, list) and all(isinstance(s, DTensorSpec) for s in self)
assert isinstance(tensor1, list) and all([isinstance(s, DTensorSpec) for s in self]) assert isinstance(tensor1, list) and all(isinstance(s, DTensorSpec) for s in self)
assert isinstance(tensor2, list) and all([isinstance(s, DTensorSpec) for s in self]) assert isinstance(tensor2, list) and all(isinstance(s, DTensorSpec) for s in self)
if any([s != t1 or s != t2 for s, t1, t2 in zip(self, tensor1, tensor2)]): if any(s != t1 or s != t2 for s, t1, t2 in zip(self, tensor1, tensor2)):
# If DTensorSpec for the two operand do not match, suggest using # If DTensorSpec for the two operand do not match, suggest using
# self's DTensorSpec. This will trigger allreduce if other is partial # self's DTensorSpec. This will trigger allreduce if other is partial
# and self is replicated. # and self is replicated.
@ -126,7 +126,7 @@ def _prop__foreach_addcop_scalar(op_schema: OpSchema):
def _prop__foreach_pow_scalar_and_tensor(op_schema: OpSchema): def _prop__foreach_pow_scalar_and_tensor(op_schema: OpSchema):
scala, exponent = op_schema.args_schema scala, exponent = op_schema.args_schema
assert isinstance(exponent, list) and all( assert isinstance(exponent, list) and all(
[isinstance(s, DTensorSpec) for s in exponent] isinstance(s, DTensorSpec) for s in exponent
) )
return OutputSharding(output_spec=exponent) return OutputSharding(output_spec=exponent)
@ -136,21 +136,21 @@ def _prop__fused_adam(op_schema: OpSchema):
NT = 5 NT = 5
tesnor_list_args: Tuple[List[DTensorSpec]] = op_schema.args_schema[:NT] # type: ignore[assignment] tesnor_list_args: Tuple[List[DTensorSpec]] = op_schema.args_schema[:NT] # type: ignore[assignment]
assert all([isinstance(schema, list) for schema in tesnor_list_args]) assert all(isinstance(schema, list) for schema in tesnor_list_args)
assert all( assert all(
[isinstance(s, DTensorSpec) for schema in tesnor_list_args for s in schema] isinstance(s, DTensorSpec) for schema in tesnor_list_args for s in schema
) )
tensor_schemas: Tuple[List[DTensorSpec]] = [ # type: ignore[assignment] tensor_schemas: Tuple[List[DTensorSpec]] = [ # type: ignore[assignment]
schema for schema in tesnor_list_args if len(schema) schema for schema in tesnor_list_args if len(schema)
] ]
assert all([len(s) == len(tensor_schemas[0]) for s in tensor_schemas]), ( assert all(len(s) == len(tensor_schemas[0]) for s in tensor_schemas), (
"expect the same number of gradients and states, but got " "expect the same number of gradients and states, but got "
f"{[len(s) for s in tensor_schemas]}." f"{[len(s) for s in tensor_schemas]}."
) )
if any([any([t != ts[0] for t in ts]) for ts in zip(*tensor_schemas)]): if any(any(t != ts[0] for t in ts) for ts in zip(*tensor_schemas)):
new_schemas: Tuple[List[DTensorSpec]] = tuple( # type: ignore[assignment] new_schemas: Tuple[List[DTensorSpec]] = tuple( # type: ignore[assignment]
op_schema.args_schema[0] if len(s) else s for s in tesnor_list_args op_schema.args_schema[0] if len(s) else s for s in tesnor_list_args
) )

View File

@ -1922,7 +1922,7 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
_warn_not_in_group("all_reduce_coalesced") _warn_not_in_group("all_reduce_coalesced")
return return
if any([t.is_complex() for t in tensors]) and not supports_complex(op): if any(t.is_complex() for t in tensors) and not supports_complex(op):
raise RuntimeError(f"all_reduce does not support {op} on complex tensors") raise RuntimeError(f"all_reduce does not support {op} on complex tensors")
tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors]

View File

@ -207,7 +207,7 @@ def generate_binconstraint_t(constraint, counter):
if constraint.lhs == Dyn: if constraint.lhs == Dyn:
return T(), counter return T(), counter
elif isinstance(constraint.lhs, TensorType): elif isinstance(constraint.lhs, TensorType):
is_fully_static = all([d != Dyn for d in constraint.lhs.__args__]) is_fully_static = all(d != Dyn for d in constraint.lhs.__args__)
if is_fully_static: if is_fully_static:
return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter
else: else:
@ -403,7 +403,7 @@ def generate_calc_product(constraint, counter):
for p in all_possibilities: for p in all_possibilities:
p = list(p) p = list(p)
# this tells us there is a dynamic variable # this tells us there is a dynamic variable
contains_dyn = not(all([constraint.op == op_neq for constraint in p])) contains_dyn = not(all(constraint.op == op_neq for constraint in p))
if contains_dyn: if contains_dyn:
mid_var = [Dyn] mid_var = [Dyn]
total_constraints = lhs + mid_var + rhs total_constraints = lhs + mid_var + rhs
@ -438,7 +438,7 @@ def generate_reshape(constraint, counter):
target = constraint.target.__args__ target = constraint.target.__args__
is_fully_static = all([d != Dyn for d in target]) is_fully_static = all(d != Dyn for d in target)
# dynamic tensor # dynamic tensor
c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq) c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq)

View File

@ -304,7 +304,7 @@ def optimize_for_inference(
if supports_mkldnn != MklSupport.NO: if supports_mkldnn != MklSupport.NO:
if supports_mkldnn == MklSupport.UNKNOWN: if supports_mkldnn == MklSupport.UNKNOWN:
if not any([arg.target == 'to_dense' for arg in node.args]): if not any(arg.target == 'to_dense' for arg in node.args):
continue continue
with fx_graph.inserting_before(node): with fx_graph.inserting_before(node):
mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, ))) mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, )))

View File

@ -50,7 +50,7 @@ class Partition:
# the remove this input node # the remove this input node
for input_node in input_nodes: for input_node in input_nodes:
if all( if all(
[n not in self.nodes for n in input_node.users] n not in self.nodes for n in input_node.users
) and input_node.op in {"placeholder", "get_attr"}: ) and input_node.op in {"placeholder", "get_attr"}:
self.nodes.remove(input_node) self.nodes.remove(input_node)
self.recalculate_mem_size() self.recalculate_mem_size()
@ -144,10 +144,8 @@ def get_latency_of_one_partition(
# or its input nodes in this partition are placeholders and get_attrs # or its input nodes in this partition are placeholders and get_attrs
# this node is on the top bfs level in this partition # this node is on the top bfs level in this partition
if not any( if not any(
[ n in partition.nodes and n.op not in {"placeholder", "get_attr"}
n in partition.nodes and n.op not in {"placeholder", "get_attr"}
for n in input_nodes for n in input_nodes
]
): ):
top_nodes.append(node) top_nodes.append(node)
return top_nodes return top_nodes

View File

@ -2613,7 +2613,7 @@ class ShapeEnv:
if not expr.has(sympy.Mod): if not expr.has(sympy.Mod):
try: try:
floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv)) floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv))
if len(floor_div_atoms) > 0 and any([a.divisor != 1 for a in floor_div_atoms]): if len(floor_div_atoms) > 0 and any(a.divisor != 1 for a in floor_div_atoms):
raise NotImplementedError raise NotImplementedError
solutions = sympy.solve(lhs - rhs, free[0], dict=True) solutions = sympy.solve(lhs - rhs, free[0], dict=True)
if len(solutions) != 1: if len(solutions) != 1:

View File

@ -747,7 +747,7 @@ class _SplitterBase:
# Determine which subgraph to start from based on which subgraph has # Determine which subgraph to start from based on which subgraph has
# 0-dep node # 0-dep node
acc_subgraph: bool = not any([len(self.deps[n]) == 0 for n in current_cpu_nodes]) acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
current_subgraph_nodes: NodeList = [] current_subgraph_nodes: NodeList = []

View File

@ -73,7 +73,7 @@ def is_consistent(t1, t2):
if isinstance(t1, TensorType) and isinstance(t2, TensorType): if isinstance(t1, TensorType) and isinstance(t2, TensorType):
return len(t1.__args__) == len(t2.__args__) and \ return len(t1.__args__) == len(t2.__args__) and \
all([is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)]) all(is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__))
else: else:
return False return False
@ -98,7 +98,7 @@ def is_more_precise(t1, t2):
if isinstance(t1, TensorType) and isinstance(t2, TensorType): if isinstance(t1, TensorType) and isinstance(t2, TensorType):
return len(t1.__args__) == len(t2.__args__) and \ return len(t1.__args__) == len(t2.__args__) and \
all([is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)]) all(is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__))
else: else:
return False return False

View File

@ -50,7 +50,7 @@ def as_nested_tensor(
tensor([0., 0., 0., 0., 0.]) tensor([0., 0., 0., 0., 0.])
""" """
if not isinstance(tensor_list, list) or any( if not isinstance(tensor_list, list) or any(
[not isinstance(t, Tensor) for t in tensor_list] not isinstance(t, Tensor) for t in tensor_list
): ):
raise TypeError( raise TypeError(
"nested_tensor(): Expected first argument to be a list of tensors " "nested_tensor(): Expected first argument to be a list of tensors "

View File

@ -1172,9 +1172,9 @@ class MultiheadAttention(Module):
# generator expressions. # generator expressions.
if torch.overrides.has_torch_function(tensor_args): if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function" why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([_arg_cuda_or_cpu(x) for x in tensor_args]): elif not all(_arg_cuda_or_cpu(x) for x in tensor_args):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any([_arg_requires_grad(x) for x in tensor_args]): elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
why_not_fast_path = ("grad is enabled and at least one of query or the " why_not_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad") "input/output projection weights or biases requires_grad")
if not why_not_fast_path: if not why_not_fast_path:

View File

@ -129,7 +129,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
or (min(cutoffs) <= 0) \ or (min(cutoffs) <= 0) \
or (max(cutoffs) > (n_classes - 1)) \ or (max(cutoffs) > (n_classes - 1)) \
or (len(set(cutoffs)) != len(cutoffs)) \ or (len(set(cutoffs)) != len(cutoffs)) \
or any([int(c) != c for c in cutoffs]): or any(int(c) != c for c in cutoffs):
raise ValueError("cutoffs should be a sequence of unique, positive " raise ValueError("cutoffs should be a sequence of unique, positive "
"integers sorted in an increasing order, where " "integers sorted in an increasing order, where "

View File

@ -1407,7 +1407,7 @@ class DistributedDataParallel(Module, Joinable):
# We batch zero_grad for all params by resetting the whole grad # We batch zero_grad for all params by resetting the whole grad
# buffer when the grad of all params is set to None. # buffer when the grad of all params is set to None.
all_param_grad_none = all( all_param_grad_none = all(
[param.grad is None for param in self._delay_all_reduce_params] param.grad is None for param in self._delay_all_reduce_params
) )
for index, param in enumerate(self._delay_all_reduce_params): for index, param in enumerate(self._delay_all_reduce_params):

View File

@ -1513,7 +1513,7 @@ def _optional_input_placeholder_tensor(g):
def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name): def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name):
rank = _get_tensor_rank(self) rank = _get_tensor_rank(self)
if rank is not None and any( if rank is not None and any(
[_get_tensor_dim_size(self, i) == 0 for i in range(rank)] _get_tensor_dim_size(self, i) == 0 for i in range(rank)
): ):
# If input tensor is empty, according to ONNX ReduceSum definition, # If input tensor is empty, according to ONNX ReduceSum definition,
# set keepdims=1 so that the resulted tensor has the same rank as the input. # set keepdims=1 so that the resulted tensor has the same rank as the input.

View File

@ -539,13 +539,11 @@ def cat(g: jit_utils.GraphContext, tensor_list, dim):
nonempty_tensors.append(t) nonempty_tensors.append(t)
assert len(nonempty_tensors) > 0 assert len(nonempty_tensors) > 0
assert all( assert all(
[ symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None
symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None or symbolic_helper._get_tensor_rank(t) is None
or symbolic_helper._get_tensor_rank(t) is None or symbolic_helper._get_tensor_rank(t)
or symbolic_helper._get_tensor_rank(t) == symbolic_helper._get_tensor_rank(nonempty_tensors[0])
== symbolic_helper._get_tensor_rank(nonempty_tensors[0]) for t in nonempty_tensors
for t in nonempty_tensors
]
) )
tensor_list.node().removeAllInputs() tensor_list.node().removeAllInputs()
for t in nonempty_tensors: for t in nonempty_tensors:
@ -1499,7 +1497,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
# TODO(justinchuby): Looks like this op is deprecated in torch # TODO(justinchuby): Looks like this op is deprecated in torch
sizes = symbolic_helper._get_tensor_sizes(input) sizes = symbolic_helper._get_tensor_sizes(input)
dim = sizes[-len(padding) :] if sizes is not None else None dim = sizes[-len(padding) :] if sizes is not None else None
if dim is None or any([i is None for i in dim]): if dim is None or any(i is None for i in dim):
return symbolic_helper._unimplemented( return symbolic_helper._unimplemented(
"get_pool_ceil_padding", "input size not accessible", input "get_pool_ceil_padding", "input size not accessible", input
) )
@ -1824,7 +1822,7 @@ def _adaptive_pool(name, type, tuple_fn, fn=None):
# FIXME(justinchuby): Avoid catching Exception. # FIXME(justinchuby): Avoid catching Exception.
# Catch a more specific exception instead. # Catch a more specific exception instead.
dim = None dim = None
if dim is None or any([i is None for i in dim]): if dim is None or any(i is None for i in dim):
if output_size == [1] * len(output_size): if output_size == [1] * len(output_size):
return g.op("GlobalMaxPool", input), None return g.op("GlobalMaxPool", input), None
return symbolic_helper._unimplemented( return symbolic_helper._unimplemented(
@ -2453,7 +2451,7 @@ def _convolution(
# Catch a more specific exception instead. # Catch a more specific exception instead.
kernel_shape = None kernel_shape = None
if kernel_shape is None or any([i is None for i in kernel_shape]): if kernel_shape is None or any(i is None for i in kernel_shape):
raise errors.SymbolicValueError( raise errors.SymbolicValueError(
"Unsupported: ONNX export of convolution for kernel of unknown shape.", "Unsupported: ONNX export of convolution for kernel of unknown shape.",
input, input,
@ -6072,7 +6070,7 @@ def dim(g: jit_utils.GraphContext, self):
def __contains_(g: jit_utils.GraphContext, self, element): def __contains_(g: jit_utils.GraphContext, self, element):
unpacked_list = symbolic_helper._unpack_list(self) unpacked_list = symbolic_helper._unpack_list(self)
if all( if all(
[symbolic_helper._is_constant(x) for x in unpacked_list] symbolic_helper._is_constant(x) for x in unpacked_list
) and symbolic_helper._is_constant(element): ) and symbolic_helper._is_constant(element):
return g.op( return g.op(
"Constant", "Constant",

View File

@ -917,7 +917,7 @@ def verify_aten_graph(
graph_inputs = list(graph.inputs()) graph_inputs = list(graph.inputs())
jit_inputs = tuple([arg for arg in input_args if arg is not None]) jit_inputs = tuple([arg for arg in input_args if arg is not None])
weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]] weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]]
assert all([w is not None for w in weights]) assert all(w is not None for w in weights)
# TODO: Only copy the argument if mutation is detected in Graph. # TODO: Only copy the argument if mutation is detected in Graph.
jit_inputs = copy.deepcopy(jit_inputs) jit_inputs = copy.deepcopy(jit_inputs)
jit_input_and_parameters = jit_inputs + tuple(weights) jit_input_and_parameters = jit_inputs + tuple(weights)

View File

@ -737,7 +737,7 @@ class MultiProcessTestCase(TestCase):
if subprocess_error: if subprocess_error:
break break
# All processes have joined cleanly if they all a valid exitcode # All processes have joined cleanly if they all a valid exitcode
if all([p.exitcode is not None for p in self.processes]): if all(p.exitcode is not None for p in self.processes):
break break
# Check if we should time out the test. If so, we terminate each process. # Check if we should time out the test. If so, we terminate each process.
elapsed = time.time() - start_time elapsed = time.time() - start_time

View File

@ -1789,7 +1789,7 @@ def check_if_enable(test: unittest.TestCase):
# Sanitize the platforms list so that we continue to disable the test for any valid platforms given # Sanitize the platforms list so that we continue to disable the test for any valid platforms given
platforms = list(filter(lambda p: p in platform_to_conditional, platforms)) platforms = list(filter(lambda p: p in platform_to_conditional, platforms))
if platforms == [] or any([platform_to_conditional[platform] for platform in platforms]): if platforms == [] or any(platform_to_conditional[platform] for platform in platforms):
should_skip = True should_skip = True
skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \ skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \
f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \ f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \
@ -3159,9 +3159,9 @@ class TestCase(expecttest.TestCase):
yield yield
if len(ws) == 0: if len(ws) == 0:
self.fail('no warning caught') self.fail('no warning caught')
self.assertTrue(any([type(w.message) is category for w in ws])) self.assertTrue(any(type(w.message) is category for w in ws))
self.assertTrue( self.assertTrue(
any([re.match(pattern, str(w.message)) for w in ws]), any(re.match(pattern, str(w.message)) for w in ws),
f'{pattern}, {[w.message for w in ws if type(w.message) is category]}') f'{pattern}, {[w.message for w in ws if type(w.message) is category]}')
def assertExpected(self, s, subname=None): def assertExpected(self, s, subname=None):

View File

@ -189,7 +189,7 @@ def generate_cct_and_mode(autograd_view_consistency=True):
# then the first argument is being written to. Introspection please save us! # then the first argument is being written to. Introspection please save us!
mutated_argument = args[0] mutated_argument = args[0]
if not isinstance(mutated_argument, CompositeCompliantTensor) and \ if not isinstance(mutated_argument, CompositeCompliantTensor) and \
any([isinstance(a, CompositeCompliantTensor) for a in args[1:]]): any(isinstance(a, CompositeCompliantTensor) for a in args[1:]):
raise RuntimeError( raise RuntimeError(
'Not composite compliant: performing in-place operation ' 'Not composite compliant: performing in-place operation '
f'{func.__name__} where the Tensor being written to is ' f'{func.__name__} where the Tensor being written to is '
@ -249,10 +249,10 @@ def is_tensorlist(lst):
return False return False
if len(lst) == 0: if len(lst) == 0:
return False return False
all_tensors = all([isinstance(elt, torch.Tensor) for elt in lst]) all_tensors = all(isinstance(elt, torch.Tensor) for elt in lst)
if all_tensors: if all_tensors:
return True return True
exists_one_tensor = all([isinstance(elt, torch.Tensor) for elt in lst]) exists_one_tensor = all(isinstance(elt, torch.Tensor) for elt in lst)
if exists_one_tensor: if exists_one_tensor:
raise RuntimeError('This test assumes that PyTorch APIs cannot take ' raise RuntimeError('This test assumes that PyTorch APIs cannot take '
'mixed lists of Tensor and other things') 'mixed lists of Tensor and other things')

View File

@ -4666,7 +4666,7 @@ class DistributedTest:
ddp_model().backward(create_graph=True) ddp_model().backward(create_graph=True)
# grad tensors should require grad. # grad tensors should require grad.
self.assertTrue( self.assertTrue(
all([param.requires_grad for param in ddp_model.parameters()]) all(param.requires_grad for param in ddp_model.parameters())
) )
@skip_but_pass_in_sandcastle_if( @skip_but_pass_in_sandcastle_if(

View File

@ -1839,7 +1839,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
trace = json.load(f) trace = json.load(f)
event_names = [event['name'] for event in trace] event_names = [event['name'] for event in trace]
for expected_event_name in EXPECTED_REMOTE_EVENTS + [RPCExecMode.ASYNC.value]: for expected_event_name in EXPECTED_REMOTE_EVENTS + [RPCExecMode.ASYNC.value]:
event_exists = any([expected_event_name in event_name for event_name in event_names]) event_exists = any(expected_event_name in event_name for event_name in event_names)
self.assertTrue(event_exists) self.assertTrue(event_exists)
@dist_init @dist_init
@ -2903,7 +2903,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
def launched_rpc(events): def launched_rpc(events):
expected_name = f"rpc_{RPCExecMode.ASYNC.value}#_rref_typeof_on_owner" expected_name = f"rpc_{RPCExecMode.ASYNC.value}#_rref_typeof_on_owner"
return any([e.name.startswith(expected_name) for e in events]) return any(e.name.startswith(expected_name) for e in events)
dst = worker_name((self.rank + 1) % self.world_size) dst = worker_name((self.rank + 1) % self.world_size)
rref = rpc.remote(dst, torch.add, args=(torch.ones(2), 1)) rref = rpc.remote(dst, torch.add, args=(torch.ones(2), 1))

View File

@ -21,7 +21,7 @@ from torch.autograd.grad_mode import no_grad
def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]], def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
with_indices: Optional[bool] = False) -> \ with_indices: Optional[bool] = False) -> \
Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]]: Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]]:
assert all([not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist]), ( assert all(not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist), (
"all specified tensorlists must match in length") "all specified tensorlists must match in length")
per_device_and_dtype_tensors: Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict( per_device_and_dtype_tensors: Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
lambda: [[] for _ in range(len(tensorlistlist) + (1 if with_indices else 0))]) lambda: [[] for _ in range(len(tensorlistlist) + (1 if with_indices else 0))])
@ -39,4 +39,4 @@ def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool: def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
if device.type not in ['cpu', 'cuda'] or torch.jit.is_scripting(): if device.type not in ['cpu', 'cuda'] or torch.jit.is_scripting():
return False return False
return all([t is None or type(t) == torch.Tensor for t in tensors]) return all(t is None or type(t) == torch.Tensor for t in tensors)

View File

@ -176,7 +176,7 @@ class Freezer:
return return
# Python packages are directories that have __init__.py in them. # Python packages are directories that have __init__.py in them.
is_package_dir = any([child.name == "__init__.py" for child in path.iterdir()]) is_package_dir = any(child.name == "__init__.py" for child in path.iterdir())
if not is_package_dir: if not is_package_dir:
self.msg(path, "S") self.msg(path, "S")
return return

View File

@ -553,7 +553,7 @@ class BuildExtension(build_ext):
_ccbin = os.getenv("CC") _ccbin = os.getenv("CC")
if ( if (
_ccbin is not None _ccbin is not None
and not any([flag.startswith(('-ccbin', '--compiler-bindir')) for flag in cflags]) and not any(flag.startswith(('-ccbin', '--compiler-bindir')) for flag in cflags)
): ):
cflags.extend(['-ccbin', _ccbin]) cflags.extend(['-ccbin', _ccbin])
@ -1462,7 +1462,7 @@ def _jit_compile(name,
if with_cuda is None: if with_cuda is None:
with_cuda = any(map(_is_cuda_file, sources)) with_cuda = any(map(_is_cuda_file, sources))
with_cudnn = any(['cudnn' in f for f in extra_ldflags or []]) with_cudnn = any('cudnn' in f for f in extra_ldflags or [])
old_version = JIT_EXTENSION_VERSIONER.get_version(name) old_version = JIT_EXTENSION_VERSIONER.get_version(name)
version = JIT_EXTENSION_VERSIONER.bump_version_if_changed( version = JIT_EXTENSION_VERSIONER.bump_version_if_changed(
name, name,

View File

@ -45,7 +45,7 @@ def optimize_for_mobile(
preserved_methods_str: List[str] = [str(method) for method in preserved_methods] preserved_methods_str: List[str] = [str(method) for method in preserved_methods]
bundled_inputs_attributes = _get_bundled_inputs_preserved_attributes(script_module, preserved_methods_str) bundled_inputs_attributes = _get_bundled_inputs_preserved_attributes(script_module, preserved_methods_str)
if all([hasattr(script_module, method) for method in bundled_inputs_attributes]): if all(hasattr(script_module, method) for method in bundled_inputs_attributes):
preserved_methods_str = list(set(preserved_methods_str + bundled_inputs_attributes)) preserved_methods_str = list(set(preserved_methods_str + bundled_inputs_attributes))
non_exist_methods = [] non_exist_methods = []

View File

@ -1373,7 +1373,7 @@ class FunctionSchema:
len(mutable_returns) == 0 or len(immutable_returns) == 0 len(mutable_returns) == 0 or len(immutable_returns) == 0
), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}" ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
for ret in mutable_returns: for ret in mutable_returns:
assert any([ret.annotation == arg.annotation for arg in out_and_self]), ( assert any(ret.annotation == arg.annotation for arg in out_and_self), (
'All mutable returns must be aliased either to a keyword argument, or to "self". ' 'All mutable returns must be aliased either to a keyword argument, or to "self". '
"Did you forget to mark an out argument as keyword-only?" "Did you forget to mark an out argument as keyword-only?"
) )