[BE] Enable flake8-comprehension rule C417 (#97880)

Enables flake8-comprehension rule C417. Ruff autogenerated these fixes to the codebase.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97880
Approved by: https://github.com/ezyang, https://github.com/kit1980, https://github.com/albanD
This commit is contained in:
Aaron Gokaslan 2023-03-30 14:34:24 +00:00 committed by PyTorch MergeBot
parent 1d08b5b103
commit 47dca20d80
36 changed files with 108 additions and 124 deletions

View File

@ -14,7 +14,7 @@ ignore =
# these ignores are from flake8-bugbear; please fix! # these ignores are from flake8-bugbear; please fix!
B007,B008,B017,B019,B020,B023,B024,B026,B027,B028,B903,B904,B905,B906,B907 B007,B008,B017,B019,B020,B023,B024,B026,B027,B028,B903,B904,B905,B906,B907
# these ignores are from flake8-comprehensions; please fix! # these ignores are from flake8-comprehensions; please fix!
C407,C417 C407
# these ignores are from flake8-logging-format; please fix! # these ignores are from flake8-logging-format; please fix!
G001,G002,G003,G004,G100,G101,G200,G201,G202 G001,G002,G003,G004,G100,G101,G200,G201,G202
per-file-ignores = per-file-ignores =

View File

@ -119,7 +119,7 @@ def get_valid_userbenchmarks(torchbench_path: str) -> List[str]:
[os.path.join(ub_path, ubdir) for ubdir in os.listdir(ub_path)], [os.path.join(ub_path, ubdir) for ubdir in os.listdir(ub_path)],
) )
) )
valid_ubs = list(map(lambda x: os.path.basename(x), ubs)) valid_ubs = [os.path.basename(x) for x in ubs]
return valid_ubs return valid_ubs
@ -130,13 +130,13 @@ def extract_models_from_pr(
userbenchmark_list = [] userbenchmark_list = []
pr_list = [] pr_list = []
with open(prbody_file, "r") as pf: with open(prbody_file, "r") as pf:
lines = map(lambda x: x.strip(), pf.read().splitlines()) lines = (x.strip() for x in pf.read().splitlines())
magic_lines = list(filter(lambda x: x.startswith(MAGIC_PREFIX), lines)) magic_lines = list(filter(lambda x: x.startswith(MAGIC_PREFIX), lines))
if magic_lines: if magic_lines:
# Only the first magic line will be recognized. # Only the first magic line will be recognized.
pr_list = list( pr_list = [
map(lambda x: x.strip(), magic_lines[0][len(MAGIC_PREFIX) :].split(",")) x.strip() for x in magic_lines[0][len(MAGIC_PREFIX) :].split(",")
) ]
valid_models = get_valid_models(torchbench_path) valid_models = get_valid_models(torchbench_path)
valid_ubs = get_valid_userbenchmarks(torchbench_path) valid_ubs = get_valid_userbenchmarks(torchbench_path)
for pr_bm in pr_list: for pr_bm in pr_list:
@ -158,7 +158,7 @@ def extract_models_from_pr(
def find_torchbench_branch(prbody_file: str) -> str: def find_torchbench_branch(prbody_file: str) -> str:
branch_name: str = "" branch_name: str = ""
with open(prbody_file, "r") as pf: with open(prbody_file, "r") as pf:
lines = map(lambda x: x.strip(), pf.read().splitlines()) lines = (x.strip() for x in pf.read().splitlines())
magic_lines = list( magic_lines = list(
filter(lambda x: x.startswith(MAGIC_TORCHBENCH_PREFIX), lines) filter(lambda x: x.startswith(MAGIC_TORCHBENCH_PREFIX), lines)
) )

View File

@ -423,7 +423,7 @@ def tensor_is_on_xla(tensors):
if not isinstance(tensors, (tuple, list)): if not isinstance(tensors, (tuple, list)):
tensors = [tensors] tensors = [tensors]
tensors = [x for x in tensors if isinstance(x, torch.Tensor)] tensors = [x for x in tensors if isinstance(x, torch.Tensor)]
return any(map(lambda x: x.device.type == "xla", tensors)) return any((x.device.type == "xla" for x in tensors))
def timed( def timed(
@ -757,12 +757,9 @@ def speedup_experiment_ds(args, model_iter_fn, model, example_inputs):
shapes = [x[0].shape for x in example_inputs] shapes = [x[0].shape for x in example_inputs]
shape_keys = sorted(set(shapes)) shape_keys = sorted(set(shapes))
shape_speedups = { shape_speedups = {
shape: list( shape: [
map( it[1] for it in filter(lambda it: it[0] == shape, zip(shapes, speedups))
lambda it: it[1], ]
filter(lambda it: it[0] == shape, zip(shapes, speedups)),
)
)
for shape in shape_keys for shape in shape_keys
} }
output_str = ( output_str = (

View File

@ -351,7 +351,7 @@ def get_skip_tests(suite):
if hasattr(module, "SKIP_TRAIN"): if hasattr(module, "SKIP_TRAIN"):
skip_tests.update(module.SKIP_TRAIN) skip_tests.update(module.SKIP_TRAIN)
skip_tests = map(lambda name: f"-x {name}", skip_tests) skip_tests = (f"-x {name}" for name in skip_tests)
skip_str = " ".join(skip_tests) skip_str = " ".join(skip_tests)
return skip_str return skip_str

View File

@ -23,7 +23,7 @@ def sparse_grad_output(a, b):
def read_matrix_params(path): def read_matrix_params(path):
with open(path, 'r') as file: with open(path, 'r') as file:
line = file.readline() line = file.readline()
nrows, ncols, nnz = map(lambda el: int(el), line.split(', ')) nrows, ncols, nnz = (int(el) for el in line.split(', '))
return (nrows, ncols), nnz return (nrows, ncols), nnz
@ -39,9 +39,9 @@ def csr_to_coo(indices, indptr, shape):
def load_sparse_matrix(path, device): def load_sparse_matrix(path, device):
with open(path, 'r') as file: with open(path, 'r') as file:
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', ')) nrows, ncols, nnz = (int(el) for el in file.readline().split(', '))
index_pointers = map(lambda el: int(el), file.readline().split()) index_pointers = (int(el) for el in file.readline().split())
indices = map(lambda el: int(el), file.readline().split()) indices = (int(el) for el in file.readline().split())
index_pointers = list(index_pointers) index_pointers = list(index_pointers)
indices = list(indices) indices = list(indices)
@ -52,17 +52,17 @@ def load_sparse_matrix(path, device):
def gen_vector(path, device): def gen_vector(path, device):
with open(path, 'r') as file: with open(path, 'r') as file:
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', ')) nrows, ncols, nnz = (int(el) for el in file.readline().split(', '))
index_pointers = map(lambda el: int(el), file.readline().split()) index_pointers = (int(el) for el in file.readline().split())
indices = map(lambda el: int(el), file.readline().split()) indices = (int(el) for el in file.readline().split())
return torch.randn(nrows, dtype=torch.double, device=device) return torch.randn(nrows, dtype=torch.double, device=device)
def gen_matrix(path, device): def gen_matrix(path, device):
with open(path, 'r') as file: with open(path, 'r') as file:
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', ')) nrows, ncols, nnz = (int(el) for el in file.readline().split(', '))
index_pointers = map(lambda el: int(el), file.readline().split()) index_pointers = (int(el) for el in file.readline().split())
indices = map(lambda el: int(el), file.readline().split()) indices = (int(el) for el in file.readline().split())
return torch.randn(nrows, ncols, dtype=torch.double, device=device) return torch.randn(nrows, ncols, dtype=torch.double, device=device)

View File

@ -499,9 +499,9 @@ class PartialMaml(torch.nn.Module):
logits = net(x_spt) logits = net(x_spt)
loss = F.cross_entropy(logits, y_spt) loss = F.cross_entropy(logits, y_spt)
grad = torch.autograd.grad(loss, net.parameters()) grad = torch.autograd.grad(loss, net.parameters())
fast_weights = list( fast_weights = [
map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters())) p[1] - self.update_lr * p[0] for p in zip(grad, net.parameters())
) ]
# this is the loss and accuracy before first update # this is the loss and accuracy before first update
with torch.no_grad(): with torch.no_grad():

View File

@ -3427,7 +3427,7 @@ class TestVmapOperatorsOpInfo(TestCase):
check_shape_only = op.name in ('empty_like', 'new_empty') check_shape_only = op.name in ('empty_like', 'new_empty')
for sample_input in sample_inputs_itr: for sample_input in sample_inputs_itr:
args = (sample_input.input,) + sample_input.args args = (sample_input.input,) + sample_input.args
if not any(map(lambda arg: isinstance(arg, torch.Tensor), args)): if not any((isinstance(arg, torch.Tensor) for arg in args)):
# Atleast one tensor required for vmap. # Atleast one tensor required for vmap.
continue continue
kwargs = sample_input.kwargs kwargs = sample_input.kwargs

View File

@ -1346,7 +1346,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
# assert that each of the desired modules have the observers inserted # assert that each of the desired modules have the observers inserted
for fqn, module in prepared_for_callibrate_model.named_modules(): for fqn, module in prepared_for_callibrate_model.named_modules():
# check if module is a supported module # check if module is a supported module
is_in_include_list = sum(list(map(lambda x: isinstance(module, x), mods_to_check))) > 0 is_in_include_list = sum([isinstance(module, x) for x in mods_to_check]) > 0
if is_in_include_list: if is_in_include_list:
# make sure it has the observer attribute # make sure it has the observer attribute

View File

@ -2398,10 +2398,8 @@ class TestAutograd(TestCase):
y = torch.randn((3, 3), requires_grad=True) y = torch.randn((3, 3), requires_grad=True)
MyFunction.apply(x, y).sum().backward() MyFunction.apply(x, y).sum().backward()
has_deprecated = map(lambda warn: has_deprecated = ('deprecated' in str(warn) and
'deprecated' in str(warn) and 'saved_variables' in str(warn) for warn in warns)
'saved_variables' in str(warn),
warns)
has_deprecated = reduce(lambda x, y: x or y, has_deprecated) has_deprecated = reduce(lambda x, y: x or y, has_deprecated)
self.assertTrue(has_deprecated) self.assertTrue(has_deprecated)

View File

@ -498,7 +498,7 @@ class TestBinaryUfuncs(TestCase):
) )
def _supported(dtypes): def _supported(dtypes):
return all(map(lambda x: x in supported_dtypes, dtypes)) return all((x in supported_dtypes for x in dtypes))
# int x int type promotion # int x int type promotion
if _supported((torch.int16, torch.int32, torch.int64)): if _supported((torch.int16, torch.int32, torch.int64)):

View File

@ -6793,7 +6793,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
for fn in [torch.det, torch.logdet, torch.slogdet, torch.linalg.slogdet]: for fn in [torch.det, torch.logdet, torch.slogdet, torch.linalg.slogdet]:
expected_value = [] expected_value = []
actual_value = fn(full_tensor) actual_value = fn(full_tensor)
for full_idx in itertools.product(*map(lambda x: list(range(x)), batchdims)): for full_idx in itertools.product(*(list(range(x)) for x in batchdims)):
expected_value.append(fn(full_tensor[full_idx])) expected_value.append(fn(full_tensor[full_idx]))
if fn == torch.slogdet or fn == torch.linalg.slogdet: if fn == torch.slogdet or fn == torch.linalg.slogdet:

View File

@ -148,7 +148,7 @@ class TestCommon(TestCase):
if isinstance(result, torch.Tensor): if isinstance(result, torch.Tensor):
self.assertTrue(result.device == cuda_device) self.assertTrue(result.device == cuda_device)
elif is_iterable_of_tensors(result): elif is_iterable_of_tensors(result):
self.assertTrue(all(map(lambda t: t.device == cuda_device, result))) self.assertTrue(all((t.device == cuda_device for t in result)))
else: else:
self.skipTest( self.skipTest(
"Skipped! Only supports single tensor or iterable of tensor outputs." "Skipped! Only supports single tensor or iterable of tensor outputs."
@ -711,7 +711,7 @@ class TestCommon(TestCase):
return (out.stride(),) return (out.stride(),)
# assumes (see above) that out is an iterable of tensors # assumes (see above) that out is an iterable of tensors
return tuple(map(lambda t: t.stride(), out)) return tuple((t.stride() for t in out))
# Extracts data pointers from a tensor or iterable of tensors into a tuple # Extracts data pointers from a tensor or iterable of tensors into a tuple
# NOTE: only extracts on the CPU and CUDA device types since some # NOTE: only extracts on the CPU and CUDA device types since some
@ -724,7 +724,7 @@ class TestCommon(TestCase):
return (out.data_ptr(),) return (out.data_ptr(),)
# assumes (see above) that out is an iterable of tensors # assumes (see above) that out is an iterable of tensors
return tuple(map(lambda t: t.data_ptr(), out)) return tuple((t.data_ptr() for t in out))
@suppress_warnings @suppress_warnings
def _compare_out(transform, *, compare_strides_and_data_ptrs=True): def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
@ -831,7 +831,7 @@ class TestCommon(TestCase):
return (out.stride(),) return (out.stride(),)
# assumes (see above) that out is an iterable of tensors # assumes (see above) that out is an iterable of tensors
return tuple(map(lambda t: t.stride(), out)) return tuple((t.stride() for t in out))
# Extracts data pointers from a tensor or iterable of tensors into a tuple # Extracts data pointers from a tensor or iterable of tensors into a tuple
# NOTE: only extracts on the CPU and CUDA device types since some # NOTE: only extracts on the CPU and CUDA device types since some
@ -844,7 +844,7 @@ class TestCommon(TestCase):
return (out.data_ptr(),) return (out.data_ptr(),)
# assumes (see above) that out is an iterable of tensors # assumes (see above) that out is an iterable of tensors
return tuple(map(lambda t: t.data_ptr(), out)) return tuple((t.data_ptr() for t in out))
def _compare_out(transform, *, compare_strides_and_data_ptrs=True): def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
out = _apply_out_transform(transform, expected) out = _apply_out_transform(transform, expected)

View File

@ -1584,7 +1584,7 @@ class TestSparseCSR(TestCase):
return t.cpu().resolve_conj().numpy() return t.cpu().resolve_conj().numpy()
res = _npref_block_addmm_addmv( res = _npref_block_addmm_addmv(
*map(lambda t: prep_input(t), (c, a, b)), *(prep_input(t) for t in (c, a, b)),
alpha, alpha,
beta beta
) )
@ -2406,7 +2406,7 @@ class TestSparseCSR(TestCase):
output.backward(covector) output.backward(covector)
# Compute dense result and compare with sparse result # Compute dense result and compare with sparse result
c1, a1, b1 = map(lambda x: x.detach().to_dense().requires_grad_(True), [c, a, b]) c1, a1, b1 = (x.detach().to_dense().requires_grad_(True) for x in [c, a, b])
dense_output = sample.kwargs['alpha'] * (a1 @ b1) * torch.ones_like(c).to_dense() + sample.kwargs['beta'] * c1 dense_output = sample.kwargs['alpha'] * (a1 @ b1) * torch.ones_like(c).to_dense() + sample.kwargs['beta'] * c1
self.assertEqual(output, dense_output) self.assertEqual(output, dense_output)
dense_covector = covector.to_dense() dense_covector = covector.to_dense()

View File

@ -1138,8 +1138,7 @@ class TestTypePromotion(TestCase):
exp_type = expected_type(inp, min_v, max_v) exp_type = expected_type(inp, min_v, max_v)
if exp_type != torch.bool: if exp_type != torch.bool:
actual = torch.clamp(inp, min_v, max_v) actual = torch.clamp(inp, min_v, max_v)
inps = list(map(lambda x: x.to(exp_type) if isinstance(x, torch.Tensor) else x, inps = [x.to(exp_type) if isinstance(x, torch.Tensor) else x for x in (inp, min_v, max_v)]
(inp, min_v, max_v)))
expected = torch.clamp(inps[0], inps[1], inps[2]) expected = torch.clamp(inps[0], inps[1], inps[2])
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
if inp.dtype in floating_types() or exp_type == inp.dtype: if inp.dtype in floating_types() or exp_type == inp.dtype:
@ -1151,8 +1150,7 @@ class TestTypePromotion(TestCase):
exp_type = expected_type(inp, val) exp_type = expected_type(inp, val)
if exp_type != torch.bool: if exp_type != torch.bool:
actual = torch.clamp_min(inp, val) actual = torch.clamp_min(inp, val)
inps = list(map(lambda x: x.to(exp_type) if isinstance(x, torch.Tensor) else x, inps = [x.to(exp_type) if isinstance(x, torch.Tensor) else x for x in (inp, val)]
(inp, val)))
expected = torch.clamp_min(inps[0], inps[1]) expected = torch.clamp_min(inps[0], inps[1])
self.assertEqual(actual.dtype, exp_type) self.assertEqual(actual.dtype, exp_type)
self.assertEqual(actual, expected) self.assertEqual(actual, expected)

View File

@ -184,7 +184,7 @@ def _parse_reveals(file: IO[str]) -> List[str]:
string = file.read().replace("*", "") string = file.read().replace("*", "")
# Grab all `# E:`-based comments # Grab all `# E:`-based comments
comments_array = list(map(lambda str: str.partition(" # E: ")[2], string.split("\n"))) comments_array = [str.partition(" # E: ")[2] for str in string.split("\n")]
comments = "/n".join(comments_array) comments = "/n".join(comments_array)
# Only search for the `{*}` pattern within comments, # Only search for the `{*}` pattern within comments,

View File

@ -443,8 +443,8 @@ def gen_autograd_functions_lib(
# get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
# infos with the diff dispatchkeys but the same name will still be in the same shard. # infos with the diff dispatchkeys but the same name will still be in the same shard.
infos = get_infos_with_derivatives_list(differentiability_infos) infos = get_infos_with_derivatives_list(differentiability_infos)
declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos)) declarations = [process_function(f, FUNCTION_DECLARATION) for f in infos]
definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos)) definitions = [process_function(f, FUNCTION_DEFINITION) for f in infos]
file_basename = "Functions" file_basename = "Functions"
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)

View File

@ -159,9 +159,9 @@ _SKIP_PYTHON_BINDINGS = [
"_nested_view_from_buffer_copy_out", "_nested_view_from_buffer_copy_out",
] ]
SKIP_PYTHON_BINDINGS = list( SKIP_PYTHON_BINDINGS = [
map(lambda pattern: re.compile(rf"^{pattern}$"), _SKIP_PYTHON_BINDINGS) re.compile(rf"^{pattern}$") for pattern in _SKIP_PYTHON_BINDINGS
) ]
# These function signatures are not exposed to Python. Note that this signature # These function signatures are not exposed to Python. Note that this signature
# list does not support regex. # list does not support regex.
@ -864,7 +864,7 @@ def method_impl(
name=name, name=name,
pycname=pycname, pycname=pycname,
method_header=method_header, method_header=method_header,
max_args=max(map(lambda o: o.signature.arguments_count(), overloads)), max_args=max((o.signature.arguments_count() for o in overloads)),
signatures=signatures, signatures=signatures,
traceable=traceable, traceable=traceable,
check_has_torch_function=gen_has_torch_function_check( check_has_torch_function=gen_has_torch_function_check(
@ -1216,7 +1216,7 @@ def sort_overloads(
del larger_than[j] del larger_than[j]
sorted_ids.append(j) sorted_ids.append(j)
return list(map(lambda x: grouped_overloads[x], sorted_ids)) return [grouped_overloads[x] for x in sorted_ids]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
@ -1250,9 +1250,9 @@ def emit_single_dispatch(
# dispatch lambda signature # dispatch lambda signature
name = cpp.name(f.func) name = cpp.name(f.func)
lambda_formals = ", ".join( lambda_formals = ", ".join(
map( (
lambda a: f"{a.type_str} {a.name}", f"{a.type_str} {a.name}"
dispatch_lambda_args(ps, f, symint=symint), for a in dispatch_lambda_args(ps, f, symint=symint)
) )
) )
lambda_return = dispatch_lambda_return_str(f) lambda_return = dispatch_lambda_return_str(f)

View File

@ -149,12 +149,7 @@ def main(argv: List[Any]) -> None:
model_dict = yaml.safe_load(model_file) model_dict = yaml.safe_load(model_file)
model_dicts.append(model_dict) model_dicts.append(model_dict)
selective_builders = list( selective_builders = [SelectiveBuilder.from_yaml_dict(m) for m in model_dicts]
map(
lambda m: SelectiveBuilder.from_yaml_dict(m),
model_dicts,
)
)
# While we have the model_dicts generate the supported mobile models api # While we have the model_dicts generate the supported mobile models api
gen_supported_mobile_models(model_dicts, options.output_dir) gen_supported_mobile_models(model_dicts, options.output_dir)

View File

@ -67,9 +67,7 @@ def get_selected_kernel_dtypes_code(
): ):
body_parts = [] body_parts = []
for kernel_tag, dtypes in selective_builder.kernel_metadata.items(): for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
conditions = list( conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes]
map(lambda x: "scalar_type == at::ScalarType::" + x, dtypes)
)
body_parts.append( body_parts.append(
if_condition_template.substitute( if_condition_template.substitute(
kernel_tag_name=kernel_tag, kernel_tag_name=kernel_tag,

View File

@ -316,11 +316,11 @@ class VariableBuilder:
elif istype( elif istype(
value, (dict, collections.defaultdict, collections.OrderedDict) value, (dict, collections.defaultdict, collections.OrderedDict)
) and all( ) and all(
map( (
lambda k: ConstantVariable.is_literal(k) ConstantVariable.is_literal(k)
or self.tensor_can_be_dict_key(k) or self.tensor_can_be_dict_key(k)
or isinstance(k, enum.Enum), or isinstance(k, enum.Enum)
value.keys(), for k in value.keys()
) )
): ):
if not value and self.get_source().is_nn_module(): if not value and self.get_source().is_nn_module():

View File

@ -2348,7 +2348,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
[isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards] [isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards]
) )
# See Note [Detaching saved tensors in AOTAutograd] # See Note [Detaching saved tensors in AOTAutograd]
ctx.save_for_backward(*map(lambda x: x.detach() if x._is_view() else x, tensors_saved_for_backwards)) ctx.save_for_backward(*(x.detach() if x._is_view() else x for x in tensors_saved_for_backwards))
symint_outs = fw_outs[-num_symints_saved_for_bw:] symint_outs = fw_outs[-num_symints_saved_for_bw:]
assert all( assert all(
[ [
@ -2360,7 +2360,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
else: else:
tensors_saved_for_backwards = fw_outs[num_forward_returns:] tensors_saved_for_backwards = fw_outs[num_forward_returns:]
# See Note [Detaching saved tensors in AOTAutograd] # See Note [Detaching saved tensors in AOTAutograd]
ctx.save_for_backward(*map(lambda x: x.detach() if x._is_view() else x, tensors_saved_for_backwards)) ctx.save_for_backward(*(x.detach() if x._is_view() else x for x in tensors_saved_for_backwards))
ctx.symints = [] ctx.symints = []
raw_returns = fw_outs[0:num_forward_returns] raw_returns = fw_outs[0:num_forward_returns]

View File

@ -544,7 +544,7 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
# Iterate and concat the jacobians of different # Iterate and concat the jacobians of different
# inputs. # inputs.
for idx in range(len(flat_primals)): for idx in range(len(flat_primals)):
r = tuple(map(lambda r_: r_[idx], chunked_results)) r = tuple((r_[idx] for r_ in chunked_results))
flat_results.append(torch.cat(r, 0)) flat_results.append(torch.cat(r, 0))
return flat_results return flat_results
@ -567,7 +567,7 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
for t in flat_basis_chunk: for t in flat_basis_chunk:
assert t.size(0) == 1 assert t.size(0) == 1
flat_basis_chunk = list(map(lambda t: torch.squeeze(t, 0), flat_basis_chunk)) flat_basis_chunk = [torch.squeeze(t, 0) for t in flat_basis_chunk]
basis = tree_unflatten(flat_basis_chunk, output_spec) basis = tree_unflatten(flat_basis_chunk, output_spec)

View File

@ -397,7 +397,7 @@ class SizeVarAllocator:
def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]: def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
strides = tuple( strides = tuple(
map(lambda x: abs(x), self.stride_hints(index, vars)) map(abs, self.stride_hints(index, vars))
) # lambda to placate mypy ) # lambda to placate mypy
order = list(range(len(strides))) order = list(range(len(strides)))
order.sort(key=lambda x: (strides[x] == 0, strides[x])) order.sort(key=lambda x: (strides[x] == 0, strides[x]))

View File

@ -346,7 +346,7 @@ def _broadcast_shapes(*_shapes):
def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True): def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
# Computes common shape # Computes common shape
common_shape = _broadcast_shapes( common_shape = _broadcast_shapes(
*map(lambda t: t.shape if isinstance(t, TensorLike) else None, args) *(t.shape if isinstance(t, TensorLike) else None for t in args)
) )
def __maybe_broadcast(x, shape): def __maybe_broadcast(x, shape):

View File

@ -292,7 +292,7 @@ class PerChannelDetector(DetectorBase):
# get the fully qualified name and check if in list of modules to include and list of modules to ignore # get the fully qualified name and check if in list of modules to include and list of modules to ignore
for fqn, module in model.named_modules(): for fqn, module in model.named_modules():
is_in_include_list = sum(list(map(lambda x: isinstance(module, x), self.supported_modules))) > 0 is_in_include_list = sum([isinstance(module, x) for x in self.supported_modules]) > 0
# check if the module per_channel is supported # check if the module per_channel is supported
# based on backend # based on backend
@ -517,10 +517,10 @@ class DynamicStaticDetector(DetectorBase):
Returns True if the module is supported by observer, False otherwise Returns True if the module is supported by observer, False otherwise
""" """
# check to see if module is of a supported type # check to see if module is of a supported type
is_supported_type = sum(list(map(lambda x: isinstance(module, x), self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED))) > 0 is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0
# check if it will be supported # check if it will be supported
future_supported_type = sum(list(map(lambda x: isinstance(module, x), self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED))) > 0 future_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED]) > 0
# supported # supported
supported = is_supported_type or future_supported_type supported = is_supported_type or future_supported_type
@ -578,7 +578,7 @@ class DynamicStaticDetector(DetectorBase):
post_obs_dist_classif = self.STATIONARY_STR if post_stat > self.tolerance else self.NON_STATIONARY_STR post_obs_dist_classif = self.STATIONARY_STR if post_stat > self.tolerance else self.NON_STATIONARY_STR
# check if current support or future support # check if current support or future support
is_supported_type = sum(list(map(lambda x: isinstance(module, x), self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED))) > 0 is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0
# store the set of important information for this module # store the set of important information for this module
module_info = { module_info = {
@ -791,7 +791,7 @@ class InputWeightEqualizationDetector(DetectorBase):
Returns True if the module is supported by observer, False otherwise Returns True if the module is supported by observer, False otherwise
""" """
# check to see if module is of a supported type # check to see if module is of a supported type
is_supported_type = sum(list(map(lambda x: type(module) is x, self.SUPPORTED_MODULES))) > 0 is_supported_type = sum([type(module) is x for x in self.SUPPORTED_MODULES]) > 0
# this is check for observer insertion # this is check for observer insertion
if insert: if insert:

View File

@ -54,7 +54,7 @@ def _cast(value, dtype):
elif isinstance(value, collections.abc.Mapping): elif isinstance(value, collections.abc.Mapping):
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()} return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
elif isinstance(value, collections.abc.Iterable): elif isinstance(value, collections.abc.Iterable):
iterable = map(lambda v: _cast(v, dtype), value) iterable = (_cast(v, dtype) for v in value)
if isinstance(value, (list, tuple)): if isinstance(value, (list, tuple)):
return type(value)(iterable) return type(value)(iterable)
else: else:

View File

@ -259,7 +259,7 @@ class _OverlapInfo:
assert ( assert (
len(self.broadcast_handles) == self.num_bucket_assignments len(self.broadcast_handles) == self.num_bucket_assignments
), f"Missing at least one broadcast handle on rank {dist.get_rank()}" ), f"Missing at least one broadcast handle on rank {dist.get_rank()}"
_ = list(map(lambda x: x.wait(), self.broadcast_handles)) _ = [x.wait() for x in self.broadcast_handles]
self.broadcast_handles.clear() self.broadcast_handles.clear()
def clear_per_iter_info(self) -> None: def clear_per_iter_info(self) -> None:
@ -807,7 +807,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
handles = [] handles = []
for rank in range(self.world_size): for rank in range(self.world_size):
handles.extend(self._broadcast_params_from_rank(rank)) handles.extend(self._broadcast_params_from_rank(rank))
_ = list(map(lambda x: x.wait(), handles)) _ = [x.wait() for x in handles]
@property @property
def _device_to_params_per_rank( def _device_to_params_per_rank(

View File

@ -109,7 +109,7 @@ def _masked_tensor_str(data, mask, formatter):
for d in data for d in data
] ]
max_len = max( max_len = max(
map(lambda x: 8 if x[1] else len(x[0]), zip(formatted_elements, ~mask)) (8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask))
) )
return ( return (
"[" "["

View File

@ -885,7 +885,7 @@ def arange(g: jit_utils.GraphContext, *args):
dtype = symbolic_helper._maybe_get_const(dtype, "i") dtype = symbolic_helper._maybe_get_const(dtype, "i")
return dtype return dtype
if len(args) == 2 and all(map(lambda val: isinstance(val, int), args)): if len(args) == 2 and all((isinstance(val, int) for val in args)):
# aten::arange(Scalar start, Scalar end) # aten::arange(Scalar start, Scalar end)
dtype = torch.int64 dtype = torch.int64
# Start index. # Start index.

View File

@ -91,7 +91,7 @@ class _StorageBase:
return str(self) return str(self)
def __iter__(self): def __iter__(self):
return iter(map(lambda i: self[i], range(self.size()))) return iter((self[i] for i in range(self.size())))
def __copy__(self): def __copy__(self):
return self.clone() return self.clone()
@ -725,7 +725,7 @@ class TypedStorage:
def __iter__(self): def __iter__(self):
_warn_typed_storage_removal() _warn_typed_storage_removal()
return iter(map(lambda i: self[i], range(self.size()))) return iter((self[i] for i in range(self.size())))
def __copy__(self): def __copy__(self):
_warn_typed_storage_removal() _warn_typed_storage_removal()

View File

@ -6827,7 +6827,7 @@ def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
if mask_t.sum() == 0: if mask_t.sum() == 0:
def random_index(shape): def random_index(shape):
return tuple(map(lambda max_idx: random.randrange(0, max_idx), shape)) return tuple((random.randrange(0, max_idx) for max_idx in shape))
mask_t[random_index(mask_t.shape)] = True mask_t[random_index(mask_t.shape)] = True
return mask_t return mask_t

View File

@ -5214,7 +5214,7 @@ class ModuleTest(TestBase):
type_map = {torch.double: torch.float} type_map = {torch.double: torch.float}
cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,) cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
is_any_input_complex = any(map(lambda t: isinstance(t, torch.Tensor) and t.dtype.is_complex, cpu_input_tuple)) is_any_input_complex = any((isinstance(t, torch.Tensor) and t.dtype.is_complex for t in cpu_input_tuple))
gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map) gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)

View File

@ -869,9 +869,9 @@ class OpInfo:
# Attribute to verify dynamic_dtypes are used. # Attribute to verify dynamic_dtypes are used.
self.dynamic_dtypes = any( self.dynamic_dtypes = any(
map( (
lambda dtypes: isinstance(dtypes, utils._dynamic_dispatch_dtypes), isinstance(dtypes, utils._dynamic_dispatch_dtypes)
dtypes_args, for dtypes in dtypes_args
) )
) )
@ -1661,7 +1661,7 @@ def generate_elementwise_binary_small_value_tensors(
complex_vals = product(_float_vals, _float_vals) complex_vals = product(_float_vals, _float_vals)
# Note the use of list is required here or the map generator will be # Note the use of list is required here or the map generator will be
# emptied by the following product and it won't produce the desired cross-product # emptied by the following product and it won't produce the desired cross-product
complex_vals = list(map(lambda x: complex(*x), complex_vals)) complex_vals = [complex(*x) for x in complex_vals]
prod = product(complex_vals, complex_vals) prod = product(complex_vals, complex_vals)
elif dtype in (torch.int8, torch.int16, torch.int32, torch.int64): elif dtype in (torch.int8, torch.int16, torch.int32, torch.int64):
prod = product(_int_vals, _int_vals) prod = product(_int_vals, _int_vals)
@ -1701,7 +1701,7 @@ def generate_elementwise_binary_large_value_tensors(
complex_vals = product(_large_float_vals, _large_float_vals) complex_vals = product(_large_float_vals, _large_float_vals)
# Note the use of list is required here or the map generator will be # Note the use of list is required here or the map generator will be
# emptied by the following product and it won't produce the desired cross-product # emptied by the following product and it won't produce the desired cross-product
complex_vals = list(map(lambda x: complex(*x), complex_vals)) complex_vals = [complex(*x) for x in complex_vals]
prod = product(complex_vals, complex_vals) prod = product(complex_vals, complex_vals)
elif dtype in (torch.int16, torch.int32, torch.int64): elif dtype in (torch.int16, torch.int32, torch.int64):
prod = product(_large_int_vals, _large_int_vals) prod = product(_large_int_vals, _large_int_vals)
@ -1732,7 +1732,7 @@ def generate_elementwise_binary_extremal_value_tensors(
complex_vals = product(_float_extremals, _float_extremals) complex_vals = product(_float_extremals, _float_extremals)
# Note the use of list is required here or the map generator will be # Note the use of list is required here or the map generator will be
# emptied by the following product and it won't produce the desired cross-product # emptied by the following product and it won't produce the desired cross-product
complex_vals = list(map(lambda x: complex(*x), complex_vals)) complex_vals = [complex(*x) for x in complex_vals]
prod = product(complex_vals, complex_vals) prod = product(complex_vals, complex_vals)
else: else:
raise ValueError("Unsupported dtype!") raise ValueError("Unsupported dtype!")

View File

@ -315,7 +315,7 @@ class PythonOutArgument(PythonArgument):
outputs=outputs, outputs=outputs,
) )
elif size > 1: elif size > 1:
if any(map(lambda a: not a.type.is_tensor_like(), outputs)): if any((not a.type.is_tensor_like() for a in outputs)):
raise RuntimeError(f"Unsupported output type: {outputs}") raise RuntimeError(f"Unsupported output type: {outputs}")
return PythonOutArgument( return PythonOutArgument(
name="out", name="out",
@ -390,9 +390,9 @@ class PythonSignature:
# signature_str_pyi(). # signature_str_pyi().
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
args = self.arguments(skip_outputs=skip_outputs) args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = list( schema_formals: List[str] = [
map(lambda a: a.argument_str(method=self.method, symint=symint), args) a.argument_str(method=self.method, symint=symint) for a in args
) ]
positional_argc = len(self.input_args) positional_argc = len(self.input_args)
if len(schema_formals) > positional_argc: if len(schema_formals) > positional_argc:
schema_formals.insert(positional_argc, "*") schema_formals.insert(positional_argc, "*")
@ -401,9 +401,9 @@ class PythonSignature:
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs) args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = list( schema_formals: List[str] = [
map(lambda a: a.argument_str_pyi(method=self.method), args) a.argument_str_pyi(method=self.method) for a in args
) ]
positional_argc = len(self.input_args) positional_argc = len(self.input_args)
if len(schema_formals) > positional_argc: if len(schema_formals) > positional_argc:
schema_formals.insert(positional_argc, "*") schema_formals.insert(positional_argc, "*")
@ -418,9 +418,9 @@ class PythonSignature:
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]: def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
# only pyi uses vararg signatures # only pyi uses vararg signatures
args = self.arguments(skip_outputs=skip_outputs) args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = list( schema_formals: List[str] = [
map(lambda a: a.argument_str_pyi(method=self.method), args) a.argument_str_pyi(method=self.method) for a in args
) ]
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
num_args = self.arguments_count() num_args = self.arguments_count()
num_positionalargs = len(self.input_args) num_positionalargs = len(self.input_args)
@ -478,9 +478,9 @@ class PythonSignatureDeprecated(PythonSignature):
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs) args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = list( schema_formals: List[str] = [
map(lambda a: a.argument_str_pyi(method=self.method, deprecated=True), args) a.argument_str_pyi(method=self.method, deprecated=True) for a in args
) ]
positional_argc = len(self.input_args) positional_argc = len(self.input_args)
if len(schema_formals) > positional_argc: if len(schema_formals) > positional_argc:
schema_formals.insert(positional_argc, "*") schema_formals.insert(positional_argc, "*")
@ -882,10 +882,10 @@ def signature_from_schema(
def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]: def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)): if len(returns) <= 1 or all((r.name is None for r in returns)):
return [] return []
else: else:
if any(map(lambda r: r.name is None, returns)): if any((r.name is None for r in returns)):
# When building on Windows, `PyStructSequence_UnnamedField` could not be # When building on Windows, `PyStructSequence_UnnamedField` could not be
# resolved by the linker for some reason, which cause error in building: # resolved by the linker for some reason, which cause error in building:
# #
@ -897,7 +897,7 @@ def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
# or none of them. # or none of them.
raise ValueError("Unnamed field is not supported by codegen") raise ValueError("Unnamed field is not supported by codegen")
return list(map(lambda r: str(r.name), returns)) return [str(r.name) for r in returns]
def argument_type_str_pyi(t: Type) -> str: def argument_type_str_pyi(t: Type) -> str:
@ -1157,7 +1157,7 @@ def dispatch_lambda_return_str(f: NativeFunction) -> str:
# mutable reference to temporary. Maybe we could assign it to a # mutable reference to temporary. Maybe we could assign it to a
# variable itself.) # variable itself.)
returns_without_annotation = tuple( returns_without_annotation = tuple(
map(lambda r: Return(r.name, r.type, None), f.func.returns) (Return(r.name, r.type, None) for r in f.func.returns)
) )
return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type() return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
if return_str not in SUPPORTED_RETURN_TYPES: if return_str not in SUPPORTED_RETURN_TYPES:
@ -1189,7 +1189,7 @@ def cpp_dispatch_exprs(
exprs: Tuple[str, ...] = tuple() exprs: Tuple[str, ...] = tuple()
if not isinstance(python_signature, PythonSignatureDeprecated): if not isinstance(python_signature, PythonSignatureDeprecated):
# By default the exprs are consistent with the C++ signature. # By default the exprs are consistent with the C++ signature.
exprs = tuple(map(lambda a: a.name, cpp_args)) exprs = tuple((a.name for a in cpp_args))
else: else:
# For deprecated python signature we may need fill in some constants. # For deprecated python signature we may need fill in some constants.
exprs = tuple( exprs = tuple(
@ -1415,7 +1415,7 @@ def dispatch_lambda_exprs(
lambda_args_exprs["self"] = "self" lambda_args_exprs["self"] = "self"
# 2. special packing/checking for TensorOptions. # 2. special packing/checking for TensorOptions.
tensor_options_args_names = list(map(lambda a: a.name, ps.tensor_options_args)) tensor_options_args_names = [a.name for a in ps.tensor_options_args]
if has_toptions: if has_toptions:
if f.func.is_out_fn(): if f.func.is_out_fn():
raise RuntimeError(f"{f.func}: tensor options with output arg") raise RuntimeError(f"{f.func}: tensor options with output arg")
@ -1429,7 +1429,7 @@ def dispatch_lambda_exprs(
f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'" f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
) )
if not all( if not all(
map(lambda a: a in tensor_options_args_names, TENSOR_OPTIONS_FIELDS.keys()) (a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS.keys())
): ):
raise RuntimeError( raise RuntimeError(
f"{f.func}: incomplete tensor options args: {tensor_options_args_names}" f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
@ -1457,9 +1457,7 @@ torch::utils::maybe_initialize_cuda(options);
raise RuntimeError( raise RuntimeError(
f"{f.func}: dtype in tensor_options_args without output arg" f"{f.func}: dtype in tensor_options_args without output arg"
) )
if not all( if not all((a in tensor_options_args_names for a in ("layout", "device"))):
map(lambda a: a in tensor_options_args_names, ("layout", "device"))
):
raise RuntimeError( raise RuntimeError(
f"{f.func}: incomplete tensor options for output check" f"{f.func}: incomplete tensor options for output check"
) )
@ -1478,6 +1476,6 @@ check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dt
) )
return DispatchLambdaArgumentExprs( return DispatchLambdaArgumentExprs(
exprs=tuple(map(lambda a: lambda_args_exprs[a.name], lambda_args)), exprs=tuple((lambda_args_exprs[a.name] for a in lambda_args)),
inits=inits, inits=inits,
) )

View File

@ -83,7 +83,7 @@ class SelectiveBuildOperator:
if "debug_info" in op_info: if "debug_info" in op_info:
di_list = op_info["debug_info"] di_list = op_info["debug_info"]
assert isinstance(di_list, list) assert isinstance(di_list, list)
debug_info = tuple(map(lambda x: str(x), di_list)) debug_info = tuple((str(x) for x in di_list))
return SelectiveBuildOperator( return SelectiveBuildOperator(
name=op_name, name=op_name,

View File

@ -85,7 +85,7 @@ class SelectiveBuilder:
di_list = data["debug_info"] di_list = data["debug_info"]
assert isinstance(di_list, list) assert isinstance(di_list, list)
debug_info = tuple(map(lambda x: str(x), di_list)) debug_info = tuple((str(x) for x in di_list))
operators = {} operators = {}
operators_dict = data.get("operators", {}) operators_dict = data.get("operators", {})
@ -99,7 +99,7 @@ class SelectiveBuilder:
assert isinstance(kernel_metadata_dict, dict) assert isinstance(kernel_metadata_dict, dict)
for k, v in kernel_metadata_dict.items(): for k, v in kernel_metadata_dict.items():
kernel_metadata[str(k)] = list(map(lambda dtype: str(dtype), v)) kernel_metadata[str(k)] = [str(dtype) for dtype in v]
custom_classes = data.get("custom_classes", []) custom_classes = data.get("custom_classes", [])
custom_classes = set(custom_classes) # type: ignore[arg-type] custom_classes = set(custom_classes) # type: ignore[arg-type]